From 48a8eacf2c0cdcacac1e22e69e60136b892d5cb8 Mon Sep 17 00:00:00 2001 From: accesstbilq <157091400+accesstbilq@users.noreply.github.com> Date: Fri, 27 Dec 2024 18:04:16 +0530 Subject: [PATCH] Update hparams.py Update hparams.py --- jukebox/hparams.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/jukebox/hparams.py b/jukebox/hparams.py index eb74584aa1..3870509d18 100644 --- a/jukebox/hparams.py +++ b/jukebox/hparams.py @@ -207,6 +207,13 @@ def setup_hparams(hparam_set_names, kwargs): ) HPARAMS_REGISTRY["small_vqvae"] = small_vqvae +custom_vqvae = Hyperparams( + restore_vqvae="https://genxx.s3.us-east-1.amazonaws.com/small_vqvae/checkpoint_step_200001.pth.tar", +) +custom_vqvae.update(small_vqvae) +HPARAMS_REGISTRY["custom_vqvae"] = custom_vqvae + + small_prior = Hyperparams( n_ctx=8192, prior_width=1024, @@ -219,6 +226,18 @@ def setup_hparams(hparam_set_names, kwargs): ) HPARAMS_REGISTRY["small_prior"] = small_prior +custom_prior = Hyperparams( + restore_prior="https://genxx.s3.us-east-1.amazonaws.com/small_prior/checkpoint_latest.pth.tar", + level=2, + labels=False, + alignment_layer=None, + alignment_head=None, +) +custom_prior.update(small_prior) +HPARAMS_REGISTRY["custom_prior"] = custom_prior + + + small_labelled_prior = Hyperparams( labels=True, labels_v3=True, @@ -231,6 +250,8 @@ def setup_hparams(hparam_set_names, kwargs): small_labelled_prior.update(small_prior) HPARAMS_REGISTRY["small_labelled_prior"] = small_labelled_prior + + small_single_enc_dec_prior = Hyperparams( n_ctx=6144, prior_width=1024, @@ -303,6 +324,15 @@ def setup_hparams(hparam_set_names, kwargs): HPARAMS_REGISTRY["small_upsampler"] = small_upsampler + +custom_upsampler = Hyperparams( + restore_prior="https://genxx.s3.us-east-1.amazonaws.com/small_upsampler/checkpoint_latest.pth.tar", + level=0, + labels=False, +) +custom_upsampler.update(small_upsampler) +HPARAMS_REGISTRY["custom_upsampler"] = custom_upsampler + all_fp16 = Hyperparams( fp16=True, fp16_params=True, @@ -484,7 +514,7 @@ def setup_hparams(hparam_set_names, kwargs): ) DEFAULTS["opt"] = Hyperparams( - epochs=10000, + epochs=50, lr=0.0003, clip=1.0, beta1=0.9,