Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 76 additions & 68 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ struct slot_params {
std::vector<std::string> antiprompt;

bool timings_per_token = false;
bool post_sampling_probs = false;
json input_prefix;
json input_suffix;

Expand Down Expand Up @@ -1545,6 +1546,8 @@ struct server_context {
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);

slot.params.post_sampling_probs = json_value(data, "post_sampling_probs", default_params.post_sampling_probs);

// speculative decoding parameters
slot.params.speculative.n_max = json_value(data, "speculative.n_max", params.n_draft);
slot.params.speculative.n_min = json_value(data, "speculative.n_min", params.n_draft_min);
Expand Down Expand Up @@ -1947,26 +1950,7 @@ struct server_context {
}

// check if there is incomplete UTF-8 character at the end
bool incomplete = false;
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
if ((c & 0xC0) == 0x80) {
// continuation byte: 10xxxxxx
continue;
}
if ((c & 0xE0) == 0xC0) {
// 2-byte character: 110xxxxx ...
incomplete = i < 2;
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character: 1110xxxx ...
incomplete = i < 3;
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character: 11110xxx ...
incomplete = i < 4;
}
// else 1-byte character or invalid byte
break;
}
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();

if (!incomplete) {
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
Expand Down Expand Up @@ -2062,6 +2046,56 @@ struct server_context {
return slot.has_next_token; // continue
}

void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
size_t n_probs = slot.sparams.n_probs;
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));

if (post_sampling) {
const auto * cur_p = llama_sampling_get_candidates(slot.ctx_sampling);
const size_t max_probs = cur_p->size;

// set probability for sampled token
for (size_t i = 0; i < max_probs; i++) {
if (cur_p->data[i].id == result.tok) {
result.prob = cur_p->data[i].p;
break;
}
}

// set probability for top n_probs tokens
result.probs.reserve(max_probs);
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
result.probs.push_back({
cur_p->data[i].id,
llama_detokenize(ctx, {cur_p->data[i].id}, special),
cur_p->data[i].p
});
}
} else {
// TODO: optimize this with min-p optimization
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx, n_probs);

// set probability for sampled token
for (size_t i = 0; i < n_vocab; i++) {
// set probability for sampled token
if (cur[i].id == result.tok) {
result.prob = cur[i].p;
break;
}
}

// set probability for top n_probs tokens
result.probs.reserve(n_probs);
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
result.probs.push_back({
cur[i].id,
llama_detokenize(ctx, {cur[i].id}, special),
cur[i].p
});
}
}
}

json get_formated_generation(const server_slot & slot) const {
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
Expand Down Expand Up @@ -2159,6 +2193,7 @@ struct server_context {
res.stop = false;
res.stream = slot.params.stream;
res.content = tkn.text_to_send;
res.post_sampling_probs = slot.params.post_sampling_probs;
res.oaicompat = slot.params.oaicompat;
res.oaicompat_model = slot.params.oaicompat_model;
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
Expand All @@ -2171,26 +2206,18 @@ struct server_context {
{"multimodal", false}
};
slot.update_chat_msg(res.oaicompat_msg_diffs);
if (slot.sparams.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());

std::vector<completion_token_output> probs_output;
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin() + probs_pos,
slot.generated_token_probs.begin() + probs_stop_pos);
}
slot.n_sent_token_probs = probs_stop_pos;

res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
// populate res.probs_output
if (slot.sparams.n_probs > 0) {
res.probs_output = {tkn}; // copy the token probs
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
}

if (slot.oaicompat) {
res.data["oaicompat_token_ctr"] = slot.n_decoded;
res.data["model"] = slot.oaicompat_model;
}

// populate timings if this is final response or timings_per_token is enabled
if (slot.params.timings_per_token) {
res.timings = slot.get_timings();
Expand All @@ -2207,6 +2234,8 @@ struct server_context {
res.stop = true; // to do: set value
res.stream = slot.params.stream;
res.content = slot.generated_text;
res.timings = slot.get_timings();
res.post_sampling_probs = slot.params.post_sampling_probs;
res.oaicompat = slot.params.oaicompat;
res.oaicompat_model = slot.params.oaicompat_model;
res.oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
Expand Down Expand Up @@ -2234,26 +2263,23 @@ struct server_context {
//{"oaicompat_chat_format", slot.params.oaicompat_chat_format},
};

// populate res.probs_output
if (slot.sparams.n_probs > 0) {
std::vector<completion_token_output> probs;
if (!slot.params.stream && slot.stopped_word) {
const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);

size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
probs = std::vector<completion_token_output>(
res.probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin(),
slot.generated_token_probs.end() - safe_offset);
} else {
probs = std::vector<completion_token_output>(
res.probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin(),
slot.generated_token_probs.end());
}
//res.generation_params = slot.params;
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
res.data["completion_probabilities"] = probs_vector_to_json(ctx, res.probs_output);
}

res.timings = slot.get_timings();

if (slot.oaicompat) {
res.data["oaicompat_token_ctr"] = slot.n_decoded;
res.data["model"] = slot.oaicompat_model;
Expand Down Expand Up @@ -3194,7 +3220,8 @@ struct server_context {
}

completion_token_output result;
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
const int tok_idx = slot.i_batch - i;
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, tok_idx);

llama_sampling_accept(slot.ctx_sampling, ctx, id, true);

Expand All @@ -3210,35 +3237,12 @@ struct server_context {

slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;

llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
result.tok = id;
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
result.text_to_send = llama_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok));

const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
if (n_probs > 0) {
const size_t n_valid = slot.ctx_sampling->n_valid;

// Make sure at least n_probs top tokens are at the front of the vector:
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
}

if (slot.sparams.temp == 0.0f) {
// With greedy sampling the probabilities have possibly not been calculated.
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,llama_detokenize(ctx, {cur_p.data[i].id}, params.special),
i == 0 ? 1.0f : 0.0f
});
}
} else {
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id, llama_detokenize(ctx, {cur_p.data[i].id}, params.special),
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
});
}
}
if (slot.sparams.n_probs > 0) {
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, tok_idx);
}

if (!process_token(result, slot)) {
Expand Down Expand Up @@ -3343,7 +3347,11 @@ struct server_context {

result.tok = ids[i];
result.text_to_send = llama_token_to_piece(ctx, result.tok, params.special);
// result.prob = 1.0f; // set later
result.prob = 1.0f; // set later

if (slot.sparams.n_probs > 0) {
populate_token_probs(slot, result, slot.params.post_sampling_probs, params.special, i);
}

if (!process_token(result, slot)) {
// release slot because of stop condition
Expand Down
32 changes: 28 additions & 4 deletions examples/server/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,6 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
return out;
}


//
// OAI utils
//
Expand Down Expand Up @@ -616,13 +615,12 @@ static json oaicompat_chat_params_parse(

// Handle "logprobs" field
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
if (body.contains("logprobs")) {
if (json_value(body, "logprobs", false)) {
if (has_tools && stream) {
throw std::runtime_error("logprobs is not supported with tools + stream");
}
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
}
else if (body.contains("top_logprobs")) {
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
}

Expand Down Expand Up @@ -715,3 +713,29 @@ static json format_error_response(const std::string & message, const enum error_
{"type", type_str},
};
}

static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx, int n_sorted) {
const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
n_sorted = std::min(n_sorted, n_vocab);

std::vector<std::pair<float, llama_token>> sorted(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) sorted[token_id] = {logits[token_id], token_id};

std::partial_sort(sorted.begin(), sorted.begin() + n_sorted, sorted.end(), std::greater<std::pair<float,llama_token>>{});

float max_l = sorted.front().first;
float cum_sum = 0.0f;
std::vector<llama_token_data> cur(n_sorted);
for (int i = 0; i < n_sorted; ++i) {
float p = expf(sorted[i].first - max_l);
cum_sum += p;
cur[i] = {sorted[i].second, sorted[i].first, p};
}
for (int i = n_sorted; i < n_vocab; ++i) cum_sum += expf(sorted[i].first - max_l);

float inv_cum_sum = 1/cum_sum;
for (int i = 0; i < n_sorted; ++i) cur[i].p *= inv_cum_sum;

return cur;
}