feat: replace axum-auth with tower_http
Slightly more involde in the auth code, but it makes the rest of the application more straight forward. Fixes #10
This commit is contained in:
parent
60aed649b1
commit
c43ca438e6
5 changed files with 166 additions and 55 deletions
16
Cargo.lock
generated
16
Cargo.lock
generated
|
@ -1044,6 +1044,21 @@ dependencies = [
|
||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower-http"
|
||||||
|
version = "0.6.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "403fa3b783d4b626a8ad51d766ab03cb6d2dbfc46b1c5d4448395e6628dc9697"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"bytes",
|
||||||
|
"http",
|
||||||
|
"mime",
|
||||||
|
"pin-project-lite",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-layer"
|
name = "tower-layer"
|
||||||
version = "0.3.3"
|
version = "0.3.3"
|
||||||
|
@ -1175,6 +1190,7 @@ dependencies = [
|
||||||
"miette",
|
"miette",
|
||||||
"ring",
|
"ring",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
]
|
]
|
||||||
|
|
|
@ -8,9 +8,6 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
axum = "0.7"
|
axum = "0.7"
|
||||||
axum-auth = { version = "0.7", default-features = false, features = [
|
|
||||||
"auth-basic",
|
|
||||||
] }
|
|
||||||
axum-client-ip = "0.6"
|
axum-client-ip = "0.6"
|
||||||
base64 = "0.22"
|
base64 = "0.22"
|
||||||
clap = { version = "4", features = ["derive", "env"] }
|
clap = { version = "4", features = ["derive", "env"] }
|
||||||
|
@ -21,6 +18,7 @@ http = "1"
|
||||||
miette = { version = "7", features = ["fancy"] }
|
miette = { version = "7", features = ["fancy"] }
|
||||||
ring = { version = "0.17", features = ["std"] }
|
ring = { version = "0.17", features = ["std"] }
|
||||||
tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] }
|
tokio = { version = "1", features = ["macros", "rt", "process", "io-util"] }
|
||||||
|
tower-http = { version = "0.6.2", features = ["validate-request"] }
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
|
||||||
|
|
105
src/auth.rs
Normal file
105
src/auth.rs
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
||||||
|
use base64::Engine;
|
||||||
|
use tower_http::validate_request::ValidateRequestHeaderLayer;
|
||||||
|
use tracing::{trace, warn};
|
||||||
|
|
||||||
|
use crate::password;
|
||||||
|
|
||||||
|
pub fn auth_layer<'a, ResBody>(
|
||||||
|
user_pass_hash: &'a [u8],
|
||||||
|
salt: &'a str,
|
||||||
|
) -> ValidateRequestHeaderLayer<BasicAuth<'a, ResBody>> {
|
||||||
|
ValidateRequestHeaderLayer::custom(BasicAuth::new(user_pass_hash, salt))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy)]
|
||||||
|
pub struct BasicAuth<'a, ResBody> {
|
||||||
|
pass: &'a [u8],
|
||||||
|
salt: &'a str,
|
||||||
|
_ty: std::marker::PhantomData<fn() -> ResBody>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<ResBody> std::fmt::Debug for BasicAuth<'_, ResBody> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("BasicAuth")
|
||||||
|
.field("pass", &self.pass)
|
||||||
|
.field("salt", &self.salt)
|
||||||
|
.field("_ty", &self._ty)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<ResBody> Clone for BasicAuth<'_, ResBody> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
pass: self.pass,
|
||||||
|
salt: self.salt,
|
||||||
|
_ty: std::marker::PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, ResBody> BasicAuth<'a, ResBody> {
|
||||||
|
pub fn new(pass: &'a [u8], salt: &'a str) -> Self {
|
||||||
|
Self {
|
||||||
|
pass,
|
||||||
|
salt,
|
||||||
|
_ty: std::marker::PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_headers(&self, headers: &http::HeaderMap<http::HeaderValue>) -> bool {
|
||||||
|
let Some(auth) = headers.get(http::header::AUTHORIZATION) else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Poor man's split once: https://doc.rust-lang.org/std/primitive.slice.html#method.split_once
|
||||||
|
let Some(index) = auth.as_bytes().iter().position(|&c| c == b' ') else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
let user_pass = &auth.as_bytes()[index + 1..];
|
||||||
|
|
||||||
|
match base64::engine::general_purpose::URL_SAFE.decode(user_pass) {
|
||||||
|
Ok(user_pass) => {
|
||||||
|
let hashed = password::hash_basic_auth(&user_pass, self.salt);
|
||||||
|
if hashed.as_ref() == self.pass {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
warn!("rejected update");
|
||||||
|
trace!(
|
||||||
|
"mismatched hashes:\nprovided: {}\nstored: {}",
|
||||||
|
URL_SAFE_NO_PAD.encode(hashed.as_ref()),
|
||||||
|
URL_SAFE_NO_PAD.encode(self.pass),
|
||||||
|
);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!("received invalid base64 when decoding Basic header: {err}");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B, ResBody> tower_http::validate_request::ValidateRequest<B> for BasicAuth<'_, ResBody>
|
||||||
|
where
|
||||||
|
ResBody: Default,
|
||||||
|
{
|
||||||
|
type ResponseBody = ResBody;
|
||||||
|
|
||||||
|
fn validate(
|
||||||
|
&mut self,
|
||||||
|
request: &mut http::Request<B>,
|
||||||
|
) -> std::result::Result<(), http::Response<Self::ResponseBody>> {
|
||||||
|
if self.check_headers(request.headers()) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut res = http::Response::new(ResBody::default());
|
||||||
|
*res.status_mut() = http::status::StatusCode::UNAUTHORIZED;
|
||||||
|
res.headers_mut()
|
||||||
|
.insert(http::header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
|
||||||
|
Err(res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
80
src/main.rs
80
src/main.rs
|
@ -7,8 +7,7 @@ use std::{
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
use axum::{extract::State, routing::get, Json, Router};
|
use axum::{extract::State, routing::get, Router};
|
||||||
use axum_auth::AuthBasic;
|
|
||||||
use axum_client_ip::{SecureClientIp, SecureClientIpSource};
|
use axum_client_ip::{SecureClientIp, SecureClientIpSource};
|
||||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
|
@ -16,9 +15,10 @@ use clap_verbosity_flag::Verbosity;
|
||||||
use http::StatusCode;
|
use http::StatusCode;
|
||||||
use miette::{bail, ensure, Context, IntoDiagnostic, Result};
|
use miette::{bail, ensure, Context, IntoDiagnostic, Result};
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
use tracing::{debug, error, info, trace, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
use tracing_subscriber::EnvFilter;
|
use tracing_subscriber::EnvFilter;
|
||||||
|
|
||||||
|
mod auth;
|
||||||
mod password;
|
mod password;
|
||||||
mod records;
|
mod records;
|
||||||
|
|
||||||
|
@ -108,18 +108,12 @@ struct AppState<'a> {
|
||||||
/// TTL set on the Zonefile
|
/// TTL set on the Zonefile
|
||||||
ttl: Duration,
|
ttl: Duration,
|
||||||
|
|
||||||
/// Salt added to the password
|
|
||||||
salt: &'a str,
|
|
||||||
|
|
||||||
/// The IN A/AAAA records that should have their IPs updated
|
/// The IN A/AAAA records that should have their IPs updated
|
||||||
records: &'a [&'a str],
|
records: &'a [&'a str],
|
||||||
|
|
||||||
/// The TSIG key file
|
/// The TSIG key file
|
||||||
key_file: Option<&'a Path>,
|
key_file: Option<&'a Path>,
|
||||||
|
|
||||||
/// The password hash
|
|
||||||
password_hash: Option<&'a [u8]>,
|
|
||||||
|
|
||||||
/// The file where the last IP is stored
|
/// The file where the last IP is stored
|
||||||
ip_file: &'a Path,
|
ip_file: &'a Path,
|
||||||
}
|
}
|
||||||
|
@ -195,9 +189,23 @@ fn main() -> Result<()> {
|
||||||
// Use last registered IP address if available
|
// Use last registered IP address if available
|
||||||
let ip_file = data_dir.join("last-ip");
|
let ip_file = data_dir.join("last-ip");
|
||||||
|
|
||||||
|
// Load password hash
|
||||||
|
let password_hash = password_file
|
||||||
|
.map(|path| -> miette::Result<_> {
|
||||||
|
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?;
|
||||||
|
|
||||||
|
let pass: Box<[u8]> = URL_SAFE_NO_PAD
|
||||||
|
.decode(pass.trim().as_bytes())
|
||||||
|
.into_diagnostic()
|
||||||
|
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
|
||||||
|
.into();
|
||||||
|
|
||||||
|
Ok(pass)
|
||||||
|
})
|
||||||
|
.transpose()?;
|
||||||
|
|
||||||
let state = AppState {
|
let state = AppState {
|
||||||
ttl,
|
ttl,
|
||||||
salt: salt.leak(),
|
|
||||||
// Load DNS records
|
// Load DNS records
|
||||||
records: records::load_no_verify(&records)?,
|
records: records::load_no_verify(&records)?,
|
||||||
// Load keyfile
|
// Load keyfile
|
||||||
|
@ -212,25 +220,11 @@ fn main() -> Result<()> {
|
||||||
Ok(&*Box::leak(key_file.into_boxed_path()))
|
Ok(&*Box::leak(key_file.into_boxed_path()))
|
||||||
})
|
})
|
||||||
.transpose()?,
|
.transpose()?,
|
||||||
// Load password hash
|
|
||||||
password_hash: password_file
|
|
||||||
.map(|path| -> miette::Result<_> {
|
|
||||||
let pass = std::fs::read_to_string(path.as_path()).into_diagnostic()?;
|
|
||||||
|
|
||||||
let pass: Box<[u8]> = URL_SAFE_NO_PAD
|
|
||||||
.decode(pass.trim().as_bytes())
|
|
||||||
.into_diagnostic()
|
|
||||||
.wrap_err_with(|| format!("failed to decode password from {}", path.display()))?
|
|
||||||
.into();
|
|
||||||
|
|
||||||
Ok(&*Box::leak(pass))
|
|
||||||
})
|
|
||||||
.transpose()?,
|
|
||||||
ip_file: Box::leak(ip_file.into_boxed_path()),
|
ip_file: Box::leak(ip_file.into_boxed_path()),
|
||||||
};
|
};
|
||||||
|
|
||||||
ensure!(
|
ensure!(
|
||||||
state.password_hash.is_some() || insecure,
|
password_hash.is_some() || insecure,
|
||||||
"a password must be used"
|
"a password must be used"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -270,11 +264,18 @@ fn main() -> Result<()> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Create services
|
||||||
|
let app = Router::new().route("/update", get(update_records));
|
||||||
|
// if a password is provided, validate it
|
||||||
|
let app = if let Some(pass) = password_hash {
|
||||||
|
app.layer(auth::auth_layer(Box::leak(pass), String::leak(salt)))
|
||||||
|
} else {
|
||||||
|
app
|
||||||
|
}
|
||||||
|
.layer(ip_source.into_extension())
|
||||||
|
.with_state(state);
|
||||||
|
|
||||||
// Start services
|
// Start services
|
||||||
let app = Router::new()
|
|
||||||
.route("/update", get(update_records))
|
|
||||||
.layer(ip_source.into_extension())
|
|
||||||
.with_state(state);
|
|
||||||
info!("starting listener on {ip}:{port}");
|
info!("starting listener on {ip}:{port}");
|
||||||
let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port))
|
let listener = tokio::net::TcpListener::bind(SocketAddr::new(ip, port))
|
||||||
.await
|
.await
|
||||||
|
@ -289,31 +290,12 @@ fn main() -> Result<()> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(state, pass), level = "trace", ret(level = "info"))]
|
#[tracing::instrument(skip(state), level = "trace", ret(level = "info"))]
|
||||||
async fn update_records(
|
async fn update_records(
|
||||||
State(state): State<AppState<'static>>,
|
State(state): State<AppState<'static>>,
|
||||||
AuthBasic((username, pass)): AuthBasic,
|
|
||||||
SecureClientIp(ip): SecureClientIp,
|
SecureClientIp(ip): SecureClientIp,
|
||||||
) -> axum::response::Result<&'static str> {
|
) -> axum::response::Result<&'static str> {
|
||||||
debug!("received update request from {ip}");
|
debug!("received update request from {ip}");
|
||||||
let Some(pass) = pass else {
|
|
||||||
return Err((StatusCode::UNAUTHORIZED, Json::from("no password provided")).into());
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(stored_pass) = state.password_hash {
|
|
||||||
let password = pass.trim().to_string();
|
|
||||||
let pass_hash = password::hash_identity(&username, &password, state.salt);
|
|
||||||
if pass_hash.as_ref() != stored_pass {
|
|
||||||
warn!("rejected update");
|
|
||||||
trace!(
|
|
||||||
"mismatched hashes:\n{}\n{}",
|
|
||||||
URL_SAFE_NO_PAD.encode(pass_hash.as_ref()),
|
|
||||||
URL_SAFE_NO_PAD.encode(stored_pass),
|
|
||||||
);
|
|
||||||
return Err((StatusCode::UNAUTHORIZED, "invalid identity").into());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
info!("accepted update");
|
info!("accepted update");
|
||||||
match nsupdate(ip, state.ttl, state.key_file, state.records).await {
|
match nsupdate(ip, state.ttl, state.key_file, state.records).await {
|
||||||
Ok(status) if status.success() => {
|
Ok(status) if status.success() => {
|
||||||
|
|
|
@ -28,10 +28,20 @@ impl Mkpasswd {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn hash_basic_auth(user_pass: &[u8], salt: &str) -> Digest {
|
||||||
|
let mut context = ring::digest::Context::new(&ring::digest::SHA256);
|
||||||
|
context.update(user_pass);
|
||||||
|
context.update(salt.as_bytes());
|
||||||
|
context.finish()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn hash_identity(username: &str, password: &str, salt: &str) -> Digest {
|
pub fn hash_identity(username: &str, password: &str, salt: &str) -> Digest {
|
||||||
let mut data = Vec::with_capacity(username.len() + password.len() + salt.len() + 1);
|
let mut context = ring::digest::Context::new(&ring::digest::SHA256);
|
||||||
write!(data, "{username}:{password}{salt}").unwrap();
|
context.update(username.as_bytes());
|
||||||
ring::digest::digest(&ring::digest::SHA256, &data)
|
context.update(b":");
|
||||||
|
context.update(password.as_bytes());
|
||||||
|
context.update(salt.as_bytes());
|
||||||
|
context.finish()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn mkpasswd(
|
pub fn mkpasswd(
|
||||||
|
|
Loading…
Reference in a new issue