diff --git a/bris/inference.py b/bris/inference.py index 14a0aca..35b052e 100644 --- a/bris/inference.py +++ b/bris/inference.py @@ -36,6 +36,8 @@ def __init__( self.precision = precision self._device = device + torch.set_float32_matmul_precision("high") + @property def device(self) -> str: if self._device is None: @@ -60,7 +62,7 @@ def strategy(self): self.config.dataloader.get( "read_group_size", self.config.hardware.num_gpus_per_model ), - static_graph=not self.checkpoint.config.training.accum_grad_batches > 1, + static_graph=False, #not self.checkpoint.config.training.accum_grad_batches > 1, ) else: LOGGER.info( @@ -69,7 +71,7 @@ def strategy(self): from bris.data.legacy.distributed.strategy import DDPGroupStrategy return DDPGroupStrategy( self.config.hardware.num_gpus_per_model, - static_graph=not self.checkpoint.config.training.accum_grad_batches > 1, + static_graph=False, #not self.checkpoint.config.training.accum_grad_batches > 1, ) @@ -78,7 +80,7 @@ def trainer(self) -> pl.Trainer: trainer = pl.Trainer( logger=False, accelerator=self.device, - deterministic=self.deterministic, + deterministic=False, detect_anomaly=False, strategy=self.strategy, devices=self.config.hardware.num_gpus_per_node,