spotify demo with optimized context window code change (#397)

This commit is contained in:
Adil Hafeez 2025-02-07 19:14:15 -08:00 committed by GitHub
parent b3c95a6698
commit 8de6eacfbd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 265 additions and 8 deletions

View file

@ -134,7 +134,7 @@ class ArchIntentHandler(ArchBaseHandler):
req.messages, req.tools, self.extra_instruction
)
logger.info(f"[request]: {json.dumps(messages)}")
logger.info(f"[request to arch-fc (intent)]: {json.dumps(messages)}")
model_response = self.client.chat.completions.create(
messages=messages,
@ -519,9 +519,11 @@ class ArchFunctionHandler(ArchBaseHandler):
"""
logger.info("[Arch-Function] - ChatCompletion")
messages = self._process_messages(req.messages, req.tools)
messages = self._process_messages(
req.messages, req.tools, metadata=req.metadata
)
logger.info(f"[request]: {json.dumps(messages)}")
logger.info(f"[request to arch-fc]: {json.dumps(messages)}")
# always enable `stream=True` to collect model responses
response = self.client.chat.completions.create(

View file

@ -105,7 +105,7 @@ class ArchGuardHanlder:
raise NotImplementedError(f"{req.task} is not supported!")
logger.info("[Arch-Guard] - Prediction")
logger.info(f"[request]: {req.input}")
logger.info(f"[request arch-guard]: {req.input}")
if len(req.input.split()) < max_num_words:
result = self._predict_text(req.task, req.input)

View file

@ -16,6 +16,7 @@ class Message(BaseModel):
class ChatMessage(BaseModel):
messages: List[Message] = []
tools: List[Dict[str, Any]] = []
metadata: Optional[Dict[str, str]] = {}
class Choice(BaseModel):
@ -123,6 +124,7 @@ class ArchBaseHandler:
tools: List[Dict[str, Any]] = None,
extra_instruction: str = None,
max_tokens=4096,
metadata: Dict[str, str] = {},
):
"""
Processes a list of messages and formats them appropriately.
@ -157,7 +159,12 @@ class ArchBaseHandler:
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
elif role == "tool":
role = "user"
content = f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
if metadata.get("optimize_context_window", "false").lower() == "true":
content = f"<tool_response>\n\n</tool_response>"
else:
content = (
f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
)
processed_messages.append({"role": role, "content": content})