-
Notifications
You must be signed in to change notification settings - Fork 0
/
cal_loss.py
39 lines (36 loc) · 1.32 KB
/
cal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch.nn.functional as F
from torch.autograd import Variable
import torch
def cal_loss(outputs,labels,loss_func):
loss = 0
if isinstance(outputs,list):
for i in outputs:
loss += loss_func(i,labels)
loss = loss/len(outputs)
else:
loss = loss_func(outputs,labels)
return loss
def cal_kl_loss(outputs,outputs2,loss_func):
loss = 0
if isinstance(outputs,list):
for i in range(len(outputs)):
loss += loss_func(F.log_softmax(outputs[i], dim=1),
F.softmax(Variable(outputs2[i]), dim=1))
loss = loss/len(outputs)
else:
loss = loss_func(F.log_softmax(outputs, dim=1),
F.softmax(Variable(outputs2), dim=1))
return loss
def cal_triplet_loss(outputs,outputs2,labels,loss_func,split_num=8):
if isinstance(outputs,list):
loss = 0
for i in range(len(outputs)):
out_concat = torch.cat((outputs[i], outputs2[i]), dim=0)
labels_concat = torch.cat((labels,labels),dim=0)
loss += loss_func(out_concat,labels_concat)
loss = loss/len(outputs)
else:
out_concat = torch.cat((outputs, outputs2), dim=0)
labels_concat = torch.cat((labels,labels),dim=0)
loss = loss_func(out_concat,labels_concat)
return loss