Problem statement
I'd like to implement a directed acyclic computation graph framework in async Rust, i.e. an interconnected graph of computation "nodes", each of which takes inputs from predecessor nodes and produces outputs for successor nodes. I was planning to implement this by spawning a collection of Futures, one for each computation node, while allowing dependencies among futures. However, in implementing this framework using async I've become hopelessly lost in compiler errors.
Minimal example
Here's an attempt at a minimal example of what I want to do. There's a single input list of floats values, and the task is to make a new list output where output[i] = values[i] output[i - 2]. This is what I've tried:
use std::sync;
fn some_complicated_expensive_fn(val1: f32, val2: f32) -> f32 {
val1 val2
}
fn example_async(values: &Vec<f32>) -> Vec<f32> {
let runtime = tokio::runtime::Runtime::new().unwrap();
let join_handles = sync::Arc::new(sync::Mutex::new(Vec::<tokio::task::JoinHandle<f32>>::new()));
for (i, value) in values.iter().enumerate() {
let future = {
let join_handles = join_handles.clone();
async move {
if i < 2 {
*value
} else {
let prev_value = join_handles.lock().unwrap()[i - 2].await.unwrap();
some_complicated_expensive_fn(*value, prev_value)
}
}
};
join_handles.lock().unwrap().push(runtime.spawn(future));
}
join_handles
.lock()
.unwrap()
.iter_mut()
.map(|join_handle| runtime.block_on(join_handle).unwrap())
.collect()
}
#[cfg(test)]
mod tests {
#[test]
fn test_example() {
let values = vec![1., 2., 3., 4., 5., 6.];
println!("{:?}", super::example_async(&values));
}
}
I get errors about the unlocked Mutex not being Send:
error: future cannot be sent between threads safely
--> sim/src/compsim/runtime.rs:23:51
|
23 | join_handles.lock().unwrap().push(runtime.spawn(future));
| ^^^^^ future created by async block is not `Send`
|
= help: within `impl Future`, the trait `Send` is not implemented for `std::sync::MutexGuard<'_, Vec<tokio::task::JoinHandle<f32>>>`
note: future is not `Send` as this value is used across an await
--> sim/src/compsim/runtime.rs:18:38
|
18 | let prev_value = join_handles.lock().unwrap()[i - 2].await.unwrap();
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ first, await occurs here, with `join_handles.lock().unwrap()` maybe used later...
note: `join_handles.lock().unwrap()` is later dropped here
--> sim/src/compsim/runtime.rs:18:88
|
18 | let prev_value = join_handles.lock().unwrap()[i - 2].await.unwrap();
| ---------------------------- ^
| |
| has type `std::sync::MutexGuard<'_, Vec<tokio::task::JoinHandle<f32>>>` which is not `Send`
help: consider moving this into a `let` binding to create a shorter lived borrow
--> sim/src/compsim/runtime.rs:18:38
|
18 | let prev_value = join_handles.lock().unwrap()[i - 2].await.unwrap();
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This makes sense, and I see in the Tokio docs that you can use a tokio::task::Mutex instead, but a) I'm not sure how, and b) I'm wondering if there's a better overall approach that I'm missing. Help greatly appreciated! Thanks.
CodePudding user response:
The compiler is complaining that you can't cross an await point with join_handle being locked. You could resolve this by making the lock shorter-lived, e.g. by keeping the handles in a Mutex<Vec<Option<JoinHandle>>>, taking the handle, and returning it back. But then you run into the issue that awaiting a JoinHandle spends it - you receive the value that was returned by the task, and you lose the handle, so you can't return it to the vector. (This is a consequence of Rust's strong value semantics, which means that once you have the value, the handle no longer has it, and it's as good as dead.)
A handle basically works like a one-shot channel for the result of the spawned task. But you need channels that can be read twice: once by another member of the pipeline, and again by example_async() itself. Such a channel would require clonable values, and would hold onto each value around until all subscribers have observed it. Fortunately, Tokio provides exactly such a channel. Instead of returning results, the task broadcasts them on the appropriate broadcast channel, which is once read from another task, and again when collecting the results.
fn example_async(values: &Vec<f32>) -> Vec<f32> {
let runtime = tokio::runtime::Runtime::new().unwrap();
let (txs, rxs1): (Vec<_>, Vec<_>) = (0..values.len())
.map(|_| tokio::sync::broadcast::channel(1))
.unzip();
let txs = Arc::new(txs);
let rxs2: Arc<Vec<_>> = Arc::new(
txs.iter()
.map(|tx| Mutex::new(Some(tx.subscribe())))
.collect(),
);
for (i, value) in values.iter().copied().enumerate() {
let future = {
let rxs2 = Arc::clone(&rxs2);
let txs = Arc::clone(&txs);
async move {
let result = if i < 2 {
value
} else {
let mut prev_rx = rxs2[i - 2].lock().unwrap().take().unwrap();
let prev_value = prev_rx.recv().await.unwrap();
some_complicated_expensive_fn(value, prev_value)
};
txs[i].send(result).unwrap();
}
};
runtime.spawn(future);
}
rxs1.into_iter()
.map(|mut rx| runtime.block_on(rx.recv()).unwrap())
.collect()
}
