From 0780fc2216aba1012b8936e55a080c871d83b3c8 Mon Sep 17 00:00:00 2001 From: Fabrice Drouin Date: Thu, 22 Aug 2019 18:58:13 +0200 Subject: [PATCH] Extended Queries: use TLV format for optional data (#1072) * Extended Queries: use TLV format for optional data Optional query extensions now use TLV instead of a custom format. Flags are encoded as varint instead of bytes as originally proposed. With the current proposal they will all fit on a single byte, but will be much easier to extends this way. * Move query message TLVs to their own namespace We add one new class for each TLV type, with specific TLV types, and encapsulate codecs. * Optional TLVs are represented as a list, not an optional list TLVs that extend regular LN messages can be represented as a TlvStream and not an Option[TlvStream] since we don't need to explicitely terminate the stream (either by preprending its length or using a specific terminator) as we do in Onion TLVs. No TLVs simply means that the TLV stream is empty. * Update to match BOLT PR Checksums in ReplyChannelRange now have the same encoding as short channel ids and timestamps: one byte for the encoding type (uncompressed or zlib) followed by encoded data. * TLV Stream: Implement a generic "get" method for TLV fields If a have a TLV stream of type MyTLV which is a subtype of TLV, and MyTLV1 and MYTLV2 are both subtypes of MyTLV then we can use stream.get[MyTLV1] to get the TLV record of type MYTLV1 (if any) in our TLV stream. * Extended range queries: Implement latest BOLT changes Checksums are just transmitted as a raw array, with optional compression as it would be useless here. * Use extended range queries on regtest and testnet We will use them on mainnet as soon as https://github.com/lightningnetwork/lightning-rfc/pull/557 has been merged. * Address review comments * Router: rework handling of ReplyChannelRange We remove the ugly and inefficient zipWithIndex we had before * NodeParams: move fee base check to its proper place * Router: minor cleanup --- .../scala/fr/acinq/eclair/NodeParams.scala | 9 +- .../main/scala/fr/acinq/eclair/io/Peer.scala | 13 +- .../scala/fr/acinq/eclair/router/Router.scala | 127 +++++++++++------- .../fr/acinq/eclair/wire/CommonCodecs.scala | 6 +- .../eclair/wire/LightningMessageCodecs.scala | 70 ++++------ .../eclair/wire/LightningMessageTypes.scala | 61 ++++----- .../eclair/wire/QueryChannelRangeTlv.scala | 37 +++++ .../eclair/wire/QueryShortChannelIdsTlv.scala | 41 ++++++ .../eclair/wire/ReplyChannelRangeTlv.scala | 64 +++++++++ .../scala/fr/acinq/eclair/wire/TlvTypes.scala | 13 +- .../scala/fr/acinq/eclair/io/PeerSpec.scala | 12 +- .../router/ChannelRangeQueriesSpec.scala | 35 ++--- .../acinq/eclair/router/RoutingSyncSpec.scala | 8 +- .../wire/ExtendedQueriesCodecsSpec.scala | 94 +++++++++++++ .../wire/LightningMessageCodecsSpec.scala | 104 +++++++++----- .../fr/acinq/eclair/wire/TlvCodecsSpec.scala | 6 + 16 files changed, 509 insertions(+), 191 deletions(-) create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryChannelRangeTlv.scala create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryShortChannelIdsTlv.scala create mode 100644 eclair-core/src/main/scala/fr/acinq/eclair/wire/ReplyChannelRangeTlv.scala create mode 100644 eclair-core/src/test/scala/fr/acinq/eclair/wire/ExtendedQueriesCodecsSpec.scala diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 2daa66510f..093d3b8a64 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -33,6 +33,7 @@ import fr.acinq.eclair.router.RouterConf import fr.acinq.eclair.tor.Socks5ProxyParams import fr.acinq.eclair.wire.{Color, NodeAddress} import scodec.bits.ByteVector + import scala.collection.JavaConversions._ import scala.concurrent.duration.FiniteDuration @@ -76,7 +77,6 @@ case class NodeParams(keyManager: KeyManager, routerConf: RouterConf, socksProxy_opt: Option[Socks5ProxyParams], maxPaymentAttempts: Int) { - val privateKey = keyManager.nodeKey.privateKey val nodeId = keyManager.nodeId } @@ -186,6 +186,11 @@ object NodeParams { claimMainBlockTarget = config.getInt("on-chain-fees.target-blocks.claim-main") ) + val feeBase = MilliSatoshi(config.getInt("fee-base-msat")) + // fee base is in msat but is encoded on 32 bits and not 64 in the BOLTs, which is why it has + // to be below 0x100000000 msat which is about 42 mbtc + require(feeBase <= MilliSatoshi(0xFFFFFFFFL), "fee-base-msat must be below 42 mbtc") + NodeParams( keyManager = keyManager, alias = nodeAlias, @@ -209,7 +214,7 @@ object NodeParams { toRemoteDelayBlocks = config.getInt("to-remote-delay-blocks"), maxToLocalDelayBlocks = config.getInt("max-to-local-delay-blocks"), minDepthBlocks = config.getInt("mindepth-blocks"), - feeBase = MilliSatoshi(config.getInt("fee-base-msat")), + feeBase = feeBase, feeProportionalMillionth = config.getInt("fee-proportional-millionths"), reserveToFundingRatio = config.getDouble("reserve-to-funding-ratio"), maxReserveToFundingRatio = config.getDouble("max-reserve-to-funding-ratio"), diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 32c4d52cf2..6f76f5dfa6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -25,8 +25,7 @@ import akka.event.Logging.MDC import akka.util.Timeout import com.google.common.net.HostAndPort import fr.acinq.bitcoin.Crypto.PublicKey -import fr.acinq.bitcoin.{ByteVector32, DeterministicWallet, Protocol, Satoshi} -import fr.acinq.eclair +import fr.acinq.bitcoin.{Block, ByteVector32, DeterministicWallet, Protocol, Satoshi} import fr.acinq.eclair.blockchain.EclairWallet import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.TransportHandler @@ -155,7 +154,15 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A if (remoteHasChannelRangeQueriesOptional || remoteHasChannelRangeQueriesMandatory) { // if they support channel queries, always ask for their filter // NB: we always add extended info; if peer doesn't understand them it will ignore them - router ! SendChannelQuery(remoteNodeId, d.transport, flags_opt = Some(ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS)) + + // README: for now we do not activate extended queries on mainnet + val flags_opt = nodeParams.chainHash match { + case Block.RegtestGenesisBlock.hash | Block.TestnetGenesisBlock.hash => + log.info("using extended range queries") + Some(QueryChannelRangeTlv.QueryFlags(QueryChannelRangeTlv.QueryFlags.WANT_ALL)) + case _ => None + } + router ! SendChannelQuery(remoteNodeId, d.transport, flags_opt = flags_opt) } // let's bring existing/requested channels online diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index c6bf7cdb8b..3d80cba7e0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -16,7 +16,7 @@ package fr.acinq.eclair.router -import java.util.zip.Adler32 +import java.util.zip.CRC32 import akka.Done import akka.actor.{ActorRef, Props, Status} @@ -35,9 +35,9 @@ import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Graph.{RichWeight, WeightRatios} import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ - import shapeless.HNil +import scala.annotation.tailrec import scala.collection.immutable.{SortedMap, TreeMap} import scala.collection.{SortedSet, mutable} import scala.compat.Platform @@ -76,13 +76,13 @@ case class RouteResponse(hops: Seq[Hop], ignoreNodes: Set[PublicKey], ignoreChan } case class ExcludeChannel(desc: ChannelDesc) // this is used when we get a TemporaryChannelFailure, to give time for the channel to recover (note that exclusions are directed) case class LiftChannelExclusion(desc: ChannelDesc) -case class SendChannelQuery(remoteNodeId: PublicKey, to: ActorRef, flags_opt: Option[ExtendedQueryFlags]) +case class SendChannelQuery(remoteNodeId: PublicKey, to: ActorRef, flags_opt: Option[QueryChannelRangeTlv]) case object GetRoutingState case class RoutingState(channels: Iterable[ChannelAnnouncement], updates: Iterable[ChannelUpdate], nodes: Iterable[NodeAnnouncement]) case class Stash(updates: Map[ChannelUpdate, Set[ActorRef]], nodes: Map[NodeAnnouncement, Set[ActorRef]]) case class Rebroadcast(channels: Map[ChannelAnnouncement, Set[ActorRef]], updates: Map[ChannelUpdate, Set[ActorRef]], nodes: Map[NodeAnnouncement, Set[ActorRef]]) -case class ShortChannelIdAndFlag(shortChannelId: ShortChannelId, flag: Byte) +case class ShortChannelIdAndFlag(shortChannelId: ShortChannelId, flag: Long) case class Sync(pending: List[RoutingMessage], total: Int) @@ -431,7 +431,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ // ask for everything // we currently send only one query_channel_range message per peer, when we just (re)connected to it, so we don't // have to worry about sending a new query_channel_range when another query is still in progress - val query = QueryChannelRange(nodeParams.chainHash, firstBlockNum = 0, numberOfBlocks = Int.MaxValue, extendedQueryFlags_opt = flags_opt) + val query = QueryChannelRange(nodeParams.chainHash, firstBlockNum = 0L, numberOfBlocks = Int.MaxValue.toLong, TlvStream(flags_opt.toList)) log.info("sending query_channel_range={}", query) remote ! query @@ -508,37 +508,62 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ val shortChannelIds: SortedSet[ShortChannelId] = d.channels.keySet.filter(keep(firstBlockNum, numberOfBlocks, _)) log.info("replying with {} items for range=({}, {})", shortChannelIds.size, firstBlockNum, numberOfBlocks) split(shortChannelIds) - .foreach(chunk => - transport ! ReplyChannelRange(chainHash, chunk.firstBlock, chunk.numBlocks, + .foreach(chunk => { + val (timestamps, checksums) = routingMessage.queryFlags_opt match { + case Some(extension) if extension.wantChecksums | extension.wantTimestamps => + // we always compute timestamps and checksums even if we don't need both, overhead is negligible + val (timestamps, checksums) = chunk.shortChannelIds.map(getChannelDigestInfo(d.channels, d.updates)).unzip + val encodedTimestamps = if (extension.wantTimestamps) Some(ReplyChannelRangeTlv.EncodedTimestamps(EncodingType.UNCOMPRESSED, timestamps)) else None + val encodedChecksums = if (extension.wantChecksums) Some(ReplyChannelRangeTlv.EncodedChecksums(checksums)) else None + (encodedTimestamps, encodedChecksums) + case _ => (None, None) + } + val reply = ReplyChannelRange(chainHash, chunk.firstBlock, chunk.numBlocks, complete = 1, shortChannelIds = EncodedShortChannelIds(EncodingType.UNCOMPRESSED, chunk.shortChannelIds), - extendedQueryFlags_opt = extendedQueryFlags_opt, - extendedInfo_opt = extendedQueryFlags_opt map { - case ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS => ExtendedInfo(chunk.shortChannelIds.map(getChannelDigestInfo(d.channels, d.updates))) - })) + timestamps = timestamps, + checksums = checksums) + transport ! reply + }) stay - case Event(PeerRoutingMessage(transport, remoteNodeId, routingMessage@ReplyChannelRange(chainHash, _, _, _, shortChannelIds, extendedQueryFlags_opt, extendedInfo_opt)), d) => + case Event(PeerRoutingMessage(transport, remoteNodeId, routingMessage@ReplyChannelRange(chainHash, _, _, _, shortChannelIds, _)), d) => sender ! TransportHandler.ReadAck(routingMessage) - val shortChannelIdAndFlags = shortChannelIds.array - .zipWithIndex - .map { case (shortChannelId: ShortChannelId, idx) => ShortChannelIdAndFlag(shortChannelId, computeFlag(d.channels, d.updates)(shortChannelId, extendedInfo_opt.map(_.array(idx)))) } - .filter(_.flag != 0) + + @tailrec + def loop(ids: List[ShortChannelId], timestamps: List[ReplyChannelRangeTlv.Timestamps], checksums: List[ReplyChannelRangeTlv.Checksums], acc: List[ShortChannelIdAndFlag] = List.empty[ShortChannelIdAndFlag]): List[ShortChannelIdAndFlag] = { + ids match { + case Nil => acc.reverse + case head :: tail => + val flag = computeFlag(d.channels, d.updates)(head, timestamps.headOption, checksums.headOption) + // 0 means nothing to query, just don't include it + val acc1 = if (flag != 0) ShortChannelIdAndFlag(head, flag) :: acc else acc + loop(tail, timestamps.drop(1), checksums.drop(1), acc1) + } + } + + val timestamps_opt = routingMessage.timestamps_opt.map(_.timestamps).getOrElse(List.empty[ReplyChannelRangeTlv.Timestamps]) + val checksums_opt = routingMessage.checksums_opt.map(_.checksums).getOrElse(List.empty[ReplyChannelRangeTlv.Checksums]) + + val shortChannelIdAndFlags = loop(shortChannelIds.array, timestamps_opt, checksums_opt) + val (channelCount, updatesCount) = shortChannelIdAndFlags.foldLeft((0, 0)) { case ((c, u), ShortChannelIdAndFlag(_, flag)) => - val c1 = c + (if (QueryFlagTypes.includeAnnouncement(flag)) 1 else 0) - val u1 = u + (if (QueryFlagTypes.includeUpdate1(flag)) 1 else 0) + (if (QueryFlagTypes.includeUpdate2(flag)) 1 else 0) + val c1 = c + (if (QueryShortChannelIdsTlv.QueryFlagType.includeAnnouncement(flag)) 1 else 0) + val u1 = u + (if (QueryShortChannelIdsTlv.QueryFlagType.includeUpdate1(flag)) 1 else 0) + (if (QueryShortChannelIdsTlv.QueryFlagType.includeUpdate2(flag)) 1 else 0) (c1, u1) } - log.info(s"received reply_channel_range with {} channels, we're missing {} channel announcements and {} updates, format={} queryFlags=${extendedQueryFlags_opt.getOrElse("n/a")}", shortChannelIds.array.size, channelCount, updatesCount, shortChannelIds.encoding) + log.info(s"received reply_channel_range with {} channels, we're missing {} channel announcements and {} updates, format={}", shortChannelIds.array.size, channelCount, updatesCount, shortChannelIds.encoding) // we update our sync data to this node (there may be multiple channel range responses and we can only query one set of ids at a time) val replies = shortChannelIdAndFlags .grouped(SHORTID_WINDOW) .map(chunk => QueryShortChannelIds(chainHash, shortChannelIds = EncodedShortChannelIds(shortChannelIds.encoding, chunk.map(_.shortChannelId)), - queryFlags_opt = extendedQueryFlags_opt map { - case _ => EncodedQueryFlags(shortChannelIds.encoding, chunk.map(_.flag)) - })) + if (routingMessage.timestamps_opt.isDefined || routingMessage.checksums_opt.isDefined) + TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(shortChannelIds.encoding, chunk.map(_.flag))) + else + TlvStream.empty + )) .toList val (sync1, replynow_opt) = updateSync(d.sync, remoteNodeId, replies) // we only send a rely right away if there were no pending requests @@ -554,16 +579,16 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ case ((c, u), (shortChannelId, idx)) => var c1 = c var u1 = u - val flag = queryFlags_opt.map(_.array(idx)).getOrElse(QueryFlagTypes.INCLUDE_ALL) + val flag = routingMessage.queryFlags_opt.map(_.array(idx)).getOrElse(QueryShortChannelIdsTlv.QueryFlagType.INCLUDE_ALL) d.channels.get(shortChannelId) match { case None => log.warning("received query for shortChannelId={} that we don't have", shortChannelId) case Some(ca) => - if (QueryFlagTypes.includeAnnouncement(flag)) { + if (QueryShortChannelIdsTlv.QueryFlagType.includeAnnouncement(flag)) { transport ! ca c1 = c1 + 1 } - if (QueryFlagTypes.includeUpdate1(flag)) d.updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId1, ca.nodeId2)).foreach { u => transport ! u; u1 = u1 + 1 } - if (QueryFlagTypes.includeUpdate2(flag)) d.updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId2, ca.nodeId1)).foreach { u => transport ! u; u1 = u1 + 1 } + if (QueryShortChannelIdsTlv.QueryFlagType.includeUpdate1(flag)) d.updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId1, ca.nodeId2)).foreach { u => transport ! u; u1 = u1 + 1 } + if (QueryShortChannelIdsTlv.QueryFlagType.includeUpdate2(flag)) d.updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId2, ca.nodeId1)).foreach { u => transport ! u; u1 = u1 + 1 } } (c1, u1) } @@ -721,7 +746,7 @@ class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[ // when we're sending updates to ourselves (transport_opt, remoteNodeId_opt) match { case (Some(transport), Some(remoteNodeId)) => - val query = QueryShortChannelIds(u.chainHash, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(u.shortChannelId)), queryFlags_opt = None) + val query = QueryShortChannelIds(u.chainHash, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(u.shortChannelId)), TlvStream.empty) d.sync.get(remoteNodeId) match { case Some(sync) => // we already have a pending request to that node, let's add this channel to the list and we'll get it later @@ -828,25 +853,41 @@ object Router { height >= firstBlockNum && height <= (firstBlockNum + numberOfBlocks) } - def computeFlag(channels: SortedMap[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate])(shortChannelId: ShortChannelId, theirInfo_opt: Option[TimestampsAndChecksums]): Byte = { - var flag = 0 - theirInfo_opt match { - case Some(theirInfo) if channels.contains(shortChannelId) => - val ourInfo = Router.getChannelDigestInfo(channels, updates)(shortChannelId) + def computeFlag(channels: SortedMap[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate])( + shortChannelId: ShortChannelId, + timestamps_opt: Option[ReplyChannelRangeTlv.Timestamps], + checksums_opt: Option[ReplyChannelRangeTlv.Checksums]): Long = { + import QueryShortChannelIdsTlv.QueryFlagType + var flag = 0L + (timestamps_opt, checksums_opt) match { + case (Some(theirTimestamps), Some(theirChecksums)) if channels.contains(shortChannelId) => + val (ourTimestamps, ourChecksums) = Router.getChannelDigestInfo(channels, updates)(shortChannelId) // we request their channel_update if all those conditions are met: // - it is more recent than ours // - it is different from ours, or it is the same but ours is about to be stale // - it is not stale itself - if (ourInfo.timestamp1 < theirInfo.timestamp1 && (ourInfo.checksum1 != theirInfo.checksum1 || isAlmostStale(ourInfo.timestamp1)) && !isStale(theirInfo.timestamp1)) flag = flag | QueryFlagTypes.INCLUDE_CHANNEL_UPDATE_1 - if (ourInfo.timestamp2 < theirInfo.timestamp2 && (ourInfo.checksum2 != theirInfo.checksum2 || isAlmostStale(ourInfo.timestamp1)) && !isStale(theirInfo.timestamp2)) flag = flag | QueryFlagTypes.INCLUDE_CHANNEL_UPDATE_2 - case None if channels.contains(shortChannelId) => + if (ourTimestamps.timestamp1 < theirTimestamps.timestamp1 && (ourChecksums.checksum1 != theirChecksums.checksum1 || isAlmostStale(ourTimestamps.timestamp1)) && !isStale(theirTimestamps.timestamp1)) flag = flag | QueryFlagType.INCLUDE_CHANNEL_UPDATE_1 + if (ourTimestamps.timestamp2 < theirTimestamps.timestamp2 && (ourChecksums.checksum2 != theirChecksums.checksum2 || isAlmostStale(ourTimestamps.timestamp1)) && !isStale(theirTimestamps.timestamp2)) flag = flag | QueryFlagType.INCLUDE_CHANNEL_UPDATE_2 + case (Some(theirTimestamps), None) if channels.contains(shortChannelId) => + val (ourTimestamps, _) = Router.getChannelDigestInfo(channels, updates)(shortChannelId) + // we request their channel_update if all those conditions are met: + // - it is more recent than ours + // - it is not stale itself + if (ourTimestamps.timestamp1 < theirTimestamps.timestamp1 && !isStale(theirTimestamps.timestamp1)) flag = flag | QueryFlagType.INCLUDE_CHANNEL_UPDATE_1 + if (ourTimestamps.timestamp2 < theirTimestamps.timestamp2 && !isStale(theirTimestamps.timestamp2)) flag = flag | QueryFlagType.INCLUDE_CHANNEL_UPDATE_2 + case (None, Some(theirChecksums)) if channels.contains(shortChannelId) => + val (_, ourChecksums) = Router.getChannelDigestInfo(channels, updates)(shortChannelId) + // this should not happen as we will not ask for checksums without asking for timestamps too + if (ourChecksums.checksum1 != theirChecksums.checksum1 && theirChecksums.checksum1 != 0) flag = flag | QueryFlagType.INCLUDE_CHANNEL_UPDATE_1 + if (ourChecksums.checksum2 != theirChecksums.checksum2 && theirChecksums.checksum2 != 0) flag = flag | QueryFlagType.INCLUDE_CHANNEL_UPDATE_2 + case (None, None) if channels.contains(shortChannelId) => // we know this channel: we only request their channel updates - flag = QueryFlagTypes.INCLUDE_CHANNEL_UPDATE_1 | QueryFlagTypes.INCLUDE_CHANNEL_UPDATE_2 + flag = QueryFlagType.INCLUDE_CHANNEL_UPDATE_1 | QueryFlagType.INCLUDE_CHANNEL_UPDATE_2 case _ => // we don't know this channel: we request everything - flag = QueryFlagTypes.INCLUDE_CHANNEL_ANNOUNCEMENT | QueryFlagTypes.INCLUDE_CHANNEL_UPDATE_1 | QueryFlagTypes.INCLUDE_CHANNEL_UPDATE_2 + flag = QueryFlagType.INCLUDE_CHANNEL_ANNOUNCEMENT | QueryFlagType.INCLUDE_CHANNEL_UPDATE_1 | QueryFlagType.INCLUDE_CHANNEL_UPDATE_2 } - flag.toByte + flag } /** @@ -900,7 +941,7 @@ object Router { timestamp } - def getChannelDigestInfo(channels: SortedMap[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate])(shortChannelId: ShortChannelId): TimestampsAndChecksums = { + def getChannelDigestInfo(channels: SortedMap[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate])(shortChannelId: ShortChannelId): (ReplyChannelRangeTlv.Timestamps, ReplyChannelRangeTlv.Checksums) = { val c = channels(shortChannelId) val u1_opt = updates.get(ChannelDesc(c.shortChannelId, c.nodeId1, c.nodeId2)) val u2_opt = updates.get(ChannelDesc(c.shortChannelId, c.nodeId2, c.nodeId1)) @@ -908,17 +949,13 @@ object Router { val timestamp2 = u2_opt.map(_.timestamp).getOrElse(0L) val checksum1 = u1_opt.map(getChecksum).getOrElse(0L) val checksum2 = u2_opt.map(getChecksum).getOrElse(0L) - TimestampsAndChecksums( - timestamp1 = timestamp1, - checksum1 = checksum1, - timestamp2 = timestamp2, - checksum2 = checksum2) + (ReplyChannelRangeTlv.Timestamps(timestamp1 = timestamp1, timestamp2 = timestamp2), ReplyChannelRangeTlv.Checksums(checksum1 = checksum1, checksum2 = checksum2)) } def getChecksum(u: ChannelUpdate): Long = { import u._ val data = serializationResult(LightningMessageCodecs.channelUpdateChecksumCodec.encode(shortChannelId :: messageFlags :: channelFlags :: cltvExpiryDelta :: htlcMinimumMsat :: feeBaseMsat :: feeProportionalMillionths :: htlcMaximumMsat :: HNil)) - val checksum = new Adler32() + val checksum = new CRC32() checksum.update(data.toArray) checksum.getValue } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala index fc749c8fe8..b0cb2da720 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala @@ -19,8 +19,8 @@ package fr.acinq.eclair.wire import java.net.{Inet4Address, Inet6Address, InetAddress} import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} -import fr.acinq.eclair.crypto.Mac32 import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Satoshi} +import fr.acinq.eclair.crypto.Mac32 import fr.acinq.eclair.{MilliSatoshi, ShortChannelId, UInt64} import org.apache.commons.codec.binary.Base32 import scodec.bits.{BitVector, ByteVector} @@ -57,6 +57,10 @@ object CommonCodecs { val satoshi: Codec[Satoshi] = uint64overflow.xmapc(l => Satoshi(l))(_.toLong) val millisatoshi: Codec[MilliSatoshi] = uint64overflow.xmapc(l => MilliSatoshi(l))(_.amount) + // this is needed because some millisatoshi values are encoded on 32 bits in the BOLTs + // this codec will fail if the amount does not fit on 32 bits + val millisatoshi32: Codec[MilliSatoshi] = uint32.xmapc(l => MilliSatoshi(l))(_.amount) + /** * We impose a minimal encoding on some values (such as varint and truncated int) to ensure that signed hashes can be * re-computed correctly. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala index 0f185fbcb9..9c753925f3 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala @@ -16,8 +16,8 @@ package fr.acinq.eclair.wire +import fr.acinq.eclair.wire import fr.acinq.eclair.wire.CommonCodecs._ -import fr.acinq.eclair.{MilliSatoshi, wire} import scodec.Codec import scodec.codecs._ @@ -188,12 +188,11 @@ object LightningMessageCodecs { ("channelFlags" | byte) :: ("cltvExpiryDelta" | uint16) :: ("htlcMinimumMsat" | millisatoshi) :: - ("feeBaseMsat" | uint32.xmapc(l => MilliSatoshi(l))(_.amount)) :: + ("feeBaseMsat" | millisatoshi32) :: ("feeProportionalMillionths" | uint32) :: ("htlcMaximumMsat" | conditional((messageFlags & 1) != 0, millisatoshi)) }) - val channelUpdateWitnessCodec = ("chainHash" | bytes32) :: ("shortChannelId" | shortchannelid) :: @@ -202,7 +201,7 @@ object LightningMessageCodecs { ("channelFlags" | byte) :: ("cltvExpiryDelta" | uint16) :: ("htlcMinimumMsat" | millisatoshi) :: - ("feeBaseMsat" | uint32.xmapc(l => MilliSatoshi(l))(_.amount)) :: + ("feeBaseMsat" | millisatoshi32) :: ("feeProportionalMillionths" | uint32) :: ("htlcMaximumMsat" | conditional((messageFlags & 1) != 0, millisatoshi)) :: ("unknownFields" | bytes) @@ -217,51 +216,38 @@ object LightningMessageCodecs { .\(0) { case a@EncodedShortChannelIds(EncodingType.UNCOMPRESSED, _) => a }((provide[EncodingType](EncodingType.UNCOMPRESSED) :: list(shortchannelid)).as[EncodedShortChannelIds]) .\(1) { case a@EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, _) => a }((provide[EncodingType](EncodingType.COMPRESSED_ZLIB) :: zlib(list(shortchannelid))).as[EncodedShortChannelIds]) - val encodedQueryFlagsCodec: Codec[EncodedQueryFlags] = - discriminated[EncodedQueryFlags].by(byte) - .\(0) { case a@EncodedQueryFlags(EncodingType.UNCOMPRESSED, _) => a }((provide[EncodingType](EncodingType.UNCOMPRESSED) :: list(byte)).as[EncodedQueryFlags]) - .\(1) { case a@EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, _) => a }((provide[EncodingType](EncodingType.COMPRESSED_ZLIB) :: zlib(list(byte))).as[EncodedQueryFlags]) - - val queryShortChannelIdsCodec: Codec[QueryShortChannelIds] = ( - ("chainHash" | bytes32) :: - ("shortChannelIds" | variableSizeBytes(uint16, encodedShortChannelIdsCodec)) :: - ("queryFlags_opt" | optional(bitsRemaining, variableSizeBytes(uint16, encodedQueryFlagsCodec))) + val queryShortChannelIdsCodec: Codec[QueryShortChannelIds] = { + Codec( + ("chainHash" | bytes32) :: + ("shortChannelIds" | variableSizeBytes(uint16, encodedShortChannelIdsCodec)) :: + ("tlvStream" | QueryShortChannelIdsTlv.codec) ).as[QueryShortChannelIds] + } val replyShortChanelIdsEndCodec: Codec[ReplyShortChannelIdsEnd] = ( ("chainHash" | bytes32) :: ("complete" | byte) ).as[ReplyShortChannelIdsEnd] - val extendedQueryFlagsCodec: Codec[ExtendedQueryFlags] = - discriminated[ExtendedQueryFlags].by(byte) - .typecase(1, provide(ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS)) - - val queryChannelRangeCodec: Codec[QueryChannelRange] = ( - ("chainHash" | bytes32) :: - ("firstBlockNum" | uint32) :: - ("numberOfBlocks" | uint32) :: - ("optionExtendedQueryFlags" | optional(bitsRemaining, extendedQueryFlagsCodec)) - ).as[QueryChannelRange] - - val timestampsAndChecksumsCodec: Codec[TimestampsAndChecksums] = ( - ("timestamp1" | uint32) :: - ("timestamp2" | uint32) :: - ("checksum1" | uint32) :: - ("checksum2" | uint32) - ).as[TimestampsAndChecksums] - - val extendedInfoCodec: Codec[ExtendedInfo] = list(timestampsAndChecksumsCodec).as[ExtendedInfo] - - val replyChannelRangeCodec: Codec[ReplyChannelRange] = ( - ("chainHash" | bytes32) :: - ("firstBlockNum" | uint32) :: - ("numberOfBlocks" | uint32) :: - ("complete" | byte) :: - ("shortChannelIds" | variableSizeBytes(uint16, encodedShortChannelIdsCodec)) :: - ("optionExtendedQueryFlags_opt" | optional(bitsRemaining, extendedQueryFlagsCodec)) :: - ("extendedInfo_opt" | optional(bitsRemaining, variableSizeBytes(uint16, extendedInfoCodec))) - ).as[ReplyChannelRange] + val queryChannelRangeCodec: Codec[QueryChannelRange] = { + Codec( + ("chainHash" | bytes32) :: + ("firstBlockNum" | uint32) :: + ("numberOfBlocks" | uint32) :: + ("tlvStream" | QueryChannelRangeTlv.codec) + ).as[QueryChannelRange] + } + + val replyChannelRangeCodec: Codec[ReplyChannelRange] = { + Codec( + ("chainHash" | bytes32) :: + ("firstBlockNum" | uint32) :: + ("numberOfBlocks" | uint32) :: + ("complete" | byte) :: + ("shortChannelIds" | variableSizeBytes(uint16, encodedShortChannelIdsCodec)) :: + ("tlvStream" | ReplyChannelRangeTlv.codec) + ).as[ReplyChannelRange] + } val gossipTimestampFilterCodec: Codec[GossipTimestampFilter] = ( ("chainHash" | bytes32) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala index b84e5e3972..a172c3d2e5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala @@ -20,8 +20,8 @@ import java.net.{Inet4Address, Inet6Address, InetAddress, InetSocketAddress} import java.nio.charset.StandardCharsets import com.google.common.base.Charsets -import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Satoshi} import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} +import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Satoshi} import fr.acinq.eclair.{MilliSatoshi, ShortChannelId, UInt64} import scodec.bits.ByteVector @@ -233,61 +233,54 @@ object EncodingType { } // @formatter:on -case object QueryFlagTypes { - val INCLUDE_CHANNEL_ANNOUNCEMENT: Byte = 1 - val INCLUDE_CHANNEL_UPDATE_1: Byte = 2 - val INCLUDE_CHANNEL_UPDATE_2: Byte = 4 - val INCLUDE_ALL: Byte = (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2).toByte - - def includeAnnouncement(flag: Byte) = (flag & QueryFlagTypes.INCLUDE_CHANNEL_ANNOUNCEMENT) != 0 - - def includeUpdate1(flag: Byte) = (flag & QueryFlagTypes.INCLUDE_CHANNEL_UPDATE_1) != 0 - - def includeUpdate2(flag: Byte) = (flag & QueryFlagTypes.INCLUDE_CHANNEL_UPDATE_2) != 0 -} case class EncodedShortChannelIds(encoding: EncodingType, array: List[ShortChannelId]) -case class EncodedQueryFlags(encoding: EncodingType, - array: List[Byte]) case class QueryShortChannelIds(chainHash: ByteVector32, shortChannelIds: EncodedShortChannelIds, - queryFlags_opt: Option[EncodedQueryFlags]) extends RoutingMessage with HasChainHash + tlvStream: TlvStream[QueryShortChannelIdsTlv] = TlvStream.empty) extends RoutingMessage with HasChainHash { + val queryFlags_opt: Option[QueryShortChannelIdsTlv.EncodedQueryFlags] = tlvStream.get[QueryShortChannelIdsTlv.EncodedQueryFlags] +} case class ReplyShortChannelIdsEnd(chainHash: ByteVector32, complete: Byte) extends RoutingMessage with HasChainHash -// @formatter:off -sealed trait ExtendedQueryFlags -object ExtendedQueryFlags { - case object TIMESTAMPS_AND_CHECKSUMS extends ExtendedQueryFlags -} -// @formatter:on case class QueryChannelRange(chainHash: ByteVector32, firstBlockNum: Long, numberOfBlocks: Long, - extendedQueryFlags_opt: Option[ExtendedQueryFlags]) extends RoutingMessage with HasChainHash + tlvStream: TlvStream[QueryChannelRangeTlv] = TlvStream.empty) extends RoutingMessage { + val queryFlags_opt: Option[QueryChannelRangeTlv.QueryFlags] = tlvStream.get[QueryChannelRangeTlv.QueryFlags] +} case class ReplyChannelRange(chainHash: ByteVector32, firstBlockNum: Long, numberOfBlocks: Long, complete: Byte, shortChannelIds: EncodedShortChannelIds, - extendedQueryFlags_opt: Option[ExtendedQueryFlags], - extendedInfo_opt: Option[ExtendedInfo]) extends RoutingMessage with HasChainHash { - extendedInfo_opt.foreach(extendedInfo => require(shortChannelIds.array.size == extendedInfo.array.size, s"shortChannelIds.size=${shortChannelIds.array.size} != extendedInfo.size=${extendedInfo.array.size}")) + tlvStream: TlvStream[ReplyChannelRangeTlv] = TlvStream.empty) extends RoutingMessage { + val timestamps_opt: Option[ReplyChannelRangeTlv.EncodedTimestamps] = tlvStream.get[ReplyChannelRangeTlv.EncodedTimestamps] + + val checksums_opt: Option[ReplyChannelRangeTlv.EncodedChecksums] = tlvStream.get[ReplyChannelRangeTlv.EncodedChecksums] } -case class GossipTimestampFilter(chainHash: ByteVector32, - firstTimestamp: Long, - timestampRange: Long) extends RoutingMessage with HasChainHash +object ReplyChannelRange { + def apply(chainHash: ByteVector32, + firstBlockNum: Long, + numberOfBlocks: Long, + complete: Byte, + shortChannelIds: EncodedShortChannelIds, + timestamps: Option[ReplyChannelRangeTlv.EncodedTimestamps], + checksums: Option[ReplyChannelRangeTlv.EncodedChecksums]) = { + timestamps.foreach(ts => require(ts.timestamps.length == shortChannelIds.array.length)) + checksums.foreach(cs => require(cs.checksums.length == shortChannelIds.array.length)) + new ReplyChannelRange(chainHash, firstBlockNum, numberOfBlocks, complete, shortChannelIds, TlvStream(timestamps.toList ::: checksums.toList)) + } +} -case class TimestampsAndChecksums(timestamp1: Long, - checksum1: Long, - timestamp2: Long, - checksum2: Long) -case class ExtendedInfo(array: List[TimestampsAndChecksums]) \ No newline at end of file +case class GossipTimestampFilter(chainHash: ByteVector32, + firstTimestamp: Long, + timestampRange: Long) extends RoutingMessage with HasChainHash \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryChannelRangeTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryChannelRangeTlv.scala new file mode 100644 index 0000000000..0dc5f57050 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryChannelRangeTlv.scala @@ -0,0 +1,37 @@ +package fr.acinq.eclair.wire + +import fr.acinq.eclair.UInt64 +import fr.acinq.eclair.wire.CommonCodecs.{shortchannelid, varint, varintoverflow} +import scodec.Codec +import scodec.codecs._ + +sealed trait QueryChannelRangeTlv extends Tlv + +object QueryChannelRangeTlv { + /** + * Optional query flag that is appended to QueryChannelRange + * @param flag bit 1 set means I want timestamps, bit 2 set means I want checksums + */ + case class QueryFlags(flag: Long) extends QueryChannelRangeTlv { + val wantTimestamps = QueryFlags.wantTimestamps(flag) + + val wantChecksums = QueryFlags.wantChecksums(flag) + } + + case object QueryFlags { + val WANT_TIMESTAMPS: Long = 1 + val WANT_CHECKSUMS: Long = 2 + val WANT_ALL: Long = (WANT_TIMESTAMPS | WANT_CHECKSUMS) + + def wantTimestamps(flag: Long) = (flag & WANT_TIMESTAMPS) != 0 + + def wantChecksums(flag: Long) = (flag & WANT_CHECKSUMS) != 0 + } + + val queryFlagsCodec: Codec[QueryFlags] = Codec(("flag" | varintoverflow)).as[QueryFlags] + + val codec: Codec[TlvStream[QueryChannelRangeTlv]] = TlvCodecs.tlvStream(discriminated.by(varint) + .typecase(UInt64(1), variableSizeBytesLong(varintoverflow, queryFlagsCodec)) + ) + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryShortChannelIdsTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryShortChannelIdsTlv.scala new file mode 100644 index 0000000000..3c878b2d7c --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryShortChannelIdsTlv.scala @@ -0,0 +1,41 @@ +package fr.acinq.eclair.wire + +import fr.acinq.eclair.UInt64 +import fr.acinq.eclair.wire.CommonCodecs.{shortchannelid, varint, varintoverflow} +import scodec.Codec +import scodec.codecs.{byte, discriminated, list, provide, variableSizeBytesLong, zlib} + +sealed trait QueryShortChannelIdsTlv extends Tlv + +object QueryShortChannelIdsTlv { + + /** + * Optional TLV-based query message that can be appended to QueryShortChannelIds + * @param encoding 0 means uncompressed, 1 means compressed with zlib + * @param array array of query flags, each flags specifies the info we want for a given channel + */ + case class EncodedQueryFlags(encoding: EncodingType, array: List[Long]) extends QueryShortChannelIdsTlv + + case object QueryFlagType { + val INCLUDE_CHANNEL_ANNOUNCEMENT: Long = 1 + val INCLUDE_CHANNEL_UPDATE_1: Long = 2 + val INCLUDE_CHANNEL_UPDATE_2: Long = 4 + val INCLUDE_ALL: Long = (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2) + + def includeAnnouncement(flag: Long) = (flag & INCLUDE_CHANNEL_ANNOUNCEMENT) != 0 + + def includeUpdate1(flag: Long) = (flag & INCLUDE_CHANNEL_UPDATE_1) != 0 + + def includeUpdate2(flag: Long) = (flag & INCLUDE_CHANNEL_UPDATE_2) != 0 + } + + val encodedQueryFlagsCodec: Codec[EncodedQueryFlags] = + discriminated[EncodedQueryFlags].by(byte) + .\(0) { case a@EncodedQueryFlags(EncodingType.UNCOMPRESSED, _) => a }((provide[EncodingType](EncodingType.UNCOMPRESSED) :: list(varintoverflow)).as[EncodedQueryFlags]) + .\(1) { case a@EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, _) => a }((provide[EncodingType](EncodingType.COMPRESSED_ZLIB) :: zlib(list(varintoverflow))).as[EncodedQueryFlags]) + + + val codec: Codec[TlvStream[QueryShortChannelIdsTlv]] = TlvCodecs.tlvStream(discriminated.by(varint) + .typecase(UInt64(1), variableSizeBytesLong(varintoverflow, encodedQueryFlagsCodec)) + ) +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/ReplyChannelRangeTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/ReplyChannelRangeTlv.scala new file mode 100644 index 0000000000..bde4605551 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/ReplyChannelRangeTlv.scala @@ -0,0 +1,64 @@ +package fr.acinq.eclair.wire + +import fr.acinq.eclair.{UInt64, wire} +import fr.acinq.eclair.wire.CommonCodecs.{varint, varintoverflow} +import scodec.Codec +import scodec.codecs._ + +sealed trait ReplyChannelRangeTlv extends Tlv + +object ReplyChannelRangeTlv { + + /** + * + * @param timestamp1 timestamp for node 1, or 0 + * @param timestamp2 timestamp for node 2, or 0 + */ + case class Timestamps(timestamp1: Long, timestamp2: Long) + + /** + * Optional timestamps TLV that can be appended to ReplyChannelRange + * + * @param encoding same convention as for short channel ids + * @param timestamps + */ + case class EncodedTimestamps(encoding: EncodingType, timestamps: List[Timestamps]) extends ReplyChannelRangeTlv + + /** + * + * @param checksum1 checksum for node 1, or 0 + * @param checksum2 checksum for node 2, or 0 + */ + case class Checksums(checksum1: Long, checksum2: Long) + + /** + * Optional checksums TLV that can be appended to ReplyChannelRange + * + * @param checksums + */ + case class EncodedChecksums(checksums: List[Checksums]) extends ReplyChannelRangeTlv + + val timestampsCodec: Codec[Timestamps] = ( + ("checksum1" | uint32) :: + ("checksum2" | uint32) + ).as[Timestamps] + + val encodedTimestampsCodec: Codec[EncodedTimestamps] = variableSizeBytesLong(varintoverflow, + discriminated[EncodedTimestamps].by(byte) + .\(0) { case a@EncodedTimestamps(EncodingType.UNCOMPRESSED, _) => a }((provide[EncodingType](EncodingType.UNCOMPRESSED) :: list(timestampsCodec)).as[EncodedTimestamps]) + .\(1) { case a@EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, _) => a }((provide[EncodingType](EncodingType.COMPRESSED_ZLIB) :: zlib(list(timestampsCodec))).as[EncodedTimestamps]) + ) + + val checksumsCodec: Codec[Checksums] = ( + ("checksum1" | uint32) :: + ("checksum2" | uint32) + ).as[Checksums] + + val encodedChecksumsCodec: Codec[EncodedChecksums] = variableSizeBytesLong(varintoverflow, list(checksumsCodec)).as[EncodedChecksums] + + val innerCodec = discriminated[ReplyChannelRangeTlv].by(varint) + .typecase(UInt64(1), encodedTimestampsCodec) + .typecase(UInt64(3), encodedChecksumsCodec) + + val codec: Codec[TlvStream[ReplyChannelRangeTlv]] = TlvCodecs.tlvStream(innerCodec) +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala index e4eea124f8..b87b1ebcb6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala @@ -19,6 +19,8 @@ package fr.acinq.eclair.wire import fr.acinq.eclair.UInt64 import scodec.bits.ByteVector +import scala.reflect.ClassTag + /** * Created by t-bast on 20/06/2019. */ @@ -45,9 +47,18 @@ case class GenericTlv(tag: UInt64, value: ByteVector) extends Tlv * @param unknown unknown tlv records. * @tparam T the stream namespace is a trait extending the top-level tlv trait. */ -case class TlvStream[T <: Tlv](records: Traversable[T], unknown: Traversable[GenericTlv] = Nil) +case class TlvStream[T <: Tlv](records: Traversable[T], unknown: Traversable[GenericTlv] = Nil) { + /** + * + * @tparam R input type parameter, must be a subtype of the main TLV type + * @return the TLV record of of type that matches the input type parameter if any (there can be at most one, since BOLTs specify + * that TLV records are supposed to be unique + */ + def get[R <: T : ClassTag]: Option[R] = records.collectFirst { case r: R => r } +} object TlvStream { + def empty[T <: Tlv] = TlvStream[T](Nil, Nil) def apply[T <: Tlv](records: T*): TlvStream[T] = TlvStream(records, Nil) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index 074169d83c..e9f0414147 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -26,14 +26,13 @@ import fr.acinq.bitcoin.{Satoshi} import fr.acinq.eclair.TestConstants._ import fr.acinq.eclair._ import fr.acinq.eclair.blockchain.{EclairWallet, TestWallet} -import fr.acinq.eclair.channel.{ChannelCreated, HasCommitments} import fr.acinq.eclair.channel.states.StateTestsHelperMethods +import fr.acinq.eclair.channel.{ChannelCreated, HasCommitments} import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer._ -import fr.acinq.eclair.router.{Rebroadcast, RoutingSyncSpec} import fr.acinq.eclair.router.RoutingSyncSpec.makeFakeRoutingInfo -import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, EncodedShortChannelIds, EncodingType, Error, IPv4, NodeAddress, NodeAnnouncement, Ping, Pong} -import fr.acinq.eclair.{ShortChannelId, TestkitBaseClass, randomBytes, wire, _} +import fr.acinq.eclair.router.{Rebroadcast, RoutingSyncSpec} +import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, EncodedShortChannelIds, EncodingType, Error, IPv4, NodeAddress, NodeAnnouncement, Ping, Pong, QueryShortChannelIds, Tlv, TlvStream} import org.scalatest.{Outcome, Tag} import scodec.bits.ByteVector @@ -337,7 +336,10 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods { val probe = TestProbe() connect(remoteNodeId, authenticator, watcher, router, relayer, connection, transport, peer) - val query = wire.QueryShortChannelIds(Alice.nodeParams.chainHash, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(42000))), queryFlags_opt = None) + val query = QueryShortChannelIds( + Alice.nodeParams.chainHash, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(42000))), + TlvStream.empty) // make sure that routing messages go through for (ann <- channels ++ updates) { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala index 9962ba50dd..065ee73ac3 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala @@ -16,8 +16,9 @@ package fr.acinq.eclair.router -import fr.acinq.eclair.{MilliSatoshi, randomKey} +import fr.acinq.eclair.wire.ReplyChannelRangeTlv._ import fr.acinq.eclair.wire._ +import fr.acinq.eclair.{MilliSatoshi, randomKey} import org.scalatest.FunSuite import scala.collection.immutable.SortedMap @@ -33,14 +34,14 @@ class ChannelRangeQueriesSpec extends FunSuite { val a = randomKey.publicKey val b = randomKey.publicKey val ab = RouteCalculationSpec.makeChannel(123466L, a, b) - val (ab1, uab1) = RouteCalculationSpec.makeUpdateShort(ab.shortChannelId, ab.nodeId1, ab.nodeId2, MilliSatoshi(0), 0, timestamp = now) - val (ab2, uab2) = RouteCalculationSpec.makeUpdateShort(ab.shortChannelId, ab.nodeId2, ab.nodeId1, MilliSatoshi(0), 0, timestamp = now) + val (ab1, uab1) = RouteCalculationSpec.makeUpdateShort(ab.shortChannelId, ab.nodeId1, ab.nodeId2, MilliSatoshi(0), 0, timestamp = now) + val (ab2, uab2) = RouteCalculationSpec.makeUpdateShort(ab.shortChannelId, ab.nodeId2, ab.nodeId1, MilliSatoshi(0), 0, timestamp = now) val c = randomKey.publicKey val d = randomKey.publicKey val cd = RouteCalculationSpec.makeChannel(451312L, c, d) - val (cd1, ucd1) = RouteCalculationSpec.makeUpdateShort(cd.shortChannelId, cd.nodeId1, cd.nodeId2, MilliSatoshi(0), 0, timestamp = now) - val (_, ucd2) = RouteCalculationSpec.makeUpdateShort(cd.shortChannelId, cd.nodeId2, cd.nodeId1, MilliSatoshi(0), 0, timestamp = now) + val (cd1, ucd1) = RouteCalculationSpec.makeUpdateShort(cd.shortChannelId, cd.nodeId1, cd.nodeId2, MilliSatoshi(0), 0, timestamp = now) + val (_, ucd2) = RouteCalculationSpec.makeUpdateShort(cd.shortChannelId, cd.nodeId2, cd.nodeId1, MilliSatoshi(0), 0, timestamp = now) val e = randomKey.publicKey val f = randomKey.publicKey @@ -57,28 +58,28 @@ class ChannelRangeQueriesSpec extends FunSuite { cd1 -> ucd1 ) - import fr.acinq.eclair.wire.QueryFlagTypes._ + import fr.acinq.eclair.wire.QueryShortChannelIdsTlv.QueryFlagType._ - assert(Router.getChannelDigestInfo(channels, updates)(ab.shortChannelId) == TimestampsAndChecksums(now, 714408668, now, 714408668)) + assert(Router.getChannelDigestInfo(channels, updates)(ab.shortChannelId) == (Timestamps(now, now), Checksums(3297511804L, 3297511804L))) // no extended info but we know the channel: we ask for the updates - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, None) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2).toByte) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, None, None) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2).toByte) // same checksums, newer timestamps: we don't ask anything - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(TimestampsAndChecksums(now + 1, 714408668, now + 1, 714408668))) === 0.toByte) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(3297511804L, 3297511804L))) === 0.toByte) // different checksums, newer timestamps: we ask for the updates - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(TimestampsAndChecksums(now + 1, 154654604, now, 714408668))) === INCLUDE_CHANNEL_UPDATE_1) - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(TimestampsAndChecksums(now, 714408668, now + 1, 45664546))) === INCLUDE_CHANNEL_UPDATE_2) - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(TimestampsAndChecksums(now + 1, 154654604, now + 1, 45664546+6))) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2).toByte) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now + 1, now)), Some(Checksums(154654604, 3297511804L))) === INCLUDE_CHANNEL_UPDATE_1) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now, now + 1)), Some(Checksums(3297511804L, 45664546))) === INCLUDE_CHANNEL_UPDATE_2) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(154654604, 45664546+6))) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2).toByte) // different checksums, older timestamps: we don't ask anything - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(TimestampsAndChecksums(now - 1, 154654604, now, 714408668))) === 0.toByte) - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(TimestampsAndChecksums(now, 714408668, now - 1, 45664546))) === 0.toByte) - assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(TimestampsAndChecksums(now - 1, 154654604, now - 1, 45664546))) === 0.toByte) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now - 1, now)), Some(Checksums(154654604, 3297511804L))) === 0.toByte) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now, now - 1)), Some(Checksums(3297511804L, 45664546))) === 0.toByte) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now - 1, now - 1)), Some(Checksums(154654604, 45664546))) === 0.toByte) // missing channel update: we ask for it - assert(Router.computeFlag(channels, updates)(cd.shortChannelId, Some(TimestampsAndChecksums(now, 714408668, now, 714408668))) === INCLUDE_CHANNEL_UPDATE_2) + assert(Router.computeFlag(channels, updates)(cd.shortChannelId, Some(Timestamps(now, now)), Some(Checksums(3297511804L, 3297511804L))) === INCLUDE_CHANNEL_UPDATE_2) // unknown channel: we ask everything - assert(Router.computeFlag(channels, updates)(ef.shortChannelId, None) === QueryFlagTypes.INCLUDE_ALL) + assert(Router.computeFlag(channels, updates)(ef.shortChannelId, None, None) === INCLUDE_ALL) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala index 22e19db62c..95f451ee45 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala @@ -67,7 +67,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { def counts = BasicSyncResult(ranges.size, queries.size, channels.size, updates.size) } - def sync(src: TestFSMRef[State, Data, Router], tgt: TestFSMRef[State, Data, Router], extendedQueryFlags_opt: Option[ExtendedQueryFlags]): SyncResult = { + def sync(src: TestFSMRef[State, Data, Router], tgt: TestFSMRef[State, Data, Router], extendedQueryFlags_opt: Option[QueryChannelRangeTlv]): SyncResult = { val sender = TestProbe() val pipe = TestProbe() pipe.ignoreMsg { @@ -161,7 +161,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { val bob = TestFSMRef(new Router(Bob.nodeParams, watcher)) val charlieId = randomKey.publicKey val sender = TestProbe() - val extendedQueryFlags_opt = Some(ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS) + val extendedQueryFlags_opt = Some(QueryChannelRangeTlv.QueryFlags(QueryChannelRangeTlv.QueryFlags.WANT_ALL)) // tell alice to sync with bob assert(BasicSyncResult(ranges = 1, queries = 0, channels = 0, updates = 0) === sync(alice, bob, extendedQueryFlags_opt).counts) @@ -227,7 +227,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, block1)) // router should ask for our first block of ids - assert(transport.expectMsgType[QueryShortChannelIds] === QueryShortChannelIds(chainHash, block1.shortChannelIds, None)) + assert(transport.expectMsgType[QueryShortChannelIds] === QueryShortChannelIds(chainHash, block1.shortChannelIds, TlvStream.empty)) // router should think that it is missing 100 channels, in one request val Some(sync) = router.stateData.sync.get(remoteNodeId) assert(sync.total == 1) @@ -241,7 +241,7 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { test("sync progress") { - def req = QueryShortChannelIds(Block.RegtestGenesisBlock.hash, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(42))), None) + def req = QueryShortChannelIds(Block.RegtestGenesisBlock.hash, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(42))), TlvStream.empty) val nodeidA = randomKey.publicKey val nodeidB = randomKey.publicKey diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/ExtendedQueriesCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/ExtendedQueriesCodecsSpec.scala new file mode 100644 index 0000000000..6d74284fc1 --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/ExtendedQueriesCodecsSpec.scala @@ -0,0 +1,94 @@ +package fr.acinq.eclair.wire + +import fr.acinq.bitcoin.Block +import fr.acinq.eclair.{ShortChannelId, UInt64} +import fr.acinq.eclair.wire.LightningMessageCodecs._ +import ReplyChannelRangeTlv._ +import org.scalatest.FunSuite +import scodec.bits.ByteVector + +class ExtendedQueriesCodecsSpec extends FunSuite { + test("encode query_short_channel_ids (no optional data)") { + val query_short_channel_id = QueryShortChannelIds( + Block.RegtestGenesisBlock.blockId, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + TlvStream.empty) + + val encoded = queryShortChannelIdsCodec.encode(query_short_channel_id).require + val decoded = queryShortChannelIdsCodec.decode(encoded).require.value + assert(decoded === query_short_channel_id) + } + + test("encode query_short_channel_ids (with optional data)") { + val query_short_channel_id = QueryShortChannelIds( + Block.RegtestGenesisBlock.blockId, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.UNCOMPRESSED, List(1.toByte, 2.toByte, 3.toByte, 4.toByte, 5.toByte)))) + + val encoded = queryShortChannelIdsCodec.encode(query_short_channel_id).require + val decoded = queryShortChannelIdsCodec.decode(encoded).require.value + assert(decoded === query_short_channel_id) + } + + test("encode query_short_channel_ids (with optional data including unknown data)") { + val query_short_channel_id = QueryShortChannelIds( + Block.RegtestGenesisBlock.blockId, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + TlvStream( + QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.UNCOMPRESSED, List(1.toByte, 2.toByte, 3.toByte, 4.toByte, 5.toByte)) :: Nil, + GenericTlv(UInt64(43), ByteVector.fromValidHex("deadbeef")) :: Nil + ) + ) + + val encoded = queryShortChannelIdsCodec.encode(query_short_channel_id).require + val decoded = queryShortChannelIdsCodec.decode(encoded).require.value + assert(decoded === query_short_channel_id) + } + + test("encode reply_channel_range (no optional data)") { + val replyChannelRange = ReplyChannelRange( + Block.RegtestGenesisBlock.blockId, + 1, 100, + 1.toByte, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + None, None) + + val encoded = replyChannelRangeCodec.encode(replyChannelRange).require + val decoded = replyChannelRangeCodec.decode(encoded).require.value + assert(decoded === replyChannelRange) + } + + test("encode reply_channel_range (with optional timestamps)") { + val replyChannelRange = ReplyChannelRange( + Block.RegtestGenesisBlock.blockId, + 1, 100, + 1.toByte, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + Some(EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, List(Timestamps(1, 1), Timestamps(2, 2), Timestamps(3, 3)))), + None) + + val encoded = replyChannelRangeCodec.encode(replyChannelRange).require + val decoded = replyChannelRangeCodec.decode(encoded).require.value + assert(decoded === replyChannelRange) + } + + test("encode reply_channel_range (with optional timestamps, checksums, and unknown data)") { + val replyChannelRange = ReplyChannelRange( + Block.RegtestGenesisBlock.blockId, + 1, 100, + 1.toByte, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + TlvStream( + List( + EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, List(Timestamps(1, 1), Timestamps(2, 2), Timestamps(3, 3))), + EncodedChecksums(List(Checksums(1, 1), Checksums(2, 2), Checksums(3, 3))) + ), + GenericTlv(UInt64(7), ByteVector.fromValidHex("deadbeef")) :: Nil + ) + ) + + val encoded = replyChannelRangeCodec.encode(replyChannelRange).require + val decoded = replyChannelRangeCodec.decode(encoded).require.value + assert(decoded === replyChannelRange) + } +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala index 3649943675..a8d7a05790 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala @@ -21,14 +21,9 @@ import java.net.{Inet4Address, InetAddress} import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Satoshi} import fr.acinq.eclair._ -import fr.acinq.eclair.api._ -import fr.acinq.eclair.channel.State -import fr.acinq.eclair.crypto.Sphinx import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.wire.LightningMessageCodecs._ -import org.json4s.JsonAST.{JNothing, JString} -import org.json4s.{CustomSerializer, ShortTypeHints} -import org.json4s.jackson.Serialization +import ReplyChannelRangeTlv._ import org.scalatest.FunSuite import scodec.bits.{ByteVector, HexStringSyntax} @@ -77,9 +72,18 @@ class LightningMessageCodecsSpec extends FunSuite { val channel_update = ChannelUpdate(randomBytes64, Block.RegtestGenesisBlock.hash, ShortChannelId(1), 2, 42, 0, 3, MilliSatoshi(4), MilliSatoshi(5), 6, None) val announcement_signatures = AnnouncementSignatures(randomBytes32, ShortChannelId(42), randomBytes64, randomBytes64) val gossip_timestamp_filter = GossipTimestampFilter(Block.RegtestGenesisBlock.blockId, 100000, 1500) - val query_short_channel_id = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), None) - val query_channel_range = QueryChannelRange(Block.RegtestGenesisBlock.blockId, 100000, 1500, Some(ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS)) - val reply_channel_range = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 100000, 1500, 1, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), Some(ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS), Some(ExtendedInfo(List(TimestampsAndChecksums(1, 1, 1, 1), TimestampsAndChecksums(2, 2, 2, 2), TimestampsAndChecksums(3, 3, 3, 3))))) + val query_short_channel_id = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty) + val unknownTlv = GenericTlv(UInt64(5), ByteVector.fromValidHex("deadbeef")) + val query_channel_range = QueryChannelRange(Block.RegtestGenesisBlock.blockId, + 100000, + 1500, + TlvStream(QueryChannelRangeTlv.QueryFlags((QueryChannelRangeTlv.QueryFlags.WANT_ALL)) :: Nil, unknownTlv :: Nil)) + val reply_channel_range = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 100000, 1500, 1, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + TlvStream( + EncodedTimestamps(EncodingType.UNCOMPRESSED, List(Timestamps(1, 1), Timestamps(2, 2), Timestamps(3, 3))) :: EncodedChecksums(List(Checksums(1, 1), Checksums(2, 2), Checksums(3, 3))) :: Nil, + unknownTlv :: Nil) + ) val ping = Ping(100, bin(10, 1)) val pong = Pong(bin(10, 1)) val channel_reestablish = ChannelReestablish(randomBytes32, 242842L, 42L) @@ -100,11 +104,16 @@ class LightningMessageCodecsSpec extends FunSuite { test("non-reg encoding type") { val refs = Map( - hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001900000000000000008e0000000000003c69000000000045a6c4" -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), None), - hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001601789c636000833e08659309a65c971d0100126e02e3" -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), None), - hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001900000000000000008e0000000000003c69000000000045a6c4000400010204" -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), Some(EncodedQueryFlags(EncodingType.UNCOMPRESSED, List(1, 2, 4)))), - hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001601789c636000833e08659309a65c971d0100126e02e3000c01789c6364620100000e0008" -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), Some(EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) + hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001900000000000000008e0000000000003c69000000000045a6c4" + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty), + hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001601789c636000833e08659309a65c971d0100126e02e3" + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty), + hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001900000000000000008e0000000000003c69000000000045a6c4010400010204" + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.UNCOMPRESSED, List(1, 2, 4)))), + hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001601789c636000833e08659309a65c971d0100126e02e3010c01789c6364620100000e0008" + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) ) + refs.forall { case (bin, obj) => lightningMessageCodec.decode(bin.toBitVector).require.value == obj && lightningMessageCodec.encode(obj).require == bin.toBitVector @@ -113,18 +122,33 @@ class LightningMessageCodecsSpec extends FunSuite { case class TestItem(msg: Any, hex: String) - ignore("test vectors") { - - val query_channel_range = QueryChannelRange(Block.RegtestGenesisBlock.blockId, 100000, 1500, None) - val query_channel_range_timestamps_checksums = QueryChannelRange(Block.RegtestGenesisBlock.blockId, 35000, 100, Some(ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS)) - val reply_channel_range = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 756230, 1500, 1, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), Some(ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS), None) - val reply_channel_range_zlib = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 1600, 110, 1, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(265462))), Some(ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS), None) - val reply_channel_range_timestamps_checksums = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 122334, 1500, 1, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(12355), ShortChannelId(489686), ShortChannelId(4645313))), Some(ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS), Some(ExtendedInfo(List(TimestampsAndChecksums(164545, 1111, 948165, 2222), TimestampsAndChecksums(489645, 3333, 4786864, 4444), TimestampsAndChecksums(46456, 5555, 9788415, 6666))))) - val reply_channel_range_timestamps_checksums_zlib = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 500, 100, 1, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(1234545), ShortChannelId(4897484), ShortChannelId(4564676))), Some(ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS), Some(ExtendedInfo(List(TimestampsAndChecksums(164545, 1111, 948165, 2222), TimestampsAndChecksums(489645, 3333, 4786864, 4444), TimestampsAndChecksums(46456, 5555, 9788415, 6666))))) - val query_short_channel_id = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), None) - val query_short_channel_id_zlib = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(4564), ShortChannelId(178622), ShortChannelId(4564676))), None) - val query_short_channel_id_flags = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(12232), ShortChannelId(15556), ShortChannelId(4564676))), Some(EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) - val query_short_channel_id_flags_zlib = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(14200), ShortChannelId(46645), ShortChannelId(4564676))), Some(EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) + test("test vectors") { + import org.json4s.{CustomSerializer, ShortTypeHints} + import org.json4s.JsonAST.JString + import org.json4s.jackson.Serialization + import fr.acinq.eclair.api._ + + val query_channel_range = QueryChannelRange(Block.RegtestGenesisBlock.blockId, 100000, 1500, TlvStream.empty) + val query_channel_range_timestamps_checksums = QueryChannelRange(Block.RegtestGenesisBlock.blockId, + 35000, + 100, + TlvStream(QueryChannelRangeTlv.QueryFlags((QueryChannelRangeTlv.QueryFlags.WANT_ALL)))) + val reply_channel_range = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 756230, 1500, 1, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), None, None) + val reply_channel_range_zlib = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 1600, 110, 1, + EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(265462))), None, None) + val reply_channel_range_timestamps_checksums = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 122334, 1500, 1, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(12355), ShortChannelId(489686), ShortChannelId(4645313))), + Some(EncodedTimestamps(EncodingType.UNCOMPRESSED, List(Timestamps(164545, 948165), Timestamps(489645, 4786864), Timestamps(46456, 9788415)))), + Some(EncodedChecksums(List(Checksums(1111, 2222), Checksums(3333, 4444), Checksums(5555, 6666))))) + val reply_channel_range_timestamps_checksums_zlib = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 122334, 1500, 1, + EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(12355), ShortChannelId(489686), ShortChannelId(4645313))), + Some(EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, List(Timestamps(164545, 948165), Timestamps(489645, 4786864), Timestamps(46456, 9788415)))), + Some(EncodedChecksums(List(Checksums(1111, 2222), Checksums(3333, 4444), Checksums(5555, 6666))))) + val query_short_channel_id = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty) + val query_short_channel_id_zlib = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(4564), ShortChannelId(178622), ShortChannelId(4564676))), TlvStream.empty) + val query_short_channel_id_flags = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(12232), ShortChannelId(15556), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) + val query_short_channel_id_flags_zlib = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(14200), ShortChannelId(46645), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) val refs = List( query_channel_range, @@ -139,13 +163,18 @@ class LightningMessageCodecsSpec extends FunSuite { query_short_channel_id_flags_zlib ) - class EncodingTypeSerializer extends CustomSerializer[EncodingType](format => ({ null }, { + class EncodingTypeSerializer extends CustomSerializer[EncodingType](format => ( { + null + }, { case EncodingType.UNCOMPRESSED => JString("UNCOMPRESSED") case EncodingType.COMPRESSED_ZLIB => JString("COMPRESSED_ZLIB") })) - class ExtendedQueryFlagsSerializer extends CustomSerializer[ExtendedQueryFlags](format => ({ null }, { - case ExtendedQueryFlags.TIMESTAMPS_AND_CHECKSUMS => JString("TIMESTAMPS_AND_CHECKSUMS") + class ExtendedQueryFlagsSerializer extends CustomSerializer[QueryChannelRangeTlv.QueryFlags](format => ( { + null + }, { + case QueryChannelRangeTlv.QueryFlags(flag) => + JString(((if (QueryChannelRangeTlv.QueryFlags.wantTimestamps(flag)) List("WANT_TIMESTAMPS") else List()) ::: (if (QueryChannelRangeTlv.QueryFlags.wantChecksums(flag)) List("WANT_CHECKSUMS") else List())) mkString (" | ")) })) implicit val formats = org.json4s.DefaultFormats.withTypeHintFieldName("type") + @@ -174,16 +203,16 @@ class LightningMessageCodecsSpec extends FunSuite { new DirectionSerializer + new PaymentRequestSerializer + ShortTypeHints(List( - classOf[QueryChannelRange], - classOf[ReplyChannelRange], - classOf[QueryShortChannelIds])) - - refs.foreach { - obj => - val bin = lightningMessageCodec.encode(obj).require - println(Serialization.writePretty(TestItem(obj, bin.toHex))) - } + classOf[QueryChannelRange], + classOf[ReplyChannelRange], + classOf[QueryShortChannelIds])) + val items = refs.map { obj => + val bin = lightningMessageCodec.encode(obj).require + TestItem(obj, bin.toHex) + } + val json = Serialization.writePretty(items) + println(json) } test("decode channel_update with htlc_maximum_msat") { @@ -196,4 +225,5 @@ class LightningMessageCodecsSpec extends FunSuite { val bin2 = ByteVector(lightningMessageCodec.encode(update).require.toByteArray) assert(bin === bin2) } + } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala index e2c5dbce1b..c21033085e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala @@ -294,6 +294,12 @@ class TlvCodecsSpec extends FunSuite { } } + test("get optional TLV field") { + val stream = TlvStream[TestTlv](Seq(TestType254(42), TestType1(42)), Seq(GenericTlv(13, hex"2a"), GenericTlv(11, hex"2b"))) + assert(stream.get[TestType254] == Some(TestType254(42))) + assert(stream.get[TestType1] == Some(TestType1(42))) + assert(stream.get[TestType2] == None) + } } object TlvCodecsSpec {