forked from gwinndr/MusicTransformer-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess_midi.py
125 lines (102 loc) · 3.67 KB
/
preprocess_midi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
from code import interact
import os
import pickle
import json
import third_party.midi_processor.processor as midi_processor
JSON_FILE = "maestro-v2.0.0.json"
# prep_midi
def prep_midi(maestro_root, output_dir, interval = False, logscale=False, octave = False, fusion=False, absolute=False):
"""
----------
Author: Damon Gwinn
----------
Pre-processes the maestro dataset, putting processed midi data (train, eval, test) into the
given output folder
----------
"""
train_dir = os.path.join(output_dir, "train")
os.makedirs(train_dir, exist_ok=True)
val_dir = os.path.join(output_dir, "val")
os.makedirs(val_dir, exist_ok=True)
test_dir = os.path.join(output_dir, "test")
os.makedirs(test_dir, exist_ok=True)
maestro_json_file = os.path.join(maestro_root, JSON_FILE)
if(not os.path.isfile(maestro_json_file)):
print("ERROR: Could not find file:", maestro_json_file)
return False
maestro_json = json.load(open(maestro_json_file, "r"))
print("Found", len(maestro_json), "pieces")
print("Preprocessing...")
total_count = 0
train_count = 0
val_count = 0
test_count = 0
for piece in maestro_json:
mid = os.path.join(maestro_root, piece["midi_filename"])
split_type = piece["split"]
f_name = mid.split("/")[-1] + ".pickle"
if(split_type == "train"):
o_file = os.path.join(train_dir, f_name)
train_count += 1
elif(split_type == "validation"):
o_file = os.path.join(val_dir, f_name)
val_count += 1
elif(split_type == "test"):
o_file = os.path.join(test_dir, f_name)
test_count += 1
else:
print("ERROR: Unrecognized split type:", split_type)
return False
if octave or interval or logscale or fusion or absolute:
prepped = midi_processor.encode_midi_JE(mid, interval=interval, logscale=logscale, octave=octave, fusion=fusion, absolute=absolute)
else:
prepped = midi_processor.encode_midi(mid)
o_stream = open(o_file, "wb")
pickle.dump(prepped, o_stream)
o_stream.close()
total_count += 1
if(total_count % 50 == 0):
print(total_count, "/", len(maestro_json))
print("Num Train:", train_count)
print("Num Val:", val_count)
print("Num Test:", test_count)
return True
# parse_args
def parse_args():
"""
----------
Author: Damon Gwinn
----------
Parses arguments for preprocess_midi using argparse
----------
"""
parser = argparse.ArgumentParser()
parser.add_argument("maestro_root", type=str, help="Root folder for the Maestro dataset")
parser.add_argument("-output_dir", type=str, default="./dataset/e_piano", help="Output folder to put the preprocessed midi into")
return parser.parse_args()
# main
def main():
"""
----------
Author: Damon Gwinn
----------
Entry point. Preprocesses maestro and saved midi to specified output folder.
----------
"""
args = parse_args()
maestro_root = args.maestro_root
output_dir = args.output_dir
#====================================================#
interval = False # interval preprocessing
logscale = True
octave =False
fusion = False
absolute = False
#====================================================#
print("Preprocessing midi files and saving to", output_dir)
prep_midi(maestro_root, output_dir, interval=interval, logscale=logscale, octave = octave, fusion=fusion, absolute=absolute)
print("Done!")
print("")
if __name__ == "__main__":
main()