Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of RTL layer #4

Merged
merged 5 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
{
"files.autoSave": "onFocusChange",
"editor.rulers": [88],
"editor.rulers": [
willbakst marked this conversation as resolved.
Show resolved Hide resolved
88
],
"editor.formatOnSaveMode": "file",
willbakst marked this conversation as resolved.
Show resolved Hide resolved
"editor.formatOnSave": true,
"files.insertFinalNewline": true,
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"

[tool.mypy]
exclude = ["venv"]
willbakst marked this conversation as resolved.
Show resolved Hide resolved

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
1 change: 0 additions & 1 deletion pytorch_lattice/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ class CategoricalCalibratorInit(_Enum):
class Monotonicity(_Enum):
"""Type of monotonicity constraint.

- NONE: no monotonicity constraint.
- INCREASING: increasing monotonicity i.e. increasing input increases output.
- DECREASING: decreasing monotonicity i.e. increasing input decreases output.
"""
Expand Down
1 change: 1 addition & 0 deletions pytorch_lattice/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .lattice import Lattice
from .linear import Linear
from .numerical_calibrator import NumericalCalibrator
from .rtl import RTL
274 changes: 274 additions & 0 deletions pytorch_lattice/layers/rtl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
"""Random Tiny Lattice module for use in calibrated modeling.
willbakst marked this conversation as resolved.
Show resolved Hide resolved

PyTorch implementation of a RTL layer.
This layer takes several inputs which would otherwise be slow to run on a single lattice
and runs random subsets on an assortment of Random Tiny Lattices as an optimization.
"""
import logging
from typing import List, Optional

import numpy as np
import torch

from ..enums import Interpolation, LatticeInit, Monotonicity
from .lattice import Lattice


class RTL(torch.nn.Module):
"""An RTL Module.
willbakst marked this conversation as resolved.
Show resolved Hide resolved

Layer takes a number of features that would otherwise be too many to assign to
willbakst marked this conversation as resolved.
Show resolved Hide resolved
a single lattice, and instead assigns small random subsets of the features to an
ensemble of smaller lattices. The features are shuffled and uniformly repeated
if there are more slots in the RTL than features.

Attributes:
- All `__init__` arguments.
- _lattice_layers: `dict` of form `{monotonic_count: (lattice, groups)}` which
willbakst marked this conversation as resolved.
Show resolved Hide resolved
keeps track of the RTL structure. Features are indexed then randomly grouped
together to be assigned to a lattice - groups with the same number of
monotonic features can be put into the same lattice for further optimization,
and are thus stored together in the dict according to `monotonic_count`.

Example:
`python
willbakst marked this conversation as resolved.
Show resolved Hide resolved
inputs=torch.tensor(...) # shape: (batch_size, D)
willbakst marked this conversation as resolved.
Show resolved Hide resolved
monotonicities = List[Monotonicity...] # len: D
willbakst marked this conversation as resolved.
Show resolved Hide resolved
rtl1=RTL(
willbakst marked this conversation as resolved.
Show resolved Hide resolved
monotonicities,
num_lattices = 5
willbakst marked this conversation as resolved.
Show resolved Hide resolved
lattice_rank = 3, # num_lattices * lattice_rank must be greater than D
willbakst marked this conversation as resolved.
Show resolved Hide resolved
)
output1 = rtl1(inputs)
willbakst marked this conversation as resolved.
Show resolved Hide resolved

# If you want to pass through consecutive RTLs
willbakst marked this conversation as resolved.
Show resolved Hide resolved

rtl2 = RTL(
monotonicities=rtl1.output_monotonicities() # len: rtl1.num_lattices
...
)
output2 = RTL2(output1)
`
"""

def __init__(
self,
monotonicities: List[Monotonicity],
willbakst marked this conversation as resolved.
Show resolved Hide resolved
num_lattices: int,
lattice_rank: int,
lattice_size: int = 2,
output_min: Optional[float] = None,
output_max: Optional[float] = None,
kernel_init: LatticeInit = LatticeInit.LINEAR,
clip_inputs: bool = True,
interpolation: Interpolation = Interpolation.HYPERCUBE,
average_outputs: bool = False,
random_seed: int = 42,
) -> None:
"""Initializes an instance of 'RTL'.

Args:
monotonicities: `List` of `Monotonicity.INCREASING` or `None`
willbakst marked this conversation as resolved.
Show resolved Hide resolved
indicating monotonicities of input features, ordered respectively.
num_lattices: number of lattices in RTL structure.
lattice_rank: number of inputs for each lattice in RTL structure.
output_min: Minimum output of each lattice in RTL.
output_max: Maximum output of each lattice in RTL.
kernel_init: Initialization scheme to use for lattices.
clip_inputs: Whether input should be clipped to the range of each lattice.
interpolation: Interpolation scheme for each lattice in RTL.
average_outputs: Whether to average the outputs of every lattice RTL.
random_seed: seed used for shuffling.

Raises:
ValueError: if size of RTL, determined by `num_lattices * lattice_rank`, is
willbakst marked this conversation as resolved.
Show resolved Hide resolved
too small to support the number of input features.
"""
super().__init__()

if len(monotonicities) > num_lattices * lattice_rank:
raise ValueError(
f"RTL with {num_lattices}x{lattice_rank}D structure cannot support "
+ f"{len(monotonicities)} input features."
)
self.monotonicities = monotonicities
self.num_lattices = num_lattices
self.lattice_rank = lattice_rank
self.lattice_size = lattice_size
self.output_min = output_min
self.output_max = output_max
self.kernel_init = kernel_init
self.clip_inputs = clip_inputs
self.interpolation = interpolation
self.average_outputs = average_outputs
self.random_seed = random_seed

rtl_indices = np.array(
[i % len(self.monotonicities) for i in range(num_lattices * lattice_rank)]
)
np.random.seed(self.random_seed)
np.random.shuffle(rtl_indices)
# split_rtl_indices = np.split(rtl_indices, num_lattices)
willbakst marked this conversation as resolved.
Show resolved Hide resolved
split_rtl_indices = [list(arr) for arr in np.split(rtl_indices, num_lattices)]
swapped_rtl_indices = self._ensure_unique_sublattices(split_rtl_indices)
monotonicity_groupings = {}
for lattice_indices in swapped_rtl_indices:
monotonic_count = sum(
1
for idx in lattice_indices
if self.monotonicities[idx] == Monotonicity.INCREASING
)
if monotonic_count not in monotonicity_groupings:
monotonicity_groupings[monotonic_count] = [lattice_indices]
else:
monotonicity_groupings[monotonic_count].append(lattice_indices)
for monotonic_count, groups in monotonicity_groupings.items():
for i, lattice_indices in enumerate(groups):
sorted_indices = sorted(
lattice_indices,
key=lambda x: (self.monotonicities[x] is None),
reverse=False,
)
groups[i] = sorted_indices

self._lattice_layers = {}
for monotonic_count, groups in monotonicity_groupings.items():
self._lattice_layers[monotonic_count] = (
Lattice(
lattice_sizes=[self.lattice_size] * self.lattice_rank,
output_min=self.output_min,
output_max=self.output_max,
kernel_init=self.kernel_init,
monotonicities=[Monotonicity.INCREASING] * monotonic_count
+ [None] * (lattice_rank - monotonic_count),
clip_inputs=self.clip_inputs,
interpolation=self.interpolation,
units=len(groups),
),
groups,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward method computed by using forward methods of each lattice in RTL.
willbakst marked this conversation as resolved.
Show resolved Hide resolved

Args:
x: input tensor of feature values with shape `(batch_size, num_features)`.

Returns:
torch.Tensor containing the outputs of each lattice within RTL structure. If
willbakst marked this conversation as resolved.
Show resolved Hide resolved
`average_outputs == True`, then all outputs are averaged into a tensor of
shape `(batch_size, 1)`. If `average_outputs == False`, shape of tensor is
`(batch_size, num_lattices)`.
"""
forward_results = []
for _, (lattice, group) in sorted(self._lattice_layers.items()):
if len(group) > 1:
lattice_input = torch.stack([x[:, idx] for idx in group], dim=-2)
else:
lattice_input = x[:, group[0]]
forward_results.append(lattice.forward(lattice_input))
result = torch.cat(forward_results, dim=-1)
if not self.average_outputs:
return result
result = torch.mean(result, dim=-1, keepdim=True)

return result

@torch.no_grad()
def output_monotonicities(self) -> List[Monotonicity]:
"""Gives the monotonicities of the outputs of RTL.

Returns:
List of `Monotonicity` corresponding to each output of the RTL layer, in the
same order as outputs.
"""
monotonicities = []
for monotonic_count, (lattice, _) in sorted(self._lattice_layers.items()):
if monotonic_count:
monotonicity = Monotonicity.INCREASING
else:
monotonicity = None
for _ in range(lattice.units):
monotonicities.append(monotonicity)

return monotonicities

@torch.no_grad()
def constrain(self) -> None:
"""Enforces constraints for each lattice in RTL."""
for lattice, _ in self._lattice_layers.values():
lattice.constrain()

@torch.no_grad()
def assert_constraints(self, eps=1e-6) -> List[List[str]]:
willbakst marked this conversation as resolved.
Show resolved Hide resolved
"""Asserts that each Lattice in RTL satisfies all constraints.

Args:
eps: allowed constraints violations.

Returns:
List of lists, each with constraints violations for an individual Lattice.
"""
return list(
lattice.assert_constraints(eps=eps)
for lattice, _ in self._lattice_layers.values()
)

@staticmethod
def _ensure_unique_sublattices(
rtl_indices: List[List[int]],
max_swaps: int = 10000,
) -> List[List[int]]:
"""Attempts to ensure every lattice in RTL structure contains unique features.

Args:
rtl_indices: list of lists where inner lists are groupings of
indices of input features to RTL layer.
max_swaps: maximum number of swaps to perform before giving up.

Returns:
List of lists where elements between inner lists have been swapped in
an attempt to remove any duplicates from every grouping.
"""
swaps = 0
num_sublattices = len(rtl_indices)

def find_swap_candidate(current_index, element):
"""Helper function to find the next sublattice not containing element."""
for offset in range(1, num_sublattices):
candidate_index = (current_index + offset) % num_sublattices
if element not in rtl_indices[candidate_index]:
return candidate_index
return None

for i, sublattice in enumerate(rtl_indices):
unique_elements = set()
for element in sublattice:
if element in unique_elements:
swap_with = find_swap_candidate(i, element)
if swap_with is not None:
for swap_element in rtl_indices[swap_with]:
if swap_element not in sublattice:
# Perform the swap
idx_to_swap = rtl_indices[swap_with].index(swap_element)
idx_duplicate = sublattice.index(element)
(
rtl_indices[swap_with][idx_to_swap],
sublattice[idx_duplicate],
) = element, swap_element
swaps += 1
break
else:
logging.info(
"Some lattices in RTL may use the same feature multiple "
"times."
)
return rtl_indices
else:
unique_elements.add(element)
if swaps >= max_swaps:
logging.info(
"Some lattices in RTL may use the same feature multiple times."
)
return rtl_indices
return rtl_indices
Loading
Loading