fix: add semaphore

This commit is contained in:
Namekuji 2023-09-18 00:32:41 -04:00
parent ebdc9821b0
commit 7aa4ffa425
No known key found for this signature in database
GPG key ID: 1D62332C07FBA532
2 changed files with 64 additions and 69 deletions

View file

@ -41,9 +41,9 @@ pub async fn run_cli() -> Result<(), Error> {
.down(num) .down(num)
.await? .await?
} }
MigrationCommand::Setup { multi_thread } => { MigrationCommand::Setup { threads } => {
let initializer = Initializer::new(&scylla_conf, &config.db).await?; let initializer = Initializer::new(&scylla_conf, &config.db).await?;
initializer.setup(multi_thread).await?; initializer.setup(threads).await?;
} }
_ => {} _ => {}
}; };
@ -120,10 +120,10 @@ pub(crate) enum MigrationCommand {
value_parser, value_parser,
short, short,
long, long,
default_value = "false", default_value = "1",
help = "Enable multi-thread mode (WARNING: High memory consumption)", help = "Number of threads used to copy",
display_order = 41 display_order = 41
)] )]
multi_thread: bool, threads: u32,
}, },
} }

View file

@ -10,6 +10,7 @@ use scylla::{
SessionBuilder, ValueList, SessionBuilder, ValueList,
}; };
use sea_orm::{entity::*, query::*, Database, DatabaseConnection}; use sea_orm::{entity::*, query::*, Database, DatabaseConnection};
use tokio::sync::Semaphore;
use urlencoding::encode; use urlencoding::encode;
use crate::{ use crate::{
@ -52,7 +53,7 @@ impl Initializer {
}) })
} }
pub(crate) async fn setup(&self, mt: bool) -> Result<(), Error> { pub(crate) async fn setup(&self, threads: u32) -> Result<(), Error> {
println!("Several tables in PostgreSQL are going to be moved to ScyllaDB."); println!("Several tables in PostgreSQL are going to be moved to ScyllaDB.");
let pool = Database::connect(&self.postgres_url).await?; let pool = Database::connect(&self.postgres_url).await?;
@ -77,7 +78,8 @@ impl Initializer {
} }
println!("Copying data from PostgreSQL to ScyllaDB."); println!("Copying data from PostgreSQL to ScyllaDB.");
self.copy(pool.clone(), mt).await?; self.copy(pool.clone(), threads.try_into().unwrap_or(1))
.await?;
println!("Dropping constraints from PostgreSQL."); println!("Dropping constraints from PostgreSQL.");
let fk_pairs = vec![ let fk_pairs = vec![
@ -119,7 +121,7 @@ impl Initializer {
Ok(()) Ok(())
} }
async fn copy(&self, db: DatabaseConnection, multi_thread: bool) -> Result<(), Error> { async fn copy(&self, db: DatabaseConnection, threads: usize) -> Result<(), Error> {
let note_prepared = Arc::new(self.scylla.prepare(INSERT_NOTE).await?); let note_prepared = Arc::new(self.scylla.prepare(INSERT_NOTE).await?);
let home_prepared = Arc::new(self.scylla.prepare(INSERT_HOME_TIMELINE).await?); let home_prepared = Arc::new(self.scylla.prepare(INSERT_HOME_TIMELINE).await?);
let reaction_prepared = Arc::new(self.scylla.prepare(INSERT_REACTION).await?); let reaction_prepared = Arc::new(self.scylla.prepare(INSERT_REACTION).await?);
@ -188,12 +190,14 @@ impl Initializer {
); );
let mut tasks = Vec::new(); let mut tasks = Vec::new();
let sem = Arc::new(Semaphore::new(threads));
let mut notes = note::Entity::find() let mut notes = note::Entity::find()
.order_by_asc(note::Column::Id) .order_by_asc(note::Column::Id)
.stream(&db) .stream(&db)
.await?; .await?;
while let Some(note) = notes.try_next().await? { while let Some(note) = notes.try_next().await? {
if let Ok(permit) = Arc::clone(&sem).acquire_owned().await {
let (s, d, n, h, p) = ( let (s, d, n, h, p) = (
self.clone(), self.clone(),
db.clone(), db.clone(),
@ -201,17 +205,14 @@ impl Initializer {
home_prepared.clone(), home_prepared.clone(),
note_pb.clone(), note_pb.clone(),
); );
let f = async move { let handler = tokio::spawn(async move {
if let Err(e) = s.copy_note(note, d, n, h).await { if let Err(e) = s.copy_note(note, d, n, h).await {
p.println(format!("Note copy error: {e}")); p.println(format!("Note copy error: {e}"));
} }
p.inc(1); p.inc(1);
}; drop(permit);
if multi_thread { });
let handler = tokio::spawn(f);
tasks.push(handler); tasks.push(handler);
} else {
(|| f)().await;
} }
} }
@ -220,18 +221,16 @@ impl Initializer {
.stream(&db) .stream(&db)
.await?; .await?;
while let Some(reaction) = reactions.try_next().await? { while let Some(reaction) = reactions.try_next().await? {
if let Ok(permit) = Arc::clone(&sem).acquire_owned().await {
let (s, r, p) = (self.clone(), reaction_prepared.clone(), reaction_pb.clone()); let (s, r, p) = (self.clone(), reaction_prepared.clone(), reaction_pb.clone());
let f = async move { let handler = tokio::spawn(async move {
if let Err(e) = s.copy_reaction(reaction, r).await { if let Err(e) = s.copy_reaction(reaction, r).await {
p.println(format!("Reaction copy error: {e}")); p.println(format!("Reaction copy error: {e}"));
} }
p.inc(1); p.inc(1);
}; drop(permit);
if multi_thread { });
let handler = tokio::spawn(f);
tasks.push(handler); tasks.push(handler);
} else {
(|| f)().await;
} }
} }
@ -240,6 +239,7 @@ impl Initializer {
.stream(&db) .stream(&db)
.await?; .await?;
while let Some(vote) = votes.try_next().await? { while let Some(vote) = votes.try_next().await? {
if let Ok(permit) = Arc::clone(&sem).acquire_owned().await {
let (s, d, sp, ip, p) = ( let (s, d, sp, ip, p) = (
self.clone(), self.clone(),
db.clone(), db.clone(),
@ -247,17 +247,14 @@ impl Initializer {
vote_insert_prepared.clone(), vote_insert_prepared.clone(),
vote_pb.clone(), vote_pb.clone(),
); );
let f = async move { let handler = tokio::spawn(async move {
if let Err(e) = s.copy_vote(vote, d, sp, ip).await { if let Err(e) = s.copy_vote(vote, d, sp, ip).await {
p.println(format!("Vote copy error: {e}")); p.println(format!("Vote copy error: {e}"));
} }
p.inc(1); p.inc(1);
}; drop(permit);
if multi_thread { });
let handler = tokio::spawn(f);
tasks.push(handler); tasks.push(handler);
} else {
(|| f)().await;
} }
} }
@ -266,23 +263,21 @@ impl Initializer {
.stream(&db) .stream(&db)
.await?; .await?;
while let Some(n) = notifications.try_next().await? { while let Some(n) = notifications.try_next().await? {
if let Ok(permit) = Arc::clone(&sem).acquire_owned().await {
let (s, d, ps, p) = ( let (s, d, ps, p) = (
self.clone(), self.clone(),
db.clone(), db.clone(),
notification_prepared.clone(), notification_prepared.clone(),
notification_pb.clone(), notification_pb.clone(),
); );
let f = async move { let handler = tokio::spawn(async move {
if let Err(e) = s.copy_notification(n, d, ps).await { if let Err(e) = s.copy_notification(n, d, ps).await {
p.println(format!("Notification copy error: {e}")); p.println(format!("Notification copy error: {e}"));
} }
p.inc(1); p.inc(1);
}; drop(permit);
if multi_thread { });
let handler = tokio::spawn(f);
tasks.push(handler); tasks.push(handler);
} else {
(|| f)().await;
} }
} }