diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 23ef79555..33511f534 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -63,3 +63,21 @@ metagpt_tti_url: "YOUR_MODEL_URL" omniparse: api_key: "YOUR_API_KEY" base_url: "YOUR_BASE_URL" + +models: +# "YOUR_MODEL_NAME_1 or YOUR_API_TYPE_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo +# api_type: "openai" # or azure / ollama / groq etc. +# base_url: "YOUR_BASE_URL" +# api_key: "YOUR_API_KEY" +# proxy: "YOUR_PROXY" # for LLM API requests +# # timeout: 600 # Optional. If set to 0, default value is 300. +# # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ +# pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's +# "YOUR_MODEL_NAME_2 or YOUR_API_TYPE_2": # api_type: "openai" # or azure / ollama / groq etc. +# api_type: "openai" # or azure / ollama / groq etc. +# base_url: "YOUR_BASE_URL" +# api_key: "YOUR_API_KEY" +# proxy: "YOUR_PROXY" # for LLM API requests +# # timeout: 600 # Optional. If set to 0, default value is 300. +# # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ +# pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 1b93213f7..20c052aa9 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -8,12 +8,14 @@ from __future__ import annotations -from typing import Optional, Union +from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, model_validator from metagpt.actions.action_node import ActionNode +from metagpt.configs.models_config import ModelsConfig from metagpt.context_mixin import ContextMixin +from metagpt.provider.llm_provider_registry import create_llm_instance from metagpt.schema import ( CodePlanAndChangeContext, CodeSummarizeContext, @@ -35,6 +37,19 @@ class Action(SerializationMixin, ContextMixin, BaseModel): prefix: str = "" # aask*时会加上prefix,作为system_message desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) + # The model name or API type of LLM of the `models` in the `config2.yaml`; + # Using `None` to use the `llm` configuration in the `config2.yaml`. + llm_name_or_type: Optional[str] = None + + @model_validator(mode="after") + @classmethod + def _update_private_llm(cls, data: Any) -> Any: + config = ModelsConfig.default().get(data.llm_name_or_type) + if config: + llm = create_llm_instance(config) + llm.cost_manager = data.llm.cost_manager + data.llm = llm + return data @property def repo(self) -> ProjectRepo: diff --git a/metagpt/configs/models_config.py b/metagpt/configs/models_config.py new file mode 100644 index 000000000..bc4897fec --- /dev/null +++ b/metagpt/configs/models_config.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +models_config.py + +This module defines the ModelsConfig class for handling configuration of LLM models. + +Attributes: + CONFIG_ROOT (Path): Root path for configuration files. + METAGPT_ROOT (Path): Root path for MetaGPT files. + +Classes: + ModelsConfig (YamlModel): Configuration class for LLM models. +""" +from pathlib import Path +from typing import Dict, List, Optional + +from pydantic import Field, field_validator + +from metagpt.config2 import merge_dict +from metagpt.configs.llm_config import LLMConfig +from metagpt.const import CONFIG_ROOT, METAGPT_ROOT +from metagpt.utils.yaml_model import YamlModel + + +class ModelsConfig(YamlModel): + """ + Configuration class for `models` in `config2.yaml`. + + Attributes: + models (Dict[str, LLMConfig]): Dictionary mapping model names or types to LLMConfig objects. + + Methods: + update_llm_model(cls, value): Validates and updates LLM model configurations. + from_home(cls, path): Loads configuration from ~/.metagpt/config2.yaml. + default(cls): Loads default configuration from predefined paths. + get(self, name_or_type: str) -> Optional[LLMConfig]: Retrieves LLMConfig by name or API type. + """ + + models: Dict[str, LLMConfig] = Field(default_factory=dict) + + @field_validator("models", mode="before") + @classmethod + def update_llm_model(cls, value): + """ + Validates and updates LLM model configurations. + + Args: + value (Dict[str, Union[LLMConfig, dict]]): Dictionary of LLM configurations. + + Returns: + Dict[str, Union[LLMConfig, dict]]: Updated dictionary of LLM configurations. + """ + for key, config in value.items(): + if isinstance(config, LLMConfig): + config.model = config.model or key + elif isinstance(config, dict): + config["model"] = config.get("model") or key + return value + + @classmethod + def from_home(cls, path): + """ + Loads configuration from ~/.metagpt/config2.yaml. + + Args: + path (str): Relative path to configuration file. + + Returns: + Optional[ModelsConfig]: Loaded ModelsConfig object or None if file doesn't exist. + """ + pathname = CONFIG_ROOT / path + if not pathname.exists(): + return None + return ModelsConfig.from_yaml_file(pathname) + + @classmethod + def default(cls): + """ + Loads default configuration from predefined paths. + + Returns: + ModelsConfig: Default ModelsConfig object. + """ + default_config_paths: List[Path] = [ + METAGPT_ROOT / "config/config2.yaml", + CONFIG_ROOT / "config2.yaml", + ] + + dicts = [ModelsConfig.read_yaml(path) for path in default_config_paths] + final = merge_dict(dicts) + return ModelsConfig(**final) + + def get(self, name_or_type: str) -> Optional[LLMConfig]: + """ + Retrieves LLMConfig object by name or API type. + + Args: + name_or_type (str): Name or API type of the LLM model. + + Returns: + Optional[LLMConfig]: LLMConfig object if found, otherwise None. + """ + if not name_or_type: + return None + model = self.models.get(name_or_type) + if model: + return model + for m in self.models.values(): + if m.api_type == name_or_type: + return m + return None diff --git a/metagpt/rag/parsers/omniparse.py b/metagpt/rag/parsers/omniparse.py index 85227dc06..ec08e38f1 100644 --- a/metagpt/rag/parsers/omniparse.py +++ b/metagpt/rag/parsers/omniparse.py @@ -6,10 +6,9 @@ from typing import List, Optional, Union from llama_index.core import Document from llama_index.core.async_utils import run_jobs from llama_index.core.readers.base import BaseReader -from llama_parse import ResultType from metagpt.logs import logger -from metagpt.rag.schema import OmniParseOptions, OmniParseType +from metagpt.rag.schema import OmniParseOptions, OmniParseType, ParseResultType from metagpt.utils.async_helper import NestAsyncio from metagpt.utils.omniparse_client import OmniParseClient @@ -44,9 +43,9 @@ class OmniParse(BaseReader): self.parse_options.parse_type = parse_type @result_type.setter - def result_type(self, result_type: Union[str, ResultType]): + def result_type(self, result_type: Union[str, ParseResultType]): if isinstance(result_type, str): - result_type = ResultType(result_type) + result_type = ParseResultType(result_type) self.parse_options.result_type = result_type async def _aload_data( diff --git a/metagpt/utils/redis.py b/metagpt/utils/redis.py index 7a640563a..9f5ef8a92 100644 --- a/metagpt/utils/redis.py +++ b/metagpt/utils/redis.py @@ -10,7 +10,7 @@ from __future__ import annotations import traceback from datetime import timedelta -import aioredis # https://aioredis.readthedocs.io/en/latest/getting-started/ +import redis.asyncio as aioredis from metagpt.configs.redis_config import RedisConfig from metagpt.logs import logger diff --git a/requirements.txt b/requirements.txt index dc8a86ae2..4d8d7f32e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,7 +43,9 @@ wrapt==1.15.0 #aiohttp_jinja2 # azure-cognitiveservices-speech~=1.31.0 # Used by metagpt/tools/azure_tts.py #aioboto3~=12.4.0 # Used by metagpt/utils/s3.py -aioredis~=2.0.1 # Used by metagpt/utils/redis.py +redis~=5.0.0 # Used by metagpt/utils/redis.py +curl-cffi~=0.7.0 +httplib2~=0.22.0 websocket-client~=1.8.0 aiofiles==23.2.1 gitpython==3.1.40 diff --git a/tests/data/config/config2.yaml b/tests/data/config/config2.yaml new file mode 100644 index 000000000..8c9fc0703 --- /dev/null +++ b/tests/data/config/config2.yaml @@ -0,0 +1,27 @@ +llm: + api_type: "openai" # or azure / ollama / groq etc. + base_url: "YOUR_gpt-3.5-turbo_BASE_URL" + api_key: "YOUR_gpt-3.5-turbo_API_KEY" + model: "gpt-3.5-turbo" # or gpt-3.5-turbo + # proxy: "YOUR_gpt-3.5-turbo_PROXY" # for LLM API requests + # timeout: 600 # Optional. If set to 0, default value is 300. + # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ + pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's + +models: + "YOUR_MODEL_NAME_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo + api_type: "openai" # or azure / ollama / groq etc. + base_url: "YOUR_MODEL_1_BASE_URL" + api_key: "YOUR_MODEL_1_API_KEY" + # proxy: "YOUR_MODEL_1_PROXY" # for LLM API requests + # timeout: 600 # Optional. If set to 0, default value is 300. + # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ + pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's + "YOUR_MODEL_NAME_2": # model: "gpt-4-turbo" # or gpt-3.5-turbo + api_type: "openai" # or azure / ollama / groq etc. + base_url: "YOUR_MODEL_2_BASE_URL" + api_key: "YOUR_MODEL_2_API_KEY" + proxy: "YOUR_MODEL_2_PROXY" # for LLM API requests + # timeout: 600 # Optional. If set to 0, default value is 300. + # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ + pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's \ No newline at end of file diff --git a/tests/metagpt/configs/__init__.py b/tests/metagpt/configs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/metagpt/configs/test_models_config.py b/tests/metagpt/configs/test_models_config.py new file mode 100644 index 000000000..cfbf1f96b --- /dev/null +++ b/tests/metagpt/configs/test_models_config.py @@ -0,0 +1,34 @@ +import pytest + +from metagpt.actions.talk_action import TalkAction +from metagpt.configs.models_config import ModelsConfig +from metagpt.const import METAGPT_ROOT, TEST_DATA_PATH +from metagpt.utils.common import aread, awrite + + +@pytest.mark.asyncio +async def test_models_configs(context): + default_model = ModelsConfig.default() + assert default_model is not None + + models = ModelsConfig.from_yaml_file(TEST_DATA_PATH / "config/config2.yaml") + assert models + + default_models = ModelsConfig.default() + backup = "" + if not default_models.models: + backup = await aread(filename=METAGPT_ROOT / "config/config2.yaml") + test_data = await aread(filename=TEST_DATA_PATH / "config/config2.yaml") + await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=test_data) + + try: + action = TalkAction(context=context, i_context="who are you?", llm_name_or_type="YOUR_MODEL_NAME_1") + assert action.private_llm.config.model == "YOUR_MODEL_NAME_1" + assert context.config.llm.model != "YOUR_MODEL_NAME_1" + finally: + if backup: + await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=backup) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])