bevy_ecs_macros/
lib.rs

1#![expect(missing_docs, reason = "Not all docs are written yet, see #3492.")]
2#![cfg_attr(docsrs, feature(doc_auto_cfg))]
3
4extern crate proc_macro;
5
6mod component;
7mod query_data;
8mod query_filter;
9mod states;
10mod world_query;
11
12use crate::{
13    component::map_entities, query_data::derive_query_data_impl,
14    query_filter::derive_query_filter_impl,
15};
16use bevy_macro_utils::{derive_label, ensure_no_collision, get_struct_fields, BevyManifest};
17use proc_macro::TokenStream;
18use proc_macro2::{Ident, Span};
19use quote::{format_ident, quote};
20use syn::{
21    parse_macro_input, parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma,
22    ConstParam, Data, DataStruct, DeriveInput, GenericParam, Index, TypeParam,
23};
24
25enum BundleFieldKind {
26    Component,
27    Ignore,
28}
29
30const BUNDLE_ATTRIBUTE_NAME: &str = "bundle";
31const BUNDLE_ATTRIBUTE_IGNORE_NAME: &str = "ignore";
32
33#[proc_macro_derive(Bundle, attributes(bundle))]
34pub fn derive_bundle(input: TokenStream) -> TokenStream {
35    let ast = parse_macro_input!(input as DeriveInput);
36    let ecs_path = bevy_ecs_path();
37
38    let named_fields = match get_struct_fields(&ast.data) {
39        Ok(fields) => fields,
40        Err(e) => return e.into_compile_error().into(),
41    };
42
43    let mut field_kind = Vec::with_capacity(named_fields.len());
44
45    for field in named_fields {
46        for attr in field
47            .attrs
48            .iter()
49            .filter(|a| a.path().is_ident(BUNDLE_ATTRIBUTE_NAME))
50        {
51            if let Err(error) = attr.parse_nested_meta(|meta| {
52                if meta.path.is_ident(BUNDLE_ATTRIBUTE_IGNORE_NAME) {
53                    field_kind.push(BundleFieldKind::Ignore);
54                    Ok(())
55                } else {
56                    Err(meta.error(format!(
57                        "Invalid bundle attribute. Use `{BUNDLE_ATTRIBUTE_IGNORE_NAME}`"
58                    )))
59                }
60            }) {
61                return error.into_compile_error().into();
62            }
63        }
64
65        field_kind.push(BundleFieldKind::Component);
66    }
67
68    let field = named_fields
69        .iter()
70        .map(|field| field.ident.as_ref())
71        .collect::<Vec<_>>();
72
73    let field_type = named_fields
74        .iter()
75        .map(|field| &field.ty)
76        .collect::<Vec<_>>();
77
78    let mut field_component_ids = Vec::new();
79    let mut field_get_component_ids = Vec::new();
80    let mut field_get_components = Vec::new();
81    let mut field_from_components = Vec::new();
82    let mut field_required_components = Vec::new();
83    for (((i, field_type), field_kind), field) in field_type
84        .iter()
85        .enumerate()
86        .zip(field_kind.iter())
87        .zip(field.iter())
88    {
89        match field_kind {
90            BundleFieldKind::Component => {
91                field_component_ids.push(quote! {
92                <#field_type as #ecs_path::bundle::Bundle>::component_ids(components, &mut *ids);
93                });
94                field_required_components.push(quote! {
95                    <#field_type as #ecs_path::bundle::Bundle>::register_required_components(components, required_components);
96                });
97                field_get_component_ids.push(quote! {
98                    <#field_type as #ecs_path::bundle::Bundle>::get_component_ids(components, &mut *ids);
99                });
100                match field {
101                    Some(field) => {
102                        field_get_components.push(quote! {
103                            self.#field.get_components(&mut *func);
104                        });
105                        field_from_components.push(quote! {
106                            #field: <#field_type as #ecs_path::bundle::BundleFromComponents>::from_components(ctx, &mut *func),
107                        });
108                    }
109                    None => {
110                        let index = Index::from(i);
111                        field_get_components.push(quote! {
112                            self.#index.get_components(&mut *func);
113                        });
114                        field_from_components.push(quote! {
115                            #index: <#field_type as #ecs_path::bundle::BundleFromComponents>::from_components(ctx, &mut *func),
116                        });
117                    }
118                }
119            }
120
121            BundleFieldKind::Ignore => {
122                field_from_components.push(quote! {
123                    #field: ::core::default::Default::default(),
124                });
125            }
126        }
127    }
128    let generics = ast.generics;
129    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
130    let struct_name = &ast.ident;
131
132    TokenStream::from(quote! {
133        // SAFETY:
134        // - ComponentId is returned in field-definition-order. [get_components] uses field-definition-order
135        // - `Bundle::get_components` is exactly once for each member. Rely's on the Component -> Bundle implementation to properly pass
136        //   the correct `StorageType` into the callback.
137        #[allow(deprecated)]
138        unsafe impl #impl_generics #ecs_path::bundle::Bundle for #struct_name #ty_generics #where_clause {
139            fn component_ids(
140                components: &mut #ecs_path::component::ComponentsRegistrator,
141                ids: &mut impl FnMut(#ecs_path::component::ComponentId)
142            ){
143                #(#field_component_ids)*
144            }
145
146            fn get_component_ids(
147                components: &#ecs_path::component::Components,
148                ids: &mut impl FnMut(Option<#ecs_path::component::ComponentId>)
149            ){
150                #(#field_get_component_ids)*
151            }
152
153            fn register_required_components(
154                components: &mut #ecs_path::component::ComponentsRegistrator,
155                required_components: &mut #ecs_path::component::RequiredComponents
156            ){
157                #(#field_required_components)*
158            }
159        }
160
161        // SAFETY:
162        // - ComponentId is returned in field-definition-order. [from_components] uses field-definition-order
163        #[allow(deprecated)]
164        unsafe impl #impl_generics #ecs_path::bundle::BundleFromComponents for #struct_name #ty_generics #where_clause {
165            #[allow(unused_variables, non_snake_case)]
166            unsafe fn from_components<__T, __F>(ctx: &mut __T, func: &mut __F) -> Self
167            where
168                __F: FnMut(&mut __T) -> #ecs_path::ptr::OwningPtr<'_>
169            {
170                Self{
171                    #(#field_from_components)*
172                }
173            }
174        }
175
176        #[allow(deprecated)]
177        impl #impl_generics #ecs_path::bundle::DynamicBundle for #struct_name #ty_generics #where_clause {
178            type Effect = ();
179            #[allow(unused_variables)]
180            #[inline]
181            fn get_components(
182                self,
183                func: &mut impl FnMut(#ecs_path::component::StorageType, #ecs_path::ptr::OwningPtr<'_>)
184            ) {
185                #(#field_get_components)*
186            }
187        }
188    })
189}
190
191#[proc_macro_derive(MapEntities, attributes(entities))]
192pub fn derive_map_entities(input: TokenStream) -> TokenStream {
193    let ast = parse_macro_input!(input as DeriveInput);
194    let ecs_path = bevy_ecs_path();
195    let map_entities_impl = map_entities(
196        &ast.data,
197        Ident::new("self", Span::call_site()),
198        false,
199        false,
200    );
201    let struct_name = &ast.ident;
202    let (impl_generics, type_generics, where_clause) = &ast.generics.split_for_impl();
203    TokenStream::from(quote! {
204        impl #impl_generics #ecs_path::entity::MapEntities for #struct_name #type_generics #where_clause {
205            fn map_entities<M: #ecs_path::entity::EntityMapper>(&mut self, mapper: &mut M) {
206                #map_entities_impl
207            }
208        }
209    })
210}
211
212/// Implement `SystemParam` to use a struct as a parameter in a system
213#[proc_macro_derive(SystemParam, attributes(system_param))]
214pub fn derive_system_param(input: TokenStream) -> TokenStream {
215    let token_stream = input.clone();
216    let ast = parse_macro_input!(input as DeriveInput);
217    let Data::Struct(DataStruct {
218        fields: field_definitions,
219        ..
220    }) = ast.data
221    else {
222        return syn::Error::new(
223            ast.span(),
224            "Invalid `SystemParam` type: expected a `struct`",
225        )
226        .into_compile_error()
227        .into();
228    };
229    let path = bevy_ecs_path();
230
231    let mut field_locals = Vec::new();
232    let mut field_names = Vec::new();
233    let mut fields = Vec::new();
234    let mut field_types = Vec::new();
235    let mut field_messages = Vec::new();
236    for (i, field) in field_definitions.iter().enumerate() {
237        field_locals.push(format_ident!("f{i}"));
238        let i = Index::from(i);
239        let field_value = field
240            .ident
241            .as_ref()
242            .map(|f| quote! { #f })
243            .unwrap_or_else(|| quote! { #i });
244        field_names.push(format!("::{}", field_value));
245        fields.push(field_value);
246        field_types.push(&field.ty);
247        let mut field_message = None;
248        for meta in field
249            .attrs
250            .iter()
251            .filter(|a| a.path().is_ident("system_param"))
252        {
253            if let Err(e) = meta.parse_nested_meta(|nested| {
254                if nested.path.is_ident("validation_message") {
255                    field_message = Some(nested.value()?.parse()?);
256                    Ok(())
257                } else {
258                    Err(nested.error("Unsupported attribute"))
259                }
260            }) {
261                return e.into_compile_error().into();
262            }
263        }
264        field_messages.push(field_message.unwrap_or_else(|| quote! { err.message }));
265    }
266
267    let generics = ast.generics;
268
269    // Emit an error if there's any unrecognized lifetime names.
270    for lt in generics.lifetimes() {
271        let ident = &lt.lifetime.ident;
272        let w = format_ident!("w");
273        let s = format_ident!("s");
274        if ident != &w && ident != &s {
275            return syn::Error::new_spanned(
276                lt,
277                r#"invalid lifetime name: expected `'w` or `'s`
278 'w -- refers to data stored in the World.
279 's -- refers to data stored in the SystemParam's state.'"#,
280            )
281            .into_compile_error()
282            .into();
283        }
284    }
285
286    let (_impl_generics, ty_generics, where_clause) = generics.split_for_impl();
287
288    let lifetimeless_generics: Vec<_> = generics
289        .params
290        .iter()
291        .filter(|g| !matches!(g, GenericParam::Lifetime(_)))
292        .collect();
293
294    let shadowed_lifetimes: Vec<_> = generics.lifetimes().map(|_| quote!('_)).collect();
295
296    let mut punctuated_generics = Punctuated::<_, Comma>::new();
297    punctuated_generics.extend(lifetimeless_generics.iter().map(|g| match g {
298        GenericParam::Type(g) => GenericParam::Type(TypeParam {
299            default: None,
300            ..g.clone()
301        }),
302        GenericParam::Const(g) => GenericParam::Const(ConstParam {
303            default: None,
304            ..g.clone()
305        }),
306        _ => unreachable!(),
307    }));
308
309    let mut punctuated_generic_idents = Punctuated::<_, Comma>::new();
310    punctuated_generic_idents.extend(lifetimeless_generics.iter().map(|g| match g {
311        GenericParam::Type(g) => &g.ident,
312        GenericParam::Const(g) => &g.ident,
313        _ => unreachable!(),
314    }));
315
316    let punctuated_generics_no_bounds: Punctuated<_, Comma> = lifetimeless_generics
317        .iter()
318        .map(|&g| match g.clone() {
319            GenericParam::Type(mut g) => {
320                g.bounds.clear();
321                GenericParam::Type(g)
322            }
323            g => g,
324        })
325        .collect();
326
327    let mut tuple_types: Vec<_> = field_types.iter().map(|x| quote! { #x }).collect();
328    let mut tuple_patterns: Vec<_> = field_locals.iter().map(|x| quote! { #x }).collect();
329
330    // If the number of fields exceeds the 16-parameter limit,
331    // fold the fields into tuples of tuples until we are below the limit.
332    const LIMIT: usize = 16;
333    while tuple_types.len() > LIMIT {
334        let end = Vec::from_iter(tuple_types.drain(..LIMIT));
335        tuple_types.push(parse_quote!( (#(#end,)*) ));
336
337        let end = Vec::from_iter(tuple_patterns.drain(..LIMIT));
338        tuple_patterns.push(parse_quote!( (#(#end,)*) ));
339    }
340
341    // Create a where clause for the `ReadOnlySystemParam` impl.
342    // Ensure that each field implements `ReadOnlySystemParam`.
343    let mut read_only_generics = generics.clone();
344    let read_only_where_clause = read_only_generics.make_where_clause();
345    for field_type in &field_types {
346        read_only_where_clause
347            .predicates
348            .push(syn::parse_quote!(#field_type: #path::system::ReadOnlySystemParam));
349    }
350
351    let fields_alias =
352        ensure_no_collision(format_ident!("__StructFieldsAlias"), token_stream.clone());
353
354    let struct_name = &ast.ident;
355    let state_struct_visibility = &ast.vis;
356    let state_struct_name = ensure_no_collision(format_ident!("FetchState"), token_stream);
357
358    let mut builder_name = None;
359    for meta in ast
360        .attrs
361        .iter()
362        .filter(|a| a.path().is_ident("system_param"))
363    {
364        if let Err(e) = meta.parse_nested_meta(|nested| {
365            if nested.path.is_ident("builder") {
366                builder_name = Some(format_ident!("{struct_name}Builder"));
367                Ok(())
368            } else {
369                Err(nested.error("Unsupported attribute"))
370            }
371        }) {
372            return e.into_compile_error().into();
373        }
374    }
375
376    let builder = builder_name.map(|builder_name| {
377        let builder_type_parameters: Vec<_> = (0..fields.len()).map(|i| format_ident!("B{i}")).collect();
378        let builder_doc_comment = format!("A [`SystemParamBuilder`] for a [`{struct_name}`].");
379        let builder_struct = quote! {
380            #[doc = #builder_doc_comment]
381            struct #builder_name<#(#builder_type_parameters,)*> {
382                #(#fields: #builder_type_parameters,)*
383            }
384        };
385        let lifetimes: Vec<_> = generics.lifetimes().collect();
386        let generic_struct = quote!{ #struct_name <#(#lifetimes,)* #punctuated_generic_idents> };
387        let builder_impl = quote!{
388            // SAFETY: This delegates to the `SystemParamBuilder` for tuples.
389            unsafe impl<
390                #(#lifetimes,)*
391                #(#builder_type_parameters: #path::system::SystemParamBuilder<#field_types>,)*
392                #punctuated_generics
393            > #path::system::SystemParamBuilder<#generic_struct> for #builder_name<#(#builder_type_parameters,)*>
394                #where_clause
395            {
396                fn build(self, world: &mut #path::world::World, meta: &mut #path::system::SystemMeta) -> <#generic_struct as #path::system::SystemParam>::State {
397                    let #builder_name { #(#fields: #field_locals,)* } = self;
398                    #state_struct_name {
399                        state: #path::system::SystemParamBuilder::build((#(#tuple_patterns,)*), world, meta)
400                    }
401                }
402            }
403        };
404        (builder_struct, builder_impl)
405    });
406    let (builder_struct, builder_impl) = builder.unzip();
407
408    TokenStream::from(quote! {
409        // We define the FetchState struct in an anonymous scope to avoid polluting the user namespace.
410        // The struct can still be accessed via SystemParam::State, e.g. EventReaderState can be accessed via
411        // <EventReader<'static, 'static, T> as SystemParam>::State
412        const _: () = {
413            // Allows rebinding the lifetimes of each field type.
414            type #fields_alias <'w, 's, #punctuated_generics_no_bounds> = (#(#tuple_types,)*);
415
416            #[doc(hidden)]
417            #state_struct_visibility struct #state_struct_name <#(#lifetimeless_generics,)*>
418            #where_clause {
419                state: <#fields_alias::<'static, 'static, #punctuated_generic_idents> as #path::system::SystemParam>::State,
420            }
421
422            unsafe impl<#punctuated_generics> #path::system::SystemParam for
423                #struct_name <#(#shadowed_lifetimes,)* #punctuated_generic_idents> #where_clause
424            {
425                type State = #state_struct_name<#punctuated_generic_idents>;
426                type Item<'w, 's> = #struct_name #ty_generics;
427
428                fn init_state(world: &mut #path::world::World, system_meta: &mut #path::system::SystemMeta) -> Self::State {
429                    #state_struct_name {
430                        state: <#fields_alias::<'_, '_, #punctuated_generic_idents> as #path::system::SystemParam>::init_state(world, system_meta),
431                    }
432                }
433
434                unsafe fn new_archetype(state: &mut Self::State, archetype: &#path::archetype::Archetype, system_meta: &mut #path::system::SystemMeta) {
435                    // SAFETY: The caller ensures that `archetype` is from the World the state was initialized from in `init_state`.
436                    unsafe { <#fields_alias::<'_, '_, #punctuated_generic_idents> as #path::system::SystemParam>::new_archetype(&mut state.state, archetype, system_meta) }
437                }
438
439                fn apply(state: &mut Self::State, system_meta: &#path::system::SystemMeta, world: &mut #path::world::World) {
440                    <#fields_alias::<'_, '_, #punctuated_generic_idents> as #path::system::SystemParam>::apply(&mut state.state, system_meta, world);
441                }
442
443                fn queue(state: &mut Self::State, system_meta: &#path::system::SystemMeta, world: #path::world::DeferredWorld) {
444                    <#fields_alias::<'_, '_, #punctuated_generic_idents> as #path::system::SystemParam>::queue(&mut state.state, system_meta, world);
445                }
446
447                #[inline]
448                unsafe fn validate_param<'w, 's>(
449                    state: &'s Self::State,
450                    _system_meta: &#path::system::SystemMeta,
451                    _world: #path::world::unsafe_world_cell::UnsafeWorldCell<'w>,
452                ) -> Result<(), #path::system::SystemParamValidationError> {
453                    let #state_struct_name { state: (#(#tuple_patterns,)*) } = state;
454                    #(
455                        <#field_types as #path::system::SystemParam>::validate_param(#field_locals, _system_meta, _world)
456                            .map_err(|err| #path::system::SystemParamValidationError::new::<Self>(err.skipped, #field_messages, #field_names))?;
457                    )*
458                    Ok(())
459                }
460
461                #[inline]
462                unsafe fn get_param<'w, 's>(
463                    state: &'s mut Self::State,
464                    system_meta: &#path::system::SystemMeta,
465                    world: #path::world::unsafe_world_cell::UnsafeWorldCell<'w>,
466                    change_tick: #path::component::Tick,
467                ) -> Self::Item<'w, 's> {
468                    let (#(#tuple_patterns,)*) = <
469                        (#(#tuple_types,)*) as #path::system::SystemParam
470                    >::get_param(&mut state.state, system_meta, world, change_tick);
471                    #struct_name {
472                        #(#fields: #field_locals,)*
473                    }
474                }
475            }
476
477            // Safety: Each field is `ReadOnlySystemParam`, so this can only read from the `World`
478            unsafe impl<'w, 's, #punctuated_generics> #path::system::ReadOnlySystemParam for #struct_name #ty_generics #read_only_where_clause {}
479
480            #builder_impl
481        };
482
483        #builder_struct
484    })
485}
486
487/// Implement `QueryData` to use a struct as a data parameter in a query
488#[proc_macro_derive(QueryData, attributes(query_data))]
489pub fn derive_query_data(input: TokenStream) -> TokenStream {
490    derive_query_data_impl(input)
491}
492
493/// Implement `QueryFilter` to use a struct as a filter parameter in a query
494#[proc_macro_derive(QueryFilter, attributes(query_filter))]
495pub fn derive_query_filter(input: TokenStream) -> TokenStream {
496    derive_query_filter_impl(input)
497}
498
499/// Derive macro generating an impl of the trait `ScheduleLabel`.
500///
501/// This does not work for unions.
502#[proc_macro_derive(ScheduleLabel)]
503pub fn derive_schedule_label(input: TokenStream) -> TokenStream {
504    let input = parse_macro_input!(input as DeriveInput);
505    let mut trait_path = bevy_ecs_path();
506    trait_path.segments.push(format_ident!("schedule").into());
507    let mut dyn_eq_path = trait_path.clone();
508    trait_path
509        .segments
510        .push(format_ident!("ScheduleLabel").into());
511    dyn_eq_path.segments.push(format_ident!("DynEq").into());
512    derive_label(input, "ScheduleLabel", &trait_path, &dyn_eq_path)
513}
514
515/// Derive macro generating an impl of the trait `SystemSet`.
516///
517/// This does not work for unions.
518#[proc_macro_derive(SystemSet)]
519pub fn derive_system_set(input: TokenStream) -> TokenStream {
520    let input = parse_macro_input!(input as DeriveInput);
521    let mut trait_path = bevy_ecs_path();
522    trait_path.segments.push(format_ident!("schedule").into());
523    let mut dyn_eq_path = trait_path.clone();
524    trait_path.segments.push(format_ident!("SystemSet").into());
525    dyn_eq_path.segments.push(format_ident!("DynEq").into());
526    derive_label(input, "SystemSet", &trait_path, &dyn_eq_path)
527}
528
529pub(crate) fn bevy_ecs_path() -> syn::Path {
530    BevyManifest::shared().get_path("bevy_ecs")
531}
532
533#[proc_macro_derive(Event, attributes(event))]
534pub fn derive_event(input: TokenStream) -> TokenStream {
535    component::derive_event(input)
536}
537
538#[proc_macro_derive(Resource)]
539pub fn derive_resource(input: TokenStream) -> TokenStream {
540    component::derive_resource(input)
541}
542
543#[proc_macro_derive(
544    Component,
545    attributes(component, require, relationship, relationship_target, entities)
546)]
547pub fn derive_component(input: TokenStream) -> TokenStream {
548    component::derive_component(input)
549}
550
551#[proc_macro_derive(States)]
552pub fn derive_states(input: TokenStream) -> TokenStream {
553    states::derive_states(input)
554}
555
556#[proc_macro_derive(SubStates, attributes(source))]
557pub fn derive_substates(input: TokenStream) -> TokenStream {
558    states::derive_substates(input)
559}
560
561#[proc_macro_derive(FromWorld, attributes(from_world))]
562pub fn derive_from_world(input: TokenStream) -> TokenStream {
563    let bevy_ecs_path = bevy_ecs_path();
564    let ast = parse_macro_input!(input as DeriveInput);
565    let name = ast.ident;
566    let (impl_generics, ty_generics, where_clauses) = ast.generics.split_for_impl();
567
568    let (fields, variant_ident) = match &ast.data {
569        Data::Struct(data) => (&data.fields, None),
570        Data::Enum(data) => {
571            match data.variants.iter().find(|variant| {
572                variant
573                    .attrs
574                    .iter()
575                    .any(|attr| attr.path().is_ident("from_world"))
576            }) {
577                Some(variant) => (&variant.fields, Some(&variant.ident)),
578                None => {
579                    return syn::Error::new(
580                        Span::call_site(),
581                        "No variant found with the `#[from_world]` attribute",
582                    )
583                    .into_compile_error()
584                    .into();
585                }
586            }
587        }
588        Data::Union(_) => {
589            return syn::Error::new(
590                Span::call_site(),
591                "#[derive(FromWorld)]` does not support unions",
592            )
593            .into_compile_error()
594            .into();
595        }
596    };
597
598    let field_init_expr = quote!(#bevy_ecs_path::world::FromWorld::from_world(world));
599    let members = fields.members();
600
601    let field_initializers = match variant_ident {
602        Some(variant_ident) => quote!( Self::#variant_ident {
603            #(#members: #field_init_expr),*
604        }),
605        None => quote!( Self {
606            #(#members: #field_init_expr),*
607        }),
608    };
609
610    TokenStream::from(quote! {
611            impl #impl_generics #bevy_ecs_path::world::FromWorld for #name #ty_generics #where_clauses {
612                fn from_world(world: &mut #bevy_ecs_path::world::World) -> Self {
613                    #field_initializers
614                }
615            }
616    })
617}