-
Notifications
You must be signed in to change notification settings - Fork 0
/
outputs.py
102 lines (85 loc) · 3.82 KB
/
outputs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
"""
from dataclasses import dataclass
from typing import Optional, Sequence
import numpy as np
from torch_ecg.components.outputs import ClassificationOutput
from cfg import BaseCfg
__all__ = [
"CINC2023Outputs",
]
cpc_map_inv = {v: k for k, v in BaseCfg.cpc_map.items()}
outcome_map_inv = {v: k for k, v in BaseCfg.outcome_map.items()}
cpc2outcome_map = {
# map of mapped cpc (int) to mapped outcome (int)
k: BaseCfg.outcome_map[BaseCfg.cpc2outcome_map[v]]
for k, v in cpc_map_inv.items()
} # {0: 0, 1: 0, 2: 1, 3: 1, 4: 1}
cpc2outcome_map_inv = {
v: [k for k, v_ in cpc2outcome_map.items() if v == v_] for v in set(cpc2outcome_map.values())
} # {0: [0, 1], 1: [2, 3, 4]}
@dataclass
class CINC2023Outputs:
"""Output class for CinC2023.
Attributes
----------
cpc_output : ClassificationOutput
Output container for CPC, containing the predicted classes, probabilities, etc.
cpc_loss : Sequence[float]
Loss for CPC.
cpc_value : Sequence[float]
CPC value (float), i.e., the predicted value of the CPC.
outcome_output : ClassificationOutput
Output container for outcome, containing the predicted classes, probabilities, etc.
outcome_loss : Sequence[float]
Loss for outcome.
outcome : Sequence[str]
Outcome, i.e., the predicted class names (str) of the outcome.
.. note::
- If `cpc_output` is provided, then `outcome_output` will be inferred from `cpc_output`.
- `outcome` will be inferred from `outcome_output` if `outcome` is not provided.
Otherwise, consistency check will be performed between `outcome` and `outcome_output`.
- `cpc_value` will be inferred from `cpc_output` if `cpc_value` is not provided.
Otherwise, consistency check will be performed between `cpc_value` and `cpc_output`.
"""
cpc_output: Optional[ClassificationOutput] = None
cpc_loss: Optional[Sequence[float]] = None
cpc_value: Optional[Sequence[float]] = None
outcome_output: Optional[ClassificationOutput] = None
outcome_loss: Optional[Sequence[float]] = None
outcome: Optional[Sequence[str]] = None
def __post_init__(self):
assert any(
[
self.cpc_output is not None,
self.outcome_output is not None,
]
), "At least one output should be provided"
if self.outcome_output is None:
prob = np.zeros((len(self.cpc_output.pred), len(BaseCfg.outcome)))
# merge the probablities of the same outcome via max
for k, v in cpc2outcome_map_inv.items():
prob[:, k] = self.cpc_output.prob[:, v].max(axis=1)
# apply the softmax
prob = np.exp(prob) / np.exp(prob).sum(axis=1, keepdims=True)
self.outcome_output = ClassificationOutput(
classes=BaseCfg.outcome,
pred=np.array([cpc2outcome_map[p] for p in self.cpc_output.pred]),
prob=prob,
)
if self.outcome_output is not None:
outcome = [outcome_map_inv[p] for p in self.outcome_output.pred]
if self.outcome is None:
self.outcome = outcome
else:
assert len(self.outcome) == len(outcome) and all(
[o1 == o2 for o1, o2 in zip(self.outcome, outcome)]
), "the provided outcome is not consistent with the outcome_output"
if self.cpc_output is not None:
cpc_value = [float(cpc_map_inv[p]) for p in self.cpc_output.pred]
if self.cpc_value is None:
self.cpc_value = cpc_value
else:
assert len(self.cpc_value) == len(cpc_value) and all(
[v1 == v2 for v1, v2 in zip(self.cpc_value, cpc_value)]
), "the provided cpc_value is not consistent with the cpc_output"