QA with web search

This commit is contained in:
seehi 2024-08-01 10:18:36 +08:00
parent e11d03fe5b
commit a366ed37d7
4 changed files with 328 additions and 24 deletions

View file

@ -0,0 +1,27 @@
"""
This script demonstrates how to use the SearchEnhancedQA action to answer questions
by leveraging web search results. It showcases a simple example of querying about
the current weather in Beijing.
The SearchEnhancedQA action combines web search capabilities with natural language
processing to provide informative answers to user queries.
"""
import asyncio
from metagpt.actions.search_enhanced_qa import SearchEnhancedQA
async def main():
"""Runs a sample query through SearchEnhancedQA and prints the result."""
action = SearchEnhancedQA()
query = "What is the weather like in Beijing today?"
answer = await action.run(query)
print(f"The answer to '{query}' is:\n\n{answer}")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -3,7 +3,7 @@
from __future__ import annotations
import asyncio
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Coroutine, Optional, Union
from pydantic import TypeAdapter, model_validator
@ -160,7 +160,7 @@ class CollectLinks(Action):
A list of ranked URLs.
"""
max_results = max(num_results * 2, 6)
results = await self.search_engine.run(query, max_results=max_results, as_string=False)
results = await self._search_urls(query, max_results=max_results)
_results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results))
prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results)
logger.debug(prompt)
@ -176,6 +176,9 @@ class CollectLinks(Action):
results = self.rank_func(results)
return [i["link"] for i in results[:num_results]]
async def _search_urls(self, query: str, max_results: int) -> list[str]:
return await self.search_engine.run(query, max_results=max_results, as_string=False)
class WebBrowseAndSummarize(Action):
"""Action class to explore the web and provide summaries of articles and webpages."""
@ -202,6 +205,7 @@ class WebBrowseAndSummarize(Action):
*urls: str,
query: str,
system_text: str = RESEARCH_BASE_SYSTEM,
use_concurrent_summarization: bool = False,
) -> dict[str, str]:
"""Run the action to browse the web and provide summaries.
@ -210,6 +214,7 @@ class WebBrowseAndSummarize(Action):
urls: Additional URLs to browse.
query: The research question.
system_text: The system text.
use_concurrent_summarization: Whether to concurrently summarize the content of the webpage by LLM.
Returns:
A dictionary containing the URLs as keys and their summaries as values.
@ -218,31 +223,47 @@ class WebBrowseAndSummarize(Action):
if not urls:
contents = [contents]
summaries = {}
all_urls = [url] + list(urls)
summarize_tasks = [
self._summarize_content(url, content, query, system_text) for url, content in zip(all_urls, contents)
]
summaries = await self._execute_summarize_tasks(summarize_tasks, use_concurrent_summarization)
return dict(summaries)
async def _summarize_content(self, url: str, content: str, query: str, system_text: str) -> tuple[str, str]:
prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}")
for u, content in zip([url, *urls], contents):
content = content.inner_text
chunk_summaries = []
for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, 4096):
logger.debug(prompt)
summary = await self._aask(prompt, [system_text])
if summary == "Not relevant.":
continue
chunk_summaries.append(summary)
if not chunk_summaries:
summaries[u] = None
continue
if len(chunk_summaries) == 1:
summaries[u] = chunk_summaries[0]
continue
content = "\n".join(chunk_summaries)
prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content)
content = content.inner_text
chunk_summaries = []
for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, 4096):
logger.debug(prompt)
summary = await self._aask(prompt, [system_text])
summaries[u] = summary
return summaries
if summary == "Not relevant.":
continue
chunk_summaries.append(summary)
if not chunk_summaries:
return url, None
if len(chunk_summaries) == 1:
return url, chunk_summaries[0]
content = "\n".join(chunk_summaries)
prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content)
summary = await self._aask(prompt, [system_text])
return url, summary
async def _execute_summarize_tasks(
self, tasks: list[Coroutine[Any, Any, tuple[str, str]]], use_concurrent: bool
) -> list[tuple[str, str]]:
"""Execute summarize tasks either concurrently or sequentially."""
if use_concurrent:
return await asyncio.gather(*tasks)
return [await task for task in tasks]
class ConductResearch(Action):

