Merge pull request #812 from shenchucheng/fix-google-search-ut-error

fix tests/metagpt/learn/test_google_search.py error
This commit is contained in:
garylin2099 2024-01-31 16:24:06 +08:00 committed by GitHub
commit 324a81d4fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 135 additions and 17 deletions

View file

@ -8,5 +8,5 @@ async def google_search(query: str, max_results: int = 6, **kwargs):
:param max_results: The number of search results to retrieve
:return: The web search results in markdown format.
"""
results = await SearchEngine().run(query, max_results=max_results, as_string=False)
results = await SearchEngine(**kwargs).run(query, max_results=max_results, as_string=False)
return "\n".join(f"{i}. [{j['title']}]({j['link']}): {j['snippet']}" for i, j in enumerate(results, 1))

View file

@ -61,9 +61,11 @@ class SerpAPIWrapper(BaseModel):
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get(url, params=params) as response:
response.raise_for_status()
res = await response.json()
return res

View file

@ -55,9 +55,11 @@ class SerperWrapper(BaseModel):
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(url, data=payloads, headers=headers) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get.post(url, data=payloads, headers=headers) as response:
response.raise_for_status()
res = await response.json()
return res

File diff suppressed because one or more lines are too long

View file

@ -1,27 +1,21 @@
import asyncio
import pytest
from pydantic import BaseModel
from metagpt.learn.google_search import google_search
from metagpt.tools import SearchEngineType
async def mock_google_search():
@pytest.mark.asyncio
async def test_google_search(search_engine_mocker):
class Input(BaseModel):
input: str
inputs = [{"input": "ai agent"}]
for i in inputs:
seed = Input(**i)
result = await google_search(seed.input)
result = await google_search(
seed.input,
engine=SearchEngineType.SERPER_GOOGLE,
serper_api_key="mock-serper-key",
)
assert result != ""
def test_suite():
loop = asyncio.get_event_loop()
task = loop.create_task(mock_google_search())
loop.run_until_complete(task)
if __name__ == "__main__":
test_suite()

View file

@ -39,3 +39,7 @@ class MockAioResponse:
data = await self.response.json(*args, **kwargs)
self.rsp_cache[self.key] = data
return data
def raise_for_status(self):
if self.response:
self.response.raise_for_status()