-
Notifications
You must be signed in to change notification settings - Fork 2
/
main.py
34 lines (21 loc) · 839 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from src.models.model import SwarModel
from src.file_format.fileformat import BitPacker, BitUnpacker
import torch
import io
model = SwarModel.build_model(24).to(torch.device('cuda'))
path = "./models/" + <model_name>
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
import torchaudio
wf, sr = torchaudio.load('<path-to-audio-file>')
model.set_target_bandwidth(24) # 24 represent that the target bandwidth is 24kbps
output = model._encode_frame(wf.view(1, 1, wf.shape[-1]).to(torch.device('cuda')))[0]
output = torch.flatten(output).tolist()
fo = io.BytesIO()
packer = BitPacker(10, fo)
for token in output:
packer.push(token)
packer.flush()
fo.seek()
with open('<file-name>.nac', 'wb') as outputfile: # saving compressed audio to custom file format
outputfile.write(fo.getbuffer())