diff --git a/config/config.yaml b/config/config.yaml index ceab18854..590ef2561 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -65,4 +65,8 @@ SD_T2I_API: "/sdapi/v1/txt2img" ### for update_costs & calc_usage UPDATE_COSTS: false -CALC_USAGE: false \ No newline at end of file +CALC_USAGE: false + +### for Research +MODEL_FOR_RESEARCHER_SUMMARY: gpt-3.5-turbo +MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k diff --git a/examples/research.py b/examples/research.py new file mode 100644 index 000000000..344f8d0e9 --- /dev/null +++ b/examples/research.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python + +import asyncio + +from metagpt.roles.researcher import RESEARCH_PATH, Researcher + + +async def main(): + topic = "dataiku vs. datarobot" + role = Researcher(language="en-us") + await role.run(topic) + print(f"save report to {RESEARCH_PATH / f'{topic}.md'}.") + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/metagpt/actions/__init__.py b/metagpt/actions/__init__.py index 0c861aa69..c56f25e31 100644 --- a/metagpt/actions/__init__.py +++ b/metagpt/actions/__init__.py @@ -22,6 +22,7 @@ from metagpt.actions.write_code_review import WriteCodeReview from metagpt.actions.write_prd import WritePRD from metagpt.actions.write_prd_review import WritePRDReview from metagpt.actions.write_test import WriteTest +from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch class ActionType(Enum): @@ -40,3 +41,6 @@ class ActionType(Enum): WRITE_TASKS = WriteTasks ASSIGN_TASKS = AssignTasks SEARCH_AND_SUMMARIZE = SearchAndSummarize + COLLECT_LINKS = CollectLinks + WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize + CONDUCT_RESEARCH = ConductResearch diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py new file mode 100644 index 000000000..81eb876dd --- /dev/null +++ b/metagpt/actions/research.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import asyncio +import json +from typing import Callable + +from pydantic import parse_obj_as + +from metagpt.actions import Action +from metagpt.config import CONFIG +from metagpt.logs import logger +from metagpt.tools.search_engine import SearchEngine +from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType +from metagpt.utils.text import generate_prompt_chunk, reduce_message_length + +LANG_PROMPT = "Please respond in {language}." + +RESEARCH_BASE_SYSTEM = """You are an AI critical thinker research assistant. Your sole purpose is to write well \ +written, critically acclaimed, objective and structured reports on the given text.""" + +RESEARCH_TOPIC_SYSTEM = "You are an AI researcher assistant, and your research topic is:\n#TOPIC#\n{topic}" + +SEARCH_TOPIC_PROMPT = """Please provide up to 2 necessary keywords related to your research topic for Google search. \ +Your response must be in JSON format, for example: ["keyword1", "keyword2"].""" + +SUMMARIZE_SEARCH_PROMPT = """### Requirements +1. The keywords related to your research topic and the search results are shown in the "Search Result Information" section. +2. Provide up to {decomposition_nums} queries related to your research topic base on the search results. +3. Please respond in the following JSON format: ["query1", "query2", "query3", ...]. + +### Search Result Information +{search_results} +""" + +COLLECT_AND_RANKURLS_PROMPT = """### Topic +{topic} +### Query +{query} + +### The online search results +{results} + +### Requirements +Please remove irrelevant search results that are not related to the query or topic. Then, sort the remaining search results \ +based on the link credibility. If two results have equal credibility, prioritize them based on the relevance. Provide the +ranked results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words. +""" + +WEB_BROWSE_AND_SUMMARIZE_PROMPT = '''### Requirements +1. Utilize the text in the "Reference Information" section to respond to the question "{query}". +2. If the question cannot be directly answered using the text, but the text is related to the research topic, please provide \ +a comprehensive summary of the text. +3. If the text is entirely unrelated to the research topic, please reply with a simple text "Not relevant." +4. Include all relevant factual information, numbers, statistics, etc., if available. + +### Reference Information +{content} +''' + + +CONDUCT_RESEARCH_PROMPT = '''### Reference Information +{content} + +### Requirements +Please provide a detailed research report in response to the following topic: "{topic}", using the information provided \ +above. The report must meet the following requirements: + +- Focus on directly addressing the chosen topic. +- Ensure a well-structured and in-depth presentation, incorporating relevant facts and figures where available. +- Present data and findings in an intuitive manner, utilizing feature comparative tables, if applicable. +- The report should have a minimum word count of 2,000 and be formatted with Markdown syntax following APA style guidelines. +- Include all source URLs in APA format at the end of the report. +''' + + +class CollectLinks(Action): + """Action class to collect links from a search engine.""" + def __init__( + self, + name: str = "", + *args, + rank_func: Callable[[list[str]], None] | None = None, + **kwargs, + ): + super().__init__(name, *args, **kwargs) + self.desc = "Collect links from a search engine." + self.search_engine = SearchEngine() + self.rank_func = rank_func + + async def run( + self, + topic: str, + decomposition_nums: int = 4, + url_per_query: int = 4, + system_text: str | None = None, + ) -> dict[str, list[str]]: + """Run the action to collect links. + + Args: + topic: The research topic. + decomposition_nums: The number of search questions to generate. + url_per_query: The number of URLs to collect per search question. + system_text: The system text. + + Returns: + A dictionary containing the search questions as keys and the collected URLs as values. + """ + system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic) + keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text]) + try: + keywords = json.loads(keywords) + keywords = parse_obj_as(list[str], keywords) + except Exception as e: + logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}") + keywords = [topic] + results = await asyncio.gather(*(self.search_engine.run(i, as_string=False) for i in keywords)) + + def gen_msg(): + while True: + search_results = "\n".join(f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results)) + prompt = SUMMARIZE_SEARCH_PROMPT.format(decomposition_nums=decomposition_nums, search_results=search_results) + yield prompt + remove = max(results, key=len) + remove.pop() + if len(remove) == 0: + break + prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp) + logger.debug(prompt) + queries = await self._aask(prompt, [system_text]) + try: + queries = json.loads(queries) + queries = parse_obj_as(list[str], queries) + except Exception as e: + logger.exception(f"fail to break down the research question due to {e}") + queries = keywords + ret = {} + for query in queries: + ret[query] = await self._search_and_rank_urls(topic, query, url_per_query) + return ret + + async def _search_and_rank_urls(self, topic: str, query: str, num_results: int = 4) -> list[str]: + """Search and rank URLs based on a query. + + Args: + topic: The research topic. + query: The search query. + num_results: The number of URLs to collect. + + Returns: + A list of ranked URLs. + """ + max_results = max(num_results * 2, 6) + results = await self.search_engine.run(query, max_results=max_results, as_string=False) + _results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results)) + prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results) + logger.debug(prompt) + indices = await self._aask(prompt) + try: + indices = json.loads(indices) + assert all(isinstance(i, int) for i in indices) + except Exception as e: + logger.exception(f"fail to rank results for {e}") + indices = list(range(max_results)) + results = [results[i] for i in indices] + if self.rank_func: + results = self.rank_func(results) + return [i["link"] for i in results[:num_results]] + + +class WebBrowseAndSummarize(Action): + """Action class to explore the web and provide summaries of articles and webpages.""" + def __init__( + self, + *args, + browse_func: Callable[[list[str]], None] | None = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + if CONFIG.model_for_researcher_summary: + self.llm.model = CONFIG.model_for_researcher_summary + self.web_browser_engine = WebBrowserEngine( + engine=WebBrowserEngineType.CUSTOM if browse_func else None, + run_func=browse_func, + ) + self.desc = "Explore the web and provide summaries of articles and webpages." + + async def run( + self, + url: str, + *urls: str, + query: str, + system_text: str = RESEARCH_BASE_SYSTEM, + ) -> dict[str, str]: + """Run the action to browse the web and provide summaries. + + Args: + url: The main URL to browse. + urls: Additional URLs to browse. + query: The research question. + system_text: The system text. + + Returns: + A dictionary containing the URLs as keys and their summaries as values. + """ + contents = await self.web_browser_engine.run(url, *urls) + if not urls: + contents = [contents] + + summaries = {} + prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}") + for u, content in zip([url, *urls], contents): + content = content.inner_text + chunk_summaries = [] + for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp): + logger.debug(prompt) + summary = await self._aask(prompt, [system_text]) + if summary == "Not relevant.": + continue + chunk_summaries.append(summary) + + if not chunk_summaries: + summaries[u] = None + continue + + if len(chunk_summaries) == 1: + summaries[u] = chunk_summaries[0] + continue + + content = "\n".join(chunk_summaries) + prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content) + summary = await self._aask(prompt, [system_text]) + summaries[u] = summary + return summaries + + +class ConductResearch(Action): + """Action class to conduct research and generate a research report.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if CONFIG.model_for_researcher_report: + self.llm.model = CONFIG.model_for_researcher_report + + async def run( + self, + topic: str, + content: str, + system_text: str = RESEARCH_BASE_SYSTEM, + ) -> str: + """Run the action to conduct research and generate a research report. + + Args: + topic: The research topic. + content: The content for research. + system_text: The system text. + + Returns: + The generated research report. + """ + prompt = CONDUCT_RESEARCH_PROMPT.format(topic=topic, content=content) + logger.debug(prompt) + self.llm.auto_max_tokens = True + return await self._aask(prompt, [system_text]) + + +def get_research_system_text(topic: str, language: str): + """Get the system text for conducting research. + + Args: + topic: The research topic. + language: The language for the system text. + + Returns: + The system text for conducting research. + """ + return " ".join((RESEARCH_TOPIC_SYSTEM.format(topic=topic), LANG_PROMPT.format(language=language))) diff --git a/metagpt/config.py b/metagpt/config.py index d53571468..41c1f8645 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -4,14 +4,14 @@ 提供配置,单例 """ import os -import openai +import openai import yaml from metagpt.const import PROJECT_ROOT from metagpt.logs import logger -from metagpt.utils.singleton import Singleton from metagpt.tools import SearchEngineType, WebBrowserEngineType +from metagpt.utils.singleton import Singleton class NotConfiguredException(Exception): @@ -77,12 +77,12 @@ class Config(metaclass=Singleton): logger.warning("LONG_TERM_MEMORY is True") self.max_budget = self._get("MAX_BUDGET", 10.0) self.total_cost = 0.0 - self.puppeteer_config = self._get("PUPPETEER_CONFIG","") - self.mmdc = self._get("MMDC","mmdc") - self.update_costs = self._get("UPDATE_COSTS",True) - self.calc_usage = self._get("CALC_USAGE",True) - - + self.puppeteer_config = self._get("PUPPETEER_CONFIG", "") + self.mmdc = self._get("MMDC", "mmdc") + self.update_costs = self._get("UPDATE_COSTS", True) + self.calc_usage = self._get("CALC_USAGE", True) + self.model_for_researcher_summary = self._get("MODEL_FOR_RESEARCHER_SUMMARY") + self.model_for_researcher_report = self._get("MODEL_FOR_RESEARCHER_REPORT") def _init_with_config_files_and_env(self, configs: dict, yaml_file): """从config/key.yaml / config/config.yaml / env三处按优先级递减加载""" diff --git a/metagpt/const.py b/metagpt/const.py index abbfb40e0..505eebd46 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -32,5 +32,6 @@ UT_PY_PATH = UT_PATH / "files/ut/" API_QUESTIONS_PATH = UT_PATH / "files/question/" YAPI_URL = "http://yapi.deepwisdomai.com/" TMP = PROJECT_ROOT / 'tmp' +RESEARCH_PATH = DATA_PATH / "research" MEM_TTL = 24 * 30 * 3600 diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index fe9532d43..e10c78c8f 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # -*- coding: utf-8 -*- """ @Time : 2023/5/5 23:08 @@ -20,6 +19,7 @@ from metagpt.utils.token_counter import ( TOKEN_COSTS, count_message_tokens, count_string_tokens, + get_max_completion_tokens, ) @@ -33,20 +33,25 @@ def retry(max_retries): except Exception: if i == max_retries - 1: raise - await asyncio.sleep(2 ** i) + await asyncio.sleep(2**i) + return wrapper + return decorator class RateLimiter: """Rate control class, each call goes through wait_if_needed, sleep if rate control is needed""" + def __init__(self, rpm): self.last_call_time = 0 - self.interval = 1.1 * 60 / rpm # Here 1.1 is used because even if the calls are made strictly according to time, they will still be QOS'd; consider switching to simple error retry later + # Here 1.1 is used because even if the calls are made strictly according to time, + # they will still be QOS'd; consider switching to simple error retry later + self.interval = 1.1 * 60 / rpm self.rpm = rpm def split_batches(self, batch): - return [batch[i:i + self.rpm] for i in range(0, len(batch), self.rpm)] + return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)] async def wait_if_needed(self, num_requests): current_time = time.time() @@ -69,6 +74,7 @@ class Costs(NamedTuple): class CostManager(metaclass=Singleton): """计算使用接口的开销""" + def __init__(self): self.total_prompt_tokens = 0 self.total_completion_tokens = 0 @@ -86,13 +92,12 @@ class CostManager(metaclass=Singleton): """ self.total_prompt_tokens += prompt_tokens self.total_completion_tokens += completion_tokens - cost = ( - prompt_tokens * TOKEN_COSTS[model]["prompt"] - + completion_tokens * TOKEN_COSTS[model]["completion"] - ) / 1000 + cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000 self.total_cost += cost - logger.info(f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | " - f"Current cost: ${cost:.3f}, {prompt_tokens=}, {completion_tokens=}") + logger.info( + f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | " + f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + ) CONFIG.total_cost = self.total_cost def get_total_prompt_tokens(self): @@ -131,10 +136,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): """ Check https://platform.openai.com/examples for examples """ + def __init__(self): self.__init_openai(CONFIG) self.llm = openai self.model = CONFIG.openai_api_model + self.auto_max_tokens = False self._cost_manager = CostManager() RateLimiter.__init__(self, rpm=self.rpm) @@ -148,10 +155,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self.rpm = int(config.get("RPM", 10)) async def _achat_completion_stream(self, messages: list[dict]) -> str: - response = await openai.ChatCompletion.acreate( - **self._cons_kwargs(messages), - stream=True - ) + response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True) # create variables to collect the stream of chunks collected_chunks = [] @@ -159,41 +163,41 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): # iterate through the stream of events async for chunk in response: collected_chunks.append(chunk) # save the event response - chunk_message = chunk['choices'][0]['delta'] # extract the message + chunk_message = chunk["choices"][0]["delta"] # extract the message collected_messages.append(chunk_message) # save the message if "content" in chunk_message: print(chunk_message["content"], end="") print() - full_reply_content = ''.join([m.get('content', '') for m in collected_messages]) + full_reply_content = "".join([m.get("content", "") for m in collected_messages]) usage = self._calc_usage(messages, full_reply_content) self._update_costs(usage) return full_reply_content def _cons_kwargs(self, messages: list[dict]) -> dict: - if CONFIG.openai_api_type == 'azure': + if CONFIG.openai_api_type == "azure": kwargs = { "deployment_id": CONFIG.deployment_id, "messages": messages, - "max_tokens": CONFIG.max_tokens_rsp, + "max_tokens": self.get_max_tokens(messages), "n": 1, "stop": None, - "temperature": 0.3 + "temperature": 0.3, } else: kwargs = { "model": self.model, "messages": messages, - "max_tokens": CONFIG.max_tokens_rsp, + "max_tokens": self.get_max_tokens(messages), "n": 1, "stop": None, - "temperature": 0.3 + "temperature": 0.3, } return kwargs async def _achat_completion(self, messages: list[dict]) -> dict: rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages)) - self._update_costs(rsp.get('usage')) + self._update_costs(rsp.get("usage")) return rsp def _chat_completion(self, messages: list[dict]) -> dict: @@ -262,3 +266,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): def get_costs(self) -> Costs: return self._cost_manager.get_costs() + + def get_max_tokens(self, messages: list[dict]): + if not self.auto_max_tokens: + return CONFIG.max_tokens_rsp + return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py new file mode 100644 index 000000000..815cfa172 --- /dev/null +++ b/metagpt/roles/researcher.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python + +import asyncio + +from pydantic import BaseModel + +from metagpt.actions import CollectLinks, ConductResearch, WebBrowseAndSummarize +from metagpt.actions.research import get_research_system_text +from metagpt.const import RESEARCH_PATH +from metagpt.logs import logger +from metagpt.roles import Role +from metagpt.schema import Message + + +class Report(BaseModel): + topic: str + links: dict[str, list[str]] = None + summaries: list[tuple[str, str]] = None + content: str = "" + + +class Researcher(Role): + def __init__( + self, + name: str = "David", + profile: str = "Researcher", + goal: str = "Gather information and conduct research", + constraints: str = "Ensure accuracy and relevance of information", + language: str = "en-us", + **kwargs, + ): + super().__init__(name, profile, goal, constraints, **kwargs) + self._init_actions([CollectLinks(name), WebBrowseAndSummarize(name), ConductResearch(name)]) + self.language = language + if language not in ("en-us", "zh-cn"): + logger.warning(f"The language `{language}` has not been tested, it may not work.") + + async def _think(self) -> None: + if self._rc.todo is None: + self._set_state(0) + return + + if self._rc.state + 1 < len(self._states): + self._set_state(self._rc.state + 1) + else: + self._rc.todo = None + + async def _act(self) -> Message: + logger.info(f"{self._setting}: ready to {self._rc.todo}") + todo = self._rc.todo + msg = self._rc.memory.get(k=1)[0] + if isinstance(msg.instruct_content, Report): + instruct_content = msg.instruct_content + topic = instruct_content.topic + else: + topic = msg.content + + research_system_text = get_research_system_text(topic, self.language) + if isinstance(todo, CollectLinks): + links = await todo.run(topic, 4, 4) + ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=type(todo)) + elif isinstance(todo, WebBrowseAndSummarize): + links = instruct_content.links + todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items()) + summaries = await asyncio.gather(*todos) + summaries = list((url, summary) for i in summaries for (url, summary) in i.items() if summary) + ret = Message("", Report(topic=topic, summaries=summaries), role=self.profile, cause_by=type(todo)) + else: + summaries = instruct_content.summaries + summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries) + content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text) + ret = Message("", Report(topic=topic, content=content), role=self.profile, cause_by=type(self._rc.todo)) + self._rc.memory.add(ret) + return ret + + async def _react(self) -> Message: + while True: + await self._think() + if self._rc.todo is None: + break + msg = await self._act() + report = msg.instruct_content + self.write_report(report.topic, report.content) + return msg + + def write_report(self, topic: str, content: str): + filepath = RESEARCH_PATH / f"{topic}.md" + filepath.write_text(content) + + +if __name__ == "__main__": + role = Researcher(language="en-us") + asyncio.run(role.run("dataiku vs. datarobot")) diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index f9b7abc52..e1f921c05 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -14,6 +14,7 @@ class SearchEngineType(Enum): SERPAPI_GOOGLE = auto() DIRECT_GOOGLE = auto() SERPER_GOOGLE = auto() + DUCK_DUCK_GO = auto() CUSTOM_ENGINE = auto() diff --git a/metagpt/tools/search_engine.py b/metagpt/tools/search_engine.py index cfd4e8789..d28700054 100644 --- a/metagpt/tools/search_engine.py +++ b/metagpt/tools/search_engine.py @@ -7,122 +7,76 @@ """ from __future__ import annotations -import json +import importlib +from typing import Callable, Coroutine, Literal, overload -from metagpt.config import Config -from metagpt.logs import logger -from metagpt.tools.search_engine_serpapi import SerpAPIWrapper -from metagpt.tools.search_engine_serper import SerperWrapper - -config = Config() +from metagpt.config import CONFIG from metagpt.tools import SearchEngineType class SearchEngine: - """ - TODO: 合入Google Search 并进行反代 - 注:这里Google需要挂Proxifier或者类似全局代理 - - DDG: https://pypi.org/project/duckduckgo-search/ - - GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9 - """ - def __init__(self, engine=None, run_func=None): - self.config = Config() - self.run_func = run_func - self.engine = engine or self.config.search_engine + """Class representing a search engine. - @classmethod - def run_google(cls, query, max_results=8): - # results = ddg(query, max_results=max_results) - results = google_official_search(query, num_results=max_results) - logger.info(results) - return results + Args: + engine: The search engine type. Defaults to the search engine specified in the config. + run_func: The function to run the search. Defaults to None. - async def run(self, query: str, max_results=8): - if self.engine == SearchEngineType.SERPAPI_GOOGLE: - api = SerpAPIWrapper() - rsp = await api.run(query) - elif self.engine == SearchEngineType.DIRECT_GOOGLE: - rsp = SearchEngine.run_google(query, max_results) - elif self.engine == SearchEngineType.SERPER_GOOGLE: - api = SerperWrapper() - rsp = await api.run(query) - elif self.engine == SearchEngineType.CUSTOM_ENGINE: - rsp = self.run_func(query) + Attributes: + run_func: The function to run the search. + engine: The search engine type. + """ + def __init__( + self, + engine: SearchEngineType | None = None, + run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None, + ): + engine = engine or CONFIG.search_engine + if engine == SearchEngineType.SERPAPI_GOOGLE: + module = "metagpt.tools.search_engine_serpapi" + run_func = importlib.import_module(module).SerpAPIWrapper().run + elif engine == SearchEngineType.SERPER_GOOGLE: + module = "metagpt.tools.search_engine_serper" + run_func = importlib.import_module(module).SerperWrapper().run + elif engine == SearchEngineType.DIRECT_GOOGLE: + module = "metagpt.tools.search_engine_googleapi" + run_func = importlib.import_module(module).GoogleAPIWrapper().run + elif engine == SearchEngineType.DUCK_DUCK_GO: + module = "metagpt.tools.search_engine_ddg" + run_func = importlib.import_module(module).DDGAPIWrapper().run + elif engine == SearchEngineType.CUSTOM_ENGINE: + pass # run_func = run_func else: raise NotImplementedError - return rsp + self.engine = engine + self.run_func = run_func + @overload + def run( + self, + query: str, + max_results: int = 8, + as_string: Literal[True] = True, + ) -> str: + ... -def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]: - """Return the results of a Google search using the official Google API + @overload + def run( + self, + query: str, + max_results: int = 8, + as_string: Literal[False] = False, + ) -> list[dict[str, str]]: + ... - Args: - query (str): The search query. - num_results (int): The number of results to return. + async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]: + """Run a search query. - Returns: - str: The results of the search. - """ + Args: + query: The search query. + max_results: The maximum number of results to return. Defaults to 8. + as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True. - from googleapiclient.discovery import build - from googleapiclient.errors import HttpError - - try: - api_key = config.google_api_key - custom_search_engine_id = config.google_cse_id - - with build("customsearch", "v1", developerKey=api_key) as service: - - result = ( - service.cse() - .list(q=query, cx=custom_search_engine_id, num=num_results) - .execute() - ) - logger.info(result) - # Extract the search result items from the response - search_results = result.get("items", []) - - # Create a list of only the URLs from the search results - search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results] - - except HttpError as e: - # Handle errors in the API call - error_details = json.loads(e.content.decode()) - - # Check if the error is related to an invalid or missing API key - if error_details.get("error", {}).get( - "code" - ) == 403 and "invalid API key" in error_details.get("error", {}).get( - "message", "" - ): - return "Error: The provided Google API key is invalid or missing." - else: - return f"Error: {e}" - # google_result can be a list or a string depending on the search results - - # Return the list of search result URLs - return search_results_details - - -def safe_google_results(results: str | list) -> str: - """ - Return the results of a google search in a safe format. - - Args: - results (str | list): The search results. - - Returns: - str: The results of the search. - """ - if isinstance(results, list): - safe_message = json.dumps( - # FIXME: # .encode("utf-8", "ignore") 这里去掉了,但是AutoGPT里有,很奇怪 - [result for result in results] - ) - else: - safe_message = results.encode("utf-8", "ignore").decode("utf-8") - return safe_message - - -if __name__ == '__main__': - SearchEngine.run(query='wtf') + Returns: + The search results as a string or a list of dictionaries. + """ + return await self.run_func(query, max_results=max_results, as_string=as_string) diff --git a/metagpt/tools/search_engine_ddg.py b/metagpt/tools/search_engine_ddg.py new file mode 100644 index 000000000..c054afed1 --- /dev/null +++ b/metagpt/tools/search_engine_ddg.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import asyncio +import json +from concurrent import futures +from typing import Literal, overload + +from duckduckgo_search import DDGS +from googleapiclient.errors import HttpError + +from metagpt.config import CONFIG +from metagpt.logs import logger + + +class DDGAPIWrapper: + """Wrapper around duckduckgo_search API. + + To use this module, you should have the `duckduckgo_search` Python package installed. + """ + def __init__( + self, + *, + loop: asyncio.AbstractEventLoop | None = None, + executor: futures.Executor | None = None, + ): + kwargs = {} + if CONFIG.global_proxy: + kwargs["proxies"] = CONFIG.global_proxy + self.loop = loop + self.executor = executor + self.ddgs = DDGS(**kwargs) + + @overload + def run( + self, + query: str, + max_results: int = 8, + as_string: Literal[True] = True, + focus: list[str] | None = None, + ) -> str: + ... + + @overload + def run( + self, + query: str, + max_results: int = 8, + as_string: Literal[False] = False, + focus: list[str] | None = None, + ) -> list[dict[str, str]]: + ... + + async def run( + self, + query: str, + max_results: int = 8, + as_string: bool = True, + ) -> str | list[dict]: + """Return the results of a Google search using the official Google API + + Args: + query: The search query. + max_results: The number of results to return. + as_string: A boolean flag to determine the return type of the results. If True, the function will + return a formatted string with the search results. If False, it will return a list of dictionaries + containing detailed information about each search result. + + Returns: + The results of the search. + """ + loop = self.loop or asyncio.get_event_loop() + future = loop.run_in_executor( + self.executor, + self._search_from_ddgs, + query, + max_results, + ) + try: + search_results = await future + # Extract the search result items from the response + + except HttpError as e: + # Handle errors in the API call + logger.exception(f"fail to search {query} for {e}") + search_results = [] + + # Return the list of search result URLs + if as_string: + return json.dumps(search_results, ensure_ascii=False) + return search_results + + def _search_from_ddgs(self, query: str, max_results: int): + return [ + { + "link": i["href"], + "snippet": i["body"], + "title": i["title"] + } for (_, i) in zip(range(max_results), self.ddgs.text(query)) + ] + + +if __name__ == "__main__": + import fire + + fire.Fire(DDGAPIWrapper().run) diff --git a/metagpt/tools/search_engine_googleapi.py b/metagpt/tools/search_engine_googleapi.py new file mode 100644 index 000000000..c226ca8d2 --- /dev/null +++ b/metagpt/tools/search_engine_googleapi.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from __future__ import annotations + +import asyncio +import json +from concurrent import futures +from urllib.parse import urlparse + +import httplib2 +from googleapiclient.discovery import build +from googleapiclient.errors import HttpError + +from metagpt.config import CONFIG +from metagpt.logs import logger + + +class GoogleAPIWrapper: + """Wrapper around GoogleAPI. + + To use this module, you should have the `google-api-python-client` Python package installed + and set property values for the configurations `GOOGLE_API_KEY` and `GOOGLE_CSE_ID`. See + https://programmablesearchengine.google.com/controlpanel/all. + """ + def __init__( + self, + *, + loop: asyncio.AbstractEventLoop | None = None, + executor: futures.Executor | None = None, + ): + build_kwargs = {"developerKey": CONFIG.google_api_key} + if CONFIG.global_proxy: + parse_result = urlparse(CONFIG.global_proxy) + proxy_type = parse_result.scheme + if proxy_type == "https": + proxy_type = "http" + build_kwargs["http"] = httplib2.Http( + proxy_info=httplib2.ProxyInfo( + getattr(httplib2.socks, f"PROXY_TYPE_{proxy_type.upper()}"), + parse_result.hostname, + parse_result.port, + ), + ) + service = build("customsearch", "v1", **build_kwargs) + self.google_api_client = service.cse() + self.custom_search_engine_id = CONFIG.google_cse_id + self.loop = loop + self.executor = executor + + async def run( + self, + query: str, + max_results: int = 8, + as_string: bool = True, + focus: list[str] | None = None, + ) -> str | list[dict]: + """Return the results of a Google search using the official Google API. + + Args: + query: The search query. + max_results: The number of results to return. + as_string: A boolean flag to determine the return type of the results. If True, the function will + return a formatted string with the search results. If False, it will return a list of dictionaries + containing detailed information about each search result. + focus: Specific information to be focused on from each search result. + + Returns: + The results of the search. + """ + loop = self.loop or asyncio.get_event_loop() + future = loop.run_in_executor( + self.executor, + self.google_api_client.list( + q=query, + num=max_results, + cx=self.custom_search_engine_id + ).execute + ) + try: + result = await future + # Extract the search result items from the response + search_results = result.get("items", []) + + except HttpError as e: + # Handle errors in the API call + logger.exception(f"fail to search {query} for {e}") + search_results = [] + + focus = focus or ["snippet", "link", "title"] + details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results] + # Return the list of search result URLs + if as_string: + return safe_google_results(details) + + return details + + +def safe_google_results(results: str | list) -> str: + """Return the results of a google search in a safe format. + + Args: + results: The search results. + + Returns: + The results of the search. + """ + if isinstance(results, list): + safe_message = json.dumps([result for result in results]) + else: + safe_message = results.encode("utf-8", "ignore").decode("utf-8") + return safe_message + + +if __name__ == "__main__": + import fire + + fire.Fire(GoogleAPIWrapper().run) diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 28033f237..3d2d7cfe4 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -37,16 +37,17 @@ class SerpAPIWrapper(BaseModel): class Config: arbitrary_types_allowed = True - async def run(self, query: str, **kwargs: Any) -> str: + async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through SerpAPI and parse result async.""" - return self._process_response(await self.results(query)) + return self._process_response(await self.results(query, max_results), as_string=as_string) - async def results(self, query: str) -> dict: + async def results(self, query: str, max_results: int) -> dict: """Use aiohttp to run query through SerpAPI and return the results async.""" def construct_url_and_params() -> Tuple[str, Dict[str, str]]: params = self.get_params(query) params["source"] = "python" + params["num"] = max_results if self.serpapi_api_key: params["serp_api_key"] = self.serpapi_api_key params["output"] = "json" @@ -74,10 +75,10 @@ class SerpAPIWrapper(BaseModel): return params @staticmethod - def _process_response(res: dict) -> str: + def _process_response(res: dict, as_string: bool) -> str: """Process response from SerpAPI.""" # logger.debug(res) - focus = ['title', 'snippet', 'link'] + focus = ["title", "snippet", "link"] get_focused = lambda x: {i: j for i, j in x.items() if i in focus} if "error" in res.keys(): @@ -86,20 +87,11 @@ class SerpAPIWrapper(BaseModel): toret = res["answer_box"]["answer"] elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): toret = res["answer_box"]["snippet"] - elif ( - "answer_box" in res.keys() - and "snippet_highlighted_words" in res["answer_box"].keys() - ): + elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys(): toret = res["answer_box"]["snippet_highlighted_words"][0] - elif ( - "sports_results" in res.keys() - and "game_spotlight" in res["sports_results"].keys() - ): + elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys(): toret = res["sports_results"]["game_spotlight"] - elif ( - "knowledge_graph" in res.keys() - and "description" in res["knowledge_graph"].keys() - ): + elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys(): toret = res["knowledge_graph"]["description"] elif "snippet" in res["organic_results"][0].keys(): toret = res["organic_results"][0]["snippet"] @@ -112,4 +104,10 @@ class SerpAPIWrapper(BaseModel): if res.get("organic_results"): toret_l += [get_focused(i) for i in res.get("organic_results")] - return str(toret) + '\n' + str(toret_l) + return str(toret) + '\n' + str(toret_l) if as_string else toret_l + + +if __name__ == "__main__": + import fire + + fire.Fire(SerpAPIWrapper().run) diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py index 80c2f8001..2ae2c3b7d 100644 --- a/metagpt/tools/search_engine_serper.py +++ b/metagpt/tools/search_engine_serper.py @@ -36,16 +36,19 @@ class SerperWrapper(BaseModel): class Config: arbitrary_types_allowed = True - async def run(self, query: str, **kwargs: Any) -> str: + async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through Serper and parse result async.""" - queries = query.split("\n") - return "\n".join([self._process_response(res) for res in await self.results(queries)]) + if isinstance(query, str): + return self._process_response((await self.results([query], max_results))[0], as_string=as_string) + else: + results = [self._process_response(res, as_string) for res in await self.results(query, max_results)] + return "\n".join(results) if as_string else results - async def results(self, queries: list[str]) -> dict: + async def results(self, queries: list[str], max_results: int = 8) -> dict: """Use aiohttp to run query through Serper and return the results async.""" def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]: - payloads = self.get_payloads(queries) + payloads = self.get_payloads(queries, max_results) url = "https://google.serper.dev/search" headers = self.get_headers() return url, payloads, headers @@ -61,12 +64,13 @@ class SerperWrapper(BaseModel): return res - def get_payloads(self, queries: list[str]) -> Dict[str, str]: + def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]: """Get payloads for Serper.""" payloads = [] for query in queries: _payload = { "q": query, + "num": max_results, } payloads.append({**self.payload, **_payload}) return json.dumps(payloads, sort_keys=True) @@ -79,7 +83,7 @@ class SerperWrapper(BaseModel): return headers @staticmethod - def _process_response(res: dict) -> str: + def _process_response(res: dict, as_string: bool = False) -> str: """Process response from SerpAPI.""" # logger.debug(res) focus = ['title', 'snippet', 'link'] @@ -117,4 +121,10 @@ class SerperWrapper(BaseModel): if res.get("organic"): toret_l += [get_focused(i) for i in res.get("organic")] - return str(toret) + '\n' + str(toret_l) + return str(toret) + '\n' + str(toret_l) if as_string else toret_l + + +if __name__ == "__main__": + import fire + + fire.Fire(SerperWrapper().run) diff --git a/metagpt/tools/web_browser_engine.py b/metagpt/tools/web_browser_engine.py index d1f83934f..453d87f31 100644 --- a/metagpt/tools/web_browser_engine.py +++ b/metagpt/tools/web_browser_engine.py @@ -1,22 +1,20 @@ #!/usr/bin/env python from __future__ import annotations -import asyncio -import importlib -from typing import Any, Callable, Coroutine, overload +import importlib +from typing import Any, Callable, Coroutine, Literal, overload from metagpt.config import CONFIG from metagpt.tools import WebBrowserEngineType -from bs4 import BeautifulSoup +from metagpt.utils.parse_html import WebPage class WebBrowserEngine: def __init__( self, engine: WebBrowserEngineType | None = None, - run_func: Callable[..., Coroutine[Any, Any, str | list[str]]] | None = None, - parse_func: Callable[[str], str] | None = None, + run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None, ): engine = engine or CONFIG.web_browser_engine @@ -30,30 +28,25 @@ class WebBrowserEngine: run_func = run_func else: raise NotImplementedError - self.parse_func = parse_func or get_page_content self.run_func = run_func self.engine = engine @overload - async def run(self, url: str) -> str: + async def run(self, url: str) -> WebPage: ... @overload - async def run(self, url: str, *urls: str) -> list[str]: + async def run(self, url: str, *urls: str) -> list[WebPage]: ... - async def run(self, url: str, *urls: str) -> str | list[str]: - page = await self.run_func(url, *urls) - if isinstance(page, str): - return self.parse_func(page) - return [self.parse_func(i) for i in page] - - -def get_page_content(page: str): - soup = BeautifulSoup(page, "html.parser") - return "\n".join(i.text.strip() for i in soup.find_all(["h1", "h2", "h3", "h4", "h5", "p", "pre"])) + async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: + return await self.run_func(url, *urls) if __name__ == "__main__": - text = asyncio.run(WebBrowserEngine().run("https://fuzhi.ai/")) - print(text) + import fire + + async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs): + return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls) + + fire.Fire(main) diff --git a/metagpt/tools/web_browser_engine_playwright.py b/metagpt/tools/web_browser_engine_playwright.py index ae8644cce..030e7701b 100644 --- a/metagpt/tools/web_browser_engine_playwright.py +++ b/metagpt/tools/web_browser_engine_playwright.py @@ -2,12 +2,15 @@ from __future__ import annotations import asyncio -from pathlib import Path import sys +from pathlib import Path from typing import Literal + from playwright.async_api import async_playwright + from metagpt.config import CONFIG from metagpt.logs import logger +from metagpt.utils.parse_html import WebPage class PlaywrightWrapper: @@ -16,7 +19,7 @@ class PlaywrightWrapper: To use this module, you should have the `playwright` Python package installed and ensure that the required browsers are also installed. You can install playwright by running the command `pip install metagpt[playwright]` and download the necessary browser binaries by running the - command `playwright install` for the first time." + command `playwright install` for the first time. """ def __init__( @@ -40,27 +43,30 @@ class PlaywrightWrapper: self._context_kwargs = context_kwargs self._has_run_precheck = False - async def run(self, url: str, *urls: str) -> str | list[str]: + async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: async with async_playwright() as ap: browser_type = getattr(ap, self.browser_type) await self._run_precheck(browser_type) browser = await browser_type.launch(**self.launch_kwargs) - - async def _scrape(url): - context = await browser.new_context(**self._context_kwargs) - page = await context.new_page() - async with page: - try: - await page.goto(url) - await page.evaluate("window.scrollTo(0, document.body.scrollHeight)") - content = await page.content() - return content - except Exception as e: - return f"Fail to load page content for {e}" + _scrape = self._scrape if urls: - return await asyncio.gather(_scrape(url), *(_scrape(i) for i in urls)) - return await _scrape(url) + return await asyncio.gather(_scrape(browser, url), *(_scrape(browser, i) for i in urls)) + return await _scrape(browser, url) + + async def _scrape(self, browser, url): + context = await browser.new_context(**self._context_kwargs) + page = await context.new_page() + async with page: + try: + await page.goto(url) + await page.evaluate("window.scrollTo(0, document.body.scrollHeight)") + html = await page.content() + inner_text = await page.evaluate("() => document.body.innerText") + except Exception as e: + inner_text = f"Fail to load page content for {e}" + html = "" + return WebPage(inner_text=inner_text, html=html, url=url) async def _run_precheck(self, browser_type): if self._has_run_precheck: @@ -72,6 +78,10 @@ class PlaywrightWrapper: if CONFIG.global_proxy: kwargs["env"] = {"ALL_PROXY": CONFIG.global_proxy} await _install_browsers(self.browser_type, **kwargs) + + if self._has_run_precheck: + return + if not executable_path.exists(): parts = executable_path.parts available_paths = list(Path(*parts[:-3]).glob(f"{self.browser_type}-*")) @@ -85,25 +95,37 @@ class PlaywrightWrapper: self._has_run_precheck = True +def _get_install_lock(): + global _install_lock + if _install_lock is None: + _install_lock = asyncio.Lock() + return _install_lock + + async def _install_browsers(*browsers, **kwargs) -> None: - process = await asyncio.create_subprocess_exec( - sys.executable, - "-m", - "playwright", - "install", - *browsers, - "--with-deps", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - **kwargs, - ) + async with _get_install_lock(): + browsers = [i for i in browsers if i not in _install_cache] + if not browsers: + return + process = await asyncio.create_subprocess_exec( + sys.executable, + "-m", + "playwright", + "install", + *browsers, + # "--with-deps", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + **kwargs, + ) - await asyncio.gather(_log_stream(process.stdout, logger.info), _log_stream(process.stderr, logger.warning)) + await asyncio.gather(_log_stream(process.stdout, logger.info), _log_stream(process.stderr, logger.warning)) - if await process.wait() == 0: - logger.info(f"Install browser for playwright successfully.") - else: - logger.warning(f"Fail to install browser for playwright.") + if await process.wait() == 0: + logger.info("Install browser for playwright successfully.") + else: + logger.warning("Fail to install browser for playwright.") + _install_cache.update(browsers) async def _log_stream(sr, log_func): @@ -114,8 +136,14 @@ async def _log_stream(sr, log_func): log_func(f"[playwright install browser]: {line.decode().strip()}") +_install_lock: asyncio.Lock = None +_install_cache = set() + + if __name__ == "__main__": - for i in ("chromium", "firefox", "webkit"): - text = asyncio.run(PlaywrightWrapper(i).run("https://httpbin.org/ip")) - print(text) - print(i) + import fire + + async def main(url: str, *urls: str, browser_type: str = "chromium", **kwargs): + return await PlaywrightWrapper(browser_type, **kwargs).run(url, *urls) + + fire.Fire(main) diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index bd8a456ea..d727709b8 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -2,16 +2,17 @@ from __future__ import annotations import asyncio -from copy import deepcopy import importlib +from concurrent import futures +from copy import deepcopy from typing import Literal -from metagpt.config import CONFIG -import asyncio from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.wait import WebDriverWait -from concurrent import futures + +from metagpt.config import CONFIG +from metagpt.utils.parse_html import WebPage class SeleniumWrapper: @@ -48,7 +49,7 @@ class SeleniumWrapper: self.loop = loop self.executor = executor - async def run(self, url: str, *urls: str) -> str | list[str]: + async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: await self._run_precheck() _scrape = lambda url: self.loop.run_in_executor(self.executor, self._scrape_website, url) @@ -69,9 +70,15 @@ class SeleniumWrapper: def _scrape_website(self, url): with self._get_driver() as driver: - driver.get(url) - WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) - return driver.page_source + try: + driver.get(url) + WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) + inner_text = driver.execute_script("return document.body.innerText;") + html = driver.page_source + except Exception as e: + inner_text = f"Fail to load page content for {e}" + html = "" + return WebPage(inner_text=inner_text, html=html, url=url) _webdriver_manager_types = { @@ -97,6 +104,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): def _get_driver(): options = Options() options.add_argument("--headless") + options.add_argument("--enable-javascript") if browser_type == "chrome": options.add_argument("--no-sandbox") for i in args: @@ -107,5 +115,9 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): if __name__ == "__main__": - text = asyncio.run(SeleniumWrapper("chrome").run("https://fuzhi.ai/")) - print(text) + import fire + + async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs): + return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls) + + fire.Fire(main) diff --git a/metagpt/utils/parse_html.py b/metagpt/utils/parse_html.py new file mode 100644 index 000000000..62de26541 --- /dev/null +++ b/metagpt/utils/parse_html.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +from __future__ import annotations + +from typing import Generator, Optional +from urllib.parse import urljoin, urlparse + +from bs4 import BeautifulSoup +from pydantic import BaseModel + + +class WebPage(BaseModel): + inner_text: str + html: str + url: str + + class Config: + underscore_attrs_are_private = True + + _soup : Optional[BeautifulSoup] = None + _title: Optional[str] = None + + @property + def soup(self) -> BeautifulSoup: + if self._soup is None: + self._soup = BeautifulSoup(self.html, "html.parser") + return self._soup + + @property + def title(self): + if self._title is None: + title_tag = self.soup.find("title") + self._title = title_tag.text.strip() if title_tag is not None else "" + return self._title + + def get_links(self) -> Generator[str, None, None]: + for i in self.soup.find_all("a", href=True): + url = i["href"] + result = urlparse(url) + if not result.scheme and result.path: + yield urljoin(self.url, url) + elif url.startswith(("http://", "https://")): + yield urljoin(self.url, url) + + +def get_html_content(page: str, base: str): + soup = _get_soup(page) + + return soup.get_text(strip=True) + + +def _get_soup(page: str): + soup = BeautifulSoup(page, "html.parser") + # https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup + for s in soup(["style", "script", "[document]", "head", "title"]): + s.extract() + + return soup diff --git a/metagpt/utils/text.py b/metagpt/utils/text.py new file mode 100644 index 000000000..be3c52edd --- /dev/null +++ b/metagpt/utils/text.py @@ -0,0 +1,124 @@ +from typing import Generator, Sequence + +from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens + + +def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str: + """Reduce the length of concatenated message segments to fit within the maximum token size. + + Args: + msgs: A generator of strings representing progressively shorter valid prompts. + model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo") + system_text: The system prompts. + reserved: The number of reserved tokens. + + Returns: + The concatenated message segments reduced to fit within the maximum token size. + + Raises: + RuntimeError: If it fails to reduce the concatenated message length. + """ + max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved + for msg in msgs: + if count_string_tokens(msg, model_name) < max_token: + return msg + + raise RuntimeError("fail to reduce message length") + + +def generate_prompt_chunk( + text: str, + prompt_template: str, + model_name: str, + system_text: str, + reserved: int = 0, +) -> Generator[str, None, None]: + """Split the text into chunks of a maximum token size. + + Args: + text: The text to split. + prompt_template: The template for the prompt, containing a single `{}` placeholder. For example, "### Reference\n{}". + model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo") + system_text: The system prompts. + reserved: The number of reserved tokens. + + Yields: + The chunk of text. + """ + paragraphs = text.splitlines(keepends=True) + current_token = 0 + current_lines = [] + + reserved = reserved + count_string_tokens(prompt_template+system_text, model_name) + # 100 is a magic number to ensure the maximum context length is not exceeded + max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100 + + while paragraphs: + paragraph = paragraphs.pop(0) + token = count_string_tokens(paragraph, model_name) + if current_token + token <= max_token: + current_lines.append(paragraph) + current_token += token + elif token > max_token: + paragraphs = split_paragraph(paragraph) + paragraphs + continue + else: + yield prompt_template.format("".join(current_lines)) + current_lines = [paragraph] + current_token = token + + if current_lines: + yield prompt_template.format("".join(current_lines)) + + +def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str]: + """Split a paragraph into multiple parts. + + Args: + paragraph: The paragraph to split. + sep: The separator character. + count: The number of parts to split the paragraph into. + + Returns: + A list of split parts of the paragraph. + """ + for i in sep: + sentences = list(_split_text_with_ends(paragraph, i)) + if len(sentences) <= 1: + continue + ret = ["".join(j) for j in _split_by_count(sentences, count)] + return ret + return _split_by_count(paragraph, count) + + +def decode_unicode_escape(text: str) -> str: + """Decode a text with unicode escape sequences. + + Args: + text: The text to decode. + + Returns: + The decoded text. + """ + return text.encode("utf-8").decode("unicode_escape", "ignore") + + +def _split_by_count(lst: Sequence , count: int): + avg = len(lst) // count + remainder = len(lst) % count + start = 0 + for i in range(count): + end = start + avg + (1 if i < remainder else 0) + yield lst[start:end] + start = end + + +def _split_text_with_ends(text: str, sep: str = "."): + parts = [] + for i in text: + parts.append(i) + if i == sep: + yield "".join(parts) + parts = [] + if parts: + yield "".join(parts) diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 99ae5e176..591bb60f0 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -25,6 +25,21 @@ TOKEN_COSTS = { } +TOKEN_MAX = { + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-0301": 4096, + "gpt-3.5-turbo-0613": 4096, + "gpt-3.5-turbo-16k": 16384, + "gpt-3.5-turbo-16k-0613": 16384, + "gpt-4-0314": 8192, + "gpt-4": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0314": 32768, + "gpt-4-0613": 8192, + "text-embedding-ada-002": 8192, +} + + def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): """Return the number of tokens used by a list of messages.""" try: @@ -39,7 +54,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"): "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613", - }: + }: tokens_per_message = 3 tokens_per_name = 1 elif model == "gpt-3.5-turbo-0301": @@ -79,3 +94,18 @@ def count_string_tokens(string: str, model_name: str) -> int: """ encoding = tiktoken.encoding_for_model(model_name) return len(encoding.encode(string)) + + +def get_max_completion_tokens(messages: list[dict], model: str, default: int) -> int: + """Calculate the maximum number of completion tokens for a given model and list of messages. + + Args: + messages: A list of messages. + model: The model name. + + Returns: + The maximum number of completion tokens. + """ + if model not in TOKEN_MAX: + return default + return TOKEN_MAX[model] - count_message_tokens(messages) diff --git a/setup.py b/setup.py index e65696901..2a8edaae7 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ setup( install_requires=requirements, extras_require={ "playwright": ["playwright>=1.26", "beautifulsoup4"], - "selenium": ["selenium>4", "webdriver_manager<3.9", "beautifulsoup4"], + "selenium": ["selenium>4", "webdriver_manager", "beautifulsoup4"], }, cmdclass={ "install_mermaid": InstallMermaidCLI, diff --git a/tests/metagpt/roles/test_researcher.py b/tests/metagpt/roles/test_researcher.py new file mode 100644 index 000000000..01b5dae3b --- /dev/null +++ b/tests/metagpt/roles/test_researcher.py @@ -0,0 +1,32 @@ +from pathlib import Path +from random import random +from tempfile import TemporaryDirectory + +import pytest + +from metagpt.roles import researcher + + +async def mock_llm_ask(self, prompt: str, system_msgs): + if "Please provide up to 2 necessary keywords" in prompt: + return '["dataiku", "datarobot"]' + elif "Provide up to 4 queries related to your research topic" in prompt: + return '["Dataiku machine learning platform", "DataRobot AI platform comparison", ' \ + '"Dataiku vs DataRobot features", "Dataiku and DataRobot use cases"]' + elif "sort the remaining search results" in prompt: + return '[1,2]' + elif "Not relevant." in prompt: + return "Not relevant" if random() > 0.5 else prompt[-100:] + elif "provide a detailed research report" in prompt: + return f"# Research Report\n## Introduction\n{prompt}" + return "" + + +@pytest.mark.asyncio +async def test_researcher(mocker): + with TemporaryDirectory() as dirname: + topic = "dataiku vs. datarobot" + mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask) + researcher.RESEARCH_PATH = Path(dirname) + await researcher.Researcher().run(topic) + assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report") diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py index 2418c7b26..a7fe063a6 100644 --- a/tests/metagpt/tools/test_search_engine.py +++ b/tests/metagpt/tools/test_search_engine.py @@ -5,24 +5,44 @@ @Author : alexanderwu @File : test_search_engine.py """ +from __future__ import annotations import pytest from metagpt.logs import logger +from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine +class MockSearchEnine: + async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]: + rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)] + return "\n".join(rets) if as_string else rets + + @pytest.mark.asyncio -@pytest.mark.usefixtures("llm_api") -async def test_search_engine(llm_api): - search_engine = SearchEngine() - poetries = [ - # ("北京美食", "北京"), - ("屈臣氏", "屈臣氏") - ] - for i, j in poetries: - rsp = await search_engine.run(i) - # rsp = context.llm.ask_batch([prompt]) - logger.info(rsp) - # assert any(j in k['body'] for k in rsp) - assert len(rsp) > 0 +@pytest.mark.parametrize( + ("search_engine_typpe", "run_func", "max_results", "as_string"), + [ + (SearchEngineType.SERPAPI_GOOGLE, None, 8, True), + (SearchEngineType.SERPAPI_GOOGLE, None, 4, False), + (SearchEngineType.DIRECT_GOOGLE, None, 8, True), + (SearchEngineType.DIRECT_GOOGLE, None, 6, False), + (SearchEngineType.SERPER_GOOGLE, None, 8, True), + (SearchEngineType.SERPER_GOOGLE, None, 6, False), + (SearchEngineType.DUCK_DUCK_GO, None, 8, True), + (SearchEngineType.DUCK_DUCK_GO, None, 6, False), + (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False), + (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False), + + ], +) +async def test_search_engine(search_engine_typpe, run_func, max_results, as_string, ): + search_engine = SearchEngine(search_engine_typpe, run_func) + rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string) + logger.info(rsp) + if as_string: + assert isinstance(rsp, str) + else: + assert isinstance(rsp, list) + assert len(rsp) == max_results diff --git a/tests/metagpt/tools/test_web_browser_engine_playwright.py b/tests/metagpt/tools/test_web_browser_engine_playwright.py index 908f92112..69e1339e7 100644 --- a/tests/metagpt/tools/test_web_browser_engine_playwright.py +++ b/tests/metagpt/tools/test_web_browser_engine_playwright.py @@ -1,4 +1,5 @@ import pytest + from metagpt.config import CONFIG from metagpt.tools import web_browser_engine_playwright @@ -20,6 +21,7 @@ async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy CONFIG.global_proxy = proxy browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type, **kwagrs) result = await browser.run(url) + result = result.inner_text assert isinstance(result, str) assert "Deepwisdom" in result diff --git a/tests/metagpt/tools/test_web_browser_engine_selenium.py b/tests/metagpt/tools/test_web_browser_engine_selenium.py index 5ea1e3083..ce322f7bd 100644 --- a/tests/metagpt/tools/test_web_browser_engine_selenium.py +++ b/tests/metagpt/tools/test_web_browser_engine_selenium.py @@ -1,4 +1,5 @@ import pytest + from metagpt.config import CONFIG from metagpt.tools import web_browser_engine_selenium @@ -20,6 +21,7 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd) CONFIG.global_proxy = proxy browser = web_browser_engine_selenium.SeleniumWrapper(browser_type) result = await browser.run(url) + result = result.inner_text assert isinstance(result, str) assert "Deepwisdom" in result @@ -27,7 +29,7 @@ async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd) results = await browser.run(url, *urls) assert isinstance(results, list) assert len(results) == len(urls) + 1 - assert all(("Deepwisdom" in i) for i in results) + assert all(("Deepwisdom" in i.inner_text) for i in results) if use_proxy: assert "Proxy:" in capfd.readouterr().out finally: diff --git a/tests/metagpt/utils/test_parse_html.py b/tests/metagpt/utils/test_parse_html.py new file mode 100644 index 000000000..42be416a6 --- /dev/null +++ b/tests/metagpt/utils/test_parse_html.py @@ -0,0 +1,68 @@ +from metagpt.utils import parse_html + +PAGE = """ + + +
+This is a paragraph with a link and some emphasized text.
+| Header 1 | +Header 2 | +
|---|---|
| Row 1, Cell 1 | +Row 1, Cell 2 | +
| Row 2, Cell 1 | +Row 2, Cell 2 | +
+
+
+
+
+"""
+
+CONTENT = 'This is a HeadingThis is a paragraph witha linkand someemphasizedtext.Item 1Item 2Item 3Numbered Item 1Numbered '\
+'Item 2Numbered Item 3Header 1Header 2Row 1, Cell 1Row 1, Cell 2Row 2, Cell 1Row 2, Cell 2Name:Email:SubmitThis is a div '\
+'with a class "box".a link'
+
+
+def test_web_page():
+ page = parse_html.WebPage(inner_text=CONTENT, html=PAGE, url="http://example.com")
+ assert page.title == "Random HTML Example"
+ assert list(page.get_links()) == ["http://example.com/test", "https://metagpt.com"]
+
+
+def test_get_page_content():
+ ret = parse_html.get_html_content(PAGE, "http://example.com")
+ assert ret == CONTENT
diff --git a/tests/metagpt/utils/test_text.py b/tests/metagpt/utils/test_text.py
new file mode 100644
index 000000000..0caf8abaa
--- /dev/null
+++ b/tests/metagpt/utils/test_text.py
@@ -0,0 +1,77 @@
+import pytest
+
+from metagpt.utils.text import (
+ decode_unicode_escape,
+ generate_prompt_chunk,
+ reduce_message_length,
+ split_paragraph,
+)
+
+
+def _msgs():
+ length = 20
+ while length:
+ yield "Hello," * 1000 * length
+ length -= 1
+
+
+def _paragraphs(n):
+ return " ".join("Hello World." for _ in range(n))
+
+
+@pytest.mark.parametrize(
+ "msgs, model_name, system_text, reserved, expected",
+ [
+ (_msgs(), "gpt-3.5-turbo", "System", 1500, 1),
+ (_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
+ (_msgs(), "gpt-3.5-turbo-16k", "Hello," * 1000, 3000, 5),
+ (_msgs(), "gpt-4", "System", 2000, 3),
+ (_msgs(), "gpt-4", "Hello," * 1000, 2000, 2),
+ (_msgs(), "gpt-4-32k", "System", 4000, 14),
+ (_msgs(), "gpt-4-32k", "Hello," * 2000, 4000, 12),
+ ]
+)
+def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
+ assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
+
+
+@pytest.mark.parametrize(
+ "text, prompt_template, model_name, system_text, reserved, expected",
+ [
+ (" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1500, 2),
+ (" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
+ (" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
+ (" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
+ ]
+)
+def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
+ ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
+ assert len(ret) == expected
+
+
+@pytest.mark.parametrize(
+ "paragraph, sep, count, expected",
+ [
+ (_paragraphs(10), ".", 2, [_paragraphs(5), f" {_paragraphs(5)}"]),
+ (_paragraphs(10), ".", 3, [_paragraphs(4), f" {_paragraphs(3)}", f" {_paragraphs(3)}"]),
+ (f"{_paragraphs(5)}\n{_paragraphs(3)}", "\n.", 2, [f"{_paragraphs(5)}\n", _paragraphs(3)]),
+ ("......", ".", 2, ["...", "..."]),
+ ("......", ".", 3, ["..", "..", ".."]),
+ (".......", ".", 2, ["....", "..."]),
+ ]
+)
+def test_split_paragraph(paragraph, sep, count, expected):
+ ret = split_paragraph(paragraph, sep, count)
+ assert ret == expected
+
+
+@pytest.mark.parametrize(
+ "text, expected",
+ [
+ ("Hello\\nWorld", "Hello\nWorld"),
+ ("Hello\\tWorld", "Hello\tWorld"),
+ ("Hello\\u0020World", "Hello World"),
+ ]
+)
+def test_decode_unicode_escape(text, expected):
+ assert decode_unicode_escape(text) == expected