From 5394da6d37399a480c39ae6d5ebedac7169efa25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 7 Dec 2023 15:51:12 +0800 Subject: [PATCH] fixbug: azure call function --- metagpt/actions/design_api.py | 4 ++-- metagpt/actions/prepare_documents.py | 5 ++++- metagpt/provider/openai_api.py | 12 ++++++++---- metagpt/utils/git_repository.py | 5 ++++- requirements.txt | 2 +- tests/metagpt/test_gpt.py | 27 +++++++++++++++++---------- 6 files changed, 36 insertions(+), 19 deletions(-) diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index eb73ed94f..557ebcbbd 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -267,10 +267,10 @@ class WriteDesign(Action): @staticmethod async def _save_data_api_design(design_doc): m = json.loads(design_doc.content) - data_api_design = m.get("Data structures and interface definitions") + data_api_design = m.get("Data structures and interfaces") if not data_api_design: return - pathname = CONFIG.git_repo.workdir / Path(DATA_API_DESIGN_FILE_REPO) / Path(design_doc.filename).with_suffix("") + pathname = CONFIG.git_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("") await WriteDesign._save_mermaid_file(data_api_design, pathname) logger.info(f"Save class view to {str(pathname)}") diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 4a2082a07..05255dcc5 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -26,7 +26,10 @@ class PrepareDocuments(Action): if not CONFIG.git_repo: # Create and initialize the workspace folder, initialize the Git environment. project_name = CONFIG.project_name or FileRepository.new_filename() - workdir = Path(CONFIG.project_path or DEFAULT_WORKSPACE_ROOT / project_name) + workdir = CONFIG.project_path + if not workdir and CONFIG.workspace: + workdir = Path(CONFIG.workspace) / project_name + workdir = Path(workdir or DEFAULT_WORKSPACE_ROOT / project_name) if not CONFIG.inc and workdir.exists(): shutil.rmtree(workdir) CONFIG.git_repo = GitRepository() diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 2d4b1583a..97bc67069 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -12,6 +12,7 @@ import asyncio import time from typing import NamedTuple, Union +import openai from openai import APIConnectionError, AsyncAzureOpenAI, AsyncOpenAI, RateLimitError from openai.types import CompletionUsage from tenacity import ( @@ -188,7 +189,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): else: kwargs["model"] = self.model kwargs["timeout"] = max(CONFIG.TIMEOUT, timeout) if CONFIG.TIMEOUT is not None else timeout - + return kwargs async def _achat_completion(self, messages: list[dict], timeout=3) -> dict: @@ -312,8 +313,12 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): >>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} """ messages = self._process_message(messages) - rsp = await self._achat_completion_function(messages, **kwargs) - return self.get_choice_function_arguments(rsp) + try: + rsp = await self._achat_completion_function(messages, **kwargs) + return self.get_choice_function_arguments(rsp) + except openai.NotFoundError as e: + logger.error(f"API TYPE:{CONFIG.openai_api_type}, err:{e}") + raise e def _calc_usage(self, messages: list[dict], rsp: str) -> CompletionUsage: if CONFIG.calc_usage: @@ -406,4 +411,3 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return loop else: raise e - diff --git a/metagpt/utils/git_repository.py b/metagpt/utils/git_repository.py index 9a9ed0fce..5aec4509c 100644 --- a/metagpt/utils/git_repository.py +++ b/metagpt/utils/git_repository.py @@ -197,7 +197,10 @@ class GitRepository: if new_path.exists(): logger.info(f"Delete directory {str(new_path)}") shutil.rmtree(new_path) - os.rename(src=str(self.workdir), dst=str(new_path)) # self.workdir.rename(new_path) + try: + shutil.move(src=str(self.workdir), dst=str(new_path)) + except Exception as e: + logger.warning(f"Move {str(self.workdir)} to {str(new_path)} error: {e}") logger.info(f"Rename directory {str(self.workdir)} to {str(new_path)}") self._repository = Repo(new_path) diff --git a/requirements.txt b/requirements.txt index bcd2db243..de80b0949 100644 --- a/requirements.txt +++ b/requirements.txt @@ -52,4 +52,4 @@ websocket-client==1.6.2 aiofiles==23.2.1 gitpython==3.1.40 zhipuai==1.0.7 - +socksio~=1.0.0 diff --git a/tests/metagpt/test_gpt.py b/tests/metagpt/test_gpt.py index 431858d4c..291531122 100644 --- a/tests/metagpt/test_gpt.py +++ b/tests/metagpt/test_gpt.py @@ -5,9 +5,10 @@ @Author : alexanderwu @File : test_gpt.py """ - +import openai import pytest +from metagpt.config import CONFIG from metagpt.logs import logger @@ -18,14 +19,17 @@ class TestGPT: logger.info(answer) assert len(answer) > 0 - # def test_gptapi_ask_batch(self, llm_api): - # answer = llm_api.ask_batch(['请扮演一个Google Python专家工程师,如果理解,回复明白', '写一个hello world']) - # assert len(answer) > 0 + def test_gptapi_ask_batch(self, llm_api): + answer = llm_api.ask_batch(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) + assert len(answer) > 0 def test_llm_api_ask_code(self, llm_api): - answer = llm_api.ask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) - logger.info(answer) - assert len(answer) > 0 + try: + answer = llm_api.ask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) + logger.info(answer) + assert len(answer) > 0 + except openai.NotFoundError: + assert CONFIG.openai_api_type == "azure" @pytest.mark.asyncio async def test_llm_api_aask(self, llm_api): @@ -35,9 +39,12 @@ class TestGPT: @pytest.mark.asyncio async def test_llm_api_aask_code(self, llm_api): - answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) - logger.info(answer) - assert len(answer) > 0 + try: + answer = await llm_api.aask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) + logger.info(answer) + assert len(answer) > 0 + except openai.NotFoundError: + assert CONFIG.openai_api_type == "azure" @pytest.mark.asyncio async def test_llm_api_costs(self, llm_api):