update make tools.

This commit is contained in:
刘棒棒 2023-12-18 22:30:14 +08:00
parent ea7e11665d
commit 52052c8244

View file

@ -220,3 +220,65 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
rsp = await self.llm.aask_code(prompt, **tool_config)
context = [Message(content=prompt, role="user")]
return context, rsp["code"]
class MakeTools(WriteCodeByGenerate):
DEFAULT_SYSTEM_MSG = """Please Create a very General Function Code startswith `def` from any codes you got.\n
**Notice:
1. Your code must contain a general function start with `def`.
2. Refactor your code to get the most efficient implementation for large input data in the shortest amount of time.
3. Use Google style for function annotations.
4. Write example code after `if __name__ == '__main__':`by using old varibales in old code,
and make sure it could be execute in the user's machine.
5. Dont have missing package references.**
"""
def __init__(self, name: str = '', context: list[Message] = None, llm: LLM = None, workspace: str = None):
"""
:param str name: name, defaults to ''
:param list[Message] context: context, defaults to None
:param LLM llm: llm, defaults to None
:param str workspace: tools code saved file path dir, defaults to None
"""
super().__init__(name, context, llm)
self.workspace = workspace or str(Path(__file__).parents[1].joinpath("./tools/functions/libs/udf"))
self.file_suffix: str = '.py'
def parse_function_name(self, function_code: str) -> str:
# 定义正则表达式模式
pattern = r'\bdef\s+([a-zA-Z_]\w*)\s*\('
# 在代码中搜索匹配的模式
match = re.search(pattern, function_code)
# 如果找到匹配项则返回匹配的函数名否则返回None
if match:
return match.group(1)
else:
return None
def save(self, tool_code: str) -> None:
func_name = self.parse_function_name(tool_code)
if func_name is None:
raise ValueError(f"No function name found in {tool_code}")
saved_path = Path(self.workspace).joinpath(func_name+self.file_suffix)
logger.info(f"Saved tool_code {func_name} in {str(saved_path)}.")
saved_path.write_text(tool_code, encoding='utf-8')
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
async def run(self, code_message: List[Message | Dict], **kwargs) -> str:
msgs = self.process_msg(code_message)
logger.info(f"\n\nAsk to Make tools:\n{'-'*60}\n {msgs[-1]}")
tool_code = await self.llm.aask_code(msgs, **kwargs)
max_tries, current_try = 3, 1
func_name = self.parse_function_name(tool_code['code'])
while current_try < max_tries and func_name is None:
logger.info(f"\n\nTools Respond\n{'-'*60}\n: {tool_code}")
logger.warning(f"No function name found in code, we will retry make tools. \n\n{tool_code['code']}\n")
msgs.append({'role': 'assistant', 'content': 'We need a general function in above code,but not found function.'})
tool_code = await self.llm.aask_code(msgs, **kwargs)
current_try += 1
func_name = self.parse_function_name(tool_code['code'])
if func_name is not None:
break
logger.info(f"\n\nTools Respond\n{'-'*60}\n: {tool_code}")
self.save(tool_code['code'])
return tool_code["code"]