Skip to content
This repository has been archived by the owner on Dec 5, 2024. It is now read-only.

Commit

Permalink
Refactoring factor groups; Support setting initial messages (#60)
Browse files Browse the repository at this point in the history
* Refactor

* Fix tests

* Checkpoint

* Outline for message manipulation

* Set messages given a pair of factor and variable

* Implement beliefs spreading

* Implement getitem

* Implement value

* Refactor

* Fix examples and tests

* Update coverage reporting

* Tests

* Tests

* Test message manipulation

* Improve coverage

* Full coverage

* Docstrings

* Simplify behavior for single variable group

* Change default naming convention

* Add ising model example; Fix crash

* Standalone evidence

* Update examples

* Support getitem for evidence

* Docstrings

* Fix crash; full coverage

* Update ising model example notebook

* Check duplicate names

* Support setting evidence for single factor group

* Comments
  • Loading branch information
StannisZhou authored Sep 6, 2021
1 parent 975750f commit bf874f0
Show file tree
Hide file tree
Showing 10 changed files with 801 additions and 255 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
#----------------------------------------------
- name: Test with coverage
run: |
poetry run pytest --cov=pgmax --cov-report=xml
poetry run pytest --cov=pgmax --cov-report=xml --cov-report=term-missing:skip-covered
#----------------------------------------------
# upload coverage report to codecov
#----------------------------------------------
Expand Down
28 changes: 16 additions & 12 deletions examples/heretic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
# format_version: '1.3'
# jupytext_version: 1.11.4
# kernelspec:
# display_name: 'Python 3.7.11 64-bit (''pgmax-zIh0MZVc-py3.7'': venv)'
# name: python371164bitpgmaxzih0mzvcpy37venve540bb1b5cdf4292a3f5a12c4904cc40
# display_name: Python 3
# language: python
# name: python3
# ---

# %%
from timeit import default_timer as timer
from typing import Any, List, Tuple

Expand Down Expand Up @@ -91,14 +93,10 @@
# Create the factor graph
fg = graph.FactorGraph((pixel_vars, hidden_vars))

# Assign evidence to pixel vars
fg.set_evidence(0, np.array(bXn_evidence))
fg.set_evidence(1, np.array(bHn_evidence))


# %% [markdown]
# # Add all Factors to graph via constructing FactorGroups


# %% tags=[]
def binary_connected_variables(
num_hidden_rows, num_hidden_cols, kernel_row, kernel_col
Expand All @@ -118,7 +116,7 @@ def binary_connected_variables(
W_pot = W_orig.swapaxes(0, 1)
for k_row in range(3):
for k_col in range(3):
fg.add_factors(
fg.add_factor(
factor_factory=groups.PairwiseFactorGroup,
connected_var_keys=binary_connected_variables(28, 28, k_row, k_col),
log_potential_matrix=W_pot[:, :, k_row, k_col],
Expand Down Expand Up @@ -174,16 +172,22 @@ def custom_flatten_ordering(Mdown, Mup):
reshaped_Mdown = Mdown.reshape(3, 3, 3, 28, 28)
reshaped_Mup = Mup.reshape(17, 3, 3, 28, 28)

init_msgs = jax.device_put(
custom_flatten_ordering(np.array(reshaped_Mdown), np.array(reshaped_Mup))
)

# %% [markdown]
# # Run Belief Propagation and Retrieve MAP Estimate

# %% tags=[]
# Run BP
init_msgs = fg.get_init_msgs()
init_msgs.ftov = graph.FToVMessages(
factor_graph=fg,
init_value=jax.device_put(
custom_flatten_ordering(np.array(reshaped_Mdown), np.array(reshaped_Mup))
),
)
init_msgs.evidence[0] = np.array(bXn_evidence)
init_msgs.evidence[1] = np.array(bHn_evidence)
bp_start_time = timer()
# Assign evidence to pixel vars
final_msgs = fg.run_bp(
500,
0.5,
Expand Down
87 changes: 87 additions & 0 deletions examples/ising_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.4
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---

# %%
# %matplotlib inline
import matplotlib.pyplot as plt
import numpy as np

from pgmax.fg import graph, groups

# %% [markdown]
# ### Construct variable grid, initialize factor graph, and add factors

# %%
variables = groups.NDVariableArray(variable_size=2, shape=(50, 50))
fg = graph.FactorGraph(variables=variables, evidence_default_mode="random")
connected_var_keys = []
for ii in range(50):
for jj in range(50):
kk = (ii + 1) % 50
ll = (jj + 1) % 50
connected_var_keys.append([(ii, jj), (kk, jj)])
connected_var_keys.append([(ii, jj), (ii, ll)])

fg.add_factor(
factor_factory=groups.PairwiseFactorGroup,
connected_var_keys=connected_var_keys,
log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
name="factors",
)

# %% [markdown]
# ### Run inference and visualize results

# %%
msgs = fg.run_bp(3000, 0.5)
map_states = fg.decode_map_states(msgs)
img = np.zeros((50, 50))
for key in map_states:
img[key] = map_states[key]

fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(img)

# %% [markdown]
# ### Message and evidence manipulation

# %%
# Query evidence for variable (0, 0)
msgs.evidence[0, 0]

# %%
# Set evidence for variable (0, 0)
msgs.evidence[0, 0] = np.array([1.0, 1.0])
msgs.evidence[0, 0]

# %%
# Set evidence for all variables using an array
evidence = np.random.randn(50, 50, 2)
msgs.evidence[:] = evidence
msgs.evidence[10, 10] == evidence[10, 10]

# %%
# Query messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0)
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)]

# %%
# Set messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0)
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)] = np.array([1.0, 1.0])
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)]

