mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
add anthropic_api
This commit is contained in:
parent
0e63b92883
commit
f1f0ae4cc1
20 changed files with 228 additions and 199 deletions
|
|
@ -16,6 +16,7 @@ from metagpt.utils.yaml_model import YamlModel
|
|||
class LLMType(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
CLAUDE = "claude" # alias name of anthropic
|
||||
SPARK = "spark"
|
||||
ZHIPUAI = "zhipuai"
|
||||
FIREWORKS = "fireworks"
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from metagpt.provider.human_provider import HumanProvider
|
|||
from metagpt.provider.spark_api import SparkLLM
|
||||
from metagpt.provider.qianfan_api import QianFanLLM
|
||||
from metagpt.provider.dashscope_api import DashScopeLLM
|
||||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
|
||||
__all__ = [
|
||||
"GeminiLLM",
|
||||
|
|
@ -28,4 +29,5 @@ __all__ = [
|
|||
"SparkLLM",
|
||||
"QianFanLLM",
|
||||
"DashScopeLLM",
|
||||
"AnthropicLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,37 +1,71 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/21 11:15
|
||||
@Author : Leo Xiao
|
||||
@File : anthropic_api.py
|
||||
"""
|
||||
|
||||
import anthropic
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message, Usage
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
class Claude2:
|
||||
@register_provider([LLMType.ANTHROPIC, LLMType.CLAUDE])
|
||||
class AnthropicLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
self.config = config
|
||||
self.__init_anthropic()
|
||||
|
||||
def ask(self, prompt: str) -> str:
|
||||
client = Anthropic(api_key=self.config.api_key)
|
||||
def __init_anthropic(self):
|
||||
self.model = self.config.model
|
||||
self.aclient: AsyncAnthropic = AsyncAnthropic(api_key=self.config.api_key, base_url=self.config.base_url)
|
||||
|
||||
res = client.completions.create(
|
||||
model="claude-2",
|
||||
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
|
||||
max_tokens_to_sample=1000,
|
||||
)
|
||||
return res.completion
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.config.max_token,
|
||||
"stream": stream,
|
||||
}
|
||||
if self.use_system_prompt:
|
||||
# if the model support system prompt, extract and pass it
|
||||
if messages[0]["role"] == "system":
|
||||
kwargs["messages"] = messages[1:]
|
||||
kwargs["system"] = messages[0]["content"] # set system prompt here
|
||||
return kwargs
|
||||
|
||||
async def aask(self, prompt: str) -> str:
|
||||
aclient = AsyncAnthropic(api_key=self.config.api_key)
|
||||
def _update_costs(self, usage: Usage, model: str = None, local_calc_usage: bool = True):
|
||||
usage = {"prompt_tokens": usage.input_tokens, "completion_tokens": usage.output_tokens}
|
||||
super()._update_costs(usage, model)
|
||||
|
||||
res = await aclient.completions.create(
|
||||
model="claude-2",
|
||||
prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}",
|
||||
max_tokens_to_sample=1000,
|
||||
)
|
||||
return res.completion
|
||||
def get_choice_text(self, resp: Message) -> str:
|
||||
return resp.content[0].text
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> Message:
|
||||
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
|
||||
self._update_costs(resp.usage, self.model)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout: int = 3) -> Message:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = Usage(input_tokens=0, output_tokens=0)
|
||||
async for event in stream:
|
||||
event_type = event.type
|
||||
if event_type == "message_start":
|
||||
usage.input_tokens = event.message.usage.input_tokens
|
||||
usage.output_tokens = event.message.usage.output_tokens
|
||||
elif event_type == "content_block_delta":
|
||||
content = event.delta.text
|
||||
log_llm_stream(content)
|
||||
collected_content.append(content)
|
||||
elif event_type == "message_delta":
|
||||
usage.output_tokens = event.usage.output_tokens # update final output_tokens
|
||||
|
||||
log_llm_stream("\n")
|
||||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
|
|
|||
|
|
@ -12,10 +12,18 @@ from typing import Optional, Union
|
|||
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import BaseModel
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
||||
|
||||
|
|
@ -129,6 +137,10 @@ class BaseLLM(ABC):
|
|||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
"""_achat_completion implemented by inherited class"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""Asynchronous version of completion
|
||||
|
|
@ -141,8 +153,22 @@ class BaseLLM(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
"""_achat_completion_stream implemented by inherited class"""
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream: bool = False, timeout: int = 3) -> str:
|
||||
"""Asynchronous version of completion. Return str. Support stream-print"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages, timeout=timeout)
|
||||
resp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(resp)
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
"""Required to provide the first text of choice"""
|
||||
|
|
|
|||
|
|
@ -24,18 +24,10 @@ from dashscope.common.error import (
|
|||
ModelRequired,
|
||||
UnsupportedApiProtocol,
|
||||
)
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM, LLMConfig
|
||||
from metagpt.provider.llm_provider_registry import LLMType, register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import DASHSCOPE_TOKEN_COSTS
|
||||
|
||||
|
|
@ -210,16 +202,16 @@ class DashScopeLLM(BaseLLM):
|
|||
self._update_costs(dict(resp.usage))
|
||||
return resp.output
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> GenerationOutput:
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> GenerationOutput:
|
||||
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
|
||||
self._check_response(resp)
|
||||
self._update_costs(dict(resp.usage))
|
||||
return resp.output
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
|
||||
return await self._achat_completion(messages)
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
|
|
@ -233,16 +225,3 @@ class DashScopeLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
|
|
@ -13,19 +13,11 @@ from google.generativeai.types.generation_types import (
|
|||
GenerateContentResponse,
|
||||
GenerationConfig,
|
||||
)
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
|
||||
|
||||
class GeminiGenerativeModel(GenerativeModel):
|
||||
|
|
@ -95,16 +87,16 @@ class GeminiLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> "AsyncGenerateContentResponse":
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> "AsyncGenerateContentResponse":
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(**self._const_kwargs(messages))
|
||||
usage = await self.aget_usage(messages, resp.text)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
resp: AsyncGenerateContentResponse = await self.llm.generate_content_async(
|
||||
**self._const_kwargs(messages, stream=True)
|
||||
)
|
||||
|
|
@ -119,17 +111,3 @@ class GeminiLLM(BaseLLM):
|
|||
usage = await self.aget_usage(messages, full_content)
|
||||
self._update_costs(usage)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
|
|
@ -35,10 +35,16 @@ class HumanProvider(BaseLLM):
|
|||
) -> str:
|
||||
return self.ask(msg, timeout=timeout)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
pass
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
pass
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -4,22 +4,12 @@
|
|||
|
||||
import json
|
||||
|
||||
from requests import ConnectionError
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import TokenCostManager
|
||||
|
||||
|
||||
|
|
@ -59,7 +49,7 @@ class OllamaLLM(BaseLLM):
|
|||
chunk = chunk.decode(encoding)
|
||||
return json.loads(chunk)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> dict:
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
|
|
@ -72,9 +62,9 @@ class OllamaLLM(BaseLLM):
|
|||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
stream_resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
|
|
@ -100,17 +90,3 @@ class OllamaLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
import json
|
||||
import re
|
||||
from typing import AsyncIterator, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
|
@ -29,8 +29,8 @@ from metagpt.provider.base_llm import BaseLLM
|
|||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import CodeParser, decode_image
|
||||
from metagpt.utils.cost_manager import CostManager, Costs, TokenCostManager
|
||||
from metagpt.utils.common import CodeParser, decode_image, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, TokenCostManager
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
count_message_tokens,
|
||||
|
|
@ -39,17 +39,6 @@ from metagpt.utils.token_counter import (
|
|||
)
|
||||
|
||||
|
||||
def log_and_reraise(retry_state):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
"""
|
||||
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
|
||||
See FAQ 5.8
|
||||
"""
|
||||
)
|
||||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
@register_provider([LLMType.OPENAI, LLMType.FIREWORKS, LLMType.OPEN_LLM, LLMType.MOONSHOT, LLMType.MISTRAL])
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
|
|
|||
|
|
@ -7,19 +7,11 @@ import os
|
|||
import qianfan
|
||||
from qianfan import ChatCompletion
|
||||
from qianfan.resources.typing import JsonBody
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.token_counter import (
|
||||
QIANFAN_ENDPOINT_TOKEN_COSTS,
|
||||
|
|
@ -115,15 +107,15 @@ class QianFanLLM(BaseLLM):
|
|||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> JsonBody:
|
||||
async def _achat_completion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
|
||||
self._update_costs(resp.body.get("usage", {}))
|
||||
return resp.body
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody:
|
||||
return await self._achat_completion(messages)
|
||||
async def acompletion(self, messages: list[dict], timeout: int = 3) -> JsonBody:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
|
||||
collected_content = []
|
||||
usage = {}
|
||||
|
|
@ -137,16 +129,3 @@ class QianFanLLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
|
|
@ -31,12 +31,18 @@ class SparkLLM(BaseLLM):
|
|||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
pass
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
# 不支持
|
||||
# logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
return w.run()
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
pass
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages, self.config)
|
||||
|
|
|
|||
|
|
@ -5,21 +5,12 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from requests import ConnectionError
|
||||
from tenacity import (
|
||||
after_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
from zhipuai.types.chat.chat_completion import Completion
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import log_llm_stream
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
|
||||
|
|
@ -86,17 +77,3 @@ class ZhiPuAILLM(BaseLLM):
|
|||
self._update_costs(usage)
|
||||
full_content = "".join(collected_content)
|
||||
return full_content
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
after=after_log(logger, logger.level("WARNING").name),
|
||||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
resp = await self._achat_completion(messages)
|
||||
return self.get_choice_text(resp)
|
||||
|
|
|
|||
|
|
@ -676,3 +676,14 @@ def decode_image(img_url_or_b64: str) -> Image:
|
|||
img_data = BytesIO(base64.b64decode(b64_data))
|
||||
img = Image.open(img_data)
|
||||
return img
|
||||
|
||||
|
||||
def log_and_reraise(retry_state: RetryCallState):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
"""
|
||||
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
|
||||
See FAQ 5.8
|
||||
"""
|
||||
)
|
||||
raise retry_state.outcome.exception()
|
||||
|
|
|
|||
|
|
@ -43,6 +43,11 @@ TOKEN_COSTS = {
|
|||
"mistral-small-latest": {"prompt": 0.002, "completion": 0.006},
|
||||
"mistral-medium-latest": {"prompt": 0.0027, "completion": 0.0081},
|
||||
"mistral-large-latest": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-instant-1.2": {"prompt": 0.0008, "completion": 0.0024},
|
||||
"claude-2.0": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-2.1": {"prompt": 0.008, "completion": 0.024},
|
||||
"claude-3-sonnet-20240229": {"prompt": 0.003, "completion": 0.015},
|
||||
"claude-3-opus-20240229": {"prompt": 0.015, "completion": 0.075},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -167,6 +172,11 @@ TOKEN_MAX = {
|
|||
"mistral-small-latest": 32768,
|
||||
"mistral-medium-latest": 32768,
|
||||
"mistral-large-latest": 32768,
|
||||
"claude-instant-1.2": 100000,
|
||||
"claude-2.0": 100000,
|
||||
"claude-2.1": 200000,
|
||||
"claude-3-sonnet-20240229": 200000,
|
||||
"claude-3-opus-20240229": 200000,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ tqdm==4.65.0
|
|||
#unstructured[local-inference]
|
||||
# selenium>4
|
||||
# webdriver_manager<3.9
|
||||
anthropic==0.8.1
|
||||
anthropic==0.18.1
|
||||
typing-inspect==0.8.0
|
||||
libcst==1.0.1
|
||||
qdrant-client==1.7.0
|
||||
|
|
|
|||
|
|
@ -56,3 +56,7 @@ mock_llm_config_spark = LLMConfig(
|
|||
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")
|
||||
|
||||
mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model="qwen-max")
|
||||
|
||||
mock_llm_config_anthropic = LLMConfig(
|
||||
api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,14 @@
|
|||
# @Desc : default request & response data for provider unittest
|
||||
|
||||
|
||||
from anthropic.types import (
|
||||
ContentBlock,
|
||||
ContentBlockDeltaEvent,
|
||||
Message,
|
||||
MessageStartEvent,
|
||||
TextDelta,
|
||||
)
|
||||
from anthropic.types import Usage as AnthropicUsage
|
||||
from dashscope.api_entities.dashscope_response import (
|
||||
DashScopeAPIResponse,
|
||||
GenerationOutput,
|
||||
|
|
@ -130,6 +138,38 @@ def get_dashscope_response(name: str) -> GenerationResponse:
|
|||
)
|
||||
|
||||
|
||||
# For Anthropic
|
||||
def get_anthropic_response(name: str, stream: bool = False) -> Message:
|
||||
if stream:
|
||||
return [
|
||||
MessageStartEvent(
|
||||
message=Message(
|
||||
id="xxx",
|
||||
model=name,
|
||||
role="assistant",
|
||||
type="message",
|
||||
content=[ContentBlock(text="", type="text")],
|
||||
usage=AnthropicUsage(input_tokens=10, output_tokens=10),
|
||||
),
|
||||
type="message_start",
|
||||
),
|
||||
ContentBlockDeltaEvent(
|
||||
index=0,
|
||||
delta=TextDelta(text=resp_cont_tmpl.format(name=name), type="text_delta"),
|
||||
type="content_block_delta",
|
||||
),
|
||||
]
|
||||
else:
|
||||
return Message(
|
||||
id="xxx",
|
||||
model=name,
|
||||
role="assistant",
|
||||
type="message",
|
||||
content=[ContentBlock(text=resp_cont_tmpl.format(name=name), type="text")],
|
||||
usage=AnthropicUsage(input_tokens=10, output_tokens=10),
|
||||
)
|
||||
|
||||
|
||||
# For llm general chat functions call
|
||||
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
|
||||
resp = await llm.aask(prompt, stream=False)
|
||||
|
|
|
|||
|
|
@ -2,31 +2,45 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of Claude2
|
||||
|
||||
|
||||
import pytest
|
||||
from anthropic.resources.completions import Completion
|
||||
|
||||
from metagpt.provider.anthropic_api import Claude2
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl
|
||||
from metagpt.provider.anthropic_api import AnthropicLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_anthropic
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_anthropic_response,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="Claude")
|
||||
name = "claude-3-opus-20240229"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
|
||||
|
||||
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
async def mock_anthropic_messages_create(
|
||||
self, messages: list[dict], model: str, stream: bool = True, max_tokens: int = None, system: str = None
|
||||
) -> Completion:
|
||||
if stream:
|
||||
|
||||
async def aresp_iterator():
|
||||
resps = get_anthropic_response(name, stream=True)
|
||||
for resp in resps:
|
||||
yield resp
|
||||
|
||||
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
|
||||
|
||||
def test_claude2_ask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
|
||||
assert resp_cont == Claude2(mock_llm_config).ask(prompt)
|
||||
return aresp_iterator()
|
||||
else:
|
||||
return get_anthropic_response(name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude2_aask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
|
||||
assert resp_cont == await Claude2(mock_llm_config).aask(prompt)
|
||||
async def test_anthropic_acompletion(mocker):
|
||||
mocker.patch("anthropic.resources.messages.AsyncMessages.create", mock_anthropic_messages_create)
|
||||
|
||||
anthropic_llm = AnthropicLLM(mock_llm_config_anthropic)
|
||||
|
||||
resp = await anthropic_llm.acompletion(messages)
|
||||
assert resp.content[0].text == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(anthropic_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
|
|
@ -27,9 +27,15 @@ class MockBaseLLM(BaseLLM):
|
|||
def completion(self, messages: list[dict], timeout=3):
|
||||
return get_part_chat_completion(name)
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3):
|
||||
pass
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
return get_part_chat_completion(name)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = 3) -> str:
|
||||
pass
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
return default_resp_cont
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Optional, Union
|
|||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -24,17 +24,8 @@ class MockLLM(OriginalLLM):
|
|||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""Overwrite original acompletion_text to cancel retry"""
|
||||
if stream:
|
||||
resp = self._achat_completion_stream(messages, timeout=timeout)
|
||||
|
||||
collected_messages = []
|
||||
async for i in resp:
|
||||
log_llm_stream(i)
|
||||
collected_messages.append(i)
|
||||
|
||||
full_reply_content = "".join(collected_messages)
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
resp = await self._achat_completion_stream(messages, timeout=timeout)
|
||||
return resp
|
||||
|
||||
rsp = await self._achat_completion(messages, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue