diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 53dd90cd..c8e75c96 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -67,7 +67,7 @@ jobs: #---------------------------------------------- - name: Test with coverage run: | - poetry run pytest --cov=pgmax --cov-report=xml + poetry run pytest --cov=pgmax --cov-report=xml --cov-report=term-missing:skip-covered #---------------------------------------------- # upload coverage report to codecov #---------------------------------------------- diff --git a/examples/heretic_example.py b/examples/heretic_example.py index 1df57561..e9d64217 100644 --- a/examples/heretic_example.py +++ b/examples/heretic_example.py @@ -8,10 +8,12 @@ # format_version: '1.3' # jupytext_version: 1.11.4 # kernelspec: -# display_name: 'Python 3.7.11 64-bit (''pgmax-zIh0MZVc-py3.7'': venv)' -# name: python371164bitpgmaxzih0mzvcpy37venve540bb1b5cdf4292a3f5a12c4904cc40 +# display_name: Python 3 +# language: python +# name: python3 # --- +# %% from timeit import default_timer as timer from typing import Any, List, Tuple @@ -91,14 +93,10 @@ # Create the factor graph fg = graph.FactorGraph((pixel_vars, hidden_vars)) -# Assign evidence to pixel vars -fg.set_evidence(0, np.array(bXn_evidence)) -fg.set_evidence(1, np.array(bHn_evidence)) - - # %% [markdown] # # Add all Factors to graph via constructing FactorGroups + # %% tags=[] def binary_connected_variables( num_hidden_rows, num_hidden_cols, kernel_row, kernel_col @@ -118,7 +116,7 @@ def binary_connected_variables( W_pot = W_orig.swapaxes(0, 1) for k_row in range(3): for k_col in range(3): - fg.add_factors( + fg.add_factor( factor_factory=groups.PairwiseFactorGroup, connected_var_keys=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], @@ -174,16 +172,22 @@ def custom_flatten_ordering(Mdown, Mup): reshaped_Mdown = Mdown.reshape(3, 3, 3, 28, 28) reshaped_Mup = Mup.reshape(17, 3, 3, 28, 28) -init_msgs = jax.device_put( - custom_flatten_ordering(np.array(reshaped_Mdown), np.array(reshaped_Mup)) -) - # %% [markdown] # # Run Belief Propagation and Retrieve MAP Estimate # %% tags=[] # Run BP +init_msgs = fg.get_init_msgs() +init_msgs.ftov = graph.FToVMessages( + factor_graph=fg, + init_value=jax.device_put( + custom_flatten_ordering(np.array(reshaped_Mdown), np.array(reshaped_Mup)) + ), +) +init_msgs.evidence[0] = np.array(bXn_evidence) +init_msgs.evidence[1] = np.array(bHn_evidence) bp_start_time = timer() +# Assign evidence to pixel vars final_msgs = fg.run_bp( 500, 0.5, diff --git a/examples/ising_model.py b/examples/ising_model.py new file mode 100644 index 00000000..9b58b37f --- /dev/null +++ b/examples/ising_model.py @@ -0,0 +1,87 @@ +# --- +# jupyter: +# jupytext: +# formats: ipynb,py:percent +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.4 +# kernelspec: +# display_name: Python 3 +# language: python +# name: python3 +# --- + +# %% +# %matplotlib inline +import matplotlib.pyplot as plt +import numpy as np + +from pgmax.fg import graph, groups + +# %% [markdown] +# ### Construct variable grid, initialize factor graph, and add factors + +# %% +variables = groups.NDVariableArray(variable_size=2, shape=(50, 50)) +fg = graph.FactorGraph(variables=variables, evidence_default_mode="random") +connected_var_keys = [] +for ii in range(50): + for jj in range(50): + kk = (ii + 1) % 50 + ll = (jj + 1) % 50 + connected_var_keys.append([(ii, jj), (kk, jj)]) + connected_var_keys.append([(ii, jj), (ii, ll)]) + +fg.add_factor( + factor_factory=groups.PairwiseFactorGroup, + connected_var_keys=connected_var_keys, + log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]), + name="factors", +) + +# %% [markdown] +# ### Run inference and visualize results + +# %% +msgs = fg.run_bp(3000, 0.5) +map_states = fg.decode_map_states(msgs) +img = np.zeros((50, 50)) +for key in map_states: + img[key] = map_states[key] + +fig, ax = plt.subplots(1, 1, figsize=(10, 10)) +ax.imshow(img) + +# %% [markdown] +# ### Message and evidence manipulation + +# %% +# Query evidence for variable (0, 0) +msgs.evidence[0, 0] + +# %% +# Set evidence for variable (0, 0) +msgs.evidence[0, 0] = np.array([1.0, 1.0]) +msgs.evidence[0, 0] + +# %% +# Set evidence for all variables using an array +evidence = np.random.randn(50, 50, 2) +msgs.evidence[:] = evidence +msgs.evidence[10, 10] == evidence[10, 10] + +# %% +# Query messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0) +msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)] + +# %% +# Set messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0) +msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)] = np.array([1.0, 1.0]) +msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)] + +# %% +# Uniformly spread expected belief at a variable to all connected factors +msgs.ftov[0, 0] = np.array([1.0, 1.0]) +msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)] diff --git a/examples/sanity_check_example.py b/examples/sanity_check_example.py index 6b991fcf..cc519bd2 100644 --- a/examples/sanity_check_example.py +++ b/examples/sanity_check_example.py @@ -8,7 +8,8 @@ # format_version: '1.3' # jupytext_version: 1.11.4 # kernelspec: -# display_name: 'Python 3.8.5 64-bit (''pgmax-JcKb81GE-py3.8'': poetry)' +# display_name: Python 3 +# language: python # name: python3 # --- @@ -154,18 +155,12 @@ # %% # Create the factor graph -fg = graph.FactorGraph( - variable_groups=composite_grid_group, -) - -# Set the evidence -fg.set_evidence("grid_vars", grid_evidence_arr) -fg.set_evidence("additional_vars", additional_vars_evidence_dict) - +fg = graph.FactorGraph(variables=composite_grid_group) # %% [markdown] # ### Create Valid Configuration Arrays + # %% # Helper function to easily generate a list of valid configurations for a given suppression diameter def create_valid_suppression_config_arr(suppression_diameter): @@ -268,7 +263,7 @@ def create_valid_suppression_config_arr(suppression_diameter): ("additional_vars", 1, row + 1, col), ] - fg.add_factors( + fg.add_factor( curr_keys, valid_configs_non_supp, np.zeros(valid_configs_non_supp.shape[0], dtype=float), @@ -312,21 +307,18 @@ def create_valid_suppression_config_arr(suppression_diameter): for c in range(start_col, start_col + SUPPRESSION_DIAMETER) ] ) -horz_suppression_group = groups.EnumerationFactorGroup( - composite_grid_group, horz_suppression_keys, valid_configs_supp -) # %% [markdown] # ### Add FactorGroups Remaining to FactorGraph # %% -fg.add_factors( +fg.add_factor( factor_factory=groups.EnumerationFactorGroup, connected_var_keys=vert_suppression_keys, factor_configs=valid_configs_supp, ) -fg.add_factors( +fg.add_factor( factor_factory=groups.EnumerationFactorGroup, connected_var_keys=horz_suppression_keys, factor_configs=valid_configs_supp, @@ -337,8 +329,12 @@ def create_valid_suppression_config_arr(suppression_diameter): # %% # Run BP +# Set the evidence +init_msgs = fg.get_init_msgs() +init_msgs.evidence["grid_vars"] = grid_evidence_arr +init_msgs.evidence["additional_vars"] = additional_vars_evidence_dict bp_start_time = timer() -final_msgs = fg.run_bp(1000, 0.5) +final_msgs = fg.run_bp(1000, 0.5, init_msgs=init_msgs) bp_end_time = timer() print(f"time taken for bp {bp_end_time - bp_start_time}") diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 6d63b9a4..55fcf778 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -1,8 +1,11 @@ """A module containing the core class to specify a Factor Graph.""" +from __future__ import annotations + +import typing from dataclasses import dataclass from types import MappingProxyType -from typing import Any, Dict, List, Mapping, Sequence, Tuple, Union +from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -17,55 +20,54 @@ class FactorGraph: """Class for representing a factor graph Args: - variable_groups: A container containing multiple VariableGroups, or a CompositeVariableGroup. - If not a CompositeVariableGroup, supported containers include mapping, sequence and single - VariableGroup. + variables: A single VariableGroup or a container containing variable groups. + If not a single VariableGroup, supported containers include mapping and sequence. For a mapping, the keys of the mapping are used to index the variable groups. For a sequence, the indices of the sequence are used to index the variable groups. - Note that a CompositeVariableGroup will be created from this input, and the individual - VariableGroups will need to be accessed by indexing this. - evidence_default: string representing a setting that specifies the default evidence value for - any variable whose evidence was not explicitly specified using 'set_evidence' + Note that if not a single VariableGroup, a CompositeVariableGroup will be created from + this input, and the individual VariableGroups will need to be accessed by indexing. + messages_default_mode: default mode for initializing messages. + Allowed values are "zeros" and "random". + evidence_default_mode: default mode for initializing evidence. + Allowed values are "zeros" and "random". + Any variable whose evidence was not explicitly specified using 'set_evidence' Attributes: - _composite_variable_group: CompositeVariableGroup. contains all involved VariableGroups + _variable_group: VariableGroup. contains all involved VariableGroups _factor_groups: List of added factor groups num_var_states: int. represents the sum of all variable states of all variables in the FactorGraph _vars_to_starts: MappingProxyType[nodes.Variable, int]. maps every variable to an int representing an index in the evidence array at which the first entry of the evidence for that particular variable should be placed. - _vars_to_evidence: Dict[nodes.Variable, np.ndarray]. maps every variable to an np.ndarray - representing the evidence for that variable _vargroups_set: Set[groups.VariableGroup]. keeps track of all the VariableGroup's that have been added to this FactorGraph + _named_factor_groups: Dict[Hashable, groups.FactorGroup]. A dictionary mapping the names of + named factor groups to the corresponding factor groups. + We only support setting messages from factors within explicitly named factor groups + to connected variables. + _total_factor_num_states: int. Current total number of edge states for the added factors. + _factor_group_to_starts: Dict[groups.FactorGroup, int]. Maps a factor group to its + corresponding starting index in the flat message array. """ - variable_groups: Union[ - Mapping[Any, groups.VariableGroup], + variables: Union[ + Mapping[Hashable, groups.VariableGroup], Sequence[groups.VariableGroup], groups.VariableGroup, ] + messages_default_mode: str = "zeros" evidence_default_mode: str = "zeros" def __post_init__(self): - if isinstance(self.variable_groups, groups.CompositeVariableGroup): - self._composite_variable_group = self.variable_groups - elif isinstance(self.variable_groups, groups.VariableGroup): - self._composite_variable_group = groups.CompositeVariableGroup( - [self.variable_groups] - ) + if isinstance(self.variables, groups.VariableGroup): + self._variable_group = self.variables else: - self._composite_variable_group = groups.CompositeVariableGroup( - self.variable_groups - ) + self._variable_group = groups.CompositeVariableGroup(self.variables) vars_num_states_cumsum = np.insert( np.array( - [ - variable.num_states - for variable in self._composite_variable_group.variables - ], + [variable.num_states for variable in self._variable_group.variables], dtype=int, ).cumsum(), 0, @@ -74,21 +76,21 @@ def __post_init__(self): self._vars_to_starts = MappingProxyType( { variable: vars_num_states_cumsum[vv] - for vv, variable in enumerate(self._composite_variable_group.variables) + for vv, variable in enumerate(self._variable_group.variables) } ) self.num_var_states = vars_num_states_cumsum[-1] - - self._vars_to_evidence: Dict[nodes.Variable, np.ndarray] = {} - self._factor_groups: List[groups.FactorGroup] = [] + self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {} + self._total_factor_num_states: int = 0 + self._factor_group_to_starts: Dict[groups.FactorGroup, int] = {} - def add_factors( + def add_factor( self, *args, **kwargs, ) -> None: - """Function to add factors to this FactorGraph. + """Function to add factor/factor group to this FactorGraph. Args: *args: optional sequence of arguments. If specified, and if there is no @@ -105,29 +107,83 @@ def add_factors( indices of variables ot be indexed to create the EnumerationFactor). If there is a "factor_factory" key, then these args are taken to specify the arguments to be used to construct the class specified by the - "factor_factory" argument. Note that either *args or **kwargs must be - specified. + "factor_factory" argument. + If there is a "name" key, we add the added factor/factor group to the list + of named factors within the factor graph. + Note that either *args or **kwargs must be specified. """ + name = kwargs.pop("name", None) + if name in self._named_factor_groups: + raise ValueError( + f"A factor group with the name {name} already exists. Please choose a different name!" + ) + factor_factory = kwargs.pop("factor_factory", None) if factor_factory is not None: - factor_group = factor_factory( - self._composite_variable_group, *args, **kwargs - ) + factor_group = factor_factory(self._variable_group, *args, **kwargs) else: if len(args) > 0: new_args = list(args) new_args[0] = [args[0]] factor_group = groups.EnumerationFactorGroup( - self._composite_variable_group, *new_args, **kwargs + self._variable_group, *new_args, **kwargs ) else: keys = kwargs.pop("keys") kwargs["connected_var_keys"] = [keys] factor_group = groups.EnumerationFactorGroup( - self._composite_variable_group, **kwargs + self._variable_group, **kwargs ) self._factor_groups.append(factor_group) + self._factor_group_to_starts[factor_group] = self._total_factor_num_states + self._total_factor_num_states += np.sum(factor_group.factor_num_states) + if name is not None: + self._named_factor_groups[name] = factor_group + + def get_factor(self, key: Any) -> Tuple[nodes.EnumerationFactor, int]: + """Function to get an individual factor and start index + + Args: + key: the key for the factor. + The queried factor must be part of an named factor group. + + Returns: + A tuple of length 2, containing the queried factor and its corresponding + start index in the flat message array. + """ + if key in self._named_factor_groups: + if len(self._named_factor_groups[key].factors) != 1: + raise ValueError( + f"Invalid factor key {key}. " + "Please provide a key for an individual factor, " + "not a factor group" + ) + + factor_group = self._named_factor_groups[key] + factor = factor_group.factors[0] + start = self._factor_group_to_starts[factor_group] + else: + if not ( + isinstance(key, tuple) + and len(key) == 2 + and key[0] in self._named_factor_groups + ): + raise ValueError( + f"Invalid factor key {key}. " + "Please provide a key either for an individual named factor, " + "or a tuple of length 2 specifying name of the factor group " + "and index of individual factors" + ) + + factor_group = self._named_factor_groups[key[0]] + factor = factor_group[key[1]] + + start = self._factor_group_to_starts[factor_group] + np.sum( + factor_group.factor_num_states[: factor_group.factors.index(factor)] + ) + + return factor, start @property def wiring(self) -> nodes.EnumerationWiring: @@ -162,118 +218,54 @@ def factor_configs_log_potentials(self) -> np.ndarray: ] ) - @property - def evidence(self) -> np.ndarray: - """Function to generate evidence array. Need to be overwritten for concrete factor graphs - - Returns: - Array of shape (num_var_states,) representing the flattened evidence for each variable - - Raises: - NotImplementedError: if self.evidence_default is a string that is not listed - """ - if self.evidence_default_mode == "zeros": - evidence = np.zeros(self.num_var_states) - elif self.evidence_default_mode == "random": - evidence = np.random.gumbel(size=self.num_var_states) - else: - raise NotImplementedError( - f"evidence_default_mode {self.evidence_default_mode} is not yet implemented" - ) - - for var, evidence_val in self._vars_to_evidence.items(): - start_index = self._vars_to_starts[var] - evidence[start_index : start_index + var.num_states] = evidence_val - - return evidence - @property def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: """List of individual factors in the factor graph""" return sum([factor_group.factors for factor_group in self._factor_groups], ()) - def get_init_msgs(self, context: Any = None): + def get_init_msgs(self) -> Messages: """Function to initialize messages. - By default it initializes all messages to 0. Can be overwritten to support - customized initialization schemes - - Args: - context: Optional context for initializing messages - Returns: - array of shape (num_edge_state,) representing initialized factor to variable - messages - """ - return jnp.zeros(self.wiring.var_states_for_edges.shape[0]) - - def set_evidence( - self, - key: Union[Tuple[Any, ...], Any], - evidence: Union[Dict[Any, np.ndarray], np.ndarray], - ) -> None: - """Function to update the evidence for variables in the FactorGraph. - - Args: - key: tuple that represents the index into the CompositeVariableGroup - (self._composite_variable_group) that is created when the FactorGraph is instantiated. Note that - this can be an index referring to an entire VariableGroup (in which case, the evidence - is set for the entire VariableGroup at once), or to an individual Variable within the - CompositeVariableGroup. - evidence: a container for np.ndarrays representing the evidence - Currently supported containers are: - - an np.ndarray: if key indexes an NDVariableArray, then evidence_values - can simply be an np.ndarray with num_var_array_dims + 1 dimensions where - num_var_array_dims is the number of dimensions of the NDVariableArray, and the - +1 represents a dimension (that should be the final dimension) for the evidence. - Note that the size of the final dimension should be the same as - variable_group.variable_size. if key indexes a particular variable, then this array - must be of the same size as variable.num_states - - a dictionary: if key indexes a GenericVariableGroup, then evidence_values - must be a dictionary mapping keys of variable_group to np.ndarrays of evidence values. - Note that each np.ndarray in the dictionary values must have the same size as - variable_group.variable_size. + Initialized messages """ - if key in self._composite_variable_group.container_keys: - self._vars_to_evidence.update( - self._composite_variable_group.variable_group_container[ - key - ].get_vars_to_evidence(evidence) - ) - else: - self._vars_to_evidence[self._composite_variable_group[key]] = evidence + return Messages( + ftov=FToVMessages( + factor_graph=self, default_mode=self.messages_default_mode + ), + evidence=Evidence( + factor_graph=self, default_mode=self.evidence_default_mode + ), + ) def run_bp( self, num_iters: int, damping_factor: float, - init_msgs: jnp.ndarray = None, - msgs_context: Any = None, - ) -> jnp.ndarray: + init_msgs: Optional[Messages] = None, + ) -> Messages: """Function to perform belief propagation. - Specifically, belief propagation is run on messages obtained from the self.get_init_msgs - method for num_iters iterations and returns the resulting messages. + Specifically, belief propagation is run for num_iters iterations and + returns the resulting messages. Args: num_iters: The number of iterations for which to perform message passing damping_factor: The damping factor to use for message updates between one timestep and the next - init_msgs: array of shape (num_edge_state,) representing the initial messaged on which to perform - belief propagation. If this argument is none, messages are generated by calling self.get_init_msgs() - msgs_context: Optional context for initializing messages + init_msgs: Initial messages to start the belief propagation. + If None, construct init_msgs by calling self.get_init_msgs() Returns: - an array of shape (num_edge_state,) that contains the message values after running BP for num_iters iterations + ftov messages after running BP for num_iters iterations """ # Retrieve the necessary data structures from the compiled self.wiring and # convert these to jax arrays. - if init_msgs is not None: - msgs = init_msgs - else: - msgs = self.get_init_msgs(msgs_context) + if init_msgs is None: + init_msgs = self.get_init_msgs() + msgs = jax.device_put(init_msgs.ftov.value) + evidence = jax.device_put(init_msgs.evidence.value) wiring = jax.device_put(self.wiring) - evidence = jax.device_put(self.evidence) factor_configs_log_potentials = jax.device_put( self.factor_configs_log_potentials ) @@ -314,31 +306,333 @@ def message_passing_step(msgs, _): return msgs, None msgs_after_bp, _ = jax.lax.scan(message_passing_step, msgs, None, num_iters) + return Messages( + ftov=FToVMessages(factor_graph=self, init_value=msgs_after_bp), + evidence=init_msgs.evidence, + ) - return msgs_after_bp - - def decode_map_states(self, msgs: jnp.ndarray) -> Dict[Tuple[Any, ...], int]: + def decode_map_states(self, msgs: Messages) -> Dict[Tuple[Any, ...], int]: """Function to computes the output of MAP inference on input messages. The final states are computed based on evidence obtained from the self.get_evidence method as well as the internal wiring. Args: - msgs: an array of shape (num_edge_state,) that correspond to messages to perform inference - upon + msgs: ftov messages for deciding MAP states Returns: a dictionary mapping each variable key to the MAP states of the corresponding variable """ var_states_for_edges = jax.device_put(self.wiring.var_states_for_edges) - evidence = jax.device_put(self.evidence) - final_var_states = evidence.at[var_states_for_edges].add(msgs) + evidence = jax.device_put(msgs.evidence.value) + final_var_states = evidence.at[var_states_for_edges].add(msgs.ftov.value) var_key_to_map_dict: Dict[Tuple[Any, ...], int] = {} final_var_states_np = np.array(final_var_states) - for var_key in self._composite_variable_group.keys: - var = self._composite_variable_group[var_key] + for var_key in self._variable_group.keys: + var = self._variable_group[var_key] start_index = self._vars_to_starts[var] var_key_to_map_dict[var_key] = np.argmax( final_var_states_np[start_index : start_index + var.num_states] ) return var_key_to_map_dict + + +@dataclass +class FToVMessages: + """Class for storing and manipulating factor to variable messages. + + Args: + factor_graph: associated factor graph + default_mode: default mode for initializing ftov messages. + Allowed values include "zeros" and "random" + If init_value is None, defaults to "zeros" + init_value: Optionally specify initial value for ftov messages + + Attributes: + _message_updates: Dict[int, jnp.ndarray]. A dictionary containing + the message updates to make on top of initial message values. + Maps starting indices to the message values to update with. + """ + + factor_graph: FactorGraph + default_mode: Optional[str] = None + init_value: Optional[Union[np.ndarray, jnp.ndarray]] = None + + def __post_init__(self): + self._message_updates: Dict[int, jnp.ndarray] = {} + if self.default_mode is not None and self.init_value is not None: + raise ValueError("Should specify only one of default_mode and init_value.") + + if self.default_mode is None and self.init_value is None: + self.default_mode = "zeros" + + if self.init_value is None: + if self.default_mode == "zeros": + self.init_value = np.zeros(self.factor_graph._total_factor_num_states) + elif self.default_mode == "random": + self.init_value = np.random.gumbel( + size=(self.factor_graph._total_factor_num_states,) + ) + else: + raise ValueError( + f"Unsupported default message mode {self.default_mode}. " + "Supported default modes are zeros or random" + ) + + def __getitem__(self, keys: Tuple[Any, Any]) -> jnp.ndarray: + """Function to query messages from a factor to a variable + + Args: + keys: a tuple of length 2, with keys[0] being the key for + factor, and keys[1] being the key for variable + + Returns: + An array containing the current ftov messages from factor + keys[0] to variable keys[1] + """ + if not ( + isinstance(keys, tuple) + and len(keys) == 2 + and keys[1] in self.factor_graph._variable_group.keys + ): + raise ValueError( + f"Invalid keys {keys}. Please specify a tuple of factor, variable " + "keys to get the messages from a named factor to a variable" + ) + + factor, start = self.factor_graph.get_factor(keys[0]) + if start in self._message_updates: + msgs = self._message_updates[start] + else: + variable = self.factor_graph._variable_group[keys[1]] + msgs = jax.device_put(self.init_value)[start : start + variable.num_states] + + return jax.device_put(msgs) + + @typing.overload + def __setitem__( + self, + keys: Tuple[Any, Any], + data: Union[np.ndarray, jnp.ndarray], + ) -> None: + """Setting messages from a factor to a variable + + Args: + keys: A tuple of length 2 + keys[0] is the key of the factor + keys[1] is the key of the variable + data: An array containing messages from factor keys[0] + to variable keys[1] + """ + + @typing.overload + def __setitem__( + self, + keys: Any, + data: Union[np.ndarray, jnp.ndarray], + ) -> None: + """Spreading beliefs at a variable to all connected factors + + Args: + keys: The key of the variable + data: An array containing the beliefs to be spread uniformly + across all factor to variable messages involving this + variable. + """ + + def __setitem__(self, keys, data) -> None: + if ( + isinstance(keys, tuple) + and len(keys) == 2 + and keys[1] in self.factor_graph._variable_group.keys + ): + factor, start = self.factor_graph.get_factor(keys[0]) + variable = self.factor_graph._variable_group[keys[1]] + if data.shape != (variable.num_states,): + raise ValueError( + f"Given message shape {data.shape} does not match expected " + f"shape f{(variable.num_states,)} from factor {keys[0]} " + f"to variable {keys[1]}." + ) + + self._message_updates[ + start + + np.sum(factor.edges_num_states[: factor.variables.index(variable)]) + ] = data + elif keys in self.factor_graph._variable_group.keys: + variable = self.factor_graph._variable_group[keys] + if data.shape != (variable.num_states,): + raise ValueError( + f"Given belief shape {data.shape} does not match expected " + f"shape f{(variable.num_states,)} for variable {keys}." + ) + + starts = np.nonzero( + self.factor_graph.wiring.var_states_for_edges + == self.factor_graph._vars_to_starts[variable] + )[0] + for start in starts: + self._message_updates[start] = data / starts.shape[0] + else: + raise ValueError( + "Invalid keys for setting messages. " + "Supported keys include a tuple of length 2 with factor " + "and variable keys for directly setting factor to variable " + "messages, or a valid variable key for spreading expected " + "beliefs at a variable" + ) + + @property + def value(self) -> jnp.ndarray: + """Functin to get the current flat message array + + Returns: + The flat message array after initializing (according to default_mode + or init_value) and applying all message updates. + """ + init_value = jax.device_put(self.init_value) + if not init_value.shape == (self.factor_graph._total_factor_num_states,): + raise ValueError( + f"Expected messages shape {(self.factor_graph._total_factor_num_states,)}. " + f"Got {init_value.shape}." + ) + + msgs = init_value + for start in self._message_updates: + data = self._message_updates[start] + msgs = msgs.at[start : start + data.shape[0]].set(data) + + return msgs + + +@dataclass +class Evidence: + """Class for storing and manipulating evidence + + Args: + factor_graph: associated factor graph + default_mode: default mode for initializing evidence. + Allowed values include "zeros" and "random" + If init_value is None, defaults to "zeros" + init_value: Optionally specify initial value for evidence + + Attributes: + _evidence_updates: Dict[nodes.Variable, np.ndarray]. maps every variable to an np.ndarray + representing the evidence for that variable + """ + + factor_graph: FactorGraph + default_mode: Optional[str] = None + init_value: Optional[Union[np.ndarray, jnp.ndarray]] = None + + def __post_init__(self): + self._evidence_updates: Dict[ + nodes.Variable, Union[np.ndarray, jnp.ndarray] + ] = {} + if self.default_mode is not None and self.init_value is not None: + raise ValueError("Should specify only one of default_mode and init_value.") + + if self.default_mode is None and self.init_value is None: + self.default_mode = "zeros" + + if self.init_value is None and self.default_mode not in ("zeros", "random"): + raise ValueError( + f"Unsupported default evidence mode {self.default_mode}. " + "Supported default modes are zeros or random" + ) + + if self.init_value is None: + if self.default_mode == "zeros": + self.init_value = jnp.zeros(self.factor_graph.num_var_states) + else: + self.init_value = jax.device_put( + np.random.gumbel(size=(self.factor_graph.num_var_states,)) + ) + + def __getitem__(self, key: Any) -> jnp.ndarray: + """Function to query evidence for a variable + + Args: + key: key for the variable + + Returns: + evidence for the queried variable + """ + variable = self.factor_graph._variable_group[key] + if self.factor_graph._variable_group[key] in self._evidence_updates: + evidence = jax.device_put(self._evidence_updates[variable]) + else: + start = self.factor_graph._vars_to_starts[variable] + evidence = jax.device_put(self.init_value)[ + start : start + variable.num_states + ] + + return evidence + + def __setitem__( + self, + key: Any, + evidence: Union[Dict[Hashable, np.ndarray], np.ndarray], + ) -> None: + """Function to update the evidence for variables + + Args: + key: tuple that represents the index into the VariableGroup + (self.factor_graph._variable_group) that is created when the FactorGraph is instantiated. Note that + this can be an index referring to an entire VariableGroup (in which case, the evidence + is set for the entire VariableGroup at once), or to an individual Variable within the + VariableGroup. + evidence: a container for np.ndarrays representing the evidence + Currently supported containers are: + - an np.ndarray: if key indexes an NDVariableArray, then evidence_values + can simply be an np.ndarray with num_var_array_dims + 1 dimensions where + num_var_array_dims is the number of dimensions of the NDVariableArray, and the + +1 represents a dimension (that should be the final dimension) for the evidence. + Note that the size of the final dimension should be the same as + variable_group.variable_size. if key indexes a particular variable, then this array + must be of the same size as variable.num_states + - a dictionary: if key indexes a GenericVariableGroup, then evidence_values + must be a dictionary mapping keys of variable_group to np.ndarrays of evidence values. + Note that each np.ndarray in the dictionary values must have the same size as + variable_group.variable_size. + """ + if key in self.factor_graph._variable_group.container_keys: + if key == slice(None): + variable_group = self.factor_graph._variable_group + else: + variable_group = ( + self.factor_graph._variable_group.variable_group_container[key] + ) + + self._evidence_updates.update(variable_group.get_vars_to_evidence(evidence)) + else: + self._evidence_updates[self.factor_graph._variable_group[key]] = evidence + + @property + def value(self) -> jnp.ndarray: + """Function to generate evidence array + + Returns: + Array of shape (num_var_states,) representing the flattened evidence for each variable + """ + evidence = jax.device_put(self.init_value) + for var, evidence_val in self._evidence_updates.items(): + start_index = self.factor_graph._vars_to_starts[var] + evidence = evidence.at[start_index : start_index + var.num_states].set( + evidence_val + ) + + return evidence + + +@dataclass +class Messages: + """Container class for factor to variable messages and evidence. + + Args: + ftov: factor to variable messages + evidence: evidence + """ + + ftov: FToVMessages + evidence: Evidence diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 2b27d480..d90835e2 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -23,12 +23,12 @@ class VariableGroup: _keys_to_vars: A private, immutable mapping from keys to variables """ - _keys_to_vars: Mapping[Any, nodes.Variable] = field(init=False) + _keys_to_vars: Mapping[Hashable, nodes.Variable] = field(init=False) def __post_init__(self) -> None: """Initialize a private, immutable mapping from keys to variables.""" object.__setattr__( - self, "_keys_to_vars", MappingProxyType(self._set_keys_to_vars()) + self, "_keys_to_vars", MappingProxyType(self._get_keys_to_vars()) ) @typing.overload @@ -71,7 +71,7 @@ def __getitem__(self, key): else: return vars_list[0] - def _set_keys_to_vars(self) -> Dict[Any, nodes.Variable]: + def _get_keys_to_vars(self) -> Dict[Any, nodes.Variable]: """Function that generates a dictionary mapping keys to variables. Returns: @@ -91,7 +91,7 @@ def get_vars_to_evidence(self, evidence: Any) -> Dict[nodes.Variable, np.ndarray "Please subclass the VariableGroup class and override this method" ) - @property + @cached_property def keys(self) -> Tuple[Any, ...]: """Function to return a tuple of all keys in the group. @@ -100,7 +100,7 @@ def keys(self) -> Tuple[Any, ...]: """ return tuple(self._keys_to_vars.keys()) - @property + @cached_property def variables(self) -> Tuple[nodes.Variable, ...]: """Function to return a tuple of all variables in the group. @@ -109,6 +109,13 @@ def variables(self) -> Tuple[nodes.Variable, ...]: """ return tuple(self._keys_to_vars.values()) + @cached_property + def container_keys(self) -> Tuple: + """Placeholder function. Returns a tuple containing slice(None) for all variable groups + other than a composite variable group + """ + return (slice(None),) + @dataclass(frozen=True, eq=False) class CompositeVariableGroup(VariableGroup): @@ -130,12 +137,12 @@ class CompositeVariableGroup(VariableGroup): """ variable_group_container: Union[ - Mapping[Any, VariableGroup], Sequence[VariableGroup] + Mapping[Hashable, VariableGroup], Sequence[VariableGroup] ] def __post_init__(self): object.__setattr__( - self, "_keys_to_vars", MappingProxyType(self._set_keys_to_vars()) + self, "_keys_to_vars", MappingProxyType(self._get_keys_to_vars()) ) @typing.overload @@ -180,13 +187,13 @@ def __getitem__(self, key): else: return vars_list[0] - def _set_keys_to_vars(self) -> Dict[Any, nodes.Variable]: + def _get_keys_to_vars(self) -> Dict[Hashable, nodes.Variable]: """Function that generates a dictionary mapping keys to variables. Returns: a dictionary mapping all possible keys to different variables. """ - keys_to_vars = {} + keys_to_vars: Dict[Hashable, nodes.Variable] = {} for container_key in self.container_keys: for variable_group_key in self.variable_group_container[container_key].keys: if isinstance(variable_group_key, tuple): @@ -250,7 +257,7 @@ class NDVariableArray(VariableGroup): variable_size: int shape: Tuple[int, ...] - def _set_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: + def _get_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: """Function that generates a dictionary mapping keys to variables. Returns: @@ -302,7 +309,7 @@ class GenericVariableGroup(VariableGroup): variable_size: int key_tuple: Tuple[Any, ...] - def _set_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: + def _get_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: """Function that generates a dictionary mapping keys to variables. Returns: @@ -314,7 +321,7 @@ def _set_keys_to_vars(self) -> Dict[Tuple[int, ...], nodes.Variable]: return keys_to_vars def get_vars_to_evidence( - self, evidence: Mapping[Any, np.ndarray] + self, evidence: Mapping[Hashable, np.ndarray] ) -> Dict[nodes.Variable, np.ndarray]: """Function that turns input evidence into a dictionary mapping variables to evidence. @@ -356,25 +363,38 @@ class FactorGroup: variable_group: either a VariableGroup or - if the elements of more than one VariableGroup are connected to this FactorGroup - then a CompositeVariableGroup. This holds all the variables that are connected to this FactorGroup - connected_var_keys: A list of list of tuples, where each innermost tuple contains a - key into variable_group. Each list within the outer list is taken to contain the keys of variables - neighboring a particular factor to be added. + + Attributes: + _keys_to_factors: maps factor keys to the corresponding factors Raises: ValueError: if connected_var_keys is an empty list """ variable_group: Union[CompositeVariableGroup, VariableGroup] - connected_var_keys: List[List[Tuple[Any, ...]]] + _keys_to_factors: Mapping[Hashable, nodes.EnumerationFactor] = field(init=False) def __post_init__(self) -> None: """Initializes a tuple of all the factors contained within this FactorGroup.""" - if len(self.connected_var_keys) == 0: - raise ValueError("self.connected_var_keys is empty") + object.__setattr__( + self, "_keys_to_factors", MappingProxyType(self._get_keys_to_factors()) + ) - @cached_property - def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: - raise NotImplementedError("Needs to be overriden by subclass") + def __getitem__(self, key: Hashable) -> nodes.EnumerationFactor: + """Function to query individual factors in the factor group + + Args: + key: a key used to query an individual factor in the factor group + + Returns: + A queried individual factor + """ + if key not in self.keys: + raise ValueError( + f"The queried factor {key} is not present in the factor group" + ) + + return self._keys_to_factors[key] def compile_wiring( self, vars_to_starts: Mapping[nodes.Variable, int] @@ -405,6 +425,34 @@ def factor_group_log_potentials(self) -> np.ndarray: [factor.factor_configs_log_potentials for factor in self.factors] ) + def _get_keys_to_factors(self) -> Dict[Hashable, nodes.EnumerationFactor]: + """Function that generates a dictionary mapping keys to factors. + + Returns: + a dictionary mapping all possible keys to different factors. + """ + raise NotImplementedError( + "Please subclass the VariableGroup class and override this method" + ) + + @cached_property + def keys(self) -> Tuple[Hashable, ...]: + """Returns all keys in the factor group.""" + return tuple(self._keys_to_factors.keys()) + + @cached_property + def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: + """Returns all factors in the factor group.""" + return tuple(self._keys_to_factors.values()) + + @cached_property + def factor_num_states(self) -> np.ndarray: + """Returns the list of total number of edge states for factors in the factor group.""" + factor_num_states = np.array( + [np.sum(factor.edges_num_states) for factor in self.factors], dtype=int + ) + return factor_num_states + @dataclass(frozen=True, eq=False) class EnumerationFactorGroup(FactorGroup): @@ -415,24 +463,30 @@ class EnumerationFactorGroup(FactorGroup): uniform 0 unless the inheriting class includes a factor_configs_log_potentials argument. Args: + connected_var_keys: A list of list of tuples, where each innermost tuple contains a + key into variable_group. Each list within the outer list is taken to contain the keys of variables + neighboring a particular factor to be added. factor_configs: Array of shape (num_val_configs, num_variables) An array containing explicit enumeration of all valid configurations factor_configs_log_potentials: Optional array of shape (num_val_configs,). If specified, it contains the log of the potential value for every possible configuration. If none, it is assumed the log potential is uniform 0 and such an array is automatically initialized. - - Attributes: - factors: a tuple of all the factors belonging to this group. These are constructed - internally by invoking the _get_connected_var_keys_for_factors method. """ + connected_var_keys: Union[ + Sequence[List[Tuple[Hashable, ...]]], + Mapping[Any, List[Tuple[Hashable, ...]]], + ] factor_configs: np.ndarray factor_configs_log_potentials: Optional[np.ndarray] = None - @cached_property - def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: - """Returns a tuple of all the factors contained within this FactorGroup.""" + def _get_keys_to_factors(self) -> Dict[Hashable, nodes.EnumerationFactor]: + """Function that generates a dictionary mapping keys to factors. + + Returns: + a dictionary mapping all possible keys to different factors. + """ if self.factor_configs_log_potentials is None: factor_configs_log_potentials = np.zeros( self.factor_configs.shape[0], dtype=float @@ -440,16 +494,26 @@ def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: else: factor_configs_log_potentials = self.factor_configs_log_potentials - return tuple( - [ - nodes.EnumerationFactor( - tuple(self.variable_group[keys_list]), + if isinstance(self.connected_var_keys, Sequence): + keys_to_factors: Dict[Hashable, nodes.EnumerationFactor] = { + frozenset(self.connected_var_keys[ii]): nodes.EnumerationFactor( + tuple(self.variable_group[self.connected_var_keys[ii]]), self.factor_configs, factor_configs_log_potentials, ) - for keys_list in self.connected_var_keys - ] - ) + for ii in range(len(self.connected_var_keys)) + } + else: + keys_to_factors = { + key: nodes.EnumerationFactor( + tuple(self.variable_group[self.connected_var_keys[key]]), + self.factor_configs, + factor_configs_log_potentials, + ) + for key in self.connected_var_keys + } + + return keys_to_factors @dataclass(frozen=True, eq=False) @@ -463,27 +527,37 @@ class PairwiseFactorGroup(FactorGroup): one CompositeVariableGroup. Args: + connected_var_keys: A list of list of tuples, where each innermost tuple contains a + key into variable_group. Each list within the outer list is taken to contain the keys of variables + neighboring a particular factor to be added. log_potential_matrix: array of shape (var1.variable_size, var2.variable_size), where var1 and var2 are the 2 VariableGroups (that may refer to the same VariableGroup) whose keys are present in each sub-list from self.connected_var_keys. - - Attributes: - factors: a tuple of all the factors belonging to this group. These are constructed - internally using self.connected_var_keys - factor_configs_log_potentials: array of shape (num_val_configs,), where - num_val_configs = var1.variable_size* var2.variable_size. This flattened array - contains the log of the potential value for every possible configuration. - - Raises: - ValueError: if every sub-list within self.connected_var_keys has len != 2, or if the shape of the - log_potential_matrix is not the same as the variable sizes for each variable referenced in - each sub-list of self.connected_var_keys """ + connected_var_keys: Union[ + Sequence[List[Tuple[Hashable, ...]]], + Mapping[Any, List[Tuple[Hashable, ...]]], + ] log_potential_matrix: np.ndarray - def __post_init__(self) -> None: - for fac_list in self.connected_var_keys: + def _get_keys_to_factors(self) -> Dict[Hashable, nodes.EnumerationFactor]: + """Function that generates a dictionary mapping keys to factors. + + Returns: + a dictionary mapping all possible keys to different factors. + + Raises: + ValueError: if every sub-list within self.connected_var_keys has len != 2, or if the shape of the + log_potential_matrix is not the same as the variable sizes for each variable referenced in + each sub-list of self.connected_var_keys + """ + if isinstance(self.connected_var_keys, Sequence): + connected_var_keys = self.connected_var_keys + else: + connected_var_keys = tuple(self.connected_var_keys.values()) + + for fac_list in connected_var_keys: if len(fac_list) != 2: raise ValueError( "All pairwise factors should connect to exactly 2 variables. Got a factor connecting to" @@ -502,9 +576,6 @@ def __post_init__(self) -> None: + f"based on self.connected_var_keys. Instead, it has shape {self.log_potential_matrix.shape}" ) - @cached_property - def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: - """Returns a tuple of all the factors contained within this FactorGroup.""" factor_configs = np.array( np.meshgrid( np.arange(self.log_potential_matrix.shape[0]), @@ -514,13 +585,23 @@ def factors(self) -> Tuple[nodes.EnumerationFactor, ...]: factor_configs_log_potentials = self.log_potential_matrix[ factor_configs[:, 0], factor_configs[:, 1] ] - return tuple( - [ - nodes.EnumerationFactor( - tuple(self.variable_group[keys_list]), + if isinstance(self.connected_var_keys, Sequence): + keys_to_factors: Dict[Hashable, nodes.EnumerationFactor] = { + frozenset(self.connected_var_keys[ii]): nodes.EnumerationFactor( + tuple(self.variable_group[self.connected_var_keys[ii]]), factor_configs, factor_configs_log_potentials, ) - for keys_list in self.connected_var_keys - ] - ) + for ii in range(len(self.connected_var_keys)) + } + else: + keys_to_factors = { + key: nodes.EnumerationFactor( + tuple(self.variable_group[self.connected_var_keys[key]]), + factor_configs, + factor_configs_log_potentials, + ) + for key in self.connected_var_keys + } + + return keys_to_factors diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index f64673ea..2f77395f 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -120,10 +120,10 @@ def edges_num_states(self) -> np.ndarray: Array of shape (num_edges,) Number of states for the variables connected to each edge """ - edge_num_states = np.array( + edges_num_states = np.array( [variable.num_states for variable in self.variables], dtype=int ) - return edge_num_states + return edges_num_states @utils.cached_property def factor_configs_edge_states(self) -> np.ndarray: diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index 581c583b..3707639c 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -1,7 +1,55 @@ +import numpy as np +import pytest + from pgmax.fg import graph, groups def test_onevar_graph(): - v_group = groups.GenericVariableGroup(15, tuple([0])) + v_group = groups.GenericVariableGroup(15, (0,)) fg = graph.FactorGraph(v_group) - assert fg._composite_variable_group[0, 0].num_states == 15 + assert fg._variable_group[0].num_states == 15 + with pytest.raises(ValueError) as verror: + graph.FToVMessages( + factor_graph=fg, default_mode="zeros", init_value=np.zeros(1) + ) + + assert "Should specify only" in str(verror.value) + with pytest.raises(ValueError) as verror: + graph.FToVMessages(factor_graph=fg, default_mode="test") + + assert "Unsupported default message mode" in str(verror.value) + with pytest.raises(ValueError) as verror: + graph.Evidence(factor_graph=fg, default_mode="zeros", init_value=np.zeros(1)) + + assert "Should specify only" in str(verror.value) + with pytest.raises(ValueError) as verror: + graph.Evidence(factor_graph=fg, default_mode="test") + + assert "Unsupported default evidence mode" in str(verror.value) + fg.add_factor([0], np.arange(15)[:, None], name="test") + with pytest.raises(ValueError) as verror: + fg.add_factor([0], np.arange(15)[:, None], name="test") + + assert "A factor group with the name" in str(verror.value) + init_msgs = fg.get_init_msgs() + init_msgs.evidence[:] = {0: np.ones(15)} + with pytest.raises(ValueError) as verror: + init_msgs.ftov["test", 1] + + assert "Invalid keys" in str(verror.value) + with pytest.raises(ValueError) as verror: + init_msgs.ftov["test", 0] = np.zeros(1) + + assert "Given message shape" in str(verror.value) + with pytest.raises(ValueError) as verror: + init_msgs.ftov[0] = np.zeros(1) + + assert "Given belief shape" in str(verror.value) + with pytest.raises(ValueError) as verror: + init_msgs.ftov[1] = np.zeros(1) + + assert "Invalid keys for setting messages" in str(verror.value) + with pytest.raises(ValueError) as verror: + graph.FToVMessages(factor_graph=fg, init_value=np.zeros(1)).value + + assert "Expected messages shape" in str(verror.value) diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 16133b40..63e07d28 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -20,6 +20,7 @@ def test_composite_vargroup_valueerror(): def test_composite_vargroup_evidence(): v_group1 = groups.GenericVariableGroup(3, tuple([0, 1, 2])) + v_group1.container_keys v_group2 = groups.GenericVariableGroup(3, tuple([0, 1, 2])) comp_var_group = groups.CompositeVariableGroup(tuple([v_group1, v_group2])) vars_to_evidence = comp_var_group.get_vars_to_evidence( @@ -39,13 +40,6 @@ def test_ndvararray_evidence_error(): assert "Input evidence" in str(verror.value) -def test_facgroup_errors(): - v_group = groups.NDVariableArray(3, (2, 2)) - with pytest.raises(ValueError) as verror: - groups.FactorGroup(v_group, []) - assert "self.connected_var_keys is empty" == str(verror.value) - - def test_pairwisefacgroup_errors(): v_group = groups.NDVariableArray(3, (2, 2)) with pytest.raises(ValueError) as verror: @@ -59,6 +53,12 @@ def test_pairwisefacgroup_errors(): v_group, [[(0, 0), (1, 1)]], np.zeros((1,), dtype=float) ) assert "self.log_potential_matrix must" in str(verror.value) + factor_group = groups.PairwiseFactorGroup( + v_group, {0: [(0, 0), (1, 1)]}, np.zeros((3, 3), dtype=float) + ) + with pytest.raises(ValueError) as verror: + factor_group[1] + assert "The queried factor" in str(verror.value) def test_generic_evidence_errors(): diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index 55fe5b0b..30b42b87 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -3,6 +3,7 @@ import jax import jax.numpy as jnp import numpy as np +import pytest from numpy.random import default_rng from scipy.ndimage import gaussian_filter @@ -255,13 +256,7 @@ def create_valid_suppression_config_arr(suppression_diameter): pass # Create the factor graph - fg = graph.FactorGraph( - variable_groups=composite_grid_group, - ) - - # Set the evidence - fg.set_evidence("grid_vars", grid_evidence_arr) - fg.set_evidence("additional_vars", additional_vars_evidence_dict) + fg = graph.FactorGraph(variables=composite_grid_group) # Imperatively add EnumerationFactorGroups (each consisting of just one EnumerationFactor) to # the graph! @@ -298,20 +293,24 @@ def create_valid_suppression_config_arr(suppression_diameter): ("additional_vars", 1, row + 1, col), ] if row % 2 == 0: - fg.add_factors( + fg.add_factor( curr_keys, valid_configs_non_supp, np.zeros(valid_configs_non_supp.shape[0], dtype=float), + name=(row, col), ) else: - fg.add_factors( + fg.add_factor( keys=curr_keys, factor_configs=valid_configs_non_supp, factor_configs_log_potentials=np.zeros( valid_configs_non_supp.shape[0], dtype=float ), + name=(row, col), ) + assert fg.get_factor((0, 0))[1] == 0 + # Create an EnumerationFactorGroup for vertical suppression factors vert_suppression_keys: List[List[Tuple[Any, ...]]] = [] for col in range(N): @@ -350,12 +349,14 @@ def create_valid_suppression_config_arr(suppression_diameter): ) # Add the suppression factors to the graph via kwargs - fg.add_factors( + fg.add_factor( factor_factory=groups.EnumerationFactorGroup, - connected_var_keys=vert_suppression_keys, + connected_var_keys={ + idx: keys for idx, keys in enumerate(vert_suppression_keys) + }, factor_configs=valid_configs_supp, ) - fg.add_factors( + fg.add_factor( factor_factory=groups.EnumerationFactorGroup, connected_var_keys=horz_suppression_keys, factor_configs=valid_configs_supp, @@ -365,11 +366,16 @@ def create_valid_suppression_config_arr(suppression_diameter): ) # Run BP - one_step_msgs = fg.run_bp(1, 0.5) + # Set the evidence + init_msgs = fg.get_init_msgs() + init_msgs.evidence["grid_vars"] = grid_evidence_arr + init_msgs.evidence["additional_vars"] = additional_vars_evidence_dict + fg.run_bp(1, 0.5) + one_step_msgs = fg.run_bp(1, 0.5, init_msgs=init_msgs) final_msgs = fg.run_bp(99, 0.5, one_step_msgs) # Test that the output messages are close to the true messages - assert jnp.allclose(final_msgs, true_final_msgs_output, atol=1e-06) + assert jnp.allclose(final_msgs.ftov.value, true_final_msgs_output, atol=1e-06) assert fg.decode_map_states(final_msgs) == true_map_state_output @@ -385,12 +391,7 @@ def test_e2e_heretic(): bXn = np.zeros((30, 30, 3)) # Create the factor graph - fg = graph.FactorGraph((pixel_vars, hidden_vars)) - - # Assign evidence to pixel vars - fg.set_evidence(0, np.array(bXn)) - fg.set_evidence(tuple([0, 0, 0]), np.array([0.0, 0.0, 0.0])) - fg.evidence_default_mode = "random" + fg = graph.FactorGraph((pixel_vars, hidden_vars), evidence_default_mode="random") def binary_connected_variables( num_hidden_rows, num_hidden_cols, kernel_row, kernel_col @@ -409,12 +410,47 @@ def binary_connected_variables( W_pot = np.zeros((17, 3, 3, 3), dtype=float) for k_row in range(3): for k_col in range(3): - fg.add_factors( + fg.add_factor( factor_factory=groups.PairwiseFactorGroup, connected_var_keys=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], + name=(k_row, k_col), ) - assert isinstance(fg.evidence, np.ndarray) - + # Assign evidence to pixel vars + init_msgs = fg.get_init_msgs() + init_msgs.evidence[0] = np.array(bXn) + init_msgs.evidence[0, 0, 0] = np.array([0.0, 0.0, 0.0]) + init_msgs.evidence[0, 0, 0] + init_msgs.evidence[1, 0, 0] + with pytest.raises(ValueError) as verror: + fg.get_factor((0, 0)) + + assert "Invalid factor key" in str(verror.value) + with pytest.raises(ValueError) as verror: + fg.get_factor((((0, 0), 0), (10, 20, 30))) + + assert "Invalid factor key" in str(verror.value) + assert isinstance(init_msgs.evidence.value, jnp.ndarray) assert len(fg.factors) == 7056 + evidence = graph.Evidence(factor_graph=fg) + for ftov_msgs in [ + graph.FToVMessages(factor_graph=fg), + graph.FToVMessages(factor_graph=fg, default_mode="random"), + ]: + ftov_msgs[((0, 0), frozenset([(1, 0, 0), (0, 0, 0)])), (0, 0, 0)] + ftov_msgs[((1, 1), frozenset([(1, 0, 0), (0, 1, 1)])), (1, 0, 0)] = np.ones(17) + assert np.all( + ftov_msgs[((1, 1), frozenset([(1, 0, 0), (0, 1, 1)])), (1, 0, 0)] == 1.0 + ) + ftov_msgs[1, 0, 0] = np.ones(17) + assert np.all( + ftov_msgs[((0, 0), frozenset([(1, 0, 0), (0, 0, 0)])), (1, 0, 0)] == 1.0 / 9 + ) + assert np.all( + ftov_msgs[((1, 1), frozenset([(1, 0, 0), (0, 1, 1)])), (1, 0, 0)] == 1.0 / 9 + ) + msgs = fg.run_bp( + 1, 0.5, init_msgs=graph.Messages(ftov=ftov_msgs, evidence=evidence) + ) + msgs.ftov[((0, 0), frozenset([(1, 0, 0), (0, 0, 0)])), (0, 0, 0)]