Skip to content

Commit

Permalink
fix(lambdas): use cfg object instead self
Browse files Browse the repository at this point in the history
  • Loading branch information
hahnec committed Jul 26, 2023
1 parent 00aed1d commit af26a22
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,15 @@ def train_model(
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.MSELoss(reduction='mean')
l1loss = nn.L1Loss(reduction='mean')
lambda_value = 0.01 if cfg.model.__contains__('unet') and self.input_type == 'iq' else cfg.lambda1
lambda_value = 0.01 if cfg.model.__contains__('unet') and cfg.input_type == 'iq' else cfg.lambda1
train_step = 0
val_step = 0

# mSPCN Gaussian
psf_heatmap = torch.from_numpy(matlab_style_gauss2D(shape=(7,7),sigma=1))
gfilter = torch.reshape(psf_heatmap, [1, 1, 7, 7])
gfilter = gfilter.to(cfg.device)
amplitude = 50 if cfg.model.__contains__('mspcn') and self.input_type == 'iq' else cfg.lambda0
amplitude = 50 if cfg.model.__contains__('mspcn') and cfg.input_type == 'iq' else cfg.lambda0

# variable init for coordinate transformation
gt_samples_list, gt_points_list = [], []
Expand Down

0 comments on commit af26a22

Please sign in to comment.