Skip to content

Commit

Permalink
[DAPHNE-daphne-eu#613] Support missing aggregation functions in aggre…
Browse files Browse the repository at this point in the history
…gation kernels (daphne-eu#638)

- The aggregation function VAR is not supported for any of the AggAll, AggRow, AggCol kernels, and STDDEV is not supported for full aggregation (AggAll) and row-wise aggregation (AggRow).
- This pull request proposes implementations for the above aggregation functions following the implementation of the MEAN aggregation function in each instance.
- Changes in files AggAll.h, AggRow.h, AggCol.h related to the implementation of STDDEV and VAR.
- Updated file AggOpCode.h to add VAR opCode.
- Updated file kernels.json to add STDDEV and VAR opCodes where needed.
- Include test cases to ensure the validity of the implementation. There are tests for both STDDEV and VAR aggregatin functions for all three kernels (AggAll, AggRow, AggCol), which can be found in the related test cases (AggAllTest.cpp, AggRowTest.cpp, AggCoTest.cpp). Furthermore, these tests to approximate comparisons now to avoid test failures due to minor floating-point deviations.
- Updated DaphneLib script-level test cases matrix_agg.daphne and matrix_agg.py.
- Closes daphne-eu#613.
  • Loading branch information
inikokali authored Nov 17, 2023
1 parent 681ad58 commit ae1d021
Show file tree
Hide file tree
Showing 10 changed files with 307 additions and 55 deletions.
64 changes: 51 additions & 13 deletions src/runtime/local/kernels/AggAll.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ struct AggAll<VTRes, DenseMatrix<VTArg>> {
const VTArg * valuesArg = arg->getValues();

EwBinaryScaFuncPtr<VTRes, VTRes, VTRes> func;
VTRes agg;
VTRes agg, stddev;
if (AggOpCodeUtils::isPureBinaryReduction(opCode)) {
func = getEwBinaryScaFuncPtr<VTRes, VTRes, VTRes>(AggOpCodeUtils::getBinaryOpCode(opCode));
agg = AggOpCodeUtils::template getNeutral<VTRes>(opCode);
}
else {
// TODO Setting the function pointer yields the correct result.
// However, since MEAN and STDDEV are not sparse-safe, the program
// However, since MEAN, VAR, and STDDEV are not sparse-safe, the program
// does not take the same path for doing the summation, and is less
// efficient.
// for MEAN and STDDDEV, we need to sum
// for MEAN, VAR, and STDDEV, we need to sum
func = getEwBinaryScaFuncPtr<VTRes, VTRes, VTRes>(AggOpCodeUtils::getBinaryOpCode(AggOpCode::SUM));
agg = VTRes(0);
}
Expand All @@ -84,14 +84,32 @@ struct AggAll<VTRes, DenseMatrix<VTArg>> {
if (AggOpCodeUtils::isPureBinaryReduction(opCode))
return agg;

// The op-code is either MEAN or STDDEV.
agg /= arg->getNumCols() * arg->getNumRows();
// The op-code is either MEAN or STDDEV or VAR.
if (opCode == AggOpCode::MEAN) {
agg /= arg->getNumCols() * arg->getNumRows();
return agg;
}
// else op-code is STDDEV
// TODO STDDEV
throw std::runtime_error("unsupported AggOpCode in AggAll for DenseMatrix");
// else op-code is STDDEV or VAR
stddev=0;
valuesArg = arg->getValues();
for(size_t r = 0; r < numRows; r++) {
for(size_t c = 0; c < numCols; c++) {
VTRes val = static_cast<VTRes>(valuesArg[c]) - agg;
stddev = stddev + val * val;
}
valuesArg += arg->getRowSkip();
}

stddev /= arg->getNumCols() * arg->getNumRows();

//Variance --> stddev before sqrt() is variance
if (opCode == AggOpCode::VAR){
VTRes var = stddev;
return var;
}

stddev = sqrt(stddev);
return stddev;
}
};

Expand Down Expand Up @@ -131,7 +149,7 @@ struct AggAll<VTRes, CSRMatrix<VTArg>> {
ctx
);
}
else { // The op-code is either MEAN or STDDEV.
else { // The op-code is either MEAN or STDDEV or VAR.
EwBinaryScaFuncPtr<VTRes, VTRes, VTRes> func = getEwBinaryScaFuncPtr<VTRes, VTRes, VTRes>(AggOpCodeUtils::getBinaryOpCode(AggOpCode::SUM));
auto agg = aggArray(
arg->getValues(0),
Expand All @@ -142,11 +160,31 @@ struct AggAll<VTRes, CSRMatrix<VTArg>> {
VTRes(0),
ctx
);
agg = agg / (arg->getNumRows() * arg->getNumCols());
if (opCode == AggOpCode::MEAN)
return agg / (arg->getNumRows() * arg->getNumCols());

// TODO STDDEV
throw std::runtime_error("unsupported AggOpCode in AggAll for CSRMatrix");
return agg;
else{
//STDDEV-VAR
VTRes stddev=0;

const VTArg * valuesArg = arg->getValues(0);
for(size_t i = 0; i < arg->getNumNonZeros(); i++) {
VTRes val = static_cast<VTRes>((valuesArg[i])) - agg;
stddev = stddev + val * val;
}
stddev += ((arg->getNumRows() * arg->getNumCols()) - arg->getNumNonZeros())*agg*agg;
stddev /= (arg->getNumRows() * arg->getNumCols());

//Variance --> stddev before sqrt() is variance
if (opCode == AggOpCode::VAR){
VTRes var = stddev;
return var;
}

stddev = sqrt(stddev);
return stddev;

}
}
}
};
Expand Down
18 changes: 11 additions & 7 deletions src/runtime/local/kernels/AggCol.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,12 @@ struct AggCol<DenseMatrix<VTRes>, DenseMatrix<VTArg>> {
if(AggOpCodeUtils::isPureBinaryReduction(opCode))
return;

// The op-code is either MEAN or STDDEV.
// The op-code is either MEAN or STDDEV or VAR.

for(size_t c = 0; c < numCols; c++)
valuesRes[c] /= numRows;

if(opCode != AggOpCode::STDDEV)
if(opCode == AggOpCode::MEAN)
return;

auto tmp = DataObjectFactory::create<DenseMatrix<VTRes>>(1, numCols, true);
Expand All @@ -165,9 +165,12 @@ struct AggCol<DenseMatrix<VTRes>, DenseMatrix<VTArg>> {

for(size_t c = 0; c < numCols; c++) {
valuesT[c] /= numRows;
valuesT[c] = sqrt(valuesT[c]);
if (opCode == AggOpCode::STDDEV)
valuesT[c] = sqrt(valuesT[c]);
}



// TODO We could avoid copying by returning tmp and destroying res. But
// that might be wrong if res was not nullptr initially.
memcpy(valuesRes, valuesT, numCols * sizeof(VTRes));
Expand Down Expand Up @@ -240,12 +243,12 @@ struct AggCol<DenseMatrix<VTRes>, CSRMatrix<VTArg>> {
if(AggOpCodeUtils::isPureBinaryReduction(opCode))
return;

// The op-code is either MEAN or STDDEV.
// The op-code is either MEAN or STDDEV or VAR.

for(size_t c = 0; c < numCols; c++)
valuesRes[c] /= arg->getNumRows();

if(opCode != AggOpCode::STDDEV)
if(opCode == AggOpCode::MEAN)
return;

auto tmp = DataObjectFactory::create<DenseMatrix<VTRes>>(1, numCols, true);
Expand All @@ -264,7 +267,8 @@ struct AggCol<DenseMatrix<VTRes>, CSRMatrix<VTArg>> {
valuesT[c] += (valuesRes[c] * valuesRes[c]) * (numRows - nnzCol[c]);
// Finish computation of stddev.
valuesT[c] /= numRows;
valuesT[c] = sqrt(valuesT[c]);
if (opCode == AggOpCode::STDDEV)
valuesT[c] = sqrt(valuesT[c]);
}

delete[] nnzCol;
Expand All @@ -277,4 +281,4 @@ struct AggCol<DenseMatrix<VTRes>, CSRMatrix<VTArg>> {
}
};

#endif //SRC_RUNTIME_LOCAL_KERNELS_AGGCOL_H
#endif //SRC_RUNTIME_LOCAL_KERNELS_AGGCOL_H
3 changes: 3 additions & 0 deletions src/runtime/local/kernels/AggOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ enum class AggOpCode {
IDXMAX,
MEAN,
STDDEV,
VAR,
};

struct AggOpCodeUtils {
Expand All @@ -45,6 +46,7 @@ struct AggOpCodeUtils {
return true;
case AggOpCode::MEAN:
case AggOpCode::STDDEV:
case AggOpCode::VAR:
return false;
default:
throw std::runtime_error("unsupported AggOpCode");
Expand Down Expand Up @@ -85,6 +87,7 @@ struct AggOpCodeUtils {
case AggOpCode::MAX:
case AggOpCode::MEAN:
case AggOpCode::STDDEV:
case AggOpCode::VAR:
return false;
default:
throw std::runtime_error("unsupported AggOpCode");
Expand Down
73 changes: 62 additions & 11 deletions src/runtime/local/kernels/AggRow.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

#include <cassert>
#include <cstddef>
#include <cstring>
#include <cmath>
#include <typeinfo>

// ****************************************************************************
// Struct for partial template specialization
Expand Down Expand Up @@ -108,8 +111,9 @@ struct AggRow<DenseMatrix<VTRes>, DenseMatrix<VTArg>> {

for(size_t r = 0; r < numRows; r++) {
VTRes agg = static_cast<VTRes>(*valuesArg);
for(size_t c = 1; c < numCols; c++)
for(size_t c = 1; c < numCols; c++){
agg = func(agg, static_cast<VTRes>(valuesArg[c]), ctx);
}
*valuesRes = static_cast<VTRes>(agg);
valuesArg += arg->getRowSkip();
valuesRes += res->getRowSkip();
Expand All @@ -118,18 +122,45 @@ struct AggRow<DenseMatrix<VTRes>, DenseMatrix<VTArg>> {
if(AggOpCodeUtils::isPureBinaryReduction(opCode))
return;

// The op-code is either MEAN or STDDEV
// The op-code is either MEAN or STDDEV or VAR
valuesRes = res->getValues();
// valuesArg = arg->getValues();
for(size_t r = 0; r < numRows; r++) {
*valuesRes = (*valuesRes) / numCols;
valuesRes += res->getRowSkip();
}

if(opCode == AggOpCode::MEAN)
return;

// else op-code is STDDEV
// TODO STDDEV
throw std::runtime_error("unsupported AggOpCode in AggRow for DenseMatrix");
// else op-code is STDDEV or VAR

// Create a temporary matrix to store the resulting standard deviations for each row
auto tmp = DataObjectFactory::create<DenseMatrix<VTRes>>(numRows, 1, true);
VTRes * valuesT = tmp->getValues();
valuesArg = arg->getValues();
valuesRes = res->getValues();
for(size_t r = 0; r < numRows; r++) {
for(size_t c = 0; c < numCols; c++) {
VTRes val = static_cast<VTRes>(valuesArg[c]) - (*valuesRes);
valuesT[r] = valuesT[r] + val * val;
}
valuesArg += arg->getRowSkip();
valuesRes += res->getRowSkip();

}
valuesRes = res->getValues();
for(size_t c = 0; c < numRows; c++) {
valuesT[c] /= numCols;
if(opCode == AggOpCode::STDDEV)
*valuesRes = sqrt(valuesT[c]);
else
*valuesRes = valuesT[c];
valuesRes += res->getRowSkip();
}

DataObjectFactory::destroy<DenseMatrix<VTRes>>(tmp);

}
}
};
Expand Down Expand Up @@ -169,10 +200,13 @@ struct AggRow<DenseMatrix<VTRes>, CSRMatrix<VTArg>> {
valuesRes += res->getRowSkip();
}
}
else { // The op-code is either MEAN or STDDEV
else { // The op-code is either MEAN or STDDEV or VAR
// get sum for each row
size_t ctr = 0 ;
const VTRes neutral = VTRes(0);
const bool isSparseSafe = true;
auto tmp = DataObjectFactory::create<DenseMatrix<VTRes>>(numRows, 1, true);
VTRes * valuesT = tmp->getValues();
EwBinaryScaFuncPtr<VTRes, VTRes, VTRes> func = getEwBinaryScaFuncPtr<VTRes, VTRes, VTRes>(AggOpCodeUtils::getBinaryOpCode(AggOpCode::SUM));
for (size_t r = 0; r < numRows; r++){
*valuesRes = AggAll<VTRes, CSRMatrix<VTArg>>::aggArray(
Expand All @@ -184,14 +218,31 @@ struct AggRow<DenseMatrix<VTRes>, CSRMatrix<VTArg>> {
neutral,
ctx
);
if (opCode == AggOpCode::MEAN)
*valuesRes = *valuesRes / numCols;
else
throw std::runtime_error("unsupported AggOpCode in AggRow for CSRMatrix");
const VTArg * valuesArg = arg->getValues(0);
const size_t numNonZeros = arg->getNumNonZeros(r);
*valuesRes = *valuesRes / numCols;
if (opCode != AggOpCode::MEAN){
for(size_t i = ctr; i < ctr+numNonZeros; i++) {
VTRes val = static_cast<VTRes>((valuesArg[i])) - (*valuesRes);
valuesT[r] = valuesT[r] + val * val;
}

ctr+=numNonZeros;
valuesT[r] += (numCols - numNonZeros) * (*valuesRes)*(*valuesRes);
valuesT[r] /= numCols;
if(opCode == AggOpCode::STDDEV)
*valuesRes = sqrt(valuesT[r]);
else
*valuesRes = valuesT[r];
}
valuesRes += res->getRowSkip();
}
valuesRes = res->getValues();
DataObjectFactory::destroy<DenseMatrix<VTRes>>(tmp);

}
}
};

#endif //SRC_RUNTIME_LOCAL_KERNELS_AGGROW_H

#endif //SRC_RUNTIME_LOCAL_KERNELS_AGGROW_H
6 changes: 3 additions & 3 deletions src/runtime/local/kernels/kernels.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
["double", ["CSRMatrix", "int64_t"]],
["float", ["CSRMatrix", "int64_t"]]
],
"opCodes": ["SUM", "MIN", "MAX", "MEAN"]
"opCodes": ["SUM", "MIN", "MAX", "MEAN", "STDDEV", "VAR"]
}
]
},
Expand Down Expand Up @@ -108,7 +108,7 @@
[["DenseMatrix", "double"], ["CSRMatrix", "int64_t"]],
[["DenseMatrix", "float"], ["CSRMatrix", "int64_t"]]
],
"opCodes": ["SUM", "MIN", "MAX", "MEAN", "STDDEV", "IDXMIN", "IDXMAX"]
"opCodes": ["SUM", "MIN", "MAX", "MEAN", "STDDEV", "VAR", "IDXMIN", "IDXMAX"]
}
]
},
Expand Down Expand Up @@ -207,7 +207,7 @@
[["DenseMatrix", "float"], ["CSRMatrix", "int64_t"]],
[["DenseMatrix", "double"], ["CSRMatrix", "int64_t"]]
],
"opCodes": ["SUM", "MIN", "MAX", "MEAN", "IDXMIN", "IDXMAX"]
"opCodes": ["SUM", "MIN", "MAX", "MEAN", "STDDEV", "VAR", "IDXMIN", "IDXMAX"]
}
]
},
Expand Down
12 changes: 5 additions & 7 deletions test/api/python/matrix_agg.daphne
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,19 @@

m = reshape(seq(1, 12, 1), 4, 3);

# TODO The commented out functions below are not supported yet (see #613).

# Full aggregation.
print(sum(m));
print(mean(m));
# print(var(m));
# print(stddev(m));
print(var(m));
print(stddev(m));
print(aggMin(m));
print(aggMax(m));

# Row-wise aggregation.
print(sum(m, 0));
print(mean(m, 0));
# print(var(m, 0));
# print(stddev(m, 0));
print(var(m, 0));
print(stddev(m, 0));
print(aggMin(m, 0));
print(aggMax(m, 0));
print(idxMin(m, 0));
Expand All @@ -37,7 +35,7 @@ print(idxMax(m, 0));
# Column-wise aggregation.
print(sum(m, 1));
print(mean(m, 1));
# print(var(m, 1));
print(var(m, 1));
print(stddev(m, 1));
print(aggMin(m, 1));
print(aggMax(m, 1));
Expand Down
Loading

0 comments on commit ae1d021

Please sign in to comment.