diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala index eb44f4a7a6..d04a5105f5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala @@ -365,7 +365,7 @@ class Setup(val datadir: File, txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcher, bitcoinClient) channelFactory = Peer.SimpleChannelFactory(nodeParams, watcher, relayer, bitcoinClient, txPublisherFactory) pendingChannelsRateLimiter = system.spawn(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, channels)).onFailure(typed.SupervisorStrategy.resume), name = "pending-channels-rate-limiter") - peerFactory = Switchboard.SimplePeerFactory(nodeParams, bitcoinClient, channelFactory, pendingChannelsRateLimiter) + peerFactory = Switchboard.SimplePeerFactory(nodeParams, bitcoinClient, channelFactory, pendingChannelsRateLimiter, register) switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, peerFactory), "switchboard", SupervisorStrategy.Resume)) _ = switchboard ! Switchboard.Init(channels) @@ -376,7 +376,7 @@ class Setup(val datadir: File, balanceActor = system.spawn(BalanceActor(nodeParams.db, bitcoinClient, channelsListener, nodeParams.balanceCheckInterval), name = "balance-actor") - postman = system.spawn(Behaviors.supervise(Postman(nodeParams, switchboard.toTyped, router.toTyped, offerManager)).onFailure(typed.SupervisorStrategy.restart), name = "postman") + postman = system.spawn(Behaviors.supervise(Postman(nodeParams, switchboard, router.toTyped, register, offerManager)).onFailure(typed.SupervisorStrategy.restart), name = "postman") kit = Kit( nodeParams = nodeParams, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Register.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Register.scala index 5dad716c18..a40f0428bb 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Register.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Register.scala @@ -17,6 +17,7 @@ package fr.acinq.eclair.channel import akka.actor.typed.scaladsl.adapter.TypedActorRefOps +import akka.actor.typed import akka.actor.{Actor, ActorLogging, ActorRef, Props} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey @@ -77,6 +78,9 @@ class Register() extends Actor with ActorLogging { case Symbol("channelsTo") => sender() ! channelsTo + case GetNextNodeId(replyTo, shortChannelId) => + replyTo ! shortIds.get(shortChannelId).flatMap(cid => channelsTo.get(cid)) + case fwd@Forward(replyTo, channelId, msg) => // for backward compatibility with legacy ask, we use the replyTo as sender val compatReplyTo = if (replyTo == null) sender() else replyTo.toClassic @@ -106,4 +110,6 @@ object Register { case class ForwardFailure[T](fwd: Forward[T]) case class ForwardShortIdFailure[T](fwd: ForwardShortId[T]) // @formatter:on + + case class GetNextNodeId(replyTo: typed.ActorRef[Option[PublicKey]], shortChannelId: ShortChannelId) } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala index 51fc4f7367..29e3322938 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala @@ -17,11 +17,13 @@ package fr.acinq.eclair.io import akka.actor.typed.Behavior -import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.scaladsl.adapter.TypedActorRefOps +import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import akka.actor.{ActorRef, typed} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey +import fr.acinq.eclair.ShortChannelId +import fr.acinq.eclair.channel.Register import fr.acinq.eclair.io.Peer.{PeerInfo, PeerInfoResponse} import fr.acinq.eclair.io.Switchboard.GetPeerInfo import fr.acinq.eclair.wire.protocol.OnionMessage @@ -29,9 +31,17 @@ import fr.acinq.eclair.wire.protocol.OnionMessage object MessageRelay { // @formatter:off sealed trait Command - case class RelayMessage(messageId: ByteVector32, switchboard: ActorRef, prevNodeId: PublicKey, nextNodeId: PublicKey, msg: OnionMessage, policy: RelayPolicy, replyTo_opt: Option[typed.ActorRef[Status]]) extends Command + case class RelayMessage(messageId: ByteVector32, + switchboard: ActorRef, + register: ActorRef, + prevNodeId: PublicKey, + nextNode: Either[ShortChannelId, PublicKey], + msg: OnionMessage, + policy: RelayPolicy, + replyTo_opt: Option[typed.ActorRef[Status]]) extends Command case class WrappedPeerInfo(peerInfo: PeerInfoResponse) extends Command case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command + case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command sealed trait Status { val messageId: ByteVector32 @@ -41,12 +51,15 @@ object MessageRelay { case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure { override def toString: String = s"Relay prevented by policy $policy" } - case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure{ + case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure { override def toString: String = s"Can't connect to peer: ${failure.toString}" } - case class Disconnected(messageId: ByteVector32) extends Failure{ + case class Disconnected(messageId: ByteVector32) extends Failure { override def toString: String = "Peer is not connected" } + case class UnknownOutgoingChannel(messageId: ByteVector32, outgoingChannelId: ShortChannelId) extends Failure { + override def toString: String = s"Unknown outgoing channel: $outgoingChannelId" + } sealed trait RelayPolicy case object RelayChannelsOnly extends RelayPolicy @@ -55,7 +68,37 @@ object MessageRelay { def apply(): Behavior[Command] = { Behaviors.receivePartial { - case (context, RelayMessage(messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt)) => + case (context, RelayMessage(messageId, switchboard, register, prevNodeId, Left(outgoingChannelId), msg, policy, replyTo_opt)) => + register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId) + waitForNextNodeId(messageId, switchboard, prevNodeId, outgoingChannelId, msg, policy, replyTo_opt) + case (context, RelayMessage(messageId, switchboard, _, prevNodeId, Right(nextNodeId), msg, policy, replyTo_opt)) => + withNextNodeId(context, messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt) + } + } + + def waitForNextNodeId(messageId: ByteVector32, + switchboard: ActorRef, + prevNodeId: PublicKey, + outgoingChannelId: ShortChannelId, + msg: OnionMessage, + policy: RelayPolicy, + replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = + Behaviors.receivePartial { + case (_, WrappedOptionalNodeId(None)) => + replyTo_opt.foreach(_ ! UnknownOutgoingChannel(messageId, outgoingChannelId)) + Behaviors.stopped + case (context, WrappedOptionalNodeId(Some(nextNodeId))) => + withNextNodeId(context, messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt) + } + + def withNextNodeId(context: ActorContext[Command], + messageId: ByteVector32, + switchboard: ActorRef, + prevNodeId: PublicKey, + nextNodeId: PublicKey, + msg: OnionMessage, + policy: RelayPolicy, + replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = policy match { case RelayChannelsOnly => switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId) @@ -64,8 +107,6 @@ object MessageRelay { switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false) waitForConnection(messageId, msg, replyTo_opt) } - } - } def waitForPreviousPeer(messageId: ByteVector32, switchboard: ActorRef, nextNodeId: PublicKey, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = { Behaviors.receivePartial { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 73e78c1849..7f2cc36a2f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -36,14 +36,11 @@ import fr.acinq.eclair.io.MessageRelay.Status import fr.acinq.eclair.io.Monitoring.Metrics import fr.acinq.eclair.io.OpenChannelInterceptor.{OpenChannelInitiator, OpenChannelNonInitiator} import fr.acinq.eclair.io.PeerConnection.KillReason -import fr.acinq.eclair.io.Switchboard.RelayMessage import fr.acinq.eclair.message.OnionMessages import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes import fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, NodeAddress, OnionMessage, RoutingMessage, UnknownMessage, Warning} -import scala.concurrent.duration.DurationInt - /** * This actor represents a logical peer. There is one [[Peer]] per unique remote node id at all time. * @@ -54,7 +51,7 @@ import scala.concurrent.duration.DurationInt * * Created by PM on 26/08/2016. */ -class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, switchboard: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] { +class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, switchboard: ActorRef, register: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] { import Peer._ @@ -280,8 +277,10 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainP OnionMessages.process(nodeParams.privateKey, msg) match { case OnionMessages.DropMessage(reason) => log.debug("dropping message from {}: {}", remoteNodeId.value.toHex, reason.toString) - case OnionMessages.SendMessage(nextNodeId, message) if nodeParams.features.hasFeature(Features.OnionMessages) => - switchboard ! RelayMessage(randomBytes32(), Some(remoteNodeId), nextNodeId, message, nodeParams.onionMessageConfig.relayPolicy, None) + case OnionMessages.SendMessage(nextNode, message) if nodeParams.features.hasFeature(Features.OnionMessages) => + val messageId = randomBytes32() + val relay = context.spawn(Behaviors.supervise(MessageRelay()).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId") + relay ! MessageRelay.RelayMessage(messageId, switchboard, register, remoteNodeId, nextNode, message, nodeParams.onionMessageConfig.relayPolicy, None) case OnionMessages.SendMessage(_, _) => log.debug("dropping message from {}: relaying onion messages is disabled", remoteNodeId.value.toHex) case received: OnionMessages.ReceiveMessage => @@ -459,7 +458,7 @@ object Peer { context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, txPublisherFactory)) } - def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: ChannelFactory, switchboard: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]): Props = Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory, switchboard, pendingChannelsRateLimiter)) + def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: ChannelFactory, switchboard: ActorRef, register: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]): Props = Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory, switchboard, register, pendingChannelsRateLimiter)) // @formatter:off diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala index c008dec020..ca49947034 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala @@ -20,17 +20,14 @@ import akka.actor.typed.receptionist.{Receptionist, ServiceKey} import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.scaladsl.adapter.{ClassicActorContextOps, ClassicActorRefOps, ClassicActorSystemOps, TypedActorRefOps} import akka.actor.{Actor, ActorContext, ActorLogging, ActorRef, OneForOneStrategy, Props, Stash, Status, SupervisorStrategy, typed} -import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.blockchain.OnchainPubkeyCache import fr.acinq.eclair.channel.Helpers.Closing import fr.acinq.eclair.channel._ import fr.acinq.eclair.io.IncomingConnectionsTracker.TrackIncomingConnection -import fr.acinq.eclair.io.MessageRelay.RelayPolicy import fr.acinq.eclair.io.Peer.{PeerInfoResponse, PeerNotFound} import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes import fr.acinq.eclair.router.Router.RouterConf -import fr.acinq.eclair.wire.protocol.OnionMessage import fr.acinq.eclair.{NodeParams, SubscriptionsComplete} /** @@ -122,10 +119,6 @@ class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory) } case GetRouterPeerConf => sender() ! RouterPeerConf(nodeParams.routerConf, nodeParams.peerConnectionConf) - - case RelayMessage(messageId, prevNodeId, nextNodeId, dataToRelay, relayPolicy, replyTo) => - val relay = context.spawn(Behaviors.supervise(MessageRelay()).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId") - relay ! MessageRelay.RelayMessage(messageId, self, prevNodeId.getOrElse(nodeParams.nodeId), nextNodeId, dataToRelay, relayPolicy, replyTo) } /** @@ -166,9 +159,9 @@ object Switchboard { def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef } - case class SimplePeerFactory(nodeParams: NodeParams, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends PeerFactory { + case class SimplePeerFactory(nodeParams: NodeParams, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command], register: ActorRef) extends PeerFactory { override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef = - context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory, context.self, pendingChannelsRateLimiter), name = peerActorName(remoteNodeId)) + context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory, context.self, register, pendingChannelsRateLimiter), name = peerActorName(remoteNodeId)) } def props(nodeParams: NodeParams, peerFactory: PeerFactory) = Props(new Switchboard(nodeParams, peerFactory)) @@ -183,8 +176,6 @@ object Switchboard { case object GetRouterPeerConf extends RemoteTypes case class RouterPeerConf(routerConf: RouterConf, peerConf: PeerConnection.Conf) extends RemoteTypes - - case class RelayMessage(messageId: ByteVector32, prevNodeId: Option[PublicKey], nextNodeId: PublicKey, message: OnionMessage, relayPolicy: RelayPolicy, replyTo_opt: Option[typed.ActorRef[MessageRelay.Status]]) // @formatter:on } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala index e6a83b3d19..397cc46a66 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/message/OnionMessages.scala @@ -17,6 +17,7 @@ package fr.acinq.eclair.message import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey} +import fr.acinq.eclair.ShortChannelId import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.io.MessageRelay.RelayPolicy import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, IntermediatePayload} @@ -43,7 +44,14 @@ object OnionMessages { timeout: FiniteDuration, maxAttempts: Int) - case class IntermediateNode(nodeId: PublicKey, padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) + case class IntermediateNode(nodeId: PublicKey, outgoingChannel_opt: Option[ShortChannelId] = None, padding: Option[ByteVector] = None, customTlvs: Set[GenericTlv] = Set.empty) { + def toTlvStream(nextNodeId: PublicKey, nextBlinding_opt: Option[PublicKey] = None): TlvStream[RouteBlindingEncryptedDataTlv] = + TlvStream(Set[Option[RouteBlindingEncryptedDataTlv]]( + padding.map(Padding), + outgoingChannel_opt.map(OutgoingChannelId).orElse(Some(OutgoingNodeId(nextNodeId))), + nextBlinding_opt.map(NextBlinding) + ).flatten, customTlvs) + } // @formatter:off sealed trait Destination @@ -63,20 +71,20 @@ object OnionMessages { } // @formatter:on - private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], nextTlvs: Set[RouteBlindingEncryptedDataTlv]): Seq[ByteVector] = { + private def buildIntermediatePayloads(intermediateNodes: Seq[IntermediateNode], lastNodeId: PublicKey, lastBlinding_opt: Option[PublicKey] = None): Seq[ByteVector] = { if (intermediateNodes.isEmpty) { Nil } else { - (intermediateNodes.tail.map(node => Set(OutgoingNodeId(node.nodeId))) :+ nextTlvs) - .zip(intermediateNodes).map { case (tlvs, hop) => TlvStream(hop.padding.map(Padding).toSet[RouteBlindingEncryptedDataTlv] ++ tlvs, hop.customTlvs) } - .map(tlvs => RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes) + val intermediatePayloads = intermediateNodes.dropRight(1).zip(intermediateNodes.tail).map { case (hop, nextNode) => hop.toTlvStream(nextNode.nodeId) } + val lastPayload = intermediateNodes.last.toTlvStream(lastNodeId, lastBlinding_opt) + (intermediatePayloads :+ lastPayload).map(tlvs => RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(tlvs).require.bytes) } } def buildRoute(blindingSecret: PrivateKey, intermediateNodes: Seq[IntermediateNode], recipient: Recipient): Sphinx.RouteBlinding.BlindedRoute = { - val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, Set(OutgoingNodeId(recipient.nodeId))) + val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, recipient.nodeId) val tlvs: Set[RouteBlindingEncryptedDataTlv] = Set(recipient.padding.map(Padding), recipient.pathId.map(PathId)).flatten val lastPayload = RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec.encode(TlvStream(tlvs, recipient.customTlvs)).require.bytes Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId) :+ recipient.nodeId, intermediatePayloads :+ lastPayload).route @@ -100,7 +108,7 @@ object OnionMessages { } case BlindedPath(route) if intermediateNodes.isEmpty => Some(route) case BlindedPath(route) => - val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, Set(OutgoingNodeId(route.introductionNodeId), NextBlinding(route.blindingKey))) + val intermediatePayloads = buildIntermediatePayloads(intermediateNodes, route.introductionNodeId, Some(route.blindingKey)) val routePrefix = Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId), intermediatePayloads).route Some(Sphinx.RouteBlinding.BlindedRoute(routePrefix.introductionNodeId, routePrefix.blindingKey, routePrefix.blindedNodes ++ route.blindedNodes)) } @@ -153,7 +161,7 @@ object OnionMessages { // @formatter:off sealed trait Action case class DropMessage(reason: DropReason) extends Action - case class SendMessage(nextNodeId: PublicKey, message: OnionMessage) extends Action + case class SendMessage(nextNode: Either[ShortChannelId, PublicKey], message: OnionMessage) extends Action case class ReceiveMessage(finalPayload: FinalPayload) extends Action sealed trait DropReason @@ -199,7 +207,8 @@ object OnionMessages { case Left(f) => DropMessage(f) case Right(DecodedEncryptedData(blindedPayload, nextBlinding)) => nextPacket_opt match { case Some(nextPacket) => validateRelayPayload(payload, blindedPayload, nextBlinding, nextPacket) match { - case SendMessage(nextNodeId, nextMsg) if nextNodeId == privateKey.publicKey => process(privateKey, nextMsg) + case SendMessage(Right(nextNodeId), nextMsg) if nextNodeId == privateKey.publicKey => process(privateKey, nextMsg) + case SendMessage(Left(outgoingChannelId), nextMsg) if outgoingChannelId == ShortChannelId.toSelf => process(privateKey, nextMsg) case action => action } case None => validateFinalPayload(payload, blindedPayload) @@ -216,7 +225,7 @@ object OnionMessages { private def validateRelayPayload(payload: TlvStream[OnionMessagePayloadTlv], blindedPayload: TlvStream[RouteBlindingEncryptedDataTlv], nextBlinding: PublicKey, nextPacket: OnionRoutingPacket): Action = { IntermediatePayload.validate(payload, blindedPayload, nextBlinding) match { case Left(f) => DropMessage(CannotDecodeBlindedPayload(f.failureMessage.message)) - case Right(relayPayload) => SendMessage(relayPayload.nextNodeId, OnionMessage(nextBlinding, nextPacket)) + case Right(relayPayload) => SendMessage(relayPayload.nextNode, OnionMessage(nextBlinding, nextPacket)) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala b/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala index 759f3f325a..268ee2beb6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/message/Postman.scala @@ -22,14 +22,15 @@ import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import akka.actor.typed.{ActorRef, Behavior} import fr.acinq.bitcoin.scalacompat.ByteVector32 import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey -import fr.acinq.eclair.io.{MessageRelay, Switchboard} +import fr.acinq.eclair.io.MessageRelay +import fr.acinq.eclair.io.MessageRelay.RelayPolicy import fr.acinq.eclair.message.OnionMessages.{Destination, RoutingStrategy} import fr.acinq.eclair.payment.offer.OfferManager import fr.acinq.eclair.router.Router import fr.acinq.eclair.router.Router.{MessageRoute, MessageRouteNotFound, MessageRouteResponse} import fr.acinq.eclair.wire.protocol.MessageOnion.{FinalPayload, InvoiceRequestPayload} -import fr.acinq.eclair.wire.protocol.{OnionMessagePayloadTlv, TlvStream} -import fr.acinq.eclair.{NodeParams, randomBytes32, randomKey} +import fr.acinq.eclair.wire.protocol.{OnionMessage, OnionMessagePayloadTlv, TlvStream} +import fr.acinq.eclair.{NodeParams, ShortChannelId, randomBytes32, randomKey} import scala.collection.mutable @@ -62,7 +63,7 @@ object Postman { case class MessageFailed(reason: String) extends MessageStatus // @formatter:on - def apply(nodeParams: NodeParams, switchboard: ActorRef[Switchboard.RelayMessage], router: ActorRef[Router.MessageRouteRequest], offerManager: typed.ActorRef[OfferManager.RequestInvoice]): Behavior[Command] = { + def apply(nodeParams: NodeParams, switchboard: akka.actor.ActorRef, router: ActorRef[Router.MessageRouteRequest], register: akka.actor.ActorRef, offerManager: typed.ActorRef[OfferManager.RequestInvoice]): Behavior[Command] = { Behaviors.setup(context => { context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[OnionMessages.ReceiveMessage](r => WrappedMessage(r.finalPayload))) @@ -85,7 +86,7 @@ object Postman { } Behaviors.same case SendMessage(destination, routingStrategy, messageContent, expectsReply, replyTo) => - val child = context.spawnAnonymous(SendingMessage(nodeParams, switchboard, router, context.self, destination, messageContent, routingStrategy, expectsReply, replyTo)) + val child = context.spawnAnonymous(SendingMessage(nodeParams, router, context.self, switchboard, register, destination, messageContent, routingStrategy, expectsReply, replyTo)) child ! SendingMessage.SendMessage Behaviors.same case Subscribe(pathId, replyTo) => @@ -112,25 +113,27 @@ object SendingMessage { // @formatter:on def apply(nodeParams: NodeParams, - switchboard: ActorRef[Switchboard.RelayMessage], router: ActorRef[Router.MessageRouteRequest], postman: ActorRef[Postman.Command], + switchboard: akka.actor.ActorRef, + register: akka.actor.ActorRef, destination: Destination, message: TlvStream[OnionMessagePayloadTlv], routingStrategy: RoutingStrategy, expectsReply: Boolean, replyTo: ActorRef[Postman.OnionMessageResponse]): Behavior[Command] = { Behaviors.setup(context => { - val actor = new SendingMessage(nodeParams, switchboard, router, postman, destination, message, routingStrategy, expectsReply, replyTo, context) + val actor = new SendingMessage(nodeParams, router, postman, switchboard, register, destination, message, routingStrategy, expectsReply, replyTo, context) actor.start() }) } } private class SendingMessage(nodeParams: NodeParams, - switchboard: ActorRef[Switchboard.RelayMessage], router: ActorRef[Router.MessageRouteRequest], postman: ActorRef[Postman.Command], + switchboard: akka.actor.ActorRef, + register: akka.actor.ActorRef, destination: Destination, message: TlvStream[OnionMessagePayloadTlv], routingStrategy: RoutingStrategy, @@ -193,7 +196,8 @@ private class SendingMessage(nodeParams: NodeParams, replyTo ! Postman.MessageFailed(failure.toString) Behaviors.stopped case Right((nextNodeId, message)) => - switchboard ! Switchboard.RelayMessage(messageId, None, nextNodeId, message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus))) + val relay = context.spawn(Behaviors.supervise(MessageRelay()).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId") + relay ! MessageRelay.RelayMessage(messageId, switchboard, register, nodeParams.nodeId, Right(nextNodeId), message, MessageRelay.RelayAll, Some(context.messageAdapter[MessageRelay.Status](SendingStatus))) waitForSent() } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala index 3ff6daf813..618583120c 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/MessageOnion.scala @@ -17,7 +17,7 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey -import fr.acinq.eclair.UInt64 +import fr.acinq.eclair.{ShortChannelId, UInt64} import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} import fr.acinq.eclair.payment.Bolt12Invoice import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, InvalidTlvPayload, MissingRequiredTlv} @@ -73,7 +73,9 @@ object MessageOnion { /** Per-hop payload for an intermediate node. */ case class IntermediatePayload(records: TlvStream[OnionMessagePayloadTlv], blindedRecords: TlvStream[RouteBlindingEncryptedDataTlv], nextBlinding: PublicKey) extends PerHopPayload { - val nextNodeId: PublicKey = blindedRecords.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].get.nodeId + val nextNode: Either[ShortChannelId, PublicKey] = + blindedRecords.get[RouteBlindingEncryptedDataTlv.OutgoingNodeId].map(outgoingNodeId => Right(outgoingNodeId.nodeId)) + .getOrElse(Left(blindedRecords.get[RouteBlindingEncryptedDataTlv.OutgoingChannelId].get.shortChannelId)) } object IntermediatePayload { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala index 7fa8c9ed0e..178901f3e6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/RouteBlinding.scala @@ -73,7 +73,8 @@ object BlindedRouteData { import RouteBlindingEncryptedDataTlv._ def validateMessageRelayData(records: TlvStream[RouteBlindingEncryptedDataTlv]): Either[InvalidTlvPayload, TlvStream[RouteBlindingEncryptedDataTlv]] = { - if (records.get[OutgoingNodeId].isEmpty) return Left(MissingRequiredTlv(UInt64(4))) + if (records.get[OutgoingNodeId].isEmpty && records.get[OutgoingChannelId].isEmpty) return Left(MissingRequiredTlv(UInt64(4))) + if (records.get[OutgoingNodeId].isDefined && records.get[OutgoingChannelId].isDefined) return Left(ForbiddenTlv(UInt64(4))) if (records.get[PathId].isDefined) return Left(ForbiddenTlv(UInt64(6))) if (records.get[PaymentRelay].isDefined) return Left(ForbiddenTlv(UInt64(10))) if (records.get[PaymentConstraints].isDefined) return Left(ForbiddenTlv(UInt64(12))) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala index aec45f29d2..4df68766cc 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/fixtures/MinimalNodeFixture.scala @@ -100,12 +100,12 @@ object MinimalNodeFixture extends Assertions with Eventually with IntegrationPat val txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcherTyped, bitcoinClient) val channelFactory = Peer.SimpleChannelFactory(nodeParams, watcherTyped, relayer, wallet, txPublisherFactory) val pendingChannelsRateLimiter = system.spawnAnonymous(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, Seq())).onFailure(typed.SupervisorStrategy.resume)) - val peerFactory = Switchboard.SimplePeerFactory(nodeParams, wallet, channelFactory, pendingChannelsRateLimiter) + val peerFactory = Switchboard.SimplePeerFactory(nodeParams, wallet, channelFactory, pendingChannelsRateLimiter, register) val switchboard = system.actorOf(Switchboard.props(nodeParams, peerFactory), "switchboard") val paymentFactory = PaymentInitiator.SimplePaymentFactory(nodeParams, router, register) val paymentInitiator = system.actorOf(PaymentInitiator.props(nodeParams, paymentFactory), "payment-initiator") val channels = nodeParams.db.channels.listLocalChannels() - val postman = system.spawn(Behaviors.supervise(Postman(nodeParams, switchboard.toTyped, router.toTyped, offerManager)).onFailure(typed.SupervisorStrategy.restart), name = "postman") + val postman = system.spawn(Behaviors.supervise(Postman(nodeParams, switchboard, router.toTyped, register, offerManager)).onFailure(typed.SupervisorStrategy.restart), name = "postman") switchboard ! Switchboard.Init(channels) relayer ! PostRestartHtlcCleaner.Init(channels) readyListener.expectMsgAllOf( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala index a11c907caa..fb57e6b07c 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/integration/basic/payment/OfferPaymentSpec.scala @@ -27,18 +27,19 @@ import fr.acinq.eclair.FeatureSupport.Optional import fr.acinq.eclair.Features.{KeySend, RouteBlinding} import fr.acinq.eclair.channel.{DATA_NORMAL, RealScidStatus} import fr.acinq.eclair.integration.basic.fixtures.MinimalNodeFixture -import fr.acinq.eclair.integration.basic.fixtures.MinimalNodeFixture.{connect, getChannelData, getRouterData, knownFundingTxs, nodeParamsFor, openChannel, watcherAutopilot} +import fr.acinq.eclair.integration.basic.fixtures.MinimalNodeFixture.{connect, getChannelData, getPeerChannels, getRouterData, knownFundingTxs, nodeParamsFor, openChannel, watcherAutopilot} import fr.acinq.eclair.integration.basic.fixtures.composite.ThreeNodesFixture import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient, buildRoute} import fr.acinq.eclair.payment._ import fr.acinq.eclair.payment.offer.OfferManager import fr.acinq.eclair.payment.receive.MultiPartHandler.{DummyBlindedHop, ReceivingRoute} import fr.acinq.eclair.payment.send.PaymentInitiator.{SendPaymentToNode, SendSpontaneousPayment} -import fr.acinq.eclair.payment.send.{OfferPayment, PaymentLifecycle} +import fr.acinq.eclair.payment.send.{ClearRecipient, OfferPayment, PaymentLifecycle} +import fr.acinq.eclair.router.Router import fr.acinq.eclair.testutils.FixtureSpec import fr.acinq.eclair.wire.protocol.OfferTypes.{Offer, OfferPaths} import fr.acinq.eclair.wire.protocol.{IncorrectOrUnknownPaymentDetails, InvalidOnionBlinding} -import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, randomBytes32, randomKey} +import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, Features, MilliSatoshi, MilliSatoshiLong, ShortChannelId, randomBytes32, randomKey} import org.scalatest.concurrent.IntegrationPatience import org.scalatest.{Tag, TestData} import scodec.bits.HexStringSyntax @@ -138,8 +139,9 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { val recipientKey = randomKey() val pathId = randomBytes32() val offerPaths = routes.map(route => { - route.nodes.dropRight(1).map(IntermediateNode(_)) - buildRoute(randomKey(), route.nodes.dropRight(1).map(IntermediateNode(_)), Recipient(route.nodes.last, Some(pathId))) + val ourNodeId = route.nodes.last + val intermediateNodes = route.nodes.dropRight(1).map(IntermediateNode(_)) ++ route.dummyHops.map(_ => IntermediateNode(ourNodeId)) + buildRoute(randomKey(), intermediateNodes, Recipient(ourNodeId, Some(pathId))) }) val offer = Offer(None, "test", recipientKey.publicKey, Features.empty, recipient.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(offerPaths))) val handler = recipient.system.spawnAnonymous(offerHandler(amount, routes)) @@ -350,4 +352,28 @@ class OfferPaymentSpec extends FixtureSpec with IntegrationPatience { assert(failure.t == PaymentLifecycle.UpdateMalformedException) } + test("send payment a->b->c compact offer") { f => + import f._ + + val probe = TestProbe() + val amount = 25_000_000 msat + val recipientKey = randomKey() + val pathId = randomBytes32() + + val blindedRoute = buildRoute(randomKey(), Seq(IntermediateNode(bob.nodeId), IntermediateNode(carol.nodeId)), Recipient(carol.nodeId, Some(pathId))) + val offer = Offer(None, "test", recipientKey.publicKey, Features.empty, carol.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(Seq(blindedRoute)))) + val scid_bc = getPeerChannels(bob, carol.nodeId).head.data.asInstanceOf[DATA_NORMAL].shortIds.real.toOption.get + val compactBlindedRoute = buildRoute(randomKey(), Seq(IntermediateNode(bob.nodeId, Some(scid_bc)), IntermediateNode(carol.nodeId, Some(ShortChannelId.toSelf))), Recipient(carol.nodeId, Some(pathId))) + val compactOffer = Offer(None, "test", recipientKey.publicKey, Features.empty, carol.nodeParams.chainHash, additionalTlvs = Set(OfferPaths(Seq(compactBlindedRoute)))) + assert(compactOffer.toString.length < offer.toString.length) + + val receivingRoute = ReceivingRoute(Seq(bob.nodeId, carol.nodeId), maxFinalExpiryDelta) + val handler = carol.system.spawnAnonymous(offerHandler(amount, Seq(receivingRoute))) + carol.offerManager ! OfferManager.RegisterOffer(compactOffer, recipientKey, Some(pathId), handler) + val offerPayment = alice.system.spawnAnonymous(OfferPayment(alice.nodeParams, alice.postman, alice.paymentInitiator)) + val sendPaymentConfig = OfferPayment.SendPaymentConfig(None, connectDirectly = false, maxAttempts = 1, alice.routeParams, blocking = true) + offerPayment ! OfferPayment.PayOffer(probe.ref, compactOffer, amount, 1, sendPaymentConfig) + val payment = verifyPaymentSuccess(compactOffer, amount, probe.expectMsgType[PaymentEvent]) + assert(payment.parts.length == 1) + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala index 035302789f..d85ffc782a 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/MessageRelaySpec.scala @@ -23,13 +23,14 @@ import akka.testkit.TestProbe import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.TestConstants.{Alice, Bob} +import fr.acinq.eclair.channel.Register import fr.acinq.eclair.io.MessageRelay._ import fr.acinq.eclair.io.Peer.{PeerInfo, PeerNotFound} import fr.acinq.eclair.io.Switchboard.GetPeerInfo import fr.acinq.eclair.message.OnionMessages import fr.acinq.eclair.message.OnionMessages.{IntermediateNode, Recipient} import fr.acinq.eclair.wire.protocol.TlvStream -import fr.acinq.eclair.{randomBytes32, randomKey} +import fr.acinq.eclair.{ShortChannelId, randomBytes32, randomKey} import org.scalatest.Outcome import org.scalatest.funsuite.FixtureAnyFunSuiteLike @@ -40,16 +41,17 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val aliceId: PublicKey = Alice.nodeParams.nodeId val bobId: PublicKey = Bob.nodeParams.nodeId - case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], probe: TypedProbe[Status]) + case class FixtureParam(relay: ActorRef[Command], switchboard: TestProbe, register: TestProbe, peerConnection: TypedProbe[Nothing], peer: TypedProbe[Peer.RelayOnionMessage], probe: TypedProbe[Status]) override def withFixture(test: OneArgTest): Outcome = { val switchboard = TestProbe("switchboard")(system.classicSystem) + val register = TestProbe("register")(system.classicSystem) val peerConnection = TypedProbe[Nothing]("peerConnection") val peer = TypedProbe[Peer.RelayOnionMessage]("peer") val probe = TypedProbe[Status]("probe") val relay = testKit.spawn(MessageRelay()) try { - withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, peerConnection, peer, probe))) + withFixture(test.toNoArgTest(FixtureParam(relay, switchboard, register, peerConnection, peer, probe))) } finally { testKit.stop(relay) } @@ -60,7 +62,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() - relay ! RelayMessage(messageId, switchboard.ref, randomKey().publicKey, bobId, message, RelayAll, None) + relay ! RelayMessage(messageId, switchboard.ref, register.ref, randomKey().publicKey, Right(bobId), message, RelayAll, None) val connectToNextPeer = switchboard.expectMsgType[Peer.Connect] assert(connectToNextPeer.nodeId == bobId) @@ -73,7 +75,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() - relay ! RelayMessage(messageId, switchboard.ref, randomKey().publicKey, bobId, message, RelayAll, None) + relay ! RelayMessage(messageId, switchboard.ref, register.ref, randomKey().publicKey, Right(bobId), message, RelayAll, None) val connectToNextPeer = switchboard.expectMsgType[Peer.Connect] assert(connectToNextPeer.nodeId == bobId) @@ -86,7 +88,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() - relay ! RelayMessage(messageId, switchboard.ref, randomKey().publicKey, bobId, message, RelayAll, Some(probe.ref)) + relay ! RelayMessage(messageId, switchboard.ref, register.ref, randomKey().publicKey, Right(bobId), message, RelayAll, Some(probe.ref)) val connectToNextPeer = switchboard.expectMsgType[Peer.Connect] assert(connectToNextPeer.nodeId == bobId) @@ -100,7 +102,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() val previousNodeId = randomKey().publicKey - relay ! RelayMessage(messageId, switchboard.ref, previousNodeId, bobId, message, RelayChannelsOnly, Some(probe.ref)) + relay ! RelayMessage(messageId, switchboard.ref, register.ref, previousNodeId, Right(bobId), message, RelayChannelsOnly, Some(probe.ref)) val getPeerInfo = switchboard.expectMsgType[GetPeerInfo] assert(getPeerInfo.remoteNodeId == previousNodeId) @@ -116,7 +118,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() val previousNodeId = randomKey().publicKey - relay ! RelayMessage(messageId, switchboard.ref, previousNodeId, bobId, message, RelayChannelsOnly, Some(probe.ref)) + relay ! RelayMessage(messageId, switchboard.ref, register.ref, previousNodeId, Right(bobId), message, RelayChannelsOnly, Some(probe.ref)) val getPeerInfo1 = switchboard.expectMsgType[GetPeerInfo] assert(getPeerInfo1.remoteNodeId == previousNodeId) @@ -136,7 +138,7 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) val messageId = randomBytes32() val previousNodeId = randomKey().publicKey - relay ! RelayMessage(messageId, switchboard.ref, previousNodeId, bobId, message, RelayChannelsOnly, None) + relay ! RelayMessage(messageId, switchboard.ref, register.ref, previousNodeId, Right(bobId), message, RelayChannelsOnly, None) val getPeerInfo1 = switchboard.expectMsgType[GetPeerInfo] assert(getPeerInfo1.remoteNodeId == previousNodeId) @@ -148,4 +150,22 @@ class MessageRelaySpec extends ScalaTestWithActorTestKit(ConfigFactory.load("app assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) } + + test("next node specified with channel id") { f => + import f._ + + val Right((_, message)) = OnionMessages.buildMessage(randomKey(), randomKey(), randomKey(), Seq(IntermediateNode(aliceId)), Recipient(bobId, None), TlvStream.empty) + val messageId = randomBytes32() + val scid = ShortChannelId(123456L) + relay ! RelayMessage(messageId, switchboard.ref, register.ref, randomKey().publicKey, Left(scid), message, RelayAll, None) + + val getNextNodeId = register.expectMsgType[Register.GetNextNodeId] + assert(getNextNodeId.shortChannelId == scid) + getNextNodeId.replyTo ! Some(bobId) + + val connectToNextPeer = switchboard.expectMsgType[Peer.Connect] + assert(connectToNextPeer.nodeId == bobId) + connectToNextPeer.replyTo ! PeerConnection.ConnectionResult.AlreadyConnected(peerConnection.ref.toClassic, peer.ref.toClassic) + assert(peer.expectMessageType[Peer.RelayOnionMessage].msg == message) + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index 452a8bb9f6..d1545d2621 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -50,7 +50,7 @@ class PeerSpec extends FixtureSpec { override implicit val patienceConfig: PatienceConfig = PatienceConfig(timeout = 30 seconds, interval = 1 second) - case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, system: ActorSystem, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe, channel: TestProbe, switchboard: TestProbe, mockLimiter: ActorRef) { + case class FixtureParam(nodeParams: NodeParams, remoteNodeId: PublicKey, system: ActorSystem, peer: TestFSMRef[Peer.State, Peer.Data, Peer], peerConnection: TestProbe, channel: TestProbe, switchboard: TestProbe, register: TestProbe, mockLimiter: ActorRef) { implicit val implicitSystem: ActorSystem = system def cleanup(): Unit = TestKit.shutdownActorSystem(system) @@ -63,6 +63,7 @@ class PeerSpec extends FixtureSpec { val peerConnection = TestProbe() val channel = TestProbe() val switchboard = TestProbe() + val register = TestProbe() import com.softwaremill.quicklens._ val aliceParams = TestConstants.Alice.nodeParams @@ -98,9 +99,9 @@ class PeerSpec extends FixtureSpec { case _ => KeepRunning }) - val peer: TestFSMRef[Peer.State, Peer.Data, Peer] = TestFSMRef(new Peer(aliceParams, remoteNodeId, wallet, FakeChannelFactory(channel), switchboard.ref, mockLimiter.ref)) + val peer: TestFSMRef[Peer.State, Peer.Data, Peer] = TestFSMRef(new Peer(aliceParams, remoteNodeId, wallet, FakeChannelFactory(channel), switchboard.ref, register.ref, mockLimiter.ref)) - FixtureParam(aliceParams, remoteNodeId, system, peer, peerConnection, channel, switchboard, mockLimiter.ref) + FixtureParam(aliceParams, remoteNodeId, system, peer, peerConnection, channel, switchboard, register, mockLimiter.ref) } def cleanupFixture(fixture: FixtureParam): Unit = fixture.cleanup() diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala index 6fe3d4b415..7f2101ab31 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/message/OnionMessagesSpec.scala @@ -27,7 +27,7 @@ import fr.acinq.eclair.wire.protocol.OnionMessagePayloadTlv.EncryptedData import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataCodecs.blindedRouteDataCodec import fr.acinq.eclair.wire.protocol.RouteBlindingEncryptedDataTlv._ import fr.acinq.eclair.wire.protocol.{GenericTlv, OnionMessage, OnionMessagePayloadTlv, OnionRoutingCodecs, RouteBlindingEncryptedDataCodecs, RouteBlindingEncryptedDataTlv, TlvStream} -import fr.acinq.eclair.{UInt64, randomBytes, randomKey} +import fr.acinq.eclair.{ShortChannelId, UInt64, randomBytes, randomKey} import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.scalatest.funsuite.AnyFunSuite @@ -115,13 +115,13 @@ class OnionMessagesSpec extends AnyFunSuite { // Checking that the onion is relayed properly process(alice, onionForAlice) match { - case SendMessage(nextNodeId, onionForBob) => + case SendMessage(Right(nextNodeId), onionForBob) => assert(nextNodeId == bob.publicKey) process(bob, onionForBob) match { - case SendMessage(nextNodeId, onionForCarol) => + case SendMessage(Right(nextNodeId), onionForCarol) => assert(nextNodeId == carol.publicKey) process(carol, onionForCarol) match { - case SendMessage(nextNodeId, onionForDave) => + case SendMessage(Right(nextNodeId), onionForDave) => assert(nextNodeId == dave.publicKey) process(dave, onionForDave) match { case ReceiveMessage(finalPayload) => assert(finalPayload.pathId_opt.contains(hex"01234567")) @@ -234,10 +234,10 @@ class OnionMessagesSpec extends AnyFunSuite { // Checking that the onion is relayed properly process(alice, messageForAlice) match { - case SendMessage(nextNodeId, onionForBob) => + case SendMessage(Right(nextNodeId), onionForBob) => assert(nextNodeId == bob.publicKey) process(bob, onionForBob) match { - case SendMessage(nextNodeId, onionForCarol) => + case SendMessage(Right(nextNodeId), onionForCarol) => assert(nextNodeId == carol.publicKey) process(carol, onionForCarol) match { case ReceiveMessage(finalPayload) => assert(finalPayload.pathId_opt.contains(pathId)) @@ -286,7 +286,7 @@ class OnionMessagesSpec extends AnyFunSuite { Recipient(nodeKey.publicKey, Some(ByteVector.fromValidHex((json \ "path_id").extract[String])), (json \ "padding").extract[Option[String]].map(ByteVector.fromValidHex(_)), getCustomTlvs(json)) def makeIntermediateNode(nodeKey: PrivateKey, json: JValue): IntermediateNode = - IntermediateNode(nodeKey.publicKey, (json \ "padding").extract[Option[String]].map(ByteVector.fromValidHex(_)), getCustomTlvs(json)) + IntermediateNode(nodeKey.publicKey, None, (json \ "padding").extract[Option[String]].map(ByteVector.fromValidHex(_)), getCustomTlvs(json)) val blindingSecretBob = PrivateKey(ByteVector32.fromValidHex(((testVector \ "generate" \ "hops")(1) \ "blinding_secret").extract[String])) val pathId = ByteVector.fromValidHex(((testVector \ "generate" \ "hops")(3) \ "tlvs" \ "path_id").extract[String]) @@ -328,13 +328,13 @@ class OnionMessagesSpec extends AnyFunSuite { // Checking that the onion is relayed properly process(alice, message) match { - case SendMessage(nextNodeId, onionForBob) => + case SendMessage(Right(nextNodeId), onionForBob) => assert(nextNodeId == bob.publicKey) process(bob, onionForBob) match { - case SendMessage(nextNodeId, onionForCarol) => + case SendMessage(Right(nextNodeId), onionForCarol) => assert(nextNodeId == carol.publicKey) process(carol, onionForCarol) match { - case SendMessage(nextNodeId, onionForDave) => + case SendMessage(Right(nextNodeId), onionForDave) => assert(nextNodeId == dave.publicKey) process(dave, onionForDave) match { case ReceiveMessage(finalPayload) => assert(finalPayload.pathId_opt.contains(pathId)) @@ -347,4 +347,33 @@ class OnionMessagesSpec extends AnyFunSuite { case x => fail(x.toString) } } + + test("route with channel ids") { + val nodeKey = randomKey() + val alice = randomKey() + val alice2bob = ShortChannelId(1) + val bob = randomKey() + val bob2carol = ShortChannelId(2) + val carol = randomKey() + val sessionKey = randomKey() + val blindingSecret = randomKey() + val pathId = randomBytes(64) + val Right((_, messageForAlice)) = buildMessage(nodeKey, sessionKey, blindingSecret, IntermediateNode(alice.publicKey, outgoingChannel_opt = Some(alice2bob)) :: IntermediateNode(bob.publicKey, outgoingChannel_opt = Some(bob2carol)) :: Nil, Recipient(carol.publicKey, Some(pathId)), TlvStream.empty) + + // Checking that the onion is relayed properly + process(alice, messageForAlice) match { + case SendMessage(Left(outgoingChannelId), onionForBob) => + assert(outgoingChannelId == alice2bob) + process(bob, onionForBob) match { + case SendMessage(Left(outgoingChannelId), onionForCarol) => + assert(outgoingChannelId == bob2carol) + process(carol, onionForCarol) match { + case ReceiveMessage(finalPayload) => assert(finalPayload.pathId_opt.contains(pathId)) + case x => fail(x.toString) + } + case x => fail(x.toString) + } + case x => fail(x.toString) + } + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala index db1e093900..e9c48d4c41 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/message/PostmanSpec.scala @@ -19,11 +19,13 @@ package fr.acinq.eclair.message import akka.actor.testkit.typed.scaladsl.{ScalaTestWithActorTestKit, TestProbe} import akka.actor.typed.ActorRef import akka.actor.typed.eventstream.EventStream +import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import com.typesafe.config.ConfigFactory import fr.acinq.bitcoin.scalacompat.Block +import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute} import fr.acinq.eclair.io.MessageRelay.{Disconnected, Sent} -import fr.acinq.eclair.io.Switchboard.RelayMessage +import fr.acinq.eclair.io.{Peer, PeerConnection} import fr.acinq.eclair.message.OnionMessages.RoutingStrategy.FindRoute import fr.acinq.eclair.message.OnionMessages.{BlindedPath, IntermediateNode, ReceiveMessage, Recipient, buildMessage, buildRoute} import fr.acinq.eclair.message.Postman._ @@ -39,15 +41,16 @@ import scodec.bits.HexStringSyntax class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike { - case class FixtureParam(postman: ActorRef[Command], nodeParams: NodeParams, messageSender: TestProbe[OnionMessageResponse], switchboard: TestProbe[RelayMessage], offerManager: TestProbe[RequestInvoice], router: TestProbe[MessageRouteRequest]) + case class FixtureParam(postman: ActorRef[Command], nodeParams: NodeParams, messageSender: TestProbe[OnionMessageResponse], switchboard: TestProbe[Any], offerManager: TestProbe[RequestInvoice], router: TestProbe[MessageRouteRequest]) override def withFixture(test: OneArgTest): Outcome = { val nodeParams = TestConstants.Alice.nodeParams val messageSender = TestProbe[OnionMessageResponse]("messageSender") - val switchboard = TestProbe[RelayMessage]("switchboard") + val switchboard = TestProbe[Any]("switchboard") val offerManager = TestProbe[RequestInvoice]("offerManager") val router = TestProbe[MessageRouteRequest]("router") - val postman = testKit.spawn(Postman(nodeParams, switchboard.ref, router.ref, offerManager.ref)) + val register = TestProbe[Any]("register") + val postman = testKit.spawn(Postman(nodeParams, switchboard.ref.toClassic, router.ref, register.ref.toClassic, offerManager.ref)) try { withFixture(test.toNoArgTest(FixtureParam(postman, nodeParams, messageSender, switchboard, offerManager, router))) } finally { @@ -55,6 +58,15 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat } } + private def expectRelayToConnected(switchboard: TestProbe[Any], recipientKey: PublicKey): Peer.RelayOnionMessage = { + val Peer.Connect(nextNodeId, _, replyTo, _) = switchboard.expectMessageType[Peer.Connect] + assert(nextNodeId == recipientKey) + val peerConnection = TestProbe[Any]() + val peer = TestProbe[Any]() + replyTo ! PeerConnection.ConnectionResult.AlreadyConnected(peerConnection.ref.toClassic, peer.ref.toClassic) + peer.expectMessageType[Peer.RelayOnionMessage] + } + test("message forwarded only once") { f => import f._ @@ -67,8 +79,7 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat assert(target == recipientKey.publicKey) waitingForRoute ! MessageRoute(Seq.empty, target) - val RelayMessage(messageId, _, nextNodeId, message, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] - assert(nextNodeId == recipientKey.publicKey) + val Peer.RelayOnionMessage(messageId, message, Some(replyTo)) = expectRelayToConnected(switchboard, recipientKey.publicKey) replyTo ! Sent(messageId) val ReceiveMessage(finalPayload) = OnionMessages.process(recipientKey, message) assert(finalPayload.records.unknown == Set(GenericTlv(UInt64(33), hex"abcd"))) @@ -96,8 +107,7 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat assert(target == recipientKey.publicKey) waitingForRoute ! MessageRoute(Seq.empty, target) - val RelayMessage(messageId, _, nextNodeId, _, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] - assert(nextNodeId == recipientKey.publicKey) + val Peer.RelayOnionMessage(messageId, _, Some(replyTo)) = expectRelayToConnected(switchboard, recipientKey.publicKey) replyTo ! Disconnected(messageId) messageSender.expectMessage(MessageFailed("Peer is not connected")) @@ -116,8 +126,7 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat assert(target == recipientKey.publicKey) waitingForRoute ! MessageRoute(Seq.empty, target) - val RelayMessage(messageId, _, nextNodeId, message, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] - assert(nextNodeId == recipientKey.publicKey) + val Peer.RelayOnionMessage(messageId, message, Some(replyTo)) = expectRelayToConnected(switchboard, recipientKey.publicKey) replyTo ! Sent(messageId) val ReceiveMessage(finalPayload) = OnionMessages.process(recipientKey, message) assert(finalPayload.records.unknown == Set(GenericTlv(UInt64(33), hex"abcd"))) @@ -144,8 +153,7 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat assert(target == recipientKey.publicKey) waitingForRoute ! MessageRoute(Seq.empty, target) - val RelayMessage(messageId, _, nextNodeId, message, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] - assert(nextNodeId == recipientKey.publicKey) + val Peer.RelayOnionMessage(messageId, message, Some(replyTo)) = expectRelayToConnected(switchboard, recipientKey.publicKey) replyTo ! Sent(messageId) val ReceiveMessage(finalPayload) = OnionMessages.process(recipientKey, message) assert(finalPayload.records.unknown == Set(GenericTlv(UInt64(33), hex"abcd"))) @@ -163,8 +171,7 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val blindedRoute = buildRoute(randomKey(), Seq(IntermediateNode(nodeParams.nodeId)), Recipient(recipientKey.publicKey, None)) postman ! SendMessage(BlindedPath(blindedRoute), FindRoute, TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(33), hex"abcd"))), expectsReply = false, messageSender.ref) - val RelayMessage(messageId, _, nextNodeId, message, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] - assert(nextNodeId == recipientKey.publicKey) + val Peer.RelayOnionMessage(messageId, message, Some(replyTo)) = expectRelayToConnected(switchboard, recipientKey.publicKey) replyTo ! Sent(messageId) val ReceiveMessage(finalPayload) = OnionMessages.process(recipientKey, message) assert(finalPayload.records.unknown == Set(GenericTlv(UInt64(33), hex"abcd"))) @@ -199,14 +206,13 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat assert(target == d.publicKey) waitingForRoute ! MessageRoute(Seq(a.publicKey, b.publicKey, c.publicKey), target) - val RelayMessage(messageId, _, next1, message1, _, Some(replyTo)) = switchboard.expectMessageType[RelayMessage] - assert(next1 == a.publicKey) + val Peer.RelayOnionMessage(messageId, message1, Some(replyTo)) = expectRelayToConnected(switchboard, a.publicKey) replyTo ! Sent(messageId) - val OnionMessages.SendMessage(next2, message2) = OnionMessages.process(a, message1) + val OnionMessages.SendMessage(Right(next2), message2) = OnionMessages.process(a, message1) assert(next2 == b.publicKey) - val OnionMessages.SendMessage(next3, message3) = OnionMessages.process(b, message2) + val OnionMessages.SendMessage(Right(next3), message3) = OnionMessages.process(b, message2) assert(next3 == c.publicKey) - val OnionMessages.SendMessage(next4, message4) = OnionMessages.process(c, message3) + val OnionMessages.SendMessage(Right(next4), message4) = OnionMessages.process(c, message3) assert(next4 == d.publicKey) val OnionMessages.ReceiveMessage(payload) = OnionMessages.process(d, message4) assert(payload.records.unknown == Set(GenericTlv(UInt64(11), hex"012345"))) @@ -218,11 +224,11 @@ class PostmanSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("applicat val Right((next5, reply)) = OnionMessages.buildMessage(d, randomKey(), randomKey(), Nil, OnionMessages.BlindedPath(replyPath), TlvStream(Set.empty[OnionMessagePayloadTlv], Set(GenericTlv(UInt64(13), hex"6789")))) assert(next5 == c.publicKey) - val OnionMessages.SendMessage(next6, message6) = OnionMessages.process(c, reply) + val OnionMessages.SendMessage(Right(next6), message6) = OnionMessages.process(c, reply) assert(next6 == b.publicKey) - val OnionMessages.SendMessage(next7, message7) = OnionMessages.process(b, message6) + val OnionMessages.SendMessage(Right(next7), message7) = OnionMessages.process(b, message6) assert(next7 == a.publicKey) - val OnionMessages.SendMessage(next8, message8) = OnionMessages.process(a, message7) + val OnionMessages.SendMessage(Right(next8), message8) = OnionMessages.process(a, message7) assert(next8 == nodeParams.nodeId) val OnionMessages.ReceiveMessage(replyPayload) = OnionMessages.process(nodeParams.privateKey, message8) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala index c34c3fec41..db27a01ddd 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/MessageOnionCodecsSpec.scala @@ -30,7 +30,7 @@ class MessageOnionCodecsSpec extends AnyFunSuiteLike { assert(decoded == expected) val nextNodeId = randomKey().publicKey val Right(payload) = IntermediatePayload.validate(decoded, TlvStream(RouteBlindingEncryptedDataTlv.OutgoingNodeId(nextNodeId)), randomKey().publicKey) - assert(payload.nextNodeId == nextNodeId) + assert(payload.nextNode == Right(nextNodeId)) val encoded = perHopPayloadCodec.encode(expected).require.bytes assert(encoded == bin) }