From cc8f9e09002c12159bf87cea53554b34d97d91f2 Mon Sep 17 00:00:00 2001 From: Jeffrey Jiang Date: Tue, 28 Jan 2025 22:10:16 +0000 Subject: [PATCH] fix graph tracing errors in test_convert_python, fix extraneous tt_fatal type error, improve bfloat16 conversions --- .../tensor/test_convert_python_tensor.py | 20 ------------------- ttnn/cpp/pybind11/pytensor.cpp | 2 +- ttnn/cpp/ttnn/tensor/tensor.cpp | 8 ++------ 3 files changed, 3 insertions(+), 27 deletions(-) diff --git a/tests/ttnn/unit_tests/tensor/test_convert_python_tensor.py b/tests/ttnn/unit_tests/tensor/test_convert_python_tensor.py index 7cfdf8bcc08..087221c57f1 100644 --- a/tests/ttnn/unit_tests/tensor/test_convert_python_tensor.py +++ b/tests/ttnn/unit_tests/tensor/test_convert_python_tensor.py @@ -20,25 +20,16 @@ def test_convert_python_tensor(device, size, mode, dtype): torch.manual_seed(0) - ttnn.graph.begin_graph_capture(mode) torch_input_tensor = torch.rand((size,), (dtype)) input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device) output_tensor = ttnn.to_torch(input_tensor, torch_rank=1) captured_graph = ttnn.graph.end_graph_capture() - calltrace = ttnn.graph.extract_calltrace(captured_graph) assert output_tensor == input_tensor # note: change this test case if force_disable_borrow is exposed to user assert output_tensor.storage_type() == ttnn.StorageType.BORROWED - assert "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor" in calltrace - assert captured_graph[0]["node_type"] == "capture_start" - assert captured_graph[1]["node_type"] == "function_start" - assert captured_graph[1]["params"]["name"] == "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor" - assert captured_graph[-2]["node_type"] == "buffer_deallocate" - assert captured_graph[-1]["node_type"] == "capture_end" - @pytest.mark.parametrize("size", [64]) @pytest.mark.parametrize("mode", [ttnn.graph.RunMode.NO_DISPATCH, ttnn.graph.RunMode.NORMAL]) @@ -46,19 +37,8 @@ def test_convert_python_tensor(device, size, mode, dtype): def test_convert_python_tensor_bfp_b(device, size, mode, dtype): torch.manual_seed(0) - ttnn.graph.begin_graph_capture(mode) torch_input_tensor = torch.rand((size,), torch.float) input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT, device=device, dtype=(dtype)) output_tensor = ttnn.to_torch(input_tensor, torch_rank=1) - captured_graph = ttnn.graph.end_graph_capture() - calltrace = ttnn.graph.extract_calltrace(captured_graph) - assert output_tensor == input_tensor assert output_tensor.storage_type() != ttnn.StorageType.BORROWED - - assert "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor" in calltrace - assert captured_graph[0]["node_type"] == "capture_start" - assert captured_graph[1]["node_type"] == "function_start" - assert captured_graph[1]["params"]["name"] == "tt::tt_metal::detail::convert_python_tensor_to_tt_tensor" - assert captured_graph[-2]["node_type"] == "buffer_deallocate" - assert captured_graph[-1]["node_type"] == "capture_end" diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 1e68d119f96..44921278004 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -145,7 +145,7 @@ Tensor create_tt_tensor_from_py_data( } case DataType::BFLOAT8_B: case DataType::BFLOAT4_B: { - return create_typed_tt_tensor_from_py_data( + return create_typed_tt_tensor_from_py_data( py_data_ptr, tensor_spec, device, on_creation_callback, on_destruction_callback, force_disable_borrow); } case DataType::INVALID: { diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index c635423af18..da332ef70ea 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -612,12 +612,8 @@ Tensor Tensor::from_span( return create_owned_tensor_from_row_major_data( std::vector(buffer.begin(), buffer.end()), spec, device); case DataType::BFLOAT16: { - std::vector bfloat16_data; - bfloat16_data.reserve(buffer.size()); - std::transform(std::begin(buffer), std::end(buffer), std::back_inserter(bfloat16_data), [](float value) { - return bfloat16(value); - }); - return create_owned_tensor_from_row_major_data(std::move(bfloat16_data), spec, device); + return create_owned_tensor_from_row_major_data( + std::vector(buffer.begin(), buffer.end()), spec, device); } case DataType::BFLOAT8_B: case DataType::BFLOAT4_B: {