Skip to content

Commit

Permalink
More safety parsing the header and simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
bisho committed Nov 14, 2023
1 parent 3f6509b commit 5d05e92
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 59 deletions.
2 changes: 0 additions & 2 deletions src/impl_build_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ use crate::RequestFlags;
const MAX_KEY_LEN: usize = 250;
const MAX_BIN_KEY_LEN: usize = 187; // 250 * 3 / 4 due to b64 encoding

// type KeyHasher = Blake2b<U18>;

pub fn impl_build_cmd(
cmd: &[u8],
key: &[u8],
Expand Down
91 changes: 87 additions & 4 deletions src/impl_parse_header_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod tests {
}
#[test]
fn test_value_response() {
let data = b"VA 1234 c1234567 h0 l1111 t2222 f1 Z s3333 MORE_SPACES_ARE_OK_TOO Ofoobar UNKNOWN FLAGS\r\n";
let data = b"VA 1234 c1234567 h0 l1111 t2222 f1 Z s3333 MORE_SPACES_ARE_OK_TOO Ofooonly UNKNOWN FLAGS Ofoobar\r\n";
let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap();
assert_eq!(end_pos, data.len());
assert_eq!(response_type, Some(RESPONSE_VALUE));
Expand All @@ -29,6 +29,89 @@ mod tests {
assert_eq!(flags.opaque, Some(b"foobar".to_vec()));
}

#[test]
fn test_size_int_overflow() {
let data = b"VA 12345678901234567890 c123456789001234567890 l111 t12345678901234567890\r\n";
let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap();
assert_eq!(end_pos, data.len());
assert_eq!(end_pos, data.len());
assert!(response_type.is_none());
assert!(size.is_none());
assert!(flags.is_none());
}

#[test]
fn test_flags_int_overflow() {
let data = b"VA 1234 c123456789001234567890 l111 t12345678901234567890\r\n";
let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap();
assert_eq!(end_pos, data.len());
assert_eq!(response_type, Some(RESPONSE_VALUE));
assert_eq!(size, Some(1234));
assert!(flags.is_some());
let flags = flags.unwrap();
assert!(flags.cas_token.is_none());
assert!(flags.fetched.is_none());
assert_eq!(flags.last_access, Some(111));
assert!(flags.ttl.is_none());
assert!(flags.client_flag.is_none());
assert!(flags.win.is_none());
assert_eq!(flags.stale, false);
assert!(flags.size.is_none());
assert!(flags.opaque.is_none());
}

#[test]
fn test_bad_ttls() {
let data = b"VA 1234 c111 t\r\n";
let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap();
assert_eq!(end_pos, data.len());
assert_eq!(response_type, Some(RESPONSE_VALUE));
assert_eq!(size, Some(1234));
assert!(flags.is_some());
let flags = flags.unwrap();
assert_eq!(flags.cas_token, Some(111));
assert!(flags.fetched.is_none());
assert!(flags.last_access.is_none());
assert!(flags.ttl.is_none());
assert!(flags.client_flag.is_none());
assert!(flags.win.is_none());
assert_eq!(flags.stale, false);
assert!(flags.size.is_none());
assert!(flags.opaque.is_none());
let data = b"VA 1234 t-999 c111\r\n";
let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap();
assert_eq!(end_pos, data.len());
assert_eq!(response_type, Some(RESPONSE_VALUE));
assert_eq!(size, Some(1234));
assert!(flags.is_some());
let flags = flags.unwrap();
assert_eq!(flags.cas_token, Some(111));
assert!(flags.fetched.is_none());
assert!(flags.last_access.is_none());
assert_eq!(flags.ttl, Some(-1));
assert!(flags.client_flag.is_none());
assert!(flags.win.is_none());
assert_eq!(flags.stale, false);
assert!(flags.size.is_none());
assert!(flags.opaque.is_none());
let data = b"VA 1234 t- c111\r\n";
let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap();
assert_eq!(end_pos, data.len());
assert_eq!(response_type, Some(RESPONSE_VALUE));
assert_eq!(size, Some(1234));
assert!(flags.is_some());
let flags = flags.unwrap();
assert_eq!(flags.cas_token, Some(111));
assert!(flags.fetched.is_none());
assert!(flags.last_access.is_none());
assert_eq!(flags.ttl, Some(-1));
assert!(flags.client_flag.is_none());
assert!(flags.win.is_none());
assert_eq!(flags.stale, false);
assert!(flags.size.is_none());
assert!(flags.opaque.is_none());
}

#[test]
fn test_value_response_no_flags() {
let data = b"VA 1234\r\n";
Expand Down Expand Up @@ -61,7 +144,7 @@ mod tests {

#[test]
fn test_success_reponse() {
let data = b"HD c1234567 h0 l1111 t2222 f1 X W s3333 Ofoobar UNKNOWN FLAGS\r\nOK\r\n";
let data = b"HD c1234567 h0 l1111 t-1 f1 X W s2222 Ofoobar UNKNOWN FLAGS\r\nOK\r\n";
let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap();
assert_eq!(end_pos, data.len() - 4);
assert_eq!(response_type, Some(RESPONSE_SUCCESS));
Expand All @@ -71,11 +154,11 @@ mod tests {
assert_eq!(flags.cas_token, Some(1234567));
assert_eq!(flags.fetched, Some(false));
assert_eq!(flags.last_access, Some(1111));
assert_eq!(flags.ttl, Some(2222));
assert_eq!(flags.ttl, Some(-1));
assert_eq!(flags.client_flag, Some(1));
assert_eq!(flags.win, Some(true));
assert_eq!(flags.stale, true);
assert_eq!(flags.size, Some(3333));
assert_eq!(flags.size, Some(2222));
assert_eq!(flags.opaque, Some(b"foobar".to_vec()));
let (end_pos, response_type, size, flags) =
impl_parse_header(data, data.len() - 4, data.len()).unwrap();
Expand Down
84 changes: 84 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,96 @@ pub fn build_cmd(
}
}

#[pyfunction]
#[pyo3(
signature = (
key,
request_flags=None,
),
text_signature = "(key: bytes, request_flags: Optional[RequestFlags])",
)]
pub fn build_meta_get(
py: Python,
key: &[u8],
request_flags: Option<&RequestFlags>,
) -> PyResult<Py<PyBytes>> {
match impl_build_cmd(b"mg", key, None, request_flags, false) {
Some(buf) => Ok(PyBytes::new(py, &buf).into()),
None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")),
}
}

#[pyfunction]
#[pyo3(
signature = (
key,
size,
request_flags=None,
legacy_size_format=false,
),
text_signature = "(key: bytes, size: int, request_flags: Optional[RequestFlags], legacy_size_format: bool = False)",
)]
pub fn build_meta_set(
py: Python,
key: &[u8],
size: u32,
request_flags: Option<&RequestFlags>,
legacy_size_format: bool,
) -> PyResult<Py<PyBytes>> {
match impl_build_cmd(b"ms", key, Some(size), request_flags, legacy_size_format) {
Some(buf) => Ok(PyBytes::new(py, &buf).into()),
None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")),
}
}

#[pyfunction]
#[pyo3(
signature = (
key,
request_flags=None,
),
text_signature = "(key: bytes, request_flags: Optional[RequestFlags])",
)]
pub fn build_meta_delete(
py: Python,
key: &[u8],
request_flags: Option<&RequestFlags>,
) -> PyResult<Py<PyBytes>> {
match impl_build_cmd(b"md", key, None, request_flags, false) {
Some(buf) => Ok(PyBytes::new(py, &buf).into()),
None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")),
}
}

