diff --git a/Cargo.lock b/Cargo.lock index 150d540e9..e26b0614f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -512,6 +512,7 @@ dependencies = [ "tracing", "tracing-subscriber", "typed-builder", + "typetag", ] [[package]] @@ -2317,6 +2318,15 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "erased-serde" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b73807008a3c7f171cc40312f37d95ef0396e048b5848d775f54b1a4dd4a0d3" +dependencies = [ + "serde", +] + [[package]] name = "errno" version = "0.3.8" @@ -3467,6 +3477,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "inventory" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f958d3d68f4167080a18141e10381e7634563984a537f2a49a30fd8e53ac5767" + [[package]] name = "io-extras" version = "0.18.2" @@ -8455,6 +8471,30 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +[[package]] +name = "typetag" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "661d18414ec032a49ece2d56eee03636e43c4e8d577047ab334c0ba892e29aaf" +dependencies = [ + "erased-serde", + "inventory", + "once_cell", + "serde", + "typetag-impl", +] + +[[package]] +name = "typetag-impl" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac73887f47b9312552aa90ef477927ff014d63d1920ca8037c6c1951eab64bb1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + [[package]] name = "ucd-trie" version = "0.1.6" diff --git a/Cargo.toml b/Cargo.toml index 9f2cfbc43..3854a05a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,4 @@ [profile.dev.package] -backtrace = { opt-level = 3 } -num-bigint-dig = { opt-level = 3 } taplo = { debug-assertions = false } # A debug assertion will make the xtask panic with too long trailing comments # The profile that 'cargo dist' will build with diff --git a/kitsune-job-runner/src/lib.rs b/kitsune-job-runner/src/lib.rs index 4a25544c8..f48d37556 100644 --- a/kitsune-job-runner/src/lib.rs +++ b/kitsune-job-runner/src/lib.rs @@ -88,7 +88,6 @@ pub async fn run_dispatcher( }, }); - let job_queue = Arc::new(job_queue); let job_tracker = TaskTracker::new(); job_tracker.close(); @@ -99,13 +98,13 @@ pub async fn run_dispatcher( let job_tracker = job_tracker.clone(); async move { - job_queue - .spawn_jobs( - num_job_workers - job_tracker.len(), - Arc::clone(&ctx), - &job_tracker, - ) - .await + athena::spawn_jobs( + &job_queue, + num_job_workers - job_tracker.len(), + Arc::clone(&ctx), + &job_tracker, + ) + .await } }) .retry(just_retry::backoff_policy()) diff --git a/lib/athena/Cargo.toml b/lib/athena/Cargo.toml index b1701387b..b0668e97e 100644 --- a/lib/athena/Cargo.toml +++ b/lib/athena/Cargo.toml @@ -10,12 +10,12 @@ name = "basic_queue" required-features = ["redis"] [dependencies] -ahash = { version = "0.8.11", optional = true } +ahash = "0.8.11" async-trait = "0.1.80" either = { version = "1.11.0", default-features = false, optional = true } futures-util = { version = "0.3.30", default-features = false } iso8601-timestamp = "0.2.17" -just-retry = { path = "../just-retry", optional = true } +just-retry = { path = "../just-retry" } multiplex-pool = { path = "../multiplex-pool", optional = true } once_cell = { version = "1.19.0", optional = true } rand = { version = "0.8.5", optional = true } @@ -26,7 +26,7 @@ redis = { version = "0.25.3", default-features = false, features = [ "streams", "tokio-comp", ], optional = true } -serde = { version = "1.0.199", features = ["derive"], optional = true } +serde = { version = "1.0.199", features = ["derive"] } simd-json = { version = "0.13.10", optional = true } smol_str = "0.2.1" speedy-uuid = { path = "../speedy-uuid", features = ["redis", "serde"] } @@ -35,17 +35,15 @@ tokio = { version = "1.37.0", features = ["macros", "rt", "sync"] } tokio-util = { version = "0.7.10", features = ["rt"] } tracing = "0.1.40" typed-builder = "0.18.2" +typetag = "0.2.16" [features] redis = [ - "dep:ahash", "dep:either", - "dep:just-retry", "dep:multiplex-pool", "dep:once_cell", "dep:rand", "dep:redis", - "dep:serde", "dep:simd-json", ] diff --git a/lib/athena/examples/basic_queue.rs b/lib/athena/examples/basic_queue.rs index f060f6aaf..4b1b32bf0 100644 --- a/lib/athena/examples/basic_queue.rs +++ b/lib/athena/examples/basic_queue.rs @@ -102,7 +102,7 @@ async fn main() { loop { if tokio::time::timeout( Duration::from_secs(5), - queue.spawn_jobs(20, Arc::new(()), &jobs), + athena::spawn_jobs(&queue, 20, Arc::new(()), &jobs), ) .await .is_err() diff --git a/lib/athena/src/common.rs b/lib/athena/src/common.rs index da30e272f..e47949e01 100644 --- a/lib/athena/src/common.rs +++ b/lib/athena/src/common.rs @@ -1,19 +1,18 @@ use crate::{ + consts::MIN_IDLE_TIME, error::{Error, Result}, - JobContextRepository, JobQueue, Runnable, + JobContextRepository, JobData, JobQueue, JobResult, Outcome, Runnable, }; use ahash::AHashMap; +use futures_util::TryStreamExt; +use just_retry::RetryExt; use speedy_uuid::Uuid; use std::{sync::Arc, time::Duration}; use tokio::time::Instant; use tokio_util::task::TaskTracker; -const BLOCK_TIME: Duration = Duration::from_secs(2); -const MAX_RETRIES: u32 = 10; -const MIN_IDLE_TIME: Duration = Duration::from_secs(10 * 60); - -type ContextFor = - <::JobContext as Runnable>::Context; +type ContextFor = + <<::ContextRepository as JobContextRepository>::JobContext as Runnable>::Context; pub async fn spawn_jobs( queue: &Q, @@ -25,9 +24,11 @@ where Q: JobQueue + Clone, { let job_data = queue.fetch_job_data(max_jobs).await?; + let job_ids: Vec = job_data.iter().map(|data| data.job_id).collect(); + let context_stream = queue .context_repository() - .fetch_context(job_data.clone().map(|data| data.meta.job_id)) + .fetch_context(job_ids.into_iter()) .await .map_err(|err| Error::ContextRepository(err.into()))?; @@ -36,14 +37,14 @@ where // Collect all the job data into a hashmap indexed by the job ID // This is because we don't enforce an ordering with the batch fetching let job_data = job_data - .map(|data| (data.meta.job_id, data)) + .into_iter() + .map(|data| (data.job_id, data)) .collect::>(); let job_data = Arc::new(job_data); while let Some((job_id, job_ctx)) = context_stream - .next() + .try_next() .await - .transpose() .map_err(|err| Error::ContextRepository(err.into()))? { let queue = queue.clone(); @@ -73,15 +74,18 @@ where let job_state = if let Err(error) = result { error!(error = ?error.into(), "Failed run job"); - JobState::Failed { - fail_count: job_data.meta.fail_count, + JobResult { + outcome: Outcome::Fail { + fail_count: job_data.fail_count, + }, job_id, - stream_id: &job_data.stream_id, + ctx: &job_data.ctx, } } else { - JobState::Succeeded { + JobResult { + outcome: Outcome::Success, job_id, - stream_id: &job_data.stream_id, + ctx: &job_data.ctx, } }; diff --git a/lib/athena/src/consts.rs b/lib/athena/src/consts.rs new file mode 100644 index 000000000..fbe032dab --- /dev/null +++ b/lib/athena/src/consts.rs @@ -0,0 +1,5 @@ +use std::time::Duration; + +pub const BLOCK_TIME: Duration = Duration::from_secs(2); +pub const MAX_RETRIES: u32 = 10; +pub const MIN_IDLE_TIME: Duration = Duration::from_secs(10 * 60); diff --git a/lib/athena/src/lib.rs b/lib/athena/src/lib.rs index 0e3f99da2..449d298fe 100644 --- a/lib/athena/src/lib.rs +++ b/lib/athena/src/lib.rs @@ -1,23 +1,27 @@ -#[cfg(feature = "redis")] #[macro_use] extern crate tracing; use self::error::{BoxError, Result}; -use ahash::AHashMap; use async_trait::async_trait; use futures_util::{Future, Stream}; use iso8601_timestamp::Timestamp; +use serde::{Deserialize, Serialize}; use speedy_uuid::Uuid; -use std::sync::Arc; +use std::{ + any::{Any, TypeId}, + sync::Arc, +}; use typed_builder::TypedBuilder; pub use self::error::Error; pub use tokio_util::task::TaskTracker; +pub use self::common::spawn_jobs; #[cfg(feature = "redis")] pub use self::redis::JobQueue as RedisJobQueue; mod common; +mod consts; mod error; mod macros; #[cfg(feature = "redis")] @@ -36,9 +40,86 @@ pub struct JobDetails { pub run_at: Option, } +#[typetag::serde] +pub trait Keepable: Any + Send + Sync + 'static {} + +// Hack around because it's not stable yet. +// So I had to implement trait downcasting myself. +// +// TODO: Remove this once is stabilized. +#[inline] +fn downcast_to(obj: &dyn Keepable) -> Option<&T> +where + T: 'static, +{ + if obj.type_id() == TypeId::of::() { + #[allow(unsafe_code)] + // SAFETY: the `TypeId` equality check ensures this type cast is correct + Some(unsafe { &*(obj as *const dyn Keepable).cast::() }) + } else { + None + } +} + +#[typetag::serde] +impl Keepable for String {} + +#[derive(Deserialize, Serialize)] +#[serde(transparent)] +pub struct KeeperOfTheSecrets { + inner: Option>, +} + +impl KeeperOfTheSecrets { + #[inline] + #[must_use] + pub fn empty() -> Self { + Self { inner: None } + } + + #[inline] + pub fn new(inner: T) -> Self + where + T: Keepable, + { + Self { + inner: Some(Box::new(inner)), + } + } + + #[inline] + #[must_use] + pub fn get(&self) -> Option<&T> + where + T: 'static, + { + self.inner + .as_ref() + .and_then(|item| downcast_to(item.as_ref())) + } +} + +pub enum Outcome { + Success, + Fail { fail_count: u32 }, +} + +pub struct JobResult<'a> { + outcome: Outcome, + job_id: Uuid, + ctx: &'a KeeperOfTheSecrets, +} + +#[derive(Deserialize, Serialize)] +pub struct JobData { + job_id: Uuid, + fail_count: u32, + ctx: KeeperOfTheSecrets, +} + #[async_trait] pub trait JobQueue: Send + Sync + 'static { - type ContextRepository: JobContextRepository; + type ContextRepository: JobContextRepository + 'static; fn context_repository(&self) -> &Self::ContextRepository; @@ -46,6 +127,43 @@ pub trait JobQueue: Send + Sync + 'static { &self, job_details: JobDetails<::JobContext>, ) -> Result<()>; + + async fn fetch_job_data(&self, max_jobs: usize) -> Result>; + + async fn reclaim_job(&self, job_data: &JobData) -> Result<()>; + + async fn complete_job(&self, state: &JobResult<'_>) -> Result<()>; +} + +#[async_trait] +impl JobQueue for Arc + '_> +where + CR: JobContextRepository + 'static, +{ + type ContextRepository = CR; + + fn context_repository(&self) -> &Self::ContextRepository { + (**self).context_repository() + } + + async fn enqueue( + &self, + job_details: JobDetails<::JobContext>, + ) -> Result<()> { + (**self).enqueue(job_details).await + } + + async fn fetch_job_data(&self, max_jobs: usize) -> Result> { + (**self).fetch_job_data(max_jobs).await + } + + async fn reclaim_job(&self, job_data: &JobData) -> Result<()> { + (**self).reclaim_job(job_data).await + } + + async fn complete_job(&self, state: &JobResult<'_>) -> Result<()> { + (**self).complete_job(state).await + } } pub trait Runnable { diff --git a/lib/athena/src/redis/mod.rs b/lib/athena/src/redis/mod.rs index 25d87dbec..5e7ba06b8 100644 --- a/lib/athena/src/redis/mod.rs +++ b/lib/athena/src/redis/mod.rs @@ -1,31 +1,25 @@ use self::{scheduled::ScheduledJobActor, util::StreamAutoClaimReply}; -use crate::{error::Result, impl_to_redis_args, Error, JobContextRepository, JobDetails, Runnable}; -use ahash::AHashMap; +use crate::{ + consts::{BLOCK_TIME, MAX_RETRIES, MIN_IDLE_TIME}, + error::Result, + Error, JobContextRepository, JobData, JobDetails, JobResult, KeeperOfTheSecrets, Outcome, +}; use async_trait::async_trait; use either::Either; -use futures_util::StreamExt; use iso8601_timestamp::Timestamp; use just_retry::{ retry_policies::{policies::ExponentialBackoff, Jitter}, - JustRetryPolicy, RetryExt, StartTime, + JustRetryPolicy, StartTime, }; use redis::{ aio::ConnectionLike, streams::{StreamReadOptions, StreamReadReply}, AsyncCommands, RedisResult, }; -use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use speedy_uuid::Uuid; -use std::{ - ops::ControlFlow, - pin::pin, - str::FromStr, - sync::Arc, - time::{Duration, SystemTime}, -}; -use tokio::{sync::OnceCell, time::Instant}; -use tokio_util::task::TaskTracker; +use std::{ops::ControlFlow, str::FromStr, sync::Arc, time::SystemTime}; +use tokio::sync::OnceCell; use typed_builder::TypedBuilder; mod scheduled; @@ -33,45 +27,6 @@ mod util; type Pool = multiplex_pool::Pool; -enum JobState<'a> { - Succeeded { - job_id: Uuid, - stream_id: &'a str, - }, - Failed { - fail_count: u32, - job_id: Uuid, - stream_id: &'a str, - }, -} - -impl JobState<'_> { - fn job_id(&self) -> Uuid { - match self { - Self::Succeeded { job_id, .. } | Self::Failed { job_id, .. } => *job_id, - } - } - - fn stream_id(&self) -> &str { - match self { - Self::Succeeded { stream_id, .. } | Self::Failed { stream_id, .. } => stream_id, - } - } -} - -struct JobData { - stream_id: String, - meta: JobMeta, -} - -impl_to_redis_args! { - #[derive(Deserialize, Serialize)] - struct JobMeta { - job_id: Uuid, - fail_count: u32, - } -} - #[derive(TypedBuilder)] pub struct JobQueue { #[builder(default = "athena-job-runners".into(), setter(into))] @@ -136,7 +91,7 @@ where fn enqueue_redis_cmd( &self, - job_meta: &JobMeta, + job_meta: &JobData, run_at: Option, ) -> Result { let cmd = if let Some(run_at) = run_at { @@ -148,17 +103,53 @@ where ) } else { let mut cmd = redis::cmd("XADD"); - cmd.arg(self.queue_name.as_str()).arg("*").arg(job_meta); + cmd.arg(self.queue_name.as_str()) + .arg("*") + .arg("job_id") + .arg(job_meta.job_id) + .arg("fail_count") + .arg(job_meta.fail_count); + cmd }; Ok(cmd) } +} - async fn fetch_job_data( - &self, - max_jobs: usize, - ) -> Result + Clone> { +#[async_trait] +impl crate::JobQueue for JobQueue +where + CR: JobContextRepository + Send + Sync + 'static, +{ + type ContextRepository = CR; + + #[inline] + fn context_repository(&self) -> &Self::ContextRepository { + &self.context_repository + } + + async fn enqueue(&self, job_details: JobDetails) -> Result<()> { + let job_data = JobData { + job_id: job_details.job_id, + fail_count: job_details.fail_count, + ctx: KeeperOfTheSecrets::empty(), + }; + + self.context_repository + .store_context(job_data.job_id, job_details.context) + .await + .map_err(|err| Error::ContextRepository(err.into()))?; + + let mut redis_conn = self.redis_pool.get(); + self.enqueue_redis_cmd(&job_data, job_details.run_at)? + .query_async(&mut redis_conn) + .await?; + + Ok(()) + } + + async fn fetch_job_data(&self, max_jobs: usize) -> Result> { let mut redis_conn = self.redis_pool.get(); self.initialise_group(&mut redis_conn).await?; @@ -198,23 +189,31 @@ where Either::Right(claimed_ids.into_iter().chain(read_ids)) }; - let job_data_iterator = claimed_ids.map(|id| { - let job_id: String = - redis::from_redis_value(&id.map["job_id"]).expect("[Bug] Malformed Job ID"); - let job_id = Uuid::from_str(&job_id).expect("[Bug] Job ID is not a UUID"); - let fail_count: u32 = - redis::from_redis_value(&id.map["fail_count"]).expect("[Bug] Malformed fail count"); - - JobData { - stream_id: id.id, - meta: JobMeta { job_id, fail_count }, - } - }); + let job_data = claimed_ids + .map(|id| { + let job_id: String = + redis::from_redis_value(&id.map["job_id"]).expect("[Bug] Malformed Job ID"); + let job_id = Uuid::from_str(&job_id).expect("[Bug] Job ID is not a UUID"); + let fail_count: u32 = redis::from_redis_value(&id.map["fail_count"]) + .expect("[Bug] Malformed fail count"); + + JobData { + ctx: KeeperOfTheSecrets::new(id.id), + job_id, + fail_count, + } + }) + .collect(); - Ok(job_data_iterator) + Ok(job_data) } - async fn complete_job(&self, state: &JobState<'_>) -> Result<()> { + async fn complete_job(&self, state: &JobResult<'_>) -> Result<()> { + let stream_id = state + .ctx + .get::() + .expect("[Bug] Not a string in the context"); + let mut pipeline = redis::pipe(); pipeline .atomic() @@ -222,28 +221,27 @@ where .xack( self.queue_name.as_str(), self.consumer_group.as_str(), - &[state.stream_id()], + &[stream_id], ) - .xdel(self.queue_name.as_str(), &[state.stream_id()]); + .xdel(self.queue_name.as_str(), &[stream_id]); - let remove_context = match state { - JobState::Failed { - fail_count, job_id, .. - } => { + let remove_context = match state.outcome { + Outcome::Fail { fail_count } => { let backoff = ExponentialBackoff::builder() .jitter(Jitter::Bounded) .build_with_max_retries(self.max_retries); if let ControlFlow::Continue(delta) = - backoff.should_retry(StartTime::Irrelevant, *fail_count) + backoff.should_retry(StartTime::Irrelevant, fail_count) { - let job_meta = JobMeta { - job_id: *job_id, + let job_data = JobData { + job_id: state.job_id, fail_count: fail_count + 1, + ctx: KeeperOfTheSecrets::empty(), }; let backoff_timestamp = Timestamp::from(SystemTime::now() + delta); - let enqueue_cmd = self.enqueue_redis_cmd(&job_meta, Some(backoff_timestamp))?; + let enqueue_cmd = self.enqueue_redis_cmd(&job_data, Some(backoff_timestamp))?; pipeline.add_command(enqueue_cmd); @@ -252,7 +250,7 @@ where true // We hit the maximum amount of retries, we won't re-enqueue the job, so we can just remove the context } } - JobState::Succeeded { .. } => true, // Execution succeeded, we don't need the context anymore + Outcome::Success => true, // Execution succeeded, we don't need the context anymore }; { @@ -262,7 +260,7 @@ where if remove_context { self.context_repository - .remove_context(state.job_id()) + .remove_context(state.job_id) .await .map_err(|err| Error::ContextRepository(err.into()))?; } @@ -271,13 +269,18 @@ where } async fn reclaim_job(&self, job_data: &JobData) -> Result<()> { + let stream_id = job_data + .ctx + .get::() + .expect("[Bug] Not a string in the context"); + let mut conn = self.redis_pool.get(); conn.xclaim( self.queue_name.as_str(), self.consumer_group.as_str(), self.consumer_name.as_str(), 0, - &[job_data.stream_id.as_str()], + &[stream_id], ) .await?; @@ -285,38 +288,6 @@ where } } -#[async_trait] -impl crate::JobQueue for JobQueue -where - CR: JobContextRepository + Send + Sync + 'static, -{ - type ContextRepository = CR; - - #[inline] - fn context_repository(&self) -> &Self::ContextRepository { - &self.context_repository - } - - async fn enqueue(&self, job_details: JobDetails) -> Result<()> { - let job_meta = JobMeta { - job_id: job_details.job_id, - fail_count: job_details.fail_count, - }; - - self.context_repository - .store_context(job_meta.job_id, job_details.context) - .await - .map_err(|err| Error::ContextRepository(err.into()))?; - - let mut redis_conn = self.redis_pool.get(); - self.enqueue_redis_cmd(&job_meta, job_details.run_at)? - .query_async(&mut redis_conn) - .await?; - - Ok(()) - } -} - impl Clone for JobQueue { fn clone(&self) -> Self { Self { diff --git a/lib/athena/tests/redis.rs b/lib/athena/tests/redis.rs index 2bd602f51..686c6ee25 100644 --- a/lib/athena/tests/redis.rs +++ b/lib/athena/tests/redis.rs @@ -40,9 +40,10 @@ impl JobContextRepository for ContextRepo { async fn fetch_context(&self, job_ids: I) -> Result where - I: Iterator + Send + 'static, + I: Iterator + Send, { - let stream = stream::iter(job_ids).map(|id| Ok((id, JobCtx))); + let vec: Vec<_> = job_ids.collect(); + let stream = stream::iter(vec).map(|id| Ok((id, JobCtx))); Ok(stream.boxed()) } @@ -75,7 +76,11 @@ async fn basic_schedule() { let jobs = TaskTracker::new(); jobs.close(); - queue.spawn_jobs(1, Arc::new(()), &jobs).await.unwrap(); + + athena::spawn_jobs(&queue, 1, Arc::new(()), &jobs) + .await + .unwrap(); + jobs.wait().await; assert!(DID_RUN.load(Ordering::Acquire));