Skip to content

Commit

Permalink
Index trampoline payments by hash and secret (#1770)
Browse files Browse the repository at this point in the history
We need to group incoming HTLCs together by payment_hash and payment_secret,
otherwise we will reject valid payments that are split into multiple distinct
trampoline parts (same payment_hash but different payment_secret).

Fixes #1723
  • Loading branch information
t-bast committed May 4, 2021
1 parent 9e4042f commit 90fbcd3
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC}
import fr.acinq.eclair.db.PendingRelayDb
import fr.acinq.eclair.payment.IncomingPacket.NodeRelayPacket
import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags}
import fr.acinq.eclair.payment.OutgoingPacket.Upstream
import fr.acinq.eclair.payment._
Expand Down Expand Up @@ -75,13 +76,29 @@ object NodeRelay {
}
}

def apply(nodeParams: NodeParams, parent: akka.actor.typed.ActorRef[NodeRelayer.Command], register: ActorRef, relayId: UUID, paymentHash: ByteVector32, outgoingPaymentFactory: OutgoingPaymentFactory): Behavior[Command] =
def apply(nodeParams: NodeParams,
parent: akka.actor.typed.ActorRef[NodeRelayer.Command],
register: ActorRef,
relayId: UUID,
nodeRelayPacket: NodeRelayPacket,
paymentSecret: ByteVector32,
outgoingPaymentFactory: OutgoingPaymentFactory): Behavior[Command] =
Behaviors.setup { context =>
val paymentHash = nodeRelayPacket.add.paymentHash
val totalAmountIn = nodeRelayPacket.outerPayload.totalAmount
Behaviors.withMdc(Logs.mdc(
category_opt = Some(Logs.LogCategory.PAYMENT),
parentPaymentId_opt = Some(relayId), // for a node relay, we use the same identifier for the whole relay itself, and the outgoing payment
paymentHash_opt = Some(paymentHash))) {
new NodeRelay(nodeParams, parent, register, relayId, paymentHash, context, outgoingPaymentFactory)()
context.log.info("relaying payment relayId={}", relayId)
val mppFsmAdapters = {
context.messageAdapter[MultiPartPaymentFSM.ExtraPaymentReceived[HtlcPart]](WrappedMultiPartExtraPaymentReceived)
context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentFailed](WrappedMultiPartPaymentFailed)
context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentSucceeded](WrappedMultiPartPaymentSucceeded)
}.toClassic
val incomingPaymentHandler = context.actorOf(MultiPartPaymentFSM.props(nodeParams, paymentHash, totalAmountIn, mppFsmAdapters))
new NodeRelay(nodeParams, parent, register, relayId, paymentHash, paymentSecret, context, outgoingPaymentFactory)
.receiving(Queue.empty, nodeRelayPacket.innerPayload, nodeRelayPacket.nextPacket, incomingPaymentHandler)
}
}

Expand Down Expand Up @@ -144,66 +161,37 @@ class NodeRelay private(nodeParams: NodeParams,
register: ActorRef,
relayId: UUID,
paymentHash: ByteVector32,
paymentSecret: ByteVector32,
context: ActorContext[NodeRelay.Command],
outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory) {

import NodeRelay._

private val mppFsmAdapters = {
context.messageAdapter[MultiPartPaymentFSM.ExtraPaymentReceived[HtlcPart]](WrappedMultiPartExtraPaymentReceived)
context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentFailed](WrappedMultiPartPaymentFailed)
context.messageAdapter[MultiPartPaymentFSM.MultiPartPaymentSucceeded](WrappedMultiPartPaymentSucceeded)
}.toClassic
private val payFsmAdapters = {
context.messageAdapter[PreimageReceived](WrappedPreimageReceived)
context.messageAdapter[PaymentSent](WrappedPaymentSent)
context.messageAdapter[PaymentFailed](WrappedPaymentFailed)
}.toClassic

