Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#15836: Update reads, writes, and synchronize ttnn apis to take in sub device ids #15812

Merged
merged 2 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 58 additions & 14 deletions tests/ttnn/unit_tests/test_sub_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import ttnn


def run_sub_devices(device):
def run_sub_devices(device, replicate_sub_devices=False):
tensix_cores0 = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
Expand All @@ -26,16 +26,26 @@ def run_sub_devices(device):
)
sub_device_1 = ttnn.SubDevice([tensix_cores0])
sub_device_2 = ttnn.SubDevice([tensix_cores1])
sub_device_manager1 = device.create_sub_device_manager([sub_device_1, sub_device_2], 3200)
sub_device_manager2 = device.create_sub_device_manager([sub_device_2], 3200)
sub_devices_1 = [sub_device_1, sub_device_2]
sub_devices_2 = [sub_device_2]
if replicate_sub_devices:
num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices()
sub_devices_1 = [sub_devices_1] * num_devices
sub_devices_2 = [sub_devices_2] * num_devices
sub_device_manager1 = device.create_sub_device_manager(sub_devices_1, 3200)
sub_device_manager2 = device.create_sub_device_manager(sub_devices_2, 3200)
device.load_sub_device_manager(sub_device_manager1)
ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(1)])
ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(0), ttnn.SubDeviceId(1)])
ttnn.synchronize_devices(device)
device.load_sub_device_manager(sub_device_manager2)
ttnn.synchronize_devices(device, sub_device_ids=[ttnn.SubDeviceId(0)])
device.clear_loaded_sub_device_manager()
device.remove_sub_device_manager(sub_device_manager1)
device.remove_sub_device_manager(sub_device_manager2)


def run_sub_devices_program(device):
def run_sub_devices_program(device, replicate_sub_devices=False):
is_mesh_device = isinstance(device, ttnn.MeshDevice)
if is_mesh_device:
inputs_mesh_mapper = ttnn.ShardTensorToMesh(device, dim=0)
Expand All @@ -48,22 +58,26 @@ def run_sub_devices_program(device):
tensix_cores0 = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(3, 3),
ttnn.CoreCoord(4, 4),
ttnn.CoreCoord(4, 4),
),
}
)
tensix_cores1 = ttnn.CoreRangeSet(
{
ttnn.CoreRange(
ttnn.CoreCoord(4, 4),
ttnn.CoreCoord(4, 4),
ttnn.CoreCoord(0, 0),
ttnn.CoreCoord(3, 3),
),
}
)
sub_device_1 = ttnn.SubDevice([tensix_cores0])
sub_device_2 = ttnn.SubDevice([tensix_cores1])
sub_device_manager = device.create_sub_device_manager([sub_device_1, sub_device_2], 3200)
sub_devices = [sub_device_1, sub_device_2]
if replicate_sub_devices:
num_devices = 1 if isinstance(device, ttnn.Device) else device.get_num_devices()
sub_devices = [sub_devices] * num_devices
sub_device_manager = device.create_sub_device_manager(sub_devices, 3200)
device.load_sub_device_manager(sub_device_manager)

x = torch.randn(num_devices, 1, 64, 64, dtype=torch.bfloat16)
Expand All @@ -74,20 +88,48 @@ def run_sub_devices_program(device):
device=device,
memory_config=ttnn.L1_MEMORY_CONFIG,
mesh_mapper=inputs_mesh_mapper,
sub_device_ids=[ttnn.SubDeviceId(0)],
)

xt_host = ttnn.from_torch(
x,
dtype=ttnn.bfloat16,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=inputs_mesh_mapper,
sub_device_ids=[ttnn.SubDeviceId(1)],
)

ttnn.copy_host_to_device_tensor(xt_host, xt, sub_device_ids=[ttnn.SubDeviceId(1)])

grid_size = device.compute_with_storage_grid_size()
shard_size = [32, 64]
shard_scheme = ttnn.TensorMemoryLayout.HEIGHT_SHARDED
shard_orientation = ttnn.ShardOrientation.ROW_MAJOR
yt = ttnn.interleaved_to_sharded(
xt, grid_size, shard_size, shard_scheme, shard_orientation, output_dtype=ttnn.bfloat16
)
y = ttnn.to_torch(yt, device=device, mesh_composer=output_mesh_composer)
y = ttnn.to_torch(yt, device=device, mesh_composer=output_mesh_composer, sub_device_ids=[ttnn.SubDeviceId(1)])

eq = torch.equal(x, y)
assert eq

y = ttnn.to_torch(yt.cpu(sub_device_ids=[ttnn.SubDeviceId(0)]), mesh_composer=output_mesh_composer)

eq = torch.equal(x, y)
assert eq

event = ttnn.create_event(device)

yt2 = ttnn.interleaved_to_sharded(
xt, grid_size, shard_size, shard_scheme, shard_orientation, output_dtype=ttnn.bfloat16
)
ttnn.record_event(0, event, [ttnn.SubDeviceId(1)])
ttnn.wait_for_event(0, event)
y2 = ttnn.to_torch(yt2, device=device, mesh_composer=output_mesh_composer, sub_device_ids=[ttnn.SubDeviceId(0)])

eq = torch.equal(x, y2)
assert eq

device.clear_loaded_sub_device_manager()
device.remove_sub_device_manager(sub_device_manager)

Expand All @@ -98,8 +140,9 @@ def test_sub_devices(device, enable_async_mode):


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_sub_devices_mesh(mesh_device, enable_async_mode):
run_sub_devices(mesh_device)
@pytest.mark.parametrize("replicate_sub_devices", (False, True))
def test_sub_devices_mesh(mesh_device, replicate_sub_devices, enable_async_mode):
run_sub_devices(mesh_device, replicate_sub_devices)


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
Expand All @@ -108,5 +151,6 @@ def test_sub_device_program(device, enable_async_mode):


@pytest.mark.parametrize("enable_async_mode", (False, True), indirect=True)
def test_sub_device_program_mesh(mesh_device, enable_async_mode):
run_sub_devices_program(mesh_device)
@pytest.mark.parametrize("replicate_sub_devices", (False, True))
def test_sub_device_program_mesh(mesh_device, replicate_sub_devices, enable_async_mode):
run_sub_devices_program(mesh_device, replicate_sub_devices)
18 changes: 18 additions & 0 deletions tt_metal/distributed/mesh_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,24 @@ MeshSubDeviceManagerId MeshDevice::create_sub_device_manager(tt::stl::Span<const
}
return mesh_sub_device_manager_id;
}

MeshSubDeviceManagerId MeshDevice::create_sub_device_manager(const std::vector<std::vector<SubDevice>>& mesh_sub_devices, DeviceAddr local_l1_size) {
MeshSubDeviceManagerId mesh_sub_device_manager_id(*this);
TT_FATAL(mesh_sub_devices.size() == this->num_devices(), "Number of devices does not match number of sub-device configurations");
for (uint32_t i = 0; i < this->num_devices(); i++) {
auto* device = this->devices[i];
auto& sub_device_manager_id = mesh_sub_device_manager_id.sub_device_manager_ids[i];
tt::stl::Span<const SubDevice> sub_devices(mesh_sub_devices[i]);
device->push_work([device, sub_devices, local_l1_size, &sub_device_manager_id]() {
sub_device_manager_id = device->create_sub_device_manager(sub_devices, local_l1_size);
});
}
for (auto* device : this->devices) {
device->synchronize();
}
return mesh_sub_device_manager_id;
}

void MeshDevice::load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id) {
for (uint32_t i = 0; i < this->num_devices(); i++) {
auto* device = this->devices[i];
Expand Down
2 changes: 2 additions & 0 deletions tt_metal/distributed/mesh_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class MeshDevice : public std::enable_shared_from_this<MeshDevice> {

MeshSubDeviceManagerId create_sub_device_manager(
tt::stl::Span<const SubDevice> sub_devices, DeviceAddr local_l1_size);
MeshSubDeviceManagerId create_sub_device_manager(
const std::vector<std::vector<SubDevice>>& mesh_sub_devices, DeviceAddr local_l1_size);
void load_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id);
void clear_loaded_sub_device_manager();
void remove_sub_device_manager(MeshSubDeviceManagerId mesh_sub_device_manager_id);
Expand Down
23 changes: 20 additions & 3 deletions ttnn/cpp/pybind11/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ void py_device_module_types(py::module& m_device) {

py::class_<SubDevice>(m_device, "SubDevice", "Class describing a sub-device of a Tenstorrent accelerator device.");

py::class_<SubDeviceId>(m_device, "SubDeviceId", "ID of a sub-device.");

py::class_<SubDeviceManagerId>(m_device, "SubDeviceManagerId", "ID of a sub-device manager.");
}

Expand All @@ -114,6 +116,14 @@ void device_module(py::module& m_device) {
The order of cores is Tensix, then Ethernet.
)doc");

auto pySubDeviceId = static_cast<py::class_<SubDeviceId>>(m_device.attr("SubDeviceId"));
pySubDeviceId.def(
py::init<uint8_t>(),
py::arg("id"),
R"doc(
Creates a SubDeviceId object with the given ID.
)doc");

auto pyDevice = static_cast<py::class_<Device, std::unique_ptr<Device, py::nodelete>>>(m_device.attr("Device"));
pyDevice
.def(
Expand Down Expand Up @@ -482,21 +492,25 @@ void device_module(py::module& m_device) {

m_device.def(
"synchronize_device",
[](Device* device, const std::optional<uint8_t> cq_id) {
[](Device* device, const std::optional<uint8_t> cq_id, const std::vector<SubDeviceId>& sub_device_ids) {
// Send finish command to issue queue through worker thread
// Worker thread will stall until the device is flushed.
device->push_work([device, cq_id]() mutable { Synchronize(device, cq_id); });
device->push_work(
[device, cq_id, &sub_device_ids]() mutable { Synchronize(device, cq_id, sub_device_ids); });
// Main thread stalls until worker is complete (full device and worker queue flush).
device->synchronize();
},
R"doc(
Synchronize the device with host by waiting for all operations to complete.
If cq_id is provided then only the operations associated with that cq_id are waited for,
otherwise operations for all command queues are waited on.
If the device has been configured with sub-devices, then sub_device_ids can be provided to only wait
for the operations that ran on the specified sub-devices, otherwise all sub-devices (the entire chip) are waited on.

Args:
device (ttnn.device.Device): The device to synchronize with.
cq_id (int, optional): The command queue ID to synchronize. Defaults to `None`.
sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to synchronize. Defaults to all sub-devices.

Returns:
`None`: The op ensures that all operations are completed.
Expand All @@ -508,7 +522,8 @@ void device_module(py::module& m_device) {
>>> ttnn.synchronize_device(device)
)doc",
py::arg("device"),
py::arg("cq_id") = std::nullopt);
py::arg("cq_id") = std::nullopt,
py::arg("sub_device_ids") = std::vector<SubDeviceId>());
m_device.def("SetLazyCommandQueueMode", &tt::tt_metal::detail::SetLazyCommandQueueMode, R"doc(
If set to true, the host does not notify the device that there are commands available other than
the FinishCommand. Once set to false, all subsequent commands will immediately notify the device
Expand All @@ -527,6 +542,8 @@ void device_module(py::module& m_device) {

m_device.attr("DEFAULT_L1_SMALL_SIZE") = py::int_(DEFAULT_L1_SMALL_SIZE);
m_device.attr("DEFAULT_TRACE_REGION_SIZE") = py::int_(DEFAULT_TRACE_REGION_SIZE);

m_device.attr("DefaultQueueId") = ttnn::DefaultQueueId;
}

void py_device_module(py::module& module) {
Expand Down
8 changes: 6 additions & 2 deletions ttnn/cpp/pybind11/events.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,17 @@ void py_module(py::module& module) {

module.def(
"record_event",
py::overload_cast<uint8_t, const std::shared_ptr<Event>&>(&record_event),
py::overload_cast<uint8_t, const std::shared_ptr<Event>&, const std::vector<SubDeviceId>&>(&record_event),
py::arg("cq_id"),
py::arg("event"),
py::arg("sub_device_ids") = std::vector<SubDeviceId>(),
R"doc(
Record the completion of commands on this CQ, preceeding this call.

Args:
cq_id (int): The Command Queue on which event completion will be recorded.
event (event): The event used to record completion of preceeding commands.
sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to record completion for. Defaults to all sub-devices.
)doc");

module.def(
Expand Down Expand Up @@ -69,9 +71,10 @@ void py_module(py::module& module) {

module.def(
"record_event",
py::overload_cast<uint8_t, const MultiDeviceEvent&>(&record_event),
py::overload_cast<uint8_t, const MultiDeviceEvent&, const std::vector<SubDeviceId>&>(&record_event),
py::arg("cq_id"),
py::arg("multi_device_event"),
py::arg("sub_device_ids") = std::vector<SubDeviceId>(),
R"doc(
Record the completion of commands on this CQ, preceeding this call.

Expand All @@ -91,6 +94,7 @@ void py_module(py::module& module) {
Args:
cq_id (int): The Command Queue on which event completion will be recorded.
event (event): The event used to record completion of preceeding commands.
sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to record completion for. Defaults to all sub-devices.
)doc");
}

Expand Down
31 changes: 25 additions & 6 deletions ttnn/cpp/pybind11/operations/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,41 @@ void py_module(py::module& module) {

module.def(
"to_device",
py::overload_cast<const ttnn::Tensor&, Device*, const std::optional<MemoryConfig>&>(
&ttnn::operations::core::to_device),
py::overload_cast<
const ttnn::Tensor&,
Device*,
const std::optional<MemoryConfig>&,
uint8_t,
const std::vector<SubDeviceId>&>(&ttnn::operations::core::to_device),
py::arg("tensor"),
py::arg("device"),
py::arg("memory_config") = std::nullopt);
py::arg("memory_config") = std::nullopt,
py::arg("cq_id") = ttnn::DefaultQueueId,
py::arg("sub_device_ids") = std::vector<SubDeviceId>());

module.def(
"to_device",
py::overload_cast<const ttnn::Tensor&, MeshDevice*, const std::optional<MemoryConfig>&>(
&ttnn::operations::core::to_device),
py::overload_cast<
const ttnn::Tensor&,
MeshDevice*,
const std::optional<MemoryConfig>&,
uint8_t,
const std::vector<SubDeviceId>&>(&ttnn::operations::core::to_device),
py::arg("tensor"),
py::arg("device"),
py::arg("memory_config") = std::nullopt,
py::arg("cq_id") = ttnn::DefaultQueueId,
py::arg("sub_device_ids") = std::vector<SubDeviceId>(),
R"doc(
Copy tensor from host to device.

Args:
tensor (ttnn.Tensor): The tensor to be copied from host to device.
device (ttnn.Device | ttnn.MeshDevice): The target device where the tensor will be copied.
memory_config (ttnn.MemoryConfig, optional): The memory configuration to use. Defaults to `None`.
cq_id (int, optional): The command queue ID to use. Defaults to `0`.
sub_device_ids (List[ttnn.SubDeviceId], optional): The sub-device IDs to wait on before writing the tensor to device memory.
If it is not provided, device will stall for all programs of the specified cq to finish before writing the tensor to device memory.

Returns:
ttnn.Tensor: The device tensor copy.
Expand All @@ -103,6 +118,7 @@ void py_module(py::module& module) {
py::arg("blocking") = true,
py::kw_only(),
py::arg("cq_id") = ttnn::DefaultQueueId,
py::arg("sub_device_ids") = std::vector<SubDeviceId>(),
R"doc(
Copy tensor from device to host.

Expand All @@ -112,6 +128,8 @@ void py_module(py::module& module) {

Keyword args:
cq_id (int, optional): the command queue ID to use. Defaults to `0`.
sub_device_ids (List[ttnn.SubDeviceId], optional): the sub-device IDs to wait on before reading the tensor from device memory.
If it is not provided, device will stall for all programs of the specified cq to finish before reading the tensor from device memory.

Returns:
ttnn.Tensor: the host tensor copy.
Expand Down Expand Up @@ -243,7 +261,8 @@ void py_module(py::module& module) {
&ttnn::operations::core::copy_host_to_device_tensor,
py::arg("host_tensor"),
py::arg("device_tensor"),
py::arg("cq_id") = ttnn::DefaultQueueId);
py::arg("cq_id") = ttnn::DefaultQueueId,
py::arg("sub_device_ids") = std::vector<SubDeviceId>());

module.def(
"begin_trace_capture",
Expand Down
Loading
Loading