diff --git a/generate.py b/generate.py index 69f94fa86..097e0aa42 100644 --- a/generate.py +++ b/generate.py @@ -104,9 +104,14 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): def prefill( - model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs + model: Transformer, + x: torch.Tensor, + input_pos: torch.Tensor, + *, + sequential_prefill = True, + **sampling_kwargs ) -> torch.Tensor: - print(f"x: {x}, input_pos: {input_pos}") + # print(f"x: {x}, input_pos: {input_pos}") width = x.size(1) assert input_pos.size(0) == width sequential_prefill = True @@ -114,7 +119,7 @@ def prefill( if sequential_prefill: for i in range(width): x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) - print(f" x: {x_sliced}, input_pos: {ip_sliced}") + #print(f" x: {x_sliced}, input_pos: {ip_sliced}") logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i]) else: # input_pos: [B, S] @@ -157,13 +162,6 @@ def decode_n_tokens( return new_tokens, new_probs -# try: -# from .thin_wrapper import model_forward -# -# except: -# print("compiled model load not successful, running eager model") - - def model_forward(model, x, input_pos): return model(x, input_pos) @@ -374,7 +372,7 @@ def _main( encoded = encode_tokens( tokenizer, generator_args.prompt, bos=True, device=builder_args.device ) - print(encoded) + # print(encoded) prompt_length = encoded.size(0) model_size = sum(