webnsupdate/src/main.rs

426 lines
12 KiB
Rust
Raw Normal View History

2024-05-03 20:29:10 +02:00
use std::{
ffi::OsStr,
io::ErrorKind,
2024-05-03 20:29:10 +02:00
net::{IpAddr, SocketAddr},
path::{Path, PathBuf},
process::{ExitStatus, Stdio},
time::Duration,
};
use axum::{extract::State, routing::get, Router};
use axum_client_ip::{SecureClientIp, SecureClientIpSource};
2024-05-03 20:29:10 +02:00
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use clap::{Parser, Subcommand};
2024-10-26 13:23:22 +02:00
use clap_verbosity_flag::Verbosity;
2024-05-03 20:29:10 +02:00
use http::StatusCode;
use miette::{bail, ensure, Context, IntoDiagnostic, Result};
2024-05-03 20:29:10 +02:00
use tokio::io::AsyncWriteExt;
use tracing::{debug, error, info, warn};
2024-05-03 20:29:10 +02:00
use tracing_subscriber::EnvFilter;
mod auth;
mod password;
mod records;
2024-05-03 20:29:10 +02:00
const DEFAULT_TTL: Duration = Duration::from_secs(60);
const DEFAULT_SALT: &str = "UpdateMyDNS";
#[derive(Debug, Parser)]
struct Opts {
2024-10-26 13:23:22 +02:00
#[command(flatten)]
verbosity: Verbosity<clap_verbosity_flag::WarnLevel>,
2024-10-26 13:23:22 +02:00
2024-05-03 20:29:10 +02:00
/// Ip address of the server
#[arg(long, default_value = "127.0.0.1")]
address: IpAddr,
2024-05-03 20:29:10 +02:00
/// Port of the server
#[arg(long, default_value_t = 5353)]
port: u16,
2024-05-03 20:29:10 +02:00
/// File containing password to match against
///
/// Should be of the format `username:password` and contain a single password
#[arg(long)]
password_file: Option<PathBuf>,
2024-05-03 20:29:10 +02:00
/// Salt to get more unique hashed passwords and prevent table based attacks
#[arg(long, default_value = DEFAULT_SALT)]
salt: String,
2024-05-03 20:29:10 +02:00
/// Time To Live (in seconds) to set on the DNS records
#[arg(long, default_value_t = DEFAULT_TTL.as_secs())]
ttl: u64,
/// Data directory
#[arg(long, default_value = ".")]
data_dir: PathBuf,
2024-05-03 20:29:10 +02:00
/// File containing the records that should be updated when an update request is made
///
/// There should be one record per line:
///
/// ```text
/// example.com.
/// mail.example.com.
/// ```
#[arg(long)]
records: PathBuf,
2024-05-03 20:29:10 +02:00
/// Keyfile `nsupdate` should use
///
/// If specified, then `webnsupdate` must have read access to the file
#[arg(long)]
key_file: Option<PathBuf>,
/// Allow not setting a password
2024-05-03 20:29:10 +02:00
#[arg(long)]
insecure: bool,
/// Set client IP source
///
/// see: https://docs.rs/axum-client-ip/latest/axum_client_ip/enum.SecureClientIpSource.html
#[clap(long, default_value = "RightmostXForwardedFor")]
ip_source: SecureClientIpSource,
2024-05-03 20:29:10 +02:00
#[clap(subcommand)]
subcommand: Option<Cmd>,
}
#[derive(Debug, Subcommand)]
enum Cmd {
Mkpasswd(password::Mkpasswd),
2024-05-03 20:29:10 +02:00
/// Verify the records file
Verify,
}
impl Cmd {
pub fn process(self, args: &Opts) -> Result<()> {
match self {
Cmd::Mkpasswd(mkpasswd) => mkpasswd.process(args),
Cmd::Verify => records::load(&args.records).map(drop),
}
}
}
2024-05-03 20:29:10 +02:00
#[derive(Clone)]
struct AppState<'a> {
/// TTL set on the Zonefile
ttl: Duration,
2024-05-03 20:29:10 +02:00
/// The IN A/AAAA records that should have their IPs updated
records: &'a [&'a str],
2024-05-03 20:29:10 +02:00
/// The TSIG key file
key_file: Option<&'a Path>,
/// The file where the last IP is stored
ip_file: &'a Path,
}
fn load_ip(path: &Path) -> Result<Option<IpAddr>> {
2024-10-26 13:23:22 +02:00
debug!("loading last IP from {}", path.display());
let data = match std::fs::read_to_string(path) {
Ok(ip) => ip,
Err(err) => {
return match err.kind() {
ErrorKind::NotFound => Ok(None),
_ => Err(err).into_diagnostic().wrap_err_with(|| {
format!("failed to load last ip address from {}", path.display())
}),
}
}
};
Ok(Some(
data.parse()
.into_diagnostic()
.wrap_err("failed to parse last ip address")?,
))
2024-05-03 20:29:10 +02:00
}
fn main() -> Result<()> {
// set panic hook to pretty print with miette's formatter
2024-05-03 20:29:10 +02:00
miette::set_panic_hook();
// parse cli arguments
let mut args = Opts::parse();
// configure logger
2024-05-03 20:29:10 +02:00
let subscriber = tracing_subscriber::FmtSubscriber::builder()
.without_time()
.with_env_filter(
EnvFilter::builder()
.with_default_directive(args.verbosity.tracing_level_filter().into())
2024-05-03 20:29:10 +02:00
.from_env_lossy(),
)
.finish();
tracing::subscriber::set_global_default(subscriber)
.into_diagnostic()
.wrap_err("setting global tracing subscriber")?;
debug!("{args:?}");
// process subcommand
if let Some(cmd) = args.subcommand.take() {
return cmd.process(&args);
2024-05-03 20:29:10 +02:00
}
let Opts {
2024-10-26 13:23:22 +02:00
verbosity: _,
address: ip,
port,
password_file,
data_dir,
key_file,
insecure,
subcommand: _,
records,
salt,
ttl,
ip_source,
} = args;
2024-05-03 20:29:10 +02:00
info!("checking environment");
2024-05-03 20:29:10 +02:00
// Set state
let ttl = Duration::from_secs(ttl);
// Use last registered IP address if available
let ip_file = data_dir.join("last-ip");
// Load password hash
let password_hash = password_file
.map(|path| -> miette::Result<_> {
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?;
let pass: Box<[u8]> = URL_SAFE_NO_PAD
.decode(pass.trim().as_bytes())
.into_diagnostic()
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
.into();
Ok(pass)
})
.transpose()?;
let state = AppState {
2024-05-03 20:29:10 +02:00
ttl,
// Load DNS records
records: records::load_no_verify(&records)?,
// Load keyfile
key_file: key_file
.map(|key_file| -> miette::Result<_> {
let path = key_file.as_path();
std::fs::File::open(path)
.into_diagnostic()
.wrap_err_with(|| {
format!("{} is not readable by the current user", path.display())
})?;
Ok(&*Box::leak(key_file.into_boxed_path()))
})
.transpose()?,
ip_file: Box::leak(ip_file.into_boxed_path()),
2024-05-03 20:29:10 +02:00
};
ensure!(
password_hash.is_some() || insecure,
"a password must be used"
);
ensure!(
state.key_file.is_some() || insecure,
"a key file must be used"
);
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
2024-05-03 20:29:10 +02:00
.into_diagnostic()
.wrap_err("failed to start the tokio runtime")?;
rt.block_on(async {
// Load previous IP and update DNS record to point to it (if available)
match load_ip(state.ip_file) {
Ok(Some(ip)) => match nsupdate(ip, ttl, state.key_file, state.records).await {
Ok(status) => {
if !status.success() {
error!("nsupdate failed: code {status}");
bail!("nsupdate returned with code {status}");
}
}
Err(err) => {
error!("Failed to update records with previous IP: {err}");
return Err(err)
.into_diagnostic()
.wrap_err("failed to update records with previous IP");
}
},
Ok(None) => {
info!("No previous IP address set");
}
Err(err) => {
error!("Failed to load last ip address: {err}")
}
};
// Create services
let app = Router::new().route("/update", get(update_records));
// if a password is provided, validate it
let app = if let Some(pass) = password_hash {
app.layer(auth::auth_layer(Box::leak(pass), String::leak(salt)))
} else {
app
}
.layer(ip_source.into_extension())
.with_state(state);
// Start services
info!("starting listener on {ip}:{port}");
let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port))
.await
.into_diagnostic()?;
info!("listening on {ip}:{port}");
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
2024-05-03 20:29:10 +02:00
.await
.into_diagnostic()
})
2024-05-03 20:29:10 +02:00
}
#[tracing::instrument(skip(state), level = "trace", ret(level = "info"))]
2024-05-03 20:29:10 +02:00
async fn update_records(
State(state): State<AppState<'static>>,
SecureClientIp(ip): SecureClientIp,
2024-05-03 20:29:10 +02:00
) -> axum::response::Result<&'static str> {
2024-10-26 13:23:22 +02:00
debug!("received update request from {ip}");
info!("accepted update");
2024-05-03 20:29:10 +02:00
match nsupdate(ip, state.ttl, state.key_file, state.records).await {
Ok(status) if status.success() => {
tokio::task::spawn_blocking(move || {
2024-10-26 13:23:22 +02:00
info!("updating last ip to {ip}");
if let Err(err) = std::fs::write(state.ip_file, format!("{ip}")) {
error!("Failed to update last IP: {err}");
}
2024-10-26 13:23:22 +02:00
info!("updated last ip to {ip}");
});
Ok("successful update")
}
2024-05-03 20:29:10 +02:00
Ok(status) => {
error!("nsupdate failed with code {status}");
Err((
StatusCode::INTERNAL_SERVER_ERROR,
"nsupdate failed, check server logs",
)
.into())
2024-05-03 20:29:10 +02:00
}
Err(error) => Err((
StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to update records: {error}"),
)
.into()),
}
}
#[tracing::instrument(level = "trace", ret(level = "warn"))]
2024-05-03 20:29:10 +02:00
async fn nsupdate(
ip: IpAddr,
ttl: Duration,
key_file: Option<&Path>,
records: &[&str],
) -> std::io::Result<ExitStatus> {
let mut cmd = tokio::process::Command::new("nsupdate");
if let Some(key_file) = key_file {
cmd.args([OsStr::new("-k"), key_file.as_os_str()]);
}
debug!("spawning new process");
let mut child = cmd
.stdin(Stdio::piped())
.spawn()
.inspect_err(|err| warn!("failed to spawn child: {err}"))?;
2024-05-03 20:29:10 +02:00
let mut stdin = child.stdin.take().expect("stdin not present");
debug!("sending update request");
2024-05-03 20:29:10 +02:00
stdin
.write_all(update_ns_records(ip, ttl, records).as_bytes())
.await
.inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?;
debug!("closing stdin");
stdin
.shutdown()
.await
.inspect_err(|err| warn!("failed to close stdin to nsupdate: {err}"))?;
debug!("waiting for nsupdate to exit");
child
.wait()
.await
.inspect_err(|err| warn!("failed to wait for child: {err}"))
2024-05-03 20:29:10 +02:00
}
fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String {
use std::fmt::Write;
let ttl_s: u64 = ttl.as_secs();
let rec_type = match ip {
IpAddr::V4(_) => "A",
IpAddr::V6(_) => "AAAA",
};
let mut cmds = String::from("server 127.0.0.1\n");
for &record in records {
writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap();
writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap();
}
writeln!(cmds, "send\nquit").unwrap();
2024-05-03 20:29:10 +02:00
cmds
}
#[cfg(test)]
mod test {
use insta::assert_snapshot;
use crate::{update_ns_records, DEFAULT_TTL};
2024-05-03 20:29:10 +02:00
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
#[test]
#[allow(non_snake_case)]
fn expected_update_string_A() {
assert_snapshot!(update_ns_records(
IpAddr::V4(Ipv4Addr::LOCALHOST),
DEFAULT_TTL,
&["example.com.", "example.org.", "example.net."],
), @r###"
server 127.0.0.1
update delete example.com. 60 IN A
update add example.com. 60 IN A 127.0.0.1
update delete example.org. 60 IN A
update add example.org. 60 IN A 127.0.0.1
update delete example.net. 60 IN A
update add example.net. 60 IN A 127.0.0.1
send
quit
2024-05-03 20:29:10 +02:00
"###);
}
#[test]
#[allow(non_snake_case)]
fn expected_update_string_AAAA() {
assert_snapshot!(update_ns_records(
IpAddr::V6(Ipv6Addr::LOCALHOST),
DEFAULT_TTL,
&["example.com.", "example.org.", "example.net."],
), @r###"
server 127.0.0.1
update delete example.com. 60 IN AAAA
update add example.com. 60 IN AAAA ::1
update delete example.org. 60 IN AAAA
update add example.org. 60 IN AAAA ::1
update delete example.net. 60 IN AAAA
update add example.net. 60 IN AAAA ::1
send
quit
2024-05-03 20:29:10 +02:00
"###);
}
}