Merge pull request #1001 from garylin2099/di_fixes

minor updates
This commit is contained in:
garylin2099 2024-03-13 16:53:58 +08:00 committed by GitHub
commit fb175c0012
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 24 additions and 26 deletions

View file

@ -1,7 +1,7 @@
# Data Interpreter (DI)
## What is Data Interpreter
Data Interpreter is an agent who solves problems through codes. It understands user requirements, makes plans, writes codes for execution, and uses tools if necessary. These capabilities enable it to tackle a wide range of scenarios, please check out the examples below.
Data Interpreter is an agent who solves data-related problems through codes. It understands user requirements, makes plans, writes codes for execution, and uses tools if necessary. These capabilities enable it to tackle a wide range of scenarios, please check out the examples below. For overall design and technical details, please see our [paper](https://arxiv.org/abs/2402.18679).
## Example List
- Data visualization
@ -12,7 +12,9 @@ ## Example List
- Tool usage: web page imitation
- Tool usage: web crawling
- Tool usage: text2image
- Tool usage: email summarization and response
- Tool usage: email summarization and response\
- More on the way!
Please see [here](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/interpreter/intro.html) for detailed explanation.
Please see the [docs](https://docs.deepwisdom.ai/main/en/guide/use_cases/agent/interpreter/intro.html) for more explanation.
We are continuously releasing codes, stay tuned!

View file

@ -50,17 +50,15 @@ class WriteAnalysisCode(Action):
)
working_memory = working_memory or []
context = [Message(content=structual_prompt, role="user")] + working_memory
context = process_message(context)
context = process_message([Message(content=structual_prompt, role="user")] + working_memory)
# LLM call
if not use_reflection:
if use_reflection:
code = await self._debug_with_reflection(context=context, working_memory=working_memory)
else:
rsp = await self.llm.aask(context, system_msgs=[INTERPRETER_SYSTEM_MSG], **kwargs)
code = CodeParser.parse_code(block=None, text=rsp)
else:
code = await self._debug_with_reflection(context=context, working_memory=working_memory)
return code

View file

@ -39,7 +39,7 @@ class DataInterpreter(Role):
use_plan: bool = True
use_reflection: bool = False
execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True)
tools: Union[str, list[str]] = []
tools: Union[str, list[str]] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
tool_recommender: ToolRecommender = None
react_mode: Literal["plan_and_act", "react"] = "plan_and_act"
max_react_loop: int = 10 # used for react mode
@ -53,6 +53,7 @@ class DataInterpreter(Role):
if self.tools:
self.tool_recommender = BM25ToolRecommender(tools=self.tools)
self.set_actions([WriteAnalysisCode])
self._set_state(0)
return self
@property
@ -140,13 +141,13 @@ class DataInterpreter(Role):
async def _write_code(
self,
counter,
plan_status="",
tool_info="",
counter: int,
plan_status: str = "",
tool_info: str = "",
):
todo = WriteAnalysisCode()
todo = self.rc.todo # todo is WriteAnalysisCode
logger.info(f"ready to {todo.name}")
use_reflection = counter > 0 and self.use_reflection
use_reflection = counter > 0 and self.use_reflection # only use reflection after the first trial
user_requirement = self.get_memories()[0].content
@ -176,7 +177,6 @@ class DataInterpreter(Role):
code = await CheckData().run(self.planner.plan)
if not code.strip():
return
success = False
result, success = await self.execute_code.run(code)
if success:
print(result)

View file

@ -8,7 +8,7 @@ import numpy as np
from pydantic import BaseModel, field_validator
from rank_bm25 import BM25Okapi
from metagpt.actions import Action
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.schema import Plan
from metagpt.tools import TOOL_REGISTRY
@ -49,11 +49,6 @@ Recommend up to {topk} tools from 'Available Tools' that can help solve the 'Use
"""
class RecommendTool(Action):
async def run(self, prompt):
return await self._aask(prompt)
class ToolRecommender(BaseModel):
"""
The default ToolRecommender:
@ -67,6 +62,7 @@ class ToolRecommender(BaseModel):
@field_validator("tools", mode="before")
@classmethod
def validate_tools(cls, v: list[str]) -> dict[str, Tool]:
# One can use special symbol ["<all>"] to indicate use of all registered tools
if v == ["<all>"]:
return TOOL_REGISTRY.get_all_tools()
else:
@ -136,7 +132,7 @@ class ToolRecommender(BaseModel):
available_tools=available_tools,
topk=topk,
)
rsp = await RecommendTool().run(prompt)
rsp = await LLM().aask(prompt)
rsp = CodeParser.parse_code(block=None, text=rsp)
ranked_tools = json.loads(rsp)
@ -160,9 +156,11 @@ class TypeMatchToolRecommender(ToolRecommender):
task_type = plan.current_task.task_type
candidate_tools = TOOL_REGISTRY.get_tools_by_tag(task_type)
candidate_tool_names = set(self.tools.keys()) & candidate_tools.keys()
recalled_tools = [candidate_tools[tool_name] for tool_name in candidate_tool_names]
recalled_tools = [candidate_tools[tool_name] for tool_name in candidate_tool_names][:topk]
return recalled_tools[:topk]
logger.info(f"Recalled tools: \n{[tool.name for tool in recalled_tools]}")
return recalled_tools
class BM25ToolRecommender(ToolRecommender):

View file

@ -22,6 +22,7 @@ async def test_interpreter(mocker, auto_run):
assert len(finished_tasks[0].code) > 0 # check one task to see if code is recorded
@pytest.mark.asyncio
async def test_interpreter_react_mode(mocker):
mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
mocker.patch("builtins.input", return_value="confirm")

View file

@ -45,7 +45,6 @@ class DummyClass:
pass
# def dummy_fn(df: pd.DataFrame, s: str, k: int = 5, type: Literal["a", "b", "c"] = "a") -> dict:
def dummy_fn(
df: pd.DataFrame,
s: str,