Merge remote-tracking branch 'origin/main'

This commit is contained in:
mannaandpoem 2024-02-29 16:50:50 +08:00
commit eb3c6d14f9
41 changed files with 1093 additions and 363 deletions

View file

@ -54,7 +54,6 @@ jobs:
export ALLOW_OPENAI_API_CALL=0
echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml
mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml
echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
- name: Show coverage report
run: |

View file

@ -31,7 +31,7 @@ jobs:
- name: Test with pytest
run: |
export ALLOW_OPENAI_API_CALL=0
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml
mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml
pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt
- name: Show coverage report
run: |

View file

@ -8,14 +8,17 @@
import asyncio
from metagpt.actions import Action
from metagpt.config2 import Config
from metagpt.environment import Environment
from metagpt.roles import Role
from metagpt.team import Team
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
action1.llm.model = "gpt-4-1106-preview"
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
action2.llm.model = "gpt-3.5-turbo-1106"
gpt35 = Config.default()
gpt35.llm.model = "gpt-3.5-turbo-1106"
gpt4 = Config.default()
gpt4.llm.model = "gpt-4-1106-preview"
action1 = Action(config=gpt4, name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
action2 = Action(config=gpt35, name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2])
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
env = Environment(desc="US election live broadcast")

View file

@ -13,7 +13,18 @@ from metagpt.logs import logger
async def main():
llm = LLM()
logger.info(await llm.aask("hello world"))
# llm type check
question = "what's your name"
logger.info(f"{question}: ")
logger.info(await llm.aask(question))
logger.info("\n\n")
logger.info(
await llm.aask(
"who are you", system_msgs=["act as a robot, just answer 'I'am robot' if the question is 'who are you'"]
)
)
logger.info(await llm.aask_batch(["hi", "write python hello world."]))
hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}]

View file

@ -6,7 +6,9 @@ async def main():
image_path = "image.jpg"
language = "English"
requirement = f"""This is a {language} receipt image.
Your goal is to perform OCR on images using PaddleOCR, then extract the total amount from ocr text results, and finally save as table. Image path: {image_path}.
Your goal is to perform OCR on images using PaddleOCR, output text content from the OCR results and discard
coordinates and confidence levels, then recognize the total amount from ocr text content, and finally save as table.
Image path: {image_path}.
NOTE: The environments for Paddle and PaddleOCR are all ready and has been fully installed."""
mi = Interpreter()

View file

@ -24,6 +24,8 @@ class LLMType(Enum):
METAGPT = "metagpt"
AZURE = "azure"
OLLAMA = "ollama"
QIANFAN = "qianfan" # Baidu BCE
DASHSCOPE = "dashscope" # Aliyun LingJi DashScope
def __missing__(self, key):
return self.OPENAI
@ -36,13 +38,18 @@ class LLMConfig(YamlModel):
Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
"""
api_key: str
api_key: str = "sk-"
api_type: LLMType = LLMType.OPENAI
base_url: str = "https://api.openai.com/v1"
api_version: Optional[str] = None
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
# For Cloud Service Provider like Baidu/ Alibaba
access_key: Optional[str] = None
secret_key: Optional[str] = None
endpoint: Optional[str] = None # for self-deployed model on the cloud
# For Spark(Xunfei), maybe remove later
app_id: Optional[str] = None
api_secret: Optional[str] = None

View file

@ -7,7 +7,6 @@
from pathlib import Path
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from langchain_core.embeddings import Embeddings
@ -15,6 +14,7 @@ from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.embedding import get_embedding
from metagpt.utils.serialize import deserialize_message, serialize_message
@ -30,7 +30,7 @@ class MemoryStorage(FaissStore):
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False
self.embedding = embedding or OpenAIEmbeddings()
self.embedding = embedding or get_embedding()
self.store: FAISS = None # Faiss engine
@property

View file

@ -16,6 +16,8 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.metagpt_api import MetaGPTLLM
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.spark_api import SparkLLM
from metagpt.provider.qianfan_api import QianFanLLM
from metagpt.provider.dashscope_api import DashScopeLLM
__all__ = [
"FireworksLLM",
@ -28,4 +30,6 @@ __all__ = [
"OllamaLLM",
"HumanProvider",
"SparkLLM",
"QianFanLLM",
"DashScopeLLM",
]

View file

@ -11,11 +11,12 @@ from abc import ABC, abstractmethod
from typing import Optional, Union
from openai import AsyncOpenAI
from pydantic import BaseModel
from metagpt.configs.llm_config import LLMConfig
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.cost_manager import CostManager, Costs
class BaseLLM(ABC):
@ -67,6 +68,28 @@ class BaseLLM(ABC):
def _default_system_msg(self):
return self._system_msg(self.system_prompt)
def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True):
"""update each request's token cost
Args:
model (str): model name or in some scenarios called endpoint
local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage
"""
calc_usage = self.config.calc_usage and local_calc_usage
model = model or self.model
usage = usage.model_dump() if isinstance(usage, BaseModel) else usage
if calc_usage and self.cost_manager:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, model)
except Exception as e:
logger.error(f"{self.__class__.__name__} updates costs failed! exp: {e}")
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
async def aask(
self,
msg: str,

View file

@ -0,0 +1,248 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import json
from http import HTTPStatus
from typing import Any, AsyncGenerator, Dict, List, Union
import dashscope
from dashscope.aigc.generation import Generation
from dashscope.api_entities.aiohttp_request import AioHttpRequest
from dashscope.api_entities.api_request_data import ApiRequestData
from dashscope.api_entities.api_request_factory import _get_protocol_params
from dashscope.api_entities.dashscope_response import (
GenerationOutput,
GenerationResponse,
Message,
)
from dashscope.client.base_api import BaseAioApi
from dashscope.common.constants import SERVICE_API_PATH, ApiProtocol
from dashscope.common.error import (
InputDataRequired,
InputRequired,
ModelRequired,
UnsupportedApiProtocol,
)
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM, LLMConfig
from metagpt.provider.llm_provider_registry import LLMType, register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import DASHSCOPE_TOKEN_COSTS
def build_api_arequest(
model: str, input: object, task_group: str, task: str, function: str, api_key: str, is_service=True, **kwargs
):
(
api_protocol,
ws_stream_mode,
is_binary_input,
http_method,
stream,
async_request,
query,
headers,
request_timeout,
form,
resources,
) = _get_protocol_params(kwargs)
task_id = kwargs.pop("task_id", None)
if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]:
if not dashscope.base_http_api_url.endswith("/"):
http_url = dashscope.base_http_api_url + "/"
else:
http_url = dashscope.base_http_api_url
if is_service:
http_url = http_url + SERVICE_API_PATH + "/"
if task_group:
http_url += "%s/" % task_group
if task:
http_url += "%s/" % task
if function:
http_url += function
request = AioHttpRequest(
url=http_url,
api_key=api_key,
http_method=http_method,
stream=stream,
async_request=async_request,
query=query,
timeout=request_timeout,
task_id=task_id,
)
else:
raise UnsupportedApiProtocol("Unsupported protocol: %s, support [http, https, websocket]" % api_protocol)
if headers is not None:
request.add_headers(headers=headers)
if input is None and form is None:
raise InputDataRequired("There is no input data and form data")
request_data = ApiRequestData(
model,
task_group=task_group,
task=task,
function=function,
input=input,
form=form,
is_binary_input=is_binary_input,
api_protocol=api_protocol,
)
request_data.add_resources(resources)
request_data.add_parameters(**kwargs)
request.data = request_data
return request
class AGeneration(Generation, BaseAioApi):
@classmethod
async def acall(
cls,
model: str,
prompt: Any = None,
history: list = None,
api_key: str = None,
messages: List[Message] = None,
plugins: Union[str, Dict[str, Any]] = None,
**kwargs,
) -> Union[GenerationResponse, AsyncGenerator[GenerationResponse, None]]:
if (prompt is None or not prompt) and (messages is None or not messages):
raise InputRequired("prompt or messages is required!")
if model is None or not model:
raise ModelRequired("Model is required!")
task_group, function = "aigc", "generation" # fixed value
if plugins is not None:
headers = kwargs.pop("headers", {})
if isinstance(plugins, str):
headers["X-DashScope-Plugin"] = plugins
else:
headers["X-DashScope-Plugin"] = json.dumps(plugins)
kwargs["headers"] = headers
input, parameters = cls._build_input_parameters(model, prompt, history, messages, **kwargs)
api_key, model = BaseAioApi._validate_params(api_key, model)
request = build_api_arequest(
model=model,
input=input,
task_group=task_group,
task=Generation.task,
function=function,
api_key=api_key,
**kwargs,
)
response = await request.aio_call()
is_stream = kwargs.get("stream", False)
if is_stream:
async def aresp_iterator(response):
async for resp in response:
yield GenerationResponse.from_api_response(resp)
return aresp_iterator(response)
else:
return GenerationResponse.from_api_response(response)
@register_provider(LLMType.DASHSCOPE)
class DashScopeLLM(BaseLLM):
def __init__(self, llm_config: LLMConfig):
self.config = llm_config
self.use_system_prompt = False # only some models support system_prompt
self.__init_dashscope()
self.cost_manager = CostManager(token_costs=self.token_costs)
def __init_dashscope(self):
self.model = self.config.model
self.api_key = self.config.api_key
self.token_costs = DASHSCOPE_TOKEN_COSTS
self.aclient: AGeneration = AGeneration
# check support system_message models
support_system_models = [
"qwen-", # all support
"llama2-", # all support
"baichuan2-7b-chat-v1",
"chatglm3-6b",
]
for support_model in support_system_models:
if support_model in self.model:
self.use_system_prompt = True
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"api_key": self.api_key,
"model": self.model,
"messages": messages,
"stream": stream,
"result_format": "message",
}
if self.config.temperature > 0:
# different model has default temperature. only set when it"s specified.
kwargs["temperature"] = self.config.temperature
if stream:
kwargs["incremental_output"] = True
return kwargs
def _check_response(self, resp: GenerationResponse):
if resp.status_code != HTTPStatus.OK:
raise RuntimeError(f"code: {resp.code}, request_id: {resp.request_id}, message: {resp.message}")
def get_choice_text(self, output: GenerationOutput) -> str:
return output.get("choices", [{}])[0].get("message", {}).get("content", "")
def completion(self, messages: list[dict]) -> GenerationOutput:
resp: GenerationResponse = self.aclient.call(**self._const_kwargs(messages, stream=False))
self._check_response(resp)
self._update_costs(dict(resp.usage))
return resp.output
async def _achat_completion(self, messages: list[dict]) -> GenerationOutput:
resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False))
self._check_response(resp)
self._update_costs(dict(resp.usage))
return resp.output
async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
async for chunk in resp:
self._check_response(chunk)
content = chunk.output.choices[0]["message"]["content"]
usage = dict(chunk.usage) # each chunk has usage
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -16,10 +16,10 @@ from tenacity import (
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import CostManager
MODEL_GRADE_TOKEN_COSTS = {
"-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition
@ -81,17 +81,6 @@ class FireworksLLM(OpenAILLM):
kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url)
return kwargs
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use FireworksCostManager not context.cost_manager
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self.cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages), stream=True
@ -107,10 +96,11 @@ class FireworksLLM(OpenAILLM):
finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None
if choice_delta.content:
collected_content.append(choice_delta.content)
print(choice_delta.content, end="")
log_llm_stream(choice_delta.content)
if finish_reason:
# fireworks api return usage when finish_reason is not None
usage = CompletionUsage(**chunk.usage)
log_llm_stream("\n")
full_content = "".join(collected_content)
self._update_costs(usage)

View file

@ -72,16 +72,6 @@ class GeminiLLM(BaseLLM):
kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"google gemini updats costs failed! exp: {e}")
def get_choice_text(self, resp: GenerateContentResponse) -> str:
return resp.text

View file

@ -46,16 +46,6 @@ class OllamaLLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"ollama updats costs failed! exp: {e}")
def get_choice_text(self, resp: dict) -> str:
"""get the resp content from llm response"""
assist_msg = resp.get("message", {})

View file

@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.cost_manager import Costs, TokenCostManager
from metagpt.utils.cost_manager import TokenCostManager
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
@ -34,14 +34,3 @@ class OpenLLM(OpenAILLM):
logger.error(f"usage calculation failed!: {e}")
return usage
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage:
try:
# use OpenLLMCostManager not CONFIG.cost_manager
self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
def get_costs(self) -> Costs:
return self._cost_manager.get_costs()

View file

@ -30,7 +30,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.common import CodeParser, decode_image
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
count_message_tokens,
@ -56,16 +56,13 @@ class OpenAILLM(BaseLLM):
def __init__(self, config: LLMConfig):
self.config = config
self._init_model()
self._init_client()
self.auto_max_tokens = False
self.cost_manager: Optional[CostManager] = None
def _init_model(self):
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
self.model = self.config.model # Used in _calc_usage & _cons_kwargs
kwargs = self._make_client_kwargs()
self.aclient = AsyncOpenAI(**kwargs)
@ -102,7 +99,7 @@ class OpenAILLM(BaseLLM):
"max_tokens": self._get_max_tokens(messages),
"n": 1,
# "stop": None, # default it's None and gpt4-v can't have this one
"temperature": 0.3,
"temperature": self.config.temperature,
"model": self.model,
"timeout": max(self.config.timeout, timeout),
}
@ -272,16 +269,6 @@ class OpenAILLM(BaseLLM):
return usage
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if self.config.calc_usage and usage and self.cost_manager:
self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
def get_costs(self) -> Costs:
if not self.cost_manager:
return Costs(0, 0, 0, 0)
return self.cost_manager.get_costs()
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return self.config.max_token

View file

@ -0,0 +1,152 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models
import copy
import os
import qianfan
from qianfan import ChatCompletion
from qianfan.resources.typing import JsonBody
from tenacity import (
after_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_random_exponential,
)
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.token_counter import (
QIANFAN_ENDPOINT_TOKEN_COSTS,
QIANFAN_MODEL_TOKEN_COSTS,
)
@register_provider(LLMType.QIANFAN)
class QianFanLLM(BaseLLM):
"""
Refs
Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B
Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat
https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8
"""
def __init__(self, config: LLMConfig):
self.config = config
self.use_system_prompt = False # only some ERNIE-x related models support system_prompt
self.__init_qianfan()
self.cost_manager = CostManager(token_costs=self.token_costs)
def __init_qianfan(self):
if self.config.access_key and self.config.secret_key:
# for system level auth, use access_key and secret_key, recommended by official
# set environment variable due to official recommendation
os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key)
os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key)
elif self.config.api_key and self.config.secret_key:
# for application level auth, use api_key and secret_key
# set environment variable due to official recommendation
os.environ.setdefault("QIANFAN_AK", self.config.api_key)
os.environ.setdefault("QIANFAN_SK", self.config.secret_key)
else:
raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first")
support_system_pairs = [
("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint)
("ERNIE-Bot-8k", "ernie_bot_8k"),
("ERNIE-Bot", "completions"),
("ERNIE-Bot-turbo", "eb-instant"),
("ERNIE-Speed", "ernie_speed"),
("EB-turbo-AppBuilder", "ai_apaas"),
]
if self.config.model in [pair[0] for pair in support_system_pairs]:
# only some ERNIE models support
self.use_system_prompt = True
if self.config.endpoint in [pair[1] for pair in support_system_pairs]:
self.use_system_prompt = True
assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config"
assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config"
self.token_costs = copy.deepcopy(QIANFAN_MODEL_TOKEN_COSTS)
self.token_costs.update(QIANFAN_ENDPOINT_TOKEN_COSTS)
# self deployed model on the cloud not to calculate usage, it charges resource pool rental fee
self.calc_usage = self.config.calc_usage and self.config.endpoint is None
self.aclient: ChatCompletion = qianfan.ChatCompletion()
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {
"messages": messages,
"stream": stream,
}
if self.config.temperature > 0:
# different model has default temperature. only set when it's specified.
kwargs["temperature"] = self.config.temperature
if self.config.endpoint:
kwargs["endpoint"] = self.config.endpoint
elif self.config.model:
kwargs["model"] = self.config.model
if self.use_system_prompt:
# if the model support system prompt, extract and pass it
if messages[0]["role"] == "system":
kwargs["messages"] = messages[1:]
kwargs["system"] = messages[0]["content"] # set system prompt here
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
model_or_endpoint = self.config.model or self.config.endpoint
local_calc_usage = model_or_endpoint in self.token_costs
super()._update_costs(usage, model_or_endpoint, local_calc_usage)
def get_choice_text(self, resp: JsonBody) -> str:
return resp.get("result", "")
def completion(self, messages: list[dict]) -> JsonBody:
resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def _achat_completion(self, messages: list[dict]) -> JsonBody:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False))
self._update_costs(resp.body.get("usage", {}))
return resp.body
async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True))
collected_content = []
usage = {}
async for chunk in resp:
content = chunk.body.get("result", "")
usage = chunk.body.get("usage", {})
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)
return full_content
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(min=1, max=60),
after=after_log(logger, logger.level("WARNING").name),
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
if stream:
return await self._achat_completion_stream(messages)
resp = await self._achat_completion(messages)
return self.get_choice_text(resp)

View file

@ -53,16 +53,6 @@ class ZhiPuAILLM(BaseLLM):
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs
def _update_costs(self, usage: dict):
"""update each request's token cost"""
if self.config.calc_usage:
try:
prompt_tokens = int(usage.get("prompt_tokens", 0))
completion_tokens = int(usage.get("completion_tokens", 0))
self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model)
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def completion(self, messages: list[dict], timeout=3) -> dict:
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()

View file

@ -281,7 +281,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
i = action
self._init_action(i)
self.actions.append(i)
self.states.append(f"{len(self.actions)}. {action}")
self.states.append(f"{len(self.actions) - 1}. {action}")
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True, use_tools: bool = False):
"""Set strategy of the Role reacting to observed Message. Variation lies in how

View file

@ -2,14 +2,11 @@
# -*- coding: utf-8 -*-
import asyncio
import shutil
from pathlib import Path
import typer
from metagpt.config2 import config
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
from metagpt.context import Context
from metagpt.const import CONFIG_ROOT
from metagpt.utils.project_repo import ProjectRepo
app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False)
@ -30,6 +27,8 @@ def generate_repo(
recover_path=None,
) -> ProjectRepo:
"""Run the startup logic. Can be called from CLI or other Python scripts."""
from metagpt.config2 import config
from metagpt.context import Context
from metagpt.roles import (
Architect,
Engineer,
@ -122,7 +121,17 @@ def startup(
)
def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
DEFAULT_CONFIG = """# Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml
# Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py
llm:
api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options
model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview
base_url: "https://api.openai.com/v1" # or forward url / other llm url
api_key: "YOUR_API_KEY"
"""
def copy_config_to():
"""Initialize the configuration file for MetaGPT."""
target_path = CONFIG_ROOT / "config2.yaml"
@ -136,7 +145,7 @@ def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"):
print(f"Existing configuration file backed up at {backup_path}")
# 复制文件
shutil.copy(str(config_path), target_path)
target_path.write_text(DEFAULT_CONFIG, encoding="utf-8")
print(f"Configuration file initialized at {target_path}")

View file

@ -29,6 +29,7 @@ class CostManager(BaseModel):
total_budget: float = 0
max_budget: float = 10.0
total_cost: float = 0
token_costs: dict[str, dict[str, float]] = TOKEN_COSTS # different model's token cost
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
@ -41,12 +42,13 @@ class CostManager(BaseModel):
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
if model not in TOKEN_COSTS:
if model not in self.token_costs:
logger.warning(f"Model {model} not found in TOKEN_COSTS.")
return
cost = (
prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]
prompt_tokens * self.token_costs[model]["prompt"]
+ completion_tokens * self.token_costs[model]["completion"]
) / 1000
self.total_cost += cost
logger.info(

View file

@ -119,6 +119,7 @@ def repair_json_format(output: str) -> str:
logger.info(f"repair_json_format: {'}]'}")
elif output.startswith("{") and output.endswith("]"):
output = output[:-1] + "}"
# remove comments in output json string, after json value content, maybe start with #, maybe start with //
arr = output.split("\n")
new_arr = []
@ -208,6 +209,17 @@ def repair_invalid_json(output: str, error: str) -> str:
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
# problem, `"""` or `'''` without `,`
new_line = f",{line}"
elif col_no - 1 >= 0 and rline[col_no - 1] in ['"', "'"]:
# backslash problem like \" in the output
char = rline[col_no - 1]
nearest_char_idx = rline[col_no:].find(char)
new_line = (
rline[: col_no - 1]
+ "\\"
+ rline[col_no - 1 : col_no + nearest_char_idx]
+ "\\"
+ rline[col_no + nearest_char_idx :]
)
elif '",' not in line and "," not in line and '"' not in line:
new_line = f'{line}",'
elif not line.endswith(","):

View file

@ -38,6 +38,88 @@ TOKEN_COSTS = {
}
"""
QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9
Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method.
"""
QIANFAN_MODEL_TOKEN_COSTS = {
"ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017},
"ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067},
"ERNIE-Bot": {"prompt": 0.0017, "completion": 0.0017},
"ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011},
"EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011},
"ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011},
"BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056},
"Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084},
"Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049},
"ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056},
"AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056},
"Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049},
"SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056},
"CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056},
"XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049},
"Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056},
"Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084},
"ChatLaw": {"prompt": 0.0011, "completion": 0.0011},
"Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0},
}
QIANFAN_ENDPOINT_TOKEN_COSTS = {
"completions_pro": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-4"],
"ernie_bot_8k": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"],
"completions": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot"],
"eb-instant": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"],
"ai_apaas": QIANFAN_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"],
"ernie_speed": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Speed"],
"bloomz_7b1": QIANFAN_MODEL_TOKEN_COSTS["BLOOMZ-7B"],
"llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"],
"llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"],
"llama_2_70b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"],
"chatglm2_6b_32k": QIANFAN_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"],
"aquilachat_7b": QIANFAN_MODEL_TOKEN_COSTS["AquilaChat-7B"],
"mixtral_8x7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"],
"sqlcoder_7b": QIANFAN_MODEL_TOKEN_COSTS["SQLCoder-7B"],
"codellama_7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"],
"xuanyuan_70b_chat": QIANFAN_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"],
"qianfan_bloomz_7b_compressed": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"],
"qianfan_chinese_llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"],
"qianfan_chinese_llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"],
"chatlaw": QIANFAN_MODEL_TOKEN_COSTS["ChatLaw"],
"yi_34b_chat": QIANFAN_MODEL_TOKEN_COSTS["Yi-34B-Chat"],
}
"""
DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
Different model has different detail page. Attention, some model are free for a limited time.
"""
DASHSCOPE_TOKEN_COSTS = {
"qwen-turbo": {"prompt": 0.0011, "completion": 0.0011},
"qwen-plus": {"prompt": 0.0028, "completion": 0.0028},
"qwen-max": {"prompt": 0.0, "completion": 0.0},
"qwen-max-1201": {"prompt": 0.0, "completion": 0.0},
"qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0},
"llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0},
"qwen-72b-chat": {"prompt": 0.0, "completion": 0.0},
"qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011},
"qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084},
"qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0},
"baichuan2-13b-chat-v1": {"prompt": 0.0011, "completion": 0.0011},
"baichuan2-7b-chat-v1": {"prompt": 0.00084, "completion": 0.00084},
"baichuan-7b-v1": {"prompt": 0.0, "completion": 0.0},
"chatglm-6b-v2": {"prompt": 0.0011, "completion": 0.0011},
"chatglm3-6b": {"prompt": 0.0, "completion": 0.0},
"ziya-llama-13b-v1": {"prompt": 0.0, "completion": 0.0}, # no price page, judge it as free
"dolly-12b-v2": {"prompt": 0.0, "completion": 0.0},
"belle-llama-13b-2m-v1": {"prompt": 0.0, "completion": 0.0},
"moss-moon-003-sft-v1": {"prompt": 0.0, "completion": 0.0},
"chatyuan-large-v2": {"prompt": 0.0, "completion": 0.0},
"billa-7b-sft-v1": {"prompt": 0.0, "completion": 0.0},
}
TOKEN_MAX = {
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-0301": 4096,

View file

@ -11,7 +11,7 @@ typer==0.9.0
# godot==0.1.1
# google_api_python_client==2.93.0 # Used by search_engine.py
lancedb==0.4.0
langchain==0.0.352
langchain==0.1.8
loguru==0.6.0
meilisearch==0.21.0
numpy>=1.24.3,<1.25.0
@ -27,7 +27,7 @@ python_docx==0.8.11
PyYAML==6.0.1
# sentence_transformers==2.2.2
setuptools==65.6.3
tenacity==8.2.2
tenacity==8.2.3
tiktoken==0.5.2
tqdm==4.65.0
#unstructured[local-inference]
@ -54,7 +54,7 @@ rich==13.6.0
nbclient==0.9.0
nbformat==5.9.2
ipython==8.17.2
ipykernel==6.27.0
ipykernel==6.27.1
scikit_learn==1.3.2
typing-extensions==4.9.0
socksio~=1.0.0
@ -68,3 +68,5 @@ anytree
ipywidgets==8.1.1
Pillow
imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py
qianfan==0.3.2
dashscope==1.14.1

View file

@ -6,6 +6,9 @@
@File : test_faiss_store.py
"""
from typing import Optional
import numpy as np
import pytest
from metagpt.const import EXAMPLE_PATH
@ -14,8 +17,17 @@ from metagpt.logs import logger
from metagpt.roles import Sales
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
num = len(texts)
embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim
embeds = (embeds - embeds.mean(axis=0)) / (embeds.std(axis=0))
return embeds
@pytest.mark.asyncio
async def test_search_json():
async def test_search_json(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
store = FaissStore(EXAMPLE_PATH / "example.json")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
@ -24,7 +36,9 @@ async def test_search_json():
@pytest.mark.asyncio
async def test_search_xlsx():
async def test_search_xlsx(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
@ -33,7 +47,9 @@ async def test_search_xlsx():
@pytest.mark.asyncio
async def test_write():
async def test_write(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
_faiss_store = store.write()
assert _faiss_store.docstore

View file

@ -0,0 +1,33 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
from typing import Optional
import numpy as np
dim = 1536 # openai embedding dim
text_embed_arr = [
{"text": "Write a cli snake game", "embed": np.zeros(shape=[1, dim])}, # mock data, same as below
{"text": "Write a game of cli snake", "embed": np.zeros(shape=[1, dim])},
{"text": "Write a 2048 web game", "embed": np.ones(shape=[1, dim])},
{"text": "Write a Battle City", "embed": np.ones(shape=[1, dim])},
{
"text": "The user has requested the creation of a command-line interface (CLI) snake game",
"embed": np.zeros(shape=[1, dim]),
},
{"text": "The request is command-line interface (CLI) snake game", "embed": np.zeros(shape=[1, dim])},
{
"text": "Incorporate basic features of a snake game such as scoring and increasing difficulty",
"embed": np.ones(shape=[1, dim]),
},
]
text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)}
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
idx = text_idx_dict.get(texts[0])
embed = text_embed_arr[idx].get("embed")
return embed

View file

@ -4,20 +4,22 @@
@Desc : unittest of `metagpt/memory/longterm_memory.py`
"""
import os
import pytest
from metagpt.actions import UserRequirement
from metagpt.config2 import config
from metagpt.memory.longterm_memory import LongTermMemory
from metagpt.roles.role import RoleContext
from metagpt.schema import Message
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
from tests.metagpt.memory.mock_text_embed import (
mock_openai_embed_documents,
text_embed_arr,
)
def test_ltm_search():
def test_ltm_search(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
role_id = "UTUserLtm(Product Manager)"
from metagpt.environment import Environment
@ -27,20 +29,20 @@ def test_ltm_search():
ltm = LongTermMemory()
ltm.recover_memory(role_id, rc)
idea = "Write a cli snake game"
idea = text_embed_arr[0].get("text", "Write a cli snake game")
message = Message(role="User", content=idea, cause_by=UserRequirement)
news = ltm.find_news([message])
assert len(news) == 1
ltm.add(message)
sim_idea = "Write a game of cli snake"
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
news = ltm.find_news([sim_message])
assert len(news) == 0
ltm.add(sim_message)
new_idea = "Write a 2048 web game"
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
news = ltm.find_news([new_message])
assert len(news) == 1
@ -56,7 +58,7 @@ def test_ltm_search():
news = ltm_new.find_news([sim_message])
assert len(news) == 0
new_idea = "Write a Battle City"
new_idea = text_embed_arr[3].get("text", "Write a Battle City")
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
news = ltm_new.find_news([new_message])
assert len(news) == 1

View file

@ -4,23 +4,25 @@
@Desc : the unittests of metagpt/memory/memory_storage.py
"""
import os
import shutil
from pathlib import Path
from typing import List
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.action_node import ActionNode
from metagpt.config2 import config
from metagpt.const import DATA_PATH
from metagpt.memory.memory_storage import MemoryStorage
from metagpt.schema import Message
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
from tests.metagpt.memory.mock_text_embed import (
mock_openai_embed_documents,
text_embed_arr,
)
def test_idea_message():
idea = "Write a cli snake game"
def test_idea_message(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
idea = text_embed_arr[0].get("text", "Write a cli snake game")
role_id = "UTUser1(Product Manager)"
message = Message(role="User", content=idea, cause_by=UserRequirement)
@ -33,12 +35,12 @@ def test_idea_message():
memory_storage.add(message)
assert memory_storage.is_initialized is True
sim_idea = "Write a game of cli snake"
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
new_messages = memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_idea = "Write a 2048 web game"
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
new_messages = memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content
@ -47,13 +49,17 @@ def test_idea_message():
assert memory_storage.is_initialized is False
def test_actionout_message():
def test_actionout_message(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
ic_obj = ActionNode.create_model_class("prd", out_mapping)
role_id = "UTUser2(Architect)"
content = "The user has requested the creation of a command-line interface (CLI) snake game"
content = text_embed_arr[4].get(
"text", "The user has requested the creation of a command-line interface (CLI) snake game"
)
message = Message(
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
) # WritePRD as test action
@ -67,12 +73,14 @@ def test_actionout_message():
memory_storage.add(message)
assert memory_storage.is_initialized is True
sim_conent = "The request is command-line interface (CLI) snake game"
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty"
new_conent = text_embed_arr[6].get(
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
)
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content

View file

@ -42,3 +42,17 @@ mock_llm_config_zhipu = LLMConfig(
model="mock_zhipu_model",
proxy="http://localhost:8080",
)
mock_llm_config_spark = LLMConfig(
api_type="spark",
app_id="xxx",
api_key="xxx",
api_secret="xxx",
domain="generalv2",
base_url="wss://spark-api.xf-yun.com/v3.1/chat",
)
mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo")
mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model="qwen-max")

View file

@ -0,0 +1,145 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : default request & response data for provider unittest
from dashscope.api_entities.dashscope_response import (
DashScopeAPIResponse,
GenerationOutput,
GenerationResponse,
GenerationUsage,
)
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from qianfan.resources.typing import QfResponse
from metagpt.provider.base_llm import BaseLLM
prompt = "who are you?"
messages = [{"role": "user", "content": prompt}]
resp_cont_tmpl = "I'm {name}"
default_resp_cont = resp_cont_tmpl.format(name="GPT")
# part of whole ChatCompletion of openai like structure
def get_part_chat_completion(name: str) -> dict:
part_chat_completion = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": resp_cont_tmpl.format(name=name),
},
"finish_reason": "stop",
}
],
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
}
return part_chat_completion
def get_openai_chat_completion(name: str) -> ChatCompletion:
openai_chat_completion = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="xx/xxx",
object="chat.completion",
created=1703300855,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
)
return openai_chat_completion
def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk:
usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202)
usage = usage if not usage_as_dict else usage.model_dump()
openai_chat_completion_chunk = ChatCompletionChunk(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="xx/xxx",
object="chat.completion.chunk",
created=1703300855,
choices=[
AChoice(
delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=usage,
)
return openai_chat_completion_chunk
# For gemini
gemini_messages = [{"role": "user", "parts": prompt}]
# For QianFan
qf_jsonbody_dict = {
"id": "as-4v1h587fyv",
"object": "chat.completion",
"created": 1695021339,
"result": "",
"is_truncated": False,
"need_clear_history": False,
"usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22},
}
def get_qianfan_response(name: str) -> QfResponse:
qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name)
return QfResponse(code=200, body=qf_jsonbody_dict)
# For DashScope
def get_dashscope_response(name: str) -> GenerationResponse:
return GenerationResponse.from_api_response(
DashScopeAPIResponse(
status_code=200,
output=GenerationOutput(
**{
"text": "",
"finish_reason": "",
"choices": [
{
"finish_reason": "stop",
"message": {"role": "assistant", "content": resp_cont_tmpl.format(name=name)},
}
],
}
),
usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}),
)
)
# For llm general chat functions call
async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str):
resp = await llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await llm.aask(prompt)
assert resp == resp_cont
resp = await llm.acompletion_text(messages, stream=False)
assert resp == resp_cont
resp = await llm.acompletion_text(messages, stream=True)
assert resp == resp_cont

View file

@ -8,25 +8,25 @@ from anthropic.resources.completions import Completion
from metagpt.provider.anthropic_api import Claude2
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl
prompt = "who are you"
resp = "I'am Claude2"
resp_cont = resp_cont_tmpl.format(name="Claude")
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion")
def test_claude2_ask(mocker):
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
assert resp == Claude2(mock_llm_config).ask(prompt)
assert resp_cont == Claude2(mock_llm_config).ask(prompt)
@pytest.mark.asyncio
async def test_claude2_aask(mocker):
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
assert resp == await Claude2(mock_llm_config).aask(prompt)
assert resp_cont == await Claude2(mock_llm_config).aask(prompt)

View file

@ -11,21 +11,13 @@ import pytest
from metagpt.configs.llm_config import LLMConfig
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
from tests.metagpt.provider.req_resp_const import (
default_resp_cont,
get_part_chat_completion,
prompt,
)
default_chat_resp = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "I'am GPT",
},
"finish_reason": "stop",
}
]
}
prompt_msg = "who are you"
resp_content = default_chat_resp["choices"][0]["message"]["content"]
name = "GPT"
class MockBaseLLM(BaseLLM):
@ -33,16 +25,13 @@ class MockBaseLLM(BaseLLM):
pass
def completion(self, messages: list[dict], timeout=3):
return default_chat_resp
return get_part_chat_completion(name)
async def acompletion(self, messages: list[dict], timeout=3):
return default_chat_resp
return get_part_chat_completion(name)
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
return resp_content
async def close(self):
return default_chat_resp
return default_resp_cont
def test_base_llm():
@ -86,25 +75,25 @@ def test_base_llm():
choice_text = base_llm.get_choice_text(openai_funccall_resp)
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
# resp = base_llm.ask(prompt_msg)
# assert resp == resp_content
# resp = base_llm.ask(prompt)
# assert resp == default_resp_cont
# resp = base_llm.ask_batch([prompt_msg])
# assert resp == resp_content
# resp = base_llm.ask_batch([prompt])
# assert resp == default_resp_cont
# resp = base_llm.ask_code([prompt_msg])
# assert resp == resp_content
# resp = base_llm.ask_code([prompt])
# assert resp == default_resp_cont
@pytest.mark.asyncio
async def test_async_base_llm():
base_llm = MockBaseLLM()
resp = await base_llm.aask(prompt_msg)
assert resp == resp_content
resp = await base_llm.aask(prompt)
assert resp == default_resp_cont
resp = await base_llm.aask_batch([prompt_msg])
assert resp == resp_content
resp = await base_llm.aask_batch([prompt])
assert resp == default_resp_cont
# resp = await base_llm.aask_code([prompt_msg])
# assert resp == resp_content
# resp = await base_llm.aask_code([prompt])
# assert resp == default_resp_cont

View file

@ -0,0 +1,73 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of DashScopeLLM
from typing import AsyncGenerator, Union
import pytest
from dashscope.api_entities.dashscope_response import GenerationResponse
from metagpt.provider.dashscope_api import DashScopeLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_dashscope
from tests.metagpt.provider.req_resp_const import (
get_dashscope_response,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "qwen-max"
resp_cont = resp_cont_tmpl.format(name=name)
@classmethod
def mock_dashscope_call(
cls,
messages: list[dict],
model: str,
api_key: str,
result_format: str,
incremental_output: bool = True,
stream: bool = False,
) -> GenerationResponse:
return get_dashscope_response(name)
@classmethod
async def mock_dashscope_acall(
cls,
messages: list[dict],
model: str,
api_key: str,
result_format: str,
incremental_output: bool = True,
stream: bool = False,
) -> Union[AsyncGenerator[GenerationResponse, None], GenerationResponse]:
resps = [get_dashscope_response(name)]
if stream:
async def aresp_iterator(resps: list[GenerationResponse]):
for resp in resps:
yield resp
return aresp_iterator(resps)
else:
return resps[0]
@pytest.mark.asyncio
async def test_dashscope_acompletion(mocker):
mocker.patch("dashscope.aigc.generation.Generation.call", mock_dashscope_call)
mocker.patch("metagpt.provider.dashscope_api.AGeneration.acall", mock_dashscope_acall)
dashscope_llm = DashScopeLLM(mock_llm_config_dashscope)
resp = dashscope_llm.completion(messages)
assert resp.choices[0]["message"]["content"] == resp_cont
resp = await dashscope_llm.acompletion(messages)
assert resp.choices[0]["message"]["content"] == resp_cont
await llm_general_chat_funcs_test(dashscope_llm, prompt, messages, resp_cont)

View file

@ -3,14 +3,7 @@
# @Desc : the unittest of fireworks api
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.fireworks_api import (
@ -20,42 +13,19 @@ from metagpt.provider.fireworks_api import (
)
from metagpt.utils.cost_manager import Costs
from tests.metagpt.provider.mock_llm_config import mock_llm_config
resp_content = "I'm fireworks"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="accounts/fireworks/models/llama-v2-13b-chat",
object="chat.completion",
created=1703300855,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
get_openai_chat_completion_chunk,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
usage=dict(default_resp.usage),
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
name = "fireworks"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_openai_chat_completion(name)
default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True)
def test_fireworks_costmanager():
@ -88,27 +58,17 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
async def test_fireworks_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
fireworks_gpt = FireworksLLM(mock_llm_config)
fireworks_gpt.model = "llama-v2-13b-chat"
fireworks_llm = FireworksLLM(mock_llm_config)
fireworks_llm.model = "llama-v2-13b-chat"
fireworks_gpt._update_costs(
fireworks_llm._update_costs(
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
)
assert fireworks_gpt.get_costs() == Costs(
assert fireworks_llm.get_costs() == Costs(
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
)
resp = await fireworks_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
resp = await fireworks_llm.acompletion(messages)
assert resp.choices[0].message.content in resp_cont
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await fireworks_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont)

View file

@ -11,6 +11,12 @@ from google.generativeai.types import content_types
from metagpt.provider.google_gemini_api import GeminiLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
gemini_messages,
llm_general_chat_funcs_test,
prompt,
resp_cont_tmpl,
)
@dataclass
@ -18,10 +24,8 @@ class MockGeminiResponse(ABC):
text: str
prompt_msg = "who are you"
messages = [{"role": "user", "parts": prompt_msg}]
resp_content = "I'm gemini from google"
default_resp = MockGeminiResponse(text=resp_content)
resp_cont = resp_cont_tmpl.format(name="gemini")
default_resp = MockGeminiResponse(text=resp_cont)
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
@ -60,28 +64,18 @@ async def test_gemini_acompletion(mocker):
mock_gemini_generate_content_async,
)
gemini_gpt = GeminiLLM(mock_llm_config)
gemini_llm = GeminiLLM(mock_llm_config)
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]}
assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]}
usage = gemini_gpt.get_usage(messages, resp_content)
usage = gemini_llm.get_usage(gemini_messages, resp_cont)
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
resp = gemini_gpt.completion(messages)
resp = gemini_llm.completion(gemini_messages)
assert resp == default_resp
resp = await gemini_gpt.acompletion(messages)
resp = await gemini_llm.acompletion(gemini_messages)
assert resp.text == default_resp.text
resp = await gemini_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await gemini_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await gemini_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont)

View file

@ -9,12 +9,15 @@ import pytest
from metagpt.provider.ollama_api import OllamaLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm ollama"
default_resp = {"message": {"role": "assistant", "content": resp_content}}
resp_cont = resp_cont_tmpl.format(name="ollama")
default_resp = {"message": {"role": "assistant", "content": resp_cont}}
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
@ -41,19 +44,12 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
ollama_gpt = OllamaLLM(mock_llm_config)
ollama_llm = OllamaLLM(mock_llm_config)
resp = await ollama_gpt.acompletion(messages)
resp = await ollama_llm.acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
resp = await ollama_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await ollama_llm.aask(prompt, stream=False)
assert resp == resp_cont
resp = await ollama_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await ollama_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await ollama_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont)

View file

@ -3,53 +3,26 @@
# @Desc :
import pytest
from openai.types.chat.chat_completion import (
ChatCompletion,
ChatCompletionMessage,
Choice,
)
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as AChoice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.completion_usage import CompletionUsage
from metagpt.provider.open_llm_api import OpenLLM
from metagpt.utils.cost_manager import Costs
from metagpt.utils.cost_manager import CostManager, Costs
from tests.metagpt.provider.mock_llm_config import mock_llm_config
resp_content = "I'm llama2"
default_resp = ChatCompletion(
id="cmpl-a6652c1bb181caae8dd19ad8",
model="llama-v2-13b-chat",
object="chat.completion",
created=1703302755,
choices=[
Choice(
finish_reason="stop",
index=0,
message=ChatCompletionMessage(role="assistant", content=resp_content),
logprobs=None,
)
],
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
get_openai_chat_completion_chunk,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
default_resp_chunk = ChatCompletionChunk(
id=default_resp.id,
model=default_resp.model,
object="chat.completion.chunk",
created=default_resp.created,
choices=[
AChoice(
delta=ChoiceDelta(content=resp_content, role="assistant"),
finish_reason="stop",
index=0,
logprobs=None,
)
],
)
name = "llama2-7b"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_openai_chat_completion(name)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
default_resp_chunk = get_openai_chat_completion_chunk(name)
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
@ -68,25 +41,16 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
async def test_openllm_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
openllm_gpt = OpenLLM(mock_llm_config)
openllm_gpt.model = "llama-v2-13b-chat"
openllm_llm = OpenLLM(mock_llm_config)
openllm_llm.model = "llama-v2-13b-chat"
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_gpt.get_costs() == Costs(
openllm_llm.cost_manager = CostManager()
openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
assert openllm_llm.get_costs() == Costs(
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
)
resp = await openllm_gpt.acompletion(messages)
assert resp.choices[0].message.content in resp_content
resp = await openllm_llm.acompletion(messages)
assert resp.choices[0].message.content in resp_cont
resp = await openllm_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await openllm_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await openllm_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont)

View file

@ -0,0 +1,56 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of qianfan api
from typing import AsyncIterator, Union
import pytest
from qianfan.resources.typing import JsonBody, QfResponse
from metagpt.provider.qianfan_api import QianFanLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan
from tests.metagpt.provider.req_resp_const import (
get_qianfan_response,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "ERNIE-Bot-turbo"
resp_cont = resp_cont_tmpl.format(name=name)
def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse:
return get_qianfan_response(name=name)
async def mock_qianfan_ado(
self, messages: list[dict], model: str, stream: bool = True, system: str = None
) -> Union[QfResponse, AsyncIterator[QfResponse]]:
resps = [get_qianfan_response(name=name)]
if stream:
async def aresp_iterator(resps: list[JsonBody]):
for resp in resps:
yield resp
return aresp_iterator(resps)
else:
return resps[0]
@pytest.mark.asyncio
async def test_qianfan_acompletion(mocker):
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do)
mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado)
qianfan_llm = QianFanLLM(mock_llm_config_qianfan)
resp = qianfan_llm.completion(messages)
assert resp.get("result") == resp_cont
resp = await qianfan_llm.acompletion(messages)
assert resp.get("result") == resp_cont
await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont)

View file

@ -4,12 +4,18 @@
import pytest
from metagpt.config2 import Config
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_spark,
)
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
prompt,
resp_cont_tmpl,
)
prompt_msg = "who are you"
resp_content = "I'm Spark"
resp_cont = resp_cont_tmpl.format(name="Spark")
class MockWebSocketApp(object):
@ -23,7 +29,7 @@ class MockWebSocketApp(object):
def test_get_msg_from_web(mocker):
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config)
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
ret = get_msg_from_web.run()
@ -31,34 +37,26 @@ def test_get_msg_from_web(mocker):
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_content
return resp_cont
@pytest.mark.asyncio
async def test_spark_aask():
llm = SparkLLM(Config.from_home("spark.yaml").llm)
async def test_spark_aask(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
llm = SparkLLM(mock_llm_config_spark)
resp = await llm.aask("Hello!")
print(resp)
assert resp == resp_cont
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
spark_gpt = SparkLLM(mock_llm_config)
spark_llm = SparkLLM(mock_llm_config)
resp = await spark_gpt.acompletion([])
assert resp == resp_content
resp = await spark_llm.acompletion([])
assert resp == resp_cont
resp = await spark_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=False)
assert resp == resp_content
resp = await spark_gpt.acompletion_text([], stream=True)
assert resp == resp_content
resp = await spark_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)

View file

@ -6,22 +6,24 @@ import pytest
from metagpt.provider.zhipuai_api import ZhiPuAILLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu
from tests.metagpt.provider.req_resp_const import (
get_part_chat_completion,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
resp_content = "I'm chatglm-turbo"
default_resp = {
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
}
name = "ChatGLM-4"
resp_cont = resp_cont_tmpl.format(name=name)
default_resp = get_part_chat_completion(name)
async def mock_zhipuai_acreate_stream(self, **kwargs):
class MockResponse(object):
async def _aread(self):
class Iterator(object):
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
events = [{"choices": [{"index": 0, "delta": {"content": resp_cont, "role": "assistant"}}]}]
async def __aiter__(self):
for event in self.events:
@ -46,22 +48,12 @@ async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu)
resp = await zhipu_gpt.acompletion(messages)
assert resp["choices"][0]["message"]["content"] == resp_content
resp = await zhipu_llm.acompletion(messages)
assert resp["choices"][0]["message"]["content"] == resp_cont
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
assert resp == resp_content
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
assert resp == resp_content
resp = await zhipu_gpt.aask(prompt_msg)
assert resp == resp_content
await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont)
def test_zhipuai_proxy():

View file

@ -211,6 +211,11 @@ value
output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1")
assert output == target_output
raw_output = '{"key": "url "http" \\"https\\" "}'
target_output = '{"key": "url \\"http\\" \\"https\\" "}'
output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 1 column 15 (char 14)")
assert output == target_output
def test_retry_parse_json_text():
from metagpt.utils.repair_llm_raw_output import retry_parse_json_text

View file

@ -1,7 +0,0 @@
llm:
api_type: "spark"
app_id: "xxx"
api_key: "xxx"
api_secret: "xxx"
domain: "generalv2"
base_url: "wss://spark-api.xf-yun.com/v3.1/chat"