Skip to content

Commit

Permalink
refactored model.py and localconfig.py
Browse files Browse the repository at this point in the history
  • Loading branch information
am1tyadav committed Aug 7, 2021
1 parent 0a6eecd commit bdf8fa4
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 586 deletions.
2 changes: 1 addition & 1 deletion output_layers.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"f0_shifts": ["permute", "dense_1", "permute_1", "dense_2", "f0_shifts_decoder_conv_0", "f0_shifts_decoder_bn_0", "f0_shifts_decoder_act_0", "add_9", "f0_shifts_decoder_conv_1", "f0_shifts_decoder_bn_1", "f0_shifts_decoder_act_1", "add_10", "f0_shifts_decoder_conv_2", "f0_shifts_decoder_bn_2", "f0_shifts_decoder_act_2", "add_11", "f0_shifts_decoder_conv_3", "f0_shifts_decoder_bn_3", "f0_shifts_decoder_act_3", "add_12", "conv2d_15", "lambda_1", "dense_3", "layer_normalization_1", "dense_4", "add_13", "layer_normalization_2", "dense_5", "add_14", "layer_normalization_3", "dense_6"], "h_freq_shifts": ["permute_2", "dense_7", "permute_3", "dense_8", "h_freq_shifts_decoder_conv_0", "h_freq_shifts_decoder_bn_0", "h_freq_shifts_decoder_act_0", "add_15", "h_freq_shifts_decoder_conv_1", "h_freq_shifts_decoder_bn_1", "h_freq_shifts_decoder_act_1", "add_16", "h_freq_shifts_decoder_conv_2", "h_freq_shifts_decoder_bn_2", "h_freq_shifts_decoder_act_2", "add_17", "h_freq_shifts_decoder_conv_3", "h_freq_shifts_decoder_bn_3", "h_freq_shifts_decoder_act_3", "add_18", "conv2d_16", "lambda_2", "dense_9", "layer_normalization_4", "dense_10", "add_19", "layer_normalization_5", "dense_11", "add_20", "layer_normalization_6", "dense_12"], "mag_env": ["permute_4", "dense_13", "permute_5", "dense_14", "mag_env_decoder_conv_0", "mag_env_decoder_bn_0", "mag_env_decoder_act_0", "add_21", "mag_env_decoder_conv_1", "mag_env_decoder_bn_1", "mag_env_decoder_act_1", "add_22", "mag_env_decoder_conv_2", "mag_env_decoder_bn_2", "mag_env_decoder_act_2", "add_23", "mag_env_decoder_conv_3", "mag_env_decoder_bn_3", "mag_env_decoder_act_3", "add_24", "conv2d_17", "lambda_3", "dense_15", "layer_normalization_7", "dense_16", "add_25", "layer_normalization_8", "dense_17", "add_26", "layer_normalization_9", "dense_18"], "h_mag_dist": ["permute_6", "dense_19", "permute_7", "dense_20", "h_mag_dist_decoder_conv_0", "h_mag_dist_decoder_bn_0", "h_mag_dist_decoder_act_0", "add_27", "h_mag_dist_decoder_conv_1", "h_mag_dist_decoder_bn_1", "h_mag_dist_decoder_act_1", "add_28", "h_mag_dist_decoder_conv_2", "h_mag_dist_decoder_bn_2", "h_mag_dist_decoder_act_2", "add_29", "h_mag_dist_decoder_conv_3", "h_mag_dist_decoder_bn_3", "h_mag_dist_decoder_act_3", "add_30", "conv2d_18", "lambda_4", "dense_21", "layer_normalization_10", "dense_22", "add_31", "layer_normalization_11", "dense_23", "add_32", "layer_normalization_12", "dense_24"]}
{"f0_shifts": ["permute", "dense_14", "permute_1", "dense_15", "f0_shifts_decoder_conv_0", "f0_shifts_decoder_bn_0", "f0_shifts_decoder_act_0", "add_13", "f0_shifts_decoder_conv_1", "f0_shifts_decoder_bn_1", "f0_shifts_decoder_act_1", "add_14", "f0_shifts_decoder_conv_2", "f0_shifts_decoder_bn_2", "f0_shifts_decoder_act_2", "add_15", "f0_shifts_decoder_conv_3", "f0_shifts_decoder_bn_3", "f0_shifts_decoder_act_3", "add_16", "conv2d_9", "lambda_1", "dense_16", "layer_normalization_7", "dense_17", "add_17", "layer_normalization_8", "dense_18", "add_18", "layer_normalization_9", "dense_19"], "h_freq_shifts": ["permute_2", "dense_20", "permute_3", "dense_21", "h_freq_shifts_decoder_conv_0", "h_freq_shifts_decoder_bn_0", "h_freq_shifts_decoder_act_0", "add_19", "h_freq_shifts_decoder_conv_1", "h_freq_shifts_decoder_bn_1", "h_freq_shifts_decoder_act_1", "add_20", "h_freq_shifts_decoder_conv_2", "h_freq_shifts_decoder_bn_2", "h_freq_shifts_decoder_act_2", "add_21", "h_freq_shifts_decoder_conv_3", "h_freq_shifts_decoder_bn_3", "h_freq_shifts_decoder_act_3", "add_22", "conv2d_10", "lambda_2", "dense_22", "layer_normalization_10", "dense_23", "add_23", "layer_normalization_11", "dense_24", "add_24", "layer_normalization_12", "dense_25"], "mag_env": ["permute_4", "dense_26", "permute_5", "dense_27", "mag_env_decoder_conv_0", "mag_env_decoder_bn_0", "mag_env_decoder_act_0", "add_25", "mag_env_decoder_conv_1", "mag_env_decoder_bn_1", "mag_env_decoder_act_1", "add_26", "mag_env_decoder_conv_2", "mag_env_decoder_bn_2", "mag_env_decoder_act_2", "add_27", "mag_env_decoder_conv_3", "mag_env_decoder_bn_3", "mag_env_decoder_act_3", "add_28", "conv2d_11", "lambda_3", "dense_28", "layer_normalization_13", "dense_29", "add_29", "layer_normalization_14", "dense_30", "add_30", "layer_normalization_15", "dense_31"], "h_mag_dist": ["permute_6", "dense_32", "permute_7", "dense_33", "h_mag_dist_decoder_conv_0", "h_mag_dist_decoder_bn_0", "h_mag_dist_decoder_act_0", "add_31", "h_mag_dist_decoder_conv_1", "h_mag_dist_decoder_bn_1", "h_mag_dist_decoder_act_1", "add_32", "h_mag_dist_decoder_conv_2", "h_mag_dist_decoder_bn_2", "h_mag_dist_decoder_act_2", "add_33", "h_mag_dist_decoder_conv_3", "h_mag_dist_decoder_bn_3", "h_mag_dist_decoder_act_3", "add_34", "conv2d_12", "lambda_4", "dense_34", "layer_normalization_16", "dense_35", "add_35", "layer_normalization_17", "dense_36", "add_36", "layer_normalization_18", "dense_37"]}
61 changes: 33 additions & 28 deletions sound_generator.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,52 @@
import os
import tensorflow as tf
from tensorflow import TensorShape
import tsms
import numpy as np
from urllib.request import urlretrieve
from loguru import logger
from pprint import pprint
import warnings
from typing import Dict, Any
from tcae import model, localconfig, train
from tcae import model, localconfig
from tcae.dataset import heuristic_names


