Merge branch 'mgx_ops' into feat-exp-pool

This commit is contained in:
seehi 2024-06-03 11:20:29 +08:00
commit 46aa6f1975
33 changed files with 691 additions and 153 deletions

View file

@ -65,7 +65,7 @@ class ExecuteNbCode(Action):
"""execute notebook code block, return result to llm, and display it."""
nb: NotebookNode
nb_client: NotebookClient = None
nb_client: RealtimeOutputNotebookClient = None
console: Console
interaction: str
timeout: int = 600
@ -78,11 +78,15 @@ class ExecuteNbCode(Action):
interaction=("ipython" if self.is_ipython() else "terminal"),
)
self.reporter = NotebookReporter()
self.set_nb_client()
def set_nb_client(self):
self.nb_client = RealtimeOutputNotebookClient(
nb,
timeout=timeout,
self.nb,
timeout=self.timeout,
resources={"metadata": {"path": DEFAULT_WORKSPACE_ROOT}},
notebook_reporter=self.reporter,
coalesce_streams=True,
)
async def build(self):
@ -118,7 +122,7 @@ class ExecuteNbCode(Action):
# sleep 1s to wait for the kernel to be cleaned up completely
await asyncio.sleep(1)
await self.build()
self.nb_client = NotebookClient(self.nb, timeout=self.timeout)
self.set_nb_client()
def add_code_cell(self, code: str):
self.nb.cells.append(new_code_cell(source=code))

View file

@ -153,7 +153,7 @@ class WriteCode(Action):
root_path = self.context.src_workspace if self.context.src_workspace else ""
coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path))
coding_context.code_doc.content = code
await reporter.async_report(self.repo.workdir / coding_context.code_doc.root_relative_path, "path")
await reporter.async_report(coding_context.code_doc, "document")
return coding_context
@staticmethod

View file

@ -17,6 +17,7 @@ from metagpt.const import REQUIREMENT_FILENAME
from metagpt.logs import logger
from metagpt.schema import CodingContext
from metagpt.utils.common import CodeParser
from metagpt.utils.report import EditorReporter
PROMPT_TEMPLATE = """
# System
@ -128,16 +129,21 @@ class WriteCodeReview(Action):
i_context: CodingContext = Field(default_factory=CodingContext)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename):
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, doc):
filename = doc.filename
cr_rsp = await self._aask(context_prompt + cr_prompt)
result = CodeParser.parse_block("Code Review Result", cr_rsp)
if "LGTM" in result:
return result, None
# if LBTM, rewrite code
rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}"
code_rsp = await self._aask(rewrite_prompt)
code = CodeParser.parse_code(text=code_rsp)
async with EditorReporter(enable_llm_stream=True) as reporter:
await reporter.async_report({"type": "code", "filename": filename, "src_path": doc.root_relative_path}, "meta")
rewrite_prompt = f"{context_prompt}\n{cr_rsp}\n{REWRITE_CODE_TEMPLATE.format(filename=filename)}"
code_rsp = await self._aask(rewrite_prompt)
code = CodeParser.parse_code(text=code_rsp)
doc.content = code
await reporter.async_report(doc, "document")
return result, code
async def run(self, *args, **kwargs) -> CodingContext:
@ -182,7 +188,7 @@ class WriteCodeReview(Action):
f"len(self.i_context.code_doc.content)={len2}"
)
result, rewrited_code = await self.write_code_review_and_rewrite(
context_prompt, cr_prompt, self.i_context.code_doc.filename
context_prompt, cr_prompt, self.i_context.code_doc
)
if "LBTM" in result:
iterative_code = rewrited_code

View file

@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Literal, Optional
from pydantic import BaseModel, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.embedding_config import EmbeddingConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.redis_config import RedisConfig
@ -49,6 +50,9 @@ class Config(CLIParams, YamlModel):
# Key Parameters
llm: LLMConfig
# RAG Embedding
embedding: EmbeddingConfig = EmbeddingConfig()
# Global Proxy. Not used by LLM, but by other tools such as browsers.
proxy: str = ""

View file

@ -0,0 +1,50 @@
from enum import Enum
from typing import Optional
from pydantic import field_validator
from metagpt.utils.yaml_model import YamlModel
class EmbeddingType(Enum):
OPENAI = "openai"
AZURE = "azure"
GEMINI = "gemini"
OLLAMA = "ollama"
class EmbeddingConfig(YamlModel):
"""Config for Embedding.
Examples:
---------
api_type: "openai"
api_key: "YOU_API_KEY"
api_type: "azure"
api_key: "YOU_API_KEY"
base_url: "YOU_BASE_URL"
api_version: "YOU_API_VERSION"
api_type: "gemini"
api_key: "YOU_API_KEY"
api_type: "ollama"
base_url: "YOU_BASE_URL"
model: "YOU_MODEL"
"""
api_type: Optional[EmbeddingType] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
api_version: Optional[str] = None
model: Optional[str] = None
embed_batch_size: Optional[int] = None
@field_validator("api_type", mode="before")
@classmethod
def check_api_type(cls, v):
if v == "":
return None
return v

View file

@ -4,7 +4,7 @@ import json
import os
from typing import Any, Optional, Union
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core import SimpleDirectoryReader
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
@ -63,7 +63,7 @@ class SimpleEngine(RetrieverQueryEngine):
response_synthesizer: Optional[BaseSynthesizer] = None,
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
index: Optional[BaseIndex] = None,
transformations: Optional[list[TransformComponent]] = None,
) -> None:
super().__init__(
retriever=retriever,
@ -71,7 +71,7 @@ class SimpleEngine(RetrieverQueryEngine):
node_postprocessors=node_postprocessors,
callback_manager=callback_manager,
)
self.index = index
self._transformations = transformations or self._default_transformations()
@classmethod
def from_docs(
@ -103,12 +103,17 @@ class SimpleEngine(RetrieverQueryEngine):
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
cls._fix_document_metadata(documents)
index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations = transformations or cls._default_transformations()
nodes = run_transformations(documents, transformations=transformations)
return cls._from_nodes(
nodes=nodes,
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_objs(
@ -137,12 +142,15 @@ class SimpleEngine(RetrieverQueryEngine):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(
return cls._from_nodes(
nodes=nodes,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_index(
@ -161,6 +169,13 @@ class SimpleEngine(RetrieverQueryEngine):
"""Inplement tools.SearchInterface"""
return await self.aquery(content)
def retrieve(self, query: QueryType) -> list[NodeWithScore]:
query_bundle = QueryBundle(query) if isinstance(query, str) else query
nodes = super().retrieve(query_bundle)
self._try_reconstruct_obj(nodes)
return nodes
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Allow query to be str."""
query_bundle = QueryBundle(query) if isinstance(query, str) else query
@ -176,7 +191,7 @@ class SimpleEngine(RetrieverQueryEngine):
documents = SimpleDirectoryReader(input_files=input_files).load_data()
self._fix_document_metadata(documents)
nodes = run_transformations(documents, transformations=self.index._transformations)
nodes = run_transformations(documents, transformations=self._transformations)
self._save_nodes(nodes)
def add_objs(self, objs: list[RAGObject]):
@ -192,6 +207,29 @@ class SimpleEngine(RetrieverQueryEngine):
self._persist(str(persist_dir), **kwargs)
@classmethod
def _from_nodes(
cls,
nodes: list[BaseNode],
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
embed_model = cls._resolve_embed_model(embed_model, retriever_configs)
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, nodes=nodes, embed_model=embed_model)
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
transformations=transformations,
)
@classmethod
def _from_index(
cls,
@ -201,6 +239,7 @@ class SimpleEngine(RetrieverQueryEngine):
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
@ -208,7 +247,6 @@ class SimpleEngine(RetrieverQueryEngine):
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
index=index,
)
def _ensure_retriever_modifiable(self):
@ -259,3 +297,7 @@ class SimpleEngine(RetrieverQueryEngine):
return MockEmbedding(embed_dim=1)
return embed_model or get_rag_embedding()
@staticmethod
def _default_transformations():
return [SentenceSplitter()]

View file

@ -26,6 +26,9 @@ class GenericFactory:
if creator:
return creator(**kwargs)
self._raise_for_key(key)
def _raise_for_key(self, key: Any):
raise ValueError(f"Creator not registered for key: {key}")
@ -33,19 +36,26 @@ class ConfigBasedFactory(GenericFactory):
"""Designed to get objects based on object type."""
def get_instance(self, key: Any, **kwargs) -> Any:
"""Key is config, such as a pydantic model.
"""Get instance by the type of key.
Call func by the type of key, and the key will be passed to func.
Key is config, such as a pydantic model, call func by the type of key, and the key will be passed to func.
Raise Exception if key not found.
"""
creator = self._creators.get(type(key))
if creator:
return creator(key, **kwargs)
self._raise_for_key(key)
def _raise_for_key(self, key: Any):
raise ValueError(f"Unknown config: `{type(key)}`, {key}")
@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs.
Return None if not found.
"""
if config is not None and hasattr(config, key):
val = getattr(config, key)
if val is not None:
@ -54,6 +64,4 @@ class ConfigBasedFactory(GenericFactory):
if key in kwargs:
return kwargs[key]
raise KeyError(
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
)
return None

View file

@ -1,37 +1,103 @@
"""RAG Embedding Factory."""
from __future__ import annotations
from typing import Any
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.base import GenericFactory
class RAGEmbeddingFactory(GenericFactory):
"""Create LlamaIndex Embedding with MetaGPT's config."""
"""Create LlamaIndex Embedding with MetaGPT's embedding config."""
def __init__(self):
creators = {
EmbeddingType.OPENAI: self._create_openai,
EmbeddingType.AZURE: self._create_azure,
EmbeddingType.GEMINI: self._create_gemini,
EmbeddingType.OLLAMA: self._create_ollama,
# For backward compatibility
LLMType.OPENAI: self._create_openai,
LLMType.AZURE: self._create_azure,
}
super().__init__(creators)
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
"""Key is LLMType, default use config.llm.api_type."""
return super().get_instance(key or config.llm.api_type)
def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding:
"""Key is EmbeddingType."""
return super().get_instance(key or self._resolve_embedding_type())
def _create_openai(self):
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
def _resolve_embedding_type(self) -> EmbeddingType | LLMType:
"""Resolves the embedding type.
def _create_azure(self):
return AzureOpenAIEmbedding(
azure_endpoint=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE.
Raise TypeError if embedding type not found.
"""
if config.embedding.api_type:
return config.embedding.api_type
if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]:
return config.llm.api_type
raise TypeError("To use RAG, please set your embedding in config2.yaml.")
def _create_openai(self) -> OpenAIEmbedding:
params = dict(
api_key=config.embedding.api_key or config.llm.api_key,
api_base=config.embedding.base_url or config.llm.base_url,
)
self._try_set_model_and_batch_size(params)
return OpenAIEmbedding(**params)
def _create_azure(self) -> AzureOpenAIEmbedding:
params = dict(
api_key=config.embedding.api_key or config.llm.api_key,
azure_endpoint=config.embedding.base_url or config.llm.base_url,
api_version=config.embedding.api_version or config.llm.api_version,
)
self._try_set_model_and_batch_size(params)
return AzureOpenAIEmbedding(**params)
def _create_gemini(self) -> GeminiEmbedding:
params = dict(
api_key=config.embedding.api_key,
api_base=config.embedding.base_url,
)
self._try_set_model_and_batch_size(params)
return GeminiEmbedding(**params)
def _create_ollama(self) -> OllamaEmbedding:
params = dict(
base_url=config.embedding.base_url,
)
self._try_set_model_and_batch_size(params)
return OllamaEmbedding(**params)
def _try_set_model_and_batch_size(self, params: dict):
"""Set the model_name and embed_batch_size only when they are specified."""
if config.embedding.model:
params["model_name"] = config.embedding.model
if config.embedding.embed_batch_size:
params["embed_batch_size"] = config.embedding.embed_batch_size
def _raise_for_key(self, key: Any):
raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}")
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding

View file

@ -48,7 +48,7 @@ class RAGIndexFactory(ConfigBasedFactory):
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)

View file

@ -1,5 +1,5 @@
"""RAG LLM."""
import asyncio
from typing import Any
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW
@ -15,7 +15,7 @@ from pydantic import Field
from metagpt.config2 import config
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.async_helper import run_coroutine_in_new_loop
from metagpt.utils.async_helper import NestAsyncio
from metagpt.utils.token_counter import TOKEN_MAX
@ -39,7 +39,8 @@ class RAGLLM(CustomLLM):
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
NestAsyncio.apply_once()
return asyncio.get_event_loop().run_until_complete(self.acomplete(prompt, **kwargs))
@llm_completion_callback()
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:

View file

@ -8,6 +8,8 @@ from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
from metagpt.rag.schema import (
BaseRankerConfig,
BGERerankConfig,
CohereRerankConfig,
ColbertRerankConfig,
LLMRankerConfig,
ObjectRankerConfig,
@ -22,6 +24,8 @@ class RankerFactory(ConfigBasedFactory):
LLMRankerConfig: self._create_llm_ranker,
ColbertRerankConfig: self._create_colbert_ranker,
ObjectRankerConfig: self._create_object_ranker,
CohereRerankConfig: self._create_cohere_rerank,
BGERerankConfig: self._create_bge_rerank,
}
super().__init__(creators)
@ -45,6 +49,26 @@ class RankerFactory(ConfigBasedFactory):
)
return ColbertRerank(**config.model_dump())
def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRerank:
try:
from llama_index.postprocessor.cohere_rerank import CohereRerank
except ImportError:
raise ImportError(
"`llama-index-postprocessor-cohere-rerank` package not found, please run `pip install llama-index-postprocessor-cohere-rerank`"
)
return CohereRerank(**config.model_dump())
def _create_bge_rerank(self, config: BGERerankConfig, **kwargs) -> LLMRerank:
try:
from llama_index.postprocessor.flag_embedding_reranker import (
FlagEmbeddingReranker,
)
except ImportError:
raise ImportError(
"`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`"
)
return FlagEmbeddingReranker(**config.model_dump())
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
return ObjectSortPostprocessor(**config.model_dump())

View file

