Skip to content

Commit 63ab0d5

Browse files
Add interface to Guide object to update masks in place, and associated kernels. (#183)
1 parent bc79d02 commit 63ab0d5

22 files changed

+1038
-15
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ __pycache__
1313
*.pyd
1414
*.so
1515
benchmarks/results
16+
benchmarks/env
1617
build
1718
Cargo.lock
1819
dist

benchmarks/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import sys
2+
3+
# There is a conflict between asv.statistics and the standard library's statistics module.
4+
# This is a workaround to use the standard library's median function.
5+
if "asv.statistics" in sys.modules:
6+
7+
def median(data):
8+
import statistics
9+
10+
return statistics.median(data)
11+
12+
asv_statistics = sys.modules["asv.statistics"]
13+
asv_statistics.median = median # type: ignore

benchmarks/asv.conf.json

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,17 @@
44
"project_url": "https://dottxt-ai.github.io/outlines-core/",
55
"repo": "..",
66
"branches": [
7-
"HEAD",
7+
"HEAD"
88
],
99
"build_command": [
10-
"python -mpip install .[test]",
11-
"PIP_NO_BUILD_ISOLATION=false python -mpip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}",
10+
"python -m pip install .[test]",
11+
"PIP_NO_BUILD_ISOLATION=false python -m pip wheel --no-deps --no-index -w {build_cache_dir} {build_dir}"
1212
],
13+
"matrix": {
14+
"torch": ["2.4.0"],
15+
"numpy": ["2.2.3"],
16+
"numba": ["0.60.0"]
17+
},
1318
"environment_type": "virtualenv",
1419
"show_commit_url": "https://github.com/dottxt-ai/outlines-core/commit/",
1520
"benchmark_dir": ".",

benchmarks/bench_kernels.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import random
2+
3+
import numpy as np
4+
import torch
5+
6+
from outlines_core.kernels.numpy import (
7+
_apply_token_bitmask_inplace_kernel as numpy_kernel,
8+
)
9+
from outlines_core.kernels.torch import (
10+
_apply_token_bitmask_inplace_kernel as torch_kernel,
11+
)
12+
13+
14+
def generate_sparse_mask(batch, vocab, allowed_count=1000):
15+
mask_shape = (batch, (vocab + 31) // 32)
16+
mask = np.zeros(mask_shape, dtype=np.uint32)
17+
allowed_indices = random.sample(range(vocab), allowed_count)
18+
for idx in allowed_indices:
19+
group = idx // 32
20+
shift = idx % 32
21+
bit_mask = np.uint32(1) << np.uint32(shift)
22+
mask[0, group] |= bit_mask
23+
return mask
24+
25+
26+
class TorchBitmaskApplyBenchmark:
27+
params = [[10, 100, 1_000, 10_000, 100_000], [1, 2, 4, 8]]
28+
param_names = ["allowed_tokens", "batch"]
29+
number = 10
30+
31+
def setup(self, allowed_tokens, batch):
32+
self.device = "cpu"
33+
self.allowed_tokens = allowed_tokens
34+
self.vocab = 128000
35+
self.batch = batch
36+
37+
self.logits = torch.randn(self.batch, self.vocab, device=self.device)
38+
39+
mask = torch.from_numpy(
40+
generate_sparse_mask(
41+
self.batch, self.vocab, allowed_count=self.allowed_tokens
42+
)
43+
)
44+
self.mask = mask.to(self.device)
45+
46+
self.kernel = torch_kernel
47+
48+
for _ in range(4):
49+
self.kernel(self.logits, self.mask)
50+
51+
def time_kernel(self, allowed_tokens, batch):
52+
self.kernel(self.logits, self.mask)
53+
54+
55+
class NumpyBitmaskApplyBenchmark:
56+
params = [[10, 100, 1_000, 10_000, 100_000], [1, 2, 4, 8]]
57+
param_names = ["allowed_tokens", "batch"]
58+
number = 10
59+
60+
def setup(self, allowed_tokens, batch):
61+
self.allowed_tokens = allowed_tokens
62+
self.vocab = 128000
63+
self.batch = batch
64+
65+
self.logits = np.random.randn(self.batch, self.vocab).astype(np.float32)
66+
67+
self.mask = generate_sparse_mask(
68+
self.batch, self.vocab, allowed_count=self.allowed_tokens
69+
)
70+
71+
self.kernel = numpy_kernel
72+
73+
for _ in range(4):
74+
self.kernel(self.logits, self.mask)
75+
76+
def time_kernel(self, allowed_tokens, batch):
77+
self.kernel(self.logits, self.mask)
78+
79+
80+
class MlxBitmaskApplyBenchmark:
81+
params = [[10, 100, 1_000, 10_000, 100_000], [1, 2, 4, 8]]
82+
param_names = ["allowed_tokens", "batch"]
83+
number = 10
84+
85+
def setup(self, allowed_tokens, batch):
86+
try:
87+
import mlx.core as mx
88+
89+
from outlines_core.kernels.mlx import (
90+
_apply_token_bitmask_kernel as mlx_kernel,
91+
)
92+
except ImportError:
93+
raise NotImplementedError
94+
95+
self.allowed_tokens = allowed_tokens
96+
self.vocab = 128000
97+
self.batch = batch
98+
99+
self.logits = mx.array(
100+
np.random.randn(self.batch, self.vocab).astype(np.float32)
101+
)
102+
103+
self.mask = mx.array(
104+
generate_sparse_mask(
105+
self.batch, self.vocab, allowed_count=self.allowed_tokens
106+
)
107+
)
108+
109+
self.kernel = mlx_kernel
110+
111+
# warm up / compile
112+
for _ in range(4):
113+
self.kernel(self.logits, self.mask)
114+
115+
def time_kernel(self, allowed_tokens, batch):
116+
self.kernel(self.logits, self.mask)

benchmarks/bench_regex_guide.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from concurrent.futures import ThreadPoolExecutor
33

44
import psutil
5+
56
from outlines_core import Guide, Index, Vocabulary
67

78
regex_samples = {
@@ -83,3 +84,21 @@ def peakmem_guides_per_index(self, num_guides):
8384

8485
assert len(objects) == num_guides
8586
assert final - initial < 5
87+
88+
89+
class WriteMaskIntoBenchmark:
90+
params = list(regex_samples.keys())
91+
param_names = ["regex_key"]
92+
93+
def setup(self, regex_key):
94+
from outlines_core.kernels.torch import allocate_token_bitmask
95+
96+
self.vocab = Vocabulary.from_pretrained("gpt2")
97+
self.mask = allocate_token_bitmask(len(self.vocab))
98+
self.index = Index(regex_samples[regex_key], self.vocab)
99+
self.guide = Guide(self.index)
100+
101+
def time_write_mask_into(self, regex_key):
102+
self.guide.write_mask_into(
103+
self.mask.data_ptr(), self.mask.numel(), self.mask.element_size()
104+
)

benchmarks/bench_torch_e2e.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
3+
from outlines_core import Guide, Index, Vocabulary
4+
from outlines_core.kernels.torch import (
5+
_apply_token_bitmask_inplace_kernel,
6+
allocate_token_bitmask,
7+
)
8+
9+
regex_samples = {
10+
"email": r"[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?",
11+
"complex_phone": "\\+?\\d{1,4}?[-.\\s]?\\(?\\d{1,3}?\\)?[-.\\s]?\\d{1,4}[-.\\s]?\\d{1,4}[-.\\s]?\\d{1,9}",
12+
"simple_phone": "\\+?[1-9][0-9]{7,14}",
13+
"date": r"([1-9]|0[1-9]|1[0-9]|2[0-9]|3[0-1])(\.|-|/)([1-9]|0[1-9]|1[0-2])(\.|-|/)([0-9][0-9]|19[0-9][0-9]|20[0-9][0-9])|([0-9][0-9]|19[0-9][0-9]|20[0-9][0-9])(\.|-|/)([1-9]|0[1-9]|1[0-2])(\.|-|/)([1-9]|0[1-9]|1[0-9]|2[0-9]|3[0-1])",
14+
"time": r"(0?[1-9]|1[0-2]):[0-5]\d\s?(am|pm)?",
15+
"ip": r"(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)",
16+
"url": r"(https?:\/\/)?([\da-z\.-]+)\.([a-z\.]{2,6})([\/\w \.-]*)*\/?",
17+
"ssn": r"\d{3}-\d{2}-\d{4}",
18+
"complex_span_constrained_relation_extraction": "(['\"\\ ,]?((?:of|resulting|case|which|cultures|a|core|extreme|selflessness|spiritual|various|However|both|vary|in|other|secular|the|religious|among|moral|and|It|object|worldviews|altruism|traditional|material|aspect|or|life|beings|virtue|is|however|opposite|concern|an|practice|it|for|s|quality|religions|In|Altruism|animals|happiness|many|become|principle|human|selfishness|may|synonym)['\"\\ ,]?)+['\"\\ ,]?\\s\\|\\s([^|\\(\\)\n]{1,})\\s\\|\\s['\"\\ ,]?((?:of|resulting|case|which|cultures|a|core|extreme|selflessness|spiritual|various|However|both|vary|in|other|secular|the|religious|among|moral|and|It|object|worldviews|altruism|traditional|material|aspect|or|life|beings|virtue|is|however|opposite|concern|an|practice|it|for|s|quality|religions|In|Altruism|animals|happiness|many|become|principle|human|selfishness|may|synonym)['\"\\ ,]?)+['\"\\ ,]?(\\s\\|\\s\\(([^|\\(\\)\n]{1,})\\s\\|\\s([^|\\(\\)\n]{1,})\\))*\\n)*",
19+
}
20+
21+
22+
class TorchE2EBenchmark:
23+
params = regex_samples.keys()
24+
25+
def setup(self, pattern_name):
26+
self.vocabulary = Vocabulary.from_pretrained("gpt2")
27+
self.pattern = regex_samples[pattern_name]
28+
self.guide = Guide(Index(self.pattern, self.vocabulary))
29+
30+
self.mask = allocate_token_bitmask(len(self.vocabulary))
31+
self.logits = torch.randn(1, len(self.vocabulary))
32+
33+
def time_write_mask_and_apply(self, pattern_name):
34+
self.guide.write_mask_into(
35+
self.mask.data_ptr(), self.mask.numel(), self.mask.element_size()
36+
)
37+
38+
_apply_token_bitmask_inplace_kernel(self.logits, self.mask)

outlines_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .outlines_core import Guide, Index, Vocabulary

outlines_core/json_schema.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from .outlines_core import ( # noqa: F401
2+
BOOLEAN,
3+
DATE,
4+
DATE_TIME,
5+
EMAIL,
6+
INTEGER,
7+
NULL,
8+
NUMBER,
9+
STRING,
10+
STRING_INNER,
11+
TIME,
12+
URI,
13+
UUID,
14+
WHITESPACE,
15+
build_regex_from_schema,
16+
)

outlines_core/kernels/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Token Masking kernel implementations for various backends."""

outlines_core/kernels/mlx.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
from outlines_core import Guide
2+
3+
try:
4+
import mlx.core as mx
5+
import numpy as np
6+
except ImportError as e:
7+
missing_dep = "numpy" if "numpy" in str(e) else "mlx"
8+
raise ImportError(
9+
f"To use the kernels in `outlines_core.kernels.mlx`, {missing_dep} must be installed. You can install it with `pip install {missing_dep}`"
10+
) from e
11+
12+
13+
def allocate_token_bitmask(vocab_size: int) -> np.ndarray:
14+
return np.full(
15+
(1, (vocab_size + 31) // 32),
16+
-1,
17+
dtype=np.int32,
18+
)
19+
20+
21+
_KERNEL_SOURCE = r"""
22+
// Batch index
23+
uint batch = thread_position_in_grid.y;
24+
// Element index
25+
uint elem = thread_position_in_grid.x;
26+
27+
uint bit = ((elem >> 5) < mask_shape[1]) &&
28+
((mask[batch * mask_shape[1] + (elem >> 5)] >> (elem & 31)) & 1);
29+
30+
out[batch * inp_shape[1] + elem] = bit ? inp[batch * inp_shape[1] + elem] : -INFINITY;
31+
"""
32+
33+
_KERNEL = mx.fast.metal_kernel(
34+
name="bitmask_apply_batched",
35+
input_names=["inp", "mask"],
36+
output_names=["out"],
37+
source=_KERNEL_SOURCE,
38+
)
39+
40+
41+
@mx.compile
42+
def _apply_token_bitmask_kernel(data: mx.array, mask: mx.array) -> mx.array:
43+
return _KERNEL(
44+
inputs=[data, mask],
45+
template=[("T", data.dtype)],
46+
grid=(data.shape[1], data.shape[0], 1),
47+
threadgroup=(256, 1, 1),
48+
output_shapes=[data.shape],
49+
output_dtypes=[data.dtype],
50+
)[0]
51+
52+
53+
def apply_token_bitmask(logits: mx.array, mask_np: np.ndarray) -> mx.array:
54+
"""
55+
Apply a logits bitmask inplace, setting the probability of invalid tokens
56+
to -infinity.
57+
58+
Arguments:
59+
logits (mx.array): The logits tensor.
60+
61+
mask (mx.array): The token bitmask representing the validity of each
62+
token in the logits tensor.
63+
64+
Raises:
65+
ValueError: If any of the following conditions are not met:
66+
- `mask.dtype` is not `mx.int32`
67+
- `mask` is not a 2D array
68+
- `logits` is not a 2D array
69+
- `mask.shape`shape does not match `logits.shape`
70+
71+
Returns:
72+
None: Modifies the mask array in place.
73+
"""
74+
# makes a copy - non consuming
75+
mask = mx.array(mask_np)
76+
77+
logits = logits if len(logits.shape) != 1 else mx.expand_dims(logits, axis=0)
78+
mask = mask if len(mask.shape) != 1 else mx.expand_dims(mask, axis=0)
79+
80+
if mask.dtype != mx.int32:
81+
raise ValueError(
82+
f"Invalid mask dtype: Expected `np.int32`, but got `{mask.dtype}`."
83+
)
84+
elif len(mask.shape) != 2:
85+
raise ValueError(
86+
f"Invalid mask dimensions: Expected a 2D array, but got {mask.ndim}D."
87+
)
88+
elif len(logits.shape) != 2:
89+
raise ValueError(
90+
f"Invalid logits dimensions: Expected a 2D array, but got {logits.ndim}D."
91+
)
92+
elif mask.shape[0] != logits.shape[0]:
93+
raise ValueError(
94+
f"Invalid batch size: Expected `mask.shape[0]` ({mask.shape[0]}) to match `logits.shape[0]` ({logits.shape[0]})."
95+
)
96+
return _apply_token_bitmask_kernel(logits, mask)
97+
98+
99+
def fill_next_token_bitmask(guide: Guide, mask: np.ndarray) -> None:
100+
"""
101+
Writes a bitmask to represent the tokens permissible by the current state of the `guide`.
102+
Each bit in the bitmask corresponds to a token ID, with a bit value of 1 indicating that
103+
the token is allowed and 0 indicating that it is disallowed. This function directly modifies
104+
the `mask` array in-place.
105+
106+
Arguments:
107+
guide (Guide): An instance of the `Guide` class that provides the current guidance state.
108+
mask (torch.Tensor): A 2D tensor of type `torch.int32` where the bitmask will be written.
109+
The tensor must be contiguous, have a single batch dimension
110+
(shape[0] == 1), and reside on the CPU.
111+
112+
Raises:
113+
ValueError: If any of the following conditions are not met:
114+
- `mask.dtype` is not `np.int32`
115+
- `mask` is not a 2D tensor
116+
- `mask` does not have a single batch dimension (shape[0] != 1)
117+
- `mask` is not contiguous in memory
118+
- `mask` is not on the CPU device
119+
120+
Returns:
121+
None: Modifies the `mask` tensor in-place.
122+
"""
123+
if mask.dtype != np.int32:
124+
raise ValueError(
125+
f"Invalid mask dtype: Expected `np.int32`, but got `{mask.dtype}`."
126+
)
127+
elif mask.ndim != 2:
128+
raise ValueError(
129+
f"Invalid mask dimensions: Expected a 2D array, but got {mask.ndim}D."
130+
)
131+
elif mask.shape[0] != 1:
132+
raise ValueError(
133+
f"Invalid batch size: Batch mask writes are not supported. Expected shape[0] == 1, but got shape {mask.shape}."
134+
)
135+
elif not mask.flags["C_CONTIGUOUS"]:
136+
raise ValueError(
137+
"Mask array must be contiguous in memory. Use `np.ascontiguousarray(mask)`."
138+
)
139+
140+
return guide.write_mask_into(mask.ctypes.data, mask.size, mask.itemsize)

0 commit comments

Comments
 (0)