1use alloc::{collections::VecDeque, sync::Arc};
4use std::sync::Mutex;
5
6use accesskit::{
7 ActionHandler, ActionRequest, ActivationHandler, DeactivationHandler, Node, NodeId, Role, Tree,
8 TreeUpdate,
9};
10use accesskit_winit::Adapter;
11use bevy_a11y::{
12 AccessibilityNode, AccessibilityRequested, AccessibilitySystem,
13 ActionRequest as ActionRequestWrapper, Focus, ManageAccessibilityUpdates,
14};
15use bevy_app::{App, Plugin, PostUpdate};
16use bevy_derive::{Deref, DerefMut};
17use bevy_ecs::{
18 entity::EntityHashMap,
19 prelude::{DetectChanges, Entity, EventReader, EventWriter},
20 query::With,
21 schedule::IntoSystemConfigs,
22 system::{NonSendMut, Query, Res, ResMut, Resource},
23};
24use bevy_hierarchy::{Children, Parent};
25use bevy_window::{PrimaryWindow, Window, WindowClosed};
26
27#[derive(Default, Deref, DerefMut)]
29pub struct AccessKitAdapters(pub EntityHashMap<Adapter>);
30
31#[derive(Resource, Default, Deref, DerefMut)]
33pub struct WinitActionRequestHandlers(pub EntityHashMap<Arc<Mutex<WinitActionRequestHandler>>>);
34
35#[derive(Clone, Default, Deref, DerefMut)]
37pub struct WinitActionRequestHandler(pub VecDeque<ActionRequest>);
38
39impl WinitActionRequestHandler {
40 fn new() -> Arc<Mutex<Self>> {
41 Arc::new(Mutex::new(Self(VecDeque::new())))
42 }
43}
44
45struct AccessKitState {
46 name: String,
47 entity: Entity,
48 requested: AccessibilityRequested,
49}
50
51impl AccessKitState {
52 fn new(
53 name: impl Into<String>,
54 entity: Entity,
55 requested: AccessibilityRequested,
56 ) -> Arc<Mutex<Self>> {
57 let name = name.into();
58
59 Arc::new(Mutex::new(Self {
60 name,
61 entity,
62 requested,
63 }))
64 }
65
66 fn build_root(&mut self) -> Node {
67 let mut node = Node::new(Role::Window);
68 node.set_label(self.name.clone());
69 node
70 }
71
72 fn build_initial_tree(&mut self) -> TreeUpdate {
73 let root = self.build_root();
74 let accesskit_window_id = NodeId(self.entity.to_bits());
75 let mut tree = Tree::new(accesskit_window_id);
76 tree.app_name = Some(self.name.clone());
77 self.requested.set(true);
78
79 TreeUpdate {
80 nodes: vec![(accesskit_window_id, root)],
81 tree: Some(tree),
82 focus: accesskit_window_id,
83 }
84 }
85}
86
87struct WinitActivationHandler(Arc<Mutex<AccessKitState>>);
88
89impl ActivationHandler for WinitActivationHandler {
90 fn request_initial_tree(&mut self) -> Option<TreeUpdate> {
91 Some(self.0.lock().unwrap().build_initial_tree())
92 }
93}
94
95impl WinitActivationHandler {
96 pub fn new(state: Arc<Mutex<AccessKitState>>) -> Self {
97 Self(state)
98 }
99}
100
101#[derive(Clone, Default)]
102struct WinitActionHandler(Arc<Mutex<WinitActionRequestHandler>>);
103
104impl ActionHandler for WinitActionHandler {
105 fn do_action(&mut self, request: ActionRequest) {
106 let mut requests = self.0.lock().unwrap();
107 requests.push_back(request);
108 }
109}
110
111impl WinitActionHandler {
112 pub fn new(handler: Arc<Mutex<WinitActionRequestHandler>>) -> Self {
113 Self(handler)
114 }
115}
116
117struct WinitDeactivationHandler;
118
119impl DeactivationHandler for WinitDeactivationHandler {
120 fn deactivate_accessibility(&mut self) {}
121}
122
123pub(crate) fn prepare_accessibility_for_window(
125 winit_window: &winit::window::Window,
126 entity: Entity,
127 name: String,
128 accessibility_requested: AccessibilityRequested,
129 adapters: &mut AccessKitAdapters,
130 handlers: &mut WinitActionRequestHandlers,
131) {
132 let state = AccessKitState::new(name, entity, accessibility_requested);
133 let activation_handler = WinitActivationHandler::new(Arc::clone(&state));
134
135 let action_request_handler = WinitActionRequestHandler::new();
136 let action_handler = WinitActionHandler::new(Arc::clone(&action_request_handler));
137 let deactivation_handler = WinitDeactivationHandler;
138
139 let adapter = Adapter::with_direct_handlers(
140 winit_window,
141 activation_handler,
142 action_handler,
143 deactivation_handler,
144 );
145
146 adapters.insert(entity, adapter);
147 handlers.insert(entity, action_request_handler);
148}
149
150fn window_closed(
151 mut adapters: NonSendMut<AccessKitAdapters>,
152 mut handlers: ResMut<WinitActionRequestHandlers>,
153 mut events: EventReader<WindowClosed>,
154) {
155 for WindowClosed { window, .. } in events.read() {
156 adapters.remove(window);
157 handlers.remove(window);
158 }
159}
160
161fn poll_receivers(
162 handlers: Res<WinitActionRequestHandlers>,
163 mut actions: EventWriter<ActionRequestWrapper>,
164) {
165 for (_id, handler) in handlers.iter() {
166 let mut handler = handler.lock().unwrap();
167 while let Some(event) = handler.pop_front() {
168 actions.send(ActionRequestWrapper(event));
169 }
170 }
171}
172
173fn should_update_accessibility_nodes(
174 accessibility_requested: Res<AccessibilityRequested>,
175 manage_accessibility_updates: Res<ManageAccessibilityUpdates>,
176) -> bool {
177 accessibility_requested.get() && manage_accessibility_updates.get()
178}
179
180fn update_accessibility_nodes(
181 mut adapters: NonSendMut<AccessKitAdapters>,
182 focus: Res<Focus>,
183 primary_window: Query<(Entity, &Window), With<PrimaryWindow>>,
184 nodes: Query<(
185 Entity,
186 &AccessibilityNode,
187 Option<&Children>,
188 Option<&Parent>,
189 )>,
190 node_entities: Query<Entity, With<AccessibilityNode>>,
191) {
192 let Ok((primary_window_id, primary_window)) = primary_window.get_single() else {
193 return;
194 };
195 let Some(adapter) = adapters.get_mut(&primary_window_id) else {
196 return;
197 };
198 if focus.is_changed() || !nodes.is_empty() {
199 adapter.update_if_active(|| {
200 update_adapter(
201 nodes,
202 node_entities,
203 primary_window,
204 primary_window_id,
205 focus,
206 )
207 });
208 }
209}
210
211fn update_adapter(
212 nodes: Query<(
213 Entity,
214 &AccessibilityNode,
215 Option<&Children>,
216 Option<&Parent>,
217 )>,
218 node_entities: Query<Entity, With<AccessibilityNode>>,
219 primary_window: &Window,
220 primary_window_id: Entity,
221 focus: Res<Focus>,
222) -> TreeUpdate {
223 let mut to_update = vec![];
224 let mut window_children = vec![];
225 for (entity, node, children, parent) in &nodes {
226 let mut node = (**node).clone();
227 queue_node_for_update(entity, parent, &node_entities, &mut window_children);
228 add_children_nodes(children, &node_entities, &mut node);
229 let node_id = NodeId(entity.to_bits());
230 to_update.push((node_id, node));
231 }
232 let mut window_node = Node::new(Role::Window);
233 if primary_window.focused {
234 let title = primary_window.title.clone();
235 window_node.set_label(title.into_boxed_str());
236 }
237 window_node.set_children(window_children);
238 let node_id = NodeId(primary_window_id.to_bits());
239 let window_update = (node_id, window_node);
240 to_update.insert(0, window_update);
241 TreeUpdate {
242 nodes: to_update,
243 tree: None,
244 focus: NodeId(focus.unwrap_or(primary_window_id).to_bits()),
245 }
246}
247
248#[inline]
249fn queue_node_for_update(
250 node_entity: Entity,
251 parent: Option<&Parent>,
252 node_entities: &Query<Entity, With<AccessibilityNode>>,
253 window_children: &mut Vec<NodeId>,
254) {
255 let should_push = if let Some(parent) = parent {
256 !node_entities.contains(parent.get())
257 } else {
258 true
259 };
260 if should_push {
261 window_children.push(NodeId(node_entity.to_bits()));
262 }
263}
264
265#[inline]
266fn add_children_nodes(
267 children: Option<&Children>,
268 node_entities: &Query<Entity, With<AccessibilityNode>>,
269 node: &mut Node,
270) {
271 let Some(children) = children else {
272 return;
273 };
274 for child in children {
275 if node_entities.contains(*child) {
276 node.push_child(NodeId(child.to_bits()));
277 }
278 }
279}
280
281pub struct AccessKitPlugin;
283
284impl Plugin for AccessKitPlugin {
285 fn build(&self, app: &mut App) {
286 app.init_non_send_resource::<AccessKitAdapters>()
287 .init_resource::<WinitActionRequestHandlers>()
288 .add_event::<ActionRequestWrapper>()
289 .add_systems(
290 PostUpdate,
291 (
292 poll_receivers,
293 update_accessibility_nodes.run_if(should_update_accessibility_nodes),
294 window_closed
295 .before(poll_receivers)
296 .before(update_accessibility_nodes),
297 )
298 .in_set(AccessibilitySystem::Update),
299 );
300 }
301}