add more changes

This commit is contained in:
Adil Hafeez 2025-03-18 15:58:27 -07:00
parent 6d357364a3
commit b73ff5bc5b
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
7 changed files with 54 additions and 39 deletions

View file

@ -52,8 +52,6 @@ properties:
- https
http_host:
type: string
agent_orchestrator:
type: boolean
additionalProperties: false
required:
- endpoint

View file

@ -48,7 +48,7 @@ def validate_and_render_schema():
arch_config_schema = file.read()
config_yaml = yaml.safe_load(arch_config)
config_schema_yaml = yaml.safe_load(arch_config_schema)
_ = yaml.safe_load(arch_config_schema)
inferred_clusters = {}
endpoints = config_yaml.get("endpoints", {})
@ -150,12 +150,26 @@ def validate_and_render_schema():
if llm_gateway_listener.get("timeout") == None:
llm_gateway_listener["timeout"] = "10s"
agent_orchestrator = None
for name, endpoint_details in endpoints.items():
if endpoint_details.get("agent_orchestrator", False):
agent_orchestrator = name
break
use_agent_orchestrator = config_yaml.get("overrides", {}).get(
"use_agent_orchestrator", False
)
agent_orchestrator = None
if use_agent_orchestrator:
print("Using agent orchestrator")
if len(endpoints) == 0:
raise Exception(
"Please provide agent orchestrator in the endpoints section in your arch_config.yaml file"
)
elif len(endpoints) > 1:
raise Exception(
"Please provide single agent orchestrator in the endpoints section in your arch_config.yaml file"
)
else:
agent_orchestrator = list(endpoints.keys())[0]
print("agent_orchestrator: ", agent_orchestrator)
data = {
"prompt_gateway_listener": prompt_gateway_listener,
"llm_gateway_listener": llm_gateway_listener,

View file

@ -177,7 +177,6 @@ impl Display for LlmProvider {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Endpoint {
pub endpoint: Option<String>,
pub agent_orchestrator: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -37,20 +37,19 @@ impl HttpContext for StreamContext {
if overrides.use_agent_orchestrator.unwrap_or_default() {
// get endpoint that has agent_orchestrator set to true
if let Some(endpoints) = self.endpoints.as_ref() {
let agent_orchestrator = endpoints
.iter()
.find(|(_, endpoint)| endpoint.agent_orchestrator.unwrap_or_default())
.map(|(name, _)| name.clone());
if let Some(agent_orchestrator_name) = agent_orchestrator {
debug!(
"Setting ARCH_PROVIDER_HINT_HEADER to {}",
agent_orchestrator_name
if endpoints.len() == 1 {
let (name, _) = endpoints.iter().next().unwrap();
debug!("Setting ARCH_PROVIDER_HINT_HEADER to {}", name);
self.set_http_request_header(ARCH_ROUTING_HEADER, Some(&name));
} else {
warn!("Need single endpoint when use_agent_orchestrator is set");
self.send_server_error(
ServerError::LogicError(
"Need single endpoint when use_agent_orchestrator is set".to_string(),
),
None,
);
self.set_http_request_header(
ARCH_ROUTING_HEADER,
Some(&agent_orchestrator_name),
);
};
}
}
}
}

View file

@ -18,7 +18,6 @@ overrides:
endpoints:
agent_gateway:
agent_orchestrator: true
endpoint: host.docker.internal:18083
connect_timeout: 0.005s

View file

@ -75,9 +75,11 @@ class ChatCompletionStreamResponse(BaseModel):
choices: List[ChunkChoice]
client = openai.OpenAI(base_url="http://host.docker.internal:12000/v1", api_key="--")
openai_client = openai.OpenAI(
base_url="http://host.docker.internal:12000/v1", api_key="--"
)
agent_map = {
agents_definition = {
"sales_agent": {
"role": "sales agent",
"instructions": "You are a sales agent for ACME Inc."
@ -88,8 +90,7 @@ agent_map = {
" - Don't mention price.\n"
"3. Once the user is bought in, drop a ridiculous price.\n"
"4. Only after everything, and if the user says yes, "
"tell them a crazy caveat and execute their order.\n"
"",
"tell them a crazy caveat and execute their order.\n",
},
"issues_and_repairs": {
"role": "issues and repairs agent",
@ -100,8 +101,7 @@ agent_map = {
" - unless the user has already provided a reason.\n"
"2. Propose a fix (make one up).\n"
"3. ONLY if not satisfied, offer a refund.\n"
"4. If accepted, search for the ID and then execute refund."
"",
"4. If accepted, search for the ID and then execute refund.",
},
"escalate_to_human": {
"role": "human agent",
@ -109,11 +109,22 @@ agent_map = {
},
"unknown agent": {
"role": "llm agent",
"instructions": "You are an LLM agent. You can do anything you want.",
"instructions": "You are a helpful LLM agent.",
},
}
def construct_llm_messages(agent_name, messages):
agent_role = agents_definition.get(agent_name)["role"]
agent_instructions = agents_definition.get(agent_name)["instructions"]
system_prompt = "You are a " + agent_role + ". " + agent_instructions
updated_messages = [{"role": "system", "content": system_prompt}]
for message in messages:
updated_messages.append({"role": message.role, "content": message.content})
return updated_messages
@app.post("/v1/chat/completions")
async def completion_api(req: ChatCompletionsRequest):
logger.info(f"request: {req}")
@ -121,16 +132,10 @@ async def completion_api(req: ChatCompletionsRequest):
req.metadata = {}
agent_name = req.metadata.get("agent-name", "unknown agent")
logger.info(f"agent: {agent_name}")
agent_role = agent_map.get(agent_name)["role"]
agent_instructions = agent_map.get(agent_name)["instructions"]
system_prompt = "You are a " + agent_role + ". " + agent_instructions
messages = [{"role": "system", "content": system_prompt}]
for message in req.messages:
messages.append({"role": message.role, "content": message.content})
messages = construct_llm_messages(agent_name, req.messages)
logger.info("messages: " + str(messages))
completion = client.chat.completions.create(
model="--",
completion = openai_client.chat.completions.create(
model="None",
messages=messages,
stream=req.stream,
)

View file

@ -601,6 +601,7 @@ class ArchFunctionHandler(ArchBaseHandler):
if len(extracted["result"]):
verified = {}
if use_agent_orchestrator:
# skip tool call verification if using agent orchestrator
verified = {"status": True, "message": ""}
else:
verified = self._verify_tool_calls(