1use crate::arena::{Handle, HandleVec};
2use std::{fmt::Display, num::NonZeroU32, ops};
3
4#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
6#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
7#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
8pub struct Alignment(NonZeroU32);
9
10impl Alignment {
11 pub const ONE: Self = Self(unsafe { NonZeroU32::new_unchecked(1) });
12 pub const TWO: Self = Self(unsafe { NonZeroU32::new_unchecked(2) });
13 pub const FOUR: Self = Self(unsafe { NonZeroU32::new_unchecked(4) });
14 pub const EIGHT: Self = Self(unsafe { NonZeroU32::new_unchecked(8) });
15 pub const SIXTEEN: Self = Self(unsafe { NonZeroU32::new_unchecked(16) });
16
17 pub const MIN_UNIFORM: Self = Self::SIXTEEN;
18
19 pub const fn new(n: u32) -> Option<Self> {
20 if n.is_power_of_two() {
21 Some(Self(unsafe { NonZeroU32::new_unchecked(n) }))
23 } else {
24 None
25 }
26 }
27
28 pub fn from_width(width: u8) -> Self {
31 Self::new(width as u32).unwrap()
32 }
33
34 pub const fn is_aligned(&self, n: u32) -> bool {
36 n & (self.0.get() - 1) == 0
38 }
39
40 pub const fn round_up(&self, n: u32) -> u32 {
42 let mask = self.0.get() - 1;
48 (n + mask) & !mask
49 }
50}
51
52impl Display for Alignment {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 self.0.get().fmt(f)
55 }
56}
57
58impl ops::Mul<u32> for Alignment {
59 type Output = u32;
60
61 fn mul(self, rhs: u32) -> Self::Output {
62 self.0.get() * rhs
63 }
64}
65
66impl ops::Mul for Alignment {
67 type Output = Alignment;
68
69 fn mul(self, rhs: Alignment) -> Self::Output {
70 Self(unsafe { NonZeroU32::new_unchecked(self.0.get() * rhs.0.get()) })
72 }
73}
74
75impl From<crate::VectorSize> for Alignment {
76 fn from(size: crate::VectorSize) -> Self {
77 match size {
78 crate::VectorSize::Bi => Alignment::TWO,
79 crate::VectorSize::Tri => Alignment::FOUR,
80 crate::VectorSize::Quad => Alignment::FOUR,
81 }
82 }
83}
84
85#[derive(Clone, Copy, Debug, Hash, PartialEq)]
87#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
88#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
89pub struct TypeLayout {
90 pub size: u32,
91 pub alignment: Alignment,
92}
93
94impl TypeLayout {
95 pub const fn to_stride(&self) -> u32 {
97 self.alignment.round_up(self.size)
98 }
99}
100
101#[derive(Debug, Default)]
111pub struct Layouter {
112 layouts: HandleVec<crate::Type, TypeLayout>,
114}
115
116impl ops::Index<Handle<crate::Type>> for Layouter {
117 type Output = TypeLayout;
118 fn index(&self, handle: Handle<crate::Type>) -> &TypeLayout {
119 &self.layouts[handle]
120 }
121}
122
123#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
124pub enum LayoutErrorInner {
125 #[error("Array element type {0:?} doesn't exist")]
126 InvalidArrayElementType(Handle<crate::Type>),
127 #[error("Struct member[{0}] type {1:?} doesn't exist")]
128 InvalidStructMemberType(u32, Handle<crate::Type>),
129 #[error("Type width must be a power of two")]
130 NonPowerOfTwoWidth,
131}
132
133#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)]
134#[error("Error laying out type {ty:?}: {inner}")]
135pub struct LayoutError {
136 pub ty: Handle<crate::Type>,
137 pub inner: LayoutErrorInner,
138}
139
140impl LayoutErrorInner {
141 const fn with(self, ty: Handle<crate::Type>) -> LayoutError {
142 LayoutError { ty, inner: self }
143 }
144}
145
146impl Layouter {
147 pub fn clear(&mut self) {
149 self.layouts.clear();
150 }
151
152 #[allow(clippy::or_fun_call)]
166 pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> {
167 use crate::TypeInner as Ti;
168
169 for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) {
170 let size = ty.inner.size(gctx);
171 let layout = match ty.inner {
172 Ti::Scalar(scalar) | Ti::Atomic(scalar) => {
173 let alignment = Alignment::new(scalar.width as u32)
174 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
175 TypeLayout { size, alignment }
176 }
177 Ti::Vector {
178 size: vec_size,
179 scalar,
180 } => {
181 let alignment = Alignment::new(scalar.width as u32)
182 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
183 TypeLayout {
184 size,
185 alignment: Alignment::from(vec_size) * alignment,
186 }
187 }
188 Ti::Matrix {
189 columns: _,
190 rows,
191 scalar,
192 } => {
193 let alignment = Alignment::new(scalar.width as u32)
194 .ok_or(LayoutErrorInner::NonPowerOfTwoWidth.with(ty_handle))?;
195 TypeLayout {
196 size,
197 alignment: Alignment::from(rows) * alignment,
198 }
199 }
200 Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout {
201 size,
202 alignment: Alignment::ONE,
203 },
204 Ti::Array {
205 base,
206 stride: _,
207 size: _,
208 } => TypeLayout {
209 size,
210 alignment: if base < ty_handle {
211 self[base].alignment
212 } else {
213 return Err(LayoutErrorInner::InvalidArrayElementType(base).with(ty_handle));
214 },
215 },
216 Ti::Struct { span, ref members } => {
217 let mut alignment = Alignment::ONE;
218 for (index, member) in members.iter().enumerate() {
219 alignment = if member.ty < ty_handle {
220 alignment.max(self[member.ty].alignment)
221 } else {
222 return Err(LayoutErrorInner::InvalidStructMemberType(
223 index as u32,
224 member.ty,
225 )
226 .with(ty_handle));
227 };
228 }
229 TypeLayout {
230 size: span,
231 alignment,
232 }
233 }
234 Ti::Image { .. }
235 | Ti::Sampler { .. }
236 | Ti::AccelerationStructure
237 | Ti::RayQuery
238 | Ti::BindingArray { .. } => TypeLayout {
239 size,
240 alignment: Alignment::ONE,
241 },
242 };
243 debug_assert!(size <= layout.size);
244 self.layouts.insert(ty_handle, layout);
245 }
246
247 Ok(())
248 }
249}