Skip to content

Commit

Permalink
Matrice qspa is broken, still some diff in perf for zero vs normal C
Browse files Browse the repository at this point in the history
  • Loading branch information
Parvfect committed Apr 3, 2024
1 parent 1ef4eb0 commit 9a398d3
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 19 deletions.
16 changes: 11 additions & 5 deletions distracted_coupon_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import row_echleon as r
import numpy as np
from itertools import combinations
import sys
from tqdm import tqdm
from protograph_interface import get_Harr_sc_ldpc, get_dv_dc
from tanner import VariableTannerGraph
Expand Down Expand Up @@ -196,7 +197,7 @@ def simulate_reads(C, symbols, read_length, P, n_motifs, n_picks):

return likelihood_arr

def decoding_errors_fer(k, n, dv, dc, P, H, G, GF, graph, C, symbols, n_motifs, n_picks, decoder=None, decoding_failures_parameter=20, max_iterations=50, iterations=50, uncoded=False, bec_decoder=False, label=None, code_class="", read_lengths=np.arange(1,20)):
def decoding_errors_fer(k, n, dv, dc, P, H, G, GF, graph, C, symbols, n_motifs, n_picks, decoder=None, decoding_failures_parameter=200, max_iterations=20000, iterations=50, uncoded=False, bec_decoder=False, label=None, code_class="", read_lengths=np.arange(1,20)):

