From ee083dfc45d91a4dc92ebe7d546352f419ca6e02 Mon Sep 17 00:00:00 2001 From: yggverse Date: Sat, 10 Jan 2026 01:41:20 +0200 Subject: [PATCH] optimize db api --- crates/crawler/Cargo.toml | 2 +- crates/crawler/src/main.rs | 46 +++++---- crates/http/src/main.rs | 97 +++++++++++++------ crates/llm/Cargo.toml | 2 +- crates/llm/src/main.rs | 88 ++++++++--------- crates/mysql/Cargo.toml | 5 +- .../mysql/src/{pollable.rs => connection.rs} | 63 ++++++------ crates/mysql/src/lib.rs | 41 ++++++-- crates/mysql/src/pollable/sort.rs | 13 --- crates/mysql/src/table.rs | 14 +++ .../src/{transactional.rs => transaction.rs} | 36 ++----- 11 files changed, 215 insertions(+), 192 deletions(-) rename crates/mysql/src/{pollable.rs => connection.rs} (63%) delete mode 100644 crates/mysql/src/pollable/sort.rs rename crates/mysql/src/{transactional.rs => transaction.rs} (81%) diff --git a/crates/crawler/Cargo.toml b/crates/crawler/Cargo.toml index 1de8c3d..d531744 100644 --- a/crates/crawler/Cargo.toml +++ b/crates/crawler/Cargo.toml @@ -14,7 +14,7 @@ anyhow = "1.0.100" chrono = "0.4.42" clap = { version = "4.5.54", features = ["derive"] } log = "0.4.29" -mysql = { package = "rssto-mysql", version = "0.1.0", features = ["transactional"], path = "../mysql" } +mysql = { package = "rssto-mysql", version = "0.1.0", features = ["transaction"], path = "../mysql" } reqwest = { version = "0.13.1", features = ["blocking"] } rss = "2.0.12" scraper = { version = "0.25.0", features = ["serde"] } diff --git a/crates/crawler/src/main.rs b/crates/crawler/src/main.rs index 25c8279..0d0867a 100644 --- a/crates/crawler/src/main.rs +++ b/crates/crawler/src/main.rs @@ -3,11 +3,9 @@ mod config; use anyhow::Result; use log::{debug, info, warn}; -use mysql::Transactional; use reqwest::blocking::get; fn main() -> Result<()> { - use argument::Argument; use chrono::Local; use clap::Parser; use std::{env::var, fs::read_to_string}; @@ -26,29 +24,29 @@ fn main() -> Result<()> { .init() } - let argument = Argument::parse(); + let argument = argument::Argument::parse(); let config: config::Config = toml::from_str(&read_to_string(argument.config)?)?; + let db = mysql::Database::pool( + &config.mysql.host, + config.mysql.port, + &config.mysql.user, + &config.mysql.password, + &config.mysql.database, + )?; info!("Crawler started"); loop { debug!("Begin new crawl queue..."); - { - // disconnect from the database immediately when exiting this scope, - // in case the `update` queue is enabled and pending for a while. - let mut db = Transactional::connect( - &config.mysql.host, - config.mysql.port, - &config.mysql.user, - &config.mysql.password, - &config.mysql.database, - )?; - for c in &config.channel { - debug!("Update `{}`...", c.url); - if let Err(e) = crawl(&mut db, c) { - warn!("Channel `{}` update failed: `{e}`", c.url) + for c in &config.channel { + debug!("Update `{}`...", c.url); + let mut tx = db.transaction()?; + match crawl(&mut tx, c) { + Ok(()) => tx.commit()?, + Err(e) => { + warn!("Channel `{}` update failed: `{e}`", c.url); + tx.rollback()? } } - db.commit()? } debug!("Crawl queue completed"); if let Some(update) = config.update { @@ -60,7 +58,7 @@ fn main() -> Result<()> { } } -fn crawl(db: &mut Transactional, channel_config: &config::Channel) -> Result<()> { +fn crawl(tx: &mut mysql::Transaction, channel_config: &config::Channel) -> Result<()> { use rss::Channel; use scraper::Selector; @@ -87,9 +85,9 @@ fn crawl(db: &mut Transactional, channel_config: &config::Channel) -> Result<()> let channel_items_limit = channel_config.items_limit.unwrap_or(channel_items.len()); - let channel_id = match db.channel_id_by_url(&channel_url)? { + let channel_id = match tx.channel_id_by_url(&channel_url)? { Some(channel_id) => channel_id, - None => db.insert_channel(&channel_url)?, + None => tx.insert_channel(&channel_url)?, }; for channel_item in channel_items.iter().take(channel_items_limit) { @@ -120,10 +118,10 @@ fn crawl(db: &mut Transactional, channel_config: &config::Channel) -> Result<()> continue; } }; - if db.channel_items_total_by_channel_id_guid(channel_id, guid)? > 0 { + if tx.channel_items_total_by_channel_id_guid(channel_id, guid)? > 0 { continue; // skip next steps as processed } - let channel_item_id = db.insert_channel_item( + let channel_item_id = tx.insert_channel_item( channel_id, pub_date, guid, @@ -188,7 +186,7 @@ fn crawl(db: &mut Transactional, channel_config: &config::Channel) -> Result<()> } }, }; - let _content_id = db.insert_content(channel_item_id, None, &title, &description)?; + let _content_id = tx.insert_content(channel_item_id, None, &title, &description)?; // @TODO preload media } Ok(()) diff --git a/crates/http/src/main.rs b/crates/http/src/main.rs index 3964aea..0803379 100644 --- a/crates/http/src/main.rs +++ b/crates/http/src/main.rs @@ -11,7 +11,7 @@ use config::Config; use feed::Feed; use global::Global; use meta::Meta; -use mysql::{Pollable, pollable::Sort}; +use mysql::{Database, table::Sort}; use rocket::{State, http::Status, response::content::RawXml, serde::Serialize}; use rocket_dyn_templates::{Template, context}; @@ -19,7 +19,7 @@ use rocket_dyn_templates::{Template, context}; fn index( search: Option<&str>, page: Option, - db: &State, + db: &State, meta: &State, global: &State, ) -> Result { @@ -32,7 +32,11 @@ fn index( time: String, title: String, } - let total = db + let mut conn = db.connection().map_err(|e| { + error!("Could not connect database: `{e}`"); + Status::InternalServerError + })?; + let total = conn .contents_total_by_provider_id(global.provider_id, search) .map_err(|e| { error!("Could not get contents total: `{e}`"); @@ -65,19 +69,24 @@ fn index( back: page.map(|p| uri!(index(search, if p > 2 { Some(p - 1) } else { None }))), next: if page.unwrap_or(1) * global.list_limit >= total { None } else { Some(uri!(index(search, Some(page.map_or(2, |p| p + 1))))) }, - rows: db.contents_by_provider_id(global.provider_id, search, Sort::Desc, Some(global.list_limit)).map_err(|e| { - error!("Could not get contents: `{e}`"); - Status::InternalServerError - })? + rows: conn.contents_by_provider_id( + global.provider_id, + search, + Sort::Desc, + Some(global.list_limit) + ).map_err(|e| { + error!("Could not get contents: `{e}`"); + Status::InternalServerError + })? .into_iter() - .map(|c| { - let channel_item = db.channel_item(c.channel_item_id).unwrap().unwrap(); + .map(|content| { + let channel_item = conn.channel_item(content.channel_item_id).unwrap().unwrap(); Content { - content_id: c.content_id, - description: c.description, + content_id: content.content_id, + description: content.description, link: channel_item.link, time: time(channel_item.pub_date).format(&global.format_time).to_string(), - title: c.title, + title: content.title, } }) .collect::>(), @@ -92,25 +101,38 @@ fn index( #[get("/")] fn info( content_id: u64, - db: &State, + db: &State, meta: &State, global: &State, ) -> Result { - match db.content(content_id).map_err(|e| { + let mut conn = db.connection().map_err(|e| { + error!("Could not connect database: `{e}`"); + Status::InternalServerError + })?; + match conn.content(content_id).map_err(|e| { error!("Could not get content `{content_id}`: `{e}`"); Status::InternalServerError })? { - Some(c) => { - let i = db.channel_item(c.channel_item_id).unwrap().unwrap(); + Some(content) => { + let channel_item = conn + .channel_item(content.channel_item_id) + .map_err(|e| { + error!("Could not get requested channel item: `{e}`"); + Status::InternalServerError + })? + .ok_or_else(|| { + error!("Could not find requested channel item"); + Status::NotFound + })?; Ok(Template::render( "info", context! { - description: c.description, - link: i.link, + description: content.description, + link: channel_item.link, meta: meta.inner(), - title: format!("{}{S}{}", c.title, meta.title), - name: c.title, - time: time(i.pub_date).format(&global.format_time).to_string(), + title: format!("{}{S}{}", content.title, meta.title), + name: content.title, + time: time(channel_item.pub_date).format(&global.format_time).to_string(), }, )) } @@ -123,30 +145,43 @@ fn rss( search: Option<&str>, global: &State, meta: &State, - db: &State, + db: &State, ) -> Result, Status> { - let mut f = Feed::new( + let mut feed = Feed::new( &meta.title, meta.description.as_deref(), 1024, // @TODO ); - for c in db + let mut conn = db.connection().map_err(|e| { + error!("Could not connect database: `{e}`"); + Status::InternalServerError + })?; + for content in conn .contents_by_provider_id(global.provider_id, search, Sort::Desc, Some(20)) // @TODO .map_err(|e| { error!("Could not load channel item contents: `{e}`"); Status::InternalServerError })? { - let channel_item = db.channel_item(c.channel_item_id).unwrap().unwrap(); - f.push( - c.channel_item_id, + let channel_item = conn + .channel_item(content.channel_item_id) + .map_err(|e| { + error!("Could not get requested channel item: `{e}`"); + Status::InternalServerError + })? + .ok_or_else(|| { + error!("Could not find requested channel item"); + Status::NotFound + })?; + feed.push( + content.channel_item_id, time(channel_item.pub_date), channel_item.link, - c.title, - c.description, + content.title, + content.description, ) } - Ok(RawXml(f.commit())) + Ok(RawXml(feed.commit())) } #[launch] @@ -165,7 +200,7 @@ fn rocket() -> _ { } }) .manage( - Pollable::connect( + Database::pool( &config.mysql_host, config.mysql_port, &config.mysql_username, diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index a5fa968..3e371e8 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -15,6 +15,6 @@ chrono = "0.4.42" clap = { version = "4.5.54", features = ["derive"] } lancor = "0.1.1" log = "0.4.29" -mysql = { package = "rssto-mysql", version = "0.1.0", features = ["transactional"], path = "../mysql" } +mysql = { package = "rssto-mysql", version = "0.1.0", features = ["transaction"], path = "../mysql" } tokio = { version = "1.0", features = ["full"] } tracing-subscriber = { version = "0.3.22", features = ["env-filter"] } \ No newline at end of file diff --git a/crates/llm/src/main.rs b/crates/llm/src/main.rs index dbcfd59..1774e00 100644 --- a/crates/llm/src/main.rs +++ b/crates/llm/src/main.rs @@ -2,7 +2,7 @@ mod argument; use anyhow::Result; use argument::Argument; -use mysql::Transactional; +use mysql::Database; #[tokio::main] async fn main() -> Result<()> { @@ -32,12 +32,17 @@ async fn main() -> Result<()> { "{}://{}:{}", arg.llm_scheme, arg.llm_host, arg.llm_port ))?; + let db = Database::pool( + &arg.mysql_host, + arg.mysql_port, + &arg.mysql_username, + &arg.mysql_password, + &arg.mysql_database, + )?; - // find existing ID by model name or create a new one - // * this feature should be moved to a separate CLI tool let provider_id = { - let mut db = tx(&arg)?; - match db.provider_id_by_name(&arg.llm_model)? { + let mut conn = db.connection()?; + match conn.provider_id_by_name(&arg.llm_model)? { Some(provider_id) => { debug!( "Use existing DB provider #{} matches model name `{}`", @@ -46,12 +51,11 @@ async fn main() -> Result<()> { provider_id } None => { - let provider_id = db.insert_provider(&arg.llm_model)?; + let provider_id = conn.insert_provider(&arg.llm_model)?; info!( "Provider `{}` not found in database, created new one with ID `{provider_id}`", &arg.llm_model ); - db.commit()?; provider_id } } @@ -60,42 +64,38 @@ async fn main() -> Result<()> { info!("Daemon started"); loop { debug!("New queue begin..."); - { - // disconnect from the database immediately when exiting this scope, - // in case the `update` queue is enabled and pending for a while. - let mut db = tx(&arg)?; - for source in db.contents_queue_for_provider_id(provider_id)? { - debug!( - "Begin generating `content_id` #{} using `provider_id` #{provider_id}.", - source.content_id - ); + let mut tx = db.transaction()?; + for source in tx.contents_queue_for_provider_id(provider_id)? { + debug!( + "Begin generating `content_id` #{} using `provider_id` #{provider_id}.", + source.content_id + ); - let title = llm - .chat_completion(ChatCompletionRequest::new(&arg.llm_model).message( - Message::user(format!("{}\n{}", arg.llm_message, source.title)), - )) - .await?; + let title = + llm.chat_completion(ChatCompletionRequest::new(&arg.llm_model).message( + Message::user(format!("{}\n{}", arg.llm_message, source.title)), + )) + .await?; - let description = llm - .chat_completion(ChatCompletionRequest::new(&arg.llm_model).message( - Message::user(format!("{}\n{}", arg.llm_message, source.description)), - )) - .await?; + let description = + llm.chat_completion(ChatCompletionRequest::new(&arg.llm_model).message( + Message::user(format!("{}\n{}", arg.llm_message, source.description)), + )) + .await?; - let content_id = db.insert_content( - source.channel_item_id, - Some(provider_id), - &title.choices[0].message.content, - &description.choices[0].message.content, - )?; + let content_id = tx.insert_content( + source.channel_item_id, + Some(provider_id), + &title.choices[0].message.content, + &description.choices[0].message.content, + )?; - debug!( - "Created `content_id` #{content_id} using `content_id` #{} source by `provider_id` #{provider_id}.", - source.content_id - ) - } - db.commit()? + debug!( + "Created `content_id` #{content_id} using `content_id` #{} source by `provider_id` #{provider_id}.", + source.content_id + ) } + tx.commit()?; debug!("Queue completed"); if let Some(update) = arg.update { debug!("Wait {update} seconds to continue..."); @@ -105,15 +105,3 @@ async fn main() -> Result<()> { } } } - -// in fact, there is no need for a transaction at this moment, -// as there are no related table updates, but who knows what the future holds -fn tx(arg: &Argument) -> Result { - Ok(Transactional::connect( - &arg.mysql_host, - arg.mysql_port, - &arg.mysql_username, - &arg.mysql_password, - &arg.mysql_database, - )?) -} diff --git a/crates/mysql/Cargo.toml b/crates/mysql/Cargo.toml index 7aeb4af..253d787 100644 --- a/crates/mysql/Cargo.toml +++ b/crates/mysql/Cargo.toml @@ -10,9 +10,8 @@ keywords = ["rssto", "database", "mysql", "library", "driver", "api"] repository = "https://github.com/YGGverse/rssto" [features] -default = ["pollable"] -pollable = [] -transactional = [] +default = [] +transaction = [] [dependencies] mysql = "26.0.1" \ No newline at end of file diff --git a/crates/mysql/src/pollable.rs b/crates/mysql/src/connection.rs similarity index 63% rename from crates/mysql/src/pollable.rs rename to crates/mysql/src/connection.rs index 474c427..c59e2df 100644 --- a/crates/mysql/src/pollable.rs +++ b/crates/mysql/src/connection.rs @@ -1,32 +1,20 @@ -pub mod sort; - -pub use sort::Sort; - use crate::table::*; -use mysql::{Error, Pool, prelude::Queryable}; +use mysql::{Error, Pool, PooledConn, prelude::Queryable}; /// Safe, read-only operations used in client apps like `rssto-http` -pub struct Pollable { - pool: Pool, +pub struct Connection { + conn: PooledConn, } -impl Pollable { - pub fn connect( - host: &str, - port: u16, - user: &str, - password: &str, - database: &str, - ) -> Result { +impl Connection { + pub fn create(pool: &Pool) -> Result { Ok(Self { - pool: mysql::Pool::new( - format!("mysql://{user}:{password}@{host}:{port}/{database}").as_str(), - )?, + conn: pool.get_conn()?, }) } - pub fn channel_item(&self, channel_item_id: u64) -> Result, Error> { - self.pool.get_conn()?.exec_first( + pub fn channel_item(&mut self, channel_item_id: u64) -> Result, Error> { + self.conn.exec_first( "SELECT `channel_item_id`, `channel_id`, `pub_date`, @@ -38,8 +26,8 @@ impl Pollable { ) } - pub fn content(&self, content_id: u64) -> Result, Error> { - self.pool.get_conn()?.exec_first( + pub fn content(&mut self, content_id: u64) -> Result, Error> { + self.conn.exec_first( "SELECT `content_id`, `channel_item_id`, `provider_id`, @@ -50,11 +38,11 @@ impl Pollable { } pub fn contents_total_by_provider_id( - &self, + &mut self, provider_id: Option, keyword: Option<&str>, ) -> Result { - let total: Option = self.pool.get_conn()?.exec_first( + let total: Option = self.conn.exec_first( "SELECT COUNT(*) FROM `content` WHERE `provider_id` <=> ? AND `title` LIKE ?", (provider_id, like(keyword)), )?; @@ -62,13 +50,13 @@ impl Pollable { } pub fn contents_by_provider_id( - &self, + &mut self, provider_id: Option, keyword: Option<&str>, sort: Sort, limit: Option, ) -> Result, Error> { - self.pool.get_conn()?.exec(format!( + self.conn.exec(format!( "SELECT `content_id`, `channel_item_id`, `provider_id`, @@ -79,8 +67,8 @@ impl Pollable { (provider_id, like(keyword), )) } - pub fn content_image(&self, content_image_id: u64) -> Result, Error> { - self.pool.get_conn()?.exec_first( + pub fn content_image(&mut self, content_image_id: u64) -> Result, Error> { + self.conn.exec_first( "SELECT `content_image_id`, `content_id`, `image_id`, @@ -92,17 +80,24 @@ impl Pollable { ) } - pub fn images(&self, limit: Option) -> Result, Error> { - self.pool.get_conn()?.query(format!( + pub fn images(&mut self, limit: Option) -> Result, Error> { + self.conn.query(format!( "SELECT `image_id`, `source`, `data` FROM `image` LIMIT {}", limit.unwrap_or(DEFAULT_LIMIT) )) } - pub fn insert_provider(&self, name: &str) -> Result { - let mut c = self.pool.get_conn()?; - c.exec_drop("INSERT INTO `provider` SET `name` = ?", (name,))?; - Ok(c.last_insert_id()) + pub fn provider_id_by_name(&mut self, name: &str) -> Result, Error> { + self.conn.exec_first( + "SELECT `provider_id` FROM `provider` WHERE `name` = ?", + (name,), + ) + } + + pub fn insert_provider(&mut self, name: &str) -> Result { + self.conn + .exec_drop("INSERT INTO `provider` SET `name` = ?", (name,))?; + Ok(self.conn.last_insert_id()) } } diff --git a/crates/mysql/src/lib.rs b/crates/mysql/src/lib.rs index c316798..53ef7d6 100644 --- a/crates/mysql/src/lib.rs +++ b/crates/mysql/src/lib.rs @@ -1,13 +1,36 @@ -#[cfg(feature = "pollable")] -pub mod pollable; - +mod connection; pub mod table; +#[cfg(feature = "transaction")] +mod transaction; -#[cfg(feature = "transactional")] -pub mod transactional; +pub use connection::Connection; +#[cfg(feature = "transaction")] +pub use transaction::Transaction; +pub struct Database { + pool: mysql::Pool, +} -#[cfg(feature = "pollable")] -pub use pollable::Pollable; +impl Database { + pub fn pool( + host: &str, + port: u16, + user: &str, + password: &str, + database: &str, + ) -> Result { + Ok(Self { + pool: mysql::Pool::new( + format!("mysql://{user}:{password}@{host}:{port}/{database}").as_str(), + )?, + }) + } -#[cfg(feature = "transactional")] -pub use transactional::Transactional; + pub fn connection(&self) -> Result { + Connection::create(&self.pool) + } + + #[cfg(feature = "transaction")] + pub fn transaction(&self) -> Result { + Transaction::create(&self.pool) + } +} diff --git a/crates/mysql/src/pollable/sort.rs b/crates/mysql/src/pollable/sort.rs deleted file mode 100644 index d8b121d..0000000 --- a/crates/mysql/src/pollable/sort.rs +++ /dev/null @@ -1,13 +0,0 @@ -pub enum Sort { - Asc, - Desc, -} - -impl std::fmt::Display for Sort { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Self::Asc => write!(f, "ASC"), - Self::Desc => write!(f, "DESC"), - } - } -} diff --git a/crates/mysql/src/table.rs b/crates/mysql/src/table.rs index 5df3348..3ee92ce 100644 --- a/crates/mysql/src/table.rs +++ b/crates/mysql/src/table.rs @@ -51,3 +51,17 @@ pub struct ContentImage { pub data: Vec, pub source: String, } + +pub enum Sort { + Asc, + Desc, +} + +impl std::fmt::Display for Sort { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Asc => write!(f, "ASC"), + Self::Desc => write!(f, "DESC"), + } + } +} diff --git a/crates/mysql/src/transactional.rs b/crates/mysql/src/transaction.rs similarity index 81% rename from crates/mysql/src/transactional.rs rename to crates/mysql/src/transaction.rs index ce80305..a39e290 100644 --- a/crates/mysql/src/transactional.rs +++ b/crates/mysql/src/transaction.rs @@ -1,24 +1,17 @@ use crate::table::*; -use mysql::{Error, Pool, Transaction, TxOpts, prelude::Queryable}; +use mysql::{Error, Pool, TxOpts, prelude::Queryable}; /// Safe, optimized read/write operations /// mostly required by the `rssto-crawler` and `rssto-llm` /// * all members implementation requires `commit` action -pub struct Transactional { - tx: Transaction<'static>, +pub struct Transaction { + tx: mysql::Transaction<'static>, } -impl Transactional { - pub fn connect( - host: &str, - port: u16, - user: &str, - password: &str, - database: &str, - ) -> Result { +impl Transaction { + pub fn create(pool: &Pool) -> Result { Ok(Self { - tx: Pool::new(format!("mysql://{user}:{password}@{host}:{port}/{database}").as_str())? - .start_transaction(TxOpts::default())?, + tx: pool.start_transaction(TxOpts::default())?, }) } @@ -26,6 +19,10 @@ impl Transactional { self.tx.commit() } + pub fn rollback(self) -> Result<(), Error> { + self.tx.rollback() + } + pub fn channel_id_by_url(&mut self, url: &str) -> Result, Error> { self.tx.exec_first( "SELECT `channel_id` FROM `channel` WHERE `url` = ? LIMIT 1", @@ -132,17 +129,4 @@ impl Transactional { )?; Ok(self.tx.last_insert_id().unwrap()) } - - pub fn provider_id_by_name(&mut self, name: &str) -> Result, Error> { - self.tx.exec_first( - "SELECT `provider_id` FROM `provider` WHERE `name` = ?", - (name,), - ) - } - - pub fn insert_provider(&mut self, name: &str) -> Result { - self.tx - .exec_drop("INSERT INTO `provider` SET `name` = ?", (name,))?; - Ok(self.tx.last_insert_id().unwrap()) - } }