Skip to content

Commit

Permalink
Add fashion_mnist notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
kristinagrig06 committed Apr 23, 2021
1 parent ec5683a commit 1064b7a
Show file tree
Hide file tree
Showing 4 changed files with 698 additions and 211 deletions.
105 changes: 105 additions & 0 deletions activeloop/fashion_mnist/ReadMe.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install required libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install hub\n",
"!pip install matplotlib"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import hub"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds = hub.load(\"activeloop/fashion_mnist_train\")\n",
"ds_test = hub.load(\"activeloop/fashion_mnist_test\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"img = ds[\"image\", 5].compute()\n",
"plt.imshow(img)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"img = ds_test[\"image\", 5].compute()\n",
"plt.imshow(img)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6.9 64-bit ('venv': venv)",
"language": "python",
"name": "python36964bitvenvvenv2f96fd8b646d44e3b9e0885369a63a85"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
289 changes: 289 additions & 0 deletions activeloop/fashion_mnist/Train.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install required libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"!pip install hub\n",
"!pip install matplotlib\n",
"!pip install torch"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import ConcatDataset\n",
"import torch\n",
"from torch.utils.data import random_split\n",
"import hub\n",
"from hub.compute.generic.ds_transforms import shift_scale_rotate, transpose\n",
"from hub.api.sharded_datasetview import ShardedDatasetView"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds = hub.load(\"activeloop/fashion_mnist_train\")\n",
"ds_test = hub.load(\"activeloop/fashion_mnist_test\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"img = ds[\"image\", 5].compute()\n",
"plt.imshow(img)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Augment images and add to the original Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"ds_augmented = transpose(shift_scale_rotate(ds, keys=['image'], rotate_limit=0, shift_limit=0.1), keys=['image'], p=0.2)\n",
"ds_augmented = ds_augmented.store(\"/tmp/fashion_mnist_aug\")\n",
"ds_sharded = ShardedDatasetView([ds, ds_augmented])\n",
"\n",
"@hub.transform(schema=ds_sharded.schema, scheduler=\"threaded\", workers=8)\n",
"def transform_identity(sample):\n",
" return sample\n",
"\n",
"ds = transform_identity(ds_sharded).store('/tmp/fashion_mnist_all')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define a model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch.nn import Module, Sequential, Conv2d, BatchNorm2d, ReLU, MaxPool2d, Linear\n",
"class Net(Module): \n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
"\n",
" self.cnn_layers = Sequential(\n",
" Conv2d(1, 4, kernel_size=3, stride=1, padding=1),\n",
" BatchNorm2d(4),\n",
" ReLU(inplace=True),\n",
" MaxPool2d(kernel_size=2, stride=2),\n",
" )\n",
"\n",
" self.linear_layers = Sequential(\n",
" Linear(784, 10)\n",
" )\n",
" \n",
" def forward(self, x):\n",
" x = self.cnn_layers(x)\n",
" x = x.view(x.size(0), -1)\n",
" x = self.linear_layers(x)\n",
" return x\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training and validation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def train(trainloader: torch.utils.data.DataLoader, valloader: torch.utils.data.DataLoader, net: nn.Module):\n",
" criterion = torch.nn.CrossEntropyLoss()\n",
" optimizer = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)\n",
" for epoch in range(2):\n",
" print(f\"Epoch {epoch}\")\n",
" running_loss = 0.0\n",
" for i, data in enumerate(trainloader, 0):\n",
" X, y = data\n",
" X = X.permute(0, 3, 1, 2)\n",
" optimizer.zero_grad()\n",
" outputs = net(X)\n",
" loss = criterion(outputs, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" running_loss += loss.item()\n",
" if not i % 1000:\n",
" print(f\"Loss {loss.item()}\")\n",
" validate(net, valloader)\n",
" print(\"Finished Training\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def validate(net, valloader):\n",
" correct_count, all_count = 0, 0\n",
" for i, data in enumerate(valloader):\n",
" X, y = data\n",
" if len(X.shape) != 4:\n",
" X = torch.unsqueeze(X, 0)\n",
" X = X.permute(0, 3, 1, 2)\n",
" with torch.no_grad():\n",
" outputs = net(X)\n",
" pred_label = outputs.argmax(1)\n",
" correct_count += np.sum(pred_label.numpy() == y.numpy())\n",
" all_count += len(pred_label)\n",
"\n",
" print(\"Number Of Images Tested =\", all_count)\n",
" print(\"\\nModel Accuracy =\", (correct_count/all_count))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Convert to PyTorch, split the data and train"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def transform(data):\n",
" img = data['image'] / 255\n",
" label = data['label']\n",
" return img, label"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch_ds = ds.to_pytorch(transform=transform)\n",
"net = Net()\n",
"train_len = int(0.8 * len(torch_ds))\n",
"test_len = len(torch_ds) - train_len\n",
"train_ds, val_ds = random_split(torch_ds, [train_len, test_len])\n",
"train_dataloader = torch.utils.data.DataLoader(\n",
" train_ds,\n",
" batch_size=8,\n",
" shuffle=True,\n",
" num_workers=8\n",
" )\n",
"val_dataloader = torch.utils.data.DataLoader(\n",
" val_ds,\n",
" batch_size=8,\n",
" shuffle=False,\n",
" num_workers=8\n",
" )\n",
"train(train_dataloader, val_dataloader, net)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch_ds_test = ds_test.to_pytorch(transform=transform)\n",
"validate(net, torch_ds_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"torch.save(net, \"/tmp/model_fashion_mnist.pth\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6.9 64-bit ('venv': venv)",
"language": "python",
"name": "python36964bitvenvvenv2f96fd8b646d44e3b9e0885369a63a85"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 1064b7a

Please sign in to comment.