Merge branch 'main' into dev

This commit is contained in:
geekan 2023-12-28 17:27:10 +08:00
commit 6707e9f1b9
16 changed files with 858 additions and 10 deletions

View file

@ -54,8 +54,8 @@ # Step 2: Clone the repository to your local machine for latest version, and ins
# Step 3: setup your OPENAI_API_KEY, or make sure it existed in the env
mkdir ~/.metagpt
cp config/config.yaml ~/.metagpt/key.yaml
vim ~/.metagpt/key.yaml
cp config/config.yaml ~/.metagpt/config.yaml
vim ~/.metagpt/config.yaml
# Step 4: run metagpt cli
metagpt "Create a 2048 game in python"

View file

@ -86,7 +86,7 @@ class CollectLinks(Action):
desc: str = "Collect links from a search engine."
search_engine: SearchEngine = Field(default_factory=SearchEngine)
rank_func: Union[Callable[[list[str]], None], None] = None
rank_func: Optional[Callable[[list[str]], None]] = None
async def run(
self,
@ -181,18 +181,18 @@ class WebBrowseAndSummarize(Action):
llm: BaseLLM = Field(default_factory=LLM)
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: WebBrowserEngine = Field(
default_factory=lambda: WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if WebBrowseAndSummarize.browse_func else None,
run_func=WebBrowseAndSummarize.browse_func,
)
)
web_browser_engine: Optional[WebBrowserEngine] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
if CONFIG.model_for_researcher_summary:
self.llm.model = CONFIG.model_for_researcher_summary
self.web_browser_engine = WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if self.browse_func else None,
run_func=self.browse_func,
)
async def run(
self,
url: str,

View file

@ -6,6 +6,7 @@
"""
import asyncio
import re
from pydantic import BaseModel
@ -107,9 +108,11 @@ class Researcher(Role):
return msg
def write_report(self, topic: str, content: str):
filename = re.sub(r'[\\/:"*?<>|]+', " ", topic)
filename = filename.replace("\n", "")
if not RESEARCH_PATH.exists():
RESEARCH_PATH.mkdir(parents=True)
filepath = RESEARCH_PATH / f"{topic}.md"
filepath = RESEARCH_PATH / f"{filename}.md"
filepath.write_text(content)

View file

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Date : 12/23/2023 4:51 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

108
metagpt/strategy/base.py Normal file
View file

@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 9:16 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from typing import List
from anytree import Node, RenderTree
from pydantic import BaseModel
class BaseParser(BaseModel):
def __call__(self, *args, **kwargs):
raise NotImplementedError
def propose(self, current_state: str, **kwargs) -> str:
raise NotImplementedError
def sample(self, current_state: str, **kwargs) -> str:
raise NotImplementedError
def value(self, input: str, **kwargs) -> str:
raise NotImplementedError
class BaseEvaluator(BaseModel):
def __call__(self, *args, **kwargs):
raise NotImplementedError
def status_verify(self, *args, **kwargs):
raise NotImplementedError
class ThoughtNode(Node):
"""A node representing a thought in the thought tree."""
name: str = ""
value: int = 0
id: int = 0
valid_status: bool = True
def update_value(self, value) -> None:
"""Update the value of the thought node."""
self.value = value
def update_valid_status(self, status) -> None:
"""Update the validity status of the thought node."""
self.valid_status = status
class ThoughtTree(RenderTree):
"""A tree structure to represent thoughts."""
@property
def all_nodes(self) -> List[ThoughtNode]:
"""
Get a list of all nodes in the thought tree.
Returns:
List[ThoughtNode]: A list containing all nodes in the thought tree.
"""
all_nodes = [node for _, _, node in self]
return all_nodes
def update_node(self, thought: List[dict] = [], current_node: ThoughtNode = None) -> List[ThoughtNode]:
"""
Update the tree with new thoughts.
Args:
thought (List[dict]): A list of dictionaries representing thought information.
current_node (ThoughtNode): The current node under which new thoughts will be added.
Returns:
List[ThoughtNode]: A list of ThoughtNode instances representing the updated tree nodes.
"""
nodes = []
for node_info in thought:
node = ThoughtNode(
name=node_info["node_state_instruction"], parent=current_node, id=int(node_info["node_id"])
)
nodes.append(node)
return nodes
def parse_node_path(self, node) -> List[str]:
"""
Parse and retrieve the hierarchical path of the given thought node.
This method traverses the parent nodes of the provided 'node' and constructs
the full path from the root node to the given node.
Args:
node: The thought node for which the hierarchical path needs to be parsed.
Returns:
List[str]: A list representing the full hierarchical path of the given thought node.
The list is ordered from the root node to the provided node.
"""
full_node_path = []
while node is not None:
full_node_path.append(node.name)
node = node.parent
full_node_path.reverse()
return full_node_path
def show(self) -> None:
"""Print the updated tree."""
print("\nUpdated Tree:")
for pre, _, node in self:
print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}")

View file

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Date : 12/26/2023 3:32 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 1:06 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import re
from metagpt.strategy.prompt_templates.creative_writing import cot_prompt, vote_prompt
from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
BaseParser,
Strategy,
ThoughtSolverConfig,
)
class TextGenParser(BaseParser):
propose_prompt: str = cot_prompt
value_prompt: str = vote_prompt
def __call__(self, input_text: str) -> str:
return input_text
def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)
def value(self, input: str = "", **kwargs) -> str:
# node_result = self(input)
id = kwargs.get("node_id", "0")
return self.value_prompt + f"Choice {id}:\n{input}\n"
class TextGenEvaluator(BaseEvaluator):
value_map = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map = {val: key for key, val in value_map.items()}
def __call__(self, evaluation: str, **kwargs) -> float:
try:
value = 0
node_id = kwargs.get("node_id", "0")
pattern = r".*best choice is .*(\d+).*"
match = re.match(pattern, evaluation, re.DOTALL)
if match:
vote = int(match.groups()[0])
print(vote)
if vote == int(node_id):
value = 1
except:
value = 0
return value
def status_verify(self, value):
status = False
if value in self.status_map:
status_value = self.status_map[value]
if status_value != "impossible":
status = True
return status
if __name__ == "__main__":
import asyncio
initial_prompt = """It isn't difficult to do a handstand if you just stand on your hands. It caught him off guard that space smelled of seared steak. When she didnt like a guy who was trying to pick her up, she started using sign language. Each person who knows you has a different perception of who you are."""
parser = TextGenParser()
evaluator = TextGenEvaluator()
config = ThoughtSolverConfig(n_generate_sample=3, parser=parser, evaluator=evaluator)
tot_base = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot_base.solve(init_prompt=initial_prompt))

