feat: replace axum-auth with tower_http
Some checks failed
/ build (push) Failing after 4s
/ check (push) Has been skipped

Slightly more involde in the auth code, but it makes the rest of the
application more straight forward.

Fixes #10
This commit is contained in:
Jalil David Salamé Messina 2024-11-23 20:36:38 +01:00
parent 60aed649b1
commit c43ca438e6
Signed by: jalil
GPG key ID: F016B9E770737A0B
5 changed files with 166 additions and 55 deletions

16
Cargo.lock generated
View file

@ -1044,6 +1044,21 @@ dependencies = [
"tracing",
]
[[package]]
name = "tower-http"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697"
dependencies = [
"bitflags",
"bytes",
"http",
"mime",
"pin-project-lite",
"tower-layer",
"tower-service",
]
[[package]]
name = "tower-layer"
version = "0.3.3"
@ -1175,6 +1190,7 @@ dependencies = [
"miette",
"ring",
"tokio",
"tower-http",
"tracing",
"tracing-subscriber",
]

View file

@ -8,9 +8,6 @@ edition = "2021"
[dependencies]
axum = "0.7"
axum-auth = { version = "0.7", default-features = false, features = [
"auth-basic",
] }
axum-client-ip = "0.6"
base64 = "0.22"
clap = { version = "4", features = ["derive", "env"] }
@ -21,6 +18,7 @@ http = "1"
miette = { version = "7", features = ["fancy"] }
ring = { version = "0.17", features = ["std"] }
tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] }
tower-http = { version = "0.6.2", features = ["validate-request"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }

105
src/auth.rs Normal file
View file

@ -0,0 +1,105 @@
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use tower_http::validate_request::ValidateRequestHeaderLayer;
use tracing::{trace, warn};
use crate::password;
pub fn auth_layer<'a, ResBody>(
user_pass_hash: &'a [u8],
salt: &'a str,
) -> ValidateRequestHeaderLayer<BasicAuth<'a, ResBody>> {
ValidateRequestHeaderLayer::custom(BasicAuth::new(user_pass_hash, salt))
}
#[derive(Copy)]
pub struct BasicAuth<'a, ResBody> {
pass: &'a [u8],
salt: &'a str,
_ty: std::marker::PhantomData<fn() -> ResBody>,
}
impl<ResBody> std::fmt::Debug for BasicAuth<'_, ResBody> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BasicAuth")
.field("pass", &self.pass)
.field("salt", &self.salt)
.field("_ty", &self._ty)
.finish()
}
}
impl<ResBody> Clone for BasicAuth<'_, ResBody> {
fn clone(&self) -> Self {
Self {
pass: self.pass,
salt: self.salt,
_ty: std::marker::PhantomData,
}
}
}
impl<'a, ResBody> BasicAuth<'a, ResBody> {
pub fn new(pass: &'a [u8], salt: &'a str) -> Self {
Self {
pass,
salt,
_ty: std::marker::PhantomData,
}
}
fn check_headers(&self, headers: &http::HeaderMap<http::HeaderValue>) -> bool {
let Some(auth) = headers.get(http::header::AUTHORIZATION) else {
return false;
};
// Poor man's split once: https://doc.rust-lang.org/std/primitive.slice.html#method.split_once
let Some(index) = auth.as_bytes().iter().position(|&c| c == b' ') else {
return false;
};
let user_pass = &auth.as_bytes()[index + 1..];
match base64::engine::general_purpose::URL_SAFE.decode(user_pass) {
Ok(user_pass) => {
let hashed = password::hash_basic_auth(&user_pass, self.salt);
if hashed.as_ref() == self.pass {
return true;
}
warn!("rejected update");
trace!(
"mismatched hashes:\nprovided: {}\nstored: {}",
URL_SAFE_NO_PAD.encode(hashed.as_ref()),
URL_SAFE_NO_PAD.encode(self.pass),
);
false
}
Err(err) => {
warn!("received invalid base64 when decoding Basic header: {err}");
false
}
}
}
}
impl<B, ResBody> tower_http::validate_request::ValidateRequest<B> for BasicAuth<'_, ResBody>
where
ResBody: Default,
{
type ResponseBody = ResBody;
fn validate(
&mut self,
request: &mut http::Request<B>,
) -> std::result::Result<(), http::Response<Self::ResponseBody>> {
if self.check_headers(request.headers()) {
return Ok(());
}
let mut res = http::Response::new(ResBody::default());
*res.status_mut() = http::status::StatusCode::UNAUTHORIZED;
res.headers_mut()
.insert(http::header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
Err(res)
}
}

View file

@ -7,8 +7,7 @@ use std::{
time::Duration,
};
use axum::{extract::State, routing::get, Json, Router};
use axum_auth::AuthBasic;
use axum::{extract::State, routing::get, Router};
use axum_client_ip::{SecureClientIp, SecureClientIpSource};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
use clap::{Parser, Subcommand};
@ -16,9 +15,10 @@ use clap_verbosity_flag::Verbosity;
use http::StatusCode;
use miette::{bail, ensure, Context, IntoDiagnostic, Result};
use tokio::io::AsyncWriteExt;
use tracing::{debug, error, info, trace, warn};
use tracing::{debug, error, info, warn};
use tracing_subscriber::EnvFilter;
mod auth;
mod password;
mod records;
@ -108,18 +108,12 @@ struct AppState<'a> {
/// TTL set on the Zonefile
ttl: Duration,
/// Salt added to the password
salt: &'a str,
/// The IN A/AAAA records that should have their IPs updated
records: &'a [&'a str],
/// The TSIG key file
key_file: Option<&'a Path>,
/// The password hash
password_hash: Option<&'a [u8]>,
/// The file where the last IP is stored
ip_file: &'a Path,
}
@ -195,9 +189,23 @@ fn main() -> Result<()> {
// Use last registered IP address if available
let ip_file = data_dir.join("last-ip");
// Load password hash
let password_hash = password_file
.map(|path| -> miette::Result<_> {
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?;
let pass: Box<[u8]> = URL_SAFE_NO_PAD
.decode(pass.trim().as_bytes())
.into_diagnostic()
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
.into();
Ok(pass)
})
.transpose()?;
let state = AppState {
ttl,
salt: salt.leak(),
// Load DNS records
records: records::load_no_verify(&records)?,
// Load keyfile
@ -212,25 +220,11 @@ fn main() -> Result<()> {
Ok(&*Box::leak(key_file.into_boxed_path()))
})
.transpose()?,
// Load password hash
password_hash: password_file
.map(|path| -> miette::Result<_> {
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?;
let pass: Box<[u8]> = URL_SAFE_NO_PAD
.decode(pass.trim().as_bytes())
.into_diagnostic()
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
.into();
Ok(&*Box::leak(pass))
})
.transpose()?,
ip_file: Box::leak(ip_file.into_boxed_path()),
};
ensure!(
state.password_hash.is_some() || insecure,
password_hash.is_some() || insecure,
"a password must be used"
);
@ -270,11 +264,18 @@ fn main() -> Result<()> {
}
};
// Start services
let app = Router::new()
.route("/update", get(update_records))
// Create services
let app = Router::new().route("/update", get(update_records));
// if a password is provided, validate it
let app = if let Some(pass) = password_hash {
app.layer(auth::auth_layer(Box::leak(pass), String::leak(salt)))
} else {
app
}
.layer(ip_source.into_extension())
.with_state(state);
// Start services
info!("starting listener on {ip}:{port}");
let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port))
.await
@ -289,31 +290,12 @@ fn main() -> Result<()> {
})
}
#[tracing::instrument(skip(state, pass), level = "trace", ret(level = "info"))]
#[tracing::instrument(skip(state), level = "trace", ret(level = "info"))]
async fn update_records(
State(state): State<AppState<'static>>,
AuthBasic((username, pass)): AuthBasic,
SecureClientIp(ip): SecureClientIp,
) -> axum::response::Result<&'static str> {
debug!("received update request from {ip}");
let Some(pass) = pass else {
return Err((StatusCode::UNAUTHORIZED, Json::from("no password provided")).into());
};
if let Some(stored_pass) = state.password_hash {
let password = pass.trim().to_string();
let pass_hash = password::hash_identity(&username, &password, state.salt);
if pass_hash.as_ref() != stored_pass {
warn!("rejected update");
trace!(
"mismatched hashes:\n{}\n{}",
URL_SAFE_NO_PAD.encode(pass_hash.as_ref()),
URL_SAFE_NO_PAD.encode(stored_pass),
);
return Err((StatusCode::UNAUTHORIZED, "invalid identity").into());
}
}
info!("accepted update");
match nsupdate(ip, state.ttl, state.key_file, state.records).await {
Ok(status) if status.success() => {

View file

@ -28,10 +28,20 @@ impl Mkpasswd {
}
}
pub fn hash_basic_auth(user_pass: &[u8], salt: &str) -> Digest {
let mut context = ring::digest::Context::new(&ring::digest::SHA256);
context.update(user_pass);
context.update(salt.as_bytes());
context.finish()
}
pub fn hash_identity(username: &str, password: &str, salt: &str) -> Digest {
let mut data = Vec::with_capacity(username.len() + password.len() + salt.len() + 1);
write!(data, "{username}:{password}{salt}").unwrap();
ring::digest::digest(&ring::digest::SHA256, &data)
let mut context = ring::digest::Context::new(&ring::digest::SHA256);
context.update(username.as_bytes());
context.update(b":");
context.update(password.as_bytes());
context.update(salt.as_bytes());
context.finish()
}
pub fn mkpasswd(