From ba9ae8b9ff3e8919b8db05fe661844a9ab7b93ea Mon Sep 17 00:00:00 2001 From: Vladimir Brkic Date: Thu, 28 Nov 2024 08:55:49 +0000 Subject: [PATCH] use float32 as dtype [skip ci] --- forge/test/random/rgg/pytorch/generated_model.jinja2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/forge/test/random/rgg/pytorch/generated_model.jinja2 b/forge/test/random/rgg/pytorch/generated_model.jinja2 index b3fee57e1..ec09b8b8d 100644 --- a/forge/test/random/rgg/pytorch/generated_model.jinja2 +++ b/forge/test/random/rgg/pytorch/generated_model.jinja2 @@ -18,7 +18,7 @@ class GeneratedTestModel_{{ test_index }}_{{ random_seed }}(torch.nn.Module): {% for node in graph.nodes %}{% if node.operator.is_layer %} self.{{ node.layer_name }} = {{ node.operator.full_name }}({{ constructor_kwargs(node=node) }}){% endif %}{% endfor %} {% for constant_node in graph.constant_nodes %} - self.{{ constant_node.out_value }} = torch.randn({{ reduce_microbatch_size(constant_node.input_shape) }}){% endfor %} + self.{{ constant_node.out_value }} = torch.randn({{ reduce_microbatch_size(constant_node.input_shape) }}, dtype = torch.float32){% endfor %} def forward(self{% for node in graph.input_nodes %}, {{ node.out_value }}: torch.Tensor{% endfor %}