Skip to content

Commit

Permalink
refactor(rust): move the logic of how to sort and apply migrations in…
Browse files Browse the repository at this point in the history
…to the `NextMigration` enum
  • Loading branch information
adrianbenavides committed Mar 4, 2024
1 parent 605f2c7 commit efbae4a
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 131 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use ockam_core::compat::collections::{HashMap, HashSet};
use ockam_core::compat::collections::HashSet;
use ockam_core::compat::time::now;
use ockam_core::errcode::{Kind, Origin};
use sqlx::migrate::{Migrate, Migration};
use sqlx::migrate::{AppliedMigration, Migrate, Migration as SqlxMigration};
use sqlx::sqlite::SqliteRow;
use sqlx::{query, Row, SqliteConnection, SqlitePool};
use std::cmp::Ordering;
Expand Down Expand Up @@ -67,29 +67,10 @@ impl Migrator {
}

impl Migrator {
fn is_inside_version_range(
version: i64,
from_version: i64, // including
to_version: i64,
including_to: bool,
) -> bool {
if from_version <= version && version < to_version {
return true;
}

if version == to_version && including_to {
return true;
}

false
}

async fn run_migrations(
&self,
connection: &mut SqliteConnection,
from_version: i64, // including
to_version: i64, // not including
run_last_sql: bool, // will run sql migration with verion == to_version
up_to: Version,
) -> Result<()> {
connection.ensure_migrations_table().await.into_core()?;

Expand All @@ -102,6 +83,26 @@ impl Migrator {
));
}

let migrations = {
let sql_iterator = self.sql_migrator.migrations.iter().filter_map(|m| {
if m.version <= up_to {
Some(NextMigration::Sql(m))
} else {
None
}
});
let rust_iterator = self.rust_migrations.iter().filter_map(|m| {
if m.version() <= up_to {
Some(NextMigration::Rust(m.as_ref()))
} else {
None
}
});
let mut migrations: Vec<NextMigration> = sql_iterator.chain(rust_iterator).collect();
migrations.sort();
migrations
};

// sqlx Migrator also optionally checks for missing migrations (ones that had been run and
// marked as migrated in the db but now don't exist). Skipping that check for now.
// WARNING: the check if re-enabled can potentially fail because of renaming
Expand All @@ -110,92 +111,18 @@ impl Migrator {
// before the _rust_migrations table existed
let applied_migrations = connection.list_applied_migrations().await.into_core()?;

let applied_migrations: HashMap<_, _> = applied_migrations
.into_iter()
.map(|m| (m.version, m))
.collect();

enum NextMigration<'a> {
Sql(&'a Migration),
#[allow(clippy::borrowed_box)]
Rust(&'a Box<dyn RustMigration>),
}

impl NextMigration<'_> {
fn is_sql(&self) -> bool {
match self {
NextMigration::Sql(_) => true,
NextMigration::Rust(_) => false,
}
}
}

let sql_iterator = self.sql_migrator.migrations.iter().filter_map(|m| {
let version = m.version;

if !Self::is_inside_version_range(version, from_version, to_version, run_last_sql) {
return None;
}

Some((version, NextMigration::Sql(m)))
});
let rust_iterator = self.rust_migrations.iter().filter_map(|m| {
let version = m.version();

if !Self::is_inside_version_range(version, from_version, to_version, false) {
return None;
}

Some((version, NextMigration::Rust(m)))
});

let mut all_migrations: Vec<(i64, NextMigration)> =
sql_iterator.chain(rust_iterator).collect();
all_migrations.sort_by(|m1, m2| match m1.0.cmp(&m2.0) {
Ordering::Less => Ordering::Less,
Ordering::Equal => {
// Sql migrations go first
if m1.1.is_sql() {
Ordering::Less
} else {
Ordering::Greater
}
}
Ordering::Greater => Ordering::Greater,
});

for migration in all_migrations.iter().map(|(_version, m)| m) {
for migration in migrations.into_iter() {
match migration {
NextMigration::Sql(sql_migration) => {
if sql_migration.migration_type.is_down_migration() {
return Ok(());
}

match applied_migrations.get(&sql_migration.version) {
Some(applied_migration) => {
if sql_migration.checksum != applied_migration.checksum {
return Err(ockam_core::Error::new(
Origin::Node,
Kind::Conflict,
format!(
"Checksum mismatch for sql migration for version {}",
sql_migration.version
),
));
}
}
None => {
connection.apply(sql_migration).await.into_core()?;
}
}
NextMigration::apply_sql_migration(
sql_migration,
connection,
&applied_migrations,
)
.await?;
}
NextMigration::Rust(rust_migration) => {
if Self::has_migrated(connection, rust_migration.name()).await? {
continue;
}
if rust_migration.migrate(connection).await? {
Self::mark_as_migrated(connection, rust_migration.name()).await?;
}
NextMigration::apply_rust_migration(rust_migration, connection).await?;
}
}
}
Expand Down Expand Up @@ -240,22 +167,14 @@ impl Migrator {
}

impl Migrator {
/// Run migrations
pub async fn migrate_partial(
&self,
pool: &SqlitePool,
from_version: i64, // including
to_version: i64, // not including
run_last_sql: bool, // Will run `to_version` version of the sql migration
) -> Result<()> {
/// Run migrations up to the specified version (inclusive)
pub(crate) async fn migrate_up_to(&self, pool: &SqlitePool, up_to: Version) -> Result<()> {
let mut connection = pool.acquire().await.into_core()?;

// Apparently does nothing for sqlite...
connection.lock().await.into_core()?;

let res = self
.run_migrations(&mut connection, from_version, to_version, run_last_sql)
.await;
let res = self.run_migrations(&mut connection, up_to).await;

connection.unlock().await.into_core()?;

Expand All @@ -266,19 +185,192 @@ impl Migrator {

/// Run all migrations
pub async fn migrate(&self, pool: &SqlitePool) -> Result<()> {
self.migrate_partial(pool, 0, i64::MAX, false).await
self.migrate_up_to(pool, i64::MAX).await
}
}

#[cfg(test)]
impl Migrator {
/// Migrate the schema of the database right before the specified version
pub(crate) async fn migrate_before(
&self,
/// Run migrations up to the specified version (inclusive) but skip the last rust migration
pub(crate) async fn migrate_up_to_skip_last_rust_migration(
mut self,
pool: &SqlitePool,
version: i64, // not including
run_last_sql: bool,
up_to: Version,
) -> Result<()> {
self.rust_migrations.retain(|m| m.version() < up_to);
self.migrate_up_to(pool, up_to).await
}
}

type Version = i64;

#[derive(Debug)]
enum NextMigration<'a> {
Sql(&'a SqlxMigration),
Rust(&'a dyn RustMigration),
}

impl NextMigration<'_> {
fn is_sql(&self) -> bool {
matches!(self, Self::Sql(_))
}

fn version(&self) -> Version {
match self {
Self::Sql(m) => m.version,
Self::Rust(m) => m.version(),
}
}

async fn apply_sql_migration<'a>(
migration: &'a SqlxMigration,
connection: &mut SqliteConnection,
applied_migrations: &[AppliedMigration],
) -> Result<()> {
self.migrate_partial(pool, 0, version, run_last_sql).await
if migration.migration_type.is_down_migration() {
return Ok(());
}
match applied_migrations
.iter()
.find(|m| m.version == migration.version)
{
Some(applied_migration) => {
if migration.checksum != applied_migration.checksum {
return Err(ockam_core::Error::new(
Origin::Node,
Kind::Conflict,
format!(
"Checksum mismatch for sql migration for version {}",
migration.version
),
));
}
}
None => {
connection.apply(migration).await.into_core()?;
}
}
Ok(())
}

async fn apply_rust_migration(
migration: &dyn RustMigration,
connection: &mut SqliteConnection,
) -> Result<()> {
if Migrator::has_migrated(connection, migration.name()).await? {
return Ok(());
}
if migration.migrate(connection).await? {
Migrator::mark_as_migrated(connection, migration.name()).await?;
}
Ok(())
}
}

impl Eq for NextMigration<'_> {}

impl PartialEq<Self> for NextMigration<'_> {
fn eq(&self, other: &Self) -> bool {
let same_kind = matches!(
(self, other),
(Self::Sql(_), Self::Sql(_)) | (Self::Rust(_), Self::Rust(_))
);
same_kind && self.version() == other.version()
}
}

impl Ord for NextMigration<'_> {
fn cmp(&self, other: &Self) -> Ordering {
match self.version().cmp(&other.version()) {
Ordering::Equal => {
// Sql migrations go first
match (self.is_sql(), other.is_sql()) {
(true, true) => Ordering::Equal,
(true, false) => Ordering::Less,
(false, true) => Ordering::Greater,
_ => unreachable!(),
}
}
ord => ord,
}
}
}

impl PartialOrd for NextMigration<'_> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

#[cfg(test)]
mod tests {
use super::*;
use ockam_core::async_trait;
use sqlx::migrate::MigrationType;

#[test]
fn ordering_of_migrations() {
let sql_1 = SqlxMigration::new(1, "sql_1".into(), MigrationType::Simple, "1".into());
let sql_2 = SqlxMigration::new(2, "sql_2".into(), MigrationType::Simple, "2".into());
let rust_1: Box<dyn RustMigration> = Box::new(DummyRustMigration::new(1));
let rust_2: Box<dyn RustMigration> = Box::new(DummyRustMigration::new(2));
let rust_3: Box<dyn RustMigration> = Box::new(DummyRustMigration::new(3));

let mut migrations = vec![
NextMigration::Sql(&sql_2),
NextMigration::Sql(&sql_1),
NextMigration::Rust(rust_1.as_ref()),
NextMigration::Rust(rust_3.as_ref()),
NextMigration::Rust(rust_2.as_ref()),
];
migrations.sort();

for m in &migrations {
match m {
NextMigration::Sql(_) => {
assert!(m.is_sql());
}
NextMigration::Rust(_) => {
assert!(!m.is_sql());
}
}
}

assert_eq!(
migrations,
vec![
NextMigration::Sql(&sql_1),
NextMigration::Rust(rust_1.as_ref()),
NextMigration::Sql(&sql_2),
NextMigration::Rust(rust_2.as_ref()),
NextMigration::Rust(rust_3.as_ref())
]
);
}

#[derive(Debug)]
struct DummyRustMigration {
version: Version,
}

impl DummyRustMigration {
fn new(version: Version) -> Self {
Self { version }
}
}

#[async_trait]
impl RustMigration for DummyRustMigration {
fn name(&self) -> &str {
"DummyRustMigration"
}

fn version(&self) -> Version {
self.version
}

async fn migrate(&self, _connection: &mut SqliteConnection) -> Result<bool> {
Ok(true)
}
}
}
Loading

0 comments on commit efbae4a

Please sign in to comment.