Skip to content
This repository has been archived by the owner on Dec 5, 2024. It is now read-only.

Commit

Permalink
Hotfix (#144)
Browse files Browse the repository at this point in the history
* Hotfix

* Unit test + bump

* Text
  • Loading branch information
antoine-dedieu authored May 19, 2022
1 parent 1570a1f commit 7c71b14
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 187 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
author = "Guangyao Zhou, Nishanth Kumar, Antoine Dedieu, Miguel Lazaro-Gredilla, Shrinu Kushagra, Dileep George"

# The full version, including alpha/beta/rc tags
release = "0.4.0"
release = "0.4.1"


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion pgmax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""A container package for the entire PGMax library."""

__version__ = "0.4.0"
__version__ = "0.4.1"
2 changes: 1 addition & 1 deletion pgmax/vgroup/varray.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
data = flat_data.reshape(self.shape)
elif flat_data.size == self.num_states.sum():
data = jnp.full(
shape=self.shape + (self.num_states.max(),), fill_value=jnp.nan
shape=self.shape + (self.num_states.max(),), fill_value=-jnp.inf
)
data = data.at[np.arange(data.shape[-1]) < self.num_states[..., None]].set(
flat_data
Expand Down
2 changes: 1 addition & 1 deletion pgmax/vgroup/vdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __post_init__(self):
elif isinstance(self.num_states, np.ndarray) and np.issubdtype(
self.num_states.dtype, int
):
if self.num_states.shape != len(self.variable_names):
if self.num_states.shape != (len(self.variable_names),):
raise ValueError(
f"Expected num_states shape ({len(self.variable_names)},). Got {self.num_states.shape}."
)
Expand Down
374 changes: 192 additions & 182 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pgmax"
version = "0.4.0"
version = "0.4.1"
description = "Loopy belief propagation for factor graphs on discrete variables, in JAX!"
authors = ["Stannis Zhou <[email protected]>", "Nishanth Kumar <[email protected]>", "Antoine Dedieu <[email protected]>", "Miguel Lazaro-Gredilla <[email protected]>", "Dileep George <[email protected]>"]
maintainers = ["Stannis Zhou <[email protected]>", "Nishanth Kumar <[email protected]>", "Antoine Dedieu <[email protected]>"]
Expand Down
43 changes: 43 additions & 0 deletions tests/fgraph/test_fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,46 @@ def test_bp():
bp_arrays = replace(bp_arrays, log_potentials=jnp.zeros((10)))
bp_state = bp.to_bp_state(bp_arrays)
assert bp_state.fg_state == fg.fg_state


def test_bp_different_num_states():
# Build factor graph where VarDict and NDVarArray both have different number of states
num_states = np.array([2, 3, 4])
vdict = vgroup.VarDict(variable_names=tuple(["a", "b", "c"]), num_states=num_states)
varray = vgroup.NDVarArray(shape=(3,), num_states=num_states)
fg = fgraph.FactorGraph([vdict, varray])

# Add factors: we enforce the variables with same number of states to be in the same state
for var_dict, var_arr, num_state in zip(["a", "b", "c"], [0, 1, 2], num_states):
enum_factor = factor.EnumFactor(
variables=[vdict[var_dict], varray[var_arr]],
factor_configs=np.array([[idx, idx] for idx in range(num_state)]),
log_potentials=np.zeros(num_state),
)
fg.add_factors(enum_factor)

# BP functions
bp = infer.BP(fg.bp_state, temperature=0)

# Evidence for both VarDict and NDVarArray
vdict_evidence = {var: np.random.gumbel(size=(var[1],)) for var in vdict.variables}
bp_arrays = bp.init(evidence_updates=vdict_evidence)

varray_evidence = {
varray: np.random.gumbel(size=(num_states.shape[0], num_states.max()))
}
bp_arrays = bp.update(bp_arrays=bp_arrays, evidence_updates=varray_evidence)

assert np.all(bp_arrays.evidence != 0)

# Run BP
bp_arrays = bp.run_bp(bp_arrays, num_iters=50)
beliefs = bp.get_beliefs(bp_arrays)
map_states = infer.decode_map_states(beliefs)

vdict_states = map_states[vdict]
varray_states = map_states[varray]

# Verify that variables with same number of states are in the same state
for var_dict, var_arr in zip(["a", "b", "c"], [0, 1, 2]):
assert vdict_states[var_dict] == varray_states[var_arr]

0 comments on commit 7c71b14

Please sign in to comment.