From a693b65b9e2abce6008d181e71f52d67de6bd28c Mon Sep 17 00:00:00 2001 From: Jackson Nie Date: Tue, 25 Feb 2025 01:40:23 +0000 Subject: [PATCH] Support Int32 datatype --- tt_torch/csrc/bindings.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tt_torch/csrc/bindings.cpp b/tt_torch/csrc/bindings.cpp index 9b61b409..87572ae7 100644 --- a/tt_torch/csrc/bindings.cpp +++ b/tt_torch/csrc/bindings.cpp @@ -22,8 +22,10 @@ static tt::target::DataType torch_scalar_type_to_dt(torch::ScalarType st) { return tt::target::DataType::UInt8; case torch::ScalarType::Short: return tt::target::DataType::UInt16; - case torch::ScalarType::Int: + case torch::ScalarType::UInt32: return tt::target::DataType::UInt32; + case torch::ScalarType::Int: + return tt::target::DataType::Int32; case torch::ScalarType::Long: return tt::target::DataType::UInt32; case torch::ScalarType::Half: @@ -54,6 +56,8 @@ static torch::ScalarType dt_to_torch_scalar_type(tt::target::DataType df) { case tt::target::DataType::UInt16: return torch::ScalarType::Short; case tt::target::DataType::UInt32: + return torch::ScalarType::UInt32; + case tt::target::DataType::Int32: return torch::ScalarType::Int; case tt::target::DataType::Float16: return torch::ScalarType::Half;