mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-14 15:25:17 +02:00
commit
1f8dd58f3b
14 changed files with 159 additions and 101 deletions
|
|
@ -45,3 +45,22 @@ entities:
|
|||
returns:
|
||||
type: string
|
||||
format: base64
|
||||
|
||||
- name: web_search
|
||||
description: Perform Google searches to provide real-time information.
|
||||
id: web_search.web_search
|
||||
x-prerequisite:
|
||||
- name: SEARCH_ENGINE
|
||||
description: "Supported values: serpapi/google/serper/ddg"
|
||||
- name: SERPER_API_KEY
|
||||
description: "SERPER API KEY, For more details, checkout: `https://serper.dev/api-key`"
|
||||
arguments:
|
||||
query: 'The search query. Required.'
|
||||
max_results: 'The number of search results to retrieve. Default value: 6.'
|
||||
examples:
|
||||
- ask: 'Search for information about artificial intelligence'
|
||||
answer: 'web_search(query="Search for information about artificial intelligence", max_results=6)'
|
||||
- ask: 'Find news articles about climate change'
|
||||
answer: 'web_search(query="Find news articles about climate change", max_results=6)'
|
||||
returns:
|
||||
type: string
|
||||
|
|
@ -55,7 +55,7 @@ class FaissStore(LocalStore):
|
|||
store.index = index
|
||||
|
||||
def search(self, query, expand_cols=False, sep='\n', *args, k=5, **kwargs):
|
||||
rsp = self.store.similarity_search(query, k=k)
|
||||
rsp = self.store.similarity_search(query, k=k, **kwargs)
|
||||
logger.debug(rsp)
|
||||
if expand_cols:
|
||||
return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp]))
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@
|
|||
|
||||
from metagpt.learn.text_to_image import text_to_image
|
||||
from metagpt.learn.text_to_speech import text_to_speech
|
||||
from metagpt.learn.google_search import google_search
|
||||
|
||||
__all__ = [
|
||||
"text_to_image",
|
||||
"text_to_speech",
|
||||
]
|
||||
__all__ = ["text_to_image", "text_to_speech", "google_search"]
|
||||
|
|
|
|||
12
metagpt/learn/google_search.py
Normal file
12
metagpt/learn/google_search.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
async def google_search(query: str, max_results: int = 6, **kwargs):
|
||||
"""Perform a web search and retrieve search results.
|
||||
|
||||
:param query: The search query.
|
||||
:param max_results: The number of search results to retrieve
|
||||
:return: The web search results in markdown format.
|
||||
"""
|
||||
resluts = await SearchEngine().run(query, max_results=max_results, as_string=False)
|
||||
return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(resluts, 1))
|
||||
|
|
@ -38,13 +38,13 @@ class BaseGPTAPI(BaseChatbot):
|
|||
rsp = self.completion(message)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str:
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, generator: bool = False) -> str:
|
||||
if system_msgs:
|
||||
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
else:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
try:
|
||||
rsp = await self.acompletion_text(message, stream=True)
|
||||
rsp = await self.acompletion_text(message, stream=True, generator=generator)
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}")
|
||||
logger.info(f"ask:{msg}, error:{e}")
|
||||
|
|
|
|||
|
|
@ -77,41 +77,21 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_openai(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.openai_api_model
|
||||
self.auto_max_tokens = False
|
||||
self.rpm = int(CONFIG.get("RPM", 10))
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openai(self, config):
|
||||
openai.api_key = config.openai_api_key
|
||||
if config.openai_api_base:
|
||||
openai.api_base = config.openai_api_base
|
||||
if config.openai_api_type:
|
||||
openai.api_type = config.openai_api_type
|
||||
openai.api_version = config.openai_api_version
|
||||
self.rpm = int(config.get("RPM", 10))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response = await self.async_retry_call(
|
||||
openai.ChatCompletion.acreate, **self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
# create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
collected_messages = []
|
||||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
chunk_message = chunk["choices"][0]["delta"] # extract the message
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
if "content" in chunk_message:
|
||||
print(chunk_message["content"], end="")
|
||||
print()
|
||||
|
||||
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
yield chunk_message["content"]
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict]) -> dict:
|
||||
if CONFIG.openai_api_type == "azure":
|
||||
|
|
@ -133,6 +113,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
"temperature": 0.3,
|
||||
}
|
||||
kwargs["timeout"] = 3
|
||||
kwargs["api_base"] = CONFIG.openai_api_base
|
||||
kwargs["api_key"] = CONFIG.openai_api_key
|
||||
kwargs["api_type"] = CONFIG.openai_api_type
|
||||
kwargs["api_version"] = CONFIG.openai_api_version
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
|
|
@ -162,10 +146,23 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = self._achat_completion_stream(messages)
|
||||
if generator:
|
||||
return resp
|
||||
|
||||
collected_messages = []
|
||||
async for i in resp:
|
||||
print(i, end="")
|
||||
collected_messages.append(i)
|
||||
|
||||
full_reply_content = "".join(collected_messages)
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
rsp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
|
|
@ -226,18 +223,18 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
|
||||
async def get_summary(self, text: str, max_words=200):
|
||||
async def get_summary(self, text: str, max_words=200, keep_language: bool = False):
|
||||
max_token_count = DEFAULT_MAX_TOKENS
|
||||
max_count = 100
|
||||
while max_count > 0:
|
||||
if len(text) < max_token_count:
|
||||
return await self._get_summary(text, max_words=max_words)
|
||||
return await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
|
||||
|
||||
padding_size = 20 if max_token_count > 20 else 0
|
||||
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
|
||||
summaries = []
|
||||
for ws in text_windows:
|
||||
response = await self._get_summary(ws, max_words=max_words)
|
||||
response = await self._get_summary(text=ws, max_words=max_words, keep_language=keep_language)
|
||||
summaries.append(response)
|
||||
if len(summaries) == 1:
|
||||
return summaries[0]
|
||||
|
|
@ -248,11 +245,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
max_count -= 1 # safeguard
|
||||
raise openai.error.InvalidRequestError("text too long")
|
||||
|
||||
async def _get_summary(self, text: str, max_words=20):
|
||||
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
|
||||
"""Generate text summary"""
|
||||
if len(text) < max_words:
|
||||
return text
|
||||
command = f"Translate the above content into a summary of less than {max_words} words."
|
||||
if keep_language:
|
||||
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
|
||||
else:
|
||||
command = f"Translate the above content into a summary of less than {max_words} words."
|
||||
msg = text + "\n\n" + command
|
||||
logger.info(f"summary ask:{msg}")
|
||||
response = await self.aask(msg=msg, system_msgs=[])
|
||||
|
|
|
|||
|
|
@ -4,12 +4,11 @@
|
|||
@Time : 2023/5/6 20:15
|
||||
@Author : alexanderwu
|
||||
@File : search_engine.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import Callable, Coroutine, Literal, overload, Dict
|
||||
from typing import Callable, Coroutine, Literal, overload
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
|
@ -28,23 +27,23 @@ class SearchEngine:
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: SearchEngineType | None = None,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None
|
||||
self,
|
||||
engine: SearchEngineType | None = None,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None,
|
||||
):
|
||||
engine = engine or CONFIG.search_engine
|
||||
if engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serpapi"
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper(**CONFIG.options).run
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper().run
|
||||
elif engine == SearchEngineType.SERPER_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serper"
|
||||
run_func = importlib.import_module(module).SerperWrapper(**CONFIG.options).run
|
||||
run_func = importlib.import_module(module).SerperWrapper().run
|
||||
elif engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_googleapi"
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper(**CONFIG.options).run
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper().run
|
||||
elif engine == SearchEngineType.DUCK_DUCK_GO:
|
||||
module = "metagpt.tools.search_engine_ddg"
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper(**CONFIG.options).run
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper().run
|
||||
elif engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
pass # run_func = run_func
|
||||
else:
|
||||
|
|
@ -54,19 +53,19 @@ class SearchEngine:
|
|||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[True] = True,
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[True] = True,
|
||||
) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[False] = False,
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[False] = False,
|
||||
) -> list[dict[str, str]]:
|
||||
...
|
||||
|
||||
|
|
|
|||
|
|
@ -1,14 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from concurrent import futures
|
||||
from typing import Literal, overload, Optional
|
||||
from typing import Literal, overload
|
||||
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
|
|
@ -18,6 +15,8 @@ except ImportError:
|
|||
"You can install it by running the command: `pip install -e.[search-ddg]`"
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class DDGAPIWrapper:
|
||||
"""Wrapper around duckduckgo_search API.
|
||||
|
|
@ -26,44 +25,43 @@ class DDGAPIWrapper:
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
global_proxy: Optional[str] = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
self,
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
):
|
||||
kwargs = {}
|
||||
if global_proxy:
|
||||
kwargs["proxies"] = global_proxy
|
||||
if CONFIG.global_proxy:
|
||||
kwargs["proxies"] = CONFIG.global_proxy
|
||||
self.loop = loop
|
||||
self.executor = executor
|
||||
self.ddgs = DDGS(**kwargs)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[True] = True,
|
||||
focus: list[str] | None = None,
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[True] = True,
|
||||
focus: list[str] | None = None,
|
||||
) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[False] = False,
|
||||
focus: list[str] | None = None,
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[False] = False,
|
||||
focus: list[str] | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
...
|
||||
|
||||
async def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: bool = True,
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: bool = True,
|
||||
) -> str | list[dict]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
|
@ -14,6 +11,7 @@ from urllib.parse import urlparse
|
|||
import httplib2
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
try:
|
||||
|
|
@ -29,7 +27,6 @@ except ImportError:
|
|||
class GoogleAPIWrapper(BaseModel):
|
||||
google_api_key: Optional[str] = None
|
||||
google_cse_id: Optional[str] = None
|
||||
global_proxy: Optional[str] = None
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
executor: Optional[futures.Executor] = None
|
||||
|
||||
|
|
@ -39,6 +36,7 @@ class GoogleAPIWrapper(BaseModel):
|
|||
@validator("google_api_key", always=True)
|
||||
@classmethod
|
||||
def check_google_api_key(cls, val: str):
|
||||
val = val or CONFIG.google_api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
|
||||
|
|
@ -49,7 +47,8 @@ class GoogleAPIWrapper(BaseModel):
|
|||
|
||||
@validator("google_cse_id", always=True)
|
||||
@classmethod
|
||||
def check_google_cse_id(cls, val):
|
||||
def check_google_cse_id(cls, val: str):
|
||||
val = val or CONFIG.google_cse_id
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
|
||||
|
|
@ -61,8 +60,8 @@ class GoogleAPIWrapper(BaseModel):
|
|||
@property
|
||||
def google_api_client(self):
|
||||
build_kwargs = {"developerKey": self.google_api_key}
|
||||
if self.global_proxy:
|
||||
parse_result = urlparse(self.global_proxy)
|
||||
if CONFIG.global_proxy:
|
||||
parse_result = urlparse(CONFIG.global_proxy)
|
||||
proxy_type = parse_result.scheme
|
||||
if proxy_type == "https":
|
||||
proxy_type = "http"
|
||||
|
|
|
|||
|
|
@ -4,14 +4,13 @@
|
|||
@Time : 2023/5/23 18:27
|
||||
@Author : alexanderwu
|
||||
@File : search_engine_serpapi.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class SerpAPIWrapper(BaseModel):
|
||||
|
|
@ -33,6 +32,7 @@ class SerpAPIWrapper(BaseModel):
|
|||
@validator("serpapi_api_key", always=True)
|
||||
@classmethod
|
||||
def check_serpapi_api_key(cls, val: str):
|
||||
val = val or CONFIG.serpapi_api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
|
||||
|
|
@ -112,4 +112,4 @@ class SerpAPIWrapper(BaseModel):
|
|||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(SerpAPIWrapper(Config().runtime_options).run)
|
||||
fire.Fire(SerpAPIWrapper().run)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
@Time : 2023/5/23 18:27
|
||||
@Author : alexanderwu
|
||||
@File : search_engine_serpapi.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
|
@ -12,6 +11,8 @@ from typing import Any, Dict, Optional, Tuple
|
|||
import aiohttp
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class SerperWrapper(BaseModel):
|
||||
search_engine: Any #: :meta private:
|
||||
|
|
@ -25,6 +26,7 @@ class SerperWrapper(BaseModel):
|
|||
@validator("serper_api_key", always=True)
|
||||
@classmethod
|
||||
def check_serper_api_key(cls, val: str):
|
||||
val = val or CONFIG.serper_api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "
|
||||
|
|
|
|||
|
|
@ -44,8 +44,9 @@ class S3:
|
|||
"""
|
||||
try:
|
||||
async with self.session.client(**self.auth_config) as client:
|
||||
with open(local_path, "rb") as file:
|
||||
await client.put_object(Body=file, Bucket=bucket, Key=object_name)
|
||||
async with aiofiles.open(local_path, mode="rb") as reader:
|
||||
body = await reader.read()
|
||||
await client.put_object(Body=body, Bucket=bucket, Key=object_name)
|
||||
logger.info(f"Successfully uploaded the file to path {object_name} in bucket {bucket} of s3.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload the file to path {object_name} in bucket {bucket} of s3: {e}")
|
||||
|
|
@ -119,12 +120,12 @@ class S3:
|
|||
async with self.session.client(**self.auth_config) as client:
|
||||
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
|
||||
stream = s3_object["Body"]
|
||||
with open(local_path, "wb") as local_file:
|
||||
async with aiofiles.open(local_path, mode="wb") as writer:
|
||||
while True:
|
||||
file_data = await stream.read(chunk_size)
|
||||
if not file_data:
|
||||
break
|
||||
local_file.write(file_data)
|
||||
await writer.write(file_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download the file from S3: {e}")
|
||||
raise e
|
||||
|
|
|
|||
27
tests/metagpt/learn/test_google_search.py
Normal file
27
tests/metagpt/learn/test_google_search.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.learn.google_search import google_search
|
||||
|
||||
|
||||
async def mock_google_search():
|
||||
class Input(BaseModel):
|
||||
input: str
|
||||
|
||||
inputs = [{"input": "ai agent"}]
|
||||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
result = await google_search(seed.input)
|
||||
assert result != ""
|
||||
|
||||
|
||||
def test_suite():
|
||||
loop = asyncio.get_event_loop()
|
||||
task = loop.create_task(mock_google_search())
|
||||
loop.run_until_complete(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_suite()
|
||||
|
|
@ -4,13 +4,11 @@
|
|||
@Time : 2023/5/2 17:46
|
||||
@Author : alexanderwu
|
||||
@File : test_search_engine.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
|
@ -18,7 +16,9 @@ from metagpt.tools.search_engine import SearchEngine
|
|||
|
||||
class MockSearchEnine:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)]
|
||||
rets = [
|
||||
{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)
|
||||
]
|
||||
return "\n".join(rets) if as_string else rets
|
||||
|
||||
|
||||
|
|
@ -36,13 +36,16 @@ class MockSearchEnine:
|
|||
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
|
||||
],
|
||||
)
|
||||
async def test_search_engine(search_engine_typpe, run_func, max_results, as_string):
|
||||
conf = Config()
|
||||
search_engine = SearchEngine(options=conf.runtime_options, engine=search_engine_typpe, run_func=run_func)
|
||||
rsp = await search_engine.run(query="metagpt", max_results=max_results, as_string=as_string)
|
||||
async def test_search_engine(
|
||||
search_engine_typpe,
|
||||
run_func,
|
||||
max_results,
|
||||
as_string,
|
||||
):
|
||||
search_engine = SearchEngine(search_engine_typpe, run_func)
|
||||
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
|
||||
logger.info(rsp)
|
||||
if as_string:
|
||||
assert isinstance(rsp, str)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue