generated from intsystems/ProjectTemplate
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1 parent
a7cdb9e
commit a6e2f30
Showing
9 changed files
with
653 additions
and
389 deletions.
There are no files selected for viewing
File renamed without changes.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch.nn as nn | ||
|
||
|
||
class BaselineEEGEncoder(nn.Module): | ||
"""Encoder for EEG""" | ||
|
||
def __init__(self, in_channels=8, dilation_filters=16, kernel_size=3, layers=3): | ||
super(BaselineEEGEncoder, self).__init__() | ||
|
||
self.eeg_convos = nn.Sequential() | ||
|
||
for layer_index in range(layers): | ||
self.eeg_convos.add_module(f"conv1d_lay{layer_index}", | ||
nn.Conv1d( | ||
in_channels=dilation_filters * (layer_index != 0) + ( | ||
layer_index == 0) * in_channels, | ||
out_channels=dilation_filters, | ||
kernel_size=kernel_size, | ||
dilation=kernel_size ** layer_index, | ||
bias=True)) | ||
self.eeg_convos.add_module(f"relu_lay{layer_index}", nn.ReLU()) | ||
|
||
def forward(self, eeg): | ||
return self.eeg_convos(eeg) | ||
|
||
|
||
class MultiheadAttentionEEGEncoder(nn.Module): | ||
"""EEG Encoder using transformer""" | ||
|
||
def __init__(self, embed_dim, ff_dim): | ||
super(MultiheadAttentionEEGEncoder, self).__init__() | ||
|
||
self.mha_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=2) | ||
self.ffn = nn.Sequential(nn.Linear(embed_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, embed_dim)) | ||
self.layer_norm1 = nn.LayerNorm(embed_dim, eps=1e-6) | ||
self.layer_norm2 = nn.LayerNorm(embed_dim, eps=1e-6) | ||
self.dropout1 = nn.Dropout(p=0.5) | ||
self.dropout2 = nn.Dropout(p=0.5) | ||
|
||
def forward(self, x): | ||
attn_output, _ = self.mha_attention(x, x, x) | ||
attn_output = self.dropout1(attn_output) | ||
out1 = self.layer_norm1(attn_output + x) | ||
ffn_output = self.ffn(out1) | ||
ffn_output = self.dropout2(ffn_output) | ||
out = self.layer_norm2(out1 + ffn_output) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import torch | ||
import torch.nn as nn | ||
from src.mylib.models.eeg_encoders import BaselineEEGEncoder, MultiheadAttentionEEGEncoder | ||
from src.mylib.models.stimulus_encoders import BaselineStimulusEncoder | ||
|
||
|
||
class BaselineModel(nn.Module): | ||
"""Baseline model""" | ||
|
||
def __init__(self, | ||
layers=3, | ||
kernel_size=3, | ||
spatial_filters=8, | ||
dilation_filters=16): | ||
super(BaselineModel, self).__init__() | ||
|
||
# EEG spatial transformation | ||
self.spatial_transformation = nn.Conv1d( | ||
in_channels=64, | ||
out_channels=spatial_filters, | ||
kernel_size=1, | ||
bias=True | ||
) | ||
|
||
args = {"dilation_filters": dilation_filters, "kernel_size": kernel_size, "layers": layers} | ||
|
||
# EEG encoder | ||
self.eeg_encoder = BaselineEEGEncoder(in_channels=spatial_filters, **args) | ||
|
||
# Stimulus encoder | ||
self.stimulus_encoder = BaselineStimulusEncoder(**args) | ||
|
||
self.fc = nn.Linear(in_features=dilation_filters * dilation_filters, | ||
out_features=1, | ||
bias=True) | ||
|
||
def forward(self, eeg, stimuli): | ||
eeg = self.spatial_transformation(eeg) | ||
eeg = self.eeg_encoder(eeg) | ||
|
||
# shared weights for stimuli | ||
for i in range(len(stimuli)): | ||
stimuli[i] = self.stimulus_encoder(stimuli[i]) | ||
|
||
cosine_sim = [] | ||
for stimulus in stimuli: | ||
cosine_sim.append(eeg @ stimulus.transpose(-1, -2)) | ||
sim_projections = [self.fc(torch.flatten(sim, start_dim=1)) for sim in cosine_sim] | ||
return torch.cat(sim_projections, dim=1) | ||
|
||
|
||
class MHAModel(nn.Module): | ||
"""Model with transformer block as spatial transformation""" | ||
|
||
def __init__(self, | ||
layers=3, | ||
kernel_size=3, | ||
dilation_filters=16): | ||
super(MHAModel, self).__init__() | ||
|
||
# EEG spatial transformation | ||
self.spatial_transformation = MultiheadAttentionEEGEncoder(embed_dim=64, ff_dim=32) | ||
|
||
args = {"dilation_filters": dilation_filters, "kernel_size": kernel_size, "layers": layers} | ||
|
||
# EEG encoder | ||
self.eeg_encoder = BaselineEEGEncoder(in_channels=64, **args) | ||
|
||
# Stimulus encoder | ||
self.stimulus_encoder = BaselineStimulusEncoder(**args) | ||
|
||
self.fc = nn.Linear(in_features=dilation_filters * dilation_filters, | ||
out_features=1, | ||
bias=True) | ||
|
||
def forward(self, eeg, stimuli): | ||
eeg = self.spatial_transformation(eeg.transpose(1, 2)) | ||
eeg = self.eeg_encoder(eeg.transpose(1, 2)) | ||
|
||
# shared weights for stimuli | ||
for i in range(len(stimuli)): | ||
stimuli[i] = self.stimulus_encoder(stimuli[i]) | ||
|
||
cosine_sim = [] | ||
for stimulus in stimuli: | ||
cosine_sim.append(eeg @ stimulus.transpose(-1, -2)) | ||
sim_projections = [self.fc(torch.flatten(sim, start_dim=1)) for sim in cosine_sim] | ||
return torch.cat(sim_projections, dim=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import torch.nn as nn | ||
|
||
|
||
class BaselineStimulusEncoder(nn.Module): | ||
"""Stimulus encoder from baseline solution""" | ||
|
||
def __init__(self, dilation_filters=16, kernel_size=3, layers=3): | ||
super(BaselineStimulusEncoder, self).__init__() | ||
|
||
self.env_convos = nn.Sequential() | ||
for layer_index in range(layers): | ||
self.env_convos.add_module(f"conv1d_lay{layer_index}", | ||
nn.Conv1d( | ||
in_channels=dilation_filters * (layer_index != 0) + (layer_index == 0), | ||
out_channels=dilation_filters, | ||
kernel_size=kernel_size, | ||
dilation=kernel_size ** layer_index, | ||
bias=True)) | ||
self.env_convos.add_module(f"relu_lay{layer_index}", nn.ReLU()) | ||
|
||
def forward(self, stimulus): | ||
return self.env_convos(stimulus) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,132 +1,150 @@ | ||
#!/usr/bin/env python3 | ||
# -*- coding: utf-8 -*- | ||
''' | ||
The :mod:`mylib.train` contains classes: | ||
import os | ||
|
||
- :class:`mylib.train.Trainer` | ||
import torch | ||
from torch.utils.tensorboard import SummaryWriter | ||
import torch.nn as nn | ||
|
||
The :mod:`mylib.train` contains functions: | ||
- :func:`mylib.train.cv_parameters` | ||
''' | ||
from __future__ import print_function | ||
|
||
__docformat__ = 'restructuredtext' | ||
|
||
import numpy | ||
from scipy.special import expit | ||
from sklearn.linear_model import LogisticRegression | ||
from sklearn.model_selection import train_test_split | ||
from src.mylib.utils.data import TaskDataset | ||
from sklearn.metrics import classification_report | ||
|
||
class SyntheticBernuliDataset(object): | ||
r'''Base class for synthetic dataset.''' | ||
def __init__(self, n=10, m=100, seed=42): | ||
r'''Constructor method | ||
|
||
:param n: the number of feature | ||
:type n: int | ||
:param m: the number of object | ||
:type m: int | ||
:param seed: seed for random state. | ||
:type seed: int | ||
''' | ||
rs = numpy.random.RandomState(seed) | ||
class Trainer(object): | ||
r"""Base class for all trainer.""" | ||
|
||
self.w = rs.randn(n) # Генерим вектор параметров из нормального распределения | ||
self.X = rs.randn(m, n) # Генерим вектора признаков из нормального распределения | ||
def __init__(self, model, train_files, val_files, test_files, args, optimizer, loss_fn): | ||
r"""Constructor method | ||
self.y = rs.binomial(1, expit(self.X@self.w)) # Гипотеза порождения данных - целевая переменная из схемы Бернули | ||
:param train_files: path to train files | ||
:type train_files: list | ||
:param val_files: path to val files | ||
:type val_files: list | ||
class Trainer(object): | ||
r'''Base class for all trainer.''' | ||
def __init__(self, model, X, Y, seed=42): | ||
r'''Constructor method | ||
:param model: The class with fit and predict methods. | ||
:type model: object | ||
:param X: The array of shape | ||
`num_elements` :math:`\times` `num_feature`. | ||
:type X: numpy.array | ||
:param Y: The array of shape | ||
`num_elements` :math:`\times` `num_answers`. | ||
:type Y: numpy.array | ||
:param seed: Seed for random state. | ||
:type seed: int | ||
''' | ||
self.model = model | ||
self.seed = seed | ||
( | ||
self.X_train, | ||
self.X_val, | ||
self.Y_train, | ||
self.Y_val | ||
) = train_test_split(X, Y, random_state=self.seed) | ||
|
||
def train(self): | ||
r''' Train model | ||
''' | ||
self.model.fit(self.X_train, self.Y_train) | ||
|
||
def eval(self, output_dict=False): | ||
r'''Evaluate model for initial validadtion dataset. | ||
''' | ||
return classification_report( | ||
self.Y_val, | ||
self.model.predict( | ||
self.X_val), output_dict=output_dict) | ||
|
||
def test(self, X, Y, output_dict=False): | ||
r"""Evaluate model for given dataset. | ||
:param X: The array of shape | ||
`num_elements` :math:`\times` `num_feature`. | ||
:type X: numpy.array | ||
:param Y: The array of shape | ||
`num_elements` :math:`\times` `num_answers`. | ||
:type Y: numpy.array | ||
:param test_files: path to test files | ||
:type test_files: list | ||
""" | ||
return classification_report( | ||
Y, self.model.predict(X), output_dict=output_dict) | ||
|
||
|
||
def cv_parameters(X, Y, seed=42, minimal=0.1, maximum=25, count=100): | ||
r'''Function for the experiment with different regularisation parameters | ||
and return accuracy and weidth for LogisticRegression for each parameter. | ||
:param X: The array of shape | ||
`num_elements` :math:`\times` `num_feature`. | ||
:type X: numpy.array | ||
:param Y: The array of shape | ||
`num_elements` :math:`\times` `num_answers`. | ||
:type Y: numpy.array | ||
:param seed: Seed for random state. | ||
:type seed: int | ||
:param minimal: Minimum value for the Cs linspace. | ||
:type minimal: int | ||
:param maximum: Maximum value for the Cs linspace. | ||
:type maximum: int | ||
:param count: Number of the Cs points. | ||
:type count: int | ||
''' | ||
|
||
Cs = numpy.linspace(minimal, maximum, count) | ||
parameters = [] | ||
accuracy = [] | ||
for C in Cs: | ||
trainer = Trainer( | ||
LogisticRegression(penalty='l1', solver='saga', C=1/C), | ||
X, Y, | ||
) | ||
|
||
trainer.train() | ||
|
||
accuracy.append(trainer.eval(output_dict=True)['accuracy']) | ||
|
||
parameters.extend(trainer.model.coef_) | ||
|
||
return Cs, accuracy, parameters | ||
self.model = model | ||
self.args = args | ||
self.optimizer = optimizer | ||
self.loss_fn = loss_fn | ||
self.test_files = test_files | ||
self.initialize_dataloaders(train_files, val_files, test_files) | ||
|
||
def initialize_dataloaders(self, train_files, val_files, test_files): | ||
r"""Initialize dataloaders""" | ||
|
||
conf = {"window_length": self.args["window_length"], "hop_length": self.args["hop_length"], | ||
"number_of_mismatch": self.args["number_of_mismatch"], "max_files": self.args["max_files"]} | ||
self.train_dataloader = torch.utils.data.DataLoader(TaskDataset(train_files, **conf), | ||
batch_size=self.args["batch_size"]) | ||
self.val_dataloader = torch.utils.data.DataLoader(TaskDataset(val_files, **conf), | ||
batch_size=self.args["batch_size"]) | ||
self.test_dataloader = torch.utils.data.DataLoader(TaskDataset(test_files, **conf), | ||
batch_size=1) | ||
|
||
def train_one_epoch(self, epoch_index, writer): | ||
r"""Train one epoch""" | ||
|
||
running_loss = 0 | ||
last_loss = 0 | ||
|
||
for i, data in enumerate(self.train_dataloader): | ||
inputs, labels = data | ||
|
||
self.optimizer.zero_grad() | ||
outputs = self.model(inputs[0], inputs[1:]) | ||
|
||
# TODO: CLASSIFICATION METRIC DURING TRAINING | ||
# probs = (torch.nn.functional.softmax(outputs.data, dim=1) >= 0.5) | ||
# _, predicted = torch.max(probs.data, 1) | ||
|
||
loss = self.loss_fn(outputs, labels) | ||
loss.backward() | ||
|
||
self.optimizer.step() | ||
|
||
running_loss += loss.item() | ||
if i % 100 == 99: | ||
last_loss = running_loss / 100 | ||
print(' batch {} loss: {}'.format(i + 1, last_loss)) | ||
x = epoch_index * len(self.train_dataloader) + i + 1 | ||
writer.add_scalar('Loss/train', last_loss, x) | ||
running_loss = 0 | ||
|
||
return last_loss | ||
|
||
def train_model(self, epochs, run_name): | ||
r""" Train models""" | ||
|
||
writer = SummaryWriter(f"runs/{run_name}_{self.model.__class__.__name__}") | ||
|
||
best_vloss = 1_000_000 | ||
if not os.path.isdir("saved_models"): | ||
os.makedirs("saved_models") | ||
|
||
for epoch in range(epochs): | ||
print(f"EPOCH {epoch + 1}:") | ||
self.model.train() | ||
avg_loss = self.train_one_epoch(epoch + 1, writer) | ||
|
||
running_vloss = 0.0 | ||
self.model.eval() | ||
with torch.no_grad(): | ||
for i, vdata in enumerate(self.val_dataloader): | ||
vinputs, vlabels = vdata | ||
voutputs = self.model(vinputs[0], vinputs[1:]) | ||
vloss = self.loss_fn(voutputs, vlabels) | ||
running_vloss += vloss.item() | ||
|
||
avg_vloss = running_vloss / (i + 1) | ||
print("LOSS train {} valid {}".format(avg_loss, avg_vloss)) | ||
|
||
writer.add_scalars("Training vs. Validation Loss", | ||
{"Training": avg_loss, "Validation": avg_vloss}, | ||
epoch + 1) | ||
writer.flush() | ||
|
||
if avg_vloss < best_vloss: | ||
best_vloss = avg_vloss | ||
model_path = f"saved_models/{self.model.__class__.__name__}_{epoch}" | ||
torch.save(self.model.state_dict(), model_path) | ||
|
||
def eval(self): | ||
r"""Evaluate model for initial validation dataset.""" | ||
pass | ||
|
||
def test(self): | ||
r"""Evaluate model for given dataset""" | ||
|
||
total = 0 | ||
self.model.eval() | ||
y_pred = [] | ||
y_true = [] | ||
subjects = list(set([os.path.basename(x).split("_-_")[1] for x in self.test_files])) | ||
loss_fn = nn.functional.cross_entropy | ||
with torch.no_grad(): | ||
for sub in subjects: | ||
sub_test_files = [f for f in self.test_files if sub in os.path.basename(f)] | ||
test_dataloader = torch.utils.data.DataLoader(TaskDataset(sub_test_files, self.args["window_length"], self.args["hop_length"])) | ||
loss = 0 | ||
correct = 0 | ||
for inputs, label in test_dataloader: | ||
outputs = self.model(inputs[0], inputs[1:]) | ||
|
||
loss += loss_fn(outputs, label).item() | ||
probs = (torch.nn.functional.softmax(outputs.data, dim=1) >= 0.5) | ||
_, predicted = torch.max() | ||
|
||
for data in self.test_dataloader: | ||
inputs, labels = data | ||
|
||
outputs = self.model(inputs[0], inputs[1:]) | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += labels.size(0) | ||
|
||
y_pred.extend(predicted.tolist()) | ||
y_true.extend(labels.tolist()) | ||
|
||
correct += (predicted == labels).sum().item() | ||
|
||
return classification_report(y_true, y_pred) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
{ | ||
"dataset_folder": "--absolute path to dataset folder--", | ||
"dataset_folder": "/home/bukkacha/Desktop/EEGDataset", | ||
"derivatives_folder": "derivatives", | ||
"preprocessed_eeg_folder": "preprocessed_eeg", | ||
"preprocessed_stimuli_folder": "preprocessed_stimuli", | ||
"split_folder": "split_data", | ||
"test_folder": "test_set" | ||
"stimuli": "stimuli" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import torch | ||
import numpy as np | ||
import itertools | ||
import os | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class TaskDataset(Dataset): | ||
"""Generate data for the Match/Mismatch task.""" | ||
|
||
def __init__(self, files, window_length, hop_length, number_of_mismatch, max_files=100): | ||
self.labels = [] | ||
assert number_of_mismatch != 0 | ||
self.window_length = window_length | ||
self.hop_length = hop_length | ||
self.number_of_mismatch = number_of_mismatch | ||
self.files = files | ||
self.max_files = max_files | ||
self.group_recordings() | ||
self.frame_recordings() | ||
self.create_imposter_segments() | ||
self.create_labels_randomize_positions() | ||
|
||
def group_recordings(self): | ||
new_files = [] | ||
grouped = itertools.groupby(sorted(self.files), lambda x: "_-_".join(os.path.basename(x).split("_-_")[:3])) | ||
|
||
for recording_name, feature_paths in grouped: | ||
sub_recordings = sorted(feature_paths, key=lambda x: "0" if x == "eeg" else x) | ||
eeg, envelope = np.load(sub_recordings[0]), np.load(sub_recordings[1]) # eeg [L, C], env [L, 1] | ||
new_files += [[torch.tensor(eeg.T).float(), torch.tensor(envelope.T).float()]] | ||
|
||
if self.max_files is not None and len(new_files) == self.max_files: | ||
break | ||
|
||
self.files = new_files | ||
|
||
def frame_recordings(self): | ||
new_files = [] | ||
for i in range(len(self.files)): | ||
self.files[i][0] = self.files[i][0].unfold( | ||
1, self.window_length, self.hop_length).transpose(0, 1) # [num_of_frames, C, window_length] | ||
self.files[i][1] = self.files[i][1].unfold( | ||
1, self.window_length, self.hop_length).transpose(0, 1) # [num_of_frames, C, window_length] | ||
eegs = list(torch.tensor_split(self.files[i][0], self.files[i][0].shape[0], dim=0)) | ||
envs = list(torch.tensor_split(self.files[i][1], self.files[i][1].shape[0], dim=0)) | ||
for eeg, env in zip(eegs, envs): | ||
new_files.append([eeg.squeeze(), env.squeeze(dim=0)]) | ||
self.files = new_files | ||
|
||
def create_imposter_segments(self): | ||
for i in range(len(self.files)): | ||
for _ in range(self.number_of_mismatch): | ||
t = self.files[i][-1].view(-1) | ||
t = t[torch.randperm(t.shape[-1])].view(self.files[i][-1].shape) | ||
self.files[i].append(t) | ||
|
||
def create_labels_randomize_positions(self): | ||
for i in range(len(self.files)): | ||
idx_permutation = torch.randperm(self.number_of_mismatch + 1) + 1 | ||
permuted = [] | ||
for idx in idx_permutation: | ||
permuted.append(self.files[i][idx]) | ||
self.files[i][1:] = permuted | ||
self.labels.append(torch.argmax((idx_permutation == 1).long())) | ||
|
||
def __len__(self): | ||
return len(self.files) | ||
|
||
def __getitem__(self, idx): | ||
return self.files[idx], self.labels[idx] |