Skip to content

Commit

Permalink
lgc: add CreateFDot2 to expose amdgcn_fdot2 to the client
Browse files Browse the repository at this point in the history
  • Loading branch information
xazhangAMD committed Oct 27, 2023
1 parent 1816e93 commit 0594c05
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 0 deletions.
18 changes: 18 additions & 0 deletions lgc/builder/ArithBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,24 @@ Value *BuilderImpl::CreateMsad4(Value *src, Value *ref, Value *accum, const Twin
return result;
}


// =====================================================================================================================
// Create "fdot2" operation, returning an 32-bit float result of the sum of dot2 of 2 half vec2 and a float scalar.
//
// @param a : Vector of 2xhalf A.
// @param b : Vector of 2xhalf B.
// @param scalar : A float scalar.
// @param clamp : Whether the dot2 result should be clamped.
Value *BuilderImpl::CreateFDot2(Value *a, Value *b, Value *scalar, Value *clamp, const Twine &instName) {
assert(a->getType()->getScalarType()->isHalfTy() && b->getType()->getScalarType()->isHalfTy());
assert(scalar->getType()->isFloatTy());
assert(clamp->getType()->isIntegerTy() && clamp->getType()->getIntegerBitWidth() == 1);

Value *result = CreateIntrinsic(scalar->getType(), Intrinsic::amdgcn_fdot2, {a, b, scalar, clamp});
result->setName(instName);
return result;
}

// =====================================================================================================================
// Create "fmix" operation, returning ( 1 - A ) * X + A * Y. Result would be FP scalar or vector value.
// Returns scalar, if and only if "pX", "pY" and "pA" are all scalars.
Expand Down
14 changes: 14 additions & 0 deletions lgc/builder/BuilderRecorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ StringRef BuilderRecorder::getCallName(BuilderOpcode opcode) {
return "fmix";
case BuilderOpcode::Msad4:
return "msad4";
case BuilderOpcode::FDot2:
return "fdot2";
case BuilderOpcode::LoadBufferDesc:
return "load.buffer.desc";
case BuilderOpcode::GetDescStride:
Expand Down Expand Up @@ -1039,6 +1041,17 @@ Value *Builder::CreateMsad4(Value *src, Value *ref, Value *accum, const Twine &i
return record(BuilderOpcode::Msad4, src->getType(), {src, ref, accum}, instName);
}

// =====================================================================================================================
// Create "fdot2" operation, returning an 32-bit float result of the sum of dot2 of 2 half vec2 and a float scalar.
//
// @param a : Vector of 2xhalf A.
// @param b : Vector of 2xhalf B.
// @param scalar : A float scalar.
// @param clamp : Whether the dot2 result should be clamped.
Value *Builder::CreateFDot2(Value *a, Value *b, Value *scalar, Value *clamp, const Twine &instName) {
return record(BuilderOpcode::FDot2, scalar->getType(), {a, b, scalar, clamp}, instName);
}

// =====================================================================================================================
// Create a load of a buffer descriptor.
//
Expand Down Expand Up @@ -2017,6 +2030,7 @@ Instruction *Builder::record(BuilderOpcode opcode, Type *resultTy, ArrayRef<Valu
case BuilderOpcode::FindSMsb:
case BuilderOpcode::CountLeadingSignBits:
case BuilderOpcode::Msad4:
case BuilderOpcode::FDot2:
case BuilderOpcode::Fma:
case BuilderOpcode::FpTruncWithRounding:
case BuilderOpcode::Fract:
Expand Down
1 change: 1 addition & 0 deletions lgc/builder/BuilderRecorder.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ enum BuilderOpcode : unsigned {
CountLeadingSignBits,
FMix,
Msad4,
FDot2,

// Descriptor
LoadBufferDesc,
Expand Down
4 changes: 4 additions & 0 deletions lgc/builder/BuilderReplayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ Value *BuilderReplayer::processCall(unsigned opcode, CallInst *call) {
return m_builder->CreateMsad4(args[0], args[1], args[2]);
}

case BuilderOpcode::FDot2: {
return m_builder->CreateFDot2(args[0], args[1], args[2], args[3]);
}

// Replayer implementations of DescBuilder methods
case BuilderOpcode::LoadBufferDesc: {
return m_builder->CreateLoadBufferDesc(cast<ConstantInt>(args[0])->getZExtValue(), // descSet
Expand Down
4 changes: 4 additions & 0 deletions lgc/include/lgc/builder/BuilderImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ class BuilderImpl : public BuilderDefs {
// Create "Masked Sum of Absolute Differences" operation.
llvm::Value *CreateMsad4(llvm::Value *src, llvm::Value *ref, llvm::Value *accum, const llvm::Twine &instName = "");

// Create fdot2_f16 + f32 operation.
llvm::Value *CreateFDot2(llvm::Value *a, llvm::Value *b, llvm::Value *scalar, llvm::Value *clamp,
const llvm::Twine &instName = "");

// Create "fmix" operation.
llvm::Value *createFMix(llvm::Value *x, llvm::Value *y, llvm::Value *a, const llvm::Twine &instName = "");

Expand Down
9 changes: 9 additions & 0 deletions lgc/interface/lgc/Builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,15 @@ class Builder : public BuilderDefs {
// @param accum : A 32-bit unsigned integer, providing an existing accumulation.
llvm::Value *CreateMsad4(llvm::Value *src, llvm::Value *ref, llvm::Value *accum, const llvm::Twine &instName = "");

// Create "fdot2" operation, returning an 32-bit float result of the sum of dot2 of 2 half vec2 and a float scalar.
//
// @param a : Vector of 2xhalf A.
// @param b : Vector of 2xhalf B.
// @param scalar : A float scalar.
// @param clamp : Whether the dot2 result should be clamped.
llvm::Value *CreateFDot2(llvm::Value *a, llvm::Value *b, llvm::Value *scalar, llvm::Value *clamp,
const llvm::Twine &instName = "");

// Create "fmix" operation, returning ( 1 - A ) * X + A * Y. Result would be FP scalar or vector value.
// Returns scalar, if and only if "pX", "pY" and "pA" are all scalars.
// Returns vector, if "pX" and "pY" are vector but "pA" is a scalar, under such condition, "pA" will be splatted.
Expand Down

0 comments on commit 0594c05

Please sign in to comment.