bevy_reflect/enums/
variants.rs

1use crate::{
2    attributes::{impl_custom_attribute_methods, CustomAttributes},
3    NamedField, UnnamedField,
4};
5use alloc::boxed::Box;
6use bevy_platform::collections::HashMap;
7use bevy_platform::sync::Arc;
8use core::slice::Iter;
9use thiserror::Error;
10
11/// Describes the form of an enum variant.
12#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
13pub enum VariantType {
14    /// Struct enums take the form:
15    ///
16    /// ```
17    /// enum MyEnum {
18    ///   A {
19    ///     foo: usize
20    ///   }
21    /// }
22    /// ```
23    Struct,
24    /// Tuple enums take the form:
25    ///
26    /// ```
27    /// enum MyEnum {
28    ///   A(usize)
29    /// }
30    /// ```
31    Tuple,
32    /// Unit enums take the form:
33    ///
34    /// ```
35    /// enum MyEnum {
36    ///   A
37    /// }
38    /// ```
39    Unit,
40}
41
42/// A [`VariantInfo`]-specific error.
43#[derive(Debug, Error)]
44pub enum VariantInfoError {
45    /// Caused when a variant was expected to be of a certain [type], but was not.
46    ///
47    /// [type]: VariantType
48    #[error("variant type mismatch: expected {expected:?}, received {received:?}")]
49    TypeMismatch {
50        /// Expected variant type.
51        expected: VariantType,
52        /// Received variant type.
53        received: VariantType,
54    },
55}
56
57/// A container for compile-time enum variant info.
58#[derive(Clone, Debug)]
59pub enum VariantInfo {
60    /// Struct enums take the form:
61    ///
62    /// ```
63    /// enum MyEnum {
64    ///   A {
65    ///     foo: usize
66    ///   }
67    /// }
68    /// ```
69    Struct(StructVariantInfo),
70    /// Tuple enums take the form:
71    ///
72    /// ```
73    /// enum MyEnum {
74    ///   A(usize)
75    /// }
76    /// ```
77    Tuple(TupleVariantInfo),
78    /// Unit enums take the form:
79    ///
80    /// ```
81    /// enum MyEnum {
82    ///   A
83    /// }
84    /// ```
85    Unit(UnitVariantInfo),
86}
87
88impl VariantInfo {
89    /// The name of the enum variant.
90    pub fn name(&self) -> &'static str {
91        match self {
92            Self::Struct(info) => info.name(),
93            Self::Tuple(info) => info.name(),
94            Self::Unit(info) => info.name(),
95        }
96    }
97
98    /// The docstring of the underlying variant, if any.
99    #[cfg(feature = "documentation")]
100    pub fn docs(&self) -> Option<&str> {
101        match self {
102            Self::Struct(info) => info.docs(),
103            Self::Tuple(info) => info.docs(),
104            Self::Unit(info) => info.docs(),
105        }
106    }
107
108    /// Returns the [type] of this variant.
109    ///
110    /// [type]: VariantType
111    pub fn variant_type(&self) -> VariantType {
112        match self {
113            Self::Struct(_) => VariantType::Struct,
114            Self::Tuple(_) => VariantType::Tuple,
115            Self::Unit(_) => VariantType::Unit,
116        }
117    }
118
119    impl_custom_attribute_methods!(
120        self,
121        match self {
122            Self::Struct(info) => info.custom_attributes(),
123            Self::Tuple(info) => info.custom_attributes(),
124            Self::Unit(info) => info.custom_attributes(),
125        },
126        "variant"
127    );
128}
129
130macro_rules! impl_cast_method {
131    ($name:ident : $kind:ident => $info:ident) => {
132        #[doc = concat!("Attempts a cast to [`", stringify!($info), "`].")]
133        #[doc = concat!("\n\nReturns an error if `self` is not [`VariantInfo::", stringify!($kind), "`].")]
134        pub fn $name(&self) -> Result<&$info, VariantInfoError> {
135            match self {
136                Self::$kind(info) => Ok(info),
137                _ => Err(VariantInfoError::TypeMismatch {
138                    expected: VariantType::$kind,
139                    received: self.variant_type(),
140                }),
141            }
142        }
143    };
144}
145
146/// Conversion convenience methods for [`VariantInfo`].
147impl VariantInfo {
148    impl_cast_method!(as_struct_variant: Struct => StructVariantInfo);
149    impl_cast_method!(as_tuple_variant: Tuple => TupleVariantInfo);
150    impl_cast_method!(as_unit_variant: Unit => UnitVariantInfo);
151}
152
153/// Type info for struct variants.
154#[derive(Clone, Debug)]
155pub struct StructVariantInfo {
156    name: &'static str,
157    fields: Box<[NamedField]>,
158    field_names: Box<[&'static str]>,
159    field_indices: HashMap<&'static str, usize>,
160    custom_attributes: Arc<CustomAttributes>,
161    #[cfg(feature = "documentation")]
162    docs: Option<&'static str>,
163}
164
165impl StructVariantInfo {
166    /// Create a new [`StructVariantInfo`].
167    pub fn new(name: &'static str, fields: &[NamedField]) -> Self {
168        let field_indices = Self::collect_field_indices(fields);
169        let field_names = fields.iter().map(NamedField::name).collect();
170        Self {
171            name,
172            fields: fields.to_vec().into_boxed_slice(),
173            field_names,
174            field_indices,
175            custom_attributes: Arc::new(CustomAttributes::default()),
176            #[cfg(feature = "documentation")]
177            docs: None,
178        }
179    }
180
181    /// Sets the docstring for this variant.
182    #[cfg(feature = "documentation")]
183    pub fn with_docs(self, docs: Option<&'static str>) -> Self {
184        Self { docs, ..self }
185    }
186
187    /// Sets the custom attributes for this variant.
188    pub fn with_custom_attributes(self, custom_attributes: CustomAttributes) -> Self {
189        Self {
190            custom_attributes: Arc::new(custom_attributes),
191            ..self
192        }
193    }
194
195    /// The name of this variant.
196    pub fn name(&self) -> &'static str {
197        self.name
198    }
199
200    /// A slice containing the names of all fields in order.
201    pub fn field_names(&self) -> &[&'static str] {
202        &self.field_names
203    }
204
205    /// Get the field with the given name.
206    pub fn field(&self, name: &str) -> Option<&NamedField> {
207        self.field_indices
208            .get(name)
209            .map(|index| &self.fields[*index])
210    }
211
212    /// Get the field at the given index.
213    pub fn field_at(&self, index: usize) -> Option<&NamedField> {
214        self.fields.get(index)
215    }
216
217    /// Get the index of the field with the given name.
218    pub fn index_of(&self, name: &str) -> Option<usize> {
219        self.field_indices.get(name).copied()
220    }
221
222    /// Iterate over the fields of this variant.
223    pub fn iter(&self) -> Iter<'_, NamedField> {
224        self.fields.iter()
225    }
226
227    /// The total number of fields in this variant.
228    pub fn field_len(&self) -> usize {
229        self.fields.len()
230    }
231
232    fn collect_field_indices(fields: &[NamedField]) -> HashMap<&'static str, usize> {
233        fields
234            .iter()
235            .enumerate()
236            .map(|(index, field)| (field.name(), index))
237            .collect()
238    }
239
240    /// The docstring of this variant, if any.
241    #[cfg(feature = "documentation")]
242    pub fn docs(&self) -> Option<&'static str> {
243        self.docs
244    }
245
246    impl_custom_attribute_methods!(self.custom_attributes, "variant");
247}
248
249/// Type info for tuple variants.
250#[derive(Clone, Debug)]
251pub struct TupleVariantInfo {
252    name: &'static str,
253    fields: Box<[UnnamedField]>,
254    custom_attributes: Arc<CustomAttributes>,
255    #[cfg(feature = "documentation")]
256    docs: Option<&'static str>,
257}
258
259impl TupleVariantInfo {
260    /// Create a new [`TupleVariantInfo`].
261    pub fn new(name: &'static str, fields: &[UnnamedField]) -> Self {
262        Self {
263            name,
264            fields: fields.to_vec().into_boxed_slice(),
265            custom_attributes: Arc::new(CustomAttributes::default()),
266            #[cfg(feature = "documentation")]
267            docs: None,
268        }
269    }
270
271    /// Sets the docstring for this variant.
272    #[cfg(feature = "documentation")]
273    pub fn with_docs(self, docs: Option<&'static str>) -> Self {
274        Self { docs, ..self }
275    }
276
277    /// Sets the custom attributes for this variant.
278    pub fn with_custom_attributes(self, custom_attributes: CustomAttributes) -> Self {
279        Self {
280            custom_attributes: Arc::new(custom_attributes),
281            ..self
282        }
283    }
284
285    /// The name of this variant.
286    pub fn name(&self) -> &'static str {
287        self.name
288    }
289
290    /// Get the field at the given index.
291    pub fn field_at(&self, index: usize) -> Option<&UnnamedField> {
292        self.fields.get(index)
293    }
294
295    /// Iterate over the fields of this variant.
296    pub fn iter(&self) -> Iter<'_, UnnamedField> {
297        self.fields.iter()
298    }
299
300    /// The total number of fields in this variant.
301    pub fn field_len(&self) -> usize {
302        self.fields.len()
303    }
304
305    /// The docstring of this variant, if any.
306    #[cfg(feature = "documentation")]
307    pub fn docs(&self) -> Option<&'static str> {
308        self.docs
309    }
310
311    impl_custom_attribute_methods!(self.custom_attributes, "variant");
312}
313
314/// Type info for unit variants.
315#[derive(Clone, Debug)]
316pub struct UnitVariantInfo {
317    name: &'static str,
318    custom_attributes: Arc<CustomAttributes>,
319    #[cfg(feature = "documentation")]
320    docs: Option<&'static str>,
321}
322
323impl UnitVariantInfo {
324    /// Create a new [`UnitVariantInfo`].
325    pub fn new(name: &'static str) -> Self {
326        Self {
327            name,
328            custom_attributes: Arc::new(CustomAttributes::default()),
329            #[cfg(feature = "documentation")]
330            docs: None,
331        }
332    }
333
334    /// Sets the docstring for this variant.
335    #[cfg(feature = "documentation")]
336    pub fn with_docs(self, docs: Option<&'static str>) -> Self {
337        Self { docs, ..self }
338    }
339
340    /// Sets the custom attributes for this variant.
341    pub fn with_custom_attributes(self, custom_attributes: CustomAttributes) -> Self {
342        Self {
343            custom_attributes: Arc::new(custom_attributes),
344            ..self
345        }
346    }
347
348    /// The name of this variant.
349    pub fn name(&self) -> &'static str {
350        self.name
351    }
352
353    /// The docstring of this variant, if any.
354    #[cfg(feature = "documentation")]
355    pub fn docs(&self) -> Option<&'static str> {
356        self.docs
357    }
358
359    impl_custom_attribute_methods!(self.custom_attributes, "variant");
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::{Reflect, Typed};
366
367    #[test]
368    fn should_return_error_on_invalid_cast() {
369        #[derive(Reflect)]
370        enum Foo {
371            Bar,
372        }
373
374        let info = Foo::type_info().as_enum().unwrap();
375        let variant = info.variant_at(0).unwrap();
376        assert!(matches!(
377            variant.as_tuple_variant(),
378            Err(VariantInfoError::TypeMismatch {
379                expected: VariantType::Tuple,
380                received: VariantType::Unit
381            })
382        ));
383    }
384}