Skip to content

Commit

Permalink
use float32 as dtype [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
vbrkicTT committed Feb 13, 2025
1 parent 15db8e6 commit ba9ae8b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion forge/test/random/rgg/pytorch/generated_model.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class GeneratedTestModel_{{ test_index }}_{{ random_seed }}(torch.nn.Module):
{% for node in graph.nodes %}{% if node.operator.is_layer %}
self.{{ node.layer_name }} = {{ node.operator.full_name }}({{ constructor_kwargs(node=node) }}){% endif %}{% endfor %}
{% for constant_node in graph.constant_nodes %}
self.{{ constant_node.out_value }} = torch.randn({{ reduce_microbatch_size(constant_node.input_shape) }}){% endfor %}
self.{{ constant_node.out_value }} = torch.randn({{ reduce_microbatch_size(constant_node.input_shape) }}, dtype = torch.float32){% endfor %}

def forward(self{% for node in graph.input_nodes %},
{{ node.out_value }}: torch.Tensor{% endfor %}
Expand Down

0 comments on commit ba9ae8b

Please sign in to comment.