diff --git a/examples/block_mnist.py b/examples/block_mnist.py index 2ffd8f2..3b3b14a 100644 --- a/examples/block_mnist.py +++ b/examples/block_mnist.py @@ -33,10 +33,10 @@ } grid_x, grid_y = torch.meshgrid( - (torch.arange(28) - 14) // 14, (torch.arange(28) - 14) // 14, indexing="ij" + (torch.arange(28) - 14) / 14, (torch.arange(28) - 14) / 14, indexing="ij" ) grid = torch.stack([grid_x, grid_y]) - +print('grid', grid) def collate_fn(batch):