Skip to content

Commit

Permalink
refactor deriv of FFT
Browse files Browse the repository at this point in the history
currently multiplier set to 1 to start testing stuff
  • Loading branch information
mofeing committed Jan 10, 2025
1 parent e6dd644 commit 03a79ae
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -935,10 +935,7 @@ def FftMultiplier : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
ret_constant;
}]>;

def FftIsIRFFT : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, [{
auto cond = op.getFftType() == FftType::IRFFT;
builder.create<ConstantOp>(op.getLoc(), builder.getDenseBoolArrayAttr(ArrayRef<bool>({cond})));
}]>;
def SelectIfIRFFT : StaticSelect<"op.getFftType() == FftType::IRFFT">;

// Derivative rules
def : HLODerivative<"AddOp", (Op $x, $y),
Expand Down Expand Up @@ -1039,14 +1036,19 @@ def : HLODerivative<"Expm1Op", (Op $x), [(CheckedMul (DiffeRet), (Exp $x))]>;
def : HLODerivative<"FftOp", (Op $x),
[
(Mul
(FftMultiplier), // TODO fix this
// multiplier
// (FftMultiplier), // TODO fix this
(HLOConstantFP<"1">),
// inverse fft
(Fft
(Select
(FftIsIRFFT), // if IRFFT
(Real (DiffeRet)), // call real(diff)
(DiffeRet),
(SelectIfIRFFT
(Real (DiffeRet)), // IRFFT is complex to real, so reverse-mode needs to pass a real diff
(DiffeRet)
),
(FftTypeInverse),
(FftLength))))
(FftLength) // TODO revise
)
)
],
(Fft (Shadow $x), (FftType), (FftLength))
>;
Expand Down

0 comments on commit 03a79ae

Please sign in to comment.