diff --git a/Cargo.toml b/Cargo.toml index dd4bde5..f435dbf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,10 +27,10 @@ clap-verbosity-flag = { version = "3", default-features = false, features = [ http = "1" miette = { version = "7", features = ["fancy"] } ring = { version = "0.17", features = ["std"] } -serde = { version = "1", features = ["derive"] } -serde_json = "1" +serde = { version = "1.0.217", features = ["derive"] } +serde_json = "1.0.137" tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] } -tower-http = { version = "0.6", features = ["validate-request"] } +tower-http = { version = "0.6.2", features = ["validate-request"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/flake.lock b/flake.lock index 10a4609..31c94b6 100644 --- a/flake.lock +++ b/flake.lock @@ -37,11 +37,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1738680400, - "narHash": "sha256-ooLh+XW8jfa+91F1nhf9OF7qhuA/y1ChLx6lXDNeY5U=", + "lastModified": 1738546358, + "narHash": "sha256-nLivjIygCiqLp5QcL7l56Tca/elVqM9FG1hGd9ZSsrg=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "799ba5bffed04ced7067a91798353d360788b30d", + "rev": "c6e957d81b96751a3d5967a0fd73694f303cc914", "type": "github" }, "original": { diff --git a/src/main.rs b/src/main.rs index 7718175..bea4268 100644 --- a/src/main.rs +++ b/src/main.rs @@ -366,24 +366,24 @@ fn main() -> Result<()> { rt.block_on(async { // Update DNS record with previous IPs (if available) let ips = state.last_ips.lock().await.clone(); - - let actions = ips - .ips() - .filter(|ip| ip_type.valid_for_type(*ip)) - .flat_map(|ip| nsupdate::Action::from_records(ip, state.ttl, state.records)); - - match nsupdate::nsupdate(state.key_file, actions).await { - Ok(status) => { - if !status.success() { - error!("nsupdate failed: code {status}"); - bail!("nsupdate returned with code {status}"); - } + for ip in ips.ips() { + if !ip_type.valid_for_type(ip) { + continue; } - Err(err) => { - error!("Failed to update records with previous IP: {err}"); - return Err(err) - .into_diagnostic() - .wrap_err("failed to update records with previous IP"); + + match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { + Ok(status) => { + if !status.success() { + error!("nsupdate failed: code {status}"); + bail!("nsupdate returned with code {status}"); + } + } + Err(err) => { + error!("Failed to update records with previous IP: {err}"); + return Err(err) + .into_diagnostic() + .wrap_err("failed to update records with previous IP"); + } } } @@ -541,8 +541,7 @@ async fn trigger_update( ip: IpAddr, state: &AppState<'static>, ) -> axum::response::Result<&'static str> { - let actions = nsupdate::Action::from_records(ip, state.ttl, state.records); - match nsupdate::nsupdate(state.key_file, actions).await { + match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { Ok(status) if status.success() => { let ips = { // Update state diff --git a/src/nsupdate.rs b/src/nsupdate.rs index 74397fa..62395b7 100644 --- a/src/nsupdate.rs +++ b/src/nsupdate.rs @@ -9,51 +9,12 @@ use std::{ use tokio::io::AsyncWriteExt; use tracing::{debug, warn}; -pub enum Action<'a> { - // Reassign a domain to a different IP - Reassign { - domain: &'a str, - to: IpAddr, - ttl: Duration, - }, -} - -impl<'a> Action<'a> { - /// Create a set of [`Action`]s reassigning the domains in `records` to the specified - /// [`IpAddr`] - pub fn from_records( - to: IpAddr, - ttl: Duration, - records: &'a [&'a str], - ) -> impl IntoIterator + 'a { - records - .iter() - .map(move |&domain| Action::Reassign { domain, to, ttl }) - } -} - -impl std::fmt::Display for Action<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Action::Reassign { domain, to, ttl } => { - let ttl = ttl.as_secs(); - let typ = match to { - IpAddr::V4(_) => "A", - IpAddr::V6(_) => "AAAA", - }; - // Delete previous record of type `typ` - writeln!(f, "update delete {domain} {ttl} IN {typ}")?; - // Add record with new IP - writeln!(f, "update add {domain} {ttl} IN {typ} {to}") - } - } - } -} - -#[tracing::instrument(level = "trace", skip(actions), ret(level = "warn"))] +#[tracing::instrument(level = "trace", ret(level = "warn"))] pub async fn nsupdate( + ip: IpAddr, + ttl: Duration, key_file: Option<&Path>, - actions: impl IntoIterator>, + records: &[&str], ) -> std::io::Result { let mut cmd = tokio::process::Command::new("nsupdate"); if let Some(key_file) = key_file { @@ -66,13 +27,10 @@ pub async fn nsupdate( .inspect_err(|err| warn!("failed to spawn child: {err}"))?; let mut stdin = child.stdin.take().expect("stdin not present"); debug!("sending update request"); - let mut buf = Vec::new(); - update_ns_records(&mut buf, actions).unwrap(); stdin - .write_all(&buf) + .write_all(update_ns_records(ip, ttl, records).as_bytes()) .await .inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?; - debug!("closing stdin"); stdin .shutdown() @@ -85,16 +43,21 @@ pub async fn nsupdate( .inspect_err(|err| warn!("failed to wait for child: {err}")) } -fn update_ns_records<'a>( - mut buf: impl std::io::Write, - actions: impl IntoIterator>, -) -> std::io::Result<()> { - writeln!(buf, "server 127.0.0.1")?; - for action in actions { - writeln!(buf, "{action}")?; +fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String { + use std::fmt::Write; + let ttl_s: u64 = ttl.as_secs(); + + let rec_type = match ip { + IpAddr::V4(_) => "A", + IpAddr::V6(_) => "AAAA", + }; + let mut cmds = String::from("server 127.0.0.1\n"); + for &record in records { + writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap(); + writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap(); } - writeln!(buf, "send")?; - writeln!(buf, "quit") + writeln!(cmds, "send\nquit").unwrap(); + cmds } #[cfg(test)] @@ -103,21 +66,17 @@ mod test { use insta::assert_snapshot; - use super::{update_ns_records, Action}; + use super::update_ns_records; use crate::DEFAULT_TTL; #[test] #[allow(non_snake_case)] fn expected_update_string_A() { - let mut buf = Vec::new(); - let actions = Action::from_records( - IpAddr::V4(Ipv4Addr::LOCALHOST), - DEFAULT_TTL, - &["example.com.", "example.org.", "example.net."], - ); - update_ns_records(&mut buf, actions).unwrap(); - - assert_snapshot!(String::from_utf8(buf).unwrap(), @r###" + assert_snapshot!(update_ns_records( + IpAddr::V4(Ipv4Addr::LOCALHOST), + DEFAULT_TTL, + &["example.com.", "example.org.", "example.net."], + ), @r###" server 127.0.0.1 update delete example.com. 60 IN A update add example.com. 60 IN A 127.0.0.1 @@ -133,15 +92,11 @@ mod test { #[test] #[allow(non_snake_case)] fn expected_update_string_AAAA() { - let mut buf = Vec::new(); - let actions = Action::from_records( - IpAddr::V6(Ipv6Addr::LOCALHOST), - DEFAULT_TTL, - &["example.com.", "example.org.", "example.net."], - ); - update_ns_records(&mut buf, actions).unwrap(); - - assert_snapshot!(String::from_utf8(buf).unwrap(), @r###" + assert_snapshot!(update_ns_records( + IpAddr::V6(Ipv6Addr::LOCALHOST), + DEFAULT_TTL, + &["example.com.", "example.org.", "example.net."], + ), @r###" server 127.0.0.1 update delete example.com. 60 IN AAAA update add example.com. 60 IN AAAA ::1