Skip to content

Commit

Permalink
Add inference api eval wrapper (#494)
Browse files Browse the repository at this point in the history
* add subset num batches

* add subset num batches

* remove tiktoken

* remove openai import

* remove bad line

* foo

* add training callback

* modify yamls

* implement train

* fix indexing to get most recent eval result

* finish

* finish

* finish

* finish

* finish

* foo

* foo

* working on debugging changeS

* [wip] removing logger dependency from model gauntlet

* remove logger from eval

* remove logger from eval

* remove logger from eval

* debug

* debug

* debug

* debug

* fix

* finish?

* fix bug

* merge main

* fix bug

* Revert "ignore empty outputs"

This reverts commit e0d77bb282c82daa2db686450a551d671d715f27.

* fix pyright

* fix pyright

* update versions

* fix

* merge updates

* remove info from yamls

* remove load in 8bit

* address comments

* address comments

* address comments

* add monkeypatch

* add back in bsz

* add back in bsz

* add openai reqs

* remove branch

* fix conditional import

* Update llmfoundry/models/inference_api_wrapper/interface.py

Co-authored-by: Daniel King <[email protected]>

* Update llmfoundry/models/inference_api_wrapper/openai_causal_lm.py

Co-authored-by: Daniel King <[email protected]>

* Update llmfoundry/models/inference_api_wrapper/interface.py

Co-authored-by: Daniel King <[email protected]>

* fix comments

* fix comments

* Update tests/test_inference_api_eval_wrapper.py

* Update tests/test_inference_api_eval_wrapper.py

* Update tests/test_inference_api_eval_wrapper.py

* Update tests/test_inference_api_eval_wrapper.py

* Update tests/test_inference_api_eval_wrapper.py

* pyright ignore

* Update tests/test_inference_api_eval_wrapper.py

* Update tests/test_inference_api_eval_wrapper.py

* Update tests/test_inference_api_eval_wrapper.py

* Update tests/test_inference_api_eval_wrapper.py

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
bmosaicml and dakinggg authored Sep 18, 2023
1 parent 1e7f909 commit c369a68
Show file tree
Hide file tree
Showing 11 changed files with 763 additions and 13 deletions.
14 changes: 14 additions & 0 deletions llmfoundry/models/inference_api_wrapper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.models.inference_api_wrapper.interface import \
InferenceAPIEvalWrapper
from llmfoundry.models.inference_api_wrapper.openai_causal_lm import (
OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, OpenAITokenizerWrapper)

__all__ = [
'OpenAICausalLMEvalWrapper',
'OpenAIChatAPIEvalWrapper',
'OpenAITokenizerWrapper',
'InferenceAPIEvalWrapper',
]
110 changes: 110 additions & 0 deletions llmfoundry/models/inference_api_wrapper/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Any, Dict, Optional

import torch
from composer.core.types import Batch
from composer.metrics import InContextLearningMetric
from composer.metrics.nlp import (InContextLearningLMAccuracy,
InContextLearningLMExpectedCalibrationError,
InContextLearningMCExpectedCalibrationError,
InContextLearningMultipleChoiceAccuracy,
InContextLearningQAAccuracy,
LanguageCrossEntropy, LanguagePerplexity)
from composer.models import ComposerModel
from torchmetrics import Metric
from transformers import AutoTokenizer


class InferenceAPIEvalWrapper(ComposerModel):

def __init__(self, model_cfg: Dict, tokenizer: AutoTokenizer):
self.tokenizer = tokenizer
self.labels = None
# set up training and eval metrics
eval_metrics = [
LanguageCrossEntropy(),
LanguagePerplexity(),
InContextLearningLMAccuracy(),
InContextLearningMultipleChoiceAccuracy(),
InContextLearningQAAccuracy(),
InContextLearningLMExpectedCalibrationError(),
InContextLearningMCExpectedCalibrationError()
]
self.eval_metrics = {
metric.__class__.__name__: metric for metric in eval_metrics
}
super().__init__()

def get_metrics(self, is_train: bool = False):
if is_train:
raise NotImplementedError(
'You cannot use inference wrappers for training')
else:
metrics = self.eval_metrics

return metrics if metrics else {}

def get_next_token_logit_tensor(self,
prompt: str) -> Optional[torch.Tensor]:
raise NotImplementedError

def rebatch(self, batch: Batch):
# default is a no-op, but Chat API modifies these
return batch

def eval_forward(self, batch: Batch, outputs: Optional[Any] = None):
# If the batch mode is generate, we will generate a requested number of tokens using the underlying
# model's generate function. Extra generation kwargs can be passed in via the batch. Strings will
# be returned from eval_forward
output_logits_batch = []
for tokens, cont_idxs in zip(batch['input_ids'],
batch['continuation_indices']):

seqlen = tokens.shape[0]
tokens = tokens.tolist()
cont_idxs = cont_idxs.tolist()
expected_cont_tokens = tokens[cont_idxs[0]:cont_idxs[-1] + 1]
output_logits = torch.nn.functional.one_hot(
torch.tensor(tokens[1:cont_idxs[0]]),
num_classes=self.tokenizer.vocab_size)
for i in range(len(expected_cont_tokens)):
# decode one token at a time
prompt = self.tokenizer.decode(tokens[:cont_idxs[0]] +
expected_cont_tokens[0:i])
next_logit_tensor = self.get_next_token_logit_tensor(prompt)
if next_logit_tensor is None:
continue
output_logits = torch.cat(
[output_logits,
next_logit_tensor.reshape(1, -1)])
padding = torch.nn.functional.one_hot(
torch.full((seqlen - output_logits.shape[0],),
self.tokenizer.pad_token_id),
num_classes=self.tokenizer.vocab_size)
output_logits = torch.cat([output_logits, padding])
output_logits_batch.append(output_logits)

return torch.stack(output_logits_batch).to(batch['input_ids'].device)

def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
batch = self.rebatch(batch)
self.labels = batch.pop('labels')
self.labels[:, :-1] = self.labels[:, 1:].clone()
self.labels[:, -1] = -100
if isinstance(metric, InContextLearningMetric) and batch.get(
'mode', None) == 'icl_task':
assert self.labels is not None
metric.update(batch, outputs, self.labels)
else:
raise NotImplementedError(
'Inference API wrapper only supports InContextLearningMetrics and mode=icl_task'
)

def forward(self):
raise NotImplementedError(
"Inference API wrapper doesn't support forward")

def loss(self):
raise NotImplementedError("Inference API wrapper doesn't support loss")
Loading

0 comments on commit c369a68

Please sign in to comment.