@ -1,10 +1,13 @@
"""RAG Retriever Factory."""
import copy
from functools import wraps
import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.schema import BaseNode
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
@ -24,10 +27,25 @@ from metagpt.rag.schema import (
ElasticsearchKeywordRetrieverConfig,
ElasticsearchRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
def get_or_build_index(build_index_func):
"""Decorator to get or build an index.
Get index using `_extract_index` method, if not found, using build_index_func.
"""
@wraps(build_index_func)
def wrapper(self, config, **kwargs):
index = self._extract_index(config, **kwargs)
if index is not None:
return index
return build_index_func(self, config, **kwargs)
return wrapper
class RetrieverFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
@ -54,48 +72,79 @@ class RetrieverFactory(ConfigBasedFactory):
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
def _create_default(self, **kwargs) -> RAGRetriever:
return self._extract_index(**kwargs).as_retriever()
index = self._extract_index(None, **kwargs) or self._build_default_index(**kwargs)
return index.as_retriever()
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_faiss_index(config, **kwargs)
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
index = self._extract_index(config, **kwargs)
nodes = list(index.docstore.docs.values()) if index else self._extract_nodes(config, **kwargs)
return DynamicBM25Retriever(nodes=list(config.index.docstore.docs.values()), **config.model_dump())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_chroma_index(config, **kwargs)
return ChromaRetriever(**config.model_dump())
def _create_es_retriever(self, config: ElasticsearchRetrieverConfig, **kwargs) -> ElasticsearchRetriever:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
config.index = self._build_es_index(config, **kwargs)
return ElasticsearchRetriever(**config.model_dump())
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)
def _extract_nodes(self, config: BaseRetrieverConfig = None, **kwargs) -> list[BaseNode]:
return self._val_from_config_or_kwargs("nodes", config, **kwargs)
def _extract_embed_model(self, config: BaseRetrieverConfig = None, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
def _build_default_index(self, **kwargs) -> VectorStoreIndex:
index = VectorStoreIndex(
nodes=self._extract_nodes(**kwargs),
embed_model=self._extract_embed_model(**kwargs),
)
return index
@get_or_build_index
def _build_faiss_index(self, config: FAISSRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_chroma_index(self, config: ChromaRetrieverConfig, **kwargs) -> VectorStoreIndex:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
return self._build_index_from_vector_store(config, vector_store, **kwargs)
@get_or_build_index
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
vector_store = ElasticsearchStore(**config.store_config.model_dump())
return self._build_index_from_vector_store(config, vector_store, **kwargs)
def _build_index_from_vector_store(
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
self, config: BaseRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
nodes=list(old_index.docstore.docs.values()),
index = VectorStoreIndex(
nodes=self._extract_nodes(config, **kwargs),
storage_context=storage_context,
embed_model=old_index._embed_model,
embed_model=self._extract_embed_model(config, **kwargs),
)
return new_index
return index
get_retriever = RetrieverFactory().get_retriever

View file

@ -40,8 +40,10 @@ class DynamicBM25Retriever(BM25Retriever):
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
self.bm25 = BM25Okapi(self._corpus)
self._index.insert_nodes(nodes, **kwargs)
if self._index:
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
self._index.storage_context.persist(persist_dir)
if self._index:
self._index.storage_context.persist(persist_dir)

View file

@ -1,14 +1,17 @@
"""RAG schemas."""
from pathlib import Path
from typing import Any, Literal, Union
from typing import Any, ClassVar, Literal, Optional, Union
from chromadb.api.types import CollectionMetadata
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from metagpt.config2 import config
from metagpt.configs.embedding_config import EmbeddingType
from metagpt.rag.interface import RAGObject
@ -31,7 +34,19 @@ class IndexRetrieverConfig(BaseRetrieverConfig):
class FAISSRetrieverConfig(IndexRetrieverConfig):
"""Config for FAISS-based retrievers."""
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.")
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
EmbeddingType.GEMINI: 768,
EmbeddingType.OLLAMA: 4096,
}
@model_validator(mode="after")
def check_dimensions(self):
if self.dimensions == 0:
self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536)
return self
class BM25RetrieverConfig(IndexRetrieverConfig):
@ -45,6 +60,9 @@ class ChromaRetrieverConfig(IndexRetrieverConfig):
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class ElasticsearchStoreConfig(BaseModel):
@ -101,6 +119,16 @@ class ColbertRerankConfig(BaseRankerConfig):
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
class CohereRerankConfig(BaseRankerConfig):
model: str = Field(default="rerank-english-v3.0")
api_key: str = Field(default="YOUR_COHERE_API")
class BGERerankConfig(BaseRankerConfig):
model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.")
use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.")
class ObjectRankerConfig(BaseRankerConfig):
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
@ -130,6 +158,9 @@ class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
metadata: Optional[CollectionMetadata] = Field(
default=None, description="Optional metadata to associate with the collection"
)
class BM25IndexConfig(BaseIndexConfig):

View file

@ -20,6 +20,7 @@ from metagpt.strategy.thinking_command import (
)
from metagpt.tools.tool_recommend import BM25ToolRecommender
from metagpt.utils.common import CodeParser
from metagpt.utils.report import ThoughtReporter
class DataAnalyst(DataInterpreter):
@ -82,8 +83,8 @@ class DataAnalyst(DataInterpreter):
available_commands=prepare_command_prompt(self.available_commands),
)
context = self.llm.format_msg(self.working_memory.get() + [Message(content=prompt, role="user")])
rsp = await self.llm.aask(context)
async with ThoughtReporter():
rsp = await self.llm.aask(context)
self.commands = json.loads(CodeParser.parse_code(block=None, text=rsp))
self.rc.memory.add(Message(content=rsp, role="assistant"))

View file

@ -15,6 +15,7 @@ from metagpt.schema import Message, Task, TaskResult
from metagpt.strategy.task_type import TaskType
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from metagpt.utils.common import CodeParser
from metagpt.utils.report import ThoughtReporter
REACT_THINK_PROMPT = """
# User Requirement
@ -73,7 +74,8 @@ class DataInterpreter(Role):
return True
prompt = REACT_THINK_PROMPT.format(user_requirement=self.user_requirement, context=context)
rsp = await self.llm.aask(prompt)
async with ThoughtReporter():
rsp = await self.llm.aask(prompt)
rsp_dict = json.loads(CodeParser.parse_code(text=rsp))
self.working_memory.add(Message(content=rsp_dict["thoughts"], role="assistant"))
need_action = rsp_dict["state"]

View file

@ -20,6 +20,7 @@ from metagpt.strategy.thinking_command import (
run_commands,
)
from metagpt.utils.common import CodeParser
from metagpt.utils.report import ThoughtReporter
class TeamLeader(Role):
@ -69,7 +70,8 @@ class TeamLeader(Role):
)
context = self.llm.format_msg(self.get_memories(k=10) + [Message(content=prompt, role="user")])
rsp = await self.llm.aask(context, system_msgs=[SYSTEM_PROMPT])
async with ThoughtReporter():
rsp = await self.llm.aask(context, system_msgs=[SYSTEM_PROMPT])
self.commands = json.loads(CodeParser.parse_code(text=rsp))
self.rc.memory.add(Message(content=rsp, role="assistant"))

View file

@ -1,9 +1,12 @@
from __future__ import annotations
import contextlib
from playwright.async_api import async_playwright
from metagpt.utils.file import MemoryFileSystem
from uuid import uuid4
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.parse_html import simplify_html
from metagpt.utils.report import BrowserReporter
@ -35,16 +38,49 @@ class Browser:
print("Now on page ", url)
await self._view()
async def open_new_page(self, url: str):
async def open_new_page(self, url: str, timeout: float = 30000):
"""open a new page in the browser and view the page"""
async with self.reporter as reporter:
page = await self.browser.new_page()
await reporter.async_report(url, "url")
await page.goto(url)
await page.goto(url, timeout=timeout)
self.pages[url] = page
await self._set_current_page(page, url)
await reporter.async_report(page, "page")
async def view_page_element_to_scrape(self, requirement: str, keep_links: bool = False) -> None:
"""view the HTML content of current page to understand the structure. When executed, the content will be printed out
Args:
requirement (str): Providing a clear and detailed requirement helps in focusing the inspection on the desired elements.
keep_links (bool): Whether to keep the hyperlinks in the HTML content. Set to True if links are required
"""
html = await self.current_page.content()
html = simplify_html(html, url=self.current_page.url, keep_links=keep_links)
mem_fs = MemoryFileSystem()
filename = f"{uuid4().hex}.html"
with mem_fs.open(filename, "w") as f:
f.write(html)
# Since RAG is an optional optimization, if it fails, the simplified HTML can be used as a fallback.
with contextlib.suppress(Exception):
from metagpt.rag.engines import SimpleEngine # avoid circular import
# TODO make `from_docs` asynchronous
engine = SimpleEngine.from_docs(input_files=[filename], fs=mem_fs)
nodes = await engine.aretrieve(requirement)
html = "\n".join(i.text for i in nodes)
mem_fs.rm_file(filename)
print(html)
async def get_page_content(self) -> str:
"""Get the HTML content of current page."""
html = await self.current_page.content()
html_content = html.strip()
return html_content
async def switch_page(self, url: str):
"""switch to an opened page in the browser and view the page"""
if url in self.pages:
@ -152,8 +188,8 @@ class Browser:
async def _view(self, keep_len: int = 5000) -> str:
"""simulate human viewing the current page, return the visible text with links"""
visible_text_with_links = await self.current_page.evaluate(VIEW_CONTENT_JS)
print("The visible text and their links (if any): ", visible_text_with_links[:keep_len])
# visible_text_with_links = await self.current_page.evaluate(VIEW_CONTENT_JS)
# print("The visible text and their links (if any): ", visible_text_with_links[:keep_len])
# html_content = await self._view_page_html(keep_len=keep_len)
# print("The html content: ", html_content)

View file

@ -100,7 +100,7 @@ class Editor:
file_path=file_path,
block_content=block_content,
)
self.resource.report(result.file_path, "path")
self.resource.report(result.file_path, "path", extra={"type": "search", "line": i, "symbol": symbol})
return result
return None

View file

@ -9,7 +9,6 @@ from github.Issue import Issue
from github.PullRequest import PullRequest
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.git_repository import GitBranch, GitRepository
@register_tool(tags=["software development", "git", "Commit the changes and push to remote git repository."])
@ -18,7 +17,7 @@ async def git_push(
access_token: str,
comments: str = "Commit",
new_branch: str = "",
) -> GitBranch:
) -> "GitBranch":
"""
Pushes changes from a local Git repository to its remote counterpart.
@ -49,6 +48,8 @@ async def git_push(
base branch:'master', head branch:'feature/new', repo_name:'iorisa/snake-game'
"""
from metagpt.utils.git_repository import GitRepository
if not GitRepository.is_git_dir(local_path):
raise ValueError("Invalid local git repository")

View file

@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any:
new_loop.call_soon_threadsafe(new_loop.stop)
t.join()
new_loop.close()
class NestAsyncio:
"""Make asyncio event loop reentrant."""
is_applied = False
@classmethod
def apply_once(cls):
"""Ensures `nest_asyncio.apply()` is called only once."""
if not cls.is_applied:
import nest_asyncio
nest_asyncio.apply()
cls.is_applied = True

View file

@ -646,7 +646,7 @@ def role_raise_decorator(func):
raise Exception(format_trackback_info(limit=None))
except Exception as e:
if self.latest_observed_msg:
logger.warning(
logger.exception(
"There is a exception in role's execution, in order to resume, "
"we delete the newest role communication message in the role's memory."
)

View file

@ -9,6 +9,7 @@
from pathlib import Path
import aiofiles
from fsspec.implementations.memory import MemoryFileSystem as _MemoryFileSystem
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
@ -68,3 +69,10 @@ class File:
content = b"".join(chunks)
logger.debug(f"Successfully read file, the path of file: {file_path}")
return content
class MemoryFileSystem(_MemoryFileSystem):
@classmethod
def _strip_protocol(cls, path):
return super()._strip_protocol(str(path))

View file

@ -7,6 +7,8 @@ from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup
from pydantic import BaseModel, PrivateAttr
import htmlmin
class WebPage(BaseModel):
inner_text: str
@ -38,6 +40,22 @@ class WebPage(BaseModel):
elif url.startswith(("http://", "https://")):
yield urljoin(self.url, url)
def get_slim_soup(self, keep_links: bool = False):
soup = _get_soup(self.html)
keep_attrs = ["class"]
if keep_links:
keep_attrs.append("href")
for i in soup.find_all(True):
for name in list(i.attrs):
if i[name] and name not in keep_attrs:
del i[name]
for i in soup.find_all(["svg", "img", "video", "audio"]):
i.decompose()
return soup
def get_html_content(page: str, base: str):
soup = _get_soup(page)
@ -48,7 +66,12 @@ def get_html_content(page: str, base: str):
def _get_soup(page: str):
soup = BeautifulSoup(page, "html.parser")
# https://stackoverflow.com/questions/1936466/how-to-scrape-only-visible-webpage-text-with-beautifulsoup
for s in soup(["style", "script", "[document]", "head", "title"]):
for s in soup(["style", "script", "[document]", "head", "title", "footer"]):
s.extract()
return soup
def simplify_html(html: str, url: str, keep_links: bool = False):
html = WebPage(inner_text="", html=html, url=url).get_slim_soup(keep_links).decode()
return htmlmin.minify(html, remove_comments=True, remove_empty_space=True)

View file

@ -39,6 +39,7 @@ class BlockType(str, Enum):
GALLERY = "Gallery"
NOTEBOOK = "Notebook"
DOCS = "Docs"
THOUGHT = "Thought"
END_MARKER_NAME = "end_marker"
@ -55,23 +56,23 @@ class ResourceReporter(BaseModel):
callback_url: str = Field(METAGPT_REPORTER_DEFAULT_URL, description="The URL to which the report should be sent")
_llm_task: Optional[asyncio.Task] = PrivateAttr(None)
def report(self, value: Any, name: str):
def report(self, value: Any, name: str, extra: Optional[dict] = None):
"""Synchronously report resource observation data.
Args:
value: The data to report.
name: The type name of the data.
"""
return self._report(value, name)
return self._report(value, name, extra)
async def async_report(self, value: Any, name: str):
async def async_report(self, value: Any, name: str, extra: Optional[dict] = None):
"""Asynchronously report resource observation data.
Args:
value: The data to report.
name: The type name of the data.
"""
return await self._async_report(value, name)
return await self._async_report(value, name, extra)
@classmethod
def set_report_fn(cls, fn: Callable):
@ -100,20 +101,20 @@ class ResourceReporter(BaseModel):
"""
cls._async_report = fn
def _report(self, value: Any, name: str):
def _report(self, value: Any, name: str, extra: Optional[dict] = None):
if not self.callback_url:
return
data = self._format_data(value, name)
data = self._format_data(value, name, extra)
resp = requests.post(self.callback_url, json=data)
resp.raise_for_status()
return resp.text
async def _async_report(self, value: Any, name: str):
async def _async_report(self, value: Any, name: str, extra: Optional[dict] = None):
if not self.callback_url:
return
data = self._format_data(value, name)
data = self._format_data(value, name, extra)
url = self.callback_url
_result = urlparse(url)
sessiion_kwargs = {}
@ -129,9 +130,16 @@ class ResourceReporter(BaseModel):
resp.raise_for_status()
return await resp.text()
def _format_data(self, value, name):
def _format_data(self, value, name, extra):
data = self.model_dump(mode="json", exclude=("callback_url", "llm_stream"))
data["value"] = str(value) if isinstance(value, Path) else value
if isinstance(value, BaseModel):
value = value.model_dump(mode="json")
elif isinstance(value, Path):
value = str(value)
if name == "path":
value = os.path.abspath(value)
data["value"] = value
data["name"] = name
role = CURRENT_ROLE.get(None)
if role:
@ -139,6 +147,8 @@ class ResourceReporter(BaseModel):
else:
role_name = os.environ.get("METAGPT_ROLE")
data["role"] = role_name
if extra:
data["extra"] = extra
return data
def __enter__(self):
@ -252,6 +262,16 @@ class TaskReporter(ObjectReporter):
block: Literal[BlockType.TASK] = BlockType.TASK
class ThoughtReporter(ObjectReporter):
"""Reporter for object resources to Task Block."""
block: Literal[BlockType.THOUGHT] = BlockType.THOUGHT
async def __aenter__(self):
await self.async_report({})
return await super().__aenter__()
class FileReporter(ResourceReporter):
"""File resource callback for reporting complete file paths.
@ -259,13 +279,23 @@ class FileReporter(ResourceReporter):
if the file can be partially output for display first, use streaming callback.
"""
def report(self, value: Union[Path, dict, Any], name: Literal["path", "meta", "content"] = "path"):
def report(
self,
value: Union[Path, dict, Any],
name: Literal["path", "meta", "content"] = "path",
extra: Optional[dict] = None,
):
"""Report file resource synchronously."""
return super().report(value, name)
return super().report(value, name, extra)
async def async_report(self, value: Path, name: Literal["path", "meta", "content"] = "path"):
async def async_report(
self,
value: Union[Path, dict, Any],
name: Literal["path", "meta", "content"] = "path",
extra: Optional[dict] = None,
):
"""Report file resource asynchronously."""
return await super().async_report(value, name)
return await super().async_report(value, name, extra)
class NotebookReporter(FileReporter):