From 7aa4ffa4255edbefe06db30bab0f1ee84d8c7af0 Mon Sep 17 00:00:00 2001 From: Namekuji Date: Mon, 18 Sep 2023 00:32:41 -0400 Subject: [PATCH] fix: add semaphore --- .../native-utils/scylla-migration/src/cli.rs | 10 +- .../scylla-migration/src/setup.rs | 123 +++++++++--------- 2 files changed, 64 insertions(+), 69 deletions(-) diff --git a/packages/backend/native-utils/scylla-migration/src/cli.rs b/packages/backend/native-utils/scylla-migration/src/cli.rs index c43bd7a680..c772dcb504 100644 --- a/packages/backend/native-utils/scylla-migration/src/cli.rs +++ b/packages/backend/native-utils/scylla-migration/src/cli.rs @@ -41,9 +41,9 @@ pub async fn run_cli() -> Result<(), Error> { .down(num) .await? } - MigrationCommand::Setup { multi_thread } => { + MigrationCommand::Setup { threads } => { let initializer = Initializer::new(&scylla_conf, &config.db).await?; - initializer.setup(multi_thread).await?; + initializer.setup(threads).await?; } _ => {} }; @@ -120,10 +120,10 @@ pub(crate) enum MigrationCommand { value_parser, short, long, - default_value = "false", - help = "Enable multi-thread mode (WARNING: High memory consumption)", + default_value = "1", + help = "Number of threads used to copy", display_order = 41 )] - multi_thread: bool, + threads: u32, }, } diff --git a/packages/backend/native-utils/scylla-migration/src/setup.rs b/packages/backend/native-utils/scylla-migration/src/setup.rs index 2c75b07b81..9a5c55b323 100644 --- a/packages/backend/native-utils/scylla-migration/src/setup.rs +++ b/packages/backend/native-utils/scylla-migration/src/setup.rs @@ -10,6 +10,7 @@ use scylla::{ SessionBuilder, ValueList, }; use sea_orm::{entity::*, query::*, Database, DatabaseConnection}; +use tokio::sync::Semaphore; use urlencoding::encode; use crate::{ @@ -52,7 +53,7 @@ impl Initializer { }) } - pub(crate) async fn setup(&self, mt: bool) -> Result<(), Error> { + pub(crate) async fn setup(&self, threads: u32) -> Result<(), Error> { println!("Several tables in PostgreSQL are going to be moved to ScyllaDB."); let pool = Database::connect(&self.postgres_url).await?; @@ -77,7 +78,8 @@ impl Initializer { } println!("Copying data from PostgreSQL to ScyllaDB."); - self.copy(pool.clone(), mt).await?; + self.copy(pool.clone(), threads.try_into().unwrap_or(1)) + .await?; println!("Dropping constraints from PostgreSQL."); let fk_pairs = vec![ @@ -119,7 +121,7 @@ impl Initializer { Ok(()) } - async fn copy(&self, db: DatabaseConnection, multi_thread: bool) -> Result<(), Error> { + async fn copy(&self, db: DatabaseConnection, threads: usize) -> Result<(), Error> { let note_prepared = Arc::new(self.scylla.prepare(INSERT_NOTE).await?); let home_prepared = Arc::new(self.scylla.prepare(INSERT_HOME_TIMELINE).await?); let reaction_prepared = Arc::new(self.scylla.prepare(INSERT_REACTION).await?); @@ -188,30 +190,29 @@ impl Initializer { ); let mut tasks = Vec::new(); + let sem = Arc::new(Semaphore::new(threads)); let mut notes = note::Entity::find() .order_by_asc(note::Column::Id) .stream(&db) .await?; while let Some(note) = notes.try_next().await? { - let (s, d, n, h, p) = ( - self.clone(), - db.clone(), - note_prepared.clone(), - home_prepared.clone(), - note_pb.clone(), - ); - let f = async move { - if let Err(e) = s.copy_note(note, d, n, h).await { - p.println(format!("Note copy error: {e}")); - } - p.inc(1); - }; - if multi_thread { - let handler = tokio::spawn(f); + if let Ok(permit) = Arc::clone(&sem).acquire_owned().await { + let (s, d, n, h, p) = ( + self.clone(), + db.clone(), + note_prepared.clone(), + home_prepared.clone(), + note_pb.clone(), + ); + let handler = tokio::spawn(async move { + if let Err(e) = s.copy_note(note, d, n, h).await { + p.println(format!("Note copy error: {e}")); + } + p.inc(1); + drop(permit); + }); tasks.push(handler); - } else { - (|| f)().await; } } @@ -220,18 +221,16 @@ impl Initializer { .stream(&db) .await?; while let Some(reaction) = reactions.try_next().await? { - let (s, r, p) = (self.clone(), reaction_prepared.clone(), reaction_pb.clone()); - let f = async move { - if let Err(e) = s.copy_reaction(reaction, r).await { - p.println(format!("Reaction copy error: {e}")); - } - p.inc(1); - }; - if multi_thread { - let handler = tokio::spawn(f); + if let Ok(permit) = Arc::clone(&sem).acquire_owned().await { + let (s, r, p) = (self.clone(), reaction_prepared.clone(), reaction_pb.clone()); + let handler = tokio::spawn(async move { + if let Err(e) = s.copy_reaction(reaction, r).await { + p.println(format!("Reaction copy error: {e}")); + } + p.inc(1); + drop(permit); + }); tasks.push(handler); - } else { - (|| f)().await; } } @@ -240,24 +239,22 @@ impl Initializer { .stream(&db) .await?; while let Some(vote) = votes.try_next().await? { - let (s, d, sp, ip, p) = ( - self.clone(), - db.clone(), - vote_select_prepared.clone(), - vote_insert_prepared.clone(), - vote_pb.clone(), - ); - let f = async move { - if let Err(e) = s.copy_vote(vote, d, sp, ip).await { - p.println(format!("Vote copy error: {e}")); - } - p.inc(1); - }; - if multi_thread { - let handler = tokio::spawn(f); + if let Ok(permit) = Arc::clone(&sem).acquire_owned().await { + let (s, d, sp, ip, p) = ( + self.clone(), + db.clone(), + vote_select_prepared.clone(), + vote_insert_prepared.clone(), + vote_pb.clone(), + ); + let handler = tokio::spawn(async move { + if let Err(e) = s.copy_vote(vote, d, sp, ip).await { + p.println(format!("Vote copy error: {e}")); + } + p.inc(1); + drop(permit); + }); tasks.push(handler); - } else { - (|| f)().await; } } @@ -266,23 +263,21 @@ impl Initializer { .stream(&db) .await?; while let Some(n) = notifications.try_next().await? { - let (s, d, ps, p) = ( - self.clone(), - db.clone(), - notification_prepared.clone(), - notification_pb.clone(), - ); - let f = async move { - if let Err(e) = s.copy_notification(n, d, ps).await { - p.println(format!("Notification copy error: {e}")); - } - p.inc(1); - }; - if multi_thread { - let handler = tokio::spawn(f); + if let Ok(permit) = Arc::clone(&sem).acquire_owned().await { + let (s, d, ps, p) = ( + self.clone(), + db.clone(), + notification_prepared.clone(), + notification_pb.clone(), + ); + let handler = tokio::spawn(async move { + if let Err(e) = s.copy_notification(n, d, ps).await { + p.println(format!("Notification copy error: {e}")); + } + p.inc(1); + drop(permit); + }); tasks.push(handler); - } else { - (|| f)().await; } }