Skip to content

Commit

Permalink
Reverted to minimal model version without temporary code for debugging
Browse files Browse the repository at this point in the history
This reverts commits f68995d, 9cecb48, and 6439ada.
  • Loading branch information
oleschwen committed Nov 27, 2024
1 parent f68995d commit 9b81f7e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ workflows = [
path = "controller.SwarmServerController"
args {
# can also set aggregation clients and train clients, see class for all available args
num_rounds = 2
num_rounds = 1
}
}
]
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from models.base_model import BasicClassifier
import torch
import torch.nn as nn
import math


class MiniCNNForTesting(BasicClassifier):
def __init__(self,
Expand All @@ -15,25 +15,20 @@ def __init__(self,
lr_scheduler=None,
lr_scheduler_kwargs: dict = {},
aucroc_kwargs: dict = {"task": "binary"},
acc_kwargs: dict = {"task": "binary"},
acc_kwargs: dict = {"task": "binary"}
):
super().__init__(in_ch, out_ch, spatial_dims, loss, loss_kwargs, optimizer, optimizer_kwargs, lr_scheduler,
lr_scheduler_kwargs, aucroc_kwargs, acc_kwargs)

waste_of_memory = 16
linear_waste_of_memory = int(math.sqrt(waste_of_memory/4))

self.model = torch.nn.Sequential(
nn.Conv2d(1, 3, 3),
nn.ReLU(),
nn.MaxPool2d(4),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3*4*4, linear_waste_of_memory), # temporary tests,
nn.Linear(linear_waste_of_memory, linear_waste_of_memory), # this should not be merged to main
nn.Linear(linear_waste_of_memory, 1)
nn.Linear(3*4*4, 1)
)
print(self.model)


def forward(self, x_in: torch.Tensor, **kwargs) -> torch.Tensor:
return self.model(x_in)
2 changes: 0 additions & 2 deletions docker_config/Dockerfile_testing
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,6 @@ RUN python3 -m pip install \
# additional package needed for [black]jupyter
RUN python3 -m pip install tokenize-rt==5.2.0

RUN python3 -m pip install torchinfo

RUN mkdir /scratch
RUN chmod a+rwx /scratch

Expand Down

0 comments on commit 9b81f7e

Please sign in to comment.