Skip to content

Commit

Permalink
particular_derive improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
Canleskis committed Jun 2, 2023
1 parent 055b038 commit f9101a9
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions particular_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use syn::spanned::Spanned;

/// Derive macro generating an implementation of the trait `Particle`.
#[proc_macro_derive(Particle)]
pub fn particle_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
Expand All @@ -11,14 +9,14 @@ pub fn particle_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStrea
fn impl_particle(input: syn::Result<syn::DeriveInput>) -> syn::Result<proc_macro::TokenStream> {
let input = input?;

let (pty, sty) = get_field_types(input.data)?;

let name = input.ident;
let (pty, sty) = get_field_types(input.data)?;
let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();

Ok(quote::quote! {
impl #impl_generics Particle for #name #ty_generics #where_clause {
type Scalar = #sty;

type Vector = #pty;

#[inline]
Expand All @@ -36,29 +34,28 @@ fn impl_particle(input: syn::Result<syn::DeriveInput>) -> syn::Result<proc_macro
}

fn get_field_types(data: syn::Data) -> syn::Result<(syn::Type, syn::Type)> {
match data {
syn::Data::Struct(struct_data) => get_type_of(&struct_data, "position")
.and_then(|pty| get_type_of(&struct_data, "mu").map(|mty| (pty, mty))),
match &data {
syn::Data::Struct(struct_data) => Ok((
get_type_of(struct_data, "position")?,
get_type_of(struct_data, "mu")?,
)),
syn::Data::Enum(enum_data) => Err(syn::Error::new_spanned(
enum_data.enum_token,
"an enum cannot represent a Particle",
"the `Particle` trait can only be derived for struct types",
)),
syn::Data::Union(union_data) => Err(syn::Error::new_spanned(
union_data.union_token,
"a union cannot represent a Particle",
"the `Particle` trait can only be derived for struct types",
)),
}
}

fn get_type_of(struct_data: &syn::DataStruct, field_name: &str) -> syn::Result<syn::Type> {
let fields_span = struct_data.fields.span();

struct_data
.fields
.iter()
.find_map(|field| match &field.ident {
Some(ident) => (ident == field_name).then_some(field.ty.clone()),
None => None,
.find_map(|field| (field.ident.as_ref()? == field_name).then(|| field.ty.clone()))
.ok_or_else(|| {
syn::Error::new_spanned(&struct_data.fields, format!("no {field_name} field"))
})
.ok_or_else(|| syn::Error::new(fields_span, format!("no {field_name} field")))
}

0 comments on commit f9101a9

Please sign in to comment.