Merge pull request #3 from send18/dev

Dev
This commit is contained in:
Guess 2023-09-04 20:22:17 +08:00 committed by GitHub
commit 1f8dd58f3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 159 additions and 101 deletions

View file

@ -45,3 +45,22 @@ entities:
returns:
type: string
format: base64
- name: web_search
description: Perform Google searches to provide real-time information.
id: web_search.web_search
x-prerequisite:
- name: SEARCH_ENGINE
description: "Supported values: serpapi/google/serper/ddg"
- name: SERPER_API_KEY
description: "SERPER API KEY, For more details, checkout: `https://serper.dev/api-key`"
arguments:
query: 'The search query. Required.'
max_results: 'The number of search results to retrieve. Default value: 6.'
examples:
- ask: 'Search for information about artificial intelligence'
answer: 'web_search(query="Search for information about artificial intelligence", max_results=6)'
- ask: 'Find news articles about climate change'
answer: 'web_search(query="Find news articles about climate change", max_results=6)'
returns:
type: string

View file

@ -55,7 +55,7 @@ class FaissStore(LocalStore):
store.index = index
def search(self, query, expand_cols=False, sep='\n', *args, k=5, **kwargs):
rsp = self.store.similarity_search(query, k=k)
rsp = self.store.similarity_search(query, k=k, **kwargs)
logger.debug(rsp)
if expand_cols:
return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp]))

View file

@ -8,8 +8,6 @@
from metagpt.learn.text_to_image import text_to_image
from metagpt.learn.text_to_speech import text_to_speech
from metagpt.learn.google_search import google_search
__all__ = [
"text_to_image",
"text_to_speech",
]
__all__ = ["text_to_image", "text_to_speech", "google_search"]

View file

@ -0,0 +1,12 @@
from metagpt.tools.search_engine import SearchEngine
async def google_search(query: str, max_results: int = 6, **kwargs):
"""Perform a web search and retrieve search results.
:param query: The search query.
:param max_results: The number of search results to retrieve
:return: The web search results in markdown format.
"""
resluts = await SearchEngine().run(query, max_results=max_results, as_string=False)
return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(resluts, 1))

View file

@ -38,13 +38,13 @@ class BaseGPTAPI(BaseChatbot):
rsp = self.completion(message)
return self.get_choice_text(rsp)
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, generator: bool = False) -> str:
if system_msgs:
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
else:
message = [self._default_system_msg(), self._user_msg(msg)]
try:
rsp = await self.acompletion_text(message, stream=True)
rsp = await self.acompletion_text(message, stream=True, generator=generator)
except Exception as e:
logger.exception(f"{e}")
logger.info(f"ask:{msg}, error:{e}")

View file

