rename schema to schemas to avoid pydantic warning

This commit is contained in:
yzlin 2024-01-16 19:18:03 +08:00
parent c8858cd8d4
commit 9dc421b122
3 changed files with 9 additions and 9 deletions

View file

@ -110,7 +110,7 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
if TOOL_REGISTRY.has_tool(tool_name):
valid_tools.append(TOOL_REGISTRY.get_tool(tool_name))
tool_catalog = {tool.name: tool.schema for tool in valid_tools}
tool_catalog = {tool.name: tool.schemas for tool in valid_tools}
return tool_catalog
async def _tool_recommendation(
@ -158,7 +158,7 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
tool_catalog = {}
if available_tools:
available_tools = {tool_name: tool.schema["description"] for tool_name, tool in available_tools.items()}
available_tools = {tool_name: tool.schemas["description"] for tool_name, tool in available_tools.items()}
recommend_tools = await self._tool_recommendation(
plan.current_task.instruction, code_steps, available_tools
@ -199,7 +199,7 @@ class WriteCodeWithToolsML(WriteCodeWithTools):
code_context = "\n\n".join(code_context)
if available_tools:
available_tools = {tool_name: tool.schema["description"] for tool_name, tool in available_tools.items()}
available_tools = {tool_name: tool.schemas["description"] for tool_name, tool in available_tools.items()}
recommend_tools = await self._tool_recommendation(
plan.current_task.instruction, code_steps, available_tools

View file

@ -29,5 +29,5 @@ class ToolSchema(BaseModel):
class Tool(BaseModel):
name: str
path: str
schema: dict = {}
schemas: dict = {}
code: str = ""

View file

@ -25,7 +25,7 @@ class ToolRegistry:
def register_tool_type(self, tool_type: ToolType):
self.tool_types[tool_type.name] = tool_type
logger.info(f"{tool_type.name} registered")
logger.info(f"tool type {tool_type.name} registered")
def register_tool(
self,
@ -51,16 +51,16 @@ class ToolRegistry:
with open(schema_path, "r", encoding="utf-8") as f:
schema_dict = yaml.safe_load(f)
schema = schema_dict.get(tool_name) or dict(schema_dict.values())
schema["tool_path"] = tool_path # corresponding code file path of the tool
schemas = schema_dict.get(tool_name) or dict(schema_dict.values())
schemas["tool_path"] = tool_path # corresponding code file path of the tool
try:
ToolSchema(**schema) # validation
ToolSchema(**schemas) # validation
except Exception:
pass
# logger.warning(
# f"{tool_name} schema not conforms to required format, but will be used anyway. Mismatch: {e}"
# )
tool = Tool(name=tool_name, path=tool_path, schema=schema, code=tool_code)
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code)
self.tools[tool_name] = tool
self.tools_by_types[tool_type_name][tool_name] = tool
logger.info(f"{tool_name} registered")