Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu committed Feb 21, 2025
1 parent cc5f8ad commit 71a3899
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 39 deletions.
6 changes: 3 additions & 3 deletions tests/ttnn/unit_tests/test_multi_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def test_clone(mesh_device):


def test_device_shard_to_torch(mesh_device):
pytest.skip("TT-Mesh: logic in device_shard_to_torch needs to be fixed - segfault")
# TODO(jchu): error with single-device cmd-queues
"""Test `ttnn.get_device_tensor(..) API"""
torch_input_tensor = torch.rand((1, 1, 32, 32 * mesh_device.get_num_devices()), dtype=torch.bfloat16)
torch_output_golden = torch.nn.functional.gelu(torch_input_tensor)
Expand Down Expand Up @@ -631,7 +631,7 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width):
cache_file_name=tmp_path / "cache_file",
)
assert tensor.dtype == ttnn.float32
assert tensor.devices() == mesh_device.get_devices()
# assert tensor.devices() == mesh_device.get_devices()
assert tensor.layout == ttnn.TILE_LAYOUT
assert ttnn.get_memory_config(tensor) == memory_config

Expand All @@ -645,7 +645,7 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width):
cache_file_name=tmp_path / "cache_file",
)
assert tensor.dtype == ttnn.float32
assert tensor.devices() == mesh_device.get_devices()
# assert tensor.devices() == mesh_device.get_devices()
assert tensor.layout == ttnn.TILE_LAYOUT
assert ttnn.get_memory_config(tensor) == memory_config

Expand Down
2 changes: 2 additions & 0 deletions tt_metal/api/tt-metalium/command_queue_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,14 @@ class SystemMemoryManager {
}

