diff --git a/config/config2.example.yaml b/config/config2.example.yaml index 64cce630f..0fe11df4e 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -59,3 +59,21 @@ iflytek_api_key: "YOUR_API_KEY" iflytek_api_secret: "YOUR_API_SECRET" metagpt_tti_url: "YOUR_MODEL_URL" + +models: +# "YOUR_MODEL_NAME_1 or YOUR_API_TYPE_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo +# api_type: "openai" # or azure / ollama / groq etc. +# base_url: "YOUR_BASE_URL" +# api_key: "YOUR_API_KEY" +# proxy: "YOUR_PROXY" # for LLM API requests +# # timeout: 600 # Optional. If set to 0, default value is 300. +# # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ +# pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's +# "YOUR_MODEL_NAME_2 or YOUR_API_TYPE_2": # api_type: "openai" # or azure / ollama / groq etc. +# api_type: "openai" # or azure / ollama / groq etc. +# base_url: "YOUR_BASE_URL" +# api_key: "YOUR_API_KEY" +# proxy: "YOUR_PROXY" # for LLM API requests +# # timeout: 600 # Optional. If set to 0, default value is 300. +# # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ +# pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's \ No newline at end of file diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index 1b93213f7..20c052aa9 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -8,12 +8,14 @@ from __future__ import annotations -from typing import Optional, Union +from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, model_validator from metagpt.actions.action_node import ActionNode +from metagpt.configs.models_config import ModelsConfig from metagpt.context_mixin import ContextMixin +from metagpt.provider.llm_provider_registry import create_llm_instance from metagpt.schema import ( CodePlanAndChangeContext, CodeSummarizeContext, @@ -35,6 +37,19 @@ class Action(SerializationMixin, ContextMixin, BaseModel): prefix: str = "" # aask*时会加上prefix,作为system_message desc: str = "" # for skill manager node: ActionNode = Field(default=None, exclude=True) + # The model name or API type of LLM of the `models` in the `config2.yaml`; + # Using `None` to use the `llm` configuration in the `config2.yaml`. + llm_name_or_type: Optional[str] = None + + @model_validator(mode="after") + @classmethod + def _update_private_llm(cls, data: Any) -> Any: + config = ModelsConfig.default().get(data.llm_name_or_type) + if config: + llm = create_llm_instance(config) + llm.cost_manager = data.llm.cost_manager + data.llm = llm + return data @property def repo(self) -> ProjectRepo: diff --git a/metagpt/configs/models_config.py b/metagpt/configs/models_config.py new file mode 100644 index 000000000..bc4897fec --- /dev/null +++ b/metagpt/configs/models_config.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +models_config.py + +This module defines the ModelsConfig class for handling configuration of LLM models. + +Attributes: + CONFIG_ROOT (Path): Root path for configuration files. + METAGPT_ROOT (Path): Root path for MetaGPT files. + +Classes: + ModelsConfig (YamlModel): Configuration class for LLM models. +""" +from pathlib import Path +from typing import Dict, List, Optional + +from pydantic import Field, field_validator + +from metagpt.config2 import merge_dict +from metagpt.configs.llm_config import LLMConfig +from metagpt.const import CONFIG_ROOT, METAGPT_ROOT +from metagpt.utils.yaml_model import YamlModel + + +class ModelsConfig(YamlModel): + """ + Configuration class for `models` in `config2.yaml`. + + Attributes: + models (Dict[str, LLMConfig]): Dictionary mapping model names or types to LLMConfig objects. + + Methods: + update_llm_model(cls, value): Validates and updates LLM model configurations. + from_home(cls, path): Loads configuration from ~/.metagpt/config2.yaml. + default(cls): Loads default configuration from predefined paths. + get(self, name_or_type: str) -> Optional[LLMConfig]: Retrieves LLMConfig by name or API type. + """ + + models: Dict[str, LLMConfig] = Field(default_factory=dict) + + @field_validator("models", mode="before") + @classmethod + def update_llm_model(cls, value): + """ + Validates and updates LLM model configurations. + + Args: + value (Dict[str, Union[LLMConfig, dict]]): Dictionary of LLM configurations. + + Returns: + Dict[str, Union[LLMConfig, dict]]: Updated dictionary of LLM configurations. + """ + for key, config in value.items(): + if isinstance(config, LLMConfig): + config.model = config.model or key + elif isinstance(config, dict): + config["model"] = config.get("model") or key + return value + + @classmethod + def from_home(cls, path): + """ + Loads configuration from ~/.metagpt/config2.yaml. + + Args: + path (str): Relative path to configuration file. + + Returns: + Optional[ModelsConfig]: Loaded ModelsConfig object or None if file doesn't exist. + """ + pathname = CONFIG_ROOT / path + if not pathname.exists(): + return None + return ModelsConfig.from_yaml_file(pathname) + + @classmethod + def default(cls): + """ + Loads default configuration from predefined paths. + + Returns: + ModelsConfig: Default ModelsConfig object. + """ + default_config_paths: List[Path] = [ + METAGPT_ROOT / "config/config2.yaml", + CONFIG_ROOT / "config2.yaml", + ] + + dicts = [ModelsConfig.read_yaml(path) for path in default_config_paths] + final = merge_dict(dicts) + return ModelsConfig(**final) + + def get(self, name_or_type: str) -> Optional[LLMConfig]: + """ + Retrieves LLMConfig object by name or API type. + + Args: + name_or_type (str): Name or API type of the LLM model. + + Returns: + Optional[LLMConfig]: LLMConfig object if found, otherwise None. + """ + if not name_or_type: + return None + model = self.models.get(name_or_type) + if model: + return model + for m in self.models.values(): + if m.api_type == name_or_type: + return m + return None diff --git a/tests/data/config/config2.yaml b/tests/data/config/config2.yaml new file mode 100644 index 000000000..8c9fc0703 --- /dev/null +++ b/tests/data/config/config2.yaml @@ -0,0 +1,27 @@ +llm: + api_type: "openai" # or azure / ollama / groq etc. + base_url: "YOUR_gpt-3.5-turbo_BASE_URL" + api_key: "YOUR_gpt-3.5-turbo_API_KEY" + model: "gpt-3.5-turbo" # or gpt-3.5-turbo + # proxy: "YOUR_gpt-3.5-turbo_PROXY" # for LLM API requests + # timeout: 600 # Optional. If set to 0, default value is 300. + # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ + pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's + +models: + "YOUR_MODEL_NAME_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo + api_type: "openai" # or azure / ollama / groq etc. + base_url: "YOUR_MODEL_1_BASE_URL" + api_key: "YOUR_MODEL_1_API_KEY" + # proxy: "YOUR_MODEL_1_PROXY" # for LLM API requests + # timeout: 600 # Optional. If set to 0, default value is 300. + # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ + pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's + "YOUR_MODEL_NAME_2": # model: "gpt-4-turbo" # or gpt-3.5-turbo + api_type: "openai" # or azure / ollama / groq etc. + base_url: "YOUR_MODEL_2_BASE_URL" + api_key: "YOUR_MODEL_2_API_KEY" + proxy: "YOUR_MODEL_2_PROXY" # for LLM API requests + # timeout: 600 # Optional. If set to 0, default value is 300. + # Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/ + pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's \ No newline at end of file diff --git a/tests/metagpt/configs/__init__.py b/tests/metagpt/configs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/metagpt/configs/test_models_config.py b/tests/metagpt/configs/test_models_config.py new file mode 100644 index 000000000..cfbf1f96b --- /dev/null +++ b/tests/metagpt/configs/test_models_config.py @@ -0,0 +1,34 @@ +import pytest + +from metagpt.actions.talk_action import TalkAction +from metagpt.configs.models_config import ModelsConfig +from metagpt.const import METAGPT_ROOT, TEST_DATA_PATH +from metagpt.utils.common import aread, awrite + + +@pytest.mark.asyncio +async def test_models_configs(context): + default_model = ModelsConfig.default() + assert default_model is not None + + models = ModelsConfig.from_yaml_file(TEST_DATA_PATH / "config/config2.yaml") + assert models + + default_models = ModelsConfig.default() + backup = "" + if not default_models.models: + backup = await aread(filename=METAGPT_ROOT / "config/config2.yaml") + test_data = await aread(filename=TEST_DATA_PATH / "config/config2.yaml") + await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=test_data) + + try: + action = TalkAction(context=context, i_context="who are you?", llm_name_or_type="YOUR_MODEL_NAME_1") + assert action.private_llm.config.model == "YOUR_MODEL_NAME_1" + assert context.config.llm.model != "YOUR_MODEL_NAME_1" + finally: + if backup: + await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=backup) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])