mirror of
https://github.com/katanemo/plano.git
synced 2026-05-11 08:42:48 +02:00
spotify demo with optimized context window code change (#397)
This commit is contained in:
parent
b3c95a6698
commit
8de6eacfbd
11 changed files with 265 additions and 8 deletions
|
|
@ -134,7 +134,7 @@ class ArchIntentHandler(ArchBaseHandler):
|
|||
req.messages, req.tools, self.extra_instruction
|
||||
)
|
||||
|
||||
logger.info(f"[request]: {json.dumps(messages)}")
|
||||
logger.info(f"[request to arch-fc (intent)]: {json.dumps(messages)}")
|
||||
|
||||
model_response = self.client.chat.completions.create(
|
||||
messages=messages,
|
||||
|
|
@ -519,9 +519,11 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
"""
|
||||
logger.info("[Arch-Function] - ChatCompletion")
|
||||
|
||||
messages = self._process_messages(req.messages, req.tools)
|
||||
messages = self._process_messages(
|
||||
req.messages, req.tools, metadata=req.metadata
|
||||
)
|
||||
|
||||
logger.info(f"[request]: {json.dumps(messages)}")
|
||||
logger.info(f"[request to arch-fc]: {json.dumps(messages)}")
|
||||
|
||||
# always enable `stream=True` to collect model responses
|
||||
response = self.client.chat.completions.create(
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ class ArchGuardHanlder:
|
|||
raise NotImplementedError(f"{req.task} is not supported!")
|
||||
|
||||
logger.info("[Arch-Guard] - Prediction")
|
||||
logger.info(f"[request]: {req.input}")
|
||||
logger.info(f"[request arch-guard]: {req.input}")
|
||||
|
||||
if len(req.input.split()) < max_num_words:
|
||||
result = self._predict_text(req.task, req.input)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class Message(BaseModel):
|
|||
class ChatMessage(BaseModel):
|
||||
messages: List[Message] = []
|
||||
tools: List[Dict[str, Any]] = []
|
||||
metadata: Optional[Dict[str, str]] = {}
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
|
|
@ -123,6 +124,7 @@ class ArchBaseHandler:
|
|||
tools: List[Dict[str, Any]] = None,
|
||||
extra_instruction: str = None,
|
||||
max_tokens=4096,
|
||||
metadata: Dict[str, str] = {},
|
||||
):
|
||||
"""
|
||||
Processes a list of messages and formats them appropriately.
|
||||
|
|
@ -157,7 +159,12 @@ class ArchBaseHandler:
|
|||
content = f"<tool_call>\n{json.dumps(tool_calls[0]['function'])}\n</tool_call>"
|
||||
elif role == "tool":
|
||||
role = "user"
|
||||
content = f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
|
||||
if metadata.get("optimize_context_window", "false").lower() == "true":
|
||||
content = f"<tool_response>\n\n</tool_response>"
|
||||
else:
|
||||
content = (
|
||||
f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
|
||||
)
|
||||
|
||||
processed_messages.append({"role": role, "content": content})
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue