Skip to content

Commit

Permalink
Add a lot of newlines to stop cells merging together
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Aug 18, 2024
1 parent c201d2a commit e75177f
Showing 1 changed file with 28 additions and 4 deletions.
32 changes: 28 additions & 4 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@

# <div class="alert alert-info"><h4>
# Task 1.1: </h4>
# We have locally changed images of 7s artificially for this exercise. What are some examples of ways that images can be corrupted or tainted during real-life data colleciton, for example in a hospital imaging environment or microscopy lab?
# We have locally changed images of 7s artificially for this exercise. What are some examples of ways that images can be corrupted or tainted during real-life data collection, for example in a hospital imaging environment or microscopy lab?
# </div>

# + [markdown] tags=["solution"]
Expand Down Expand Up @@ -237,7 +237,10 @@
# Prevention is easer than fixing after generation!
# - PCA on metadata <3 to help detect such issues
# - Randomization of data generation (blind yourself to your samples, dont always put certain classes in certain wells, etc)
#
# -

#
# <div class="alert alert-info"><h4>
# Task 1.4:</h4>
# Given the changes we made to generate the tainted dataset, do you think a digit classification network trained on the tainted data will converge? Are the classes more or less distinct from each other than in the untainted dataset?
Expand All @@ -252,13 +255,17 @@
# **1.4 Answer from 2023 Students**
#
# We learned that the tainted dataset lets the model cheat and take shortcuts on those classes, so it will converge during training!
#
# -

#
# <div class="alert alert-success"><h3>
# Checkpoint 1</h3>
#
# Post to the course chat when you have reached Checkpoint 1. We will discuss all the questions and make more predictions!
# </div>

#
# <div class="alert alert-block alert-warning"><h3>
# Bonus Questions:</h3>
# Note that we only added the white dot to the images of 7s and the grid to images of 4s, not all classes.
Expand Down Expand Up @@ -286,7 +293,7 @@
# Now we will train the neural network. A training function is provided below - this should be familiar, but make sure you look it over and understand what is happening in the training loop.

# +
from tqdm import tqdm
from tqdm.auto import tqdm

# Training function:
def train_mnist(model, train_loader, batch_size, criterion, optimizer, history):
Expand All @@ -303,6 +310,8 @@ def train_mnist(model, train_loader, batch_size, criterion, optimizer, history):
history.append(loss.item())
pbar.update(1)
return history


# -

# We have to choose hyperparameters for our model. We have selected to train for two epochs, with a batch size of 64 for training and 1000 for testing. We are using the cross entropy loss, a standard multi-class classification loss.
Expand Down Expand Up @@ -481,6 +490,8 @@ def predict(model, dataset):
dataset_groundtruth.append(y_true)

return np.array(dataset_prediction), np.array(dataset_groundtruth)


# -

# Now we call the predict method with the clean and tainted models on the clean and tainted datasets.
Expand Down Expand Up @@ -669,6 +680,8 @@ def apply_integrated_gradients(test_input, model):
)

return attributions


# -

# Next we provide a function to visualize the output of integrated gradients, using the function above to actually run the algorithm.
Expand Down Expand Up @@ -705,6 +718,8 @@ def visualize_integrated_gradients(test_input, model, plot_title):
use_pyplot=False)
figure.suptitle(plot_title, y=0.95)
plt.tight_layout()


# -

# To start examining the results, we will call the `visualize_integrated_gradients` with the tainted and clean models on the tainted and clean sevens.
Expand Down Expand Up @@ -831,6 +846,8 @@ def visualize_integrated_gradients(test_input, model, plot_title):
# A simple function to add noise to tensors:
def add_noise(tensor, power=1.5):
return tensor * torch.rand(tensor.size()).to(tensor.device) ** power + 0.75*torch.randn(tensor.size()).to(tensor.device)


# -

# Next we will visualize a couple MNIST examples with and without noise.
Expand Down Expand Up @@ -906,6 +923,8 @@ def train_denoising_model(train_loader, model, criterion, optimizer, history):
# updates progress bar:
pbar.update(1)
return history


# -

# Here we choose hyperparameters and initialize the model and data loaders.
Expand Down Expand Up @@ -1120,7 +1139,7 @@ def visualize_denoising(model, dataset, index):
# **5.4 Answer:**
#
# The new denoiser has been trained on both MNIST and FashionMNIST, and as a result, it no longer insist on reshaping objects from the FashionMNIST dataset into digits. However, it seems to be performing slightly worse on the original MNIST (some of the digits are hardly recognisable).

#
# ### Train the denoiser on both MNIST and FashionMNIST, shuffling the training data
#
# We previously performed the training sequentially on the MNIST data first then followed by the FashionMNIST data. Now, we ask for the training data to be shuffled and observe the impact on performance. (noe the `shuffle=True` in the lines below)
Expand Down Expand Up @@ -1171,18 +1190,23 @@ def visualize_denoising(model, dataset, index):
# **5.5 Answer:**
#
# The denoiser trained on shuffled data performs well accross both MNIST and FashionMNIST, without having any particular issue with either of the two datasets.
#
# -

#
# <div class="alert alert-block alert-success"><h3>
# Checkpoint 5</h3>
# <ol>
# Congrats on reaching the final checkpoint! Let us know on Element, and we'll discuss the questions once reaching critical mass.
# </ol>
# </div>

#
# <div class="alert alert-block alert-warning"><h3>
# Bonus Questions</h3>
# <ol>
# <li>Try training a FashionMNIST denoising network and applying it to MNIST. Or, try training a denoising network on both datasets and see how it works on each.</li>
# <li>Go back to Part 4 and try another attribution method, such as <a href="https://captum.ai/api/saliency.html">Saliency</a>, and see how the results differ.</li>
# </ol>
# </div>

#

0 comments on commit e75177f

Please sign in to comment.