diff --git a/metagpt/context.py b/metagpt/context.py index 3dfd52d58..0add4c71a 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -12,10 +12,14 @@ from typing import Any, Optional from pydantic import BaseModel, ConfigDict from metagpt.config2 import Config -from metagpt.configs.llm_config import LLMConfig +from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import create_llm_instance -from metagpt.utils.cost_manager import CostManager +from metagpt.utils.cost_manager import ( + CostManager, + FireworksCostManager, + TokenCostManager, +) from metagpt.utils.git_repository import GitRepository from metagpt.utils.project_repo import ProjectRepo @@ -80,12 +84,21 @@ class Context(BaseModel): # self._llm = None # return self._llm + def _select_costmanager(self, llm_config: LLMConfig) -> CostManager: + """Return a CostManager instance""" + if llm_config.api_type == LLMType.FIREWORKS: + return FireworksCostManager() + elif llm_config.api_type == LLMType.OPEN_LLM: + return TokenCostManager() + else: + return self.cost_manager + def llm(self) -> BaseLLM: """Return a LLM instance, fixme: support cache""" # if self._llm is None: self._llm = create_llm_instance(self.config.llm) if self._llm.cost_manager is None: - self._llm.cost_manager = self.cost_manager + self._llm.cost_manager = self._select_costmanager(self.config.llm) return self._llm def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM: @@ -93,5 +106,5 @@ class Context(BaseModel): # if self._llm is None: llm = create_llm_instance(llm_config) if llm.cost_manager is None: - llm.cost_manager = self.cost_manager + llm.cost_manager = self._select_costmanager(llm_config) return llm