diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8fbd3c69..133f9ff8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,3 +25,8 @@ repos: # --lib is to only test the library, since when integration tests are made, # they will be in a seperate tests directory entry: bash -c "cd arch && cargo test -p intelligent-prompt-gateway --lib" + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + language_version: python3 diff --git a/arch/tools/cli.py b/arch/tools/cli.py index 50e4f03e..60c57e7e 100644 --- a/arch/tools/cli.py +++ b/arch/tools/cli.py @@ -16,18 +16,22 @@ logo = r""" /_/ \_\|_| \___||_| |_| """ + + @click.group(invoke_without_command=True) @click.pass_context def main(ctx): if ctx.invoked_subcommand is None: - click.echo( """Arch (The Intelligent Prompt Gateway) CLI""") + click.echo("""Arch (The Intelligent Prompt Gateway) CLI""") click.echo(logo) click.echo(ctx.get_help()) + # Command to build archgw and model_server Docker images ARCHGW_DOCKERFILE = "./arch/Dockerfile" MODEL_SERVER_BUILD_FILE = "./model_server/pyproject.toml" + @click.command() def build(): """Build Arch from source. Must be in root of cloned repo.""" @@ -35,7 +39,18 @@ def build(): if os.path.exists(ARCHGW_DOCKERFILE): click.echo("Building archgw image...") try: - subprocess.run(["docker", "build", "-f", ARCHGW_DOCKERFILE, "-t", "archgw:latest", "."], check=True) + subprocess.run( + [ + "docker", + "build", + "-f", + ARCHGW_DOCKERFILE, + "-t", + "archgw:latest", + ".", + ], + check=True, + ) click.echo("archgw image built successfully.") except subprocess.CalledProcessError as e: click.echo(f"Error building archgw image: {e}") @@ -51,7 +66,11 @@ def build(): if os.path.exists(MODEL_SERVER_BUILD_FILE): click.echo("Installing model server dependencies with Poetry...") try: - subprocess.run(["poetry", "install", "--no-cache"], cwd=os.path.dirname(MODEL_SERVER_BUILD_FILE), check=True) + subprocess.run( + ["poetry", "install", "--no-cache"], + cwd=os.path.dirname(MODEL_SERVER_BUILD_FILE), + check=True, + ) click.echo("Model server dependencies installed successfully.") except subprocess.CalledProcessError as e: click.echo(f"Error installing model server dependencies: {e}") @@ -60,9 +79,12 @@ def build(): click.echo(f"Error: pyproject.toml not found in {MODEL_SERVER_BUILD_FILE}") sys.exit(1) + @click.command() -@click.argument('file', required=False) # Optional file argument -@click.option('-path', default='.', help='Path to the directory containing arch_config.yml') +@click.argument("file", required=False) # Optional file argument +@click.option( + "-path", default=".", help="Path to the directory containing arch_config.yml" +) def up(file, path): """Starts Arch.""" if file: @@ -78,10 +100,15 @@ def up(file, path): return print(f"Validating {arch_config_file}") - arch_schema_config = pkg_resources.resource_filename(__name__, "config/arch_config_schema.yaml") + arch_schema_config = pkg_resources.resource_filename( + __name__, "config/arch_config_schema.yaml" + ) try: - config_generator.validate_prompt_config(arch_config_file=arch_config_file, arch_config_schema_file=arch_schema_config) + config_generator.validate_prompt_config( + arch_config_file=arch_config_file, + arch_config_schema_file=arch_schema_config, + ) except Exception as e: print("Exiting archgw up") sys.exit(1) @@ -91,52 +118,67 @@ def up(file, path): # Set the ARCH_CONFIG_FILE environment variable env_stage = {} env = os.environ.copy() - #check if access_keys are preesnt in the config file + # check if access_keys are preesnt in the config file access_keys = get_llm_provider_access_keys(arch_config_file=arch_config_file) if access_keys: if file: - app_env_file = os.path.join(os.path.dirname(os.path.abspath(file)), ".env") #check the .env file in the path + app_env_file = os.path.join( + os.path.dirname(os.path.abspath(file)), ".env" + ) # check the .env file in the path else: app_env_file = os.path.abspath(os.path.join(path, ".env")) - if not os.path.exists(app_env_file): #check to see if the environment variables in the current environment or not + if not os.path.exists( + app_env_file + ): # check to see if the environment variables in the current environment or not for access_key in access_keys: if env.get(access_key) is None: - print (f"Access Key: {access_key} not found. Exiting Start") + print(f"Access Key: {access_key} not found. Exiting Start") sys.exit(1) else: env_stage[access_key] = env.get(access_key) - else: #.env file exists, use that to send parameters to Arch + else: # .env file exists, use that to send parameters to Arch env_file_dict = load_env_file_to_dict(app_env_file) for access_key in access_keys: if env_file_dict.get(access_key) is None: - print (f"Access Key: {access_key} not found. Exiting Start") + print(f"Access Key: {access_key} not found. Exiting Start") sys.exit(1) else: env_stage[access_key] = env_file_dict[access_key] - with open(pkg_resources.resource_filename(__name__, "config/stage.env"), 'w') as file: + with open( + pkg_resources.resource_filename(__name__, "config/stage.env"), "w" + ) as file: for key, value in env_stage.items(): file.write(f"{key}={value}\n") env.update(env_stage) - env['ARCH_CONFIG_FILE'] = arch_config_file + env["ARCH_CONFIG_FILE"] = arch_config_file start_arch_modelserver() start_arch(arch_config_file, env) + @click.command() def down(): """Stops Arch.""" stop_arch_modelserver() stop_arch() + @click.command() -@click.option('-f', '--file', type=click.Path(exists=True), required=True, help="Path to the Python file") +@click.option( + "-f", + "--file", + type=click.Path(exists=True), + required=True, + help="Path to the Python file", +) def generate_prompt_targets(file): """Generats prompt_targets from python methods. - Note: This works for simple data types like ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']: - If you have a complex pydantic data type, you will have to flatten those manually until we add support for it.""" + Note: This works for simple data types like ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']: + If you have a complex pydantic data type, you will have to flatten those manually until we add support for it. + """ print(f"Processing file: {file}") if not file.endswith(".py"): @@ -145,10 +187,11 @@ def generate_prompt_targets(file): targets.generate_prompt_targets(file) + main.add_command(up) main.add_command(down) main.add_command(build) main.add_command(generate_prompt_targets) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/arch/tools/config_generator.py b/arch/tools/config_generator.py index 46cfd93d..fde60526 100644 --- a/arch/tools/config_generator.py +++ b/arch/tools/config_generator.py @@ -3,36 +3,44 @@ from jinja2 import Environment, FileSystemLoader import yaml from jsonschema import validate -ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.template.yaml') -ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml') -ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml') -ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml') +ENVOY_CONFIG_TEMPLATE_FILE = os.getenv( + "ENVOY_CONFIG_TEMPLATE_FILE", "envoy.template.yaml" +) +ARCH_CONFIG_FILE = os.getenv("ARCH_CONFIG_FILE", "/config/arch_config.yaml") +ENVOY_CONFIG_FILE_RENDERED = os.getenv( + "ENVOY_CONFIG_FILE_RENDERED", "/etc/envoy/envoy.yaml" +) +ARCH_CONFIG_SCHEMA_FILE = os.getenv( + "ARCH_CONFIG_SCHEMA_FILE", "arch_config_schema.yaml" +) -def add_secret_key_to_llm_providers(config_yaml) : + +def add_secret_key_to_llm_providers(config_yaml): llm_providers = [] for llm_provider in config_yaml.get("llm_providers", []): - access_key_env_var = llm_provider.get('access_key', False) + access_key_env_var = llm_provider.get("access_key", False) access_key_value = os.getenv(access_key_env_var, False) if access_key_env_var and access_key_value: - llm_provider['access_key'] = access_key_value + llm_provider["access_key"] = access_key_value llm_providers.append(llm_provider) config_yaml["llm_providers"] = llm_providers return config_yaml + def validate_and_render_schema(): - env = Environment(loader=FileSystemLoader('./')) - template = env.get_template('envoy.template.yaml') + env = Environment(loader=FileSystemLoader("./")) + template = env.get_template("envoy.template.yaml") try: validate_prompt_config(ARCH_CONFIG_FILE, ARCH_CONFIG_SCHEMA_FILE) except Exception as e: print(e) - exit(1) # validate_prompt_config failed. Exit + exit(1) # validate_prompt_config failed. Exit - with open(ARCH_CONFIG_FILE, 'r') as file: + with open(ARCH_CONFIG_FILE, "r") as file: arch_config = file.read() - with open(ARCH_CONFIG_SCHEMA_FILE, 'r') as file: + with open(ARCH_CONFIG_SCHEMA_FILE, "r") as file: arch_config_schema = file.read() config_yaml = yaml.safe_load(arch_config) @@ -44,7 +52,7 @@ def validate_and_render_schema(): if name not in inferred_clusters: inferred_clusters[name] = { "name": name, - "port": 80, # default port + "port": 80, # default port } print(inferred_clusters) @@ -55,14 +63,13 @@ def validate_and_render_schema(): if name in inferred_clusters: print("updating cluster", endpoint_details) inferred_clusters[name].update(endpoint_details) - endpoint = inferred_clusters[name]['endpoint'] - if len(endpoint.split(':')) > 1: - inferred_clusters[name]['endpoint'] = endpoint.split(':')[0] - inferred_clusters[name]['port'] = int(endpoint.split(':')[1]) + endpoint = inferred_clusters[name]["endpoint"] + if len(endpoint.split(":")) > 1: + inferred_clusters[name]["endpoint"] = endpoint.split(":")[0] + inferred_clusters[name]["port"] = int(endpoint.split(":")[1]) else: inferred_clusters[name] = endpoint_details - print("updated clusters", inferred_clusters) config_yaml = add_secret_key_to_llm_providers(config_yaml) @@ -71,23 +78,24 @@ def validate_and_render_schema(): arch_config_string = yaml.dump(config_yaml) data = { - 'arch_config': arch_config_string, - 'arch_clusters': inferred_clusters, - 'arch_llm_providers': arch_llm_providers, - 'arch_tracing': arch_tracing + "arch_config": arch_config_string, + "arch_clusters": inferred_clusters, + "arch_llm_providers": arch_llm_providers, + "arch_tracing": arch_tracing, } rendered = template.render(data) print(rendered) print(ENVOY_CONFIG_FILE_RENDERED) - with open(ENVOY_CONFIG_FILE_RENDERED, 'w') as file: + with open(ENVOY_CONFIG_FILE_RENDERED, "w") as file: file.write(rendered) + def validate_prompt_config(arch_config_file, arch_config_schema_file): - with open(arch_config_file, 'r') as file: + with open(arch_config_file, "r") as file: arch_config = file.read() - with open(arch_config_schema_file, 'r') as file: + with open(arch_config_schema_file, "r") as file: arch_config_schema = file.read() config_yaml = yaml.safe_load(arch_config) @@ -96,8 +104,11 @@ def validate_prompt_config(arch_config_file, arch_config_schema_file): try: validate(config_yaml, config_schema_yaml) except Exception as e: - print(f"Error validating arch_config file: {arch_config_file}, error: {e.message}") + print( + f"Error validating arch_config file: {arch_config_file}, error: {e.message}" + ) raise e -if __name__ == '__main__': + +if __name__ == "__main__": validate_and_render_schema() diff --git a/arch/tools/core.py b/arch/tools/core.py index f4732b92..fd93c589 100644 --- a/arch/tools/core.py +++ b/arch/tools/core.py @@ -5,6 +5,7 @@ import pkg_resources import select from utils import run_docker_compose_ps, print_service_status, check_services_state + def start_arch(arch_config_file, env, log_timeout=120): """ Start Docker Compose in detached mode and stream logs until services are healthy. @@ -14,22 +15,35 @@ def start_arch(arch_config_file, env, log_timeout=120): log_timeout (int): Time in seconds to show logs before checking for healthy state. """ - compose_file = pkg_resources.resource_filename(__name__, 'config/docker-compose.yaml') + compose_file = pkg_resources.resource_filename( + __name__, "config/docker-compose.yaml" + ) try: # Run the Docker Compose command in detached mode (-d) subprocess.run( - ["docker", "compose", "-p", "arch", "up", "-d",], - cwd=os.path.dirname(compose_file), # Ensure the Docker command runs in the correct path - env=env, # Pass the modified environment - check=True # Raise an exception if the command fails + [ + "docker", + "compose", + "-p", + "arch", + "up", + "-d", + ], + cwd=os.path.dirname( + compose_file + ), # Ensure the Docker command runs in the correct path + env=env, # Pass the modified environment + check=True, # Raise an exception if the command fails ) print(f"Arch docker-compose started in detached.") print("Monitoring `docker-compose ps` logs...") start_time = time.time() services_status = {} - services_running = False #assume that the services are not running at the moment + services_running = ( + False # assume that the services are not running at the moment + ) while True: current_time = time.time() @@ -40,16 +54,22 @@ def start_arch(arch_config_file, env, log_timeout=120): print(f"Stopping log monitoring after {log_timeout} seconds.") break - current_services_status = run_docker_compose_ps(compose_file=compose_file, env=env) + current_services_status = run_docker_compose_ps( + compose_file=compose_file, env=env + ) if not current_services_status: - print("Status for the services could not be detected. Something went wrong. Please run docker logs") + print( + "Status for the services could not be detected. Something went wrong. Please run docker logs" + ) break if not services_status: - services_status = current_services_status #set the first time - print_service_status(services_status) #print the services status and proceed. + services_status = current_services_status # set the first time + print_service_status( + services_status + ) # print the services status and proceed. - #check if anyone service is failed or exited state, if so print and break out + # check if anyone service is failed or exited state, if so print and break out unhealthy_states = ["unhealthy", "exit", "exited", "dead", "bad"] running_states = ["running", "up"] @@ -58,14 +78,23 @@ def start_arch(arch_config_file, env, log_timeout=120): break if check_services_state(current_services_status, unhealthy_states): - print("One or more Arch services are unhealthy. Please run `docker logs` for more information") - print_service_status(current_services_status) #print the services status and proceed. + print( + "One or more Arch services are unhealthy. Please run `docker logs` for more information" + ) + print_service_status( + current_services_status + ) # print the services status and proceed. break - #check to see if the status of one of the services has changed from prior. Print and loop over until finish, or error + # check to see if the status of one of the services has changed from prior. Print and loop over until finish, or error for service_name in services_status.keys(): - if services_status[service_name]['State'] != current_services_status[service_name]['State']: - print("One or more Arch services have changed state. Printing current state") + if ( + services_status[service_name]["State"] + != current_services_status[service_name]["State"] + ): + print( + "One or more Arch services have changed state. Printing current state" + ) print_service_status(current_services_status) break @@ -82,7 +111,9 @@ def stop_arch(): Args: path (str): The path where the docker-compose.yml file is located. """ - compose_file = pkg_resources.resource_filename(__name__, 'config/docker-compose.yaml') + compose_file = pkg_resources.resource_filename( + __name__, "config/docker-compose.yaml" + ) try: # Run `docker-compose down` to shut down all services @@ -96,6 +127,7 @@ def stop_arch(): except subprocess.CalledProcessError as e: print(f"Failed to shut down services: {str(e)}") + def start_arch_modelserver(): """ Start the model server. This assumes that the archgw_modelserver package is installed locally @@ -103,15 +135,14 @@ def start_arch_modelserver(): """ try: subprocess.run( - ['archgw_modelserver', 'restart'], - check=True, - start_new_session=True + ["archgw_modelserver", "restart"], check=True, start_new_session=True ) print("Successfull run the archgw model_server") except subprocess.CalledProcessError as e: - print (f"Failed to start model_server. Please check archgw_modelserver logs") + print(f"Failed to start model_server. Please check archgw_modelserver logs") sys.exit(1) + def stop_arch_modelserver(): """ Stop the model server. This assumes that the archgw_modelserver package is installed locally @@ -119,10 +150,10 @@ def stop_arch_modelserver(): """ try: subprocess.run( - ['archgw_modelserver', 'stop'], + ["archgw_modelserver", "stop"], check=True, ) print("Successfull stopped the archgw model_server") except subprocess.CalledProcessError as e: - print (f"Failed to start model_server. Please check archgw_modelserver logs") + print(f"Failed to start model_server. Please check archgw_modelserver logs") sys.exit(1) diff --git a/arch/tools/setup.py b/arch/tools/setup.py index f1e30416..62051025 100644 --- a/arch/tools/setup.py +++ b/arch/tools/setup.py @@ -6,17 +6,29 @@ setup( description="Python-based CLI tool to manage Arch and generate targets.", author="Katanemo Labs, Inc.", packages=find_packages(), - py_modules = ['cli', 'core', 'targets', 'utils', 'config_generator'], + py_modules=["cli", "core", "targets", "utils", "config_generator"], include_package_data=True, # Specify to include the docker-compose.yml file package_data={ - '': ['config/docker-compose.yaml', 'config/arch_config_schema.yaml', 'config/stage.env'] #Specify to include the docker-compose.yml file + "": [ + "config/docker-compose.yaml", + "config/arch_config_schema.yaml", + "config/stage.env", + ] # Specify to include the docker-compose.yml file }, # Add dependencies here, e.g., 'PyYAML' for YAML processing - install_requires=['pyyaml', 'pydantic', 'click', 'jinja2','pyyaml','jsonschema', 'setuptools'], + install_requires=[ + "pyyaml", + "pydantic", + "click", + "jinja2", + "pyyaml", + "jsonschema", + "setuptools", + ], entry_points={ - 'console_scripts': [ - 'archgw=cli:main', + "console_scripts": [ + "archgw=cli:main", ], }, ) diff --git a/arch/tools/targets.py b/arch/tools/targets.py index 82cc770a..a25c1aad 100644 --- a/arch/tools/targets.py +++ b/arch/tools/targets.py @@ -18,14 +18,20 @@ def detect_framework(tree: Any) -> str: return "fastapi" return "unknown" + def get_route_decorators(node: Any, framework: str) -> list: """Extract route decorators based on the framework.""" decorators = [] for decorator in node.decorator_list: - if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute): + if isinstance(decorator, ast.Call) and isinstance( + decorator.func, ast.Attribute + ): if framework == "flask" and decorator.func.attr in FLASK_ROUTE_DECORATORS: decorators.append(decorator.func.attr) - elif framework == "fastapi" and decorator.func.attr in FASTAPI_ROUTE_DECORATORS: + elif ( + framework == "fastapi" + and decorator.func.attr in FASTAPI_ROUTE_DECORATORS + ): decorators.append(decorator.func.attr) return decorators @@ -36,6 +42,7 @@ def get_route_path(node: Any, framework: str) -> str: if isinstance(decorator, ast.Call) and decorator.args: return decorator.args[0].s # Assuming it's a string literal + def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool: """Check if a given type annotation is a Pydantic model.""" # We walk through the AST to find class definitions and check if they inherit from Pydantic's BaseModel @@ -47,6 +54,7 @@ def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool: return True return False + def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list: """Extract fields from a Pydantic model, handling list, tuple, set, dict types, and direct default values.""" fields = [] @@ -62,15 +70,26 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list: required = True # Assume the field is required initially # Check if the field uses Field() with required status and description - if stmt.value and isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == 'Field': + if ( + stmt.value + and isinstance(stmt.value, ast.Call) + and isinstance(stmt.value.func, ast.Name) + and stmt.value.func.id == "Field" + ): # Extract the description argument inside the Field call for keyword in stmt.value.keywords: - if keyword.arg == 'description' and isinstance(keyword.value, ast.Str): + if keyword.arg == "description" and isinstance( + keyword.value, ast.Str + ): description = keyword.value.s - if keyword.arg == 'default': + if keyword.arg == "default": default_value = keyword.value # If Ellipsis (...) is used, it means the field is required - if stmt.value.args and isinstance(stmt.value.args[0], ast.Constant) and stmt.value.args[0].value is Ellipsis: + if ( + stmt.value.args + and isinstance(stmt.value.args[0], ast.Constant) + and stmt.value.args[0].value is Ellipsis + ): required = True else: required = False @@ -80,19 +99,28 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list: if isinstance(stmt.value, ast.Constant): # Set the default value from the assignment (e.g., name: str = "John Doe") default_value = stmt.value.value - required = False # Not required since it has a default value + required = ( + False # Not required since it has a default value + ) # Always extract the field type, even if there's a default value if isinstance(stmt.annotation, ast.Subscript): # Get the base type (list, tuple, set, dict) - base_type = stmt.annotation.value.id if isinstance(stmt.annotation.value, ast.Name) else "Unknown" + base_type = ( + stmt.annotation.value.id + if isinstance(stmt.annotation.value, ast.Name) + else "Unknown" + ) # Handle only list, tuple, set, dict and ignore the inner types - if base_type.lower() in ['list', 'tuple', 'set', 'dict']: + if base_type.lower() in ["list", "tuple", "set", "dict"]: field_type = base_type.lower() # Handle the ellipsis '...' for required fields if no Field() call - elif isinstance(stmt.value, ast.Constant) and stmt.value.value is Ellipsis: + elif ( + isinstance(stmt.value, ast.Constant) + and stmt.value.value is Ellipsis + ): required = True # Handle simple types like str, int, etc. @@ -100,16 +128,17 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list: field_type = stmt.annotation.id field_info = { - "name": stmt.target.id, - "type": field_type, # Always set the field type - "description": description, - "default": default_value, # Handle direct default values - "required": required + "name": stmt.target.id, + "type": field_type, # Always set the field type + "description": description, + "default": default_value, # Handle direct default values + "required": required, } fields.append(field_info) return fields + def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list: """Extract the parameters and their types from the function definition.""" parameters = [] @@ -119,40 +148,68 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list: arg_descriptions = extract_arg_descriptions_from_docstring(docstring) # Extract default values - defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + node.args.defaults # Align defaults with args + defaults = [None] * ( + len(node.args.args) - len(node.args.defaults) + ) + node.args.defaults # Align defaults with args for arg, default in zip(node.args.args, defaults): if arg.arg != "self": # Skip 'self' or 'cls' in class methods - param_info = {"name": arg.arg, "description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]")} + param_info = { + "name": arg.arg, + "description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]"), + } # Handle Pydantic model types - if hasattr(arg, 'annotation') and is_pydantic_model(arg.annotation, tree): + if hasattr(arg, "annotation") and is_pydantic_model(arg.annotation, tree): # Extract and flatten Pydantic model fields pydantic_fields = get_pydantic_model_fields(arg.annotation.id, tree) - parameters.extend(pydantic_fields) # Flatten the model fields into the parameters list + parameters.extend( + pydantic_fields + ) # Flatten the model fields into the parameters list continue # Skip adding the current param_info for the model since we expand the fields # Handle standard Python types (int, float, str, etc.) - elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Name): - if arg.annotation.id in ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']: + elif hasattr(arg, "annotation") and isinstance(arg.annotation, ast.Name): + if arg.annotation.id in [ + "int", + "float", + "bool", + "str", + "list", + "tuple", + "set", + "dict", + ]: param_info["type"] = arg.annotation.id else: param_info["type"] = "[UNKNOWN - PLEASE FIX]" # Handle generic subscript types (e.g., Optional, List[Type], etc.) - elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Subscript): - if isinstance(arg.annotation.value, ast.Name) and arg.annotation.value.id in ['list', 'tuple', 'set', 'dict']: - param_info["type"] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc. + elif hasattr(arg, "annotation") and isinstance( + arg.annotation, ast.Subscript + ): + if isinstance( + arg.annotation.value, ast.Name + ) and arg.annotation.value.id in ["list", "tuple", "set", "dict"]: + param_info[ + "type" + ] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc. else: param_info["type"] = "[UNKNOWN - PLEASE FIX]" # Default for unknown types else: - param_info["type"] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type + param_info[ + "type" + ] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type # Handle default values if default is not None: - if isinstance(default, ast.Constant) or isinstance(default, ast.NameConstant): - param_info["default"] = default.value # Use the default value directly + if isinstance(default, ast.Constant) or isinstance( + default, ast.NameConstant + ): + param_info[ + "default" + ] = default.value # Use the default value directly else: param_info["default"] = "[UNKNOWN DEFAULT]" # Unknown default type param_info["required"] = False # Optional since it has a default value @@ -164,6 +221,7 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list: return parameters + def get_function_docstring(node: Any) -> str: """Extract the function's docstring description if present.""" # Check if the first node is a docstring @@ -178,6 +236,7 @@ def get_function_docstring(node: Any) -> str: return "No description provided." + def extract_arg_descriptions_from_docstring(docstring: str) -> dict: """Extract descriptions for function parameters from the 'Args' section of the docstring.""" descriptions = {} @@ -195,14 +254,14 @@ def extract_arg_descriptions_from_docstring(docstring: str) -> dict: continue # Proceed to the next line after 'Args:' # End of 'Args' section if no indentation and no colon - if in_args_section and not line.startswith(" ") and ':' not in line: + if in_args_section and not line.startswith(" ") and ":" not in line: break # Stop processing if we reach a new section # Process lines in the 'Args' section if in_args_section: - if ':' in line: + if ":" in line: # Extract parameter name and description - param_name, description = line.split(':', 1) + param_name, description = line.split(":", 1) descriptions[param_name.strip()] = description.strip() current_param = param_name.strip() elif current_param and line.startswith(" "): @@ -230,43 +289,50 @@ def generate_prompt_targets(input_file_path: str) -> None: route_decorators = get_route_decorators(node, framework) if route_decorators: route_path = get_route_path(node, framework) - function_params = get_function_parameters(node, tree) # Get parameters for the route + function_params = get_function_parameters( + node, tree + ) # Get parameters for the route function_docstring = get_function_docstring(node) # Extract docstring - routes.append({ - 'name': node.name, - 'path': route_path, - 'methods': route_decorators, - 'parameters': function_params, # Add parameters to the route - 'description': function_docstring # Add the docstring as the description - }) + routes.append( + { + "name": node.name, + "path": route_path, + "methods": route_decorators, + "parameters": function_params, # Add parameters to the route + "description": function_docstring, # Add the docstring as the description + } + ) # Generate YAML structure - output_structure = { - "prompt_targets": [] - } + output_structure = {"prompt_targets": []} for route in routes: target = { - "name": route['name'], + "name": route["name"], "endpoint": [ { "name": "app_server", - "path": route['path'], + "path": route["path"], } ], - "description": route['description'], # Use extracted docstring + "description": route["description"], # Use extracted docstring "parameters": [ { - "name": param['name'], - "type": param['type'], + "name": param["name"], + "type": param["type"], "description": f"{param['description']}", - **({"default": param['default']} if "default" in param and param['default'] is not None else {}), # Only add default if it's set - "required": param['required'] - } for param in route['parameters'] - ] + **( + {"default": param["default"]} + if "default" in param and param["default"] is not None + else {} + ), # Only add default if it's set + "required": param["required"], + } + for param in route["parameters"] + ], } - if route['name'] == "default": + if route["name"] == "default": # Special case for `information_extraction` based on your YAML format target["type"] = "default" target["auto-llm-dispatch-on-response"] = True @@ -274,9 +340,12 @@ def generate_prompt_targets(input_file_path: str) -> None: output_structure["prompt_targets"].append(target) # Output as YAML - print(yaml.dump(output_structure, sort_keys=False,default_flow_style=False, indent=3)) + print( + yaml.dump(output_structure, sort_keys=False, default_flow_style=False, indent=3) + ) -if __name__ == '__main__': + +if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: python targets.py ") sys.exit(1) diff --git a/arch/tools/test/fastapi_test.py b/arch/tools/test/fastapi_test.py index 1f25a0e1..bedac8bd 100644 --- a/arch/tools/test/fastapi_test.py +++ b/arch/tools/test/fastapi_test.py @@ -4,12 +4,23 @@ from typing import List, Dict, Set app = FastAPI() + class User(BaseModel): - name: str = Field("John Doe", description="The name of the user.") # Default value and description for name + name: str = Field( + "John Doe", description="The name of the user." + ) # Default value and description for name location: int = None - age: int = Field(30, description="The age of the user.") # Default value and description for age - tags: Set[str] = Field(default_factory=set, description="A set of tags associated with the user.") # Default empty set and description for tags - metadata: Dict[str, int] = Field(default_factory=dict, description="A dictionary storing metadata about the user, with string keys and integer values.") # Default empty dict and description for metadata + age: int = Field( + 30, description="The age of the user." + ) # Default value and description for age + tags: Set[str] = Field( + default_factory=set, description="A set of tags associated with the user." + ) # Default empty set and description for tags + metadata: Dict[str, int] = Field( + default_factory=dict, + description="A dictionary storing metadata about the user, with string keys and integer values.", + ) # Default empty dict and description for metadata + @app.get("/agent/default") async def default(request: User): @@ -19,6 +30,7 @@ async def default(request: User): """ return {"info": f"Query: {request.name}, Count: {request.age}"} + @app.post("/agent/action") async def reboot_network_device(device_id: str, confirmation: str): """ diff --git a/arch/tools/utils.py b/arch/tools/utils.py index eb72870c..5ed3e0fc 100644 --- a/arch/tools/utils.py +++ b/arch/tools/utils.py @@ -6,6 +6,7 @@ import shlex import yaml import json + def run_docker_compose_ps(compose_file, env): """ Check if all Docker Compose services are in a healthy state. @@ -16,20 +17,31 @@ def run_docker_compose_ps(compose_file, env): try: # Run `docker-compose ps` to get the health status of each service ps_process = subprocess.Popen( - ["docker", "compose", "-p", "arch", "ps", "--format", "table{{.Service}}\t{{.State}}\t{{.Ports}}"], + [ + "docker", + "compose", + "-p", + "arch", + "ps", + "--format", + "table{{.Service}}\t{{.State}}\t{{.Ports}}", + ], cwd=os.path.dirname(compose_file), stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, start_new_session=True, - env=env + env=env, ) # Capture the output of `docker-compose ps` services_status, error_output = ps_process.communicate() # Check if there is any error output if error_output: - print(f"Error while checking service status:\n{error_output}", file=os.sys.stderr) + print( + f"Error while checking service status:\n{error_output}", + file=os.sys.stderr, + ) return {} services = parse_docker_compose_ps_output(services_status) @@ -39,26 +51,31 @@ def run_docker_compose_ps(compose_file, env): print(f"Failed to check service status. Error:\n{e.stderr}") return e -#Helper method to print service status + +# Helper method to print service status def print_service_status(services): print(f"{'Service Name':<25} {'State':<20} {'Ports'}") - print("="*72) + print("=" * 72) for service_name, info in services.items(): - status = info['STATE'] - ports = info['PORTS'] + status = info["STATE"] + ports = info["PORTS"] print(f"{service_name:<25} {status:<20} {ports}") -#check for states based on the states passed in + +# check for states based on the states passed in def check_services_state(services, states): for service_name, service_info in services.items(): - status = service_info['STATE'].lower() # Convert status to lowercase for easier comparison + status = service_info[ + "STATE" + ].lower() # Convert status to lowercase for easier comparison if any(state in status for state in states): return True return False + def get_llm_provider_access_keys(arch_config_file): - with open(arch_config_file, 'r') as file: + with open(arch_config_file, "r") as file: arch_config = file.read() arch_config_yaml = yaml.safe_load(arch_config) @@ -70,22 +87,23 @@ def get_llm_provider_access_keys(arch_config_file): return access_key_list + def load_env_file_to_dict(file_path): env_dict = {} # Open and read the .env file - with open(file_path, 'r') as file: + with open(file_path, "r") as file: for line in file: # Strip any leading/trailing whitespaces line = line.strip() # Skip empty lines and comments - if not line or line.startswith('#'): + if not line or line.startswith("#"): continue # Split the line into key and value at the first '=' sign - if '=' in line: - key, value = line.split('=', 1) + if "=" in line: + key, value = line.split("=", 1) key = key.strip() value = value.strip() @@ -94,6 +112,7 @@ def load_env_file_to_dict(file_path): return env_dict + def parse_docker_compose_ps_output(output): # Split the output into lines lines = output.strip().splitlines() @@ -111,10 +130,7 @@ def parse_docker_compose_ps_output(output): parts = line.split() # Create a dictionary entry using the header names - service_info = { - headers[1]: parts[1], # State - headers[2]: parts[2] # Ports - } + service_info = {headers[1]: parts[1], headers[2]: parts[2]} # State # Ports # Add to the result dictionary using the service name as the key services[parts[0]] = service_info diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index 63c034ae..1fe10e12 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -7,47 +7,53 @@ from dotenv import load_dotenv load_dotenv() -OPENAI_API_KEY=os.getenv("OPENAI_API_KEY") +OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo") -ARCH_STATE_HEADER = 'x-arch-state' +ARCH_STATE_HEADER = "x-arch-state" log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) -client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT, http_client=DefaultHttpxClient(headers={"accept-encoding": "*"})) +client = OpenAI( + api_key=OPENAI_API_KEY, + base_url=CHAT_COMPLETION_ENDPOINT, + http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}), +) + def predict(message, state): - if 'history' not in state: - state['history'] = [] + if "history" not in state: + state["history"] = [] history = state.get("history") history.append({"role": "user", "content": message}) log.info("history: ", history) # Custom headers custom_headers = { - 'x-arch-openai-api-key': f"{OPENAI_API_KEY}", - 'x-arch-mistral-api-key': f"{MISTRAL_API_KEY}", - 'x-arch-deterministic-provider': 'openai', + "x-arch-openai-api-key": f"{OPENAI_API_KEY}", + "x-arch-mistral-api-key": f"{MISTRAL_API_KEY}", + "x-arch-deterministic-provider": "openai", } metadata = None - if 'arch_state' in state: - metadata = {ARCH_STATE_HEADER: state['arch_state']} + if "arch_state" in state: + metadata = {ARCH_STATE_HEADER: state["arch_state"]} try: - raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME, - messages = history, - temperature=1.0, - metadata=metadata, - extra_headers=custom_headers - ) + raw_response = client.chat.completions.with_raw_response.create( + model=MODEL_NAME, + messages=history, + temperature=1.0, + metadata=metadata, + extra_headers=custom_headers, + ) except Exception as e: - log.info(e) - # remove last user message in case of exception - history.pop() - log.info("Error calling gateway API: {}".format(e.message)) - raise gr.Error("Error calling gateway API: {}".format(e.message)) + log.info(e) + # remove last user message in case of exception + history.pop() + log.info("Error calling gateway API: {}".format(e.message)) + raise gr.Error("Error calling gateway API: {}".format(e.message)) log.info("raw_response: ", raw_response.text) response = raw_response.parse() @@ -57,24 +63,33 @@ def predict(message, state): response_json = json.loads(raw_response.text) arch_state = None if response_json: - metadata = response_json.get('metadata', {}) - if metadata: - arch_state = metadata.get(ARCH_STATE_HEADER, None) + metadata = response_json.get("metadata", {}) + if metadata: + arch_state = metadata.get(ARCH_STATE_HEADER, None) if arch_state: - state['arch_state'] = arch_state + state["arch_state"] = arch_state content = response.choices[0].message.content history.append({"role": "assistant", "content": content, "model": response.model}) - messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)] + messages = [ + (history[i]["content"], history[i + 1]["content"]) + for i in range(0, len(history) - 1, 2) + ] return messages, state + with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo: print("Starting Demo...") chatbot = gr.Chatbot(label="Arch Chatbot", scale=1) state = gr.State({}) with gr.Row(): - txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1, autofocus=True) + txt = gr.Textbox( + show_label=False, + placeholder="Enter text and press enter", + scale=1, + autofocus=True, + ) txt.submit(predict, [txt, state], [chatbot, state]) diff --git a/chatbot_ui/app/run_stream.py b/chatbot_ui/app/run_stream.py index 458508d3..8be5a16b 100644 --- a/chatbot_ui/app/run_stream.py +++ b/chatbot_ui/app/run_stream.py @@ -5,26 +5,32 @@ from openai import OpenAI import gradio as gr api_key = os.getenv("OPENAI_API_KEY") -CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1") +CHAT_COMPLETION_ENDPOINT = os.getenv( + "CHAT_COMPLETION_ENDPOINT", "https://api.openai.com/v1" +) client = OpenAI(api_key=api_key, base_url=CHAT_COMPLETION_ENDPOINT) + def predict(message, history): history_openai_format = [] for human, assistant in history: - history_openai_format.append({"role": "user", "content": human }) - history_openai_format.append({"role": "assistant", "content":assistant}) + history_openai_format.append({"role": "user", "content": human}) + history_openai_format.append({"role": "assistant", "content": assistant}) history_openai_format.append({"role": "user", "content": message}) - response = client.chat.completions.create(model='gpt-3.5-turbo', - messages= history_openai_format, - temperature=1.0, - stream=True) + response = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=history_openai_format, + temperature=1.0, + stream=True, + ) partial_message = "" for chunk in response: if chunk.choices[0].delta.content is not None: - partial_message = partial_message + chunk.choices[0].delta.content - yield partial_message + partial_message = partial_message + chunk.choices[0].delta.content + yield partial_message + gr.ChatInterface(predict).launch(server_name="0.0.0.0", server_port=8081) diff --git a/demos/function_calling/api_server/app/main.py b/demos/function_calling/api_server/app/main.py index edd58a19..cde06795 100644 --- a/demos/function_calling/api_server/app/main.py +++ b/demos/function_calling/api_server/app/main.py @@ -6,56 +6,54 @@ import logging from pydantic import BaseModel -logger = logging.getLogger('uvicorn.error') +logger = logging.getLogger("uvicorn.error") logger.setLevel(logging.INFO) app = FastAPI() + @app.get("/healthz") async def healthz(): - return { - "status": "ok" - } + return {"status": "ok"} + class WeatherRequest(BaseModel): - city: str - days: int = 7 - units: str = "Farenheit" + city: str + days: int = 7 + units: str = "Farenheit" @app.post("/weather") async def weather(req: WeatherRequest, res: Response): - weather_forecast = { "city": req.city, "temperature": [], "units": req.units, } for i in range(7): - min_temp = random.randrange(50,90) - max_temp = random.randrange(min_temp+5, min_temp+20) - if req.units.lower() == "celsius" or req.units.lower() == "c": - min_temp = (min_temp - 32) * 5.0/9.0 - max_temp = (max_temp - 32) * 5.0/9.0 - weather_forecast["temperature"].append({ - "date": str(date.today() + timedelta(days=i)), - "temperature": { - "min": min_temp, - "max": max_temp - }, - "units": req.units, - "query_time": str(datetime.now(timezone.utc)) - }) + min_temp = random.randrange(50, 90) + max_temp = random.randrange(min_temp + 5, min_temp + 20) + if req.units.lower() == "celsius" or req.units.lower() == "c": + min_temp = (min_temp - 32) * 5.0 / 9.0 + max_temp = (max_temp - 32) * 5.0 / 9.0 + weather_forecast["temperature"].append( + { + "date": str(date.today() + timedelta(days=i)), + "temperature": {"min": min_temp, "max": max_temp}, + "units": req.units, + "query_time": str(datetime.now(timezone.utc)), + } + ) return weather_forecast class InsuranceClaimDetailsRequest(BaseModel): - policy_number: str + policy_number: str + @app.post("/insurance_claim_details") async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Response): - claim_details = { "policy_number": req.policy_number, "claim_status": "Approved", @@ -68,26 +66,25 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon class DefaultTargetRequest(BaseModel): - arch_messages: list + arch_messages: list + @app.post("/default_target") async def default_target(req: DefaultTargetRequest, res: Response): logger.info(f"Received arch_messages: {req.arch_messages}") resp = { - "choices": [ - { + "choices": [ + { "message": { - "role": "assistant", - "content": "hello world from api server" + "role": "assistant", + "content": "hello world from api server", }, "finish_reason": "completed", - "index": 0 - } - ], - "model": "api_server", - "usage": { - "completion_tokens": 0 + "index": 0, } - } + ], + "model": "api_server", + "usage": {"completion_tokens": 0}, + } logger.info(f"sending response: {json.dumps(resp)}") return resp diff --git a/demos/insurance_agent/insurance_agent_main.py b/demos/insurance_agent/insurance_agent_main.py index 6688da12..3143342f 100644 --- a/demos/insurance_agent/insurance_agent_main.py +++ b/demos/insurance_agent/insurance_agent_main.py @@ -3,28 +3,40 @@ from pydantic import BaseModel, Field app = FastAPI() + class Conversation(BaseModel): arch_messages: list + class PolicyCoverageRequest(BaseModel): - policy_type: str = Field(..., description="The type of a policy held by the customer For, e.g. car, boat, house, motorcycle)") + policy_type: str = Field( + ..., + description="The type of a policy held by the customer For, e.g. car, boat, house, motorcycle)", + ) + class PolicyInitiateRequest(PolicyCoverageRequest): - deductible: float = Field(..., description="The deductible amount set of the policy") + deductible: float = Field( + ..., description="The deductible amount set of the policy" + ) + class ClaimUpdate(BaseModel): claim_id: str notes: str # Status or details of the claim + class DeductibleUpdate(BaseModel): policy_id: str deductible: float + class CoverageResponse(BaseModel): policy_type: str coverage: str # Description of coverage premium: float # The premium cost + # Get information about policy coverage @app.post("/policy/coverage", response_model=CoverageResponse) async def get_policy_coverage(req: PolicyCoverageRequest): @@ -32,10 +44,22 @@ async def get_policy_coverage(req: PolicyCoverageRequest): Retrieve the coverage details for a given policy type (car, boat, house, motorcycle). """ policy_coverage = { - "car": {"coverage": "Full car coverage with collision, liability", "premium": 500.0}, - "boat": {"coverage": "Full boat coverage including theft and storm damage", "premium": 700.0}, - "house": {"coverage": "Full house coverage including fire, theft, flood", "premium": 1000.0}, - "motorcycle": {"coverage": "Full motorcycle coverage with liability", "premium": 400.0}, + "car": { + "coverage": "Full car coverage with collision, liability", + "premium": 500.0, + }, + "boat": { + "coverage": "Full boat coverage including theft and storm damage", + "premium": 700.0, + }, + "house": { + "coverage": "Full house coverage including fire, theft, flood", + "premium": 1000.0, + }, + "motorcycle": { + "coverage": "Full motorcycle coverage with liability", + "premium": 400.0, + }, } if req.policy_type not in policy_coverage: @@ -44,9 +68,10 @@ async def get_policy_coverage(req: PolicyCoverageRequest): return CoverageResponse( policy_type=req.policy_type, coverage=policy_coverage[req.policy_type]["coverage"], - premium=policy_coverage[req.policy_type]["premium"] + premium=policy_coverage[req.policy_type]["premium"], ) + # Initiate policy coverage @app.post("/policy/initiate") async def initiate_policy(policy_request: PolicyInitiateRequest): @@ -56,7 +81,11 @@ async def initiate_policy(policy_request: PolicyInitiateRequest): if policy_request.policy_type not in ["car", "boat", "house", "motorcycle"]: raise HTTPException(status_code=400, detail="Invalid policy type") - return {"message": f"Policy initiated for {policy_request.policy_type}", "deductible": policy_request.deductible} + return { + "message": f"Policy initiated for {policy_request.policy_type}", + "deductible": policy_request.deductible, + } + # Update claim details @app.post("/policy/claim") @@ -65,8 +94,11 @@ async def update_claim(req: ClaimUpdate): Update the status or details of a claim. """ # For simplicity, this is a mock update response - return {"message": f"Claim {claim_update.claim_id} for policy {claim_update.claim_id} has been updated", - "update": claim_update.notes} + return { + "message": f"Claim {claim_update.claim_id} for policy {claim_update.claim_id} has been updated", + "update": claim_update.notes, + } + # Update deductible amount @app.post("/policy/deductible") @@ -75,8 +107,11 @@ async def update_deductible(deductible_update: DeductibleUpdate): Update the deductible amount for a specific policy. """ # For simplicity, this is a mock update response - return {"message": f"Deductible for policy {deductible_update.policy_id} has been updated", - "new_deductible": deductible_update.deductible} + return { + "message": f"Deductible for policy {deductible_update.policy_id} has been updated", + "new_deductible": deductible_update.deductible, + } + # Post method for policy Q/A @app.post("/policy/qa") @@ -86,21 +121,20 @@ async def policy_qa(conversation: Conversation): It forwards the conversation to the OpenAI client via a local proxy and returns the response. """ return { - "choices": [ - { + "choices": [ + { "message": { - "role": "assistant", - "content": "I am a helpful insurance agent, and can only help with insurance things" + "role": "assistant", + "content": "I am a helpful insurance agent, and can only help with insurance things", }, "finish_reason": "completed", - "index": 0 - } - ], - "model": "insurance_agent", - "usage": { - "completion_tokens": 0 + "index": 0, } - } + ], + "model": "insurance_agent", + "usage": {"completion_tokens": 0}, + } + # Run the app using: # uvicorn main:app --reload diff --git a/demos/network_agent/main.py b/demos/network_agent/main.py index 682f89ae..0f9a6ee0 100644 --- a/demos/network_agent/main.py +++ b/demos/network_agent/main.py @@ -4,10 +4,14 @@ from typing import List, Optional app = FastAPI() + # Define the request model class DeviceSummaryRequest(BaseModel): device_ids: List[int] - time_range: Optional[int] = Field(default=7, description="Time range in days, defaults to 7") + time_range: Optional[int] = Field( + default=7, description="Time range in days, defaults to 7" + ) + # Define the response model class DeviceStatistics(BaseModel): @@ -15,18 +19,23 @@ class DeviceStatistics(BaseModel): time_range: str data: str + class DeviceSummaryResponse(BaseModel): statistics: List[DeviceStatistics] # Request model for device reboot + + class DeviceRebootRequest(BaseModel): device_ids: List[int] + # Response model for the device reboot class CoverageResponse(BaseModel): status: str summary: dict + @app.post("/agent/device_reboot", response_model=CoverageResponse) def reboot_network_device(request_data: DeviceRebootRequest): """ @@ -38,20 +47,21 @@ def reboot_network_device(request_data: DeviceRebootRequest): # Validate 'device_ids' (This is already validated by Pydantic, but additional logic can be added if needed) if not device_ids: - raise HTTPException(status_code=400, detail="'device_ids' parameter is required") + raise HTTPException( + status_code=400, detail="'device_ids' parameter is required" + ) # Simulate reboot operation and return the response statistics = [] for device_id in device_ids: # Placeholder for actual data retrieval or device reboot logic - stats = { - "data": f"Device {device_id} has been successfully rebooted." - } + stats = {"data": f"Device {device_id} has been successfully rebooted."} statistics.append(stats) # Return the response with a summary return CoverageResponse(status="success", summary={"device_ids": device_ids}) + # Post method for device summary @app.post("/agent/device_summary", response_model=DeviceSummaryResponse) def get_device_summary(request: DeviceSummaryRequest): @@ -77,6 +87,7 @@ def get_device_summary(request: DeviceSummaryRequest): return DeviceSummaryResponse(statistics=statistics) + @app.post("/agent/network_summary") async def policy_qa(): """ @@ -84,21 +95,20 @@ async def policy_qa(): It forwards the conversation to the OpenAI client via a local proxy and returns the response. """ return { - "choices": [ - { + "choices": [ + { "message": { - "role": "assistant", - "content": "I am a helpful networking agent, and I can help you get status for network devices or reboot them" + "role": "assistant", + "content": "I am a helpful networking agent, and I can help you get status for network devices or reboot them", }, "finish_reason": "completed", - "index": 0 - } - ], - "model": "network_agent", - "usage": { - "completion_tokens": 0 + "index": 0, } - } + ], + "model": "network_agent", + "usage": {"completion_tokens": 0}, + } + if __name__ == "__main__": app.run(debug=True) diff --git a/demos/network_agent/utils.py b/demos/network_agent/utils.py index f02fa2e5..18b782c5 100644 --- a/demos/network_agent/utils.py +++ b/demos/network_agent/utils.py @@ -11,6 +11,7 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) + def load_sql(): # Example Usage conn = sqlite3.connect(":memory:") @@ -26,6 +27,7 @@ def load_sql(): return conn + # Function to convert natural language time expressions to "X {time} ago" format def convert_to_ago_format(expression): # Define patterns for different time units diff --git a/docs/source/build_with_arch/includes/agent/parameter_handling.py b/docs/source/build_with_arch/includes/agent/parameter_handling.py index e99893dd..cf2aebc7 100644 --- a/docs/source/build_with_arch/includes/agent/parameter_handling.py +++ b/docs/source/build_with_arch/includes/agent/parameter_handling.py @@ -13,9 +13,10 @@ def get_device_summary(): # Validate 'device_ids' parameter device_ids = data.get("device_ids") if not device_ids or not isinstance(device_ids, list): - return jsonify( - {"error": "'device_ids' parameter is required and must be a list"} - ), 400 + return ( + jsonify({"error": "'device_ids' parameter is required and must be a list"}), + 400, + ) # Validate 'time_range' parameter (optional, defaults to 7) time_range = data.get("time_range", 7) diff --git a/docs/source/build_with_arch/includes/rag/intent_detection.py b/docs/source/build_with_arch/includes/rag/intent_detection.py index d4fea371..df6f1c0a 100644 --- a/docs/source/build_with_arch/includes/rag/intent_detection.py +++ b/docs/source/build_with_arch/includes/rag/intent_detection.py @@ -119,9 +119,10 @@ def process_rag(): intent_changed = True else: # Invalid value provided - return jsonify( - {"error": "Invalid value for x-arch-prompt-intent-change header"} - ), 400 + return ( + jsonify({"error": "Invalid value for x-arch-prompt-intent-change header"}), + 400, + ) # Update user conversation based on intent change memory = update_user_conversation(user_id, client_messages, intent_changed) diff --git a/docs/source/build_with_arch/includes/rag/parameter_handling.py b/docs/source/build_with_arch/includes/rag/parameter_handling.py index e99893dd..cf2aebc7 100644 --- a/docs/source/build_with_arch/includes/rag/parameter_handling.py +++ b/docs/source/build_with_arch/includes/rag/parameter_handling.py @@ -13,9 +13,10 @@ def get_device_summary(): # Validate 'device_ids' parameter device_ids = data.get("device_ids") if not device_ids or not isinstance(device_ids, list): - return jsonify( - {"error": "'device_ids' parameter is required and must be a list"} - ), 400 + return ( + jsonify({"error": "'device_ids' parameter is required and must be a list"}), + 400, + ) # Validate 'time_range' parameter (optional, defaults to 7) time_range = data.get("time_range", 7) diff --git a/docs/source/conf.py b/docs/source/conf.py index 431ce72d..00ca3311 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,16 +12,16 @@ from sphinx.util.docfields import Field from sphinxawesome_theme import ThemeOptions from sphinxawesome_theme.postprocess import Icons -project = 'Arch Docs' -copyright = '2024, Katanemo Labs, Inc' -author = 'Katanemo Labs, Inc' -release = ' v0.1' +project = "Arch Docs" +copyright = "2024, Katanemo Labs, Inc" +author = "Katanemo Labs, Inc" +release = " v0.1" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -root_doc = 'index' +root_doc = "index" nitpicky = True add_module_names = False @@ -33,23 +33,23 @@ extensions = [ "sphinx.ext.extlinks", "sphinx.ext.viewcode", "sphinx_sitemap", - "sphinx_design" + "sphinx_design", ] # Paths that contain templates, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and directories # to ignore when looking for source files. -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] # -- Options for HTML output ------------------------------------------------- -html_theme = 'sphinxawesome_theme' # You can change the theme to 'sphinx_rtd_theme' or another of your choice. +html_theme = "sphinxawesome_theme" # You can change the theme to 'sphinx_rtd_theme' or another of your choice. html_title = project + release html_permalinks_icon = Icons.permalinks_icon -html_favicon = '_static/favicon.ico' -html_logo = '_static/favicon.ico' # Specify the path to the logo image file (make sure the logo is in the _static directory) +html_favicon = "_static/favicon.ico" +html_logo = "_static/favicon.ico" # Specify the path to the logo image file (make sure the logo is in the _static directory) html_last_updated_fmt = "" html_use_index = False # Don't create index html_domain_indices = False # Don't need module indices @@ -57,10 +57,14 @@ html_copy_source = False # Don't need sources html_show_sphinx = False -html_baseurl = './docs' +html_baseurl = "./docs" html_sidebars = { - "**": ['analytics.html', "sidebar_main_nav_links.html", "sidebar_toc.html", ] + "**": [ + "analytics.html", + "sidebar_main_nav_links.html", + "sidebar_toc.html", + ] } theme_options = ThemeOptions( @@ -102,7 +106,7 @@ html_theme_options = asdict(theme_options) # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] pygments_style = "lovelace" pygments_style_dark = "github-dark" @@ -111,10 +115,11 @@ sitemap_url_scheme = "{link}" # Add this configuration at the bottom of your conf.py html_context = { - 'google_analytics_id': 'G-K2LXXSX6HB', # Replace with your Google Analytics tracking ID + "google_analytics_id": "G-K2LXXSX6HB", # Replace with your Google Analytics tracking ID } -templates_path = ['_templates'] +templates_path = ["_templates"] + # -- Register a :confval: interpreted text role ---------------------------------- def setup(app: Sphinx) -> None: @@ -138,4 +143,4 @@ def setup(app: Sphinx) -> None: ], ) - app.add_css_file('_static/custom.css') + app.add_css_file("_static/custom.css") diff --git a/model_server/app/__init__.py b/model_server/app/__init__.py index 0e103e11..c2d4ff43 100644 --- a/model_server/app/__init__.py +++ b/model_server/app/__init__.py @@ -35,11 +35,13 @@ def start_server(): print("Server is already running. Use 'model_server restart' to restart it.") sys.exit(1) - print(f"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)") + print( + f"Starting Archgw Model Server - Loading some awesomeness, this may take a little time.)" + ) process = subprocess.Popen( ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "51000"], start_new_session=True, - stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to + stdout=subprocess.DEVNULL, # Suppress standard output. There is a logger that model_server prints to stderr=subprocess.DEVNULL, # Suppress standard error. There is a logger that model_server prints to ) diff --git a/model_server/app/load_models.py b/model_server/app/load_models.py index 60b62daf..a13578ff 100644 --- a/model_server/app/load_models.py +++ b/model_server/app/load_models.py @@ -7,7 +7,6 @@ from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForSequenc def get_device(): - if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): @@ -19,13 +18,15 @@ def get_device(): return device -def load_transformers(model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx")): +def load_transformers( + model_name=os.getenv("MODELS", "katanemo/bge-large-en-v1.5-onnx") +): print("Loading Embedding Model") transformers = {} device = get_device() transformers["tokenizer"] = AutoTokenizer.from_pretrained(model_name) transformers["model"] = ORTModelForFeatureExtraction.from_pretrained( - model_name, device_map = device + model_name, device_map=device ) transformers["model_name"] = model_name @@ -62,7 +63,9 @@ def load_guard_model( return guard_model -def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx")): +def load_zero_shot_models( + model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deberta-base-nli-onnx") +): zero_shot_model = {} device = get_device() zero_shot_model["model"] = ORTModelForSequenceClassification.from_pretrained( @@ -81,5 +84,6 @@ def load_zero_shot_models(model_name=os.getenv("ZERO_SHOT_MODELS", "katanemo/deb return zero_shot_model + if __name__ == "__main__": print(get_device()) diff --git a/model_server/app/main.py b/model_server/app/main.py index 88db3701..107f20c2 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -7,7 +7,12 @@ from app.load_models import ( get_device, ) import os -from app.utils import GuardHandler, split_text_into_chunks, load_yaml_config, get_model_server_logger +from app.utils import ( + GuardHandler, + split_text_into_chunks, + load_yaml_config, + get_model_server_logger, +) import torch import yaml import string @@ -39,6 +44,7 @@ guard_handler = GuardHandler(toxic_model=None, jailbreak_model=jailbreak_model) app = FastAPI() + class EmbeddingRequest(BaseModel): input: str model: str @@ -84,6 +90,7 @@ async def embedding(req: EmbeddingRequest, res: Response): } return {"data": data, "model": req.model, "object": "list", "usage": usage} + class GuardRequest(BaseModel): input: str task: str diff --git a/model_server/app/utils.py b/model_server/app/utils.py index 2a3fe5c0..f521afd7 100644 --- a/model_server/app/utils.py +++ b/model_server/app/utils.py @@ -9,6 +9,7 @@ import logging logger_instance = None + def load_yaml_config(file_name): # Load the YAML file from the package yaml_path = pkg_resources.resource_filename("app", file_name) @@ -138,6 +139,7 @@ class GuardHandler: } return result_dict + def get_model_server_logger(): global logger_instance @@ -164,8 +166,8 @@ def get_model_server_logger(): level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", handlers=[ - logging.FileHandler(log_file_path, mode='w'), # Overwrite logs in file - ] + logging.FileHandler(log_file_path, mode="w"), # Overwrite logs in file + ], ) except (PermissionError, OSError) as e: # Dont' fallback to console logging if there are issues writing to the log file