From 5f1ca3ca7e2ffb38199e7ed956bc7c72928dde30 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Fri, 24 May 2024 17:18:56 +0800 Subject: [PATCH] use model_name in embedding --- metagpt/rag/factories/embedding.py | 4 ++-- tests/metagpt/rag/factories/test_embedding.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index 1599b79d9..3613fd228 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -89,9 +89,9 @@ class RAGEmbeddingFactory(GenericFactory): return OllamaEmbedding(**params) def _try_set_model_and_batch_size(self, params: dict): - """Set the model and embed_batch_size only when they are specified.""" + """Set the model_name and embed_batch_size only when they are specified.""" if config.embedding.model: - params["model"] = config.embedding.model + params["model_name"] = config.embedding.model if config.embedding.embed_batch_size: params["embed_batch_size"] = config.embedding.embed_batch_size diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py index 1fdceace8..1a9e9b2c9 100644 --- a/tests/metagpt/rag/factories/test_embedding.py +++ b/tests/metagpt/rag/factories/test_embedding.py @@ -66,7 +66,7 @@ class TestRAGEmbeddingFactory: @pytest.mark.parametrize( "model, embed_batch_size, expected_params", - [("test_model", 100, {"model": "test_model", "embed_batch_size": 100}), (None, None, {})], + [("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})], ) def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params): # Mock