Merge branch 'geekan:main' into main

This commit is contained in:
Leon Yee 2023-08-10 12:00:14 -07:00 committed by GitHub
commit 2b91ca3dd0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 1489 additions and 275 deletions

View file

@ -22,6 +22,7 @@ from metagpt.actions.write_code_review import WriteCodeReview
from metagpt.actions.write_prd import WritePRD
from metagpt.actions.write_prd_review import WritePRDReview
from metagpt.actions.write_test import WriteTest
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
class ActionType(Enum):
@ -40,3 +41,6 @@ class ActionType(Enum):
WRITE_TASKS = WriteTasks
ASSIGN_TASKS = AssignTasks
SEARCH_AND_SUMMARIZE = SearchAndSummarize
COLLECT_LINKS = CollectLinks
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
CONDUCT_RESEARCH = ConductResearch

277
metagpt/actions/research.py Normal file
View file

@ -0,0 +1,277 @@
#!/usr/bin/env python
from __future__ import annotations
import asyncio
import json
from typing import Callable
from pydantic import parse_obj_as
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.tools.search_engine import SearchEngine
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
from metagpt.utils.text import generate_prompt_chunk, reduce_message_length
LANG_PROMPT = "Please respond in {language}."
RESEARCH_BASE_SYSTEM = """You are an AI critical thinker research assistant. Your sole purpose is to write well \
written, critically acclaimed, objective and structured reports on the given text."""
RESEARCH_TOPIC_SYSTEM = "You are an AI researcher assistant, and your research topic is:\n#TOPIC#\n{topic}"
SEARCH_TOPIC_PROMPT = """Please provide up to 2 necessary keywords related to your research topic for Google search. \
Your response must be in JSON format, for example: ["keyword1", "keyword2"]."""
SUMMARIZE_SEARCH_PROMPT = """### Requirements
1. The keywords related to your research topic and the search results are shown in the "Search Result Information" section.
2. Provide up to {decomposition_nums} queries related to your research topic base on the search results.
3. Please respond in the following JSON format: ["query1", "query2", "query3", ...].
### Search Result Information
{search_results}
"""
COLLECT_AND_RANKURLS_PROMPT = """### Topic
{topic}
### Query
{query}
### The online search results
{results}
### Requirements
Please remove irrelevant search results that are not related to the query or topic. Then, sort the remaining search results \
based on the link credibility. If two results have equal credibility, prioritize them based on the relevance. Provide the
ranked results' indices in JSON format, like [0, 1, 3, 4, ...], without including other words.
"""
WEB_BROWSE_AND_SUMMARIZE_PROMPT = '''### Requirements
1. Utilize the text in the "Reference Information" section to respond to the question "{query}".
2. If the question cannot be directly answered using the text, but the text is related to the research topic, please provide \
a comprehensive summary of the text.
3. If the text is entirely unrelated to the research topic, please reply with a simple text "Not relevant."
4. Include all relevant factual information, numbers, statistics, etc., if available.
### Reference Information
{content}
'''
CONDUCT_RESEARCH_PROMPT = '''### Reference Information
{content}
### Requirements
Please provide a detailed research report in response to the following topic: "{topic}", using the information provided \
above. The report must meet the following requirements:
- Focus on directly addressing the chosen topic.
- Ensure a well-structured and in-depth presentation, incorporating relevant facts and figures where available.
- Present data and findings in an intuitive manner, utilizing feature comparative tables, if applicable.
- The report should have a minimum word count of 2,000 and be formatted with Markdown syntax following APA style guidelines.
- Include all source URLs in APA format at the end of the report.
'''
class CollectLinks(Action):
"""Action class to collect links from a search engine."""
def __init__(
self,
name: str = "",
*args,
rank_func: Callable[[list[str]], None] | None = None,
**kwargs,
):
super().__init__(name, *args, **kwargs)
self.desc = "Collect links from a search engine."
self.search_engine = SearchEngine()
self.rank_func = rank_func
async def run(
self,
topic: str,
decomposition_nums: int = 4,
url_per_query: int = 4,
system_text: str | None = None,
) -> dict[str, list[str]]:
"""Run the action to collect links.
Args:
topic: The research topic.
decomposition_nums: The number of search questions to generate.
url_per_query: The number of URLs to collect per search question.
system_text: The system text.
Returns:
A dictionary containing the search questions as keys and the collected URLs as values.
"""
system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic)
keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text])
try:
keywords = json.loads(keywords)
keywords = parse_obj_as(list[str], keywords)
except Exception as e:
logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}")
keywords = [topic]
results = await asyncio.gather(*(self.search_engine.run(i, as_string=False) for i in keywords))
def gen_msg():
while True:
search_results = "\n".join(f"#### Keyword: {i}\n Search Result: {j}\n" for (i, j) in zip(keywords, results))
prompt = SUMMARIZE_SEARCH_PROMPT.format(decomposition_nums=decomposition_nums, search_results=search_results)
yield prompt
remove = max(results, key=len)
remove.pop()
if len(remove) == 0:
break
prompt = reduce_message_length(gen_msg(), self.llm.model, system_text, CONFIG.max_tokens_rsp)
logger.debug(prompt)
queries = await self._aask(prompt, [system_text])
try:
queries = json.loads(queries)
queries = parse_obj_as(list[str], queries)
except Exception as e:
logger.exception(f"fail to break down the research question due to {e}")
queries = keywords
ret = {}
for query in queries:
ret[query] = await self._search_and_rank_urls(topic, query, url_per_query)
return ret
async def _search_and_rank_urls(self, topic: str, query: str, num_results: int = 4) -> list[str]:
"""Search and rank URLs based on a query.
Args:
topic: The research topic.
query: The search query.
num_results: The number of URLs to collect.
Returns:
A list of ranked URLs.
"""
max_results = max(num_results * 2, 6)
results = await self.search_engine.run(query, max_results=max_results, as_string=False)
_results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results))
prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results)
logger.debug(prompt)
indices = await self._aask(prompt)
try:
indices = json.loads(indices)
assert all(isinstance(i, int) for i in indices)
except Exception as e:
logger.exception(f"fail to rank results for {e}")
indices = list(range(max_results))
results = [results[i] for i in indices]
if self.rank_func:
results = self.rank_func(results)
return [i["link"] for i in results[:num_results]]
class WebBrowseAndSummarize(Action):
"""Action class to explore the web and provide summaries of articles and webpages."""
def __init__(
self,
*args,
browse_func: Callable[[list[str]], None] | None = None,
**kwargs,
):
super().__init__(*args, **kwargs)
if CONFIG.model_for_researcher_summary:
self.llm.model = CONFIG.model_for_researcher_summary
self.web_browser_engine = WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if browse_func else None,
run_func=browse_func,
)
self.desc = "Explore the web and provide summaries of articles and webpages."
async def run(
self,
url: str,
*urls: str,
query: str,
system_text: str = RESEARCH_BASE_SYSTEM,
) -> dict[str, str]:
"""Run the action to browse the web and provide summaries.
Args:
url: The main URL to browse.
urls: Additional URLs to browse.
query: The research question.
system_text: The system text.
Returns:
A dictionary containing the URLs as keys and their summaries as values.
"""
contents = await self.web_browser_engine.run(url, *urls)
if not urls:
contents = [contents]
summaries = {}
prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}")
for u, content in zip([url, *urls], contents):
content = content.inner_text
chunk_summaries = []
for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, CONFIG.max_tokens_rsp):
logger.debug(prompt)
summary = await self._aask(prompt, [system_text])
if summary == "Not relevant.":
continue
chunk_summaries.append(summary)
if not chunk_summaries:
summaries[u] = None
continue
if len(chunk_summaries) == 1:
summaries[u] = chunk_summaries[0]
continue
content = "\n".join(chunk_summaries)
prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content)
summary = await self._aask(prompt, [system_text])
summaries[u] = summary
return summaries
class ConductResearch(Action):
"""Action class to conduct research and generate a research report."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if CONFIG.model_for_researcher_report:
self.llm.model = CONFIG.model_for_researcher_report
async def run(
self,
topic: str,
content: str,
system_text: str = RESEARCH_BASE_SYSTEM,
) -> str:
"""Run the action to conduct research and generate a research report.
Args:
topic: The research topic.
content: The content for research.
system_text: The system text.
Returns:
The generated research report.
"""
prompt = CONDUCT_RESEARCH_PROMPT.format(topic=topic, content=content)
logger.debug(prompt)
self.llm.auto_max_tokens = True
return await self._aask(prompt, [system_text])
def get_research_system_text(topic: str, language: str):
"""Get the system text for conducting research.
Args:
topic: The research topic.
language: The language for the system text.
Returns:
The system text for conducting research.
"""
return " ".join((RESEARCH_TOPIC_SYSTEM.format(topic=topic), LANG_PROMPT.format(language=language)))

View file

@ -27,7 +27,7 @@ Please summarize the cause of the errors and give correction instruction
Determine the ONE file to rewrite in order to fix the error, for example, xyz.py, or test_xyz.py
## Status:
Determine if all of the code works fine, if so write PASS, else FAIL,
WRITE ONLY ONE WORD, PASS OR FAIL, IN THI SECTION
WRITE ONLY ONE WORD, PASS OR FAIL, IN THIS SECTION
## Send To:
Please write Engineer if the errors are due to problematic development codes, and QaEngineer to problematic test codes, and NoOne if there are no errors,
WRITE ONLY ONE WORD, Engineer OR QaEngineer OR NoOne, IN THIS SECTION.

View file

@ -4,14 +4,14 @@
提供配置单例
"""
import os
import openai
import openai
import yaml
from metagpt.const import PROJECT_ROOT
from metagpt.logs import logger
from metagpt.utils.singleton import Singleton
from metagpt.tools import SearchEngineType, WebBrowserEngineType
from metagpt.utils.singleton import Singleton
class NotConfiguredException(Exception):
@ -46,7 +46,6 @@ class Config(metaclass=Singleton):
self.openai_api_key = self._get("OPENAI_API_KEY")
if not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key:
raise NotConfiguredException("Set OPENAI_API_KEY first")
self.openai_api_base = self._get("OPENAI_API_BASE")
if not self.openai_api_base or "YOUR_API_BASE" == self.openai_api_base:
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
@ -67,22 +66,22 @@ class Config(metaclass=Singleton):
self.google_api_key = self._get("GOOGLE_API_KEY")
self.google_cse_id = self._get("GOOGLE_CSE_ID")
self.search_engine = self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE)
self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", "playwright"))
self.playwright_browser_type = self._get("PLAYWRIGHT_BROWSER_TYPE", "chromium")
self.selenium_browser_type = self._get("SELENIUM_BROWSER_TYPE", "chrome")
self.long_term_memory = self._get('LONG_TERM_MEMORY', False)
if self.long_term_memory:
logger.warning("LONG_TERM_MEMORY is True")
self.max_budget = self._get("MAX_BUDGET", 10.0)
self.total_cost = 0.0
self.puppeteer_config = self._get("PUPPETEER_CONFIG","")
self.mmdc = self._get("MMDC","mmdc")
self.update_costs = self._get("UPDATE_COSTS",True)
self.calc_usage = self._get("CALC_USAGE",True)
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
self.mmdc = self._get("MMDC", "mmdc")
self.update_costs = self._get("UPDATE_COSTS", True)
self.calc_usage = self._get("CALC_USAGE", True)
self.model_for_researcher_summary = self._get("MODEL_FOR_RESEARCHER_SUMMARY")
self.model_for_researcher_report = self._get("MODEL_FOR_RESEARCHER_REPORT")
def _init_with_config_files_and_env(self, configs: dict, yaml_file):
"""从config/key.yaml / config/config.yaml / env三处按优先级递减加载"""

View file

@ -32,5 +32,6 @@ UT_PY_PATH = UT_PATH / "files/ut/"
API_QUESTIONS_PATH = UT_PATH / "files/question/"
YAPI_URL = "http://yapi.deepwisdomai.com/"
TMP = PROJECT_ROOT / 'tmp'
RESEARCH_PATH = DATA_PATH / "research"
MEM_TTL = 24 * 30 * 3600

View file

@ -1,4 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/5 23:08
@ -7,10 +6,11 @@
"""
import asyncio
import time
from functools import wraps
from typing import NamedTuple
import openai
from openai.error import APIConnectionError
from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type
from metagpt.config import CONFIG
from metagpt.logs import logger
@ -20,33 +20,22 @@ from metagpt.utils.token_counter import (
TOKEN_COSTS,
count_message_tokens,
count_string_tokens,
get_max_completion_tokens,
)
def retry(max_retries):
def decorator(f):
@wraps(f)
async def wrapper(*args, **kwargs):
for i in range(max_retries):
try:
return await f(*args, **kwargs)
except Exception:
if i == max_retries - 1:
raise
await asyncio.sleep(2 ** i)
return wrapper
return decorator
class RateLimiter:
"""Rate control class, each call goes through wait_if_needed, sleep if rate control is needed"""
def __init__(self, rpm):
self.last_call_time = 0
self.interval = 1.1 * 60 / rpm # Here 1.1 is used because even if the calls are made strictly according to time, they will still be QOS'd; consider switching to simple error retry later
# Here 1.1 is used because even if the calls are made strictly according to time,
# they will still be QOS'd; consider switching to simple error retry later
self.interval = 1.1 * 60 / rpm
self.rpm = rpm
def split_batches(self, batch):
return [batch[i:i + self.rpm] for i in range(0, len(batch), self.rpm)]
return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)]
async def wait_if_needed(self, num_requests):
current_time = time.time()
@ -69,6 +58,7 @@ class Costs(NamedTuple):
class CostManager(metaclass=Singleton):
"""计算使用接口的开销"""
def __init__(self):
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
@ -86,13 +76,12 @@ class CostManager(metaclass=Singleton):
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"]
+ completion_tokens * TOKEN_COSTS[model]["completion"]
) / 1000
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000
self.total_cost += cost
logger.info(f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
f"Current cost: ${cost:.3f}, {prompt_tokens=}, {completion_tokens=}")
logger.info(
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
def get_total_prompt_tokens(self):
@ -127,14 +116,25 @@ class CostManager(metaclass=Singleton):
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
def log_and_reraise(retry_state):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning("""
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
See FAQ 5.8
""")
raise retry_state.outcome.exception()
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"""
Check https://platform.openai.com/examples for examples
"""
def __init__(self):
self.__init_openai(CONFIG)
self.llm = openai
self.model = CONFIG.openai_api_model
self.auto_max_tokens = False
self._cost_manager = CostManager()
RateLimiter.__init__(self, rpm=self.rpm)
@ -148,10 +148,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
self.rpm = int(config.get("RPM", 10))
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response = await openai.ChatCompletion.acreate(
**self._cons_kwargs(messages),
stream=True
)
response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True)
# create variables to collect the stream of chunks
collected_chunks = []
@ -159,41 +156,42 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
chunk_message = chunk['choices'][0]['delta'] # extract the message
chunk_message = chunk["choices"][0]["delta"] # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
print()
full_reply_content = ''.join([m.get('content', '') for m in collected_messages])
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
def _cons_kwargs(self, messages: list[dict]) -> dict:
if CONFIG.openai_api_type == 'azure':
if CONFIG.openai_api_type == "azure":
kwargs = {
"deployment_id": CONFIG.deployment_id,
"messages": messages,
"max_tokens": CONFIG.max_tokens_rsp,
"max_tokens": self.get_max_tokens(messages),
"n": 1,
"stop": None,
"temperature": 0.3
"temperature": 0.3,
}
else:
kwargs = {
"model": self.model,
"messages": messages,
"max_tokens": CONFIG.max_tokens_rsp,
"max_tokens": self.get_max_tokens(messages),
"n": 1,
"stop": None,
"temperature": 0.3
"temperature": 0.3,
}
kwargs["timeout"] = 3
return kwargs
async def _achat_completion(self, messages: list[dict]) -> dict:
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages))
self._update_costs(rsp.get('usage'))
self._update_costs(rsp.get("usage"))
return rsp
def _chat_completion(self, messages: list[dict]) -> dict:
@ -211,7 +209,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
# messages = self.messages_to_dict(messages)
return await self._achat_completion(messages)
@retry(max_retries=6)
@retry(
stop=stop_after_attempt(3),
wait=wait_fixed(1),
after=after_log(logger, logger.level('WARNING').name),
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
"""when streaming, print each token in place."""
if stream:
@ -262,3 +266,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()
def get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return CONFIG.max_tokens_rsp
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)

View file

@ -0,0 +1,93 @@
#!/usr/bin/env python
import asyncio
from pydantic import BaseModel
from metagpt.actions import CollectLinks, ConductResearch, WebBrowseAndSummarize
from metagpt.actions.research import get_research_system_text
from metagpt.const import RESEARCH_PATH
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
class Report(BaseModel):
topic: str
links: dict[str, list[str]] = None
summaries: list[tuple[str, str]] = None
content: str = ""
class Researcher(Role):
def __init__(
self,
name: str = "David",
profile: str = "Researcher",
goal: str = "Gather information and conduct research",
constraints: str = "Ensure accuracy and relevance of information",
language: str = "en-us",
**kwargs,
):
super().__init__(name, profile, goal, constraints, **kwargs)
self._init_actions([CollectLinks(name), WebBrowseAndSummarize(name), ConductResearch(name)])
self.language = language
if language not in ("en-us", "zh-cn"):
logger.warning(f"The language `{language}` has not been tested, it may not work.")
async def _think(self) -> None:
if self._rc.todo is None:
self._set_state(0)
return
if self._rc.state + 1 < len(self._states):
self._set_state(self._rc.state + 1)
else:
self._rc.todo = None
async def _act(self) -> Message:
logger.info(f"{self._setting}: ready to {self._rc.todo}")
todo = self._rc.todo
msg = self._rc.memory.get(k=1)[0]
if isinstance(msg.instruct_content, Report):
instruct_content = msg.instruct_content
topic = instruct_content.topic
else:
topic = msg.content
research_system_text = get_research_system_text(topic, self.language)
if isinstance(todo, CollectLinks):
links = await todo.run(topic, 4, 4)
ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=type(todo))
elif isinstance(todo, WebBrowseAndSummarize):
links = instruct_content.links
todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items())
summaries = await asyncio.gather(*todos)
summaries = list((url, summary) for i in summaries for (url, summary) in i.items() if summary)
ret = Message("", Report(topic=topic, summaries=summaries), role=self.profile, cause_by=type(todo))
else:
summaries = instruct_content.summaries
summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries)
content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text)
ret = Message("", Report(topic=topic, content=content), role=self.profile, cause_by=type(self._rc.todo))
self._rc.memory.add(ret)
return ret
async def _react(self) -> Message:
while True:
await self._think()
if self._rc.todo is None:
break
msg = await self._act()
report = msg.instruct_content
self.write_report(report.topic, report.content)
return msg
def write_report(self, topic: str, content: str):
filepath = RESEARCH_PATH / f"{topic}.md"
filepath.write_text(content)
if __name__ == "__main__":
role = Researcher(language="en-us")
asyncio.run(role.run("dataiku vs. datarobot"))

View file

@ -14,6 +14,7 @@ class SearchEngineType(Enum):
SERPAPI_GOOGLE = auto()
DIRECT_GOOGLE = auto()
SERPER_GOOGLE = auto()
DUCK_DUCK_GO = auto()
CUSTOM_ENGINE = auto()

View file

@ -7,122 +7,76 @@
"""
from __future__ import annotations
import json
import importlib
from typing import Callable, Coroutine, Literal, overload
from metagpt.config import Config
from metagpt.logs import logger
from metagpt.tools.search_engine_serpapi import SerpAPIWrapper
from metagpt.tools.search_engine_serper import SerperWrapper
config = Config()
from metagpt.config import CONFIG
from metagpt.tools import SearchEngineType
class SearchEngine:
"""
TODO: 合入Google Search 并进行反代
这里Google需要挂Proxifier或者类似全局代理
- DDG: https://pypi.org/project/duckduckgo-search/
- GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9
"""
def __init__(self, engine=None, run_func=None):
self.config = Config()
self.run_func = run_func
self.engine = engine or self.config.search_engine
"""Class representing a search engine.
@classmethod
def run_google(cls, query, max_results=8):
# results = ddg(query, max_results=max_results)
results = google_official_search(query, num_results=max_results)
logger.info(results)
return results
Args:
engine: The search engine type. Defaults to the search engine specified in the config.
run_func: The function to run the search. Defaults to None.
async def run(self, query: str, max_results=8):
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
api = SerpAPIWrapper()
rsp = await api.run(query)
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
rsp = SearchEngine.run_google(query, max_results)
elif self.engine == SearchEngineType.SERPER_GOOGLE:
api = SerperWrapper()
rsp = await api.run(query)
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
rsp = self.run_func(query)
Attributes:
run_func: The function to run the search.
engine: The search engine type.
"""
def __init__(
self,
engine: SearchEngineType | None = None,
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None,
):
engine = engine or CONFIG.search_engine
if engine == SearchEngineType.SERPAPI_GOOGLE:
module = "metagpt.tools.search_engine_serpapi"
run_func = importlib.import_module(module).SerpAPIWrapper().run
elif engine == SearchEngineType.SERPER_GOOGLE:
module = "metagpt.tools.search_engine_serper"
run_func = importlib.import_module(module).SerperWrapper().run
elif engine == SearchEngineType.DIRECT_GOOGLE:
module = "metagpt.tools.search_engine_googleapi"
run_func = importlib.import_module(module).GoogleAPIWrapper().run
elif engine == SearchEngineType.DUCK_DUCK_GO:
module = "metagpt.tools.search_engine_ddg"
run_func = importlib.import_module(module).DDGAPIWrapper().run
elif engine == SearchEngineType.CUSTOM_ENGINE:
pass # run_func = run_func
else:
raise NotImplementedError
return rsp
self.engine = engine
self.run_func = run_func
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
) -> str:
...
def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]:
"""Return the results of a Google search using the official Google API
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
) -> list[dict[str, str]]:
...
Args:
query (str): The search query.
num_results (int): The number of results to return.
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
"""Run a search query.
Returns:
str: The results of the search.
"""
Args:
query: The search query.
max_results: The maximum number of results to return. Defaults to 8.
as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True.
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
try:
api_key = config.google_api_key
custom_search_engine_id = config.google_cse_id
with build("customsearch", "v1", developerKey=api_key) as service:
result = (
service.cse()
.list(q=query, cx=custom_search_engine_id, num=num_results)
.execute()
)
logger.info(result)
# Extract the search result items from the response
search_results = result.get("items", [])
# Create a list of only the URLs from the search results
search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
except HttpError as e:
# Handle errors in the API call
error_details = json.loads(e.content.decode())
# Check if the error is related to an invalid or missing API key
if error_details.get("error", {}).get(
"code"
) == 403 and "invalid API key" in error_details.get("error", {}).get(
"message", ""
):
return "Error: The provided Google API key is invalid or missing."
else:
return f"Error: {e}"
# google_result can be a list or a string depending on the search results
# Return the list of search result URLs
return search_results_details
def safe_google_results(results: str | list) -> str:
"""
Return the results of a google search in a safe format.
Args:
results (str | list): The search results.
Returns:
str: The results of the search.
"""
if isinstance(results, list):
safe_message = json.dumps(
# FIXME: # .encode("utf-8", "ignore") 这里去掉了但是AutoGPT里有很奇怪
[result for result in results]
)
else:
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
return safe_message
if __name__ == '__main__':
SearchEngine.run(query='wtf')
Returns:
The search results as a string or a list of dictionaries.
"""
return await self.run_func(query, max_results=max_results, as_string=as_string)

