Skip to content

Commit

Permalink
server : simplify
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Nov 22, 2024
1 parent 3fbecf8 commit 7dc6ae5
Showing 1 changed file with 42 additions and 49 deletions.
91 changes: 42 additions & 49 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2251,6 +2251,8 @@ struct server_context {

id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);

slot.i_batch = -1;

common_sampler_accept(slot.smpl, id, true);

slot.n_decoded += 1;
Expand All @@ -2277,73 +2279,64 @@ struct server_context {
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
continue;
}
}

slot.i_batch = -1;

if (slot.ctx_dft) {
struct common_speculative_params params_spec;
params_spec.n_draft = params.n_draft;
params_spec.n_reuse = 256;
params_spec.p_min = 0.9f;

llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
// check if the slot supports speculative decoding
if (!slot.ctx_dft) {
continue;
}

if (draft.size() > params.n_draft_min) {
common_batch_clear(slot.batch_spec);
common_batch_add(slot.batch_spec, id, slot.n_past++, { slot.id }, true);
// TODO: configurable through requests
struct common_speculative_params params_spec;
params_spec.n_draft = params.n_draft;
params_spec.n_reuse = 256;
params_spec.p_min = 0.9f;

for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + i, { slot.id }, true);
}
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);

llama_decode(ctx, slot.batch_spec);

const auto ids = common_sampler_sample_n(slot.smpl, ctx, draft);
if (params.n_draft_min > (int) draft.size()) {
continue;
}

slot.n_past += ids.size() - 1;
// construct the speculation batch
common_batch_clear(slot.batch_spec);
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);

slot.cache_tokens.push_back(id);
for (size_t i = 0; i < draft.size(); ++i) {
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
}

for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;
llama_decode(ctx, slot.batch_spec);

id = ids[i];
// the accepted tokens from the speculation
const auto ids = common_sampler_sample_n(slot.smpl, ctx, draft);

common_sampler_accept(slot.smpl, id, true);
slot.n_past += ids.size();
slot.n_decoded += ids.size();

slot.n_decoded += 1;
if (slot.n_decoded == 1) {
slot.t_start_generation = ggml_time_us();
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
metrics.on_prompt_eval(slot);
}
slot.cache_tokens.push_back(id);
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);

result.tok = id;
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);

const auto * cur_p = common_sampler_get_candidates(slot.smpl);
for (size_t i = 0; i < ids.size(); ++i) {
completion_token_output result;

for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
result.probs.push_back({
cur_p->data[i].id,
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
});
}
id = ids[i];

if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}
common_sampler_accept(slot.smpl, id, true);

llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
result.tok = id;

slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
if (!process_token(result, slot)) {
// release slot because of stop condition
slot.release();
slot.print_timings();
send_final_response(slot);
metrics.on_prediction(slot);
break;
}
}
}
Expand Down

0 comments on commit 7dc6ae5

Please sign in to comment.