Skip to content

Commit

Permalink
Cleanup the code
Browse files Browse the repository at this point in the history
  • Loading branch information
Hinton committed Dec 4, 2023
1 parent 8610f30 commit 79edca4
Showing 1 changed file with 133 additions and 115 deletions.
248 changes: 133 additions & 115 deletions crates/bitwarden/src/vault/totp.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::{collections::HashMap, str::FromStr};

use crate::error::{Error, Result};
use chrono::{DateTime, Utc};
Expand All @@ -24,70 +24,160 @@ pub struct TotpResponse {
pub period: u32,
}

/// Generate a OATH or RFC 6238 TOTP code from a provided key.
///
/// <https://datatracker.ietf.org/doc/html/rfc6238>
///
/// Key can be either:
/// - A base32 encoded string
/// - OTP Auth URI
/// - Steam URI
///
/// Supports providing an optional time, and defaults to current system time if none is provided.
///
/// Arguments:
/// - `key` - The key to generate the TOTP code from
/// - `time` - The time in UTC to generate the TOTP code for, defaults to current system time
pub async fn generate_totp(key: String, time: Option<DateTime<Utc>>) -> Result<TotpResponse> {
let params: Totp = key.parse()?;

let time = time.unwrap_or_else(Utc::now);

let otp = params.derive_otp(time.timestamp())?;

Ok(TotpResponse {
code: otp,
period: params.period,
})
}

#[derive(Clone, Copy, Debug)]
enum TotpAlgorithm {
enum Algorithm {
Sha1,
Sha256,
Sha512,
Steam,
}

impl Algorithm {
// Derive the HMAC hash for the given algorithm
fn derive_hash(&self, key: &[u8], time: &[u8]) -> Result<Vec<u8>> {
fn compute_digest<D: Mac>(mut digest: D, time: &[u8]) -> Vec<u8> {
digest.update(time);
digest.finalize().into_bytes().to_vec()
}

Ok(match self {
Algorithm::Sha1 => compute_digest(HmacSha1::new_from_slice(key)?, time),
Algorithm::Sha256 => compute_digest(HmacSha256::new_from_slice(key)?, time),
Algorithm::Sha512 => compute_digest(HmacSha512::new_from_slice(key)?, time),
Algorithm::Steam => compute_digest(HmacSha1::new_from_slice(key)?, time),
})
}
}

#[derive(Debug)]
struct TotpParams {
algorithm: TotpAlgorithm,
struct Totp {
algorithm: Algorithm,
digits: u32,
period: u32,
secret: String,
}

impl Default for TotpParams {
impl Default for Totp {
fn default() -> Self {
Self {
algorithm: TotpAlgorithm::Sha1,
algorithm: Algorithm::Sha1,
digits: 6,
period: 30,
secret: "".to_string(),
}
}
}

/// Generate a OATH or RFC 6238 TOTP code from a provided key.
///
/// <https://datatracker.ietf.org/doc/html/rfc6238>
///
/// Key can be either:
/// - A base32 encoded string
/// - OTP Auth URI
/// - Steam URI
///
/// Supports providing an optional time, and defaults to current system time if none is provided.
pub async fn generate_totp(key: String, time: Option<DateTime<Utc>>) -> Result<TotpResponse> {
let params = get_params(key)?;

// TODO: Should we swap the expected time to timestamp?
let time = time.unwrap_or_else(Utc::now);
print!("{:?}", params);

let t = time.timestamp() / params.period as i64;
let secret = BASE32.decode(params.secret.as_ref()).map_err(|e| {
println!("{:?}", e);
Error::Internal("Unable to decode secret")
})?;

let hash = derive_hash(params.algorithm, &secret, t.to_be_bytes().as_ref())?;
let binary = derive_binary(hash);

let otp = if let TotpAlgorithm::Steam = params.algorithm {
derive_steam_otp(binary, params.digits)
} else {
let otp = binary % 10_u32.pow(params.digits);
format!("{1:00$}", params.digits as usize, otp)
};
impl Totp {
fn derive_otp(&self, time: i64) -> Result<String> {
let time = time / self.period as i64;

let secret = BASE32.decode(self.secret.as_ref()).map_err(|e| {
println!("{:?}", e);
Error::Internal("Unable to decode secret")
})?;

let hash = self
.algorithm
.derive_hash(&secret, time.to_be_bytes().as_ref())?;
let binary = derive_binary(hash);

Ok(if let Algorithm::Steam = self.algorithm {
derive_steam_otp(binary, self.digits)
} else {
let otp = binary % 10_u32.pow(self.digits);
format!("{1:00$}", self.digits as usize, otp)
})
}
}

Ok(TotpResponse {
code: otp,
period: params.period,
})
impl FromStr for Totp {
type Err = Error;

/// Parses the provided key and returns the corresponding `TotpParams`.
///
/// Key can be either:
/// - A base32 encoded string
/// - OTP Auth URI
/// - Steam URI
fn from_str(key: &str) -> Result<Self> {
let params = if key.starts_with("otpauth://") {
let url = Url::parse(key).map_err(|_| Error::Internal("Unable to parse URL"))?;
let parts: HashMap<_, _> = url.query_pairs().collect();

let defaults = Totp::default();

Totp {
algorithm: parts
.get("algorithm")
.and_then(|v| match v.to_uppercase().as_ref() {
"SHA1" => Some(Algorithm::Sha1),
"SHA256" => Some(Algorithm::Sha256),
"SHA512" => Some(Algorithm::Sha512),
_ => None,
})
.unwrap_or(defaults.algorithm),
digits: parts
.get("digits")
.and_then(|v| v.parse().ok())
.map(|v: u32| v.clamp(0, 10))
.unwrap_or(defaults.digits),
period: parts
.get("period")
.and_then(|v| v.parse().ok())
.map(|v: u32| v.max(1))
.unwrap_or(defaults.period),
secret: parts
.get("secret")
.map(|v| v.to_string())
.unwrap_or(defaults.secret),
}
} else if key.starts_with("steam://") {
Totp {
algorithm: Algorithm::Steam,
digits: 5,
secret: key
.strip_prefix("steam://")
.expect("Prefix is defined")
.to_string(),
..Totp::default()
}
} else {
Totp {
secret: key.to_string(),
..Totp::default()
}
};

Ok(params)
}
}

/// Derive the Steam OTP from the hash with the given number of digits.
Expand All @@ -108,64 +198,6 @@ fn derive_steam_otp(binary: u32, digits: u32) -> String {
otp
}

/// Parses the provided key and returns the corresponding `TotpParams`.
///
/// Key can be either:
/// - A base32 encoded string
/// - OTP Auth URI
/// - Steam URI
fn get_params(key: String) -> Result<TotpParams> {
let params = if key.starts_with("otpauth://") {
let url = Url::parse(&key).map_err(|_| Error::Internal("Unable to parse URL"))?;
let parts: HashMap<_, _> = url.query_pairs().collect();

let defaults = TotpParams::default();

TotpParams {
algorithm: parts
.get("algorithm")
.and_then(|v| match v.to_uppercase().as_ref() {
"SHA1" => Some(TotpAlgorithm::Sha1),
"SHA256" => Some(TotpAlgorithm::Sha256),
"SHA512" => Some(TotpAlgorithm::Sha512),
_ => None,
})
.unwrap_or(defaults.algorithm),
digits: parts
.get("digits")
.and_then(|v| v.parse().ok())
.map(|v: u32| v.clamp(0, 10))
.unwrap_or(defaults.digits),
period: parts
.get("period")
.and_then(|v| v.parse().ok())
.map(|v: u32| v.max(1))
.unwrap_or(defaults.period),
secret: parts
.get("secret")
.map(|v| v.to_string())
.unwrap_or(defaults.secret),
}
} else if key.starts_with("steam://") {
TotpParams {
algorithm: TotpAlgorithm::Steam,
digits: 5,
secret: key
.strip_prefix("steam://")
.expect("Prefix is defined")
.to_string(),
..TotpParams::default()
}
} else {
TotpParams {
secret: key,
..TotpParams::default()
}
};

Ok(params)
}

/// Derive the OTP from the hash with the given number of digits.
fn derive_binary(hash: Vec<u8>) -> u32 {
let offset = (hash.last().unwrap_or(&0) & 15) as usize;
Expand All @@ -176,27 +208,13 @@ fn derive_binary(hash: Vec<u8>) -> u32 {
| hash[offset + 3] as u32
}

/// Convert from the dependency `InvalidLength` error into this crate's `Error`.
impl From<aes::cipher::InvalidLength> for Error {
fn from(_: aes::cipher::InvalidLength) -> Self {
Error::Internal("Invalid length")
}
}

// Derive the HMAC hash for the given algorithm
fn derive_hash(algorithm: TotpAlgorithm, key: &[u8], time: &[u8]) -> Result<Vec<u8>> {
fn compute_digest<D: Mac>(mut digest: D, time: &[u8]) -> Vec<u8> {
digest.update(time);
digest.finalize().into_bytes().to_vec()
}

Ok(match algorithm {
TotpAlgorithm::Sha1 => compute_digest(HmacSha1::new_from_slice(key)?, time),
TotpAlgorithm::Sha256 => compute_digest(HmacSha256::new_from_slice(key)?, time),
TotpAlgorithm::Sha512 => compute_digest(HmacSha512::new_from_slice(key)?, time),
TotpAlgorithm::Steam => compute_digest(HmacSha1::new_from_slice(key)?, time),
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down

0 comments on commit 79edca4

Please sign in to comment.