diff --git a/Cargo.lock b/Cargo.lock index 859dee6..a6267f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1139,6 +1139,8 @@ dependencies = [ "insta", "miette", "ring", + "serde", + "serde_json", "tokio", "tower-http", "tracing", diff --git a/Cargo.toml b/Cargo.toml index 0bafe90..89197a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,8 @@ clap-verbosity-flag = { version = "3", default-features = false, features = [ http = "1" miette = { version = "7", features = ["fancy"] } ring = { version = "0.17", features = ["std"] } +serde = { version = "1.0.217", features = ["derive"] } +serde_json = "1.0.137" tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] } tower-http = { version = "0.6.2", features = ["validate-request"] } tracing = "0.1" diff --git a/src/main.rs b/src/main.rs index d2b9136..d909d3c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ use std::{ io::ErrorKind, - net::{IpAddr, SocketAddr}, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, path::{Path, PathBuf}, time::Duration, }; @@ -114,6 +114,48 @@ struct AppState<'a> { /// The file where the last IP is stored ip_file: &'a Path, + + /// Last recorded IPs + last_ips: std::sync::Arc>, +} + +#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)] +struct SavedIPs { + #[serde(skip_serializing_if = "Option::is_none")] + ipv4: Option, + #[serde(skip_serializing_if = "Option::is_none")] + ipv6: Option, +} + +impl SavedIPs { + fn update(&mut self, ip: IpAddr) { + match ip { + IpAddr::V4(ipv4_addr) => self.ipv4 = Some(ipv4_addr), + IpAddr::V6(ipv6_addr) => self.ipv6 = Some(ipv6_addr), + } + } + + fn ips(&self) -> impl Iterator { + self.ipv4 + .map(IpAddr::V4) + .into_iter() + .chain(self.ipv6.map(IpAddr::V6)) + } + + fn from_str(data: &str) -> miette::Result { + match data.parse::() { + // Old format + Ok(IpAddr::V4(ipv4)) => Ok(Self { + ipv4: Some(ipv4), + ipv6: None, + }), + Ok(IpAddr::V6(ipv6)) => Ok(Self { + ipv4: None, + ipv6: Some(ipv6), + }), + Err(_) => serde_json::from_str(data).into_diagnostic(), + } + } } impl AppState<'static> { @@ -137,7 +179,7 @@ impl AppState<'static> { let ttl = Duration::from_secs(*ttl); // Use last registered IP address if available - let ip_file = data_dir.join("last-ip"); + let ip_file = Box::leak(data_dir.join("last-ip").into_boxed_path()); let state = AppState { ttl, @@ -155,7 +197,10 @@ impl AppState<'static> { Ok(&*Box::leak(path.into())) }) .transpose()?, - ip_file: Box::leak(ip_file.into_boxed_path()), + ip_file, + last_ips: std::sync::Arc::new(tokio::sync::Mutex::new( + load_ip(ip_file)?.unwrap_or_default(), + )), }; ensure!( @@ -167,7 +212,7 @@ impl AppState<'static> { } } -fn load_ip(path: &Path) -> Result> { +fn load_ip(path: &Path) -> Result> { debug!("loading last IP from {}", path.display()); let data = match std::fs::read_to_string(path) { Ok(ip) => ip, @@ -181,11 +226,9 @@ fn load_ip(path: &Path) -> Result> { } }; - Ok(Some( - data.parse() - .into_diagnostic() - .wrap_err("failed to parse last ip address")?, - )) + SavedIPs::from_str(&data) + .wrap_err_with(|| format!("failed to load last ip address from {}", path.display())) + .map(Some) } #[tracing::instrument(err)] @@ -266,28 +309,24 @@ fn main() -> Result<()> { .wrap_err("failed to start the tokio runtime")?; rt.block_on(async { - // Load previous IP and update DNS record to point to it (if available) - match load_ip(state.ip_file) { - Ok(Some(ip)) => { - match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { - Ok(status) => { - if !status.success() { - error!("nsupdate failed: code {status}"); - bail!("nsupdate returned with code {status}"); - } - } - Err(err) => { - error!("Failed to update records with previous IP: {err}"); - return Err(err) - .into_diagnostic() - .wrap_err("failed to update records with previous IP"); + // Update DNS record with previous IPs (if available) + let ips = state.last_ips.lock().await.clone(); + for ip in ips.ips() { + match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { + Ok(status) => { + if !status.success() { + error!("nsupdate failed: code {status}"); + bail!("nsupdate returned with code {status}"); } } + Err(err) => { + error!("Failed to update records with previous IP: {err}"); + return Err(err) + .into_diagnostic() + .wrap_err("failed to update records with previous IP"); + } } - Ok(None) => info!("No previous IP address set"), - - Err(err) => error!("Ignoring previous IP due to: {err}"), - }; + } // Create services let app = Router::new().route("/update", get(update_records)); @@ -324,13 +363,22 @@ async fn update_records( info!("accepted update from {ip}"); match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { Ok(status) if status.success() => { + let ips = { + // Update state + let mut ips = state.last_ips.lock().await; + ips.update(ip); + ips.clone() + }; + tokio::task::spawn_blocking(move || { - info!("updating last ip to {ip}"); - if let Err(err) = std::fs::write(state.ip_file, format!("{ip}")) { + info!("updating last ips to {ips:?}"); + let data = serde_json::to_vec(&ips).expect("invalid serialization impl"); + if let Err(err) = std::fs::write(state.ip_file, data) { error!("Failed to update last IP: {err}"); } - info!("updated last ip to {ip}"); + info!("updated last ips to {ips:?}"); }); + Ok("successful update") } Ok(status) => {