From 4dc2910c4e865d100de3ddb0d28616c6b201f6ed Mon Sep 17 00:00:00 2001 From: Pierre-Marie Padiou Date: Tue, 25 May 2021 19:03:17 +0200 Subject: [PATCH] Make result set an iterable (#1823) This allows us to use the full power of scala collections, to iterate over results, convert to options, etc. while staying purely functional and immutable. There is a catch though: the iterator is lazy, it must be materialized before the result set is closed, by converting the end result in a collection or an option. In other words, database methods must never return an `Iterable` or `Iterator`. --- .../fr/acinq/eclair/db/DbEventHandler.scala | 2 +- .../scala/fr/acinq/eclair/db/FeeratesDb.scala | 4 +- .../acinq/eclair/db/FileBackupHandler.scala | 5 +- .../scala/fr/acinq/eclair/db/NetworkDb.scala | 3 +- .../scala/fr/acinq/eclair/db/PaymentsDb.scala | 6 +- .../scala/fr/acinq/eclair/db/PeersDb.scala | 4 +- .../fr/acinq/eclair/db/jdbc/JdbcUtils.scala | 42 +++-- .../fr/acinq/eclair/db/pg/PgAuditDb.scala | 162 ++++++++-------- .../fr/acinq/eclair/db/pg/PgChannelsDb.scala | 15 +- .../fr/acinq/eclair/db/pg/PgNetworkDb.scala | 32 ++-- .../fr/acinq/eclair/db/pg/PgPaymentsDb.scala | 111 ++++------- .../fr/acinq/eclair/db/pg/PgPeersDb.scala | 20 +- .../eclair/db/pg/PgPendingCommandsDb.scala | 16 +- .../scala/fr/acinq/eclair/db/pg/PgUtils.scala | 17 +- .../eclair/db/sqlite/SqliteAuditDb.scala | 176 +++++++++--------- .../eclair/db/sqlite/SqliteChannelsDb.scala | 14 +- .../eclair/db/sqlite/SqliteFeeratesDb.scala | 35 ++-- .../eclair/db/sqlite/SqliteNetworkDb.scala | 32 ++-- .../eclair/db/sqlite/SqlitePaymentsDb.scala | 115 ++++-------- .../eclair/db/sqlite/SqlitePeersDb.scala | 22 +-- .../db/sqlite/SqlitePendingCommandsDb.scala | 14 +- .../acinq/eclair/db/sqlite/SqliteUtils.scala | 4 +- .../eclair/io/ReconnectionTaskSpec.scala | 2 +- 23 files changed, 363 insertions(+), 490 deletions(-) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala index 38a8f6eaab..e89d3db24f 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/DbEventHandler.scala @@ -20,7 +20,7 @@ import akka.actor.{Actor, ActorLogging, Props} import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.eclair.NodeParams -import fr.acinq.eclair.channel.Helpers.Closing.{ClosingType, CurrentRemoteClose, LocalClose, MutualClose, NextRemoteClose, RecoveryClose, RevokedClose} +import fr.acinq.eclair.channel.Helpers.Closing._ import fr.acinq.eclair.channel.Monitoring.{Metrics => ChannelMetrics, Tags => ChannelTags} import fr.acinq.eclair.channel._ import fr.acinq.eclair.db.DbEventHandler.ChannelEvent diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/FeeratesDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/FeeratesDb.scala index 1bbf791987..f8343eaef7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/FeeratesDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/FeeratesDb.scala @@ -16,10 +16,10 @@ package fr.acinq.eclair.db -import java.io.Closeable - import fr.acinq.eclair.blockchain.fee.FeeratesPerKB +import java.io.Closeable + /** * This database stores the fee rates retrieved by a [[fr.acinq.eclair.blockchain.fee.FeeProvider]]. */ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/FileBackupHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/FileBackupHandler.scala index 9b8db9b7bb..900e634a12 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/FileBackupHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/FileBackupHandler.scala @@ -16,9 +16,6 @@ package fr.acinq.eclair.db -import java.io.File -import java.nio.file.{Files, StandardCopyOption} - import akka.actor.{Actor, ActorLogging, Props} import akka.dispatch.{BoundedMessageQueueSemantics, RequiresMessageQueue} import fr.acinq.eclair.KamonExt @@ -26,6 +23,8 @@ import fr.acinq.eclair.channel.ChannelPersisted import fr.acinq.eclair.db.Databases.FileBackup import fr.acinq.eclair.db.Monitoring.Metrics +import java.io.File +import java.nio.file.{Files, StandardCopyOption} import scala.sys.process.Process import scala.util.{Failure, Success, Try} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala index 966512d59c..b50815a024 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/NetworkDb.scala @@ -16,14 +16,13 @@ package fr.acinq.eclair.db -import java.io.Closeable - import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{ByteVector32, Satoshi} import fr.acinq.eclair.ShortChannelId import fr.acinq.eclair.router.Router.PublicChannel import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement} +import java.io.Closeable import scala.collection.immutable.SortedMap trait NetworkDb extends Closeable { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala index d4dd10f3e9..d6c35f4dfa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PaymentsDb.scala @@ -16,15 +16,15 @@ package fr.acinq.eclair.db -import java.io.Closeable -import java.util.UUID - import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.payment._ import fr.acinq.eclair.router.Router.{ChannelHop, Hop, NodeHop} import fr.acinq.eclair.{MilliSatoshi, ShortChannelId} +import java.io.Closeable +import java.util.UUID + trait PaymentsDb extends IncomingPaymentsDb with OutgoingPaymentsDb with PaymentsOverviewDb with Closeable trait IncomingPaymentsDb { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala index dcf1b7117b..4d71e0b3d6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PeersDb.scala @@ -16,11 +16,11 @@ package fr.acinq.eclair.db -import java.io.Closeable - import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.wire.protocol.NodeAddress +import java.io.Closeable + trait PeersDb extends Closeable { def addOrUpdatePeer(nodeId: PublicKey, address: NodeAddress): Unit diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala index a405c96efe..a6c5afc1a7 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/jdbc/JdbcUtils.scala @@ -19,16 +19,17 @@ package fr.acinq.eclair.db.jdbc import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.MilliSatoshi import org.sqlite.SQLiteConnection -import scodec.Codec +import scodec.Decoder import scodec.bits.{BitVector, ByteVector} import java.sql.{Connection, ResultSet, Statement, Timestamp} import java.util.UUID import javax.sql.DataSource -import scala.collection.immutable.Queue trait JdbcUtils { + import ExtendedResultSet._ + def withConnection[T](f: Connection => T)(implicit dataSource: DataSource): T = { val connection = dataSource.getConnection() try { @@ -72,15 +73,16 @@ trait JdbcUtils { def getVersion(statement: Statement, db_name: String): Option[Int] = { createVersionTable(statement) // if there was a previous version installed, this will return a different value from current version - val rs = statement.executeQuery(s"SELECT version FROM versions WHERE db_name='$db_name'") - if (rs.next()) Some(rs.getInt("version")) else None + statement.executeQuery(s"SELECT version FROM versions WHERE db_name='$db_name'") + .map(rs => rs.getInt("version")) + .headOption } /** * Updates the version for a particular logical database, it will overwrite the previous version. * * NB: we could define this method in [[fr.acinq.eclair.db.sqlite.SqliteUtils]] and [[fr.acinq.eclair.db.pg.PgUtils]] - * but it would make testing more complicated because we need to use one or the other depending on the backend. + * but it would make testing more complicated because we need to use one or the other depending on the backend. */ def setVersion(statement: Statement, db_name: String, newVersion: Int): Unit = { createVersionTable(statement) @@ -96,20 +98,25 @@ trait JdbcUtils { } } - /** - * This helper assumes that there is a "data" column available, decodable with the provided codec - * - * TODO: we should use an scala.Iterator instead - */ - def codecSequence[T](rs: ResultSet, codec: Codec[T]): Seq[T] = { - var q: Queue[T] = Queue() - while (rs.next()) { - q = q :+ codec.decode(BitVector(rs.getBytes("data"))).require.value + case class ExtendedResultSet(rs: ResultSet) extends Iterable[ResultSet] { + + /** + * Iterates over all rows of a result set. + * + * Careful: the iterator is lazy, it must be materialized before the [[ResultSet]] is closed, by converting the end + * result in a collection or an option. + */ + override def iterator: Iterator[ResultSet] = { + // @formatter:off + new Iterator[ResultSet] { + def hasNext: Boolean = rs.next() + def next(): ResultSet = rs + } + // @formatter:on } - q - } - case class ExtendedResultSet(rs: ResultSet) { + /** This helper assumes that there is a "data" column available, that can be decoded with the provided codec */ + def mapCodec[T](codec: Decoder[T]): Iterable[T] = rs.map(rs => codec.decode(BitVector(rs.getBytes("data"))).require.value) def getByteVectorFromHex(columnLabel: String): ByteVector = { val s = rs.getString(columnLabel).stripPrefix("\\x") @@ -166,7 +173,6 @@ trait JdbcUtils { val result = rs.getTimestamp(label) if (rs.wasNull()) None else Some(result) } - } object ExtendedResultSet { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala index 2e44c80ccc..44852abdb8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgAuditDb.scala @@ -33,7 +33,6 @@ import java.sql.{Statement, Timestamp} import java.time.Instant import java.util.UUID import javax.sql.DataSource -import scala.collection.immutable.Queue class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { @@ -215,30 +214,28 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { using(pg.prepareStatement("SELECT * FROM sent WHERE timestamp BETWEEN ? AND ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) - val rs = statement.executeQuery() - var sentByParentId = Map.empty[UUID, PaymentSent] - while (rs.next()) { - val parentId = UUID.fromString(rs.getString("parent_payment_id")) - val part = PaymentSent.PartialPayment( - UUID.fromString(rs.getString("payment_id")), - MilliSatoshi(rs.getLong("amount_msat")), - MilliSatoshi(rs.getLong("fees_msat")), - rs.getByteVector32FromHex("to_channel_id"), - None, // we don't store the route in the audit DB - rs.getTimestamp("timestamp").getTime) - val sent = sentByParentId.get(parentId) match { - case Some(s) => s.copy(parts = s.parts :+ part) - case None => PaymentSent( - parentId, - rs.getByteVector32FromHex("payment_hash"), - rs.getByteVector32FromHex("payment_preimage"), - MilliSatoshi(rs.getLong("recipient_amount_msat")), - PublicKey(rs.getByteVectorFromHex("recipient_node_id")), - Seq(part)) - } - sentByParentId = sentByParentId + (parentId -> sent) - } - sentByParentId.values.toSeq.sortBy(_.timestamp) + statement.executeQuery() + .foldLeft(Map.empty[UUID, PaymentSent]) { (sentByParentId, rs) => + val parentId = UUID.fromString(rs.getString("parent_payment_id")) + val part = PaymentSent.PartialPayment( + UUID.fromString(rs.getString("payment_id")), + MilliSatoshi(rs.getLong("amount_msat")), + MilliSatoshi(rs.getLong("fees_msat")), + rs.getByteVector32FromHex("to_channel_id"), + None, // we don't store the route in the audit DB + rs.getTimestamp("timestamp").getTime) + val sent = sentByParentId.get(parentId) match { + case Some(s) => s.copy(parts = s.parts :+ part) + case None => PaymentSent( + parentId, + rs.getByteVector32FromHex("payment_hash"), + rs.getByteVector32FromHex("payment_preimage"), + MilliSatoshi(rs.getLong("recipient_amount_msat")), + PublicKey(rs.getByteVectorFromHex("recipient_node_id")), + Seq(part)) + } + sentByParentId + (parentId -> sent) + }.values.toSeq.sortBy(_.timestamp) } } @@ -247,70 +244,66 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { using(pg.prepareStatement("SELECT * FROM received WHERE timestamp BETWEEN ? AND ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) - val rs = statement.executeQuery() - var receivedByHash = Map.empty[ByteVector32, PaymentReceived] - while (rs.next()) { - val paymentHash = rs.getByteVector32FromHex("payment_hash") - val part = PaymentReceived.PartialPayment( - MilliSatoshi(rs.getLong("amount_msat")), - rs.getByteVector32FromHex("from_channel_id"), - rs.getTimestamp("timestamp").getTime) - val received = receivedByHash.get(paymentHash) match { - case Some(r) => r.copy(parts = r.parts :+ part) - case None => PaymentReceived(paymentHash, Seq(part)) - } - receivedByHash = receivedByHash + (paymentHash -> received) - } - receivedByHash.values.toSeq.sortBy(_.timestamp) + statement.executeQuery() + .foldLeft(Map.empty[ByteVector32, PaymentReceived]) { (receivedByHash, rs) => + val paymentHash = rs.getByteVector32FromHex("payment_hash") + val part = PaymentReceived.PartialPayment( + MilliSatoshi(rs.getLong("amount_msat")), + rs.getByteVector32FromHex("from_channel_id"), + rs.getTimestamp("timestamp").getTime) + val received = receivedByHash.get(paymentHash) match { + case Some(r) => r.copy(parts = r.parts :+ part) + case None => PaymentReceived(paymentHash, Seq(part)) + } + receivedByHash + (paymentHash -> received) + }.values.toSeq.sortBy(_.timestamp) } } override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] = inTransaction { pg => - var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)] - using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement => + val trampolineByHash = using(pg.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp BETWEEN ? and ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) - val rs = statement.executeQuery() - while (rs.next()) { - val paymentHash = rs.getByteVector32FromHex("payment_hash") - val amount = MilliSatoshi(rs.getLong("amount_msat")) - val nodeId = PublicKey(rs.getByteVectorFromHex("next_node_id")) - trampolineByHash += (paymentHash -> (amount, nodeId)) - } + statement.executeQuery() + .foldLeft(Map.empty[ByteVector32, (MilliSatoshi, PublicKey)]) { (trampolineByHash, rs) => + val paymentHash = rs.getByteVector32FromHex("payment_hash") + val amount = MilliSatoshi(rs.getLong("amount_msat")) + val nodeId = PublicKey(rs.getByteVectorFromHex("next_node_id")) + trampolineByHash + (paymentHash -> (amount, nodeId)) + } } - using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement => + val relayedByHash = using(pg.prepareStatement("SELECT * FROM relayed WHERE timestamp BETWEEN ? and ?")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) - val rs = statement.executeQuery() - var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]] - while (rs.next()) { - val paymentHash = rs.getByteVector32FromHex("payment_hash") - val part = RelayedPart( - rs.getByteVector32FromHex("channel_id"), - MilliSatoshi(rs.getLong("amount_msat")), - rs.getString("direction"), - rs.getString("relay_type"), - rs.getTimestamp("timestamp").getTime) - relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part)) - } - relayedByHash.flatMap { - case (paymentHash, parts) => - // We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel). - // NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch. - val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount) - val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount) - parts.headOption match { - case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map { - case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp) - } - case Some(RelayedPart(_, _, _, "trampoline", timestamp)) => - val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey)) - TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount, timestamp) :: Nil - case _ => Nil - } - }.toSeq.sortBy(_.timestamp) + statement.executeQuery() + .foldLeft(Map.empty[ByteVector32, Seq[RelayedPart]]) { (relayedByHash, rs) => + val paymentHash = rs.getByteVector32FromHex("payment_hash") + val part = RelayedPart( + rs.getByteVector32FromHex("channel_id"), + MilliSatoshi(rs.getLong("amount_msat")), + rs.getString("direction"), + rs.getString("relay_type"), + rs.getTimestamp("timestamp").getTime) + relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part)) + } } + relayedByHash.flatMap { + case (paymentHash, parts) => + // We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel). + // NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch. + val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount) + val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount) + parts.headOption match { + case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map { + case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp) + } + case Some(RelayedPart(_, _, _, "trampoline", timestamp)) => + val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey)) + TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount, timestamp) :: Nil + case _ => Nil + } + }.toSeq.sortBy(_.timestamp) } override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] = @@ -318,27 +311,24 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging { using(pg.prepareStatement("SELECT * FROM network_fees WHERE timestamp BETWEEN ? and ? ORDER BY timestamp")) { statement => statement.setTimestamp(1, Timestamp.from(Instant.ofEpochMilli(from))) statement.setTimestamp(2, Timestamp.from(Instant.ofEpochMilli(to))) - val rs = statement.executeQuery() - var q: Queue[NetworkFee] = Queue() - while (rs.next()) { - q = q :+ NetworkFee( + statement.executeQuery().map { rs => + NetworkFee( remoteNodeId = PublicKey(rs.getByteVectorFromHex("node_id")), channelId = rs.getByteVector32FromHex("channel_id"), txId = rs.getByteVector32FromHex("tx_id"), fee = Satoshi(rs.getLong("fee_sat")), txType = rs.getString("tx_type"), timestamp = rs.getTimestamp("timestamp").getTime) - } - q + }.toSeq } } override def stats(from: Long, to: Long): Seq[Stats] = { - val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { case (feeByChannelId, f) => + val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { (feeByChannelId, f) => feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee)) } case class Relayed(amount: MilliSatoshi, fee: MilliSatoshi, direction: String) - val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { case (previous, e) => + val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { (previous, e) => // NB: we must avoid counting the fee twice: we associate it to the outgoing channels rather than the incoming ones. val current = e match { case c: ChannelPaymentRelayed => Map( diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala index 445b3856fb..6c11001428 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala @@ -31,7 +31,6 @@ import grizzled.slf4j.Logging import java.sql.{Statement, Timestamp} import java.time.Instant import javax.sql.DataSource -import scala.collection.immutable.Queue class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb with Logging { @@ -146,8 +145,8 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit override def listLocalChannels(): Seq[HasCommitments] = withMetrics("channels/list-local-channels", DbBackends.Postgres) { withLock { pg => using(pg.createStatement) { statement => - val rs = statement.executeQuery("SELECT data FROM local_channels WHERE is_closed=FALSE") - codecSequence(rs, stateDataCodec) + statement.executeQuery("SELECT data FROM local_channels WHERE is_closed=FALSE") + .mapCodec(stateDataCodec).toSeq } } } @@ -169,12 +168,10 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit using(pg.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement => statement.setString(1, channelId.toHex) statement.setLong(2, commitmentNumber) - val rs = statement.executeQuery - var q: Queue[(ByteVector32, CltvExpiry)] = Queue() - while (rs.next()) { - q = q :+ (ByteVector32(rs.getByteVector32FromHex("payment_hash")), CltvExpiry(rs.getLong("cltv_expiry"))) - } - q + statement.executeQuery + .map { rs => + (ByteVector32(rs.getByteVector32FromHex("payment_hash")), CltvExpiry(rs.getLong("cltv_expiry"))) + }.toSeq } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala index 719f95f901..b42cb927aa 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala @@ -75,8 +75,9 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { inTransaction { pg => using(pg.prepareStatement("SELECT data FROM nodes WHERE node_id=?")) { statement => statement.setString(1, nodeId.value.toHex) - val rs = statement.executeQuery() - codecSequence(rs, nodeAnnouncementCodec).headOption + statement.executeQuery() + .mapCodec(nodeAnnouncementCodec) + .headOption } } } @@ -93,8 +94,8 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def listNodes(): Seq[NodeAnnouncement] = withMetrics("network/list-nodes", DbBackends.Postgres) { inTransaction { pg => using(pg.createStatement()) { statement => - val rs = statement.executeQuery("SELECT data FROM nodes") - codecSequence(rs, nodeAnnouncementCodec) + statement.executeQuery("SELECT data FROM nodes") + .mapCodec(nodeAnnouncementCodec).toSeq } } } @@ -125,17 +126,15 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { override def listChannels(): SortedMap[ShortChannelId, PublicChannel] = withMetrics("network/list-channels", DbBackends.Postgres) { inTransaction { pg => using(pg.createStatement()) { statement => - val rs = statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels") - var m = SortedMap.empty[ShortChannelId, PublicChannel] - while (rs.next()) { - val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value - val txId = ByteVector32.fromValidHex(rs.getString("txid")) - val capacity = rs.getLong("capacity_sat") - val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value) - val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value) - m = m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None)) - } - m + statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels") + .foldLeft(SortedMap.empty[ShortChannelId, PublicChannel]) { (m, rs) => + val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value + val txId = ByteVector32.fromValidHex(rs.getString("txid")) + val capacity = rs.getLong("capacity_sat") + val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value) + val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value) + m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None)) + } } } } @@ -182,8 +181,7 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging { inTransaction { pg => using(pg.prepareStatement("SELECT short_channel_id from pruned WHERE short_channel_id=?")) { statement => statement.setLong(1, shortChannelId.toLong) - val rs = statement.executeQuery() - rs.next() + statement.executeQuery().nonEmpty } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala index ed60a01ea8..c8a44174a2 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPaymentsDb.scala @@ -33,7 +33,6 @@ import scodec.codecs._ import java.sql.ResultSet import java.util.UUID import javax.sql.DataSource -import scala.collection.immutable.Queue import scala.concurrent.duration._ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb with Logging { @@ -168,12 +167,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit withLock { pg => using(pg.prepareStatement("SELECT * FROM sent_payments WHERE id = ?")) { statement => statement.setString(1, id.toString) - val rs = statement.executeQuery() - if (rs.next()) { - Some(parseOutgoingPayment(rs)) - } else { - None - } + statement.executeQuery().map(parseOutgoingPayment).headOption } } } @@ -182,12 +176,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit withLock { pg => using(pg.prepareStatement("SELECT * FROM sent_payments WHERE parent_id = ? ORDER BY created_at")) { statement => statement.setString(1, parentId.toString) - val rs = statement.executeQuery() - var q: Queue[OutgoingPayment] = Queue() - while (rs.next()) { - q = q :+ parseOutgoingPayment(rs) - } - q + statement.executeQuery().map(parseOutgoingPayment).toSeq } } } @@ -196,12 +185,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit withLock { pg => using(pg.prepareStatement("SELECT * FROM sent_payments WHERE payment_hash = ? ORDER BY created_at")) { statement => statement.setString(1, paymentHash.toHex) - val rs = statement.executeQuery() - var q: Queue[OutgoingPayment] = Queue() - while (rs.next()) { - q = q :+ parseOutgoingPayment(rs) - } - q + statement.executeQuery().map(parseOutgoingPayment).toSeq } } } @@ -211,12 +195,9 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit using(pg.prepareStatement("SELECT * FROM sent_payments WHERE created_at >= ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) - val rs = statement.executeQuery() - var q: Queue[OutgoingPayment] = Queue() - while (rs.next()) { - q = q :+ parseOutgoingPayment(rs) - } - q + statement.executeQuery().map { rs => + parseOutgoingPayment(rs) + }.toSeq } } } @@ -271,12 +252,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit withLock { pg => using(pg.prepareStatement("SELECT * FROM received_payments WHERE payment_hash = ?")) { statement => statement.setString(1, paymentHash.toHex) - val rs = statement.executeQuery() - if (rs.next()) { - Some(parseIncomingPayment(rs)) - } else { - None - } + statement.executeQuery().map(parseIncomingPayment).headOption } } } @@ -286,12 +262,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit using(pg.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) - val rs = statement.executeQuery() - var q: Queue[IncomingPayment] = Queue() - while (rs.next()) { - q = q :+ parseIncomingPayment(rs) - } - q + statement.executeQuery().map(parseIncomingPayment).toSeq } } } @@ -301,12 +272,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit using(pg.prepareStatement("SELECT * FROM received_payments WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) - val rs = statement.executeQuery() - var q: Queue[IncomingPayment] = Queue() - while (rs.next()) { - q = q :+ parseIncomingPayment(rs) - } - q + statement.executeQuery().map(parseIncomingPayment).toSeq } } } @@ -317,12 +283,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit statement.setLong(1, from) statement.setLong(2, to) statement.setLong(3, System.currentTimeMillis) - val rs = statement.executeQuery() - var q: Queue[IncomingPayment] = Queue() - while (rs.next()) { - q = q :+ parseIncomingPayment(rs) - } - q + statement.executeQuery().map(parseIncomingPayment).toSeq } } } @@ -333,12 +294,7 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit statement.setLong(1, from) statement.setLong(2, to) statement.setLong(3, System.currentTimeMillis) - val rs = statement.executeQuery() - var q: Queue[IncomingPayment] = Queue() - while (rs.next()) { - q = q :+ parseIncomingPayment(rs) - } - q + statement.executeQuery().map(parseIncomingPayment).toSeq } } } @@ -388,31 +344,28 @@ class PgPaymentsDb(implicit ds: DataSource, lock: PgLock) extends PaymentsDb wit """.stripMargin )) { statement => statement.setInt(1, limit) - val rs = statement.executeQuery() - var q: Queue[PlainPayment] = Queue() - while (rs.next()) { - val parentId = rs.getUUIDNullable("parent_id") - val externalId_opt = rs.getStringNullable("external_id") - val paymentHash = rs.getByteVector32FromHex("payment_hash") - val paymentType = rs.getString("payment_type") - val paymentRequest_opt = rs.getStringNullable("payment_request") - val amount_opt = rs.getMilliSatoshiNullable("final_amount") - val createdAt = rs.getLong("created_at") - val completedAt_opt = rs.getLongNullable("completed_at") - val expireAt_opt = rs.getLongNullable("expire_at") + statement.executeQuery() + .map { rs => + val parentId = rs.getUUIDNullable("parent_id") + val externalId_opt = rs.getStringNullable("external_id") + val paymentHash = rs.getByteVector32FromHex("payment_hash") + val paymentType = rs.getString("payment_type") + val paymentRequest_opt = rs.getStringNullable("payment_request") + val amount_opt = rs.getMilliSatoshiNullable("final_amount") + val createdAt = rs.getLong("created_at") + val completedAt_opt = rs.getLongNullable("completed_at") + val expireAt_opt = rs.getLongNullable("expire_at") - val p = if (rs.getString("type") == "received") { - val status: IncomingPaymentStatus = buildIncomingPaymentStatus(amount_opt, paymentRequest_opt, completedAt_opt) - PlainIncomingPayment(paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt, expireAt_opt) - } else { - val preimage_opt = rs.getByteVector32Nullable("payment_preimage") - // note that the resulting status will not contain any details (routes, failures...) - val status: OutgoingPaymentStatus = buildOutgoingPaymentStatus(preimage_opt, None, None, completedAt_opt, None) - PlainOutgoingPayment(parentId, externalId_opt, paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt) - } - q = q :+ p - } - q + if (rs.getString("type") == "received") { + val status: IncomingPaymentStatus = buildIncomingPaymentStatus(amount_opt, paymentRequest_opt, completedAt_opt) + PlainIncomingPayment(paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt, expireAt_opt) + } else { + val preimage_opt = rs.getByteVector32Nullable("payment_preimage") + // note that the resulting status will not contain any details (routes, failures...) + val status: OutgoingPaymentStatus = buildOutgoingPaymentStatus(preimage_opt, None, None, completedAt_opt, None) + PlainOutgoingPayment(parentId, externalId_opt, paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt) + } + }.toSeq } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala index c65cd19e10..a26b1c93c5 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPeersDb.scala @@ -78,8 +78,9 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb { withLock { pg => using(pg.prepareStatement("SELECT data FROM peers WHERE node_id=?")) { statement => statement.setString(1, nodeId.value.toHex) - val rs = statement.executeQuery() - codecSequence(rs, CommonCodecs.nodeaddress).headOption + statement.executeQuery() + .mapCodec(CommonCodecs.nodeaddress) + .headOption } } } @@ -87,14 +88,13 @@ class PgPeersDb(implicit ds: DataSource, lock: PgLock) extends PeersDb { override def listPeers(): Map[PublicKey, NodeAddress] = withMetrics("peers/list", DbBackends.Postgres) { withLock { pg => using(pg.createStatement()) { statement => - val rs = statement.executeQuery("SELECT node_id, data FROM peers") - var m: Map[PublicKey, NodeAddress] = Map() - while (rs.next()) { - val nodeid = PublicKey(rs.getByteVectorFromHex("node_id")) - val nodeaddress = CommonCodecs.nodeaddress.decode(BitVector(rs.getBytes("data"))).require.value - m += (nodeid -> nodeaddress) - } - m + statement.executeQuery("SELECT node_id, data FROM peers") + .map { rs => + val nodeid = PublicKey(rs.getByteVectorFromHex("node_id")) + val nodeaddress = CommonCodecs.nodeaddress.decode(BitVector(rs.getBytes("data"))).require.value + nodeid -> nodeaddress + } + .toMap } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingCommandsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingCommandsDb.scala index 62a2e1ea8b..dee8347273 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingCommandsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingCommandsDb.scala @@ -18,7 +18,7 @@ package fr.acinq.eclair.db.pg import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.eclair.channel.{Command, HtlcSettlementCommand} +import fr.acinq.eclair.channel.HtlcSettlementCommand import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics import fr.acinq.eclair.db.Monitoring.Tags.DbBackends import fr.acinq.eclair.db.PendingCommandsDb @@ -28,7 +28,6 @@ import grizzled.slf4j.Logging import java.sql.Statement import javax.sql.DataSource -import scala.collection.immutable.Queue class PgPendingCommandsDb(implicit ds: DataSource, lock: PgLock) extends PendingCommandsDb with Logging { @@ -85,8 +84,8 @@ class PgPendingCommandsDb(implicit ds: DataSource, lock: PgLock) extends Pending withLock { pg => using(pg.prepareStatement("SELECT htlc_id, data FROM pending_settlement_commands WHERE channel_id=?")) { statement => statement.setString(1, channelId.toHex) - val rs = statement.executeQuery() - codecSequence(rs, cmdCodec) + statement.executeQuery() + .mapCodec(cmdCodec).toSeq } } } @@ -94,12 +93,9 @@ class PgPendingCommandsDb(implicit ds: DataSource, lock: PgLock) extends Pending override def listSettlementCommands(): Seq[(ByteVector32, HtlcSettlementCommand)] = withMetrics("pending-relay/list", DbBackends.Postgres) { withLock { pg => using(pg.prepareStatement("SELECT channel_id, data FROM pending_settlement_commands")) { statement => - val rs = statement.executeQuery() - var q: Queue[(ByteVector32, HtlcSettlementCommand)] = Queue() - while (rs.next()) { - q = q :+ (rs.getByteVector32FromHex("channel_id"), cmdCodec.decode(rs.getByteVector("data").bits).require.value) - } - q + statement.executeQuery() + .map(rs => (rs.getByteVector32FromHex("channel_id"), cmdCodec.decode(rs.getByteVector("data").bits).require.value)) + .toSeq } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgUtils.scala index 8d1359df8b..12fb0e42d6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgUtils.scala @@ -20,6 +20,7 @@ import com.zaxxer.hikari.util.IsolationLevel import fr.acinq.eclair.db.Monitoring.Metrics._ import fr.acinq.eclair.db.Monitoring.Tags import fr.acinq.eclair.db.jdbc.JdbcUtils +import fr.acinq.eclair.db.jdbc.JdbcUtils.ExtendedResultSet._ import fr.acinq.eclair.db.pg.PgUtils.PgLock.LockFailureHandler.LockException import grizzled.slf4j.Logging import org.postgresql.util.{PGInterval, PSQLException} @@ -215,14 +216,14 @@ object PgUtils extends JdbcUtils { private def getCurrentLease(implicit connection: Connection): Option[LockLease] = { using(connection.createStatement()) { statement => - val rs = statement.executeQuery(s"SELECT expires_at, instance, now() > expires_at AS expired FROM $LeaseTable WHERE id = 1") - if (rs.next()) - Some(LockLease( - expiresAt = rs.getTimestamp("expires_at"), - instanceId = UUID.fromString(rs.getString("instance")), - expired = rs.getBoolean("expired"))) - else - None + statement.executeQuery(s"SELECT expires_at, instance, now() > expires_at AS expired FROM $LeaseTable WHERE id = 1") + .map { rs => + LockLease( + expiresAt = rs.getTimestamp("expires_at"), + instanceId = UUID.fromString(rs.getString("instance")), + expired = rs.getBoolean("expired")) + } + .headOption } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala index 4b3c4c8dce..52593cad35 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteAuditDb.scala @@ -31,7 +31,6 @@ import grizzled.slf4j.Logging import java.sql.{Connection, Statement} import java.util.UUID -import scala.collection.immutable.Queue class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { @@ -232,124 +231,117 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging { using(sqlite.prepareStatement("SELECT * FROM sent WHERE timestamp >= ? AND timestamp < ?")) { statement => statement.setLong(1, from) statement.setLong(2, to) - val rs = statement.executeQuery() - var sentByParentId = Map.empty[UUID, PaymentSent] - while (rs.next()) { - val parentId = UUID.fromString(rs.getString("parent_payment_id")) - val part = PaymentSent.PartialPayment( - UUID.fromString(rs.getString("payment_id")), - MilliSatoshi(rs.getLong("amount_msat")), - MilliSatoshi(rs.getLong("fees_msat")), - rs.getByteVector32("to_channel_id"), - None, // we don't store the route in the audit DB - rs.getLong("timestamp")) - val sent = sentByParentId.get(parentId) match { - case Some(s) => s.copy(parts = s.parts :+ part) - case None => PaymentSent( - parentId, - rs.getByteVector32("payment_hash"), - rs.getByteVector32("payment_preimage"), - MilliSatoshi(rs.getLong("recipient_amount_msat")), - PublicKey(rs.getByteVector("recipient_node_id")), - Seq(part)) - } - sentByParentId = sentByParentId + (parentId -> sent) - } - sentByParentId.values.toSeq.sortBy(_.timestamp) + statement.executeQuery() + .foldLeft(Map.empty[UUID, PaymentSent]) { (sentByParentId, rs) => + val parentId = UUID.fromString(rs.getString("parent_payment_id")) + val part = PaymentSent.PartialPayment( + UUID.fromString(rs.getString("payment_id")), + MilliSatoshi(rs.getLong("amount_msat")), + MilliSatoshi(rs.getLong("fees_msat")), + rs.getByteVector32("to_channel_id"), + None, // we don't store the route in the audit DB + rs.getLong("timestamp")) + val sent = sentByParentId.get(parentId) match { + case Some(s) => s.copy(parts = s.parts :+ part) + case None => PaymentSent( + parentId, + rs.getByteVector32("payment_hash"), + rs.getByteVector32("payment_preimage"), + MilliSatoshi(rs.getLong("recipient_amount_msat")), + PublicKey(rs.getByteVector("recipient_node_id")), + Seq(part)) + } + sentByParentId + (parentId -> sent) + }.values.toSeq.sortBy(_.timestamp) } override def listReceived(from: Long, to: Long): Seq[PaymentReceived] = using(sqlite.prepareStatement("SELECT * FROM received WHERE timestamp >= ? AND timestamp < ?")) { statement => statement.setLong(1, from) statement.setLong(2, to) - val rs = statement.executeQuery() - var receivedByHash = Map.empty[ByteVector32, PaymentReceived] - while (rs.next()) { - val paymentHash = rs.getByteVector32("payment_hash") - val part = PaymentReceived.PartialPayment( - MilliSatoshi(rs.getLong("amount_msat")), - rs.getByteVector32("from_channel_id"), - rs.getLong("timestamp")) - val received = receivedByHash.get(paymentHash) match { - case Some(r) => r.copy(parts = r.parts :+ part) - case None => PaymentReceived(paymentHash, Seq(part)) - } - receivedByHash = receivedByHash + (paymentHash -> received) - } - receivedByHash.values.toSeq.sortBy(_.timestamp) + statement.executeQuery() + .foldLeft(Map.empty[ByteVector32, PaymentReceived]) { (receivedByHash, rs) => + val paymentHash = rs.getByteVector32("payment_hash") + val part = PaymentReceived.PartialPayment( + MilliSatoshi(rs.getLong("amount_msat")), + rs.getByteVector32("from_channel_id"), + rs.getLong("timestamp")) + val received = receivedByHash.get(paymentHash) match { + case Some(r) => r.copy(parts = r.parts :+ part) + case None => PaymentReceived(paymentHash, Seq(part)) + } + receivedByHash + (paymentHash -> received) + }.values.toSeq.sortBy(_.timestamp) } override def listRelayed(from: Long, to: Long): Seq[PaymentRelayed] = { - var trampolineByHash = Map.empty[ByteVector32, (MilliSatoshi, PublicKey)] - using(sqlite.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp >= ? AND timestamp < ?")) { statement => + val trampolineByHash = using(sqlite.prepareStatement("SELECT * FROM relayed_trampoline WHERE timestamp >= ? AND timestamp < ?")) { statement => statement.setLong(1, from) statement.setLong(2, to) - val rs = statement.executeQuery() - while (rs.next()) { - val paymentHash = rs.getByteVector32("payment_hash") - val amount = MilliSatoshi(rs.getLong("amount_msat")) - val nodeId = PublicKey(rs.getByteVector("next_node_id")) - trampolineByHash += (paymentHash -> (amount, nodeId)) - } + statement.executeQuery() + .map { rs => + val paymentHash = rs.getByteVector32("payment_hash") + val amount = MilliSatoshi(rs.getLong("amount_msat")) + val nodeId = PublicKey(rs.getByteVector("next_node_id")) + paymentHash -> (amount, nodeId) + } + .toMap } - using(sqlite.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ?")) { statement => + val relayedByHash = using(sqlite.prepareStatement("SELECT * FROM relayed WHERE timestamp >= ? AND timestamp < ?")) { statement => statement.setLong(1, from) statement.setLong(2, to) - val rs = statement.executeQuery() - var relayedByHash = Map.empty[ByteVector32, Seq[RelayedPart]] - while (rs.next()) { - val paymentHash = rs.getByteVector32("payment_hash") - val part = RelayedPart( - rs.getByteVector32("channel_id"), - MilliSatoshi(rs.getLong("amount_msat")), - rs.getString("direction"), - rs.getString("relay_type"), - rs.getLong("timestamp")) - relayedByHash = relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part)) - } - relayedByHash.flatMap { - case (paymentHash, parts) => - // We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel). - // NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch. - val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount) - val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount) - parts.headOption match { - case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map { - case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp) - } - case Some(RelayedPart(_, _, _, "trampoline", timestamp)) => - val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey)) - TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount, timestamp) :: Nil - case _ => Nil - } - }.toSeq.sortBy(_.timestamp) + statement.executeQuery() + .foldLeft(Map.empty[ByteVector32, Seq[RelayedPart]]) { (relayedByHash, rs) => + val paymentHash = rs.getByteVector32("payment_hash") + val part = RelayedPart( + rs.getByteVector32("channel_id"), + MilliSatoshi(rs.getLong("amount_msat")), + rs.getString("direction"), + rs.getString("relay_type"), + rs.getLong("timestamp")) + relayedByHash + (paymentHash -> (relayedByHash.getOrElse(paymentHash, Nil) :+ part)) + } } + relayedByHash.flatMap { + case (paymentHash, parts) => + // We may have been routing multiple payments for the same payment_hash (MPP) in both cases (trampoline and channel). + // NB: we may link the wrong in-out parts, but the overall sum will be correct: we sort by amounts to minimize the risk of mismatch. + val incoming = parts.filter(_.direction == "IN").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount) + val outgoing = parts.filter(_.direction == "OUT").map(p => PaymentRelayed.Part(p.amount, p.channelId)).sortBy(_.amount) + parts.headOption match { + case Some(RelayedPart(_, _, _, "channel", timestamp)) => incoming.zip(outgoing).map { + case (in, out) => ChannelPaymentRelayed(in.amount, out.amount, paymentHash, in.channelId, out.channelId, timestamp) + } + case Some(RelayedPart(_, _, _, "trampoline", timestamp)) => + val (nextTrampolineAmount, nextTrampolineNodeId) = trampolineByHash.getOrElse(paymentHash, (0 msat, PlaceHolderPubKey)) + TrampolinePaymentRelayed(paymentHash, incoming, outgoing, nextTrampolineNodeId, nextTrampolineAmount, timestamp) :: Nil + case _ => Nil + } + }.toSeq.sortBy(_.timestamp) } override def listNetworkFees(from: Long, to: Long): Seq[NetworkFee] = using(sqlite.prepareStatement("SELECT * FROM network_fees WHERE timestamp >= ? AND timestamp < ? ORDER BY timestamp")) { statement => statement.setLong(1, from) statement.setLong(2, to) - val rs = statement.executeQuery() - var q: Queue[NetworkFee] = Queue() - while (rs.next()) { - q = q :+ NetworkFee( - remoteNodeId = PublicKey(rs.getByteVector("node_id")), - channelId = rs.getByteVector32("channel_id"), - txId = rs.getByteVector32("tx_id"), - fee = Satoshi(rs.getLong("fee_sat")), - txType = rs.getString("tx_type"), - timestamp = rs.getLong("timestamp")) - } - q + statement.executeQuery() + .map { rs => + NetworkFee( + remoteNodeId = PublicKey(rs.getByteVector("node_id")), + channelId = rs.getByteVector32("channel_id"), + txId = rs.getByteVector32("tx_id"), + fee = Satoshi(rs.getLong("fee_sat")), + txType = rs.getString("tx_type"), + timestamp = rs.getLong("timestamp")) + }.toSeq } override def stats(from: Long, to: Long): Seq[Stats] = { - val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { case (feeByChannelId, f) => + val networkFees = listNetworkFees(from, to).foldLeft(Map.empty[ByteVector32, Satoshi]) { (feeByChannelId, f) => feeByChannelId + (f.channelId -> (feeByChannelId.getOrElse(f.channelId, 0 sat) + f.fee)) } case class Relayed(amount: MilliSatoshi, fee: MilliSatoshi, direction: String) - val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { case (previous, e) => + val relayed = listRelayed(from, to).foldLeft(Map.empty[ByteVector32, Seq[Relayed]]) { (previous, e) => // NB: we must avoid counting the fee twice: we associate it to the outgoing channels rather than the incoming ones. val current = e match { case c: ChannelPaymentRelayed => Map( diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala index cc442bafa0..732c6fe603 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala @@ -27,7 +27,6 @@ import fr.acinq.eclair.wire.internal.channel.ChannelCodecs.stateDataCodec import grizzled.slf4j.Logging import java.sql.{Connection, Statement} -import scala.collection.immutable.Queue class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging { @@ -135,8 +134,8 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging { override def listLocalChannels(): Seq[HasCommitments] = withMetrics("channels/list-local-channels", DbBackends.Sqlite) { using(sqlite.createStatement) { statement => - val rs = statement.executeQuery("SELECT data FROM local_channels WHERE is_closed=0") - codecSequence(rs, stateDataCodec) + statement.executeQuery("SELECT data FROM local_channels WHERE is_closed=0") + .mapCodec(stateDataCodec).toSeq } } @@ -154,12 +153,9 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging { using(sqlite.prepareStatement("SELECT payment_hash, cltv_expiry FROM htlc_infos WHERE channel_id=? AND commitment_number=?")) { statement => statement.setBytes(1, channelId.toArray) statement.setLong(2, commitmentNumber) - val rs = statement.executeQuery - var q: Queue[(ByteVector32, CltvExpiry)] = Queue() - while (rs.next()) { - q = q :+ (ByteVector32(rs.getByteVector32("payment_hash")), CltvExpiry(rs.getLong("cltv_expiry"))) - } - q + statement.executeQuery + .map(rs => (ByteVector32(rs.getByteVector32("payment_hash")), CltvExpiry(rs.getLong("cltv_expiry")))) + .toSeq } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteFeeratesDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteFeeratesDb.scala index 4fed233e9c..faa8d3f1da 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteFeeratesDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteFeeratesDb.scala @@ -16,14 +16,16 @@ package fr.acinq.eclair.db.sqlite -import java.sql.{Connection, Statement} import fr.acinq.bitcoin.Satoshi import fr.acinq.eclair.blockchain.fee.{FeeratePerKB, FeeratesPerKB} import fr.acinq.eclair.db.FeeratesDb import grizzled.slf4j.Logging +import java.sql.{Connection, Statement} + class SqliteFeeratesDb(sqlite: Connection) extends FeeratesDb with Logging { + import SqliteUtils.ExtendedResultSet._ import SqliteUtils._ val DB_NAME = "feerates" @@ -89,22 +91,21 @@ class SqliteFeeratesDb(sqlite: Connection) extends FeeratesDb with Logging { override def getFeerates(): Option[FeeratesPerKB] = { using(sqlite.prepareStatement("SELECT rate_block_1, rate_blocks_2, rate_blocks_6, rate_blocks_12, rate_blocks_36, rate_blocks_72, rate_blocks_144, rate_blocks_1008 FROM feerates_per_kb")) { statement => - val rs = statement.executeQuery() - if (rs.next()) { - Some(FeeratesPerKB( - // NB: we don't bother storing this value in the DB, because it's unused on mobile. - mempoolMinFee = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_1008"))), - block_1 = FeeratePerKB(Satoshi(rs.getLong("rate_block_1"))), - blocks_2 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_2"))), - blocks_6 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_6"))), - blocks_12 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_12"))), - blocks_36 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_36"))), - blocks_72 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_72"))), - blocks_144 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_144"))), - blocks_1008 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_1008"))))) - } else { - None - } + statement.executeQuery() + .map { rs => + FeeratesPerKB( + // NB: we don't bother storing this value in the DB, because it's unused on mobile. + mempoolMinFee = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_1008"))), + block_1 = FeeratePerKB(Satoshi(rs.getLong("rate_block_1"))), + blocks_2 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_2"))), + blocks_6 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_6"))), + blocks_12 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_12"))), + blocks_36 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_36"))), + blocks_72 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_72"))), + blocks_144 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_144"))), + blocks_1008 = FeeratePerKB(Satoshi(rs.getLong("rate_blocks_1008")))) + } + .headOption } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala index 7152f8cc46..c0f3a1d488 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteNetworkDb.scala @@ -82,8 +82,9 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging { override def getNode(nodeId: Crypto.PublicKey): Option[NodeAnnouncement] = withMetrics("network/get-node", DbBackends.Sqlite) { using(sqlite.prepareStatement("SELECT data FROM nodes WHERE node_id=?")) { statement => statement.setBytes(1, nodeId.value.toArray) - val rs = statement.executeQuery() - codecSequence(rs, nodeAnnouncementCodec).headOption + statement.executeQuery() + .mapCodec(nodeAnnouncementCodec) + .headOption } } @@ -96,8 +97,8 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging { override def listNodes(): Seq[NodeAnnouncement] = withMetrics("network/list-nodes", DbBackends.Sqlite) { using(sqlite.createStatement()) { statement => - val rs = statement.executeQuery("SELECT data FROM nodes") - codecSequence(rs, nodeAnnouncementCodec) + statement.executeQuery("SELECT data FROM nodes") + .mapCodec(nodeAnnouncementCodec).toSeq } } @@ -122,17 +123,15 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging { override def listChannels(): SortedMap[ShortChannelId, PublicChannel] = withMetrics("network/list-channels", DbBackends.Sqlite) { using(sqlite.createStatement()) { statement => - val rs = statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels") - var m = SortedMap.empty[ShortChannelId, PublicChannel] - while (rs.next()) { - val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value - val txId = ByteVector32.fromValidHex(rs.getString("txid")) - val capacity = rs.getLong("capacity_sat") - val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value) - val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value) - m = m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None)) - } - m + statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels") + .foldLeft(SortedMap.empty[ShortChannelId, PublicChannel]) { (m, rs) => + val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value + val txId = ByteVector32.fromValidHex(rs.getString("txid")) + val capacity = rs.getLong("capacity_sat") + val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value) + val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value) + m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None)) + } } } @@ -171,8 +170,7 @@ class SqliteNetworkDb(sqlite: Connection) extends NetworkDb with Logging { override def isPruned(shortChannelId: ShortChannelId): Boolean = withMetrics("network/is-pruned", DbBackends.Sqlite) { using(sqlite.prepareStatement("SELECT short_channel_id from pruned WHERE short_channel_id=?")) { statement => statement.setLong(1, shortChannelId.toLong) - val rs = statement.executeQuery() - rs.next() + statement.executeQuery().nonEmpty } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala index 921f451d9a..2ecdec7550 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePaymentsDb.scala @@ -16,8 +16,6 @@ package fr.acinq.eclair.db.sqlite -import java.sql.{Connection, ResultSet, Statement} -import java.util.UUID import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} import fr.acinq.eclair.MilliSatoshi @@ -32,7 +30,8 @@ import scodec.Attempt import scodec.bits.BitVector import scodec.codecs._ -import scala.collection.immutable.Queue +import java.sql.{Connection, ResultSet, Statement} +import java.util.UUID import scala.concurrent.duration._ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { @@ -217,36 +216,21 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { override def getOutgoingPayment(id: UUID): Option[OutgoingPayment] = withMetrics("payments/get-outgoing", DbBackends.Sqlite) { using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE id = ?")) { statement => statement.setString(1, id.toString) - val rs = statement.executeQuery() - if (rs.next()) { - Some(parseOutgoingPayment(rs)) - } else { - None - } + statement.executeQuery().map(parseOutgoingPayment).headOption } } override def listOutgoingPayments(parentId: UUID): Seq[OutgoingPayment] = withMetrics("payments/list-outgoing-by-parent-id", DbBackends.Sqlite) { using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE parent_id = ? ORDER BY created_at")) { statement => statement.setString(1, parentId.toString) - val rs = statement.executeQuery() - var q: Queue[OutgoingPayment] = Queue() - while (rs.next()) { - q = q :+ parseOutgoingPayment(rs) - } - q + statement.executeQuery().map(parseOutgoingPayment).toSeq } } override def listOutgoingPayments(paymentHash: ByteVector32): Seq[OutgoingPayment] = withMetrics("payments/list-outgoing-by-payment-hash", DbBackends.Sqlite) { using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE payment_hash = ? ORDER BY created_at")) { statement => statement.setBytes(1, paymentHash.toArray) - val rs = statement.executeQuery() - var q: Queue[OutgoingPayment] = Queue() - while (rs.next()) { - q = q :+ parseOutgoingPayment(rs) - } - q + statement.executeQuery().map(parseOutgoingPayment).toSeq } } @@ -254,12 +238,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { using(sqlite.prepareStatement("SELECT * FROM sent_payments WHERE created_at >= ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) - val rs = statement.executeQuery() - var q: Queue[OutgoingPayment] = Queue() - while (rs.next()) { - q = q :+ parseOutgoingPayment(rs) - } - q + statement.executeQuery().map(parseOutgoingPayment).toSeq } } @@ -308,12 +287,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { override def getIncomingPayment(paymentHash: ByteVector32): Option[IncomingPayment] = withMetrics("payments/get-incoming", DbBackends.Sqlite) { using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE payment_hash = ?")) { statement => statement.setBytes(1, paymentHash.toArray) - val rs = statement.executeQuery() - if (rs.next()) { - Some(parseIncomingPayment(rs)) - } else { - None - } + statement.executeQuery().map(parseIncomingPayment).headOption } } @@ -321,12 +295,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE created_at > ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) - val rs = statement.executeQuery() - var q: Queue[IncomingPayment] = Queue() - while (rs.next()) { - q = q :+ parseIncomingPayment(rs) - } - q + statement.executeQuery().map(parseIncomingPayment).toSeq } } @@ -334,12 +303,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { using(sqlite.prepareStatement("SELECT * FROM received_payments WHERE received_msat > 0 AND created_at > ? AND created_at < ? ORDER BY created_at")) { statement => statement.setLong(1, from) statement.setLong(2, to) - val rs = statement.executeQuery() - var q: Queue[IncomingPayment] = Queue() - while (rs.next()) { - q = q :+ parseIncomingPayment(rs) - } - q + statement.executeQuery().map(parseIncomingPayment).toSeq } } @@ -348,12 +312,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { statement.setLong(1, from) statement.setLong(2, to) statement.setLong(3, System.currentTimeMillis) - val rs = statement.executeQuery() - var q: Queue[IncomingPayment] = Queue() - while (rs.next()) { - q = q :+ parseIncomingPayment(rs) - } - q + statement.executeQuery().map(parseIncomingPayment).toSeq } } @@ -362,12 +321,7 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { statement.setLong(1, from) statement.setLong(2, to) statement.setLong(3, System.currentTimeMillis) - val rs = statement.executeQuery() - var q: Queue[IncomingPayment] = Queue() - while (rs.next()) { - q = q :+ parseIncomingPayment(rs) - } - q + statement.executeQuery().map(parseIncomingPayment).toSeq } } @@ -415,31 +369,28 @@ class SqlitePaymentsDb(sqlite: Connection) extends PaymentsDb with Logging { """.stripMargin )) { statement => statement.setInt(1, limit) - val rs = statement.executeQuery() - var q: Queue[PlainPayment] = Queue() - while (rs.next()) { - val parentId = rs.getUUIDNullable("parent_id") - val externalId_opt = rs.getStringNullable("external_id") - val paymentHash = rs.getByteVector32("payment_hash") - val paymentType = rs.getString("payment_type") - val paymentRequest_opt = rs.getStringNullable("payment_request") - val amount_opt = rs.getMilliSatoshiNullable("final_amount") - val createdAt = rs.getLong("created_at") - val completedAt_opt = rs.getLongNullable("completed_at") - val expireAt_opt = rs.getLongNullable("expire_at") - - val p = if (rs.getString("type") == "received") { - val status: IncomingPaymentStatus = buildIncomingPaymentStatus(amount_opt, paymentRequest_opt, completedAt_opt) - PlainIncomingPayment(paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt, expireAt_opt) - } else { - val preimage_opt = rs.getByteVector32Nullable("payment_preimage") - // note that the resulting status will not contain any details (routes, failures...) - val status: OutgoingPaymentStatus = buildOutgoingPaymentStatus(preimage_opt, None, None, completedAt_opt, None) - PlainOutgoingPayment(parentId, externalId_opt, paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt) - } - q = q :+ p - } - q + statement.executeQuery() + .map { rs => + val parentId = rs.getUUIDNullable("parent_id") + val externalId_opt = rs.getStringNullable("external_id") + val paymentHash = rs.getByteVector32("payment_hash") + val paymentType = rs.getString("payment_type") + val paymentRequest_opt = rs.getStringNullable("payment_request") + val amount_opt = rs.getMilliSatoshiNullable("final_amount") + val createdAt = rs.getLong("created_at") + val completedAt_opt = rs.getLongNullable("completed_at") + val expireAt_opt = rs.getLongNullable("expire_at") + + if (rs.getString("type") == "received") { + val status: IncomingPaymentStatus = buildIncomingPaymentStatus(amount_opt, paymentRequest_opt, completedAt_opt) + PlainIncomingPayment(paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt, expireAt_opt) + } else { + val preimage_opt = rs.getByteVector32Nullable("payment_preimage") + // note that the resulting status will not contain any details (routes, failures...) + val status: OutgoingPaymentStatus = buildOutgoingPaymentStatus(preimage_opt, None, None, completedAt_opt, None) + PlainOutgoingPayment(parentId, externalId_opt, paymentHash, paymentType, amount_opt, paymentRequest_opt, status, createdAt, completedAt_opt) + } + }.toSeq } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala index cbd520efcf..9125b1daee 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePeersDb.scala @@ -21,7 +21,7 @@ import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics import fr.acinq.eclair.db.Monitoring.Tags.DbBackends import fr.acinq.eclair.db.PeersDb -import fr.acinq.eclair.db.sqlite.SqliteUtils.{codecSequence, getVersion, setVersion, using} +import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, setVersion, using} import fr.acinq.eclair.wire.protocol._ import scodec.bits.BitVector @@ -69,21 +69,21 @@ class SqlitePeersDb(sqlite: Connection) extends PeersDb { override def getPeer(nodeId: PublicKey): Option[NodeAddress] = withMetrics("peers/get", DbBackends.Sqlite) { using(sqlite.prepareStatement("SELECT data FROM peers WHERE node_id=?")) { statement => statement.setBytes(1, nodeId.value.toArray) - val rs = statement.executeQuery() - codecSequence(rs, CommonCodecs.nodeaddress).headOption + statement.executeQuery() + .mapCodec(CommonCodecs.nodeaddress) + .headOption } } override def listPeers(): Map[PublicKey, NodeAddress] = withMetrics("peers/list", DbBackends.Sqlite) { using(sqlite.createStatement()) { statement => - val rs = statement.executeQuery("SELECT node_id, data FROM peers") - var m: Map[PublicKey, NodeAddress] = Map() - while (rs.next()) { - val nodeid = PublicKey(rs.getByteVector("node_id")) - val nodeaddress = CommonCodecs.nodeaddress.decode(BitVector(rs.getBytes("data"))).require.value - m += (nodeid -> nodeaddress) - } - m + statement.executeQuery("SELECT node_id, data FROM peers") + .map { rs => + val nodeid = PublicKey(rs.getByteVector("node_id")) + val nodeaddress = CommonCodecs.nodeaddress.decode(BitVector(rs.getBytes("data"))).require.value + nodeid -> nodeaddress + } + .toMap } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingCommandsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingCommandsDb.scala index 974a53dd68..cbd2965c3d 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingCommandsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingCommandsDb.scala @@ -25,7 +25,6 @@ import fr.acinq.eclair.wire.internal.CommandCodecs.cmdCodec import grizzled.slf4j.Logging import java.sql.{Connection, Statement} -import scala.collection.immutable.Queue class SqlitePendingCommandsDb(sqlite: Connection) extends PendingCommandsDb with Logging { @@ -74,19 +73,16 @@ class SqlitePendingCommandsDb(sqlite: Connection) extends PendingCommandsDb with override def listSettlementCommands(channelId: ByteVector32): Seq[HtlcSettlementCommand] = withMetrics("pending-relay/list-channel", DbBackends.Sqlite) { using(sqlite.prepareStatement("SELECT data FROM pending_settlement_commands WHERE channel_id=?")) { statement => statement.setBytes(1, channelId.toArray) - val rs = statement.executeQuery() - codecSequence(rs, cmdCodec) + statement.executeQuery() + .mapCodec(cmdCodec).toSeq } } override def listSettlementCommands(): Seq[(ByteVector32, HtlcSettlementCommand)] = withMetrics("pending-relay/list", DbBackends.Sqlite) { using(sqlite.prepareStatement("SELECT channel_id, data FROM pending_settlement_commands")) { statement => - val rs = statement.executeQuery() - var q: Queue[(ByteVector32, HtlcSettlementCommand)] = Queue() - while (rs.next()) { - q = q :+ (rs.getByteVector32("channel_id"), cmdCodec.decode(rs.getByteVector("data").bits).require.value) - } - q + statement.executeQuery() + .map(rs => (rs.getByteVector32("channel_id"), cmdCodec.decode(rs.getByteVector("data").bits).require.value)) + .toSeq } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala index c54903c4b0..44fccac7b9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteUtils.scala @@ -16,10 +16,10 @@ package fr.acinq.eclair.db.sqlite -import java.sql.{Connection, Statement} - import fr.acinq.eclair.db.jdbc.JdbcUtils +import java.sql.Connection + object SqliteUtils extends JdbcUtils { /** diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala index 21fb548483..ee8e32c96b 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/ReconnectionTaskSpec.scala @@ -216,7 +216,7 @@ class ReconnectionTaskSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike peer.send(reconnectionTask, Peer.Connect(remoteNodeId, None)) // assert our mock server got an incoming connection (the client was spawned with the address from node_announcement) - awaitCond(mockServer.accept() != null, max = 30 seconds, interval = 1 second) + awaitCond(mockServer.accept() != null, max = 60 seconds, interval = 1 second) mockServer.close() }