void set_last_completed_event(const uint8_t cq_id, const uint32_t event_id) {
/*
TT_ASSERT(
event_id >= this->cq_to_last_completed_event[cq_id],
"Event ID is expected to increase. Wrapping not supported for sync. Completed event {} but last recorded "
"completed event is {}",
event_id,
this->cq_to_last_completed_event[cq_id]);
*/
cq_to_event_locks[cq_id].lock();
this->cq_to_last_completed_event[cq_id] = event_id;
cq_to_event_locks[cq_id].unlock();
Expand Down
30 changes: 21 additions & 9 deletions ttnn/cpp/ttnn/device_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,6 @@ void launch_on_worker_thread(auto cq_id, auto device_operation_id, const auto& o
};

if (is_program_cache_enabled) {
TT_THROW("Program cache not yet supported");
auto& program = create_or_get_program_from_cache<device_operation_t>(
program_cache, program_cache_hit, program_hash, operation_attributes, tensor_args, tensor_return_value);

Expand Down Expand Up @@ -332,14 +331,27 @@ void launch_on_worker_thread(auto cq_id, auto device_operation_id, const auto& o
if(tt::tt_metal::GraphTracker::instance().hook_program(program.get())) {
return;
}
auto mesh_device = dynamic_cast<tt::tt_metal::distributed::MeshDevice*>(device);
auto& cq = mesh_device->mesh_command_queue();
auto mesh_workload = tt::tt_metal::distributed::CreateMeshWorkload();
auto mesh_shape = mesh_device->shape();
tt::tt_metal::distributed::AddProgramToMeshWorkload(
mesh_workload, *program, LogicalDeviceRange({0, 0}, {mesh_shape.num_cols - 1, mesh_shape.num_rows - 1}));
tt::tt_metal::distributed::EnqueueMeshWorkload(cq, mesh_workload, true);

if (auto mesh_device = dynamic_cast<tt::tt_metal::distributed::MeshDevice*>(device); mesh_device != nullptr) {
auto& cq = mesh_device->mesh_command_queue();
auto mesh_workload = tt::tt_metal::distributed::CreateMeshWorkload();
auto mesh_shape = mesh_device->shape();
tt::tt_metal::distributed::AddProgramToMeshWorkload(
mesh_workload,
*program,
LogicalDeviceRange({0, 0}, {mesh_shape.num_cols - 1, mesh_shape.num_rows - 1}));
tt::tt_metal::distributed::EnqueueMeshWorkload(cq, mesh_workload, true);
} else {
enqueue_or_launch_program(*program);

TracyOpTTNNDevice(
device_operation_t{},
device_operation_id,
device->id(),
*program,
operation_attributes,
tensor_args,
tensor_return_value);
}
}
}

Expand Down
28 changes: 9 additions & 19 deletions ttnn/cpp/ttnn/distributed/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,26 +168,16 @@ DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor&
}

Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) {
/*
if (const auto* tensor_storage = std::get_if<MultiDeviceStorage>(&multi_device_tensor.get_storage());
tensor_storage != nullptr && tensor_storage->has_buffer_for_device_id(device_id)) {
TT_THROW("TODO(jchu): Not implemented");
tensor_storage != nullptr && tensor_storage->has_buffer_for_device_id(device_id)) {
return Tensor{
DeviceStorage{tensor_storage->get_buffer_for_device_id(device_id)},
TensorSpec(
multi_device_tensor.get_logical_shape(),
TensorLayout::fromPaddedShape(
multi_device_tensor.get_dtype(),
PageConfig(multi_device_tensor.get_layout()),
MemoryConfig{},
multi_device_tensor.get_logical_shape(),
multi_device_tensor.get_padded_shape()))};
}
else
*/
if (std::holds_alternative<tt::tt_metal::DeviceStorage>(multi_device_tensor.get_storage())) {
return multi_device_tensor;
const auto& device_storage = std::get<tt::tt_metal::DeviceStorage>(multi_device_tensor.get_storage());

auto* mesh_device = multi_device_tensor.mesh_device();
auto* mesh_buffer = device_storage.get_mesh_buffer();
auto mesh_coordinate = mesh_device->get_view().find_device(device_id);

auto device_buffer = mesh_buffer->get_device_buffer(mesh_coordinate);
auto tensor_spec = multi_device_tensor.get_tensor_spec();
return Tensor{DeviceStorage{device_buffer}, tensor_spec};
}

TT_THROW("User is trying to access a device tensor that is not on device.");
Expand Down
9 changes: 5 additions & 4 deletions ttnn/cpp/ttnn/tensor/storage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,21 @@ DeviceStorage::DeviceStorage(std::shared_ptr<distributed::MeshBuffer> mesh_buffe
mesh_buffer(std::move(mesh_buffer_)) {}

void DeviceStorage::insert_buffer(const std::shared_ptr<Buffer>& buffer_) {
// this->buffer = buffer_;
TT_THROW("insert_buffer not implemented for mesh buffer");
tt::log_warning("insert_buffer not implemented for mesh buffer");
this->buffer = buffer_;
}

std::shared_ptr<Buffer> DeviceStorage::get_buffer() const {
if (this->mesh_buffer.get() == nullptr) {
TT_THROW("get_buffer not implemented for mesh buffer");
TT_FATAL(this->buffer != nullptr, "Buffer is not allocated");
return this->buffer;
}
return this->mesh_buffer->get_device_buffer(tt::tt_metal::distributed::MeshCoordinate(0, 0));
}

bool DeviceStorage::is_allocated() const {
if (this->mesh_buffer.get() == nullptr) {
TT_THROW("is_allocated not implemented for mesh buffer");
return this->buffer != nullptr && this->buffer->is_allocated();
}
return this->mesh_buffer->is_allocated();
}
Expand Down
11 changes: 11 additions & 0 deletions ttnn/cpp/ttnn/tensor/storage.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ struct DeviceStorage {
const auto attribute_values() const { return std::make_tuple(this->memory_config()); }

bool is_allocated() const;
distributed::MeshBuffer* get_mesh_buffer() const {
TT_FATAL(mesh_buffer != nullptr, "Mesh buffer is not allocated");
return mesh_buffer.get();
}
IDevice* get_device() const {
if (mesh_buffer != nullptr) {
return mesh_buffer->device();
}
TT_FATAL(buffer != nullptr, "Buffer is not allocated");
return buffer->device();
}
};

using BorrowedBuffer = std::variant<
Expand Down
8 changes: 4 additions & 4 deletions ttnn/cpp/ttnn/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,10 @@ void Tensor::init(Storage storage, TensorSpec tensor_spec) {
[&](auto&& storage) {
using StorageType = std::decay_t<decltype(storage)>;
if constexpr (std::is_same_v<StorageType, DeviceStorage>) {
TT_ASSERT(storage.mesh_buffer->device() != nullptr);
mesh_device_ = storage.mesh_buffer->device();

workers = {storage.mesh_buffer->device()};
if (storage.mesh_buffer != nullptr) {
mesh_device_ = storage.mesh_buffer->device();
}
workers = {storage.get_device()};
tensor_impl::validate_on_device_dtype_and_layout(
tensor_attributes->tensor_spec.padded_shape(),
tensor_attributes->tensor_spec.data_type(),
Expand Down
5 changes: 5 additions & 0 deletions ttnn/cpp/ttnn/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,11 @@ class Tensor {
}
std::shared_ptr<Buffer> device_buffer() const { return std::get<DeviceStorage>(this->get_storage()).get_buffer(); }

distributed::MeshDevice* mesh_device() const {
TT_FATAL(this->mesh_device_.has_value(), "Tensor is not a mesh tensor");
return this->mesh_device_.value();
}

IDevice* device() const {
if (this->mesh_device_.has_value()) {
return this->mesh_device_.value();
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/tensor/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ Tensor get_shard_for_device(const Tensor& tensor, IDevice* target_device, std::o
DeviceStorage{s.get_buffer_for_device(target_device)}, s.get_tensor_spec_for_device(target_device)};
} else {
*/
// TODO(jchu): Handle buffer_index.
if constexpr (std::is_same_v<T, MultiDeviceHostStorage>) {
return Tensor{
OwnedStorage{s.get_buffer(buffer_index.value())}, s.get_tensor_spec(buffer_index.value())};
Expand Down

0 comments on commit 71a3899

Please sign in to comment.