From a366ed37d7e40713e5168ce3b5f91aaffc55b935 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Thu, 1 Aug 2024 10:18:36 +0800 Subject: [PATCH] QA with web search --- examples/search_enhanced_qa.py | 27 ++++ metagpt/actions/research.py | 69 +++++--- metagpt/actions/search_enhanced_qa.py | 220 ++++++++++++++++++++++++++ metagpt/utils/common.py | 36 +++++ 4 files changed, 328 insertions(+), 24 deletions(-) create mode 100644 examples/search_enhanced_qa.py create mode 100644 metagpt/actions/search_enhanced_qa.py diff --git a/examples/search_enhanced_qa.py b/examples/search_enhanced_qa.py new file mode 100644 index 000000000..9eb5449a4 --- /dev/null +++ b/examples/search_enhanced_qa.py @@ -0,0 +1,27 @@ +""" +This script demonstrates how to use the SearchEnhancedQA action to answer questions +by leveraging web search results. It showcases a simple example of querying about +the current weather in Beijing. + +The SearchEnhancedQA action combines web search capabilities with natural language +processing to provide informative answers to user queries. +""" + +import asyncio + +from metagpt.actions.search_enhanced_qa import SearchEnhancedQA + + +async def main(): + """Runs a sample query through SearchEnhancedQA and prints the result.""" + + action = SearchEnhancedQA() + + query = "What is the weather like in Beijing today?" + answer = await action.run(query) + + print(f"The answer to '{query}' is:\n\n{answer}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 2a99a8d99..b5373c069 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Coroutine, Optional, Union from pydantic import TypeAdapter, model_validator @@ -160,7 +160,7 @@ class CollectLinks(Action): A list of ranked URLs. """ max_results = max(num_results * 2, 6) - results = await self.search_engine.run(query, max_results=max_results, as_string=False) + results = await self._search_urls(query, max_results=max_results) _results = "\n".join(f"{i}: {j}" for i, j in zip(range(max_results), results)) prompt = COLLECT_AND_RANKURLS_PROMPT.format(topic=topic, query=query, results=_results) logger.debug(prompt) @@ -176,6 +176,9 @@ class CollectLinks(Action): results = self.rank_func(results) return [i["link"] for i in results[:num_results]] + async def _search_urls(self, query: str, max_results: int) -> list[str]: + return await self.search_engine.run(query, max_results=max_results, as_string=False) + class WebBrowseAndSummarize(Action): """Action class to explore the web and provide summaries of articles and webpages.""" @@ -202,6 +205,7 @@ class WebBrowseAndSummarize(Action): *urls: str, query: str, system_text: str = RESEARCH_BASE_SYSTEM, + use_concurrent_summarization: bool = False, ) -> dict[str, str]: """Run the action to browse the web and provide summaries. @@ -210,6 +214,7 @@ class WebBrowseAndSummarize(Action): urls: Additional URLs to browse. query: The research question. system_text: The system text. + use_concurrent_summarization: Whether to concurrently summarize the content of the webpage by LLM. Returns: A dictionary containing the URLs as keys and their summaries as values. @@ -218,31 +223,47 @@ class WebBrowseAndSummarize(Action): if not urls: contents = [contents] - summaries = {} + all_urls = [url] + list(urls) + summarize_tasks = [ + self._summarize_content(url, content, query, system_text) for url, content in zip(all_urls, contents) + ] + + summaries = await self._execute_summarize_tasks(summarize_tasks, use_concurrent_summarization) + + return dict(summaries) + + async def _summarize_content(self, url: str, content: str, query: str, system_text: str) -> tuple[str, str]: prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}") - for u, content in zip([url, *urls], contents): - content = content.inner_text - chunk_summaries = [] - for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, 4096): - logger.debug(prompt) - summary = await self._aask(prompt, [system_text]) - if summary == "Not relevant.": - continue - chunk_summaries.append(summary) - if not chunk_summaries: - summaries[u] = None - continue - - if len(chunk_summaries) == 1: - summaries[u] = chunk_summaries[0] - continue - - content = "\n".join(chunk_summaries) - prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content) + content = content.inner_text + chunk_summaries = [] + for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, 4096): + logger.debug(prompt) summary = await self._aask(prompt, [system_text]) - summaries[u] = summary - return summaries + if summary == "Not relevant.": + continue + chunk_summaries.append(summary) + + if not chunk_summaries: + return url, None + + if len(chunk_summaries) == 1: + return url, chunk_summaries[0] + + content = "\n".join(chunk_summaries) + prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content) + summary = await self._aask(prompt, [system_text]) + return url, summary + + async def _execute_summarize_tasks( + self, tasks: list[Coroutine[Any, Any, tuple[str, str]]], use_concurrent: bool + ) -> list[tuple[str, str]]: + """Execute summarize tasks either concurrently or sequentially.""" + + if use_concurrent: + return await asyncio.gather(*tasks) + + return [await task for task in tasks] class ConductResearch(Action): diff --git a/metagpt/actions/search_enhanced_qa.py b/metagpt/actions/search_enhanced_qa.py new file mode 100644 index 000000000..d44a7057a --- /dev/null +++ b/metagpt/actions/search_enhanced_qa.py @@ -0,0 +1,220 @@ +"""Enhancing question-answering capabilities through search engine augmentation.""" + +from __future__ import annotations + +import json + +from pydantic import Field + +from metagpt.actions import Action +from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize +from metagpt.logs import logger +from metagpt.utils.common import CodeParser + +REWRITE_QUERY_PROMPT = """ +Role: You are a highly efficient assistant that provide a better search query for web search engine to answer the given question. + +I will provide you with a question. Your task is to provide a better search query for web search engine. + +## Context +### Question +{q} + +## Format Example +```json +{{ + "query": "the better search query for web search engine.", +}} +``` + +## Instructions +- Understand the question given by the user. +- Provide a better search query for web search engine to answer the given question, your answer must be written in the same language as the question. +- When rewriting, if you are unsure of the specific time, do not include the time. + +## Constraint +Format: Just print the result in json format like **Format Example**. + +## Action +Follow instructions, generate output and make sure it follows the **Constraint**. +""" + +SEARCH_ENHANCED_QA_SYSTEM_PROMPT = """ +You are a large language AI assistant built by MGX. You are given a user question, and please write clean, concise and accurate answer to the question. You will be given a set of related contexts to the question, each starting with a reference number like [[citation:x]], where x is a number. Please use the context. + +Your answer must be correct, accurate and written by an expert using an unbiased and professional tone. Please limit to 1024 tokens. Do not give any information that is not related to the question, and do not repeat. Say "information is missing on" followed by the related topic, if the given context do not provide sufficient information. Do not include [citation] in your anwser. + +Here are the set of contexts: + +{context} + +Remember, don't blindly repeat the contexts verbatim. And here is the user question: +""" + + +class SearchEnhancedQA(Action): + """Enhancing question-answering capabilities through search engine augmentation.""" + + name: str = "SearchEnhancedQA" + desc: str = "Integrating search engine results to anwser the question." + + collect_links_action: CollectLinks = Field( + default=CollectLinks(), description="Action to collect relevant links from a search engine." + ) + web_browse_and_summarize_action: WebBrowseAndSummarize = Field( + default=WebBrowseAndSummarize(), + description="Action to explore the web and provide summaries of articles and webpages.", + ) + + async def run(self, query: str, rewrite_query: bool = True) -> str: + """Answer a query by leveraging web search results. + + Args: + query (str): The original user query. + rewrite_query (bool): Whether to rewrite the query for better web search results. Defaults to True. + + Returns: + str: A detailed answer based on web search results. + + Raises: + ValueError: If the query is invalid. + """ + + self._validate_query(query) + + processed_query = await self._process_query(query, rewrite_query) + context = await self._build_context(processed_query) + + return await self._generate_answer(processed_query, context) + + def _validate_query(self, query: str) -> None: + """Validate the input query. + + Args: + query (str): The query to validate. + + Raises: + ValueError: If the query is invalid. + """ + + if not query.strip(): + raise ValueError("Query cannot be empty or contain only whitespace.") + + async def _process_query(self, query: str, should_rewrite: bool) -> str: + """Process the query, optionally rewriting it.""" + + if should_rewrite: + return await self._rewrite_query(query) + + return query + + async def _rewrite_query(self, query: str) -> str: + """Write a better search query for web search engine. + + If the rewrite process fails, the original query is returned. + + Args: + query (str): The original search query. + + Returns: + str: The rewritten query if successful, otherwise the original query. + """ + + prompt = REWRITE_QUERY_PROMPT.format(q=query) + + try: + resp = await self._aask(prompt) + rewritten_query = self._extract_rewritten_query(resp) + + logger.info(f"Query rewritten: '{query}' -> '{rewritten_query}'") + return rewritten_query + except Exception as e: + logger.warning(f"Query rewrite failed. Returning original query. Error: {e}") + return query + + def _extract_rewritten_query(self, response: str) -> str: + """Extract the rewritten query from the LLM's JSON response.""" + + resp_json = json.loads(CodeParser.parse_code(response, lang="json")) + return resp_json["query"] + + async def _build_context(self, query: str) -> str: + """Construct a context string from web search citations. + + Args: + query (str): The search query. + + Returns: + str: Formatted context with numbered citations. + """ + + citations = await self._search_citations(query) + context = "\n\n".join([f"[[citation:{i+1}]] {c}" for i, c in enumerate(citations)]) + + return context + + async def _search_citations(self, query: str) -> list[str]: + """Perform web search and summarize relevant content. + + Args: + query (str): The search query. + + Returns: + list[str]: Summaries of relevant web content. + """ + + relevant_urls = await self._collect_relevant_links(query) + if not relevant_urls: + logger.warning(f"No relevant URLs found for query: {query}") + return [] + + logger.info(f"The Relevant links are: {relevant_urls}") + + web_summaries = await self._summarize_web_content(relevant_urls, query) + if not web_summaries: + logger.warning(f"No summaries generated for query: {query}") + return [] + + citations = list(web_summaries.values()) + + return citations + + async def _collect_relevant_links(self, query: str) -> list[str]: + """Search and rank URLs relevant to the query. + + Args: + query (str): The search query. + + Returns: + list[str]: Ranked list of relevant URLs. + """ + + return await self.collect_links_action._search_and_rank_urls(topic=query, query=query) + + async def _summarize_web_content(self, urls: list[str], query: str) -> dict[str, str]: + """Fetch and summarize content from given URLs. + + Args: + urls (list[str]): List of URLs to summarize. + query (str): The original query for context. + + Returns: + dict[str, str]: Mapping of URLs to their summaries. + """ + + return await self.web_browse_and_summarize_action.run(*urls, query=query, use_concurrent_summarization=True) + + async def _generate_answer(self, query: str, context: str) -> str: + """Generate an answer using the query and context. + + Args: + query (str): The user's question. + context (str): Relevant information from web search. + + Returns: + str: Generated answer based on the context. + """ + + system_prompt = SEARCH_ENHANCED_QA_SYSTEM_PROMPT.format(context=context) + + return await self._aask(query, [system_prompt]) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index eea16bb2e..3eead9ed4 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -15,6 +15,7 @@ import ast import base64 import contextlib import csv +import functools import importlib import inspect import json @@ -23,7 +24,10 @@ import os import platform import re import sys +import time import traceback +from asyncio import iscoroutinefunction +from datetime import datetime from io import BytesIO from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union @@ -1044,3 +1048,35 @@ def tool2name(cls, methods: List[str], entry) -> Dict[str, Any]: if len(mappings) < 2: mappings[class_name] = entry return mappings + + +def log_time(method): + """A time-consuming decorator for printing execution duration.""" + + def before_call(): + start_time, cpu_start_time = time.perf_counter(), time.process_time() + logger.info(f"[{method.__name__}] started at: " f"{datetime.now().strftime('%Y-%m-%d %H:%m:%S')}") + return start_time, cpu_start_time + + def after_call(start_time, cpu_start_time): + end_time, cpu_end_time = time.perf_counter(), time.process_time() + logger.info( + f"[{method.__name__}] ended. " + f"Time elapsed: {end_time - start_time:.4} sec, CPU elapsed: {cpu_end_time - cpu_start_time:.4} sec" + ) + + @functools.wraps(method) + def timeit_wrapper(*args, **kwargs): + start_time, cpu_start_time = before_call() + result = method(*args, **kwargs) + after_call(start_time, cpu_start_time) + return result + + @functools.wraps(method) + async def timeit_wrapper_async(*args, **kwargs): + start_time, cpu_start_time = before_call() + result = await method(*args, **kwargs) + after_call(start_time, cpu_start_time) + return result + + return timeit_wrapper_async if iscoroutinefunction(method) else timeit_wrapper