more changes

This commit is contained in:
Adil Hafeez 2025-09-15 13:16:02 -07:00
parent f5f33f6de2
commit a016212588
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
10 changed files with 157 additions and 129 deletions

View file

@ -50,11 +50,12 @@ properties:
endpoints:
type: object
patternProperties:
"^.*$":
"^[a-zA-Z][a-zA-Z0-9_]*$":
type: object
properties:
endpoint:
type: string
pattern: "^[a-zA-Z].*$"
connect_timeout:
type: string
protocol:

View file

@ -249,6 +249,8 @@ static_resources:
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
path: "/var/log/access_internal.log"
format: |
[%START_TIME%] "%REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% %PROTOCOL%" %RESPONSE_CODE% %RESPONSE_FLAGS% %BYTES_RECEIVED% %BYTES_SENT% %DURATION% %RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)% "%REQ(X-FORWARDED-FOR)%" "%REQ(USER-AGENT)%" "%REQ(X-REQUEST-ID)%" "%REQ(:AUTHORITY)%" "%UPSTREAM_HOST%" "%UPSTREAM_CLUSTER%"
route_config:
name: local_routes
virtual_hosts:

View file

@ -68,15 +68,41 @@ def validate_and_render_schema():
endpoints = config_yaml.get("endpoints", {})
# Process agents section and convert to endpoints
agents = config_yaml.get("agents", [])
for agent in agents:
agent_name = agent.get("name")
agent_endpoint = agent.get("endpoint")
if agent_name and agent_endpoint:
urlparse_result = urlparse(agent_endpoint)
if urlparse_result.scheme and urlparse_result.hostname:
protocol = urlparse_result.scheme
port = urlparse_result.port
if port is None:
if protocol == "http":
port = 80
else:
port = 443
endpoints[agent_name] = {
"endpoint": urlparse_result.hostname,
"port": port,
"protocol": protocol,
}
# override the inferred clusters with the ones defined in the config
for name, endpoint_details in endpoints.items():
inferred_clusters[name] = endpoint_details
endpoint = inferred_clusters[name]["endpoint"]
protocol = inferred_clusters[name].get("protocol", "http")
(
inferred_clusters[name]["endpoint"],
inferred_clusters[name]["port"],
) = get_endpoint_and_port(endpoint, protocol)
# Only call get_endpoint_and_port for manually defined endpoints, not agent-derived ones
if "port" not in endpoint_details:
endpoint = inferred_clusters[name]["endpoint"]
protocol = inferred_clusters[name].get("protocol", "http")
(
inferred_clusters[name]["endpoint"],
inferred_clusters[name]["port"],
) = get_endpoint_and_port(endpoint, protocol)
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))

View file

@ -95,19 +95,19 @@ version: v0.2.0
agents:
- name: query_rewriter
kind: openai
endpoint: openai://localhost:10500
endpoint: http://localhost:10500
- name: context_builder
kind: openai
endpoint: openai://localhost:10501
endpoint: http://localhost:10501
- name: response_generator
kind: openai
endpoint: openai://localhost:10502
endpoint: http://localhost:10502
- name: research_agent
kind: openai
endpoint: https://localhost:10500
endpoint: http://localhost:10500
- name: input_guard_rails
kind: openai
endpoint: https://localhost:10503
endpoint: http://localhost:10503
listeners:
- name: tmobile

View file

