mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
supply per agent model
This commit is contained in:
parent
96e857a682
commit
0ba7d73284
5 changed files with 38 additions and 27 deletions
|
|
@ -101,7 +101,9 @@ impl RatelimitMap {
|
|||
) -> Result<(), Error> {
|
||||
trace!(
|
||||
"Checking limit for provider={}, with selector={:?}, consuming tokens={:?}",
|
||||
provider, selector, tokens_used
|
||||
provider,
|
||||
selector,
|
||||
tokens_used
|
||||
);
|
||||
|
||||
let provider_limits = match self.datastore.get(&provider) {
|
||||
|
|
|
|||
|
|
@ -300,25 +300,31 @@ impl HttpContext for StreamContext {
|
|||
.cloned();
|
||||
|
||||
let model_name = match self.llm_provider.as_ref() {
|
||||
Some(llm_provider) => match llm_provider.model.as_ref() {
|
||||
Some(model) => Some(model),
|
||||
None => None,
|
||||
},
|
||||
Some(llm_provider) => llm_provider.model.as_ref(),
|
||||
None => None,
|
||||
};
|
||||
|
||||
let use_agent_orchestrator = match self.overrides.as_ref() {
|
||||
Some(overrides) => overrides.use_agent_orchestrator.unwrap_or_default(),
|
||||
None => false,
|
||||
};
|
||||
|
||||
let model_requested = deserialized_body.model.clone();
|
||||
if deserialized_body.model.is_empty() || deserialized_body.model.to_lowercase() == "none" {
|
||||
deserialized_body.model = match model_name {
|
||||
Some(model_name) => model_name.clone(),
|
||||
None => {
|
||||
self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: "No model specified in request and couldn't determine model name from arch_config".to_string(),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
if use_agent_orchestrator {
|
||||
"agent_orchestrator".to_string()
|
||||
} else {
|
||||
self.send_server_error(
|
||||
ServerError::BadRequest {
|
||||
why: format!("No model specified in request and couldn't determine model name from arch_config. Model name in req: {}, arch_config, provider: {}, model: {:?}", deserialized_body.model, self.llm_provider().name, self.llm_provider().model).to_string(),
|
||||
},
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,7 +45,8 @@ impl HttpContext for StreamContext {
|
|||
warn!("Need single endpoint when use_agent_orchestrator is set");
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(
|
||||
"Need single endpoint when use_agent_orchestrator is set".to_string(),
|
||||
"Need single endpoint when use_agent_orchestrator is set"
|
||||
.to_string(),
|
||||
),
|
||||
None,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -427,7 +427,6 @@ impl StreamContext {
|
|||
headers.insert(key.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
|
||||
let call_args = CallArgs::new(
|
||||
ARCH_INTERNAL_CLUSTER_NAME,
|
||||
&path,
|
||||
|
|
@ -499,10 +498,7 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
if !prompt_target
|
||||
.auto_llm_dispatch_on_response
|
||||
.unwrap_or(true)
|
||||
{
|
||||
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
|
||||
let tool_call_response = self.tool_call_response.as_ref().unwrap().clone();
|
||||
|
||||
let direct_response_str = if self.streaming_response {
|
||||
|
|
@ -655,10 +651,7 @@ impl StreamContext {
|
|||
.clone();
|
||||
|
||||
// check if the default target should be dispatched to the LLM provider
|
||||
if !prompt_target
|
||||
.auto_llm_dispatch_on_response
|
||||
.unwrap_or(true)
|
||||
{
|
||||
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(true) {
|
||||
let default_target_response_str = if self.streaming_response {
|
||||
let chat_completion_response =
|
||||
match serde_json::from_slice::<ChatCompletionsResponse>(&body) {
|
||||
|
|
|
|||
|
|
@ -31,9 +31,10 @@ openai_client = openai.OpenAI(
|
|||
)
|
||||
|
||||
|
||||
def call_openai(messages: List[Dict[str, str]], stream: bool):
|
||||
def call_openai(messages: List[Dict[str, str]], stream: bool, model: str):
|
||||
logger.info(f"llm agent model: {model}")
|
||||
completion = openai_client.chat.completions.create(
|
||||
model="None", # archgw picks the default LLM configured in the config file
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
)
|
||||
|
|
@ -53,14 +54,19 @@ def call_openai(messages: List[Dict[str, str]], stream: bool):
|
|||
|
||||
|
||||
class Agent:
|
||||
def __init__(self, role: str, instructions: str):
|
||||
def __init__(self, role: str, instructions: str, model: str = ""):
|
||||
self.model = model
|
||||
self.system_prompt = f"You are a {role}.\n{instructions}"
|
||||
|
||||
def handle(self, req: ChatCompletionsRequest):
|
||||
messages = [{"role": "system", "content": self.get_system_prompt()}] + [
|
||||
message.model_dump() for message in req.messages
|
||||
]
|
||||
return call_openai(messages, req.stream)
|
||||
|
||||
model = req.model
|
||||
if self.model:
|
||||
model = self.model
|
||||
return call_openai(messages, req.stream, model)
|
||||
|
||||
def get_system_prompt(self) -> str:
|
||||
return self.system_prompt
|
||||
|
|
@ -77,13 +83,16 @@ AGENTS = {
|
|||
"2. Quote ridiculous price\n"
|
||||
"3. Reveal caveat if user agrees."
|
||||
),
|
||||
model="gpt-4o-mini",
|
||||
),
|
||||
"issues_and_repairs": Agent(
|
||||
role="issues and repairs agent",
|
||||
instructions="Propose a solution, offer refund if necessary.",
|
||||
model="gpt-4o",
|
||||
),
|
||||
"escalate_to_human": Agent(
|
||||
role="human escalation agent", instructions="Escalate issues to a human."
|
||||
role="human escalation agent",
|
||||
instructions="Escalate issues to a human.",
|
||||
),
|
||||
"unknown_agent": Agent(
|
||||
role="general assistant", instructions="Assist the user in general queries."
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue