Skip to main content

freya_core/accessibility/
tree.rs

1use std::any::Any;
2
3use accesskit::{
4    Action,
5    Affine,
6    Node,
7    Rect,
8    Role,
9    TreeId,
10    TreeUpdate,
11};
12use ragnarok::ProcessedEvents;
13use rustc_hash::{
14    FxHashMap,
15    FxHashSet,
16};
17use torin::prelude::{
18    CursorPoint,
19    LayoutNode,
20};
21
22use crate::{
23    accessibility::{
24        focus_strategy::AccessibilityFocusStrategy,
25        focusable::Focusable,
26        id::AccessibilityId,
27    },
28    data::Overflow,
29    elements::{
30        label::LabelElement,
31        paragraph::ParagraphElement,
32    },
33    events::emittable::EmmitableEvent,
34    integration::{
35        EventName,
36        EventsChunk,
37    },
38    node_id::NodeId,
39    prelude::{
40        AccessibilityFocusMovement,
41        Color,
42        EventType,
43        FontSlant,
44        TextAlign,
45        TextDecoration,
46        WheelEventData,
47        WheelSource,
48    },
49    tree::Tree,
50};
51
52pub const ACCESSIBILITY_ROOT_ID: AccessibilityId = AccessibilityId(0);
53
54pub struct AccessibilityTree {
55    pub map: FxHashMap<AccessibilityId, NodeId>,
56    // Current focused Accessibility Node.
57    pub focused_id: AccessibilityId,
58}
59
60impl Default for AccessibilityTree {
61    fn default() -> Self {
62        Self::new(ACCESSIBILITY_ROOT_ID)
63    }
64}
65
66impl AccessibilityTree {
67    pub fn new(focused_id: AccessibilityId) -> Self {
68        Self {
69            focused_id,
70            map: FxHashMap::default(),
71        }
72    }
73
74    pub fn focused_node_id(&self) -> Option<NodeId> {
75        self.map.get(&self.focused_id).cloned()
76    }
77
78    /// Initialize the Accessibility Tree
79    pub fn init(&mut self, tree: &mut Tree) -> TreeUpdate {
80        tree.accessibility_diff.clear();
81
82        let mut nodes = vec![];
83
84        tree.traverse_depth(|node_id| {
85            let accessibility_state = tree.accessibility_state.get(&node_id).unwrap();
86            let layout_node = tree.layout.get(&node_id).unwrap();
87            let accessibility_node = Self::create_node(node_id, layout_node, tree);
88            nodes.push((accessibility_state.a11y_id, accessibility_node));
89            self.map.insert(accessibility_state.a11y_id, node_id);
90        });
91
92        #[cfg(debug_assertions)]
93        tracing::info!(
94            "Initialized the Accessibility Tree with {} nodes.",
95            nodes.len()
96        );
97
98        if !self.map.contains_key(&self.focused_id) {
99            self.focused_id = ACCESSIBILITY_ROOT_ID;
100        }
101
102        TreeUpdate {
103            tree_id: TreeId::ROOT,
104            nodes,
105            tree: Some(accesskit::Tree::new(ACCESSIBILITY_ROOT_ID)),
106            focus: self.focused_id,
107        }
108    }
109
110    /// Process any pending Accessibility Tree update
111    #[cfg_attr(feature = "hotpath", hotpath::measure)]
112    pub fn process_updates(
113        &mut self,
114        tree: &mut Tree,
115        events_sender: &futures_channel::mpsc::UnboundedSender<EventsChunk>,
116    ) -> TreeUpdate {
117        let requested_focus = tree.accessibility_diff.requested_focus.take();
118        let removed_ids = tree
119            .accessibility_diff
120            .removed
121            .drain()
122            .collect::<FxHashMap<_, _>>();
123        let mut added_or_updated_ids = tree
124            .accessibility_diff
125            .added_or_updated
126            .drain()
127            .collect::<FxHashSet<_>>();
128
129        #[cfg(debug_assertions)]
130        if !removed_ids.is_empty() || !added_or_updated_ids.is_empty() {
131            tracing::info!(
132                "Updating the Accessibility Tree with {} removals and {} additions/modifications",
133                removed_ids.len(),
134                added_or_updated_ids.len()
135            );
136        }
137
138        // Remove all the removed nodes from the update list
139        for (node_id, _) in removed_ids.iter() {
140            added_or_updated_ids.remove(node_id);
141            self.map.retain(|_, id| id != node_id);
142        }
143
144        // Mark the parent of the removed nodes as updated
145        for (_, parent_id) in removed_ids.iter() {
146            if !removed_ids.contains_key(parent_id) {
147                added_or_updated_ids.insert(*parent_id);
148            }
149        }
150
151        // Register the created/updated nodes
152        for node_id in added_or_updated_ids.clone() {
153            let accessibility_state = tree.accessibility_state.get(&node_id).unwrap();
154            self.map.insert(accessibility_state.a11y_id, node_id);
155
156            let node_parent_id = tree.parents.get(&node_id).unwrap_or(&NodeId::ROOT);
157            added_or_updated_ids.insert(*node_parent_id);
158        }
159
160        // Create the updated nodes
161        let mut nodes = Vec::new();
162        for node_id in added_or_updated_ids {
163            let accessibility_state = tree.accessibility_state.get(&node_id).unwrap();
164            let layout_node = tree.layout.get(&node_id).unwrap();
165            let accessibility_node = Self::create_node(node_id, layout_node, tree);
166            nodes.push((accessibility_state.a11y_id, accessibility_node));
167        }
168
169        let has_request_focus = requested_focus.is_some();
170
171        // Fallback the focused id to the root if the focused node no longer exists
172        if !self.map.contains_key(&self.focused_id) {
173            self.focused_id = ACCESSIBILITY_ROOT_ID;
174        }
175
176        // Focus the requested node id if there is one
177        if let Some(requested_focus) = requested_focus {
178            self.focus_node_with_strategy(requested_focus, tree);
179        }
180
181        if let Some(node_id) = self.focused_node_id()
182            && has_request_focus
183        {
184            self.scroll_to(node_id, tree, events_sender);
185        }
186
187        TreeUpdate {
188            tree_id: TreeId::ROOT,
189            nodes,
190            tree: Some(accesskit::Tree::new(ACCESSIBILITY_ROOT_ID)),
191            focus: self.focused_id,
192        }
193    }
194
195    /// Focus a Node given the strategy.
196    pub fn focus_node_with_strategy(
197        &mut self,
198        strategy: AccessibilityFocusStrategy,
199        tree: &mut Tree,
200    ) {
201        if let AccessibilityFocusStrategy::Node(id) = strategy {
202            if self.map.contains_key(&id) {
203                self.focused_id = id;
204            }
205            return;
206        }
207
208        let (navigable_nodes, focused_id) = if strategy.mode()
209            == Some(AccessibilityFocusMovement::InsideGroup)
210        {
211            // Get all accessible nodes in the current group
212            let mut group_nodes = Vec::new();
213
214            let node_id = self.map.get(&self.focused_id).unwrap();
215            let accessibility_state = tree.accessibility_state.get(node_id).unwrap();
216            let member_accessibility_id = accessibility_state.a11y_member_of;
217            if let Some(member_accessibility_id) = member_accessibility_id {
218                group_nodes = tree
219                    .accessibility_groups
220                    .get(&member_accessibility_id)
221                    .cloned()
222                    .unwrap_or_default()
223                    .into_iter()
224                    .filter(|id| {
225                        let node_id = self.map.get(id).unwrap();
226                        let accessibility_state = tree.accessibility_state.get(node_id).unwrap();
227                        accessibility_state.a11y_focusable == Focusable::Enabled
228                    })
229                    .collect();
230            }
231            (group_nodes, self.focused_id)
232        } else {
233            let mut nodes = Vec::new();
234
235            tree.traverse_depth(|node_id| {
236                let accessibility_state = tree.accessibility_state.get(&node_id).unwrap();
237                let member_accessibility_id = accessibility_state.a11y_member_of;
238
239                // Exclude nodes that are members of groups except for the parent of the group
240                if let Some(member_accessibility_id) = member_accessibility_id
241                    && member_accessibility_id != accessibility_state.a11y_id
242                {
243                    return;
244                }
245                if accessibility_state.a11y_focusable == Focusable::Enabled {
246                    nodes.push(accessibility_state.a11y_id);
247                }
248            });
249
250            (nodes, self.focused_id)
251        };
252
253        let node_index = navigable_nodes
254            .iter()
255            .position(|accessibility_id| *accessibility_id == focused_id);
256
257        let target_node = match strategy {
258            AccessibilityFocusStrategy::Forward(_) => {
259                // Find the next Node
260                if let Some(node_index) = node_index {
261                    if node_index == navigable_nodes.len() - 1 {
262                        navigable_nodes.first().cloned()
263                    } else {
264                        navigable_nodes.get(node_index + 1).cloned()
265                    }
266                } else {
267                    navigable_nodes.first().cloned()
268                }
269            }
270            AccessibilityFocusStrategy::Backward(_) => {
271                // Find the previous Node
272                if let Some(node_index) = node_index {
273                    if node_index == 0 {
274                        navigable_nodes.last().cloned()
275                    } else {
276                        navigable_nodes.get(node_index - 1).cloned()
277                    }
278                } else {
279                    navigable_nodes.last().cloned()
280                }
281            }
282            _ => unreachable!(),
283        };
284
285        self.focused_id = target_node.unwrap_or(focused_id);
286
287        #[cfg(debug_assertions)]
288        tracing::info!("Focused {:?} node.", self.focused_id);
289    }
290
291    /// Send the necessary wheel events to scroll views so that the given focused [NodeId] is visible on screen.
292    fn scroll_to(
293        &self,
294        node_id: NodeId,
295        tree: &mut Tree,
296        events_sender: &futures_channel::mpsc::UnboundedSender<EventsChunk>,
297    ) {
298        let Some(effect_state) = tree.effect_state.get(&node_id) else {
299            return;
300        };
301        let mut target_node = node_id;
302        let mut emmitable_events = Vec::new();
303        // Iterate over the inherited scrollables from the closes to the farthest
304        for closest_scrollable in effect_state.scrollables.iter().rev() {
305            // Every scrollable has a target node, the first scrollable target is the focused node that we want to make visible,
306            // the rest scrollables will in the other hand just have the previous scrollable as target
307            let target_layout_node = tree.layout.get(&target_node).unwrap();
308            let target_area = target_layout_node.area;
309            let scrollable_layout_node = tree.layout.get(closest_scrollable).unwrap();
310            let scrollable_target_area = scrollable_layout_node.area;
311
312            // We only want to scroll if it is not visible
313            if !effect_state.is_visible(&tree.layout, &target_area) {
314                let element = tree.elements.get(closest_scrollable).unwrap();
315                let scroll_x = element
316                    .accessibility()
317                    .builder
318                    .scroll_x()
319                    .unwrap_or_default() as f32;
320                let scroll_y = element
321                    .accessibility()
322                    .builder
323                    .scroll_y()
324                    .unwrap_or_default() as f32;
325
326                // Get the relative diff from where the scrollable scroll starts
327                let diff_x = target_area.min_x() - scrollable_target_area.min_x() - scroll_x;
328                let diff_y = target_area.min_y() - scrollable_target_area.min_y() - scroll_y;
329
330                // And get the distance it needs to scroll in order to make the target visible
331                let delta_y = -(scroll_y + diff_y);
332                let delta_x = -(scroll_x + diff_x);
333                emmitable_events.push(EmmitableEvent {
334                    name: EventName::Wheel,
335                    source_event: EventName::Wheel,
336                    node_id: *closest_scrollable,
337                    data: EventType::Wheel(WheelEventData::new(
338                        delta_x as f64,
339                        delta_y as f64,
340                        WheelSource::Custom,
341                        CursorPoint::default(),
342                        CursorPoint::default(),
343                    )),
344                    bubbles: false,
345                });
346                // Change the target to the current scrollable, so that the next scrollable makes sure this one is visible
347                target_node = *closest_scrollable;
348            }
349        }
350        events_sender
351            .unbounded_send(EventsChunk::Processed(ProcessedEvents {
352                emmitable_events,
353                ..Default::default()
354            }))
355            .unwrap();
356    }
357
358    /// Create an accessibility node
359    pub fn create_node(node_id: NodeId, layout_node: &LayoutNode, tree: &Tree) -> Node {
360        let element = tree.elements.get(&node_id).unwrap();
361        let mut accessibility_data = element.accessibility().into_owned();
362
363        if node_id == NodeId::ROOT {
364            accessibility_data.builder.set_role(Role::Window);
365        }
366
367        // Set children
368        let children = tree
369            .children
370            .get(&node_id)
371            .cloned()
372            .unwrap_or_default()
373            .into_iter()
374            .map(|child| tree.accessibility_state.get(&child).unwrap().a11y_id)
375            .collect::<Vec<_>>();
376        accessibility_data.builder.set_children(children);
377
378        // Set the area
379        let area = layout_node.area.to_f64();
380        accessibility_data.builder.set_bounds(Rect {
381            x0: area.min_x(),
382            x1: area.max_x(),
383            y0: area.min_y(),
384            y1: area.max_y(),
385        });
386
387        // Set inner text
388        if let Some(children) = tree.children.get(&node_id) {
389            for child in children {
390                let child_element = tree.elements.get(child).unwrap().as_ref() as &dyn Any;
391                if let Some(label) = child_element.downcast_ref::<LabelElement>() {
392                    accessibility_data.builder.set_label(label.text.as_ref());
393                } else if let Some(paragraph) = child_element.downcast_ref::<ParagraphElement>() {
394                    accessibility_data.builder.set_label(
395                        paragraph
396                            .spans
397                            .iter()
398                            .map(|span| span.text.as_ref())
399                            .collect::<String>(),
400                    );
401                }
402            }
403        }
404
405        // Set focusable action
406        // This will cause assistive technology to offer the user an option
407        // to focus the current element if it supports it.
408        if accessibility_data.a11y_focusable.is_enabled() {
409            accessibility_data.builder.add_action(Action::Focus);
410        }
411
412        let builder = &mut accessibility_data.builder;
413
414        if let Some(effect_state) = tree.effect_state.get(&node_id) {
415            if let Some(rotation) = effect_state.rotation {
416                let rotation = (rotation as f64).to_radians();
417                let (sin, cos) = rotation.sin_cos();
418                builder.set_transform(Affine::new([cos, sin, -sin, cos, 0.0, 0.0]));
419            }
420
421            if effect_state.overflow == Overflow::Clip {
422                builder.set_clips_children();
423            }
424        }
425
426        if let Some(background) = element.style().background.as_color() {
427            builder.set_background_color(color_to_accesskit(background));
428        }
429
430        let element = element.as_ref() as &dyn Any;
431        let is_text_element = element.is::<LabelElement>() || element.is::<ParagraphElement>();
432        if !is_text_element {
433            builder.set_is_line_breaking_object();
434        }
435
436        if let Some(text_style) = tree.text_style_state.get(&node_id) {
437            if let Some(color) = text_style.color.as_color() {
438                builder.set_foreground_color(color_to_accesskit(color));
439            }
440
441            builder.set_font_size(f32::from(text_style.font_size));
442            builder.set_font_weight(f32::from(text_style.font_weight));
443            builder.set_font_family(text_style.font_families.join(", "));
444
445            if matches!(
446                text_style.font_slant,
447                FontSlant::Italic | FontSlant::Oblique
448            ) {
449                builder.set_italic();
450            }
451
452            builder.set_text_align(match text_style.text_align {
453                TextAlign::Center => accesskit::TextAlign::Center,
454                TextAlign::Justify => accesskit::TextAlign::Justify,
455                // TODO: change the representation of `Start` and `End` once writing modes are supported.
456                TextAlign::Left | TextAlign::Start => accesskit::TextAlign::Left,
457                TextAlign::Right | TextAlign::End => accesskit::TextAlign::Right,
458            });
459
460            // TODO: adjust this once text directions other than left to right are supported.
461            builder.set_text_direction(accesskit::TextDirection::LeftToRight);
462
463            let decoration = accesskit::TextDecoration {
464                style: accesskit::TextDecorationStyle::Solid,
465                color: text_style
466                    .color
467                    .as_color()
468                    .map(color_to_accesskit)
469                    .unwrap_or(color_to_accesskit(Color::BLACK)),
470            };
471            match text_style.text_decoration {
472                TextDecoration::Underline => builder.set_underline(decoration),
473                TextDecoration::Overline => builder.set_overline(decoration),
474                TextDecoration::LineThrough => builder.set_strikethrough(decoration),
475                TextDecoration::None => {}
476            }
477        }
478
479        accessibility_data.builder
480    }
481}
482
483/// Convert a Freya [Color] into its [accesskit::Color] equivalent.
484fn color_to_accesskit(color: Color) -> accesskit::Color {
485    accesskit::Color {
486        red: color.r(),
487        green: color.g(),
488        blue: color.b(),
489        alpha: color.a(),
490    }
491}