mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Merge branch 'main' into adil/update_arch_config_format
This commit is contained in:
commit
0977bf5b34
16 changed files with 2197 additions and 316 deletions
|
|
@ -8,7 +8,7 @@ RUN cd prompt_gateway && cargo build --release --target wasm32-wasip1
|
|||
RUN cd llm_gateway && cargo build --release --target wasm32-wasip1
|
||||
|
||||
# copy built filter into envoy image
|
||||
FROM envoyproxy/envoy:v1.32-latest as envoy
|
||||
FROM docker.io/envoyproxy/envoy:v1.32-latest as envoy
|
||||
|
||||
#Build config generator, so that we have a single build image for both Rust and Python
|
||||
FROM python:3.12-slim as arch
|
||||
|
|
|
|||
|
|
@ -2,112 +2,26 @@ import subprocess
|
|||
import os
|
||||
import time
|
||||
import sys
|
||||
import glob
|
||||
import docker
|
||||
from docker.errors import DockerException
|
||||
from cli.utils import getLogger, update_docker_host_env
|
||||
from cli.utils import getLogger
|
||||
from cli.consts import (
|
||||
ARCHGW_DOCKER_IMAGE,
|
||||
ARCHGW_DOCKER_NAME,
|
||||
KATANEMO_LOCAL_MODEL_LIST,
|
||||
MODEL_SERVER_LOG_FILE,
|
||||
ACCESS_LOG_FILES,
|
||||
)
|
||||
from huggingface_hub import snapshot_download
|
||||
from dotenv import dotenv_values
|
||||
import yaml
|
||||
import subprocess
|
||||
from cli.docker_cli import (
|
||||
docker_container_status,
|
||||
docker_remove_container,
|
||||
docker_start_archgw_detached,
|
||||
docker_stop_container,
|
||||
health_check_endpoint,
|
||||
stream_gateway_logs,
|
||||
)
|
||||
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
def start_archgw_docker(
|
||||
client, arch_config_file, env, prompt_gateway_port, llm_gateway_port
|
||||
):
|
||||
logs_path = "~/archgw_logs"
|
||||
logs_path_abs = os.path.expanduser(logs_path)
|
||||
|
||||
return client.containers.run(
|
||||
name=ARCHGW_DOCKER_NAME,
|
||||
image=ARCHGW_DOCKER_IMAGE,
|
||||
detach=True, # Run in detached mode
|
||||
ports={
|
||||
f"{prompt_gateway_port}/tcp": prompt_gateway_port,
|
||||
"10001/tcp": 10001,
|
||||
"11000/tcp": 11000,
|
||||
f"{llm_gateway_port}/tcp": llm_gateway_port,
|
||||
"9901/tcp": 19901,
|
||||
},
|
||||
volumes={
|
||||
f"{arch_config_file}": {
|
||||
"bind": "/app/arch_config.yaml",
|
||||
"mode": "ro",
|
||||
},
|
||||
"/etc/ssl/cert.pem": {"bind": "/etc/ssl/cert.pem", "mode": "ro"},
|
||||
logs_path_abs: {"bind": "/var/log"},
|
||||
},
|
||||
environment={
|
||||
"OTEL_TRACING_HTTP_ENDPOINT": "http://host.docker.internal:4318/v1/traces",
|
||||
"MODEL_SERVER_PORT": os.getenv("MODEL_SERVER_PORT", "51000"),
|
||||
**env,
|
||||
},
|
||||
extra_hosts={"host.docker.internal": "host-gateway"},
|
||||
healthcheck={
|
||||
"test": [
|
||||
"CMD",
|
||||
"curl",
|
||||
"-f",
|
||||
f"http://localhost:{prompt_gateway_port}/healthz",
|
||||
],
|
||||
"interval": 5000000000, # 5 seconds
|
||||
"timeout": 1000000000, # 1 seconds
|
||||
"retries": 3,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def stream_gateway_logs(follow):
|
||||
"""
|
||||
Stream logs from the arch gateway service.
|
||||
"""
|
||||
log.info("Logs from arch gateway service.")
|
||||
|
||||
options = ["docker", "logs", "archgw"]
|
||||
if follow:
|
||||
options.append("-f")
|
||||
try:
|
||||
# Run `docker-compose logs` to stream logs from the gateway service
|
||||
subprocess.run(
|
||||
options,
|
||||
check=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to stream logs: {str(e)}")
|
||||
|
||||
|
||||
def stream_access_logs(follow):
|
||||
"""
|
||||
Get the archgw access logs
|
||||
"""
|
||||
log_file_pattern_expanded = os.path.expanduser(ACCESS_LOG_FILES)
|
||||
log_files = glob.glob(log_file_pattern_expanded)
|
||||
|
||||
stream_command = ["tail"]
|
||||
if follow:
|
||||
stream_command.append("-f")
|
||||
|
||||
stream_command.extend(log_files)
|
||||
subprocess.run(
|
||||
stream_command,
|
||||
check=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
def start_arch(arch_config_file, env, log_timeout=120, foreground=False):
|
||||
"""
|
||||
Start Docker Compose in detached mode and stream logs until services are healthy.
|
||||
|
|
@ -119,73 +33,47 @@ def start_arch(arch_config_file, env, log_timeout=120, foreground=False):
|
|||
log.info("Starting arch gateway")
|
||||
|
||||
try:
|
||||
try:
|
||||
client = docker.from_env()
|
||||
except DockerException as e:
|
||||
# try setting up the docker host environment variable and retry
|
||||
update_docker_host_env()
|
||||
client = docker.from_env()
|
||||
archgw_container_status = docker_container_status(ARCHGW_DOCKER_NAME)
|
||||
if archgw_container_status != "not found":
|
||||
log.info("archgw found in docker, stopping and removing it")
|
||||
docker_stop_container(ARCHGW_DOCKER_NAME)
|
||||
docker_remove_container(ARCHGW_DOCKER_NAME)
|
||||
|
||||
try:
|
||||
container = client.containers.get("archgw")
|
||||
log.info("archgw container found in docker, stopping and removing it")
|
||||
# ensure that previous docker container is stopped and removed
|
||||
container.stop()
|
||||
container.remove()
|
||||
log.info("Stopped and removed archgw container")
|
||||
except docker.errors.NotFound as e:
|
||||
pass
|
||||
|
||||
# parse arch_config_file yaml file and get prompt_gateway_port
|
||||
arch_config_dict = {}
|
||||
with open(arch_config_file) as f:
|
||||
arch_config_dict = yaml.safe_load(f)
|
||||
|
||||
prompt_gateway_port = (
|
||||
arch_config_dict.get("listeners", {})
|
||||
.get("ingress_traffic", {})
|
||||
.get("port", 10000)
|
||||
)
|
||||
llm_gateway_port = (
|
||||
arch_config_dict.get("listeners", {})
|
||||
.get("egress_traffic", {})
|
||||
.get("port", 12000)
|
||||
)
|
||||
|
||||
container = start_archgw_docker(
|
||||
client, arch_config_file, env, prompt_gateway_port, llm_gateway_port
|
||||
return_code, _, archgw_stderr = docker_start_archgw_detached(
|
||||
arch_config_file, os.path.expanduser("~/archgw_logs"), env
|
||||
)
|
||||
if return_code != 0:
|
||||
log.info("Failed to start arch gateway: " + str(return_code))
|
||||
log.info("stderr: " + archgw_stderr)
|
||||
sys.exit(1)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
while True:
|
||||
container = client.containers.get(container.id)
|
||||
health_check_status = health_check_endpoint(
|
||||
"http://localhost:10000/healthz"
|
||||
)
|
||||
archgw_status = docker_container_status(ARCHGW_DOCKER_NAME)
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - start_time
|
||||
|
||||
# Check if timeout is reached
|
||||
if elapsed_time > log_timeout:
|
||||
log.info(f"Stopping log monitoring after {log_timeout} seconds.")
|
||||
log.info(f"stopping log monitoring after {log_timeout} seconds.")
|
||||
break
|
||||
|
||||
container_status = container.attrs["State"]["Health"]["Status"]
|
||||
|
||||
if container_status == "healthy":
|
||||
log.info("Container is healthy!")
|
||||
if health_check_status:
|
||||
log.info("archgw is running and is healthy!")
|
||||
break
|
||||
else:
|
||||
log.info(f"Container health status: {container_status}")
|
||||
log.info(f"archgw status: {archgw_status}, health status: starting")
|
||||
time.sleep(1)
|
||||
|
||||
if foreground:
|
||||
for line in container.logs(stream=True):
|
||||
print(line.decode("utf-8").strip("\n"))
|
||||
stream_gateway_logs(follow=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
log.info("Keyboard interrupt received, stopping arch gateway service.")
|
||||
stop_arch()
|
||||
except docker.errors.APIError as e:
|
||||
log.info(f"Failed to start Arch: {str(e)}")
|
||||
|
||||
|
||||
def stop_arch():
|
||||
|
|
@ -199,10 +87,10 @@ def stop_arch():
|
|||
|
||||
try:
|
||||
subprocess.run(
|
||||
["docker", "stop", "archgw"],
|
||||
["docker", "stop", ARCHGW_DOCKER_NAME],
|
||||
)
|
||||
subprocess.run(
|
||||
["docker", "remove", "archgw"],
|
||||
["docker", "remove", ARCHGW_DOCKER_NAME],
|
||||
)
|
||||
|
||||
log.info("Successfully shut down arch gateway service.")
|
||||
|
|
|
|||
118
arch/tools/cli/docker_cli.py
Normal file
118
arch/tools/cli/docker_cli.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
import subprocess
|
||||
import json
|
||||
import sys
|
||||
import requests # Add this import
|
||||
|
||||
from cli.consts import ARCHGW_DOCKER_IMAGE, ARCHGW_DOCKER_NAME
|
||||
from cli.utils import getLogger
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
def docker_container_status(container: str) -> str:
|
||||
result = subprocess.run(
|
||||
["docker", "inspect", container], capture_output=True, text=True, check=False
|
||||
)
|
||||
if result.returncode != 0:
|
||||
return "not found"
|
||||
return json.loads(result.stdout)[0]["State"]["Status"]
|
||||
|
||||
|
||||
def docker_stop_container(container: str) -> str:
|
||||
result = subprocess.run(
|
||||
["docker", "stop", container], capture_output=True, text=True, check=False
|
||||
)
|
||||
return result.returncode
|
||||
|
||||
|
||||
def docker_remove_container(container: str) -> str:
|
||||
result = subprocess.run(
|
||||
["docker", "remove", container], capture_output=True, text=True, check=False
|
||||
)
|
||||
return result.returncode
|
||||
|
||||
|
||||
def docker_start_archgw_detached(
|
||||
arch_config_file: str, logs_path_abs: str, env: dict
|
||||
) -> str:
|
||||
env_args = [item for key, value in env.items() for item in ["-e", f"{key}={value}"]]
|
||||
|
||||
port_mappings = ["10000:10000", "12000:12000", "9901:19901"]
|
||||
port_mappings_args = [item for port in port_mappings for item in ("-p", port)]
|
||||
|
||||
volume_mappings = [
|
||||
f"{logs_path_abs}:/var/log:rw",
|
||||
f"{arch_config_file}:/app/arch_config.yaml:ro",
|
||||
]
|
||||
volume_mappings_args = [
|
||||
item for volume in volume_mappings for item in ("-v", volume)
|
||||
]
|
||||
|
||||
options = [
|
||||
"docker",
|
||||
"run",
|
||||
"-d",
|
||||
"--name",
|
||||
ARCHGW_DOCKER_NAME,
|
||||
*port_mappings_args,
|
||||
*volume_mappings_args,
|
||||
*env_args,
|
||||
"--add-host",
|
||||
"host.docker.internal:host-gateway",
|
||||
ARCHGW_DOCKER_IMAGE,
|
||||
]
|
||||
|
||||
result = subprocess.run(options, capture_output=True, text=True, check=False)
|
||||
return result.returncode, result.stdout, result.stderr
|
||||
|
||||
|
||||
def health_check_endpoint(endpoint: str) -> bool:
|
||||
try:
|
||||
response = requests.get(endpoint)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except requests.RequestException as e:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def stream_gateway_logs(follow):
|
||||
"""
|
||||
Stream logs from the arch gateway service.
|
||||
"""
|
||||
log.info("Logs from arch gateway service.")
|
||||
|
||||
options = ["docker", "logs", ARCHGW_DOCKER_NAME]
|
||||
if follow:
|
||||
options.append("-f")
|
||||
try:
|
||||
# Run `docker-compose logs` to stream logs from the gateway service
|
||||
subprocess.run(
|
||||
options,
|
||||
check=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
log.info(f"Failed to stream logs: {str(e)}")
|
||||
|
||||
|
||||
def docker_validate_archgw_schema(arch_config_file):
|
||||
result = subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"run",
|
||||
"--rm",
|
||||
"-v",
|
||||
f"{arch_config_file}:/app/arch_config.yaml:ro",
|
||||
"--entrypoint",
|
||||
"python",
|
||||
ARCHGW_DOCKER_IMAGE,
|
||||
"config_generator.py",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
return result.returncode, result.stdout, result.stderr
|
||||
|
|
@ -5,11 +5,12 @@ import subprocess
|
|||
import multiprocessing
|
||||
import importlib.metadata
|
||||
from cli import targets
|
||||
from cli.docker_cli import docker_validate_archgw_schema, stream_gateway_logs
|
||||
from cli.utils import (
|
||||
getLogger,
|
||||
get_llm_provider_access_keys,
|
||||
load_env_file_to_dict,
|
||||
validate_schema,
|
||||
stream_access_logs,
|
||||
)
|
||||
from cli.core import (
|
||||
start_arch_modelserver,
|
||||
|
|
@ -17,12 +18,9 @@ from cli.core import (
|
|||
start_arch,
|
||||
stop_arch,
|
||||
download_models_from_hf,
|
||||
stream_access_logs,
|
||||
stream_gateway_logs,
|
||||
)
|
||||
from cli.consts import (
|
||||
KATANEMO_DOCKERHUB_REPO,
|
||||
KATANEMO_LOCAL_MODEL_LIST,
|
||||
SERVICE_NAME_ARCHGW,
|
||||
SERVICE_NAME_MODEL_SERVER,
|
||||
SERVICE_ALL,
|
||||
|
|
@ -174,17 +172,24 @@ def up(file, path, service, foreground):
|
|||
|
||||
log.info(f"Validating {arch_config_file}")
|
||||
|
||||
try:
|
||||
validate_schema(arch_config_file)
|
||||
except Exception as e:
|
||||
log.info(f"Exiting archgw up: validation failed")
|
||||
log.info(f"Error: {str(e)}")
|
||||
(
|
||||
validation_return_code,
|
||||
validation_stdout,
|
||||
validation_stderr,
|
||||
) = docker_validate_archgw_schema(arch_config_file)
|
||||
if validation_return_code != 0:
|
||||
log.info(f"Error: Validation failed. Exiting")
|
||||
log.info(f"Validation stdout: {validation_stdout}")
|
||||
log.info(f"Validation stderr: {validation_stderr}")
|
||||
sys.exit(1)
|
||||
|
||||
log.info("Starting arch model server and arch gateway")
|
||||
|
||||
# Set the ARCH_CONFIG_FILE environment variable
|
||||
env_stage = {}
|
||||
env_stage = {
|
||||
"OTEL_TRACING_HTTP_ENDPOINT": "http://host.docker.internal:4318/v1/traces",
|
||||
"MODEL_SERVER_PORT": os.getenv("MODEL_SERVER_PORT", "51000"),
|
||||
}
|
||||
env = os.environ.copy()
|
||||
# check if access_keys are preesnt in the config file
|
||||
access_keys = get_llm_provider_access_keys(arch_config_file=arch_config_file)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import ast
|
|||
import sys
|
||||
import yaml
|
||||
from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
FLASK_ROUTE_DECORATORS = ["route", "get", "post", "put", "delete", "patch"]
|
||||
FASTAPI_ROUTE_DECORATORS = ["get", "post", "put", "delete", "patch"]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import glob
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import yaml
|
||||
import logging
|
||||
import docker
|
||||
from docker.errors import DockerException
|
||||
|
||||
from cli.consts import ARCHGW_DOCKER_IMAGE, ARCHGW_DOCKER_NAME
|
||||
from cli.consts import ACCESS_LOG_FILES
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -21,63 +22,6 @@ def getLogger(name="cli"):
|
|||
log = getLogger(__name__)
|
||||
|
||||
|
||||
def update_docker_host_env():
|
||||
"""
|
||||
Update DOCKER_HOST environment variable to use the local Docker socket
|
||||
"""
|
||||
if os.getenv("DOCKER_HOST"):
|
||||
return
|
||||
|
||||
default_docker_socket = os.getenv("DEFAULT_DOCKER_SOCKET", "/var/run/docker.sock")
|
||||
if not os.path.exists(default_docker_socket):
|
||||
home_dir = os.getenv("HOME")
|
||||
docker_host = f"unix://{home_dir}/.docker/run/docker.sock"
|
||||
log.info(
|
||||
f"Default docker socket {default_docker_socket} not found, using {docker_host}"
|
||||
)
|
||||
os.environ["DOCKER_HOST"] = docker_host
|
||||
|
||||
|
||||
def validate_schema(arch_config_file: str) -> None:
|
||||
try:
|
||||
try:
|
||||
client = docker.from_env()
|
||||
except DockerException as e:
|
||||
# try setting up the docker host environment variable and retry
|
||||
update_docker_host_env()
|
||||
client = docker.from_env()
|
||||
|
||||
container = client.containers.run(
|
||||
image=ARCHGW_DOCKER_IMAGE,
|
||||
volumes={
|
||||
f"{arch_config_file}": {
|
||||
"bind": "/app/arch_config.yaml",
|
||||
"mode": "ro",
|
||||
},
|
||||
},
|
||||
entrypoint=["python", "config_generator.py"],
|
||||
detach=True,
|
||||
)
|
||||
|
||||
# Wait for the container to finish and get the exit code
|
||||
exit_code = container.wait()
|
||||
|
||||
# Check exit code for validation success
|
||||
if exit_code["StatusCode"] != 0:
|
||||
# Validation failed (non-zero exit code)
|
||||
logs = container.logs().decode() # Get container logs for debugging
|
||||
raise ValueError(
|
||||
f"Validation failed. Container exited with code {exit_code}.\nLogs:\n{logs}"
|
||||
)
|
||||
|
||||
# Successful validation (exit code 0)
|
||||
log.info("Schema validation successful!")
|
||||
|
||||
except docker.errors.APIError as e:
|
||||
# Handle container creation error
|
||||
raise ValueError(f"Failed to create container: {e}")
|
||||
|
||||
|
||||
def get_llm_provider_access_keys(arch_config_file):
|
||||
with open(arch_config_file, "r") as file:
|
||||
arch_config = file.read()
|
||||
|
|
@ -127,3 +71,23 @@ def load_env_file_to_dict(file_path):
|
|||
env_dict[key] = value
|
||||
|
||||
return env_dict
|
||||
|
||||
|
||||
def stream_access_logs(follow):
|
||||
"""
|
||||
Get the archgw access logs
|
||||
"""
|
||||
log_file_pattern_expanded = os.path.expanduser(ACCESS_LOG_FILES)
|
||||
log_files = glob.glob(log_file_pattern_expanded)
|
||||
|
||||
stream_command = ["tail"]
|
||||
if follow:
|
||||
stream_command.append("-f")
|
||||
|
||||
stream_command.extend(log_files)
|
||||
subprocess.run(
|
||||
stream_command,
|
||||
check=True,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
|
|
|||
2023
arch/tools/poetry.lock
generated
2023
arch/tools/poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,4 +1,4 @@
|
|||
FROM jaegertracing/all-in-one:1.62.0
|
||||
FROM jaegertracing/jaeger:2.3.0
|
||||
HEALTHCHECK \
|
||||
--interval=1s \
|
||||
--timeout=1s \
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ logger = get_model_server_logger()
|
|||
|
||||
|
||||
# Define the client
|
||||
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://api.fc.archgw.com/v1")
|
||||
ARCH_ENDPOINT = os.getenv("ARCH_ENDPOINT", "https://archfc.katanemo.dev/v1")
|
||||
ARCH_API_KEY = "EMPTY"
|
||||
ARCH_CLIENT = OpenAI(base_url=ARCH_ENDPOINT, api_key=ARCH_API_KEY)
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ HALLUCINATION_THRESHOLD_DICT = {
|
|||
},
|
||||
MaskToken.PARAMETER_VALUE.value: {
|
||||
"entropy": 0.28,
|
||||
"varentropy": 1.2,
|
||||
"varentropy": 1.4,
|
||||
"probability": 0.8,
|
||||
},
|
||||
}
|
||||
|
|
@ -60,7 +60,7 @@ def check_threshold(entropy: float, varentropy: float, thd: Dict) -> bool:
|
|||
thd (dict): A dictionary containing the threshold values with keys 'entropy' and 'varentropy'.
|
||||
|
||||
Returns:
|
||||
bool: True if either the entropy or varentropy exceeds their respective thresholds, False otherwise.
|
||||
bool: True if both the entropy and varentropy exceeds their respective thresholds, False otherwise.
|
||||
"""
|
||||
return entropy > thd["entropy"] and varentropy > thd["varentropy"]
|
||||
|
||||
|
|
@ -82,7 +82,7 @@ def calculate_uncertainty(log_probs: List[float]) -> Tuple[float, float]:
|
|||
token_probs = torch.exp(log_probs)
|
||||
entropy = -torch.sum(log_probs * token_probs, dim=-1) / math.log(2, math.e)
|
||||
varentropy = torch.sum(
|
||||
token_probs * (log_probs / math.log(2, math.e)) + entropy.unsqueeze(-1) ** 2,
|
||||
token_probs * (log_probs / math.log(2, math.e) + entropy.unsqueeze(-1)) ** 2,
|
||||
dim=-1,
|
||||
)
|
||||
return entropy.item(), varentropy.item(), token_probs[0].item()
|
||||
|
|
@ -303,22 +303,30 @@ class HallucinationState:
|
|||
self.mask.append(MaskToken.PARAMETER_VALUE)
|
||||
|
||||
# checking if the parameter doesn't have enum and the token is the first parameter value token
|
||||
if (
|
||||
len(self.mask) > 1
|
||||
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
||||
and is_parameter_required(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
# check if function name is in function properties
|
||||
if self.function_name in self.function_properties:
|
||||
if (
|
||||
len(self.mask) > 1
|
||||
and self.mask[-2] != MaskToken.PARAMETER_VALUE
|
||||
and is_parameter_required(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
)
|
||||
and not is_parameter_property(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
"enum",
|
||||
)
|
||||
):
|
||||
if self.parameter_name[-1] not in self.check_parameter_name:
|
||||
self._check_logprob()
|
||||
self.check_parameter_name[self.parameter_name[-1]] = True
|
||||
else:
|
||||
self._check_logprob()
|
||||
self.error_message = f"Function name {self.function_name} not found in function properties"
|
||||
logger.warning(
|
||||
f"Function name {self.function_name} not found in function properties"
|
||||
)
|
||||
and not is_parameter_property(
|
||||
self.function_properties[self.function_name],
|
||||
self.parameter_name[-1],
|
||||
"enum",
|
||||
)
|
||||
):
|
||||
if self.parameter_name[-1] not in self.check_parameter_name:
|
||||
self._check_logprob()
|
||||
self.check_parameter_name[self.parameter_name[-1]] = True
|
||||
else:
|
||||
self.mask.append(MaskToken.NOT_USED)
|
||||
# if the state is parameter value and the token is an end token, change the state
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import time
|
|||
import logging
|
||||
import src.commons.utils as utils
|
||||
|
||||
from src.commons.globals import handler_map
|
||||
from src.commons.globals import ARCH_ENDPOINT, handler_map
|
||||
from src.core.utils.model_utils import (
|
||||
ChatMessage,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -51,6 +51,8 @@ logging.getLogger("opentelemetry.exporter.otlp.proto.grpc.exporter").setLevel(
|
|||
app = FastAPI()
|
||||
FastAPIInstrumentor().instrument_app(app)
|
||||
|
||||
logger.info(f"using archfc endpoint: {ARCH_ENDPOINT}")
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
|
|
|
|||
|
|
@ -54,20 +54,6 @@ def get_hallucination_data_complex():
|
|||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_easy():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# model will hallucinate
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_medium():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in?")
|
||||
|
|
@ -142,7 +128,6 @@ def get_greeting_data():
|
|||
"get_data_func",
|
||||
[
|
||||
get_hallucination_data_complex,
|
||||
get_hallucination_data_easy,
|
||||
get_complete_data,
|
||||
get_irrelevant_data,
|
||||
get_complete_data_2,
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ test_cases:
|
|||
- role: "assistant"
|
||||
content: "Can you please provide me the days for the weather forecast?"
|
||||
- role: "user"
|
||||
content: "los angeles in 5 days"
|
||||
content: "5 days"
|
||||
tools:
|
||||
- type: "function"
|
||||
function:
|
||||
|
|
@ -82,7 +82,7 @@ test_cases:
|
|||
required: ["location", "days"]
|
||||
expected:
|
||||
- type: "metadata"
|
||||
hallucination: true
|
||||
hallucination: false
|
||||
|
||||
- id: "[WEATHER AGENT] - multi turn, single tool, clarification"
|
||||
input:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
@model_server_endpoint = http://localhost:51000
|
||||
@archfc_endpoint = https://api.fc.archgw.com
|
||||
|
||||
@archfc_endpoint = https://archfc.katanemo.dev
|
||||
|
||||
### talk to function calling endpoint
|
||||
POST {{model_server_endpoint}}/function_calling HTTP/1.1
|
||||
|
|
@ -119,7 +118,7 @@ Content-Type: application/json
|
|||
}
|
||||
|
||||
### talk to Arch-Intent directly for completion
|
||||
POST {{archfc_endpoint}}/v1/chat/completions HTTP/1.1
|
||||
POST {{{{archfc_endpoint}}}}/v1/chat/completions HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
@model_server_endpoint = http://localhost:51000
|
||||
@archfc_endpoint = https://api.fc.archgw.com
|
||||
@archfc_endpoint = https://archfc.katanemo.dev
|
||||
|
||||
### multi turn conversation with intent, except parameter gathering
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ Content-Type: application/json
|
|||
]
|
||||
}
|
||||
### talk to Arch-Intent directly for completion
|
||||
POST https://api.fc.archgw.com/v1/chat/completions HTTP/1.1
|
||||
POST https://archfc.katanemo.dev/v1/chat/completions HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
|
|
@ -126,7 +126,7 @@ Content-Type: application/json
|
|||
]
|
||||
}
|
||||
### talk to Arch-Intent directly for completion, expect No
|
||||
POST https://api.fc.archgw.com/v1/chat/completions HTTP/1.1
|
||||
POST https://archfc.katanemo.dev/v1/chat/completions HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
@model_server_endpoint = http://localhost:51000
|
||||
@archfc_endpoint = https://api.fc.archgw.com
|
||||
@archfc_endpoint = https://archfc.katanemo.dev
|
||||
|
||||
### single turn function calling all parameters insurance agent summary
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue