diff --git a/src/sm2/ecc.rs b/src/sm2/ecc.rs index 952ef82..6a2e355 100644 --- a/src/sm2/ecc.rs +++ b/src/sm2/ecc.rs @@ -28,7 +28,7 @@ pub struct EccCtx { inv2: FieldElem, } -#[derive(Clone)] +#[derive(Clone, PartialEq, Eq)] pub struct Point { pub x: FieldElem, pub y: FieldElem, @@ -176,7 +176,7 @@ impl EccCtx { rc } - pub fn new_point(&self, x: &FieldElem, y: &FieldElem) -> Result { + pub fn new_point(&self, x: &FieldElem, y: &FieldElem) -> Result { let ctx = &self.fctx; // Check if (x, y) is a valid point on the curve(affine projection) @@ -188,7 +188,7 @@ impl EccCtx { let rhs = ctx.add(&self.b, &ctx.add(&x_cubic, &ax)); if !lhs.eq(&rhs) { - return Err(String::from("invalid point")); + return Err(Sm2Error::NotOnCurve); } let p = Point { @@ -263,7 +263,7 @@ impl EccCtx { match self.new_point(&x, &y) { Ok(p) => p, - Err(m) => panic!(m), + Err(m) => panic!("{:?}", m), } } @@ -291,7 +291,7 @@ impl EccCtx { let neg_y = self.fctx.neg(&p.y); match self.new_jacobian(&p.x, &neg_y, &p.z) { Ok(neg_p) => neg_p, - Err(e) => panic!(e), + Err(e) => panic!("{}", e), } } @@ -304,9 +304,9 @@ impl EccCtx { let ctx = &self.fctx; - //if self.eq(&p1, &p2) { - // return self.double(p1); - //} + if p1 == p2 { + return self.double(p1); + } let lam1 = ctx.mul(&p1.x, &ctx.square(&p2.z)); let lam2 = ctx.mul(&p2.x, &ctx.square(&p1.z)); @@ -485,8 +485,7 @@ impl EccCtx { ret } - #[allow(clippy::result_unit_err)] - pub fn bytes_to_point(&self, b: &[u8]) -> Result { + pub fn bytes_to_point(&self, b: &[u8]) -> Result { let ctx = &self.fctx; if b.len() == 33 { @@ -496,7 +495,7 @@ impl EccCtx { } else if b[0] == 0x03 { y_q = 1 } else { - return Err(()); + return Err(Sm2Error::InvalidPublic); } let x = FieldElem::from_bytes(&b[1..]); @@ -506,26 +505,22 @@ impl EccCtx { let y_2 = ctx.add(&self.b, &ctx.add(&x_cubic, &ax)); let mut y = self.fctx.sqrt(&y_2)?; + if y.get_value(7) & 0x01 != y_q { y = self.fctx.neg(&y); } - match self.new_point(&x, &y) { - Ok(p) => Ok(p), - Err(_) => Err(()), - } + self.new_point(&x, &y) } else if b.len() == 65 { if b[0] != 0x04 { - return Err(()); + return Err(Sm2Error::InvalidPublic); } let x = FieldElem::from_bytes(&b[1..33]); let y = FieldElem::from_bytes(&b[33..65]); - match self.new_point(&x, &y) { - Ok(p) => Ok(p), - Err(_) => Err(()), - } + + self.new_point(&x, &y) } else { - Err(()) + Err(Sm2Error::InvalidPublic) } } } @@ -542,6 +537,7 @@ impl Point { } } +use sm2::error::Sm2Error; use std::fmt; impl fmt::Display for Point { @@ -572,6 +568,10 @@ mod tests { assert!(curve.eq(&g, &new_g)); assert!(zero.is_zero()); + + let double_g = curve.double(&g); // 2 * g + let add_g = curve.add(&g, &g); // g + g + assert!(curve.eq(&add_g, &double_g)); } #[test] diff --git a/src/sm2/error.rs b/src/sm2/error.rs new file mode 100644 index 0000000..0ce72b1 --- /dev/null +++ b/src/sm2/error.rs @@ -0,0 +1,41 @@ +// Copyright 2018 Cryptape Technology LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Formatter; + +pub enum Sm2Error { + NotOnCurve, + FieldSqrtError, + InvalidDer, + InvalidPublic, + InvalidPrivate, +} + +impl ::std::fmt::Debug for Sm2Error { + fn fmt(&self, f: &mut Formatter<'_>) -> ::std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl From for &str { + fn from(e: Sm2Error) -> Self { + match e { + Sm2Error::NotOnCurve => "the point not on curve", + Sm2Error::FieldSqrtError => "field elem sqrt error", + Sm2Error::InvalidDer => "invalid der", + Sm2Error::InvalidPublic => "invalid public key", + Sm2Error::InvalidPrivate => "invalid private key", + } + } +} diff --git a/src/sm2/field.rs b/src/sm2/field.rs index 05e5bea..3176491 100644 --- a/src/sm2/field.rs +++ b/src/sm2/field.rs @@ -17,6 +17,7 @@ use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; use num_bigint::BigUint; use num_traits::Num; +use sm2::error::Sm2Error; use std::io::Cursor; pub struct FieldCtx { @@ -239,8 +240,7 @@ impl FieldCtx { } // Square root of a field element - #[allow(clippy::result_unit_err)] - pub fn sqrt(&self, g: &FieldElem) -> Result { + pub fn sqrt(&self, g: &FieldElem) -> Result { // p = 4 * u + 3 // u = u + 1 let u = BigUint::from_str_radix( @@ -253,7 +253,7 @@ impl FieldCtx { if self.square(&y) == *g { Ok(y) } else { - Err(()) + Err(Sm2Error::FieldSqrtError) } } } diff --git a/src/sm2/mod.rs b/src/sm2/mod.rs index 7676e27..0b8b393 100644 --- a/src/sm2/mod.rs +++ b/src/sm2/mod.rs @@ -13,5 +13,6 @@ // limitations under the License. pub mod ecc; +mod error; pub mod field; pub mod signature; diff --git a/src/sm2/signature.rs b/src/sm2/signature.rs index aa3fbe0..d6a919d 100644 --- a/src/sm2/signature.rs +++ b/src/sm2/signature.rs @@ -21,6 +21,7 @@ use sm3::hash::Sm3Hash; use yasna; use byteorder::{BigEndian, WriteBytesExt}; +use sm2::error::Sm2Error; pub type Pubkey = Point; pub type Seckey = BigUint; @@ -48,24 +49,23 @@ impl Signature { Ok(Signature { r, s }) } - #[allow(clippy::result_unit_err)] - pub fn der_decode_raw(buf: &[u8]) -> Result { + pub fn der_decode_raw(buf: &[u8]) -> Result { if buf[0] != 0x02 { - return Err(()); + return Err(Sm2Error::InvalidDer); } let r_len: usize = buf[1] as usize; if buf.len() <= r_len + 4 { - return Err(()); + return Err(Sm2Error::InvalidDer); } let r = BigUint::from_bytes_be(&buf[2..2 + r_len]); let buf = &buf[2 + r_len..]; if buf[0] != 0x02 { - return Err(()); + return Err(Sm2Error::InvalidDer); } let s_len: usize = buf[1] as usize; if buf.len() < s_len + 2 { - return Err(()); + return Err(Sm2Error::InvalidDer); } let s = BigUint::from_bytes_be(&buf[2..2 + s_len]); @@ -142,7 +142,7 @@ impl SigCtx { let mut prepended_msg: Vec = Vec::new(); prepended_msg.extend_from_slice(&z_a[..]); - prepended_msg.extend_from_slice(&msg[..]); + prepended_msg.extend_from_slice(msg); let mut hasher = Sm3Hash::new(&prepended_msg[..]); hasher.get_hash() @@ -187,7 +187,7 @@ impl SigCtx { let mut prepended_msg: Vec = Vec::new(); prepended_msg.extend_from_slice(&z_a[..]); - prepended_msg.extend_from_slice(&msg[..]); + prepended_msg.extend_from_slice(msg); prepended_msg } @@ -304,8 +304,7 @@ impl SigCtx { curve.mul(&sk, &curve.generator()) } - #[allow(clippy::result_unit_err)] - pub fn load_pubkey(&self, buf: &[u8]) -> Result { + pub fn load_pubkey(&self, buf: &[u8]) -> Result { self.curve.bytes_to_point(buf) } @@ -313,14 +312,13 @@ impl SigCtx { self.curve.point_to_bytes(p, compress) } - #[allow(clippy::result_unit_err)] - pub fn load_seckey(&self, buf: &[u8]) -> Result { + pub fn load_seckey(&self, buf: &[u8]) -> Result { if buf.len() != 32 { - return Err(()); + return Err(Sm2Error::InvalidPrivate); } let sk = BigUint::from_bytes_be(buf); if sk > *self.curve.get_n() { - Err(()) + Err(Sm2Error::InvalidPrivate) } else { Ok(sk) } diff --git a/src/sm4/cipher_mode.rs b/src/sm4/cipher_mode.rs index 4fbabc8..a8ad1e5 100644 --- a/src/sm4/cipher_mode.rs +++ b/src/sm4/cipher_mode.rs @@ -21,7 +21,7 @@ pub enum CipherMode { Cbc, } -pub struct SM4CipherMode { +pub struct Sm4CipherMode { cipher: Sm4Cipher, mode: CipherMode, } @@ -47,10 +47,10 @@ fn block_add_one(a: &mut [u8]) { } } -impl SM4CipherMode { - pub fn new(key: &[u8], mode: CipherMode) -> SM4CipherMode { +impl Sm4CipherMode { + pub fn new(key: &[u8], mode: CipherMode) -> Sm4CipherMode { let cipher = Sm4Cipher::new(key); - SM4CipherMode { cipher, mode } + Sm4CipherMode { cipher, mode } } pub fn encrypt(&self, data: &[u8], iv: &[u8]) -> Vec { @@ -286,7 +286,7 @@ mod tests { let key = rand_block(); let iv = rand_block(); - let cmode = SM4CipherMode::new(&key, mode); + let cmode = Sm4CipherMode::new(&key, mode); let pt = rand_data(10); let ct = cmode.encrypt(&pt[..], &iv); @@ -309,7 +309,7 @@ mod tests { let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap(); let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap(); - let cipher_mode = SM4CipherMode::new(&key, CipherMode::Ctr); + let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Ctr); let msg = b"hello world, this file is used for smx test\n"; let lhs = cipher_mode.encrypt(msg, &iv); let lhs: &[u8] = lhs.as_ref(); @@ -323,7 +323,7 @@ mod tests { let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap(); let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap(); - let cipher_mode = SM4CipherMode::new(&key, CipherMode::Cfb); + let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Cfb); let msg = b"hello world, this file is used for smx test\n"; let lhs = cipher_mode.encrypt(msg, &iv); let lhs: &[u8] = lhs.as_ref(); @@ -337,7 +337,7 @@ mod tests { let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap(); let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap(); - let cipher_mode = SM4CipherMode::new(&key, CipherMode::Ofb); + let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Ofb); let msg = b"hello world, this file is used for smx test\n"; let lhs = cipher_mode.encrypt(msg, &iv); let lhs: &[u8] = lhs.as_ref(); @@ -351,7 +351,7 @@ mod tests { let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap(); let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap(); - let cipher_mode = SM4CipherMode::new(&key, CipherMode::Cbc); + let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Cbc); let msg = b"hello world, this file is used for smx test\n"; let lhs = cipher_mode.encrypt(msg, &iv); let lhs: &[u8] = lhs.as_ref(); diff --git a/src/sm4/mod.rs b/src/sm4/mod.rs index 27ad94a..195ef9a 100644 --- a/src/sm4/mod.rs +++ b/src/sm4/mod.rs @@ -16,4 +16,4 @@ pub mod cipher; pub mod cipher_mode; pub type Mode = self::cipher_mode::CipherMode; -pub type Cipher = self::cipher_mode::SM4CipherMode; +pub type Cipher = self::cipher_mode::Sm4CipherMode;