Skip to content

Commit

Permalink
Allow using outgoing channel id for onion messages (#2762)
Browse files Browse the repository at this point in the history
Allow using a short channel id instead of a public key to designate the next node a message should be relayed to.
  • Loading branch information
thomash-acinq committed Nov 2, 2023
1 parent 2879a54 commit ca3f681
Show file tree
Hide file tree
Showing 16 changed files with 236 additions and 101 deletions.
4 changes: 2 additions & 2 deletions eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ class Setup(val datadir: File,
txPublisherFactory = Channel.SimpleTxPublisherFactory(nodeParams, watcher, bitcoinClient)
channelFactory = Peer.SimpleChannelFactory(nodeParams, watcher, relayer, bitcoinClient, txPublisherFactory)
pendingChannelsRateLimiter = system.spawn(Behaviors.supervise(PendingChannelsRateLimiter(nodeParams, router.toTyped, channels)).onFailure(typed.SupervisorStrategy.resume), name = "pending-channels-rate-limiter")
peerFactory = Switchboard.SimplePeerFactory(nodeParams, bitcoinClient, channelFactory, pendingChannelsRateLimiter)
peerFactory = Switchboard.SimplePeerFactory(nodeParams, bitcoinClient, channelFactory, pendingChannelsRateLimiter, register)

switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, peerFactory), "switchboard", SupervisorStrategy.Resume))
_ = switchboard ! Switchboard.Init(channels)
Expand All @@ -376,7 +376,7 @@ class Setup(val datadir: File,

balanceActor = system.spawn(BalanceActor(nodeParams.db, bitcoinClient, channelsListener, nodeParams.balanceCheckInterval), name = "balance-actor")

postman = system.spawn(Behaviors.supervise(Postman(nodeParams, switchboard.toTyped, router.toTyped, offerManager)).onFailure(typed.SupervisorStrategy.restart), name = "postman")
postman = system.spawn(Behaviors.supervise(Postman(nodeParams, switchboard, router.toTyped, register, offerManager)).onFailure(typed.SupervisorStrategy.restart), name = "postman")

kit = Kit(
nodeParams = nodeParams,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package fr.acinq.eclair.channel

import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed
import akka.actor.{Actor, ActorLogging, ActorRef, Props}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
Expand Down Expand Up @@ -77,6 +78,9 @@ class Register() extends Actor with ActorLogging {

case Symbol("channelsTo") => sender() ! channelsTo

case GetNextNodeId(replyTo, shortChannelId) =>
replyTo ! shortIds.get(shortChannelId).flatMap(cid => channelsTo.get(cid))

case fwd@Forward(replyTo, channelId, msg) =>
// for backward compatibility with legacy ask, we use the replyTo as sender
val compatReplyTo = if (replyTo == null) sender() else replyTo.toClassic
Expand Down Expand Up @@ -106,4 +110,6 @@ object Register {
case class ForwardFailure[T](fwd: Forward[T])
case class ForwardShortIdFailure[T](fwd: ForwardShortId[T])
// @formatter:on

case class GetNextNodeId(replyTo: typed.ActorRef[Option[PublicKey]], shortChannelId: ShortChannelId)
}
55 changes: 48 additions & 7 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/MessageRelay.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,31 @@
package fr.acinq.eclair.io

import akka.actor.typed.Behavior
import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.scaladsl.adapter.TypedActorRefOps
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.{ActorRef, typed}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.channel.Register
import fr.acinq.eclair.io.Peer.{PeerInfo, PeerInfoResponse}
import fr.acinq.eclair.io.Switchboard.GetPeerInfo
import fr.acinq.eclair.wire.protocol.OnionMessage

object MessageRelay {
// @formatter:off
sealed trait Command
case class RelayMessage(messageId: ByteVector32, switchboard: ActorRef, prevNodeId: PublicKey, nextNodeId: PublicKey, msg: OnionMessage, policy: RelayPolicy, replyTo_opt: Option[typed.ActorRef[Status]]) extends Command
case class RelayMessage(messageId: ByteVector32,
switchboard: ActorRef,
register: ActorRef,
prevNodeId: PublicKey,
nextNode: Either[ShortChannelId, PublicKey],
msg: OnionMessage,
policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]) extends Command
case class WrappedPeerInfo(peerInfo: PeerInfoResponse) extends Command
case class WrappedConnectionResult(result: PeerConnection.ConnectionResult) extends Command
case class WrappedOptionalNodeId(nodeId_opt: Option[PublicKey]) extends Command

sealed trait Status {
val messageId: ByteVector32
Expand All @@ -41,12 +51,15 @@ object MessageRelay {
case class AgainstPolicy(messageId: ByteVector32, policy: RelayPolicy) extends Failure {
override def toString: String = s"Relay prevented by policy $policy"
}
case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure{
case class ConnectionFailure(messageId: ByteVector32, failure: PeerConnection.ConnectionResult.Failure) extends Failure {
override def toString: String = s"Can't connect to peer: ${failure.toString}"
}
case class Disconnected(messageId: ByteVector32) extends Failure{
case class Disconnected(messageId: ByteVector32) extends Failure {
override def toString: String = "Peer is not connected"
}
case class UnknownOutgoingChannel(messageId: ByteVector32, outgoingChannelId: ShortChannelId) extends Failure {
override def toString: String = s"Unknown outgoing channel: $outgoingChannelId"
}

sealed trait RelayPolicy
case object RelayChannelsOnly extends RelayPolicy
Expand All @@ -55,7 +68,37 @@ object MessageRelay {

def apply(): Behavior[Command] = {
Behaviors.receivePartial {
case (context, RelayMessage(messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt)) =>
case (context, RelayMessage(messageId, switchboard, register, prevNodeId, Left(outgoingChannelId), msg, policy, replyTo_opt)) =>
register ! Register.GetNextNodeId(context.messageAdapter(WrappedOptionalNodeId), outgoingChannelId)
waitForNextNodeId(messageId, switchboard, prevNodeId, outgoingChannelId, msg, policy, replyTo_opt)
case (context, RelayMessage(messageId, switchboard, _, prevNodeId, Right(nextNodeId), msg, policy, replyTo_opt)) =>
withNextNodeId(context, messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt)
}
}

def waitForNextNodeId(messageId: ByteVector32,
switchboard: ActorRef,
prevNodeId: PublicKey,
outgoingChannelId: ShortChannelId,
msg: OnionMessage,
policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] =
Behaviors.receivePartial {
case (_, WrappedOptionalNodeId(None)) =>
replyTo_opt.foreach(_ ! UnknownOutgoingChannel(messageId, outgoingChannelId))
Behaviors.stopped
case (context, WrappedOptionalNodeId(Some(nextNodeId))) =>
withNextNodeId(context, messageId, switchboard, prevNodeId, nextNodeId, msg, policy, replyTo_opt)
}

def withNextNodeId(context: ActorContext[Command],
messageId: ByteVector32,
switchboard: ActorRef,
prevNodeId: PublicKey,
nextNodeId: PublicKey,
msg: OnionMessage,
policy: RelayPolicy,
replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] =
policy match {
case RelayChannelsOnly =>
switchboard ! GetPeerInfo(context.messageAdapter(WrappedPeerInfo), prevNodeId)
Expand All @@ -64,8 +107,6 @@ object MessageRelay {
switchboard ! Peer.Connect(nextNodeId, None, context.messageAdapter(WrappedConnectionResult).toClassic, isPersistent = false)
waitForConnection(messageId, msg, replyTo_opt)
}
}
}

def waitForPreviousPeer(messageId: ByteVector32, switchboard: ActorRef, nextNodeId: PublicKey, msg: OnionMessage, replyTo_opt: Option[typed.ActorRef[Status]]): Behavior[Command] = {
Behaviors.receivePartial {
Expand Down
13 changes: 6 additions & 7 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,11 @@ import fr.acinq.eclair.io.MessageRelay.Status
import fr.acinq.eclair.io.Monitoring.Metrics
import fr.acinq.eclair.io.OpenChannelInterceptor.{OpenChannelInitiator, OpenChannelNonInitiator}
import fr.acinq.eclair.io.PeerConnection.KillReason
import fr.acinq.eclair.io.Switchboard.RelayMessage
import fr.acinq.eclair.message.OnionMessages
import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes
import fr.acinq.eclair.wire.protocol
import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, NodeAddress, OnionMessage, RoutingMessage, UnknownMessage, Warning}

import scala.concurrent.duration.DurationInt

/**
* This actor represents a logical peer. There is one [[Peer]] per unique remote node id at all time.
*
Expand All @@ -54,7 +51,7 @@ import scala.concurrent.duration.DurationInt
*
* Created by PM on 26/08/2016.
*/
class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, switchboard: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {
class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, switchboard: ActorRef, register: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends FSMDiagnosticActorLogging[Peer.State, Peer.Data] {

import Peer._

Expand Down Expand Up @@ -280,8 +277,10 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainP
OnionMessages.process(nodeParams.privateKey, msg) match {
case OnionMessages.DropMessage(reason) =>
log.debug("dropping message from {}: {}", remoteNodeId.value.toHex, reason.toString)
case OnionMessages.SendMessage(nextNodeId, message) if nodeParams.features.hasFeature(Features.OnionMessages) =>
switchboard ! RelayMessage(randomBytes32(), Some(remoteNodeId), nextNodeId, message, nodeParams.onionMessageConfig.relayPolicy, None)
case OnionMessages.SendMessage(nextNode, message) if nodeParams.features.hasFeature(Features.OnionMessages) =>
val messageId = randomBytes32()
val relay = context.spawn(Behaviors.supervise(MessageRelay()).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId")
relay ! MessageRelay.RelayMessage(messageId, switchboard, register, remoteNodeId, nextNode, message, nodeParams.onionMessageConfig.relayPolicy, None)
case OnionMessages.SendMessage(_, _) =>
log.debug("dropping message from {}: relaying onion messages is disabled", remoteNodeId.value.toHex)
case received: OnionMessages.ReceiveMessage =>
Expand Down Expand Up @@ -459,7 +458,7 @@ object Peer {
context.actorOf(Channel.props(nodeParams, wallet, remoteNodeId, watcher, relayer, txPublisherFactory))
}

def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: ChannelFactory, switchboard: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]): Props = Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory, switchboard, pendingChannelsRateLimiter))
def props(nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnchainPubkeyCache, channelFactory: ChannelFactory, switchboard: ActorRef, register: ActorRef, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]): Props = Props(new Peer(nodeParams, remoteNodeId, wallet, channelFactory, switchboard, register, pendingChannelsRateLimiter))

// @formatter:off

Expand Down
13 changes: 2 additions & 11 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,14 @@ import akka.actor.typed.receptionist.{Receptionist, ServiceKey}
import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.scaladsl.adapter.{ClassicActorContextOps, ClassicActorRefOps, ClassicActorSystemOps, TypedActorRefOps}
import akka.actor.{Actor, ActorContext, ActorLogging, ActorRef, OneForOneStrategy, Props, Stash, Status, SupervisorStrategy, typed}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.blockchain.OnchainPubkeyCache
import fr.acinq.eclair.channel.Helpers.Closing
import fr.acinq.eclair.channel._
import fr.acinq.eclair.io.IncomingConnectionsTracker.TrackIncomingConnection
import fr.acinq.eclair.io.MessageRelay.RelayPolicy
import fr.acinq.eclair.io.Peer.{PeerInfoResponse, PeerNotFound}
import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes
import fr.acinq.eclair.router.Router.RouterConf
import fr.acinq.eclair.wire.protocol.OnionMessage
import fr.acinq.eclair.{NodeParams, SubscriptionsComplete}

/**
Expand Down Expand Up @@ -122,10 +119,6 @@ class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory)
}

case GetRouterPeerConf => sender() ! RouterPeerConf(nodeParams.routerConf, nodeParams.peerConnectionConf)

case RelayMessage(messageId, prevNodeId, nextNodeId, dataToRelay, relayPolicy, replyTo) =>
val relay = context.spawn(Behaviors.supervise(MessageRelay()).onFailure(typed.SupervisorStrategy.stop), s"relay-message-$messageId")
relay ! MessageRelay.RelayMessage(messageId, self, prevNodeId.getOrElse(nodeParams.nodeId), nextNodeId, dataToRelay, relayPolicy, replyTo)
}

/**
Expand Down Expand Up @@ -166,9 +159,9 @@ object Switchboard {
def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef
}

case class SimplePeerFactory(nodeParams: NodeParams, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command]) extends PeerFactory {
case class SimplePeerFactory(nodeParams: NodeParams, wallet: OnchainPubkeyCache, channelFactory: Peer.ChannelFactory, pendingChannelsRateLimiter: typed.ActorRef[PendingChannelsRateLimiter.Command], register: ActorRef) extends PeerFactory {
override def spawn(context: ActorContext, remoteNodeId: PublicKey): ActorRef =
context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory, context.self, pendingChannelsRateLimiter), name = peerActorName(remoteNodeId))
context.actorOf(Peer.props(nodeParams, remoteNodeId, wallet, channelFactory, context.self, register, pendingChannelsRateLimiter), name = peerActorName(remoteNodeId))
}

def props(nodeParams: NodeParams, peerFactory: PeerFactory) = Props(new Switchboard(nodeParams, peerFactory))
Expand All @@ -183,8 +176,6 @@ object Switchboard {

case object GetRouterPeerConf extends RemoteTypes
case class RouterPeerConf(routerConf: RouterConf, peerConf: PeerConnection.Conf) extends RemoteTypes

case class RelayMessage(messageId: ByteVector32, prevNodeId: Option[PublicKey], nextNodeId: PublicKey, message: OnionMessage, relayPolicy: RelayPolicy, replyTo_opt: Option[typed.ActorRef[MessageRelay.Status]])
// @formatter:on

}
Loading

0 comments on commit ca3f681

Please sign in to comment.