-
Notifications
You must be signed in to change notification settings - Fork 453
/
utils.py
142 lines (121 loc) · 4.33 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
#
import miditoolkit
import math
pos_resolution = 4 # per beat (quarter note)
bar_max = 256
velocity_quant = 4
tempo_quant = 12 # 2 ** (1 / 12)
min_tempo = 16
max_tempo = 256
duration_max = 4 # 4 * beat
max_ts_denominator = 6 # x/1 x/2 x/4 ... x/64
max_notes_per_bar = 2 # 1/64 ... 128/64
beat_note_factor = 4 # In MIDI format a note is always 4 beats
deduplicate = True
filter_symbolic = False
filter_symbolic_ppl = 16
trunc_pos = 2 ** 16 # approx 30 minutes (1024 measures)
sample_len_max = 200 # window length max
sample_overlap_rate = 4
ts_filter = False
min_pitch = 48
max_pitch = 72
_CHORD_KIND_PITCHES = {
'': [0, 4, 7],
'm': [0, 3, 7],
'+': [0, 4, 8],
'dim': [0, 3, 6],
'7': [0, 4, 7, 10],
'maj7': [0, 4, 7, 11],
'm7': [0, 3, 7, 10],
'm7b5': [0, 3, 6, 10],
}
ts_dict = dict()
ts_list = list()
for i in range(0, max_ts_denominator + 1): # 1 ~ 64
for j in range(1, ((2 ** i) * max_notes_per_bar) + 1):
ts_dict[(j, 2 ** i)] = len(ts_dict)
ts_list.append((j, 2 ** i))
def enc_ts(x):
assert x in ts_dict, 'unsupported time signature: ' + str(x)
return ts_dict[x]
def dec_ts(x):
return ts_list[x]
def enc_dur(x):
return min(x, duration_max * pos_resolution)
def dec_dur(x):
return x
def enc_vel(x):
return x // velocity_quant
def dec_vel(x):
return (x * velocity_quant) + (velocity_quant // 2)
def enc_tpo(x):
x = max(x, min_tempo)
x = min(x, max_tempo)
x = x / min_tempo
e = round(math.log2(x) * tempo_quant)
return e
def dec_tpo(x):
return 2 ** (x / tempo_quant) * min_tempo
def encoding_to_midi(encoding):
bar_to_timesig = [list()
for _ in range(max(map(lambda x: x[0], encoding)) + 1)]
for i in encoding:
bar_to_timesig[i[0]].append(i[6])
bar_to_timesig = [max(set(i), key=i.count) if len(
i) > 0 else None for i in bar_to_timesig]
for i in range(len(bar_to_timesig)):
if bar_to_timesig[i] is None:
bar_to_timesig[i] = enc_ts((
4, 4)) if i == 0 else bar_to_timesig[i - 1]
bar_to_pos = [None] * len(bar_to_timesig)
cur_pos = 0
for i in range(len(bar_to_pos)):
bar_to_pos[i] = cur_pos
ts = dec_ts(bar_to_timesig[i])
measure_length = ts[0] * beat_note_factor * pos_resolution // ts[1]
cur_pos += measure_length
pos_to_tempo = [list() for _ in range(
cur_pos + max(map(lambda x: x[1], encoding)))]
for i in encoding:
pos_to_tempo[bar_to_pos[i[0]] + i[1]].append(i[7])
pos_to_tempo = [round(sum(i) / len(i)) if len(i) >
0 else None for i in pos_to_tempo]
for i in range(len(pos_to_tempo)):
if pos_to_tempo[i] is None:
pos_to_tempo[i] = enc_tpo(80.0) if i == 0 else pos_to_tempo[i - 1]
midi_obj = miditoolkit.midi.parser.MidiFile()
def get_tick(bar, pos):
return (bar_to_pos[bar] + pos) * midi_obj.ticks_per_beat // pos_resolution
midi_obj.instruments = [miditoolkit.containers.Instrument(program=(
0 if i == 128 else i), is_drum=(i == 128), name=str(i)) for i in range(128 + 1)]
for i in encoding:
start = get_tick(i[0], i[1])
program = i[2]
pitch = (i[3] - 128 if program == 128 else i[3])
duration = get_tick(0, dec_dur(i[4]))
end = start + duration
velocity = dec_vel(i[5])
midi_obj.instruments[program].notes.append(miditoolkit.containers.Note(
start=start, end=end, pitch=pitch, velocity=velocity))
midi_obj.instruments = [
i for i in midi_obj.instruments if len(i.notes) > 0]
cur_ts = None
for i in range(len(bar_to_timesig)):
new_ts = bar_to_timesig[i]
if new_ts != cur_ts:
numerator, denominator = dec_ts(new_ts)
midi_obj.time_signature_changes.append(miditoolkit.containers.TimeSignature(
numerator=numerator, denominator=denominator, time=get_tick(i, 0)))
cur_ts = new_ts
cur_tp = None
for i in range(len(pos_to_tempo)):
new_tp = pos_to_tempo[i]
if new_tp != cur_tp:
tempo = dec_tpo(new_tp)
midi_obj.tempo_changes.append(
miditoolkit.containers.TempoChange(tempo=tempo, time=get_tick(0, i)))
cur_tp = new_tp
return midi_obj