From 66a971b086f844d1d066635176887b330fed58ef Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 7 Feb 2025 19:01:42 -0800 Subject: [PATCH] add more changes --- arch/arch_config_schema.yaml | 2 ++ crates/common/src/configuration.rs | 1 + crates/prompt_gateway/src/http_context.rs | 15 +++++++++++++-- crates/prompt_gateway/src/stream_context.rs | 4 ++-- demos/spotify/arch_config.yaml | 3 +++ model_server/src/core/function_calling.py | 8 +++++--- model_server/src/core/guardrails.py | 2 +- model_server/src/core/utils/model_utils.py | 9 ++++++++- 8 files changed, 35 insertions(+), 9 deletions(-) diff --git a/arch/arch_config_schema.yaml b/arch/arch_config_schema.yaml index f2b2c8a5..1b32b730 100644 --- a/arch/arch_config_schema.yaml +++ b/arch/arch_config_schema.yaml @@ -79,6 +79,8 @@ properties: properties: prompt_target_intent_matching_threshold: type: number + optimize_context_window: + type: boolean system_prompt: type: string prompt_targets: diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index f1250499..069695ba 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -25,6 +25,7 @@ pub struct Configuration { #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct Overrides { pub prompt_target_intent_matching_threshold: Option, + pub optimize_context_window: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index e7d920f1..4b9c8015 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -137,9 +137,20 @@ impl HttpContext for StreamContext { .map(|(_, pt)| pt.into()) .collect(); + let mut metadata = deserialized_body.metadata.clone(); + + if let Some(overrides) = self.overrides.as_ref() { + if overrides.optimize_context_window.unwrap_or_default() { + if metadata.is_none() { + metadata = Some(HashMap::new()); + } + metadata.as_mut().unwrap().insert("optimize_context_window".to_string(), "true".to_string()); + } + } + let arch_fc_chat_completion_request = ChatCompletionsRequest { messages: deserialized_body.messages.clone(), - metadata: deserialized_body.metadata.clone(), + metadata, stream: deserialized_body.stream, model: "--".to_string(), stream_options: deserialized_body.stream_options.clone(), @@ -157,7 +168,7 @@ impl HttpContext for StreamContext { }; debug!("sending request to model server"); - trace!("request body: {}", json_data); + debug!("request body: {}", json_data); let mut headers = vec![ (ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME), diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 2704e8d8..e6db7f59 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -46,7 +46,7 @@ pub struct StreamCallContext { pub struct StreamContext { system_prompt: Rc>, pub prompt_targets: Rc>, - _overrides: Rc>, + pub overrides: Rc>, pub metrics: Rc, pub callouts: RefCell>, pub context_id: u32, @@ -89,7 +89,7 @@ impl StreamContext { streaming_response: false, user_prompt: None, is_chat_completions_request: false, - _overrides: overrides, + overrides: overrides, request_id: None, traceparent: None, _tracing: tracing, diff --git a/demos/spotify/arch_config.yaml b/demos/spotify/arch_config.yaml index 75ade01f..dbed4c8c 100644 --- a/demos/spotify/arch_config.yaml +++ b/demos/spotify/arch_config.yaml @@ -4,6 +4,9 @@ listener: port: 8080 #If you configure port 443, you'll need to update the listener with tls_certificates message_format: huggingface +overrides: + optimize_context_window: true + endpoints: spotify: endpoint: api.spotify.com diff --git a/model_server/src/core/function_calling.py b/model_server/src/core/function_calling.py index 25c83818..99dd29ba 100644 --- a/model_server/src/core/function_calling.py +++ b/model_server/src/core/function_calling.py @@ -134,7 +134,7 @@ class ArchIntentHandler(ArchBaseHandler): req.messages, req.tools, self.extra_instruction ) - logger.info(f"[request]: {json.dumps(messages)}") + logger.info(f"[request to arch-fc (intent)]: {json.dumps(messages)}") model_response = self.client.chat.completions.create( messages=messages, @@ -519,9 +519,11 @@ class ArchFunctionHandler(ArchBaseHandler): """ logger.info("[Arch-Function] - ChatCompletion") - messages = self._process_messages(req.messages, req.tools) + messages = self._process_messages( + req.messages, req.tools, metadata=req.metadata + ) - logger.info(f"[request]: {json.dumps(messages)}") + logger.info(f"[request to arch-fc]: {json.dumps(messages)}") # always enable `stream=True` to collect model responses response = self.client.chat.completions.create( diff --git a/model_server/src/core/guardrails.py b/model_server/src/core/guardrails.py index 0d2f34fc..fae4e5ba 100644 --- a/model_server/src/core/guardrails.py +++ b/model_server/src/core/guardrails.py @@ -105,7 +105,7 @@ class ArchGuardHanlder: raise NotImplementedError(f"{req.task} is not supported!") logger.info("[Arch-Guard] - Prediction") - logger.info(f"[request]: {req.input}") + logger.info(f"[request arch-guard]: {req.input}") if len(req.input.split()) < max_num_words: result = self._predict_text(req.task, req.input) diff --git a/model_server/src/core/utils/model_utils.py b/model_server/src/core/utils/model_utils.py index a0d56d50..7dc71acf 100644 --- a/model_server/src/core/utils/model_utils.py +++ b/model_server/src/core/utils/model_utils.py @@ -16,6 +16,7 @@ class Message(BaseModel): class ChatMessage(BaseModel): messages: List[Message] = [] tools: List[Dict[str, Any]] = [] + metadata: Optional[Dict[str, str]] = {} class Choice(BaseModel): @@ -123,6 +124,7 @@ class ArchBaseHandler: tools: List[Dict[str, Any]] = None, extra_instruction: str = None, max_tokens=4096, + metadata: Dict[str, str] = {}, ): """ Processes a list of messages and formats them appropriately. @@ -157,7 +159,12 @@ class ArchBaseHandler: content = f"\n{json.dumps(tool_calls[0]['function'])}\n" elif role == "tool": role = "user" - content = f"\n\n" + if metadata.get("optimize_context_window", "false").lower() == "true": + content = f"\n\n" + else: + content = ( + f"\n{json.dumps(content)}\n" + ) processed_messages.append({"role": role, "content": content})