fix: add semaphore

This commit is contained in:
Namekuji 2023-09-18 00:32:41 -04:00
parent ebdc9821b0
commit 7aa4ffa425
No known key found for this signature in database
GPG key ID: 1D62332C07FBA532
2 changed files with 64 additions and 69 deletions

View file

@ -41,9 +41,9 @@ pub async fn run_cli() -> Result<(), Error> {
.down(num)
.await?
}
MigrationCommand::Setup { multi_thread } => {
MigrationCommand::Setup { threads } => {
let initializer = Initializer::new(&scylla_conf, &config.db).await?;
initializer.setup(multi_thread).await?;
initializer.setup(threads).await?;
}
_ => {}
};
@ -120,10 +120,10 @@ pub(crate) enum MigrationCommand {
value_parser,
short,
long,
default_value = "false",
help = "Enable multi-thread mode (WARNING: High memory consumption)",
default_value = "1",
help = "Number of threads used to copy",
display_order = 41
)]
multi_thread: bool,
threads: u32,
},
}

View file

@ -10,6 +10,7 @@ use scylla::{
SessionBuilder, ValueList,
};
use sea_orm::{entity::*, query::*, Database, DatabaseConnection};
use tokio::sync::Semaphore;
use urlencoding::encode;
use crate::{
@ -52,7 +53,7 @@ impl Initializer {
})
}
pub(crate) async fn setup(&self, mt: bool) -> Result<(), Error> {
pub(crate) async fn setup(&self, threads: u32) -> Result<(), Error> {
println!("Several tables in PostgreSQL are going to be moved to ScyllaDB.");
let pool = Database::connect(&self.postgres_url).await?;
@ -77,7 +78,8 @@ impl Initializer {
}
println!("Copying data from PostgreSQL to ScyllaDB.");
self.copy(pool.clone(), mt).await?;
self.copy(pool.clone(), threads.try_into().unwrap_or(1))
.await?;
println!("Dropping constraints from PostgreSQL.");
let fk_pairs = vec![
@ -119,7 +121,7 @@ impl Initializer {
Ok(())
}
async fn copy(&self, db: DatabaseConnection, multi_thread: bool) -> Result<(), Error> {
async fn copy(&self, db: DatabaseConnection, threads: usize) -> Result<(), Error> {
let note_prepared = Arc::new(self.scylla.prepare(INSERT_NOTE).await?);
let home_prepared = Arc::new(self.scylla.prepare(INSERT_HOME_TIMELINE).await?);
let reaction_prepared = Arc::new(self.scylla.prepare(INSERT_REACTION).await?);
@ -188,30 +190,29 @@ impl Initializer {
);
let mut tasks = Vec::new();
let sem = Arc::new(Semaphore::new(threads));
let mut notes = note::Entity::find()
.order_by_asc(note::Column::Id)
.stream(&db)
.await?;
while let Some(note) = notes.try_next().await? {
let (s, d, n, h, p) = (
self.clone(),
db.clone(),
note_prepared.clone(),
home_prepared.clone(),
note_pb.clone(),
);
let f = async move {
if let Err(e) = s.copy_note(note, d, n, h).await {
p.println(format!("Note copy error: {e}"));
}
p.inc(1);
};
if multi_thread {
let handler = tokio::spawn(f);
if let Ok(permit) = Arc::clone(&sem).acquire_owned().await {
let (s, d, n, h, p) = (
self.clone(),
db.clone(),
note_prepared.clone(),
home_prepared.clone(),
note_pb.clone(),
);
let handler = tokio::spawn(async move {
if let Err(e) = s.copy_note(note, d, n, h).await {
p.println(format!("Note copy error: {e}"));
}
p.inc(1);
drop(permit);
});
tasks.push(handler);
} else {
(|| f)().await;
}
}
@ -220,18 +221,16 @@ impl Initializer {
.stream(&db)
.await?;
while let Some(reaction) = reactions.try_next().await? {
let (s, r, p) = (self.clone(), reaction_prepared.clone(), reaction_pb.clone());
let f = async move {
if let Err(e) = s.copy_reaction(reaction, r).await {
p.println(format!("Reaction copy error: {e}"));
}
p.inc(1);
};
if multi_thread {
let handler = tokio::spawn(f);
if let Ok(permit) = Arc::clone(&sem).acquire_owned().await {
let (s, r, p) = (self.clone(), reaction_prepared.clone(), reaction_pb.clone());
let handler = tokio::spawn(async move {
if let Err(e) = s.copy_reaction(reaction, r).await {
p.println(format!("Reaction copy error: {e}"));
}
p.inc(1);
drop(permit);
});
tasks.push(handler);
} else {
(|| f)().await;
}
}
@ -240,24 +239,22 @@ impl Initializer {
.stream(&db)
.await?;
while let Some(vote) = votes.try_next().await? {
let (s, d, sp, ip, p) = (
self.clone(),
db.clone(),
vote_select_prepared.clone(),
vote_insert_prepared.clone(),
vote_pb.clone(),
);
let f = async move {
if let Err(e) = s.copy_vote(vote, d, sp, ip).await {
p.println(format!("Vote copy error: {e}"));
}
p.inc(1);
};
if multi_thread {
let handler = tokio::spawn(f);
if let Ok(permit) = Arc::clone(&sem).acquire_owned().await {
let (s, d, sp, ip, p) = (
self.clone(),
db.clone(),
vote_select_prepared.clone(),
vote_insert_prepared.clone(),
vote_pb.clone(),
);
let handler = tokio::spawn(async move {
if let Err(e) = s.copy_vote(vote, d, sp, ip).await {
p.println(format!("Vote copy error: {e}"));
}
p.inc(1);
drop(permit);
});
tasks.push(handler);
} else {
(|| f)().await;
}
}
@ -266,23 +263,21 @@ impl Initializer {
.stream(&db)
.await?;
while let Some(n) = notifications.try_next().await? {
let (s, d, ps, p) = (
self.clone(),
db.clone(),
notification_prepared.clone(),
notification_pb.clone(),
);
let f = async move {
if let Err(e) = s.copy_notification(n, d, ps).await {
p.println(format!("Notification copy error: {e}"));
}
p.inc(1);
};
if multi_thread {
let handler = tokio::spawn(f);
if let Ok(permit) = Arc::clone(&sem).acquire_owned().await {
let (s, d, ps, p) = (
self.clone(),
db.clone(),
notification_prepared.clone(),
notification_pb.clone(),
);
let handler = tokio::spawn(async move {
if let Err(e) = s.copy_notification(n, d, ps).await {
p.println(format!("Notification copy error: {e}"));
}
p.inc(1);
drop(permit);
});
tasks.push(handler);
} else {
(|| f)().await;
}
}