diff --git a/modules/apps/transfer/keeper/forwarding.go b/modules/apps/transfer/keeper/forwarding.go index b99f3b06713..346f3b56423 100644 --- a/modules/apps/transfer/keeper/forwarding.go +++ b/modules/apps/transfer/keeper/forwarding.go @@ -44,35 +44,6 @@ func (k Keeper) forwardPacket(ctx sdk.Context, data types.FungibleTokenPacketDat return nil } -// ackForwardPacketSuccess writes a successful async acknowledgement for the prevPacket -func (k Keeper) ackForwardPacketSuccess(ctx sdk.Context, prevPacket, forwardedPacket channeltypes.Packet) error { - forwardAck := channeltypes.NewResultAcknowledgement([]byte{byte(1)}) - return k.acknowledgeForwardedPacket(ctx, prevPacket, forwardedPacket, forwardAck) -} - -// ackForwardPacketError reverts the receive packet logic that occurs in the middle chain and writes the async ack for the prevPacket -func (k Keeper) ackForwardPacketError(ctx sdk.Context, prevPacket, forwardedPacket channeltypes.Packet, failedPacketData types.FungibleTokenPacketDataV2) error { - // the forwarded packet has failed, thus the funds have been refunded to the intermediate address. - // we must revert the changes that came from successfully receiving the tokens on our chain - // before propagating the error acknowledgement back to original sender chain - if err := k.revertForwardedPacket(ctx, prevPacket, failedPacketData); err != nil { - return err - } - - forwardAck := channeltypes.NewErrorAcknowledgement(types.ErrForwardedPacketFailed) - return k.acknowledgeForwardedPacket(ctx, prevPacket, forwardedPacket, forwardAck) -} - -// ackForwardPacketTimeout reverts the receive packet logic that occurs in the middle chain and writes a failed async ack for the prevPacket -func (k Keeper) ackForwardPacketTimeout(ctx sdk.Context, prevPacket, forwardedPacket channeltypes.Packet, timeoutPacketData types.FungibleTokenPacketDataV2) error { - if err := k.revertForwardedPacket(ctx, prevPacket, timeoutPacketData); err != nil { - return err - } - - forwardAck := channeltypes.NewErrorAcknowledgement(types.ErrForwardedPacketTimedOut) - return k.acknowledgeForwardedPacket(ctx, prevPacket, forwardedPacket, forwardAck) -} - // acknowledgeForwardedPacket writes the async acknowledgement for packet func (k Keeper) acknowledgeForwardedPacket(ctx sdk.Context, packet, forwardedPacket channeltypes.Packet, ack channeltypes.Acknowledgement) error { capability, ok := k.scopedKeeper.GetCapability(ctx, host.ChannelCapabilityPath(packet.DestinationPort, packet.DestinationChannel)) diff --git a/modules/apps/transfer/keeper/relay.go b/modules/apps/transfer/keeper/relay.go index b3bd61c3cad..5ee59d0a370 100644 --- a/modules/apps/transfer/keeper/relay.go +++ b/modules/apps/transfer/keeper/relay.go @@ -301,7 +301,9 @@ func (k Keeper) OnAcknowledgementPacket(ctx sdk.Context, packet channeltypes.Pac switch ack.Response.(type) { case *channeltypes.Acknowledgement_Result: if isForwarded { - return k.ackForwardPacketSuccess(ctx, prevPacket, packet) + // Write a successful async ack for the prevPacket + forwardAck := channeltypes.NewResultAcknowledgement([]byte{byte(1)}) + return k.acknowledgeForwardedPacket(ctx, prevPacket, packet, forwardAck) } // the acknowledgement succeeded on the receiving chain so nothing @@ -313,7 +315,15 @@ func (k Keeper) OnAcknowledgementPacket(ctx sdk.Context, packet channeltypes.Pac return err } if isForwarded { - return k.ackForwardPacketError(ctx, prevPacket, packet, data) + // the forwarded packet has failed, thus the funds have been refunded to the intermediate address. + // we must revert the changes that came from successfully receiving the tokens on our chain + // before propagating the error acknowledgement back to original sender chain + if err := k.revertForwardedPacket(ctx, prevPacket, data); err != nil { + return err + } + + forwardAck := channeltypes.NewErrorAcknowledgement(types.ErrForwardedPacketFailed) + return k.acknowledgeForwardedPacket(ctx, prevPacket, packet, forwardAck) } return nil @@ -332,7 +342,12 @@ func (k Keeper) OnTimeoutPacket(ctx sdk.Context, packet channeltypes.Packet, dat prevPacket, isForwarded := k.getForwardedPacket(ctx, packet.SourcePort, packet.SourceChannel, packet.Sequence) if isForwarded { - return k.ackForwardPacketTimeout(ctx, prevPacket, packet, data) + if err := k.revertForwardedPacket(ctx, prevPacket, data); err != nil { + return err + } + + forwardAck := channeltypes.NewErrorAcknowledgement(types.ErrForwardedPacketTimedOut) + return k.acknowledgeForwardedPacket(ctx, prevPacket, packet, forwardAck) } return nil