Skip to content

Commit

Permalink
upload code
Browse files Browse the repository at this point in the history
  • Loading branch information
davidtellez committed Mar 14, 2020
1 parent cdc6e95 commit 7a992f7
Show file tree
Hide file tree
Showing 15 changed files with 1,302 additions and 2 deletions.
49 changes: 47 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,47 @@
# contrastive-predictive-coding-images
Keras implementation of Representation Learning with Contrastive Predictive Coding for images
### Representation Learning with Contrastive Predictive Coding for images

This repository contains a Keras implementation of contrastive-predictive-coding for **images**, an algorithm fully described in:
* [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/abs/1807.03748).
* [Data-Efficient Image Recognition with Contrastive Predictive Coding](https://arxiv.org/abs/1905.09272).

The goal of unsupervised representation learning is to capture semantic information about the world, recognizing regular patterns in the data without using annotations. This paper presents a new method called Contrastive Predictive Coding (CPC) that can do so across multiple applications. This implementation covers the case of CPC applied to images (vision).

In a nutshell, an input image is divided into a grid of overlapping patches, and each patch is embedded using an encoding network. From these embeddings, the model computes a context vector at each position using masked convolutions (do not have access to future pixels). These context vectors propagate and integrate spatial information. Given a row in the grid of vectors, the model predicts entire rows below at different offsets (2 rows below, 3 rows below, etc.). These predictions should match the embedded vectors computed at the very beginning. The model is optimized to find the correct embedding when vectors from other patches are considered.

My code is optimized for readability and it is meant to be used as a resource to understand how CPC works. Therefore, I would like to explain a few concepts that I found challenging to understand in the papers.


<p align="center">
<img src="/resources/context.png" alt="CPC algorithm - context" height="150">
</p>

In this figure, horizontal lines in the left represent embedded input image rows (7 embedding vectors), and triangles correspond to masked convolutions. We can see how all the orange rows (0 to 4) contribute to the context vector of row 3 (center). Although masked convolutions prevent information from lower rows to flow to the context, notice how input row 4 makes its way to the end. This is because patches are extracted with overlapping, that is, patches encoded in row 3 (left) contain pixels from row 4 below (in yellow). For this reason, we should never optimize the CPC model to predict just one row below, since the information would flow directly from the input preventing it from learning anything useful.

Once the context vectors are computed, we can proceed with the actual row predictions. We cannot predict the row below, but we can predict the rest. Depending on how far we would like to predict (offset), we will use different prediction networks. In total, there are 5 prediction networks.


<p align="center">
<img src="/resources/offsets.png" alt="CPC algorithm - offset" height="300">
</p>

Given a set of context vectors, we apply each prediction network to all rows, however, not all predictions will be used. These predictions are aligned with the embedded input image rows taking into account the row offset used during the prediction (see bottom of the figure for a detailed mapping). This operation is performed in ```cpc_model.py > CPCLayer > align_embeddings()```.

To train the CPC algorithm, I have created a toy dataset based on 64x64 color MNIST.

<p align="center">
<img src="/resources/samples.png" alt="CPC algorithm - samples" height="100">
</p>

Disclaimer: this code is provided *as is*, if you encounter a bug please report it as an issue. Your help will be much welcomed!

### Usage

- Execute ```python train_cpc.py``` to train the CPC model.
- Execute ```python train_classifier.py``` to train a classifier on top of the CPC encoder.

### Requisites

- [Anaconda Python 3.5.3](https://www.continuum.io/downloads)
- [Keras 2.0.6](https://keras.io/)
- [Tensorflow 1.4.0](https://www.tensorflow.org/)
- GPU for fast training.
66 changes: 66 additions & 0 deletions classifier_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
This module implements the image classifier model that uses a pretrained CPC to extract features.
"""

import keras

from cpc_model import network_encoder, get_custom_objects_cpc


def network_classifier(encoder_path, crop_shape, n_crops, code_size, lr, n_classes):
"""
Builds a Keras model that make predictions of image crops using a pretrained CPC encoder to extract features.
:param encoder_path: path to pretrained CPC encoder model.
:param crop_shape: size of the image crops/patches (16, 16, 3).
:param n_crops: resulting number of image crops (for example 7 for a 7x7 grid of crops).
:param code_size: length of embedding vector.
:param lr: optimizer's learning rate during training.
:param n_classes: number of possible predicted classes.
:return: compiled Keras model.
"""

if encoder_path is not None:
print('Reading encoder from disk and freezing weights.', flush=True)
encoder_model = keras.models.load_model(encoder_path, custom_objects=get_custom_objects_cpc())
encoder_model.trainable = False
for l in encoder_model.layers:
l.trainable = False
else:
encoder_input = keras.layers.Input(crop_shape)
encoder_output = network_encoder(encoder_input, code_size)
encoder_model = keras.models.Model(encoder_input, encoder_output, name='encoder')
encoder_model.summary()

# Crops feature extraction
x_input = keras.layers.Input((n_crops, n_crops) + crop_shape)
x = keras.layers.Reshape((n_crops * n_crops, ) + crop_shape)(x_input)
x = keras.layers.TimeDistributed(encoder_model)(x)
x = keras.layers.Reshape((n_crops, n_crops, code_size))(x)

# # Define the classifier
# x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x) #
# x = LayerNormalization()(x)
# x = keras.layers.LeakyReLU()(x)
# x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, activation='linear')(x) #
# x = LayerNormalization()(x)
# x = keras.layers.LeakyReLU()(x)

x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(units=code_size, activation='linear')(x)
x = keras.layers.LeakyReLU()(x)
x = keras.layers.Dense(units=n_classes, activation='softmax')(x)

# Model
model = keras.models.Model(inputs=x_input, outputs=x)

# Compile model
model.compile(
optimizer=keras.optimizers.Adam(lr=lr),
loss='categorical_crossentropy',
metrics=['categorical_accuracy']
)
model.summary()

return model

Loading

0 comments on commit 7a992f7

Please sign in to comment.