From 9688fcf36355d41263f9b1caac1be33e271a9022 Mon Sep 17 00:00:00 2001 From: Thilina Rajapakse Date: Sat, 5 Dec 2020 20:51:46 +0530 Subject: [PATCH] Fixed issue with class weights --- .../classification/classification_model.py | 3 - .../transformer_models/albert_model.py | 7 +- .../transformer_models/bert_model.py | 7 +- .../transformer_models/distilbert_model.py | 14 ++- .../transformer_models/electra_model.py | 7 +- .../transformer_models/flaubert_model.py | 7 +- .../transformer_models/layoutlm_model.py | 7 +- .../transformer_models/mmbt_model.py | 7 +- .../transformer_models/roberta_model.py | 7 +- .../transformer_models/xlm_model.py | 7 +- .../transformer_models/xlnet_model.py | 7 +- train.txt | 100 ------------------ 12 files changed, 57 insertions(+), 123 deletions(-) delete mode 100644 train.txt diff --git a/simpletransformers/classification/classification_model.py b/simpletransformers/classification/classification_model.py index d142dc11..02f486e8 100755 --- a/simpletransformers/classification/classification_model.py +++ b/simpletransformers/classification/classification_model.py @@ -1549,9 +1549,6 @@ def _get_inputs_dict(self, batch): if self.args.model_type == "layoutlm": inputs["bbox"] = batch[4] - if self.weight is not None: - inputs["class_weights"] = self.weight - return inputs def _get_last_metrics(self, metric_values): diff --git a/simpletransformers/classification/transformer_models/albert_model.py b/simpletransformers/classification/transformer_models/albert_model.py index 52df9119..625c90f2 100755 --- a/simpletransformers/classification/transformer_models/albert_model.py +++ b/simpletransformers/classification/transformer_models/albert_model.py @@ -51,7 +51,6 @@ def forward( head_mask=None, inputs_embeds=None, labels=None, - class_weights=None, ): outputs = self.albert( @@ -76,7 +75,11 @@ def forward( loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: - loss_fct = CrossEntropyLoss(weight=class_weights) + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs diff --git a/simpletransformers/classification/transformer_models/bert_model.py b/simpletransformers/classification/transformer_models/bert_model.py index 0cda431e..ca20dd68 100755 --- a/simpletransformers/classification/transformer_models/bert_model.py +++ b/simpletransformers/classification/transformer_models/bert_model.py @@ -52,7 +52,6 @@ def forward( head_mask=None, inputs_embeds=None, labels=None, - class_weights=None, ): outputs = self.bert( @@ -77,7 +76,11 @@ def forward( loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: - loss_fct = CrossEntropyLoss(weight=class_weights) + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs diff --git a/simpletransformers/classification/transformer_models/distilbert_model.py b/simpletransformers/classification/transformer_models/distilbert_model.py index d6b755d0..568d8bef 100755 --- a/simpletransformers/classification/transformer_models/distilbert_model.py +++ b/simpletransformers/classification/transformer_models/distilbert_model.py @@ -44,7 +44,13 @@ def __init__(self, config, weight=None): self.init_weights() def forward( - self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, class_weights=None, + self, + input_ids=None, + attention_mask=None, + head_mask=None, + inputs_embeds=None, + labels=None, + class_weights=None, ): distilbert_output = self.distilbert(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask) hidden_state = distilbert_output[0] # (bs, seq_len, dim) @@ -60,7 +66,11 @@ def forward( loss_fct = nn.MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: - loss_fct = CrossEntropyLoss(weight=class_weights) + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs diff --git a/simpletransformers/classification/transformer_models/electra_model.py b/simpletransformers/classification/transformer_models/electra_model.py index 43336630..f82607db 100755 --- a/simpletransformers/classification/transformer_models/electra_model.py +++ b/simpletransformers/classification/transformer_models/electra_model.py @@ -54,7 +54,6 @@ def forward( head_mask=None, inputs_embeds=None, labels=None, - class_weights=None, ): discriminator_hidden_states = self.electra( @@ -71,7 +70,11 @@ def forward( loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: - loss_fct = CrossEntropyLoss(weight=class_weights) + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) output = (logits,) + discriminator_hidden_states[1:] diff --git a/simpletransformers/classification/transformer_models/flaubert_model.py b/simpletransformers/classification/transformer_models/flaubert_model.py index 01a1d528..74026455 100644 --- a/simpletransformers/classification/transformer_models/flaubert_model.py +++ b/simpletransformers/classification/transformer_models/flaubert_model.py @@ -55,7 +55,6 @@ def forward( head_mask=None, inputs_embeds=None, labels=None, - class_weights=None, ): transformer_outputs = self.transformer( input_ids, @@ -79,7 +78,11 @@ def forward( loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: - loss_fct = CrossEntropyLoss(weight=class_weights) + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs diff --git a/simpletransformers/classification/transformer_models/layoutlm_model.py b/simpletransformers/classification/transformer_models/layoutlm_model.py index b3b6441e..562f2407 100644 --- a/simpletransformers/classification/transformer_models/layoutlm_model.py +++ b/simpletransformers/classification/transformer_models/layoutlm_model.py @@ -25,7 +25,6 @@ def forward( head_mask=None, inputs_embeds=None, labels=None, - class_weights=None, ): outputs = self.bert( @@ -50,7 +49,11 @@ def forward( loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: - loss_fct = CrossEntropyLoss(weight=class_weights) + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs diff --git a/simpletransformers/classification/transformer_models/mmbt_model.py b/simpletransformers/classification/transformer_models/mmbt_model.py index 32771534..06815085 100644 --- a/simpletransformers/classification/transformer_models/mmbt_model.py +++ b/simpletransformers/classification/transformer_models/mmbt_model.py @@ -54,7 +54,6 @@ def forward( head_mask=None, inputs_embeds=None, labels=None, - class_weights=None, ): outputs = self.mmbt( @@ -84,7 +83,11 @@ def forward( loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: - loss_fct = CrossEntropyLoss(weight=class_weights) + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs diff --git a/simpletransformers/classification/transformer_models/roberta_model.py b/simpletransformers/classification/transformer_models/roberta_model.py index b5ae036d..6d107905 100755 --- a/simpletransformers/classification/transformer_models/roberta_model.py +++ b/simpletransformers/classification/transformer_models/roberta_model.py @@ -59,7 +59,6 @@ def forward( head_mask=None, inputs_embeds=None, labels=None, - class_weights=None, ): outputs = self.roberta( input_ids, @@ -78,7 +77,11 @@ def forward( loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: - loss_fct = CrossEntropyLoss(weight=class_weights) + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs diff --git a/simpletransformers/classification/transformer_models/xlm_model.py b/simpletransformers/classification/transformer_models/xlm_model.py index 0a8fa45a..6bb9fca4 100755 --- a/simpletransformers/classification/transformer_models/xlm_model.py +++ b/simpletransformers/classification/transformer_models/xlm_model.py @@ -54,7 +54,6 @@ def forward( head_mask=None, inputs_embeds=None, labels=None, - class_weights=None, ): transformer_outputs = self.transformer( input_ids, @@ -78,7 +77,11 @@ def forward( loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: - loss_fct = CrossEntropyLoss(weight=class_weights) + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs diff --git a/simpletransformers/classification/transformer_models/xlnet_model.py b/simpletransformers/classification/transformer_models/xlnet_model.py index 6335b024..8de7754a 100755 --- a/simpletransformers/classification/transformer_models/xlnet_model.py +++ b/simpletransformers/classification/transformer_models/xlnet_model.py @@ -60,7 +60,6 @@ def forward( head_mask=None, inputs_embeds=None, labels=None, - class_weights=None, ): transformer_outputs = self.transformer( input_ids, @@ -85,7 +84,11 @@ def forward( loss_fct = MSELoss() loss = loss_fct(logits.view(-1), labels.view(-1)) else: - loss_fct = CrossEntropyLoss(weight=class_weights) + if self.weight is not None: + weight = self.weight.to(labels.device) + else: + weight = None + loss_fct = CrossEntropyLoss(weight=weight) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) outputs = (loss,) + outputs diff --git a/train.txt b/train.txt deleted file mode 100644 index 721c1e19..00000000 --- a/train.txt +++ /dev/null @@ -1,100 +0,0 @@ -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers! -Hello world with Simple Transformers!