Skip to content

Commit

Permalink
Multi Hot encoding
Browse files Browse the repository at this point in the history
No ground truth plotting and no real metrics yet. Just loss.

related to: #6
  • Loading branch information
anthonio9 committed Jan 27, 2024
1 parent c4f569b commit a787645
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 34 deletions.
33 changes: 33 additions & 0 deletions Notes.adoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
Try dropout and different regularization techniques.

== Divide the dataset (important)

* division for train and validation sets based on the guitar players

== Regularization techniques

* do the parameters to below techniques manually (important)
* weight regularization, weight penalty, comes out of the box with Adam optimizer
* L1, L2 regularization
* ADd noise to inputs, gaussian noise of the same shape

== Hyperparameters tuning and the three sets (train, test, valid)


25.01.2024

== 1st experiment

multi hot piano, one array of 1440 instead of 6.

* use sigmoid istead of softmax
* use binary cross entropy instead of categorical cross entropy

set a threshold for the pitch recognition.

== 2nd experiment

take inspiration from here: https://arxiv.org/abs/1802.08435

* use 60 discrete values of MIDI per each string
* use a vector 1-60 for estimating the deviation
36 changes: 36 additions & 0 deletions config/fcnf0++-gset-multi-hot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
MODULE = 'penn'

# Configuration name
CONFIG = 'fcnf0++-gset-voiced'

# gset only
DATASETS = ['gset']

EVALUATION_DATASETS = ['gset']

STEPS = 50000

LOG_INTERVAL = 500

CHECKPOINT_INTERVAL = 5000 # steps

# audio parameters
SAMPLE_RATE = 11025

# the original hopsize is 256 samples, this is 4 times less than that
HOPSIZE = 64

# use only the voiced frames
VOICED_ONLY = True

LOSS = 'binary_cross_entropy'

LOSS_MULTI_HOT = True

GAUSSIAN_BLUR = False

GSET_SPLIT_PLAYERS = True

EVALUATE = False

PITCH_CATS = 6
6 changes: 6 additions & 0 deletions penn/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,10 @@
# Loss chunk
LOSS_CHUNKED = False

# Loss mutli-hot

LOSS_MULTI_HOT = False

# Number of training steps
STEPS = 250000

Expand All @@ -188,3 +192,5 @@

# Wheater to use the weight decay (L2 penalty)
WEIGHT_DECAY = None

EVALUATE = True
39 changes: 23 additions & 16 deletions penn/evaluate/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@ def __init__(self):
self.pitch_metrics = PitchMetrics()

def __call__(self):
return (
{
'accuracy': self.accuracy(),
'loss': self.loss()
} |
self.f1() |
self.pitch_metrics())
if penn.LOSS_MULTI_HOT:
return (
{
'loss': self.loss()
})
else:
return (
{
'accuracy': self.accuracy(),
'loss': self.loss()
} |
self.f1() |
self.pitch_metrics())

def update(self, logits, bins, target, voiced):
# Detach from graph
Expand All @@ -42,18 +48,19 @@ def update(self, logits, bins, target, voiced):
# Update loss
self.loss.update(logits[:, :penn.PITCH_BINS], bins.T)

# Decode bins, pitch, and periodicity
with torchutil.time.context('decode'):
predicted, pitch, periodicity = penn.postprocess(logits)
if not penn.LOSS_MULTI_HOT:
# Decode bins, pitch, and periodicity
with torchutil.time.context('decode'):
predicted, pitch, periodicity = penn.postprocess(logits)

# Update bin accuracy
self.accuracy.update(predicted[voiced], bins[voiced])
# Update bin accuracy
self.accuracy.update(predicted[voiced], bins[voiced])

# Update pitch metrics
self.pitch_metrics.update(pitch, target, voiced)
# Update pitch metrics
self.pitch_metrics.update(pitch, target, voiced)

# Update periodicity metrics
self.f1.update(periodicity, voiced)
# Update periodicity metrics
self.f1.update(periodicity, voiced)

def reset(self):
self.accuracy.reset()
Expand Down
34 changes: 18 additions & 16 deletions penn/plot/logits/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def process_logits(logits: torch.Tensor):
# NOTE - We use softmax even if the loss is BCE for more comparable
# visualization. Otherwise, the variance of models trained with
# BCE looks erroneously lower.
distributions = torch.nn.functional.softmax(logits, dim=1)

if penn.LOSS_MULTI_HOT:
distributions = torch.nn.functional.sigmoid(logits)
else:
distributions = torch.nn.functional.softmax(logits, dim=1)

