add more changes

This commit is contained in:
Adil Hafeez 2025-02-07 19:01:42 -08:00
parent cd90faf50c
commit 66a971b086
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
8 changed files with 35 additions and 9 deletions

View file

@ -79,6 +79,8 @@ properties:
properties:
prompt_target_intent_matching_threshold:
type: number
optimize_context_window:
type: boolean
system_prompt:
type: string
prompt_targets:

View file

@ -25,6 +25,7 @@ pub struct Configuration {
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct Overrides {
pub prompt_target_intent_matching_threshold: Option<f64>,
pub optimize_context_window: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]

View file

@ -137,9 +137,20 @@ impl HttpContext for StreamContext {
.map(|(_, pt)| pt.into())
.collect();
let mut metadata = deserialized_body.metadata.clone();
if let Some(overrides) = self.overrides.as_ref() {
if overrides.optimize_context_window.unwrap_or_default() {
if metadata.is_none() {
metadata = Some(HashMap::new());
}
metadata.as_mut().unwrap().insert("optimize_context_window".to_string(), "true".to_string());
}
}
let arch_fc_chat_completion_request = ChatCompletionsRequest {
messages: deserialized_body.messages.clone(),
metadata: deserialized_body.metadata.clone(),
metadata,
stream: deserialized_body.stream,
model: "--".to_string(),
stream_options: deserialized_body.stream_options.clone(),
@ -157,7 +168,7 @@ impl HttpContext for StreamContext {
};
debug!("sending request to model server");
trace!("request body: {}", json_data);
debug!("request body: {}", json_data);
let mut headers = vec![
(ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME),

View file

@ -46,7 +46,7 @@ pub struct StreamCallContext {
pub struct StreamContext {
system_prompt: Rc<Option<String>>,
pub prompt_targets: Rc<HashMap<String, PromptTarget>>,
_overrides: Rc<Option<Overrides>>,
pub overrides: Rc<Option<Overrides>>,
pub metrics: Rc<Metrics>,
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
pub context_id: u32,
@ -89,7 +89,7 @@ impl StreamContext {
streaming_response: false,
user_prompt: None,
is_chat_completions_request: false,
_overrides: overrides,
overrides: overrides,
request_id: None,
traceparent: None,
_tracing: tracing,

View file

@ -4,6 +4,9 @@ listener:
port: 8080 #If you configure port 443, you'll need to update the listener with tls_certificates
message_format: huggingface
overrides:
optimize_context_window: true
endpoints:
spotify:
endpoint: api.spotify.com

View file

@ -134,7 +134,7 @@ class ArchIntentHandler(ArchBaseHandler):
req.messages, req.tools, self.extra_instruction
)
logger.info(f"[request]: {json.dumps(messages)}")
logger.info(f"[request to arch-fc (intent)]: {json.dumps(messages)}")
model_response = self.client.chat.completions.create(
messages=messages,
@ -519,9 +519,11 @@ class ArchFunctionHandler(ArchBaseHandler):
"""
logger.info("[Arch-Function] - ChatCompletion")
messages = self._process_messages(req.messages, req.tools)
messages = self._process_messages(
req.messages, req.tools, metadata=req.metadata
)
logger.info(f"[request]: {json.dumps(messages)}")
logger.info(f"[request to arch-fc]: {json.dumps(messages)}")
# always enable `stream=True` to collect model responses
response = self.client.chat.completions.create(

View file

@ -105,7 +105,7 @@ class ArchGuardHanlder:
raise NotImplementedError(f"{req.task} is not supported!")
logger.info("[Arch-Guard] - Prediction")
logger.info(f"[request]: {req.input}")
logger.info(f"[request arch-guard]: {req.input}")
if len(req.input.split()) < max_num_words:
result = self._predict_text(req.task, req.input)

View file

@ -16,6 +16,7 @@ class Message(BaseModel):
class ChatMessage(BaseModel):
messages: List[Message] = []
tools: List[Dict[str, Any]] = []
metadata: Optional[Dict[str, str]] = {}
class Choice(BaseModel):
@ -123,6 +124,7 @@ class ArchBaseHandler:
tools: List[Dict[str, Any]] = None,
extra_instruction: str = None,
max_tokens=4096,
metadata: Dict[str, str] = {},
):
"""
Processes a list of messages and formats them appropriately.
@ -157,7 +159,12 @@ class ArchBaseHandler:
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
elif role == "tool":
role = "user"
content = f"<tool_response>\n\n</tool_response>"
if metadata.get("optimize_context_window", "false").lower() == "true":
content = f"<tool_response>\n\n</tool_response>"
else:
content = (
f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
)
processed_messages.append({"role": role, "content": content})