Skip to content

Commit 31ab9f1

Browse files
torymurrlouf
authored andcommitted
Allow threads on Index init
1 parent 3d01211 commit 31ab9f1

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

benchmarks/bench_regex_guide.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from concurrent.futures import ThreadPoolExecutor
2+
3+
import psutil
14
from outlines_core.fsm.guide import RegexGuide
25

36
from .common import setup_tokenizer
@@ -25,6 +28,28 @@ def setup(self, pattern_name):
2528
def time_regex_to_guide(self, pattern_name):
2629
RegexGuide.from_regex(self.pattern, self.tokenizer)
2730

31+
def time_regex_to_guide_parallel(self, pattern_name):
32+
# Default GIL switch interval is 5ms (0.005), which isn't helpful for cpu heavy tasks,
33+
# this parallel case should be relatively close in runtime to one thread, but it is not,
34+
# because of the GIL.
35+
core_count = psutil.cpu_count(logical=False)
36+
with ThreadPoolExecutor(max_workers=core_count) as executor:
37+
list(executor.map(self._from_regex, [pattern_name] * core_count))
38+
39+
def time_regex_to_guide_parallel_with_custom_switch_interval(self, pattern_name):
40+
# This test is to show, that if GIL's switch interval is set to be longer, then the parallel
41+
# test's runtime on physical cores will be much closer to the one-threaded case.
42+
import sys
43+
44+
sys.setswitchinterval(5)
45+
46+
core_count = psutil.cpu_count(logical=False)
47+
with ThreadPoolExecutor(max_workers=core_count) as executor:
48+
list(executor.map(self._from_regex, [pattern_name] * core_count))
49+
50+
def _from_regex(self, pattern_name):
51+
RegexGuide.from_regex(self.pattern, self.tokenizer)
52+
2853

2954
class MemoryRegexGuideBenchmark:
3055
params = ["simple_phone", "complex_span_constrained_relation_extraction"]

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ test = [
4949
"datasets",
5050
"pillow",
5151
"asv",
52+
"psutil",
5253
"setuptools-rust",
5354
]
5455

src/python_bindings/mod.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,17 @@ pub struct PyIndex(Index);
8080
impl PyIndex {
8181
#[new]
8282
fn new(
83+
py: Python<'_>,
8384
fsm_info: &PyFSMInfo,
8485
vocabulary: &PyVocabulary,
8586
eos_token_id: u32,
8687
frozen_tokens: FxHashSet<String>,
8788
) -> PyResult<Self> {
88-
Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens)
89-
.map(PyIndex)
90-
.map_err(Into::into)
89+
py.allow_threads(|| {
90+
Index::new(&fsm_info.into(), &vocabulary.0, eos_token_id, frozen_tokens)
91+
.map(PyIndex)
92+
.map_err(Into::into)
93+
})
9194
}
9295

9396
fn __reduce__(&self) -> PyResult<(PyObject, (Vec<u8>,))> {

0 commit comments

Comments
 (0)