From 2d69d3154146a8bee96ef9549ce86eec99560536 Mon Sep 17 00:00:00 2001 From: Nir Gazit Date: Tue, 12 Nov 2024 23:32:00 +0200 Subject: [PATCH] feat: implement pipeline steering logic (#5) --- src/routes.rs | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/routes.rs b/src/routes.rs index 431d468..9390f7a 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -1,16 +1,30 @@ use crate::{pipelines::pipeline::create_pipeline, state::AppState}; use axum::{extract::Request, routing::get, Router}; +use std::collections::HashMap; use std::sync::Arc; use tower::steer::Steer; pub fn create_router(state: Arc) -> Router { - let routers = state - .config - .pipelines - .iter() - .map(|pipeline| create_pipeline(pipeline, &state.model_registry)) - .collect::>(); - let pipeline_router = Steer::new(routers, |_req: &Request, _services: &[_]| 0); + let mut pipeline_idxs = HashMap::new(); + let mut routers = Vec::new(); + + // Sort pipelines to ensure default is first + let mut sorted_pipelines: Vec<_> = state.config.pipelines.clone(); + sorted_pipelines.sort_by_key(|p| p.name != "default"); // "default" will come first since false < true + + for pipeline in sorted_pipelines { + let name = pipeline.name.clone(); + pipeline_idxs.insert(name, routers.len()); + routers.push(create_pipeline(&pipeline, &state.model_registry)); + } + + let pipeline_router = Steer::new(routers, move |req: &Request, _services: &[_]| { + *req.headers() + .get("x-traceloop-pipeline") + .and_then(|h| h.to_str().ok()) + .and_then(|name| pipeline_idxs.get(name)) + .unwrap_or(&0) + }); Router::new() .nest_service("/api/v1", pipeline_router)