Merge branch 'main' into code_interpreter

This commit is contained in:
garylin2099 2024-02-08 00:14:24 +08:00 committed by GitHub
commit a570c81ccf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 632 additions and 356 deletions

View file

@ -54,7 +54,6 @@ jobs:
export ALLOW_OPENAI_API_CALL=0
echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml
mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml
echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
- name: Show coverage report
run: |

View file

@ -31,7 +31,7 @@ jobs:
- name: Test with pytest
run: |
export ALLOW_OPENAI_API_CALL=0
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
- name: Show coverage report
run: |

View file

@ -1,5 +1,5 @@
llm:
api_type: "openai"
api_type: "openai" # or azure / ollama etc.
base_url: "YOUR_BASE_URL"
api_key: "YOUR_API_KEY"
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview

View file

@ -6,16 +6,25 @@
@File : llm_hello_world.py
"""
import asyncio
from pathlib import Path
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.utils.common import encode_image
async def main():
llm = LLM()
logger.info(await llm.aask("hello world"))
# llm type check
id_ques = "what's your name"
logger.info(f"{id_ques}: ")
logger.info(await llm.aask(id_ques))
logger.info("\n\n")
logger.info(
await llm.aask(
"who are you", system_msgs=["act as a robot, answer 'I'am robot' if the question is 'who are you'"]
)
)
logger.info(await llm.aask_batch(["hi", "write python hello world."]))
hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}]
@ -29,12 +38,6 @@ async def main():
if hasattr(llm, "completion"):
logger.info(llm.completion(hello_msg))
# check if the configured llm supports llm-vision capacity. If not, it will throw a error
invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png")
img_base64 = encode_image(invoice_path)
res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64])
assert "true" in res.lower()
if __name__ == "__main__":
asyncio.run(main())

23
examples/llm_vision.py Normal file
View file

@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : example to run the ability of LLM vision
import asyncio
from pathlib import Path
from metagpt.llm import LLM
from metagpt.utils.common import encode_image
async def main():
llm = LLM()
# check if the configured llm supports llm-vision capacity. If not, it will throw a error
invoice_path = Path(__file__).parent.joinpath("..", "tests", "data", "invoices", "invoice-2.png")
img_base64 = encode_image(invoice_path)
res = await llm.aask(msg="if this is a invoice, just return True else return False", images=[img_base64])
assert "true" in res.lower()
if __name__ == "__main__":
asyncio.run(main())

View file

@ -133,7 +133,7 @@ class CollectLinks(Action):
if len(remove) == 0:
break
model_name = config.get_openai_llm().model
model_name = config.llm.model
prompt = reduce_message_length(gen_msg(), model_name, system_text, 4096)
logger.debug(prompt)
queries = await self._aask(prompt, [system_text])

View file

@ -24,6 +24,7 @@ class LLMType(Enum):
METAGPT = "metagpt"
AZURE = "azure"
OLLAMA = "ollama"
QIANFAN = "qianfan" # Baidu BCE
def __missing__(self, key):
return self.OPENAI
@ -36,13 +37,18 @@ class LLMConfig(YamlModel):
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
"""
api_key: str
api_key: str = "sk-"
api_type: LLMType = LLMType.OPENAI
base_url: str = "https://api.openai.com/v1"
api_version: Optional[str] = None
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
# For Cloud Service Provider like Baidu/ Alibaba
access_key: Optional[str] = None
secret_key: Optional[str] = None
endpoint: Optional[str] = None # for self-deployed model on the cloud
# For Spark(Xunfei), maybe remove later
app_id: Optional[str] = None
api_secret: Optional[str] = None

View file