View file

@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 1:36 AM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import re
from metagpt.strategy.prompt_templates.game24 import propose_prompt, value_prompt
from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
BaseParser,
Strategy,
ThoughtSolverConfig,
)
class Game24Parser(BaseParser):
propose_prompt: str = propose_prompt
value_prompt: str = value_prompt
def __call__(self, input_text: str) -> str:
last_line = input_text.strip().split("\n")[-1]
return last_line.split("left: ")[-1].split(")")[0]
def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)
def value(self, input: str = "", **kwargs) -> str:
node_result = self(input)
return self.value_prompt.format(input=node_result)
class Game24Evaluator(BaseEvaluator):
value_map = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map = {val: key for key, val in value_map.items()}
def __call__(self, evaluation: str, **kwargs) -> float:
try:
matches = re.findall(r"\b(impossible|sure|likely)\b", evaluation)
value = self.value_map[matches[0]]
except:
value = 0.001
return value
def status_verify(self, value):
status = False
if value in self.status_map:
status_value = self.status_map[value]
if status_value != "impossible":
status = True
return status
if __name__ == "__main__":
import asyncio
initial_prompt = """4 5 6 10"""
parser = Game24Parser()
evaluator = Game24Evaluator()
config = ThoughtSolverConfig(n_generate_sample=5, parser=parser, evaluator=evaluator)
tot = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot.solve(init_prompt=initial_prompt))

