mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
refactor: decompose orchestrator handler, deduplicate headers, fix unwraps
This commit is contained in:
parent
2c7d3a9c6c
commit
dd74df6ca8
6 changed files with 166 additions and 134 deletions
|
|
@ -89,7 +89,7 @@ pub async fn agent_chat(
|
||||||
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
|
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
response.headers_mut().insert(
|
response.headers_mut().insert(
|
||||||
hyper::header::CONTENT_TYPE,
|
hyper::header::CONTENT_TYPE,
|
||||||
"application/json".parse().unwrap(),
|
hyper::header::HeaderValue::from_static("application/json"),
|
||||||
);
|
);
|
||||||
return Ok(response);
|
return Ok(response);
|
||||||
}
|
}
|
||||||
|
|
@ -102,16 +102,22 @@ pub async fn agent_chat(
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_agent_chat_inner(
|
/// Parsed and validated agent request data.
|
||||||
|
struct AgentRequest {
|
||||||
|
client_request: ProviderRequestType,
|
||||||
|
messages: Vec<OpenAIMessage>,
|
||||||
|
request_headers: hyper::HeaderMap,
|
||||||
|
request_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse the incoming HTTP request, resolve the listener, and extract messages.
|
||||||
|
async fn parse_agent_request(
|
||||||
request: Request<hyper::body::Incoming>,
|
request: Request<hyper::body::Incoming>,
|
||||||
state: Arc<AppState>,
|
state: &AppState,
|
||||||
request_id: String,
|
request_id: &str,
|
||||||
custom_attrs: std::collections::HashMap<String, String>,
|
custom_attrs: &std::collections::HashMap<String, String>,
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
) -> Result<(AgentRequest, common::configuration::Listener, AgentSelector), AgentFilterChainError> {
|
||||||
// Initialize services
|
|
||||||
let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service));
|
let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service));
|
||||||
let mut pipeline_processor = PipelineProcessor::default();
|
|
||||||
let response_handler = ResponseHandler::new();
|
|
||||||
|
|
||||||
// Extract listener name from headers
|
// Extract listener name from headers
|
||||||
let listener_name = request
|
let listener_name = request
|
||||||
|
|
@ -129,7 +135,7 @@ async fn handle_agent_chat_inner(
|
||||||
|
|
||||||
get_active_span(|span| {
|
get_active_span(|span| {
|
||||||
span.update_name(listener.name.to_string());
|
span.update_name(listener.name.to_string());
|
||||||
for (key, value) in &custom_attrs {
|
for (key, value) in custom_attrs {
|
||||||
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
@ -147,12 +153,10 @@ async fn handle_agent_chat_inner(
|
||||||
let mut headers = request.headers().clone();
|
let mut headers = request.headers().clone();
|
||||||
headers.remove(common::consts::ENVOY_ORIGINAL_PATH_HEADER);
|
headers.remove(common::consts::ENVOY_ORIGINAL_PATH_HEADER);
|
||||||
|
|
||||||
// Set the request_id in headers if not already present
|
|
||||||
if !headers.contains_key(common::consts::REQUEST_ID_HEADER) {
|
if !headers.contains_key(common::consts::REQUEST_ID_HEADER) {
|
||||||
headers.insert(
|
if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
|
||||||
common::consts::REQUEST_ID_HEADER,
|
headers.insert(common::consts::REQUEST_ID_HEADER, val);
|
||||||
hyper::header::HeaderValue::from_str(&request_id).unwrap(),
|
}
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
headers
|
headers
|
||||||
|
|
@ -165,7 +169,6 @@ async fn handle_agent_chat_inner(
|
||||||
"received request body"
|
"received request body"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Determine the API type from the endpoint
|
|
||||||
let api_type =
|
let api_type =
|
||||||
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| {
|
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| {
|
||||||
let err_msg = format!("Unsupported endpoint: {}", request_path);
|
let err_msg = format!("Unsupported endpoint: {}", request_path);
|
||||||
|
|
@ -173,25 +176,48 @@ async fn handle_agent_chat_inner(
|
||||||
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
|
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
|
let client_request = ProviderRequestType::try_from((&chat_request_bytes[..], &api_type))
|
||||||
Ok(request) => request,
|
.map_err(|err| {
|
||||||
Err(err) => {
|
|
||||||
warn!("failed to parse request as ProviderRequestType: {}", err);
|
warn!("failed to parse request as ProviderRequestType: {}", err);
|
||||||
let err_msg = format!("Failed to parse request: {}", err);
|
AgentFilterChainError::RequestParsing(serde_json::Error::custom(format!(
|
||||||
return Err(AgentFilterChainError::RequestParsing(
|
"Failed to parse request: {}",
|
||||||
serde_json::Error::custom(err_msg),
|
err
|
||||||
));
|
)))
|
||||||
}
|
})?;
|
||||||
};
|
|
||||||
|
|
||||||
let message: Vec<OpenAIMessage> = client_request.get_messages();
|
let messages: Vec<OpenAIMessage> = client_request.get_messages();
|
||||||
|
|
||||||
let request_id = request_headers
|
let request_id = request_headers
|
||||||
.get(common::consts::REQUEST_ID_HEADER)
|
.get(common::consts::REQUEST_ID_HEADER)
|
||||||
.and_then(|val| val.to_str().ok())
|
.and_then(|val| val.to_str().ok())
|
||||||
.map(|s| s.to_string());
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
// Create agent map for pipeline processing and agent selection
|
Ok((
|
||||||
|
AgentRequest {
|
||||||
|
client_request,
|
||||||
|
messages,
|
||||||
|
request_headers,
|
||||||
|
request_id,
|
||||||
|
},
|
||||||
|
listener,
|
||||||
|
agent_selector,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Select agents via the orchestrator model and record selection metrics.
|
||||||
|
async fn select_and_build_agent_map(
|
||||||
|
agent_selector: &AgentSelector,
|
||||||
|
state: &AppState,
|
||||||
|
messages: &[OpenAIMessage],
|
||||||
|
listener: &common::configuration::Listener,
|
||||||
|
request_id: Option<String>,
|
||||||
|
) -> Result<
|
||||||
|
(
|
||||||
|
Vec<common::configuration::AgentFilterChain>,
|
||||||
|
std::collections::HashMap<String, common::configuration::Agent>,
|
||||||
|
),
|
||||||
|
AgentFilterChainError,
|
||||||
|
> {
|
||||||
let agent_map = {
|
let agent_map = {
|
||||||
let agents = state.agents_list.read().await;
|
let agents = state.agents_list.read().await;
|
||||||
let agents = agents.as_ref().ok_or_else(|| {
|
let agents = agents.as_ref().ok_or_else(|| {
|
||||||
|
|
@ -200,13 +226,11 @@ async fn handle_agent_chat_inner(
|
||||||
agent_selector.create_agent_map(agents)
|
agent_selector.create_agent_map(agents)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Select appropriate agents using arch orchestrator llm model
|
|
||||||
let selection_start = Instant::now();
|
let selection_start = Instant::now();
|
||||||
let selected_agents = agent_selector
|
let selected_agents = agent_selector
|
||||||
.select_agents(&message, &listener, request_id.clone())
|
.select_agents(messages, listener, request_id)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Record selection attributes on the current orchestrator span
|
|
||||||
let selection_elapsed_ms = selection_start.elapsed().as_secs_f64() * 1000.0;
|
let selection_elapsed_ms = selection_start.elapsed().as_secs_f64() * 1000.0;
|
||||||
get_active_span(|span| {
|
get_active_span(|span| {
|
||||||
span.set_attribute(opentelemetry::KeyValue::new(
|
span.set_attribute(opentelemetry::KeyValue::new(
|
||||||
|
|
@ -236,12 +260,25 @@ async fn handle_agent_chat_inner(
|
||||||
"selected agents for execution"
|
"selected agents for execution"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Execute agents sequentially, passing output from one to the next
|
Ok((selected_agents, agent_map))
|
||||||
let mut current_messages = message.clone();
|
}
|
||||||
|
|
||||||
|
/// Execute the agent chain: run each selected agent sequentially, streaming
|
||||||
|
/// the final agent's response back to the client.
|
||||||
|
async fn execute_agent_chain(
|
||||||
|
selected_agents: &[common::configuration::AgentFilterChain],
|
||||||
|
agent_map: &std::collections::HashMap<String, common::configuration::Agent>,
|
||||||
|
client_request: ProviderRequestType,
|
||||||
|
messages: Vec<OpenAIMessage>,
|
||||||
|
request_headers: &hyper::HeaderMap,
|
||||||
|
custom_attrs: &std::collections::HashMap<String, String>,
|
||||||
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||||
|
let mut pipeline_processor = PipelineProcessor::default();
|
||||||
|
let response_handler = ResponseHandler::new();
|
||||||
|
let mut current_messages = messages;
|
||||||
let agent_count = selected_agents.len();
|
let agent_count = selected_agents.len();
|
||||||
|
|
||||||
for (agent_index, selected_agent) in selected_agents.iter().enumerate() {
|
for (agent_index, selected_agent) in selected_agents.iter().enumerate() {
|
||||||
// Get agent name
|
|
||||||
let agent_name = selected_agent.id.clone();
|
let agent_name = selected_agent.id.clone();
|
||||||
let is_last_agent = agent_index == agent_count - 1;
|
let is_last_agent = agent_index == agent_count - 1;
|
||||||
|
|
||||||
|
|
@ -252,17 +289,15 @@ async fn handle_agent_chat_inner(
|
||||||
"processing agent"
|
"processing agent"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Process the filter chain
|
|
||||||
let chat_history = pipeline_processor
|
let chat_history = pipeline_processor
|
||||||
.process_filter_chain(
|
.process_filter_chain(
|
||||||
¤t_messages,
|
¤t_messages,
|
||||||
selected_agent,
|
selected_agent,
|
||||||
&agent_map,
|
agent_map,
|
||||||
&request_headers,
|
request_headers,
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Get agent details and invoke
|
|
||||||
let agent = agent_map.get(&agent_name).ok_or_else(|| {
|
let agent = agent_map.get(&agent_name).ok_or_else(|| {
|
||||||
AgentFilterChainError::RequestParsing(serde_json::Error::custom(format!(
|
AgentFilterChainError::RequestParsing(serde_json::Error::custom(format!(
|
||||||
"Selected agent '{}' not found in configuration",
|
"Selected agent '{}' not found in configuration",
|
||||||
|
|
@ -282,7 +317,7 @@ async fn handle_agent_chat_inner(
|
||||||
set_service_name(operation_component::AGENT);
|
set_service_name(operation_component::AGENT);
|
||||||
get_active_span(|span| {
|
get_active_span(|span| {
|
||||||
span.update_name(format!("{} /v1/chat/completions", agent_name));
|
span.update_name(format!("{} /v1/chat/completions", agent_name));
|
||||||
for (key, value) in &custom_attrs {
|
for (key, value) in custom_attrs {
|
||||||
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
@ -292,28 +327,25 @@ async fn handle_agent_chat_inner(
|
||||||
&chat_history,
|
&chat_history,
|
||||||
client_request.clone(),
|
client_request.clone(),
|
||||||
agent,
|
agent,
|
||||||
&request_headers,
|
request_headers,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
.instrument(agent_span.clone())
|
.instrument(agent_span.clone())
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// If this is the last agent, return the streaming response
|
|
||||||
if is_last_agent {
|
if is_last_agent {
|
||||||
info!(
|
info!(
|
||||||
agent = %agent_name,
|
agent = %agent_name,
|
||||||
"completed agent chain, returning response"
|
"completed agent chain, returning response"
|
||||||
);
|
);
|
||||||
// Capture the orchestrator span (parent of the agent span) so it
|
|
||||||
// stays open for the full streaming duration alongside the agent span.
|
|
||||||
let orchestrator_span = tracing::Span::current();
|
let orchestrator_span = tracing::Span::current();
|
||||||
return async {
|
return async {
|
||||||
response_handler
|
response_handler
|
||||||
.create_streaming_response(
|
.create_streaming_response(
|
||||||
llm_response,
|
llm_response,
|
||||||
tracing::Span::current(), // agent span (inner)
|
tracing::Span::current(),
|
||||||
orchestrator_span, // orchestrator span (outer)
|
orchestrator_span,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(AgentFilterChainError::from)
|
.map_err(AgentFilterChainError::from)
|
||||||
|
|
@ -322,7 +354,6 @@ async fn handle_agent_chat_inner(
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// For intermediate agents, collect the full response and pass to next agent
|
|
||||||
debug!(agent = %agent_name, "collecting response from intermediate agent");
|
debug!(agent = %agent_name, "collecting response from intermediate agent");
|
||||||
let response_text = async { response_handler.collect_full_response(llm_response).await }
|
let response_text = async { response_handler.collect_full_response(llm_response).await }
|
||||||
.instrument(agent_span)
|
.instrument(agent_span)
|
||||||
|
|
@ -334,14 +365,11 @@ async fn handle_agent_chat_inner(
|
||||||
"agent completed, passing response to next agent"
|
"agent completed, passing response to next agent"
|
||||||
);
|
);
|
||||||
|
|
||||||
// remove last message and add new one at the end
|
|
||||||
let Some(last_message) = current_messages.pop() else {
|
let Some(last_message) = current_messages.pop() else {
|
||||||
warn!(agent = %agent_name, "no messages in conversation history");
|
warn!(agent = %agent_name, "no messages in conversation history");
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create a new message with the agent's response as assistant message
|
|
||||||
// and add it to the conversation history
|
|
||||||
current_messages.push(OpenAIMessage {
|
current_messages.push(OpenAIMessage {
|
||||||
role: hermesllm::apis::openai::Role::Assistant,
|
role: hermesllm::apis::openai::Role::Assistant,
|
||||||
content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)),
|
content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)),
|
||||||
|
|
@ -353,6 +381,34 @@ async fn handle_agent_chat_inner(
|
||||||
current_messages.push(last_message);
|
current_messages.push(last_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This should never be reached since we return in the last agent iteration
|
|
||||||
unreachable!("Agent execution loop should have returned a response")
|
unreachable!("Agent execution loop should have returned a response")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn handle_agent_chat_inner(
|
||||||
|
request: Request<hyper::body::Incoming>,
|
||||||
|
state: Arc<AppState>,
|
||||||
|
request_id: String,
|
||||||
|
custom_attrs: std::collections::HashMap<String, String>,
|
||||||
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
|
||||||
|
let (agent_req, listener, agent_selector) =
|
||||||
|
parse_agent_request(request, &state, &request_id, &custom_attrs).await?;
|
||||||
|
|
||||||
|
let (selected_agents, agent_map) = select_and_build_agent_map(
|
||||||
|
&agent_selector,
|
||||||
|
&state,
|
||||||
|
&agent_req.messages,
|
||||||
|
&listener,
|
||||||
|
agent_req.request_id,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
execute_agent_chain(
|
||||||
|
&selected_agents,
|
||||||
|
&agent_map,
|
||||||
|
agent_req.client_request,
|
||||||
|
agent_req.messages,
|
||||||
|
&agent_req.request_headers,
|
||||||
|
&custom_attrs,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -141,12 +141,11 @@ impl PipelineProcessor {
|
||||||
Ok(chat_history_updated)
|
Ok(chat_history_updated)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build common MCP headers for requests
|
/// Prepare headers shared by all agent/filter requests: removes
|
||||||
fn build_mcp_headers(
|
/// content-length, injects trace context, sets upstream host and retry.
|
||||||
&self,
|
fn build_agent_headers(
|
||||||
request_headers: &HeaderMap,
|
request_headers: &HeaderMap,
|
||||||
agent_id: &str,
|
agent_id: &str,
|
||||||
session_id: Option<&str>,
|
|
||||||
) -> Result<HeaderMap, PipelineError> {
|
) -> Result<HeaderMap, PipelineError> {
|
||||||
let mut headers = request_headers.clone();
|
let mut headers = request_headers.clone();
|
||||||
headers.remove(hyper::header::CONTENT_LENGTH);
|
headers.remove(hyper::header::CONTENT_LENGTH);
|
||||||
|
|
@ -167,24 +166,34 @@ impl PipelineProcessor {
|
||||||
|
|
||||||
headers.insert(
|
headers.insert(
|
||||||
ENVOY_RETRY_HEADER,
|
ENVOY_RETRY_HEADER,
|
||||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
hyper::header::HeaderValue::from_static("3"),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
Ok(headers)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build headers for MCP requests (adds Accept, Content-Type, optional session id).
|
||||||
|
fn build_mcp_headers(
|
||||||
|
&self,
|
||||||
|
request_headers: &HeaderMap,
|
||||||
|
agent_id: &str,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> Result<HeaderMap, PipelineError> {
|
||||||
|
let mut headers = Self::build_agent_headers(request_headers, agent_id)?;
|
||||||
|
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"Accept",
|
"Accept",
|
||||||
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
|
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
|
||||||
);
|
);
|
||||||
|
|
||||||
headers.insert(
|
headers.insert(
|
||||||
"Content-Type",
|
"Content-Type",
|
||||||
hyper::header::HeaderValue::from_static("application/json"),
|
hyper::header::HeaderValue::from_static("application/json"),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some(sid) = session_id {
|
if let Some(sid) = session_id {
|
||||||
headers.insert(
|
if let Ok(val) = hyper::header::HeaderValue::from_str(sid) {
|
||||||
"mcp-session-id",
|
headers.insert("mcp-session-id", val);
|
||||||
hyper::header::HeaderValue::from_str(sid).unwrap(),
|
}
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(headers)
|
Ok(headers)
|
||||||
|
|
@ -530,33 +539,11 @@ impl PipelineProcessor {
|
||||||
});
|
});
|
||||||
|
|
||||||
// Build headers
|
// Build headers
|
||||||
let mut agent_headers = request_headers.clone();
|
let mut agent_headers = Self::build_agent_headers(request_headers, &agent.id)?;
|
||||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
|
||||||
|
|
||||||
// Inject OpenTelemetry trace context automatically
|
|
||||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
|
||||||
global::get_text_map_propagator(|propagator| {
|
|
||||||
let cx =
|
|
||||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
|
||||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
|
|
||||||
});
|
|
||||||
|
|
||||||
agent_headers.insert(
|
|
||||||
ARCH_UPSTREAM_HOST_HEADER,
|
|
||||||
hyper::header::HeaderValue::from_str(&agent.id)
|
|
||||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
|
||||||
);
|
|
||||||
|
|
||||||
agent_headers.insert(
|
|
||||||
ENVOY_RETRY_HEADER,
|
|
||||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
agent_headers.insert(
|
agent_headers.insert(
|
||||||
"Accept",
|
"Accept",
|
||||||
hyper::header::HeaderValue::from_static("application/json"),
|
hyper::header::HeaderValue::from_static("application/json"),
|
||||||
);
|
);
|
||||||
|
|
||||||
agent_headers.insert(
|
agent_headers.insert(
|
||||||
"Content-Type",
|
"Content-Type",
|
||||||
hyper::header::HeaderValue::from_static("application/json"),
|
hyper::header::HeaderValue::from_static("application/json"),
|
||||||
|
|
@ -629,27 +616,7 @@ impl PipelineProcessor {
|
||||||
.map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?;
|
.map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?;
|
||||||
debug!("sending request to terminal agent {}", terminal_agent.id);
|
debug!("sending request to terminal agent {}", terminal_agent.id);
|
||||||
|
|
||||||
let mut agent_headers = request_headers.clone();
|
let agent_headers = Self::build_agent_headers(request_headers, &terminal_agent.id)?;
|
||||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
|
||||||
|
|
||||||
// Inject OpenTelemetry trace context automatically
|
|
||||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
|
||||||
global::get_text_map_propagator(|propagator| {
|
|
||||||
let cx =
|
|
||||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
|
||||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
|
|
||||||
});
|
|
||||||
|
|
||||||
agent_headers.insert(
|
|
||||||
ARCH_UPSTREAM_HOST_HEADER,
|
|
||||||
hyper::header::HeaderValue::from_str(&terminal_agent.id)
|
|
||||||
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.id.clone()))?,
|
|
||||||
);
|
|
||||||
|
|
||||||
agent_headers.insert(
|
|
||||||
ENVOY_RETRY_HEADER,
|
|
||||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
|
|
|
||||||
|
|
@ -497,13 +497,16 @@ async fn send_upstream(
|
||||||
"Routing to upstream"
|
"Routing to upstream"
|
||||||
);
|
);
|
||||||
|
|
||||||
request_headers.insert(
|
if let Ok(val) = header::HeaderValue::from_str(resolved_model) {
|
||||||
ARCH_PROVIDER_HINT_HEADER,
|
request_headers.insert(ARCH_PROVIDER_HINT_HEADER, val);
|
||||||
header::HeaderValue::from_str(resolved_model).unwrap(),
|
}
|
||||||
);
|
|
||||||
request_headers.insert(
|
request_headers.insert(
|
||||||
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
|
||||||
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
|
header::HeaderValue::from_static(if is_streaming_request {
|
||||||
|
"true"
|
||||||
|
} else {
|
||||||
|
"false"
|
||||||
|
}),
|
||||||
);
|
);
|
||||||
request_headers.remove(header::CONTENT_LENGTH);
|
request_headers.remove(header::CONTENT_LENGTH);
|
||||||
|
|
||||||
|
|
@ -535,9 +538,10 @@ async fn send_upstream(
|
||||||
let response_headers = llm_response.headers().clone();
|
let response_headers = llm_response.headers().clone();
|
||||||
let upstream_status = llm_response.status();
|
let upstream_status = llm_response.status();
|
||||||
let mut response = Response::builder().status(upstream_status);
|
let mut response = Response::builder().status(upstream_status);
|
||||||
let headers = response.headers_mut().unwrap();
|
if let Some(headers) = response.headers_mut() {
|
||||||
for (name, value) in response_headers.iter() {
|
for (name, value) in response_headers.iter() {
|
||||||
headers.insert(name, value.clone());
|
headers.insert(name, value.clone());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let byte_stream = llm_response.bytes_stream();
|
let byte_stream = llm_response.bytes_stream();
|
||||||
|
|
@ -637,8 +641,9 @@ async fn get_upstream_path(
|
||||||
) -> String {
|
) -> String {
|
||||||
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
|
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
|
||||||
|
|
||||||
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
|
let Some(client_api) = SupportedAPIsFromClient::from_endpoint(request_path) else {
|
||||||
.expect("Should have valid API endpoint");
|
return request_path.to_string();
|
||||||
|
};
|
||||||
|
|
||||||
client_api.target_endpoint_for_provider(
|
client_api.target_endpoint_for_provider(
|
||||||
&provider_id,
|
&provider_id,
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ impl ResponseHandler {
|
||||||
*response.status_mut() = hyper::StatusCode::BAD_REQUEST;
|
*response.status_mut() = hyper::StatusCode::BAD_REQUEST;
|
||||||
response.headers_mut().insert(
|
response.headers_mut().insert(
|
||||||
hyper::header::CONTENT_TYPE,
|
hyper::header::CONTENT_TYPE,
|
||||||
"application/json".parse().unwrap(),
|
hyper::header::HeaderValue::from_static("application/json"),
|
||||||
);
|
);
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
|
|
@ -55,7 +55,7 @@ impl ResponseHandler {
|
||||||
*response.status_mut() = hyper::StatusCode::INTERNAL_SERVER_ERROR;
|
*response.status_mut() = hyper::StatusCode::INTERNAL_SERVER_ERROR;
|
||||||
response.headers_mut().insert(
|
response.headers_mut().insert(
|
||||||
hyper::header::CONTENT_TYPE,
|
hyper::header::CONTENT_TYPE,
|
||||||
"application/json".parse().unwrap(),
|
hyper::header::HeaderValue::from_static("application/json"),
|
||||||
);
|
);
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,9 @@ impl RouterService {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (usage_preferences.is_none() || usage_preferences.as_ref().unwrap().len() < 2)
|
if usage_preferences
|
||||||
|
.as_ref()
|
||||||
|
.is_none_or(|prefs| prefs.len() < 2)
|
||||||
&& !self.llm_usage_defined
|
&& !self.llm_usage_defined
|
||||||
{
|
{
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
|
|
@ -108,18 +110,18 @@ impl RouterService {
|
||||||
header::CONTENT_TYPE,
|
header::CONTENT_TYPE,
|
||||||
header::HeaderValue::from_static("application/json"),
|
header::HeaderValue::from_static("application/json"),
|
||||||
);
|
);
|
||||||
headers.insert(
|
if let Ok(val) = header::HeaderValue::from_str(&self.routing_provider_name) {
|
||||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
headers.insert(
|
||||||
header::HeaderValue::from_str(&self.routing_provider_name).unwrap(),
|
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||||
);
|
val,
|
||||||
headers.insert(
|
);
|
||||||
header::HeaderName::from_static(TRACE_PARENT_HEADER),
|
}
|
||||||
header::HeaderValue::from_str(traceparent).unwrap(),
|
if let Ok(val) = header::HeaderValue::from_str(traceparent) {
|
||||||
);
|
headers.insert(header::HeaderName::from_static(TRACE_PARENT_HEADER), val);
|
||||||
headers.insert(
|
}
|
||||||
header::HeaderName::from_static(REQUEST_ID_HEADER),
|
if let Ok(val) = header::HeaderValue::from_str(request_id) {
|
||||||
header::HeaderValue::from_str(request_id).unwrap(),
|
headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
|
||||||
);
|
}
|
||||||
headers.insert(
|
headers.insert(
|
||||||
header::HeaderName::from_static("model"),
|
header::HeaderName::from_static("model"),
|
||||||
header::HeaderValue::from_static("arch-router"),
|
header::HeaderValue::from_static("arch-router"),
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,10 @@ impl OrchestratorService {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
if usage_preferences.is_none() || usage_preferences.as_ref().unwrap().is_empty() {
|
if usage_preferences
|
||||||
|
.as_ref()
|
||||||
|
.is_none_or(|prefs| prefs.is_empty())
|
||||||
|
{
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -85,7 +88,7 @@ impl OrchestratorService {
|
||||||
);
|
);
|
||||||
headers.insert(
|
headers.insert(
|
||||||
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
|
||||||
header::HeaderValue::from_str(PLANO_ORCHESTRATOR_MODEL_NAME).unwrap(),
|
header::HeaderValue::from_static(PLANO_ORCHESTRATOR_MODEL_NAME),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Inject OpenTelemetry trace context from current span
|
// Inject OpenTelemetry trace context from current span
|
||||||
|
|
@ -96,10 +99,9 @@ impl OrchestratorService {
|
||||||
});
|
});
|
||||||
|
|
||||||
if let Some(ref request_id) = request_id {
|
if let Some(ref request_id) = request_id {
|
||||||
headers.insert(
|
if let Ok(val) = header::HeaderValue::from_str(request_id) {
|
||||||
header::HeaderName::from_static(REQUEST_ID_HEADER),
|
headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
|
||||||
header::HeaderValue::from_str(request_id).unwrap(),
|
}
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
headers.insert(
|
headers.insert(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue