Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

fixed bugs #180

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/src/wholememory/embedding_optimizer.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -80,7 +80,6 @@ wholememory_tensor_t embedding_optimizer_impl_base::get_optimizer_state(
WHOLEMEMORY_CHECK_NOTHROW(optimizer_state != nullptr);
WHOLEMEMORY_CHECK_NOTHROW(state_names_.size() == optimizer_state->cachable_states.size() +
optimizer_state->uncachable_states.size() + 1);
WHOLEMEMORY_FAIL_NOTHROW("optimizer state name %s not found for %s", state_name, name_);
for (size_t i = 0; i < optimizer_state->cachable_states.size(); i++) {
if (strcmp(state_name, optimizer_state->cachable_states[i].name.c_str()) == 0) {
WHOLEMEMORY_CHECK_NOTHROW(strcmp(state_name, state_names_[i]) == 0);
Expand All @@ -94,6 +93,7 @@ wholememory_tensor_t embedding_optimizer_impl_base::get_optimizer_state(
return optimizer_state->uncachable_states[i].global_raw_sub_tensor;
}
}
WHOLEMEMORY_FAIL_NOTHROW("optimizer state name %s not found for %s", state_name, name_);
return nullptr;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -828,9 +828,10 @@ cdef class PyWholeMemoryEmbedding:
result = []
cdef const char * const * state_names
state_names = wholememory_embedding_get_optimizer_state_names(self.wm_embedding)
while state_names[i] != NULL:
result.append(<object> PyUnicode_FromString(state_names[i]))
i += 1
if state_names != NULL:
while state_names[i] != NULL:
result.append(<object> PyUnicode_FromString(state_names[i]))
i += 1
return result

def get_optimizer_state(self,
Expand Down
2 changes: 1 addition & 1 deletion python/pylibwholegraph/pylibwholegraph/torch/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __init__(
super().__init__()
self.wmb_embedding = wmb_embedding
self.embedding_tensor = None
self.optimizer_states = None
self.optimizer_states = dict()

self.wmb_optimizer = wmb_optimizer
self.wmb_cache_policy = wmb_cache_policy
Expand Down