@ -77,41 +77,21 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"""
def __init__(self):
self.__init_openai(CONFIG)
self.llm = openai
self.model = CONFIG.openai_api_model
self.auto_max_tokens = False
self.rpm = int(CONFIG.get("RPM", 10))
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openai(self, config):
openai.api_key = config.openai_api_key
if config.openai_api_base:
openai.api_base = config.openai_api_base
if config.openai_api_type:
openai.api_type = config.openai_api_type
openai.api_version = config.openai_api_version
self.rpm = int(config.get("RPM", 10))
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response = await self.async_retry_call(
openai.ChatCompletion.acreate, **self._cons_kwargs(messages), stream=True
)
# create variables to collect the stream of chunks
collected_chunks = []
collected_messages = []
# iterate through the stream of events
async for chunk in response:
collected_chunks.append(chunk) # save the event response
chunk_message = chunk["choices"][0]["delta"] # extract the message
collected_messages.append(chunk_message) # save the message
if "content" in chunk_message:
print(chunk_message["content"], end="")
print()
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
yield chunk_message["content"]
def _cons_kwargs(self, messages: list[dict]) -> dict:
if CONFIG.openai_api_type == "azure":
@ -133,6 +113,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"temperature": 0.3,
}
kwargs["timeout"] = 3
kwargs["api_base"] = CONFIG.openai_api_base
kwargs["api_key"] = CONFIG.openai_api_key
kwargs["api_type"] = CONFIG.openai_api_type
kwargs["api_version"] = CONFIG.openai_api_version
return kwargs
async def _achat_completion(self, messages: list[dict]) -> dict:
@ -162,10 +146,23 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)
resp = self._achat_completion_stream(messages)
if generator:
return resp
collected_messages = []
async for i in resp:
print(i, end="")
collected_messages.append(i)
full_reply_content = "".join(collected_messages)
usage = self._calc_usage(messages, full_reply_content)
self._update_costs(usage)
return full_reply_content
rsp = await self._achat_completion(messages)
return self.get_choice_text(rsp)
@ -226,18 +223,18 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return CONFIG.max_tokens_rsp
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
async def get_summary(self, text: str, max_words=200):
async def get_summary(self, text: str, max_words=200, keep_language: bool = False):
max_token_count = DEFAULT_MAX_TOKENS
max_count = 100
while max_count > 0:
if len(text) < max_token_count:
return await self._get_summary(text, max_words=max_words)
return await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
summaries = []
for ws in text_windows:
response = await self._get_summary(ws, max_words=max_words)
response = await self._get_summary(text=ws, max_words=max_words, keep_language=keep_language)
summaries.append(response)
if len(summaries) == 1:
return summaries[0]
@ -248,11 +245,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
max_count -= 1 # safeguard
raise openai.error.InvalidRequestError("text too long")
async def _get_summary(self, text: str, max_words=20):
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
"""Generate text summary"""
if len(text) < max_words:
return text
command = f"Translate the above content into a summary of less than {max_words} words."
if keep_language:
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
else:
command = f"Translate the above content into a summary of less than {max_words} words."
msg = text + "\n\n" + command
logger.info(f"summary ask:{msg}")
response = await self.aask(msg=msg, system_msgs=[])

View file

@ -4,12 +4,11 @@
@Time : 2023/5/6 20:15
@Author : alexanderwu
@File : search_engine.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from __future__ import annotations
import importlib
from typing import Callable, Coroutine, Literal, overload, Dict
from typing import Callable, Coroutine, Literal, overload
from metagpt.config import CONFIG
from metagpt.tools import SearchEngineType
@ -28,23 +27,23 @@ class SearchEngine:
"""
def __init__(
self,
engine: SearchEngineType | None = None,
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None
self,
engine: SearchEngineType | None = None,
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None,
):
engine = engine or CONFIG.search_engine
if engine == SearchEngineType.SERPAPI_GOOGLE:
module = "metagpt.tools.search_engine_serpapi"
run_func = importlib.import_module(module).SerpAPIWrapper(**CONFIG.options).run
run_func = importlib.import_module(module).SerpAPIWrapper().run
elif engine == SearchEngineType.SERPER_GOOGLE:
module = "metagpt.tools.search_engine_serper"
run_func = importlib.import_module(module).SerperWrapper(**CONFIG.options).run
run_func = importlib.import_module(module).SerperWrapper().run
elif engine == SearchEngineType.DIRECT_GOOGLE:
module = "metagpt.tools.search_engine_googleapi"
run_func = importlib.import_module(module).GoogleAPIWrapper(**CONFIG.options).run
run_func = importlib.import_module(module).GoogleAPIWrapper().run
elif engine == SearchEngineType.DUCK_DUCK_GO:
module = "metagpt.tools.search_engine_ddg"
run_func = importlib.import_module(module).DDGAPIWrapper(**CONFIG.options).run
run_func = importlib.import_module(module).DDGAPIWrapper().run
elif engine == SearchEngineType.CUSTOM_ENGINE:
pass # run_func = run_func
else:
@ -54,19 +53,19 @@ class SearchEngine:
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
) -> str:
...
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
) -> list[dict[str, str]]:
...

View file

@ -1,14 +1,11 @@
#!/usr/bin/env python
"""
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from __future__ import annotations
import asyncio
import json
from concurrent import futures
from typing import Literal, overload, Optional
from typing import Literal, overload
try:
from duckduckgo_search import DDGS
@ -18,6 +15,8 @@ except ImportError:
"You can install it by running the command: `pip install -e.[search-ddg]`"
)
from metagpt.config import CONFIG
class DDGAPIWrapper:
"""Wrapper around duckduckgo_search API.
@ -26,44 +25,43 @@ class DDGAPIWrapper:
"""
def __init__(
self,
*,
global_proxy: Optional[str] = None,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
):
kwargs = {}
if global_proxy:
kwargs["proxies"] = global_proxy
if CONFIG.global_proxy:
kwargs["proxies"] = CONFIG.global_proxy
self.loop = loop
self.executor = executor
self.ddgs = DDGS(**kwargs)
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
focus: list[str] | None = None,
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
focus: list[str] | None = None,
) -> str:
...
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
focus: list[str] | None = None,
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
focus: list[str] | None = None,
) -> list[dict[str, str]]:
...
async def run(
self,
query: str,
max_results: int = 8,
as_string: bool = True,
self,
query: str,
max_results: int = 8,
as_string: bool = True,
) -> str | list[dict]:
"""Return the results of a Google search using the official Google API

View file