@ -16,6 +16,7 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.metagpt_api import MetaGPTLLM
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.spark_api import SparkLLM
from metagpt.provider.qianfan_api import QianFanLLM
__all__ = [
"FireworksLLM",
@ -28,4 +29,5 @@ __all__ = [
"OllamaLLM",
"HumanProvider",
"SparkLLM",
"QianFanLLM",
]

View file

@ -11,11 +11,12 @@ from abc import ABC, abstractmethod
from typing import Optional, Union
from openai import AsyncOpenAI
from pydantic import BaseModel
from metagpt.configs.llm_config import LLMConfig
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import CostManager, Costs
class BaseLLM(ABC):
@ -67,6 +68,28 @@ class BaseLLM(ABC):
def _default_system_msg(self):
return self._system_msg(self.system_prompt)
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
"""update each request's token cost
Args:
model (str): model name or in some scenarios called endpoint
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
"""
calc_usage = self.config.calc_usage and local_calc_usage
model = model if model else self.model
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
if calc_usage and self.cost_manager:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, model)
except Exception as e:
logger.error(f"{self.__class__.__name__} updats costs failed! exp: {e}")
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
async def aask(
self,
msg: str,

View file

@ -19,7 +19,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import CostManager
MODEL_GRADE_TOKEN_COSTS = {
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
@ -81,17 +81,6 @@ class FireworksLLM(OpenAILLM):
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
return kwargs
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use FireworksCostManager not context.cost_manager
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self.cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages), stream=True
@ -113,7 +102,7 @@ class FireworksLLM(OpenAILLM):
usage = CompletionUsage(**chunk.usage)
full_content = "".join(collected_content)
self._update_costs(usage)
self._update_costs(usage.model_dump())
return full_content
@retry(

View file

@ -60,7 +60,8 @@ class GeneralAPIRequestor(APIRequestor):
self, result: requests.Response, stream: bool
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
"""Returns the response(s) and a bool indicating whether it is a stream."""
if stream and "text/event-stream" in result.headers.get("Content-Type", ""):
content_type = result.headers.get("Content-Type", "")
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
return (
self._interpret_response_line(line, result.status_code, result.headers, stream=True)
for line in parse_stream(result.iter_lines())

View file

@ -72,16 +72,6 @@ class GeminiLLM(BaseLLM):
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text

View file

@ -46,16 +46,6 @@ class OllamaLLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"ollama updats costs failed! exp: {e}")
def get_choice_text(self, resp: dict) -> str:
"""get the resp content from llm response"""
assist_msg = resp.get("message", {})

View file

@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.cost_manager import Costs, TokenCostManager
from metagpt.utils.cost_manager import TokenCostManager
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
@ -34,14 +34,3 @@ class OpenLLM(OpenAILLM):
logger.error(f"usage calculation failed!: {e}")
return usage
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use OpenLLMCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()

View file

@ -30,7 +30,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.common import CodeParser, decode_image
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
@ -56,16 +56,13 @@ class OpenAILLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self._init_model()
self._init_client()
self.auto_max_tokens = False
self.cost_manager: Optional[CostManager] = None
def _init_model(self):
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
kwargs = self._make_client_kwargs()
self.aclient = AsyncOpenAI(**kwargs)
@ -268,24 +265,16 @@ class OpenAILLM(BaseLLM):
usage.prompt_tokens = count_message_tokens(messages, self.model)
usage.completion_tokens = count_string_tokens(rsp, self.model)
except Exception as e:
logger.error(f"usage calculation failed: {e}")
logger.warning(f"usage calculation failed: {e}")
return usage
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage and self.cost_manager:
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return self.config.max_token
return get_max_completion_tokens(messages, self.model, self.config.max_tokens)
# FIXME
# https://community.openai.com/t/why-is-gpt-3-5-turbo-1106-max-tokens-limited-to-4096/494973/3
return min(get_max_completion_tokens(messages, self.model, self.config.max_token), 4096)
@handle_exception
async def amoderation(self, content: Union[str, list[str]]):

View file

@ -0,0 +1,152 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models
import copy
import os
import qianfan
from qianfan import ChatCompletion
from qianfan.resources.typing import JsonBody
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import (
QianFan_EndPoint_TOKEN_COSTS,
QianFan_MODEL_TOKEN_COSTS,
)
@register_provider(LLMType.QIANFAN)
class QianFanLLM(BaseLLM):
"""
Refs
Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
"""
def __init__(self, config: LLMConfig):
self.config = config
self.use_system_prompt = False # only some ERNIE-x related models support system_prompt
self.__init_qianfan()
self.cost_manager = CostManager(token_costs=self.token_costs)
def __init_qianfan(self):
if self.config.access_key and self.config.secret_key:
# for system level auth, use access_key and secret_key, recommended by official
# set environment variable due to official recommendation
os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
elif self.config.api_key and self.config.secret_key:
# for application level auth, use api_key and secret_key
# set environment variable due to official recommendation
os.environ.setdefault("QIANFAN_AK", self.config.api_key)
os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
else:
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
support_system_pairs = [
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
("ERNIE-Bot-8k", "ernie_bot_8k"),
("ERNIE-Bot", "completions"),
("ERNIE-Bot-turbo", "eb-instant"),
("ERNIE-Speed", "ernie_speed"),
("EB-turbo-AppBuilder", "ai_apaas"),
]
if self.config.model in [pair[0] for pair in support_system_pairs]:
# only some ERNIE models support
self.use_system_prompt = True
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
self.use_system_prompt = True
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
self.token_costs = copy.deepcopy(QianFan_MODEL_TOKEN_COSTS)
self.token_costs.update(QianFan_EndPoint_TOKEN_COSTS)
# self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
self.aclient: ChatCompletion = qianfan.ChatCompletion()
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"messages": messages,
"stream": stream,
}
if self.config.temperature > 0:
# different model has default temperature. only set when it's specified.
kwargs["temperature"] = self.config.temperature
if self.config.endpoint:
kwargs["endpoint"] = self.config.endpoint
elif self.config.model:
kwargs["model"] = self.config.model
if self.use_system_prompt:
# if the model support system prompt, extract and pass it
if messages[0]["role"] == "system":
kwargs["messages"] = messages[1:]
kwargs["system"] = messages[0]["content"] # set system prompt here
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
model_or_endpoint = self.config.model if self.config.model else self.config.endpoint
local_calc_usage = True if model_or_endpoint in self.token_costs else False
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
def get_choice_text(self, resp: JsonBody) -> str:
return resp.get("result", "")
def completion(self, messages: list[dict]) -> JsonBody:
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def _achat_completion(self, messages: list[dict]) -> JsonBody:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
collected_content = []
usage = {}
async for chunk in resp:
content = chunk.body.get("result", "")
usage = chunk.body.get("usage", {})
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -3,9 +3,8 @@
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
from enum import Enum
from typing import Optional
import openai
import zhipuai
from requests import ConnectionError
from tenacity import (
after_log,
@ -14,6 +13,7 @@ from tenacity import (
stop_after_attempt,
wait_random_exponential,
)
from zhipuai.types.chat.chat_completion import Completion
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
@ -21,6 +21,7 @@ from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
from metagpt.utils.cost_manager import CostManager
class ZhiPuEvent(Enum):
@ -38,37 +39,22 @@ class ZhiPuAILLM(BaseLLM):
"""
def __init__(self, config: LLMConfig):
self.__init_zhipuai(config)
self.llm = ZhiPuModelAPI
self.model = "chatglm_turbo" # so far only one model, just use it
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
self.config = config
self.__init_zhipuai()
self.cost_manager: Optional[CostManager] = None
def __init_zhipuai(self, config: LLMConfig):
assert config.api_key
zhipuai.api_key = config.api_key
# due to use openai sdk, set the api_key but it will't be used.
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
if config.proxy:
# FIXME: openai v1.x sdk has no proxy support
openai.proxy = config.proxy
def __init_zhipuai(self):
assert self.config.api_key
self.api_key = self.config.api_key
self.model = self.config.model # so far, it support glm-3-turbo、glm-4
self.llm = ZhiPuModelAPI(api_key=self.api_key)
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.config.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def completion(self, messages: list[dict], timeout=3) -> dict:
resp = self.llm.chat.completions.create(**self._const_kwargs(messages))
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()
self._update_costs(usage)
return resp.model_dump()

View file

@ -29,6 +29,7 @@ class CostManager(BaseModel):
total_budget: float = 0
max_budget: float = 10.0
total_cost: float = 0
token_costs: dict[str, dict[str, float]] = TOKEN_COSTS
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
@ -41,8 +42,13 @@ class CostManager(BaseModel):
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
if model not in TOKEN_COSTS:
logger.warning(f"Model {model} not found in TOKEN_COSTS.")
return
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
prompt_tokens * self.token_costs[model]["prompt"]
+ completion_tokens * self.token_costs[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(

View file

@ -25,7 +25,7 @@ def reduce_message_length(
"""
max_token = TOKEN_MAX.get(model_name, 2048) - count_string_tokens(system_text, model_name) - reserved
for msg in msgs:
if count_string_tokens(msg, model_name) < max_token:
if count_string_tokens(msg, model_name) < max_token or model_name not in TOKEN_MAX:
return msg
raise RuntimeError("fail to reduce message length")
@ -93,7 +93,7 @@ def split_paragraph(paragraph: str, sep: str = ".,", count: int = 2) -> list[str
continue
ret = ["".join(j) for j in _split_by_count(sentences, count)]
return ret
return _split_by_count(paragraph, count)
return list(_split_by_count(paragraph, count))
def decode_unicode_escape(text: str) -> str:

View file

@ -32,12 +32,65 @@ TOKEN_COSTS = {
"gpt-4-vision-preview": {"prompt": 0.01, "completion": 0.03}, # TODO add extra image price calculator
"gpt-4-1106-vision-preview": {"prompt": 0.01, "completion": 0.03},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
"glm-4": {"prompt": 0.0, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
"glm-3-turbo": {"prompt": 0.0007, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
"glm-4": {"prompt": 0.014, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
}
"""
QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method.
"""
QianFan_MODEL_TOKEN_COSTS = {
"ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017},
"ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067},
"ERNIE-Bot": {"prompt": 0.017, "completion": 0.017},
"ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011},
"EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011},
"ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011},
"BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084},
"Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049},
"ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056},
"AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056},
"Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049},
"SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056},
"CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056},
"XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049},
"Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084},
"ChatLaw": {"prompt": 0.0011, "completion": 0.0011},
"Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0},
}
QianFan_EndPoint_TOKEN_COSTS = {
"completions_pro": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-4"],
"ernie_bot_8k": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"],
"completions": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot"],
"eb-instant": QianFan_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"],
"ai_apaas": QianFan_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"],
"ernie_speed": QianFan_MODEL_TOKEN_COSTS["ERNIE-Speed"],
"bloomz_7b1": QianFan_MODEL_TOKEN_COSTS["BLOOMZ-7B"],
"llama_2_7b": QianFan_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"],
"llama_2_13b": QianFan_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"],
"llama_2_70b": QianFan_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"],
"chatglm2_6b_32k": QianFan_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"],
"aquilachat_7b": QianFan_MODEL_TOKEN_COSTS["AquilaChat-7B"],
"mixtral_8x7b_instruct": QianFan_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"],
"sqlcoder_7b": QianFan_MODEL_TOKEN_COSTS["SQLCoder-7B"],
"codellama_7b_instruct": QianFan_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"],
"xuanyuan_70b_chat": QianFan_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"],
"qianfan_bloomz_7b_compressed": QianFan_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"],
"qianfan_chinese_llama_2_7b": QianFan_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"],
"qianfan_chinese_llama_2_13b": QianFan_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"],
"chatlaw": QianFan_MODEL_TOKEN_COSTS["ChatLaw"],
"yi_34b_chat": QianFan_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
}
TOKEN_MAX = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,
@ -58,7 +111,8 @@ TOKEN_MAX = {
"gpt-4-vision-preview": 128000,
"gpt-4-1106-vision-preview": 128000,
"text-embedding-ada-002": 8192,
"chatglm_turbo": 32768,
"glm-3-turbo": 128000,
"glm-4": 128000,
"gemini-pro": 32768,
}

View file

@ -63,8 +63,9 @@ gitignore-parser==0.1.9
websockets~=12.0
networkx~=3.2.1
google-generativeai==0.3.2
# playwright==1.40.0 # playwright extras require
playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py
anytree
ipywidgets==8.1.1
Pillow
imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
qianfan==0.3.1

View file

@ -24,7 +24,6 @@ requirements = (here / "requirements.txt").read_text(encoding="utf-8").splitline
extras_require = {
"playwright": ["playwright>=1.26", "beautifulsoup4"],
"selenium": ["selenium>4", "webdriver_manager", "beautifulsoup4"],
"search-google": ["google-api-python-client==2.94.0"],
"search-ddg": ["duckduckgo-search~=4.1.1"],

View file

@ -14,6 +14,7 @@ from metagpt.actions.rebuild_class_view import RebuildClassView
from metagpt.llm import LLM
@pytest.mark.skip
@pytest.mark.asyncio
async def test_rebuild(context):
action = RebuildClassView(

View file

@ -176,6 +176,7 @@ class Snake:
"""
@pytest.mark.skip
@pytest.mark.asyncio
async def test_summarize_code(context):
git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}"

View file

@ -42,3 +42,15 @@ mock_llm_config_zhipu = LLMConfig(
model="mock_zhipu_model",
proxy="http://localhost:8080",
)
mock_llm_config_spark = LLMConfig(
api_type="spark",
app_id="xxx",
api_key="xxx",
api_secret="xxx",
domain="generalv2",
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
)
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")

View file

@ -0,0 +1,117 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : default request & response data for provider unittest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from qianfan.resources.typing import QfResponse
from metagpt.provider.base_llm import BaseLLM
prompt = "who are you?"
messages = [{"role": "user", "content": prompt}]
resp_cont_tmpl = "I'm {name}"
default_resp_cont = resp_cont_tmpl.format(name="GPT")
# part of whole ChatCompletion of openai like structure
def get_part_chat_completion(name: str) -> dict:
part_chat_completion = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": resp_cont_tmpl.format(name=name),
},
"finish_reason": "stop",
}
],
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
}
return part_chat_completion
def get_openai_chat_completion(name: str) -> ChatCompletion:
openai_chat_completion = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="xx/xxx",
object="chat.completion",
created=1703300855,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
return openai_chat_completion
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
usage = usage if not usage_as_dict else usage.model_dump()
openai_chat_completion_chunk = ChatCompletionChunk(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="xx/xxx",
object="chat.completion.chunk",
created=1703300855,
choices=[
AChoice(
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=usage,
)
return openai_chat_completion_chunk
# For gemini
gemini_messages = [{"role": "user", "parts": prompt}]
# For QianFan
qf_jsonbody_dict = {
"id": "as-4v1h587fyv",
"object": "chat.completion",
"created": 1695021339,
"result": "",
"is_truncated": False,
"need_clear_history": False,
"usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22},
}
def get_qianfan_response(name: str) -> QfResponse:
qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name)
return QfResponse(code=200, body=qf_jsonbody_dict)
# For llm general chat functions call
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
resp = await llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await llm.aask(prompt)
assert resp == resp_cont
resp = await llm.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await llm.acompletion_text(messages, stream=True)
assert resp == resp_cont

View file

@ -8,25 +8,25 @@ from anthropic.resources.completions import Completion
from metagpt.provider.anthropic_api import Claude2
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl
prompt = "who are you"
resp = "I'am Claude2"
resp_cont = resp_cont_tmpl.format(name="Claude")
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
def test_claude2_ask(mocker):
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
assert resp == Claude2(mock_llm_config).ask(prompt)
assert resp_cont == Claude2(mock_llm_config).ask(prompt)
@pytest.mark.asyncio
async def test_claude2_aask(mocker):
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
assert resp == await Claude2(mock_llm_config).aask(prompt)
assert resp_cont == await Claude2(mock_llm_config).aask(prompt)

View file

@ -11,21 +11,13 @@ import pytest
from metagpt.configs.llm_config import LLMConfig
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
from tests.metagpt.provider.req_resp_const import (
default_resp_cont,
get_part_chat_completion,
prompt,
)
default_chat_resp = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I'am GPT",
},
"finish_reason": "stop",
}
]
}
prompt_msg = "who are you"
resp_content = default_chat_resp["choices"][0]["message"]["content"]
name = "GPT"
class MockBaseLLM(BaseLLM):
@ -33,16 +25,13 @@ class MockBaseLLM(BaseLLM):
pass
def completion(self, messages: list[dict], timeout=3):
return default_chat_resp
return get_part_chat_completion(name)
async def acompletion(self, messages: list[dict], timeout=3):
return default_chat_resp
return get_part_chat_completion(name)
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
return resp_content
async def close(self):
return default_chat_resp
return default_resp_cont
def test_base_llm():
@ -86,25 +75,25 @@ def test_base_llm():
choice_text = base_llm.get_choice_text(openai_funccall_resp)
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
# resp = base_llm.ask(prompt_msg)
# assert resp == resp_content
# resp = base_llm.ask(prompt)
# assert resp == default_resp_cont
# resp = base_llm.ask_batch([prompt_msg])
# assert resp == resp_content
# resp = base_llm.ask_batch([prompt])
# assert resp == default_resp_cont
# resp = base_llm.ask_code([prompt_msg])
# assert resp == resp_content
# resp = base_llm.ask_code([prompt])
# assert resp == default_resp_cont
@pytest.mark.asyncio
async def test_async_base_llm():
base_llm = MockBaseLLM()
resp = await base_llm.aask(prompt_msg)
assert resp == resp_content
resp = await base_llm.aask(prompt)
assert resp == default_resp_cont
resp = await base_llm.aask_batch([prompt_msg])
assert resp == resp_content
resp = await base_llm.aask_batch([prompt])
assert resp == default_resp_cont
# resp = await base_llm.aask_code([prompt_msg])
# assert resp == resp_content
# resp = await base_llm.aask_code([prompt])
# assert resp == default_resp_cont

View file

@ -3,14 +3,7 @@
# @Desc : the unittest of fireworks api
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.fireworks_api import (
@ -20,42 +13,19 @@ from metagpt.provider.fireworks_api import (
)
from metagpt.utils.cost_manager import Costs
from tests.metagpt.provider.mock_llm_config import mock_llm_config
resp_content = "I'm fireworks"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="accounts/fireworks/models/llama-v2-13b-chat",
object="chat.completion",
created=1703300855,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
get_openai_chat_completion_chunk,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=dict(default_resp.usage),
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
name = "fireworks"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_openai_chat_completion(name)
default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True)
def test_fireworks_costmanager():
@ -88,27 +58,17 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
async def test_fireworks_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
fireworks_gpt = FireworksLLM(mock_llm_config)
fireworks_gpt.model = "llama-v2-13b-chat"
fireworks_llm = FireworksLLM(mock_llm_config)
fireworks_llm.model = "llama-v2-13b-chat"
fireworks_gpt._update_costs(
fireworks_llm._update_costs(
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
)
assert fireworks_gpt.get_costs() == Costs(
assert fireworks_llm.get_costs() == Costs(
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
)
resp = await fireworks_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
resp = await fireworks_llm.acompletion(messages)
assert resp.choices[0].message.content in resp_cont
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await fireworks_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont)

View file

@ -11,6 +11,12 @@ from google.generativeai.types import content_types
from metagpt.provider.google_gemini_api import GeminiLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
gemini_messages,
llm_general_chat_funcs_test,
prompt,
resp_cont_tmpl,
)
@dataclass
@ -18,10 +24,8 @@ class MockGeminiResponse(ABC):
text: str
prompt_msg = "who are you"
messages = [{"role": "user", "parts": prompt_msg}]
resp_content = "I'm gemini from google"
default_resp = MockGeminiResponse(text=resp_content)
resp_cont = resp_cont_tmpl.format(name="gemini")
default_resp = MockGeminiResponse(text=resp_cont)
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
@ -60,28 +64,18 @@ async def test_gemini_acompletion(mocker):
mock_gemini_generate_content_async,
)
gemini_gpt = GeminiLLM(mock_llm_config)
gemini_llm = GeminiLLM(mock_llm_config)
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]}
assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
usage = gemini_gpt.get_usage(messages, resp_content)
usage = gemini_llm.get_usage(gemini_messages, resp_cont)
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
resp = gemini_gpt.completion(messages)
resp = gemini_llm.completion(gemini_messages)
assert resp == default_resp
resp = await gemini_gpt.acompletion(messages)
resp = await gemini_llm.acompletion(gemini_messages)
assert resp.text == default_resp.text
resp = await gemini_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await gemini_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont)

View file

@ -9,12 +9,15 @@ import pytest
from metagpt.provider.ollama_api import OllamaLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm ollama"
default_resp = {"message": {"role": "assistant", "content": resp_content}}
resp_cont = resp_cont_tmpl.format(name="ollama")
default_resp = {"message": {"role": "assistant", "content": resp_cont}}
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
@ -41,19 +44,12 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
ollama_gpt = OllamaLLM(mock_llm_config)
ollama_llm = OllamaLLM(mock_llm_config)
resp = await ollama_gpt.acompletion(messages)
resp = await ollama_llm.acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
resp = await ollama_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await ollama_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await ollama_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await ollama_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await ollama_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont)

View file

@ -3,53 +3,26 @@
# @Desc :
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.open_llm_api import OpenLLM
from metagpt.utils.cost_manager import Costs
from metagpt.utils.cost_manager import CostManager, Costs
from tests.metagpt.provider.mock_llm_config import mock_llm_config
resp_content = "I'm llama2"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="llama-v2-13b-chat",
object="chat.completion",
created=1703302755,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
get_openai_chat_completion_chunk,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
)
name = "llama2-7b"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_openai_chat_completion(name)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
default_resp_chunk = get_openai_chat_completion_chunk(name)
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
@ -68,25 +41,16 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
async def test_openllm_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
openllm_gpt = OpenLLM(mock_llm_config)
openllm_gpt.model = "llama-v2-13b-chat"
openllm_llm = OpenLLM(mock_llm_config)
openllm_llm.model = "llama-v2-13b-chat"
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_gpt.get_costs() == Costs(
openllm_llm.cost_manager = CostManager()
openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_llm.get_costs() == Costs(
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
)
resp = await openllm_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
resp = await openllm_llm.acompletion(messages)
assert resp.choices[0].message.content in resp_cont
resp = await openllm_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await openllm_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont)

View file

@ -0,0 +1,56 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of qianfan api
from typing import AsyncIterator, Union
import pytest
from qianfan.resources.typing import JsonBody, QfResponse
from metagpt.provider.qianfan_api import QianFanLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan
from tests.metagpt.provider.req_resp_const import (
get_qianfan_response,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "ERNIE-Bot-turbo"
resp_cont = resp_cont_tmpl.format(name=name)
def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse:
return get_qianfan_response(name=name)
async def mock_qianfan_ado(
self, messages: list[dict], model: str, stream: bool = True, system: str = None
) -> Union[QfResponse, AsyncIterator[QfResponse]]:
resps = [get_qianfan_response(name=name)]
if stream:
async def aresp_iterator(resps: list[JsonBody]):
for resp in resps:
yield resp
return aresp_iterator(resps)
else:
return resps[0]
@pytest.mark.asyncio
async def test_qianfan_acompletion(mocker):
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do)
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado)
qianfan_llm = QianFanLLM(mock_llm_config_qianfan)
resp = qianfan_llm.completion(messages)
assert resp.get("result") == resp_cont
resp = await qianfan_llm.acompletion(messages)
assert resp.get("result") == resp_cont
await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont)

View file

@ -4,12 +4,18 @@
import pytest
from metagpt.config2 import Config
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_spark,
)
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
prompt,
resp_cont_tmpl,
)
prompt_msg = "who are you"
resp_content = "I'm Spark"
resp_cont = resp_cont_tmpl.format(name="Spark")
class MockWebSocketApp(object):
@ -23,7 +29,7 @@ class MockWebSocketApp(object):
def test_get_msg_from_web(mocker):
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config)
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
ret = get_msg_from_web.run()
@ -31,34 +37,26 @@ def test_get_msg_from_web(mocker):
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_content
return resp_cont
@pytest.mark.asyncio
async def test_spark_aask():
llm = SparkLLM(Config.from_home("spark.yaml").llm)
async def test_spark_aask(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
llm = SparkLLM(mock_llm_config_spark)
resp = await llm.aask("Hello!")
print(resp)
assert resp == resp_cont
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
spark_gpt = SparkLLM(mock_llm_config)
spark_llm = SparkLLM(mock_llm_config)
resp = await spark_gpt.acompletion([])
assert resp == resp_content
resp = await spark_llm.acompletion([])
assert resp == resp_cont
resp = await spark_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=True)
assert resp == resp_content
resp = await spark_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)

View file

@ -6,22 +6,24 @@ import pytest
from metagpt.provider.zhipuai_api import ZhiPuAILLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu
from tests.metagpt.provider.req_resp_const import (
get_part_chat_completion,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm chatglm-turbo"
default_resp = {
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
}
name = "ChatGLM-4"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_part_chat_completion(name)
async def mock_zhipuai_acreate_stream(**kwargs):
async def mock_zhipuai_acreate_stream(self, **kwargs):
class MockResponse(object):
async def _aread(self):
class Iterator(object):
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
events = [{"choices": [{"index": 0, "delta": {"content": resp_cont, "role": "assistant"}}]}]
async def __aiter__(self):
for event in self.events:
@ -37,7 +39,7 @@ async def mock_zhipuai_acreate_stream(**kwargs):
return MockResponse()
async def mock_zhipuai_acreate(**kwargs) -> dict:
async def mock_zhipuai_acreate(self, **kwargs) -> dict:
return default_resp
@ -46,22 +48,12 @@ async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu)
resp = await zhipu_gpt.acompletion(messages)
assert resp["choices"][0]["message"]["content"] == resp_content
resp = await zhipu_llm.acompletion(messages)
assert resp["choices"][0]["message"]["content"] == resp_cont
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await zhipu_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont)
def test_zhipuai_proxy():

View file

@ -42,6 +42,7 @@ def test_reduce_message_length(msgs, model_name, system_text, reserved, expected
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
(" ".join("Hello World" for _ in range(8000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1000, 8),
],
)
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):

View file

@ -1,7 +0,0 @@
llm:
api_type: "spark"
app_id: "xxx"
api_key: "xxx"
api_secret: "xxx"
domain: "generalv2"
base_url: "wss://spark-api.xf-yun.com/v3.1/chat"