1use alloc::{collections::VecDeque, sync::Arc};
4use bevy_input_focus::InputFocus;
5use std::sync::Mutex;
6use winit::event_loop::ActiveEventLoop;
7
8use accesskit::{
9 ActionHandler, ActionRequest, ActivationHandler, DeactivationHandler, Node, NodeId, Role, Tree,
10 TreeUpdate,
11};
12use accesskit_winit::Adapter;
13use bevy_a11y::{
14 AccessibilityNode, AccessibilityRequested, AccessibilitySystem,
15 ActionRequest as ActionRequestWrapper, ManageAccessibilityUpdates,
16};
17use bevy_app::{App, Plugin, PostUpdate};
18use bevy_derive::{Deref, DerefMut};
19use bevy_ecs::{entity::EntityHashMap, prelude::*};
20use bevy_window::{PrimaryWindow, Window, WindowClosed};
21
22#[derive(Default, Deref, DerefMut)]
24pub struct AccessKitAdapters(pub EntityHashMap<Adapter>);
25
26#[derive(Resource, Default, Deref, DerefMut)]
28pub struct WinitActionRequestHandlers(pub EntityHashMap<Arc<Mutex<WinitActionRequestHandler>>>);
29
30#[derive(Clone, Default, Deref, DerefMut)]
32pub struct WinitActionRequestHandler(pub VecDeque<ActionRequest>);
33
34impl WinitActionRequestHandler {
35 fn new() -> Arc<Mutex<Self>> {
36 Arc::new(Mutex::new(Self(VecDeque::new())))
37 }
38}
39
40struct AccessKitState {
41 name: String,
42 entity: Entity,
43 requested: AccessibilityRequested,
44}
45
46impl AccessKitState {
47 fn new(
48 name: impl Into<String>,
49 entity: Entity,
50 requested: AccessibilityRequested,
51 ) -> Arc<Mutex<Self>> {
52 let name = name.into();
53
54 Arc::new(Mutex::new(Self {
55 name,
56 entity,
57 requested,
58 }))
59 }
60
61 fn build_root(&mut self) -> Node {
62 let mut node = Node::new(Role::Window);
63 node.set_label(self.name.clone());
64 node
65 }
66
67 fn build_initial_tree(&mut self) -> TreeUpdate {
68 let root = self.build_root();
69 let accesskit_window_id = NodeId(self.entity.to_bits());
70 let tree = Tree::new(accesskit_window_id);
71 self.requested.set(true);
72
73 TreeUpdate {
74 nodes: vec![(accesskit_window_id, root)],
75 tree: Some(tree),
76 focus: accesskit_window_id,
77 }
78 }
79}
80
81struct WinitActivationHandler(Arc<Mutex<AccessKitState>>);
82
83impl ActivationHandler for WinitActivationHandler {
84 fn request_initial_tree(&mut self) -> Option<TreeUpdate> {
85 Some(self.0.lock().unwrap().build_initial_tree())
86 }
87}
88
89impl WinitActivationHandler {
90 pub fn new(state: Arc<Mutex<AccessKitState>>) -> Self {
91 Self(state)
92 }
93}
94
95#[derive(Clone, Default)]
96struct WinitActionHandler(Arc<Mutex<WinitActionRequestHandler>>);
97
98impl ActionHandler for WinitActionHandler {
99 fn do_action(&mut self, request: ActionRequest) {
100 let mut requests = self.0.lock().unwrap();
101 requests.push_back(request);
102 }
103}
104
105impl WinitActionHandler {
106 pub fn new(handler: Arc<Mutex<WinitActionRequestHandler>>) -> Self {
107 Self(handler)
108 }
109}
110
111struct WinitDeactivationHandler;
112
113impl DeactivationHandler for WinitDeactivationHandler {
114 fn deactivate_accessibility(&mut self) {}
115}
116
117pub(crate) fn prepare_accessibility_for_window(
119 event_loop: &ActiveEventLoop,
120 winit_window: &winit::window::Window,
121 entity: Entity,
122 name: String,
123 accessibility_requested: AccessibilityRequested,
124 adapters: &mut AccessKitAdapters,
125 handlers: &mut WinitActionRequestHandlers,
126) {
127 let state = AccessKitState::new(name, entity, accessibility_requested);
128 let activation_handler = WinitActivationHandler::new(Arc::clone(&state));
129
130 let action_request_handler = WinitActionRequestHandler::new();
131 let action_handler = WinitActionHandler::new(Arc::clone(&action_request_handler));
132 let deactivation_handler = WinitDeactivationHandler;
133
134 let adapter = Adapter::with_direct_handlers(
135 event_loop,
136 winit_window,
137 activation_handler,
138 action_handler,
139 deactivation_handler,
140 );
141
142 adapters.insert(entity, adapter);
143 handlers.insert(entity, action_request_handler);
144}
145
146fn window_closed(
147 mut adapters: NonSendMut<AccessKitAdapters>,
148 mut handlers: ResMut<WinitActionRequestHandlers>,
149 mut events: EventReader<WindowClosed>,
150) {
151 for WindowClosed { window, .. } in events.read() {
152 adapters.remove(window);
153 handlers.remove(window);
154 }
155}
156
157fn poll_receivers(
158 handlers: Res<WinitActionRequestHandlers>,
159 mut actions: EventWriter<ActionRequestWrapper>,
160) {
161 for (_id, handler) in handlers.iter() {
162 let mut handler = handler.lock().unwrap();
163 while let Some(event) = handler.pop_front() {
164 actions.write(ActionRequestWrapper(event));
165 }
166 }
167}
168
169fn should_update_accessibility_nodes(
170 accessibility_requested: Res<AccessibilityRequested>,
171 manage_accessibility_updates: Res<ManageAccessibilityUpdates>,
172) -> bool {
173 accessibility_requested.get() && manage_accessibility_updates.get()
174}
175
176fn update_accessibility_nodes(
177 mut adapters: NonSendMut<AccessKitAdapters>,
178 focus: Option<Res<InputFocus>>,
179 primary_window: Query<(Entity, &Window), With<PrimaryWindow>>,
180 nodes: Query<(
181 Entity,
182 &AccessibilityNode,
183 Option<&Children>,
184 Option<&ChildOf>,
185 )>,
186 node_entities: Query<Entity, With<AccessibilityNode>>,
187) {
188 let Ok((primary_window_id, primary_window)) = primary_window.single() else {
189 return;
190 };
191 let Some(adapter) = adapters.get_mut(&primary_window_id) else {
192 return;
193 };
194 let Some(focus) = focus else {
195 return;
196 };
197 if focus.is_changed() || !nodes.is_empty() {
198 if let Some(focused_entity) = focus.0 {
201 if !node_entities.contains(focused_entity) {
202 return;
203 }
204 }
205
206 adapter.update_if_active(|| {
207 update_adapter(
208 nodes,
209 node_entities,
210 primary_window,
211 primary_window_id,
212 focus,
213 )
214 });
215 }
216}
217
218fn update_adapter(
219 nodes: Query<(
220 Entity,
221 &AccessibilityNode,
222 Option<&Children>,
223 Option<&ChildOf>,
224 )>,
225 node_entities: Query<Entity, With<AccessibilityNode>>,
226 primary_window: &Window,
227 primary_window_id: Entity,
228 focus: Res<InputFocus>,
229) -> TreeUpdate {
230 let mut to_update = vec![];
231 let mut window_children = vec![];
232 for (entity, node, children, child_of) in &nodes {
233 let mut node = (**node).clone();
234 queue_node_for_update(entity, child_of, &node_entities, &mut window_children);
235 add_children_nodes(children, &node_entities, &mut node);
236 let node_id = NodeId(entity.to_bits());
237 to_update.push((node_id, node));
238 }
239 let mut window_node = Node::new(Role::Window);
240 if primary_window.focused {
241 let title = primary_window.title.clone();
242 window_node.set_label(title.into_boxed_str());
243 }
244 window_node.set_children(window_children);
245 let node_id = NodeId(primary_window_id.to_bits());
246 let window_update = (node_id, window_node);
247 to_update.insert(0, window_update);
248 TreeUpdate {
249 nodes: to_update,
250 tree: None,
251 focus: NodeId(focus.0.unwrap_or(primary_window_id).to_bits()),
252 }
253}
254
255#[inline]
256fn queue_node_for_update(
257 node_entity: Entity,
258 child_of: Option<&ChildOf>,
259 node_entities: &Query<Entity, With<AccessibilityNode>>,
260 window_children: &mut Vec<NodeId>,
261) {
262 let should_push = if let Some(child_of) = child_of {
263 !node_entities.contains(child_of.parent())
264 } else {
265 true
266 };
267 if should_push {
268 window_children.push(NodeId(node_entity.to_bits()));
269 }
270}
271
272#[inline]
273fn add_children_nodes(
274 children: Option<&Children>,
275 node_entities: &Query<Entity, With<AccessibilityNode>>,
276 node: &mut Node,
277) {
278 let Some(children) = children else {
279 return;
280 };
281 for child in children {
282 if node_entities.contains(*child) {
283 node.push_child(NodeId(child.to_bits()));
284 }
285 }
286}
287
288pub struct AccessKitPlugin;
290
291impl Plugin for AccessKitPlugin {
292 fn build(&self, app: &mut App) {
293 app.init_non_send_resource::<AccessKitAdapters>()
294 .init_resource::<WinitActionRequestHandlers>()
295 .add_event::<ActionRequestWrapper>()
296 .add_systems(
297 PostUpdate,
298 (
299 poll_receivers,
300 update_accessibility_nodes.run_if(should_update_accessibility_nodes),
301 window_closed
302 .before(poll_receivers)
303 .before(update_accessibility_nodes),
304 )
305 .in_set(AccessibilitySystem::Update),
306 );
307 }
308}