A simple class override that allows ensembling T5 models from huggingface/transformers during inference. Works with trainer().
model1 = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
model2 = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
num_beams = 3
model = EnsembledT5ForConditionalGeneration(model1, model2, num_beams)
- Currently works with two models only
- Only for inference, using either model.generate() or trainer with the flag predict_with_generate = True
- The models must be the same architecture with the same config files. Ideally sister checkpoints.
- Currently tested only for beam_search generation
- Works on single GPU only
https://colab.research.google.com/drive/1JE5kBpwK5qFY8JDtIX3oVWcwC-_BX74n?usp=sharing https://www.kaggle.com/sameen53/ensemblet5-demo