From 846a0675d1769fad4d559c66734255c636c390f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jalil=20David=20Salam=C3=A9=20Messina?= Date: Sat, 23 Nov 2024 20:56:03 +0100 Subject: [PATCH] refactor: reorganize main.rs --- src/main.rs | 232 ++++++++++++++++-------------------------------- src/nsupdate.rs | 111 +++++++++++++++++++++++ 2 files changed, 189 insertions(+), 154 deletions(-) create mode 100644 src/nsupdate.rs diff --git a/src/main.rs b/src/main.rs index b850e02..07c06f5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,7 @@ use std::{ - ffi::OsStr, io::ErrorKind, net::{IpAddr, SocketAddr}, path::{Path, PathBuf}, - process::{ExitStatus, Stdio}, time::Duration, }; @@ -14,11 +12,11 @@ use clap::{Parser, Subcommand}; use clap_verbosity_flag::Verbosity; use http::StatusCode; use miette::{bail, ensure, Context, IntoDiagnostic, Result}; -use tokio::io::AsyncWriteExt; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, info}; use tracing_subscriber::EnvFilter; mod auth; +mod nsupdate; mod password; mod records; @@ -118,6 +116,57 @@ struct AppState<'a> { ip_file: &'a Path, } +impl AppState<'static> { + fn from_args(args: &Opts) -> miette::Result { + let Opts { + verbosity: _, + address: _, + port: _, + password_file: _, + data_dir, + key_file, + insecure, + subcommand: _, + records, + salt: _, + ttl, + ip_source: _, + } = args; + + // Set state + let ttl = Duration::from_secs(*ttl); + + // Use last registered IP address if available + let ip_file = data_dir.join("last-ip"); + + let state = AppState { + ttl, + // Load DNS records + records: records::load_no_verify(records)?, + // Load keyfile + key_file: key_file + .as_deref() + .map(|path| -> miette::Result<_> { + std::fs::File::open(path) + .into_diagnostic() + .wrap_err_with(|| { + format!("{} is not readable by the current user", path.display()) + })?; + Ok(&*Box::leak(path.into())) + }) + .transpose()?, + ip_file: Box::leak(ip_file.into_boxed_path()), + }; + + ensure!( + state.key_file.is_some() || *insecure, + "a key file must be used" + ); + + Ok(state) + } +} + fn load_ip(path: &Path) -> Result> { debug!("loading last IP from {}", path.display()); let data = match std::fs::read_to_string(path) { @@ -166,33 +215,31 @@ fn main() -> Result<()> { return cmd.process(&args); } + // Initialize state + let state = AppState::from_args(&args)?; + let Opts { verbosity: _, address: ip, port, password_file, - data_dir, - key_file, + data_dir: _, + key_file: _, insecure, subcommand: _, - records, + records: _, salt, - ttl, + ttl: _, ip_source, } = args; info!("checking environment"); - // Set state - let ttl = Duration::from_secs(ttl); - - // Use last registered IP address if available - let ip_file = data_dir.join("last-ip"); - // Load password hash let password_hash = password_file .map(|path| -> miette::Result<_> { - let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?; + let path = path.as_path(); + let pass = std::fs::read_to_string(path).into_diagnostic()?; let pass: Box<[u8]> = URL_SAFE_NO_PAD .decode(pass.trim().as_bytes()) @@ -204,35 +251,11 @@ fn main() -> Result<()> { }) .transpose()?; - let state = AppState { - ttl, - // Load DNS records - records: records::load_no_verify(&records)?, - // Load keyfile - key_file: key_file - .map(|key_file| -> miette::Result<_> { - let path = key_file.as_path(); - std::fs::File::open(path) - .into_diagnostic() - .wrap_err_with(|| { - format!("{} is not readable by the current user", path.display()) - })?; - Ok(&*Box::leak(key_file.into_boxed_path())) - }) - .transpose()?, - ip_file: Box::leak(ip_file.into_boxed_path()), - }; - ensure!( password_hash.is_some() || insecure, "a password must be used" ); - ensure!( - state.key_file.is_some() || insecure, - "a key file must be used" - ); - let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() @@ -242,20 +265,22 @@ fn main() -> Result<()> { rt.block_on(async { // Load previous IP and update DNS record to point to it (if available) match load_ip(state.ip_file) { - Ok(Some(ip)) => match nsupdate(ip, ttl, state.key_file, state.records).await { - Ok(status) => { - if !status.success() { - error!("nsupdate failed: code {status}"); - bail!("nsupdate returned with code {status}"); + Ok(Some(ip)) => { + match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { + Ok(status) => { + if !status.success() { + error!("nsupdate failed: code {status}"); + bail!("nsupdate returned with code {status}"); + } + } + Err(err) => { + error!("Failed to update records with previous IP: {err}"); + return Err(err) + .into_diagnostic() + .wrap_err("failed to update records with previous IP"); } } - Err(err) => { - error!("Failed to update records with previous IP: {err}"); - return Err(err) - .into_diagnostic() - .wrap_err("failed to update records with previous IP"); - } - }, + } Ok(None) => { info!("No previous IP address set"); } @@ -295,9 +320,8 @@ async fn update_records( State(state): State>, SecureClientIp(ip): SecureClientIp, ) -> axum::response::Result<&'static str> { - debug!("received update request from {ip}"); - info!("accepted update"); - match nsupdate(ip, state.ttl, state.key_file, state.records).await { + info!("accepted update from {ip}"); + match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { Ok(status) if status.success() => { tokio::task::spawn_blocking(move || { info!("updating last ip to {ip}"); @@ -323,103 +347,3 @@ async fn update_records( .into()), } } - -#[tracing::instrument(level = "trace", ret(level = "warn"))] -async fn nsupdate( - ip: IpAddr, - ttl: Duration, - key_file: Option<&Path>, - records: &[&str], -) -> std::io::Result { - let mut cmd = tokio::process::Command::new("nsupdate"); - if let Some(key_file) = key_file { - cmd.args([OsStr::new("-k"), key_file.as_os_str()]); - } - debug!("spawning new process"); - let mut child = cmd - .stdin(Stdio::piped()) - .spawn() - .inspect_err(|err| warn!("failed to spawn child: {err}"))?; - let mut stdin = child.stdin.take().expect("stdin not present"); - debug!("sending update request"); - stdin - .write_all(update_ns_records(ip, ttl, records).as_bytes()) - .await - .inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?; - debug!("closing stdin"); - stdin - .shutdown() - .await - .inspect_err(|err| warn!("failed to close stdin to nsupdate: {err}"))?; - debug!("waiting for nsupdate to exit"); - child - .wait() - .await - .inspect_err(|err| warn!("failed to wait for child: {err}")) -} - -fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String { - use std::fmt::Write; - let ttl_s: u64 = ttl.as_secs(); - - let rec_type = match ip { - IpAddr::V4(_) => "A", - IpAddr::V6(_) => "AAAA", - }; - let mut cmds = String::from("server 127.0.0.1\n"); - for &record in records { - writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap(); - writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap(); - } - writeln!(cmds, "send\nquit").unwrap(); - cmds -} - -#[cfg(test)] -mod test { - use insta::assert_snapshot; - - use crate::{update_ns_records, DEFAULT_TTL}; - - use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - - #[test] - #[allow(non_snake_case)] - fn expected_update_string_A() { - assert_snapshot!(update_ns_records( - IpAddr::V4(Ipv4Addr::LOCALHOST), - DEFAULT_TTL, - &["example.com.", "example.org.", "example.net."], - ), @r###" - server 127.0.0.1 - update delete example.com. 60 IN A - update add example.com. 60 IN A 127.0.0.1 - update delete example.org. 60 IN A - update add example.org. 60 IN A 127.0.0.1 - update delete example.net. 60 IN A - update add example.net. 60 IN A 127.0.0.1 - send - quit - "###); - } - - #[test] - #[allow(non_snake_case)] - fn expected_update_string_AAAA() { - assert_snapshot!(update_ns_records( - IpAddr::V6(Ipv6Addr::LOCALHOST), - DEFAULT_TTL, - &["example.com.", "example.org.", "example.net."], - ), @r###" - server 127.0.0.1 - update delete example.com. 60 IN AAAA - update add example.com. 60 IN AAAA ::1 - update delete example.org. 60 IN AAAA - update add example.org. 60 IN AAAA ::1 - update delete example.net. 60 IN AAAA - update add example.net. 60 IN AAAA ::1 - send - quit - "###); - } -} diff --git a/src/nsupdate.rs b/src/nsupdate.rs new file mode 100644 index 0000000..62395b7 --- /dev/null +++ b/src/nsupdate.rs @@ -0,0 +1,111 @@ +use std::{ + ffi::OsStr, + net::IpAddr, + path::Path, + process::{ExitStatus, Stdio}, + time::Duration, +}; + +use tokio::io::AsyncWriteExt; +use tracing::{debug, warn}; + +#[tracing::instrument(level = "trace", ret(level = "warn"))] +pub async fn nsupdate( + ip: IpAddr, + ttl: Duration, + key_file: Option<&Path>, + records: &[&str], +) -> std::io::Result { + let mut cmd = tokio::process::Command::new("nsupdate"); + if let Some(key_file) = key_file { + cmd.args([OsStr::new("-k"), key_file.as_os_str()]); + } + debug!("spawning new process"); + let mut child = cmd + .stdin(Stdio::piped()) + .spawn() + .inspect_err(|err| warn!("failed to spawn child: {err}"))?; + let mut stdin = child.stdin.take().expect("stdin not present"); + debug!("sending update request"); + stdin + .write_all(update_ns_records(ip, ttl, records).as_bytes()) + .await + .inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?; + debug!("closing stdin"); + stdin + .shutdown() + .await + .inspect_err(|err| warn!("failed to close stdin to nsupdate: {err}"))?; + debug!("waiting for nsupdate to exit"); + child + .wait() + .await + .inspect_err(|err| warn!("failed to wait for child: {err}")) +} + +fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String { + use std::fmt::Write; + let ttl_s: u64 = ttl.as_secs(); + + let rec_type = match ip { + IpAddr::V4(_) => "A", + IpAddr::V6(_) => "AAAA", + }; + let mut cmds = String::from("server 127.0.0.1\n"); + for &record in records { + writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap(); + writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap(); + } + writeln!(cmds, "send\nquit").unwrap(); + cmds +} + +#[cfg(test)] +mod test { + use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; + + use insta::assert_snapshot; + + use super::update_ns_records; + use crate::DEFAULT_TTL; + + #[test] + #[allow(non_snake_case)] + fn expected_update_string_A() { + assert_snapshot!(update_ns_records( + IpAddr::V4(Ipv4Addr::LOCALHOST), + DEFAULT_TTL, + &["example.com.", "example.org.", "example.net."], + ), @r###" + server 127.0.0.1 + update delete example.com. 60 IN A + update add example.com. 60 IN A 127.0.0.1 + update delete example.org. 60 IN A + update add example.org. 60 IN A 127.0.0.1 + update delete example.net. 60 IN A + update add example.net. 60 IN A 127.0.0.1 + send + quit + "###); + } + + #[test] + #[allow(non_snake_case)] + fn expected_update_string_AAAA() { + assert_snapshot!(update_ns_records( + IpAddr::V6(Ipv6Addr::LOCALHOST), + DEFAULT_TTL, + &["example.com.", "example.org.", "example.net."], + ), @r###" + server 127.0.0.1 + update delete example.com. 60 IN AAAA + update add example.com. 60 IN AAAA ::1 + update delete example.org. 60 IN AAAA + update add example.org. 60 IN AAAA ::1 + update delete example.net. 60 IN AAAA + update add example.net. 60 IN AAAA ::1 + send + quit + "###); + } +}