add preliminary support for llm agents (#432)

This commit is contained in:
Adil Hafeez 2025-03-19 15:21:34 -07:00 committed by GitHub
parent 8d66fefded
commit 84cd1df7bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 1388 additions and 121 deletions

View file

@ -4,7 +4,7 @@ use common::api::open_ai::{
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
ChatCompletionsResponse, Message, ToolCall,
};
use common::configuration::{Overrides, PromptTarget, Tracing};
use common::configuration::{Endpoint, Overrides, PromptTarget, Tracing};
use common::consts::{
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY,
@ -46,6 +46,7 @@ pub struct StreamCallContext {
pub struct StreamContext {
system_prompt: Rc<Option<String>>,
pub prompt_targets: Rc<HashMap<String, PromptTarget>>,
pub endpoints: Rc<Option<HashMap<String, Endpoint>>>,
pub overrides: Rc<Option<Overrides>>,
pub metrics: Rc<Metrics>,
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
@ -72,6 +73,7 @@ impl StreamContext {
metrics: Rc<Metrics>,
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
overrides: Rc<Option<Overrides>>,
tracing: Rc<Option<Tracing>>,
) -> Self {
@ -80,6 +82,7 @@ impl StreamContext {
metrics,
system_prompt,
prompt_targets,
endpoints,
callouts: RefCell::new(HashMap::new()),
chat_completions_request: None,
tool_calls: None,
@ -312,12 +315,59 @@ impl StreamContext {
callout_context.prompt_target_name =
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
if let Some(overrides) = self.overrides.as_ref() {
if overrides.use_agent_orchestrator.unwrap_or_default() {
let mut metadata = HashMap::new();
metadata.insert("use_agent_orchestrator".to_string(), "true".to_string());
metadata.insert(
"agent-name".to_string(),
callout_context
.prompt_target_name
.as_ref()
.unwrap()
.to_string(),
);
if let Some(overrides) = self.overrides.as_ref() {
if overrides.optimize_context_window.unwrap_or_default() {
metadata.insert("optimize_context_window".to_string(), "true".to_string());
}
}
if let Some(overrides) = self.overrides.as_ref() {
if overrides.use_agent_orchestrator.unwrap_or_default() {
metadata.insert("use_agent_orchestrator".to_string(), "true".to_string());
}
}
let messages = self.construct_llm_messages(&callout_context);
let chat_completion_request = ChatCompletionsRequest {
model: callout_context.request_body.model.clone(),
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options.clone(),
metadata: Some(metadata),
};
let body_str = serde_json::to_string(&chat_completion_request).unwrap();
debug!("sending request to llm agent: {}", body_str);
self.set_http_request_body(0, self.request_body_size, body_str.as_bytes());
self.resume_http_request();
return;
}
}
self.schedule_api_call_request(callout_context);
}
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
// Construct messages early to avoid mutable borrow conflicts
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
let tool_params = &self.tool_calls.as_ref().unwrap()[0].function.arguments;
let endpoint_details = prompt_target.endpoint.as_ref().unwrap();
let endpoint_path: String = endpoint_details
@ -329,7 +379,7 @@ impl StreamContext {
let http_method = endpoint_details.method.clone().unwrap_or_default();
let prompt_target_params = prompt_target.parameters.clone().unwrap_or_default();
let (path, body) = match compute_request_path_body(
let (path, api_call_body) = match compute_request_path_body(
&endpoint_path,
tool_params,
&prompt_target_params,
@ -346,6 +396,8 @@ impl StreamContext {
}
};
debug!("api call body {:?}", api_call_body);
let timeout_str = API_REQUEST_TIMEOUT_MS.to_string();
let http_method_str = http_method.to_string();
@ -375,11 +427,12 @@ impl StreamContext {
headers.insert(key.as_str(), value.as_str());
}
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&path,
headers.into_iter().collect(),
body.as_deref().map(|s| s.as_bytes()),
api_call_body.as_deref().map(|s| s.as_bytes()),
vec![],
Duration::from_secs(5),
);
@ -406,6 +459,11 @@ impl StreamContext {
"developer api call response received: status code: {}",
http_status
);
let prompt_target = self
.prompt_targets
.get(callout_context.prompt_target_name.as_ref().unwrap())
.unwrap()
.clone();
if http_status != StatusCode::OK.as_str() {
warn!(
"api server responded with non 2xx status code: {}",
@ -441,6 +499,40 @@ impl StreamContext {
}
};
if !prompt_target
.auto_llm_dispatch_on_response
.unwrap_or(true)
{
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
let direct_response_str = if self.streaming_response {
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
ChatCompletionStreamResponse::new(
Some(tool_call_response.clone()),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
];
to_server_events(chunks)
} else {
tool_call_response
};
return self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![],
Some(direct_response_str.as_bytes()),
);
}
let final_prompt = format!(
"{}\ncontext: {}",
user_message.content.unwrap(),
@ -565,7 +657,7 @@ impl StreamContext {
// check if the default target should be dispatched to the LLM provider
if !prompt_target
.auto_llm_dispatch_on_response
.unwrap_or_default()
.unwrap_or(true)
{
let default_target_response_str = if self.streaming_response {
let chat_completion_response =