Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Preventing the loss from being computed when the input token is EOS Token #878

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

ShashankMosaicML
Copy link
Contributor

@ShashankMosaicML ShashankMosaicML commented Jan 17, 2024

The model should not be trained to predict the word after the eos_token, because it comes from a different sequence. This PR implements this logic.

TODO: Experimental verification.

@ShashankMosaicML ShashankMosaicML changed the title Preventing the loss from being computed when the input token is EOS Token WIP: Preventing the loss from being computed when the input token is EOS Token Jan 17, 2024
@samhavens
Copy link
Contributor

I think having this option is good, some users almost certainly want it.

However, I think this should be optional, as I am not convinced it shouldn't learn to predict the token after EOS. I'd expect the model to learn that after EOS (if sequences are joined randomly) it can disregard all context and pick from the distribution of tokens which begin sequences. This is a different distribution than raw unigram frequencies, which are the probabilities it should use when picking a token not conditioned on EOS.

Then, if sequences are not joined randomly, as in that TSP NN method, we definitely want to compute loss.

@ShashankMosaicML
Copy link
Contributor Author

ShashankMosaicML commented Jan 18, 2024

Then, if sequences are not joined randomly, as in that TSP NN method, we definitely want to compute loss.

Thanks for your comment! Yes, what you said makes sense. This is still very much a work in progress, and I just wanted to run some experimental tests initially to sanity check.
Also, this is mainly for the case where we do sequence id based masking. In that case, the eos token is still a part of the previous sequence, but its target is the first word of the next sequence.

@vchiley
Copy link
Contributor

vchiley commented Jan 18, 2024

@samhavens should we also add the option to not predict BOS (assuming the previous tok is the end of the previous seq).

@samhavens
Copy link
Contributor

@vchiley for models which have both EOS and BOS, are you saying don't learn that BOS comes after EOS? it isn't worth learning, true, but also... we'll always stop generating at EOS so it wouldn't matter... or am I misunderstanding

@samhavens
Copy link
Contributor

as discussed on Slack, I think that:

  • EOS is effectively a BOS token, and so we want P(t|EOS) to be different than P(t), so we don't want to mask this loss
  • however, when doing seq id masking, we currently mask EOS for every token other than the first, so we learn P(t_0|EOS), P(t_1|t_0), P(t_2|t_0, t_1), ...
  • So @ShashankMosaicML will confirm this and if it is happening, shift the mask so that EOS is visible after t_0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants