use std::{ io::ErrorKind, net::{IpAddr, SocketAddr}, path::{Path, PathBuf}, time::Duration, }; use axum::{extract::State, routing::get, Router}; use axum_client_ip::{SecureClientIp, SecureClientIpSource}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use clap::{Parser, Subcommand}; use clap_verbosity_flag::Verbosity; use http::StatusCode; use miette::{bail, ensure, Context, IntoDiagnostic, Result}; use tracing::{debug, error, info}; use tracing_subscriber::EnvFilter; mod auth; mod nsupdate; mod password; mod records; const DEFAULT_TTL: Duration = Duration::from_secs(60); const DEFAULT_SALT: &str = "UpdateMyDNS"; #[derive(Debug, Parser)] struct Opts { #[command(flatten)] verbosity: Verbosity, /// Ip address of the server #[arg(long, default_value = "127.0.0.1")] address: IpAddr, /// Port of the server #[arg(long, default_value_t = 5353)] port: u16, /// File containing password to match against /// /// Should be of the format `username:password` and contain a single password #[arg(long)] password_file: Option, /// Salt to get more unique hashed passwords and prevent table based attacks #[arg(long, default_value = DEFAULT_SALT)] salt: String, /// Time To Live (in seconds) to set on the DNS records #[arg(long, default_value_t = DEFAULT_TTL.as_secs())] ttl: u64, /// Data directory #[arg(long, default_value = ".")] data_dir: PathBuf, /// File containing the records that should be updated when an update request is made /// /// There should be one record per line: /// /// ```text /// example.com. /// mail.example.com. /// ``` #[arg(long)] records: PathBuf, /// Keyfile `nsupdate` should use /// /// If specified, then `webnsupdate` must have read access to the file #[arg(long)] key_file: Option, /// Allow not setting a password #[arg(long)] insecure: bool, /// Set client IP source /// /// see: #[clap(long, default_value = "RightmostXForwardedFor")] ip_source: SecureClientIpSource, #[clap(subcommand)] subcommand: Option, } #[derive(Debug, Subcommand)] enum Cmd { Mkpasswd(password::Mkpasswd), /// Verify the records file Verify, } impl Cmd { pub fn process(self, args: &Opts) -> Result<()> { match self { Cmd::Mkpasswd(mkpasswd) => mkpasswd.process(args), Cmd::Verify => records::load(&args.records).map(drop), } } } #[derive(Clone)] struct AppState<'a> { /// TTL set on the Zonefile ttl: Duration, /// The IN A/AAAA records that should have their IPs updated records: &'a [&'a str], /// The TSIG key file key_file: Option<&'a Path>, /// The file where the last IP is stored ip_file: &'a Path, } impl AppState<'static> { fn from_args(args: &Opts) -> miette::Result { let Opts { verbosity: _, address: _, port: _, password_file: _, data_dir, key_file, insecure, subcommand: _, records, salt: _, ttl, ip_source: _, } = args; // Set state let ttl = Duration::from_secs(*ttl); // Use last registered IP address if available let ip_file = data_dir.join("last-ip"); let state = AppState { ttl, // Load DNS records records: records::load_no_verify(records)?, // Load keyfile key_file: key_file .as_deref() .map(|path| -> miette::Result<_> { std::fs::File::open(path) .into_diagnostic() .wrap_err_with(|| { format!("{} is not readable by the current user", path.display()) })?; Ok(&*Box::leak(path.into())) }) .transpose()?, ip_file: Box::leak(ip_file.into_boxed_path()), }; ensure!( state.key_file.is_some() || *insecure, "a key file must be used" ); Ok(state) } } fn load_ip(path: &Path) -> Result> { debug!("loading last IP from {}", path.display()); let data = match std::fs::read_to_string(path) { Ok(ip) => ip, Err(err) => { return match err.kind() { ErrorKind::NotFound => Ok(None), _ => Err(err).into_diagnostic().wrap_err_with(|| { format!("failed to load last ip address from {}", path.display()) }), } } }; Ok(Some( data.parse() .into_diagnostic() .wrap_err("failed to parse last ip address")?, )) } #[tracing::instrument(err)] fn main() -> Result<()> { // set panic hook to pretty print with miette's formatter miette::set_panic_hook(); // parse cli arguments let mut args = Opts::parse(); // configure logger let subscriber = tracing_subscriber::FmtSubscriber::builder() .without_time() .with_env_filter( EnvFilter::builder() .with_default_directive(args.verbosity.tracing_level_filter().into()) .from_env_lossy(), ) .finish(); tracing::subscriber::set_global_default(subscriber) .into_diagnostic() .wrap_err("failed to set global tracing subscriber")?; debug!("{args:?}"); // process subcommand if let Some(cmd) = args.subcommand.take() { return cmd.process(&args); } // Initialize state let state = AppState::from_args(&args)?; let Opts { verbosity: _, address: ip, port, password_file, data_dir: _, key_file: _, insecure, subcommand: _, records: _, salt, ttl: _, ip_source, } = args; info!("checking environment"); // Load password hash let password_hash = password_file .map(|path| -> miette::Result<_> { let path = path.as_path(); let pass = std::fs::read_to_string(path).into_diagnostic()?; let pass: Box<[u8]> = URL_SAFE_NO_PAD .decode(pass.trim().as_bytes()) .into_diagnostic() .wrap_err_with(|| format!("failed to decode password from {}", path.display()))? .into(); Ok(pass) }) .transpose() .wrap_err("failed to load password hash")?; ensure!( password_hash.is_some() || insecure, "a password must be used" ); let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() .into_diagnostic() .wrap_err("failed to start the tokio runtime")?; rt.block_on(async { // Load previous IP and update DNS record to point to it (if available) match load_ip(state.ip_file) { Ok(Some(ip)) => { match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { Ok(status) => { if !status.success() { error!("nsupdate failed: code {status}"); bail!("nsupdate returned with code {status}"); } } Err(err) => { error!("Failed to update records with previous IP: {err}"); return Err(err) .into_diagnostic() .wrap_err("failed to update records with previous IP"); } } } Ok(None) => info!("No previous IP address set"), Err(err) => error!("Ignoring previous IP due to: {err}"), }; // Create services let app = Router::new().route("/update", get(update_records)); // if a password is provided, validate it let app = if let Some(pass) = password_hash { app.layer(auth::layer(Box::leak(pass), String::leak(salt))) } else { app } .layer(ip_source.into_extension()) .with_state(state); // Start services info!("starting listener on {ip}:{port}"); let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port)) .await .into_diagnostic()?; info!("listening on {ip}:{port}"); axum::serve( listener, app.into_make_service_with_connect_info::(), ) .await .into_diagnostic() }) .wrap_err("failed to run main loop") } #[tracing::instrument(skip(state), level = "trace", ret(level = "info"))] async fn update_records( State(state): State>, SecureClientIp(ip): SecureClientIp, ) -> axum::response::Result<&'static str> { info!("accepted update from {ip}"); match nsupdate::nsupdate(ip, state.ttl, state.key_file, state.records).await { Ok(status) if status.success() => { tokio::task::spawn_blocking(move || { info!("updating last ip to {ip}"); if let Err(err) = std::fs::write(state.ip_file, format!("{ip}")) { error!("Failed to update last IP: {err}"); } info!("updated last ip to {ip}"); }); Ok("successful update") } Ok(status) => { error!("nsupdate failed with code {status}"); Err(( StatusCode::INTERNAL_SERVER_ERROR, "nsupdate failed, check server logs", ) .into()) } Err(error) => Err(( StatusCode::INTERNAL_SERVER_ERROR, format!("failed to update records: {error}"), ) .into()), } }