Merge pull request #62 from LeonZh0u/dev

[bugfix] fix serper integration bug to support batch queries
This commit is contained in:
geekan 2023-07-28 21:30:31 +08:00 committed by GitHub
commit af884b461f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 18 additions and 18 deletions

View file

@ -37,7 +37,7 @@ class SearchEngine:
logger.info(results)
return results
async def run(self, query, max_results=8):
async def run(self, query: str, max_results=8):
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
api = SerpAPIWrapper()
rsp = await api.run(query)
@ -45,10 +45,7 @@ class SearchEngine:
rsp = SearchEngine.run_google(query, max_results)
elif self.engine == SearchEngineType.SERPER_GOOGLE:
api = SerperWrapper()
if isinstance(query, list):
rsp = await api.run(query)
elif isinstance(query, str):
rsp = await api.run([query])
rsp = await api.run(query)
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
rsp = self.run_func(query)
else:
@ -74,15 +71,15 @@ def google_official_search(query: str, num_results: int = 8, focus=['snippet', '
api_key = config.google_api_key
custom_search_engine_id = config.google_cse_id
service = build("customsearch", "v1", developerKey=api_key)
with build("customsearch", "v1", developerKey=api_key) as service:
result = (
service.cse()
.list(q=query, cx=custom_search_engine_id, num=num_results)
.execute()
)
# Extract the search result items from the response
result = (
service.cse()
.list(q=query, cx=custom_search_engine_id, num=num_results)
.execute()
)
logger.info(result)
# Extract the search result items from the response
search_results = result.get("items", [])
# Create a list of only the URLs from the search results

View file

@ -38,7 +38,8 @@ class SerperWrapper(BaseModel):
async def run(self, query: str, **kwargs: Any) -> str:
"""Run query through Serper and parse result async."""
return ";".join([self._process_response(res) for res in await self.results(query)])
queries = query.split("\n")
return "\n".join([self._process_response(res) for res in await self.results(queries)])
async def results(self, queries: list[str]) -> dict:
"""Use aiohttp to run query through Serper and return the results async."""