webnsupdate/src/auth.rs
Jalil David Salamé Messina 750cbbff93
All checks were successful
/ build (push) Successful in 3s
/ check (push) Successful in 7s
feat: replace axum-auth with tower_http
Slightly more involde in the auth code, but it makes the rest of the
application more straight forward.

Fixes #10
2024-11-23 20:39:06 +01:00

104 lines
3.1 KiB
Rust

use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use tower_http::validate_request::ValidateRequestHeaderLayer;
use tracing::{trace, warn};
use crate::password;
pub fn auth_layer<'a, ResBody>(
user_pass_hash: &'a [u8],
salt: &'a str,
) -> ValidateRequestHeaderLayer<BasicAuth<'a, ResBody>> {
ValidateRequestHeaderLayer::custom(BasicAuth::new(user_pass_hash, salt))
}
#[derive(Copy)]
pub struct BasicAuth<'a, ResBody> {
pass: &'a [u8],
salt: &'a str,
_ty: std::marker::PhantomData<fn() -> ResBody>,
}
impl<ResBody> std::fmt::Debug for BasicAuth<'_, ResBody> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BasicAuth")
.field("pass", &self.pass)
.field("salt", &self.salt)
.field("_ty", &self._ty)
.finish()
}
}
impl<ResBody> Clone for BasicAuth<'_, ResBody> {
fn clone(&self) -> Self {
Self {
pass: self.pass,
salt: self.salt,
_ty: std::marker::PhantomData,
}
}
}
impl<'a, ResBody> BasicAuth<'a, ResBody> {
pub fn new(pass: &'a [u8], salt: &'a str) -> Self {
Self {
pass,
salt,
_ty: std::marker::PhantomData,
}
}
fn check_headers(&self, headers: &http::HeaderMap<http::HeaderValue>) -> bool {
let Some(auth) = headers.get(http::header::AUTHORIZATION) else {
return false;
};
// Poor man's split once: https://doc.rust-lang.org/std/primitive.slice.html#method.split_once
let Some(index) = auth.as_bytes().iter().position(|&c| c == b' ') else {
return false;
};
let user_pass = &auth.as_bytes()[index + 1..];
match base64::engine::general_purpose::URL_SAFE.decode(user_pass) {
Ok(user_pass) => {
let hashed = password::hash_basic_auth(&user_pass, self.salt);
if hashed.as_ref() == self.pass {
return true;
}
warn!("rejected update");
trace!(
"mismatched hashes:\nprovided: {}\nstored: {}",
URL_SAFE_NO_PAD.encode(hashed.as_ref()),
URL_SAFE_NO_PAD.encode(self.pass),
);
false
}
Err(err) => {
warn!("received invalid base64 when decoding Basic header: {err}");
false
}
}
}
}
impl<B, ResBody> tower_http::validate_request::ValidateRequest<B> for BasicAuth<'_, ResBody>
where
ResBody: Default,
{
type ResponseBody = ResBody;
fn validate(
&mut self,
request: &mut http::Request<B>,
) -> std::result::Result<(), http::Response<Self::ResponseBody>> {
if self.check_headers(request.headers()) {
return Ok(());
}
let mut res = http::Response::new(ResBody::default());
*res.status_mut() = http::status::StatusCode::UNAUTHORIZED;
res.headers_mut()
.insert(http::header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
Err(res)
}
}