diff --git a/msm/pippenger.cuh b/msm/pippenger.cuh
index 0481abc..24c5f70 100644
--- a/msm/pippenger.cuh
+++ b/msm/pippenger.cuh
@@ -318,7 +318,7 @@ template <class bucket_t, class point_t, class affine_t, class scalar_t, class a
           class bucket_h = class bucket_t::mem_t>
 class msm_t {
     const gpu_t &gpu;
-    
+
     // main data
     bool owned;
     affine_h *d_points;
@@ -346,7 +346,7 @@ class msm_t {
 
       public:
         result_t() {}
-        inline operator decltype(ret) &() { return ret; }
+        inline operator decltype(ret) & () { return ret; }
         inline const bucket_t *operator[](size_t i) const { return ret[i]; }
     };
 
@@ -359,8 +359,7 @@ class msm_t {
 
   public:
     // Initialize the MSM by moving the points to the device
-    msm_t(const affine_t points[], size_t npoints, bool owned, int device_id = -1) 
-    : gpu(select_gpu(device_id)) {
+    msm_t(const affine_t points[], size_t npoints, bool owned, int device_id = -1) : gpu(select_gpu(device_id)) {
         // set default values for fields
         this->d_points = nullptr;
         this->d_scalars = nullptr;
@@ -375,8 +374,7 @@ class msm_t {
         CUDA_OK(cudaGetLastError());
     }
 
-    msm_t(affine_h *d_points, size_t npoints, int device_id = -1) 
-    : gpu(select_gpu(device_id)) {
+    msm_t(affine_h *d_points, size_t npoints, int device_id = -1) : gpu(select_gpu(device_id)) {
         // set default values for fields
         this->d_points = d_points;
         this->d_scalars = nullptr;
@@ -453,13 +451,57 @@ class msm_t {
     void setup_scratch(size_t nscalars) {
         this->nscalars = nscalars;
 
-        // nscalars = (nscalars + WARP_SZ - 1) & ~(WARP_SZ - 1);
+        // uint32_t lg_n = lg2(nscalars + nscalars / 2);
+
+        // wbits = 17;
+        // if (nscalars > 192) {
+        //     wbits = std::min(lg_n, (uint32_t)18);
+        //     if (wbits < 10)
+        //         wbits = 10;
+        // } else if (nscalars > 0) {
+        //     wbits = 10;
+        // }
+        // nwins = (scalar_t::bit_length() - 1) / wbits + 1;
+
+        // uint32_t row_sz = 1U << (wbits - 1);
+
+        // size_t d_buckets_sz = (nwins * row_sz) + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
+        // d_buckets_sz *= sizeof(d_buckets[0]);
+        // size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t);
+        // size_t temp_sz = scalars ? sizeof(scalar_t) : 0;
+
+        // size_t batch = 1 << (std::max(lg_n, wbits) - wbits);
+        // batch >>= 6;
+        // batch = batch ? batch : 1;
+        // uint32_t stride = (nscalars + batch - 1) / batch;
+        // stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ);
+
+        // temp_sz = stride * std::max(2 * sizeof(uint2), temp_sz);
+        // size_t digits_sz = nwins * stride * sizeof(uint32_t);
+
+        // size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz;
+
+        // d_total_blob = reinterpret_cast<char *>(gpu.Dmalloc(d_blob_sz));
+        // size_t offset = 0;
+        // d_buckets = reinterpret_cast<decltype(d_buckets)>(&d_total_blob[offset]);
+        // offset += d_buckets_sz;
+        // d_hist = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], row_sz);
+        // offset += d_hist_sz;
+
+        // d_temps = vec2d_t<uint2>((uint2 *)&d_total_blob[offset], stride);
+        // d_scalars = (scalar_t *)&d_total_blob[offset];
+        // offset += temp_sz;
+        // d_digits = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], stride);
+    }
+
+    RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) {
+        // assert(this->nscalars <= nscalars);
+
         uint32_t lg_n = lg2(nscalars + nscalars / 2);
 
-        // Compute window size
         wbits = 17;
         if (nscalars > 192) {
-            wbits = std::min(lg_n - 8, (uint32_t)18);
+            wbits = std::min(lg_n, (uint32_t)18);
             if (wbits < 10)
                 wbits = 10;
         } else if (nscalars > 0) {
@@ -467,71 +509,35 @@ class msm_t {
         }
         nwins = (scalar_t::bit_length() - 1) / wbits + 1;
 
-        // Allocate the buckets and histogram
         uint32_t row_sz = 1U << (wbits - 1);
+
         size_t d_buckets_sz = (nwins * row_sz) + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
-        d_buckets_sz *= sizeof(bucket_h);
+        d_buckets_sz *= sizeof(d_buckets[0]);
         size_t d_hist_sz = nwins * row_sz * sizeof(uint32_t);
+        size_t temp_sz = sizeof(scalar_t);
+        temp_sz = stride * std::max(2 * sizeof(uint2), temp_sz);
 
-        // Compute how big each batch should be
-        batch = 1 << (std::max(lg_n, wbits) - wbits);
+        size_t batch = 1 << (std::max(lg_n, wbits) - wbits);
         batch >>= 6;
         batch = batch ? batch : 1;
-        stride = (nscalars + batch - 1) / batch;
-        stride = (stride + WARP_SZ - 1) & ~(WARP_SZ - 1);
-        
-        // Allocate the memory required for each batch
-        size_t scalars_sz = stride * std::max(2 * sizeof(uint2), sizeof(scalar_t));
-        size_t pidx_sz = sizeof(uint32_t) * stride;
+        uint32_t stride = (nscalars + batch - 1) / batch;
+        stride = (stride + WARP_SZ - 1) & ((size_t)0 - WARP_SZ);
+
         size_t digits_sz = nwins * stride * sizeof(uint32_t);
 
-        size_t d_blob_sz = d_buckets_sz + d_hist_sz + scalars_sz + pidx_sz + digits_sz;
+        size_t d_blob_sz = d_buckets_sz + d_hist_sz + temp_sz + digits_sz;
+
         d_total_blob = reinterpret_cast<char *>(gpu.Dmalloc(d_blob_sz));
         size_t offset = 0;
-
-        d_buckets = reinterpret_cast<decltype(d_buckets)>(d_total_blob);
+        d_buckets = reinterpret_cast<decltype(d_buckets)>(&d_total_blob[offset]);
         offset += d_buckets_sz;
-        d_hist = vec2d_t<uint32_t>(reinterpret_cast<uint32_t *>(&d_total_blob[offset]), row_sz);
+        d_hist = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], row_sz);
         offset += d_hist_sz;
 
-        d_temps = vec2d_t<uint2>(reinterpret_cast<uint2 *>(&d_total_blob[offset]), stride);
-        d_scalars = reinterpret_cast<scalar_t *>(&d_total_blob[offset]);
-        offset += scalars_sz;
-        d_pidx = reinterpret_cast<uint32_t *>(&d_total_blob[offset]);
-        offset += pidx_sz;
-        d_digits = vec2d_t<uint32_t>(reinterpret_cast<uint32_t *>(&d_total_blob[offset]), stride);
-        offset += digits_sz;
-    }
-
-    RustError invoke(point_t &out, const scalar_t *scalars, size_t nscalars, uint32_t pidx[], bool mont = true) {
-        // assert(this->nscalars <= nscalars);
-        
-        wbits = 17;
-        if (nscalars > 192) {
-            wbits = std::min(lg2(nscalars + nscalars/2) - 8, 18);
-            if (wbits < 10)
-                wbits = 10;
-        } else if (nscalars > 0) {
-            wbits = 10;
-        }
-        nwins = (scalar_t::bit_length() - 1) / wbits + 1;
-
-        uint32_t row_sz = 1U << (wbits-1);
-
-        size_t d_buckets_sz = (nwins * row_sz)
-                            + (gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
-        size_t d_blob_sz = (d_buckets_sz * sizeof(d_buckets[0]))
-                         + (nwins * row_sz * sizeof(uint32_t));
-
-        d_buckets = reinterpret_cast<decltype(d_buckets)>(gpu.Dmalloc(d_blob_sz));
-        d_hist = vec2d_t<uint32_t>(&d_buckets[d_buckets_sz], row_sz);
-
-        uint32_t lg_n = lg2(nscalars + nscalars/2);
-        size_t batch = 1 << (std::max(lg_n, wbits) - wbits);
-        batch >>= 6;
-        batch = batch ? batch : 1;
-        uint32_t stride = (nscalars + batch - 1) / batch;
-        stride = (stride+WARP_SZ-1) & ((size_t)0-WARP_SZ);
+        d_temps = vec2d_t<uint2>((uint2 *)&d_total_blob[offset], stride);
+        d_scalars = (scalar_t *)&d_total_blob[offset];
+        offset += temp_sz;
+        d_digits = vec2d_t<uint32_t>((uint32_t *)&d_total_blob[offset], stride);
 
         std::vector<result_t> res(nwins);
         std::vector<bucket_t> ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
@@ -540,27 +546,14 @@ class msm_t {
         point_t p;
 
         try {
-            // |scalars| being nullptr means the scalars are pre-loaded to
-            // |d_scalars|, otherwise allocate stride.
-            size_t temp_sz = scalars ? sizeof(scalar_t) : 0;
-            temp_sz = stride * std::max(2*sizeof(uint2), temp_sz);
-
-            size_t digits_sz = nwins * stride * sizeof(uint32_t);
-
-            dev_ptr_t<uint8_t> d_temp{temp_sz + digits_sz, gpu[2]};
-
-            vec2d_t<uint2> d_temps{&d_temp[0], stride};
-            vec2d_t<uint32_t> d_digits{&d_temp[temp_sz], stride};
-
-            scalar_t* d_scalars = scalars ? (scalar_t*)&d_temp[0]
-                                          : this->d_scalars;
             size_t d_off = 0; // device offset
             size_t h_off = 0; // host offset
             size_t num = stride > nscalars ? nscalars : stride;
             event_t ev;
 
             gpu[2].HtoD(&d_scalars[0], &scalars[h_off], num);
-            if (pidx) gpu[2].HtoD(&d_pidx[0], &pidx[h_off], num);
+            if (pidx)
+                gpu[2].HtoD(&d_pidx[0], &pidx[h_off], num);
 
             digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx);
             gpu[2].record(ev);
@@ -587,7 +580,8 @@ class msm_t {
                     num = d_off + stride <= nscalars ? stride : nscalars - d_off;
 
                     gpu[2].HtoD(&d_scalars[0], &scalars[d_off], num);
-                    if (pidx) gpu[2].HtoD(&d_pidx[0], &pidx[d_off], num);
+                    if (pidx)
+                        gpu[2].HtoD(&d_pidx[0], &pidx[d_off], num);
 
                     gpu[2].wait(ev);
                     digits(&d_scalars[0], num, d_digits, d_temps, mont, d_pidx);
@@ -764,8 +758,8 @@ static RustError mult_pippenger(point_t *out, const affine_t points[], size_t np
 
 template <class bucket_t, class point_t, class affine_t, class scalar_t, class affine_h = class affine_t::mem_t,
           class bucket_h = class bucket_t::mem_t>
-static RustError mult_pippenger_with(point_t *out, msm_context_t<affine_h> *msm_context,
-                                     const scalar_t scalars[], size_t nscalars, uint32_t pidx[], bool mont = true) {
+static RustError mult_pippenger_with(point_t *out, msm_context_t<affine_h> *msm_context, const scalar_t scalars[],
+                                     size_t nscalars, uint32_t pidx[], bool mont = true) {
     try {
         msm_t<bucket_t, point_t, affine_t, scalar_t> msm{msm_context->d_points, msm_context->npoints};
         // msm.setup_scratch(nscalars);