Skip to content

Commit 2bad79f

Browse files
committed
Fix the EOS checking
The secondary eos is usually `<end_of_turn>`, which can appear in the prompt, so we can only check it not in the prompt.
1 parent 6300c12 commit 2bad79f

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

gemma/gemma-inl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,7 +1427,7 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
14271427
// Sanity check: prompts should not be empty, nor start with EOS.
14281428
for (size_t query_idx = 0; query_idx < queries_prompt.size(); ++query_idx) {
14291429
const PromptTokens& prompt = queries_prompt[query_idx];
1430-
HWY_ASSERT(prompt.size() != 0 && !model.Config().IsEOS(prompt[0]));
1430+
HWY_ASSERT(prompt.size() != 0 && prompt[0] != runtime_config.eos_id);
14311431
}
14321432

14331433
const size_t num_queries = queries_prompt.size();
@@ -1615,4 +1615,4 @@ void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
16151615
} // namespace gcpp
16161616
HWY_AFTER_NAMESPACE();
16171617

1618-
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_
1618+
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_INL_H_

gemma/run.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,12 +118,6 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
118118
// callback function invoked for each generated token.
119119
auto stream_token = [&](int token, float) {
120120
++abs_pos;
121-
if (model.GetModelConfig().IsEOS(token)) {
122-
if (app.verbosity >= 2) {
123-
std::cout << "\n[ End ]\n";
124-
}
125-
return true;
126-
}
127121
const bool in_prompt = tokens_generated_this_turn < prompt_size;
128122
const bool first_response_token = tokens_generated_this_turn == prompt_size;
129123
++tokens_generated_this_turn;
@@ -132,6 +126,11 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
132126
std::cerr << "." << std::flush;
133127
}
134128
return true;
129+
} else if (model.GetModelConfig().IsEOS(token)) {
130+
if (app.verbosity >= 2) {
131+
std::cout << "\n[ End ]\n";
132+
}
133+
return true;
135134
}
136135
std::string token_text;
137136
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));

0 commit comments

Comments
 (0)