Skip to content

Commit

Permalink
Merge pull request encryptogroup#126 from lenerd/lenerd/phasing-threads
Browse files Browse the repository at this point in the history
Update PSI Phasing Example
  • Loading branch information
dd23 authored Mar 13, 2019
2 parents 6eb1030 + 0fe30d2 commit d61d1d9
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 57 deletions.
42 changes: 21 additions & 21 deletions src/examples/psi_phasing/common/hashing/cuckoo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
*/

#include <iostream>
#include <thread>
#include <vector>
#include "cuckoo.h"

//returns a cuckoo hash table with the first dimension being the bins and the second dimension being the pointer to the elements
Expand All @@ -22,11 +24,8 @@ cuckoo_hashing(uint8_t* elements, uint32_t neles, uint32_t nbins, uint32_t bitle
uint8_t* hash_table;
cuckoo_entry_ctx** cuckoo_table;
cuckoo_entry_ctx** cuckoo_stash;
cuckoo_entry_ctx* cuckoo_entries;
uint32_t i, j, stashctr=0, elebytelen;
uint32_t *perm_ptr;
pthread_t* entry_gen_tasks;
cuckoo_entry_gen_ctx* ctx;
hs_t hs;
elebytelen = ceil_divide(bitlen, 8);

Expand All @@ -42,28 +41,34 @@ cuckoo_hashing(uint8_t* elements, uint32_t neles, uint32_t nbins, uint32_t bitle
cuckoo_table = (cuckoo_entry_ctx**) calloc(nbins, sizeof(cuckoo_entry_ctx*));
cuckoo_stash = (cuckoo_entry_ctx**) calloc(maxstashsize, sizeof(cuckoo_entry_ctx*));

cuckoo_entries = (cuckoo_entry_ctx*) malloc(neles * sizeof(cuckoo_entry_ctx));
entry_gen_tasks = (pthread_t*) malloc(sizeof(pthread_t) * ntasks);
ctx = (cuckoo_entry_gen_ctx*) malloc(sizeof(cuckoo_entry_gen_ctx) * ntasks);
std::vector<cuckoo_entry_ctx> cuckoo_entries(neles);
std::vector<std::thread> entry_gen_tasks(ntasks);
std::vector<cuckoo_entry_gen_ctx> ctx(ntasks);

#ifndef TEST_UTILIZATION
for(i = 0; i < ntasks; i++) {
ctx[i].elements = elements;
ctx[i].cuckoo_entries = cuckoo_entries;
ctx[i].cuckoo_entries = cuckoo_entries.data();
ctx[i].hs = &hs;
ctx[i].startpos = i * ceil_divide(neles, ntasks);
ctx[i].endpos = std::min(ctx[i].startpos + ceil_divide(neles, ntasks), neles);
//std::cout << "Thread " << i << " starting from " << ctx[i].startpos << " going to " << ctx[i].endpos << " for " << neles << " elements" << std::endl;
if(pthread_create(entry_gen_tasks+i, NULL, gen_cuckoo_entries, (void*) (ctx+i))) {
std::cerr << "Error in creating new pthread at cuckoo hashing!" << std::endl;
exit(0);
try {
entry_gen_tasks[i] = std::thread(gen_cuckoo_entries, &ctx[i]);
} catch (const std::system_error& e) {
std::cerr << "Error in creating new thread at cuckoo hashing!\n"
<< e.what() << std::endl;
exit(1);
}
}

for(i = 0; i < ntasks; i++) {
if(pthread_join(entry_gen_tasks[i], NULL)) {
std::cerr << "Error in joining pthread at cuckoo hashing!" << std::endl;
exit(0);
try {
entry_gen_tasks[i].join();
} catch (const std::system_error& e) {
std::cerr << "Error in joining thread at cuckoo hashing!\n"
<< e.what() << std::endl;
exit(1);
}
}
#else
Expand All @@ -79,7 +84,7 @@ cuckoo_hashing(uint8_t* elements, uint32_t neles, uint32_t nbins, uint32_t bitle
//}
//insert all elements into the cuckoo hash table
for(i = 0; i < neles; i++) {
if(!(insert_element(cuckoo_table, cuckoo_entries + i, neles, hs.nhashfuns))) {
if(!(insert_element(cuckoo_table, &cuckoo_entries[i], neles, hs.nhashfuns))) {
#ifdef COUNT_FAILS
fails++;
/*std::cout << "insertion failed for element " << (hex) << (*(((uint32_t*) elements)+i)) << ", inserting to address: ";
Expand All @@ -90,7 +95,7 @@ cuckoo_hashing(uint8_t* elements, uint32_t neles, uint32_t nbins, uint32_t bitle
#else
if(stashctr < maxstashsize) {
std::cout << "Insertion not successful for element " << i <<", putting it on the stash" << std::endl;
cuckoo_stash[stashctr] = cuckoo_entries+i;
cuckoo_stash[stashctr] = &cuckoo_entries[i];
stashctr++;
} else {
std::cerr << "Stash exceeded maximum stash size of " << maxstashsize << ", terminating program" << std::endl;
Expand Down Expand Up @@ -147,11 +152,8 @@ cuckoo_hashing(uint8_t* elements, uint32_t neles, uint32_t nbins, uint32_t bitle
free(cuckoo_entries[i].address);
}
#endif
free(cuckoo_entries);
free(cuckoo_table);
free(cuckoo_stash);
free(entry_gen_tasks);
free(ctx);

free_hashing_state(&hs);

Expand All @@ -163,8 +165,7 @@ cuckoo_hashing(uint8_t* elements, uint32_t neles, uint32_t nbins, uint32_t bitle
}


void *gen_cuckoo_entries(void *ctx_void) {
cuckoo_entry_gen_ctx* ctx = (cuckoo_entry_gen_ctx*) ctx_void;
void gen_cuckoo_entries(cuckoo_entry_gen_ctx* ctx) {
hs_t* hs = ctx->hs;
uint32_t i, inbytelen = ceil_divide(hs->inbitlen, 8);
uint8_t* eleptr = ctx->elements + inbytelen * ctx->startpos;
Expand All @@ -174,7 +175,6 @@ void *gen_cuckoo_entries(void *ctx_void) {
for(i = ctx->startpos; i < ctx->endpos; i++, eleptr+=inbytelen) {
gen_cuckoo_entry(eleptr, ctx->cuckoo_entries + i, hs, i);
}
return nullptr;
}


Expand Down
2 changes: 1 addition & 1 deletion src/examples/psi_phasing/common/hashing/cuckoo.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ cuckoo_hashing(uint8_t* elements, uint32_t neles, uint32_t nbins, uint32_t bitle
uint32_t* perm, uint32_t ntasks, uint8_t** stash_elements, uint32_t maxstashsize, uint32_t** stashperm, uint32_t nhashfuns,
prf_state_ctx* prf_state);
//routine for generating the entries, is invoked by the threads
void *gen_cuckoo_entries(void *ctx);
void gen_cuckoo_entries(cuckoo_entry_gen_ctx* ctx);
inline void gen_cuckoo_entry(uint8_t* in, cuckoo_entry_ctx* out, hs_t* hs, uint32_t ele_id);
inline bool insert_element(cuckoo_entry_ctx** ctable, cuckoo_entry_ctx* element, uint32_t max_iterations, uint32_t nhashfuns);
inline uint32_t compute_stash_size(uint32_t nbins, uint32_t neles);
Expand Down
61 changes: 30 additions & 31 deletions src/examples/psi_phasing/common/hashing/simple_hashing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,24 @@
*/

#include <iostream>
#include <thread>
#include <vector>
#include "simple_hashing.h"

uint8_t* simple_hashing(uint8_t* elements, uint32_t neles, uint32_t bitlen, uint32_t *outbitlen, uint32_t* nelesinbin, uint32_t nbins,
uint32_t* maxbinsize, uint32_t ntasks, uint32_t nhashfuns, prf_state_ctx* prf_state) {
sht_ctx* table;
//uint8_t** bin_content;
uint8_t *eleptr, *bin_ptr, *result, *res_bins;
uint32_t i, j, tmpneles;
sheg_ctx* ctx;
pthread_t* entry_gen_tasks;
hs_t hs;

init_hashing_state(&hs, neles, bitlen, nbins, nhashfuns, prf_state);
//Set the output bit-length of the hashed elements
*outbitlen = hs.outbitlen;

entry_gen_tasks = (pthread_t*) malloc(sizeof(pthread_t) * ntasks);
ctx = (sheg_ctx*) malloc(sizeof(sheg_ctx) * ntasks);
table = (sht_ctx*) malloc(sizeof(sht_ctx) * ntasks);
std::vector<std::thread> entry_gen_tasks(ntasks);
std::vector<sheg_ctx> ctx(ntasks);
std::vector<sht_ctx> table(ntasks);


//in case no maxbinsize is specified, compute based on Eq3 in eprint 2016/930
Expand All @@ -35,7 +34,7 @@ uint8_t* simple_hashing(uint8_t* elements, uint32_t neles, uint32_t bitlen, uint
}

for(i = 0; i < ntasks; i++) {
init_hash_table(table + i, ceil_divide(neles, ntasks), &hs, *maxbinsize);
init_hash_table(&table[i], ceil_divide(neles, ntasks), &hs, *maxbinsize);
}

//for(i = 0; i < nbins; i++)
Expand All @@ -45,26 +44,32 @@ uint8_t* simple_hashing(uint8_t* elements, uint32_t neles, uint32_t bitlen, uint

for(i = 0; i < ntasks; i++) {
ctx[i].elements = elements;
ctx[i].table = table + i;
ctx[i].table = &table[i];
ctx[i].startpos = i * ceil_divide(neles, ntasks);
ctx[i].endpos = std::min(ctx[i].startpos + ceil_divide(neles, ntasks), neles);
ctx[i].hs = &hs;

//std::cout << "Thread " << i << " starting from " << ctx[i].startpos << " going to " << ctx[i].endpos << " for " << neles << " elements" << std::endl;
if(pthread_create(entry_gen_tasks+i, NULL, gen_entries, (void*) (ctx+i))) {
std::cerr << "Error in creating new pthread at simple hashing!" << std::endl;
exit(0);
try {
entry_gen_tasks[i] = std::thread(gen_entries, &ctx[i]);
} catch (const std::system_error& e) {
std::cerr << "Error in creating new thread at simple hashing!\n"
<< e.what() << std::endl;
exit(1);
}
}

for(i = 0; i < ntasks; i++) {
if(pthread_join(entry_gen_tasks[i], NULL)) {
std::cerr << "Error in joining pthread at simple hashing!" << std::endl;
exit(0);
try {
entry_gen_tasks[i].join();
} catch (const std::system_error& e) {
std::cerr << "Error in joining thread at simple hashing!\n"
<< e.what() << std::endl;
exit(1);
}
}

*maxbinsize = table->maxbinsize;
*maxbinsize = table[0].maxbinsize;

//for(i = 0, eleptr=elements; i < neles; i++, eleptr+=inbytelen) {
// insert_element(table, eleptr, tmpbuf);
Expand All @@ -81,20 +86,17 @@ uint8_t* simple_hashing(uint8_t* elements, uint32_t neles, uint32_t bitlen, uint
for(i = 0; i < hs.nbins; i++) {
nelesinbin[i] = 0;
for(j = 0; j < ntasks; j++) {
tmpneles = (table +j)->bins[i].nvals;
tmpneles = table[j].bins[i].nvals;
nelesinbin[i] += tmpneles;
//bin_content[i] = (uint8_t*) malloc(nelesinbin[i] * table->outbytelen);
memcpy(bin_ptr, (table + j)->bins[i].values, tmpneles * hs.outbytelen);
memcpy(bin_ptr, table[j].bins[i].values, tmpneles * hs.outbytelen);
bin_ptr += (tmpneles * hs.outbytelen);
}
//right now only the number of elements in each bin is copied instead of the max bin size
}

for(j = 0; j < ntasks; j++)
free_hash_table(table + j);
free(table);
free(entry_gen_tasks);
free(ctx);
free_hash_table(&table[j]);

//for(i = 0; i < nbins; i++)
// pthread_mutex_destroy(locks+i);
Expand All @@ -105,23 +107,20 @@ uint8_t* simple_hashing(uint8_t* elements, uint32_t neles, uint32_t bitlen, uint
return res_bins;
}

void *gen_entries(void *ctx_tmp) {
void gen_entries(sheg_ctx* ctx) {
//Insert elements in parallel, use lock to communicate
uint8_t *tmpbuf, *eleptr;
sheg_ctx* ctx = (sheg_ctx*) ctx_tmp;
uint32_t i, inbytelen, *address;
uint8_t *eleptr;
uint32_t i, inbytelen;

address = (uint32_t*) malloc(ctx->hs->nhashfuns * sizeof(uint32_t));
tmpbuf = (uint8_t*) calloc(ceil_divide(ctx->hs->outbitlen, 8), sizeof(uint8_t)); //for(i = 0; i < NUM_HASH_FUNCTIONS; i++) {
std::vector<uint32_t> address(ctx->hs->nhashfuns);
std::vector<uint8_t> tmpbuf(ceil_divide(ctx->hs->outbitlen, 8), 0);
//for(i = 0; i < NUM_HASH_FUNCTIONS; i++) {
// tmpbuf[i] = (uint8_t*) malloc(ceil_divide(ctx->hs->outbitlen, 8));
//}

for(i = ctx->startpos, eleptr=ctx->elements, inbytelen=ctx->hs->inbytelen; i < ctx->endpos; i++, eleptr+=inbytelen) {
insert_element(ctx->table, eleptr, address, tmpbuf, ctx->hs);
insert_element(ctx->table, eleptr, address.data(), tmpbuf.data(), ctx->hs);
}
free(tmpbuf);
free(address);
return nullptr;
}

inline void insert_element(sht_ctx* table, uint8_t* element, uint32_t* address, uint8_t* tmpbuf, hs_t* hs) {
Expand Down
2 changes: 1 addition & 1 deletion src/examples/psi_phasing/common/hashing/simple_hashing.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ typedef struct simple_hash_entry_gen_ctx {
uint8_t* simple_hashing(uint8_t* elements, uint32_t neles, uint32_t bitlen, uint32_t* outbitlen, uint32_t* nelesinbin, uint32_t nbins,
uint32_t* maxbinsize, uint32_t ntasks, uint32_t nhashfuns, prf_state_ctx* prf_state);
//routine for generating the entries, is invoked by the threads
void *gen_entries(void *ctx);
void gen_entries(sheg_ctx *ctx);
void init_hash_table(sht_ctx* table, uint32_t nelements, hs_t* hs, uint32_t maxbinsize);
void increase_max_bin_size(sht_ctx* table, uint32_t valbytelen);
void free_hash_table(sht_ctx* table);
Expand Down
17 changes: 14 additions & 3 deletions src/examples/psi_phasing/common/phasing_circuit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,11 @@ int32_t test_phasing_circuit(e_role role, const std::string& address, uint16_t p
}

shr_stash_out = BuildPhasingStashCircuit(shr_srv_set, shr_cli_stash, server_neles, bitlen, maxstashsize, circ);
shr_stash_out = circ->PutOUTGate(shr_stash_out, CLIENT);
{
auto tmp = shr_stash_out;
shr_stash_out = circ->PutOUTGate(shr_stash_out, CLIENT);
delete tmp;
}
party->ExecCircuit();

//Only the client obtains the outputs and performs the checks
Expand Down Expand Up @@ -189,6 +193,7 @@ int32_t test_phasing_circuit(e_role role, const std::string& address, uint16_t p
for(uint32_t i = 0; i < circ_inter_ctr; i++) {
assert(ver_intersect[i] == circ_intersect[i]);
}
free(output);
}

#ifdef BATCH
Expand All @@ -200,12 +205,18 @@ int32_t test_phasing_circuit(e_role role, const std::string& address, uint16_t p
free(srv_set);
free(cli_set);
free(shr_srv_hash_table);
free(shr_cli_hash_table);
free(shr_out);
free(shr_cli_stash);
delete shr_cli_hash_table;
delete shr_out;
delete shr_stash_out;
free(ver_intersect);
free(circ_intersect);
free(inv_perm);
free(stash);
free(client_hash_table);
free(server_hash_table);
free(stashperm);
delete crypt;

return 0;
}
Expand Down

0 comments on commit d61d1d9

Please sign in to comment.