mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-21 14:05:17 +02:00
commit
2e1e647628
27 changed files with 1282 additions and 246 deletions
|
|
@ -65,4 +65,8 @@ SD_T2I_API: "/sdapi/v1/txt2img"
|
|||
|
||||
### for update_costs & calc_usage
|
||||
UPDATE_COSTS: false
|
||||
CALC_USAGE: false
|
||||
CALC_USAGE: false
|
||||
|
||||
### for Research
|
||||
MODEL_FOR_RESEARCHER_SUMMARY: gpt-3.5-turbo
|
||||
MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k
|
||||
|
|
|
|||
16
examples/research.py
Normal file
16
examples/research.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import asyncio
|
||||
|
||||
from metagpt.roles.researcher import RESEARCH_PATH, Researcher
|
||||
|
||||
|
||||
async def main():
|
||||
topic = "dataiku vs. datarobot"
|
||||
role = Researcher(language="en-us")
|
||||
await role.run(topic)
|
||||
print(f"save report to {RESEARCH_PATH / f'{topic}.md'}.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
|
@ -22,6 +22,7 @@ from metagpt.actions.write_code_review import WriteCodeReview
|
|||
from metagpt.actions.write_prd import WritePRD
|
||||
from metagpt.actions.write_prd_review import WritePRDReview
|
||||
from metagpt.actions.write_test import WriteTest
|
||||
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
|
|
@ -40,3 +41,6 @@ class ActionType(Enum):
|
|||
WRITE_TASKS = WriteTasks
|
||||
ASSIGN_TASKS = AssignTasks
|
||||
SEARCH_AND_SUMMARIZE = SearchAndSummarize
|
||||
COLLECT_LINKS = CollectLinks
|
||||
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
|
||||
CONDUCT_RESEARCH = ConductResearch
|
||||
|
|
|
|||
277
metagpt/actions/research.py
Normal file
277
metagpt/actions/research.py
Normal file
|
|
@ -0,0 +1,277 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Callable
|
||||
|
||||
from pydantic import parse_obj_as
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
|
||||
from metagpt.utils.text import generate_prompt_chunk, reduce_message_length
|
||||
|
||||
LANG_PROMPT = "Please respond in {language}."
|
||||
|
||||
RESEARCH_BASE_SYSTEM = """You are an AI critical thinker research assistant. Your sole purpose is to write well \
|
||||
written, critically acclaimed, objective and structured reports on the given text."""
|
||||
|
||||
RESEARCH_TOPIC_SYSTEM = "You are an AI researcher assistant, and your research topic is:\n#TOPIC#\n{topic}"
|
||||
|
||||
SEARCH_TOPIC_PROMPT = """Please provide up to 2 necessary keywords related to your research topic for Google search. \
|
||||
Your response must be in JSON format, for example: ["keyword1", "keyword2"]."""
|
||||
|
||||
SUMMARIZE_SEARCH_PROMPT = """### Requirements
|
||||
1. The keywords related to your research topic and the search results are shown in the "Search Result Information" section.
|
||||
2. Provide up to {decomposition_nums} queries related to your research topic base on the search results.
|
||||
3. Please respond in the following JSON format: ["query1", "query2", "query3", ...].
|
||||
|
||||
### Search Result Information
|
||||
{search_results}
|
||||
"""
|
||||
|
||||
COLLECT_AND_RANKURLS_PROMPT = """### Topic
|
||||
{topic}
|
||||
### Query
|
||||
{query}
|
||||
|
||||
### The online search results
|
||||
{results}
|
||||
|
||||
### Requirements
|
||||
Please remove irrelevant search results that are not related to the query or topic. Then, sort the remaining search results \
|
||||
based on the link credibility. If two results have equal credibility, prioritize them based on the relevance. Provide the
|
||||
ranked results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words.
|
||||
"""
|
||||
|
||||
WEB_BROWSE_AND_SUMMARIZE_PROMPT = '''### Requirements
|
||||
1. Utilize the text in the "Reference Information" section to respond to the question "{query}".
|
||||
2. If the question cannot be directly answered using the text, but the text is related to the research topic, please provide \
|
||||
a comprehensive summary of the text.
|
||||
3. If the text is entirely unrelated to the research topic, please reply with a simple text "Not relevant."
|
||||
4. Include all relevant factual information, numbers, statistics, etc., if available.
|
||||
|
||||
### Reference Information
|
||||
{content}
|
||||
'''
|
||||
|
||||
|
||||
CONDUCT_RESEARCH_PROMPT = '''### Reference Information
|
||||
{content}
|
||||
|
||||
### Requirements
|
||||
Please provide a detailed research report in response to the following topic: "{topic}", using the information provided \
|
||||
above. The report must meet the following requirements:
|
||||
|
||||
- Focus on directly addressing the chosen topic.
|
||||
- Ensure a well-structured and in-depth presentation, incorporating relevant facts and figures where available.
|
||||
- Present data and findings in an intuitive manner, utilizing feature comparative tables, if applicable.
|
||||
- The report should have a minimum word count of 2,000 and be formatted with Markdown syntax following APA style guidelines.
|
||||
- Include all source URLs in APA format at the end of the report.
|
||||
'''
|
||||
|
||||
|
||||
class CollectLinks(Action):
|
||||
"""Action class to collect links from a search engine."""
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "",
|
||||
*args,
|
||||
rank_func: Callable[[list[str]], None] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(name, *args, **kwargs)
|
||||
self.desc = "Collect links from a search engine."
|
||||
self.search_engine = SearchEngine()
|
||||
self.rank_func = rank_func
|
||||
|
||||
async def run(
|
||||
self,
|
||||
topic: str,
|
||||
decomposition_nums: int = 4,
|
||||
url_per_query: int = 4,
|
||||
system_text: str | None = None,
|
||||
) -> dict[str, list[str]]:
|
||||
"""Run the action to collect links.
|
||||
|
||||
Args:
|
||||
topic: The research topic.
|
||||
decomposition_nums: The number of search questions to generate.
|
||||
url_per_query: The number of URLs to collect per search question.
|
||||
system_text: The system text.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the search questions as keys and the collected URLs as values.
|
||||
"""
|
||||
system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic)
|
||||
keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text])
|
||||
try:
|
||||
keywords = json.loads(keywords)
|
||||
keywords = parse_obj_as(list[str], keywords)
|
||||
except Exception as e:
|
||||
logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}")
|
||||
keywords = [topic]
|
||||
results = await asyncio.gather(*(self.search_engine.run(i, as_string=False) for i in keywords))
|
||||
|
||||
def gen_msg():
|
||||
while True:
|
||||
search_results = "\n".join(f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results))
|
||||
prompt = SUMMARIZE_SEARCH_PROMPT.format(decomposition_nums=decomposition_nums, search_results=search_results)
|
||||
yield prompt
|
||||
remove = max(results, key=len)
|
||||
remove.pop()
|
||||
if len(remove) == 0:
|
||||
break
|
||||
prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp)
|
||||
logger.debug(prompt)
|
||||
queries = await self._aask(prompt, [system_text])
|
||||
try:
|
||||
queries = json.loads(queries)
|
||||
queries = parse_obj_as(list[str], queries)
|
||||
except Exception as e:
|
||||
logger.exception(f"fail to break down the research question due to {e}")
|
||||
queries = keywords
|
||||
ret = {}
|
||||
for query in queries:
|
||||
ret[query] = await self._search_and_rank_urls(topic, query, url_per_query)
|
||||
return ret
|
||||
|
||||
async def _search_and_rank_urls(self, topic: str, query: str, num_results: int = 4) -> list[str]:
|
||||
"""Search and rank URLs based on a query.
|
||||
|
||||
Args:
|
||||
topic: The research topic.
|
||||
query: The search query.
|
||||
num_results: The number of URLs to collect.
|
||||
|
||||
Returns:
|
||||
A list of ranked URLs.
|
||||
"""
|
||||
max_results = max(num_results * 2, 6)
|
||||
results = await self.search_engine.run(query, max_results=max_results, as_string=False)
|
||||
_results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results))
|
||||
prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results)
|
||||
logger.debug(prompt)
|
||||
indices = await self._aask(prompt)
|
||||
try:
|
||||
indices = json.loads(indices)
|
||||
assert all(isinstance(i, int) for i in indices)
|
||||
except Exception as e:
|
||||
logger.exception(f"fail to rank results for {e}")
|
||||
indices = list(range(max_results))
|
||||
results = [results[i] for i in indices]
|
||||
if self.rank_func:
|
||||
results = self.rank_func(results)
|
||||
return [i["link"] for i in results[:num_results]]
|
||||
|
||||
|
||||
class WebBrowseAndSummarize(Action):
|
||||
"""Action class to explore the web and provide summaries of articles and webpages."""
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
browse_func: Callable[[list[str]], None] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
if CONFIG.model_for_researcher_summary:
|
||||
self.llm.model = CONFIG.model_for_researcher_summary
|
||||
self.web_browser_engine = WebBrowserEngine(
|
||||
engine=WebBrowserEngineType.CUSTOM if browse_func else None,
|
||||
run_func=browse_func,
|
||||
)
|
||||
self.desc = "Explore the web and provide summaries of articles and webpages."
|
||||
|
||||
async def run(
|
||||
self,
|
||||
url: str,
|
||||
*urls: str,
|
||||
query: str,
|
||||
system_text: str = RESEARCH_BASE_SYSTEM,
|
||||
) -> dict[str, str]:
|
||||
"""Run the action to browse the web and provide summaries.
|
||||
|
||||
Args:
|
||||
url: The main URL to browse.
|
||||
urls: Additional URLs to browse.
|
||||
query: The research question.
|
||||
system_text: The system text.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the URLs as keys and their summaries as values.
|
||||
"""
|
||||
contents = await self.web_browser_engine.run(url, *urls)
|
||||
if not urls:
|
||||
contents = [contents]
|
||||
|
||||
summaries = {}
|
||||
prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}")
|
||||
for u, content in zip([url, *urls], contents):
|
||||
content = content.inner_text
|
||||
chunk_summaries = []
|
||||
for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp):
|
||||
logger.debug(prompt)
|
||||
summary = await self._aask(prompt, [system_text])
|
||||
if summary == "Not relevant.":
|
||||
continue
|
||||
chunk_summaries.append(summary)
|
||||
|
||||
if not chunk_summaries:
|
||||
summaries[u] = None
|
||||
continue
|
||||
|
||||
if len(chunk_summaries) == 1:
|
||||
summaries[u] = chunk_summaries[0]
|
||||
continue
|
||||
|
||||
content = "\n".join(chunk_summaries)
|
||||
prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content)
|
||||
summary = await self._aask(prompt, [system_text])
|
||||
summaries[u] = summary
|
||||
return summaries
|
||||
|
||||
|
||||
class ConductResearch(Action):
|
||||
"""Action class to conduct research and generate a research report."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if CONFIG.model_for_researcher_report:
|
||||
self.llm.model = CONFIG.model_for_researcher_report
|
||||
|
||||
async def run(
|
||||
self,
|
||||
topic: str,
|
||||
content: str,
|
||||
system_text: str = RESEARCH_BASE_SYSTEM,
|
||||
) -> str:
|
||||
"""Run the action to conduct research and generate a research report.
|
||||
|
||||
Args:
|
||||
topic: The research topic.
|
||||
content: The content for research.
|
||||
system_text: The system text.
|
||||
|
||||
Returns:
|
||||
The generated research report.
|
||||
"""
|
||||
prompt = CONDUCT_RESEARCH_PROMPT.format(topic=topic, content=content)
|
||||
logger.debug(prompt)
|
||||
self.llm.auto_max_tokens = True
|
||||
return await self._aask(prompt, [system_text])
|
||||
|
||||
|
||||
def get_research_system_text(topic: str, language: str):
|
||||
"""Get the system text for conducting research.
|
||||
|
||||
Args:
|
||||
topic: The research topic.
|
||||
language: The language for the system text.
|
||||
|
||||
Returns:
|
||||
The system text for conducting research.
|
||||
"""
|
||||
return " ".join((RESEARCH_TOPIC_SYSTEM.format(topic=topic), LANG_PROMPT.format(language=language)))
|
||||
|
|
@ -4,14 +4,14 @@
|
|||
提供配置,单例
|
||||
"""
|
||||
import os
|
||||
import openai
|
||||
|
||||
import openai
|
||||
import yaml
|
||||
|
||||
from metagpt.const import PROJECT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.singleton import Singleton
|
||||
from metagpt.tools import SearchEngineType, WebBrowserEngineType
|
||||
from metagpt.utils.singleton import Singleton
|
||||
|
||||
|
||||
class NotConfiguredException(Exception):
|
||||
|
|
@ -77,12 +77,12 @@ class Config(metaclass=Singleton):
|
|||
logger.warning("LONG_TERM_MEMORY is True")
|
||||
self.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
self.total_cost = 0.0
|
||||
self.puppeteer_config = self._get("PUPPETEER_CONFIG","")
|
||||
self.mmdc = self._get("MMDC","mmdc")
|
||||
self.update_costs = self._get("UPDATE_COSTS",True)
|
||||
self.calc_usage = self._get("CALC_USAGE",True)
|
||||
|
||||
|
||||
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
|
||||
self.mmdc = self._get("MMDC", "mmdc")
|
||||
self.update_costs = self._get("UPDATE_COSTS", True)
|
||||
self.calc_usage = self._get("CALC_USAGE", True)
|
||||
self.model_for_researcher_summary = self._get("MODEL_FOR_RESEARCHER_SUMMARY")
|
||||
self.model_for_researcher_report = self._get("MODEL_FOR_RESEARCHER_REPORT")
|
||||
|
||||
def _init_with_config_files_and_env(self, configs: dict, yaml_file):
|
||||
"""从config/key.yaml / config/config.yaml / env三处按优先级递减加载"""
|
||||
|
|
|
|||
|
|
@ -32,5 +32,6 @@ UT_PY_PATH = UT_PATH / "files/ut/"
|
|||
API_QUESTIONS_PATH = UT_PATH / "files/question/"
|
||||
YAPI_URL = "http://yapi.deepwisdomai.com/"
|
||||
TMP = PROJECT_ROOT / 'tmp'
|
||||
RESEARCH_PATH = DATA_PATH / "research"
|
||||
|
||||
MEM_TTL = 24 * 30 * 3600
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/5 23:08
|
||||
|
|
@ -20,6 +19,7 @@ from metagpt.utils.token_counter import (
|
|||
TOKEN_COSTS,
|
||||
count_message_tokens,
|
||||
count_string_tokens,
|
||||
get_max_completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -33,20 +33,25 @@ def retry(max_retries):
|
|||
except Exception:
|
||||
if i == max_retries - 1:
|
||||
raise
|
||||
await asyncio.sleep(2 ** i)
|
||||
await asyncio.sleep(2**i)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate control class, each call goes through wait_if_needed, sleep if rate control is needed"""
|
||||
|
||||
def __init__(self, rpm):
|
||||
self.last_call_time = 0
|
||||
self.interval = 1.1 * 60 / rpm # Here 1.1 is used because even if the calls are made strictly according to time, they will still be QOS'd; consider switching to simple error retry later
|
||||
# Here 1.1 is used because even if the calls are made strictly according to time,
|
||||
# they will still be QOS'd; consider switching to simple error retry later
|
||||
self.interval = 1.1 * 60 / rpm
|
||||
self.rpm = rpm
|
||||
|
||||
def split_batches(self, batch):
|
||||
return [batch[i:i + self.rpm] for i in range(0, len(batch), self.rpm)]
|
||||
return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)]
|
||||
|
||||
async def wait_if_needed(self, num_requests):
|
||||
current_time = time.time()
|
||||
|
|
@ -69,6 +74,7 @@ class Costs(NamedTuple):
|
|||
|
||||
class CostManager(metaclass=Singleton):
|
||||
"""计算使用接口的开销"""
|
||||
|
||||
def __init__(self):
|
||||
self.total_prompt_tokens = 0
|
||||
self.total_completion_tokens = 0
|
||||
|
|
@ -86,13 +92,12 @@ class CostManager(metaclass=Singleton):
|
|||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
cost = (
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"]
|
||||
+ completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
) / 1000
|
||||
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, {prompt_tokens=}, {completion_tokens=}")
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
def get_total_prompt_tokens(self):
|
||||
|
|
@ -131,10 +136,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_openai(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.openai_api_model
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = CostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
|
|
@ -148,10 +155,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
self.rpm = int(config.get("RPM", 10))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
**self._cons_kwargs(messages),
|
||||
stream=True
|
||||
)
|
||||
response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True)
|
||||
|
||||
# create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
|
|
@ -159,41 +163,41 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
chunk_message = chunk['choices'][0]['delta'] # extract the message
|
||||
chunk_message = chunk["choices"][0]["delta"] # extract the message
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
if "content" in chunk_message:
|
||||
print(chunk_message["content"], end="")
|
||||
print()
|
||||
|
||||
full_reply_content = ''.join([m.get('content', '') for m in collected_messages])
|
||||
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict]) -> dict:
|
||||
if CONFIG.openai_api_type == 'azure':
|
||||
if CONFIG.openai_api_type == "azure":
|
||||
kwargs = {
|
||||
"deployment_id": CONFIG.deployment_id,
|
||||
"messages": messages,
|
||||
"max_tokens": CONFIG.max_tokens_rsp,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3
|
||||
"temperature": 0.3,
|
||||
}
|
||||
else:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": CONFIG.max_tokens_rsp,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3
|
||||
"temperature": 0.3,
|
||||
}
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages))
|
||||
self._update_costs(rsp.get('usage'))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
return rsp
|
||||
|
||||
def _chat_completion(self, messages: list[dict]) -> dict:
|
||||
|
|
@ -262,3 +266,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
||||
def get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
|
|
|
|||
93
metagpt/roles/researcher.py
Normal file
93
metagpt/roles/researcher.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import CollectLinks, ConductResearch, WebBrowseAndSummarize
|
||||
from metagpt.actions.research import get_research_system_text
|
||||
from metagpt.const import RESEARCH_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class Report(BaseModel):
|
||||
topic: str
|
||||
links: dict[str, list[str]] = None
|
||||
summaries: list[tuple[str, str]] = None
|
||||
content: str = ""
|
||||
|
||||
|
||||
class Researcher(Role):
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "David",
|
||||
profile: str = "Researcher",
|
||||
goal: str = "Gather information and conduct research",
|
||||
constraints: str = "Ensure accuracy and relevance of information",
|
||||
language: str = "en-us",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(name, profile, goal, constraints, **kwargs)
|
||||
self._init_actions([CollectLinks(name), WebBrowseAndSummarize(name), ConductResearch(name)])
|
||||
self.language = language
|
||||
if language not in ("en-us", "zh-cn"):
|
||||
logger.warning(f"The language `{language}` has not been tested, it may not work.")
|
||||
|
||||
async def _think(self) -> None:
|
||||
if self._rc.todo is None:
|
||||
self._set_state(0)
|
||||
return
|
||||
|
||||
if self._rc.state + 1 < len(self._states):
|
||||
self._set_state(self._rc.state + 1)
|
||||
else:
|
||||
self._rc.todo = None
|
||||
|
||||
async def _act(self) -> Message:
|
||||
logger.info(f"{self._setting}: ready to {self._rc.todo}")
|
||||
todo = self._rc.todo
|
||||
msg = self._rc.memory.get(k=1)[0]
|
||||
if isinstance(msg.instruct_content, Report):
|
||||
instruct_content = msg.instruct_content
|
||||
topic = instruct_content.topic
|
||||
else:
|
||||
topic = msg.content
|
||||
|
||||
research_system_text = get_research_system_text(topic, self.language)
|
||||
if isinstance(todo, CollectLinks):
|
||||
links = await todo.run(topic, 4, 4)
|
||||
ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=type(todo))
|
||||
elif isinstance(todo, WebBrowseAndSummarize):
|
||||
links = instruct_content.links
|
||||
todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items())
|
||||
summaries = await asyncio.gather(*todos)
|
||||
summaries = list((url, summary) for i in summaries for (url, summary) in i.items() if summary)
|
||||
ret = Message("", Report(topic=topic, summaries=summaries), role=self.profile, cause_by=type(todo))
|
||||
else:
|
||||
summaries = instruct_content.summaries
|
||||
summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries)
|
||||
content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text)
|
||||
ret = Message("", Report(topic=topic, content=content), role=self.profile, cause_by=type(self._rc.todo))
|
||||
self._rc.memory.add(ret)
|
||||
return ret
|
||||
|
||||
async def _react(self) -> Message:
|
||||
while True:
|
||||
await self._think()
|
||||
if self._rc.todo is None:
|
||||
break
|
||||
msg = await self._act()
|
||||
report = msg.instruct_content
|
||||
self.write_report(report.topic, report.content)
|
||||
return msg
|
||||
|
||||
def write_report(self, topic: str, content: str):
|
||||
filepath = RESEARCH_PATH / f"{topic}.md"
|
||||
filepath.write_text(content)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
role = Researcher(language="en-us")
|
||||
asyncio.run(role.run("dataiku vs. datarobot"))
|
||||
|
|
@ -14,6 +14,7 @@ class SearchEngineType(Enum):
|
|||
SERPAPI_GOOGLE = auto()
|
||||
DIRECT_GOOGLE = auto()
|
||||
SERPER_GOOGLE = auto()
|
||||
DUCK_DUCK_GO = auto()
|
||||
CUSTOM_ENGINE = auto()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,122 +7,76 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import importlib
|
||||
from typing import Callable, Coroutine, Literal, overload
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.search_engine_serpapi import SerpAPIWrapper
|
||||
from metagpt.tools.search_engine_serper import SerperWrapper
|
||||
|
||||
config = Config()
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
||||
|
||||
class SearchEngine:
|
||||
"""
|
||||
TODO: 合入Google Search 并进行反代
|
||||
注:这里Google需要挂Proxifier或者类似全局代理
|
||||
- DDG: https://pypi.org/project/duckduckgo-search/
|
||||
- GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9
|
||||
"""
|
||||
def __init__(self, engine=None, run_func=None):
|
||||
self.config = Config()
|
||||
self.run_func = run_func
|
||||
self.engine = engine or self.config.search_engine
|
||||
"""Class representing a search engine.
|
||||
|
||||
@classmethod
|
||||
def run_google(cls, query, max_results=8):
|
||||
# results = ddg(query, max_results=max_results)
|
||||
results = google_official_search(query, num_results=max_results)
|
||||
logger.info(results)
|
||||
return results
|
||||
Args:
|
||||
engine: The search engine type. Defaults to the search engine specified in the config.
|
||||
run_func: The function to run the search. Defaults to None.
|
||||
|
||||
async def run(self, query: str, max_results=8):
|
||||
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
api = SerpAPIWrapper()
|
||||
rsp = await api.run(query)
|
||||
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
rsp = SearchEngine.run_google(query, max_results)
|
||||
elif self.engine == SearchEngineType.SERPER_GOOGLE:
|
||||
api = SerperWrapper()
|
||||
rsp = await api.run(query)
|
||||
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
rsp = self.run_func(query)
|
||||
Attributes:
|
||||
run_func: The function to run the search.
|
||||
engine: The search engine type.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
engine: SearchEngineType | None = None,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None,
|
||||
):
|
||||
engine = engine or CONFIG.search_engine
|
||||
if engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serpapi"
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper().run
|
||||
elif engine == SearchEngineType.SERPER_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serper"
|
||||
run_func = importlib.import_module(module).SerperWrapper().run
|
||||
elif engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_googleapi"
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper().run
|
||||
elif engine == SearchEngineType.DUCK_DUCK_GO:
|
||||
module = "metagpt.tools.search_engine_ddg"
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper().run
|
||||
elif engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
pass # run_func = run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return rsp
|
||||
self.engine = engine
|
||||
self.run_func = run_func
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[True] = True,
|
||||
) -> str:
|
||||
...
|
||||
|
||||
def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[False] = False,
|
||||
) -> list[dict[str, str]]:
|
||||
...
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
"""Run a search query.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
Args:
|
||||
query: The search query.
|
||||
max_results: The maximum number of results to return. Defaults to 8.
|
||||
as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True.
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
try:
|
||||
api_key = config.google_api_key
|
||||
custom_search_engine_id = config.google_cse_id
|
||||
|
||||
with build("customsearch", "v1", developerKey=api_key) as service:
|
||||
|
||||
result = (
|
||||
service.cse()
|
||||
.list(q=query, cx=custom_search_engine_id, num=num_results)
|
||||
.execute()
|
||||
)
|
||||
logger.info(result)
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
# Create a list of only the URLs from the search results
|
||||
search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
error_details = json.loads(e.content.decode())
|
||||
|
||||
# Check if the error is related to an invalid or missing API key
|
||||
if error_details.get("error", {}).get(
|
||||
"code"
|
||||
) == 403 and "invalid API key" in error_details.get("error", {}).get(
|
||||
"message", ""
|
||||
):
|
||||
return "Error: The provided Google API key is invalid or missing."
|
||||
else:
|
||||
return f"Error: {e}"
|
||||
# google_result can be a list or a string depending on the search results
|
||||
|
||||
# Return the list of search result URLs
|
||||
return search_results_details
|
||||
|
||||
|
||||
def safe_google_results(results: str | list) -> str:
|
||||
"""
|
||||
Return the results of a google search in a safe format.
|
||||
|
||||
Args:
|
||||
results (str | list): The search results.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps(
|
||||
# FIXME: # .encode("utf-8", "ignore") 这里去掉了,但是AutoGPT里有,很奇怪
|
||||
[result for result in results]
|
||||
)
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
SearchEngine.run(query='wtf')
|
||||
Returns:
|
||||
The search results as a string or a list of dictionaries.
|
||||
"""
|
||||
return await self.run_func(query, max_results=max_results, as_string=as_string)
|
||||
|
|
|
|||
107
metagpt/tools/search_engine_ddg.py
Normal file
107
metagpt/tools/search_engine_ddg.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from concurrent import futures
|
||||
from typing import Literal, overload
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class DDGAPIWrapper:
|
||||
"""Wrapper around duckduckgo_search API.
|
||||
|
||||
To use this module, you should have the `duckduckgo_search` Python package installed.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
):
|
||||
kwargs = {}
|
||||
if CONFIG.global_proxy:
|
||||
kwargs["proxies"] = CONFIG.global_proxy
|
||||
self.loop = loop
|
||||
self.executor = executor
|
||||
self.ddgs = DDGS(**kwargs)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[True] = True,
|
||||
focus: list[str] | None = None,
|
||||
) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[False] = False,
|
||||
focus: list[str] | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
...
|
||||
|
||||
async def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: bool = True,
|
||||
) -> str | list[dict]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
max_results: The number of results to return.
|
||||
as_string: A boolean flag to determine the return type of the results. If True, the function will
|
||||
return a formatted string with the search results. If False, it will return a list of dictionaries
|
||||
containing detailed information about each search result.
|
||||
|
||||
Returns:
|
||||
The results of the search.
|
||||
"""
|
||||
loop = self.loop or asyncio.get_event_loop()
|
||||
future = loop.run_in_executor(
|
||||
self.executor,
|
||||
self._search_from_ddgs,
|
||||
query,
|
||||
max_results,
|
||||
)
|
||||
try:
|
||||
search_results = await future
|
||||
# Extract the search result items from the response
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
logger.exception(f"fail to search {query} for {e}")
|
||||
search_results = []
|
||||
|
||||
# Return the list of search result URLs
|
||||
if as_string:
|
||||
return json.dumps(search_results, ensure_ascii=False)
|
||||
return search_results
|
||||
|
||||
def _search_from_ddgs(self, query: str, max_results: int):
|
||||
return [
|
||||
{
|
||||
"link": i["href"],
|
||||
"snippet": i["body"],
|
||||
"title": i["title"]
|
||||
} for (_, i) in zip(range(max_results), self.ddgs.text(query))
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(DDGAPIWrapper().run)
|
||||
117
metagpt/tools/search_engine_googleapi.py
Normal file
117
metagpt/tools/search_engine_googleapi.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from concurrent import futures
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httplib2
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class GoogleAPIWrapper:
|
||||
"""Wrapper around GoogleAPI.
|
||||
|
||||
To use this module, you should have the `google-api-python-client` Python package installed
|
||||
and set property values for the configurations `GOOGLE_API_KEY` and `GOOGLE_CSE_ID`. See
|
||||
https://programmablesearchengine.google.com/controlpanel/all.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
):
|
||||
build_kwargs = {"developerKey": CONFIG.google_api_key}
|
||||
if CONFIG.global_proxy:
|
||||
parse_result = urlparse(CONFIG.global_proxy)
|
||||
proxy_type = parse_result.scheme
|
||||
if proxy_type == "https":
|
||||
proxy_type = "http"
|
||||
build_kwargs["http"] = httplib2.Http(
|
||||
proxy_info=httplib2.ProxyInfo(
|
||||
getattr(httplib2.socks, f"PROXY_TYPE_{proxy_type.upper()}"),
|
||||
parse_result.hostname,
|
||||
parse_result.port,
|
||||
),
|
||||
)
|
||||
service = build("customsearch", "v1", **build_kwargs)
|
||||
self.google_api_client = service.cse()
|
||||
self.custom_search_engine_id = CONFIG.google_cse_id
|
||||
self.loop = loop
|
||||
self.executor = executor
|
||||
|
||||
async def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: bool = True,
|
||||
focus: list[str] | None = None,
|
||||
) -> str | list[dict]:
|
||||
"""Return the results of a Google search using the official Google API.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
max_results: The number of results to return.
|
||||
as_string: A boolean flag to determine the return type of the results. If True, the function will
|
||||
return a formatted string with the search results. If False, it will return a list of dictionaries
|
||||
containing detailed information about each search result.
|
||||
focus: Specific information to be focused on from each search result.
|
||||
|
||||
Returns:
|
||||
The results of the search.
|
||||
"""
|
||||
loop = self.loop or asyncio.get_event_loop()
|
||||
future = loop.run_in_executor(
|
||||
self.executor,
|
||||
self.google_api_client.list(
|
||||
q=query,
|
||||
num=max_results,
|
||||
cx=self.custom_search_engine_id
|
||||
).execute
|
||||
)
|
||||
try:
|
||||
result = await future
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
logger.exception(f"fail to search {query} for {e}")
|
||||
search_results = []
|
||||
|
||||
focus = focus or ["snippet", "link", "title"]
|
||||
details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
|
||||
# Return the list of search result URLs
|
||||
if as_string:
|
||||
return safe_google_results(details)
|
||||
|
||||
return details
|
||||
|
||||
|
||||
def safe_google_results(results: str | list) -> str:
|
||||
"""Return the results of a google search in a safe format.
|
||||
|
||||
Args:
|
||||
results: The search results.
|
||||
|
||||
Returns:
|
||||
The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps([result for result in results])
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(GoogleAPIWrapper().run)
|
||||
|
|
@ -37,16 +37,17 @@ class SerpAPIWrapper(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def run(self, query: str, **kwargs: Any) -> str:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result async."""
|
||||
return self._process_response(await self.results(query))
|
||||
return self._process_response(await self.results(query, max_results), as_string=as_string)
|
||||
|
||||
async def results(self, query: str) -> dict:
|
||||
async def results(self, query: str, max_results: int) -> dict:
|
||||
"""Use aiohttp to run query through SerpAPI and return the results async."""
|
||||
|
||||
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
|
||||
params = self.get_params(query)
|
||||
params["source"] = "python"
|
||||
params["num"] = max_results
|
||||
if self.serpapi_api_key:
|
||||
params["serp_api_key"] = self.serpapi_api_key
|
||||
params["output"] = "json"
|
||||
|
|
@ -74,10 +75,10 @@ class SerpAPIWrapper(BaseModel):
|
|||
return params
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict) -> str:
|
||||
def _process_response(res: dict, as_string: bool) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
# logger.debug(res)
|
||||
focus = ['title', 'snippet', 'link']
|
||||
focus = ["title", "snippet", "link"]
|
||||
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
|
||||
|
||||
if "error" in res.keys():
|
||||
|
|
@ -86,20 +87,11 @@ class SerpAPIWrapper(BaseModel):
|
|||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif (
|
||||
"sports_results" in res.keys()
|
||||
and "game_spotlight" in res["sports_results"].keys()
|
||||
):
|
||||
elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys():
|
||||
toret = res["sports_results"]["game_spotlight"]
|
||||
elif (
|
||||
"knowledge_graph" in res.keys()
|
||||
and "description" in res["knowledge_graph"].keys()
|
||||
):
|
||||
elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
|
||||
toret = res["knowledge_graph"]["description"]
|
||||
elif "snippet" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["snippet"]
|
||||
|
|
@ -112,4 +104,10 @@ class SerpAPIWrapper(BaseModel):
|
|||
if res.get("organic_results"):
|
||||
toret_l += [get_focused(i) for i in res.get("organic_results")]
|
||||
|
||||
return str(toret) + '\n' + str(toret_l)
|
||||
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(SerpAPIWrapper().run)
|
||||
|
|
|
|||
|
|
@ -36,16 +36,19 @@ class SerperWrapper(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def run(self, query: str, **kwargs: Any) -> str:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through Serper and parse result async."""
|
||||
queries = query.split("\n")
|
||||
return "\n".join([self._process_response(res) for res in await self.results(queries)])
|
||||
if isinstance(query, str):
|
||||
return self._process_response((await self.results([query], max_results))[0], as_string=as_string)
|
||||
else:
|
||||
results = [self._process_response(res, as_string) for res in await self.results(query, max_results)]
|
||||
return "\n".join(results) if as_string else results
|
||||
|
||||
async def results(self, queries: list[str]) -> dict:
|
||||
async def results(self, queries: list[str], max_results: int = 8) -> dict:
|
||||
"""Use aiohttp to run query through Serper and return the results async."""
|
||||
|
||||
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
|
||||
payloads = self.get_payloads(queries)
|
||||
payloads = self.get_payloads(queries, max_results)
|
||||
url = "https://google.serper.dev/search"
|
||||
headers = self.get_headers()
|
||||
return url, payloads, headers
|
||||
|
|
@ -61,12 +64,13 @@ class SerperWrapper(BaseModel):
|
|||
|
||||
return res
|
||||
|
||||
def get_payloads(self, queries: list[str]) -> Dict[str, str]:
|
||||
def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]:
|
||||
"""Get payloads for Serper."""
|
||||
payloads = []
|
||||
for query in queries:
|
||||
_payload = {
|
||||
"q": query,
|
||||
"num": max_results,
|
||||
}
|
||||
payloads.append({**self.payload, **_payload})
|
||||
return json.dumps(payloads, sort_keys=True)
|
||||
|
|
@ -79,7 +83,7 @@ class SerperWrapper(BaseModel):
|
|||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict) -> str:
|
||||
def _process_response(res: dict, as_string: bool = False) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
# logger.debug(res)
|
||||
focus = ['title', 'snippet', 'link']
|
||||
|
|
@ -117,4 +121,10 @@ class SerperWrapper(BaseModel):
|
|||
if res.get("organic"):
|
||||
toret_l += [get_focused(i) for i in res.get("organic")]
|
||||
|
||||
return str(toret) + '\n' + str(toret_l)
|
||||
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(SerperWrapper().run)
|
||||
|
|
|
|||
|
|
@ -1,22 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import importlib
|
||||
|
||||
from typing import Any, Callable, Coroutine, overload
|
||||
import importlib
|
||||
from typing import Any, Callable, Coroutine, Literal, overload
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
from bs4 import BeautifulSoup
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class WebBrowserEngine:
|
||||
def __init__(
|
||||
self,
|
||||
engine: WebBrowserEngineType | None = None,
|
||||
run_func: Callable[..., Coroutine[Any, Any, str | list[str]]] | None = None,
|
||||
parse_func: Callable[[str], str] | None = None,
|
||||
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
|
||||
):
|
||||
engine = engine or CONFIG.web_browser_engine
|
||||
|
||||
|
|
@ -30,30 +28,25 @@ class WebBrowserEngine:
|
|||
run_func = run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.parse_func = parse_func or get_page_content
|
||||
self.run_func = run_func
|
||||
self.engine = engine
|
||||
|
||||
@overload
|
||||
async def run(self, url: str) -> str:
|
||||
async def run(self, url: str) -> WebPage:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def run(self, url: str, *urls: str) -> list[str]:
|
||||
async def run(self, url: str, *urls: str) -> list[WebPage]:
|
||||
...
|
||||
|
||||
async def run(self, url: str, *urls: str) -> str | list[str]:
|
||||
page = await self.run_func(url, *urls)
|
||||
if isinstance(page, str):
|
||||
return self.parse_func(page)
|
||||
return [self.parse_func(i) for i in page]
|
||||
|
||||
|
||||
def get_page_content(page: str):
|
||||
soup = BeautifulSoup(page, "html.parser")
|
||||
return "\n".join(i.text.strip() for i in soup.find_all(["h1", "h2", "h3", "h4", "h5", "p", "pre"]))
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
return await self.run_func(url, *urls)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = asyncio.run(WebBrowserEngine().run("https://fuzhi.ai/"))
|
||||
print(text)
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs):
|
||||
return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -2,12 +2,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class PlaywrightWrapper:
|
||||
|
|
@ -16,7 +19,7 @@ class PlaywrightWrapper:
|
|||
To use this module, you should have the `playwright` Python package installed and ensure that
|
||||
the required browsers are also installed. You can install playwright by running the command
|
||||
`pip install metagpt[playwright]` and download the necessary browser binaries by running the
|
||||
command `playwright install` for the first time."
|
||||
command `playwright install` for the first time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -40,27 +43,30 @@ class PlaywrightWrapper:
|
|||
self._context_kwargs = context_kwargs
|
||||
self._has_run_precheck = False
|
||||
|
||||
async def run(self, url: str, *urls: str) -> str | list[str]:
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
async with async_playwright() as ap:
|
||||
browser_type = getattr(ap, self.browser_type)
|
||||
await self._run_precheck(browser_type)
|
||||
browser = await browser_type.launch(**self.launch_kwargs)
|
||||
|
||||
async def _scrape(url):
|
||||
context = await browser.new_context(**self._context_kwargs)
|
||||
page = await context.new_page()
|
||||
async with page:
|
||||
try:
|
||||
await page.goto(url)
|
||||
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
||||
content = await page.content()
|
||||
return content
|
||||
except Exception as e:
|
||||
return f"Fail to load page content for {e}"
|
||||
_scrape = self._scrape
|
||||
|
||||
if urls:
|
||||
return await asyncio.gather(_scrape(url), *(_scrape(i) for i in urls))
|
||||
return await _scrape(url)
|
||||
return await asyncio.gather(_scrape(browser, url), *(_scrape(browser, i) for i in urls))
|
||||
return await _scrape(browser, url)
|
||||
|
||||
async def _scrape(self, browser, url):
|
||||
context = await browser.new_context(**self._context_kwargs)
|
||||
page = await context.new_page()
|
||||
async with page:
|
||||
try:
|
||||
await page.goto(url)
|
||||
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
||||
html = await page.content()
|
||||
inner_text = await page.evaluate("() => document.body.innerText")
|
||||
except Exception as e:
|
||||
inner_text = f"Fail to load page content for {e}"
|
||||
html = ""
|
||||
return WebPage(inner_text=inner_text, html=html, url=url)
|
||||
|
||||
async def _run_precheck(self, browser_type):
|
||||
if self._has_run_precheck:
|
||||
|
|
@ -72,6 +78,10 @@ class PlaywrightWrapper:
|
|||
if CONFIG.global_proxy:
|
||||
kwargs["env"] = {"ALL_PROXY": CONFIG.global_proxy}
|
||||
await _install_browsers(self.browser_type, **kwargs)
|
||||
|
||||
if self._has_run_precheck:
|
||||
return
|
||||
|
||||
if not executable_path.exists():
|
||||
parts = executable_path.parts
|
||||
available_paths = list(Path(*parts[:-3]).glob(f"{self.browser_type}-*"))
|
||||
|
|
@ -85,25 +95,37 @@ class PlaywrightWrapper:
|
|||
self._has_run_precheck = True
|
||||
|
||||
|
||||
def _get_install_lock():
|
||||
global _install_lock
|
||||
if _install_lock is None:
|
||||
_install_lock = asyncio.Lock()
|
||||
return _install_lock
|
||||
|
||||
|
||||
async def _install_browsers(*browsers, **kwargs) -> None:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
sys.executable,
|
||||
"-m",
|
||||
"playwright",
|
||||
"install",
|
||||
*browsers,
|
||||
"--with-deps",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
**kwargs,
|
||||
)
|
||||
async with _get_install_lock():
|
||||
browsers = [i for i in browsers if i not in _install_cache]
|
||||
if not browsers:
|
||||
return
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
sys.executable,
|
||||
"-m",
|
||||
"playwright",
|
||||
"install",
|
||||
*browsers,
|
||||
# "--with-deps",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
await asyncio.gather(_log_stream(process.stdout, logger.info), _log_stream(process.stderr, logger.warning))
|
||||
await asyncio.gather(_log_stream(process.stdout, logger.info), _log_stream(process.stderr, logger.warning))
|
||||
|
||||
if await process.wait() == 0:
|
||||
logger.info(f"Install browser for playwright successfully.")
|
||||
else:
|
||||
logger.warning(f"Fail to install browser for playwright.")
|
||||
if await process.wait() == 0:
|
||||
logger.info("Install browser for playwright successfully.")
|
||||
else:
|
||||
logger.warning("Fail to install browser for playwright.")
|
||||
_install_cache.update(browsers)
|
||||
|
||||
|
||||
async def _log_stream(sr, log_func):
|
||||
|
|
@ -114,8 +136,14 @@ async def _log_stream(sr, log_func):
|
|||
log_func(f"[playwright install browser]: {line.decode().strip()}")
|
||||
|
||||
|
||||
_install_lock: asyncio.Lock = None
|
||||
_install_cache = set()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for i in ("chromium", "firefox", "webkit"):
|
||||
text = asyncio.run(PlaywrightWrapper(i).run("https://httpbin.org/ip"))
|
||||
print(text)
|
||||
print(i)
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, browser_type: str = "chromium", **kwargs):
|
||||
return await PlaywrightWrapper(browser_type, **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -2,16 +2,17 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
import importlib
|
||||
from concurrent import futures
|
||||
from copy import deepcopy
|
||||
from typing import Literal
|
||||
from metagpt.config import CONFIG
|
||||
import asyncio
|
||||
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
from concurrent import futures
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class SeleniumWrapper:
|
||||
|
|
@ -48,7 +49,7 @@ class SeleniumWrapper:
|
|||
self.loop = loop
|
||||
self.executor = executor
|
||||
|
||||
async def run(self, url: str, *urls: str) -> str | list[str]:
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
await self._run_precheck()
|
||||
|
||||
_scrape = lambda url: self.loop.run_in_executor(self.executor, self._scrape_website, url)
|
||||
|
|
@ -69,9 +70,15 @@ class SeleniumWrapper:
|
|||
|
||||
def _scrape_website(self, url):
|
||||
with self._get_driver() as driver:
|
||||
driver.get(url)
|
||||
WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
|
||||
return driver.page_source
|
||||
try:
|
||||
driver.get(url)
|
||||
WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
|
||||
inner_text = driver.execute_script("return document.body.innerText;")
|
||||
html = driver.page_source
|
||||
except Exception as e:
|
||||
inner_text = f"Fail to load page content for {e}"
|
||||
html = ""
|
||||
return WebPage(inner_text=inner_text, html=html, url=url)
|
||||
|
||||
|
||||
_webdriver_manager_types = {
|
||||
|
|
@ -97,6 +104,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
def _get_driver():
|
||||
options = Options()
|
||||
options.add_argument("--headless")
|
||||
options.add_argument("--enable-javascript")
|
||||
if browser_type == "chrome":
|
||||
options.add_argument("--no-sandbox")
|
||||
for i in args:
|
||||
|
|
@ -107,5 +115,9 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = asyncio.run(SeleniumWrapper("chrome").run("https://fuzhi.ai/"))
|
||||
print(text)
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs):
|
||||
return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
57
metagpt/utils/parse_html.py
Normal file
57
metagpt/utils/parse_html.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Generator, Optional
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class WebPage(BaseModel):
|
||||
inner_text: str
|
||||
html: str
|
||||
url: str
|
||||
|
||||
class Config:
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
_soup : Optional[BeautifulSoup] = None
|
||||
_title: Optional[str] = None
|
||||
|
||||
@property
|
||||
def soup(self) -> BeautifulSoup:
|
||||
if self._soup is None:
|
||||
self._soup = BeautifulSoup(self.html, "html.parser")
|
||||
return self._soup
|
||||
|
||||
@property
|
||||
def title(self):
|
||||
if self._title is None:
|
||||
title_tag = self.soup.find("title")
|
||||
self._title = title_tag.text.strip() if title_tag is not None else ""
|
||||
return self._title
|
||||
|
||||
def get_links(self) -> Generator[str, None, None]:
|
||||
for i in self.soup.find_all("a", href=True):
|
||||
url = i["href"]
|
||||
result = urlparse(url)
|
||||
if not result.scheme and result.path:
|
||||
yield urljoin(self.url, url)
|
||||
elif url.startswith(("http://", "https://")):
|
||||
yield urljoin(self.url, url)
|
||||
|
||||
|
||||
def get_html_content(page: str, base: str):
|
||||
soup = _get_soup(page)
|
||||
|
||||
return soup.get_text(strip=True)
|
||||
|
||||
|
||||
def _get_soup(page: str):
|
||||
soup = BeautifulSoup(page, "html.parser")
|
||||
# https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup
|
||||
for s in soup(["style", "script", "[document]", "head", "title"]):
|
||||
s.extract()
|
||||
|
||||
return soup
|
||||
124
metagpt/utils/text.py
Normal file
124
metagpt/utils/text.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
from typing import Generator, Sequence
|
||||
|
||||
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
|
||||
|
||||
|
||||
def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
|
||||
"""Reduce the length of concatenated message segments to fit within the maximum token size.
|
||||
|
||||
Args:
|
||||
msgs: A generator of strings representing progressively shorter valid prompts.
|
||||
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
|
||||
system_text: The system prompts.
|
||||
reserved: The number of reserved tokens.
|
||||
|
||||
Returns:
|
||||
The concatenated message segments reduced to fit within the maximum token size.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If it fails to reduce the concatenated message length.
|
||||
"""
|
||||
max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved
|
||||
for msg in msgs:
|
||||
if count_string_tokens(msg, model_name) < max_token:
|
||||
return msg
|
||||
|
||||
raise RuntimeError("fail to reduce message length")
|
||||
|
||||
|
||||
def generate_prompt_chunk(
|
||||
text: str,
|
||||
prompt_template: str,
|
||||
model_name: str,
|
||||
system_text: str,
|
||||
reserved: int = 0,
|
||||
) -> Generator[str, None, None]:
|
||||
"""Split the text into chunks of a maximum token size.
|
||||
|
||||
Args:
|
||||
text: The text to split.
|
||||
prompt_template: The template for the prompt, containing a single `{}` placeholder. For example, "### Reference\n{}".
|
||||
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
|
||||
system_text: The system prompts.
|
||||
reserved: The number of reserved tokens.
|
||||
|
||||
Yields:
|
||||
The chunk of text.
|
||||
"""
|
||||
paragraphs = text.splitlines(keepends=True)
|
||||
current_token = 0
|
||||
current_lines = []
|
||||
|
||||
reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
|
||||
# 100 is a magic number to ensure the maximum context length is not exceeded
|
||||
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
|
||||
|
||||
while paragraphs:
|
||||
paragraph = paragraphs.pop(0)
|
||||
token = count_string_tokens(paragraph, model_name)
|
||||
if current_token + token <= max_token:
|
||||
current_lines.append(paragraph)
|
||||
current_token += token
|
||||
elif token > max_token:
|
||||
paragraphs = split_paragraph(paragraph) + paragraphs
|
||||
continue
|
||||
else:
|
||||
yield prompt_template.format("".join(current_lines))
|
||||
current_lines = [paragraph]
|
||||
current_token = token
|
||||
|
||||
if current_lines:
|
||||
yield prompt_template.format("".join(current_lines))
|
||||
|
||||
|
||||
def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str]:
|
||||
"""Split a paragraph into multiple parts.
|
||||
|
||||
Args:
|
||||
paragraph: The paragraph to split.
|
||||
sep: The separator character.
|
||||
count: The number of parts to split the paragraph into.
|
||||
|
||||
Returns:
|
||||
A list of split parts of the paragraph.
|
||||
"""
|
||||
for i in sep:
|
||||
sentences = list(_split_text_with_ends(paragraph, i))
|
||||
if len(sentences) <= 1:
|
||||
continue
|
||||
ret = ["".join(j) for j in _split_by_count(sentences, count)]
|
||||
return ret
|
||||
return _split_by_count(paragraph, count)
|
||||
|
||||
|
||||
def decode_unicode_escape(text: str) -> str:
|
||||
"""Decode a text with unicode escape sequences.
|
||||
|
||||
Args:
|
||||
text: The text to decode.
|
||||
|
||||
Returns:
|
||||
The decoded text.
|
||||
"""
|
||||
return text.encode("utf-8").decode("unicode_escape", "ignore")
|
||||
|
||||
|
||||
def _split_by_count(lst: Sequence , count: int):
|
||||
avg = len(lst) // count
|
||||
remainder = len(lst) % count
|
||||
start = 0
|
||||
for i in range(count):
|
||||
end = start + avg + (1 if i < remainder else 0)
|
||||
yield lst[start:end]
|
||||
start = end
|
||||
|
||||
|
||||
def _split_text_with_ends(text: str, sep: str = "."):
|
||||
parts = []
|
||||
for i in text:
|
||||
parts.append(i)
|
||||
if i == sep:
|
||||
yield "".join(parts)
|
||||
parts = []
|
||||
if parts:
|
||||
yield "".join(parts)
|
||||
|
|
@ -25,6 +25,21 @@ TOKEN_COSTS = {
|
|||
}
|
||||
|
||||
|
||||
TOKEN_MAX = {
|
||||
"gpt-3.5-turbo": 4096,
|
||||
"gpt-3.5-turbo-0301": 4096,
|
||||
"gpt-3.5-turbo-0613": 4096,
|
||||
"gpt-3.5-turbo-16k": 16384,
|
||||
"gpt-3.5-turbo-16k-0613": 16384,
|
||||
"gpt-4-0314": 8192,
|
||||
"gpt-4": 8192,
|
||||
"gpt-4-32k": 32768,
|
||||
"gpt-4-32k-0314": 32768,
|
||||
"gpt-4-0613": 8192,
|
||||
"text-embedding-ada-002": 8192,
|
||||
}
|
||||
|
||||
|
||||
def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
||||
"""Return the number of tokens used by a list of messages."""
|
||||
try:
|
||||
|
|
@ -39,7 +54,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
|
|||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
}:
|
||||
}:
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
|
|
@ -79,3 +94,18 @@ def count_string_tokens(string: str, model_name: str) -> int:
|
|||
"""
|
||||
encoding = tiktoken.encoding_for_model(model_name)
|
||||
return len(encoding.encode(string))
|
||||
|
||||
|
||||
def get_max_completion_tokens(messages: list[dict], model: str, default: int) -> int:
|
||||
"""Calculate the maximum number of completion tokens for a given model and list of messages.
|
||||
|
||||
Args:
|
||||
messages: A list of messages.
|
||||
model: The model name.
|
||||
|
||||
Returns:
|
||||
The maximum number of completion tokens.
|
||||
"""
|
||||
if model not in TOKEN_MAX:
|
||||
return default
|
||||
return TOKEN_MAX[model] - count_message_tokens(messages)
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -44,7 +44,7 @@ setup(
|
|||
install_requires=requirements,
|
||||
extras_require={
|
||||
"playwright": ["playwright>=1.26", "beautifulsoup4"],
|
||||
"selenium": ["selenium>4", "webdriver_manager<3.9", "beautifulsoup4"],
|
||||
"selenium": ["selenium>4", "webdriver_manager", "beautifulsoup4"],
|
||||
},
|
||||
cmdclass={
|
||||
"install_mermaid": InstallMermaidCLI,
|
||||
|
|
|
|||
32
tests/metagpt/roles/test_researcher.py
Normal file
32
tests/metagpt/roles/test_researcher.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
from pathlib import Path
|
||||
from random import random
|
||||
from tempfile import TemporaryDirectory
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.roles import researcher
|
||||
|
||||
|
||||
async def mock_llm_ask(self, prompt: str, system_msgs):
|
||||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["dataiku", "datarobot"]'
|
||||
elif "Provide up to 4 queries related to your research topic" in prompt:
|
||||
return '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' \
|
||||
'"Dataiku vs DataRobot features", "Dataiku and DataRobot use cases"]'
|
||||
elif "sort the remaining search results" in prompt:
|
||||
return '[1,2]'
|
||||
elif "Not relevant." in prompt:
|
||||
return "Not relevant" if random() > 0.5 else prompt[-100:]
|
||||
elif "provide a detailed research report" in prompt:
|
||||
return f"# Research Report\n## Introduction\n{prompt}"
|
||||
return ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_researcher(mocker):
|
||||
with TemporaryDirectory() as dirname:
|
||||
topic = "dataiku vs. datarobot"
|
||||
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
|
||||
researcher.RESEARCH_PATH = Path(dirname)
|
||||
await researcher.Researcher().run(topic)
|
||||
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
|
||||
|
|
@ -5,24 +5,44 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_search_engine.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
class MockSearchEnine:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)]
|
||||
return "\n".join(rets) if as_string else rets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("llm_api")
|
||||
async def test_search_engine(llm_api):
|
||||
search_engine = SearchEngine()
|
||||
poetries = [
|
||||
# ("北京美食", "北京"),
|
||||
("屈臣氏", "屈臣氏")
|
||||
]
|
||||
for i, j in poetries:
|
||||
rsp = await search_engine.run(i)
|
||||
# rsp = context.llm.ask_batch([prompt])
|
||||
logger.info(rsp)
|
||||
# assert any(j in k['body'] for k in rsp)
|
||||
assert len(rsp) > 0
|
||||
@pytest.mark.parametrize(
|
||||
("search_engine_typpe", "run_func", "max_results", "as_string"),
|
||||
[
|
||||
(SearchEngineType.SERPAPI_GOOGLE, None, 8, True),
|
||||
(SearchEngineType.SERPAPI_GOOGLE, None, 4, False),
|
||||
(SearchEngineType.DIRECT_GOOGLE, None, 8, True),
|
||||
(SearchEngineType.DIRECT_GOOGLE, None, 6, False),
|
||||
(SearchEngineType.SERPER_GOOGLE, None, 8, True),
|
||||
(SearchEngineType.SERPER_GOOGLE, None, 6, False),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 8, True),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
|
||||
],
|
||||
)
|
||||
async def test_search_engine(search_engine_typpe, run_func, max_results, as_string, ):
|
||||
search_engine = SearchEngine(search_engine_typpe, run_func)
|
||||
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
|
||||
logger.info(rsp)
|
||||
if as_string:
|
||||
assert isinstance(rsp, str)
|
||||
else:
|
||||
assert isinstance(rsp, list)
|
||||
assert len(rsp) == max_results
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import web_browser_engine_playwright
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy
|
|||
CONFIG.global_proxy = proxy
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type, **kwagrs)
|
||||
result = await browser.run(url)
|
||||
result = result.inner_text
|
||||
assert isinstance(result, str)
|
||||
assert "Deepwisdom" in result
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import web_browser_engine_selenium
|
||||
|
||||
|
|
@ -20,6 +21,7 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd)
|
|||
CONFIG.global_proxy = proxy
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type)
|
||||
result = await browser.run(url)
|
||||
result = result.inner_text
|
||||
assert isinstance(result, str)
|
||||
assert "Deepwisdom" in result
|
||||
|
||||
|
|
@ -27,7 +29,7 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd)
|
|||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("Deepwisdom" in i) for i in results)
|
||||
assert all(("Deepwisdom" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
finally:
|
||||
|
|
|
|||
68
tests/metagpt/utils/test_parse_html.py
Normal file
68
tests/metagpt/utils/test_parse_html.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from metagpt.utils import parse_html
|
||||
|
||||
PAGE = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Random HTML Example</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>This is a Heading</h1>
|
||||
<p>This is a paragraph with <a href="test">a link</a> and some <em>emphasized</em> text.</p>
|
||||
<ul>
|
||||
<li>Item 1</li>
|
||||
<li>Item 2</li>
|
||||
<li>Item 3</li>
|
||||
</ul>
|
||||
<ol>
|
||||
<li>Numbered Item 1</li>
|
||||
<li>Numbered Item 2</li>
|
||||
<li>Numbered Item 3</li>
|
||||
</ol>
|
||||
<table>
|
||||
<tr>
|
||||
<th>Header 1</th>
|
||||
<th>Header 2</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Row 1, Cell 1</td>
|
||||
<td>Row 1, Cell 2</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>Row 2, Cell 1</td>
|
||||
<td>Row 2, Cell 2</td>
|
||||
</tr>
|
||||
</table>
|
||||
<img src="image.jpg" alt="Sample Image">
|
||||
<form action="/submit" method="post">
|
||||
<label for="name">Name:</label>
|
||||
<input type="text" id="name" name="name" required>
|
||||
<label for="email">Email:</label>
|
||||
<input type="email" id="email" name="email" required>
|
||||
<button type="submit">Submit</button>
|
||||
</form>
|
||||
<div class="box">
|
||||
<p>This is a div with a class "box".</p>
|
||||
<p><a href="https://metagpt.com">a link</a></p>
|
||||
<p><a href="#section2"></a></p>
|
||||
<p><a href="ftp://192.168.1.1:8080"></a></p>
|
||||
<p><a href="javascript:alert('Hello');"></a></p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
CONTENT = 'This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered '\
|
||||
'Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div '\
|
||||
'with a class "box".a link'
|
||||
|
||||
|
||||
def test_web_page():
|
||||
page = parse_html.WebPage(inner_text=CONTENT, html=PAGE, url="http://example.com")
|
||||
assert page.title == "Random HTML Example"
|
||||
assert list(page.get_links()) == ["http://example.com/test", "https://metagpt.com"]
|
||||
|
||||
|
||||
def test_get_page_content():
|
||||
ret = parse_html.get_html_content(PAGE, "http://example.com")
|
||||
assert ret == CONTENT
|
||||
77
tests/metagpt/utils/test_text.py
Normal file
77
tests/metagpt/utils/test_text.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.utils.text import (
|
||||
decode_unicode_escape,
|
||||
generate_prompt_chunk,
|
||||
reduce_message_length,
|
||||
split_paragraph,
|
||||
)
|
||||
|
||||
|
||||
def _msgs():
|
||||
length = 20
|
||||
while length:
|
||||
yield "Hello," * 1000 * length
|
||||
length -= 1
|
||||
|
||||
|
||||
def _paragraphs(n):
|
||||
return " ".join("Hello World." for _ in range(n))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msgs, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(_msgs(), "gpt-3.5-turbo", "System", 1500, 1),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "Hello," * 1000, 3000, 5),
|
||||
(_msgs(), "gpt-4", "System", 2000, 3),
|
||||
(_msgs(), "gpt-4", "Hello," * 1000, 2000, 2),
|
||||
(_msgs(), "gpt-4-32k", "System", 4000, 14),
|
||||
(_msgs(), "gpt-4-32k", "Hello," * 2000, 4000, 12),
|
||||
]
|
||||
)
|
||||
def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
|
||||
assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, prompt_template, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1500, 2),
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
|
||||
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
|
||||
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
|
||||
]
|
||||
)
|
||||
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
|
||||
ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
|
||||
assert len(ret) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"paragraph, sep, count, expected",
|
||||
[
|
||||
(_paragraphs(10), ".", 2, [_paragraphs(5), f" {_paragraphs(5)}"]),
|
||||
(_paragraphs(10), ".", 3, [_paragraphs(4), f" {_paragraphs(3)}", f" {_paragraphs(3)}"]),
|
||||
(f"{_paragraphs(5)}\n{_paragraphs(3)}", "\n.", 2, [f"{_paragraphs(5)}\n", _paragraphs(3)]),
|
||||
("......", ".", 2, ["...", "..."]),
|
||||
("......", ".", 3, ["..", "..", ".."]),
|
||||
(".......", ".", 2, ["....", "..."]),
|
||||
]
|
||||
)
|
||||
def test_split_paragraph(paragraph, sep, count, expected):
|
||||
ret = split_paragraph(paragraph, sep, count)
|
||||
assert ret == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, expected",
|
||||
[
|
||||
("Hello\\nWorld", "Hello\nWorld"),
|
||||
("Hello\\tWorld", "Hello\tWorld"),
|
||||
("Hello\\u0020World", "Hello World"),
|
||||
]
|
||||
)
|
||||
def test_decode_unicode_escape(text, expected):
|
||||
assert decode_unicode_escape(text) == expected
|
||||
Loading…
Add table
Add a link
Reference in a new issue