View file

@ -0,0 +1,107 @@
#!/usr/bin/env python
from __future__ import annotations
import asyncio
import json
from concurrent import futures
from typing import Literal, overload
from duckduckgo_search import DDGS
from googleapiclient.errors import HttpError
from metagpt.config import CONFIG
from metagpt.logs import logger
class DDGAPIWrapper:
"""Wrapper around duckduckgo_search API.
To use this module, you should have the `duckduckgo_search` Python package installed.
"""
def __init__(
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
):
kwargs = {}
if CONFIG.global_proxy:
kwargs["proxies"] = CONFIG.global_proxy
self.loop = loop
self.executor = executor
self.ddgs = DDGS(**kwargs)
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
focus: list[str] | None = None,
) -> str:
...
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
focus: list[str] | None = None,
) -> list[dict[str, str]]:
...
async def run(
self,
query: str,
max_results: int = 8,
as_string: bool = True,
) -> str | list[dict]:
"""Return the results of a Google search using the official Google API
Args:
query: The search query.
max_results: The number of results to return.
as_string: A boolean flag to determine the return type of the results. If True, the function will
return a formatted string with the search results. If False, it will return a list of dictionaries
containing detailed information about each search result.
Returns:
The results of the search.
"""
loop = self.loop or asyncio.get_event_loop()
future = loop.run_in_executor(
self.executor,
self._search_from_ddgs,
query,
max_results,
)
try:
search_results = await future
# Extract the search result items from the response
except HttpError as e:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
search_results = []
# Return the list of search result URLs
if as_string:
return json.dumps(search_results, ensure_ascii=False)
return search_results
def _search_from_ddgs(self, query: str, max_results: int):
return [
{
"link": i["href"],
"snippet": i["body"],
"title": i["title"]
} for (_, i) in zip(range(max_results), self.ddgs.text(query))
]
if __name__ == "__main__":
import fire
fire.Fire(DDGAPIWrapper().run)

