lint + formating with black (#158)

* lint + formating with black

* add black as pre commit
This commit is contained in:
Co Tran 2024-10-09 11:25:07 -07:00 committed by GitHub
parent 498e7f9724
commit 5c4a6bc8ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 581 additions and 295 deletions

View file

@ -25,3 +25,8 @@ repos:
# --lib is to only test the library, since when integration tests are made, # --lib is to only test the library, since when integration tests are made,
# they will be in a seperate tests directory # they will be in a seperate tests directory
entry: bash -c "cd arch && cargo test -p intelligent-prompt-gateway --lib" entry: bash -c "cd arch && cargo test -p intelligent-prompt-gateway --lib"
- repo: https://github.com/psf/black
rev: 23.1.0
hooks:
- id: black
language_version: python3

View file

@ -16,18 +16,22 @@ logo = r"""
/_/ \_\|_| \___||_| |_| /_/ \_\|_| \___||_| |_|
""" """
@click.group(invoke_without_command=True) @click.group(invoke_without_command=True)
@click.pass_context @click.pass_context
def main(ctx): def main(ctx):
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
click.echo( """Arch (The Intelligent Prompt Gateway) CLI""") click.echo("""Arch (The Intelligent Prompt Gateway) CLI""")
click.echo(logo) click.echo(logo)
click.echo(ctx.get_help()) click.echo(ctx.get_help())
# Command to build archgw and model_server Docker images # Command to build archgw and model_server Docker images
ARCHGW_DOCKERFILE = "./arch/Dockerfile" ARCHGW_DOCKERFILE = "./arch/Dockerfile"
MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml" MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml"
@click.command() @click.command()
def build(): def build():
"""Build Arch from source. Must be in root of cloned repo.""" """Build Arch from source. Must be in root of cloned repo."""
@ -35,7 +39,18 @@ def build():
if os.path.exists(ARCHGW_DOCKERFILE): if os.path.exists(ARCHGW_DOCKERFILE):
click.echo("Building archgw image...") click.echo("Building archgw image...")
try: try:
subprocess.run(["docker", "build", "-f", ARCHGW_DOCKERFILE, "-t", "archgw:latest", "."], check=True) subprocess.run(
[
"docker",
"build",
"-f",
ARCHGW_DOCKERFILE,
"-t",
"archgw:latest",
".",
],
check=True,
)
click.echo("archgw image built successfully.") click.echo("archgw image built successfully.")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
click.echo(f"Error building archgw image: {e}") click.echo(f"Error building archgw image: {e}")
@ -51,7 +66,11 @@ def build():
if os.path.exists(MODEL_SERVER_BUILD_FILE): if os.path.exists(MODEL_SERVER_BUILD_FILE):
click.echo("Installing model server dependencies with Poetry...") click.echo("Installing model server dependencies with Poetry...")
try: try:
subprocess.run(["poetry", "install", "--no-cache"], cwd=os.path.dirname(MODEL_SERVER_BUILD_FILE), check=True) subprocess.run(
["poetry", "install", "--no-cache"],
cwd=os.path.dirname(MODEL_SERVER_BUILD_FILE),
check=True,
)
click.echo("Model server dependencies installed successfully.") click.echo("Model server dependencies installed successfully.")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
click.echo(f"Error installing model server dependencies: {e}") click.echo(f"Error installing model server dependencies: {e}")
@ -60,9 +79,12 @@ def build():
click.echo(f"Error: pyproject.toml not found in {MODEL_SERVER_BUILD_FILE}") click.echo(f"Error: pyproject.toml not found in {MODEL_SERVER_BUILD_FILE}")
sys.exit(1) sys.exit(1)
@click.command() @click.command()
@click.argument('file', required=False) # Optional file argument @click.argument("file", required=False) # Optional file argument
@click.option('-path', default='.', help='Path to the directory containing arch_config.yml') @click.option(
"-path", default=".", help="Path to the directory containing arch_config.yml"
)
def up(file, path): def up(file, path):
"""Starts Arch.""" """Starts Arch."""
if file: if file:
@ -78,10 +100,15 @@ def up(file, path):
return return
print(f"Validating {arch_config_file}") print(f"Validating {arch_config_file}")
arch_schema_config = pkg_resources.resource_filename(__name__, "config/arch_config_schema.yaml") arch_schema_config = pkg_resources.resource_filename(
__name__, "config/arch_config_schema.yaml"
)
try: try:
config_generator.validate_prompt_config(arch_config_file=arch_config_file, arch_config_schema_file=arch_schema_config) config_generator.validate_prompt_config(
arch_config_file=arch_config_file,
arch_config_schema_file=arch_schema_config,
)
except Exception as e: except Exception as e:
print("Exiting archgw up") print("Exiting archgw up")
sys.exit(1) sys.exit(1)
@ -91,52 +118,67 @@ def up(file, path):
# Set the ARCH_CONFIG_FILE environment variable # Set the ARCH_CONFIG_FILE environment variable
env_stage = {} env_stage = {}
env = os.environ.copy() env = os.environ.copy()
#check if access_keys are preesnt in the config file # check if access_keys are preesnt in the config file
access_keys = get_llm_provider_access_keys(arch_config_file=arch_config_file) access_keys = get_llm_provider_access_keys(arch_config_file=arch_config_file)
if access_keys: if access_keys:
if file: if file:
app_env_file = os.path.join(os.path.dirname(os.path.abspath(file)), ".env") #check the .env file in the path app_env_file = os.path.join(
os.path.dirname(os.path.abspath(file)), ".env"
) # check the .env file in the path
else: else:
app_env_file = os.path.abspath(os.path.join(path, ".env")) app_env_file = os.path.abspath(os.path.join(path, ".env"))
if not os.path.exists(app_env_file): #check to see if the environment variables in the current environment or not if not os.path.exists(
app_env_file
): # check to see if the environment variables in the current environment or not
for access_key in access_keys: for access_key in access_keys:
if env.get(access_key) is None: if env.get(access_key) is None:
print (f"Access Key: {access_key} not found. Exiting Start") print(f"Access Key: {access_key} not found. Exiting Start")
sys.exit(1) sys.exit(1)
else: else:
env_stage[access_key] = env.get(access_key) env_stage[access_key] = env.get(access_key)
else: #.env file exists, use that to send parameters to Arch else: # .env file exists, use that to send parameters to Arch
env_file_dict = load_env_file_to_dict(app_env_file) env_file_dict = load_env_file_to_dict(app_env_file)
for access_key in access_keys: for access_key in access_keys:
if env_file_dict.get(access_key) is None: if env_file_dict.get(access_key) is None:
print (f"Access Key: {access_key} not found. Exiting Start") print(f"Access Key: {access_key} not found. Exiting Start")
sys.exit(1) sys.exit(1)
else: else:
env_stage[access_key] = env_file_dict[access_key] env_stage[access_key] = env_file_dict[access_key]
with open(pkg_resources.resource_filename(__name__, "config/stage.env"), 'w') as file: with open(
pkg_resources.resource_filename(__name__, "config/stage.env"), "w"
) as file:
for key, value in env_stage.items(): for key, value in env_stage.items():
file.write(f"{key}={value}\n") file.write(f"{key}={value}\n")
env.update(env_stage) env.update(env_stage)
env['ARCH_CONFIG_FILE'] = arch_config_file env["ARCH_CONFIG_FILE"] = arch_config_file
start_arch_modelserver() start_arch_modelserver()
start_arch(arch_config_file, env) start_arch(arch_config_file, env)
@click.command() @click.command()
def down(): def down():
"""Stops Arch.""" """Stops Arch."""
stop_arch_modelserver() stop_arch_modelserver()
stop_arch() stop_arch()
@click.command() @click.command()
@click.option('-f', '--file', type=click.Path(exists=True), required=True, help="Path to the Python file") @click.option(
"-f",
"--file",
type=click.Path(exists=True),
required=True,
help="Path to the Python file",
)
def generate_prompt_targets(file): def generate_prompt_targets(file):
"""Generats prompt_targets from python methods. """Generats prompt_targets from python methods.
Note: This works for simple data types like ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']: Note: This works for simple data types like ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']:
If you have a complex pydantic data type, you will have to flatten those manually until we add support for it.""" If you have a complex pydantic data type, you will have to flatten those manually until we add support for it.
"""
print(f"Processing file: {file}") print(f"Processing file: {file}")
if not file.endswith(".py"): if not file.endswith(".py"):
@ -145,10 +187,11 @@ def generate_prompt_targets(file):
targets.generate_prompt_targets(file) targets.generate_prompt_targets(file)
main.add_command(up) main.add_command(up)
main.add_command(down) main.add_command(down)
main.add_command(build) main.add_command(build)
main.add_command(generate_prompt_targets) main.add_command(generate_prompt_targets)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

View file

@ -3,36 +3,44 @@ from jinja2 import Environment, FileSystemLoader
import yaml import yaml
from jsonschema import validate from jsonschema import validate
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.template.yaml') ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml') "ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml') )
ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml') ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/config/arch_config.yaml")
ENVOY_CONFIG_FILE_RENDERED = os.getenv(
"ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml"
)
ARCH_CONFIG_SCHEMA_FILE = os.getenv(
"ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml"
)
def add_secret_key_to_llm_providers(config_yaml) :
def add_secret_key_to_llm_providers(config_yaml):
llm_providers = [] llm_providers = []
for llm_provider in config_yaml.get("llm_providers", []): for llm_provider in config_yaml.get("llm_providers", []):
access_key_env_var = llm_provider.get('access_key', False) access_key_env_var = llm_provider.get("access_key", False)
access_key_value = os.getenv(access_key_env_var, False) access_key_value = os.getenv(access_key_env_var, False)
if access_key_env_var and access_key_value: if access_key_env_var and access_key_value:
llm_provider['access_key'] = access_key_value llm_provider["access_key"] = access_key_value
llm_providers.append(llm_provider) llm_providers.append(llm_provider)
config_yaml["llm_providers"] = llm_providers config_yaml["llm_providers"] = llm_providers
return config_yaml return config_yaml
def validate_and_render_schema(): def validate_and_render_schema():
env = Environment(loader=FileSystemLoader('./')) env = Environment(loader=FileSystemLoader("./"))
template = env.get_template('envoy.template.yaml') template = env.get_template("envoy.template.yaml")
try: try:
validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE) validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE)
except Exception as e: except Exception as e:
print(e) print(e)
exit(1) # validate_prompt_config failed. Exit exit(1) # validate_prompt_config failed. Exit
with open(ARCH_CONFIG_FILE, 'r') as file: with open(ARCH_CONFIG_FILE, "r") as file:
arch_config = file.read() arch_config = file.read()
with open(ARCH_CONFIG_SCHEMA_FILE, 'r') as file: with open(ARCH_CONFIG_SCHEMA_FILE, "r") as file:
arch_config_schema = file.read() arch_config_schema = file.read()
config_yaml = yaml.safe_load(arch_config) config_yaml = yaml.safe_load(arch_config)
@ -44,7 +52,7 @@ def validate_and_render_schema():
if name not in inferred_clusters: if name not in inferred_clusters:
inferred_clusters[name] = { inferred_clusters[name] = {
"name": name, "name": name,
"port": 80, # default port "port": 80, # default port
} }
print(inferred_clusters) print(inferred_clusters)
@ -55,14 +63,13 @@ def validate_and_render_schema():
if name in inferred_clusters: if name in inferred_clusters:
print("updating cluster", endpoint_details) print("updating cluster", endpoint_details)
inferred_clusters[name].update(endpoint_details) inferred_clusters[name].update(endpoint_details)
endpoint = inferred_clusters[name]['endpoint'] endpoint = inferred_clusters[name]["endpoint"]
if len(endpoint.split(':')) > 1: if len(endpoint.split(":")) > 1:
inferred_clusters[name]['endpoint'] = endpoint.split(':')[0] inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
inferred_clusters[name]['port'] = int(endpoint.split(':')[1]) inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
else: else:
inferred_clusters[name] = endpoint_details inferred_clusters[name] = endpoint_details
print("updated clusters", inferred_clusters) print("updated clusters", inferred_clusters)
config_yaml = add_secret_key_to_llm_providers(config_yaml) config_yaml = add_secret_key_to_llm_providers(config_yaml)
@ -71,23 +78,24 @@ def validate_and_render_schema():
arch_config_string = yaml.dump(config_yaml) arch_config_string = yaml.dump(config_yaml)
data = { data = {
'arch_config': arch_config_string, "arch_config": arch_config_string,
'arch_clusters': inferred_clusters, "arch_clusters": inferred_clusters,
'arch_llm_providers': arch_llm_providers, "arch_llm_providers": arch_llm_providers,
'arch_tracing': arch_tracing "arch_tracing": arch_tracing,
} }
rendered = template.render(data) rendered = template.render(data)
print(rendered) print(rendered)
print(ENVOY_CONFIG_FILE_RENDERED) print(ENVOY_CONFIG_FILE_RENDERED)
with open(ENVOY_CONFIG_FILE_RENDERED, 'w') as file: with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
file.write(rendered) file.write(rendered)
def validate_prompt_config(arch_config_file, arch_config_schema_file): def validate_prompt_config(arch_config_file, arch_config_schema_file):
with open(arch_config_file, 'r') as file: with open(arch_config_file, "r") as file:
arch_config = file.read() arch_config = file.read()
with open(arch_config_schema_file, 'r') as file: with open(arch_config_schema_file, "r") as file:
arch_config_schema = file.read() arch_config_schema = file.read()
config_yaml = yaml.safe_load(arch_config) config_yaml = yaml.safe_load(arch_config)
@ -96,8 +104,11 @@ def validate_prompt_config(arch_config_file, arch_config_schema_file):
try: try:
validate(config_yaml, config_schema_yaml) validate(config_yaml, config_schema_yaml)
except Exception as e: except Exception as e:
print(f"Error validating arch_config file: {arch_config_file}, error: {e.message}") print(
f"Error validating arch_config file: {arch_config_file}, error: {e.message}"
)
raise e raise e
if __name__ == '__main__':
if __name__ == "__main__":
validate_and_render_schema() validate_and_render_schema()

View file

@ -5,6 +5,7 @@ import pkg_resources
import select import select
from utils import run_docker_compose_ps, print_service_status, check_services_state from utils import run_docker_compose_ps, print_service_status, check_services_state
def start_arch(arch_config_file, env, log_timeout=120): def start_arch(arch_config_file, env, log_timeout=120):
""" """
Start Docker Compose in detached mode and stream logs until services are healthy. Start Docker Compose in detached mode and stream logs until services are healthy.
@ -14,22 +15,35 @@ def start_arch(arch_config_file, env, log_timeout=120):
log_timeout (int): Time in seconds to show logs before checking for healthy state. log_timeout (int): Time in seconds to show logs before checking for healthy state.
""" """
compose_file = pkg_resources.resource_filename(__name__, 'config/docker-compose.yaml') compose_file = pkg_resources.resource_filename(
__name__, "config/docker-compose.yaml"
)
try: try:
# Run the Docker Compose command in detached mode (-d) # Run the Docker Compose command in detached mode (-d)
subprocess.run( subprocess.run(
["docker", "compose", "-p", "arch", "up", "-d",], [
cwd=os.path.dirname(compose_file), # Ensure the Docker command runs in the correct path "docker",
env=env, # Pass the modified environment "compose",
check=True # Raise an exception if the command fails "-p",
"arch",
"up",
"-d",
],
cwd=os.path.dirname(
compose_file
), # Ensure the Docker command runs in the correct path
env=env, # Pass the modified environment
check=True, # Raise an exception if the command fails
) )
print(f"Arch docker-compose started in detached.") print(f"Arch docker-compose started in detached.")
print("Monitoring `docker-compose ps` logs...") print("Monitoring `docker-compose ps` logs...")
start_time = time.time() start_time = time.time()
services_status = {} services_status = {}
services_running = False #assume that the services are not running at the moment services_running = (
False # assume that the services are not running at the moment
)
while True: while True:
current_time = time.time() current_time = time.time()
@ -40,16 +54,22 @@ def start_arch(arch_config_file, env, log_timeout=120):
print(f"Stopping log monitoring after {log_timeout} seconds.") print(f"Stopping log monitoring after {log_timeout} seconds.")
break break
current_services_status = run_docker_compose_ps(compose_file=compose_file, env=env) current_services_status = run_docker_compose_ps(
compose_file=compose_file, env=env
)
if not current_services_status: if not current_services_status:
print("Status for the services could not be detected. Something went wrong. Please run docker logs") print(
"Status for the services could not be detected. Something went wrong. Please run docker logs"
)
break break
if not services_status: if not services_status:
services_status = current_services_status #set the first time services_status = current_services_status # set the first time
print_service_status(services_status) #print the services status and proceed. print_service_status(
services_status
) # print the services status and proceed.
#check if anyone service is failed or exited state, if so print and break out # check if anyone service is failed or exited state, if so print and break out
unhealthy_states = ["unhealthy", "exit", "exited", "dead", "bad"] unhealthy_states = ["unhealthy", "exit", "exited", "dead", "bad"]
running_states = ["running", "up"] running_states = ["running", "up"]
@ -58,14 +78,23 @@ def start_arch(arch_config_file, env, log_timeout=120):
break break
if check_services_state(current_services_status, unhealthy_states): if check_services_state(current_services_status, unhealthy_states):
print("One or more Arch services are unhealthy. Please run `docker logs` for more information") print(
print_service_status(current_services_status) #print the services status and proceed. "One or more Arch services are unhealthy. Please run `docker logs` for more information"
)
print_service_status(
current_services_status
) # print the services status and proceed.
break break
#check to see if the status of one of the services has changed from prior. Print and loop over until finish, or error # check to see if the status of one of the services has changed from prior. Print and loop over until finish, or error
for service_name in services_status.keys(): for service_name in services_status.keys():
if services_status[service_name]['State'] != current_services_status[service_name]['State']: if (
print("One or more Arch services have changed state. Printing current state") services_status[service_name]["State"]
!= current_services_status[service_name]["State"]
):
print(
"One or more Arch services have changed state. Printing current state"
)
print_service_status(current_services_status) print_service_status(current_services_status)
break break
@ -82,7 +111,9 @@ def stop_arch():
Args: Args:
path (str): The path where the docker-compose.yml file is located. path (str): The path where the docker-compose.yml file is located.
""" """
compose_file = pkg_resources.resource_filename(__name__, 'config/docker-compose.yaml') compose_file = pkg_resources.resource_filename(
__name__, "config/docker-compose.yaml"
)
try: try:
# Run `docker-compose down` to shut down all services # Run `docker-compose down` to shut down all services
@ -96,6 +127,7 @@ def stop_arch():
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f"Failed to shut down services: {str(e)}") print(f"Failed to shut down services: {str(e)}")
def start_arch_modelserver(): def start_arch_modelserver():
""" """
Start the model server. This assumes that the archgw_modelserver package is installed locally Start the model server. This assumes that the archgw_modelserver package is installed locally
@ -103,15 +135,14 @@ def start_arch_modelserver():
""" """
try: try:
subprocess.run( subprocess.run(
['archgw_modelserver', 'restart'], ["archgw_modelserver", "restart"], check=True, start_new_session=True
check=True,
start_new_session=True
) )
print("Successfull run the archgw model_server") print("Successfull run the archgw model_server")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print (f"Failed to start model_server. Please check archgw_modelserver logs") print(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1) sys.exit(1)
def stop_arch_modelserver(): def stop_arch_modelserver():
""" """
Stop the model server. This assumes that the archgw_modelserver package is installed locally Stop the model server. This assumes that the archgw_modelserver package is installed locally
@ -119,10 +150,10 @@ def stop_arch_modelserver():
""" """
try: try:
subprocess.run( subprocess.run(
['archgw_modelserver', 'stop'], ["archgw_modelserver", "stop"],
check=True, check=True,
) )
print("Successfull stopped the archgw model_server") print("Successfull stopped the archgw model_server")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print (f"Failed to start model_server. Please check archgw_modelserver logs") print(f"Failed to start model_server. Please check archgw_modelserver logs")
sys.exit(1) sys.exit(1)

View file

@ -6,17 +6,29 @@ setup(
description="Python-based CLI tool to manage Arch and generate targets.", description="Python-based CLI tool to manage Arch and generate targets.",
author="Katanemo Labs, Inc.", author="Katanemo Labs, Inc.",
packages=find_packages(), packages=find_packages(),
py_modules = ['cli', 'core', 'targets', 'utils', 'config_generator'], py_modules=["cli", "core", "targets", "utils", "config_generator"],
include_package_data=True, include_package_data=True,
# Specify to include the docker-compose.yml file # Specify to include the docker-compose.yml file
package_data={ package_data={
'': ['config/docker-compose.yaml', 'config/arch_config_schema.yaml', 'config/stage.env'] #Specify to include the docker-compose.yml file "": [
"config/docker-compose.yaml",
"config/arch_config_schema.yaml",
"config/stage.env",
] # Specify to include the docker-compose.yml file
}, },
# Add dependencies here, e.g., 'PyYAML' for YAML processing # Add dependencies here, e.g., 'PyYAML' for YAML processing
install_requires=['pyyaml', 'pydantic', 'click', 'jinja2','pyyaml','jsonschema', 'setuptools'], install_requires=[
"pyyaml",
"pydantic",
"click",
"jinja2",
"pyyaml",
"jsonschema",
"setuptools",
],
entry_points={ entry_points={
'console_scripts': [ "console_scripts": [
'archgw=cli:main', "archgw=cli:main",
], ],
}, },
) )

View file

@ -18,14 +18,20 @@ def detect_framework(tree: Any) -> str:
return "fastapi" return "fastapi"
return "unknown" return "unknown"
def get_route_decorators(node: Any, framework: str) -> list: def get_route_decorators(node: Any, framework: str) -> list:
"""Extract route decorators based on the framework.""" """Extract route decorators based on the framework."""
decorators = [] decorators = []
for decorator in node.decorator_list: for decorator in node.decorator_list:
if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute): if isinstance(decorator, ast.Call) and isinstance(
decorator.func, ast.Attribute
):
if framework == "flask" and decorator.func.attr in FLASK_ROUTE_DECORATORS: if framework == "flask" and decorator.func.attr in FLASK_ROUTE_DECORATORS:
decorators.append(decorator.func.attr) decorators.append(decorator.func.attr)
elif framework == "fastapi" and decorator.func.attr in FASTAPI_ROUTE_DECORATORS: elif (
framework == "fastapi"
and decorator.func.attr in FASTAPI_ROUTE_DECORATORS
):
decorators.append(decorator.func.attr) decorators.append(decorator.func.attr)
return decorators return decorators
@ -36,6 +42,7 @@ def get_route_path(node: Any, framework: str) -> str:
if isinstance(decorator, ast.Call) and decorator.args: if isinstance(decorator, ast.Call) and decorator.args:
return decorator.args[0].s # Assuming it's a string literal return decorator.args[0].s # Assuming it's a string literal
def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool: def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool:
"""Check if a given type annotation is a Pydantic model.""" """Check if a given type annotation is a Pydantic model."""
# We walk through the AST to find class definitions and check if they inherit from Pydantic's BaseModel # We walk through the AST to find class definitions and check if they inherit from Pydantic's BaseModel
@ -47,6 +54,7 @@ def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool:
return True return True
return False return False
def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list: def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
"""Extract fields from a Pydantic model, handling list, tuple, set, dict types, and direct default values.""" """Extract fields from a Pydantic model, handling list, tuple, set, dict types, and direct default values."""
fields = [] fields = []
@ -62,15 +70,26 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
required = True # Assume the field is required initially required = True # Assume the field is required initially
# Check if the field uses Field() with required status and description # Check if the field uses Field() with required status and description
if stmt.value and isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == 'Field': if (
stmt.value
and isinstance(stmt.value, ast.Call)
and isinstance(stmt.value.func, ast.Name)
and stmt.value.func.id == "Field"
):
# Extract the description argument inside the Field call # Extract the description argument inside the Field call
for keyword in stmt.value.keywords: for keyword in stmt.value.keywords:
if keyword.arg == 'description' and isinstance(keyword.value, ast.Str): if keyword.arg == "description" and isinstance(
keyword.value, ast.Str
):
description = keyword.value.s description = keyword.value.s
if keyword.arg == 'default': if keyword.arg == "default":
default_value = keyword.value default_value = keyword.value
# If Ellipsis (...) is used, it means the field is required # If Ellipsis (...) is used, it means the field is required
if stmt.value.args and isinstance(stmt.value.args[0], ast.Constant) and stmt.value.args[0].value is Ellipsis: if (
stmt.value.args
and isinstance(stmt.value.args[0], ast.Constant)
and stmt.value.args[0].value is Ellipsis
):
required = True required = True
else: else:
required = False required = False
@ -80,19 +99,28 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
if isinstance(stmt.value, ast.Constant): if isinstance(stmt.value, ast.Constant):
# Set the default value from the assignment (e.g., name: str = "John Doe") # Set the default value from the assignment (e.g., name: str = "John Doe")
default_value = stmt.value.value default_value = stmt.value.value
required = False # Not required since it has a default value required = (
False # Not required since it has a default value
)
# Always extract the field type, even if there's a default value # Always extract the field type, even if there's a default value
if isinstance(stmt.annotation, ast.Subscript): if isinstance(stmt.annotation, ast.Subscript):
# Get the base type (list, tuple, set, dict) # Get the base type (list, tuple, set, dict)
base_type = stmt.annotation.value.id if isinstance(stmt.annotation.value, ast.Name) else "Unknown" base_type = (
stmt.annotation.value.id
if isinstance(stmt.annotation.value, ast.Name)
else "Unknown"
)
# Handle only list, tuple, set, dict and ignore the inner types # Handle only list, tuple, set, dict and ignore the inner types
if base_type.lower() in ['list', 'tuple', 'set', 'dict']: if base_type.lower() in ["list", "tuple", "set", "dict"]:
field_type = base_type.lower() field_type = base_type.lower()
# Handle the ellipsis '...' for required fields if no Field() call # Handle the ellipsis '...' for required fields if no Field() call
elif isinstance(stmt.value, ast.Constant) and stmt.value.value is Ellipsis: elif (
isinstance(stmt.value, ast.Constant)
and stmt.value.value is Ellipsis
):
required = True required = True
# Handle simple types like str, int, etc. # Handle simple types like str, int, etc.
@ -100,16 +128,17 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
field_type = stmt.annotation.id field_type = stmt.annotation.id
field_info = { field_info = {
"name": stmt.target.id, "name": stmt.target.id,
"type": field_type, # Always set the field type "type": field_type, # Always set the field type
"description": description, "description": description,
"default": default_value, # Handle direct default values "default": default_value, # Handle direct default values
"required": required "required": required,
} }
fields.append(field_info) fields.append(field_info)
return fields return fields
def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list: def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
"""Extract the parameters and their types from the function definition.""" """Extract the parameters and their types from the function definition."""
parameters = [] parameters = []
@ -119,40 +148,68 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
arg_descriptions = extract_arg_descriptions_from_docstring(docstring) arg_descriptions = extract_arg_descriptions_from_docstring(docstring)
# Extract default values # Extract default values
defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + node.args.defaults # Align defaults with args defaults = [None] * (
len(node.args.args) - len(node.args.defaults)
) + node.args.defaults # Align defaults with args
for arg, default in zip(node.args.args, defaults): for arg, default in zip(node.args.args, defaults):
if arg.arg != "self": # Skip 'self' or 'cls' in class methods if arg.arg != "self": # Skip 'self' or 'cls' in class methods
param_info = {"name": arg.arg, "description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]")} param_info = {
"name": arg.arg,
"description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]"),
}
# Handle Pydantic model types # Handle Pydantic model types
if hasattr(arg, 'annotation') and is_pydantic_model(arg.annotation, tree): if hasattr(arg, "annotation") and is_pydantic_model(arg.annotation, tree):
# Extract and flatten Pydantic model fields # Extract and flatten Pydantic model fields
pydantic_fields = get_pydantic_model_fields(arg.annotation.id, tree) pydantic_fields = get_pydantic_model_fields(arg.annotation.id, tree)
parameters.extend(pydantic_fields) # Flatten the model fields into the parameters list parameters.extend(
pydantic_fields
) # Flatten the model fields into the parameters list
continue # Skip adding the current param_info for the model since we expand the fields continue # Skip adding the current param_info for the model since we expand the fields
# Handle standard Python types (int, float, str, etc.) # Handle standard Python types (int, float, str, etc.)
elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Name): elif hasattr(arg, "annotation") and isinstance(arg.annotation, ast.Name):
if arg.annotation.id in ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']: if arg.annotation.id in [
"int",
"float",
"bool",
"str",
"list",
"tuple",
"set",
"dict",
]:
param_info["type"] = arg.annotation.id param_info["type"] = arg.annotation.id
else: else:
param_info["type"] = "[UNKNOWN - PLEASE FIX]" param_info["type"] = "[UNKNOWN - PLEASE FIX]"
# Handle generic subscript types (e.g., Optional, List[Type], etc.) # Handle generic subscript types (e.g., Optional, List[Type], etc.)
elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Subscript): elif hasattr(arg, "annotation") and isinstance(
if isinstance(arg.annotation.value, ast.Name) and arg.annotation.value.id in ['list', 'tuple', 'set', 'dict']: arg.annotation, ast.Subscript
param_info["type"] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc. ):
if isinstance(
arg.annotation.value, ast.Name
) and arg.annotation.value.id in ["list", "tuple", "set", "dict"]:
param_info[
"type"
] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
else: else:
param_info["type"] = "[UNKNOWN - PLEASE FIX]" param_info["type"] = "[UNKNOWN - PLEASE FIX]"
# Default for unknown types # Default for unknown types
else: else:
param_info["type"] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type param_info[
"type"
] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type
# Handle default values # Handle default values
if default is not None: if default is not None:
if isinstance(default, ast.Constant) or isinstance(default, ast.NameConstant): if isinstance(default, ast.Constant) or isinstance(
param_info["default"] = default.value # Use the default value directly default, ast.NameConstant
):
param_info[
"default"
] = default.value # Use the default value directly
else: else:
param_info["default"] = "[UNKNOWN DEFAULT]" # Unknown default type param_info["default"] = "[UNKNOWN DEFAULT]" # Unknown default type
param_info["required"] = False # Optional since it has a default value param_info["required"] = False # Optional since it has a default value
@ -164,6 +221,7 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
return parameters return parameters
def get_function_docstring(node: Any) -> str: def get_function_docstring(node: Any) -> str:
"""Extract the function's docstring description if present.""" """Extract the function's docstring description if present."""
# Check if the first node is a docstring # Check if the first node is a docstring
@ -178,6 +236,7 @@ def get_function_docstring(node: Any) -> str:
return "No description provided." return "No description provided."
def extract_arg_descriptions_from_docstring(docstring: str) -> dict: def extract_arg_descriptions_from_docstring(docstring: str) -> dict:
"""Extract descriptions for function parameters from the 'Args' section of the docstring.""" """Extract descriptions for function parameters from the 'Args' section of the docstring."""
descriptions = {} descriptions = {}
@ -195,14 +254,14 @@ def extract_arg_descriptions_from_docstring(docstring: str) -> dict:
continue # Proceed to the next line after 'Args:' continue # Proceed to the next line after 'Args:'
# End of 'Args' section if no indentation and no colon # End of 'Args' section if no indentation and no colon
if in_args_section and not line.startswith(" ") and ':' not in line: if in_args_section and not line.startswith(" ") and ":" not in line:
break # Stop processing if we reach a new section break # Stop processing if we reach a new section
# Process lines in the 'Args' section # Process lines in the 'Args' section
if in_args_section: if in_args_section:
if ':' in line: if ":" in line:
# Extract parameter name and description # Extract parameter name and description
param_name, description = line.split(':', 1) param_name, description = line.split(":", 1)
descriptions[param_name.strip()] = description.strip() descriptions[param_name.strip()] = description.strip()
current_param = param_name.strip() current_param = param_name.strip()
elif current_param and line.startswith(" "): elif current_param and line.startswith(" "):
@ -230,43 +289,50 @@ def generate_prompt_targets(input_file_path: str) -> None:
route_decorators = get_route_decorators(node, framework) route_decorators = get_route_decorators(node, framework)
if route_decorators: if route_decorators:
route_path = get_route_path(node, framework) route_path = get_route_path(node, framework)
function_params = get_function_parameters(node, tree) # Get parameters for the route function_params = get_function_parameters(
node, tree
) # Get parameters for the route
function_docstring = get_function_docstring(node) # Extract docstring function_docstring = get_function_docstring(node) # Extract docstring
routes.append({ routes.append(
'name': node.name, {
'path': route_path, "name": node.name,
'methods': route_decorators, "path": route_path,
'parameters': function_params, # Add parameters to the route "methods": route_decorators,
'description': function_docstring # Add the docstring as the description "parameters": function_params, # Add parameters to the route
}) "description": function_docstring, # Add the docstring as the description
}
)
# Generate YAML structure # Generate YAML structure
output_structure = { output_structure = {"prompt_targets": []}
"prompt_targets": []
}
for route in routes: for route in routes:
target = { target = {
"name": route['name'], "name": route["name"],
"endpoint": [ "endpoint": [
{ {
"name": "app_server", "name": "app_server",
"path": route['path'], "path": route["path"],
} }
], ],
"description": route['description'], # Use extracted docstring "description": route["description"], # Use extracted docstring
"parameters": [ "parameters": [
{ {
"name": param['name'], "name": param["name"],
"type": param['type'], "type": param["type"],
"description": f"{param['description']}", "description": f"{param['description']}",
**({"default": param['default']} if "default" in param and param['default'] is not None else {}), # Only add default if it's set **(
"required": param['required'] {"default": param["default"]}
} for param in route['parameters'] if "default" in param and param["default"] is not None
] else {}
), # Only add default if it's set
"required": param["required"],
}
for param in route["parameters"]
],
} }
if route['name'] == "default": if route["name"] == "default":
# Special case for `information_extraction` based on your YAML format # Special case for `information_extraction` based on your YAML format
target["type"] = "default" target["type"] = "default"
target["auto-llm-dispatch-on-response"] = True target["auto-llm-dispatch-on-response"] = True
@ -274,9 +340,12 @@ def generate_prompt_targets(input_file_path: str) -> None:
output_structure["prompt_targets"].append(target) output_structure["prompt_targets"].append(target)
# Output as YAML # Output as YAML
print(yaml.dump(output_structure, sort_keys=False,default_flow_style=False, indent=3)) print(
yaml.dump(output_structure, sort_keys=False, default_flow_style=False, indent=3)
)
if __name__ == '__main__':
if __name__ == "__main__":
if len(sys.argv) != 2: if len(sys.argv) != 2:
print("Usage: python targets.py <input_file>") print("Usage: python targets.py <input_file>")
sys.exit(1) sys.exit(1)

View file

@ -4,12 +4,23 @@ from typing import List, Dict, Set
app = FastAPI() app = FastAPI()
class User(BaseModel): class User(BaseModel):
name: str = Field("John Doe", description="The name of the user.") # Default value and description for name name: str = Field(
"John Doe", description="The name of the user."
) # Default value and description for name
location: int = None location: int = None
age: int = Field(30, description="The age of the user.") # Default value and description for age age: int = Field(
tags: Set[str] = Field(default_factory=set, description="A set of tags associated with the user.") # Default empty set and description for tags 30, description="The age of the user."
metadata: Dict[str, int] = Field(default_factory=dict, description="A dictionary storing metadata about the user, with string keys and integer values.") # Default empty dict and description for metadata ) # Default value and description for age
tags: Set[str] = Field(
default_factory=set, description="A set of tags associated with the user."
) # Default empty set and description for tags
metadata: Dict[str, int] = Field(
default_factory=dict,
description="A dictionary storing metadata about the user, with string keys and integer values.",
) # Default empty dict and description for metadata
@app.get("/agent/default") @app.get("/agent/default")
async def default(request: User): async def default(request: User):
@ -19,6 +30,7 @@ async def default(request: User):
""" """
return {"info": f"Query: {request.name}, Count: {request.age}"} return {"info": f"Query: {request.name}, Count: {request.age}"}
@app.post("/agent/action") @app.post("/agent/action")
async def reboot_network_device(device_id: str, confirmation: str): async def reboot_network_device(device_id: str, confirmation: str):
""" """

View file

@ -6,6 +6,7 @@ import shlex
import yaml import yaml
import json import json
def run_docker_compose_ps(compose_file, env): def run_docker_compose_ps(compose_file, env):
""" """
Check if all Docker Compose services are in a healthy state. Check if all Docker Compose services are in a healthy state.
@ -16,20 +17,31 @@ def run_docker_compose_ps(compose_file, env):
try: try:
# Run `docker-compose ps` to get the health status of each service # Run `docker-compose ps` to get the health status of each service
ps_process = subprocess.Popen( ps_process = subprocess.Popen(
["docker", "compose", "-p", "arch", "ps", "--format", "table{{.Service}}\t{{.State}}\t{{.Ports}}"], [
"docker",
"compose",
"-p",
"arch",
"ps",
"--format",
"table{{.Service}}\t{{.State}}\t{{.Ports}}",
],
cwd=os.path.dirname(compose_file), cwd=os.path.dirname(compose_file),
stdout=subprocess.PIPE, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, stderr=subprocess.PIPE,
text=True, text=True,
start_new_session=True, start_new_session=True,
env=env env=env,
) )
# Capture the output of `docker-compose ps` # Capture the output of `docker-compose ps`
services_status, error_output = ps_process.communicate() services_status, error_output = ps_process.communicate()
# Check if there is any error output # Check if there is any error output
if error_output: if error_output:
print(f"Error while checking service status:\n{error_output}", file=os.sys.stderr) print(
f"Error while checking service status:\n{error_output}",
file=os.sys.stderr,
)
return {} return {}
services = parse_docker_compose_ps_output(services_status) services = parse_docker_compose_ps_output(services_status)
@ -39,26 +51,31 @@ def run_docker_compose_ps(compose_file, env):
print(f"Failed to check service status. Error:\n{e.stderr}") print(f"Failed to check service status. Error:\n{e.stderr}")
return e return e
#Helper method to print service status
# Helper method to print service status
def print_service_status(services): def print_service_status(services):
print(f"{'Service Name':<25} {'State':<20} {'Ports'}") print(f"{'Service Name':<25} {'State':<20} {'Ports'}")
print("="*72) print("=" * 72)
for service_name, info in services.items(): for service_name, info in services.items():
status = info['STATE'] status = info["STATE"]
ports = info['PORTS'] ports = info["PORTS"]
print(f"{service_name:<25} {status:<20} {ports}") print(f"{service_name:<25} {status:<20} {ports}")
#check for states based on the states passed in
# check for states based on the states passed in
def check_services_state(services, states): def check_services_state(services, states):
for service_name, service_info in services.items(): for service_name, service_info in services.items():
status = service_info['STATE'].lower() # Convert status to lowercase for easier comparison status = service_info[
"STATE"
].lower() # Convert status to lowercase for easier comparison
if any(state in status for state in states): if any(state in status for state in states):
return True return True
return False return False
def get_llm_provider_access_keys(arch_config_file): def get_llm_provider_access_keys(arch_config_file):
with open(arch_config_file, 'r') as file: with open(arch_config_file, "r") as file:
arch_config = file.read() arch_config = file.read()
arch_config_yaml = yaml.safe_load(arch_config) arch_config_yaml = yaml.safe_load(arch_config)
@ -70,22 +87,23 @@ def get_llm_provider_access_keys(arch_config_file):
return access_key_list return access_key_list
def load_env_file_to_dict(file_path): def load_env_file_to_dict(file_path):
env_dict = {} env_dict = {}
# Open and read the .env file # Open and read the .env file
with open(file_path, 'r') as file: with open(file_path, "r") as file:
for line in file: for line in file:
# Strip any leading/trailing whitespaces # Strip any leading/trailing whitespaces
line = line.strip() line = line.strip()
# Skip empty lines and comments # Skip empty lines and comments
if not line or line.startswith('#'): if not line or line.startswith("#"):
continue continue
# Split the line into key and value at the first '=' sign # Split the line into key and value at the first '=' sign
if '=' in line: if "=" in line:
key, value = line.split('=', 1) key, value = line.split("=", 1)
key = key.strip() key = key.strip()
value = value.strip() value = value.strip()
@ -94,6 +112,7 @@ def load_env_file_to_dict(file_path):
return env_dict return env_dict
def parse_docker_compose_ps_output(output): def parse_docker_compose_ps_output(output):
# Split the output into lines # Split the output into lines
lines = output.strip().splitlines() lines = output.strip().splitlines()
@ -111,10 +130,7 @@ def parse_docker_compose_ps_output(output):
parts = line.split() parts = line.split()
# Create a dictionary entry using the header names # Create a dictionary entry using the header names
service_info = { service_info = {headers[1]: parts[1], headers[2]: parts[2]} # State # Ports
headers[1]: parts[1], # State
headers[2]: parts[2] # Ports
}
# Add to the result dictionary using the service name as the key # Add to the result dictionary using the service name as the key
services[parts[0]] = service_info services[parts[0]] = service_info

View file

@ -7,47 +7,53 @@ from dotenv import load_dotenv
load_dotenv() load_dotenv()
OPENAI_API_KEY=os.getenv("OPENAI_API_KEY") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo") MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
ARCH_STATE_HEADER = 'x-arch-state' ARCH_STATE_HEADER = "x-arch-state"
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT, http_client=DefaultHttpxClient(headers={"accept-encoding": "*"})) client = OpenAI(
api_key=OPENAI_API_KEY,
base_url=CHAT_COMPLETION_ENDPOINT,
http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}),
)
def predict(message, state): def predict(message, state):
if 'history' not in state: if "history" not in state:
state['history'] = [] state["history"] = []
history = state.get("history") history = state.get("history")
history.append({"role": "user", "content": message}) history.append({"role": "user", "content": message})
log.info("history: ", history) log.info("history: ", history)
# Custom headers # Custom headers
custom_headers = { custom_headers = {
'x-arch-openai-api-key': f"{OPENAI_API_KEY}", "x-arch-openai-api-key": f"{OPENAI_API_KEY}",
'x-arch-mistral-api-key': f"{MISTRAL_API_KEY}", "x-arch-mistral-api-key": f"{MISTRAL_API_KEY}",
'x-arch-deterministic-provider': 'openai', "x-arch-deterministic-provider": "openai",
} }
metadata = None metadata = None
if 'arch_state' in state: if "arch_state" in state:
metadata = {ARCH_STATE_HEADER: state['arch_state']} metadata = {ARCH_STATE_HEADER: state["arch_state"]}
try: try:
raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME, raw_response = client.chat.completions.with_raw_response.create(
messages = history, model=MODEL_NAME,
temperature=1.0, messages=history,
metadata=metadata, temperature=1.0,
extra_headers=custom_headers metadata=metadata,
) extra_headers=custom_headers,
)
except Exception as e: except Exception as e:
log.info(e) log.info(e)
# remove last user message in case of exception # remove last user message in case of exception
history.pop() history.pop()
log.info("Error calling gateway API: {}".format(e.message)) log.info("Error calling gateway API: {}".format(e.message))
raise gr.Error("Error calling gateway API: {}".format(e.message)) raise gr.Error("Error calling gateway API: {}".format(e.message))
log.info("raw_response: ", raw_response.text) log.info("raw_response: ", raw_response.text)
response = raw_response.parse() response = raw_response.parse()
@ -57,24 +63,33 @@ def predict(message, state):
response_json = json.loads(raw_response.text) response_json = json.loads(raw_response.text)
arch_state = None arch_state = None
if response_json: if response_json:
metadata = response_json.get('metadata', {}) metadata = response_json.get("metadata", {})
if metadata: if metadata:
arch_state = metadata.get(ARCH_STATE_HEADER, None) arch_state = metadata.get(ARCH_STATE_HEADER, None)
if arch_state: if arch_state:
state['arch_state'] = arch_state state["arch_state"] = arch_state
content = response.choices[0].message.content content = response.choices[0].message.content
history.append({"role": "assistant", "content": content, "model": response.model}) history.append({"role": "assistant", "content": content, "model": response.model})
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)] messages = [
(history[i]["content"], history[i + 1]["content"])
for i in range(0, len(history) - 1, 2)
]
return messages, state return messages, state
with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo: with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo:
print("Starting Demo...") print("Starting Demo...")
chatbot = gr.Chatbot(label="Arch Chatbot", scale=1) chatbot = gr.Chatbot(label="Arch Chatbot", scale=1)
state = gr.State({}) state = gr.State({})
with gr.Row(): with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1, autofocus=True) txt = gr.Textbox(
show_label=False,
placeholder="Enter text and press enter",
scale=1,
autofocus=True,
)
txt.submit(predict, [txt, state], [chatbot, state]) txt.submit(predict, [txt, state], [chatbot, state])

View file

@ -5,26 +5,32 @@ from openai import OpenAI
import gradio as gr import gradio as gr
api_key = os.getenv("OPENAI_API_KEY") api_key = os.getenv("OPENAI_API_KEY")
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1") CHAT_COMPLETION_ENDPOINT = os.getenv(
"CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1"
)
client = OpenAI(api_key=api_key, base_url=CHAT_COMPLETION_ENDPOINT) client = OpenAI(api_key=api_key, base_url=CHAT_COMPLETION_ENDPOINT)
def predict(message, history): def predict(message, history):
history_openai_format = [] history_openai_format = []
for human, assistant in history: for human, assistant in history:
history_openai_format.append({"role": "user", "content": human }) history_openai_format.append({"role": "user", "content": human})
history_openai_format.append({"role": "assistant", "content":assistant}) history_openai_format.append({"role": "assistant", "content": assistant})
history_openai_format.append({"role": "user", "content": message}) history_openai_format.append({"role": "user", "content": message})
response = client.chat.completions.create(model='gpt-3.5-turbo', response = client.chat.completions.create(
messages= history_openai_format, model="gpt-3.5-turbo",
temperature=1.0, messages=history_openai_format,
stream=True) temperature=1.0,
stream=True,
)
partial_message = "" partial_message = ""
for chunk in response: for chunk in response:
if chunk.choices[0].delta.content is not None: if chunk.choices[0].delta.content is not None:
partial_message = partial_message + chunk.choices[0].delta.content partial_message = partial_message + chunk.choices[0].delta.content
yield partial_message yield partial_message
gr.ChatInterface(predict).launch(server_name="0.0.0.0", server_port=8081) gr.ChatInterface(predict).launch(server_name="0.0.0.0", server_port=8081)

View file

@ -6,56 +6,54 @@ import logging
from pydantic import BaseModel from pydantic import BaseModel
logger = logging.getLogger('uvicorn.error') logger = logging.getLogger("uvicorn.error")
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
app = FastAPI() app = FastAPI()
@app.get("/healthz") @app.get("/healthz")
async def healthz(): async def healthz():
return { return {"status": "ok"}
"status": "ok"
}
class WeatherRequest(BaseModel): class WeatherRequest(BaseModel):
city: str city: str
days: int = 7 days: int = 7
units: str = "Farenheit" units: str = "Farenheit"
@app.post("/weather") @app.post("/weather")
async def weather(req: WeatherRequest, res: Response): async def weather(req: WeatherRequest, res: Response):
weather_forecast = { weather_forecast = {
"city": req.city, "city": req.city,
"temperature": [], "temperature": [],
"units": req.units, "units": req.units,
} }
for i in range(7): for i in range(7):
min_temp = random.randrange(50,90) min_temp = random.randrange(50, 90)
max_temp = random.randrange(min_temp+5, min_temp+20) max_temp = random.randrange(min_temp + 5, min_temp + 20)
if req.units.lower() == "celsius" or req.units.lower() == "c": if req.units.lower() == "celsius" or req.units.lower() == "c":
min_temp = (min_temp - 32) * 5.0/9.0 min_temp = (min_temp - 32) * 5.0 / 9.0
max_temp = (max_temp - 32) * 5.0/9.0 max_temp = (max_temp - 32) * 5.0 / 9.0
weather_forecast["temperature"].append({ weather_forecast["temperature"].append(
"date": str(date.today() + timedelta(days=i)), {
"temperature": { "date": str(date.today() + timedelta(days=i)),
"min": min_temp, "temperature": {"min": min_temp, "max": max_temp},
"max": max_temp "units": req.units,
}, "query_time": str(datetime.now(timezone.utc)),
"units": req.units, }
"query_time": str(datetime.now(timezone.utc)) )
})
return weather_forecast return weather_forecast
class InsuranceClaimDetailsRequest(BaseModel): class InsuranceClaimDetailsRequest(BaseModel):
policy_number: str policy_number: str
@app.post("/insurance_claim_details") @app.post("/insurance_claim_details")
async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Response): async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Response):
claim_details = { claim_details = {
"policy_number": req.policy_number, "policy_number": req.policy_number,
"claim_status": "Approved", "claim_status": "Approved",
@ -68,26 +66,25 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon
class DefaultTargetRequest(BaseModel): class DefaultTargetRequest(BaseModel):
arch_messages: list arch_messages: list
@app.post("/default_target") @app.post("/default_target")
async def default_target(req: DefaultTargetRequest, res: Response): async def default_target(req: DefaultTargetRequest, res: Response):
logger.info(f"Received arch_messages: {req.arch_messages}") logger.info(f"Received arch_messages: {req.arch_messages}")
resp = { resp = {
"choices": [ "choices": [
{ {
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": "hello world from api server" "content": "hello world from api server",
}, },
"finish_reason": "completed", "finish_reason": "completed",
"index": 0 "index": 0,
}
],
"model": "api_server",
"usage": {
"completion_tokens": 0
} }
} ],
"model": "api_server",
"usage": {"completion_tokens": 0},
}
logger.info(f"sending response: {json.dumps(resp)}") logger.info(f"sending response: {json.dumps(resp)}")
return resp return resp

View file

@ -3,28 +3,40 @@ from pydantic import BaseModel, Field
app = FastAPI() app = FastAPI()
class Conversation(BaseModel): class Conversation(BaseModel):
arch_messages: list arch_messages: list
class PolicyCoverageRequest(BaseModel): class PolicyCoverageRequest(BaseModel):
policy_type: str = Field(..., description="The type of a policy held by the customer For, e.g. car, boat, house, motorcycle)") policy_type: str = Field(
...,
description="The type of a policy held by the customer For, e.g. car, boat, house, motorcycle)",
)
class PolicyInitiateRequest(PolicyCoverageRequest): class PolicyInitiateRequest(PolicyCoverageRequest):
deductible: float = Field(..., description="The deductible amount set of the policy") deductible: float = Field(
..., description="The deductible amount set of the policy"
)
class ClaimUpdate(BaseModel): class ClaimUpdate(BaseModel):
claim_id: str claim_id: str
notes: str # Status or details of the claim notes: str # Status or details of the claim
class DeductibleUpdate(BaseModel): class DeductibleUpdate(BaseModel):
policy_id: str policy_id: str
deductible: float deductible: float
class CoverageResponse(BaseModel): class CoverageResponse(BaseModel):
policy_type: str policy_type: str
coverage: str # Description of coverage coverage: str # Description of coverage
premium: float # The premium cost premium: float # The premium cost
# Get information about policy coverage # Get information about policy coverage
@app.post("/policy/coverage", response_model=CoverageResponse) @app.post("/policy/coverage", response_model=CoverageResponse)
async def get_policy_coverage(req: PolicyCoverageRequest): async def get_policy_coverage(req: PolicyCoverageRequest):
@ -32,10 +44,22 @@ async def get_policy_coverage(req: PolicyCoverageRequest):
Retrieve the coverage details for a given policy type (car, boat, house, motorcycle). Retrieve the coverage details for a given policy type (car, boat, house, motorcycle).
""" """
policy_coverage = { policy_coverage = {
"car": {"coverage": "Full car coverage with collision, liability", "premium": 500.0}, "car": {
"boat": {"coverage": "Full boat coverage including theft and storm damage", "premium": 700.0}, "coverage": "Full car coverage with collision, liability",
"house": {"coverage": "Full house coverage including fire, theft, flood", "premium": 1000.0}, "premium": 500.0,
"motorcycle": {"coverage": "Full motorcycle coverage with liability", "premium": 400.0}, },
"boat": {
"coverage": "Full boat coverage including theft and storm damage",
"premium": 700.0,
},
"house": {
"coverage": "Full house coverage including fire, theft, flood",
"premium": 1000.0,
},
"motorcycle": {
"coverage": "Full motorcycle coverage with liability",
"premium": 400.0,
},
} }
if req.policy_type not in policy_coverage: if req.policy_type not in policy_coverage:
@ -44,9 +68,10 @@ async def get_policy_coverage(req: PolicyCoverageRequest):
return CoverageResponse( return CoverageResponse(
policy_type=req.policy_type, policy_type=req.policy_type,
coverage=policy_coverage[req.policy_type]["coverage"], coverage=policy_coverage[req.policy_type]["coverage"],
premium=policy_coverage[req.policy_type]["premium"] premium=policy_coverage[req.policy_type]["premium"],
) )
# Initiate policy coverage # Initiate policy coverage
@app.post("/policy/initiate") @app.post("/policy/initiate")
async def initiate_policy(policy_request: PolicyInitiateRequest): async def initiate_policy(policy_request: PolicyInitiateRequest):
@ -56,7 +81,11 @@ async def initiate_policy(policy_request: PolicyInitiateRequest):
if policy_request.policy_type not in ["car", "boat", "house", "motorcycle"]: if policy_request.policy_type not in ["car", "boat", "house", "motorcycle"]:
raise HTTPException(status_code=400, detail="Invalid policy type") raise HTTPException(status_code=400, detail="Invalid policy type")
return {"message": f"Policy initiated for {policy_request.policy_type}", "deductible": policy_request.deductible} return {
"message": f"Policy initiated for {policy_request.policy_type}",
"deductible": policy_request.deductible,
}
# Update claim details # Update claim details
@app.post("/policy/claim") @app.post("/policy/claim")
@ -65,8 +94,11 @@ async def update_claim(req: ClaimUpdate):
Update the status or details of a claim. Update the status or details of a claim.
""" """
# For simplicity, this is a mock update response # For simplicity, this is a mock update response
return {"message": f"Claim {claim_update.claim_id} for policy {claim_update.claim_id} has been updated", return {
"update": claim_update.notes} "message": f"Claim {claim_update.claim_id} for policy {claim_update.claim_id} has been updated",
"update": claim_update.notes,
}
# Update deductible amount # Update deductible amount
@app.post("/policy/deductible") @app.post("/policy/deductible")
@ -75,8 +107,11 @@ async def update_deductible(deductible_update: DeductibleUpdate):
Update the deductible amount for a specific policy. Update the deductible amount for a specific policy.
""" """
# For simplicity, this is a mock update response # For simplicity, this is a mock update response
return {"message": f"Deductible for policy {deductible_update.policy_id} has been updated", return {
"new_deductible": deductible_update.deductible} "message": f"Deductible for policy {deductible_update.policy_id} has been updated",
"new_deductible": deductible_update.deductible,
}
# Post method for policy Q/A # Post method for policy Q/A
@app.post("/policy/qa") @app.post("/policy/qa")
@ -86,21 +121,20 @@ async def policy_qa(conversation: Conversation):
It forwards the conversation to the OpenAI client via a local proxy and returns the response. It forwards the conversation to the OpenAI client via a local proxy and returns the response.
""" """
return { return {
"choices": [ "choices": [
{ {
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": "I am a helpful insurance agent, and can only help with insurance things" "content": "I am a helpful insurance agent, and can only help with insurance things",
}, },
"finish_reason": "completed", "finish_reason": "completed",
"index": 0 "index": 0,
}
],
"model": "insurance_agent",
"usage": {
"completion_tokens": 0
} }
} ],
"model": "insurance_agent",
"usage": {"completion_tokens": 0},
}
# Run the app using: # Run the app using:
# uvicorn main:app --reload # uvicorn main:app --reload

View file

@ -4,10 +4,14 @@ from typing import List, Optional
app = FastAPI() app = FastAPI()
# Define the request model # Define the request model
class DeviceSummaryRequest(BaseModel): class DeviceSummaryRequest(BaseModel):
device_ids: List[int] device_ids: List[int]
time_range: Optional[int] = Field(default=7, description="Time range in days, defaults to 7") time_range: Optional[int] = Field(
default=7, description="Time range in days, defaults to 7"
)
# Define the response model # Define the response model
class DeviceStatistics(BaseModel): class DeviceStatistics(BaseModel):
@ -15,18 +19,23 @@ class DeviceStatistics(BaseModel):
time_range: str time_range: str
data: str data: str
class DeviceSummaryResponse(BaseModel): class DeviceSummaryResponse(BaseModel):
statistics: List[DeviceStatistics] statistics: List[DeviceStatistics]
# Request model for device reboot # Request model for device reboot
class DeviceRebootRequest(BaseModel): class DeviceRebootRequest(BaseModel):
device_ids: List[int] device_ids: List[int]
# Response model for the device reboot # Response model for the device reboot
class CoverageResponse(BaseModel): class CoverageResponse(BaseModel):
status: str status: str
summary: dict summary: dict
@app.post("/agent/device_reboot", response_model=CoverageResponse) @app.post("/agent/device_reboot", response_model=CoverageResponse)
def reboot_network_device(request_data: DeviceRebootRequest): def reboot_network_device(request_data: DeviceRebootRequest):
""" """
@ -38,20 +47,21 @@ def reboot_network_device(request_data: DeviceRebootRequest):
# Validate 'device_ids' (This is already validated by Pydantic, but additional logic can be added if needed) # Validate 'device_ids' (This is already validated by Pydantic, but additional logic can be added if needed)
if not device_ids: if not device_ids:
raise HTTPException(status_code=400, detail="'device_ids' parameter is required") raise HTTPException(
status_code=400, detail="'device_ids' parameter is required"
)
# Simulate reboot operation and return the response # Simulate reboot operation and return the response
statistics = [] statistics = []
for device_id in device_ids: for device_id in device_ids:
# Placeholder for actual data retrieval or device reboot logic # Placeholder for actual data retrieval or device reboot logic
stats = { stats = {"data": f"Device {device_id} has been successfully rebooted."}
"data": f"Device {device_id} has been successfully rebooted."
}
statistics.append(stats) statistics.append(stats)
# Return the response with a summary # Return the response with a summary
return CoverageResponse(status="success", summary={"device_ids": device_ids}) return CoverageResponse(status="success", summary={"device_ids": device_ids})
# Post method for device summary # Post method for device summary
@app.post("/agent/device_summary", response_model=DeviceSummaryResponse) @app.post("/agent/device_summary", response_model=DeviceSummaryResponse)
def get_device_summary(request: DeviceSummaryRequest): def get_device_summary(request: DeviceSummaryRequest):
@ -77,6 +87,7 @@ def get_device_summary(request: DeviceSummaryRequest):
return DeviceSummaryResponse(statistics=statistics) return DeviceSummaryResponse(statistics=statistics)
@app.post("/agent/network_summary") @app.post("/agent/network_summary")
async def policy_qa(): async def policy_qa():
""" """
@ -84,21 +95,20 @@ async def policy_qa():
It forwards the conversation to the OpenAI client via a local proxy and returns the response. It forwards the conversation to the OpenAI client via a local proxy and returns the response.
""" """
return { return {
"choices": [ "choices": [
{ {
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": "I am a helpful networking agent, and I can help you get status for network devices or reboot them" "content": "I am a helpful networking agent, and I can help you get status for network devices or reboot them",
}, },
"finish_reason": "completed", "finish_reason": "completed",
"index": 0 "index": 0,
}
],
"model": "network_agent",
"usage": {
"completion_tokens": 0
} }
} ],
"model": "network_agent",
"usage": {"completion_tokens": 0},
}
if __name__ == "__main__": if __name__ == "__main__":
app.run(debug=True) app.run(debug=True)

View file

@ -11,6 +11,7 @@ logging.basicConfig(
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def load_sql(): def load_sql():
# Example Usage # Example Usage
conn = sqlite3.connect(":memory:") conn = sqlite3.connect(":memory:")
@ -26,6 +27,7 @@ def load_sql():
return conn return conn
# Function to convert natural language time expressions to "X {time} ago" format # Function to convert natural language time expressions to "X {time} ago" format
def convert_to_ago_format(expression): def convert_to_ago_format(expression):
# Define patterns for different time units # Define patterns for different time units

View file

@ -13,9 +13,10 @@ def get_device_summary():
# Validate 'device_ids' parameter # Validate 'device_ids' parameter
device_ids = data.get("device_ids") device_ids = data.get("device_ids")
if not device_ids or not isinstance(device_ids, list): if not device_ids or not isinstance(device_ids, list):
return jsonify( return (
{"error": "'device_ids' parameter is required and must be a list"} jsonify({"error": "'device_ids' parameter is required and must be a list"}),
), 400 400,
)
# Validate 'time_range' parameter (optional, defaults to 7) # Validate 'time_range' parameter (optional, defaults to 7)
time_range = data.get("time_range", 7) time_range = data.get("time_range", 7)

View file

@ -119,9 +119,10 @@ def process_rag():
intent_changed = True intent_changed = True
else: else:
# Invalid value provided # Invalid value provided
return jsonify( return (
{"error": "Invalid value for x-arch-prompt-intent-change header"} jsonify({"error": "Invalid value for x-arch-prompt-intent-change header"}),
), 400 400,
)
# Update user conversation based on intent change # Update user conversation based on intent change
memory = update_user_conversation(user_id, client_messages, intent_changed) memory = update_user_conversation(user_id, client_messages, intent_changed)

View file

@ -13,9 +13,10 @@ def get_device_summary():
# Validate 'device_ids' parameter # Validate 'device_ids' parameter
device_ids = data.get("device_ids") device_ids = data.get("device_ids")
if not device_ids or not isinstance(device_ids, list): if not device_ids or not isinstance(device_ids, list):
return jsonify( return (
{"error": "'device_ids' parameter is required and must be a list"} jsonify({"error": "'device_ids' parameter is required and must be a list"}),
), 400 400,
)
# Validate 'time_range' parameter (optional, defaults to 7) # Validate 'time_range' parameter (optional, defaults to 7)
time_range = data.get("time_range", 7) time_range = data.get("time_range", 7)

View file

@ -12,16 +12,16 @@ from sphinx.util.docfields import Field
from sphinxawesome_theme import ThemeOptions from sphinxawesome_theme import ThemeOptions
from sphinxawesome_theme.postprocess import Icons from sphinxawesome_theme.postprocess import Icons
project = 'Arch Docs' project = "Arch Docs"
copyright = '2024, Katanemo Labs, Inc' copyright = "2024, Katanemo Labs, Inc"
author = 'Katanemo Labs, Inc' author = "Katanemo Labs, Inc"
release = ' v0.1' release = " v0.1"
# -- General configuration --------------------------------------------------- # -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
root_doc = 'index' root_doc = "index"
nitpicky = True nitpicky = True
add_module_names = False add_module_names = False
@ -33,23 +33,23 @@ extensions = [
"sphinx.ext.extlinks", "sphinx.ext.extlinks",
"sphinx.ext.viewcode", "sphinx.ext.viewcode",
"sphinx_sitemap", "sphinx_sitemap",
"sphinx_design" "sphinx_design",
] ]
# Paths that contain templates, relative to this directory. # Paths that contain templates, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# List of patterns, relative to source directory, that match files and directories # List of patterns, relative to source directory, that match files and directories
# to ignore when looking for source files. # to ignore when looking for source files.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
# -- Options for HTML output ------------------------------------------------- # -- Options for HTML output -------------------------------------------------
html_theme = 'sphinxawesome_theme' # You can change the theme to 'sphinx_rtd_theme' or another of your choice. html_theme = "sphinxawesome_theme" # You can change the theme to 'sphinx_rtd_theme' or another of your choice.
html_title = project + release html_title = project + release
html_permalinks_icon = Icons.permalinks_icon html_permalinks_icon = Icons.permalinks_icon
html_favicon = '_static/favicon.ico' html_favicon = "_static/favicon.ico"
html_logo = '_static/favicon.ico' # Specify the path to the logo image file (make sure the logo is in the _static directory) html_logo = "_static/favicon.ico" # Specify the path to the logo image file (make sure the logo is in the _static directory)
html_last_updated_fmt = "" html_last_updated_fmt = ""
html_use_index = False # Don't create index html_use_index = False # Don't create index
html_domain_indices = False # Don't need module indices html_domain_indices = False # Don't need module indices
@ -57,10 +57,14 @@ html_copy_source = False # Don't need sources
html_show_sphinx = False html_show_sphinx = False
html_baseurl = './docs' html_baseurl = "./docs"
html_sidebars = { html_sidebars = {
"**": ['analytics.html', "sidebar_main_nav_links.html", "sidebar_toc.html", ] "**": [
"analytics.html",
"sidebar_main_nav_links.html",
"sidebar_toc.html",
]
} }
theme_options = ThemeOptions( theme_options = ThemeOptions(
@ -102,7 +106,7 @@ html_theme_options = asdict(theme_options)
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
pygments_style = "lovelace" pygments_style = "lovelace"
pygments_style_dark = "github-dark" pygments_style_dark = "github-dark"
@ -111,10 +115,11 @@ sitemap_url_scheme = "{link}"
# Add this configuration at the bottom of your conf.py # Add this configuration at the bottom of your conf.py
html_context = { html_context = {
'google_analytics_id': 'G-K2LXXSX6HB', # Replace with your Google Analytics tracking ID "google_analytics_id": "G-K2LXXSX6HB", # Replace with your Google Analytics tracking ID
} }
templates_path = ['_templates'] templates_path = ["_templates"]
# -- Register a :confval: interpreted text role ---------------------------------- # -- Register a :confval: interpreted text role ----------------------------------
def setup(app: Sphinx) -> None: def setup(app: Sphinx) -> None:
@ -138,4 +143,4 @@ def setup(app: Sphinx) -> None:
], ],
) )
app.add_css_file('_static/custom.css') app.add_css_file("_static/custom.css")

View file

@ -35,11 +35,13 @@ def start_server():
print("Server is already running. Use 'model_server restart' to restart it.") print("Server is already running. Use 'model_server restart' to restart it.")
sys.exit(1) sys.exit(1)
print(f"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)") print(
f"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)"
)
process = subprocess.Popen( process = subprocess.Popen(
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"], ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],
start_new_session=True, start_new_session=True,
stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to
stderr=subprocess.DEVNULL, # Suppress standard error. There is a logger that model_server prints to stderr=subprocess.DEVNULL, # Suppress standard error. There is a logger that model_server prints to
) )

View file

@ -7,7 +7,6 @@ from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForSequenc
def get_device(): def get_device():
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = "cuda"
elif torch.backends.mps.is_available(): elif torch.backends.mps.is_available():
@ -19,13 +18,15 @@ def get_device():
return device return device
def load_transformers(model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")): def load_transformers(
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")
):
print("Loading Embedding Model") print("Loading Embedding Model")
transformers = {} transformers = {}
device = get_device() device = get_device()
transformers["tokenizer"] = AutoTokenizer.from_pretrained(model_name) transformers["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
transformers["model"] = ORTModelForFeatureExtraction.from_pretrained( transformers["model"] = ORTModelForFeatureExtraction.from_pretrained(
model_name, device_map = device model_name, device_map=device
) )
transformers["model_name"] = model_name transformers["model_name"] = model_name
@ -62,7 +63,9 @@ def load_guard_model(
return guard_model return guard_model
def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")): def load_zero_shot_models(
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")
):
zero_shot_model = {} zero_shot_model = {}
device = get_device() device = get_device()
zero_shot_model["model"] = ORTModelForSequenceClassification.from_pretrained( zero_shot_model["model"] = ORTModelForSequenceClassification.from_pretrained(
@ -81,5 +84,6 @@ def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deb
return zero_shot_model return zero_shot_model
if __name__ == "__main__": if __name__ == "__main__":
print(get_device()) print(get_device())

View file

@ -7,7 +7,12 @@ from app.load_models import (
get_device, get_device,
) )
import os import os
from app.utils import GuardHandler, split_text_into_chunks, load_yaml_config, get_model_server_logger from app.utils import (
GuardHandler,
split_text_into_chunks,
load_yaml_config,
get_model_server_logger,
)
import torch import torch
import yaml import yaml
import string import string
@ -39,6 +44,7 @@ guard_handler = GuardHandler(toxic_model=None, jailbreak_model=jailbreak_model)
app = FastAPI() app = FastAPI()
class EmbeddingRequest(BaseModel): class EmbeddingRequest(BaseModel):
input: str input: str
model: str model: str
@ -84,6 +90,7 @@ async def embedding(req: EmbeddingRequest, res: Response):
} }
return {"data": data, "model": req.model, "object": "list", "usage": usage} return {"data": data, "model": req.model, "object": "list", "usage": usage}
class GuardRequest(BaseModel): class GuardRequest(BaseModel):
input: str input: str
task: str task: str

View file

@ -9,6 +9,7 @@ import logging
logger_instance = None logger_instance = None
def load_yaml_config(file_name): def load_yaml_config(file_name):
# Load the YAML file from the package # Load the YAML file from the package
yaml_path = pkg_resources.resource_filename("app", file_name) yaml_path = pkg_resources.resource_filename("app", file_name)
@ -138,6 +139,7 @@ class GuardHandler:
} }
return result_dict return result_dict
def get_model_server_logger(): def get_model_server_logger():
global logger_instance global logger_instance
@ -164,8 +166,8 @@ def get_model_server_logger():
level=logging.INFO, level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s", format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[ handlers=[
logging.FileHandler(log_file_path, mode='w'), # Overwrite logs in file logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in file
] ],
) )
except (PermissionError, OSError) as e: except (PermissionError, OSError) as e:
# Dont' fallback to console logging if there are issues writing to the log file # Dont' fallback to console logging if there are issues writing to the log file