From c37b3d15aef1ee5ae764458d1a06b3c0b528c356 Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Thu, 27 Feb 2025 07:22:18 +0000 Subject: [PATCH 1/3] Fix accidentally disabled async mode for some t3k tests (#18381) ### Ticket https://github.com/tenstorrent/tt-metal/issues/18360 ### Problem description Recently we disabled async mode for single device, by ignoring enable_async call for it, assuming multi-device customers make a call to MeshDevice enable_async. However in some places including our test we actually iterate over each individual device in the mesh and call enable_async on it, which is being ignored ### What's changed Make a single call to MeshDevice::enable_async instead of iterating over individual devices and calling Device::enable_async for each one of them ### Checklist - [ ] [All post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/13553947437) - [x] [T3K demo tests CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/13553950838) - [x] New/Existing tests provide coverage for changes (cherry picked from commit 69a36b8d74d3f59b6309a3a224ab9a6c24d4249e) --- conftest.py | 8 ++++---- .../demos/t3000/falcon40b/tests/test_falcon_end_to_end.py | 3 +-- tests/ttnn/distributed/test_multidevice_TG.py | 4 +--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/conftest.py b/conftest.py index 9e94913a18f..05cd3c67536 100644 --- a/conftest.py +++ b/conftest.py @@ -342,13 +342,13 @@ def get_devices(request): elif "pcie_devices" in request.fixturenames: devices = request.getfixturevalue("pcie_devices") elif "mesh_device" in request.fixturenames: - devices = request.getfixturevalue("mesh_device").get_devices() + devices = [request.getfixturevalue("mesh_device")] elif "n300_mesh_device" in request.fixturenames: - devices = request.getfixturevalue("n300_mesh_device").get_devices() + devices = [request.getfixturevalue("n300_mesh_device")] elif "t3k_mesh_device" in request.fixturenames: - devices = request.getfixturevalue("t3k_mesh_device").get_devices() + devices = [request.getfixturevalue("t3k_mesh_device")] elif "pcie_mesh_device" in request.fixturenames: - devices = request.getfixturevalue("pcie_mesh_device").get_devices() + devices = [request.getfixturevalue("pcie_mesh_device")] else: devices = [] return devices diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py b/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py index c686d2dda7e..2ea829068a1 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py @@ -526,8 +526,7 @@ def test_FalconCausalLM_end_to_end_with_program_cache( model_config = get_model_config(model_config_str, llm_mode, input_shape, num_devices) devices = t3k_mesh_device.get_devices() # Set async mode - for device in devices: - device.enable_async(async_mode) + t3k_mesh_device.enable_async(async_mode) compute_grid_size = devices[0].compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") diff --git a/tests/ttnn/distributed/test_multidevice_TG.py b/tests/ttnn/distributed/test_multidevice_TG.py index 6c1c84c5dd9..82b4381c4aa 100644 --- a/tests/ttnn/distributed/test_multidevice_TG.py +++ b/tests/ttnn/distributed/test_multidevice_TG.py @@ -1448,9 +1448,7 @@ def test_device_line_all_gather_8x4_data(mesh_device, cluster_axis: int, dim: in - Every device will have the shape: [4, 1, 32, 32] """ if async_mode: - for i in mesh_device.get_device_ids(): - device = mesh_device.get_device(i) - device.enable_async(True) + mesh_device.enable_async(True) (rows, cols), tile_size = mesh_device.shape, 32 full_tensor = torch.zeros((1, 1, tile_size * rows, tile_size * cols), dtype=torch.bfloat16) From a62788134410bdf8612a543dfdd558ee51ee952d Mon Sep 17 00:00:00 2001 From: Joseph Chu Date: Fri, 28 Feb 2025 06:32:26 +0000 Subject: [PATCH 2/3] Enable Caching Mechanism for MeshWorkload lots of fun TMP --- tests/ttnn/unit_tests/test_multi_device.py | 21 ++ tt_metal/api/tt-metalium/device.hpp | 5 +- tt_metal/api/tt-metalium/program_cache.hpp | 74 ++++++- tt_metal/api/tt-metalium/unique_any.hpp | 1 + tt_metal/impl/buffers/dispatch.cpp | 2 + tt_metal/impl/device/device.cpp | 1 + tt_metal/impl/event/dispatch.cpp | 1 + .../impl/lightmetal/lightmetal_replay.hpp | 1 + tt_metal/impl/trace/dispatch.cpp | 2 +- ttnn/cpp/ttnn/device_operation.hpp | 199 +++++++++++++----- .../device/moreh_matmul_device_operation.hpp | 4 +- 11 files changed, 248 insertions(+), 63 deletions(-) diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index 21e46e4383a..15694cee129 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -247,6 +247,27 @@ def test_multi_device_single_op_unary(mesh_device): assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden, pcc=0.999) +def test_multi_device_single_op_unary_with_cache(mesh_device): + """Multidevice API test: Running tensor-parallel multi-device single-op unary with cache""" + mesh_device.enable_program_cache() + + torch_input_tensor = torch.rand((1, 1, 32, 32 * mesh_device.get_num_devices()), dtype=torch.bfloat16) + torch_output_golden = torch.nn.functional.gelu(torch_input_tensor) + torch_golden = torch.nn.functional.gelu(torch_output_golden) + + ttnn_input_tensor = ttnn.from_torch( + torch_input_tensor, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=ShardTensorToMesh(mesh_device, dim=3), + device=mesh_device, + ) + ttnn_output_tensor = ttnn.gelu(ttnn_input_tensor) + final_output_tensor = ttnn.gelu(ttnn_output_tensor) + + ttnn_torch_output_tensor = ttnn.to_torch(final_output_tensor, mesh_composer=ConcatMeshToTensor(mesh_device, dim=3)) + assert_with_pcc(ttnn_torch_output_tensor, torch_golden, pcc=0.999) + + @pytest.mark.parametrize( "device_params", [{"dispatch_core_axis": ttnn.DispatchCoreAxis.ROW}, {"dispatch_core_axis": ttnn.DispatchCoreAxis.COL}], diff --git a/tt_metal/api/tt-metalium/device.hpp b/tt_metal/api/tt-metalium/device.hpp index fdc1cbef87d..7719f8de203 100644 --- a/tt_metal/api/tt-metalium/device.hpp +++ b/tt_metal/api/tt-metalium/device.hpp @@ -20,11 +20,14 @@ #include "sub_device_manager.hpp" #include "sub_device_types.hpp" #include "span.hpp" -#include "program_cache.hpp" namespace tt { namespace tt_metal { + +namespace program_cache::detail { +class ProgramCache; +} /* MemoryBlockTable is a list of memory blocks in the following format: [{"blockID": "0", "address": "0", "size": "0", "prevID": "0", "nextID": "0", "allocated": true}] diff --git a/tt_metal/api/tt-metalium/program_cache.hpp b/tt_metal/api/tt-metalium/program_cache.hpp index ef45052c61c..3ad96024620 100644 --- a/tt_metal/api/tt-metalium/program_cache.hpp +++ b/tt_metal/api/tt-metalium/program_cache.hpp @@ -8,8 +8,18 @@ #include "program_impl.hpp" #include "unique_any.hpp" +#include "mesh_workload.hpp" namespace tt::tt_metal::program_cache::detail { +template +struct CachedMeshWorkload { + tt::tt_metal::distributed::MeshWorkload workload; + // Shared variables between create and override_runtime_arguments functions + shared_variables_t shared_variables; + + CachedMeshWorkload(tt::tt_metal::distributed::MeshWorkload&& workload, shared_variables_t&& shared_variables) : + workload{std::move(workload)}, shared_variables{std::forward(shared_variables)} {} +}; template struct CachedProgram { @@ -21,6 +31,68 @@ struct CachedProgram { program{std::move(program)}, shared_variables{std::forward(shared_variables)} {} }; +// Adapter that provides a unified interface for both CachedProgram and CachedMeshWorkload +template +class ProgramAdapter { +private: + using CachedObject = std::variant, CachedProgram>; + CachedObject cached_object_; + // Helper to retrieve the first program from a mesh workload + static tt::tt_metal::Program& get_first_program(CachedMeshWorkload& cached_mesh_workload) { + // Get the programs map from the workload + auto& programs = cached_mesh_workload.workload.get_programs(); + + // There must be at least one program in the workload + TT_FATAL(!programs.empty(), "Mesh workload must have at least one program"); + + // Return the first program in the workload + auto& first_program_pair = *programs.begin(); + return first_program_pair.second; + } + +public: + // These are references to the original objects + tt::tt_metal::Program& program; + shared_variables_t& shared_variables; + + // Constructor for CachedProgram + ProgramAdapter(CachedProgram&& cached_program) : + cached_object_(std::move(cached_program)), + program(std::get>(cached_object_).program), + shared_variables(std::get>(cached_object_).shared_variables) {} + + // Constructor for Program and shared variables + ProgramAdapter(tt::tt_metal::Program&& program, shared_variables_t&& shared_vars) : + ProgramAdapter(CachedProgram{std::move(program), std::move(shared_vars)}) {} + + // Constructor for CachedMeshWorkload + ProgramAdapter(CachedMeshWorkload&& cached_mesh_workload) : + cached_object_(std::move(cached_mesh_workload)), + program(get_first_program(std::get>(cached_object_))), + shared_variables(std::get>(cached_object_).shared_variables) {} + + ProgramAdapter(ProgramAdapter&& other) noexcept : + cached_object_{std::move(other.cached_object_)}, + program{ + (cached_object_.index() == 0) + ? get_first_program(std::get>(cached_object_)) + : std::get>(cached_object_).program}, + shared_variables{ + (cached_object_.index() == 0) + ? std::get>(cached_object_).shared_variables + : std::get>(cached_object_).shared_variables} {} + + // Get the CachedMeshWorkload (throws if not a mesh workload) + CachedMeshWorkload& get_cached_mesh_workload() { + return std::get>(cached_object_); + } + + // Get the CachedProgram (throws if not a program) + CachedProgram& get_cached_program() { + return std::get>(cached_object_); + } +}; + struct CachedProgramFactory { static constexpr auto MAX_SIZE = 4096; static constexpr auto ALIGNMENT = 32; @@ -30,7 +102,7 @@ struct CachedProgramFactory { std::size_t program_factory_index; template - CachedProgramFactory(CachedProgram&& cached_program, std::size_t program_factory_index) : + CachedProgramFactory(ProgramAdapter&& cached_program, std::size_t program_factory_index) : cached_program{std::move(cached_program)}, program_factory_index{program_factory_index} {} }; diff --git a/tt_metal/api/tt-metalium/unique_any.hpp b/tt_metal/api/tt-metalium/unique_any.hpp index 01854f0fdbb..6490fbc52af 100644 --- a/tt_metal/api/tt-metalium/unique_any.hpp +++ b/tt_metal/api/tt-metalium/unique_any.hpp @@ -54,6 +54,7 @@ struct unique_any final { } this->delete_storage = other.delete_storage; this->move_storage = other.move_storage; + other.pointer = nullptr; } return *this; } diff --git a/tt_metal/impl/buffers/dispatch.cpp b/tt_metal/impl/buffers/dispatch.cpp index f1de42f22e9..396a6d8ea86 100644 --- a/tt_metal/impl/buffers/dispatch.cpp +++ b/tt_metal/impl/buffers/dispatch.cpp @@ -7,7 +7,9 @@ #include "assert.hpp" #include "dispatch.hpp" #include +#include #include +#include #include "tt_cluster.hpp" diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 71cae1833ab..c8ae50d1e51 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -32,6 +32,7 @@ #include #include #include +#include #include "impl/dispatch/topology.hpp" #include "impl/dispatch/hardware_command_queue.hpp" diff --git a/tt_metal/impl/event/dispatch.cpp b/tt_metal/impl/event/dispatch.cpp index dad0f24cb7e..689d6e0e495 100644 --- a/tt_metal/impl/event/dispatch.cpp +++ b/tt_metal/impl/event/dispatch.cpp @@ -4,6 +4,7 @@ #include "tt_metal/impl/event/dispatch.hpp" #include +#include #include "tt_metal/impl/dispatch/dispatch_query_manager.hpp" #include diff --git a/tt_metal/impl/lightmetal/lightmetal_replay.hpp b/tt_metal/impl/lightmetal/lightmetal_replay.hpp index 87e94bb2c86..ea2cdd99385 100644 --- a/tt_metal/impl/lightmetal/lightmetal_replay.hpp +++ b/tt_metal/impl/lightmetal/lightmetal_replay.hpp @@ -10,6 +10,7 @@ #include #include +#include #include // Forward decl for trace_buffer.hpp diff --git a/tt_metal/impl/trace/dispatch.cpp b/tt_metal/impl/trace/dispatch.cpp index 19d08460004..fb4c983c1f3 100644 --- a/tt_metal/impl/trace/dispatch.cpp +++ b/tt_metal/impl/trace/dispatch.cpp @@ -4,7 +4,7 @@ #include "tt_metal/impl/trace/dispatch.hpp" #include "tt_metal/impl/dispatch/dispatch_query_manager.hpp" - +#include namespace tt::tt_metal::trace_dispatch { void reset_host_dispatch_state_for_trace( diff --git a/ttnn/cpp/ttnn/device_operation.hpp b/ttnn/cpp/ttnn/device_operation.hpp index a139f7899d7..5cb27cf8572 100644 --- a/ttnn/cpp/ttnn/device_operation.hpp +++ b/ttnn/cpp/ttnn/device_operation.hpp @@ -23,7 +23,7 @@ namespace ttnn { namespace device_operation { template -using CachedProgram = tt::tt_metal::program_cache::detail::CachedProgram; +using CachedProgram = tt::tt_metal::program_cache::detail::ProgramAdapter; template concept ProgramFactoryConcept = requires { @@ -132,7 +132,8 @@ inline auto& create_or_get_program_from_cache( program_cache.insert( program_hash, tt::tt_metal::program_cache::detail::CachedProgramFactory{ - program_factory_t::create(operation_attributes, tensor_args, tensor_return_value), + tt::tt_metal::program_cache::detail::ProgramAdapter( + program_factory_t::create(operation_attributes, tensor_args, tensor_return_value)), program_factory_index}); auto& cached_program_factory = program_cache.get(program_hash); auto& cached_program = cached_program_factory.cached_program.template get(); @@ -168,6 +169,119 @@ inline auto& create_or_get_program_from_cache( } } +template +inline auto& create_or_get_meshworkload_from_cache( + auto& program_cache, + auto program_cache_hit, + auto program_hash, + const typename device_operation_t::operation_attributes_t& operation_attributes, + const typename device_operation_t::tensor_args_t& tensor_args, + typename device_operation_t::tensor_return_value_t& tensor_return_value, + tt::tt_metal::distributed::MeshDevice* mesh_device, + uint64_t device_operation_id) { + if (!program_cache_hit) { + tt::log_info("CACHE MISS: Creating mesh workload from cache"); + auto program_factory = device_operation_t::select_program_factory(operation_attributes, tensor_args); + auto program_factory_index = program_factory.index(); + + auto& mesh_workload = std::visit( + [&program_factory_index, + &program_hash, + &program_cache, + &operation_attributes, + &tensor_args, + &tensor_return_value, + &device_operation_id, + &mesh_device](auto&& program_factory) -> auto& { + using program_factory_t = std::decay_t; + using cached_program_t = + decltype(program_factory_t::create(operation_attributes, tensor_args, tensor_return_value)); + + // Create a cached program (contains both program and shared variables) + auto cached_program = program_factory_t::create(operation_attributes, tensor_args, tensor_return_value); + + // Set runtime ID here before moving the program + cached_program.program.set_runtime_id(device_operation_id); + + // Create a new mesh workload + auto mesh_workload = tt::tt_metal::distributed::CreateMeshWorkload(); + + // Move the program from the cached_program into the mesh workload + tt::tt_metal::distributed::AddProgramToMeshWorkload( + mesh_workload, + std::move(cached_program.program), // Move the program + tt::tt_metal::distributed::MeshCoordinateRange( + {0, 0}, {mesh_device->num_rows() - 1, mesh_device->num_cols() - 1})); + + // Create a cached mesh workload with the mesh workload and shared variables + auto cached_mesh_workload = tt::tt_metal::program_cache::detail::CachedMeshWorkload< + typename program_factory_t::shared_variables_t>( + std::move(mesh_workload), std::move(cached_program.shared_variables)); + + // Create a program adapter to wrap the cached mesh workload + tt::tt_metal::program_cache::detail::ProgramAdapter + adapter(std::move(cached_mesh_workload)); + + // Insert the cached program factory into the cache + program_cache.insert( + program_hash, + tt::tt_metal::program_cache::detail::CachedProgramFactory{ + std::move(adapter), program_factory_index}); + + // Return the mesh workload from the cached factory + auto& cached_program_factory = program_cache.get(program_hash); + // Get the program adapter from the cached factory + auto& cached_adapter = cached_program_factory.cached_program + .template get>(); + return cached_adapter.get_cached_mesh_workload().workload; + }, + program_factory); + return mesh_workload; + } else { + tt::log_info("CACHE HIT: Creating mesh workload from cache"); + auto& cached_program_factory = program_cache.get(program_hash); + auto program_factory_index = cached_program_factory.program_factory_index; + + // Reconstruct the program factory variant based on the stored index + using program_factory_variant_t = + decltype(device_operation_t::select_program_factory(operation_attributes, tensor_args)); + auto program_factory = map_index_to_variant(program_factory_index, program_factory_variant_t{}); + + // Use std::visit to override runtime arguments using the selected factory + auto& mesh_workload = std::visit( + [&program_factory_index, + &operation_attributes, + &tensor_args, + &tensor_return_value, + &device_operation_id, + &mesh_device, + &cached_program_factory](auto&& program_factory) -> auto& { + using program_factory_t = std::decay_t; + using cached_program_t = + decltype(program_factory_t::create(operation_attributes, tensor_args, tensor_return_value)); + + // Get the program adapter from the cached factory + auto& adapter = cached_program_factory.cached_program + .template get>(); + + // Override runtime arguments through the adapter + program_factory_t::override_runtime_arguments( + adapter, operation_attributes, tensor_args, tensor_return_value); + + adapter.program.set_runtime_id(device_operation_id); + + tt::tt_metal::GraphTracker::instance().track_program(&adapter.program, mesh_device); + + // Return the mesh workload from the cached factory + return adapter.get_cached_mesh_workload().workload; + }, + program_factory); + return mesh_workload; + } +} + struct CheckDeviceBufferIsAllocated { std::size_t index = 0; @@ -259,8 +373,7 @@ void launch_on_mesh_device( auto program_hash = 0; bool program_cache_hit = false; - // auto is_program_cache_enabled = program_cache.is_enabled(); - auto is_program_cache_enabled = false; + auto is_program_cache_enabled = program_cache.is_enabled(); if (is_program_cache_enabled) { program_hash = compute_program_hash(operation_attributes, tensor_args); program_cache_hit = program_cache.contains(program_hash); @@ -279,38 +392,19 @@ void launch_on_mesh_device( device_operation_t::validate_on_program_cache_miss(operation_attributes, tensor_args); } - const auto enqueue_or_launch_program = [=](tt::tt_metal::Program& program) { - if (USE_FAST_DISPATCH) { - ZoneScopedN("EnqueueProgram"); - auto& queue = device->command_queue(*cq_id); - tt::tt_metal::EnqueueProgram(queue, program, false); - } else { - ZoneScopedN("LaunchProgram"); - tt::tt_metal::detail::LaunchProgram(device, program); - } - }; - if (is_program_cache_enabled) { - auto& program = create_or_get_program_from_cache( - program_cache, program_cache_hit, program_hash, operation_attributes, tensor_args, tensor_return_value); - - program.set_runtime_id(device_operation_id); - - tt::tt_metal::GraphTracker::instance().track_program(&program, device); - if (tt::tt_metal::GraphTracker::instance().hook_program(&program)) { - return; - } - - enqueue_or_launch_program(program); - - TracyOpTTNNDevice( - device_operation_t{}, - device_operation_id, - device->id(), - program, + tt::log_info("Creating mesh workload from cache"); + auto& mesh_workload = create_or_get_meshworkload_from_cache( + program_cache, + program_cache_hit, + program_hash, operation_attributes, tensor_args, - tensor_return_value); + tensor_return_value, + device, + device_operation_id); + + tt::tt_metal::distributed::EnqueueMeshWorkload(device->mesh_command_queue(), mesh_workload, true); } else { auto program_factory = device_operation_t::select_program_factory(operation_attributes, tensor_args); @@ -318,8 +412,8 @@ void launch_on_mesh_device( auto program = std::visit( [&](auto&& program_factory) { using program_factory_t = std::decay_t; - return std::make_shared( - program_factory_t::create(operation_attributes, tensor_args, tensor_return_value).program); + auto cached_program = program_factory_t::create(operation_attributes, tensor_args, tensor_return_value); + return std::make_shared(std::move(cached_program.program)); }, program_factory); @@ -329,27 +423,16 @@ void launch_on_mesh_device( if (tt::tt_metal::GraphTracker::instance().hook_program(program.get())) { return; } - if (auto mesh_device = dynamic_cast(device); mesh_device != nullptr) { - auto& cq = mesh_device->mesh_command_queue(); - auto mesh_workload = tt::tt_metal::distributed::CreateMeshWorkload(); - tt::tt_metal::distributed::AddProgramToMeshWorkload( - mesh_workload, - std::move(*program), - tt::tt_metal::distributed::MeshCoordinateRange( - {0, 0}, {mesh_device->num_rows() - 1, mesh_device->num_cols() - 1})); - tt::tt_metal::distributed::EnqueueMeshWorkload(cq, mesh_workload, true); - } else { - enqueue_or_launch_program(*program); - - TracyOpTTNNDevice( - device_operation_t{}, - device_operation_id, - device->id(), - *program, - operation_attributes, - tensor_args, - tensor_return_value); - } + auto mesh_device = dynamic_cast(device); + TT_FATAL(mesh_device != nullptr, "Device is not a MeshDevice"); + auto& cq = mesh_device->mesh_command_queue(); + auto mesh_workload = tt::tt_metal::distributed::CreateMeshWorkload(); + tt::tt_metal::distributed::AddProgramToMeshWorkload( + mesh_workload, + std::move(*program), + tt::tt_metal::distributed::MeshCoordinateRange( + {0, 0}, {mesh_device->num_rows() - 1, mesh_device->num_cols() - 1})); + tt::tt_metal::distributed::EnqueueMeshWorkload(cq, mesh_workload, true); } } @@ -432,8 +515,8 @@ void launch_on_worker_thread( auto program = std::visit( [&](auto&& program_factory) { using program_factory_t = std::decay_t; - return std::make_shared( - program_factory_t::create(operation_attributes, tensor_args, tensor_return_value).program); + auto cached_program = program_factory_t::create(operation_attributes, tensor_args, tensor_return_value); + return std::make_shared(std::move(cached_program.program)); }, program_factory); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp index e61e5cf51c8..c568179ed42 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp @@ -34,14 +34,14 @@ struct MorehMatmulOperation { using tensor_return_value_t = Tensor; struct MultiCoreProgramFactory { - struct shared_variable_t { + struct shared_variables_t { KernelHandle reader_kernel_id; KernelHandle writer_kernel_id; std::size_t num_cores; std::size_t num_cores_y; }; - using cached_program_t = ttnn::device_operation::CachedProgram; + using cached_program_t = ttnn::device_operation::CachedProgram; static cached_program_t create( const operation_attributes_t& operation_attributes, From ad8be43a9d37371f78435c94e9eaa82157562288 Mon Sep 17 00:00:00 2001 From: asaigal Date: Fri, 28 Feb 2025 17:14:12 +0000 Subject: [PATCH 3/3] #0: Remove TTNN synchronization/blocking when executing workloads --- ttnn/cpp/ttnn/device_operation.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/device_operation.hpp b/ttnn/cpp/ttnn/device_operation.hpp index 5cb27cf8572..74f2512d40e 100644 --- a/ttnn/cpp/ttnn/device_operation.hpp +++ b/ttnn/cpp/ttnn/device_operation.hpp @@ -404,7 +404,7 @@ void launch_on_mesh_device( device, device_operation_id); - tt::tt_metal::distributed::EnqueueMeshWorkload(device->mesh_command_queue(), mesh_workload, true); + tt::tt_metal::distributed::EnqueueMeshWorkload(device->mesh_command_queue(), mesh_workload, false); } else { auto program_factory = device_operation_t::select_program_factory(operation_attributes, tensor_args); @@ -432,7 +432,7 @@ void launch_on_mesh_device( std::move(*program), tt::tt_metal::distributed::MeshCoordinateRange( {0, 0}, {mesh_device->num_rows() - 1, mesh_device->num_cols() - 1})); - tt::tt_metal::distributed::EnqueueMeshWorkload(cq, mesh_workload, true); + tt::tt_metal::distributed::EnqueueMeshWorkload(cq, mesh_workload, false); } }