View file

@ -0,0 +1,117 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
import json
from concurrent import futures
from urllib.parse import urlparse
import httplib2
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from metagpt.config import CONFIG
from metagpt.logs import logger
class GoogleAPIWrapper:
"""Wrapper around GoogleAPI.
To use this module, you should have the `google-api-python-client` Python package installed
and set property values for the configurations `GOOGLE_API_KEY` and `GOOGLE_CSE_ID`. See
https://programmablesearchengine.google.com/controlpanel/all.
"""
def __init__(
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
):
build_kwargs = {"developerKey": CONFIG.google_api_key}
if CONFIG.global_proxy:
parse_result = urlparse(CONFIG.global_proxy)
proxy_type = parse_result.scheme
if proxy_type == "https":
proxy_type = "http"
build_kwargs["http"] = httplib2.Http(
proxy_info=httplib2.ProxyInfo(
getattr(httplib2.socks, f"PROXY_TYPE_{proxy_type.upper()}"),
parse_result.hostname,
parse_result.port,
),
)
service = build("customsearch", "v1", **build_kwargs)
self.google_api_client = service.cse()
self.custom_search_engine_id = CONFIG.google_cse_id
self.loop = loop
self.executor = executor
async def run(
self,
query: str,
max_results: int = 8,
as_string: bool = True,
focus: list[str] | None = None,
) -> str | list[dict]:
"""Return the results of a Google search using the official Google API.
Args:
query: The search query.
max_results: The number of results to return.
as_string: A boolean flag to determine the return type of the results. If True, the function will
return a formatted string with the search results. If False, it will return a list of dictionaries
containing detailed information about each search result.
focus: Specific information to be focused on from each search result.
Returns:
The results of the search.
"""
loop = self.loop or asyncio.get_event_loop()
future = loop.run_in_executor(
self.executor,
self.google_api_client.list(
q=query,
num=max_results,
cx=self.custom_search_engine_id
).execute
)
try:
result = await future
# Extract the search result items from the response
search_results = result.get("items", [])
except HttpError as e:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
search_results = []
focus = focus or ["snippet", "link", "title"]
details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
# Return the list of search result URLs
if as_string:
return safe_google_results(details)
return details
def safe_google_results(results: str | list) -> str:
"""Return the results of a google search in a safe format.
Args:
results: The search results.
Returns:
The results of the search.
"""
if isinstance(results, list):
safe_message = json.dumps([result for result in results])
else:
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
return safe_message
if __name__ == "__main__":
import fire
fire.Fire(GoogleAPIWrapper().run)

