Skip to content

Commit

Permalink
Merge pull request #24 from citahub/fix_point_add
Browse files Browse the repository at this point in the history
fix: same affine point add
  • Loading branch information
rink1969 authored May 8, 2021
2 parents 6656219 + eb62fa3 commit fafd9f2
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 48 deletions.
42 changes: 21 additions & 21 deletions src/sm2/ecc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub struct EccCtx {
inv2: FieldElem,
}

#[derive(Clone)]
#[derive(Clone, PartialEq, Eq)]
pub struct Point {
pub x: FieldElem,
pub y: FieldElem,
Expand Down Expand Up @@ -176,7 +176,7 @@ impl EccCtx {
rc
}

pub fn new_point(&self, x: &FieldElem, y: &FieldElem) -> Result<Point, String> {
pub fn new_point(&self, x: &FieldElem, y: &FieldElem) -> Result<Point, Sm2Error> {
let ctx = &self.fctx;

// Check if (x, y) is a valid point on the curve(affine projection)
Expand All @@ -188,7 +188,7 @@ impl EccCtx {
let rhs = ctx.add(&self.b, &ctx.add(&x_cubic, &ax));

if !lhs.eq(&rhs) {
return Err(String::from("invalid point"));
return Err(Sm2Error::NotOnCurve);
}

let p = Point {
Expand Down Expand Up @@ -263,7 +263,7 @@ impl EccCtx {

match self.new_point(&x, &y) {
Ok(p) => p,
Err(m) => panic!(m),
Err(m) => panic!("{:?}", m),
}
}

Expand Down Expand Up @@ -291,7 +291,7 @@ impl EccCtx {
let neg_y = self.fctx.neg(&p.y);
match self.new_jacobian(&p.x, &neg_y, &p.z) {
Ok(neg_p) => neg_p,
Err(e) => panic!(e),
Err(e) => panic!("{}", e),
}
}

Expand All @@ -304,9 +304,9 @@ impl EccCtx {

let ctx = &self.fctx;

//if self.eq(&p1, &p2) {
// return self.double(p1);
//}
if p1 == p2 {
return self.double(p1);
}

let lam1 = ctx.mul(&p1.x, &ctx.square(&p2.z));
let lam2 = ctx.mul(&p2.x, &ctx.square(&p1.z));
Expand Down Expand Up @@ -485,8 +485,7 @@ impl EccCtx {
ret
}

#[allow(clippy::result_unit_err)]
pub fn bytes_to_point(&self, b: &[u8]) -> Result<Point, ()> {
pub fn bytes_to_point(&self, b: &[u8]) -> Result<Point, Sm2Error> {
let ctx = &self.fctx;

if b.len() == 33 {
Expand All @@ -496,7 +495,7 @@ impl EccCtx {
} else if b[0] == 0x03 {
y_q = 1
} else {
return Err(());
return Err(Sm2Error::InvalidPublic);
}

let x = FieldElem::from_bytes(&b[1..]);
Expand All @@ -506,26 +505,22 @@ impl EccCtx {
let y_2 = ctx.add(&self.b, &ctx.add(&x_cubic, &ax));

let mut y = self.fctx.sqrt(&y_2)?;

if y.get_value(7) & 0x01 != y_q {
y = self.fctx.neg(&y);
}

match self.new_point(&x, &y) {
Ok(p) => Ok(p),
Err(_) => Err(()),
}
self.new_point(&x, &y)
} else if b.len() == 65 {
if b[0] != 0x04 {
return Err(());
return Err(Sm2Error::InvalidPublic);
}
let x = FieldElem::from_bytes(&b[1..33]);
let y = FieldElem::from_bytes(&b[33..65]);
match self.new_point(&x, &y) {
Ok(p) => Ok(p),
Err(_) => Err(()),
}

self.new_point(&x, &y)
} else {
Err(())
Err(Sm2Error::InvalidPublic)
}
}
}
Expand All @@ -542,6 +537,7 @@ impl Point {
}
}

use sm2::error::Sm2Error;
use std::fmt;

impl fmt::Display for Point {
Expand Down Expand Up @@ -572,6 +568,10 @@ mod tests {

assert!(curve.eq(&g, &new_g));
assert!(zero.is_zero());

let double_g = curve.double(&g); // 2 * g
let add_g = curve.add(&g, &g); // g + g
assert!(curve.eq(&add_g, &double_g));
}

#[test]
Expand Down
41 changes: 41 additions & 0 deletions src/sm2/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2018 Cryptape Technology LLC.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt::Formatter;

pub enum Sm2Error {
NotOnCurve,
FieldSqrtError,
InvalidDer,
InvalidPublic,
InvalidPrivate,
}

impl ::std::fmt::Debug for Sm2Error {
fn fmt(&self, f: &mut Formatter<'_>) -> ::std::fmt::Result {
write!(f, "{:?}", self)
}
}

impl From<Sm2Error> for &str {
fn from(e: Sm2Error) -> Self {
match e {
Sm2Error::NotOnCurve => "the point not on curve",
Sm2Error::FieldSqrtError => "field elem sqrt error",
Sm2Error::InvalidDer => "invalid der",
Sm2Error::InvalidPublic => "invalid public key",
Sm2Error::InvalidPrivate => "invalid private key",
}
}
}
6 changes: 3 additions & 3 deletions src/sm2/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use num_bigint::BigUint;
use num_traits::Num;
use sm2::error::Sm2Error;
use std::io::Cursor;

pub struct FieldCtx {
Expand Down Expand Up @@ -239,8 +240,7 @@ impl FieldCtx {
}

// Square root of a field element
#[allow(clippy::result_unit_err)]
pub fn sqrt(&self, g: &FieldElem) -> Result<FieldElem, ()> {
pub fn sqrt(&self, g: &FieldElem) -> Result<FieldElem, Sm2Error> {
// p = 4 * u + 3
// u = u + 1
let u = BigUint::from_str_radix(
Expand All @@ -253,7 +253,7 @@ impl FieldCtx {
if self.square(&y) == *g {
Ok(y)
} else {
Err(())
Err(Sm2Error::FieldSqrtError)
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/sm2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
// limitations under the License.

pub mod ecc;
mod error;
pub mod field;
pub mod signature;
26 changes: 12 additions & 14 deletions src/sm2/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use sm3::hash::Sm3Hash;
use yasna;

use byteorder::{BigEndian, WriteBytesExt};
use sm2::error::Sm2Error;

pub type Pubkey = Point;
pub type Seckey = BigUint;
Expand Down Expand Up @@ -48,24 +49,23 @@ impl Signature {
Ok(Signature { r, s })
}

#[allow(clippy::result_unit_err)]
pub fn der_decode_raw(buf: &[u8]) -> Result<Signature, ()> {
pub fn der_decode_raw(buf: &[u8]) -> Result<Signature, Sm2Error> {
if buf[0] != 0x02 {
return Err(());
return Err(Sm2Error::InvalidDer);
}
let r_len: usize = buf[1] as usize;
if buf.len() <= r_len + 4 {
return Err(());
return Err(Sm2Error::InvalidDer);
}
let r = BigUint::from_bytes_be(&buf[2..2 + r_len]);

let buf = &buf[2 + r_len..];
if buf[0] != 0x02 {
return Err(());
return Err(Sm2Error::InvalidDer);
}
let s_len: usize = buf[1] as usize;
if buf.len() < s_len + 2 {
return Err(());
return Err(Sm2Error::InvalidDer);
}
let s = BigUint::from_bytes_be(&buf[2..2 + s_len]);

Expand Down Expand Up @@ -142,7 +142,7 @@ impl SigCtx {

let mut prepended_msg: Vec<u8> = Vec::new();
prepended_msg.extend_from_slice(&z_a[..]);
prepended_msg.extend_from_slice(&msg[..]);
prepended_msg.extend_from_slice(msg);

let mut hasher = Sm3Hash::new(&prepended_msg[..]);
hasher.get_hash()
Expand Down Expand Up @@ -187,7 +187,7 @@ impl SigCtx {

let mut prepended_msg: Vec<u8> = Vec::new();
prepended_msg.extend_from_slice(&z_a[..]);
prepended_msg.extend_from_slice(&msg[..]);
prepended_msg.extend_from_slice(msg);

prepended_msg
}
Expand Down Expand Up @@ -304,23 +304,21 @@ impl SigCtx {
curve.mul(&sk, &curve.generator())
}

#[allow(clippy::result_unit_err)]
pub fn load_pubkey(&self, buf: &[u8]) -> Result<Point, ()> {
pub fn load_pubkey(&self, buf: &[u8]) -> Result<Point, Sm2Error> {
self.curve.bytes_to_point(buf)
}

pub fn serialize_pubkey(&self, p: &Point, compress: bool) -> Vec<u8> {
self.curve.point_to_bytes(p, compress)
}

#[allow(clippy::result_unit_err)]
pub fn load_seckey(&self, buf: &[u8]) -> Result<BigUint, ()> {
pub fn load_seckey(&self, buf: &[u8]) -> Result<BigUint, Sm2Error> {
if buf.len() != 32 {
return Err(());
return Err(Sm2Error::InvalidPrivate);
}
let sk = BigUint::from_bytes_be(buf);
if sk > *self.curve.get_n() {
Err(())
Err(Sm2Error::InvalidPrivate)
} else {
Ok(sk)
}
Expand Down
18 changes: 9 additions & 9 deletions src/sm4/cipher_mode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub enum CipherMode {
Cbc,
}

pub struct SM4CipherMode {
pub struct Sm4CipherMode {
cipher: Sm4Cipher,
mode: CipherMode,
}
Expand All @@ -47,10 +47,10 @@ fn block_add_one(a: &mut [u8]) {
}
}

impl SM4CipherMode {
pub fn new(key: &[u8], mode: CipherMode) -> SM4CipherMode {
impl Sm4CipherMode {
pub fn new(key: &[u8], mode: CipherMode) -> Sm4CipherMode {
let cipher = Sm4Cipher::new(key);
SM4CipherMode { cipher, mode }
Sm4CipherMode { cipher, mode }
}

pub fn encrypt(&self, data: &[u8], iv: &[u8]) -> Vec<u8> {
Expand Down Expand Up @@ -286,7 +286,7 @@ mod tests {
let key = rand_block();
let iv = rand_block();

let cmode = SM4CipherMode::new(&key, mode);
let cmode = Sm4CipherMode::new(&key, mode);

let pt = rand_data(10);
let ct = cmode.encrypt(&pt[..], &iv);
Expand All @@ -309,7 +309,7 @@ mod tests {
let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap();
let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap();

let cipher_mode = SM4CipherMode::new(&key, CipherMode::Ctr);
let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Ctr);
let msg = b"hello world, this file is used for smx test\n";
let lhs = cipher_mode.encrypt(msg, &iv);
let lhs: &[u8] = lhs.as_ref();
Expand All @@ -323,7 +323,7 @@ mod tests {
let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap();
let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap();

let cipher_mode = SM4CipherMode::new(&key, CipherMode::Cfb);
let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Cfb);
let msg = b"hello world, this file is used for smx test\n";
let lhs = cipher_mode.encrypt(msg, &iv);
let lhs: &[u8] = lhs.as_ref();
Expand All @@ -337,7 +337,7 @@ mod tests {
let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap();
let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap();

let cipher_mode = SM4CipherMode::new(&key, CipherMode::Ofb);
let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Ofb);
let msg = b"hello world, this file is used for smx test\n";
let lhs = cipher_mode.encrypt(msg, &iv);
let lhs: &[u8] = lhs.as_ref();
Expand All @@ -351,7 +351,7 @@ mod tests {
let key = hex::decode("1234567890abcdef1234567890abcdef").unwrap();
let iv = hex::decode("fedcba0987654321fedcba0987654321").unwrap();

let cipher_mode = SM4CipherMode::new(&key, CipherMode::Cbc);
let cipher_mode = Sm4CipherMode::new(&key, CipherMode::Cbc);
let msg = b"hello world, this file is used for smx test\n";
let lhs = cipher_mode.encrypt(msg, &iv);
let lhs: &[u8] = lhs.as_ref();
Expand Down
2 changes: 1 addition & 1 deletion src/sm4/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ pub mod cipher;
pub mod cipher_mode;

pub type Mode = self::cipher_mode::CipherMode;
pub type Cipher = self::cipher_mode::SM4CipherMode;
pub type Cipher = self::cipher_mode::Sm4CipherMode;

0 comments on commit fafd9f2

Please sign in to comment.