Skip to content

Commit

Permalink
what
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jan 21, 2024
1 parent ea47079 commit b91ecfc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/haliax/nn/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ def scan(self, init, *args, **kwargs):

@named_call(name="Stacked.fold")
def fold(self, init, *args, **kwargs):
print(f"FOLD! {self.gradient_checkpointing} {self.prevent_cse}", flush=True)
if self.gradient_checkpointing:
do_block = filter_checkpoint(self._do_block, prevent_cse=self.prevent_cse)
# determine a checkpoint block size, should be roughly sqrt(self.Block.size)
size = int(math.sqrt(self.Block.size))
num_blocks = int(math.ceil(self.Block.size / size))
print(f"FOLD! ${num_blocks}")

return haliax.fold(
do_block, self.Block, grad_checkpointing=self.gradient_checkpointing, checkpoint_blocks=[num_blocks]
Expand Down

0 comments on commit b91ecfc

Please sign in to comment.