mirror of
https://github.com/katanemo/plano.git
synced 2026-06-23 15:38:07 +02:00
add preliminary support for llm agents
This commit is contained in:
parent
ffb8566c36
commit
8104eac596
17 changed files with 1508 additions and 79 deletions
|
|
@ -189,7 +189,7 @@ pub struct ToolCall {
|
|||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCallDetail {
|
||||
pub name: String,
|
||||
pub arguments: HashMap<String, Value>,
|
||||
pub arguments: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ pub struct Configuration {
|
|||
pub struct Overrides {
|
||||
pub prompt_target_intent_matching_threshold: Option<f64>,
|
||||
pub optimize_context_window: Option<bool>,
|
||||
pub use_agent_orchestrator: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
@ -217,6 +218,7 @@ pub struct EndpointDetails {
|
|||
#[serde(rename = "http_method")]
|
||||
pub method: Option<HttpMethod>,
|
||||
pub http_headers: Option<HashMap<String, String>>,
|
||||
pub pass_context: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
|
|||
|
|
@ -217,6 +217,12 @@ impl HttpContext for StreamContext {
|
|||
fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||
// Let the client send the gateway all the data before sending to the LLM_provider.
|
||||
// TODO: consider a streaming API.
|
||||
trace!(
|
||||
"on_http_request_body [S={}] bytes={} end_stream={}",
|
||||
self.context_id,
|
||||
body_size,
|
||||
end_of_stream
|
||||
);
|
||||
|
||||
if self.request_body_sent_time.is_none() {
|
||||
self.request_body_sent_time = Some(current_time_ns());
|
||||
|
|
@ -230,33 +236,37 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
let body_bytes = match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => body_bytes,
|
||||
None => {
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
)),
|
||||
None,
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let mut deserialized_body: ChatCompletionsRequest =
|
||||
match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
ServerError::Deserialization(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(e) => {
|
||||
debug!("body str: {}", String::from_utf8_lossy(&body_bytes));
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
)),
|
||||
None,
|
||||
ServerError::Deserialization(e),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
// remove metadata from the request body
|
||||
//TODO: move this to prompt gateway
|
||||
deserialized_body.metadata = None;
|
||||
// delete model key from message array
|
||||
for message in deserialized_body.messages.iter_mut() {
|
||||
|
|
|
|||
|
|
@ -152,6 +152,18 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(overrides) = self.overrides.as_ref() {
|
||||
if overrides.use_agent_orchestrator.unwrap_or_default() {
|
||||
if metadata.is_none() {
|
||||
metadata = Some(HashMap::new());
|
||||
}
|
||||
metadata
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.insert("use_agent_orchestrator".to_string(), "true".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let arch_fc_chat_completion_request = ChatCompletionsRequest {
|
||||
messages: deserialized_body.messages.clone(),
|
||||
metadata,
|
||||
|
|
|
|||
|
|
@ -316,8 +316,10 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
|
||||
// Construct messages early to avoid mutable borrow conflicts
|
||||
|
||||
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap();
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
||||
let tool_params = &self.tool_calls.as_ref().unwrap()[0].function.arguments;
|
||||
let endpoint_details = prompt_target.endpoint.as_ref().unwrap();
|
||||
let endpoint_path: String = endpoint_details
|
||||
|
|
@ -361,6 +363,25 @@ impl StreamContext {
|
|||
.into_iter()
|
||||
.collect();
|
||||
|
||||
let api_call_body = match endpoint_details.pass_context.unwrap_or_default() {
|
||||
true => {
|
||||
let messages = self.construct_llm_messages(&callout_context);
|
||||
|
||||
let chat_completion_request = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model.clone(),
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options.clone(),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let body_str = serde_json::to_string(&chat_completion_request).unwrap();
|
||||
Some(body_str)
|
||||
}
|
||||
false => body,
|
||||
};
|
||||
|
||||
if self.request_id.is_some() {
|
||||
headers.insert(REQUEST_ID_HEADER, self.request_id.as_ref().unwrap());
|
||||
}
|
||||
|
|
@ -375,11 +396,13 @@ impl StreamContext {
|
|||
headers.insert(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
debug!("api call body string: {}", api_call_body.as_ref().unwrap());
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&path,
|
||||
headers.into_iter().collect(),
|
||||
body.as_deref().map(|s| s.as_bytes()),
|
||||
api_call_body.as_deref().map(|s| s.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
|
|
@ -406,6 +429,11 @@ impl StreamContext {
|
|||
"developer api call response received: status code: {}",
|
||||
http_status
|
||||
);
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
warn!(
|
||||
"api server responded with non 2xx status code: {}",
|
||||
|
|
@ -441,6 +469,40 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
if !prompt_target
|
||||
.auto_llm_dispatch_on_response
|
||||
.unwrap_or_default()
|
||||
{
|
||||
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
|
||||
|
||||
let direct_response_str = if self.streaming_response {
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
Some(tool_call_response.clone()),
|
||||
None,
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
to_server_events(chunks)
|
||||
} else {
|
||||
tool_call_response
|
||||
};
|
||||
|
||||
return self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![],
|
||||
Some(direct_response_str.as_bytes()),
|
||||
);
|
||||
}
|
||||
|
||||
let final_prompt = format!(
|
||||
"{}\ncontext: {}",
|
||||
user_message.content.unwrap(),
|
||||
|
|
|
|||
|
|
@ -4,8 +4,11 @@ use std::collections::HashMap;
|
|||
use serde_yaml::Value;
|
||||
|
||||
// only add params that are of string, number and bool type
|
||||
pub fn filter_tool_params(tool_params: &HashMap<String, Value>) -> HashMap<String, String> {
|
||||
tool_params
|
||||
pub fn filter_tool_params(tool_params: &Option<HashMap<String, Value>>) -> HashMap<String, String> {
|
||||
if tool_params.is_none() {
|
||||
return HashMap::new();
|
||||
}
|
||||
tool_params.as_ref().unwrap()
|
||||
.iter()
|
||||
.filter(|(_, value)| value.is_number() || value.is_string() || value.is_bool())
|
||||
.map(|(key, value)| match value {
|
||||
|
|
@ -22,7 +25,7 @@ pub fn filter_tool_params(tool_params: &HashMap<String, Value>) -> HashMap<Strin
|
|||
|
||||
pub fn compute_request_path_body(
|
||||
endpoint_path: &str,
|
||||
tool_params: &HashMap<String, Value>,
|
||||
tool_params: &Option<HashMap<String, Value>>,
|
||||
prompt_target_params: &[Parameter],
|
||||
http_method: &HttpMethod,
|
||||
) -> Result<(String, Option<String>), String> {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue