mirror of
https://github.com/katanemo/plano.git
synced 2026-05-10 16:22:42 +02:00
add preliminary support for llm agents (#432)
This commit is contained in:
parent
8d66fefded
commit
84cd1df7bf
29 changed files with 1388 additions and 121 deletions
|
|
@ -4,7 +4,7 @@ use common::api::open_ai::{
|
|||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, Message, ToolCall,
|
||||
};
|
||||
use common::configuration::{Overrides, PromptTarget, Tracing};
|
||||
use common::configuration::{Endpoint, Overrides, PromptTarget, Tracing};
|
||||
use common::consts::{
|
||||
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY,
|
||||
|
|
@ -46,6 +46,7 @@ pub struct StreamCallContext {
|
|||
pub struct StreamContext {
|
||||
system_prompt: Rc<Option<String>>,
|
||||
pub prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
pub endpoints: Rc<Option<HashMap<String, Endpoint>>>,
|
||||
pub overrides: Rc<Option<Overrides>>,
|
||||
pub metrics: Rc<Metrics>,
|
||||
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
|
|
@ -72,6 +73,7 @@ impl StreamContext {
|
|||
metrics: Rc<Metrics>,
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
endpoints: Rc<Option<HashMap<String, Endpoint>>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
tracing: Rc<Option<Tracing>>,
|
||||
) -> Self {
|
||||
|
|
@ -80,6 +82,7 @@ impl StreamContext {
|
|||
metrics,
|
||||
system_prompt,
|
||||
prompt_targets,
|
||||
endpoints,
|
||||
callouts: RefCell::new(HashMap::new()),
|
||||
chat_completions_request: None,
|
||||
tool_calls: None,
|
||||
|
|
@ -312,12 +315,59 @@ impl StreamContext {
|
|||
callout_context.prompt_target_name =
|
||||
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
|
||||
|
||||
if let Some(overrides) = self.overrides.as_ref() {
|
||||
if overrides.use_agent_orchestrator.unwrap_or_default() {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("use_agent_orchestrator".to_string(), "true".to_string());
|
||||
|
||||
metadata.insert(
|
||||
"agent-name".to_string(),
|
||||
callout_context
|
||||
.prompt_target_name
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.to_string(),
|
||||
);
|
||||
|
||||
if let Some(overrides) = self.overrides.as_ref() {
|
||||
if overrides.optimize_context_window.unwrap_or_default() {
|
||||
metadata.insert("optimize_context_window".to_string(), "true".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(overrides) = self.overrides.as_ref() {
|
||||
if overrides.use_agent_orchestrator.unwrap_or_default() {
|
||||
metadata.insert("use_agent_orchestrator".to_string(), "true".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let messages = self.construct_llm_messages(&callout_context);
|
||||
|
||||
let chat_completion_request = ChatCompletionsRequest {
|
||||
model: callout_context.request_body.model.clone(),
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options.clone(),
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
|
||||
let body_str = serde_json::to_string(&chat_completion_request).unwrap();
|
||||
debug!("sending request to llm agent: {}", body_str);
|
||||
self.set_http_request_body(0, self.request_body_size, body_str.as_bytes());
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
self.schedule_api_call_request(callout_context);
|
||||
}
|
||||
|
||||
fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) {
|
||||
// Construct messages early to avoid mutable borrow conflicts
|
||||
|
||||
let tools_call_name = self.tool_calls.as_ref().unwrap()[0].function.name.clone();
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap();
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
||||
let tool_params = &self.tool_calls.as_ref().unwrap()[0].function.arguments;
|
||||
let endpoint_details = prompt_target.endpoint.as_ref().unwrap();
|
||||
let endpoint_path: String = endpoint_details
|
||||
|
|
@ -329,7 +379,7 @@ impl StreamContext {
|
|||
let http_method = endpoint_details.method.clone().unwrap_or_default();
|
||||
let prompt_target_params = prompt_target.parameters.clone().unwrap_or_default();
|
||||
|
||||
let (path, body) = match compute_request_path_body(
|
||||
let (path, api_call_body) = match compute_request_path_body(
|
||||
&endpoint_path,
|
||||
tool_params,
|
||||
&prompt_target_params,
|
||||
|
|
@ -346,6 +396,8 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
debug!("api call body {:?}", api_call_body);
|
||||
|
||||
let timeout_str = API_REQUEST_TIMEOUT_MS.to_string();
|
||||
|
||||
let http_method_str = http_method.to_string();
|
||||
|
|
@ -375,11 +427,12 @@ impl StreamContext {
|
|||
headers.insert(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&path,
|
||||
headers.into_iter().collect(),
|
||||
body.as_deref().map(|s| s.as_bytes()),
|
||||
api_call_body.as_deref().map(|s| s.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
|
|
@ -406,6 +459,11 @@ impl StreamContext {
|
|||
"developer api call response received: status code: {}",
|
||||
http_status
|
||||
);
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||
.unwrap()
|
||||
.clone();
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
warn!(
|
||||
"api server responded with non 2xx status code: {}",
|
||||
|
|
@ -441,6 +499,40 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
if !prompt_target
|
||||
.auto_llm_dispatch_on_response
|
||||
.unwrap_or(true)
|
||||
{
|
||||
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
|
||||
|
||||
let direct_response_str = if self.streaming_response {
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
Some(tool_call_response.clone()),
|
||||
None,
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
to_server_events(chunks)
|
||||
} else {
|
||||
tool_call_response
|
||||
};
|
||||
|
||||
return self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![],
|
||||
Some(direct_response_str.as_bytes()),
|
||||
);
|
||||
}
|
||||
|
||||
let final_prompt = format!(
|
||||
"{}\ncontext: {}",
|
||||
user_message.content.unwrap(),
|
||||
|
|
@ -565,7 +657,7 @@ impl StreamContext {
|
|||
// check if the default target should be dispatched to the LLM provider
|
||||
if !prompt_target
|
||||
.auto_llm_dispatch_on_response
|
||||
.unwrap_or_default()
|
||||
.unwrap_or(true)
|
||||
{
|
||||
let default_target_response_str = if self.streaming_response {
|
||||
let chat_completion_response =
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue