Resolve PR issues

This commit is contained in:
Stitch-z 2023-09-20 12:02:46 +08:00
parent c99f4bffe9
commit 97ab38ac7c
5 changed files with 18 additions and 17 deletions

View file

@ -111,7 +111,7 @@ class CollectLinks(Action):
system_text = system_text if system_text else RESEARCH_TOPIC_SYSTEM.format(topic=topic)
keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text])
try:
keywords = OutputParser.extract_struct(keywords, "list")
keywords = OutputParser.extract_struct(keywords, list)
keywords = parse_obj_as(list[str], keywords)
except Exception as e:
logger.exception(f"fail to get keywords related to the research topic \"{topic}\" for {e}")
@ -131,7 +131,7 @@ class CollectLinks(Action):
logger.debug(prompt)
queries = await self._aask(prompt, [system_text])
try:
queries = OutputParser.extract_struct(queries, "list")
queries = OutputParser.extract_struct(queries, list)
queries = parse_obj_as(list[str], queries)
except Exception as e:
logger.exception(f"fail to break down the research question due to {e}")
@ -159,7 +159,7 @@ class CollectLinks(Action):
logger.debug(prompt)
indices = await self._aask(prompt)
try:
indices = OutputParser.extract_struct(indices, "list")
indices = OutputParser.extract_struct(indices, list)
assert all(isinstance(i, int) for i in indices)
except Exception as e:
logger.exception(f"fail to rank results for {e}")

View file

@ -37,7 +37,7 @@ class WriteDirectory(Action):
"""
prompt = DIRECTORY_PROMPT.format(topic=topic, language=self.language)
resp = await self._aask(prompt=prompt)
return OutputParser.extract_struct(resp, "dict")
return OutputParser.extract_struct(resp, dict)
class WriteContent(Action):

View file

@ -11,7 +11,7 @@ import inspect
import os
import platform
import re
from typing import List, Tuple
from typing import List, Tuple, Union
from metagpt.logs import logger
@ -152,7 +152,7 @@ class OutputParser:
return parsed_data
@classmethod
def extract_struct(cls, text: str, data_type: str) -> Tuple[list, dict]:
def extract_struct(cls, text: str, data_type: Union[type(list), type(dict)]) -> Union[list, dict]:
"""Extracts and parses a specified type of structure (dictionary or list) from the given text.
The text only contains a list or dictionary, which may have nested structures.
@ -176,8 +176,8 @@ class OutputParser:
>>> # Output: {"x": 1, "y": {"a": 2, "b": {"c": 3}}}
"""
# Find the first "[" or "{" and the last "]" or "}"
start_index = text.find("[" if data_type == "list" else "{")
end_index = text.rfind("]" if data_type == "list" else "}")
start_index = text.find("[" if data_type is list else "{")
end_index = text.rfind("]" if data_type is list else "}")
if start_index != -1 and end_index != -1:
# Extract the structure part
@ -188,7 +188,7 @@ class OutputParser:
result = ast.literal_eval(structure_text)
# Ensure the result matches the specified data type
if (data_type == "list" and isinstance(result, list)) or (data_type == "dict" and isinstance(result, dict)):
if isinstance(result, list) or isinstance(result, dict):
return result
raise ValueError(f"The extracted structure is not a {data_type}.")