Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/jchu/ttnn-integration-with-mesh'…
Browse files Browse the repository at this point in the history
… into sminakov/all-mesh2
  • Loading branch information
sminakov-tt committed Feb 28, 2025
2 parents 6ef2fd8 + ad8be43 commit d298eef
Show file tree
Hide file tree
Showing 14 changed files with 254 additions and 72 deletions.
8 changes: 4 additions & 4 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,13 +342,13 @@ def get_devices(request):
elif "pcie_devices" in request.fixturenames:
devices = request.getfixturevalue("pcie_devices")
elif "mesh_device" in request.fixturenames:
devices = request.getfixturevalue("mesh_device").get_devices()
devices = [request.getfixturevalue("mesh_device")]
elif "n300_mesh_device" in request.fixturenames:
devices = request.getfixturevalue("n300_mesh_device").get_devices()
devices = [request.getfixturevalue("n300_mesh_device")]
elif "t3k_mesh_device" in request.fixturenames:
devices = request.getfixturevalue("t3k_mesh_device").get_devices()
devices = [request.getfixturevalue("t3k_mesh_device")]
elif "pcie_mesh_device" in request.fixturenames:
devices = request.getfixturevalue("pcie_mesh_device").get_devices()
devices = [request.getfixturevalue("pcie_mesh_device")]
else:
devices = []
return devices
Expand Down
3 changes: 1 addition & 2 deletions models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,8 +526,7 @@ def test_FalconCausalLM_end_to_end_with_program_cache(
model_config = get_model_config(model_config_str, llm_mode, input_shape, num_devices)
devices = t3k_mesh_device.get_devices()
# Set async mode
for device in devices:
device.enable_async(async_mode)
t3k_mesh_device.enable_async(async_mode)
compute_grid_size = devices[0].compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")
Expand Down
4 changes: 1 addition & 3 deletions tests/ttnn/distributed/test_multidevice_TG.py
Original file line number Diff line number Diff line change
Expand Up @@ -1448,9 +1448,7 @@ def test_device_line_all_gather_8x4_data(mesh_device, cluster_axis: int, dim: in
- Every device will have the shape: [4, 1, 32, 32]
"""
if async_mode:
for i in mesh_device.get_device_ids():
device = mesh_device.get_device(i)
device.enable_async(True)
mesh_device.enable_async(True)

(rows, cols), tile_size = mesh_device.shape, 32
full_tensor = torch.zeros((1, 1, tile_size * rows, tile_size * cols), dtype=torch.bfloat16)
Expand Down
21 changes: 21 additions & 0 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,27 @@ def test_multi_device_single_op_unary(mesh_device):
assert_with_pcc(ttnn_torch_output_tensor, torch_output_golden, pcc=0.999)


def test_multi_device_single_op_unary_with_cache(mesh_device):
"""Multidevice API test: Running tensor-parallel multi-device single-op unary with cache"""
mesh_device.enable_program_cache()

torch_input_tensor = torch.rand((1, 1, 32, 32 * mesh_device.get_num_devices()), dtype=torch.bfloat16)
torch_output_golden = torch.nn.functional.gelu(torch_input_tensor)
torch_golden = torch.nn.functional.gelu(torch_output_golden)

ttnn_input_tensor = ttnn.from_torch(
torch_input_tensor,
layout=ttnn.TILE_LAYOUT,
mesh_mapper=ShardTensorToMesh(mesh_device, dim=3),
device=mesh_device,
)
ttnn_output_tensor = ttnn.gelu(ttnn_input_tensor)
final_output_tensor = ttnn.gelu(ttnn_output_tensor)

ttnn_torch_output_tensor = ttnn.to_torch(final_output_tensor, mesh_composer=ConcatMeshToTensor(mesh_device, dim=3))
assert_with_pcc(ttnn_torch_output_tensor, torch_golden, pcc=0.999)


@pytest.mark.parametrize(
"device_params",
[{"dispatch_core_axis": ttnn.DispatchCoreAxis.ROW}, {"dispatch_core_axis": ttnn.DispatchCoreAxis.COL}],
Expand Down
5 changes: 4 additions & 1 deletion tt_metal/api/tt-metalium/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,14 @@
#include "sub_device_manager.hpp"
#include "sub_device_types.hpp"
#include "span.hpp"
#include "program_cache.hpp"

namespace tt {

namespace tt_metal {

namespace program_cache::detail {
class ProgramCache;
}
/*
MemoryBlockTable is a list of memory blocks in the following format:
[{"blockID": "0", "address": "0", "size": "0", "prevID": "0", "nextID": "0", "allocated": true}]
Expand Down
74 changes: 73 additions & 1 deletion tt_metal/api/tt-metalium/program_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,18 @@

#include "program_impl.hpp"
#include "unique_any.hpp"
#include "mesh_workload.hpp"

namespace tt::tt_metal::program_cache::detail {
template <typename shared_variables_t>
struct CachedMeshWorkload {
tt::tt_metal::distributed::MeshWorkload workload;
// Shared variables between create and override_runtime_arguments functions
shared_variables_t shared_variables;

CachedMeshWorkload(tt::tt_metal::distributed::MeshWorkload&& workload, shared_variables_t&& shared_variables) :
workload{std::move(workload)}, shared_variables{std::forward<shared_variables_t>(shared_variables)} {}
};

template <typename shared_variables_t>
struct CachedProgram {
Expand All @@ -21,6 +31,68 @@ struct CachedProgram {
program{std::move(program)}, shared_variables{std::forward<shared_variables_t>(shared_variables)} {}
};

// Adapter that provides a unified interface for both CachedProgram and CachedMeshWorkload
template <typename shared_variables_t>
class ProgramAdapter {
private:
using CachedObject = std::variant<CachedMeshWorkload<shared_variables_t>, CachedProgram<shared_variables_t>>;
CachedObject cached_object_;
// Helper to retrieve the first program from a mesh workload
static tt::tt_metal::Program& get_first_program(CachedMeshWorkload<shared_variables_t>& cached_mesh_workload) {
// Get the programs map from the workload
auto& programs = cached_mesh_workload.workload.get_programs();

// There must be at least one program in the workload
TT_FATAL(!programs.empty(), "Mesh workload must have at least one program");

// Return the first program in the workload
auto& first_program_pair = *programs.begin();
return first_program_pair.second;
}

public:
// These are references to the original objects
tt::tt_metal::Program& program;
shared_variables_t& shared_variables;

// Constructor for CachedProgram
ProgramAdapter(CachedProgram<shared_variables_t>&& cached_program) :
cached_object_(std::move(cached_program)),
program(std::get<CachedProgram<shared_variables_t>>(cached_object_).program),
shared_variables(std::get<CachedProgram<shared_variables_t>>(cached_object_).shared_variables) {}

// Constructor for Program and shared variables
ProgramAdapter(tt::tt_metal::Program&& program, shared_variables_t&& shared_vars) :
ProgramAdapter(CachedProgram<shared_variables_t>{std::move(program), std::move(shared_vars)}) {}

// Constructor for CachedMeshWorkload
ProgramAdapter(CachedMeshWorkload<shared_variables_t>&& cached_mesh_workload) :
cached_object_(std::move(cached_mesh_workload)),
program(get_first_program(std::get<CachedMeshWorkload<shared_variables_t>>(cached_object_))),
shared_variables(std::get<CachedMeshWorkload<shared_variables_t>>(cached_object_).shared_variables) {}

ProgramAdapter(ProgramAdapter&& other) noexcept :
cached_object_{std::move(other.cached_object_)},
program{
(cached_object_.index() == 0)
? get_first_program(std::get<CachedMeshWorkload<shared_variables_t>>(cached_object_))
: std::get<CachedProgram<shared_variables_t>>(cached_object_).program},
shared_variables{
(cached_object_.index() == 0)
? std::get<CachedMeshWorkload<shared_variables_t>>(cached_object_).shared_variables
: std::get<CachedProgram<shared_variables_t>>(cached_object_).shared_variables} {}

// Get the CachedMeshWorkload (throws if not a mesh workload)
CachedMeshWorkload<shared_variables_t>& get_cached_mesh_workload() {
return std::get<CachedMeshWorkload<shared_variables_t>>(cached_object_);
}

// Get the CachedProgram (throws if not a program)
CachedProgram<shared_variables_t>& get_cached_program() {
return std::get<CachedProgram<shared_variables_t>>(cached_object_);
}
};

struct CachedProgramFactory {
static constexpr auto MAX_SIZE = 4096;
static constexpr auto ALIGNMENT = 32;
Expand All @@ -30,7 +102,7 @@ struct CachedProgramFactory {
std::size_t program_factory_index;

template <typename shared_variables_t>
CachedProgramFactory(CachedProgram<shared_variables_t>&& cached_program, std::size_t program_factory_index) :
CachedProgramFactory(ProgramAdapter<shared_variables_t>&& cached_program, std::size_t program_factory_index) :
cached_program{std::move(cached_program)}, program_factory_index{program_factory_index} {}
};

Expand Down
1 change: 1 addition & 0 deletions tt_metal/api/tt-metalium/unique_any.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct unique_any final {
}
this->delete_storage = other.delete_storage;
this->move_storage = other.move_storage;
other.pointer = nullptr;
}
return *this;
}
Expand Down
2 changes: 2 additions & 0 deletions tt_metal/impl/buffers/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include "assert.hpp"
#include "dispatch.hpp"
#include <tt-metalium/command_queue_interface.hpp>
#include <tt-metalium/device_command.hpp>
#include <tt-metalium/dispatch_settings.hpp>
#include <tt-metalium/program_impl.hpp>

#include "tt_cluster.hpp"

Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <sub_device_types.hpp>
#include <span.hpp>
#include <types.hpp>
#include <tt-metalium/program_cache.hpp>

#include "impl/dispatch/topology.hpp"
#include "impl/dispatch/hardware_command_queue.hpp"
Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/event/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "tt_metal/impl/event/dispatch.hpp"
#include <tt-metalium/dispatch_settings.hpp>
#include <tt-metalium/program_impl.hpp>
#include "tt_metal/impl/dispatch/dispatch_query_manager.hpp"
#include <tt_align.hpp>

Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/lightmetal/lightmetal_replay.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <optional>
#include <lightmetal_binary.hpp>

#include <tt-metalium/program_impl.hpp>
#include <tt-metalium/device.hpp>

// Forward decl for trace_buffer.hpp
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/trace/dispatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include "tt_metal/impl/trace/dispatch.hpp"
#include "tt_metal/impl/dispatch/dispatch_query_manager.hpp"

#include <tt-metalium/device_command.hpp>
namespace tt::tt_metal::trace_dispatch {

void reset_host_dispatch_state_for_trace(
Expand Down
Loading

0 comments on commit d298eef

Please sign in to comment.