Skip to main content

freya_core/lifecycle/
task.rs

1use std::{
2    cell::RefCell,
3    pin::Pin,
4    rc::Rc,
5    sync::{
6        Arc,
7        atomic::Ordering,
8    },
9};
10
11use crate::{
12    current_context::CurrentContext,
13    prelude::current_scope_id,
14    runner::Message,
15    scope_id::ScopeId,
16};
17
18pub fn spawn_forever(future: impl Future<Output = ()> + 'static) -> TaskHandle {
19    CurrentContext::with(|context| {
20        let task_id = TaskId(context.task_id_counter.fetch_add(1, Ordering::Relaxed));
21        context.tasks.borrow_mut().insert(
22            task_id,
23            Rc::new(RefCell::new(Task {
24                scope_id: ScopeId::ROOT,
25                future: Box::pin(future),
26                waker: futures_util::task::waker(Arc::new(TaskWaker {
27                    task_id,
28                    sender: context.sender.clone(),
29                })),
30            })),
31        );
32        context
33            .sender
34            .unbounded_send(Message::PollTask(task_id))
35            .unwrap();
36        task_id.into()
37    })
38}
39
40pub fn spawn(future: impl Future<Output = ()> + 'static) -> TaskHandle {
41    CurrentContext::with(|context| {
42        let task_id = TaskId(context.task_id_counter.fetch_add(1, Ordering::Relaxed));
43        context.tasks.borrow_mut().insert(
44            task_id,
45            Rc::new(RefCell::new(Task {
46                scope_id: current_scope_id(),
47                future: Box::pin(future),
48                waker: futures_util::task::waker(Arc::new(TaskWaker {
49                    task_id,
50                    sender: context.sender.clone(),
51                })),
52            })),
53        );
54        context
55            .sender
56            .unbounded_send(Message::PollTask(task_id))
57            .unwrap();
58        task_id.into()
59    })
60}
61
62#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
63pub struct TaskHandle(TaskId);
64
65impl From<TaskId> for TaskHandle {
66    fn from(value: TaskId) -> Self {
67        TaskHandle(value)
68    }
69}
70
71impl TaskHandle {
72    pub fn cancel(&self) {
73        CurrentContext::with(|context| context.tasks.borrow_mut().remove(&self.0));
74    }
75
76    pub fn try_cancel(&self) {
77        CurrentContext::try_with(|context| context.tasks.borrow_mut().remove(&self.0));
78    }
79
80    /// Upgrade to an [OwnedTaskHandle] that cancels the task when the last clone is dropped.
81    pub fn owned(self) -> OwnedTaskHandle {
82        OwnedTaskHandle(Rc::new(InnerOwnedTaskHandle(self)))
83    }
84}
85
86struct InnerOwnedTaskHandle(TaskHandle);
87
88impl Drop for InnerOwnedTaskHandle {
89    fn drop(&mut self) {
90        self.0.try_cancel();
91    }
92}
93
94/// A task handle that cancels the task when the last clone is dropped.
95#[derive(Clone)]
96pub struct OwnedTaskHandle(Rc<InnerOwnedTaskHandle>);
97
98impl PartialEq for OwnedTaskHandle {
99    fn eq(&self, other: &Self) -> bool {
100        Rc::ptr_eq(&self.0, &other.0)
101    }
102}
103
104impl OwnedTaskHandle {
105    pub fn cancel(&self) {
106        self.0.0.cancel();
107    }
108
109    pub fn try_cancel(&self) {
110        self.0.0.try_cancel();
111    }
112
113    /// Downgrade to a non-owning [TaskHandle].
114    pub fn downgrade(&self) -> TaskHandle {
115        self.0.0
116    }
117}
118
119pub struct TaskWaker {
120    task_id: TaskId,
121    sender: futures_channel::mpsc::UnboundedSender<Message>,
122}
123
124impl futures_util::task::ArcWake for TaskWaker {
125    fn wake_by_ref(arc_self: &Arc<Self>) {
126        _ = arc_self
127            .sender
128            .unbounded_send(Message::PollTask(arc_self.task_id));
129    }
130}
131
132pub struct Task {
133    pub scope_id: ScopeId,
134    pub future: Pin<Box<dyn Future<Output = ()>>>,
135    /// Used to notify the runner that this task needs progress.
136    pub waker: futures_util::task::Waker,
137}
138
139#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
140pub struct TaskId(u64);