diff --git a/gptfast/generate.py b/gptfast/generate.py index b5c3cc2..1e8446e 100644 --- a/gptfast/generate.py +++ b/gptfast/generate.py @@ -134,10 +134,11 @@ def decode_n_tokens( next_token, next_prob = next_token.clone(), next_prob.clone() input_pos += 1 new_tokens.append(next_token) - callback(new_tokens[-1]) + generation_done = callback(new_tokens) new_probs.append(next_prob) cur_token = next_token.view(1, -1) - + if generation_done is True: + break return new_tokens, new_probs @@ -250,9 +251,6 @@ def _load_model(checkpoint_path, device, precision): return model.eval() -B_INST, E_INST = "[INST]", "[/INST]" - - def recommended_inductor_config_setter(): torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.coordinate_descent_check_all_directions = True @@ -335,10 +333,26 @@ def setup_model_compilation( def process_generation( - model, inputs, tokenizer, i, num_samples, profile, device, **generation_kwargs + model, + inputs, + tokenizer, + i, + num_samples, + profile, + device, + stop_strings=None, + **generation_kwargs, ): t0 = time.perf_counter() + # Encode stop strings once at the start + stop_sequences = None + if stop_strings: + stop_sequences = [ + torch.tensor(tokenizer.encode(stop), dtype=torch.int, device=device) + for stop in stop_strings + ] + prof = ( torch.profiler.profile(with_stack=True) if i == num_samples - 1 and profile @@ -346,7 +360,18 @@ def process_generation( ) with prof: - output = generate(model, **inputs, callback=lambda x: x, **generation_kwargs) + + def callback(new_tokens): + if stop_sequences: + generated = torch.cat(new_tokens) + return any( + generated.size(0) >= stop_seq.size(0) + and torch.equal(generated[-stop_seq.size(0) :], stop_seq) + for stop_seq in stop_sequences + ) + return False + + output = generate(model, **inputs, callback=callback, **generation_kwargs) if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") @@ -390,6 +415,7 @@ def main( memory_profile: Optional[Path] = None, device=default_device, precision=torch.bfloat16, + stop_strings: Optional[list] = None, ) -> None: recommended_inductor_config_setter() assert checkpoint_path.is_file(), checkpoint_path @@ -425,6 +451,7 @@ def main( "top_k": top_k, "cache_size": cache_size, "linear_causal_mask": linear_causal_mask, + "stop_strings": stop_strings, } for i in range(start, num_samples): @@ -483,7 +510,7 @@ def main( parser.add_argument( "--prompt", type=str, - default="Explain what is the meaning of life in 500 words", + default="Explain what is the meaning of life", help="Input prompt.", ) parser.add_argument("--image_path", type=str, default=None, help="Image path.") @@ -532,6 +559,13 @@ def main( default=torch.bfloat16, help="dtype precision to use", ) + parser.add_argument( + "--stop_strings", + type=str, + nargs="+", + default=["<|im_end|>"], + help="List of strings that will stop generation when encountered at the end", + ) args = parser.parse_args() main( @@ -551,4 +585,5 @@ def main( args.memory_profile, args.device, args.precision, + args.stop_strings, )