Refine model_server

This commit is contained in:
Shuguang Chen 2024-12-05 15:19:41 -08:00
parent a5bd005411
commit 4fcfd83639
6 changed files with 149 additions and 64 deletions

View file

@ -38,7 +38,7 @@ class ArchBaseHandler:
client: OpenAI,
model_name: str,
task_prompt: str,
tool_prompt: str,
tool_prompt_template: str,
format_prompt: str,
generation_params: Dict,
):
@ -59,7 +59,7 @@ class ArchBaseHandler:
self.model_name = model_name
self.task_prompt = task_prompt
self.tool_prompt = tool_prompt
self.tool_prompt_template = tool_prompt_template
self.format_prompt = format_prompt
self.generation_params = generation_params
@ -78,7 +78,7 @@ class ArchBaseHandler:
raise NotImplementedError()
@final
def _format_system(self, tools: List[Dict[str, Any]]) -> str:
def _format_system_prompt(self, tools: List[Dict[str, Any]]) -> str:
"""
Formats the system prompt using provided tools.
@ -94,7 +94,7 @@ class ArchBaseHandler:
system_prompt = (
self.task_prompt
+ "\n\n"
+ self.tool_prompt.format(tool_text=tool_text)
+ self.tool_prompt_template.format(tool_text=tool_text)
+ "\n\n"
+ self.format_prompt
)