Skip to content

Commit

Permalink
[DAPHNE-daphne-eu#768] isNan: elementwise check for NaN elements (dap…
Browse files Browse the repository at this point in the history
…hne-eu#779)

- This commit introduces the isNan() function to DAPHNE, enabling element-wise NaN checks on matrices and scalars 
- Created an IsNanOp operation in DaphneIR as a subclass of EwUnaryOp (elementwise unary operation) (see src/ir/daphneir/DaphneOps.td).
- Created a DaphneDSL built-in function isNan() and documented it (see src/parser/daphnedsl/DaphneDSLBuiltins.cpp and doc/DaphneDSL/Builtins.md).
- Created a DaphneLib function isNan() for matrices and scalars and documented it (see src/api/python/daphne/operator/nodes/matrix.py, src/api/python/daphne/operator/nodes/scalar.py, and doc/DaphneLib/APIRef.md).
- Created a runtime kernel for isNan() by extending the existing ewUnaryMat and ewUnarySca kernel with a new op code (see src/runtime/local/kernels/EwUnaryMat.h).
- Added unit test cases for the kernel and script-level test cases for DaphneDSL and DaphneLib.
- Remove the remark on missing isNan() in the guidelines on porting from Numpy to DaphneLib (see docs/DaphneLib/Numpy2DaphneLib.py).
- Closes daphne-eu#768.
  • Loading branch information
StoeckOverflow authored Jul 11, 2024
1 parent 26a3d1e commit 29fd7b9
Show file tree
Hide file tree
Showing 19 changed files with 149 additions and 9 deletions.
6 changes: 6 additions & 0 deletions doc/DaphneDSL/Builtins.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ The following built-in functions all follow the same scheme:
| **`floor`** | round down |
| **`ceil`** | round up |

### Comparison

| function | meaning |
| ----- | ----- |
| **`isNan`** | `1` if argument is NaN, `0` otherwise |

## Elementwise binary

DaphneDSL supports various elementwise binary operations.
Expand Down
2 changes: 2 additions & 0 deletions doc/DaphneLib/APIRef.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ In the following, we describe only the latter.
- **`sinh`**`()`
- **`cosh`**`()`
- **`tanh`**`()`
- **`isNan`**`()`

**Elementwise binary:**

Expand Down Expand Up @@ -222,6 +223,7 @@ In the following, we describe only the latter.
- **`sinh`**`()`
- **`cosh`**`()`
- **`tanh`**`()`
- **`isNan`**`()`

**Elementwise binary:**

Expand Down
2 changes: 0 additions & 2 deletions doc/DaphneLib/Numpy2DaphneLib.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,6 @@ Porting for other Numpy versions may be possible by following similar lines of t

- `numpy.`**`isnan`**`(x, /, out=None, *, where=True, casting='same_kind', order='K', dtype=None, subok=True[, signature]) = <ufunc 'isnan'>`

*Note: `isNan()` is not supported in DAPHNE yet (see #768).*

*Parameters*

- `x`: supported
Expand Down
6 changes: 6 additions & 0 deletions src/api/python/daphne/operator/nodes/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def sqrt(self) -> 'OperationNode':
"""
return Matrix(self.daphne_context,'sqrt', [self])

def isNan(self) -> 'OperationNode':
"""Elementwise check for NaN values in this matrix (resulting in 1 if the element is NaN, 0 otherwise).
:return: `Matrix` A node representing the isNan operation.
"""
return Matrix(self.daphne_context, 'isNan', [self])

def round(self) -> 'OperationNode':
return Matrix(self.daphne_context, 'round', [self])

Expand Down
3 changes: 3 additions & 0 deletions src/api/python/daphne/operator/nodes/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def acos(self) -> 'Scalar':
def atan(self) -> 'Scalar':
return Scalar(self.daphne_context, 'atan', [self])

def isNan(self) -> 'Scalar':
return Scalar(self.daphne_context, 'isNan', [self])

def pow(self, other) -> 'Scalar':
return Scalar(self.daphne_context, 'pow', [self, other])

Expand Down
6 changes: 6 additions & 0 deletions src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ def Daphne_EwAsinOp : Daphne_EwUnaryOp<"ewAsin", NumScalar, [ValueTypeFromArgsFP
def Daphne_EwAcosOp : Daphne_EwUnaryOp<"ewAcos", NumScalar, [ValueTypeFromArgsFP]>;
def Daphne_EwAtanOp : Daphne_EwUnaryOp<"ewAtan", NumScalar, [ValueTypeFromArgsFP]>;

// ----------------------------------------------------------------------------
// Comparison
// ----------------------------------------------------------------------------

def Daphne_EwIsNanOp : Daphne_EwUnaryOp<"ewIsnan", NumScalar, [ValueTypeFromFirstArg]>;

// ****************************************************************************
// Elementwise binary
// ****************************************************************************
Expand Down
7 changes: 7 additions & 0 deletions src/parser/daphnedsl/DaphneDSLBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,13 @@ antlrcpp::Any DaphneDSLBuiltins::build(mlir::Location loc, const std::string & f
if(func == "atan")
return createUnaryOp<EwAtanOp>(loc, func, args);

// --------------------------------------------------------------------
// Comparison
// --------------------------------------------------------------------

if (func == "isNan")
return createUnaryOp<EwIsNanOp>(loc, func, args);

// ********************************************************************
// Elementwise binary
// ********************************************************************
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/local/kernels/EwUnarySca.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ EwUnaryScaFuncPtr<VTRes, VTArg> getEwUnaryScaFuncPtr(UnaryOpCode opCode) {
MAKE_CASE(UnaryOpCode::FLOOR)
MAKE_CASE(UnaryOpCode::CEIL)
MAKE_CASE(UnaryOpCode::ROUND)
// Comparison.
MAKE_CASE(UnaryOpCode::ISNAN)
#undef MAKE_CASE
default:
throw std::runtime_error("unknown UnaryOpCode");
Expand Down Expand Up @@ -167,6 +169,8 @@ MAKE_EW_UNARY_SCA(UnaryOpCode::TANH, tanh(arg));
MAKE_EW_UNARY_SCA(UnaryOpCode::FLOOR, floor(arg));
MAKE_EW_UNARY_SCA(UnaryOpCode::CEIL, std::ceil(arg));
MAKE_EW_UNARY_SCA(UnaryOpCode::ROUND, round(arg));
// Comparison.
MAKE_EW_UNARY_SCA(UnaryOpCode::ISNAN, std::isnan(arg));

#undef MAKE_EW_UNARY_SCA_CLOSED_DOMAIN_ERROR
#undef MAKE_EW_UNARY_SCA_OPEN_DOMAIN_ERROR
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/local/kernels/UnaryOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ enum class UnaryOpCode {
FLOOR,
CEIL,
ROUND,
// Comparison.
ISNAN
};

#endif //SRC_RUNTIME_LOCAL_KERNELS_UNARYOPCODE_H
4 changes: 2 additions & 2 deletions src/runtime/local/kernels/kernels.json
Original file line number Diff line number Diff line change
Expand Up @@ -2909,7 +2909,7 @@
[["DenseMatrix", "int64_t"],["DenseMatrix", "int64_t"]]
],
"opCodes": ["MINUS", "SIGN", "SQRT", "EXP", "ABS", "FLOOR", "CEIL", "ROUND", "LN",
"SIN", "COS", "TAN", "ASIN", "ACOS", "ATAN", "SINH", "COSH", "TANH"]
"SIN", "COS", "TAN", "ASIN", "ACOS", "ATAN", "SINH", "COSH", "TANH", "ISNAN"]
},
{
"kernelTemplate": {
Expand Down Expand Up @@ -2943,7 +2943,7 @@
["int64_t", "int64_t"]
],
"opCodes": ["MINUS", "SIGN", "SQRT", "EXP", "ABS", "FLOOR", "CEIL", "ROUND", "LN",
"SIN", "COS", "TAN", "ASIN", "ACOS", "ATAN", "SINH", "COSH", "TANH"]
"SIN", "COS", "TAN", "ASIN", "ACOS", "ATAN", "SINH", "COSH", "TANH", "ISNAN"]
},
{
"kernelTemplate": {
Expand Down
1 change: 1 addition & 0 deletions test/api/cli/operations/OperationsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ MAKE_TEST_CASE("ctable", 1)
MAKE_TEST_CASE("gemv", 1)
MAKE_TEST_CASE("idxMax", 1)
MAKE_TEST_CASE("idxMin", 1)
MAKE_TEST_CASE("isNan", 1)
MAKE_TEST_CASE("mean", 1)
MAKE_TEST_CASE("operator_at", 2)
MAKE_TEST_CASE("operator_eq", 2)
Expand Down
4 changes: 4 additions & 0 deletions test/api/cli/operations/isNan_1.daphne
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
print(isNan(3));
print(isNan(nan));
X = t([1, nan, 0, inf, 99.9, nan]);
print(isNan(X));
4 changes: 4 additions & 0 deletions test/api/cli/operations/isNan_1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
0
1
DenseMatrix(1x6, double)
0 1 0 0 0 1
5 changes: 4 additions & 1 deletion test/api/python/matrix_ewunary.daphne
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,7 @@ print(acos(m));
print(atan(m));
print(sinh(m));
print(cosh(m));
print(tanh(m));
print(tanh(m));

print(isNan(m));
print(isNan([nan]));
6 changes: 5 additions & 1 deletion test/api/python/matrix_ewunary.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from daphne.context.daphne_context import DaphneContext
import math

dc = DaphneContext()

Expand Down Expand Up @@ -42,4 +43,7 @@
m.atan().print().compute()
m.sinh().print().compute()
m.cosh().print().compute()
m.tanh().print().compute()
m.tanh().print().compute()

m.isNan().print().compute()
dc.fill(math.nan, 1, 1).isNan().print().compute()
5 changes: 4 additions & 1 deletion test/api/python/scalar_ewunary.daphne
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,7 @@ print(acos(s));
print(atan(s));
print(sinh(s));
print(cosh(s));
print(tanh(s));
print(tanh(s));

print(isNan(s));
print(isNan(nan));
6 changes: 5 additions & 1 deletion test/api/python/scalar_ewunary.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from daphne.context.daphne_context import DaphneContext
import math

dc = DaphneContext()

Expand Down Expand Up @@ -45,4 +46,7 @@
s.sum().atan().print().compute()
s.sum().sinh().print().compute()
s.sum().cosh().print().compute()
s.sum().tanh().print().compute()
s.sum().tanh().print().compute()

s.sum().isNan().print().compute()
dc.fill(math.nan, 1, 1).cbind(dc.fill(1.0, 1, 1)).sum().isNan().print().compute()
61 changes: 60 additions & 1 deletion test/runtime/local/kernels/EwUnaryMatTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,65 @@ TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("round, floating-point-specific"), TAG_KERN
DataObjectFactory::destroy(arg, exp);
}

// ****************************************************************************
// Comparison
// ****************************************************************************

TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("isNan"), TAG_KERNELS, (DATA_TYPES), (int32_t)) {
using DT = TestType;

auto arg = genGivenVals<DT>(4, {
1,
0,
99,
-99,
});

auto exp = genGivenVals<DT>(4, {
0,
0,
0,
0
});

checkEwUnaryMat(UnaryOpCode::ISNAN, arg, exp);

DataObjectFactory::destroy(arg, exp);
}

TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("isNan, floating-point specific"), TAG_KERNELS, (DATA_TYPES), (double)) {
using DT = TestType;
using VT = typename DT::VT;

auto arg = genGivenVals<DT>(9, {
1,
std::numeric_limits<VT>::quiet_NaN(),
0,
std::numeric_limits<VT>::infinity(),
-std::numeric_limits<VT>::infinity(),
99.9,
-99.9,
std::numeric_limits<VT>::quiet_NaN(),
std::numeric_limits<VT>::denorm_min()
});

auto exp = genGivenVals<DT>(9, {
0,
1,
0,
0,
0,
0,
0,
1,
0
});

checkEwUnaryMat(UnaryOpCode::ISNAN, arg, exp);

DataObjectFactory::destroy(arg, exp);
}

// ****************************************************************************
// Invalid op-code
// ****************************************************************************
Expand All @@ -560,4 +619,4 @@ TEMPLATE_PRODUCT_TEST_CASE(TEST_NAME("some invalid op-code"), TAG_KERNELS, (DATA
CHECK_THROWS(ewUnaryMat<DT, DT>(static_cast<UnaryOpCode>(999), exp, arg, nullptr));

DataObjectFactory::destroy(arg);
}
}
24 changes: 24 additions & 0 deletions test/runtime/local/kernels/EwUnaryScaTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,30 @@ TEMPLATE_TEST_CASE(TEST_NAME("round, floating-point-specific"), TAG_KERNELS, FP_
checkEwUnarySca<UnaryOpCode::ROUND, VT>(-0.5, -1);
}

// ****************************************************************************
// Comparison
// ****************************************************************************

TEMPLATE_TEST_CASE(TEST_NAME("isNan"), TAG_KERNELS, SI_VALUE_TYPES) {
using VT = TestType;
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(1, 0);
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(99, 0);
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(-99, 0);
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(0, 0);
}

TEMPLATE_TEST_CASE(TEST_NAME("isNan, floating-point-specific"), TAG_KERNELS, FP_VALUE_TYPES) {
using VT = TestType;
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(1, 0);
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(std::numeric_limits<VT>::quiet_NaN(), 1);
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(std::numeric_limits<VT>::infinity(), 0);
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(-std::numeric_limits<VT>::infinity(), 0);
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(99.9, 0);
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(-99.9, 0);
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(0, 0);
checkEwUnarySca<UnaryOpCode::ISNAN, VT>(std::numeric_limits<VT>::denorm_min(), 0);
}

// ****************************************************************************
// Invalid op-code
// ****************************************************************************
Expand Down

0 comments on commit 29fd7b9

Please sign in to comment.