feat: load task to database

This commit is contained in:
Alexander Navarro 2025-05-07 16:42:25 -04:00
parent 1e3c235b78
commit 7be435332c
5 changed files with 89 additions and 38 deletions

1
.idea/sqldialects.xml generated
View file

@ -2,7 +2,6 @@
<project version="4"> <project version="4">
<component name="SqlDialectMappings"> <component name="SqlDialectMappings">
<file url="file://$PROJECT_DIR$/migrations" dialect="SQLite" /> <file url="file://$PROJECT_DIR$/migrations" dialect="SQLite" />
<file url="file://$PROJECT_DIR$/migrations/002_statuses.sql" dialect="SQLite" />
<file url="PROJECT" dialect="SQLite" /> <file url="PROJECT" dialect="SQLite" />
</component> </component>
</project> </project>

View file

@ -5,7 +5,7 @@ create table tasks
primary key autoincrement, primary key autoincrement,
payload_key ANY not null payload_key ANY not null
constraint tasks_payload_key constraint tasks_payload_key
unique, unique on conflict ignore,
payload TEXT not null, payload TEXT not null,
status_id integer not null status_id integer not null
constraint tasks_task_statuses_id_fk constraint tasks_task_statuses_id_fk

View file

@ -2,7 +2,7 @@ use std::fs::File;
use clap::Parser; use clap::Parser;
use readwise_bulk_upload::config::Args; use readwise_bulk_upload::config::Args;
use readwise_bulk_upload::readwise::DocumentPayload; use readwise_bulk_upload::readwise::DocumentPayload;
use readwise_bulk_upload::sql::get_database; use readwise_bulk_upload::sql::{TaskManager};
use readwise_bulk_upload::{Error, Result}; use readwise_bulk_upload::{Error, Result};
#[tokio::main] #[tokio::main]
@ -17,7 +17,9 @@ async fn main() -> Result<()> {
let documents: Vec<DocumentPayload> = serde_json::from_reader(file)?; let documents: Vec<DocumentPayload> = serde_json::from_reader(file)?;
let db = get_database().await?; let task_manager = TaskManager::new().await?;
task_manager.load_tasks(documents).await?;
Ok(()) Ok(())
} }

View file

@ -1,8 +1,9 @@
use chrono::{DateTime, Local}; use chrono::{DateTime, Local};
use serde::{Deserialize, Deserializer, de}; use serde::{Deserialize, Deserializer, de, Serialize};
use serde_json::Value; use serde_json::Value;
use crate::sql::TaskPayload;
#[derive(Deserialize)] #[derive(Deserialize, Serialize, Debug)]
pub struct DocumentPayload { pub struct DocumentPayload {
title: String, title: String,
summary: Option<String>, summary: Option<String>,
@ -13,20 +14,10 @@ pub struct DocumentPayload {
location: String, location: String,
} }
fn str_to_int<'de, D: Deserializer<'de>>(deserializer: D) -> Result<u64, D::Error> { impl TaskPayload for DocumentPayload {
Ok(match Value::deserialize(deserializer)? { fn get_key(&self) -> String {
Value::String(s) => s.parse().map_err(de::Error::custom)?, self.url.clone()
Value::Number(num) => num.as_u64().ok_or(de::Error::custom("Invalid number"))?,
_ => return Err(de::Error::custom("wrong type")),
})
} }
fn str_to_bool<'de, D: Deserializer<'de>>(deserializer: D) -> Result<bool, D::Error> {
Ok(match Value::deserialize(deserializer)? {
Value::String(s) => s.parse().map_err(de::Error::custom)?,
Value::Bool(b) => b,
_ => return Err(de::Error::custom("wrong type")),
})
} }
fn single_or_vec<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Vec<String>, D::Error> { fn single_or_vec<'de, D: Deserializer<'de>>(deserializer: D) -> Result<Vec<String>, D::Error> {

View file

@ -1,10 +1,37 @@
use crate::Error; use crate::Error;
use directories::ProjectDirs; use directories::ProjectDirs;
use serde::Serialize;
use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
use sqlx::SqlitePool; use sqlx::{QueryBuilder, Sqlite, SqlitePool};
use tokio::fs; use tokio::fs;
pub async fn get_database() -> crate::Result<SqlitePool> { static SQLITE_QUERY_LIMIT: usize = 32766;
#[derive(sqlx::Type)]
#[repr(u8)]
pub enum TaskStatus {
Pending = 1,
InProgress = 2,
Completed = 3,
Failed = 4,
}
pub trait TaskPayload {
fn get_key(&self) -> String;
}
pub struct TaskManager {
pool: SqlitePool,
}
impl TaskManager {
pub async fn new() -> Result<TaskManager, Error> {
Ok(Self {
pool: Self::connect_database().await?,
})
}
async fn connect_database() -> crate::Result<SqlitePool> {
let project_dir = ProjectDirs::from("", "", env!("CARGO_PKG_NAME")) let project_dir = ProjectDirs::from("", "", env!("CARGO_PKG_NAME"))
.ok_or(Error::Unhandled("Could not get standard directories"))?; .ok_or(Error::Unhandled("Could not get standard directories"))?;
@ -23,3 +50,35 @@ pub async fn get_database() -> crate::Result<SqlitePool> {
Ok(pool) Ok(pool)
} }
pub async fn load_tasks<T>(&self, values: Vec<T>) -> crate::Result<()>
where
T: TaskPayload + Serialize,
{
let mut builder: QueryBuilder<'_, Sqlite> =
QueryBuilder::new("insert into tasks(payload_key, payload, status_id)");
let args: crate::Result<Vec<(String, String)>> = values
.iter()
.map(|value| Ok((value.get_key(), serde_json::to_string(value)?)))
.collect();
builder.push_values(
args?.into_iter().take(SQLITE_QUERY_LIMIT / 3),
|mut builder, item| {
builder
.push_bind(item.0)
.push_bind(item.1)
.push_bind(TaskStatus::Pending);
},
);
builder.push("ON conflict (payload_key) DO NOTHING");
let query = builder.build();
query.execute(&self.pool).await?;
Ok(())
}
}