minor updates

This commit is contained in:
yzlin 2024-03-13 16:50:19 +08:00
parent 93663784ab
commit 5fc711ae82
5 changed files with 19 additions and 23 deletions

View file

@ -50,17 +50,15 @@ class WriteAnalysisCode(Action):
)
working_memory = working_memory or []
context = [Message(content=structual_prompt, role="user")] + working_memory
context = process_message(context)
context = process_message([Message(content=structual_prompt, role="user")] + working_memory)
# LLM call
if not use_reflection:
if use_reflection:
code = await self._debug_with_reflection(context=context, working_memory=working_memory)
else:
rsp = await self.llm.aask(context, system_msgs=[INTERPRETER_SYSTEM_MSG], **kwargs)
code = CodeParser.parse_code(block=None, text=rsp)
else:
code = await self._debug_with_reflection(context=context, working_memory=working_memory)
return code

View file

@ -39,7 +39,7 @@ class DataInterpreter(Role):
use_plan: bool = True
use_reflection: bool = False
execute_code: ExecuteNbCode = Field(default_factory=ExecuteNbCode, exclude=True)
tools: Union[str, list[str]] = []
tools: Union[str, list[str]] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
tool_recommender: ToolRecommender = None
react_mode: Literal["plan_and_act", "react"] = "plan_and_act"
max_react_loop: int = 10 # used for react mode
@ -53,6 +53,7 @@ class DataInterpreter(Role):
if self.tools:
self.tool_recommender = BM25ToolRecommender(tools=self.tools)
self.set_actions([WriteAnalysisCode])
self._set_state(0)
return self
@property
@ -140,13 +141,13 @@ class DataInterpreter(Role):
async def _write_code(
self,
counter,
plan_status="",
tool_info="",
counter: int,
plan_status: str = "",
tool_info: str = "",
):
todo = WriteAnalysisCode()
todo = self.rc.todo # todo is WriteAnalysisCode
logger.info(f"ready to {todo.name}")
use_reflection = counter > 0 and self.use_reflection
use_reflection = counter > 0 and self.use_reflection # only use reflection after the first trial
user_requirement = self.get_memories()[0].content
@ -176,7 +177,6 @@ class DataInterpreter(Role):
code = await CheckData().run(self.planner.plan)
if not code.strip():
return
success = False
result, success = await self.execute_code.run(code)
if success:
print(result)

View file

@ -8,7 +8,7 @@ import numpy as np
from pydantic import BaseModel, field_validator
from rank_bm25 import BM25Okapi
from metagpt.actions import Action
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.schema import Plan
from metagpt.tools import TOOL_REGISTRY
@ -49,11 +49,6 @@ Recommend up to {topk} tools from 'Available Tools' that can help solve the 'Use
"""
class RecommendTool(Action):
async def run(self, prompt):
return await self._aask(prompt)
class ToolRecommender(BaseModel):
"""
The default ToolRecommender:
@ -67,6 +62,7 @@ class ToolRecommender(BaseModel):
@field_validator("tools", mode="before")
@classmethod
def validate_tools(cls, v: list[str]) -> dict[str, Tool]:
# One can use special symbol ["<all>"] to indicate use of all registered tools
if v == ["<all>"]:
return TOOL_REGISTRY.get_all_tools()
else:
@ -136,7 +132,7 @@ class ToolRecommender(BaseModel):
available_tools=available_tools,
topk=topk,
)
rsp = await RecommendTool().run(prompt)
rsp = await LLM().aask(prompt)
rsp = CodeParser.parse_code(block=None, text=rsp)
ranked_tools = json.loads(rsp)
@ -160,9 +156,11 @@ class TypeMatchToolRecommender(ToolRecommender):
task_type = plan.current_task.task_type
candidate_tools = TOOL_REGISTRY.get_tools_by_tag(task_type)
candidate_tool_names = set(self.tools.keys()) & candidate_tools.keys()
recalled_tools = [candidate_tools[tool_name] for tool_name in candidate_tool_names]
recalled_tools = [candidate_tools[tool_name] for tool_name in candidate_tool_names][:topk]
return recalled_tools[:topk]
logger.info(f"Recalled tools: \n{[tool.name for tool in recalled_tools]}")
return recalled_tools
class BM25ToolRecommender(ToolRecommender):

View file

@ -22,6 +22,7 @@ async def test_interpreter(mocker, auto_run):
assert len(finished_tasks[0].code) > 0 # check one task to see if code is recorded
@pytest.mark.asyncio
async def test_interpreter_react_mode(mocker):
mocker.patch("metagpt.actions.di.execute_nb_code.ExecuteNbCode.run", return_value=("a successful run", True))
mocker.patch("builtins.input", return_value="confirm")

View file

@ -45,7 +45,6 @@ class DummyClass:
pass
# def dummy_fn(df: pd.DataFrame, s: str, k: int = 5, type: Literal["a", "b", "c"] = "a") -> dict:
def dummy_fn(
df: pd.DataFrame,
s: str,