diff --git a/README.md b/README.md index 24a4891c..d3a40447 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ and register your tasks with it: ```rust let my_app = celery::app!( broker = AMQPBroker { std::env::var("AMQP_ADDR").unwrap() }, + backend = RedisBackend { std::env::var("REDIS_ADDR").unwrap() }, tasks = [add], task_routes = [ "*" => "celery", @@ -61,7 +62,7 @@ let my_app = celery::app!( Then send tasks to a queue with ```rust -my_app.send_task(add::new(1, 2)).await?; +let result = my_app.send_task(add::new(1, 2)).await?; ``` And consume tasks as a worker from a queue with @@ -70,6 +71,14 @@ And consume tasks as a worker from a queue with my_app.consume().await?; ``` +In a Rust client you can wait for the result of a task with + + +``` rust +let result = result.get().fetch().await.expect("get task"); +assert_eq!(3, result); +``` + ## Examples The [`examples/`](https://github.com/rusty-celery/rusty-celery/tree/main/examples) directory contains: @@ -152,7 +161,7 @@ And then you can consume tasks from Rust or Python as explained above. | Consumers | ✅ | | | Brokers | ✅ | | | Beat | ✅ | | -| Backends | 🔴 | | +| Backends | ⚠️ | | | [Baskets](https://github.com/rusty-celery/rusty-celery/issues/53) | 🔴 | | ### Brokers @@ -167,4 +176,4 @@ And then you can consume tasks from Rust or Python as explained above. | | Status | Tracking | | ----------- |:------:| -------- | | RPC | 🔴 | [![](https://img.shields.io/github/issues/rusty-celery/rusty-celery/Backend%3A%20RPC?label=Issues)](https://github.com/rusty-celery/rusty-celery/labels/Backend%3A%20RPC) | -| Redis | 🔴 | [![](https://img.shields.io/github/issues/rusty-celery/rusty-celery/Backend%3A%20Redis?label=Issues)](https://github.com/rusty-celery/rusty-celery/labels/Backend%3A%20Redis) | +| Redis | ✅ | [![](https://img.shields.io/github/issues/rusty-celery/rusty-celery/Backend%3A%20Redis?label=Issues)](https://github.com/rusty-celery/rusty-celery/labels/Backend%3A%20Redis) | diff --git a/examples/celery_app.rs b/examples/celery_app.rs index b5858f45..5aa853f8 100644 --- a/examples/celery_app.rs +++ b/examples/celery_app.rs @@ -111,11 +111,17 @@ async fn main() -> Result<()> { } else { for task in tasks { match task.as_str() { - "add" => my_app.send_task(add::new(1, 2)).await?, - "bound_task" => my_app.send_task(bound_task::new()).await?, - "buggy_task" => my_app.send_task(buggy_task::new()).await?, + "add" => { + my_app.send_task(add::new(1, 2)).await?; + } + "bound_task" => { + my_app.send_task(bound_task::new()).await?; + } + "buggy_task" => { + my_app.send_task(buggy_task::new()).await?; + } "long_running_task" => { - my_app.send_task(long_running_task::new(Some(3))).await? + my_app.send_task(long_running_task::new(Some(3))).await?; } _ => panic!("unknown task"), }; diff --git a/src/app/mod.rs b/src/app/mod.rs index 83c873ce..10b5add0 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -17,6 +17,7 @@ use tokio_stream::StreamMap; mod trace; +use crate::backend::{backend_builder_from_url, Backend, BackendBuilder}; use crate::broker::{ broker_builder_from_url, build_and_connect, configure_task_routes, Broker, BrokerBuilder, Delivery, @@ -31,6 +32,7 @@ struct Config { name: String, hostname: String, broker_builder: Box, + backend_builder: Option>, broker_connection_timeout: u32, broker_connection_retry: bool, broker_connection_max_retries: u32, @@ -47,7 +49,7 @@ pub struct CeleryBuilder { impl CeleryBuilder { /// Get a [`CeleryBuilder`] for creating a [`Celery`] app with a custom configuration. - pub fn new(name: &str, broker_url: &str) -> Self { + pub fn new(name: &str, broker_url: &str, backend_url: Option>) -> Self { Self { config: Config { name: name.into(), @@ -60,6 +62,7 @@ impl CeleryBuilder { .unwrap_or_else(|| "unknown".into()) ), broker_builder: broker_builder_from_url(broker_url), + backend_builder: backend_url.map(backend_builder_from_url), broker_connection_timeout: 2, broker_connection_retry: true, broker_connection_max_retries: 5, @@ -189,6 +192,24 @@ impl CeleryBuilder { self } + /// Set connection timeout for backend. + pub fn backend_connection_timeout(mut self, timeout: u32) -> Self { + self.config.backend_builder = self + .config + .backend_builder + .map(|builder| builder.connection_timeout(timeout)); + self + } + + /// Set backend task meta collection name. + pub fn backend_taskmeta_collection(mut self, collection_name: &str) -> Self { + self.config.backend_builder = self + .config + .backend_builder + .map(|builder| builder.taskmeta_collection(collection_name)); + self + } + /// Construct a [`Celery`] app with the current configuration. pub async fn build(self) -> Result { // Declare default queue to broker. @@ -212,10 +233,16 @@ impl CeleryBuilder { ) .await?; + let backend = match self.config.backend_builder { + Some(backend_builder) => Some(backend_builder.build().await?), + None => None, + }; + Ok(Celery { name: self.config.name, hostname: self.config.hostname, broker, + backend, default_queue: self.config.default_queue, task_options: self.config.task_options, task_routes, @@ -240,6 +267,9 @@ pub struct Celery { /// The app's broker. pub broker: Box, + /// The backend to use for storing task results. + pub backend: Option>, + /// The default queue to send and receive from. pub default_queue: String, @@ -302,7 +332,7 @@ impl Celery { pub async fn send_task( &self, mut task_sig: Signature, - ) -> Result { + ) -> Result, CeleryError> { task_sig.options.update(&self.task_options); let maybe_queue = task_sig.queue.take(); let queue = maybe_queue.as_deref().unwrap_or_else(|| { @@ -316,7 +346,15 @@ impl Celery { queue, ); self.broker.send(&message, queue).await?; - Ok(AsyncResult::new(message.task_id())) + + if let Some(backend) = &self.backend { + backend.add_task(message.task_id()).await?; + } + + Ok(AsyncResult::::new( + message.task_id(), + self.backend.clone(), + )) } /// Register a task. @@ -428,12 +466,26 @@ impl Celery { // NOTE: we don't need to log errors from the trace here since the tracer // handles all errors at it's own level or the task level. In this function // we only log errors at the broker and delivery level. - if let Err(TraceError::Retry(retry_eta)) = tracer.trace().await { - // If retry error -> retry the task. - self.broker - .retry(delivery.as_ref(), retry_eta) - .await - .map_err(|e| Box::new(e) as Box)?; + match tracer.trace().await { + Err(TraceError::Retry(retry_eta)) => { + // If retry error -> retry the task. + self.broker + .retry(delivery.as_ref(), retry_eta) + .await + .map_err(|e| Box::new(e) as Box)?; + } + + result => { + if let Some(backend) = self.backend.as_ref() { + backend + .store_result( + tracer.task_id(), + result.map_err(|err| err.into_task_error()), + ) + .await + .map_err(|e| Box::new(e) as Box)?; + } + } } // If we have not done it before, we have to acknowledge the message now. diff --git a/src/app/tests.rs b/src/app/tests.rs index cfb76d18..41c7779d 100644 --- a/src/app/tests.rs +++ b/src/app/tests.rs @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; use std::time::SystemTime; async fn build_basic_app() -> Celery { - let celery = CeleryBuilder::new("mock-app", "mock://localhost:8000") + let celery = CeleryBuilder::new("mock-app", "mock://localhost:8000", Option::<&str>::None) .build() .await .unwrap(); @@ -18,7 +18,7 @@ async fn build_basic_app() -> Celery { } async fn build_configured_app() -> Celery { - let celery = CeleryBuilder::new("mock-app", "mock://localhost:8000") + let celery = CeleryBuilder::new("mock-app", "mock://localhost:8000", Option::<&str>::None) .task_time_limit(10) .task_max_retries(100) .task_content_type(MessageContentType::Yaml) diff --git a/src/app/trace.rs b/src/app/trace.rs index 2441e219..3df68c8a 100644 --- a/src/app/trace.rs +++ b/src/app/trace.rs @@ -5,8 +5,9 @@ use tokio::sync::mpsc::UnboundedSender; use tokio::time::{self, Duration, Instant}; use crate::error::{ProtocolError, TaskError, TraceError}; +use crate::prelude::TaskErrorType; use crate::protocol::Message; -use crate::task::{Request, Task, TaskEvent, TaskOptions, TaskStatus}; +use crate::task::{Request, Task, TaskEvent, TaskOptions, TaskReturns, TaskStatus}; use crate::Celery; /// A `Tracer` provides the API through which a `Celery` application interacts with its tasks. @@ -48,7 +49,7 @@ impl TracerTrait for Tracer where T: Task, { - async fn trace(&mut self) -> Result<(), TraceError> { + async fn trace(&mut self) -> Result, TraceError> { if self.is_expired() { warn!( "Task {}[{}] expired, discarding", @@ -60,10 +61,10 @@ where self.event_tx .send(TaskEvent::StatusChange(TaskStatus::Pending)) - .unwrap_or_else(|_| { + .unwrap_or_else(|err| { // This really shouldn't happen. If it does, there's probably much // bigger things to worry about like running out of memory. - error!("Failed sending task event"); + error!("Failed sending task event {err:?}"); }); let start = Instant::now(); @@ -73,7 +74,7 @@ where let duration = Duration::from_secs(secs as u64); time::timeout(duration, self.task.run(self.task.request().params.clone())) .await - .unwrap_or(Err(TaskError::TimeoutError)) + .unwrap_or(Err(TaskError::timeout())) } None => self.task.run(self.task.request().params.clone()).await, }; @@ -94,49 +95,27 @@ where self.event_tx .send(TaskEvent::StatusChange(TaskStatus::Finished)) - .unwrap_or_else(|_| { - error!("Failed sending task event"); + .unwrap_or_else(|err| { + error!("Failed sending task event {err:?}"); }); - Ok(()) + Ok(Box::new(returned)) } + Err(e) => { - let (should_retry, retry_eta) = match e { - TaskError::ExpectedError(ref reason) => { - warn!( - "Task {}[{}] failed with expected error: {}", - self.task.name(), - &self.task.request().id, - reason - ); - (true, None) - } - TaskError::UnexpectedError(ref reason) => { - error!( - "Task {}[{}] failed with unexpected error: {}", - self.task.name(), - &self.task.request().id, - reason - ); - (self.task.retry_for_unexpected(), None) - } - TaskError::TimeoutError => { - error!( - "Task {}[{}] timed out after {}s", - self.task.name(), - &self.task.request().id, - duration.as_secs_f32(), - ); - (true, None) - } - TaskError::Retry(eta) => { - error!( - "Task {}[{}] triggered retry", - self.task.name(), - &self.task.request().id, - ); - (true, eta) - } + error!( + "Task {}[{}] failed: {e}", + self.task.name(), + &self.task.request().id + ); + + let (should_retry, retry_eta) = match e.kind { + TaskErrorType::Retry(eta) => (true, eta), + TaskErrorType::MaxRetriesExceeded => (false, None), + TaskErrorType::Expected => (true, None), + TaskErrorType::Unexpected => (self.task.retry_for_unexpected(), None), + TaskErrorType::Timeout => (true, None), + TaskErrorType::Other => (true, None), }; // Run failure callback. @@ -144,8 +123,8 @@ where self.event_tx .send(TaskEvent::StatusChange(TaskStatus::Finished)) - .unwrap_or_else(|_| { - error!("Failed sending task event"); + .unwrap_or_else(|err| { + error!("Failed sending task event {err:?}"); }); if !should_retry { @@ -160,7 +139,10 @@ where self.task.name(), &self.task.request().id, ); - return Err(TraceError::TaskError(e)); + return Err(TraceError::TaskError(TaskError::max_retries_exceeded( + self.task.name(), + &self.task.request().id, + ))); } info!( "Task {}[{}] retrying ({} / {})", @@ -202,13 +184,17 @@ where fn acks_late(&self) -> bool { self.task.acks_late() } + + fn task_id(&self) -> &str { + &self.task.request().id + } } #[async_trait] pub(super) trait TracerTrait: Send + Sync { /// Wraps the execution of a task, catching and logging errors and then running /// the appropriate post-execution functions. - async fn trace(&mut self) -> Result<(), TraceError>; + async fn trace(&mut self) -> Result, TraceError>; /// Wait until the task is due. async fn wait(&self); @@ -218,6 +204,8 @@ pub(super) trait TracerTrait: Send + Sync { fn is_expired(&self) -> bool; fn acks_late(&self) -> bool; + + fn task_id(&self) -> &str; } pub(super) type TraceBuilderResult = Result, ProtocolError>; diff --git a/src/backend/mod.rs b/src/backend/mod.rs index e69de29b..4b08c42a 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -0,0 +1,134 @@ +mod redis; + +use std::sync::Arc; + +use chrono::prelude::*; +use serde::{Deserialize, Serialize}; + +use crate::{ + prelude::{BackendError, TaskError}, + task::{TaskResult, TaskReturns, TaskState}, +}; +pub use redis::RedisBackend; +use redis::RedisBackendBuilder; + +use async_trait::async_trait; + +/// A [`BackendBuilder`] is used to create a type of results [`Backend`] with a custom configuration. +#[async_trait] +pub trait BackendBuilder { + /// Create a new `BackendBuilder`. + fn new(broker_url: &str) -> Self + where + Self: Sized; + + /// Set database name. + fn database(self: Box, database: &str) -> Box; + + /// Set database collection name. + fn connection_timeout(self: Box, timeout: u32) -> Box; + + /// Set database collection name. + fn taskmeta_collection(self: Box, collection_name: &str) -> Box; + + /// Construct the `Backend` with the given configuration. + async fn build(self: Box) -> Result, BackendError>; +} + +#[async_trait] +pub trait Backend: Send + Sync { + /// Update task state and result. + async fn store_task_meta(&self, task_id: &str, metadata: TaskMeta) -> Result<(), BackendError>; + + /// Get task meta from backend. + async fn get_task_meta(&self, task_id: &str) -> Result; + + async fn store_result( + &self, + task_id: &str, + result: TaskResult>, + ) -> Result<(), BackendError> { + let (status, result) = match result { + Ok(val) => (TaskState::Success, val.to_json()?), + Err(err) => (TaskState::Failure, serde_json::to_value(err)?), + }; + + self.store_task_meta( + task_id, + TaskMeta { + task_id: task_id.to_string(), + status, + result: Some(result), + traceback: None, + date_done: NaiveDateTime::from_timestamp_opt(chrono::Utc::now().timestamp(), 0), + }, + ) + .await + } + + /// Add task to collection + async fn add_task(&self, task_id: &str) -> Result<(), BackendError> { + let metadata = TaskMeta { + task_id: task_id.to_string(), + status: TaskState::Pending, + result: None, + traceback: None, + date_done: None, + }; + self.store_task_meta(task_id, metadata).await + } + + /// Mark task as started to trace + async fn mark_as_started(&self, task_id: &str) -> Result<(), BackendError> { + let metadata = TaskMeta { + task_id: task_id.to_string(), + status: TaskState::Started, + result: None, + traceback: None, + date_done: None, + }; + self.store_task_meta(task_id, metadata).await + } + + /// Mark task as failure and save error + async fn mark_as_failure( + &self, + task_id: &str, + error: TaskError, + date_done: NaiveDateTime, + ) -> Result<(), BackendError> { + let metadata = TaskMeta { + task_id: task_id.to_string(), + status: TaskState::Failure, + result: Some(serde_json::to_value(error)?), + traceback: None, + date_done: Some(date_done), + }; + self.store_task_meta(task_id, metadata).await + } +} + +pub fn backend_builder_from_url(backend_url: impl AsRef) -> Box { + let backend_url = backend_url.as_ref(); + match backend_url.split_once("://") { + Some(("redis", _)) => Box::new(RedisBackendBuilder::new(backend_url)), + _ => panic!("Unsupported backend"), + } +} + +/// Metadata of the task stored in the storage used. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskMeta { + /// Task's ID. + pub(crate) task_id: String, + /// Current status of the task. + pub(crate) status: TaskState, + /// Result of the task. + pub result: Option, + /// Error of the task. + pub(crate) traceback: Option, + /// Date of culmination of the task + pub(crate) date_done: Option, + // TODO + // pub(crate) children: Option>, +} diff --git a/src/backend/redis.rs b/src/backend/redis.rs new file mode 100644 index 00000000..6052ac71 --- /dev/null +++ b/src/backend/redis.rs @@ -0,0 +1,111 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use redis::{aio::ConnectionManager, Client}; + +use super::{Backend, BackendBuilder, BackendError, TaskMeta}; + +pub struct RedisBackendBuilder { + backend_url: String, + #[allow(dead_code)] + connection_timeout: Option, + taskmeta_collection: String, +} + +#[async_trait] +impl BackendBuilder for RedisBackendBuilder { + /// Create new `RedisBackendBuilder`. + fn new(backend_url: &str) -> Self { + Self { + backend_url: backend_url.to_string(), + connection_timeout: None, + taskmeta_collection: "celery-task-meta".to_string(), + } + } + + fn database(self: Box, _: &str) -> Box { + self + } + + fn connection_timeout(self: Box, timeout: u32) -> Box { + Box::new(Self { + connection_timeout: Some(timeout), + ..*self + }) + } + + fn taskmeta_collection(self: Box, col: &str) -> Box { + Box::new(Self { + taskmeta_collection: col.to_string(), + ..*self + }) + } + + /// Create new `RedisBackend`. + async fn build(self: Box) -> Result, BackendError> { + let Self { + backend_url, + connection_timeout: _, + taskmeta_collection, + } = *self; + + let client = Client::open(backend_url.as_str()) + .map_err(|_| BackendError::InvalidBrokerUrl(backend_url))?; + + log::info!("Creating tokio manager"); + let manager = client.get_tokio_connection_manager().await?; + + Ok(Arc::new(RedisBackend { + _client: client, + manager, + taskmeta_collection, + })) + } +} + +pub struct RedisBackend { + _client: Client, + manager: ConnectionManager, + taskmeta_collection: String, +} + +#[async_trait] +impl Backend for RedisBackend { + /// Store the task meta into redis and notify pubsub subscribers waiting for + /// the task id. + async fn store_task_meta( + &self, + task_id: &str, + task_meta: TaskMeta, + ) -> Result<(), BackendError> { + let task_meta = serde_json::to_string(&task_meta)?; + let key = format!("{}-{}", self.taskmeta_collection, task_id); + + log::debug!("Storing task meta into {key}"); + log::trace!(" task meta value {task_meta:#?}"); + + let _ret: () = redis::cmd("SET") + .arg(&key) + .arg(&task_meta) + .query_async(&mut self.manager.clone()) + .await?; + + let _ret: () = redis::cmd("PUBLISH") + .arg(key) + .arg(&task_meta) + .query_async(&mut self.manager.clone()) + .await?; + + Ok(()) + } + + /// Retrieve task metadata and deserialize the result value + async fn get_task_meta(&self, task_id: &str) -> Result { + let key = format!("{}-{}", self.taskmeta_collection, task_id); + let raw: String = redis::cmd("GET") + .arg(key) + .query_async(&mut self.manager.clone()) + .await?; + Ok(serde_json::from_str(&raw)?) + } +} diff --git a/src/codegen.rs b/src/codegen.rs index 693a9ac9..4465630c 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -5,6 +5,7 @@ pub use celery_codegen::task; macro_rules! __app_internal { ( $broker_type:ty { $broker_url:expr }, + $backend_url:expr, [ $( $t:ty ),* ], [ $( $pattern:expr => $queue:expr ),* ], $( $x:ident = $y:expr, )* @@ -22,7 +23,7 @@ macro_rules! __app_internal { let broker_url = $broker_url; - let mut builder = $crate::CeleryBuilder::new("celery", &broker_url); + let mut builder = $crate::CeleryBuilder::new("celery", &broker_url, $backend_url); $( builder = builder.$x($y); @@ -158,6 +159,23 @@ macro_rules! app { ) => { $crate::__app_internal!( $broker_type { $broker_url }, + Option::<&str>::None, + [ $( $t ),* ], + [ $( $pattern => $queue ),* ], + $( $x = $y, )* + ); + }; + + ( + broker = $broker_type:ty { $broker_url:expr }, + backend = $backend_type:ty { $backend_url:expr }, + tasks = [ $( $t:ty ),* $(,)? ], + task_routes = [ $( $pattern:expr => $queue:expr ),* $(,)? ] + $(, $x:ident = $y:expr )* $(,)? + ) => { + $crate::__app_internal!( + $broker_type { $broker_url }, + Some($backend_url), [ $( $t ),* ], [ $( $pattern => $queue ),* ], $( $x = $y, )* diff --git a/src/error.rs b/src/error.rs index 8663f95c..0a9925cd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -20,6 +20,11 @@ pub enum CeleryError { #[error("broker error")] BrokerError(#[from] BrokerError), + /// Any other backend-level error that could happen when initializing or with an open + /// connection. + #[error("backend error")] + BackendError(#[from] BackendError), + /// Any other IO error that could occur. #[error("IO error")] IoError(#[from] std::io::Error), @@ -28,6 +33,10 @@ pub enum CeleryError { #[error("protocol error")] ProtocolError(#[from] ProtocolError), + /// Task error. + #[error("task error")] + TaskError(#[from] TaskError), + /// There is already a task registered to this name. #[error("there is already a task registered as '{0}'")] TaskRegistrationError(String), @@ -61,8 +70,117 @@ pub enum ScheduleError { } /// Errors that can occur at the task level. -#[derive(Error, Debug, Serialize, Deserialize)] -pub enum TaskError { +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskError { + /// Distinguish the error in Rust + #[serde(default)] + pub kind: TaskErrorType, + /// The Python exception type. For errors generated in Rust this will + /// default to `Exception`. + pub exc_type: String, + /// The module in which the exception type can be found. Will be `builtins` for + /// errors raised in Rust. + pub exc_module: String, + /// Error message + pub exc_message: TaskErrorMessage, + pub exc_cause: Option, + pub exc_traceback: Option, +} + +impl std::error::Error for TaskError {} + +impl std::fmt::Display for TaskError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let TaskError { + kind, + exc_type, + exc_message, + exc_module, + exc_cause: _, + exc_traceback, + } = self; + + match kind { + TaskErrorType::Expected => writeln!(f, "expected error:")?, + TaskErrorType::Unexpected => writeln!(f, "unexpected error:")?, + TaskErrorType::MaxRetriesExceeded => writeln!(f, "max retries exceeded:")?, + TaskErrorType::Other => writeln!(f, "{exc_type}:")?, + _ => {} + } + + if !exc_module.is_empty() && exc_module != "rust" { + writeln!(f, "in module {exc_module}")?; + } + + exc_message.print(f, 0)?; + + if let Some(trace) = exc_traceback { + writeln!(f, "exc traceback: {trace}")?; + } + Ok(()) + } +} + +impl TaskError { + pub fn expected(msg: impl ToString) -> Self { + TaskError { + kind: TaskErrorType::Expected, + exc_type: "Exception".to_string(), + exc_module: "builtins".to_string(), + exc_message: TaskErrorMessage::Text(msg.to_string()), + exc_cause: None, + exc_traceback: None, + } + } + + pub fn unexpected(msg: impl ToString) -> Self { + TaskError { + kind: TaskErrorType::Unexpected, + exc_type: "Exception".to_string(), + exc_module: "builtins".to_string(), + exc_message: TaskErrorMessage::Text(msg.to_string()), + exc_cause: None, + exc_traceback: None, + } + } + + pub fn timeout() -> Self { + TaskError { + kind: TaskErrorType::Timeout, + exc_type: "Exception".to_string(), + exc_module: "builtins".to_string(), + exc_message: TaskErrorMessage::Text("task timed out".to_string()), + exc_cause: None, + exc_traceback: None, + } + } + + pub fn retry(eta: Option>) -> Self { + TaskError { + kind: TaskErrorType::Retry(eta), + exc_type: "Exception".to_string(), + exc_module: "builtins".to_string(), + exc_message: TaskErrorMessage::Text("task retry triggered".to_string()), + exc_cause: None, + exc_traceback: None, + } + } + + pub fn max_retries_exceeded(name: &str, id: &str) -> Self { + TaskError { + kind: TaskErrorType::MaxRetriesExceeded, + exc_type: "MaxRetriesExceededError".to_string(), + exc_module: "celery.exceptions".to_string(), + exc_message: TaskErrorMessage::Text(format!("Can't retry {name}[{id}]")), + exc_cause: None, + exc_traceback: None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum TaskErrorType { /// An error that is expected to happen every once in a while. /// /// These errors will only be logged at the `WARN` level and will always trigger a task @@ -73,17 +191,14 @@ pub enum TaskError { /// If that service is temporarily unavailable the task should raise an `ExpectedError`. /// /// Tasks are always retried with capped exponential backoff. - #[error("task raised expected error: {0}")] - ExpectedError(String), - + Expected, /// Should be used when a task encounters an error that is unexpected. /// /// These errors will always be logged at the `ERROR` level. The retry behavior /// when this error is encountered is determined by the /// [`TaskOptions::retry_for_unexpected`](../task/struct.TaskOptions.html#structfield.retry_for_unexpected) /// setting. - #[error("task raised unexpected error: {0}")] - UnexpectedError(String), + Unexpected, /// Raised when a task runs over its time limit specified by the /// [`TaskOptions::time_limit`](../task/struct.TaskOptions.html#structfield.time_limit) setting. @@ -94,16 +209,52 @@ pub enum TaskError { /// Typically a task implementation doesn't need to return these errors directly /// because they will be raised automatically when the task runs over it's `time_limit`, /// provided the task yields control at some point (like with non-blocking IO). - #[error("task timed out")] - TimeoutError, + Timeout, /// A task can return this error variant to manually trigger a retry. /// /// This error variant should generally not be used directly. Instead, you should /// call the `Task::retry_with_countdown` or `Task::retry_with_eta` trait methods /// to manually trigger a retry from within a task. - #[error("task retry triggered")] Retry(Option>), + + MaxRetriesExceeded, + + #[default] + Other, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum TaskErrorMessage { + Text(String), + List(Vec), + Other(serde_json::Value), +} + +impl std::fmt::Display for TaskErrorMessage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.print(f, 0) + } +} + +impl TaskErrorMessage { + fn print(&self, f: &mut std::fmt::Formatter<'_>, indent: usize) -> std::fmt::Result { + match self { + TaskErrorMessage::Text(it) => { + writeln!(f, "{}{it}", " ".repeat(indent))?; + } + TaskErrorMessage::List(it) => { + for item in it { + item.print(f, indent + 2)?; + } + } + TaskErrorMessage::Other(it) => { + writeln!(f, "{}{it}", " ".repeat(indent))?; + } + } + Ok(()) + } } /// Errors that can occur while tracing a task. @@ -122,6 +273,16 @@ pub(crate) enum TraceError { Retry(Option>), } +impl TraceError { + pub(crate) fn into_task_error(self) -> TaskError { + match self { + TraceError::TaskError(e) => e, + TraceError::ExpirationError => TaskError::timeout(), + TraceError::Retry(_eta) => unreachable!("retry should not be returned as error"), + } + } +} + /// Errors that can occur at the broker level. #[derive(Error, Debug)] pub enum BrokerError { @@ -278,3 +439,17 @@ pub enum ContentTypeError { #[error("Unknown content type error")] Unknown, } + +/// Errors that can occur at the broker level. +#[derive(Error, Debug)] +pub enum BackendError { + #[error("invalid broker URL '{0}'")] + InvalidBrokerUrl(String), + + /// Any other Redis error that could happen. + #[error("Redis error \"{0}\"")] + RedisError(#[from] ::redis::RedisError), + + #[error("JSON serialization error")] + Json(#[from] serde_json::Error), +} diff --git a/src/lib.rs b/src/lib.rs index a0cfe109..a7aaf688 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -87,6 +87,7 @@ )] mod app; +pub mod backend; mod routing; pub use app::{Celery, CeleryBuilder}; pub mod beat; diff --git a/src/prelude.rs b/src/prelude.rs index e7e28b40..66a5831b 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -1,5 +1,6 @@ //! A "prelude" for users of the `celery` crate. +pub use crate::backend::RedisBackend; pub use crate::broker::{AMQPBroker, RedisBroker}; pub use crate::error::*; pub use crate::task::{Task, TaskResult, TaskResultExt}; diff --git a/src/task/async_result.rs b/src/task/async_result.rs index a21a04b5..314db175 100644 --- a/src/task/async_result.rs +++ b/src/task/async_result.rs @@ -1,13 +1,149 @@ +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; + +use crate::{backend::Backend, prelude::*}; + +use super::TaskReturns; + +pub struct AsyncResultGetBuilder { + task_id: String, + backend: Option>, + timeout: Option, + interval: Duration, + _marker: std::marker::PhantomData, +} + +impl AsyncResultGetBuilder +where + T: TaskReturns, +{ + pub fn new(task_id: String, backend: Option>) -> Self { + Self { + task_id, + backend, + timeout: None, + interval: Duration::from_millis(500), + _marker: std::marker::PhantomData, + } + } + + /// How long to wait, before the operation times out. This is the setting + /// for the publisher (celery client) and is different from `timeout` + /// parameter of `@app.task`, which is the setting for the worker. The task + /// isn't terminated even if timeout occurs. + #[must_use] + pub fn timeout(mut self, timeout: Duration) -> Self { + self.timeout = Some(timeout); + self + } + + /// Time to wait before retrying to retrieve the result. + #[must_use] + pub fn interval(mut self, interval: Duration) -> Self { + self.interval = interval; + self + } + + pub async fn fetch(self) -> Result { + use super::TaskState::*; + + let backend = self.backend.unwrap(); + let start = Instant::now(); + let meta = loop { + if let Some(timeout) = self.timeout { + if start.elapsed() > timeout { + return Err(crate::error::TaskError::timeout().into()); + } + } + + // Note: The task meta gets set by the client before the request is + // send, so this call should never error with "response was nil". + let meta = backend.get_task_meta(&self.task_id).await?; + + match meta.status { + Pending | Started | Received => { + tokio::time::sleep(self.interval).await; + } + + _ => break meta, + } + }; + + match (meta.status, meta.result) { + (_, None) => { + log::error!("task {} has no result", self.task_id); + Err(TaskError::unexpected( + "Task succeeded but did not provide a result value.".to_string(), + ) + .into()) + } + + (Failure, Some(val)) => { + match (serde_json::from_value::(val), meta.traceback) { + (Ok(mut err), Some(traceback)) => { + err.exc_traceback = err + .exc_traceback + .map(|mut tb| { + tb += "\n\n"; + tb += &traceback; + tb + }) + .or(Some(traceback)); + Err(err.into()) + } + + (Ok(err), _) => { + log::trace!("task {} failed", self.task_id); + Err(err.into()) + } + + (Err(err), _) => { + log::error!("unable to deserialize task failure value: {:?}", err); + Err(BackendError::Json(err).into()) + } + } + } + + (_, Some(val)) => { + log::trace!("task {} succeeded", self.task_id); + T::from_json(val).map_err(|e| BackendError::Json(e).into()) + } + } + } +} + /// An [`AsyncResult`] is a handle for the result of a task. -#[derive(Debug, Clone)] -pub struct AsyncResult { +#[derive(Clone)] +pub struct AsyncResult { pub task_id: String, + pub backend: Option>, + _marker: std::marker::PhantomData, } -impl AsyncResult { - pub fn new(task_id: &str) -> Self { +impl std::fmt::Debug for AsyncResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AsyncResult") + .field("task_id", &self.task_id) + .finish() + } +} + +impl AsyncResult +where + T: TaskReturns, +{ + pub fn new(task_id: &str, backend: Option>) -> Self { Self { task_id: task_id.into(), + backend, + _marker: std::marker::PhantomData, } } + + /// Wait until task is ready, and return its result. + pub fn get(&self) -> AsyncResultGetBuilder { + AsyncResultGetBuilder::new(self.task_id.clone(), self.backend.clone()) + } } diff --git a/src/task/mod.rs b/src/task/mod.rs index 6e1a3d1b..bbd46867 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -33,6 +33,27 @@ where type Returns = R; } +pub trait TaskReturns: std::fmt::Debug + Send + Sync + 'static { + fn to_json(&self) -> serde_json::Result; + + fn from_json(json: serde_json::Value) -> serde_json::Result + where + Self: Sized; +} + +impl TaskReturns for T +where + T: serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + Send + Sync + 'static, +{ + fn to_json(&self) -> serde_json::Result { + serde_json::to_value(self) + } + + fn from_json(json: serde_json::Value) -> serde_json::Result { + serde_json::from_value(json) + } +} + /// A `Task` represents a unit of work that a `Celery` app can produce or consume. /// /// The recommended way to create tasks is through the [`task`](../attr.task.html) attribute macro, not by directly implementing @@ -64,7 +85,7 @@ pub trait Task: Send + Sync + std::marker::Sized { type Params: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de>; /// The return type of the task. - type Returns: Send + Sync + std::fmt::Debug; + type Returns: TaskReturns; /// Used to initialize a task instance from a request. fn from_request(request: Request, options: TaskOptions) -> Self; @@ -116,21 +137,19 @@ pub trait Task: Send + Sync + std::marker::Sized { Some(DateTime::::from_naive_utc_and_offset( NaiveDateTime::from_timestamp_opt(eta_secs as i64, now_millis * 1000) .ok_or_else(|| { - TaskError::UnexpectedError(format!( - "Invalid countdown seconds {countdown}", - )) + TaskError::unexpected(format!("Invalid countdown seconds {countdown}",)) })?, Utc, )) } Err(_) => None, }; - Err(TaskError::Retry(eta)) + Err(TaskError::retry(eta)) } /// This can be called from within a task function to trigger a retry at the specified `eta`. fn retry_with_eta(&self, eta: DateTime) -> TaskResult { - Err(TaskError::Retry(Some(eta))) + Err(TaskError::retry(Some(eta))) } /// Get a future ETA at which time the task should be retried. By default this @@ -219,6 +238,25 @@ pub(crate) enum TaskStatus { Finished, } +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "UPPERCASE")] +pub enum TaskState { + /// The task is waiting for execution. + Pending, + /// The task has been started. + Started, + /// The task is to be retried, possibly because of failure. + Retry, + /// The task failed with error, or has exceeded the retry limit. + Failure, + /// The task executed successfully. + Success, + /// Task was received by a worker (only used in events). + Received, + /// Task was revoked. + Revoked, +} + /// Extension methods for `Result` types within a task body. /// /// These methods can be used to convert a `Result` to a `Result` with the @@ -254,10 +292,10 @@ where C: std::fmt::Display + Send + Sync + 'static, { fn with_expected_err(self, f: F) -> Result { - self.map_err(|e| TaskError::ExpectedError(format!("{} ➥ Cause: {:?}", f(), e))) + self.map_err(|e| TaskError::expected(format!("{} ➥ Cause: {:?}", f(), e))) } fn with_unexpected_err(self, f: F) -> Result { - self.map_err(|e| TaskError::UnexpectedError(format!("{} ➥ Cause: {:?}", f(), e))) + self.map_err(|e| TaskError::unexpected(format!("{} ➥ Cause: {:?}", f(), e))) } } diff --git a/tests/backends/mod.rs b/tests/backends/mod.rs new file mode 100644 index 00000000..edb368e0 --- /dev/null +++ b/tests/backends/mod.rs @@ -0,0 +1 @@ +mod redis; diff --git a/tests/backends/redis.rs b/tests/backends/redis.rs new file mode 100644 index 00000000..ffc58fa8 --- /dev/null +++ b/tests/backends/redis.rs @@ -0,0 +1,94 @@ +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; +use tokio::time::{self, Duration}; + +use celery::{prelude::*, task::AsyncResult, Celery}; + +#[celery::task(name = "sucessful_task")] +pub async fn sucessful_task(a: i32, b: i32) -> TaskResult { + Ok(a + b) +} + +#[celery::task(name = "task_with_retry_failure", max_retries = 5)] +pub async fn task_with_retry_failure() -> TaskResult { + Err(TaskError::retry(Some( + chrono::Utc::now() + chrono::Duration::milliseconds(100), + ))) +} + +#[celery::task(name = "task_with_expires_failure")] +pub async fn task_with_expires_failure() -> TaskResult { + println!("running task_with_expires_failure"); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + Err(TaskError::expected("failure expected")) +} + +async fn app() -> Result, CeleryError> { + celery::app!( + broker = AMQPBroker { std::env::var("AMQP_ADDR").unwrap_or_else(|_| "amqp://127.0.0.1:5672//".into()) }, + backend = RedisBAckend { std::env::var("REDIS_ADDR").unwrap_or_else(|_| "redis://127.0.0.1:6379/".into()) }, + tasks = [sucessful_task, task_with_retry_failure, task_with_expires_failure], + task_routes = ["*" => "celery"], + prefetch_count = 2 + ).await +} + +#[tokio::test] +async fn test() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace")).init(); + let app = app().await?; + let result: Result, CeleryError> = + app.send_task(sucessful_task::new(1, 2)).await; + let Ok(result) = result else { + panic!("Failed to send task"); + }; + let _ = time::timeout(Duration::from_secs(1), app.consume()).await; + let result = result.get().fetch().await; + assert!(matches!(result, Ok(3))); + Ok(()) +} + +#[tokio::test] +async fn test_max_retries() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace")).init(); + let app = app().await?; + let result: Result, CeleryError> = + app.send_task(task_with_retry_failure::new()).await; + let Ok(result) = result else { + panic!("Failed to send task"); + }; + let _ = time::timeout(Duration::from_secs(2), app.consume()).await; + let result = result.get().fetch().await; + assert!(matches!( + result, + Err(CeleryError::TaskError(TaskError { + kind: TaskErrorType::MaxRetriesExceeded, + .. + })) + )); + Ok(()) +} + +#[tokio::test] +async fn test_expires() -> Result<()> { + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("trace")).init(); + let app = app().await?; + let result: Result, CeleryError> = app + .send_task(task_with_expires_failure::new().with_expires_in(1)) + .await; + let Ok(result) = result else { + panic!("Failed to send task"); + }; + let _ = time::timeout(Duration::from_secs(2), app.consume()).await; + let result = result.get().fetch().await; + assert!(matches!( + result, + Err(CeleryError::TaskError(TaskError { + kind: TaskErrorType::Timeout, + .. + })) + )); + Ok(()) +} diff --git a/tests/integrations.rs b/tests/integrations.rs index 580e24ae..1bce1e4f 100644 --- a/tests/integrations.rs +++ b/tests/integrations.rs @@ -1 +1,2 @@ +mod backends; mod brokers;