Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 709162113
Change-Id: I0f6e2ae742dd1b574d1b6d0fb2f8e9807c685fe1
  • Loading branch information
Akshaya Purohit authored and copybara-github committed Dec 23, 2024
1 parent 884aac9 commit 3f187ea
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
73 changes: 73 additions & 0 deletions qkeras/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""General purpose registy for registering classes or functions.
The registry can be used along with decorators to record any class/function.
Sample usage:
# Setup registry with decorator.
_REGISTRY = registry.Registry()
def register(cls):
_REGISTRY.register(cls)
def lookup(name):
return _REGISTRY.lookup(name)
# Register instances.
@register
def foo_task():
...
@register
def bar_task():
...
# Retrieve instances.
def my_executor():
...
my_task = lookup("foo_task")
...
"""


class Registry(object):
"""A registry class to record class representations or function objects."""

def __init__(self):
"""Initializes the registry."""
self._container = {}

def register(self, item, name=None):
"""Register an item.
Args:
item: Python item to be recorded.
name: Optional name to be used for recording item. If not provided,
item.__name__ is used.
"""
if not name:
name = item.__name__
self._container[name] = item

def lookup(self, name):
"""Retrieves an item from the registry.
Args:
name: Name of the item to lookup.
Returns:
Registered item from the registry.
"""
return self._container[name]
82 changes: 82 additions & 0 deletions tests/registry_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2024 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for registry."""

from numpy.testing import assert_equal
from numpy.testing import assert_raises
import pytest

from qkeras import registry


def sample_function(arg):
"""Sample function for testing."""
return arg


class SampleClass(object):
"""Sample class for testing."""

def __init__(self, arg):
self._arg = arg

def get_arg(self):
return self._arg


def test_register_function():
reg = registry.Registry()
reg.register(sample_function)
registered_function = reg.lookup('sample_function')
# Call the function to validate.
assert_equal(registered_function, sample_function)


def test_register_class():
reg = registry.Registry()
reg.register(SampleClass)
registered_class = reg.lookup('SampleClass')
# Create and call class object to validate.
assert_equal(SampleClass, registered_class)


def test_register_with_name():
reg = registry.Registry()
name = 'NewSampleClass'
reg.register(SampleClass, name=name)
registered_class = reg.lookup(name)
# Create and call class object to validate.
assert_equal(SampleClass, registered_class)


def test_lookup_missing_item():
reg = registry.Registry()
assert_raises(KeyError, reg.lookup, 'foo')


def test_lookup_missing_name():
reg = registry.Registry()
sample_class = SampleClass(arg=1)
# objects don't have a default __name__ attribute.
assert_raises(AttributeError, reg.register, sample_class)

# check that the object can be retrieved with a registered name.
reg.register(sample_class, 'sample_class')
assert_equal(sample_class, reg.lookup('sample_class'))


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 3f187ea

Please sign in to comment.