ArchFC endpoint integration (#94)

* integration

* mopdify docker file

* add params and fix python lint

* fix empty context and tool calls

* address comments

* revert port

* fix bug merge

* fix environment

* fix bug

* fix compose

* fix merge
This commit is contained in:
Co Tran 2024-10-01 12:47:26 -07:00 committed by GitHub
parent 1a7c1ad0a5
commit 17a643c410
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 98 additions and 41 deletions

View file

@ -4,7 +4,7 @@ import yaml
from jsonschema import validate
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.template.yaml')
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml')
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/root/arch_config.yaml')
ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml')
ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml')

View file

@ -473,7 +473,9 @@ impl StreamContext {
let model_resp = &arch_fc_response.choices[0];
if model_resp.message.tool_calls.is_none() {
if model_resp.message.tool_calls.is_none()
|| model_resp.message.tool_calls.as_ref().unwrap().is_empty()
{
// This means that Arch FC did not have enough information to resolve the function call
// Arch FC probably responded with a message asking for more information.
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
@ -488,12 +490,6 @@ impl StreamContext {
}
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
if tool_calls.is_empty() {
return self.send_server_error(
"No tool calls found in function resolver response".to_string(),
Some(StatusCode::BAD_REQUEST),
);
}
debug!("tool_call_details: {:?}", tool_calls);
// extract all tool names

View file

@ -1,3 +1,9 @@
x-variables: &common-vars
environment:
- MODE=${MODE:-cloud} # Set the default mode to 'cloud', others values are local-gpu, local-cpu
services:
arch:
@ -11,7 +17,10 @@ services:
- ./generated/envoy.yaml:/etc/envoy/envoy.yaml
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
- ./arch_log:/var/log/
- ./arch_config.yaml:/root/arch_config.yaml
depends_on:
# config_generator:
# condition: service_completed_successfully
model_server:
condition: service_healthy
environment:
@ -30,14 +39,15 @@ services:
volumes:
- ~/.cache/huggingface:/root/.cache/huggingface
- ./arch_config.yaml:/root/arch_config.yaml
<< : *common-vars
environment:
- OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal}
- FC_URL=${FC_URL:-empty}
- OLLAMA_MODEL=Arch-Function-Calling-3B-Q4_K_M
# use ollama endpoint that is hosted by host machine (no virtualization)
- MODE=${MODE:-cloud}
# uncomment following line to use ollama endpoint that is hosted by docker
# - OLLAMA_ENDPOINT=ollama
# - OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M
api_server:
build:
context: api_server

View file

@ -5,30 +5,52 @@ from app.arch_fc.arch_handler import ArchHandler
from app.arch_fc.bolt_handler import BoltHandler
from app.arch_fc.common import ChatMessage
import logging
import yaml
from openai import OpenAI
import os
with open("openai_params.yaml") as f:
params = yaml.safe_load(f)
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
logger = logging.getLogger('uvicorn.error')
fc_url = os.getenv("FC_URL", ollama_endpoint)
mode = os.getenv("MODE", "cloud")
if mode not in ["cloud", "local-gpu", "local-cpu"]:
raise ValueError(f"Invalid mode: {mode}")
arch_api_key = os.getenv("ARCH_API_KEY", "")
logger = logging.getLogger("uvicorn.error")
handler = None
if ollama_model.startswith("Arch"):
handler = ArchHandler()
handler = ArchHandler()
else:
handler = BoltHandler()
logger.info(f"using model: {ollama_model}")
logger.info(f"using ollama endpoint: {ollama_endpoint}")
# app = FastAPI()
client = OpenAI(
base_url='http://{}:11434/v1/'.format(ollama_endpoint),
if mode == "cloud":
client = OpenAI(
base_url=fc_url,
api_key="EMPTY",
)
models = client.models.list()
model = models.data[0].id
chosen_model = model
endpoint = fc_url
else:
client = OpenAI(
base_url="http://{}:11434/v1/".format(ollama_endpoint),
api_key="ollama",
)
chosen_model = ollama_model
endpoint = ollama_endpoint
logger.info(f"serving mode: {mode}")
logger.info(f"using model: {chosen_model}")
logger.info(f"using endpoint: {endpoint}")
# required but ignored
api_key='ollama',
)
async def chat_completion(req: ChatMessage, res: Response):
@ -38,23 +60,28 @@ async def chat_completion(req: ChatMessage, res: Response):
messages = [{"role": "system", "content": tools_encoded}]
for message in req.messages:
messages.append({"role": message.role, "content": message.content})
logger.info(f"request model: {ollama_model}, messages: {json.dumps(messages)}")
resp = client.chat.completions.create(messages=messages, model=ollama_model, stream=False)
logger.info(f"request model: {chosen_model}, messages: {json.dumps(messages)}")
completions_params = params["params"]
resp = client.chat.completions.create(
messages=messages,
model=chosen_model,
stream=False,
extra_body=completions_params,
)
tools = handler.extract_tools(resp.choices[0].message.content)
tool_calls = []
for tool in tools:
for tool_name, tool_args in tool.items():
tool_calls.append({
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_name,
"arguments": tool_args
}
})
for tool_name, tool_args in tool.items():
tool_calls.append(
{
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {"name": tool_name, "arguments": tool_args},
}
)
if tools:
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
logger.info(f"response (tools): {json.dumps(tools)}")
logger.info(f"response: {json.dumps(resp.to_dict())}")
return resp

View file

@ -89,7 +89,9 @@ class BoltHandler:
if isinstance(tool_call, dict):
try:
if not executable:
extracted_tools.append({tool_call["name"]: tool_call["arguments"]})
extracted_tools.append(
{tool_call["name"]: tool_call["arguments"]}
)
else:
name, arguments = (
tool_call.get("name", ""),

View file

@ -1,10 +1,12 @@
from typing import Any, Dict, List
from pydantic import BaseModel
class Message(BaseModel):
role: str
content: str
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]

View file

@ -9,7 +9,5 @@ print("installing transformers")
load_transformers()
print("installing ner models")
load_ner_models()
print("installing toxic models")
load_toxic_model()
print("installing jailbreak models")
load_jailbreak_model()

View file

@ -6,6 +6,7 @@ from app.load_models import (
load_guard_model,
load_zero_shot_models,
)
import os
from app.utils import GuardHandler, split_text_into_chunks
import torch
import yaml
@ -25,14 +26,27 @@ zero_shot_models = load_zero_shot_models()
with open("guard_model_config.yaml") as f:
guard_model_config = yaml.safe_load(f)
with open('/root/arch_config.yaml') as f:
config = yaml.safe_load(f)
mode = os.getenv("MODE", "cloud")
logger.info(f"Serving model mode: {mode}")
if mode not in ['cloud', 'local-gpu', 'local-cpu']:
raise ValueError(f"Invalid mode: {mode}")
if mode == 'local-cpu':
hardware = 'cpu'
else:
hardware = "gpu" if torch.cuda.is_available() else "cpu"
task = "both"
hardware = "gpu" if torch.cuda.is_available() else "cpu"
jailbreak_model = load_guard_model(
guard_model_config["jailbreak"][hardware], hardware
)
if "prompt_guards" in config.keys():
task = list(config["prompt_guards"]["input_guards"].keys())[0]
guard_handler = GuardHandler(toxic_model=None, jailbreak_model=jailbreak_model)
hardware = "gpu" if torch.cuda.is_available() else "cpu"
jailbreak_model = load_guard_model(
guard_model_config["jailbreak"][hardware], hardware
)
toxic_model = None
guard_handler = GuardHandler(toxic_model=toxic_model, jailbreak_model=jailbreak_model)
app = FastAPI()

View file

@ -0,0 +1,8 @@
params:
temperature: 0.0001
top_p : 0.5
repetition_penalty: 1.0
top_k: 50
max_tokens: 128
stop: ["<|im_start|>", "<|im_end|>"]
stop_token_ids: [151645, 151643]