diff --git a/msm/pippenger.cuh b/msm/pippenger.cuh index 3261ecf..427abe2 100644 --- a/msm/pippenger.cuh +++ b/msm/pippenger.cuh @@ -449,8 +449,6 @@ class msm_t { // Compute various constants (stride length, window size) based on the number of scalars. // Also allocate scratch space. void setup_scratch(size_t nscalars) { - this->nscalars = nscalars; - uint32_t lg_n = lg2(nscalars + nscalars / 2); wbits = 17; @@ -469,10 +467,10 @@ class msm_t { d_buckets_sz *= sizeof(d_buckets[0]); size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t); - size_t batch = 1 << (std::max(lg_n, wbits) - wbits); + this->batch = 1 << (std::max(lg_n, wbits) - wbits); batch >>= 6; batch = batch ? batch : 1; - uint32_t stride = (nscalars + batch - 1) / batch; + this->stride = (nscalars + batch - 1) / batch; stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ); size_t temp_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t)); @@ -496,48 +494,7 @@ class msm_t { RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) { // assert(this->nscalars <= nscalars); - - uint32_t lg_n = lg2(nscalars + nscalars / 2); - - wbits = 17; - if (nscalars > 192) { - wbits = std::min(lg_n, (uint32_t)18); - if (wbits < 10) - wbits = 10; - } else if (nscalars > 0) { - wbits = 10; - } - nwins = (scalar_t::bit_length() - 1) / wbits + 1; - - uint32_t row_sz = 1U << (wbits - 1); - - size_t d_buckets_sz = (nwins * row_sz) + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ); - d_buckets_sz *= sizeof(d_buckets[0]); - size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t); - - size_t batch = 1 << (std::max(lg_n, wbits) - wbits); - batch >>= 6; - batch = batch ? batch : 1; - uint32_t stride = (nscalars + batch - 1) / batch; - stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ); - - size_t temp_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t)); - - size_t digits_sz = nwins * stride * sizeof(uint32_t); - - size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz; - - d_total_blob = reinterpret_cast(gpu.Dmalloc(d_blob_sz)); - size_t offset = 0; - d_buckets = reinterpret_cast(&d_total_blob[offset]); - offset += d_buckets_sz; - d_hist = vec2d_t((uint32_t *)&d_total_blob[offset], row_sz); - offset += d_hist_sz; - - d_temps = vec2d_t((uint2 *)&d_total_blob[offset], stride); - d_scalars = (scalar_t *)&d_total_blob[offset]; - offset += temp_sz; - d_digits = vec2d_t((uint32_t *)&d_total_blob[offset], stride); + setup_scratch(nscalars); std::vector res(nwins); std::vector ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);