@ -3,7 +3,7 @@ use std::sync::Arc;
use bytes::Bytes;
use common::api::open_ai::{ChatCompletionsResponse, Choice};
use common::configuration::ModelUsagePreference;
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_UPSTREAM_HOST_HEADER};
use hermesllm::apis::openai::ChatCompletionsRequest;
use hermesllm::apis::{Role, Usage};
use hermesllm::clients::SupportedAPIs;
@ -12,7 +12,7 @@ use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use hyper::{Request, Response, StatusCode, Uri};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
@ -83,26 +83,42 @@ pub async fn agent_chat(
debug!("Processing agent pipeline: {}", agent_pipeline.name);
let mut chat_completions_history = chat_completions_request.messages.clone();
let mut last_response: Option<String> = None;
// let trace_parent = request_headers
// .iter()
// .find(|(ty, _)| ty.as_str() == "traceparent")
// .map(|(_, value)| value.to_str().unwrap_or_default().to_string());
// if let Some(trace_parent) = trace_parent {
// request_headers.insert(
// header::HeaderName::from_static("traceparent"),
// header::HeaderValue::from_str(&trace_parent).unwrap(),
// );
// }
request_headers.remove(header::CONTENT_LENGTH);
// request_headers.remove("traceparent");
for agent_name in agent_pipeline.filter_chain {
debug!("Processing agent: {}", agent_name);
let agent = agent_name_map.get(&agent_name).unwrap();
debug!("Agent details: {:?}", agent);
let path = format!(
"{}/v1/chat/completions",
agent.endpoint.trim_end_matches('/')
);
let mut request = chat_completions_request.clone();
request.messages = chat_completions_history.clone();
let request_str = serde_json::to_string(&request).unwrap();
debug!("Sending request to agent {}: {}", agent_name, request_str);
let mut agent_request_headers = request_headers.clone();
agent_request_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(agent.name.as_str()).unwrap(),
);
let response = match reqwest::Client::new()
.post(path)
.post("http://localhost:11000/v1/chat/completions")
.headers(agent_request_headers)
.body(request_str)
.send()
.await
@ -149,14 +165,6 @@ pub async fn agent_chat(
);
chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]);
// chat_completions_history.append(&mut vec![hermesllm::apis::openai::Message {
// role: hermesllm::apis::openai::Role::Assistant,
// content: hermesllm::apis::openai::MessageContent::Text(response_str),
// name: Some(agent_name.clone()),
// tool_calls: None,
// tool_call_id: None,
// }]);
}
let last_response: Option<String> = match chat_completions_history.last() {
@ -200,79 +208,4 @@ pub async fn agent_chat(
let response_body = serde_json::to_string(&chat_completion_response).unwrap();
return Ok(Response::new(full(response_body)));
// request_headers.insert(
// ARCH_PROVIDER_HINT_HEADER,
// header::HeaderValue::from_str(&model_name).unwrap(),
// );
// if let Some(trace_parent) = trace_parent {
// request_headers.insert(
// header::HeaderName::from_static("traceparent"),
// header::HeaderValue::from_str(&trace_parent).unwrap(),
// );
// }
// // remove content-length header if it exists
// request_headers.remove(header::CONTENT_LENGTH);
// let llm_response = match reqwest::Client::new()
// .post(full_qualified_llm_provider_url)
// .headers(request_headers)
// .body(client_request_bytes_for_upstream)
// .send()
// .await
// {
// Ok(res) => res,
// Err(err) => {
// let err_msg = format!("Failed to send request: {}", err);
// let mut internal_error = Response::new(full(err_msg));
// *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
// return Ok(internal_error);
// }
// };
// // copy over the headers from the original response
// let response_headers = llm_response.headers().clone();
// let mut response = Response::builder();
// let headers = response.headers_mut().unwrap();
// for (header_name, header_value) in response_headers.iter() {
// headers.insert(header_name, header_value.clone());
// }
// // channel to create async stream
// let (tx, rx) = mpsc::channel::<Bytes>(16);
// // Spawn a task to send data as it becomes available
// tokio::spawn(async move {
// let mut byte_stream = llm_response.bytes_stream();
// while let Some(item) = byte_stream.next().await {
// let item = match item {
// Ok(item) => item,
// Err(err) => {
// warn!("Error receiving chunk: {:?}", err);
// break;
// }
// };
// if tx.send(item).await.is_err() {
// warn!("Receiver dropped");
// break;
// }
// }
// });
// let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
// let stream_body = BoxBody::new(StreamBody::new(stream));
// match response.body(stream_body) {
// Ok(response) => Ok(response),
// Err(err) => {
// let err_msg = format!("Failed to create response: {}", err);
// let mut internal_error = Response::new(full(err_msg));
// *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
// Ok(internal_error)
// }
// }
}

View file

@ -29,14 +29,21 @@ listeners:
filter_chain:
- research_agent
- response_generator
protocol: openai
address: 0.0.0.0
port: 8001
- name: egress_traffic
description: llm provider configuration
port: 12000
protocol: openai
llm_providers:
- access_key: $OPENAI_API_KEY
model: openai/gpt-4o
- access_key: $OPENAI_API_KEY
model: openai/gpt-4o-mini
address: 0.0.0.0
port: 12000
tracing:
random_sampling: 100
trace_arch_internal: true

View file

@ -0,0 +1,8 @@
services:
jaeger:
build:
context: ../../shared/jaeger
ports:
- "16686:16686"
- "4317:4317"
- "4318:4318"

View file

@ -1,7 +1,7 @@
import json
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from openai import AsyncOpenAI
import os
import logging
@ -55,7 +55,9 @@ def load_knowledge_base():
knowledge_base = []
async def find_relevant_passages(query: str, top_k: int = 3) -> List[Dict[str, str]]:
async def find_relevant_passages(
query: str, traceparent: Optional[str] = None, top_k: int = 3
) -> List[Dict[str, str]]:
"""Use the LLM to find the most relevant passages from the knowledge base."""
if not knowledge_base:
@ -87,11 +89,18 @@ async def find_relevant_passages(query: str, top_k: int = 3) -> List[Dict[str, s
try:
# Call archgw to select relevant passages
logger.info(f"Calling archgw to find relevant passages for query: '{query}'")
# Prepare extra headers if traceparent is provided
extra_headers = {}
if traceparent:
extra_headers["traceparent"] = traceparent
response = await archgw_client.chat.completions.create(
model=RAG_MODEL,
messages=[{"role": "system", "content": system_prompt}],
temperature=0.1,
max_tokens=50,
extra_headers=extra_headers,
)
result = response.choices[0].message.content.strip()
@ -118,7 +127,9 @@ async def find_relevant_passages(query: str, top_k: int = 3) -> List[Dict[str, s
return []
async def augment_query_with_context(messages: List[ChatMessage]) -> List[ChatMessage]:
async def augment_query_with_context(
messages: List[ChatMessage], traceparent: Optional[str] = None
) -> List[ChatMessage]:
"""Extract user query, find relevant context, and augment the messages."""
# Find the last user message
@ -138,7 +149,7 @@ async def augment_query_with_context(messages: List[ChatMessage]) -> List[ChatMe
logger.info(f"Processing user query: '{last_user_message}'")
# Find relevant passages
relevant_passages = await find_relevant_passages(last_user_message)
relevant_passages = await find_relevant_passages(last_user_message, traceparent)
if not relevant_passages:
logger.info("No relevant passages found, returning original messages")
@ -178,23 +189,34 @@ app = FastAPI(title="RAG Content Builder Agent", version="1.0.0")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest) -> ChatCompletionResponse:
async def chat_completions(
request_body: ChatCompletionRequest, request: Request
) -> ChatCompletionResponse:
"""Chat completions endpoint that augments user queries with relevant context from the knowledge base."""
import time
import uuid
logger.info(
f"Received chat completion request with {len(request.messages)} messages"
f"Received chat completion request with {len(request_body.messages)} messages"
)
# Read traceparent header if present
traceparent_header = request.headers.get("traceparent")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
else:
logger.info("No traceparent header found")
# Augment the user query with relevant context
updated_messages = await augment_query_with_context(request.messages)
updated_messages = await augment_query_with_context(
request_body.messages, traceparent_header
)
messages_history_json = json.dumps([msg.dict() for msg in updated_messages])
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request.model,
model=request_body.model,
choices=[
{
"index": 0,

View file

@ -1,7 +1,7 @@
import json
from pydantic import BaseModel
from typing import List, Optional, Dict, Any
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from openai import AsyncOpenAI
import os
import logging
@ -26,7 +26,9 @@ archgw_client = AsyncOpenAI(
)
async def rewrite_query_with_archgw(messages: List[ChatMessage]) -> str:
async def rewrite_query_with_archgw(
messages: List[ChatMessage], traceparent_header: str
) -> str:
# Prepare the system prompt for query rewriting
system_prompt = """You are a query rewriter that improves user queries for better retrieval.
@ -48,12 +50,16 @@ async def rewrite_query_with_archgw(messages: List[ChatMessage]) -> str:
try:
# Call archgw using OpenAI client
extra_headers = {}
if traceparent_header:
extra_headers["traceparent"] = traceparent_header
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to rewrite query")
response = await archgw_client.chat.completions.create(
model=QUERY_REWRITE_MODEL,
messages=rewrite_messages,
temperature=0.3,
max_tokens=200,
extra_headers=extra_headers,
)
rewritten_query = response.choices[0].message.content.strip()
@ -81,20 +87,29 @@ app = FastAPI(title="RAG Agent Query Parser", version="1.0.0")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
async def chat_completions(request_body: ChatCompletionRequest, request: Request):
"""Chat completions endpoint that rewrites the last user query using archgw."""
import time
import uuid
logger.info(
f"Received chat completion request with {len(request.messages)} messages"
f"Received chat completion request with {len(request_body.messages)} messages"
)
# Read traceparent header if present
traceparent_header = request.headers.get("traceparent")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
else:
logger.info("No traceparent header found")
# Call archgw to rewrite the last user query
rewritten_query = await rewrite_query_with_archgw(request.messages)
rewritten_query = await rewrite_query_with_archgw(
request_body.messages, traceparent_header
)
# Create updated messages with the rewritten query
updated_messages = request.messages.copy()
updated_messages = request_body.messages.copy()
# Find and update the last user message with the rewritten query
for i in range(len(updated_messages) - 1, -1, -1):
@ -111,7 +126,7 @@ async def chat_completions(request: ChatCompletionRequest):
response = ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request.model,
model=request_body.model,
choices=[
{
"index": 0,

View file

@ -1,5 +1,5 @@
import json
from fastapi import FastAPI
from fastapi import FastAPI, Request
from openai import AsyncOpenAI
import os
import logging
@ -28,12 +28,19 @@ app = FastAPI(title="RAG Agent Response Generator", version="1.0.0")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
async def chat_completions(request_body: ChatCompletionRequest, request: Request):
"""Chat completions endpoint that generates a coherent response based on all context."""
logger.info(
f"Received chat completion request with {len(request.messages)} messages"
f"Received chat completion request with {len(request_body.messages)} messages"
)
# Read traceparent header if present
traceparent_header = request.headers.get("traceparent")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
else:
logger.info("No traceparent header found")
# Prepare the system prompt for response generation
system_prompt = """You are a helpful assistant that generates coherent, contextual responses.
@ -50,17 +57,24 @@ async def chat_completions(request: ChatCompletionRequest):
response_messages = [{"role": "system", "content": system_prompt}]
# Add conversation history
for msg in request.messages:
for msg in request_body.messages:
response_messages.append({"role": msg.role, "content": msg.content})
try:
# Call archgw using OpenAI client
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate response")
# Prepare extra headers if traceparent is provided
extra_headers = {}
if traceparent_header:
extra_headers["traceparent"] = traceparent_header
response = await archgw_client.chat.completions.create(
model=RESPONSE_MODEL,
messages=response_messages,
temperature=request.temperature or 0.7,
max_tokens=request.max_tokens or 1000,
temperature=request_body.temperature or 0.7,
max_tokens=request_body.max_tokens or 1000,
extra_headers=extra_headers,
)
generated_response = response.choices[0].message.content.strip()
@ -71,7 +85,7 @@ async def chat_completions(request: ChatCompletionRequest):
return ChatCompletionResponse(
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
created=int(time.time()),
model=request.model,
model=request_body.model,
choices=[
{
"index": 0,
@ -84,11 +98,11 @@ async def chat_completions(request: ChatCompletionRequest):
],
usage={
"prompt_tokens": sum(
len(msg.content.split()) for msg in request.messages
len(msg.content.split()) for msg in request_body.messages
),
"completion_tokens": len(generated_response.split()),
"total_tokens": sum(
len(msg.content.split()) for msg in request.messages
len(msg.content.split()) for msg in request_body.messages
)
+ len(generated_response.split()),
},