代码优化

This commit is contained in:
liuminhui 2024-07-22 18:38:44 +08:00
parent 758acf8ba6
commit f9d3a8c521
9 changed files with 57 additions and 65 deletions

View file

@ -38,7 +38,7 @@ from metagpt.rag.factories import (
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.parser import OmniParse
from metagpt.rag.parsers import OmniParse
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
@ -106,7 +106,7 @@ class SimpleEngine(RetrieverQueryEngine):
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
file_extractor = cls._get_file_extractor(file_type=".pdf")
file_extractor = cls._get_file_extractor()
documents = SimpleDirectoryReader(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
).load_data()
@ -312,29 +312,21 @@ class SimpleEngine(RetrieverQueryEngine):
return [SentenceSplitter()]
@staticmethod
def _get_file_extractor(file_type: str = None) -> dict[str:BaseReader]:
def _get_file_extractor() -> dict[str:BaseReader]:
"""
Get the file extractor for a specified file type.
If no file type is provided, return all available extractors.
Currently, only OmniParse PDF extraction is supported.
Args:
file_type: The type of file for which the extractor is needed. Defaults to None.
Get the file extractor.
Currently, only PDF use OmniParse
Returns:
dict[file_type: BaseReader]
"""
file_extractor_mapping: dict[str:BaseReader] = {}
file_extractor: dict[str:BaseReader] = {}
if config.omniparse.base_url:
pdf_parser = OmniParse(
api_key=config.omniparse.api_key,
base_url=config.omniparse.base_url,
parse_options=OmniParseOptions(parse_type=OmniParseType.PDF, result_type=ParseResultType.MD),
)
file_extractor_mapping[".pdf"] = pdf_parser
file_extractor[".pdf"] = pdf_parser
if file_type:
file_extractor = file_extractor_mapping.get(file_type)
return {file_type: file_extractor} if file_extractor else {}
return file_extractor_mapping
return file_extractor

View file

@ -1,3 +0,0 @@
from metagpt.rag.parser.omniparse import OmniParse
__all__ = ["OmniParse"]

View file

@ -0,0 +1,3 @@
from metagpt.rag.parsers.omniparse import OmniParse
__all__ = ["OmniParse"]

View file

@ -232,7 +232,7 @@ class ParseResultType(str, Enum):
class OmniParseOptions(BaseModel):
"""OmniParse可选配置"""
"""OmniParse Options config"""
result_type: ParseResultType = Field(default=ParseResultType.MD, description="OmniParse result_type")
parse_type: OmniParseType = Field(default=OmniParseType.DOCUMENT, description="OmniParse parse_type")