Skip to content

Commit a9c570e

Browse files
authored
Hot fixes (#8)
* improve docker build process * fix streaming issue * bug fixes
1 parent 74f8c65 commit a9c570e

File tree

17 files changed

+440
-69
lines changed

17 files changed

+440
-69
lines changed

docker/Dockerfile.aarch64-cuda

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ COPY . /scratchpad
1717

1818
RUN pip install https://filedn.eu/lougUsdPvd1uJK2jfOYWogH/pypi/flashinfer-0.1.6-cp310-cp310-linux_aarch64.whl
1919
RUN pip install https://filedn.eu/lougUsdPvd1uJK2jfOYWogH/pypi/triteia-0.1.0-cp310-cp310-linux_aarch64.whl
20-
RUN pip install -r requirements-extra.txt
20+
RUN pip install -r meta/requirements-extra.txt
2121
RUN pip install .
2222
# todo(xiaozhe): figure out why pynvml is installed in the first place. We should use nvidia-ml-py instead.
2323
RUN pip uninstall pynvml -y

docker/Dockerfile.x86_64-cuda

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM nvcr.io/nvidia/pytorch:24.09-py3 AS base
1+
FROM nvcr.io/nvidia/pytorch:24.07-py3 AS base
22

33
LABEL org.opencontainers.image.source=https://github.com/xiaozheyao/Scratchpad
44
LABEL org.opencontainers.image.description="Scratchpad: Adaptive Serving of LMs"
@@ -15,8 +15,8 @@ WORKDIR /scratchpad
1515

1616
COPY . /scratchpad
1717

18-
RUN pip install https://filedn.eu/lougUsdPvd1uJK2jfOYWogH/pypi/flashinfer-0.1.6-cp310-cp310-linux_x86_64.whl
18+
RUN pip install flashinfer -i https://flashinfer.ai/whl/cu124/torch2.4/
1919
RUN pip install https://filedn.eu/lougUsdPvd1uJK2jfOYWogH/pypi/triteia-0.1.0-cp310-cp310-linux_x86_64.whl
20-
RUN pip install -r requirements-extra.txt
20+
RUN pip install -r meta/requirements-extra.txt
2121
RUN pip install .
2222
RUN pip uninstall pynvml -y

docker/build_image.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ if [ -z "$version" ]; then
77
exit 1
88
fi
99
echo "Building image for $arch, version $version"
10-
podman build -f docker/Dockerfile.$arch-cuda . -t ghcr.io/xiaozheyao/scratchpad:${version}dev-$arch --build-arg ARCH=$arch
11-
podman push ghcr.io/xiaozheyao/scratchpad:${version}dev-$arch
10+
docker build -f docker/Dockerfile.$arch-cuda . -t ghcr.io/xiaozheyao/scratchpad:${version}dev-$arch --build-arg ARCH=$arch
11+
docker push ghcr.io/xiaozheyao/scratchpad:${version}dev-$arch

docs/examples/mllama_request.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
prompt = "What is in this image?"
66
img_url = "https://images.unsplash.com/photo-1692350914621-f0ca2d206368?q=80&w=3000&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"
7+
stream = True
78

89

910
def image_url_to_base64(url):
@@ -41,5 +42,11 @@ def image_url_to_base64(url):
4142
],
4243
}
4344
],
45+
stream=stream,
4446
)
45-
print(response.choices[0].message.content)
47+
if stream:
48+
for chunk in response:
49+
if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
50+
print(chunk.choices[0].delta.content, end="", flush=True)
51+
else:
52+
print(response.choices[0].message.content)

