Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement derives for generic wrapper types #37

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,12 +232,13 @@ impl NumTraits {
pub fn from_primitive(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);

let import = NumTraits::new(&ast);

let impl_ = if let Some(inner_ty) = newtype_inner(&ast.data) {
quote! {
impl #import::FromPrimitive for #name {
impl #impl_ #import::FromPrimitive for #name #type_ #where_ #inner_ty: #import::FromPrimitive {
fn from_i64(n: i64) -> Option<Self> {
<#inner_ty as #import::FromPrimitive>::from_i64(n).map(#name)
}
Expand Down Expand Up @@ -320,7 +321,7 @@ pub fn from_primitive(input: TokenStream) -> TokenStream {
};

quote! {
impl #import::FromPrimitive for #name {
impl #impl_ #import::FromPrimitive for #name #type_ #where_ {
#[allow(trivial_numeric_casts)]
fn from_i64(#from_i64_var: i64) -> Option<Self> {
#(#clauses else)* {
Expand Down Expand Up @@ -390,12 +391,13 @@ pub fn from_primitive(input: TokenStream) -> TokenStream {
pub fn to_primitive(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);

let import = NumTraits::new(&ast);

let impl_ = if let Some(inner_ty) = newtype_inner(&ast.data) {
quote! {
impl #import::ToPrimitive for #name {
impl #impl_ #import::ToPrimitive for #name #type_ #where_ #inner_ty: #import::ToPrimitive {
fn to_i64(&self) -> Option<i64> {
<#inner_ty as #import::ToPrimitive>::to_i64(&self.0)
}
Expand Down Expand Up @@ -481,7 +483,7 @@ pub fn to_primitive(input: TokenStream) -> TokenStream {
};

quote! {
impl #import::ToPrimitive for #name {
impl #impl_ #import::ToPrimitive for #name #type_ #where_ {
#[allow(trivial_numeric_casts)]
fn to_i64(&self) -> Option<i64> {
#match_expr
Expand Down Expand Up @@ -511,33 +513,34 @@ const NEWTYPE_ONLY: &str = "This trait can only be derived for newtypes";
pub fn num_ops(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
let impl_ = quote! {
impl ::std::ops::Add for #name {
impl #impl_ ::std::ops::Add for #name #type_ #where_ #inner_ty: ::std::ops::Add<Output = #inner_ty> {
type Output = Self;
fn add(self, other: Self) -> Self {
#name(<#inner_ty as ::std::ops::Add>::add(self.0, other.0))
}
}
impl ::std::ops::Sub for #name {
impl #impl_ ::std::ops::Sub for #name #type_ #where_ #inner_ty: ::std::ops::Sub<Output = #inner_ty> {
type Output = Self;
fn sub(self, other: Self) -> Self {
#name(<#inner_ty as ::std::ops::Sub>::sub(self.0, other.0))
}
}
impl ::std::ops::Mul for #name {
impl #impl_ ::std::ops::Mul for #name #type_ #where_ #inner_ty: ::std::ops::Mul<Output = #inner_ty> {
type Output = Self;
fn mul(self, other: Self) -> Self {
#name(<#inner_ty as ::std::ops::Mul>::mul(self.0, other.0))
}
}
impl ::std::ops::Div for #name {
impl #impl_ ::std::ops::Div for #name #type_ #where_ #inner_ty: ::std::ops::Div<Output = #inner_ty> {
type Output = Self;
fn div(self, other: Self) -> Self {
#name(<#inner_ty as ::std::ops::Div>::div(self.0, other.0))
}
}
impl ::std::ops::Rem for #name {
impl #impl_ ::std::ops::Rem for #name #type_ #where_ #inner_ty: ::std::ops::Rem<Output = #inner_ty> {
type Output = Self;
fn rem(self, other: Self) -> Self {
#name(<#inner_ty as ::std::ops::Rem>::rem(self.0, other.0))
Expand All @@ -555,13 +558,16 @@ pub fn num_ops(input: TokenStream) -> TokenStream {
pub fn num_cast(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);
let fn_param = proc_macro2::Ident::new("FROM_T", name.span());

let import = NumTraits::new(&ast);

let impl_ = quote! {
impl #import::NumCast for #name {
fn from<T: #import::ToPrimitive>(n: T) -> Option<Self> {
impl #impl_ #import::NumCast for #name #type_ #where_ #inner_ty: #import::NumCast {
#[allow(non_camel_case_types)]
fn from<#fn_param: #import::ToPrimitive>(n: #fn_param) -> Option<Self> {
<#inner_ty as #import::NumCast>::from(n).map(#name)
}
}
Expand All @@ -577,12 +583,13 @@ pub fn num_cast(input: TokenStream) -> TokenStream {
pub fn zero(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);

let import = NumTraits::new(&ast);

let impl_ = quote! {
impl #import::Zero for #name {
impl #impl_ #import::Zero for #name #type_ #where_ #inner_ty: #import::Zero {
fn zero() -> Self {
#name(<#inner_ty as #import::Zero>::zero())
}
Expand All @@ -602,12 +609,13 @@ pub fn zero(input: TokenStream) -> TokenStream {
pub fn one(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);

let import = NumTraits::new(&ast);

let impl_ = quote! {
impl #import::One for #name {
impl #impl_ #import::One for #name #type_ #where_ #inner_ty: #import::One + PartialEq {
fn one() -> Self {
#name(<#inner_ty as #import::One>::one())
}
Expand All @@ -620,19 +628,31 @@ pub fn one(input: TokenStream) -> TokenStream {
import.wrap("One", &name, impl_).into()
}

fn split_for_impl(
generics: &syn::Generics,
) -> (syn::ImplGenerics, syn::TypeGenerics, impl quote::ToTokens) {
let (impl_, type_, where_) = generics.split_for_impl();
let where_ = match where_ {
Some(where_) => quote! { #where_, },
None => quote! { where },
};
(impl_, type_, where_)
}

/// Derives [`num_traits::Num`][num] for newtypes. The inner type must already implement `Num`.
///
/// [num]: https://docs.rs/num-traits/0.2/num_traits/trait.Num.html
#[proc_macro_derive(Num, attributes(num_traits))]
pub fn num(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);

let import = NumTraits::new(&ast);

let impl_ = quote! {
impl #import::Num for #name {
impl #impl_ #import::Num for #name #type_ #where_ #inner_ty: #import::Num {
type FromStrRadixErr = <#inner_ty as #import::Num>::FromStrRadixErr;
fn from_str_radix(s: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
<#inner_ty as #import::Num>::from_str_radix(s, radix).map(#name)
Expand All @@ -651,12 +671,13 @@ pub fn num(input: TokenStream) -> TokenStream {
pub fn float(input: TokenStream) -> TokenStream {
let ast: syn::DeriveInput = syn::parse(input).unwrap();
let name = &ast.ident;
let (impl_, type_, where_) = split_for_impl(&ast.generics);
let inner_ty = newtype_inner(&ast.data).expect(NEWTYPE_ONLY);

let import = NumTraits::new(&ast);

let impl_ = quote! {
impl #import::Float for #name {
impl #impl_ #import::Float for #name #type_ #where_ #inner_ty: #import::Float {
fn nan() -> Self {
#name(<#inner_ty as #import::Float>::nan())
}
Expand Down
97 changes: 97 additions & 0 deletions tests/generic_newtype.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
extern crate num as num_renamed;
#[macro_use]
extern crate num_derive;

use crate::num_renamed::{Float, FromPrimitive, Num, NumCast, One, ToPrimitive, Zero};
use std::ops::Neg;

#[derive(
Debug,
Clone,
Copy,
PartialEq,
PartialOrd,
ToPrimitive,
FromPrimitive,
NumOps,
NumCast,
One,
Zero,
Num,
Float,
)]
struct MyThing<T: Cake>(T)
where
T: Lie;

trait Cake {}
trait Lie {}

impl Cake for f32 {}
impl Lie for f32 {}

impl<T: Neg<Output = T> + Cake + Lie> Neg for MyThing<T> {
type Output = Self;
fn neg(self) -> Self {
MyThing(self.0.neg())
}
}

#[test]
fn test_from_primitive() {
assert_eq!(MyThing::from_u32(25), Some(MyThing(25.0)));
}

#[test]
fn test_from_primitive_128() {
assert_eq!(
MyThing::from_i128(std::i128::MIN),
Some(MyThing((-2.0).powi(127)))
);
}

#[test]
fn test_to_primitive() {
assert_eq!(MyThing(25.0).to_u32(), Some(25));
}

#[test]
fn test_to_primitive_128() {
let f: MyThing<f32> = MyThing::from_f32(std::f32::MAX).unwrap();
assert_eq!(f.to_i128(), None);
assert_eq!(f.to_u128(), Some(0xffff_ff00_0000_0000_0000_0000_0000_0000));
}

#[test]
fn test_num_ops() {
assert_eq!(MyThing(25.0) + MyThing(10.0), MyThing(35.0));
assert_eq!(MyThing(25.0) - MyThing(10.0), MyThing(15.0));
assert_eq!(MyThing(25.0) * MyThing(2.0), MyThing(50.0));
assert_eq!(MyThing(25.0) / MyThing(10.0), MyThing(2.5));
assert_eq!(MyThing(25.0) % MyThing(10.0), MyThing(5.0));
}

#[test]
fn test_num_cast() {
assert_eq!(<MyThing<f32> as NumCast>::from(25u8), Some(MyThing(25.0)));
}

#[test]
fn test_zero() {
assert_eq!(MyThing::zero(), MyThing(0.0));
}

#[test]
fn test_one() {
assert_eq!(MyThing::one(), MyThing(1.0));
}

#[test]
fn test_num() {
assert_eq!(MyThing::from_str_radix("25", 10).ok(), Some(MyThing(25.0)));
}

#[test]
fn test_float() {
assert_eq!(MyThing(4.0).log(MyThing(2.0)), MyThing(2.0));
}
44 changes: 44 additions & 0 deletions tests/generic_newtype2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
extern crate num as num_renamed;
#[macro_use]
extern crate num_derive;

use crate::num_renamed::{FromPrimitive};
use std::ops::Neg;
use std::num::Wrapping;

#[derive(
Debug,
Clone,
Copy,
PartialOrd,
ToPrimitive,
FromPrimitive,
NumOps,
NumCast,
One,
Zero,
Num,
)]
struct MyThing<T>(Wrapping<T>);

impl<T> PartialEq for MyThing<T> where Wrapping<T>: PartialEq {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Num requires PartialEq, and the libstd derive causes PartialEq to get a T: PartialEq bound instead of a Wrapping<T>: PartialEq bound, this workaround is required to use the new num derives from this PR with such types.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a meaningful difference? The std implementation looks like this:

impl<T> PartialEq<Wrapping<T>> for Wrapping<T>
where
    T: PartialEq<T>,

You can't impl PartialEq for Wrapping<LocalType>, so it seem like T: PartialEq and Wrapping<T>: PartialEq should be equivalent.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why there's a difference, but there definitely is one. I'll look into it

fn eq(&self, other: &Self) -> bool {
self.0.eq(&other.0)
}
}

impl<T: Neg<Output = T>> Neg for MyThing<T> where Wrapping<T>: Neg<Output = Wrapping<T>> {
type Output = Self;
fn neg(self) -> Self {
MyThing(self.0.neg())
}
}

#[test]
fn test_num_ops() {
assert_eq!(MyThing::<u32>::from_i128(25).unwrap() + MyThing::<u32>::from_i128(10).unwrap(), MyThing::<u32>::from_i128(35).unwrap());
assert_eq!(MyThing::<u32>::from_i128(25).unwrap() - MyThing::<u32>::from_i128(10).unwrap(), MyThing::<u32>::from_i128(15).unwrap());
assert_eq!(MyThing::<u32>::from_i128(25).unwrap() * MyThing::<u32>::from_i128(2).unwrap(), MyThing::<u32>::from_i128(50).unwrap());
assert_eq!(MyThing::<u32>::from_i128(25).unwrap() / MyThing::<u32>::from_i128(10).unwrap(), MyThing::<u32>::from_i128(2).unwrap());
assert_eq!(MyThing::<u32>::from_i128(25).unwrap() % MyThing::<u32>::from_i128(10).unwrap(), MyThing::<u32>::from_i128(5).unwrap());
}