Skip to content

Commit

Permalink
fix graph tracing errors in test_convert_python, fix extraneous tt_fa…
Browse files Browse the repository at this point in the history
…tal type error, improve bfloat16 conversions
  • Loading branch information
jjiangTT committed Jan 28, 2025
1 parent ff70a26 commit cc8f9e0
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 27 deletions.
20 changes: 0 additions & 20 deletions tests/ttnn/unit_tests/tensor/test_convert_python_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,45 +20,25 @@
def test_convert_python_tensor(device, size, mode, dtype):
torch.manual_seed(0)

ttnn.graph.begin_graph_capture(mode)
torch_input_tensor = torch.rand((size,), (dtype))
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.to_torch(input_tensor, torch_rank=1)
captured_graph = ttnn.graph.end_graph_capture()
calltrace = ttnn.graph.extract_calltrace(captured_graph)

assert output_tensor == input_tensor

# note: change this test case if force_disable_borrow is exposed to user
assert output_tensor.storage_type() == ttnn.StorageType.BORROWED

assert "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor" in calltrace
assert captured_graph[0]["node_type"] == "capture_start"
assert captured_graph[1]["node_type"] == "function_start"
assert captured_graph[1]["params"]["name"] == "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor"
assert captured_graph[-2]["node_type"] == "buffer_deallocate"
assert captured_graph[-1]["node_type"] == "capture_end"


@pytest.mark.parametrize("size", [64])
@pytest.mark.parametrize("mode", [ttnn.graph.RunMode.NO_DISPATCH, ttnn.graph.RunMode.NORMAL])
@pytest.mark.parametrize("dtype", [ttnn.bfloat4_b, ttnn.bfloat8_b])
def test_convert_python_tensor_bfp_b(device, size, mode, dtype):
torch.manual_seed(0)

ttnn.graph.begin_graph_capture(mode)
torch_input_tensor = torch.rand((size,), torch.float)
input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device, dtype=(dtype))
output_tensor = ttnn.to_torch(input_tensor, torch_rank=1)
captured_graph = ttnn.graph.end_graph_capture()
calltrace = ttnn.graph.extract_calltrace(captured_graph)

assert output_tensor == input_tensor
assert output_tensor.storage_type() != ttnn.StorageType.BORROWED

assert "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor" in calltrace
assert captured_graph[0]["node_type"] == "capture_start"
assert captured_graph[1]["node_type"] == "function_start"
assert captured_graph[1]["params"]["name"] == "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor"
assert captured_graph[-2]["node_type"] == "buffer_deallocate"
assert captured_graph[-1]["node_type"] == "capture_end"
2 changes: 1 addition & 1 deletion ttnn/cpp/pybind11/pytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ Tensor create_tt_tensor_from_py_data(
}
case DataType::BFLOAT8_B:
case DataType::BFLOAT4_B: {
return create_typed_tt_tensor_from_py_data<bfloat16>(
return create_typed_tt_tensor_from_py_data<float>(
py_data_ptr, tensor_spec, device, on_creation_callback, on_destruction_callback, force_disable_borrow);
}
case DataType::INVALID: {
Expand Down
8 changes: 2 additions & 6 deletions ttnn/cpp/ttnn/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,12 +612,8 @@ Tensor Tensor::from_span<float>(
return create_owned_tensor_from_row_major_data(
std::vector<float>(buffer.begin(), buffer.end()), spec, device);
case DataType::BFLOAT16: {
std::vector<bfloat16> bfloat16_data;
bfloat16_data.reserve(buffer.size());
std::transform(std::begin(buffer), std::end(buffer), std::back_inserter(bfloat16_data), [](float value) {
return bfloat16(value);
});
return create_owned_tensor_from_row_major_data(std::move(bfloat16_data), spec, device);
return create_owned_tensor_from_row_major_data(
std::vector<bfloat16>(buffer.begin(), buffer.end()), spec, device);
}
case DataType::BFLOAT8_B:
case DataType::BFLOAT4_B: {
Expand Down

0 comments on commit cc8f9e0

Please sign in to comment.