From 393cb0e737995cedd4cb3836b83cfe920d8f4249 Mon Sep 17 00:00:00 2001 From: Eran Boodnero Date: Thu, 5 Dec 2024 20:22:31 -0800 Subject: [PATCH] wtf is going on --- attr/src/derive.rs | 2 +- attr/src/ident.rs | 4 +-- attr/src/metadata/column.rs | 9 ++++--- attr/src/metadata/model.rs | 23 ++++++++++------ attr/src/metadata/table.rs | 2 +- cli/src/command/down.rs | 4 +-- cli/src/command/migrate.rs | 8 +++--- cli/src/schema.rs | 22 ++++++++------- core/src/query_builder/args.rs | 6 ++++- core/src/query_builder/select.rs | 5 +++- core/src/query_builder/util.rs | 39 +++++++++++++-------------- core/src/schema.rs | 21 ++++++++------- macro/src/codegen/common.rs | 6 ++--- macro/src/codegen/from_row.rs | 4 +-- macro/src/codegen/insert.rs | 14 +++++++--- macro/src/codegen/insert_model.rs | 2 +- macro/src/codegen/join_description.rs | 2 +- macro/src/codegen/model.rs | 2 +- macro/src/codegen/update.rs | 2 +- macro/src/lib.rs | 15 ++++++----- 20 files changed, 110 insertions(+), 82 deletions(-) diff --git a/attr/src/derive.rs b/attr/src/derive.rs index 3366cb7..a4ed66b 100644 --- a/attr/src/derive.rs +++ b/attr/src/derive.rs @@ -221,6 +221,6 @@ pub enum Privacy { panic!() }; let attr = DeriveParser::from_attributes(&item.attrs); - assert_eq!(attr.has_derive("ormlite", "ManualType"), true); + assert!(attr.has_derive("ormlite", "ManualType")); } } diff --git a/attr/src/ident.rs b/attr/src/ident.rs index 7417e77..93ea3d7 100644 --- a/attr/src/ident.rs +++ b/attr/src/ident.rs @@ -5,8 +5,8 @@ use quote::TokenStreamExt; #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct Ident(String); -impl Ident { - pub fn as_ref(&self) -> &String { +impl AsRef for Ident { + fn as_ref(&self) -> &String { &self.0 } } diff --git a/attr/src/metadata/column.rs b/attr/src/metadata/column.rs index 1f4d34c..065095d 100644 --- a/attr/src/metadata/column.rs +++ b/attr/src/metadata/column.rs @@ -42,7 +42,10 @@ impl ColumnMeta { } pub fn from_fields<'a>(fields: impl Iterator) -> Vec { - fields.map(|f| ColumnMeta::from_field(f)).collect() + fn fun_name(f: &Field) -> ColumnMeta { + ColumnMeta::from_field(f) + } + fields.map(fun_name).collect() } pub fn from_syn(ident: &syn::Ident, ty: &syn::Type) -> Self { @@ -268,8 +271,8 @@ pub name: String let column = ColumnMeta::from_field(field); assert_eq!(column.name, "name"); assert_eq!(column.ty, "String"); - assert_eq!(column.marked_primary_key, false); - assert_eq!(column.has_database_default, false); + assert!(!column.marked_primary_key); + assert!(!column.has_database_default); assert_eq!(column.rust_default, Some("\"foo\".to_string()".to_string())); assert_eq!(column.ident, "name"); } diff --git a/attr/src/metadata/model.rs b/attr/src/metadata/model.rs index ed7867a..6be3e05 100644 --- a/attr/src/metadata/model.rs +++ b/attr/src/metadata/model.rs @@ -28,10 +28,12 @@ impl ModelMeta { pub fn from_derive(ast: &DeriveInput) -> Self { let attrs = TableAttr::from_attrs(&ast.attrs); let table = TableMeta::new(ast, &attrs); - let pkey = table.pkey.as_deref().expect(&format!( - "No column marked with #[ormlite(primary_key)], and no column named id, uuid, {0}_id, or {0}_uuid", - table.name, - )); + let pkey = table.pkey.as_deref().unwrap_or_else(|| { + panic!( + "No column marked with #[ormlite(primary_key)], and no column named id, uuid, {0}_id, or {0}_uuid", + table.name + ) + }); let mut insert_struct = None; let mut extra_derives: Option> = None; for attr in attrs { @@ -48,13 +50,18 @@ impl ModelMeta { } } let pkey = table.columns.iter().find(|&c| c.name == pkey).unwrap().clone(); - let insert_struct = insert_struct.map(|v| Ident::from(v)); - let extra_derives = extra_derives.take().map(|vec| vec.into_iter().map(|v| v.to_string()).map(Ident::from).collect()); - + fn fun_name(v: String) -> Ident { + Ident::from(v) + } + let insert_struct = insert_struct.map(fun_name); + let extra_derives = extra_derives + .take() + .map(|vec| vec.into_iter().map(|v| v.to_string()).map(Ident::from).collect()); + Self { table, insert_struct, - extra_derives, + extra_derives, pkey, } } diff --git a/attr/src/metadata/table.rs b/attr/src/metadata/table.rs index c538239..253120d 100644 --- a/attr/src/metadata/table.rs +++ b/attr/src/metadata/table.rs @@ -31,7 +31,7 @@ impl TableMeta { let mut pkey = columns .iter() .find(|&c| c.marked_primary_key) - .map(|c| c.clone()) + .cloned() .map(|c| c.name.clone()); if pkey.is_none() { let candidates = sqlmo::util::pkey_column_names(&name); diff --git a/cli/src/command/down.rs b/cli/src/command/down.rs index 74d50ce..f3ec6f9 100644 --- a/cli/src/command/down.rs +++ b/cli/src/command/down.rs @@ -7,12 +7,12 @@ use std::path::Path; use crate::command::{get_executed_migrations, get_pending_migrations, MigrationType}; use crate::util::{create_runtime, CommandSuccess}; +use anyhow::anyhow; use ormlite::postgres::{PgArguments, PgConnection}; use ormlite::Arguments; use ormlite::{Acquire, Connection, Executor}; use ormlite_core::config::{get_var_database_url, get_var_migration_folder, get_var_snapshot_folder}; use url::Url; -use anyhow::anyhow; #[derive(Parser, Debug)] pub struct Down { @@ -68,7 +68,7 @@ impl Down { let target = if let Some(target) = self.target { target } else if executed.len() > 1 { - executed.iter().nth(1).unwrap().name.clone() + executed.get(1).unwrap().name.clone() } else if executed.len() == 1 { "0_empty".to_string() } else { diff --git a/cli/src/command/migrate.rs b/cli/src/command/migrate.rs index 7fafcc2..442f20b 100644 --- a/cli/src/command/migrate.rs +++ b/cli/src/command/migrate.rs @@ -172,9 +172,9 @@ fn check_for_pending_migrations( fn check_reversible_compatibility(reversible: bool, migration_environment: Option) -> Result<()> { if let Some(migration_environment) = migration_environment { - if reversible && migration_environment == MigrationType::Simple { - return Err(anyhow!("You cannot mix reversible and non-reversible migrations")); - } else if !reversible && migration_environment != MigrationType::Simple { + if (reversible && migration_environment == MigrationType::Simple) + || (!reversible && migration_environment != MigrationType::Simple) + { return Err(anyhow!("You cannot mix reversible and non-reversible migrations")); } } @@ -209,7 +209,7 @@ fn autogenerate_migration( let mut current = runtime.block_on(Schema::try_from_postgres(conn, "public"))?; current.tables.retain(|t| t.name != "_sqlx_migrations"); - let mut desired = schema_from_ormlite_project(codebase_path, &c)?; + let mut desired = schema_from_ormlite_project(codebase_path, c)?; experimental_modifications_to_schema(&mut desired)?; let migration = current.migrate_to( diff --git a/cli/src/schema.rs b/cli/src/schema.rs index 6e2f270..566e8aa 100644 --- a/cli/src/schema.rs +++ b/cli/src/schema.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; -use std::path::Path; -use sqlmo::{Constraint, Schema, Table}; +use crate::config::Config; +use anyhow::Result as AnyResult; use ormlite_attr::{schema_from_filepaths, Ident, InnerType, Type}; use ormlite_core::schema::FromMeta; -use anyhow::Result as AnyResult; -use crate::config::Config; +use sqlmo::{Constraint, Schema, Table}; +use std::collections::HashMap; +use std::path::Path; pub fn schema_from_ormlite_project(paths: &[&Path], c: &Config) -> AnyResult { let mut schema = Schema::default(); @@ -29,7 +29,7 @@ pub fn schema_from_ormlite_project(paths: &[&Path], c: &Config) -> AnyResult AnyResult = - schema.tables.iter().map(|t| (t.name.clone(), (t.name.clone(), t.primary_key().unwrap().name.clone()))).collect(); + let mut table_names: HashMap = schema + .tables + .iter() + .map(|t| (t.name.clone(), (t.name.clone(), t.primary_key().unwrap().name.clone()))) + .collect(); for (alias, real) in &c.table.aliases { let Some(real) = table_names.get(real) else { continue; @@ -63,4 +66,5 @@ pub fn schema_from_ormlite_project(paths: &[&Path], c: &Config) -> AnyResult(pub Box>, usize); +pub struct QueryBuilderArgs<'q, DB: Database>(pub Box>, pub usize); impl<'q, DB: Database> QueryBuilderArgs<'q, DB> { pub fn add + sqlx::Type>(&mut self, arg: T) { @@ -12,6 +12,10 @@ impl<'q, DB: Database> QueryBuilderArgs<'q, DB> { pub fn len(&self) -> usize { self.1 } + + pub fn is_empty(&self) -> bool { + self.1 == 0 + } } impl<'q, DB: Database> IntoArguments<'q, DB> for QueryBuilderArgs<'q, DB> { diff --git a/core/src/query_builder/select.rs b/core/src/query_builder/select.rs index 8c6746c..2ac156b 100644 --- a/core/src/query_builder/select.rs +++ b/core/src/query_builder/select.rs @@ -212,7 +212,10 @@ where let q = self.query.to_sql(DB::dialect()); let args = self.arguments; let (q, placeholder_count) = util::replace_placeholders(&q, &mut self.gen)?; - if placeholder_count != args.len() { + if placeholder_count != { + let this = &args; + this.1 + } { return Err(Error::OrmliteError(format!( "Failing to build query. {} placeholders were found in the query, but \ {} arguments were provided.", diff --git a/core/src/query_builder/util.rs b/core/src/query_builder/util.rs index eb80635..c2f6ac3 100644 --- a/core/src/query_builder/util.rs +++ b/core/src/query_builder/util.rs @@ -22,30 +22,27 @@ pub fn replace_placeholders>( buf.push_str(&placeholder_generator.next().unwrap()); placeholder_count += 1; } - Token::Char(c) => { - match c { - '?' => { - buf.push_str(&placeholder_generator.next().unwrap()); - placeholder_count += 1; - } - '$' => { - let next_tok = it.next(); - if let Some(next_tok) = next_tok { - match next_tok { - Token::Number(text, _) => { - let n = text.parse::().map_err(|_| Error::OrmliteError( - format!("Failed to parse number after a $ during query tokenization. Value was: {text}" - )))?; - buf.push_str(&format!("${next_tok}")); - placeholder_count = std::cmp::max(placeholder_count, n); - } - _ => {} - } + Token::Char(c) => match c { + '?' => { + buf.push_str(&placeholder_generator.next().unwrap()); + placeholder_count += 1; + } + '$' => { + let next_tok = it.next(); + if let Some(next_tok) = next_tok { + if let Token::Number(text, _) = next_tok { + let n = text.parse::().map_err(|_| { + Error::OrmliteError(format!( + "Failed to parse number after a $ during query tokenization. Value was: {text}" + )) + })?; + buf.push_str(&format!("${next_tok}")); + placeholder_count = std::cmp::max(placeholder_count, n); } } - _ => buf.push(*c), } - } + _ => buf.push(*c), + }, _ => buf.push_str(&tok.to_string()), } } diff --git a/core/src/schema.rs b/core/src/schema.rs index cd1d19d..20697d7 100644 --- a/core/src/schema.rs +++ b/core/src/schema.rs @@ -1,11 +1,11 @@ -use std::collections::HashMap; -use std::path::Path; -use ormlite_attr::{schema_from_filepaths, ColumnMeta, Ident, InnerType}; +use crate::config::Config; +use anyhow::Result as AnyResult; use ormlite_attr::ModelMeta; use ormlite_attr::Type; +use ormlite_attr::{schema_from_filepaths, ColumnMeta, Ident, InnerType}; use sqlmo::{schema::Column, Constraint, Schema, Table}; -use anyhow::Result as AnyResult; -use crate::config::Config; +use std::collections::HashMap; +use std::path::Path; pub fn schema_from_ormlite_project(paths: &[&Path], c: &Config) -> AnyResult { let mut schema = Schema::default(); @@ -30,7 +30,7 @@ pub fn schema_from_ormlite_project(paths: &[&Path], c: &Config) -> AnyResult AnyResult = - schema.tables.iter().map(|t| (t.name.clone(), (t.name.clone(), t.primary_key().unwrap().name.clone()))).collect(); + let mut table_names: HashMap = schema + .tables + .iter() + .map(|t| (t.name.clone(), (t.name.clone(), t.primary_key().unwrap().name.clone()))) + .collect(); for (alias, real) in &c.table.aliases { let Some(real) = table_names.get(real) else { continue; @@ -207,10 +210,10 @@ impl Nullable { #[cfg(test)] mod tests { use super::*; + use anyhow::Result; use assert_matches::assert_matches; use ormlite_attr::Type; use syn::parse_str; - use anyhow::Result; #[test] fn test_convert_type() -> Result<()> { diff --git a/macro/src/codegen/common.rs b/macro/src/codegen/common.rs index f233024..2922755 100644 --- a/macro/src/codegen/common.rs +++ b/macro/src/codegen/common.rs @@ -75,8 +75,7 @@ fn recursive_primitive_types<'a>(table: &'a ModelMeta, cache: &'a MetadataCache) table .columns .iter() - .map(|c| recursive_primitive_types_ty(&c.ty, cache)) - .flatten() + .flat_map(|c| recursive_primitive_types_ty(&c.ty, cache)) .collect() } @@ -85,8 +84,7 @@ pub(crate) fn table_primitive_types<'a>(attr: &'a TableMeta, cache: &'a Metadata .iter() .filter(|c| !c.skip) .filter(|c| !c.json) - .map(|c| recursive_primitive_types_ty(&c.ty, cache)) - .flatten() + .flat_map(|c| recursive_primitive_types_ty(&c.ty, cache)) .unique() .collect() } diff --git a/macro/src/codegen/from_row.rs b/macro/src/codegen/from_row.rs index dfc37db..e25c432 100644 --- a/macro/src/codegen/from_row.rs +++ b/macro/src/codegen/from_row.rs @@ -1,8 +1,8 @@ use crate::codegen::common::{from_row_bounds, OrmliteCodegen}; use crate::MetadataCache; -use ormlite_attr::{ColumnMeta, Type}; use ormlite_attr::Ident; use ormlite_attr::TableMeta; +use ormlite_attr::{ColumnMeta, Type}; use proc_macro2::TokenStream; use quote::quote; @@ -96,7 +96,7 @@ pub fn impl_from_row_using_aliases( ) -> TokenStream { let row = db.row(); let fields = attr.all_fields(); - let bounds = from_row_bounds(db, attr, &metadata_cache); + let bounds = from_row_bounds(db, attr, metadata_cache); let mut incrementer = 0usize..; let columns = attr .columns diff --git a/macro/src/codegen/insert.rs b/macro/src/codegen/insert.rs index e5d971b..1ee7f10 100644 --- a/macro/src/codegen/insert.rs +++ b/macro/src/codegen/insert.rs @@ -22,13 +22,16 @@ pub fn impl_Model__insert(db: &dyn OrmliteCodegen, attr: &ModelMeta, metadata_ca placeholder.next().unwrap() } }); + fn fun_name(c: &ColumnMeta) -> TokenStream { + insertion_binding(c) + } let query_bindings = attr .database_columns() .filter(|c| attr.pkey.name == c.name || !c.has_database_default) - .map(|c| insertion_binding(c)); + .map(fun_name); - let insert_join = attr.many_to_one_joins().map(|c| insert_join(c)); + let insert_join = attr.many_to_one_joins().map(insert_join); let late_bind = attr.many_to_one_joins().map(|c| { let id = &c.ident; @@ -141,15 +144,18 @@ pub fn impl_Insert(db: &dyn OrmliteCodegen, meta: &TableMeta, model: &Ident, ret ); let query_bindings = meta.database_columns().filter(|&c| !c.has_database_default).map(|c| { if let Some(rust_default) = &c.rust_default { - let default: syn::Expr = syn::parse_str(&rust_default).expect("Failed to parse default_value"); + let default: syn::Expr = syn::parse_str(rust_default).expect("Failed to parse default_value"); return quote! { q = q.bind(#default); }; } insertion_binding(c) }); + fn fun_name(c: &ColumnMeta) -> TokenStream { + insert_join(c) + } - let insert_join = meta.many_to_one_joins().map(|c| insert_join(c)); + let insert_join = meta.many_to_one_joins().map(fun_name); let late_bind = meta.many_to_one_joins().map(|c| { let id = &c.ident; diff --git a/macro/src/codegen/insert_model.rs b/macro/src/codegen/insert_model.rs index 1acafdf..c358d5c 100644 --- a/macro/src/codegen/insert_model.rs +++ b/macro/src/codegen/insert_model.rs @@ -23,7 +23,7 @@ pub fn struct_InsertModel(ast: &DeriveInput, attr: &ModelMeta) -> TokenStream { #vis struct #insert_model { #(#struct_fields,)* } - } + } } else { quote! { #[derive(Debug)] diff --git a/macro/src/codegen/join_description.rs b/macro/src/codegen/join_description.rs index bdc0bf2..617ed9e 100644 --- a/macro/src/codegen/join_description.rs +++ b/macro/src/codegen/join_description.rs @@ -11,7 +11,7 @@ pub fn static_join_descriptions(attr: &TableMeta, metadata_cache: &MetadataCache let struct_name = c.joined_struct_name().unwrap(); let joined_table = metadata_cache .get(&struct_name) - .expect(&format!("Did not find metadata for joined struct: {}", struct_name)); + .unwrap_or_else(|| panic!("Did not find metadata for joined struct: {}", struct_name)); let column_name = c.many_to_one_column_name.as_ref().unwrap(); let foreign_key = &joined_table.pkey.name; diff --git a/macro/src/codegen/model.rs b/macro/src/codegen/model.rs index 2876072..e4d7bcc 100644 --- a/macro/src/codegen/model.rs +++ b/macro/src/codegen/model.rs @@ -11,7 +11,7 @@ pub fn impl_Model(db: &dyn OrmliteCodegen, attr: &ModelMeta, metadata_cache: &Me let model = &attr.ident; let partial_model = attr.builder_struct(); - let impl_Model__insert = impl_Model__insert(db, &attr, metadata_cache); + let impl_Model__insert = impl_Model__insert(db, attr, metadata_cache); let impl_Model__update_all_fields = impl_Model__update_all_fields(db, attr); let impl_Model__delete = impl_Model__delete(db, attr); let impl_Model__fetch_one = impl_Model__fetch_one(db, attr); diff --git a/macro/src/codegen/update.rs b/macro/src/codegen/update.rs index 7b76012..ca86ec6 100644 --- a/macro/src/codegen/update.rs +++ b/macro/src/codegen/update.rs @@ -25,7 +25,7 @@ pub fn impl_Model__update_all_fields(db: &dyn OrmliteCodegen, attr: &ModelMeta) query.push_str(" RETURNING *"); let id = &attr.pkey.ident; - let query_bindings = attr.database_columns_except_pkey().map(|c| insertion_binding(c)); + let query_bindings = attr.database_columns_except_pkey().map(insertion_binding); let unwind_joins = attr.many_to_one_joins().map(|c| { let id = &c.ident; diff --git a/macro/src/lib.rs b/macro/src/lib.rs index 03f5aa8..eaa2666 100644 --- a/macro/src/lib.rs +++ b/macro/src/lib.rs @@ -39,7 +39,10 @@ pub(crate) type MetadataCache = HashMap; static TABLES: OnceCell = OnceCell::new(); fn get_tables() -> &'static MetadataCache { - TABLES.get_or_init(|| load_metadata_cache()) + fn fun_name() -> HashMap { + load_metadata_cache() + } + TABLES.get_or_init(fun_name) } fn load_metadata_cache() -> MetadataCache { @@ -131,10 +134,10 @@ pub fn expand_ormlite_model(input: TokenStream) -> TokenStream { let db = first.as_ref(); let impl_TableMeta = impl_TableMeta(&meta.table, Some(meta.pkey.name.as_str())); let impl_JoinMeta = impl_JoinMeta(&meta); - let static_join_descriptions = static_join_descriptions(&meta.table, &tables); + let static_join_descriptions = static_join_descriptions(&meta.table, tables); let impl_Model = impl_Model(db, &meta, tables); - let impl_FromRow = impl_FromRow(db, &meta.table, &tables); - let impl_from_row_using_aliases = impl_from_row_using_aliases(db, &meta.table, &tables); + let impl_FromRow = impl_FromRow(db, &meta.table, tables); + let impl_from_row_using_aliases = impl_from_row_using_aliases(db, &meta.table, tables); let struct_ModelBuilder = struct_ModelBuilder(&ast, &meta); let impl_ModelBuilder = impl_ModelBuilder(db, &meta); @@ -197,8 +200,8 @@ pub fn expand_derive_fromrow(input: TokenStream) -> TokenStream { let expanded = databases.iter().map(|db| { let db = db.as_ref(); - let impl_FromRow = impl_FromRow(db, &meta, &tables); - let impl_from_row_using_aliases = impl_from_row_using_aliases(db, &meta, &tables); + let impl_FromRow = impl_FromRow(db, &meta, tables); + let impl_from_row_using_aliases = impl_from_row_using_aliases(db, &meta, tables); quote! { #impl_FromRow #impl_from_row_using_aliases