Skip to content

Commit

Permalink
updated test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
SHoltzen committed Aug 23, 2019
1 parent f310f15 commit 0409021
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 6 deletions.
1 change: 1 addition & 0 deletions .#experiment.py
1 change: 1 addition & 0 deletions .#markov.py
1 change: 1 addition & 0 deletions .#my_graphs.py
105 changes: 104 additions & 1 deletion experiment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,106 @@
### experiments from the paper
### not all of these made it into the paper, but they give a good idea of how
### the code works

from sage.all import *
import numpy.random
from collections import deque
import random
import numpy as np
import my_bliss
from my_graphs import *
import cProfile, pstats, StringIO
from test import *
import itertools
import time

# holes pigeons, m holes
def mk_pigeonhole_fg(n, m, order=True):
w1 = 10000000
w2 = 100000
(g, (variables, factors)) = gen_pigeonhole_fg(n, m)
def potential(state):
total = 0.0
# to see every pigeon in exactly one hole
for p in range(0, n):
# check the holes for the pigeons
in_hole = False
for h in range(0, m):
if state[(p, h)]:
if in_hole:
return 0.000000001
else:
in_hole = True
if in_hole:
total += w1


# check to see no no hole has 2 pigeons
for h in range(0, m):
for (p1, p2) in findsubsets(range(0, n), 2):
if not state[(p1, h)] or not state[(p2, h)]:
total += w2

return total
return FactorGraph(g, variables, factors, potential)


### computes the total variation distance comparing different sampling methods
def complete_pairwise_dtv():
model = gen_complete_pairwise_factorgraph(6)
gibbs = model.gibbs_transition()
# print(gibbs)
# print(sum(gibbs))
within_orbit = model.orbit_transition()
orbitalmcmc = np.matmul(within_orbit, gibbs)
# unif = model.uniform_transition()
burnside = model.burnside_mh_transition(4)
M = orbitalmcmc

pv = model.brute_force_prob_vector()
start = np.zeros([2**len(model.variables)])
start[10] = 1
# print(np.linalg.matrix_power(M, 5))
print("-------------------")
print("pure gibbs")
model.total_variation(gibbs, start, 100)
print("-------------------")
print("lifted MCMC")
model.total_variation(orbitalmcmc, start, 100)
print("------------------")
print("orbit jump MCMC")
model.total_variation(burnside, start, 100)

def complete_pairwise_exact():
model = gen_complete_pairwise_factorgraph(6)
print("partition: %f" % model.partition())

def pigeonhole_dtv():
model = mk_pigeonhole_fg(2,5)
gibbs = model.gibbs_transition()
# print(gibbs)
# print(sum(gibbs))
within_orbit = model.orbit_transition()
orbitalmcmc = np.matmul(within_orbit, gibbs)
# unif = model.uniform_transition()
burnside = model.burnside_mh_transition(4)
M = orbitalmcmc

pv = model.brute_force_prob_vector()
start = np.zeros([2**len(model.variables)])
start[10] = 1
# print(np.linalg.matrix_power(M, 5))
print("-------------------")
print("pure gibbs")
model.total_variation(gibbs, start, 100)
print("-------------------")
print("lifted MCMC")
model.total_variation(orbitalmcmc, start, 100)
print("------------------")
print("orbit jump MCMC")
model.total_variation(burnside, start, 100)


if __name__ == "__main__":
None
# run your test here if you want
pigeonhole_dtv()
6 changes: 3 additions & 3 deletions factor.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def burnside_transition(self):
state_to_idx[st] = idx
idx_to_state[idx] = st

print "states_to_idx: %s" % state_to_idx
transition = np.zeros([len(states),len(states)])
for (idx, s) in enumerate(states):
var_part = self.state_to_partition(dict(s))
Expand Down Expand Up @@ -456,7 +455,8 @@ def orbitjumpmcmc(self, n, query, burnsidesize=10, gamma=10, burn=100):
return query_count / n

if __name__ == "__main__":
model = gen_complete_pairwise_factorgraph(6)
model = gen_complete_pairwise_factorgraph_half(6)
print(model)
gibbs = model.gibbs_transition()
# print(gibbs)
# print(sum(gibbs))
Expand All @@ -472,7 +472,7 @@ def orbitjumpmcmc(self, n, query, burnsidesize=10, gamma=10, burn=100):
# print(np.linalg.matrix_power(M, 5))
print("-------------------")
print("pure gibbs")
model.total_variation(gibbs, start, 100)
model.total_variation(orbitalmcmc, start, 100)
print("------------------")
print("pure jump")
model.total_variation(burnside, start, 100)
Expand Down
37 changes: 37 additions & 0 deletions my_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,43 @@ def gen_pigeonhole(n,m):
return g


def gen_pigeonhole_fg(n,m):
# n = holes, m = pigeons
g = Graph()
v = []
e = []
pigeon_factors = []
hole_factors = []
# generate vertices
for x in range(0,n):
for y in range(0,m):
v += [(x,y)]

count = 0
# generate edges
# generate fully connected graph in n
for x in range(0,n):
for y in findsubsets(range(0, m), 2):
count += 1
pigeon_factors.append(count)
e.append(tuple([(x, y[0]), count]))
e.append(tuple([count, (x, y[1])]))
# connect between n and m
for x in range(0,n-1):
for y in range (0,m):
count += 1
hole_factors.append(count)
e.append(tuple([(x, y), count]))
e.append(tuple([count, (((x + 1) % n), y)]))
g.add_vertices(v)
g.add_vertices(pigeon_factors)
g.add_vertices(hole_factors)
g.add_edges(e)
return (g, (v, [hole_factors, pigeon_factors]))




# generates a complete graph with n vertices with some extra nodes on each
# vertex
def gen_complete_extra(n):
Expand Down
42 changes: 40 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,43 @@
### defines a test suite for testing basic functionality
### invoke `sage test.py`

from sage.all import *
import orbitgen
from my_graphs import *
import unittest
from factor import *

# holes pigeons, m holes
def mk_pigeonhole_fg(n, m, order=True):
w1 = 10000000
w2 = 100000
(g, (variables, factors)) = gen_pigeonhole_fg(n, m)
def potential(state):
total = 0.0
# to see every pigeon in exactly one hole
for p in range(0, n):
# check the holes for the pigeons
in_hole = False
for h in range(0, m):
if state[(p, h)]:
if in_hole:
return 0.000000001
else:
in_hole = True
if in_hole:
total += w1


# check to see no no hole has 2 pigeons
for h in range(0, m):
for (p1, p2) in findsubsets(range(0, n), 2):
if not state[(p1, h)] or not state[(p2, h)]:
total += w2

return total
return FactorGraph(g, variables, factors, potential)


# generate a complete pairwise factor graph with half the factors different colors
def gen_complete_pairwise_factorgraph_half(n):
(g, (v, factors)) = gen_complete_pairwise_factor(n)
Expand Down Expand Up @@ -78,13 +110,19 @@ def test_aug_complete(self):

def test_pairwise(self):
fg = gen_complete_pairwise_factorgraph(10)
self.assertEqual(g.partition(), fg.brute_force_partition)
self.assertAlmostEqual(fg.partition(), fg.brute_force_partition())

def test_half_pairwise(self):
fg = gen_complete_pairwise_factorgraph_half(10)
self.assertEqual(g.partition(), fg.brute_force_partition)
self.assertAlmostEqual(fg.partition(), fg.brute_force_partition())

def test_pigeonhole(self):
fg = mk_pigeonhole_fg(2,4)
self.assertAlmostEqual(fg.partition(), fg.brute_force_partition())

def test_pigeonhole_2(self):
fg = mk_pigeonhole_fg(3,6)
self.assertEqual(int(fg.partition()), int(fg.brute_force_partition()))


if __name__ == '__main__':
Expand Down

0 comments on commit 0409021

Please sign in to comment.