From 0cfa85197df4461294d8d01487ee3c1ee513b75b Mon Sep 17 00:00:00 2001 From: sup39 Date: Fri, 12 Apr 2024 14:38:44 +0900 Subject: [PATCH] refactor (backend-rs): rewrite the function to get db connection Co-authored-by: naskya --- Cargo.lock | 1 + Cargo.toml | 1 + packages/backend-rs/Cargo.toml | 1 + packages/backend-rs/src/database/mod.rs | 31 +++++++++++++++++-------- 4 files changed, 24 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 09c4946e17..97190cdc6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -212,6 +212,7 @@ dependencies = [ "serde_yaml", "thiserror", "tokio", + "urlencoding", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index fcb5be2f0b..c983a9de36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ serde_yaml = "0.9.34" syn = "2.0.58" thiserror = "1.0.58" tokio = "1.37.0" +urlencoding = "2.1.3" [profile.release] lto = true diff --git a/packages/backend-rs/Cargo.toml b/packages/backend-rs/Cargo.toml index 1429548a67..0c0fdd16ab 100644 --- a/packages/backend-rs/Cargo.toml +++ b/packages/backend-rs/Cargo.toml @@ -34,6 +34,7 @@ serde_json = { workspace = true } serde_yaml = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["full"] } +urlencoding = { workspace = true } [dev-dependencies] pretty_assertions = { workspace = true } diff --git a/packages/backend-rs/src/database/mod.rs b/packages/backend-rs/src/database/mod.rs index 739f39bb64..eb4935f6e9 100644 --- a/packages/backend-rs/src/database/mod.rs +++ b/packages/backend-rs/src/database/mod.rs @@ -1,26 +1,37 @@ pub mod error; +use crate::config::server::SERVER_CONFIG; use error::Error; use sea_orm::{Database, DbConn}; static DB_CONN: once_cell::sync::OnceCell = once_cell::sync::OnceCell::new(); -pub async fn init_database(conn_uri: impl Into) -> Result<(), Error> { - let conn = Database::connect(conn_uri.into()).await?; - DB_CONN.get_or_init(move || conn); - Ok(()) +async fn init_database() -> Result<&'static DbConn, Error> { + let database_uri = format!( + "postgres://{}:{}@{}:{}/{}", + SERVER_CONFIG.db.user, + urlencoding::encode(&SERVER_CONFIG.db.pass), + SERVER_CONFIG.db.host, + SERVER_CONFIG.db.port, + SERVER_CONFIG.db.db, + ); + let conn = Database::connect(database_uri).await?; + Ok(DB_CONN.get_or_init(move || conn)) } -pub fn get_database() -> Result<&'static DbConn, Error> { - DB_CONN.get().ok_or(Error::Uninitialized) +pub async fn db_conn() -> Result<&'static DbConn, Error> { + match DB_CONN.get() { + Some(conn) => Ok(conn), + None => init_database().await, + } } #[cfg(test)] mod unit_test { - use super::{error::Error, get_database}; + use super::db_conn; - #[test] - fn error_uninitialized() { - assert_eq!(get_database().unwrap_err(), Error::Uninitialized); + #[tokio::test] + async fn connect_test() { + assert!(db_conn().await.is_ok()); } }