Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions bin/alisa/src/web_service/auth/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,22 @@ use axum::{
Form,
};
use chrono::Duration;
use log::info;
use log::{error, info};
use serde::Deserialize;
use url::Url;

use super::token::{create_token_with_expiration_in, TokenType};
use super::token::{create_token_with_expiration_in, TokenError, TokenType};

pub async fn authorize(Form(credentials): Form<Credentials<'_>>) -> impl IntoResponse {
if verify_credentials(&credentials) {
let redirect_url = get_redirect_url_from_params(credentials).unwrap();
let redirect_url = match get_redirect_url_from_params(credentials) {
Ok(Some(redirect_url)) => redirect_url,
Ok(None) => return (StatusCode::BAD_REQUEST, HeaderMap::new()),
Err(err) => {
error!("failed to create authorization code: {}", err);
return (StatusCode::INTERNAL_SERVER_ERROR, HeaderMap::new());
}
};

info!("received credentials, generating an authorization code");

Expand All @@ -27,15 +34,18 @@ pub async fn authorize(Form(credentials): Form<Credentials<'_>>) -> impl IntoRes
}
}

fn get_redirect_url_from_params(auth: Credentials) -> Option<Url> {
let mut url = Url::parse(auth.redirect_uri.as_ref()).ok()?;
fn get_redirect_url_from_params(auth: Credentials) -> Result<Option<Url>, TokenError> {
let mut url = match Url::parse(auth.redirect_uri.as_ref()) {
Ok(url) => url,
Err(_) => return Ok(None),
};
Comment on lines +38 to +41

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The redirect_uri parameter is parsed and used as the base for a redirect without validation against an allow-list of trusted domains. This constitutes an Open Redirect vulnerability. An attacker can craft a malicious link that, once the user authenticates, redirects them to an attacker-controlled site, potentially leaking the authorization code and state parameters.

To remediate this, you should validate the redirect_uri against a list of allowed URIs, similar to the check performed in issue_token.rs.

Suggested change
let mut url = match Url::parse(auth.redirect_uri.as_ref()) {
Ok(url) => url,
Err(_) => return Ok(None),
};
let mut url = match Url::parse(auth.redirect_uri.as_ref()) {
Ok(url) if url.as_str() == "https://social.yandex.net/broker/redirect" => url,
_ => return Ok(None),
};


let code = create_token_with_expiration_in(Duration::seconds(30), TokenType::Code);
let code = create_token_with_expiration_in(Duration::seconds(30), TokenType::Code)?;
url.query_pairs_mut()
.append_pair("state", &auth.state)
.append_pair("code", &code);

Some(url)
Ok(Some(url))
}

#[derive(Debug, Deserialize)]
Expand Down
36 changes: 27 additions & 9 deletions bin/alisa/src/web_service/auth/issue_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ use std::borrow::Cow;

use axum::{http::StatusCode, response::IntoResponse, Form, Json};
use chrono::Duration;
use log::debug;
use log::{debug, error};
use serde::{Deserialize, Serialize};

use super::token::{create_token_with_expiration_in, is_valid_token, TokenType};
use super::token::{create_token_with_expiration_in, is_valid_token, TokenError, TokenType};

pub async fn issue_token(Form(client_creds): Form<Creds<'_>>) -> impl IntoResponse {
if !validate_client_creds(&client_creds) {
Expand All @@ -21,7 +21,16 @@ pub async fn issue_token(Form(client_creds): Form<Creds<'_>>) -> impl IntoRespon
// TODO: save token version

debug!("received a valid authorization code, generating access and refresh tokens");
(StatusCode::OK, Json(Response::success()))
match Response::success() {
Ok(response) => (StatusCode::OK, Json(response)),
Err(err) => {
error!("failed to issue tokens from auth code: {}", err);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(Response::failure("internal server error".to_string())),
)
}
}
} else {
debug!("received an invalid authorization code");

Expand All @@ -37,7 +46,16 @@ pub async fn issue_token(Form(client_creds): Form<Creds<'_>>) -> impl IntoRespon

debug!("received a valid refresh token, generating new access and refresh tokens");

(StatusCode::OK, Json(Response::success()))
match Response::success() {
Ok(response) => (StatusCode::OK, Json(response)),
Err(err) => {
error!("failed to issue tokens from refresh token: {}", err);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(Response::failure("internal server error".to_string())),
)
}
}
} else {
debug!("received an invalid refresh token");

Expand Down Expand Up @@ -118,19 +136,19 @@ pub enum Response {
}

impl Response {
fn success() -> Response {
Response::Success {
fn success() -> Result<Response, TokenError> {
Ok(Response::Success {
access_token: create_token_with_expiration_in(
ACCESS_TOKEN_EXPIRATION,
TokenType::Access,
),
)?,
refresh_token: create_token_with_expiration_in(
REFRESH_TOKEN_EXPIRATION,
TokenType::Refresh,
),
)?,
token_type: "Bearer".to_string(),
expires_in: ACCESS_TOKEN_EXPIRATION,
}
})
}

fn failure(error: String) -> Response {
Expand Down
70 changes: 55 additions & 15 deletions bin/alisa/src/web_service/auth/token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@ impl fmt::Display for TokenType {
}
}

#[derive(Debug)]
pub enum TokenError {
InvalidExpiration,
Encoding(jsonwebtoken::errors::Error),
}

impl fmt::Display for TokenError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidExpiration => write!(f, "invalid token expiration"),
Self::Encoding(err) => write!(f, "token encoding failed: {err}"),
}
}
}

impl std::error::Error for TokenError {}
Comment on lines +20 to +35

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For more idiomatic error handling, you can implement the From trait for jsonwebtoken::errors::Error. This will allow you to use the ? operator to automatically convert the error type at call sites.

After applying this suggestion, you can simplify the code at lines 129-134 to:

    encode(
        &header,
        &claims,
        &EncodingKey::from_secret(secret.as_bytes()),
    )?
#[derive(Debug)]
pub enum TokenError {
    InvalidExpiration,
    Encoding(jsonwebtoken::errors::Error),
}

impl From<jsonwebtoken::errors::Error> for TokenError {
    fn from(err: jsonwebtoken::errors::Error) -> Self {
        Self::Encoding(err)
    }
}

impl fmt::Display for TokenError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match self {
            Self::InvalidExpiration => write!(f, "invalid token expiration"),
            Self::Encoding(err) => write!(f, "token encoding failed: {err}"),
        }
    }
}

impl std::error::Error for TokenError {}


pub fn is_valid_token<T: AsRef<str>>(token: T, token_type: TokenType) -> bool {
let secret = extract_secret_from_env();
is_valid_token_with_secret(token, token_type, &secret)
Expand Down Expand Up @@ -51,10 +68,13 @@ fn is_valid_token_with_secret_at<T: AsRef<str>>(
&validation,
) {
Ok(decoded) => decoded,
Err(_) => return false,
Err(err) => {
log::debug!("token decoding failed: {}", err);
return false;
}
};

decoded.claims.exp >= now_timestamp
decoded.claims.exp > now_timestamp
}

#[derive(Debug, Serialize, Deserialize)]
Expand All @@ -64,7 +84,10 @@ struct Claims {
aud: Vec<String>,
}

pub fn create_token_with_expiration_in(expiration: Duration, token_type: TokenType) -> String {
pub fn create_token_with_expiration_in(
expiration: Duration,
token_type: TokenType,
) -> Result<String, TokenError> {
let secret = extract_secret_from_env();
create_token_with_expiration_in_with_secret(expiration, token_type, &secret)
}
Expand All @@ -73,7 +96,7 @@ fn create_token_with_expiration_in_with_secret(
expiration: Duration,
token_type: TokenType,
secret: &str,
) -> String {
) -> Result<String, TokenError> {
create_token_with_expiration_in_with_secret_at(
expiration,
token_type,
Expand All @@ -87,13 +110,13 @@ fn create_token_with_expiration_in_with_secret_at(
token_type: TokenType,
secret: &str,
now_timestamp: i64,
) -> String {
) -> Result<String, TokenError> {
use jsonwebtoken::{encode, Algorithm, EncodingKey, Header};

let expiration = now_timestamp
.checked_add(expiration.num_seconds())
.expect("valid timestamp");
let expiration = u64::try_from(expiration).expect("non-negative timestamp");
.ok_or(TokenError::InvalidExpiration)?;
let expiration = u64::try_from(expiration).map_err(|_| TokenError::InvalidExpiration)?;

let claims = Claims {
sub: "yandex".to_owned(),
Expand All @@ -108,7 +131,7 @@ fn create_token_with_expiration_in_with_secret_at(
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap()
.map_err(TokenError::Encoding)
}

fn extract_secret_from_env() -> String {
Expand Down Expand Up @@ -139,7 +162,8 @@ mod tests {
TokenType::Access,
SECRET,
NOW,
);
)
.unwrap();

assert!(is_valid_token_with_secret_at(
token,
Expand All @@ -156,7 +180,8 @@ mod tests {
TokenType::Access,
SECRET,
NOW,
);
)
.unwrap();

assert!(!is_valid_token_with_secret_at(
token,
Expand All @@ -173,7 +198,8 @@ mod tests {
TokenType::Access,
SECRET,
NOW,
);
)
.unwrap();

assert!(!is_valid_token_with_secret_at(
token,
Expand All @@ -190,7 +216,8 @@ mod tests {
TokenType::Access,
"secret-a",
NOW,
);
)
.unwrap();

assert!(!is_valid_token_with_secret_at(
token,
Expand All @@ -207,19 +234,32 @@ mod tests {
TokenType::Access,
SECRET,
NOW,
);
)
.unwrap();

assert!(is_valid_token_with_secret_at(
&token,
TokenType::Access,
SECRET,
(NOW + 30) as u64
(NOW + 29) as u64
));
assert!(!is_valid_token_with_secret_at(
token,
TokenType::Access,
SECRET,
(NOW + 31) as u64
(NOW + 30) as u64
));
}

#[test]
fn invalid_expiration_returns_error() {
let result = create_token_with_expiration_in_with_secret_at(
Duration::seconds(1),
TokenType::Access,
SECRET,
i64::MAX,
);

assert!(matches!(result, Err(super::TokenError::InvalidExpiration)));
}
}