-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
cdc6e95
commit 7a992f7
Showing
15 changed files
with
1,302 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,47 @@ | ||
# contrastive-predictive-coding-images | ||
Keras implementation of Representation Learning with Contrastive Predictive Coding for images | ||
### Representation Learning with Contrastive Predictive Coding for images | ||
|
||
This repository contains a Keras implementation of contrastive-predictive-coding for **images**, an algorithm fully described in: | ||
* [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/abs/1807.03748). | ||
* [Data-Efficient Image Recognition with Contrastive Predictive Coding](https://arxiv.org/abs/1905.09272). | ||
|
||
The goal of unsupervised representation learning is to capture semantic information about the world, recognizing regular patterns in the data without using annotations. This paper presents a new method called Contrastive Predictive Coding (CPC) that can do so across multiple applications. This implementation covers the case of CPC applied to images (vision). | ||
|
||
In a nutshell, an input image is divided into a grid of overlapping patches, and each patch is embedded using an encoding network. From these embeddings, the model computes a context vector at each position using masked convolutions (do not have access to future pixels). These context vectors propagate and integrate spatial information. Given a row in the grid of vectors, the model predicts entire rows below at different offsets (2 rows below, 3 rows below, etc.). These predictions should match the embedded vectors computed at the very beginning. The model is optimized to find the correct embedding when vectors from other patches are considered. | ||
|
||
My code is optimized for readability and it is meant to be used as a resource to understand how CPC works. Therefore, I would like to explain a few concepts that I found challenging to understand in the papers. | ||
|
||
|
||
<p align="center"> | ||
<img src="/resources/context.png" alt="CPC algorithm - context" height="150"> | ||
</p> | ||
|
||
In this figure, horizontal lines in the left represent embedded input image rows (7 embedding vectors), and triangles correspond to masked convolutions. We can see how all the orange rows (0 to 4) contribute to the context vector of row 3 (center). Although masked convolutions prevent information from lower rows to flow to the context, notice how input row 4 makes its way to the end. This is because patches are extracted with overlapping, that is, patches encoded in row 3 (left) contain pixels from row 4 below (in yellow). For this reason, we should never optimize the CPC model to predict just one row below, since the information would flow directly from the input preventing it from learning anything useful. | ||
|
||
Once the context vectors are computed, we can proceed with the actual row predictions. We cannot predict the row below, but we can predict the rest. Depending on how far we would like to predict (offset), we will use different prediction networks. In total, there are 5 prediction networks. | ||
|
||
|
||
<p align="center"> | ||
<img src="/resources/offsets.png" alt="CPC algorithm - offset" height="300"> | ||
</p> | ||
|
||
Given a set of context vectors, we apply each prediction network to all rows, however, not all predictions will be used. These predictions are aligned with the embedded input image rows taking into account the row offset used during the prediction (see bottom of the figure for a detailed mapping). This operation is performed in ```cpc_model.py > CPCLayer > align_embeddings()```. | ||
|
||
To train the CPC algorithm, I have created a toy dataset based on 64x64 color MNIST. | ||
|
||
<p align="center"> | ||
<img src="/resources/samples.png" alt="CPC algorithm - samples" height="100"> | ||
</p> | ||
|
||
Disclaimer: this code is provided *as is*, if you encounter a bug please report it as an issue. Your help will be much welcomed! | ||
|
||
### Usage | ||
|
||
- Execute ```python train_cpc.py``` to train the CPC model. | ||
- Execute ```python train_classifier.py``` to train a classifier on top of the CPC encoder. | ||
|
||
### Requisites | ||
|
||
- [Anaconda Python 3.5.3](https://www.continuum.io/downloads) | ||
- [Keras 2.0.6](https://keras.io/) | ||
- [Tensorflow 1.4.0](https://www.tensorflow.org/) | ||
- GPU for fast training. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
""" | ||
This module implements the image classifier model that uses a pretrained CPC to extract features. | ||
""" | ||
|
||
import keras | ||
|
||
from cpc_model import network_encoder, get_custom_objects_cpc | ||
|
||
|
||
def network_classifier(encoder_path, crop_shape, n_crops, code_size, lr, n_classes): | ||
""" | ||
Builds a Keras model that make predictions of image crops using a pretrained CPC encoder to extract features. | ||
:param encoder_path: path to pretrained CPC encoder model. | ||
:param crop_shape: size of the image crops/patches (16, 16, 3). | ||
:param n_crops: resulting number of image crops (for example 7 for a 7x7 grid of crops). | ||
:param code_size: length of embedding vector. | ||
:param lr: optimizer's learning rate during training. | ||
:param n_classes: number of possible predicted classes. | ||
:return: compiled Keras model. | ||
""" | ||
|
||
if encoder_path is not None: | ||
print('Reading encoder from disk and freezing weights.', flush=True) | ||
encoder_model = keras.models.load_model(encoder_path, custom_objects=get_custom_objects_cpc()) | ||
encoder_model.trainable = False | ||
for l in encoder_model.layers: | ||
l.trainable = False | ||
else: | ||
encoder_input = keras.layers.Input(crop_shape) | ||
encoder_output = network_encoder(encoder_input, code_size) | ||
encoder_model = keras.models.Model(encoder_input, encoder_output, name='encoder') | ||
encoder_model.summary() | ||
|
||
# Crops feature extraction | ||
x_input = keras.layers.Input((n_crops, n_crops) + crop_shape) | ||
x = keras.layers.Reshape((n_crops * n_crops, ) + crop_shape)(x_input) | ||
x = keras.layers.TimeDistributed(encoder_model)(x) | ||
x = keras.layers.Reshape((n_crops, n_crops, code_size))(x) | ||
|
||
# # Define the classifier | ||
# x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')(x) # | ||
# x = LayerNormalization()(x) | ||
# x = keras.layers.LeakyReLU()(x) | ||
# x = keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, activation='linear')(x) # | ||
# x = LayerNormalization()(x) | ||
# x = keras.layers.LeakyReLU()(x) | ||
|
||
x = keras.layers.GlobalAveragePooling2D()(x) | ||
x = keras.layers.Dense(units=code_size, activation='linear')(x) | ||
x = keras.layers.LeakyReLU()(x) | ||
x = keras.layers.Dense(units=n_classes, activation='softmax')(x) | ||
|
||
# Model | ||
model = keras.models.Model(inputs=x_input, outputs=x) | ||
|
||
# Compile model | ||
model.compile( | ||
optimizer=keras.optimizers.Adam(lr=lr), | ||
loss='categorical_crossentropy', | ||
metrics=['categorical_accuracy'] | ||
) | ||
model.summary() | ||
|
||
return model | ||
|
Oops, something went wrong.