fix tests

This commit is contained in:
Adil Hafeez 2025-03-14 17:02:09 -07:00
parent 2179b5a162
commit 898136470a
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
7 changed files with 65 additions and 63 deletions

View file

@ -170,8 +170,6 @@ properties:
type: object
additionalProperties:
type: string
pass_context:
type: boolean
additionalProperties: false
required:
- name

View file

@ -219,7 +219,6 @@ pub struct EndpointDetails {
#[serde(rename = "http_method")]
pub method: Option<HttpMethod>,
pub http_headers: Option<HashMap<String, String>>,
pub pass_context: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -379,7 +379,7 @@ impl StreamContext {
let http_method = endpoint_details.method.clone().unwrap_or_default();
let prompt_target_params = prompt_target.parameters.clone().unwrap_or_default();
let (path, body) = match compute_request_path_body(
let (path, api_call_body) = match compute_request_path_body(
&endpoint_path,
tool_params,
&prompt_target_params,
@ -396,6 +396,8 @@ impl StreamContext {
}
};
debug!("api call body {:?}", api_call_body);
let timeout_str = API_REQUEST_TIMEOUT_MS.to_string();
let http_method_str = http_method.to_string();
@ -411,25 +413,6 @@ impl StreamContext {
.into_iter()
.collect();
let api_call_body = match endpoint_details.pass_context.unwrap_or_default() {
true => {
let messages = self.construct_llm_messages(&callout_context);
let chat_completion_request = ChatCompletionsRequest {
model: callout_context.request_body.model.clone(),
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options.clone(),
metadata: None,
};
let body_str = serde_json::to_string(&chat_completion_request).unwrap();
Some(body_str)
}
false => body,
};
if self.request_id.is_some() {
headers.insert(REQUEST_ID_HEADER, self.request_id.as_ref().unwrap());
}
@ -444,7 +427,6 @@ impl StreamContext {
headers.insert(key.as_str(), value.as_str());
}
debug!("api call body string: {}", api_call_body.as_ref().unwrap());
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
@ -519,7 +501,7 @@ impl StreamContext {
if !prompt_target
.auto_llm_dispatch_on_response
.unwrap_or_default()
.unwrap_or(true)
{
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
@ -675,7 +657,7 @@ impl StreamContext {
// check if the default target should be dispatched to the LLM provider
if !prompt_target
.auto_llm_dispatch_on_response
.unwrap_or_default()
.unwrap_or(true)
{
let default_target_response_str = if self.streaming_response {
let chat_completion_response =

View file

@ -0,0 +1,19 @@
POST http://localhost:10000/v1/chat/completions
Content-Type: application/json
{
"messages": [
{
"role": "user",
"content": "I want to sell red shoes"
}
]
}
HTTP 200
[Asserts]
header "content-type" == "application/json"
jsonpath "$.model" matches /^gpt-4o-mini/
jsonpath "$.metadata.x-arch-state" != null
jsonpath "$.usage" != null
jsonpath "$.choices[0].message.content" != null
jsonpath "$.choices[0].message.role" == "assistant"

View file

@ -0,0 +1,16 @@
POST http://localhost:10000/v1/chat/completions
Content-Type: application/json
{
"messages": [
{
"role": "user",
"content": "I want to sell red shoes"
}
],
"stream": true
}
HTTP 200
[Asserts]
header "content-type" matches /text\/event-stream/
body matches /^data: .*?sales_agent.*?\n/

View file

@ -54,6 +54,7 @@ class ChatCompletionsRequest(BaseModel):
messages: List[Message]
model: str
metadata: Dict[str, Any] = None
stream: bool = False
class Choice(BaseModel):
@ -115,48 +116,35 @@ agent_map = {
@app.post("/v1/chat/completions")
async def completion_api(req: ChatCompletionsRequest):
logger.info(f"request: {req}")
if req.metadata is None:
req.metadata = {}
agent_name = req.metadata.get("Agent-Name", "unknown agent")
logger.info(f"agent: {agent_name}")
def stream():
agent_role = agent_map.get(agent_name)["role"]
agent_instructions = agent_map.get(agent_name)["instructions"]
system_prompt = "You are a " + agent_role + ". " + agent_instructions
messages = [{"role": "system", "content": system_prompt}]
for message in req.messages:
messages.append({"role": message.role, "content": message.content})
completion = client.chat.completions.create(
model="--",
messages=messages,
stream=True,
)
for line in completion:
if line.choices and len(line.choices) > 0 and line.choices[0].delta:
chunk_response_str = json.dumps(line.model_dump())
yield "data: " + chunk_response_str + "\n\n"
yield "data: [DONE]" + "\n\n"
agent_role = agent_map.get(agent_name)["role"]
agent_instructions = agent_map.get(agent_name)["instructions"]
system_prompt = "You are a " + agent_role + ". " + agent_instructions
messages = [{"role": "system", "content": system_prompt}]
for message in req.messages:
messages.append({"role": message.role, "content": message.content})
logger.info("messages: " + str(messages))
completion = client.chat.completions.create(
model="--",
messages=messages,
stream=req.stream,
)
# content = agent_map.get(agent_name)
if req.stream:
# for c in content:
# resp = ChatCompletionStreamResponse(
# model="--",
# choices=[
# ChunkChoice(
# delta=Message(
# role="assistant",
# content=c,
# )
# )
# ],
# )
# # random sleep between 10m and 50ms
# time.sleep(random.randint(10, 50) / 1000)
def stream():
for line in completion:
if line.choices and len(line.choices) > 0 and line.choices[0].delta:
chunk_response_str = json.dumps(line.model_dump())
yield "data: " + chunk_response_str + "\n\n"
yield "data: [DONE]" + "\n\n"
# yield "data: " + json.dumps(resp.model_dump()) + "\n\n"
return StreamingResponse(stream(), media_type="text/event-stream")
# yield "data: [DONE]" + "\n\n"
return StreamingResponse(stream(), media_type="text/event-stream")
else:
return completion

View file

@ -547,7 +547,7 @@ class ArchFunctionHandler(ArchBaseHandler):
messages=messages,
model=self.model_name,
stream=True,
extra_body={"temperature": 0.01, "logprobs": True},
extra_body=self.generation_params,
)
use_agent_orchestrator = req.metadata.get("use_agent_orchestrator", False)