mirror of
https://github.com/katanemo/plano.git
synced 2026-07-02 15:51:02 +02:00
add support for default target (#111)
* add support for default target * add more fixes
This commit is contained in:
parent
c8d0dbec26
commit
1b57a49c9d
8 changed files with 215 additions and 88 deletions
|
|
@ -67,6 +67,8 @@ properties:
|
||||||
type: boolean
|
type: boolean
|
||||||
description:
|
description:
|
||||||
type: string
|
type: string
|
||||||
|
auto_llm_dispatch_on_response:
|
||||||
|
type: boolean
|
||||||
parameters:
|
parameters:
|
||||||
type: array
|
type: array
|
||||||
items:
|
items:
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ enum ResponseHandlerType {
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
ZeroShotIntent,
|
ZeroShotIntent,
|
||||||
ArchGuard,
|
ArchGuard,
|
||||||
|
DefaultTarget,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct CallContext {
|
pub struct CallContext {
|
||||||
|
|
@ -179,12 +180,16 @@ impl StreamContext {
|
||||||
|
|
||||||
let prompt_target_names = prompt_targets
|
let prompt_target_names = prompt_targets
|
||||||
.iter()
|
.iter()
|
||||||
|
// exclude default target
|
||||||
|
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
|
||||||
.map(|(name, _)| name.clone())
|
.map(|(name, _)| name.clone())
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let similarity_scores: Vec<(String, f64)> = prompt_targets
|
let similarity_scores: Vec<(String, f64)> = prompt_targets
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(prompt_name, _prompt_target)| {
|
// exclude default prompt target
|
||||||
|
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
|
||||||
|
.map(|(prompt_name, _)| {
|
||||||
let default_embeddings = HashMap::new();
|
let default_embeddings = HashMap::new();
|
||||||
let pte = prompt_target_embeddings
|
let pte = prompt_target_embeddings
|
||||||
.get(prompt_name)
|
.get(prompt_name)
|
||||||
|
|
@ -331,34 +336,84 @@ impl StreamContext {
|
||||||
|
|
||||||
// check to ensure that the prompt target similarity score is above the threshold
|
// check to ensure that the prompt target similarity score is above the threshold
|
||||||
if prompt_target_similarity_score < prompt_target_intent_matching_threshold
|
if prompt_target_similarity_score < prompt_target_intent_matching_threshold
|
||||||
&& !arch_assistant
|
|| arch_assistant
|
||||||
{
|
{
|
||||||
|
debug!("intent score is low or arch assistant is handling the conversation");
|
||||||
// if arch fc responded to the user message, then we don't need to check the similarity score
|
// if arch fc responded to the user message, then we don't need to check the similarity score
|
||||||
// it may be that arch fc is handling the conversation for parameter collection
|
// it may be that arch fc is handling the conversation for parameter collection
|
||||||
if arch_assistant {
|
if arch_assistant {
|
||||||
info!("arch assistant is handling the conversation");
|
info!("arch assistant is handling the conversation");
|
||||||
} else {
|
} else {
|
||||||
info!(
|
debug!("checking for default prompt target");
|
||||||
"prompt target below limit: {:.3}, threshold: {:.3}, continue conversation with user",
|
if let Some(default_prompt_target) = self
|
||||||
prompt_target_similarity_score,
|
.prompt_targets
|
||||||
prompt_target_intent_matching_threshold
|
.read()
|
||||||
);
|
.unwrap()
|
||||||
|
.values()
|
||||||
|
.find(|pt| pt.default.unwrap_or(false))
|
||||||
|
{
|
||||||
|
debug!("default prompt target found");
|
||||||
|
let endpoint = default_prompt_target.endpoint.clone().unwrap();
|
||||||
|
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
|
||||||
|
|
||||||
|
let upstream_endpoint = endpoint.name;
|
||||||
|
let mut params = HashMap::new();
|
||||||
|
params.insert(
|
||||||
|
ARCH_MESSAGES_KEY.to_string(),
|
||||||
|
callout_context.request_body.messages.clone(),
|
||||||
|
);
|
||||||
|
let arch_messages_json = serde_json::to_string(¶ms).unwrap();
|
||||||
|
debug!("no prompt target found with similarity score above threshold, using default prompt target");
|
||||||
|
let token_id = match self.dispatch_http_call(
|
||||||
|
&upstream_endpoint,
|
||||||
|
vec![
|
||||||
|
(":method", "POST"),
|
||||||
|
(":path", &upstream_path),
|
||||||
|
(":authority", &upstream_endpoint),
|
||||||
|
("content-type", "application/json"),
|
||||||
|
("x-envoy-max-retries", "3"),
|
||||||
|
(
|
||||||
|
"x-envoy-upstream-rq-timeout-ms",
|
||||||
|
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
Some(arch_messages_json.as_bytes()),
|
||||||
|
vec![],
|
||||||
|
Duration::from_secs(5),
|
||||||
|
) {
|
||||||
|
Ok(token_id) => token_id,
|
||||||
|
Err(e) => {
|
||||||
|
let error_msg =
|
||||||
|
format!("Error dispatching HTTP call for default-target: {:?}", e);
|
||||||
|
return self
|
||||||
|
.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
self.metrics.active_http_calls.increment(1);
|
||||||
|
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
|
||||||
|
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
|
||||||
|
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||||
|
panic!("duplicate token_id")
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
self.resume_http_request();
|
self.resume_http_request();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let prompt_target = self
|
let prompt_target = match self.prompt_targets.read().unwrap().get(&prompt_target_name) {
|
||||||
.prompt_targets
|
Some(prompt_target) => prompt_target.clone(),
|
||||||
.read()
|
None => {
|
||||||
.unwrap()
|
return self.send_server_error(
|
||||||
.get(&prompt_target_name)
|
format!("Prompt target not found: {}", prompt_target_name),
|
||||||
.unwrap()
|
None,
|
||||||
.clone();
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
info!("prompt_target name: {:?}", prompt_target_name);
|
info!("prompt_target name: {:?}", prompt_target_name);
|
||||||
|
|
||||||
//TODO: handle default function resolver type
|
|
||||||
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
||||||
for pt in self.prompt_targets.read().unwrap().values() {
|
for pt in self.prompt_targets.read().unwrap().values() {
|
||||||
// only extract entity names
|
// only extract entity names
|
||||||
|
|
@ -761,6 +816,83 @@ impl StreamContext {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
|
||||||
|
let prompt_target = self
|
||||||
|
.prompt_targets
|
||||||
|
.read()
|
||||||
|
.unwrap()
|
||||||
|
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||||
|
.unwrap()
|
||||||
|
.clone();
|
||||||
|
debug!(
|
||||||
|
"response received for default target: {}",
|
||||||
|
prompt_target.name
|
||||||
|
);
|
||||||
|
// check if the default target should be dispatched to the LLM provider
|
||||||
|
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
|
||||||
|
let default_target_response_str = String::from_utf8(body).unwrap();
|
||||||
|
debug!(
|
||||||
|
"sending response back to developer: {}",
|
||||||
|
default_target_response_str
|
||||||
|
);
|
||||||
|
self.send_http_response(
|
||||||
|
StatusCode::OK.as_u16().into(),
|
||||||
|
vec![("Powered-By", "Katanemo")],
|
||||||
|
Some(default_target_response_str.as_bytes()),
|
||||||
|
);
|
||||||
|
// self.resume_http_request();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
debug!("default_target: sending api response to default llm");
|
||||||
|
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
|
||||||
|
Ok(chat_completions_resp) => chat_completions_resp,
|
||||||
|
Err(e) => {
|
||||||
|
return self.send_server_error(
|
||||||
|
format!("Error deserializing default target response: {:?}", e),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let api_resp = chat_completions_resp.choices[0]
|
||||||
|
.message
|
||||||
|
.content
|
||||||
|
.as_ref()
|
||||||
|
.unwrap();
|
||||||
|
let mut messages = callout_context.request_body.messages;
|
||||||
|
|
||||||
|
// add system prompt
|
||||||
|
match prompt_target.system_prompt.as_ref() {
|
||||||
|
None => {}
|
||||||
|
Some(system_prompt) => {
|
||||||
|
let system_prompt_message = Message {
|
||||||
|
role: SYSTEM_ROLE.to_string(),
|
||||||
|
content: Some(system_prompt.clone()),
|
||||||
|
model: None,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
messages.push(system_prompt_message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages.push(Message {
|
||||||
|
role: USER_ROLE.to_string(),
|
||||||
|
content: Some(api_resp.clone()),
|
||||||
|
model: None,
|
||||||
|
tool_calls: None,
|
||||||
|
});
|
||||||
|
let chat_completion_request = ChatCompletionsRequest {
|
||||||
|
model: GPT_35_TURBO.to_string(),
|
||||||
|
messages,
|
||||||
|
tools: None,
|
||||||
|
stream: callout_context.request_body.stream,
|
||||||
|
stream_options: callout_context.request_body.stream_options,
|
||||||
|
};
|
||||||
|
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
|
||||||
|
debug!("sending response back to default llm: {}", json_resp);
|
||||||
|
self.set_http_request_body(0, json_resp.len(), json_resp.as_bytes());
|
||||||
|
self.resume_http_request();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||||
|
|
@ -1067,6 +1199,9 @@ impl Context for StreamContext {
|
||||||
self.function_call_response_handler(body, callout_context)
|
self.function_call_response_handler(body, callout_context)
|
||||||
}
|
}
|
||||||
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
|
||||||
|
ResponseHandlerType::DefaultTarget => {
|
||||||
|
self.default_target_handler(body, callout_context)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
self.send_server_error(
|
self.send_server_error(
|
||||||
|
|
|
||||||
|
|
@ -16,4 +16,4 @@ COPY --from=builder /runtime /usr/local
|
||||||
COPY /app /app
|
COPY /app /app
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"]
|
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"]
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,13 @@
|
||||||
|
import json
|
||||||
import random
|
import random
|
||||||
from fastapi import FastAPI, Response
|
from fastapi import FastAPI, Response
|
||||||
from datetime import datetime, date, timedelta, timezone
|
from datetime import datetime, date, timedelta, timezone
|
||||||
import logging
|
import logging
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
import pytz
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger('uvicorn.error')
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
@ -58,18 +59,28 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon
|
||||||
|
|
||||||
return claim_details
|
return claim_details
|
||||||
|
|
||||||
@app.get("/current_time")
|
|
||||||
async def current_time(timezone: str):
|
class DefaultTargetRequest(BaseModel):
|
||||||
tz = None
|
arch_messages: list
|
||||||
try:
|
|
||||||
timezone.strip('"')
|
@app.post("/default_target")
|
||||||
tz = pytz.timezone(timezone)
|
async def default_target(req: DefaultTargetRequest, res: Response):
|
||||||
except pytz.exceptions.UnknownTimeZoneError:
|
logger.info(f"Received arch_messages: {req.arch_messages}")
|
||||||
return {
|
resp = {
|
||||||
"error": "Invalid timezone: {}".format(timezone)
|
"choices": [
|
||||||
}
|
{
|
||||||
current_time = datetime.now(tz)
|
"message": {
|
||||||
return {
|
"role": "assistant",
|
||||||
"timezone": timezone,
|
"content": "hello world from api server"
|
||||||
"current_time": current_time.strftime("%Y-%m-%d %H:%M:%S %Z")
|
},
|
||||||
}
|
"finish_reason": "completed",
|
||||||
|
"index": 0
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "api_server",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.info(f"sending response: {json.dumps(resp)}")
|
||||||
|
return resp
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
fastapi
|
fastapi
|
||||||
uvicorn
|
uvicorn
|
||||||
pytz
|
pyyaml
|
||||||
|
|
|
||||||
|
|
@ -47,21 +47,6 @@ prompt_targets:
|
||||||
- Use farenheight for temperature
|
- Use farenheight for temperature
|
||||||
- Use miles per hour for wind speed
|
- Use miles per hour for wind speed
|
||||||
|
|
||||||
- name: system_time
|
|
||||||
description: This function provides the current system time.
|
|
||||||
parameters:
|
|
||||||
- name: timezone
|
|
||||||
description: The city for which the weather forecast is requested.
|
|
||||||
default: US/Pacific
|
|
||||||
type: string
|
|
||||||
endpoint:
|
|
||||||
name: api_server
|
|
||||||
path: /current_time
|
|
||||||
system_prompt: |
|
|
||||||
You are a helpful system time provider. Use system time data that is provided to you. Please following following guidelines when responding to user queries:
|
|
||||||
- Use 12 hour time format
|
|
||||||
- Use AM/PM for time
|
|
||||||
|
|
||||||
- name: insurance_claim_details
|
- name: insurance_claim_details
|
||||||
description: This function resolver provides insurance claim details for a given policy number.
|
description: This function resolver provides insurance claim details for a given policy number.
|
||||||
parameters:
|
parameters:
|
||||||
|
|
@ -80,6 +65,18 @@ prompt_targets:
|
||||||
You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries:
|
You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries:
|
||||||
- Use policy number to retrieve insurance claim details
|
- Use policy number to retrieve insurance claim details
|
||||||
|
|
||||||
|
- name: default_target
|
||||||
|
default: true
|
||||||
|
description: This is the default target for all unmatched prompts.
|
||||||
|
endpoint:
|
||||||
|
name: api_server
|
||||||
|
path: /default_target
|
||||||
|
system_prompt: |
|
||||||
|
You are a helpful assistant. Use the information that is provided to you.
|
||||||
|
# if it is set to false arch will send response that it received from this prompt target to the user
|
||||||
|
# if true arch will forward the response to the default LLM
|
||||||
|
auto_llm_dispatch_on_response: true
|
||||||
|
|
||||||
ratelimits:
|
ratelimits:
|
||||||
- provider: gpt-3.5-turbo
|
- provider: gpt-3.5-turbo
|
||||||
selector:
|
selector:
|
||||||
|
|
|
||||||
|
|
@ -1,41 +1,25 @@
|
||||||
FROM python:3.10 AS base
|
FROM python:3.10 AS builder
|
||||||
|
|
||||||
#
|
COPY requirements.txt .
|
||||||
# builder
|
RUN pip install --prefix=/runtime -r requirements.txt
|
||||||
#
|
|
||||||
FROM base AS builder
|
|
||||||
|
|
||||||
WORKDIR /src
|
|
||||||
RUN pip install --upgrade pip
|
|
||||||
|
|
||||||
# Install git (needed for cloning the repository)
|
|
||||||
RUN apt-get update && apt-get install -y git && apt-get clean
|
|
||||||
|
|
||||||
COPY requirements.txt /src/
|
|
||||||
|
|
||||||
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
|
|
||||||
|
|
||||||
COPY . /src
|
|
||||||
|
|
||||||
#
|
|
||||||
# output
|
|
||||||
#
|
|
||||||
|
|
||||||
FROM python:3.10-slim AS output
|
FROM python:3.10-slim AS output
|
||||||
|
|
||||||
|
# curl is needed for health check in docker-compose
|
||||||
|
RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
COPY --from=builder /runtime /usr/local
|
||||||
|
|
||||||
|
WORKDIR /src
|
||||||
|
|
||||||
# specify list of models that will go into the image as a comma separated list
|
# specify list of models that will go into the image as a comma separated list
|
||||||
# following models have been tested to work with this image
|
# following models have been tested to work with this image
|
||||||
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
|
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
|
||||||
ENV MODELS="BAAI/bge-large-en-v1.5"
|
ENV MODELS="BAAI/bge-large-en-v1.5"
|
||||||
|
|
||||||
COPY --from=builder /runtime /usr/local
|
COPY ./app ./app
|
||||||
|
COPY ./guard_model_config.yaml .
|
||||||
COPY ./ /app
|
COPY ./openai_params.yaml .
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y \
|
|
||||||
curl \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# comment it out for now as we don't want to download the model every time we build the image
|
# comment it out for now as we don't want to download the model every time we build the image
|
||||||
# we will mount host cache to docker image to avoid downloading the model every time
|
# we will mount host cache to docker image to avoid downloading the model every time
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,10 @@ import yaml
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
with open("openai_params.yaml") as f:
|
with open("openai_params.yaml") as f:
|
||||||
params = yaml.safe_load(f)
|
params = yaml.safe_load(f)
|
||||||
|
|
@ -20,7 +24,6 @@ mode = os.getenv("MODE", "cloud")
|
||||||
if mode not in ["cloud", "local-gpu", "local-cpu"]:
|
if mode not in ["cloud", "local-gpu", "local-cpu"]:
|
||||||
raise ValueError(f"Invalid mode: {mode}")
|
raise ValueError(f"Invalid mode: {mode}")
|
||||||
arch_api_key = os.getenv("ARCH_API_KEY", "vllm")
|
arch_api_key = os.getenv("ARCH_API_KEY", "vllm")
|
||||||
logger = logging.getLogger("uvicorn.error")
|
|
||||||
|
|
||||||
handler = None
|
handler = None
|
||||||
if ollama_model.startswith("Arch"):
|
if ollama_model.startswith("Arch"):
|
||||||
|
|
@ -28,17 +31,12 @@ if ollama_model.startswith("Arch"):
|
||||||
else:
|
else:
|
||||||
handler = BoltHandler()
|
handler = BoltHandler()
|
||||||
|
|
||||||
|
|
||||||
# app = FastAPI()
|
|
||||||
|
|
||||||
if mode == "cloud":
|
if mode == "cloud":
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
base_url=fc_url,
|
base_url=fc_url,
|
||||||
api_key="EMPTY",
|
api_key="EMPTY",
|
||||||
)
|
)
|
||||||
models = client.models.list()
|
chosen_model = "fc-cloud"
|
||||||
model = models.data[0].id
|
|
||||||
chosen_model = model
|
|
||||||
endpoint = fc_url
|
endpoint = fc_url
|
||||||
else:
|
else:
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
|
|
@ -47,12 +45,12 @@ else:
|
||||||
)
|
)
|
||||||
chosen_model = ollama_model
|
chosen_model = ollama_model
|
||||||
endpoint = ollama_endpoint
|
endpoint = ollama_endpoint
|
||||||
|
|
||||||
logger.info(f"serving mode: {mode}")
|
logger.info(f"serving mode: {mode}")
|
||||||
logger.info(f"using model: {chosen_model}")
|
logger.info(f"using model: {chosen_model}")
|
||||||
logger.info(f"using endpoint: {endpoint}")
|
logger.info(f"using endpoint: {endpoint}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion(req: ChatMessage, res: Response):
|
async def chat_completion(req: ChatMessage, res: Response):
|
||||||
logger.info("starting request")
|
logger.info("starting request")
|
||||||
tools_encoded = handler._format_system(req.tools)
|
tools_encoded = handler._format_system(req.tools)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue