Skip to content

Commit 1df0df1

Browse files
committed
Implement repetition-penalty for generation
1 parent dcc9309 commit 1df0df1

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ Other configurable options include the top-p (nucleus sampling) probability, and
134134
To test generation latency (e.g. batch size = 1) with different sampling strategies:
135135

136136
```
137-
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
138-
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
137+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
138+
python benchmarks/benchmark_generation_mamba_simple.py --model-name "EleutherAI/pythia-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
139139
```
140140

141141
To test generation throughput with random prompts (e.g. large batch size):

benchmarks/benchmark_generation_mamba_simple.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
parser.add_argument("--temperature", type=float, default=1.0)
2323
parser.add_argument("--topk", type=int, default=1)
2424
parser.add_argument("--topp", type=float, default=1.0)
25+
parser.add_argument("--repetition-penalty", type=float, default=1.0)
2526
parser.add_argument("--batch", type=int, default=1)
2627
args = parser.parse_args()
2728

@@ -61,6 +62,7 @@
6162
temperature=args.temperature,
6263
top_k=args.topk,
6364
top_p=args.topp,
65+
repetition_penalty=args.repetition_penalty,
6466
)
6567
else:
6668
fn = lambda: model.generate(
@@ -73,6 +75,7 @@
7375
temperature=args.temperature,
7476
top_k=args.topk,
7577
top_p=args.topp,
78+
repetition_penalty=args.repetition_penalty,
7679
)
7780
out = fn()
7881
if args.prompt is not None:

mamba_ssm/utils/generation.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ def modify_logits_for_top_p_filtering(logits, top_p):
6060
logits.masked_fill_(indices_to_remove, float("-inf"))
6161

6262

63+
def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
64+
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858
65+
logits: (batch_size, vocab_size)
66+
prev_output_tokens: (batch_size, seq_len)
67+
"""
68+
if repetition_penalty == 1.0:
69+
return logits
70+
score = torch.gather(logits, 1, prev_output_tokens)
71+
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
72+
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
73+
logits.scatter_(1, prev_output_tokens, score)
74+
return logits
75+
76+
6377
def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
6478
"""Sample from top-k logits.
6579
Arguments:
@@ -97,6 +111,7 @@ def decode(
97111
top_k=1,
98112
top_p=0.0,
99113
temperature=1.0,
114+
repetition_penalty=1.0,
100115
eos_token_id=None,
101116
teacher_outputs=None,
102117
vocab_size=None,
@@ -186,10 +201,18 @@ def should_stop(current_token, inference_params):
186201
if enable_timing:
187202
start.record()
188203
scores, sequences = [], [input_ids]
204+
sequences_cat = input_ids
189205
while not should_stop(sequences[-1], inference_params):
190206
scores.append(get_logits(sequences[-1], inference_params))
191207
inference_params.seqlen_offset += sequences[-1].shape[1]
192-
sampled_tokens = sample_tokens(scores[-1], inference_params)
208+
if repetition_penalty == 1.0:
209+
sampled_tokens = sample_tokens(scores[-1], inference_params)
210+
else:
211+
logits = modify_logit_for_repetition_penalty(
212+
scores[-1].clone(), sequences_cat, repetition_penalty
213+
)
214+
sampled_tokens = sample_tokens(logits, inference_params)
215+
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
193216
sequences.append(sampled_tokens)
194217
if streamer is not None:
195218
streamer.put(sampled_tokens.cpu())

0 commit comments

Comments
 (0)