diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 8aab7c6e..900e2065 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -69,14 +69,9 @@ static_resources: clusters: - name: openai connect_timeout: 5s - dns_lookup_family: V4_ONLY type: LOGICAL_DNS + dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN - typed_extension_protocol_options: - envoy.extensions.upstreams.http.v3.HttpProtocolOptions: - "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions - explicit_http_config: - http2_protocol_options: {} load_assignment: cluster_name: openai endpoints: @@ -98,14 +93,9 @@ static_resources: tls_maximum_protocol_version: TLSv1_3 - name: mistral connect_timeout: 5s - dns_lookup_family: V4_ONLY type: LOGICAL_DNS + dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN - typed_extension_protocol_options: - envoy.extensions.upstreams.http.v3.HttpProtocolOptions: - "@type": type.googleapis.com/envoy.extensions.upstreams.http.v3.HttpProtocolOptions - explicit_http_config: - http2_protocol_options: {} load_assignment: cluster_name: mistral endpoints: @@ -124,6 +114,7 @@ static_resources: - name: model_server connect_timeout: 5s type: STRICT_DNS + dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN load_assignment: cluster_name: model_server @@ -138,6 +129,7 @@ static_resources: - name: mistral_7b_instruct connect_timeout: 5s type: STRICT_DNS + dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN load_assignment: cluster_name: mistral_7b_instruct @@ -152,6 +144,7 @@ static_resources: - name: arch_fc connect_timeout: 5s type: STRICT_DNS + dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN load_assignment: cluster_name: arch_fc diff --git a/arch/src/consts.rs b/arch/src/consts.rs index 572bd2c3..9b14c532 100644 --- a/arch/src/consts.rs +++ b/arch/src/consts.rs @@ -12,3 +12,4 @@ pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const ARCH_MESSAGES_KEY: &str = "arch_messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions"; +// pub const ARCH_STATE_HEADER: &str = "x-arch-state"; diff --git a/arch/src/filter_context.rs b/arch/src/filter_context.rs index 82ab4213..cb2eb732 100644 --- a/arch/src/filter_context.rs +++ b/arch/src/filter_context.rs @@ -72,11 +72,6 @@ impl FilterContext { fn process_prompt_targets(&self) { for values in self.prompt_targets.iter() { let prompt_target = values.1; - self.schedule_embeddings_call( - &prompt_target.name, - &prompt_target.name, - EmbeddingType::Name, - ); self.schedule_embeddings_call( &prompt_target.name, &prompt_target.description, diff --git a/arch/src/http.rs b/arch/src/http.rs index dfa683f0..592e7c5f 100644 --- a/arch/src/http.rs +++ b/arch/src/http.rs @@ -65,7 +65,7 @@ pub trait Client: Context { } Err(status) => Err(ClientError::DispatchError { upstream_name: String::from(call_args.upstream), - internal_status: status.clone(), + internal_status: status, }), } } diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index c6a356c5..fdfe5be0 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -469,6 +469,7 @@ impl StreamContext { tools: Some(chat_completion_tools), stream: false, stream_options: None, + metadata: None, }; let msg_body = match serde_json::to_string(&chat_completions) { @@ -686,6 +687,7 @@ impl StreamContext { tools: None, stream: callout_context.request_body.stream, stream_options: callout_context.request_body.stream_options, + metadata: None, }; let json_string = match serde_json::to_string(&chat_completions_request) { @@ -875,6 +877,7 @@ impl StreamContext { tools: None, stream: callout_context.request_body.stream, stream_options: callout_context.request_body.stream_options, + metadata: None, }; let json_resp = serde_json::to_string(&chat_completion_request).unwrap(); debug!("sending response back to default llm: {}", json_resp); diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index 7e1249cf..53fbf215 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -254,7 +254,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { module .call_proxy_on_configure(filter_context, config.len() as i32) .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) - .returning(Some(&config)) + .returning(Some(config)) .execute_and_expect(ReturnType::Bool(true)) .unwrap(); @@ -276,22 +276,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { ) .returning(Some(101)) .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("model_server"), - Some(vec![ - (":method", "POST"), - (":path", "/embeddings"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(102)) - .expect_metric_increment("active_http_calls", 1) .expect_set_tick_period_millis(Some(0)) .execute_and_expect(ReturnType::None) .unwrap(); @@ -335,31 +319,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { .execute_and_expect(ReturnType::None) .unwrap(); - module - .call_proxy_on_http_call_response( - filter_context, - 102, - 0, - embedding_response_str.len() as i32, - 0, - ) - .expect_log( - Some(LogLevel::Debug), - Some( - format!( - "filter_context: on_http_call_response called with token_id: {:?}", - 102 - ) - .as_str(), - ), - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&embedding_response_str)) - .expect_log(Some(LogLevel::Debug), None) - .execute_and_expect(ReturnType::None) - .unwrap(); - filter_context } @@ -599,6 +558,7 @@ fn request_ratelimited() { }, }], model: String::from("test"), + metadata: None, }; let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); @@ -712,6 +672,7 @@ fn request_not_ratelimited() { }, }], model: String::from("test"), + metadata: None, }; let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index 02b89d3c..75f5f295 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -1,5 +1,6 @@ +import json import os -from openai import OpenAI +from openai import OpenAI, DefaultHttpxClient import gradio as gr import logging as log from dotenv import load_dotenv @@ -13,11 +14,13 @@ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo") log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) -client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT) +client = OpenAI(api_key=OPENAI_API_KEY, base_url=CHAT_COMPLETION_ENDPOINT, http_client=DefaultHttpxClient(headers={"accept-encoding": "*"})) -def predict(message, history): +def predict(message, state): + if 'history' not in state: + state['history'] = [] + history = state.get("history") history.append({"role": "user", "content": message}) - log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) log.info("history: ", history) # Custom headers @@ -27,34 +30,42 @@ def predict(message, history): 'x-arch-deterministic-provider': 'openai', } + metadata = None + if 'arch_state' in state: + metadata = {"x-arch-state": state['arch_state']} + try: - response = client.chat.completions.create(model=MODEL_NAME, - messages= history, + raw_response = client.chat.completions.with_raw_response.create(model=MODEL_NAME, + messages = history, temperature=1.0, + metadata=metadata, extra_headers=custom_headers ) except Exception as e: log.info(e) # remove last user message in case of exception history.pop() - log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) log.info("Error calling gateway API: {}".format(e.message)) raise gr.Error("Error calling gateway API: {}".format(e.message)) - choices = response.choices - message = choices[0].message - content = message.content - history.append({"role": "assistant", "content": content}) - history[-1]["model"] = response.model + response = raw_response.parse() + # extract arch_state from metadata and store it in gradio session state + # this state must be passed back to the gateway in the next request + arch_state = json.loads(raw_response.text).get('metadata', {}).get('x-arch-state', None) + if arch_state: + state['arch_state'] = arch_state + + content = response.choices[0].message.content + + history.append({"role": "assistant", "content": content, "model": response.model}) messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)] - return messages, history - + return messages, state with gr.Blocks(fill_height=True, css="footer {visibility: hidden}") as demo: print("Starting Demo...") chatbot = gr.Chatbot(label="Arch Chatbot", scale=1) - state = gr.State([]) + state = gr.State({}) with gr.Row(): txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter", scale=1, autofocus=True) diff --git a/chatbot_ui/requirements.txt b/chatbot_ui/requirements.txt index 26131d36..60a107fe 100644 --- a/chatbot_ui/requirements.txt +++ b/chatbot_ui/requirements.txt @@ -5,4 +5,4 @@ asyncio==3.4.3 httpx==0.27.0 python-dotenv==1.0.1 pydantic==2.8.2 -openai==1.46.1 +openai==1.51.0 diff --git a/demos/function_calling/arch_config.yaml b/demos/function_calling/arch_config.yaml index c84d6b08..056fdc17 100644 --- a/demos/function_calling/arch_config.yaml +++ b/demos/function_calling/arch_config.yaml @@ -21,10 +21,6 @@ llm_providers: provider: openai model: gpt-4 default: true - - name: mistral-large-latest - access_key: MISTRAL_API_KEY - provider: mistral - model: mistral-large-latest system_prompt: | You are a helpful assistant. diff --git a/model_server/app/__init__.py b/model_server/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model_server/app/arch_fc/__init__.py b/model_server/app/arch_fc/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model_server/app/arch_fc/arch_fc.py b/model_server/app/arch_fc/arch_fc.py index a0216294..7bea01ae 100644 --- a/model_server/app/arch_fc/arch_fc.py +++ b/model_server/app/arch_fc/arch_fc.py @@ -3,11 +3,12 @@ import random from fastapi import FastAPI, Response from app.arch_fc.arch_handler import ArchHandler from app.arch_fc.bolt_handler import BoltHandler -from app.arch_fc.common import ChatMessage +from app.arch_fc.common import ChatMessage, Message import logging import yaml from openai import OpenAI import os +import hashlib logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" @@ -51,14 +52,54 @@ logger.info(f"serving mode: {mode}") logger.info(f"using model: {chosen_model}") logger.info(f"using endpoint: {endpoint}") +def process_state(arch_state, history: list[Message]): + print("state: {}".format(arch_state)) + state_json = json.loads(arch_state) + + state_map = {} + if state_json: + for tools_state in state_json: + for tool_state in tools_state: + state_map[tool_state['key']] = tool_state + + print(f"state_map: {json.dumps(state_map)}") + + sha_history = [] + updated_history = [] + for hist in history: + updated_history.append({"role": hist.role, "content": hist.content}) + if hist.role == 'user': + sha_history.append(hist.content) + sha256_hash = hashlib.sha256() + sha256_hash.update(json.dumps(sha_history).encode()) + sha_key = sha256_hash.hexdigest() + print(f"sha_key: {sha_key}") + if sha_key in state_map: + tool_call_state = state_map[sha_key] + if 'tool_call' in tool_call_state: + tool_call_str = json.dumps(tool_call_state['tool_call']) + updated_history.append({"role": "assistant", "content": f"\n{tool_call_str}\n"}) + if 'tool_response' in tool_call_state: + tool_resp = tool_call_state['tool_response'] + #TODO: try with role = user as well + updated_history.append({"role": "user", "content": f"\n{tool_resp}\n"}) + # we dont want to match this state with any other messages + del(state_map[sha_key]) + + return updated_history + async def chat_completion(req: ChatMessage, res: Response): logger.info("starting request") tools_encoded = handler._format_system(req.tools) # append system prompt with tools to messages messages = [{"role": "system", "content": tools_encoded}] - for message in req.messages: - messages.append({"role": message.role, "content": message.content}) - logger.info(f"request model: {chosen_model}, messages: {json.dumps(messages)}") + metadata = req.metadata + arch_state = metadata.get("x-arch-state", "[]") + updated_history = process_state(arch_state, req.messages) + for message in updated_history: + messages.append({"role": message["role"], "content": message["content"]}) + + logger.info(f"model_server => arch_fc: {chosen_model}, messages: {json.dumps(messages)}") completions_params = params["params"] resp = client.chat.completions.create( messages=messages, @@ -80,6 +121,6 @@ async def chat_completion(req: ChatMessage, res: Response): if tools: resp.choices[0].message.tool_calls = tool_calls resp.choices[0].message.content = None - logger.info(f"response (tools): {json.dumps(tools)}") - logger.info(f"response: {json.dumps(resp.to_dict())}") + logger.info(f"model_server <= arch_fc: (tools): {json.dumps(tools)}") + logger.info(f"model_server <= arch_fc: response body: {json.dumps(resp.to_dict())}") return resp diff --git a/model_server/app/arch_fc/common.py b/model_server/app/arch_fc/common.py index c26e8422..e9d78ecb 100644 --- a/model_server/app/arch_fc/common.py +++ b/model_server/app/arch_fc/common.py @@ -10,3 +10,5 @@ class Message(BaseModel): class ChatMessage(BaseModel): messages: list[Message] tools: List[Dict[str, Any]] + # todo: make it default none + metadata: Dict[str, str] = {} diff --git a/model_server/app/arch_fc/test_arch_fc.py b/model_server/app/arch_fc/test_arch_fc.py new file mode 100644 index 00000000..0eb409d0 --- /dev/null +++ b/model_server/app/arch_fc/test_arch_fc.py @@ -0,0 +1,17 @@ +import json +import pytest +from app.arch_fc.arch_fc import process_state +from app.arch_fc.common import ChatMessage, Message +# test process_state + +arch_state = '[[{"key": "cafbda799879e1dce6cd3de3c3e8a40052a93addec457bda0b2f21f8c86b3424", "message": {"role": "user", "content": "how is the weather in chicago?"}, "tool_call": {"name": "weather_forecast", "arguments": {"city": "Chicago"}}, "tool_response": "{\\"city\\":\\"Chicago\\",\\"temperature\\":[{\\"date\\":\\"2024-10-05\\",\\"temperature\\":{\\"min\\":51,\\"max\\":70},\\"query_time\\":\\"2024-10-05 08:18:00.264171+00:00\\"},{\\"date\\":\\"2024-10-06\\",\\"temperature\\":{\\"min\\":77,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264186+00:00\\"},{\\"date\\":\\"2024-10-07\\",\\"temperature\\":{\\"min\\":66,\\"max\\":84},\\"query_time\\":\\"2024-10-05 08:18:00.264190+00:00\\"},{\\"date\\":\\"2024-10-08\\",\\"temperature\\":{\\"min\\":77,\\"max\\":94},\\"query_time\\":\\"2024-10-05 08:18:00.264209+00:00\\"},{\\"date\\":\\"2024-10-09\\",\\"temperature\\":{\\"min\\":76,\\"max\\":92},\\"query_time\\":\\"2024-10-05 08:18:00.264518+00:00\\"},{\\"date\\":\\"2024-10-10\\",\\"temperature\\":{\\"min\\":56,\\"max\\":68},\\"query_time\\":\\"2024-10-05 08:18:00.264550+00:00\\"},{\\"date\\":\\"2024-10-11\\",\\"temperature\\":{\\"min\\":73,\\"max\\":88},\\"query_time\\":\\"2024-10-05 08:18:00.264559+00:00\\"}],\\"unit\\":\\"F\\"}"}]]' + + +def test_process_state(): + history = [] + history.append(Message(role="user", content="how is the weather in chicago?")) + updated_history = process_state(arch_state, history) + print(json.dumps(updated_history, indent=2)) + +if __name__ == "__main__": + pytest.main() diff --git a/model_server/requirements.txt b/model_server/requirements.txt index 79ec8e71..b0904be8 100644 --- a/model_server/requirements.txt +++ b/model_server/requirements.txt @@ -17,3 +17,4 @@ openai pandas tf-keras onnx +pytest diff --git a/public_types/src/common_types.rs b/public_types/src/common_types.rs index 5b6bd794..4b338fc3 100644 --- a/public_types/src/common_types.rs +++ b/public_types/src/common_types.rs @@ -50,6 +50,8 @@ pub mod open_ai { pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] pub stream_options: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -209,11 +211,26 @@ pub mod open_ai { pub arguments: HashMap, } + #[derive(Debug, Deserialize, Serialize)] + pub struct ToolCallState { + pub key: String, + pub message: Option, + pub tool_call: FunctionCallDetail, + pub tool_response: String, + } + + #[derive(Debug, Deserialize, Serialize)] + #[serde(untagged)] + pub enum ArchState { + ToolCall(Vec), + } + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionsResponse { pub usage: Usage, pub choices: Vec, pub model: String, + pub metadata: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -360,6 +377,7 @@ mod test { stream_options: Some(super::open_ai::StreamOptions { include_usage: true, }), + metadata: None, }; let serialized = serde_json::to_string_pretty(&chat_completions_request).unwrap();