warnings.simplefilter("ignore")


# qualities = {
# "bright": {
# "high": [],
# "high_mid": []
# },
# "dark": {
#
# }
# }
#
# Todo: Sound Generation prediction / output transform needs to be fixed

class HeuristicMeasures:
def __init__(self):
self.names = heuristic_names

for name in heuristic_names:
vars(self)[name] = 0.
def get_zero_batch(conf: localconfig.LocalConfig):
mask = TensorShape([conf.batch_size, conf.harmonic_frame_steps, conf.max_num_harmonics])
note_number = TensorShape([conf.batch_size, conf.num_pitches])
velocity = TensorShape([conf.batch_size, conf.num_velocities])
measures = TensorShape([conf.batch_size, conf.num_measures])
f0_shifts = TensorShape([conf.batch_size, conf.harmonic_frame_steps, 1])
mag_env = TensorShape([conf.batch_size, conf.harmonic_frame_steps, 1])
h_freq_shifts = TensorShape([conf.batch_size, conf.harmonic_frame_steps, conf.max_num_harmonics])
h_mag_dist = TensorShape([conf.batch_size, conf.harmonic_frame_steps, conf.max_num_harmonics])
h_phase_diff = TensorShape([conf.batch_size, conf.harmonic_frame_steps, conf.max_num_harmonics])

def __call__(self, **kwargs):
outputs = []
_shapes = {}

for k, v in kwargs.items():
if k in vars(self):
vars(self)[k] = v
_shapes.update({
"mask": tf.zeros(mask),
"f0_shifts": tf.zeros(f0_shifts),
"mag_env": tf.zeros(mag_env),
"h_freq_shifts": tf.zeros(h_freq_shifts),
"h_mag_dist": tf.zeros(h_mag_dist),
"h_phase_diff": tf.zeros(h_phase_diff)
})

for name in heuristic_names:
outputs.append(vars(self)[name])
return outputs
if conf.use_note_number:
_shapes.update({"note_number": tf.zeros(note_number)})
if conf.use_velocity:
_shapes.update({"velocity": tf.zeros(velocity)})
if conf.use_heuristics:
_shapes.update({"measures": tf.zeros(measures)})
return _shapes


class SoundGenerator:
Expand Down Expand Up @@ -119,8 +124,8 @@ def load_config(self) -> None:

def load_model(self) -> None:
assert os.path.isfile(self._checkpoint_path), f"No checkpoint at {self._checkpoint_path}"
self._model = model.MtVae(self._conf)
_ = self._model(train.get_zero_batch(self._conf))
self._model = model.TCAEModel(self._conf)
_ = self._model(get_zero_batch(self._conf))
self._model.load_weights(self._checkpoint_path)
logger.info("Model loaded")

Expand All @@ -131,7 +136,7 @@ def _get_mask(self, note_number: int) -> np.ndarray:
)
harmonics = tsms.core.get_number_harmonics(f0, self._conf.sample_rate)
harmonics = np.squeeze(harmonics)
mask = np.zeros((1, self._conf.harmonic_frame_steps, 110))
mask = np.zeros((1, self._conf.harmonic_frame_steps, self._conf.max_num_harmonics))
mask[:, :, :harmonics] = np.ones((1, self._conf.harmonic_frame_steps, harmonics))
return mask

Expand Down Expand Up @@ -205,7 +210,7 @@ def get_prediction(self, data: Dict) -> Any:
logger.info("Getting prediction")
prediction = self._model.decoder(decoder_inputs)
logger.info("Transforming prediction")
transformed = self._conf.data_handler.prediction_transform(prediction)
transformed = self._conf.data_handler.output_transform(prediction)
logger.info("De-normalizing prediction")
freq, mag, phase = self._conf.data_handler.denormalize(
transformed, decoder_inputs["mask"],
Expand Down
2 changes: 1 addition & 1 deletion tcae/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import localconfig, dataset, model, train
from tcae import localconfig, dataset, model, train
30 changes: 15 additions & 15 deletions tcae/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import os
import tensorflow as tf
from .localconfig import LocalConfig
from . import heuristics
from tcae.localconfig import LocalConfig
from tcae import heuristics


heuristic_names = [
"inharmonicity", "even_odd", "sparse_rich", "attack_rms",
"decay_rms", "attack_time", "decay_time", "bass", "mid",
"high_mid", "high"
]


class CompleteTFRecordProvider:
Expand Down Expand Up @@ -69,13 +76,6 @@ def features_dict(self):
}


heuristic_names = [
"inharmonicity", "even_odd", "sparse_rich", "attack_rms",
"decay_rms", "attack_time", "decay_time", "bass", "mid",
"high_mid", "high"
]


def get_measures(h_freq, h_mag, conf: LocalConfig):
inharmonic = heuristics.core.inharmonicity_measure(
h_freq, h_mag, None, None, None, None)
Expand Down Expand Up @@ -227,12 +227,12 @@ def get_dataset(conf: LocalConfig):
test_dataset = create_dataset(
test_path, map_func=map_features, batch_size=conf.batch_size)

train_dataset = train_dataset.apply(
tf.data.experimental.assert_cardinality(10542))
valid_dataset = valid_dataset.apply(
tf.data.experimental.assert_cardinality(2906))
test_dataset = test_dataset.apply(
tf.data.experimental.assert_cardinality(1588))
# train_dataset = train_dataset.apply(
# tf.data.experimental.assert_cardinality(num_train_steps))
# valid_dataset = valid_dataset.apply(
# tf.data.experimental.assert_cardinality(num_valid_steps))
# test_dataset = test_dataset.apply(
# tf.data.experimental.assert_cardinality(num_test_steps))

if conf.dataset_modifier is not None:
train_dataset, valid_dataset, test_dataset = conf.dataset_modifier(
Expand Down
43 changes: 8 additions & 35 deletions tcae/localconfig.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import json
from typing import Dict
from .data_handler import DataHandler
from tcae.data_handler import DataHandler


class LocalConfig:
Expand All @@ -18,53 +18,34 @@ class LocalConfig:
use_encoder = True
use_phase = False
latent_dim = 16
use_max_pool = True
strides = 2
use_lstm_in_encoder = True
# strides = 2
use_note_number = True
use_velocity = True
use_instrument_id = False
use_heuristics = True
use_one_hot_conditioning = True
hidden_dim = 256
default_k = 3
deep_decoder = False
add_z_to_decoder_blocks = True
check_decoder_hidden_dim = True
print_model_summary = False
skip_channels = 32
lstm_dim = 256
lstm_dropout = 0.4
harmonic_frame_steps = 1001
max_harmonics = 110
frame_size = 64
batch_size = 2
num_instruments = 74
num_measures = 7 + 4
starting_midi_pitch = 40
num_pitches = 49
num_velocities = 5
max_num_harmonics = 98
max_num_harmonics = 110
frame_size = 64
harmonic_frame_steps = 1024
row_dim = 1024
col_dim = 128
padding = "same"
epochs = 500
num_train_steps = None
num_valid_steps = None
early_stopping = 7
learning_rate = 2e-4
lr_plateau = 4
lr_factor = 0.5
gradient_norm = 5.
csv_log_file = "logs.csv"
final_conv_shape = (64, 8, 192) # ToDo: to be calculated dynamically
final_conv_units = 64 * 8 * 192 # ToDo: to be calculated dynamically
best_loss = 1e6
sample_rate = 16000
log_steps = True
step_log_interval = 100
is_variational = True
using_mt = True
using_categorical = False
use_embeddings = False
scalar_embedding = False
Expand All @@ -78,6 +59,7 @@ class LocalConfig:
"h_mag_dist": {"shape": (row_dim, 64)},
"h_phase_diff": {"shape": (row_dim, 64)},
}

mt_outputs = {
"f0_shifts": {"shape": (row_dim, 64, 16)},
"h_freq_shifts": {"shape": (row_dim, 110, 16)},
Expand All @@ -86,22 +68,13 @@ class LocalConfig:
"h_phase_diff": {"shape": (row_dim, 110, 16)},
}

use_kl_anneal = False
kl_weight = 1.
kl_weight_max = 1.
kl_anneal_factor = 0.05
kl_anneal_start = 20
reconstruction_weight = 1.
st_var = (2.0 ** (1.0 / 12.0) - 1.0)
db_limit = -120
encoder_type = "2d" # or "1d"
decoder_type = "cnn"
freq_bands = {
"bass": [60, 270],
"mid": [270, 2000],
"high_mid": [2000, 6000],
"high": [6000, 20000]
}

data_handler = None
data_handler_properties = []
data_handler_type = "none"
Expand Down Expand Up @@ -183,7 +156,7 @@ def save_config(self):
to_save[p] = eval(f"self.data_handler.{p}")

if "data_handler" in to_save:
# to_save is not a deep copy so, we pop item
# to_save will not save objects, we pop item
to_save.pop("data_handler")

with open(target_path, "w") as f:
Expand Down
Loading

0 comments on commit bdf8fa4

Please sign in to comment.