mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
more changes
This commit is contained in:
parent
f5f33f6de2
commit
a016212588
10 changed files with 157 additions and 129 deletions
|
|
@ -50,11 +50,12 @@ properties:
|
|||
endpoints:
|
||||
type: object
|
||||
patternProperties:
|
||||
"^.*$":
|
||||
"^[a-zA-Z][a-zA-Z0-9_]*$":
|
||||
type: object
|
||||
properties:
|
||||
endpoint:
|
||||
type: string
|
||||
pattern: "^[a-zA-Z].*$"
|
||||
connect_timeout:
|
||||
type: string
|
||||
protocol:
|
||||
|
|
|
|||
|
|
@ -249,6 +249,8 @@ static_resources:
|
|||
typed_config:
|
||||
"@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog
|
||||
path: "/var/log/access_internal.log"
|
||||
format: |
|
||||
[%START_TIME%] "%REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% %PROTOCOL%" %RESPONSE_CODE% %RESPONSE_FLAGS% %BYTES_RECEIVED% %BYTES_SENT% %DURATION% %RESP(X-ENVOY-UPSTREAM-SERVICE-TIME)% "%REQ(X-FORWARDED-FOR)%" "%REQ(USER-AGENT)%" "%REQ(X-REQUEST-ID)%" "%REQ(:AUTHORITY)%" "%UPSTREAM_HOST%" "%UPSTREAM_CLUSTER%"
|
||||
route_config:
|
||||
name: local_routes
|
||||
virtual_hosts:
|
||||
|
|
|
|||
|
|
@ -68,15 +68,41 @@ def validate_and_render_schema():
|
|||
|
||||
endpoints = config_yaml.get("endpoints", {})
|
||||
|
||||
# Process agents section and convert to endpoints
|
||||
agents = config_yaml.get("agents", [])
|
||||
for agent in agents:
|
||||
agent_name = agent.get("name")
|
||||
agent_endpoint = agent.get("endpoint")
|
||||
|
||||
if agent_name and agent_endpoint:
|
||||
urlparse_result = urlparse(agent_endpoint)
|
||||
if urlparse_result.scheme and urlparse_result.hostname:
|
||||
protocol = urlparse_result.scheme
|
||||
|
||||
port = urlparse_result.port
|
||||
if port is None:
|
||||
if protocol == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
|
||||
endpoints[agent_name] = {
|
||||
"endpoint": urlparse_result.hostname,
|
||||
"port": port,
|
||||
"protocol": protocol,
|
||||
}
|
||||
|
||||
# override the inferred clusters with the ones defined in the config
|
||||
for name, endpoint_details in endpoints.items():
|
||||
inferred_clusters[name] = endpoint_details
|
||||
endpoint = inferred_clusters[name]["endpoint"]
|
||||
protocol = inferred_clusters[name].get("protocol", "http")
|
||||
(
|
||||
inferred_clusters[name]["endpoint"],
|
||||
inferred_clusters[name]["port"],
|
||||
) = get_endpoint_and_port(endpoint, protocol)
|
||||
# Only call get_endpoint_and_port for manually defined endpoints, not agent-derived ones
|
||||
if "port" not in endpoint_details:
|
||||
endpoint = inferred_clusters[name]["endpoint"]
|
||||
protocol = inferred_clusters[name].get("protocol", "http")
|
||||
(
|
||||
inferred_clusters[name]["endpoint"],
|
||||
inferred_clusters[name]["port"],
|
||||
) = get_endpoint_and_port(endpoint, protocol)
|
||||
|
||||
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))
|
||||
|
||||
|
|
|
|||
|
|
@ -95,19 +95,19 @@ version: v0.2.0
|
|||
agents:
|
||||
- name: query_rewriter
|
||||
kind: openai
|
||||
endpoint: openai://localhost:10500
|
||||
endpoint: http://localhost:10500
|
||||
- name: context_builder
|
||||
kind: openai
|
||||
endpoint: openai://localhost:10501
|
||||
endpoint: http://localhost:10501
|
||||
- name: response_generator
|
||||
kind: openai
|
||||
endpoint: openai://localhost:10502
|
||||
endpoint: http://localhost:10502
|
||||
- name: research_agent
|
||||
kind: openai
|
||||
endpoint: https://localhost:10500
|
||||
endpoint: http://localhost:10500
|
||||
- name: input_guard_rails
|
||||
kind: openai
|
||||
endpoint: https://localhost:10503
|
||||
endpoint: http://localhost:10503
|
||||
|
||||
listeners:
|
||||
- name: tmobile
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use std::sync::Arc;
|
|||
use bytes::Bytes;
|
||||
use common::api::open_ai::{ChatCompletionsResponse, Choice};
|
||||
use common::configuration::ModelUsagePreference;
|
||||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_UPSTREAM_HOST_HEADER};
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::apis::{Role, Usage};
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
|
|
@ -12,7 +12,7 @@ use http_body_util::combinators::BoxBody;
|
|||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use hyper::{Request, Response, StatusCode, Uri};
|
||||
use serde::{ser::SerializeMap, Deserialize, Serialize};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
|
@ -83,26 +83,42 @@ pub async fn agent_chat(
|
|||
debug!("Processing agent pipeline: {}", agent_pipeline.name);
|
||||
|
||||
let mut chat_completions_history = chat_completions_request.messages.clone();
|
||||
let mut last_response: Option<String> = None;
|
||||
|
||||
// let trace_parent = request_headers
|
||||
// .iter()
|
||||
// .find(|(ty, _)| ty.as_str() == "traceparent")
|
||||
// .map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
// if let Some(trace_parent) = trace_parent {
|
||||
// request_headers.insert(
|
||||
// header::HeaderName::from_static("traceparent"),
|
||||
// header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
// );
|
||||
// }
|
||||
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
// request_headers.remove("traceparent");
|
||||
|
||||
for agent_name in agent_pipeline.filter_chain {
|
||||
debug!("Processing agent: {}", agent_name);
|
||||
let agent = agent_name_map.get(&agent_name).unwrap();
|
||||
debug!("Agent details: {:?}", agent);
|
||||
|
||||
let path = format!(
|
||||
"{}/v1/chat/completions",
|
||||
agent.endpoint.trim_end_matches('/')
|
||||
);
|
||||
|
||||
let mut request = chat_completions_request.clone();
|
||||
request.messages = chat_completions_history.clone();
|
||||
|
||||
let request_str = serde_json::to_string(&request).unwrap();
|
||||
debug!("Sending request to agent {}: {}", agent_name, request_str);
|
||||
|
||||
let mut agent_request_headers = request_headers.clone();
|
||||
agent_request_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(agent.name.as_str()).unwrap(),
|
||||
);
|
||||
|
||||
let response = match reqwest::Client::new()
|
||||
.post(path)
|
||||
.post("http://localhost:11000/v1/chat/completions")
|
||||
.headers(agent_request_headers)
|
||||
.body(request_str)
|
||||
.send()
|
||||
.await
|
||||
|
|
@ -149,14 +165,6 @@ pub async fn agent_chat(
|
|||
);
|
||||
|
||||
chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]);
|
||||
|
||||
// chat_completions_history.append(&mut vec![hermesllm::apis::openai::Message {
|
||||
// role: hermesllm::apis::openai::Role::Assistant,
|
||||
// content: hermesllm::apis::openai::MessageContent::Text(response_str),
|
||||
// name: Some(agent_name.clone()),
|
||||
// tool_calls: None,
|
||||
// tool_call_id: None,
|
||||
// }]);
|
||||
}
|
||||
|
||||
let last_response: Option<String> = match chat_completions_history.last() {
|
||||
|
|
@ -200,79 +208,4 @@ pub async fn agent_chat(
|
|||
let response_body = serde_json::to_string(&chat_completion_response).unwrap();
|
||||
|
||||
return Ok(Response::new(full(response_body)));
|
||||
|
||||
// request_headers.insert(
|
||||
// ARCH_PROVIDER_HINT_HEADER,
|
||||
// header::HeaderValue::from_str(&model_name).unwrap(),
|
||||
// );
|
||||
|
||||
// if let Some(trace_parent) = trace_parent {
|
||||
// request_headers.insert(
|
||||
// header::HeaderName::from_static("traceparent"),
|
||||
// header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
// );
|
||||
// }
|
||||
// // remove content-length header if it exists
|
||||
// request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
// let llm_response = match reqwest::Client::new()
|
||||
// .post(full_qualified_llm_provider_url)
|
||||
// .headers(request_headers)
|
||||
// .body(client_request_bytes_for_upstream)
|
||||
// .send()
|
||||
// .await
|
||||
// {
|
||||
// Ok(res) => res,
|
||||
// Err(err) => {
|
||||
// let err_msg = format!("Failed to send request: {}", err);
|
||||
// let mut internal_error = Response::new(full(err_msg));
|
||||
// *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
// return Ok(internal_error);
|
||||
// }
|
||||
// };
|
||||
|
||||
// // copy over the headers from the original response
|
||||
// let response_headers = llm_response.headers().clone();
|
||||
// let mut response = Response::builder();
|
||||
// let headers = response.headers_mut().unwrap();
|
||||
// for (header_name, header_value) in response_headers.iter() {
|
||||
// headers.insert(header_name, header_value.clone());
|
||||
// }
|
||||
|
||||
// // channel to create async stream
|
||||
// let (tx, rx) = mpsc::channel::<Bytes>(16);
|
||||
|
||||
// // Spawn a task to send data as it becomes available
|
||||
// tokio::spawn(async move {
|
||||
// let mut byte_stream = llm_response.bytes_stream();
|
||||
|
||||
// while let Some(item) = byte_stream.next().await {
|
||||
// let item = match item {
|
||||
// Ok(item) => item,
|
||||
// Err(err) => {
|
||||
// warn!("Error receiving chunk: {:?}", err);
|
||||
// break;
|
||||
// }
|
||||
// };
|
||||
|
||||
// if tx.send(item).await.is_err() {
|
||||
// warn!("Receiver dropped");
|
||||
// break;
|
||||
// }
|
||||
// }
|
||||
// });
|
||||
|
||||
// let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
|
||||
// let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
// match response.body(stream_body) {
|
||||
// Ok(response) => Ok(response),
|
||||
// Err(err) => {
|
||||
// let err_msg = format!("Failed to create response: {}", err);
|
||||
// let mut internal_error = Response::new(full(err_msg));
|
||||
// *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
// Ok(internal_error)
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,14 +29,21 @@ listeners:
|
|||
filter_chain:
|
||||
- research_agent
|
||||
- response_generator
|
||||
protocol: openai
|
||||
address: 0.0.0.0
|
||||
port: 8001
|
||||
|
||||
- name: egress_traffic
|
||||
description: llm provider configuration
|
||||
port: 12000
|
||||
protocol: openai
|
||||
llm_providers:
|
||||
- access_key: $OPENAI_API_KEY
|
||||
model: openai/gpt-4o
|
||||
- access_key: $OPENAI_API_KEY
|
||||
model: openai/gpt-4o-mini
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
trace_arch_internal: true
|
||||
|
|
|
|||
8
demos/use_cases/rag_agent/docker-compose.yaml
Normal file
8
demos/use_cases/rag_agent/docker-compose.yaml
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
services:
|
||||
jaeger:
|
||||
build:
|
||||
context: ../../shared/jaeger
|
||||
ports:
|
||||
- "16686:16686"
|
||||
- "4317:4317"
|
||||
- "4318:4318"
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
|
|
@ -55,7 +55,9 @@ def load_knowledge_base():
|
|||
knowledge_base = []
|
||||
|
||||
|
||||
async def find_relevant_passages(query: str, top_k: int = 3) -> List[Dict[str, str]]:
|
||||
async def find_relevant_passages(
|
||||
query: str, traceparent: Optional[str] = None, top_k: int = 3
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Use the LLM to find the most relevant passages from the knowledge base."""
|
||||
|
||||
if not knowledge_base:
|
||||
|
|
@ -87,11 +89,18 @@ async def find_relevant_passages(query: str, top_k: int = 3) -> List[Dict[str, s
|
|||
try:
|
||||
# Call archgw to select relevant passages
|
||||
logger.info(f"Calling archgw to find relevant passages for query: '{query}'")
|
||||
|
||||
# Prepare extra headers if traceparent is provided
|
||||
extra_headers = {}
|
||||
if traceparent:
|
||||
extra_headers["traceparent"] = traceparent
|
||||
|
||||
response = await archgw_client.chat.completions.create(
|
||||
model=RAG_MODEL,
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
temperature=0.1,
|
||||
max_tokens=50,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
result = response.choices[0].message.content.strip()
|
||||
|
|
@ -118,7 +127,9 @@ async def find_relevant_passages(query: str, top_k: int = 3) -> List[Dict[str, s
|
|||
return []
|
||||
|
||||
|
||||
async def augment_query_with_context(messages: List[ChatMessage]) -> List[ChatMessage]:
|
||||
async def augment_query_with_context(
|
||||
messages: List[ChatMessage], traceparent: Optional[str] = None
|
||||
) -> List[ChatMessage]:
|
||||
"""Extract user query, find relevant context, and augment the messages."""
|
||||
|
||||
# Find the last user message
|
||||
|
|
@ -138,7 +149,7 @@ async def augment_query_with_context(messages: List[ChatMessage]) -> List[ChatMe
|
|||
logger.info(f"Processing user query: '{last_user_message}'")
|
||||
|
||||
# Find relevant passages
|
||||
relevant_passages = await find_relevant_passages(last_user_message)
|
||||
relevant_passages = await find_relevant_passages(last_user_message, traceparent)
|
||||
|
||||
if not relevant_passages:
|
||||
logger.info("No relevant passages found, returning original messages")
|
||||
|
|
@ -178,23 +189,34 @@ app = FastAPI(title="RAG Content Builder Agent", version="1.0.0")
|
|||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
async def chat_completions(
|
||||
request_body: ChatCompletionRequest, request: Request
|
||||
) -> ChatCompletionResponse:
|
||||
"""Chat completions endpoint that augments user queries with relevant context from the knowledge base."""
|
||||
import time
|
||||
import uuid
|
||||
|
||||
logger.info(
|
||||
f"Received chat completion request with {len(request.messages)} messages"
|
||||
f"Received chat completion request with {len(request_body.messages)} messages"
|
||||
)
|
||||
|
||||
# Read traceparent header if present
|
||||
traceparent_header = request.headers.get("traceparent")
|
||||
if traceparent_header:
|
||||
logger.info(f"Received traceparent header: {traceparent_header}")
|
||||
else:
|
||||
logger.info("No traceparent header found")
|
||||
|
||||
# Augment the user query with relevant context
|
||||
updated_messages = await augment_query_with_context(request.messages)
|
||||
updated_messages = await augment_query_with_context(
|
||||
request_body.messages, traceparent_header
|
||||
)
|
||||
messages_history_json = json.dumps([msg.dict() for msg in updated_messages])
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
model=request_body.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional, Dict, Any
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
|
|
@ -26,7 +26,9 @@ archgw_client = AsyncOpenAI(
|
|||
)
|
||||
|
||||
|
||||
async def rewrite_query_with_archgw(messages: List[ChatMessage]) -> str:
|
||||
async def rewrite_query_with_archgw(
|
||||
messages: List[ChatMessage], traceparent_header: str
|
||||
) -> str:
|
||||
# Prepare the system prompt for query rewriting
|
||||
system_prompt = """You are a query rewriter that improves user queries for better retrieval.
|
||||
|
||||
|
|
@ -48,12 +50,16 @@ async def rewrite_query_with_archgw(messages: List[ChatMessage]) -> str:
|
|||
|
||||
try:
|
||||
# Call archgw using OpenAI client
|
||||
extra_headers = {}
|
||||
if traceparent_header:
|
||||
extra_headers["traceparent"] = traceparent_header
|
||||
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to rewrite query")
|
||||
response = await archgw_client.chat.completions.create(
|
||||
model=QUERY_REWRITE_MODEL,
|
||||
messages=rewrite_messages,
|
||||
temperature=0.3,
|
||||
max_tokens=200,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
rewritten_query = response.choices[0].message.content.strip()
|
||||
|
|
@ -81,20 +87,29 @@ app = FastAPI(title="RAG Agent Query Parser", version="1.0.0")
|
|||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest):
|
||||
async def chat_completions(request_body: ChatCompletionRequest, request: Request):
|
||||
"""Chat completions endpoint that rewrites the last user query using archgw."""
|
||||
import time
|
||||
import uuid
|
||||
|
||||
logger.info(
|
||||
f"Received chat completion request with {len(request.messages)} messages"
|
||||
f"Received chat completion request with {len(request_body.messages)} messages"
|
||||
)
|
||||
|
||||
# Read traceparent header if present
|
||||
traceparent_header = request.headers.get("traceparent")
|
||||
if traceparent_header:
|
||||
logger.info(f"Received traceparent header: {traceparent_header}")
|
||||
else:
|
||||
logger.info("No traceparent header found")
|
||||
|
||||
# Call archgw to rewrite the last user query
|
||||
rewritten_query = await rewrite_query_with_archgw(request.messages)
|
||||
rewritten_query = await rewrite_query_with_archgw(
|
||||
request_body.messages, traceparent_header
|
||||
)
|
||||
|
||||
# Create updated messages with the rewritten query
|
||||
updated_messages = request.messages.copy()
|
||||
updated_messages = request_body.messages.copy()
|
||||
|
||||
# Find and update the last user message with the rewritten query
|
||||
for i in range(len(updated_messages) - 1, -1, -1):
|
||||
|
|
@ -111,7 +126,7 @@ async def chat_completions(request: ChatCompletionRequest):
|
|||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
model=request_body.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import json
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
import logging
|
||||
|
|
@ -28,12 +28,19 @@ app = FastAPI(title="RAG Agent Response Generator", version="1.0.0")
|
|||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completions(request: ChatCompletionRequest):
|
||||
async def chat_completions(request_body: ChatCompletionRequest, request: Request):
|
||||
"""Chat completions endpoint that generates a coherent response based on all context."""
|
||||
logger.info(
|
||||
f"Received chat completion request with {len(request.messages)} messages"
|
||||
f"Received chat completion request with {len(request_body.messages)} messages"
|
||||
)
|
||||
|
||||
# Read traceparent header if present
|
||||
traceparent_header = request.headers.get("traceparent")
|
||||
if traceparent_header:
|
||||
logger.info(f"Received traceparent header: {traceparent_header}")
|
||||
else:
|
||||
logger.info("No traceparent header found")
|
||||
|
||||
# Prepare the system prompt for response generation
|
||||
system_prompt = """You are a helpful assistant that generates coherent, contextual responses.
|
||||
|
||||
|
|
@ -50,17 +57,24 @@ async def chat_completions(request: ChatCompletionRequest):
|
|||
response_messages = [{"role": "system", "content": system_prompt}]
|
||||
|
||||
# Add conversation history
|
||||
for msg in request.messages:
|
||||
for msg in request_body.messages:
|
||||
response_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
try:
|
||||
# Call archgw using OpenAI client
|
||||
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate response")
|
||||
|
||||
# Prepare extra headers if traceparent is provided
|
||||
extra_headers = {}
|
||||
if traceparent_header:
|
||||
extra_headers["traceparent"] = traceparent_header
|
||||
|
||||
response = await archgw_client.chat.completions.create(
|
||||
model=RESPONSE_MODEL,
|
||||
messages=response_messages,
|
||||
temperature=request.temperature or 0.7,
|
||||
max_tokens=request.max_tokens or 1000,
|
||||
temperature=request_body.temperature or 0.7,
|
||||
max_tokens=request_body.max_tokens or 1000,
|
||||
extra_headers=extra_headers,
|
||||
)
|
||||
|
||||
generated_response = response.choices[0].message.content.strip()
|
||||
|
|
@ -71,7 +85,7 @@ async def chat_completions(request: ChatCompletionRequest):
|
|||
return ChatCompletionResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4().hex[:8]}",
|
||||
created=int(time.time()),
|
||||
model=request.model,
|
||||
model=request_body.model,
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
|
|
@ -84,11 +98,11 @@ async def chat_completions(request: ChatCompletionRequest):
|
|||
],
|
||||
usage={
|
||||
"prompt_tokens": sum(
|
||||
len(msg.content.split()) for msg in request.messages
|
||||
len(msg.content.split()) for msg in request_body.messages
|
||||
),
|
||||
"completion_tokens": len(generated_response.split()),
|
||||
"total_tokens": sum(
|
||||
len(msg.content.split()) for msg in request.messages
|
||||
len(msg.content.split()) for msg in request_body.messages
|
||||
)
|
||||
+ len(generated_response.split()),
|
||||
},
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue