-
Notifications
You must be signed in to change notification settings - Fork 140
Fix logprobs #787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix logprobs #787
Conversation
Previously, I had a more straight-forward but partial cherry-pick of ggml-org/llama.cpp#10783 at sayap/ik_llama.cpp@a8e39f6e, but that can't be applied anymore after #723. |
examples/server/utils.hpp
Outdated
} | ||
|
||
// sort tokens by logits | ||
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these need to be sorted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, the caller expects the vector to be sorted when returning the top n_probs tokens.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But is this just to be able to display the logprobs in the UI? If so, wouldn't it be better to have a command line argument to enable/disable this relatively expensive operation to avoid the performance penalty?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is controlled by the per-request parameter logprobs
and top_logprobs
. For example, I will use:
"top_k": 1,
"logprobs": true,
"top_logprobs": 3,
to choose the most probable token, and also to retrieve the top 3 logprobs for comparison, e.g. to make sure that the prompt is clear enough such that the most probable token has >= 80% probability.
When I don't need logprobs, then I will just go with the default of "logprobs": false
and "top_logprobs": 0
. No performance penalty when n_probs is 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, and especially if top_logprobs
is typically much smaller than the vocabulary size, one could speed it up by using partial sort. E.g.
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx, int n_sorted) {
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
n_sorted = std::min(n_sorted, n_vocab);
std::vector<std::pair<float, llama_token>> sorted(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) sorted[token_id] = {logits[token_id], token_id};
std::partial_sort(sorted.begin(), sorted.begin() + n_sorted, sorted.end(), std::greater<std::pair<float,llama_token>>{});
float max_l = sorted.front().first;
float cum_sum = 0.0f;
std::vector<llama_token_data> cur(n_sorted);
for (int i = 0; i < n_sorted; ++i) {
float p = expf(sorted[i].first - max_l);
cum_sum += p;
cur[i] = {sorted[i].second, sorted[i].first, p};
}
for (int i = n_sorted; i < n_vocab; ++i) cum_sum += expf(sorted[i].first - max_l);
float inv_cum_sum = 1/cum_sum;
for (int i = 0; i < n_sorted; ++i) cur[i].p *= inv_cum_sum;
return cur;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah very nice. Let me try that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the commit to do partial sort. The performance penalty is much smaller now 🎉
partial sort:
- no draft, no logprobs: 12.87 tok/s
- no draft, with top 3 logprobs: 12.61 tok/s (2.0% drop)
- with draft, no logprobs: 36.74 tok/s
- with draft, with top 3 logprobs: 36.12 tok/s (1.7% drop)
full sort:
- no draft, no logprobs: 12.81 tok/s
- no draft, with top 3 logprobs: 12.02 tok/s (6.2% drop)
- with draft, no logprobs: 36.59 tok/s
- with draft, with top 3 logprobs: 29.08 tok/s (20.5% drop)
{"content", !slot.params.stream ? slot.generated_text : ""}, | ||
{"generated_text", slot.generated_text}, // Always include full text for finish_reason logic | ||
{"id_slot", slot.id}, | ||
{"stop", true}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't delete anything for res.data
in send partial and final response. It will break /completions and /completion endpoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Indeed I didn't test those endpoints. Let me check how is it handled in mainline..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Indeed I didn't test those endpoints. Let me check how is it handled in mainline..
Do not follow mainline when it comes to logprobs and those endpoints. They needlessly broke things by changing the output stream when n_probs
is requested on those endpoints. Not only that but it put a performance cost on an API parameter (n_probs) that did not have one before as the default (requiring the use of passing post_sampling_probs to get rid of the performance penalty, but even then the output stream is still modified probs
vs top_probs
).
Following the OAI spec is a good thing for the v1 endpoints as that is their purpose, but the /completions and /completion are not based on those specs, and there are third party front ends that do make use of n_probs
on the non OAI endpoints in a non OAI way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Indeed I didn't test those endpoints.
Also if you are looking for a frontend to test the /completion endpoint instead of just sending network requests over the console, the WebUI added in #558 uses it.
Just a note, the performance numbers in that endpoint are measured different than the other WebUI, because it measures from the user's perspective which accounts for any network/browser latency, and does not use the timings sent from the server. (Both are useful and valid, and I may end up adding an option to display both or just the differential which would be the network/browser overhead). I just thought it should be noted especially if you try to directly compare them.
This commit is mostly a cherry-pick of ggml-org/llama.cpp#10783, plus optimization to do partial sort when sorting the logits. That mainline PR and friends were partially cherry-picked by ikawrakow#723, but wasn't really in a working state yet. A couple of additional changes: * Include timing information in response, which was (unintentionally?) done in mainline since ggml-org/llama.cpp#10643. * Also return the actual logprobs for accepted draft tokens. This is still a TODO in mainline [1]. Note that there is a TG performance penalty to return the logprobs, as we need to sort the logits. By doing partial sort, the penalty is quite small. Here are some numbers I got using the same prompt: This PR with partial sort: * no draft, no logprobs: 12.87 tok/s * no draft, with logprobs: 12.61 tok/s (2.0% drop) * with draft, no logprobs: 36.74 tok/s * with draft, with logprobs: 36.12 tok/s (1.7% drop) If cherry-pick the full sort from mainline PR: * no draft, no logprobs: 12.81 tok/s * no draft, with logprobs: 12.02 tok/s (6.2% drop) * with draft, no logprobs: 36.59 tok/s * with draft, with logprobs: 29.08 tok/s (20.5% drop) [1] https://github.com/ggml-org/llama.cpp/blob/b6548/tools/server/server.cpp#L4019 Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is mostly a cherry-pick of ggml-org/llama.cpp#10783.
That PR and friends were partially cherry-picked by #723, but wasn't really in a working state yet.
A couple of additional changes:
Note that there is a TG performance penalty to return the logprobs. Here are some numbers I got with Qwen2.5-Coder-32B-Instruct:
[1] https://github.com/ggml-org/llama.cpp/blob/b6548/tools/server/server.cpp#L4019