View file

@ -37,16 +37,17 @@ class SerpAPIWrapper(BaseModel):
class Config:
arbitrary_types_allowed = True
async def run(self, query: str, **kwargs: Any) -> str:
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result async."""
return self._process_response(await self.results(query))
return self._process_response(await self.results(query, max_results), as_string=as_string)
async def results(self, query: str) -> dict:
async def results(self, query: str, max_results: int) -> dict:
"""Use aiohttp to run query through SerpAPI and return the results async."""
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
params = self.get_params(query)
params["source"] = "python"
params["num"] = max_results
if self.serpapi_api_key:
params["serp_api_key"] = self.serpapi_api_key
params["output"] = "json"
@ -74,10 +75,10 @@ class SerpAPIWrapper(BaseModel):
return params
@staticmethod
def _process_response(res: dict) -> str:
def _process_response(res: dict, as_string: bool) -> str:
"""Process response from SerpAPI."""
# logger.debug(res)
focus = ['title', 'snippet', 'link']
focus = ["title", "snippet", "link"]
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
if "error" in res.keys():
@ -86,20 +87,11 @@ class SerpAPIWrapper(BaseModel):
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
elif (
"answer_box" in res.keys()
and "snippet_highlighted_words" in res["answer_box"].keys()
):
elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys():
toret = res["answer_box"]["snippet_highlighted_words"][0]
elif (
"sports_results" in res.keys()
and "game_spotlight" in res["sports_results"].keys()
):
elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys():
toret = res["sports_results"]["game_spotlight"]
elif (
"knowledge_graph" in res.keys()
and "description" in res["knowledge_graph"].keys()
):
elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
toret = res["knowledge_graph"]["description"]
elif "snippet" in res["organic_results"][0].keys():
toret = res["organic_results"][0]["snippet"]
@ -112,4 +104,10 @@ class SerpAPIWrapper(BaseModel):
if res.get("organic_results"):
toret_l += [get_focused(i) for i in res.get("organic_results")]
return str(toret) + '\n' + str(toret_l)
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
if __name__ == "__main__":
import fire
fire.Fire(SerpAPIWrapper().run)

View file

@ -36,16 +36,19 @@ class SerperWrapper(BaseModel):
class Config:
arbitrary_types_allowed = True
async def run(self, query: str, **kwargs: Any) -> str:
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through Serper and parse result async."""
queries = query.split("\n")
return "\n".join([self._process_response(res) for res in await self.results(queries)])
if isinstance(query, str):
return self._process_response((await self.results([query], max_results))[0], as_string=as_string)
else:
results = [self._process_response(res, as_string) for res in await self.results(query, max_results)]
return "\n".join(results) if as_string else results
async def results(self, queries: list[str]) -> dict:
async def results(self, queries: list[str], max_results: int = 8) -> dict:
"""Use aiohttp to run query through Serper and return the results async."""
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
payloads = self.get_payloads(queries)
payloads = self.get_payloads(queries, max_results)
url = "https://google.serper.dev/search"
headers = self.get_headers()
return url, payloads, headers
@ -61,12 +64,13 @@ class SerperWrapper(BaseModel):
return res
def get_payloads(self, queries: list[str]) -> Dict[str, str]:
def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]:
"""Get payloads for Serper."""
payloads = []
for query in queries:
_payload = {
"q": query,
"num": max_results,
}
payloads.append({**self.payload, **_payload})
return json.dumps(payloads, sort_keys=True)
@ -79,7 +83,7 @@ class SerperWrapper(BaseModel):
return headers
@staticmethod
def _process_response(res: dict) -> str:
def _process_response(res: dict, as_string: bool = False) -> str:
"""Process response from SerpAPI."""
# logger.debug(res)
focus = ['title', 'snippet', 'link']
@ -117,4 +121,10 @@ class SerperWrapper(BaseModel):
if res.get("organic"):
toret_l += [get_focused(i) for i in res.get("organic")]
return str(toret) + '\n' + str(toret_l)
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
if __name__ == "__main__":
import fire
fire.Fire(SerperWrapper().run)

