[fix] everything: various bugs found in production

This commit is contained in:
Jalil David Salamé Messina 2024-05-09 00:59:43 +02:00
parent 68658bf83f
commit 15e2d2da06
Signed by: jalil
GPG key ID: F016B9E770737A0B
7 changed files with 138 additions and 178 deletions

View file

@ -8,19 +8,16 @@ use std::{
time::Duration,
};
use axum::{
extract::{ConnectInfo, State},
routing::get,
Json, Router,
};
use axum::{extract::State, routing::get, Json, Router};
use axum_auth::AuthBasic;
use axum_client_ip::{SecureClientIp, SecureClientIpSource};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use clap::{Args, Parser, Subcommand};
use http::StatusCode;
use miette::{ensure, miette, Context, IntoDiagnostic, LabeledSpan, NamedSource, Result};
use ring::digest::Digest;
use tokio::io::AsyncWriteExt;
use tracing::{info, level_filters::LevelFilter, warn};
use tracing::{debug, error, info, level_filters::LevelFilter, trace, warn};
use tracing_subscriber::EnvFilter;
const DEFAULT_TTL: Duration = Duration::from_secs(60);
@ -60,9 +57,14 @@ struct Opts {
/// If specified, then `webnsupdate` must have read access to the file
#[arg(long)]
key_file: Option<PathBuf>,
/// Allow not setting a password when the server is exposed to the network
/// Allow not setting a password
#[arg(long)]
insecure: bool,
/// Set client IP source
///
/// see: https://docs.rs/axum-client-ip/latest/axum_client_ip/enum.SecureClientIpSource.html
#[clap(long, default_value = "RightmostXForwardedFor")]
ip_source: SecureClientIpSource,
#[clap(subcommand)]
subcommand: Option<Cmd>,
}
@ -112,6 +114,7 @@ async fn main() -> Result<()> {
records,
salt,
ttl,
ip_source,
} = Opts::parse();
let subscriber = tracing_subscriber::FmtSubscriber::builder()
.without_time()
@ -144,9 +147,14 @@ async fn main() -> Result<()> {
key_file: None,
password_hash: None,
};
if let Some(password_file) = password_file {
let pass = std::fs::read_to_string(password_file).into_diagnostic()?;
let pass: Box<[u8]> = pass.trim().as_bytes().into();
if let Some(path) = password_file {
let pass = std::fs::read_to_string(&path).into_diagnostic()?;
let pass: Box<[u8]> = URL_SAFE_NO_PAD
.decode(pass.trim().as_bytes())
.into_diagnostic()
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
.into();
state.password_hash = Some(Box::leak(pass));
} else {
ensure!(insecure, "a password must be used");
@ -174,6 +182,7 @@ async fn main() -> Result<()> {
// Start services
let app = Router::new()
.route("/update", get(update_records))
.layer(ip_source.into_extension())
.with_state(state);
info!("starting listener on {ip}:{port}");
let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port))
@ -188,29 +197,35 @@ async fn main() -> Result<()> {
.into_diagnostic()
}
#[tracing::instrument(skip(state), level = "trace", ret(level = "warn"))]
#[tracing::instrument(skip(state, pass), level = "trace", ret(level = "info"))]
async fn update_records(
State(state): State<AppState<'static>>,
AuthBasic((username, pass)): AuthBasic,
ConnectInfo(client): ConnectInfo<SocketAddr>,
SecureClientIp(ip): SecureClientIp,
) -> axum::response::Result<&'static str> {
let Some(pass) = pass else {
return Err((StatusCode::UNAUTHORIZED, Json::from("no password provided")).into());
};
if let Some(stored_pass) = state.password_hash {
let password = pass.trim().to_string();
if hash_identity(&username, &password, state.salt).as_ref() != stored_pass {
warn!("rejected update from {username}@{client}");
let pass_hash = hash_identity(&username, &password, state.salt);
if pass_hash.as_ref() != stored_pass {
warn!("rejected update");
trace!(
"mismatched hashes:\n{}\n{}",
URL_SAFE_NO_PAD.encode(pass_hash.as_ref()),
URL_SAFE_NO_PAD.encode(stored_pass.as_ref()),
);
return Err((StatusCode::UNAUTHORIZED, "invalid identity").into());
}
}
let ip = client.ip();
info!("accepted update");
match nsupdate(ip, state.ttl, state.key_file, state.records).await {
Ok(status) => {
if status.success() {
Ok("successful update")
} else {
error!("nsupdate failed");
Err((
StatusCode::INTERNAL_SERVER_ERROR,
"nsupdate failed, check server logs",
@ -226,6 +241,7 @@ async fn update_records(
}
}
#[tracing::instrument(level = "trace", ret(level = "warn"))]
async fn nsupdate(
ip: IpAddr,
ttl: Duration,
@ -236,13 +252,27 @@ async fn nsupdate(
if let Some(key_file) = key_file {
cmd.args([OsStr::new("-k"), key_file.as_os_str()]);
}
cmd.stdin(Stdio::piped());
let mut child = cmd.spawn()?;
debug!("spawning new process");
let mut child = cmd
.stdin(Stdio::piped())
.spawn()
.inspect_err(|err| warn!("failed to spawn child: {err}"))?;
let mut stdin = child.stdin.take().expect("stdin not present");
debug!("sending update request");
stdin
.write_all(update_ns_records(ip, ttl, records).as_bytes())
.await?;
child.wait().await
.await
.inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?;
debug!("closing stdin");
stdin
.shutdown()
.await
.inspect_err(|err| warn!("failed to close stdin to nsupdate: {err}"))?;
debug!("waiting for nsupdate to exit");
child
.wait()
.await
.inspect_err(|err| warn!("failed to wait for child: {err}"))
}
fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String {
@ -258,7 +288,7 @@ fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String {
writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap();
writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap();
}
writeln!(cmds, "send").unwrap();
writeln!(cmds, "send\nquit").unwrap();
cmds
}
@ -423,6 +453,7 @@ mod test {
update delete example.net. 60 IN A
update add example.net. 60 IN A 127.0.0.1
send
quit
"###);
}
@ -442,6 +473,7 @@ mod test {
update delete example.net. 60 IN AAAA
update add example.net. 60 IN AAAA ::1
send
quit
"###);
}