feat: replace axum-auth with tower_http
All checks were successful
/ build (push) Successful in 3s
/ check (push) Successful in 7s

Slightly more involde in the auth code, but it makes the rest of the
application more straight forward.

Fixes #10
This commit is contained in:
Jalil David Salamé Messina 2024-11-23 20:36:38 +01:00
parent 60aed649b1
commit 750cbbff93
Signed by: jalil
GPG key ID: F016B9E770737A0B
5 changed files with 166 additions and 75 deletions

37
Cargo.lock generated
View file

@ -120,18 +120,6 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "axum-auth"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8169113a185f54f68614fcfc3581df585d30bf8542bcb99496990e1025e4120a"
dependencies = [
"async-trait",
"axum-core",
"base64 0.21.7",
"http",
]
[[package]] [[package]]
name = "axum-client-ip" name = "axum-client-ip"
version = "0.6.1" version = "0.6.1"
@ -188,12 +176,6 @@ dependencies = [
"backtrace", "backtrace",
] ]
[[package]]
name = "base64"
version = "0.21.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567"
[[package]] [[package]]
name = "base64" name = "base64"
version = "0.22.1" version = "0.22.1"
@ -1044,6 +1026,21 @@ dependencies = [
"tracing", "tracing",
] ]
[[package]]
name = "tower-http"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697"
dependencies = [
"bitflags",
"bytes",
"http",
"mime",
"pin-project-lite",
"tower-layer",
"tower-service",
]
[[package]] [[package]]
name = "tower-layer" name = "tower-layer"
version = "0.3.3" version = "0.3.3"
@ -1165,9 +1162,8 @@ name = "webnsupdate"
version = "0.3.2-dev" version = "0.3.2-dev"
dependencies = [ dependencies = [
"axum", "axum",
"axum-auth",
"axum-client-ip", "axum-client-ip",
"base64 0.22.1", "base64",
"clap", "clap",
"clap-verbosity-flag", "clap-verbosity-flag",
"http", "http",
@ -1175,6 +1171,7 @@ dependencies = [
"miette", "miette",
"ring", "ring",
"tokio", "tokio",
"tower-http",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
] ]

View file

@ -8,9 +8,6 @@ edition = "2021"
[dependencies] [dependencies]
axum = "0.7" axum = "0.7"
axum-auth = { version = "0.7", default-features = false, features = [
"auth-basic",
] }
axum-client-ip = "0.6" axum-client-ip = "0.6"
base64 = "0.22" base64 = "0.22"
clap = { version = "4", features = ["derive", "env"] } clap = { version = "4", features = ["derive", "env"] }
@ -21,6 +18,7 @@ http = "1"
miette = { version = "7", features = ["fancy"] } miette = { version = "7", features = ["fancy"] }
ring = { version = "0.17", features = ["std"] } ring = { version = "0.17", features = ["std"] }
tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] } tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] }
tower-http = { version = "0.6.2", features = ["validate-request"] }
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }

104
src/auth.rs Normal file
View file

@ -0,0 +1,104 @@
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use tower_http::validate_request::ValidateRequestHeaderLayer;
use tracing::{trace, warn};
use crate::password;
pub fn auth_layer<'a, ResBody>(
user_pass_hash: &'a [u8],
salt: &'a str,
) -> ValidateRequestHeaderLayer<BasicAuth<'a, ResBody>> {
ValidateRequestHeaderLayer::custom(BasicAuth::new(user_pass_hash, salt))
}
#[derive(Copy)]
pub struct BasicAuth<'a, ResBody> {
pass: &'a [u8],
salt: &'a str,
_ty: std::marker::PhantomData<fn() -> ResBody>,
}
impl<ResBody> std::fmt::Debug for BasicAuth<'_, ResBody> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BasicAuth")
.field("pass", &self.pass)
.field("salt", &self.salt)
.field("_ty", &self._ty)
.finish()
}
}
impl<ResBody> Clone for BasicAuth<'_, ResBody> {
fn clone(&self) -> Self {
Self {
pass: self.pass,
salt: self.salt,
_ty: std::marker::PhantomData,
}
}
}
impl<'a, ResBody> BasicAuth<'a, ResBody> {
pub fn new(pass: &'a [u8], salt: &'a str) -> Self {
Self {
pass,
salt,
_ty: std::marker::PhantomData,
}
}
fn check_headers(&self, headers: &http::HeaderMap<http::HeaderValue>) -> bool {
let Some(auth) = headers.get(http::header::AUTHORIZATION) else {
return false;
};
// Poor man's split once: https://doc.rust-lang.org/std/primitive.slice.html#method.split_once
let Some(index) = auth.as_bytes().iter().position(|&c| c == b' ') else {
return false;
};
let user_pass = &auth.as_bytes()[index + 1..];
match base64::engine::general_purpose::URL_SAFE.decode(user_pass) {
Ok(user_pass) => {
let hashed = password::hash_basic_auth(&user_pass, self.salt);
if hashed.as_ref() == self.pass {
return true;
}
warn!("rejected update");
trace!(
"mismatched hashes:\nprovided: {}\nstored: {}",
URL_SAFE_NO_PAD.encode(hashed.as_ref()),
URL_SAFE_NO_PAD.encode(self.pass),
);
false
}
Err(err) => {
warn!("received invalid base64 when decoding Basic header: {err}");
false
}
}
}
}
impl<B, ResBody> tower_http::validate_request::ValidateRequest<B> for BasicAuth<'_, ResBody>
where
ResBody: Default,
{
type ResponseBody = ResBody;
fn validate(
&mut self,
request: &mut http::Request<B>,
) -> std::result::Result<(), http::Response<Self::ResponseBody>> {
if self.check_headers(request.headers()) {
return Ok(());
}
let mut res = http::Response::new(ResBody::default());
*res.status_mut() = http::status::StatusCode::UNAUTHORIZED;
res.headers_mut()
.insert(http::header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
Err(res)
}
}

View file

@ -7,8 +7,7 @@ use std::{
time::Duration, time::Duration,
}; };
use axum::{extract::State, routing::get, Json, Router}; use axum::{extract::State, routing::get, Router};
use axum_auth::AuthBasic;
use axum_client_ip::{SecureClientIp, SecureClientIpSource}; use axum_client_ip::{SecureClientIp, SecureClientIpSource};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
@ -16,9 +15,10 @@ use clap_verbosity_flag::Verbosity;
use http::StatusCode; use http::StatusCode;
use miette::{bail, ensure, Context, IntoDiagnostic, Result}; use miette::{bail, ensure, Context, IntoDiagnostic, Result};
use tokio::io::AsyncWriteExt; use tokio::io::AsyncWriteExt;
use tracing::{debug, error, info, trace, warn}; use tracing::{debug, error, info, warn};
use tracing_subscriber::EnvFilter; use tracing_subscriber::EnvFilter;
mod auth;
mod password; mod password;
mod records; mod records;
@ -108,18 +108,12 @@ struct AppState<'a> {
/// TTL set on the Zonefile /// TTL set on the Zonefile
ttl: Duration, ttl: Duration,
/// Salt added to the password
salt: &'a str,
/// The IN A/AAAA records that should have their IPs updated /// The IN A/AAAA records that should have their IPs updated
records: &'a [&'a str], records: &'a [&'a str],
/// The TSIG key file /// The TSIG key file
key_file: Option<&'a Path>, key_file: Option<&'a Path>,
/// The password hash
password_hash: Option<&'a [u8]>,
/// The file where the last IP is stored /// The file where the last IP is stored
ip_file: &'a Path, ip_file: &'a Path,
} }
@ -195,9 +189,23 @@ fn main() -> Result<()> {
// Use last registered IP address if available // Use last registered IP address if available
let ip_file = data_dir.join("last-ip"); let ip_file = data_dir.join("last-ip");
// Load password hash
let password_hash = password_file
.map(|path| -> miette::Result<_> {
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?;
let pass: Box<[u8]> = URL_SAFE_NO_PAD
.decode(pass.trim().as_bytes())
.into_diagnostic()
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
.into();
Ok(pass)
})
.transpose()?;
let state = AppState { let state = AppState {
ttl, ttl,
salt: salt.leak(),
// Load DNS records // Load DNS records
records: records::load_no_verify(&records)?, records: records::load_no_verify(&records)?,
// Load keyfile // Load keyfile
@ -212,25 +220,11 @@ fn main() -> Result<()> {
Ok(&*Box::leak(key_file.into_boxed_path())) Ok(&*Box::leak(key_file.into_boxed_path()))
}) })
.transpose()?, .transpose()?,
// Load password hash
password_hash: password_file
.map(|path| -> miette::Result<_> {
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?;
let pass: Box<[u8]> = URL_SAFE_NO_PAD
.decode(pass.trim().as_bytes())
.into_diagnostic()
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
.into();
Ok(&*Box::leak(pass))
})
.transpose()?,
ip_file: Box::leak(ip_file.into_boxed_path()), ip_file: Box::leak(ip_file.into_boxed_path()),
}; };
ensure!( ensure!(
state.password_hash.is_some() || insecure, password_hash.is_some() || insecure,
"a password must be used" "a password must be used"
); );
@ -270,11 +264,18 @@ fn main() -> Result<()> {
} }
}; };
// Start services // Create services
let app = Router::new() let app = Router::new().route("/update", get(update_records));
.route("/update", get(update_records)) // if a password is provided, validate it
let app = if let Some(pass) = password_hash {
app.layer(auth::auth_layer(Box::leak(pass), String::leak(salt)))
} else {
app
}
.layer(ip_source.into_extension()) .layer(ip_source.into_extension())
.with_state(state); .with_state(state);
// Start services
info!("starting listener on {ip}:{port}"); info!("starting listener on {ip}:{port}");
let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port)) let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port))
.await .await
@ -289,31 +290,12 @@ fn main() -> Result<()> {
}) })
} }
#[tracing::instrument(skip(state, pass), level = "trace", ret(level = "info"))] #[tracing::instrument(skip(state), level = "trace", ret(level = "info"))]
async fn update_records( async fn update_records(
State(state): State<AppState<'static>>, State(state): State<AppState<'static>>,
AuthBasic((username, pass)): AuthBasic,
SecureClientIp(ip): SecureClientIp, SecureClientIp(ip): SecureClientIp,
) -> axum::response::Result<&'static str> { ) -> axum::response::Result<&'static str> {
debug!("received update request from {ip}"); debug!("received update request from {ip}");
let Some(pass) = pass else {
return Err((StatusCode::UNAUTHORIZED, Json::from("no password provided")).into());
};
if let Some(stored_pass) = state.password_hash {
let password = pass.trim().to_string();
let pass_hash = password::hash_identity(&username, &password, state.salt);
if pass_hash.as_ref() != stored_pass {
warn!("rejected update");
trace!(
"mismatched hashes:\n{}\n{}",
URL_SAFE_NO_PAD.encode(pass_hash.as_ref()),
URL_SAFE_NO_PAD.encode(stored_pass),
);
return Err((StatusCode::UNAUTHORIZED, "invalid identity").into());
}
}
info!("accepted update"); info!("accepted update");
match nsupdate(ip, state.ttl, state.key_file, state.records).await { match nsupdate(ip, state.ttl, state.key_file, state.records).await {
Ok(status) if status.success() => { Ok(status) if status.success() => {

View file

@ -28,10 +28,20 @@ impl Mkpasswd {
} }
} }
pub fn hash_basic_auth(user_pass: &[u8], salt: &str) -> Digest {
let mut context = ring::digest::Context::new(&ring::digest::SHA256);
context.update(user_pass);
context.update(salt.as_bytes());
context.finish()
}
pub fn hash_identity(username: &str, password: &str, salt: &str) -> Digest { pub fn hash_identity(username: &str, password: &str, salt: &str) -> Digest {
let mut data = Vec::with_capacity(username.len() + password.len() + salt.len() + 1); let mut context = ring::digest::Context::new(&ring::digest::SHA256);
write!(data, "{username}:{password}{salt}").unwrap(); context.update(username.as_bytes());
ring::digest::digest(&ring::digest::SHA256, &data) context.update(b":");
context.update(password.as_bytes());
context.update(salt.as_bytes());
context.finish()
} }
pub fn mkpasswd( pub fn mkpasswd(