Skip to content

Commit

Permalink
Merge pull request #67 from rhymes-ai/gptfast_imend
Browse files Browse the repository at this point in the history
feat: add stop strings support for gptfast
  • Loading branch information
xffxff authored Nov 13, 2024
2 parents 5191428 + 278e9d4 commit a2838e8
Showing 1 changed file with 43 additions and 8 deletions.
51 changes: 43 additions & 8 deletions gptfast/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,11 @@ def decode_n_tokens(
next_token, next_prob = next_token.clone(), next_prob.clone()
input_pos += 1
new_tokens.append(next_token)
callback(new_tokens[-1])
generation_done = callback(new_tokens)
new_probs.append(next_prob)
cur_token = next_token.view(1, -1)

if generation_done is True:
break
return new_tokens, new_probs


Expand Down Expand Up @@ -250,9 +251,6 @@ def _load_model(checkpoint_path, device, precision):
return model.eval()


B_INST, E_INST = "[INST]", "[/INST]"


def recommended_inductor_config_setter():
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
Expand Down Expand Up @@ -335,18 +333,45 @@ def setup_model_compilation(


def process_generation(
model, inputs, tokenizer, i, num_samples, profile, device, **generation_kwargs
model,
inputs,
tokenizer,
i,
num_samples,
profile,
device,
stop_strings=None,
**generation_kwargs,
):
t0 = time.perf_counter()

# Encode stop strings once at the start
stop_sequences = None
if stop_strings:
stop_sequences = [
torch.tensor(tokenizer.encode(stop), dtype=torch.int, device=device)
for stop in stop_strings
]

prof = (
torch.profiler.profile(with_stack=True)
if i == num_samples - 1 and profile
else contextlib.nullcontext()
)

with prof:
output = generate(model, **inputs, callback=lambda x: x, **generation_kwargs)

def callback(new_tokens):
if stop_sequences:
generated = torch.cat(new_tokens)
return any(
generated.size(0) >= stop_seq.size(0)
and torch.equal(generated[-stop_seq.size(0) :], stop_seq)
for stop_seq in stop_sequences
)
return False

output = generate(model, **inputs, callback=callback, **generation_kwargs)

if i == -1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
Expand Down Expand Up @@ -390,6 +415,7 @@ def main(
memory_profile: Optional[Path] = None,
device=default_device,
precision=torch.bfloat16,
stop_strings: Optional[list] = None,
) -> None:
recommended_inductor_config_setter()
assert checkpoint_path.is_file(), checkpoint_path
Expand Down Expand Up @@ -425,6 +451,7 @@ def main(
"top_k": top_k,
"cache_size": cache_size,
"linear_causal_mask": linear_causal_mask,
"stop_strings": stop_strings,
}

for i in range(start, num_samples):
Expand Down Expand Up @@ -483,7 +510,7 @@ def main(
parser.add_argument(
"--prompt",
type=str,
default="Explain what is the meaning of life in 500 words",
default="Explain what is the meaning of life",
help="Input prompt.",
)
parser.add_argument("--image_path", type=str, default=None, help="Image path.")
Expand Down Expand Up @@ -532,6 +559,13 @@ def main(
default=torch.bfloat16,
help="dtype precision to use",
)
parser.add_argument(
"--stop_strings",
type=str,
nargs="+",
default=["<|im_end|>"],
help="List of strings that will stop generation when encountered at the end",
)

args = parser.parse_args()
main(
Expand All @@ -551,4 +585,5 @@ def main(
args.memory_profile,
args.device,
args.precision,
args.stop_strings,
)

0 comments on commit a2838e8

Please sign in to comment.