mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
add streaming
This commit is contained in:
parent
4588787427
commit
08471d8adf
5 changed files with 274 additions and 59 deletions
|
|
@ -26,3 +26,11 @@ class ChatCompletionResponse(BaseModel):
|
|||
model: str
|
||||
choices: List[Dict[str, Any]]
|
||||
usage: Dict[str, int]
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int
|
||||
model: str
|
||||
choices: List[Dict[str, Any]]
|
||||
|
|
|
|||
|
|
@ -1,13 +1,19 @@
|
|||
import json
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import uvicorn
|
||||
import asyncio
|
||||
|
||||
from .api import ChatCompletionRequest, ChatCompletionResponse
|
||||
from .api import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
)
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
|
@ -17,6 +23,18 @@ logger = logging.getLogger(__name__)
|
|||
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:9000/v1")
|
||||
RESPONSE_MODEL = "gpt-4o"
|
||||
|
||||
# System prompt for response generation
|
||||
SYSTEM_PROMPT = """You are a helpful assistant that generates coherent, contextual responses.
|
||||
|
||||
Given a conversation history, generate a helpful and relevant response based on all the context available in the messages.
|
||||
Your response should:
|
||||
1. Be contextually aware of the entire conversation
|
||||
2. Address the user's needs appropriately
|
||||
3. Be helpful and informative
|
||||
4. Maintain a natural conversational tone
|
||||
|
||||
Generate a complete response to assist the user."""
|
||||
|
||||
# Initialize OpenAI client for archgw
|
||||
archgw_client = AsyncOpenAI(
|
||||
base_url=LLM_GATEWAY_ENDPOINT,
|
||||
|
|
@ -27,6 +45,17 @@ archgw_client = AsyncOpenAI(
|
|||
app = FastAPI(title="RAG Agent Response Generator", version="1.0.0")
|
||||
|
||||
|
||||
def prepare_response_messages(request_body: ChatCompletionRequest):
|
||||
"""Prepare messages for response generation by adding system prompt."""
|
||||
response_messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
||||
|
||||
# Add conversation history
|
||||
for msg in request_body.messages:
|
||||
response_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
return response_messages
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request_body: ChatCompletionRequest, request: Request):
|
||||
"""Chat completions endpoint that generates a coherent response based on all context."""
|
||||
|
|
@ -41,24 +70,121 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request
|
|||
else:
|
||||
logger.info("No traceparent header found")
|
||||
|
||||
# Prepare the system prompt for response generation
|
||||
system_prompt = """You are a helpful assistant that generates coherent, contextual responses.
|
||||
# Check if streaming is requested
|
||||
if request_body.stream:
|
||||
return StreamingResponse(
|
||||
stream_chat_completions(request_body, traceparent_header),
|
||||
media_type="text/plain",
|
||||
)
|
||||
else:
|
||||
return await non_streaming_chat_completions(request_body, traceparent_header)
|
||||
|
||||
Given a conversation history, generate a helpful and relevant response based on all the context available in the messages.
|
||||
Your response should:
|
||||
1. Be contextually aware of the entire conversation
|
||||
2. Address the user's needs appropriately
|
||||
3. Be helpful and informative
|
||||
4. Maintain a natural conversational tone
|
||||
|
||||
Generate a complete response to assist the user."""
|
||||
|
||||
async def stream_chat_completions(
|
||||
request_body: ChatCompletionRequest, traceparent_header: str = None
|
||||
):
|
||||
"""Generate streaming chat completions."""
|
||||
# Prepare messages for response generation
|
||||
response_messages = [{"role": "system", "content": system_prompt}]
|
||||
response_messages = prepare_response_messages(request_body)
|
||||
|
||||
# Add conversation history
|
||||
for msg in request_body.messages:
|
||||
response_messages.append({"role": msg.role, "content": msg.content})
|
||||
try:
|
||||
# Call archgw using OpenAI client for streaming
|
||||
logger.info(
|
||||
f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate streaming response"
|
||||
)
|
||||
|
||||
# Prepare extra headers if traceparent is provided
|
||||
extra_headers = {}
|
||||
if traceparent_header:
|
||||
extra_headers["traceparent"] = traceparent_header
|
||||
|
||||
response_stream = await archgw_client.chat.completions.create(
|
||||
model=RESPONSE_MODEL,
|
||||
messages=response_messages,
|
||||
temperature=request_body.temperature or 0.7,
|
||||
max_tokens=request_body.max_tokens or 1000,
|
||||
stream=True,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
completion_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
|
||||
created_time = int(time.time())
|
||||
collected_content = []
|
||||
|
||||
async for chunk in response_stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
content = chunk.choices[0].delta.content
|
||||
collected_content.append(content)
|
||||
|
||||
# Create streaming response chunk
|
||||
stream_chunk = ChatCompletionStreamResponse(
|
||||
id=completion_id,
|
||||
created=created_time,
|
||||
model=request_body.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": content},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
yield f"data: {stream_chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Send final chunk with complete response in expected format
|
||||
full_response = "".join(collected_content)
|
||||
updated_history = [{"role": "assistant", "content": full_response}]
|
||||
|
||||
final_chunk = ChatCompletionStreamResponse(
|
||||
id=completion_id,
|
||||
created=created_time,
|
||||
model=request_body.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": json.dumps(updated_history),
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
yield f"data: {final_chunk.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating streaming response: {e}")
|
||||
|
||||
# Send error as streaming response
|
||||
error_chunk = ChatCompletionStreamResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request_body.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": "I apologize, but I'm having trouble generating a response right now. Please try again."
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
yield f"data: {error_chunk.model_dump_json()}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def non_streaming_chat_completions(
|
||||
request_body: ChatCompletionRequest, traceparent_header: str = None
|
||||
):
|
||||
"""Generate non-streaming chat completions."""
|
||||
# Prepare messages for response generation
|
||||
response_messages = prepare_response_messages(request_body)
|
||||
|
||||
try:
|
||||
# Call archgw using OpenAI client
|
||||
|
|
@ -116,7 +242,7 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request
|
|||
return ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
model=request_body.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
|
|
@ -126,11 +252,11 @@ async def chat_completions(request_body: ChatCompletionRequest, request: Request
|
|||
],
|
||||
usage={
|
||||
"prompt_tokens": sum(
|
||||
len(msg.content.split()) for msg in request.messages
|
||||
len(msg.content.split()) for msg in request_body.messages
|
||||
),
|
||||
"completion_tokens": len(fallback_message.split()),
|
||||
"total_tokens": sum(
|
||||
len(msg.content.split()) for msg in request.messages
|
||||
len(msg.content.split()) for msg in request_body.messages
|
||||
)
|
||||
+ len(fallback_message.split()),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
@baseUrl = http://0.0.0.0:10502
|
||||
@model = gpt-4o
|
||||
|
||||
###
|
||||
|
||||
# Health Check
|
||||
GET {{baseUrl}}/health
|
||||
|
||||
###
|
||||
|
||||
# Test 1: Simple Non-Streaming Chat Completion
|
||||
POST {{baseUrl}}/v1/chat/completions
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "{{model}}",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello! Can you help me understand what machine learning is?"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
###
|
||||
|
||||
# Test 2: Simple Streaming Chat Completion
|
||||
POST {{baseUrl}}/v1/chat/completions
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "{{model}}",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Explain the concept of artificial intelligence in simple terms."
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue