diff --git a/src/main.rs b/src/main.rs index bd7aa91..193df24 100644 --- a/src/main.rs +++ b/src/main.rs @@ -128,7 +128,7 @@ impl SavedIPs { .chain(self.ipv6.map(IpAddr::V6)) } - fn from_str(data: &str) -> miette::Result { + fn from_str(data: &str) -> Result { match data.parse::() { // Old format Ok(IpAddr::V4(ipv4)) => Ok(Self { @@ -145,7 +145,7 @@ impl SavedIPs { } impl AppState<'static> { - fn from_args(args: &Opts, config: &config::Config) -> miette::Result { + fn from_args(args: &Opts, config: &config::Config) -> Result { let Opts { verbosity: _, data_dir, @@ -180,7 +180,7 @@ impl AppState<'static> { // Load keyfile key_file: key_file .as_deref() - .map(|path| -> miette::Result<_> { + .map(|path| -> Result<_> { std::fs::File::open(path) .into_diagnostic() .wrap_err_with(|| { @@ -255,6 +255,18 @@ impl std::str::FromStr for Ipv6Prefix { } } +fn load_password(path: &Path) -> Result> { + let pass = std::fs::read_to_string(path).into_diagnostic()?; + + let pass: Box<[u8]> = URL_SAFE_NO_PAD + .decode(pass.trim().as_bytes()) + .into_diagnostic() + .wrap_err_with(|| format!("failed to decode password from {}", path.display()))? + .into(); + + Ok(pass) +} + #[tracing::instrument(err)] fn main() -> Result<()> { // set panic hook to pretty print with miette's formatter @@ -310,18 +322,8 @@ fn main() -> Result<()> { let password_hash = config .password .password_file - .map(|path| -> miette::Result<_> { - let path = path.as_path(); - let pass = std::fs::read_to_string(path).into_diagnostic()?; - - let pass: Box<[u8]> = URL_SAFE_NO_PAD - .decode(pass.trim().as_bytes()) - .into_diagnostic() - .wrap_err_with(|| format!("failed to decode password from {}", path.display()))? - .into(); - - Ok(pass) - }) + .as_deref() + .map(load_password) .transpose() .wrap_err("failed to load password hash")?; @@ -336,63 +338,70 @@ fn main() -> Result<()> { .into_diagnostic() .wrap_err("failed to start the tokio runtime")?; - rt.block_on(async { - // Update DNS record with previous IPs (if available) - let ips = state.last_ips.lock().await.clone(); + rt.block_on(async_main(state, config, password_hash)) + .wrap_err("failed to run main loop") +} - let mut actions = ips - .ips() - .filter(|ip| config.records.ip_type.valid_for_type(*ip)) - .flat_map(|ip| nsupdate::Action::from_records(ip, state.ttl, state.records)) - .peekable(); +#[tracing::instrument(err, skip(state, pass))] +async fn async_main( + state: AppState<'static>, + config: Config, + pass: Option>, +) -> Result<()> { + // Update DNS record with previous IPs (if available) + let ips = state.last_ips.lock().await.clone(); - if actions.peek().is_some() { - match nsupdate::nsupdate(state.key_file, actions).await { - Ok(status) => { - if !status.success() { - error!("nsupdate failed: code {status}"); - bail!("nsupdate returned with code {status}"); - } - } - Err(err) => { - error!("Failed to update records with previous IP: {err}"); - return Err(err) - .into_diagnostic() - .wrap_err("failed to update records with previous IP"); + let mut actions = ips + .ips() + .filter(|ip| config.records.ip_type.valid_for_type(*ip)) + .flat_map(|ip| nsupdate::Action::from_records(ip, state.ttl, state.records)) + .peekable(); + + if actions.peek().is_some() { + match nsupdate::nsupdate(state.key_file, actions).await { + Ok(status) => { + if !status.success() { + error!("nsupdate failed: code {status}"); + bail!("nsupdate returned with code {status}"); } } + Err(err) => { + error!("Failed to update records with previous IP: {err}"); + return Err(err) + .into_diagnostic() + .wrap_err("failed to update records with previous IP"); + } } + } - // Create services - let app = Router::new().route("/update", get(update_records)); - // if a password is provided, validate it - let app = if let Some(pass) = password_hash { - app.layer(auth::layer( - Box::leak(pass), - Box::leak(config.password.salt), - )) - } else { - app - } - .layer(config.records.ip_source.into_extension()) - .with_state(state); + // Create services + let app = Router::new().route("/update", get(update_records)); + // if a password is provided, validate it + let app = if let Some(pass) = pass { + app.layer(auth::layer( + Box::leak(pass), + Box::leak(config.password.salt), + )) + } else { + app + } + .layer(config.records.ip_source.into_extension()) + .with_state(state); - let config::Server { address } = config.server; + let config::Server { address } = config.server; - // Start services - info!("starting listener on {address}"); - let listener = tokio::net::TcpListener::bind(address) - .await - .into_diagnostic()?; - info!("listening on {address}"); - axum::serve( - listener, - app.into_make_service_with_connect_info::(), - ) + // Start services + info!("starting listener on {address}"); + let listener = tokio::net::TcpListener::bind(address) .await - .into_diagnostic() - }) - .wrap_err("failed to run main loop") + .into_diagnostic()?; + info!("listening on {address}"); + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) + .await + .into_diagnostic() } /// Serde deserialization decorator to map empty Strings to None,