Skip to content

Commit

Permalink
Major changes inbound
Browse files Browse the repository at this point in the history
  • Loading branch information
Parvfect committed Apr 9, 2024
1 parent 9a398d3 commit fad8876
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 69 deletions.
91 changes: 60 additions & 31 deletions coupon_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,30 +106,36 @@ def display_parameters(n_motifs, n_picks, dv, dc, k, n, motifs, symbols, Harr, H
print("The Codeword is \n{}\n".format(C))
return

def get_parameters(n_motifs, n_picks, dv, dc, k, n, ffdim, display=True, Harr=None, H=None, G=None):
def get_parameters(n_motifs, n_picks, dv, dc, k, n, ffdim, zero_codeword=False, display=True, Harr=None, H=None, G=None):
""" Returns the parameters for the simulation """

# Starting adresses from 1
motifs = np.arange(1, n_motifs+1)

symbols = choose_symbols(n_motifs, n_picks)

symbols.pop(-1)
symbols.pop(-2)
symbols.pop(-3)
symbols.pop()
symbols.pop()
symbols.pop()

symbol_keys = np.arange(0, ffdim)

#graph = VariableTannerGraph(dv, dc, k, n, ffdim=ffdim)
graph = TannerGraph(dv, dc, k, n, ffdim=ffdim)

if Harr is None:
Harr = r.get_H_arr(dv, dc, k, n)
H = r.get_H_Matrix(dv, dc, k, n, Harr)
G = r.parity_to_generator(H, ffdim=ffdim)

H = r.get_H_Matrix(dv, dc, k, n, Harr)
#print(H)

if zero_codeword:
G = np.zeros([k,n], dtype=int)
else:
G = r.parity_to_generator(H, ffdim=ffdim)

graph.establish_connections(Harr)



if np.any(np.dot(G, H.T) % ffdim != 0):
print("Matrices are not valid, aborting simulation")
exit()
Expand All @@ -144,9 +150,6 @@ def get_parameters(n_motifs, n_picks, dv, dc, k, n, ffdim, display=True, Harr=No
print("Codeword is not valid, aborting simulation")
exit()

if display:
display_parameters(n_motifs, n_picks, dv, dc, k, n, motifs, symbols, Harr, H, G, C, ffdim)

return graph, C, symbols, motifs

def get_parameters_sc_ldpc(n_motifs, n_picks, L, M, dv, dc, k, n, ffdim, display=True, Harr=None, H=None, G=None):
Expand All @@ -157,9 +160,9 @@ def get_parameters_sc_ldpc(n_motifs, n_picks, L, M, dv, dc, k, n, ffdim, display

symbols = choose_symbols(n_motifs, n_picks)

symbols.pop(-1)
symbols.pop(-2)
symbols.pop(-3)
symbols.pop()
symbols.pop()
symbols.pop()

symbol_keys = np.arange(0, ffdim)

Expand Down Expand Up @@ -218,7 +221,7 @@ def run_singular_decoding(graph, C, read_length, symbols, motifs, n_picks):
print("Decoding unsuccessful")
return None

def decoding_errors_fer(k, n, dv, dc, graph, C, symbols, motifs, n_picks, decoding_failures_parameter=10, max_iterations=10000, iterations=50, uncoded=False, bec_decoder=False, label=None, code_class="", read_lengths=np.arange(1,20)):
def decoding_errors_fer(k, n, dv, dc, graph, C, symbols, motifs, n_picks, decoding_failures_parameter=100, max_iterations=100, iterations=50, uncoded=False, masked = False, bec_decoder=False, label=None, code_class="", read_lengths=np.arange(1,20)):
""" Returns the frame error rate curve - for same H, same G, same C"""

frame_error_rate = []
Expand All @@ -228,10 +231,25 @@ def decoding_errors_fer(k, n, dv, dc, graph, C, symbols, motifs, n_picks, decodi
for i in tqdm(read_lengths):
decoding_failures, iterations, counter = 0, 0, 0
for j in tqdm(range(max_iterations)):
# Assigning values to Variable Nodes after generating erasures in zero array
symbols_read = read_symbols(C, i, symbols, motifs, n_picks)

#print(C[:10])
if masked:
mask = [np.random.randint(ffdim) for i in range(n)]
C2 = [(C[i] + mask[i]) % ffdim for i in range(len(C))]
symbols_read = read_symbols(C2, i, symbols, motifs, n_picks)

# Unmasking
symbols_read = [[(i - mask[j]) % ffdim for i in symbols_read[j]] for j in range(len(symbols_read))]
#print(symbols_read[:10])
#exit()
else:
symbols_read = read_symbols(C, i, symbols, motifs, n_picks)
#print(symbols_read[:10])



if not uncoded:
graph.assign_values(read_symbols(C, i, symbols, motifs, n_picks))
graph.assign_values(symbols_read)
if bec_decoder:
decoded_values = graph.coupon_collector_erasure_decoder()
else:
Expand All @@ -240,6 +258,10 @@ def decoding_errors_fer(k, n, dv, dc, graph, C, symbols, motifs, n_picks, decodi
decoded_values = symbols_read
# Getting the average error rates for iteration runs

#print(C[:10])
#print(decoded_values[:10])


# Would want to fix this ideally
if sum([len(i) for i in decoded_values]) == len(decoded_values):
if np.all(np.array(decoded_values).T[0] == C):
Expand All @@ -257,8 +279,8 @@ def decoding_errors_fer(k, n, dv, dc, graph, C, symbols, motifs, n_picks, decodi
frame_error_rate.append(error_rate)


plt.plot(read_lengths, frame_error_rate, 'o')
plt.plot(read_lengths, frame_error_rate, label=label)
plt.plot(read_lengths, frame_error_rate, 'o', label=label)
plt.plot(read_lengths, frame_error_rate)
plt.title("Frame Error Rate for CC for {}{}-{} {}-{} for 8C4 Symbols".format(code_class, k, n, dv, dc))
plt.ylabel("Frame Error Rate")
plt.xlabel("Read Length")
Expand All @@ -272,41 +294,48 @@ def decoding_errors_fer(k, n, dv, dc, graph, C, symbols, motifs, n_picks, decodi



def run_fer(n_motifs, n_picks, dv, dc, k, n, L, M, ffdim, code_class="", iterations=5, bec_decoder=False, uncoded=False, saved_code=False, singular_decoding=True, fer_errors=True, read_lengths=np.arange(1,20)):
def run_fer(n_motifs, n_picks, dv, dc, k, n, L, M, ffdim, code_class="", iterations=5, bec_decoder=False, uncoded=False, saved_code=False, singular_decoding=False, fer_errors=True, read_lengths=np.arange(1,20), zero_codeword=False, label="", Harr=None, masked=False):

if saved_code:
Harr, H, G = get_saved_code(dv, dc, k, n, L, M, code_class=code_class)

if code_class == "sc_":
graph, C, symbols, motifs = get_parameters_sc_ldpc(n_motifs, n_picks, L, M, dv, dc, k, n, ffdim, display=False)
else:
graph, C, symbols, motifs = get_parameters(n_motifs, n_picks, dv, dc, k, n, ffdim, display=True)
graph, C, symbols, motifs = get_parameters(n_motifs, n_picks, dv, dc, k, n, ffdim, display=False, zero_codeword=zero_codeword, Harr=Harr)

if singular_decoding:
run_singular_decoding(graph, C, 8, symbols, motifs, n_picks)

if bec_decoder:
elif bec_decoder:
print(decoding_errors_fer(k, n, dv, dc, graph, C, symbols, motifs, n_picks, iterations=iterations, bec_decoder=True, label='Erasure Decoder', code_class=code_class, read_lengths=read_lengths))

if uncoded:
print(decoding_errors_fer(k, n, dv, dc, graph, C, symbols, motifs, n_picks, iterations=iterations, uncoded=True, label='Uncoded', code_class=code_class, read_lengths=read_lengths))
print(decoding_errors_fer(k, n, dv, dc, graph, C, symbols, motifs, n_picks, iterations=iterations, label=f'CC Decoder', code_class=code_class, read_lengths=read_lengths))
elif uncoded:
print(decoding_errors_fer(k, n, dv, dc, graph, C, symbols, motifs, n_picks, iterations=iterations, uncoded=True, label=f'{label} Uncoded', code_class=code_class, read_lengths=read_lengths))

print(decoding_errors_fer(k, n, dv, dc, graph, C, symbols, motifs, n_picks, iterations=iterations, label=f'{label} CC Decoder', code_class=code_class, read_lengths=read_lengths, masked=masked))

plt.grid()
plt.legend()
plt.show()
#plt.show()


if __name__ == "__main__":
with Profile() as prof:
n_motifs, n_picks = 8, 4
dv, dc, ffdim = 3, 9, 67
k, n = 100, 150
k, n = 30, 45
L, M = 50, 1002
read_length = 6
read_lengths = np.arange(4,5)
run_fer(n_motifs, n_picks, dv, dc, k, n, L, M, ffdim, code_class="sc_", saved_code=False, uncoded=False, bec_decoder=False, read_lengths=read_lengths)
read_lengths = np.arange(5,12)

Harr = r.get_H_arr(dv, dc, k, n)
masked = True

run_fer(n_motifs, n_picks, dv, dc, k, n, L, M, ffdim, code_class="", saved_code=False, uncoded=True, bec_decoder=False, read_lengths=read_lengths, zero_codeword=True, label="ZeroCW", Harr=Harr, masked=masked)

run_fer(n_motifs, n_picks, dv, dc, k, n, L, M, ffdim, code_class="", saved_code=False, uncoded=True, bec_decoder=False, read_lengths=read_lengths, zero_codeword=False, label="FullCW", Harr=Harr, masked=masked)
plt.show()
(
Stats(prof)
.strip_dirs()
Expand Down
Loading

0 comments on commit fad8876

Please sign in to comment.