add no-confirm flag for scripting

This commit is contained in:
Namekuji 2023-09-26 17:31:21 -04:00
parent 156ba277fd
commit 93978bbe83
No known key found for this signature in database
GPG key ID: 1D62332C07FBA532
2 changed files with 54 additions and 32 deletions

View file

@ -44,10 +44,13 @@ pub async fn run_cli() -> Result<(), Error> {
MigrationCommand::Setup { MigrationCommand::Setup {
threads, threads,
note_since, note_since,
note_skip note_skip,
no_confirm,
} => { } => {
let initializer = Initializer::new(&scylla_conf, &config.db).await?; let initializer =
initializer.setup(threads, note_skip, note_since).await?; Initializer::new(&scylla_conf, &config.db, note_since, note_skip, no_confirm)
.await?;
initializer.setup(threads).await?;
} }
_ => {} _ => {}
}; };
@ -144,5 +147,13 @@ pub(crate) enum MigrationCommand {
display_order = 43 display_order = 43
)] )]
note_skip: u64, note_skip: u64,
#[clap(
value_parser,
long,
default_value = "false",
help = "Does not confirm before the process begins",
display_order = 44
)]
no_confirm: bool,
}, },
} }

View file

@ -25,12 +25,18 @@ use crate::{
pub(crate) struct Initializer { pub(crate) struct Initializer {
scylla: Arc<Session>, scylla: Arc<Session>,
postgres_url: String, postgres_url: String,
note_since: Option<String>,
note_skip: u64,
no_confirm: bool,
} }
impl Initializer { impl Initializer {
pub(crate) async fn new( pub(crate) async fn new(
scylla_conf: &ScyllaConfig, scylla_conf: &ScyllaConfig,
postgres_conf: &DbConfig, postgres_conf: &DbConfig,
note_since: Option<String>,
note_skip: u64,
no_confirm: bool,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let mut builder = SessionBuilder::new().known_nodes(&scylla_conf.nodes); let mut builder = SessionBuilder::new().known_nodes(&scylla_conf.nodes);
@ -56,15 +62,13 @@ impl Initializer {
Ok(Self { Ok(Self {
scylla: Arc::new(session), scylla: Arc::new(session),
postgres_url: conn_url, postgres_url: conn_url,
note_since,
note_skip,
no_confirm,
}) })
} }
pub(crate) async fn setup( pub(crate) async fn setup(&self, threads: u32) -> Result<(), Error> {
&self,
threads: u32,
skip: u64,
since: Option<String>,
) -> Result<(), Error> {
println!("Several tables in PostgreSQL are going to be moved to ScyllaDB."); println!("Several tables in PostgreSQL are going to be moved to ScyllaDB.");
let pool = Database::connect(&self.postgres_url).await?; let pool = Database::connect(&self.postgres_url).await?;
@ -78,6 +82,7 @@ impl Initializer {
.bold() .bold()
); );
if !self.no_confirm {
let confirm = Confirm::with_theme(&ColorfulTheme::default()) let confirm = Confirm::with_theme(&ColorfulTheme::default())
.with_prompt("This process may take a while. Do you want to continue?") .with_prompt("This process may take a while. Do you want to continue?")
.interact() .interact()
@ -87,9 +92,10 @@ impl Initializer {
println!("Cancelled."); println!("Cancelled.");
return Ok(()); return Ok(());
} }
}
println!("Copying data from PostgreSQL to ScyllaDB."); println!("Copying data from PostgreSQL to ScyllaDB.");
self.copy(pool.clone(), threads.try_into().unwrap_or(1), skip, since) self.copy(pool.clone(), threads.try_into().unwrap_or(1))
.await?; .await?;
println!("Dropping constraints from PostgreSQL."); println!("Dropping constraints from PostgreSQL.");
@ -132,13 +138,7 @@ impl Initializer {
Ok(()) Ok(())
} }
async fn copy( async fn copy(&self, db: DatabaseConnection, threads: usize) -> Result<(), Error> {
&self,
db: DatabaseConnection,
threads: usize,
note_skip: u64,
note_since: Option<String>,
) -> Result<(), Error> {
let note_prepared = Arc::new(self.scylla.prepare(INSERT_NOTE).await?); let note_prepared = Arc::new(self.scylla.prepare(INSERT_NOTE).await?);
let home_prepared = Arc::new(self.scylla.prepare(INSERT_HOME_TIMELINE).await?); let home_prepared = Arc::new(self.scylla.prepare(INSERT_HOME_TIMELINE).await?);
let reaction_prepared = Arc::new(self.scylla.prepare(INSERT_REACTION).await?); let reaction_prepared = Arc::new(self.scylla.prepare(INSERT_REACTION).await?);
@ -148,7 +148,7 @@ impl Initializer {
let mut num_notes = note::Entity::find(); let mut num_notes = note::Entity::find();
if let Some(since) = note_since.clone() { if let Some(since) = self.note_since.clone() {
num_notes = num_notes.filter(note::Column::Id.gt(&since)); num_notes = num_notes.filter(note::Column::Id.gt(&since));
} }
@ -159,7 +159,7 @@ impl Initializer {
.one(&db) .one(&db)
.await? .await?
.unwrap_or_default(); .unwrap_or_default();
num_notes -= note_skip as i64; num_notes -= self.note_skip as i64;
println!("Posts: {num_notes}"); println!("Posts: {num_notes}");
let num_reactions: i64 = note_reaction::Entity::find() let num_reactions: i64 = note_reaction::Entity::find()
.select_only() .select_only()
@ -217,7 +217,7 @@ impl Initializer {
let sem = Arc::new(Semaphore::new(threads)); let sem = Arc::new(Semaphore::new(threads));
let mut notes = note::Entity::find().order_by_asc(note::Column::Id); let mut notes = note::Entity::find().order_by_asc(note::Column::Id);
if let Some(since_id) = note_since { if let Some(since_id) = self.note_since.clone() {
notes = notes.filter(note::Column::Id.gt(&since_id)); notes = notes.filter(note::Column::Id.gt(&since_id));
} }
let mut notes = notes.stream(&db).await?; let mut notes = notes.stream(&db).await?;
@ -226,7 +226,7 @@ impl Initializer {
while let Some(note) = notes.try_next().await? { while let Some(note) = notes.try_next().await? {
copied += 1; copied += 1;
if copied <= note_skip { if copied <= self.note_skip {
continue; continue;
} }
if let Ok(permit) = Arc::clone(&sem).acquire_owned().await { if let Ok(permit) = Arc::clone(&sem).acquire_owned().await {
@ -249,10 +249,13 @@ impl Initializer {
if tasks.len() > 1000 { if tasks.len() > 1000 {
future::join_all(tasks).await; future::join_all(tasks).await;
tasks = Vec::new() tasks = Vec::new();
} }
} }
future::join_all(tasks).await;
tasks = Vec::new();
let mut reactions = note_reaction::Entity::find() let mut reactions = note_reaction::Entity::find()
.order_by_asc(note_reaction::Column::Id) .order_by_asc(note_reaction::Column::Id)
.stream(&db) .stream(&db)
@ -272,10 +275,13 @@ impl Initializer {
if tasks.len() > 1000 { if tasks.len() > 1000 {
future::join_all(tasks).await; future::join_all(tasks).await;
tasks = Vec::new() tasks = Vec::new();
} }
} }
future::join_all(tasks).await;
tasks = Vec::new();
let mut votes = poll_vote::Entity::find() let mut votes = poll_vote::Entity::find()
.order_by_asc(poll_vote::Column::Id) .order_by_asc(poll_vote::Column::Id)
.stream(&db) .stream(&db)
@ -301,10 +307,13 @@ impl Initializer {
if tasks.len() > 1000 { if tasks.len() > 1000 {
future::join_all(tasks).await; future::join_all(tasks).await;
tasks = Vec::new() tasks = Vec::new();
} }
} }
future::join_all(tasks).await;
tasks = Vec::new();
let mut notifications = notification::Entity::find() let mut notifications = notification::Entity::find()
.order_by_asc(notification::Column::Id) .order_by_asc(notification::Column::Id)
.stream(&db) .stream(&db)
@ -329,10 +338,12 @@ impl Initializer {
if tasks.len() > 1000 { if tasks.len() > 1000 {
future::join_all(tasks).await; future::join_all(tasks).await;
tasks = Vec::new() tasks = Vec::new();
} }
} }
future::join_all(tasks).await;
Ok(()) Ok(())
} }