don't compute embeddings for names and other fixes see description (#126)

* serialize tools - 2

* fix int tests

* fix int test

* fix unit tests
This commit is contained in:
Adil Hafeez 2024-10-05 19:25:16 -07:00 committed by GitHub
parent 0e5ea3d6db
commit 2a747df7c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 125 additions and 86 deletions

View file

@ -69,14 +69,9 @@ static_resources:
clusters:
- name: openai
connect_timeout: 5s
dns_lookup_family: V4_ONLY
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
typed_extension_protocol_options:
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
explicit_http_config:
http2_protocol_options: {}
load_assignment:
cluster_name: openai
endpoints:
@ -98,14 +93,9 @@ static_resources:
tls_maximum_protocol_version: TLSv1_3
- name: mistral
connect_timeout: 5s
dns_lookup_family: V4_ONLY
type: LOGICAL_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
typed_extension_protocol_options:
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
explicit_http_config:
http2_protocol_options: {}
load_assignment:
cluster_name: mistral
endpoints:
@ -124,6 +114,7 @@ static_resources:
- name: model_server
connect_timeout: 5s
type: STRICT_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: model_server
@ -138,6 +129,7 @@ static_resources:
- name: mistral_7b_instruct
connect_timeout: 5s
type: STRICT_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: mistral_7b_instruct
@ -152,6 +144,7 @@ static_resources:
- name: arch_fc
connect_timeout: 5s
type: STRICT_DNS
dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: arch_fc

View file

@ -12,3 +12,4 @@ pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
// pub const ARCH_STATE_HEADER: &str = "x-arch-state";

View file

@ -72,11 +72,6 @@ impl FilterContext {
fn process_prompt_targets(&self) {
for values in self.prompt_targets.iter() {
let prompt_target = values.1;
self.schedule_embeddings_call(
&prompt_target.name,
&prompt_target.name,
EmbeddingType::Name,
);
self.schedule_embeddings_call(
&prompt_target.name,
&prompt_target.description,

View file

@ -65,7 +65,7 @@ pub trait Client: Context {
}
Err(status) => Err(ClientError::DispatchError {
upstream_name: String::from(call_args.upstream),
internal_status: status.clone(),
internal_status: status,
}),
}
}

View file

@ -469,6 +469,7 @@ impl StreamContext {
tools: Some(chat_completion_tools),
stream: false,
stream_options: None,
metadata: None,
};
let msg_body = match serde_json::to_string(&chat_completions) {
@ -686,6 +687,7 @@ impl StreamContext {
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let json_string = match serde_json::to_string(&chat_completions_request) {
@ -875,6 +877,7 @@ impl StreamContext {
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
debug!("sending response back to default llm: {}", json_resp);

View file

@ -254,7 +254,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config))
.returning(Some(config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
@ -276,22 +276,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
)
.returning(Some(101))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("model_server"),
Some(vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(102))
.expect_metric_increment("active_http_calls", 1)
.expect_set_tick_period_millis(Some(0))
.execute_and_expect(ReturnType::None)
.unwrap();
@ -335,31 +319,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_http_call_response(
filter_context,
102,
0,
embedding_response_str.len() as i32,
0,
)
.expect_log(
Some(LogLevel::Debug),
Some(
format!(
"filter_context: on_http_call_response called with token_id: {:?}",
102
)
.as_str(),
),
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embedding_response_str))
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
filter_context
}
@ -599,6 +558,7 @@ fn request_ratelimited() {
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
@ -712,6 +672,7 @@ fn request_not_ratelimited() {
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();

View file

@ -1,5 +1,6 @@
import json
import os
from openai import OpenAI
from openai import OpenAI, DefaultHttpxClient
import gradio as gr
import logging as log
from dotenv import load_dotenv
@ -13,11 +14,13 @@ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT)
client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT, http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}))
def predict(message, history):
def predict(message, state):
if 'history' not in state:
state['history'] = []
history = state.get("history")
history.append({"role": "user", "content": message})
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
log.info("history: ", history)
# Custom headers
@ -27,34 +30,42 @@ def predict(message, history):
'x-arch-deterministic-provider': 'openai',
}
metadata = None
if 'arch_state' in state:
metadata = {"x-arch-state": state['arch_state']}
try:
response = client.chat.completions.create(model=MODEL_NAME,
messages= history,
raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME,
messages = history,
temperature=1.0,
metadata=metadata,
extra_headers=custom_headers
)
except Exception as e:
log.info(e)
# remove last user message in case of exception
history.pop()
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
log.info("Error calling gateway API: {}".format(e.message))
raise gr.Error("Error calling gateway API: {}".format(e.message))
choices = response.choices
message = choices[0].message
content = message.content
history.append({"role": "assistant", "content": content})
history[-1]["model"] = response.model
response = raw_response.parse()
# extract arch_state from metadata and store it in gradio session state
# this state must be passed back to the gateway in the next request
arch_state = json.loads(raw_response.text).get('metadata', {}).get('x-arch-state', None)
if arch_state:
state['arch_state'] = arch_state
content = response.choices[0].message.content
history.append({"role": "assistant", "content": content, "model": response.model})
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
return messages, history
return messages, state
with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo:
print("Starting Demo...")
chatbot = gr.Chatbot(label="Arch Chatbot", scale=1)
state = gr.State([])
state = gr.State({})
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1, autofocus=True)

View file

@ -5,4 +5,4 @@ asyncio==3.4.3
httpx==0.27.0
python-dotenv==1.0.1
pydantic==2.8.2
openai==1.46.1
openai==1.51.0

View file

@ -21,10 +21,6 @@ llm_providers:
provider: openai
model: gpt-4
default: true
- name: mistral-large-latest
access_key: MISTRAL_API_KEY
provider: mistral
model: mistral-large-latest
system_prompt: |
You are a helpful assistant.

