Skip to content

Commit

Permalink
some fixes in inference class
Browse files Browse the repository at this point in the history
  • Loading branch information
einrone committed Jan 9, 2025
1 parent 431cc01 commit f1c0633
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions bris/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(
self.precision = precision
self._device = device

torch.set_float32_matmul_precision("high")

@property
def device(self) -> str:
if self._device is None:
Expand All @@ -60,7 +62,7 @@ def strategy(self):
self.config.dataloader.get(
"read_group_size", self.config.hardware.num_gpus_per_model
),
static_graph=not self.checkpoint.config.training.accum_grad_batches > 1,
static_graph=False, #not self.checkpoint.config.training.accum_grad_batches > 1,
)
else:
LOGGER.info(
Expand All @@ -69,7 +71,7 @@ def strategy(self):
from bris.data.legacy.distributed.strategy import DDPGroupStrategy
return DDPGroupStrategy(
self.config.hardware.num_gpus_per_model,
static_graph=not self.checkpoint.config.training.accum_grad_batches > 1,
static_graph=False, #not self.checkpoint.config.training.accum_grad_batches > 1,
)


Expand All @@ -78,7 +80,7 @@ def trainer(self) -> pl.Trainer:
trainer = pl.Trainer(
logger=False,
accelerator=self.device,
deterministic=self.deterministic,
deterministic=False,
detect_anomaly=False,
strategy=self.strategy,
devices=self.config.hardware.num_gpus_per_node,
Expand Down

0 comments on commit f1c0633

Please sign in to comment.