diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index f39e708eb..f1590a77c 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -4,6 +4,7 @@ @Time : 2023/5/5 23:04 @Author : alexanderwu @File : base_gpt_api.py +@Desc : mashenquan, 2023/8/22. + try catch """ from abc import abstractmethod from typing import Optional @@ -41,7 +42,11 @@ class BaseGPTAPI(BaseChatbot): message = self._system_msgs(system_msgs) + [self._user_msg(msg)] else: message = [self._default_system_msg(), self._user_msg(msg)] - rsp = await self.acompletion_text(message, stream=True) + try: + rsp = await self.acompletion_text(message, stream=True) + except Exception as e: + logger.exception(f"{e}") + raise e logger.debug(message) # logger.debug(rsp) return rsp diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 2e951b36f..abfb796f3 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -8,10 +8,11 @@ """ import asyncio import time -from typing import NamedTuple +from typing import NamedTuple, Dict import openai from openai.error import APIConnectionError +from pydantic import BaseModel from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type from metagpt.logs import logger @@ -35,7 +36,7 @@ class RateLimiter: self.rpm = rpm def split_batches(self, batch): - return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)] + return [batch[i: i + self.rpm] for i in range(0, len(batch), self.rpm)] async def wait_if_needed(self, num_requests): current_time = time.time() @@ -56,14 +57,14 @@ class Costs(NamedTuple): total_budget: float -class CostManager: +class CostManager(BaseModel): """计算使用接口的开销""" - def __init__(self, options): - self.total_prompt_tokens = 0 - self.total_completion_tokens = 0 - self.options = options - self.total_budget = 0 + total_prompt_tokens: int = 0 + total_completion_tokens: int = 0 + total_budget: int = 0 + max_budget: int + total_cost: int = 0 def update_cost(self, prompt_tokens, completion_tokens, model): """ @@ -76,7 +77,8 @@ class CostManager: """ self.total_prompt_tokens += prompt_tokens self.total_completion_tokens += completion_tokens - cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000 + cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model][ + "completion"]) / 1000 self.total_cost += cost logger.info( f"Total running cost: ${self.total_cost:.3f} | Max budget: ${self.max_budget:.3f} | " @@ -114,18 +116,6 @@ class CostManager: """获得所有开销""" return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) - @property - def total_cost(self): - return self.options.get("total_cost", 0) - - @total_cost.setter - def total_cost(self, v): - self.options["total_cost"] = v - - @property - def max_budget(self): - return self.options.get("max_budget", 0) - def log_and_reraise(retry_state): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") diff --git a/metagpt/schema.py b/metagpt/schema.py index f45d1e36d..56e9ad95c 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -9,7 +9,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from enum import StrEnum +from enum import Enum from typing import Type, TypedDict, Set, Optional from pydantic import BaseModel @@ -17,7 +17,7 @@ from pydantic import BaseModel from metagpt.logs import logger -class MessageTag(StrEnum): +class MessageTag(Enum): Prerequisite = "prerequisite" diff --git a/metagpt/software_company.py b/metagpt/software_company.py index 3f6f484b4..87b24a1cb 100644 --- a/metagpt/software_company.py +++ b/metagpt/software_company.py @@ -30,7 +30,7 @@ class SoftwareCompany(BaseModel): investment: float = Field(default=10.0) idea: str = Field(default="") options: Dict = Field(default=Config().runtime_options) - cost_manager: CostManager = Field(default=CostManager(Config().runtime_options)) + cost_manager: CostManager = Field(default=CostManager(**Config().runtime_options)) class Config: arbitrary_types_allowed = True diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index 04216ad7c..9861fd4cd 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -19,7 +19,7 @@ from tests.metagpt.actions.mock import TASKS_2, WRITE_CODE_PROMPT_SAMPLE async def test_write_code(): api_design = "设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。" conf = Config() - cost_manager = CostManager(conf.runtime_options) + cost_manager = CostManager(**conf.runtime_options) llm = LLM(options=conf.runtime_options, cost_manager=cost_manager) write_code = WriteCode(options=conf.runtime_options, name="write_code", llm=llm) @@ -35,6 +35,6 @@ async def test_write_code(): async def test_write_code_directly(): prompt = WRITE_CODE_PROMPT_SAMPLE + '\n' + TASKS_2[0] options = Config().runtime_options - llm = LLM(options=options, cost_manager=CostManager(options=options)) + llm = LLM(options=options, cost_manager=CostManager(**options)) rsp = await llm.aask(prompt) logger.info(rsp) diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index d10c93ec0..57650d145 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -26,7 +26,7 @@ def env(): def test_add_role(env: Environment): conf = Config() - cost_manager = CostManager(options=conf.runtime_options) + cost_manager = CostManager(**conf.runtime_options) role = ProductManager(options=conf.runtime_options, cost_manager=cost_manager, name="Alice", @@ -39,7 +39,7 @@ def test_add_role(env: Environment): def test_get_roles(env: Environment): conf = Config() - cost_manager = CostManager(options=conf.runtime_options) + cost_manager = CostManager(**conf.runtime_options) role1 = Role(options=conf.runtime_options, cost_manager=cost_manager, name="Alice", profile="product manager", goal="create a new product", constraints="limited resources") role2 = Role(options=conf.runtime_options, cost_manager=cost_manager, name="Bob", profile="engineer", @@ -53,7 +53,7 @@ def test_get_roles(env: Environment): @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): conf = Config() - cost_manager = CostManager(options=conf.runtime_options) + cost_manager = CostManager(**conf.runtime_options) product_manager = ProductManager(options=conf.runtime_options, cost_manager=cost_manager, name="Alice", profile="Product Manager", diff --git a/tests/metagpt/test_llm.py b/tests/metagpt/test_llm.py index 77de6df0c..f61793151 100644 --- a/tests/metagpt/test_llm.py +++ b/tests/metagpt/test_llm.py @@ -16,7 +16,7 @@ from metagpt.provider.openai_api import OpenAIGPTAPI as LLM, CostManager @pytest.fixture() def llm(): options = Config().runtime_options - return LLM(options=options, cost_manager=CostManager(options)) + return LLM(options=options, cost_manager=CostManager(**options)) @pytest.mark.asyncio