@ -1,8 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from __future__ import annotations
import asyncio
@ -14,6 +11,7 @@ from urllib.parse import urlparse
import httplib2
from pydantic import BaseModel, validator
from metagpt.config import CONFIG
from metagpt.logs import logger
try:
@ -29,7 +27,6 @@ except ImportError:
class GoogleAPIWrapper(BaseModel):
google_api_key: Optional[str] = None
google_cse_id: Optional[str] = None
global_proxy: Optional[str] = None
loop: Optional[asyncio.AbstractEventLoop] = None
executor: Optional[futures.Executor] = None
@ -39,6 +36,7 @@ class GoogleAPIWrapper(BaseModel):
@validator("google_api_key", always=True)
@classmethod
def check_google_api_key(cls, val: str):
val = val or CONFIG.google_api_key
if not val:
raise ValueError(
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
@ -49,7 +47,8 @@ class GoogleAPIWrapper(BaseModel):
@validator("google_cse_id", always=True)
@classmethod
def check_google_cse_id(cls, val):
def check_google_cse_id(cls, val: str):
val = val or CONFIG.google_cse_id
if not val:
raise ValueError(
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
@ -61,8 +60,8 @@ class GoogleAPIWrapper(BaseModel):
@property
def google_api_client(self):
build_kwargs = {"developerKey": self.google_api_key}
if self.global_proxy:
parse_result = urlparse(self.global_proxy)
if CONFIG.global_proxy:
parse_result = urlparse(CONFIG.global_proxy)
proxy_type = parse_result.scheme
if proxy_type == "https":
proxy_type = "http"

View file

@ -4,14 +4,13 @@
@Time : 2023/5/23 18:27
@Author : alexanderwu
@File : search_engine_serpapi.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, Field, validator
from metagpt.config import Config
from metagpt.config import CONFIG
class SerpAPIWrapper(BaseModel):
@ -33,6 +32,7 @@ class SerpAPIWrapper(BaseModel):
@validator("serpapi_api_key", always=True)
@classmethod
def check_serpapi_api_key(cls, val: str):
val = val or CONFIG.serpapi_api_key
if not val:
raise ValueError(
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
@ -112,4 +112,4 @@ class SerpAPIWrapper(BaseModel):
if __name__ == "__main__":
import fire
fire.Fire(SerpAPIWrapper(Config().runtime_options).run)
fire.Fire(SerpAPIWrapper().run)

View file

@ -4,7 +4,6 @@
@Time : 2023/5/23 18:27
@Author : alexanderwu
@File : search_engine_serpapi.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
import json
from typing import Any, Dict, Optional, Tuple
@ -12,6 +11,8 @@ from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, Field, validator
from metagpt.config import CONFIG
class SerperWrapper(BaseModel):
search_engine: Any #: :meta private:
@ -25,6 +26,7 @@ class SerperWrapper(BaseModel):
@validator("serper_api_key", always=True)
@classmethod
def check_serper_api_key(cls, val: str):
val = val or CONFIG.serper_api_key
if not val:
raise ValueError(
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "

View file

@ -44,8 +44,9 @@ class S3:
"""
try:
async with self.session.client(**self.auth_config) as client:
with open(local_path, "rb") as file:
await client.put_object(Body=file, Bucket=bucket, Key=object_name)
async with aiofiles.open(local_path, mode="rb") as reader:
body = await reader.read()
await client.put_object(Body=body, Bucket=bucket, Key=object_name)
logger.info(f"Successfully uploaded the file to path {object_name} in bucket {bucket} of s3.")
except Exception as e:
logger.error(f"Failed to upload the file to path {object_name} in bucket {bucket} of s3: {e}")
@ -119,12 +120,12 @@ class S3:
async with self.session.client(**self.auth_config) as client:
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
stream = s3_object["Body"]
with open(local_path, "wb") as local_file:
async with aiofiles.open(local_path, mode="wb") as writer:
while True:
file_data = await stream.read(chunk_size)
if not file_data:
break
local_file.write(file_data)
await writer.write(file_data)
except Exception as e:
logger.error(f"Failed to download the file from S3: {e}")
raise e

View file

@ -0,0 +1,27 @@
import asyncio
from pydantic import BaseModel
from metagpt.learn.google_search import google_search
async def mock_google_search():
class Input(BaseModel):
input: str
inputs = [{"input": "ai agent"}]
for i in inputs:
seed = Input(**i)
result = await google_search(seed.input)
assert result != ""
def test_suite():
loop = asyncio.get_event_loop()
task = loop.create_task(mock_google_search())
loop.run_until_complete(task)
if __name__ == "__main__":
test_suite()

View file

@ -4,13 +4,11 @@
@Time : 2023/5/2 17:46
@Author : alexanderwu
@File : test_search_engine.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from __future__ import annotations
import pytest
from metagpt.config import Config
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
@ -18,7 +16,9 @@ from metagpt.tools.search_engine import SearchEngine
class MockSearchEnine:
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)]
rets = [
{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)
]
return "\n".join(rets) if as_string else rets
@ -36,13 +36,16 @@ class MockSearchEnine:
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
],
)
async def test_search_engine(search_engine_typpe, run_func, max_results, as_string):
conf = Config()
search_engine = SearchEngine(options=conf.runtime_options, engine=search_engine_typpe, run_func=run_func)
rsp = await search_engine.run(query="metagpt", max_results=max_results, as_string=as_string)
async def test_search_engine(
search_engine_typpe,
run_func,
max_results,
as_string,
):
search_engine = SearchEngine(search_engine_typpe, run_func)
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
logger.info(rsp)
if as_string:
assert isinstance(rsp, str)