fix test_debug_code

This commit is contained in:
yzlin 2024-01-30 22:20:34 +08:00
parent ede04f20f6
commit 274747e72f
3 changed files with 40 additions and 5 deletions

View file

@ -24,9 +24,10 @@ class ToolRegistry(BaseModel):
tool_types: dict = {}
tools_by_types: dict = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
def register_tool_type(self, tool_type: ToolType):
def register_tool_type(self, tool_type: ToolType, verbose: bool = False):
self.tool_types[tool_type.name] = tool_type
logger.info(f"tool type {tool_type.name} registered")
if verbose:
logger.info(f"tool type {tool_type.name} registered")
def register_tool(
self,
@ -38,6 +39,7 @@ class ToolRegistry(BaseModel):
tool_source_object=None,
include_functions=[],
make_schema_if_not_exists=True,
verbose=False,
):
if self.has_tool(tool_name):
return
@ -68,7 +70,8 @@ class ToolRegistry(BaseModel):
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code)
self.tools[tool_name] = tool
self.tools_by_types[tool_type][tool_name] = tool
logger.info(f"{tool_name} registered")
if verbose:
logger.info(f"{tool_name} registered")
def has_tool(self, key: str) -> Tool:
return key in self.tools

File diff suppressed because one or more lines are too long

View file

@ -48,7 +48,7 @@ def sort_array(arr):
async def test_debug_code():
debug_context = Message(content=DebugContext)
new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr)
assert "def sort_array(arr)" in new_code
assert "def sort_array(arr)" in new_code["code"]
def test_messages_to_str():