From fa70a70f53b9c2a55625a3eb56029e11647c4e37 Mon Sep 17 00:00:00 2001 From: geekan Date: Sun, 24 Dec 2023 20:51:50 +0800 Subject: [PATCH 1/5] add json mock --- metagpt/config.py | 1 + metagpt/utils/common.py | 2 +- tests/metagpt/actions/mock_json.py | 143 ++++++++++++++++++ .../actions/{mock.py => mock_markdown.py} | 2 +- tests/metagpt/actions/test_design_api.py | 2 +- tests/metagpt/actions/test_write_code.py | 2 +- tests/metagpt/roles/mock.py | 2 +- 7 files changed, 149 insertions(+), 5 deletions(-) create mode 100644 tests/metagpt/actions/mock_json.py rename tests/metagpt/actions/{mock.py => mock_markdown.py} (99%) diff --git a/metagpt/config.py b/metagpt/config.py index 9a452cab0..0109f4b1d 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -81,6 +81,7 @@ class Config(metaclass=Singleton): logger.debug("Config loading done.") def get_default_llm_provider_enum(self) -> LLMProviderEnum: + """Get first valid LLM provider enum""" mappings = { LLMProviderEnum.OPENAI: bool( self._is_valid_llm_key(self.OPENAI_API_KEY) and not self.OPENAI_API_TYPE and self.OPENAI_API_MODEL diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 382523083..09cc092fc 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -48,7 +48,7 @@ def check_cmd_exists(command) -> int: return result -def require_python_version(req_version: tuple[int]) -> bool: +def require_python_version(req_version: Tuple) -> bool: if not (2 <= len(req_version) <= 3): raise ValueError("req_version should be (3, 9) or (3, 10, 13)") return True if sys.version_info > req_version else False diff --git a/tests/metagpt/actions/mock_json.py b/tests/metagpt/actions/mock_json.py new file mode 100644 index 000000000..875d74d3c --- /dev/null +++ b/tests/metagpt/actions/mock_json.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2023/12/24 20:32 +@Author : alexanderwu +@File : mock_json.py +""" + +PRD = { + "Language": "zh_cn", + "Programming Language": "Python", + "Original Requirements": "写一个简单的cli贪吃蛇", + "Project Name": "cli_snake", + "Product Goals": ["创建一个简单易用的贪吃蛇游戏", "提供良好的用户体验", "支持不同难度级别"], + "User Stories": [ + "作为玩家,我希望能够选择不同的难度级别", + "作为玩家,我希望在每局游戏结束后能够看到我的得分", + "作为玩家,我希望在输掉游戏后能够重新开始", + "作为玩家,我希望看到简洁美观的界面", + "作为玩家,我希望能够在手机上玩游戏", + ], + "Competitive Analysis": ["贪吃蛇游戏A:界面简单,缺乏响应式特性", "贪吃蛇游戏B:美观且响应式的界面,显示最高得分", "贪吃蛇游戏C:响应式界面,显示最高得分,但有很多广告"], + "Competitive Quadrant Chart": 'quadrantChart\n title "Reach and engagement of campaigns"\n x-axis "Low Reach" --> "High Reach"\n y-axis "Low Engagement" --> "High Engagement"\n quadrant-1 "We should expand"\n quadrant-2 "Need to promote"\n quadrant-3 "Re-evaluate"\n quadrant-4 "May be improved"\n "Game A": [0.3, 0.6]\n "Game B": [0.45, 0.23]\n "Game C": [0.57, 0.69]\n "Game D": [0.78, 0.34]\n "Game E": [0.40, 0.34]\n "Game F": [0.35, 0.78]\n "Our Target Product": [0.5, 0.6]', + "Requirement Analysis": "", + "Requirement Pool": [["P0", "主要代码..."], ["P0", "游戏算法..."]], + "UI Design draft": "基本功能描述,简单的风格和布局。", + "Anything UNCLEAR": "", +} + + +DESIGN = { + "Implementation approach": "我们将使用Python编程语言,并选择合适的开源框架来实现贪吃蛇游戏。我们将分析需求中的难点,并选择合适的开源框架来简化开发流程。", + "File list": ["main.py", "game.py"], + "Data structures and interfaces": "\nclassDiagram\n class Game {\n -int width\n -int height\n -int score\n -int speed\n -List snake\n -Point food\n +__init__(width: int, height: int, speed: int)\n +start_game()\n +change_direction(direction: str)\n +game_over()\n +update_snake()\n +update_food()\n +check_collision()\n }\n class Point {\n -int x\n -int y\n +__init__(x: int, y: int)\n }\n Game --> Point\n", + "Program call flow": "\nsequenceDiagram\n participant M as Main\n participant G as Game\n M->>G: start_game()\n M->>G: change_direction(direction)\n G->>G: update_snake()\n G->>G: update_food()\n G->>G: check_collision()\n G-->>G: game_over()\n", + "Anything UNCLEAR": "", +} + + +TASKS = { + "Required Python packages": ["pygame==2.0.1"], + "Required Other language third-party packages": ["No third-party dependencies required"], + "Logic Analysis": [ + ["game.py", "Contains Game class and related functions for game logic"], + ["main.py", "Contains the main function, imports Game class from game.py"], + ], + "Task list": ["game.py", "main.py"], + "Full API spec": "", + "Shared Knowledge": "'game.py' contains functions shared across the project.", + "Anything UNCLEAR": "", +} + + +FILE_GAME = """## game.py + +import pygame +import random + +class Point: + def __init__(self, x: int, y: int): + self.x = x + self.y = y + +class Game: + def __init__(self, width: int, height: int, speed: int): + self.width = width + self.height = height + self.score = 0 + self.speed = speed + self.snake = [Point(width // 2, height // 2)] + self.food = self._create_food() + + def start_game(self): + pygame.init() + self._display = pygame.display.set_mode((self.width, self.height)) + pygame.display.set_caption('Snake Game') + self._clock = pygame.time.Clock() + self._running = True + + while self._running: + self._handle_events() + self._update_snake() + self._update_food() + self._check_collision() + self._draw_screen() + self._clock.tick(self.speed) + + def change_direction(self, direction: str): + # Update the direction of the snake based on user input + pass + + def game_over(self): + # Display game over message and handle game over logic + pass + + def _create_food(self) -> Point: + # Create and return a new food Point + return Point(random.randint(0, self.width - 1), random.randint(0, self.height - 1)) + + def _handle_events(self): + for event in pygame.event.get(): + if event.type == pygame.QUIT: + self._running = False + + def _update_snake(self): + # Update the position of the snake based on its direction + pass + + def _update_food(self): + # Update the position of the food if the snake eats it + pass + + def _check_collision(self): + # Check for collision between the snake and the walls or itself + pass + + def _draw_screen(self): + self._display.fill((0, 0, 0)) # Clear the screen + # Draw the snake and food on the screen + pygame.display.update() + +if __name__ == "__main__": + game = Game(800, 600, 15) + game.start_game() +""" + +FILE_GAME_CR_1 = """## Code Review: game.py +1. Yes, the code is implemented as per the requirements. It initializes the game with the specified width, height, and speed, and starts the game loop. +2. No, the logic for handling events and updating the snake, food, and collision is not implemented. To correct this, we need to implement the logic for handling events, updating the snake and food positions, and checking for collisions. +3. Yes, the existing code follows the "Data structures and interfaces" by defining the Game and Point classes with the specified attributes and methods. +4. No, several functions such as change_direction, game_over, _update_snake, _update_food, and _check_collision are not implemented. These functions need to be implemented to complete the game logic. +5. Yes, all necessary pre-dependencies have been imported. The required pygame package is imported at the beginning of the file. +6. No, methods from other files are not being reused as there are no other files being imported or referenced in the current code. + +## Actions +1. Implement the logic for handling events, updating the snake and food positions, and checking for collisions within the Game class. +2. Implement the change_direction and game_over methods to handle user input and game over logic. +3. Implement the _update_snake method to update the position of the snake based on its direction. +4. Implement the _update_food method to update the position of the food if the snake eats it. +5. Implement the _check_collision method to check for collision between the snake and the walls or itself. + +## Code Review Result +LBTM""" diff --git a/tests/metagpt/actions/mock.py b/tests/metagpt/actions/mock_markdown.py similarity index 99% rename from tests/metagpt/actions/mock.py rename to tests/metagpt/actions/mock_markdown.py index f6602a82b..c5d984146 100644 --- a/tests/metagpt/actions/mock.py +++ b/tests/metagpt/actions/mock_markdown.py @@ -3,7 +3,7 @@ """ @Time : 2023/5/18 23:51 @Author : alexanderwu -@File : mock.py +@File : mock_markdown.py """ PRD_SAMPLE = """## Original Requirements diff --git a/tests/metagpt/actions/test_design_api.py b/tests/metagpt/actions/test_design_api.py index e90707d1a..fe98b9120 100644 --- a/tests/metagpt/actions/test_design_api.py +++ b/tests/metagpt/actions/test_design_api.py @@ -13,7 +13,7 @@ from metagpt.const import PRDS_FILE_REPO from metagpt.logs import logger from metagpt.schema import Message from metagpt.utils.file_repository import FileRepository -from tests.metagpt.actions.mock import PRD_SAMPLE +from tests.metagpt.actions.mock_markdown import PRD_SAMPLE @pytest.mark.asyncio diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index 73f3a6dcf..ba7cb6f2d 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -12,7 +12,7 @@ from metagpt.actions.write_code import WriteCode from metagpt.logs import logger from metagpt.provider.openai_api import OpenAIGPTAPI as LLM from metagpt.schema import CodingContext, Document -from tests.metagpt.actions.mock import TASKS_2, WRITE_CODE_PROMPT_SAMPLE +from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE @pytest.mark.asyncio diff --git a/tests/metagpt/roles/mock.py b/tests/metagpt/roles/mock.py index 75f6b3b43..2ea036bb7 100644 --- a/tests/metagpt/roles/mock.py +++ b/tests/metagpt/roles/mock.py @@ -3,7 +3,7 @@ """ @Time : 2023/5/12 13:05 @Author : alexanderwu -@File : mock.py +@File : mock_markdown.py """ from metagpt.actions import UserRequirement, WriteDesign, WritePRD, WriteTasks from metagpt.schema import Message From a41ed7df66498c7e3c1016d9aac01818e1aca08a Mon Sep 17 00:00:00 2001 From: geekan Date: Mon, 25 Dec 2023 16:41:09 +0800 Subject: [PATCH 2/5] refine test code --- tests/metagpt/actions/test_action.py | 9 +++- tests/metagpt/actions/test_action_node.py | 50 ++++++++++++++++++- tests/metagpt/actions/test_action_output.py | 53 --------------------- 3 files changed, 56 insertions(+), 56 deletions(-) delete mode 100644 tests/metagpt/actions/test_action_output.py diff --git a/tests/metagpt/actions/test_action.py b/tests/metagpt/actions/test_action.py index 9775630cc..f750b5e6f 100644 --- a/tests/metagpt/actions/test_action.py +++ b/tests/metagpt/actions/test_action.py @@ -5,9 +5,16 @@ @Author : alexanderwu @File : test_action.py """ -from metagpt.actions import Action, WritePRD, WriteTest +from metagpt.actions import Action, ActionType, WritePRD, WriteTest def test_action_repr(): actions = [Action(), WriteTest(), WritePRD()] assert "WriteTest" in str(actions) + + +def test_action_type(): + assert ActionType.WRITE_PRD.value == WritePRD + assert ActionType.WRITE_TEST.value == WriteTest + assert ActionType.WRITE_PRD.name == "WRITE_PRD" + assert ActionType.WRITE_TEST.name == "WRITE_TEST" diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index 5bafe2bf2..92d8a1bbc 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -5,6 +5,8 @@ @Author : alexanderwu @File : test_action_node.py """ +from typing import List, Tuple + import pytest from metagpt.actions import Action @@ -29,7 +31,7 @@ async def test_debate_two_roles(): team = Team(investment=10.0, env=env, roles=[biden, trump]) history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3) - assert "BidenSay" in history + assert "Biden" in history @pytest.mark.asyncio @@ -39,7 +41,7 @@ async def test_debate_one_role_in_env(): env = Environment(desc="US election live broadcast") team = Team(investment=10.0, env=env, roles=[biden]) history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3) - assert "Debate" in history + assert "Biden" in history @pytest.mark.asyncio @@ -86,3 +88,47 @@ async def test_action_node_two_layer(): assert node_b in root.children.values() json_template = root.compile(context="123", schema="json", mode="auto") assert "i-a" in json_template + + +t_dict = { + "Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n', + "Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n', + "Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n', + "Logic Analysis": [ + ["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."], + ["game.py", "Contains the Game and Snake classes. Handles the game logic."], + ["static/js/script.js", "Handles user interactions and updates the game UI."], + ["static/css/styles.css", "Defines the styles for the game UI."], + ["templates/index.html", "The main page of the web application. Displays the game UI."], + ], + "Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"], + "Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n", + "Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?", +} + +WRITE_TASKS_OUTPUT_MAPPING = { + "Required Python third-party packages": (str, ...), + "Required Other language third-party packages": (str, ...), + "Full API spec": (str, ...), + "Logic Analysis": (List[Tuple[str, str]], ...), + "Task list": (List[str], ...), + "Shared Knowledge": (str, ...), + "Anything UNCLEAR": (str, ...), +} + + +def test_create_model_class(): + test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING) + assert test_class.__name__ == "test_class" + + +def test_create_model_class_with_mapping(): + t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING) + t1 = t(**t_dict) + value = t1.dict()["Task list"] + assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"] + + +if __name__ == "__main__": + test_create_model_class() + test_create_model_class_with_mapping() diff --git a/tests/metagpt/actions/test_action_output.py b/tests/metagpt/actions/test_action_output.py deleted file mode 100644 index f1765cb03..000000000 --- a/tests/metagpt/actions/test_action_output.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python -# coding: utf-8 -""" -@Time : 2023/7/11 10:49 -@Author : chengmaoyu -@File : test_action_output -""" -from typing import List, Tuple - -from metagpt.actions.action_node import ActionNode - -t_dict = { - "Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n', - "Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n', - "Full API spec": '"""\nopenapi: 3.0.0\ninfo:\n title: Web Snake Game API\n version: 1.0.0\npaths:\n /game:\n get:\n summary: Get the current game state\n responses:\n \'200\':\n description: A JSON object of the game state\n post:\n summary: Send a command to the game\n requestBody:\n required: true\n content:\n application/json:\n schema:\n type: object\n properties:\n command:\n type: string\n responses:\n \'200\':\n description: A JSON object of the updated game state\n"""\n', - "Logic Analysis": [ - ["app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."], - ["game.py", "Contains the Game and Snake classes. Handles the game logic."], - ["static/js/script.js", "Handles user interactions and updates the game UI."], - ["static/css/styles.css", "Defines the styles for the game UI."], - ["templates/index.html", "The main page of the web application. Displays the game UI."], - ], - "Task list": ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"], - "Shared Knowledge": "\"\"\"\n'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.\n\n'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.\n\n'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.\n\n'static/css/styles.css' defines the styles for the game UI.\n\n'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.\n\"\"\"\n", - "Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?", -} - -WRITE_TASKS_OUTPUT_MAPPING = { - "Required Python third-party packages": (str, ...), - "Required Other language third-party packages": (str, ...), - "Full API spec": (str, ...), - "Logic Analysis": (List[Tuple[str, str]], ...), - "Task list": (List[str], ...), - "Shared Knowledge": (str, ...), - "Anything UNCLEAR": (str, ...), -} - - -def test_create_model_class(): - test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING) - assert test_class.__name__ == "test_class" - - -def test_create_model_class_with_mapping(): - t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING) - t1 = t(**t_dict) - value = t1.dict()["Task list"] - assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"] - - -if __name__ == "__main__": - test_create_model_class() - test_create_model_class_with_mapping() From 8a5f8b7ee0d22c8286771a1eab7e64faaf962a7f Mon Sep 17 00:00:00 2001 From: geekan Date: Mon, 25 Dec 2023 18:00:41 +0800 Subject: [PATCH 3/5] add #TOTEST flag --- metagpt/actions/search_and_summarize.py | 1 + metagpt/actions/skill_action.py | 1 + metagpt/actions/summarize_code.py | 1 + metagpt/actions/talk_action.py | 1 + 4 files changed, 4 insertions(+) diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 25af21795..9fd392a5c 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -105,6 +105,7 @@ You are a member of a professional butler team and will provide helpful suggesti """ +# TOTEST class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None diff --git a/metagpt/actions/skill_action.py b/metagpt/actions/skill_action.py index c95a83cbb..292202294 100644 --- a/metagpt/actions/skill_action.py +++ b/metagpt/actions/skill_action.py @@ -19,6 +19,7 @@ from metagpt.learn.skill_loader import Skill from metagpt.logs import logger +# TOTEST class ArgumentsParingAction(Action): skill: Skill ask: str diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index 0aec15937..2d1cd4d3d 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -91,6 +91,7 @@ flowchart TB """ +# TOTEST class SummarizeCode(Action): name: str = "SummarizeCode" context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) diff --git a/metagpt/actions/talk_action.py b/metagpt/actions/talk_action.py index 3695ec5bb..1c22e86de 100644 --- a/metagpt/actions/talk_action.py +++ b/metagpt/actions/talk_action.py @@ -15,6 +15,7 @@ from metagpt.llm import LLMType from metagpt.logs import logger +# TOTEST class TalkAction(Action): def __init__(self, name: str = "", talk="", history_summary="", knowledge="", context=None, llm=None, **kwargs): context = context or {} From 454e6164fb804bba1fcc58797140e3ee15e137ab Mon Sep 17 00:00:00 2001 From: better629 Date: Mon, 25 Dec 2023 18:00:51 +0800 Subject: [PATCH 4/5] update provider unittests --- metagpt/provider/anthropic_api.py | 10 +- metagpt/provider/base_gpt_api.py | 2 +- metagpt/provider/fireworks_api.py | 4 +- metagpt/provider/google_gemini_api.py | 7 +- metagpt/provider/ollama_api.py | 7 +- metagpt/provider/spark_api.py | 11 +- metagpt/provider/zhipuai_api.py | 5 +- tests/metagpt/provider/test_anthropic_api.py | 29 +++++ tests/metagpt/provider/test_base_gpt_api.py | 100 +++++++++++++++++- tests/metagpt/provider/test_fireworks_api.py | 67 +++++++++--- .../provider/test_general_api_requestor.py | 20 ++++ .../provider/test_google_gemini_api.py | 53 +++++++--- tests/metagpt/provider/test_human_provider.py | 38 +++++++ .../metagpt/provider/test_metagpt_llm_api.py | 4 +- tests/metagpt/provider/test_ollama_api.py | 52 ++++++--- tests/metagpt/provider/test_openai.py | 19 +++- tests/metagpt/provider/test_spark_api.py | 56 ++++++++-- tests/metagpt/provider/test_zhipuai_api.py | 54 +++++++--- 18 files changed, 460 insertions(+), 78 deletions(-) create mode 100644 tests/metagpt/provider/test_anthropic_api.py create mode 100644 tests/metagpt/provider/test_general_api_requestor.py create mode 100644 tests/metagpt/provider/test_human_provider.py diff --git a/metagpt/provider/anthropic_api.py b/metagpt/provider/anthropic_api.py index f5b06c855..b9d7d9e38 100644 --- a/metagpt/provider/anthropic_api.py +++ b/metagpt/provider/anthropic_api.py @@ -7,13 +7,13 @@ """ import anthropic -from anthropic import Anthropic +from anthropic import Anthropic, AsyncAnthropic from metagpt.config import CONFIG class Claude2: - def ask(self, prompt): + def ask(self, prompt: str) -> str: client = Anthropic(api_key=CONFIG.anthropic_api_key) res = client.completions.create( @@ -23,10 +23,10 @@ class Claude2: ) return res.completion - async def aask(self, prompt): - client = Anthropic(api_key=CONFIG.anthropic_api_key) + async def aask(self, prompt: str) -> str: + aclient = AsyncAnthropic(api_key=CONFIG.anthropic_api_key) - res = client.completions.create( + res = await aclient.completions.create( model="claude-2", prompt=f"{anthropic.HUMAN_PROMPT} {prompt} {anthropic.AI_PROMPT}", max_tokens_to_sample=1000, diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index f650305e3..a5541324f 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -162,7 +162,7 @@ class BaseGPTAPI(BaseChatbot): def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" - return "\n".join([f"{i['role']}: {i['content']}" for i in messages]) + return "\n".join([f"{i.role}: {i.content}" for i in messages]) def messages_to_dict(self, messages): """objects to [{"role": "user", "content": msg}] etc.""" diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index 96b7db453..55b1b6c28 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -133,7 +133,9 @@ class FireWorksGPTAPI(OpenAIGPTAPI): retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + async def acompletion_text( + self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 + ) -> str: """when streaming, print each token in place.""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index eb91cc32b..e9d3ea70d 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -79,6 +79,9 @@ class GeminiGPTAPI(BaseGPTAPI): except Exception as e: logger.error(f"google gemini updats costs failed! exp: {e}") + def close(self): + pass + def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text @@ -133,7 +136,9 @@ class GeminiGPTAPI(BaseGPTAPI): retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + async def acompletion_text( + self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 + ) -> str: """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 05bdb5a1f..7d858e769 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -57,6 +57,9 @@ class OllamaGPTAPI(BaseGPTAPI): self.model = config.ollama_api_model + def close(self): + pass + def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} return kwargs @@ -144,7 +147,9 @@ class OllamaGPTAPI(BaseGPTAPI): retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + async def acompletion_text( + self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 + ) -> str: """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 484fa7956..70076bc86 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -26,16 +26,19 @@ from metagpt.provider.llm_provider_registry import register_provider @register_provider(LLMProviderEnum.SPARK) -class SparkAPI(BaseGPTAPI): +class SparkGPTAPI(BaseGPTAPI): def __init__(self): logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") + def close(self): + pass + def ask(self, msg: str) -> str: message = [self._default_system_msg(), self._user_msg(msg)] rsp = self.completion(message) return rsp - async def aask(self, msg: str, system_msgs: Optional[list[str]] = None) -> str: + async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str: if system_msgs: message = self._system_msgs(system_msgs) + [self._user_msg(msg)] else: @@ -47,7 +50,9 @@ class SparkAPI(BaseGPTAPI): def get_choice_text(self, rsp: dict) -> str: return rsp["payload"]["choices"]["text"][-1]["content"] - async def acompletion_text(self, messages: list[dict], stream=False) -> str: + async def acompletion_text( + self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 + ) -> str: # 不支持 logger.error("该功能禁用。") w = GetMessageFromWeb(messages) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 4a2cae51d..0d5663431 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -64,6 +64,9 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): except Exception as e: logger.error(f"zhipuai updats costs failed! exp: {e}") + def close(self): + pass + def get_choice_text(self, resp: dict) -> str: """get the first text of choice from llm response""" assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1] @@ -131,6 +134,6 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: """response in async with stream or non-stream mode""" if stream: - return await self._achat_completion_stream(messages, timeout=timeout) + return await self._achat_completion_stream(messages) resp = await self._achat_completion(messages) return self.get_choice_text(resp) diff --git a/tests/metagpt/provider/test_anthropic_api.py b/tests/metagpt/provider/test_anthropic_api.py new file mode 100644 index 000000000..4d3de5320 --- /dev/null +++ b/tests/metagpt/provider/test_anthropic_api.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of Claude2 + +import pytest + +from metagpt.provider.anthropic_api import Claude2 + +prompt = "who are you" +resp = "I'am Claude2" + + +def mock_llm_ask(self, msg: str) -> str: + return resp + + +async def mock_llm_aask(self, msg: str) -> str: + return resp + + +def test_claude2_ask(mocker): + mocker.patch("metagpt.provider.anthropic_api.Claude2.ask", mock_llm_ask) + assert resp == Claude2().ask(prompt) + + +@pytest.mark.asyncio +async def test_claude2_aask(mocker): + mocker.patch("metagpt.provider.anthropic_api.Claude2.aask", mock_llm_aask) + assert resp == await Claude2().aask(prompt) diff --git a/tests/metagpt/provider/test_base_gpt_api.py b/tests/metagpt/provider/test_base_gpt_api.py index 6cfe3b02d..aaa7b64ff 100644 --- a/tests/metagpt/provider/test_base_gpt_api.py +++ b/tests/metagpt/provider/test_base_gpt_api.py @@ -6,10 +6,106 @@ @File : test_base_gpt_api.py """ +import pytest + +from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.schema import Message +default_chat_resp = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "I'am GPT", + }, + "finish_reason": "stop", + } + ] +} +prompt_msg = "who are you" +resp_content = default_chat_resp["choices"][0]["message"]["content"] -def test_message(): - message = Message(role="user", content="wtf") + +class MockBaseGPTAPI(BaseGPTAPI): + def completion(self, messages: list[dict], timeout=3): + return default_chat_resp + + async def acompletion(self, messages: list[dict], timeout=3): + return default_chat_resp + + async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + return resp_content + + async def close(self): + return default_chat_resp + + +def test_base_gpt_api(): + message = Message(role="user", content="hello") assert "role" in message.to_dict() assert "user" in str(message) + + base_gpt_api = MockBaseGPTAPI() + msg_prompt = base_gpt_api.messages_to_prompt([message]) + assert msg_prompt == "user: hello" + + msg_dict = base_gpt_api.messages_to_dict([message]) + assert msg_dict == [{"role": "user", "content": "hello"}] + + openai_funccall_resp = { + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "test", + "tool_calls": [ + { + "id": "call_Y5r6Ddr2Qc2ZrqgfwzPX5l72", + "type": "function", + "function": { + "name": "execute", + "arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}', + }, + } + ], + }, + "finish_reason": "stop", + } + ] + } + func: dict = base_gpt_api.get_choice_function(openai_funccall_resp) + assert func == { + "name": "execute", + "arguments": '{\n "language": "python",\n "code": "print(\'Hello, World!\')"\n}', + } + + func_args: dict = base_gpt_api.get_choice_function_arguments(openai_funccall_resp) + assert func_args == {"language": "python", "code": "print('Hello, World!')"} + + choice_text = base_gpt_api.get_choice_text(openai_funccall_resp) + assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"] + + resp = base_gpt_api.ask(prompt_msg) + assert resp == resp_content + + resp = base_gpt_api.ask_batch([prompt_msg]) + assert resp == resp_content + + resp = base_gpt_api.ask_code([prompt_msg]) + assert resp == resp_content + + +@pytest.mark.asyncio +async def test_async_base_gpt_api(): + base_gpt_api = MockBaseGPTAPI() + + resp = await base_gpt_api.aask(prompt_msg) + assert resp == resp_content + + resp = await base_gpt_api.aask_batch([prompt_msg]) + assert resp == resp_content + + resp = await base_gpt_api.aask_code([prompt_msg]) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py index 43e45adf3..caf8b9f45 100644 --- a/tests/metagpt/provider/test_fireworks_api.py +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -10,41 +10,82 @@ from openai.types.chat.chat_completion import ( ) from openai.types.completion_usage import CompletionUsage -from metagpt.provider.fireworks_api import FireWorksGPTAPI +from metagpt.provider.fireworks_api import ( + MODEL_GRADE_TOKEN_COSTS, + FireworksCostManager, + FireWorksGPTAPI, +) +resp_content = "I'm fireworks" default_resp = ChatCompletion( id="cmpl-a6652c1bb181caae8dd19ad8", model="accounts/fireworks/models/llama-v2-13b-chat", object="chat.completion", created=1703300855, choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content="I'm fireworks")) + Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content)) ], usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), ) -messages = [{"role": "user", "content": "who are you"}] +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] -def mock_llm_ask(self, messages: list[dict]) -> ChatCompletion: +def test_fireworks_costmanager(): + cost_manager = FireworksCostManager() + assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("test") + assert MODEL_GRADE_TOKEN_COSTS["-1"] == cost_manager.model_grade_token_costs("xxx-81b-chat") + assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("llama-v2-13b-chat") + assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-15.5b-chat") + assert MODEL_GRADE_TOKEN_COSTS["16"] == cost_manager.model_grade_token_costs("xxx-16b-chat") + assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat") + assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat") + + +def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> ChatCompletion: return default_resp +async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> ChatCompletion: + return default_resp + + +async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: + return default_resp.choices[0].message.content + + def test_fireworks_completion(mocker): - mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_ask) + mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion) + fireworks_gpt = FireWorksGPTAPI() - resp = FireWorksGPTAPI().completion(messages) - assert "fireworks" in resp.choices[0].message.content + resp = fireworks_gpt.completion(messages) + assert resp.choices[0].message.content == resp_content - -async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> ChatCompletion: - return default_resp + resp = fireworks_gpt.ask(prompt_msg) + assert resp == resp_content @pytest.mark.asyncio async def test_fireworks_acompletion(mocker): - mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_aask) + mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch( + "metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + ) + fireworks_gpt = FireWorksGPTAPI() - resp = await FireWorksGPTAPI().acompletion(messages, stream=False) + resp = await fireworks_gpt.acompletion(messages, stream=False) + assert resp.choices[0].message.content in resp_content - assert "fireworks" in resp.choices[0].message.content + resp = await fireworks_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await fireworks_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await fireworks_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await fireworks_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_general_api_requestor.py b/tests/metagpt/provider/test_general_api_requestor.py new file mode 100644 index 000000000..28130fa65 --- /dev/null +++ b/tests/metagpt/provider/test_general_api_requestor.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of APIRequestor + +import pytest + +from metagpt.provider.general_api_requestor import GeneralAPIRequestor + +api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com") + + +def test_api_requestor(): + resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu") + assert b"baidu" in resp + + +@pytest.mark.asyncio +async def test_async_api_requestor(): + resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu") + assert b"baidu" in resp diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 9c8cf46c0..aec7b8520 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -9,33 +9,62 @@ import pytest from metagpt.provider.google_gemini_api import GeminiGPTAPI -messages = [{"role": "user", "parts": "who are you"}] - @dataclass class MockGeminiResponse(ABC): text: str -default_resp = MockGeminiResponse(text="I'm gemini from google") +prompt_msg = "who are you" +messages = [{"role": "user", "parts": prompt_msg}] +resp_content = "I'm gemini from google" +default_resp = MockGeminiResponse(text=resp_content) -def mock_llm_ask(self, messages: list[dict]) -> MockGeminiResponse: +def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> MockGeminiResponse: return default_resp +async def mock_llm_acompletion( + self, messgaes: list[dict], stream: bool = False, timeout: int = 60 +) -> MockGeminiResponse: + return default_resp + + +async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: + return resp_content + + def test_gemini_completion(mocker): - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_ask) - resp = GeminiGPTAPI().completion(messages) - assert resp.text == default_resp.text + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion) + gemini_gpt = GeminiGPTAPI() + resp = gemini_gpt.completion(messages) + assert resp.text == resp_content - -async def mock_llm_aask(self, messgaes: list[dict]) -> MockGeminiResponse: - return default_resp + resp = gemini_gpt.ask(prompt_msg) + assert resp == resp_content @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_aask) - resp = await GeminiGPTAPI().acompletion(messages) + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch( + "metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + ) + gemini_gpt = GeminiGPTAPI() + + resp = await gemini_gpt.acompletion(messages) assert resp.text == default_resp.text + + resp = await gemini_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await gemini_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await gemini_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await gemini_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_human_provider.py b/tests/metagpt/provider/test_human_provider.py new file mode 100644 index 000000000..caab9f15f --- /dev/null +++ b/tests/metagpt/provider/test_human_provider.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of HumanProvider + +import pytest + +from metagpt.provider.human_provider import HumanProvider + +resp_content = "test" + + +def mock_llm_ask(msg: str, timeout: int = 3) -> str: + return resp_content + + +async def mock_llm_aask(msg: str, timeout: int = 3) -> str: + return mock_llm_ask(msg) + + +def test_human_provider(mocker): + mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask) + human_provider = HumanProvider() + + assert resp_content == human_provider.ask(None) + + assert not human_provider.completion(messages=[]) + + +@pytest.mark.asyncio +async def test_async_human_provider(mocker): + mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask) + human_provider = HumanProvider() + + resp = await human_provider.aask(None) + assert resp_content == resp + + resp = await human_provider.acompletion([]) + assert not resp diff --git a/tests/metagpt/provider/test_metagpt_llm_api.py b/tests/metagpt/provider/test_metagpt_llm_api.py index 9c8356ca6..f454b08a7 100644 --- a/tests/metagpt/provider/test_metagpt_llm_api.py +++ b/tests/metagpt/provider/test_metagpt_llm_api.py @@ -5,11 +5,11 @@ @Author : mashenquan @File : test_metagpt_llm_api.py """ -from metagpt.provider.metagpt_llm_api import MetaGPTLLMAPI +from metagpt.provider.metagpt_api import MetaGPTAPI def test_metagpt(): - llm = MetaGPTLLMAPI() + llm = MetaGPTAPI() assert llm diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index 2798f5cc3..d552d9f9e 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -4,30 +4,58 @@ import pytest +from metagpt.config import CONFIG from metagpt.provider.ollama_api import OllamaGPTAPI -messages = [{"role": "user", "content": "who are you"}] +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] + +resp_content = "I'm ollama" +default_resp = {"message": {"role": "assistant", "content": resp_content}} + +CONFIG.ollama_api_base = "http://xxx" -default_resp = {"message": {"role": "assisant", "content": "I'm ollama"}} - - -def mock_llm_ask(self, messages: list[dict]) -> dict: +def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict: return default_resp +async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict: + return default_resp + + +async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: + return resp_content + + def test_gemini_completion(mocker): - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_ask) - resp = OllamaGPTAPI().completion(messages) + mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion) + ollama_gpt = OllamaGPTAPI() + resp = ollama_gpt.completion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] - -async def mock_llm_aask(self, messgaes: list[dict]) -> dict: - return default_resp + resp = ollama_gpt.ask(prompt_msg) + assert resp == resp_content @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_aask) - resp = await OllamaGPTAPI().acompletion(messages) + mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream) + ollama_gpt = OllamaGPTAPI() + + resp = await ollama_gpt.acompletion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] + + resp = await ollama_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await ollama_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await ollama_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await ollama_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 332d554cf..1f25951b1 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -85,14 +85,23 @@ def test_ask_code_list_str(): class TestOpenAI: @pytest.fixture def config(self): - return Mock(openai_api_key="test_key", openai_base_url="test_url", openai_proxy=None, openai_api_type="other") + return Mock( + openai_api_key="test_key", + OPENAI_API_KEY="test_key", + openai_base_url="test_url", + OPENAI_BASE_URL="test_url", + openai_proxy=None, + openai_api_type="other", + ) @pytest.fixture def config_azure(self): return Mock( openai_api_key="test_key", + OPENAI_API_KEY="test_key", openai_api_version="test_version", openai_base_url="test_url", + OPENAI_BASE_URL="test_url", openai_proxy=None, openai_api_type="azure", ) @@ -101,7 +110,9 @@ class TestOpenAI: def config_proxy(self): return Mock( openai_api_key="test_key", + OPENAI_API_KEY="test_key", openai_base_url="test_url", + OPENAI_BASE_URL="test_url", openai_proxy="http://proxy.com", openai_api_type="other", ) @@ -110,8 +121,10 @@ class TestOpenAI: def config_azure_proxy(self): return Mock( openai_api_key="test_key", + OPENAI_API_KEY="test_key", openai_api_version="test_version", openai_base_url="test_url", + OPENAI_BASE_URL="test_url", openai_proxy="http://proxy.com", openai_api_type="azure", ) @@ -129,8 +142,8 @@ class TestOpenAI: instance = OpenAIGPTAPI() instance.config = config_azure kwargs, async_kwargs = instance._make_client_kwargs() - assert kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} - assert async_kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} + assert kwargs == {"api_key": "test_key", "base_url": "test_url"} + assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"} assert "http_client" not in kwargs assert "http_client" not in async_kwargs diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 3b3dd67f4..61ae8cbec 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -1,11 +1,51 @@ -from metagpt.logs import logger -from metagpt.provider.spark_api import SparkAPI +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : the unittest of spark api + +import pytest + +from metagpt.provider.spark_api import SparkGPTAPI + +prompt_msg = "who are you" +resp_content = "I'm Spark" -def test_message(): - llm = SparkAPI() +def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> str: + return resp_content - logger.info(llm.ask('只回答"收到了"这三个字。')) - result = llm.ask("写一篇五百字的日记") - logger.info(result) - assert len(result) > 100 + +async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> str: + return resp_content + + +def test_spark_completion(mocker): + mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion) + spark_gpt = SparkGPTAPI() + + resp = spark_gpt.completion([]) + assert resp == resp_content + + resp = spark_gpt.ask(prompt_msg) + assert resp == resp_content + + +@pytest.mark.asyncio +async def test_spark_acompletion(mocker): + mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion) + spark_gpt = SparkGPTAPI() + + resp = await spark_gpt.acompletion([], stream=False) + assert resp == resp_content + + resp = await spark_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await spark_gpt.acompletion_text([], stream=False) + assert resp == resp_content + + resp = await spark_gpt.acompletion_text([], stream=True) + assert resp == resp_content + + resp = await spark_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 4684e8887..ec02e1b47 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -4,34 +4,62 @@ import pytest +from metagpt.config import CONFIG from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI -default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": "I'm chatglm-turbo"}]}} +CONFIG.zhipuai_api_key = "xxx" -messages = [{"role": "user", "content": "who are you"}] +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] + +resp_content = "I'm chatglm-turbo" +default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": resp_content}]}} -def mock_llm_ask(self, messages: list[dict]) -> dict: +def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict: return default_resp +async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict: + return default_resp + + +async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: + return resp_content + + def test_zhipuai_completion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_ask) + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion) + zhipu_gpt = ZhiPuAIGPTAPI() - resp = ZhiPuAIGPTAPI().completion(messages) + resp = zhipu_gpt.completion(messages) assert resp["code"] == 200 - assert "chatglm-turbo" in resp["data"]["choices"][0]["content"] + assert resp["data"]["choices"][0]["content"] == resp_content - -async def mock_llm_aask(self, messgaes: list[dict], stream: bool = False) -> dict: - return default_resp + resp = zhipu_gpt.ask(prompt_msg) + assert resp == resp_content @pytest.mark.asyncio async def test_zhipuai_acompletion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion_text", mock_llm_aask) + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch( + "metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + ) + zhipu_gpt = ZhiPuAIGPTAPI() - resp = await ZhiPuAIGPTAPI().acompletion_text(messages, stream=False) + resp = await zhipu_gpt.acompletion(messages) + assert resp["data"]["choices"][0]["content"] == resp_content - assert resp["code"] == 200 - assert "chatglm-turbo" in resp["data"]["choices"][0]["content"] + resp = await zhipu_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await zhipu_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await zhipu_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await zhipu_gpt.aask(prompt_msg) + assert resp == resp_content From 2b57b88ec8364553b7995be274438daf801c799b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Mon, 25 Dec 2023 18:25:41 +0800 Subject: [PATCH 5/5] add test for run_function_script. --- tests/metagpt/actions/test_clone_function.py | 46 +++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/metagpt/actions/test_clone_function.py b/tests/metagpt/actions/test_clone_function.py index 44248eb80..93ead48bd 100644 --- a/tests/metagpt/actions/test_clone_function.py +++ b/tests/metagpt/actions/test_clone_function.py @@ -1,6 +1,13 @@ +import os +import tempfile + import pytest -from metagpt.actions.clone_function import CloneFunction, run_function_code +from metagpt.actions.clone_function import ( + CloneFunction, + run_function_code, + run_function_script, +) source_code = """ import pandas as pd @@ -55,3 +62,40 @@ async def test_clone_function(): assert not msg expected_df = get_expected_res() assert df.equals(expected_df) + + +def test_run_function_script(): + # 创建一个临时文件并写入脚本内容 + script_content = """def valid_function(arg1, arg2):\n return arg1 + arg2\n""" + with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as temp_file: + temp_file.write(script_content) + temp_file_path = temp_file.name + + invalid_script_content = """def valid_function(arg1, arg2)\n return arg1 + arg2\n""" + with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as error_temp_file: + error_temp_file.write(invalid_script_content) + error_temp_file_path = error_temp_file.name + + try: + # 正常情况下运行脚本 + result, _ = run_function_script(temp_file_path, "valid_function", 1, arg2=2) + assert result == 3 + + # 不存在的脚本路径 + with pytest.raises(FileNotFoundError): + run_function_script("nonexistent/path/script.py", "valid_function", 1, arg2=2) + + # 无效的脚本内容 + result, traceback = run_function_script(error_temp_file_path, "invalid_function", 1, arg2=2) + assert not result + assert "SyntaxError" in traceback + + # 函数调用失败的情况 + result, traceback = run_function_script(temp_file_path, "function_that_raises_exception", 1, arg2=2) + assert not result + assert "KeyError" in traceback + + finally: + # 删除临时文件 + if os.path.exists(temp_file_path): + os.remove(temp_file_path)