diff --git a/models/demos/convnet_mnist/tt/convnet_mnist.py b/models/demos/convnet_mnist/tt/convnet_mnist.py index a38aa60a770c..a699cbcbd380 100644 --- a/models/demos/convnet_mnist/tt/convnet_mnist.py +++ b/models/demos/convnet_mnist/tt/convnet_mnist.py @@ -33,7 +33,7 @@ def convnet_mnist( ) x = ttnn.to_layout(input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) - [x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + x = ttnn.conv2d( input_tensor=x, weight_tensor=parameters.conv1.weight, in_channels=1, @@ -50,6 +50,8 @@ def convnet_mnist( conv_op_cache={}, debug=True, groups=1, + return_output_dim=False, + return_weights_and_bias=False, ) x = ttnn.relu(x) @@ -76,7 +78,7 @@ def convnet_mnist( dilation=[1, 1], ) - [x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + x, [out_height, out_width] = ttnn.conv2d( input_tensor=x, weight_tensor=parameters.conv2.weight, in_channels=32, @@ -93,6 +95,8 @@ def convnet_mnist( conv_op_cache={}, debug=False, groups=1, + return_output_dim=True, + return_weights_and_bias=False, ) x = ttnn.relu(x) diff --git a/models/demos/segformer/tt/common.py b/models/demos/segformer/tt/common.py index 5f52fe0e5072..85029cfda4b1 100644 --- a/models/demos/segformer/tt/common.py +++ b/models/demos/segformer/tt/common.py @@ -57,7 +57,7 @@ def __call__(self, device, input_tensor): if self.act_block_h is not None: conv_config.act_block_h_override = self.act_block_h - [output_tensor, _out_height, _out_width, self.weights, self.bias] = ttnn.conv2d( + output_tensor, [_out_height, _out_width] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.weights, bias_tensor=self.bias, @@ -72,6 +72,8 @@ def __call__(self, device, input_tensor): input_width=input_tensor.shape[2], conv_config=conv_config, groups=self.groups, + return_output_dim=True, + return_weights_and_bias=False, ) return output_tensor, _out_height, _out_width diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py index cfe555d0367f..a18159b828f9 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_large_new_conv_api.py @@ -167,7 +167,7 @@ def run_downsample_if_req( shard_layout = ( ttnn.TensorMemoryLayout.HEIGHT_SHARDED if height_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED ) - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -190,6 +190,8 @@ def run_downsample_if_req( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -214,7 +216,7 @@ def __call__( # conv1 is 1x1 conv # print("Running conv1") module_input_height = input_height - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -238,6 +240,8 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) act_block_h_override = 0 @@ -277,7 +281,7 @@ def __call__( ) # if ds_out_mem_config and ds_out_mem_config != ttnn.get_memory_config(out): # out = ttnn.to_memory_config(out, ds_out_mem_config) - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -304,11 +308,13 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -331,6 +337,8 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_weights_and_bias=True, + return_output_dim=False, ) if not self.run_downsample_before_conv2: @@ -546,7 +554,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config ) - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -569,6 +577,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 @@ -857,7 +867,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config ) - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -880,6 +890,8 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index 44d90cb0f348..fefc05924fdf 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -160,7 +160,7 @@ def run_downsample_if_req( ): if self.downsample: logger.debug(f"Running downsample") - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -195,6 +195,8 @@ def run_downsample_if_req( enable_subblock_padding=enable_subblock_padding, ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -226,7 +228,7 @@ def __call__( # conv1 is 1x1 conv logger.debug(f"Running conv1") module_input_height = input_height - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -252,6 +254,8 @@ def __call__( packer_l1_accum_enabled=packer_l1_acc, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) act_block_h_override = 0 @@ -307,7 +311,7 @@ def __call__( reallocate_halo_output = batch_size == 20 logger.debug(f"Running conv2") - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -340,6 +344,8 @@ def __call__( enable_subblock_padding=enable_subblock_padding, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) logger.debug( @@ -358,7 +364,7 @@ def __call__( # conv3 is 1x1 conv logger.debug(f"Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -383,6 +389,8 @@ def __call__( packer_l1_accum_enabled=packer_l1_acc, ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) if not run_downsample_before_conv2: @@ -719,7 +727,7 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt logger.debug(f"==== first conv") # first conv - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=fold_output_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -734,6 +742,8 @@ def run(self, input_tensor, device, ops_parallel_config, conv_op_cache={}) -> tt input_width=self.conv1_input_width, conv_config=self.conv1_config, conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 if self.batch_size == 20: diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py index 5c0750003c16..a57b84e89927 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api.py @@ -162,7 +162,7 @@ def run_downsample_if_req( height_sharding=None, ): if self.downsample: - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -187,6 +187,8 @@ def run_downsample_if_req( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -209,7 +211,7 @@ def __call__( # conv1 is 1x1 conv # print("Running conv1") module_input_height = input_height - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -233,6 +235,8 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) act_block_h_override = 0 @@ -270,7 +274,7 @@ def __call__( # self.conv1_input_channels == 256 and # self.downsample ) - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -297,11 +301,13 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -324,6 +330,8 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) if not self.run_downsample_before_conv2: @@ -516,7 +524,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt elif batch_size == 20: act_block_h_override = 640 - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -539,6 +547,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 @@ -819,7 +829,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c else: act_block_h_override = 0 - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -842,6 +852,8 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py index f2e266e1d8b6..f9e755d6df3b 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xlarge_new_conv_api_24.py @@ -164,7 +164,7 @@ def run_downsample_if_req( height_sharding=None, ): if self.downsample: - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -189,6 +189,8 @@ def run_downsample_if_req( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -211,7 +213,7 @@ def __call__( # conv1 is 1x1 conv # print("Running conv1") module_input_height = input_height - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -235,6 +237,8 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) act_block_h_override = 0 @@ -273,7 +277,7 @@ def __call__( logger.info( f"Running conv2 with reallocate_halo_output={reallocate_halo_output}, input_height={input_height}, conv2_output_channels={self.conv2_output_channels}" ) - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -300,11 +304,13 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -327,6 +333,8 @@ def __call__( reshard_if_not_optimal=reshard_if_not_optimal, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if not self.run_downsample_before_conv2: @@ -541,7 +549,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt input_tensor, device=device, memory_config=self.grayskull_conv1_input_memory_config ) - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -564,6 +572,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 @@ -872,7 +882,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c elif batch_size == 20: act_block_h_override = 640 - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -895,6 +905,8 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py index 45d93ebf6859..ad0f2a6844d3 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_xxlarge_new_conv_api.py @@ -163,7 +163,7 @@ def run_downsample_if_req( height_sharding=None, ): if self.downsample: - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -189,6 +189,8 @@ def run_downsample_if_req( transpose_shards=height_sharding, ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) ds_out = ttnn.reallocate(ds_out) @@ -216,7 +218,7 @@ def __call__( # conv1 is 1x1 conv # print("Running conv1") module_input_height = input_height - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -241,6 +243,8 @@ def __call__( transpose_shards=height_sharding, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if is_wormhole_b0(): @@ -321,7 +325,7 @@ def __call__( # self.conv1_input_channels == 256 and # self.downsample ) - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -349,11 +353,13 @@ def __call__( transpose_shards=height_sharding, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -377,6 +383,8 @@ def __call__( transpose_shards=height_sharding, ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) if not self.run_downsample_before_conv2: @@ -581,7 +589,7 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt else: act_block_h_override = 0 - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -605,6 +613,8 @@ def first_run(self, input_tensor, device, batch_size, ops_parallel_config) -> tt act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 @@ -915,7 +925,7 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c elif batch_size == 1: act_block_h_override = 256 - x, x_height, x_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + x, [x_height, x_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -938,6 +948,8 @@ def optimized_run(self, input_tensor, device, batch_size, ops_parallel_config, c act_block_h_override=act_block_h_override, ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # Relu is fused with conv1 diff --git a/models/demos/vgg/tt/ttnn_vgg.py b/models/demos/vgg/tt/ttnn_vgg.py index 4cb986c27304..bdf190f9c3d9 100644 --- a/models/demos/vgg/tt/ttnn_vgg.py +++ b/models/demos/vgg/tt/ttnn_vgg.py @@ -112,7 +112,7 @@ def ttnn_vgg16( tt_bias = parameters.features[conv_feature_ids[iter_conv_id]].bias # Call ttnn.conv conv_op_cache = {} - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device] = ttnn.conv2d( input_tensor=tt_x, weight_tensor=tt_weight, in_channels=conv_ttnn_params[iter_conv_id][0], @@ -127,6 +127,8 @@ def ttnn_vgg16( input_width=conv_ttnn_params[iter_conv_id][3], conv_config=conv_config, conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) tt_x = ttnn.from_device(tt_output_tensor_on_device) ttnn.deallocate(tt_output_tensor_on_device) @@ -234,7 +236,7 @@ def ttnn_vgg11( # Call ttnn.conv conv_op_cache = {} - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device] = ttnn.conv2d( input_tensor=tt_x, weight_tensor=tt_weight, in_channels=conv_ttnn_params_2[iter_conv_id][0], @@ -249,6 +251,8 @@ def ttnn_vgg11( input_width=conv_ttnn_params_2[iter_conv_id][3], conv_config=conv_config, conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) tt_x = ttnn.from_device(tt_output_tensor_on_device) ttnn.deallocate(tt_output_tensor_on_device) diff --git a/models/demos/wormhole/mamba/tt/mamba_conv.py b/models/demos/wormhole/mamba/tt/mamba_conv.py index a2700198f835..52401de47d24 100644 --- a/models/demos/wormhole/mamba/tt/mamba_conv.py +++ b/models/demos/wormhole/mamba/tt/mamba_conv.py @@ -87,7 +87,7 @@ def __call__(self, input_tensor): input_tensor_splits = self.prepare_input(input_tensor) output_tensor_splits = [] for i in range(self.config.channels_split_factor): - [tt_output_tensor_on_device, out_length, weights_device, _] = ttnn.Conv1d( + tt_output_tensor_on_device, out_length, [weights_device, _] = ttnn.Conv1d( input_tensor=input_tensor_splits[i], weight_tensor=self.tt_weight_tensor_splits[i], in_channels=self.config.input_channels // self.config.channels_split_factor, @@ -103,6 +103,8 @@ def __call__(self, input_tensor): conv_op_cache={}, debug=False, groups=self.config.groups // self.config.channels_split_factor, + return_output_dim=True, + return_weights_and_bias=True, ) self.tt_weight_tensor_splits[i] = weights_device output_tensor_splits.append(ttnn.sharded_to_interleaved(tt_output_tensor_on_device)) diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py index 2ad02078d718..13a8ddc0f581 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_downsample_2d_new_conv.py @@ -143,7 +143,7 @@ def __call__( if self.conv_config_override and "act_block_h" in self.conv_config_override: conv_config.act_block_h_override = self.conv_config_override["act_block_h"] - [hidden_states, _out_height, _out_width, self.conv_weights, self.conv_bias] = ttnn.conv2d( + [hidden_states, [self.conv_weights, self.conv_bias]] = ttnn.conv2d( input_tensor=hidden_states, in_channels=self.in_channels, out_channels=self.out_channels, @@ -158,6 +158,8 @@ def __call__( bias_tensor=self.conv_bias, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) # hidden_states = run_ttnn_conv_with_pre_and_post_tensor_formatting( # self.device, diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py index cdcea7056263..844ffbcbe868 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_resnetblock2d_new_conv.py @@ -469,7 +469,7 @@ def __call__( ) if self.conv1_config_override and "act_block_h" in self.conv2_config_override: conv_config.act_block_h_override = self.conv1_config_override["act_block_h"] - [hidden_states, _out_height, _out_width, self.conv1s_weights[0], self.conv1s_bias[0]] = ttnn.conv2d( + [hidden_states, [self.conv1s_weights[0], self.conv1s_bias[0]]] = ttnn.conv2d( input_tensor=hidden_states, weight_tensor=self.conv1s_weights[0], in_channels=self.conv1_in_channels, @@ -484,6 +484,8 @@ def __call__( input_width=self.conv1_input_width, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) else: @@ -543,10 +545,8 @@ def __call__( [ split_hidden_states[i], - _out_height, - _out_width, - self.conv1s_weights[i], - self.conv1s_bias[i], + [_out_height, _out_width], + [self.conv1s_weights[i], self.conv1s_bias[i]], ] = ttnn.conv2d( input_tensor=split_hidden_states[i], weight_tensor=self.conv1s_weights[i], @@ -562,6 +562,8 @@ def __call__( input_width=self.conv1_input_width, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if i != 0: split_hidden_states[i] = ttnn.add( @@ -668,7 +670,7 @@ def __call__( ) if self.conv2_config_override and "act_block_h" in self.conv2_config_override: conv_config.act_block_h_override = self.conv2_config_override["act_block_h"] - [hidden_states, _out_height, _out_width, self.conv2_weights, self.conv2_bias] = ttnn.conv2d( + [hidden_states, [_out_height, _out_width], [self.conv2_weights, self.conv2_bias]] = ttnn.conv2d( input_tensor=hidden_states, weight_tensor=self.conv2_weights, bias_tensor=self.conv2_bias, @@ -683,6 +685,8 @@ def __call__( input_width=self.conv2_input_width, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_dim=True, + return_weights_and_bias=True, ) use_in_shortcut = in_channels != out_channels if use_in_shortcut is None else use_in_shortcut @@ -710,7 +714,11 @@ def __call__( transpose_shards=False, reshard_if_not_optimal=False, ) - [input_tensor, _out_height, _out_width, self.conv_shortcut_weights, self.conv_shortcut_bias] = ttnn.conv2d( + [ + input_tensor, + [_out_height, _out_width], + [self.conv_shortcut_weights, self.conv_shortcut_bias], + ] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.conv_shortcut_weights, in_channels=self.conv_shortcut_in_channels, @@ -725,6 +733,8 @@ def __call__( input_width=self.conv_shortcut_input_width, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if ttnn.get_memory_config(input_tensor) != ttnn.get_memory_config(hidden_states): diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py index 12e4d5432071..8972ba32b54b 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_transformer_2d_new_conv.py @@ -249,7 +249,7 @@ def __call__( fp32_dest_acc_enabled=self.compute_kernel_config.fp32_dest_acc_en, transpose_shards=False, ) - [hidden_states, _out_height, _out_width, self.proj_in_conv_weights, self.proj_in_conv_bias] = ttnn.conv2d( + [hidden_states, [self.proj_in_conv_weights, self.proj_in_conv_bias]] = ttnn.conv2d( input_tensor=hidden_states, in_channels=self.proj_in_in_channels, out_channels=self.proj_in_out_channels, @@ -264,6 +264,8 @@ def __call__( bias_tensor=self.proj_in_conv_bias, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) inner_dim = hidden_states.shape[-1] @@ -293,10 +295,8 @@ def __call__( # hidden_states = ttnn.to_memory_config(hidden_states, self.proj_out.conv.input_sharded_memory_config) [ hidden_states, - _out_height, - _out_width, - self.proj_out_conv_weights, - self.proj_out_conv_bias, + [_out_height, _out_width], + [self.proj_out_conv_weights, self.proj_out_conv_bias], ] = ttnn.conv2d( input_tensor=hidden_states, in_channels=self.proj_out_in_channels, @@ -312,6 +312,8 @@ def __call__( bias_tensor=self.proj_out_conv_bias, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if output_bfloat16: diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py index 9cbdfff2f486..26851dfa1769 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_unet_2d_condition_model_new_conv.py @@ -394,7 +394,7 @@ def __call__( reshard_if_not_optimal=True, ) - [sample, _out_height, _out_width, self.conv_in_weights, self.conv_in_bias] = ttnn.conv2d( + [sample, [self.conv_in_weights, self.conv_in_bias]] = ttnn.conv2d( input_tensor=sample, weight_tensor=self.conv_in_weights, bias_tensor=self.conv_in_bias, @@ -409,6 +409,8 @@ def __call__( input_width=self.input_width, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) sample = ttnn.reallocate(sample) # TODO: Test remove @@ -657,7 +659,7 @@ def __call__( transpose_shards=False, reshard_if_not_optimal=True, ) - [sample, _out_height, _out_width, self.conv_out_weights, self.conv_out_bias] = ttnn.conv2d( + [sample, [self.conv_out_weights, self.conv_out_bias]] = ttnn.conv2d( input_tensor=sample, in_channels=self.conv_out_in_channels, out_channels=self.conv_out_out_channels, @@ -672,6 +674,8 @@ def __call__( bias_tensor=self.conv_out_bias, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) sample = ttnn.to_memory_config(sample, ttnn.L1_MEMORY_CONFIG) sample = ttnn.clone(sample, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=ttnn.bfloat16) diff --git a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py index 622a63065db4..d890230689b2 100644 --- a/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py +++ b/models/demos/wormhole/stable_diffusion/tt/ttnn_functional_upsample_2d_new_conv.py @@ -103,7 +103,7 @@ def __call__(self, input, in_channels, out_channels): ) if self.conv_config_override and "act_block_h" in self.conv_config_override: conv_config.act_block_h_override = self.conv_config_override["act_block_h"] - [tt_out, _out_height, _out_width, self.conv_weight_tensor, self.conv_bias_tensor] = ttnn.conv2d( + [tt_out, [self.conv_weight_tensor, self.conv_bias_tensor]] = ttnn.conv2d( input_tensor=tt_out, in_channels=self.conv_in_channels, out_channels=self.conv_out_channels, @@ -118,5 +118,7 @@ def __call__(self, input, in_channels, out_channels): bias_tensor=self.conv_bias_tensor, conv_config=conv_config, conv_op_cache=conv_cache, + return_output_dim=False, + return_weights_and_bias=True, ) return tt_out diff --git a/models/demos/yolov4/ttnn/common.py b/models/demos/yolov4/ttnn/common.py index b293a6db751c..b1c736accc29 100644 --- a/models/demos/yolov4/ttnn/common.py +++ b/models/demos/yolov4/ttnn/common.py @@ -99,7 +99,7 @@ def __call__(self, device, input_tensor): if self.act_block_h is not None: conv_config.act_block_h_override = self.act_block_h - [output_tensor, _out_height, _out_width, self.weights, self.bias] = ttnn.conv2d( + output_tensor, [self.weights, self.bias] = ttnn.conv2d( input_tensor=input_tensor, weight_tensor=self.weights, bias_tensor=self.bias, @@ -113,5 +113,7 @@ def __call__(self, device, input_tensor): input_height=self.input_params[1], input_width=self.input_params[2], conv_config=conv_config, + return_output_dim=False, + return_weights_and_bias=True, ) return output_tensor diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index 215399ea23ba..cbb393634db2 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -143,7 +143,7 @@ def __init__( self.bias = ttnn.from_torch(bias, dtype=ttnn.float32, mesh_mapper=mesh_mapper) def __call__(self, x): - x, _, _, self.weight, self.bias = ttnn.conv2d( + x, [self.weight, self.bias] = ttnn.conv2d( input_tensor=x, weight_tensor=self.weight, bias_tensor=self.bias, @@ -159,6 +159,8 @@ def __call__(self, x): conv_config=self.conv_config, conv_op_cache=self.cache, groups=2, + return_output_dim=False, + return_weights_and_bias=True, ) return x diff --git a/tests/sweep_framework/sweep_utils/conv2d_common.py b/tests/sweep_framework/sweep_utils/conv2d_common.py index 55769adb9842..6906aea343d4 100644 --- a/tests/sweep_framework/sweep_utils/conv2d_common.py +++ b/tests/sweep_framework/sweep_utils/conv2d_common.py @@ -137,7 +137,7 @@ def run_full( {ttnn.CoreRange(core_grid[0], core_grid[1]), ttnn.CoreRange(core_grid[2], core_grid[3])} ) start_time = start_measuring_time() - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -153,6 +153,8 @@ def run_full( input_width=input_width, conv_config=conv_config, groups=groups, + return_output_dim=True, + return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) @@ -220,7 +222,7 @@ def run_short( tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) start_time = start_measuring_time() - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -235,6 +237,8 @@ def run_short( input_height=input_height, input_width=input_width, groups=groups, + return_output_dim=True, + return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) diff --git a/tests/ttnn/unit_tests/operations/test_conv1d.py b/tests/ttnn/unit_tests/operations/test_conv1d.py index 3e7a1496c634..22c0276d948a 100644 --- a/tests/ttnn/unit_tests/operations/test_conv1d.py +++ b/tests/ttnn/unit_tests/operations/test_conv1d.py @@ -104,7 +104,7 @@ def run_conv( conv_config.override_sharding_config = True print("Setting num_cores_nhw to 98") - [tt_output_tensor_on_device, out_length, weights_device, bias_device] = ttnn.Conv1d( + [tt_output_tensor_on_device, out_length, [weights_device, bias_device]] = ttnn.Conv1d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -120,6 +120,8 @@ def run_conv( conv_op_cache=reader_patterns_cache, debug=debug, groups=groups, + return_output_dim=True, + return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index cfe4c0f143ad..e0e9b55a8182 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -162,7 +162,7 @@ def run_conv( conv_config.override_sharding_config = True print("Setting num_cores_nhw to 98") - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -181,6 +181,8 @@ def run_conv( debug=debug, groups=groups, memory_config=memory_config, + return_weights_and_bias=True, + return_output_dim=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) @@ -306,7 +308,7 @@ def run_conv_with_split( tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) # tt_input_tensor_on_device = convs[i].copy_input_to_device(tt_input_tensor) # tt_output_tensor_on_device = convs[i](tt_input_tensor_on_device) - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=split_input_channels, @@ -321,6 +323,8 @@ def run_conv_with_split( input_width=input_width, conv_config=conv_config, conv_op_cache=reader_patterns_cache, + return_output_dim=True, + return_weights_and_bias=True, ) tt_conv_output_tensor = ttnn.from_device(tt_output_tensor_on_device) torch_conv_output_tensor = ttnn.to_torch(tt_conv_output_tensor) @@ -638,7 +642,7 @@ def test_conv_ws( act_block_w_div=act_block_w_div, act_block_h_override=32, ) - [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + tt_output_tensor_on_device, [out_height, out_width], [weights_device, bias_device] = ttnn.conv2d( input_tensor=tt_input_tensor, weight_tensor=tt_weight_tensor, in_channels=input_channels, @@ -655,6 +659,8 @@ def test_conv_ws( conv_op_cache=reader_patterns_cache, debug=debug, groups=groups, + return_output_dim=True, + return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_output_tensor_on_device) @@ -2644,7 +2650,7 @@ def test_shallow_conv_with_tiled_input(device): tt_input = ttnn.reshape(tt_input, (1, 1, batch_size * img_h * img_w, in_channels)) - tt_out, out_height, out_width, _, _ = ttnn.conv2d( + tt_out, [out_height, out_width], [weights_device, bias_device] = ttnn.conv2d( input_tensor=tt_input, weight_tensor=tt_kernel, in_channels=in_channels, @@ -2660,6 +2666,8 @@ def test_shallow_conv_with_tiled_input(device): input_width=img_w, groups=1, memory_config=ttnn.DRAM_MEMORY_CONFIG, + return_output_dim=True, + return_weights_and_bias=True, ) tt_output_tensor = ttnn.from_device(tt_out) diff --git a/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py b/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py index 84ee4d5d9729..71331834332b 100644 --- a/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py +++ b/tests/ttnn/unit_tests/operations/test_small_resnet50_block.py @@ -103,7 +103,7 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac # logger.info("This module input shape - ", self.module_input_shape) # conv1 is 1x1 conv # print("Running conv1") - x, input_height, input_width, self.identity_conv_weight_tensor, _ = ttnn.conv2d( + x, [input_height, input_width], [self.identity_conv_weight_tensor, _] = ttnn.conv2d( input_tensor=x, weight_tensor=self.identity_conv_weight_tensor, in_channels=self.conv1_input_channels, @@ -121,9 +121,11 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac math_fidelity=self.model_config["MATH_FIDELITY"], ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) - out, input_height, input_width, self.conv1_weight_tensor, self.conv1_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv1_weight_tensor, self.conv1_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.conv1_weight_tensor, in_channels=self.conv1_input_channels, @@ -143,10 +145,12 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac activation="relu", ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) if self.downsample: - ds_out, _, _, self.ds_conv_weight_tensor, self.ds_conv_bias_tensor = ttnn.conv2d( + ds_out, [self.ds_conv_weight_tensor, self.ds_conv_bias_tensor] = ttnn.conv2d( input_tensor=x, weight_tensor=self.ds_conv_weight_tensor, in_channels=self.ds_conv_input_channels, @@ -165,13 +169,15 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac math_fidelity=self.model_config["MATH_FIDELITY"], ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) ttnn.deallocate(x) else: ds_out = x # print("Running conv2") - out, input_height, input_width, self.conv2_weight_tensor, self.conv2_bias_tensor = ttnn.conv2d( + out, [input_height, input_width], [self.conv2_weight_tensor, self.conv2_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv2_weight_tensor, in_channels=self.conv2_input_channels, @@ -191,11 +197,13 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac activation="relu", ), conv_op_cache=conv_op_cache, + return_output_dim=True, + return_weights_and_bias=True, ) # conv3 is 1x1 conv # print("Running conv3") - out, _, _, self.conv3_weight_tensor, self.conv3_bias_tensor = ttnn.conv2d( + out, [self.conv3_weight_tensor, self.conv3_bias_tensor] = ttnn.conv2d( input_tensor=out, weight_tensor=self.conv3_weight_tensor, in_channels=self.conv3_input_channels, @@ -214,6 +222,8 @@ def __call__(self, x, device, batch_size, input_height, input_width, conv_op_cac math_fidelity=self.model_config["MATH_FIDELITY"], ), conv_op_cache=conv_op_cache, + return_output_dim=False, + return_weights_and_bias=True, ) # underscore version is in_place = True diff --git a/ttnn/ttnn/operations/conv1d.py b/ttnn/ttnn/operations/conv1d.py index e979a12b21d1..41ded1753438 100644 --- a/ttnn/ttnn/operations/conv1d.py +++ b/ttnn/ttnn/operations/conv1d.py @@ -30,6 +30,8 @@ def Conv1d( conv_config: Conv1dConfig = None, # config overrides by user conv_op_cache={}, # basic conv object caching in python needed for intermediate refactoring. Not needed after full op refactoring in C++. debug=False, + return_output_dim=False, + return_weights_and_bias=False, ) -> Tuple[ttnn.Tensor, int, int, ttnn.Tensor, ttnn.Tensor]: # Reshape the input and weight tensors to 4D for conv2d operation # Should be no-op as input_tensor is in RM layout @@ -62,12 +64,14 @@ def Conv1d( conv_config=conv_config, ) - return ( - output_tensor_new, - output_length_new, - weight_tensor_on_dev_new, - bias_tensor_on_dev_new, - ) + if return_output_dim and return_weights_and_bias: + return output_tensor_new, output_length_new, [weight_tensor_on_dev_new, bias_tensor_on_dev_new] + elif return_weights_and_bias: + return output_tensor_new, [weight_tensor_on_dev_new, bias_tensor_on_dev_new] + elif return_output_dim: + return output_tensor_new, output_length_new + else: + return output_tensor_new __all__ = [] diff --git a/ttnn/ttnn/operations/conv2d.py b/ttnn/ttnn/operations/conv2d.py index b46a7e1fbf79..6a1f8e5291d1 100644 --- a/ttnn/ttnn/operations/conv2d.py +++ b/ttnn/ttnn/operations/conv2d.py @@ -103,8 +103,16 @@ def conv2d( memory_config: ttnn.MemoryConfig = None, # memory config overrides by user conv_op_cache={}, # basic conv object caching in python needed for intermediate refactoring. Not needed after full op refactoring in C++. debug=False, # ignored + return_output_dim=False, + return_weights_and_bias=False, ) -> Tuple[ttnn.Tensor, int, int, ttnn.Tensor, ttnn.Tensor]: - return ttnn._ttnn.operations.conv.conv2d( + ( + conv_output, + output_height, + output_width, + prepared_device_weight, + prepared_device_bias, + ) = ttnn._ttnn.operations.conv.conv2d( input_tensor=input_tensor, weight_tensor=weight_tensor, device=device, @@ -123,5 +131,14 @@ def conv2d( memory_config=memory_config, ) + if return_output_dim and return_weights_and_bias: + return conv_output, [output_height, output_width], [prepared_device_weight, prepared_device_bias] + elif return_weights_and_bias: + return conv_output, [prepared_device_weight, prepared_device_bias] + elif return_output_dim: + return conv_output, [output_height, output_width] + else: + return conv_output + __all__ = []