-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fad7ea7
commit 0c6a9ba
Showing
1 changed file
with
124 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"colab": { | ||
"provenance": [] | ||
}, | ||
"kernelspec": { | ||
"name": "python3", | ||
"display_name": "Python 3" | ||
}, | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"%cd libs" | ||
], | ||
"metadata": { | ||
"colab": { | ||
"base_uri": "https://localhost:8080/" | ||
}, | ||
"id": "N9JGQR8jrHCZ", | ||
"outputId": "7838d1d1-9e78-4fe8-f6ac-f7f72f8b6555" | ||
}, | ||
"execution_count": 1, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"/content/libs\n" | ||
] | ||
} | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"source": [ | ||
"from torch.utils.data import Dataset, DataLoader, random_split\n", | ||
"import torch\n", | ||
"\n", | ||
"class RandomDataset(Dataset):\n", | ||
" def __init__(self):\n", | ||
" self.x = 10\n", | ||
" def __len__(self):\n", | ||
" return 100\n", | ||
" def __getitem__(self, idx):\n", | ||
" return torch.randn(128, 512), 1" | ||
], | ||
"metadata": { | ||
"id": "bxTHPqrTQ9jG" | ||
}, | ||
"execution_count": 2, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"id": "jZjGE-n6-6yw", | ||
"colab": { | ||
"base_uri": "https://localhost:8080/", | ||
"height": 567 | ||
}, | ||
"outputId": "327ad6cf-e91b-4aec-b8f9-85fa76c1a554" | ||
}, | ||
"outputs": [ | ||
{ | ||
"output_type": "stream", | ||
"name": "stderr", | ||
"text": [ | ||
"/usr/local/lib/python3.10/dist-packages/torch/nn/utils/weight_norm.py:134: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.\n", | ||
" WeightNorm.apply(module, name, dim)\n" | ||
] | ||
}, | ||
{ | ||
"output_type": "stream", | ||
"name": "stdout", | ||
"text": [ | ||
"Training on cpu\n", | ||
"Epoch 0, Iteration 0, loss = 2.2128\n", | ||
"Epoch 0, Iteration 1, loss = 0.8272\n", | ||
"Epoch 0, Iteration 2, loss = 12.6069\n", | ||
"Epoch 0, Iteration 3, loss = 9.4549\n", | ||
"Epoch 0, Iteration 4, loss = 11.6616\n", | ||
"Epoch 0, Iteration 5, loss = 1.4846\n", | ||
"Epoch 0, Iteration 6, loss = 6.6838\n", | ||
"Epoch 0, Iteration 7, loss = 0.0001\n", | ||
"Epoch 0, Iteration 8, loss = 0.0000\n", | ||
"Epoch 0, Iteration 9, loss = 34.3141\n" | ||
] | ||
}, | ||
{ | ||
"output_type": "error", | ||
"ename": "KeyboardInterrupt", | ||
"evalue": "", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | ||
"\u001b[0;32m<ipython-input-3-3a2d01619373>\u001b[0m in \u001b[0;36m<cell line: 7>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mtask\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mCPC\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mWav2Vec2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mtask\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'num_epochs'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'batch_size'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'print_every'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | ||
"\u001b[0;32m/content/libs/ssl_utils.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, model, train_params)\u001b[0m\n\u001b[1;32m 619\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mct\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfuture_embeddings\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnegative_embeddings\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 620\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 621\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 622\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 623\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | ||
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 519\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 520\u001b[0m )\n\u001b[0;32m--> 521\u001b[0;31m torch.autograd.backward(\n\u001b[0m\u001b[1;32m 522\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m )\n", | ||
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 287\u001b[0m \u001b[0;31m# some Python versions print out the first line of a multi-line function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0;31m# calls in the traceback and some print out the last line\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 289\u001b[0;31m _engine_run_backward(\n\u001b[0m\u001b[1;32m 290\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 291\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | ||
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py\u001b[0m in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 766\u001b[0m \u001b[0munregister_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_register_logging_hooks_on_whole_graph\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt_outputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 767\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 768\u001b[0;31m return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n\u001b[0m\u001b[1;32m 769\u001b[0m \u001b[0mt_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 770\u001b[0m ) # Calls into the C++ engine to run the backward pass\n", | ||
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from ssl_model import Wav2Vec2\n", | ||
"from ssl_utils import MaskedContrastiveLearningTask, RelativePositioningTask, TemporalShufflingTask, CPC\n", | ||
"\n", | ||
"data = RandomDataset()\n", | ||
"task = CPC(data)\n", | ||
"model = Wav2Vec2([10, 3, 3, 3, 3, 2, 2], [2, 1, 1, 1, 1, 1, 1])\n", | ||
"task.train(model, {'num_epochs': 10, 'batch_size': 1, 'print_every': 1})" | ||
] | ||
} | ||
] | ||
} |