Skip to content

Commit

Permalink
it works for a single sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Nov 20, 2023
1 parent 51c4ee6 commit fa96efb
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
15 changes: 6 additions & 9 deletions src/deepsparse/v2/text_generation/nl_engine_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import os
from typing import Any, List, Tuple

import numpy
from pydantic import BaseModel, Field

from deepsparse.transformers.helpers import overwrite_transformer_onnx_model_inputs
from deepsparse.utils.onnx import (
CACHE_INPUT_PREFIX,
overwrite_onnx_model_inputs_for_kv_cache_models,
Expand Down Expand Up @@ -213,12 +213,7 @@ class NlEngineOperatorNoCache(EngineOperator):
input_schema = NlEngineInputNoCache
output_schema = None

def __init__(self, sequence_length, **kwargs):
model_path, *_ = overwrite_transformer_onnx_model_inputs(
path=kwargs.get("model_path"),
max_length=sequence_length,
batch_size=kwargs.get("batch_size", 1),
)
def __init__(self, **kwargs):
super().__init__(**kwargs)

def run(self, inp: NlEngineInputNoCache, **kwargs) -> Any:
Expand All @@ -228,11 +223,13 @@ def run(self, inp: NlEngineInputNoCache, **kwargs) -> Any:
.run(EngineOperatorInputs(engine_inputs=engine_inputs), **kwargs)
.get("engine_outputs")
)

logits = numpy.compress(inp.attention_mask[0], logits[0], axis=1)
return {
"logits": logits,
"logits": [logits],
"logits_shape": None,
"deterministic": None,
"kv_cache": None,
"tokens": None,
"sampling_temperature": None,
}, {"prompt_logits": logits}
}, {"prompt_logits": [logits]}
4 changes: 2 additions & 2 deletions src/deepsparse/v2/text_generation/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
) = setup_transformers_pipeline(
model_path,
sequence_length,
tokenizer_padding_side="right",
onnx_model_name=onnx_model_name,
engine_kwargs=engine_kwargs,
)
Expand All @@ -73,14 +74,13 @@ def __init__(
sequence_length=sequence_length,
tokenizer=self.tokenizer,
),
NlEngineOperatorNoCache(sequence_length=sequence_length, **engine_kwargs),
NlEngineOperatorNoCache(**engine_kwargs),
PrepareGeneration(
sequence_length=sequence_length,
prompt_sequence_length=1,
token_generator=token_generator,
),
GenerateNewTokenOperator(tokenizer=self.tokenizer, force_max_tokens=True),
CompileGeneratedTokens(),
CompileGenerations(),
JoinOutput(tokenizer=self.tokenizer),
ProcessOutputs(tokenizer=self.tokenizer),
Expand Down

0 comments on commit fa96efb

Please sign in to comment.