Slightly more involde in the auth code, but it makes the rest of the application more straight forward. Fixes #10
104 lines
3.1 KiB
Rust
104 lines
3.1 KiB
Rust
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
|
|
use base64::Engine;
|
|
use tower_http::validate_request::ValidateRequestHeaderLayer;
|
|
use tracing::{trace, warn};
|
|
|
|
use crate::password;
|
|
|
|
pub fn auth_layer<'a, ResBody>(
|
|
user_pass_hash: &'a [u8],
|
|
salt: &'a str,
|
|
) -> ValidateRequestHeaderLayer<BasicAuth<'a, ResBody>> {
|
|
ValidateRequestHeaderLayer::custom(BasicAuth::new(user_pass_hash, salt))
|
|
}
|
|
|
|
#[derive(Copy)]
|
|
pub struct BasicAuth<'a, ResBody> {
|
|
pass: &'a [u8],
|
|
salt: &'a str,
|
|
_ty: std::marker::PhantomData<fn() -> ResBody>,
|
|
}
|
|
|
|
impl<ResBody> std::fmt::Debug for BasicAuth<'_, ResBody> {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("BasicAuth")
|
|
.field("pass", &self.pass)
|
|
.field("salt", &self.salt)
|
|
.field("_ty", &self._ty)
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl<ResBody> Clone for BasicAuth<'_, ResBody> {
|
|
fn clone(&self) -> Self {
|
|
Self {
|
|
pass: self.pass,
|
|
salt: self.salt,
|
|
_ty: std::marker::PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, ResBody> BasicAuth<'a, ResBody> {
|
|
pub fn new(pass: &'a [u8], salt: &'a str) -> Self {
|
|
Self {
|
|
pass,
|
|
salt,
|
|
_ty: std::marker::PhantomData,
|
|
}
|
|
}
|
|
|
|
fn check_headers(&self, headers: &http::HeaderMap<http::HeaderValue>) -> bool {
|
|
let Some(auth) = headers.get(http::header::AUTHORIZATION) else {
|
|
return false;
|
|
};
|
|
|
|
// Poor man's split once: https://doc.rust-lang.org/std/primitive.slice.html#method.split_once
|
|
let Some(index) = auth.as_bytes().iter().position(|&c| c == b' ') else {
|
|
return false;
|
|
};
|
|
let user_pass = &auth.as_bytes()[index + 1..];
|
|
|
|
match base64::engine::general_purpose::URL_SAFE.decode(user_pass) {
|
|
Ok(user_pass) => {
|
|
let hashed = password::hash_basic_auth(&user_pass, self.salt);
|
|
if hashed.as_ref() == self.pass {
|
|
return true;
|
|
}
|
|
warn!("rejected update");
|
|
trace!(
|
|
"mismatched hashes:\nprovided: {}\nstored: {}",
|
|
URL_SAFE_NO_PAD.encode(hashed.as_ref()),
|
|
URL_SAFE_NO_PAD.encode(self.pass),
|
|
);
|
|
false
|
|
}
|
|
Err(err) => {
|
|
warn!("received invalid base64 when decoding Basic header: {err}");
|
|
false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<B, ResBody> tower_http::validate_request::ValidateRequest<B> for BasicAuth<'_, ResBody>
|
|
where
|
|
ResBody: Default,
|
|
{
|
|
type ResponseBody = ResBody;
|
|
|
|
fn validate(
|
|
&mut self,
|
|
request: &mut http::Request<B>,
|
|
) -> std::result::Result<(), http::Response<Self::ResponseBody>> {
|
|
if self.check_headers(request.headers()) {
|
|
return Ok(());
|
|
}
|
|
|
|
let mut res = http::Response::new(ResBody::default());
|
|
*res.status_mut() = http::status::StatusCode::UNAUTHORIZED;
|
|
res.headers_mut()
|
|
.insert(http::header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
|
|
Err(res)
|
|
}
|
|
}
|