View file

@ -1,22 +1,20 @@
#!/usr/bin/env python
from __future__ import annotations
import asyncio
import importlib
from typing import Any, Callable, Coroutine, overload
import importlib
from typing import Any, Callable, Coroutine, Literal, overload
from metagpt.config import CONFIG
from metagpt.tools import WebBrowserEngineType
from bs4 import BeautifulSoup
from metagpt.utils.parse_html import WebPage
class WebBrowserEngine:
def __init__(
self,
engine: WebBrowserEngineType | None = None,
run_func: Callable[..., Coroutine[Any, Any, str | list[str]]] | None = None,
parse_func: Callable[[str], str] | None = None,
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):
engine = engine or CONFIG.web_browser_engine
@ -30,30 +28,25 @@ class WebBrowserEngine:
run_func = run_func
else:
raise NotImplementedError
self.parse_func = parse_func or get_page_content
self.run_func = run_func
self.engine = engine
@overload
async def run(self, url: str) -> str:
async def run(self, url: str) -> WebPage:
...
@overload
async def run(self, url: str, *urls: str) -> list[str]:
async def run(self, url: str, *urls: str) -> list[WebPage]:
...
async def run(self, url: str, *urls: str) -> str | list[str]:
page = await self.run_func(url, *urls)
if isinstance(page, str):
return self.parse_func(page)
return [self.parse_func(i) for i in page]
def get_page_content(page: str):
soup = BeautifulSoup(page, "html.parser")
return "\n".join(i.text.strip() for i in soup.find_all(["h1", "h2", "h3", "h4", "h5", "p", "pre"]))
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
return await self.run_func(url, *urls)
if __name__ == "__main__":
text = asyncio.run(WebBrowserEngine().run("https://fuzhi.ai/"))
print(text)
import fire
async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs):
return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
fire.Fire(main)

