diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index e5d8b88a..59276589 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -52,8 +52,6 @@ properties: - https http_host: type: string - agent_orchestrator: - type: boolean additionalProperties: false required: - endpoint diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 615a3df5..44ceda27 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -48,7 +48,7 @@ def validate_and_render_schema(): arch_config_schema = file.read() config_yaml = yaml.safe_load(arch_config) - config_schema_yaml = yaml.safe_load(arch_config_schema) + _ = yaml.safe_load(arch_config_schema) inferred_clusters = {} endpoints = config_yaml.get("endpoints", {}) @@ -150,12 +150,26 @@ def validate_and_render_schema(): if llm_gateway_listener.get("timeout") == None: llm_gateway_listener["timeout"] = "10s" - agent_orchestrator = None - for name, endpoint_details in endpoints.items(): - if endpoint_details.get("agent_orchestrator", False): - agent_orchestrator = name - break + use_agent_orchestrator = config_yaml.get("overrides", {}).get( + "use_agent_orchestrator", False + ) + agent_orchestrator = None + if use_agent_orchestrator: + print("Using agent orchestrator") + + if len(endpoints) == 0: + raise Exception( + "Please provide agent orchestrator in the endpoints section in your arch_config.yaml file" + ) + elif len(endpoints) > 1: + raise Exception( + "Please provide single agent orchestrator in the endpoints section in your arch_config.yaml file" + ) + else: + agent_orchestrator = list(endpoints.keys())[0] + + print("agent_orchestrator: ", agent_orchestrator) data = { "prompt_gateway_listener": prompt_gateway_listener, "llm_gateway_listener": llm_gateway_listener, diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index a956e71c..2065b1aa 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -177,7 +177,6 @@ impl Display for LlmProvider { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Endpoint { pub endpoint: Option, - pub agent_orchestrator: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index ba3e71bc..a26af5be 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -37,20 +37,19 @@ impl HttpContext for StreamContext { if overrides.use_agent_orchestrator.unwrap_or_default() { // get endpoint that has agent_orchestrator set to true if let Some(endpoints) = self.endpoints.as_ref() { - let agent_orchestrator = endpoints - .iter() - .find(|(_, endpoint)| endpoint.agent_orchestrator.unwrap_or_default()) - .map(|(name, _)| name.clone()); - if let Some(agent_orchestrator_name) = agent_orchestrator { - debug!( - "Setting ARCH_PROVIDER_HINT_HEADER to {}", - agent_orchestrator_name + if endpoints.len() == 1 { + let (name, _) = endpoints.iter().next().unwrap(); + debug!("Setting ARCH_PROVIDER_HINT_HEADER to {}", name); + self.set_http_request_header(ARCH_ROUTING_HEADER, Some(&name)); + } else { + warn!("Need single endpoint when use_agent_orchestrator is set"); + self.send_server_error( + ServerError::LogicError( + "Need single endpoint when use_agent_orchestrator is set".to_string(), + ), + None, ); - self.set_http_request_header( - ARCH_ROUTING_HEADER, - Some(&agent_orchestrator_name), - ); - }; + } } } } diff --git a/demos/use_cases/orchestrating_agents/arch_config.yaml b/demos/use_cases/orchestrating_agents/arch_config.yaml index 9e90f158..7cffa101 100644 --- a/demos/use_cases/orchestrating_agents/arch_config.yaml +++ b/demos/use_cases/orchestrating_agents/arch_config.yaml @@ -18,7 +18,6 @@ overrides: endpoints: agent_gateway: - agent_orchestrator: true endpoint: host.docker.internal:18083 connect_timeout: 0.005s diff --git a/demos/use_cases/orchestrating_agents/main.py b/demos/use_cases/orchestrating_agents/main.py index 13f04647..c72a0d70 100644 --- a/demos/use_cases/orchestrating_agents/main.py +++ b/demos/use_cases/orchestrating_agents/main.py @@ -75,9 +75,11 @@ class ChatCompletionStreamResponse(BaseModel): choices: List[ChunkChoice] -client = openai.OpenAI(base_url="http://host.docker.internal:12000/v1", api_key="--") +openai_client = openai.OpenAI( + base_url="http://host.docker.internal:12000/v1", api_key="--" +) -agent_map = { +agents_definition = { "sales_agent": { "role": "sales agent", "instructions": "You are a sales agent for ACME Inc." @@ -88,8 +90,7 @@ agent_map = { " - Don't mention price.\n" "3. Once the user is bought in, drop a ridiculous price.\n" "4. Only after everything, and if the user says yes, " - "tell them a crazy caveat and execute their order.\n" - "", + "tell them a crazy caveat and execute their order.\n", }, "issues_and_repairs": { "role": "issues and repairs agent", @@ -100,8 +101,7 @@ agent_map = { " - unless the user has already provided a reason.\n" "2. Propose a fix (make one up).\n" "3. ONLY if not satisfied, offer a refund.\n" - "4. If accepted, search for the ID and then execute refund." - "", + "4. If accepted, search for the ID and then execute refund.", }, "escalate_to_human": { "role": "human agent", @@ -109,11 +109,22 @@ agent_map = { }, "unknown agent": { "role": "llm agent", - "instructions": "You are an LLM agent. You can do anything you want.", + "instructions": "You are a helpful LLM agent.", }, } +def construct_llm_messages(agent_name, messages): + agent_role = agents_definition.get(agent_name)["role"] + agent_instructions = agents_definition.get(agent_name)["instructions"] + system_prompt = "You are a " + agent_role + ". " + agent_instructions + + updated_messages = [{"role": "system", "content": system_prompt}] + for message in messages: + updated_messages.append({"role": message.role, "content": message.content}) + return updated_messages + + @app.post("/v1/chat/completions") async def completion_api(req: ChatCompletionsRequest): logger.info(f"request: {req}") @@ -121,16 +132,10 @@ async def completion_api(req: ChatCompletionsRequest): req.metadata = {} agent_name = req.metadata.get("agent-name", "unknown agent") logger.info(f"agent: {agent_name}") - - agent_role = agent_map.get(agent_name)["role"] - agent_instructions = agent_map.get(agent_name)["instructions"] - system_prompt = "You are a " + agent_role + ". " + agent_instructions - messages = [{"role": "system", "content": system_prompt}] - for message in req.messages: - messages.append({"role": message.role, "content": message.content}) + messages = construct_llm_messages(agent_name, req.messages) logger.info("messages: " + str(messages)) - completion = client.chat.completions.create( - model="--", + completion = openai_client.chat.completions.create( + model="None", messages=messages, stream=req.stream, ) diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 457aced8..8c01493a 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -601,6 +601,7 @@ class ArchFunctionHandler(ArchBaseHandler): if len(extracted["result"]): verified = {} if use_agent_orchestrator: + # skip tool call verification if using agent orchestrator verified = {"status": True, "message": ""} else: verified = self._verify_tool_calls(