def apply(): Behavior[Command] =
Behaviors.receiveMessagePartial {
// We make sure we receive all payment parts before forwarding to the next trampoline node.
case Relay(IncomingPacket.NodeRelayPacket(add, outer, inner, next)) => outer.paymentSecret match {
case None =>
// TODO: @pm: maybe those checks should be done later in the flow (by the mpp FSM?)
context.log.warn("rejecting htlcId={}: missing payment secret", add.id)
rejectHtlc(add.id, add.channelId, add.amountMsat)
stopping()
case Some(secret) =>
import akka.actor.typed.scaladsl.adapter._
context.log.info("relaying payment relayId={}", relayId)
val mppFsm = context.actorOf(MultiPartPaymentFSM.props(nodeParams, add.paymentHash, outer.totalAmount, mppFsmAdapters))
context.log.debug("forwarding incoming htlc to the payment FSM")
mppFsm ! MultiPartPaymentFSM.HtlcPart(outer.totalAmount, add)
receiving(Queue(add), secret, inner, next, mppFsm)
}
}

/**
* We start by aggregating an incoming HTLC set. Once we received the whole set, we will compute a route to the next
* trampoline node and forward the payment.
*
* @param htlcs received incoming HTLCs for this set.
* @param secret all incoming HTLCs in this set must have the same secret to protect against probing / fee theft.
* @param nextPayload relay instructions (should be identical across HTLCs in this set).
* @param nextPacket trampoline onion to relay to the next trampoline node.
* @param handler actor handling the aggregation of the incoming HTLC set.
*/
private def receiving(htlcs: Queue[UpdateAddHtlc], secret: ByteVector32, nextPayload: Onion.NodeRelayPayload, nextPacket: OnionRoutingPacket, handler: ActorRef): Behavior[Command] =
private def receiving(htlcs: Queue[UpdateAddHtlc], nextPayload: Onion.NodeRelayPayload, nextPacket: OnionRoutingPacket, handler: ActorRef): Behavior[Command] =
Behaviors.receiveMessagePartial {
case Relay(IncomingPacket.NodeRelayPacket(add, outer, _, _)) => outer.paymentSecret match {
// TODO: @pm: maybe those checks should be done by the mpp FSM?
case None =>
context.log.warn("rejecting htlcId={}: missing payment secret", add.id)
context.log.warn("rejecting htlc #{} from channel {}: missing payment secret", add.id, add.channelId)
rejectHtlc(add.id, add.channelId, add.amountMsat)
Behaviors.same
case Some(incomingSecret) if incomingSecret != secret =>
context.log.warn("rejecting htlcId={}: payment secret doesn't match other HTLCs in the set", add.id)
case Some(incomingSecret) if incomingSecret != paymentSecret =>
context.log.warn("rejecting htlc #{} from channel {}: payment secret doesn't match other HTLCs in the set", add.id, add.channelId)
rejectHtlc(add.id, add.channelId, add.amountMsat)
Behaviors.same
case Some(incomingSecret) if incomingSecret == secret =>
context.log.debug("forwarding incoming htlc to the payment FSM")
case Some(incomingSecret) if incomingSecret == paymentSecret =>
context.log.debug("forwarding incoming htlc #{} from channel {} to the payment FSM", add.id, add.channelId)
handler ! MultiPartPaymentFSM.HtlcPart(outer.totalAmount, add)
receiving(htlcs :+ add, secret, nextPayload, nextPacket, handler)
receiving(htlcs :+ add, nextPayload, nextPacket, handler)
}
case WrappedMultiPartPaymentFailed(MultiPartPaymentFSM.MultiPartPaymentFailed(_, failure, parts)) =>
context.log.warn("could not complete incoming multi-part payment (parts={} paidAmount={} failure={})", parts.size, parts.map(_.amount).sum, failure)
Expand Down Expand Up @@ -267,14 +255,20 @@ class NodeRelay private(nodeParams: NodeParams,
* Once the downstream payment is settled (fulfilled or failed), we reject new upstream payments while we wait for our parent to stop us.
*/
private def stopping(): Behavior[Command] = {
parent ! NodeRelayer.RelayComplete(context.self, paymentHash)
parent ! NodeRelayer.RelayComplete(context.self, paymentHash, paymentSecret)
Behaviors.receiveMessagePartial {
rejectExtraHtlcPartialFunction orElse {
case Stop => Behaviors.stopped
}
}
}

private val payFsmAdapters = {
context.messageAdapter[PreimageReceived](WrappedPreimageReceived)
context.messageAdapter[PaymentSent](WrappedPaymentSent)
context.messageAdapter[PaymentFailed](WrappedPaymentFailed)
}.toClassic

private def relay(upstream: Upstream.Trampoline, payloadOut: Onion.NodeRelayPayload, packetOut: OnionRoutingPacket): ActorRef = {
val paymentCfg = SendPaymentConfig(relayId, relayId, None, paymentHash, payloadOut.amountToForward, payloadOut.outgoingNodeId, upstream, None, storeInDb = false, publishEvent = false, Nil)
val routeParams = computeRouteParams(nodeParams, upstream.amountIn, upstream.expiryIn, payloadOut.amountToForward, payloadOut.outgoingCltv)
Expand Down Expand Up @@ -322,7 +316,7 @@ class NodeRelay private(nodeParams: NodeParams,
}

private def rejectExtraHtlc(add: UpdateAddHtlc): Unit = {
context.log.warn("rejecting extra htlcId={}", add.id)
context.log.warn("rejecting extra htlc #{} from channel {}", add.id, add.channelId)
rejectHtlc(add.id, add.channelId, add.amountMsat)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ package fr.acinq.eclair.payment.relay
import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.{ActorRef, Behavior}
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.channel.CMD_FAIL_HTLC
import fr.acinq.eclair.db.PendingRelayDb
import fr.acinq.eclair.payment._
import fr.acinq.eclair.wire.protocol.IncorrectOrUnknownPaymentDetails
import fr.acinq.eclair.{Logs, NodeParams}

import java.util.UUID
Expand All @@ -29,16 +32,16 @@ import java.util.UUID
*/

/**
* The [[NodeRelayer]] relays an upstream payment to a downstream remote node (which is not necessarily a direct peer). It
* doesn't do the job itself, instead it dispatches each individual payment (which can be multi-in, multi-out) to a child
* actor of type [[NodeRelay]].
* The [[NodeRelayer]] relays an upstream payment to a downstream remote node (which is not necessarily a direct peer).
* It doesn't do the job itself, instead it dispatches each individual payment (which can be multi-in, multi-out) to a
* child actor of type [[NodeRelay]].
*/
object NodeRelayer {

// @formatter:off
sealed trait Command
case class Relay(nodeRelayPacket: IncomingPacket.NodeRelayPacket) extends Command
case class RelayComplete(childHandler: ActorRef[NodeRelay.Command], paymentHash: ByteVector32) extends Command
case class RelayComplete(childHandler: ActorRef[NodeRelay.Command], paymentHash: ByteVector32, paymentSecret: ByteVector32) extends Command
private[relay] case class GetPendingPayments(replyTo: akka.actor.ActorRef) extends Command
// @formatter:on

Expand All @@ -48,34 +51,47 @@ object NodeRelayer {
case _: GetPendingPayments => Logs.mdc()
}

case class PaymentKey(paymentHash: ByteVector32, paymentSecret: ByteVector32)

/**
* @param children a map of current in-process payments, indexed by payment hash and purposefully *not* by payment id,
* because that is how we aggregate payment parts (when the incoming payment uses MPP).
* @param children a map of pending payments. We must index by both payment hash and payment secret because we may
* need to independently relay multiple parts of the same payment using distinct payment secrets.
* NB: the payment secret used here is different from the invoice's payment secret and ensures we can
* group together HTLCs that the previous trampoline node sent in the same MPP.
*/
def apply(nodeParams: NodeParams, router: akka.actor.ActorRef, register: akka.actor.ActorRef, children: Map[ByteVector32, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] =
def apply(nodeParams: NodeParams, register: akka.actor.ActorRef, outgoingPaymentFactory: NodeRelay.OutgoingPaymentFactory, children: Map[PaymentKey, ActorRef[NodeRelay.Command]] = Map.empty): Behavior[Command] =
Behaviors.setup { context =>
Behaviors.withMdc(Logs.mdc(category_opt = Some(Logs.LogCategory.PAYMENT)), mdc) {
Behaviors.receiveMessage {
case Relay(nodeRelayPacket) =>
import nodeRelayPacket.add.paymentHash
children.get(paymentHash) match {
case Some(handler) =>
context.log.debug("forwarding incoming htlc to existing handler")
handler ! NodeRelay.Relay(nodeRelayPacket)
Behaviors.same
val htlcIn = nodeRelayPacket.add
nodeRelayPacket.outerPayload.paymentSecret match {
case Some(paymentSecret) =>
val childKey = PaymentKey(htlcIn.paymentHash, paymentSecret)
children.get(childKey) match {
case Some(handler) =>
context.log.debug("forwarding incoming htlc #{} from channel {} to existing handler", htlcIn.id, htlcIn.channelId)
handler ! NodeRelay.Relay(nodeRelayPacket)
Behaviors.same
case None =>
val relayId = UUID.randomUUID()
context.log.debug(s"spawning a new handler with relayId=$relayId")
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, nodeRelayPacket, childKey.paymentSecret, outgoingPaymentFactory), relayId.toString)
context.log.debug("forwarding incoming htlc #{} from channel {} to new handler", htlcIn.id, htlcIn.channelId)
handler ! NodeRelay.Relay(nodeRelayPacket)
apply(nodeParams, register, outgoingPaymentFactory, children + (childKey -> handler))
}
case None =>
val relayId = UUID.randomUUID()
context.log.debug(s"spawning a new handler with relayId=$relayId")
val outgoingPaymentFactory = NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register)
val handler = context.spawn(NodeRelay.apply(nodeParams, context.self, register, relayId, paymentHash, outgoingPaymentFactory), relayId.toString)
context.log.debug("forwarding incoming htlc to new handler")
handler ! NodeRelay.Relay(nodeRelayPacket)
apply(nodeParams, router, register, children + (paymentHash -> handler))
context.log.warn("rejecting htlc #{} from channel {}: missing payment secret", htlcIn.id, htlcIn.channelId)
val failureMessage = IncorrectOrUnknownPaymentDetails(htlcIn.amountMsat, nodeParams.currentBlockHeight)
val cmd = CMD_FAIL_HTLC(htlcIn.id, Right(failureMessage), commit = true)
PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, htlcIn.channelId, cmd)
Behaviors.same
}
case RelayComplete(childHandler, paymentHash) =>
case RelayComplete(childHandler, paymentHash, paymentSecret) =>
// we do a back-and-forth between parent and child before stopping the child to prevent a race condition
childHandler ! NodeRelay.Stop
apply(nodeParams, router, register, children - paymentHash)
apply(nodeParams, register, outgoingPaymentFactory, children - PaymentKey(paymentHash, paymentSecret))
case GetPendingPayments(replyTo) =>
replyTo ! children
Behaviors.same
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym

private val postRestartCleaner = context.actorOf(PostRestartHtlcCleaner.props(nodeParams, register, initialized), "post-restart-htlc-cleaner")
private val channelRelayer = context.spawn(Behaviors.supervise(ChannelRelayer(nodeParams, register)).onFailure(SupervisorStrategy.resume), "channel-relayer")
private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, router, register)).onFailure(SupervisorStrategy.resume), name = "node-relayer")
private val nodeRelayer = context.spawn(Behaviors.supervise(NodeRelayer(nodeParams, register, NodeRelay.SimpleOutgoingPaymentFactory(nodeParams, router, register))).onFailure(SupervisorStrategy.resume), name = "node-relayer")

def receive: Receive = {
case RelayForward(add) =>
Expand Down
Loading

0 comments on commit 90fbcd3

Please sign in to comment.