[fix] everything: various bugs found in production
This commit is contained in:
parent
68658bf83f
commit
15e2d2da06
7 changed files with 138 additions and 178 deletions
74
src/main.rs
74
src/main.rs
|
@ -8,19 +8,16 @@ use std::{
|
|||
time::Duration,
|
||||
};
|
||||
|
||||
use axum::{
|
||||
extract::{ConnectInfo, State},
|
||||
routing::get,
|
||||
Json, Router,
|
||||
};
|
||||
use axum::{extract::State, routing::get, Json, Router};
|
||||
use axum_auth::AuthBasic;
|
||||
use axum_client_ip::{SecureClientIp, SecureClientIpSource};
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use http::StatusCode;
|
||||
use miette::{ensure, miette, Context, IntoDiagnostic, LabeledSpan, NamedSource, Result};
|
||||
use ring::digest::Digest;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tracing::{info, level_filters::LevelFilter, warn};
|
||||
use tracing::{debug, error, info, level_filters::LevelFilter, trace, warn};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
const DEFAULT_TTL: Duration = Duration::from_secs(60);
|
||||
|
@ -60,9 +57,14 @@ struct Opts {
|
|||
/// If specified, then `webnsupdate` must have read access to the file
|
||||
#[arg(long)]
|
||||
key_file: Option<PathBuf>,
|
||||
/// Allow not setting a password when the server is exposed to the network
|
||||
/// Allow not setting a password
|
||||
#[arg(long)]
|
||||
insecure: bool,
|
||||
/// Set client IP source
|
||||
///
|
||||
/// see: https://docs.rs/axum-client-ip/latest/axum_client_ip/enum.SecureClientIpSource.html
|
||||
#[clap(long, default_value = "RightmostXForwardedFor")]
|
||||
ip_source: SecureClientIpSource,
|
||||
#[clap(subcommand)]
|
||||
subcommand: Option<Cmd>,
|
||||
}
|
||||
|
@ -112,6 +114,7 @@ async fn main() -> Result<()> {
|
|||
records,
|
||||
salt,
|
||||
ttl,
|
||||
ip_source,
|
||||
} = Opts::parse();
|
||||
let subscriber = tracing_subscriber::FmtSubscriber::builder()
|
||||
.without_time()
|
||||
|
@ -144,9 +147,14 @@ async fn main() -> Result<()> {
|
|||
key_file: None,
|
||||
password_hash: None,
|
||||
};
|
||||
if let Some(password_file) = password_file {
|
||||
let pass = std::fs::read_to_string(password_file).into_diagnostic()?;
|
||||
let pass: Box<[u8]> = pass.trim().as_bytes().into();
|
||||
if let Some(path) = password_file {
|
||||
let pass = std::fs::read_to_string(&path).into_diagnostic()?;
|
||||
|
||||
let pass: Box<[u8]> = URL_SAFE_NO_PAD
|
||||
.decode(pass.trim().as_bytes())
|
||||
.into_diagnostic()
|
||||
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
|
||||
.into();
|
||||
state.password_hash = Some(Box::leak(pass));
|
||||
} else {
|
||||
ensure!(insecure, "a password must be used");
|
||||
|
@ -174,6 +182,7 @@ async fn main() -> Result<()> {
|
|||
// Start services
|
||||
let app = Router::new()
|
||||
.route("/update", get(update_records))
|
||||
.layer(ip_source.into_extension())
|
||||
.with_state(state);
|
||||
info!("starting listener on {ip}:{port}");
|
||||
let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port))
|
||||
|
@ -188,29 +197,35 @@ async fn main() -> Result<()> {
|
|||
.into_diagnostic()
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(state), level = "trace", ret(level = "warn"))]
|
||||
#[tracing::instrument(skip(state, pass), level = "trace", ret(level = "info"))]
|
||||
async fn update_records(
|
||||
State(state): State<AppState<'static>>,
|
||||
AuthBasic((username, pass)): AuthBasic,
|
||||
ConnectInfo(client): ConnectInfo<SocketAddr>,
|
||||
SecureClientIp(ip): SecureClientIp,
|
||||
) -> axum::response::Result<&'static str> {
|
||||
let Some(pass) = pass else {
|
||||
return Err((StatusCode::UNAUTHORIZED, Json::from("no password provided")).into());
|
||||
};
|
||||
if let Some(stored_pass) = state.password_hash {
|
||||
let password = pass.trim().to_string();
|
||||
|
||||
if hash_identity(&username, &password, state.salt).as_ref() != stored_pass {
|
||||
warn!("rejected update from {username}@{client}");
|
||||
let pass_hash = hash_identity(&username, &password, state.salt);
|
||||
if pass_hash.as_ref() != stored_pass {
|
||||
warn!("rejected update");
|
||||
trace!(
|
||||
"mismatched hashes:\n{}\n{}",
|
||||
URL_SAFE_NO_PAD.encode(pass_hash.as_ref()),
|
||||
URL_SAFE_NO_PAD.encode(stored_pass.as_ref()),
|
||||
);
|
||||
return Err((StatusCode::UNAUTHORIZED, "invalid identity").into());
|
||||
}
|
||||
}
|
||||
let ip = client.ip();
|
||||
info!("accepted update");
|
||||
match nsupdate(ip, state.ttl, state.key_file, state.records).await {
|
||||
Ok(status) => {
|
||||
if status.success() {
|
||||
Ok("successful update")
|
||||
} else {
|
||||
error!("nsupdate failed");
|
||||
Err((
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"nsupdate failed, check server logs",
|
||||
|
@ -226,6 +241,7 @@ async fn update_records(
|
|||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(level = "trace", ret(level = "warn"))]
|
||||
async fn nsupdate(
|
||||
ip: IpAddr,
|
||||
ttl: Duration,
|
||||
|
@ -236,13 +252,27 @@ async fn nsupdate(
|
|||
if let Some(key_file) = key_file {
|
||||
cmd.args([OsStr::new("-k"), key_file.as_os_str()]);
|
||||
}
|
||||
cmd.stdin(Stdio::piped());
|
||||
let mut child = cmd.spawn()?;
|
||||
debug!("spawning new process");
|
||||
let mut child = cmd
|
||||
.stdin(Stdio::piped())
|
||||
.spawn()
|
||||
.inspect_err(|err| warn!("failed to spawn child: {err}"))?;
|
||||
let mut stdin = child.stdin.take().expect("stdin not present");
|
||||
debug!("sending update request");
|
||||
stdin
|
||||
.write_all(update_ns_records(ip, ttl, records).as_bytes())
|
||||
.await?;
|
||||
child.wait().await
|
||||
.await
|
||||
.inspect_err(|err| warn!("failed to write to the stdin of nsupdate: {err}"))?;
|
||||
debug!("closing stdin");
|
||||
stdin
|
||||
.shutdown()
|
||||
.await
|
||||
.inspect_err(|err| warn!("failed to close stdin to nsupdate: {err}"))?;
|
||||
debug!("waiting for nsupdate to exit");
|
||||
child
|
||||
.wait()
|
||||
.await
|
||||
.inspect_err(|err| warn!("failed to wait for child: {err}"))
|
||||
}
|
||||
|
||||
fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String {
|
||||
|
@ -258,7 +288,7 @@ fn update_ns_records(ip: IpAddr, ttl: Duration, records: &[&str]) -> String {
|
|||
writeln!(cmds, "update delete {record} {ttl_s} IN {rec_type}").unwrap();
|
||||
writeln!(cmds, "update add {record} {ttl_s} IN {rec_type} {ip}").unwrap();
|
||||
}
|
||||
writeln!(cmds, "send").unwrap();
|
||||
writeln!(cmds, "send\nquit").unwrap();
|
||||
cmds
|
||||
}
|
||||
|
||||
|
@ -423,6 +453,7 @@ mod test {
|
|||
update delete example.net. 60 IN A
|
||||
update add example.net. 60 IN A 127.0.0.1
|
||||
send
|
||||
quit
|
||||
"###);
|
||||
}
|
||||
|
||||
|
@ -442,6 +473,7 @@ mod test {
|
|||
update delete example.net. 60 IN AAAA
|
||||
update add example.net. 60 IN AAAA ::1
|
||||
send
|
||||
quit
|
||||
"###);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue