From af26a22ca4bd4dad90861c2bba3fe13ee04f84dd Mon Sep 17 00:00:00 2001 From: hahnec Date: Wed, 26 Jul 2023 12:00:18 +0200 Subject: [PATCH] fix(lambdas): use cfg object instead self --- train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index e249191..23ebc4b 100644 --- a/train.py +++ b/train.py @@ -118,7 +118,7 @@ def train_model( grad_scaler = torch.cuda.amp.GradScaler(enabled=amp) criterion = nn.MSELoss(reduction='mean') l1loss = nn.L1Loss(reduction='mean') - lambda_value = 0.01 if cfg.model.__contains__('unet') and self.input_type == 'iq' else cfg.lambda1 + lambda_value = 0.01 if cfg.model.__contains__('unet') and cfg.input_type == 'iq' else cfg.lambda1 train_step = 0 val_step = 0 @@ -126,7 +126,7 @@ def train_model( psf_heatmap = torch.from_numpy(matlab_style_gauss2D(shape=(7,7),sigma=1)) gfilter = torch.reshape(psf_heatmap, [1, 1, 7, 7]) gfilter = gfilter.to(cfg.device) - amplitude = 50 if cfg.model.__contains__('mspcn') and self.input_type == 'iq' else cfg.lambda0 + amplitude = 50 if cfg.model.__contains__('mspcn') and cfg.input_type == 'iq' else cfg.lambda0 # variable init for coordinate transformation gt_samples_list, gt_points_list = [], []