optimize db api

This commit is contained in:
yggverse 2026-01-10 01:41:20 +02:00
parent f48e256fad
commit ee083dfc45
11 changed files with 215 additions and 192 deletions

View file

@ -14,7 +14,7 @@ anyhow = "1.0.100"
chrono = "0.4.42" chrono = "0.4.42"
clap = { version = "4.5.54", features = ["derive"] } clap = { version = "4.5.54", features = ["derive"] }
log = "0.4.29" log = "0.4.29"
mysql = { package = "rssto-mysql", version = "0.1.0", features = ["transactional"], path = "../mysql" } mysql = { package = "rssto-mysql", version = "0.1.0", features = ["transaction"], path = "../mysql" }
reqwest = { version = "0.13.1", features = ["blocking"] } reqwest = { version = "0.13.1", features = ["blocking"] }
rss = "2.0.12" rss = "2.0.12"
scraper = { version = "0.25.0", features = ["serde"] } scraper = { version = "0.25.0", features = ["serde"] }

View file

@ -3,11 +3,9 @@ mod config;
use anyhow::Result; use anyhow::Result;
use log::{debug, info, warn}; use log::{debug, info, warn};
use mysql::Transactional;
use reqwest::blocking::get; use reqwest::blocking::get;
fn main() -> Result<()> { fn main() -> Result<()> {
use argument::Argument;
use chrono::Local; use chrono::Local;
use clap::Parser; use clap::Parser;
use std::{env::var, fs::read_to_string}; use std::{env::var, fs::read_to_string};
@ -26,29 +24,29 @@ fn main() -> Result<()> {
.init() .init()
} }
let argument = Argument::parse(); let argument = argument::Argument::parse();
let config: config::Config = toml::from_str(&read_to_string(argument.config)?)?; let config: config::Config = toml::from_str(&read_to_string(argument.config)?)?;
let db = mysql::Database::pool(
&config.mysql.host,
config.mysql.port,
&config.mysql.user,
&config.mysql.password,
&config.mysql.database,
)?;
info!("Crawler started"); info!("Crawler started");
loop { loop {
debug!("Begin new crawl queue..."); debug!("Begin new crawl queue...");
{ for c in &config.channel {
// disconnect from the database immediately when exiting this scope, debug!("Update `{}`...", c.url);
// in case the `update` queue is enabled and pending for a while. let mut tx = db.transaction()?;
let mut db = Transactional::connect( match crawl(&mut tx, c) {
&config.mysql.host, Ok(()) => tx.commit()?,
config.mysql.port, Err(e) => {
&config.mysql.user, warn!("Channel `{}` update failed: `{e}`", c.url);
&config.mysql.password, tx.rollback()?
&config.mysql.database,
)?;
for c in &config.channel {
debug!("Update `{}`...", c.url);
if let Err(e) = crawl(&mut db, c) {
warn!("Channel `{}` update failed: `{e}`", c.url)
} }
} }
db.commit()?
} }
debug!("Crawl queue completed"); debug!("Crawl queue completed");
if let Some(update) = config.update { if let Some(update) = config.update {
@ -60,7 +58,7 @@ fn main() -> Result<()> {
} }
} }
fn crawl(db: &mut Transactional, channel_config: &config::Channel) -> Result<()> { fn crawl(tx: &mut mysql::Transaction, channel_config: &config::Channel) -> Result<()> {
use rss::Channel; use rss::Channel;
use scraper::Selector; use scraper::Selector;
@ -87,9 +85,9 @@ fn crawl(db: &mut Transactional, channel_config: &config::Channel) -> Result<()>
let channel_items_limit = channel_config.items_limit.unwrap_or(channel_items.len()); let channel_items_limit = channel_config.items_limit.unwrap_or(channel_items.len());
let channel_id = match db.channel_id_by_url(&channel_url)? { let channel_id = match tx.channel_id_by_url(&channel_url)? {
Some(channel_id) => channel_id, Some(channel_id) => channel_id,
None => db.insert_channel(&channel_url)?, None => tx.insert_channel(&channel_url)?,
}; };
for channel_item in channel_items.iter().take(channel_items_limit) { for channel_item in channel_items.iter().take(channel_items_limit) {
@ -120,10 +118,10 @@ fn crawl(db: &mut Transactional, channel_config: &config::Channel) -> Result<()>
continue; continue;
} }
}; };
if db.channel_items_total_by_channel_id_guid(channel_id, guid)? > 0 { if tx.channel_items_total_by_channel_id_guid(channel_id, guid)? > 0 {
continue; // skip next steps as processed continue; // skip next steps as processed
} }
let channel_item_id = db.insert_channel_item( let channel_item_id = tx.insert_channel_item(
channel_id, channel_id,
pub_date, pub_date,
guid, guid,
@ -188,7 +186,7 @@ fn crawl(db: &mut Transactional, channel_config: &config::Channel) -> Result<()>
} }
}, },
}; };
let _content_id = db.insert_content(channel_item_id, None, &title, &description)?; let _content_id = tx.insert_content(channel_item_id, None, &title, &description)?;
// @TODO preload media // @TODO preload media
} }
Ok(()) Ok(())

View file

@ -11,7 +11,7 @@ use config::Config;
use feed::Feed; use feed::Feed;
use global::Global; use global::Global;
use meta::Meta; use meta::Meta;
use mysql::{Pollable, pollable::Sort}; use mysql::{Database, table::Sort};
use rocket::{State, http::Status, response::content::RawXml, serde::Serialize}; use rocket::{State, http::Status, response::content::RawXml, serde::Serialize};
use rocket_dyn_templates::{Template, context}; use rocket_dyn_templates::{Template, context};
@ -19,7 +19,7 @@ use rocket_dyn_templates::{Template, context};
fn index( fn index(
search: Option<&str>, search: Option<&str>,
page: Option<usize>, page: Option<usize>,
db: &State<Pollable>, db: &State<Database>,
meta: &State<Meta>, meta: &State<Meta>,
global: &State<Global>, global: &State<Global>,
) -> Result<Template, Status> { ) -> Result<Template, Status> {
@ -32,7 +32,11 @@ fn index(
time: String, time: String,
title: String, title: String,
} }
let total = db let mut conn = db.connection().map_err(|e| {
error!("Could not connect database: `{e}`");
Status::InternalServerError
})?;
let total = conn
.contents_total_by_provider_id(global.provider_id, search) .contents_total_by_provider_id(global.provider_id, search)
.map_err(|e| { .map_err(|e| {
error!("Could not get contents total: `{e}`"); error!("Could not get contents total: `{e}`");
@ -65,19 +69,24 @@ fn index(
back: page.map(|p| uri!(index(search, if p > 2 { Some(p - 1) } else { None }))), back: page.map(|p| uri!(index(search, if p > 2 { Some(p - 1) } else { None }))),
next: if page.unwrap_or(1) * global.list_limit >= total { None } next: if page.unwrap_or(1) * global.list_limit >= total { None }
else { Some(uri!(index(search, Some(page.map_or(2, |p| p + 1))))) }, else { Some(uri!(index(search, Some(page.map_or(2, |p| p + 1))))) },
rows: db.contents_by_provider_id(global.provider_id, search, Sort::Desc, Some(global.list_limit)).map_err(|e| { rows: conn.contents_by_provider_id(
error!("Could not get contents: `{e}`"); global.provider_id,
Status::InternalServerError search,
})? Sort::Desc,
Some(global.list_limit)
).map_err(|e| {
error!("Could not get contents: `{e}`");
Status::InternalServerError
})?
.into_iter() .into_iter()
.map(|c| { .map(|content| {
let channel_item = db.channel_item(c.channel_item_id).unwrap().unwrap(); let channel_item = conn.channel_item(content.channel_item_id).unwrap().unwrap();
Content { Content {
content_id: c.content_id, content_id: content.content_id,
description: c.description, description: content.description,
link: channel_item.link, link: channel_item.link,
time: time(channel_item.pub_date).format(&global.format_time).to_string(), time: time(channel_item.pub_date).format(&global.format_time).to_string(),
title: c.title, title: content.title,
} }
}) })
.collect::<Vec<Content>>(), .collect::<Vec<Content>>(),
@ -92,25 +101,38 @@ fn index(
#[get("/<content_id>")] #[get("/<content_id>")]
fn info( fn info(
content_id: u64, content_id: u64,
db: &State<Pollable>, db: &State<Database>,
meta: &State<Meta>, meta: &State<Meta>,
global: &State<Global>, global: &State<Global>,
) -> Result<Template, Status> { ) -> Result<Template, Status> {
match db.content(content_id).map_err(|e| { let mut conn = db.connection().map_err(|e| {
error!("Could not connect database: `{e}`");
Status::InternalServerError
})?;
match conn.content(content_id).map_err(|e| {
error!("Could not get content `{content_id}`: `{e}`"); error!("Could not get content `{content_id}`: `{e}`");
Status::InternalServerError Status::InternalServerError
})? { })? {
Some(c) => { Some(content) => {
let i = db.channel_item(c.channel_item_id).unwrap().unwrap(); let channel_item = conn
.channel_item(content.channel_item_id)
.map_err(|e| {
error!("Could not get requested channel item: `{e}`");
Status::InternalServerError
})?
.ok_or_else(|| {
error!("Could not find requested channel item");
Status::NotFound
})?;
Ok(Template::render( Ok(Template::render(
"info", "info",
context! { context! {
description: c.description, description: content.description,
link: i.link, link: channel_item.link,
meta: meta.inner(), meta: meta.inner(),
title: format!("{}{S}{}", c.title, meta.title), title: format!("{}{S}{}", content.title, meta.title),
name: c.title, name: content.title,
time: time(i.pub_date).format(&global.format_time).to_string(), time: time(channel_item.pub_date).format(&global.format_time).to_string(),
}, },
)) ))
} }
@ -123,30 +145,43 @@ fn rss(
search: Option<&str>, search: Option<&str>,
global: &State<Global>, global: &State<Global>,
meta: &State<Meta>, meta: &State<Meta>,
db: &State<Pollable>, db: &State<Database>,
) -> Result<RawXml<String>, Status> { ) -> Result<RawXml<String>, Status> {
let mut f = Feed::new( let mut feed = Feed::new(
&meta.title, &meta.title,
meta.description.as_deref(), meta.description.as_deref(),
1024, // @TODO 1024, // @TODO
); );
for c in db let mut conn = db.connection().map_err(|e| {
error!("Could not connect database: `{e}`");
Status::InternalServerError
})?;
for content in conn
.contents_by_provider_id(global.provider_id, search, Sort::Desc, Some(20)) // @TODO .contents_by_provider_id(global.provider_id, search, Sort::Desc, Some(20)) // @TODO
.map_err(|e| { .map_err(|e| {
error!("Could not load channel item contents: `{e}`"); error!("Could not load channel item contents: `{e}`");
Status::InternalServerError Status::InternalServerError
})? })?
{ {
let channel_item = db.channel_item(c.channel_item_id).unwrap().unwrap(); let channel_item = conn
f.push( .channel_item(content.channel_item_id)
c.channel_item_id, .map_err(|e| {
error!("Could not get requested channel item: `{e}`");
Status::InternalServerError
})?
.ok_or_else(|| {
error!("Could not find requested channel item");
Status::NotFound
})?;
feed.push(
content.channel_item_id,
time(channel_item.pub_date), time(channel_item.pub_date),
channel_item.link, channel_item.link,
c.title, content.title,
c.description, content.description,
) )
} }
Ok(RawXml(f.commit())) Ok(RawXml(feed.commit()))
} }
#[launch] #[launch]
@ -165,7 +200,7 @@ fn rocket() -> _ {
} }
}) })
.manage( .manage(
Pollable::connect( Database::pool(
&config.mysql_host, &config.mysql_host,
config.mysql_port, config.mysql_port,
&config.mysql_username, &config.mysql_username,

View file

@ -15,6 +15,6 @@ chrono = "0.4.42"
clap = { version = "4.5.54", features = ["derive"] } clap = { version = "4.5.54", features = ["derive"] }
lancor = "0.1.1" lancor = "0.1.1"
log = "0.4.29" log = "0.4.29"
mysql = { package = "rssto-mysql", version = "0.1.0", features = ["transactional"], path = "../mysql" } mysql = { package = "rssto-mysql", version = "0.1.0", features = ["transaction"], path = "../mysql" }
tokio = { version = "1.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] }
tracing-subscriber = { version = "0.3.22", features = ["env-filter"] } tracing-subscriber = { version = "0.3.22", features = ["env-filter"] }

View file

@ -2,7 +2,7 @@ mod argument;
use anyhow::Result; use anyhow::Result;
use argument::Argument; use argument::Argument;
use mysql::Transactional; use mysql::Database;
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
@ -32,12 +32,17 @@ async fn main() -> Result<()> {
"{}://{}:{}", "{}://{}:{}",
arg.llm_scheme, arg.llm_host, arg.llm_port arg.llm_scheme, arg.llm_host, arg.llm_port
))?; ))?;
let db = Database::pool(
&arg.mysql_host,
arg.mysql_port,
&arg.mysql_username,
&arg.mysql_password,
&arg.mysql_database,
)?;
// find existing ID by model name or create a new one
// * this feature should be moved to a separate CLI tool
let provider_id = { let provider_id = {
let mut db = tx(&arg)?; let mut conn = db.connection()?;
match db.provider_id_by_name(&arg.llm_model)? { match conn.provider_id_by_name(&arg.llm_model)? {
Some(provider_id) => { Some(provider_id) => {
debug!( debug!(
"Use existing DB provider #{} matches model name `{}`", "Use existing DB provider #{} matches model name `{}`",
@ -46,12 +51,11 @@ async fn main() -> Result<()> {
provider_id provider_id
} }
None => { None => {
let provider_id = db.insert_provider(&arg.llm_model)?; let provider_id = conn.insert_provider(&arg.llm_model)?;
info!( info!(
"Provider `{}` not found in database, created new one with ID `{provider_id}`", "Provider `{}` not found in database, created new one with ID `{provider_id}`",
&arg.llm_model &arg.llm_model
); );
db.commit()?;
provider_id provider_id
} }
} }
@ -60,42 +64,38 @@ async fn main() -> Result<()> {
info!("Daemon started"); info!("Daemon started");
loop { loop {
debug!("New queue begin..."); debug!("New queue begin...");
{ let mut tx = db.transaction()?;
// disconnect from the database immediately when exiting this scope, for source in tx.contents_queue_for_provider_id(provider_id)? {
// in case the `update` queue is enabled and pending for a while. debug!(
let mut db = tx(&arg)?; "Begin generating `content_id` #{} using `provider_id` #{provider_id}.",
for source in db.contents_queue_for_provider_id(provider_id)? { source.content_id
debug!( );
"Begin generating `content_id` #{} using `provider_id` #{provider_id}.",
source.content_id
);
let title = llm let title =
.chat_completion(ChatCompletionRequest::new(&arg.llm_model).message( llm.chat_completion(ChatCompletionRequest::new(&arg.llm_model).message(
Message::user(format!("{}\n{}", arg.llm_message, source.title)), Message::user(format!("{}\n{}", arg.llm_message, source.title)),
)) ))
.await?; .await?;
let description = llm let description =
.chat_completion(ChatCompletionRequest::new(&arg.llm_model).message( llm.chat_completion(ChatCompletionRequest::new(&arg.llm_model).message(
Message::user(format!("{}\n{}", arg.llm_message, source.description)), Message::user(format!("{}\n{}", arg.llm_message, source.description)),
)) ))
.await?; .await?;
let content_id = db.insert_content( let content_id = tx.insert_content(
source.channel_item_id, source.channel_item_id,
Some(provider_id), Some(provider_id),
&title.choices[0].message.content, &title.choices[0].message.content,
&description.choices[0].message.content, &description.choices[0].message.content,
)?; )?;
debug!( debug!(
"Created `content_id` #{content_id} using `content_id` #{} source by `provider_id` #{provider_id}.", "Created `content_id` #{content_id} using `content_id` #{} source by `provider_id` #{provider_id}.",
source.content_id source.content_id
) )
}
db.commit()?
} }
tx.commit()?;
debug!("Queue completed"); debug!("Queue completed");
if let Some(update) = arg.update { if let Some(update) = arg.update {
debug!("Wait {update} seconds to continue..."); debug!("Wait {update} seconds to continue...");
@ -105,15 +105,3 @@ async fn main() -> Result<()> {
} }
} }
} }
// in fact, there is no need for a transaction at this moment,
// as there are no related table updates, but who knows what the future holds
fn tx(arg: &Argument) -> Result<Transactional> {
Ok(Transactional::connect(
&arg.mysql_host,
arg.mysql_port,
&arg.mysql_username,
&arg.mysql_password,
&arg.mysql_database,
)?)
}

View file

@ -10,9 +10,8 @@ keywords = ["rssto", "database", "mysql", "library", "driver", "api"]
repository = "https://github.com/YGGverse/rssto" repository = "https://github.com/YGGverse/rssto"
[features] [features]
default = ["pollable"] default = []
pollable = [] transaction = []
transactional = []
[dependencies] [dependencies]
mysql = "26.0.1" mysql = "26.0.1"

View file

@ -1,32 +1,20 @@
pub mod sort;
pub use sort::Sort;
use crate::table::*; use crate::table::*;
use mysql::{Error, Pool, prelude::Queryable}; use mysql::{Error, Pool, PooledConn, prelude::Queryable};
/// Safe, read-only operations used in client apps like `rssto-http` /// Safe, read-only operations used in client apps like `rssto-http`
pub struct Pollable { pub struct Connection {
pool: Pool, conn: PooledConn,
} }
impl Pollable { impl Connection {
pub fn connect( pub fn create(pool: &Pool) -> Result<Self, Error> {
host: &str,
port: u16,
user: &str,
password: &str,
database: &str,
) -> Result<Self, Error> {
Ok(Self { Ok(Self {
pool: mysql::Pool::new( conn: pool.get_conn()?,
format!("mysql://{user}:{password}@{host}:{port}/{database}").as_str(),
)?,
}) })
} }
pub fn channel_item(&self, channel_item_id: u64) -> Result<Option<ChannelItem>, Error> { pub fn channel_item(&mut self, channel_item_id: u64) -> Result<Option<ChannelItem>, Error> {
self.pool.get_conn()?.exec_first( self.conn.exec_first(
"SELECT `channel_item_id`, "SELECT `channel_item_id`,
`channel_id`, `channel_id`,
`pub_date`, `pub_date`,
@ -38,8 +26,8 @@ impl Pollable {
) )
} }
pub fn content(&self, content_id: u64) -> Result<Option<Content>, Error> { pub fn content(&mut self, content_id: u64) -> Result<Option<Content>, Error> {
self.pool.get_conn()?.exec_first( self.conn.exec_first(
"SELECT `content_id`, "SELECT `content_id`,
`channel_item_id`, `channel_item_id`,
`provider_id`, `provider_id`,
@ -50,11 +38,11 @@ impl Pollable {
} }
pub fn contents_total_by_provider_id( pub fn contents_total_by_provider_id(
&self, &mut self,
provider_id: Option<u64>, provider_id: Option<u64>,
keyword: Option<&str>, keyword: Option<&str>,
) -> Result<usize, Error> { ) -> Result<usize, Error> {
let total: Option<usize> = self.pool.get_conn()?.exec_first( let total: Option<usize> = self.conn.exec_first(
"SELECT COUNT(*) FROM `content` WHERE `provider_id` <=> ? AND `title` LIKE ?", "SELECT COUNT(*) FROM `content` WHERE `provider_id` <=> ? AND `title` LIKE ?",
(provider_id, like(keyword)), (provider_id, like(keyword)),
)?; )?;
@ -62,13 +50,13 @@ impl Pollable {
} }
pub fn contents_by_provider_id( pub fn contents_by_provider_id(
&self, &mut self,
provider_id: Option<u64>, provider_id: Option<u64>,
keyword: Option<&str>, keyword: Option<&str>,
sort: Sort, sort: Sort,
limit: Option<usize>, limit: Option<usize>,
) -> Result<Vec<Content>, Error> { ) -> Result<Vec<Content>, Error> {
self.pool.get_conn()?.exec(format!( self.conn.exec(format!(
"SELECT `content_id`, "SELECT `content_id`,
`channel_item_id`, `channel_item_id`,
`provider_id`, `provider_id`,
@ -79,8 +67,8 @@ impl Pollable {
(provider_id, like(keyword), )) (provider_id, like(keyword), ))
} }
pub fn content_image(&self, content_image_id: u64) -> Result<Option<ContentImage>, Error> { pub fn content_image(&mut self, content_image_id: u64) -> Result<Option<ContentImage>, Error> {
self.pool.get_conn()?.exec_first( self.conn.exec_first(
"SELECT `content_image_id`, "SELECT `content_image_id`,
`content_id`, `content_id`,
`image_id`, `image_id`,
@ -92,17 +80,24 @@ impl Pollable {
) )
} }
pub fn images(&self, limit: Option<usize>) -> Result<Vec<Image>, Error> { pub fn images(&mut self, limit: Option<usize>) -> Result<Vec<Image>, Error> {
self.pool.get_conn()?.query(format!( self.conn.query(format!(
"SELECT `image_id`, `source`, `data` FROM `image` LIMIT {}", "SELECT `image_id`, `source`, `data` FROM `image` LIMIT {}",
limit.unwrap_or(DEFAULT_LIMIT) limit.unwrap_or(DEFAULT_LIMIT)
)) ))
} }
pub fn insert_provider(&self, name: &str) -> Result<u64, Error> { pub fn provider_id_by_name(&mut self, name: &str) -> Result<Option<u64>, Error> {
let mut c = self.pool.get_conn()?; self.conn.exec_first(
c.exec_drop("INSERT INTO `provider` SET `name` = ?", (name,))?; "SELECT `provider_id` FROM `provider` WHERE `name` = ?",
Ok(c.last_insert_id()) (name,),
)
}
pub fn insert_provider(&mut self, name: &str) -> Result<u64, Error> {
self.conn
.exec_drop("INSERT INTO `provider` SET `name` = ?", (name,))?;
Ok(self.conn.last_insert_id())
} }
} }

View file

@ -1,13 +1,36 @@
#[cfg(feature = "pollable")] mod connection;
pub mod pollable;
pub mod table; pub mod table;
#[cfg(feature = "transaction")]
mod transaction;
#[cfg(feature = "transactional")] pub use connection::Connection;
pub mod transactional; #[cfg(feature = "transaction")]
pub use transaction::Transaction;
pub struct Database {
pool: mysql::Pool,
}
#[cfg(feature = "pollable")] impl Database {
pub use pollable::Pollable; pub fn pool(
host: &str,
port: u16,
user: &str,
password: &str,
database: &str,
) -> Result<Self, mysql::Error> {
Ok(Self {
pool: mysql::Pool::new(
format!("mysql://{user}:{password}@{host}:{port}/{database}").as_str(),
)?,
})
}
#[cfg(feature = "transactional")] pub fn connection(&self) -> Result<Connection, mysql::Error> {
pub use transactional::Transactional; Connection::create(&self.pool)
}
#[cfg(feature = "transaction")]
pub fn transaction(&self) -> Result<Transaction, mysql::Error> {
Transaction::create(&self.pool)
}
}

View file

@ -1,13 +0,0 @@
pub enum Sort {
Asc,
Desc,
}
impl std::fmt::Display for Sort {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Asc => write!(f, "ASC"),
Self::Desc => write!(f, "DESC"),
}
}
}

View file

@ -51,3 +51,17 @@ pub struct ContentImage {
pub data: Vec<u8>, pub data: Vec<u8>,
pub source: String, pub source: String,
} }
pub enum Sort {
Asc,
Desc,
}
impl std::fmt::Display for Sort {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Asc => write!(f, "ASC"),
Self::Desc => write!(f, "DESC"),
}
}
}

View file

@ -1,24 +1,17 @@
use crate::table::*; use crate::table::*;
use mysql::{Error, Pool, Transaction, TxOpts, prelude::Queryable}; use mysql::{Error, Pool, TxOpts, prelude::Queryable};
/// Safe, optimized read/write operations /// Safe, optimized read/write operations
/// mostly required by the `rssto-crawler` and `rssto-llm` /// mostly required by the `rssto-crawler` and `rssto-llm`
/// * all members implementation requires `commit` action /// * all members implementation requires `commit` action
pub struct Transactional { pub struct Transaction {
tx: Transaction<'static>, tx: mysql::Transaction<'static>,
} }
impl Transactional { impl Transaction {
pub fn connect( pub fn create(pool: &Pool) -> Result<Self, Error> {
host: &str,
port: u16,
user: &str,
password: &str,
database: &str,
) -> Result<Self, Error> {
Ok(Self { Ok(Self {
tx: Pool::new(format!("mysql://{user}:{password}@{host}:{port}/{database}").as_str())? tx: pool.start_transaction(TxOpts::default())?,
.start_transaction(TxOpts::default())?,
}) })
} }
@ -26,6 +19,10 @@ impl Transactional {
self.tx.commit() self.tx.commit()
} }
pub fn rollback(self) -> Result<(), Error> {
self.tx.rollback()
}
pub fn channel_id_by_url(&mut self, url: &str) -> Result<Option<u64>, Error> { pub fn channel_id_by_url(&mut self, url: &str) -> Result<Option<u64>, Error> {
self.tx.exec_first( self.tx.exec_first(
"SELECT `channel_id` FROM `channel` WHERE `url` = ? LIMIT 1", "SELECT `channel_id` FROM `channel` WHERE `url` = ? LIMIT 1",
@ -132,17 +129,4 @@ impl Transactional {
)?; )?;
Ok(self.tx.last_insert_id().unwrap()) Ok(self.tx.last_insert_id().unwrap())
} }
pub fn provider_id_by_name(&mut self, name: &str) -> Result<Option<u64>, Error> {
self.tx.exec_first(
"SELECT `provider_id` FROM `provider` WHERE `name` = ?",
(name,),
)
}
pub fn insert_provider(&mut self, name: &str) -> Result<u64, Error> {
self.tx
.exec_drop("INSERT INTO `provider` SET `name` = ?", (name,))?;
Ok(self.tx.last_insert_id().unwrap())
}
} }