From 07a1d229cf08f89595c10f7d198ca9aa6b0e550d Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Sat, 2 Sep 2023 18:03:31 +0800 Subject: [PATCH 1/8] restoresearch engine code --- metagpt/tools/search_engine.py | 33 ++++++++-------- metagpt/tools/search_engine_ddg.py | 48 +++++++++++------------ metagpt/tools/search_engine_googleapi.py | 13 +++--- metagpt/tools/search_engine_serpapi.py | 6 +-- metagpt/tools/search_engine_serper.py | 4 +- tests/metagpt/tools/test_search_engine.py | 19 +++++---- 6 files changed, 62 insertions(+), 61 deletions(-) diff --git a/metagpt/tools/search_engine.py b/metagpt/tools/search_engine.py index 5b8b7f046..db8c091d1 100644 --- a/metagpt/tools/search_engine.py +++ b/metagpt/tools/search_engine.py @@ -4,12 +4,11 @@ @Time : 2023/5/6 20:15 @Author : alexanderwu @File : search_engine.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ from __future__ import annotations import importlib -from typing import Callable, Coroutine, Literal, overload, Dict +from typing import Callable, Coroutine, Literal, overload from metagpt.config import CONFIG from metagpt.tools import SearchEngineType @@ -28,23 +27,23 @@ class SearchEngine: """ def __init__( - self, - engine: SearchEngineType | None = None, - run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None + self, + engine: SearchEngineType | None = None, + run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None, ): engine = engine or CONFIG.search_engine if engine == SearchEngineType.SERPAPI_GOOGLE: module = "metagpt.tools.search_engine_serpapi" - run_func = importlib.import_module(module).SerpAPIWrapper(**CONFIG.options).run + run_func = importlib.import_module(module).SerpAPIWrapper().run elif engine == SearchEngineType.SERPER_GOOGLE: module = "metagpt.tools.search_engine_serper" - run_func = importlib.import_module(module).SerperWrapper(**CONFIG.options).run + run_func = importlib.import_module(module).SerperWrapper().run elif engine == SearchEngineType.DIRECT_GOOGLE: module = "metagpt.tools.search_engine_googleapi" - run_func = importlib.import_module(module).GoogleAPIWrapper(**CONFIG.options).run + run_func = importlib.import_module(module).GoogleAPIWrapper().run elif engine == SearchEngineType.DUCK_DUCK_GO: module = "metagpt.tools.search_engine_ddg" - run_func = importlib.import_module(module).DDGAPIWrapper(**CONFIG.options).run + run_func = importlib.import_module(module).DDGAPIWrapper().run elif engine == SearchEngineType.CUSTOM_ENGINE: pass # run_func = run_func else: @@ -54,19 +53,19 @@ class SearchEngine: @overload def run( - self, - query: str, - max_results: int = 8, - as_string: Literal[True] = True, + self, + query: str, + max_results: int = 8, + as_string: Literal[True] = True, ) -> str: ... @overload def run( - self, - query: str, - max_results: int = 8, - as_string: Literal[False] = False, + self, + query: str, + max_results: int = 8, + as_string: Literal[False] = False, ) -> list[dict[str, str]]: ... diff --git a/metagpt/tools/search_engine_ddg.py b/metagpt/tools/search_engine_ddg.py index 78562c77e..57bc61b82 100644 --- a/metagpt/tools/search_engine_ddg.py +++ b/metagpt/tools/search_engine_ddg.py @@ -1,14 +1,11 @@ #!/usr/bin/env python -""" -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. -""" from __future__ import annotations import asyncio import json from concurrent import futures -from typing import Literal, overload, Optional +from typing import Literal, overload try: from duckduckgo_search import DDGS @@ -18,6 +15,8 @@ except ImportError: "You can install it by running the command: `pip install -e.[search-ddg]`" ) +from metagpt.config import CONFIG + class DDGAPIWrapper: """Wrapper around duckduckgo_search API. @@ -26,44 +25,43 @@ class DDGAPIWrapper: """ def __init__( - self, - *, - global_proxy: Optional[str] = None, - loop: asyncio.AbstractEventLoop | None = None, - executor: futures.Executor | None = None, + self, + *, + loop: asyncio.AbstractEventLoop | None = None, + executor: futures.Executor | None = None, ): kwargs = {} - if global_proxy: - kwargs["proxies"] = global_proxy + if CONFIG.global_proxy: + kwargs["proxies"] = CONFIG.global_proxy self.loop = loop self.executor = executor self.ddgs = DDGS(**kwargs) @overload def run( - self, - query: str, - max_results: int = 8, - as_string: Literal[True] = True, - focus: list[str] | None = None, + self, + query: str, + max_results: int = 8, + as_string: Literal[True] = True, + focus: list[str] | None = None, ) -> str: ... @overload def run( - self, - query: str, - max_results: int = 8, - as_string: Literal[False] = False, - focus: list[str] | None = None, + self, + query: str, + max_results: int = 8, + as_string: Literal[False] = False, + focus: list[str] | None = None, ) -> list[dict[str, str]]: ... async def run( - self, - query: str, - max_results: int = 8, - as_string: bool = True, + self, + query: str, + max_results: int = 8, + as_string: bool = True, ) -> str | list[dict]: """Return the results of a Google search using the official Google API diff --git a/metagpt/tools/search_engine_googleapi.py b/metagpt/tools/search_engine_googleapi.py index b5aeb5875..b9faf2ced 100644 --- a/metagpt/tools/search_engine_googleapi.py +++ b/metagpt/tools/search_engine_googleapi.py @@ -1,8 +1,5 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -""" -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. -""" from __future__ import annotations import asyncio @@ -14,6 +11,7 @@ from urllib.parse import urlparse import httplib2 from pydantic import BaseModel, validator +from metagpt.config import CONFIG from metagpt.logs import logger try: @@ -29,7 +27,6 @@ except ImportError: class GoogleAPIWrapper(BaseModel): google_api_key: Optional[str] = None google_cse_id: Optional[str] = None - global_proxy: Optional[str] = None loop: Optional[asyncio.AbstractEventLoop] = None executor: Optional[futures.Executor] = None @@ -39,6 +36,7 @@ class GoogleAPIWrapper(BaseModel): @validator("google_api_key", always=True) @classmethod def check_google_api_key(cls, val: str): + val = val or CONFIG.google_api_key if not val: raise ValueError( "To use, make sure you provide the google_api_key when constructing an object. Alternatively, " @@ -49,7 +47,8 @@ class GoogleAPIWrapper(BaseModel): @validator("google_cse_id", always=True) @classmethod - def check_google_cse_id(cls, val): + def check_google_cse_id(cls, val: str): + val = val or CONFIG.google_cse_id if not val: raise ValueError( "To use, make sure you provide the google_cse_id when constructing an object. Alternatively, " @@ -61,8 +60,8 @@ class GoogleAPIWrapper(BaseModel): @property def google_api_client(self): build_kwargs = {"developerKey": self.google_api_key} - if self.global_proxy: - parse_result = urlparse(self.global_proxy) + if CONFIG.global_proxy: + parse_result = urlparse(CONFIG.global_proxy) proxy_type = parse_result.scheme if proxy_type == "https": proxy_type = "http" diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 1b93a91e9..750184198 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -4,14 +4,13 @@ @Time : 2023/5/23 18:27 @Author : alexanderwu @File : search_engine_serpapi.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ from typing import Any, Dict, Optional, Tuple import aiohttp from pydantic import BaseModel, Field, validator -from metagpt.config import Config +from metagpt.config import CONFIG class SerpAPIWrapper(BaseModel): @@ -33,6 +32,7 @@ class SerpAPIWrapper(BaseModel): @validator("serpapi_api_key", always=True) @classmethod def check_serpapi_api_key(cls, val: str): + val = val or CONFIG.serpapi_api_key if not val: raise ValueError( "To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, " @@ -112,4 +112,4 @@ class SerpAPIWrapper(BaseModel): if __name__ == "__main__": import fire - fire.Fire(SerpAPIWrapper(Config().runtime_options).run) + fire.Fire(SerpAPIWrapper().run) diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py index 849839f05..0eec2694b 100644 --- a/metagpt/tools/search_engine_serper.py +++ b/metagpt/tools/search_engine_serper.py @@ -4,7 +4,6 @@ @Time : 2023/5/23 18:27 @Author : alexanderwu @File : search_engine_serpapi.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ import json from typing import Any, Dict, Optional, Tuple @@ -12,6 +11,8 @@ from typing import Any, Dict, Optional, Tuple import aiohttp from pydantic import BaseModel, Field, validator +from metagpt.config import CONFIG + class SerperWrapper(BaseModel): search_engine: Any #: :meta private: @@ -25,6 +26,7 @@ class SerperWrapper(BaseModel): @validator("serper_api_key", always=True) @classmethod def check_serper_api_key(cls, val: str): + val = val or CONFIG.serper_api_key if not val: raise ValueError( "To use, make sure you provide the serper_api_key when constructing an object. Alternatively, " diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py index 35ccdf78b..25bce124a 100644 --- a/tests/metagpt/tools/test_search_engine.py +++ b/tests/metagpt/tools/test_search_engine.py @@ -4,13 +4,11 @@ @Time : 2023/5/2 17:46 @Author : alexanderwu @File : test_search_engine.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ from __future__ import annotations import pytest -from metagpt.config import Config from metagpt.logs import logger from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine @@ -18,7 +16,9 @@ from metagpt.tools.search_engine import SearchEngine class MockSearchEnine: async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]: - rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)] + rets = [ + {"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results) + ] return "\n".join(rets) if as_string else rets @@ -36,13 +36,16 @@ class MockSearchEnine: (SearchEngineType.DUCK_DUCK_GO, None, 6, False), (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False), (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False), - ], ) -async def test_search_engine(search_engine_typpe, run_func, max_results, as_string): - conf = Config() - search_engine = SearchEngine(options=conf.runtime_options, engine=search_engine_typpe, run_func=run_func) - rsp = await search_engine.run(query="metagpt", max_results=max_results, as_string=as_string) +async def test_search_engine( + search_engine_typpe, + run_func, + max_results, + as_string, +): + search_engine = SearchEngine(search_engine_typpe, run_func) + rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string) logger.info(rsp) if as_string: assert isinstance(rsp, str) From 7bd62b6a498543d4fdf95e62e643eebed8743c3f Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Sat, 2 Sep 2023 21:04:51 +0800 Subject: [PATCH 2/8] add google search skill --- .well-known/skills.yaml | 19 ++++++++++++++++ metagpt/learn/__init__.py | 6 ++--- metagpt/learn/google_search.py | 12 ++++++++++ tests/metagpt/learn/test_google_search.py | 27 +++++++++++++++++++++++ 4 files changed, 60 insertions(+), 4 deletions(-) create mode 100644 metagpt/learn/google_search.py create mode 100644 tests/metagpt/learn/test_google_search.py diff --git a/.well-known/skills.yaml b/.well-known/skills.yaml index 06b9ffd0c..009368dbe 100644 --- a/.well-known/skills.yaml +++ b/.well-known/skills.yaml @@ -45,3 +45,22 @@ entities: returns: type: string format: base64 + + - name: web_search + description: Perform Google searches to provide real-time information. + id: web_search.web_search + x-prerequisite: + - name: SEARCH_ENGINE + description: "Supported values: serpapi/google/serper/ddg" + - name: SERPER_API_KEY + description: "SERPER API KEY, For more details, checkout: `https://serper.dev/api-key`" + arguments: + query: 'The search query. Required.' + max_results: 'The number of search results to retrieve. Default value: 6.' + examples: + - ask: 'Search for information about artificial intelligence' + answer: 'web_search(query="Search for information about artificial intelligence", max_results=6)' + - ask: 'Find news articles about climate change' + answer: 'web_search(query="Find news articles about climate change", max_results=6)' + returns: + type: string \ No newline at end of file diff --git a/metagpt/learn/__init__.py b/metagpt/learn/__init__.py index c8270dbfb..bab9f3e37 100644 --- a/metagpt/learn/__init__.py +++ b/metagpt/learn/__init__.py @@ -8,8 +8,6 @@ from metagpt.learn.text_to_image import text_to_image from metagpt.learn.text_to_speech import text_to_speech +from metagpt.learn.google_search import google_search -__all__ = [ - "text_to_image", - "text_to_speech", -] \ No newline at end of file +__all__ = ["text_to_image", "text_to_speech", "google_search"] diff --git a/metagpt/learn/google_search.py b/metagpt/learn/google_search.py new file mode 100644 index 000000000..ef099fe94 --- /dev/null +++ b/metagpt/learn/google_search.py @@ -0,0 +1,12 @@ +from metagpt.tools.search_engine import SearchEngine + + +async def google_search(query: str, max_results: int = 6, **kwargs): + """Perform a web search and retrieve search results. + + :param query: The search query. + :param max_results: The number of search results to retrieve + :return: The web search results in markdown format. + """ + resluts = await SearchEngine().run(query, max_results=max_results, as_string=False) + return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(resluts, 1)) diff --git a/tests/metagpt/learn/test_google_search.py b/tests/metagpt/learn/test_google_search.py new file mode 100644 index 000000000..da32e8923 --- /dev/null +++ b/tests/metagpt/learn/test_google_search.py @@ -0,0 +1,27 @@ +import asyncio + +from pydantic import BaseModel + +from metagpt.learn.google_search import google_search + + +async def mock_google_search(): + class Input(BaseModel): + input: str + + inputs = [{"input": "ai agent"}] + + for i in inputs: + seed = Input(**i) + result = await google_search(seed.input) + assert result != "" + + +def test_suite(): + loop = asyncio.get_event_loop() + task = loop.create_task(mock_google_search()) + loop.run_until_complete(task) + + +if __name__ == "__main__": + test_suite() From 2856acb3f343b7a4d14643c52352ed2da6bc3119 Mon Sep 17 00:00:00 2001 From: hongjiongteng Date: Sun, 3 Sep 2023 17:22:36 +0800 Subject: [PATCH 3/8] faiss search kwargs --- metagpt/document_store/faiss_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/document_store/faiss_store.py b/metagpt/document_store/faiss_store.py index 051bc2507..b034f40b2 100644 --- a/metagpt/document_store/faiss_store.py +++ b/metagpt/document_store/faiss_store.py @@ -51,7 +51,7 @@ class FaissStore(LocalStore): store.index = index def search(self, query, expand_cols=False, sep='\n', *args, k=5, **kwargs): - rsp = self.store.similarity_search(query, k=k) + rsp = self.store.similarity_search(query, k=k, **kwargs) logger.debug(rsp) if expand_cols: return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp])) From b036b5d22ee17c59f0d01124dea98c34e8ff0a99 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Sun, 3 Sep 2023 22:22:26 +0800 Subject: [PATCH 4/8] remove openai global settings --- metagpt/provider/openai_api.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index b2a0faca5..844cd4c1c 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -77,21 +77,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): """ def __init__(self): - self.__init_openai(CONFIG) self.llm = openai self.model = CONFIG.openai_api_model self.auto_max_tokens = False + self.rpm = int(CONFIG.get("RPM", 10)) RateLimiter.__init__(self, rpm=self.rpm) - def __init_openai(self, config): - openai.api_key = config.openai_api_key - if config.openai_api_base: - openai.api_base = config.openai_api_base - if config.openai_api_type: - openai.api_type = config.openai_api_type - openai.api_version = config.openai_api_version - self.rpm = int(config.get("RPM", 10)) - async def _achat_completion_stream(self, messages: list[dict]) -> str: response = await self.async_retry_call( openai.ChatCompletion.acreate, **self._cons_kwargs(messages), stream=True @@ -133,6 +124,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): "temperature": 0.3, } kwargs["timeout"] = 3 + kwargs["api_base"] = CONFIG.openai_api_base + kwargs["api_key"] = CONFIG.openai_api_key + kwargs["api_type"] = CONFIG.openai_api_type + kwargs["api_version"] = CONFIG.openai_api_version return kwargs async def _achat_completion(self, messages: list[dict]) -> dict: From 87f4c22b6050ea7b951498b03d3cc9149dc54fb9 Mon Sep 17 00:00:00 2001 From: Stitch-z <284618289@qq.com> Date: Mon, 4 Sep 2023 10:48:48 +0800 Subject: [PATCH 5/8] update: aioboto3 client async open file --- metagpt/utils/s3.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/metagpt/utils/s3.py b/metagpt/utils/s3.py index 74c3f1654..96b457972 100644 --- a/metagpt/utils/s3.py +++ b/metagpt/utils/s3.py @@ -44,8 +44,9 @@ class S3: """ try: async with self.session.client(**self.auth_config) as client: - with open(local_path, "rb") as file: - await client.put_object(Body=file, Bucket=bucket, Key=object_name) + async with aiofiles.open(local_path, mode="rb") as reader: + body = await reader.read() + await client.put_object(Body=body, Bucket=bucket, Key=object_name) logger.info(f"Successfully uploaded the file to path {object_name} in bucket {bucket} of s3.") except Exception as e: logger.error(f"Failed to upload the file to path {object_name} in bucket {bucket} of s3: {e}") @@ -119,12 +120,12 @@ class S3: async with self.session.client(**self.auth_config) as client: s3_object = await client.get_object(Bucket=bucket, Key=object_name) stream = s3_object["Body"] - with open(local_path, "wb") as local_file: + async with aiofiles.open(local_path, mode="wb") as writer: while True: file_data = await stream.read(chunk_size) if not file_data: break - local_file.write(file_data) + await writer.write(file_data) except Exception as e: logger.error(f"Failed to download the file from S3: {e}") raise e From d4878f23a0042bf983c1fef8947c649f7d4f4878 Mon Sep 17 00:00:00 2001 From: zhanglei Date: Mon, 4 Sep 2023 10:50:21 +0800 Subject: [PATCH 6/8] =?UTF-8?q?update:=E4=BF=AE=E6=94=B9get=5Fsummary?= =?UTF-8?q?=EF=BC=8C=E5=8A=A0=E4=B8=8A=E6=98=AF=E5=90=A6=E4=BF=9D=E6=8C=81?= =?UTF-8?q?=E8=AF=AD=E8=A8=80=E7=9A=84=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- metagpt/provider/openai_api.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 844cd4c1c..26929575c 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -221,18 +221,18 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return CONFIG.max_tokens_rsp return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) - async def get_summary(self, text: str, max_words=200): + async def get_summary(self, text: str, max_words=200, keep_language: bool = False): max_token_count = DEFAULT_MAX_TOKENS max_count = 100 while max_count > 0: if len(text) < max_token_count: - return await self._get_summary(text, max_words=max_words) + return await self._get_summary(text=text, max_words=max_words,keep_language=keep_language) padding_size = 20 if max_token_count > 20 else 0 text_windows = self.split_texts(text, window_size=max_token_count - padding_size) summaries = [] for ws in text_windows: - response = await self._get_summary(ws, max_words=max_words) + response = await self._get_summary(text=ws, max_words=max_words,keep_language=keep_language) summaries.append(response) if len(summaries) == 1: return summaries[0] @@ -243,11 +243,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): max_count -= 1 # safeguard raise openai.error.InvalidRequestError("text too long") - async def _get_summary(self, text: str, max_words=20): + async def _get_summary(self, text: str, max_words=20, keep_language: bool = False): """Generate text summary""" if len(text) < max_words: return text - command = f"Translate the above content into a summary of less than {max_words} words." + if keep_language: + command = f".Translate the above content into a summary of less than {max_words} words in language of the content." + else: + command = f"Translate the above content into a summary of less than {max_words} words." msg = text + "\n\n" + command logger.info(f"summary ask:{msg}") response = await self.aask(msg=msg, system_msgs=[]) From 9cc85d631ad15fe369f1cd647a4071ca31bd6a94 Mon Sep 17 00:00:00 2001 From: zhanglei Date: Mon, 4 Sep 2023 11:50:22 +0800 Subject: [PATCH 7/8] =?UTF-8?q?update:=E4=BF=AE=E6=94=B9get=5Fsummary?= =?UTF-8?q?=EF=BC=8C=E5=8A=A0=E4=B8=8A=E6=98=AF=E5=90=A6=E4=BF=9D=E6=8C=81?= =?UTF-8?q?=E8=AF=AD=E8=A8=80=E7=9A=84=E9=85=8D=E7=BD=AE,=E5=BC=BA?= =?UTF-8?q?=E8=B0=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- metagpt/provider/openai_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 26929575c..5c11ed7a6 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -248,7 +248,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): if len(text) < max_words: return text if keep_language: - command = f".Translate the above content into a summary of less than {max_words} words in language of the content." + command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." else: command = f"Translate the above content into a summary of less than {max_words} words." msg = text + "\n\n" + command From 32c604a002e78e924d43a732e4b4bd7e3bce1faf Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Mon, 4 Sep 2023 17:21:21 +0800 Subject: [PATCH 8/8] add llm.aask generator --- metagpt/provider/base_gpt_api.py | 4 ++-- metagpt/provider/openai_api.py | 34 +++++++++++++++++--------------- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index af0cf2ec0..7351e6916 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -38,13 +38,13 @@ class BaseGPTAPI(BaseChatbot): rsp = self.completion(message) return self.get_choice_text(rsp) - async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str: + async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, generator: bool = False) -> str: if system_msgs: message = self._system_msgs(system_msgs) + [self._user_msg(msg)] else: message = [self._default_system_msg(), self._user_msg(msg)] try: - rsp = await self.acompletion_text(message, stream=True) + rsp = await self.acompletion_text(message, stream=True, generator=generator) except Exception as e: logger.exception(f"{e}") logger.info(f"ask:{msg}, error:{e}") diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 5c11ed7a6..d0dd5b9d8 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -87,22 +87,11 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): response = await self.async_retry_call( openai.ChatCompletion.acreate, **self._cons_kwargs(messages), stream=True ) - # create variables to collect the stream of chunks - collected_chunks = [] - collected_messages = [] # iterate through the stream of events async for chunk in response: - collected_chunks.append(chunk) # save the event response chunk_message = chunk["choices"][0]["delta"] # extract the message - collected_messages.append(chunk_message) # save the message if "content" in chunk_message: - print(chunk_message["content"], end="") - print() - - full_reply_content = "".join([m.get("content", "") for m in collected_messages]) - usage = self._calc_usage(messages, full_reply_content) - self._update_costs(usage) - return full_reply_content + yield chunk_message["content"] def _cons_kwargs(self, messages: list[dict]) -> dict: if CONFIG.openai_api_type == "azure": @@ -157,10 +146,23 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False) -> str: """when streaming, print each token in place.""" if stream: - return await self._achat_completion_stream(messages) + resp = self._achat_completion_stream(messages) + if generator: + return resp + + collected_messages = [] + async for i in resp: + print(i, end="") + collected_messages.append(i) + + full_reply_content = "".join(collected_messages) + usage = self._calc_usage(messages, full_reply_content) + self._update_costs(usage) + return full_reply_content + rsp = await self._achat_completion(messages) return self.get_choice_text(rsp) @@ -226,13 +228,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): max_count = 100 while max_count > 0: if len(text) < max_token_count: - return await self._get_summary(text=text, max_words=max_words,keep_language=keep_language) + return await self._get_summary(text=text, max_words=max_words, keep_language=keep_language) padding_size = 20 if max_token_count > 20 else 0 text_windows = self.split_texts(text, window_size=max_token_count - padding_size) summaries = [] for ws in text_windows: - response = await self._get_summary(text=ws, max_words=max_words,keep_language=keep_language) + response = await self._get_summary(text=ws, max_words=max_words, keep_language=keep_language) summaries.append(response) if len(summaries) == 1: return summaries[0]