Skip to content

Commit

Permalink
[daphneir] fix CastOp::inferTypes bug (daphne-eu#861)
Browse files Browse the repository at this point in the history
This patch fixes a bug in the CastOp::inferTypes implementation which,
for a matrix type of unknown value type, resets the matrix type by
calling our custom builder MatrixType::get(ctx, valueType).

This means that for example, casts from
daphne.Matrix<?x?x!daphne.Unknown> to
daphne.Matrix<?x?x!daphne.Unknown:rep[sparse]>, the matrix representation
would be and reset.

Closes daphne-eu#861
  • Loading branch information
philipportner committed Oct 15, 2024
1 parent 9858fe8 commit bb9c158
Showing 1 changed file with 33 additions and 37 deletions.
70 changes: 33 additions & 37 deletions src/ir/daphneir/DaphneInferTypesOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,45 +63,41 @@ Type getFrameColumnTypeByLabel(Operation *op, daphne::FrameType ft, Value labelV
// ****************************************************************************

std::vector<Type> daphne::CastOp::inferTypes() {
Type argTy = getArg().getType();
Type resTy = getRes().getType();
auto mtArg = argTy.dyn_cast<daphne::MatrixType>();
auto ftArg = argTy.dyn_cast<daphne::FrameType>();
auto mtRes = resTy.dyn_cast<daphne::MatrixType>();

// If the result type is a matrix with so far unknown value type, then we
// infer the value type from the argument.
if (mtRes && llvm::isa<daphne::UnknownType>(mtRes.getElementType())) {
Type resVt;

if (mtArg)
// The argument is a matrix; we use its value type for the result.
resVt = mtArg.getElementType();
else if (ftArg) {
// The argument is a frame, we use the value type of its only
// column for the results; if the argument has more than one
// column, we throw an exception.
std::vector<Type> ctsArg = ftArg.getColumnTypes();
if (ctsArg.size() == 1)
resVt = ctsArg[0];
else
// TODO We could use the most general of the column types.
throw ErrorHandler::compilerError(getLoc(), "InferTypesOpInterface (daphne::CastOp::inferTypes)",
"currently CastOp cannot infer the value type of its "
"output matrix, if the input is a multi-column frame");
} else
// The argument is a scalar, we use its type for the value type
// of the result.
// TODO double-check if it is really a scalar
resVt = argTy;

return {daphne::MatrixType::get(getContext(), resVt)};
}

// Otherwise, we leave the result type as it is. We do not reset it to
Type argumentType = getArg().getType();
Type resultType = getRes().getType();
auto matrixArgument = argumentType.dyn_cast<daphne::MatrixType>();
auto frameArgument = argumentType.dyn_cast<daphne::FrameType>();
auto matrixResult = resultType.dyn_cast<daphne::MatrixType>();

// If the result type is not a matrix or a matrix with so far unknown value type, then we
// we leave the result type as it is. We do not reset it to
// unknown, since this could drop information that was explicitly
// encoded in the CastOp.
return {resTy};
if (!matrixResult || !llvm::isa<daphne::UnknownType>(matrixResult.getElementType()))
return {resultType};

// The argument is a matrix, result is a matrix; we use its value type for the result.
if (matrixArgument)
return {matrixResult.withElementType(matrixArgument.getElementType())};

// The argument is a frame, result is a matrix; we use the value type of its only
// column for the results; if the argument has more than one
// column, we throw an exception.
if (frameArgument) {
auto argumentColumnTypes = frameArgument.getColumnTypes();
if (argumentColumnTypes.size() != 1) {
// TODO We could use the most general of the column types.
throw ErrorHandler::compilerError(getLoc(), "InferTypesOpInterface (daphne::CastOp::inferTypes)",
"currently CastOp cannot infer the value type of its "
"output matrix, if the input is a multi-column frame");
}
return {matrixResult.withElementType(argumentColumnTypes[0])};
}

// The argument is a scalar, result is a matrix; we use its type for the value type
// of the result.
// TODO double-check if it is really a scalar
return {daphne::MatrixType::get(getContext(), argumentType)};
}

std::vector<Type> daphne::ExtractColOp::inferTypes() {
Expand Down

0 comments on commit bb9c158

Please sign in to comment.