mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-21 14:05:17 +02:00
Merge branch 'fix-rag' into 'mgx_ops'
Fix rag See merge request pub/MetaGPT!317
This commit is contained in:
commit
aa8e2fa8c3
5 changed files with 34 additions and 16 deletions
|
|
@ -8,7 +8,7 @@ import json
|
|||
from pathlib import Path
|
||||
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.exp_pool import exp_manager
|
||||
from metagpt.exp_pool import get_exp_manager
|
||||
from metagpt.exp_pool.schema import EntryType, Experience, Metric, Score
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import aread
|
||||
|
|
@ -45,7 +45,7 @@ async def add_exp(req: str, resp: str, tag: str, metric: Metric = None):
|
|||
tag=tag,
|
||||
metric=metric or Metric(score=Score(val=10, reason="Manual")),
|
||||
)
|
||||
|
||||
exp_manager = get_exp_manager()
|
||||
exp_manager.config.exp_pool.enable_write = True
|
||||
exp_manager.create_exp(exp)
|
||||
logger.info(f"New experience created for the request `{req[:10]}`.")
|
||||
|
|
@ -79,7 +79,7 @@ async def add_exps_from_file(tag: str, filepath: Path):
|
|||
|
||||
def query_exps_count():
|
||||
"""Queries and logs the total count of experiences in the pool."""
|
||||
|
||||
exp_manager = get_exp_manager()
|
||||
count = exp_manager.get_exps_count()
|
||||
logger.info(f"Experiences Count: {count}")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ This script creates a new experience, logs its creation, and then queries for ex
|
|||
|
||||
import asyncio
|
||||
|
||||
from metagpt.exp_pool import exp_manager
|
||||
from metagpt.exp_pool import get_exp_manager
|
||||
from metagpt.exp_pool.schema import EntryType, Experience
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
|
@ -15,6 +15,7 @@ async def main():
|
|||
# Define the simple request and response
|
||||
req = "Simple req"
|
||||
resp = "Simple resp"
|
||||
exp_manager = get_exp_manager()
|
||||
|
||||
# Add the new experience
|
||||
exp = Experience(req=req, resp=resp, entry_type=EntryType.MANUAL)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class RAGEmbeddingFactory(GenericFactory):
|
|||
LLMType.AZURE: self._create_azure,
|
||||
}
|
||||
super().__init__(creators)
|
||||
self.config = config if self.config else Config.default()
|
||||
self.config = config if config else Config.default()
|
||||
|
||||
def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding:
|
||||
"""Key is EmbeddingType."""
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from llama_index.core.llms import (
|
|||
LLMMetadata,
|
||||
)
|
||||
from llama_index.core.llms.callbacks import llm_completion_callback
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.llm import LLM
|
||||
|
|
@ -30,19 +30,30 @@ class RAGLLM(CustomLLM):
|
|||
num_output: int = -1
|
||||
model_name: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def update_from_config(self):
|
||||
def __init__(
|
||||
self,
|
||||
model_infer: BaseLLM,
|
||||
context_window: int = -1,
|
||||
num_output: int = -1,
|
||||
model_name: str = "",
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
config = Config.default()
|
||||
if self.context_window < 0:
|
||||
self.context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
|
||||
if context_window < 0:
|
||||
context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
|
||||
|
||||
if self.num_output < 0:
|
||||
self.num_output = config.llm.max_token
|
||||
if num_output < 0:
|
||||
num_output = config.llm.max_token
|
||||
|
||||
if not self.model_name:
|
||||
self.model_name = config.llm.model
|
||||
if not model_name:
|
||||
model_name = config.llm.model
|
||||
|
||||
return self
|
||||
self.model_infer = model_infer
|
||||
self.context_window = context_window
|
||||
self.num_output = num_output
|
||||
self.model_name = model_name
|
||||
|
||||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
|
|
|
|||
|
|
@ -155,7 +155,10 @@ class TestExpCache:
|
|||
|
||||
@pytest.fixture
|
||||
def mock_config(self, mocker):
|
||||
return mocker.patch("metagpt.exp_pool.decorator.config")
|
||||
config = Config.default().model_copy(deep=True)
|
||||
default = mocker.patch("metagpt.config2.Config.default")
|
||||
default.return_value = config
|
||||
return config
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_disabled(self, mock_config, mock_exp_manager):
|
||||
|
|
@ -171,7 +174,9 @@ class TestExpCache:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_enabled_no_perfect_exp(self, mock_config, mock_exp_manager, mock_scorer):
|
||||
mock_config.exp_pool.enabled = True
|
||||
mock_config.exp_pool.enable_read = True
|
||||
mock_config.exp_pool.enable_write = True
|
||||
mock_exp_manager.query_exps.return_value = []
|
||||
|
||||
@exp_cache(manager=mock_exp_manager, scorer=mock_scorer)
|
||||
|
|
@ -185,6 +190,7 @@ class TestExpCache:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_enabled_with_perfect_exp(self, mock_config, mock_exp_manager, mock_perfect_judge):
|
||||
mock_config.exp_pool.enabled = True
|
||||
mock_config.exp_pool.enable_read = True
|
||||
perfect_exp = Experience(req="test", resp="perfect_result")
|
||||
mock_exp_manager.query_exps.return_value = [perfect_exp]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue