bevy_render/render_graph/
context.rs

1use crate::{
2    render_graph::{NodeState, RenderGraph, SlotInfos, SlotLabel, SlotType, SlotValue},
3    render_resource::{Buffer, Sampler, TextureView},
4};
5use alloc::borrow::Cow;
6use bevy_ecs::entity::Entity;
7use derive_more::derive::{Display, Error};
8
9use super::{InternedRenderSubGraph, RenderSubGraph};
10
11/// A command that signals the graph runner to run the sub graph corresponding to the `sub_graph`
12/// with the specified `inputs` next.
13pub struct RunSubGraph {
14    pub sub_graph: InternedRenderSubGraph,
15    pub inputs: Vec<SlotValue>,
16    pub view_entity: Option<Entity>,
17}
18
19/// The context with all graph information required to run a [`Node`](super::Node).
20/// This context is created for each node by the render graph runner.
21///
22/// The slot input can be read from here and the outputs must be written back to the context for
23/// passing them onto the next node.
24///
25/// Sub graphs can be queued for running by adding a [`RunSubGraph`] command to the context.
26/// After the node has finished running the graph runner is responsible for executing the sub graphs.
27pub struct RenderGraphContext<'a> {
28    graph: &'a RenderGraph,
29    node: &'a NodeState,
30    inputs: &'a [SlotValue],
31    outputs: &'a mut [Option<SlotValue>],
32    run_sub_graphs: Vec<RunSubGraph>,
33    /// The `view_entity` associated with the render graph being executed
34    /// This is optional because you aren't required to have a `view_entity` for a node.
35    /// For example, compute shader nodes don't have one.
36    /// It should always be set when the [`RenderGraph`] is running on a View.
37    view_entity: Option<Entity>,
38}
39
40impl<'a> RenderGraphContext<'a> {
41    /// Creates a new render graph context for the `node`.
42    pub fn new(
43        graph: &'a RenderGraph,
44        node: &'a NodeState,
45        inputs: &'a [SlotValue],
46        outputs: &'a mut [Option<SlotValue>],
47    ) -> Self {
48        Self {
49            graph,
50            node,
51            inputs,
52            outputs,
53            run_sub_graphs: Vec::new(),
54            view_entity: None,
55        }
56    }
57
58    /// Returns the input slot values for the node.
59    #[inline]
60    pub fn inputs(&self) -> &[SlotValue] {
61        self.inputs
62    }
63
64    /// Returns the [`SlotInfos`] of the inputs.
65    pub fn input_info(&self) -> &SlotInfos {
66        &self.node.input_slots
67    }
68
69    /// Returns the [`SlotInfos`] of the outputs.
70    pub fn output_info(&self) -> &SlotInfos {
71        &self.node.output_slots
72    }
73
74    /// Retrieves the input slot value referenced by the `label`.
75    pub fn get_input(&self, label: impl Into<SlotLabel>) -> Result<&SlotValue, InputSlotError> {
76        let label = label.into();
77        let index = self
78            .input_info()
79            .get_slot_index(label.clone())
80            .ok_or(InputSlotError::InvalidSlot(label))?;
81        Ok(&self.inputs[index])
82    }
83
84    // TODO: should this return an Arc or a reference?
85    /// Retrieves the input slot value referenced by the `label` as a [`TextureView`].
86    pub fn get_input_texture(
87        &self,
88        label: impl Into<SlotLabel>,
89    ) -> Result<&TextureView, InputSlotError> {
90        let label = label.into();
91        match self.get_input(label.clone())? {
92            SlotValue::TextureView(value) => Ok(value),
93            value => Err(InputSlotError::MismatchedSlotType {
94                label,
95                actual: value.slot_type(),
96                expected: SlotType::TextureView,
97            }),
98        }
99    }
100
101    /// Retrieves the input slot value referenced by the `label` as a [`Sampler`].
102    pub fn get_input_sampler(
103        &self,
104        label: impl Into<SlotLabel>,
105    ) -> Result<&Sampler, InputSlotError> {
106        let label = label.into();
107        match self.get_input(label.clone())? {
108            SlotValue::Sampler(value) => Ok(value),
109            value => Err(InputSlotError::MismatchedSlotType {
110                label,
111                actual: value.slot_type(),
112                expected: SlotType::Sampler,
113            }),
114        }
115    }
116
117    /// Retrieves the input slot value referenced by the `label` as a [`Buffer`].
118    pub fn get_input_buffer(&self, label: impl Into<SlotLabel>) -> Result<&Buffer, InputSlotError> {
119        let label = label.into();
120        match self.get_input(label.clone())? {
121            SlotValue::Buffer(value) => Ok(value),
122            value => Err(InputSlotError::MismatchedSlotType {
123                label,
124                actual: value.slot_type(),
125                expected: SlotType::Buffer,
126            }),
127        }
128    }
129
130    /// Retrieves the input slot value referenced by the `label` as an [`Entity`].
131    pub fn get_input_entity(&self, label: impl Into<SlotLabel>) -> Result<Entity, InputSlotError> {
132        let label = label.into();
133        match self.get_input(label.clone())? {
134            SlotValue::Entity(value) => Ok(*value),
135            value => Err(InputSlotError::MismatchedSlotType {
136                label,
137                actual: value.slot_type(),
138                expected: SlotType::Entity,
139            }),
140        }
141    }
142
143    /// Sets the output slot value referenced by the `label`.
144    pub fn set_output(
145        &mut self,
146        label: impl Into<SlotLabel>,
147        value: impl Into<SlotValue>,
148    ) -> Result<(), OutputSlotError> {
149        let label = label.into();
150        let value = value.into();
151        let slot_index = self
152            .output_info()
153            .get_slot_index(label.clone())
154            .ok_or_else(|| OutputSlotError::InvalidSlot(label.clone()))?;
155        let slot = self
156            .output_info()
157            .get_slot(slot_index)
158            .expect("slot is valid");
159        if value.slot_type() != slot.slot_type {
160            return Err(OutputSlotError::MismatchedSlotType {
161                label,
162                actual: slot.slot_type,
163                expected: value.slot_type(),
164            });
165        }
166        self.outputs[slot_index] = Some(value);
167        Ok(())
168    }
169
170    pub fn view_entity(&self) -> Entity {
171        self.view_entity.unwrap()
172    }
173
174    pub fn get_view_entity(&self) -> Option<Entity> {
175        self.view_entity
176    }
177
178    pub fn set_view_entity(&mut self, view_entity: Entity) {
179        self.view_entity = Some(view_entity);
180    }
181
182    /// Queues up a sub graph for execution after the node has finished running.
183    pub fn run_sub_graph(
184        &mut self,
185        name: impl RenderSubGraph,
186        inputs: Vec<SlotValue>,
187        view_entity: Option<Entity>,
188    ) -> Result<(), RunSubGraphError> {
189        let name = name.intern();
190        let sub_graph = self
191            .graph
192            .get_sub_graph(name)
193            .ok_or(RunSubGraphError::MissingSubGraph(name))?;
194        if let Some(input_node) = sub_graph.get_input_node() {
195            for (i, input_slot) in input_node.input_slots.iter().enumerate() {
196                if let Some(input_value) = inputs.get(i) {
197                    if input_slot.slot_type != input_value.slot_type() {
198                        return Err(RunSubGraphError::MismatchedInputSlotType {
199                            graph_name: name,
200                            slot_index: i,
201                            actual: input_value.slot_type(),
202                            expected: input_slot.slot_type,
203                            label: input_slot.name.clone().into(),
204                        });
205                    }
206                } else {
207                    return Err(RunSubGraphError::MissingInput {
208                        slot_index: i,
209                        slot_name: input_slot.name.clone(),
210                        graph_name: name,
211                    });
212                }
213            }
214        } else if !inputs.is_empty() {
215            return Err(RunSubGraphError::SubGraphHasNoInputs(name));
216        }
217
218        self.run_sub_graphs.push(RunSubGraph {
219            sub_graph: name,
220            inputs,
221            view_entity,
222        });
223
224        Ok(())
225    }
226
227    /// Finishes the context for this [`Node`](super::Node) by
228    /// returning the sub graphs to run next.
229    pub fn finish(self) -> Vec<RunSubGraph> {
230        self.run_sub_graphs
231    }
232}
233
234#[derive(Error, Display, Debug, Eq, PartialEq)]
235pub enum RunSubGraphError {
236    #[display("attempted to run sub-graph `{_0:?}`, but it does not exist")]
237    #[error(ignore)]
238    MissingSubGraph(InternedRenderSubGraph),
239    #[display("attempted to pass inputs to sub-graph `{_0:?}`, which has no input slots")]
240    #[error(ignore)]
241    SubGraphHasNoInputs(InternedRenderSubGraph),
242    #[display("sub graph (name: `{graph_name:?}`) could not be run because slot `{slot_name}` at index {slot_index} has no value")]
243    MissingInput {
244        slot_index: usize,
245        slot_name: Cow<'static, str>,
246        graph_name: InternedRenderSubGraph,
247    },
248    #[display("attempted to use the wrong type for input slot")]
249    MismatchedInputSlotType {
250        graph_name: InternedRenderSubGraph,
251        slot_index: usize,
252        label: SlotLabel,
253        expected: SlotType,
254        actual: SlotType,
255    },
256}
257
258#[derive(Error, Display, Debug, Eq, PartialEq)]
259pub enum OutputSlotError {
260    #[display("output slot `{_0:?}` does not exist")]
261    #[error(ignore)]
262    InvalidSlot(SlotLabel),
263    #[display("attempted to output a value of type `{actual}` to output slot `{label:?}`, which has type `{expected}`")]
264    MismatchedSlotType {
265        label: SlotLabel,
266        expected: SlotType,
267        actual: SlotType,
268    },
269}
270
271#[derive(Error, Display, Debug, Eq, PartialEq)]
272pub enum InputSlotError {
273    #[display("input slot `{_0:?}` does not exist")]
274    #[error(ignore)]
275    InvalidSlot(SlotLabel),
276    #[display("attempted to retrieve a value of type `{actual}` from input slot `{label:?}`, which has type `{expected}`")]
277    MismatchedSlotType {
278        label: SlotLabel,
279        expected: SlotType,
280        actual: SlotType,
281    },
282}