View file

@ -0,0 +1,220 @@
"""Enhancing question-answering capabilities through search engine augmentation."""
from __future__ import annotations
import json
from pydantic import Field
from metagpt.actions import Action
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize
from metagpt.logs import logger
from metagpt.utils.common import CodeParser
REWRITE_QUERY_PROMPT = """
Role: You are a highly efficient assistant that provide a better search query for web search engine to answer the given question.
I will provide you with a question. Your task is to provide a better search query for web search engine.
## Context
### Question
{q}
## Format Example
```json
{{
"query": "the better search query for web search engine.",
}}
```
## Instructions
- Understand the question given by the user.
- Provide a better search query for web search engine to answer the given question, your answer must be written in the same language as the question.
- When rewriting, if you are unsure of the specific time, do not include the time.
## Constraint
Format: Just print the result in json format like **Format Example**.
## Action
Follow instructions, generate output and make sure it follows the **Constraint**.
"""
SEARCH_ENHANCED_QA_SYSTEM_PROMPT = """
You are a large language AI assistant built by MGX. You are given a user question, and please write clean, concise and accurate answer to the question. You will be given a set of related contexts to the question, each starting with a reference number like [[citation:x]], where x is a number. Please use the context.
Your answer must be correct, accurate and written by an expert using an unbiased and professional tone. Please limit to 1024 tokens. Do not give any information that is not related to the question, and do not repeat. Say "information is missing on" followed by the related topic, if the given context do not provide sufficient information. Do not include [citation] in your anwser.
Here are the set of contexts:
{context}
Remember, don't blindly repeat the contexts verbatim. And here is the user question:
"""
class SearchEnhancedQA(Action):
"""Enhancing question-answering capabilities through search engine augmentation."""
name: str = "SearchEnhancedQA"
desc: str = "Integrating search engine results to anwser the question."
collect_links_action: CollectLinks = Field(
default=CollectLinks(), description="Action to collect relevant links from a search engine."
)
web_browse_and_summarize_action: WebBrowseAndSummarize = Field(
default=WebBrowseAndSummarize(),
description="Action to explore the web and provide summaries of articles and webpages.",
)
async def run(self, query: str, rewrite_query: bool = True) -> str:
"""Answer a query by leveraging web search results.
Args:
query (str): The original user query.
rewrite_query (bool): Whether to rewrite the query for better web search results. Defaults to True.
Returns:
str: A detailed answer based on web search results.
Raises:
ValueError: If the query is invalid.
"""
self._validate_query(query)
processed_query = await self._process_query(query, rewrite_query)
context = await self._build_context(processed_query)
return await self._generate_answer(processed_query, context)
def _validate_query(self, query: str) -> None:
"""Validate the input query.
Args:
query (str): The query to validate.
Raises:
ValueError: If the query is invalid.
"""
if not query.strip():
raise ValueError("Query cannot be empty or contain only whitespace.")
async def _process_query(self, query: str, should_rewrite: bool) -> str:
"""Process the query, optionally rewriting it."""
if should_rewrite:
return await self._rewrite_query(query)
return query
async def _rewrite_query(self, query: str) -> str:
"""Write a better search query for web search engine.
If the rewrite process fails, the original query is returned.
Args:
query (str): The original search query.
Returns:
str: The rewritten query if successful, otherwise the original query.
"""
prompt = REWRITE_QUERY_PROMPT.format(q=query)
try:
resp = await self._aask(prompt)
rewritten_query = self._extract_rewritten_query(resp)
logger.info(f"Query rewritten: '{query}' -> '{rewritten_query}'")
return rewritten_query
except Exception as e:
logger.warning(f"Query rewrite failed. Returning original query. Error: {e}")
return query
def _extract_rewritten_query(self, response: str) -> str:
"""Extract the rewritten query from the LLM's JSON response."""
resp_json = json.loads(CodeParser.parse_code(response, lang="json"))
return resp_json["query"]
async def _build_context(self, query: str) -> str:
"""Construct a context string from web search citations.
Args:
query (str): The search query.
Returns:
str: Formatted context with numbered citations.
"""
citations = await self._search_citations(query)
context = "\n\n".join([f"[[citation:{i+1}]] {c}" for i, c in enumerate(citations)])
return context
async def _search_citations(self, query: str) -> list[str]:
"""Perform web search and summarize relevant content.
Args:
query (str): The search query.
Returns:
list[str]: Summaries of relevant web content.
"""
relevant_urls = await self._collect_relevant_links(query)
if not relevant_urls:
logger.warning(f"No relevant URLs found for query: {query}")
return []
logger.info(f"The Relevant links are: {relevant_urls}")
web_summaries = await self._summarize_web_content(relevant_urls, query)
if not web_summaries:
logger.warning(f"No summaries generated for query: {query}")
return []
citations = list(web_summaries.values())
return citations
async def _collect_relevant_links(self, query: str) -> list[str]:
"""Search and rank URLs relevant to the query.
Args:
query (str): The search query.
Returns:
list[str]: Ranked list of relevant URLs.
"""
return await self.collect_links_action._search_and_rank_urls(topic=query, query=query)
async def _summarize_web_content(self, urls: list[str], query: str) -> dict[str, str]:
"""Fetch and summarize content from given URLs.
Args:
urls (list[str]): List of URLs to summarize.
query (str): The original query for context.
Returns:
dict[str, str]: Mapping of URLs to their summaries.
"""
return await self.web_browse_and_summarize_action.run(*urls, query=query, use_concurrent_summarization=True)
async def _generate_answer(self, query: str, context: str) -> str:
"""Generate an answer using the query and context.
Args:
query (str): The user's question.
context (str): Relevant information from web search.
Returns:
str: Generated answer based on the context.
"""
system_prompt = SEARCH_ENHANCED_QA_SYSTEM_PROMPT.format(context=context)
return await self._aask(query, [system_prompt])

