lint + formating with black (#158)

* lint + formating with black

* add black as pre commit
This commit is contained in:
Co Tran 2024-10-09 11:25:07 -07:00 committed by GitHub
parent 498e7f9724
commit 5c4a6bc8ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 581 additions and 295 deletions

View file

@ -25,3 +25,8 @@ repos:
# --lib is to only test the library, since when integration tests are made,
# they will be in a seperate tests directory
entry: bash -c "cd arch && cargo test -p intelligent-prompt-gateway --lib"
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
language_version: python3

View file

@ -16,18 +16,22 @@ logo = r"""
/_/ \_\|_| \___||_| |_|
"""
@click.group(invoke_without_command=True)
@click.pass_context
def main(ctx):
if ctx.invoked_subcommand is None:
click.echo( """Arch (The Intelligent Prompt Gateway) CLI""")
click.echo("""Arch (The Intelligent Prompt Gateway) CLI""")
click.echo(logo)
click.echo(ctx.get_help())
# Command to build archgw and model_server Docker images
ARCHGW_DOCKERFILE = "./arch/Dockerfile"
MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml"
@click.command()
def build():
"""Build Arch from source. Must be in root of cloned repo."""
@ -35,7 +39,18 @@ def build():
if os.path.exists(ARCHGW_DOCKERFILE):
click.echo("Building archgw image...")
try:
subprocess.run(["docker", "build", "-f", ARCHGW_DOCKERFILE, "-t", "archgw:latest", "."], check=True)
subprocess.run(
[
"docker",
"build",
"-f",
ARCHGW_DOCKERFILE,
"-t",
"archgw:latest",
".",
],
check=True,
)
click.echo("archgw image built successfully.")
except subprocess.CalledProcessError as e:
click.echo(f"Error building archgw image: {e}")
@ -51,7 +66,11 @@ def build():
if os.path.exists(MODEL_SERVER_BUILD_FILE):
click.echo("Installing model server dependencies with Poetry...")
try:
subprocess.run(["poetry", "install", "--no-cache"], cwd=os.path.dirname(MODEL_SERVER_BUILD_FILE), check=True)
subprocess.run(
["poetry", "install", "--no-cache"],
cwd=os.path.dirname(MODEL_SERVER_BUILD_FILE),
check=True,
)
click.echo("Model server dependencies installed successfully.")
except subprocess.CalledProcessError as e:
click.echo(f"Error installing model server dependencies: {e}")
@ -60,9 +79,12 @@ def build():
click.echo(f"Error: pyproject.toml not found in {MODEL_SERVER_BUILD_FILE}")
sys.exit(1)
@click.command()
@click.argument('file', required=False) # Optional file argument
@click.option('-path', default='.', help='Path to the directory containing arch_config.yml')
@click.argument("file", required=False) # Optional file argument
@click.option(
"-path", default=".", help="Path to the directory containing arch_config.yml"
)
def up(file, path):
"""Starts Arch."""
if file:
@ -78,10 +100,15 @@ def up(file, path):
return
print(f"Validating {arch_config_file}")
arch_schema_config = pkg_resources.resource_filename(__name__, "config/arch_config_schema.yaml")
arch_schema_config = pkg_resources.resource_filename(
__name__, "config/arch_config_schema.yaml"
)
try:
config_generator.validate_prompt_config(arch_config_file=arch_config_file, arch_config_schema_file=arch_schema_config)
config_generator.validate_prompt_config(
arch_config_file=arch_config_file,
arch_config_schema_file=arch_schema_config,
)
except Exception as e:
print("Exiting archgw up")
sys.exit(1)
@ -91,52 +118,67 @@ def up(file, path):
# Set the ARCH_CONFIG_FILE environment variable
env_stage = {}
env = os.environ.copy()
#check if access_keys are preesnt in the config file
# check if access_keys are preesnt in the config file
access_keys = get_llm_provider_access_keys(arch_config_file=arch_config_file)
if access_keys:
if file:
app_env_file = os.path.join(os.path.dirname(os.path.abspath(file)), ".env") #check the .env file in the path
app_env_file = os.path.join(
os.path.dirname(os.path.abspath(file)), ".env"
) # check the .env file in the path
else:
app_env_file = os.path.abspath(os.path.join(path, ".env"))
if not os.path.exists(app_env_file): #check to see if the environment variables in the current environment or not
if not os.path.exists(
app_env_file
): # check to see if the environment variables in the current environment or not
for access_key in access_keys:
if env.get(access_key) is None:
print (f"Access Key: {access_key} not found. Exiting Start")
print(f"Access Key: {access_key} not found. Exiting Start")
sys.exit(1)
else:
env_stage[access_key] = env.get(access_key)
else: #.env file exists, use that to send parameters to Arch
else: # .env file exists, use that to send parameters to Arch
env_file_dict = load_env_file_to_dict(app_env_file)
for access_key in access_keys:
if env_file_dict.get(access_key) is None:
print (f"Access Key: {access_key} not found. Exiting Start")
print(f"Access Key: {access_key} not found. Exiting Start")
sys.exit(1)
else:
env_stage[access_key] = env_file_dict[access_key]
with open(pkg_resources.resource_filename(__name__, "config/stage.env"), 'w') as file:
with open(
pkg_resources.resource_filename(__name__, "config/stage.env"), "w"
) as file:
for key, value in env_stage.items():
file.write(f"{key}={value}\n")
env.update(env_stage)
env['ARCH_CONFIG_FILE'] = arch_config_file
env["ARCH_CONFIG_FILE"] = arch_config_file
start_arch_modelserver()
start_arch(arch_config_file, env)
@click.command()
def down():
"""Stops Arch."""
stop_arch_modelserver()
stop_arch()
@click.command()
@click.option('-f', '--file', type=click.Path(exists=True), required=True, help="Path to the Python file")
@click.option(
"-f",
"--file",
type=click.Path(exists=True),
required=True,
help="Path to the Python file",
)
def generate_prompt_targets(file):
"""Generats prompt_targets from python methods.
Note: This works for simple data types like ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']:
If you have a complex pydantic data type, you will have to flatten those manually until we add support for it."""
Note: This works for simple data types like ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']:
If you have a complex pydantic data type, you will have to flatten those manually until we add support for it.
"""
print(f"Processing file: {file}")
if not file.endswith(".py"):
@ -145,10 +187,11 @@ def generate_prompt_targets(file):
targets.generate_prompt_targets(file)
main.add_command(up)
main.add_command(down)
main.add_command(build)
main.add_command(generate_prompt_targets)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View file

@ -3,36 +3,44 @@ from jinja2 import Environment, FileSystemLoader
import yaml
from jsonschema import validate
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.template.yaml')
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml')
ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml')
ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml')
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
)
ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/config/arch_config.yaml")
ENVOY_CONFIG_FILE_RENDERED = os.getenv(
"ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml"
)
ARCH_CONFIG_SCHEMA_FILE = os.getenv(
"ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml"
)
def add_secret_key_to_llm_providers(config_yaml) :
def add_secret_key_to_llm_providers(config_yaml):
llm_providers = []
for llm_provider in config_yaml.get("llm_providers", []):
access_key_env_var = llm_provider.get('access_key', False)
access_key_env_var = llm_provider.get("access_key", False)
access_key_value = os.getenv(access_key_env_var, False)
if access_key_env_var and access_key_value:
llm_provider['access_key'] = access_key_value
llm_provider["access_key"] = access_key_value
llm_providers.append(llm_provider)
config_yaml["llm_providers"] = llm_providers
return config_yaml
def validate_and_render_schema():
env = Environment(loader=FileSystemLoader('./'))
template = env.get_template('envoy.template.yaml')
env = Environment(loader=FileSystemLoader("./"))
template = env.get_template("envoy.template.yaml")
try:
validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE)
except Exception as e:
print(e)
exit(1) # validate_prompt_config failed. Exit
exit(1) # validate_prompt_config failed. Exit
with open(ARCH_CONFIG_FILE, 'r') as file:
with open(ARCH_CONFIG_FILE, "r") as file:
arch_config = file.read()
with open(ARCH_CONFIG_SCHEMA_FILE, 'r') as file:
with open(ARCH_CONFIG_SCHEMA_FILE, "r") as file:
arch_config_schema = file.read()
config_yaml = yaml.safe_load(arch_config)
@ -44,7 +52,7 @@ def validate_and_render_schema():
if name not in inferred_clusters:
inferred_clusters[name] = {
"name": name,
"port": 80, # default port
"port": 80, # default port
}
print(inferred_clusters)
@ -55,14 +63,13 @@ def validate_and_render_schema():
if name in inferred_clusters:
print("updating cluster", endpoint_details)
inferred_clusters[name].update(endpoint_details)
endpoint = inferred_clusters[name]['endpoint']
if len(endpoint.split(':')) > 1:
inferred_clusters[name]['endpoint'] = endpoint.split(':')[0]
inferred_clusters[name]['port'] = int(endpoint.split(':')[1])
endpoint = inferred_clusters[name]["endpoint"]
if len(endpoint.split(":")) > 1:
inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
else:
inferred_clusters[name] = endpoint_details
print("updated clusters", inferred_clusters)
config_yaml = add_secret_key_to_llm_providers(config_yaml)
@ -71,23 +78,24 @@ def validate_and_render_schema():
arch_config_string = yaml.dump(config_yaml)
data = {
'arch_config': arch_config_string,
'arch_clusters': inferred_clusters,
'arch_llm_providers': arch_llm_providers,
'arch_tracing': arch_tracing
"arch_config": arch_config_string,
"arch_clusters": inferred_clusters,
"arch_llm_providers": arch_llm_providers,
"arch_tracing": arch_tracing,
}
rendered = template.render(data)
print(rendered)
print(ENVOY_CONFIG_FILE_RENDERED)
with open(ENVOY_CONFIG_FILE_RENDERED, 'w') as file:
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
file.write(rendered)
def validate_prompt_config(arch_config_file, arch_config_schema_file):
with open(arch_config_file, 'r') as file:
with open(arch_config_file, "r") as file:
arch_config = file.read()
with open(arch_config_schema_file, 'r') as file:
with open(arch_config_schema_file, "r") as file:
arch_config_schema = file.read()
config_yaml = yaml.safe_load(arch_config)
@ -96,8 +104,11 @@ def validate_prompt_config(arch_config_file, arch_config_schema_file):
try:
validate(config_yaml, config_schema_yaml)
except Exception as e:
print(f"Error validating arch_config file: {arch_config_file}, error: {e.message}")
print(
f"Error validating arch_config file: {arch_config_file}, error: {e.message}"
)
raise e
if __name__ == '__main__':
if __name__ == "__main__":
validate_and_render_schema()

View file

@ -5,6 +5,7 @@ import pkg_resources
import select
from utils import run_docker_compose_ps, print_service_status, check_services_state
def start_arch(arch_config_file, env, log_timeout=120):
"""
Start Docker Compose in detached mode and stream logs until services are healthy.
@ -14,22 +15,35 @@ def start_arch(arch_config_file, env, log_timeout=120):
log_timeout (int): Time in seconds to show logs before checking for healthy state.
"""
compose_file = pkg_resources.resource_filename(__name__, 'config/docker-compose.yaml')
compose_file = pkg_resources.resource_filename(
__name__, "config/docker-compose.yaml"
)
try:
# Run the Docker Compose command in detached mode (-d)
subprocess.run(
["docker", "compose", "-p", "arch", "up", "-d",],
cwd=os.path.dirname(compose_file), # Ensure the Docker command runs in the correct path
env=env, # Pass the modified environment
check=True # Raise an exception if the command fails
[
"docker",
"compose",
"-p",
"arch",
"up",
"-d",
],
cwd=os.path.dirname(
compose_file
), # Ensure the Docker command runs in the correct path
env=env, # Pass the modified environment
check=True, # Raise an exception if the command fails
)
print(f"Arch docker-compose started in detached.")
print("Monitoring `docker-compose ps` logs...")
start_time = time.time()
services_status = {}
services_running = False #assume that the services are not running at the moment
services_running = (
False # assume that the services are not running at the moment
)
while True:
current_time = time.time()
@ -40,16 +54,22 @@ def start_arch(arch_config_file, env, log_timeout=120):
print(f"Stopping log monitoring after {log_timeout} seconds.")
break
current_services_status = run_docker_compose_ps(compose_file=compose_file, env=env)
current_services_status = run_docker_compose_ps(
compose_file=compose_file, env=env
)
if not current_services_status:
print("Status for the services could not be detected. Something went wrong. Please run docker logs")
print(
"Status for the services could not be detected. Something went wrong. Please run docker logs"
)
break
if not services_status:
services_status = current_services_status #set the first time
print_service_status(services_status) #print the services status and proceed.
services_status = current_services_status # set the first time
print_service_status(
services_status
) # print the services status and proceed.
#check if anyone service is failed or exited state, if so print and break out
# check if anyone service is failed or exited state, if so print and break out
unhealthy_states = ["unhealthy", "exit", "exited", "dead", "bad"]
running_states = ["running", "up"]
@ -58,14 +78,23 @@ def start_arch(arch_config_file, env, log_timeout=120):
break
if check_services_state(current_services_status, unhealthy_states):
print("One or more Arch services are unhealthy. Please run `docker logs` for more information")
print_service_status(current_services_status) #print the services status and proceed.
print(
"One or more Arch services are unhealthy. Please run `docker logs` for more information"
)
print_service_status(
current_services_status
) # print the services status and proceed.
break
#check to see if the status of one of the services has changed from prior. Print and loop over until finish, or error
# check to see if the status of one of the services has changed from prior. Print and loop over until finish, or error
for service_name in services_status.keys():
if services_status[service_name]['State'] != current_services_status[service_name]['State']:
print("One or more Arch services have changed state. Printing current state")
if (
services_status[service_name]["State"]
!= current_services_status[service_name]["State"]
):
print(
"One or more Arch services have changed state. Printing current state"
)
print_service_status(current_services_status)
break
@ -82,7 +111,9 @@ def stop_arch():
Args:
path (str): The path where the docker-compose.yml file is located.
"""
compose_file = pkg_resources.resource_filename(__name__, 'config/docker-compose.yaml')
compose_file = pkg_resources.resource_filename(
__name__, "config/docker-compose.yaml"
)
try:
# Run `docker-compose down` to shut down all services
@ -96,6 +127,7 @@ def stop_arch():
except subprocess.CalledProcessError as e:
print(f"Failed to shut down services: {str(e)}")
def start_arch_modelserver():
"""
Start the model server. This assumes that the archgw_modelserver package is installed locally
@ -103,15 +135,14 @@ def start_arch_modelserver():
"""
try:
subprocess.run(
['archgw_modelserver', 'restart'],
check=True,
start_new_session=True
["archgw_modelserver", "restart"], check=True, start_new_session=True
)
print("Successfull run the archgw model_server")
except subprocess.CalledProcessError as e:
print (f"Failed to start model_server. Please check archgw_modelserver logs")
print(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1)
def stop_arch_modelserver():
"""
Stop the model server. This assumes that the archgw_modelserver package is installed locally
@ -119,10 +150,10 @@ def stop_arch_modelserver():
"""
try:
subprocess.run(
['archgw_modelserver', 'stop'],
["archgw_modelserver", "stop"],
check=True,
)
print("Successfull stopped the archgw model_server")
except subprocess.CalledProcessError as e:
print (f"Failed to start model_server. Please check archgw_modelserver logs")
print(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1)

View file

@ -6,17 +6,29 @@ setup(
description="Python-based CLI tool to manage Arch and generate targets.",
author="Katanemo Labs, Inc.",
packages=find_packages(),
py_modules = ['cli', 'core', 'targets', 'utils', 'config_generator'],
py_modules=["cli", "core", "targets", "utils", "config_generator"],
include_package_data=True,
# Specify to include the docker-compose.yml file
package_data={
'': ['config/docker-compose.yaml', 'config/arch_config_schema.yaml', 'config/stage.env'] #Specify to include the docker-compose.yml file
"": [
"config/docker-compose.yaml",
"config/arch_config_schema.yaml",
"config/stage.env",
] # Specify to include the docker-compose.yml file
},
# Add dependencies here, e.g., 'PyYAML' for YAML processing
install_requires=['pyyaml', 'pydantic', 'click', 'jinja2','pyyaml','jsonschema', 'setuptools'],
install_requires=[
"pyyaml",
"pydantic",
"click",
"jinja2",
"pyyaml",
"jsonschema",
"setuptools",
],
entry_points={
'console_scripts': [
'archgw=cli:main',
"console_scripts": [
"archgw=cli:main",
],
},
)

View file

@ -18,14 +18,20 @@ def detect_framework(tree: Any) -> str:
return "fastapi"
return "unknown"
def get_route_decorators(node: Any, framework: str) -> list:
"""Extract route decorators based on the framework."""
decorators = []
for decorator in node.decorator_list:
if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute):
if isinstance(decorator, ast.Call) and isinstance(
decorator.func, ast.Attribute
):
if framework == "flask" and decorator.func.attr in FLASK_ROUTE_DECORATORS:
decorators.append(decorator.func.attr)
elif framework == "fastapi" and decorator.func.attr in FASTAPI_ROUTE_DECORATORS:
elif (
framework == "fastapi"
and decorator.func.attr in FASTAPI_ROUTE_DECORATORS
):
decorators.append(decorator.func.attr)
return decorators
@ -36,6 +42,7 @@ def get_route_path(node: Any, framework: str) -> str:
if isinstance(decorator, ast.Call) and decorator.args:
return decorator.args[0].s # Assuming it's a string literal
def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool:
"""Check if a given type annotation is a Pydantic model."""
# We walk through the AST to find class definitions and check if they inherit from Pydantic's BaseModel
@ -47,6 +54,7 @@ def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool:
return True
return False
def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
"""Extract fields from a Pydantic model, handling list, tuple, set, dict types, and direct default values."""
fields = []
@ -62,15 +70,26 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
required = True # Assume the field is required initially
# Check if the field uses Field() with required status and description
if stmt.value and isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == 'Field':
if (
stmt.value
and isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id == "Field"
):
# Extract the description argument inside the Field call
for keyword in stmt.value.keywords:
if keyword.arg == 'description' and isinstance(keyword.value, ast.Str):
if keyword.arg == "description" and isinstance(
keyword.value, ast.Str
):
description = keyword.value.s
if keyword.arg == 'default':
if keyword.arg == "default":
default_value = keyword.value
# If Ellipsis (...) is used, it means the field is required
if stmt.value.args and isinstance(stmt.value.args[0], ast.Constant) and stmt.value.args[0].value is Ellipsis:
if (
stmt.value.args
and isinstance(stmt.value.args[0], ast.Constant)
and stmt.value.args[0].value is Ellipsis
):
required = True
else:
required = False
@ -80,19 +99,28 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
if isinstance(stmt.value, ast.Constant):
# Set the default value from the assignment (e.g., name: str = "John Doe")
default_value = stmt.value.value
required = False # Not required since it has a default value
required = (
False # Not required since it has a default value
)
# Always extract the field type, even if there's a default value
if isinstance(stmt.annotation, ast.Subscript):
# Get the base type (list, tuple, set, dict)
base_type = stmt.annotation.value.id if isinstance(stmt.annotation.value, ast.Name) else "Unknown"
base_type = (
stmt.annotation.value.id
if isinstance(stmt.annotation.value, ast.Name)
else "Unknown"
)
# Handle only list, tuple, set, dict and ignore the inner types
if base_type.lower() in ['list', 'tuple', 'set', 'dict']:
if base_type.lower() in ["list", "tuple", "set", "dict"]:
field_type = base_type.lower()
# Handle the ellipsis '...' for required fields if no Field() call
elif isinstance(stmt.value, ast.Constant) and stmt.value.value is Ellipsis:
elif (
isinstance(stmt.value, ast.Constant)
and stmt.value.value is Ellipsis
):
required = True
# Handle simple types like str, int, etc.
@ -100,16 +128,17 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
field_type = stmt.annotation.id
field_info = {
"name": stmt.target.id,
"type": field_type, # Always set the field type
"description": description,
"default": default_value, # Handle direct default values
"required": required
"name": stmt.target.id,
"type": field_type, # Always set the field type
"description": description,
"default": default_value, # Handle direct default values
"required": required,
}
fields.append(field_info)
return fields
def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
"""Extract the parameters and their types from the function definition."""
parameters = []
@ -119,40 +148,68 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
arg_descriptions = extract_arg_descriptions_from_docstring(docstring)
# Extract default values
defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + node.args.defaults # Align defaults with args
defaults = [None] * (
len(node.args.args) - len(node.args.defaults)
) + node.args.defaults # Align defaults with args
for arg, default in zip(node.args.args, defaults):
if arg.arg != "self": # Skip 'self' or 'cls' in class methods
param_info = {"name": arg.arg, "description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]")}
param_info = {
"name": arg.arg,
"description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]"),
}
# Handle Pydantic model types
if hasattr(arg, 'annotation') and is_pydantic_model(arg.annotation, tree):
if hasattr(arg, "annotation") and is_pydantic_model(arg.annotation, tree):
# Extract and flatten Pydantic model fields
pydantic_fields = get_pydantic_model_fields(arg.annotation.id, tree)
parameters.extend(pydantic_fields) # Flatten the model fields into the parameters list
parameters.extend(
pydantic_fields
) # Flatten the model fields into the parameters list
continue # Skip adding the current param_info for the model since we expand the fields
# Handle standard Python types (int, float, str, etc.)
elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Name):
if arg.annotation.id in ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']:
elif hasattr(arg, "annotation") and isinstance(arg.annotation, ast.Name):
if arg.annotation.id in [
"int",
"float",
"bool",
"str",
"list",
"tuple",
"set",
"dict",
]:
param_info["type"] = arg.annotation.id
else:
param_info["type"] = "[UNKNOWN - PLEASE FIX]"
# Handle generic subscript types (e.g., Optional, List[Type], etc.)
elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Subscript):
if isinstance(arg.annotation.value, ast.Name) and arg.annotation.value.id in ['list', 'tuple', 'set', 'dict']:
param_info["type"] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
elif hasattr(arg, "annotation") and isinstance(
arg.annotation, ast.Subscript
):
if isinstance(
arg.annotation.value, ast.Name
) and arg.annotation.value.id in ["list", "tuple", "set", "dict"]:
param_info[
"type"
] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
else:
param_info["type"] = "[UNKNOWN - PLEASE FIX]"
# Default for unknown types
else:
param_info["type"] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type
param_info[
"type"
] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type
# Handle default values
if default is not None:
if isinstance(default, ast.Constant) or isinstance(default, ast.NameConstant):
param_info["default"] = default.value # Use the default value directly
if isinstance(default, ast.Constant) or isinstance(
default, ast.NameConstant
):
param_info[
"default"
] = default.value # Use the default value directly
else:
param_info["default"] = "[UNKNOWN DEFAULT]" # Unknown default type
param_info["required"] = False # Optional since it has a default value
@ -164,6 +221,7 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
return parameters
def get_function_docstring(node: Any) -> str:
"""Extract the function's docstring description if present."""
# Check if the first node is a docstring
@ -178,6 +236,7 @@ def get_function_docstring(node: Any) -> str:
return "No description provided."
def extract_arg_descriptions_from_docstring(docstring: str) -> dict:
"""Extract descriptions for function parameters from the 'Args' section of the docstring."""
descriptions = {}
@ -195,14 +254,14 @@ def extract_arg_descriptions_from_docstring(docstring: str) -> dict:
continue # Proceed to the next line after 'Args:'
# End of 'Args' section if no indentation and no colon
if in_args_section and not line.startswith(" ") and ':' not in line:
if in_args_section and not line.startswith(" ") and ":" not in line:
break # Stop processing if we reach a new section
# Process lines in the 'Args' section
if in_args_section:
if ':' in line:
if ":" in line:
# Extract parameter name and description
param_name, description = line.split(':', 1)
param_name, description = line.split(":", 1)
descriptions[param_name.strip()] = description.strip()
current_param = param_name.strip()
elif current_param and line.startswith(" "):
@ -230,43 +289,50 @@ def generate_prompt_targets(input_file_path: str) -> None:
route_decorators = get_route_decorators(node, framework)
if route_decorators:
route_path = get_route_path(node, framework)
function_params = get_function_parameters(node, tree) # Get parameters for the route
function_params = get_function_parameters(
node, tree
) # Get parameters for the route
function_docstring = get_function_docstring(node) # Extract docstring
routes.append({
'name': node.name,
'path': route_path,
'methods': route_decorators,
'parameters': function_params, # Add parameters to the route
'description': function_docstring # Add the docstring as the description
})
routes.append(
{
"name": node.name,
"path": route_path,
"methods": route_decorators,
"parameters": function_params, # Add parameters to the route
"description": function_docstring, # Add the docstring as the description
}
)
# Generate YAML structure
output_structure = {
"prompt_targets": []
}
output_structure = {"prompt_targets": []}
for route in routes:
target = {
"name": route['name'],
"name": route["name"],
"endpoint": [
{
"name": "app_server",
"path": route['path'],
"path": route["path"],
}
],
"description": route['description'], # Use extracted docstring
"description": route["description"], # Use extracted docstring
"parameters": [
{
"name": param['name'],
"type": param['type'],
"name": param["name"],
"type": param["type"],
"description": f"{param['description']}",
**({"default": param['default']} if "default" in param and param['default'] is not None else {}), # Only add default if it's set
"required": param['required']
} for param in route['parameters']
]
**(
{"default": param["default"]}
if "default" in param and param["default"] is not None
else {}
), # Only add default if it's set
"required": param["required"],
}
for param in route["parameters"]
],
}
if route['name'] == "default":
if route["name"] == "default":
# Special case for `information_extraction` based on your YAML format
target["type"] = "default"
target["auto-llm-dispatch-on-response"] = True
@ -274,9 +340,12 @@ def generate_prompt_targets(input_file_path: str) -> None:
output_structure["prompt_targets"].append(target)
# Output as YAML
print(yaml.dump(output_structure, sort_keys=False,default_flow_style=False, indent=3))
print(
yaml.dump(output_structure, sort_keys=False, default_flow_style=False, indent=3)
)
if __name__ == '__main__':
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python targets.py <input_file>")
sys.exit(1)