#[pyfunction]
#[pyo3(
signature = (
key,
request_flags=None,
),
text_signature = "(key: bytes, request_flags: Optional[RequestFlags])",
)]
pub fn build_meta_arithmetic(
py: Python,
key: &[u8],
request_flags: Option<&RequestFlags>,
) -> PyResult<Py<PyBytes>> {
match impl_build_cmd(b"ma", key, None, request_flags, false) {
Some(buf) => Ok(PyBytes::new(py, &buf).into()),
None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")),
}
}

#[pymodule]
fn meta_socket(_py: Python, module: &PyModule) -> PyResult<()> {
module.add_class::<ResponseFlags>()?;
module.add_class::<RequestFlags>()?;
module.add_function(wrap_pyfunction!(parse_header, module)?)?;
module.add_function(wrap_pyfunction!(build_cmd, module)?)?;
module.add_function(wrap_pyfunction!(build_meta_get, module)?)?;
module.add_function(wrap_pyfunction!(build_meta_set, module)?)?;
module.add_function(wrap_pyfunction!(build_meta_delete, module)?)?;
module.add_function(wrap_pyfunction!(build_meta_arithmetic, module)?)?;
module.add("RESPONSE_VALUE", RESPONSE_VALUE)?;
module.add("RESPONSE_SUCCESS", RESPONSE_SUCCESS)?;
module.add("RESPONSE_NOT_STORED", RESPONSE_NOT_STORED)?;
Expand Down
90 changes: 37 additions & 53 deletions src/response_flags.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use atoi::FromRadix10;
use atoi::FromRadix10Checked;
use pyo3::prelude::*;

#[inline]
Expand All @@ -13,6 +13,22 @@ fn find_space_or_end(header: &[u8], start: usize) -> usize {
n
}

#[inline]
fn get_u32_value(header: &[u8], start: usize) -> (Option<u32>, usize) {
match u32::from_radix_10_checked(&header[start..]) {
(Some(v), len) if len > 0 => (Some(v), start + len),
_ => (None, find_space_or_end(&header, start)),
}
}

#[inline]
fn get_i32_value(header: &[u8], start: usize) -> (Option<i32>, usize) {
match i32::from_radix_10_checked(&header[start..]) {
(Some(v), len) if len > 0 => (Some(v), start + len),
_ => (None, find_space_or_end(&header, start)),
}
}

#[pyclass]
pub struct ResponseFlags {
#[pyo3(get)]
Expand Down Expand Up @@ -125,12 +141,14 @@ impl ResponseFlags {
if header.len() < size_start + 1 {
return None;
}
let (size, pos) = u32::from_radix_10(&header[size_start..]);
if pos == 0 {
return None;
// let (size, pos) = u32::from_radix_10_checked(&header[size_start..]);
match u32::from_radix_10_checked(&header[size_start..]) {
(Some(size), pos) if pos > 0 => {
let flags = ResponseFlags::parse_flags(header, size_start + pos);
Some((size, flags))
}
_ => None,
}
let flags = ResponseFlags::parse_flags(header, size_start + pos);
Some((size, flags))
}

#[staticmethod]
Expand All @@ -157,61 +175,33 @@ impl ResponseFlags {
}
b'c' => {
// cas_token flag (u32)
let (v, len) = u32::from_radix_10(&header[n..]);
if len > 0 {
cas_token = Some(v);
n += len;
} else {
n = find_space_or_end(&header, n);
}
(cas_token, n) = get_u32_value(&header, n);
}
b'h' => {
// fetched flag (bool) encoded as 1 or 0
match header[n] {
b'1' => {
fetched = Some(true);
}
b'0' => {
fetched = Some(false);
}
_ => {}
}
n += 1;
fetched = match header[n] {
b'1' => Some(true),
b'0' => Some(false),
_ => None,
};
n = find_space_or_end(&header, n + 1);
}
b'l' => {
// last_access flag (u32)
let (v, len) = u32::from_radix_10(&header[n..]);
if len > 0 {
last_access = Some(v);
n += len;
} else {
n = find_space_or_end(&header, n);
}
(last_access, n) = get_u32_value(&header, n);
}
b't' => {
// ttl flag (i32) encoded as -1 (for no ttl) or a positive number
if header[n + 1] == b'-' {
if n < header.len() && header[n] == b'-' {
ttl = Some(-1);
n += 2;
n = find_space_or_end(&header, n)
} else {
let (v, len) = i32::from_radix_10(&header[n..]);
if len > 0 {
ttl = Some(v);
n += len;
} else {
n = find_space_or_end(&header, n);
}
(ttl, n) = get_i32_value(&header, n);
}
}
b'f' => {
// client_flag flag (u32)
let (v, len) = u32::from_radix_10(&header[n..]);
if len > 0 {
client_flag = Some(v);
n += len;
} else {
n = find_space_or_end(&header, n);
}
(client_flag, n) = get_u32_value(&header, n);
}
b'W' => {
// win flag (bool), no value
Expand All @@ -227,13 +217,7 @@ impl ResponseFlags {
}
b's' => {
// size flag (u32)
let (v, len) = u32::from_radix_10(&header[n..]);
if len > 0 {
size = Some(v);
n += len;
} else {
n = find_space_or_end(&header, n);
}
(size, n) = get_u32_value(&header, n);
}
b'O' => {
// opaque flag (bytes)
Expand Down

0 comments on commit 5d05e92

Please sign in to comment.