diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index 5f9613051e..f2f90444e0 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -307,6 +307,18 @@ eclair { channel-query-chunk-size = 100 // max number of short_channel_ids in query_short_channel_ids *do not change this unless you know what you are doing* } + message-path-finding { + max-route-length = 6 + ratios { + // The next three weights must sum to one. + base = 0.6 // when computing the weight for a channel, proportion that stays the same for all channels + channel-age = 0.1 // when computing the weight for a channel, consider its AGE in this proportion + channel-capacity = 0.3 // when computing the weight for a channel, consider its CAPACITY in this proportion + + disabled-multiplier = 2.5 // How much we prefer relaying a message along an active channel instead of a disabled one. + } + } + path-finding { default { randomize-route-selection = true // when computing a route for a payment we randomize the final selection @@ -478,8 +490,11 @@ eclair { max-per-peer-per-second = 10 + # Minimum number of hops before our node to hide it in the reply paths that we build + min-intermediate-hops = 6 + # Consider a message to be lost if we haven't received a reply after that amount of time - reply-timeout = 5 seconds + reply-timeout = 15 seconds # If we expect a reply but do not get one, retry until we reach this number of attempts max-attempts = 3 diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala index e2ab5ccc48..b5e9f61221 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Eclair.scala @@ -172,11 +172,11 @@ trait Eclair { def verifyMessage(message: ByteVector, recoverableSignature: ByteVector): VerifiedMessage - def sendOnionMessage(intermediateNodes: Seq[PublicKey], destination: Either[PublicKey, Sphinx.RouteBlinding.BlindedRoute], replyPath: Option[Seq[PublicKey]], userCustomContent: ByteVector)(implicit timeout: Timeout): Future[SendOnionMessageResponse] + def sendOnionMessage(intermediateNodes_opt: Option[Seq[PublicKey]], destination: Either[PublicKey, Sphinx.RouteBlinding.BlindedRoute], expectsReply: Boolean, userCustomContent: ByteVector)(implicit timeout: Timeout): Future[SendOnionMessageResponse] - def payOffer(offer: Offer, amount: MilliSatoshi, quantity: Long, externalId_opt: Option[String] = None, maxAttempts_opt: Option[Int] = None, maxFeeFlat_opt: Option[Satoshi] = None, maxFeePct_opt: Option[Double] = None, pathFindingExperimentName_opt: Option[String] = None)(implicit timeout: Timeout): Future[UUID] + def payOffer(offer: Offer, amount: MilliSatoshi, quantity: Long, externalId_opt: Option[String] = None, maxAttempts_opt: Option[Int] = None, maxFeeFlat_opt: Option[Satoshi] = None, maxFeePct_opt: Option[Double] = None, pathFindingExperimentName_opt: Option[String] = None, connectDirectly: Boolean = false)(implicit timeout: Timeout): Future[UUID] - def payOfferBlocking(offer: Offer, amount: MilliSatoshi, quantity: Long, externalId_opt: Option[String] = None, maxAttempts_opt: Option[Int] = None, maxFeeFlat_opt: Option[Satoshi] = None, maxFeePct_opt: Option[Double] = None, pathFindingExperimentName_opt: Option[String] = None)(implicit timeout: Timeout): Future[PaymentEvent] + def payOfferBlocking(offer: Offer, amount: MilliSatoshi, quantity: Long, externalId_opt: Option[String] = None, maxAttempts_opt: Option[Int] = None, maxFeeFlat_opt: Option[Satoshi] = None, maxFeePct_opt: Option[Double] = None, pathFindingExperimentName_opt: Option[String] = None, connectDirectly: Boolean = false)(implicit timeout: Timeout): Future[PaymentEvent] def stop(): Future[Unit] } @@ -621,22 +621,22 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { } } - override def sendOnionMessage(intermediateNodes: Seq[PublicKey], + override def sendOnionMessage(intermediateNodes_opt: Option[Seq[PublicKey]], recipient: Either[PublicKey, Sphinx.RouteBlinding.BlindedRoute], - replyPath: Option[Seq[PublicKey]], + expectsReply: Boolean, userCustomContent: ByteVector)(implicit timeout: Timeout): Future[SendOnionMessageResponse] = { - if (replyPath.nonEmpty && (replyPath.get.isEmpty || replyPath.get.last != appKit.nodeParams.nodeId)) { - return Future.failed(new Exception("Reply path must end at our node.")) - } TlvCodecs.tlvStream(MessageOnionCodecs.onionTlvCodec).decode(userCustomContent.bits) match { case Attempt.Successful(DecodeResult(userTlvs, _)) => val destination = recipient match { case Left(key) => OnionMessages.Recipient(key, None) case Right(route) => OnionMessages.BlindedPath(route) } - appKit.postman.ask(ref => Postman.SendMessage(intermediateNodes, destination, replyPath, userTlvs, ref, appKit.nodeParams.onionMessageConfig.timeout)).map { - case Postman.Response(payload) => - SendOnionMessageResponse(sent = true, None, Some(SendOnionMessageResponsePayload(payload.records))) + val routingStrategy = intermediateNodes_opt match { + case Some(intermediateNodes) => OnionMessages.RoutingStrategy.UseRoute(intermediateNodes) + case None => OnionMessages.RoutingStrategy.FindRoute + } + appKit.postman.ask(ref => Postman.SendMessage(destination, routingStrategy, userTlvs, expectsReply, ref)).map { + case Postman.Response(payload) => SendOnionMessageResponse(sent = true, None, Some(SendOnionMessageResponsePayload(payload.records))) case Postman.NoReply => SendOnionMessageResponse(sent = true, Some("No response"), None) case Postman.MessageSent => SendOnionMessageResponse(sent = true, None, None) case Postman.MessageFailed(failure: String) => SendOnionMessageResponse(sent = false, Some(failure), None) @@ -653,6 +653,7 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { maxFeeFlat_opt: Option[Satoshi], maxFeePct_opt: Option[Double], pathFindingExperimentName_opt: Option[String], + connectDirectly: Boolean, blocking: Boolean)(implicit timeout: Timeout): Future[Any] = { if (externalId_opt.exists(_.length > externalIdMaxLength)) { return Future.failed(new IllegalArgumentException(s"externalId is too long: cannot exceed $externalIdMaxLength characters")) @@ -664,7 +665,7 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { .modify(_.boundaries.maxFeeFlat).setToIfDefined(maxFeeFlat_opt.map(_.toMilliSatoshi)) case Left(t) => return Future.failed(t) } - val sendPaymentConfig = OfferPayment.SendPaymentConfig(externalId_opt, maxAttempts_opt.getOrElse(appKit.nodeParams.maxPaymentAttempts), routeParams, blocking) + val sendPaymentConfig = OfferPayment.SendPaymentConfig(externalId_opt, connectDirectly, maxAttempts_opt.getOrElse(appKit.nodeParams.maxPaymentAttempts), routeParams, blocking) val offerPayment = appKit.system.spawnAnonymous(OfferPayment(appKit.nodeParams, appKit.postman, appKit.paymentInitiator)) offerPayment.ask((ref: typed.ActorRef[Any]) => OfferPayment.PayOffer(ref.toClassic, offer, amount, quantity, sendPaymentConfig)).flatMap { case f: OfferPayment.Failure => Future.failed(new Exception(f.toString)) @@ -679,8 +680,9 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { maxAttempts_opt: Option[Int], maxFeeFlat_opt: Option[Satoshi], maxFeePct_opt: Option[Double], - pathFindingExperimentName_opt: Option[String])(implicit timeout: Timeout): Future[UUID] = { - payOfferInternal(offer, amount, quantity, externalId_opt, maxAttempts_opt, maxFeeFlat_opt, maxFeePct_opt, pathFindingExperimentName_opt, blocking = false).mapTo[UUID] + pathFindingExperimentName_opt: Option[String], + connectDirectly: Boolean)(implicit timeout: Timeout): Future[UUID] = { + payOfferInternal(offer, amount, quantity, externalId_opt, maxAttempts_opt, maxFeeFlat_opt, maxFeePct_opt, pathFindingExperimentName_opt, connectDirectly, blocking = false).mapTo[UUID] } override def payOfferBlocking(offer: Offer, @@ -690,8 +692,9 @@ class EclairImpl(appKit: Kit) extends Eclair with Logging { maxAttempts_opt: Option[Int], maxFeeFlat_opt: Option[Satoshi], maxFeePct_opt: Option[Double], - pathFindingExperimentName_opt: Option[String])(implicit timeout: Timeout): Future[PaymentEvent] = { - payOfferInternal(offer, amount, quantity, externalId_opt, maxAttempts_opt, maxFeeFlat_opt, maxFeePct_opt, pathFindingExperimentName_opt, blocking = true).mapTo[PaymentEvent] + pathFindingExperimentName_opt: Option[String], + connectDirectly: Boolean)(implicit timeout: Timeout): Future[PaymentEvent] = { + payOfferInternal(offer, amount, quantity, externalId_opt, maxAttempts_opt, maxFeeFlat_opt, maxFeePct_opt, pathFindingExperimentName_opt, connectDirectly, blocking = true).mapTo[PaymentEvent] } override def stop(): Future[Unit] = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index fa9c5f98fc..5ee319fa53 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -33,8 +33,8 @@ import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams} import fr.acinq.eclair.router.Announcements.AddressException import fr.acinq.eclair.router.Graph.{HeuristicsConstants, WeightRatios} -import fr.acinq.eclair.router.PathFindingExperimentConf -import fr.acinq.eclair.router.Router.{MultiPartParams, PathFindingConf, RouterConf, SearchBoundaries} +import fr.acinq.eclair.router.{Graph, PathFindingExperimentConf} +import fr.acinq.eclair.router.Router.{MessageRouteParams, MultiPartParams, PathFindingConf, RouterConf, SearchBoundaries} import fr.acinq.eclair.tor.Socks5ProxyParams import fr.acinq.eclair.wire.protocol._ import grizzled.slf4j.Logging @@ -422,6 +422,15 @@ object NodeParams extends Logging { PathFindingExperimentConf(experiments.toMap) } + def getMessageRouteParams(config: Config): MessageRouteParams = { + val maxRouteLength = config.getInt("max-route-length") + val ratioBase = config.getDouble("ratios.base") + val ratioAge = config.getDouble("ratios.channel-age") + val ratioCapacity = config.getDouble("ratios.channel-capacity") + val disabledMultiplier = config.getDouble("ratios.disabled-multiplier") + MessageRouteParams(maxRouteLength, Graph.MessagePath.WeightRatios(ratioBase, ratioAge, ratioCapacity, disabledMultiplier)) + } + val unhandledExceptionStrategy = config.getString("channel.unhandled-exception-strategy") match { case "local-close" => UnhandledExceptionStrategy.LocalClose case "stop" => UnhandledExceptionStrategy.Stop @@ -557,6 +566,7 @@ object NodeParams extends Logging { channelRangeChunkSize = config.getInt("router.sync.channel-range-chunk-size"), channelQueryChunkSize = config.getInt("router.sync.channel-query-chunk-size"), pathFindingExperimentConf = getPathFindingExperimentConf(config.getConfig("router.path-finding.experiments")), + messageRouteParams = getMessageRouteParams(config.getConfig("router.message-path-finding")), balanceEstimateHalfLife = FiniteDuration(config.getDuration("router.balance-estimate-half-life").getSeconds, TimeUnit.SECONDS), ), socksProxy_opt = socksProxy_opt, @@ -568,6 +578,7 @@ object NodeParams extends Logging { blockchainWatchdogSources = config.getStringList("blockchain-watchdog.sources").asScala.toSeq, onionMessageConfig = OnionMessageConfig( relayPolicy = onionMessageRelayPolicy, + minIntermediateHops = config.getInt("onion-messages.min-intermediate-hops"), timeout = FiniteDuration(config.getDuration("onion-messages.reply-timeout").getSeconds, TimeUnit.SECONDS), maxAttempts = config.getInt("onion-messages.max-attempts"), ), diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala index 1727be255d..4716c49a29 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala @@ -368,7 +368,7 @@ class Setup(val datadir: File, balanceActor = system.spawn(BalanceActor(nodeParams.db, bitcoinClient, channelsListener, nodeParams.balanceCheckInterval), name = "balance-actor") - postman = system.spawn(Behaviors.supervise(Postman(nodeParams, switchboard.toTyped, offerManager)).onFailure(typed.SupervisorStrategy.restart), name = "postman") + postman = system.spawn(Behaviors.supervise(Postman(nodeParams, switchboard.toTyped, router.toTyped, offerManager)).onFailure(typed.SupervisorStrategy.restart), name = "postman") kit = Kit( nodeParams = nodeParams, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala index 1ec0fc25f1..e6a83b3d19 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala @@ -31,7 +31,15 @@ import scala.concurrent.duration.FiniteDuration object OnionMessages { + /** + * @param relayPolicy When to relay onion messages (always, never, only along existing channels). + * @param minIntermediateHops For routes we build to us, minimum number of hops before our node. Dummy hops are added + * if needed to hide our position in the network. + * @param timeout Time after which we consider that the message has been lost and stop waiting for a reply. + * @param maxAttempts Maximum number of attempts for sending a message. + */ case class OnionMessageConfig(relayPolicy: RelayPolicy, + minIntermediateHops: Int, timeout: FiniteDuration, maxAttempts: Int) @@ -43,6 +51,18 @@ object OnionMessages { case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) extends Destination // @formatter:on + // @formatter:off + sealed trait RoutingStrategy + object RoutingStrategy { + /** Use the provided route to reach the recipient or the blinded path's introduction node. */ + case class UseRoute(intermediateNodes: Seq[PublicKey]) extends RoutingStrategy + /** Directly connect to the recipient or the blinded path's introduction node. */ + val connectDirectly: UseRoute = UseRoute(Nil) + /** Use path-finding to find a route to reach the recipient or the blinded path's introduction node. */ + case object FindRoute extends RoutingStrategy + } + // @formatter:on + private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], nextTlvs: Set[RouteBlindingEncryptedDataTlv]): Seq[ByteVector] = { if (intermediateNodes.isEmpty) { Nil @@ -63,9 +83,9 @@ object OnionMessages { } private[message] def buildRouteFrom(originKey: PrivateKey, - blindingSecret: PrivateKey, - intermediateNodes: Seq[IntermediateNode], - destination: Destination): Option[Sphinx.RouteBlinding.BlindedRoute] = { + blindingSecret: PrivateKey, + intermediateNodes: Seq[IntermediateNode], + destination: Destination): Option[Sphinx.RouteBlinding.BlindedRoute] = { destination match { case recipient: Recipient => Some(buildRoute(blindingSecret, intermediateNodes, recipient)) case BlindedPath(route) if route.introductionNodeId == originKey.publicKey => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala b/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala index b4a405f5c8..759f3f325a 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala @@ -18,42 +18,41 @@ package fr.acinq.eclair.message import akka.actor.typed import akka.actor.typed.eventstream.EventStream -import akka.actor.typed.scaladsl.Behaviors +import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import akka.actor.typed.{ActorRef, Behavior} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.io.{MessageRelay, Switchboard} -import fr.acinq.eclair.message.OnionMessages.Destination +import fr.acinq.eclair.message.OnionMessages.{Destination, RoutingStrategy} import fr.acinq.eclair.payment.offer.OfferManager +import fr.acinq.eclair.router.Router +import fr.acinq.eclair.router.Router.{MessageRoute, MessageRouteNotFound, MessageRouteResponse} import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, InvoiceRequestPayload} import fr.acinq.eclair.wire.protocol.{OnionMessagePayloadTlv, TlvStream} import fr.acinq.eclair.{NodeParams, randomBytes32, randomKey} import scala.collection.mutable -import scala.concurrent.duration.FiniteDuration object Postman { // @formatter:off sealed trait Command - - /** Builds a message packet and send it to the destination using the provided path. + /** + * Builds a message packet and send it to the destination using the provided path. * - * @param intermediateNodes Extra hops to use between us and the destination - * @param destination Recipient of the message - * @param replyPath Hops to use for the reply (including our own node as the last hop) or None if not expecting a reply - * @param message Content of the message to send - * @param replyTo Actor to send the status and reply to - * @param timeout When expecting a reply, maximum delay to wait for it + * @param destination Recipient of the message + * @param routingStrategy How to reach the destination (recipient or blinded path introduction node). + * @param message Content of the message to send + * @param expectsReply Whether the message expects a reply + * @param replyTo Actor to send the status and reply to */ - case class SendMessage(intermediateNodes: Seq[PublicKey], - destination: Destination, - replyPath: Option[Seq[PublicKey]], + case class SendMessage(destination: Destination, + routingStrategy: RoutingStrategy, message: TlvStream[OnionMessagePayloadTlv], - replyTo: ActorRef[OnionMessageResponse], - timeout: FiniteDuration) extends Command + expectsReply: Boolean, + replyTo: ActorRef[OnionMessageResponse]) extends Command + case class Subscribe(pathId: ByteVector32, replyTo: ActorRef[OnionMessageResponse]) extends Command private case class Unsubscribe(pathId: ByteVector32) extends Command case class WrappedMessage(finalPayload: FinalPayload) extends Command - case class SendingStatus(status: MessageRelay.Status) extends Command sealed trait OnionMessageResponse case object NoReply extends OnionMessageResponse @@ -63,19 +62,14 @@ object Postman { case class MessageFailed(reason: String) extends MessageStatus // @formatter:on - def apply(nodeParams: NodeParams, switchboard: ActorRef[Switchboard.RelayMessage], offerManager: typed.ActorRef[OfferManager.RequestInvoice]): Behavior[Command] = { + def apply(nodeParams: NodeParams, switchboard: ActorRef[Switchboard.RelayMessage], router: ActorRef[Router.MessageRouteRequest], offerManager: typed.ActorRef[OfferManager.RequestInvoice]): Behavior[Command] = { Behaviors.setup(context => { context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[OnionMessages.ReceiveMessage](r => WrappedMessage(r.finalPayload))) - val relayMessageStatusAdapter = context.messageAdapter[MessageRelay.Status](SendingStatus) - // For messages expecting a reply, send reply or failure to send val subscribed = new mutable.HashMap[ByteVector32, ActorRef[OnionMessageResponse]]() - // For messages not expecting a reply, send success or failure to send - val sendStatusTo = new mutable.HashMap[ByteVector32, ActorRef[OnionMessageResponse]]() - - Behaviors.receiveMessagePartial { + Behaviors.receiveMessage { case WrappedMessage(invoiceRequestPayload: InvoiceRequestPayload) => offerManager ! OfferManager.RequestInvoice(invoiceRequestPayload, context.self) Behaviors.same @@ -90,31 +84,13 @@ object Postman { case _ => // ignoring message with invalid or missing pathId } Behaviors.same - case SendMessage(intermediateNodes, destination, replyPath, messageContent, replyTo, timeout) => - val messageId = randomBytes32() - val replyRoute = replyPath.map(replyHops => { - val intermediateHops = replyHops.dropRight(1).map(OnionMessages.IntermediateNode(_)) - val lastHop = OnionMessages.Recipient(replyHops.last, Some(messageId)) - OnionMessages.buildRoute(randomKey(), intermediateHops, lastHop) - }) - OnionMessages.buildMessage( - nodeParams.privateKey, - randomKey(), - randomKey(), - intermediateNodes.map(OnionMessages.IntermediateNode(_)), - destination, - TlvStream(replyRoute.map(OnionMessagePayloadTlv.ReplyPath).toSet ++ messageContent.records, messageContent.unknown)) match { - case Left(failure) => - replyTo ! MessageFailed(failure.toString) - case Right((nextNodeId, message)) => - if (replyPath.isEmpty) { // not expecting reply - sendStatusTo += (messageId -> replyTo) - } else { // expecting reply - subscribed += (messageId -> replyTo) - context.scheduleOnce(timeout, context.self, Unsubscribe(messageId)) - } - switchboard ! Switchboard.RelayMessage(messageId, None, nextNodeId, message, MessageRelay.RelayAll, Some(relayMessageStatusAdapter)) - } + case SendMessage(destination, routingStrategy, messageContent, expectsReply, replyTo) => + val child = context.spawnAnonymous(SendingMessage(nodeParams, switchboard, router, context.self, destination, messageContent, routingStrategy, expectsReply, replyTo)) + child ! SendingMessage.SendMessage + Behaviors.same + case Subscribe(pathId, replyTo) => + subscribed += (pathId -> replyTo) + context.scheduleOnce(nodeParams.onionMessageConfig.timeout, context.self, Unsubscribe(pathId)) Behaviors.same case Unsubscribe(pathId) => subscribed.get(pathId).foreach(ref => { @@ -122,23 +98,119 @@ object Postman { ref ! NoReply }) Behaviors.same - case SendingStatus(MessageRelay.Sent(messageId)) => - sendStatusTo.get(messageId).foreach(ref => { - sendStatusTo -= messageId - ref ! MessageSent - }) - Behaviors.same - case SendingStatus(status: MessageRelay.Failure) => - sendStatusTo.get(status.messageId).foreach(ref => { - sendStatusTo -= status.messageId - ref ! MessageFailed(status.toString) - }) - subscribed.get(status.messageId).foreach(ref => { - subscribed -= status.messageId - ref ! MessageFailed(status.toString) - }) - Behaviors.same } }) } } + +object SendingMessage { + // @formatter:off + sealed trait Command + case object SendMessage extends Command + private case class SendingStatus(status: MessageRelay.Status) extends Command + private case class WrappedMessageRouteResponse(response: MessageRouteResponse) extends Command + // @formatter:on + + def apply(nodeParams: NodeParams, + switchboard: ActorRef[Switchboard.RelayMessage], + router: ActorRef[Router.MessageRouteRequest], + postman: ActorRef[Postman.Command], + destination: Destination, + message: TlvStream[OnionMessagePayloadTlv], + routingStrategy: RoutingStrategy, + expectsReply: Boolean, + replyTo: ActorRef[Postman.OnionMessageResponse]): Behavior[Command] = { + Behaviors.setup(context => { + val actor = new SendingMessage(nodeParams, switchboard, router, postman, destination, message, routingStrategy, expectsReply, replyTo, context) + actor.start() + }) + } +} + +private class SendingMessage(nodeParams: NodeParams, + switchboard: ActorRef[Switchboard.RelayMessage], + router: ActorRef[Router.MessageRouteRequest], + postman: ActorRef[Postman.Command], + destination: Destination, + message: TlvStream[OnionMessagePayloadTlv], + routingStrategy: RoutingStrategy, + expectsReply: Boolean, + replyTo: ActorRef[Postman.OnionMessageResponse], + context: ActorContext[SendingMessage.Command]) { + + import SendingMessage._ + + def start(): Behavior[Command] = { + Behaviors.receiveMessagePartial { + case SendMessage => + val targetNodeId = destination match { + case OnionMessages.BlindedPath(route) => route.introductionNodeId + case OnionMessages.Recipient(nodeId, _, _, _) => nodeId + } + routingStrategy match { + case RoutingStrategy.UseRoute(intermediateNodes) => sendToRoute(intermediateNodes, targetNodeId) + case RoutingStrategy.FindRoute if targetNodeId == nodeParams.nodeId => + context.self ! WrappedMessageRouteResponse(MessageRoute(Nil, targetNodeId)) + waitForRouteFromRouter() + case RoutingStrategy.FindRoute => + router ! Router.MessageRouteRequest(context.messageAdapter(WrappedMessageRouteResponse), nodeParams.nodeId, targetNodeId, Set.empty) + waitForRouteFromRouter() + } + } + } + + private def waitForRouteFromRouter(): Behavior[Command] = { + Behaviors.receiveMessagePartial { + case WrappedMessageRouteResponse(MessageRoute(intermediateNodes, targetNodeId)) => + context.log.debug("Found route: {}", (intermediateNodes :+ targetNodeId).mkString(" -> ")) + sendToRoute(intermediateNodes, targetNodeId) + case WrappedMessageRouteResponse(MessageRouteNotFound(targetNodeId)) => + context.log.debug("No route found to {}", targetNodeId) + replyTo ! Postman.MessageFailed("No route found") + Behaviors.stopped + } + } + + private def sendToRoute(intermediateNodes: Seq[PublicKey], targetNodeId: PublicKey): Behavior[Command] = { + val messageId = randomBytes32() + val replyRoute = + if (expectsReply) { + val numHopsToAdd = 0.max(nodeParams.onionMessageConfig.minIntermediateHops - intermediateNodes.length - 1) + val intermediateHops = (Seq(targetNodeId) ++ intermediateNodes.reverse ++ Seq.fill(numHopsToAdd)(nodeParams.nodeId)).map(OnionMessages.IntermediateNode(_)) + val lastHop = OnionMessages.Recipient(nodeParams.nodeId, Some(messageId)) + Some(OnionMessages.buildRoute(randomKey(), intermediateHops, lastHop)) + } else { + None + } + OnionMessages.buildMessage( + nodeParams.privateKey, + randomKey(), + randomKey(), + intermediateNodes.map(OnionMessages.IntermediateNode(_)), + destination, + TlvStream(message.records ++ replyRoute.map(OnionMessagePayloadTlv.ReplyPath).toSet, message.unknown)) match { + case Left(failure) => + replyTo ! Postman.MessageFailed(failure.toString) + Behaviors.stopped + case Right((nextNodeId, message)) => + switchboard ! Switchboard.RelayMessage(messageId, None, nextNodeId, message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus))) + waitForSent() + } + } + + private def waitForSent(): Behavior[Command] = { + Behaviors.receiveMessagePartial { + case SendingStatus(MessageRelay.Sent(messageId)) => + if (expectsReply) { + postman ! Postman.Subscribe(messageId, replyTo) + } else { + replyTo ! Postman.MessageSent + } + Behaviors.stopped + case SendingStatus(status: MessageRelay.Failure) => + replyTo ! Postman.MessageFailed(status.toString) + Behaviors.stopped + } + } + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala index 10e946ff5c..4c2d511c6c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/offer/OfferManager.scala @@ -38,8 +38,6 @@ import scala.concurrent.duration.FiniteDuration * Created by thomash-acinq on 13/01/2023. */ -import scala.concurrent.duration.DurationInt - object OfferManager { sealed trait Command @@ -192,7 +190,7 @@ object OfferManager { def waitForHandler(): Behavior[Command] = { Behaviors.receiveMessagePartial { case RejectRequest(error) => - postman ! Postman.SendMessage(Nil, pathToSender, None, TlvStream(OnionMessagePayloadTlv.InvoiceError(TlvStream(OfferTypes.Error(error)))), context.messageAdapter[Postman.OnionMessageResponse](WrappedOnionMessageResponse), 0 seconds) + postman ! Postman.SendMessage(pathToSender, OnionMessages.RoutingStrategy.FindRoute, TlvStream(OnionMessagePayloadTlv.InvoiceError(TlvStream(OfferTypes.Error(error)))), expectsReply = false, context.messageAdapter[Postman.OnionMessageResponse](WrappedOnionMessageResponse)) waitForSent() case ApproveRequest(amount, routes, pluginData_opt, additionalTlvs, customTlvs) => val preimage = randomBytes32() @@ -210,7 +208,7 @@ object OfferManager { case WrappedInvoiceResponse(invoiceResponse) => invoiceResponse match { case CreateInvoiceActor.InvoiceCreated(invoice) => - postman ! Postman.SendMessage(Nil, pathToSender, None, TlvStream(OnionMessagePayloadTlv.Invoice(invoice.records)), context.messageAdapter[Postman.OnionMessageResponse](WrappedOnionMessageResponse), 0 seconds) + postman ! Postman.SendMessage(pathToSender, OnionMessages.RoutingStrategy.FindRoute, TlvStream(OnionMessagePayloadTlv.Invoice(invoice.records)), expectsReply = false, context.messageAdapter[Postman.OnionMessageResponse](WrappedOnionMessageResponse)) waitForSent() case f: CreateInvoiceActor.InvoiceCreationFailed => context.log.debug("invoice creation failed: {}", f.message) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/OfferPayment.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/OfferPayment.scala index 03fd7c2df4..0a0ab233c0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/OfferPayment.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/send/OfferPayment.scala @@ -60,6 +60,7 @@ object OfferPayment { case class WrappedMessageResponse(response: OnionMessageResponse) extends Command case class SendPaymentConfig(externalId_opt: Option[String], + connectDirectly: Boolean, maxAttempts: Int, routeParams: RouteParams, blocking: Boolean) @@ -109,10 +110,9 @@ object OfferPayment { case Right(nodeId) => OnionMessages.Recipient(nodeId, None, None) } - // TODO: Find a path made of channels as some nodes may refuse to relay messages to nodes with which they don't have a channel. - val intermediateNodesToRecipient = Nil val messageContent = TlvStream[OnionMessagePayloadTlv](OnionMessagePayloadTlv.InvoiceRequest(request.records)) - postman ! SendMessage(intermediateNodesToRecipient, destination, Some((nodeParams.nodeId +: intermediateNodesToRecipient).reverse), messageContent, context.messageAdapter(WrappedMessageResponse), nodeParams.onionMessageConfig.timeout) + val routingStrategy = if (sendPaymentConfig.connectDirectly) OnionMessages.RoutingStrategy.connectDirectly else OnionMessages.RoutingStrategy.FindRoute + postman ! SendMessage(destination, routingStrategy, messageContent, expectsReply = true, context.messageAdapter(WrappedMessageResponse)) waitForInvoice(nodeParams, postman, paymentInitiator, context, request, payerKey, replyTo, attemptNumber + 1, sendPaymentConfig) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/remote/EclairInternalsSerializer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/remote/EclairInternalsSerializer.scala index 06caeb1510..e0dc572a68 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/remote/EclairInternalsSerializer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/remote/EclairInternalsSerializer.scala @@ -86,6 +86,13 @@ object EclairInternalsSerializer { "experiments" | listOfN(int32, pathFindingConfCodec).xmap[Map[String, PathFindingConf]](_.map(e => e.experimentName -> e).toMap, _.values.toList) ).as[PathFindingExperimentConf] + val messageRouteParamsCodec: Codec[MessageRouteParams] = ( + ("maxRouteLength" | int32) :: + (("baseFactor" | double) :: + ("ageFactor" | double) :: + ("capacityFactor" | double) :: + ("disabledMultiplier" | double)).as[Graph.MessagePath.WeightRatios]).as[MessageRouteParams] + val routerConfCodec: Codec[RouterConf] = ( ("watchSpentWindow" | finiteDurationCodec) :: ("channelExcludeDuration" | finiteDurationCodec) :: @@ -97,6 +104,7 @@ object EclairInternalsSerializer { ("channelRangeChunkSize" | int32) :: ("channelQueryChunkSize" | int32) :: ("pathFindingExperimentConf" | pathFindingExperimentConfCodec) :: + ("messageRouteParams" | messageRouteParamsCodec) :: ("balanceEstimateHalfLife" | finiteDurationCodec)).as[RouterConf] val overrideFeaturesListCodec: Codec[List[(PublicKey, Features[Feature])]] = listOfN(uint16, publicKey ~ lengthPrefixedFeaturesCodec) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/BalanceEstimate.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/BalanceEstimate.scala index 18ea2a810d..c05a33be5e 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/BalanceEstimate.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/BalanceEstimate.scala @@ -18,8 +18,9 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{Satoshi, SatoshiLong} -import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} +import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, ActiveEdge} import fr.acinq.eclair.router.Router.{ChannelDesc, ChannelHop, Route} +import fr.acinq.eclair.wire.protocol.NodeAnnouncement import fr.acinq.eclair.{MilliSatoshi, MilliSatoshiLong, ShortChannelId, TimestampSecond, TimestampSecondLong, ToMilliSatoshiConversion} import scala.concurrent.duration.{DurationInt, FiniteDuration} @@ -178,7 +179,7 @@ case class BalanceEstimate private(low: MilliSatoshi, def didReceive(amount: MilliSatoshi, timestamp: TimestampSecond): BalanceEstimate = otherSide.didSend(amount, timestamp).otherSide - def addEdge(edge: GraphEdge): BalanceEstimate = copy( + def addEdge(edge: ActiveEdge): BalanceEstimate = copy( high = high.max(edge.capacity.toMilliSatoshi), capacities = capacities.updated(edge.desc.shortChannelId, edge.capacity) ) @@ -234,7 +235,7 @@ object BalanceEstimate { case class BalancesEstimates(balances: Map[(PublicKey, PublicKey), BalanceEstimate], defaultHalfLife: FiniteDuration) { private def get(a: PublicKey, b: PublicKey): Option[BalanceEstimate] = balances.get((a, b)) - def addEdge(edge: GraphEdge): BalancesEstimates = BalancesEstimates( + def addEdge(edge: ActiveEdge): BalancesEstimates = BalancesEstimates( balances.updatedWith((edge.desc.a, edge.desc.b))(balance => Some(balance.getOrElse(BalanceEstimate.empty(defaultHalfLife)).addEdge(edge)) ), @@ -284,13 +285,17 @@ case class BalancesEstimates(balances: Map[(PublicKey, PublicKey), BalanceEstima } case class GraphWithBalanceEstimates(graph: DirectedGraph, private val balances: BalancesEstimates) { - def addEdge(edge: GraphEdge): GraphWithBalanceEstimates = GraphWithBalanceEstimates(graph.addEdge(edge), balances.addEdge(edge)) + def addOrUpdateVertex(ann: NodeAnnouncement): GraphWithBalanceEstimates = GraphWithBalanceEstimates(graph.addOrUpdateVertex(ann), balances) - def removeEdge(desc: ChannelDesc): GraphWithBalanceEstimates = GraphWithBalanceEstimates(graph.removeEdge(desc), balances.removeEdge(desc)) + def addEdge(edge: ActiveEdge): GraphWithBalanceEstimates = GraphWithBalanceEstimates(graph.addEdge(edge), balances.addEdge(edge)) - def removeEdges(descList: Iterable[ChannelDesc]): GraphWithBalanceEstimates = GraphWithBalanceEstimates( - graph.removeEdges(descList), - descList.foldLeft(balances)((acc, edge) => acc.removeEdge(edge)), + def disableEdge(desc: ChannelDesc): GraphWithBalanceEstimates = GraphWithBalanceEstimates(graph.disableEdge(desc), balances.removeEdge(desc)) + + def removeChannel(desc: ChannelDesc): GraphWithBalanceEstimates = GraphWithBalanceEstimates(graph.removeChannel(desc), balances.removeEdge(desc).removeEdge(desc.reversed)) + + def removeChannels(descList: Iterable[ChannelDesc]): GraphWithBalanceEstimates = GraphWithBalanceEstimates( + graph.removeChannels(descList), + descList.foldLeft(balances)((acc, edge) => acc.removeEdge(edge).removeEdge(edge.reversed)), ) def routeCouldRelay(route: Route): GraphWithBalanceEstimates = { @@ -313,7 +318,7 @@ case class GraphWithBalanceEstimates(graph: DirectedGraph, private val balances: GraphWithBalanceEstimates(graph, balances.channelCouldNotSend(hop, amount)) } - def canSend(amount: MilliSatoshi, edge: GraphEdge): Double = { + def canSend(amount: MilliSatoshi, edge: ActiveEdge): Double = { balances.balances.get((edge.desc.a, edge.desc.b)) match { case Some(estimate) => estimate.canSend(amount) case None => BalanceEstimate.empty(1 hour).addEdge(edge).canSend(amount) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala index 24e445052e..e09561be6c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala @@ -18,15 +18,14 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{Btc, BtcDouble, MilliBtc, Satoshi} -import fr.acinq.eclair.payment.relay.Relayer.RelayFees +import fr.acinq.eclair._ import fr.acinq.eclair.payment.Invoice -import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} +import fr.acinq.eclair.payment.relay.Relayer.RelayFees +import fr.acinq.eclair.router.Graph.GraphStructure.{ActiveEdge, DirectedGraph} import fr.acinq.eclair.router.Router._ -import fr.acinq.eclair.wire.protocol.ChannelUpdate -import fr.acinq.eclair.{RealShortChannelId, _} +import fr.acinq.eclair.wire.protocol.{ChannelUpdate, NodeAnnouncement} import scala.annotation.tailrec -import scala.collection.immutable.SortedMap import scala.collection.mutable object Graph { @@ -70,7 +69,7 @@ object Graph { case class WeightedNode(key: PublicKey, weight: RichWeight) - case class WeightedPath(path: Seq[GraphEdge], weight: RichWeight) + case class WeightedPath(path: Seq[ActiveEdge], weight: RichWeight) /** * This comparator must be consistent with the "equals" behavior, thus for two weighted nodes with @@ -89,9 +88,9 @@ object Graph { override def compare(x: WeightedPath, y: WeightedPath): Int = y.weight.compare(x.weight) } - case class InfiniteLoop(path: Seq[GraphEdge]) extends Exception + case class InfiniteLoop(path: Seq[ActiveEdge]) extends Exception - case class NegativeProbability(edge: GraphEdge, weight: RichWeight, heuristicsConstants: HeuristicsConstants) extends Exception + case class NegativeProbability(edge: ActiveEdge, weight: RichWeight, heuristicsConstants: HeuristicsConstants) extends Exception /** * Yen's algorithm to find the k-shortest (loop-less) paths in a graph, uses dijkstra as search algo. Is guaranteed to @@ -116,7 +115,7 @@ object Graph { amount: MilliSatoshi, ignoredEdges: Set[ChannelDesc], ignoredVertices: Set[PublicKey], - extraEdges: Set[GraphEdge], + extraEdges: Set[ActiveEdge], pathsToFind: Int, wr: Either[WeightRatios, HeuristicsConstants], currentBlockHeight: BlockHeight, @@ -206,12 +205,12 @@ object Graph { targetNode: PublicKey, ignoredEdges: Set[ChannelDesc], ignoredVertices: Set[PublicKey], - extraEdges: Set[GraphEdge], + extraEdges: Set[ActiveEdge], initialWeight: RichWeight, boundaries: RichWeight => Boolean, currentBlockHeight: BlockHeight, wr: Either[WeightRatios, HeuristicsConstants], - includeLocalChannelCost: Boolean): Seq[GraphEdge] = { + includeLocalChannelCost: Boolean): Seq[ActiveEdge] = { // the graph does not contain source/destination nodes val sourceNotInGraph = !g.containsVertex(sourceNode) && !extraEdges.exists(_.desc.a == sourceNode) val targetNotInGraph = !g.containsVertex(targetNode) && !extraEdges.exists(_.desc.b == targetNode) @@ -223,7 +222,7 @@ object Graph { // because in the worst case scenario we will insert all the vertices. val initialCapacity = 100 val bestWeights = mutable.HashMap.newBuilder[PublicKey, RichWeight](initialCapacity, mutable.HashMap.defaultLoadFactor).result() - val bestEdges = mutable.HashMap.newBuilder[PublicKey, GraphEdge](initialCapacity, mutable.HashMap.defaultLoadFactor).result() + val bestEdges = mutable.HashMap.newBuilder[PublicKey, ActiveEdge](initialCapacity, mutable.HashMap.defaultLoadFactor).result() // NB: we want the elements with smallest weight first, hence the `reverse`. val toExplore = mutable.PriorityQueue.empty[WeightedNode](NodeComparator.reverse) val visitedNodes = mutable.HashSet[PublicKey]() @@ -243,7 +242,7 @@ object Graph { val neighborEdges = { val extraNeighbors = extraEdges.filter(_.desc.b == current.key) // the resulting set must have only one element per shortChannelId; we prioritize extra edges - g.getIncomingEdgesOf(current.key).filterNot(e => extraNeighbors.exists(_.desc.shortChannelId == e.desc.shortChannelId)) ++ extraNeighbors + g.getIncomingEdgesOf(current.key).collect{case e: ActiveEdge if !extraNeighbors.exists(_.desc.shortChannelId == e.desc.shortChannelId) => e} ++ extraNeighbors } neighborEdges.foreach { edge => val neighbor = edge.desc.a @@ -274,7 +273,7 @@ object Graph { } if (targetFound) { - val edgePath = new mutable.ArrayBuffer[GraphEdge](RouteCalculation.ROUTE_MAX_LENGTH) + val edgePath = new mutable.ArrayBuffer[ActiveEdge](RouteCalculation.ROUTE_MAX_LENGTH) var current = bestEdges.get(sourceNode) while (current.isDefined) { edgePath += current.get @@ -285,7 +284,7 @@ object Graph { } edgePath.toSeq } else { - Seq.empty[GraphEdge] + Seq.empty[ActiveEdge] } } @@ -299,7 +298,7 @@ object Graph { * @param weightRatios ratios used to 'weight' edges when searching for the shortest path * @param includeLocalChannelCost if the path is for relaying and we need to include the cost of the local channel */ - private def addEdgeWeight(sender: PublicKey, edge: GraphEdge, prev: RichWeight, currentBlockHeight: BlockHeight, weightRatios: Either[WeightRatios, HeuristicsConstants], includeLocalChannelCost: Boolean): RichWeight = { + private def addEdgeWeight(sender: PublicKey, edge: ActiveEdge, prev: RichWeight, currentBlockHeight: BlockHeight, weightRatios: Either[WeightRatios, HeuristicsConstants], includeLocalChannelCost: Boolean): RichWeight = { val totalAmount = if (edge.desc.a == sender && !includeLocalChannelCost) prev.amount else addEdgeFees(edge, prev.amount) val fee = totalAmount - prev.amount val totalFees = prev.fees + fee @@ -368,15 +367,15 @@ object Graph { * @param amountToForward the value that this edge will have to carry along * @return the new amount updated with the necessary fees for this edge */ - private def addEdgeFees(edge: GraphEdge, amountToForward: MilliSatoshi): MilliSatoshi = { + private def addEdgeFees(edge: ActiveEdge, amountToForward: MilliSatoshi): MilliSatoshi = { amountToForward + edge.params.fee(amountToForward) } /** Validate that all edges along the path can relay the amount with fees. */ - def validatePath(path: Seq[GraphEdge], amount: MilliSatoshi): Boolean = validateReversePath(path.reverse, amount) + def validatePath(path: Seq[ActiveEdge], amount: MilliSatoshi): Boolean = validateReversePath(path.reverse, amount) @tailrec - private def validateReversePath(path: Seq[GraphEdge], amount: MilliSatoshi): Boolean = path.headOption match { + private def validateReversePath(path: Seq[ActiveEdge], amount: MilliSatoshi): Boolean = path.headOption match { case None => true case Some(edge) => val canRelayAmount = amount <= edge.capacity && @@ -397,7 +396,7 @@ object Graph { * @param wr ratios used to 'weight' edges when searching for the shortest path * @param includeLocalChannelCost if the path is for relaying and we need to include the cost of the local channel */ - def pathWeight(sender: PublicKey, path: Seq[GraphEdge], amount: MilliSatoshi, currentBlockHeight: BlockHeight, wr: Either[WeightRatios, HeuristicsConstants], includeLocalChannelCost: Boolean): RichWeight = { + def pathWeight(sender: PublicKey, path: Seq[ActiveEdge], amount: MilliSatoshi, currentBlockHeight: BlockHeight, wr: Either[WeightRatios, HeuristicsConstants], includeLocalChannelCost: Boolean): RichWeight = { path.foldRight(RichWeight(amount, 0, CltvExpiryDelta(0), 1.0, 0 msat, 0 msat, 0.0)) { (edge, prev) => addEdgeWeight(sender, edge, prev, currentBlockHeight, wr, includeLocalChannelCost) } @@ -429,8 +428,165 @@ object Graph { } + object MessagePath { + /** + * The cumulative weight of a set of edges (path in the graph). + * + * @param length number of edges in the path + * @param weight cost multiplied by a factor based on heuristics (see [[WeightRatios]]). + */ + case class RichWeight(length: Int, weight: Double) extends Ordered[RichWeight] { + override def compare(that: RichWeight): Int = this.weight.compareTo(that.weight) + } + + object RichWeight { + def zero: RichWeight = RichWeight(0, 0.0) + } + + case class WeightRatios(baseFactor: Double, ageFactor: Double, capacityFactor: Double, disabledMultiplier: Double) { + require(baseFactor + ageFactor + capacityFactor == 1, "The sum of heuristics ratios must be 1") + require(baseFactor >= 0.0, "ratio-base must be nonnegative") + require(ageFactor >= 0.0, "ratio-channel-age must be nonnegative") + require(capacityFactor >= 0.0, "ratio-channel-capacity must be nonnegative") + require(disabledMultiplier >= 1.0, "disabled-multiplier must be at least 1") + } + + case class WeightedNode(key: PublicKey, weight: RichWeight) + + /** + * This comparator must be consistent with the "equals" behavior, thus for two weighted nodes with + * the same weight we distinguish them by their public key. + * See https://docs.oracle.com/javase/8/docs/api/java/util/Comparator.html + */ + object NodeComparator extends Ordering[WeightedNode] { + override def compare(x: WeightedNode, y: WeightedNode): Int = { + val weightCmp = x.weight.compareTo(y.weight) + if (weightCmp == 0) x.key.toString().compareTo(y.key.toString()) + else weightCmp + } + } + + /** + * Add the given edge to the path and compute the new weight. + * + * @param desc the edge we want to cross + * @param prev weight of the rest of the path + * @param currentBlockHeight the height of the chain tip (latest block). + * @param weightRatios ratios used to 'weight' edges when searching for the shortest path + */ + private def addEdgeWeight(desc: ChannelDesc, capacity: Satoshi, isActive: Boolean, prev: RichWeight, currentBlockHeight: BlockHeight, weightRatios: WeightRatios): RichWeight = { + import RoutingHeuristics._ + + // Every edge is weighted by funding block height where older blocks add less weight. The window considered is 1 year. + val ageFactor = desc.shortChannelId match { + case real: RealShortChannelId => normalize(real.blockHeight.toDouble, min = (currentBlockHeight - BLOCK_TIME_ONE_YEAR).toDouble, max = currentBlockHeight.toDouble) + // for local channels or route hints we don't easily have access to the channel block height, but we want to + // give them the best score anyway + case _: Alias => 1 + case _: UnspecifiedShortChannelId => 1 + } + + // Every edge is weighted by channel capacity, larger channels add less weight + val capFactor = 1 - normalize(capacity.toMilliSatoshi.toLong.toDouble, CAPACITY_CHANNEL_LOW.toLong.toDouble, CAPACITY_CHANNEL_HIGH.toLong.toDouble) + + val multiplier = if (isActive) 1 else weightRatios.disabledMultiplier + + val totalWeight = prev.weight + (weightRatios.baseFactor + (ageFactor * weightRatios.ageFactor) + (capFactor * weightRatios.capacityFactor)) * multiplier + RichWeight(prev.length + 1, totalWeight) + } + + /** + * Finds the shortest path in the graph, uses a modified version of Dijkstra's algorithm that computes the shortest + * path from the target to the source (this is because we want to calculate the weight of the edges correctly). The + * graph @param g is optimized for querying the incoming edges given a vertex. + * + * @param g the graph on which will be performed the search + * @param sourceNode the starting node of the path we're looking for (payer) + * @param targetNode the destination node of the path + * @param ignoredVertices nodes that should be avoided + * @param boundaries a predicate function that can be used to impose limits on the outcome of the search + * @param currentBlockHeight the height of the chain tip (latest block) + * @param wr ratios used to 'weight' edges when searching for the shortest path + */ + def dijkstraMessagePath(g: DirectedGraph, + sourceNode: PublicKey, + targetNode: PublicKey, + ignoredVertices: Set[PublicKey], + boundaries: RichWeight => Boolean, + currentBlockHeight: BlockHeight, + wr: WeightRatios): Option[Seq[ChannelDesc]] = { + // the graph does not contain source/destination nodes + val sourceNotInGraph = !g.containsVertex(sourceNode) + val targetNotInGraph = !g.containsVertex(targetNode) + if (sourceNotInGraph || targetNotInGraph) { + return None + } + + // conservative estimation to avoid over-allocating memory: this is not the actual optimal size for the maps, + // because in the worst case scenario we will insert all the vertices. + val initialCapacity = 100 + val bestWeights = mutable.HashMap.newBuilder[PublicKey, RichWeight](initialCapacity, mutable.HashMap.defaultLoadFactor).result() + val bestEdges = mutable.HashMap.newBuilder[PublicKey, ChannelDesc](initialCapacity, mutable.HashMap.defaultLoadFactor).result() + // NB: we want the elements with smallest weight first, hence the `reverse`. + val toExplore = mutable.PriorityQueue.empty[WeightedNode](NodeComparator.reverse) + val visitedNodes = mutable.HashSet[PublicKey]() + + // initialize the queue and cost array with the initial weight + bestWeights.put(targetNode, RichWeight.zero) + toExplore.enqueue(WeightedNode(targetNode, RichWeight.zero)) + + var targetFound = false + while (toExplore.nonEmpty && !targetFound) { + // node with the smallest distance from the target + val current = toExplore.dequeue() // O(log(n)) + targetFound = current.key == sourceNode + if (!targetFound && !visitedNodes.contains(current.key)) { + visitedNodes += current.key + g.getIncomingEdgesOf(current.key).foreach { edge => + val neighbor = edge.desc.a + if (neighbor == sourceNode || (g.getVertexFeatures(neighbor).hasFeature(Features.OnionMessages) && !ignoredVertices.contains(neighbor))) { + val neighborWeight = addEdgeWeight(edge.desc, edge.capacity, edge.isInstanceOf[ActiveEdge], current.weight, currentBlockHeight, wr) + if (boundaries(neighborWeight)) { + val previousNeighborWeight = bestWeights.getOrElse(neighbor, RichWeight(Int.MaxValue, Double.MaxValue)) + // if this path between neighbor and the target has a shorter distance than previously known, we select it + if (neighborWeight < previousNeighborWeight) { + // update the best edge for this vertex + bestEdges.put(neighbor, edge.desc) + // add this updated node to the list for further exploration + toExplore.enqueue(WeightedNode(neighbor, neighborWeight)) // O(1) + // update the minimum known distance array + bestWeights.put(neighbor, neighborWeight) + } + } + } + } + } + } + + if (targetFound) { + val edgePath = new mutable.ArrayBuffer[ChannelDesc](RouteCalculation.ROUTE_MAX_LENGTH) + var current = bestEdges.get(sourceNode) + while (current.isDefined) { + edgePath += current.get + current = bestEdges.get(current.get.b) + if (edgePath.length > RouteCalculation.ROUTE_MAX_LENGTH) { + throw InfiniteLoop(Nil) + } + } + Some(edgePath.toSeq) + } else { + None + } + } + } + object GraphStructure { + sealed trait GraphEdge { + val desc: ChannelDesc + val capacity: Satoshi + } + /** * Representation of an edge of the graph * @@ -439,7 +595,7 @@ object Graph { * @param capacity channel capacity * @param balance_opt (optional) available balance that can be sent through this edge */ - case class GraphEdge private(desc: ChannelDesc, params: HopRelayParams, capacity: Satoshi, balance_opt: Option[MilliSatoshi]) { + case class ActiveEdge private(desc: ChannelDesc, params: HopRelayParams, capacity: Satoshi, balance_opt: Option[MilliSatoshi]) extends GraphEdge { def maxHtlcAmount(reservedCapacity: MilliSatoshi): MilliSatoshi = Seq( balance_opt.map(balance => balance - reservedCapacity), @@ -450,37 +606,41 @@ object Graph { def fee(amount: MilliSatoshi): MilliSatoshi = params.fee(amount) } - object GraphEdge { - def apply(u: ChannelUpdate, pc: PublicChannel): GraphEdge = GraphEdge( + object ActiveEdge { + def apply(u: ChannelUpdate, pc: PublicChannel): ActiveEdge = ActiveEdge( desc = ChannelDesc(u, pc.ann), params = HopRelayParams.FromAnnouncement(u), capacity = pc.capacity, balance_opt = pc.getBalanceSameSideAs(u) ) - def apply(u: ChannelUpdate, pc: PrivateChannel): GraphEdge = GraphEdge( + def apply(u: ChannelUpdate, pc: PrivateChannel): ActiveEdge = ActiveEdge( desc = ChannelDesc(u, pc), params = HopRelayParams.FromAnnouncement(u), capacity = pc.capacity, balance_opt = pc.getBalanceSameSideAs(u) ) - def apply(e: Invoice.ExtraEdge): GraphEdge = { - val maxBtc = 21e6.btc - GraphEdge( - desc = ChannelDesc(e.shortChannelId, e.sourceNodeId, e.targetNodeId), - params = HopRelayParams.FromHint(e), - // Routing hints don't include the channel's capacity, so we assume it's big enough. - capacity = maxBtc.toSatoshi, - balance_opt = None, - ) + def apply(e: Invoice.ExtraEdge): ActiveEdge = { + val maxBtc = 21e6.btc + ActiveEdge( + desc = ChannelDesc(e.shortChannelId, e.sourceNodeId, e.targetNodeId), + params = HopRelayParams.FromHint(e), + // Routing hints don't include the channel's capacity, so we assume it's big enough. + capacity = maxBtc.toSatoshi, + balance_opt = None, + ) } } + case class DisabledEdge(desc: ChannelDesc, capacity: Satoshi) extends GraphEdge + + case class Vertex(features: Features[NodeFeature], incomingEdges: Map[ChannelDesc, GraphEdge]) + /** A graph data structure that uses an adjacency list, stores the incoming edges of the neighbors */ - case class DirectedGraph(private val vertices: Map[PublicKey, List[GraphEdge]]) { + case class DirectedGraph(private val vertices: Map[PublicKey, Vertex]) { - def addEdges(edges: Iterable[GraphEdge]): DirectedGraph = edges.foldLeft(this)((acc, edge) => acc.addEdge(edge)) + def addEdges(edges: Iterable[ActiveEdge]): DirectedGraph = edges.foldLeft(this)((acc, edge) => acc.addEdge(edge)) /** * Adds an edge to the graph. If one of the two vertices is not found it will be created. @@ -488,57 +648,75 @@ object Graph { * @param edge the edge that is going to be added to the graph * @return a new graph containing this edge */ - def addEdge(edge: GraphEdge): DirectedGraph = { - val vertexIn = edge.desc.a - val vertexOut = edge.desc.b - // the graph is allowed to have multiple edges between the same vertices but only one per channel - if (containsEdge(edge.desc)) { - removeEdge(edge.desc).addEdge(edge) // the recursive call will have the original params - } else { - val withVertices = addVertex(vertexIn).addVertex(vertexOut) - DirectedGraph(withVertices.vertices.updated(vertexOut, edge +: withVertices.vertices(vertexOut))) - } + def addEdge(edge: ActiveEdge): DirectedGraph = { + val vertexA = vertices.getOrElse(edge.desc.a, Vertex(Features.empty, Map.empty)) + val updatedVertexA = + if (vertexA.incomingEdges.contains(edge.desc.reversed)) { + vertexA + } else { + // If the reversed edge is not already in the graph, we add it disabled. + vertexA.copy(incomingEdges = vertexA.incomingEdges + (edge.desc.reversed -> DisabledEdge(edge.desc.reversed, edge.capacity))) + } + val vertexB = vertices.getOrElse(edge.desc.b, Vertex(Features.empty, Map.empty)) + val updatedVertexB = vertexB.copy(incomingEdges = vertexB.incomingEdges + (edge.desc -> edge)) + DirectedGraph(vertices.updated(edge.desc.a, updatedVertexA).updated(edge.desc.b, updatedVertexB)) } /** - * Removes the edge corresponding to the given pair channel-desc/channel-update, + * Disables the edge corresponding to the given channel-desc. * NB: this operation does NOT remove any vertex * - * @param desc the channel description associated to the edge that will be removed - * @return a new graph without this edge + * @param desc the channel description associated to the edge that will be disabled + * @return a new graph with this edge disabled */ - def removeEdge(desc: ChannelDesc): DirectedGraph = { - if (containsEdge(desc)) { - DirectedGraph(vertices.updated(desc.b, vertices(desc.b).filterNot(_.desc == desc))) - } else { - this - } + def disableEdge(desc: ChannelDesc): DirectedGraph = { + val updatedVertices = vertices.updatedWith(desc.b)(_.map(vertex => { + val updatedEdges = vertex.incomingEdges.updatedWith(desc)(_.map(edge => DisabledEdge(desc, edge.capacity))) + vertex.copy(incomingEdges = updatedEdges) + })) + DirectedGraph(updatedVertices) + } + + /** + * Removes the edges corresponding to the given channel-desc, + * both edges (corresponding to both directions) are removed. + * NB: this operation does NOT remove any vertex + * + * @param desc the channel description for the channel to remove + * @return a new graph without this channel + */ + def removeChannel(desc: ChannelDesc): DirectedGraph = { + val updatedVertices = + vertices + .updatedWith(desc.b)(_.map(vertexB => vertexB.copy(incomingEdges = vertexB.incomingEdges - desc))) + .updatedWith(desc.a)(_.map(vertexA => vertexA.copy(incomingEdges = vertexA.incomingEdges - desc.reversed))) + DirectedGraph(updatedVertices) } - def removeEdges(descList: Iterable[ChannelDesc]): DirectedGraph = { - descList.foldLeft(this)((acc, edge) => acc.removeEdge(edge)) + def removeChannels(descList: Iterable[ChannelDesc]): DirectedGraph = { + descList.foldLeft(this)((acc, edge) => acc.removeChannel(edge)) } /** * @return For edges to be considered equal they must have the same in/out vertices AND same shortChannelId */ - def getEdge(edge: GraphEdge): Option[GraphEdge] = getEdge(edge.desc) + def getEdge(edge: ActiveEdge): Option[ActiveEdge] = getEdge(edge.desc) - def getEdge(desc: ChannelDesc): Option[GraphEdge] = { - vertices.get(desc.b).flatMap { adj => - adj.find(e => e.desc.shortChannelId == desc.shortChannelId && e.desc.a == desc.a) + def getEdge(desc: ChannelDesc): Option[ActiveEdge] = + vertices.get(desc.b).flatMap(_.incomingEdges.get(desc)) match { + case Some(e: ActiveEdge) => Some(e) + case None | Some(_: DisabledEdge) => None } - } /** * @param keyA the key associated with the starting vertex * @param keyB the key associated with the ending vertex * @return all the edges going from keyA --> keyB (there might be more than one if there are multiple channels) */ - def getEdgesBetween(keyA: PublicKey, keyB: PublicKey): Seq[GraphEdge] = { + def getEdgesBetween(keyA: PublicKey, keyB: PublicKey): Iterable[ActiveEdge] = { vertices.get(keyB) match { - case None => Seq.empty - case Some(adj) => adj.filter(e => e.desc.a == keyA) + case None => Iterable.empty + case Some(vertex) => vertex.incomingEdges.collect { case (desc, edge: ActiveEdge) if desc.a == keyA => edge } } } @@ -546,33 +724,41 @@ object Graph { * @param keyB the key associated with the target vertex * @return all edges incoming to that vertex */ - def getIncomingEdgesOf(keyB: PublicKey): Seq[GraphEdge] = { - vertices.getOrElse(keyB, List.empty) + def getIncomingEdgesOf(keyB: PublicKey): Iterable[GraphEdge] = { + vertices.get(keyB).map(_.incomingEdges.values).getOrElse(Iterable.empty) } + def getVertexFeatures(key: PublicKey): Features[NodeFeature] = vertices.get(key).map(_.features).getOrElse(Features.empty) + /** * Removes a vertex and all its associated edges (both incoming and outgoing) */ def removeVertex(key: PublicKey): DirectedGraph = { - DirectedGraph(removeEdges(getIncomingEdgesOf(key).map(_.desc)).vertices - key) + val channels = getIncomingEdgesOf(key).map(_.desc) + DirectedGraph(removeChannels(channels).vertices - key) } + def removeVertices(nodeIds: Iterable[PublicKey]): DirectedGraph = nodeIds.foldLeft(this)((acc, nodeId) => acc.removeVertex(nodeId)) + /** - * Adds a new vertex to the graph, starting with no edges + * Adds a new vertex to the graph, starting with no edges. + * Or update the node features if the vertex is already present. */ - def addVertex(key: PublicKey): DirectedGraph = { - vertices.get(key) match { - case None => DirectedGraph(vertices + (key -> List.empty)) - case _ => this - } + def addOrUpdateVertex(ann: NodeAnnouncement): DirectedGraph = { + DirectedGraph(vertices.updatedWith(ann.nodeId) { + case Some(vertex) => Some(vertex.copy(features = ann.features.nodeAnnouncementFeatures())) + case None => Some(Vertex(ann.features.nodeAnnouncementFeatures(), Map.empty)) + }) } + def addVertices(announcements: Iterable[NodeAnnouncement]): DirectedGraph = announcements.foldLeft(this)((acc, ann) => acc.addOrUpdateVertex(ann)) + /** * Note this operation will traverse all edges in the graph (expensive) * * @return a list of the outgoing edges of the given vertex. If the vertex doesn't exists an empty list is returned. */ - def edgesOf(key: PublicKey): Seq[GraphEdge] = { + def edgesOf(key: PublicKey): Seq[ActiveEdge] = { edgeSet().filter(_.desc.a == key).toSeq } @@ -584,7 +770,7 @@ object Graph { /** * @return an iterator of all the edges in this graph */ - def edgeSet(): Iterable[GraphEdge] = vertices.values.flatten + def edgeSet(): Iterable[ActiveEdge] = vertices.values.flatMap(_.incomingEdges.collect { case (_, edge: ActiveEdge) => edge }) /** * @return true if this graph contain a vertex with this key, false otherwise @@ -592,18 +778,15 @@ object Graph { def containsVertex(key: PublicKey): Boolean = vertices.contains(key) /** - * @return true if this edge desc is in the graph. For edges to be considered equal they must have the same in/out vertices AND same shortChannelId + * @return true if this edge desc is in the graph and not disabled. For edges to be considered equal they must have the same in/out vertices AND same shortChannelId */ def containsEdge(desc: ChannelDesc): Boolean = { vertices.get(desc.b) match { case None => false - case Some(adj) => adj.exists(neighbor => neighbor.desc.shortChannelId == desc.shortChannelId && neighbor.desc.a == desc.a) - } - } - - def prettyPrint(): String = { - vertices.foldLeft("") { case (acc, (vertex, adj)) => - acc + s"[${vertex.toString().take(5)}]: ${adj.map("-> " + _.desc.b.toString().take(5))} \n" + case Some(vertex) => vertex.incomingEdges.get(desc) match { + case None | Some(_: DisabledEdge) => false + case Some(_: ActiveEdge) => true + } } } } @@ -611,10 +794,10 @@ object Graph { object DirectedGraph { // @formatter:off - def apply(): DirectedGraph = new DirectedGraph(Map()) - def apply(key: PublicKey): DirectedGraph = new DirectedGraph(Map(key -> List.empty)) - def apply(edge: GraphEdge): DirectedGraph = DirectedGraph().addEdge(edge) - def apply(edges: Seq[GraphEdge]): DirectedGraph = DirectedGraph().addEdges(edges) + def apply(): DirectedGraph = new DirectedGraph(Map.empty) + def apply(key: PublicKey): DirectedGraph = new DirectedGraph(Map(key -> Vertex(Features.empty, Map.empty))) + def apply(edge: ActiveEdge): DirectedGraph = DirectedGraph().addEdge(edge) + def apply(edges: Seq[ActiveEdge]): DirectedGraph = DirectedGraph().addEdges(edges) // @formatter:on /** @@ -625,27 +808,16 @@ object Graph { * * @param channels map of all known public channels in the network. */ - def makeGraph(channels: SortedMap[RealShortChannelId, PublicChannel]): DirectedGraph = { - // initialize the map with the appropriate size to avoid resizing during the graph initialization - val mutableMap = new mutable.HashMap[PublicKey, List[GraphEdge]](initialCapacity = channels.size + 1, mutable.HashMap.defaultLoadFactor) - - // add all the vertices and edges in one go - channels.values.foreach { channel => - channel.update_1_opt.foreach(u1 => addToMap(GraphEdge(u1, channel))) - channel.update_2_opt.foreach(u2 => addToMap(GraphEdge(u2, channel))) - } - - def addToMap(edge: GraphEdge): Unit = { - mutableMap.put(edge.desc.b, edge +: mutableMap.getOrElse(edge.desc.b, List.empty[GraphEdge])) - if (!mutableMap.contains(edge.desc.a)) { - mutableMap += edge.desc.a -> List.empty[GraphEdge] - } - } + def makeGraph(channels: Map[RealShortChannelId, PublicChannel], nodes: Seq[NodeAnnouncement]): DirectedGraph = { + val edges = channels.values.flatMap(channel => Seq( + channel.update_1_opt.collect { case u1 if u1.channelFlags.isEnabled => ActiveEdge(u1, channel) }, + channel.update_2_opt.collect { case u2 if u2.channelFlags.isEnabled => ActiveEdge(u2, channel) }, + ).flatten) - new DirectedGraph(mutableMap.toMap) + DirectedGraph().addVertices(nodes).addEdges(edges) } - def graphEdgeToHop(graphEdge: GraphEdge): ChannelHop = ChannelHop(graphEdge.desc.shortChannelId, graphEdge.desc.a, graphEdge.desc.b, graphEdge.params) + def graphEdgeToHop(graphEdge: ActiveEdge): ChannelHop = ChannelHop(graphEdge.desc.shortChannelId, graphEdge.desc.a, graphEdge.desc.b, graphEdge.params) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala index 40e6f61a5b..14e27fbe00 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/RouteCalculation.scala @@ -22,10 +22,11 @@ import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.Logs.LogCategory import fr.acinq.eclair._ +import fr.acinq.eclair.message.SendingMessage import fr.acinq.eclair.payment.send._ import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop -import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} -import fr.acinq.eclair.router.Graph.{InfiniteLoop, NegativeProbability, RichWeight} +import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, ActiveEdge} +import fr.acinq.eclair.router.Graph.{InfiniteLoop, MessagePath, NegativeProbability, RichWeight} import fr.acinq.eclair.router.Monitoring.{Metrics, Tags} import fr.acinq.eclair.router.Router._ import kamon.tag.TagSet @@ -36,7 +37,7 @@ import scala.util.{Failure, Random, Success, Try} object RouteCalculation { - private def getEdgeRelayScid(d: Data, localNodeId: PublicKey, e: GraphEdge): ShortChannelId = { + private def getEdgeRelayScid(d: Data, localNodeId: PublicKey, e: ActiveEdge): ShortChannelId = { if (e.desc.b == localNodeId) { // We are the destination of that edge: local graph edges always use either the local alias or the real scid. // We want to use the remote alias when available, because our peer won't understand our local alias. @@ -65,8 +66,8 @@ object RouteCalculation { paymentHash_opt = fr.paymentContext.map(_.paymentHash))) { implicit val sender: ActorRef = ctx.self // necessary to preserve origin when sending messages to other actors - val extraEdges = fr.extraEdges.map(GraphEdge(_)) - val g = extraEdges.foldLeft(d.graphWithBalances.graph) { case (g: DirectedGraph, e: GraphEdge) => g.addEdge(e) } + val extraEdges = fr.extraEdges.map(ActiveEdge(_)) + val g = extraEdges.foldLeft(d.graphWithBalances.graph) { case (g: DirectedGraph, e: ActiveEdge) => g.addEdge(e) } fr.route match { case PredefinedNodeRoute(amount, hops, maxFee_opt) => @@ -128,7 +129,7 @@ object RouteCalculation { * * The routes found must then be post-processed by calling [[addFinalHop]]. */ - private def computeTarget(r: RouteRequest, ignoredEdges: Set[ChannelDesc]): (PublicKey, MilliSatoshi, MilliSatoshi, Set[GraphEdge]) = { + private def computeTarget(r: RouteRequest, ignoredEdges: Set[ChannelDesc]): (PublicKey, MilliSatoshi, MilliSatoshi, Set[ActiveEdge]) = { val pendingAmount = r.pendingPayments.map(_.amount).sum val totalMaxFee = r.routeParams.getMaxFee(r.target.totalAmount) val pendingChannelFee = r.pendingPayments.map(_.channelFee(r.routeParams.includeLocalChannelCost)).sum @@ -139,7 +140,7 @@ object RouteCalculation { val maxFee = totalMaxFee - pendingChannelFee val extraEdges = recipient.extraEdges .filter(_.sourceNodeId != r.source) // we ignore routing hints for our own channels, we have more accurate information - .map(GraphEdge(_)) + .map(ActiveEdge(_)) .filterNot(e => ignoredEdges.contains(e.desc)) .toSet (targetNodeId, amountToSend, maxFee, extraEdges) @@ -155,7 +156,7 @@ object RouteCalculation { .map(_.copy(targetNodeId = targetNodeId)) .filterNot(e => ignoredEdges.exists(_.shortChannelId == e.shortChannelId)) // For blinded routes, the maximum htlc field is used to indicate the maximum amount that can be sent through the route. - .map(e => GraphEdge(e).copy(balance_opt = e.htlcMaximum_opt)) + .map(e => ActiveEdge(e).copy(balance_opt = e.htlcMaximum_opt)) .toSet val amountToSend = recipient.totalAmount - pendingAmount // When we are the introduction node and includeLocalChannelCost is false, we cannot easily remove the fee for @@ -238,6 +239,23 @@ object RouteCalculation { } } + def handleMessageRouteRequest(d: Data, currentBlockHeight: BlockHeight, r: MessageRouteRequest, routeParams: MessageRouteParams)(implicit ctx: ActorContext, log: DiagnosticLoggingAdapter): Data = { + val boundaries: MessagePath.RichWeight => Boolean = { weight => + weight.length <= routeParams.maxRouteLength && weight.length <= ROUTE_MAX_LENGTH + } + log.info("finding route for onion messages {} -> {}", r.source, r.target) + Graph.MessagePath.dijkstraMessagePath(d.graphWithBalances.graph, r.source, r.target, r.ignoredNodes, boundaries, currentBlockHeight, routeParams.ratios) match { + case Some(path) => + val intermediateNodes = path.map(_.a).drop(1) + log.info("found route for onion messages {}", (r.source +: intermediateNodes :+ r.target).mkString(" -> ")) + r.replyTo ! MessageRoute(intermediateNodes, r.target) + case None => + log.info("no route found for onion messages {} -> {}", r.source, r.target) + r.replyTo ! MessageRouteNotFound(r.target) + } + d + } + /** This method is used after a payment failed, and we want to exclude some nodes that we know are failing */ def getIgnoredChannelDesc(channels: Map[ShortChannelId, PublicChannel], ignoreNodes: Set[PublicKey]): Iterable[ChannelDesc] = { val desc = if (ignoreNodes.isEmpty) { @@ -282,7 +300,7 @@ object RouteCalculation { amount: MilliSatoshi, maxFee: MilliSatoshi, numRoutes: Int, - extraEdges: Set[GraphEdge] = Set.empty, + extraEdges: Set[ActiveEdge] = Set.empty, ignoredEdges: Set[ChannelDesc] = Set.empty, ignoredVertices: Set[PublicKey] = Set.empty, routeParams: RouteParams, @@ -300,7 +318,7 @@ object RouteCalculation { amount: MilliSatoshi, maxFee: MilliSatoshi, numRoutes: Int, - extraEdges: Set[GraphEdge] = Set.empty, + extraEdges: Set[ActiveEdge] = Set.empty, ignoredEdges: Set[ChannelDesc] = Set.empty, ignoredVertices: Set[PublicKey] = Set.empty, routeParams: RouteParams, @@ -357,7 +375,7 @@ object RouteCalculation { targetNodeId: PublicKey, amount: MilliSatoshi, maxFee: MilliSatoshi, - extraEdges: Set[GraphEdge] = Set.empty, + extraEdges: Set[ActiveEdge] = Set.empty, ignoredEdges: Set[ChannelDesc] = Set.empty, ignoredVertices: Set[PublicKey] = Set.empty, pendingHtlcs: Seq[Route] = Nil, @@ -381,7 +399,7 @@ object RouteCalculation { targetNodeId: PublicKey, amount: MilliSatoshi, maxFee: MilliSatoshi, - extraEdges: Set[GraphEdge] = Set.empty, + extraEdges: Set[ActiveEdge] = Set.empty, ignoredEdges: Set[ChannelDesc] = Set.empty, ignoredVertices: Set[PublicKey] = Set.empty, pendingHtlcs: Seq[Route] = Nil, @@ -394,8 +412,8 @@ object RouteCalculation { val directChannels = g.getEdgesBetween(localNodeId, targetNodeId).collect { // We should always have balance information available for local channels. // NB: htlcMinimumMsat is set by our peer and may be 0 msat (even though it's not recommended). - case GraphEdge(_, params, _, Some(balance)) => DirectChannel(balance, balance <= 0.msat || balance < params.htlcMinimum) - } + case ActiveEdge(_, params, _, Some(balance)) => DirectChannel(balance, balance <= 0.msat || balance < params.htlcMinimum) + }.toSeq // If we have direct channels to the target, we can use them all. // We also count empty channels, which allows replacing them with a non-direct route (multiple hops). val numRoutes = routeParams.mpp.maxParts.max(directChannels.length) @@ -447,7 +465,7 @@ object RouteCalculation { } /** Compute the maximum amount that we can send through the given route. */ - private def computeRouteMaxAmount(route: Seq[GraphEdge], usedCapacity: mutable.Map[ShortChannelId, MilliSatoshi]): Route = { + private def computeRouteMaxAmount(route: Seq[ActiveEdge], usedCapacity: mutable.Map[ShortChannelId, MilliSatoshi]): Route = { val firstHopMaxAmount = route.head.maxHtlcAmount(usedCapacity.getOrElse(route.head.desc.shortChannelId, 0 msat)) val amount = route.drop(1).foldLeft(firstHopMaxAmount) { case (amount, edge) => // We compute fees going forward instead of backwards. That means we will slightly overestimate the fees of some diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index 7950dec574..e5ff4d4654 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -37,8 +37,8 @@ import fr.acinq.eclair.payment.relay.Relayer import fr.acinq.eclair.payment.send.Recipient import fr.acinq.eclair.payment.{Bolt11Invoice, Invoice} import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes -import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph -import fr.acinq.eclair.router.Graph.{HeuristicsConstants, WeightRatios} +import fr.acinq.eclair.router.Graph.GraphStructure.{ActiveEdge, DirectedGraph} +import fr.acinq.eclair.router.Graph.{HeuristicsConstants, MessagePath, WeightRatios} import fr.acinq.eclair.router.Monitoring.Metrics import fr.acinq.eclair.wire.protocol._ @@ -74,17 +74,17 @@ class Router(val nodeParams: NodeParams, watcher: typed.ActorRef[ZmqWatcher.Comm { log.info("loading network announcements from db...") val (pruned, channels) = db.listChannels().partition { case (_, pc) => pc.isStale(nodeParams.currentBlockHeight) } - val nodes = db.listNodes().map(n => n.nodeId -> n).toMap + val nodes = db.listNodes() Metrics.Nodes.withoutTags().update(nodes.size) Metrics.Channels.withoutTags().update(channels.size) log.info("loaded from db: channels={} nodes={}", channels.size, nodes.size) log.info("{} pruned channels at blockHeight={}", pruned.size, nodeParams.currentBlockHeight) // this will be used to calculate routes - val graph = DirectedGraph.makeGraph(channels) + val graph = DirectedGraph.makeGraph(channels, nodes) // send events for remaining channels/nodes context.system.eventStream.publish(ChannelsDiscovered(channels.values.map(pc => SingleChannelDiscovered(pc.ann, pc.capacity, pc.update_1_opt, pc.update_2_opt)))) context.system.eventStream.publish(ChannelUpdatesReceived(channels.values.flatMap(pc => pc.update_1_opt ++ pc.update_2_opt ++ Nil))) - context.system.eventStream.publish(NodesDiscovered(nodes.values)) + context.system.eventStream.publish(NodesDiscovered(nodes)) // watch the funding tx of all these channels // note: some of them may already have been spent, in that case we will receive the watch event immediately @@ -105,7 +105,7 @@ class Router(val nodeParams: NodeParams, watcher: typed.ActorRef[ZmqWatcher.Comm log.info(s"initialization completed, ready to process messages") Try(initialized.map(_.success(Done))) val data = Data( - nodes, channels, pruned, + nodes.map(n => n.nodeId -> n).toMap, channels, pruned, Stash(Map.empty, Map.empty), rebroadcast = Rebroadcast(channels = Map.empty, updates = Map.empty, nodes = Map.empty), awaiting = Map.empty, @@ -205,7 +205,7 @@ class Router(val nodeParams: NodeParams, watcher: typed.ActorRef[ZmqWatcher.Comm d.nodes.get(nodeId) match { case Some(announcement) => // This only provides a lower bound on the number of channels this peer has: disabled channels will be filtered out. - val activeChannels = d.graphWithBalances.graph.getIncomingEdgesOf(nodeId) + val activeChannels = d.graphWithBalances.graph.getIncomingEdgesOf(nodeId).collect{case e: ActiveEdge => e} val totalCapacity = activeChannels.map(_.capacity).sum replyTo ! PublicNode(announcement, activeChannels.size, totalCapacity) case None => @@ -240,6 +240,9 @@ class Router(val nodeParams: NodeParams, watcher: typed.ActorRef[ZmqWatcher.Comm case Event(r: RouteRequest, d) => stay() using RouteCalculation.handleRouteRequest(d, nodeParams.currentBlockHeight, r) + case Event(r: MessageRouteRequest, d) => + stay() using RouteCalculation.handleMessageRouteRequest(d, nodeParams.currentBlockHeight, r, nodeParams.routerConf.messageRouteParams) + // Warning: order matters here, this must be the first match for HasChainHash messages ! case Event(PeerRoutingMessage(_, _, routingMessage: HasChainHash), _) if routingMessage.chainHash != nodeParams.chainHash => sender() ! TransportHandler.ReadAck(routingMessage) @@ -356,13 +359,16 @@ object Router { channelRangeChunkSize: Int, channelQueryChunkSize: Int, pathFindingExperimentConf: PathFindingExperimentConf, + messageRouteParams: MessageRouteParams, balanceEstimateHalfLife: FiniteDuration) { require(channelRangeChunkSize <= Sync.MAXIMUM_CHUNK_SIZE, "channel range chunk size exceeds the size of a lightning message") require(channelQueryChunkSize <= Sync.MAXIMUM_CHUNK_SIZE, "channel query chunk size exceeds the size of a lightning message") } // @formatter:off - case class ChannelDesc private(shortChannelId: ShortChannelId, a: PublicKey, b: PublicKey) + case class ChannelDesc private(shortChannelId: ShortChannelId, a: PublicKey, b: PublicKey){ + def reversed: ChannelDesc = ChannelDesc(shortChannelId, b, a) + } object ChannelDesc { def apply(u: ChannelUpdate, ann: ChannelAnnouncement): ChannelDesc = { // the least significant bit tells us if it is node1 or node2 @@ -555,6 +561,8 @@ object Router { } } + case class MessageRouteParams(maxRouteLength: Int, ratios: MessagePath.WeightRatios) + case class Ignore(nodes: Set[PublicKey], channels: Set[ChannelDesc]) { // @formatter:off def +(ignoreNode: PublicKey): Ignore = copy(nodes = nodes + ignoreNode) @@ -581,6 +589,17 @@ object Router { extraEdges: Seq[ExtraEdge] = Nil, paymentContext: Option[PaymentContext] = None) + case class MessageRouteRequest(replyTo: typed.ActorRef[MessageRouteResponse], + source: PublicKey, + target: PublicKey, + ignoredNodes: Set[PublicKey]) + + // @formatter:off + sealed trait MessageRouteResponse { def target: PublicKey } + case class MessageRoute(intermediateNodes: Seq[PublicKey], target: PublicKey) extends MessageRouteResponse + case class MessageRouteNotFound(target: PublicKey) extends MessageRouteResponse + // @formatter:on + /** * Useful for having appropriate logging context at hand when finding routes */ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/StaleChannels.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/StaleChannels.scala index a6d228fc9f..8500491699 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/StaleChannels.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/StaleChannels.scala @@ -44,8 +44,8 @@ object StaleChannels { ctx.system.eventStream.publish(ChannelLost(shortChannelId)) } - val staleChannelsToRemove = staleChannels.flatMap(pc => Seq(ChannelDesc(pc.ann.shortChannelId, pc.ann.nodeId1, pc.ann.nodeId2), ChannelDesc(pc.ann.shortChannelId, pc.ann.nodeId2, pc.ann.nodeId1))) - val graphWithBalances1 = d.graphWithBalances.removeEdges(staleChannelsToRemove) + val staleChannelsToRemove = staleChannels.map(pc => ChannelDesc(pc.ann.shortChannelId, pc.ann.nodeId1, pc.ann.nodeId2)) + val graphWithBalances1 = d.graphWithBalances.removeChannels(staleChannelsToRemove) staleNodes.foreach { nodeId => log.info("pruning nodeId={} (stale)", nodeId) db.removeNode(nodeId) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala index d5c8e70220..8f10738095 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Validation.scala @@ -28,7 +28,7 @@ import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{UtxoStatus, ValidateReque import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.db.NetworkDb -import fr.acinq.eclair.router.Graph.GraphStructure.GraphEdge +import fr.acinq.eclair.router.Graph.GraphStructure.ActiveEdge import fr.acinq.eclair.router.Monitoring.Metrics import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.transactions.Scripts @@ -179,11 +179,10 @@ object Validation { // mutable variable is simpler here var graph = d.graphWithBalances // remove previous private edges - pubChan.update_1_opt.foreach(u => graph = graph.removeEdge(ChannelDesc(u, privateChannel))) - pubChan.update_2_opt.foreach(u => graph = graph.removeEdge(ChannelDesc(u, privateChannel))) + graph = graph.removeChannel(ChannelDesc(privateChannel.shortIds.localAlias, privateChannel.nodeId1, privateChannel.nodeId2)) // add new public edges - pubChan.update_1_opt.foreach(u => graph = graph.addEdge(GraphEdge(u, pubChan))) - pubChan.update_2_opt.foreach(u => graph = graph.addEdge(GraphEdge(u, pubChan))) + pubChan.update_1_opt.foreach(u => graph = graph.addEdge(ActiveEdge(u, pubChan))) + pubChan.update_2_opt.foreach(u => graph = graph.addEdge(ActiveEdge(u, pubChan))) graph case None => d.graphWithBalances } @@ -228,8 +227,7 @@ object Validation { db.removeChannel(shortChannelId) // NB: this also removes channel updates // we also need to remove updates from the graph val graphWithBalances1 = d.graphWithBalances - .removeEdge(ChannelDesc(lostChannel.shortChannelId, lostChannel.nodeId1, lostChannel.nodeId2)) - .removeEdge(ChannelDesc(lostChannel.shortChannelId, lostChannel.nodeId2, lostChannel.nodeId1)) + .removeChannel(ChannelDesc(lostChannel.shortChannelId, lostChannel.nodeId1, lostChannel.nodeId2)) // we notify front nodes ctx.system.eventStream.publish(ChannelLost(shortChannelId)) lostNodes.foreach { @@ -282,13 +280,13 @@ object Validation { remoteOrigins.foreach(sendDecision(_, GossipDecision.Accepted(n))) ctx.system.eventStream.publish(NodeUpdated(n)) db.updateNode(n) - d.copy(nodes = d.nodes + (n.nodeId -> n), rebroadcast = d.rebroadcast.copy(nodes = d.rebroadcast.nodes ++ rebroadcastNode)) + d.copy(nodes = d.nodes + (n.nodeId -> n), rebroadcast = d.rebroadcast.copy(nodes = d.rebroadcast.nodes ++ rebroadcastNode), graphWithBalances = d.graphWithBalances.addOrUpdateVertex(n)) } else if (d.channels.values.exists(c => isRelatedTo(c.ann, n.nodeId))) { log.debug("added node nodeId={}", n.nodeId) remoteOrigins.foreach(sendDecision(_, GossipDecision.Accepted(n))) ctx.system.eventStream.publish(NodesDiscovered(n :: Nil)) db.addNode(n) - d.copy(nodes = d.nodes + (n.nodeId -> n), rebroadcast = d.rebroadcast.copy(nodes = d.rebroadcast.nodes ++ rebroadcastNode)) + d.copy(nodes = d.nodes + (n.nodeId -> n), rebroadcast = d.rebroadcast.copy(nodes = d.rebroadcast.nodes ++ rebroadcastNode), graphWithBalances = d.graphWithBalances.addOrUpdateVertex(n)) } else if (d.awaiting.keys.exists(c => isRelatedTo(c, n.nodeId))) { log.debug("stashing {}", n) d.copy(stash = d.stash.copy(nodes = d.stash.nodes + (n -> origins))) @@ -325,7 +323,7 @@ object Validation { case Left(_) => // NB: we update the channels because the balances may have changed even if the channel_update is the same. val pc1 = pc.applyChannelUpdate(update) - val graphWithBalances1 = d.graphWithBalances.addEdge(GraphEdge(u, pc1)) + val graphWithBalances1 = d.graphWithBalances.addEdge(ActiveEdge(u, pc1)) d.copy(rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> origins1)), channels = d.channels + (pc.shortChannelId -> pc1), graphWithBalances = graphWithBalances1) case Right(_) => d.copy(rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> origins1))) @@ -341,7 +339,7 @@ object Validation { case Left(_) => // NB: we update the graph because the balances may have changed even if the channel_update is the same. val pc1 = pc.applyChannelUpdate(update) - val graphWithBalances1 = d.graphWithBalances.addEdge(GraphEdge(u, pc1)) + val graphWithBalances1 = d.graphWithBalances.addEdge(ActiveEdge(u, pc1)) d.copy(channels = d.channels + (pc.shortChannelId -> pc1), graphWithBalances = graphWithBalances1) case Right(_) => d } @@ -359,10 +357,10 @@ object Validation { val pc1 = pc.applyChannelUpdate(update) val graphWithBalances1 = if (u.channelFlags.isEnabled) { update.left.foreach(_ => log.info("added local shortChannelId={} public={} to the network graph", u.shortChannelId, publicChannel)) - d.graphWithBalances.addEdge(GraphEdge(u, pc1)) + d.graphWithBalances.addEdge(ActiveEdge(u, pc1)) } else { - update.left.foreach(_ => log.info("removed local shortChannelId={} public={} from the network graph", u.shortChannelId, publicChannel)) - d.graphWithBalances.removeEdge(ChannelDesc(u, pc1.ann)) + update.left.foreach(_ => log.info("disabled local shortChannelId={} public={} in the network graph", u.shortChannelId, publicChannel)) + d.graphWithBalances.disableEdge(ChannelDesc(u, pc1.ann)) } d.copy(channels = d.channels + (pc.shortChannelId -> pc1), rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> origins)), graphWithBalances = graphWithBalances1) } else { @@ -372,7 +370,7 @@ object Validation { db.updateChannel(u) // we also need to update the graph val pc1 = pc.applyChannelUpdate(update) - val graphWithBalances1 = d.graphWithBalances.addEdge(GraphEdge(u, pc1)) + val graphWithBalances1 = d.graphWithBalances.addEdge(ActiveEdge(u, pc1)) update.left.foreach(_ => log.info("added local shortChannelId={} public={} to the network graph", u.shortChannelId, publicChannel)) d.copy(channels = d.channels + (pc.shortChannelId -> pc1), rebroadcast = d.rebroadcast.copy(updates = d.rebroadcast.updates + (u -> origins)), graphWithBalances = graphWithBalances1) } @@ -399,10 +397,10 @@ object Validation { val pc1 = pc.applyChannelUpdate(update) val graphWithBalances1 = if (u.channelFlags.isEnabled) { update.left.foreach(_ => log.info("added local channelId={} public={} to the network graph", pc.channelId, publicChannel)) - d.graphWithBalances.addEdge(GraphEdge(u, pc1)) + d.graphWithBalances.addEdge(ActiveEdge(u, pc1)) } else { - update.left.foreach(_ => log.info("removed local channelId={} public={} from the network graph", pc.channelId, publicChannel)) - d.graphWithBalances.removeEdge(ChannelDesc(u, pc1)) + update.left.foreach(_ => log.info("disabled local channelId={} public={} in the network graph", pc.channelId, publicChannel)) + d.graphWithBalances.disableEdge(ChannelDesc(u, pc1)) } d.copy(privateChannels = d.privateChannels + (pc.channelId -> pc1), graphWithBalances = graphWithBalances1) } else { @@ -411,7 +409,7 @@ object Validation { ctx.system.eventStream.publish(ChannelUpdatesReceived(u :: Nil)) // we also need to update the graph val pc1 = pc.applyChannelUpdate(update) - val graphWithBalances1 = d.graphWithBalances.addEdge(GraphEdge(u, pc1)) + val graphWithBalances1 = d.graphWithBalances.addEdge(ActiveEdge(u, pc1)) update.left.foreach(_ => log.info("added local channelId={} public={} to the network graph", pc.channelId, publicChannel)) d.copy(privateChannels = d.privateChannels + (pc.channelId -> pc1), graphWithBalances = graphWithBalances1) } @@ -461,7 +459,7 @@ object Validation { ctx.system.eventStream.publish(ChannelUpdatesReceived(channelUpdates.keys)) // We update the graph. val graphWithBalances1 = channelUpdates.keys.foldLeft(d.graphWithBalances) { - case (currentGraph, currentUpdate) if currentUpdate.channelFlags.isEnabled => currentGraph.addEdge(GraphEdge(currentUpdate, pc1)) + case (currentGraph, currentUpdate) if currentUpdate.channelFlags.isEnabled => currentGraph.addEdge(ActiveEdge(currentUpdate, pc1)) case (currentGraph, _) => currentGraph } d.copy(channels = d.channels + (pc1.shortChannelId -> pc1), prunedChannels = d.prunedChannels - pc1.shortChannelId, rebroadcast = rebroadcast1, graphWithBalances = graphWithBalances1) @@ -556,8 +554,7 @@ object Validation { log.info("removing private local channel and channel_update for channelId={} localAlias={}", channelId, localAlias) // we remove the corresponding updates from the graph val graphWithBalances1 = d.graphWithBalances - .removeEdge(ChannelDesc(localAlias, localNodeId, remoteNodeId)) - .removeEdge(ChannelDesc(localAlias, remoteNodeId, localNodeId)) + .removeChannel(ChannelDesc(localAlias, localNodeId, remoteNodeId)) // and we remove the channel and channel_update from our state d.copy(privateChannels = d.privateChannels - channelId, scid2PrivateChannels = scid2PrivateChannels1, graphWithBalances = graphWithBalances1) } else { @@ -571,7 +568,7 @@ object Validation { val pc1 = pc.updateBalances(e.commitments) log.debug("public channel balance updated: {}", pc1) val update_opt = if (e.commitments.localNodeId == pc1.ann.nodeId1) pc1.update_1_opt else pc1.update_2_opt - val graphWithBalances1 = update_opt.map(u => d.graphWithBalances.addEdge(GraphEdge(u, pc1))).getOrElse(d.graphWithBalances) + val graphWithBalances1 = update_opt.map(u => d.graphWithBalances.addEdge(ActiveEdge(u, pc1))).getOrElse(d.graphWithBalances) (d.channels + (pc.ann.shortChannelId -> pc1), graphWithBalances1) case None => (d.channels, d.graphWithBalances) @@ -581,7 +578,7 @@ object Validation { val pc1 = pc.updateBalances(e.commitments) log.debug("private channel balance updated: {}", pc1) val update_opt = if (e.commitments.localNodeId == pc1.nodeId1) pc1.update_1_opt else pc1.update_2_opt - val graphWithBalances2 = update_opt.map(u => graphWithBalances1.addEdge(GraphEdge(u, pc1))).getOrElse(graphWithBalances1) + val graphWithBalances2 = update_opt.map(u => graphWithBalances1.addEdge(ActiveEdge(u, pc1))).getOrElse(graphWithBalances1) (d.privateChannels + (e.channelId -> pc1), graphWithBalances2) case None => (d.privateChannels, graphWithBalances1) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 2f69cb7c70..fd634a0024 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -27,9 +27,9 @@ import fr.acinq.eclair.io.MessageRelay.RelayAll import fr.acinq.eclair.io.{OpenChannelInterceptor, PeerConnection} import fr.acinq.eclair.message.OnionMessages.OnionMessageConfig import fr.acinq.eclair.payment.relay.Relayer.{AsyncPaymentsParams, RelayFees, RelayParams} -import fr.acinq.eclair.router.Graph.WeightRatios +import fr.acinq.eclair.router.Graph.{MessagePath, WeightRatios} import fr.acinq.eclair.router.PathFindingExperimentConf -import fr.acinq.eclair.router.Router.{MultiPartParams, PathFindingConf, RouterConf, SearchBoundaries} +import fr.acinq.eclair.router.Router.{MessageRouteParams, MultiPartParams, PathFindingConf, RouterConf, SearchBoundaries} import fr.acinq.eclair.wire.protocol.{Color, EncodingType, NodeAddress, OnionRoutingPacket} import org.scalatest.Tag import scodec.bits.{ByteVector, HexStringSyntax} @@ -200,6 +200,7 @@ object TestConstants { ), experimentName = "alice-test-experiment", experimentPercentage = 100))), + messageRouteParams = MessageRouteParams(8, MessagePath.WeightRatios(0.7, 0.1, 0.2, 1.5)), balanceEstimateHalfLife = 1 day, ), socksProxy_opt = None, @@ -212,7 +213,8 @@ object TestConstants { blockchainWatchdogSources = blockchainWatchdogSources, onionMessageConfig = OnionMessageConfig( relayPolicy = RelayAll, - timeout = 1 minute, + minIntermediateHops = 9, + timeout = 200 millis, maxAttempts = 2, ), purgeInvoicesInterval = None @@ -356,6 +358,7 @@ object TestConstants { ), experimentName = "bob-test-experiment", experimentPercentage = 100))), + messageRouteParams = MessageRouteParams(9, MessagePath.WeightRatios(0.5, 0.2, 0.3, 3.14)), balanceEstimateHalfLife = 1 day, ), socksProxy_opt = None, @@ -368,7 +371,8 @@ object TestConstants { blockchainWatchdogSources = blockchainWatchdogSources, onionMessageConfig = OnionMessageConfig( relayPolicy = RelayAll, - timeout = 1 minute, + minIntermediateHops = 8, + timeout = 100 millis, maxAttempts = 2, ), purgeInvoicesInterval = None diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/MessageIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/MessageIntegrationSpec.scala index 72bb4935a2..fefdef46e2 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/MessageIntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/MessageIntegrationSpec.scala @@ -58,7 +58,7 @@ class MessageIntegrationSpec extends IntegrationSpec { test("try to reach unknown node") { val alice = new EclairImpl(nodes("A")) val probe = TestProbe() - alice.sendOnionMessage(Nil, Left(nodes("B").nodeParams.nodeId), None, ByteVector.empty).pipeTo(probe.ref) + alice.sendOnionMessage(Some(Nil), Left(nodes("B").nodeParams.nodeId), expectsReply = false, ByteVector.empty).pipeTo(probe.ref) val result = probe.expectMsgType[SendOnionMessageResponse] assert(!result.sent) } @@ -70,7 +70,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("B").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - alice.sendOnionMessage(Nil, Left(nodes("B").nodeParams.nodeId), None, ByteVector.empty).pipeTo(probe.ref) + alice.sendOnionMessage(Some(Nil), Left(nodes("B").nodeParams.nodeId), expectsReply = false, ByteVector.empty).pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) } @@ -85,7 +85,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val blindedRoute = buildRoute(randomKey(), Seq(IntermediateNode(nodes("A").nodeParams.nodeId), IntermediateNode(nodes("B").nodeParams.nodeId), IntermediateNode(nodes("B").nodeParams.nodeId)), Recipient(nodes("B").nodeParams.nodeId, None)) assert(blindedRoute.introductionNodeId == nodes("A").nodeParams.nodeId) - alice.sendOnionMessage(Nil, Right(blindedRoute), None, ByteVector.empty).pipeTo(probe.ref) + alice.sendOnionMessage(None, Right(blindedRoute), expectsReply = false, ByteVector.empty).pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) } @@ -96,11 +96,13 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("B").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - alice.sendOnionMessage(Nil, Left(nodes("B").nodeParams.nodeId), Some(Seq(nodes("A").nodeParams.nodeId)), hex"3f00").pipeTo(probe.ref) + alice.sendOnionMessage(Some(Nil), Left(nodes("B").nodeParams.nodeId), expectsReply = true, hex"3f00").pipeTo(probe.ref) val recv = eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) assert(recv.finalPayload.records.get[ReplyPath].nonEmpty) - bob.sendOnionMessage(Nil, Right(recv.finalPayload.records.get[ReplyPath].get.blindedRoute), None, hex"1d01ab") + val replyPath = recv.finalPayload.records.get[ReplyPath].get.blindedRoute + assert(replyPath.introductionNodeId == nodes("B").nodeParams.nodeId) + bob.sendOnionMessage(Some(Nil), Right(replyPath), expectsReply = false, hex"1d01ab") val res = probe.expectMsgType[SendOnionMessageResponse] assert(res.failureMessage.isEmpty) @@ -113,7 +115,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("A").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - bob.sendOnionMessage(Nil, Left(nodes("A").nodeParams.nodeId), Some(Seq(nodes("B").nodeParams.nodeId)), hex"3f00").pipeTo(probe.ref) + bob.sendOnionMessage(Some(Nil), Left(nodes("A").nodeParams.nodeId), expectsReply = true, hex"3f00").pipeTo(probe.ref) val recv = eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) assert(recv.finalPayload.records.get[ReplyPath].nonEmpty) @@ -129,7 +131,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("A").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - eve.sendOnionMessage(Nil, Left(nodes("A").nodeParams.nodeId), None, ByteVector.empty).pipeTo(probe.ref) + eve.sendOnionMessage(Some(Nil), Left(nodes("A").nodeParams.nodeId), expectsReply = false, ByteVector.empty).pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) } @@ -141,7 +143,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("A").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - fabrice.sendOnionMessage(Nil, Left(nodes("A").nodeParams.nodeId), None, ByteVector.empty).pipeTo(probe.ref) + fabrice.sendOnionMessage(Some(Nil), Left(nodes("A").nodeParams.nodeId), expectsReply = false, ByteVector.empty).pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) } @@ -152,7 +154,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("F").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - alice.sendOnionMessage(Nil, Left(nodes("F").nodeParams.nodeId), None, ByteVector.empty).pipeTo(probe.ref) + alice.sendOnionMessage(Some(Nil), Left(nodes("F").nodeParams.nodeId), expectsReply = false, ByteVector.empty).pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) } @@ -164,7 +166,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("C").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - alice.sendOnionMessage(nodes("B").nodeParams.nodeId :: Nil, Left(nodes("C").nodeParams.nodeId), None, hex"710301020375020102").pipeTo(probe.ref) + alice.sendOnionMessage(Some(nodes("B").nodeParams.nodeId :: Nil), Left(nodes("C").nodeParams.nodeId), expectsReply = false, hex"710301020375020102").pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) val r = eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) @@ -181,7 +183,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("C").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - alice.sendOnionMessage(nodes("B").nodeParams.nodeId :: Nil, Left(nodes("C").nodeParams.nodeId), None, encodedBytes).pipeTo(probe.ref) + alice.sendOnionMessage(Some(nodes("B").nodeParams.nodeId :: Nil), Left(nodes("C").nodeParams.nodeId), expectsReply = false, encodedBytes).pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) val r = eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) @@ -198,7 +200,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("C").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - alice.sendOnionMessage(nodes("B").nodeParams.nodeId :: Nil, Left(nodes("C").nodeParams.nodeId), None, encodedBytes).pipeTo(probe.ref) + alice.sendOnionMessage(Some(nodes("B").nodeParams.nodeId :: Nil), Left(nodes("C").nodeParams.nodeId), expectsReply = false, encodedBytes).pipeTo(probe.ref) assert(!probe.expectMsgType[SendOnionMessageResponse].sent) eventListener.expectNoMessage() @@ -211,7 +213,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("C").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - alice.sendOnionMessage(nodes("E").nodeParams.nodeId :: Nil, Left(nodes("C").nodeParams.nodeId), None, hex"710301020375020102").pipeTo(probe.ref) + alice.sendOnionMessage(Some(nodes("E").nodeParams.nodeId :: Nil), Left(nodes("C").nodeParams.nodeId), expectsReply = false, hex"710301020375020102").pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) eventListener.expectNoMessage() @@ -224,7 +226,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("C").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - alice.sendOnionMessage(nodes("F").nodeParams.nodeId :: Nil, Left(nodes("C").nodeParams.nodeId), None, hex"710301020375020102").pipeTo(probe.ref) + alice.sendOnionMessage(Some(nodes("F").nodeParams.nodeId :: Nil), Left(nodes("C").nodeParams.nodeId), expectsReply = false, hex"710301020375020102").pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) eventListener.expectNoMessage() @@ -290,7 +292,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("C").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - alice.sendOnionMessage(nodes("E").nodeParams.nodeId :: Nil, Left(nodes("C").nodeParams.nodeId), None, hex"710301020375020102").pipeTo(probe.ref) + alice.sendOnionMessage(Some(nodes("E").nodeParams.nodeId :: Nil), Left(nodes("C").nodeParams.nodeId), expectsReply = false, hex"710301020375020102").pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) val r = eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) @@ -303,7 +305,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val probe = TestProbe() val eventListener = TestProbe() nodes("C").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - alice.sendOnionMessage(nodes("F").nodeParams.nodeId :: Nil, Left(nodes("C").nodeParams.nodeId), None, hex"710301020375020102").pipeTo(probe.ref) + alice.sendOnionMessage(Some(nodes("F").nodeParams.nodeId :: Nil), Left(nodes("C").nodeParams.nodeId), expectsReply = false, hex"710301020375020102").pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) eventListener.expectNoMessage() @@ -315,7 +317,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val eventListener = TestProbe() nodes("C").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) - alice.sendOnionMessage(Nil, Left(nodes("C").nodeParams.nodeId), None, ByteVector.empty).pipeTo(probe.ref) + alice.sendOnionMessage(Some(Nil), Left(nodes("C").nodeParams.nodeId), expectsReply = false, ByteVector.empty).pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) @@ -367,7 +369,7 @@ class MessageIntegrationSpec extends IntegrationSpec { val eventListener = TestProbe() nodes("C").system.eventStream.subscribe(eventListener.ref, classOf[OnionMessages.ReceiveMessage]) waitEventStreamSynced(nodes("C").system.eventStream) - alice.sendOnionMessage(nodes("B").nodeParams.nodeId :: Nil, Left(nodes("C").nodeParams.nodeId), None, hex"7300").pipeTo(probe.ref) + alice.sendOnionMessage(Some(nodes("B").nodeParams.nodeId :: Nil), Left(nodes("C").nodeParams.nodeId), expectsReply = false, hex"7300").pipeTo(probe.ref) assert(probe.expectMsgType[SendOnionMessageResponse].sent) val r = eventListener.expectMsgType[OnionMessages.ReceiveMessage](max = 60 seconds) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala index c71781db25..25344f6038 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala @@ -103,7 +103,7 @@ object MinimalNodeFixture extends Assertions with Eventually with IntegrationPat val paymentFactory = PaymentInitiator.SimplePaymentFactory(nodeParams, router, register) val paymentInitiator = system.actorOf(PaymentInitiator.props(nodeParams, paymentFactory), "payment-initiator") val channels = nodeParams.db.channels.listLocalChannels() - val postman = system.spawn(Behaviors.supervise(Postman(nodeParams, switchboard.toTyped, offerManager)).onFailure(typed.SupervisorStrategy.restart), name = "postman") + val postman = system.spawn(Behaviors.supervise(Postman(nodeParams, switchboard.toTyped, router.toTyped, offerManager)).onFailure(typed.SupervisorStrategy.restart), name = "postman") switchboard ! Switchboard.Init(channels) relayer ! PostRestartHtlcCleaner.Init(channels) readyListener.expectMsgAllOf( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala index 668e5abf12..cb82c33046 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala @@ -29,15 +29,16 @@ import fr.acinq.eclair.channel.{DATA_NORMAL, RealScidStatus} import fr.acinq.eclair.integration.basic.fixtures.MinimalNodeFixture import fr.acinq.eclair.integration.basic.fixtures.MinimalNodeFixture.{connect, getChannelData, getRouterData, knownFundingTxs, nodeParamsFor, openChannel, watcherAutopilot} import fr.acinq.eclair.integration.basic.fixtures.composite.ThreeNodesFixture +import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient, buildRoute} import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.offer.OfferManager import fr.acinq.eclair.payment.receive.MultiPartHandler.{DummyBlindedHop, ReceivingRoute} import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentToNode, SendSpontaneousPayment} import fr.acinq.eclair.payment.send.{OfferPayment, PaymentLifecycle} import fr.acinq.eclair.testutils.FixtureSpec -import fr.acinq.eclair.wire.protocol.OfferTypes.Offer +import fr.acinq.eclair.wire.protocol.OfferTypes.{Offer, OfferPaths} import fr.acinq.eclair.wire.protocol.{IncorrectOrUnknownPaymentDetails, InvalidOnionBlinding} -import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, randomBytes32} +import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, randomBytes32, randomKey} import org.scalatest.concurrent.IntegrationPatience import org.scalatest.{Tag, TestData} import scodec.bits.HexStringSyntax @@ -70,7 +71,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { .modify(_.channelConf.channelFlags.announceChannel).setTo(!testData.tags.contains(PrivateChannels)) val f = ThreeNodesFixture(aliceParams, bobParams, carolParams, testData.name) - createChannels(f) + createChannels(f, testData) f } @@ -78,7 +79,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { fixture.cleanup() } - private def createChannels(f: FixtureParam): Unit = { + private def createChannels(f: FixtureParam, testData: TestData): Unit = { import f._ alice.watcher.setAutoPilot(watcherAutopilot(knownFundingTxs(alice, bob))) @@ -87,7 +88,6 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { connect(alice, bob) connect(bob, carol) - connect(alice, carol) // TODO: remove once finding routes for invoice requests has been implemented val channelId_ab = openChannel(alice, bob, 500_000 sat).channelId val channelId_bc_1 = openChannel(bob, carol, 100_000 sat).channelId @@ -97,6 +97,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { assert(getChannelData(alice, channelId_ab).asInstanceOf[DATA_NORMAL].shortIds.real.isInstanceOf[RealScidStatus.Final]) assert(getChannelData(bob, channelId_bc_1).asInstanceOf[DATA_NORMAL].shortIds.real.isInstanceOf[RealScidStatus.Final]) assert(getChannelData(bob, channelId_bc_2).asInstanceOf[DATA_NORMAL].shortIds.real.isInstanceOf[RealScidStatus.Final]) + assert(getRouterData(alice).channels.size == 3 || testData.tags.contains(PrivateChannels)) // Carol must have received Bob's alias to create usable blinded routes to herself. assert(getRouterData(carol).privateChannels.values.forall(_.shortIds.remoteAlias_opt.nonEmpty)) } @@ -121,7 +122,26 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes)) recipient.offerManager ! OfferManager.RegisterOffer(offer, recipient.nodeParams.privateKey, None, handler) val offerPayment = payer.system.spawnAnonymous(OfferPayment(payer.nodeParams, payer.postman, payer.paymentInitiator)) - val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, maxAttempts = 1, payer.routeParams, blocking = true) + val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts = 1, payer.routeParams, blocking = true) + offerPayment ! OfferPayment.PayOffer(sender.ref, offer, amount, 1, sendPaymentConfig) + (offer, sender.expectMsgType[PaymentEvent]) + } + + def sendPrivateOfferPayment(f: FixtureParam, payer: MinimalNodeFixture, recipient: MinimalNodeFixture, amount: MilliSatoshi, routes: Seq[ReceivingRoute]): (Offer, PaymentEvent) = { + import f._ + + val sender = TestProbe("sender") + val recipientKey = randomKey() + val pathId = randomBytes32() + val offerPaths = routes.map(route => { + route.nodes.dropRight(1).map(IntermediateNode(_)) + buildRoute(randomKey(), route.nodes.dropRight(1).map(IntermediateNode(_)), Recipient(route.nodes.last, Some(pathId))) + }) + val offer = Offer(None, "test", recipientKey.publicKey, Features.empty, recipient.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(offerPaths))) + val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes)) + recipient.offerManager ! OfferManager.RegisterOffer(offer, recipientKey, Some(pathId), handler) + val offerPayment = payer.system.spawnAnonymous(OfferPayment(payer.nodeParams, payer.postman, payer.paymentInitiator)) + val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts = 1, payer.routeParams, blocking = true) offerPayment ! OfferPayment.PayOffer(sender.ref, offer, amount, 1, sendPaymentConfig) (offer, sender.expectMsgType[PaymentEvent]) } @@ -135,7 +155,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val handler = recipient.system.spawnAnonymous(offerHandler(recipientAmount, routes)) recipient.offerManager ! OfferManager.RegisterOffer(offer, recipient.nodeParams.privateKey, None, handler) val offerPayment = payer.system.spawnAnonymous(OfferPayment(payer.nodeParams, payer.postman, paymentInterceptor.ref)) - val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, maxAttempts = 1, payer.routeParams, blocking = true) + val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts = 1, payer.routeParams, blocking = true) offerPayment ! OfferPayment.PayOffer(sender.ref, offer, recipientAmount, 1, sendPaymentConfig) // We intercept the payment and modify it to use a different amount. val payment = paymentInterceptor.expectMsgType[SendPaymentToNode] @@ -193,7 +213,7 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val amount = 50_000_000 msat val routes = Seq(ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta)) - val (offer, result) = sendOfferPayment(f, alice, carol, amount, routes) + val (offer, result) = sendPrivateOfferPayment(f, alice, carol, amount, routes) verifyPaymentSuccess(offer, amount, result) } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala index db5998001b..db1e093900 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala @@ -24,31 +24,32 @@ import fr.acinq.bitcoin.scalacompat.Block import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} import fr.acinq.eclair.io.MessageRelay.{Disconnected, Sent} import fr.acinq.eclair.io.Switchboard.RelayMessage +import fr.acinq.eclair.message.OnionMessages.RoutingStrategy.FindRoute import fr.acinq.eclair.message.OnionMessages.{BlindedPath, IntermediateNode, ReceiveMessage, Recipient, buildMessage, buildRoute} import fr.acinq.eclair.message.Postman._ import fr.acinq.eclair.payment.offer.OfferManager.RequestInvoice +import fr.acinq.eclair.router.Router.{MessageRoute, MessageRouteRequest} import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv.{InvoiceRequest, ReplyPath} import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv.PathId import fr.acinq.eclair.wire.protocol.{GenericTlv, MessageOnion, OfferTypes, OnionMessagePayloadTlv, TlvStream} -import fr.acinq.eclair.{Features, MilliSatoshiLong, NodeParams, TestConstants, UInt64, randomBytes32, randomKey} +import fr.acinq.eclair.{Features, MilliSatoshiLong, NodeParams, TestConstants, UInt64, randomKey} import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike import scodec.bits.HexStringSyntax -import scala.concurrent.duration._ - class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { - case class FixtureParam(postman: ActorRef[Command], nodeParams: NodeParams, messageRecipient: TestProbe[OnionMessageResponse], switchboard: TestProbe[RelayMessage], offerManager: TestProbe[RequestInvoice]) + case class FixtureParam(postman: ActorRef[Command], nodeParams: NodeParams, messageSender: TestProbe[OnionMessageResponse], switchboard: TestProbe[RelayMessage], offerManager: TestProbe[RequestInvoice], router: TestProbe[MessageRouteRequest]) override def withFixture(test: OneArgTest): Outcome = { val nodeParams = TestConstants.Alice.nodeParams - val messageRecipient = TestProbe[OnionMessageResponse]("messageRecipient") + val messageSender = TestProbe[OnionMessageResponse]("messageSender") val switchboard = TestProbe[RelayMessage]("switchboard") val offerManager = TestProbe[RequestInvoice]("offerManager") - val postman = testKit.spawn(Postman(nodeParams, switchboard.ref, offerManager.ref)) + val router = TestProbe[MessageRouteRequest]("router") + val postman = testKit.spawn(Postman(nodeParams, switchboard.ref, router.ref, offerManager.ref)) try { - withFixture(test.toNoArgTest(FixtureParam(postman, nodeParams, messageRecipient, switchboard, offerManager))) + withFixture(test.toNoArgTest(FixtureParam(postman, nodeParams, messageSender, switchboard, offerManager, router))) } finally { testKit.stop(postman) } @@ -57,66 +58,78 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat test("message forwarded only once") { f => import f._ - val ourKey = randomKey() val recipientKey = randomKey() - postman ! SendMessage(Nil, Recipient(recipientKey.publicKey, None), Some(Seq(ourKey.publicKey)), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), messageRecipient.ref, 100 millis) + postman ! SendMessage(Recipient(recipientKey.publicKey, None), FindRoute, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), expectsReply = true, messageSender.ref) + + val MessageRouteRequest(waitingForRoute, source, target, _) = router.expectMessageType[MessageRouteRequest] + assert(source == nodeParams.nodeId) + assert(target == recipientKey.publicKey) + waitingForRoute ! MessageRoute(Seq.empty, target) - val RelayMessage(messageId, _, nextNodeId, message, _, _) = switchboard.expectMessageType[RelayMessage] + val RelayMessage(messageId, _, nextNodeId, message, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] assert(nextNodeId == recipientKey.publicKey) - postman ! SendingStatus(Sent(messageId)) + replyTo ! Sent(messageId) val ReceiveMessage(finalPayload) = OnionMessages.process(recipientKey, message) assert(finalPayload.records.unknown == Set(GenericTlv(UInt64(33), hex"abcd"))) val replyPath = finalPayload.records.get[ReplyPath].get.blindedRoute val Right((_, reply)) = buildMessage(recipientKey, randomKey(), randomKey(), Nil, BlindedPath(replyPath), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(55), hex"1234")))) - val ReceiveMessage(replyPayload) = OnionMessages.process(ourKey, reply) + val ReceiveMessage(replyPayload) = OnionMessages.process(nodeParams.privateKey, reply) testKit.system.eventStream ! EventStream.Publish(ReceiveMessage(replyPayload)) testKit.system.eventStream ! EventStream.Publish(ReceiveMessage(replyPayload)) - messageRecipient.expectMessage(Response(replyPayload)) - messageRecipient.expectNoMessage() + messageSender.expectMessage(Response(replyPayload)) + messageSender.expectNoMessage() } test("sending failure") { f => import f._ - val ourKey = randomKey() val recipientKey = randomKey() - postman ! SendMessage(Nil, Recipient(recipientKey.publicKey, None), Some(Seq(ourKey.publicKey)), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), messageRecipient.ref, 100 millis) + postman ! SendMessage(Recipient(recipientKey.publicKey, None), FindRoute, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), expectsReply = true, messageSender.ref) + + val MessageRouteRequest(waitingForRoute, source, target, _) = router.expectMessageType[MessageRouteRequest] + assert(source == nodeParams.nodeId) + assert(target == recipientKey.publicKey) + waitingForRoute ! MessageRoute(Seq.empty, target) - val RelayMessage(messageId, _, nextNodeId, _, _, _) = switchboard.expectMessageType[RelayMessage] + val RelayMessage(messageId, _, nextNodeId, _, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] assert(nextNodeId == recipientKey.publicKey) - postman ! SendingStatus(Disconnected(messageId)) + replyTo ! Disconnected(messageId) - messageRecipient.expectMessage(MessageFailed("Peer is not connected")) - messageRecipient.expectNoMessage() + messageSender.expectMessage(MessageFailed("Peer is not connected")) + messageSender.expectNoMessage() } test("timeout") { f => import f._ - val ourKey = randomKey() val recipientKey = randomKey() - postman ! SendMessage(Nil, Recipient(recipientKey.publicKey, None), Some(Seq(ourKey.publicKey)), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), messageRecipient.ref, 1 millis) + postman ! SendMessage(Recipient(recipientKey.publicKey, None), FindRoute, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), expectsReply = true, messageSender.ref) - val RelayMessage(messageId, _, nextNodeId, message, _, _) = switchboard.expectMessageType[RelayMessage] + val MessageRouteRequest(waitingForRoute, source, target, _) = router.expectMessageType[MessageRouteRequest] + assert(source == nodeParams.nodeId) + assert(target == recipientKey.publicKey) + waitingForRoute ! MessageRoute(Seq.empty, target) + + val RelayMessage(messageId, _, nextNodeId, message, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] assert(nextNodeId == recipientKey.publicKey) - postman ! SendingStatus(Sent(messageId)) + replyTo ! Sent(messageId) val ReceiveMessage(finalPayload) = OnionMessages.process(recipientKey, message) assert(finalPayload.records.unknown == Set(GenericTlv(UInt64(33), hex"abcd"))) - messageRecipient.expectMessage(NoReply) + messageSender.expectMessage(NoReply) val replyPath = finalPayload.records.get[ReplyPath].get.blindedRoute val Right((_, reply)) = buildMessage(recipientKey, randomKey(), randomKey(), Nil, BlindedPath(replyPath), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(55), hex"1234")))) - val ReceiveMessage(replyPayload) = OnionMessages.process(ourKey, reply) + val ReceiveMessage(replyPayload) = OnionMessages.process(nodeParams.privateKey, reply) testKit.system.eventStream ! EventStream.Publish(ReceiveMessage(replyPayload)) - messageRecipient.expectNoMessage() + messageSender.expectNoMessage() } test("do not expect reply") { f => @@ -124,17 +137,22 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val recipientKey = randomKey() - postman ! SendMessage(Nil, Recipient(recipientKey.publicKey, None), None, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), messageRecipient.ref, 100 millis) + postman ! SendMessage(Recipient(recipientKey.publicKey, None), FindRoute, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), expectsReply = false, messageSender.ref) + + val MessageRouteRequest(waitingForRoute, source, target, _) = router.expectMessageType[MessageRouteRequest] + assert(source == nodeParams.nodeId) + assert(target == recipientKey.publicKey) + waitingForRoute ! MessageRoute(Seq.empty, target) - val RelayMessage(messageId, _, nextNodeId, message, _, _) = switchboard.expectMessageType[RelayMessage] + val RelayMessage(messageId, _, nextNodeId, message, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] assert(nextNodeId == recipientKey.publicKey) - postman ! SendingStatus(Sent(messageId)) + replyTo ! Sent(messageId) val ReceiveMessage(finalPayload) = OnionMessages.process(recipientKey, message) assert(finalPayload.records.unknown == Set(GenericTlv(UInt64(33), hex"abcd"))) assert(finalPayload.records.get[ReplyPath].isEmpty) - messageRecipient.expectMessage(MessageSent) - messageRecipient.expectNoMessage() + messageSender.expectMessage(MessageSent) + messageSender.expectNoMessage() } test("send to route that starts at ourselves") { f => @@ -143,17 +161,17 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val recipientKey = randomKey() val blindedRoute = buildRoute(randomKey(), Seq(IntermediateNode(nodeParams.nodeId)), Recipient(recipientKey.publicKey, None)) - postman ! SendMessage(Nil, BlindedPath(blindedRoute), None, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), messageRecipient.ref, 100 millis) + postman ! SendMessage(BlindedPath(blindedRoute), FindRoute, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), expectsReply = false, messageSender.ref) - val RelayMessage(messageId, _, nextNodeId, message, _, _) = switchboard.expectMessageType[RelayMessage] + val RelayMessage(messageId, _, nextNodeId, message, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] assert(nextNodeId == recipientKey.publicKey) - postman ! SendingStatus(Sent(messageId)) + replyTo ! Sent(messageId) val ReceiveMessage(finalPayload) = OnionMessages.process(recipientKey, message) assert(finalPayload.records.unknown == Set(GenericTlv(UInt64(33), hex"abcd"))) assert(finalPayload.records.get[ReplyPath].isEmpty) - messageRecipient.expectMessage(MessageSent) - messageRecipient.expectNoMessage() + messageSender.expectMessage(MessageSent) + messageSender.expectNoMessage() } test("forward invoice request to offer manager") { f => @@ -168,4 +186,47 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val request = offerManager.expectMessageType[RequestInvoice] assert(request.messagePayload.pathId_opt.contains(hex"abcd")) } + + test("reply path") {f => + import f._ + + val (a, b, c, d) = (randomKey(), randomKey(), randomKey(), randomKey()) + + postman ! SendMessage(Recipient(d.publicKey, None), FindRoute, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(11), hex"012345"))), expectsReply = true, messageSender.ref) + + val MessageRouteRequest(waitingForRoute, source, target, _) = router.expectMessageType[MessageRouteRequest] + assert(source == nodeParams.nodeId) + assert(target == d.publicKey) + waitingForRoute ! MessageRoute(Seq(a.publicKey, b.publicKey, c.publicKey), target) + + val RelayMessage(messageId, _, next1, message1, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] + assert(next1 == a.publicKey) + replyTo ! Sent(messageId) + val OnionMessages.SendMessage(next2, message2) = OnionMessages.process(a, message1) + assert(next2 == b.publicKey) + val OnionMessages.SendMessage(next3, message3) = OnionMessages.process(b, message2) + assert(next3 == c.publicKey) + val OnionMessages.SendMessage(next4, message4) = OnionMessages.process(c, message3) + assert(next4 == d.publicKey) + val OnionMessages.ReceiveMessage(payload) = OnionMessages.process(d, message4) + assert(payload.records.unknown == Set(GenericTlv(UInt64(11), hex"012345"))) + assert(payload.records.get[ReplyPath].nonEmpty) + val replyPath = payload.records.get[ReplyPath].get.blindedRoute + assert(replyPath.introductionNodeId == d.publicKey) + assert(replyPath.length >= nodeParams.onionMessageConfig.minIntermediateHops) + assert(nodeParams.onionMessageConfig.minIntermediateHops > 5) + + val Right((next5, reply)) = OnionMessages.buildMessage(d, randomKey(), randomKey(), Nil, OnionMessages.BlindedPath(replyPath), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(13), hex"6789")))) + assert(next5 == c.publicKey) + val OnionMessages.SendMessage(next6, message6) = OnionMessages.process(c, reply) + assert(next6 == b.publicKey) + val OnionMessages.SendMessage(next7, message7) = OnionMessages.process(b, message6) + assert(next7 == a.publicKey) + val OnionMessages.SendMessage(next8, message8) = OnionMessages.process(a, message7) + assert(next8 == nodeParams.nodeId) + val OnionMessages.ReceiveMessage(replyPayload) = OnionMessages.process(nodeParams.privateKey, message8) + + postman ! WrappedMessage(replyPayload) + assert(replyPayload.records.unknown == Set(GenericTlv(UInt64(13), hex"6789"))) + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala index da8549fd00..a6eb63626b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/send/OfferPaymentSpec.scala @@ -24,6 +24,7 @@ import akka.testkit.TestProbe import com.typesafe.config.ConfigFactory import fr.acinq.eclair.crypto.Sphinx.RouteBlinding import fr.acinq.eclair.message.OnionMessages.Recipient +import fr.acinq.eclair.message.OnionMessages.RoutingStrategy.FindRoute import fr.acinq.eclair.message.Postman import fr.acinq.eclair.payment.send.OfferPayment._ import fr.acinq.eclair.payment.send.PaymentInitiator.SendPaymentToNode @@ -65,10 +66,11 @@ class OfferPaymentSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val merchantKey = randomKey() val offer = Offer(None, "amountless offer", merchantKey.publicKey, Features.empty, nodeParams.chainHash) - offerPayment ! PayOffer(probe.ref, offer, 40_000_000 msat, 1, SendPaymentConfig(None, 1, routeParams, blocking = false)) - val Postman.SendMessage(_, Recipient(recipientId, _, _, _), _, message, replyTo, _) = postman.expectMessageType[Postman.SendMessage] + offerPayment ! PayOffer(probe.ref, offer, 40_000_000 msat, 1, SendPaymentConfig(None, connectDirectly = false, 1, routeParams, blocking = false)) + val Postman.SendMessage(Recipient(recipientId, _, _, _), FindRoute, message, expectsReply, replyTo) = postman.expectMessageType[Postman.SendMessage] assert(recipientId == merchantKey.publicKey) assert(message.get[OnionMessagePayloadTlv.InvoiceRequest].nonEmpty) + assert(expectsReply) val Right(invoiceRequest) = InvoiceRequest.validate(message.get[OnionMessagePayloadTlv.InvoiceRequest].get.tlvs) val preimage = randomBytes32() @@ -88,11 +90,12 @@ class OfferPaymentSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val merchantKey = randomKey() val offer = Offer(None, "amountless offer", merchantKey.publicKey, Features.empty, nodeParams.chainHash) - offerPayment ! PayOffer(probe.ref, offer, 40_000_000 msat, 1, SendPaymentConfig(None, 1, routeParams, blocking = false)) + offerPayment ! PayOffer(probe.ref, offer, 40_000_000 msat, 1, SendPaymentConfig(None, connectDirectly = false, 1, routeParams, blocking = false)) for (_ <- 1 to nodeParams.onionMessageConfig.maxAttempts) { - val Postman.SendMessage(_, Recipient(recipientId, _, _, _), _, message, replyTo, _) = postman.expectMessageType[Postman.SendMessage] + val Postman.SendMessage(Recipient(recipientId, _, _, _), FindRoute, message, expectsReply, replyTo) = postman.expectMessageType[Postman.SendMessage] assert(recipientId == merchantKey.publicKey) assert(message.get[OnionMessagePayloadTlv.InvoiceRequest].nonEmpty) + assert(expectsReply) val Right(invoiceRequest) = InvoiceRequest.validate(message.get[OnionMessagePayloadTlv.InvoiceRequest].get.tlvs) assert(invoiceRequest.isValid) assert(invoiceRequest.offer == offer) @@ -110,10 +113,11 @@ class OfferPaymentSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val merchantKey = randomKey() val offer = Offer(None, "amountless offer", merchantKey.publicKey, Features.empty, nodeParams.chainHash) - offerPayment ! PayOffer(probe.ref, offer, 40_000_000 msat, 1, SendPaymentConfig(None, 1, routeParams, blocking = false)) - val Postman.SendMessage(_, Recipient(recipientId, _, _, _), _, message, replyTo, _) = postman.expectMessageType[Postman.SendMessage] + offerPayment ! PayOffer(probe.ref, offer, 40_000_000 msat, 1, SendPaymentConfig(None, connectDirectly = false, 1, routeParams, blocking = false)) + val Postman.SendMessage(Recipient(recipientId, _, _, _), FindRoute, message, expectsReply, replyTo) = postman.expectMessageType[Postman.SendMessage] assert(recipientId == merchantKey.publicKey) assert(message.get[OnionMessagePayloadTlv.InvoiceRequest].nonEmpty) + assert(expectsReply) val Right(invoiceRequest) = InvoiceRequest.validate(message.get[OnionMessagePayloadTlv.InvoiceRequest].get.tlvs) val preimage = randomBytes32() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala index 9784f93345..733b6fb5b6 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/BalanceEstimateSpec.scala @@ -19,7 +19,7 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{Satoshi, SatoshiLong} import fr.acinq.eclair.payment.Invoice -import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} +import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, ActiveEdge} import fr.acinq.eclair.router.Router.{ChannelDesc, HopRelayParams} import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, TimestampSecond, randomKey} import org.scalactic.Tolerance.convertNumericToPlusOrMinusWrapper @@ -35,13 +35,13 @@ class BalanceEstimateSpec extends AnyFunSuite { balance.high <= balance.maxCapacity } - def makeEdge(nodeId1: PublicKey, nodeId2: PublicKey, channelId: Long, capacity: Satoshi): GraphEdge = - GraphEdge( + def makeEdge(nodeId1: PublicKey, nodeId2: PublicKey, channelId: Long, capacity: Satoshi): ActiveEdge = + ActiveEdge( ChannelDesc(ShortChannelId(channelId), nodeId1, nodeId2), HopRelayParams.FromHint(Invoice.ExtraEdge(nodeId1, nodeId2, ShortChannelId(channelId), 0 msat, 0, CltvExpiryDelta(0), 0 msat, None)), capacity, None) - def makeEdge(channelId: Long, capacity: Satoshi): GraphEdge = + def makeEdge(channelId: Long, capacity: Satoshi): ActiveEdge = makeEdge(randomKey().publicKey, randomKey().publicKey, channelId, capacity) test("no balance information") { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRouterIntegrationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRouterIntegrationSpec.scala index f39b64c75e..70b6a6717f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRouterIntegrationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRouterIntegrationSpec.scala @@ -6,7 +6,7 @@ import fr.acinq.eclair.blockchain.bitcoind.ZmqWatcher.{WatchExternalChannelSpent import fr.acinq.eclair.channel.states.{ChannelStateTestsBase, ChannelStateTestsTags} import fr.acinq.eclair.channel.{CMD_CLOSE, DATA_NORMAL} import fr.acinq.eclair.io.Peer.PeerRoutingMessage -import fr.acinq.eclair.router.Graph.GraphStructure.GraphEdge +import fr.acinq.eclair.router.Graph.GraphStructure.ActiveEdge import fr.acinq.eclair.router.Router._ import fr.acinq.eclair.wire.protocol.{AnnouncementSignatures, ChannelUpdate, Shutdown} import fr.acinq.eclair.{BlockHeight, TestKitBaseClass} @@ -82,7 +82,7 @@ class ChannelRouterIntegrationSpec extends TestKitBaseClass with FixtureAnyFunSu // router graph contains a single channel assert(router.stateData.graphWithBalances.graph.vertexSet() == Set(aliceNodeId, bobNodeId)) - assert(router.stateData.graphWithBalances.graph.edgeSet().toSet == Set(GraphEdge(aliceInitialChannelUpdate, privateChannel), GraphEdge(bobChannelUpdate1, privateChannel))) + assert(router.stateData.graphWithBalances.graph.edgeSet().toSet == Set(ActiveEdge(aliceInitialChannelUpdate, privateChannel), ActiveEdge(bobChannelUpdate1, privateChannel))) if (testTags.contains(ChannelStateTestsTags.ChannelsPublic)) { // this is a public channel @@ -151,7 +151,7 @@ class ChannelRouterIntegrationSpec extends TestKitBaseClass with FixtureAnyFunSu // router graph contains a single channel assert(router.stateData.graphWithBalances.graph.vertexSet() == Set(aliceNodeId, bobNodeId)) assert(router.stateData.graphWithBalances.graph.edgeSet().size == 2) - assert(router.stateData.graphWithBalances.graph.edgeSet().toSet == Set(GraphEdge(aliceChannelUpdate2, publicChannel), GraphEdge(bobChannelUpdate2, publicChannel))) + assert(router.stateData.graphWithBalances.graph.edgeSet().toSet == Set(ActiveEdge(aliceChannelUpdate2, publicChannel), ActiveEdge(bobChannelUpdate2, publicChannel))) } else { // this is a private channel // funding tx reaches 6 blocks, no announcements are exchanged because the channel is private @@ -166,7 +166,7 @@ class ChannelRouterIntegrationSpec extends TestKitBaseClass with FixtureAnyFunSu // router graph contains a single channel assert(router.stateData.graphWithBalances.graph.vertexSet() == Set(aliceNodeId, bobNodeId)) - assert(router.stateData.graphWithBalances.graph.edgeSet().toSet == Set(GraphEdge(aliceInitialChannelUpdate, privateChannel), GraphEdge(bobChannelUpdate1, privateChannel))) + assert(router.stateData.graphWithBalances.graph.edgeSet().toSet == Set(ActiveEdge(aliceInitialChannelUpdate, privateChannel), ActiveEdge(bobChannelUpdate1, privateChannel))) } // channel closes channels.alice ! CMD_CLOSE(TestProbe().ref, scriptPubKey = None, feerates = None) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala index 092fb0547e..5c370838f5 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/GraphSpec.scala @@ -17,28 +17,34 @@ package fr.acinq.eclair.router import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey -import fr.acinq.bitcoin.scalacompat.SatoshiLong +import fr.acinq.bitcoin.scalacompat.{ByteVector32, ByteVector64, SatoshiLong} import fr.acinq.eclair.payment.relay.Relayer.RelayFees -import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} -import fr.acinq.eclair.router.Graph.{HeuristicsConstants, WeightRatios, yenKshortestPaths} +import fr.acinq.eclair.router.Announcements.makeNodeAnnouncement +import fr.acinq.eclair.router.Graph.GraphStructure.{ActiveEdge, DirectedGraph, DisabledEdge} +import fr.acinq.eclair.router.Graph.{HeuristicsConstants, MessagePath, WeightRatios, yenKshortestPaths} import fr.acinq.eclair.router.RouteCalculationSpec._ -import fr.acinq.eclair.router.Router.ChannelDesc -import fr.acinq.eclair.{BlockHeight, MilliSatoshiLong, ShortChannelId} +import fr.acinq.eclair.router.Router.{ChannelDesc, PublicChannel} +import fr.acinq.eclair.wire.protocol.{ChannelUpdate, Color} +import fr.acinq.eclair.{BlockHeight, CltvExpiryDelta, FeatureSupport, Features, MilliSatoshiLong, RealShortChannelId, ShortChannelId, TimestampSecondLong, randomKey} import org.scalactic.Tolerance.convertNumericToPlusOrMinusWrapper import org.scalatest.funsuite.AnyFunSuite -import scodec.bits._ +import scodec.bits.HexStringSyntax + +import scala.collection.immutable.SortedMap class GraphSpec extends AnyFunSuite { - val (a, b, c, d, e, f, g, h) = ( - PublicKey(hex"02999fa724ec3c244e4da52b4a91ad421dc96c9a810587849cd4b2469313519c73"), //a - PublicKey(hex"03f1cb1af20fe9ccda3ea128e27d7c39ee27375c8480f11a87c17197e97541ca6a"), //b - PublicKey(hex"0358e32d245ff5f5a3eb14c78c6f69c67cea7846bdf9aeeb7199e8f6fbb0306484"), //c - PublicKey(hex"029e059b6780f155f38e83601969919aae631ddf6faed58fe860c72225eb327d7c"), //d - PublicKey(hex"02f38f4e37142cc05df44683a83e22dea608cf4691492829ff4cf99888c5ec2d3a"), //e - PublicKey(hex"03fc5b91ce2d857f146fd9b986363374ffe04dc143d8bcd6d7664c8873c463cdfc"), //f - PublicKey(hex"03864ef025fde8fb587d989186ce6a4a186895ee44a926bfc370e2c366597a3f8f"), //g - PublicKey(hex"03bfddd2253b42fe12edd37f9071a3883830ed61a4bc347eeac63421629cf032b5") //h + val (priv_a, priv_b, priv_c, priv_d, priv_e, priv_f, priv_g, priv_h) = (randomKey(), randomKey(), randomKey(), randomKey(), randomKey(), randomKey(), randomKey(), randomKey()) + val (a, b, c, d, e, f, g, h) = (priv_a.publicKey, priv_b.publicKey, priv_c.publicKey, priv_d.publicKey, priv_e.publicKey, priv_f.publicKey, priv_g.publicKey, priv_h.publicKey) + val (annA, annB, annC, annD, annE, annF, annG, annH) = ( + makeNodeAnnouncement(priv_a, "A", Color(0, 0, 0), Nil, Features.empty), + makeNodeAnnouncement(priv_b, "B", Color(0, 0, 0), Nil, Features.empty), + makeNodeAnnouncement(priv_c, "C", Color(0, 0, 0), Nil, Features.empty), + makeNodeAnnouncement(priv_d, "D", Color(0, 0, 0), Nil, Features.empty), + makeNodeAnnouncement(priv_e, "E", Color(0, 0, 0), Nil, Features.empty), + makeNodeAnnouncement(priv_f, "F", Color(0, 0, 0), Nil, Features.empty), + makeNodeAnnouncement(priv_g, "G", Color(0, 0, 0), Nil, Features.empty), + makeNodeAnnouncement(priv_h, "H", Color(0, 0, 0), Nil, Features.empty), ) // +---- D -------+ @@ -58,15 +64,15 @@ class GraphSpec extends AnyFunSuite { test("instantiate a graph, with vertices and then add edges") { val graph = DirectedGraph(a) - .addVertex(b) - .addVertex(c) - .addVertex(d) - .addVertex(e) + .addOrUpdateVertex(annB) + .addOrUpdateVertex(annC) + .addOrUpdateVertex(annD) + .addOrUpdateVertex(annE) assert(graph.containsVertex(a) && graph.containsVertex(e)) assert(graph.vertexSet().size == 5) - val otherGraph = graph.addVertex(a) // adding the same vertex twice! + val otherGraph = graph.addOrUpdateVertex(annA) // adding the same vertex twice! assert(otherGraph.vertexSet().size == 5) // add some edges to the graph @@ -89,7 +95,7 @@ class GraphSpec extends AnyFunSuite { assert(graphWithEdges.edgesOf(d).size == 1) assert(graphWithEdges.edgesOf(e).size == 0) - val withRemovedEdges = graphWithEdges.removeEdge(edgeAD.desc) + val withRemovedEdges = graphWithEdges.disableEdge(edgeAD.desc) assert(withRemovedEdges.edgesOf(d).size == 1) } @@ -110,7 +116,7 @@ class GraphSpec extends AnyFunSuite { assert(graph.vertexSet().size == 5) assert(graph.edgesOf(c).size == 1) - assert(graph.getIncomingEdgesOf(c).size == 2) + assert(graph.getIncomingEdgesOf(c).collect{case e: ActiveEdge => e}.size == 2) assert(graph.edgeSet().size == 6) } @@ -144,13 +150,13 @@ class GraphSpec extends AnyFunSuite { assert(graph.edgeSet().size == 6) assert(graph.containsEdge(edgeBE.desc)) - val withRemovedEdge = graph.removeEdge(edgeBE.desc) + val withRemovedEdge = graph.disableEdge(edgeBE.desc) assert(withRemovedEdge.edgeSet().size == 5) - val withRemovedList = graph.removeEdges(Seq(edgeAD.desc, edgeDC.desc)) + val withRemovedList = graph.removeChannels(Seq(edgeAD.desc, edgeDC.desc)) assert(withRemovedList.edgeSet().size == 4) - val withoutAnyIncomingEdgeInE = graph.removeEdges(Seq(edgeBE.desc, edgeCE.desc)) + val withoutAnyIncomingEdgeInE = graph.removeChannels(Seq(edgeBE.desc, edgeCE.desc)) assert(withoutAnyIncomingEdgeInE.containsVertex(e)) assert(withoutAnyIncomingEdgeInE.edgesOf(e).isEmpty) } @@ -167,7 +173,7 @@ class GraphSpec extends AnyFunSuite { assert(edgesAB.head.desc.a == a) assert(edgesAB.head.desc.b == b) - val bIncoming = graph.getIncomingEdgesOf(b) + val bIncoming = graph.getIncomingEdgesOf(b).collect{case e: ActiveEdge => e} assert(bIncoming.size == 1) assert(bIncoming.exists(_.desc.a == a)) // there should be an edge a --> b assert(bIncoming.exists(_.desc.b == b)) @@ -222,23 +228,23 @@ class GraphSpec extends AnyFunSuite { val graph = DirectedGraph(Seq(edgeAB, edgeAD, edgeBC, edgeDC)) assert(graph.edgesOf(a).toSet == Set(edgeAB, edgeAD)) - assert(graph.getIncomingEdgesOf(a) == Nil) + assert(graph.getIncomingEdgesOf(a).collect{case e: ActiveEdge => e}.toSeq == Nil) assert(graph.edgesOf(c) == Nil) - assert(graph.getIncomingEdgesOf(c).toSet == Set(edgeBC, edgeDC)) + assert(graph.getIncomingEdgesOf(c).collect{case e: ActiveEdge => e}.toSet == Set(edgeBC, edgeDC)) val edgeAB1 = edgeAB.copy(balance_opt = Some(200000 msat)) val edgeBC1 = edgeBC.copy(balance_opt = Some(150000 msat)) val graph1 = graph.addEdge(edgeAB1).addEdge(edgeBC1) assert(graph1.edgesOf(a).toSet == Set(edgeAB1, edgeAD)) - assert(graph1.getIncomingEdgesOf(a) == Nil) + assert(graph1.getIncomingEdgesOf(a).collect{case e: ActiveEdge => e}.toSeq == Nil) assert(graph1.edgesOf(c) == Nil) - assert(graph1.getIncomingEdgesOf(c).toSet == Set(edgeBC1, edgeDC)) + assert(graph1.getIncomingEdgesOf(c).collect{case e: ActiveEdge => e}.toSet == Set(edgeBC1, edgeDC)) } def descFromNodes(shortChannelId: Long, a: PublicKey, b: PublicKey): ChannelDesc = makeEdge(shortChannelId, a, b, 0 msat, 0).desc - def edgeFromNodes(shortChannelId: Long, a: PublicKey, b: PublicKey): GraphEdge = makeEdge(shortChannelId, a, b, 0 msat, 0) + def edgeFromNodes(shortChannelId: Long, a: PublicKey, b: PublicKey): ActiveEdge = makeEdge(shortChannelId, a, b, 0 msat, 0) test("amount with fees larger than channel capacity for C->D") { /* @@ -387,4 +393,147 @@ class GraphSpec extends AnyFunSuite { BlockHeight(714930), _ => true, includeLocalChannelCost = true) assert(paths.head.path == Seq(edgeAB)) } + + test("route for messages") { + /* + A -- B -- C -- D + \____ E _____/ + */ + val graph = DirectedGraph(Seq( + makeEdge(1L, a, b, 0 msat, 0, capacity = 100000000 sat, minHtlc = 0 msat, maxHtlc = Some(100 msat)), + makeEdge(1L, b, a, 1 msat, 1, capacity = 100000000 sat, minHtlc = 100 msat, maxHtlc = Some(200 msat)), + makeEdge(2L, b, c, 2 msat, 2, capacity = 100000000 sat, minHtlc = 200 msat, maxHtlc = Some(300 msat)), + makeEdge(2L, c, b, 3 msat, 3, capacity = 100000000 sat, minHtlc = 300 msat, maxHtlc = Some(400 msat)), + makeEdge(3L, c, d, 4 msat, 4, capacity = 100000000 sat, minHtlc = 400 msat, maxHtlc = Some(500 msat)), + makeEdge(3L, d, c, 5 msat, 5, capacity = 100000000 sat, minHtlc = 500 msat, maxHtlc = Some(600 msat)), + makeEdge(4L, a, e, 6 msat, 6, capacity = 1000 sat, minHtlc = 600 msat, maxHtlc = Some(700 msat)), + makeEdge(4L, e, a, 7 msat, 7, capacity = 1000 sat, minHtlc = 700 msat, maxHtlc = Some(800 msat)), + makeEdge(5L, d, e, 8 msat, 8, capacity = 1000 sat, minHtlc = 800 msat, maxHtlc = Some(900 msat)), + makeEdge(5L, e, d, 9 msat, 9, capacity = 1000 sat, minHtlc = 900 msat, maxHtlc = Some(1000 msat)), + )).addOrUpdateVertex(makeNodeAnnouncement(priv_a, "A", Color(0, 0, 0), Nil, Features(Features.OnionMessages -> FeatureSupport.Optional))) + .addOrUpdateVertex(makeNodeAnnouncement(priv_b, "B", Color(0, 0, 0), Nil, Features(Features.OnionMessages -> FeatureSupport.Optional))) + .addOrUpdateVertex(makeNodeAnnouncement(priv_c, "C", Color(0, 0, 0), Nil, Features(Features.OnionMessages -> FeatureSupport.Optional))) + .addOrUpdateVertex(makeNodeAnnouncement(priv_d, "D", Color(0, 0, 0), Nil, Features(Features.OnionMessages -> FeatureSupport.Optional))) + .addOrUpdateVertex(makeNodeAnnouncement(priv_e, "E", Color(0, 0, 0), Nil, Features(Features.OnionMessages -> FeatureSupport.Optional))) + + { + // All nodes can relay messages, same weight for each channel. + val boundaries = (w: MessagePath.RichWeight) => w.length <= 8 + val wr = MessagePath.WeightRatios(1.0, 0.0, 0.0, 1.0) + val Some(path) = MessagePath.dijkstraMessagePath(graph, a, d, Set.empty, boundaries, BlockHeight(793397), wr) + assert(path.map(_.shortChannelId.toLong) == Seq(4, 5)) + } + { + // Source and target don't relay messages but they can still emit and receive. + val boundaries = (w: MessagePath.RichWeight) => w.length <= 8 + val wr = MessagePath.WeightRatios(1.0, 0.0, 0.0, 1.0) + val g = graph.addOrUpdateVertex(makeNodeAnnouncement(priv_a, "A", Color(0, 0, 0), Nil, Features.empty)) + .addOrUpdateVertex(makeNodeAnnouncement(priv_d, "D", Color(0, 0, 0), Nil, Features.empty)) + val Some(path) = MessagePath.dijkstraMessagePath(g, a, d, Set.empty, boundaries, BlockHeight(793397), wr) + assert(path.map(_.shortChannelId.toLong) == Seq(4, 5)) + } + { + // E doesn't relay messages. + val boundaries = (w: MessagePath.RichWeight) => w.length <= 8 + val wr = MessagePath.WeightRatios(1.0, 0.0, 0.0, 1.0) + val g = graph.addOrUpdateVertex(makeNodeAnnouncement(priv_e, "E", Color(0, 0, 0), Nil, Features.empty)) + val Some(path) = MessagePath.dijkstraMessagePath(g, a, d, Set.empty, boundaries, BlockHeight(793397), wr) + assert(path.map(_.shortChannelId.toLong) == Seq(1, 2, 3)) + } + { + // Message can take disabled edges. + val boundaries = (w: MessagePath.RichWeight) => w.length <= 8 + val wr = MessagePath.WeightRatios(1.0, 0.0, 0.0, 1.0) + val g = graph.disableEdge(ChannelDesc(ShortChannelId(4L), a, e)) + .disableEdge(ChannelDesc(ShortChannelId(5L), e, d)) + val Some(path) = MessagePath.dijkstraMessagePath(g, a, d, Set.empty, boundaries, BlockHeight(793397), wr) + assert(path.map(_.shortChannelId.toLong) == Seq(4, 5)) + } + { + // Disabled edges are penalized. + val boundaries = (w: MessagePath.RichWeight) => w.length <= 8 + val wr = MessagePath.WeightRatios(1.0, 0.0, 0.0, 2.0) + val g = graph.disableEdge(ChannelDesc(ShortChannelId(4L), a, e)) + .disableEdge(ChannelDesc(ShortChannelId(5L), e, d)) + val Some(path) = MessagePath.dijkstraMessagePath(g, a, d, Set.empty, boundaries, BlockHeight(793397), wr) + assert(path.map(_.shortChannelId.toLong) == Seq(1, 2, 3)) + } + { + // Disabled edges are penalized but we limit the maximum length of the path. + val boundaries = (w: MessagePath.RichWeight) => w.length <= 2 + val wr = MessagePath.WeightRatios(1.0, 0.0, 0.0, 2.0) + val g = graph.disableEdge(ChannelDesc(ShortChannelId(4L), a, e)) + .disableEdge(ChannelDesc(ShortChannelId(5L), e, d)) + val Some(path) = MessagePath.dijkstraMessagePath(g, a, d, Set.empty, boundaries, BlockHeight(793397), wr) + assert(path.map(_.shortChannelId.toLong) == Seq(4, 5)) + } + { + // Prefer high-capacity channels. + val boundaries = (w: MessagePath.RichWeight) => w.length <= 8 + val wr = MessagePath.WeightRatios(0.0, 0.0, 1.0, 1.0) + val Some(path) = MessagePath.dijkstraMessagePath(graph, a, d, Set.empty, boundaries, BlockHeight(793397), wr) + assert(path.map(_.shortChannelId.toLong) == Seq(1, 2, 3)) + } + { + // We ignore E. + val boundaries = (w: MessagePath.RichWeight) => w.length <= 8 + val wr = MessagePath.WeightRatios(1.0, 0.0, 0.0, 1.0) + val Some(path) = MessagePath.dijkstraMessagePath(graph, a, d, Set(e), boundaries, BlockHeight(793397), wr) + assert(path.map(_.shortChannelId.toLong) == Seq(1, 2, 3)) + } + { + // Target not in graph. + val boundaries = (w: MessagePath.RichWeight) => w.length <= 8 + val wr = MessagePath.WeightRatios(1.0, 0.0, 0.0, 1.0) + assert(MessagePath.dijkstraMessagePath(graph, a, f, Set.empty, boundaries, BlockHeight(793397), wr).isEmpty) + } + } + + test("makeGraph with disabled channels") { + val scid1 = RealShortChannelId(BlockHeight(565643), 1216, 0) + val scid2 = RealShortChannelId(BlockHeight(542280), 2156, 0) + val scid3 = RealShortChannelId(BlockHeight(565779), 2711, 0) + val a = PublicKey(hex"024655b768ef40951b20053a5c4b951606d4d86085d51238f2c67c7dec29c792ca") + val b = PublicKey(hex"036d65409c41ab7380a43448f257809e7496b52bf92057c09c4f300cbd61c50d96") + val c = PublicKey(hex"03cb7983dc247f9f81a0fa2dfa3ce1c255365f7279c8dd143e086ca333df10e278") + val updates = SortedMap( + scid1 -> PublicChannel( + ann = makeChannel(scid1.toLong, a, b), + fundingTxid = ByteVector32.Zeroes, + capacity = DEFAULT_CAPACITY, + // Both directions are enabled. + update_1_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, scid1, 0 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags(isEnabled = true, isNode1 = true), CltvExpiryDelta(14), htlcMinimumMsat = 1 msat, feeBaseMsat = 1000 msat, 10, 4_294_967_295L msat)), + update_2_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, scid1, 0 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags(isEnabled = true, isNode1 = false), CltvExpiryDelta(144), htlcMinimumMsat = 0 msat, feeBaseMsat = 1000 msat, 100, 15_000_000_000L msat)), + meta_opt = None + ), + scid2 -> PublicChannel( + ann = makeChannel(scid2.toLong, b, c), + fundingTxid = ByteVector32.Zeroes, + capacity = DEFAULT_CAPACITY, + // Only one direction is enabled. + update_1_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, scid2, 0 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags(isEnabled = false, isNode1 = true), CltvExpiryDelta(144), htlcMinimumMsat = 1000 msat, feeBaseMsat = 1000 msat, 100, 16_777_000_000L msat)), + update_2_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, scid2, 0 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags(isEnabled = true, isNode1 = false), CltvExpiryDelta(144), htlcMinimumMsat = 1 msat, feeBaseMsat = 667 msat, 1, 16_777_000_000L msat)), + meta_opt = None + ), + scid3 -> PublicChannel( + ann = makeChannel(scid3.toLong, a, c), + fundingTxid = ByteVector32.Zeroes, + capacity = DEFAULT_CAPACITY, + // Both directions are disabled. + update_1_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, scid3, 0 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags(isEnabled = false, isNode1 = true), CltvExpiryDelta(144), htlcMinimumMsat = 1 msat, feeBaseMsat = 1000 msat, 100, 230_000_000L msat)), + update_2_opt = Some(ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, scid3, 0 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags(isEnabled = false, isNode1 = false), CltvExpiryDelta(144), htlcMinimumMsat = 1 msat, feeBaseMsat = 1000 msat, 100, 230_000_000L msat)), + meta_opt = None + ) + ) + val g = DirectedGraph.makeGraph(updates, Seq.empty) + val edgesOfA = g.getIncomingEdgesOf(a).toSeq + assert(edgesOfA.size == 1) + assert(edgesOfA.head.isInstanceOf[ActiveEdge]) + val edgesOfB = g.getIncomingEdgesOf(b).toSeq + assert(edgesOfB.size == 2) + assert(edgesOfB.forall(_.isInstanceOf[ActiveEdge])) + val edgesOfC = g.getIncomingEdgesOf(c).toSeq + assert(edgesOfC.size == 1) + assert(edgesOfC.head.isInstanceOf[DisabledEdge]) + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala index 50ce90911f..9030dac868 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala @@ -20,9 +20,10 @@ import com.softwaremill.quicklens.ModifyPimp import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.bitcoin.scalacompat.{Block, ByteVector32, ByteVector64, Satoshi, SatoshiLong} import fr.acinq.eclair.payment.relay.Relayer.RelayFees +import fr.acinq.eclair.router.Announcements.makeNodeAnnouncement import fr.acinq.eclair.router.BaseRouterSpec.channelHopFromUpdate import fr.acinq.eclair.router.Graph.GraphStructure.DirectedGraph.graphEdgeToHop -import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} +import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, ActiveEdge} import fr.acinq.eclair.router.Graph.{HeuristicsConstants, RichWeight, WeightRatios} import fr.acinq.eclair.router.RouteCalculation._ import fr.acinq.eclair.router.Router._ @@ -160,7 +161,7 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val Success(route1 :: Nil) = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) assert(route2Ids(route1) == 1 :: 2 :: 3 :: 4 :: Nil) - val graphWithRemovedEdge = g.removeEdge(ChannelDesc(ShortChannelId(3L), c, d)) + val graphWithRemovedEdge = g.disableEdge(ChannelDesc(ShortChannelId(3L), c, d)) val route2 = findRoute(graphWithRemovedEdge, a, e, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) assert(route2 == Failure(RouteNotFound)) } @@ -315,10 +316,17 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { } test("route not found (source OR target node not connected)") { + val priv_a = randomKey() + val a = priv_a.publicKey + val annA = makeNodeAnnouncement(priv_a, "A", Color(0, 0, 0), Nil, Features.empty) + val priv_e = randomKey() + val e = priv_e.publicKey + val annE = makeNodeAnnouncement(priv_e, "E", Color(0, 0, 0), Nil, Features.empty) + val g = DirectedGraph(List( makeEdge(2L, b, c, 0 msat, 0), makeEdge(4L, c, d, 0 msat, 0) - )).addVertex(a).addVertex(e) + )).addOrUpdateVertex(annA).addOrUpdateVertex(annE) assert(findRoute(g, a, d, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) == Failure(RouteNotFound)) assert(findRoute(g, b, e, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, numRoutes = 1, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = BlockHeight(400000)) == Failure(RouteNotFound)) @@ -428,14 +436,14 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { val ued = ChannelUpdate(DUMMY_SIG, Block.RegtestGenesisBlock.hash, ShortChannelId(4L), 1 unixsec, ChannelUpdate.MessageFlags(dontForward = false), ChannelUpdate.ChannelFlags(isNode1 = false, isEnabled = false), CltvExpiryDelta(1), 49 msat, 2507 msat, 147, DEFAULT_CAPACITY.toMilliSatoshi) val edges = Seq( - GraphEdge(ChannelDesc(ShortChannelId(1L), a, b), HopRelayParams.FromAnnouncement(uab), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(1L), b, a), HopRelayParams.FromAnnouncement(uba), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(2L), b, c), HopRelayParams.FromAnnouncement(ubc), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(2L), c, b), HopRelayParams.FromAnnouncement(ucb), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(3L), c, d), HopRelayParams.FromAnnouncement(ucd), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(3L), d, c), HopRelayParams.FromAnnouncement(udc), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(4L), d, e), HopRelayParams.FromAnnouncement(ude), DEFAULT_CAPACITY, None), - GraphEdge(ChannelDesc(ShortChannelId(4L), e, d), HopRelayParams.FromAnnouncement(ued), DEFAULT_CAPACITY, None) + ActiveEdge(ChannelDesc(ShortChannelId(1L), a, b), HopRelayParams.FromAnnouncement(uab), DEFAULT_CAPACITY, None), + ActiveEdge(ChannelDesc(ShortChannelId(1L), b, a), HopRelayParams.FromAnnouncement(uba), DEFAULT_CAPACITY, None), + ActiveEdge(ChannelDesc(ShortChannelId(2L), b, c), HopRelayParams.FromAnnouncement(ubc), DEFAULT_CAPACITY, None), + ActiveEdge(ChannelDesc(ShortChannelId(2L), c, b), HopRelayParams.FromAnnouncement(ucb), DEFAULT_CAPACITY, None), + ActiveEdge(ChannelDesc(ShortChannelId(3L), c, d), HopRelayParams.FromAnnouncement(ucd), DEFAULT_CAPACITY, None), + ActiveEdge(ChannelDesc(ShortChannelId(3L), d, c), HopRelayParams.FromAnnouncement(udc), DEFAULT_CAPACITY, None), + ActiveEdge(ChannelDesc(ShortChannelId(4L), d, e), HopRelayParams.FromAnnouncement(ude), DEFAULT_CAPACITY, None), + ActiveEdge(ChannelDesc(ShortChannelId(4L), e, d), HopRelayParams.FromAnnouncement(ued), DEFAULT_CAPACITY, None) ) val g = DirectedGraph(edges) @@ -919,7 +927,7 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { ) ) - val g = DirectedGraph.makeGraph(updates) + val g = DirectedGraph.makeGraph(updates, Seq.empty) val params = DEFAULT_ROUTE_PARAMS .modify(_.boundaries.maxCltv).setTo(CltvExpiryDelta(1008)) .modify(_.heuristics).setTo(Left(WeightRatios(baseFactor = 0, cltvDeltaFactor = 0.15, ageFactor = 0.35, capacityFactor = 0.5, hopCost = RelayFees(0 msat, 0)))) @@ -1648,12 +1656,12 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { // \ / \ / \ / // +----> N ---> N N ---> N ----+ - def makeEdges(n: Int): Seq[GraphEdge] = { + def makeEdges(n: Int): Seq[ActiveEdge] = { val nodes = new Array[(PublicKey, PublicKey)](n) for (i <- nodes.indices) { nodes(i) = (randomKey().publicKey, randomKey().publicKey) } - val q = new mutable.Queue[GraphEdge] + val q = new mutable.Queue[ActiveEdge] // One path is shorter to maximise the overlap between the n-shortest paths, they will all be like the shortest path with a single hop changed. q.enqueue(makeEdge(1L, a, nodes(0)._1, 100 msat, 90)) q.enqueue(makeEdge(2L, a, nodes(0)._2, 100 msat, 100)) @@ -1681,12 +1689,12 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution { // \ / \ / \ / // +----> N ---> N N ---> N ----+ - def makeEdges(n: Int): Seq[GraphEdge] = { + def makeEdges(n: Int): Seq[ActiveEdge] = { val nodes = new Array[(PublicKey, PublicKey)](n) for (i <- nodes.indices) { nodes(i) = (randomKey().publicKey, randomKey().publicKey) } - val q = new mutable.Queue[GraphEdge] + val q = new mutable.Queue[ActiveEdge] q.enqueue(makeEdge(1L, a, nodes(0)._1, 100 msat, 100)) q.enqueue(makeEdge(2L, a, nodes(0)._2, 100 msat, 100)) for (i <- 0 until (n - 1)) { @@ -1939,9 +1947,9 @@ object RouteCalculationSpec { maxHtlc: Option[MilliSatoshi] = None, cltvDelta: CltvExpiryDelta = CltvExpiryDelta(0), capacity: Satoshi = DEFAULT_CAPACITY, - balance_opt: Option[MilliSatoshi] = None): GraphEdge = { + balance_opt: Option[MilliSatoshi] = None): ActiveEdge = { val update = makeUpdateShort(ShortChannelId(shortChannelId), nodeId1, nodeId2, feeBase, feeProportionalMillionth, minHtlc, maxHtlc, cltvDelta) - GraphEdge(ChannelDesc(RealShortChannelId(shortChannelId), nodeId1, nodeId2), HopRelayParams.FromAnnouncement(update), capacity, balance_opt) + ActiveEdge(ChannelDesc(RealShortChannelId(shortChannelId), nodeId1, nodeId2), HopRelayParams.FromAnnouncement(update), capacity, balance_opt) } def makeUpdateShort(shortChannelId: ShortChannelId, nodeId1: PublicKey, nodeId2: PublicKey, feeBase: MilliSatoshi, feeProportionalMillionth: Int, minHtlc: MilliSatoshi = DEFAULT_AMOUNT_MSAT, maxHtlc: Option[MilliSatoshi] = None, cltvDelta: CltvExpiryDelta = CltvExpiryDelta(0), timestamp: TimestampSecond = 0 unixsec): ChannelUpdate = @@ -1965,7 +1973,7 @@ object RouteCalculationSpec { def routes2Ids(routes: Seq[Route]): Set[Seq[Long]] = routes.map(route2Ids).toSet - def route2Edges(route: Route): Seq[GraphEdge] = route.hops.map(hop => GraphEdge(ChannelDesc(hop.shortChannelId, hop.nodeId, hop.nextNodeId), hop.params, 0 sat, None)) + def route2Edges(route: Route): Seq[ActiveEdge] = route.hops.map(hop => ActiveEdge(ChannelDesc(hop.shortChannelId, hop.nodeId, hop.nextNodeId), hop.params, 0 sat, None)) def route2Nodes(route: Route): Seq[(PublicKey, PublicKey)] = route.hops.map(hop => (hop.nodeId, hop.nextNodeId)) diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Message.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Message.scala index 13eaec8e69..edbe16c277 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Message.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Message.scala @@ -46,19 +46,21 @@ trait Message { "recipientNode".as[PublicKey](publicKeyUnmarshaller).?, "recipientBlindedRoute".as[Sphinx.RouteBlinding.BlindedRoute](blindedRouteUnmarshaller).?, "intermediateNodes".as[List[PublicKey]](pubkeyListUnmarshaller).?, - "replyPath".as[List[PublicKey]](pubkeyListUnmarshaller).?, + "expectsReply".as[Boolean], "content".as[ByteVector](bytesUnmarshaller)) { - case (Some(recipientNode), None, intermediateNodes, replyPath, userCustomContent) => + case (Some(recipientNode), None, intermediateNodes_opt, expectsReply, userCustomContent) => complete( - eclairApi.sendOnionMessage(intermediateNodes.getOrElse(Nil), + eclairApi.sendOnionMessage( + intermediateNodes_opt, Left(recipientNode), - replyPath, + expectsReply, userCustomContent)) - case (None, Some(recipientBlindedRoute), intermediateNodes, replyPath, userCustomContent) => + case (None, Some(recipientBlindedRoute), intermediateNodes_opt, expectsReply, userCustomContent) => complete( - eclairApi.sendOnionMessage(intermediateNodes.getOrElse(Nil), + eclairApi.sendOnionMessage( + intermediateNodes_opt, Right(recipientBlindedRoute), - replyPath, + expectsReply, userCustomContent)) case (None, None, _, _, _) => reject(MalformedFormFieldRejection("recipientNode", "You must provide recipientNode or recipientBlindedRoute")) diff --git a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala index 66e650f055..e0528dcc4b 100644 --- a/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala +++ b/eclair-node/src/main/scala/fr/acinq/eclair/api/handlers/Payment.scala @@ -104,11 +104,11 @@ trait Payment { } val payOffer: Route = postRequest("payoffer") { implicit t => - formFields(offerFormParam, amountMsatFormParam, "quantity".as[Long].?, "maxAttempts".as[Int].?, "maxFeeFlatSat".as[Satoshi].?, "maxFeePct".as[Double].?, "externalId".?, "pathFindingExperimentName".?, "blocking".as[Boolean].?) { - case (offer, amountMsat, quantity_opt, maxAttempts_opt, maxFeeFlat_opt, maxFeePct_opt, externalId_opt, pathFindingExperimentName_opt, blocking_opt) => + formFields(offerFormParam, amountMsatFormParam, "quantity".as[Long].?, "maxAttempts".as[Int].?, "maxFeeFlatSat".as[Satoshi].?, "maxFeePct".as[Double].?, "externalId".?, "pathFindingExperimentName".?, "connectDirectly".as[Boolean].?, "blocking".as[Boolean].?) { + case (offer, amountMsat, quantity_opt, maxAttempts_opt, maxFeeFlat_opt, maxFeePct_opt, externalId_opt, pathFindingExperimentName_opt, connectDirectly, blocking_opt) => blocking_opt match { - case Some(true) => complete(eclairApi.payOfferBlocking(offer, amountMsat, quantity_opt.getOrElse(1), externalId_opt, maxAttempts_opt, maxFeeFlat_opt, maxFeePct_opt, pathFindingExperimentName_opt)) - case _ => complete(eclairApi.payOffer(offer, amountMsat, quantity_opt.getOrElse(1), externalId_opt, maxAttempts_opt, maxFeeFlat_opt, maxFeePct_opt, pathFindingExperimentName_opt)) + case Some(true) => complete(eclairApi.payOfferBlocking(offer, amountMsat, quantity_opt.getOrElse(1), externalId_opt, maxAttempts_opt, maxFeeFlat_opt, maxFeePct_opt, pathFindingExperimentName_opt, connectDirectly.getOrElse(false))) + case _ => complete(eclairApi.payOffer(offer, amountMsat, quantity_opt.getOrElse(1), externalId_opt, maxAttempts_opt, maxFeeFlat_opt, maxFeePct_opt, pathFindingExperimentName_opt, connectDirectly.getOrElse(false))) } } }