Skip to content

Commit

Permalink
implement backend integration + redis backend
Browse files Browse the repository at this point in the history
- care was put into being able to run rust celery both as a worker or client with python worker/clients.
  - for that the TaskError has now a format that captures python exceptions and lets python report rust errors
- only content type of application/json is supported for now
- the interface to retrieve task results using AsyncResult is minimal. only a timeout option is currently supported

To use the redis backend, a celery app can be instantiated like this:

```
celery::app!(
    broker = RedisBroker { std::env::var("REDIS_ADDR").unwrap() },
    backend = RedisBackend { std::env::var("REDIS_ADDR").unwrap() },
    tasks = [add, expected_failure, unexpected_failure, task_with_timeout],
    task_routes = [
        "*" => "celery",
    ],
)
.await
```

The result can then be fetched using async result:

```rust
pub async fn add(x: i32, y: i32) -> TaskResult<i32> {
    info!("adding {} + {}", x, y);
    Ok(x + y)
}

let task = tasks::add::new(1, 2);
let result = app.send_task(task).await.expect("send task");
assert_eq!(3, result.get().fetch().await.expect("get task"));
```
  • Loading branch information
rksm committed Dec 26, 2023
1 parent 8348f72 commit ff102ad
Show file tree
Hide file tree
Showing 16 changed files with 854 additions and 89 deletions.
15 changes: 12 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ and register your tasks with it:
```rust
let my_app = celery::app!(
broker = AMQPBroker { std::env::var("AMQP_ADDR").unwrap() },
backend = RedisBackend { std::env::var("REDIS_ADDR").unwrap() },
tasks = [add],
task_routes = [
"*" => "celery",
Expand All @@ -61,7 +62,7 @@ let my_app = celery::app!(
Then send tasks to a queue with

```rust
my_app.send_task(add::new(1, 2)).await?;
let result = my_app.send_task(add::new(1, 2)).await?;
```

And consume tasks as a worker from a queue with
Expand All @@ -70,6 +71,14 @@ And consume tasks as a worker from a queue with
my_app.consume().await?;
```

In a Rust client you can wait for the result of a task with


``` rust
let result = result.get().fetch().await.expect("get task");
assert_eq!(3, result);
```

## Examples

