add support for default target (#111)

* add support for default target

* add more fixes
This commit is contained in:
Adil Hafeez 2024-10-02 20:43:16 -07:00 committed by GitHub
parent c8d0dbec26
commit 1b57a49c9d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 215 additions and 88 deletions

View file

@ -67,6 +67,8 @@ properties:
type: boolean type: boolean
description: description:
type: string type: string
auto_llm_dispatch_on_response:
type: boolean
parameters: parameters:
type: array type: array
items: items:

View file

@ -39,6 +39,7 @@ enum ResponseHandlerType {
FunctionCall, FunctionCall,
ZeroShotIntent, ZeroShotIntent,
ArchGuard, ArchGuard,
DefaultTarget,
} }
pub struct CallContext { pub struct CallContext {
@ -179,12 +180,16 @@ impl StreamContext {
let prompt_target_names = prompt_targets let prompt_target_names = prompt_targets
.iter() .iter()
// exclude default target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(name, _)| name.clone()) .map(|(name, _)| name.clone())
.collect(); .collect();
let similarity_scores: Vec<(String, f64)> = prompt_targets let similarity_scores: Vec<(String, f64)> = prompt_targets
.iter() .iter()
.map(|(prompt_name, _prompt_target)| { // exclude default prompt target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(prompt_name, _)| {
let default_embeddings = HashMap::new(); let default_embeddings = HashMap::new();
let pte = prompt_target_embeddings let pte = prompt_target_embeddings
.get(prompt_name) .get(prompt_name)
@ -331,34 +336,84 @@ impl StreamContext {
// check to ensure that the prompt target similarity score is above the threshold // check to ensure that the prompt target similarity score is above the threshold
if prompt_target_similarity_score < prompt_target_intent_matching_threshold if prompt_target_similarity_score < prompt_target_intent_matching_threshold
&& !arch_assistant || arch_assistant
{ {
debug!("intent score is low or arch assistant is handling the conversation");
// if arch fc responded to the user message, then we don't need to check the similarity score // if arch fc responded to the user message, then we don't need to check the similarity score
// it may be that arch fc is handling the conversation for parameter collection // it may be that arch fc is handling the conversation for parameter collection
if arch_assistant { if arch_assistant {
info!("arch assistant is handling the conversation"); info!("arch assistant is handling the conversation");
} else { } else {
info!( debug!("checking for default prompt target");
"prompt target below limit: {:.3}, threshold: {:.3}, continue conversation with user", if let Some(default_prompt_target) = self
prompt_target_similarity_score, .prompt_targets
prompt_target_intent_matching_threshold .read()
); .unwrap()
.values()
.find(|pt| pt.default.unwrap_or(false))
{
debug!("default prompt target found");
let endpoint = default_prompt_target.endpoint.clone().unwrap();
let upstream_path: String = endpoint.path.unwrap_or(String::from("/"));
let upstream_endpoint = endpoint.name;
let mut params = HashMap::new();
params.insert(
ARCH_MESSAGES_KEY.to_string(),
callout_context.request_body.messages.clone(),
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
debug!("no prompt target found with similarity score above threshold, using default prompt target");
let token_id = match self.dispatch_http_call(
&upstream_endpoint,
vec![
(":method", "POST"),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
(
"x-envoy-upstream-rq-timeout-ms",
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
),
],
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
let error_msg =
format!("Error dispatching HTTP call for default-target: {:?}", e);
return self
.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
};
self.metrics.active_http_calls.increment(1);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
return;
}
self.resume_http_request(); self.resume_http_request();
return; return;
} }
} }
let prompt_target = self let prompt_target = match self.prompt_targets.read().unwrap().get(&prompt_target_name) {
.prompt_targets Some(prompt_target) => prompt_target.clone(),
.read() None => {
.unwrap() return self.send_server_error(
.get(&prompt_target_name) format!("Prompt target not found: {}", prompt_target_name),
.unwrap() None,
.clone(); );
}
};
info!("prompt_target name: {:?}", prompt_target_name); info!("prompt_target name: {:?}", prompt_target_name);
//TODO: handle default function resolver type
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new(); let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
for pt in self.prompt_targets.read().unwrap().values() { for pt in self.prompt_targets.read().unwrap().values() {
// only extract entity names // only extract entity names
@ -761,6 +816,83 @@ impl StreamContext {
) )
} }
} }
fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(callout_context.prompt_target_name.as_ref().unwrap())
.unwrap()
.clone();
debug!(
"response received for default target: {}",
prompt_target.name
);
// check if the default target should be dispatched to the LLM provider
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
let default_target_response_str = String::from_utf8(body).unwrap();
debug!(
"sending response back to developer: {}",
default_target_response_str
);
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(default_target_response_str.as_bytes()),
);
// self.resume_http_request();
return;
}
debug!("default_target: sending api response to default llm");
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
Ok(chat_completions_resp) => chat_completions_resp,
Err(e) => {
return self.send_server_error(
format!("Error deserializing default target response: {:?}", e),
None,
);
}
};
let api_resp = chat_completions_resp.choices[0]
.message
.content
.as_ref()
.unwrap();
let mut messages = callout_context.request_body.messages;
// add system prompt
match prompt_target.system_prompt.as_ref() {
None => {}
Some(system_prompt) => {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
};
messages.push(system_prompt_message);
}
}
messages.push(Message {
role: USER_ROLE.to_string(),
content: Some(api_resp.clone()),
model: None,
tool_calls: None,
});
let chat_completion_request = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
};
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
debug!("sending response back to default llm: {}", json_resp);
self.set_http_request_body(0, json_resp.len(), json_resp.as_bytes());
self.resume_http_request();
}
} }
// HttpContext is the trait that allows the Rust code to interact with HTTP objects. // HttpContext is the trait that allows the Rust code to interact with HTTP objects.
@ -1067,6 +1199,9 @@ impl Context for StreamContext {
self.function_call_response_handler(body, callout_context) self.function_call_response_handler(body, callout_context)
} }
ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context),
ResponseHandlerType::DefaultTarget => {
self.default_target_handler(body, callout_context)
}
} }
} else { } else {
self.send_server_error( self.send_server_error(

View file

@ -16,4 +16,4 @@ COPY --from=builder /runtime /usr/local
COPY /app /app COPY /app /app
WORKDIR /app WORKDIR /app
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80"] CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "info"]

View file

@ -1,12 +1,13 @@
import json
import random import random
from fastapi import FastAPI, Response from fastapi import FastAPI, Response
from datetime import datetime, date, timedelta, timezone from datetime import datetime, date, timedelta, timezone
import logging import logging
from pydantic import BaseModel from pydantic import BaseModel
import pytz
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__) logger = logging.getLogger('uvicorn.error')
logger.setLevel(logging.INFO)
app = FastAPI() app = FastAPI()
@ -58,18 +59,28 @@ async def insurance_claim_details(req: InsuranceClaimDetailsRequest, res: Respon
return claim_details return claim_details
@app.get("/current_time")
async def current_time(timezone: str): class DefaultTargetRequest(BaseModel):
tz = None arch_messages: list
try:
timezone.strip('"') @app.post("/default_target")
tz = pytz.timezone(timezone) async def default_target(req: DefaultTargetRequest, res: Response):
except pytz.exceptions.UnknownTimeZoneError: logger.info(f"Received arch_messages: {req.arch_messages}")
return { resp = {
"error": "Invalid timezone: {}".format(timezone) "choices": [
} {
current_time = datetime.now(tz) "message": {
return { "role": "assistant",
"timezone": timezone, "content": "hello world from api server"
"current_time": current_time.strftime("%Y-%m-%d %H:%M:%S %Z") },
} "finish_reason": "completed",
"index": 0
}
],
"model": "api_server",
"usage": {
"completion_tokens": 0
}
}
logger.info(f"sending response: {json.dumps(resp)}")
return resp

View file

@ -1,3 +1,3 @@
fastapi fastapi
uvicorn uvicorn
pytz pyyaml

View file

@ -47,21 +47,6 @@ prompt_targets:
- Use farenheight for temperature - Use farenheight for temperature
- Use miles per hour for wind speed - Use miles per hour for wind speed
- name: system_time
description: This function provides the current system time.
parameters:
- name: timezone
description: The city for which the weather forecast is requested.
default: US/Pacific
type: string
endpoint:
name: api_server
path: /current_time
system_prompt: |
You are a helpful system time provider. Use system time data that is provided to you. Please following following guidelines when responding to user queries:
- Use 12 hour time format
- Use AM/PM for time
- name: insurance_claim_details - name: insurance_claim_details
description: This function resolver provides insurance claim details for a given policy number. description: This function resolver provides insurance claim details for a given policy number.
parameters: parameters:
@ -80,6 +65,18 @@ prompt_targets:
You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries: You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries:
- Use policy number to retrieve insurance claim details - Use policy number to retrieve insurance claim details
- name: default_target
default: true
description: This is the default target for all unmatched prompts.
endpoint:
name: api_server
path: /default_target
system_prompt: |
You are a helpful assistant. Use the information that is provided to you.
# if it is set to false arch will send response that it received from this prompt target to the user
# if true arch will forward the response to the default LLM
auto_llm_dispatch_on_response: true
ratelimits: ratelimits:
- provider: gpt-3.5-turbo - provider: gpt-3.5-turbo
selector: selector:

View file

@ -1,41 +1,25 @@
FROM python:3.10 AS base FROM python:3.10 AS builder
# COPY requirements.txt .
# builder RUN pip install --prefix=/runtime -r requirements.txt
#
FROM base AS builder
WORKDIR /src
RUN pip install --upgrade pip
# Install git (needed for cloning the repository)
RUN apt-get update && apt-get install -y git && apt-get clean
COPY requirements.txt /src/
RUN pip install --prefix=/runtime --force-reinstall -r requirements.txt
COPY . /src
#
# output
#
FROM python:3.10-slim AS output FROM python:3.10-slim AS output
# curl is needed for health check in docker-compose
RUN apt-get update && apt-get install -y curl && apt-get clean && rm -rf /var/lib/apt/lists/*
COPY --from=builder /runtime /usr/local
WORKDIR /src
# specify list of models that will go into the image as a comma separated list # specify list of models that will go into the image as a comma separated list
# following models have been tested to work with this image # following models have been tested to work with this image
# "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small" # "sentence-transformers/all-MiniLM-L6-v2,sentence-transformers/all-mpnet-base-v2,thenlper/gte-base,thenlper/gte-large,thenlper/gte-small"
ENV MODELS="BAAI/bge-large-en-v1.5" ENV MODELS="BAAI/bge-large-en-v1.5"
COPY --from=builder /runtime /usr/local COPY ./app ./app
COPY ./guard_model_config.yaml .
COPY ./ /app COPY ./openai_params.yaml .
WORKDIR /app
RUN apt-get update && apt-get install -y \
curl \
&& rm -rf /var/lib/apt/lists/*
# comment it out for now as we don't want to download the model every time we build the image # comment it out for now as we don't want to download the model every time we build the image
# we will mount host cache to docker image to avoid downloading the model every time # we will mount host cache to docker image to avoid downloading the model every time

View file

@ -9,6 +9,10 @@ import yaml
from openai import OpenAI from openai import OpenAI
import os import os
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
with open("openai_params.yaml") as f: with open("openai_params.yaml") as f:
params = yaml.safe_load(f) params = yaml.safe_load(f)
@ -20,7 +24,6 @@ mode = os.getenv("MODE", "cloud")
if mode not in ["cloud", "local-gpu", "local-cpu"]: if mode not in ["cloud", "local-gpu", "local-cpu"]:
raise ValueError(f"Invalid mode: {mode}") raise ValueError(f"Invalid mode: {mode}")
arch_api_key = os.getenv("ARCH_API_KEY", "vllm") arch_api_key = os.getenv("ARCH_API_KEY", "vllm")
logger = logging.getLogger("uvicorn.error")
handler = None handler = None
if ollama_model.startswith("Arch"): if ollama_model.startswith("Arch"):
@ -28,17 +31,12 @@ if ollama_model.startswith("Arch"):
else: else:
handler = BoltHandler() handler = BoltHandler()
# app = FastAPI()
if mode == "cloud": if mode == "cloud":
client = OpenAI( client = OpenAI(
base_url=fc_url, base_url=fc_url,
api_key="EMPTY", api_key="EMPTY",
) )
models = client.models.list() chosen_model = "fc-cloud"
model = models.data[0].id
chosen_model = model
endpoint = fc_url endpoint = fc_url
else: else:
client = OpenAI( client = OpenAI(
@ -47,12 +45,12 @@ else:
) )
chosen_model = ollama_model chosen_model = ollama_model
endpoint = ollama_endpoint endpoint = ollama_endpoint
logger.info(f"serving mode: {mode}") logger.info(f"serving mode: {mode}")
logger.info(f"using model: {chosen_model}") logger.info(f"using model: {chosen_model}")
logger.info(f"using endpoint: {endpoint}") logger.info(f"using endpoint: {endpoint}")
async def chat_completion(req: ChatMessage, res: Response): async def chat_completion(req: ChatMessage, res: Response):
logger.info("starting request") logger.info("starting request")
tools_encoded = handler._format_system(req.tools) tools_encoded = handler._format_system(req.tools)