mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
ArchFC endpoint integration (#94)
* integration * mopdify docker file * add params and fix python lint * fix empty context and tool calls * address comments * revert port * fix bug merge * fix environment * fix bug * fix compose * fix merge
This commit is contained in:
parent
1a7c1ad0a5
commit
17a643c410
9 changed files with 98 additions and 41 deletions
|
|
@ -4,7 +4,7 @@ import yaml
|
|||
from jsonschema import validate
|
||||
|
||||
ENVOY_CONFIG_TEMPLATE_FILE = os.getenv('ENVOY_CONFIG_TEMPLATE_FILE', 'envoy.template.yaml')
|
||||
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/config/arch_config.yaml')
|
||||
ARCH_CONFIG_FILE = os.getenv('ARCH_CONFIG_FILE', '/root/arch_config.yaml')
|
||||
ENVOY_CONFIG_FILE_RENDERED = os.getenv('ENVOY_CONFIG_FILE_RENDERED', '/etc/envoy/envoy.yaml')
|
||||
ARCH_CONFIG_SCHEMA_FILE = os.getenv('ARCH_CONFIG_SCHEMA_FILE', 'arch_config_schema.yaml')
|
||||
|
||||
|
|
|
|||
|
|
@ -473,7 +473,9 @@ impl StreamContext {
|
|||
|
||||
let model_resp = &arch_fc_response.choices[0];
|
||||
|
||||
if model_resp.message.tool_calls.is_none() {
|
||||
if model_resp.message.tool_calls.is_none()
|
||||
|| model_resp.message.tool_calls.as_ref().unwrap().is_empty()
|
||||
{
|
||||
// This means that Arch FC did not have enough information to resolve the function call
|
||||
// Arch FC probably responded with a message asking for more information.
|
||||
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
|
||||
|
|
@ -488,12 +490,6 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
|
||||
if tool_calls.is_empty() {
|
||||
return self.send_server_error(
|
||||
"No tool calls found in function resolver response".to_string(),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
|
||||
debug!("tool_call_details: {:?}", tool_calls);
|
||||
// extract all tool names
|
||||
|
|
|
|||
|
|
@ -1,3 +1,9 @@
|
|||
|
||||
x-variables: &common-vars
|
||||
environment:
|
||||
- MODE=${MODE:-cloud} # Set the default mode to 'cloud', others values are local-gpu, local-cpu
|
||||
|
||||
|
||||
services:
|
||||
|
||||
arch:
|
||||
|
|
@ -11,7 +17,10 @@ services:
|
|||
- ./generated/envoy.yaml:/etc/envoy/envoy.yaml
|
||||
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
|
||||
- ./arch_log:/var/log/
|
||||
- ./arch_config.yaml:/root/arch_config.yaml
|
||||
depends_on:
|
||||
# config_generator:
|
||||
# condition: service_completed_successfully
|
||||
model_server:
|
||||
condition: service_healthy
|
||||
environment:
|
||||
|
|
@ -30,14 +39,15 @@ services:
|
|||
volumes:
|
||||
- ~/.cache/huggingface:/root/.cache/huggingface
|
||||
- ./arch_config.yaml:/root/arch_config.yaml
|
||||
<< : *common-vars
|
||||
environment:
|
||||
- OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal}
|
||||
- FC_URL=${FC_URL:-empty}
|
||||
- OLLAMA_MODEL=Arch-Function-Calling-3B-Q4_K_M
|
||||
# use ollama endpoint that is hosted by host machine (no virtualization)
|
||||
- MODE=${MODE:-cloud}
|
||||
# uncomment following line to use ollama endpoint that is hosted by docker
|
||||
# - OLLAMA_ENDPOINT=ollama
|
||||
# - OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M
|
||||
|
||||
api_server:
|
||||
build:
|
||||
context: api_server
|
||||
|
|
|
|||
|
|
@ -5,30 +5,52 @@ from app.arch_fc.arch_handler import ArchHandler
|
|||
from app.arch_fc.bolt_handler import BoltHandler
|
||||
from app.arch_fc.common import ChatMessage
|
||||
import logging
|
||||
import yaml
|
||||
from openai import OpenAI
|
||||
import os
|
||||
|
||||
|
||||
with open("openai_params.yaml") as f:
|
||||
params = yaml.safe_load(f)
|
||||
|
||||
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
|
||||
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
|
||||
logger = logging.getLogger('uvicorn.error')
|
||||
fc_url = os.getenv("FC_URL", ollama_endpoint)
|
||||
mode = os.getenv("MODE", "cloud")
|
||||
if mode not in ["cloud", "local-gpu", "local-cpu"]:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
arch_api_key = os.getenv("ARCH_API_KEY", "")
|
||||
logger = logging.getLogger("uvicorn.error")
|
||||
|
||||
handler = None
|
||||
if ollama_model.startswith("Arch"):
|
||||
handler = ArchHandler()
|
||||
handler = ArchHandler()
|
||||
else:
|
||||
handler = BoltHandler()
|
||||
|
||||
logger.info(f"using model: {ollama_model}")
|
||||
logger.info(f"using ollama endpoint: {ollama_endpoint}")
|
||||
|
||||
# app = FastAPI()
|
||||
|
||||
client = OpenAI(
|
||||
base_url='http://{}:11434/v1/'.format(ollama_endpoint),
|
||||
if mode == "cloud":
|
||||
client = OpenAI(
|
||||
base_url=fc_url,
|
||||
api_key="EMPTY",
|
||||
)
|
||||
models = client.models.list()
|
||||
model = models.data[0].id
|
||||
chosen_model = model
|
||||
endpoint = fc_url
|
||||
else:
|
||||
client = OpenAI(
|
||||
base_url="http://{}:11434/v1/".format(ollama_endpoint),
|
||||
api_key="ollama",
|
||||
)
|
||||
chosen_model = ollama_model
|
||||
endpoint = ollama_endpoint
|
||||
logger.info(f"serving mode: {mode}")
|
||||
logger.info(f"using model: {chosen_model}")
|
||||
logger.info(f"using endpoint: {endpoint}")
|
||||
|
||||
# required but ignored
|
||||
api_key='ollama',
|
||||
)
|
||||
|
||||
|
||||
async def chat_completion(req: ChatMessage, res: Response):
|
||||
|
|
@ -38,23 +60,28 @@ async def chat_completion(req: ChatMessage, res: Response):
|
|||
messages = [{"role": "system", "content": tools_encoded}]
|
||||
for message in req.messages:
|
||||
messages.append({"role": message.role, "content": message.content})
|
||||
logger.info(f"request model: {ollama_model}, messages: {json.dumps(messages)}")
|
||||
resp = client.chat.completions.create(messages=messages, model=ollama_model, stream=False)
|
||||
logger.info(f"request model: {chosen_model}, messages: {json.dumps(messages)}")
|
||||
completions_params = params["params"]
|
||||
resp = client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=chosen_model,
|
||||
stream=False,
|
||||
extra_body=completions_params,
|
||||
)
|
||||
tools = handler.extract_tools(resp.choices[0].message.content)
|
||||
tool_calls = []
|
||||
for tool in tools:
|
||||
for tool_name, tool_args in tool.items():
|
||||
tool_calls.append({
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": tool_args
|
||||
}
|
||||
})
|
||||
for tool_name, tool_args in tool.items():
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {"name": tool_name, "arguments": tool_args},
|
||||
}
|
||||
)
|
||||
if tools:
|
||||
resp.choices[0].message.tool_calls = tool_calls
|
||||
resp.choices[0].message.content = None
|
||||
resp.choices[0].message.tool_calls = tool_calls
|
||||
resp.choices[0].message.content = None
|
||||
logger.info(f"response (tools): {json.dumps(tools)}")
|
||||
logger.info(f"response: {json.dumps(resp.to_dict())}")
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -89,7 +89,9 @@ class BoltHandler:
|
|||
if isinstance(tool_call, dict):
|
||||
try:
|
||||
if not executable:
|
||||
extracted_tools.append({tool_call["name"]: tool_call["arguments"]})
|
||||
extracted_tools.append(
|
||||
{tool_call["name"]: tool_call["arguments"]}
|
||||
)
|
||||
else:
|
||||
name, arguments = (
|
||||
tool_call.get("name", ""),
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from typing import Any, Dict, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
|
|
|||
|
|
@ -9,7 +9,5 @@ print("installing transformers")
|
|||
load_transformers()
|
||||
print("installing ner models")
|
||||
load_ner_models()
|
||||
print("installing toxic models")
|
||||
load_toxic_model()
|
||||
print("installing jailbreak models")
|
||||
load_jailbreak_model()
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from app.load_models import (
|
|||
load_guard_model,
|
||||
load_zero_shot_models,
|
||||
)
|
||||
import os
|
||||
from app.utils import GuardHandler, split_text_into_chunks
|
||||
import torch
|
||||
import yaml
|
||||
|
|
@ -25,14 +26,27 @@ zero_shot_models = load_zero_shot_models()
|
|||
|
||||
with open("guard_model_config.yaml") as f:
|
||||
guard_model_config = yaml.safe_load(f)
|
||||
with open('/root/arch_config.yaml') as f:
|
||||
config = yaml.safe_load(f)
|
||||
mode = os.getenv("MODE", "cloud")
|
||||
logger.info(f"Serving model mode: {mode}")
|
||||
if mode not in ['cloud', 'local-gpu', 'local-cpu']:
|
||||
raise ValueError(f"Invalid mode: {mode}")
|
||||
if mode == 'local-cpu':
|
||||
hardware = 'cpu'
|
||||
else:
|
||||
hardware = "gpu" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
task = "both"
|
||||
hardware = "gpu" if torch.cuda.is_available() else "cpu"
|
||||
jailbreak_model = load_guard_model(
|
||||
guard_model_config["jailbreak"][hardware], hardware
|
||||
)
|
||||
if "prompt_guards" in config.keys():
|
||||
task = list(config["prompt_guards"]["input_guards"].keys())[0]
|
||||
|
||||
guard_handler = GuardHandler(toxic_model=None, jailbreak_model=jailbreak_model)
|
||||
hardware = "gpu" if torch.cuda.is_available() else "cpu"
|
||||
jailbreak_model = load_guard_model(
|
||||
guard_model_config["jailbreak"][hardware], hardware
|
||||
)
|
||||
toxic_model = None
|
||||
|
||||
guard_handler = GuardHandler(toxic_model=toxic_model, jailbreak_model=jailbreak_model)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
|
|||
8
model_server/openai_params.yaml
Normal file
8
model_server/openai_params.yaml
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
params:
|
||||
temperature: 0.0001
|
||||
top_p : 0.5
|
||||
repetition_penalty: 1.0
|
||||
top_k: 50
|
||||
max_tokens: 128
|
||||
stop: ["<|im_start|>", "<|im_end|>"]
|
||||
stop_token_ids: [151645, 151643]
|
||||
Loading…
Add table
Add a link
Reference in a new issue