refactor: reorganize main.rs #20

Merged
jalil merged 1 commit from refactor-main into main 2024-11-23 21:10:15 +01:00
2 changed files with 189 additions and 154 deletions
Showing only changes of commit 846a0675d1 - Show all commits

View file

@ -1,9 +1,7 @@
use std::{ use std::{
ffi::OsStr,
io::ErrorKind, io::ErrorKind,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
path::{Path, PathBuf}, path::{Path, PathBuf},
process::{ExitStatus, Stdio},
time::Duration, time::Duration,
}; };
@ -14,11 +12,11 @@ use clap::{Parser, Subcommand};
use clap_verbosity_flag::Verbosity; use clap_verbosity_flag::Verbosity;
use http::StatusCode; use http::StatusCode;
use miette::{bail, ensure, Context, IntoDiagnostic, Result}; use miette::{bail, ensure, Context, IntoDiagnostic, Result};
use tokio::io::AsyncWriteExt; use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
mod auth; mod auth;
mod nsupdate;
mod password; mod password;
mod records; mod records;
@ -118,6 +116,57 @@ struct AppState<'a> {
ip_file: &'a Path, ip_file: &'a Path,
} }
impl AppState<'static> {
fn from_args(args: &Opts) -> miette::Result<Self> {
let Opts {
verbosity: _,
address: _,
port: _,
password_file: _,
data_dir,
key_file,
insecure,
subcommand: _,
records,
salt: _,
ttl,
ip_source: _,
} = args;
// Set state
let ttl = Duration::from_secs(*ttl);
// Use last registered IP address if available
let ip_file = data_dir.join("last-ip");
let state = AppState {
ttl,
// Load DNS records
records: records::load_no_verify(records)?,
// Load keyfile
key_file: key_file
.as_deref()
.map(|path| -> miette::Result<_> {
std::fs::File::open(path)
.into_diagnostic()
.wrap_err_with(|| {
format!("{} is not readable by the current user", path.display())
})?;
Ok(&*Box::leak(path.into()))
})
.transpose()?,
ip_file: Box::leak(ip_file.into_boxed_path()),
};
ensure!(
state.key_file.is_some() || *insecure,
"a key file must be used"
);
Ok(state)
}
}
fn load_ip(path: &Path) -> Result<Option<IpAddr>> { fn load_ip(path: &Path) -> Result<Option<IpAddr>> {
debug!("loading last IP from {}", path.display()); debug!("loading last IP from {}", path.display());
let data = match std::fs::read_to_string(path) { let data = match std::fs::read_to_string(path) {
@ -166,33 +215,31 @@ fn main() -> Result<()> {
return cmd.process(&args); return cmd.process(&args);
} }
// Initialize state
let state = AppState::from_args(&args)?;
let Opts { let Opts {
verbosity: _, verbosity: _,
address: ip, address: ip,
port, port,
password_file, password_file,
data_dir, data_dir: _,
key_file, key_file: _,
insecure, insecure,
subcommand: _, subcommand: _,
records, records: _,
salt, salt,
ttl, ttl: _,
ip_source, ip_source,
} = args; } = args;
info!("checking environment"); info!("checking environment");
// Set state
let ttl = Duration::from_secs(ttl);
// Use last registered IP address if available
let ip_file = data_dir.join("last-ip");
// Load password hash // Load password hash
let password_hash = password_file let password_hash = password_file
.map(|path| -> miette::Result<_> { .map(|path| -> miette::Result<_> {
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?; let path = path.as_path();
let pass = std::fs::read_to_string(path).into_diagnostic()?;
let pass: Box<[u8]> = URL_SAFE_NO_PAD let pass: Box<[u8]> = URL_SAFE_NO_PAD
.decode(pass.trim().as_bytes()) .decode(pass.trim().as_bytes())
@ -204,35 +251,11 @@ fn main() -> Result<()> {
}) })
.transpose()?; .transpose()?;
let state = AppState {
ttl,
// Load DNS records
records: records::load_no_verify(&records)?,
// Load keyfile
key_file: key_file
.map(|key_file| -> miette::Result<_> {
let path = key_file.as_path();
std::fs::File::open(path)
.into_diagnostic()
.wrap_err_with(|| {
format!("{} is not readable by the current user", path.display())
})?;
Ok(&*Box::leak(key_file.into_boxed_path()))
})
.transpose()?,
ip_file: Box::leak(ip_file.into_boxed_path()),
};
ensure!( ensure!(
password_hash.is_some() || insecure, password_hash.is_some() || insecure,
"a password must be used" "a password must be used"
); );
ensure!(
state.key_file.is_some() || insecure,
"a key file must be used"
);
let rt = tokio::runtime::Builder::new_current_thread() let rt = tokio::runtime::Builder::new_current_thread()
.enable_all() .enable_all()
.build() .build()
@ -242,20 +265,22 @@ fn main() -> Result<()> {
rt.block_on(async { rt.block_on(async {
// Load previous IP and update DNS record to point to it (if available) // Load previous IP and update DNS record to point to it (if available)
match load_ip(state.ip_file) { match load_ip(state.ip_file) {
Ok(Some(ip)) => match nsupdate(ip, ttl, state.key_file, state.records).await { Ok(Some(ip)) => {
Ok(status) => { match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await {
if !status.success() { Ok(status) => {
error!("nsupdate failed: code {status}"); if !status.success() {
bail!("nsupdate returned with code {status}"); error!("nsupdate failed: code {status}");
bail!("nsupdate returned with code {status}");
}
}
Err(err) => {
error!("Failed to update records with previous IP: {err}");
return Err(err)
.into_diagnostic()
.wrap_err("failed to update records with previous IP");
} }
} }
Err(err) => { }
error!("Failed to update records with previous IP: {err}");
return Err(err)
.into_diagnostic()
.wrap_err("failed to update records with previous IP");
}
},
Ok(None) => { Ok(None) => {
info!("No previous IP address set"); info!("No previous IP address set");
} }
@ -295,9 +320,8 @@ async fn update_records(
State(state): State<AppState<'static>>, State(state): State<AppState<'static>>,
SecureClientIp(ip): SecureClientIp, SecureClientIp(ip): SecureClientIp,
) -> axum::response::Result<&'static str> { ) -> axum::response::Result<&'static str> {
debug!("received update request from {ip}"); info!("accepted update from {ip}");
info!("accepted update"); match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await {
match nsupdate(ip, state.ttl, state.key_file, state.records).await {
Ok(status) if status.success() => { Ok(status) if status.success() => {
tokio::task::spawn_blocking(move || { tokio::task::spawn_blocking(move || {
info!("updating last ip to {ip}"); info!("updating last ip to {ip}");
@ -323,103 +347,3 @@ async fn update_records(
.into()), .into()),
} }
} }
#[tracing::instrument(level = "trace", ret(level = "warn"))]
async fn nsupdate(
ip: IpAddr,
ttl: Duration,
key_file: Option<&Path>,
records: &[&str],
) -> std::io::Result<ExitStatus> {
let mut cmd = tokio::process::Command::new("nsupdate");
if let Some(key_file) = key_file {
cmd.args([OsStr::new("-k"), key_file.as_os_str()]);
}
debug!("spawning new process");
let mut child = cmd
.stdin(Stdio::piped())
.spawn()
.inspect_err(|err| warn!("failed to spawn child: {err}"))?;
let mut stdin = child.stdin.take().expect("stdin not present");
debug!("sending update request");
stdin
.write_all(update_ns_records(ip, ttl, records).as_bytes())
.await
.inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?;
debug!("closing stdin");
stdin
.shutdown()
.await
.inspect_err(|err| warn!("failed to close stdin to nsupdate: {err}"))?;
debug!("waiting for nsupdate to exit");
child
.wait()
.await
.inspect_err(|err| warn!("failed to wait for child: {err}"))
}
fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String {
use std::fmt::Write;
let ttl_s: u64 = ttl.as_secs();
let rec_type = match ip {
IpAddr::V4(_) => "A",
IpAddr::V6(_) => "AAAA",
};
let mut cmds = String::from("server 127.0.0.1\n");
for &record in records {
writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap();
writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap();
}
writeln!(cmds, "send\nquit").unwrap();
cmds
}
#[cfg(test)]
mod test {
use insta::assert_snapshot;
use crate::{update_ns_records, DEFAULT_TTL};
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
#[test]
#[allow(non_snake_case)]
fn expected_update_string_A() {
assert_snapshot!(update_ns_records(
IpAddr::V4(Ipv4Addr::LOCALHOST),
DEFAULT_TTL,
&["example.com.", "example.org.", "example.net."],
), @r###"
server 127.0.0.1
update delete example.com. 60 IN A
update add example.com. 60 IN A 127.0.0.1
update delete example.org. 60 IN A
update add example.org. 60 IN A 127.0.0.1
update delete example.net. 60 IN A
update add example.net. 60 IN A 127.0.0.1
send
quit
"###);
}
#[test]
#[allow(non_snake_case)]
fn expected_update_string_AAAA() {
assert_snapshot!(update_ns_records(
IpAddr::V6(Ipv6Addr::LOCALHOST),
DEFAULT_TTL,
&["example.com.", "example.org.", "example.net."],
), @r###"
server 127.0.0.1
update delete example.com. 60 IN AAAA
update add example.com. 60 IN AAAA ::1
update delete example.org. 60 IN AAAA
update add example.org. 60 IN AAAA ::1
update delete example.net. 60 IN AAAA
update add example.net. 60 IN AAAA ::1
send
quit
"###);
}
}

111
src/nsupdate.rs Normal file
View file

@ -0,0 +1,111 @@
use std::{
ffi::OsStr,
net::IpAddr,
path::Path,
process::{ExitStatus, Stdio},
time::Duration,
};
use tokio::io::AsyncWriteExt;
use tracing::{debug, warn};
#[tracing::instrument(level = "trace", ret(level = "warn"))]
pub async fn nsupdate(
ip: IpAddr,
ttl: Duration,
key_file: Option<&Path>,
records: &[&str],
) -> std::io::Result<ExitStatus> {
let mut cmd = tokio::process::Command::new("nsupdate");
if let Some(key_file) = key_file {
cmd.args([OsStr::new("-k"), key_file.as_os_str()]);
}
debug!("spawning new process");
let mut child = cmd
.stdin(Stdio::piped())
.spawn()
.inspect_err(|err| warn!("failed to spawn child: {err}"))?;
let mut stdin = child.stdin.take().expect("stdin not present");
debug!("sending update request");
stdin
.write_all(update_ns_records(ip, ttl, records).as_bytes())
.await
.inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?;
debug!("closing stdin");
stdin
.shutdown()
.await
.inspect_err(|err| warn!("failed to close stdin to nsupdate: {err}"))?;
debug!("waiting for nsupdate to exit");
child
.wait()
.await
.inspect_err(|err| warn!("failed to wait for child: {err}"))
}
fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String {
use std::fmt::Write;
let ttl_s: u64 = ttl.as_secs();
let rec_type = match ip {
IpAddr::V4(_) => "A",
IpAddr::V6(_) => "AAAA",
};
let mut cmds = String::from("server 127.0.0.1\n");
for &record in records {
writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap();
writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap();
}
writeln!(cmds, "send\nquit").unwrap();
cmds
}
#[cfg(test)]
mod test {
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use insta::assert_snapshot;
use super::update_ns_records;
use crate::DEFAULT_TTL;
#[test]
#[allow(non_snake_case)]
fn expected_update_string_A() {
assert_snapshot!(update_ns_records(
IpAddr::V4(Ipv4Addr::LOCALHOST),
DEFAULT_TTL,
&["example.com.", "example.org.", "example.net."],
), @r###"
server 127.0.0.1
update delete example.com. 60 IN A
update add example.com. 60 IN A 127.0.0.1
update delete example.org. 60 IN A
update add example.org. 60 IN A 127.0.0.1
update delete example.net. 60 IN A
update add example.net. 60 IN A 127.0.0.1
send
quit
"###);
}
#[test]
#[allow(non_snake_case)]
fn expected_update_string_AAAA() {
assert_snapshot!(update_ns_records(
IpAddr::V6(Ipv6Addr::LOCALHOST),
DEFAULT_TTL,
&["example.com.", "example.org.", "example.net."],
), @r###"
server 127.0.0.1
update delete example.com. 60 IN AAAA
update add example.com. 60 IN AAAA ::1
update delete example.org. 60 IN AAAA
update add example.org. 60 IN AAAA ::1
update delete example.net. 60 IN AAAA
update add example.net. 60 IN AAAA ::1
send
quit
"###);
}
}