diff --git a/Cargo.toml b/Cargo.toml index 9613e82..d65ae0f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,5 +4,5 @@ members = [ "macro", "core", "ormlite", - "cli" -] \ No newline at end of file + "cli", +] diff --git a/macro/Cargo.toml b/macro/Cargo.toml index a76c645..b497d9e 100644 --- a/macro/Cargo.toml +++ b/macro/Cargo.toml @@ -30,3 +30,4 @@ sqlx = "0.8.2" lazy_static = "1" once_cell = "1" itertools = "0.13.0" +heck = "0.5" diff --git a/macro/src/codegen/insert_model.rs b/macro/src/codegen/insert_model.rs index 1acafdf..c358d5c 100644 --- a/macro/src/codegen/insert_model.rs +++ b/macro/src/codegen/insert_model.rs @@ -23,7 +23,7 @@ pub fn struct_InsertModel(ast: &DeriveInput, attr: &ModelMeta) -> TokenStream { #vis struct #insert_model { #(#struct_fields,)* } - } + } } else { quote! { #[derive(Debug)] diff --git a/macro/src/lib.rs b/macro/src/lib.rs index 03f5aa8..55c385c 100644 --- a/macro/src/lib.rs +++ b/macro/src/lib.rs @@ -2,6 +2,7 @@ #![allow(non_snake_case)] use codegen::insert::impl_Insert; +use heck::ToSnakeCase; use ormlite_attr::InsertMeta; use proc_macro::TokenStream; use std::borrow::Borrow; @@ -9,6 +10,7 @@ use std::collections::HashMap; use std::env; use std::env::var; use std::ops::Deref; +use syn::DataEnum; use once_cell::sync::OnceCell; use quote::quote; @@ -250,3 +252,79 @@ pub fn expand_derive_into_arguments(input: TokenStream) -> TokenStream { pub fn expand_derive_manual_type(input: TokenStream) -> TokenStream { TokenStream::new() } + +#[proc_macro_derive(Enum)] +pub fn derive_ormlite_enum(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + + let enum_name = input.ident; + + let variants = match input.data { + Data::Enum(DataEnum { variants, .. }) => variants, + _ => panic!("#[derive(OrmliteEnum)] is only supported on enums"), + }; + + // Collect variant names and strings into vectors + let variant_names: Vec<_> = variants.iter().map(|v| &v.ident).collect(); + let variant_strings: Vec<_> = variant_names.iter().map(|v| v.to_string().to_snake_case()).collect(); + + let gen = quote! { + impl std::fmt::Display for #enum_name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + #(Self::#variant_names => write!(f, "{}", #variant_strings)),* + } + } + } + + impl std::str::FromStr for #enum_name { + type Err = String; + fn from_str(s: &str) -> Result::Err> { + match s { + #(#variant_strings => Ok(Self::#variant_names)),*, + _ => Err(format!("Invalid {} value: {}", stringify!(#enum_name), s)) + } + } + } + + impl std::convert::TryFrom<&str> for #enum_name { + type Error = String; + fn try_from(value: &str) -> Result { + ::from_str(value) + } + } + + impl sqlx::Decode<'_, sqlx::Postgres> for #enum_name { + fn decode( + value: sqlx::postgres::PgValueRef<'_>, + ) -> Result { + let s = value.as_str()?; + ::from_str(s).map_err(|e| sqlx::error::BoxDynError::from( + std::io::Error::new(std::io::ErrorKind::InvalidData, e) + )) + } + } + + impl sqlx::Encode<'_, sqlx::Postgres> for #enum_name { + fn encode_by_ref( + &self, + buf: &mut sqlx::postgres::PgArgumentBuffer + ) -> Result { + let s = self.to_string(); + >::encode(s, buf) + } + } + + impl sqlx::Type for #enum_name { + fn type_info() -> ::TypeInfo { + sqlx::postgres::PgTypeInfo::with_name("VARCHAR") + } + + fn compatible(ty: &::TypeInfo) -> bool { + ty.to_string() == "VARCHAR" + } + } + }; + + gen.into() +} diff --git a/ormlite/src/lib.rs b/ormlite/src/lib.rs index 4a9dc7b..f0261c9 100644 --- a/ormlite/src/lib.rs +++ b/ormlite/src/lib.rs @@ -1,11 +1,12 @@ #![cfg_attr(docsrs, feature(doc_cfg))] -pub use ::sqlx::{Column, ColumnIndex, Database, Decode, Row}; pub use model::{FromRow, Insert, IntoArguments, Model, TableMeta}; pub use ormlite_core::BoxFuture; pub use ormlite_core::{Error, Result}; +pub use ormlite_macro::Enum; +pub use sqlx::{Column, ColumnIndex, Database, Decode, Row}; -pub use ::sqlx::pool::PoolOptions; -pub use ::sqlx::{ +pub use sqlx::pool::PoolOptions; +pub use sqlx::{ query, query_as, query_as_with, query_with, Acquire, Arguments, ConnectOptions, Connection, Encode, Executor, Pool, };