diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml
index 8aab7c6e..900e2065 100644
--- a/arch/envoy.template.yaml
+++ b/arch/envoy.template.yaml
@@ -69,14 +69,9 @@ static_resources:
clusters:
- name: openai
connect_timeout: 5s
- dns_lookup_family: V4_ONLY
type: LOGICAL_DNS
+ dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
- typed_extension_protocol_options:
- envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
- "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
- explicit_http_config:
- http2_protocol_options: {}
load_assignment:
cluster_name: openai
endpoints:
@@ -98,14 +93,9 @@ static_resources:
tls_maximum_protocol_version: TLSv1_3
- name: mistral
connect_timeout: 5s
- dns_lookup_family: V4_ONLY
type: LOGICAL_DNS
+ dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
- typed_extension_protocol_options:
- envoy.extensions.upstreams.http.v3.HttpProtocolOptions:
- "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions
- explicit_http_config:
- http2_protocol_options: {}
load_assignment:
cluster_name: mistral
endpoints:
@@ -124,6 +114,7 @@ static_resources:
- name: model_server
connect_timeout: 5s
type: STRICT_DNS
+ dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: model_server
@@ -138,6 +129,7 @@ static_resources:
- name: mistral_7b_instruct
connect_timeout: 5s
type: STRICT_DNS
+ dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: mistral_7b_instruct
@@ -152,6 +144,7 @@ static_resources:
- name: arch_fc
connect_timeout: 5s
type: STRICT_DNS
+ dns_lookup_family: V4_ONLY
lb_policy: ROUND_ROBIN
load_assignment:
cluster_name: arch_fc
diff --git a/arch/src/consts.rs b/arch/src/consts.rs
index 572bd2c3..9b14c532 100644
--- a/arch/src/consts.rs
+++ b/arch/src/consts.rs
@@ -12,3 +12,4 @@ pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider";
pub const ARCH_MESSAGES_KEY: &str = "arch_messages";
pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint";
pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions";
+// pub const ARCH_STATE_HEADER: &str = "x-arch-state";
diff --git a/arch/src/filter_context.rs b/arch/src/filter_context.rs
index 82ab4213..cb2eb732 100644
--- a/arch/src/filter_context.rs
+++ b/arch/src/filter_context.rs
@@ -72,11 +72,6 @@ impl FilterContext {
fn process_prompt_targets(&self) {
for values in self.prompt_targets.iter() {
let prompt_target = values.1;
- self.schedule_embeddings_call(
- &prompt_target.name,
- &prompt_target.name,
- EmbeddingType::Name,
- );
self.schedule_embeddings_call(
&prompt_target.name,
&prompt_target.description,
diff --git a/arch/src/http.rs b/arch/src/http.rs
index dfa683f0..592e7c5f 100644
--- a/arch/src/http.rs
+++ b/arch/src/http.rs
@@ -65,7 +65,7 @@ pub trait Client: Context {
}
Err(status) => Err(ClientError::DispatchError {
upstream_name: String::from(call_args.upstream),
- internal_status: status.clone(),
+ internal_status: status,
}),
}
}
diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs
index c6a356c5..fdfe5be0 100644
--- a/arch/src/stream_context.rs
+++ b/arch/src/stream_context.rs
@@ -469,6 +469,7 @@ impl StreamContext {
tools: Some(chat_completion_tools),
stream: false,
stream_options: None,
+ metadata: None,
};
let msg_body = match serde_json::to_string(&chat_completions) {
@@ -686,6 +687,7 @@ impl StreamContext {
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
+ metadata: None,
};
let json_string = match serde_json::to_string(&chat_completions_request) {
@@ -875,6 +877,7 @@ impl StreamContext {
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
+ metadata: None,
};
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
debug!("sending response back to default llm: {}", json_resp);
diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs
index 7e1249cf..53fbf215 100644
--- a/arch/tests/integration.rs
+++ b/arch/tests/integration.rs
@@ -254,7 +254,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
- .returning(Some(&config))
+ .returning(Some(config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
@@ -276,22 +276,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
)
.returning(Some(101))
.expect_metric_increment("active_http_calls", 1)
- .expect_log(Some(LogLevel::Debug), None)
- .expect_http_call(
- Some("model_server"),
- Some(vec![
- (":method", "POST"),
- (":path", "/embeddings"),
- (":authority", "model_server"),
- ("content-type", "application/json"),
- ("x-envoy-upstream-rq-timeout-ms", "60000"),
- ]),
- None,
- None,
- None,
- )
- .returning(Some(102))
- .expect_metric_increment("active_http_calls", 1)
.expect_set_tick_period_millis(Some(0))
.execute_and_expect(ReturnType::None)
.unwrap();
@@ -335,31 +319,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
.execute_and_expect(ReturnType::None)
.unwrap();
- module
- .call_proxy_on_http_call_response(
- filter_context,
- 102,
- 0,
- embedding_response_str.len() as i32,
- 0,
- )
- .expect_log(
- Some(LogLevel::Debug),
- Some(
- format!(
- "filter_context: on_http_call_response called with token_id: {:?}",
- 102
- )
- .as_str(),
- ),
- )
- .expect_metric_increment("active_http_calls", -1)
- .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
- .returning(Some(&embedding_response_str))
- .expect_log(Some(LogLevel::Debug), None)
- .execute_and_expect(ReturnType::None)
- .unwrap();
-
filter_context
}
@@ -599,6 +558,7 @@ fn request_ratelimited() {
},
}],
model: String::from("test"),
+ metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
@@ -712,6 +672,7 @@ fn request_not_ratelimited() {
},
}],
model: String::from("test"),
+ metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py
index 02b89d3c..75f5f295 100644
--- a/chatbot_ui/app/run.py
+++ b/chatbot_ui/app/run.py
@@ -1,5 +1,6 @@
+import json
import os
-from openai import OpenAI
+from openai import OpenAI, DefaultHttpxClient
import gradio as gr
import logging as log
from dotenv import load_dotenv
@@ -13,11 +14,13 @@ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
-client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT)
+client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT, http_client=DefaultHttpxClient(headers={"accept-encoding": "*"}))
-def predict(message, history):
+def predict(message, state):
+ if 'history' not in state:
+ state['history'] = []
+ history = state.get("history")
history.append({"role": "user", "content": message})
- log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
log.info("history: ", history)
# Custom headers
@@ -27,34 +30,42 @@ def predict(message, history):
'x-arch-deterministic-provider': 'openai',
}
+ metadata = None
+ if 'arch_state' in state:
+ metadata = {"x-arch-state": state['arch_state']}
+
try:
- response = client.chat.completions.create(model=MODEL_NAME,
- messages= history,
+ raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME,
+ messages = history,
temperature=1.0,
+ metadata=metadata,
extra_headers=custom_headers
)
except Exception as e:
log.info(e)
# remove last user message in case of exception
history.pop()
- log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
log.info("Error calling gateway API: {}".format(e.message))
raise gr.Error("Error calling gateway API: {}".format(e.message))
- choices = response.choices
- message = choices[0].message
- content = message.content
- history.append({"role": "assistant", "content": content})
- history[-1]["model"] = response.model
+ response = raw_response.parse()
+ # extract arch_state from metadata and store it in gradio session state
+ # this state must be passed back to the gateway in the next request
+ arch_state = json.loads(raw_response.text).get('metadata', {}).get('x-arch-state', None)
+ if arch_state:
+ state['arch_state'] = arch_state
+
+ content = response.choices[0].message.content
+
+ history.append({"role": "assistant", "content": content, "model": response.model})
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
- return messages, history
-
+ return messages, state
with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo:
print("Starting Demo...")
chatbot = gr.Chatbot(label="Arch Chatbot", scale=1)
- state = gr.State([])
+ state = gr.State({})
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1, autofocus=True)
diff --git a/chatbot_ui/requirements.txt b/chatbot_ui/requirements.txt
index 26131d36..60a107fe 100644
--- a/chatbot_ui/requirements.txt
+++ b/chatbot_ui/requirements.txt
@@ -5,4 +5,4 @@ asyncio==3.4.3
httpx==0.27.0
python-dotenv==1.0.1
pydantic==2.8.2
-openai==1.46.1
+openai==1.51.0
diff --git a/demos/function_calling/arch_config.yaml b/demos/function_calling/arch_config.yaml
index c84d6b08..056fdc17 100644
--- a/demos/function_calling/arch_config.yaml
+++ b/demos/function_calling/arch_config.yaml
@@ -21,10 +21,6 @@ llm_providers:
provider: openai
model: gpt-4
default: true
- - name: mistral-large-latest
- access_key: MISTRAL_API_KEY
- provider: mistral
- model: mistral-large-latest
system_prompt: |
You are a helpful assistant.
diff --git a/model_server/app/__init__.py b/model_server/app/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/model_server/app/arch_fc/__init__.py b/model_server/app/arch_fc/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/model_server/app/arch_fc/arch_fc.py b/model_server/app/arch_fc/arch_fc.py
index a0216294..7bea01ae 100644
--- a/model_server/app/arch_fc/arch_fc.py
+++ b/model_server/app/arch_fc/arch_fc.py
@@ -3,11 +3,12 @@ import random
from fastapi import FastAPI, Response
from app.arch_fc.arch_handler import ArchHandler
from app.arch_fc.bolt_handler import BoltHandler
-from app.arch_fc.common import ChatMessage
+from app.arch_fc.common import ChatMessage, Message
import logging
import yaml
from openai import OpenAI
import os
+import hashlib
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@@ -51,14 +52,54 @@ logger.info(f"serving mode: {mode}")
logger.info(f"using model: {chosen_model}")
logger.info(f"using endpoint: {endpoint}")
+def process_state(arch_state, history: list[Message]):
+ print("state: {}".format(arch_state))
+ state_json = json.loads(arch_state)
+
+ state_map = {}
+ if state_json:
+ for tools_state in state_json:
+ for tool_state in tools_state:
+ state_map[tool_state['key']] = tool_state
+
+ print(f"state_map: {json.dumps(state_map)}")
+
+ sha_history = []
+ updated_history = []
+ for hist in history:
+ updated_history.append({"role": hist.role, "content": hist.content})
+ if hist.role == 'user':
+ sha_history.append(hist.content)
+ sha256_hash = hashlib.sha256()
+ sha256_hash.update(json.dumps(sha_history).encode())
+ sha_key = sha256_hash.hexdigest()
+ print(f"sha_key: {sha_key}")
+ if sha_key in state_map:
+ tool_call_state = state_map[sha_key]
+ if 'tool_call' in tool_call_state:
+ tool_call_str = json.dumps(tool_call_state['tool_call'])
+ updated_history.append({"role": "assistant", "content": f"\n{tool_call_str}\n"})
+ if 'tool_response' in tool_call_state:
+ tool_resp = tool_call_state['tool_response']
+ #TODO: try with role = user as well
+ updated_history.append({"role": "user", "content": f"\n{tool_resp}\n"})
+ # we dont want to match this state with any other messages
+ del(state_map[sha_key])
+
+ return updated_history
+
async def chat_completion(req: ChatMessage, res: Response):
logger.info("starting request")
tools_encoded = handler._format_system(req.tools)
# append system prompt with tools to messages
messages = [{"role": "system", "content": tools_encoded}]
- for message in req.messages:
- messages.append({"role": message.role, "content": message.content})
- logger.info(f"request model: {chosen_model}, messages: {json.dumps(messages)}")
+ metadata = req.metadata
+ arch_state = metadata.get("x-arch-state", "[]")
+ updated_history = process_state(arch_state, req.messages)
+ for message in updated_history:
+ messages.append({"role": message["role"], "content": message["content"]})
+
+ logger.info(f"model_server => arch_fc: {chosen_model}, messages: {json.dumps(messages)}")
completions_params = params["params"]
resp = client.chat.completions.create(
messages=messages,
@@ -80,6 +121,6 @@ async def chat_completion(req: ChatMessage, res: Response):
if tools:
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
- logger.info(f"response (tools): {json.dumps(tools)}")
- logger.info(f"response: {json.dumps(resp.to_dict())}")
+ logger.info(f"model_server <= arch_fc: (tools): {json.dumps(tools)}")
+ logger.info(f"model_server <= arch_fc: response body: {json.dumps(resp.to_dict())}")
return resp
diff --git a/model_server/app/arch_fc/common.py b/model_server/app/arch_fc/common.py
index c26e8422..e9d78ecb 100644
--- a/model_server/app/arch_fc/common.py
+++ b/model_server/app/arch_fc/common.py
@@ -10,3 +10,5 @@ class Message(BaseModel):
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
+ # todo: make it default none
+ metadata: Dict[str, str] = {}
diff --git a/model_server/app/arch_fc/test_arch_fc.py b/model_server/app/arch_fc/test_arch_fc.py
new file mode 100644
index 00000000..0eb409d0
--- /dev/null
+++ b/model_server/app/arch_fc/test_arch_fc.py
@@ -0,0 +1,17 @@
+import json
+import pytest
+from app.arch_fc.arch_fc import process_state
+from app.arch_fc.common import ChatMessage, Message
+# test process_state
+
+arch_state = '[[{"key": "cafbda799879e1dce6cd3de3c3e8a40052a93addec457bda0b2f21f8c86b3424", "message": {"role": "user", "content": "how is the weather in chicago?"}, "tool_call": {"name": "weather_forecast", "arguments": {"city": "Chicago"}}, "tool_response": "{\\"city\\":\\"Chicago\\",\\"temperature\\":[{\\"date\\":\\"2024-10-05\\",\\"temperature\\":{\\"min\\":51,\\"max\\":70},\\"query_time\\":\\"2024-10-05 08:18:00.264171+00:00\\"},{\\"date\\":\\"2024-10-06\\",\\"temperature\\":{\\"min\\":77,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264186+00:00\\"},{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":66,\\"max\\":84},\\"query_time\\":\\"2024-10-05 08:18:00.264190+00:00\\"},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":77,\\"max\\":94},\\"query_time\\":\\"2024-10-05 08:18:00.264209+00:00\\"},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":76,\\"max\\":92},\\"query_time\\":\\"2024-10-05 08:18:00.264518+00:00\\"},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":56,\\"max\\":68},\\"query_time\\":\\"2024-10-05 08:18:00.264550+00:00\\"},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":73,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264559+00:00\\"}],\\"unit\\":\\"F\\"}"}]]'
+
+
+def test_process_state():
+ history = []
+ history.append(Message(role="user", content="how is the weather in chicago?"))
+ updated_history = process_state(arch_state, history)
+ print(json.dumps(updated_history, indent=2))
+
+if __name__ == "__main__":
+ pytest.main()
diff --git a/model_server/requirements.txt b/model_server/requirements.txt
index 79ec8e71..b0904be8 100644
--- a/model_server/requirements.txt
+++ b/model_server/requirements.txt
@@ -17,3 +17,4 @@ openai
pandas
tf-keras
onnx
+pytest
diff --git a/public_types/src/common_types.rs b/public_types/src/common_types.rs
index 5b6bd794..4b338fc3 100644
--- a/public_types/src/common_types.rs
+++ b/public_types/src/common_types.rs
@@ -50,6 +50,8 @@ pub mod open_ai {
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub metadata: Option>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -209,11 +211,26 @@ pub mod open_ai {
pub arguments: HashMap,
}
+ #[derive(Debug, Deserialize, Serialize)]
+ pub struct ToolCallState {
+ pub key: String,
+ pub message: Option,
+ pub tool_call: FunctionCallDetail,
+ pub tool_response: String,
+ }
+
+ #[derive(Debug, Deserialize, Serialize)]
+ #[serde(untagged)]
+ pub enum ArchState {
+ ToolCall(Vec),
+ }
+
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsResponse {
pub usage: Usage,
pub choices: Vec,
pub model: String,
+ pub metadata: Option>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -360,6 +377,7 @@ mod test {
stream_options: Some(super::open_ai::StreamOptions {
include_usage: true,
}),
+ metadata: None,
};
let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();