diff --git a/mmseg/models/losses/cross_entropy_loss.py b/mmseg/models/losses/cross_entropy_loss.py index 65553472c0..988fb789c1 100644 --- a/mmseg/models/losses/cross_entropy_loss.py +++ b/mmseg/models/losses/cross_entropy_loss.py @@ -63,8 +63,9 @@ def cross_entropy(pred, else: # the average factor should take the class weights into account - label_weights = torch.tensor([class_weight[cls] for cls in label], - device=class_weight.device) + label_weights = torch.stack([class_weight[cls] for cls in label + ]).to(device=class_weight.device) + if avg_non_ignore: label_weights[label == ignore_index] = 0 avg_factor = label_weights.sum() diff --git a/tests/test_models/test_losses/test_cross_entropy_loss.py b/tests/test_models/test_losses/test_cross_entropy_loss.py new file mode 100644 index 0000000000..8c6b86d014 --- /dev/null +++ b/tests/test_models/test_losses/test_cross_entropy_loss.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F + +from mmseg.models.losses import CrossEntropyLoss, weight_reduce_loss + + +def test_cross_entropy_loss_class_weights(): + loss_class = CrossEntropyLoss + pred = torch.rand((1, 10, 4, 4)) + target = torch.randint(0, 10, (1, 4, 4)) + class_weight = torch.ones(10) + avg_factor = target.numel() + + cross_entropy_loss = F.cross_entropy( + pred, target, weight=class_weight, reduction='none', ignore_index=-100) + + expected_loss = weight_reduce_loss( + cross_entropy_loss, + weight=None, + reduction='mean', + avg_factor=avg_factor) + + # Test loss forward + loss = loss_class(class_weight=class_weight.tolist())(pred, target) + + assert isinstance(loss, torch.Tensor) + assert expected_loss == loss