Skip to content

Commit

Permalink
(apichard) improve device dataloading where possible and add auto che…
Browse files Browse the repository at this point in the history
…ckpoints
  • Loading branch information
AlfredPichard committed Jan 18, 2024
1 parent 5177fc4 commit ed6be7e
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 7 deletions.
6 changes: 3 additions & 3 deletions src/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, n_channels = 128, alpha_steps = 100, time_emb_dim=16, start_c

self.encodec = WrappedEncodec().to(device)
self.alpha_steps = alpha_steps
self._config_prior(torch.zeros(n_channels, 1).to(device), torch.ones(n_channels, 1).to(device))
self._config_prior(torch.zeros(n_channels, 1, device=device), torch.ones(n_channels, 1, device=device))
self.device = device

def _config_prior(self, mean, std):
Expand All @@ -52,7 +52,7 @@ def inference(self, x_0 = None, n_batch = 1, n_frames = 1024, T = None):
if self.inference_bpm is not None:
conditioner = torch.from_numpy(np.array([phasor_from_bpm(self.inference_bpm) for _ in range(n_batch)]))[:,None,:].float().to(self.device)
denoised_samples = [x_0]
alpha = torch.arange(T).to(self.device)/T
alpha = torch.arange(T, device=self.device)/T

for t in range(1,T,1):
a = alpha[t]*torch.ones(n_batch, 1, device=self.device)
Expand All @@ -62,4 +62,4 @@ def inference(self, x_0 = None, n_batch = 1, n_frames = 1024, T = None):
return self.encodec.decode(denoised_samples[-1])

def sample(self, n_batch, n_frames):
return self.mean + self.std * torch.randn((n_batch, self.n_channels, n_frames)).to(self.device)
return self.mean + self.std * torch.randn((n_batch, self.n_channels, n_frames), device=self.device)
7 changes: 7 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
dataset = ds.SimpleDataset(path=DATA_PATH, keys=['encodec', 'metadata'], transforms=None, readonly=True)

batch = parser.args.batch
save_checkpoint_epochs = parser.args.save_checkpoint_epochs

train_dataset, valid_dataset= torch.utils.data.random_split(dataset, [0.8, 0.2])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch, shuffle = True, collate_fn = ds.collate_fn_padd, drop_last=True)
Expand Down Expand Up @@ -55,6 +56,9 @@
trainer.train_one_epoch(log = True)
if log:
trainer.validate()
if trainer.epoch % save_checkpoint_epochs == 0:
print("Autosaving checkpoint...")
trainer.checkpoint()
print('Training stopping, saving model')
trainer.checkpoint()
sys.exit()
Expand All @@ -64,6 +68,9 @@
trainer.train_one_epoch(log = True)
if log:
trainer.validate()
if trainer.epoch % save_checkpoint_epochs == 0:
print("Autosaving checkpoint...")
trainer.checkpoint()
except KeyboardInterrupt:
print('Training stopping, saving model')
trainer.checkpoint()
Expand Down
8 changes: 4 additions & 4 deletions src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def train_one_epoch(self, log = True):
data=np.array([ds.phasor(metadata[k]['BT_beat']) for k in range(batch_size)]))
bt_conditioner = torch.from_numpy(numpy_bt_conditioner)[:,None,:].float().to(self.DEVICE)
z_1 = data['encodec'].to(self.DEVICE)
a = torch.rand((batch_size)).view(batch_size, 1, 1).to(self.DEVICE)
z_0 = self.model.sample(batch_size, z_1.shape[-1]).to(self.DEVICE)
a = torch.rand((batch_size), device=self.DEVICE).view(batch_size, 1, 1)
z_0 = self.model.sample(batch_size, z_1.shape[-1])
z_a = ((1 - a)*z_0 + a*z_1)

diff_pred = self.model(z_a, a, bt_conditioner)
Expand Down Expand Up @@ -63,8 +63,8 @@ def validate(self):
metadata = data['metadata']
bt_conditioner = torch.from_numpy(np.array([ds.phasor(metadata[k]['BT_beat']) for k in range(batch_size)]))[:,None,:].float().to(self.DEVICE)
z_1 = data['encodec'].to(self.DEVICE)
a = torch.rand((batch_size)).view(batch_size, 1, 1).to(self.DEVICE)
z_0 = self.model.sample(batch_size, z_1.shape[-1]).to(self.DEVICE)
a = torch.rand((batch_size), device=self.DEVICE).view(batch_size, 1, 1)
z_0 = self.model.sample(batch_size, z_1.shape[-1])
z_a = ((1 - a)*z_0 + a*z_1)

diff_pred = self.model(z_a, a, bt_conditioner)
Expand Down
2 changes: 2 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class Parser:
# Default values
EPOCHS = None
LOG_EPOCHS = 10
SAVE_CHECKPOINT_EPOCHS = 100
MODEL = None
TRAINING = 'default'
BATCH = 16
Expand All @@ -178,6 +179,7 @@ def __init__(self):
self.parser.add_argument('-c', '--checkpoint', required=False, dest='checkpoint', action='store_const', const=True, default=False, help='If possible start training from last checkpoint')
self.parser.add_argument('-e', '--epochs', type=int, default=self.EPOCHS, dest='epochs', required=False, nargs='?', help='Specify training epochs, default is infinite')
self.parser.add_argument('-l', '--log', type=int, default=self.LOG_EPOCHS, dest='epochs_log', required=False, nargs='?', help=f'Specify after how many epochs to log, default is {self.LOG_EPOCHS}')
self.parser.add_argument('-s', '--save_checkpoint', type=int, default=self.SAVE_CHECKPOINT_EPOCHS, dest='save_checkpoint_epochs', required=False, nargs='?', help=f'Specify after how many epochs to autosave checkpoint, default is {self.SAVE_CHECKPOINT_EPOCHS}')
self.parser.add_argument('-m', '--model', type=str, default=self.MODEL, dest='model', required=False, nargs='?', help='Model configuration to load')
self.parser.add_argument('-t', '--training', type=str, default=self.TRAINING, dest='train', required=False, nargs='?', help='Training configuration to load')
self.parser.add_argument('-b', '--batch', type=int, default=self.BATCH, dest='batch', required=False, nargs='?', help='Training batch size')
Expand Down

0 comments on commit ed6be7e

Please sign in to comment.