wip: add test to task_manager

refs: #5
This commit is contained in:
Alexander Navarro 2025-05-20 16:49:41 -04:00
parent d87843614a
commit 2c47226dc9
9 changed files with 149 additions and 51 deletions

View file

@ -1,5 +1,6 @@
<component name="ProjectCodeStyleConfiguration"> <component name="ProjectCodeStyleConfiguration">
<state> <state>
<option name="USE_PER_PROJECT_SETTINGS" value="true" /> <option name="USE_PER_PROJECT_SETTINGS" value="true" />
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
</state> </state>
</component> </component>

1
Cargo.lock generated
View file

@ -1066,6 +1066,7 @@ dependencies = [
"tabled", "tabled",
"thiserror", "thiserror",
"tokio", "tokio",
"tokio-stream",
"tracing", "tracing",
"tracing-core", "tracing-core",
"tracing-subscriber", "tracing-subscriber",

View file

@ -6,6 +6,7 @@ edition = "2024"
[dependencies] [dependencies]
directories = "6.0.0" directories = "6.0.0"
tokio = { version = "1.45.0", features = ["default", "rt", "rt-multi-thread", "macros"] } tokio = { version = "1.45.0", features = ["default", "rt", "rt-multi-thread", "macros"] }
tokio-stream = "0.1.17"
sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "chrono", "migrate", "uuid"] } sqlx = { version = "0.8", features = ["runtime-tokio", "sqlite", "chrono", "migrate", "uuid"] }
clap = { version = "4.5.37", features = ["derive"] } clap = { version = "4.5.37", features = ["derive"] }
serde = { version = "1.0.219", features = ["derive"] } serde = { version = "1.0.219", features = ["derive"] }

View file

@ -46,15 +46,12 @@ impl TaskPagination {
pub struct TasksPage<T: TaskPayload> { pub struct TasksPage<T: TaskPayload> {
tasks: Vec<Task<T>>, tasks: Vec<Task<T>>,
page: TaskPagination page: TaskPagination,
} }
impl<T: TaskPayload> TasksPage<T> { impl<T: TaskPayload> TasksPage<T> {
fn new(tasks: Vec<Task<T>>, page: TaskPagination) -> Self { fn new(tasks: Vec<Task<T>>, page: TaskPagination) -> Self {
Self { Self { tasks, page }
tasks,
page
}
} }
pub fn next(&self) -> TaskPagination { pub fn next(&self) -> TaskPagination {
@ -67,11 +64,13 @@ impl<T: TaskPayload> TasksPage<T> {
} }
pub trait TaskStorage<T: TaskPayload> { pub trait TaskStorage<T: TaskPayload> {
async fn insert_tasks<'a, I: IntoIterator<Item=&'a Task<T>>>(&self, tasks: I) -> crate::Result<()>; async fn insert_tasks<'a, I: IntoIterator<Item = &'a Task<T>>>(
&self,
tasks: I,
) -> crate::Result<()>;
fn get_tasks(&self, task_status: TaskStatus) -> impl Stream<Item = crate::Result<Task<T>>>; fn get_tasks(&self, task_status: TaskStatus) -> impl Stream<Item = crate::Result<Task<T>>>;
async fn listen_tasks(&self, task_status: TaskStatus) -> crate::Result<()>; fn listen_tasks(&self, task_status: TaskStatus) -> impl Stream<Item = crate::Result<Task<T>>>;
async fn get_paginated_tasks(&self, page: TaskPagination) -> crate::Result<TasksPage<T>>; async fn get_paginated_tasks(&self, page: TaskPagination) -> crate::Result<TasksPage<T>>;
} }

View file

@ -108,10 +108,11 @@ impl<T: TaskPayload> TaskStorage<T> for Sqlite {
query.fetch(&self.pool).err_into::<crate::Error>() query.fetch(&self.pool).err_into::<crate::Error>()
} }
async fn listen_tasks(&self, task_status: TaskStatus) -> crate::error::Result<()> { fn listen_tasks(&self, task_status: TaskStatus) -> impl Stream<Item=crate::error::Result<Task<T>>> {
todo!() futures::stream::empty()
} }
async fn get_paginated_tasks(&self, page: TaskPagination) -> crate::Result<TasksPage<T>> { async fn get_paginated_tasks(&self, page: TaskPagination) -> crate::Result<TasksPage<T>> {
let mut builder: QueryBuilder<'_, sqlx::Sqlite> = QueryBuilder::new( let mut builder: QueryBuilder<'_, sqlx::Sqlite> = QueryBuilder::new(
"select id, payload_key, payload, status_id, created_at, updated_at from tasks ", "select id, payload_key, payload, status_id, created_at, updated_at from tasks ",

View file

@ -6,6 +6,8 @@ use tabled::Tabled;
mod manager; mod manager;
mod jobs; mod jobs;
mod worker;
mod bus;
#[derive(sqlx::Type, Debug, Clone)] #[derive(sqlx::Type, Debug, Clone)]
#[repr(u8)] #[repr(u8)]
@ -75,9 +77,7 @@ impl<T: TaskPayload> Task<T> {
pub fn payload(&self) -> &T { pub fn payload(&self) -> &T {
&self.payload &self.payload
} }
}
impl<T: TaskPayload> Task<T> {
pub fn get_key(&self) -> String { pub fn get_key(&self) -> String {
self.payload_key.clone() self.payload_key.clone()
} }

View file

@ -0,0 +1,4 @@
#[derive(Clone)]
pub enum Bus {
Local,
}

View file

@ -1,9 +1,14 @@
use crate::database::TaskStorage; use crate::database::TaskStorage;
use crate::tasks::bus::Bus;
use crate::tasks::jobs::TaskJob; use crate::tasks::jobs::TaskJob;
use crate::tasks::{Task, TaskPayload, TaskStatus}; use crate::tasks::{Task, TaskPayload, TaskStatus};
use futures::StreamExt; use futures::StreamExt;
use futures::stream::FuturesOrdered;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::pin::pin;
use tokio::sync::mpsc::Receiver;
use tokio::sync::{mpsc, oneshot};
use tokio::sync::oneshot::Sender;
use crate::tasks::worker::TaskMessage;
pub enum RateLimit { pub enum RateLimit {
Buffer(usize), Buffer(usize),
@ -12,11 +17,12 @@ pub enum RateLimit {
None, None,
} }
pub struct ExecuteOptions { pub struct ManagerOptions {
rate_limit: RateLimit, rate_limit: RateLimit,
bus: Bus,
} }
impl ExecuteOptions { impl ManagerOptions {
pub fn new() -> Self { pub fn new() -> Self {
Self::default() Self::default()
} }
@ -27,22 +33,23 @@ impl ExecuteOptions {
} }
} }
impl Default for ExecuteOptions { impl Default for ManagerOptions {
fn default() -> Self { fn default() -> Self {
Self { Self {
rate_limit: RateLimit::None, rate_limit: RateLimit::None,
bus: Bus::Local,
} }
} }
} }
struct TaskManager<S: TaskPayload, T: TaskStorage<S>> { struct TaskManager<S: TaskPayload, T: TaskStorage<S>> {
storage: T, storage: T,
options: ExecuteOptions, options: ManagerOptions,
_marker: PhantomData<S>, _marker: PhantomData<S>,
} }
impl<S: TaskPayload, T: TaskStorage<S>> TaskManager<S, T> { impl<S: TaskPayload, T: TaskStorage<S>> TaskManager<S, T> {
pub fn new(storage: T, options: ExecuteOptions) -> Self { pub fn new(storage: T, options: ManagerOptions) -> Self {
Self { Self {
storage, storage,
options, options,
@ -50,18 +57,31 @@ impl<S: TaskPayload, T: TaskStorage<S>> TaskManager<S, T> {
} }
} }
pub async fn run_tasks(&self, func: TaskJob<S>) -> crate::Result<()> { pub async fn run_tasks(&self, mut task_sink: TaskMessage<S>) -> crate::Result<()> {
let rows = self.storage.get_tasks(TaskStatus::Pending); let rows = self.storage.get_tasks(TaskStatus::Pending);
let listener = self.storage.listen_tasks(TaskStatus::Pending);
let result: Vec<(Task<S>, TaskStatus)> = rows let mut queue = pin!(rows.chain(listener));
.map(async |x| {
let task = x.unwrap();
let status = func(&task);
(task, status) while let Some(task) = queue.next().await {
}) let task = match task {
.collect() Ok(task) => task,
.await; Err(e) => {
continue
}
};
let sink = match task_sink.recv().await {
Some(s) => s,
None => break, // sink has stoped requesting tasks
};
if let Err(_) = sink.send(task) {
continue;
}
// (task, status)
}
Ok(()) Ok(())
} }
@ -71,12 +91,17 @@ impl<S: TaskPayload, T: TaskStorage<S>> TaskManager<S, T> {
mod tests { mod tests {
use super::*; use super::*;
use crate::database::{TaskPagination, TasksPage}; use crate::database::{TaskPagination, TasksPage};
use async_stream::stream;
use fake::{Dummy, Fake, Faker}; use fake::{Dummy, Fake, Faker};
use futures::{Stream, StreamExt}; use futures::Stream;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::Row;
use sqlx::types::Uuid; use sqlx::types::Uuid;
use sync::mpsc;
use tokio::sync;
use tokio_stream::wrappers::ReceiverStream;
use tracing_test::traced_test; use tracing_test::traced_test;
use crate::error::Error;
use crate::tasks::worker::{Worker, WorkerManager};
#[derive(Dummy, Serialize, Deserialize, Debug)] #[derive(Dummy, Serialize, Deserialize, Debug)]
struct DummyTaskPayload { struct DummyTaskPayload {
@ -98,7 +123,7 @@ mod tests {
fn get_tasks( fn get_tasks(
&self, &self,
task_status: TaskStatus, task_status: TaskStatus,
) -> impl Stream<Item = crate::error::Result<Task<DummyTaskPayload>>> { ) -> impl Stream<Item = crate::Result<Task<DummyTaskPayload>>> {
let payloads: Vec<DummyTaskPayload> = Faker.fake(); let payloads: Vec<DummyTaskPayload> = Faker.fake();
let tasks = payloads.into_iter().enumerate().map(move |(i, item)| { let tasks = payloads.into_iter().enumerate().map(move |(i, item)| {
@ -108,24 +133,27 @@ mod tests {
futures::stream::iter(tasks) futures::stream::iter(tasks)
} }
async fn listen_tasks(&self, task_status: TaskStatus) -> crate::error::Result<()> { fn listen_tasks(
todo!() &self,
} task_status: TaskStatus,
async fn listen_tasks2(&self, task_status: TaskStatus) -> FuturesOrdered<impl Future<Output=Task<DummyTaskPayload>> + Sized> { ) -> impl Stream<Item = crate::Result<Task<DummyTaskPayload>>> {
let mut fifo = FuturesOrdered::new(); let (tx, rx) = mpsc::channel::<crate::Result<Task<DummyTaskPayload>>>(10);
tokio::spawn(async move { tokio::spawn(async move {
loop { for _ in 0..10 {
tokio::time::sleep(std::time::Duration::from_millis(250)).await; tokio::time::sleep(std::time::Duration::from_millis(250)).await;
let payload: DummyTaskPayload = Faker.fake(); let payload: DummyTaskPayload = Faker.fake();
let task_status: TaskStatus = task_status.clone(); let task_status: TaskStatus = task_status.clone();
fifo.push_back(async move { let task = Ok(Task::new(payload.key.to_string(), payload, task_status));
Task::new(payload.key.to_string(), payload, task_status)
}); if let Err(_) = tx.send(task).await {
break;
}
} }
}); });
fifo ReceiverStream::new(rx)
} }
async fn get_paginated_tasks( async fn get_paginated_tasks(
@ -136,12 +164,27 @@ mod tests {
} }
} }
struct DummyWorker;
impl Worker<DummyTaskPayload> for DummyWorker {
fn process_job(task: &Task<DummyTaskPayload>) -> crate::error::Result<()> {
println!("{:#?}", task);
Ok(())
}
async fn on_job_failure(task: &Task<DummyTaskPayload>, error: Error) -> crate::error::Result<()> {
println!("{:#?} {:?}", task, error);
Ok(())
}
}
#[tokio::test] #[tokio::test]
#[traced_test] #[traced_test]
async fn manager_runs() { async fn manager_runs() {
let execute_options = ExecuteOptions::new(); let execute_options = ManagerOptions::new();
let manager = TaskManager::new(DummyTaskStorage {}, execute_options); let local_worker_sink = WorkerManager::get_listener_sink::<DummyTaskPayload, DummyWorker>(execute_options.bus.clone());
let task_manager = TaskManager::new(DummyTaskStorage {}, execute_options);
manager.run_tasks(|_| TaskStatus::Completed).await.unwrap(); task_manager.run_tasks(local_worker_sink).await.unwrap()
} }
} }

View file

@ -0,0 +1,48 @@
use crate::error::Error;
use crate::tasks::bus::Bus;
use crate::tasks::{Task, TaskPayload};
use tokio::sync::mpsc::Receiver;
use tokio::sync::oneshot::Sender;
use tokio::sync::{mpsc, oneshot};
pub type TaskMessage<T> = Receiver<Sender<Task<T>>>;
pub struct WorkerManager;
impl WorkerManager {
pub fn get_listener_sink<T: TaskPayload, W: Worker<T>>(bus: Bus) -> TaskMessage<T> {
match bus {
Bus::Local => {
let (bus_tx, bus_rx) = mpsc::channel(100);
tokio::spawn(async move {
loop {
// TODO: properly catch errors
let (tx, rx) = oneshot::channel();
// Request a task
bus_tx.send(tx).await.unwrap();
// Wait for a task to be returned
let task = rx.await.unwrap();
W::process_job(&task).unwrap();
}
});
bus_rx
}
}
}
}
pub trait Worker<T: TaskPayload> {
async fn pre_process_job(task: &Task<T>) -> crate::Result<()> {
Ok(())
}
fn process_job(task: &Task<T>) -> crate::Result<()>;
async fn post_process_job(task: &Task<T>) -> crate::Result<()> {
Ok(())
}
async fn on_job_failure(task: &Task<T>, error: Error) -> crate::Result<()>;
}