View file

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Date : 12/23/2023 5:21 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,25 @@
standard_prompt = """
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
"""
cot_prompt = """
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
Make a plan then write. Your output should be of the following format:
Plan:
Your plan here.
Passage:
Your passage here.
"""
vote_prompt = """Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice.
"""
compare_prompt = """Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent".
"""
score_prompt = """Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10.
"""

View file

@ -0,0 +1,139 @@
# 5-shot
standard_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Input: 1 4 8 8
Answer: (8 / 4 + 1) * 8 = 24
Input: 5 5 5 9
Answer: 5 + 5 + 5 + 9 = 24
Input: {input}
"""
# 5-shot
cot_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Input: 2 9 10 12
Steps:
12 * 2 = 24 (left: 9 10 24)
10 - 9 = 1 (left: 1 24)
24 * 1 = 24 (left: 24)
Answer: (12 * 2) * (10 - 9) = 24
Input: 4 9 10 13
Steps:
13 - 10 = 3 (left: 3 4 9)
9 - 3 = 6 (left: 4 6)
4 * 6 = 24 (left: 24)
Answer: 4 * (9 - (13 - 10)) = 24
Input: 1 4 8 8
Steps:
8 / 4 = 2 (left: 1 2 8)
1 + 2 = 3 (left: 3 8)
3 * 8 = 24 (left: 24)
Answer: (1 + 8 / 4) * 8 = 24
Input: 5 5 5 9
Steps:
5 + 5 = 10 (left: 5 9 10)
10 + 5 = 15 (left: 9 15)
15 + 9 = 24 (left: 24)
Answer: ((5 + 5) + 5) + 9 = 24
Input: {input}
"""
# 1-shot
propose_prompt = """Here is an Example for 1 input and 8 possible thoughts:
Input: 2 8 8 14
Possible next steps:
2 + 8 = 10 (left: 8 10 14)
8 / 2 = 4 (left: 4 8 14)
14 + 2 = 16 (left: 8 8 16)
2 * 8 = 16 (left: 8 14 16)
8 - 2 = 6 (left: 6 8 14)
14 - 8 = 6 (left: 2 6 8)
14 / 2 = 7 (left: 7 8 8)
14 - 2 = 12 (left: 8 8 12)
Here is my task for 1 input and {n_generate_sample} possible thoughts:
Input: {input}
Possible next steps:
"""
value_prompt = """Evaluate if given numbers can reach 24 (sure/likely/impossible)
10 14
10 + 14 = 24
sure
11 12
11 + 12 = 23
12 - 11 = 1
11 * 12 = 132
11 / 12 = 0.91
impossible
4 4 10
4 + 4 + 10 = 8 + 10 = 18
4 * 10 - 4 = 40 - 4 = 36
(10 - 4) * 4 = 6 * 4 = 24
sure
4 9 11
9 + 11 + 4 = 20 + 4 = 24
sure
5 7 8
5 + 7 + 8 = 12 + 8 = 20
(8 - 5) * 7 = 3 * 7 = 21
I cannot obtain 24 now, but numbers are within a reasonable range
likely
5 6 6
5 + 6 + 6 = 17
(6 - 5) * 6 = 1 * 6 = 6
I cannot obtain 24 now, but numbers are within a reasonable range
likely
10 10 11
10 + 10 + 11 = 31
(11 - 10) * 10 = 10
10 10 10 are all too big
impossible
1 3 3
1 * 3 * 3 = 9
(1 + 3) * 3 = 12
1 3 3 are all too small
impossible
{input}
"""
value_last_step_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Judge:
sure
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Judge:
sure
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Judge:
sure
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) + 1 = 25
Judge:
impossible
Input: 2 9 10 12
Answer: 2 * (12 - 10) = 24
Judge:
impossible
Input: 4 9 10 13
Answer: (13 - 4) * (10 - 9) = 24
Judge:
impossible
Input: {input}
Answer: {answer}
Judge:"""

272
metagpt/strategy/tot.py Normal file
View file

@ -0,0 +1,272 @@
# -*- coding: utf-8 -*-
# @Date : 12/23/2023 4:51 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import asyncio
from typing import Any, List
from pydantic import BaseModel, Field
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.strategy.base import ThoughtNode, ThoughtTree
from metagpt.strategy.tot_schema import MethodSelect, Strategy, ThoughtSolverConfig
from metagpt.utils.common import CodeParser
OUTPUT_FORMAT = """
Output a list of jsons following the format:
```json
[
{
"node_id": str = "unique identifier for a solution, can be an ordinal",
"node_state_instruction": "specified sample of solution",
},
...
]
```
"""
class ThoughtSolverBase(BaseModel):
thought_tree: str = ""
llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True)
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.llm.use_system_prompt = False
async def solve(self, init_prompt):
"""
Solve method for subclasses to implement.
"""
raise NotImplementedError("Subclasses must implement the solve method")
async def generate_thoughts(self, current_state="", current_node=None) -> List[ThoughtNode]:
"""
Generate children thoughts based on the current state.
Args:
current_state (str): The current state for which thoughts are generated.
current_node (ThoughtNode): The current node in the thought tree.
Returns:
List[ThoughtNode]: List of nodes representing the generated thoughts.
"""
state_prompt = self.config.parser.propose(
current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample}
)
rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT)
thoughts = CodeParser.parse_code(block=None, text=rsp)
thoughts = eval(thoughts)
# fixme 避免不跟随生成过多nodes
# valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample]
return self.thought_tree.update_node(thoughts, current_node=current_node)
async def evaluate_node(self, node, parent_value) -> None:
"""
Evaluate a node and update its status and value.
Args:
node (ThoughtNode): The node to be evaluated.
parent_value (float): The parent node's value.
Returns:
None
"""
eval_prompt = self.config.parser.value(input=node.name, **{"node_id": node.id})
evaluation = await self.llm.aask(msg=eval_prompt)
value = self.config.evaluator(evaluation, **{"node_id": node.id})
status = self.config.evaluator.status_verify(value)
node.update_valid_status(status=status)
# 累计分数
node.update_value(parent_value + value)
def select_nodes(self, thought_nodes: List[ThoughtNode]) -> List[ThoughtNode]:
"""
Select nodes based on the configured selection method.
Args:
thought_nodes (List[ThoughtNode]): List of nodes to be selected.
Returns:
List[ThoughtNode]: List of selected nodes.
"""
# selection
if self.config.method_select == MethodSelect.SAMPLE:
raise NotImplementedError
elif self.config.method_select == MethodSelect.GREEDY:
select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
for node in thought_nodes:
if node not in select_nodes:
node.parent = None # 从树中删除节点
return select_nodes
def update_solution(self):
"""
Select the result with the highest score.
Returns:
- List[ThoughtNode]: List of nodes representing the best solution.
- List[str]: List of node names forming the best solution path.
"""
best_node = max(self.thought_tree.all_nodes, key=lambda x: x.value, default=None)
best_solution_path = self.thought_tree.parse_node_path(best_node)
return [best_node], best_solution_path
class BFSSolver(ThoughtSolverBase):
async def solve(self, init_prompt=""):
"""
Solve the problem using Breadth-First Search (BFS) strategy.
Args:
init_prompt (str): The initial prompt for the solver.
Returns:
List[str]: The best solution path obtained through BFS.
"""
root = ThoughtNode(init_prompt)
self.thought_tree = ThoughtTree(root)
current_nodes = [root]
for step in range(self.config.max_steps):
solutions = await self._bfs_build(current_nodes)
selected_nodes = self.select_nodes(solutions)
current_nodes = selected_nodes
self.thought_tree.show()
best_solution, best_solution_path = self.update_solution()
logger.info(f"best solution is: {best_solution_path}")
return best_solution_path
async def _bfs_build(self, current_nodes):
"""
Build the thought tree using Breadth-First Search (BFS) strategy.
Args:
current_nodes (List[ThoughtNode]): Current nodes to expand.
Returns:
List[ThoughtNode]: The solutions obtained after expanding the current nodes.
"""
tasks = []
for node in current_nodes:
current_state = self.config.parser(node.name)
current_value = node.value
tasks.append(self.generate_and_evaluate_nodes(current_state, current_value, node))
thought_nodes_list = await asyncio.gather(*tasks)
solutions = [child_node for thought_nodes in thought_nodes_list for child_node in thought_nodes]
return solutions
async def generate_and_evaluate_nodes(self, current_state, current_value, node):
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
await asyncio.gather(
*(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes)
)
return thought_nodes
class DFSSolver(ThoughtSolverBase):
async def _dfs(self, root_node):
"""
Perform Depth-First Search (DFS) on the thought tree.
Args:
root_node (ThoughtNode): The root node of the thought tree.
Returns:
List[str]: The solution path obtained through DFS.
"""
impossible_state_cnt = 0
node = root_node
for step in range(self.max_steps):
current_state = self.config.parser(node.name)
current_value = node.value
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
await self.evaluate_node(thought_nodes[0], parent_value=current_value)
if thought_nodes[0].valid_status is False:
impossible_state_cnt += 1
if impossible_state_cnt >= 2:
logger.info("impossible state reached, break")
break
node = thought_nodes[0]
_solution_path = self.thought_tree.parse_node_path(node)
self.thought_tree.show()
return _solution_path
async def solve(self, init_prompt="", root=ThoughtNode("")):
"""
Solve the problem using Depth-First Search (DFS) strategy.
Args:
init_prompt (str): The initial prompt for the solver.
Returns:
List[str]: The best solution path obtained through DFS.
"""
root = ThoughtNode(init_prompt)
self.thought_tree = ThoughtTree(root)
for n in range(self.config.n_solution_sample):
# fixme: 需要产生回退,当前节点不可用时回退到父节点,产生新的节点继续探索
await self._dfs(root)
best_solution, best_solution_path = self.update_solution()
logger.info(f"best solution is: {best_solution_path}")
return best_solution_path
class MCTSSolver(ThoughtSolverBase):
async def solve(self, init_prompt=""):
raise NotImplementedError
class TreeofThought(BaseModel):
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
solver: ThoughtSolverBase = Field(default_factory=ThoughtSolverBase)
strategy: Strategy = Field(default=Strategy.BFS)
class Config:
arbitrary_types_allowed = True
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._initialize_solver(self.strategy)
def _initialize_solver(self, strategy):
"""
Initialize the solver based on the chosen strategy.
Args:
strategy (Strategy): The strategy to use for solving.
Returns:
ThoughtSolverBase: An instance of the appropriate solver.
"""
if strategy == Strategy.BFS:
self.solver = BFSSolver(config=self.config)
elif strategy == Strategy.DFS:
self.solver = DFSSolver(config=self.config)
elif strategy == Strategy.MCTS:
self.solver = MCTSSolver(config=self.config)
else:
raise NotImplementedError(f"Invalid strategy: {strategy}, only support BFS/DFS/MCTS currently!")
async def solve(self, init_prompt=""):
"""
Solve the problem using the specified strategy.
Args:
init_prompt (str): The initial prompt for the solver.
strategy (str): The strategy to use for solving.
Returns:
Any: The solution obtained using the selected strategy.
"""
await self.solver.solve(init_prompt)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 9:14 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from enum import Enum
from pydantic import BaseModel, Field
from metagpt.strategy.base import BaseEvaluator, BaseParser
class MethodSelect(Enum):
SAMPLE = "sample"
GREEDY = "greedy"
class Strategy(Enum):
BFS = "BFS"
DFS = "DFS"
MCTS = "MCTS"
class ThoughtSolverConfig(BaseModel):
max_steps: int = 3
method_select: str = MethodSelect.GREEDY # ["sample"/"greedy"]
n_generate_sample: int = 5 # per node
n_select_sample: int = 3 # per path
n_solution_sample: int = 5 # only for dfs
parser: BaseParser = Field(default_factory=BaseParser)
evaluator: BaseEvaluator = Field(default_factory=BaseEvaluator)

View file

@ -0,0 +1,105 @@
import pytest
from metagpt.actions import research
@pytest.mark.asyncio
async def test_collect_links(mocker):
async def mock_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["metagpt", "llm"]'
elif "Provide up to 4 queries related to your research topic" in prompt:
return (
'["MetaGPT use cases", "The roadmap of MetaGPT", '
'"The function of MetaGPT", "What llm MetaGPT support"]'
)
elif "sort the remaining search results" in prompt:
return "[1,2]"
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
resp = await research.CollectLinks().run("The application of MetaGPT")
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
assert i in resp
@pytest.mark.asyncio
async def test_collect_links_with_rank_func(mocker):
rank_before = []
rank_after = []
url_per_query = 4
def rank_func(results):
results = results[:url_per_query]
rank_before.append(results)
results = results[::-1]
rank_after.append(results)
return results
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_collect_links_llm_ask)
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
for x, y, z in zip(rank_before, rank_after, resp.values()):
assert x[::-1] == y
assert [i["link"] for i in y] == z
@pytest.mark.asyncio
async def test_web_browse_and_summarize(mocker):
async def mock_llm_ask(*args, **kwargs):
return "metagpt"
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
url = "https://github.com/geekan/MetaGPT"
url2 = "https://github.com/trending"
query = "What's new in metagpt"
resp = await research.WebBrowseAndSummarize().run(url, query=query)
assert len(resp) == 1
assert url in resp
assert resp[url] == "metagpt"
resp = await research.WebBrowseAndSummarize().run(url, url2, query=query)
assert len(resp) == 2
async def mock_llm_ask(*args, **kwargs):
return "Not relevant."
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
resp = await research.WebBrowseAndSummarize().run(url, query=query)
assert len(resp) == 1
assert url in resp
assert resp[url] is None
@pytest.mark.asyncio
async def test_conduct_research(mocker):
data = None
async def mock_llm_ask(*args, **kwargs):
nonlocal data
data = f"# Research Report\n## Introduction\n{args} {kwargs}"
return data
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
content = (
"MetaGPT takes a one line requirement as input and "
"outputs user stories / competitive analysis / requirements / data structures / APIs / documents, etc."
)
resp = await research.ConductResearch().run("The application of MetaGPT", content)
assert resp == data
async def mock_collect_links_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["metagpt", "llm"]'
elif "Provide up to 4 queries related to your research topic" in prompt:
return (
'["MetaGPT use cases", "The roadmap of MetaGPT", ' '"The function of MetaGPT", "What llm MetaGPT support"]'
)
elif "sort the remaining search results" in prompt:
return "[1,2]"
return ""

View file

@ -53,6 +53,7 @@ async def test_zhipuai_acompletion(mocker):
assert resp == resp_content
def test_zhipuai_proxy(mocker):
import openai

View file

@ -32,3 +32,19 @@ async def test_researcher(mocker):
researcher.RESEARCH_PATH = Path(dirname)
await researcher.Researcher().run(topic)
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
def test_write_report(mocker):
with TemporaryDirectory() as dirname:
for i, topic in enumerate(
[
("1./metagpt"),
('2.:"metagpt'),
("3.*?<>|metagpt"),
("4. metagpt\n"),
]
):
researcher.RESEARCH_PATH = Path(dirname)
content = "# Research Report"
researcher.Researcher().write_report(topic, content)
assert (researcher.RESEARCH_PATH / f"{i+1}. metagpt.md").read_text().startswith("# Research Report")