From f1f0ae4cc1e09cd269082e1069e26264c240573c Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 5 Mar 2024 13:58:51 +0800 Subject: [PATCH] add anthropic_api --- metagpt/configs/llm_config.py | 1 + metagpt/provider/__init__.py | 2 + metagpt/provider/anthropic_api.py | 84 ++++++++++++++------ metagpt/provider/base_llm.py | 28 ++++++- metagpt/provider/dashscope_api.py | 29 +------ metagpt/provider/google_gemini_api.py | 30 +------ metagpt/provider/human_provider.py | 6 ++ metagpt/provider/ollama_api.py | 32 +------- metagpt/provider/openai_api.py | 17 +--- metagpt/provider/qianfan_api.py | 31 ++------ metagpt/provider/spark_api.py | 6 ++ metagpt/provider/zhipuai_api.py | 25 +----- metagpt/utils/common.py | 11 +++ metagpt/utils/token_counter.py | 10 +++ requirements.txt | 2 +- tests/metagpt/provider/mock_llm_config.py | 4 + tests/metagpt/provider/req_resp_const.py | 40 ++++++++++ tests/metagpt/provider/test_anthropic_api.py | 48 +++++++---- tests/metagpt/provider/test_base_llm.py | 6 ++ tests/mock/mock_llm.py | 15 +--- 20 files changed, 228 insertions(+), 199 deletions(-) diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index a07a87cdf..66a68cc9d 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -16,6 +16,7 @@ from metagpt.utils.yaml_model import YamlModel class LLMType(Enum): OPENAI = "openai" ANTHROPIC = "anthropic" + CLAUDE = "claude" # alias name of anthropic SPARK = "spark" ZHIPUAI = "zhipuai" FIREWORKS = "fireworks" diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index ed49d01c9..14d5e7682 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -16,6 +16,7 @@ from metagpt.provider.human_provider import HumanProvider from metagpt.provider.spark_api import SparkLLM from metagpt.provider.qianfan_api import QianFanLLM from metagpt.provider.dashscope_api import DashScopeLLM +from metagpt.provider.anthropic_api import AnthropicLLM __all__ = [ "GeminiLLM", @@ -28,4 +29,5 @@ __all__ = [ "SparkLLM", "QianFanLLM", "DashScopeLLM", + "AnthropicLLM", ] diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index f31c2d04d..872f9b2c7 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -1,37 +1,71 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -""" -@Time : 2023/7/21 11:15 -@Author : Leo Xiao -@File : anthropic_api.py -""" -import anthropic -from anthropic import Anthropic, AsyncAnthropic +from anthropic import AsyncAnthropic +from anthropic.types import Message, Usage -from metagpt.configs.llm_config import LLMConfig +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.logs import log_llm_stream +from metagpt.provider.base_llm import BaseLLM +from metagpt.provider.llm_provider_registry import register_provider -class Claude2: +@register_provider([LLMType.ANTHROPIC, LLMType.CLAUDE]) +class AnthropicLLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config + self.__init_anthropic() - def ask(self, prompt: str) -> str: - client = Anthropic(api_key=self.config.api_key) + def __init_anthropic(self): + self.model = self.config.model + self.aclient: AsyncAnthropic = AsyncAnthropic(api_key=self.config.api_key, base_url=self.config.base_url) - res = client.completions.create( - model="claude-2", - prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}", - max_tokens_to_sample=1000, - ) - return res.completion + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = { + "model": self.model, + "messages": messages, + "max_tokens": self.config.max_token, + "stream": stream, + } + if self.use_system_prompt: + # if the model support system prompt, extract and pass it + if messages[0]["role"] == "system": + kwargs["messages"] = messages[1:] + kwargs["system"] = messages[0]["content"] # set system prompt here + return kwargs - async def aask(self, prompt: str) -> str: - aclient = AsyncAnthropic(api_key=self.config.api_key) + def _update_costs(self, usage: Usage, model: str = None, local_calc_usage: bool = True): + usage = {"prompt_tokens": usage.input_tokens, "completion_tokens": usage.output_tokens} + super()._update_costs(usage, model) - res = await aclient.completions.create( - model="claude-2", - prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}", - max_tokens_to_sample=1000, - ) - return res.completion + def get_choice_text(self, resp: Message) -> str: + return resp.content[0].text + + async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> Message: + resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages)) + self._update_costs(resp.usage, self.model) + return resp + + async def acompletion(self, messages: list[dict], timeout: int = 3) -> Message: + return await self._achat_completion(messages, timeout=timeout) + + async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str: + stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True)) + collected_content = [] + usage = Usage(input_tokens=0, output_tokens=0) + async for event in stream: + event_type = event.type + if event_type == "message_start": + usage.input_tokens = event.message.usage.input_tokens + usage.output_tokens = event.message.usage.output_tokens + elif event_type == "content_block_delta": + content = event.delta.text + log_llm_stream(content) + collected_content.append(content) + elif event_type == "message_delta": + usage.output_tokens = event.usage.output_tokens # update final output_tokens + + log_llm_stream("\n") + self._update_costs(usage) + full_content = "".join(collected_content) + return full_content diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 7cf3faac0..da6acf09c 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -12,10 +12,18 @@ from typing import Optional, Union from openai import AsyncOpenAI from pydantic import BaseModel +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) from metagpt.configs.llm_config import LLMConfig from metagpt.logs import logger from metagpt.schema import Message +from metagpt.utils.common import log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs @@ -129,6 +137,10 @@ class BaseLLM(ABC): """FIXME: No code segment filtering has been done here, and all results are actually displayed""" raise NotImplementedError + @abstractmethod + async def _achat_completion(self, messages: list[dict], timeout=3): + """_achat_completion implemented by inherited class""" + @abstractmethod async def acompletion(self, messages: list[dict], timeout=3): """Asynchronous version of completion @@ -141,8 +153,22 @@ class BaseLLM(ABC): """ @abstractmethod - async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str: + """_achat_completion_stream implemented by inherited class""" + + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(min=1, max=60), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise, + ) + async def acompletion_text(self, messages: list[dict], stream: bool = False, timeout: int = 3) -> str: """Asynchronous version of completion. Return str. Support stream-print""" + if stream: + return await self._achat_completion_stream(messages, timeout=timeout) + resp = await self._achat_completion(messages, timeout=timeout) + return self.get_choice_text(resp) def get_choice_text(self, rsp: dict) -> str: """Required to provide the first text of choice""" diff --git a/metagpt/provider/dashscope_api.py b/metagpt/provider/dashscope_api.py index f2b3a19a1..21f3ef351 100644 --- a/metagpt/provider/dashscope_api.py +++ b/metagpt/provider/dashscope_api.py @@ -24,18 +24,10 @@ from dashscope.common.error import ( ModelRequired, UnsupportedApiProtocol, ) -from tenacity import ( - after_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_random_exponential, -) -from metagpt.logs import log_llm_stream, logger +from metagpt.logs import log_llm_stream from metagpt.provider.base_llm import BaseLLM, LLMConfig from metagpt.provider.llm_provider_registry import LLMType, register_provider -from metagpt.provider.openai_api import log_and_reraise from metagpt.utils.cost_manager import CostManager from metagpt.utils.token_counter import DASHSCOPE_TOKEN_COSTS @@ -210,16 +202,16 @@ class DashScopeLLM(BaseLLM): self._update_costs(dict(resp.usage)) return resp.output - async def _achat_completion(self, messages: list[dict]) -> GenerationOutput: + async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> GenerationOutput: resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False)) self._check_response(resp) self._update_costs(dict(resp.usage)) return resp.output async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput: - return await self._achat_completion(messages) + return await self._achat_completion(messages, timeout=timeout) - async def _achat_completion_stream(self, messages: list[dict]) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str: resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True)) collected_content = [] usage = {} @@ -233,16 +225,3 @@ class DashScopeLLM(BaseLLM): self._update_costs(usage) full_content = "".join(collected_content) return full_content - - @retry( - stop=stop_after_attempt(3), - wait=wait_random_exponential(min=1, max=60), - after=after_log(logger, logger.level("WARNING").name), - retry=retry_if_exception_type(ConnectionError), - retry_error_callback=log_and_reraise, - ) - async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: - if stream: - return await self._achat_completion_stream(messages) - resp = await self._achat_completion(messages) - return self.get_choice_text(resp) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 87ea81c80..1e8fe2be5 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -13,19 +13,11 @@ from google.generativeai.types.generation_types import ( GenerateContentResponse, GenerationConfig, ) -from tenacity import ( - after_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_random_exponential, -) from metagpt.configs.llm_config import LLMConfig, LLMType -from metagpt.logs import log_llm_stream, logger +from metagpt.logs import log_llm_stream from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import log_and_reraise class GeminiGenerativeModel(GenerativeModel): @@ -95,16 +87,16 @@ class GeminiLLM(BaseLLM): self._update_costs(usage) return resp - async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse": + async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> "AsyncGenerateContentResponse": resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages)) usage = await self.aget_usage(messages, resp.text) self._update_costs(usage) return resp async def acompletion(self, messages: list[dict], timeout=3) -> dict: - return await self._achat_completion(messages) + return await self._achat_completion(messages, timeout=timeout) - async def _achat_completion_stream(self, messages: list[dict]) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str: resp: AsyncGenerateContentResponse = await self.llm.generate_content_async( **self._const_kwargs(messages, stream=True) ) @@ -119,17 +111,3 @@ class GeminiLLM(BaseLLM): usage = await self.aget_usage(messages, full_content) self._update_costs(usage) return full_content - - @retry( - stop=stop_after_attempt(3), - wait=wait_random_exponential(min=1, max=60), - after=after_log(logger, logger.level("WARNING").name), - retry=retry_if_exception_type(ConnectionError), - retry_error_callback=log_and_reraise, - ) - async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: - """response in async with stream or non-stream mode""" - if stream: - return await self._achat_completion_stream(messages) - resp = await self._achat_completion(messages) - return self.get_choice_text(resp) diff --git a/metagpt/provider/human_provider.py b/metagpt/provider/human_provider.py index fe000b3a6..e5f37c5b9 100644 --- a/metagpt/provider/human_provider.py +++ b/metagpt/provider/human_provider.py @@ -35,10 +35,16 @@ class HumanProvider(BaseLLM): ) -> str: return self.ask(msg, timeout=timeout) + async def _achat_completion(self, messages: list[dict], timeout=3): + pass + async def acompletion(self, messages: list[dict], timeout=3): """dummy implementation of abstract method in base""" return [] + async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str: + pass + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """dummy implementation of abstract method in base""" return "" diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 52e8dbe36..cc7cc12fc 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -4,22 +4,12 @@ import json -from requests import ConnectionError -from tenacity import ( - after_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_random_exponential, -) - from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.const import LLM_API_TIMEOUT -from metagpt.logs import log_llm_stream, logger +from metagpt.logs import log_llm_stream from metagpt.provider.base_llm import BaseLLM from metagpt.provider.general_api_requestor import GeneralAPIRequestor from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import log_and_reraise from metagpt.utils.cost_manager import TokenCostManager @@ -59,7 +49,7 @@ class OllamaLLM(BaseLLM): chunk = chunk.decode(encoding) return json.loads(chunk) - async def _achat_completion(self, messages: list[dict]) -> dict: + async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> dict: resp, _, _ = await self.client.arequest( method=self.http_method, url=self.suffix_url, @@ -72,9 +62,9 @@ class OllamaLLM(BaseLLM): return resp async def acompletion(self, messages: list[dict], timeout=3) -> dict: - return await self._achat_completion(messages) + return await self._achat_completion(messages, timeout=timeout) - async def _achat_completion_stream(self, messages: list[dict]) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str: stream_resp, _, _ = await self.client.arequest( method=self.http_method, url=self.suffix_url, @@ -100,17 +90,3 @@ class OllamaLLM(BaseLLM): self._update_costs(usage) full_content = "".join(collected_content) return full_content - - @retry( - stop=stop_after_attempt(3), - wait=wait_random_exponential(min=1, max=60), - after=after_log(logger, logger.level("WARNING").name), - retry=retry_if_exception_type(ConnectionError), - retry_error_callback=log_and_reraise, - ) - async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: - """response in async with stream or non-stream mode""" - if stream: - return await self._achat_completion_stream(messages) - resp = await self._achat_completion(messages) - return self.get_choice_text(resp) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index e575232b0..884a3f6d3 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -9,7 +9,7 @@ import json import re -from typing import AsyncIterator, Optional, Union +from typing import Optional, Union from openai import APIConnectionError, AsyncOpenAI, AsyncStream from openai._base_client import AsyncHttpxClientWrapper @@ -29,8 +29,8 @@ from metagpt.provider.base_llm import BaseLLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message -from metagpt.utils.common import CodeParser, decode_image -from metagpt.utils.cost_manager import CostManager, Costs, TokenCostManager +from metagpt.utils.common import CodeParser, decode_image, log_and_reraise +from metagpt.utils.cost_manager import CostManager, TokenCostManager from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( count_message_tokens, @@ -39,17 +39,6 @@ from metagpt.utils.token_counter import ( ) -def log_and_reraise(retry_state): - logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") - logger.warning( - """ -Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ -See FAQ 5.8 -""" - ) - raise retry_state.outcome.exception() - - @register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT, LLMType.MISTRAL]) class OpenAILLM(BaseLLM): """Check https://platform.openai.com/examples for examples""" diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py index 4cbb76566..50916fa3e 100644 --- a/metagpt/provider/qianfan_api.py +++ b/metagpt/provider/qianfan_api.py @@ -7,19 +7,11 @@ import os import qianfan from qianfan import ChatCompletion from qianfan.resources.typing import JsonBody -from tenacity import ( - after_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_random_exponential, -) from metagpt.configs.llm_config import LLMConfig, LLMType -from metagpt.logs import log_llm_stream, logger +from metagpt.logs import log_llm_stream from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import log_and_reraise from metagpt.utils.cost_manager import CostManager from metagpt.utils.token_counter import ( QIANFAN_ENDPOINT_TOKEN_COSTS, @@ -115,15 +107,15 @@ class QianFanLLM(BaseLLM): self._update_costs(resp.body.get("usage", {})) return resp.body - async def _achat_completion(self, messages: list[dict]) -> JsonBody: + async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> JsonBody: resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False)) self._update_costs(resp.body.get("usage", {})) return resp.body - async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody: - return await self._achat_completion(messages) + async def acompletion(self, messages: list[dict], timeout: int = 3) -> JsonBody: + return await self._achat_completion(messages, timeout=timeout) - async def _achat_completion_stream(self, messages: list[dict]) -> str: + async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str: resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True)) collected_content = [] usage = {} @@ -137,16 +129,3 @@ class QianFanLLM(BaseLLM): self._update_costs(usage) full_content = "".join(collected_content) return full_content - - @retry( - stop=stop_after_attempt(3), - wait=wait_random_exponential(min=1, max=60), - after=after_log(logger, logger.level("WARNING").name), - retry=retry_if_exception_type(ConnectionError), - retry_error_callback=log_and_reraise, - ) - async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: - if stream: - return await self._achat_completion_stream(messages) - resp = await self._achat_completion(messages) - return self.get_choice_text(resp) diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 5e89c26d5..882c6ce85 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -31,12 +31,18 @@ class SparkLLM(BaseLLM): def get_choice_text(self, rsp: dict) -> str: return rsp["payload"]["choices"]["text"][-1]["content"] + async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str: + pass + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: # 不支持 # logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") w = GetMessageFromWeb(messages, self.config) return w.run() + async def _achat_completion(self, messages: list[dict], timeout=3): + pass + async def acompletion(self, messages: list[dict], timeout=3): # 不支持异步 w = GetMessageFromWeb(messages, self.config) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 4cbee4038..546d2f269 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -5,21 +5,12 @@ from enum import Enum from typing import Optional -from requests import ConnectionError -from tenacity import ( - after_log, - retry, - retry_if_exception_type, - stop_after_attempt, - wait_random_exponential, -) from zhipuai.types.chat.chat_completion import Completion from metagpt.configs.llm_config import LLMConfig, LLMType -from metagpt.logs import log_llm_stream, logger +from metagpt.logs import log_llm_stream from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import log_and_reraise from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI from metagpt.utils.cost_manager import CostManager @@ -86,17 +77,3 @@ class ZhiPuAILLM(BaseLLM): self._update_costs(usage) full_content = "".join(collected_content) return full_content - - @retry( - stop=stop_after_attempt(3), - wait=wait_random_exponential(min=1, max=60), - after=after_log(logger, logger.level("WARNING").name), - retry=retry_if_exception_type(ConnectionError), - retry_error_callback=log_and_reraise, - ) - async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: - """response in async with stream or non-stream mode""" - if stream: - return await self._achat_completion_stream(messages) - resp = await self._achat_completion(messages) - return self.get_choice_text(resp) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 015902c3d..1fde13c98 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -676,3 +676,14 @@ def decode_image(img_url_or_b64: str) -> Image: img_data = BytesIO(base64.b64decode(b64_data)) img = Image.open(img_data) return img + + +def log_and_reraise(retry_state: RetryCallState): + logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") + logger.warning( + """ +Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ +See FAQ 5.8 +""" + ) + raise retry_state.outcome.exception() diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index c20caa8e1..54c9f3610 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -43,6 +43,11 @@ TOKEN_COSTS = { "mistral-small-latest": {"prompt": 0.002, "completion": 0.006}, "mistral-medium-latest": {"prompt": 0.0027, "completion": 0.0081}, "mistral-large-latest": {"prompt": 0.008, "completion": 0.024}, + "claude-instant-1.2": {"prompt": 0.0008, "completion": 0.0024}, + "claude-2.0": {"prompt": 0.008, "completion": 0.024}, + "claude-2.1": {"prompt": 0.008, "completion": 0.024}, + "claude-3-sonnet-20240229": {"prompt": 0.003, "completion": 0.015}, + "claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075}, } @@ -167,6 +172,11 @@ TOKEN_MAX = { "mistral-small-latest": 32768, "mistral-medium-latest": 32768, "mistral-large-latest": 32768, + "claude-instant-1.2": 100000, + "claude-2.0": 100000, + "claude-2.1": 200000, + "claude-3-sonnet-20240229": 200000, + "claude-3-opus-20240229": 200000, } diff --git a/requirements.txt b/requirements.txt index cf50cf255..a009464dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,7 +33,7 @@ tqdm==4.65.0 #unstructured[local-inference] # selenium>4 # webdriver_manager<3.9 -anthropic==0.8.1 +anthropic==0.18.1 typing-inspect==0.8.0 libcst==1.0.1 qdrant-client==1.7.0 diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index e75acf68f..0c56cc8ea 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -56,3 +56,7 @@ mock_llm_config_spark = LLMConfig( mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo") mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model="qwen-max") + +mock_llm_config_anthropic = LLMConfig( + api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229" +) diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py index 802962013..7e4c1a49c 100644 --- a/tests/metagpt/provider/req_resp_const.py +++ b/tests/metagpt/provider/req_resp_const.py @@ -3,6 +3,14 @@ # @Desc : default request & response data for provider unittest +from anthropic.types import ( + ContentBlock, + ContentBlockDeltaEvent, + Message, + MessageStartEvent, + TextDelta, +) +from anthropic.types import Usage as AnthropicUsage from dashscope.api_entities.dashscope_response import ( DashScopeAPIResponse, GenerationOutput, @@ -130,6 +138,38 @@ def get_dashscope_response(name: str) -> GenerationResponse: ) +# For Anthropic +def get_anthropic_response(name: str, stream: bool = False) -> Message: + if stream: + return [ + MessageStartEvent( + message=Message( + id="xxx", + model=name, + role="assistant", + type="message", + content=[ContentBlock(text="", type="text")], + usage=AnthropicUsage(input_tokens=10, output_tokens=10), + ), + type="message_start", + ), + ContentBlockDeltaEvent( + index=0, + delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"), + type="content_block_delta", + ), + ] + else: + return Message( + id="xxx", + model=name, + role="assistant", + type="message", + content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")], + usage=AnthropicUsage(input_tokens=10, output_tokens=10), + ) + + # For llm general chat functions call async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str): resp = await llm.aask(prompt, stream=False) diff --git a/tests/metagpt/provider/test_anthropic_api.py b/tests/metagpt/provider/test_anthropic_api.py index 93cfd7dbc..b8a3fe289 100644 --- a/tests/metagpt/provider/test_anthropic_api.py +++ b/tests/metagpt/provider/test_anthropic_api.py @@ -2,31 +2,45 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of Claude2 - import pytest from anthropic.resources.completions import Completion -from metagpt.provider.anthropic_api import Claude2 -from tests.metagpt.provider.mock_llm_config import mock_llm_config -from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl +from metagpt.provider.anthropic_api import AnthropicLLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config_anthropic +from tests.metagpt.provider.req_resp_const import ( + get_anthropic_response, + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) -resp_cont = resp_cont_tmpl.format(name="Claude") +name = "claude-3-opus-20240229" +resp_cont = resp_cont_tmpl.format(name=name) -def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: - return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion") +async def mock_anthropic_messages_create( + self, messages: list[dict], model: str, stream: bool = True, max_tokens: int = None, system: str = None +) -> Completion: + if stream: + async def aresp_iterator(): + resps = get_anthropic_response(name, stream=True) + for resp in resps: + yield resp -async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: - return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion") - - -def test_claude2_ask(mocker): - mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create) - assert resp_cont == Claude2(mock_llm_config).ask(prompt) + return aresp_iterator() + else: + return get_anthropic_response(name) @pytest.mark.asyncio -async def test_claude2_aask(mocker): - mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create) - assert resp_cont == await Claude2(mock_llm_config).aask(prompt) +async def test_anthropic_acompletion(mocker): + mocker.patch("anthropic.resources.messages.AsyncMessages.create", mock_anthropic_messages_create) + + anthropic_llm = AnthropicLLM(mock_llm_config_anthropic) + + resp = await anthropic_llm.acompletion(messages) + assert resp.content[0].text == resp_cont + + await llm_general_chat_funcs_test(anthropic_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index cf44343bc..bff8dbde4 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -27,9 +27,15 @@ class MockBaseLLM(BaseLLM): def completion(self, messages: list[dict], timeout=3): return get_part_chat_completion(name) + async def _achat_completion(self, messages: list[dict], timeout=3): + pass + async def acompletion(self, messages: list[dict], timeout=3): return get_part_chat_completion(name) + async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str: + pass + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: return default_resp_cont diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index 50e75dabf..b2052e2b3 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -3,7 +3,7 @@ from typing import Optional, Union from metagpt.config2 import config from metagpt.configs.llm_config import LLMType -from metagpt.logs import log_llm_stream, logger +from metagpt.logs import logger from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import Message @@ -24,17 +24,8 @@ class MockLLM(OriginalLLM): async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """Overwrite original acompletion_text to cancel retry""" if stream: - resp = self._achat_completion_stream(messages, timeout=timeout) - - collected_messages = [] - async for i in resp: - log_llm_stream(i) - collected_messages.append(i) - - full_reply_content = "".join(collected_messages) - usage = self._calc_usage(messages, full_reply_content) - self._update_costs(usage) - return full_reply_content + resp = await self._achat_completion_stream(messages, timeout=timeout) + return resp rsp = await self._achat_completion(messages, timeout=timeout) return self.get_choice_text(rsp)