Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve CustomCommitmentsPlugin methods #1613

Merged
merged 5 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions eclair-core/src/main/scala/fr/acinq/eclair/PluginParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package fr.acinq.eclair

import akka.event.LoggingAdapter
import fr.acinq.bitcoin.ByteVector32
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.channel.Origin
Expand Down Expand Up @@ -51,12 +52,12 @@ trait CustomCommitmentsPlugin extends PluginParams {
* expire. If your plugin defines non-standard HTLCs, and they need to be automatically failed, they should be
* returned by this method.
*/
def getIncomingHtlcs: Seq[IncomingHtlc]
def getIncomingHtlcs(nodeParams: NodeParams, log: LoggingAdapter): Seq[IncomingHtlc]

/**
* Outgoing HTLC sets that are still pending may either succeed or fail: we need to watch them to properly forward the
* result upstream to preserve channels. If you have non-standard HTLCs that may be in this situation, they should be
* returned by this method.
*/
def getHtlcsRelayedOut(htlcsIn: Seq[IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]]
def getHtlcsRelayedOut(htlcsIn: Seq[IncomingHtlc], nodeParams: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]]
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial
// result upstream to preserve channels.
val brokenHtlcs: BrokenHtlcs = {
val channels = listLocalChannels(nodeParams.db.channels)
val nonStandardIncomingHtlcs: Seq[IncomingHtlc] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getIncomingHtlcs }.flatten
val nonStandardIncomingHtlcs: Seq[IncomingHtlc] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getIncomingHtlcs(nodeParams, log) }.flatten
val htlcsIn: Seq[IncomingHtlc] = getIncomingHtlcs(channels, nodeParams.db.payments, nodeParams.privateKey) ++ nonStandardIncomingHtlcs
val nonStandardRelayedOutHtlcs: Map[Origin, Set[(ByteVector32, Long)]] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getHtlcsRelayedOut(htlcsIn) }.flatten.toMap
val nonStandardRelayedOutHtlcs: Map[Origin, Set[(ByteVector32, Long)]] = nodeParams.pluginParams.collect { case p: CustomCommitmentsPlugin => p.getHtlcsRelayedOut(htlcsIn, nodeParams, log) }.flatten.toMap
val relayedOut: Map[Origin, Set[(ByteVector32, Long)]] = getHtlcsRelayedOut(channels, htlcsIn) ++ nonStandardRelayedOutHtlcs

val notRelayed = htlcsIn.filterNot(htlcIn => relayedOut.keys.exists(origin => matchesOrigin(htlcIn.add, origin)))
Expand Down Expand Up @@ -329,9 +329,24 @@ object PostRestartHtlcCleaner {
private def isPendingUpstream(channelId: ByteVector32, htlcId: Long, htlcsIn: Seq[IncomingHtlc]): Boolean =
htlcsIn.exists(htlc => htlc.add.channelId == channelId && htlc.add.id == htlcId)

def groupByOrigin(htlcsOut: Seq[(Origin, ByteVector32, Long)], htlcsIn: Seq[IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] =
htlcsOut
.groupBy { case (origin, _, _) => origin }
.mapValues(_.map { case (_, channelId, htlcId) => (channelId, htlcId) }.toSet)
// We are only interested in HTLCs that are pending upstream (not fulfilled nor failed yet).
// It may be the case that we have unresolved HTLCs downstream that have been resolved upstream when the downstream
// channel is closing (e.g. due to an HTLC timeout) because cooperatively failing the HTLC downstream will be
// instant whereas the uncooperative close of the downstream channel will take time.
.filterKeys {
case _: Origin.Local => true
case o: Origin.ChannelRelayed => isPendingUpstream(o.originChannelId, o.originHtlcId, htlcsIn)
case o: Origin.TrampolineRelayed => o.htlcs.exists { case (channelId, htlcId) => isPendingUpstream(channelId, htlcId, htlcsIn) }
}
.toMap

/** @return pending outgoing HTLCs, grouped by their upstream origin. */
private def getHtlcsRelayedOut(channels: Seq[HasCommitments], htlcsIn: Seq[IncomingHtlc])(implicit log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = {
channels
val htlcsOut = channels
.flatMap { c =>
// Filter out HTLCs that will never reach the blockchain or have already been timed-out on-chain.
val htlcsToIgnore: Set[Long] = c match {
Expand Down Expand Up @@ -361,18 +376,7 @@ object PostRestartHtlcCleaner {
}
c.commitments.originChannels.collect { case (outgoingHtlcId, origin) if !htlcsToIgnore.contains(outgoingHtlcId) => (origin, c.channelId, outgoingHtlcId) }
}
.groupBy { case (origin, _, _) => origin }
.mapValues(_.map { case (_, channelId, htlcId) => (channelId, htlcId) }.toSet)
// We are only interested in HTLCs that are pending upstream (not fulfilled nor failed yet).
// It may be the case that we have unresolved HTLCs downstream that have been resolved upstream when the downstream
// channel is closing (e.g. due to an HTLC timeout) because cooperatively failing the HTLC downstream will be
// instant whereas the uncooperative close of the downstream channel will take time.
.filterKeys {
case _: Origin.Local => true
case o: Origin.ChannelRelayed => isPendingUpstream(o.originChannelId, o.originHtlcId, htlcsIn)
case o: Origin.TrampolineRelayed => o.htlcs.exists { case (channelId, htlcId) => isPendingUpstream(channelId, htlcId, htlcsIn) }
}
.toMap
groupByOrigin(htlcsOut, htlcsIn)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import java.util.UUID

import akka.Done
import akka.actor.ActorRef
import akka.event.LoggingAdapter
import akka.testkit.TestProbe
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin.{Block, ByteVector32, Crypto, Satoshi}
Expand Down Expand Up @@ -555,9 +556,9 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
val nonRelayedHtlc2In = buildHtlcIn(1L, channelId_ab_1, relayedPaymentHash)

val pluginParams = new CustomCommitmentsPlugin {
def name = "test with incoming HTLC from remote"
def getIncomingHtlcs: Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None), PostRestartHtlcCleaner.IncomingHtlc(nonRelayedHtlc2In.add, None))
def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
override def name = "test with incoming HTLC from remote"
override def getIncomingHtlcs(np: NodeParams, log: LoggingAdapter): Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None), PostRestartHtlcCleaner.IncomingHtlc(nonRelayedHtlc2In.add, None))
override def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc], np: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
}

val nodeParams1 = nodeParams.copy(pluginParams = List(pluginParams))
Expand Down Expand Up @@ -602,9 +603,9 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
val nonRelayedHtlcIn = buildHtlcIn(1L, channelId_ab_2, relayedPaymentHash)

val pluginParams = new CustomCommitmentsPlugin {
def name = "test with outgoing HTLC to remote"
def getIncomingHtlcs: Seq[PostRestartHtlcCleaner.IncomingHtlc] = List.empty
def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] = Map(trampolineRelayed -> Set((channelId_ab_1, 10L)))
override def name = "test with outgoing HTLC to remote"
override def getIncomingHtlcs(np: NodeParams, log: LoggingAdapter): Seq[PostRestartHtlcCleaner.IncomingHtlc] = List.empty
override def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc], np: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = Map(trampolineRelayed -> Set((channelId_ab_1, 10L)))
}

val nodeParams1 = nodeParams.copy(pluginParams = List(pluginParams))
Expand All @@ -628,9 +629,9 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit
val relayedHtlc1In = buildHtlcIn(0L, channelId_ab_1, trampolineRelayedPaymentHash)

val pluginParams = new CustomCommitmentsPlugin {
def name = "test with incoming HTLC from remote"
def getIncomingHtlcs: Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None))
def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc]): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
override def name = "test with incoming HTLC from remote"
override def getIncomingHtlcs(np: NodeParams, log: LoggingAdapter): Seq[PostRestartHtlcCleaner.IncomingHtlc] = List(PostRestartHtlcCleaner.IncomingHtlc(relayedHtlc1In.add, None))
override def getHtlcsRelayedOut(htlcsIn: Seq[PostRestartHtlcCleaner.IncomingHtlc], np: NodeParams, log: LoggingAdapter): Map[Origin, Set[(ByteVector32, Long)]] = Map.empty
}

val cmd1 = CMD_FAIL_HTLC(id = 0L, reason = Left(ByteVector.empty), replyTo_opt = None)
Expand Down