Skip to content

Commit

Permalink
Fix (backport/fx): remove pytree warning (#1144)
Browse files Browse the repository at this point in the history
* Fix: pytree warning

* Fix: backportability
  • Loading branch information
i-colbert authored Jan 3, 2025
1 parent 39ce837 commit 3aadf17
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/brevitas/backport/fx/immutable_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@

from typing import Any, Dict, List, Tuple

from torch.utils._pytree import _register_pytree_node
try:
from torch.utils._pytree import register_pytree_node
except:
# Deprecated as of 2.3, but keeping for backportability
from torch.utils._pytree import _register_pytree_node
register_pytree_node = _register_pytree_node
from torch.utils._pytree import Context

from ._compatibility import compatibility
Expand Down Expand Up @@ -111,5 +116,5 @@ def _immutable_list_unflatten(values: List[Any], context: Context) -> List[Any]:
return immutable_list(values)


_register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten)
_register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)
register_pytree_node(immutable_dict, _immutable_dict_flatten, _immutable_dict_unflatten)
register_pytree_node(immutable_list, _immutable_list_flatten, _immutable_list_unflatten)

0 comments on commit 3aadf17

Please sign in to comment.