@@ -10,24 +10,24 @@ class BaseResponse(BaseModel):
10
10
11
11
class GenerateRequest (BaseModel ):
12
12
text : List [str ] = None
13
- min_length : int = None
14
- do_sample : bool = None
15
- early_stopping : bool = None
16
- temperature : float = None
17
- top_k : int = None
18
- top_p : float = None
19
- typical_p : float = None
20
- repetition_penalty : float = None
13
+ min_length : int = 0
14
+ do_sample : bool = False
15
+ early_stopping : bool = False
16
+ temperature : float = 1
17
+ top_k : int = 50
18
+ top_p : float = 1
19
+ typical_p : float = 1
20
+ repetition_penalty : float = 1
21
21
bos_token_id : int = None
22
22
pad_token_id : int = None
23
23
eos_token_id : int = None
24
- length_penalty : float = None
25
- no_repeat_ngram_size : int = None
26
- encoder_no_repeat_ngram_size : int = None
24
+ length_penalty : float = 1
25
+ no_repeat_ngram_size : int = 0
26
+ encoder_no_repeat_ngram_size : int = 0
27
27
max_time : float = None
28
28
max_new_tokens : int = None
29
29
decoder_start_token_id : int = None
30
- diversity_penalty : float = None
30
+ diversity_penalty : float = 0
31
31
forced_bos_token_id : int = None
32
32
forced_eos_token_id : int = None
33
33
exponential_decay_length_penalty : float = None
@@ -89,32 +89,51 @@ def parse_field(kwargs: dict, field: str, dtype: type, default_value: Any = None
89
89
90
90
def create_generate_request (text : List [str ], generate_kwargs : dict ) -> GenerateRequest :
91
91
# get user generate_kwargs as json and parse it
92
+ default_request = GenerateRequest ()
93
+
92
94
return GenerateRequest (
93
95
text = text ,
94
- min_length = parse_field (generate_kwargs , "min_length" , int ),
95
- do_sample = parse_field (generate_kwargs , "do_sample" , bool ),
96
- early_stopping = parse_field (generate_kwargs , "early_stopping" , bool ),
97
- num_beams = parse_field (generate_kwargs , "num_beams" , int ),
98
- temperature = parse_field (generate_kwargs , "temperature" , float ),
99
- top_k = parse_field (generate_kwargs , "top_k" , int ),
100
- top_p = parse_field (generate_kwargs , "top_p" , float ),
101
- typical_p = parse_field (generate_kwargs , "typical_p" , float ),
102
- repetition_penalty = parse_field (generate_kwargs , "repetition_penalty" , float ),
103
- bos_token_id = parse_field (generate_kwargs , "bos_token_id" , int ),
104
- pad_token_id = parse_field (generate_kwargs , "pad_token_id" , int ),
105
- eos_token_id = parse_field (generate_kwargs , "eos_token_id" , int ),
106
- length_penalty = parse_field (generate_kwargs , "length_penalty" , float ),
107
- no_repeat_ngram_size = parse_field (generate_kwargs , "no_repeat_ngram_size" , int ),
108
- encoder_no_repeat_ngram_size = parse_field (generate_kwargs , "encoder_no_repeat_ngram_size" , int ),
109
- max_time = parse_field (generate_kwargs , "max_time" , float ),
110
- max_new_tokens = parse_field (generate_kwargs , "max_new_tokens" , int ),
111
- decoder_start_token_id = parse_field (generate_kwargs , "decoder_start_token_id" , int ),
112
- num_beam_group = parse_field (generate_kwargs , "num_beam_group" , int ),
113
- diversity_penalty = parse_field (generate_kwargs , "diversity_penalty" , float ),
114
- forced_bos_token_id = parse_field (generate_kwargs , "forced_bos_token_id" , int ),
115
- forced_eos_token_id = parse_field (generate_kwargs , "forced_eos_token_id" , int ),
116
- exponential_decay_length_penalty = parse_field (generate_kwargs , "exponential_decay_length_penalty" , float ),
117
- remove_input_from_output = parse_field (generate_kwargs , "remove_input_from_output" , bool , False ),
96
+ min_length = parse_field (generate_kwargs , "min_length" , int , default_request .min_length ),
97
+ do_sample = parse_field (generate_kwargs , "do_sample" , bool , default_request .do_sample ),
98
+ early_stopping = parse_field (generate_kwargs , "early_stopping" , bool , default_request .early_stopping ),
99
+ temperature = parse_field (generate_kwargs , "temperature" , float , default_request .temperature ),
100
+ top_k = parse_field (generate_kwargs , "top_k" , int , default_request .top_k ),
101
+ top_p = parse_field (generate_kwargs , "top_p" , float , default_request .top_p ),
102
+ typical_p = parse_field (generate_kwargs , "typical_p" , float , default_request .typical_p ),
103
+ repetition_penalty = parse_field (
104
+ generate_kwargs , "repetition_penalty" , float , default_request .repetition_penalty
105
+ ),
106
+ bos_token_id = parse_field (generate_kwargs , "bos_token_id" , int , default_request .bos_token_id ),
107
+ pad_token_id = parse_field (generate_kwargs , "pad_token_id" , int , default_request .pad_token_id ),
108
+ eos_token_id = parse_field (generate_kwargs , "eos_token_id" , int , default_request .eos_token_id ),
109
+ length_penalty = parse_field (generate_kwargs , "length_penalty" , float , default_request .length_penalty ),
110
+ no_repeat_ngram_size = parse_field (
111
+ generate_kwargs , "no_repeat_ngram_size" , int , default_request .no_repeat_ngram_size
112
+ ),
113
+ encoder_no_repeat_ngram_size = parse_field (
114
+ generate_kwargs , "encoder_no_repeat_ngram_size" , int , default_request .encoder_no_repeat_ngram_size
115
+ ),
116
+ max_time = parse_field (generate_kwargs , "max_time" , float , default_request .max_time ),
117
+ max_new_tokens = parse_field (generate_kwargs , "max_new_tokens" , int , default_request .max_new_tokens ),
118
+ decoder_start_token_id = parse_field (
119
+ generate_kwargs , "decoder_start_token_id" , int , default_request .decoder_start_token_id
120
+ ),
121
+ diversity_penalty = parse_field (generate_kwargs , "diversity_penalty" , float , default_request .diversity_penalty ),
122
+ forced_bos_token_id = parse_field (
123
+ generate_kwargs , "forced_bos_token_id" , int , default_request .forced_bos_token_id
124
+ ),
125
+ forced_eos_token_id = parse_field (
126
+ generate_kwargs , "forced_eos_token_id" , int , default_request .forced_eos_token_id
127
+ ),
128
+ exponential_decay_length_penalty = parse_field (
129
+ generate_kwargs ,
130
+ "exponential_decay_length_penalty" ,
131
+ float ,
132
+ default_request .exponential_decay_length_penalty ,
133
+ ),
134
+ remove_input_from_output = parse_field (
135
+ generate_kwargs , "remove_input_from_output" , bool , default_request .remove_input_from_output
136
+ ),
118
137
)
119
138
120
139
0 commit comments