Skip to content

Commit

Permalink
use ffi_affine_sz
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanting Zhang committed Feb 19, 2024
1 parent c05be74 commit 1dc6da2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
21 changes: 11 additions & 10 deletions msm/pippenger.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ class msm_t {

public:
// Initialize the MSM by moving the points to the device
msm_t(const affine_t points[], size_t npoints, bool owned, int device_id = -1) : gpu(select_gpu(device_id)) {
msm_t(const affine_t points[], size_t npoints, bool owned, size_t ffi_affine_sz = sizeof(affine_t), int device_id = -1) : gpu(select_gpu(device_id)) {
// set default values for fields
this->d_points = nullptr;
this->d_scalars = nullptr;
Expand All @@ -377,7 +377,7 @@ public:

if (points) {
d_points = reinterpret_cast<decltype(d_points)>(gpu.Dmalloc(npoints * sizeof(d_points[0])));
gpu.HtoD(d_points, points, npoints, sizeof(affine_h));
gpu.HtoD(d_points, points, npoints, ffi_affine_sz);
CUDA_OK(cudaGetLastError());
}
}
Expand Down Expand Up @@ -460,7 +460,7 @@ private:
public:
// Compute various constants (stride length, window size) based on the number of scalars.
// Also allocate scratch space.
void setup_scratch(const affine_t *&points, size_t npoints, size_t nscalars, uint32_t *pidx) {
void setup_scratch(const affine_t *&points, size_t npoints, size_t nscalars, uint32_t *pidx, size_t ffi_affine_sz = sizeof(affine_t)) {
this->npoints = npoints;
this->nscalars = nscalars;

Expand All @@ -473,7 +473,7 @@ public:
// if both are not null, then we move all the points onto the GPU at once,
// at a performance penalty
d_points = reinterpret_cast<decltype(d_points)>(gpu.Dmalloc(npoints * sizeof(d_points[0])));
gpu.HtoD(d_points, points, npoints, sizeof(affine_h));
gpu.HtoD(d_points, points, npoints, ffi_affine_sz);
CUDA_OK(cudaGetLastError());
points = nullptr;
}
Expand Down Expand Up @@ -532,10 +532,11 @@ public:
d_points = (affine_h *)&d_total_blob[offset];
}

RustError invoke(point_t &out, const affine_t points[], size_t npoints,
RustError invoke(point_t &out, const affine_t points_[], size_t npoints,
const scalar_t *scalars, size_t nscalars,
uint32_t pidx[], bool mont = true, size_t ffi_affine_sz = sizeof(affine_t)) {
setup_scratch(points, npoints, nscalars, pidx);
setup_scratch(points_, npoints, nscalars, pidx, ffi_affine_sz);
const char* points = reinterpret_cast<const char*>(points_);

std::vector<result_t> res(nwins);
std::vector<bucket_t> ones(gpu.sm_count() * BATCH_ADD_BLOCK_SIZE / WARP_SZ);
Expand Down Expand Up @@ -791,11 +792,11 @@ static RustError mult_pippenger_init(const affine_t points[], size_t npoints, ms

template <class bucket_t, class point_t, class affine_t, class scalar_t>
static RustError mult_pippenger(point_t *out, const affine_t points[], size_t npoints, const scalar_t scalars[],
size_t nscalars, uint32_t pidx[], bool mont = true) {
size_t nscalars, uint32_t pidx[], bool mont = true, size_t ffi_affine_sz = sizeof(affine_t)) {
try {
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{nullptr, npoints, false};
// msm.setup_scratch(nscalars);
return msm.invoke(*out, points, npoints, scalars, nscalars, pidx, mont);
return msm.invoke(*out, points, npoints, scalars, nscalars, pidx, mont, ffi_affine_sz);
} catch (const cuda_error &e) {
out->inf();
#ifdef TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE
Expand All @@ -809,11 +810,11 @@ static RustError mult_pippenger(point_t *out, const affine_t points[], size_t np
template <class bucket_t, class point_t, class affine_t, class scalar_t, class affine_h = class affine_t::mem_t,
class bucket_h = class bucket_t::mem_t>
static RustError mult_pippenger_with(point_t *out, msm_context_t<affine_h> *msm_context, const scalar_t scalars[],
size_t nscalars, uint32_t pidx[], bool mont = true) {
size_t nscalars, uint32_t pidx[], bool mont = true, size_t ffi_affine_sz = sizeof(affine_t)) {
try {
msm_t<bucket_t, point_t, affine_t, scalar_t> msm{msm_context->d_points, msm_context->npoints};
// msm.setup_scratch(nscalars);
return msm.invoke(*out, nullptr, msm_context->npoints, scalars, nscalars, pidx, mont);
return msm.invoke(*out, nullptr, msm_context->npoints, scalars, nscalars, pidx, mont, ffi_affine_sz);
} catch (const cuda_error &e) {
out->inf();
#ifdef TAKE_RESPONSIBILITY_FOR_ERROR_MESSAGE
Expand Down
4 changes: 2 additions & 2 deletions poc/msm-cuda/cuda/pippenger_inf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ RustError::by_value mult_pippenger_inf(point_t* out, const affine_t points[],
size_t npoints, const scalar_t scalars[],
size_t ffi_affine_sz)
{
return mult_pippenger<bucket_t>(out, points, npoints, scalars, npoints, nullptr, false);
return mult_pippenger<bucket_t>(out, points, npoints, scalars, npoints, nullptr, false, ffi_affine_sz);
}

#if defined(FEATURE_BLS12_381) || defined(FEATURE_BLS12_377)
Expand All @@ -43,7 +43,7 @@ RustError::by_value mult_pippenger_fp2_inf(point_fp2_t* out, const affine_fp2_t
size_t npoints, const scalar_t scalars[],
size_t ffi_affine_sz)
{
return mult_pippenger<bucket_fp2_t>(out, points, npoints, scalars, npoints, nullptr, false);
return mult_pippenger<bucket_fp2_t>(out, points, npoints, scalars, npoints, nullptr, false, ffi_affine_sz);

}
#endif

0 comments on commit 1dc6da2

Please sign in to comment.