Skip to content

Commit

Permalink
Adding couple more APIs to KVTensorWrapper to bring partiy with torch…
Browse files Browse the repository at this point in the history
…::Tensor (#3645)

Summary:
X-link: facebookresearch/FBGEMM#721


Differential Revision: D68934783
  • Loading branch information
pradeepfn authored and facebook-github-bot committed Feb 3, 2025
1 parent 98d54f7 commit f80d4ca
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
const int64_t length,
const at::Tensor& weights);

c10::IntArrayRef size();
c10::IntArrayRef sizes();

c10::IntArrayRef strides();

c10::ScalarType dtype();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,16 @@ void KVTensorWrapper::set_range(
FBEXCEPTION("Not implemented");
}

c10::IntArrayRef KVTensorWrapper::size() {
c10::IntArrayRef KVTensorWrapper::sizes() {
FBEXCEPTION("Not implemented");
return shape_;
}

c10::IntArrayRef KVTensorWrapper::strides() {
FBEXCEPTION("Not implemented");
return shape_; // make linter happy.
}

c10::ScalarType KVTensorWrapper::dtype() {
FBEXCEPTION("Not implemented");
return options_.dtype().toScalarType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,20 @@ void KVTensorWrapper::set_range(
}
}

c10::IntArrayRef KVTensorWrapper::size() {
c10::IntArrayRef KVTensorWrapper::sizes() {
return shape_;
}

c10::IntArrayRef KVTensorWrapper::strides() {
// Assume contiguous tensor.
std::vector<int64_t> strides(shape_.size(), 1);
for (int i = shape_.size() - 2; i > -1; i--) {
int prev = i + 1;
strides[i] = strides[prev] * std::max<int64_t>(shape_[prev], 1);
}
return strides;
}

c10::ScalarType KVTensorWrapper::dtype() {
return options_.dtype().toScalarType();
}
Expand Down Expand Up @@ -500,7 +510,7 @@ static auto kv_tensor_wrapper =
.def_property("layout_str", &KVTensorWrapper::layout_str)
.def_property(
"shape",
&KVTensorWrapper::size,
&KVTensorWrapper::sizes,
std::string(
"Returns the shape of the original tensor. Only the narrowed part is materialized."));

Expand Down

0 comments on commit f80d4ca

Please sign in to comment.