scratchpad/constrained/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,21 @@ def build_regex_from_object(
3434
return build_regex_from_schema(schema, whitespace_pattern)
3535

3636

37+
try:
38+
from xgrammar import (
39+
GrammarMatcher,
40+
GrammarMatcherInitContext,
41+
GrammarMatcherInitContextCache,
42+
)
43+
except ImportError as e:
44+
45+
class Dummy:
46+
pass
47+
48+
GrammarMatcher = Dummy
49+
GrammarMatcherInitContext = Dummy
50+
GrammarMatcherInitContextCache = Dummy
51+
3752
__all__ = [
3853
"RegexGuide",
3954
"FSMInfo",
@@ -43,4 +58,7 @@ def build_regex_from_object(
4358
"disk_cache",
4459
"disable_cache",
4560
"make_byte_level_fsm",
61+
"GrammarMatcher",
62+
"GrammarMatcherInitContext",
63+
"GrammarMatcherInitContextCache",
4664
]
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""Cache for the compressed finite state machine."""
2+
3+
from typing import Tuple
4+
5+
from transformers import AutoTokenizer
6+
7+
from scratchpad.constrained import (
8+
GrammarMatcher,
9+
GrammarMatcherInitContext,
10+
GrammarMatcherInitContextCache,
11+
)
12+
13+
MAX_ROLLBACK_TOKENS = 10
14+
15+
16+
class BNFCache:
17+
grammar_cache: GrammarMatcherInitContextCache
18+
19+
def __init__(
20+
self,
21+
tokenizer_path,
22+
tokenizer_args_dict,
23+
skip_tokenizer_init=False,
24+
whitespace_patterns=None,
25+
):
26+
# TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
27+
if skip_tokenizer_init:
28+
return
29+
30+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
31+
self.grammar_cache = GrammarMatcherInitContextCache(
32+
tokenizer_or_vocab=tokenizer
33+
)
34+
35+
def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext:
36+
key_type, key_string = key
37+
if key_type == "json":
38+
return self.grammar_cache.get_init_context_for_json_schema(key_string)
39+
elif key_type == "regex":
40+
raise ValueError(f"regex hasn't been supported by xgrammar yet")
41+
else:
42+
raise ValueError(f"Invalid key_type: {key_type}")
43+
44+
def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher:
45+
ctx = self.get_context(key)
46+
return GrammarMatcher(
47+
ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
48+
)

scratchpad/constrained/fsm_cache.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33
from .base_tool_cache import BaseToolCache
44
from . import RegexGuide, TransformerTokenizer
55

6+
import logging
7+
8+
from interegular import InvalidSyntax, parse_pattern
9+
from outlines.fsm.json_schema import build_regex_from_schema
10+
from transformers import AutoTokenizer
11+
12+
from scratchpad.constrained import RegexGuide, TransformerTokenizer
13+
from .base_tool_cache import BaseToolCache
14+
15+
logger = logging.getLogger(__name__)
16+
617

718
class FSMCache(BaseToolCache):
819
def __init__(
@@ -51,12 +62,23 @@ def fset(self, value):
5162
def init_value(self, key):
5263
key_type, key_string = key
5364
if key_type == "json":
54-
regex = build_regex_from_schema(
55-
key_string, whitespace_pattern=self.constrained_json_whitespace_pattern
56-
)
65+
try:
66+
regex = build_regex_from_schema(
67+
key_string,
68+
whitespace_pattern=self.constrained_json_whitespace_pattern,
69+
)
70+
except NotImplementedError as e:
71+
logger.warning(
72+
f"skip invalid json schema: json_schema={key_string}, {e=}"
73+
)
74+
return None, key_string
5775
elif key_type == "regex":
5876
regex = key_string
5977
else:
6078
raise ValueError(f"Invalid key_type: {key_type}")
61-
79+
try:
80+
parse_pattern(regex)
81+
except InvalidSyntax as e:
82+
logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
83+
return None, regex
6284
return RegexGuide(regex, self.outlines_tokenizer), regex

scratchpad/constrained/grammar.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""Cache for the compressed finite state machine."""
2+
import logging
3+
from typing import List, Optional, Tuple, Union
4+
5+
import torch
6+
7+
from scratchpad.constrained import GrammarMatcher, RegexGuide
8+
from .bnf_cache import BNFCache
9+
from .fsm_cache import FSMCache
10+
from .jump_forward import JumpForwardCache, JumpForwardMap
11+
12+
# from sglang.srt.managers.schedule_batch import Req
13+
14+
logger = logging.getLogger(__name__)
15+
16+
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
17+
18+
19+
class XGrammarJump:
20+
pass
21+
22+
23+
class JumpHelper:
24+
data: Union[List, str]
25+
state: int
26+
suffix_ids: List[int]
27+
28+
def __init__(
29+
self, data: Union[List, str] = "", state: int = -1, suffix_ids=[]
30+
) -> None:
31+
self.data = data
32+
self.state = state
33+
self.suffix_ids = suffix_ids
34+
35+
def can_jump(self):
36+
return len(self.data) > 0
37+
38+
39+
class Grammar:
40+
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]]
41+
jump_map: Union[XGrammarJump, JumpForwardMap, None]
42+
43+
def __init__(
44+
self,
45+
grammar: Union[GrammarMatcher, Tuple[RegexGuide, int]],
46+
jump_map: Union[XGrammarJump, JumpForwardMap, None],
47+
) -> None:
48+
self.grammar = grammar
49+
self.jump_map = jump_map
50+
51+
def accept_token(self, token: int):
52+
if isinstance(self.grammar, GrammarMatcher):
53+
assert self.grammar.accept_token(token)
54+
else:
55+
guide, state = self.grammar
56+
self.grammar = guide, guide.get_next_state(state, token)
57+
58+
def try_jump(self, tokenizer) -> JumpHelper:
59+
if isinstance(self.jump_map, XGrammarJump):
60+
assert isinstance(self.grammar, GrammarMatcher)
61+
return JumpHelper(self.grammar.find_jump_forward_string())
62+
elif isinstance(self.jump_map, JumpForwardMap):
63+
assert isinstance(self.grammar, Tuple)
64+
65+
_, state = self.grammar
66+
jump_forward_bytes = self.jump_map.jump_forward_byte(state)
67+
if jump_forward_bytes is None or len(jump_forward_bytes) == 0:
68+
return JumpHelper() # can't jump
69+
70+
# preprocess the jump forward string
71+
suffix_bytes = []
72+
continuation_range = range(0x80, 0xC0)
73+
cur_state = state
74+
while (
75+
len(jump_forward_bytes)
76+
and jump_forward_bytes[0][0] in continuation_range
77+
):
78+
# continuation bytes
79+
byte_edge = jump_forward_bytes.pop(0)
80+
suffix_bytes.append(byte_edge[0])
81+
cur_state = byte_edge[1]
82+
83+
suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
84+
suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
85+
return JumpHelper(suffix_ids, cur_state, suffix_bytes)
86+
else:
87+
return JumpHelper() # can't jump
88+
89+
def jump_forward_str_state(self, helper: JumpHelper) -> Tuple[str, int]:
90+
if isinstance(helper.data, str):
91+
return helper.data, -1
92+
else:
93+
assert isinstance(self.jump_map, JumpForwardMap)
94+
return self.jump_map.jump_forward_symbol(helper.state)
95+
96+
def jump_and_retokenize(
97+
self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
98+
):
99+
if isinstance(self.grammar, GrammarMatcher):
100+
k = 0
101+
for i, old_id in enumerate(old_output_ids):
102+
if old_id == new_output_ids[i]:
103+
k = i + 1
104+
else:
105+
break
106+
107+
# rollback to the last token that is the same
108+
if k < len(old_output_ids):
109+
self.grammar.rollback(len(old_output_ids) - k)
110+
111+
for i in range(k, len(new_output_ids)):
112+
assert self.grammar.accept_token(new_output_ids[i])
113+
else:
114+
self.grammar = self.grammar[0], next_state
115+
116+
def fill_vocab_mask(self, vocab_mask: torch.Tensor, vocab_size: int):
117+
if isinstance(self.grammar, GrammarMatcher):
118+
# Note that this bitmask is a bitset, not bool
119+
bitmask = self.grammar.find_next_token_bitmask()
120+
# Mask the tokens that are not allowed
121+
vocab_mask[
122+
self.grammar.get_rejected_tokens_from_bitmask(bitmask, vocab_size)
123+
] = 1
124+
else:
125+
guide, state = self.grammar
126+
vocab_mask.fill_(1)
127+
vocab_mask[guide.get_next_instruction(state).tokens] = 0
128+
129+
130+
class GrammarCache:
131+
grammar_cache: Union[BNFCache, FSMCache]
132+
jump_cache: Union[XGrammarJump, JumpForwardCache, None]
133+
134+
def __init__(
135+
self,
136+
tokenizer_path,
137+
tokenizer_args_dict,
138+
skip_tokenizer_init=False,
139+
whitespace_patterns=None,
140+
backend=None,
141+
allow_jump=False,
142+
):
143+
if backend == "xgrammar":
144+
self.grammar_cache = BNFCache(
145+
tokenizer_path=tokenizer_path,
146+
tokenizer_args_dict=tokenizer_args_dict,
147+
skip_tokenizer_init=skip_tokenizer_init,
148+
whitespace_patterns=whitespace_patterns,
149+
)
150+
self.jump_cache = XGrammarJump() if allow_jump else None
151+
else:
152+
assert backend == "outlines"
153+
self.grammar_cache = FSMCache(
154+
tokenizer_path=tokenizer_path,
155+
tokenizer_args_dict=tokenizer_args_dict,
156+
skip_tokenizer_init=skip_tokenizer_init,
157+
constrained_json_whitespace_pattern=whitespace_patterns,
158+
enable=True,
159+
)
160+
self.jump_cache = JumpForwardCache() if allow_jump else None
161+
162+
def query(self, key: Tuple[str, str], vocab_size: int) -> Grammar:
163+
if isinstance(self.grammar_cache, BNFCache):
164+
assert not isinstance(self.jump_cache, JumpForwardCache)
165+
return Grammar(self.grammar_cache.query(key, vocab_size), self.jump_cache)
166+
else:
167+
jump_map = None
168+
guide, regex = self.grammar_cache.query(key)
169+
if isinstance(self.jump_cache, JumpForwardCache):
170+
jump_map = self.jump_cache.query(regex)
171+
return Grammar((guide, 0), jump_map)
172+
173+
def reset(self):
174+
if isinstance(self.grammar_cache, FSMCache):
175+
self.grammar_cache.reset()
176+
if isinstance(self.jump_cache, JumpForwardCache):
177+
self.jump_cache.reset()

scratchpad/managers/tokenizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,6 @@ def create_handle_loop(self):
440440

441441
async def sigterm_watchdog(self):
442442
while not self.gracefully_exit:
443-
print("sigterm_watchdog")
444443
await asyncio.sleep(5)
445444

446445
# drain requests

0 commit comments

Comments
 (0)