From a6a6e61a93dad56c2eac6b0d4a9402398b6971e3 Mon Sep 17 00:00:00 2001 From: YanbingJiang Date: Wed, 20 Nov 2024 00:09:02 +0800 Subject: [PATCH] Ignore tokens per sec from jit_compile iteration (#1378) * Remove tokens per sec in aggregate_metrics when jit_compile * Add warning to user * Update --------- Co-authored-by: Jack-Khuu --- torchchat/generate.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 66f26ff9f..9b4c6430a 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1149,9 +1149,11 @@ def callback(x, *, done_generating=False): print( f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds" ) - aggregate_metrics["tokens_per_sec"].append(tokens_sec) - aggregate_metrics["first_token_per_sec"].append(first_token_sec) - aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) + else: + # aggregate_metrics will not append when is jit_compile, which will affect the average numbers. + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + aggregate_metrics["first_token_per_sec"].append(first_token_sec) + aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) logging.info( f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\ @@ -1205,7 +1207,8 @@ def callback(x, *, done_generating=False): or torch.isnan(torch.tensor(avg_next_tokens_sec)) ): print( - f"\n Average tokens/sec (total): {avg_tokens_sec:.2f} \ + f"\nWarning: Excluding compile in calculations \ + \n Average tokens/sec (total): {avg_tokens_sec:.2f} \ \nAverage tokens/sec (first token): {avg_first_token_sec:.2f} \ \nAverage tokens/sec (next tokens): {avg_next_tokens_sec:.2f} \n\ "