mirror of
https://github.com/katanemo/plano.git
synced 2026-06-20 15:28:07 +02:00
restructure cli
- move /arch/tools => /cli - rename /arch -> /config - update all refs in tests to use /config - update planoai code with new references
This commit is contained in:
parent
a56bb9d190
commit
425acecccc
45 changed files with 137 additions and 102 deletions
0
cli/planoai/__init__.py
Normal file
0
cli/planoai/__init__.py
Normal file
425
cli/planoai/config_generator.py
Normal file
425
cli/planoai/config_generator.py
Normal file
|
|
@ -0,0 +1,425 @@
|
|||
import json
|
||||
import os
|
||||
from planoai.utils import convert_legacy_listeners
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
import yaml
|
||||
from jsonschema import validate
|
||||
from urllib.parse import urlparse
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
SUPPORTED_PROVIDERS_WITH_BASE_URL = [
|
||||
"azure_openai",
|
||||
"ollama",
|
||||
"qwen",
|
||||
"amazon_bedrock",
|
||||
"arch",
|
||||
]
|
||||
|
||||
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL = [
|
||||
"deepseek",
|
||||
"groq",
|
||||
"mistral",
|
||||
"openai",
|
||||
"gemini",
|
||||
"anthropic",
|
||||
"together_ai",
|
||||
"xai",
|
||||
"moonshotai",
|
||||
"zhipu",
|
||||
]
|
||||
|
||||
SUPPORTED_PROVIDERS = (
|
||||
SUPPORTED_PROVIDERS_WITHOUT_BASE_URL + SUPPORTED_PROVIDERS_WITH_BASE_URL
|
||||
)
|
||||
|
||||
|
||||
def get_endpoint_and_port(endpoint, protocol):
|
||||
endpoint_tokens = endpoint.split(":")
|
||||
if len(endpoint_tokens) > 1:
|
||||
endpoint = endpoint_tokens[0]
|
||||
port = int(endpoint_tokens[1])
|
||||
return endpoint, port
|
||||
else:
|
||||
if protocol == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
return endpoint, port
|
||||
|
||||
|
||||
def validate_and_render_schema():
|
||||
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
|
||||
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
|
||||
)
|
||||
ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/app/arch_config.yaml")
|
||||
ARCH_CONFIG_FILE_RENDERED = os.getenv(
|
||||
"ARCH_CONFIG_FILE_RENDERED", "/app/arch_config_rendered.yaml"
|
||||
)
|
||||
ENVOY_CONFIG_FILE_RENDERED = os.getenv(
|
||||
"ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml"
|
||||
)
|
||||
ARCH_CONFIG_SCHEMA_FILE = os.getenv(
|
||||
"ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml"
|
||||
)
|
||||
|
||||
env = Environment(loader=FileSystemLoader(os.getenv("TEMPLATE_ROOT", "./")))
|
||||
template = env.get_template(ENVOY_CONFIG_TEMPLATE_FILE)
|
||||
|
||||
try:
|
||||
validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
exit(1) # validate_prompt_config failed. Exit
|
||||
|
||||
with open(ARCH_CONFIG_FILE, "r") as file:
|
||||
arch_config = file.read()
|
||||
|
||||
with open(ARCH_CONFIG_SCHEMA_FILE, "r") as file:
|
||||
arch_config_schema = file.read()
|
||||
|
||||
config_yaml = yaml.safe_load(arch_config)
|
||||
_ = yaml.safe_load(arch_config_schema)
|
||||
inferred_clusters = {}
|
||||
|
||||
# Convert legacy llm_providers to model_providers
|
||||
if "llm_providers" in config_yaml:
|
||||
if "model_providers" in config_yaml:
|
||||
raise Exception(
|
||||
"Please provide either llm_providers or model_providers, not both. llm_providers is deprecated, please use model_providers instead"
|
||||
)
|
||||
config_yaml["model_providers"] = config_yaml["llm_providers"]
|
||||
del config_yaml["llm_providers"]
|
||||
|
||||
listeners, llm_gateway, prompt_gateway = convert_legacy_listeners(
|
||||
config_yaml.get("listeners"), config_yaml.get("model_providers")
|
||||
)
|
||||
|
||||
config_yaml["listeners"] = listeners
|
||||
|
||||
endpoints = config_yaml.get("endpoints", {})
|
||||
|
||||
# Process agents section and convert to endpoints
|
||||
agents = config_yaml.get("agents", [])
|
||||
filters = config_yaml.get("filters", [])
|
||||
agents_combined = agents + filters
|
||||
agent_id_keys = set()
|
||||
|
||||
for agent in agents_combined:
|
||||
agent_id = agent.get("id")
|
||||
if agent_id in agent_id_keys:
|
||||
raise Exception(
|
||||
f"Duplicate agent id {agent_id}, please provide unique id for each agent"
|
||||
)
|
||||
agent_id_keys.add(agent_id)
|
||||
agent_endpoint = agent.get("url")
|
||||
|
||||
if agent_id and agent_endpoint:
|
||||
urlparse_result = urlparse(agent_endpoint)
|
||||
if urlparse_result.scheme and urlparse_result.hostname:
|
||||
protocol = urlparse_result.scheme
|
||||
|
||||
port = urlparse_result.port
|
||||
if port is None:
|
||||
if protocol == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
|
||||
endpoints[agent_id] = {
|
||||
"endpoint": urlparse_result.hostname,
|
||||
"port": port,
|
||||
"protocol": protocol,
|
||||
}
|
||||
|
||||
# override the inferred clusters with the ones defined in the config
|
||||
for name, endpoint_details in endpoints.items():
|
||||
inferred_clusters[name] = endpoint_details
|
||||
# Only call get_endpoint_and_port for manually defined endpoints, not agent-derived ones
|
||||
if "port" not in endpoint_details:
|
||||
endpoint = inferred_clusters[name]["endpoint"]
|
||||
protocol = inferred_clusters[name].get("protocol", "http")
|
||||
(
|
||||
inferred_clusters[name]["endpoint"],
|
||||
inferred_clusters[name]["port"],
|
||||
) = get_endpoint_and_port(endpoint, protocol)
|
||||
|
||||
print("defined clusters from arch_config.yaml: ", json.dumps(inferred_clusters))
|
||||
|
||||
if "prompt_targets" in config_yaml:
|
||||
for prompt_target in config_yaml["prompt_targets"]:
|
||||
name = prompt_target.get("endpoint", {}).get("name", None)
|
||||
if not name:
|
||||
continue
|
||||
if name not in inferred_clusters:
|
||||
raise Exception(
|
||||
f"Unknown endpoint {name}, please add it in endpoints section in your arch_config.yaml file"
|
||||
)
|
||||
|
||||
arch_tracing = config_yaml.get("tracing", {})
|
||||
|
||||
llms_with_endpoint = []
|
||||
llms_with_endpoint_cluster_names = set()
|
||||
updated_model_providers = []
|
||||
model_provider_name_set = set()
|
||||
llms_with_usage = []
|
||||
model_name_keys = set()
|
||||
model_usage_name_keys = set()
|
||||
|
||||
print("listeners: ", listeners)
|
||||
|
||||
for listener in listeners:
|
||||
if (
|
||||
listener.get("model_providers") is None
|
||||
or listener.get("model_providers") == []
|
||||
):
|
||||
continue
|
||||
print("Processing listener with model_providers: ", listener)
|
||||
name = listener.get("name", None)
|
||||
|
||||
for model_provider in listener.get("model_providers", []):
|
||||
if model_provider.get("usage", None):
|
||||
llms_with_usage.append(model_provider["name"])
|
||||
if model_provider.get("name") in model_provider_name_set:
|
||||
raise Exception(
|
||||
f"Duplicate model_provider name {model_provider.get('name')}, please provide unique name for each model_provider"
|
||||
)
|
||||
|
||||
model_name = model_provider.get("model")
|
||||
print("Processing model_provider: ", model_provider)
|
||||
if model_name in model_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate model name {model_name}, please provide unique model name for each model_provider"
|
||||
)
|
||||
model_name_keys.add(model_name)
|
||||
if model_provider.get("name") is None:
|
||||
model_provider["name"] = model_name
|
||||
|
||||
model_provider_name_set.add(model_provider.get("name"))
|
||||
|
||||
model_name_tokens = model_name.split("/")
|
||||
if len(model_name_tokens) < 2:
|
||||
raise Exception(
|
||||
f"Invalid model name {model_name}. Please provide model name in the format <provider>/<model_id>."
|
||||
)
|
||||
provider = model_name_tokens[0]
|
||||
|
||||
# Validate azure_openai and ollama provider requires base_url
|
||||
if (provider in SUPPORTED_PROVIDERS_WITH_BASE_URL) and model_provider.get(
|
||||
"base_url"
|
||||
) is None:
|
||||
raise Exception(
|
||||
f"Provider '{provider}' requires 'base_url' to be set for model {model_name}"
|
||||
)
|
||||
|
||||
model_id = "/".join(model_name_tokens[1:])
|
||||
if provider not in SUPPORTED_PROVIDERS:
|
||||
if (
|
||||
model_provider.get("base_url", None) is None
|
||||
or model_provider.get("provider_interface", None) is None
|
||||
):
|
||||
raise Exception(
|
||||
f"Must provide base_url and provider_interface for unsupported provider {provider} for model {model_name}. Supported providers are: {', '.join(SUPPORTED_PROVIDERS)}"
|
||||
)
|
||||
provider = model_provider.get("provider_interface", None)
|
||||
elif model_provider.get("provider_interface", None) is not None:
|
||||
raise Exception(
|
||||
f"Please provide provider interface as part of model name {model_name} using the format <provider>/<model_id>. For example, use 'openai/gpt-3.5-turbo' instead of 'gpt-3.5-turbo' "
|
||||
)
|
||||
|
||||
if model_id in model_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate model_id {model_id}, please provide unique model_id for each model_provider"
|
||||
)
|
||||
model_name_keys.add(model_id)
|
||||
|
||||
for routing_preference in model_provider.get("routing_preferences", []):
|
||||
if routing_preference.get("name") in model_usage_name_keys:
|
||||
raise Exception(
|
||||
f"Duplicate routing preference name \"{routing_preference.get('name')}\", please provide unique name for each routing preference"
|
||||
)
|
||||
model_usage_name_keys.add(routing_preference.get("name"))
|
||||
|
||||
model_provider["model"] = model_id
|
||||
model_provider["provider_interface"] = provider
|
||||
model_provider_name_set.add(model_provider.get("name"))
|
||||
if model_provider.get("provider") and model_provider.get(
|
||||
"provider_interface"
|
||||
):
|
||||
raise Exception(
|
||||
"Please provide either provider or provider_interface, not both"
|
||||
)
|
||||
if model_provider.get("provider"):
|
||||
provider = model_provider["provider"]
|
||||
model_provider["provider_interface"] = provider
|
||||
del model_provider["provider"]
|
||||
updated_model_providers.append(model_provider)
|
||||
|
||||
if model_provider.get("base_url", None):
|
||||
base_url = model_provider["base_url"]
|
||||
urlparse_result = urlparse(base_url)
|
||||
base_url_path_prefix = urlparse_result.path
|
||||
if base_url_path_prefix and base_url_path_prefix != "/":
|
||||
# we will now support base_url_path_prefix. This means that the user can provide base_url like http://example.com/path and we will extract /path as base_url_path_prefix
|
||||
model_provider["base_url_path_prefix"] = base_url_path_prefix
|
||||
|
||||
if urlparse_result.scheme == "" or urlparse_result.scheme not in [
|
||||
"http",
|
||||
"https",
|
||||
]:
|
||||
raise Exception(
|
||||
"Please provide a valid URL with scheme (http/https) in base_url"
|
||||
)
|
||||
protocol = urlparse_result.scheme
|
||||
port = urlparse_result.port
|
||||
if port is None:
|
||||
if protocol == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
endpoint = urlparse_result.hostname
|
||||
model_provider["endpoint"] = endpoint
|
||||
model_provider["port"] = port
|
||||
model_provider["protocol"] = protocol
|
||||
cluster_name = (
|
||||
provider + "_" + endpoint
|
||||
) # make name unique by appending endpoint
|
||||
model_provider["cluster_name"] = cluster_name
|
||||
# Only add if cluster_name is not already present to avoid duplicates
|
||||
if cluster_name not in llms_with_endpoint_cluster_names:
|
||||
llms_with_endpoint.append(model_provider)
|
||||
llms_with_endpoint_cluster_names.add(cluster_name)
|
||||
|
||||
if len(model_usage_name_keys) > 0:
|
||||
routing_model_provider = config_yaml.get("routing", {}).get(
|
||||
"model_provider", None
|
||||
)
|
||||
if (
|
||||
routing_model_provider
|
||||
and routing_model_provider not in model_provider_name_set
|
||||
):
|
||||
raise Exception(
|
||||
f"Routing model_provider {routing_model_provider} is not defined in model_providers"
|
||||
)
|
||||
if (
|
||||
routing_model_provider is None
|
||||
and "arch-router" not in model_provider_name_set
|
||||
):
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "arch-router",
|
||||
"provider_interface": "arch",
|
||||
"model": config_yaml.get("routing", {}).get("model", "Arch-Router"),
|
||||
}
|
||||
)
|
||||
|
||||
# Always add arch-function model provider if not already defined
|
||||
if "arch-function" not in model_provider_name_set:
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "arch-function",
|
||||
"provider_interface": "arch",
|
||||
"model": "Arch-Function",
|
||||
}
|
||||
)
|
||||
|
||||
if "plano-orchestrator" not in model_provider_name_set:
|
||||
updated_model_providers.append(
|
||||
{
|
||||
"name": "plano-orchestrator",
|
||||
"provider_interface": "arch",
|
||||
"model": "Plano-Orchestrator",
|
||||
}
|
||||
)
|
||||
|
||||
config_yaml["model_providers"] = deepcopy(updated_model_providers)
|
||||
|
||||
listeners_with_provider = 0
|
||||
for listener in listeners:
|
||||
print("Processing listener: ", listener)
|
||||
model_providers = listener.get("model_providers", None)
|
||||
if model_providers is not None:
|
||||
listeners_with_provider += 1
|
||||
if listeners_with_provider > 1:
|
||||
raise Exception(
|
||||
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
|
||||
)
|
||||
|
||||
# Validate model aliases if present
|
||||
if "model_aliases" in config_yaml:
|
||||
model_aliases = config_yaml["model_aliases"]
|
||||
for alias_name, alias_config in model_aliases.items():
|
||||
target = alias_config.get("target")
|
||||
if target not in model_name_keys:
|
||||
raise Exception(
|
||||
f"Model alias 2 - '{alias_name}' targets '{target}' which is not defined as a model. Available models: {', '.join(sorted(model_name_keys))}"
|
||||
)
|
||||
|
||||
arch_config_string = yaml.dump(config_yaml)
|
||||
arch_llm_config_string = yaml.dump(config_yaml)
|
||||
|
||||
use_agent_orchestrator = config_yaml.get("overrides", {}).get(
|
||||
"use_agent_orchestrator", False
|
||||
)
|
||||
|
||||
agent_orchestrator = None
|
||||
if use_agent_orchestrator:
|
||||
print("Using agent orchestrator")
|
||||
|
||||
if len(endpoints) == 0:
|
||||
raise Exception(
|
||||
"Please provide agent orchestrator in the endpoints section in your arch_config.yaml file"
|
||||
)
|
||||
elif len(endpoints) > 1:
|
||||
raise Exception(
|
||||
"Please provide single agent orchestrator in the endpoints section in your arch_config.yaml file"
|
||||
)
|
||||
else:
|
||||
agent_orchestrator = list(endpoints.keys())[0]
|
||||
|
||||
print("agent_orchestrator: ", agent_orchestrator)
|
||||
|
||||
data = {
|
||||
"prompt_gateway_listener": prompt_gateway,
|
||||
"llm_gateway_listener": llm_gateway,
|
||||
"arch_config": arch_config_string,
|
||||
"arch_llm_config": arch_llm_config_string,
|
||||
"arch_clusters": inferred_clusters,
|
||||
"arch_model_providers": updated_model_providers,
|
||||
"arch_tracing": arch_tracing,
|
||||
"local_llms": llms_with_endpoint,
|
||||
"agent_orchestrator": agent_orchestrator,
|
||||
"listeners": listeners,
|
||||
}
|
||||
|
||||
rendered = template.render(data)
|
||||
print(ENVOY_CONFIG_FILE_RENDERED)
|
||||
print(rendered)
|
||||
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
|
||||
file.write(rendered)
|
||||
|
||||
with open(ARCH_CONFIG_FILE_RENDERED, "w") as file:
|
||||
file.write(arch_config_string)
|
||||
|
||||
|
||||
def validate_prompt_config(arch_config_file, arch_config_schema_file):
|
||||
with open(arch_config_file, "r") as file:
|
||||
arch_config = file.read()
|
||||
|
||||
with open(arch_config_schema_file, "r") as file:
|
||||
arch_config_schema = file.read()
|
||||
|
||||
config_yaml = yaml.safe_load(arch_config)
|
||||
config_schema_yaml = yaml.safe_load(arch_config_schema)
|
||||
|
||||
try:
|
||||
validate(config_yaml, config_schema_yaml)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error validating arch_config file: {arch_config_file}, schema file: {arch_config_schema_file}, error: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
validate_and_render_schema()
|
||||
5
cli/planoai/consts.py
Normal file
5
cli/planoai/consts.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
import os
|
||||
|
||||
SERVICE_NAME_ARCHGW = "plano"
|
||||
PLANO_DOCKER_NAME = "plano"
|
||||
PLANO_DOCKER_IMAGE = os.getenv("PLANO_DOCKER_IMAGE", "katanemo/plano:0.4.0")
|
||||
234
cli/planoai/core.py
Normal file
234
cli/planoai/core.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
import json
|
||||
import subprocess
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
|
||||
import yaml
|
||||
from planoai.utils import convert_legacy_listeners, getLogger
|
||||
from planoai.consts import (
|
||||
PLANO_DOCKER_IMAGE,
|
||||
PLANO_DOCKER_NAME,
|
||||
)
|
||||
import subprocess
|
||||
from planoai.docker_cli import (
|
||||
docker_container_status,
|
||||
docker_remove_container,
|
||||
docker_start_plano_detached,
|
||||
docker_stop_container,
|
||||
health_check_endpoint,
|
||||
stream_gateway_logs,
|
||||
)
|
||||
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
def _get_gateway_ports(arch_config_file: str) -> list[int]:
|
||||
PROMPT_GATEWAY_DEFAULT_PORT = 10000
|
||||
LLM_GATEWAY_DEFAULT_PORT = 12000
|
||||
|
||||
# parse arch_config_file yaml file and get prompt_gateway_port
|
||||
arch_config_dict = {}
|
||||
with open(arch_config_file) as f:
|
||||
arch_config_dict = yaml.safe_load(f)
|
||||
|
||||
listeners, _, _ = convert_legacy_listeners(
|
||||
arch_config_dict.get("listeners"), arch_config_dict.get("llm_providers")
|
||||
)
|
||||
|
||||
all_ports = [listener.get("port") for listener in listeners]
|
||||
|
||||
# unique ports
|
||||
all_ports = list(set(all_ports))
|
||||
|
||||
return all_ports
|
||||
|
||||
|
||||
def start_arch(arch_config_file, env, log_timeout=120, foreground=False):
|
||||
"""
|
||||
Start Docker Compose in detached mode and stream logs until services are healthy.
|
||||
|
||||
Args:
|
||||
path (str): The path where the prompt_config.yml file is located.
|
||||
log_timeout (int): Time in seconds to show logs before checking for healthy state.
|
||||
"""
|
||||
log.info(
|
||||
f"Starting arch gateway, image name: {PLANO_DOCKER_NAME}, tag: {PLANO_DOCKER_IMAGE}"
|
||||
)
|
||||
|
||||
try:
|
||||
plano_container_status = docker_container_status(PLANO_DOCKER_NAME)
|
||||
if plano_container_status != "not found":
|
||||
log.info("plano found in docker, stopping and removing it")
|
||||
docker_stop_container(PLANO_DOCKER_NAME)
|
||||
docker_remove_container(PLANO_DOCKER_NAME)
|
||||
|
||||
gateway_ports = _get_gateway_ports(arch_config_file)
|
||||
|
||||
return_code, _, plano_stderr = docker_start_plano_detached(
|
||||
arch_config_file,
|
||||
env,
|
||||
gateway_ports,
|
||||
)
|
||||
if return_code != 0:
|
||||
log.info("Failed to start plano gateway: " + str(return_code))
|
||||
log.info("stderr: " + plano_stderr)
|
||||
sys.exit(1)
|
||||
|
||||
start_time = time.time()
|
||||
while True:
|
||||
all_listeners_healthy = True
|
||||
for port in gateway_ports:
|
||||
health_check_status = health_check_endpoint(
|
||||
f"http://localhost:{port}/healthz"
|
||||
)
|
||||
if not health_check_status:
|
||||
all_listeners_healthy = False
|
||||
|
||||
plano_status = docker_container_status(PLANO_DOCKER_NAME)
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
if plano_status == "exited":
|
||||
log.info("plano container exited unexpectedly.")
|
||||
stream_gateway_logs(follow=False)
|
||||
sys.exit(1)
|
||||
|
||||
# Check if timeout is reached
|
||||
if elapsed_time > log_timeout:
|
||||
log.info(f"stopping log monitoring after {log_timeout} seconds.")
|
||||
stream_gateway_logs(follow=False)
|
||||
sys.exit(1)
|
||||
|
||||
if all_listeners_healthy:
|
||||
log.info("plano is running and is healthy!")
|
||||
break
|
||||
else:
|
||||
health_check_status_str = (
|
||||
"healthy" if health_check_status else "not healthy"
|
||||
)
|
||||
log.info(
|
||||
f"plano status: {plano_status}, health status: {health_check_status_str}"
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
if foreground:
|
||||
stream_gateway_logs(follow=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
log.info("Keyboard interrupt received, stopping arch gateway service.")
|
||||
stop_docker_container()
|
||||
|
||||
|
||||
def stop_docker_container(service=PLANO_DOCKER_NAME):
|
||||
"""
|
||||
Shutdown all Docker Compose services by running `docker-compose down`.
|
||||
|
||||
Args:
|
||||
path (str): The path where the docker-compose.yml file is located.
|
||||
"""
|
||||
log.info(f"Shutting down {service} service.")
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "stop", service],
|
||||
)
|
||||
subprocess.run(
|
||||
["docker", "rm", service],
|
||||
)
|
||||
|
||||
log.info(f"Successfully shut down {service} service.")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to shut down services: {str(e)}")
|
||||
|
||||
|
||||
def start_cli_agent(arch_config_file=None, settings_json="{}"):
|
||||
"""Start a CLI client connected to Arch."""
|
||||
|
||||
with open(arch_config_file, "r") as file:
|
||||
arch_config = file.read()
|
||||
arch_config_yaml = yaml.safe_load(arch_config)
|
||||
|
||||
# Get egress listener configuration
|
||||
egress_config = arch_config_yaml.get("listeners", {}).get("egress_traffic", {})
|
||||
host = egress_config.get("host", "127.0.0.1")
|
||||
port = egress_config.get("port", 12000)
|
||||
|
||||
# Parse additional settings from command line
|
||||
try:
|
||||
additional_settings = json.loads(settings_json) if settings_json else {}
|
||||
except json.JSONDecodeError:
|
||||
log.error("Settings must be valid JSON")
|
||||
sys.exit(1)
|
||||
|
||||
# Set up environment variables
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"ANTHROPIC_AUTH_TOKEN": "test", # Use test token for arch
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_BASE_URL": f"http://{host}:{port}",
|
||||
"NO_PROXY": host,
|
||||
"DISABLE_TELEMETRY": "true",
|
||||
"DISABLE_COST_WARNINGS": "true",
|
||||
"API_TIMEOUT_MS": "600000",
|
||||
}
|
||||
)
|
||||
|
||||
# Set ANTHROPIC_SMALL_FAST_MODEL from additional_settings or model alias
|
||||
if "ANTHROPIC_SMALL_FAST_MODEL" in additional_settings:
|
||||
env["ANTHROPIC_SMALL_FAST_MODEL"] = additional_settings[
|
||||
"ANTHROPIC_SMALL_FAST_MODEL"
|
||||
]
|
||||
else:
|
||||
# Check if arch.claude.code.small.fast alias exists in model_aliases
|
||||
model_aliases = arch_config_yaml.get("model_aliases", {})
|
||||
if "arch.claude.code.small.fast" in model_aliases:
|
||||
env["ANTHROPIC_SMALL_FAST_MODEL"] = "arch.claude.code.small.fast"
|
||||
else:
|
||||
log.info(
|
||||
"Tip: Set an alias 'arch.claude.code.small.fast' in your model_aliases config to set a small fast model Claude Code"
|
||||
)
|
||||
log.info("Or provide ANTHROPIC_SMALL_FAST_MODEL in --settings JSON")
|
||||
|
||||
# Non-interactive mode configuration from additional_settings only
|
||||
if additional_settings.get("NON_INTERACTIVE_MODE", False):
|
||||
env.update(
|
||||
{
|
||||
"CI": "true",
|
||||
"FORCE_COLOR": "0",
|
||||
"NODE_NO_READLINE": "1",
|
||||
"TERM": "dumb",
|
||||
}
|
||||
)
|
||||
|
||||
# Build claude command arguments
|
||||
claude_args = []
|
||||
|
||||
# Add settings if provided, excluding those already handled as environment variables
|
||||
if additional_settings:
|
||||
# Filter out settings that are already processed as environment variables
|
||||
claude_settings = {
|
||||
k: v
|
||||
for k, v in additional_settings.items()
|
||||
if k not in ["ANTHROPIC_SMALL_FAST_MODEL", "NON_INTERACTIVE_MODE"]
|
||||
}
|
||||
if claude_settings:
|
||||
claude_args.append(f"--settings={json.dumps(claude_settings)}")
|
||||
|
||||
# Use claude from PATH
|
||||
claude_path = "claude"
|
||||
log.info(f"Connecting Claude Code Agent to Arch at {host}:{port}")
|
||||
|
||||
try:
|
||||
subprocess.run([claude_path] + claude_args, env=env, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.error(f"Error starting claude: {e}")
|
||||
sys.exit(1)
|
||||
except FileNotFoundError:
|
||||
log.error(
|
||||
f"{claude_path} not found. Make sure Claude Code is installed: npm install -g @anthropic-ai/claude-code"
|
||||
)
|
||||
sys.exit(1)
|
||||
136
cli/planoai/docker_cli.py
Normal file
136
cli/planoai/docker_cli.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
import subprocess
|
||||
import json
|
||||
import sys
|
||||
import requests
|
||||
|
||||
from planoai.consts import (
|
||||
PLANO_DOCKER_IMAGE,
|
||||
PLANO_DOCKER_NAME,
|
||||
)
|
||||
from planoai.utils import getLogger
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
def docker_container_status(container: str) -> str:
|
||||
result = subprocess.run(
|
||||
["docker", "inspect", "--type=container", container],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return "not found"
|
||||
|
||||
container_status = json.loads(result.stdout)[0]
|
||||
return container_status.get("State", {}).get("Status", "")
|
||||
|
||||
|
||||
def docker_stop_container(container: str) -> str:
|
||||
result = subprocess.run(
|
||||
["docker", "stop", container], capture_output=True, text=True, check=False
|
||||
)
|
||||
return result.returncode
|
||||
|
||||
|
||||
def docker_remove_container(container: str) -> str:
|
||||
result = subprocess.run(
|
||||
["docker", "rm", container], capture_output=True, text=True, check=False
|
||||
)
|
||||
return result.returncode
|
||||
|
||||
|
||||
def docker_start_plano_detached(
|
||||
arch_config_file: str,
|
||||
env: dict,
|
||||
gateway_ports: list[int],
|
||||
) -> str:
|
||||
env_args = [item for key, value in env.items() for item in ["-e", f"{key}={value}"]]
|
||||
|
||||
port_mappings = [
|
||||
f"{12001}:{12001}",
|
||||
"19901:9901",
|
||||
]
|
||||
|
||||
for port in gateway_ports:
|
||||
port_mappings.append(f"{port}:{port}")
|
||||
|
||||
port_mappings_args = [item for port in port_mappings for item in ("-p", port)]
|
||||
|
||||
volume_mappings = [
|
||||
f"{arch_config_file}:/app/arch_config.yaml:ro",
|
||||
]
|
||||
volume_mappings_args = [
|
||||
item for volume in volume_mappings for item in ("-v", volume)
|
||||
]
|
||||
|
||||
options = [
|
||||
"docker",
|
||||
"run",
|
||||
"-d",
|
||||
"--name",
|
||||
PLANO_DOCKER_NAME,
|
||||
*port_mappings_args,
|
||||
*volume_mappings_args,
|
||||
*env_args,
|
||||
"--add-host",
|
||||
"host.docker.internal:host-gateway",
|
||||
PLANO_DOCKER_IMAGE,
|
||||
]
|
||||
|
||||
result = subprocess.run(options, capture_output=True, text=True, check=False)
|
||||
return result.returncode, result.stdout, result.stderr
|
||||
|
||||
|
||||
def health_check_endpoint(endpoint: str) -> bool:
|
||||
try:
|
||||
response = requests.get(endpoint)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def stream_gateway_logs(follow, service="plano"):
|
||||
"""
|
||||
Stream logs from the plano gateway service.
|
||||
"""
|
||||
log.info("Logs from plano gateway service.")
|
||||
|
||||
options = ["docker", "logs"]
|
||||
if follow:
|
||||
options.append("-f")
|
||||
options.append(service)
|
||||
try:
|
||||
# Run `docker-compose logs` to stream logs from the gateway service
|
||||
subprocess.run(
|
||||
options,
|
||||
check=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to stream logs: {str(e)}")
|
||||
|
||||
|
||||
def docker_validate_plano_schema(arch_config_file):
|
||||
result = subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"run",
|
||||
"--rm",
|
||||
"-v",
|
||||
f"{arch_config_file}:/app/arch_config.yaml:ro",
|
||||
"--entrypoint",
|
||||
"python",
|
||||
PLANO_DOCKER_IMAGE,
|
||||
"-m",
|
||||
"planoai.config_generator",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
return result.returncode, result.stdout, result.stderr
|
||||
302
cli/planoai/main.py
Normal file
302
cli/planoai/main.py
Normal file
|
|
@ -0,0 +1,302 @@
|
|||
import click
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import multiprocessing
|
||||
import importlib.metadata
|
||||
import json
|
||||
from planoai import targets
|
||||
from planoai.docker_cli import (
|
||||
docker_validate_plano_schema,
|
||||
stream_gateway_logs,
|
||||
docker_container_status,
|
||||
)
|
||||
from planoai.utils import (
|
||||
getLogger,
|
||||
get_llm_provider_access_keys,
|
||||
has_ingress_listener,
|
||||
load_env_file_to_dict,
|
||||
stream_access_logs,
|
||||
find_config_file,
|
||||
find_repo_root,
|
||||
)
|
||||
from planoai.core import (
|
||||
start_arch,
|
||||
stop_docker_container,
|
||||
start_cli_agent,
|
||||
)
|
||||
from planoai.consts import (
|
||||
PLANO_DOCKER_IMAGE,
|
||||
PLANO_DOCKER_NAME,
|
||||
SERVICE_NAME_ARCHGW,
|
||||
)
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
# ref https://patorjk.com/software/taag/#p=display&f=Doom&t=Plano&x=none&v=4&h=4&w=80&we=false
|
||||
logo = r"""
|
||||
______ _
|
||||
| ___ \ |
|
||||
| |_/ / | __ _ _ __ ___
|
||||
| __/| |/ _` | '_ \ / _ \
|
||||
| | | | (_| | | | | (_) |
|
||||
\_| |_|\__,_|_| |_|\___/
|
||||
|
||||
"""
|
||||
|
||||
# Command to build plano Docker images
|
||||
ARCHGW_DOCKERFILE = "./Dockerfile"
|
||||
|
||||
|
||||
def get_version():
|
||||
try:
|
||||
version = importlib.metadata.version("planoai")
|
||||
return version
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
return "version not found"
|
||||
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.option("--version", is_flag=True, help="Show the plano cli version and exit.")
|
||||
@click.pass_context
|
||||
def main(ctx, version):
|
||||
if version:
|
||||
click.echo(f"plano cli version: {get_version()}")
|
||||
ctx.exit()
|
||||
|
||||
log.info(f"Starting plano cli version: {get_version()}")
|
||||
|
||||
if ctx.invoked_subcommand is None:
|
||||
click.echo("""Arch (The Intelligent Prompt Gateway) CLI""")
|
||||
click.echo(logo)
|
||||
click.echo(ctx.get_help())
|
||||
|
||||
|
||||
@click.command()
|
||||
def build():
|
||||
"""Build Arch from source. Works from any directory within the repo."""
|
||||
|
||||
# Find the repo root
|
||||
repo_root = find_repo_root()
|
||||
if not repo_root:
|
||||
click.echo(
|
||||
"Error: Could not find repository root. Make sure you're inside the plano repository."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
dockerfile_path = os.path.join(repo_root, "Dockerfile")
|
||||
|
||||
if not os.path.exists(dockerfile_path):
|
||||
click.echo(f"Error: Dockerfile not found at {dockerfile_path}")
|
||||
sys.exit(1)
|
||||
|
||||
click.echo(f"Building plano image from {repo_root}...")
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"build",
|
||||
"-f",
|
||||
dockerfile_path,
|
||||
"-t",
|
||||
f"{PLANO_DOCKER_IMAGE}",
|
||||
repo_root,
|
||||
"--add-host=host.docker.internal:host-gateway",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
click.echo("archgw image built successfully.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"Error building plano image: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("file", required=False) # Optional file argument
|
||||
@click.option(
|
||||
"--path", default=".", help="Path to the directory containing arch_config.yaml"
|
||||
)
|
||||
@click.option(
|
||||
"--foreground",
|
||||
default=False,
|
||||
help="Run Arch in the foreground. Default is False",
|
||||
is_flag=True,
|
||||
)
|
||||
def up(file, path, foreground):
|
||||
"""Starts Arch."""
|
||||
# Use the utility function to find config file
|
||||
arch_config_file = find_config_file(path, file)
|
||||
|
||||
# Check if the file exists
|
||||
if not os.path.exists(arch_config_file):
|
||||
log.info(f"Error: {arch_config_file} does not exist.")
|
||||
return
|
||||
|
||||
log.info(f"Validating {arch_config_file}")
|
||||
(
|
||||
validation_return_code,
|
||||
validation_stdout,
|
||||
validation_stderr,
|
||||
) = docker_validate_plano_schema(arch_config_file)
|
||||
if validation_return_code != 0:
|
||||
log.info(f"Error: Validation failed. Exiting")
|
||||
log.info(f"Validation stdout: {validation_stdout}")
|
||||
log.info(f"Validation stderr: {validation_stderr}")
|
||||
sys.exit(1)
|
||||
|
||||
# Set the ARCH_CONFIG_FILE environment variable
|
||||
env_stage = {
|
||||
"OTEL_TRACING_HTTP_ENDPOINT": "http://host.docker.internal:4318/v1/traces",
|
||||
}
|
||||
env = os.environ.copy()
|
||||
# Remove PATH variable if present
|
||||
env.pop("PATH", None)
|
||||
# check if access_keys are preesnt in the config file
|
||||
access_keys = get_llm_provider_access_keys(arch_config_file=arch_config_file)
|
||||
|
||||
# remove duplicates
|
||||
access_keys = set(access_keys)
|
||||
# remove the $ from the access_keys
|
||||
access_keys = [item[1:] if item.startswith("$") else item for item in access_keys]
|
||||
|
||||
if access_keys:
|
||||
if file:
|
||||
app_env_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(file)), ".env"
|
||||
) # check the .env file in the path
|
||||
else:
|
||||
app_env_file = os.path.abspath(os.path.join(path, ".env"))
|
||||
|
||||
if not os.path.exists(
|
||||
app_env_file
|
||||
): # check to see if the environment variables in the current environment or not
|
||||
for access_key in access_keys:
|
||||
if env.get(access_key) is None:
|
||||
log.info(f"Access Key: {access_key} not found. Exiting Start")
|
||||
sys.exit(1)
|
||||
else:
|
||||
env_stage[access_key] = env.get(access_key)
|
||||
else: # .env file exists, use that to send parameters to Arch
|
||||
env_file_dict = load_env_file_to_dict(app_env_file)
|
||||
for access_key in access_keys:
|
||||
if env_file_dict.get(access_key) is None:
|
||||
log.info(f"Access Key: {access_key} not found. Exiting Start")
|
||||
sys.exit(1)
|
||||
else:
|
||||
env_stage[access_key] = env_file_dict[access_key]
|
||||
|
||||
env.update(env_stage)
|
||||
start_arch(arch_config_file, env, foreground=foreground)
|
||||
|
||||
|
||||
@click.command()
|
||||
def down():
|
||||
"""Stops Arch."""
|
||||
stop_docker_container()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--f",
|
||||
"--file",
|
||||
type=click.Path(exists=True),
|
||||
required=True,
|
||||
help="Path to the Python file",
|
||||
)
|
||||
def generate_prompt_targets(file):
|
||||
"""Generats prompt_targets from python methods.
|
||||
Note: This works for simple data types like ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']:
|
||||
If you have a complex pydantic data type, you will have to flatten those manually until we add support for it.
|
||||
"""
|
||||
|
||||
print(f"Processing file: {file}")
|
||||
if not file.endswith(".py"):
|
||||
print("Error: Input file must be a .py file")
|
||||
sys.exit(1)
|
||||
|
||||
targets.generate_prompt_targets(file)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
"--debug",
|
||||
help="For detailed debug logs to trace calls from plano <> api_server, etc",
|
||||
is_flag=True,
|
||||
)
|
||||
@click.option("--follow", help="Follow the logs", is_flag=True)
|
||||
def logs(debug, follow):
|
||||
"""Stream logs from access logs services."""
|
||||
|
||||
archgw_process = None
|
||||
try:
|
||||
if debug:
|
||||
archgw_process = multiprocessing.Process(
|
||||
target=stream_gateway_logs, args=(follow,)
|
||||
)
|
||||
archgw_process.start()
|
||||
|
||||
archgw_access_logs_process = multiprocessing.Process(
|
||||
target=stream_access_logs, args=(follow,)
|
||||
)
|
||||
archgw_access_logs_process.start()
|
||||
archgw_access_logs_process.join()
|
||||
|
||||
if archgw_process:
|
||||
archgw_process.join()
|
||||
except KeyboardInterrupt:
|
||||
log.info("KeyboardInterrupt detected. Exiting.")
|
||||
if archgw_access_logs_process.is_alive():
|
||||
archgw_access_logs_process.terminate()
|
||||
if archgw_process and archgw_process.is_alive():
|
||||
archgw_process.terminate()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("type", type=click.Choice(["claude"]), required=True)
|
||||
@click.argument("file", required=False) # Optional file argument
|
||||
@click.option(
|
||||
"--path", default=".", help="Path to the directory containing arch_config.yaml"
|
||||
)
|
||||
@click.option(
|
||||
"--settings",
|
||||
default="{}",
|
||||
help="Additional settings as JSON string for the CLI agent.",
|
||||
)
|
||||
def cli_agent(type, file, path, settings):
|
||||
"""Start a CLI agent connected to Arch.
|
||||
|
||||
CLI_AGENT: The type of CLI agent to start (currently only 'claude' is supported)
|
||||
"""
|
||||
|
||||
# Check if plano docker container is running
|
||||
archgw_status = docker_container_status(PLANO_DOCKER_NAME)
|
||||
if archgw_status != "running":
|
||||
log.error(f"archgw docker container is not running (status: {archgw_status})")
|
||||
log.error("Please start plano using the 'planoai up' command.")
|
||||
sys.exit(1)
|
||||
|
||||
# Determine arch_config.yaml path
|
||||
arch_config_file = find_config_file(path, file)
|
||||
if not os.path.exists(arch_config_file):
|
||||
log.error(f"Config file not found: {arch_config_file}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
start_cli_agent(arch_config_file, settings)
|
||||
except SystemExit:
|
||||
# Re-raise SystemExit to preserve exit codes
|
||||
raise
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
main.add_command(up)
|
||||
main.add_command(down)
|
||||
main.add_command(build)
|
||||
main.add_command(logs)
|
||||
main.add_command(cli_agent)
|
||||
main.add_command(generate_prompt_targets)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
365
cli/planoai/targets.py
Normal file
365
cli/planoai/targets.py
Normal file
|
|
@ -0,0 +1,365 @@
|
|||
import ast
|
||||
import sys
|
||||
import yaml
|
||||
from typing import Any
|
||||
|
||||
FLASK_ROUTE_DECORATORS = ["route", "get", "post", "put", "delete", "patch"]
|
||||
FASTAPI_ROUTE_DECORATORS = ["get", "post", "put", "delete", "patch"]
|
||||
|
||||
|
||||
def detect_framework(tree: Any) -> str:
|
||||
"""Detect whether the file is using Flask or FastAPI based on imports."""
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ImportFrom):
|
||||
if node.module == "flask":
|
||||
return "flask"
|
||||
elif node.module == "fastapi":
|
||||
return "fastapi"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def get_route_decorators(node: Any, framework: str) -> list:
|
||||
"""Extract route decorators based on the framework."""
|
||||
decorators = []
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Call) and isinstance(
|
||||
decorator.func, ast.Attribute
|
||||
):
|
||||
if framework == "flask" and decorator.func.attr in FLASK_ROUTE_DECORATORS:
|
||||
decorators.append(decorator.func.attr)
|
||||
elif (
|
||||
framework == "fastapi"
|
||||
and decorator.func.attr in FASTAPI_ROUTE_DECORATORS
|
||||
):
|
||||
decorators.append(decorator.func.attr)
|
||||
return decorators
|
||||
|
||||
|
||||
def get_route_path(node: Any, framework: str) -> str:
|
||||
"""Extract route path based on the framework."""
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Call) and decorator.args:
|
||||
return decorator.args[0].s # Assuming it's a string literal
|
||||
|
||||
|
||||
def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool:
|
||||
"""Check if a given type annotation is a Pydantic model."""
|
||||
# We walk through the AST to find class definitions and check if they inherit from Pydantic's BaseModel
|
||||
if isinstance(annotation, ast.Name):
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == annotation.id:
|
||||
for base in node.bases:
|
||||
if isinstance(base, ast.Name) and base.id == "BaseModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
|
||||
"""Extract fields from a Pydantic model, handling list, tuple, set, dict types, and direct default values."""
|
||||
fields = []
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.ClassDef) and node.name == model_name:
|
||||
for stmt in node.body:
|
||||
if isinstance(stmt, ast.AnnAssign):
|
||||
# Initialize the default field description
|
||||
field_type = "Unknown: Please Fix This!"
|
||||
description = "Field, description not present. Please fix."
|
||||
default_value = None
|
||||
required = True # Assume the field is required initially
|
||||
|
||||
# Check if the field uses Field() with required status and description
|
||||
if (
|
||||
stmt.value
|
||||
and isinstance(stmt.value, ast.Call)
|
||||
and isinstance(stmt.value.func, ast.Name)
|
||||
and stmt.value.func.id == "Field"
|
||||
):
|
||||
# Extract the description argument inside the Field call
|
||||
for keyword in stmt.value.keywords:
|
||||
if keyword.arg == "description" and isinstance(
|
||||
keyword.value, ast.Str
|
||||
):
|
||||
description = keyword.value.s
|
||||
if keyword.arg == "default":
|
||||
default_value = keyword.value
|
||||
# If Ellipsis (...) is used, it means the field is required
|
||||
if (
|
||||
stmt.value.args
|
||||
and isinstance(stmt.value.args[0], ast.Constant)
|
||||
and stmt.value.args[0].value is Ellipsis
|
||||
):
|
||||
required = True
|
||||
else:
|
||||
required = False
|
||||
|
||||
# Handle direct default values (e.g., name: str = "John Doe")
|
||||
elif stmt.value is not None:
|
||||
if isinstance(stmt.value, ast.Constant):
|
||||
# Set the default value from the assignment (e.g., name: str = "John Doe")
|
||||
default_value = stmt.value.value
|
||||
required = (
|
||||
False # Not required since it has a default value
|
||||
)
|
||||
|
||||
# Always extract the field type, even if there's a default value
|
||||
if isinstance(stmt.annotation, ast.Subscript):
|
||||
# Get the base type (list, tuple, set, dict)
|
||||
base_type = (
|
||||
stmt.annotation.value.id
|
||||
if isinstance(stmt.annotation.value, ast.Name)
|
||||
else "Unknown"
|
||||
)
|
||||
|
||||
# Handle only list, tuple, set, dict and ignore the inner types
|
||||
if base_type.lower() in ["list", "tuple", "set", "dict"]:
|
||||
field_type = base_type.lower()
|
||||
|
||||
# Handle the ellipsis '...' for required fields if no Field() call
|
||||
elif (
|
||||
isinstance(stmt.value, ast.Constant)
|
||||
and stmt.value.value is Ellipsis
|
||||
):
|
||||
required = True
|
||||
|
||||
# Handle simple types like str, int, etc.
|
||||
if isinstance(stmt.annotation, ast.Name):
|
||||
field_type = stmt.annotation.id
|
||||
|
||||
field_info = {
|
||||
"name": stmt.target.id,
|
||||
"type": field_type, # Always set the field type
|
||||
"description": description,
|
||||
"default": default_value, # Handle direct default values
|
||||
"required": required,
|
||||
}
|
||||
fields.append(field_info)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
|
||||
"""Extract the parameters and their types from the function definition."""
|
||||
parameters = []
|
||||
|
||||
# Extract docstring to find descriptions
|
||||
docstring = ast.get_docstring(node)
|
||||
arg_descriptions = extract_arg_descriptions_from_docstring(docstring)
|
||||
|
||||
# Extract default values
|
||||
defaults = [None] * (
|
||||
len(node.args.args) - len(node.args.defaults)
|
||||
) + node.args.defaults # Align defaults with args
|
||||
for arg, default in zip(node.args.args, defaults):
|
||||
if arg.arg != "self": # Skip 'self' or 'cls' in class methods
|
||||
param_info = {
|
||||
"name": arg.arg,
|
||||
"description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]"),
|
||||
}
|
||||
|
||||
# Handle Pydantic model types
|
||||
if hasattr(arg, "annotation") and is_pydantic_model(arg.annotation, tree):
|
||||
# Extract and flatten Pydantic model fields
|
||||
pydantic_fields = get_pydantic_model_fields(arg.annotation.id, tree)
|
||||
parameters.extend(
|
||||
pydantic_fields
|
||||
) # Flatten the model fields into the parameters list
|
||||
continue # Skip adding the current param_info for the model since we expand the fields
|
||||
|
||||
# Handle standard Python types (int, float, str, etc.)
|
||||
elif hasattr(arg, "annotation") and isinstance(arg.annotation, ast.Name):
|
||||
if arg.annotation.id in [
|
||||
"int",
|
||||
"float",
|
||||
"bool",
|
||||
"str",
|
||||
"list",
|
||||
"tuple",
|
||||
"set",
|
||||
"dict",
|
||||
]:
|
||||
param_info["type"] = arg.annotation.id
|
||||
else:
|
||||
param_info["type"] = "[UNKNOWN - PLEASE FIX]"
|
||||
|
||||
# Handle generic subscript types (e.g., Optional, List[Type], etc.)
|
||||
elif hasattr(arg, "annotation") and isinstance(
|
||||
arg.annotation, ast.Subscript
|
||||
):
|
||||
if isinstance(
|
||||
arg.annotation.value, ast.Name
|
||||
) and arg.annotation.value.id in ["list", "tuple", "set", "dict"]:
|
||||
param_info[
|
||||
"type"
|
||||
] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
|
||||
else:
|
||||
param_info["type"] = "[UNKNOWN - PLEASE FIX]"
|
||||
|
||||
# Default for unknown types
|
||||
else:
|
||||
param_info[
|
||||
"type"
|
||||
] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type
|
||||
|
||||
# Handle default values
|
||||
if default is not None:
|
||||
if isinstance(default, ast.Constant) or isinstance(
|
||||
default, ast.NameConstant
|
||||
):
|
||||
param_info[
|
||||
"default"
|
||||
] = default.value # Use the default value directly
|
||||
else:
|
||||
param_info["default"] = "[UNKNOWN DEFAULT]" # Unknown default type
|
||||
param_info["required"] = False # Optional since it has a default value
|
||||
else:
|
||||
param_info["default"] = None
|
||||
param_info["required"] = True # Required if no default value
|
||||
|
||||
parameters.append(param_info)
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def get_function_docstring(node: Any) -> str:
|
||||
"""Extract the function's docstring description if present."""
|
||||
# Check if the first node is a docstring
|
||||
if isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value, ast.Str):
|
||||
# Get the entire docstring
|
||||
full_docstring = node.body[0].value.s.strip()
|
||||
|
||||
# Split the docstring by double newlines (to separate description from fields like Args)
|
||||
description = full_docstring.split("\n\n")[0].strip()
|
||||
|
||||
return description
|
||||
|
||||
return "No description provided."
|
||||
|
||||
|
||||
def extract_arg_descriptions_from_docstring(docstring: str) -> dict:
|
||||
"""Extract descriptions for function parameters from the 'Args' section of the docstring."""
|
||||
descriptions = {}
|
||||
if not docstring:
|
||||
return descriptions
|
||||
|
||||
in_args_section = False
|
||||
current_param = None
|
||||
for line in docstring.splitlines():
|
||||
line = line.strip()
|
||||
|
||||
# Detect the start of the 'Args' section
|
||||
if line.startswith("Args:"):
|
||||
in_args_section = True
|
||||
continue # Proceed to the next line after 'Args:'
|
||||
|
||||
# End of 'Args' section if no indentation and no colon
|
||||
if in_args_section and not line.startswith(" ") and ":" not in line:
|
||||
break # Stop processing if we reach a new section
|
||||
|
||||
# Process lines in the 'Args' section
|
||||
if in_args_section:
|
||||
if ":" in line:
|
||||
# Extract parameter name and description
|
||||
param_name, description = line.split(":", 1)
|
||||
descriptions[param_name.strip()] = description.strip()
|
||||
current_param = param_name.strip()
|
||||
elif current_param and line.startswith(" "):
|
||||
# Handle multiline descriptions (indented lines)
|
||||
descriptions[current_param] += f" {line.strip()}"
|
||||
|
||||
return descriptions
|
||||
|
||||
|
||||
def generate_prompt_targets(input_file_path: str) -> None:
|
||||
"""Introspect routes and generate YAML for either Flask or FastAPI."""
|
||||
with open(input_file_path, "r") as source:
|
||||
tree = ast.parse(source.read())
|
||||
|
||||
# Detect the framework (Flask or FastAPI)
|
||||
framework = detect_framework(tree)
|
||||
if framework == "unknown":
|
||||
print("Could not detect Flask or FastAPI in the file.")
|
||||
return
|
||||
|
||||
# Extract routes
|
||||
routes = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
route_decorators = get_route_decorators(node, framework)
|
||||
if route_decorators:
|
||||
route_path = get_route_path(node, framework)
|
||||
function_params = get_function_parameters(
|
||||
node, tree
|
||||
) # Get parameters for the route
|
||||
function_docstring = get_function_docstring(node) # Extract docstring
|
||||
routes.append(
|
||||
{
|
||||
"name": node.name,
|
||||
"path": route_path,
|
||||
"methods": route_decorators,
|
||||
"parameters": function_params, # Add parameters to the route
|
||||
"description": function_docstring, # Add the docstring as the description
|
||||
}
|
||||
)
|
||||
|
||||
# Generate YAML structure
|
||||
output_structure = {"prompt_targets": []}
|
||||
|
||||
for route in routes:
|
||||
target = {
|
||||
"name": route["name"],
|
||||
"endpoint": [
|
||||
{
|
||||
"name": "app_server",
|
||||
"path": route["path"],
|
||||
}
|
||||
],
|
||||
"description": route["description"], # Use extracted docstring
|
||||
"parameters": [
|
||||
{
|
||||
"name": param["name"],
|
||||
"type": param["type"],
|
||||
"description": f"{param['description']}",
|
||||
**(
|
||||
{"default": param["default"]}
|
||||
if "default" in param and param["default"] is not None
|
||||
else {}
|
||||
), # Only add default if it's set
|
||||
"required": param["required"],
|
||||
}
|
||||
for param in route["parameters"]
|
||||
],
|
||||
}
|
||||
|
||||
if route["name"] == "default":
|
||||
# Special case for `information_extraction` based on your YAML format
|
||||
target["type"] = "default"
|
||||
target["auto-llm-dispatch-on-response"] = True
|
||||
|
||||
output_structure["prompt_targets"].append(target)
|
||||
|
||||
# Output as YAML
|
||||
print(
|
||||
yaml.dump(output_structure, sort_keys=False, default_flow_style=False, indent=3)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python targets.py <input_file>")
|
||||
sys.exit(1)
|
||||
|
||||
input_file = sys.argv[1]
|
||||
|
||||
# Automatically generate the output file name
|
||||
if input_file.endswith(".py"):
|
||||
output_file = input_file.replace(".py", "_prompt_targets.yml")
|
||||
else:
|
||||
print("Error: Input file must be a .py file")
|
||||
sys.exit(1)
|
||||
|
||||
# Call the function with the input and generated output file names
|
||||
generate_prompt_targets(input_file, output_file)
|
||||
|
||||
# Example usage:
|
||||
# python targets.py api.yaml
|
||||
263
cli/planoai/utils.py
Normal file
263
cli/planoai/utils.py
Normal file
|
|
@ -0,0 +1,263 @@
|
|||
import glob
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import yaml
|
||||
import logging
|
||||
from planoai.consts import PLANO_DOCKER_NAME
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
|
||||
def getLogger(name="cli"):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
return logger
|
||||
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
def find_repo_root(start_path=None):
|
||||
"""Find the repository root by looking for Dockerfile or .git directory."""
|
||||
if start_path is None:
|
||||
start_path = os.getcwd()
|
||||
|
||||
current = os.path.abspath(start_path)
|
||||
|
||||
while current != os.path.dirname(current): # Stop at filesystem root
|
||||
# Check for markers that indicate repo root
|
||||
if (
|
||||
os.path.exists(os.path.join(current, "Dockerfile"))
|
||||
and os.path.exists(os.path.join(current, "crates"))
|
||||
and os.path.exists(os.path.join(current, "config"))
|
||||
):
|
||||
return current
|
||||
|
||||
# Also check for .git as fallback
|
||||
if os.path.exists(os.path.join(current, ".git")):
|
||||
# Verify it's the right repo by checking for expected structure
|
||||
if os.path.exists(os.path.join(current, "crates")):
|
||||
return current
|
||||
|
||||
current = os.path.dirname(current)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def has_ingress_listener(arch_config_file):
|
||||
"""Check if the arch config file has ingress_traffic listener configured."""
|
||||
try:
|
||||
with open(arch_config_file) as f:
|
||||
arch_config_dict = yaml.safe_load(f)
|
||||
|
||||
ingress_traffic = arch_config_dict.get("listeners", {}).get(
|
||||
"ingress_traffic", {}
|
||||
)
|
||||
|
||||
return bool(ingress_traffic)
|
||||
except Exception as e:
|
||||
log.error(f"Error reading config file {arch_config_file}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def convert_legacy_listeners(
|
||||
listeners: dict | list, model_providers: list | None
|
||||
) -> tuple[list, dict | None, dict | None]:
|
||||
llm_gateway_listener = {
|
||||
"name": "egress_traffic",
|
||||
"type": "model_listener",
|
||||
"port": 12000,
|
||||
"address": "0.0.0.0",
|
||||
"timeout": "30s",
|
||||
"model_providers": model_providers or [],
|
||||
}
|
||||
|
||||
prompt_gateway_listener = {
|
||||
"name": "ingress_traffic",
|
||||
"type": "prompt_listener",
|
||||
"port": 10000,
|
||||
"address": "0.0.0.0",
|
||||
"timeout": "30s",
|
||||
}
|
||||
|
||||
# Handle None case
|
||||
if listeners is None:
|
||||
return [llm_gateway_listener], llm_gateway_listener, prompt_gateway_listener
|
||||
|
||||
if isinstance(listeners, dict):
|
||||
# legacy listeners
|
||||
# check if type is array or object
|
||||
# if its dict its legacy format let's convert it to array
|
||||
updated_listeners = []
|
||||
ingress_traffic = listeners.get("ingress_traffic", {})
|
||||
egress_traffic = listeners.get("egress_traffic", {})
|
||||
|
||||
llm_gateway_listener["port"] = egress_traffic.get(
|
||||
"port", llm_gateway_listener["port"]
|
||||
)
|
||||
llm_gateway_listener["address"] = egress_traffic.get(
|
||||
"address", llm_gateway_listener["address"]
|
||||
)
|
||||
llm_gateway_listener["timeout"] = egress_traffic.get(
|
||||
"timeout", llm_gateway_listener["timeout"]
|
||||
)
|
||||
if model_providers is None or model_providers == []:
|
||||
raise ValueError("model_providers cannot be empty when using legacy format")
|
||||
|
||||
llm_gateway_listener["model_providers"] = model_providers
|
||||
updated_listeners.append(llm_gateway_listener)
|
||||
|
||||
if ingress_traffic and ingress_traffic != {}:
|
||||
prompt_gateway_listener["port"] = ingress_traffic.get(
|
||||
"port", prompt_gateway_listener["port"]
|
||||
)
|
||||
prompt_gateway_listener["address"] = ingress_traffic.get(
|
||||
"address", prompt_gateway_listener["address"]
|
||||
)
|
||||
prompt_gateway_listener["timeout"] = ingress_traffic.get(
|
||||
"timeout", prompt_gateway_listener["timeout"]
|
||||
)
|
||||
updated_listeners.append(prompt_gateway_listener)
|
||||
|
||||
return updated_listeners, llm_gateway_listener, prompt_gateway_listener
|
||||
|
||||
model_provider_set = False
|
||||
for listener in listeners:
|
||||
if listener.get("type") == "model_listener":
|
||||
if model_provider_set:
|
||||
raise ValueError(
|
||||
"Currently only one listener can have model_providers set"
|
||||
)
|
||||
listener["model_providers"] = model_providers or []
|
||||
model_provider_set = True
|
||||
llm_gateway_listener = listener
|
||||
if not model_provider_set:
|
||||
listeners.append(llm_gateway_listener)
|
||||
|
||||
return listeners, llm_gateway_listener, prompt_gateway_listener
|
||||
|
||||
|
||||
def get_llm_provider_access_keys(arch_config_file):
|
||||
with open(arch_config_file, "r") as file:
|
||||
arch_config = file.read()
|
||||
arch_config_yaml = yaml.safe_load(arch_config)
|
||||
|
||||
access_key_list = []
|
||||
|
||||
# Convert legacy llm_providers to model_providers
|
||||
if "llm_providers" in arch_config_yaml:
|
||||
if "model_providers" in arch_config_yaml:
|
||||
raise Exception(
|
||||
"Please provide either llm_providers or model_providers, not both. llm_providers is deprecated, please use model_providers instead"
|
||||
)
|
||||
arch_config_yaml["model_providers"] = arch_config_yaml["llm_providers"]
|
||||
del arch_config_yaml["llm_providers"]
|
||||
|
||||
listeners, _, _ = convert_legacy_listeners(
|
||||
arch_config_yaml.get("listeners"), arch_config_yaml.get("model_providers")
|
||||
)
|
||||
|
||||
for prompt_target in arch_config_yaml.get("prompt_targets", []):
|
||||
for k, v in prompt_target.get("endpoint", {}).get("http_headers", {}).items():
|
||||
if k.lower() == "authorization":
|
||||
print(
|
||||
f"found auth header: {k} for prompt_target: {prompt_target.get('name')}/{prompt_target.get('endpoint').get('name')}"
|
||||
)
|
||||
auth_tokens = v.split(" ")
|
||||
if len(auth_tokens) > 1:
|
||||
access_key_list.append(auth_tokens[1])
|
||||
else:
|
||||
access_key_list.append(v)
|
||||
|
||||
for listener in listeners:
|
||||
for llm_provider in listener.get("model_providers", []):
|
||||
access_key = llm_provider.get("access_key")
|
||||
if access_key is not None:
|
||||
access_key_list.append(access_key)
|
||||
|
||||
# Extract environment variables from state_storage.connection_string
|
||||
state_storage = arch_config_yaml.get("state_storage_v1_responses")
|
||||
if state_storage:
|
||||
connection_string = state_storage.get("connection_string")
|
||||
if connection_string and isinstance(connection_string, str):
|
||||
# Extract all $VAR and ${VAR} patterns from connection string
|
||||
import re
|
||||
|
||||
# Match both $VAR and ${VAR} patterns
|
||||
pattern = r"\$\{?([A-Z_][A-Z0-9_]*)\}?"
|
||||
matches = re.findall(pattern, connection_string)
|
||||
for var in matches:
|
||||
access_key_list.append(f"${var}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid connection string received in state_storage_v1_responses"
|
||||
)
|
||||
|
||||
return access_key_list
|
||||
|
||||
|
||||
def load_env_file_to_dict(file_path):
|
||||
env_dict = {}
|
||||
|
||||
# Open and read the .env file
|
||||
with open(file_path, "r") as file:
|
||||
for line in file:
|
||||
# Strip any leading/trailing whitespaces
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
# Split the line into key and value at the first '=' sign
|
||||
if "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
# Add key-value pair to the dictionary
|
||||
env_dict[key] = value
|
||||
|
||||
return env_dict
|
||||
|
||||
|
||||
def find_config_file(path=".", file=None):
|
||||
"""Find the appropriate config file path."""
|
||||
if file:
|
||||
# If a file is provided, process that file
|
||||
return os.path.abspath(file)
|
||||
else:
|
||||
# If no file is provided, use the path and look for arch_config.yaml first, then config.yaml for convenience
|
||||
arch_config_file = os.path.abspath(os.path.join(path, "config.yaml"))
|
||||
if not os.path.exists(arch_config_file):
|
||||
arch_config_file = os.path.abspath(os.path.join(path, "arch_config.yaml"))
|
||||
return arch_config_file
|
||||
|
||||
|
||||
def stream_access_logs(follow):
|
||||
"""
|
||||
Get the archgw access logs
|
||||
"""
|
||||
|
||||
follow_arg = "-f" if follow else ""
|
||||
|
||||
stream_command = [
|
||||
"docker",
|
||||
"exec",
|
||||
PLANO_DOCKER_NAME,
|
||||
"sh",
|
||||
"-c",
|
||||
f"tail {follow_arg} /var/log/access_*.log",
|
||||
]
|
||||
|
||||
subprocess.run(
|
||||
stream_command,
|
||||
check=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue