diff --git a/README.md b/README.md index 44fcfab18..8f5cc5393 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ # Check https://docs.deepwisdom.ai/main/en/guide/get_started/configuration.html ```yaml llm: api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options - model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview + model: "gpt-4-turbo" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview base_url: "https://api.openai.com/v1" # or forward url / other llm url api_key: "YOUR_API_KEY" ``` diff --git a/config/config2.example.yaml b/config/config2.example.yaml index c5454ec32..7cfd70347 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -13,6 +13,16 @@ llm: # - gpt-4 8k: "gpt-4" # See for more: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ +# RAG Embedding. +# For backward compatibility, if the embedding is not set and the llm's api_type is either openai or azure, the llm's config will be used. +embedding: + api_type: "" # openai / azure / gemini / ollama etc. Check EmbeddingType for more options. + base_url: "" + api_key: "" + model: "" + api_version: "" + embed_batch_size: 100 + repair_llm_output: true # when the output is not a valid json, try to repair it proxy: "YOUR_PROXY" # for tools like requests, playwright, selenium, etc. diff --git a/config/config2.yaml b/config/config2.yaml index 8e5825b57..ba071e804 100644 --- a/config/config2.yaml +++ b/config/config2.yaml @@ -2,6 +2,6 @@ # Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py llm: api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options - model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview + model: "gpt-4-turbo" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview base_url: "https://api.openai.com/v1" # or forward url / other llm url api_key: "YOUR_API_KEY" \ No newline at end of file diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index b5111b75c..1687d556b 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -8,7 +8,6 @@ from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.rag.schema import ( - BM25RetrieverConfig, ChromaIndexConfig, ChromaRetrieverConfig, ElasticsearchIndexConfig, @@ -51,7 +50,7 @@ class RAGExample: if not self._engine: self._engine = SimpleEngine.from_docs( input_files=[DOC_PATH], - retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()], + retriever_configs=[FAISSRetrieverConfig()], ranker_configs=[LLMRankerConfig()], ) return self._engine @@ -61,7 +60,7 @@ class RAGExample: self._engine = value async def run_pipeline(self, question=QUESTION, print_title=True): - """This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like: + """This example run rag pipeline, use faiss retriever and llm ranker, will print something like: Retrieve Result: 0. Productivi..., 10.0 diff --git a/examples/stream_output_via_api.py b/examples/stream_output_via_api.py new file mode 100644 index 000000000..5961f3a08 --- /dev/null +++ b/examples/stream_output_via_api.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/3/27 9:44 +@Author : leiwu30 +@File : stream_output_via_api.py +@Description : Stream log information and communicate over the network via web api. +""" +import asyncio +import json +import socket +import threading +from contextvars import ContextVar + +from flask import Flask, Response, jsonify, request, send_from_directory + +from metagpt.const import TUTORIAL_PATH +from metagpt.logs import logger, set_llm_stream_logfunc +from metagpt.roles.tutorial_assistant import TutorialAssistant +from metagpt.utils.stream_pipe import StreamPipe + +app = Flask(__name__) + + +def stream_pipe_log(content): + print(content, end="") + stream_pipe = stream_pipe_var.get(None) + if stream_pipe: + stream_pipe.set_message(content) + + +def write_tutorial(message): + async def main(idea, stream_pipe): + stream_pipe_var.set(stream_pipe) + role = TutorialAssistant() + await role.run(idea) + + def thread_run(idea: str, stream_pipe: StreamPipe = None): + """ + Convert asynchronous function to thread function + """ + asyncio.run(main(idea, stream_pipe)) + + stream_pipe = StreamPipe() + thread = threading.Thread( + target=thread_run, + args=( + message["content"], + stream_pipe, + ), + ) + thread.start() + + while thread.is_alive(): + msg = stream_pipe.get_message() + yield stream_pipe.msg2stream(msg) + + +@app.route("/v1/chat/completions", methods=["POST"]) +def completions(): + """ + data: { + "model": "write_tutorial", + "stream": true, + "messages": [ + { + "role": "user", + "content": "Write a tutorial about MySQL" + } + ] + } + """ + + data = json.loads(request.data) + logger.info(json.dumps(data, indent=4, ensure_ascii=False)) + + # Non-streaming interfaces are not supported yet + stream_type = True if data.get("stream") else False + if not stream_type: + return jsonify({"status": 400, "msg": "Non-streaming requests are not supported, please use `stream=True`."}) + + # Only accept the last user information + # openai['model'] ~ MetaGPT['agent'] + last_message = data["messages"][-1] + model = data["model"] + + # write_tutorial + if model == "write_tutorial": + return Response(write_tutorial(last_message), mimetype="text/plain") + else: + return jsonify({"status": 400, "msg": "No suitable agent found."}) + + +@app.route("/download/") +def download_file(filename): + return send_from_directory(TUTORIAL_PATH, filename, as_attachment=True) + + +if __name__ == "__main__": + """ + curl https://$server_address:$server_port/v1/chat/completions -X POST -d '{ + "model": "write_tutorial", + "stream": true, + "messages": [ + { + "role": "user", + "content": "Write a tutorial about MySQL" + } + ] + }' + """ + server_port = 7860 + server_address = socket.gethostbyname(socket.gethostname()) + + set_llm_stream_logfunc(stream_pipe_log) + stream_pipe_var: ContextVar[StreamPipe] = ContextVar("stream_pipe") + app.run(port=server_port, host=server_address) diff --git a/metagpt/config2.py b/metagpt/config2.py index ed68b4db2..58a99c920 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Literal, Optional from pydantic import BaseModel, model_validator from metagpt.configs.browser_config import BrowserConfig +from metagpt.configs.embedding_config import EmbeddingConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig from metagpt.configs.redis_config import RedisConfig @@ -47,6 +48,9 @@ class Config(CLIParams, YamlModel): # Key Parameters llm: LLMConfig + # RAG Embedding + embedding: EmbeddingConfig = EmbeddingConfig() + # Global Proxy. Will be used if llm.proxy is not set proxy: str = "" diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py new file mode 100644 index 000000000..20de47999 --- /dev/null +++ b/metagpt/configs/embedding_config.py @@ -0,0 +1,50 @@ +from enum import Enum +from typing import Optional + +from pydantic import field_validator + +from metagpt.utils.yaml_model import YamlModel + + +class EmbeddingType(Enum): + OPENAI = "openai" + AZURE = "azure" + GEMINI = "gemini" + OLLAMA = "ollama" + + +class EmbeddingConfig(YamlModel): + """Config for Embedding. + + Examples: + --------- + api_type: "openai" + api_key: "YOU_API_KEY" + + api_type: "azure" + api_key: "YOU_API_KEY" + base_url: "YOU_BASE_URL" + api_version: "YOU_API_VERSION" + + api_type: "gemini" + api_key: "YOU_API_KEY" + + api_type: "ollama" + base_url: "YOU_BASE_URL" + model: "YOU_MODEL" + """ + + api_type: Optional[EmbeddingType] = None + api_key: Optional[str] = None + base_url: Optional[str] = None + api_version: Optional[str] = None + + model: Optional[str] = None + embed_batch_size: Optional[int] = None + + @field_validator("api_type", mode="before") + @classmethod + def check_api_type(cls, v): + if v == "": + return None + return v diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 2db441991..acac44aaf 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -43,7 +43,15 @@ class ZhiPuAILLM(BaseLLM): self.llm = ZhiPuModelAPI(api_key=self.api_key) def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: - kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3} + max_tokens = self.config.max_token if self.config.max_token > 0 else 1024 + temperature = self.config.temperature if self.config.temperature > 0.0 else 0.3 + kwargs = { + "model": self.model, + "max_tokens": max_tokens, + "messages": messages, + "stream": stream, + "temperature": temperature, + } return kwargs def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) -> dict: diff --git a/metagpt/rag/factories/base.py b/metagpt/rag/factories/base.py index fbdfbf1a8..fcfec03ec 100644 --- a/metagpt/rag/factories/base.py +++ b/metagpt/rag/factories/base.py @@ -26,6 +26,9 @@ class GenericFactory: if creator: return creator(**kwargs) + self._raise_for_key(key) + + def _raise_for_key(self, key: Any): raise ValueError(f"Creator not registered for key: {key}") diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index 4247db256..3613fd228 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -1,37 +1,103 @@ """RAG Embedding Factory.""" +from __future__ import annotations + +from typing import Any from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding +from llama_index.embeddings.gemini import GeminiEmbedding +from llama_index.embeddings.ollama import OllamaEmbedding from llama_index.embeddings.openai import OpenAIEmbedding from metagpt.config2 import config +from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.base import GenericFactory class RAGEmbeddingFactory(GenericFactory): - """Create LlamaIndex Embedding with MetaGPT's config.""" + """Create LlamaIndex Embedding with MetaGPT's embedding config.""" def __init__(self): creators = { + EmbeddingType.OPENAI: self._create_openai, + EmbeddingType.AZURE: self._create_azure, + EmbeddingType.GEMINI: self._create_gemini, + EmbeddingType.OLLAMA: self._create_ollama, + # For backward compatibility LLMType.OPENAI: self._create_openai, LLMType.AZURE: self._create_azure, } super().__init__(creators) - def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding: - """Key is LLMType, default use config.llm.api_type.""" - return super().get_instance(key or config.llm.api_type) + def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding: + """Key is EmbeddingType.""" + return super().get_instance(key or self._resolve_embedding_type()) - def _create_openai(self): - return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url) + def _resolve_embedding_type(self) -> EmbeddingType | LLMType: + """Resolves the embedding type. - def _create_azure(self): - return AzureOpenAIEmbedding( - azure_endpoint=config.llm.base_url, - api_key=config.llm.api_key, - api_version=config.llm.api_version, + If the embedding type is not specified, for backward compatibility, it checks if the LLM API type is either OPENAI or AZURE. + Raise TypeError if embedding type not found. + """ + if config.embedding.api_type: + return config.embedding.api_type + + if config.llm.api_type in [LLMType.OPENAI, LLMType.AZURE]: + return config.llm.api_type + + raise TypeError("To use RAG, please set your embedding in config2.yaml.") + + def _create_openai(self) -> OpenAIEmbedding: + params = dict( + api_key=config.embedding.api_key or config.llm.api_key, + api_base=config.embedding.base_url or config.llm.base_url, ) + self._try_set_model_and_batch_size(params) + + return OpenAIEmbedding(**params) + + def _create_azure(self) -> AzureOpenAIEmbedding: + params = dict( + api_key=config.embedding.api_key or config.llm.api_key, + azure_endpoint=config.embedding.base_url or config.llm.base_url, + api_version=config.embedding.api_version or config.llm.api_version, + ) + + self._try_set_model_and_batch_size(params) + + return AzureOpenAIEmbedding(**params) + + def _create_gemini(self) -> GeminiEmbedding: + params = dict( + api_key=config.embedding.api_key, + api_base=config.embedding.base_url, + ) + + self._try_set_model_and_batch_size(params) + + return GeminiEmbedding(**params) + + def _create_ollama(self) -> OllamaEmbedding: + params = dict( + base_url=config.embedding.base_url, + ) + + self._try_set_model_and_batch_size(params) + + return OllamaEmbedding(**params) + + def _try_set_model_and_batch_size(self, params: dict): + """Set the model_name and embed_batch_size only when they are specified.""" + if config.embedding.model: + params["model_name"] = config.embedding.model + + if config.embedding.embed_batch_size: + params["embed_batch_size"] = config.embedding.embed_batch_size + + def _raise_for_key(self, key: Any): + raise ValueError(f"The embedding type is currently not supported: `{type(key)}`, {key}") + get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 17c499b76..9fd19cab5 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -1,5 +1,5 @@ """RAG LLM.""" - +import asyncio from typing import Any from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW @@ -15,7 +15,7 @@ from pydantic import Field from metagpt.config2 import config from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM -from metagpt.utils.async_helper import run_coroutine_in_new_loop +from metagpt.utils.async_helper import NestAsyncio from metagpt.utils.token_counter import TOKEN_MAX @@ -39,7 +39,8 @@ class RAGLLM(CustomLLM): @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs)) + NestAsyncio.apply_once() + return asyncio.get_event_loop().run_until_complete(self.acomplete(prompt, **kwargs)) @llm_completion_callback() async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse: diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 581815321..c00486c82 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -1,15 +1,17 @@ """RAG schemas.""" from pathlib import Path -from typing import Any, Literal, Optional, Union +from typing import Any, ClassVar, Literal, Optional, Union from chromadb.api.types import CollectionMetadata from llama_index.core.embeddings import BaseEmbedding from llama_index.core.indices.base import BaseIndex from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from metagpt.config2 import config +from metagpt.configs.embedding_config import EmbeddingType from metagpt.rag.interface import RAGObject @@ -32,7 +34,19 @@ class IndexRetrieverConfig(BaseRetrieverConfig): class FAISSRetrieverConfig(IndexRetrieverConfig): """Config for FAISS-based retrievers.""" - dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.") + dimensions: int = Field(default=0, description="Dimensionality of the vectors for FAISS index construction.") + + _embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = { + EmbeddingType.GEMINI: 768, + EmbeddingType.OLLAMA: 4096, + } + + @model_validator(mode="after") + def check_dimensions(self): + if self.dimensions == 0: + self.dimensions = self._embedding_type_to_dimensions.get(config.embedding.api_type, 1536) + + return self class BM25RetrieverConfig(IndexRetrieverConfig): diff --git a/metagpt/utils/async_helper.py b/metagpt/utils/async_helper.py index ee440ef44..cecb20c5d 100644 --- a/metagpt/utils/async_helper.py +++ b/metagpt/utils/async_helper.py @@ -20,3 +20,18 @@ def run_coroutine_in_new_loop(coroutine) -> Any: new_loop.call_soon_threadsafe(new_loop.stop) t.join() new_loop.close() + + +class NestAsyncio: + """Make asyncio event loop reentrant.""" + + is_applied = False + + @classmethod + def apply_once(cls): + """Ensures `nest_asyncio.apply()` is called only once.""" + if not cls.is_applied: + import nest_asyncio + + nest_asyncio.apply() + cls.is_applied = True diff --git a/metagpt/utils/stream_pipe.py b/metagpt/utils/stream_pipe.py new file mode 100644 index 000000000..4c4485158 --- /dev/null +++ b/metagpt/utils/stream_pipe.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# @Time : 2024/3/27 10:00 +# @Author : leiwu30 +# @File : stream_pipe.py +# @Version : None +# @Description : None + +import json +import time +from multiprocessing import Pipe + + +class StreamPipe: + parent_conn, child_conn = Pipe() + finish: bool = False + + format_data = { + "id": "chatcmpl-96bVnBOOyPFZZxEoTIGbdpFcVEnur", + "object": "chat.completion.chunk", + "created": 1711361191, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_3bc1b5746c", + "choices": [ + {"index": 0, "delta": {"role": "assistant", "content": "content"}, "logprobs": None, "finish_reason": None} + ], + } + + def set_message(self, msg): + self.parent_conn.send(msg) + + def get_message(self, timeout: int = 3): + if self.child_conn.poll(timeout): + return self.child_conn.recv() + else: + return None + + def msg2stream(self, msg): + self.format_data["created"] = int(time.time()) + self.format_data["choices"][0]["delta"]["content"] = msg + return f"data: {json.dumps(self.format_data, ensure_ascii=False)}\n".encode("utf-8") diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 0ba2daa89..0ca22cf35 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -28,6 +28,7 @@ TOKEN_COSTS = { "gpt-4-32k-0314": {"prompt": 0.06, "completion": 0.12}, "gpt-4-0613": {"prompt": 0.06, "completion": 0.12}, "gpt-4-turbo-preview": {"prompt": 0.01, "completion": 0.03}, + "gpt-4-turbo": {"prompt": 0.01, "completion": 0.03}, "gpt-4-0125-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03}, "gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator @@ -147,6 +148,7 @@ FIREWORKS_GRADE_TOKEN_COSTS = { TOKEN_MAX = { "gpt-4-0125-preview": 128000, "gpt-4-turbo-preview": 128000, + "gpt-4-turbo": 128000, "gpt-4-1106-preview": 128000, "gpt-4-vision-preview": 128000, "gpt-4-1106-vision-preview": 128000, @@ -202,6 +204,7 @@ def count_message_tokens(messages, model="gpt-3.5-turbo-0125"): "gpt-4-32k-0314", "gpt-4-0613", "gpt-4-32k-0613", + "gpt-4-turbo", "gpt-4-turbo-preview", "gpt-4-0125-preview", "gpt-4-1106-preview", diff --git a/setup.py b/setup.py index c54ace90a..e43bf3ed0 100644 --- a/setup.py +++ b/setup.py @@ -32,12 +32,15 @@ extras_require = { "llama-index-core==0.10.15", "llama-index-embeddings-azure-openai==0.1.6", "llama-index-embeddings-openai==0.1.5", + "llama-index-embeddings-gemini==0.1.6", + "llama-index-embeddings-ollama==0.1.2", "llama-index-llms-azure-openai==0.1.4", "llama-index-readers-file==0.1.4", "llama-index-retrievers-bm25==0.1.3", "llama-index-vector-stores-faiss==0.1.1", "llama-index-vector-stores-elasticsearch==0.1.6", "llama-index-vector-stores-chroma==0.1.6", + "docx2txt==0.8", ], "android_assistant": ["pyshine==0.0.9", "opencv-python==4.6.0.66"], } diff --git a/tests/metagpt/rag/factories/test_embedding.py b/tests/metagpt/rag/factories/test_embedding.py index 1ded6b4a8..1a9e9b2c9 100644 --- a/tests/metagpt/rag/factories/test_embedding.py +++ b/tests/metagpt/rag/factories/test_embedding.py @@ -1,5 +1,6 @@ import pytest +from metagpt.configs.embedding_config import EmbeddingType from metagpt.configs.llm_config import LLMType from metagpt.rag.factories.embedding import RAGEmbeddingFactory @@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory: self.embedding_factory = RAGEmbeddingFactory() @pytest.fixture - def mock_openai_embedding(self, mocker): + def mock_config(self, mocker): + return mocker.patch("metagpt.rag.factories.embedding.config") + + @staticmethod + def mock_openai_embedding(mocker): return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding") - @pytest.fixture - def mock_azure_embedding(self, mocker): + @staticmethod + def mock_azure_embedding(mocker): return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding") - def test_get_rag_embedding_openai(self, mock_openai_embedding): - # Exec - self.embedding_factory.get_rag_embedding(LLMType.OPENAI) + @staticmethod + def mock_gemini_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding") - # Assert - mock_openai_embedding.assert_called_once() + @staticmethod + def mock_ollama_embedding(mocker): + return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding") - def test_get_rag_embedding_azure(self, mock_azure_embedding): - # Exec - self.embedding_factory.get_rag_embedding(LLMType.AZURE) - - # Assert - mock_azure_embedding.assert_called_once() - - def test_get_rag_embedding_default(self, mocker, mock_openai_embedding): + @pytest.mark.parametrize( + ("mock_func", "embedding_type"), + [ + (mock_openai_embedding, LLMType.OPENAI), + (mock_azure_embedding, LLMType.AZURE), + (mock_openai_embedding, EmbeddingType.OPENAI), + (mock_azure_embedding, EmbeddingType.AZURE), + (mock_gemini_embedding, EmbeddingType.GEMINI), + (mock_ollama_embedding, EmbeddingType.OLLAMA), + ], + ) + def test_get_rag_embedding(self, mock_func, embedding_type, mocker): # Mock - mock_config = mocker.patch("metagpt.rag.factories.embedding.config") + mock = mock_func(mocker) + + # Exec + self.embedding_factory.get_rag_embedding(embedding_type) + + # Assert + mock.assert_called_once() + + def test_get_rag_embedding_default(self, mocker, mock_config): + # Mock + mock_openai_embedding = self.mock_openai_embedding(mocker) + + mock_config.embedding.api_type = None mock_config.llm.api_type = LLMType.OPENAI # Exec @@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory: # Assert mock_openai_embedding.assert_called_once() + + @pytest.mark.parametrize( + "model, embed_batch_size, expected_params", + [("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})], + ) + def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params): + # Mock + mock_config.embedding.model = model + mock_config.embedding.embed_batch_size = embed_batch_size + + # Setup + test_params = {} + + # Exec + self.embedding_factory._try_set_model_and_batch_size(test_params) + + # Assert + assert test_params == expected_params + + def test_resolve_embedding_type(self, mock_config): + # Mock + mock_config.embedding.api_type = EmbeddingType.OPENAI + + # Exec + embedding_type = self.embedding_factory._resolve_embedding_type() + + # Assert + assert embedding_type == EmbeddingType.OPENAI + + def test_resolve_embedding_type_exception(self, mock_config): + # Mock + mock_config.embedding.api_type = None + mock_config.llm.api_type = LLMType.GEMINI + + # Assert + with pytest.raises(TypeError): + self.embedding_factory._resolve_embedding_type() + + def test_raise_for_key(self): + with pytest.raises(ValueError): + self.embedding_factory._raise_for_key("key")