-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #16 from galactic-ai/implement-dann
initial implementation of DANN
- Loading branch information
Showing
7 changed files
with
649 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "ece31ab8", | ||
"metadata": {}, | ||
"source": [ | ||
"# Actual Implementation" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "fa05e35a", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from haloflow.dann.data_loader import SimulationDataset\n", | ||
"from haloflow.dann import model as M\n", | ||
"from haloflow.dann import train as T\n", | ||
"from haloflow.dann import evalutate as E\n", | ||
"from haloflow.dann import visualise as V\n", | ||
"\n", | ||
"from haloflow import config as C" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "34c68ded", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 26, | ||
"id": "b91cf1a2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Configuration\n", | ||
"config = {\n", | ||
" 'sims': ['TNG50', 'TNG100', 'Eagle100', 'Simba100', 'TNG_ALL'],\n", | ||
" 'obs': 'mags',\n", | ||
" 'dat_dir': C.get_dat_dir(),\n", | ||
" 'input_dim': None, # Will be inferred from data\n", | ||
" 'num_domains': 4,\n", | ||
" 'batch_size': 128,\n", | ||
" 'num_epochs': 100,\n", | ||
" 'lr': 0.001,\n", | ||
" 'alpha': 0.5,\n", | ||
"}" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 27, | ||
"id": "5fc0cda4", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dataset = SimulationDataset(config['sims'], config['obs'], config['dat_dir'])\n", | ||
"train_loader, test_loader = dataset.get_train_test_loaders(\n", | ||
" train_sims=config['sims'][:-1], # First 4 sims for training\n", | ||
" test_sim=config['sims'][-1], # Last sim (TNG_ALL) for testing\n", | ||
" batch_size=config['batch_size']\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 28, | ||
"id": "8a1c0109", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Infer input dimension from data\n", | ||
"sample_X, _, _ = next(iter(train_loader))\n", | ||
"config['input_dim'] = sample_X.shape[1]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 29, | ||
"id": "cd0d8406", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Initialize model\n", | ||
"model = M.DANN(input_dim=config['input_dim'], \n", | ||
" num_domains=config['num_domains'], \n", | ||
" alpha=config['alpha']\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "484a59a2", | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Train\n", | ||
"T.train_dann(\n", | ||
" model, \n", | ||
" train_loader, \n", | ||
" test_loader, \n", | ||
" num_epochs=config['num_epochs'], \n", | ||
" lr=config['lr'], \n", | ||
" device='cuda' if torch.cuda.is_available() else 'cpu'\n", | ||
")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 31, | ||
"id": "3725b6a6", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Evaluating Regression Performance:\n", | ||
"MSE: 0.0955, RMSE: 0.3090, R²: 0.7535\n", | ||
"\n", | ||
"Evaluating Domain Accuracy:\n", | ||
"Domain Accuracy: 0.3846\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"0.38458414554905784" | ||
] | ||
}, | ||
"execution_count": 31, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"# Evaluate\n", | ||
"print(\"\\nEvaluating Regression Performance:\")\n", | ||
"E.evaluate_regression(model, test_loader, 'cpu')\n", | ||
"\n", | ||
"print(\"\\nEvaluating Domain Accuracy:\")\n", | ||
"E.domain_accuracy(model, train_loader, 'cpu')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "57b0beb1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python (haloflow_venv)", | ||
"language": "python", | ||
"name": "myenv" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.12" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from re import M | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import DataLoader, TensorDataset | ||
from sklearn.preprocessing import StandardScaler | ||
from .. import data as D | ||
|
||
|
||
class SimulationDataset: | ||
def __init__(self, sims, obs, data_dir): | ||
self.sims = sims | ||
self.obs = obs | ||
self.data_dir = data_dir | ||
self.data = self._load_data() | ||
|
||
def _load_data(self): | ||
data = {} | ||
for sim in self.sims: | ||
Y_train, X_train = D.hf2_centrals("train", self.obs, sim=sim) | ||
Y_test, X_test = D.hf2_centrals("test", self.obs, sim=sim) | ||
|
||
# impose mass priors (already in log space) | ||
# TODO: need to revisit later | ||
mass_range_sm = [10.0, 13.] | ||
mass_range_hm = [10.7, 15.] | ||
mask_sm = (Y_train[:, 0] > mass_range_sm[0]) & (Y_train[:, 0] < mass_range_sm[1]) | ||
mask_hm = (Y_train[:, 1] > mass_range_hm[0]) & (Y_train[:, 1] < mass_range_hm[1]) | ||
Y_train = Y_train[mask_sm & mask_hm] | ||
X_train = X_train[mask_sm & mask_hm] | ||
|
||
data[sim] = { | ||
"X_train": X_train, | ||
"Y_train": Y_train, | ||
"X_test": X_test, | ||
"Y_test": Y_test, | ||
} | ||
return data | ||
|
||
def get_train_test_loaders(self, train_sims, test_sim, batch_size=64): | ||
"""Get DataLoaders for training and testing.""" | ||
# Combine training data from specified simulations | ||
X_train = np.concatenate([self.data[sim]["X_train"] for sim in train_sims]) | ||
Y_train = np.concatenate([self.data[sim]["Y_train"] for sim in train_sims]) | ||
domain_labels = np.concatenate( | ||
[[i] * len(self.data[sim]["X_train"]) for i, sim in enumerate(train_sims)] | ||
) | ||
|
||
scaler = StandardScaler() | ||
|
||
# Get test data | ||
X_test = self.data[test_sim]["X_test"] | ||
Y_test = self.data[test_sim]["Y_test"] | ||
domain_labels_test = np.full(len(Y_test), len(train_sims)) | ||
|
||
# Convert to tensors | ||
X_train_tensor = torch.tensor( | ||
scaler.fit_transform(X_train), dtype=torch.float32 | ||
) | ||
Y_train_tensor = torch.tensor(Y_train, dtype=torch.float32) | ||
domain_labels_tensor = torch.tensor(domain_labels, dtype=torch.long) | ||
|
||
X_test_tensor = torch.tensor(scaler.fit_transform(X_test), dtype=torch.float32) | ||
Y_test_tensor = torch.tensor(Y_test, dtype=torch.float32) | ||
domain_labels_tensor_test = torch.tensor(domain_labels_test, dtype=torch.long) | ||
|
||
# Create datasets | ||
train_dataset = TensorDataset( | ||
X_train_tensor, Y_train_tensor, domain_labels_tensor | ||
) | ||
test_dataset = TensorDataset(X_test_tensor, Y_test_tensor) | ||
|
||
# Create DataLoaders | ||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | ||
|
||
return train_loader, test_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import numpy as np | ||
import torch | ||
from sklearn.metrics import mean_squared_error, r2_score | ||
|
||
|
||
def evaluate_regression(model, dataloader, device="cuda"): | ||
""" | ||
Evaluate the model's regression performance (MSE, RMSE, R²). | ||
Parameters | ||
---------- | ||
model : torch.nn.Module | ||
Trained regression model. | ||
dataloader : torch.utils.data.DataLoader | ||
DataLoader for the test set. | ||
device : str | ||
Device to run evaluation on. | ||
Returns | ||
------- | ||
dict | ||
Dictionary containing MSE, RMSE, and R² scores. | ||
""" | ||
model.eval() | ||
y_true, y_pred = [], [] | ||
|
||
with torch.no_grad(): | ||
for X_batch, y_batch in dataloader: | ||
X_batch, y_batch = X_batch.to(device), y_batch.to(device) | ||
preds, _ = model(X_batch) | ||
y_true.append(y_batch.cpu().numpy()) | ||
y_pred.append(preds.cpu().numpy()) | ||
|
||
y_true = np.concatenate(y_true) | ||
y_pred = np.concatenate(y_pred) | ||
|
||
mse = mean_squared_error(y_true, y_pred) | ||
rmse = np.sqrt(mse) | ||
r2 = r2_score(y_true, y_pred) | ||
|
||
print(f"MSE: {mse:.4f}, RMSE: {rmse:.4f}, R²: {r2:.4f}") | ||
return {"mse": mse, "rmse": rmse, "r2": r2} | ||
|
||
|
||
def domain_accuracy(model, dataloader, device="cuda"): | ||
""" | ||
Evaluate the domain classifier's accuracy. | ||
Parameters | ||
---------- | ||
model : torch.nn.Module | ||
Trained domain classifier model. | ||
dataloader : torch.utils.data.DataLoader | ||
DataLoader for the test set. | ||
device : str | ||
Device to run evaluation on. | ||
Returns | ||
------- | ||
float | ||
Domain classification accuracy. | ||
""" | ||
model.eval() | ||
correct = 0 | ||
total = 0 | ||
|
||
with torch.no_grad(): | ||
for X_batch, _, domain_batch in dataloader: | ||
X_batch, domain_batch = X_batch.to(device), domain_batch.to(device) | ||
_, domain_pred = model(X_batch) | ||
preds = domain_pred.argmax(dim=1) | ||
correct += (preds == domain_batch).sum().item() | ||
total += domain_batch.size(0) | ||
|
||
acc = correct / total | ||
print(f"Domain Accuracy: {acc:.4f}") | ||
return acc |
Oops, something went wrong.