mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
lint + formating with black (#158)
* lint + formating with black * add black as pre commit
This commit is contained in:
parent
498e7f9724
commit
5c4a6bc8ff
22 changed files with 581 additions and 295 deletions
|
|
@ -25,3 +25,8 @@ repos:
|
|||
# --lib is to only test the library, since when integration tests are made,
|
||||
# they will be in a seperate tests directory
|
||||
entry: bash -c "cd arch && cargo test -p intelligent-prompt-gateway --lib"
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
|
|
|
|||
|
|
@ -16,18 +16,22 @@ logo = r"""
|
|||
/_/ \_\|_| \___||_| |_|
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@click.pass_context
|
||||
def main(ctx):
|
||||
if ctx.invoked_subcommand is None:
|
||||
click.echo( """Arch (The Intelligent Prompt Gateway) CLI""")
|
||||
click.echo("""Arch (The Intelligent Prompt Gateway) CLI""")
|
||||
click.echo(logo)
|
||||
click.echo(ctx.get_help())
|
||||
|
||||
|
||||
# Command to build archgw and model_server Docker images
|
||||
ARCHGW_DOCKERFILE = "./arch/Dockerfile"
|
||||
MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml"
|
||||
|
||||
|
||||
@click.command()
|
||||
def build():
|
||||
"""Build Arch from source. Must be in root of cloned repo."""
|
||||
|
|
@ -35,7 +39,18 @@ def build():
|
|||
if os.path.exists(ARCHGW_DOCKERFILE):
|
||||
click.echo("Building archgw image...")
|
||||
try:
|
||||
subprocess.run(["docker", "build", "-f", ARCHGW_DOCKERFILE, "-t", "archgw:latest", "."], check=True)
|
||||
subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"build",
|
||||
"-f",
|
||||
ARCHGW_DOCKERFILE,
|
||||
"-t",
|
||||
"archgw:latest",
|
||||
".",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
click.echo("archgw image built successfully.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"Error building archgw image: {e}")
|
||||
|
|
@ -51,7 +66,11 @@ def build():
|
|||
if os.path.exists(MODEL_SERVER_BUILD_FILE):
|
||||
click.echo("Installing model server dependencies with Poetry...")
|
||||
try:
|
||||
subprocess.run(["poetry", "install", "--no-cache"], cwd=os.path.dirname(MODEL_SERVER_BUILD_FILE), check=True)
|
||||
subprocess.run(
|
||||
["poetry", "install", "--no-cache"],
|
||||
cwd=os.path.dirname(MODEL_SERVER_BUILD_FILE),
|
||||
check=True,
|
||||
)
|
||||
click.echo("Model server dependencies installed successfully.")
|
||||
except subprocess.CalledProcessError as e:
|
||||
click.echo(f"Error installing model server dependencies: {e}")
|
||||
|
|
@ -60,9 +79,12 @@ def build():
|
|||
click.echo(f"Error: pyproject.toml not found in {MODEL_SERVER_BUILD_FILE}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument('file', required=False) # Optional file argument
|
||||
@click.option('-path', default='.', help='Path to the directory containing arch_config.yml')
|
||||
@click.argument("file", required=False) # Optional file argument
|
||||
@click.option(
|
||||
"-path", default=".", help="Path to the directory containing arch_config.yml"
|
||||
)
|
||||
def up(file, path):
|
||||
"""Starts Arch."""
|
||||
if file:
|
||||
|
|
@ -78,10 +100,15 @@ def up(file, path):
|
|||
return
|
||||
|
||||
print(f"Validating {arch_config_file}")
|
||||
arch_schema_config = pkg_resources.resource_filename(__name__, "config/arch_config_schema.yaml")
|
||||
arch_schema_config = pkg_resources.resource_filename(
|
||||
__name__, "config/arch_config_schema.yaml"
|
||||
)
|
||||
|
||||
try:
|
||||
config_generator.validate_prompt_config(arch_config_file=arch_config_file, arch_config_schema_file=arch_schema_config)
|
||||
config_generator.validate_prompt_config(
|
||||
arch_config_file=arch_config_file,
|
||||
arch_config_schema_file=arch_schema_config,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Exiting archgw up")
|
||||
sys.exit(1)
|
||||
|
|
@ -91,52 +118,67 @@ def up(file, path):
|
|||
# Set the ARCH_CONFIG_FILE environment variable
|
||||
env_stage = {}
|
||||
env = os.environ.copy()
|
||||
#check if access_keys are preesnt in the config file
|
||||
# check if access_keys are preesnt in the config file
|
||||
access_keys = get_llm_provider_access_keys(arch_config_file=arch_config_file)
|
||||
if access_keys:
|
||||
if file:
|
||||
app_env_file = os.path.join(os.path.dirname(os.path.abspath(file)), ".env") #check the .env file in the path
|
||||
app_env_file = os.path.join(
|
||||
os.path.dirname(os.path.abspath(file)), ".env"
|
||||
) # check the .env file in the path
|
||||
else:
|
||||
app_env_file = os.path.abspath(os.path.join(path, ".env"))
|
||||
|
||||
if not os.path.exists(app_env_file): #check to see if the environment variables in the current environment or not
|
||||
if not os.path.exists(
|
||||
app_env_file
|
||||
): # check to see if the environment variables in the current environment or not
|
||||
for access_key in access_keys:
|
||||
if env.get(access_key) is None:
|
||||
print (f"Access Key: {access_key} not found. Exiting Start")
|
||||
print(f"Access Key: {access_key} not found. Exiting Start")
|
||||
sys.exit(1)
|
||||
else:
|
||||
env_stage[access_key] = env.get(access_key)
|
||||
else: #.env file exists, use that to send parameters to Arch
|
||||
else: # .env file exists, use that to send parameters to Arch
|
||||
env_file_dict = load_env_file_to_dict(app_env_file)
|
||||
for access_key in access_keys:
|
||||
if env_file_dict.get(access_key) is None:
|
||||
print (f"Access Key: {access_key} not found. Exiting Start")
|
||||
print(f"Access Key: {access_key} not found. Exiting Start")
|
||||
sys.exit(1)
|
||||
else:
|
||||
env_stage[access_key] = env_file_dict[access_key]
|
||||
|
||||
with open(pkg_resources.resource_filename(__name__, "config/stage.env"), 'w') as file:
|
||||
with open(
|
||||
pkg_resources.resource_filename(__name__, "config/stage.env"), "w"
|
||||
) as file:
|
||||
for key, value in env_stage.items():
|
||||
file.write(f"{key}={value}\n")
|
||||
|
||||
env.update(env_stage)
|
||||
env['ARCH_CONFIG_FILE'] = arch_config_file
|
||||
env["ARCH_CONFIG_FILE"] = arch_config_file
|
||||
|
||||
start_arch_modelserver()
|
||||
start_arch(arch_config_file, env)
|
||||
|
||||
|
||||
@click.command()
|
||||
def down():
|
||||
"""Stops Arch."""
|
||||
stop_arch_modelserver()
|
||||
stop_arch()
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('-f', '--file', type=click.Path(exists=True), required=True, help="Path to the Python file")
|
||||
@click.option(
|
||||
"-f",
|
||||
"--file",
|
||||
type=click.Path(exists=True),
|
||||
required=True,
|
||||
help="Path to the Python file",
|
||||
)
|
||||
def generate_prompt_targets(file):
|
||||
"""Generats prompt_targets from python methods.
|
||||
Note: This works for simple data types like ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']:
|
||||
If you have a complex pydantic data type, you will have to flatten those manually until we add support for it."""
|
||||
Note: This works for simple data types like ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']:
|
||||
If you have a complex pydantic data type, you will have to flatten those manually until we add support for it.
|
||||
"""
|
||||
|
||||
print(f"Processing file: {file}")
|
||||
if not file.endswith(".py"):
|
||||
|
|
@ -145,10 +187,11 @@ def generate_prompt_targets(file):
|
|||
|
||||
targets.generate_prompt_targets(file)
|
||||
|
||||
|
||||
main.add_command(up)
|
||||
main.add_command(down)
|
||||
main.add_command(build)
|
||||
main.add_command(generate_prompt_targets)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -3,36 +3,44 @@ from jinja2 import Environment, FileSystemLoader
|
|||
import yaml
|
||||
from jsonschema import validate
|
||||
|
||||
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.template.yaml')
|
||||
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml')
|
||||
ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml')
|
||||
ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml')
|
||||
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv(
|
||||
"ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml"
|
||||
)
|
||||
ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/config/arch_config.yaml")
|
||||
ENVOY_CONFIG_FILE_RENDERED = os.getenv(
|
||||
"ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml"
|
||||
)
|
||||
ARCH_CONFIG_SCHEMA_FILE = os.getenv(
|
||||
"ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml"
|
||||
)
|
||||
|
||||
def add_secret_key_to_llm_providers(config_yaml) :
|
||||
|
||||
def add_secret_key_to_llm_providers(config_yaml):
|
||||
llm_providers = []
|
||||
for llm_provider in config_yaml.get("llm_providers", []):
|
||||
access_key_env_var = llm_provider.get('access_key', False)
|
||||
access_key_env_var = llm_provider.get("access_key", False)
|
||||
access_key_value = os.getenv(access_key_env_var, False)
|
||||
if access_key_env_var and access_key_value:
|
||||
llm_provider['access_key'] = access_key_value
|
||||
llm_provider["access_key"] = access_key_value
|
||||
llm_providers.append(llm_provider)
|
||||
config_yaml["llm_providers"] = llm_providers
|
||||
return config_yaml
|
||||
|
||||
|
||||
def validate_and_render_schema():
|
||||
env = Environment(loader=FileSystemLoader('./'))
|
||||
template = env.get_template('envoy.template.yaml')
|
||||
env = Environment(loader=FileSystemLoader("./"))
|
||||
template = env.get_template("envoy.template.yaml")
|
||||
|
||||
try:
|
||||
validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
exit(1) # validate_prompt_config failed. Exit
|
||||
exit(1) # validate_prompt_config failed. Exit
|
||||
|
||||
with open(ARCH_CONFIG_FILE, 'r') as file:
|
||||
with open(ARCH_CONFIG_FILE, "r") as file:
|
||||
arch_config = file.read()
|
||||
|
||||
with open(ARCH_CONFIG_SCHEMA_FILE, 'r') as file:
|
||||
with open(ARCH_CONFIG_SCHEMA_FILE, "r") as file:
|
||||
arch_config_schema = file.read()
|
||||
|
||||
config_yaml = yaml.safe_load(arch_config)
|
||||
|
|
@ -44,7 +52,7 @@ def validate_and_render_schema():
|
|||
if name not in inferred_clusters:
|
||||
inferred_clusters[name] = {
|
||||
"name": name,
|
||||
"port": 80, # default port
|
||||
"port": 80, # default port
|
||||
}
|
||||
|
||||
print(inferred_clusters)
|
||||
|
|
@ -55,14 +63,13 @@ def validate_and_render_schema():
|
|||
if name in inferred_clusters:
|
||||
print("updating cluster", endpoint_details)
|
||||
inferred_clusters[name].update(endpoint_details)
|
||||
endpoint = inferred_clusters[name]['endpoint']
|
||||
if len(endpoint.split(':')) > 1:
|
||||
inferred_clusters[name]['endpoint'] = endpoint.split(':')[0]
|
||||
inferred_clusters[name]['port'] = int(endpoint.split(':')[1])
|
||||
endpoint = inferred_clusters[name]["endpoint"]
|
||||
if len(endpoint.split(":")) > 1:
|
||||
inferred_clusters[name]["endpoint"] = endpoint.split(":")[0]
|
||||
inferred_clusters[name]["port"] = int(endpoint.split(":")[1])
|
||||
else:
|
||||
inferred_clusters[name] = endpoint_details
|
||||
|
||||
|
||||
print("updated clusters", inferred_clusters)
|
||||
|
||||
config_yaml = add_secret_key_to_llm_providers(config_yaml)
|
||||
|
|
@ -71,23 +78,24 @@ def validate_and_render_schema():
|
|||
arch_config_string = yaml.dump(config_yaml)
|
||||
|
||||
data = {
|
||||
'arch_config': arch_config_string,
|
||||
'arch_clusters': inferred_clusters,
|
||||
'arch_llm_providers': arch_llm_providers,
|
||||
'arch_tracing': arch_tracing
|
||||
"arch_config": arch_config_string,
|
||||
"arch_clusters": inferred_clusters,
|
||||
"arch_llm_providers": arch_llm_providers,
|
||||
"arch_tracing": arch_tracing,
|
||||
}
|
||||
|
||||
rendered = template.render(data)
|
||||
print(rendered)
|
||||
print(ENVOY_CONFIG_FILE_RENDERED)
|
||||
with open(ENVOY_CONFIG_FILE_RENDERED, 'w') as file:
|
||||
with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file:
|
||||
file.write(rendered)
|
||||
|
||||
|
||||
def validate_prompt_config(arch_config_file, arch_config_schema_file):
|
||||
with open(arch_config_file, 'r') as file:
|
||||
with open(arch_config_file, "r") as file:
|
||||
arch_config = file.read()
|
||||
|
||||
with open(arch_config_schema_file, 'r') as file:
|
||||
with open(arch_config_schema_file, "r") as file:
|
||||
arch_config_schema = file.read()
|
||||
|
||||
config_yaml = yaml.safe_load(arch_config)
|
||||
|
|
@ -96,8 +104,11 @@ def validate_prompt_config(arch_config_file, arch_config_schema_file):
|
|||
try:
|
||||
validate(config_yaml, config_schema_yaml)
|
||||
except Exception as e:
|
||||
print(f"Error validating arch_config file: {arch_config_file}, error: {e.message}")
|
||||
print(
|
||||
f"Error validating arch_config file: {arch_config_file}, error: {e.message}"
|
||||
)
|
||||
raise e
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
validate_and_render_schema()
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import pkg_resources
|
|||
import select
|
||||
from utils import run_docker_compose_ps, print_service_status, check_services_state
|
||||
|
||||
|
||||
def start_arch(arch_config_file, env, log_timeout=120):
|
||||
"""
|
||||
Start Docker Compose in detached mode and stream logs until services are healthy.
|
||||
|
|
@ -14,22 +15,35 @@ def start_arch(arch_config_file, env, log_timeout=120):
|
|||
log_timeout (int): Time in seconds to show logs before checking for healthy state.
|
||||
"""
|
||||
|
||||
compose_file = pkg_resources.resource_filename(__name__, 'config/docker-compose.yaml')
|
||||
compose_file = pkg_resources.resource_filename(
|
||||
__name__, "config/docker-compose.yaml"
|
||||
)
|
||||
|
||||
try:
|
||||
# Run the Docker Compose command in detached mode (-d)
|
||||
subprocess.run(
|
||||
["docker", "compose", "-p", "arch", "up", "-d",],
|
||||
cwd=os.path.dirname(compose_file), # Ensure the Docker command runs in the correct path
|
||||
env=env, # Pass the modified environment
|
||||
check=True # Raise an exception if the command fails
|
||||
[
|
||||
"docker",
|
||||
"compose",
|
||||
"-p",
|
||||
"arch",
|
||||
"up",
|
||||
"-d",
|
||||
],
|
||||
cwd=os.path.dirname(
|
||||
compose_file
|
||||
), # Ensure the Docker command runs in the correct path
|
||||
env=env, # Pass the modified environment
|
||||
check=True, # Raise an exception if the command fails
|
||||
)
|
||||
print(f"Arch docker-compose started in detached.")
|
||||
print("Monitoring `docker-compose ps` logs...")
|
||||
|
||||
start_time = time.time()
|
||||
services_status = {}
|
||||
services_running = False #assume that the services are not running at the moment
|
||||
services_running = (
|
||||
False # assume that the services are not running at the moment
|
||||
)
|
||||
|
||||
while True:
|
||||
current_time = time.time()
|
||||
|
|
@ -40,16 +54,22 @@ def start_arch(arch_config_file, env, log_timeout=120):
|
|||
print(f"Stopping log monitoring after {log_timeout} seconds.")
|
||||
break
|
||||
|
||||
current_services_status = run_docker_compose_ps(compose_file=compose_file, env=env)
|
||||
current_services_status = run_docker_compose_ps(
|
||||
compose_file=compose_file, env=env
|
||||
)
|
||||
if not current_services_status:
|
||||
print("Status for the services could not be detected. Something went wrong. Please run docker logs")
|
||||
print(
|
||||
"Status for the services could not be detected. Something went wrong. Please run docker logs"
|
||||
)
|
||||
break
|
||||
|
||||
if not services_status:
|
||||
services_status = current_services_status #set the first time
|
||||
print_service_status(services_status) #print the services status and proceed.
|
||||
services_status = current_services_status # set the first time
|
||||
print_service_status(
|
||||
services_status
|
||||
) # print the services status and proceed.
|
||||
|
||||
#check if anyone service is failed or exited state, if so print and break out
|
||||
# check if anyone service is failed or exited state, if so print and break out
|
||||
unhealthy_states = ["unhealthy", "exit", "exited", "dead", "bad"]
|
||||
running_states = ["running", "up"]
|
||||
|
||||
|
|
@ -58,14 +78,23 @@ def start_arch(arch_config_file, env, log_timeout=120):
|
|||
break
|
||||
|
||||
if check_services_state(current_services_status, unhealthy_states):
|
||||
print("One or more Arch services are unhealthy. Please run `docker logs` for more information")
|
||||
print_service_status(current_services_status) #print the services status and proceed.
|
||||
print(
|
||||
"One or more Arch services are unhealthy. Please run `docker logs` for more information"
|
||||
)
|
||||
print_service_status(
|
||||
current_services_status
|
||||
) # print the services status and proceed.
|
||||
break
|
||||
|
||||
#check to see if the status of one of the services has changed from prior. Print and loop over until finish, or error
|
||||
# check to see if the status of one of the services has changed from prior. Print and loop over until finish, or error
|
||||
for service_name in services_status.keys():
|
||||
if services_status[service_name]['State'] != current_services_status[service_name]['State']:
|
||||
print("One or more Arch services have changed state. Printing current state")
|
||||
if (
|
||||
services_status[service_name]["State"]
|
||||
!= current_services_status[service_name]["State"]
|
||||
):
|
||||
print(
|
||||
"One or more Arch services have changed state. Printing current state"
|
||||
)
|
||||
print_service_status(current_services_status)
|
||||
break
|
||||
|
||||
|
|
@ -82,7 +111,9 @@ def stop_arch():
|
|||
Args:
|
||||
path (str): The path where the docker-compose.yml file is located.
|
||||
"""
|
||||
compose_file = pkg_resources.resource_filename(__name__, 'config/docker-compose.yaml')
|
||||
compose_file = pkg_resources.resource_filename(
|
||||
__name__, "config/docker-compose.yaml"
|
||||
)
|
||||
|
||||
try:
|
||||
# Run `docker-compose down` to shut down all services
|
||||
|
|
@ -96,6 +127,7 @@ def stop_arch():
|
|||
except subprocess.CalledProcessError as e:
|
||||
print(f"Failed to shut down services: {str(e)}")
|
||||
|
||||
|
||||
def start_arch_modelserver():
|
||||
"""
|
||||
Start the model server. This assumes that the archgw_modelserver package is installed locally
|
||||
|
|
@ -103,15 +135,14 @@ def start_arch_modelserver():
|
|||
"""
|
||||
try:
|
||||
subprocess.run(
|
||||
['archgw_modelserver', 'restart'],
|
||||
check=True,
|
||||
start_new_session=True
|
||||
["archgw_modelserver", "restart"], check=True, start_new_session=True
|
||||
)
|
||||
print("Successfull run the archgw model_server")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print (f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
print(f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def stop_arch_modelserver():
|
||||
"""
|
||||
Stop the model server. This assumes that the archgw_modelserver package is installed locally
|
||||
|
|
@ -119,10 +150,10 @@ def stop_arch_modelserver():
|
|||
"""
|
||||
try:
|
||||
subprocess.run(
|
||||
['archgw_modelserver', 'stop'],
|
||||
["archgw_modelserver", "stop"],
|
||||
check=True,
|
||||
)
|
||||
print("Successfull stopped the archgw model_server")
|
||||
except subprocess.CalledProcessError as e:
|
||||
print (f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
print(f"Failed to start model_server. Please check archgw_modelserver logs")
|
||||
sys.exit(1)
|
||||
|
|
|
|||
|
|
@ -6,17 +6,29 @@ setup(
|
|||
description="Python-based CLI tool to manage Arch and generate targets.",
|
||||
author="Katanemo Labs, Inc.",
|
||||
packages=find_packages(),
|
||||
py_modules = ['cli', 'core', 'targets', 'utils', 'config_generator'],
|
||||
py_modules=["cli", "core", "targets", "utils", "config_generator"],
|
||||
include_package_data=True,
|
||||
# Specify to include the docker-compose.yml file
|
||||
package_data={
|
||||
'': ['config/docker-compose.yaml', 'config/arch_config_schema.yaml', 'config/stage.env'] #Specify to include the docker-compose.yml file
|
||||
"": [
|
||||
"config/docker-compose.yaml",
|
||||
"config/arch_config_schema.yaml",
|
||||
"config/stage.env",
|
||||
] # Specify to include the docker-compose.yml file
|
||||
},
|
||||
# Add dependencies here, e.g., 'PyYAML' for YAML processing
|
||||
install_requires=['pyyaml', 'pydantic', 'click', 'jinja2','pyyaml','jsonschema', 'setuptools'],
|
||||
install_requires=[
|
||||
"pyyaml",
|
||||
"pydantic",
|
||||
"click",
|
||||
"jinja2",
|
||||
"pyyaml",
|
||||
"jsonschema",
|
||||
"setuptools",
|
||||
],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'archgw=cli:main',
|
||||
"console_scripts": [
|
||||
"archgw=cli:main",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,14 +18,20 @@ def detect_framework(tree: Any) -> str:
|
|||
return "fastapi"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def get_route_decorators(node: Any, framework: str) -> list:
|
||||
"""Extract route decorators based on the framework."""
|
||||
decorators = []
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute):
|
||||
if isinstance(decorator, ast.Call) and isinstance(
|
||||
decorator.func, ast.Attribute
|
||||
):
|
||||
if framework == "flask" and decorator.func.attr in FLASK_ROUTE_DECORATORS:
|
||||
decorators.append(decorator.func.attr)
|
||||
elif framework == "fastapi" and decorator.func.attr in FASTAPI_ROUTE_DECORATORS:
|
||||
elif (
|
||||
framework == "fastapi"
|
||||
and decorator.func.attr in FASTAPI_ROUTE_DECORATORS
|
||||
):
|
||||
decorators.append(decorator.func.attr)
|
||||
return decorators
|
||||
|
||||
|
|
@ -36,6 +42,7 @@ def get_route_path(node: Any, framework: str) -> str:
|
|||
if isinstance(decorator, ast.Call) and decorator.args:
|
||||
return decorator.args[0].s # Assuming it's a string literal
|
||||
|
||||
|
||||
def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool:
|
||||
"""Check if a given type annotation is a Pydantic model."""
|
||||
# We walk through the AST to find class definitions and check if they inherit from Pydantic's BaseModel
|
||||
|
|
@ -47,6 +54,7 @@ def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool:
|
|||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
|
||||
"""Extract fields from a Pydantic model, handling list, tuple, set, dict types, and direct default values."""
|
||||
fields = []
|
||||
|
|
@ -62,15 +70,26 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
|
|||
required = True # Assume the field is required initially
|
||||
|
||||
# Check if the field uses Field() with required status and description
|
||||
if stmt.value and isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == 'Field':
|
||||
if (
|
||||
stmt.value
|
||||
and isinstance(stmt.value, ast.Call)
|
||||
and isinstance(stmt.value.func, ast.Name)
|
||||
and stmt.value.func.id == "Field"
|
||||
):
|
||||
# Extract the description argument inside the Field call
|
||||
for keyword in stmt.value.keywords:
|
||||
if keyword.arg == 'description' and isinstance(keyword.value, ast.Str):
|
||||
if keyword.arg == "description" and isinstance(
|
||||
keyword.value, ast.Str
|
||||
):
|
||||
description = keyword.value.s
|
||||
if keyword.arg == 'default':
|
||||
if keyword.arg == "default":
|
||||
default_value = keyword.value
|
||||
# If Ellipsis (...) is used, it means the field is required
|
||||
if stmt.value.args and isinstance(stmt.value.args[0], ast.Constant) and stmt.value.args[0].value is Ellipsis:
|
||||
if (
|
||||
stmt.value.args
|
||||
and isinstance(stmt.value.args[0], ast.Constant)
|
||||
and stmt.value.args[0].value is Ellipsis
|
||||
):
|
||||
required = True
|
||||
else:
|
||||
required = False
|
||||
|
|
@ -80,19 +99,28 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
|
|||
if isinstance(stmt.value, ast.Constant):
|
||||
# Set the default value from the assignment (e.g., name: str = "John Doe")
|
||||
default_value = stmt.value.value
|
||||
required = False # Not required since it has a default value
|
||||
required = (
|
||||
False # Not required since it has a default value
|
||||
)
|
||||
|
||||
# Always extract the field type, even if there's a default value
|
||||
if isinstance(stmt.annotation, ast.Subscript):
|
||||
# Get the base type (list, tuple, set, dict)
|
||||
base_type = stmt.annotation.value.id if isinstance(stmt.annotation.value, ast.Name) else "Unknown"
|
||||
base_type = (
|
||||
stmt.annotation.value.id
|
||||
if isinstance(stmt.annotation.value, ast.Name)
|
||||
else "Unknown"
|
||||
)
|
||||
|
||||
# Handle only list, tuple, set, dict and ignore the inner types
|
||||
if base_type.lower() in ['list', 'tuple', 'set', 'dict']:
|
||||
if base_type.lower() in ["list", "tuple", "set", "dict"]:
|
||||
field_type = base_type.lower()
|
||||
|
||||
# Handle the ellipsis '...' for required fields if no Field() call
|
||||
elif isinstance(stmt.value, ast.Constant) and stmt.value.value is Ellipsis:
|
||||
elif (
|
||||
isinstance(stmt.value, ast.Constant)
|
||||
and stmt.value.value is Ellipsis
|
||||
):
|
||||
required = True
|
||||
|
||||
# Handle simple types like str, int, etc.
|
||||
|
|
@ -100,16 +128,17 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
|
|||
field_type = stmt.annotation.id
|
||||
|
||||
field_info = {
|
||||
"name": stmt.target.id,
|
||||
"type": field_type, # Always set the field type
|
||||
"description": description,
|
||||
"default": default_value, # Handle direct default values
|
||||
"required": required
|
||||
"name": stmt.target.id,
|
||||
"type": field_type, # Always set the field type
|
||||
"description": description,
|
||||
"default": default_value, # Handle direct default values
|
||||
"required": required,
|
||||
}
|
||||
fields.append(field_info)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
|
||||
"""Extract the parameters and their types from the function definition."""
|
||||
parameters = []
|
||||
|
|
@ -119,40 +148,68 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
|
|||
arg_descriptions = extract_arg_descriptions_from_docstring(docstring)
|
||||
|
||||
# Extract default values
|
||||
defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + node.args.defaults # Align defaults with args
|
||||
defaults = [None] * (
|
||||
len(node.args.args) - len(node.args.defaults)
|
||||
) + node.args.defaults # Align defaults with args
|
||||
for arg, default in zip(node.args.args, defaults):
|
||||
if arg.arg != "self": # Skip 'self' or 'cls' in class methods
|
||||
param_info = {"name": arg.arg, "description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]")}
|
||||
param_info = {
|
||||
"name": arg.arg,
|
||||
"description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]"),
|
||||
}
|
||||
|
||||
# Handle Pydantic model types
|
||||
if hasattr(arg, 'annotation') and is_pydantic_model(arg.annotation, tree):
|
||||
if hasattr(arg, "annotation") and is_pydantic_model(arg.annotation, tree):
|
||||
# Extract and flatten Pydantic model fields
|
||||
pydantic_fields = get_pydantic_model_fields(arg.annotation.id, tree)
|
||||
parameters.extend(pydantic_fields) # Flatten the model fields into the parameters list
|
||||
parameters.extend(
|
||||
pydantic_fields
|
||||
) # Flatten the model fields into the parameters list
|
||||
continue # Skip adding the current param_info for the model since we expand the fields
|
||||
|
||||
# Handle standard Python types (int, float, str, etc.)
|
||||
elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Name):
|
||||
if arg.annotation.id in ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']:
|
||||
elif hasattr(arg, "annotation") and isinstance(arg.annotation, ast.Name):
|
||||
if arg.annotation.id in [
|
||||
"int",
|
||||
"float",
|
||||
"bool",
|
||||
"str",
|
||||
"list",
|
||||
"tuple",
|
||||
"set",
|
||||
"dict",
|
||||
]:
|
||||
param_info["type"] = arg.annotation.id
|
||||
else:
|
||||
param_info["type"] = "[UNKNOWN - PLEASE FIX]"
|
||||
|
||||
# Handle generic subscript types (e.g., Optional, List[Type], etc.)
|
||||
elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Subscript):
|
||||
if isinstance(arg.annotation.value, ast.Name) and arg.annotation.value.id in ['list', 'tuple', 'set', 'dict']:
|
||||
param_info["type"] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
|
||||
elif hasattr(arg, "annotation") and isinstance(
|
||||
arg.annotation, ast.Subscript
|
||||
):
|
||||
if isinstance(
|
||||
arg.annotation.value, ast.Name
|
||||
) and arg.annotation.value.id in ["list", "tuple", "set", "dict"]:
|
||||
param_info[
|
||||
"type"
|
||||
] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
|
||||
else:
|
||||
param_info["type"] = "[UNKNOWN - PLEASE FIX]"
|
||||
|
||||
# Default for unknown types
|
||||
else:
|
||||
param_info["type"] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type
|
||||
param_info[
|
||||
"type"
|
||||
] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type
|
||||
|
||||
# Handle default values
|
||||
if default is not None:
|
||||
if isinstance(default, ast.Constant) or isinstance(default, ast.NameConstant):
|
||||
param_info["default"] = default.value # Use the default value directly
|
||||
if isinstance(default, ast.Constant) or isinstance(
|
||||
default, ast.NameConstant
|
||||
):
|
||||
param_info[
|
||||
"default"
|
||||
] = default.value # Use the default value directly
|
||||
else:
|
||||
param_info["default"] = "[UNKNOWN DEFAULT]" # Unknown default type
|
||||
param_info["required"] = False # Optional since it has a default value
|
||||
|
|
@ -164,6 +221,7 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
|
|||
|
||||
return parameters
|
||||
|
||||
|
||||
def get_function_docstring(node: Any) -> str:
|
||||
"""Extract the function's docstring description if present."""
|
||||
# Check if the first node is a docstring
|
||||
|
|
@ -178,6 +236,7 @@ def get_function_docstring(node: Any) -> str:
|
|||
|
||||
return "No description provided."
|
||||
|
||||
|
||||
def extract_arg_descriptions_from_docstring(docstring: str) -> dict:
|
||||
"""Extract descriptions for function parameters from the 'Args' section of the docstring."""
|
||||
descriptions = {}
|
||||
|
|
@ -195,14 +254,14 @@ def extract_arg_descriptions_from_docstring(docstring: str) -> dict:
|
|||
continue # Proceed to the next line after 'Args:'
|
||||
|
||||
# End of 'Args' section if no indentation and no colon
|
||||
if in_args_section and not line.startswith(" ") and ':' not in line:
|
||||
if in_args_section and not line.startswith(" ") and ":" not in line:
|
||||
break # Stop processing if we reach a new section
|
||||
|
||||
# Process lines in the 'Args' section
|
||||
if in_args_section:
|
||||
if ':' in line:
|
||||
if ":" in line:
|
||||
# Extract parameter name and description
|
||||
param_name, description = line.split(':', 1)
|
||||
param_name, description = line.split(":", 1)
|
||||
descriptions[param_name.strip()] = description.strip()
|
||||
current_param = param_name.strip()
|
||||
elif current_param and line.startswith(" "):
|
||||
|
|
@ -230,43 +289,50 @@ def generate_prompt_targets(input_file_path: str) -> None:
|
|||
route_decorators = get_route_decorators(node, framework)
|
||||
if route_decorators:
|
||||
route_path = get_route_path(node, framework)
|
||||
function_params = get_function_parameters(node, tree) # Get parameters for the route
|
||||
function_params = get_function_parameters(
|
||||
node, tree
|
||||
) # Get parameters for the route
|
||||
function_docstring = get_function_docstring(node) # Extract docstring
|
||||
routes.append({
|
||||
'name': node.name,
|
||||
'path': route_path,
|
||||
'methods': route_decorators,
|
||||
'parameters': function_params, # Add parameters to the route
|
||||
'description': function_docstring # Add the docstring as the description
|
||||
})
|
||||
routes.append(
|
||||
{
|
||||
"name": node.name,
|
||||
"path": route_path,
|
||||
"methods": route_decorators,
|
||||
"parameters": function_params, # Add parameters to the route
|
||||
"description": function_docstring, # Add the docstring as the description
|
||||
}
|
||||
)
|
||||
|
||||
# Generate YAML structure
|
||||
output_structure = {
|
||||
"prompt_targets": []
|
||||
}
|
||||
output_structure = {"prompt_targets": []}
|
||||
|
||||
for route in routes:
|
||||
target = {
|
||||
"name": route['name'],
|
||||
"name": route["name"],
|
||||
"endpoint": [
|
||||
{
|
||||
"name": "app_server",
|
||||
"path": route['path'],
|
||||
"path": route["path"],
|
||||
}
|
||||
],
|
||||
"description": route['description'], # Use extracted docstring
|
||||
"description": route["description"], # Use extracted docstring
|
||||
"parameters": [
|
||||
{
|
||||
"name": param['name'],
|
||||
"type": param['type'],
|
||||
"name": param["name"],
|
||||
"type": param["type"],
|
||||
"description": f"{param['description']}",
|
||||
**({"default": param['default']} if "default" in param and param['default'] is not None else {}), # Only add default if it's set
|
||||
"required": param['required']
|
||||
} for param in route['parameters']
|
||||
]
|
||||
**(
|
||||
{"default": param["default"]}
|
||||
if "default" in param and param["default"] is not None
|
||||
else {}
|
||||
), # Only add default if it's set
|
||||
"required": param["required"],
|
||||
}
|
||||
for param in route["parameters"]
|
||||
],
|
||||
}
|
||||
|
||||
if route['name'] == "default":
|
||||
if route["name"] == "default":
|
||||
# Special case for `information_extraction` based on your YAML format
|
||||
target["type"] = "default"
|
||||
target["auto-llm-dispatch-on-response"] = True
|
||||
|
|
@ -274,9 +340,12 @@ def generate_prompt_targets(input_file_path: str) -> None:
|
|||
output_structure["prompt_targets"].append(target)
|
||||
|
||||
# Output as YAML
|
||||
print(yaml.dump(output_structure, sort_keys=False,default_flow_style=False, indent=3))
|
||||
print(
|
||||
yaml.dump(output_structure, sort_keys=False, default_flow_style=False, indent=3)
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python targets.py <input_file>")
|
||||
sys.exit(1)
|
||||
|
|
|
|||
|
|
@ -4,12 +4,23 @@ from typing import List, Dict, Set
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
name: str = Field("John Doe", description="The name of the user.") # Default value and description for name
|
||||
name: str = Field(
|
||||
"John Doe", description="The name of the user."
|
||||
) # Default value and description for name
|
||||
location: int = None
|
||||
age: int = Field(30, description="The age of the user.") # Default value and description for age
|
||||
tags: Set[str] = Field(default_factory=set, description="A set of tags associated with the user.") # Default empty set and description for tags
|
||||
metadata: Dict[str, int] = Field(default_factory=dict, description="A dictionary storing metadata about the user, with string keys and integer values.") # Default empty dict and description for metadata
|
||||
age: int = Field(
|
||||
30, description="The age of the user."
|
||||
) # Default value and description for age
|
||||
tags: Set[str] = Field(
|
||||
default_factory=set, description="A set of tags associated with the user."
|
||||
) # Default empty set and description for tags
|
||||
metadata: Dict[str, int] = Field(
|
||||
default_factory=dict,
|
||||
description="A dictionary storing metadata about the user, with string keys and integer values.",
|
||||
) # Default empty dict and description for metadata
|
||||
|
||||
|
||||
@app.get("/agent/default")
|
||||
async def default(request: User):
|
||||
|
|
@ -19,6 +30,7 @@ async def default(request: User):
|
|||
"""
|
||||
return {"info": f"Query: {request.name}, Count: {request.age}"}
|
||||
|
||||
|
||||
@app.post("/agent/action")
|
||||
async def reboot_network_device(device_id: str, confirmation: str):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import shlex
|
|||
import yaml
|
||||
import json
|
||||
|
||||
|
||||
def run_docker_compose_ps(compose_file, env):
|
||||
"""
|
||||
Check if all Docker Compose services are in a healthy state.
|
||||
|
|
@ -16,20 +17,31 @@ def run_docker_compose_ps(compose_file, env):
|
|||
try:
|
||||
# Run `docker-compose ps` to get the health status of each service
|
||||
ps_process = subprocess.Popen(
|
||||
["docker", "compose", "-p", "arch", "ps", "--format", "table{{.Service}}\t{{.State}}\t{{.Ports}}"],
|
||||
[
|
||||
"docker",
|
||||
"compose",
|
||||
"-p",
|
||||
"arch",
|
||||
"ps",
|
||||
"--format",
|
||||
"table{{.Service}}\t{{.State}}\t{{.Ports}}",
|
||||
],
|
||||
cwd=os.path.dirname(compose_file),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
start_new_session=True,
|
||||
env=env
|
||||
env=env,
|
||||
)
|
||||
# Capture the output of `docker-compose ps`
|
||||
services_status, error_output = ps_process.communicate()
|
||||
|
||||
# Check if there is any error output
|
||||
if error_output:
|
||||
print(f"Error while checking service status:\n{error_output}", file=os.sys.stderr)
|
||||
print(
|
||||
f"Error while checking service status:\n{error_output}",
|
||||
file=os.sys.stderr,
|
||||
)
|
||||
return {}
|
||||
|
||||
services = parse_docker_compose_ps_output(services_status)
|
||||
|
|
@ -39,26 +51,31 @@ def run_docker_compose_ps(compose_file, env):
|
|||
print(f"Failed to check service status. Error:\n{e.stderr}")
|
||||
return e
|
||||
|
||||
#Helper method to print service status
|
||||
|
||||
# Helper method to print service status
|
||||
def print_service_status(services):
|
||||
print(f"{'Service Name':<25} {'State':<20} {'Ports'}")
|
||||
print("="*72)
|
||||
print("=" * 72)
|
||||
for service_name, info in services.items():
|
||||
status = info['STATE']
|
||||
ports = info['PORTS']
|
||||
status = info["STATE"]
|
||||
ports = info["PORTS"]
|
||||
print(f"{service_name:<25} {status:<20} {ports}")
|
||||
|
||||
#check for states based on the states passed in
|
||||
|
||||
# check for states based on the states passed in
|
||||
def check_services_state(services, states):
|
||||
for service_name, service_info in services.items():
|
||||
status = service_info['STATE'].lower() # Convert status to lowercase for easier comparison
|
||||
status = service_info[
|
||||
"STATE"
|
||||
].lower() # Convert status to lowercase for easier comparison
|
||||
if any(state in status for state in states):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_llm_provider_access_keys(arch_config_file):
|
||||
with open(arch_config_file, 'r') as file:
|
||||
with open(arch_config_file, "r") as file:
|
||||
arch_config = file.read()
|
||||
arch_config_yaml = yaml.safe_load(arch_config)
|
||||
|
||||
|
|
@ -70,22 +87,23 @@ def get_llm_provider_access_keys(arch_config_file):
|
|||
|
||||
return access_key_list
|
||||
|
||||
|
||||
def load_env_file_to_dict(file_path):
|
||||
env_dict = {}
|
||||
|
||||
# Open and read the .env file
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
for line in file:
|
||||
# Strip any leading/trailing whitespaces
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines and comments
|
||||
if not line or line.startswith('#'):
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
# Split the line into key and value at the first '=' sign
|
||||
if '=' in line:
|
||||
key, value = line.split('=', 1)
|
||||
if "=" in line:
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
|
|
@ -94,6 +112,7 @@ def load_env_file_to_dict(file_path):
|
|||
|
||||
return env_dict
|
||||
|
||||
|
||||
def parse_docker_compose_ps_output(output):
|
||||
# Split the output into lines
|
||||
lines = output.strip().splitlines()
|
||||
|
|
@ -111,10 +130,7 @@ def parse_docker_compose_ps_output(output):
|
|||
parts = line.split()
|
||||
|
||||
# Create a dictionary entry using the header names
|
||||
service_info = {
|
||||
headers[1]: parts[1], # State
|
||||
headers[2]: parts[2] # Ports
|
||||
}
|
||||
service_info = {headers[1]: parts[1], headers[2]: parts[2]} # State # Ports
|
||||
|
||||
# Add to the result dictionary using the service name as the key
|
||||
services[parts[0]] = service_info
|
||||
|
|
|
|||
|
|
@ -7,47 +7,53 @@ from dotenv import load_dotenv
|
|||
|
||||
load_dotenv()
|
||||
|
||||
OPENAI_API_KEY=os.getenv("OPENAI_API_KEY")
|
||||
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
||||
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
|
||||
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
||||
ARCH_STATE_HEADER = 'x-arch-state'
|
||||
ARCH_STATE_HEADER = "x-arch-state"
|
||||
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
|
||||
client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT, http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}))
|
||||
client = OpenAI(
|
||||
api_key=OPENAI_API_KEY,
|
||||
base_url=CHAT_COMPLETION_ENDPOINT,
|
||||
http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}),
|
||||
)
|
||||
|
||||
|
||||
def predict(message, state):
|
||||
if 'history' not in state:
|
||||
state['history'] = []
|
||||
if "history" not in state:
|
||||
state["history"] = []
|
||||
history = state.get("history")
|
||||
history.append({"role": "user", "content": message})
|
||||
log.info("history: ", history)
|
||||
|
||||
# Custom headers
|
||||
custom_headers = {
|
||||
'x-arch-openai-api-key': f"{OPENAI_API_KEY}",
|
||||
'x-arch-mistral-api-key': f"{MISTRAL_API_KEY}",
|
||||
'x-arch-deterministic-provider': 'openai',
|
||||
"x-arch-openai-api-key": f"{OPENAI_API_KEY}",
|
||||
"x-arch-mistral-api-key": f"{MISTRAL_API_KEY}",
|
||||
"x-arch-deterministic-provider": "openai",
|
||||
}
|
||||
|
||||
metadata = None
|
||||
if 'arch_state' in state:
|
||||
metadata = {ARCH_STATE_HEADER: state['arch_state']}
|
||||
if "arch_state" in state:
|
||||
metadata = {ARCH_STATE_HEADER: state["arch_state"]}
|
||||
|
||||
try:
|
||||
raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME,
|
||||
messages = history,
|
||||
temperature=1.0,
|
||||
metadata=metadata,
|
||||
extra_headers=custom_headers
|
||||
)
|
||||
raw_response = client.chat.completions.with_raw_response.create(
|
||||
model=MODEL_NAME,
|
||||
messages=history,
|
||||
temperature=1.0,
|
||||
metadata=metadata,
|
||||
extra_headers=custom_headers,
|
||||
)
|
||||
except Exception as e:
|
||||
log.info(e)
|
||||
# remove last user message in case of exception
|
||||
history.pop()
|
||||
log.info("Error calling gateway API: {}".format(e.message))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e.message))
|
||||
log.info(e)
|
||||
# remove last user message in case of exception
|
||||
history.pop()
|
||||
log.info("Error calling gateway API: {}".format(e.message))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e.message))
|
||||
|
||||
log.info("raw_response: ", raw_response.text)
|
||||
response = raw_response.parse()
|
||||
|
|
@ -57,24 +63,33 @@ def predict(message, state):
|
|||
response_json = json.loads(raw_response.text)
|
||||
arch_state = None
|
||||
if response_json:
|
||||
metadata = response_json.get('metadata', {})
|
||||
if metadata:
|
||||
arch_state = metadata.get(ARCH_STATE_HEADER, None)
|
||||
metadata = response_json.get("metadata", {})
|
||||
if metadata:
|
||||
arch_state = metadata.get(ARCH_STATE_HEADER, None)
|
||||
if arch_state:
|
||||
state['arch_state'] = arch_state
|
||||
state["arch_state"] = arch_state
|
||||
|
||||
content = response.choices[0].message.content
|
||||
|
||||
history.append({"role": "assistant", "content": content, "model": response.model})
|
||||
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
|
||||
messages = [
|
||||
(history[i]["content"], history[i + 1]["content"])
|
||||
for i in range(0, len(history) - 1, 2)
|
||||
]
|
||||
return messages, state
|
||||
|
||||
|
||||
with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo:
|
||||
print("Starting Demo...")
|
||||
chatbot = gr.Chatbot(label="Arch Chatbot", scale=1)
|
||||
state = gr.State({})
|
||||
with gr.Row():
|
||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1, autofocus=True)
|
||||
txt = gr.Textbox(
|
||||
show_label=False,
|
||||
placeholder="Enter text and press enter",
|
||||
scale=1,
|
||||
autofocus=True,
|
||||
)
|
||||
|
||||
txt.submit(predict, [txt, state], [chatbot, state])
|
||||
|
||||
|
|
|
|||
|
|
@ -5,26 +5,32 @@ from openai import OpenAI
|
|||
import gradio as gr
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1")
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv(
|
||||
"CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=CHAT_COMPLETION_ENDPOINT)
|
||||
|
||||
|
||||
def predict(message, history):
|
||||
history_openai_format = []
|
||||
for human, assistant in history:
|
||||
history_openai_format.append({"role": "user", "content": human })
|
||||
history_openai_format.append({"role": "assistant", "content":assistant})
|
||||
history_openai_format.append({"role": "user", "content": human})
|
||||
history_openai_format.append({"role": "assistant", "content": assistant})
|
||||
history_openai_format.append({"role": "user", "content": message})
|
||||
|
||||
response = client.chat.completions.create(model='gpt-3.5-turbo',
|
||||
messages= history_openai_format,
|
||||
temperature=1.0,
|
||||
stream=True)
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=history_openai_format,
|
||||
temperature=1.0,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
partial_message = ""
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
partial_message = partial_message + chunk.choices[0].delta.content
|
||||
yield partial_message
|
||||
partial_message = partial_message + chunk.choices[0].delta.content
|
||||
yield partial_message
|
||||
|
||||
|
||||
gr.ChatInterface(predict).launch(server_name="0.0.0.0", server_port=8081)
|
||||
|
|
|
|||
|
|
@ -6,56 +6,54 @@ import logging
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
logger = logging.getLogger('uvicorn.error')
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
return {
|
||||
"status": "ok"
|
||||
}
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
class WeatherRequest(BaseModel):
|
||||
city: str
|
||||
days: int = 7
|
||||
units: str = "Farenheit"
|
||||
city: str
|
||||
days: int = 7
|
||||
units: str = "Farenheit"
|
||||
|
||||
|
||||
@app.post("/weather")
|
||||
async def weather(req: WeatherRequest, res: Response):
|
||||
|
||||
weather_forecast = {
|
||||
"city": req.city,
|
||||
"temperature": [],
|
||||
"units": req.units,
|
||||
}
|
||||
for i in range(7):
|
||||
min_temp = random.randrange(50,90)
|
||||
max_temp = random.randrange(min_temp+5, min_temp+20)
|
||||
if req.units.lower() == "celsius" or req.units.lower() == "c":
|
||||
min_temp = (min_temp - 32) * 5.0/9.0
|
||||
max_temp = (max_temp - 32) * 5.0/9.0
|
||||
weather_forecast["temperature"].append({
|
||||
"date": str(date.today() + timedelta(days=i)),
|
||||
"temperature": {
|
||||
"min": min_temp,
|
||||
"max": max_temp
|
||||
},
|
||||
"units": req.units,
|
||||
"query_time": str(datetime.now(timezone.utc))
|
||||
})
|
||||
min_temp = random.randrange(50, 90)
|
||||
max_temp = random.randrange(min_temp + 5, min_temp + 20)
|
||||
if req.units.lower() == "celsius" or req.units.lower() == "c":
|
||||
min_temp = (min_temp - 32) * 5.0 / 9.0
|
||||
max_temp = (max_temp - 32) * 5.0 / 9.0
|
||||
weather_forecast["temperature"].append(
|
||||
{
|
||||
"date": str(date.today() + timedelta(days=i)),
|
||||
"temperature": {"min": min_temp, "max": max_temp},
|
||||
"units": req.units,
|
||||
"query_time": str(datetime.now(timezone.utc)),
|
||||
}
|
||||
)
|
||||
|
||||
return weather_forecast
|
||||
|
||||
|
||||
class InsuranceClaimDetailsRequest(BaseModel):
|
||||
policy_number: str
|
||||
policy_number: str
|
||||
|
||||
|
||||
@app.post("/insurance_claim_details")
|
||||
async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Response):
|
||||
|
||||
claim_details = {
|
||||
"policy_number": req.policy_number,
|
||||
"claim_status": "Approved",
|
||||
|
|
@ -68,26 +66,25 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon
|
|||
|
||||
|
||||
class DefaultTargetRequest(BaseModel):
|
||||
arch_messages: list
|
||||
arch_messages: list
|
||||
|
||||
|
||||
@app.post("/default_target")
|
||||
async def default_target(req: DefaultTargetRequest, res: Response):
|
||||
logger.info(f"Received arch_messages: {req.arch_messages}")
|
||||
resp = {
|
||||
"choices": [
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "hello world from api server"
|
||||
"role": "assistant",
|
||||
"content": "hello world from api server",
|
||||
},
|
||||
"finish_reason": "completed",
|
||||
"index": 0
|
||||
}
|
||||
],
|
||||
"model": "api_server",
|
||||
"usage": {
|
||||
"completion_tokens": 0
|
||||
"index": 0,
|
||||
}
|
||||
}
|
||||
],
|
||||
"model": "api_server",
|
||||
"usage": {"completion_tokens": 0},
|
||||
}
|
||||
logger.info(f"sending response: {json.dumps(resp)}")
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -3,28 +3,40 @@ from pydantic import BaseModel, Field
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class Conversation(BaseModel):
|
||||
arch_messages: list
|
||||
|
||||
|
||||
class PolicyCoverageRequest(BaseModel):
|
||||
policy_type: str = Field(..., description="The type of a policy held by the customer For, e.g. car, boat, house, motorcycle)")
|
||||
policy_type: str = Field(
|
||||
...,
|
||||
description="The type of a policy held by the customer For, e.g. car, boat, house, motorcycle)",
|
||||
)
|
||||
|
||||
|
||||
class PolicyInitiateRequest(PolicyCoverageRequest):
|
||||
deductible: float = Field(..., description="The deductible amount set of the policy")
|
||||
deductible: float = Field(
|
||||
..., description="The deductible amount set of the policy"
|
||||
)
|
||||
|
||||
|
||||
class ClaimUpdate(BaseModel):
|
||||
claim_id: str
|
||||
notes: str # Status or details of the claim
|
||||
|
||||
|
||||
class DeductibleUpdate(BaseModel):
|
||||
policy_id: str
|
||||
deductible: float
|
||||
|
||||
|
||||
class CoverageResponse(BaseModel):
|
||||
policy_type: str
|
||||
coverage: str # Description of coverage
|
||||
premium: float # The premium cost
|
||||
|
||||
|
||||
# Get information about policy coverage
|
||||
@app.post("/policy/coverage", response_model=CoverageResponse)
|
||||
async def get_policy_coverage(req: PolicyCoverageRequest):
|
||||
|
|
@ -32,10 +44,22 @@ async def get_policy_coverage(req: PolicyCoverageRequest):
|
|||
Retrieve the coverage details for a given policy type (car, boat, house, motorcycle).
|
||||
"""
|
||||
policy_coverage = {
|
||||
"car": {"coverage": "Full car coverage with collision, liability", "premium": 500.0},
|
||||
"boat": {"coverage": "Full boat coverage including theft and storm damage", "premium": 700.0},
|
||||
"house": {"coverage": "Full house coverage including fire, theft, flood", "premium": 1000.0},
|
||||
"motorcycle": {"coverage": "Full motorcycle coverage with liability", "premium": 400.0},
|
||||
"car": {
|
||||
"coverage": "Full car coverage with collision, liability",
|
||||
"premium": 500.0,
|
||||
},
|
||||
"boat": {
|
||||
"coverage": "Full boat coverage including theft and storm damage",
|
||||
"premium": 700.0,
|
||||
},
|
||||
"house": {
|
||||
"coverage": "Full house coverage including fire, theft, flood",
|
||||
"premium": 1000.0,
|
||||
},
|
||||
"motorcycle": {
|
||||
"coverage": "Full motorcycle coverage with liability",
|
||||
"premium": 400.0,
|
||||
},
|
||||
}
|
||||
|
||||
if req.policy_type not in policy_coverage:
|
||||
|
|
@ -44,9 +68,10 @@ async def get_policy_coverage(req: PolicyCoverageRequest):
|
|||
return CoverageResponse(
|
||||
policy_type=req.policy_type,
|
||||
coverage=policy_coverage[req.policy_type]["coverage"],
|
||||
premium=policy_coverage[req.policy_type]["premium"]
|
||||
premium=policy_coverage[req.policy_type]["premium"],
|
||||
)
|
||||
|
||||
|
||||
# Initiate policy coverage
|
||||
@app.post("/policy/initiate")
|
||||
async def initiate_policy(policy_request: PolicyInitiateRequest):
|
||||
|
|
@ -56,7 +81,11 @@ async def initiate_policy(policy_request: PolicyInitiateRequest):
|
|||
if policy_request.policy_type not in ["car", "boat", "house", "motorcycle"]:
|
||||
raise HTTPException(status_code=400, detail="Invalid policy type")
|
||||
|
||||
return {"message": f"Policy initiated for {policy_request.policy_type}", "deductible": policy_request.deductible}
|
||||
return {
|
||||
"message": f"Policy initiated for {policy_request.policy_type}",
|
||||
"deductible": policy_request.deductible,
|
||||
}
|
||||
|
||||
|
||||
# Update claim details
|
||||
@app.post("/policy/claim")
|
||||
|
|
@ -65,8 +94,11 @@ async def update_claim(req: ClaimUpdate):
|
|||
Update the status or details of a claim.
|
||||
"""
|
||||
# For simplicity, this is a mock update response
|
||||
return {"message": f"Claim {claim_update.claim_id} for policy {claim_update.claim_id} has been updated",
|
||||
"update": claim_update.notes}
|
||||
return {
|
||||
"message": f"Claim {claim_update.claim_id} for policy {claim_update.claim_id} has been updated",
|
||||
"update": claim_update.notes,
|
||||
}
|
||||
|
||||
|
||||
# Update deductible amount
|
||||
@app.post("/policy/deductible")
|
||||
|
|
@ -75,8 +107,11 @@ async def update_deductible(deductible_update: DeductibleUpdate):
|
|||
Update the deductible amount for a specific policy.
|
||||
"""
|
||||
# For simplicity, this is a mock update response
|
||||
return {"message": f"Deductible for policy {deductible_update.policy_id} has been updated",
|
||||
"new_deductible": deductible_update.deductible}
|
||||
return {
|
||||
"message": f"Deductible for policy {deductible_update.policy_id} has been updated",
|
||||
"new_deductible": deductible_update.deductible,
|
||||
}
|
||||
|
||||
|
||||
# Post method for policy Q/A
|
||||
@app.post("/policy/qa")
|
||||
|
|
@ -86,21 +121,20 @@ async def policy_qa(conversation: Conversation):
|
|||
It forwards the conversation to the OpenAI client via a local proxy and returns the response.
|
||||
"""
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am a helpful insurance agent, and can only help with insurance things"
|
||||
"role": "assistant",
|
||||
"content": "I am a helpful insurance agent, and can only help with insurance things",
|
||||
},
|
||||
"finish_reason": "completed",
|
||||
"index": 0
|
||||
}
|
||||
],
|
||||
"model": "insurance_agent",
|
||||
"usage": {
|
||||
"completion_tokens": 0
|
||||
"index": 0,
|
||||
}
|
||||
}
|
||||
],
|
||||
"model": "insurance_agent",
|
||||
"usage": {"completion_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
# Run the app using:
|
||||
# uvicorn main:app --reload
|
||||
|
|
|
|||
|
|
@ -4,10 +4,14 @@ from typing import List, Optional
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
# Define the request model
|
||||
class DeviceSummaryRequest(BaseModel):
|
||||
device_ids: List[int]
|
||||
time_range: Optional[int] = Field(default=7, description="Time range in days, defaults to 7")
|
||||
time_range: Optional[int] = Field(
|
||||
default=7, description="Time range in days, defaults to 7"
|
||||
)
|
||||
|
||||
|
||||
# Define the response model
|
||||
class DeviceStatistics(BaseModel):
|
||||
|
|
@ -15,18 +19,23 @@ class DeviceStatistics(BaseModel):
|
|||
time_range: str
|
||||
data: str
|
||||
|
||||
|
||||
class DeviceSummaryResponse(BaseModel):
|
||||
statistics: List[DeviceStatistics]
|
||||
|
||||
# Request model for device reboot
|
||||
|
||||
|
||||
class DeviceRebootRequest(BaseModel):
|
||||
device_ids: List[int]
|
||||
|
||||
|
||||
# Response model for the device reboot
|
||||
class CoverageResponse(BaseModel):
|
||||
status: str
|
||||
summary: dict
|
||||
|
||||
|
||||
@app.post("/agent/device_reboot", response_model=CoverageResponse)
|
||||
def reboot_network_device(request_data: DeviceRebootRequest):
|
||||
"""
|
||||
|
|
@ -38,20 +47,21 @@ def reboot_network_device(request_data: DeviceRebootRequest):
|
|||
|
||||
# Validate 'device_ids' (This is already validated by Pydantic, but additional logic can be added if needed)
|
||||
if not device_ids:
|
||||
raise HTTPException(status_code=400, detail="'device_ids' parameter is required")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="'device_ids' parameter is required"
|
||||
)
|
||||
|
||||
# Simulate reboot operation and return the response
|
||||
statistics = []
|
||||
for device_id in device_ids:
|
||||
# Placeholder for actual data retrieval or device reboot logic
|
||||
stats = {
|
||||
"data": f"Device {device_id} has been successfully rebooted."
|
||||
}
|
||||
stats = {"data": f"Device {device_id} has been successfully rebooted."}
|
||||
statistics.append(stats)
|
||||
|
||||
# Return the response with a summary
|
||||
return CoverageResponse(status="success", summary={"device_ids": device_ids})
|
||||
|
||||
|
||||
# Post method for device summary
|
||||
@app.post("/agent/device_summary", response_model=DeviceSummaryResponse)
|
||||
def get_device_summary(request: DeviceSummaryRequest):
|
||||
|
|
@ -77,6 +87,7 @@ def get_device_summary(request: DeviceSummaryRequest):
|
|||
|
||||
return DeviceSummaryResponse(statistics=statistics)
|
||||
|
||||
|
||||
@app.post("/agent/network_summary")
|
||||
async def policy_qa():
|
||||
"""
|
||||
|
|
@ -84,21 +95,20 @@ async def policy_qa():
|
|||
It forwards the conversation to the OpenAI client via a local proxy and returns the response.
|
||||
"""
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I am a helpful networking agent, and I can help you get status for network devices or reboot them"
|
||||
"role": "assistant",
|
||||
"content": "I am a helpful networking agent, and I can help you get status for network devices or reboot them",
|
||||
},
|
||||
"finish_reason": "completed",
|
||||
"index": 0
|
||||
}
|
||||
],
|
||||
"model": "network_agent",
|
||||
"usage": {
|
||||
"completion_tokens": 0
|
||||
"index": 0,
|
||||
}
|
||||
}
|
||||
],
|
||||
"model": "network_agent",
|
||||
"usage": {"completion_tokens": 0},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(debug=True)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_sql():
|
||||
# Example Usage
|
||||
conn = sqlite3.connect(":memory:")
|
||||
|
|
@ -26,6 +27,7 @@ def load_sql():
|
|||
|
||||
return conn
|
||||
|
||||
|
||||
# Function to convert natural language time expressions to "X {time} ago" format
|
||||
def convert_to_ago_format(expression):
|
||||
# Define patterns for different time units
|
||||
|
|
|
|||
|
|
@ -13,9 +13,10 @@ def get_device_summary():
|
|||
# Validate 'device_ids' parameter
|
||||
device_ids = data.get("device_ids")
|
||||
if not device_ids or not isinstance(device_ids, list):
|
||||
return jsonify(
|
||||
{"error": "'device_ids' parameter is required and must be a list"}
|
||||
), 400
|
||||
return (
|
||||
jsonify({"error": "'device_ids' parameter is required and must be a list"}),
|
||||
400,
|
||||
)
|
||||
|
||||
# Validate 'time_range' parameter (optional, defaults to 7)
|
||||
time_range = data.get("time_range", 7)
|
||||
|
|
|
|||
|
|
@ -119,9 +119,10 @@ def process_rag():
|
|||
intent_changed = True
|
||||
else:
|
||||
# Invalid value provided
|
||||
return jsonify(
|
||||
{"error": "Invalid value for x-arch-prompt-intent-change header"}
|
||||
), 400
|
||||
return (
|
||||
jsonify({"error": "Invalid value for x-arch-prompt-intent-change header"}),
|
||||
400,
|
||||
)
|
||||
|
||||
# Update user conversation based on intent change
|
||||
memory = update_user_conversation(user_id, client_messages, intent_changed)
|
||||
|
|
|
|||
|
|
@ -13,9 +13,10 @@ def get_device_summary():
|
|||
# Validate 'device_ids' parameter
|
||||
device_ids = data.get("device_ids")
|
||||
if not device_ids or not isinstance(device_ids, list):
|
||||
return jsonify(
|
||||
{"error": "'device_ids' parameter is required and must be a list"}
|
||||
), 400
|
||||
return (
|
||||
jsonify({"error": "'device_ids' parameter is required and must be a list"}),
|
||||
400,
|
||||
)
|
||||
|
||||
# Validate 'time_range' parameter (optional, defaults to 7)
|
||||
time_range = data.get("time_range", 7)
|
||||
|
|
|
|||
|
|
@ -12,16 +12,16 @@ from sphinx.util.docfields import Field
|
|||
from sphinxawesome_theme import ThemeOptions
|
||||
from sphinxawesome_theme.postprocess import Icons
|
||||
|
||||
project = 'Arch Docs'
|
||||
copyright = '2024, Katanemo Labs, Inc'
|
||||
author = 'Katanemo Labs, Inc'
|
||||
release = ' v0.1'
|
||||
project = "Arch Docs"
|
||||
copyright = "2024, Katanemo Labs, Inc"
|
||||
author = "Katanemo Labs, Inc"
|
||||
release = " v0.1"
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
|
||||
root_doc = 'index'
|
||||
root_doc = "index"
|
||||
|
||||
nitpicky = True
|
||||
add_module_names = False
|
||||
|
|
@ -33,23 +33,23 @@ extensions = [
|
|||
"sphinx.ext.extlinks",
|
||||
"sphinx.ext.viewcode",
|
||||
"sphinx_sitemap",
|
||||
"sphinx_design"
|
||||
"sphinx_design",
|
||||
]
|
||||
|
||||
# Paths that contain templates, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
templates_path = ["_templates"]
|
||||
|
||||
# List of patterns, relative to source directory, that match files and directories
|
||||
# to ignore when looking for source files.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
html_theme = 'sphinxawesome_theme' # You can change the theme to 'sphinx_rtd_theme' or another of your choice.
|
||||
html_theme = "sphinxawesome_theme" # You can change the theme to 'sphinx_rtd_theme' or another of your choice.
|
||||
html_title = project + release
|
||||
html_permalinks_icon = Icons.permalinks_icon
|
||||
html_favicon = '_static/favicon.ico'
|
||||
html_logo = '_static/favicon.ico' # Specify the path to the logo image file (make sure the logo is in the _static directory)
|
||||
html_favicon = "_static/favicon.ico"
|
||||
html_logo = "_static/favicon.ico" # Specify the path to the logo image file (make sure the logo is in the _static directory)
|
||||
html_last_updated_fmt = ""
|
||||
html_use_index = False # Don't create index
|
||||
html_domain_indices = False # Don't need module indices
|
||||
|
|
@ -57,10 +57,14 @@ html_copy_source = False # Don't need sources
|
|||
html_show_sphinx = False
|
||||
|
||||
|
||||
html_baseurl = './docs'
|
||||
html_baseurl = "./docs"
|
||||
|
||||
html_sidebars = {
|
||||
"**": ['analytics.html', "sidebar_main_nav_links.html", "sidebar_toc.html", ]
|
||||
"**": [
|
||||
"analytics.html",
|
||||
"sidebar_main_nav_links.html",
|
||||
"sidebar_toc.html",
|
||||
]
|
||||
}
|
||||
|
||||
theme_options = ThemeOptions(
|
||||
|
|
@ -102,7 +106,7 @@ html_theme_options = asdict(theme_options)
|
|||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
html_static_path = ["_static"]
|
||||
|
||||
pygments_style = "lovelace"
|
||||
pygments_style_dark = "github-dark"
|
||||
|
|
@ -111,10 +115,11 @@ sitemap_url_scheme = "{link}"
|
|||
# Add this configuration at the bottom of your conf.py
|
||||
|
||||
html_context = {
|
||||
'google_analytics_id': 'G-K2LXXSX6HB', # Replace with your Google Analytics tracking ID
|
||||
"google_analytics_id": "G-K2LXXSX6HB", # Replace with your Google Analytics tracking ID
|
||||
}
|
||||
|
||||
templates_path = ['_templates']
|
||||
templates_path = ["_templates"]
|
||||
|
||||
|
||||
# -- Register a :confval: interpreted text role ----------------------------------
|
||||
def setup(app: Sphinx) -> None:
|
||||
|
|
@ -138,4 +143,4 @@ def setup(app: Sphinx) -> None:
|
|||
],
|
||||
)
|
||||
|
||||
app.add_css_file('_static/custom.css')
|
||||
app.add_css_file("_static/custom.css")
|
||||
|
|
|
|||
|
|
@ -35,11 +35,13 @@ def start_server():
|
|||
print("Server is already running. Use 'model_server restart' to restart it.")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)")
|
||||
print(
|
||||
f"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)"
|
||||
)
|
||||
process = subprocess.Popen(
|
||||
["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"],
|
||||
start_new_session=True,
|
||||
stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to
|
||||
stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to
|
||||
stderr=subprocess.DEVNULL, # Suppress standard error. There is a logger that model_server prints to
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForSequenc
|
|||
|
||||
|
||||
def get_device():
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
|
|
@ -19,13 +18,15 @@ def get_device():
|
|||
return device
|
||||
|
||||
|
||||
def load_transformers(model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")):
|
||||
def load_transformers(
|
||||
model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")
|
||||
):
|
||||
print("Loading Embedding Model")
|
||||
transformers = {}
|
||||
device = get_device()
|
||||
transformers["tokenizer"] = AutoTokenizer.from_pretrained(model_name)
|
||||
transformers["model"] = ORTModelForFeatureExtraction.from_pretrained(
|
||||
model_name, device_map = device
|
||||
model_name, device_map=device
|
||||
)
|
||||
transformers["model_name"] = model_name
|
||||
|
||||
|
|
@ -62,7 +63,9 @@ def load_guard_model(
|
|||
return guard_model
|
||||
|
||||
|
||||
def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")):
|
||||
def load_zero_shot_models(
|
||||
model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")
|
||||
):
|
||||
zero_shot_model = {}
|
||||
device = get_device()
|
||||
zero_shot_model["model"] = ORTModelForSequenceClassification.from_pretrained(
|
||||
|
|
@ -81,5 +84,6 @@ def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deb
|
|||
|
||||
return zero_shot_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(get_device())
|
||||
|
|
|
|||
|
|
@ -7,7 +7,12 @@ from app.load_models import (
|
|||
get_device,
|
||||
)
|
||||
import os
|
||||
from app.utils import GuardHandler, split_text_into_chunks, load_yaml_config, get_model_server_logger
|
||||
from app.utils import (
|
||||
GuardHandler,
|
||||
split_text_into_chunks,
|
||||
load_yaml_config,
|
||||
get_model_server_logger,
|
||||
)
|
||||
import torch
|
||||
import yaml
|
||||
import string
|
||||
|
|
@ -39,6 +44,7 @@ guard_handler = GuardHandler(toxic_model=None, jailbreak_model=jailbreak_model)
|
|||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
input: str
|
||||
model: str
|
||||
|
|
@ -84,6 +90,7 @@ async def embedding(req: EmbeddingRequest, res: Response):
|
|||
}
|
||||
return {"data": data, "model": req.model, "object": "list", "usage": usage}
|
||||
|
||||
|
||||
class GuardRequest(BaseModel):
|
||||
input: str
|
||||
task: str
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import logging
|
|||
|
||||
logger_instance = None
|
||||
|
||||
|
||||
def load_yaml_config(file_name):
|
||||
# Load the YAML file from the package
|
||||
yaml_path = pkg_resources.resource_filename("app", file_name)
|
||||
|
|
@ -138,6 +139,7 @@ class GuardHandler:
|
|||
}
|
||||
return result_dict
|
||||
|
||||
|
||||
def get_model_server_logger():
|
||||
global logger_instance
|
||||
|
||||
|
|
@ -164,8 +166,8 @@ def get_model_server_logger():
|
|||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.FileHandler(log_file_path, mode='w'), # Overwrite logs in file
|
||||
]
|
||||
logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in file
|
||||
],
|
||||
)
|
||||
except (PermissionError, OSError) as e:
|
||||
# Dont' fallback to console logging if there are issues writing to the log file
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue