-
Notifications
You must be signed in to change notification settings - Fork 3
/
test.py
45 lines (34 loc) · 1.32 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
import sys
sys.path.append("/root/Speech2Intent/s2i-baselines")
import torch
import torch.nn as nn
import torch.nn.functional as F
# choose the model
from trainer_whisper import LightningModel
# from trainer_hubert import LightningModel
# from trainer_wav2vec2 import LightningModel
from dataset import S2IMELDataset, collate_fn
from tqdm import tqdm
from sklearn.metrics import f1_score, accuracy_score
dataset = S2IMELDataset(
csv_path="/root/Speech2Intent/dataset/speech-to-intent/test.csv",
wav_dir_path="/root/Speech2Intent/dataset/speech-to-intent/",
)
# change path to checkpoint
model = LightningModel.load_from_checkpoint("/root/Speech2Intent/s2i-baselines/checkpoints/whisper_asr_small.ckpt")
model.to('cuda')
model.eval()
trues=[]
preds = []
for x, label in tqdm(dataset):
x_tensor = x.to("cuda").unsqueeze(0)
y_hat_l = model(x_tensor)
probs = F.softmax(y_hat_l, dim=1).detach().cpu().view(1, 14)
pred = probs.argmax(dim=1).detach().cpu().numpy().astype(int)
probs = probs.numpy().astype(float).tolist()
trues.append(label)
preds.append(pred)
print(f"Accuracy Score = {accuracy_score(trues, preds)}\nF1-Score = {f1_score(trues, preds, average='weighted')}")