-
Notifications
You must be signed in to change notification settings - Fork 42
WIP: BPE Training ctc loss and label smooth loss #219
base: master
Are you sure you want to change the base?
Conversation
optimizer=optimizer, | ||
) | ||
|
||
total_objf += curr_batch_objf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be: total_objf += curr_batch_objf * curr_batch_num_utts, because you'll later be normalizing by epoch_num_utts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
After fixing some bugs, wer on test-clean decrease from 3.97% to 3.32%, though it's still higher than espnet's 2.97%. |
Great work!! |
Sorry for my late response. First, I just want to know whether the difference comes from the training part (probably so) or other parts.
Also, could you point out the main script to me? |
Thanks for your kindly help! @sw005320
tensorboard for above screenshot
Not yet. I will compare the differences between espnet and snowfall.
Currently this pr is about training part. and #227 is focusing on decoding part. For training, this shell scripts is the entrance.
Some hyper-parameters is hard-coded in bpe_ctc_att_conformer_train.py or constructor of class Conformer. Some of them are listed below:
For decoding part, the latest decoding implementations is #217, and I plan to port them to espnet after it's approved and finally merged into sowfall. The entrance of decoding is here
A core function of decoding in bpe_ctc_att_conformer_decode.py is here |
Thanks! |
Did you tune it? |
…/glynpu/snowfall into training_ctcLoss_labelSmoothLoss
Latest results are:
Result difference between current pr and espnet is sovled by tune training hyper-parameters with following modifications:
Reason of previous modifications are: As a matter of experience, smaller batch_size is compatible with smaller learning rate, so half the learning rate. The module feat_batch_norm also helps, resulting 3.32 --> 3.17. As 35 epochs --> 50 eochs, I just set it arbitrarily to see what will happen with more epochs. BTW, I failed to increase max_duration=200 because larger max_duration easily cause OOM. 200 seems the largest with my GPUs. |
I feel that 80,000 warm-up steps are too large. It requires larger epochs to make training converged. I think you can find some optimal points with fewer warm-up steps and comparable performance. Also, how about using the 3000 batches? |
80,000 is calculated from:
6000 batches --> 3000 batches means max_duration = 200 --> max_duration = 400; |
As we mentioned in person, I believe a problem with the current setup is that the transformer loss is being normalized (divided by the minibatch size) twice, once in a library function and once in the training script, while the CTC loss is only normalized once. |
FYI, espnet did not normalize the CTC and attention loss by the length. |
ctc_loss = ctc_loss.sum() / bno | ||
|
||
if att_rate != 0.0: | ||
loss = ((1.0 - att_rate) * ctc_loss + att_rate * att_loss) * accum_grad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we mentioned in person, I believe a problem with the current setup is that the transformer loss is being normalized (divided by the minibatch size) twice, once in a library function and once in the training script, while the CTC loss is only normalized once.
If we had logged the 2 objectives separately, we likely would have noticed this.
@danpovey I don't think att_loss is normalized twice. After it is computed at line 85, there is no extra normalization for att_loss anymore.
BTW, the reduplicated normization for att_loss your mentioned maybe about another code(not this pr), which is
if att_rate != 0.0:
loss = (- (1.0 - att_rate) * tot_score + att_rate * att_loss) / (len(texts) * accum_grad)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, OK.
In any case, IMO we shouldn't be normalizing even once. (so remove the "/ bno" above; and remove the same thing in the library function that computes the att_loss).
IMO, we should also remove the gradient-clipping step; my feeling is that if it was helping before, it was helping because it was compensating for the normalization that shouldn't have been happening. This setup is non-recurrent so gradient clipping should not be needed. (However, if we encounter instabilities we can revisit this).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model that removing gradient-clipping and normalizing is training now.
The loss curve is as follows, in the beginning of each epoch, the loss value increases suddenly. Did it encounter instabilities?
total loss | ctc loss | att loss |
---|---|---|
The displaying loss value in tensorboard is normalized by num of utterances.
sys.exit(-1) | ||
|
||
# TODO(Liyong Guo) make this configurable. | ||
lang_dir = Path('data/en_token_list/bpe_unigram5000/') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you generate the directory data/en_token_list/bpe_unigram5000/
?
I don't find any code responsible for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will the code to generate bpe related files.
I will make a pr to your branch. @glynpu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, they are downloaded from the model zoo.
Please refer to the bpe_run.sh, which contains following downloading code:
git clone https://huggingface.co/GuoLiyong/snowfall_bpe_model
for sub_dir in data; do
ln -sf snowfall_bpe_model/$sub_dir ./
done
Actually, I deliberately don't sumit the code about training bpe model, because this pr is mainly about training pipeline.
@@ -87,6 +91,9 @@ def __init__(self, num_features: int, num_classes: int, subsampling_factor: int | |||
else: | |||
self.decoder_criterion = None | |||
|
|||
# Reference: https://github.com/espnet/espnet/blob/master/espnet2/asr/ctc.py#L37 | |||
self.ctc_loss_fn = torch.nn.CTCLoss(reduction='none') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not retain this pattern, we can make the CTC loss part of the training code. This may not carry over so easily to k2 code, and in any case is a little inflexible for our purposes.
Does anyone have any pointers to visualizations of the decoder attention in the application of transformers to ASR? I want to get a feel for how it works. |
add k2 ctcloss
As metioned in #217,currently bpe training with ctcLoss and labelSmoothLoss in snowfall obtain higher wer than that of espnet.
The PROBLEM I am facing is:
Wer of snowfall trained models is still a little higher than the model of espnet trained, by 3.32% > 2.97%.(fixed by correcting datapreparation mistake)
During espnet training: loss_att and loss_ctc always have the same order of magnitude, i.e. they decrease at the same pace.However during snowfall training, loss_att decrease sharply to even below 1.0 while loss_ctc keeps more than [30 to 100] times larger than loss_att.espnet training log file: https://github.com/glynpu/bpe_training_log_files/blob/master/espnet-egs2-librispeech-asr1-exp-asr_train_asr_conformer7_n_fft512_hop_length256_raw_en_bpe5000_sp-train.log
snowfall training log file of wer 3.97%: https://github.com/glynpu/bpe_training_log_files/blob/master/snowfall-egs-librispeech-asr-simple_v1-train_log.txtsnowfall training log file of wer 3.32% experiment:
https://github.com/glynpu/bpe_training_log_files/blob/master/wer_3.32_June_26_snowfall_egs_librispeech-asr-simple_v1-train_log.txt
What I have tried to make compariable between espnet and snowfall are: