mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-26 17:26:22 +02:00
Merge branch 'dev' of https://github.com/geekan/MetaGPT into geekan/dev
This commit is contained in:
commit
ef1bc01c99
33 changed files with 714 additions and 140 deletions
143
tests/metagpt/actions/mock_json.py
Normal file
143
tests/metagpt/actions/mock_json.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/24 20:32
|
||||
@Author : alexanderwu
|
||||
@File : mock_json.py
|
||||
"""
|
||||
|
||||
PRD = {
|
||||
"Language": "zh_cn",
|
||||
"Programming Language": "Python",
|
||||
"Original Requirements": "写一个简单的cli贪吃蛇",
|
||||
"Project Name": "cli_snake",
|
||||
"Product Goals": ["创建一个简单易用的贪吃蛇游戏", "提供良好的用户体验", "支持不同难度级别"],
|
||||
"User Stories": [
|
||||
"作为玩家,我希望能够选择不同的难度级别",
|
||||
"作为玩家,我希望在每局游戏结束后能够看到我的得分",
|
||||
"作为玩家,我希望在输掉游戏后能够重新开始",
|
||||
"作为玩家,我希望看到简洁美观的界面",
|
||||
"作为玩家,我希望能够在手机上玩游戏",
|
||||
],
|
||||
"Competitive Analysis": ["贪吃蛇游戏A:界面简单,缺乏响应式特性", "贪吃蛇游戏B:美观且响应式的界面,显示最高得分", "贪吃蛇游戏C:响应式界面,显示最高得分,但有很多广告"],
|
||||
"Competitive Quadrant Chart": 'quadrantChart\n title "Reach and engagement of campaigns"\n x-axis "Low Reach" --> "High Reach"\n y-axis "Low Engagement" --> "High Engagement"\n quadrant-1 "We should expand"\n quadrant-2 "Need to promote"\n quadrant-3 "Re-evaluate"\n quadrant-4 "May be improved"\n "Game A": [0.3, 0.6]\n "Game B": [0.45, 0.23]\n "Game C": [0.57, 0.69]\n "Game D": [0.78, 0.34]\n "Game E": [0.40, 0.34]\n "Game F": [0.35, 0.78]\n "Our Target Product": [0.5, 0.6]',
|
||||
"Requirement Analysis": "",
|
||||
"Requirement Pool": [["P0", "主要代码..."], ["P0", "游戏算法..."]],
|
||||
"UI Design draft": "基本功能描述,简单的风格和布局。",
|
||||
"Anything UNCLEAR": "",
|
||||
}
|
||||
|
||||
|
||||
DESIGN = {
|
||||
"Implementation approach": "我们将使用Python编程语言,并选择合适的开源框架来实现贪吃蛇游戏。我们将分析需求中的难点,并选择合适的开源框架来简化开发流程。",
|
||||
"File list": ["main.py", "game.py"],
|
||||
"Data structures and interfaces": "\nclassDiagram\n class Game {\n -int width\n -int height\n -int score\n -int speed\n -List<Point> snake\n -Point food\n +__init__(width: int, height: int, speed: int)\n +start_game()\n +change_direction(direction: str)\n +game_over()\n +update_snake()\n +update_food()\n +check_collision()\n }\n class Point {\n -int x\n -int y\n +__init__(x: int, y: int)\n }\n Game --> Point\n",
|
||||
"Program call flow": "\nsequenceDiagram\n participant M as Main\n participant G as Game\n M->>G: start_game()\n M->>G: change_direction(direction)\n G->>G: update_snake()\n G->>G: update_food()\n G->>G: check_collision()\n G-->>G: game_over()\n",
|
||||
"Anything UNCLEAR": "",
|
||||
}
|
||||
|
||||
|
||||
TASKS = {
|
||||
"Required Python packages": ["pygame==2.0.1"],
|
||||
"Required Other language third-party packages": ["No third-party dependencies required"],
|
||||
"Logic Analysis": [
|
||||
["game.py", "Contains Game class and related functions for game logic"],
|
||||
["main.py", "Contains the main function, imports Game class from game.py"],
|
||||
],
|
||||
"Task list": ["game.py", "main.py"],
|
||||
"Full API spec": "",
|
||||
"Shared Knowledge": "'game.py' contains functions shared across the project.",
|
||||
"Anything UNCLEAR": "",
|
||||
}
|
||||
|
||||
|
||||
FILE_GAME = """## game.py
|
||||
|
||||
import pygame
|
||||
import random
|
||||
|
||||
class Point:
|
||||
def __init__(self, x: int, y: int):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
class Game:
|
||||
def __init__(self, width: int, height: int, speed: int):
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.score = 0
|
||||
self.speed = speed
|
||||
self.snake = [Point(width // 2, height // 2)]
|
||||
self.food = self._create_food()
|
||||
|
||||
def start_game(self):
|
||||
pygame.init()
|
||||
self._display = pygame.display.set_mode((self.width, self.height))
|
||||
pygame.display.set_caption('Snake Game')
|
||||
self._clock = pygame.time.Clock()
|
||||
self._running = True
|
||||
|
||||
while self._running:
|
||||
self._handle_events()
|
||||
self._update_snake()
|
||||
self._update_food()
|
||||
self._check_collision()
|
||||
self._draw_screen()
|
||||
self._clock.tick(self.speed)
|
||||
|
||||
def change_direction(self, direction: str):
|
||||
# Update the direction of the snake based on user input
|
||||
pass
|
||||
|
||||
def game_over(self):
|
||||
# Display game over message and handle game over logic
|
||||
pass
|
||||
|
||||
def _create_food(self) -> Point:
|
||||
# Create and return a new food Point
|
||||
return Point(random.randint(0, self.width - 1), random.randint(0, self.height - 1))
|
||||
|
||||
def _handle_events(self):
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
self._running = False
|
||||
|
||||
def _update_snake(self):
|
||||
# Update the position of the snake based on its direction
|
||||
pass
|
||||
|
||||
def _update_food(self):
|
||||
# Update the position of the food if the snake eats it
|
||||
pass
|
||||
|
||||
def _check_collision(self):
|
||||
# Check for collision between the snake and the walls or itself
|
||||
pass
|
||||
|
||||
def _draw_screen(self):
|
||||
self._display.fill((0, 0, 0)) # Clear the screen
|
||||
# Draw the snake and food on the screen
|
||||
pygame.display.update()
|
||||
|
||||
if __name__ == "__main__":
|
||||
game = Game(800, 600, 15)
|
||||
game.start_game()
|
||||
"""
|
||||
|
||||
FILE_GAME_CR_1 = """## Code Review: game.py
|
||||
1. Yes, the code is implemented as per the requirements. It initializes the game with the specified width, height, and speed, and starts the game loop.
|
||||
2. No, the logic for handling events and updating the snake, food, and collision is not implemented. To correct this, we need to implement the logic for handling events, updating the snake and food positions, and checking for collisions.
|
||||
3. Yes, the existing code follows the "Data structures and interfaces" by defining the Game and Point classes with the specified attributes and methods.
|
||||
4. No, several functions such as change_direction, game_over, _update_snake, _update_food, and _check_collision are not implemented. These functions need to be implemented to complete the game logic.
|
||||
5. Yes, all necessary pre-dependencies have been imported. The required pygame package is imported at the beginning of the file.
|
||||
6. No, methods from other files are not being reused as there are no other files being imported or referenced in the current code.
|
||||
|
||||
## Actions
|
||||
1. Implement the logic for handling events, updating the snake and food positions, and checking for collisions within the Game class.
|
||||
2. Implement the change_direction and game_over methods to handle user input and game over logic.
|
||||
3. Implement the _update_snake method to update the position of the snake based on its direction.
|
||||
4. Implement the _update_food method to update the position of the food if the snake eats it.
|
||||
5. Implement the _check_collision method to check for collision between the snake and the walls or itself.
|
||||
|
||||
## Code Review Result
|
||||
LBTM"""
|
||||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
@Time : 2023/5/18 23:51
|
||||
@Author : alexanderwu
|
||||
@File : mock.py
|
||||
@File : mock_markdown.py
|
||||
"""
|
||||
|
||||
PRD_SAMPLE = """## Original Requirements
|
||||
|
|
@ -5,9 +5,16 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_action.py
|
||||
"""
|
||||
from metagpt.actions import Action, WritePRD, WriteTest
|
||||
from metagpt.actions import Action, ActionType, WritePRD, WriteTest
|
||||
|
||||
|
||||
def test_action_repr():
|
||||
actions = [Action(), WriteTest(), WritePRD()]
|
||||
assert "WriteTest" in str(actions)
|
||||
|
||||
|
||||
def test_action_type():
|
||||
assert ActionType.WRITE_PRD.value == WritePRD
|
||||
assert ActionType.WRITE_TEST.value == WriteTest
|
||||
assert ActionType.WRITE_PRD.name == "WRITE_PRD"
|
||||
assert ActionType.WRITE_TEST.name == "WRITE_TEST"
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_action_node.py
|
||||
"""
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
|
|
@ -29,7 +31,7 @@ async def test_debate_two_roles():
|
|||
team = Team(investment=10.0, env=env, roles=[biden, trump])
|
||||
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3)
|
||||
assert "BidenSay" in history
|
||||
assert "Biden" in history
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -39,7 +41,7 @@ async def test_debate_one_role_in_env():
|
|||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden])
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3)
|
||||
assert "Debate" in history
|
||||
assert "Biden" in history
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -86,3 +88,47 @@ async def test_action_node_two_layer():
|
|||
assert node_b in root.children.values()
|
||||
json_template = root.compile(context="123", schema="json", mode="auto")
|
||||
assert "i-a" in json_template
|
||||
|
||||
|
||||
t_dict = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
|
||||
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
|
||||
"Logic Analysis": [
|
||||
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
|
||||
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
|
||||
["static/js/script.js", "Handles user interactions and updates the game UI."],
|
||||
["static/css/styles.css", "Defines the styles for the game UI."],
|
||||
["templates/index.html", "The main page of the web application. Displays the game UI."],
|
||||
],
|
||||
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
|
||||
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
|
||||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
|
||||
}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
"Required Other language third-party packages": (str, ...),
|
||||
"Full API spec": (str, ...),
|
||||
"Logic Analysis": (List[Tuple[str, str]], ...),
|
||||
"Task list": (List[str], ...),
|
||||
"Shared Knowledge": (str, ...),
|
||||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
|
||||
|
||||
def test_create_model_class():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
|
||||
def test_create_model_class_with_mapping():
|
||||
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
t1 = t(**t_dict)
|
||||
value = t1.dict()["Task list"]
|
||||
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_model_class()
|
||||
test_create_model_class_with_mapping()
|
||||
|
|
|
|||
|
|
@ -1,53 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
"""
|
||||
@Time : 2023/7/11 10:49
|
||||
@Author : chengmaoyu
|
||||
@File : test_action_output
|
||||
"""
|
||||
from typing import List, Tuple
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
t_dict = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
|
||||
"Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n',
|
||||
"Logic Analysis": [
|
||||
["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."],
|
||||
["game.py", "Contains the Game and Snake classes. Handles the game logic."],
|
||||
["static/js/script.js", "Handles user interactions and updates the game UI."],
|
||||
["static/css/styles.css", "Defines the styles for the game UI."],
|
||||
["templates/index.html", "The main page of the web application. Displays the game UI."],
|
||||
],
|
||||
"Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"],
|
||||
"Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n",
|
||||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
|
||||
}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
"Required Other language third-party packages": (str, ...),
|
||||
"Full API spec": (str, ...),
|
||||
"Logic Analysis": (List[Tuple[str, str]], ...),
|
||||
"Task list": (List[str], ...),
|
||||
"Shared Knowledge": (str, ...),
|
||||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
|
||||
|
||||
def test_create_model_class():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
|
||||
def test_create_model_class_with_mapping():
|
||||
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
t1 = t(**t_dict)
|
||||
value = t1.dict()["Task list"]
|
||||
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_model_class()
|
||||
test_create_model_class_with_mapping()
|
||||
|
|
@ -1,6 +1,13 @@
|
|||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.clone_function import CloneFunction, run_function_code
|
||||
from metagpt.actions.clone_function import (
|
||||
CloneFunction,
|
||||
run_function_code,
|
||||
run_function_script,
|
||||
)
|
||||
|
||||
source_code = """
|
||||
import pandas as pd
|
||||
|
|
@ -55,3 +62,40 @@ async def test_clone_function():
|
|||
assert not msg
|
||||
expected_df = get_expected_res()
|
||||
assert df.equals(expected_df)
|
||||
|
||||
|
||||
def test_run_function_script():
|
||||
# 创建一个临时文件并写入脚本内容
|
||||
script_content = """def valid_function(arg1, arg2):\n return arg1 + arg2\n"""
|
||||
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as temp_file:
|
||||
temp_file.write(script_content)
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
invalid_script_content = """def valid_function(arg1, arg2)\n return arg1 + arg2\n"""
|
||||
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as error_temp_file:
|
||||
error_temp_file.write(invalid_script_content)
|
||||
error_temp_file_path = error_temp_file.name
|
||||
|
||||
try:
|
||||
# 正常情况下运行脚本
|
||||
result, _ = run_function_script(temp_file_path, "valid_function", 1, arg2=2)
|
||||
assert result == 3
|
||||
|
||||
# 不存在的脚本路径
|
||||
with pytest.raises(FileNotFoundError):
|
||||
run_function_script("nonexistent/path/script.py", "valid_function", 1, arg2=2)
|
||||
|
||||
# 无效的脚本内容
|
||||
result, traceback = run_function_script(error_temp_file_path, "invalid_function", 1, arg2=2)
|
||||
assert not result
|
||||
assert "SyntaxError" in traceback
|
||||
|
||||
# 函数调用失败的情况
|
||||
result, traceback = run_function_script(temp_file_path, "function_that_raises_exception", 1, arg2=2)
|
||||
assert not result
|
||||
assert "KeyError" in traceback
|
||||
|
||||
finally:
|
||||
# 删除临时文件
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.const import PRDS_FILE_REPO
|
|||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock import PRD_SAMPLE
|
||||
from tests.metagpt.actions.mock_markdown import PRD_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from metagpt.actions.write_code import WriteCode
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from tests.metagpt.actions.mock import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
29
tests/metagpt/provider/test_anthropic_api.py
Normal file
29
tests/metagpt/provider/test_anthropic_api.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of Claude2
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.anthropic_api import Claude2
|
||||
|
||||
prompt = "who are you"
|
||||
resp = "I'am Claude2"
|
||||
|
||||
|
||||
def mock_llm_ask(self, msg: str) -> str:
|
||||
return resp
|
||||
|
||||
|
||||
async def mock_llm_aask(self, msg: str) -> str:
|
||||
return resp
|
||||
|
||||
|
||||
def test_claude2_ask(mocker):
|
||||
mocker.patch("metagpt.provider.anthropic_api.Claude2.ask", mock_llm_ask)
|
||||
assert resp == Claude2().ask(prompt)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude2_aask(mocker):
|
||||
mocker.patch("metagpt.provider.anthropic_api.Claude2.aask", mock_llm_aask)
|
||||
assert resp == await Claude2().aask(prompt)
|
||||
|
|
@ -6,10 +6,106 @@
|
|||
@File : test_base_gpt_api.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Message
|
||||
|
||||
default_chat_resp = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "I'am GPT",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
}
|
||||
prompt_msg = "who are you"
|
||||
resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
||||
|
||||
def test_message():
|
||||
message = Message(role="user", content="wtf")
|
||||
|
||||
class MockBaseGPTAPI(BaseGPTAPI):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
return resp_content
|
||||
|
||||
async def close(self):
|
||||
return default_chat_resp
|
||||
|
||||
|
||||
def test_base_gpt_api():
|
||||
message = Message(role="user", content="hello")
|
||||
assert "role" in message.to_dict()
|
||||
assert "user" in str(message)
|
||||
|
||||
base_gpt_api = MockBaseGPTAPI()
|
||||
msg_prompt = base_gpt_api.messages_to_prompt([message])
|
||||
assert msg_prompt == "user: hello"
|
||||
|
||||
msg_dict = base_gpt_api.messages_to_dict([message])
|
||||
assert msg_dict == [{"role": "user", "content": "hello"}]
|
||||
|
||||
openai_funccall_resp = {
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "test",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_Y5r6Ddr2Qc2ZrqgfwzPX5l72",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "execute",
|
||||
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
]
|
||||
}
|
||||
func: dict = base_gpt_api.get_choice_function(openai_funccall_resp)
|
||||
assert func == {
|
||||
"name": "execute",
|
||||
"arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}',
|
||||
}
|
||||
|
||||
func_args: dict = base_gpt_api.get_choice_function_arguments(openai_funccall_resp)
|
||||
assert func_args == {"language": "python", "code": "print('Hello, World!')"}
|
||||
|
||||
choice_text = base_gpt_api.get_choice_text(openai_funccall_resp)
|
||||
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
|
||||
|
||||
resp = base_gpt_api.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = base_gpt_api.ask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = base_gpt_api.ask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_base_gpt_api():
|
||||
base_gpt_api = MockBaseGPTAPI()
|
||||
|
||||
resp = await base_gpt_api.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_gpt_api.aask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_gpt_api.aask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -10,41 +10,82 @@ from openai.types.chat.chat_completion import (
|
|||
)
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.provider.fireworks_api import FireWorksGPTAPI
|
||||
from metagpt.provider.fireworks_api import (
|
||||
MODEL_GRADE_TOKEN_COSTS,
|
||||
FireworksCostManager,
|
||||
FireWorksGPTAPI,
|
||||
)
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="accounts/fireworks/models/llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703300855,
|
||||
choices=[
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content="I'm fireworks"))
|
||||
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content))
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> ChatCompletion:
|
||||
def test_fireworks_costmanager():
|
||||
cost_manager = FireworksCostManager()
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat")
|
||||
|
||||
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> ChatCompletion:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> ChatCompletion:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return default_resp.choices[0].message.content
|
||||
|
||||
|
||||
def test_fireworks_completion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_ask)
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion)
|
||||
fireworks_gpt = FireWorksGPTAPI()
|
||||
|
||||
resp = FireWorksGPTAPI().completion(messages)
|
||||
assert "fireworks" in resp.choices[0].message.content
|
||||
resp = fireworks_gpt.completion(messages)
|
||||
assert resp.choices[0].message.content == resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> ChatCompletion:
|
||||
return default_resp
|
||||
resp = fireworks_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_aask)
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch(
|
||||
"metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
fireworks_gpt = FireWorksGPTAPI()
|
||||
|
||||
resp = await FireWorksGPTAPI().acompletion(messages, stream=False)
|
||||
resp = await fireworks_gpt.acompletion(messages, stream=False)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
||||
assert "fireworks" in resp.choices[0].message.content
|
||||
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
20
tests/metagpt/provider/test_general_api_requestor.py
Normal file
20
tests/metagpt/provider/test_general_api_requestor.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of APIRequestor
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
|
||||
api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com")
|
||||
|
||||
|
||||
def test_api_requestor():
|
||||
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
|
||||
assert b"baidu" in resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_api_requestor():
|
||||
resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu")
|
||||
assert b"baidu" in resp
|
||||
|
|
@ -9,33 +9,62 @@ import pytest
|
|||
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
|
||||
messages = [{"role": "user", "parts": "who are you"}]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockGeminiResponse(ABC):
|
||||
text: str
|
||||
|
||||
|
||||
default_resp = MockGeminiResponse(text="I'm gemini from google")
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "parts": prompt_msg}]
|
||||
resp_content = "I'm gemini from google"
|
||||
default_resp = MockGeminiResponse(text=resp_content)
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse:
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(
|
||||
self, messgaes: list[dict], stream: bool = False, timeout: int = 60
|
||||
) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask)
|
||||
resp = GeminiGPTAPI().completion(messages)
|
||||
assert resp.text == default_resp.text
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion)
|
||||
gemini_gpt = GeminiGPTAPI()
|
||||
resp = gemini_gpt.completion(messages)
|
||||
assert resp.text == resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
resp = gemini_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask)
|
||||
resp = await GeminiGPTAPI().acompletion(messages)
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch(
|
||||
"metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
gemini_gpt = GeminiGPTAPI()
|
||||
|
||||
resp = await gemini_gpt.acompletion(messages)
|
||||
assert resp.text == default_resp.text
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await gemini_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
38
tests/metagpt/provider/test_human_provider.py
Normal file
38
tests/metagpt/provider/test_human_provider.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of HumanProvider
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
|
||||
resp_content = "test"
|
||||
|
||||
|
||||
def mock_llm_ask(msg: str, timeout: int = 3) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
|
||||
return mock_llm_ask(msg)
|
||||
|
||||
|
||||
def test_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
assert resp_content == human_provider.ask(None)
|
||||
|
||||
assert not human_provider.completion(messages=[])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
resp = await human_provider.aask(None)
|
||||
assert resp_content == resp
|
||||
|
||||
resp = await human_provider.acompletion([])
|
||||
assert not resp
|
||||
|
|
@ -5,11 +5,11 @@
|
|||
@Author : mashenquan
|
||||
@File : test_metagpt_llm_api.py
|
||||
"""
|
||||
from metagpt.provider.metagpt_llm_api import MetaGPTLLMAPI
|
||||
from metagpt.provider.metagpt_api import MetaGPTAPI
|
||||
|
||||
|
||||
def test_metagpt():
|
||||
llm = MetaGPTLLMAPI()
|
||||
llm = MetaGPTAPI()
|
||||
assert llm
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,30 +4,58 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.ollama_api import OllamaGPTAPI
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm ollama"
|
||||
default_resp = {"message": {"role": "assistant", "content": resp_content}}
|
||||
|
||||
CONFIG.ollama_api_base = "http://xxx"
|
||||
|
||||
|
||||
default_resp = {"message": {"role": "assisant", "content": "I'm ollama"}}
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> dict:
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_ask)
|
||||
resp = OllamaGPTAPI().completion(messages)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
resp = ollama_gpt.completion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict]) -> dict:
|
||||
return default_resp
|
||||
resp = ollama_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_aask)
|
||||
resp = await OllamaGPTAPI().acompletion(messages)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await ollama_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -85,14 +85,23 @@ def test_ask_code_list_str():
|
|||
class TestOpenAI:
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
return Mock(openai_api_key="test_key", openai_base_url="test_url", openai_proxy=None, openai_api_type="other")
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_azure(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
|
@ -101,7 +110,9 @@ class TestOpenAI:
|
|||
def config_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
|
@ -110,8 +121,10 @@ class TestOpenAI:
|
|||
def config_azure_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
|
@ -129,8 +142,8 @@ class TestOpenAI:
|
|||
instance = OpenAIGPTAPI()
|
||||
instance.config = config_azure
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"}
|
||||
assert async_kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"}
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert "http_client" not in kwargs
|
||||
assert "http_client" not in async_kwargs
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,51 @@
|
|||
from metagpt.logs import logger
|
||||
from metagpt.provider.spark_api import SparkAPI
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of spark api
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.spark_api import SparkGPTAPI
|
||||
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
|
||||
|
||||
def test_message():
|
||||
llm = SparkAPI()
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> str:
|
||||
return resp_content
|
||||
|
||||
logger.info(llm.ask('只回答"收到了"这三个字。'))
|
||||
result = llm.ask("写一篇五百字的日记")
|
||||
logger.info(result)
|
||||
assert len(result) > 100
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_spark_completion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
|
||||
resp = spark_gpt.completion([])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = spark_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
|
||||
resp = await spark_gpt.acompletion([], stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.acompletion_text([], stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -4,34 +4,62 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
|
||||
default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": "I'm chatglm-turbo"}]}}
|
||||
CONFIG.zhipuai_api_key = "xxx"
|
||||
|
||||
messages = [{"role": "user", "content": "who are you"}]
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": resp_content}]}}
|
||||
|
||||
|
||||
def mock_llm_ask(self, messages: list[dict]) -> dict:
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
def test_zhipuai_completion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask)
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion)
|
||||
zhipu_gpt = ZhiPuAIGPTAPI()
|
||||
|
||||
resp = ZhiPuAIGPTAPI().completion(messages)
|
||||
resp = zhipu_gpt.completion(messages)
|
||||
assert resp["code"] == 200
|
||||
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dict:
|
||||
return default_resp
|
||||
resp = zhipu_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask)
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch(
|
||||
"metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
zhipu_gpt = ZhiPuAIGPTAPI()
|
||||
|
||||
resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False)
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
|
||||
assert resp["code"] == 200
|
||||
assert "chatglm-turbo" in resp["data"]["choices"][0]["content"]
|
||||
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
"""
|
||||
@Time : 2023/5/12 13:05
|
||||
@Author : alexanderwu
|
||||
@File : mock.py
|
||||
@File : mock_markdown.py
|
||||
"""
|
||||
from metagpt.actions import UserRequirement, WriteDesign, WritePRD, WriteTasks
|
||||
from metagpt.schema import Message
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue