diff --git a/output_layers.json b/output_layers.json index 95e38f8..740787d 100644 --- a/output_layers.json +++ b/output_layers.json @@ -1 +1 @@ -{"f0_shifts": ["permute", "dense_1", "permute_1", "dense_2", "f0_shifts_decoder_conv_0", "f0_shifts_decoder_bn_0", "f0_shifts_decoder_act_0", "add_9", "f0_shifts_decoder_conv_1", "f0_shifts_decoder_bn_1", "f0_shifts_decoder_act_1", "add_10", "f0_shifts_decoder_conv_2", "f0_shifts_decoder_bn_2", "f0_shifts_decoder_act_2", "add_11", "f0_shifts_decoder_conv_3", "f0_shifts_decoder_bn_3", "f0_shifts_decoder_act_3", "add_12", "conv2d_15", "lambda_1", "dense_3", "layer_normalization_1", "dense_4", "add_13", "layer_normalization_2", "dense_5", "add_14", "layer_normalization_3", "dense_6"], "h_freq_shifts": ["permute_2", "dense_7", "permute_3", "dense_8", "h_freq_shifts_decoder_conv_0", "h_freq_shifts_decoder_bn_0", "h_freq_shifts_decoder_act_0", "add_15", "h_freq_shifts_decoder_conv_1", "h_freq_shifts_decoder_bn_1", "h_freq_shifts_decoder_act_1", "add_16", "h_freq_shifts_decoder_conv_2", "h_freq_shifts_decoder_bn_2", "h_freq_shifts_decoder_act_2", "add_17", "h_freq_shifts_decoder_conv_3", "h_freq_shifts_decoder_bn_3", "h_freq_shifts_decoder_act_3", "add_18", "conv2d_16", "lambda_2", "dense_9", "layer_normalization_4", "dense_10", "add_19", "layer_normalization_5", "dense_11", "add_20", "layer_normalization_6", "dense_12"], "mag_env": ["permute_4", "dense_13", "permute_5", "dense_14", "mag_env_decoder_conv_0", "mag_env_decoder_bn_0", "mag_env_decoder_act_0", "add_21", "mag_env_decoder_conv_1", "mag_env_decoder_bn_1", "mag_env_decoder_act_1", "add_22", "mag_env_decoder_conv_2", "mag_env_decoder_bn_2", "mag_env_decoder_act_2", "add_23", "mag_env_decoder_conv_3", "mag_env_decoder_bn_3", "mag_env_decoder_act_3", "add_24", "conv2d_17", "lambda_3", "dense_15", "layer_normalization_7", "dense_16", "add_25", "layer_normalization_8", "dense_17", "add_26", "layer_normalization_9", "dense_18"], "h_mag_dist": ["permute_6", "dense_19", "permute_7", "dense_20", "h_mag_dist_decoder_conv_0", "h_mag_dist_decoder_bn_0", "h_mag_dist_decoder_act_0", "add_27", "h_mag_dist_decoder_conv_1", "h_mag_dist_decoder_bn_1", "h_mag_dist_decoder_act_1", "add_28", "h_mag_dist_decoder_conv_2", "h_mag_dist_decoder_bn_2", "h_mag_dist_decoder_act_2", "add_29", "h_mag_dist_decoder_conv_3", "h_mag_dist_decoder_bn_3", "h_mag_dist_decoder_act_3", "add_30", "conv2d_18", "lambda_4", "dense_21", "layer_normalization_10", "dense_22", "add_31", "layer_normalization_11", "dense_23", "add_32", "layer_normalization_12", "dense_24"]} \ No newline at end of file +{"f0_shifts": ["permute", "dense_14", "permute_1", "dense_15", "f0_shifts_decoder_conv_0", "f0_shifts_decoder_bn_0", "f0_shifts_decoder_act_0", "add_13", "f0_shifts_decoder_conv_1", "f0_shifts_decoder_bn_1", "f0_shifts_decoder_act_1", "add_14", "f0_shifts_decoder_conv_2", "f0_shifts_decoder_bn_2", "f0_shifts_decoder_act_2", "add_15", "f0_shifts_decoder_conv_3", "f0_shifts_decoder_bn_3", "f0_shifts_decoder_act_3", "add_16", "conv2d_9", "lambda_1", "dense_16", "layer_normalization_7", "dense_17", "add_17", "layer_normalization_8", "dense_18", "add_18", "layer_normalization_9", "dense_19"], "h_freq_shifts": ["permute_2", "dense_20", "permute_3", "dense_21", "h_freq_shifts_decoder_conv_0", "h_freq_shifts_decoder_bn_0", "h_freq_shifts_decoder_act_0", "add_19", "h_freq_shifts_decoder_conv_1", "h_freq_shifts_decoder_bn_1", "h_freq_shifts_decoder_act_1", "add_20", "h_freq_shifts_decoder_conv_2", "h_freq_shifts_decoder_bn_2", "h_freq_shifts_decoder_act_2", "add_21", "h_freq_shifts_decoder_conv_3", "h_freq_shifts_decoder_bn_3", "h_freq_shifts_decoder_act_3", "add_22", "conv2d_10", "lambda_2", "dense_22", "layer_normalization_10", "dense_23", "add_23", "layer_normalization_11", "dense_24", "add_24", "layer_normalization_12", "dense_25"], "mag_env": ["permute_4", "dense_26", "permute_5", "dense_27", "mag_env_decoder_conv_0", "mag_env_decoder_bn_0", "mag_env_decoder_act_0", "add_25", "mag_env_decoder_conv_1", "mag_env_decoder_bn_1", "mag_env_decoder_act_1", "add_26", "mag_env_decoder_conv_2", "mag_env_decoder_bn_2", "mag_env_decoder_act_2", "add_27", "mag_env_decoder_conv_3", "mag_env_decoder_bn_3", "mag_env_decoder_act_3", "add_28", "conv2d_11", "lambda_3", "dense_28", "layer_normalization_13", "dense_29", "add_29", "layer_normalization_14", "dense_30", "add_30", "layer_normalization_15", "dense_31"], "h_mag_dist": ["permute_6", "dense_32", "permute_7", "dense_33", "h_mag_dist_decoder_conv_0", "h_mag_dist_decoder_bn_0", "h_mag_dist_decoder_act_0", "add_31", "h_mag_dist_decoder_conv_1", "h_mag_dist_decoder_bn_1", "h_mag_dist_decoder_act_1", "add_32", "h_mag_dist_decoder_conv_2", "h_mag_dist_decoder_bn_2", "h_mag_dist_decoder_act_2", "add_33", "h_mag_dist_decoder_conv_3", "h_mag_dist_decoder_bn_3", "h_mag_dist_decoder_act_3", "add_34", "conv2d_12", "lambda_4", "dense_34", "layer_normalization_16", "dense_35", "add_35", "layer_normalization_17", "dense_36", "add_36", "layer_normalization_18", "dense_37"]} \ No newline at end of file diff --git a/sound_generator.py b/sound_generator.py index 16f86cf..d5e5d87 100644 --- a/sound_generator.py +++ b/sound_generator.py @@ -1,5 +1,6 @@ import os import tensorflow as tf +from tensorflow import TensorShape import tsms import numpy as np from urllib.request import urlretrieve @@ -7,41 +8,45 @@ from pprint import pprint import warnings from typing import Dict, Any -from tcae import model, localconfig, train +from tcae import model, localconfig from tcae.dataset import heuristic_names warnings.simplefilter("ignore") -# qualities = { -# "bright": { -# "high": [], -# "high_mid": [] -# }, -# "dark": { -# -# } -# } -# +# Todo: Sound Generation prediction / output transform needs to be fixed -class HeuristicMeasures: - def __init__(self): - self.names = heuristic_names - for name in heuristic_names: - vars(self)[name] = 0. +def get_zero_batch(conf: localconfig.LocalConfig): + mask = TensorShape([conf.batch_size, conf.harmonic_frame_steps, conf.max_num_harmonics]) + note_number = TensorShape([conf.batch_size, conf.num_pitches]) + velocity = TensorShape([conf.batch_size, conf.num_velocities]) + measures = TensorShape([conf.batch_size, conf.num_measures]) + f0_shifts = TensorShape([conf.batch_size, conf.harmonic_frame_steps, 1]) + mag_env = TensorShape([conf.batch_size, conf.harmonic_frame_steps, 1]) + h_freq_shifts = TensorShape([conf.batch_size, conf.harmonic_frame_steps, conf.max_num_harmonics]) + h_mag_dist = TensorShape([conf.batch_size, conf.harmonic_frame_steps, conf.max_num_harmonics]) + h_phase_diff = TensorShape([conf.batch_size, conf.harmonic_frame_steps, conf.max_num_harmonics]) - def __call__(self, **kwargs): - outputs = [] + _shapes = {} - for k, v in kwargs.items(): - if k in vars(self): - vars(self)[k] = v + _shapes.update({ + "mask": tf.zeros(mask), + "f0_shifts": tf.zeros(f0_shifts), + "mag_env": tf.zeros(mag_env), + "h_freq_shifts": tf.zeros(h_freq_shifts), + "h_mag_dist": tf.zeros(h_mag_dist), + "h_phase_diff": tf.zeros(h_phase_diff) + }) - for name in heuristic_names: - outputs.append(vars(self)[name]) - return outputs + if conf.use_note_number: + _shapes.update({"note_number": tf.zeros(note_number)}) + if conf.use_velocity: + _shapes.update({"velocity": tf.zeros(velocity)}) + if conf.use_heuristics: + _shapes.update({"measures": tf.zeros(measures)}) + return _shapes class SoundGenerator: @@ -119,8 +124,8 @@ def load_config(self) -> None: def load_model(self) -> None: assert os.path.isfile(self._checkpoint_path), f"No checkpoint at {self._checkpoint_path}" - self._model = model.MtVae(self._conf) - _ = self._model(train.get_zero_batch(self._conf)) + self._model = model.TCAEModel(self._conf) + _ = self._model(get_zero_batch(self._conf)) self._model.load_weights(self._checkpoint_path) logger.info("Model loaded") @@ -131,7 +136,7 @@ def _get_mask(self, note_number: int) -> np.ndarray: ) harmonics = tsms.core.get_number_harmonics(f0, self._conf.sample_rate) harmonics = np.squeeze(harmonics) - mask = np.zeros((1, self._conf.harmonic_frame_steps, 110)) + mask = np.zeros((1, self._conf.harmonic_frame_steps, self._conf.max_num_harmonics)) mask[:, :, :harmonics] = np.ones((1, self._conf.harmonic_frame_steps, harmonics)) return mask @@ -205,7 +210,7 @@ def get_prediction(self, data: Dict) -> Any: logger.info("Getting prediction") prediction = self._model.decoder(decoder_inputs) logger.info("Transforming prediction") - transformed = self._conf.data_handler.prediction_transform(prediction) + transformed = self._conf.data_handler.output_transform(prediction) logger.info("De-normalizing prediction") freq, mag, phase = self._conf.data_handler.denormalize( transformed, decoder_inputs["mask"], diff --git a/tcae/__init__.py b/tcae/__init__.py index 391d16c..a0b0e32 100644 --- a/tcae/__init__.py +++ b/tcae/__init__.py @@ -1 +1 @@ -from . import localconfig, dataset, model, train +from tcae import localconfig, dataset, model, train diff --git a/tcae/dataset.py b/tcae/dataset.py index eef5b3b..8e6704d 100644 --- a/tcae/dataset.py +++ b/tcae/dataset.py @@ -1,7 +1,14 @@ import os import tensorflow as tf -from .localconfig import LocalConfig -from . import heuristics +from tcae.localconfig import LocalConfig +from tcae import heuristics + + +heuristic_names = [ + "inharmonicity", "even_odd", "sparse_rich", "attack_rms", + "decay_rms", "attack_time", "decay_time", "bass", "mid", + "high_mid", "high" +] class CompleteTFRecordProvider: @@ -69,13 +76,6 @@ def features_dict(self): } -heuristic_names = [ - "inharmonicity", "even_odd", "sparse_rich", "attack_rms", - "decay_rms", "attack_time", "decay_time", "bass", "mid", - "high_mid", "high" -] - - def get_measures(h_freq, h_mag, conf: LocalConfig): inharmonic = heuristics.core.inharmonicity_measure( h_freq, h_mag, None, None, None, None) @@ -227,12 +227,12 @@ def get_dataset(conf: LocalConfig): test_dataset = create_dataset( test_path, map_func=map_features, batch_size=conf.batch_size) - train_dataset = train_dataset.apply( - tf.data.experimental.assert_cardinality(10542)) - valid_dataset = valid_dataset.apply( - tf.data.experimental.assert_cardinality(2906)) - test_dataset = test_dataset.apply( - tf.data.experimental.assert_cardinality(1588)) + # train_dataset = train_dataset.apply( + # tf.data.experimental.assert_cardinality(num_train_steps)) + # valid_dataset = valid_dataset.apply( + # tf.data.experimental.assert_cardinality(num_valid_steps)) + # test_dataset = test_dataset.apply( + # tf.data.experimental.assert_cardinality(num_test_steps)) if conf.dataset_modifier is not None: train_dataset, valid_dataset, test_dataset = conf.dataset_modifier( diff --git a/tcae/localconfig.py b/tcae/localconfig.py index f1c2b95..29962e5 100644 --- a/tcae/localconfig.py +++ b/tcae/localconfig.py @@ -1,7 +1,7 @@ import os import json from typing import Dict -from .data_handler import DataHandler +from tcae.data_handler import DataHandler class LocalConfig: @@ -18,53 +18,34 @@ class LocalConfig: use_encoder = True use_phase = False latent_dim = 16 - use_max_pool = True - strides = 2 - use_lstm_in_encoder = True + # strides = 2 use_note_number = True use_velocity = True use_instrument_id = False use_heuristics = True use_one_hot_conditioning = True hidden_dim = 256 - default_k = 3 - deep_decoder = False - add_z_to_decoder_blocks = True - check_decoder_hidden_dim = True print_model_summary = False - skip_channels = 32 lstm_dim = 256 lstm_dropout = 0.4 - harmonic_frame_steps = 1001 - max_harmonics = 110 - frame_size = 64 batch_size = 2 num_instruments = 74 num_measures = 7 + 4 starting_midi_pitch = 40 num_pitches = 49 num_velocities = 5 - max_num_harmonics = 98 + max_num_harmonics = 110 + frame_size = 64 + harmonic_frame_steps = 1024 row_dim = 1024 - col_dim = 128 padding = "same" epochs = 500 - num_train_steps = None - num_valid_steps = None early_stopping = 7 learning_rate = 2e-4 lr_plateau = 4 lr_factor = 0.5 - gradient_norm = 5. csv_log_file = "logs.csv" - final_conv_shape = (64, 8, 192) # ToDo: to be calculated dynamically - final_conv_units = 64 * 8 * 192 # ToDo: to be calculated dynamically - best_loss = 1e6 sample_rate = 16000 - log_steps = True - step_log_interval = 100 - is_variational = True - using_mt = True using_categorical = False use_embeddings = False scalar_embedding = False @@ -78,6 +59,7 @@ class LocalConfig: "h_mag_dist": {"shape": (row_dim, 64)}, "h_phase_diff": {"shape": (row_dim, 64)}, } + mt_outputs = { "f0_shifts": {"shape": (row_dim, 64, 16)}, "h_freq_shifts": {"shape": (row_dim, 110, 16)}, @@ -86,22 +68,13 @@ class LocalConfig: "h_phase_diff": {"shape": (row_dim, 110, 16)}, } - use_kl_anneal = False - kl_weight = 1. - kl_weight_max = 1. - kl_anneal_factor = 0.05 - kl_anneal_start = 20 - reconstruction_weight = 1. - st_var = (2.0 ** (1.0 / 12.0) - 1.0) - db_limit = -120 - encoder_type = "2d" # or "1d" - decoder_type = "cnn" freq_bands = { "bass": [60, 270], "mid": [270, 2000], "high_mid": [2000, 6000], "high": [6000, 20000] } + data_handler = None data_handler_properties = [] data_handler_type = "none" @@ -183,7 +156,7 @@ def save_config(self): to_save[p] = eval(f"self.data_handler.{p}") if "data_handler" in to_save: - # to_save is not a deep copy so, we pop item + # to_save will not save objects, we pop item to_save.pop("data_handler") with open(target_path, "w") as f: diff --git a/tcae/model.py b/tcae/model.py index 1a0e5e8..fda91a0 100644 --- a/tcae/model.py +++ b/tcae/model.py @@ -1,192 +1,5 @@ import tensorflow as tf -import json -from .localconfig import LocalConfig - - -def sample_from_latent_space(inputs): - z_mean, z_log_variance = inputs - batch_size = tf.shape(z_mean)[0] - epsilon = tf.random.normal(shape=(batch_size, LocalConfig().latent_dim)) - return z_mean + tf.exp(0.5 * z_log_variance) * epsilon - - -def create_encoder(conf: LocalConfig): - def conv_block(x, f, k, index): - x = tf.keras.layers.Conv2D( - f, k, padding=conf.padding, name=f"encoder_conv_{index}", - kernel_initializer=tf.initializers.glorot_uniform())(x) - x = tf.keras.layers.BatchNormalization(name=f"encoder_bn_{index}")(x) - x = tf.keras.layers.Activation("relu", name=f"encoder_act_{index}")(x) - return x - - if conf is None: - conf = LocalConfig() - - wrapper = {} - - encoder_input = tf.keras.layers.Input(shape=(conf.row_dim, conf.col_dim, 2), name="encoder_input") - - filters = [32] * 2 + [64] * 2 - kernels = [conf.default_k] * 4 - - filters_kernels = iter(zip(filters, kernels)) - - # conv block 1 - f, k = next(filters_kernels) - wrapper["block_1"] = conv_block(encoder_input, f, k, 0) - wrapper["out_1"] = tf.keras.layers.MaxPool2D(2, name="encoder_pool_0")(wrapper["block_1"]) - # remaining conv blocks with skip connections - for i in range(1, len(filters)): - f, k = next(filters_kernels) - wrapper[f"block_{i + 1}"] = conv_block(wrapper[f"out_{i}"], f, k, i) - wrapper[f"skip_{i + 1}"] = tf.keras.layers.concatenate([wrapper[f"block_{i + 1}"], wrapper[f"out_{i}"]]) - wrapper[f"out_{i + 1}"] = tf.keras.layers.MaxPool2D(2, name=f"encoder_pool_{i}")(wrapper[f"skip_{i + 1}"]) - - if conf.use_lstm_in_encoder: - flattened = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten())(wrapper[f"out_{len(filters)}"]) - hidden = tf.keras.layers.Bidirectional( - tf.keras.layers.LSTM(conf.lstm_dim, activation="tanh", recurrent_activation="sigmoid", - return_sequences=False, dropout=conf.lstm_dropout, - recurrent_dropout=0, unroll=False, use_bias=True))(flattened) - else: - hidden = tf.keras.layers.Flatten()(wrapper[f"out_{len(filters)}"]) - if conf.is_variational: - z_mean = tf.keras.layers.Dense(conf.latent_dim, name="z_mean")(hidden) - z_log_variance = tf.keras.layers.Dense(conf.latent_dim, name="z_log_variance")(hidden) - z = tf.keras.layers.Lambda(sample_from_latent_space)([z_mean, z_log_variance]) - outputs = [z, z_mean, z_log_variance] - else: - outputs = tf.keras.layers.Dense(conf.latent_dim, activation="relu")(hidden) - model = tf.keras.models.Model( - encoder_input, outputs, name="encoder" - ) - tf.keras.utils.plot_model(model, to_file="encoder.png", show_shapes=True) - if conf.print_model_summary: - print(model.summary()) - return model - - -def create_strides_encoder(conf: LocalConfig): - def conv_block(x, f, k, skip=False): - x = tf.keras.layers.Conv2D( - f, k, padding=conf.padding, strides=conf.strides)(x) - if not skip: - x = tf.keras.layers.BatchNormalization()(x) - x = tf.keras.layers.Activation("elu")(x) - return x - - if conf is None: - conf = LocalConfig() - - encoder_input = tf.keras.layers.Input(shape=(conf.row_dim, conf.col_dim, 2), name="encoder_input") - - filters = [32] * 2 + [64] * 2 - kernels = [conf.default_k] * 4 - max_blocks = len(filters) - - filters_kernels = iter(zip(filters, kernels)) - - wrapper = {} - wrapper["block_0_input"] = encoder_input - - for i, (f, k) in enumerate(filters_kernels): - wrapper[f"block_{i}_output"] = conv_block(wrapper[f"block_{i}_input"], f, k) - if i == 0: - wrapper[f"block_{i + 1}_input"] = wrapper[f"block_{i}_output"] - else: - wrapper[f"block_{i + 1}_input"] = tf.keras.layers.Add()( - [wrapper[f"block_{i}_output"], wrapper[f"block_{i}_skip"]] - ) - if i < max_blocks: - wrapper[f"block_{i + 1}_skip"] = conv_block(wrapper[f"block_{i}_output"], 1, 1, skip=True) - - if conf.use_lstm_in_encoder: - flattened = tf.keras.layers.TimeDistributed(tf.keras.layers.Flatten())(wrapper[f"block_{max_blocks}_input"]) - hidden = tf.keras.layers.Bidirectional( - tf.keras.layers.LSTM(conf.lstm_dim, activation="tanh", recurrent_activation="sigmoid", - return_sequences=False, dropout=conf.lstm_dropout, - recurrent_dropout=0, unroll=False, use_bias=True))(flattened) - else: - hidden = tf.keras.layers.Flatten()(wrapper[f"block_{max_blocks}_input"]) - - if conf.is_variational: - z_mean = tf.keras.layers.Dense(conf.latent_dim, - kernel_initializer=tf.initializers.glorot_uniform(), - name="z_mean")(hidden) - z_log_variance = tf.keras.layers.Dense(conf.latent_dim, - kernel_initializer=tf.initializers.glorot_uniform(), - name="z_log_variance")(hidden) - z = tf.keras.layers.Lambda(sample_from_latent_space)([z_mean, z_log_variance]) - outputs = [z, z_mean, z_log_variance] - else: - outputs = tf.keras.layers.Dense(conf.latent_dim, activation="elu")(hidden) - model = tf.keras.models.Model( - encoder_input, outputs, name="encoder" - ) - tf.keras.utils.plot_model(model, to_file="encoder.png", show_shapes=True) - if conf.print_model_summary: - print(model.summary()) - return model - - -def create_1d_encoder(conf: LocalConfig): - def conv1d_model(inputs): - filters = [256, 32, 32, 32, 64, 128] - kernels = [128, 16, 16, 16, 16, 16] - strides = [8, 2, 2, 2, 2, 2] - # ToDo: Find better values for kernels for both freq - - x = inputs - - for f, k, s in zip(filters, kernels, strides): - x = tf.keras.layers.Conv1D(f, k, strides=s, padding=conf.padding)(x) - x = tf.keras.layers.BatchNormalization()(x) - x = tf.keras.layers.Activation("elu")(x) - - x = tf.keras.layers.Flatten()(x) - return x - - if conf is None: - conf = LocalConfig() - - encoder_input = tf.keras.layers.Input(shape=(conf.row_dim, conf.col_dim, 2), name="encoder_input") - freq = tf.keras.layers.Lambda(lambda x: x[..., 0])(encoder_input) - mag = tf.keras.layers.Lambda(lambda x: x[..., 1])(encoder_input) - - freq_flattened = conv1d_model(freq) - mag_flattened = conv1d_model(mag) - - hidden_input = [freq_flattened, mag_flattened] - hidden = tf.keras.layers.concatenate(hidden_input) - - if conf.is_variational: - z_mean = tf.keras.layers.Dense(conf.latent_dim, name="z_mean")(hidden) - z_log_variance = tf.keras.layers.Dense(conf.latent_dim, name="z_log_variance")(hidden) - z = tf.keras.layers.Lambda(sample_from_latent_space)([z_mean, z_log_variance]) - outputs = [z, z_mean, z_log_variance] - else: - outputs = tf.keras.layers.Dense(conf.latent_dim, activation="elu")(hidden) - - model = tf.keras.models.Model(encoder_input, outputs) - tf.keras.utils.plot_model(model, to_file="encoder.png", show_shapes=True) - if conf.print_model_summary: - print(model.summary()) - return model - - -def decoder_inputs(conf: LocalConfig): - z_input = tf.keras.layers.Input(shape=(conf.latent_dim,), name="z_input") - note_number = tf.keras.layers.Input(shape=(conf.num_pitches,), name="note_number") - velocity = tf.keras.layers.Input(shape=(conf.num_velocities,), name="velocity") - return z_input, note_number, velocity - - -def reshape_z(block, z_input, conf: LocalConfig): - target_shape = (128 * 2**block, 16 * 2**block, conf.latent_dim) - n_repeats = target_shape[0] * target_shape[1] - current_z = tf.keras.layers.RepeatVector(n_repeats)(z_input) - current_z = tf.keras.layers.Reshape(target_shape=target_shape)(current_z) - return current_z +from tcae.localconfig import LocalConfig def embedding_layers(inputs, input_dims, embedding_size, conf: LocalConfig): @@ -201,183 +14,6 @@ def embedding_layers(inputs, input_dims, embedding_size, conf: LocalConfig): return numeric_val -def create_decoder(conf: LocalConfig): - if conf is None: - conf = LocalConfig() - z_input, note_number, velocity = decoder_inputs(conf) - heuristic_measures = tf.keras.layers.Input(shape=(conf.num_measures,), name="measures") - - if conf.use_encoder: - inputs_list = [z_input, note_number, velocity] - else: - inputs_list = [note_number, velocity] - - if conf.use_heuristics: - inputs_list += [heuristic_measures] - - if conf.hidden_dim < conf.latent_dim and conf.check_decoder_hidden_dim: - conf.hidden_dim = max(conf.hidden_dim, conf.latent_dim) - print("Decoder hidden dimension updated to", conf.hidden_dim) - - inputs = tf.keras.layers.concatenate(inputs_list) - - if not conf.deep_decoder: - hidden = tf.keras.layers.Dense(conf.hidden_dim, activation="relu")(inputs) - num_repeats = int((conf.final_conv_units // conf.hidden_dim) / 2) - repeat = tf.keras.layers.RepeatVector(num_repeats)(hidden) - conv_input = tf.keras.layers.Bidirectional( - tf.keras.layers.LSTM(conf.lstm_dim, activation="tanh", recurrent_activation="sigmoid", - return_sequences=True, dropout=conf.lstm_dropout, - recurrent_dropout=0, unroll=False, use_bias=True))(repeat) - f_repeat = 2 - k_repeat = 4 - final_conv_shape = conf.final_conv_shape - else: - if conf.add_z_to_decoder_blocks: - print("Deeper decoder can not add z to decoder blocks, disabling conf.add_z_to_decoder_blocks") - conf.add_z_to_decoder_blocks = False - units = int(conf.final_conv_units / 16) - conv_input = tf.keras.layers.Dense(units, activation="relu")(inputs) - f_repeat = 3 - k_repeat = 6 - final_conv_shape = [int(n/4) for n in conf.final_conv_shape[:-1]] + [192] - reshaped = tf.keras.layers.Reshape(final_conv_shape)(conv_input) - - filters = list(reversed([32] * f_repeat + [64] * f_repeat)) - kernels = list(reversed([conf.default_k] * k_repeat)) - filters_kernels = iter(zip(filters, kernels)) - wrapper = { - "up_in_0": reshaped - } - - up_conv_channels = conf.latent_dim if conf.add_z_to_decoder_blocks else conf.skip_channels - - for i in range(0, len(filters)): - f, k = next(filters_kernels) - wrapper[f"up_out_{i}"] = tf.keras.layers.UpSampling2D(2, name=f"decoder_up_{i}")(wrapper[f"up_in_{i}"]) - - wrapper[f"conv_out_{i}"] = tf.keras.layers.Conv2D( - f, k, padding=conf.padding, name=f"decoder_conv_{i}", - kernel_initializer=tf.initializers.glorot_uniform())(wrapper[f"up_out_{i}"]) - wrapper[f"bn_out_{i}"] = tf.keras.layers.BatchNormalization(name=f"decoder_bn_{i}")(wrapper[f"conv_out_{i}"]) - wrapper[f"act_{i}"] = tf.keras.layers.Activation("relu", name=f"decoder_act_{i}")(wrapper[f"bn_out_{i}"]) - if conf.use_encoder and conf.add_z_to_decoder_blocks: - wrapper[f"up_conv_{i}"] = tf.keras.layers.Conv2D( - up_conv_channels, 3, padding="same", name=f"decoder_up_conv_{i}" - )(wrapper[f"up_out_{i}"]) - current_z = reshape_z(i, z_input, conf) - z_added = tf.keras.layers.Add()([wrapper[f"up_conv_{i}"], current_z]) - wrapper[f"up_in_{i + 1}"] = tf.keras.layers.concatenate([ - wrapper[f"act_{i}"], z_added - ]) - else: - wrapper[f"up_out_{i}"] = tf.keras.layers.Conv2D(1, 1, padding="same")(wrapper[f"up_out_{i}"]) - wrapper[f"up_in_{i + 1}"] = tf.keras.layers.Add()([ - wrapper[f"act_{i}"], wrapper[f"up_out_{i}"] - ]) - - reconstructed = tf.keras.layers.Conv2D( - 2, 3, padding=conf.padding, activation="linear", name="decoder_output", - kernel_initializer=tf.initializers.glorot_uniform())(wrapper[f"up_in_{len(filters)}"]) - - model = tf.keras.models.Model( - inputs_list, - reconstructed, name="decoder" - ) - # tf.keras.utils.plot_model(model, to_file="decoder.png", show_shapes=True) - if conf.print_model_summary: - print(model.summary()) - return model - - -def create_rnn_decoder(conf: LocalConfig): - if conf is None: - conf = LocalConfig() - note_number = tf.keras.layers.Input(shape=(conf.num_pitches,), name="note_number") - velocity = tf.keras.layers.Input(shape=(conf.num_velocities,), name="velocity") - heuristic_measures = tf.keras.layers.Input(shape=(conf.num_measures,), name="measures") - inputs_list = [note_number, velocity] - - if conf.use_heuristics: - inputs_list += [heuristic_measures] - - inputs = tf.keras.layers.concatenate(inputs_list) - hidden = tf.keras.layers.Dense(256, activation="relu", - kernel_initializer=tf.initializers.glorot_uniform())(inputs) - num_repeats = 1024 - repeat = tf.keras.layers.RepeatVector(num_repeats)(hidden) - # Note: When using LSTM, only this specific configuration works with CUDA - lstm_1 = tf.keras.layers.Bidirectional( - tf.keras.layers.LSTM(64, activation="tanh", recurrent_activation="sigmoid", - return_sequences=True, dropout=conf.lstm_dropout, - recurrent_dropout=0, unroll=False, use_bias=True))(repeat) - lstm_2 = tf.keras.layers.Bidirectional( - tf.keras.layers.LSTM(64, activation="tanh", recurrent_activation="sigmoid", - return_sequences=True, dropout=conf.lstm_dropout, - recurrent_dropout=0, unroll=False, use_bias=True))(repeat) - - output_1 = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(lstm_1) - output_2 = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(lstm_2) - - output = tf.keras.layers.concatenate([output_1, output_2]) - model = tf.keras.models.Model( - inputs_list, - output, name="decoder" - ) - if conf.print_model_summary: - print(model.summary()) - return model - - -def create_vae(conf: LocalConfig): - if conf is None: - conf = LocalConfig() - if conf.use_max_pool: - encoder = create_encoder(conf) - else: - if conf.encoder_type == "2d": - encoder = create_strides_encoder(conf) - else: - encoder = create_1d_encoder(conf) - if conf.decoder_type == "rnn": - decoder = create_rnn_decoder(conf) - else: - decoder = create_decoder(conf) - - encoder_input = tf.keras.layers.Input(shape=(conf.row_dim, conf.col_dim, 2)) - note_number = tf.keras.layers.Input(shape=(conf.num_pitches,)) - heuristic_measures = tf.keras.layers.Input(shape=(conf.num_measures,), name="measures") - velocity = tf.keras.layers.Input(shape=(conf.num_velocities,)) - - if conf.is_variational: - z, z_mean, z_log_variance = encoder(encoder_input) - inputs = [z, note_number, velocity] - if conf.use_heuristics: - inputs += [heuristic_measures] - reconstruction = decoder(inputs) - outputs = [reconstruction, z_mean, z_log_variance] - else: - z = encoder(encoder_input) - inputs = [z, note_number, velocity] - if conf.use_heuristics: - inputs += [heuristic_measures] - reconstruction = decoder(inputs) - outputs = reconstruction - - model_input = [encoder_input, note_number, velocity] - if conf.use_heuristics: - model_input += [heuristic_measures] - - model = tf.keras.models.Model( - model_input, - outputs, - name="auto_encoder" - ) - if conf.print_model_summary: - print(model.summary()) - return model - - def bn_act(inputs, name=None): x = tf.keras.layers.BatchNormalization()(inputs) return tf.keras.layers.Activation("elu", name=name)(x) @@ -399,26 +35,15 @@ def conv_2d_encoder_block(inputs, filters, kernel, stride=1, use_act=True, name= strides=stride, name=name)(inputs) -def ffn_block(inputs, num_repeats, hidden_dim, save_names=False, k=None): - if not save_names: - x = tf.keras.layers.Dense(hidden_dim, activation="elu")(inputs) - x = tf.keras.layers.LayerNormalization()(x) +def ffn_block(inputs, num_repeats, hidden_dim): + x = tf.keras.layers.Dense(hidden_dim, activation="elu")(inputs) + x = tf.keras.layers.LayerNormalization()(x) - for i in range(0, num_repeats): - y = tf.keras.layers.Dense(hidden_dim, activation="elu")(x) - x = tf.keras.layers.Add()([x, y]) - x = tf.keras.layers.LayerNormalization()(x) - return x - else: - assert k is not None, "Need a key name to save output layer names" - x = name(k, tf.keras.layers.Dense(hidden_dim, activation="elu"))(inputs) - x = name(k, tf.keras.layers.LayerNormalization())(x) - - for i in range(0, num_repeats): - y = name(k, tf.keras.layers.Dense(hidden_dim, activation="elu"))(x) - x = name(k, tf.keras.layers.Add())([x, y]) - x = name(k, tf.keras.layers.LayerNormalization())(x) - return x + for i in range(0, num_repeats): + y = tf.keras.layers.Dense(hidden_dim, activation="elu")(x) + x = tf.keras.layers.Add()([x, y]) + x = tf.keras.layers.LayerNormalization()(x) + return x def get_encoder_inputs(inputs, conf: LocalConfig): @@ -477,27 +102,6 @@ def create_mt_encoder(inputs, conf: LocalConfig): return m -output_layers = {} - - -def name(key, layer): - global output_layers - - if key not in output_layers: - output_layers[key] = [] - - layer_name = str(layer.name).split("/")[0] - output_layers[key].append(layer_name) - return layer - - -def save_output_layers(): - global output_layers - - with open("output_layers.json", "w") as f: - json.dump(output_layers, f) - - def create_mt_decoder(inputs, conf: LocalConfig): if conf is None: conf = LocalConfig() @@ -559,10 +163,10 @@ def create_mt_decoder(inputs, conf: LocalConfig): cols = conf.mt_outputs[k]["shape"][1] channels = conf.mt_outputs[k]["shape"][2] - task_in = name(k, tf.keras.layers.Permute(dims=(1, 3, 2)))(shared_out) - task_in = name(k, tf.keras.layers.Dense(units=cols))(task_in) - task_in = name(k, tf.keras.layers.Permute(dims=(1, 3, 2)))(task_in) - task_in = name(k, tf.keras.layers.Dense(units=channels))(task_in) + task_in = tf.keras.layers.Permute(dims=(1, 3, 2))(shared_out) + task_in = tf.keras.layers.Dense(units=cols)(task_in) + task_in = tf.keras.layers.Permute(dims=(1, 3, 2))(task_in) + task_in = tf.keras.layers.Dense(units=channels)(task_in) repeats = 4 filters = [channels] * repeats @@ -573,43 +177,43 @@ def create_mt_decoder(inputs, conf: LocalConfig): for i in range(0, len(filters)): _f, _k = next(filters_kernels) - wrapper[f"conv_out_{i}"] = name(k, tf.keras.layers.Conv2D( - _f, _k, padding=conf.padding, name=f"{k}_decoder_conv_{i}"))(wrapper[f"up_out_{i}"]) - wrapper[f"bn_out_{i}"] = name(k, tf.keras.layers.BatchNormalization( - name=f"{k}_decoder_bn_{i}"))(wrapper[f"conv_out_{i}"]) - wrapper[f"act_{i}"] = name(k, tf.keras.layers.Activation( - "elu", name=f"{k}_decoder_act_{i}"))(wrapper[f"bn_out_{i}"]) + wrapper[f"conv_out_{i}"] = tf.keras.layers.Conv2D( + _f, _k, padding=conf.padding, name=f"{k}_decoder_conv_{i}")(wrapper[f"up_out_{i}"]) + wrapper[f"bn_out_{i}"] = tf.keras.layers.BatchNormalization( + name=f"{k}_decoder_bn_{i}")(wrapper[f"conv_out_{i}"]) + wrapper[f"act_{i}"] = tf.keras.layers.Activation( + "elu", name=f"{k}_decoder_act_{i}")(wrapper[f"bn_out_{i}"]) # wrapper[f"up_out_{i}"] = tf.keras.layers.Dense(units=f)(wrapper[f"up_out_{i}"]) - wrapper[f"up_out_{i + 1}"] = name(k, tf.keras.layers.Add()([ + wrapper[f"up_out_{i + 1}"] = tf.keras.layers.Add()([ wrapper[f"act_{i}"], wrapper[f"up_out_{i}"] - ])) + ]) conv2d_out = wrapper[f"up_out_{len(filters)}"] if conf.using_categorical: - task_out = name(k, tf.keras.layers.Permute(dims=(1, 3, 2)))(conv2d_out) - task_out = name(k, tf.keras.layers.Dense(v["size"], activation="elu"))(task_out) - task_out = name(k, tf.keras.layers.Permute(dims=(1, 3, 2)))(task_out) - task_out = name(k, tf.keras.layers.Conv2D(256, 1, padding="same"))(task_out) + task_out = tf.keras.layers.Permute(dims=(1, 3, 2))(conv2d_out) + task_out = tf.keras.layers.Dense(v["size"], activation="elu")(task_out) + task_out = tf.keras.layers.Permute(dims=(1, 3, 2))(task_out) + task_out = tf.keras.layers.Conv2D(256, 1, padding="same")(task_out) else: - conv2d_out = name(k, tf.keras.layers.Conv2D(1, 1, padding="same"))(conv2d_out) - task_out = name(k, tf.keras.layers.Lambda(lambda y: tf.squeeze(y, axis=-1)))(conv2d_out) - task_out = ffn_block(task_out, 2, cols, save_names=True, k=k) - task_out = name(k, tf.keras.layers.Dense(units=v["size"]))(task_out) + conv2d_out = tf.keras.layers.Conv2D(1, 1, padding="same")(conv2d_out) + task_out = tf.keras.layers.Lambda(lambda y: tf.squeeze(y, axis=-1))(conv2d_out) + task_out = ffn_block(task_out, 2, cols) + task_out = tf.keras.layers.Dense(units=v["size"])(task_out) outputs[k] = task_out m = tf.keras.models.Model( inputs, outputs ) - save_output_layers() + if conf.print_model_summary: tf.keras.utils.plot_model(m, to_file="decoder.png", show_shapes=True, show_layer_names=False) return m -class MtVae(tf.keras.Model): +class TCAEModel(tf.keras.Model): def __init__(self, conf: LocalConfig): - super(MtVae, self).__init__() + super(TCAEModel, self).__init__() self.conf = LocalConfig() if conf is None else conf self.encoder = None self.decoder = None @@ -672,13 +276,3 @@ def call(self, inputs, training=None, mask=None): def get_config(self): pass - - -def get_model_from_config(conf): - print("Creating Auto Encoder") - mt_vae = MtVae(conf) - return mt_vae - # print("Creating Decoder") - # if conf.decoder_type == "rnn": - # return create_rnn_decoder(conf) - # return create_decoder(conf) diff --git a/tcae/train.py b/tcae/train.py index bca283c..1991b02 100644 --- a/tcae/train.py +++ b/tcae/train.py @@ -2,9 +2,9 @@ from tensorflow.python.keras.engine import data_adapter import os import warnings -from .dataset import get_dataset -from .model import get_model_from_config -from .localconfig import LocalConfig +from tcae.dataset import get_dataset +from tcae.model import TCAEModel +from tcae.localconfig import LocalConfig warnings.simplefilter("ignore") @@ -58,7 +58,7 @@ def train(conf: LocalConfig): # create and compile model model = ModelWrapper( - get_model_from_config(conf), + TCAEModel(conf), conf.data_handler.loss) model.compile( @@ -67,23 +67,24 @@ def train(conf: LocalConfig): # build model x, y = next(iter(train_dataset)) - y_pred = model(x) + _ = model(x) if conf.print_model_summary: print(model.summary()) # load model checkpoint - checkpoint_file = os.path.join(conf.checkpoints_dir, "cp.ckpt") - - if os.path.isdir(conf.checkpoints_dir) and os.listdir(conf.checkpoints_dir): - model.load_weights(checkpoint_file) + if conf.pretrained_model_path is not None: + if os.path.isfile(conf.pretrained_model_path): + model.load_weights(conf.pretrained_model_path) + else: + print(f"No pretrained model found at {conf.pretrained_model_path}") # create training callbacks checkpoint = tf.keras.callbacks.ModelCheckpoint( - filepath=checkpoint_file, + filepath=os.path.join(conf.checkpoints_dir, conf.model_name + "_{epoch}_{val_loss:.5f}.h5"), monitor='val_loss', save_weights_only=True, - verbose=1, + verbose=True, save_best_only=True, save_freq='epoch') @@ -110,10 +111,3 @@ def train(conf: LocalConfig): # evaluate model model.evaluate(test_dataset) - - -# ------------------------------------------------------------------------------ - -if __name__ == "__main__": - conf = LocalConfig() - train(conf) diff --git a/train_on_colab.ipynb b/train_on_colab.ipynb index 31ab790..1767b30 100644 --- a/train_on_colab.ipynb +++ b/train_on_colab.ipynb @@ -166,52 +166,6 @@ "execution_count": null, "outputs": [] }, - { - "cell_type": "markdown", - "metadata": { - "id": "JuZWnvpid21O" - }, - "source": [ - "## Freeze Everything Except 1 Output" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "HCFTAFskd99x" - }, - "source": [ - "# Output layers json file is created when model is built in a previous run\n", - "# So, if you need to train from scratch, this code cell is irrelevant\n", - "with open(\"output_layers.json\", \"r\") as f:\n", - " output_layers = json.load(f)\n", - "\n", - "print(output_layers.keys())\n", - "# This is same as mt_outputs except the phase diff output\n", - "print(conf.mt_outputs.keys())" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [ - "# One of the outputs selected to be unfrozen\n", - "# Since layer names are named automatically, if you see a mismatch\n", - "# Restart runtime and train without frozen/ unfrozen\n", - "# And then use the output_layers.json\n", - "unfrozen_layers = output_layers.get(\"h_mag_dist\")\n", - "print(unfrozen_layers)" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, { "cell_type": "markdown", "metadata": { @@ -228,10 +182,6 @@ "id": "IK8BksMivYyA" }, "source": [ - "# If training just an output submodel: (Only works when pretrained model path is specified)\n", - "# train.train(conf, unfrozen_layers)\n", - "\n", - "# If training from scratch\n", "train.train(conf)" ], "execution_count": null,