Skip to content

A simple class override that allows ensembling T5 models from huggingface/transformers during inference. Works with trainer().

Notifications You must be signed in to change notification settings

Patchwork53/EnsembleT5

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 

Repository files navigation

EnsembleT5

T5 models from huggingface/transformers

A simple class override that allows ensembling T5 models from huggingface/transformers during inference. Works with trainer().

Usage

model1 = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
model2 = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
num_beams = 3
model = EnsembledT5ForConditionalGeneration(model1, model2, num_beams)

Current Limitations

  • Currently works with two models only
  • Only for inference, using either model.generate() or trainer with the flag predict_with_generate = True
  • The models must be the same architecture with the same config files. Ideally sister checkpoints.
  • Currently tested only for beam_search generation
  • Works on single GPU only

Demo (Colab/Kaggle)

https://colab.research.google.com/drive/1JE5kBpwK5qFY8JDtIX3oVWcwC-_BX74n?usp=sharing https://www.kaggle.com/sameen53/ensemblet5-demo

About

A simple class override that allows ensembling T5 models from huggingface/transformers during inference. Works with trainer().

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages