use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; use tower_http::validate_request::ValidateRequestHeaderLayer; use tracing::{trace, warn}; use crate::password; pub fn auth_layer<'a, ResBody>( user_pass_hash: &'a [u8], salt: &'a str, ) -> ValidateRequestHeaderLayer> { ValidateRequestHeaderLayer::custom(BasicAuth::new(user_pass_hash, salt)) } #[derive(Copy)] pub struct BasicAuth<'a, ResBody> { pass: &'a [u8], salt: &'a str, _ty: std::marker::PhantomData ResBody>, } impl std::fmt::Debug for BasicAuth<'_, ResBody> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("BasicAuth") .field("pass", &self.pass) .field("salt", &self.salt) .field("_ty", &self._ty) .finish() } } impl Clone for BasicAuth<'_, ResBody> { fn clone(&self) -> Self { Self { pass: self.pass, salt: self.salt, _ty: std::marker::PhantomData, } } } impl<'a, ResBody> BasicAuth<'a, ResBody> { pub fn new(pass: &'a [u8], salt: &'a str) -> Self { Self { pass, salt, _ty: std::marker::PhantomData, } } fn check_headers(&self, headers: &http::HeaderMap) -> bool { let Some(auth) = headers.get(http::header::AUTHORIZATION) else { return false; }; // Poor man's split once: https://doc.rust-lang.org/std/primitive.slice.html#method.split_once let Some(index) = auth.as_bytes().iter().position(|&c| c == b' ') else { return false; }; let user_pass = &auth.as_bytes()[index + 1..]; match base64::engine::general_purpose::URL_SAFE.decode(user_pass) { Ok(user_pass) => { let hashed = password::hash_basic_auth(&user_pass, self.salt); if hashed.as_ref() == self.pass { return true; } warn!("rejected update"); trace!( "mismatched hashes:\nprovided: {}\nstored: {}", URL_SAFE_NO_PAD.encode(hashed.as_ref()), URL_SAFE_NO_PAD.encode(self.pass), ); false } Err(err) => { warn!("received invalid base64 when decoding Basic header: {err}"); false } } } } impl tower_http::validate_request::ValidateRequest for BasicAuth<'_, ResBody> where ResBody: Default, { type ResponseBody = ResBody; fn validate( &mut self, request: &mut http::Request, ) -> std::result::Result<(), http::Response> { if self.check_headers(request.headers()) { return Ok(()); } let mut res = http::Response::new(ResBody::default()); *res.status_mut() = http::status::StatusCode::UNAUTHORIZED; res.headers_mut() .insert(http::header::WWW_AUTHENTICATE, "Basic".parse().unwrap()); Err(res) } }