View file

@ -2,12 +2,15 @@
from __future__ import annotations
import asyncio
from pathlib import Path
import sys
from pathlib import Path
from typing import Literal
from playwright.async_api import async_playwright
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.utils.parse_html import WebPage
class PlaywrightWrapper:
@ -16,7 +19,7 @@ class PlaywrightWrapper:
To use this module, you should have the `playwright` Python package installed and ensure that
the required browsers are also installed. You can install playwright by running the command
`pip install metagpt[playwright]` and download the necessary browser binaries by running the
command `playwright install` for the first time."
command `playwright install` for the first time.
"""
def __init__(
@ -40,27 +43,30 @@ class PlaywrightWrapper:
self._context_kwargs = context_kwargs
self._has_run_precheck = False
async def run(self, url: str, *urls: str) -> str | list[str]:
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
async with async_playwright() as ap:
browser_type = getattr(ap, self.browser_type)
await self._run_precheck(browser_type)
browser = await browser_type.launch(**self.launch_kwargs)
async def _scrape(url):
context = await browser.new_context(**self._context_kwargs)
page = await context.new_page()
async with page:
try:
await page.goto(url)
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
content = await page.content()
return content
except Exception as e:
return f"Fail to load page content for {e}"
_scrape = self._scrape
if urls:
return await asyncio.gather(_scrape(url), *(_scrape(i) for i in urls))
return await _scrape(url)
return await asyncio.gather(_scrape(browser, url), *(_scrape(browser, i) for i in urls))
return await _scrape(browser, url)
async def _scrape(self, browser, url):
context = await browser.new_context(**self._context_kwargs)
page = await context.new_page()
async with page:
try:
await page.goto(url)
await page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
html = await page.content()
inner_text = await page.evaluate("() => document.body.innerText")
except Exception as e:
inner_text = f"Fail to load page content for {e}"
html = ""
return WebPage(inner_text=inner_text, html=html, url=url)
async def _run_precheck(self, browser_type):
if self._has_run_precheck:
@ -72,6 +78,10 @@ class PlaywrightWrapper:
if CONFIG.global_proxy:
kwargs["env"] = {"ALL_PROXY": CONFIG.global_proxy}
await _install_browsers(self.browser_type, **kwargs)
if self._has_run_precheck:
return
if not executable_path.exists():
parts = executable_path.parts
available_paths = list(Path(*parts[:-3]).glob(f"{self.browser_type}-*"))
@ -85,25 +95,37 @@ class PlaywrightWrapper:
self._has_run_precheck = True
def _get_install_lock():
global _install_lock
if _install_lock is None:
_install_lock = asyncio.Lock()
return _install_lock
async def _install_browsers(*browsers, **kwargs) -> None:
process = await asyncio.create_subprocess_exec(
sys.executable,
"-m",
"playwright",
"install",
*browsers,
"--with-deps",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
**kwargs,
)
async with _get_install_lock():
browsers = [i for i in browsers if i not in _install_cache]
if not browsers:
return
process = await asyncio.create_subprocess_exec(
sys.executable,
"-m",
"playwright",
"install",
*browsers,
# "--with-deps",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
**kwargs,
)
await asyncio.gather(_log_stream(process.stdout, logger.info), _log_stream(process.stderr, logger.warning))
await asyncio.gather(_log_stream(process.stdout, logger.info), _log_stream(process.stderr, logger.warning))
if await process.wait() == 0:
logger.info(f"Install browser for playwright successfully.")
else:
logger.warning(f"Fail to install browser for playwright.")
if await process.wait() == 0:
logger.info("Install browser for playwright successfully.")
else:
logger.warning("Fail to install browser for playwright.")
_install_cache.update(browsers)
async def _log_stream(sr, log_func):
@ -114,8 +136,14 @@ async def _log_stream(sr, log_func):
log_func(f"[playwright install browser]: {line.decode().strip()}")
_install_lock: asyncio.Lock = None
_install_cache = set()
if __name__ == "__main__":
for i in ("chromium", "firefox", "webkit"):
text = asyncio.run(PlaywrightWrapper(i).run("https://httpbin.org/ip"))
print(text)
print(i)
import fire
async def main(url: str, *urls: str, browser_type: str = "chromium", **kwargs):
return await PlaywrightWrapper(browser_type, **kwargs).run(url, *urls)
fire.Fire(main)

View file

@ -2,16 +2,17 @@
from __future__ import annotations
import asyncio
from copy import deepcopy
import importlib
from concurrent import futures
from copy import deepcopy
from typing import Literal
from metagpt.config import CONFIG
import asyncio
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
from concurrent import futures
from metagpt.config import CONFIG
from metagpt.utils.parse_html import WebPage
class SeleniumWrapper:
@ -48,7 +49,7 @@ class SeleniumWrapper:
self.loop = loop
self.executor = executor
async def run(self, url: str, *urls: str) -> str | list[str]:
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
await self._run_precheck()
_scrape = lambda url: self.loop.run_in_executor(self.executor, self._scrape_website, url)
@ -69,9 +70,15 @@ class SeleniumWrapper:
def _scrape_website(self, url):
with self._get_driver() as driver:
driver.get(url)
WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
return driver.page_source
try:
driver.get(url)
WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
inner_text = driver.execute_script("return document.body.innerText;")
html = driver.page_source
except Exception as e:
inner_text = f"Fail to load page content for {e}"
html = ""
return WebPage(inner_text=inner_text, html=html, url=url)
_webdriver_manager_types = {
@ -97,6 +104,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
def _get_driver():
options = Options()
options.add_argument("--headless")
options.add_argument("--enable-javascript")
if browser_type == "chrome":
options.add_argument("--no-sandbox")
for i in args:
@ -107,5 +115,9 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
if __name__ == "__main__":
text = asyncio.run(SeleniumWrapper("chrome").run("https://fuzhi.ai/"))
print(text)
import fire
async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs):
return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls)
fire.Fire(main)

View file

@ -0,0 +1,57 @@
#!/usr/bin/env python
from __future__ import annotations
from typing import Generator, Optional
from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup
from pydantic import BaseModel
class WebPage(BaseModel):
inner_text: str
html: str
url: str
class Config:
underscore_attrs_are_private = True
_soup : Optional[BeautifulSoup] = None
_title: Optional[str] = None
@property
def soup(self) -> BeautifulSoup:
if self._soup is None:
self._soup = BeautifulSoup(self.html, "html.parser")
return self._soup
@property
def title(self):
if self._title is None:
title_tag = self.soup.find("title")
self._title = title_tag.text.strip() if title_tag is not None else ""
return self._title
def get_links(self) -> Generator[str, None, None]:
for i in self.soup.find_all("a", href=True):
url = i["href"]
result = urlparse(url)
if not result.scheme and result.path:
yield urljoin(self.url, url)
elif url.startswith(("http://", "https://")):
yield urljoin(self.url, url)
def get_html_content(page: str, base: str):
soup = _get_soup(page)
return soup.get_text(strip=True)
def _get_soup(page: str):
soup = BeautifulSoup(page, "html.parser")
# https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup
for s in soup(["style", "script", "[document]", "head", "title"]):
s.extract()
return soup

124
metagpt/utils/text.py Normal file
View file

@ -0,0 +1,124 @@
from typing import Generator, Sequence
from metagpt.utils.token_counter import TOKEN_MAX, count_string_tokens
def reduce_message_length(msgs: Generator[str, None, None], model_name: str, system_text: str, reserved: int = 0,) -> str:
"""Reduce the length of concatenated message segments to fit within the maximum token size.
Args:
msgs: A generator of strings representing progressively shorter valid prompts.
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
system_text: The system prompts.
reserved: The number of reserved tokens.
Returns:
The concatenated message segments reduced to fit within the maximum token size.
Raises:
RuntimeError: If it fails to reduce the concatenated message length.
"""
max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved
for msg in msgs:
if count_string_tokens(msg, model_name) < max_token:
return msg
raise RuntimeError("fail to reduce message length")
def generate_prompt_chunk(
text: str,
prompt_template: str,
model_name: str,
system_text: str,
reserved: int = 0,
) -> Generator[str, None, None]:
"""Split the text into chunks of a maximum token size.
Args:
text: The text to split.
prompt_template: The template for the prompt, containing a single `{}` placeholder. For example, "### Reference\n{}".
model_name: The name of the encoding to use. (e.g., "gpt-3.5-turbo")
system_text: The system prompts.
reserved: The number of reserved tokens.
Yields:
The chunk of text.
"""
paragraphs = text.splitlines(keepends=True)
current_token = 0
current_lines = []
reserved = reserved + count_string_tokens(prompt_template+system_text, model_name)
# 100 is a magic number to ensure the maximum context length is not exceeded
max_token = TOKEN_MAX.get(model_name, 2048) - reserved - 100
while paragraphs:
paragraph = paragraphs.pop(0)
token = count_string_tokens(paragraph, model_name)
if current_token + token <= max_token:
current_lines.append(paragraph)
current_token += token
elif token > max_token:
paragraphs = split_paragraph(paragraph) + paragraphs
continue
else:
yield prompt_template.format("".join(current_lines))
current_lines = [paragraph]
current_token = token
if current_lines:
yield prompt_template.format("".join(current_lines))
def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str]:
"""Split a paragraph into multiple parts.
Args:
paragraph: The paragraph to split.
sep: The separator character.
count: The number of parts to split the paragraph into.
Returns:
A list of split parts of the paragraph.
"""
for i in sep:
sentences = list(_split_text_with_ends(paragraph, i))
if len(sentences) <= 1:
continue
ret = ["".join(j) for j in _split_by_count(sentences, count)]
return ret
return _split_by_count(paragraph, count)
def decode_unicode_escape(text: str) -> str:
"""Decode a text with unicode escape sequences.
Args:
text: The text to decode.
Returns:
The decoded text.
"""
return text.encode("utf-8").decode("unicode_escape", "ignore")
def _split_by_count(lst: Sequence , count: int):
avg = len(lst) // count
remainder = len(lst) % count
start = 0
for i in range(count):
end = start + avg + (1 if i < remainder else 0)
yield lst[start:end]
start = end
def _split_text_with_ends(text: str, sep: str = "."):
parts = []
for i in text:
parts.append(i)
if i == sep:
yield "".join(parts)
parts = []
if parts:
yield "".join(parts)

View file

@ -25,6 +25,21 @@ TOKEN_COSTS = {
}
TOKEN_MAX = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,
"gpt-3.5-turbo-0613": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-3.5-turbo-16k-0613": 16384,
"gpt-4-0314": 8192,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-32k-0314": 32768,
"gpt-4-0613": 8192,
"text-embedding-ada-002": 8192,
}
def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"""Return the number of tokens used by a list of messages."""
try:
@ -39,7 +54,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0613"):
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
@ -79,3 +94,18 @@ def count_string_tokens(string: str, model_name: str) -> int:
"""
encoding = tiktoken.encoding_for_model(model_name)
return len(encoding.encode(string))
def get_max_completion_tokens(messages: list[dict], model: str, default: int) -> int:
"""Calculate the maximum number of completion tokens for a given model and list of messages.
Args:
messages: A list of messages.
model: The model name.
Returns:
The maximum number of completion tokens.
"""
if model not in TOKEN_MAX:
return default
return TOKEN_MAX[model] - count_message_tokens(messages)