Skip to content

Commit

Permalink
Move more generator args to use dataclass (#233)
Browse files Browse the repository at this point in the history
* prompt

* chat_mode, num_samples

* move more args

* more gen args

* update

* args

* undo some changes

* typos
  • Loading branch information
mikekgfb authored Apr 17, 2024
1 parent 55aa360 commit 1ea7739
Showing 1 changed file with 8 additions and 13 deletions.
21 changes: 8 additions & 13 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
@dataclass
class GeneratorArgs:
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
encoded_prompt: Optional[torch.Tensor] = None
chat_mode: bool = False
gui_mode: bool = False
num_samples: int = 1
Expand All @@ -45,6 +46,7 @@ class GeneratorArgs:
def from_args(cls, args): # -> GeneratorArgs:
return cls(
prompt=args.prompt,
encoded_prompt=None,
chat_mode=args.chat,
gui_mode=args.gui,
num_samples=args.num_samples,
Expand Down Expand Up @@ -305,7 +307,7 @@ def generate(
return seq, generate_stats


def encode_tokens(tokenizer, string, bos=True, device="cuda"):
def encode_tokens(tokenizer, string, bos=True, device="cpu"):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
Expand All @@ -317,13 +319,9 @@ def _main(
speculative_builder_args: BuilderArgs,
tokenizer_args: TokenizerArgs,
generator_args: GeneratorArgs,
max_new_tokens: int = 100,
top_k: int = 200,
temperature: float = 0.8,
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
speculate_k: int = 5,
quantize=None,
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
Expand Down Expand Up @@ -436,6 +434,7 @@ def callback(x):
t0 = time.perf_counter()
import contextlib

generator_args.encoded_prompt = encoded
if (i != generator_args.num_samples - 1 or not profile) or (use_tp and rank != 0):
prof = contextlib.nullcontext()
else:
Expand All @@ -445,13 +444,13 @@ def callback(x):
y, metrics = generate(
model,
encoded,
max_new_tokens,
generator_args.max_new_tokens,
draft_model=draft_model,
speculate_k=speculate_k,
speculate_k=generator_args.speculate_k,
chat_mode=generator_args.chat_mode,
callback=callback,
temperature=temperature,
top_k=top_k,
temperature=generator_args.temperature,
top_k=generator_args.top_k,
)
aggregate_metrics["accept_counts"].append(metrics["accept_counts"])
if i == -1:
Expand Down Expand Up @@ -502,13 +501,9 @@ def main(args):
speculative_builder_args,
tokenizer_args,
generator_args,
args.max_new_tokens,
args.top_k,
args.temperature,
args.compile,
args.compile_prefill,
args.profile,
args.speculate_k,
args.quantize,
)

Expand Down

0 comments on commit 1ea7739

Please sign in to comment.