mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add tests for config generator
This commit is contained in:
parent
0a724ebfd6
commit
1a220b4634
5 changed files with 290 additions and 136 deletions
|
|
@ -5,19 +5,6 @@ import yaml
|
|||
from jsonschema import validate
|
||||
from urllib.parse import urlparse
|
||||
|
||||
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
|
||||
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
|
||||
)
|
||||
ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml")
|
||||
ARCH_CONFIG_FILE_RENDERED = os.getenv(
|
||||
"ARCH_CONFIG_FILE_RENDERED", "/app/arch_config_rendered.yaml"
|
||||
)
|
||||
ENVOY_CONFIG_FILE_RENDERED = os.getenv(
|
||||
"ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml"
|
||||
)
|
||||
ARCH_CONFIG_SCHEMA_FILE = os.getenv(
|
||||
"ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml"
|
||||
)
|
||||
|
||||
SUPPORTED_PROVIDERS = [
|
||||
"arch",
|
||||
|
|
@ -45,8 +32,22 @@ def get_endpoint_and_port(endpoint, protocol):
|
|||
|
||||
|
||||
def validate_and_render_schema():
|
||||
env = Environment(loader=FileSystemLoader("./"))
|
||||
template = env.get_template("envoy.template.yaml")
|
||||
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
|
||||
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
|
||||
)
|
||||
ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml")
|
||||
ARCH_CONFIG_FILE_RENDERED = os.getenv(
|
||||
"ARCH_CONFIG_FILE_RENDERED", "/app/arch_config_rendered.yaml"
|
||||
)
|
||||
ENVOY_CONFIG_FILE_RENDERED = os.getenv(
|
||||
"ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml"
|
||||
)
|
||||
ARCH_CONFIG_SCHEMA_FILE = os.getenv(
|
||||
"ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml"
|
||||
)
|
||||
|
||||
env = Environment(loader=FileSystemLoader(os.getenv("TEMPLATE_ROOT", "./")))
|
||||
template = env.get_template(ENVOY_CONFIG_TEMPLATE_FILE)
|
||||
|
||||
try:
|
||||
validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE)
|
||||
|
|
@ -248,6 +249,7 @@ def validate_and_render_schema():
|
|||
agent_orchestrator = list(endpoints.keys())[0]
|
||||
|
||||
print("agent_orchestrator: ", agent_orchestrator)
|
||||
|
||||
data = {
|
||||
"prompt_gateway_listener": prompt_gateway_listener,
|
||||
"llm_gateway_listener": llm_gateway_listener,
|
||||
|
|
@ -284,7 +286,7 @@ def validate_prompt_config(arch_config_file, arch_config_schema_file):
|
|||
validate(config_yaml, config_schema_yaml)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error validating arch_config file: {arch_config_file}, schema file: {arch_config_schema_file}, error: {e.message}"
|
||||
f"Error validating arch_config file: {arch_config_file}, schema file: {arch_config_schema_file}, error: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
|
|
|||
|
|
@ -1,45 +0,0 @@
|
|||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict, Set
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
name: str = Field(
|
||||
"John Doe", description="The name of the user."
|
||||
) # Default value and description for name
|
||||
location: int = None
|
||||
age: int = Field(
|
||||
30, description="The age of the user."
|
||||
) # Default value and description for age
|
||||
tags: Set[str] = Field(
|
||||
default_factory=set, description="A set of tags associated with the user."
|
||||
) # Default empty set and description for tags
|
||||
metadata: Dict[str, int] = Field(
|
||||
default_factory=dict,
|
||||
description="A dictionary storing metadata about the user, with string keys and integer values.",
|
||||
) # Default empty dict and description for metadata
|
||||
|
||||
|
||||
@app.get("/agent/default")
|
||||
async def default(request: User):
|
||||
"""
|
||||
This endpoint handles information extraction queries.
|
||||
It can summarize, extract details, and perform various other information-related tasks.
|
||||
"""
|
||||
return {"info": f"Query: {request.name}, Count: {request.age}"}
|
||||
|
||||
|
||||
@app.post("/agent/action")
|
||||
async def reboot_network_device(device_id: str, confirmation: str):
|
||||
"""
|
||||
This endpoint reboots a network device based on the device ID.
|
||||
Confirmation is required to proceed with the reboot.
|
||||
|
||||
Args:
|
||||
device_id: The device_id that you want to reboot.
|
||||
confirmation: The confirmation that the user wants to reboot.
|
||||
metadata: Ignore this parameter
|
||||
"""
|
||||
return {"status": "Device rebooted", "device_id": device_id}
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
prompt_targets:
|
||||
- name: default
|
||||
path: /agent/default
|
||||
description: "This endpoint handles information extraction queries.\n It can\
|
||||
\ summarize, extract details, and perform various other information-related tasks."
|
||||
parameters:
|
||||
- name: query
|
||||
type: str
|
||||
description: Field from Pydantic model DefaultRequest
|
||||
default_value: null
|
||||
required: false
|
||||
- name: count
|
||||
type: int
|
||||
description: Field from Pydantic model DefaultRequest
|
||||
default_value: null
|
||||
required: false
|
||||
type: default
|
||||
auto-llm-dispatch-on-response: true
|
||||
- name: reboot_network_device
|
||||
path: /agent/action
|
||||
description: "This endpoint reboots a network device based on the device ID.\n \
|
||||
\ Confirmation is required to proceed with the reboot."
|
||||
parameters:
|
||||
- name: device_id
|
||||
type: str
|
||||
description: Description for device_id
|
||||
default_value: ''
|
||||
required: true
|
||||
- name: confirmation
|
||||
type: int
|
||||
description: Description for confirmation
|
||||
default_value: ''
|
||||
required: true
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from tools.cli.main import main # Import your CLI's entry point
|
||||
import importlib.metadata
|
||||
|
||||
|
||||
def get_version():
|
||||
"""Helper function to fetch the version."""
|
||||
try:
|
||||
version = importlib.metadata.version("archgw")
|
||||
return version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
"""Fixture to create a Click test runner."""
|
||||
return CliRunner()
|
||||
|
||||
|
||||
def test_version_option(runner):
|
||||
"""Test the --version option."""
|
||||
result = runner.invoke(main, ["--version"])
|
||||
assert result.exit_code == 0
|
||||
expected_version = get_version()
|
||||
assert f"archgw cli version: {expected_version}" in result.output
|
||||
|
||||
|
||||
def test_default_behavior(runner):
|
||||
"""Test the default behavior when no command is provided."""
|
||||
result = runner.invoke(main)
|
||||
assert result.exit_code == 0
|
||||
assert "Arch (The Intelligent Prompt Gateway) CLI" in result.output
|
||||
assert "Usage:" in result.output # Ensure help text is shown
|
||||
|
||||
|
||||
def test_invalid_command(runner):
|
||||
"""Test that an invalid command returns an appropriate error message."""
|
||||
result = runner.invoke(main, ["invalid_command"])
|
||||
assert result.exit_code != 0 # Non-zero exit code for invalid command
|
||||
assert "Error: No such command 'invalid_command'" in result.output
|
||||
272
arch/tools/test/test_config_generator.py
Normal file
272
arch/tools/test/test_config_generator.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
import pytest
|
||||
from unittest import mock
|
||||
import sys
|
||||
from cli.config_generator import validate_and_render_schema
|
||||
|
||||
# Patch sys.path to allow import from cli/
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "cli"))
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_env(monkeypatch):
|
||||
# Clean up environment variables and mocks after each test
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
|
||||
|
||||
def test_validate_and_render_happy_path(monkeypatch):
|
||||
monkeypatch.setenv("ARCH_CONFIG_FILE", "fake_arch_config.yaml")
|
||||
monkeypatch.setenv("ARCH_CONFIG_SCHEMA_FILE", "fake_arch_config_schema.yaml")
|
||||
monkeypatch.setenv("ENVOY_CONFIG_TEMPLATE_FILE", "./envoy.template.yaml")
|
||||
monkeypatch.setenv("ARCH_CONFIG_FILE_RENDERED", "fake_arch_config_rendered.yaml")
|
||||
monkeypatch.setenv("ENVOY_CONFIG_FILE_RENDERED", "fake_envoy.yaml")
|
||||
monkeypatch.setenv("TEMPLATE_ROOT", "../")
|
||||
|
||||
arch_config = """
|
||||
version: v0.1.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
|
||||
- model: openai/gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
default: true
|
||||
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code understanding
|
||||
description: understand and explain existing code snippets, functions, or libraries
|
||||
|
||||
- model: openai/gpt-4.1
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code generation
|
||||
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
"""
|
||||
arch_config_schema = ""
|
||||
with open("../arch_config_schema.yaml", "r") as file:
|
||||
arch_config_schema = file.read()
|
||||
|
||||
m_open = mock.mock_open()
|
||||
# Provide enough file handles for all open() calls in validate_and_render_schema
|
||||
m_open.side_effect = [
|
||||
mock.mock_open(read_data="").return_value,
|
||||
mock.mock_open(read_data=arch_config).return_value, # ARCH_CONFIG_FILE
|
||||
mock.mock_open(
|
||||
read_data=arch_config_schema
|
||||
).return_value, # ARCH_CONFIG_SCHEMA_FILE
|
||||
mock.mock_open(read_data=arch_config).return_value, # ARCH_CONFIG_FILE
|
||||
mock.mock_open(
|
||||
read_data=arch_config_schema
|
||||
).return_value, # ARCH_CONFIG_SCHEMA_FILE
|
||||
mock.mock_open().return_value, # ENVOY_CONFIG_FILE_RENDERED (write)
|
||||
mock.mock_open().return_value, # ARCH_CONFIG_FILE_RENDERED (write)
|
||||
]
|
||||
with mock.patch("builtins.open", m_open):
|
||||
with mock.patch("config_generator.Environment"):
|
||||
validate_and_render_schema()
|
||||
|
||||
|
||||
arch_config_test_cases = [
|
||||
{
|
||||
"id": "duplicate_provider_name",
|
||||
"expected_error": "Duplicate llm_provider name",
|
||||
"arch_config": """
|
||||
version: v0.1.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
|
||||
- name: test1
|
||||
model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
|
||||
- name: test1
|
||||
model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "provider_interface_with_model_id",
|
||||
"expected_error": "Please provide provider interface as part of model name",
|
||||
"arch_config": """
|
||||
version: v0.1.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider_interface: openai
|
||||
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "duplicate_model_id",
|
||||
"expected_error": "Duplicate model_id",
|
||||
"arch_config": """
|
||||
version: v0.1.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
|
||||
- model: mistral/gpt-4o
|
||||
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "custom_provider_base_url",
|
||||
"expected_error": "Must provide base_url and provider_interface",
|
||||
"arch_config": """
|
||||
version: v0.1.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
|
||||
- model: custom/gpt-4o
|
||||
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "base_url_no_prefix",
|
||||
"expected_error": "Please provide base_url without path",
|
||||
"arch_config": """
|
||||
version: v0.1.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
|
||||
- model: custom/gpt-4o
|
||||
base_url: "http://custom.com/test"
|
||||
provider_interface: openai
|
||||
|
||||
""",
|
||||
},
|
||||
{
|
||||
"id": "duplicate_routeing_preference_name",
|
||||
"expected_error": "Duplicate routing preference name",
|
||||
"arch_config": """
|
||||
version: v0.1.0
|
||||
|
||||
listeners:
|
||||
egress_traffic:
|
||||
address: 0.0.0.0
|
||||
port: 12000
|
||||
message_format: openai
|
||||
timeout: 30s
|
||||
|
||||
llm_providers:
|
||||
|
||||
- model: openai/gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
default: true
|
||||
|
||||
- model: openai/gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code understanding
|
||||
description: understand and explain existing code snippets, functions, or libraries
|
||||
|
||||
- model: openai/gpt-4.1
|
||||
access_key: $OPENAI_API_KEY
|
||||
routing_preferences:
|
||||
- name: code understanding
|
||||
description: generating new code snippets, functions, or boilerplate based on user prompts or requirements
|
||||
|
||||
tracing:
|
||||
random_sampling: 100
|
||||
|
||||
""",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"arch_config_test_case",
|
||||
arch_config_test_cases,
|
||||
ids=[case["id"] for case in arch_config_test_cases],
|
||||
)
|
||||
def test_validate_and_render_schema_tests(monkeypatch, arch_config_test_case):
|
||||
monkeypatch.setenv("ARCH_CONFIG_FILE", "fake_arch_config.yaml")
|
||||
monkeypatch.setenv("ARCH_CONFIG_SCHEMA_FILE", "fake_arch_config_schema.yaml")
|
||||
monkeypatch.setenv("ENVOY_CONFIG_TEMPLATE_FILE", "./envoy.template.yaml")
|
||||
monkeypatch.setenv("ARCH_CONFIG_FILE_RENDERED", "fake_arch_config_rendered.yaml")
|
||||
monkeypatch.setenv("ENVOY_CONFIG_FILE_RENDERED", "fake_envoy.yaml")
|
||||
monkeypatch.setenv("TEMPLATE_ROOT", "../")
|
||||
|
||||
arch_config = arch_config_test_case["arch_config"]
|
||||
expected_error = arch_config_test_case["expected_error"]
|
||||
test_id = arch_config_test_case["id"]
|
||||
|
||||
arch_config_schema = ""
|
||||
with open("../arch_config_schema.yaml", "r") as file:
|
||||
arch_config_schema = file.read()
|
||||
|
||||
m_open = mock.mock_open()
|
||||
# Provide enough file handles for all open() calls in validate_and_render_schema
|
||||
m_open.side_effect = [
|
||||
mock.mock_open(read_data="").return_value,
|
||||
mock.mock_open(read_data=arch_config).return_value, # ARCH_CONFIG_FILE
|
||||
mock.mock_open(
|
||||
read_data=arch_config_schema
|
||||
).return_value, # ARCH_CONFIG_SCHEMA_FILE
|
||||
mock.mock_open(read_data=arch_config).return_value, # ARCH_CONFIG_FILE
|
||||
mock.mock_open(
|
||||
read_data=arch_config_schema
|
||||
).return_value, # ARCH_CONFIG_SCHEMA_FILE
|
||||
mock.mock_open().return_value, # ENVOY_CONFIG_FILE_RENDERED (write)
|
||||
mock.mock_open().return_value, # ARCH_CONFIG_FILE_RENDERED (write)
|
||||
]
|
||||
with mock.patch("builtins.open", m_open):
|
||||
with mock.patch("config_generator.Environment"):
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
validate_and_render_schema()
|
||||
assert expected_error in str(excinfo.value)
|
||||
Loading…
Add table
Add a link
Reference in a new issue