Merge branch 'fix-rag' into 'mgx_ops'

Fix rag

See merge request pub/MetaGPT!317
This commit is contained in:
张雷 2024-08-15 13:12:09 +00:00
commit aa8e2fa8c3
5 changed files with 34 additions and 16 deletions

View file

@ -8,7 +8,7 @@ import json
from pathlib import Path
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.exp_pool import exp_manager
from metagpt.exp_pool import get_exp_manager
from metagpt.exp_pool.schema import EntryType, Experience, Metric, Score
from metagpt.logs import logger
from metagpt.utils.common import aread
@ -45,7 +45,7 @@ async def add_exp(req: str, resp: str, tag: str, metric: Metric = None):
tag=tag,
metric=metric or Metric(score=Score(val=10, reason="Manual")),
)
exp_manager = get_exp_manager()
exp_manager.config.exp_pool.enable_write = True
exp_manager.create_exp(exp)
logger.info(f"New experience created for the request `{req[:10]}`.")
@ -79,7 +79,7 @@ async def add_exps_from_file(tag: str, filepath: Path):
def query_exps_count():
"""Queries and logs the total count of experiences in the pool."""
exp_manager = get_exp_manager()
count = exp_manager.get_exps_count()
logger.info(f"Experiences Count: {count}")

View file

@ -6,7 +6,7 @@ This script creates a new experience, logs its creation, and then queries for ex
import asyncio
from metagpt.exp_pool import exp_manager
from metagpt.exp_pool import get_exp_manager
from metagpt.exp_pool.schema import EntryType, Experience
from metagpt.logs import logger
@ -15,6 +15,7 @@ async def main():
# Define the simple request and response
req = "Simple req"
resp = "Simple resp"
exp_manager = get_exp_manager()
# Add the new experience
exp = Experience(req=req, resp=resp, entry_type=EntryType.MANUAL)

View file

@ -29,7 +29,7 @@ class RAGEmbeddingFactory(GenericFactory):
LLMType.AZURE: self._create_azure,
}
super().__init__(creators)
self.config = config if self.config else Config.default()
self.config = config if config else Config.default()
def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding:
"""Key is EmbeddingType."""

View file

@ -10,7 +10,7 @@ from llama_index.core.llms import (
LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
from pydantic import Field, model_validator
from pydantic import Field
from metagpt.config2 import Config
from metagpt.llm import LLM
@ -30,19 +30,30 @@ class RAGLLM(CustomLLM):
num_output: int = -1
model_name: str = ""
@model_validator(mode="after")
def update_from_config(self):
def __init__(
self,
model_infer: BaseLLM,
context_window: int = -1,
num_output: int = -1,
model_name: str = "",
*args,
**kwargs
):
super().__init__(*args, **kwargs)
config = Config.default()
if self.context_window < 0:
self.context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
if context_window < 0:
context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
if self.num_output < 0:
self.num_output = config.llm.max_token
if num_output < 0:
num_output = config.llm.max_token
if not self.model_name:
self.model_name = config.llm.model
if not model_name:
model_name = config.llm.model
return self
self.model_infer = model_infer
self.context_window = context_window
self.num_output = num_output
self.model_name = model_name
@property
def metadata(self) -> LLMMetadata:

View file

@ -155,7 +155,10 @@ class TestExpCache:
@pytest.fixture
def mock_config(self, mocker):
return mocker.patch("metagpt.exp_pool.decorator.config")
config = Config.default().model_copy(deep=True)
default = mocker.patch("metagpt.config2.Config.default")
default.return_value = config
return config
@pytest.mark.asyncio
async def test_exp_cache_disabled(self, mock_config, mock_exp_manager):
@ -171,7 +174,9 @@ class TestExpCache:
@pytest.mark.asyncio
async def test_exp_cache_enabled_no_perfect_exp(self, mock_config, mock_exp_manager, mock_scorer):
mock_config.exp_pool.enabled = True
mock_config.exp_pool.enable_read = True
mock_config.exp_pool.enable_write = True
mock_exp_manager.query_exps.return_value = []
@exp_cache(manager=mock_exp_manager, scorer=mock_scorer)
@ -185,6 +190,7 @@ class TestExpCache:
@pytest.mark.asyncio
async def test_exp_cache_enabled_with_perfect_exp(self, mock_config, mock_exp_manager, mock_perfect_judge):
mock_config.exp_pool.enabled = True
mock_config.exp_pool.enable_read = True
perfect_exp = Experience(req="test", resp="perfect_result")
mock_exp_manager.query_exps.return_value = [perfect_exp]