# %%
# Uniformly spread expected belief at a variable to all connected factors
msgs.ftov[0, 0] = np.array([1.0, 1.0])
msgs.ftov[("factors", frozenset([(0, 0), (0, 1)])), (0, 0)]
28 changes: 12 additions & 16 deletions examples/sanity_check_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
# format_version: '1.3'
# jupytext_version: 1.11.4
# kernelspec:
# display_name: 'Python 3.8.5 64-bit (''pgmax-JcKb81GE-py3.8'': poetry)'
# display_name: Python 3
# language: python
# name: python3
# ---

Expand Down Expand Up @@ -154,18 +155,12 @@

# %%
# Create the factor graph
fg = graph.FactorGraph(
variable_groups=composite_grid_group,
)

# Set the evidence
fg.set_evidence("grid_vars", grid_evidence_arr)
fg.set_evidence("additional_vars", additional_vars_evidence_dict)

fg = graph.FactorGraph(variables=composite_grid_group)

# %% [markdown]
# ### Create Valid Configuration Arrays


# %%
# Helper function to easily generate a list of valid configurations for a given suppression diameter
def create_valid_suppression_config_arr(suppression_diameter):
Expand Down Expand Up @@ -268,7 +263,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
("additional_vars", 1, row + 1, col),
]

fg.add_factors(
fg.add_factor(
curr_keys,
valid_configs_non_supp,
np.zeros(valid_configs_non_supp.shape[0], dtype=float),
Expand Down Expand Up @@ -312,21 +307,18 @@ def create_valid_suppression_config_arr(suppression_diameter):
for c in range(start_col, start_col + SUPPRESSION_DIAMETER)
]
)
horz_suppression_group = groups.EnumerationFactorGroup(
composite_grid_group, horz_suppression_keys, valid_configs_supp
)


# %% [markdown]
# ### Add FactorGroups Remaining to FactorGraph

# %%
fg.add_factors(
fg.add_factor(
factor_factory=groups.EnumerationFactorGroup,
connected_var_keys=vert_suppression_keys,
factor_configs=valid_configs_supp,
)
fg.add_factors(
fg.add_factor(
factor_factory=groups.EnumerationFactorGroup,
connected_var_keys=horz_suppression_keys,
factor_configs=valid_configs_supp,
Expand All @@ -337,8 +329,12 @@ def create_valid_suppression_config_arr(suppression_diameter):

# %%
# Run BP
# Set the evidence
init_msgs = fg.get_init_msgs()
init_msgs.evidence["grid_vars"] = grid_evidence_arr
init_msgs.evidence["additional_vars"] = additional_vars_evidence_dict
bp_start_time = timer()
final_msgs = fg.run_bp(1000, 0.5)
final_msgs = fg.run_bp(1000, 0.5, init_msgs=init_msgs)
bp_end_time = timer()
print(f"time taken for bp {bp_end_time - bp_start_time}")

Expand Down
Loading

0 comments on commit bf874f0

Please sign in to comment.