From bf9ebb6746117b74b1578ea526de8f5ca1b3593f Mon Sep 17 00:00:00 2001 From: Jochen Bartl Date: Wed, 12 Oct 2022 21:52:41 +0200 Subject: [PATCH 1/6] use Bytes.t as data type for IPv6 addresses * this addresses issue #16 --- lib/ipaddr.ml | 513 +++++++++++++++++++++-------------- lib_test/test_ipaddr_b128.ml | 184 +++++++++++-- 2 files changed, 473 insertions(+), 224 deletions(-) diff --git a/lib/ipaddr.ml b/lib/ipaddr.ml index 5a50e02..6a3660c 100644 --- a/lib/ipaddr.ml +++ b/lib/ipaddr.ml @@ -442,109 +442,256 @@ module V4 = struct end module B128 = struct - type t = int32 * int32 * int32 * int32 + let int_of_hex_char c = + match c with + | '0' .. '9' -> Char.code c - 48 + | 'a' .. 'f' -> Char.code c - 87 + | 'A' .. 'F' -> Char.code c - 55 + | _ -> invalid_arg "char is not a valid hex digit" + + exception Overflow + + type t = Bytes.t + + let zero () = Bytes.make 16 '\x00' + let min_int () = zero () + let max_int () = Bytes.make 16 '\xff' + let equal = Bytes.equal + let compare = Bytes.compare + + let fold_left f a b = + let a' = ref a in + for i = 0 to 15 do + let x' = Bytes.get_uint8 b i in + a' := f !a' x' + done; + !a' + + let iteri_right2 f x y = + for i = 15 downto 0 do + let x' = Bytes.get_uint8 x i in + let y' = Bytes.get_uint8 y i in + f i x' y' + done - let of_int64 (a, b) = - Int64. - ( to_int32 (shift_right_logical a 32), - to_int32 a, - to_int32 (shift_right_logical b 32), - to_int32 b ) - - let to_int64 (a, b, c, d) = - Int64. - ( logor (shift_left (of_int32 a) 32) (of_int32 b), - logor (shift_left (of_int32 c) 32) (of_int32 d) ) + let of_string_exn s = + let l = String.length s in + if l != 32 then invalid_arg "not 32 chars long" + else + let b = zero () in + let bi = ref 15 in + let i = ref (l - 1) in + while !i >= 0 do + let x = int_of_hex_char (String.get s !i) in + let y = int_of_hex_char (String.get s (!i - 1)) in + Bytes.set_uint8 b !bi ((y lsl 4) + x); + i := !i - 2; + bi := !bi - 1 + done; + b + + let of_string s = try Some (of_string_exn s) with Invalid_argument _ -> None + + let to_string b = + let l = ref [] in + for i = 15 downto 0 do + l := Printf.sprintf "%.2x" (Bytes.get_uint8 b i) :: !l + done; + String.concat "" !l + + let pp ppf b = Format.fprintf ppf "%s" (to_string b) - let of_int32 x = x - let to_int32 x = x + let of_int64 (a, b) = + let b' = zero () in + Bytes.set_int64_be b' 0 a; + Bytes.set_int64_be b' 8 b; + b' + + let to_int64 b = (Bytes.get_int64_be b 0, Bytes.get_int64_be b 8) + + let of_int32 (a, b, c, d) = + let b' = zero () in + Bytes.set_int32_be b' 0 a; + Bytes.set_int32_be b' 4 b; + Bytes.set_int32_be b' 8 c; + Bytes.set_int32_be b' 12 d; + b' + + let to_int32 b = + ( Bytes.get_int32_be b 0, + Bytes.get_int32_be b 4, + Bytes.get_int32_be b 8, + Bytes.get_int32_be b 12 ) let of_int16 (a, b, c, d, e, f, g, h) = - ( V4.of_int16 (a, b), - V4.of_int16 (c, d), - V4.of_int16 (e, f), - V4.of_int16 (g, h) ) - - let to_int16 (x, y, z, t) = - let a, b = V4.to_int16 x - and c, d = V4.to_int16 y - and e, f = V4.to_int16 z - and g, h = V4.to_int16 t in - (a, b, c, d, e, f, g, h) - - let write_octets_exn ?(off = 0) (a, b, c, d) byte = - V4.write_octets_exn ~off a byte; - V4.write_octets_exn ~off:(off + 4) b byte; - V4.write_octets_exn ~off:(off + 8) c byte; - V4.write_octets_exn ~off:(off + 12) d byte - - let compare (a1, b1, c1, d1) (a2, b2, c2, d2) = - match V4.compare a1 a2 with - | 0 -> ( - match V4.compare b1 b2 with - | 0 -> ( match V4.compare c1 c2 with 0 -> V4.compare d1 d2 | n -> n) - | n -> n) - | n -> n - - let logand (a1, b1, c1, d1) (a2, b2, c2, d2) = - (a1 &&& a2, b1 &&& b2, c1 &&& c2, d1 &&& d2) - - let logor (a1, b1, c1, d1) (a2, b2, c2, d2) = - (a1 ||| a2, b1 ||| b2, c1 ||| c2, d1 ||| d2) - - let lognot (a, b, c, d) = Int32.(lognot a, lognot b, lognot c, lognot d) - - let succ (a, b, c, d) = - let cb (n, tl) v = - match n with - | 0l -> (0l, v :: tl) - | n -> - let n = if Int32.equal v 0xFF_FF_FF_FFl then n else 0l in - (n, Int32.succ v :: tl) - in - match List.fold_left cb (1l, []) [ d; c; b; a ] with - | 0l, [ a; b; c; d ] -> Ok (of_int32 (a, b, c, d)) - | n, [ _; _; _; _ ] when n > 0l -> - Error (`Msg "Ipaddr: highest address has been reached") - | _ -> Error (`Msg "Ipaddr: unexpected error with B128") - - let pred (a, b, c, d) = - let cb (n, tl) v = - match n with - | 0l -> (0l, v :: tl) - | n -> - let n = if v = 0x00_00_00_00l then n else 0l in - (n, Int32.pred v :: tl) - in - match List.fold_left cb (-1l, []) [ d; c; b; a ] with - | 0l, [ a; b; c; d ] -> Ok (of_int32 (a, b, c, d)) - | n, [ _; _; _; _ ] when n < 0l -> - Error (`Msg "Ipaddr: lowest address has been reached") - | _ -> Error (`Msg "Ipaddr: unexpected error with B128") - - (* result is unspecified if sz < 0 *) - let shift_right (a, b, c, d) sz = - if sz < 0 || sz > 128 then - Error (`Msg "Ipaddr: unexpected argument sz (must be >= 0 and < 128)") + let b' = zero () in + Bytes.set_uint16_be b' 0 a; + Bytes.set_uint16_be b' 2 b; + Bytes.set_uint16_be b' 4 c; + Bytes.set_uint16_be b' 6 d; + Bytes.set_uint16_be b' 8 e; + Bytes.set_uint16_be b' 10 f; + Bytes.set_uint16_be b' 12 g; + Bytes.set_uint16_be b' 14 h; + b' + + let to_int16 b = + ( Bytes.get_uint16_be b 0, + Bytes.get_uint16_be b 2, + Bytes.get_uint16_be b 4, + Bytes.get_uint16_be b 6, + Bytes.get_uint16_be b 8, + Bytes.get_uint16_be b 10, + Bytes.get_uint16_be b 12, + Bytes.get_uint16_be b 14 ) + + let add_exn x y = + let b = zero () in + let carry = ref 0 in + iteri_right2 + (fun i x' y' -> + let sum = x' + y' + !carry in + if sum >= 256 then ( + carry := 1; + Bytes.set_uint8 b i (sum - 256)) + else ( + carry := 0; + Bytes.set_uint8 b i sum)) + x y; + if !carry <> 0 then raise Overflow else b + + let add x y = try Some (add_exn x y) with Overflow -> None + + let sub_exn x y = + if Bytes.compare x y = -1 then raise Overflow else - let rec loop (a, b, c, d) sz = - if sz < 32 then (sz, (a, b, c, d)) else loop (0l, a, b, c) (sz - 32) - in - let sz, (a, b, c, d) = loop (a, b, c, d) sz in - let fn (saved, tl) part = - let new_saved = Int32.logand part (0xFF_FF_FF_FFl >|> sz) in - let new_part = part >|> sz ||| (saved <|< 32 - sz) in - (new_saved, new_part :: tl) - in - match List.fold_left fn (0l, []) [ a; b; c; d ] with - | _, [ d; c; b; a ] -> Ok (of_int32 (a, b, c, d)) - | _ -> Error (`Msg "Ipaddr: unexpected error with B128.shift_right") + let b = zero () in + let carry = ref 0 in + iteri_right2 + (fun i x' y' -> + if x' < y' then ( + Bytes.set_uint8 b i (256 + x' - y' - !carry); + carry := 1) + else ( + Bytes.set_uint8 b i (x' - y' - !carry); + carry := 0)) + x y; + if !carry <> 0 then raise Overflow else b + + let sub x y = + try Some (sub_exn x y) with Overflow -> None | Invalid_argument _ -> None + + let logand x y = + let b = zero () in + iteri_right2 (fun i x y -> Bytes.set_uint8 b i (x land y)) x y; + b + + let logor x y = + let b = zero () in + iteri_right2 (fun i x y -> Bytes.set_uint8 b i (x lor y)) x y; + b + + let logxor x y = + let b = zero () in + iteri_right2 (fun i x y -> Bytes.set_uint8 b i (x lxor y)) x y; + b + + let lognot x = + let b = zero () in + Bytes.iteri (fun i _ -> Bytes.set_uint8 b i (lnot (Bytes.get_uint8 x i))) x; + b + + module Byte = struct + (* Extract the [n] least significant bits from [i] *) + let get_lsbits n i = + if n <= 0 || n > 8 then invalid_arg "out of bounds"; + i land ((1 lsl n) - 1) + + (* Extract the [n] most significant bits from [i] *) + let get_msbits n i = + if n <= 0 || n > 8 then invalid_arg "out of bounds"; + (i land (255 lsl (8 - n))) lsr (8 - n) + + (* Set value [x] in [i]'s [n] most significant bits *) + let set_msbits n x i = + if n < 0 || n > 8 then raise (Invalid_argument "n must be >= 0 && <= 8") + else if n = 0 then i + else if n = 8 then x + else (x lsl (8 - n)) lor i + + (* set bits are represented as true *) + let fold_left f a i = + let bitmask = ref 0b1000_0000 in + let a' = ref a in + for _ = 0 to 7 do + a' := f !a' (i land !bitmask > 0); + bitmask := !bitmask lsr 1 + done; + !a' + end + + let shift_right x n = + match n with + | 0 -> x + | 128 -> zero () + | n when n > 0 && n < 128 -> + let b = zero () in + let shift_bytes, shift_bits = (n / 8, n mod 8) in + (if shift_bits = 0 then Bytes.blit x 0 b shift_bytes (16 - shift_bytes) + else + let carry = ref 0 in + for i = 0 to 15 - shift_bytes do + let x' = Bytes.get_uint8 x i in + let new_carry = Byte.get_lsbits shift_bits x' in + let shifted_value = x' lsr shift_bits in + let new_value = Byte.set_msbits shift_bits !carry shifted_value in + Bytes.set_uint8 b (i + shift_bytes) new_value; + carry := new_carry + done); + b + | _ -> raise (Invalid_argument "n must be >= 0 && <= 128") + + let shift_left x n = + match n with + | 0 -> x + | 128 -> zero () + | n when n > 0 && n < 128 -> + let b = zero () in + let shift_bytes, shift_bits = (n / 8, n mod 8) in + (if shift_bits = 0 then Bytes.blit x shift_bytes b 0 (16 - shift_bytes) + else + let carry = ref 0 in + for i = 15 downto 0 + shift_bytes do + let x' = Bytes.get_uint8 x i in + let new_carry = Byte.get_msbits shift_bits x' in + let shifted_value = x' lsl shift_bits in + let new_value = shifted_value lor !carry in + Bytes.set_uint8 b (i - shift_bytes) new_value; + carry := new_carry + done); + b + | _ -> raise (Invalid_argument "n must be >= 0 && <= 128") + + let write_octets_exn ?(off = 0) b' byte = + if Bytes.length b' + off > Bytes.length byte then + raise + (Parse_error + ("larger including offset than target bytes", String.of_bytes b')) + else Bytes.blit b' 0 byte off (Bytes.length b') + + let succ b = + try Ok (add_exn b (of_string_exn "00000000000000000000000000000001")) + with Overflow -> Error (`Msg "Ipaddr: highest address has been reached") + + let pred b = + try Ok (sub_exn b (of_string_exn "00000000000000000000000000000001")) + with Overflow | Invalid_argument _ -> + Error (`Msg "Ipaddr: lowest address has been reached") end module V6 = struct include B128 - (* TODO: Perhaps represent with bytestring? *) let make a b c d e f g h = of_int16 (a, b, c, d, e, f, g, h) (* parsing *) @@ -716,102 +863,75 @@ module V6 = struct (* byte conversion *) let of_octets_exn ?(off = 0) bs = - (* TODO : from cstruct *) - let hihi = V4.of_octets_exn ~off bs in - let hilo = V4.of_octets_exn ~off:(off + 4) bs in - let lohi = V4.of_octets_exn ~off:(off + 8) bs in - let lolo = V4.of_octets_exn ~off:(off + 12) bs in - of_int32 (hihi, hilo, lohi, lolo) + if String.length bs - off < 16 then raise (need_more bs) + else + let b = B128.zero () in + Bytes.blit_string bs off b 0 16; + b let of_octets ?off bs = try_with_result (of_octets_exn ?off) bs let write_octets ?off i bs = try_with_result (write_octets_exn ?off i) bs - - let to_octets i = - let b = Bytes.create 16 in - write_octets_exn i b; - Bytes.to_string b + let to_octets = Bytes.to_string (* MAC *) (* {{:https://tools.ietf.org/html/rfc2464#section-7}RFC 2464}. *) - let multicast_to_mac i = - let _, _, _, i = to_int32 i in - let macb = Bytes.create 6 in - Bytes.set macb 0 (Char.chr 0x33); - Bytes.set macb 1 (Char.chr 0x33); - Bytes.set macb 2 (Char.chr (( |~ ) (i >! 24))); - Bytes.set macb 3 (Char.chr (( |~ ) (i >! 16))); - Bytes.set macb 4 (Char.chr (( |~ ) (i >! 8))); - Bytes.set macb 5 (Char.chr (( |~ ) (i >! 0))); + let multicast_to_mac b = + let macb = Bytes.make 6 (Char.chr 0x33) in + Bytes.blit b 12 macb 2 4; Macaddr.of_octets_exn (Bytes.to_string macb) (* Host *) - let to_domain_name (a, b, c, d) = - let name = - [ - hex_string_of_int32 (d >|> 0 &&& 0xF_l); - hex_string_of_int32 (d >|> 4 &&& 0xF_l); - hex_string_of_int32 (d >|> 8 &&& 0xF_l); - hex_string_of_int32 (d >|> 12 &&& 0xF_l); - hex_string_of_int32 (d >|> 16 &&& 0xF_l); - hex_string_of_int32 (d >|> 20 &&& 0xF_l); - hex_string_of_int32 (d >|> 24 &&& 0xF_l); - hex_string_of_int32 (d >|> 28 &&& 0xF_l); - hex_string_of_int32 (c >|> 0 &&& 0xF_l); - hex_string_of_int32 (c >|> 4 &&& 0xF_l); - hex_string_of_int32 (c >|> 8 &&& 0xF_l); - hex_string_of_int32 (c >|> 12 &&& 0xF_l); - hex_string_of_int32 (c >|> 16 &&& 0xF_l); - hex_string_of_int32 (c >|> 20 &&& 0xF_l); - hex_string_of_int32 (c >|> 24 &&& 0xF_l); - hex_string_of_int32 (c >|> 28 &&& 0xF_l); - hex_string_of_int32 (b >|> 0 &&& 0xF_l); - hex_string_of_int32 (b >|> 4 &&& 0xF_l); - hex_string_of_int32 (b >|> 8 &&& 0xF_l); - hex_string_of_int32 (b >|> 12 &&& 0xF_l); - hex_string_of_int32 (b >|> 16 &&& 0xF_l); - hex_string_of_int32 (b >|> 20 &&& 0xF_l); - hex_string_of_int32 (b >|> 24 &&& 0xF_l); - hex_string_of_int32 (b >|> 28 &&& 0xF_l); - hex_string_of_int32 (a >|> 0 &&& 0xF_l); - hex_string_of_int32 (a >|> 4 &&& 0xF_l); - hex_string_of_int32 (a >|> 8 &&& 0xF_l); - hex_string_of_int32 (a >|> 12 &&& 0xF_l); - hex_string_of_int32 (a >|> 16 &&& 0xF_l); - hex_string_of_int32 (a >|> 20 &&& 0xF_l); - hex_string_of_int32 (a >|> 24 &&& 0xF_l); - hex_string_of_int32 (a >|> 28 &&& 0xF_l); - "ip6"; - "arpa"; - ] + let to_domain_name b = + let hexstr_of_int = Printf.sprintf "%x" in + let rec aux_fold_left a i = + if i = 16 then a + else + let x = hexstr_of_int (Bytes.get_uint8 b i land ((1 lsl 4) - 1)) in + let y = hexstr_of_int (Bytes.get_uint8 b i lsr 4) in + aux_fold_left (x :: y :: a) (i + 1) in + let name = aux_fold_left [ "ip6"; "arpa" ] 0 in Domain_name.(host_exn (of_strings_exn name)) let of_domain_name n = - let open Domain_name in - if count_labels n = 34 then - let ip6 = get_label_exn n 32 and arpa = get_label_exn n 33 in - if equal_label ip6 "ip6" && equal_label arpa "arpa" then - let rev = true in - let n' = drop_label_exn ~rev ~amount:2 n in - let d = drop_label_exn ~rev ~amount:24 n' - and c = drop_label_exn ~amount:8 (drop_label_exn ~rev ~amount:16 n') - and b = drop_label_exn ~amount:16 (drop_label_exn ~rev ~amount:8 n') - and a = drop_label_exn ~amount:24 n' in - let t b d = - let v = Int32.of_int (parse_hex_int d (ref 0)) in - if v > 0xFl then raise (Parse_error ("number in label too big", d)) - else v <|< b - in - let f d = - List.fold_left - (fun (acc, b) d -> (Int32.add acc (t b d), b + 4)) - (0l, 0) (to_strings d) - in - try - let a', _ = f a and b', _ = f b and c', _ = f c and d', _ = f d in - Some (a', b', c', d') - with Parse_error _ -> None - else None + let int_of_char_string = function + | "0" -> 0 + | "1" -> 1 + | "2" -> 2 + | "3" -> 3 + | "4" -> 4 + | "5" -> 5 + | "6" -> 6 + | "7" -> 7 + | "8" -> 8 + | "9" -> 9 + | "a" -> 10 + | "b" -> 11 + | "c" -> 12 + | "d" -> 13 + | "e" -> 14 + | "f" -> 15 + | _ -> failwith "int_of_char_string: invalid hexadecimal string" + in + let labels = Domain_name.to_array n in + if + Array.length labels = 34 + && Domain_name.equal_label labels.(0) "arpa" + && Domain_name.equal_label labels.(1) "ip6" + then + let b = B128.zero () in + let bi = ref 0 in + let i = ref 2 in + try + while !i <= 32 do + let x = int_of_char_string labels.(!i) in + let y = int_of_char_string labels.(!i + 1) in + Bytes.set_uint8 b !bi (Int.logor (Int.shift_left x 4) y); + bi := !bi + 1; + i := !i + 2 + done; + Some b + with Failure _ -> None else None (* constant *) @@ -833,14 +953,7 @@ module V6 = struct if c = 0 then Stdlib.compare sz sz' else c let ip = make - - let _full = - let f = 0x0_FFFF_FFFF_l in - (f, f, f, f) - - let mask sz = - V4.Prefix.(mask (sz - 0), mask (sz - 32), mask (sz - 64), mask (sz - 96)) - + let mask sz = shift_left (max_int ()) (128 - sz) let prefix (pre, sz) = (logand pre (mask sz), sz) let make sz pre = (pre, sz) @@ -871,19 +984,23 @@ module V6 = struct let of_string s = try_with_result of_string_exn s let _of_netmask_exn ~netmask address = - let nm = - let bits netmask = - V4.Prefix.bits (V4.Prefix.of_netmask_exn ~netmask ~address:V4.any) - in - match netmask with - | 0_l, 0_l, 0_l, 0_l -> 0 - | lsw, 0_l, 0_l, 0_l -> bits lsw - | -1_l, lsw, 0_l, 0_l -> bits lsw + 32 - | -1_l, -1_l, lsw, 0_l -> bits lsw + 64 - | -1_l, -1_l, -1_l, lsw -> bits lsw + 96 - | _ -> raise (Parse_error ("invalid netmask", to_string netmask)) + let count_bits bits is_last_bit_set i = + B128.Byte.fold_left + (fun (a, is_last_bit_set) e -> + match (is_last_bit_set, e) with + | true, false | false, false -> (a, false) + | true, true -> (a + 1, true) + | false, true -> + (* netmask is not contiguous *) + raise (Parse_error ("invalid netmask", to_string netmask))) + (bits, is_last_bit_set) i + in + let nm_bits_set, _ = + B128.fold_left + (fun (a, is_last_bit_set) e -> count_bits a is_last_bit_set e) + (0, true) netmask in - make nm address + make nm_bits_set address let of_netmask_exn ~netmask ~address = _of_netmask_exn ~netmask address @@ -923,8 +1040,8 @@ module V6 = struct if sz > 126 then network cidr else network cidr |> succ |> failwith_msg let last ((_, sz) as cidr) = - let ffff = ip 0xffff 0xffff 0xffff 0xffff 0xffff 0xffff 0xffff 0xffff in - logor (network cidr) (shift_right ffff sz |> failwith_msg) + let ffff = B128.max_int () in + logor (network cidr) (B128.shift_right ffff sz) end (* TODO: This could be optimized with something trie-like *) diff --git a/lib_test/test_ipaddr_b128.ml b/lib_test/test_ipaddr_b128.ml index 19c9491..153b438 100644 --- a/lib_test/test_ipaddr_b128.ml +++ b/lib_test/test_ipaddr_b128.ml @@ -16,35 +16,167 @@ *) open OUnit +module B128 = Ipaddr_internal.B128 + +(* copied from test_ipaddr.ml *) +let assert_raises ~msg exn test_fn = + assert_raises ~msg exn (fun () -> + try test_fn () + with rtexn -> + if exn <> rtexn then ( + Printf.eprintf "Stacktrace for '%s':\n%!" msg; + Printexc.print_backtrace stderr); + raise rtexn) + +let assert_equal = assert_equal ~printer:Ipaddr_internal.B128.to_string + +let test_addition () = + (* simple addition *) + let d1 = B128.zero () in + let d2 = B128.of_string_exn "00000000000000000000000000000001" in + assert_equal ~msg:"adding one to zero is one" d2 (B128.add_exn d1 d2); + + (* addition carry *) + let d1 = B128.of_string_exn "000000000000000000ff000000000000" in + let d2 = B128.of_string_exn "00000000000000000001000000000000" in + let d3 = B128.of_string_exn "00000000000000000100000000000000" in + assert_equal ~msg:"test addition carry over" d3 (B128.add_exn d1 d2); + + (* adding one to max_int overflows *) + let d1 = B128.max_int () in + let d2 = B128.of_string_exn "00000000000000000000000000000001" in + assert_raises ~msg:"adding one to max_int overflows" B128.Overflow (fun () -> + B128.add_exn d1 d2) + +let test_subtraction () = + (* simple subtraction *) + let d1 = B128.of_string_exn "00000000000000000000000000000001" in + let d2 = B128.of_string_exn "00000000000000000000000000000001" in + let d3 = B128.zero () in + assert_equal ~msg:"subtracting one from one is zero" d3 (B128.sub_exn d1 d2); + + (* subtract carry *) + let d1 = B128.of_string_exn "00000000000000000000000000000300" in + let d2 = B128.of_string_exn "0000000000000000000000000000002a" in + let d3 = B128.of_string_exn "000000000000000000000000000002d6" in + assert_equal ~msg:"test subtraction carry over" d3 (B128.sub_exn d1 d2); + + (* subtracting one from min_int overflows *) + let d1 = B128.min_int () in + let d2 = B128.of_string_exn "00000000000000000000000000000001" in + assert_raises ~msg:"subtracting one from min_int overflows" B128.Overflow + (fun () -> B128.sub_exn d1 d2) + +let test_of_to_string () = + let s = "ff000000000000004200000000000001" in + OUnit.assert_equal ~msg:"input of of_string is equal to output of to_string" s + (B128.of_string_exn s |> B128.to_string) + +let test_lognot () = + let d1 = B128.of_string_exn "00000000000000000000000000000001" in + let d2 = B128.of_string_exn "fffffffffffffffffffffffffffffffe" in + assert_equal ~msg:"lognot inverts bits" d2 (B128.lognot d1) + +let test_shift_left () = + (* bit shift count, input, expected output *) + let test_shifts = + [ + (1, "f0000000000000000000000000000000", "e0000000000000000000000000000000"); + (1, "0000000000000000000000000000000f", "0000000000000000000000000000001e"); + (1, "00000000000000000000000000000001", "00000000000000000000000000000002"); + (2, "f0000000000000000000000000000000", "c0000000000000000000000000000000"); + (2, "0000000000000000000000000000ffff", "0000000000000000000000000003fffc"); + (8, "00000000000000000000000000000100", "00000000000000000000000000010000"); + (9, "f0000000000000000000000000000000", "00000000000000000000000000000000"); + ( 64, + "00000000000000000000000000000001", + "00000000000000010000000000000000" ); + ( 127, + "00000000000000000000000000000001", + "80000000000000000000000000000000" ); + ( 128, + "00000000000000000000000000000001", + "00000000000000000000000000000000" ); + ] + in + List.iter + (fun (bits, input_value, expected_output) -> + assert_equal + ~msg:(Printf.sprintf "shift left by %i" bits) + (B128.of_string_exn expected_output) + (B128.shift_left (B128.of_string_exn input_value) bits)) + test_shifts let test_shift_right () = - let open Ipaddr_internal in - let open V6 in - let printer = function - | Ok v -> Printf.sprintf "Ok %s" (to_string v) - | Error (`Msg e) -> Printf.sprintf "Error `Msg \"%s\"" e + (* (bit shift count, input, expected output) *) + let test_shifts = + [ + (1, "f0000000000000000000000000000000", "78000000000000000000000000000000"); + (2, "f0000000000000000000000000000000", "3c000000000000000000000000000000"); + (2, "0000000000000000000000000000ffff", "00000000000000000000000000003fff"); + (2, "000000000000000000000000000ffff0", "0000000000000000000000000003fffc"); + (8, "00000000000000000000000000000100", "00000000000000000000000000000001"); + (9, "f0000000000000000000000000000000", "00780000000000000000000000000000"); + ( 32, + "000000000000000000000000ffffffff", + "00000000000000000000000000000000" ); + ( 32, + "0000000000000000aaaabbbbffffffff", + "000000000000000000000000aaaabbbb" ); + ( 40, + "0000000000000000aaaabbbbffffffff", + "00000000000000000000000000aaaabb" ); + ( 64, + "01000000000000000000000000000000", + "00000000000000000100000000000000" ); + ( 120, + "aaaabbbbccccdddd0000000000000000", + "000000000000000000000000000000aa" ); + ( 127, + "80000000000000000000000000000000", + "00000000000000000000000000000001" ); + ( 128, + "ffff0000000000000000000000000000", + "00000000000000000000000000000000" ); + ] in - let assert_equal = assert_equal ~printer in - assert_equal ~msg:":: >> 32" (of_string "::") - (B128.shift_right (of_string_exn "::ffff:ffff") 32); - assert_equal ~msg:"::aaaa:bbbb:ffff:ffff >> 32" (of_string "::aaaa:bbbb") - (B128.shift_right (of_string_exn "::aaaa:bbbb:ffff:ffff") 32); - assert_equal ~msg:"::aaaa:bbbb:ffff:ffff >> 40" (of_string "::aa:aabb") - (B128.shift_right (of_string_exn "::aaaa:bbbb:ffff:ffff") 40); - assert_equal ~msg:"::ffff >> 2" (of_string "::3fff") - (B128.shift_right (of_string_exn "::ffff") 2); - assert_equal ~msg:"ffff:: >> 128" (of_string "::") - (B128.shift_right (of_string_exn "ffff::") 128); - assert_equal ~msg:"aaaa:bbbb:cccc:dddd:: >> 120" (of_string "::aa") - (B128.shift_right (of_string_exn "aaaa:bbbb:cccc:dddd::") 120); - assert_equal ~msg:"ffff:: >> 140" - (Error (`Msg "Ipaddr: unexpected argument sz (must be >= 0 and < 128)")) - (B128.shift_right (of_string_exn "ffff::") 140); - assert_equal ~msg:"::ffff:ffff >> -8" - (Error (`Msg "Ipaddr: unexpected argument sz (must be >= 0 and < 128)")) - (B128.shift_right (of_string_exn "::ffff:ffff") (-8)) - -let suite = "Test B128 module" >::: [ "shift_right" >:: test_shift_right ];; + List.iter + (fun (bits, input_value, expected_output) -> + assert_equal + ~msg:(Printf.sprintf "shift right by %i" bits) + (B128.of_string_exn expected_output) + (B128.shift_right (B128.of_string_exn input_value) bits)) + test_shifts + +let test_byte_module () = + let assert_equal = OUnit2.assert_equal ~printer:(Printf.sprintf "0x%x") in + assert_equal ~msg:"get 3 lsb" 0x00 (B128.Byte.get_lsbits 3 0x00); + assert_equal ~msg:"get 4 lsb" 0x0f (B128.Byte.get_lsbits 4 0xff); + assert_equal ~msg:"get 5 lsb" 0x10 (B128.Byte.get_lsbits 5 0x10); + assert_equal ~msg:"get 8 lsb" 0xff (B128.Byte.get_lsbits 8 0xff); + + assert_equal ~msg:"get 3 msb" 0x0 (B128.Byte.get_msbits 3 0x00); + assert_equal ~msg:"get 4 msb" 0xf (B128.Byte.get_msbits 4 0xff); + assert_equal ~msg:"get 5 msb" 0x2 (B128.Byte.get_msbits 5 0x10); + assert_equal ~msg:"get 8 msb" 0xff (B128.Byte.get_msbits 8 0xff); + + assert_equal ~msg:"set 3 msb" 0x20 (B128.Byte.set_msbits 3 0x1 0x00); + assert_equal ~msg:"set 4 msb" 0xa0 (B128.Byte.set_msbits 4 0xa 0x00); + assert_equal ~msg:"set 5 msb" 0x98 (B128.Byte.set_msbits 5 0x13 0x00); + assert_equal ~msg:"set 8 msb" 0xff (B128.Byte.set_msbits 8 0xff 0x00) + +let suite = + "Test B128 module" + >::: [ + "addition" >:: test_addition; + "subtraction" >:: test_subtraction; + "of_to_string" >:: test_of_to_string; + "lognot" >:: test_lognot; + "shift_left" >:: test_shift_left; + "shift_right" >:: test_shift_right; + "byte_module" >:: test_byte_module; + ] +;; let _results = run_test_tt_main suite in () From c60bae53a290e33b99a67b307188bb85b65f0ca6 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Sun, 12 Mar 2023 21:41:31 +0100 Subject: [PATCH 2/6] remove superfluous code --- lib/ipaddr.ml | 39 +----------------------------------- lib_test/test_ipaddr_b128.ml | 4 ++-- 2 files changed, 3 insertions(+), 40 deletions(-) diff --git a/lib/ipaddr.ml b/lib/ipaddr.ml index 6a3660c..91798ab 100644 --- a/lib/ipaddr.ml +++ b/lib/ipaddr.ml @@ -108,27 +108,6 @@ let reject_octal s i = if s.[!i] == '0' && is_number 10 (int_of_char s.[!i + 1]) then raise (octal_notation s) -let hex_char_of_int = function - | 0 -> '0' - | 1 -> '1' - | 2 -> '2' - | 3 -> '3' - | 4 -> '4' - | 5 -> '5' - | 6 -> '6' - | 7 -> '7' - | 8 -> '8' - | 9 -> '9' - | 10 -> 'a' - | 11 -> 'b' - | 12 -> 'c' - | 13 -> 'd' - | 14 -> 'e' - | 15 -> 'f' - | _ -> raise (Invalid_argument "not a hex int") - -let hex_string_of_int32 i = String.make 1 (hex_char_of_int (Int32.to_int i)) - module V4 = struct type t = int32 @@ -454,9 +433,7 @@ module B128 = struct type t = Bytes.t let zero () = Bytes.make 16 '\x00' - let min_int () = zero () let max_int () = Bytes.make 16 '\xff' - let equal = Bytes.equal let compare = Bytes.compare let fold_left f a b = @@ -490,16 +467,12 @@ module B128 = struct done; b - let of_string s = try Some (of_string_exn s) with Invalid_argument _ -> None - let to_string b = let l = ref [] in for i = 15 downto 0 do l := Printf.sprintf "%.2x" (Bytes.get_uint8 b i) :: !l done; - String.concat "" !l - - let pp ppf b = Format.fprintf ppf "%s" (to_string b) + String.concat "" !l[@@ocaml.warning "-32"] (* used in the tests *) let of_int64 (a, b) = let b' = zero () in @@ -560,8 +533,6 @@ module B128 = struct x y; if !carry <> 0 then raise Overflow else b - let add x y = try Some (add_exn x y) with Overflow -> None - let sub_exn x y = if Bytes.compare x y = -1 then raise Overflow else @@ -578,9 +549,6 @@ module B128 = struct x y; if !carry <> 0 then raise Overflow else b - let sub x y = - try Some (sub_exn x y) with Overflow -> None | Invalid_argument _ -> None - let logand x y = let b = zero () in iteri_right2 (fun i x y -> Bytes.set_uint8 b i (x land y)) x y; @@ -591,11 +559,6 @@ module B128 = struct iteri_right2 (fun i x y -> Bytes.set_uint8 b i (x lor y)) x y; b - let logxor x y = - let b = zero () in - iteri_right2 (fun i x y -> Bytes.set_uint8 b i (x lxor y)) x y; - b - let lognot x = let b = zero () in Bytes.iteri (fun i _ -> Bytes.set_uint8 b i (lnot (Bytes.get_uint8 x i))) x; diff --git a/lib_test/test_ipaddr_b128.ml b/lib_test/test_ipaddr_b128.ml index 153b438..425e54d 100644 --- a/lib_test/test_ipaddr_b128.ml +++ b/lib_test/test_ipaddr_b128.ml @@ -61,8 +61,8 @@ let test_subtraction () = let d3 = B128.of_string_exn "000000000000000000000000000002d6" in assert_equal ~msg:"test subtraction carry over" d3 (B128.sub_exn d1 d2); - (* subtracting one from min_int overflows *) - let d1 = B128.min_int () in + (* subtracting one from zero overflows *) + let d1 = B128.zero () in let d2 = B128.of_string_exn "00000000000000000000000000000001" in assert_raises ~msg:"subtracting one from min_int overflows" B128.Overflow (fun () -> B128.sub_exn d1 d2) From 4bf4aa518bb7001bc97e0c790955251e83cac996 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Mon, 13 Mar 2023 11:01:49 +0100 Subject: [PATCH 3/6] String.of_bytes is only available since 4.13, use Bytes.to_string instead --- lib/ipaddr.ml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/ipaddr.ml b/lib/ipaddr.ml index 91798ab..1e2f1b3 100644 --- a/lib/ipaddr.ml +++ b/lib/ipaddr.ml @@ -639,7 +639,7 @@ module B128 = struct if Bytes.length b' + off > Bytes.length byte then raise (Parse_error - ("larger including offset than target bytes", String.of_bytes b')) + ("larger including offset than target bytes", Bytes.to_string b')) else Bytes.blit b' 0 byte off (Bytes.length b') let succ b = From a199bd75b63f721d7d9c22b99819e355b9c4fb84 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Mon, 13 Mar 2023 11:04:19 +0100 Subject: [PATCH 4/6] auto-format, and upgrade to ocamlformat 0.25.1 --- .ocamlformat | 2 +- lib/ipaddr.ml | 44 +++++++++++++++++++++++--------------------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/.ocamlformat b/.ocamlformat index 3bcdafe..e767a63 100644 --- a/.ocamlformat +++ b/.ocamlformat @@ -1,4 +1,4 @@ -version = 0.20.0 +version = 0.25.1 profile = conventional break-infix = fit-or-vertical parse-docstrings = true diff --git a/lib/ipaddr.ml b/lib/ipaddr.ml index 1e2f1b3..84afa37 100644 --- a/lib/ipaddr.ml +++ b/lib/ipaddr.ml @@ -472,7 +472,9 @@ module B128 = struct for i = 15 downto 0 do l := Printf.sprintf "%.2x" (Bytes.get_uint8 b i) :: !l done; - String.concat "" !l[@@ocaml.warning "-32"] (* used in the tests *) + String.concat "" !l + [@@ocaml.warning "-32"] + (* used in the tests *) let of_int64 (a, b) = let b' = zero () in @@ -601,16 +603,16 @@ module B128 = struct let b = zero () in let shift_bytes, shift_bits = (n / 8, n mod 8) in (if shift_bits = 0 then Bytes.blit x 0 b shift_bytes (16 - shift_bytes) - else - let carry = ref 0 in - for i = 0 to 15 - shift_bytes do - let x' = Bytes.get_uint8 x i in - let new_carry = Byte.get_lsbits shift_bits x' in - let shifted_value = x' lsr shift_bits in - let new_value = Byte.set_msbits shift_bits !carry shifted_value in - Bytes.set_uint8 b (i + shift_bytes) new_value; - carry := new_carry - done); + else + let carry = ref 0 in + for i = 0 to 15 - shift_bytes do + let x' = Bytes.get_uint8 x i in + let new_carry = Byte.get_lsbits shift_bits x' in + let shifted_value = x' lsr shift_bits in + let new_value = Byte.set_msbits shift_bits !carry shifted_value in + Bytes.set_uint8 b (i + shift_bytes) new_value; + carry := new_carry + done); b | _ -> raise (Invalid_argument "n must be >= 0 && <= 128") @@ -622,16 +624,16 @@ module B128 = struct let b = zero () in let shift_bytes, shift_bits = (n / 8, n mod 8) in (if shift_bits = 0 then Bytes.blit x shift_bytes b 0 (16 - shift_bytes) - else - let carry = ref 0 in - for i = 15 downto 0 + shift_bytes do - let x' = Bytes.get_uint8 x i in - let new_carry = Byte.get_msbits shift_bits x' in - let shifted_value = x' lsl shift_bits in - let new_value = shifted_value lor !carry in - Bytes.set_uint8 b (i - shift_bytes) new_value; - carry := new_carry - done); + else + let carry = ref 0 in + for i = 15 downto 0 + shift_bytes do + let x' = Bytes.get_uint8 x i in + let new_carry = Byte.get_msbits shift_bits x' in + let shifted_value = x' lsl shift_bits in + let new_value = shifted_value lor !carry in + Bytes.set_uint8 b (i - shift_bytes) new_value; + carry := new_carry + done); b | _ -> raise (Invalid_argument "n must be >= 0 && <= 128") From 5c4a09b0282cb68852bfa78cc4787583b8723ee0 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Mon, 13 Mar 2023 11:19:29 +0100 Subject: [PATCH 5/6] add test case for #113 --- lib_test/test_ipaddr.ml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/lib_test/test_ipaddr.ml b/lib_test/test_ipaddr.ml index b87a77f..a6f4e6b 100644 --- a/lib_test/test_ipaddr.ml +++ b/lib_test/test_ipaddr.ml @@ -633,6 +633,15 @@ module Test_v6 = struct V6.(to_int32 (of_int32 addr)) addr + let test_int64_rt () = + let ((a, b) as addr) = + 0x2a01_04f9_c011_87adL, 0x0_0_0_0L + in + assert_equal + ~msg:(Printf.sprintf "%016Lx %016Lx" a b) + V6.(to_int64 (of_int64 addr)) + addr + let test_prefix_string_rt () = let subnets = [ @@ -917,6 +926,7 @@ module Test_v6 = struct "cstruct_rt" >:: test_cstruct_rt; "cstruct_rt_bad" >:: test_cstruct_rt_bad; "int32_rt" >:: test_int32_rt; + "int64_rt" >:: test_int64_rt; "prefix_string_rt" >:: test_prefix_string_rt; "prefix_string_rt_bad" >:: test_prefix_string_rt_bad; "network_address_rt" >:: test_network_address_rt; From 4b17291cb21c05d557e3851ec70386e955474186 Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Mon, 13 Mar 2023 12:54:27 +0100 Subject: [PATCH 6/6] add test for V6.to_int64/of_int64 for 0:0:8000:: --- lib_test/test_ipaddr.ml | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/lib_test/test_ipaddr.ml b/lib_test/test_ipaddr.ml index a6f4e6b..0d980bf 100644 --- a/lib_test/test_ipaddr.ml +++ b/lib_test/test_ipaddr.ml @@ -634,13 +634,19 @@ module Test_v6 = struct addr let test_int64_rt () = - let ((a, b) as addr) = - 0x2a01_04f9_c011_87adL, 0x0_0_0_0L + let tests = + [ + (0x2a01_04f9_c011_87adL, 0x0_0_0_0L); + (0x0000_0000_8000_0000L, 0x0_0_0_0L); + ] in - assert_equal - ~msg:(Printf.sprintf "%016Lx %016Lx" a b) - V6.(to_int64 (of_int64 addr)) - addr + List.iter + (fun ((a, b) as addr) -> + assert_equal + ~msg:(Printf.sprintf "%016Lx %016Lx" a b) + V6.(to_int64 (of_int64 addr)) + addr) + tests let test_prefix_string_rt () = let subnets =