View file

@ -15,6 +15,7 @@ import ast
import base64
import contextlib
import csv
import functools
import importlib
import inspect
import json
@ -23,7 +24,10 @@ import os
import platform
import re
import sys
import time
import traceback
from asyncio import iscoroutinefunction
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
@ -1044,3 +1048,35 @@ def tool2name(cls, methods: List[str], entry) -> Dict[str, Any]:
if len(mappings) < 2:
mappings[class_name] = entry
return mappings
def log_time(method):
"""A time-consuming decorator for printing execution duration."""
def before_call():
start_time, cpu_start_time = time.perf_counter(), time.process_time()
logger.info(f"[{method.__name__}] started at: " f"{datetime.now().strftime('%Y-%m-%d %H:%m:%S')}")
return start_time, cpu_start_time
def after_call(start_time, cpu_start_time):
end_time, cpu_end_time = time.perf_counter(), time.process_time()
logger.info(
f"[{method.__name__}] ended. "
f"Time elapsed: {end_time - start_time:.4} sec, CPU elapsed: {cpu_end_time - cpu_start_time:.4} sec"
)
@functools.wraps(method)
def timeit_wrapper(*args, **kwargs):
start_time, cpu_start_time = before_call()
result = method(*args, **kwargs)
after_call(start_time, cpu_start_time)
return result
@functools.wraps(method)
async def timeit_wrapper_async(*args, **kwargs):
start_time, cpu_start_time = before_call()
result = await method(*args, **kwargs)
after_call(start_time, cpu_start_time)
return result
return timeit_wrapper_async if iscoroutinefunction(method) else timeit_wrapper