mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add Researcher role
This commit is contained in:
parent
ede23b2fe9
commit
25d2621198
11 changed files with 690 additions and 32 deletions
|
|
@ -65,4 +65,8 @@ SD_T2I_API: "/sdapi/v1/txt2img"
|
|||
|
||||
### for update_costs & calc_usage
|
||||
UPDATE_COSTS: false
|
||||
CALC_USAGE: false
|
||||
CALC_USAGE: false
|
||||
|
||||
### for Research
|
||||
MODEL_FOR_RESEARCHER_SUMMARY: gpt-3.5-turbo
|
||||
MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from metagpt.actions.write_code_review import WriteCodeReview
|
|||
from metagpt.actions.write_prd import WritePRD
|
||||
from metagpt.actions.write_prd_review import WritePRDReview
|
||||
from metagpt.actions.write_test import WriteTest
|
||||
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
|
|
@ -40,3 +41,6 @@ class ActionType(Enum):
|
|||
WRITE_TASKS = WriteTasks
|
||||
ASSIGN_TASKS = AssignTasks
|
||||
SEARCH_AND_SUMMARIZE = SearchAndSummarize
|
||||
COLLECT_LINKS = CollectLinks
|
||||
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
|
||||
CONDUCT_RESEARCH = ConductResearch
|
||||
|
|
|
|||
288
metagpt/actions/research.py
Normal file
288
metagpt/actions/research.py
Normal file
|
|
@ -0,0 +1,288 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
|
||||
from metagpt.utils.text import generate_prompt_chunk, reduce_message_length
|
||||
|
||||
LANG_PROMPT = "Please respond in {language}."
|
||||
|
||||
RESEARCH_BASE_SYSTEM = """You are an AI critical thinker research assistant. Your sole purpose is to write well \
|
||||
written, critically acclaimed, objective and structured reports on given text."""
|
||||
|
||||
RESEARCH_TOPIC_SYSTEM = "You are an AI researcher assistant, and your research topic is \"{topic}\"."
|
||||
|
||||
SEARCH_TOPIC_PROMPT = """Please provide up to 2 necessary keywords related to your research topic that require Google search. \
|
||||
Your response must be in JSON format, for example: ["keyword1", "keyword2"]."""
|
||||
|
||||
SUMMARIZE_SEARCH_PROMPT = """### Requirements
|
||||
1. The keywords related to your research topic and the search results are shown in the "Reference Information" section.
|
||||
2. Provide up to {decomposition_nums} queries related to your research topic base on the search results.
|
||||
3. Please respond in JSON format as follows: ["query1", "query2", "query3", ...].
|
||||
|
||||
### Reference Information
|
||||
{search}
|
||||
"""
|
||||
|
||||
DECOMPOSITION_PROMPT = """You are a researcher, and before delving into an topic, you break it down into several \
|
||||
sub-questions. These sub-questions can be researched through online searches to gather objective opinions about the given \
|
||||
topic.
|
||||
|
||||
---
|
||||
The topic is: {topic}
|
||||
|
||||
---
|
||||
Now, please break down the provided topic into {decomposition_nums} search questions. You should respond with an array of \
|
||||
strings in JSON format like ["question1", "question2", ...].
|
||||
"""
|
||||
|
||||
COLLECT_AND_RANKURLS_PROMPT = """### Reference Information
|
||||
1. Topic: "{topic}"
|
||||
2. Query: "{query}"
|
||||
3. The online search results: {results}
|
||||
|
||||
---
|
||||
Please remove irrelevant search results that are not related to the query or topic. Then, sort the remaining search results \
|
||||
based on link credibility. If two results have equal credibility, prioritize them based on relevance. Provide the ranked \
|
||||
results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words.
|
||||
"""
|
||||
|
||||
WEB_BROWSE_AND_SUMMARIZE_PROMPT = '''### Requirements
|
||||
1. Utilize the text in the "Reference Information" section to respond to the question "{query}".
|
||||
2. If the question cannot be directly answered using the text, but the text is related to the research topic, please provide \
|
||||
a comprehensive summary of the text.
|
||||
3. If the text is entirely unrelated to the research topic, please reply with a simple text "Not relevant."
|
||||
4. Include all relevant factual information, numbers, statistics, etc., if available.
|
||||
|
||||
### Reference Information
|
||||
{content}
|
||||
'''
|
||||
|
||||
|
||||
CONDUCT_RESEARCH_PROMPT = '''### Reference Information
|
||||
{content}
|
||||
|
||||
### Requirements
|
||||
Please provide a detailed research report in response to the following topic: "{topic}", using the information provided \
|
||||
above. The report must adhere to the following requirements:
|
||||
|
||||
- Focus on directly addressing the chosen topic.
|
||||
- Ensure a well-structured and in-depth presentation, incorporating relevant facts and figures where available.
|
||||
- Present data and findings in an intuitive manner, utilizing feature comparative tables, if applicable.
|
||||
- The report should have a minimum word count of 2,000 and be formatted with Markdown syntax following APA style guidelines.
|
||||
- Include all source URLs in APA format at the end of the report.
|
||||
'''
|
||||
|
||||
|
||||
class CollectLinks(Action):
|
||||
"""Action class to collect links from a search engine."""
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "",
|
||||
*args,
|
||||
rank_func: Callable[[list[str]], None] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(name, *args, **kwargs)
|
||||
self.desc = "Collect links from a search engine."
|
||||
self.search_engine = SearchEngine()
|
||||
self.rank_func = rank_func
|
||||
|
||||
async def run(
|
||||
self,
|
||||
topic: str,
|
||||
decomposition_nums: int = 4,
|
||||
url_per_query: int = 4,
|
||||
system_text: str | None = None,
|
||||
) -> dict[str, list[str]]:
|
||||
"""Run the action to collect links.
|
||||
|
||||
Args:
|
||||
topic: The research topic.
|
||||
decomposition_nums: The number of search questions to generate.
|
||||
url_per_query: The number of URLs to collect per search question.
|
||||
system_text: The system text.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the search questions as keys and the collected URLs as values.
|
||||
"""
|
||||
system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic)
|
||||
search_topic_prompt = SEARCH_TOPIC_PROMPT.format(topic=topic)
|
||||
logger.debug(search_topic_prompt)
|
||||
keywords = await self._aask(search_topic_prompt, [system_text])
|
||||
try:
|
||||
keywords = json.loads(keywords)
|
||||
keywords = parse_obj_as(list[str], keywords)
|
||||
except Exception as e:
|
||||
logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}")
|
||||
keywords = [topic]
|
||||
results = await asyncio.gather(*(self.search_engine.run(i, as_string=False) for i in keywords))
|
||||
|
||||
def gen_msg():
|
||||
while True:
|
||||
search = "\n".join(f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results))
|
||||
prompt = SUMMARIZE_SEARCH_PROMPT.format(decomposition_nums=decomposition_nums, search=search)
|
||||
yield prompt
|
||||
remove = max(results, key=len)
|
||||
remove.pop()
|
||||
if len(remove) == 0:
|
||||
break
|
||||
prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp)
|
||||
logger.debug(prompt)
|
||||
queries = await self._aask(prompt, [system_text])
|
||||
try:
|
||||
queries = json.loads(queries)
|
||||
queries = parse_obj_as(list[str], queries)
|
||||
except Exception as e:
|
||||
logger.exception(f"fail to break down the research question for {e}")
|
||||
queries = keywords
|
||||
ret = {}
|
||||
for query in queries:
|
||||
ret[query] = await self._search_and_rank_urls(topic, query, url_per_query)
|
||||
return ret
|
||||
|
||||
async def _search_and_rank_urls(self, topic: str, query: str, num_results: int = 4) -> list[str]:
|
||||
"""Search and rank URLs based on a query.
|
||||
|
||||
Args:
|
||||
topic: The research topic.
|
||||
query: The search query.
|
||||
num_results: The number of URLs to collect.
|
||||
|
||||
Returns:
|
||||
A list of ranked URLs.
|
||||
"""
|
||||
max_results = max(num_results * 2, 6)
|
||||
results = await self.search_engine.run(query, max_results=max_results, as_string=False)
|
||||
_results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results))
|
||||
prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results)
|
||||
logger.debug(prompt)
|
||||
indices = await self._aask(prompt)
|
||||
try:
|
||||
indices = json.loads(indices)
|
||||
assert all(isinstance(i, int) for i in indices)
|
||||
except Exception as e:
|
||||
logger.exception(f"fail to rank results for {e}")
|
||||
indices = list(range(max_results))
|
||||
results = [results[i] for i in indices]
|
||||
if self.rank_func:
|
||||
results = self.rank_func(results)
|
||||
return [i["link"] for i in results[:num_results]]
|
||||
|
||||
|
||||
class WebBrowseAndSummarize(Action):
|
||||
"""Action class to explore the web and provide summaries of articles and webpages."""
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
browse_func: Callable[[list[str]], None] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
if CONFIG.model_for_researcher_summary:
|
||||
self.llm.model = CONFIG.model_for_researcher_summary
|
||||
self.web_browser_engine = WebBrowserEngine(
|
||||
engine=WebBrowserEngineType.CUSTOM if browse_func else None,
|
||||
run_func=browse_func,
|
||||
)
|
||||
self.desc = "Explore the web and provide summaries of articles and webpages."
|
||||
|
||||
async def run(
|
||||
self,
|
||||
url: str,
|
||||
*urls: str,
|
||||
query: str,
|
||||
system_text: str = RESEARCH_BASE_SYSTEM,
|
||||
) -> dict[str, str]:
|
||||
"""Run the action to browse the web and provide summaries.
|
||||
|
||||
Args:
|
||||
url: The main URL to browse.
|
||||
urls: Additional URLs to browse.
|
||||
query: The research question.
|
||||
system_text: The system text.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the URLs as keys and their summaries as values.
|
||||
"""
|
||||
contents = await self.web_browser_engine.run(url, *urls)
|
||||
if not urls:
|
||||
contents = [contents]
|
||||
|
||||
summaries = {}
|
||||
prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}")
|
||||
for u, content in zip([url, *urls], contents):
|
||||
content = content.inner_text
|
||||
chunk_summaries = []
|
||||
for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp):
|
||||
logger.debug(prompt)
|
||||
summary = await self._aask(prompt, [system_text])
|
||||
if summary == "Not relevant.":
|
||||
continue
|
||||
chunk_summaries.append(summary)
|
||||
|
||||
if not chunk_summaries:
|
||||
summaries[u] = None
|
||||
continue
|
||||
|
||||
if len(chunk_summaries) == 1:
|
||||
summaries[u] = chunk_summaries[0]
|
||||
continue
|
||||
|
||||
content = "\n".join(chunk_summaries)
|
||||
prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content)
|
||||
summary = await self._aask(prompt, [system_text])
|
||||
summaries[u] = summary
|
||||
return summaries
|
||||
|
||||
|
||||
class ConductResearch(Action):
|
||||
"""Action class to conduct research and generate a research report."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if CONFIG.model_for_researcher_report:
|
||||
self.llm.model = CONFIG.model_for_researcher_report
|
||||
|
||||
async def run(
|
||||
self,
|
||||
topic: str,
|
||||
content: str,
|
||||
system_text: str = RESEARCH_BASE_SYSTEM,
|
||||
) -> str:
|
||||
"""Run the action to conduct research and generate a research report.
|
||||
|
||||
Args:
|
||||
topic: The research topic.
|
||||
content: The content for research.
|
||||
system_text: The system text.
|
||||
|
||||
Returns:
|
||||
The generated research report.
|
||||
"""
|
||||
prompt = CONDUCT_RESEARCH_PROMPT.format(topic=topic, content=content)
|
||||
logger.debug(prompt)
|
||||
self.llm.auto_max_tokens = True
|
||||
return await self._aask(prompt, [system_text])
|
||||
|
||||
|
||||
def get_research_system_text(topic: str, language: str):
|
||||
"""Get the system text for conducting research.
|
||||
|
||||
Args:
|
||||
topic: The research topic.
|
||||
language: The language for the system text.
|
||||
|
||||
Returns:
|
||||
The system text for conducting research.
|
||||
"""
|
||||
return " ".join((RESEARCH_TOPIC_SYSTEM.format(topic=topic), LANG_PROMPT.format(language=language)))
|
||||
|
|
@ -4,14 +4,14 @@
|
|||
提供配置,单例
|
||||
"""
|
||||
import os
|
||||
import openai
|
||||
|
||||
import openai
|
||||
import yaml
|
||||
|
||||
from metagpt.const import PROJECT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.singleton import Singleton
|
||||
from metagpt.tools import SearchEngineType, WebBrowserEngineType
|
||||
from metagpt.utils.singleton import Singleton
|
||||
|
||||
|
||||
class NotConfiguredException(Exception):
|
||||
|
|
@ -77,12 +77,12 @@ class Config(metaclass=Singleton):
|
|||
logger.warning("LONG_TERM_MEMORY is True")
|
||||
self.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
self.total_cost = 0.0
|
||||
self.puppeteer_config = self._get("PUPPETEER_CONFIG","")
|
||||
self.mmdc = self._get("MMDC","mmdc")
|
||||
self.update_costs = self._get("UPDATE_COSTS",True)
|
||||
self.calc_usage = self._get("CALC_USAGE",True)
|
||||
|
||||
|
||||
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
|
||||
self.mmdc = self._get("MMDC", "mmdc")
|
||||
self.update_costs = self._get("UPDATE_COSTS", True)
|
||||
self.calc_usage = self._get("CALC_USAGE", True)
|
||||
self.model_for_researcher_summary = self._get("MODEL_FOR_RESEARCHER_SUMMARY")
|
||||
self.model_for_researcher_report = self._get("MODEL_FOR_RESEARCHER_REPORT")
|
||||
|
||||
def _init_with_config_files_and_env(self, configs: dict, yaml_file):
|
||||
"""从config/key.yaml / config/config.yaml / env三处按优先级递减加载"""
|
||||
|
|
|
|||
|
|
@ -32,5 +32,6 @@ UT_PY_PATH = UT_PATH / "files/ut/"
|
|||
API_QUESTIONS_PATH = UT_PATH / "files/question/"
|
||||
YAPI_URL = "http://yapi.deepwisdomai.com/"
|
||||
TMP = PROJECT_ROOT / 'tmp'
|
||||
RESEARCH_PATH = DATA_PATH / "research"
|
||||
|
||||
MEM_TTL = 24 * 30 * 3600
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/5 23:08
|
||||
|
|
@ -20,6 +19,7 @@ from metagpt.utils.token_counter import (
|
|||
TOKEN_COSTS,
|
||||
count_message_tokens,
|
||||
count_string_tokens,
|
||||
get_max_completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -33,20 +33,25 @@ def retry(max_retries):
|
|||
except Exception:
|
||||
if i == max_retries - 1:
|
||||
raise
|
||||
await asyncio.sleep(2 ** i)
|
||||
await asyncio.sleep(2**i)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate control class, each call goes through wait_if_needed, sleep if rate control is needed"""
|
||||
|
||||
def __init__(self, rpm):
|
||||
self.last_call_time = 0
|
||||
self.interval = 1.1 * 60 / rpm # Here 1.1 is used because even if the calls are made strictly according to time, they will still be QOS'd; consider switching to simple error retry later
|
||||
# Here 1.1 is used because even if the calls are made strictly according to time,
|
||||
# they will still be QOS'd; consider switching to simple error retry later
|
||||
self.interval = 1.1 * 60 / rpm
|
||||
self.rpm = rpm
|
||||
|
||||
def split_batches(self, batch):
|
||||
return [batch[i:i + self.rpm] for i in range(0, len(batch), self.rpm)]
|
||||
return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)]
|
||||
|
||||
async def wait_if_needed(self, num_requests):
|
||||
current_time = time.time()
|
||||
|
|
@ -69,6 +74,7 @@ class Costs(NamedTuple):
|
|||
|
||||
class CostManager(metaclass=Singleton):
|
||||
"""计算使用接口的开销"""
|
||||
|
||||
def __init__(self):
|
||||
self.total_prompt_tokens = 0
|
||||
self.total_completion_tokens = 0
|
||||
|
|
@ -86,13 +92,12 @@ class CostManager(metaclass=Singleton):
|
|||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
cost = (
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"]
|
||||
+ completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
) / 1000
|
||||
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, {prompt_tokens=}, {completion_tokens=}")
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
def get_total_prompt_tokens(self):
|
||||
|
|
@ -131,10 +136,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_openai(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.openai_api_model
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = CostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
|
|
@ -148,10 +155,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
self.rpm = int(config.get("RPM", 10))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
**self._cons_kwargs(messages),
|
||||
stream=True
|
||||
)
|
||||
response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True)
|
||||
|
||||
# create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
|
|
@ -159,41 +163,41 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
chunk_message = chunk['choices'][0]['delta'] # extract the message
|
||||
chunk_message = chunk["choices"][0]["delta"] # extract the message
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
if "content" in chunk_message:
|
||||
print(chunk_message["content"], end="")
|
||||
print()
|
||||
|
||||
full_reply_content = ''.join([m.get('content', '') for m in collected_messages])
|
||||
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict]) -> dict:
|
||||
if CONFIG.openai_api_type == 'azure':
|
||||
if CONFIG.openai_api_type == "azure":
|
||||
kwargs = {
|
||||
"deployment_id": CONFIG.deployment_id,
|
||||
"messages": messages,
|
||||
"max_tokens": CONFIG.max_tokens_rsp,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3
|
||||
"temperature": 0.3,
|
||||
}
|
||||
else:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": CONFIG.max_tokens_rsp,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3
|
||||
"temperature": 0.3,
|
||||
}
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages))
|
||||
self._update_costs(rsp.get('usage'))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
return rsp
|
||||
|
||||
def _chat_completion(self, messages: list[dict]) -> dict:
|
||||
|
|
@ -262,3 +266,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
||||
def get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model)
|
||||
|
|
|
|||
91
metagpt/roles/researcher.py
Normal file
91
metagpt/roles/researcher.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import CollectLinks, ConductResearch, WebBrowseAndSummarize
|
||||
from metagpt.actions.research import get_research_system_text
|
||||
from metagpt.const import RESEARCH_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class Report(BaseModel):
|
||||
topic: str
|
||||
links: dict[str, list[str]] = None
|
||||
summaries: list[tuple[str, str]] = None
|
||||
content: str = ""
|
||||
|
||||
|
||||
class Researcher(Role):
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "Bob",
|
||||
profile: str = "Researcher",
|
||||
goal: str = "Gather information and conduct research",
|
||||
constraints: str = "Ensure accuracy and relevance of information",
|
||||
language: str = "en-us",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(name, profile, goal, constraints, **kwargs)
|
||||
self._init_actions([CollectLinks(name), WebBrowseAndSummarize(name), ConductResearch(name)])
|
||||
self.language = language
|
||||
|
||||
async def _think(self) -> None:
|
||||
if self._rc.todo is None:
|
||||
self._set_state(0)
|
||||
return
|
||||
|
||||
if self._rc.state + 1 < len(self._states):
|
||||
self._set_state(self._rc.state + 1)
|
||||
else:
|
||||
self._rc.todo = None
|
||||
|
||||
async def _act(self) -> Message:
|
||||
logger.info(f"{self._setting}: ready to {self._rc.todo}")
|
||||
todo = self._rc.todo
|
||||
msg = self._rc.memory.get(k=1)[0]
|
||||
if isinstance(msg.instruct_content, Report):
|
||||
instruct_content = msg.instruct_content
|
||||
topic = instruct_content.topic
|
||||
else:
|
||||
topic = msg.content
|
||||
|
||||
research_system_text = get_research_system_text(topic, self.language)
|
||||
if isinstance(todo, CollectLinks):
|
||||
links = await todo.run(topic, 4, 4)
|
||||
ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=type(todo))
|
||||
elif isinstance(todo, WebBrowseAndSummarize):
|
||||
links = instruct_content.links
|
||||
todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items())
|
||||
summaries = await asyncio.gather(*todos)
|
||||
summaries = list((url, summary) for i in summaries for (url, summary) in i.items() if summary)
|
||||
ret = Message("", Report(topic=topic, summaries=summaries), role=self.profile, cause_by=type(todo))
|
||||
else:
|
||||
summaries = instruct_content.summaries
|
||||
summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries)
|
||||
content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text)
|
||||
ret = Message("", Report(topic=topic, content=content), role=self.profile, cause_by=type(self._rc.todo))
|
||||
self._rc.memory.add(ret)
|
||||
return ret
|
||||
|
||||
async def _react(self) -> Message:
|
||||
while True:
|
||||
await self._think()
|
||||
if self._rc.todo is None:
|
||||
break
|
||||
msg = await self._act()
|
||||
report = msg.instruct_content
|
||||
self.write_report(report.topic, report.content)
|
||||
return msg
|
||||
|
||||
def write_report(self, topic: str, content: str):
|
||||
filepath = RESEARCH_PATH / f"{topic}.md"
|
||||
filepath.write_text(content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
role = Researcher(language="en-us")
|
||||
asyncio.run(role.run("dataiku .vs datarobot"))
|
||||
124
metagpt/utils/text.py
Normal file
124
metagpt/utils/text.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
from typing import Generator, Sequence
|
||||
|
||||
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
|
||||
|
||||
|
||||
def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
|
||||
"""Reduce the length of messages to fit within the maximum token size.
|
||||
|
||||
Args:
|
||||
msgs: A generator of strings representing the messages.
|
||||
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
|
||||
system_text: The system prompts.
|
||||
reserved: The number of reserved tokens.
|
||||
|
||||
Returns:
|
||||
The reduced message.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If it fails to reduce the message length.
|
||||
"""
|
||||
max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved
|
||||
for msg in msgs:
|
||||
if count_string_tokens(msg, model_name) < max_token:
|
||||
return msg
|
||||
|
||||
raise RuntimeError("fail to reduce message length")
|
||||
|
||||
|
||||
def generate_prompt_chunk(
|
||||
text: str,
|
||||
prompt_template: str,
|
||||
model_name: str,
|
||||
system_text: str,
|
||||
reserved: int = 0,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Split the text into chunks of a maximum token size.
|
||||
|
||||
Args:
|
||||
text: The text to split.
|
||||
prompt_template: The template for the prompt.
|
||||
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
|
||||
system_text: The system prompts.
|
||||
reserved: The number of reserved tokens.
|
||||
|
||||
Yields:
|
||||
The chunk of text.
|
||||
"""
|
||||
paragraphs = text.splitlines(keepends=True)
|
||||
current_token = 0
|
||||
current_lines = []
|
||||
|
||||
reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
|
||||
# 100 is a magic number to ensure the maximum context length is not exceeded
|
||||
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
|
||||
|
||||
while paragraphs:
|
||||
paragraph = paragraphs.pop(0)
|
||||
token = count_string_tokens(paragraph, model_name)
|
||||
if current_token + token <= max_token:
|
||||
current_lines.append(paragraph)
|
||||
current_token += token
|
||||
elif token > max_token:
|
||||
paragraphs = split_paragraph(paragraph) + paragraphs
|
||||
continue
|
||||
else:
|
||||
yield prompt_template.format("".join(current_lines))
|
||||
current_lines = [paragraph]
|
||||
current_token = token
|
||||
|
||||
if current_lines:
|
||||
yield prompt_template.format("".join(current_lines))
|
||||
|
||||
|
||||
def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str]:
|
||||
"""Split a paragraph into multiple parts.
|
||||
|
||||
Args:
|
||||
paragraph: The paragraph to split.
|
||||
sep: The separator character.
|
||||
count: The number of parts to split the paragraph into.
|
||||
|
||||
Returns:
|
||||
A list of split parts of the paragraph.
|
||||
"""
|
||||
for i in sep:
|
||||
sentences = list(_split_text_with_ends(paragraph, i))
|
||||
if len(sentences) <= 1:
|
||||
continue
|
||||
ret = ["".join(j) for j in _split_by_count(sentences, count)]
|
||||
return ret
|
||||
return _split_by_count(paragraph, count)
|
||||
|
||||
|
||||
def decode_unicode_escape(text: str) -> str:
|
||||
"""Decode a text with unicode escape sequences.
|
||||
|
||||
Args:
|
||||
text: The text to decode.
|
||||
|
||||
Returns:
|
||||
The decoded text.
|
||||
"""
|
||||
return text.encode("utf-8").decode("unicode_escape", "ignore")
|
||||
|
||||
|
||||
def _split_by_count(lst: Sequence , count: int):
|
||||
avg = len(lst) // count
|
||||
remainder = len(lst) % count
|
||||
start = 0
|
||||
for i in range(count):
|
||||
end = start + avg + (1 if i < remainder else 0)
|
||||
yield lst[start:end]
|
||||
start = end
|
||||
|
||||
|
||||
def _split_text_with_ends(text: str, sep: str = "."):
|
||||
parts = []
|
||||
for i in text:
|
||||
parts.append(i)
|
||||
if i == sep:
|
||||
yield "".join(parts)
|
||||
parts = []
|
||||
if parts:
|
||||
yield "".join(parts)
|
||||
|
|
@ -25,6 +25,21 @@ TOKEN_COSTS = {
|
|||
}
|
||||
|
||||
|
||||
TOKEN_MAX = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
"gpt-3.5-turbo-0613": 4096,
|
||||
"gpt-3.5-turbo-16k": 16384,
|
||||
"gpt-3.5-turbo-16k-0613": 16384,
|
||||
"gpt-4-0314": 8192,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0314": 32768,
|
||||
"gpt-4-0613": 8192,
|
||||
"text-embedding-ada-002": 8192,
|
||||
}
|
||||
|
||||
|
||||
def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
try:
|
||||
|
|
@ -39,7 +54,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
|||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
}:
|
||||
}:
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
|
|
@ -79,3 +94,16 @@ def count_string_tokens(string: str, model_name: str) -> int:
|
|||
"""
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
return len(encoding.encode(string))
|
||||
|
||||
|
||||
def get_max_completion_tokens(messages: list[dict], model: str):
|
||||
"""Calculate the maximum number of completion tokens for a given model and list of messages.
|
||||
|
||||
Args:
|
||||
messages: A list of messages.
|
||||
model: The model name.
|
||||
|
||||
Returns:
|
||||
The maximum number of completion tokens.
|
||||
"""
|
||||
return TOKEN_MAX.get(model, 4096) - count_message_tokens(messages)
|
||||
|
|
|
|||
32
tests/metagpt/roles/test_researcher.py
Normal file
32
tests/metagpt/roles/test_researcher.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from pathlib import Path
|
||||
from random import random
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.roles import researcher
|
||||
|
||||
|
||||
async def mock_llm_ask(self, prompt: str, system_msgs):
|
||||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["dataiku", "datarobot"]'
|
||||
elif "Provide up to 4 queries related to your research topic" in prompt:
|
||||
return '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' \
|
||||
'"Dataiku vs DataRobot features", "Dataiku and DataRobot use cases"]'
|
||||
elif "sort the remaining search results" in prompt:
|
||||
return '[1,2]'
|
||||
elif "Not relevant." in prompt:
|
||||
return "Not relevant" if random() > 0.5 else prompt[-100:]
|
||||
elif "provide a detailed research report" in prompt:
|
||||
return f"# Research Report\n## Introduction\n{prompt}"
|
||||
return ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_researcher(mocker):
|
||||
with TemporaryDirectory() as dirname:
|
||||
topic = "dataiku .vs datarobot"
|
||||
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
|
||||
researcher.RESEARCH_PATH = Path(dirname)
|
||||
await researcher.Researcher().run(topic)
|
||||
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
|
||||
77
tests/metagpt/utils/test_text.py
Normal file
77
tests/metagpt/utils/test_text.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.utils.text import (
|
||||
decode_unicode_escape,
|
||||
generate_prompt_chunk,
|
||||
reduce_message_length,
|
||||
split_paragraph,
|
||||
)
|
||||
|
||||
|
||||
def _msgs():
|
||||
length = 20
|
||||
while length:
|
||||
yield "Hello," * 1000 * length
|
||||
length -= 1
|
||||
|
||||
|
||||
def _paragraphs(n):
|
||||
return " ".join("Hello World." for _ in range(n))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msgs, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(_msgs(), "gpt-3.5-turbo", "System", 1500, 1),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "Hello," * 1000, 3000, 5),
|
||||
(_msgs(), "gpt-4", "System", 2000, 3),
|
||||
(_msgs(), "gpt-4", "Hello," * 1000, 2000, 2),
|
||||
(_msgs(), "gpt-4-32k", "System", 4000, 14),
|
||||
(_msgs(), "gpt-4-32k", "Hello," * 2000, 4000, 12),
|
||||
]
|
||||
)
|
||||
def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
|
||||
assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, prompt_template, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1500, 2),
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
|
||||
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
|
||||
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
|
||||
]
|
||||
)
|
||||
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
|
||||
ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
|
||||
assert len(ret) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"paragraph, sep, count, expected",
|
||||
[
|
||||
(_paragraphs(10), ".", 2, [_paragraphs(5), f" {_paragraphs(5)}"]),
|
||||
(_paragraphs(10), ".", 3, [_paragraphs(4), f" {_paragraphs(3)}", f" {_paragraphs(3)}"]),
|
||||
(f"{_paragraphs(5)}\n{_paragraphs(3)}", "\n.", 2, [f"{_paragraphs(5)}\n", _paragraphs(3)]),
|
||||
("......", ".", 2, ["...", "..."]),
|
||||
("......", ".", 3, ["..", "..", ".."]),
|
||||
(".......", ".", 2, ["....", "..."]),
|
||||
]
|
||||
)
|
||||
def test_split_paragraph(paragraph, sep, count, expected):
|
||||
ret = split_paragraph(paragraph, sep, count)
|
||||
assert ret == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("Hello\\nWorld", "Hello\nWorld"),
|
||||
("Hello\\tWorld", "Hello\tWorld"),
|
||||
("Hello\\u0020World", "Hello World"),
|
||||
]
|
||||
)
|
||||
def test_decode_unicode_escape(text, expected):
|
||||
assert decode_unicode_escape(text) == expected
|
||||
Loading…
Add table
Add a link
Reference in a new issue