mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
don't compute embeddings for names and other fixes see description (#126)
* serialize tools - 2 * fix int tests * fix int test * fix unit tests
This commit is contained in:
parent
0e5ea3d6db
commit
2a747df7c0
16 changed files with 125 additions and 86 deletions
|
|
@ -69,14 +69,9 @@ static_resources:
|
|||
clusters:
|
||||
- name: openai
|
||||
connect_timeout: 5s
|
||||
dns_lookup_family: V4_ONLY
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
typed_extension_protocol_options:
|
||||
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
|
||||
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
|
||||
explicit_http_config:
|
||||
http2_protocol_options: {}
|
||||
load_assignment:
|
||||
cluster_name: openai
|
||||
endpoints:
|
||||
|
|
@ -98,14 +93,9 @@ static_resources:
|
|||
tls_maximum_protocol_version: TLSv1_3
|
||||
- name: mistral
|
||||
connect_timeout: 5s
|
||||
dns_lookup_family: V4_ONLY
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
typed_extension_protocol_options:
|
||||
envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
|
||||
"@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
|
||||
explicit_http_config:
|
||||
http2_protocol_options: {}
|
||||
load_assignment:
|
||||
cluster_name: mistral
|
||||
endpoints:
|
||||
|
|
@ -124,6 +114,7 @@ static_resources:
|
|||
- name: model_server
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: model_server
|
||||
|
|
@ -138,6 +129,7 @@ static_resources:
|
|||
- name: mistral_7b_instruct
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: mistral_7b_instruct
|
||||
|
|
@ -152,6 +144,7 @@ static_resources:
|
|||
- name: arch_fc
|
||||
connect_timeout: 5s
|
||||
type: STRICT_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
load_assignment:
|
||||
cluster_name: arch_fc
|
||||
|
|
|
|||
|
|
@ -12,3 +12,4 @@ pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
|
|||
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
|
||||
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
|
||||
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
|
||||
// pub const ARCH_STATE_HEADER: &str = "x-arch-state";
|
||||
|
|
|
|||
|
|
@ -72,11 +72,6 @@ impl FilterContext {
|
|||
fn process_prompt_targets(&self) {
|
||||
for values in self.prompt_targets.iter() {
|
||||
let prompt_target = values.1;
|
||||
self.schedule_embeddings_call(
|
||||
&prompt_target.name,
|
||||
&prompt_target.name,
|
||||
EmbeddingType::Name,
|
||||
);
|
||||
self.schedule_embeddings_call(
|
||||
&prompt_target.name,
|
||||
&prompt_target.description,
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ pub trait Client: Context {
|
|||
}
|
||||
Err(status) => Err(ClientError::DispatchError {
|
||||
upstream_name: String::from(call_args.upstream),
|
||||
internal_status: status.clone(),
|
||||
internal_status: status,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -469,6 +469,7 @@ impl StreamContext {
|
|||
tools: Some(chat_completion_tools),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||
|
|
@ -686,6 +687,7 @@ impl StreamContext {
|
|||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let json_string = match serde_json::to_string(&chat_completions_request) {
|
||||
|
|
@ -875,6 +877,7 @@ impl StreamContext {
|
|||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
|
||||
debug!("sending response back to default llm: {}", json_resp);
|
||||
|
|
|
|||
|
|
@ -254,7 +254,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
|||
module
|
||||
.call_proxy_on_configure(filter_context, config.len() as i32)
|
||||
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
|
||||
.returning(Some(&config))
|
||||
.returning(Some(config))
|
||||
.execute_and_expect(ReturnType::Bool(true))
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -276,22 +276,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
|||
)
|
||||
.returning(Some(101))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("model_server"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(102))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_set_tick_period_millis(Some(0))
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
|
@ -335,31 +319,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
|||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
filter_context,
|
||||
102,
|
||||
0,
|
||||
embedding_response_str.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_log(
|
||||
Some(LogLevel::Debug),
|
||||
Some(
|
||||
format!(
|
||||
"filter_context: on_http_call_response called with token_id: {:?}",
|
||||
102
|
||||
)
|
||||
.as_str(),
|
||||
),
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embedding_response_str))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
filter_context
|
||||
}
|
||||
|
||||
|
|
@ -599,6 +558,7 @@ fn request_ratelimited() {
|
|||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
|
|
@ -712,6 +672,7 @@ fn request_not_ratelimited() {
|
|||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
import os
|
||||
from openai import OpenAI
|
||||
from openai import OpenAI, DefaultHttpxClient
|
||||
import gradio as gr
|
||||
import logging as log
|
||||
from dotenv import load_dotenv
|
||||
|
|
@ -13,11 +14,13 @@ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
|||
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
|
||||
client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT)
|
||||
client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT, http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}))
|
||||
|
||||
def predict(message, history):
|
||||
def predict(message, state):
|
||||
if 'history' not in state:
|
||||
state['history'] = []
|
||||
history = state.get("history")
|
||||
history.append({"role": "user", "content": message})
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
log.info("history: ", history)
|
||||
|
||||
# Custom headers
|
||||
|
|
@ -27,34 +30,42 @@ def predict(message, history):
|
|||
'x-arch-deterministic-provider': 'openai',
|
||||
}
|
||||
|
||||
metadata = None
|
||||
if 'arch_state' in state:
|
||||
metadata = {"x-arch-state": state['arch_state']}
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(model=MODEL_NAME,
|
||||
messages= history,
|
||||
raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME,
|
||||
messages = history,
|
||||
temperature=1.0,
|
||||
metadata=metadata,
|
||||
extra_headers=custom_headers
|
||||
)
|
||||
except Exception as e:
|
||||
log.info(e)
|
||||
# remove last user message in case of exception
|
||||
history.pop()
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
log.info("Error calling gateway API: {}".format(e.message))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e.message))
|
||||
|
||||
choices = response.choices
|
||||
message = choices[0].message
|
||||
content = message.content
|
||||
history.append({"role": "assistant", "content": content})
|
||||
history[-1]["model"] = response.model
|
||||
response = raw_response.parse()
|
||||
|
||||
# extract arch_state from metadata and store it in gradio session state
|
||||
# this state must be passed back to the gateway in the next request
|
||||
arch_state = json.loads(raw_response.text).get('metadata', {}).get('x-arch-state', None)
|
||||
if arch_state:
|
||||
state['arch_state'] = arch_state
|
||||
|
||||
content = response.choices[0].message.content
|
||||
|
||||
history.append({"role": "assistant", "content": content, "model": response.model})
|
||||
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
|
||||
return messages, history
|
||||
|
||||
return messages, state
|
||||
|
||||
with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo:
|
||||
print("Starting Demo...")
|
||||
chatbot = gr.Chatbot(label="Arch Chatbot", scale=1)
|
||||
state = gr.State([])
|
||||
state = gr.State({})
|
||||
with gr.Row():
|
||||
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1, autofocus=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,4 +5,4 @@ asyncio==3.4.3
|
|||
httpx==0.27.0
|
||||
python-dotenv==1.0.1
|
||||
pydantic==2.8.2
|
||||
openai==1.46.1
|
||||
openai==1.51.0
|
||||
|
|
|
|||
|
|
@ -21,10 +21,6 @@ llm_providers:
|
|||
provider: openai
|
||||
model: gpt-4
|
||||
default: true
|
||||
- name: mistral-large-latest
|
||||
access_key: MISTRAL_API_KEY
|
||||
provider: mistral
|
||||
model: mistral-large-latest
|
||||
|
||||
system_prompt: |
|
||||
You are a helpful assistant.
|
||||
|
|
|
|||
0
model_server/app/__init__.py
Normal file
0
model_server/app/__init__.py
Normal file
0
model_server/app/arch_fc/__init__.py
Normal file
0
model_server/app/arch_fc/__init__.py
Normal file
|
|
@ -3,11 +3,12 @@ import random
|
|||
from fastapi import FastAPI, Response
|
||||
from app.arch_fc.arch_handler import ArchHandler
|
||||
from app.arch_fc.bolt_handler import BoltHandler
|
||||
from app.arch_fc.common import ChatMessage
|
||||
from app.arch_fc.common import ChatMessage, Message
|
||||
import logging
|
||||
import yaml
|
||||
from openai import OpenAI
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
|
|
@ -51,14 +52,54 @@ logger.info(f"serving mode: {mode}")
|
|||
logger.info(f"using model: {chosen_model}")
|
||||
logger.info(f"using endpoint: {endpoint}")
|
||||
|
||||
def process_state(arch_state, history: list[Message]):
|
||||
print("state: {}".format(arch_state))
|
||||
state_json = json.loads(arch_state)
|
||||
|
||||
state_map = {}
|
||||
if state_json:
|
||||
for tools_state in state_json:
|
||||
for tool_state in tools_state:
|
||||
state_map[tool_state['key']] = tool_state
|
||||
|
||||
print(f"state_map: {json.dumps(state_map)}")
|
||||
|
||||
sha_history = []
|
||||
updated_history = []
|
||||
for hist in history:
|
||||
updated_history.append({"role": hist.role, "content": hist.content})
|
||||
if hist.role == 'user':
|
||||
sha_history.append(hist.content)
|
||||
sha256_hash = hashlib.sha256()
|
||||
sha256_hash.update(json.dumps(sha_history).encode())
|
||||
sha_key = sha256_hash.hexdigest()
|
||||
print(f"sha_key: {sha_key}")
|
||||
if sha_key in state_map:
|
||||
tool_call_state = state_map[sha_key]
|
||||
if 'tool_call' in tool_call_state:
|
||||
tool_call_str = json.dumps(tool_call_state['tool_call'])
|
||||
updated_history.append({"role": "assistant", "content": f"<tool_call>\n{tool_call_str}\n</tool_call>"})
|
||||
if 'tool_response' in tool_call_state:
|
||||
tool_resp = tool_call_state['tool_response']
|
||||
#TODO: try with role = user as well
|
||||
updated_history.append({"role": "user", "content": f"<tool_response>\n{tool_resp}\n</tool_response>"})
|
||||
# we dont want to match this state with any other messages
|
||||
del(state_map[sha_key])
|
||||
|
||||
return updated_history
|
||||
|
||||
async def chat_completion(req: ChatMessage, res: Response):
|
||||
logger.info("starting request")
|
||||
tools_encoded = handler._format_system(req.tools)
|
||||
# append system prompt with tools to messages
|
||||
messages = [{"role": "system", "content": tools_encoded}]
|
||||
for message in req.messages:
|
||||
messages.append({"role": message.role, "content": message.content})
|
||||
logger.info(f"request model: {chosen_model}, messages: {json.dumps(messages)}")
|
||||
metadata = req.metadata
|
||||
arch_state = metadata.get("x-arch-state", "[]")
|
||||
updated_history = process_state(arch_state, req.messages)
|
||||
for message in updated_history:
|
||||
messages.append({"role": message["role"], "content": message["content"]})
|
||||
|
||||
logger.info(f"model_server => arch_fc: {chosen_model}, messages: {json.dumps(messages)}")
|
||||
completions_params = params["params"]
|
||||
resp = client.chat.completions.create(
|
||||
messages=messages,
|
||||
|
|
@ -80,6 +121,6 @@ async def chat_completion(req: ChatMessage, res: Response):
|
|||
if tools:
|
||||
resp.choices[0].message.tool_calls = tool_calls
|
||||
resp.choices[0].message.content = None
|
||||
logger.info(f"response (tools): {json.dumps(tools)}")
|
||||
logger.info(f"response: {json.dumps(resp.to_dict())}")
|
||||
logger.info(f"model_server <= arch_fc: (tools): {json.dumps(tools)}")
|
||||
logger.info(f"model_server <= arch_fc: response body: {json.dumps(resp.to_dict())}")
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -10,3 +10,5 @@ class Message(BaseModel):
|
|||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: List[Dict[str, Any]]
|
||||
# todo: make it default none
|
||||
metadata: Dict[str, str] = {}
|
||||
|
|
|
|||
17
model_server/app/arch_fc/test_arch_fc.py
Normal file
17
model_server/app/arch_fc/test_arch_fc.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
import json
|
||||
import pytest
|
||||
from app.arch_fc.arch_fc import process_state
|
||||
from app.arch_fc.common import ChatMessage, Message
|
||||
# test process_state
|
||||
|
||||
arch_state = '[[{"key": "cafbda799879e1dce6cd3de3c3e8a40052a93addec457bda0b2f21f8c86b3424", "message": {"role": "user", "content": "how is the weather in chicago?"}, "tool_call": {"name": "weather_forecast", "arguments": {"city": "Chicago"}}, "tool_response": "{\\"city\\":\\"Chicago\\",\\"temperature\\":[{\\"date\\":\\"2024-10-05\\",\\"temperature\\":{\\"min\\":51,\\"max\\":70},\\"query_time\\":\\"2024-10-05 08:18:00.264171+00:00\\"},{\\"date\\":\\"2024-10-06\\",\\"temperature\\":{\\"min\\":77,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264186+00:00\\"},{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":66,\\"max\\":84},\\"query_time\\":\\"2024-10-05 08:18:00.264190+00:00\\"},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":77,\\"max\\":94},\\"query_time\\":\\"2024-10-05 08:18:00.264209+00:00\\"},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":76,\\"max\\":92},\\"query_time\\":\\"2024-10-05 08:18:00.264518+00:00\\"},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":56,\\"max\\":68},\\"query_time\\":\\"2024-10-05 08:18:00.264550+00:00\\"},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":73,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264559+00:00\\"}],\\"unit\\":\\"F\\"}"}]]'
|
||||
|
||||
|
||||
def test_process_state():
|
||||
history = []
|
||||
history.append(Message(role="user", content="how is the weather in chicago?"))
|
||||
updated_history = process_state(arch_state, history)
|
||||
print(json.dumps(updated_history, indent=2))
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
|
@ -17,3 +17,4 @@ openai
|
|||
pandas
|
||||
tf-keras
|
||||
onnx
|
||||
pytest
|
||||
|
|
|
|||
|
|
@ -50,6 +50,8 @@ pub mod open_ai {
|
|||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -209,11 +211,26 @@ pub mod open_ai {
|
|||
pub arguments: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ToolCallState {
|
||||
pub key: String,
|
||||
pub message: Option<Message>,
|
||||
pub tool_call: FunctionCallDetail,
|
||||
pub tool_response: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ArchState {
|
||||
ToolCall(Vec<ToolCallState>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub usage: Usage,
|
||||
pub choices: Vec<Choice>,
|
||||
pub model: String,
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -360,6 +377,7 @@ mod test {
|
|||
stream_options: Some(super::open_ai::StreamOptions {
|
||||
include_usage: true,
|
||||
}),
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue