diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index 326d23a5c..03802a716 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -4,22 +4,17 @@ @Time : 2023/7/21 11:15 @Author : Leo Xiao @File : anthropic_api.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation; - Change cost control from global to company level. """ import anthropic from anthropic import Anthropic -from metagpt.config import Config +from metagpt.config import CONFIG class Claude2: - def __init__(self, options=None): - self.options = options or Config().runtime_options - def ask(self, prompt): - client = Anthropic(api_key=self.claude_api_key) + client = Anthropic(api_key=CONFIG.claude_api_key) res = client.completions.create( model="claude-2", @@ -29,7 +24,7 @@ class Claude2: return res.completion async def aask(self, prompt): - client = Anthropic(api_key=self.claude_api_key) + client = Anthropic(api_key=CONFIG.claude_api_key) res = client.completions.create( model="claude-2", @@ -37,7 +32,3 @@ class Claude2: max_tokens_to_sample=1000, ) return res.completion - - @property - def claude_api_key(self): - return self.options.get("claude_api_key") diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 098388a7c..640694b67 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -18,6 +18,7 @@ from openai.error import APIConnectionError from pydantic import BaseModel from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type +from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.utils.token_counter import ( @@ -134,23 +135,22 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): Check https://platform.openai.com/examples for examples """ - def __init__(self, options, cost_manager): - self._options = options - self.__init_openai() + def __init__(self, cost_manager): + self.__init_openai(CONFIG) self.llm = openai - self.model = self.openai_api_model + self.model = CONFIG.openai_api_model self.auto_max_tokens = False - self._cost_manager = cost_manager + self._cost_manager = cost_manager or CostManager() RateLimiter.__init__(self, rpm=self.rpm) - def __init_openai(self): - openai.api_key = self.openai_api_key - if self.openai_api_base: - openai.api_base = self.openai_api_base - if self.openai_api_type: - openai.api_type = self.openai_api_type - openai.api_version = self.openai_api_version - self.rpm = int(self._options.get("RPM", 10)) + def __init_openai(self, config): + openai.api_key = config.openai_api_key + if config.openai_api_base: + openai.api_base = config.openai_api_base + if config.openai_api_type: + openai.api_type = config.openai_api_type + openai.api_version = config.openai_api_version + self.rpm = int(config.get("RPM", 10)) async def _achat_completion_stream(self, messages: list[dict]) -> str: response = await self.async_retry_call(openai.ChatCompletion.acreate, @@ -175,9 +175,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return full_reply_content def _cons_kwargs(self, messages: list[dict]) -> dict: - if self._options.get("openai_api_type") == "azure": + if CONFIG.openai_api_type == "azure": kwargs = { - "deployment_id": self._options.get("deployment_id"), + "deployment_id": CONFIG.deployment_id, "messages": messages, "max_tokens": self.get_max_tokens(messages), "n": 1, @@ -232,7 +232,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): def _calc_usage(self, messages: list[dict], rsp: str) -> dict: usage = {} - if self._options.get("calc_usage"): + if CONFIG.calc_usage: try: prompt_tokens = count_message_tokens(messages, self.model) completion_tokens = count_string_tokens(rsp, self.model) @@ -271,7 +271,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return results def _update_costs(self, usage: dict): - if self._options.get("calc_usage"): + if CONFIG.calc_usage: try: prompt_tokens = int(usage['prompt_tokens']) completion_tokens = int(usage['completion_tokens']) @@ -284,34 +284,14 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): def get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: - return self._options.get("max_tokens_rsp") - return get_max_completion_tokens(messages, self.model, self._options.get("max_tokens_rsp")) - - @property - def openai_api_model(self): - return self._options.get("openai_api_model") - - @property - def openai_api_key(self): - return self._options.get("openai_api_key") - - @property - def openai_api_base(self): - return self._options.get("openai_api_base") - - @property - def openai_api_type(self): - return self._options.get("openai_api_type") - - @property - def openai_api_version(self): - return self._options.get("openai_api_version") + return CONFIG.max_tokens_rsp + return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) async def get_summary(self, text: str, max_words=20): """Generate text summary""" if len(text) < max_words: return text - language = self._options.get("language", "English") + language = CONFIG.language or self.DEFAULT_LANGUAGE command = f"Translate the above content into a {language} summary of less than {max_words} words." msg = text + "\n\n" + command logger.info(f"summary ask:{msg}") @@ -322,7 +302,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): async def get_context_title(self, text: str, max_token_count_per_ask=None, max_words=5) -> str: """Generate text title""" max_response_token_count = 50 - max_token_count = max_token_count_per_ask or self._options.get("MAX_TOKENS", 1500) + max_token_count = max_token_count_per_ask or CONFIG.MAX_TOKENS or 1500 text_windows = self.split_texts(text, window_size=max_token_count - max_response_token_count) summaries = [] @@ -332,7 +312,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): if len(summaries) == 1: return summaries[0] - language = self._options.get("language", "English") + language = CONFIG.language or self.DEFAULT_LANGUAGE command = f"Translate the above summary into a {language} title of less than {max_words} words." summaries.append(command) msg = "\n".join(summaries) @@ -418,3 +398,4 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): raise openai.error.OpenAIError("Exceeds the maximum retries") MAX_TRY = 5 + DEFAULT_LANGUAGE = "Engilish"