-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 709162113 Change-Id: I0f6e2ae742dd1b574d1b6d0fb2f8e9807c685fe1
- Loading branch information
1 parent
884aac9
commit 3f187ea
Showing
2 changed files
with
155 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""General purpose registy for registering classes or functions. | ||
The registry can be used along with decorators to record any class/function. | ||
Sample usage: | ||
# Setup registry with decorator. | ||
_REGISTRY = registry.Registry() | ||
def register(cls): | ||
_REGISTRY.register(cls) | ||
def lookup(name): | ||
return _REGISTRY.lookup(name) | ||
# Register instances. | ||
@register | ||
def foo_task(): | ||
... | ||
@register | ||
def bar_task(): | ||
... | ||
# Retrieve instances. | ||
def my_executor(): | ||
... | ||
my_task = lookup("foo_task") | ||
... | ||
""" | ||
|
||
|
||
class Registry(object): | ||
"""A registry class to record class representations or function objects.""" | ||
|
||
def __init__(self): | ||
"""Initializes the registry.""" | ||
self._container = {} | ||
|
||
def register(self, item, name=None): | ||
"""Register an item. | ||
Args: | ||
item: Python item to be recorded. | ||
name: Optional name to be used for recording item. If not provided, | ||
item.__name__ is used. | ||
""" | ||
if not name: | ||
name = item.__name__ | ||
self._container[name] = item | ||
|
||
def lookup(self, name): | ||
"""Retrieves an item from the registry. | ||
Args: | ||
name: Name of the item to lookup. | ||
Returns: | ||
Registered item from the registry. | ||
""" | ||
return self._container[name] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Unit tests for registry.""" | ||
|
||
from numpy.testing import assert_equal | ||
from numpy.testing import assert_raises | ||
import pytest | ||
|
||
from qkeras import registry | ||
|
||
|
||
def sample_function(arg): | ||
"""Sample function for testing.""" | ||
return arg | ||
|
||
|
||
class SampleClass(object): | ||
"""Sample class for testing.""" | ||
|
||
def __init__(self, arg): | ||
self._arg = arg | ||
|
||
def get_arg(self): | ||
return self._arg | ||
|
||
|
||
def test_register_function(): | ||
reg = registry.Registry() | ||
reg.register(sample_function) | ||
registered_function = reg.lookup('sample_function') | ||
# Call the function to validate. | ||
assert_equal(registered_function, sample_function) | ||
|
||
|
||
def test_register_class(): | ||
reg = registry.Registry() | ||
reg.register(SampleClass) | ||
registered_class = reg.lookup('SampleClass') | ||
# Create and call class object to validate. | ||
assert_equal(SampleClass, registered_class) | ||
|
||
|
||
def test_register_with_name(): | ||
reg = registry.Registry() | ||
name = 'NewSampleClass' | ||
reg.register(SampleClass, name=name) | ||
registered_class = reg.lookup(name) | ||
# Create and call class object to validate. | ||
assert_equal(SampleClass, registered_class) | ||
|
||
|
||
def test_lookup_missing_item(): | ||
reg = registry.Registry() | ||
assert_raises(KeyError, reg.lookup, 'foo') | ||
|
||
|
||
def test_lookup_missing_name(): | ||
reg = registry.Registry() | ||
sample_class = SampleClass(arg=1) | ||
# objects don't have a default __name__ attribute. | ||
assert_raises(AttributeError, reg.register, sample_class) | ||
|
||
# check that the object can be retrieved with a registered name. | ||
reg.register(sample_class, 'sample_class') | ||
assert_equal(sample_class, reg.lookup('sample_class')) | ||
|
||
|
||
if __name__ == '__main__': | ||
pytest.main([__file__]) |