mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
add MCP raw filter support: body+path args, update mcp_filter demo handlers
This commit is contained in:
parent
d26abbfb9c
commit
b88bdb94f2
16 changed files with 226 additions and 116 deletions
|
|
@ -85,7 +85,7 @@ properties:
|
||||||
type: string
|
type: string
|
||||||
default:
|
default:
|
||||||
type: boolean
|
type: boolean
|
||||||
filter_chain:
|
input_filters:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
|
|
|
||||||
|
|
@ -332,15 +332,38 @@ async fn handle_agent_chat_inner(
|
||||||
"processing agent"
|
"processing agent"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Process the filter chain
|
// Process input filters — serialize current request as OpenAI chat completions body,
|
||||||
let chat_history = pipeline_processor
|
// pass raw bytes through each filter, then extract updated messages from the result.
|
||||||
.process_filter_chain(
|
let chat_history = if selected_agent
|
||||||
¤t_messages,
|
.input_filters
|
||||||
selected_agent,
|
.as_ref()
|
||||||
&agent_map,
|
.map(|f| !f.is_empty())
|
||||||
&request_headers,
|
.unwrap_or(false)
|
||||||
)
|
{
|
||||||
.await?;
|
let filter_body = serde_json::json!({
|
||||||
|
"model": client_request.model(),
|
||||||
|
"messages": current_messages,
|
||||||
|
});
|
||||||
|
let filter_bytes =
|
||||||
|
serde_json::to_vec(&filter_body).map_err(PipelineError::ParseError)?;
|
||||||
|
|
||||||
|
let filtered_bytes = pipeline_processor
|
||||||
|
.process_raw_filter_chain(
|
||||||
|
&filter_bytes,
|
||||||
|
selected_agent,
|
||||||
|
&agent_map,
|
||||||
|
&request_headers,
|
||||||
|
"/v1/chat/completions",
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let filtered_body: serde_json::Value =
|
||||||
|
serde_json::from_slice(&filtered_bytes).map_err(PipelineError::ParseError)?;
|
||||||
|
serde_json::from_value(filtered_body["messages"].clone())
|
||||||
|
.map_err(PipelineError::ParseError)?
|
||||||
|
} else {
|
||||||
|
current_messages.clone()
|
||||||
|
};
|
||||||
|
|
||||||
// Get agent details and invoke
|
// Get agent details and invoke
|
||||||
let agent = agent_map.get(&agent_name).unwrap();
|
let agent = agent_map.get(&agent_name).unwrap();
|
||||||
|
|
|
||||||
|
|
@ -187,7 +187,7 @@ mod tests {
|
||||||
id: name.to_string(),
|
id: name.to_string(),
|
||||||
description: Some(description.to_string()),
|
description: Some(description.to_string()),
|
||||||
default: Some(is_default),
|
default: Some(is_default),
|
||||||
filter_chain: Some(vec![name.to_string()]),
|
input_filters: Some(vec![name.to_string()]),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ mod tests {
|
||||||
|
|
||||||
let agent_pipeline = AgentFilterChain {
|
let agent_pipeline = AgentFilterChain {
|
||||||
id: "terminal-agent".to_string(),
|
id: "terminal-agent".to_string(),
|
||||||
filter_chain: Some(vec![
|
input_filters: Some(vec![
|
||||||
"filter-agent".to_string(),
|
"filter-agent".to_string(),
|
||||||
"terminal-agent".to_string(),
|
"terminal-agent".to_string(),
|
||||||
]),
|
]),
|
||||||
|
|
@ -110,7 +110,7 @@ mod tests {
|
||||||
// Create a pipeline with empty filter chain to avoid network calls
|
// Create a pipeline with empty filter chain to avoid network calls
|
||||||
let test_pipeline = AgentFilterChain {
|
let test_pipeline = AgentFilterChain {
|
||||||
id: "terminal-agent".to_string(),
|
id: "terminal-agent".to_string(),
|
||||||
filter_chain: Some(vec![]), // Empty filter chain - no network calls needed
|
input_filters: Some(vec![]), // Empty filter chain - no network calls needed
|
||||||
description: None,
|
description: None,
|
||||||
default: None,
|
default: None,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -279,7 +279,7 @@ async fn llm_chat_inner(
|
||||||
id: "model_listener".to_string(),
|
id: "model_listener".to_string(),
|
||||||
default: None,
|
default: None,
|
||||||
description: None,
|
description: None,
|
||||||
filter_chain: Some(fc.clone()),
|
input_filters: Some(fc.clone()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut pipeline_processor = PipelineProcessor::default();
|
let mut pipeline_processor = PipelineProcessor::default();
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ impl PipelineProcessor {
|
||||||
// #[instrument(
|
// #[instrument(
|
||||||
// skip(self, chat_history, agent_filter_chain, agent_map, request_headers),
|
// skip(self, chat_history, agent_filter_chain, agent_map, request_headers),
|
||||||
// fields(
|
// fields(
|
||||||
// filter_count = agent_filter_chain.filter_chain.as_ref().map(|fc| fc.len()).unwrap_or(0),
|
// filter_count = agent_filter_chain.input_filters.as_ref().map(|fc| fc.len()).unwrap_or(0),
|
||||||
// message_count = chat_history.len()
|
// message_count = chat_history.len()
|
||||||
// )
|
// )
|
||||||
// )]
|
// )]
|
||||||
|
|
@ -99,7 +99,7 @@ impl PipelineProcessor {
|
||||||
let mut chat_history_updated = chat_history.to_vec();
|
let mut chat_history_updated = chat_history.to_vec();
|
||||||
|
|
||||||
// If filter_chain is None or empty, proceed without filtering
|
// If filter_chain is None or empty, proceed without filtering
|
||||||
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
|
let filter_chain = match agent_filter_chain.input_filters.as_ref() {
|
||||||
Some(fc) if !fc.is_empty() => fc,
|
Some(fc) if !fc.is_empty() => fc,
|
||||||
_ => return Ok(chat_history_updated),
|
_ => return Ok(chat_history_updated),
|
||||||
};
|
};
|
||||||
|
|
@ -283,6 +283,30 @@ impl PipelineProcessor {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build a tools/call JSON-RPC request with a full body dict and path hint.
|
||||||
|
/// Used by execute_mcp_filter_raw so MCP tools receive the same contract as HTTP filters.
|
||||||
|
fn build_tool_call_request_with_body(
|
||||||
|
&self,
|
||||||
|
tool_name: &str,
|
||||||
|
body: &serde_json::Value,
|
||||||
|
path: &str,
|
||||||
|
) -> Result<JsonRpcRequest, PipelineError> {
|
||||||
|
let mut arguments = HashMap::new();
|
||||||
|
arguments.insert("body".to_string(), serde_json::to_value(body)?);
|
||||||
|
arguments.insert("path".to_string(), serde_json::to_value(path)?);
|
||||||
|
|
||||||
|
let mut params = HashMap::new();
|
||||||
|
params.insert("name".to_string(), serde_json::to_value(tool_name)?);
|
||||||
|
params.insert("arguments".to_string(), serde_json::to_value(arguments)?);
|
||||||
|
|
||||||
|
Ok(JsonRpcRequest {
|
||||||
|
jsonrpc: JSON_RPC_VERSION.to_string(),
|
||||||
|
id: JsonRpcId::String(Uuid::new_v4().to_string()),
|
||||||
|
method: TOOL_CALL_METHOD.to_string(),
|
||||||
|
params: Some(params),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Send request to a specific agent and return the response content
|
/// Send request to a specific agent and return the response content
|
||||||
#[instrument(
|
#[instrument(
|
||||||
skip(self, messages, agent, request_headers),
|
skip(self, messages, agent, request_headers),
|
||||||
|
|
@ -406,6 +430,106 @@ impl PipelineProcessor {
|
||||||
Ok(messages)
|
Ok(messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Like execute_mcp_filter but passes the full raw body dict + path hint as MCP tool arguments.
|
||||||
|
/// The MCP tool receives (body: dict, path: str) and returns the modified body dict.
|
||||||
|
async fn execute_mcp_filter_raw(
|
||||||
|
&mut self,
|
||||||
|
raw_bytes: &[u8],
|
||||||
|
agent: &Agent,
|
||||||
|
request_headers: &HeaderMap,
|
||||||
|
request_path: &str,
|
||||||
|
) -> Result<Bytes, PipelineError> {
|
||||||
|
set_service_name(operation_component::AGENT_FILTER);
|
||||||
|
use opentelemetry::trace::get_active_span;
|
||||||
|
get_active_span(|span| {
|
||||||
|
span.update_name(format!("execute_mcp_filter_raw ({})", agent.id));
|
||||||
|
});
|
||||||
|
|
||||||
|
let body: serde_json::Value =
|
||||||
|
serde_json::from_slice(raw_bytes).map_err(PipelineError::ParseError)?;
|
||||||
|
|
||||||
|
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
|
||||||
|
session_id.clone()
|
||||||
|
} else {
|
||||||
|
let session_id = self.get_new_session_id(&agent.id, request_headers).await;
|
||||||
|
self.agent_id_session_map
|
||||||
|
.insert(agent.id.clone(), session_id.clone());
|
||||||
|
session_id
|
||||||
|
};
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Using MCP session ID {} for agent {}",
|
||||||
|
mcp_session_id, agent.id
|
||||||
|
);
|
||||||
|
|
||||||
|
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||||
|
let json_rpc_request =
|
||||||
|
self.build_tool_call_request_with_body(tool_name, &body, request_path)?;
|
||||||
|
|
||||||
|
let agent_headers =
|
||||||
|
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id))?;
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.send_mcp_request(&json_rpc_request, &agent_headers, &agent.id)
|
||||||
|
.await?;
|
||||||
|
let http_status = response.status();
|
||||||
|
let response_bytes = response.bytes().await?;
|
||||||
|
|
||||||
|
if !http_status.is_success() {
|
||||||
|
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||||
|
return Err(if http_status.is_client_error() {
|
||||||
|
PipelineError::ClientError {
|
||||||
|
agent: agent.id.clone(),
|
||||||
|
status: http_status.as_u16(),
|
||||||
|
body: error_body,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PipelineError::ServerError {
|
||||||
|
agent: agent.id.clone(),
|
||||||
|
status: http_status.as_u16(),
|
||||||
|
body: error_body,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?;
|
||||||
|
let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?;
|
||||||
|
let response_result = response
|
||||||
|
.result
|
||||||
|
.ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?;
|
||||||
|
|
||||||
|
if response_result
|
||||||
|
.get("isError")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false)
|
||||||
|
{
|
||||||
|
let error_message = response_result
|
||||||
|
.get("content")
|
||||||
|
.and_then(|v| v.as_array())
|
||||||
|
.and_then(|arr| arr.first())
|
||||||
|
.and_then(|v| v.get("text"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("unknown_error")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
return Err(PipelineError::ClientError {
|
||||||
|
agent: agent.id.clone(),
|
||||||
|
status: hyper::StatusCode::BAD_REQUEST.as_u16(),
|
||||||
|
body: error_message,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = response_result
|
||||||
|
.get("structuredContent")
|
||||||
|
.and_then(|v| v.get("result"))
|
||||||
|
.cloned()
|
||||||
|
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
|
||||||
|
|
||||||
|
Ok(Bytes::from(
|
||||||
|
serde_json::to_vec(&result).map_err(PipelineError::ParseError)?,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
/// Build an initialize JSON-RPC request
|
/// Build an initialize JSON-RPC request
|
||||||
fn build_initialize_request(&self) -> JsonRpcRequest {
|
fn build_initialize_request(&self) -> JsonRpcRequest {
|
||||||
JsonRpcRequest {
|
JsonRpcRequest {
|
||||||
|
|
@ -708,7 +832,7 @@ impl PipelineProcessor {
|
||||||
request_headers: &HeaderMap,
|
request_headers: &HeaderMap,
|
||||||
request_path: &str,
|
request_path: &str,
|
||||||
) -> Result<Bytes, PipelineError> {
|
) -> Result<Bytes, PipelineError> {
|
||||||
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
|
let filter_chain = match agent_filter_chain.input_filters.as_ref() {
|
||||||
Some(fc) if !fc.is_empty() => fc,
|
Some(fc) if !fc.is_empty() => fc,
|
||||||
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
|
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
|
||||||
};
|
};
|
||||||
|
|
@ -722,16 +846,22 @@ impl PipelineProcessor {
|
||||||
.get(agent_name)
|
.get(agent_name)
|
||||||
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||||
|
|
||||||
|
let agent_type = agent.agent_type.as_deref().unwrap_or("mcp");
|
||||||
info!(
|
info!(
|
||||||
agent = %agent_name,
|
agent = %agent_name,
|
||||||
url = %agent.url,
|
url = %agent.url,
|
||||||
|
agent_type = %agent_type,
|
||||||
bytes_len = current_bytes.len(),
|
bytes_len = current_bytes.len(),
|
||||||
"executing raw filter"
|
"executing raw filter"
|
||||||
);
|
);
|
||||||
|
|
||||||
current_bytes = self
|
current_bytes = if agent_type == "mcp" {
|
||||||
.execute_raw_filter(¤t_bytes, agent, request_headers, request_path)
|
self.execute_mcp_filter_raw(¤t_bytes, agent, request_headers, request_path)
|
||||||
.await?;
|
.await?
|
||||||
|
} else {
|
||||||
|
self.execute_raw_filter(¤t_bytes, agent, request_headers, request_path)
|
||||||
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
|
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
|
||||||
}
|
}
|
||||||
|
|
@ -812,7 +942,7 @@ mod tests {
|
||||||
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
|
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
|
||||||
AgentFilterChain {
|
AgentFilterChain {
|
||||||
id: "test-agent".to_string(),
|
id: "test-agent".to_string(),
|
||||||
filter_chain: Some(agents.iter().map(|s| s.to_string()).collect()),
|
input_filters: Some(agents.iter().map(|s| s.to_string()).collect()),
|
||||||
description: None,
|
description: None,
|
||||||
default: None,
|
default: None,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -308,7 +308,7 @@ where
|
||||||
id: "output_filter".to_string(),
|
id: "output_filter".to_string(),
|
||||||
default: None,
|
default: None,
|
||||||
description: None,
|
description: None,
|
||||||
filter_chain: Some(output_filters),
|
input_filters: Some(output_filters),
|
||||||
};
|
};
|
||||||
|
|
||||||
while let Some(item) = byte_stream.next().await {
|
while let Some(item) = byte_stream.next().await {
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ pub struct AgentFilterChain {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub default: Option<bool>,
|
pub default: Option<bool>,
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
pub filter_chain: Option<Vec<String>>,
|
pub input_filters: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ listeners:
|
||||||
agents:
|
agents:
|
||||||
- id: rag_agent
|
- id: rag_agent
|
||||||
description: virtual assistant for retrieval augmented generation tasks
|
description: virtual assistant for retrieval augmented generation tasks
|
||||||
filter_chain:
|
input_filters:
|
||||||
- input_guards
|
- input_guards
|
||||||
- query_rewriter
|
- query_rewriter
|
||||||
- context_builder
|
- context_builder
|
||||||
|
|
|
||||||
|
|
@ -195,11 +195,11 @@ async def augment_query_with_context(
|
||||||
load_knowledge_base()
|
load_knowledge_base()
|
||||||
|
|
||||||
|
|
||||||
@app.post("/")
|
@app.post("/{path:path}")
|
||||||
async def context_builder(
|
async def context_builder(path: str, request: Request) -> dict:
|
||||||
messages: List[ChatMessage], request: Request
|
|
||||||
) -> List[ChatMessage]:
|
|
||||||
"""MCP tool that augments user queries with relevant context from the knowledge base."""
|
"""MCP tool that augments user queries with relevant context from the knowledge base."""
|
||||||
|
body = await request.json()
|
||||||
|
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||||
logger.info(f"Received chat completion request with {len(messages)} messages")
|
logger.info(f"Received chat completion request with {len(messages)} messages")
|
||||||
|
|
||||||
# Get traceparent header from MCP request
|
# Get traceparent header from MCP request
|
||||||
|
|
@ -219,8 +219,7 @@ async def context_builder(
|
||||||
messages, traceparent_header, request_id
|
messages, traceparent_header, request_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return as dict to minimize text serialization
|
return {**body, "messages": [{"role": msg.role, "content": msg.content} for msg in updated_messages]}
|
||||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
|
||||||
|
|
||||||
|
|
||||||
# Register MCP tool only if mcp is available
|
# Register MCP tool only if mcp is available
|
||||||
|
|
|
||||||
|
|
@ -126,14 +126,14 @@ Respond in JSON format:
|
||||||
|
|
||||||
|
|
||||||
# @mcp.tool
|
# @mcp.tool
|
||||||
@app.post("/")
|
@app.post("/{path:path}")
|
||||||
async def input_guards(
|
async def input_guards(path: str, request: Request) -> dict:
|
||||||
messages: List[ChatMessage], request: Request
|
|
||||||
) -> List[ChatMessage]:
|
|
||||||
"""Input guard that validates queries are within TechCorp's domain.
|
"""Input guard that validates queries are within TechCorp's domain.
|
||||||
|
|
||||||
If the query is out of scope, replaces the user message with a rejection notice.
|
If the query is out of scope, replaces the user message with a rejection notice.
|
||||||
"""
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||||
logger.info(f"Received request with {len(messages)} messages")
|
logger.info(f"Received request with {len(messages)} messages")
|
||||||
|
|
||||||
# Get traceparent header from HTTP request using FastMCP's dependency function
|
# Get traceparent header from HTTP request using FastMCP's dependency function
|
||||||
|
|
@ -164,7 +164,7 @@ async def input_guards(
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Query validation passed - forwarding to next filter")
|
logger.info("Query validation passed - forwarding to next filter")
|
||||||
return messages
|
return body
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
|
|
|
||||||
|
|
@ -81,11 +81,11 @@ Return only the rewritten query, nothing else."""
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@app.post("/")
|
@app.post("/{path:path}")
|
||||||
async def query_rewriter_http(
|
async def query_rewriter_http(path: str, request: Request) -> dict:
|
||||||
messages: List[ChatMessage], request: Request
|
|
||||||
) -> List[ChatMessage]:
|
|
||||||
"""HTTP filter endpoint used by Plano (type: http)."""
|
"""HTTP filter endpoint used by Plano (type: http)."""
|
||||||
|
body = await request.json()
|
||||||
|
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||||
logger.info(f"Received request with {len(messages)} messages")
|
logger.info(f"Received request with {len(messages)} messages")
|
||||||
|
|
||||||
traceparent_header = request.headers.get("traceparent")
|
traceparent_header = request.headers.get("traceparent")
|
||||||
|
|
@ -99,25 +99,20 @@ async def query_rewriter_http(
|
||||||
rewritten_query = await rewrite_query_with_plano(
|
rewritten_query = await rewrite_query_with_plano(
|
||||||
messages, traceparent_header, request_id
|
messages, traceparent_header, request_id
|
||||||
)
|
)
|
||||||
# Create updated messages with the rewritten query
|
|
||||||
updated_messages = messages.copy()
|
|
||||||
|
|
||||||
# Find and update the last user message with the rewritten query
|
# Find and update the last user message with the rewritten query
|
||||||
|
updated_messages = [m.model_dump() for m in messages]
|
||||||
for i in range(len(updated_messages) - 1, -1, -1):
|
for i in range(len(updated_messages) - 1, -1, -1):
|
||||||
if updated_messages[i].role == "user":
|
if updated_messages[i]["role"] == "user":
|
||||||
original_query = updated_messages[i].content
|
original_query = updated_messages[i]["content"]
|
||||||
updated_messages[i] = ChatMessage(role="user", content=rewritten_query)
|
updated_messages[i]["content"] = rewritten_query
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Updated user query from '{original_query}' to '{rewritten_query}'"
|
f"Updated user query from '{original_query}' to '{rewritten_query}'"
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
updated_messages_data = [
|
|
||||||
{"role": msg.role, "content": msg.content} for msg in updated_messages
|
|
||||||
]
|
|
||||||
updated_messages = [ChatMessage(**msg) for msg in updated_messages_data]
|
|
||||||
|
|
||||||
logger.info("Returning rewritten chat completion response")
|
logger.info("Returning rewritten chat completion response")
|
||||||
return updated_messages
|
return {**body, "messages": updated_messages}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ listeners:
|
||||||
agents:
|
agents:
|
||||||
- id: rag_agent
|
- id: rag_agent
|
||||||
description: virtual assistant for retrieval augmented generation tasks
|
description: virtual assistant for retrieval augmented generation tasks
|
||||||
filter_chain:
|
input_filters:
|
||||||
- input_guards
|
- input_guards
|
||||||
- query_rewriter
|
- query_rewriter
|
||||||
- context_builder
|
- context_builder
|
||||||
|
|
|
||||||
|
|
@ -195,9 +195,14 @@ async def augment_query_with_context(
|
||||||
load_knowledge_base()
|
load_knowledge_base()
|
||||||
|
|
||||||
|
|
||||||
async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
|
async def context_builder(body: dict, path: str) -> dict:
|
||||||
"""MCP tool that augments user queries with relevant context from the knowledge base."""
|
"""MCP tool that augments user queries with relevant context from the knowledge base.
|
||||||
logger.info(f"Received chat completion request with {len(messages)} messages")
|
|
||||||
|
Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
|
||||||
|
Returns the body with the last user message augmented with retrieved context.
|
||||||
|
"""
|
||||||
|
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||||
|
logger.info(f"Received request with {len(messages)} messages at path {path}")
|
||||||
|
|
||||||
# Get traceparent header from MCP request
|
# Get traceparent header from MCP request
|
||||||
headers = get_http_headers()
|
headers = get_http_headers()
|
||||||
|
|
@ -215,8 +220,7 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||||
messages, traceparent_header, request_id
|
messages, traceparent_header, request_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Return as dict to minimize text serialization
|
return {**body, "messages": [{"role": msg.role, "content": msg.content} for msg in updated_messages]}
|
||||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
|
||||||
|
|
||||||
|
|
||||||
# Register MCP tool only if mcp is available
|
# Register MCP tool only if mcp is available
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,12 @@ import json
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
import uuid
|
import uuid
|
||||||
from fastapi import FastAPI, Depends, Request
|
|
||||||
from fastmcp.exceptions import ToolError
|
from fastmcp.exceptions import ToolError
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage
|
from .api import ChatMessage
|
||||||
from . import mcp
|
from . import mcp
|
||||||
from fastmcp.server.dependencies import get_http_headers
|
from fastmcp.server.dependencies import get_http_headers
|
||||||
|
|
||||||
|
|
@ -30,8 +29,6 @@ plano_client = AsyncOpenAI(
|
||||||
api_key="EMPTY", # Plano doesn't require a real API key
|
api_key="EMPTY", # Plano doesn't require a real API key
|
||||||
)
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
async def validate_query_scope(
|
async def validate_query_scope(
|
||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
|
|
@ -127,12 +124,14 @@ Respond in JSON format:
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool
|
@mcp.tool
|
||||||
async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]:
|
async def input_guards(body: dict, path: str) -> dict:
|
||||||
"""Input guard that validates queries are within TechCorp's domain.
|
"""Input guard that validates queries are within TechCorp's domain.
|
||||||
|
|
||||||
If the query is out of scope, replaces the user message with a rejection notice.
|
Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
|
||||||
|
If the query is out of scope, raises a ToolError to block the request.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Received request with {len(messages)} messages")
|
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||||
|
logger.info(f"Received request with {len(messages)} messages at path {path}")
|
||||||
|
|
||||||
# Get traceparent header from HTTP request using FastMCP's dependency function
|
# Get traceparent header from HTTP request using FastMCP's dependency function
|
||||||
headers = get_http_headers()
|
headers = get_http_headers()
|
||||||
|
|
@ -153,9 +152,8 @@ async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||||
reason = validation_result.get("reason", "Query is outside TechCorp's domain")
|
reason = validation_result.get("reason", "Query is outside TechCorp's domain")
|
||||||
logger.warning(f"Query rejected: {reason}")
|
logger.warning(f"Query rejected: {reason}")
|
||||||
|
|
||||||
# Throw ToolError
|
|
||||||
error_message = f"I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. {reason}\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support."
|
error_message = f"I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. {reason}\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support."
|
||||||
raise ToolError(error_message)
|
raise ToolError(error_message)
|
||||||
|
|
||||||
logger.info("Query validation passed - forwarding to next filter")
|
logger.info("Query validation passed - forwarding to next filter")
|
||||||
return messages
|
return body
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,11 @@ import json
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
import uuid
|
import uuid
|
||||||
from fastapi import FastAPI, Depends, Request
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage
|
from .api import ChatMessage
|
||||||
from . import mcp
|
from . import mcp
|
||||||
from fastmcp.server.dependencies import get_http_headers
|
from fastmcp.server.dependencies import get_http_headers
|
||||||
|
|
||||||
|
|
@ -29,9 +28,6 @@ plano_client = AsyncOpenAI(
|
||||||
api_key="EMPTY", # Plano doesn't require a real API key
|
api_key="EMPTY", # Plano doesn't require a real API key
|
||||||
)
|
)
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
|
|
||||||
async def rewrite_query_with_plano(
|
async def rewrite_query_with_plano(
|
||||||
messages: List[ChatMessage],
|
messages: List[ChatMessage],
|
||||||
traceparent_header: str,
|
traceparent_header: str,
|
||||||
|
|
@ -87,12 +83,14 @@ async def rewrite_query_with_plano(
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
|
async def query_rewriter(body: dict, path: str) -> dict:
|
||||||
"""Chat completions endpoint that rewrites the last user query using Plano.
|
"""Rewrites the last user query in the request body using Plano.
|
||||||
|
|
||||||
Returns a dict with a 'messages' key containing the updated message list.
|
Receives the full request body dict and the API path hint (e.g. /v1/chat/completions).
|
||||||
|
Returns the body with the last user message rewritten for better retrieval.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Received chat completion request with {len(messages)} messages")
|
messages = [ChatMessage(**m) for m in body.get("messages", [])]
|
||||||
|
logger.info(f"Received request with {len(messages)} messages at path {path}")
|
||||||
|
|
||||||
# Get traceparent header from HTTP request using FastMCP's dependency function
|
# Get traceparent header from HTTP request using FastMCP's dependency function
|
||||||
headers = get_http_headers()
|
headers = get_http_headers()
|
||||||
|
|
@ -109,57 +107,20 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||||
messages, traceparent_header, request_id
|
messages, traceparent_header, request_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create updated messages with the rewritten query
|
|
||||||
updated_messages = messages.copy()
|
|
||||||
|
|
||||||
# Find and update the last user message with the rewritten query
|
# Find and update the last user message with the rewritten query
|
||||||
|
updated_messages = [m.model_dump() for m in messages]
|
||||||
for i in range(len(updated_messages) - 1, -1, -1):
|
for i in range(len(updated_messages) - 1, -1, -1):
|
||||||
if updated_messages[i].role == "user":
|
if updated_messages[i]["role"] == "user":
|
||||||
original_query = updated_messages[i].content
|
|
||||||
updated_messages[i] = ChatMessage(role="user", content=rewritten_query)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Updated user query from '{original_query}' to '{rewritten_query}'"
|
f"Updated user query from '{updated_messages[i]['content']}' to '{rewritten_query}'"
|
||||||
)
|
)
|
||||||
|
updated_messages[i]["content"] = rewritten_query
|
||||||
break
|
break
|
||||||
|
|
||||||
# Return as dict to minimize text serialization
|
logger.info("Returning rewritten chat completion response")
|
||||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
return {**body, "messages": updated_messages}
|
||||||
|
|
||||||
|
|
||||||
# Register MCP tool only if mcp is available
|
# Register MCP tool only if mcp is available
|
||||||
if mcp is not None:
|
if mcp is not None:
|
||||||
mcp.tool()(query_rewriter)
|
mcp.tool()(query_rewriter)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/")
|
|
||||||
async def chat_completions_endpoint(
|
|
||||||
request_messages: List[ChatMessage], request: Request
|
|
||||||
) -> List[ChatMessage]:
|
|
||||||
"""FastAPI endpoint for chat completions with query rewriting."""
|
|
||||||
logger.info(
|
|
||||||
f"Received /v1/chat/completions request with {len(request_messages)} messages"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract traceparent header
|
|
||||||
traceparent_header = request.headers.get("traceparent")
|
|
||||||
if traceparent_header:
|
|
||||||
logger.info(f"Received traceparent header: {traceparent_header}")
|
|
||||||
else:
|
|
||||||
logger.info("No traceparent header found")
|
|
||||||
|
|
||||||
# Call the query rewriter tool
|
|
||||||
updated_messages_data = await query_rewriter(request_messages)
|
|
||||||
|
|
||||||
# Convert back to ChatMessage objects
|
|
||||||
updated_messages = [ChatMessage(**msg) for msg in updated_messages_data]
|
|
||||||
|
|
||||||
logger.info("Returning rewritten chat completion response")
|
|
||||||
return updated_messages
|
|
||||||
|
|
||||||
|
|
||||||
def start_server(host: str = "0.0.0.0", port: int = 10501):
|
|
||||||
"""Start the FastAPI server for query rewriter."""
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
logger.info(f"Starting Query Rewriter REST server on {host}:{port}")
|
|
||||||
uvicorn.run(app, host=host, port=port)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue