diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala index 5a7db96d7b..731a88c3be 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala @@ -35,7 +35,7 @@ import fr.acinq.eclair.channel.TxPublisher.{PublishRawTx, PublishTx, SetChannelI import fr.acinq.eclair.crypto.ShaChain import fr.acinq.eclair.crypto.keymanager.ChannelKeyManager import fr.acinq.eclair.db.DbEventHandler.ChannelEvent.EventType -import fr.acinq.eclair.db.PendingRelayDb +import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.db.pg.PgUtils.PgLock.logger import fr.acinq.eclair.io.Peer import fr.acinq.eclair.payment.PaymentSettlingOnChain @@ -1875,7 +1875,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId onTransition { case _ -> CLOSING => - PendingRelayDb.getPendingFailsAndFulfills(nodeParams.db.pendingRelay, nextStateData.asInstanceOf[HasCommitments].channelId) match { + PendingCommandsDb.getSettlementCommands(nodeParams.db.pendingCommands, nextStateData.asInstanceOf[HasCommitments].channelId) match { case Nil => log.debug("nothing to replay") case cmds => @@ -1883,7 +1883,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId cmds.foreach(self ! _) // they all have commit = false } case SYNCING -> (NORMAL | SHUTDOWN) => - PendingRelayDb.getPendingFailsAndFulfills(nodeParams.db.pendingRelay, nextStateData.asInstanceOf[HasCommitments].channelId) match { + PendingCommandsDb.getSettlementCommands(nodeParams.db.pendingCommands, nextStateData.asInstanceOf[HasCommitments].channelId) match { case Nil => log.debug("nothing to replay") case cmds => @@ -2109,7 +2109,7 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId } else { // There might be pending fulfill commands that we haven't relayed yet. // Since this involves a DB call, we only want to check it if all the previous checks failed (this is the slow path). - val pendingRelayFulfills = nodeParams.db.pendingRelay.listPendingRelay(d.channelId).collect { case c: CMD_FULFILL_HTLC => c.id } + val pendingRelayFulfills = nodeParams.db.pendingCommands.listSettlementCommands(d.channelId).collect { case c: CMD_FULFILL_HTLC => c.id } val offendingPendingRelayFulfills = almostTimedOutIncoming.filter(htlc => pendingRelayFulfills.contains(htlc.id)) if (offendingPendingRelayFulfills.nonEmpty) { handleLocalError(HtlcsWillTimeoutUpstream(d.channelId, offendingPendingRelayFulfills), d, Some(c)) @@ -2520,13 +2520,13 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId */ def acking(channelId: ByteVector32, cmd: HtlcSettlementCommand): FSM.State[fr.acinq.eclair.channel.State, Data] = { log.debug("scheduling acknowledgement of cmd id={}", cmd.id) - context.system.scheduler.scheduleOnce(10 seconds)(PendingRelayDb.ackCommand(nodeParams.db.pendingRelay, channelId, cmd))(context.system.dispatcher) + context.system.scheduler.scheduleOnce(10 seconds)(PendingCommandsDb.ackSettlementCommand(nodeParams.db.pendingCommands, channelId, cmd))(context.system.dispatcher) state } def acking(updates: List[UpdateMessage]): FSM.State[fr.acinq.eclair.channel.State, Data] = { log.debug("scheduling acknowledgement of cmds ids={}", updates.collect { case s: HtlcSettlementMessage => s.id }.mkString(",")) - context.system.scheduler.scheduleOnce(10 seconds)(PendingRelayDb.ackPendingFailsAndFulfills(nodeParams.db.pendingRelay, updates))(context.system.dispatcher) + context.system.scheduler.scheduleOnce(10 seconds)(PendingCommandsDb.ackSettlementCommands(nodeParams.db.pendingCommands, updates))(context.system.dispatcher) state } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala index cb8179f877..b80da9ef82 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/Databases.scala @@ -38,7 +38,7 @@ trait Databases { def channels: ChannelsDb def peers: PeersDb def payments: PaymentsDb - def pendingRelay: PendingRelayDb + def pendingCommands: PendingCommandsDb //@formatter:on } @@ -59,7 +59,7 @@ object Databases extends Logging { channels: SqliteChannelsDb, peers: SqlitePeersDb, payments: SqlitePaymentsDb, - pendingRelay: SqlitePendingRelayDb, + pendingCommands: SqlitePendingCommandsDb, private val backupConnection: Connection) extends Databases with FileBackup { override def backup(backupFile: File): Unit = SqliteUtils.using(backupConnection.createStatement()) { statement => { @@ -75,7 +75,7 @@ object Databases extends Logging { channels = new SqliteChannelsDb(eclairJdbc), peers = new SqlitePeersDb(eclairJdbc), payments = new SqlitePaymentsDb(eclairJdbc), - pendingRelay = new SqlitePendingRelayDb(eclairJdbc), + pendingCommands = new SqlitePendingCommandsDb(eclairJdbc), backupConnection = eclairJdbc ) } @@ -85,7 +85,7 @@ object Databases extends Logging { channels: PgChannelsDb, peers: PgPeersDb, payments: PgPaymentsDb, - pendingRelay: PgPendingRelayDb, + pendingCommands: PgPendingCommandsDb, dataSource: HikariDataSource, lock: PgLock) extends Databases with ExclusiveLock { override def obtainExclusiveLock(): Unit = lock.obtainExclusiveLock(dataSource) @@ -119,7 +119,7 @@ object Databases extends Logging { channels = new PgChannelsDb, peers = new PgPeersDb, payments = new PgPaymentsDb, - pendingRelay = new PgPendingRelayDb, + pendingCommands = new PgPendingCommandsDb, dataSource = ds, lock = lock) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/PendingRelayDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/PendingCommandsDb.scala similarity index 63% rename from eclair-core/src/main/scala/fr/acinq/eclair/db/PendingRelayDb.scala rename to eclair-core/src/main/scala/fr/acinq/eclair/db/PendingCommandsDb.scala index f2f60f52f5..47bf04f163 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/PendingRelayDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/PendingCommandsDb.scala @@ -16,14 +16,14 @@ package fr.acinq.eclair.db -import java.io.Closeable - -import akka.actor.{ActorContext, ActorRef} +import akka.actor.ActorRef import akka.event.LoggingAdapter import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.channel._ import fr.acinq.eclair.wire.protocol.{UpdateFailHtlc, UpdateFailMalformedHtlc, UpdateFulfillHtlc, UpdateMessage} +import java.io.Closeable + /** * This database stores CMD_FULFILL_HTLC and CMD_FAIL_HTLC that we have received from downstream * (either directly via UpdateFulfillHtlc or by extracting the value from the @@ -36,48 +36,48 @@ import fr.acinq.eclair.wire.protocol.{UpdateFailHtlc, UpdateFailMalformedHtlc, U * to handle all corner cases. * */ -trait PendingRelayDb extends Closeable { +trait PendingCommandsDb extends Closeable { - def addPendingRelay(channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit + def addSettlementCommand(channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit - def removePendingRelay(channelId: ByteVector32, htlcId: Long): Unit + def removeSettlementCommand(channelId: ByteVector32, htlcId: Long): Unit - def listPendingRelay(channelId: ByteVector32): Seq[HtlcSettlementCommand] + def listSettlementCommands(channelId: ByteVector32): Seq[HtlcSettlementCommand] - def listPendingRelay(): Set[(ByteVector32, Long)] + def listSettlementCommands(): Seq[(ByteVector32, HtlcSettlementCommand)] } -object PendingRelayDb { +object PendingCommandsDb { /** * We store [[CMD_FULFILL_HTLC]]/[[CMD_FAIL_HTLC]]/[[CMD_FAIL_MALFORMED_HTLC]] * in a database because we don't want to lose preimages, or to forget to fail * incoming htlcs, which would lead to unwanted channel closings. */ - def safeSend(register: ActorRef, db: PendingRelayDb, channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit = { + def safeSend(register: ActorRef, db: PendingCommandsDb, channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit = { // htlc settlement commands don't have replyTo register ! Register.Forward(ActorRef.noSender, channelId, cmd) // we store the command in a db (note that this happens *after* forwarding the command to the channel, so we don't add latency) - db.addPendingRelay(channelId, cmd) + db.addSettlementCommand(channelId, cmd) } - def ackCommand(db: PendingRelayDb, channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit = { - db.removePendingRelay(channelId, cmd.id) + def ackSettlementCommand(db: PendingCommandsDb, channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit = { + db.removeSettlementCommand(channelId, cmd.id) } - def ackPendingFailsAndFulfills(db: PendingRelayDb, updates: List[UpdateMessage])(implicit log: LoggingAdapter): Unit = updates.collect { + def ackSettlementCommands(db: PendingCommandsDb, updates: List[UpdateMessage])(implicit log: LoggingAdapter): Unit = updates.collect { case u: UpdateFulfillHtlc => log.debug(s"fulfill acked for htlcId=${u.id}") - db.removePendingRelay(u.channelId, u.id) + db.removeSettlementCommand(u.channelId, u.id) case u: UpdateFailHtlc => log.debug(s"fail acked for htlcId=${u.id}") - db.removePendingRelay(u.channelId, u.id) + db.removeSettlementCommand(u.channelId, u.id) case u: UpdateFailMalformedHtlc => log.debug(s"fail-malformed acked for htlcId=${u.id}") - db.removePendingRelay(u.channelId, u.id) + db.removeSettlementCommand(u.channelId, u.id) } - def getPendingFailsAndFulfills(db: PendingRelayDb, channelId: ByteVector32)(implicit log: LoggingAdapter): Seq[HtlcSettlementCommand] = { - db.listPendingRelay(channelId) + def getSettlementCommands(db: PendingCommandsDb, channelId: ByteVector32)(implicit log: LoggingAdapter): Seq[HtlcSettlementCommand] = { + db.listSettlementCommands(channelId) } } \ No newline at end of file diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala index e76deb7910..445b3856fb 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgChannelsDb.scala @@ -126,7 +126,7 @@ class PgChannelsDb(implicit ds: DataSource, lock: PgLock) extends ChannelsDb wit override def removeChannel(channelId: ByteVector32): Unit = withMetrics("channels/remove-channel", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("DELETE FROM pending_relay WHERE channel_id=?")) { statement => + using(pg.prepareStatement("DELETE FROM pending_settlement_commands WHERE channel_id=?")) { statement => statement.setString(1, channelId.toHex) statement.executeUpdate() } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingRelayDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingCommandsDb.scala similarity index 54% rename from eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingRelayDb.scala rename to eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingCommandsDb.scala index a4d9e40275..62a2e1ea8b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingRelayDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgPendingCommandsDb.scala @@ -21,28 +21,38 @@ import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.channel.{Command, HtlcSettlementCommand} import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics import fr.acinq.eclair.db.Monitoring.Tags.DbBackends -import fr.acinq.eclair.db.PendingRelayDb +import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.db.pg.PgUtils._ import fr.acinq.eclair.wire.internal.CommandCodecs.cmdCodec +import grizzled.slf4j.Logging +import java.sql.Statement import javax.sql.DataSource import scala.collection.immutable.Queue -class PgPendingRelayDb(implicit ds: DataSource, lock: PgLock) extends PendingRelayDb { +class PgPendingCommandsDb(implicit ds: DataSource, lock: PgLock) extends PendingCommandsDb with Logging { import PgUtils.ExtendedResultSet._ import PgUtils._ import lock._ val DB_NAME = "pending_relay" - val CURRENT_VERSION = 1 + val CURRENT_VERSION = 2 inTransaction { pg => using(pg.createStatement()) { statement => + + def migration12(statement: Statement): Unit = { + statement.executeUpdate("ALTER TABLE pending_relay RENAME TO pending_settlement_commands") + } + getVersion(statement, DB_NAME) match { case None => // note: should we use a foreign key to local_channels table here? - statement.executeUpdate("CREATE TABLE pending_relay (channel_id TEXT NOT NULL, htlc_id BIGINT NOT NULL, data BYTEA NOT NULL, PRIMARY KEY(channel_id, htlc_id))") + statement.executeUpdate("CREATE TABLE pending_settlement_commands (channel_id TEXT NOT NULL, htlc_id BIGINT NOT NULL, data BYTEA NOT NULL, PRIMARY KEY(channel_id, htlc_id))") + case Some(v@1) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration12(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } @@ -50,9 +60,9 @@ class PgPendingRelayDb(implicit ds: DataSource, lock: PgLock) extends PendingRel } } - override def addPendingRelay(channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit = withMetrics("pending-relay/add", DbBackends.Postgres) { + override def addSettlementCommand(channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit = withMetrics("pending-relay/add", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("INSERT INTO pending_relay VALUES (?, ?, ?) ON CONFLICT DO NOTHING")) { statement => + using(pg.prepareStatement("INSERT INTO pending_settlement_commands VALUES (?, ?, ?) ON CONFLICT DO NOTHING")) { statement => statement.setString(1, channelId.toHex) statement.setLong(2, cmd.id) statement.setBytes(3, cmdCodec.encode(cmd).require.toByteArray) @@ -61,9 +71,9 @@ class PgPendingRelayDb(implicit ds: DataSource, lock: PgLock) extends PendingRel } } - override def removePendingRelay(channelId: ByteVector32, htlcId: Long): Unit = withMetrics("pending-relay/remove", DbBackends.Postgres) { + override def removeSettlementCommand(channelId: ByteVector32, htlcId: Long): Unit = withMetrics("pending-relay/remove", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("DELETE FROM pending_relay WHERE channel_id=? AND htlc_id=?")) { statement => + using(pg.prepareStatement("DELETE FROM pending_settlement_commands WHERE channel_id=? AND htlc_id=?")) { statement => statement.setString(1, channelId.toHex) statement.setLong(2, htlcId) statement.executeUpdate() @@ -71,9 +81,9 @@ class PgPendingRelayDb(implicit ds: DataSource, lock: PgLock) extends PendingRel } } - override def listPendingRelay(channelId: ByteVector32): Seq[HtlcSettlementCommand] = withMetrics("pending-relay/list-channel", DbBackends.Postgres) { + override def listSettlementCommands(channelId: ByteVector32): Seq[HtlcSettlementCommand] = withMetrics("pending-relay/list-channel", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT htlc_id, data FROM pending_relay WHERE channel_id=?")) { statement => + using(pg.prepareStatement("SELECT htlc_id, data FROM pending_settlement_commands WHERE channel_id=?")) { statement => statement.setString(1, channelId.toHex) val rs = statement.executeQuery() codecSequence(rs, cmdCodec) @@ -81,15 +91,15 @@ class PgPendingRelayDb(implicit ds: DataSource, lock: PgLock) extends PendingRel } } - override def listPendingRelay(): Set[(ByteVector32, Long)] = withMetrics("pending-relay/list", DbBackends.Postgres) { + override def listSettlementCommands(): Seq[(ByteVector32, HtlcSettlementCommand)] = withMetrics("pending-relay/list", DbBackends.Postgres) { withLock { pg => - using(pg.prepareStatement("SELECT channel_id, htlc_id FROM pending_relay")) { statement => + using(pg.prepareStatement("SELECT channel_id, data FROM pending_settlement_commands")) { statement => val rs = statement.executeQuery() - var q: Queue[(ByteVector32, Long)] = Queue() + var q: Queue[(ByteVector32, HtlcSettlementCommand)] = Queue() while (rs.next()) { - q = q :+ (rs.getByteVector32FromHex("channel_id"), rs.getLong("htlc_id")) + q = q :+ (rs.getByteVector32FromHex("channel_id"), cmdCodec.decode(rs.getByteVector("data").bits).require.value) } - q.toSet + q } } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala index 70a62846c3..cc442bafa0 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqliteChannelsDb.scala @@ -117,7 +117,7 @@ class SqliteChannelsDb(sqlite: Connection) extends ChannelsDb with Logging { } override def removeChannel(channelId: ByteVector32): Unit = withMetrics("channels/remove-channel", DbBackends.Sqlite) { - using(sqlite.prepareStatement("DELETE FROM pending_relay WHERE channel_id=?")) { statement => + using(sqlite.prepareStatement("DELETE FROM pending_settlement_commands WHERE channel_id=?")) { statement => statement.setBytes(1, channelId.toArray) statement.executeUpdate() } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingCommandsDb.scala similarity index 53% rename from eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala rename to eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingCommandsDb.scala index eccbc9bfe8..974a53dd68 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingRelayDb.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/db/sqlite/SqlitePendingCommandsDb.scala @@ -20,33 +20,42 @@ import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.channel.HtlcSettlementCommand import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics import fr.acinq.eclair.db.Monitoring.Tags.DbBackends -import fr.acinq.eclair.db.PendingRelayDb +import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.wire.internal.CommandCodecs.cmdCodec +import grizzled.slf4j.Logging -import java.sql.Connection +import java.sql.{Connection, Statement} import scala.collection.immutable.Queue -class SqlitePendingRelayDb(sqlite: Connection) extends PendingRelayDb { +class SqlitePendingCommandsDb(sqlite: Connection) extends PendingCommandsDb with Logging { import SqliteUtils.ExtendedResultSet._ import SqliteUtils._ val DB_NAME = "pending_relay" - val CURRENT_VERSION = 1 + val CURRENT_VERSION = 2 using(sqlite.createStatement(), inTransaction = true) { statement => + + def migration12(statement: Statement): Unit = { + statement.executeUpdate("ALTER TABLE pending_relay RENAME TO pending_settlement_commands") + } + getVersion(statement, DB_NAME) match { case None => // note: should we use a foreign key to local_channels table here? - statement.executeUpdate("CREATE TABLE pending_relay (channel_id BLOB NOT NULL, htlc_id INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(channel_id, htlc_id))") + statement.executeUpdate("CREATE TABLE pending_settlement_commands (channel_id BLOB NOT NULL, htlc_id INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(channel_id, htlc_id))") + case Some(v@1) => + logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION") + migration12(statement) case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion") } setVersion(statement, DB_NAME, CURRENT_VERSION) } - override def addPendingRelay(channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit = withMetrics("pending-relay/add", DbBackends.Sqlite) { - using(sqlite.prepareStatement("INSERT OR IGNORE INTO pending_relay VALUES (?, ?, ?)")) { statement => + override def addSettlementCommand(channelId: ByteVector32, cmd: HtlcSettlementCommand): Unit = withMetrics("pending-relay/add", DbBackends.Sqlite) { + using(sqlite.prepareStatement("INSERT OR IGNORE INTO pending_settlement_commands VALUES (?, ?, ?)")) { statement => statement.setBytes(1, channelId.toArray) statement.setLong(2, cmd.id) statement.setBytes(3, cmdCodec.encode(cmd).require.toByteArray) @@ -54,30 +63,30 @@ class SqlitePendingRelayDb(sqlite: Connection) extends PendingRelayDb { } } - override def removePendingRelay(channelId: ByteVector32, htlcId: Long): Unit = withMetrics("pending-relay/remove", DbBackends.Sqlite) { - using(sqlite.prepareStatement("DELETE FROM pending_relay WHERE channel_id=? AND htlc_id=?")) { statement => + override def removeSettlementCommand(channelId: ByteVector32, htlcId: Long): Unit = withMetrics("pending-relay/remove", DbBackends.Sqlite) { + using(sqlite.prepareStatement("DELETE FROM pending_settlement_commands WHERE channel_id=? AND htlc_id=?")) { statement => statement.setBytes(1, channelId.toArray) statement.setLong(2, htlcId) statement.executeUpdate() } } - override def listPendingRelay(channelId: ByteVector32): Seq[HtlcSettlementCommand] = withMetrics("pending-relay/list-channel", DbBackends.Sqlite) { - using(sqlite.prepareStatement("SELECT data FROM pending_relay WHERE channel_id=?")) { statement => + override def listSettlementCommands(channelId: ByteVector32): Seq[HtlcSettlementCommand] = withMetrics("pending-relay/list-channel", DbBackends.Sqlite) { + using(sqlite.prepareStatement("SELECT data FROM pending_settlement_commands WHERE channel_id=?")) { statement => statement.setBytes(1, channelId.toArray) val rs = statement.executeQuery() codecSequence(rs, cmdCodec) } } - override def listPendingRelay(): Set[(ByteVector32, Long)] = withMetrics("pending-relay/list", DbBackends.Sqlite) { - using(sqlite.prepareStatement("SELECT channel_id, htlc_id FROM pending_relay")) { statement => + override def listSettlementCommands(): Seq[(ByteVector32, HtlcSettlementCommand)] = withMetrics("pending-relay/list", DbBackends.Sqlite) { + using(sqlite.prepareStatement("SELECT channel_id, data FROM pending_settlement_commands")) { statement => val rs = statement.executeQuery() - var q: Queue[(ByteVector32, Long)] = Queue() + var q: Queue[(ByteVector32, HtlcSettlementCommand)] = Queue() while (rs.next()) { - q = q :+ (rs.getByteVector32("channel_id"), rs.getLong("htlc_id")) + q = q :+ (rs.getByteVector32("channel_id"), cmdCodec.decode(rs.getByteVector("data").bits).require.value) } - q.toSet + q } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala index 26994644d9..4a60a07da9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/receive/MultiPartHandler.scala @@ -79,7 +79,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP case Some(record) => validatePayment(nodeParams, p, record) match { case Some(cmdFail) => Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, Tags.FailureType(cmdFail)).increment() - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, p.add.channelId, cmdFail) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.add.channelId, cmdFail) case None => log.info("received payment for amount={} totalAmount={}", p.add.amountMsat, p.payload.totalAmount) pendingPayments.get(p.add.paymentHash) match { @@ -110,7 +110,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP case _ => Metrics.PaymentFailed.withTag(Tags.Direction, Tags.Directions.Received).withTag(Tags.Failure, "InvoiceNotFound").increment() val cmdFail = CMD_FAIL_HTLC(p.add.id, Right(IncorrectOrUnknownPaymentDetails(p.payload.totalAmount, nodeParams.currentBlockHeight)), commit = true) - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, p.add.channelId, cmdFail) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.add.channelId, cmdFail) } } } @@ -121,7 +121,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP log.warning("payment with paidAmount={} failed ({})", parts.map(_.amount).sum, failure) pendingPayments.get(paymentHash).foreach { case (_, handler: ActorRef) => handler ! PoisonPill } parts.collect { - case p: MultiPartPaymentFSM.HtlcPart => PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, p.htlc.channelId, CMD_FAIL_HTLC(p.htlc.id, Right(failure), commit = true)) + case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FAIL_HTLC(p.htlc.id, Right(failure), commit = true)) } pendingPayments = pendingPayments - paymentHash } @@ -141,13 +141,13 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP Logs.withMdc(log)(Logs.mdc(paymentHash_opt = Some(paymentHash))) { failure match { case Some(failure) => p match { - case p: MultiPartPaymentFSM.HtlcPart => PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, p.htlc.channelId, CMD_FAIL_HTLC(p.htlc.id, Right(failure), commit = true)) + case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FAIL_HTLC(p.htlc.id, Right(failure), commit = true)) } case None => p match { // NB: this case shouldn't happen unless the sender violated the spec, so it's ok that we take a slightly more // expensive code path by fetching the preimage from DB. case p: MultiPartPaymentFSM.HtlcPart => db.getIncomingPayment(paymentHash).foreach(record => { - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, record.paymentPreimage, commit = true)) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, record.paymentPreimage, commit = true)) val received = PaymentReceived(paymentHash, PaymentReceived.PartialPayment(p.amount, p.htlc.channelId) :: Nil) db.receiveIncomingPayment(paymentHash, p.amount, received.timestamp) ctx.system.eventStream.publish(received) @@ -164,7 +164,7 @@ class MultiPartHandler(nodeParams: NodeParams, register: ActorRef, db: IncomingP }) db.receiveIncomingPayment(paymentHash, received.amount, received.timestamp) parts.collect { - case p: MultiPartPaymentFSM.HtlcPart => PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, preimage, commit = true)) + case p: MultiPartPaymentFSM.HtlcPart => PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, p.htlc.channelId, CMD_FULFILL_HTLC(p.htlc.id, preimage, commit = true)) } postFulfill(received) ctx.system.eventStream.publish(received) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala index dd8a40a20b..0070522885 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/ChannelRelay.scala @@ -25,7 +25,7 @@ import akka.actor.typed.scaladsl.adapter.TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.channel._ -import fr.acinq.eclair.db.PendingRelayDb +import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.relay.Relayer.OutgoingChannel import fr.acinq.eclair.payment.{ChannelPaymentRelayed, IncomingPacket} @@ -158,7 +158,7 @@ class ChannelRelay private(nodeParams: NodeParams, def safeSendAndStop(channelId: ByteVector32, cmd: channel.Command with channel.HtlcSettlementCommand): Behavior[Command] = { // NB: we are not using an adapter here because we are stopping anyway so we won't be there to get the result - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, channelId, cmd) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, cmd) Behaviors.stopped } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala index fb0b2c75c6..e562309ad8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelay.scala @@ -23,7 +23,7 @@ import akka.actor.typed.scaladsl.adapter.{TypedActorContextOps, TypedActorRefOps import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC} -import fr.acinq.eclair.db.PendingRelayDb +import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.payment.IncomingPacket.NodeRelayPacket import fr.acinq.eclair.payment.Monitoring.{Metrics, Tags} import fr.acinq.eclair.payment.OutgoingPacket.Upstream @@ -323,7 +323,7 @@ class NodeRelay private(nodeParams: NodeParams, private def rejectHtlc(htlcId: Long, channelId: ByteVector32, amount: MilliSatoshi, failure: Option[FailureMessage] = None): Unit = { val failureMessage = failure.getOrElse(IncorrectOrUnknownPaymentDetails(amount, nodeParams.currentBlockHeight)) val cmd = CMD_FAIL_HTLC(htlcId, Right(failureMessage), commit = true) - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, channelId, cmd) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, cmd) } private def rejectPayment(upstream: Upstream.Trampoline, failure: Option[FailureMessage]): Unit = { @@ -333,7 +333,7 @@ class NodeRelay private(nodeParams: NodeParams, private def fulfillPayment(upstream: Upstream.Trampoline, paymentPreimage: ByteVector32): Unit = upstream.adds.foreach(add => { val cmd = CMD_FULFILL_HTLC(add.id, paymentPreimage, commit = true) - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, add.channelId, cmd) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, cmd) }) private def success(upstream: Upstream.Trampoline, fulfilledUpstream: Boolean, paymentSent: PaymentSent): Unit = { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala index 05a8425c58..1a6ebbb519 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/NodeRelayer.scala @@ -20,7 +20,7 @@ import akka.actor.typed.scaladsl.Behaviors import akka.actor.typed.{ActorRef, Behavior} import fr.acinq.bitcoin.ByteVector32 import fr.acinq.eclair.channel.CMD_FAIL_HTLC -import fr.acinq.eclair.db.PendingRelayDb +import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.payment._ import fr.acinq.eclair.wire.protocol.IncorrectOrUnknownPaymentDetails import fr.acinq.eclair.{Logs, NodeParams} @@ -85,7 +85,7 @@ object NodeRelayer { context.log.warn("rejecting htlc #{} from channel {}: missing payment secret", htlcIn.id, htlcIn.channelId) val failureMessage = IncorrectOrUnknownPaymentDetails(htlcIn.amountMsat, nodeParams.currentBlockHeight) val cmd = CMD_FAIL_HTLC(htlcIn.id, Right(failureMessage), commit = true) - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, htlcIn.channelId, cmd) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, htlcIn.channelId, cmd) Behaviors.same } case RelayComplete(childHandler, paymentHash, paymentSecret) => diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala index a6469129d4..ad10121868 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/PostRestartHtlcCleaner.scala @@ -71,7 +71,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial val relayedOut: Map[Origin, Set[(ByteVector32, Long)]] = getHtlcsRelayedOut(channels, htlcsIn) ++ nonStandardRelayedOutHtlcs val notRelayed = htlcsIn.filterNot(htlcIn => relayedOut.keys.exists(origin => matchesOrigin(htlcIn.add, origin))) - cleanupRelayDb(htlcsIn, nodeParams.db.pendingRelay) + cleanupRelayDb(htlcsIn, nodeParams.db.pendingCommands) log.info(s"htlcsIn=${htlcsIn.length} notRelayed=${notRelayed.length} relayedOut=${relayedOut.values.flatten.size}") log.info("notRelayed={}", notRelayed.map(htlc => (htlc.add.channelId, htlc.add.id))) @@ -170,7 +170,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial if (relayedOut != Set((fulfilledHtlc.channelId, fulfilledHtlc.id))) { log.error(s"unexpected channel relay downstream HTLCs: expected (${fulfilledHtlc.channelId},${fulfilledHtlc.id}), found $relayedOut") } - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, originChannelId, CMD_FULFILL_HTLC(originHtlcId, paymentPreimage, commit = true)) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, originChannelId, CMD_FULFILL_HTLC(originHtlcId, paymentPreimage, commit = true)) context.system.eventStream.publish(ChannelPaymentRelayed(amountIn, amountOut, fulfilledHtlc.paymentHash, originChannelId, fulfilledHtlc.channelId)) Metrics.PendingRelayedOut.decrement() context become main(brokenHtlcs.copy(relayedOut = brokenHtlcs.relayedOut - origin)) @@ -181,7 +181,7 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial log.info(s"received preimage for paymentHash=${fulfilledHtlc.paymentHash}: fulfilling ${origins.length} HTLCs upstream") origins.foreach { case (channelId, htlcId) => Metrics.Resolved.withTag(Tags.Success, value = true).withTag(Metrics.Relayed, value = true).increment() - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, channelId, CMD_FULFILL_HTLC(htlcId, paymentPreimage, commit = true)) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, CMD_FULFILL_HTLC(htlcId, paymentPreimage, commit = true)) } } val relayedOut1 = relayedOut diff Set((fulfilledHtlc.channelId, fulfilledHtlc.id)) @@ -225,14 +225,14 @@ class PostRestartHtlcCleaner(nodeParams: NodeParams, register: ActorRef, initial log.warning(s"payment failed for paymentHash=${failedHtlc.paymentHash}: failing 1 HTLC upstream") Metrics.Resolved.withTag(Tags.Success, value = false).withTag(Metrics.Relayed, value = true).increment() val cmd = ChannelRelay.translateRelayFailure(originHtlcId, fail) - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, originChannelId, cmd) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, originChannelId, cmd) case Origin.TrampolineRelayedCold(origins) => log.warning(s"payment failed for paymentHash=${failedHtlc.paymentHash}: failing ${origins.length} HTLCs upstream") origins.foreach { case (channelId, htlcId) => Metrics.Resolved.withTag(Tags.Success, value = false).withTag(Metrics.Relayed, value = true).increment() // We don't bother decrypting the downstream failure to forward a more meaningful error upstream, it's // very likely that it won't be actionable anyway because of our node restart. - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, channelId, CMD_FAIL_HTLC(htlcId, Right(TemporaryNodeFailure), commit = true)) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, channelId, CMD_FAIL_HTLC(htlcId, Right(TemporaryNodeFailure), commit = true)) } } } @@ -390,7 +390,7 @@ object PostRestartHtlcCleaner { /** * We store [[CMD_FULFILL_HTLC]]/[[CMD_FAIL_HTLC]]/[[CMD_FAIL_MALFORMED_HTLC]] in a database - * (see [[fr.acinq.eclair.db.PendingRelayDb]]) because we don't want to lose preimages, or to forget to fail + * (see [[fr.acinq.eclair.db.PendingCommandsDb]]) because we don't want to lose preimages, or to forget to fail * incoming htlcs, which would lead to unwanted channel closings. * * Because of the way our watcher works, in a scenario where a downstream channel has gone to the blockchain, it may @@ -398,17 +398,17 @@ object PostRestartHtlcCleaner { * * That's why we need to periodically clean up the pending relay db. */ - private def cleanupRelayDb(htlcsIn: Seq[IncomingHtlc], relayDb: PendingRelayDb)(implicit log: LoggingAdapter): Unit = { + private def cleanupRelayDb(htlcsIn: Seq[IncomingHtlc], relayDb: PendingCommandsDb)(implicit log: LoggingAdapter): Unit = { // We are interested in incoming HTLCs, that have been *cross-signed* (otherwise they wouldn't have been relayed). // If the HTLC is not in their commitment, it means that we have already fulfilled/failed it and that we can remove // the command from the pending relay db. val channel2Htlc: Seq[(ByteVector32, Long)] = htlcsIn.map { case IncomingHtlc(add, _) => (add.channelId, add.id) } - val pendingRelay: Set[(ByteVector32, Long)] = relayDb.listPendingRelay() + val pendingRelay: Set[(ByteVector32, Long)] = relayDb.listSettlementCommands().map { case (channelId, cmd) => (channelId, cmd.id) }.toSet val toClean = pendingRelay -- channel2Htlc toClean.foreach { case (channelId, htlcId) => log.info(s"cleaning up channelId=$channelId htlcId=$htlcId from relay db") - relayDb.removePendingRelay(channelId, htlcId) + relayDb.removeSettlementCommand(channelId, htlcId) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala index c834fdae83..9deb4378dc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/payment/relay/Relayer.scala @@ -25,7 +25,7 @@ import akka.event.Logging.MDC import akka.event.LoggingAdapter import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.eclair.channel._ -import fr.acinq.eclair.db.PendingRelayDb +import fr.acinq.eclair.db.PendingCommandsDb import fr.acinq.eclair.payment._ import fr.acinq.eclair.wire.protocol._ import fr.acinq.eclair.{Logs, MilliSatoshi, NodeParams, ShortChannelId} @@ -69,7 +69,7 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym case Right(r: IncomingPacket.NodeRelayPacket) => if (!nodeParams.enableTrampolinePayment) { log.warning(s"rejecting htlc #${add.id} from channelId=${add.channelId} to nodeId=${r.innerPayload.outgoingNodeId} reason=trampoline disabled") - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, add.channelId, CMD_FAIL_HTLC(add.id, Right(RequiredNodeFeatureMissing), commit = true)) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, CMD_FAIL_HTLC(add.id, Right(RequiredNodeFeatureMissing), commit = true)) } else { nodeRelayer ! NodeRelayer.Relay(r) } @@ -77,11 +77,11 @@ class Relayer(nodeParams: NodeParams, router: ActorRef, register: ActorRef, paym log.warning(s"couldn't parse onion: reason=${badOnion.message}") val cmdFail = CMD_FAIL_MALFORMED_HTLC(add.id, badOnion.onionHash, badOnion.code, commit = true) log.warning(s"rejecting htlc #${add.id} from channelId=${add.channelId} reason=malformed onionHash=${cmdFail.onionHash} failureCode=${cmdFail.failureCode}") - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, add.channelId, cmdFail) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, cmdFail) case Left(failure) => log.warning(s"rejecting htlc #${add.id} from channelId=${add.channelId} reason=$failure") val cmdFail = CMD_FAIL_HTLC(add.id, Right(failure), commit = true) - PendingRelayDb.safeSend(register, nodeParams.db.pendingRelay, add.channelId, cmdFail) + PendingCommandsDb.safeSend(register, nodeParams.db.pendingCommands, add.channelId, cmdFail) } case r: RES_ADD_SETTLED[_, _] => r.origin match { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala index 34279ebea6..648a4eeca0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestDatabases.scala @@ -28,7 +28,7 @@ sealed trait TestDatabases extends Databases { override def channels: ChannelsDb = db.channels override def peers: PeersDb = db.peers override def payments: PaymentsDb = db.payments - override def pendingRelay: PendingRelayDb = db.pendingRelay + override def pendingCommands: PendingCommandsDb = db.pendingCommands def close(): Unit // @formatter:on } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala index e873813e30..5cbc7078c9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/NormalStateSpec.scala @@ -1225,12 +1225,12 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val initialState = bob.stateData.asInstanceOf[DATA_NORMAL] val c = CMD_FULFILL_HTLC(htlc.id, r, replyTo_opt = Some(sender.ref)) // this would be done automatically when the relayer calls safeSend - bob.underlyingActor.nodeParams.db.pendingRelay.addPendingRelay(initialState.channelId, c) + bob.underlyingActor.nodeParams.db.pendingCommands.addSettlementCommand(initialState.channelId, c) bob ! c bob2alice.expectMsgType[UpdateFulfillHtlc] bob ! CMD_SIGN(replyTo_opt = Some(sender.ref)) bob2alice.expectMsgType[CommitSig] - awaitCond(bob.underlyingActor.nodeParams.db.pendingRelay.listPendingRelay(initialState.channelId).isEmpty) + awaitCond(bob.underlyingActor.nodeParams.db.pendingCommands.listSettlementCommands(initialState.channelId).isEmpty) } test("recv CMD_FULFILL_HTLC (acknowledge in case of failure)") { f => @@ -1241,7 +1241,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val c = CMD_FULFILL_HTLC(42, randomBytes32(), replyTo_opt = Some(sender.ref)) sender.send(bob, c) // this will fail sender.expectMsg(RES_FAILURE(c, UnknownHtlcId(channelId(bob), 42))) - awaitCond(bob.underlyingActor.nodeParams.db.pendingRelay.listPendingRelay(initialState.channelId).isEmpty) + awaitCond(bob.underlyingActor.nodeParams.db.pendingCommands.listSettlementCommands(initialState.channelId).isEmpty) } private def testUpdateFulfillHtlc(f: FixtureParam): Unit = { @@ -1372,7 +1372,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val c = CMD_FAIL_HTLC(42, Right(PermanentChannelFailure), replyTo_opt = Some(sender.ref)) sender.send(bob, c) // this will fail sender.expectMsg(RES_FAILURE(c, UnknownHtlcId(channelId(bob), 42))) - awaitCond(bob.underlyingActor.nodeParams.db.pendingRelay.listPendingRelay(initialState.channelId).isEmpty) + awaitCond(bob.underlyingActor.nodeParams.db.pendingCommands.listSettlementCommands(initialState.channelId).isEmpty) } test("recv CMD_FAIL_MALFORMED_HTLC") { f => @@ -1418,7 +1418,7 @@ class NormalStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with val c = CMD_FAIL_MALFORMED_HTLC(42, ByteVector32.Zeroes, FailureMessageCodecs.BADONION, replyTo_opt = Some(sender.ref)) sender.send(bob, c) // this will fail sender.expectMsg(RES_FAILURE(c, UnknownHtlcId(channelId(bob), 42))) - awaitCond(bob.underlyingActor.nodeParams.db.pendingRelay.listPendingRelay(initialState.channelId).isEmpty) + awaitCond(bob.underlyingActor.nodeParams.db.pendingCommands.listSettlementCommands(initialState.channelId).isEmpty) } private def testUpdateFailHtlc(f: FixtureParam): Unit = { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala index e184b3f31f..64fbdefa94 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/e/OfflineStateSpec.scala @@ -409,7 +409,7 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with disconnect(alice, bob) // We simulate a pending fulfill - bob.underlyingActor.nodeParams.db.pendingRelay.addPendingRelay(initialState.channelId, CMD_FULFILL_HTLC(htlc.id, r, commit = true)) + bob.underlyingActor.nodeParams.db.pendingCommands.addSettlementCommand(initialState.channelId, CMD_FULFILL_HTLC(htlc.id, r, commit = true)) // then we reconnect them reconnect(alice, bob, alice2bob, bob2alice) @@ -440,7 +440,7 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with disconnect(alice, bob) // We simulate a pending fulfill - bob.underlyingActor.nodeParams.db.pendingRelay.addPendingRelay(initialState.channelId, CMD_FULFILL_HTLC(htlc.id, r, commit = true)) + bob.underlyingActor.nodeParams.db.pendingCommands.addSettlementCommand(initialState.channelId, CMD_FULFILL_HTLC(htlc.id, r, commit = true)) // then we reconnect them reconnect(alice, bob, alice2bob, bob2alice) @@ -477,7 +477,7 @@ class OfflineStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with // We simulate a pending fulfill on that HTLC but not relayed. // When it is close to expiring upstream, we should close the channel. - bob.underlyingActor.nodeParams.db.pendingRelay.addPendingRelay(initialState.channelId, CMD_FULFILL_HTLC(htlc.id, r, commit = true)) + bob.underlyingActor.nodeParams.db.pendingCommands.addSettlementCommand(initialState.channelId, CMD_FULFILL_HTLC(htlc.id, r, commit = true)) bob ! CurrentBlockCount((htlc.cltvExpiry - bob.underlyingActor.nodeParams.fulfillSafetyBeforeTimeout).toLong) val ChannelErrorOccurred(_, _, _, _, LocalError(err), isFatal) = listener.expectMsgType[ChannelErrorOccurred] diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala index 6fcc6a0988..94890c9965 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/f/ShutdownStateSpec.scala @@ -151,12 +151,12 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit val initialState = bob.stateData.asInstanceOf[DATA_SHUTDOWN] val c = CMD_FULFILL_HTLC(0, r1, replyTo_opt = Some(sender.ref)) // this would be done automatically when the relayer calls safeSend - bob.underlyingActor.nodeParams.db.pendingRelay.addPendingRelay(initialState.channelId, c) + bob.underlyingActor.nodeParams.db.pendingCommands.addSettlementCommand(initialState.channelId, c) bob ! c bob2alice.expectMsgType[UpdateFulfillHtlc] bob ! CMD_SIGN(replyTo_opt = Some(sender.ref)) bob2alice.expectMsgType[CommitSig] - awaitCond(bob.underlyingActor.nodeParams.db.pendingRelay.listPendingRelay(initialState.channelId).isEmpty) + awaitCond(bob.underlyingActor.nodeParams.db.pendingCommands.listSettlementCommands(initialState.channelId).isEmpty) } test("recv CMD_FULFILL_HTLC (acknowledge in case of failure)") { f => @@ -167,7 +167,7 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit val c = CMD_FULFILL_HTLC(42, randomBytes32(), replyTo_opt = Some(sender.ref)) sender.send(bob, c) // this will fail sender.expectMsg(RES_FAILURE(c, UnknownHtlcId(channelId(bob), 42))) - awaitCond(bob.underlyingActor.nodeParams.db.pendingRelay.listPendingRelay(initialState.channelId).isEmpty) + awaitCond(bob.underlyingActor.nodeParams.db.pendingCommands.listSettlementCommands(initialState.channelId).isEmpty) } test("recv UpdateFulfillHtlc") { f => @@ -232,7 +232,7 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit val c = CMD_FAIL_HTLC(42, Right(PermanentChannelFailure), replyTo_opt = Some(sender.ref)) sender.send(bob, c) // this will fail sender.expectMsg(RES_FAILURE(c, UnknownHtlcId(channelId(bob), 42))) - awaitCond(bob.underlyingActor.nodeParams.db.pendingRelay.listPendingRelay(initialState.channelId).isEmpty) + awaitCond(bob.underlyingActor.nodeParams.db.pendingCommands.listSettlementCommands(initialState.channelId).isEmpty) } test("recv CMD_FAIL_MALFORMED_HTLC") { f => @@ -272,7 +272,7 @@ class ShutdownStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wit val c = CMD_FAIL_MALFORMED_HTLC(42, randomBytes32(), FailureMessageCodecs.BADONION, replyTo_opt = Some(sender.ref)) sender.send(bob, c) // this will fail sender.expectMsg(RES_FAILURE(c, UnknownHtlcId(channelId(bob), 42))) - awaitCond(bob.underlyingActor.nodeParams.db.pendingRelay.listPendingRelay(initialState.channelId).isEmpty) + awaitCond(bob.underlyingActor.nodeParams.db.pendingCommands.listSettlementCommands(initialState.channelId).isEmpty) } test("recv UpdateFailHtlc") { f => diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/h/ClosingStateSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/h/ClosingStateSpec.scala index f2efccab33..9733ec58fa 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/h/ClosingStateSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/channel/states/h/ClosingStateSpec.scala @@ -102,7 +102,7 @@ class ClosingStateSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike with relayerA.expectMsgType[RES_ADD_SETTLED[Origin, HtlcResult.Fulfill]] crossSign(bob, alice, bob2alice, alice2bob) // bob confirms that it has forwarded the fulfill to alice - awaitCond(bob.underlyingActor.nodeParams.db.pendingRelay.listPendingRelay(htlc.channelId).isEmpty) + awaitCond(bob.underlyingActor.nodeParams.db.pendingCommands.listSettlementCommands(htlc.channelId).isEmpty) val bobCommitTx2 = bob.stateData.asInstanceOf[DATA_NORMAL].commitments.localCommit.publishableTxs bobCommitTx1 :: bobCommitTx2 :: Nil }).flatten diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala index cad456ee03..468836f3c4 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/ChannelsDbSpec.scala @@ -56,7 +56,7 @@ class ChannelsDbSpec extends AnyFunSuite { test("add/remove/list channels") { forAllDbs { dbs => val db = dbs.channels - dbs.pendingRelay // needed by db.removeChannel + dbs.pendingCommands // needed by db.removeChannel val channel1 = ChannelCodecsSpec.normal val channel2a = ChannelCodecsSpec.normal.modify(_.commitments.channelId).setTo(randomBytes32()) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingCommandsDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingCommandsDbSpec.scala new file mode 100644 index 0000000000..1efa35b7de --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingCommandsDbSpec.scala @@ -0,0 +1,145 @@ +/* + * Copyright 2019 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.db + +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases} +import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FAIL_MALFORMED_HTLC, CMD_FULFILL_HTLC, HtlcSettlementCommand} +import fr.acinq.eclair.db.pg.PgPendingCommandsDb +import fr.acinq.eclair.db.sqlite.SqlitePendingCommandsDb +import fr.acinq.eclair.db.sqlite.SqliteUtils.{setVersion, using} +import fr.acinq.eclair.randomBytes32 +import fr.acinq.eclair.wire.internal.CommandCodecs.cmdCodec +import fr.acinq.eclair.wire.protocol.{FailureMessageCodecs, UnknownNextPeer} +import org.scalatest.funsuite.AnyFunSuite + +import scala.util.Random + + +class PendingCommandsDbSpec extends AnyFunSuite { + + import PendingCommandsDbSpec._ + import fr.acinq.eclair.TestDatabases.{forAllDbs, migrationCheck} + + + test("init database two times in a row") { + forAllDbs { + case sqlite: TestSqliteDatabases => + new SqlitePendingCommandsDb(sqlite.connection) + new SqlitePendingCommandsDb(sqlite.connection) + case pg: TestPgDatabases => + new PgPendingCommandsDb()(pg.datasource, pg.lock) + new PgPendingCommandsDb()(pg.datasource, pg.lock) + } + } + + test("add/remove/list messages") { + forAllDbs { dbs => + val db = dbs.pendingCommands + + val channelId1 = randomBytes32() + val channelId2 = randomBytes32() + val msg0 = CMD_FULFILL_HTLC(0, randomBytes32()) + val msg1 = CMD_FULFILL_HTLC(1, randomBytes32()) + val msg2 = CMD_FAIL_HTLC(2, Left(randomBytes32())) + val msg3 = CMD_FAIL_HTLC(3, Left(randomBytes32())) + val msg4 = CMD_FAIL_MALFORMED_HTLC(4, randomBytes32(), FailureMessageCodecs.BADONION) + + assert(db.listSettlementCommands(channelId1).toSet === Set.empty) + db.addSettlementCommand(channelId1, msg0) + db.addSettlementCommand(channelId1, msg0) // duplicate + db.addSettlementCommand(channelId1, msg1) + db.addSettlementCommand(channelId1, msg2) + db.addSettlementCommand(channelId1, msg3) + db.addSettlementCommand(channelId1, msg4) + db.addSettlementCommand(channelId2, msg0) // same messages but for different channel + db.addSettlementCommand(channelId2, msg1) + assert(db.listSettlementCommands(channelId1).toSet === Set(msg0, msg1, msg2, msg3, msg4)) + assert(db.listSettlementCommands(channelId2).toSet === Set(msg0, msg1)) + assert(db.listSettlementCommands().toSet === Set((channelId1, msg0), (channelId1, msg1), (channelId1, msg2), (channelId1, msg3), (channelId1, msg4), (channelId2, msg0), (channelId2, msg1))) + db.removeSettlementCommand(channelId1, msg1.id) + assert(db.listSettlementCommands().toSet === Set((channelId1, msg0), (channelId1, msg2), (channelId1, msg3), (channelId1, msg4), (channelId2, msg0), (channelId2, msg1))) + } + } + + test("migrate database v1->v2") { + forAllDbs { + case dbs: TestPgDatabases => + migrationCheck( + dbs = dbs, + initializeTables = connection => { + using(connection.createStatement()) { statement => + statement.executeUpdate("CREATE TABLE pending_relay (channel_id TEXT NOT NULL, htlc_id BIGINT NOT NULL, data BYTEA NOT NULL, PRIMARY KEY(channel_id, htlc_id))") + setVersion(statement, "pending_relay", 1) + } + testCases.foreach { testCase => + using(connection.prepareStatement("INSERT INTO pending_relay VALUES (?, ?, ?) ON CONFLICT DO NOTHING")) { statement => + statement.setString(1, testCase.channelId.toHex) + statement.setLong(2, testCase.cmd.id) + statement.setBytes(3, cmdCodec.encode(testCase.cmd).require.toByteArray) + statement.executeUpdate() + } + } + }, + dbName = "pending_relay", + targetVersion = 2, + postCheck = _ => + assert(dbs.pendingCommands.listSettlementCommands().toSet === testCases.map(tc => tc.channelId -> tc.cmd)) + ) + case dbs: TestSqliteDatabases => + migrationCheck( + dbs = dbs, + initializeTables = connection => { + using(connection.createStatement()) { statement => + statement.executeUpdate("CREATE TABLE pending_relay (channel_id BLOB NOT NULL, htlc_id INTEGER NOT NULL, data BLOB NOT NULL, PRIMARY KEY(channel_id, htlc_id))") + setVersion(statement, "pending_relay", 1) + } + testCases.foreach { testCase => + using(connection.prepareStatement("INSERT OR IGNORE INTO pending_relay VALUES (?, ?, ?)")) { statement => + statement.setBytes(1, testCase.channelId.toArray) + statement.setLong(2, testCase.cmd.id) + statement.setBytes(3, cmdCodec.encode(testCase.cmd).require.toByteArray) + statement.executeUpdate() + } + } + }, + dbName = "pending_relay", + targetVersion = 2, + postCheck = _ => + assert(dbs.pendingCommands.listSettlementCommands().toSet === testCases.map(tc => tc.channelId -> tc.cmd)) + ) + } + } + +} + +object PendingCommandsDbSpec { + + case class TestCase(channelId: ByteVector32, + cmd: HtlcSettlementCommand) + + val testCases: Set[TestCase] = (0 until 100).flatMap { _ => + val channelId = randomBytes32() + val cmds = (0 until Random.nextInt(5)).map { _ => + Random.nextInt(2) match { + case 0 => CMD_FULFILL_HTLC(Random.nextLong(100_000), randomBytes32()) + case 1 => CMD_FAIL_HTLC(Random.nextLong(100_000), Right(UnknownNextPeer)) + } + } + cmds.map(cmd => TestCase(channelId, cmd)) + }.toSet +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingRelayDbSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingRelayDbSpec.scala deleted file mode 100644 index 5f7eadffad..0000000000 --- a/eclair-core/src/test/scala/fr/acinq/eclair/db/PendingRelayDbSpec.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright 2019 ACINQ SAS - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package fr.acinq.eclair.db - -import fr.acinq.eclair.TestDatabases.{TestPgDatabases, TestSqliteDatabases} -import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FAIL_MALFORMED_HTLC, CMD_FULFILL_HTLC} -import fr.acinq.eclair.db.pg.PgPendingRelayDb -import fr.acinq.eclair.db.sqlite.SqlitePendingRelayDb -import fr.acinq.eclair.randomBytes32 -import fr.acinq.eclair.wire.protocol.FailureMessageCodecs -import org.scalatest.funsuite.AnyFunSuite - - -class PendingRelayDbSpec extends AnyFunSuite { - - import fr.acinq.eclair.TestDatabases.forAllDbs - - test("init database two times in a row") { - forAllDbs { - case sqlite: TestSqliteDatabases => - new SqlitePendingRelayDb(sqlite.connection) - new SqlitePendingRelayDb(sqlite.connection) - case pg: TestPgDatabases => - new PgPendingRelayDb()(pg.datasource, pg.lock) - new PgPendingRelayDb()(pg.datasource, pg.lock) - } - } - - test("add/remove/list messages") { - forAllDbs { dbs => - val db = dbs.pendingRelay - - val channelId1 = randomBytes32() - val channelId2 = randomBytes32() - val msg0 = CMD_FULFILL_HTLC(0, randomBytes32()) - val msg1 = CMD_FULFILL_HTLC(1, randomBytes32()) - val msg2 = CMD_FAIL_HTLC(2, Left(randomBytes32())) - val msg3 = CMD_FAIL_HTLC(3, Left(randomBytes32())) - val msg4 = CMD_FAIL_MALFORMED_HTLC(4, randomBytes32(), FailureMessageCodecs.BADONION) - - assert(db.listPendingRelay(channelId1).toSet === Set.empty) - db.addPendingRelay(channelId1, msg0) - db.addPendingRelay(channelId1, msg0) // duplicate - db.addPendingRelay(channelId1, msg1) - db.addPendingRelay(channelId1, msg2) - db.addPendingRelay(channelId1, msg3) - db.addPendingRelay(channelId1, msg4) - db.addPendingRelay(channelId2, msg0) // same messages but for different channel - db.addPendingRelay(channelId2, msg1) - assert(db.listPendingRelay(channelId1).toSet === Set(msg0, msg1, msg2, msg3, msg4)) - assert(db.listPendingRelay(channelId2).toSet === Set(msg0, msg1)) - assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg1.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id))) - db.removePendingRelay(channelId1, msg1.id) - assert(db.listPendingRelay === Set((channelId1, msg0.id), (channelId1, msg2.id), (channelId1, msg3.id), (channelId1, msg4.id), (channelId2, msg0.id), (channelId2, msg1.id))) - } - } - -} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala index bf03bd6346..76d0476d0f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/payment/PostRestartHtlcCleanerSpec.scala @@ -645,11 +645,11 @@ class PostRestartHtlcCleanerSpec extends TestKitBaseClass with FixtureAnyFunSuit val cmd1 = CMD_FAIL_HTLC(id = 0L, reason = Left(ByteVector.empty), replyTo_opt = None) val cmd2 = CMD_FAIL_HTLC(id = 1L, reason = Left(ByteVector.empty), replyTo_opt = None) val nodeParams1 = nodeParams.copy(pluginParams = List(pluginParams)) - nodeParams1.db.pendingRelay.addPendingRelay(channelId_ab_1, cmd1) - nodeParams1.db.pendingRelay.addPendingRelay(channelId_ab_1, cmd2) + nodeParams1.db.pendingCommands.addSettlementCommand(channelId_ab_1, cmd1) + nodeParams1.db.pendingCommands.addSettlementCommand(channelId_ab_1, cmd2) f.createRelayer(nodeParams1) register.expectNoMsg(100 millis) - awaitCond(Seq(cmd1) == nodeParams1.db.pendingRelay.listPendingRelay(channelId_ab_1)) + awaitCond(Seq(cmd1) == nodeParams1.db.pendingCommands.listSettlementCommands(channelId_ab_1)) } }