Skip to content

Commit

Permalink
Fixed issue with class weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Thilina Rajapakse committed Dec 5, 2020
1 parent 4a16e77 commit 9688fcf
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 123 deletions.
3 changes: 0 additions & 3 deletions simpletransformers/classification/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1549,9 +1549,6 @@ def _get_inputs_dict(self, batch):
if self.args.model_type == "layoutlm":
inputs["bbox"] = batch[4]

if self.weight is not None:
inputs["class_weights"] = self.weight

return inputs

def _get_last_metrics(self, metric_values):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def forward(
head_mask=None,
inputs_embeds=None,
labels=None,
class_weights=None,
):

outputs = self.albert(
Expand All @@ -76,7 +75,11 @@ def forward(
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(weight=class_weights)
if self.weight is not None:
weight = self.weight.to(labels.device)
else:
weight = None
loss_fct = CrossEntropyLoss(weight=weight)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def forward(
head_mask=None,
inputs_embeds=None,
labels=None,
class_weights=None,
):

outputs = self.bert(
Expand All @@ -77,7 +76,11 @@ def forward(
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(weight=class_weights)
if self.weight is not None:
weight = self.weight.to(labels.device)
else:
weight = None
loss_fct = CrossEntropyLoss(weight=weight)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def __init__(self, config, weight=None):
self.init_weights()

def forward(
self, input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, labels=None, class_weights=None,
self,
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
labels=None,
class_weights=None,
):
distilbert_output = self.distilbert(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask)
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
Expand All @@ -60,7 +66,11 @@ def forward(
loss_fct = nn.MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(weight=class_weights)
if self.weight is not None:
weight = self.weight.to(labels.device)
else:
weight = None
loss_fct = CrossEntropyLoss(weight=weight)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def forward(
head_mask=None,
inputs_embeds=None,
labels=None,
class_weights=None,
):

discriminator_hidden_states = self.electra(
Expand All @@ -71,7 +70,11 @@ def forward(
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(weight=class_weights)
if self.weight is not None:
weight = self.weight.to(labels.device)
else:
weight = None
loss_fct = CrossEntropyLoss(weight=weight)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

output = (logits,) + discriminator_hidden_states[1:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def forward(
head_mask=None,
inputs_embeds=None,
labels=None,
class_weights=None,
):
transformer_outputs = self.transformer(
input_ids,
Expand All @@ -79,7 +78,11 @@ def forward(
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(weight=class_weights)
if self.weight is not None:
weight = self.weight.to(labels.device)
else:
weight = None
loss_fct = CrossEntropyLoss(weight=weight)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def forward(
head_mask=None,
inputs_embeds=None,
labels=None,
class_weights=None,
):

outputs = self.bert(
Expand All @@ -50,7 +49,11 @@ def forward(
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(weight=class_weights)
if self.weight is not None:
weight = self.weight.to(labels.device)
else:
weight = None
loss_fct = CrossEntropyLoss(weight=weight)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def forward(
head_mask=None,
inputs_embeds=None,
labels=None,
class_weights=None,
):

outputs = self.mmbt(
Expand Down Expand Up @@ -84,7 +83,11 @@ def forward(
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(weight=class_weights)
if self.weight is not None:
weight = self.weight.to(labels.device)
else:
weight = None
loss_fct = CrossEntropyLoss(weight=weight)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def forward(
head_mask=None,
inputs_embeds=None,
labels=None,
class_weights=None,
):
outputs = self.roberta(
input_ids,
Expand All @@ -78,7 +77,11 @@ def forward(
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(weight=class_weights)
if self.weight is not None:
weight = self.weight.to(labels.device)
else:
weight = None
loss_fct = CrossEntropyLoss(weight=weight)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def forward(
head_mask=None,
inputs_embeds=None,
labels=None,
class_weights=None,
):
transformer_outputs = self.transformer(
input_ids,
Expand All @@ -78,7 +77,11 @@ def forward(
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(weight=class_weights)
if self.weight is not None:
weight = self.weight.to(labels.device)
else:
weight = None
loss_fct = CrossEntropyLoss(weight=weight)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def forward(
head_mask=None,
inputs_embeds=None,
labels=None,
class_weights=None,
):
transformer_outputs = self.transformer(
input_ids,
Expand All @@ -85,7 +84,11 @@ def forward(
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(weight=class_weights)
if self.weight is not None:
weight = self.weight.to(labels.device)
else:
weight = None
loss_fct = CrossEntropyLoss(weight=weight)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

Expand Down
100 changes: 0 additions & 100 deletions train.txt

This file was deleted.

0 comments on commit 9688fcf

Please sign in to comment.