add preliminary support for llm agents

This commit is contained in:
Adil Hafeez 2025-03-12 15:45:05 -07:00
parent ffb8566c36
commit 8104eac596
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
17 changed files with 1508 additions and 79 deletions

View file

@ -189,7 +189,7 @@ pub struct ToolCall {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDetail {
pub name: String,
pub arguments: HashMap<String, Value>,
pub arguments: Option<HashMap<String, Value>>,
}
#[derive(Debug, Deserialize, Serialize)]

View file

@ -25,6 +25,7 @@ pub struct Configuration {
pub struct Overrides {
pub prompt_target_intent_matching_threshold: Option<f64>,
pub optimize_context_window: Option<bool>,
pub use_agent_orchestrator: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
@ -217,6 +218,7 @@ pub struct EndpointDetails {
#[serde(rename = "http_method")]
pub method: Option<HttpMethod>,
pub http_headers: Option<HashMap<String, String>>,
pub pass_context: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -217,6 +217,12 @@ impl HttpContext for StreamContext {
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
// Let the client send the gateway all the data before sending to the LLM_provider.
// TODO: consider a streaming API.
trace!(
"on_http_request_body [S={}] bytes={} end_stream={}",
self.context_id,
body_size,
end_of_stream
);
if self.request_body_sent_time.is_none() {
self.request_body_sent_time = Some(current_time_ns());
@ -230,33 +236,37 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
let body_bytes = match self.get_http_request_body(0, body_size) {
Some(body_bytes) => body_bytes,
None => {
self.send_server_error(
ServerError::LogicError(format!(
"Failed to obtain body bytes even though body_size is {}",
body_size
)),
None,
);
return Action::Pause;
}
};
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
self.send_server_error(
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
},
None => {
match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(e) => {
debug!("body str: {}", String::from_utf8_lossy(&body_bytes));
self.send_server_error(
ServerError::LogicError(format!(
"Failed to obtain body bytes even though body_size is {}",
body_size
)),
None,
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
};
// remove metadata from the request body
//TODO: move this to prompt gateway
deserialized_body.metadata = None;
// delete model key from message array
for message in deserialized_body.messages.iter_mut() {

View file

@ -152,6 +152,18 @@ impl HttpContext for StreamContext {
}
}
if let Some(overrides) = self.overrides.as_ref() {
if overrides.use_agent_orchestrator.unwrap_or_default() {
if metadata.is_none() {
metadata = Some(HashMap::new());
}
metadata
.as_mut()
.unwrap()
.insert("use_agent_orchestrator".to_string(), "true".to_string());
}
}
let arch_fc_chat_completion_request = ChatCompletionsRequest {
messages: deserialized_body.messages.clone(),
metadata,

View file

@ -316,8 +316,10 @@ impl StreamContext {
}
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
// Construct messages early to avoid mutable borrow conflicts
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap();
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
let tool_params = &self.tool_calls.as_ref().unwrap()[0].function.arguments;
let endpoint_details = prompt_target.endpoint.as_ref().unwrap();
let endpoint_path: String = endpoint_details
@ -361,6 +363,25 @@ impl StreamContext {
.into_iter()
.collect();
let api_call_body = match endpoint_details.pass_context.unwrap_or_default() {
true => {
let messages = self.construct_llm_messages(&callout_context);
let chat_completion_request = ChatCompletionsRequest {
model: callout_context.request_body.model.clone(),
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options.clone(),
metadata: None,
};
let body_str = serde_json::to_string(&chat_completion_request).unwrap();
Some(body_str)
}
false => body,
};
if self.request_id.is_some() {
headers.insert(REQUEST_ID_HEADER, self.request_id.as_ref().unwrap());
}
@ -375,11 +396,13 @@ impl StreamContext {
headers.insert(key.as_str(), value.as_str());
}
debug!("api call body string: {}", api_call_body.as_ref().unwrap());
let call_args = CallArgs::new(
ARCH_INTERNAL_CLUSTER_NAME,
&path,
headers.into_iter().collect(),
body.as_deref().map(|s| s.as_bytes()),
api_call_body.as_deref().map(|s| s.as_bytes()),
vec![],
Duration::from_secs(5),
);
@ -406,6 +429,11 @@ impl StreamContext {
"developer api call response received: status code: {}",
http_status
);
let prompt_target = self
.prompt_targets
.get(callout_context.prompt_target_name.as_ref().unwrap())
.unwrap()
.clone();
if http_status != StatusCode::OK.as_str() {
warn!(
"api server responded with non 2xx status code: {}",
@ -441,6 +469,40 @@ impl StreamContext {
}
};
if !prompt_target
.auto_llm_dispatch_on_response
.unwrap_or_default()
{
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
let direct_response_str = if self.streaming_response {
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
ChatCompletionStreamResponse::new(
Some(tool_call_response.clone()),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
];
to_server_events(chunks)
} else {
tool_call_response
};
return self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![],
Some(direct_response_str.as_bytes()),
);
}
let final_prompt = format!(
"{}\ncontext: {}",
user_message.content.unwrap(),

View file

@ -4,8 +4,11 @@ use std::collections::HashMap;
use serde_yaml::Value;
// only add params that are of string, number and bool type
pub fn filter_tool_params(tool_params: &HashMap<String, Value>) -> HashMap<String, String> {
tool_params
pub fn filter_tool_params(tool_params: &Option<HashMap<String, Value>>) -> HashMap<String, String> {
if tool_params.is_none() {
return HashMap::new();
}
tool_params.as_ref().unwrap()
.iter()
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
.map(|(key, value)| match value {
@ -22,7 +25,7 @@ pub fn filter_tool_params(tool_params: &HashMap<String, Value>) -> HashMap<Strin
pub fn compute_request_path_body(
endpoint_path: &str,
tool_params: &HashMap<String, Value>,
tool_params: &Option<HashMap<String, Value>>,
prompt_target_params: &[Parameter],
http_method: &HttpMethod,
) -> Result<(String, Option<String>), String> {