View file

View file

View file

@ -3,11 +3,12 @@ import random
from fastapi import FastAPI, Response
from app.arch_fc.arch_handler import ArchHandler
from app.arch_fc.bolt_handler import BoltHandler
from app.arch_fc.common import ChatMessage
from app.arch_fc.common import ChatMessage, Message
import logging
import yaml
from openai import OpenAI
import os
import hashlib
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@ -51,14 +52,54 @@ logger.info(f"serving mode: {mode}")
logger.info(f"using model: {chosen_model}")
logger.info(f"using endpoint: {endpoint}")
def process_state(arch_state, history: list[Message]):
print("state: {}".format(arch_state))
state_json = json.loads(arch_state)
state_map = {}
if state_json:
for tools_state in state_json:
for tool_state in tools_state:
state_map[tool_state['key']] = tool_state
print(f"state_map: {json.dumps(state_map)}")
sha_history = []
updated_history = []
for hist in history:
updated_history.append({"role": hist.role, "content": hist.content})
if hist.role == 'user':
sha_history.append(hist.content)
sha256_hash = hashlib.sha256()
sha256_hash.update(json.dumps(sha_history).encode())
sha_key = sha256_hash.hexdigest()
print(f"sha_key: {sha_key}")
if sha_key in state_map:
tool_call_state = state_map[sha_key]
if 'tool_call' in tool_call_state:
tool_call_str = json.dumps(tool_call_state['tool_call'])
updated_history.append({"role": "assistant", "content": f"<tool_call>\n{tool_call_str}\n</tool_call>"})
if 'tool_response' in tool_call_state:
tool_resp = tool_call_state['tool_response']
#TODO: try with role = user as well
updated_history.append({"role": "user", "content": f"<tool_response>\n{tool_resp}\n</tool_response>"})
# we dont want to match this state with any other messages
del(state_map[sha_key])
return updated_history
async def chat_completion(req: ChatMessage, res: Response):
logger.info("starting request")
tools_encoded = handler._format_system(req.tools)
# append system prompt with tools to messages
messages = [{"role": "system", "content": tools_encoded}]
for message in req.messages:
messages.append({"role": message.role, "content": message.content})
logger.info(f"request model: {chosen_model}, messages: {json.dumps(messages)}")
metadata = req.metadata
arch_state = metadata.get("x-arch-state", "[]")
updated_history = process_state(arch_state, req.messages)
for message in updated_history:
messages.append({"role": message["role"], "content": message["content"]})
logger.info(f"model_server => arch_fc: {chosen_model}, messages: {json.dumps(messages)}")
completions_params = params["params"]
resp = client.chat.completions.create(
messages=messages,
@ -80,6 +121,6 @@ async def chat_completion(req: ChatMessage, res: Response):
if tools:
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
logger.info(f"response (tools): {json.dumps(tools)}")
logger.info(f"response: {json.dumps(resp.to_dict())}")
logger.info(f"model_server <= arch_fc: (tools): {json.dumps(tools)}")
logger.info(f"model_server <= arch_fc: response body: {json.dumps(resp.to_dict())}")
return resp

View file

@ -10,3 +10,5 @@ class Message(BaseModel):
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
# todo: make it default none
metadata: Dict[str, str] = {}

View file

@ -0,0 +1,17 @@
import json
import pytest
from app.arch_fc.arch_fc import process_state
from app.arch_fc.common import ChatMessage, Message
# test process_state
arch_state = '[[{"key": "cafbda799879e1dce6cd3de3c3e8a40052a93addec457bda0b2f21f8c86b3424", "message": {"role": "user", "content": "how is the weather in chicago?"}, "tool_call": {"name": "weather_forecast", "arguments": {"city": "Chicago"}}, "tool_response": "{\\"city\\":\\"Chicago\\",\\"temperature\\":[{\\"date\\":\\"2024-10-05\\",\\"temperature\\":{\\"min\\":51,\\"max\\":70},\\"query_time\\":\\"2024-10-05 08:18:00.264171+00:00\\"},{\\"date\\":\\"2024-10-06\\",\\"temperature\\":{\\"min\\":77,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264186+00:00\\"},{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":66,\\"max\\":84},\\"query_time\\":\\"2024-10-05 08:18:00.264190+00:00\\"},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":77,\\"max\\":94},\\"query_time\\":\\"2024-10-05 08:18:00.264209+00:00\\"},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":76,\\"max\\":92},\\"query_time\\":\\"2024-10-05 08:18:00.264518+00:00\\"},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":56,\\"max\\":68},\\"query_time\\":\\"2024-10-05 08:18:00.264550+00:00\\"},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":73,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264559+00:00\\"}],\\"unit\\":\\"F\\"}"}]]'
def test_process_state():
history = []
history.append(Message(role="user", content="how is the weather in chicago?"))
updated_history = process_state(arch_state, history)
print(json.dumps(updated_history, indent=2))
if __name__ == "__main__":
pytest.main()

View file

@ -17,3 +17,4 @@ openai
pandas
tf-keras
onnx
pytest

View file

@ -50,6 +50,8 @@ pub mod open_ai {
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -209,11 +211,26 @@ pub mod open_ai {
pub arguments: HashMap<String, Value>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ToolCallState {
pub key: String,
pub message: Option<Message>,
pub tool_call: FunctionCallDetail,
pub tool_response: String,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ArchState {
ToolCall(Vec<ToolCallState>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsResponse {
pub usage: Usage,
pub choices: Vec<Choice>,
pub model: String,
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -360,6 +377,7 @@ mod test {
stream_options: Some(super::open_ai::StreamOptions {
include_usage: true,
}),
metadata: None,
};
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();