diff --git a/src/deepsparse/operators/engine_operator.py b/src/deepsparse/operators/engine_operator.py index 3c4e001293..af7f2c6242 100644 --- a/src/deepsparse/operators/engine_operator.py +++ b/src/deepsparse/operators/engine_operator.py @@ -19,7 +19,7 @@ from deepsparse.benchmark import ORTEngine from deepsparse.engine import Context as EngineContext -from deepsparse.engine import Engine, MultiModelEngine, Scheduler +from deepsparse.engine import DebugAnalysisEngine, KVCacheParams, Engine, MultiModelEngine, Scheduler from deepsparse.operators import Operator from deepsparse.utils import join_engine_outputs, model_to_path, split_engine_inputs @@ -169,7 +169,14 @@ def create_engine( **engine_args, ) engine_args.pop("cache_output_bools", None) - return Engine(onnx_file_path, **engine_args) + engine_args.pop("num_streams", None) + cached_outputs = engine_args.pop("cached_outputs", None) + if not cached_outputs: + raise ValueError + engine_args["kv_cache_params"] = KVCacheParams(cached_outputs, 0, 0) + engine_args["num_warmup_iterations"] = 0 + engine_args["num_iterations"] = 1 + return DebugAnalysisEngine(onnx_file_path, **engine_args) if engine_type == ORT_ENGINE: return ORTEngine(onnx_file_path, **engine_args) diff --git a/src/deepsparse/pipeline.py b/src/deepsparse/pipeline.py index e2a1beeab1..e7e1d7502c 100644 --- a/src/deepsparse/pipeline.py +++ b/src/deepsparse/pipeline.py @@ -139,8 +139,9 @@ def create(cls, task: str, **kwargs) -> "Pipeline": "Pipeline was not created for the given task. The " "provided task should be registered using the OperatorRegistry" ) - except Exception: - _LOGGER.warning(f"Could not create v2 '{task}' pipeline, trying legacy") + except Exception as e: + _LOGGER.warning(f"Could not create v2 '{task}' pipeline, with error: {e}") + _LOGGER.warning(f"Attempting to create the legacy pipeline") from deepsparse.legacy import Pipeline pipeline = Pipeline.create(task=task, **kwargs) diff --git a/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py b/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py index 0f5174e1c4..df4e587df3 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py +++ b/src/deepsparse/transformers/pipelines/text_generation/autoregressive_preprocess_operator.py @@ -51,6 +51,12 @@ def can_operate(self, inp: Any) -> bool: if inp.get("in_generation"): return True + if kv_cache.total_num_processed_tokens >= kv_cache.capacity: + raise RuntimeError( + "Not enough kv_cache capacity to run generation. Please use a larger " + "sequence_length or a shorter prompt" + ) + remaining_tokens = len(tokens) - kv_cache.total_num_processed_tokens can_process = ( remaining_tokens > 0 and remaining_tokens < self.prompt_sequence_length diff --git a/src/deepsparse/transformers/pipelines/text_generation/compile_generations.py b/src/deepsparse/transformers/pipelines/text_generation/compile_generations.py index 2187e525a1..71c60a5936 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/compile_generations.py +++ b/src/deepsparse/transformers/pipelines/text_generation/compile_generations.py @@ -17,7 +17,6 @@ from pydantic import BaseModel, Field from deepsparse.operators import Operator -from deepsparse.transformers.schemas.text_generation_schemas import FinishReason from deepsparse.utils import InferenceState @@ -43,9 +42,6 @@ def run(self, inference_state: InferenceState, **kwargs): generated_logits = inference_state.current_state.get("generated_logits") finished_reason = inference_state.current_state.get("finished_reason") - if len(finished_reason) == 0: - finished_reason.append(FinishReason.LENGTH) - generated_tokens = numpy.array([generated_tokens]) generated_logits = numpy.concatenate(generated_logits, axis=1) return { diff --git a/src/deepsparse/transformers/pipelines/text_generation/generate_new_token.py b/src/deepsparse/transformers/pipelines/text_generation/generate_new_token.py index 4b32722590..97f5393129 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/generate_new_token.py +++ b/src/deepsparse/transformers/pipelines/text_generation/generate_new_token.py @@ -19,10 +19,7 @@ from deepsparse.transformers.pipelines.text_generation.nl_engine_operator import ( NLEngineOutputs, ) -from deepsparse.transformers.schemas.text_generation_schemas import ( - FinishReason, - PromptLogitsNoKVCacheInference, -) +from deepsparse.transformers.schemas.text_generation_schemas import FinishReason from deepsparse.utils import InferenceState @@ -36,14 +33,16 @@ def __init__( self.force_max_tokens = force_max_tokens self.tokenizer = tokenizer - def can_operate(self, inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs]): + def can_operate( + self, inp: Union[NLEngineOutputs, "PrepareForGenerationOutput"] # noqa: F821 + ): if inp.in_generation: return True return False def run( self, - inp: Union[PromptLogitsNoKVCacheInference, NLEngineOutputs], + inp: Union[NLEngineOutputs, "PrepareForGenerationOutput"], # noqa: F821 inference_state: InferenceState, **kwargs, ): @@ -52,21 +51,26 @@ def run( if isinstance(inp, NLEngineOutputs) else inp.prompt_logits ) - kv_cache = inp.kv_cache if isinstance(inp, NLEngineOutputs) else None + kv_cache = inp.kv_cache + + max_tokens = inference_state.current_state.get("max_tokens") + length_finish_reason = inference_state.current_state.get("length_finish_reason") + generated_tokens = inference_state.current_state.get("generated_tokens") + num_generated_tokens = len(generated_tokens) token_generator = inference_state.current_state.get("token_generator") token = token_generator.generate(logits=logits[0, -1, :]) finish_reason = None - callback = inference_state.current_state.get("callback") - stop = inference_state.current_state.get("stop") - if ( kv_cache is not None and kv_cache.total_num_processed_tokens >= kv_cache.capacity ): finish_reason = FinishReason.CAPACITY + callback = inference_state.current_state.get("callback") + stop = inference_state.current_state.get("stop") + if token == self.tokenizer.eos_token_id and not self.force_max_tokens: finish_reason = FinishReason.STOP @@ -84,9 +88,11 @@ def run( ) finish_reason = FinishReason.CALLBACK - max_tokens = inference_state.current_state.get("max_tokens") - if len(inference_state.current_state.get("generated_tokens")) + 1 >= max_tokens: - finish_reason = inference_state.current_state.get("length_finish_reason") + # Note: this is +1 as the inference state variable keeping track of all the + # generated tokens has not yet been updated with the most recently generated + # token from this operator + if num_generated_tokens + 1 == max_tokens: + finish_reason = length_finish_reason state_update = { "token_generator": token_generator, diff --git a/src/deepsparse/transformers/pipelines/text_generation/multi_engine_prefill_operator.py b/src/deepsparse/transformers/pipelines/text_generation/multi_engine_prefill_operator.py index dca4fc3ff9..76383028f9 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/multi_engine_prefill_operator.py +++ b/src/deepsparse/transformers/pipelines/text_generation/multi_engine_prefill_operator.py @@ -42,6 +42,12 @@ def can_operate(self, inp: Any): kv_cache = inp.get("kv_cache") tokens = inp.get("tokens") + if kv_cache.total_num_processed_tokens >= kv_cache.capacity: + raise RuntimeError( + "Not enough kv_cache capacity to run generation. Please use a larger " + "sequence_length or a shorter prompt" + ) + if len(tokens) < self.prompt_sequence_length: return False diff --git a/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py b/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py index 1f631573ae..a38240766c 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py +++ b/src/deepsparse/transformers/pipelines/text_generation/nl_engine_operator.py @@ -31,7 +31,6 @@ overwrite_onnx_model_inputs_for_kv_cache_models, ) - __all__ = ["NLEngineOperator", "NLEngineInputs", "NLEngineOutputs"] @@ -105,7 +104,6 @@ def split(self) -> List["NLEngineOutputs"]: class NLEngineOperator(EngineOperator): - """ Operator for the NL Decoder Engine. This Operator inherits from the EngineOperator. Specific updates to engine attributes are made through this operator, as well @@ -117,22 +115,23 @@ class NLEngineOperator(EngineOperator): output_schema = NLEngineOutputs def __init__( - self, - sequence_length: int, - input_ids_length: int, - internal_kv_cache: bool = False, - **kwargs, + self, + sequence_length: int, + input_ids_length: int, + internal_kv_cache: bool = False, + **kwargs, ): self.sequence_length = sequence_length self.input_ids_length = input_ids_length self.internal_kv_cache = internal_kv_cache self.kv_cache_data_type = None + self.inference_index = 0 super().__init__(**kwargs) def create_engine( - self, batch_size: Optional[int] = None, engine_kwargs: Optional[dict] = None + self, batch_size: Optional[int] = None, engine_kwargs: Optional[dict] = None ): batch_size = batch_size if batch_size is not None else self._batch_size @@ -164,10 +163,10 @@ def create_engine( return super().create_engine(**kwargs, **engine_kwargs) def override_model_inputs( - self, - model_path: Union[str, Path], - batch_size: int, - return_additional_outputs=False, + self, + model_path: Union[str, Path], + batch_size: int, + return_additional_outputs=False, ): """ Override the model based on the provided batch_size, sequence_length, @@ -219,10 +218,30 @@ def run(self, inp: NLEngineInputs, **kwargs) -> NLEngineOutputs: # we skip the validation internal_kv_cache = [x.engine_internal_cache for x in kv_cache] + # if inp.engine: + # out = inp.engine._eng_net.execute_list_out(inputs, internal_kv_cache) + # else: + # out = self.engine._eng_net.execute_list_out(inputs, internal_kv_cache) if inp.engine: - out = inp.engine._eng_net.execute_list_out(inputs, internal_kv_cache) + out, bench_info = inp.engine._eng_net.benchmark_execute( + inputs, internal_kv_cache + ) else: - out = self.engine._eng_net.execute_list_out(inputs, internal_kv_cache) + out, bench_info = self.engine._eng_net.benchmark_execute( + inputs, internal_kv_cache + ) + + inferenece_type = 'prefill' if self.prefill else 'decode' + filename = f"analysis-{inferenece_type}-{self.inference_index}.pickle" + if 'WAND_BENCH_ANALYSIS_DIR' in os.environ: + filename = os.path.join(os.environ['WAND_BENCH_ANALYSIS_DIR'], filename) + + print(f"Saving text generation inference analysis to {filename}") + import pickle + with open(filename, 'wb') as f: + pickle.dump(bench_info, f, pickle.HIGHEST_PROTOCOL) + self.inference_index += 1 + out = [v for v in out.values()] else: # run the engine without the LIB.kv_cache object diff --git a/src/deepsparse/transformers/pipelines/text_generation/pipeline.py b/src/deepsparse/transformers/pipelines/text_generation/pipeline.py index dede3d4a90..bcb854de91 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/pipeline.py +++ b/src/deepsparse/transformers/pipelines/text_generation/pipeline.py @@ -184,6 +184,9 @@ def __init__( **engine_kwargs, ) + single_engine_operator.prefill = False + multi_engine_operator.prefill = True + # NOTE: Currently using pipeline state. Can swap to simply pass in the # attributes to the specific Operator that need them, as class attributes. pipeline_state_vals[ @@ -239,7 +242,6 @@ def __init__( sequence_length=sequence_length, prompt_sequence_length=prompt_sequence_length, token_generator=token_generator, - process_output_operator=process_output, ) # TODO: do we want to support lists for different engines? @@ -286,7 +288,7 @@ def __init__( "compile_logits", "generate_new_token", ], - "prep_for_generation": "autoregressive_preprocess", + "prep_for_generation": "generate_new_token", "generate_new_token": "compile_generated_tokens", } diff --git a/src/deepsparse/transformers/pipelines/text_generation/pipeline_no_kv_cache.py b/src/deepsparse/transformers/pipelines/text_generation/pipeline_no_kv_cache.py index d58766c2b1..7f6cb9db5f 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/pipeline_no_kv_cache.py +++ b/src/deepsparse/transformers/pipelines/text_generation/pipeline_no_kv_cache.py @@ -19,6 +19,7 @@ from deepsparse.routers import GraphRouter from deepsparse.schedulers import OperatorScheduler from deepsparse.transformers.pipelines.text_generation import ( + CompileGeneratedTokens, CompileGenerations, GenerateNewTokenOperator, JoinOutput, @@ -73,6 +74,7 @@ def __init__( tokenizer=self.tokenizer, force_max_tokens=True ) compile_generations = CompileGenerations() + compile_generated_tokens = CompileGeneratedTokens() join_output = JoinOutput(tokenizer=self.tokenizer) process_outputs = ProcessOutputs(tokenizer=self.tokenizer) @@ -82,6 +84,7 @@ def __init__( "engine_operator": engine_operator, "prepare_generation": prepare_generation, "generate_new_token": generate_new_token, + "compile_generated_tokens": compile_generated_tokens, "compile_generations": compile_generations, "join_output": join_output, "process_outputs": process_outputs, @@ -92,7 +95,8 @@ def __init__( "SPLIT": "engine_operator", "engine_operator": "prepare_generation", "prepare_generation": "generate_new_token", - "generate_new_token": "compile_generations", + "generate_new_token": "compile_generated_tokens", + "compile_generated_tokens": "compile_generations", "compile_generations": "JOIN", "JOIN": "join_output", "join_output": "process_outputs", diff --git a/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py b/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py index 6b97b432ae..3318ec88c5 100644 --- a/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py +++ b/src/deepsparse/transformers/pipelines/text_generation/prep_for_generation.py @@ -15,35 +15,38 @@ from typing import Any, Optional import numpy +from pydantic import BaseModel, Field from deepsparse.operators import Operator -from deepsparse.subgraph_execute import StreamingOutput from deepsparse.transformers.pipelines.text_generation import TokenGeneratorOperator -from deepsparse.transformers.schemas.text_generation_schemas import ( - FinishReason, - PromptLogitsNoKVCacheInference, -) +from deepsparse.transformers.schemas.text_generation_schemas import FinishReason from deepsparse.transformers.utils.helpers import set_generated_length from deepsparse.utils import InferenceState -__all__ = ["PrepareGeneration"] +__all__ = ["PrepareGeneration", "PrepareForGenerationOutput"] + + +class PrepareForGenerationOutput(BaseModel): + prompt_logits: Any = Field( + description="A set of prompt logits generated during prefill" + ) + kv_cache: Optional[Any] = Field(description="kv cache") + in_generation: Optional[bool] = Field(description="in_generation flag") class PrepareGeneration(Operator): + output_schema = PrepareForGenerationOutput + def __init__( self, token_generator: TokenGeneratorOperator, prompt_sequence_length: int, sequence_length: int, - process_output_operator: Optional[Operator] = None, ): self.sequence_length = sequence_length self.token_generator_creator = token_generator self.prompt_sequence_length = prompt_sequence_length - # Needed for streaming as currently both setting up generation and generating - # Will split this up soon - self.process_output_operator = process_output_operator def can_operate(self, inp: Any): kv_cache = inp.get("kv_cache") @@ -79,7 +82,6 @@ def run( **inference_state.current_state, ) token_generator = token_generator_creator_output.get("token_generator") - token_generator.generate(prompt_logits[0, -1, :]) max_tokens, length_finish_reason = set_generated_length( max_length=generation_config.max_length, @@ -93,43 +95,21 @@ def run( state_update = { "max_tokens": max_tokens, "length_finish_reason": length_finish_reason, - "generated_tokens": [token_generator.tokens[-1]], - "generated_logits": [prompt_logits] + "generated_tokens": [], + "generated_logits": [prompt_logits[:, 0:-1, :]] if include_prompt_logits - else [numpy.expand_dims(prompt_logits[:, -1, :], 0)], + else [], "finished_reason": [], "token_generator": token_generator, } + if kv_cache is None: - output = PromptLogitsNoKVCacheInference(prompt_logits=prompt_logits) + output = {"prompt_logits": numpy.expand_dims(prompt_logits[:, -1, :], 0)} else: output = { - "tokens": token_generator.tokens, "kv_cache": kv_cache, "in_generation": True, + "prompt_logits": numpy.expand_dims(prompt_logits[:, -1, :], 0), } - # TODO: maybe break this operator up since it is both generating and setting - # up values needed for generation? Holding off on this as this will change - # routes slighty and want to confirm wont break anything for non-kv cache - if inference_state.current_state.get("streaming") and max_tokens >= 1: - finished_reason = [length_finish_reason] if max_tokens == 1 else [None] - - if self.process_output_operator is None: - raise ValueError( - "An operator must be provided to process outputs" - "while streaming." - ) - data_to_yield = self.process_output_operator.run( - generated_tokens=state_update.get("generated_tokens"), - finished_reason=finished_reason, - inference_state=inference_state, - generated_logits=prompt_logits[0, -1, :], - ) - output = StreamingOutput( - data_to_yield=self.process_output_operator.output_schema( - **data_to_yield - ), - data_to_return=output, - ) return output, state_update diff --git a/src/deepsparse/transformers/schemas/text_generation_schemas.py b/src/deepsparse/transformers/schemas/text_generation_schemas.py index b1a15fc67c..32a7675694 100644 --- a/src/deepsparse/transformers/schemas/text_generation_schemas.py +++ b/src/deepsparse/transformers/schemas/text_generation_schemas.py @@ -166,11 +166,3 @@ class TextGenerationOutput(BaseModel): class Config: arbitrary_types_allowed = True extra = "allow" - - -class PromptLogitsNoKVCacheInference(BaseModel): - prompt_logits: Any = Field( - description="A set of prompt logits generated " - "during the inference pass with a " - "non-kv cache model" - ) diff --git a/src/deepsparse/transformers/utils/helpers.py b/src/deepsparse/transformers/utils/helpers.py index 581a2b0d41..7ee77f6c28 100644 --- a/src/deepsparse/transformers/utils/helpers.py +++ b/src/deepsparse/transformers/utils/helpers.py @@ -104,8 +104,10 @@ def set_generated_length( :param max_new_tokens: the max_new_tokens attribute, which may be provided as part of the input during inference """ - if max_length: + if max_length is not None: # if max_length provided, use that to cap total tokens generated + if max_length == 0: + raise ValueError("max_length must be greater than 0") max_tokens = max_length finish_reason = finish_reason_choices.LENGTH else: diff --git a/tests/deepsparse/transformers/pipelines/test_text_generation.py b/tests/deepsparse/transformers/pipelines/test_text_generation.py index ec848961b4..825dafee62 100644 --- a/tests/deepsparse/transformers/pipelines/test_text_generation.py +++ b/tests/deepsparse/transformers/pipelines/test_text_generation.py @@ -143,6 +143,7 @@ def test_stop_inference_kv_cache_full(prompt): expected_generated_tokens_length=max_new_tokens_plus_one, expected_finished_reason="capacity", ) + """ Check the following structure ok the kv cache: minus_one | full | plus_one | plus_two @@ -152,6 +153,7 @@ def test_stop_inference_kv_cache_full(prompt): [row B] | [row C] | [row D] | [row D] ... | ... | ... | ... """ + # check for the "free" space in the kv cache assert kv_cache_state_full_minus_one["past_key_values.0.key"][:, :, 0, :].sum() == 0 # check for the row A @@ -282,3 +284,53 @@ def test_streaming_non_streaming_generate_same_tokens(pipeline, prompt): tokens.append(g.generations[0].text) output_2 = "".join(tokens) assert output_1 == output_2 + + +def test_edge_cases(pipeline, prompt): + # total length of the generated sequence is just 1 token; this should just use + # the last prompt logit + output = pipeline(prompt=prompt, max_length=1, output_scores=True) + assert len(output.generations[0].score) == 1 + + output = pipeline( + prompt=prompt, max_length=1, output_scores=True, include_prompt_logits=True + ) + assert len(output.generations[0].score) == 11 + + # max_new_tokens == 0 and max_length == 1 should result in the same behaviour + # the generation is only dependent on the prompt logit, not any new generated logit + output = pipeline(prompt=prompt, max_new_tokens=0, output_scores=True) + assert len(output.generations[0].score) == 1 + + output = pipeline( + prompt=prompt, max_new_tokens=0, output_scores=True, include_prompt_logits=True + ) + assert len(output.generations[0].score) == 11 + + # expect total scores/length of the generation to be 2: 1 for the token generated + # from the last prompt logit and the rest generated from the value provided + # using the max_new_tokens argument (which in this case is 1) + output = pipeline(prompt=prompt, max_new_tokens=1, output_scores=True) + assert len(output.generations[0].score) == 2 + + output = pipeline( + prompt=prompt, max_new_tokens=1, output_scores=True, include_prompt_logits=True + ) + assert len(output.generations[0].score) == 12 + + # dont support max_length == 0; raise value error + with pytest.raises(ValueError): + pipeline(prompt=prompt, max_length=0) + + +def test_kv_cache_too_small_for_prefill(prompt): + for i in range(10): + prompt += prompt + + pipeline = Pipeline.create( + task="text_generation", + model_path="hf:mgoin/TinyStories-1M-deepsparse", + sequence_length=25, + ) + with pytest.raises(RuntimeError): + pipeline(prompt=prompt) diff --git a/tests/deepsparse/transformers/text_generation/unit/text_generation/test_pipeline_no_kv_cache.py b/tests/deepsparse/transformers/text_generation/unit/text_generation/test_pipeline_no_kv_cache.py index 7b219ee080..88f79edb36 100644 --- a/tests/deepsparse/transformers/text_generation/unit/text_generation/test_pipeline_no_kv_cache.py +++ b/tests/deepsparse/transformers/text_generation/unit/text_generation/test_pipeline_no_kv_cache.py @@ -14,6 +14,8 @@ import copy +import numpy + from deepsparse import TextGeneration from deepsparse.transformers.pipelines.text_generation.pipeline import ( TextGenerationPipeline, @@ -41,19 +43,24 @@ def test_assert_same_outputs_regardless_of_kv_cache_support(model_attributes): # make sure that kv cache support does not change the output prompt = "Hello, how are you doing today?" _prompt = copy.deepcopy(prompt) - max_new_tokens = 16 + max_length = 16 pipeline = TextGeneration( model_path=model_attributes[1], onnx_model_name="model-orig.onnx" ) - for _ in range(max_new_tokens): + non_kv_cache_logits = [] + for _ in range(max_length): # simulate autoregressive generation with non-kv cache pipeline - out = pipeline(prompt=_prompt) + out = pipeline(prompt=_prompt, output_scores=True) + non_kv_cache_logits.append(out.generations[0].score) new_token = pipeline.tokenizer.encode(out.generations[0].text) old_tokens = pipeline.tokenizer.encode(_prompt) _prompt = pipeline.tokenizer.decode(old_tokens + new_token) - # max_new_tokens reduced by one, because the pipeline always grabs - # the first generated token from the prefill - out = TextGeneration(model_path=model_attributes[1])( - prompt=prompt, max_new_tokens=max_new_tokens - 1 - ) + + pipeline_kv_cache = TextGeneration(model_path=model_attributes[1]) + out = pipeline_kv_cache(prompt=prompt, max_length=max_length, output_scores=True) + kv_cache_scores = out.generations[0].score + + non_kv_cache_logits = numpy.concatenate(non_kv_cache_logits, axis=0) + + assert numpy.allclose(non_kv_cache_logits, kv_cache_scores, atol=0.001) assert _prompt == prompt + out.generations[0].text diff --git a/tests/deepsparse/transformers/text_generation/unit/text_generation/test_token_generation.py b/tests/deepsparse/transformers/text_generation/unit/text_generation/test_token_generation.py index ab126ff72e..0c0ad1e703 100644 --- a/tests/deepsparse/transformers/text_generation/unit/text_generation/test_token_generation.py +++ b/tests/deepsparse/transformers/text_generation/unit/text_generation/test_token_generation.py @@ -52,18 +52,23 @@ def test_prep_for_generation( prompt_logits = [numpy.random.rand(1, len(mock_tokens_multiple), len(tokenizer))] mock_inference_state.update_state( - {"prompt_logits": prompt_logits, "max_tokens": 10} + { + "prompt_logits": prompt_logits, + "include_prompt_logits": True, + "max_tokens": 10, + } ) outputs, state = prep_for_generation.run( tokens=mock_tokens_multiple, kv_cache=mock_kv_cache_three_tokens_processed, inference_state=mock_inference_state, ) - assert len(outputs.get("tokens")) == len(mock_tokens_multiple) + 1 assert outputs.get("in_generation") assert numpy.array_equal( - state.get("generated_logits")[0], - numpy.expand_dims(prompt_logits[0][:, -1, :], 0), + numpy.concatenate( + (state.get("generated_logits")[0], outputs.get("prompt_logits")), axis=1 + ), + prompt_logits[0], )