diff --git a/src/impl_build_cmd.rs b/src/impl_build_cmd.rs index f09f94e..0409055 100644 --- a/src/impl_build_cmd.rs +++ b/src/impl_build_cmd.rs @@ -5,8 +5,6 @@ use crate::RequestFlags; const MAX_KEY_LEN: usize = 250; const MAX_BIN_KEY_LEN: usize = 187; // 250 * 3 / 4 due to b64 encoding -// type KeyHasher = Blake2b; - pub fn impl_build_cmd( cmd: &[u8], key: &[u8], diff --git a/src/impl_parse_header_tests.rs b/src/impl_parse_header_tests.rs index 8fb3a40..6e5c778 100644 --- a/src/impl_parse_header_tests.rs +++ b/src/impl_parse_header_tests.rs @@ -11,7 +11,7 @@ mod tests { } #[test] fn test_value_response() { - let data = b"VA 1234 c1234567 h0 l1111 t2222 f1 Z s3333 MORE_SPACES_ARE_OK_TOO Ofoobar UNKNOWN FLAGS\r\n"; + let data = b"VA 1234 c1234567 h0 l1111 t2222 f1 Z s3333 MORE_SPACES_ARE_OK_TOO Ofooonly UNKNOWN FLAGS Ofoobar\r\n"; let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap(); assert_eq!(end_pos, data.len()); assert_eq!(response_type, Some(RESPONSE_VALUE)); @@ -29,6 +29,89 @@ mod tests { assert_eq!(flags.opaque, Some(b"foobar".to_vec())); } + #[test] + fn test_size_int_overflow() { + let data = b"VA 12345678901234567890 c123456789001234567890 l111 t12345678901234567890\r\n"; + let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap(); + assert_eq!(end_pos, data.len()); + assert_eq!(end_pos, data.len()); + assert!(response_type.is_none()); + assert!(size.is_none()); + assert!(flags.is_none()); + } + + #[test] + fn test_flags_int_overflow() { + let data = b"VA 1234 c123456789001234567890 l111 t12345678901234567890\r\n"; + let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap(); + assert_eq!(end_pos, data.len()); + assert_eq!(response_type, Some(RESPONSE_VALUE)); + assert_eq!(size, Some(1234)); + assert!(flags.is_some()); + let flags = flags.unwrap(); + assert!(flags.cas_token.is_none()); + assert!(flags.fetched.is_none()); + assert_eq!(flags.last_access, Some(111)); + assert!(flags.ttl.is_none()); + assert!(flags.client_flag.is_none()); + assert!(flags.win.is_none()); + assert_eq!(flags.stale, false); + assert!(flags.size.is_none()); + assert!(flags.opaque.is_none()); + } + + #[test] + fn test_bad_ttls() { + let data = b"VA 1234 c111 t\r\n"; + let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap(); + assert_eq!(end_pos, data.len()); + assert_eq!(response_type, Some(RESPONSE_VALUE)); + assert_eq!(size, Some(1234)); + assert!(flags.is_some()); + let flags = flags.unwrap(); + assert_eq!(flags.cas_token, Some(111)); + assert!(flags.fetched.is_none()); + assert!(flags.last_access.is_none()); + assert!(flags.ttl.is_none()); + assert!(flags.client_flag.is_none()); + assert!(flags.win.is_none()); + assert_eq!(flags.stale, false); + assert!(flags.size.is_none()); + assert!(flags.opaque.is_none()); + let data = b"VA 1234 t-999 c111\r\n"; + let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap(); + assert_eq!(end_pos, data.len()); + assert_eq!(response_type, Some(RESPONSE_VALUE)); + assert_eq!(size, Some(1234)); + assert!(flags.is_some()); + let flags = flags.unwrap(); + assert_eq!(flags.cas_token, Some(111)); + assert!(flags.fetched.is_none()); + assert!(flags.last_access.is_none()); + assert_eq!(flags.ttl, Some(-1)); + assert!(flags.client_flag.is_none()); + assert!(flags.win.is_none()); + assert_eq!(flags.stale, false); + assert!(flags.size.is_none()); + assert!(flags.opaque.is_none()); + let data = b"VA 1234 t- c111\r\n"; + let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap(); + assert_eq!(end_pos, data.len()); + assert_eq!(response_type, Some(RESPONSE_VALUE)); + assert_eq!(size, Some(1234)); + assert!(flags.is_some()); + let flags = flags.unwrap(); + assert_eq!(flags.cas_token, Some(111)); + assert!(flags.fetched.is_none()); + assert!(flags.last_access.is_none()); + assert_eq!(flags.ttl, Some(-1)); + assert!(flags.client_flag.is_none()); + assert!(flags.win.is_none()); + assert_eq!(flags.stale, false); + assert!(flags.size.is_none()); + assert!(flags.opaque.is_none()); + } + #[test] fn test_value_response_no_flags() { let data = b"VA 1234\r\n"; @@ -61,7 +144,7 @@ mod tests { #[test] fn test_success_reponse() { - let data = b"HD c1234567 h0 l1111 t2222 f1 X W s3333 Ofoobar UNKNOWN FLAGS\r\nOK\r\n"; + let data = b"HD c1234567 h0 l1111 t-1 f1 X W s2222 Ofoobar UNKNOWN FLAGS\r\nOK\r\n"; let (end_pos, response_type, size, flags) = impl_parse_header(data, 0, data.len()).unwrap(); assert_eq!(end_pos, data.len() - 4); assert_eq!(response_type, Some(RESPONSE_SUCCESS)); @@ -71,11 +154,11 @@ mod tests { assert_eq!(flags.cas_token, Some(1234567)); assert_eq!(flags.fetched, Some(false)); assert_eq!(flags.last_access, Some(1111)); - assert_eq!(flags.ttl, Some(2222)); + assert_eq!(flags.ttl, Some(-1)); assert_eq!(flags.client_flag, Some(1)); assert_eq!(flags.win, Some(true)); assert_eq!(flags.stale, true); - assert_eq!(flags.size, Some(3333)); + assert_eq!(flags.size, Some(2222)); assert_eq!(flags.opaque, Some(b"foobar".to_vec())); let (end_pos, response_type, size, flags) = impl_parse_header(data, data.len() - 4, data.len()).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 8b9d564..841185f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -65,12 +65,96 @@ pub fn build_cmd( } } +#[pyfunction] +#[pyo3( + signature = ( + key, + request_flags=None, + ), + text_signature = "(key: bytes, request_flags: Optional[RequestFlags])", +)] +pub fn build_meta_get( + py: Python, + key: &[u8], + request_flags: Option<&RequestFlags>, +) -> PyResult> { + match impl_build_cmd(b"mg", key, None, request_flags, false) { + Some(buf) => Ok(PyBytes::new(py, &buf).into()), + None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")), + } +} + +#[pyfunction] +#[pyo3( + signature = ( + key, + size, + request_flags=None, + legacy_size_format=false, + ), + text_signature = "(key: bytes, size: int, request_flags: Optional[RequestFlags], legacy_size_format: bool = False)", +)] +pub fn build_meta_set( + py: Python, + key: &[u8], + size: u32, + request_flags: Option<&RequestFlags>, + legacy_size_format: bool, +) -> PyResult> { + match impl_build_cmd(b"ms", key, Some(size), request_flags, legacy_size_format) { + Some(buf) => Ok(PyBytes::new(py, &buf).into()), + None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")), + } +} + +#[pyfunction] +#[pyo3( + signature = ( + key, + request_flags=None, + ), + text_signature = "(key: bytes, request_flags: Optional[RequestFlags])", +)] +pub fn build_meta_delete( + py: Python, + key: &[u8], + request_flags: Option<&RequestFlags>, +) -> PyResult> { + match impl_build_cmd(b"md", key, None, request_flags, false) { + Some(buf) => Ok(PyBytes::new(py, &buf).into()), + None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")), + } +} + +#[pyfunction] +#[pyo3( + signature = ( + key, + request_flags=None, + ), + text_signature = "(key: bytes, request_flags: Optional[RequestFlags])", +)] +pub fn build_meta_arithmetic( + py: Python, + key: &[u8], + request_flags: Option<&RequestFlags>, +) -> PyResult> { + match impl_build_cmd(b"ma", key, None, request_flags, false) { + Some(buf) => Ok(PyBytes::new(py, &buf).into()), + None => Err(pyo3::exceptions::PyValueError::new_err("Key is too long")), + } +} + #[pymodule] fn meta_socket(_py: Python, module: &PyModule) -> PyResult<()> { module.add_class::()?; module.add_class::()?; module.add_function(wrap_pyfunction!(parse_header, module)?)?; module.add_function(wrap_pyfunction!(build_cmd, module)?)?; + module.add_function(wrap_pyfunction!(build_meta_get, module)?)?; + module.add_function(wrap_pyfunction!(build_meta_set, module)?)?; + module.add_function(wrap_pyfunction!(build_meta_delete, module)?)?; + module.add_function(wrap_pyfunction!(build_meta_arithmetic, module)?)?; module.add("RESPONSE_VALUE", RESPONSE_VALUE)?; module.add("RESPONSE_SUCCESS", RESPONSE_SUCCESS)?; module.add("RESPONSE_NOT_STORED", RESPONSE_NOT_STORED)?; diff --git a/src/response_flags.rs b/src/response_flags.rs index 3c876e5..0496701 100644 --- a/src/response_flags.rs +++ b/src/response_flags.rs @@ -1,4 +1,4 @@ -use atoi::FromRadix10; +use atoi::FromRadix10Checked; use pyo3::prelude::*; #[inline] @@ -13,6 +13,22 @@ fn find_space_or_end(header: &[u8], start: usize) -> usize { n } +#[inline] +fn get_u32_value(header: &[u8], start: usize) -> (Option, usize) { + match u32::from_radix_10_checked(&header[start..]) { + (Some(v), len) if len > 0 => (Some(v), start + len), + _ => (None, find_space_or_end(&header, start)), + } +} + +#[inline] +fn get_i32_value(header: &[u8], start: usize) -> (Option, usize) { + match i32::from_radix_10_checked(&header[start..]) { + (Some(v), len) if len > 0 => (Some(v), start + len), + _ => (None, find_space_or_end(&header, start)), + } +} + #[pyclass] pub struct ResponseFlags { #[pyo3(get)] @@ -125,12 +141,14 @@ impl ResponseFlags { if header.len() < size_start + 1 { return None; } - let (size, pos) = u32::from_radix_10(&header[size_start..]); - if pos == 0 { - return None; + // let (size, pos) = u32::from_radix_10_checked(&header[size_start..]); + match u32::from_radix_10_checked(&header[size_start..]) { + (Some(size), pos) if pos > 0 => { + let flags = ResponseFlags::parse_flags(header, size_start + pos); + Some((size, flags)) + } + _ => None, } - let flags = ResponseFlags::parse_flags(header, size_start + pos); - Some((size, flags)) } #[staticmethod] @@ -157,61 +175,33 @@ impl ResponseFlags { } b'c' => { // cas_token flag (u32) - let (v, len) = u32::from_radix_10(&header[n..]); - if len > 0 { - cas_token = Some(v); - n += len; - } else { - n = find_space_or_end(&header, n); - } + (cas_token, n) = get_u32_value(&header, n); } b'h' => { // fetched flag (bool) encoded as 1 or 0 - match header[n] { - b'1' => { - fetched = Some(true); - } - b'0' => { - fetched = Some(false); - } - _ => {} - } - n += 1; + fetched = match header[n] { + b'1' => Some(true), + b'0' => Some(false), + _ => None, + }; + n = find_space_or_end(&header, n + 1); } b'l' => { // last_access flag (u32) - let (v, len) = u32::from_radix_10(&header[n..]); - if len > 0 { - last_access = Some(v); - n += len; - } else { - n = find_space_or_end(&header, n); - } + (last_access, n) = get_u32_value(&header, n); } b't' => { // ttl flag (i32) encoded as -1 (for no ttl) or a positive number - if header[n + 1] == b'-' { + if n < header.len() && header[n] == b'-' { ttl = Some(-1); - n += 2; + n = find_space_or_end(&header, n) } else { - let (v, len) = i32::from_radix_10(&header[n..]); - if len > 0 { - ttl = Some(v); - n += len; - } else { - n = find_space_or_end(&header, n); - } + (ttl, n) = get_i32_value(&header, n); } } b'f' => { // client_flag flag (u32) - let (v, len) = u32::from_radix_10(&header[n..]); - if len > 0 { - client_flag = Some(v); - n += len; - } else { - n = find_space_or_end(&header, n); - } + (client_flag, n) = get_u32_value(&header, n); } b'W' => { // win flag (bool), no value @@ -227,13 +217,7 @@ impl ResponseFlags { } b's' => { // size flag (u32) - let (v, len) = u32::from_radix_10(&header[n..]); - if len > 0 { - size = Some(v); - n += len; - } else { - n = find_space_or_end(&header, n); - } + (size, n) = get_u32_value(&header, n); } b'O' => { // opaque flag (bytes)