mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-05 05:42:37 +02:00
disentangle planner and tool module, optimize tool module, add react mode
This commit is contained in:
parent
0a2273c7a0
commit
0116de01b9
20 changed files with 554 additions and 354 deletions
|
|
@ -16,9 +16,8 @@ from sklearn.preprocessing import (
|
|||
)
|
||||
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
TOOL_TYPE = ToolType.DATA_PREPROCESS.type_name
|
||||
TAGS = ["data preprocessing", "machine learning"]
|
||||
|
||||
|
||||
class MLProcess:
|
||||
|
|
@ -85,7 +84,7 @@ class DataPreprocessTool(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class FillMissingValue(DataPreprocessTool):
|
||||
"""
|
||||
Completing missing values with simple strategies.
|
||||
|
|
@ -106,7 +105,7 @@ class FillMissingValue(DataPreprocessTool):
|
|||
self.model = SimpleImputer(strategy=strategy, fill_value=fill_value)
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class MinMaxScale(DataPreprocessTool):
|
||||
"""
|
||||
Transform features by scaling each feature to a range, which is (0, 1).
|
||||
|
|
@ -117,7 +116,7 @@ class MinMaxScale(DataPreprocessTool):
|
|||
self.model = MinMaxScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class StandardScale(DataPreprocessTool):
|
||||
"""
|
||||
Standardize features by removing the mean and scaling to unit variance.
|
||||
|
|
@ -128,7 +127,7 @@ class StandardScale(DataPreprocessTool):
|
|||
self.model = StandardScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class MaxAbsScale(DataPreprocessTool):
|
||||
"""
|
||||
Scale each feature by its maximum absolute value.
|
||||
|
|
@ -139,7 +138,7 @@ class MaxAbsScale(DataPreprocessTool):
|
|||
self.model = MaxAbsScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class RobustScale(DataPreprocessTool):
|
||||
"""
|
||||
Apply the RobustScaler to scale features using statistics that are robust to outliers.
|
||||
|
|
@ -150,7 +149,7 @@ class RobustScale(DataPreprocessTool):
|
|||
self.model = RobustScaler()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class OrdinalEncode(DataPreprocessTool):
|
||||
"""
|
||||
Encode categorical features as ordinal integers.
|
||||
|
|
@ -161,7 +160,7 @@ class OrdinalEncode(DataPreprocessTool):
|
|||
self.model = OrdinalEncoder()
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class OneHotEncode(DataPreprocessTool):
|
||||
"""
|
||||
Apply one-hot encoding to specified categorical columns, the original columns will be dropped.
|
||||
|
|
@ -180,7 +179,7 @@ class OneHotEncode(DataPreprocessTool):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class LabelEncode(DataPreprocessTool):
|
||||
"""
|
||||
Apply label encoding to specified categorical columns in-place.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
from imap_tools import MailBox
|
||||
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
# Define a dictionary mapping email domains to their IMAP server addresses
|
||||
IMAP_SERVERS = {
|
||||
|
|
@ -24,7 +23,7 @@ IMAP_SERVERS = {
|
|||
}
|
||||
|
||||
|
||||
@register_tool(tool_type=ToolType.EMAIL_LOGIN.type_name)
|
||||
@register_tool()
|
||||
def email_login_imap(email_address, email_password):
|
||||
"""
|
||||
Use imap_tools package to log in to your email (the email that supports IMAP protocol) to verify and return the account object.
|
||||
|
|
|
|||
|
|
@ -19,12 +19,11 @@ from sklearn.preprocessing import KBinsDiscretizer, PolynomialFeatures
|
|||
|
||||
from metagpt.tools.libs.data_preprocess import MLProcess
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
TOOL_TYPE = ToolType.FEATURE_ENGINEERING.type_name
|
||||
TAGS = ["feature engineering", "machine learning"]
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class PolynomialExpansion(MLProcess):
|
||||
"""
|
||||
Add polynomial and interaction features from selected numeric columns to input DataFrame.
|
||||
|
|
@ -67,7 +66,7 @@ class PolynomialExpansion(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class CatCount(MLProcess):
|
||||
"""
|
||||
Add value counts of a categorical column as new feature.
|
||||
|
|
@ -92,7 +91,7 @@ class CatCount(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class TargetMeanEncoder(MLProcess):
|
||||
"""
|
||||
Encode a categorical column by the mean of the label column, and adds the result as a new feature.
|
||||
|
|
@ -119,7 +118,7 @@ class TargetMeanEncoder(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class KFoldTargetMeanEncoder(MLProcess):
|
||||
"""
|
||||
Add a new feature to the DataFrame by k-fold mean encoding of a categorical column using the label column.
|
||||
|
|
@ -159,7 +158,7 @@ class KFoldTargetMeanEncoder(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class CatCross(MLProcess):
|
||||
"""
|
||||
Add pairwise crossed features and convert them to numerical features.
|
||||
|
|
@ -216,7 +215,7 @@ class CatCross(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class GroupStat(MLProcess):
|
||||
"""
|
||||
Aggregate specified column in a DataFrame grouped by another column, adding new features named '<agg_col>_<agg_func>_by_<group_col>'.
|
||||
|
|
@ -248,7 +247,7 @@ class GroupStat(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class SplitBins(MLProcess):
|
||||
"""
|
||||
Inplace binning of continuous data into intervals, returning integer-encoded bin identifiers directly.
|
||||
|
|
@ -276,7 +275,7 @@ class SplitBins(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
# @register_tool(tool_type=TOOL_TYPE)
|
||||
# @register_tool(tags=TAGS)
|
||||
class ExtractTimeComps(MLProcess):
|
||||
"""
|
||||
Extract time components from a datetime column and add them as new features.
|
||||
|
|
@ -316,7 +315,7 @@ class ExtractTimeComps(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class GeneralSelection(MLProcess):
|
||||
"""
|
||||
Drop all nan feats and feats with only one unique value.
|
||||
|
|
@ -349,7 +348,7 @@ class GeneralSelection(MLProcess):
|
|||
|
||||
|
||||
# skip for now because lgb is needed
|
||||
# @register_tool(tool_type=TOOL_TYPE)
|
||||
# @register_tool(tags=TAGS)
|
||||
class TreeBasedSelection(MLProcess):
|
||||
"""
|
||||
Select features based on tree-based model and remove features with low importance.
|
||||
|
|
@ -403,7 +402,7 @@ class TreeBasedSelection(MLProcess):
|
|||
return new_df
|
||||
|
||||
|
||||
@register_tool(tool_type=TOOL_TYPE)
|
||||
@register_tool(tags=TAGS)
|
||||
class VarianceBasedSelection(MLProcess):
|
||||
"""
|
||||
Select features based on variance and remove features with low variance.
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ from pathlib import Path
|
|||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from metagpt.utils.common import encode_image
|
||||
|
||||
ANALYZE_LAYOUT_PROMPT = """You are now a UI/UX designer, please generate layout information for this image:
|
||||
|
|
@ -28,9 +27,7 @@ As the design pays tribute to large companies, sometimes it is normal for some c
|
|||
Now, please generate the corresponding webpage code including HTML, CSS and JavaScript:"""
|
||||
|
||||
|
||||
@register_tool(
|
||||
tool_type=ToolType.IMAGE2WEBPAGE.type_name, include_functions=["__init__", "generate_webpages", "save_webpages"]
|
||||
)
|
||||
@register_tool(include_functions=["__init__", "generate_webpages", "save_webpages"])
|
||||
class GPTvGenerator:
|
||||
"""Class for generating webpages at once.
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,6 @@ from PIL import Image, PngImagePlugin
|
|||
from metagpt.const import SD_OUTPUT_FILE_REPO, SOURCE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
|
||||
payload = {
|
||||
"prompt": "",
|
||||
|
|
@ -55,7 +54,7 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
|||
|
||||
|
||||
@register_tool(
|
||||
tool_type=ToolType.STABLE_DIFFUSION.type_name,
|
||||
tags=["text2image", "multimodal"],
|
||||
include_functions=["__init__", "simple_run_t2i", "run_t2i", "construct_payload", "save"],
|
||||
)
|
||||
class SDEngine:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from metagpt.tools.web_browser_engine_playwright import PlaywrightWrapper
|
||||
|
||||
|
||||
@register_tool(tool_type=ToolType.WEBSCRAPING.type_name)
|
||||
@register_tool(tags=["web scraping", "web"])
|
||||
async def scrape_web_playwright(url):
|
||||
"""
|
||||
Asynchronously Scrape and save the HTML structure and inner text content of a web page using Playwright.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import inspect
|
|||
from metagpt.utils.parse_docstring import GoogleDocstringParser, remove_spaces
|
||||
|
||||
|
||||
def convert_code_to_tool_schema(obj, include: list[str] = []):
|
||||
def convert_code_to_tool_schema(obj, include: list[str] = None):
|
||||
docstring = inspect.getdoc(obj)
|
||||
assert docstring, "no docstring found for the objects, skip registering"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,6 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToolTypeDef(BaseModel):
|
||||
name: str
|
||||
desc: str = ""
|
||||
usage_prompt: str = ""
|
||||
|
||||
|
||||
class ToolSchema(BaseModel):
|
||||
description: str
|
||||
|
||||
|
|
@ -16,3 +10,4 @@ class Tool(BaseModel):
|
|||
path: str
|
||||
schemas: dict = {}
|
||||
code: str = ""
|
||||
tags: list[str] = []
|
||||
|
|
|
|||
196
metagpt/tools/tool_recommend.py
Normal file
196
metagpt/tools/tool_recommend.py
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import jieba
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, field_validator
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Plan
|
||||
from metagpt.tools import TOOL_REGISTRY
|
||||
from metagpt.tools.tool_data_type import Tool
|
||||
from metagpt.tools.tool_registry import validate_tool_names
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
TOOL_INFO_PROMPT = """
|
||||
## Capabilities
|
||||
- You can utilize pre-defined tools in any code lines from 'Available Tools' in the form of Python class or function.
|
||||
- You can freely combine the use of any other public packages, like sklearn, numpy, pandas, etc..
|
||||
|
||||
## Available Tools:
|
||||
Each tool is described in JSON format. When you call a tool, import the tool from its path first.
|
||||
{tool_schemas}
|
||||
"""
|
||||
|
||||
|
||||
TOOL_RECOMMENDATION_PROMPT = """
|
||||
## User Requirement:
|
||||
{current_task}
|
||||
|
||||
## Task
|
||||
Recommend up to {topk} tools from 'Available Tools' that can help solve the 'User Requirement'.
|
||||
|
||||
## Available Tools:
|
||||
{available_tools}
|
||||
|
||||
## Tool Selection and Instructions:
|
||||
- Select tools most relevant to completing the 'User Requirement'.
|
||||
- If you believe that no tools are suitable, indicate with an empty list.
|
||||
- Only list the names of the tools, not the full schema of each tool.
|
||||
- Ensure selected tools are listed in 'Available Tools'.
|
||||
- Output a json list of tool names:
|
||||
```json
|
||||
["tool_name1", "tool_name2", ...]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class RecommendTool(Action):
|
||||
async def run(self, prompt):
|
||||
return await self._aask(prompt)
|
||||
|
||||
|
||||
class ToolRecommender(BaseModel):
|
||||
"""
|
||||
The default ToolRecommender:
|
||||
1. Recall: If plan exists, use exact match between task type and tool type to recall tools;
|
||||
If plan doesn't exist (e.g. we use ReAct), return all user-specified tools;
|
||||
2. Rank: Use LLM to select final candidates from recalled set.
|
||||
"""
|
||||
|
||||
tools: dict[str, Tool] = {}
|
||||
force: bool = False
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def validate_tools(cls, v: list[str]) -> dict[str, Tool]:
|
||||
if v == ["<all>"]:
|
||||
return TOOL_REGISTRY.get_all_tools()
|
||||
else:
|
||||
return validate_tool_names(v)
|
||||
|
||||
async def recommend_tools(
|
||||
self, context: str = "", plan: Plan = None, recall_topk: int = 20, topk: int = 5
|
||||
) -> list[Tool]:
|
||||
"""
|
||||
Recommends a list of tools based on the given context and plan. The recommendation process includes two stages: recall from a large pool and rank the recalled tools to select the final set.
|
||||
|
||||
Args:
|
||||
context (str): The context for tool recommendation.
|
||||
plan (Plan): The plan for tool recommendation.
|
||||
recall_topk (int): The number of tools to recall in the initial step.
|
||||
topk (int): The number of tools to return after rank as final recommendations.
|
||||
|
||||
Returns:
|
||||
list[Tool]: A list of recommended tools.
|
||||
"""
|
||||
|
||||
if not self.tools:
|
||||
return []
|
||||
|
||||
if self.force or (not context and not plan):
|
||||
# directly use what users have specified as result for forced recommendation;
|
||||
# directly use the whole set if there is no useful information
|
||||
return list(self.tools.values())
|
||||
|
||||
recalled_tools = await self.recall_tools(context=context, plan=plan, topk=recall_topk)
|
||||
if not recalled_tools:
|
||||
return []
|
||||
|
||||
ranked_tools = await self.rank_tools(recalled_tools=recalled_tools, context=context, plan=plan, topk=topk)
|
||||
|
||||
logger.info(f"Recommended tools: \n{[tool.name for tool in ranked_tools]}")
|
||||
|
||||
return ranked_tools
|
||||
|
||||
async def get_recommended_tool_info(self, **kwargs) -> str:
|
||||
"""
|
||||
Wrap recommended tools with their info in a string, which can be used directly in a prompt.
|
||||
"""
|
||||
recommended_tools = await self.recommend_tools(**kwargs)
|
||||
if not recommended_tools:
|
||||
return ""
|
||||
tool_schemas = {tool.name: tool.schemas for tool in recommended_tools}
|
||||
return TOOL_INFO_PROMPT.format(tool_schemas=tool_schemas)
|
||||
|
||||
async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
|
||||
"""
|
||||
Retrieves a list of relevant tools from a large pool, based on the given context and plan.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def rank_tools(
|
||||
self, recalled_tools: list[Tool], context: str = "", plan: Plan = None, topk: int = 5
|
||||
) -> list[Tool]:
|
||||
"""
|
||||
Default rank methods for a ToolRecommender. Use LLM to rank the recalled tools based on the given context, plan, and topk value.
|
||||
"""
|
||||
current_task = plan.current_task.instruction if plan else context
|
||||
|
||||
available_tools = {tool.name: tool.schemas["description"] for tool in recalled_tools}
|
||||
prompt = TOOL_RECOMMENDATION_PROMPT.format(
|
||||
current_task=current_task,
|
||||
available_tools=available_tools,
|
||||
topk=topk,
|
||||
)
|
||||
rsp = await RecommendTool().run(prompt)
|
||||
rsp = CodeParser.parse_code(block=None, text=rsp)
|
||||
ranked_tools = json.loads(rsp)
|
||||
|
||||
valid_tools = validate_tool_names(ranked_tools)
|
||||
|
||||
return list(valid_tools.values())[:topk]
|
||||
|
||||
|
||||
class BM25ToolRecommender(ToolRecommender):
|
||||
"""
|
||||
A ToolRecommender using BM25 at the recall stage:
|
||||
1. Recall: Querying tool descriptions with task instruction if plan exists. Otherwise, return all user-specified tools;
|
||||
2. Rank: LLM rank, the same as the default ToolRecommender.
|
||||
"""
|
||||
|
||||
bm25: Any = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._init_corpus()
|
||||
|
||||
def _init_corpus(self):
|
||||
corpus = [f"{tool.name} {tool.tags}: {tool.schemas['description']}" for tool in self.tools.values()]
|
||||
tokenized_corpus = [self._tokenize(doc) for doc in corpus]
|
||||
self.bm25 = BM25Okapi(tokenized_corpus)
|
||||
|
||||
def _tokenize(self, text):
|
||||
return jieba.lcut(text) # FIXME: needs more sophisticated tokenization
|
||||
|
||||
async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
|
||||
query = plan.current_task.instruction if plan else context
|
||||
|
||||
query_tokens = self._tokenize(query)
|
||||
doc_scores = self.bm25.get_scores(query_tokens)
|
||||
top_indexes = np.argsort(doc_scores)[::-1][:topk]
|
||||
recalled_tools = [list(self.tools.values())[index] for index in top_indexes]
|
||||
print([doc_scores[index] for index in top_indexes])
|
||||
print([recalled_tools[i].name for i in range(len(recalled_tools))])
|
||||
print([recalled_tools[i].schemas["description"] for i in range(len(recalled_tools))])
|
||||
|
||||
return recalled_tools
|
||||
|
||||
|
||||
class EmbeddingToolRecommender(ToolRecommender):
|
||||
"""
|
||||
NOTE: To be implemented.
|
||||
A ToolRecommender using embeddings at the recall stage:
|
||||
1. Recall: Use embeddings to calculate the similarity between query and tool info;
|
||||
2. Rank: LLM rank, the same as the default ToolRecommender.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def recall_tools(self, context: str = "", plan: Plan = None, topk: int = 20) -> list[Tool]:
|
||||
pass
|
||||
|
|
@ -10,26 +10,20 @@ from __future__ import annotations
|
|||
import inspect
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import TOOL_SCHEMA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
from metagpt.tools.tool_data_type import Tool, ToolSchema, ToolTypeDef
|
||||
from metagpt.tools.tool_type import ToolType
|
||||
from metagpt.tools.tool_data_type import Tool, ToolSchema
|
||||
|
||||
|
||||
class ToolRegistry(BaseModel):
|
||||
tools: dict = {}
|
||||
tool_types: dict = {}
|
||||
tools_by_types: dict = defaultdict(dict) # two-layer k-v, {tool_type: {tool_name: {...}, ...}, ...}
|
||||
|
||||
@field_validator("tool_types", mode="before")
|
||||
@classmethod
|
||||
def init_tool_types(cls, tool_types: ToolType):
|
||||
return {tool_type.type_name: tool_type.value for tool_type in tool_types}
|
||||
tools_by_tags: dict = defaultdict(dict) # two-layer k-v, {tag: {tool_name: {...}, ...}, ...}
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
|
|
@ -37,25 +31,15 @@ class ToolRegistry(BaseModel):
|
|||
tool_path,
|
||||
schema_path="",
|
||||
tool_code="",
|
||||
tool_type="other",
|
||||
tags=None,
|
||||
tool_source_object=None,
|
||||
include_functions=[],
|
||||
include_functions=None,
|
||||
verbose=False,
|
||||
):
|
||||
if self.has_tool(tool_name):
|
||||
return
|
||||
|
||||
if tool_type not in self.tool_types:
|
||||
# register new tool type on the fly
|
||||
logger.warning(
|
||||
f"{tool_type} not previously defined, will create a temporary tool type with just a name. This tool type is only effective during this runtime. You may consider add this tool type with more configs permanently at metagpt.tools.tool_type"
|
||||
)
|
||||
temp_tool_type_obj = ToolTypeDef(name=tool_type)
|
||||
self.tool_types[tool_type] = temp_tool_type_obj
|
||||
if verbose:
|
||||
logger.info(f"tool type {tool_type} registered")
|
||||
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / tool_type / f"{tool_name}.yml"
|
||||
schema_path = schema_path or TOOL_SCHEMA_PATH / f"{tool_name}.yml"
|
||||
|
||||
schemas = make_schema(tool_source_object, include_functions, schema_path)
|
||||
|
||||
|
|
@ -70,10 +54,11 @@ class ToolRegistry(BaseModel):
|
|||
# logger.warning(
|
||||
# f"{tool_name} schema not conforms to required format, but will be used anyway. Mismatch: {e}"
|
||||
# )
|
||||
|
||||
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code)
|
||||
tags = tags or []
|
||||
tool = Tool(name=tool_name, path=tool_path, schemas=schemas, code=tool_code, tags=tags)
|
||||
self.tools[tool_name] = tool
|
||||
self.tools_by_types[tool_type][tool_name] = tool
|
||||
for tag in tags:
|
||||
self.tools_by_tags[tag].update({tool_name: tool})
|
||||
if verbose:
|
||||
logger.info(f"{tool_name} registered")
|
||||
logger.info(f"schema made at {str(schema_path)}, can be used for checking")
|
||||
|
|
@ -84,24 +69,24 @@ class ToolRegistry(BaseModel):
|
|||
def get_tool(self, key) -> Tool:
|
||||
return self.tools.get(key)
|
||||
|
||||
def get_tools_by_type(self, key) -> dict[str, Tool]:
|
||||
return self.tools_by_types.get(key, {})
|
||||
def get_tools_by_tag(self, key) -> dict[str, Tool]:
|
||||
return self.tools_by_tags.get(key, {})
|
||||
|
||||
def has_tool_type(self, key) -> bool:
|
||||
return key in self.tool_types
|
||||
def get_all_tools(self) -> dict[str, Tool]:
|
||||
return self.tools
|
||||
|
||||
def get_tool_type(self, key) -> ToolType:
|
||||
return self.tool_types.get(key)
|
||||
def has_tool_tag(self, key) -> bool:
|
||||
return key in self.tools_by_tags
|
||||
|
||||
def get_tool_types(self) -> dict[str, ToolType]:
|
||||
return self.tool_types
|
||||
def get_tool_tags(self) -> list[str]:
|
||||
return list(self.tools_by_tags.keys())
|
||||
|
||||
|
||||
# Registry instance
|
||||
TOOL_REGISTRY = ToolRegistry(tool_types=ToolType)
|
||||
TOOL_REGISTRY = ToolRegistry()
|
||||
|
||||
|
||||
def register_tool(tool_type: str = "other", schema_path: str = "", **kwargs):
|
||||
def register_tool(tags: list[str] = None, schema_path: str = "", **kwargs):
|
||||
"""register a tool to registry"""
|
||||
|
||||
def decorator(cls):
|
||||
|
|
@ -117,7 +102,7 @@ def register_tool(tool_type: str = "other", schema_path: str = "", **kwargs):
|
|||
tool_path=file_path,
|
||||
schema_path=schema_path,
|
||||
tool_code=source_code,
|
||||
tool_type=tool_type,
|
||||
tags=tags,
|
||||
tool_source_object=cls,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -142,14 +127,15 @@ def make_schema(tool_source_object, include, path):
|
|||
return schema
|
||||
|
||||
|
||||
def validate_tool_names(tools: list[str], return_tool_object=False) -> list[str]:
|
||||
valid_tools = []
|
||||
for tool_name in tools:
|
||||
if not TOOL_REGISTRY.has_tool(tool_name):
|
||||
logger.warning(
|
||||
f"Specified tool {tool_name} not found and was skipped. Check if you have registered it properly"
|
||||
)
|
||||
def validate_tool_names(tools: Union[list[str], str]) -> str:
|
||||
assert isinstance(tools, list), "tools must be a list of str"
|
||||
valid_tools = {}
|
||||
for key in tools:
|
||||
# one can define either tool names or tool type names, take union to get the whole set
|
||||
if TOOL_REGISTRY.has_tool(key):
|
||||
valid_tools.update({key: TOOL_REGISTRY.get_tool(key)})
|
||||
elif TOOL_REGISTRY.tool_tool_tag(key):
|
||||
valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key))
|
||||
else:
|
||||
valid_tool = TOOL_REGISTRY.get_tool(tool_name) if return_tool_object else tool_name
|
||||
valid_tools.append(valid_tool)
|
||||
logger.warning(f"invalid tool name or tool type name: {key}, skipped")
|
||||
return valid_tools
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue