diff --git a/.github/workflows/fulltest.yaml b/.github/workflows/fulltest.yaml index f5c6049e1..70c800481 100644 --- a/.github/workflows/fulltest.yaml +++ b/.github/workflows/fulltest.yaml @@ -54,7 +54,6 @@ jobs: export ALLOW_OPENAI_API_CALL=0 echo "${{ secrets.METAGPT_KEY_YAML }}" | base64 -d > config/key.yaml mkdir -p ~/.metagpt && echo "${{ secrets.METAGPT_CONFIG2_YAML }}" | base64 -d > ~/.metagpt/config2.yaml - echo "${{ secrets.SPARK_YAML }}" | base64 -d > ~/.metagpt/spark.yaml pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt - name: Show coverage report run: | diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 2e7e3ce2b..afa9faba7 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -31,7 +31,7 @@ jobs: - name: Test with pytest run: | export ALLOW_OPENAI_API_CALL=0 - mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml && cp tests/spark.yaml ~/.metagpt/spark.yaml + mkdir -p ~/.metagpt && cp tests/config2.yaml ~/.metagpt/config2.yaml pytest tests/ --doctest-modules --cov=./metagpt/ --cov-report=xml:cov.xml --cov-report=html:htmlcov --durations=20 | tee unittest.txt - name: Show coverage report run: | diff --git a/examples/debate_simple.py b/examples/debate_simple.py index 869e02a0e..953f664f3 100644 --- a/examples/debate_simple.py +++ b/examples/debate_simple.py @@ -8,14 +8,17 @@ import asyncio from metagpt.actions import Action +from metagpt.config2 import Config from metagpt.environment import Environment from metagpt.roles import Role from metagpt.team import Team -action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it") -action1.llm.model = "gpt-4-1106-preview" -action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it") -action2.llm.model = "gpt-3.5-turbo-1106" +gpt35 = Config.default() +gpt35.llm.model = "gpt-3.5-turbo-1106" +gpt4 = Config.default() +gpt4.llm.model = "gpt-4-1106-preview" +action1 = Action(config=gpt4, name="AlexSay", instruction="Express your opinion with emotion and don't repeat it") +action2 = Action(config=gpt35, name="BobSay", instruction="Express your opinion with emotion and don't repeat it") alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]) bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]) env = Environment(desc="US election live broadcast") diff --git a/examples/llm_hello_world.py b/examples/llm_hello_world.py index 219a303c8..62fc2ed68 100644 --- a/examples/llm_hello_world.py +++ b/examples/llm_hello_world.py @@ -13,7 +13,18 @@ from metagpt.logs import logger async def main(): llm = LLM() - logger.info(await llm.aask("hello world")) + # llm type check + question = "what's your name" + logger.info(f"{question}: ") + logger.info(await llm.aask(question)) + logger.info("\n\n") + + logger.info( + await llm.aask( + "who are you", system_msgs=["act as a robot, just answer 'I'am robot' if the question is 'who are you'"] + ) + ) + logger.info(await llm.aask_batch(["hi", "write python hello world."])) hello_msg = [{"role": "user", "content": "count from 1 to 10. split by newline."}] diff --git a/examples/mi/ocr_receipt.py b/examples/mi/ocr_receipt.py index ffa5cff05..f394cccec 100644 --- a/examples/mi/ocr_receipt.py +++ b/examples/mi/ocr_receipt.py @@ -6,7 +6,9 @@ async def main(): image_path = "image.jpg" language = "English" requirement = f"""This is a {language} receipt image. - Your goal is to perform OCR on images using PaddleOCR, then extract the total amount from ocr text results, and finally save as table. Image path: {image_path}. + Your goal is to perform OCR on images using PaddleOCR, output text content from the OCR results and discard + coordinates and confidence levels, then recognize the total amount from ocr text content, and finally save as table. + Image path: {image_path}. NOTE: The environments for Paddle and PaddleOCR are all ready and has been fully installed.""" mi = Interpreter() diff --git a/metagpt/configs/llm_config.py b/metagpt/configs/llm_config.py index fb923d3e4..36f5d7ae7 100644 --- a/metagpt/configs/llm_config.py +++ b/metagpt/configs/llm_config.py @@ -24,6 +24,8 @@ class LLMType(Enum): METAGPT = "metagpt" AZURE = "azure" OLLAMA = "ollama" + QIANFAN = "qianfan" # Baidu BCE + DASHSCOPE = "dashscope" # Aliyun LingJi DashScope def __missing__(self, key): return self.OPENAI @@ -36,13 +38,18 @@ class LLMConfig(YamlModel): Optional Fields in pydantic: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields """ - api_key: str + api_key: str = "sk-" api_type: LLMType = LLMType.OPENAI base_url: str = "https://api.openai.com/v1" api_version: Optional[str] = None model: Optional[str] = None # also stands for DEPLOYMENT_NAME + # For Cloud Service Provider like Baidu/ Alibaba + access_key: Optional[str] = None + secret_key: Optional[str] = None + endpoint: Optional[str] = None # for self-deployed model on the cloud + # For Spark(Xunfei), maybe remove later app_id: Optional[str] = None api_secret: Optional[str] = None diff --git a/metagpt/memory/memory_storage.py b/metagpt/memory/memory_storage.py index c029d027b..fa04d8138 100644 --- a/metagpt/memory/memory_storage.py +++ b/metagpt/memory/memory_storage.py @@ -7,7 +7,6 @@ from pathlib import Path from typing import Optional -from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores.faiss import FAISS from langchain_core.embeddings import Embeddings @@ -15,6 +14,7 @@ from metagpt.const import DATA_PATH, MEM_TTL from metagpt.document_store.faiss_store import FaissStore from metagpt.logs import logger from metagpt.schema import Message +from metagpt.utils.embedding import get_embedding from metagpt.utils.serialize import deserialize_message, serialize_message @@ -30,7 +30,7 @@ class MemoryStorage(FaissStore): self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories self._initialized: bool = False - self.embedding = embedding or OpenAIEmbeddings() + self.embedding = embedding or get_embedding() self.store: FAISS = None # Faiss engine @property diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 675734811..44e6d3f3b 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -16,6 +16,8 @@ from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.metagpt_api import MetaGPTLLM from metagpt.provider.human_provider import HumanProvider from metagpt.provider.spark_api import SparkLLM +from metagpt.provider.qianfan_api import QianFanLLM +from metagpt.provider.dashscope_api import DashScopeLLM __all__ = [ "FireworksLLM", @@ -28,4 +30,6 @@ __all__ = [ "OllamaLLM", "HumanProvider", "SparkLLM", + "QianFanLLM", + "DashScopeLLM", ] diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index b144471b5..7cf3faac0 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -11,11 +11,12 @@ from abc import ABC, abstractmethod from typing import Optional, Union from openai import AsyncOpenAI +from pydantic import BaseModel from metagpt.configs.llm_config import LLMConfig from metagpt.logs import logger from metagpt.schema import Message -from metagpt.utils.cost_manager import CostManager +from metagpt.utils.cost_manager import CostManager, Costs class BaseLLM(ABC): @@ -67,6 +68,28 @@ class BaseLLM(ABC): def _default_system_msg(self): return self._system_msg(self.system_prompt) + def _update_costs(self, usage: Union[dict, BaseModel], model: str = None, local_calc_usage: bool = True): + """update each request's token cost + Args: + model (str): model name or in some scenarios called endpoint + local_calc_usage (bool): some models don't calculate usage, it will overwrite LLMConfig.calc_usage + """ + calc_usage = self.config.calc_usage and local_calc_usage + model = model or self.model + usage = usage.model_dump() if isinstance(usage, BaseModel) else usage + if calc_usage and self.cost_manager: + try: + prompt_tokens = int(usage.get("prompt_tokens", 0)) + completion_tokens = int(usage.get("completion_tokens", 0)) + self.cost_manager.update_cost(prompt_tokens, completion_tokens, model) + except Exception as e: + logger.error(f"{self.__class__.__name__} updates costs failed! exp: {e}") + + def get_costs(self) -> Costs: + if not self.cost_manager: + return Costs(0, 0, 0, 0) + return self.cost_manager.get_costs() + async def aask( self, msg: str, diff --git a/metagpt/provider/dashscope_api.py b/metagpt/provider/dashscope_api.py new file mode 100644 index 000000000..f2b3a19a1 --- /dev/null +++ b/metagpt/provider/dashscope_api.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import json +from http import HTTPStatus +from typing import Any, AsyncGenerator, Dict, List, Union + +import dashscope +from dashscope.aigc.generation import Generation +from dashscope.api_entities.aiohttp_request import AioHttpRequest +from dashscope.api_entities.api_request_data import ApiRequestData +from dashscope.api_entities.api_request_factory import _get_protocol_params +from dashscope.api_entities.dashscope_response import ( + GenerationOutput, + GenerationResponse, + Message, +) +from dashscope.client.base_api import BaseAioApi +from dashscope.common.constants import SERVICE_API_PATH, ApiProtocol +from dashscope.common.error import ( + InputDataRequired, + InputRequired, + ModelRequired, + UnsupportedApiProtocol, +) +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) + +from metagpt.logs import log_llm_stream, logger +from metagpt.provider.base_llm import BaseLLM, LLMConfig +from metagpt.provider.llm_provider_registry import LLMType, register_provider +from metagpt.provider.openai_api import log_and_reraise +from metagpt.utils.cost_manager import CostManager +from metagpt.utils.token_counter import DASHSCOPE_TOKEN_COSTS + + +def build_api_arequest( + model: str, input: object, task_group: str, task: str, function: str, api_key: str, is_service=True, **kwargs +): + ( + api_protocol, + ws_stream_mode, + is_binary_input, + http_method, + stream, + async_request, + query, + headers, + request_timeout, + form, + resources, + ) = _get_protocol_params(kwargs) + task_id = kwargs.pop("task_id", None) + if api_protocol in [ApiProtocol.HTTP, ApiProtocol.HTTPS]: + if not dashscope.base_http_api_url.endswith("/"): + http_url = dashscope.base_http_api_url + "/" + else: + http_url = dashscope.base_http_api_url + + if is_service: + http_url = http_url + SERVICE_API_PATH + "/" + + if task_group: + http_url += "%s/" % task_group + if task: + http_url += "%s/" % task + if function: + http_url += function + request = AioHttpRequest( + url=http_url, + api_key=api_key, + http_method=http_method, + stream=stream, + async_request=async_request, + query=query, + timeout=request_timeout, + task_id=task_id, + ) + else: + raise UnsupportedApiProtocol("Unsupported protocol: %s, support [http, https, websocket]" % api_protocol) + + if headers is not None: + request.add_headers(headers=headers) + + if input is None and form is None: + raise InputDataRequired("There is no input data and form data") + + request_data = ApiRequestData( + model, + task_group=task_group, + task=task, + function=function, + input=input, + form=form, + is_binary_input=is_binary_input, + api_protocol=api_protocol, + ) + request_data.add_resources(resources) + request_data.add_parameters(**kwargs) + request.data = request_data + return request + + +class AGeneration(Generation, BaseAioApi): + @classmethod + async def acall( + cls, + model: str, + prompt: Any = None, + history: list = None, + api_key: str = None, + messages: List[Message] = None, + plugins: Union[str, Dict[str, Any]] = None, + **kwargs, + ) -> Union[GenerationResponse, AsyncGenerator[GenerationResponse, None]]: + if (prompt is None or not prompt) and (messages is None or not messages): + raise InputRequired("prompt or messages is required!") + if model is None or not model: + raise ModelRequired("Model is required!") + task_group, function = "aigc", "generation" # fixed value + if plugins is not None: + headers = kwargs.pop("headers", {}) + if isinstance(plugins, str): + headers["X-DashScope-Plugin"] = plugins + else: + headers["X-DashScope-Plugin"] = json.dumps(plugins) + kwargs["headers"] = headers + input, parameters = cls._build_input_parameters(model, prompt, history, messages, **kwargs) + + api_key, model = BaseAioApi._validate_params(api_key, model) + request = build_api_arequest( + model=model, + input=input, + task_group=task_group, + task=Generation.task, + function=function, + api_key=api_key, + **kwargs, + ) + response = await request.aio_call() + is_stream = kwargs.get("stream", False) + if is_stream: + + async def aresp_iterator(response): + async for resp in response: + yield GenerationResponse.from_api_response(resp) + + return aresp_iterator(response) + else: + return GenerationResponse.from_api_response(response) + + +@register_provider(LLMType.DASHSCOPE) +class DashScopeLLM(BaseLLM): + def __init__(self, llm_config: LLMConfig): + self.config = llm_config + self.use_system_prompt = False # only some models support system_prompt + self.__init_dashscope() + self.cost_manager = CostManager(token_costs=self.token_costs) + + def __init_dashscope(self): + self.model = self.config.model + self.api_key = self.config.api_key + self.token_costs = DASHSCOPE_TOKEN_COSTS + self.aclient: AGeneration = AGeneration + + # check support system_message models + support_system_models = [ + "qwen-", # all support + "llama2-", # all support + "baichuan2-7b-chat-v1", + "chatglm3-6b", + ] + for support_model in support_system_models: + if support_model in self.model: + self.use_system_prompt = True + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = { + "api_key": self.api_key, + "model": self.model, + "messages": messages, + "stream": stream, + "result_format": "message", + } + if self.config.temperature > 0: + # different model has default temperature. only set when it"s specified. + kwargs["temperature"] = self.config.temperature + if stream: + kwargs["incremental_output"] = True + return kwargs + + def _check_response(self, resp: GenerationResponse): + if resp.status_code != HTTPStatus.OK: + raise RuntimeError(f"code: {resp.code}, request_id: {resp.request_id}, message: {resp.message}") + + def get_choice_text(self, output: GenerationOutput) -> str: + return output.get("choices", [{}])[0].get("message", {}).get("content", "") + + def completion(self, messages: list[dict]) -> GenerationOutput: + resp: GenerationResponse = self.aclient.call(**self._const_kwargs(messages, stream=False)) + self._check_response(resp) + + self._update_costs(dict(resp.usage)) + return resp.output + + async def _achat_completion(self, messages: list[dict]) -> GenerationOutput: + resp: GenerationResponse = await self.aclient.acall(**self._const_kwargs(messages, stream=False)) + self._check_response(resp) + self._update_costs(dict(resp.usage)) + return resp.output + + async def acompletion(self, messages: list[dict], timeout=3) -> GenerationOutput: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + resp = await self.aclient.acall(**self._const_kwargs(messages, stream=True)) + collected_content = [] + usage = {} + async for chunk in resp: + self._check_response(chunk) + content = chunk.output.choices[0]["message"]["content"] + usage = dict(chunk.usage) # each chunk has usage + log_llm_stream(content) + collected_content.append(content) + log_llm_stream("\n") + self._update_costs(usage) + full_content = "".join(collected_content) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(min=1, max=60), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise, + ) + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index d56453a85..f356c23c4 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -16,10 +16,10 @@ from tenacity import ( ) from metagpt.configs.llm_config import LLMConfig, LLMType -from metagpt.logs import logger +from metagpt.logs import log_llm_stream, logger from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM, log_and_reraise -from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.cost_manager import CostManager MODEL_GRADE_TOKEN_COSTS = { "-1": {"prompt": 0.0, "completion": 0.0}, # abnormal condition @@ -81,17 +81,6 @@ class FireworksLLM(OpenAILLM): kwargs = dict(api_key=self.config.api_key, base_url=self.config.base_url) return kwargs - def _update_costs(self, usage: CompletionUsage): - if self.config.calc_usage and usage: - try: - # use FireworksCostManager not context.cost_manager - self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) - except Exception as e: - logger.error(f"updating costs failed!, exp: {e}") - - def get_costs(self) -> Costs: - return self.cost_manager.get_costs() - async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages), stream=True @@ -107,10 +96,11 @@ class FireworksLLM(OpenAILLM): finish_reason = choice.finish_reason if hasattr(choice, "finish_reason") else None if choice_delta.content: collected_content.append(choice_delta.content) - print(choice_delta.content, end="") + log_llm_stream(choice_delta.content) if finish_reason: # fireworks api return usage when finish_reason is not None usage = CompletionUsage(**chunk.usage) + log_llm_stream("\n") full_content = "".join(collected_content) self._update_costs(usage) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 2647ab16b..87ea81c80 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -72,16 +72,6 @@ class GeminiLLM(BaseLLM): kwargs = {"contents": messages, "generation_config": GenerationConfig(temperature=0.3), "stream": stream} return kwargs - def _update_costs(self, usage: dict): - """update each request's token cost""" - if self.config.calc_usage: - try: - prompt_tokens = int(usage.get("prompt_tokens", 0)) - completion_tokens = int(usage.get("completion_tokens", 0)) - self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) - except Exception as e: - logger.error(f"google gemini updats costs failed! exp: {e}") - def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index c9103b018..52e8dbe36 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -46,16 +46,6 @@ class OllamaLLM(BaseLLM): kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} return kwargs - def _update_costs(self, usage: dict): - """update each request's token cost""" - if self.config.calc_usage: - try: - prompt_tokens = int(usage.get("prompt_tokens", 0)) - completion_tokens = int(usage.get("completion_tokens", 0)) - self._cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) - except Exception as e: - logger.error(f"ollama updats costs failed! exp: {e}") - def get_choice_text(self, resp: dict) -> str: """get the resp content from llm response""" assist_msg = resp.get("message", {}) diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index a29b263a4..69371e379 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -8,7 +8,7 @@ from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import OpenAILLM -from metagpt.utils.cost_manager import Costs, TokenCostManager +from metagpt.utils.cost_manager import TokenCostManager from metagpt.utils.token_counter import count_message_tokens, count_string_tokens @@ -34,14 +34,3 @@ class OpenLLM(OpenAILLM): logger.error(f"usage calculation failed!: {e}") return usage - - def _update_costs(self, usage: CompletionUsage): - if self.config.calc_usage and usage: - try: - # use OpenLLMCostManager not CONFIG.cost_manager - self._cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) - except Exception as e: - logger.error(f"updating costs failed!, exp: {e}") - - def get_costs(self) -> Costs: - return self._cost_manager.get_costs() diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 28abed752..36d6f6d77 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -30,7 +30,7 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message from metagpt.utils.common import CodeParser, decode_image -from metagpt.utils.cost_manager import CostManager, Costs +from metagpt.utils.cost_manager import CostManager from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( count_message_tokens, @@ -56,16 +56,13 @@ class OpenAILLM(BaseLLM): def __init__(self, config: LLMConfig): self.config = config - self._init_model() self._init_client() self.auto_max_tokens = False self.cost_manager: Optional[CostManager] = None - def _init_model(self): - self.model = self.config.model # Used in _calc_usage & _cons_kwargs - def _init_client(self): """https://github.com/openai/openai-python#async-usage""" + self.model = self.config.model # Used in _calc_usage & _cons_kwargs kwargs = self._make_client_kwargs() self.aclient = AsyncOpenAI(**kwargs) @@ -102,7 +99,7 @@ class OpenAILLM(BaseLLM): "max_tokens": self._get_max_tokens(messages), "n": 1, # "stop": None, # default it's None and gpt4-v can't have this one - "temperature": 0.3, + "temperature": self.config.temperature, "model": self.model, "timeout": max(self.config.timeout, timeout), } @@ -272,16 +269,6 @@ class OpenAILLM(BaseLLM): return usage - @handle_exception - def _update_costs(self, usage: CompletionUsage): - if self.config.calc_usage and usage and self.cost_manager: - self.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) - - def get_costs(self) -> Costs: - if not self.cost_manager: - return Costs(0, 0, 0, 0) - return self.cost_manager.get_costs() - def _get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: return self.config.max_token diff --git a/metagpt/provider/qianfan_api.py b/metagpt/provider/qianfan_api.py new file mode 100644 index 000000000..4cbb76566 --- /dev/null +++ b/metagpt/provider/qianfan_api.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : llm api of qianfan from Baidu, supports ERNIE(wen xin yi yan) and opensource models +import copy +import os + +import qianfan +from qianfan import ChatCompletion +from qianfan.resources.typing import JsonBody +from tenacity import ( + after_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_random_exponential, +) + +from metagpt.configs.llm_config import LLMConfig, LLMType +from metagpt.logs import log_llm_stream, logger +from metagpt.provider.base_llm import BaseLLM +from metagpt.provider.llm_provider_registry import register_provider +from metagpt.provider.openai_api import log_and_reraise +from metagpt.utils.cost_manager import CostManager +from metagpt.utils.token_counter import ( + QIANFAN_ENDPOINT_TOKEN_COSTS, + QIANFAN_MODEL_TOKEN_COSTS, +) + + +@register_provider(LLMType.QIANFAN) +class QianFanLLM(BaseLLM): + """ + Refs + Auth: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/3lmokh7n6#%E3%80%90%E6%8E%A8%E8%8D%90%E3%80%91%E4%BD%BF%E7%94%A8%E5%AE%89%E5%85%A8%E8%AE%A4%E8%AF%81aksk%E9%89%B4%E6%9D%83%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B + Token Price: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9 + Models: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/wlmhm7vuo#%E5%AF%B9%E8%AF%9Dchat + https://cloud.baidu.com/doc/WENXINWORKSHOP/s/xlmokikxe#%E6%94%AF%E6%8C%81%E6%A8%A1%E5%9E%8B%E5%88%97%E8%A1%A8 + """ + + def __init__(self, config: LLMConfig): + self.config = config + self.use_system_prompt = False # only some ERNIE-x related models support system_prompt + self.__init_qianfan() + self.cost_manager = CostManager(token_costs=self.token_costs) + + def __init_qianfan(self): + if self.config.access_key and self.config.secret_key: + # for system level auth, use access_key and secret_key, recommended by official + # set environment variable due to official recommendation + os.environ.setdefault("QIANFAN_ACCESS_KEY", self.config.access_key) + os.environ.setdefault("QIANFAN_SECRET_KEY", self.config.secret_key) + elif self.config.api_key and self.config.secret_key: + # for application level auth, use api_key and secret_key + # set environment variable due to official recommendation + os.environ.setdefault("QIANFAN_AK", self.config.api_key) + os.environ.setdefault("QIANFAN_SK", self.config.secret_key) + else: + raise ValueError("Set the `access_key`&`secret_key` or `api_key`&`secret_key` first") + + support_system_pairs = [ + ("ERNIE-Bot-4", "completions_pro"), # (model, corresponding-endpoint) + ("ERNIE-Bot-8k", "ernie_bot_8k"), + ("ERNIE-Bot", "completions"), + ("ERNIE-Bot-turbo", "eb-instant"), + ("ERNIE-Speed", "ernie_speed"), + ("EB-turbo-AppBuilder", "ai_apaas"), + ] + if self.config.model in [pair[0] for pair in support_system_pairs]: + # only some ERNIE models support + self.use_system_prompt = True + if self.config.endpoint in [pair[1] for pair in support_system_pairs]: + self.use_system_prompt = True + + assert not (self.config.model and self.config.endpoint), "Only set `model` or `endpoint` in the config" + assert self.config.model or self.config.endpoint, "Should set one of `model` or `endpoint` in the config" + + self.token_costs = copy.deepcopy(QIANFAN_MODEL_TOKEN_COSTS) + self.token_costs.update(QIANFAN_ENDPOINT_TOKEN_COSTS) + + # self deployed model on the cloud not to calculate usage, it charges resource pool rental fee + self.calc_usage = self.config.calc_usage and self.config.endpoint is None + self.aclient: ChatCompletion = qianfan.ChatCompletion() + + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: + kwargs = { + "messages": messages, + "stream": stream, + } + if self.config.temperature > 0: + # different model has default temperature. only set when it's specified. + kwargs["temperature"] = self.config.temperature + if self.config.endpoint: + kwargs["endpoint"] = self.config.endpoint + elif self.config.model: + kwargs["model"] = self.config.model + + if self.use_system_prompt: + # if the model support system prompt, extract and pass it + if messages[0]["role"] == "system": + kwargs["messages"] = messages[1:] + kwargs["system"] = messages[0]["content"] # set system prompt here + return kwargs + + def _update_costs(self, usage: dict): + """update each request's token cost""" + model_or_endpoint = self.config.model or self.config.endpoint + local_calc_usage = model_or_endpoint in self.token_costs + super()._update_costs(usage, model_or_endpoint, local_calc_usage) + + def get_choice_text(self, resp: JsonBody) -> str: + return resp.get("result", "") + + def completion(self, messages: list[dict]) -> JsonBody: + resp = self.aclient.do(**self._const_kwargs(messages=messages, stream=False)) + self._update_costs(resp.body.get("usage", {})) + return resp.body + + async def _achat_completion(self, messages: list[dict]) -> JsonBody: + resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=False)) + self._update_costs(resp.body.get("usage", {})) + return resp.body + + async def acompletion(self, messages: list[dict], timeout=3) -> JsonBody: + return await self._achat_completion(messages) + + async def _achat_completion_stream(self, messages: list[dict]) -> str: + resp = await self.aclient.ado(**self._const_kwargs(messages=messages, stream=True)) + collected_content = [] + usage = {} + async for chunk in resp: + content = chunk.body.get("result", "") + usage = chunk.body.get("usage", {}) + log_llm_stream(content) + collected_content.append(content) + log_llm_stream("\n") + + self._update_costs(usage) + full_content = "".join(collected_content) + return full_content + + @retry( + stop=stop_after_attempt(3), + wait=wait_random_exponential(min=1, max=60), + after=after_log(logger, logger.level("WARNING").name), + retry=retry_if_exception_type(ConnectionError), + retry_error_callback=log_and_reraise, + ) + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: + if stream: + return await self._achat_completion_stream(messages) + resp = await self._achat_completion(messages) + return self.get_choice_text(resp) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 9e8e5fb53..4cbee4038 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -53,16 +53,6 @@ class ZhiPuAILLM(BaseLLM): kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3} return kwargs - def _update_costs(self, usage: dict): - """update each request's token cost""" - if self.config.calc_usage: - try: - prompt_tokens = int(usage.get("prompt_tokens", 0)) - completion_tokens = int(usage.get("completion_tokens", 0)) - self.cost_manager.update_cost(prompt_tokens, completion_tokens, self.model) - except Exception as e: - logger.error(f"zhipuai updats costs failed! exp: {e}") - def completion(self, messages: list[dict], timeout=3) -> dict: resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages)) usage = resp.usage.model_dump() diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 3938664ba..893c5cafd 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -281,7 +281,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel): i = action self._init_action(i) self.actions.append(i) - self.states.append(f"{len(self.actions)}. {action}") + self.states.append(f"{len(self.actions) - 1}. {action}") def _set_react_mode(self, react_mode: str, max_react_loop: int = 1, auto_run: bool = True, use_tools: bool = False): """Set strategy of the Role reacting to observed Message. Variation lies in how diff --git a/metagpt/software_company.py b/metagpt/software_company.py index 26bb29cd1..f290d497a 100644 --- a/metagpt/software_company.py +++ b/metagpt/software_company.py @@ -2,14 +2,11 @@ # -*- coding: utf-8 -*- import asyncio -import shutil from pathlib import Path import typer -from metagpt.config2 import config -from metagpt.const import CONFIG_ROOT, METAGPT_ROOT -from metagpt.context import Context +from metagpt.const import CONFIG_ROOT from metagpt.utils.project_repo import ProjectRepo app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False) @@ -30,6 +27,8 @@ def generate_repo( recover_path=None, ) -> ProjectRepo: """Run the startup logic. Can be called from CLI or other Python scripts.""" + from metagpt.config2 import config + from metagpt.context import Context from metagpt.roles import ( Architect, Engineer, @@ -122,7 +121,17 @@ def startup( ) -def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"): +DEFAULT_CONFIG = """# Full Example: https://github.com/geekan/MetaGPT/blob/main/config/config2.example.yaml +# Reflected Code: https://github.com/geekan/MetaGPT/blob/main/metagpt/config2.py +llm: + api_type: "openai" # or azure / ollama / open_llm etc. Check LLMType for more options + model: "gpt-4-turbo-preview" # or gpt-3.5-turbo-1106 / gpt-4-1106-preview + base_url: "https://api.openai.com/v1" # or forward url / other llm url + api_key: "YOUR_API_KEY" +""" + + +def copy_config_to(): """Initialize the configuration file for MetaGPT.""" target_path = CONFIG_ROOT / "config2.yaml" @@ -136,7 +145,7 @@ def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"): print(f"Existing configuration file backed up at {backup_path}") # 复制文件 - shutil.copy(str(config_path), target_path) + target_path.write_text(DEFAULT_CONFIG, encoding="utf-8") print(f"Configuration file initialized at {target_path}") diff --git a/metagpt/utils/cost_manager.py b/metagpt/utils/cost_manager.py index c4c93f91f..efff07ae1 100644 --- a/metagpt/utils/cost_manager.py +++ b/metagpt/utils/cost_manager.py @@ -29,6 +29,7 @@ class CostManager(BaseModel): total_budget: float = 0 max_budget: float = 10.0 total_cost: float = 0 + token_costs: dict[str, dict[str, float]] = TOKEN_COSTS # different model's token cost def update_cost(self, prompt_tokens, completion_tokens, model): """ @@ -41,12 +42,13 @@ class CostManager(BaseModel): """ self.total_prompt_tokens += prompt_tokens self.total_completion_tokens += completion_tokens - if model not in TOKEN_COSTS: + if model not in self.token_costs: logger.warning(f"Model {model} not found in TOKEN_COSTS.") return cost = ( - prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"] + prompt_tokens * self.token_costs[model]["prompt"] + + completion_tokens * self.token_costs[model]["completion"] ) / 1000 self.total_cost += cost logger.info( diff --git a/metagpt/utils/repair_llm_raw_output.py b/metagpt/utils/repair_llm_raw_output.py index 06484f71d..b8756e8c6 100644 --- a/metagpt/utils/repair_llm_raw_output.py +++ b/metagpt/utils/repair_llm_raw_output.py @@ -119,6 +119,7 @@ def repair_json_format(output: str) -> str: logger.info(f"repair_json_format: {'}]'}") elif output.startswith("{") and output.endswith("]"): output = output[:-1] + "}" + # remove comments in output json string, after json value content, maybe start with #, maybe start with // arr = output.split("\n") new_arr = [] @@ -208,6 +209,17 @@ def repair_invalid_json(output: str, error: str) -> str: elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line: # problem, `"""` or `'''` without `,` new_line = f",{line}" + elif col_no - 1 >= 0 and rline[col_no - 1] in ['"', "'"]: + # backslash problem like \" in the output + char = rline[col_no - 1] + nearest_char_idx = rline[col_no:].find(char) + new_line = ( + rline[: col_no - 1] + + "\\" + + rline[col_no - 1 : col_no + nearest_char_idx] + + "\\" + + rline[col_no + nearest_char_idx :] + ) elif '",' not in line and "," not in line and '"' not in line: new_line = f'{line}",' elif not line.endswith(","): diff --git a/metagpt/utils/token_counter.py b/metagpt/utils/token_counter.py index 65f5fe76f..167a1d755 100644 --- a/metagpt/utils/token_counter.py +++ b/metagpt/utils/token_counter.py @@ -38,6 +38,88 @@ TOKEN_COSTS = { } +""" +QianFan Token Price https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7#tokens%E5%90%8E%E4%BB%98%E8%B4%B9 +Due to QianFan has multi price strategies, we unify `Tokens post-payment` as a statistical method. +""" +QIANFAN_MODEL_TOKEN_COSTS = { + "ERNIE-Bot-4": {"prompt": 0.017, "completion": 0.017}, + "ERNIE-Bot-8k": {"prompt": 0.0034, "completion": 0.0067}, + "ERNIE-Bot": {"prompt": 0.0017, "completion": 0.0017}, + "ERNIE-Bot-turbo": {"prompt": 0.0011, "completion": 0.0011}, + "EB-turbo-AppBuilder": {"prompt": 0.0011, "completion": 0.0011}, + "ERNIE-Speed": {"prompt": 0.00056, "completion": 0.0011}, + "BLOOMZ-7B": {"prompt": 0.00056, "completion": 0.00056}, + "Llama-2-7B-Chat": {"prompt": 0.00056, "completion": 0.00056}, + "Llama-2-13B-Chat": {"prompt": 0.00084, "completion": 0.00084}, + "Llama-2-70B-Chat": {"prompt": 0.0049, "completion": 0.0049}, + "ChatGLM2-6B-32K": {"prompt": 0.00056, "completion": 0.00056}, + "AquilaChat-7B": {"prompt": 0.00056, "completion": 0.00056}, + "Mixtral-8x7B-Instruct": {"prompt": 0.0049, "completion": 0.0049}, + "SQLCoder-7B": {"prompt": 0.00056, "completion": 0.00056}, + "CodeLlama-7B-Instruct": {"prompt": 0.00056, "completion": 0.00056}, + "XuanYuan-70B-Chat-4bit": {"prompt": 0.0049, "completion": 0.0049}, + "Qianfan-BLOOMZ-7B-compressed": {"prompt": 0.00056, "completion": 0.00056}, + "Qianfan-Chinese-Llama-2-7B": {"prompt": 0.00056, "completion": 0.00056}, + "Qianfan-Chinese-Llama-2-13B": {"prompt": 0.00084, "completion": 0.00084}, + "ChatLaw": {"prompt": 0.0011, "completion": 0.0011}, + "Yi-34B-Chat": {"prompt": 0.0, "completion": 0.0}, +} + +QIANFAN_ENDPOINT_TOKEN_COSTS = { + "completions_pro": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-4"], + "ernie_bot_8k": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-8k"], + "completions": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot"], + "eb-instant": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Bot-turbo"], + "ai_apaas": QIANFAN_MODEL_TOKEN_COSTS["EB-turbo-AppBuilder"], + "ernie_speed": QIANFAN_MODEL_TOKEN_COSTS["ERNIE-Speed"], + "bloomz_7b1": QIANFAN_MODEL_TOKEN_COSTS["BLOOMZ-7B"], + "llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-7B-Chat"], + "llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-13B-Chat"], + "llama_2_70b": QIANFAN_MODEL_TOKEN_COSTS["Llama-2-70B-Chat"], + "chatglm2_6b_32k": QIANFAN_MODEL_TOKEN_COSTS["ChatGLM2-6B-32K"], + "aquilachat_7b": QIANFAN_MODEL_TOKEN_COSTS["AquilaChat-7B"], + "mixtral_8x7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["Mixtral-8x7B-Instruct"], + "sqlcoder_7b": QIANFAN_MODEL_TOKEN_COSTS["SQLCoder-7B"], + "codellama_7b_instruct": QIANFAN_MODEL_TOKEN_COSTS["CodeLlama-7B-Instruct"], + "xuanyuan_70b_chat": QIANFAN_MODEL_TOKEN_COSTS["XuanYuan-70B-Chat-4bit"], + "qianfan_bloomz_7b_compressed": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-BLOOMZ-7B-compressed"], + "qianfan_chinese_llama_2_7b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-7B"], + "qianfan_chinese_llama_2_13b": QIANFAN_MODEL_TOKEN_COSTS["Qianfan-Chinese-Llama-2-13B"], + "chatlaw": QIANFAN_MODEL_TOKEN_COSTS["ChatLaw"], + "yi_34b_chat": QIANFAN_MODEL_TOKEN_COSTS["Yi-34B-Chat"], +} + +""" +DashScope Token price https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing +Different model has different detail page. Attention, some model are free for a limited time. +""" +DASHSCOPE_TOKEN_COSTS = { + "qwen-turbo": {"prompt": 0.0011, "completion": 0.0011}, + "qwen-plus": {"prompt": 0.0028, "completion": 0.0028}, + "qwen-max": {"prompt": 0.0, "completion": 0.0}, + "qwen-max-1201": {"prompt": 0.0, "completion": 0.0}, + "qwen-max-longcontext": {"prompt": 0.0, "completion": 0.0}, + "llama2-7b-chat-v2": {"prompt": 0.0, "completion": 0.0}, + "llama2-13b-chat-v2": {"prompt": 0.0, "completion": 0.0}, + "qwen-72b-chat": {"prompt": 0.0, "completion": 0.0}, + "qwen-14b-chat": {"prompt": 0.0011, "completion": 0.0011}, + "qwen-7b-chat": {"prompt": 0.00084, "completion": 0.00084}, + "qwen-1.8b-chat": {"prompt": 0.0, "completion": 0.0}, + "baichuan2-13b-chat-v1": {"prompt": 0.0011, "completion": 0.0011}, + "baichuan2-7b-chat-v1": {"prompt": 0.00084, "completion": 0.00084}, + "baichuan-7b-v1": {"prompt": 0.0, "completion": 0.0}, + "chatglm-6b-v2": {"prompt": 0.0011, "completion": 0.0011}, + "chatglm3-6b": {"prompt": 0.0, "completion": 0.0}, + "ziya-llama-13b-v1": {"prompt": 0.0, "completion": 0.0}, # no price page, judge it as free + "dolly-12b-v2": {"prompt": 0.0, "completion": 0.0}, + "belle-llama-13b-2m-v1": {"prompt": 0.0, "completion": 0.0}, + "moss-moon-003-sft-v1": {"prompt": 0.0, "completion": 0.0}, + "chatyuan-large-v2": {"prompt": 0.0, "completion": 0.0}, + "billa-7b-sft-v1": {"prompt": 0.0, "completion": 0.0}, +} + + TOKEN_MAX = { "gpt-3.5-turbo": 4096, "gpt-3.5-turbo-0301": 4096, diff --git a/requirements.txt b/requirements.txt index 4ea6dc5d2..cf50cf255 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ typer==0.9.0 # godot==0.1.1 # google_api_python_client==2.93.0 # Used by search_engine.py lancedb==0.4.0 -langchain==0.0.352 +langchain==0.1.8 loguru==0.6.0 meilisearch==0.21.0 numpy>=1.24.3,<1.25.0 @@ -27,7 +27,7 @@ python_docx==0.8.11 PyYAML==6.0.1 # sentence_transformers==2.2.2 setuptools==65.6.3 -tenacity==8.2.2 +tenacity==8.2.3 tiktoken==0.5.2 tqdm==4.65.0 #unstructured[local-inference] @@ -54,7 +54,7 @@ rich==13.6.0 nbclient==0.9.0 nbformat==5.9.2 ipython==8.17.2 -ipykernel==6.27.0 +ipykernel==6.27.1 scikit_learn==1.3.2 typing-extensions==4.9.0 socksio~=1.0.0 @@ -68,3 +68,5 @@ anytree ipywidgets==8.1.1 Pillow imap_tools==1.5.0 # Used by metagpt/tools/libs/email_login.py +qianfan==0.3.2 +dashscope==1.14.1 diff --git a/tests/metagpt/document_store/test_faiss_store.py b/tests/metagpt/document_store/test_faiss_store.py index 7e2979bd4..397ba6ce5 100644 --- a/tests/metagpt/document_store/test_faiss_store.py +++ b/tests/metagpt/document_store/test_faiss_store.py @@ -6,6 +6,9 @@ @File : test_faiss_store.py """ +from typing import Optional + +import numpy as np import pytest from metagpt.const import EXAMPLE_PATH @@ -14,8 +17,17 @@ from metagpt.logs import logger from metagpt.roles import Sales +def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]: + num = len(texts) + embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim + embeds = (embeds - embeds.mean(axis=0)) / (embeds.std(axis=0)) + return embeds + + @pytest.mark.asyncio -async def test_search_json(): +async def test_search_json(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + store = FaissStore(EXAMPLE_PATH / "example.json") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" @@ -24,7 +36,9 @@ async def test_search_json(): @pytest.mark.asyncio -async def test_search_xlsx(): +async def test_search_xlsx(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + store = FaissStore(EXAMPLE_PATH / "example.xlsx") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" @@ -33,7 +47,9 @@ async def test_search_xlsx(): @pytest.mark.asyncio -async def test_write(): +async def test_write(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question") _faiss_store = store.write() assert _faiss_store.docstore diff --git a/tests/metagpt/memory/mock_text_embed.py b/tests/metagpt/memory/mock_text_embed.py new file mode 100644 index 000000000..897c7cf10 --- /dev/null +++ b/tests/metagpt/memory/mock_text_embed.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from typing import Optional + +import numpy as np + +dim = 1536 # openai embedding dim + +text_embed_arr = [ + {"text": "Write a cli snake game", "embed": np.zeros(shape=[1, dim])}, # mock data, same as below + {"text": "Write a game of cli snake", "embed": np.zeros(shape=[1, dim])}, + {"text": "Write a 2048 web game", "embed": np.ones(shape=[1, dim])}, + {"text": "Write a Battle City", "embed": np.ones(shape=[1, dim])}, + { + "text": "The user has requested the creation of a command-line interface (CLI) snake game", + "embed": np.zeros(shape=[1, dim]), + }, + {"text": "The request is command-line interface (CLI) snake game", "embed": np.zeros(shape=[1, dim])}, + { + "text": "Incorporate basic features of a snake game such as scoring and increasing difficulty", + "embed": np.ones(shape=[1, dim]), + }, +] + +text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)} + + +def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]: + idx = text_idx_dict.get(texts[0]) + embed = text_embed_arr[idx].get("embed") + return embed diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index 5c71ddd13..f7e652758 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -4,20 +4,22 @@ @Desc : unittest of `metagpt/memory/longterm_memory.py` """ -import os import pytest from metagpt.actions import UserRequirement -from metagpt.config2 import config from metagpt.memory.longterm_memory import LongTermMemory from metagpt.roles.role import RoleContext from metagpt.schema import Message - -os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key) +from tests.metagpt.memory.mock_text_embed import ( + mock_openai_embed_documents, + text_embed_arr, +) -def test_ltm_search(): +def test_ltm_search(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + role_id = "UTUserLtm(Product Manager)" from metagpt.environment import Environment @@ -27,20 +29,20 @@ def test_ltm_search(): ltm = LongTermMemory() ltm.recover_memory(role_id, rc) - idea = "Write a cli snake game" + idea = text_embed_arr[0].get("text", "Write a cli snake game") message = Message(role="User", content=idea, cause_by=UserRequirement) news = ltm.find_news([message]) assert len(news) == 1 ltm.add(message) - sim_idea = "Write a game of cli snake" + sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake") sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) news = ltm.find_news([sim_message]) assert len(news) == 0 ltm.add(sim_message) - new_idea = "Write a 2048 web game" + new_idea = text_embed_arr[2].get("text", "Write a 2048 web game") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) news = ltm.find_news([new_message]) assert len(news) == 1 @@ -56,7 +58,7 @@ def test_ltm_search(): news = ltm_new.find_news([sim_message]) assert len(news) == 0 - new_idea = "Write a Battle City" + new_idea = text_embed_arr[3].get("text", "Write a Battle City") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) news = ltm_new.find_news([new_message]) assert len(news) == 1 diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index e82a82fc8..28a73276b 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -4,23 +4,25 @@ @Desc : the unittests of metagpt/memory/memory_storage.py """ -import os import shutil from pathlib import Path from typing import List from metagpt.actions import UserRequirement, WritePRD from metagpt.actions.action_node import ActionNode -from metagpt.config2 import config from metagpt.const import DATA_PATH from metagpt.memory.memory_storage import MemoryStorage from metagpt.schema import Message - -os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key) +from tests.metagpt.memory.mock_text_embed import ( + mock_openai_embed_documents, + text_embed_arr, +) -def test_idea_message(): - idea = "Write a cli snake game" +def test_idea_message(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + + idea = text_embed_arr[0].get("text", "Write a cli snake game") role_id = "UTUser1(Product Manager)" message = Message(role="User", content=idea, cause_by=UserRequirement) @@ -33,12 +35,12 @@ def test_idea_message(): memory_storage.add(message) assert memory_storage.is_initialized is True - sim_idea = "Write a game of cli snake" + sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake") sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement) new_messages = memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] - new_idea = "Write a 2048 web game" + new_idea = text_embed_arr[2].get("text", "Write a 2048 web game") new_message = Message(role="User", content=new_idea, cause_by=UserRequirement) new_messages = memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content @@ -47,13 +49,17 @@ def test_idea_message(): assert memory_storage.is_initialized is False -def test_actionout_message(): +def test_actionout_message(mocker): + mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents) + out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]} ic_obj = ActionNode.create_model_class("prd", out_mapping) role_id = "UTUser2(Architect)" - content = "The user has requested the creation of a command-line interface (CLI) snake game" + content = text_embed_arr[4].get( + "text", "The user has requested the creation of a command-line interface (CLI) snake game" + ) message = Message( content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD ) # WritePRD as test action @@ -67,12 +73,14 @@ def test_actionout_message(): memory_storage.add(message) assert memory_storage.is_initialized is True - sim_conent = "The request is command-line interface (CLI) snake game" + sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game") sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) new_messages = memory_storage.search_dissimilar(sim_message) assert len(new_messages) == 0 # similar, return [] - new_conent = "Incorporate basic features of a snake game such as scoring and increasing difficulty" + new_conent = text_embed_arr[6].get( + "text", "Incorporate basic features of a snake game such as scoring and increasing difficulty" + ) new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD) new_messages = memory_storage.search_dissimilar(new_message) assert new_messages[0].content == message.content diff --git a/tests/metagpt/provider/mock_llm_config.py b/tests/metagpt/provider/mock_llm_config.py index e2f626a6a..e75acf68f 100644 --- a/tests/metagpt/provider/mock_llm_config.py +++ b/tests/metagpt/provider/mock_llm_config.py @@ -42,3 +42,17 @@ mock_llm_config_zhipu = LLMConfig( model="mock_zhipu_model", proxy="http://localhost:8080", ) + + +mock_llm_config_spark = LLMConfig( + api_type="spark", + app_id="xxx", + api_key="xxx", + api_secret="xxx", + domain="generalv2", + base_url="wss://spark-api.xf-yun.com/v3.1/chat", +) + +mock_llm_config_qianfan = LLMConfig(api_type="qianfan", access_key="xxx", secret_key="xxx", model="ERNIE-Bot-turbo") + +mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model="qwen-max") diff --git a/tests/metagpt/provider/req_resp_const.py b/tests/metagpt/provider/req_resp_const.py new file mode 100644 index 000000000..802962013 --- /dev/null +++ b/tests/metagpt/provider/req_resp_const.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : default request & response data for provider unittest + + +from dashscope.api_entities.dashscope_response import ( + DashScopeAPIResponse, + GenerationOutput, + GenerationResponse, + GenerationUsage, +) +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta +from openai.types.completion_usage import CompletionUsage +from qianfan.resources.typing import QfResponse + +from metagpt.provider.base_llm import BaseLLM + +prompt = "who are you?" +messages = [{"role": "user", "content": prompt}] + +resp_cont_tmpl = "I'm {name}" +default_resp_cont = resp_cont_tmpl.format(name="GPT") + + +# part of whole ChatCompletion of openai like structure +def get_part_chat_completion(name: str) -> dict: + part_chat_completion = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": resp_cont_tmpl.format(name=name), + }, + "finish_reason": "stop", + } + ], + "usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41}, + } + return part_chat_completion + + +def get_openai_chat_completion(name: str) -> ChatCompletion: + openai_chat_completion = ChatCompletion( + id="cmpl-a6652c1bb181caae8dd19ad8", + model="xx/xxx", + object="chat.completion", + created=1703300855, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=resp_cont_tmpl.format(name=name)), + logprobs=None, + ) + ], + usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), + ) + return openai_chat_completion + + +def get_openai_chat_completion_chunk(name: str, usage_as_dict: bool = False) -> ChatCompletionChunk: + usage = CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202) + usage = usage if not usage_as_dict else usage.model_dump() + openai_chat_completion_chunk = ChatCompletionChunk( + id="cmpl-a6652c1bb181caae8dd19ad8", + model="xx/xxx", + object="chat.completion.chunk", + created=1703300855, + choices=[ + AChoice( + delta=ChoiceDelta(role="assistant", content=resp_cont_tmpl.format(name=name)), + finish_reason="stop", + index=0, + logprobs=None, + ) + ], + usage=usage, + ) + return openai_chat_completion_chunk + + +# For gemini +gemini_messages = [{"role": "user", "parts": prompt}] + + +# For QianFan +qf_jsonbody_dict = { + "id": "as-4v1h587fyv", + "object": "chat.completion", + "created": 1695021339, + "result": "", + "is_truncated": False, + "need_clear_history": False, + "usage": {"prompt_tokens": 7, "completion_tokens": 15, "total_tokens": 22}, +} + + +def get_qianfan_response(name: str) -> QfResponse: + qf_jsonbody_dict["result"] = resp_cont_tmpl.format(name=name) + return QfResponse(code=200, body=qf_jsonbody_dict) + + +# For DashScope +def get_dashscope_response(name: str) -> GenerationResponse: + return GenerationResponse.from_api_response( + DashScopeAPIResponse( + status_code=200, + output=GenerationOutput( + **{ + "text": "", + "finish_reason": "", + "choices": [ + { + "finish_reason": "stop", + "message": {"role": "assistant", "content": resp_cont_tmpl.format(name=name)}, + } + ], + } + ), + usage=GenerationUsage(**{"input_tokens": 12, "output_tokens": 98, "total_tokens": 110}), + ) + ) + + +# For llm general chat functions call +async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[dict], resp_cont: str): + resp = await llm.aask(prompt, stream=False) + assert resp == resp_cont + + resp = await llm.aask(prompt) + assert resp == resp_cont + + resp = await llm.acompletion_text(messages, stream=False) + assert resp == resp_cont + + resp = await llm.acompletion_text(messages, stream=True) + assert resp == resp_cont diff --git a/tests/metagpt/provider/test_anthropic_api.py b/tests/metagpt/provider/test_anthropic_api.py index 6962ab064..93cfd7dbc 100644 --- a/tests/metagpt/provider/test_anthropic_api.py +++ b/tests/metagpt/provider/test_anthropic_api.py @@ -8,25 +8,25 @@ from anthropic.resources.completions import Completion from metagpt.provider.anthropic_api import Claude2 from tests.metagpt.provider.mock_llm_config import mock_llm_config +from tests.metagpt.provider.req_resp_const import prompt, resp_cont_tmpl -prompt = "who are you" -resp = "I'am Claude2" +resp_cont = resp_cont_tmpl.format(name="Claude") def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: - return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") + return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion") async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: - return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") + return Completion(id="xx", completion=resp_cont, model="claude-2", stop_reason="stop_sequence", type="completion") def test_claude2_ask(mocker): mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create) - assert resp == Claude2(mock_llm_config).ask(prompt) + assert resp_cont == Claude2(mock_llm_config).ask(prompt) @pytest.mark.asyncio async def test_claude2_aask(mocker): mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create) - assert resp == await Claude2(mock_llm_config).aask(prompt) + assert resp_cont == await Claude2(mock_llm_config).aask(prompt) diff --git a/tests/metagpt/provider/test_base_llm.py b/tests/metagpt/provider/test_base_llm.py index cc781f78a..cf44343bc 100644 --- a/tests/metagpt/provider/test_base_llm.py +++ b/tests/metagpt/provider/test_base_llm.py @@ -11,21 +11,13 @@ import pytest from metagpt.configs.llm_config import LLMConfig from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message +from tests.metagpt.provider.req_resp_const import ( + default_resp_cont, + get_part_chat_completion, + prompt, +) -default_chat_resp = { - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": "I'am GPT", - }, - "finish_reason": "stop", - } - ] -} -prompt_msg = "who are you" -resp_content = default_chat_resp["choices"][0]["message"]["content"] +name = "GPT" class MockBaseLLM(BaseLLM): @@ -33,16 +25,13 @@ class MockBaseLLM(BaseLLM): pass def completion(self, messages: list[dict], timeout=3): - return default_chat_resp + return get_part_chat_completion(name) async def acompletion(self, messages: list[dict], timeout=3): - return default_chat_resp + return get_part_chat_completion(name) async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: - return resp_content - - async def close(self): - return default_chat_resp + return default_resp_cont def test_base_llm(): @@ -86,25 +75,25 @@ def test_base_llm(): choice_text = base_llm.get_choice_text(openai_funccall_resp) assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"] - # resp = base_llm.ask(prompt_msg) - # assert resp == resp_content + # resp = base_llm.ask(prompt) + # assert resp == default_resp_cont - # resp = base_llm.ask_batch([prompt_msg]) - # assert resp == resp_content + # resp = base_llm.ask_batch([prompt]) + # assert resp == default_resp_cont - # resp = base_llm.ask_code([prompt_msg]) - # assert resp == resp_content + # resp = base_llm.ask_code([prompt]) + # assert resp == default_resp_cont @pytest.mark.asyncio async def test_async_base_llm(): base_llm = MockBaseLLM() - resp = await base_llm.aask(prompt_msg) - assert resp == resp_content + resp = await base_llm.aask(prompt) + assert resp == default_resp_cont - resp = await base_llm.aask_batch([prompt_msg]) - assert resp == resp_content + resp = await base_llm.aask_batch([prompt]) + assert resp == default_resp_cont - # resp = await base_llm.aask_code([prompt_msg]) - # assert resp == resp_content + # resp = await base_llm.aask_code([prompt]) + # assert resp == default_resp_cont diff --git a/tests/metagpt/provider/test_dashscope_api.py b/tests/metagpt/provider/test_dashscope_api.py new file mode 100644 index 000000000..a6dd8f247 --- /dev/null +++ b/tests/metagpt/provider/test_dashscope_api.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of DashScopeLLM + +from typing import AsyncGenerator, Union + +import pytest +from dashscope.api_entities.dashscope_response import GenerationResponse + +from metagpt.provider.dashscope_api import DashScopeLLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config_dashscope +from tests.metagpt.provider.req_resp_const import ( + get_dashscope_response, + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) + +name = "qwen-max" +resp_cont = resp_cont_tmpl.format(name=name) + + +@classmethod +def mock_dashscope_call( + cls, + messages: list[dict], + model: str, + api_key: str, + result_format: str, + incremental_output: bool = True, + stream: bool = False, +) -> GenerationResponse: + return get_dashscope_response(name) + + +@classmethod +async def mock_dashscope_acall( + cls, + messages: list[dict], + model: str, + api_key: str, + result_format: str, + incremental_output: bool = True, + stream: bool = False, +) -> Union[AsyncGenerator[GenerationResponse, None], GenerationResponse]: + resps = [get_dashscope_response(name)] + + if stream: + + async def aresp_iterator(resps: list[GenerationResponse]): + for resp in resps: + yield resp + + return aresp_iterator(resps) + else: + return resps[0] + + +@pytest.mark.asyncio +async def test_dashscope_acompletion(mocker): + mocker.patch("dashscope.aigc.generation.Generation.call", mock_dashscope_call) + mocker.patch("metagpt.provider.dashscope_api.AGeneration.acall", mock_dashscope_acall) + + dashscope_llm = DashScopeLLM(mock_llm_config_dashscope) + + resp = dashscope_llm.completion(messages) + assert resp.choices[0]["message"]["content"] == resp_cont + + resp = await dashscope_llm.acompletion(messages) + assert resp.choices[0]["message"]["content"] == resp_cont + + await llm_general_chat_funcs_test(dashscope_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_fireworks_llm.py b/tests/metagpt/provider/test_fireworks_llm.py index 66b55e5b2..1c1aa9caa 100644 --- a/tests/metagpt/provider/test_fireworks_llm.py +++ b/tests/metagpt/provider/test_fireworks_llm.py @@ -3,14 +3,7 @@ # @Desc : the unittest of fireworks api import pytest -from openai.types.chat.chat_completion import ( - ChatCompletion, - ChatCompletionMessage, - Choice, -) from openai.types.chat.chat_completion_chunk import ChatCompletionChunk -from openai.types.chat.chat_completion_chunk import Choice as AChoice -from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage from metagpt.provider.fireworks_api import ( @@ -20,42 +13,19 @@ from metagpt.provider.fireworks_api import ( ) from metagpt.utils.cost_manager import Costs from tests.metagpt.provider.mock_llm_config import mock_llm_config - -resp_content = "I'm fireworks" -default_resp = ChatCompletion( - id="cmpl-a6652c1bb181caae8dd19ad8", - model="accounts/fireworks/models/llama-v2-13b-chat", - object="chat.completion", - created=1703300855, - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage(role="assistant", content=resp_content), - logprobs=None, - ) - ], - usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), +from tests.metagpt.provider.req_resp_const import ( + get_openai_chat_completion, + get_openai_chat_completion_chunk, + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, ) -default_resp_chunk = ChatCompletionChunk( - id=default_resp.id, - model=default_resp.model, - object="chat.completion.chunk", - created=default_resp.created, - choices=[ - AChoice( - delta=ChoiceDelta(content=resp_content, role="assistant"), - finish_reason="stop", - index=0, - logprobs=None, - ) - ], - usage=dict(default_resp.usage), -) - -prompt_msg = "who are you" -messages = [{"role": "user", "content": prompt_msg}] +name = "fireworks" +resp_cont = resp_cont_tmpl.format(name=name) +default_resp = get_openai_chat_completion(name) +default_resp_chunk = get_openai_chat_completion_chunk(name, usage_as_dict=True) def test_fireworks_costmanager(): @@ -88,27 +58,17 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) async def test_fireworks_acompletion(mocker): mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - fireworks_gpt = FireworksLLM(mock_llm_config) - fireworks_gpt.model = "llama-v2-13b-chat" + fireworks_llm = FireworksLLM(mock_llm_config) + fireworks_llm.model = "llama-v2-13b-chat" - fireworks_gpt._update_costs( + fireworks_llm._update_costs( usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000) ) - assert fireworks_gpt.get_costs() == Costs( + assert fireworks_llm.get_costs() == Costs( total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0 ) - resp = await fireworks_gpt.acompletion(messages) - assert resp.choices[0].message.content in resp_content + resp = await fireworks_llm.acompletion(messages) + assert resp.choices[0].message.content in resp_cont - resp = await fireworks_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content - - resp = await fireworks_gpt.acompletion_text(messages, stream=False) - assert resp == resp_content - - resp = await fireworks_gpt.acompletion_text(messages, stream=True) - assert resp == resp_content - - resp = await fireworks_gpt.aask(prompt_msg) - assert resp == resp_content + await llm_general_chat_funcs_test(fireworks_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 404ae1e90..50c15ee19 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -11,6 +11,12 @@ from google.generativeai.types import content_types from metagpt.provider.google_gemini_api import GeminiLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config +from tests.metagpt.provider.req_resp_const import ( + gemini_messages, + llm_general_chat_funcs_test, + prompt, + resp_cont_tmpl, +) @dataclass @@ -18,10 +24,8 @@ class MockGeminiResponse(ABC): text: str -prompt_msg = "who are you" -messages = [{"role": "user", "parts": prompt_msg}] -resp_content = "I'm gemini from google" -default_resp = MockGeminiResponse(text=resp_content) +resp_cont = resp_cont_tmpl.format(name="gemini") +default_resp = MockGeminiResponse(text=resp_cont) def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: @@ -60,28 +64,18 @@ async def test_gemini_acompletion(mocker): mock_gemini_generate_content_async, ) - gemini_gpt = GeminiLLM(mock_llm_config) + gemini_llm = GeminiLLM(mock_llm_config) - assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]} - assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]} + assert gemini_llm._user_msg(prompt) == {"role": "user", "parts": [prompt]} + assert gemini_llm._assistant_msg(prompt) == {"role": "model", "parts": [prompt]} - usage = gemini_gpt.get_usage(messages, resp_content) + usage = gemini_llm.get_usage(gemini_messages, resp_cont) assert usage == {"prompt_tokens": 20, "completion_tokens": 20} - resp = gemini_gpt.completion(messages) + resp = gemini_llm.completion(gemini_messages) assert resp == default_resp - resp = await gemini_gpt.acompletion(messages) + resp = await gemini_llm.acompletion(gemini_messages) assert resp.text == default_resp.text - resp = await gemini_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content - - resp = await gemini_gpt.acompletion_text(messages, stream=False) - assert resp == resp_content - - resp = await gemini_gpt.acompletion_text(messages, stream=True) - assert resp == resp_content - - resp = await gemini_gpt.aask(prompt_msg) - assert resp == resp_content + await llm_general_chat_funcs_test(gemini_llm, prompt, gemini_messages, resp_cont) diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index 5d942598b..af2e929e9 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -9,12 +9,15 @@ import pytest from metagpt.provider.ollama_api import OllamaLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config +from tests.metagpt.provider.req_resp_const import ( + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) -prompt_msg = "who are you" -messages = [{"role": "user", "content": prompt_msg}] - -resp_content = "I'm ollama" -default_resp = {"message": {"role": "assistant", "content": resp_content}} +resp_cont = resp_cont_tmpl.format(name="ollama") +default_resp = {"message": {"role": "assistant", "content": resp_cont}} async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]: @@ -41,19 +44,12 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An async def test_gemini_acompletion(mocker): mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest) - ollama_gpt = OllamaLLM(mock_llm_config) + ollama_llm = OllamaLLM(mock_llm_config) - resp = await ollama_gpt.acompletion(messages) + resp = await ollama_llm.acompletion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] - resp = await ollama_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content + resp = await ollama_llm.aask(prompt, stream=False) + assert resp == resp_cont - resp = await ollama_gpt.acompletion_text(messages, stream=False) - assert resp == resp_content - - resp = await ollama_gpt.acompletion_text(messages, stream=True) - assert resp == resp_content - - resp = await ollama_gpt.aask(prompt_msg) - assert resp == resp_content + await llm_general_chat_funcs_test(ollama_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py index fc7b510cc..aa38b95a6 100644 --- a/tests/metagpt/provider/test_open_llm_api.py +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -3,53 +3,26 @@ # @Desc : import pytest -from openai.types.chat.chat_completion import ( - ChatCompletion, - ChatCompletionMessage, - Choice, -) from openai.types.chat.chat_completion_chunk import ChatCompletionChunk -from openai.types.chat.chat_completion_chunk import Choice as AChoice -from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage from metagpt.provider.open_llm_api import OpenLLM -from metagpt.utils.cost_manager import Costs +from metagpt.utils.cost_manager import CostManager, Costs from tests.metagpt.provider.mock_llm_config import mock_llm_config - -resp_content = "I'm llama2" -default_resp = ChatCompletion( - id="cmpl-a6652c1bb181caae8dd19ad8", - model="llama-v2-13b-chat", - object="chat.completion", - created=1703302755, - choices=[ - Choice( - finish_reason="stop", - index=0, - message=ChatCompletionMessage(role="assistant", content=resp_content), - logprobs=None, - ) - ], +from tests.metagpt.provider.req_resp_const import ( + get_openai_chat_completion, + get_openai_chat_completion_chunk, + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, ) -default_resp_chunk = ChatCompletionChunk( - id=default_resp.id, - model=default_resp.model, - object="chat.completion.chunk", - created=default_resp.created, - choices=[ - AChoice( - delta=ChoiceDelta(content=resp_content, role="assistant"), - finish_reason="stop", - index=0, - logprobs=None, - ) - ], -) +name = "llama2-7b" +resp_cont = resp_cont_tmpl.format(name=name) +default_resp = get_openai_chat_completion(name) -prompt_msg = "who are you" -messages = [{"role": "user", "content": prompt_msg}] +default_resp_chunk = get_openai_chat_completion_chunk(name) async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: @@ -68,25 +41,16 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) async def test_openllm_acompletion(mocker): mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - openllm_gpt = OpenLLM(mock_llm_config) - openllm_gpt.model = "llama-v2-13b-chat" + openllm_llm = OpenLLM(mock_llm_config) + openllm_llm.model = "llama-v2-13b-chat" - openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) - assert openllm_gpt.get_costs() == Costs( + openllm_llm.cost_manager = CostManager() + openllm_llm._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) + assert openllm_llm.get_costs() == Costs( total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0 ) - resp = await openllm_gpt.acompletion(messages) - assert resp.choices[0].message.content in resp_content + resp = await openllm_llm.acompletion(messages) + assert resp.choices[0].message.content in resp_cont - resp = await openllm_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content - - resp = await openllm_gpt.acompletion_text(messages, stream=False) - assert resp == resp_content - - resp = await openllm_gpt.acompletion_text(messages, stream=True) - assert resp == resp_content - - resp = await openllm_gpt.aask(prompt_msg) - assert resp == resp_content + await llm_general_chat_funcs_test(openllm_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_qianfan_api.py b/tests/metagpt/provider/test_qianfan_api.py new file mode 100644 index 000000000..28341425c --- /dev/null +++ b/tests/metagpt/provider/test_qianfan_api.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of qianfan api + +from typing import AsyncIterator, Union + +import pytest +from qianfan.resources.typing import JsonBody, QfResponse + +from metagpt.provider.qianfan_api import QianFanLLM +from tests.metagpt.provider.mock_llm_config import mock_llm_config_qianfan +from tests.metagpt.provider.req_resp_const import ( + get_qianfan_response, + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) + +name = "ERNIE-Bot-turbo" +resp_cont = resp_cont_tmpl.format(name=name) + + +def mock_qianfan_do(self, messages: list[dict], model: str, stream: bool = False, system: str = None) -> QfResponse: + return get_qianfan_response(name=name) + + +async def mock_qianfan_ado( + self, messages: list[dict], model: str, stream: bool = True, system: str = None +) -> Union[QfResponse, AsyncIterator[QfResponse]]: + resps = [get_qianfan_response(name=name)] + if stream: + + async def aresp_iterator(resps: list[JsonBody]): + for resp in resps: + yield resp + + return aresp_iterator(resps) + else: + return resps[0] + + +@pytest.mark.asyncio +async def test_qianfan_acompletion(mocker): + mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.do", mock_qianfan_do) + mocker.patch("qianfan.resources.llm.chat_completion.ChatCompletion.ado", mock_qianfan_ado) + + qianfan_llm = QianFanLLM(mock_llm_config_qianfan) + + resp = qianfan_llm.completion(messages) + assert resp.get("result") == resp_cont + + resp = await qianfan_llm.acompletion(messages) + assert resp.get("result") == resp_cont + + await llm_general_chat_funcs_test(qianfan_llm, prompt, messages, resp_cont) diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index f5a6f66fd..9c278267d 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -4,12 +4,18 @@ import pytest -from metagpt.config2 import Config from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM -from tests.metagpt.provider.mock_llm_config import mock_llm_config +from tests.metagpt.provider.mock_llm_config import ( + mock_llm_config, + mock_llm_config_spark, +) +from tests.metagpt.provider.req_resp_const import ( + llm_general_chat_funcs_test, + prompt, + resp_cont_tmpl, +) -prompt_msg = "who are you" -resp_content = "I'm Spark" +resp_cont = resp_cont_tmpl.format(name="Spark") class MockWebSocketApp(object): @@ -23,7 +29,7 @@ class MockWebSocketApp(object): def test_get_msg_from_web(mocker): mocker.patch("websocket.WebSocketApp", MockWebSocketApp) - get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config) + get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config) assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain" ret = get_msg_from_web.run() @@ -31,34 +37,26 @@ def test_get_msg_from_web(mocker): def mock_spark_get_msg_from_web_run(self) -> str: - return resp_content + return resp_cont @pytest.mark.asyncio -async def test_spark_aask(): - llm = SparkLLM(Config.from_home("spark.yaml").llm) +async def test_spark_aask(mocker): + mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) + + llm = SparkLLM(mock_llm_config_spark) resp = await llm.aask("Hello!") - print(resp) + assert resp == resp_cont @pytest.mark.asyncio async def test_spark_acompletion(mocker): mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) - spark_gpt = SparkLLM(mock_llm_config) + spark_llm = SparkLLM(mock_llm_config) - resp = await spark_gpt.acompletion([]) - assert resp == resp_content + resp = await spark_llm.acompletion([]) + assert resp == resp_cont - resp = await spark_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content - - resp = await spark_gpt.acompletion_text([], stream=False) - assert resp == resp_content - - resp = await spark_gpt.acompletion_text([], stream=True) - assert resp == resp_content - - resp = await spark_gpt.aask(prompt_msg) - assert resp == resp_content + await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index ad2ececa2..c51010122 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -6,22 +6,24 @@ import pytest from metagpt.provider.zhipuai_api import ZhiPuAILLM from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu +from tests.metagpt.provider.req_resp_const import ( + get_part_chat_completion, + llm_general_chat_funcs_test, + messages, + prompt, + resp_cont_tmpl, +) -prompt_msg = "who are you" -messages = [{"role": "user", "content": prompt_msg}] - -resp_content = "I'm chatglm-turbo" -default_resp = { - "choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}], - "usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41}, -} +name = "ChatGLM-4" +resp_cont = resp_cont_tmpl.format(name=name) +default_resp = get_part_chat_completion(name) async def mock_zhipuai_acreate_stream(self, **kwargs): class MockResponse(object): async def _aread(self): class Iterator(object): - events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}] + events = [{"choices": [{"index": 0, "delta": {"content": resp_cont, "role": "assistant"}}]}] async def __aiter__(self): for event in self.events: @@ -46,22 +48,12 @@ async def test_zhipuai_acompletion(mocker): mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate) mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream) - zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu) + zhipu_llm = ZhiPuAILLM(mock_llm_config_zhipu) - resp = await zhipu_gpt.acompletion(messages) - assert resp["choices"][0]["message"]["content"] == resp_content + resp = await zhipu_llm.acompletion(messages) + assert resp["choices"][0]["message"]["content"] == resp_cont - resp = await zhipu_gpt.aask(prompt_msg, stream=False) - assert resp == resp_content - - resp = await zhipu_gpt.acompletion_text(messages, stream=False) - assert resp == resp_content - - resp = await zhipu_gpt.acompletion_text(messages, stream=True) - assert resp == resp_content - - resp = await zhipu_gpt.aask(prompt_msg) - assert resp == resp_content + await llm_general_chat_funcs_test(zhipu_llm, prompt, messages, resp_cont) def test_zhipuai_proxy(): diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index e28423b91..7a29ea3ee 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -211,6 +211,11 @@ value output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1") assert output == target_output + raw_output = '{"key": "url "http" \\"https\\" "}' + target_output = '{"key": "url \\"http\\" \\"https\\" "}' + output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 1 column 15 (char 14)") + assert output == target_output + def test_retry_parse_json_text(): from metagpt.utils.repair_llm_raw_output import retry_parse_json_text diff --git a/tests/spark.yaml b/tests/spark.yaml deleted file mode 100644 index a5bbd98bd..000000000 --- a/tests/spark.yaml +++ /dev/null @@ -1,7 +0,0 @@ -llm: - api_type: "spark" - app_id: "xxx" - api_key: "xxx" - api_secret: "xxx" - domain: "generalv2" - base_url: "wss://spark-api.xf-yun.com/v3.1/chat" \ No newline at end of file