From c8720f33bd95df647597d6b3576619ba38786687 Mon Sep 17 00:00:00 2001 From: Hanting Zhang Date: Thu, 8 Feb 2024 22:50:55 +0000 Subject: [PATCH] no pidx --- msm/pippenger.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/msm/pippenger.cuh b/msm/pippenger.cuh index 2911a38..52be562 100644 --- a/msm/pippenger.cuh +++ b/msm/pippenger.cuh @@ -448,7 +448,7 @@ class msm_t { public: // Compute various constants (stride length, window size) based on the number of scalars. // Also allocate scratch space. - void setup_scratch(size_t nscalars) { + void setup_scratch(size_t nscalars, uint32_t *pidx) { this->nscalars = nscalars; uint32_t lg_n = lg2(nscalars + nscalars / 2); @@ -477,7 +477,7 @@ class msm_t { size_t temp_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t)); size_t digits_sz = nwins * stride * sizeof(uint32_t); - size_t pidx_sz = stride * sizeof(uint32_t); + size_t pidx_sz = pidx ? stride * sizeof(uint32_t) : 0; size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + pidx_sz + digits_sz; @@ -498,7 +498,7 @@ class msm_t { RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) { // assert(this->nscalars <= nscalars); - setup_scratch(nscalars); + setup_scratch(nscalars, pidx); std::vector res(nwins); std::vector ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);