bevy_render/render_graph/
context.rs

1use crate::{
2    render_graph::{NodeState, RenderGraph, SlotInfos, SlotLabel, SlotType, SlotValue},
3    render_resource::{Buffer, Sampler, TextureView},
4};
5use alloc::borrow::Cow;
6use bevy_ecs::{entity::Entity, intern::Interned};
7use thiserror::Error;
8
9use super::{InternedRenderSubGraph, RenderLabel, RenderSubGraph};
10
11/// A command that signals the graph runner to run the sub graph corresponding to the `sub_graph`
12/// with the specified `inputs` next.
13pub struct RunSubGraph {
14    pub sub_graph: InternedRenderSubGraph,
15    pub inputs: Vec<SlotValue>,
16    pub view_entity: Option<Entity>,
17    pub debug_group: Option<String>,
18}
19
20/// The context with all graph information required to run a [`Node`](super::Node).
21/// This context is created for each node by the render graph runner.
22///
23/// The slot input can be read from here and the outputs must be written back to the context for
24/// passing them onto the next node.
25///
26/// Sub graphs can be queued for running by adding a [`RunSubGraph`] command to the context.
27/// After the node has finished running the graph runner is responsible for executing the sub graphs.
28pub struct RenderGraphContext<'a> {
29    graph: &'a RenderGraph,
30    node: &'a NodeState,
31    inputs: &'a [SlotValue],
32    outputs: &'a mut [Option<SlotValue>],
33    run_sub_graphs: Vec<RunSubGraph>,
34    /// The `view_entity` associated with the render graph being executed
35    /// This is optional because you aren't required to have a `view_entity` for a node.
36    /// For example, compute shader nodes don't have one.
37    /// It should always be set when the [`RenderGraph`] is running on a View.
38    view_entity: Option<Entity>,
39}
40
41impl<'a> RenderGraphContext<'a> {
42    /// Creates a new render graph context for the `node`.
43    pub fn new(
44        graph: &'a RenderGraph,
45        node: &'a NodeState,
46        inputs: &'a [SlotValue],
47        outputs: &'a mut [Option<SlotValue>],
48    ) -> Self {
49        Self {
50            graph,
51            node,
52            inputs,
53            outputs,
54            run_sub_graphs: Vec::new(),
55            view_entity: None,
56        }
57    }
58
59    /// Returns the input slot values for the node.
60    #[inline]
61    pub fn inputs(&self) -> &[SlotValue] {
62        self.inputs
63    }
64
65    /// Returns the [`SlotInfos`] of the inputs.
66    pub fn input_info(&self) -> &SlotInfos {
67        &self.node.input_slots
68    }
69
70    /// Returns the [`SlotInfos`] of the outputs.
71    pub fn output_info(&self) -> &SlotInfos {
72        &self.node.output_slots
73    }
74
75    /// Retrieves the input slot value referenced by the `label`.
76    pub fn get_input(&self, label: impl Into<SlotLabel>) -> Result<&SlotValue, InputSlotError> {
77        let label = label.into();
78        let index = self
79            .input_info()
80            .get_slot_index(label.clone())
81            .ok_or(InputSlotError::InvalidSlot(label))?;
82        Ok(&self.inputs[index])
83    }
84
85    // TODO: should this return an Arc or a reference?
86    /// Retrieves the input slot value referenced by the `label` as a [`TextureView`].
87    pub fn get_input_texture(
88        &self,
89        label: impl Into<SlotLabel>,
90    ) -> Result<&TextureView, InputSlotError> {
91        let label = label.into();
92        match self.get_input(label.clone())? {
93            SlotValue::TextureView(value) => Ok(value),
94            value => Err(InputSlotError::MismatchedSlotType {
95                label,
96                actual: value.slot_type(),
97                expected: SlotType::TextureView,
98            }),
99        }
100    }
101
102    /// Retrieves the input slot value referenced by the `label` as a [`Sampler`].
103    pub fn get_input_sampler(
104        &self,
105        label: impl Into<SlotLabel>,
106    ) -> Result<&Sampler, InputSlotError> {
107        let label = label.into();
108        match self.get_input(label.clone())? {
109            SlotValue::Sampler(value) => Ok(value),
110            value => Err(InputSlotError::MismatchedSlotType {
111                label,
112                actual: value.slot_type(),
113                expected: SlotType::Sampler,
114            }),
115        }
116    }
117
118    /// Retrieves the input slot value referenced by the `label` as a [`Buffer`].
119    pub fn get_input_buffer(&self, label: impl Into<SlotLabel>) -> Result<&Buffer, InputSlotError> {
120        let label = label.into();
121        match self.get_input(label.clone())? {
122            SlotValue::Buffer(value) => Ok(value),
123            value => Err(InputSlotError::MismatchedSlotType {
124                label,
125                actual: value.slot_type(),
126                expected: SlotType::Buffer,
127            }),
128        }
129    }
130
131    /// Retrieves the input slot value referenced by the `label` as an [`Entity`].
132    pub fn get_input_entity(&self, label: impl Into<SlotLabel>) -> Result<Entity, InputSlotError> {
133        let label = label.into();
134        match self.get_input(label.clone())? {
135            SlotValue::Entity(value) => Ok(*value),
136            value => Err(InputSlotError::MismatchedSlotType {
137                label,
138                actual: value.slot_type(),
139                expected: SlotType::Entity,
140            }),
141        }
142    }
143
144    /// Sets the output slot value referenced by the `label`.
145    pub fn set_output(
146        &mut self,
147        label: impl Into<SlotLabel>,
148        value: impl Into<SlotValue>,
149    ) -> Result<(), OutputSlotError> {
150        let label = label.into();
151        let value = value.into();
152        let slot_index = self
153            .output_info()
154            .get_slot_index(label.clone())
155            .ok_or_else(|| OutputSlotError::InvalidSlot(label.clone()))?;
156        let slot = self
157            .output_info()
158            .get_slot(slot_index)
159            .expect("slot is valid");
160        if value.slot_type() != slot.slot_type {
161            return Err(OutputSlotError::MismatchedSlotType {
162                label,
163                actual: slot.slot_type,
164                expected: value.slot_type(),
165            });
166        }
167        self.outputs[slot_index] = Some(value);
168        Ok(())
169    }
170
171    pub fn view_entity(&self) -> Entity {
172        self.view_entity.unwrap()
173    }
174
175    pub fn get_view_entity(&self) -> Option<Entity> {
176        self.view_entity
177    }
178
179    pub fn set_view_entity(&mut self, view_entity: Entity) {
180        self.view_entity = Some(view_entity);
181    }
182
183    /// Queues up a sub graph for execution after the node has finished running.
184    pub fn run_sub_graph(
185        &mut self,
186        name: impl RenderSubGraph,
187        inputs: Vec<SlotValue>,
188        view_entity: Option<Entity>,
189        debug_group: Option<String>,
190    ) -> Result<(), RunSubGraphError> {
191        let name = name.intern();
192        let sub_graph = self
193            .graph
194            .get_sub_graph(name)
195            .ok_or(RunSubGraphError::MissingSubGraph(name))?;
196        if let Some(input_node) = sub_graph.get_input_node() {
197            for (i, input_slot) in input_node.input_slots.iter().enumerate() {
198                if let Some(input_value) = inputs.get(i) {
199                    if input_slot.slot_type != input_value.slot_type() {
200                        return Err(RunSubGraphError::MismatchedInputSlotType {
201                            graph_name: name,
202                            slot_index: i,
203                            actual: input_value.slot_type(),
204                            expected: input_slot.slot_type,
205                            label: input_slot.name.clone().into(),
206                        });
207                    }
208                } else {
209                    return Err(RunSubGraphError::MissingInput {
210                        slot_index: i,
211                        slot_name: input_slot.name.clone(),
212                        graph_name: name,
213                    });
214                }
215            }
216        } else if !inputs.is_empty() {
217            return Err(RunSubGraphError::SubGraphHasNoInputs(name));
218        }
219
220        self.run_sub_graphs.push(RunSubGraph {
221            sub_graph: name,
222            inputs,
223            view_entity,
224            debug_group,
225        });
226
227        Ok(())
228    }
229
230    /// Returns a human-readable label for this node, for debugging purposes.
231    pub fn label(&self) -> Interned<dyn RenderLabel> {
232        self.node.label
233    }
234
235    /// Finishes the context for this [`Node`](super::Node) by
236    /// returning the sub graphs to run next.
237    pub fn finish(self) -> Vec<RunSubGraph> {
238        self.run_sub_graphs
239    }
240}
241
242#[derive(Error, Debug, Eq, PartialEq)]
243pub enum RunSubGraphError {
244    #[error("attempted to run sub-graph `{0:?}`, but it does not exist")]
245    MissingSubGraph(InternedRenderSubGraph),
246    #[error("attempted to pass inputs to sub-graph `{0:?}`, which has no input slots")]
247    SubGraphHasNoInputs(InternedRenderSubGraph),
248    #[error("sub graph (name: `{graph_name:?}`) could not be run because slot `{slot_name}` at index {slot_index} has no value")]
249    MissingInput {
250        slot_index: usize,
251        slot_name: Cow<'static, str>,
252        graph_name: InternedRenderSubGraph,
253    },
254    #[error("attempted to use the wrong type for input slot")]
255    MismatchedInputSlotType {
256        graph_name: InternedRenderSubGraph,
257        slot_index: usize,
258        label: SlotLabel,
259        expected: SlotType,
260        actual: SlotType,
261    },
262}
263
264#[derive(Error, Debug, Eq, PartialEq)]
265pub enum OutputSlotError {
266    #[error("output slot `{0:?}` does not exist")]
267    InvalidSlot(SlotLabel),
268    #[error("attempted to output a value of type `{actual}` to output slot `{label:?}`, which has type `{expected}`")]
269    MismatchedSlotType {
270        label: SlotLabel,
271        expected: SlotType,
272        actual: SlotType,
273    },
274}
275
276#[derive(Error, Debug, Eq, PartialEq)]
277pub enum InputSlotError {
278    #[error("input slot `{0:?}` does not exist")]
279    InvalidSlot(SlotLabel),
280    #[error("attempted to retrieve a value of type `{actual}` from input slot `{label:?}`, which has type `{expected}`")]
281    MismatchedSlotType {
282        label: SlotLabel,
283        expected: SlotType,
284        actual: SlotType,
285    },
286}