# Take the log again for display
distributions = torch.log(distributions)
Expand Down Expand Up @@ -77,7 +81,8 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None):

distributions, figsize = process_logits(logits)

predicted_bins, pitch, periodicity = penn.postprocess(logits)
if not penn.LOSS_MULTI_HOT:
predicted_bins, pitch, periodicity = penn.postprocess(logits)

# Change font size
matplotlib.rcParams.update({'font.size': 5})
Expand Down Expand Up @@ -108,31 +113,30 @@ def logits_matplotlib(logits, bins=None, voiced=None, stem=None):
axis.set_ylabel('Frequency (Hz)')
axis.set_title(f"track: {stem}")

if bins is not None and voiced is not None:
if bins is not None and voiced is not None and not penn.LOSS_MULTI_HOT:
nbins = bins.detach().cpu().numpy()
nvoiced = voiced.detach().cpu().numpy()

npredicted_bins = predicted_bins.detach().cpu().numpy()

nbins = nbins.squeeze().T
npredicted_bins = npredicted_bins.squeeze().T
nvoiced = nvoiced.squeeze().T

offset = np.arange(0, penn.PITCH_CATS)*penn.PITCH_BINS

nbins += offset
npredicted_bins += offset

nbins_masked = np.ma.MaskedArray(nbins, np.logical_not(nvoiced))
npredicted_bins_masked = np.ma.MaskedArray(npredicted_bins, np.logical_not(nvoiced))

axis.plot(nbins_masked, 'r--', linewidth=2)
axis.plot(npredicted_bins_masked, 'b:', linewidth=2)

# Plot pitch posteriorgram
# if len(distributions.shape) == 4:
# axis.imshow(new_distributions, extent=[0,100,0,1], aspect=80, origin='lower')
# else:
# axis.imshow(new_distributions, aspect='auto', origin='lower')
if predicted_bins is not None:
npredicted_bins = predicted_bins.detach().cpu().numpy()
npredicted_bins = npredicted_bins.squeeze().T

npredicted_bins += offset
npredicted_bins_masked = np.ma.MaskedArray(npredicted_bins, np.logical_not(nvoiced))

axis.plot(npredicted_bins_masked, 'b:', linewidth=2)

axis.imshow(distributions, aspect='auto', origin='lower')

return figure
Expand Down Expand Up @@ -244,8 +248,6 @@ def from_file_to_file(audio_file=None, output_file=None, checkpoint=None, gpu=No
else:
figure = from_testset(checkpoint, gpu)

breakpoint()

# Save to disk
if output_file is not None:
figure.savefig(output_file, bbox_inches='tight', pad_inches=0, dpi=900)
17 changes: 15 additions & 2 deletions penn/train/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def evaluate(directory, step, model, gpu, condition, loader, log_wandb):
# Forward pass
logits = model(audio.to(device))

if len(logits.shape) == 4:
if len(logits.shape) == 4 or penn.LOSS_MULTI_HOT:
binsT = bins.permute(*torch.arange(bins.ndim - 1, -1, -1))
pitchT = pitch.permute(*torch.arange(pitch.ndim - 1, -1, -1))
voicedT = voiced.permute(*torch.arange(voiced.ndim - 1, -1, -1))
Expand Down Expand Up @@ -302,7 +302,7 @@ def evaluate(directory, step, model, gpu, condition, loader, log_wandb):
# Write to tensorboard
torchutil.tensorboard.update(directory, step, scalars=scalars)

return scalars[f'accuracy/{condition}']
return scalars[f'loss/{condition}']


###############################################################################
Expand Down Expand Up @@ -362,9 +362,22 @@ def get_bins(bins):
bins_chunks = bins.chunk(penn.PITCH_CATS, dim=1)
bins_chunks = [get_bins(chunk) for chunk in bins_chunks]
bins = torch.stack(bins_chunks)

elif penn.LOSS_MULTI_HOT:
bins_chunks = bins.chunk(penn.PITCH_CATS, dim=1)
bins_chunks = [get_bins(chunk) for chunk in bins_chunks]
bins = torch.stack(bins_chunks)

# combine all one-hot vectors into a single multi-hot vector
bins = torch.sum(bins, dim=0)

# it may happen that two strings are playing the same note, in that case interpret it as a single note
bins[bins > 1] = 1

else:
bins = get_bins(bins)


if penn.LOSS == 'binary_cross_entropy':

# Compute binary cross-entropy loss
Expand Down

0 comments on commit a787645

Please sign in to comment.