Skip to content

Commit

Permalink
Adding Missing includes and explicitly declaring Tensor in aten names…
Browse files Browse the repository at this point in the history
…pace.

Summary:
X-link: facebookresearch/FBGEMM#713

Build failures otherwise.

Reviewed By: daulet-askarov, jiayulu

Differential Revision: D68858972

fbshipit-source-id: 4427dfd82cf1aef49efe0ec89e63335625e9a349
  • Loading branch information
pradeepfn authored and facebook-github-bot committed Jan 30, 2025
1 parent 5def5ca commit 3049ebe
Showing 1 changed file with 12 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#pragma once

#include "kv_tensor_wrapper.h"
#include "ssd_table_batched_embeddings.h"

namespace ssd {

Expand Down Expand Up @@ -58,30 +59,34 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
enable_async_update)) {}

void set_cuda(
Tensor indices,
Tensor weights,
Tensor count,
at::Tensor indices,
at::Tensor weights,
at::Tensor count,
int64_t timestep,
bool is_bwd) {
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
}

void get_cuda(Tensor indices, Tensor weights, Tensor count) {
void get_cuda(at::Tensor indices, at::Tensor weights, at::Tensor count) {
return impl_->get_cuda(indices, weights, count);
}

void set(Tensor indices, Tensor weights, Tensor count) {
void set(at::Tensor indices, at::Tensor weights, at::Tensor count) {
return impl_->set(indices, weights, count);
}

void set_range_to_storage(
const Tensor& weights,
const at::Tensor& weights,
const int64_t start,
const int64_t length) {
return impl_->set_range_to_storage(weights, start, length);
}

void get(Tensor indices, Tensor weights, Tensor count, int64_t sleep_ms) {
void get(
at::Tensor indices,
at::Tensor weights,
at::Tensor count,
int64_t sleep_ms) {
return impl_->get(indices, weights, count, sleep_ms);
}

Expand Down

0 comments on commit 3049ebe

Please sign in to comment.