diff --git a/arch/tools/cli/config_generator.py b/arch/tools/cli/config_generator.py index 806d5113..90160fb3 100644 --- a/arch/tools/cli/config_generator.py +++ b/arch/tools/cli/config_generator.py @@ -113,15 +113,21 @@ def validate_and_render_schema(): "port": ingress_traffic.get("port", 10000), "address": ingress_traffic.get("address", "0.0.0.0"), "timeout": ingress_traffic.get("timeout", "30s"), + "protocol": "openai", } config_yaml["listeners"].append(prompt_gateway_listener) if egress_traffic: + llm_providers = [] + if config_yaml.get("llm_providers"): + llm_providers = config_yaml["llm_providers"] + del config_yaml["llm_providers"] llm_gateway_listener = { "name": "egress_traffic", "port": egress_traffic.get("port", 12000), "address": egress_traffic.get("address", "0.0.0.0"), "timeout": egress_traffic.get("timeout", "30s"), - "llm_providers": config_yaml.get("llm_providers", []), + "llm_providers": llm_providers, + "protocol": "openai", } config_yaml["listeners"].append(llm_gateway_listener) @@ -237,7 +243,9 @@ def validate_and_render_schema(): } ) - config_yaml["llm_providers"] = updated_llm_providers + for listener in config_yaml["listeners"]: + if listener.get("name") == "egress_traffic": + listener["llm_providers"] = updated_llm_providers arch_config_string = yaml.dump(config_yaml) arch_llm_config_string = yaml.dump(config_yaml) @@ -279,7 +287,7 @@ def validate_and_render_schema(): "arch_config": arch_config_string, "arch_llm_config": arch_llm_config_string, "arch_clusters": inferred_clusters, - "arch_llm_providers": config_yaml["llm_providers"], + "arch_llm_providers": updated_llm_providers, "arch_tracing": arch_tracing, "local_llms": llms_with_endpoint, "agent_orchestrator": agent_orchestrator, diff --git a/demos/use_cases/rag_agent/src/rag_agent/__init__.py b/demos/use_cases/rag_agent/src/rag_agent/__init__.py index 318972e9..b26f3e20 100644 --- a/demos/use_cases/rag_agent/src/rag_agent/__init__.py +++ b/demos/use_cases/rag_agent/src/rag_agent/__init__.py @@ -3,11 +3,12 @@ from mcp.server.fastmcp import FastMCP mcp = None + @click.command() -@click.option('--transport', 'transport', default='stdio') -@click.option('--host', 'host', default='localhost') -@click.option('--port', 'port', default=10101) -@click.option('--agent', 'agent', default=None) +@click.option("--transport", "transport", default="stdio") +@click.option("--host", "host", default="localhost") +@click.option("--port", "port", default=10101) +@click.option("--agent", "agent", default=None) def main(host, port, agent, transport): print(f"Starting agent(s): {agent if agent else 'all'}") global mcp @@ -26,5 +27,6 @@ def main(host, port, agent, transport): print("All agents loaded.") mcp.run(transport=transport) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/demos/use_cases/rag_agent/src/rag_agent/document_store.py b/demos/use_cases/rag_agent/src/rag_agent/document_store.py index 86ab092f..93dfc228 100644 --- a/demos/use_cases/rag_agent/src/rag_agent/document_store.py +++ b/demos/use_cases/rag_agent/src/rag_agent/document_store.py @@ -1,6 +1,7 @@ from pydantic import BaseModel from . import mcp + class QueryRequest(BaseModel): query: str metadata: dict | None = None @@ -10,6 +11,7 @@ class QueryResponse(BaseModel): query: str results: list + @mcp.tool() def query_rag_store(request: QueryRequest): """Query the RAG document store.""" diff --git a/demos/use_cases/rag_agent/src/rag_agent/query_parser.py b/demos/use_cases/rag_agent/src/rag_agent/query_parser.py index 20305f0b..1bc54548 100644 --- a/demos/use_cases/rag_agent/src/rag_agent/query_parser.py +++ b/demos/use_cases/rag_agent/src/rag_agent/query_parser.py @@ -1,13 +1,13 @@ from pydantic import BaseModel from . import mcp + class Response(BaseModel): query: str metadata: dict + @mcp.tool() def parse_query(query): """Parse the user query and returns metadata extracted from query.""" - return Response(query=query, metadata={ - "is_valid": True - }) + return Response(query=query, metadata={"is_valid": True}) diff --git a/demos/use_cases/rag_agent/src/rag_agent/response_generator.py b/demos/use_cases/rag_agent/src/rag_agent/response_generator.py index 7a71d2d5..a612c626 100644 --- a/demos/use_cases/rag_agent/src/rag_agent/response_generator.py +++ b/demos/use_cases/rag_agent/src/rag_agent/response_generator.py @@ -1,6 +1,11 @@ from . import mcp + @mcp.tool() def generate_response(query, context): """Generate a response based on the user query and context.""" - return {"query": query, "context": context, "response": "This is a generated response."} + return { + "query": query, + "context": context, + "response": "This is a generated response.", + } diff --git a/docs/source/resources/includes/arch_config_full_reference_rendered.yaml b/docs/source/resources/includes/arch_config_full_reference_rendered.yaml index 503f6a80..e0b00f62 100644 --- a/docs/source/resources/includes/arch_config_full_reference_rendered.yaml +++ b/docs/source/resources/includes/arch_config_full_reference_rendered.yaml @@ -10,33 +10,33 @@ endpoints: endpoint: 127.0.0.1 port: 8001 listeners: - egress_traffic: - address: 0.0.0.0 - message_format: openai - port: 12000 - timeout: 5s - ingress_traffic: - address: 0.0.0.0 - message_format: openai - port: 10000 - timeout: 5s -llm_providers: -- access_key: $OPENAI_API_KEY - default: true - model: gpt-4o - name: openai/gpt-4o - provider_interface: openai -- access_key: $MISTRAL_API_KEY - model: mistral-8x7b - name: mistral/mistral-8x7b - provider_interface: mistral -- base_url: http://mistral_local - endpoint: mistral_local - model: mistral-7b-instruct - name: mistral/mistral-7b-instruct - port: 80 - protocol: http - provider_interface: mistral +- address: 0.0.0.0 + name: ingress_traffic + port: 10000 + protocol: openai + timeout: 5s +- address: 0.0.0.0 + llm_providers: + - access_key: $OPENAI_API_KEY + default: true + model: gpt-4o + name: openai/gpt-4o + provider_interface: openai + - access_key: $MISTRAL_API_KEY + model: mistral-8x7b + name: mistral/mistral-8x7b + provider_interface: mistral + - base_url: http://mistral_local + endpoint: mistral_local + model: mistral-7b-instruct + name: mistral/mistral-7b-instruct + port: 80 + protocol: http + provider_interface: mistral + name: egress_traffic + port: 12000 + protocol: openai + timeout: 5s overrides: prompt_target_intent_matching_threshold: 0.6 prompt_guards: