Skip to content

Commit

Permalink
Ignore tokens per sec from jit_compile iteration (#1378)
Browse files Browse the repository at this point in the history
* Remove tokens per sec in aggregate_metrics when jit_compile

* Add warning to user

* Update

---------

Co-authored-by: Jack-Khuu <[email protected]>
  • Loading branch information
yanbing-j and Jack-Khuu authored Nov 19, 2024
1 parent 5da240a commit a6a6e61
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,9 +1149,11 @@ def callback(x, *, done_generating=False):
print(
f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds"
)
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
else:
# aggregate_metrics will not append when is jit_compile, which will affect the average numbers.
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)

logging.info(
f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\
Expand Down Expand Up @@ -1205,7 +1207,8 @@ def callback(x, *, done_generating=False):
or torch.isnan(torch.tensor(avg_next_tokens_sec))
):
print(
f"\n Average tokens/sec (total): {avg_tokens_sec:.2f} \
f"\nWarning: Excluding compile in calculations \
\n Average tokens/sec (total): {avg_tokens_sec:.2f} \
\nAverage tokens/sec (first token): {avg_first_token_sec:.2f} \
\nAverage tokens/sec (next tokens): {avg_next_tokens_sec:.2f} \n\
"
Expand Down

0 comments on commit a6a6e61

Please sign in to comment.