From 4e6102b3be6dd099d0c5180c4c4a9dd7f04b9095 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Sun, 8 Dec 2024 01:18:58 +0000 Subject: [PATCH 1/2] #15836: Update reads, writes, and synchronize ttnn apis to take in sub device ids --- tests/ttnn/unit_tests/test_sub_device.py | 42 +++++++- ttnn/cpp/pybind11/device.cpp | 23 +++- ttnn/cpp/pybind11/events.cpp | 8 +- ttnn/cpp/pybind11/operations/core.hpp | 31 ++++-- ttnn/cpp/pybind11/pytensor.cpp | 35 +++++- ttnn/cpp/ttnn/events.cpp | 11 +- ttnn/cpp/ttnn/events.hpp | 5 +- ttnn/cpp/ttnn/operations/core/core.cpp | 36 +++++-- ttnn/cpp/ttnn/operations/core/core.hpp | 26 ++++- .../cpp/ttnn/operations/reduction/moe/moe.cpp | 3 +- ttnn/cpp/ttnn/tensor/tensor.cpp | 41 ++++--- ttnn/cpp/ttnn/tensor/tensor.hpp | 27 +++-- ttnn/cpp/ttnn/tensor/tensor_impl.cpp | 102 +++++++++++------- ttnn/cpp/ttnn/tensor/tensor_impl.hpp | 17 ++- ttnn/cpp/ttnn/tensor/tensor_ops.cpp | 64 +++++++---- ttnn/cpp/ttnn/tensor/tensor_ops.hpp | 17 ++- ttnn/ttnn/__init__.py | 2 + ttnn/ttnn/device.py | 3 + ttnn/ttnn/distributed/distributed.py | 12 ++- ttnn/ttnn/operations/core.py | 14 ++- 20 files changed, 378 insertions(+), 141 deletions(-) diff --git a/tests/ttnn/unit_tests/test_sub_device.py b/tests/ttnn/unit_tests/test_sub_device.py index 7d3f93797a7..be2c8174870 100644 --- a/tests/ttnn/unit_tests/test_sub_device.py +++ b/tests/ttnn/unit_tests/test_sub_device.py @@ -29,7 +29,11 @@ def run_sub_devices(device): sub_device_manager1 = device.create_sub_device_manager([sub_device_1, sub_device_2], 3200) sub_device_manager2 = device.create_sub_device_manager([sub_device_2], 3200) device.load_sub_device_manager(sub_device_manager1) + ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(1)]) + ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(0), ttnn.SubDeviceId(1)]) + ttnn.synchronize_devices(device) device.load_sub_device_manager(sub_device_manager2) + ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(0)]) device.clear_loaded_sub_device_manager() device.remove_sub_device_manager(sub_device_manager1) device.remove_sub_device_manager(sub_device_manager2) @@ -48,16 +52,16 @@ def run_sub_devices_program(device): tensix_cores0 = ttnn.CoreRangeSet( { ttnn.CoreRange( - ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(3, 3), + ttnn.CoreCoord(4, 4), + ttnn.CoreCoord(4, 4), ), } ) tensix_cores1 = ttnn.CoreRangeSet( { ttnn.CoreRange( - ttnn.CoreCoord(4, 4), - ttnn.CoreCoord(4, 4), + ttnn.CoreCoord(0, 0), + ttnn.CoreCoord(3, 3), ), } ) @@ -74,8 +78,19 @@ def run_sub_devices_program(device): device=device, memory_config=ttnn.L1_MEMORY_CONFIG, mesh_mapper=inputs_mesh_mapper, + sub_device_ids=[ttnn.SubDeviceId(0)], ) + xt_host = ttnn.from_torch( + x, + dtype=ttnn.bfloat16, + layout=ttnn.TILE_LAYOUT, + mesh_mapper=inputs_mesh_mapper, + sub_device_ids=[ttnn.SubDeviceId(1)], + ) + + ttnn.copy_host_to_device_tensor(xt_host, xt, sub_device_ids=[ttnn.SubDeviceId(1)]) + grid_size = device.compute_with_storage_grid_size() shard_size = [32, 64] shard_scheme = ttnn.TensorMemoryLayout.HEIGHT_SHARDED @@ -83,11 +98,28 @@ def run_sub_devices_program(device): yt = ttnn.interleaved_to_sharded( xt, grid_size, shard_size, shard_scheme, shard_orientation, output_dtype=ttnn.bfloat16 ) - y = ttnn.to_torch(yt, device=device, mesh_composer=output_mesh_composer) + y = ttnn.to_torch(yt, device=device, mesh_composer=output_mesh_composer, sub_device_ids=[ttnn.SubDeviceId(1)]) + + eq = torch.equal(x, y) + assert eq + + y = ttnn.to_torch(yt.cpu(sub_device_ids=[ttnn.SubDeviceId(0)]), mesh_composer=output_mesh_composer) eq = torch.equal(x, y) assert eq + event = ttnn.create_event(device) + + yt2 = ttnn.interleaved_to_sharded( + xt, grid_size, shard_size, shard_scheme, shard_orientation, output_dtype=ttnn.bfloat16 + ) + ttnn.record_event(0, event, [ttnn.SubDeviceId(1)]) + ttnn.wait_for_event(0, event) + y2 = ttnn.to_torch(yt2, device=device, mesh_composer=output_mesh_composer, sub_device_ids=[ttnn.SubDeviceId(0)]) + + eq = torch.equal(x, y2) + assert eq + device.clear_loaded_sub_device_manager() device.remove_sub_device_manager(sub_device_manager) diff --git a/ttnn/cpp/pybind11/device.cpp b/ttnn/cpp/pybind11/device.cpp index b60a36ed7ad..1196c1f08e0 100644 --- a/ttnn/cpp/pybind11/device.cpp +++ b/ttnn/cpp/pybind11/device.cpp @@ -100,6 +100,8 @@ void py_device_module_types(py::module& m_device) { py::class_(m_device, "SubDevice", "Class describing a sub-device of a Tenstorrent accelerator device."); + py::class_(m_device, "SubDeviceId", "ID of a sub-device."); + py::class_(m_device, "SubDeviceManagerId", "ID of a sub-device manager."); } @@ -114,6 +116,14 @@ void device_module(py::module& m_device) { The order of cores is Tensix, then Ethernet. )doc"); + auto pySubDeviceId = static_cast>(m_device.attr("SubDeviceId")); + pySubDeviceId.def( + py::init(), + py::arg("id"), + R"doc( + Creates a SubDeviceId object with the given ID. + )doc"); + auto pyDevice = static_cast>>(m_device.attr("Device")); pyDevice .def( @@ -482,10 +492,11 @@ void device_module(py::module& m_device) { m_device.def( "synchronize_device", - [](Device* device, const std::optional cq_id) { + [](Device* device, const std::optional cq_id, const std::vector& sub_device_ids) { // Send finish command to issue queue through worker thread // Worker thread will stall until the device is flushed. - device->push_work([device, cq_id]() mutable { Synchronize(device, cq_id); }); + device->push_work( + [device, cq_id, &sub_device_ids]() mutable { Synchronize(device, cq_id, sub_device_ids); }); // Main thread stalls until worker is complete (full device and worker queue flush). device->synchronize(); }, @@ -493,10 +504,13 @@ void device_module(py::module& m_device) { Synchronize the device with host by waiting for all operations to complete. If cq_id is provided then only the operations associated with that cq_id are waited for, otherwise operations for all command queues are waited on. + If the device has been configured with sub-devices, then sub_device_ids can be provided to only wait + for the operations that ran on the specified sub-devices, otherwise all sub-devices (the entire chip) are waited on. Args: device (ttnn.device.Device): The device to synchronize with. cq_id (int, optional): The command queue ID to synchronize. Defaults to `None`. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to synchronize. Defaults to all sub-devices. Returns: `None`: The op ensures that all operations are completed. @@ -508,7 +522,8 @@ void device_module(py::module& m_device) { >>> ttnn.synchronize_device(device) )doc", py::arg("device"), - py::arg("cq_id") = std::nullopt); + py::arg("cq_id") = std::nullopt, + py::arg("sub_device_ids") = std::vector()); m_device.def("SetLazyCommandQueueMode", &tt::tt_metal::detail::SetLazyCommandQueueMode, R"doc( If set to true, the host does not notify the device that there are commands available other than the FinishCommand. Once set to false, all subsequent commands will immediately notify the device @@ -527,6 +542,8 @@ void device_module(py::module& m_device) { m_device.attr("DEFAULT_L1_SMALL_SIZE") = py::int_(DEFAULT_L1_SMALL_SIZE); m_device.attr("DEFAULT_TRACE_REGION_SIZE") = py::int_(DEFAULT_TRACE_REGION_SIZE); + + m_device.attr("DefaultQueueId") = ttnn::DefaultQueueId; } void py_device_module(py::module& module) { diff --git a/ttnn/cpp/pybind11/events.cpp b/ttnn/cpp/pybind11/events.cpp index fdb12668f63..4ce6d41e644 100644 --- a/ttnn/cpp/pybind11/events.cpp +++ b/ttnn/cpp/pybind11/events.cpp @@ -31,15 +31,17 @@ void py_module(py::module& module) { module.def( "record_event", - py::overload_cast&>(&record_event), + py::overload_cast&, const std::vector&>(&record_event), py::arg("cq_id"), py::arg("event"), + py::arg("sub_device_ids") = std::vector(), R"doc( Record the completion of commands on this CQ, preceeding this call. Args: cq_id (int): The Command Queue on which event completion will be recorded. event (event): The event used to record completion of preceeding commands. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to record completion for. Defaults to all sub-devices. )doc"); module.def( @@ -69,9 +71,10 @@ void py_module(py::module& module) { module.def( "record_event", - py::overload_cast(&record_event), + py::overload_cast&>(&record_event), py::arg("cq_id"), py::arg("multi_device_event"), + py::arg("sub_device_ids") = std::vector(), R"doc( Record the completion of commands on this CQ, preceeding this call. @@ -91,6 +94,7 @@ void py_module(py::module& module) { Args: cq_id (int): The Command Queue on which event completion will be recorded. event (event): The event used to record completion of preceeding commands. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to record completion for. Defaults to all sub-devices. )doc"); } diff --git a/ttnn/cpp/pybind11/operations/core.hpp b/ttnn/cpp/pybind11/operations/core.hpp index 489ee330c5e..db8a7a1970c 100644 --- a/ttnn/cpp/pybind11/operations/core.hpp +++ b/ttnn/cpp/pybind11/operations/core.hpp @@ -65,19 +65,31 @@ void py_module(py::module& module) { module.def( "to_device", - py::overload_cast&>( - &ttnn::operations::core::to_device), + py::overload_cast< + const ttnn::Tensor&, + Device*, + const std::optional&, + uint8_t, + const std::vector&>(&ttnn::operations::core::to_device), py::arg("tensor"), py::arg("device"), - py::arg("memory_config") = std::nullopt); + py::arg("memory_config") = std::nullopt, + py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector()); module.def( "to_device", - py::overload_cast&>( - &ttnn::operations::core::to_device), + py::overload_cast< + const ttnn::Tensor&, + MeshDevice*, + const std::optional&, + uint8_t, + const std::vector&>(&ttnn::operations::core::to_device), py::arg("tensor"), py::arg("device"), py::arg("memory_config") = std::nullopt, + py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector(), R"doc( Copy tensor from host to device. @@ -85,6 +97,9 @@ void py_module(py::module& module) { tensor (ttnn.Tensor): The tensor to be copied from host to device. device (ttnn.Device | ttnn.MeshDevice): The target device where the tensor will be copied. memory_config (ttnn.MemoryConfig, optional): The memory configuration to use. Defaults to `None`. + cq_id (int, optional): The command queue ID to use. Defaults to `0`. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to wait on before writing the tensor to device memory. + If it is not provided, device will stall for all programs of the specified cq to finish before writing the tensor to device memory. Returns: ttnn.Tensor: The device tensor copy. @@ -103,6 +118,7 @@ void py_module(py::module& module) { py::arg("blocking") = true, py::kw_only(), py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector(), R"doc( Copy tensor from device to host. @@ -112,6 +128,8 @@ void py_module(py::module& module) { Keyword args: cq_id (int, optional): the command queue ID to use. Defaults to `0`. + sub_device_ids (List[ttnn.SubDeviceId], optional): the sub-device IDs to wait on before reading the tensor from device memory. + If it is not provided, device will stall for all programs of the specified cq to finish before reading the tensor from device memory. Returns: ttnn.Tensor: the host tensor copy. @@ -243,7 +261,8 @@ void py_module(py::module& module) { &ttnn::operations::core::copy_host_to_device_tensor, py::arg("host_tensor"), py::arg("device_tensor"), - py::arg("cq_id") = ttnn::DefaultQueueId); + py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector()); module.def( "begin_trace_capture", diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 8cd2e3da094..48a360fb3cb 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -919,15 +919,22 @@ void pytensor_module(py::module& m_tensor) { )doc") .def( "to", - py::overload_cast(&Tensor::to, py::const_), + py::overload_cast&>( + &Tensor::to, py::const_), py::arg("device").noconvert(), py::arg("mem_config").noconvert() = MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED}, + py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector(), py::keep_alive<0, 2>(), R"doc( Move TT Tensor from host device to TT accelerator device. Only BFLOAT16 (in ROW_MAJOR or TILE layout) and BFLOAT8_B, BFLOAT4_B (in TILE layout) are supported on device. + ``sub_device_ids`` can be used to specify which specific sub devices to wait on before writing the tensor to device memory. + + If it is not provided, device will stall for all programs of the specified cq to finish before writing the tensor to device memory. + If ``arg1`` is not supplied, default ``MemoryConfig`` with ``interleaved`` set to ``True``. +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ @@ -937,6 +944,10 @@ void pytensor_module(py::module& m_tensor) { +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ | arg1 | MemoryConfig of tensor of TT accelerator device | ttnn.MemoryConfig | | No | +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ + | arg2 | CQ ID of TT accelerator device to use | uint8_t | | No | + +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ + | arg3 | Sub device IDs to wait on before writing tensor | List[ttnn.SubDeviceId] | | No | + +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ .. code-block:: python @@ -950,15 +961,22 @@ void pytensor_module(py::module& m_tensor) { )doc") .def( "to", - py::overload_cast(&Tensor::to, py::const_), + py::overload_cast&>( + &Tensor::to, py::const_), py::arg("mesh_device").noconvert(), py::arg("mem_config").noconvert() = MemoryConfig{.memory_layout = TensorMemoryLayout::INTERLEAVED}, + py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector(), py::keep_alive<0, 2>(), R"doc( Move TT Tensor from host device to TT accelerator device. Only BFLOAT16 (in ROW_MAJOR or TILE layout) and BFLOAT8_B, BFLOAT4_B (in TILE layout) are supported on device. + ``sub_device_ids`` can be used to specify which specific sub devices to wait on before writing the tensor to device memory. + + If it is not provided, device will stall for all programs of the specified cq to finish before writing the tensor to device memory. + If ``arg1`` is not supplied, default ``MemoryConfig`` with ``interleaved`` set to ``True``. +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ @@ -968,6 +986,10 @@ void pytensor_module(py::module& m_tensor) { +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ | arg1 | MemoryConfig of tensor of TT accelerator device | ttnn.MemoryConfig | | No | +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ + | arg2 | CQ ID of TT accelerator device to use | uint8_t | | No | + +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ + | arg3 | Sub device IDs to wait before writing tensor | List[ttnn.SubDeviceId] | | No | + +-----------+-------------------------------------------------+----------------------------+-----------------------+----------+ .. code-block:: python @@ -1022,12 +1044,19 @@ void pytensor_module(py::module& m_tensor) { )doc") .def( "cpu", - [](const Tensor& self, bool blocking, uint8_t cq_id) { return self.cpu(blocking, cq_id); }, + [](const Tensor& self, bool blocking, uint8_t cq_id, const std::vector& sub_device_ids) { + return self.cpu(blocking, cq_id, sub_device_ids); + }, py::arg("blocking") = true, py::arg("cq_id") = ttnn::DefaultQueueId, + py::arg("sub_device_ids") = std::vector(), R"doc( Move TT Tensor from TT accelerator device to host device. + ``sub_device_ids`` can be used to specify which specific sub devices to wait on before reading the tensor from device memory. + + If it is not provided, device will stall waiting for all programs of the specified cq to finish before reading the tensor from device memory. + .. code-block:: python tt_tensor = tt_tensor.cpu() diff --git a/ttnn/cpp/ttnn/events.cpp b/ttnn/cpp/ttnn/events.cpp index 789cdd36c6e..a38da21fccb 100644 --- a/ttnn/cpp/ttnn/events.cpp +++ b/ttnn/cpp/ttnn/events.cpp @@ -29,9 +29,11 @@ std::shared_ptr create_event(Device* device) { return event; } -void record_event(uint8_t cq_id, const std::shared_ptr& event) { +void record_event(uint8_t cq_id, const std::shared_ptr& event, const std::vector& sub_device_ids) { Device* device = event->device; - device->push_work([device, event, cq_id] { EnqueueRecordEvent(device->command_queue(cq_id), event); }); + device->push_work([device, event, cq_id, sub_device_ids] { + EnqueueRecordEvent(device->command_queue(cq_id), event, sub_device_ids); + }); } void wait_for_event(uint8_t cq_id, const std::shared_ptr& event) { @@ -41,9 +43,10 @@ void wait_for_event(uint8_t cq_id, const std::shared_ptr& event) { MultiDeviceEvent create_event(MeshDevice* mesh_device) { return MultiDeviceEvent(mesh_device); } -void record_event(uint8_t cq_id, const MultiDeviceEvent& multi_device_event) { +void record_event( + uint8_t cq_id, const MultiDeviceEvent& multi_device_event, const std::vector& sub_device_ids) { for (auto& event : multi_device_event.events) { - record_event(cq_id, event); + record_event(cq_id, event, sub_device_ids); } } diff --git a/ttnn/cpp/ttnn/events.hpp b/ttnn/cpp/ttnn/events.hpp index 57405fa9526..d4c409338c6 100644 --- a/ttnn/cpp/ttnn/events.hpp +++ b/ttnn/cpp/ttnn/events.hpp @@ -16,11 +16,12 @@ struct MultiDeviceEvent { }; // Single Device APIs std::shared_ptr create_event(Device* device); -void record_event(uint8_t cq_id, const std::shared_ptr& event); +void record_event( + uint8_t cq_id, const std::shared_ptr& event, const std::vector& sub_device_ids = {}); void wait_for_event(uint8_t cq_id, const std::shared_ptr& event); // Multi Device APIs MultiDeviceEvent create_event(MeshDevice* mesh_device); -void record_event(uint8_t cq_id, const MultiDeviceEvent& event); +void record_event(uint8_t cq_id, const MultiDeviceEvent& event, const std::vector& sub_device_ids = {}); void wait_for_event(uint8_t cq_id, const MultiDeviceEvent& event); } // namespace ttnn::events diff --git a/ttnn/cpp/ttnn/operations/core/core.cpp b/ttnn/cpp/ttnn/operations/core/core.cpp index 184f6e139f1..90fc3f34908 100644 --- a/ttnn/cpp/ttnn/operations/core/core.cpp +++ b/ttnn/cpp/ttnn/operations/core/core.cpp @@ -58,25 +58,34 @@ ttnn::Tensor squeeze_from_4D(const ttnn::Tensor& tensor, const int rank) { return ttnn::reshape(tensor, shape.to_rank(rank)); } -ttnn::Tensor to_device(const ttnn::Tensor& tensor, Device* device, const std::optional& memory_config) { +ttnn::Tensor to_device( + const ttnn::Tensor& tensor, + Device* device, + const std::optional& memory_config, + uint8_t cq_id, + const std::vector& sub_device_ids) { auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); if (mem_config.is_sharded() and (device->arch() == tt::ARCH::BLACKHOLE)) { - auto interleaved_tensor = tensor.to(device, ttnn::DRAM_MEMORY_CONFIG); + auto interleaved_tensor = tensor.to(device, ttnn::DRAM_MEMORY_CONFIG, cq_id, sub_device_ids); return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt); } else { - return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); + return tensor.to(device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG), cq_id, sub_device_ids); } } ttnn::Tensor to_device( - const ttnn::Tensor& tensor, MeshDevice* mesh_device, const std::optional& memory_config) { + const ttnn::Tensor& tensor, + MeshDevice* mesh_device, + const std::optional& memory_config, + uint8_t cq_id, + const std::vector& sub_device_ids) { auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); // Currently no direct sharded write support in BLACKHOLE due to alignment issue if (mem_config.is_sharded() and (mesh_device->arch() == tt::ARCH::BLACKHOLE)) { - auto interleaved_tensor = tensor.to(mesh_device, ttnn::DRAM_MEMORY_CONFIG); + auto interleaved_tensor = tensor.to(mesh_device, ttnn::DRAM_MEMORY_CONFIG, cq_id, sub_device_ids); return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt); } else { - return tensor.to(mesh_device, mem_config); + return tensor.to(mesh_device, mem_config, cq_id, sub_device_ids); } } @@ -100,17 +109,22 @@ ttnn::Tensor allocate_tensor_on_device( shape, data_type, layout, mesh_device->get_devices(), memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); } -void copy_host_to_device_tensor(const ttnn::Tensor& host_tensor, ttnn::Tensor device_tensor, uint8_t cq_id) { - tt::tt_metal::write_tensor(std::move(host_tensor), std::move(device_tensor), cq_id); +void copy_host_to_device_tensor( + const ttnn::Tensor& host_tensor, + ttnn::Tensor device_tensor, + uint8_t cq_id, + const std::vector& sub_device_ids) { + tt::tt_metal::write_tensor(std::move(host_tensor), std::move(device_tensor), cq_id, sub_device_ids); } -ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id) { +ttnn::Tensor from_device( + const ttnn::Tensor& tensor, bool blocking, uint8_t cq_id, const std::vector& sub_device_ids) { // Currently no direct sharded read support in BLACKHOLE due to alignment issue if (tensor.is_sharded() and (tensor.device()->arch() == tt::ARCH::BLACKHOLE)) { auto interleaved_tensor = ttnn::sharded_to_interleaved(cq_id, tensor, ttnn::DRAM_MEMORY_CONFIG, std::nullopt); - return interleaved_tensor.cpu(blocking, cq_id); + return interleaved_tensor.cpu(blocking, cq_id, sub_device_ids); } else { - return tensor.cpu(blocking, cq_id); + return tensor.cpu(blocking, cq_id, sub_device_ids); } } diff --git a/ttnn/cpp/ttnn/operations/core/core.hpp b/ttnn/cpp/ttnn/operations/core/core.hpp index e269e7030b6..d3ce90a4e24 100644 --- a/ttnn/cpp/ttnn/operations/core/core.hpp +++ b/ttnn/cpp/ttnn/operations/core/core.hpp @@ -24,10 +24,19 @@ ttnn::Tensor unsqueeze_to_4D(const ttnn::Tensor& tensor); ttnn::Tensor squeeze_from_4D(const ttnn::Tensor& tensor, const int rank); -ttnn::Tensor to_device(const ttnn::Tensor& tensor, Device* device, const std::optional& memory_config); +ttnn::Tensor to_device( + const ttnn::Tensor& tensor, + Device* device, + const std::optional& memory_config, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& = {}); ttnn::Tensor to_device( - const ttnn::Tensor& tensor, MeshDevice* mesh_device, const std::optional& memory_config); + const ttnn::Tensor& tensor, + MeshDevice* mesh_device, + const std::optional& memory_config, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& = {}); ttnn::Tensor allocate_tensor_on_device( const Shape& shape, @@ -44,9 +53,16 @@ ttnn::Tensor allocate_tensor_on_device( const std::optional& memory_config); void copy_host_to_device_tensor( - const ttnn::Tensor& host_tensor, ttnn::Tensor device_tensor, uint8_t cq_id = ttnn::DefaultQueueId); - -ttnn::Tensor from_device(const ttnn::Tensor& tensor, bool blocking = true, uint8_t cq_id = ttnn::DefaultQueueId); + const ttnn::Tensor& host_tensor, + ttnn::Tensor device_tensor, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}); + +ttnn::Tensor from_device( + const ttnn::Tensor& tensor, + bool blocking = true, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}); void deallocate(Tensor& tensor, bool force = true); diff --git a/ttnn/cpp/ttnn/operations/reduction/moe/moe.cpp b/ttnn/cpp/ttnn/operations/reduction/moe/moe.cpp index fcfd1f35e60..dbf98519483 100644 --- a/ttnn/cpp/ttnn/operations/reduction/moe/moe.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/moe/moe.cpp @@ -39,9 +39,8 @@ auto MoeOperation::invoke( const uint16_t k, const std::optional& memory_config, std::optional optional_output_tensor) { - constexpr uint8_t DefaultQueueId = 0; return invoke( - DefaultQueueId, + ttnn::DefaultQueueId, input_tensor, expert_mask_tensor, topk_mask_tensor, diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index 6f153e8e6b4..f4304d33c6a 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -559,24 +559,28 @@ const Storage& Tensor::get_storage() const { return this->tensor_attributes->storage; } -Tensor Tensor::to(CommandQueue& queue, const MemoryConfig& mem_config) const { - return tensor_ops::tensor_to(*this, queue.device(), mem_config); +Tensor Tensor::to(Device* target_device, const MemoryConfig& mem_config,uint8_t cq_id, + const std::vector& sub_device_ids) const { + return tensor_ops::tensor_to(*this, target_device, mem_config, cq_id, sub_device_ids); } -Tensor Tensor::to(Device* target_device, const MemoryConfig& mem_config) const { - return tensor_ops::tensor_to(*this, target_device, mem_config); -} - -Tensor Tensor::to(distributed::MeshDevice* mesh_device, const MemoryConfig& mem_config) const { +Tensor Tensor::to(distributed::MeshDevice* mesh_device, const MemoryConfig& mem_config,uint8_t cq_id, + const std::vector& sub_device_ids) const { std::vector workers_to_use = ttnn::distributed::get_mapped_devices(*this, *mesh_device); - return tensor_ops::tensor_to(*this, workers_to_use, mem_config); + return tensor_ops::tensor_to(*this, workers_to_use, mem_config, cq_id, sub_device_ids); } -Tensor Tensor::to(const std::vector& workers, const MemoryConfig& mem_config) const { - return tensor_ops::tensor_to(*this, workers, mem_config); +Tensor Tensor::to( + const std::vector& workers, + const MemoryConfig& mem_config, + uint8_t cq_id, + const std::vector& sub_device_ids) const { + return tensor_ops::tensor_to(*this, workers, mem_config, cq_id, sub_device_ids); } -Tensor Tensor::cpu(bool blocking, uint8_t cq_id) const { return tensor_ops::tensor_cpu(*this, blocking, cq_id); } +Tensor Tensor::cpu(bool blocking, uint8_t cq_id, const std::vector& sub_device_ids) const { + return tensor_ops::tensor_cpu(*this, blocking, cq_id, sub_device_ids); +} Tensor Tensor::cpu_sharded() const { return tensor_ops::tensor_cpu_sharded(*this); } @@ -861,7 +865,8 @@ Tensor allocate_tensor_on_devices( return device_tensor; } -void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id) { +void write_tensor( + const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id, const std::vector& sub_device_ids) { // Top level wrapper to copy a host tensor to a preallocated device tensor TT_ASSERT(device_tensor.workers.size(), "Workers must be specified for device_tensor in write_tensor"); @@ -877,7 +882,7 @@ void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id for (int worker_index = 0; worker_index < device_tensor.workers.size(); ++worker_index) { auto& worker = device_tensor.workers[worker_index]; - worker->push_work([cq_id, worker, worker_index, async_safe_tensor, device_tensor]() mutable { + worker->push_work([cq_id, worker, worker_index, async_safe_tensor, device_tensor, sub_device_ids]() mutable { TT_FATAL( device_tensor.storage_type() == StorageType::DEVICE or device_tensor.storage_type() == StorageType::MULTI_DEVICE, @@ -889,7 +894,7 @@ void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id "Error"); std::visit( tt::stl::overloaded{ - [worker, worker_index, cq_id, &async_safe_tensor](const DeviceStorage& device_storage) { + [worker, worker_index, cq_id, &async_safe_tensor, sub_device_ids](const DeviceStorage& device_storage) { // Copying from host to a single device. void* host_data = std::visit( tt::stl::overloaded{ @@ -913,9 +918,10 @@ void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id worker->command_queue(cq_id), device_storage.get_buffer(), host_data, - /*blocking=*/false); + /*blocking=*/false, + sub_device_ids); }, - [worker, worker_index, cq_id, &async_safe_tensor](const MultiDeviceStorage& device_storage) { + [worker, worker_index, cq_id, &async_safe_tensor, sub_device_ids](const MultiDeviceStorage& device_storage) { // Copying from host to multi-device. TT_ASSERT( std::holds_alternative(async_safe_tensor.get_storage()), @@ -928,7 +934,8 @@ void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id worker->command_queue(cq_id), device_storage.get_buffer_for_device(worker), host_data, - /*blocking=*/false); + /*blocking=*/false, + sub_device_ids); }, [](auto&& s) { TT_THROW("Unreachable"); }}, device_tensor.get_storage()); diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index bf86cca99c1..b8b7a993b8a 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -141,19 +141,21 @@ struct Tensor { Tensor to( Device* target_device, - const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const; + const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}) const; Tensor to( distributed::MeshDevice* mesh_device, - const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const; - - Tensor to( - CommandQueue& queue, - const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const; + const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}) const; Tensor to( const std::vector& workers, - const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) const; + const MemoryConfig& mem_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}) const; Tensor to(Layout target_layout, Device* worker = nullptr) const; @@ -164,7 +166,10 @@ struct Tensor { const ttnn::SimpleShape& input_tensor_start, float pad_value) const; - Tensor cpu(bool blocking = true, uint8_t cq_id = ttnn::DefaultQueueId) const; + Tensor cpu( + bool blocking = true, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}) const; Tensor cpu_sharded() const; @@ -374,7 +379,11 @@ Tensor allocate_tensor_on_devices( const std::vector& devices, const MemoryConfig& memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, const std::optional& tile = std::nullopt); -void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id = ttnn::DefaultQueueId); +void write_tensor( + const Tensor& host_tensor, + Tensor device_tensor, + uint8_t cq_id = ttnn::DefaultQueueId, + const std::vector& sub_device_ids = {}); Tensor set_tensor_id(const Tensor& tensor); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index 0386f6e353c..dc7545ac0e5 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -565,7 +565,11 @@ std::string to_string(const Tensor& tensor, std::optional o // ====================================================================================== template -Tensor to_host_helper(const Tensor& tensor, bool blocking = true, uint8_t cq_id = ttnn::DefaultQueueId) { +Tensor to_host_helper( + const Tensor& tensor, + bool blocking = true, + uint8_t cq_id = ttnn::DefaultQueueId, + tt::stl::Span sub_device_ids = {}) { TT_ASSERT(tensor.is_allocated(), "Buffer must be allocated on device!"); auto device_buffer = tensor.device_buffer(); auto device = tensor.device(); @@ -575,7 +579,8 @@ Tensor to_host_helper(const Tensor& tensor, bool blocking = true, uint8_t cq_id const char* TT_METAL_SLOW_DISPATCH_MODE = std::getenv("TT_METAL_SLOW_DISPATCH_MODE"); if (TT_METAL_SLOW_DISPATCH_MODE == nullptr) { data_vec.resize(size_in_bytes / sizeof(T)); - read_data_from_device_buffer(device->command_queue(cq_id), device_buffer, data_vec.data(), blocking); + read_data_from_device_buffer( + device->command_queue(cq_id), device_buffer, data_vec.data(), blocking, sub_device_ids); } else { read_data_from_device_buffer(device_buffer, data_vec); } @@ -584,9 +589,9 @@ Tensor to_host_helper(const Tensor& tensor, bool blocking = true, uint8_t cq_id } template -Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { +Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids) { if (tensor.storage_type() == StorageType::DEVICE) { - return to_host_helper(tensor, blocking, cq_id); + return to_host_helper(tensor, blocking, cq_id, sub_device_ids); } else if (tensor.storage_type() == StorageType::MULTI_DEVICE) { auto devices = get_devices(tensor); Tensor host_tensor(devices.size()); @@ -594,7 +599,7 @@ Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { for (int device_index = 0; device_index < devices.size(); ++device_index) { const auto& device = devices[device_index]; auto shard = get_shard_for_device(tensor, device); - shard = to_host_helper(shard, blocking, cq_id); + shard = to_host_helper(shard, blocking, cq_id, sub_device_ids); insert_buffer_and_shape_for_device(device, shard, host_tensor, device_index); } return host_tensor; @@ -603,21 +608,29 @@ Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { } } -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); -template Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); +template Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids); template <> -Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { - return to_host(tensor, blocking, cq_id); +Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids) { + return to_host(tensor, blocking, cq_id, sub_device_ids); } template <> -Tensor to_host(const Tensor& tensor, bool blocking, uint8_t cq_id) { - return to_host(tensor, blocking, cq_id); +Tensor to_host( + const Tensor& tensor, bool blocking, uint8_t cq_id, tt::stl::Span sub_device_ids) { + return to_host(tensor, blocking, cq_id, sub_device_ids); } // ====================================================================================== @@ -662,7 +675,11 @@ Tensor to_host_sharded(const Tensor& tensor) { // ====================================================================================== template typename BufferType> -void write_data_to_device_buffer(CommandQueue& cq, const BufferType& host_buffer, DeviceBuffer device_buffer) { +void write_data_to_device_buffer( + CommandQueue& cq, + const BufferType& host_buffer, + DeviceBuffer device_buffer, + tt::stl::Span sub_device_ids) { ZoneScoped; // TODO(arakhmati): can we use generators in this function to go from `data_to_write` to `uint32_data`? // And effectively get rid of any additional allocation @@ -676,12 +693,12 @@ void write_data_to_device_buffer(CommandQueue& cq, const BufferType& host_buf const uint32_t* borrowed_buf_base = static_cast(host_buffer.data()); std::vector owned_copy_vec(borrowed_buf_base, borrowed_buf_base + borrowed_buf_size_words); owned_buffer::Buffer owned_copy(std::make_shared>(owned_copy_vec)); - EnqueueWriteBuffer(cq, device_buffer, owned_copy.get_ptr(), false); + EnqueueWriteBuffer(cq, device_buffer, owned_copy.get_ptr(), false, sub_device_ids); } else if constexpr (std::is_same_v, owned_buffer::Buffer>) { - EnqueueWriteBuffer(cq, device_buffer, host_buffer.get_ptr(), false); + EnqueueWriteBuffer(cq, device_buffer, host_buffer.get_ptr(), false, sub_device_ids); } } else { - EnqueueWriteBuffer(cq, device_buffer, host_buffer.data(), false); + EnqueueWriteBuffer(cq, device_buffer, host_buffer.data(), false, sub_device_ids); } } @@ -699,7 +716,8 @@ DeviceBuffer initialize_data_on_device( BufferType& data_to_write, Device* device, const TensorSpec& tensor_spec, - std::optional> queue = std::nullopt) { + uint8_t cq_id = ttnn::DefaultQueueId, + tt::stl::Span sub_device_ids = {}) { ZoneScoped; TT_ASSERT(device != nullptr); @@ -707,8 +725,7 @@ DeviceBuffer initialize_data_on_device( const char* TT_METAL_SLOW_DISPATCH_MODE = std::getenv("TT_METAL_SLOW_DISPATCH_MODE"); if (TT_METAL_SLOW_DISPATCH_MODE == nullptr) { - write_data_to_device_buffer( - queue.has_value() ? queue.value().get() : device->command_queue(), data_to_write, device_buffer); + write_data_to_device_buffer(device->command_queue(cq_id), data_to_write, device_buffer, sub_device_ids); } else { write_data_to_device_buffer(data_to_write, *device_buffer); } @@ -720,13 +737,14 @@ DeviceBuffer to_device_buffer( const Storage& storage, Device* device, const TensorSpec& tensor_spec, - std::optional> queue) { + uint8_t cq_id, + tt::stl::Span sub_device_ids) { return std::visit( - [&device, &tensor_spec, &queue](auto&& storage) -> DeviceBuffer { + [&device, &tensor_spec, cq_id, sub_device_ids](auto&& storage) -> DeviceBuffer { using StorageType = std::decay_t; if constexpr (std::is_same_v or std::is_same_v) { auto data_to_write = host_buffer::get_as(storage.buffer); - return initialize_data_on_device(data_to_write, device, tensor_spec, queue); + return initialize_data_on_device(data_to_write, device, tensor_spec, cq_id, sub_device_ids); } else if constexpr (std::is_same_v) { TT_THROW("Device storage doesn't support to_device_buffer"); } else if constexpr (std::is_same_v) { @@ -749,7 +767,8 @@ Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue) { + uint8_t cq_id, + tt::stl::Span sub_device_ids) { TT_FATAL(tensor.storage_type() != StorageType::DEVICE, "Tensor is already on device!"); if (tensor.storage_type() == StorageType::OWNED) { TT_FATAL(tensor.is_allocated(), "Need host buffer on device to exist to copy data to device!"); @@ -759,7 +778,8 @@ Tensor to_device( TensorSpec tensor_spec( tensor.get_logical_shape(), tensor.get_tensor_spec().tensor_layout().with_memory_config(memory_config)); - auto device_buffer = tensor_impl::to_device_buffer(tensor.get_storage(), target_device, tensor_spec, queue); + auto device_buffer = + tensor_impl::to_device_buffer(tensor.get_storage(), target_device, tensor_spec, cq_id, sub_device_ids); return Tensor(DeviceStorage{device_buffer}, tensor_spec); } @@ -767,40 +787,47 @@ template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids); template <> Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue) { - return to_device(tensor, target_device, memory_config, queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids) { + return to_device(tensor, target_device, memory_config, cq_id, sub_device_ids); } template <> @@ -808,8 +835,9 @@ Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue) { - return to_device(tensor, target_device, memory_config, queue); + uint8_t cq_id, + tt::stl::Span sub_device_ids) { + return to_device(tensor, target_device, memory_config, cq_id, sub_device_ids); } // ====================================================================================== diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index 5a0ec30ecdd..87c34bdb199 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -167,8 +167,12 @@ DeviceBuffer allocate_buffer_on_device(Device* device, const TensorSpec& tensor_ template inline void read_data_from_device_buffer( - CommandQueue& cq, DeviceBuffer device_buffer, void* host_buffer_data, bool blocking) { - EnqueueReadBuffer(cq, device_buffer, host_buffer_data, blocking); + CommandQueue& cq, + DeviceBuffer device_buffer, + void* host_buffer_data, + bool blocking, + tt::stl::Span sub_device_ids = {}) { + EnqueueReadBuffer(cq, device_buffer, host_buffer_data, blocking, sub_device_ids); } template @@ -181,7 +185,11 @@ inline void read_data_from_device_buffer(DeviceBuffer device_buffer, std::vector // ====================================================================================== template -Tensor to_host(const Tensor& tensor, bool blocking = true, uint8_t cq_id = ttnn::DefaultQueueId); +Tensor to_host( + const Tensor& tensor, + bool blocking = true, + uint8_t cq_id = ttnn::DefaultQueueId, + tt::stl::Span sub_device_ids = {}); template Tensor to_host_sharded(const Tensor& tensor); @@ -191,7 +199,8 @@ Tensor to_device( const Tensor& tensor, Device* target_device, const MemoryConfig& memory_config, - std::optional> queue); + uint8_t cq_id = ttnn::DefaultQueueId, + tt::stl::Span sub_device_ids = {}); template Tensor to_layout(const Tensor& tensor, Layout target_layout); diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 460f8b0d5db..f40690d2a44 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -24,7 +24,12 @@ namespace tt::tt_metal::tensor_ops { -Tensor tensor_to(const Tensor& input_tensor, Device* target_device, const MemoryConfig& mem_config) { +Tensor tensor_to( + const Tensor& input_tensor, + Device* target_device, + const MemoryConfig& mem_config, + uint8_t cq_id, + const std::vector& sub_device_ids) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to", input_tensor, target_device, mem_config); // Tensor can be using borrowed storage. If so, when running in async mode, copy this tensor to owned storage. @@ -35,7 +40,12 @@ Tensor tensor_to(const Tensor& input_tensor, Device* target_device, const Memory // Record main thread ref count for tensors before pushing to queue. uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(); - target_device->push_work([async_safe_tensor, device_tensor, mem_config, target_device]() mutable { + target_device->push_work([async_safe_tensor, + device_tensor, + mem_config, + target_device, + cq_id, + sub_device_ids]() mutable { if (async_safe_tensor.storage_type() == StorageType::DEVICE) { TT_ASSERT(async_safe_tensor.device() == target_device && "Currently do not support moving between devices"); device_tensor.populate_buffers_and_metadata(async_safe_tensor); @@ -46,7 +56,7 @@ Tensor tensor_to(const Tensor& input_tensor, Device* target_device, const Memory async_safe_tensor.get_dtype(), async_safe_tensor.get_layout()); auto local_tensor = - tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config, std::nullopt); + tensor_impl::to_device_wrapper(async_safe_tensor, target_device, mem_config, cq_id, sub_device_ids); // Populate device tensor device_tensor.populate_buffers_and_metadata(local_tensor); } @@ -61,7 +71,12 @@ Tensor tensor_to(const Tensor& input_tensor, Device* target_device, const Memory return device_tensor; } -Tensor tensor_to(const Tensor& input_tensor, const std::vector& workers, const MemoryConfig& mem_config) { +Tensor tensor_to( + const Tensor& input_tensor, + const std::vector& workers, + const MemoryConfig& mem_config, + uint8_t cq_id, + const std::vector& sub_device_ids) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::to", input_tensor, workers, mem_config); TT_FATAL( @@ -72,10 +87,17 @@ Tensor tensor_to(const Tensor& input_tensor, const std::vector& workers uint32_t num_workers = workers.size(); for (int worker_index = 0; worker_index < workers.size(); ++worker_index) { auto& worker = workers[worker_index]; - worker->push_work([worker, input_tensor, device_tensor, mem_config, num_workers, worker_index]() mutable { + worker->push_work([worker, + input_tensor, + device_tensor, + mem_config, + num_workers, + worker_index, + cq_id, + sub_device_ids]() mutable { auto shard = get_shard_for_device(input_tensor, worker, worker_index); if (shard.storage_type() == StorageType::OWNED) { - shard = tensor_impl::to_device_wrapper(shard, worker, mem_config, std::nullopt); + shard = tensor_impl::to_device_wrapper(shard, worker, mem_config, cq_id, sub_device_ids); } insert_buffer_and_shape_for_device(worker, shard, device_tensor, worker_index); uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++; @@ -93,7 +115,8 @@ Tensor tensor_to(const Tensor& input_tensor, const std::vector& workers return device_tensor; } -Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id) { +Tensor tensor_cpu( + const Tensor& input_tensor, bool blocking, uint8_t cq_id, const std::vector& sub_device_ids) { ZoneScoped; GraphTracker::instance().track_function_start("Tensor::cpu", input_tensor, blocking); auto workers = input_tensor.get_workers(blocking); @@ -111,19 +134,20 @@ Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id) { uint32_t original_tensor_ref_count = input_tensor.tensor_attributes->record_main_thread_ref_count(); for (int worker_index = 0; worker_index < workers.size(); worker_index++) { auto target_device = workers[worker_index]; - target_device->push_work([host_tensor, blocking, target_device, input_tensor, worker_index, cq_id]() mutable { - TT_ASSERT( - input_tensor.storage_type() == StorageType::DEVICE or - input_tensor.storage_type() == StorageType::MULTI_DEVICE, - "Can only use worker queue for cpu call if tensor is on device."); - auto shard = get_shard_for_device(input_tensor, target_device); - shard = tensor_impl::to_host_wrapper(shard, blocking, cq_id); - insert_buffer_and_shape_for_device(target_device, shard, host_tensor, worker_index); - uint32_t num_workers_completed = (host_tensor.tensor_attributes->num_workers_completed)++; - if (not num_workers_completed) { - host_tensor.set_tensor_spec(input_tensor.get_tensor_spec()); - } - }); + target_device->push_work( + [host_tensor, blocking, target_device, input_tensor, worker_index, cq_id, sub_device_ids]() mutable { + TT_ASSERT( + input_tensor.storage_type() == StorageType::DEVICE or + input_tensor.storage_type() == StorageType::MULTI_DEVICE, + "Can only use worker queue for cpu call if tensor is on device."); + auto shard = get_shard_for_device(input_tensor, target_device); + shard = tensor_impl::to_host_wrapper(shard, blocking, cq_id, sub_device_ids); + insert_buffer_and_shape_for_device(target_device, shard, host_tensor, worker_index); + uint32_t num_workers_completed = (host_tensor.tensor_attributes->num_workers_completed)++; + if (not num_workers_completed) { + host_tensor.set_tensor_spec(input_tensor.get_tensor_spec()); + } + }); } if (blocking) { diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp index 98f8103c151..b8edff425f8 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp @@ -20,15 +20,26 @@ class Device; namespace tt::tt_metal::tensor_ops { -Tensor tensor_to(const Tensor& input_tensor, Device* target_device, const MemoryConfig& mem_config); +Tensor tensor_to( + const Tensor& input_tensor, + Device* target_device, + const MemoryConfig& mem_config, + uint8_t cq_id, + const std::vector& sub_device_ids); -Tensor tensor_to(const Tensor& input_tensor, const std::vector& workers, const MemoryConfig& mem_config); +Tensor tensor_to( + const Tensor& input_tensor, + const std::vector& workers, + const MemoryConfig& mem_config, + uint8_t cq_id, + const std::vector& sub_device_ids); Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, Device* worker); Tensor tensor_to(const Tensor& input_tensor, Layout target_layout, distributed::MeshDevice* mesh_device); -Tensor tensor_cpu(const Tensor& input_tensor, bool blocking, uint8_t cq_id); +Tensor tensor_cpu( + const Tensor& input_tensor, bool blocking, uint8_t cq_id, const std::vector& sub_device_ids); Tensor tensor_cpu_sharded(const Tensor& input_tensor); diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 041de728018..4f613ca11ef 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -187,7 +187,9 @@ def manage_config(name, value): format_output_tensor, pad_to_tile_shape, SubDevice, + SubDeviceId, SubDeviceManagerId, + DefaultQueueId, init_device_compute_kernel_config, ) diff --git a/ttnn/ttnn/device.py b/ttnn/ttnn/device.py index b8de80cd87a..6cbfaa85ead 100644 --- a/ttnn/ttnn/device.py +++ b/ttnn/ttnn/device.py @@ -162,6 +162,9 @@ def is_blackhole(device=None): pad_to_tile_shape = ttnn._ttnn.device.pad_to_tile_shape SubDevice = ttnn._ttnn.device.SubDevice +SubDeviceId = ttnn._ttnn.device.SubDeviceId SubDeviceManagerId = ttnn._ttnn.device.SubDeviceManagerId +DefaultQueueId = ttnn._ttnn.device.DefaultQueueId + __all__ = [] diff --git a/ttnn/ttnn/distributed/distributed.py b/ttnn/ttnn/distributed/distributed.py index bda36c5b0f5..65a902d11cf 100644 --- a/ttnn/ttnn/distributed/distributed.py +++ b/ttnn/ttnn/distributed/distributed.py @@ -209,19 +209,23 @@ def create_mesh_device( close_mesh_device(mesh_device) -def synchronize_devices(devices: Union["ttnn.Device", "ttnn.MeshDevice"], queue_id: Optional[int] = None) -> None: +def synchronize_devices( + devices: Union["ttnn.Device", "ttnn.MeshDevice"], + queue_id: Optional[int] = ttnn.DefaultQueueId, + sub_device_ids: List[ttnn.SubDeviceId] = [], +) -> None: """ - synchronize_devices(devices: Union[ttnn.Device, ttnn.MeshDevice], queue_id: Optional[int] = None) -> None: + synchronize_devices(devices: Union[ttnn.Device, ttnn.MeshDevice], queue_id: Optional[int] = None, sub_device_ids: List[ttnn.SubDeviceId] = []) -> None: Synchronize the devices with host by waiting for all operations to complete. If queue_id is provided then only the operations associated with that queue_id are waited for, otherwise operations for all command queues are waited on. """ if isinstance(devices, ttnn.Device): - ttnn._ttnn.device.synchronize_device(devices, queue_id) + ttnn._ttnn.device.synchronize_device(devices, queue_id, sub_device_ids) else: for device in devices.get_device_ids(): - ttnn._ttnn.device.synchronize_device(devices.get_device(device), queue_id) + ttnn._ttnn.device.synchronize_device(devices.get_device(device), queue_id, sub_device_ids) class TensorToMesh: diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 3eeda3a90b6..24480037a3f 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -4,7 +4,7 @@ import math import pathlib -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import ttnn.decorators @@ -158,6 +158,8 @@ def from_torch( device: Optional[ttnn.Device] = None, memory_config: Optional[ttnn.MemoryConfig] = None, mesh_mapper: Optional[ttnn.TensorToMesh] = None, + cq_id: Optional[int] = ttnn.DefaultQueueId, + sub_device_ids: List[ttnn.SubDeviceId] = [], ) -> ttnn.Tensor: """ Converts the `torch.Tensor` tensor into a `ttnn.Tensor`. For bfloat8_b or bfloat4_b format, the function itself is called twice, @@ -176,6 +178,8 @@ def from_torch( device (ttnn.Device, optional): the desired `ttnn` device. Defaults to `None`. memory_config (ttnn.MemoryConfig, optional): The desired `ttnn` memory configuration. Defaults to `None`. mesh_mapper (ttnn.TensorToMesh, optional): The desired `ttnn` mesh mapper. Defaults to `None`. + cq_id (int, optional): The command queue ID to use. Defaults to `0`. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to wait on. Defaults to all sub-devices. Returns: ttnn.Tensor: The resulting `ttnn` tensor. @@ -225,7 +229,7 @@ def from_torch( if device is not None: if memory_config is None: memory_config = ttnn.DRAM_MEMORY_CONFIG - tensor = ttnn.to_device(tensor, device, memory_config=memory_config) + tensor = ttnn.to_device(tensor, device, memory_config=memory_config, cq_id=cq_id, sub_device_ids=sub_device_ids) if shape_with_padding is not None and shape_with_padding != tensor.shape and mesh_mapper is None: tensor = ttnn.reshape(tensor, shape_with_padding) @@ -262,7 +266,8 @@ def to_torch( torch_rank: Optional[int] = None, mesh_composer: Optional[ttnn.MeshToTensor] = None, device: Optional[ttnn.Device] = None, - cq_id: Optional[int] = 0, + cq_id: Optional[int] = ttnn.DefaultQueueId, + sub_device_ids: List[ttnn.SubDeviceId] = [], ) -> "torch.Tensor": """ Converts the `ttnn.Tensor` tensor into a `torch.Tensor`. It does not call to_layout for bfloat8_b or bfloat4_b as we now convert @@ -278,6 +283,7 @@ def to_torch( mesh_composer (ttnn.MeshToTensor, optional): The desired `ttnn` mesh composer. Defaults to `None`. device (ttnn.Device, optional): The `ttnn` device of the input tensor. Defaults to `None`. cq_id (int, optional): The command queue ID to use. Defaults to `0`. + sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to wait on. Defaults to all sub-devices. Returns: torch.Tensor: The converted `torch` tensor. @@ -290,7 +296,7 @@ def to_torch( [ 0.9023, -0.5820, 0.5312]], dtype=torch.bfloat16) """ if ttnn.is_tensor_storage_on_device(tensor): - tensor = ttnn.from_device(tensor, cq_id=cq_id) + tensor = ttnn.from_device(tensor, cq_id=cq_id, sub_device_ids=sub_device_ids) if (tensor.layout != ttnn.ROW_MAJOR_LAYOUT) and not ( tensor.dtype == ttnn.bfloat8_b or tensor.dtype == ttnn.bfloat4_b From a6fe436414068aee2a1692a363193e5cc7c803e9 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Sun, 8 Dec 2024 20:01:25 +0000 Subject: [PATCH 2/2] #15836: Add api passing different sub device configurations per device in a mesh_device --- tests/ttnn/unit_tests/test_sub_device.py | 30 +++++++++++++------ tt_metal/distributed/mesh_device.cpp | 18 +++++++++++ tt_metal/distributed/mesh_device.hpp | 2 ++ .../ttnn/distributed/distributed_pybind.cpp | 20 +++++++++++++ 4 files changed, 61 insertions(+), 9 deletions(-) diff --git a/tests/ttnn/unit_tests/test_sub_device.py b/tests/ttnn/unit_tests/test_sub_device.py index be2c8174870..f7bfb20401a 100644 --- a/tests/ttnn/unit_tests/test_sub_device.py +++ b/tests/ttnn/unit_tests/test_sub_device.py @@ -7,7 +7,7 @@ import ttnn -def run_sub_devices(device): +def run_sub_devices(device, replicate_sub_devices=False): tensix_cores0 = ttnn.CoreRangeSet( { ttnn.CoreRange( @@ -26,8 +26,14 @@ def run_sub_devices(device): ) sub_device_1 = ttnn.SubDevice([tensix_cores0]) sub_device_2 = ttnn.SubDevice([tensix_cores1]) - sub_device_manager1 = device.create_sub_device_manager([sub_device_1, sub_device_2], 3200) - sub_device_manager2 = device.create_sub_device_manager([sub_device_2], 3200) + sub_devices_1 = [sub_device_1, sub_device_2] + sub_devices_2 = [sub_device_2] + if replicate_sub_devices: + num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices() + sub_devices_1 = [sub_devices_1] * num_devices + sub_devices_2 = [sub_devices_2] * num_devices + sub_device_manager1 = device.create_sub_device_manager(sub_devices_1, 3200) + sub_device_manager2 = device.create_sub_device_manager(sub_devices_2, 3200) device.load_sub_device_manager(sub_device_manager1) ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(1)]) ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(0), ttnn.SubDeviceId(1)]) @@ -39,7 +45,7 @@ def run_sub_devices(device): device.remove_sub_device_manager(sub_device_manager2) -def run_sub_devices_program(device): +def run_sub_devices_program(device, replicate_sub_devices=False): is_mesh_device = isinstance(device, ttnn.MeshDevice) if is_mesh_device: inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0) @@ -67,7 +73,11 @@ def run_sub_devices_program(device): ) sub_device_1 = ttnn.SubDevice([tensix_cores0]) sub_device_2 = ttnn.SubDevice([tensix_cores1]) - sub_device_manager = device.create_sub_device_manager([sub_device_1, sub_device_2], 3200) + sub_devices = [sub_device_1, sub_device_2] + if replicate_sub_devices: + num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices() + sub_devices = [sub_devices] * num_devices + sub_device_manager = device.create_sub_device_manager(sub_devices, 3200) device.load_sub_device_manager(sub_device_manager) x = torch.randn(num_devices, 1, 64, 64, dtype=torch.bfloat16) @@ -130,8 +140,9 @@ def test_sub_devices(device, enable_async_mode): @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) -def test_sub_devices_mesh(mesh_device, enable_async_mode): - run_sub_devices(mesh_device) +@pytest.mark.parametrize("replicate_sub_devices", (False, True)) +def test_sub_devices_mesh(mesh_device, replicate_sub_devices, enable_async_mode): + run_sub_devices(mesh_device, replicate_sub_devices) @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) @@ -140,5 +151,6 @@ def test_sub_device_program(device, enable_async_mode): @pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True) -def test_sub_device_program_mesh(mesh_device, enable_async_mode): - run_sub_devices_program(mesh_device) +@pytest.mark.parametrize("replicate_sub_devices", (False, True)) +def test_sub_device_program_mesh(mesh_device, replicate_sub_devices, enable_async_mode): + run_sub_devices_program(mesh_device, replicate_sub_devices) diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index dc0275cd26a..6971abd948e 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -489,6 +489,24 @@ MeshSubDeviceManagerId MeshDevice::create_sub_device_manager(tt::stl::Span>& mesh_sub_devices, DeviceAddr local_l1_size) { + MeshSubDeviceManagerId mesh_sub_device_manager_id(*this); + TT_FATAL(mesh_sub_devices.size() == this->num_devices(), "Number of devices does not match number of sub-device configurations"); + for (uint32_t i = 0; i < this->num_devices(); i++) { + auto* device = this->devices[i]; + auto& sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i]; + tt::stl::Span sub_devices(mesh_sub_devices[i]); + device->push_work([device, sub_devices, local_l1_size, &sub_device_manager_id]() { + sub_device_manager_id = device->create_sub_device_manager(sub_devices, local_l1_size); + }); + } + for (auto* device : this->devices) { + device->synchronize(); + } + return mesh_sub_device_manager_id; +} + void MeshDevice::load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id) { for (uint32_t i = 0; i < this->num_devices(); i++) { auto* device = this->devices[i]; diff --git a/tt_metal/distributed/mesh_device.hpp b/tt_metal/distributed/mesh_device.hpp index e7a0ab22db8..a7727fb97bd 100644 --- a/tt_metal/distributed/mesh_device.hpp +++ b/tt_metal/distributed/mesh_device.hpp @@ -137,6 +137,8 @@ class MeshDevice : public std::enable_shared_from_this { MeshSubDeviceManagerId create_sub_device_manager( tt::stl::Span sub_devices, DeviceAddr local_l1_size); + MeshSubDeviceManagerId create_sub_device_manager( + const std::vector>& mesh_sub_devices, DeviceAddr local_l1_size); void load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id); void clear_loaded_sub_device_manager(); void remove_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id); diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 43ac6aa3574..ed946f23d9b 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -190,6 +190,26 @@ void py_module(py::module& module) { Args: sub_devices (List[ttnn.SubDevice]): The sub-devices to include in the sub-device manager. + This configuration will be used for each device in the MeshDevice. + local_l1_size (int): The size of the local allocators of each sub-device. The global allocator will be shrunk by this amount. + + Returns: + MeshSubDeviceManagerId: The ID of the created sub-device manager. + )doc") + .def( + "create_sub_device_manager", + [](MeshDevice& self, + const std::vector>& mesh_sub_devices, + DeviceAddr local_l1_size) { return self.create_sub_device_manager(mesh_sub_devices, local_l1_size); }, + py::arg("sub_devices"), + py::arg("local_l1_size"), + R"doc( + Creates a sub-device manager for the given mesh device. + + Args: + mesh_sub_devices (List[List[ttnn.SubDevice]]): The sub-devices to include in the sub-device manager. + Each element of the outer list will be used to configure the corresponding device in the MeshDevice. + This means that the individual devices in the MeshDevice may have different configurations. local_l1_size (int): The size of the local allocators of each sub-device. The global allocator will be shrunk by this amount. Returns: