Skip to content

Commit

Permalink
add pybindings for custom 1D fabric ctx switch intervals (#18239)
Browse files Browse the repository at this point in the history
### Problem description
There is currently no one-size-fits-all context switch interval for 1D
fabric on Wormhole. In some use cases (e.g. test suites with many back
to back tests) we want smaller intervals so teardown is quick. In other
cases (real workloads), we want a longer interval since there may be
longer gaps between subsequent ops using a given fabric link.

### What's changed
Added pybindings for context switch interval override. By default, if a
user does not provide an override, the fabric will use the
implementation default, which is more favourable to test environments
and faster teardown times.

To override the context switch check interval, a user can override
either `create_and_load_sub_device_manager_with_fabric_interface` or
`ttnn.initialize_edm_fabric`. In both cases, the kw_only arg
`context_switch_interval_override` is used to override the interval. The
current default is `10000`. For performance oriented workloads, it is
recommended to start in the 100k-200k range and tweak from there.
  • Loading branch information
SeanNijjar authored Feb 25, 2025
1 parent e963fa4 commit bc10e86
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 4 deletions.
8 changes: 7 additions & 1 deletion tests/ttnn/unit_tests/operations/ccl/test_ccl_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def create_and_load_sub_device_manager_with_fabric_interface(
local_allocator_size,
enable_persistent_fabric=True,
wrap_fabric_around_mesh=False,
context_switch_interval_override=None,
):
assert ccl_worker_sub_device_id < len(worker_sub_devices)
mesh_sub_device_manager_id, fabric_subdevice_id = mesh_device.create_sub_device_manager_with_fabric(
Expand All @@ -21,11 +22,16 @@ def create_and_load_sub_device_manager_with_fabric_interface(
# fabric sub-device id can also be queried from device, no need to explicitly pass it in
mesh_device.load_sub_device_manager(mesh_sub_device_manager_id)
if enable_persistent_fabric:
ttnn.initialize_edm_fabric(mesh_device, wrap_fabric_around_mesh=wrap_fabric_around_mesh)
ttnn.initialize_edm_fabric(
mesh_device,
wrap_fabric_around_mesh=wrap_fabric_around_mesh,
context_switch_interval_override=context_switch_interval_override,
)
return mesh_sub_device_manager_id


def teardown_fabric_interface(mesh_device):
logger.debug(f"Tearing down fabric (this may take a while if context switch interval is large)")
ttnn.teardown_edm_fabric(mesh_device)
ttnn.synchronize_devices(mesh_device)

Expand Down
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/operations/ccl/ccl_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ void py_bind_common(pybind11::module& module) {
&ttnn::ccl::initialize_edm_fabric,
py::arg("mesh_device"),
py::kw_only(),
py::arg("wrap_fabric_around_mesh") = false);
py::arg("wrap_fabric_around_mesh") = false,
py::arg("context_switch_interval_override") = std::nullopt);

module.def("teardown_edm_fabric", &ttnn::ccl::teardown_edm_fabric, py::arg("mesh_device"), py::kw_only());
}
Expand Down
14 changes: 13 additions & 1 deletion ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,10 @@ void EdmLineFabricOpInterface::set_firmware_context_switch_interval(size_t inter
}
}

void initialize_edm_fabric(distributed::MeshDevice* mesh_device, bool wrap_fabric_around_mesh) {
void initialize_edm_fabric(
distributed::MeshDevice* mesh_device,
bool wrap_fabric_around_mesh,
std::optional<size_t> context_switch_interval_override) {
if (wrap_fabric_around_mesh) {
auto devices = mesh_device->get_view().get_ring_devices();
std::vector<Program*> program_ptrs;
Expand All @@ -835,6 +838,9 @@ void initialize_edm_fabric(distributed::MeshDevice* mesh_device, bool wrap_fabri
std::transform(
programs.begin(), programs.end(), std::back_inserter(program_ptrs), [](Program& p) { return &p; });
EdmLineFabricOpInterface fabric_device_builders = EdmLineFabricOpInterface(devices, program_ptrs, true);
if (context_switch_interval_override.has_value()) {
fabric_device_builders.set_firmware_context_switch_interval(context_switch_interval_override.value());
}
fabric_device_builders.build_kernels();

for (size_t i = 0; i < devices.size(); i++) {
Expand Down Expand Up @@ -865,6 +871,9 @@ void initialize_edm_fabric(distributed::MeshDevice* mesh_device, bool wrap_fabri
});
row_fabric_lines.push_back(
EdmLineFabricOpInterface(mesh_device->get_view().get_row_views()[i], program_ptrs, true));
if (context_switch_interval_override.has_value()) {
row_fabric_lines.back().set_firmware_context_switch_interval(context_switch_interval_override.value());
}
}

for (size_t i = 0; i < num_cols; i++) {
Expand All @@ -875,6 +884,9 @@ void initialize_edm_fabric(distributed::MeshDevice* mesh_device, bool wrap_fabri
}
col_fabric_lines.push_back(
EdmLineFabricOpInterface(mesh_device->get_view().get_column_views()[i], program_ptrs, true));
if (context_switch_interval_override.has_value()) {
col_fabric_lines.back().set_firmware_context_switch_interval(context_switch_interval_override.value());
}
}

std::for_each(row_fabric_lines.begin(), row_fabric_lines.end(), [](auto& line) { line.build_kernels(); });
Expand Down
5 changes: 4 additions & 1 deletion ttnn/cpp/ttnn/operations/ccl/erisc_datamover_builder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,10 @@ class EdmLineFabricOpInterface {
size_t firmware_context_switch_interval = FabricEriscDatamoverBuilder::default_firmware_context_switch_interval;
};

void initialize_edm_fabric(distributed::MeshDevice* mesh_device, bool wrap_fabric_around_mesh = false);
void initialize_edm_fabric(
distributed::MeshDevice* mesh_device,
bool wrap_fabric_around_mesh = false,
std::optional<size_t> context_switch_interval_override = std::nullopt);
void teardown_edm_fabric(distributed::MeshDevice* mesh_device);

}; // namespace ccl
Expand Down

0 comments on commit bc10e86

Please sign in to comment.