View file

@ -4,12 +4,23 @@ from typing import List, Dict, Set
app = FastAPI()
class User(BaseModel):
name: str = Field("John Doe", description="The name of the user.") # Default value and description for name
name: str = Field(
"John Doe", description="The name of the user."
) # Default value and description for name
location: int = None
age: int = Field(30, description="The age of the user.") # Default value and description for age
tags: Set[str] = Field(default_factory=set, description="A set of tags associated with the user.") # Default empty set and description for tags
metadata: Dict[str, int] = Field(default_factory=dict, description="A dictionary storing metadata about the user, with string keys and integer values.") # Default empty dict and description for metadata
age: int = Field(
30, description="The age of the user."
) # Default value and description for age
tags: Set[str] = Field(
default_factory=set, description="A set of tags associated with the user."
) # Default empty set and description for tags
metadata: Dict[str, int] = Field(
default_factory=dict,
description="A dictionary storing metadata about the user, with string keys and integer values.",
) # Default empty dict and description for metadata
@app.get("/agent/default")
async def default(request: User):
@ -19,6 +30,7 @@ async def default(request: User):
"""
return {"info": f"Query: {request.name}, Count: {request.age}"}
@app.post("/agent/action")
async def reboot_network_device(device_id: str, confirmation: str):
"""

View file

@ -6,6 +6,7 @@ import shlex
import yaml
import json
def run_docker_compose_ps(compose_file, env):
"""
Check if all Docker Compose services are in a healthy state.
@ -16,20 +17,31 @@ def run_docker_compose_ps(compose_file, env):
try:
# Run `docker-compose ps` to get the health status of each service
ps_process = subprocess.Popen(
["docker", "compose", "-p", "arch", "ps", "--format", "table{{.Service}}\t{{.State}}\t{{.Ports}}"],
[
"docker",
"compose",
"-p",
"arch",
"ps",
"--format",
"table{{.Service}}\t{{.State}}\t{{.Ports}}",
],
cwd=os.path.dirname(compose_file),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
start_new_session=True,
env=env
env=env,
)
# Capture the output of `docker-compose ps`
services_status, error_output = ps_process.communicate()
# Check if there is any error output
if error_output:
print(f"Error while checking service status:\n{error_output}", file=os.sys.stderr)
print(
f"Error while checking service status:\n{error_output}",
file=os.sys.stderr,
)
return {}
services = parse_docker_compose_ps_output(services_status)
@ -39,26 +51,31 @@ def run_docker_compose_ps(compose_file, env):
print(f"Failed to check service status. Error:\n{e.stderr}")
return e
#Helper method to print service status
# Helper method to print service status
def print_service_status(services):
print(f"{'Service Name':<25} {'State':<20} {'Ports'}")
print("="*72)
print("=" * 72)
for service_name, info in services.items():
status = info['STATE']
ports = info['PORTS']
status = info["STATE"]
ports = info["PORTS"]
print(f"{service_name:<25} {status:<20} {ports}")
#check for states based on the states passed in
# check for states based on the states passed in
def check_services_state(services, states):
for service_name, service_info in services.items():
status = service_info['STATE'].lower() # Convert status to lowercase for easier comparison
status = service_info[
"STATE"
].lower() # Convert status to lowercase for easier comparison
if any(state in status for state in states):
return True
return False
def get_llm_provider_access_keys(arch_config_file):
with open(arch_config_file, 'r') as file:
with open(arch_config_file, "r") as file:
arch_config = file.read()
arch_config_yaml = yaml.safe_load(arch_config)
@ -70,22 +87,23 @@ def get_llm_provider_access_keys(arch_config_file):
return access_key_list
def load_env_file_to_dict(file_path):
env_dict = {}
# Open and read the .env file
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
for line in file:
# Strip any leading/trailing whitespaces
line = line.strip()
# Skip empty lines and comments
if not line or line.startswith('#'):
if not line or line.startswith("#"):
continue
# Split the line into key and value at the first '=' sign
if '=' in line:
key, value = line.split('=', 1)
if "=" in line:
key, value = line.split("=", 1)
key = key.strip()
value = value.strip()
@ -94,6 +112,7 @@ def load_env_file_to_dict(file_path):
return env_dict
def parse_docker_compose_ps_output(output):
# Split the output into lines
lines = output.strip().splitlines()
@ -111,10 +130,7 @@ def parse_docker_compose_ps_output(output):
parts = line.split()
# Create a dictionary entry using the header names
service_info = {
headers[1]: parts[1], # State
headers[2]: parts[2] # Ports
}
service_info = {headers[1]: parts[1], headers[2]: parts[2]} # State # Ports
# Add to the result dictionary using the service name as the key
services[parts[0]] = service_info

View file

@ -7,47 +7,53 @@ from dotenv import load_dotenv
load_dotenv()
OPENAI_API_KEY=os.getenv("OPENAI_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
ARCH_STATE_HEADER = 'x-arch-state'
ARCH_STATE_HEADER = "x-arch-state"
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT, http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}))
client = OpenAI(
api_key=OPENAI_API_KEY,
base_url=CHAT_COMPLETION_ENDPOINT,
http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}),
)
def predict(message, state):
if 'history' not in state:
state['history'] = []
if "history" not in state:
state["history"] = []
history = state.get("history")
history.append({"role": "user", "content": message})
log.info("history: ", history)
# Custom headers
custom_headers = {
'x-arch-openai-api-key': f"{OPENAI_API_KEY}",
'x-arch-mistral-api-key': f"{MISTRAL_API_KEY}",
'x-arch-deterministic-provider': 'openai',
"x-arch-openai-api-key": f"{OPENAI_API_KEY}",
"x-arch-mistral-api-key": f"{MISTRAL_API_KEY}",
"x-arch-deterministic-provider": "openai",
}
metadata = None
if 'arch_state' in state:
metadata = {ARCH_STATE_HEADER: state['arch_state']}
if "arch_state" in state:
metadata = {ARCH_STATE_HEADER: state["arch_state"]}
try:
raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME,
messages = history,
temperature=1.0,
metadata=metadata,
extra_headers=custom_headers
)
raw_response = client.chat.completions.with_raw_response.create(
model=MODEL_NAME,
messages=history,
temperature=1.0,
metadata=metadata,
extra_headers=custom_headers,
)
except Exception as e:
log.info(e)
# remove last user message in case of exception
history.pop()
log.info("Error calling gateway API: {}".format(e.message))
raise gr.Error("Error calling gateway API: {}".format(e.message))
log.info(e)
# remove last user message in case of exception
history.pop()
log.info("Error calling gateway API: {}".format(e.message))
raise gr.Error("Error calling gateway API: {}".format(e.message))
log.info("raw_response: ", raw_response.text)
response = raw_response.parse()
@ -57,24 +63,33 @@ def predict(message, state):
response_json = json.loads(raw_response.text)
arch_state = None
if response_json:
metadata = response_json.get('metadata', {})
if metadata:
arch_state = metadata.get(ARCH_STATE_HEADER, None)
metadata = response_json.get("metadata", {})
if metadata:
arch_state = metadata.get(ARCH_STATE_HEADER, None)
if arch_state:
state['arch_state'] = arch_state
state["arch_state"] = arch_state
content = response.choices[0].message.content
history.append({"role": "assistant", "content": content, "model": response.model})
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
messages = [
(history[i]["content"], history[i + 1]["content"])
for i in range(0, len(history) - 1, 2)
]
return messages, state
with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo:
print("Starting Demo...")
chatbot = gr.Chatbot(label="Arch Chatbot", scale=1)
state = gr.State({})
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1, autofocus=True)
txt = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
scale=1,
autofocus=True,
)
txt.submit(predict, [txt, state], [chatbot, state])

View file

@ -5,26 +5,32 @@ from openai import OpenAI
import gradio as gr
api_key = os.getenv("OPENAI_API_KEY")
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1")
CHAT_COMPLETION_ENDPOINT = os.getenv(
"CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1"
)
client = OpenAI(api_key=api_key, base_url=CHAT_COMPLETION_ENDPOINT)
def predict(message, history):
history_openai_format = []
for human, assistant in history:
history_openai_format.append({"role": "user", "content": human })
history_openai_format.append({"role": "assistant", "content":assistant})
history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message})
response = client.chat.completions.create(model='gpt-3.5-turbo',
messages= history_openai_format,
temperature=1.0,
stream=True)
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=history_openai_format,
temperature=1.0,
stream=True,
)
partial_message = ""
for chunk in response:
if chunk.choices[0].delta.content is not None:
partial_message = partial_message + chunk.choices[0].delta.content
yield partial_message
partial_message = partial_message + chunk.choices[0].delta.content
yield partial_message
gr.ChatInterface(predict).launch(server_name="0.0.0.0", server_port=8081)

View file

@ -6,56 +6,54 @@ import logging
from pydantic import BaseModel
logger = logging.getLogger('uvicorn.error')
logger = logging.getLogger("uvicorn.error")
logger.setLevel(logging.INFO)
app = FastAPI()
@app.get("/healthz")
async def healthz():
return {
"status": "ok"
}
return {"status": "ok"}
class WeatherRequest(BaseModel):
city: str
days: int = 7
units: str = "Farenheit"
city: str
days: int = 7
units: str = "Farenheit"
@app.post("/weather")
async def weather(req: WeatherRequest, res: Response):
weather_forecast = {
"city": req.city,
"temperature": [],
"units": req.units,
}
for i in range(7):
min_temp = random.randrange(50,90)
max_temp = random.randrange(min_temp+5, min_temp+20)
if req.units.lower() == "celsius" or req.units.lower() == "c":
min_temp = (min_temp - 32) * 5.0/9.0
max_temp = (max_temp - 32) * 5.0/9.0
weather_forecast["temperature"].append({
"date": str(date.today() + timedelta(days=i)),
"temperature": {
"min": min_temp,
"max": max_temp
},
"units": req.units,
"query_time": str(datetime.now(timezone.utc))
})
min_temp = random.randrange(50, 90)
max_temp = random.randrange(min_temp + 5, min_temp + 20)
if req.units.lower() == "celsius" or req.units.lower() == "c":
min_temp = (min_temp - 32) * 5.0 / 9.0
max_temp = (max_temp - 32) * 5.0 / 9.0
weather_forecast["temperature"].append(
{
"date": str(date.today() + timedelta(days=i)),
"temperature": {"min": min_temp, "max": max_temp},
"units": req.units,
"query_time": str(datetime.now(timezone.utc)),
}
)
return weather_forecast
class InsuranceClaimDetailsRequest(BaseModel):
policy_number: str
policy_number: str
@app.post("/insurance_claim_details")
async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Response):
claim_details = {
"policy_number": req.policy_number,
"claim_status": "Approved",
@ -68,26 +66,25 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon
class DefaultTargetRequest(BaseModel):
arch_messages: list
arch_messages: list
@app.post("/default_target")
async def default_target(req: DefaultTargetRequest, res: Response):
logger.info(f"Received arch_messages: {req.arch_messages}")
resp = {
"choices": [
{
"choices": [
{
"message": {
"role": "assistant",
"content": "hello world from api server"
"role": "assistant",
"content": "hello world from api server",
},
"finish_reason": "completed",
"index": 0
}
],
"model": "api_server",
"usage": {
"completion_tokens": 0
"index": 0,
}
}
],
"model": "api_server",
"usage": {"completion_tokens": 0},
}
logger.info(f"sending response: {json.dumps(resp)}")
return resp

View file

@ -3,28 +3,40 @@ from pydantic import BaseModel, Field
app = FastAPI()
class Conversation(BaseModel):
arch_messages: list
class PolicyCoverageRequest(BaseModel):
policy_type: str = Field(..., description="The type of a policy held by the customer For, e.g. car, boat, house, motorcycle)")
policy_type: str = Field(
...,
description="The type of a policy held by the customer For, e.g. car, boat, house, motorcycle)",
)
class PolicyInitiateRequest(PolicyCoverageRequest):
deductible: float = Field(..., description="The deductible amount set of the policy")
deductible: float = Field(
..., description="The deductible amount set of the policy"
)
class ClaimUpdate(BaseModel):
claim_id: str
notes: str # Status or details of the claim
class DeductibleUpdate(BaseModel):
policy_id: str
deductible: float
class CoverageResponse(BaseModel):
policy_type: str
coverage: str # Description of coverage
premium: float # The premium cost
# Get information about policy coverage
@app.post("/policy/coverage", response_model=CoverageResponse)
async def get_policy_coverage(req: PolicyCoverageRequest):
@ -32,10 +44,22 @@ async def get_policy_coverage(req: PolicyCoverageRequest):
Retrieve the coverage details for a given policy type (car, boat, house, motorcycle).
"""
policy_coverage = {
"car": {"coverage": "Full car coverage with collision, liability", "premium": 500.0},
"boat": {"coverage": "Full boat coverage including theft and storm damage", "premium": 700.0},
"house": {"coverage": "Full house coverage including fire, theft, flood", "premium": 1000.0},
"motorcycle": {"coverage": "Full motorcycle coverage with liability", "premium": 400.0},
"car": {
"coverage": "Full car coverage with collision, liability",
"premium": 500.0,
},
"boat": {
"coverage": "Full boat coverage including theft and storm damage",
"premium": 700.0,
},
"house": {
"coverage": "Full house coverage including fire, theft, flood",
"premium": 1000.0,
},
"motorcycle": {
"coverage": "Full motorcycle coverage with liability",
"premium": 400.0,
},
}
if req.policy_type not in policy_coverage:
@ -44,9 +68,10 @@ async def get_policy_coverage(req: PolicyCoverageRequest):
return CoverageResponse(
policy_type=req.policy_type,
coverage=policy_coverage[req.policy_type]["coverage"],
premium=policy_coverage[req.policy_type]["premium"]
premium=policy_coverage[req.policy_type]["premium"],
)
# Initiate policy coverage
@app.post("/policy/initiate")
async def initiate_policy(policy_request: PolicyInitiateRequest):
@ -56,7 +81,11 @@ async def initiate_policy(policy_request: PolicyInitiateRequest):
if policy_request.policy_type not in ["car", "boat", "house", "motorcycle"]:
raise HTTPException(status_code=400, detail="Invalid policy type")
return {"message": f"Policy initiated for {policy_request.policy_type}", "deductible": policy_request.deductible}
return {
"message": f"Policy initiated for {policy_request.policy_type}",
"deductible": policy_request.deductible,
}
# Update claim details
@app.post("/policy/claim")
@ -65,8 +94,11 @@ async def update_claim(req: ClaimUpdate):
Update the status or details of a claim.
"""
# For simplicity, this is a mock update response
return {"message": f"Claim {claim_update.claim_id} for policy {claim_update.claim_id} has been updated",
"update": claim_update.notes}
return {
"message": f"Claim {claim_update.claim_id} for policy {claim_update.claim_id} has been updated",
"update": claim_update.notes,
}
# Update deductible amount
@app.post("/policy/deductible")
@ -75,8 +107,11 @@ async def update_deductible(deductible_update: DeductibleUpdate):
Update the deductible amount for a specific policy.
"""
# For simplicity, this is a mock update response
return {"message": f"Deductible for policy {deductible_update.policy_id} has been updated",
"new_deductible": deductible_update.deductible}
return {
"message": f"Deductible for policy {deductible_update.policy_id} has been updated",
"new_deductible": deductible_update.deductible,
}
# Post method for policy Q/A
@app.post("/policy/qa")
@ -86,21 +121,20 @@ async def policy_qa(conversation: Conversation):
It forwards the conversation to the OpenAI client via a local proxy and returns the response.
"""
return {
"choices": [
{
"choices": [
{
"message": {
"role": "assistant",
"content": "I am a helpful insurance agent, and can only help with insurance things"
"role": "assistant",
"content": "I am a helpful insurance agent, and can only help with insurance things",
},
"finish_reason": "completed",
"index": 0
}
],
"model": "insurance_agent",
"usage": {
"completion_tokens": 0
"index": 0,
}
}
],
"model": "insurance_agent",
"usage": {"completion_tokens": 0},
}
# Run the app using:
# uvicorn main:app --reload

View file

@ -4,10 +4,14 @@ from typing import List, Optional
app = FastAPI()
# Define the request model
class DeviceSummaryRequest(BaseModel):
device_ids: List[int]
time_range: Optional[int] = Field(default=7, description="Time range in days, defaults to 7")
time_range: Optional[int] = Field(
default=7, description="Time range in days, defaults to 7"
)
# Define the response model
class DeviceStatistics(BaseModel):
@ -15,18 +19,23 @@ class DeviceStatistics(BaseModel):
time_range: str
data: str
class DeviceSummaryResponse(BaseModel):
statistics: List[DeviceStatistics]
# Request model for device reboot
class DeviceRebootRequest(BaseModel):
device_ids: List[int]
# Response model for the device reboot
class CoverageResponse(BaseModel):
status: str
summary: dict
@app.post("/agent/device_reboot", response_model=CoverageResponse)
def reboot_network_device(request_data: DeviceRebootRequest):
"""
@ -38,20 +47,21 @@ def reboot_network_device(request_data: DeviceRebootRequest):
# Validate 'device_ids' (This is already validated by Pydantic, but additional logic can be added if needed)
if not device_ids:
raise HTTPException(status_code=400, detail="'device_ids' parameter is required")
raise HTTPException(
status_code=400, detail="'device_ids' parameter is required"
)
# Simulate reboot operation and return the response
statistics = []
for device_id in device_ids:
# Placeholder for actual data retrieval or device reboot logic
stats = {
"data": f"Device {device_id} has been successfully rebooted."
}
stats = {"data": f"Device {device_id} has been successfully rebooted."}
statistics.append(stats)
# Return the response with a summary
return CoverageResponse(status="success", summary={"device_ids": device_ids})
# Post method for device summary
@app.post("/agent/device_summary", response_model=DeviceSummaryResponse)
def get_device_summary(request: DeviceSummaryRequest):
@ -77,6 +87,7 @@ def get_device_summary(request: DeviceSummaryRequest):
return DeviceSummaryResponse(statistics=statistics)
@app.post("/agent/network_summary")
async def policy_qa():
"""
@ -84,21 +95,20 @@ async def policy_qa():
It forwards the conversation to the OpenAI client via a local proxy and returns the response.
"""
return {
"choices": [
{
"choices": [
{
"message": {
"role": "assistant",
"content": "I am a helpful networking agent, and I can help you get status for network devices or reboot them"
"role": "assistant",
"content": "I am a helpful networking agent, and I can help you get status for network devices or reboot them",
},
"finish_reason": "completed",
"index": 0
}
],
"model": "network_agent",
"usage": {
"completion_tokens": 0
"index": 0,
}
}
],
"model": "network_agent",
"usage": {"completion_tokens": 0},
}
if __name__ == "__main__":
app.run(debug=True)

View file

@ -11,6 +11,7 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
def load_sql():
# Example Usage
conn = sqlite3.connect(":memory:")
@ -26,6 +27,7 @@ def load_sql():
return conn
# Function to convert natural language time expressions to "X {time} ago" format
def convert_to_ago_format(expression):
# Define patterns for different time units

View file

@ -13,9 +13,10 @@ def get_device_summary():
# Validate 'device_ids' parameter
device_ids = data.get("device_ids")
if not device_ids or not isinstance(device_ids, list):
return jsonify(
{"error": "'device_ids' parameter is required and must be a list"}
), 400
return (
jsonify({"error": "'device_ids' parameter is required and must be a list"}),
400,
)
# Validate 'time_range' parameter (optional, defaults to 7)
time_range = data.get("time_range", 7)

View file

@ -119,9 +119,10 @@ def process_rag():
intent_changed = True
else:
# Invalid value provided
return jsonify(
{"error": "Invalid value for x-arch-prompt-intent-change header"}
), 400
return (
jsonify({"error": "Invalid value for x-arch-prompt-intent-change header"}),
400,
)
# Update user conversation based on intent change
memory = update_user_conversation(user_id, client_messages, intent_changed)

View file

@ -13,9 +13,10 @@ def get_device_summary():
# Validate 'device_ids' parameter
device_ids = data.get("device_ids")
if not device_ids or not isinstance(device_ids, list):
return jsonify(
{"error": "'device_ids' parameter is required and must be a list"}
), 400
return (
jsonify({"error": "'device_ids' parameter is required and must be a list"}),
400,
)
# Validate 'time_range' parameter (optional, defaults to 7)
time_range = data.get("time_range", 7)

View file

@ -12,16 +12,16 @@ from sphinx.util.docfields import Field
from sphinxawesome_theme import ThemeOptions
from sphinxawesome_theme.postprocess import Icons
project = 'Arch Docs'
copyright = '2024, Katanemo Labs, Inc'
author = 'Katanemo Labs, Inc'
release = ' v0.1'
project = "Arch Docs"
copyright = "2024, Katanemo Labs, Inc"
author = "Katanemo Labs, Inc"
release = " v0.1"
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
root_doc = 'index'
root_doc = "index"
nitpicky = True
add_module_names = False
@ -33,23 +33,23 @@ extensions = [
"sphinx.ext.extlinks",
"sphinx.ext.viewcode",
"sphinx_sitemap",
"sphinx_design"
"sphinx_design",
]
# Paths that contain templates, relative to this directory.
templates_path = ['_templates']
templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and directories
# to ignore when looking for source files.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for HTML output -------------------------------------------------
html_theme = 'sphinxawesome_theme' # You can change the theme to 'sphinx_rtd_theme' or another of your choice.
html_theme = "sphinxawesome_theme" # You can change the theme to 'sphinx_rtd_theme' or another of your choice.
html_title = project + release
html_permalinks_icon = Icons.permalinks_icon
html_favicon = '_static/favicon.ico'
html_logo = '_static/favicon.ico' # Specify the path to the logo image file (make sure the logo is in the _static directory)
html_favicon = "_static/favicon.ico"
html_logo = "_static/favicon.ico" # Specify the path to the logo image file (make sure the logo is in the _static directory)
html_last_updated_fmt = ""
html_use_index = False # Don't create index
html_domain_indices = False # Don't need module indices
@ -57,10 +57,14 @@ html_copy_source = False # Don't need sources
html_show_sphinx = False
html_baseurl = './docs'
html_baseurl = "./docs"
html_sidebars = {
"**": ['analytics.html', "sidebar_main_nav_links.html", "sidebar_toc.html", ]
"**": [
"analytics.html",
"sidebar_main_nav_links.html",
"sidebar_toc.html",
]
}
theme_options = ThemeOptions(
@ -102,7 +106,7 @@ html_theme_options = asdict(theme_options)
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
html_static_path = ["_static"]
pygments_style = "lovelace"
pygments_style_dark = "github-dark"
@ -111,10 +115,11 @@ sitemap_url_scheme = "{link}"
# Add this configuration at the bottom of your conf.py
html_context = {
'google_analytics_id': 'G-K2LXXSX6HB', # Replace with your Google Analytics tracking ID
"google_analytics_id": "G-K2LXXSX6HB", # Replace with your Google Analytics tracking ID
}
templates_path = ['_templates']
templates_path = ["_templates"]
# -- Register a :confval: interpreted text role ----------------------------------
def setup(app: Sphinx) -> None:
@ -138,4 +143,4 @@ def setup(app: Sphinx) -> None:
],
)
app.add_css_file('_static/custom.css')
app.add_css_file("_static/custom.css")

View file

@ -35,11 +35,13 @@ def start_server():
print("Server is already running. Use 'model_server restart' to restart it.")
sys.exit(1)
print(f"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)")
print(
f"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)"
)
process = subprocess.Popen(
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],
start_new_session=True,
stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to
stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to
stderr=subprocess.DEVNULL, # Suppress standard error. There is a logger that model_server prints to
)

View file

@ -7,7 +7,6 @@ from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForSequenc
def get_device():
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
@ -19,13 +18,15 @@ def get_device():
return device
def load_transformers(model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")):
def load_transformers(
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")
):
print("Loading Embedding Model")
transformers = {}
device = get_device()
transformers["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
transformers["model"] = ORTModelForFeatureExtraction.from_pretrained(
model_name, device_map = device
model_name, device_map=device
)
transformers["model_name"] = model_name
@ -62,7 +63,9 @@ def load_guard_model(
return guard_model
def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")):
def load_zero_shot_models(
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")
):
zero_shot_model = {}
device = get_device()
zero_shot_model["model"] = ORTModelForSequenceClassification.from_pretrained(
@ -81,5 +84,6 @@ def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deb
return zero_shot_model
if __name__ == "__main__":
print(get_device())

View file

@ -7,7 +7,12 @@ from app.load_models import (
get_device,
)
import os
from app.utils import GuardHandler, split_text_into_chunks, load_yaml_config, get_model_server_logger
from app.utils import (
GuardHandler,
split_text_into_chunks,
load_yaml_config,
get_model_server_logger,
)
import torch
import yaml
import string
@ -39,6 +44,7 @@ guard_handler = GuardHandler(toxic_model=None, jailbreak_model=jailbreak_model)
app = FastAPI()
class EmbeddingRequest(BaseModel):
input: str
model: str
@ -84,6 +90,7 @@ async def embedding(req: EmbeddingRequest, res: Response):
}
return {"data": data, "model": req.model, "object": "list", "usage": usage}
class GuardRequest(BaseModel):
input: str
task: str

View file

@ -9,6 +9,7 @@ import logging
logger_instance = None
def load_yaml_config(file_name):
# Load the YAML file from the package
yaml_path = pkg_resources.resource_filename("app", file_name)
@ -138,6 +139,7 @@ class GuardHandler:
}
return result_dict
def get_model_server_logger():
global logger_instance
@ -164,8 +166,8 @@ def get_model_server_logger():
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler(log_file_path, mode='w'), # Overwrite logs in file
]
logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in file
],
)
except (PermissionError, OSError) as e:
# Dont' fallback to console logging if there are issues writing to the log file