The [`examples/`](https://github.com/rusty-celery/rusty-celery/tree/main/examples) directory contains:
Expand Down Expand Up @@ -152,7 +161,7 @@ And then you can consume tasks from Rust or Python as explained above.
| Consumers || |
| Brokers || |
| Beat || |
| Backends | 🔴 | |
| Backends | ⚠️ | |
| [Baskets](https://github.com/rusty-celery/rusty-celery/issues/53) | 🔴 | |

### Brokers
Expand All @@ -167,4 +176,4 @@ And then you can consume tasks from Rust or Python as explained above.
| | Status | Tracking |
| ----------- |:------:| -------- |
| RPC | 🔴 | [![](https://img.shields.io/github/issues/rusty-celery/rusty-celery/Backend%3A%20RPC?label=Issues)](https://github.com/rusty-celery/rusty-celery/labels/Backend%3A%20RPC) |
| Redis | 🔴 | [![](https://img.shields.io/github/issues/rusty-celery/rusty-celery/Backend%3A%20Redis?label=Issues)](https://github.com/rusty-celery/rusty-celery/labels/Backend%3A%20Redis) |
| Redis | | [![](https://img.shields.io/github/issues/rusty-celery/rusty-celery/Backend%3A%20Redis?label=Issues)](https://github.com/rusty-celery/rusty-celery/labels/Backend%3A%20Redis) |
14 changes: 10 additions & 4 deletions examples/celery_app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,17 @@ async fn main() -> Result<()> {
} else {
for task in tasks {
match task.as_str() {
"add" => my_app.send_task(add::new(1, 2)).await?,
"bound_task" => my_app.send_task(bound_task::new()).await?,
"buggy_task" => my_app.send_task(buggy_task::new()).await?,
"add" => {
my_app.send_task(add::new(1, 2)).await?;
}
"bound_task" => {
my_app.send_task(bound_task::new()).await?;
}
"buggy_task" => {
my_app.send_task(buggy_task::new()).await?;
}
"long_running_task" => {
my_app.send_task(long_running_task::new(Some(3))).await?
my_app.send_task(long_running_task::new(Some(3))).await?;
}
_ => panic!("unknown task"),
};
Expand Down
70 changes: 61 additions & 9 deletions src/app/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use tokio_stream::StreamMap;

mod trace;

use crate::backend::{backend_builder_from_url, Backend, BackendBuilder};
use crate::broker::{
broker_builder_from_url, build_and_connect, configure_task_routes, Broker, BrokerBuilder,
Delivery,
Expand All @@ -31,6 +32,7 @@ struct Config {
name: String,
hostname: String,
broker_builder: Box<dyn BrokerBuilder>,
backend_builder: Option<Box<dyn BackendBuilder>>,
broker_connection_timeout: u32,
broker_connection_retry: bool,
broker_connection_max_retries: u32,
Expand All @@ -47,7 +49,7 @@ pub struct CeleryBuilder {

impl CeleryBuilder {
/// Get a [`CeleryBuilder`] for creating a [`Celery`] app with a custom configuration.
pub fn new(name: &str, broker_url: &str) -> Self {
pub fn new(name: &str, broker_url: &str, backend_url: Option<impl AsRef<str>>) -> Self {
Self {
config: Config {
name: name.into(),
Expand All @@ -60,6 +62,7 @@ impl CeleryBuilder {
.unwrap_or_else(|| "unknown".into())
),
broker_builder: broker_builder_from_url(broker_url),
backend_builder: backend_url.map(backend_builder_from_url),
broker_connection_timeout: 2,
broker_connection_retry: true,
broker_connection_max_retries: 5,
Expand Down Expand Up @@ -189,6 +192,24 @@ impl CeleryBuilder {
self
}

/// Set connection timeout for backend.
pub fn backend_connection_timeout(mut self, timeout: u32) -> Self {
self.config.backend_builder = self
.config
.backend_builder
.map(|builder| builder.connection_timeout(timeout));
self
}

/// Set backend task meta collection name.
pub fn backend_taskmeta_collection(mut self, collection_name: &str) -> Self {
self.config.backend_builder = self
.config
.backend_builder
.map(|builder| builder.taskmeta_collection(collection_name));
self
}

/// Construct a [`Celery`] app with the current configuration.
pub async fn build(self) -> Result<Celery, CeleryError> {
// Declare default queue to broker.
Expand All @@ -212,10 +233,16 @@ impl CeleryBuilder {
)
.await?;

let backend = match self.config.backend_builder {
Some(backend_builder) => Some(backend_builder.build().await?),
None => None,
};

Ok(Celery {
name: self.config.name,
hostname: self.config.hostname,
broker,
backend,
default_queue: self.config.default_queue,
task_options: self.config.task_options,
task_routes,
Expand All @@ -240,6 +267,9 @@ pub struct Celery {
/// The app's broker.
pub broker: Box<dyn Broker>,

/// The backend to use for storing task results.
pub backend: Option<Arc<dyn Backend>>,

/// The default queue to send and receive from.
pub default_queue: String,

Expand Down Expand Up @@ -302,7 +332,7 @@ impl Celery {
pub async fn send_task<T: Task>(
&self,
mut task_sig: Signature<T>,
) -> Result<AsyncResult, CeleryError> {
) -> Result<AsyncResult<T::Returns>, CeleryError> {
task_sig.options.update(&self.task_options);
let maybe_queue = task_sig.queue.take();
let queue = maybe_queue.as_deref().unwrap_or_else(|| {
Expand All @@ -316,7 +346,15 @@ impl Celery {
queue,
);
self.broker.send(&message, queue).await?;
Ok(AsyncResult::new(message.task_id()))

if let Some(backend) = &self.backend {
backend.add_task(message.task_id()).await?;
}

Ok(AsyncResult::<T::Returns>::new(
message.task_id(),
self.backend.clone(),
))
}

/// Register a task.
Expand Down Expand Up @@ -428,12 +466,26 @@ impl Celery {
// NOTE: we don't need to log errors from the trace here since the tracer
// handles all errors at it's own level or the task level. In this function
// we only log errors at the broker and delivery level.
if let Err(TraceError::Retry(retry_eta)) = tracer.trace().await {
// If retry error -> retry the task.
self.broker
.retry(delivery.as_ref(), retry_eta)
.await
.map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
match tracer.trace().await {
Err(TraceError::Retry(retry_eta)) => {
// If retry error -> retry the task.
self.broker
.retry(delivery.as_ref(), retry_eta)
.await
.map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
}

result => {
if let Some(backend) = self.backend.as_ref() {
backend
.store_result(
tracer.task_id(),
result.map_err(|err| err.into_task_error()),
)
.await
.map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
}
}
}

// If we have not done it before, we have to acknowledge the message now.
Expand Down
4 changes: 2 additions & 2 deletions src/app/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
use std::time::SystemTime;

async fn build_basic_app() -> Celery {
let celery = CeleryBuilder::new("mock-app", "mock://localhost:8000")
let celery = CeleryBuilder::new("mock-app", "mock://localhost:8000", Option::<&str>::None)
.build()
.await
.unwrap();
Expand All @@ -18,7 +18,7 @@ async fn build_basic_app() -> Celery {
}

async fn build_configured_app() -> Celery {
let celery = CeleryBuilder::new("mock-app", "mock://localhost:8000")
let celery = CeleryBuilder::new("mock-app", "mock://localhost:8000", Option::<&str>::None)
.task_time_limit(10)
.task_max_retries(100)
.task_content_type(MessageContentType::Yaml)
Expand Down
84 changes: 36 additions & 48 deletions src/app/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ use tokio::sync::mpsc::UnboundedSender;
use tokio::time::{self, Duration, Instant};

use crate::error::{ProtocolError, TaskError, TraceError};
use crate::prelude::TaskErrorType;
use crate::protocol::Message;
use crate::task::{Request, Task, TaskEvent, TaskOptions, TaskStatus};
use crate::task::{Request, Task, TaskEvent, TaskOptions, TaskReturns, TaskStatus};
use crate::Celery;

/// A `Tracer` provides the API through which a `Celery` application interacts with its tasks.
Expand Down Expand Up @@ -48,7 +49,7 @@ impl<T> TracerTrait for Tracer<T>
where
T: Task,
{
async fn trace(&mut self) -> Result<(), TraceError> {
async fn trace(&mut self) -> Result<Box<dyn TaskReturns>, TraceError> {
if self.is_expired() {
warn!(
"Task {}[{}] expired, discarding",
Expand All @@ -60,10 +61,10 @@ where

self.event_tx
.send(TaskEvent::StatusChange(TaskStatus::Pending))
.unwrap_or_else(|_| {
.unwrap_or_else(|err| {
// This really shouldn't happen. If it does, there's probably much
// bigger things to worry about like running out of memory.
error!("Failed sending task event");
error!("Failed sending task event {err:?}");
});

let start = Instant::now();
Expand All @@ -73,7 +74,7 @@ where
let duration = Duration::from_secs(secs as u64);
time::timeout(duration, self.task.run(self.task.request().params.clone()))
.await
.unwrap_or(Err(TaskError::TimeoutError))
.unwrap_or(Err(TaskError::timeout()))
}
None => self.task.run(self.task.request().params.clone()).await,
};
Expand All @@ -94,58 +95,36 @@ where

self.event_tx
.send(TaskEvent::StatusChange(TaskStatus::Finished))
.unwrap_or_else(|_| {
error!("Failed sending task event");
.unwrap_or_else(|err| {
error!("Failed sending task event {err:?}");
});

Ok(())
Ok(Box::new(returned))
}

Err(e) => {
let (should_retry, retry_eta) = match e {
TaskError::ExpectedError(ref reason) => {
warn!(
"Task {}[{}] failed with expected error: {}",
self.task.name(),
&self.task.request().id,
reason
);
(true, None)
}
TaskError::UnexpectedError(ref reason) => {
error!(
"Task {}[{}] failed with unexpected error: {}",
self.task.name(),
&self.task.request().id,
reason
);
(self.task.retry_for_unexpected(), None)
}
TaskError::TimeoutError => {
error!(
"Task {}[{}] timed out after {}s",
self.task.name(),
&self.task.request().id,
duration.as_secs_f32(),
);
(true, None)
}
TaskError::Retry(eta) => {
error!(
"Task {}[{}] triggered retry",
self.task.name(),
&self.task.request().id,
);
(true, eta)
}
error!(
"Task {}[{}] failed: {e}",
self.task.name(),
&self.task.request().id
);

let (should_retry, retry_eta) = match e.kind {
TaskErrorType::Retry(eta) => (true, eta),
TaskErrorType::MaxRetriesExceeded => (false, None),
TaskErrorType::Expected => (true, None),
TaskErrorType::Unexpected => (self.task.retry_for_unexpected(), None),
TaskErrorType::Timeout => (true, None),
TaskErrorType::Other => (true, None),
};

// Run failure callback.
self.task.on_failure(&e).await;

self.event_tx
.send(TaskEvent::StatusChange(TaskStatus::Finished))
.unwrap_or_else(|_| {
error!("Failed sending task event");
.unwrap_or_else(|err| {
error!("Failed sending task event {err:?}");
});

if !should_retry {
Expand All @@ -160,7 +139,10 @@ where
self.task.name(),
&self.task.request().id,
);
return Err(TraceError::TaskError(e));
return Err(TraceError::TaskError(TaskError::max_retries_exceeded(
self.task.name(),
&self.task.request().id,
)));
}
info!(
"Task {}[{}] retrying ({} / {})",
Expand Down Expand Up @@ -202,13 +184,17 @@ where
fn acks_late(&self) -> bool {
self.task.acks_late()
}

fn task_id(&self) -> &str {
&self.task.request().id
}
}

#[async_trait]
pub(super) trait TracerTrait: Send + Sync {
/// Wraps the execution of a task, catching and logging errors and then running
/// the appropriate post-execution functions.
async fn trace(&mut self) -> Result<(), TraceError>;
async fn trace(&mut self) -> Result<Box<dyn TaskReturns>, TraceError>;

/// Wait until the task is due.
async fn wait(&self);
Expand All @@ -218,6 +204,8 @@ pub(super) trait TracerTrait: Send + Sync {
fn is_expired(&self) -> bool;

fn acks_late(&self) -> bool;

fn task_id(&self) -> &str;
}

pub(super) type TraceBuilderResult = Result<Box<dyn TracerTrait>, ProtocolError>;
Expand Down
Loading

0 comments on commit ff102ad

Please sign in to comment.