From 08471d8adf2d868827ff71618c8f42a1ffec161d Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Wed, 17 Sep 2025 09:39:10 -0700 Subject: [PATCH] add streaming --- .github/workflows/rust_tests.yml | 3 - .../src/handlers/agent_chat_completions.rs | 120 +++++++++---- .../use_cases/rag_agent/src/rag_agent/api.py | 8 + .../src/rag_agent/response_generator_agent.py | 162 ++++++++++++++++-- .../rag_agent/response_generator_test.rest | 40 +++++ 5 files changed, 274 insertions(+), 59 deletions(-) create mode 100644 demos/use_cases/rag_agent/src/rag_agent/response_generator_test.rest diff --git a/.github/workflows/rust_tests.yml b/.github/workflows/rust_tests.yml index aa27fdca..9837531d 100644 --- a/.github/workflows/rust_tests.yml +++ b/.github/workflows/rust_tests.yml @@ -29,6 +29,3 @@ jobs: - name: Run unit tests run: cargo test --lib - - # - name: Run integration tests - # run: cargo test --test integration diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index 46af6c11..f362a14d 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -1,19 +1,14 @@ use std::sync::Arc; use bytes::Bytes; -use common::api::open_ai::{ChatCompletionsResponse, Choice}; use common::configuration::{AgentPipeline, ModelUsagePreference, RoutingPreference}; -use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_UPSTREAM_HOST_HEADER}; +use common::consts::ARCH_UPSTREAM_HOST_HEADER; use hermesllm::apis::openai::ChatCompletionsRequest; -use hermesllm::apis::{Role, Usage}; -use hermesllm::clients::SupportedAPIs; -use hermesllm::{ProviderRequest, ProviderRequestType}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Frame; use hyper::header::{self}; -use hyper::{Request, Response, StatusCode, Uri}; -use serde::{ser::SerializeMap, Deserialize, Serialize}; +use hyper::{Request, Response, StatusCode}; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; @@ -30,7 +25,7 @@ fn full>(chunk: T) -> BoxBody { pub async fn agent_chat( request: Request, router_service: Arc, - full_qualified_llm_provider_url: String, + _: String, agents_list: Arc>>>, listeners: Arc>>, ) -> Result>, hyper::Error> { @@ -52,7 +47,6 @@ pub async fn agent_chat( info!("Handling request for listener: {}", listener.name); - let request_path = request.uri().path().to_string(); let mut request_headers = request.headers().clone(); let chat_request_bytes = request.collect().await?.to_bytes(); @@ -163,9 +157,12 @@ pub async fn agent_chat( request_headers.remove(header::CONTENT_LENGTH); - for agent_name in agent_pipeline.filter_chain { + let filter_chain_without_terminal_agent = + &agent_pipeline.filter_chain[..agent_pipeline.filter_chain.len() - 1]; + + for agent_name in filter_chain_without_terminal_agent { debug!("Processing agent: {}", agent_name); - let agent = agent_name_map.get(&agent_name).unwrap(); + let agent = agent_name_map.get(agent_name).unwrap(); debug!("Agent details: {:?}", agent); let mut request = chat_completions_request.clone(); @@ -223,41 +220,88 @@ pub async fn agent_chat( .clone() .unwrap(); - debug!( - "Received response from agent {}", - agent_name - ); + debug!("Received response from agent {}", agent_name); chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]); } - let last_response: Option = match chat_completions_history.last() { - Some(msg) => Some(msg.content.clone().to_string()), - None => None, + let terminal_agent_name = agent_pipeline.filter_chain.last().unwrap(); + let terminal_agent = agent_name_map.get(terminal_agent_name).unwrap(); + debug!("Processing terminal agent: {}", terminal_agent_name); + debug!("Terminal agent details: {:?}", terminal_agent); + + let mut request = chat_completions_request.clone(); + request.messages = chat_completions_history.clone(); + + let request_str = serde_json::to_string(&request).unwrap(); + debug!("Sending request to agent {}", terminal_agent_name); + + let mut agent_request_headers = request_headers.clone(); + agent_request_headers.insert( + ARCH_UPSTREAM_HOST_HEADER, + hyper::header::HeaderValue::from_str(terminal_agent.name.as_str()).unwrap(), + ); + + let llm_response = match reqwest::Client::new() + .post("http://localhost:11000/v1/chat/completions") + .headers(agent_request_headers) + .body(request_str) + .send() + .await + { + Ok(res) => res, + Err(err) => { + let err_msg = format!("Failed to send request: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Ok(internal_error); + } }; - let chat_completion_response: hermesllm::apis::openai::ChatCompletionsResponse = - hermesllm::apis::openai::ChatCompletionsResponse { - model: "arch-agent".to_string(), - choices: vec![hermesllm::apis::openai::Choice { - message: { - hermesllm::apis::openai::ResponseMessage { - role: hermesllm::apis::openai::Role::Assistant, - content: last_response, - ..Default::default() - } - }, - ..Default::default() - }], - usage: hermesllm::apis::openai::Usage { - ..Default::default() - }, - ..Default::default() - }; + // copy over the headers from the original response + let response_headers = llm_response.headers().clone(); + let mut response = Response::builder(); + let headers = response.headers_mut().unwrap(); + for (header_name, header_value) in response_headers.iter() { + headers.insert(header_name, header_value.clone()); + } - let response_body = serde_json::to_string(&chat_completion_response).unwrap(); + // channel to create async stream + let (tx, rx) = mpsc::channel::(16); - return Ok(Response::new(full(response_body))); + // Spawn a task to send data as it becomes available + tokio::spawn(async move { + let mut byte_stream = llm_response.bytes_stream(); + + while let Some(item) = byte_stream.next().await { + let item = match item { + Ok(item) => item, + Err(err) => { + warn!("Error receiving chunk: {:?}", err); + break; + } + }; + + if tx.send(item).await.is_err() { + warn!("Receiver dropped"); + break; + } + } + }); + + let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk))); + + let stream_body = BoxBody::new(StreamBody::new(stream)); + + match response.body(stream_body) { + Ok(response) => Ok(response), + Err(err) => { + let err_msg = format!("Failed to create response: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + Ok(internal_error) + } + } } fn convert_agent_description_to_routing_preferences( diff --git a/demos/use_cases/rag_agent/src/rag_agent/api.py b/demos/use_cases/rag_agent/src/rag_agent/api.py index 292451c2..eb63ea99 100644 --- a/demos/use_cases/rag_agent/src/rag_agent/api.py +++ b/demos/use_cases/rag_agent/src/rag_agent/api.py @@ -26,3 +26,11 @@ class ChatCompletionResponse(BaseModel): model: str choices: List[Dict[str, Any]] usage: Dict[str, int] + + +class ChatCompletionStreamResponse(BaseModel): + id: str + object: str = "chat.completion.chunk" + created: int + model: str + choices: List[Dict[str, Any]] diff --git a/demos/use_cases/rag_agent/src/rag_agent/response_generator_agent.py b/demos/use_cases/rag_agent/src/rag_agent/response_generator_agent.py index 3faaf4ef..f3f0c72d 100644 --- a/demos/use_cases/rag_agent/src/rag_agent/response_generator_agent.py +++ b/demos/use_cases/rag_agent/src/rag_agent/response_generator_agent.py @@ -1,13 +1,19 @@ import json from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse from openai import AsyncOpenAI import os import logging import time import uuid import uvicorn +import asyncio -from .api import ChatCompletionRequest, ChatCompletionResponse +from .api import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionStreamResponse, +) # Set up logging logging.basicConfig(level=logging.INFO) @@ -17,6 +23,18 @@ logger = logging.getLogger(__name__) LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:9000/v1") RESPONSE_MODEL = "gpt-4o" +# System prompt for response generation +SYSTEM_PROMPT = """You are a helpful assistant that generates coherent, contextual responses. + +Given a conversation history, generate a helpful and relevant response based on all the context available in the messages. +Your response should: +1. Be contextually aware of the entire conversation +2. Address the user's needs appropriately +3. Be helpful and informative +4. Maintain a natural conversational tone + +Generate a complete response to assist the user.""" + # Initialize OpenAI client for archgw archgw_client = AsyncOpenAI( base_url=LLM_GATEWAY_ENDPOINT, @@ -27,6 +45,17 @@ archgw_client = AsyncOpenAI( app = FastAPI(title="RAG Agent Response Generator", version="1.0.0") +def prepare_response_messages(request_body: ChatCompletionRequest): + """Prepare messages for response generation by adding system prompt.""" + response_messages = [{"role": "system", "content": SYSTEM_PROMPT}] + + # Add conversation history + for msg in request_body.messages: + response_messages.append({"role": msg.role, "content": msg.content}) + + return response_messages + + @app.post("/v1/chat/completions") async def chat_completions(request_body: ChatCompletionRequest, request: Request): """Chat completions endpoint that generates a coherent response based on all context.""" @@ -41,24 +70,121 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request else: logger.info("No traceparent header found") - # Prepare the system prompt for response generation - system_prompt = """You are a helpful assistant that generates coherent, contextual responses. + # Check if streaming is requested + if request_body.stream: + return StreamingResponse( + stream_chat_completions(request_body, traceparent_header), + media_type="text/plain", + ) + else: + return await non_streaming_chat_completions(request_body, traceparent_header) - Given a conversation history, generate a helpful and relevant response based on all the context available in the messages. - Your response should: - 1. Be contextually aware of the entire conversation - 2. Address the user's needs appropriately - 3. Be helpful and informative - 4. Maintain a natural conversational tone - - Generate a complete response to assist the user.""" +async def stream_chat_completions( + request_body: ChatCompletionRequest, traceparent_header: str = None +): + """Generate streaming chat completions.""" # Prepare messages for response generation - response_messages = [{"role": "system", "content": system_prompt}] + response_messages = prepare_response_messages(request_body) - # Add conversation history - for msg in request_body.messages: - response_messages.append({"role": msg.role, "content": msg.content}) + try: + # Call archgw using OpenAI client for streaming + logger.info( + f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate streaming response" + ) + + # Prepare extra headers if traceparent is provided + extra_headers = {} + if traceparent_header: + extra_headers["traceparent"] = traceparent_header + + response_stream = await archgw_client.chat.completions.create( + model=RESPONSE_MODEL, + messages=response_messages, + temperature=request_body.temperature or 0.7, + max_tokens=request_body.max_tokens or 1000, + stream=True, + extra_headers=extra_headers, + ) + + completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}" + created_time = int(time.time()) + collected_content = [] + + async for chunk in response_stream: + if chunk.choices and chunk.choices[0].delta.content: + content = chunk.choices[0].delta.content + collected_content.append(content) + + # Create streaming response chunk + stream_chunk = ChatCompletionStreamResponse( + id=completion_id, + created=created_time, + model=request_body.model, + choices=[ + { + "index": 0, + "delta": {"content": content}, + "finish_reason": None, + } + ], + ) + + yield f"data: {stream_chunk.model_dump_json()}\n\n" + + # Send final chunk with complete response in expected format + full_response = "".join(collected_content) + updated_history = [{"role": "assistant", "content": full_response}] + + final_chunk = ChatCompletionStreamResponse( + id=completion_id, + created=created_time, + model=request_body.model, + choices=[ + { + "index": 0, + "delta": {}, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": json.dumps(updated_history), + }, + } + ], + ) + + yield f"data: {final_chunk.model_dump_json()}\n\n" + yield "data: [DONE]\n\n" + + except Exception as e: + logger.error(f"Error generating streaming response: {e}") + + # Send error as streaming response + error_chunk = ChatCompletionStreamResponse( + id=f"chatcmpl-{uuid.uuid4().hex[:8]}", + created=int(time.time()), + model=request_body.model, + choices=[ + { + "index": 0, + "delta": { + "content": "I apologize, but I'm having trouble generating a response right now. Please try again." + }, + "finish_reason": "stop", + } + ], + ) + + yield f"data: {error_chunk.model_dump_json()}\n\n" + yield "data: [DONE]\n\n" + + +async def non_streaming_chat_completions( + request_body: ChatCompletionRequest, traceparent_header: str = None +): + """Generate non-streaming chat completions.""" + # Prepare messages for response generation + response_messages = prepare_response_messages(request_body) try: # Call archgw using OpenAI client @@ -116,7 +242,7 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request return ChatCompletionResponse( id=f"chatcmpl-{uuid.uuid4().hex[:8]}", created=int(time.time()), - model=request.model, + model=request_body.model, choices=[ { "index": 0, @@ -126,11 +252,11 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request ], usage={ "prompt_tokens": sum( - len(msg.content.split()) for msg in request.messages + len(msg.content.split()) for msg in request_body.messages ), "completion_tokens": len(fallback_message.split()), "total_tokens": sum( - len(msg.content.split()) for msg in request.messages + len(msg.content.split()) for msg in request_body.messages ) + len(fallback_message.split()), }, diff --git a/demos/use_cases/rag_agent/src/rag_agent/response_generator_test.rest b/demos/use_cases/rag_agent/src/rag_agent/response_generator_test.rest new file mode 100644 index 00000000..3509f043 --- /dev/null +++ b/demos/use_cases/rag_agent/src/rag_agent/response_generator_test.rest @@ -0,0 +1,40 @@ +@baseUrl = http://0.0.0.0:10502 +@model = gpt-4o + +### + +# Health Check +GET {{baseUrl}}/health + +### + +# Test 1: Simple Non-Streaming Chat Completion +POST {{baseUrl}}/v1/chat/completions +Content-Type: application/json + +{ + "model": "{{model}}", + "messages": [ + { + "role": "user", + "content": "Hello! Can you help me understand what machine learning is?" + } + ] +} + +### + +# Test 2: Simple Streaming Chat Completion +POST {{baseUrl}}/v1/chat/completions +Content-Type: application/json + +{ + "model": "{{model}}", + "messages": [ + { + "role": "user", + "content": "Explain the concept of artificial intelligence in simple terms." + } + ], + "stream": true +}