Skip to content

Commit

Permalink
Update try_op.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
crStiv authored Feb 11, 2025
1 parent 39ce2bc commit af88429
Showing 1 changed file with 41 additions and 43 deletions.
84 changes: 41 additions & 43 deletions rig-core/src/pipeline/try_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -345,36 +345,33 @@ where
}
}

pub struct TryParallel<Op1, Op2> {
op1: Op1,
op2: Op2,
}

impl<Op1, Op2> TryParallel<Op1, Op2> {
pub fn new(op1: Op1, op2: Op2) -> Self {
Self { op1, op2 }
}
}

impl<Op1, Op2> TryOp for TryParallel<Op1, Op2>
where
Op1: TryOp,
Op1::Input: Clone,
Op2: TryOp<Input = Op1::Input, Error = Op1::Error>,
{
type Input = Op1::Input;
type Output = (Op1::Output, Op2::Output);
type Error = Op1::Error;

#[inline]
async fn try_call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
use futures::try_join;
try_join!(
self.op1.try_call(input.clone()),
self.op2.try_call(input)
)
}
}
// TODO: Implement TryParallel
// pub struct TryParallel<Op1, Op2> {
// op1: Op1,
// op2: Op2,
// }

// impl<Op1, Op2> TryParallel<Op1, Op2> {
// pub fn new(op1: Op1, op2: Op2) -> Self {
// Self { op1, op2 }
// }
// }

// impl<Op1, Op2> TryOp for TryParallel<Op1, Op2>
// where
// Op1: TryOp,
// Op2: TryOp<Input = Op1::Input, Output = Op1::Output, Error = Op1::Error>,
// {
// type Input = Op1::Input;
// type Output = (Op1::Output, Op2::Output);
// type Error = Op1::Error;

// #[inline]
// async fn try_call(&self, input: Self::Input) -> Result<Self::Output, Self::Error> {
// let (output1, output2) = tokio::join!(self.op1.try_call(input.clone()), self.op2.try_call(input));
// Ok((output1?, output2?))
// }
// }

#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -478,8 +475,20 @@ mod tests {

#[tokio::test]
async fn test_try_parallel() {
let op1 = map(|x: i32| if x % 2 == 0 { Ok(x + 1) } else { Err("x is odd") });
let op2 = map(|x: i32| if x % 2 == 0 { Ok(x * 2) } else { Err("x is odd") });
let op1 = map(|x: i32| {
if x % 2 == 0 {
Ok(x + 1)
} else {
Err("x is odd")
}
});
let op2 = map(|x: i32| {
if x % 2 == 0 {
Ok(x * 2)
} else {
Err("x is odd")
}
});
let pipeline = TryParallel::new(op1, op2);

let result = pipeline.try_call(2).await;
Expand All @@ -488,15 +497,4 @@ mod tests {
let result = pipeline.try_call(1).await;
assert_eq!(result, Err("x is odd"));
}

#[tokio::test]
async fn test_try_parallel_nested() {
let op1 = map(|x: i32| Ok::<_, &str>(x + 1));
let op2 = map(|x: i32| Ok::<_, &str>(x * 2));
let op3 = map(|x: i32| Ok::<_, &str>(x * 3));
let pipeline = TryParallel::new(TryParallel::new(op1, op2), op3);

let result = pipeline.try_call(2).await;
assert_eq!(result, Ok(((3, 4), 6)));
}
}

0 comments on commit af88429

Please sign in to comment.