Skip to content
This repository was archived by the owner on Oct 9, 2024. It is now read-only.

Commit 114b912

Browse files
committed
fix broken generate in 4.26.1
1 parent a17b7d3 commit 114b912

File tree

4 files changed

+59
-39
lines changed

4 files changed

+59
-39
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
repos:
22
- repo: https://github.com/pycqa/isort
3-
rev: 5.10.1
3+
rev: 5.12.0
44
hooks:
55
- id: isort
66
name: isort (python)
77
- repo: https://github.com/psf/black
8-
rev: 22.8.0
8+
rev: 23.1.0
99
hooks:
1010
- id: black
1111
args: [--line-length=119,--target-version=py35]

bloom-inference-scripts/bloom-accelerate-inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,14 @@ def print_rank0(*msg):
6161
device_map="auto",
6262
)
6363

64+
6465
def get_world_size() -> int:
6566
if dist.is_initialized():
6667
return dist.get_world_size()
6768
else:
6869
return 1
6970

71+
7072
# balanced_low_0 - because it allows a larger batch size with multiple GPUs
7173
if get_world_size() > 1:
7274
kwargs["device_map"] = "balanced_low_0"

inference_server/model_handler/grpc_utils/pb/generation_pb2.py

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

inference_server/utils/requests.py

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,24 @@ class BaseResponse(BaseModel):
1010

1111
class GenerateRequest(BaseModel):
1212
text: List[str] = None
13-
min_length: int = None
14-
do_sample: bool = None
15-
early_stopping: bool = None
16-
temperature: float = None
17-
top_k: int = None
18-
top_p: float = None
19-
typical_p: float = None
20-
repetition_penalty: float = None
13+
min_length: int = 0
14+
do_sample: bool = False
15+
early_stopping: bool = False
16+
temperature: float = 1
17+
top_k: int = 50
18+
top_p: float = 1
19+
typical_p: float = 1
20+
repetition_penalty: float = 1
2121
bos_token_id: int = None
2222
pad_token_id: int = None
2323
eos_token_id: int = None
24-
length_penalty: float = None
25-
no_repeat_ngram_size: int = None
26-
encoder_no_repeat_ngram_size: int = None
24+
length_penalty: float = 1
25+
no_repeat_ngram_size: int = 0
26+
encoder_no_repeat_ngram_size: int = 0
2727
max_time: float = None
2828
max_new_tokens: int = None
2929
decoder_start_token_id: int = None
30-
diversity_penalty: float = None
30+
diversity_penalty: float = 0
3131
forced_bos_token_id: int = None
3232
forced_eos_token_id: int = None
3333
exponential_decay_length_penalty: float = None
@@ -89,32 +89,51 @@ def parse_field(kwargs: dict, field: str, dtype: type, default_value: Any = None
8989

9090
def create_generate_request(text: List[str], generate_kwargs: dict) -> GenerateRequest:
9191
# get user generate_kwargs as json and parse it
92+
default_request = GenerateRequest()
93+
9294
return GenerateRequest(
9395
text=text,
94-
min_length=parse_field(generate_kwargs, "min_length", int),
95-
do_sample=parse_field(generate_kwargs, "do_sample", bool),
96-
early_stopping=parse_field(generate_kwargs, "early_stopping", bool),
97-
num_beams=parse_field(generate_kwargs, "num_beams", int),
98-
temperature=parse_field(generate_kwargs, "temperature", float),
99-
top_k=parse_field(generate_kwargs, "top_k", int),
100-
top_p=parse_field(generate_kwargs, "top_p", float),
101-
typical_p=parse_field(generate_kwargs, "typical_p", float),
102-
repetition_penalty=parse_field(generate_kwargs, "repetition_penalty", float),
103-
bos_token_id=parse_field(generate_kwargs, "bos_token_id", int),
104-
pad_token_id=parse_field(generate_kwargs, "pad_token_id", int),
105-
eos_token_id=parse_field(generate_kwargs, "eos_token_id", int),
106-
length_penalty=parse_field(generate_kwargs, "length_penalty", float),
107-
no_repeat_ngram_size=parse_field(generate_kwargs, "no_repeat_ngram_size", int),
108-
encoder_no_repeat_ngram_size=parse_field(generate_kwargs, "encoder_no_repeat_ngram_size", int),
109-
max_time=parse_field(generate_kwargs, "max_time", float),
110-
max_new_tokens=parse_field(generate_kwargs, "max_new_tokens", int),
111-
decoder_start_token_id=parse_field(generate_kwargs, "decoder_start_token_id", int),
112-
num_beam_group=parse_field(generate_kwargs, "num_beam_group", int),
113-
diversity_penalty=parse_field(generate_kwargs, "diversity_penalty", float),
114-
forced_bos_token_id=parse_field(generate_kwargs, "forced_bos_token_id", int),
115-
forced_eos_token_id=parse_field(generate_kwargs, "forced_eos_token_id", int),
116-
exponential_decay_length_penalty=parse_field(generate_kwargs, "exponential_decay_length_penalty", float),
117-
remove_input_from_output=parse_field(generate_kwargs, "remove_input_from_output", bool, False),
96+
min_length=parse_field(generate_kwargs, "min_length", int, default_request.min_length),
97+
do_sample=parse_field(generate_kwargs, "do_sample", bool, default_request.do_sample),
98+
early_stopping=parse_field(generate_kwargs, "early_stopping", bool, default_request.early_stopping),
99+
temperature=parse_field(generate_kwargs, "temperature", float, default_request.temperature),
100+
top_k=parse_field(generate_kwargs, "top_k", int, default_request.top_k),
101+
top_p=parse_field(generate_kwargs, "top_p", float, default_request.top_p),
102+
typical_p=parse_field(generate_kwargs, "typical_p", float, default_request.typical_p),
103+
repetition_penalty=parse_field(
104+
generate_kwargs, "repetition_penalty", float, default_request.repetition_penalty
105+
),
106+
bos_token_id=parse_field(generate_kwargs, "bos_token_id", int, default_request.bos_token_id),
107+
pad_token_id=parse_field(generate_kwargs, "pad_token_id", int, default_request.pad_token_id),
108+
eos_token_id=parse_field(generate_kwargs, "eos_token_id", int, default_request.eos_token_id),
109+
length_penalty=parse_field(generate_kwargs, "length_penalty", float, default_request.length_penalty),
110+
no_repeat_ngram_size=parse_field(
111+
generate_kwargs, "no_repeat_ngram_size", int, default_request.no_repeat_ngram_size
112+
),
113+
encoder_no_repeat_ngram_size=parse_field(
114+
generate_kwargs, "encoder_no_repeat_ngram_size", int, default_request.encoder_no_repeat_ngram_size
115+
),
116+
max_time=parse_field(generate_kwargs, "max_time", float, default_request.max_time),
117+
max_new_tokens=parse_field(generate_kwargs, "max_new_tokens", int, default_request.max_new_tokens),
118+
decoder_start_token_id=parse_field(
119+
generate_kwargs, "decoder_start_token_id", int, default_request.decoder_start_token_id
120+
),
121+
diversity_penalty=parse_field(generate_kwargs, "diversity_penalty", float, default_request.diversity_penalty),
122+
forced_bos_token_id=parse_field(
123+
generate_kwargs, "forced_bos_token_id", int, default_request.forced_bos_token_id
124+
),
125+
forced_eos_token_id=parse_field(
126+
generate_kwargs, "forced_eos_token_id", int, default_request.forced_eos_token_id
127+
),
128+
exponential_decay_length_penalty=parse_field(
129+
generate_kwargs,
130+
"exponential_decay_length_penalty",
131+
float,
132+
default_request.exponential_decay_length_penalty,
133+
),
134+
remove_input_from_output=parse_field(
135+
generate_kwargs, "remove_input_from_output", bool, default_request.remove_input_from_output
136+
),
118137
)
119138

120139

0 commit comments

Comments
 (0)