Skip to content

Commit

Permalink
Use length-delimited byte-aligned codecs (#1442)
Browse files Browse the repository at this point in the history
Legacy codecs are isolated in a separate file, with a visibility restricted to "package" in order to reduce the risk of using those codecs. Also codecs are restricted to `decodeOnly` for the same reason.
  • Loading branch information
pm47 committed Jun 22, 2020
1 parent d5ec6a5 commit 6c81f95
Show file tree
Hide file tree
Showing 7 changed files with 502 additions and 207 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,10 @@ import fr.acinq.bitcoin.{ByteVector32, Satoshi}
import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid, RemoteError}
import fr.acinq.eclair.db._
import fr.acinq.eclair.payment._
import fr.acinq.eclair.wire.ChannelCodecs
import fr.acinq.eclair.{LongToBtcAmount, MilliSatoshi}
import grizzled.slf4j.Logging

import scala.collection.immutable.Queue
import scala.compat.Platform

class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging {

Expand All @@ -44,7 +42,8 @@ class SqliteAuditDb(sqlite: Connection) extends AuditDb with Logging {
using(sqlite.createStatement(), inTransaction = true) { statement =>

def migration12(statement: Statement): Int = {
statement.executeUpdate(s"ALTER TABLE sent ADD id BLOB DEFAULT '${ChannelCodecs.UNKNOWN_UUID.toString}' NOT NULL")
val ZERO_UUID = UUID.fromString("00000000-0000-0000-0000-000000000000")
statement.executeUpdate(s"ALTER TABLE sent ADD id BLOB DEFAULT '${ZERO_UUID.toString}' NOT NULL")
}

def migration23(statement: Statement): Int = {
Expand Down
222 changes: 76 additions & 146 deletions eclair-core/src/main/scala/fr/acinq/eclair/wire/ChannelCodecs.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ object CommonCodecs {

val varsizebinarydata: Codec[ByteVector] = variableSizeBytes(uint16, bytes)

def mapCodec[K, V](keyCodec: Codec[K], valueCodec: Codec[V]): Codec[Map[K, V]] = listOfN(uint16, keyCodec ~ valueCodec).xmap(_.toMap, _.toList)

def setCodec[T](codec: Codec[T]): Codec[Set[T]] = listOfN(uint16, codec).xmap(_.toSet, _.toList)

val listofsignatures: Codec[List[ByteVector64]] = listOfN(uint16, bytes64)

val ipv4address: Codec[Inet4Address] = bytes(4).xmap(b => InetAddress.getByAddress(b.toArray).asInstanceOf[Inet4Address], a => ByteVector(a.getAddress))
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
package fr.acinq.eclair.blockchain.electrum.db.sqlite

import fr.acinq.bitcoin.{Block, BlockHeader, OutPoint, Satoshi, Transaction, TxIn, TxOut}
import fr.acinq.eclair.{TestConstants, randomBytes, randomBytes32}
import fr.acinq.eclair.blockchain.electrum.ElectrumClient
import fr.acinq.eclair.blockchain.electrum.ElectrumClient.GetMerkleResponse
import fr.acinq.eclair.blockchain.electrum.ElectrumWallet.PersistentData
import fr.acinq.eclair.blockchain.electrum.db.sqlite.SqliteWalletDb.version
import fr.acinq.eclair.wire.ChannelCodecs.txCodec
import fr.acinq.eclair.{TestConstants, randomBytes, randomBytes32}
import fr.acinq.eclair.wire.CommonCodecs.setCodec
import org.scalatest.funsuite.AnyFunSuite
import scodec.Codec
import scodec.bits.BitVector
import scodec.codecs.{constant, listOfN, provide, uint16}

import scala.util.Random

Expand Down Expand Up @@ -105,9 +103,9 @@ class SqliteWalletDbSpec extends AnyFunSuite {
}

test("read old persistent data") {
import scodec.codecs._
import SqliteWalletDb._
import fr.acinq.eclair.wire.ChannelCodecs._
import scodec.codecs._

val oldPersistentDataCodec: Codec[PersistentData] = (
("version" | constant(BitVector.fromInt(version))) ::
Expand All @@ -119,7 +117,7 @@ class SqliteWalletDbSpec extends AnyFunSuite {
("history" | historyCodec) ::
("proofs" | proofsCodec) ::
("pendingTransactions" | listOfN(uint16, txCodec)) ::
("locks" | setCodec(txCodec))).as[PersistentData]
("locks" | setCodec(txCodec))).as[PersistentData]

for (i <- 0 until 50) {
val data = randomPersistentData
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ import fr.acinq.eclair.channel.{ChannelErrorOccurred, LocalError, NetworkFeePaid
import fr.acinq.eclair.db.sqlite.SqliteAuditDb
import fr.acinq.eclair.db.sqlite.SqliteUtils.{getVersion, using}
import fr.acinq.eclair.payment._
import fr.acinq.eclair.wire.ChannelCodecs
import org.scalatest.Tag
import org.scalatest.funsuite.AnyFunSuite

import scala.compat.Platform
import scala.concurrent.duration._
import scala.util.Random

class SqliteAuditDbSpec extends AnyFunSuite {

val ZERO_UUID = UUID.fromString("00000000-0000-0000-0000-000000000000")

test("init sqlite 2 times in a row") {
val sqlite = TestConstants.sqliteInMemory()
val db1 = new SqliteAuditDb(sqlite)
Expand All @@ -45,7 +45,7 @@ class SqliteAuditDbSpec extends AnyFunSuite {
val sqlite = TestConstants.sqliteInMemory()
val db = new SqliteAuditDb(sqlite)

val e1 = PaymentSent(ChannelCodecs.UNKNOWN_UUID, randomBytes32, randomBytes32, 40000 msat, randomKey.publicKey, PaymentSent.PartialPayment(ChannelCodecs.UNKNOWN_UUID, 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
val e1 = PaymentSent(ZERO_UUID, randomBytes32, randomBytes32, 40000 msat, randomKey.publicKey, PaymentSent.PartialPayment(ZERO_UUID, 42000 msat, 1000 msat, randomBytes32, None) :: Nil)
val pp2a = PaymentReceived.PartialPayment(42000 msat, randomBytes32)
val pp2b = PaymentReceived.PartialPayment(42100 msat, randomBytes32)
val e2 = PaymentReceived(randomBytes32, pp2a :: pp2b :: Nil)
Expand Down Expand Up @@ -203,7 +203,7 @@ class SqliteAuditDbSpec extends AnyFunSuite {
}

// existing rows in the 'sent' table will use id=00000000-0000-0000-0000-000000000000 as default
assert(migratedDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, parts = Seq(ps.parts.head.copy(id = ChannelCodecs.UNKNOWN_UUID)))))
assert(migratedDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID)))))

val postMigrationDb = new SqliteAuditDb(connection)

Expand All @@ -216,7 +216,7 @@ class SqliteAuditDbSpec extends AnyFunSuite {
postMigrationDb.add(e2)

// the old record will have the UNKNOWN_UUID but the new ones will have their actual id
val expected = Seq(ps.copy(id = ChannelCodecs.UNKNOWN_UUID, parts = Seq(ps.parts.head.copy(id = ChannelCodecs.UNKNOWN_UUID))), ps1)
val expected = Seq(ps.copy(id = ZERO_UUID, parts = Seq(ps.parts.head.copy(id = ZERO_UUID))), ps1)
assert(postMigrationDb.listSent(0, (System.currentTimeMillis.milliseconds + 1.minute).toMillis) === expected)
}

Expand Down
Loading

0 comments on commit 6c81f95

Please sign in to comment.