bevy_reflect/enums/
variants.rs

1use crate::{
2    attributes::{impl_custom_attribute_methods, CustomAttributes},
3    NamedField, UnnamedField,
4};
5use bevy_utils::HashMap;
6use core::slice::Iter;
7
8use alloc::sync::Arc;
9use derive_more::derive::{Display, Error};
10
11/// Describes the form of an enum variant.
12#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
13pub enum VariantType {
14    /// Struct enums take the form:
15    ///
16    /// ```
17    /// enum MyEnum {
18    ///   A {
19    ///     foo: usize
20    ///   }
21    /// }
22    /// ```
23    Struct,
24    /// Tuple enums take the form:
25    ///
26    /// ```
27    /// enum MyEnum {
28    ///   A(usize)
29    /// }
30    /// ```
31    Tuple,
32    /// Unit enums take the form:
33    ///
34    /// ```
35    /// enum MyEnum {
36    ///   A
37    /// }
38    /// ```
39    Unit,
40}
41
42/// A [`VariantInfo`]-specific error.
43#[derive(Debug, Error, Display)]
44pub enum VariantInfoError {
45    /// Caused when a variant was expected to be of a certain [type], but was not.
46    ///
47    /// [type]: VariantType
48    #[display("variant type mismatch: expected {expected:?}, received {received:?}")]
49    TypeMismatch {
50        expected: VariantType,
51        received: VariantType,
52    },
53}
54
55/// A container for compile-time enum variant info.
56#[derive(Clone, Debug)]
57pub enum VariantInfo {
58    /// Struct enums take the form:
59    ///
60    /// ```
61    /// enum MyEnum {
62    ///   A {
63    ///     foo: usize
64    ///   }
65    /// }
66    /// ```
67    Struct(StructVariantInfo),
68    /// Tuple enums take the form:
69    ///
70    /// ```
71    /// enum MyEnum {
72    ///   A(usize)
73    /// }
74    /// ```
75    Tuple(TupleVariantInfo),
76    /// Unit enums take the form:
77    ///
78    /// ```
79    /// enum MyEnum {
80    ///   A
81    /// }
82    /// ```
83    Unit(UnitVariantInfo),
84}
85
86impl VariantInfo {
87    pub fn name(&self) -> &'static str {
88        match self {
89            Self::Struct(info) => info.name(),
90            Self::Tuple(info) => info.name(),
91            Self::Unit(info) => info.name(),
92        }
93    }
94
95    /// The docstring of the underlying variant, if any.
96    #[cfg(feature = "documentation")]
97    pub fn docs(&self) -> Option<&str> {
98        match self {
99            Self::Struct(info) => info.docs(),
100            Self::Tuple(info) => info.docs(),
101            Self::Unit(info) => info.docs(),
102        }
103    }
104
105    /// Returns the [type] of this variant.
106    ///
107    /// [type]: VariantType
108    pub fn variant_type(&self) -> VariantType {
109        match self {
110            Self::Struct(_) => VariantType::Struct,
111            Self::Tuple(_) => VariantType::Tuple,
112            Self::Unit(_) => VariantType::Unit,
113        }
114    }
115
116    impl_custom_attribute_methods!(
117        self,
118        match self {
119            Self::Struct(info) => info.custom_attributes(),
120            Self::Tuple(info) => info.custom_attributes(),
121            Self::Unit(info) => info.custom_attributes(),
122        },
123        "variant"
124    );
125}
126
127macro_rules! impl_cast_method {
128    ($name:ident : $kind:ident => $info:ident) => {
129        #[doc = concat!("Attempts a cast to [`", stringify!($info), "`].")]
130        #[doc = concat!("\n\nReturns an error if `self` is not [`VariantInfo::", stringify!($kind), "`].")]
131        pub fn $name(&self) -> Result<&$info, VariantInfoError> {
132            match self {
133                Self::$kind(info) => Ok(info),
134                _ => Err(VariantInfoError::TypeMismatch {
135                    expected: VariantType::$kind,
136                    received: self.variant_type(),
137                }),
138            }
139        }
140    };
141}
142
143/// Conversion convenience methods for [`VariantInfo`].
144impl VariantInfo {
145    impl_cast_method!(as_struct_variant: Struct => StructVariantInfo);
146    impl_cast_method!(as_tuple_variant: Tuple => TupleVariantInfo);
147    impl_cast_method!(as_unit_variant: Unit => UnitVariantInfo);
148}
149
150/// Type info for struct variants.
151#[derive(Clone, Debug)]
152pub struct StructVariantInfo {
153    name: &'static str,
154    fields: Box<[NamedField]>,
155    field_names: Box<[&'static str]>,
156    field_indices: HashMap<&'static str, usize>,
157    custom_attributes: Arc<CustomAttributes>,
158    #[cfg(feature = "documentation")]
159    docs: Option<&'static str>,
160}
161
162impl StructVariantInfo {
163    /// Create a new [`StructVariantInfo`].
164    pub fn new(name: &'static str, fields: &[NamedField]) -> Self {
165        let field_indices = Self::collect_field_indices(fields);
166        let field_names = fields.iter().map(NamedField::name).collect();
167        Self {
168            name,
169            fields: fields.to_vec().into_boxed_slice(),
170            field_names,
171            field_indices,
172            custom_attributes: Arc::new(CustomAttributes::default()),
173            #[cfg(feature = "documentation")]
174            docs: None,
175        }
176    }
177
178    /// Sets the docstring for this variant.
179    #[cfg(feature = "documentation")]
180    pub fn with_docs(self, docs: Option<&'static str>) -> Self {
181        Self { docs, ..self }
182    }
183
184    /// Sets the custom attributes for this variant.
185    pub fn with_custom_attributes(self, custom_attributes: CustomAttributes) -> Self {
186        Self {
187            custom_attributes: Arc::new(custom_attributes),
188            ..self
189        }
190    }
191
192    /// The name of this variant.
193    pub fn name(&self) -> &'static str {
194        self.name
195    }
196
197    /// A slice containing the names of all fields in order.
198    pub fn field_names(&self) -> &[&'static str] {
199        &self.field_names
200    }
201
202    /// Get the field with the given name.
203    pub fn field(&self, name: &str) -> Option<&NamedField> {
204        self.field_indices
205            .get(name)
206            .map(|index| &self.fields[*index])
207    }
208
209    /// Get the field at the given index.
210    pub fn field_at(&self, index: usize) -> Option<&NamedField> {
211        self.fields.get(index)
212    }
213
214    /// Get the index of the field with the given name.
215    pub fn index_of(&self, name: &str) -> Option<usize> {
216        self.field_indices.get(name).copied()
217    }
218
219    /// Iterate over the fields of this variant.
220    pub fn iter(&self) -> Iter<'_, NamedField> {
221        self.fields.iter()
222    }
223
224    /// The total number of fields in this variant.
225    pub fn field_len(&self) -> usize {
226        self.fields.len()
227    }
228
229    fn collect_field_indices(fields: &[NamedField]) -> HashMap<&'static str, usize> {
230        fields
231            .iter()
232            .enumerate()
233            .map(|(index, field)| (field.name(), index))
234            .collect()
235    }
236
237    /// The docstring of this variant, if any.
238    #[cfg(feature = "documentation")]
239    pub fn docs(&self) -> Option<&'static str> {
240        self.docs
241    }
242
243    impl_custom_attribute_methods!(self.custom_attributes, "variant");
244}
245
246/// Type info for tuple variants.
247#[derive(Clone, Debug)]
248pub struct TupleVariantInfo {
249    name: &'static str,
250    fields: Box<[UnnamedField]>,
251    custom_attributes: Arc<CustomAttributes>,
252    #[cfg(feature = "documentation")]
253    docs: Option<&'static str>,
254}
255
256impl TupleVariantInfo {
257    /// Create a new [`TupleVariantInfo`].
258    pub fn new(name: &'static str, fields: &[UnnamedField]) -> Self {
259        Self {
260            name,
261            fields: fields.to_vec().into_boxed_slice(),
262            custom_attributes: Arc::new(CustomAttributes::default()),
263            #[cfg(feature = "documentation")]
264            docs: None,
265        }
266    }
267
268    /// Sets the docstring for this variant.
269    #[cfg(feature = "documentation")]
270    pub fn with_docs(self, docs: Option<&'static str>) -> Self {
271        Self { docs, ..self }
272    }
273
274    /// Sets the custom attributes for this variant.
275    pub fn with_custom_attributes(self, custom_attributes: CustomAttributes) -> Self {
276        Self {
277            custom_attributes: Arc::new(custom_attributes),
278            ..self
279        }
280    }
281
282    /// The name of this variant.
283    pub fn name(&self) -> &'static str {
284        self.name
285    }
286
287    /// Get the field at the given index.
288    pub fn field_at(&self, index: usize) -> Option<&UnnamedField> {
289        self.fields.get(index)
290    }
291
292    /// Iterate over the fields of this variant.
293    pub fn iter(&self) -> Iter<'_, UnnamedField> {
294        self.fields.iter()
295    }
296
297    /// The total number of fields in this variant.
298    pub fn field_len(&self) -> usize {
299        self.fields.len()
300    }
301
302    /// The docstring of this variant, if any.
303    #[cfg(feature = "documentation")]
304    pub fn docs(&self) -> Option<&'static str> {
305        self.docs
306    }
307
308    impl_custom_attribute_methods!(self.custom_attributes, "variant");
309}
310
311/// Type info for unit variants.
312#[derive(Clone, Debug)]
313pub struct UnitVariantInfo {
314    name: &'static str,
315    custom_attributes: Arc<CustomAttributes>,
316    #[cfg(feature = "documentation")]
317    docs: Option<&'static str>,
318}
319
320impl UnitVariantInfo {
321    /// Create a new [`UnitVariantInfo`].
322    pub fn new(name: &'static str) -> Self {
323        Self {
324            name,
325            custom_attributes: Arc::new(CustomAttributes::default()),
326            #[cfg(feature = "documentation")]
327            docs: None,
328        }
329    }
330
331    /// Sets the docstring for this variant.
332    #[cfg(feature = "documentation")]
333    pub fn with_docs(self, docs: Option<&'static str>) -> Self {
334        Self { docs, ..self }
335    }
336
337    /// Sets the custom attributes for this variant.
338    pub fn with_custom_attributes(self, custom_attributes: CustomAttributes) -> Self {
339        Self {
340            custom_attributes: Arc::new(custom_attributes),
341            ..self
342        }
343    }
344
345    /// The name of this variant.
346    pub fn name(&self) -> &'static str {
347        self.name
348    }
349
350    /// The docstring of this variant, if any.
351    #[cfg(feature = "documentation")]
352    pub fn docs(&self) -> Option<&'static str> {
353        self.docs
354    }
355
356    impl_custom_attribute_methods!(self.custom_attributes, "variant");
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use crate as bevy_reflect;
363    use crate::{Reflect, Typed};
364
365    #[test]
366    fn should_return_error_on_invalid_cast() {
367        #[derive(Reflect)]
368        enum Foo {
369            Bar,
370        }
371
372        let info = Foo::type_info().as_enum().unwrap();
373        let variant = info.variant_at(0).unwrap();
374        assert!(matches!(
375            variant.as_tuple_variant(),
376            Err(VariantInfoError::TypeMismatch {
377                expected: VariantType::Tuple,
378                received: VariantType::Unit
379            })
380        ));
381    }
382}