frame_error_rate = []
max_iterations = max_iterations
Expand All @@ -207,13 +208,16 @@ def decoding_errors_fer(k, n, dv, dc, P, H, G, GF, graph, C, symbols, n_motifs,
for j in tqdm(range(max_iterations)):
symbol_likelihoods_arr = np.array(simulate_reads(C, symbols, i, P, n_motifs, n_picks))

#symbol_likelihoods_arr = [[1/67]*67]*n
#print(symbol_likelihoods_arr)

if not decoder:
z = graph.qspa_decode(symbol_likelihoods_arr, H, GF)
else:
z = decoder.decode(symbol_likelihoods_arr, max_iter=20)

#print(C)
#print(z)
print(z)

if np.array_equal(C, z):
counter += 1
Expand Down Expand Up @@ -262,7 +266,7 @@ def run_fer(n_motifs, n_picks, dv, dc, k, n, L, M, ffdim, P, code_class="", iter

plt.legend()
plt.grid()
plt.show()
#plt.show()


if __name__ == "__main__":
Expand All @@ -272,11 +276,13 @@ def run_fer(n_motifs, n_picks, dv, dc, k, n, L, M, ffdim, P, code_class="", iter
k, n = 22, 33
L, M = 12, 51
read_length = 6
read_lengths = np.arange(7, 13)
read_lengths = np.arange(9, 10)


run_fer(n_motifs, n_picks, dv, dc, k, n, L, M, ffdim, P, code_class="", uncoded=False, zero_codeword=False, bec_decoder=False, graph_decoding=False, read_lengths=read_lengths)
#run_fer(n_motifs, n_picks, dv, dc, k, n, L, M, ffdim, P, code_class="", uncoded=False, zero_codeword=False, bec_decoder=False, graph_decoding=False, read_lengths=read_lengths)
run_fer(n_motifs, n_picks, dv, dc, k, n, L, M, ffdim, P, code_class="", uncoded=False, zero_codeword=False, bec_decoder=False, graph_decoding=True, read_lengths=read_lengths)
run_fer(n_motifs, n_picks, dv, dc, k, n, L, M, ffdim, P, code_class="", uncoded=False, zero_codeword=True, bec_decoder=False, graph_decoding=True, read_lengths=read_lengths)
plt.show()
(
Stats(prof)
.strip_dirs()
Expand Down
8 changes: 6 additions & 2 deletions qspa_conv.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from collections import defaultdict
from more_itertools import distinct_permutations

import random
import galois
import numpy as np
from scipy.stats import multinomial
from math import factorial
import utils
import math

def get_max_symbol(prob_arr):
max_val = np.max(prob_arr)
max_indices = [i for i, val in enumerate(prob_arr) if val == max_val]
return random.choice(max_indices)

def multinomial_def(n, x, p):
return (factorial(n)/(np.prod([factorial(i) for i in x])))*(np.prod([p[j]**x[j] for j in range(len(x))]))
Expand Down Expand Up @@ -330,7 +334,7 @@ def decode_hard(self, P, S):
for a in range(self.GF.order):
for i in idxs:
probs[a] *= S[i, j, a]
z[j] = np.argmax(probs)
z[j] = get_max_symbol(probs)
z = self.GF(z.astype(int))
return z

Expand Down
51 changes: 39 additions & 12 deletions tanner_qspa.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@

from tanner import VariableTannerGraph, conv_circ
import numpy as np
import random

def get_max_symbol(prob_arr):
max_val = np.max(prob_arr)
max_indices = [i for i, val in enumerate(prob_arr) if val == max_val]
#print(prob_arr)
#print(max_indices)
return random.choice(max_indices)

class TannerQSPA(VariableTannerGraph):

Expand Down Expand Up @@ -47,19 +55,30 @@ def qspa_decode(self, symbol_likelihood_arr, H, GF, max_iterations=20):
#for i in range(max_iterations):
while(True):


self.cn_update()

#print()
#print(f"Iteration {iterations+1}")
#print()
max_prob_codeword = self.get_max_prob_codeword(symbol_likelihood_arr, GF)
#print(max_prob_codeword)

#print(sum(random.choice(list(self.cn_links.items()))[1]))

parity = not np.matmul(H, max_prob_codeword).any()
if parity:
print("Decoding converges")
return max_prob_codeword

self.vn_update(symbol_likelihood_arr)
#print(sum(random.choice(list(self.vn_links.items()))[1]))


if np.array_equal(max_prob_codeword, prev_max_prob_codeword) or iterations > max_iterations:
if iterations > max_iterations:
break
#if np.array_equal(max_prob_codeword, prev_max_prob_codeword) or iterations > max_iterations:
# break

prev_max_prob_codeword = max_prob_codeword

Expand Down Expand Up @@ -95,7 +114,9 @@ def get_max_prob_codeword(self, P, GF):
probs[a] *= self.cn_links[(cn, vn_index)][a]

# Most likely symbol is the Symbol with the highest probability
z[vn_index] = np.argmax(probs)
#print(probs)
z[vn_index] = get_max_symbol(probs)
#print(z)

return GF(z.astype(int))

Expand Down Expand Up @@ -124,32 +145,38 @@ def cn_update(self):

# Updating the CN Link weight with the conv value
self.cn_links[(cn_index, vn)] = pdf[self.idx_shuffle]
#print(sum(self.cn_links[(cn_index, vn)]))

def vn_update(self, P):
""" Updates the CN as per the QSPA Decoding. Conditional Probability of a Symbol being favoured yadayada """

# Use the CN links to update the VN links by taking the favoured probabilities

# Iterating through all the Symbols
for a in range(self.GF.order):

# For each VN
for vn in self.vns:
vn_index = vn.identifier

# For each VN
for vn in self.vns:
vn_index = vn.identifier

for cn in vn.links:
for a in range(self.GF.order):

for cn in vn.links:

self.vn_links[(cn, vn_index)][a] = P[vn_index][a]

for t in vn.links:

# Iterating through all the other cns besides selected
if t == cn:
continue

self.vn_links[(cn, vn_index)][a] *= self.cn_links[(t, vn_index)][a]

sum_copy_links = np.einsum('i->', self.vn_links[(cn, vn_index)])
self.vn_links[(cn, vn_index)] = self.vn_links[(cn, vn_index)]/sum_copy_links

# Normalizing
#sum_copy_links = np.einsum('i->', self.vn_links[(cn, vn_index)])
sum_copy_links = np.sum(self.vn_links[(cn, vn_index)])
self.vn_links[(cn, vn_index)] = self.vn_links[(cn, vn_index)]/sum_copy_links





0 comments on commit 9a398d3

Please sign in to comment.