Merge branch 'dev' of https://github.com/geekan/MetaGPT into geekan/dev

This commit is contained in:
莘权 马 2023-12-26 19:13:13 +08:00
commit dce4696f17
57 changed files with 281 additions and 706 deletions

View file

@ -14,7 +14,7 @@ from pydantic import BaseModel, Field
from metagpt.actions.action_node import ActionNode
from metagpt.llm import LLM
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import (
CodeSummarizeContext,
CodingContext,
@ -27,7 +27,7 @@ action_subclass_registry = {}
class Action(BaseModel):
name: str = ""
llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True)
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
prefix = "" # aask*时会加上prefix作为system_message
desc = "" # for skill manager

View file

@ -15,7 +15,7 @@ from pydantic import BaseModel, create_model, root_validator, validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.config import CONFIG
from metagpt.llm import BaseGPTAPI
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
from metagpt.utils.common import OutputParser, general_after_log
@ -60,7 +60,7 @@ class ActionNode:
# Action Context
context: str # all the context, including all necessary info
llm: BaseGPTAPI # LLM with aask interface
llm: BaseLLM # LLM with aask interface
children: dict[str, "ActionNode"]
# Action Input

View file

@ -5,7 +5,7 @@ from pydantic import Field
from metagpt.actions.write_code import WriteCode
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.highlight import highlight
@ -33,7 +33,7 @@ def run(*args) -> pd.DataFrame:
class CloneFunction(WriteCode):
name: str = "CloneFunction"
context: list[Message] = []
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
def _save(self, code_path, code):
if isinstance(code_path, str):

View file

@ -15,7 +15,7 @@ from pydantic import Field
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
from metagpt.llm import LLM, BaseGPTAPI
from metagpt.llm import LLM, BaseLLM
from metagpt.logs import logger
from metagpt.schema import RunCodeContext, RunCodeResult
from metagpt.utils.common import CodeParser
@ -52,7 +52,7 @@ Now you should start rewriting the code:
class DebugError(Action):
name: str = "DebugError"
context: RunCodeContext = Field(default_factory=RunCodeContext)
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
async def run(self, *args, **kwargs) -> str:
output_doc = await FileRepository.get_file(

View file

@ -27,7 +27,7 @@ from metagpt.const import (
)
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Document, Documents, Message
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.mermaid import mermaid_to_file
@ -44,7 +44,7 @@ NEW_REQ_TEMPLATE = """
class WriteDesign(Action):
name: str = ""
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
desc: str = (
"Based on the PRD, think about the system design, and design the corresponding APIs, "
"data structures, library tables, processes, and paths. Please provide your design, feedback "

View file

@ -12,13 +12,13 @@ from pydantic import Field
from metagpt.actions.action import Action
from metagpt.llm import LLM
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
class DesignReview(Action):
name: str = "DesignReview"
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
async def run(self, prd, api_design):
prompt = (

View file

@ -10,14 +10,14 @@ from pydantic import Field
from metagpt.actions import Action
from metagpt.llm import LLM
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
class ExecuteTask(Action):
name: str = "ExecuteTask"
context: list[Message] = []
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
async def run(self, *args, **kwargs):
pass

View file

@ -26,7 +26,7 @@ from metagpt.prompts.invoice_ocr import (
EXTRACT_OCR_MAIN_INFO_PROMPT,
REPLY_OCR_QUESTION_PROMPT,
)
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.common import OutputParser
from metagpt.utils.file import File
@ -42,7 +42,7 @@ class InvoiceOCR(Action):
name: str = "InvoiceOCR"
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
@staticmethod
async def _check_file_type(file_path: Path) -> str:
@ -132,7 +132,7 @@ class GenerateTable(Action):
name: str = "GenerateTable"
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
language: str = "ch"
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
@ -177,7 +177,7 @@ class ReplyQuestion(Action):
name: str = "ReplyQuestion"
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
language: str = "ch"
async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str:

View file

@ -17,7 +17,7 @@ from metagpt.actions import Action, ActionOutput
from metagpt.config import CONFIG
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
from metagpt.llm import LLM
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Document
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
@ -28,7 +28,7 @@ class PrepareDocuments(Action):
name: str = "PrepareDocuments"
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
def _init_repo(self):
"""Initialize the Git environment."""

View file

@ -27,7 +27,7 @@ from metagpt.const import (
)
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Document, Documents
from metagpt.utils.file_repository import FileRepository
@ -43,7 +43,7 @@ NEW_REQ_TEMPLATE = """
class WriteTasks(Action):
name: str = "CreateTasks"
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
async def run(self, with_messages, schema=CONFIG.prompt_schema):
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)

View file

@ -11,7 +11,7 @@ from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.tools.search_engine import SearchEngine
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
from metagpt.utils.common import OutputParser
@ -82,7 +82,7 @@ class CollectLinks(Action):
name: str = "CollectLinks"
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
desc: str = "Collect links from a search engine."
search_engine: SearchEngine = Field(default_factory=SearchEngine)
rank_func: Union[Callable[[list[str]], None], None] = None
@ -177,7 +177,7 @@ class WebBrowseAndSummarize(Action):
name: str = "WebBrowseAndSummarize"
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: WebBrowserEngine = Field(
@ -248,7 +248,7 @@ class ConductResearch(Action):
name: str = "ConductResearch"
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
def __init__(self, **kwargs):
super().__init__(**kwargs)

View file

@ -22,7 +22,7 @@ from pydantic import Field
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.llm import LLM, BaseGPTAPI
from metagpt.llm import LLM, BaseLLM
from metagpt.logs import logger
from metagpt.schema import RunCodeContext, RunCodeResult
from metagpt.utils.exceptions import handle_exception
@ -79,7 +79,7 @@ standard errors:
class RunCode(Action):
name: str = "RunCode"
context: RunCodeContext = Field(default_factory=RunCodeContext)
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
@classmethod
@handle_exception

View file

@ -14,7 +14,7 @@ from metagpt.actions import Action
from metagpt.config import CONFIG, Config
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
@ -109,7 +109,7 @@ You are a member of a professional butler team and will provide helpful suggesti
class SearchAndSummarize(Action):
name: str = ""
content: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
config: None = Field(default_factory=Config)
engine: Optional[SearchEngineType] = CONFIG.search_engine
search_func: Optional[Any] = None

View file

@ -13,7 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.llm import LLM, BaseGPTAPI
from metagpt.llm import LLM, BaseLLM
from metagpt.logs import logger
from metagpt.schema import CodeSummarizeContext
from metagpt.utils.file_repository import FileRepository
@ -95,7 +95,7 @@ flowchart TB
class SummarizeCode(Action):
name: str = "SummarizeCode"
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
@retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60))
async def summarize_code(self, prompt):

View file

@ -31,7 +31,7 @@ from metagpt.const import (
)
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import CodingContext, Document, RunCodeResult
from metagpt.utils.common import CodeParser
from metagpt.utils.file_repository import FileRepository
@ -90,7 +90,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
class WriteCode(Action):
name: str = "WriteCode"
context: Document = Field(default_factory=Document)
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
async def write_code(self, prompt) -> str:

View file

@ -16,7 +16,7 @@ from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import CodingContext
from metagpt.utils.common import CodeParser
@ -123,7 +123,7 @@ REWRITE_CODE_TEMPLATE = """
class WriteCodeReview(Action):
name: str = "WriteCodeReview"
context: CodingContext = Field(default_factory=CodingContext)
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename):

View file

@ -28,7 +28,7 @@ from pydantic import Field
from metagpt.actions.action import Action
from metagpt.llm import LLM
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.common import OutputParser
from metagpt.utils.pycst import merge_docstring
@ -163,7 +163,7 @@ class WriteDocstring(Action):
desc: str = "Write docstring for code."
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
async def run(
self,

View file

@ -38,7 +38,7 @@ from metagpt.const import (
)
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import BugFixContext, Document, Documents, Message
from metagpt.utils.common import CodeParser
from metagpt.utils.file_repository import FileRepository
@ -67,7 +67,7 @@ NEW_REQ_TEMPLATE = """
class WritePRD(Action):
name: str = ""
content: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are

View file

@ -12,13 +12,13 @@ from pydantic import Field
from metagpt.actions.action import Action
from metagpt.llm import LLM
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
class WritePRDReview(Action):
name: str = ""
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
prd: Optional[str] = None
desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"
prd_review_prompt_template: str = """

View file

@ -11,7 +11,7 @@ from pydantic import Field
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.llm import LLM
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
REVIEW = ActionNode(
key="Review",
@ -38,7 +38,7 @@ class WriteReview(Action):
"""Write a review for the given context."""
name: str = "WriteReview"
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
async def run(self, context):
return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json")

View file

@ -13,14 +13,14 @@ from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
class WriteTeachingPlanPart(Action):
"""Write Teaching Plan Part"""
context: Optional[str] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
topic: str = ""
language: str = "Chinese"
rsp: Optional[str] = None

View file

@ -17,7 +17,7 @@ from metagpt.config import CONFIG
from metagpt.const import TEST_CODES_FILE_REPO
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Document, TestingContext
from metagpt.utils.common import CodeParser
@ -45,7 +45,7 @@ you should correctly import the necessary classes based on these file locations!
class WriteTest(Action):
name: str = "WriteTest"
context: Optional[TestingContext] = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
async def write_code(self, prompt):
code_rsp = await self._aask(prompt)

View file

@ -14,7 +14,7 @@ from pydantic import Field
from metagpt.actions import Action
from metagpt.llm import LLM
from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.common import OutputParser
@ -27,7 +27,7 @@ class WriteDirectory(Action):
"""
name: str = "WriteDirectory"
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
language: str = "Chinese"
async def run(self, topic: str, *args, **kwargs) -> Dict:
@ -54,7 +54,7 @@ class WriteContent(Action):
"""
name: str = "WriteContent"
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
directory: dict = dict()
language: str = "Chinese"

View file

@ -9,14 +9,14 @@
from typing import Optional
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.llm_provider_registry import LLM_REGISTRY
_ = HumanProvider() # Avoid pre-commit error
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseGPTAPI:
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseLLM:
"""get the default llm provider"""
if provider is None:
provider = CONFIG.get_default_llm_provider_enum()

View file

@ -10,14 +10,15 @@
"""
import json
import re
from typing import Dict, List
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from metagpt.config import CONFIG
from metagpt.const import DEFAULT_LANGUAGE
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
from metagpt.logs import logger
from metagpt.provider import MetaGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message, SimpleMessage
from metagpt.utils.redis import Redis
@ -30,6 +31,7 @@ class BrainMemory(BaseModel):
is_dirty: bool = False
last_talk: str = None
cacheable: bool = True
llm: Optional[BaseLLM] = None
def add_talk(self, msg: Message):
"""
@ -120,6 +122,7 @@ class BrainMemory(BaseModel):
if isinstance(llm, MetaGPTAPI):
return await self._metagpt_summarize(max_words=max_words)
self.llm = llm
return await self._openai_summarize(llm=llm, max_words=max_words, keep_language=keep_language, limit=limit)
async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1):
@ -131,7 +134,7 @@ class BrainMemory(BaseModel):
text_length = len(text)
if limit > 0 and text_length < limit:
return text
summary = await llm.summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
summary = await self._summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
if summary:
await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS)
return summary
@ -251,3 +254,74 @@ class BrainMemory(BaseModel):
texts.append(t)
return "\n".join(texts)
async def _summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str:
max_token_count = DEFAULT_MAX_TOKENS
max_count = 100
text_length = len(text)
if limit > 0 and text_length < limit:
return text
summary = ""
while max_count > 0:
if text_length < max_token_count:
summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
break
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
summaries = []
for ws in text_windows:
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
summaries.append(response)
if len(summaries) == 1:
summary = summaries[0]
break
# Merged and retry
text = "\n".join(summaries)
text_length = len(text)
max_count -= 1 # safeguard
return summary
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
"""Generate text summary"""
if len(text) < max_words:
return text
if keep_language:
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
else:
command = f"Translate the above content into a summary of less than {max_words} words."
msg = text + "\n\n" + command
logger.debug(f"summary ask:{msg}")
response = await self.llm.aask(msg=msg, system_msgs=[])
logger.debug(f"summary rsp: {response}")
return response
@staticmethod
def split_texts(text: str, window_size) -> List[str]:
"""Splitting long text into sliding windows text"""
if window_size <= 0:
window_size = DEFAULT_TOKEN_SIZE
total_len = len(text)
if total_len <= window_size:
return [text]
padding_size = 20 if window_size > 20 else 0
windows = []
idx = 0
data_len = window_size - padding_size
while idx < total_len:
if window_size + idx > total_len: # 不足一个滑窗
windows.append(text[idx:])
break
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
# window_size=3, padding_size=1
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
# idx=2, | idx=5 | idx=8 | ...
w = text[idx : idx + window_size]
windows.append(w)
idx += data_len
return windows

View file

@ -6,22 +6,22 @@
@File : __init__.py
"""
from metagpt.provider.fireworks_api import FireWorksGPTAPI
from metagpt.provider.fireworks_api import FireworksLLM
from metagpt.provider.google_gemini_api import GeminiGPTAPI
from metagpt.provider.ollama_api import OllamaGPTAPI
from metagpt.provider.ollama_api import OllamaLLM
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.openai_api import OpenAILLM
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.azure_openai_api import AzureOpenAIGPTAPI
from metagpt.provider.azure_openai_api import AzureOpenAILLM
from metagpt.provider.metagpt_api import MetaGPTAPI
__all__ = [
"FireWorksGPTAPI",
"FireworksLLM",
"GeminiGPTAPI",
"OpenLLMGPTAPI",
"OpenAIGPTAPI",
"OpenAILLM",
"ZhiPuAIGPTAPI",
"AzureOpenAIGPTAPI",
"AzureOpenAILLM",
"MetaGPTAPI",
"OllamaGPTAPI",
"OllamaLLM",
]

View file

@ -10,60 +10,36 @@
"""
from openai import AsyncAzureOpenAI, AzureOpenAI
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
from openai import AsyncAzureOpenAI
from openai._base_client import AsyncHttpxClientWrapper
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.config import LLMProviderEnum
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
from metagpt.provider.openai_api import OpenAILLM
@register_provider(LLMProviderEnum.AZURE_OPENAI)
class AzureOpenAIGPTAPI(OpenAIGPTAPI):
class AzureOpenAILLM(OpenAILLM):
"""
Check https://platform.openai.com/examples for examples
"""
def __init__(self):
self.config: Config = CONFIG
self._init_openai()
self.auto_max_tokens = False
RateLimiter.__init__(self, rpm=self.rpm)
def _make_client(self):
kwargs, async_kwargs = self._make_client_kwargs()
def _init_client(self):
kwargs = self._make_client_kwargs()
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
self.client = AzureOpenAI(**kwargs)
self.async_client = AsyncAzureOpenAI(**async_kwargs)
self.async_client = AsyncAzureOpenAI(**kwargs)
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs
def _make_client_kwargs(self) -> (dict, dict):
def _make_client_kwargs(self) -> dict:
kwargs = dict(
api_key=self.config.OPENAI_API_KEY,
api_version=self.config.OPENAI_API_VERSION,
azure_endpoint=self.config.OPENAI_BASE_URL,
)
async_kwargs = kwargs.copy()
# to use proxy, openai v1 needs http_client
proxy_params = self._get_proxy_params()
if proxy_params:
kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params)
async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs, async_kwargs
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self.get_max_tokens(messages),
"n": 1,
"stop": None,
"temperature": 0.3,
"model": self.model,
}
if configs:
kwargs.update(configs)
kwargs["timeout"] = max(CONFIG.timeout, timeout)
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs

View file

@ -1,30 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/5 23:00
@Author : alexanderwu
@File : base_chatbot.py
@Modified By: mashenquan, 2023/11/21. Add `timeout`.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
@dataclass
class BaseChatbot(ABC):
"""Abstract GPT class"""
mode: str = "API"
use_system_prompt: bool = True
@abstractmethod
def ask(self, msg: str, timeout=3) -> str:
"""Ask GPT a question and get an answer"""
@abstractmethod
def ask_batch(self, msgs: list, timeout=3) -> str:
"""Ask GPT multiple questions and get a series of answers"""
@abstractmethod
def ask_code(self, msgs: list, timeout=3) -> str:
"""Ask GPT multiple questions and get a piece of code"""

View file

@ -3,19 +3,18 @@
"""
@Time : 2023/5/5 23:04
@Author : alexanderwu
@File : base_gpt_api.py
@File : base_llm.py
@Desc : mashenquan, 2023/8/22. + try catch
"""
import json
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import Optional
from metagpt.provider.base_chatbot import BaseChatbot
class BaseLLM(ABC):
"""LLM API abstract class, requiring all inheritors to provide a series of standard capabilities"""
class BaseGPTAPI(BaseChatbot):
"""GPT API abstract class, requiring all inheritors to provide a series of standard capabilities"""
use_system_prompt: bool = True
system_prompt = "You are a helpful assistant."
def _user_msg(self, msg: str) -> dict[str, str]:
@ -33,17 +32,11 @@ class BaseGPTAPI(BaseChatbot):
def _default_system_msg(self):
return self._system_msg(self.system_prompt)
def ask(self, msg: str, timeout=3) -> str:
message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
rsp = self.completion(message, timeout=timeout)
return self.get_choice_text(rsp)
async def aask(
self,
msg: str,
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
generator: bool = False,
timeout=3,
stream=True,
) -> str:
@ -54,23 +47,12 @@ class BaseGPTAPI(BaseChatbot):
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))
rsp = await self.acompletion_text(message, stream=stream, generator=generator, timeout=timeout)
# logger.debug(rsp)
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
return rsp
def _extract_assistant_rsp(self, context):
return "\n".join([i["content"] for i in context if i["role"] == "assistant"])
def ask_batch(self, msgs: list, timeout=3) -> str:
context = []
for msg in msgs:
umsg = self._user_msg(msg)
context.append(umsg)
rsp = self.completion(context, timeout=timeout)
rsp_text = self.get_choice_text(rsp)
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
async def aask_batch(self, msgs: list, timeout=3) -> str:
"""Sequential questioning"""
context = []
@ -81,26 +63,11 @@ class BaseGPTAPI(BaseChatbot):
context.append(self._assistant_msg(rsp_text))
return self._extract_assistant_rsp(context)
def ask_code(self, msgs: list[str], timeout=3) -> str:
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
rsp_text = self.ask_batch(msgs, timeout=timeout)
return rsp_text
async def aask_code(self, msgs: list[str], timeout=3) -> str:
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
rsp_text = await self.aask_batch(msgs, timeout=timeout)
return rsp_text
@abstractmethod
def completion(self, messages: list[dict], timeout=3):
"""All GPTAPIs are required to provide the standard OpenAI completion interface
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "hello, show me python hello world code"},
# {"role": "assistant", "content": ...}, # If there is an answer in the history, also include it
]
"""
@abstractmethod
async def acompletion(self, messages: list[dict], timeout=3):
"""Asynchronous version of completion
@ -113,7 +80,7 @@ class BaseGPTAPI(BaseChatbot):
"""
@abstractmethod
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""Asynchronous version of completion. Return str. Support stream-print"""
def get_choice_text(self, rsp: dict) -> str:
@ -159,16 +126,3 @@ class BaseGPTAPI(BaseChatbot):
{'language': 'python', 'code': "print('Hello, World!')"}
"""
return json.loads(self.get_choice_function(rsp)["arguments"])
def messages_to_prompt(self, messages: list[dict]):
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
return "\n".join([f"{i.role}: {i.content}" for i in messages])
def messages_to_dict(self, messages):
"""objects to [{"role": "user", "content": msg}] etc."""
return [i.to_dict() for i in messages]
@abstractmethod
async def close(self):
"""Close connection"""
pass

View file

@ -18,7 +18,7 @@ from tenacity import (
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter, log_and_reraise
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
from metagpt.utils.cost_manager import CostManager, Costs
MODEL_GRADE_TOKEN_COSTS = {
@ -72,18 +72,17 @@ class FireworksCostManager(CostManager):
@register_provider(LLMProviderEnum.FIREWORKS)
class FireWorksGPTAPI(OpenAIGPTAPI):
class FireworksLLM(OpenAILLM):
def __init__(self):
self.config: Config = CONFIG
self.__init_fireworks()
self.auto_max_tokens = False
self._cost_manager = FireworksCostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_fireworks(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._make_client()
self._init_client()
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
def _make_client_kwargs(self) -> (dict, dict):
@ -103,7 +102,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
return self._cost_manager.get_costs()
async def _achat_completion_stream(self, messages: list[dict]) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages), stream=True
)
@ -133,9 +132,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
"""when streaming, print each token in place."""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -47,8 +47,7 @@ MAX_CONNECTION_RETRIES = 2
# Has one attribute per thread, 'session'.
_thread_context = threading.local()
LLM_LOG = os.environ.get("LLM_LOG")
LLM_LOG = "debug"
LLM_LOG = os.environ.get("LLM_LOG", "debug")
class ApiType(Enum):

View file

@ -21,7 +21,7 @@ from tenacity import (
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
@ -42,7 +42,7 @@ class GeminiGenerativeModel(GenerativeModel):
@register_provider(LLMProviderEnum.GEMINI)
class GeminiGPTAPI(BaseGPTAPI):
class GeminiGPTAPI(BaseLLM):
"""
Refs to `https://ai.google.dev/tutorials/python_quickstart`
"""
@ -136,9 +136,7 @@ class GeminiGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -6,10 +6,10 @@ Author: garylin2099
from typing import Optional
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
class HumanProvider(BaseGPTAPI):
class HumanProvider(BaseLLM):
"""Humans provide themselves as a 'model', which actually takes in human input as its response.
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
"""
@ -31,10 +31,6 @@ class HumanProvider(BaseGPTAPI):
) -> str:
return self.ask(msg, timeout=timeout)
def completion(self, messages: list[dict], timeout=3):
"""dummy implementation of abstract method in base"""
return []
async def acompletion(self, messages: list[dict], timeout=3):
"""dummy implementation of abstract method in base"""
return []
@ -42,7 +38,3 @@ class HumanProvider(BaseGPTAPI):
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""dummy implementation of abstract method in base"""
return ""
async def close(self):
"""Close connection"""
pass

View file

@ -6,11 +6,11 @@
@Desc : MetaGPT LLM provider.
"""
from metagpt.config import LLMProviderEnum
from metagpt.provider import OpenAIGPTAPI
from metagpt.provider import OpenAILLM
from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.METAGPT)
class MetaGPTAPI(OpenAIGPTAPI):
class MetaGPTAPI(OpenAILLM):
def __init__(self):
super().__init__()

View file

@ -16,7 +16,7 @@ from tenacity import (
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.const import LLM_API_TIMEOUT
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
@ -39,7 +39,7 @@ class OllamaCostManager(CostManager):
@register_provider(LLMProviderEnum.OLLAMA)
class OllamaGPTAPI(BaseGPTAPI):
class OllamaLLM(BaseLLM):
"""
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
"""
@ -54,12 +54,8 @@ class OllamaGPTAPI(BaseGPTAPI):
def __init_ollama(self, config: CONFIG):
assert config.ollama_api_base
self.model = config.ollama_api_model
def close(self):
pass
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
return kwargs
@ -87,18 +83,6 @@ class OllamaGPTAPI(BaseGPTAPI):
chunk = chunk.decode(encoding)
return json.loads(chunk)
def completion(self, messages: list[dict]) -> dict:
resp, _, _ = self.client.request(
method=self.http_method,
url=self.suffix_url,
params=self._const_kwargs(messages),
request_timeout=LLM_API_TIMEOUT,
)
resp = self._decode_and_load(resp)
usage = self.get_usage(resp)
self._update_costs(usage)
return resp
async def _achat_completion(self, messages: list[dict]) -> dict:
resp, _, _ = await self.client.arequest(
method=self.http_method,
@ -111,7 +95,7 @@ class OllamaGPTAPI(BaseGPTAPI):
self._update_costs(usage)
return resp
async def acompletion(self, messages: list[dict]) -> dict:
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
return await self._achat_completion(messages)
async def _achat_completion_stream(self, messages: list[dict]) -> str:
@ -147,9 +131,7 @@ class OllamaGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -7,7 +7,7 @@ from openai.types import CompletionUsage
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
from metagpt.provider.openai_api import OpenAILLM
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
@ -35,18 +35,17 @@ class OpenLLMCostManager(CostManager):
@register_provider(LLMProviderEnum.OPEN_LLM)
class OpenLLMGPTAPI(OpenAIGPTAPI):
class OpenLLMGPTAPI(OpenAILLM):
def __init__(self):
self.config: Config = CONFIG
self.__init_openllm()
self.auto_max_tokens = False
self._cost_manager = OpenLLMCostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openllm(self):
self.is_azure = False
self.rpm = int(self.config.get("RPM", 10))
self._make_client()
self._init_client()
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
def _make_client_kwargs(self) -> (dict, dict):

View file

@ -3,20 +3,17 @@
@Time : 2023/5/5 23:08
@Author : alexanderwu
@File : openai.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation;
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for isolation;
Change cost control from global to company level.
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
"""
import asyncio
import json
import time
from typing import List, Union
from typing import AsyncIterator, Union
import openai
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
from openai._base_client import AsyncHttpxClientWrapper
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from tenacity import (
@ -28,9 +25,8 @@ from tenacity import (
)
from metagpt.config import CONFIG, Config, LLMProviderEnum
from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
@ -43,31 +39,6 @@ from metagpt.utils.token_counter import (
)
class RateLimiter:
"""Rate control class, each call goes through wait_if_needed, sleep if rate control is needed"""
def __init__(self, rpm):
self.last_call_time = 0
# Here 1.1 is used because even if the calls are made strictly according to time,
# they will still be QOS'd; consider switching to simple error retry later
self.interval = 1.1 * 60 / rpm
self.rpm = rpm
def split_batches(self, batch):
return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)]
async def wait_if_needed(self, num_requests):
current_time = time.time()
elapsed_time = current_time - self.last_call_time
if elapsed_time < self.interval * num_requests:
remaining_time = self.interval * num_requests - elapsed_time
logger.info(f"sleep {remaining_time}")
await asyncio.sleep(remaining_time)
self.last_call_time = time.time()
def log_and_reraise(retry_state):
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
logger.warning(
@ -80,39 +51,31 @@ See FAQ 5.8
@register_provider(LLMProviderEnum.OPENAI)
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"""
Check https://platform.openai.com/examples for examples
"""
class OpenAILLM(BaseLLM):
"""Check https://platform.openai.com/examples for examples"""
def __init__(self):
self.config: Config = CONFIG
self._init_openai()
self._init_client()
self.auto_max_tokens = False
RateLimiter.__init__(self, rpm=self.rpm)
def _init_openai(self):
self.rpm = int(self.config.RPM or 10)
self._make_client()
def _make_client(self):
kwargs, async_kwargs = self._make_client_kwargs()
# https://github.com/openai/openai-python#async-usage
self.client = OpenAI(**kwargs)
self.async_client = AsyncOpenAI(**async_kwargs)
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
def _make_client_kwargs(self) -> (dict, dict):
kwargs = dict(api_key=self.config.OPENAI_API_KEY, base_url=self.config.OPENAI_BASE_URL)
async_kwargs = kwargs.copy()
def _init_client(self):
"""https://github.com/openai/openai-python#async-usage"""
kwargs = self._make_client_kwargs()
self.aclient = AsyncOpenAI(**kwargs)
def _make_client_kwargs(self) -> dict:
kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL}
# to use proxy, openai v1 needs http_client
proxy_params = self._get_proxy_params()
if proxy_params:
kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params)
async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
if proxy_params := self._get_proxy_params():
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
return kwargs, async_kwargs
return kwargs
def _get_proxy_params(self) -> dict:
params = {}
@ -123,8 +86,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return params
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
**self._cons_kwargs(messages, timeout=timeout), stream=True
)
@ -132,35 +95,26 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
yield chunk_message
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
kwargs = {
"messages": messages,
"max_tokens": self.get_max_tokens(messages),
"max_tokens": self._get_max_tokens(messages),
"n": 1,
"stop": None,
"temperature": 0.3,
"model": self.model,
"timeout": max(CONFIG.timeout, timeout),
}
if configs:
kwargs.update(configs)
kwargs["timeout"] = max(CONFIG.timeout, timeout)
if extra_kwargs:
kwargs.update(extra_kwargs)
return kwargs
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
kwargs = self._cons_kwargs(messages, timeout=timeout)
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
def _chat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages, timeout=timeout))
self._update_costs(rsp.usage)
return rsp
def completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
return self._chat_completion(messages, timeout=timeout)
async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
return await self._achat_completion(messages, timeout=timeout)
@ -171,12 +125,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
retry=retry_if_exception_type(APIConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""when streaming, print each token in place."""
if stream:
resp = self._achat_completion_stream(messages, timeout=timeout)
if generator:
return resp
collected_messages = []
async for i in resp:
@ -192,9 +144,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return self.get_choice_text(rsp)
def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
"""
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
"""
"""Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create"""
if "tools" not in kwargs:
configs = {
"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}],
@ -204,14 +154,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> ChatCompletion:
rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs))
self._update_costs(rsp.usage)
return rsp
async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion:
kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
self._update_costs(rsp.usage)
return rsp
@ -231,56 +176,28 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
)
return messages
def ask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
"""Use function of tools to ask a code.
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
Examples:
>>> llm = OpenAIGPTAPI()
>>> llm.ask_code("Write a python hello world code.")
{'language': 'python', 'code': "print('Hello, World!')"}
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
>>> llm.ask_code(msg)
{'language': 'python', 'code': "print('Hello, World!')"}
"""
messages = self._process_message(messages)
rsp = self._chat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
"""Use function of tools to ask a code.
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
Examples:
>>> llm = OpenAIGPTAPI()
>>> rsp = await llm.ask_code("Write a python hello world code.")
>>> rsp
{'language': 'python', 'code': "print('Hello, World!')"}
>>> llm = OpenAILLM()
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
>>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
>>> rsp = await llm.aask_code(msg)
# -> {'language': 'python', 'code': "print('Hello, World!')"}
"""
messages = self._process_message(messages)
try:
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
except openai.BadRequestError as e:
logger.error(f"API TYPE:{CONFIG.OPENAI_API_TYPE}, err:{e}")
raise e
rsp = await self._achat_completion_function(messages, **kwargs)
return self.get_choice_function_arguments(rsp)
@handle_exception
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
"""Required to provide the first function arguments of choice.
:return dict: return the first function arguments of choice, for example,
{'language': 'python', 'code': "print('Hello, World!')"}
"""
try:
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
except json.JSONDecodeError:
return {}
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
def get_choice_text(self, rsp: ChatCompletion) -> str:
"""Required to provide the first text of choice"""
@ -295,134 +212,24 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
usage.prompt_tokens = count_message_tokens(messages, self.model)
usage.completion_tokens = count_string_tokens(rsp, self.model)
except Exception as e:
logger.error(f"usage calculation failed!: {e}")
logger.error(f"usage calculation failed: {e}")
return usage
async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[ChatCompletion]:
"""Return full JSON"""
split_batches = self.split_batches(batch)
all_results = []
for small_batch in split_batches:
logger.info(small_batch)
await self.wait_if_needed(len(small_batch))
future = [self.acompletion(prompt, timeout=timeout) for prompt in small_batch]
results = await asyncio.gather(*future)
logger.info(results)
all_results.extend(results)
return all_results
async def acompletion_batch_text(self, batch: list[list[dict]], timeout=3) -> list[str]:
"""Only return plain text"""
raw_results = await self.acompletion_batch(batch, timeout=timeout)
results = []
for idx, raw_result in enumerate(raw_results, start=1):
result = self.get_choice_text(raw_result)
results.append(result)
logger.info(f"Result of task {idx}: {result}")
return results
@handle_exception
def _update_costs(self, usage: CompletionUsage):
if CONFIG.calc_usage and usage:
try:
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
except Exception as e:
logger.error(f"updating costs failed!, exp: {e}")
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
def get_costs(self) -> Costs:
return CONFIG.cost_manager.get_costs()
def get_max_tokens(self, messages: list[dict]):
def _get_max_tokens(self, messages: list[dict]):
if not self.auto_max_tokens:
return CONFIG.max_tokens_rsp
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
def moderation(self, content: Union[str, list[str]]):
return self.client.moderations.create(input=content)
@handle_exception
async def amoderation(self, content: Union[str, list[str]]):
return await self.async_client.moderations.create(input=content)
async def close(self):
"""Close connection"""
if self.client:
self.client.close()
self.client = None
if self.async_client:
await self.async_client.close()
self.async_client = None
async def summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str:
max_token_count = DEFAULT_MAX_TOKENS
max_count = 100
text_length = len(text)
if limit > 0 and text_length < limit:
return text
summary = ""
while max_count > 0:
if text_length < max_token_count:
summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
break
padding_size = 20 if max_token_count > 20 else 0
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
summaries = []
for ws in text_windows:
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
summaries.append(response)
if len(summaries) == 1:
summary = summaries[0]
break
# Merged and retry
text = "\n".join(summaries)
text_length = len(text)
max_count -= 1 # safeguard
return summary
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
"""Generate text summary"""
if len(text) < max_words:
return text
if keep_language:
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
else:
command = f"Translate the above content into a summary of less than {max_words} words."
msg = text + "\n\n" + command
logger.debug(f"summary ask:{msg}")
response = await self.aask(msg=msg, system_msgs=[])
logger.debug(f"summary rsp: {response}")
return response
@staticmethod
def split_texts(text: str, window_size) -> List[str]:
"""Splitting long text into sliding windows text"""
if window_size <= 0:
window_size = DEFAULT_TOKEN_SIZE
total_len = len(text)
if total_len <= window_size:
return [text]
padding_size = 20 if window_size > 20 else 0
windows = []
idx = 0
data_len = window_size - padding_size
while idx < total_len:
if window_size + idx > total_len: # 不足一个滑窗
windows.append(text[idx:])
break
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
# window_size=3, padding_size=1
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
# idx=2, | idx=5 | idx=8 | ...
w = text[idx : idx + window_size]
windows.append(w)
idx += data_len
return windows
"""Moderate content."""
return await self.aclient.moderations.create(input=content)

View file

@ -1,9 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/7/21 11:15
@Author : Leo Xiao
@File : anthropic_api.py
@File : spark_api.py
"""
import _thread as thread
import base64
@ -13,7 +11,6 @@ import hmac
import json
import ssl
from time import mktime
from typing import Optional
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time
@ -21,52 +18,29 @@ import websocket # 使用websocket_client
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.SPARK)
class SparkGPTAPI(BaseGPTAPI):
class SparkLLM(BaseLLM):
def __init__(self):
logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")
def close(self):
pass
def ask(self, msg: str) -> str:
message = [self._default_system_msg(), self._user_msg(msg)]
rsp = self.completion(message)
return rsp
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str:
if system_msgs:
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
else:
message = [self._default_system_msg(), self._user_msg(msg)]
rsp = await self.acompletion(message)
logger.debug(message)
return rsp
def get_choice_text(self, rsp: dict) -> str:
return rsp["payload"]["choices"]["text"][-1]["content"]
async def acompletion_text(
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
# 不支持
logger.error("该功能禁用。")
w = GetMessageFromWeb(messages)
return w.run()
async def acompletion(self, messages: list[dict]):
async def acompletion(self, messages: list[dict], timeout=3):
# 不支持异步
w = GetMessageFromWeb(messages)
return w.run()
def completion(self, messages: list[dict]):
w = GetMessageFromWeb(messages)
return w.run()
class GetMessageFromWeb:
class WsParam:

View file

@ -17,7 +17,7 @@ from tenacity import (
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import log_llm_stream, logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import log_and_reraise
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
@ -31,7 +31,7 @@ class ZhiPuEvent(Enum):
@register_provider(LLMProviderEnum.ZHIPUAI)
class ZhiPuAIGPTAPI(BaseGPTAPI):
class ZhiPuAIGPTAPI(BaseLLM):
"""
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
From now, there is only one model named `chatglm_turbo`
@ -131,7 +131,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
retry=retry_if_exception_type(ConnectionError),
retry_error_callback=log_and_reraise,
)
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
"""response in async with stream or non-stream mode"""
if stream:
return await self._achat_completion_stream(messages)

View file

@ -36,7 +36,7 @@ from metagpt.const import SERDESER_PATH
from metagpt.llm import LLM, HumanProvider
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message, MessageQueue
from metagpt.utils.common import (
any_to_name,
@ -141,7 +141,7 @@ class Role(BaseModel):
desc: str = ""
is_human: bool = False
_llm: BaseGPTAPI = Field(default_factory=LLM) # Each role has its own LLM, use different system message
_llm: BaseLLM = Field(default_factory=LLM) # Each role has its own LLM, use different system message
_role_id: str = ""
_states: list[str] = []
_actions: list[Action] = []

View file

@ -19,7 +19,7 @@ from metagpt.actions import UserRequirement
from metagpt.actions.execute_task import ExecuteTask
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.make_sk_kernel import make_sk_kernel
@ -44,7 +44,7 @@ class SkAgent(Role):
plan: Any = None
planner_cls: Any = None
planner: Any = None
llm: BaseGPTAPI = Field(default_factory=LLM)
llm: BaseLLM = Field(default_factory=LLM)
kernel: Kernel = Field(default_factory=Kernel)
import_semantic_skill_from_directory: Type[Kernel.import_semantic_skill_from_directory] = None
import_skill: Type[Kernel.import_skill] = None

View file

@ -21,10 +21,6 @@ class OpenAIText2Image:
"""
self._llm = LLM()
def __del__(self):
if self._llm:
self._llm.close()
async def text_2_image(self, text, size_type="1024x1024"):
"""Text to image

View file

@ -4,7 +4,7 @@
import json
from pathlib import Path
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
from metagpt.provider.openai_api import OpenAILLM as GPTAPI
ICL_SAMPLE = """Interface definition:
```text
@ -278,11 +278,11 @@ class UTGenerator:
question += self.build_api_doc(node, path, method)
self.ask_gpt_and_save(question, tag, summary)
def gpt_msgs_to_code(self, messages: list) -> str:
async def gpt_msgs_to_code(self, messages: list) -> str:
"""Choose based on different calling methods"""
result = ""
if self.chatgpt_method == "API":
result = GPTAPI().ask_code(msgs=messages)
result = await GPTAPI().aask_code(msgs=messages)
return result

View file

@ -22,9 +22,9 @@ async def test_design_api():
for prd in inputs:
await FileRepository.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
design_api = WriteDesign("design_api")
design_api = WriteDesign()
result = await design_api.run([Message(content=prd, instruct_content=None)])
result = await design_api.run(Message(content=prd, instruct_content=None))
logger.info(result)
assert result

View file

@ -6,10 +6,26 @@
@File : test_project_management.py
"""
import pytest
class TestCreateProjectPlan:
pass
from metagpt.actions.project_management import WriteTasks
from metagpt.config import CONFIG
from metagpt.const import PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.utils.file_repository import FileRepository
from tests.metagpt.actions.mock_json import DESIGN, PRD
class TestAssignTasks:
pass
@pytest.mark.asyncio
async def test_design_api():
await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
logger.info(CONFIG.git_repo)
action = WriteTasks()
result = await action.run(Message(content="", instruct_content=None))
logger.info(result)
assert result

View file

@ -10,7 +10,7 @@ import pytest
from metagpt.actions.write_code import WriteCode
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
from metagpt.provider.openai_api import OpenAILLM as LLM
from metagpt.schema import CodingContext, Document
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE

View file

@ -8,7 +8,7 @@
import pytest
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message
default_chat_resp = {
@ -27,14 +27,14 @@ prompt_msg = "who are you"
resp_content = default_chat_resp["choices"][0]["message"]["content"]
class MockBaseGPTAPI(BaseGPTAPI):
class MockBaseGPTAPI(BaseLLM):
def completion(self, messages: list[dict], timeout=3):
return default_chat_resp
async def acompletion(self, messages: list[dict], timeout=3):
return default_chat_resp
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
return resp_content
async def close(self):
@ -47,11 +47,6 @@ def test_base_gpt_api():
assert "user" in str(message)
base_gpt_api = MockBaseGPTAPI()
msg_prompt = base_gpt_api.messages_to_prompt([message])
assert msg_prompt == "user: hello"
msg_dict = base_gpt_api.messages_to_dict([message])
assert msg_dict == [{"role": "user", "content": "hello"}]
openai_funccall_resp = {
"choices": [
@ -87,14 +82,14 @@ def test_base_gpt_api():
choice_text = base_gpt_api.get_choice_text(openai_funccall_resp)
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
resp = base_gpt_api.ask(prompt_msg)
assert resp == resp_content
# resp = base_gpt_api.ask(prompt_msg)
# assert resp == resp_content
resp = base_gpt_api.ask_batch([prompt_msg])
assert resp == resp_content
# resp = base_gpt_api.ask_batch([prompt_msg])
# assert resp == resp_content
resp = base_gpt_api.ask_code([prompt_msg])
assert resp == resp_content
# resp = base_gpt_api.ask_code([prompt_msg])
# assert resp == resp_content
@pytest.mark.asyncio

View file

@ -13,7 +13,7 @@ from openai.types.completion_usage import CompletionUsage
from metagpt.provider.fireworks_api import (
MODEL_GRADE_TOKEN_COSTS,
FireworksCostManager,
FireWorksGPTAPI,
FireworksLLM,
)
resp_content = "I'm fireworks"
@ -55,17 +55,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return default_resp.choices[0].message.content
def test_fireworks_completion(mocker):
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion)
fireworks_gpt = FireWorksGPTAPI()
resp = fireworks_gpt.completion(messages)
assert resp.choices[0].message.content == resp_content
resp = fireworks_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_fireworks_acompletion(mocker):
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion)
@ -73,7 +62,7 @@ async def test_fireworks_acompletion(mocker):
mocker.patch(
"metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
)
fireworks_gpt = FireWorksGPTAPI()
fireworks_gpt = FireworksLLM()
resp = await fireworks_gpt.acompletion(messages, stream=False)
assert resp.choices[0].message.content in resp_content

View file

@ -35,16 +35,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
def test_gemini_completion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion)
gemini_gpt = GeminiGPTAPI()
resp = gemini_gpt.completion(messages)
assert resp.text == resp_content
resp = gemini_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion)

View file

@ -17,15 +17,6 @@ async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
return mock_llm_ask(msg)
def test_human_provider(mocker):
mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask)
human_provider = HumanProvider()
assert resp_content == human_provider.ask(None)
assert not human_provider.completion(messages=[])
@pytest.mark.asyncio
async def test_async_human_provider(mocker):
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)

View file

@ -5,7 +5,7 @@
import pytest
from metagpt.config import CONFIG
from metagpt.provider.ollama_api import OllamaGPTAPI
from metagpt.provider.ollama_api import OllamaLLM
prompt_msg = "who are you"
messages = [{"role": "user", "content": prompt_msg}]
@ -28,22 +28,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
def test_gemini_completion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion)
ollama_gpt = OllamaGPTAPI()
resp = ollama_gpt.completion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]
resp = ollama_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_gemini_acompletion(mocker):
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion)
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream)
ollama_gpt = OllamaGPTAPI()
ollama_gpt = OllamaLLM()
resp = await ollama_gpt.acompletion(messages)
assert resp["message"]["content"] == default_resp["message"]["content"]

View file

@ -2,13 +2,13 @@ from unittest.mock import Mock
import pytest
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.openai_api import OpenAILLM
from metagpt.schema import UserMessage
@pytest.mark.asyncio
async def test_aask_code():
llm = OpenAIGPTAPI()
llm = OpenAILLM()
msg = [{"role": "user", "content": "Write a python hello world code."}]
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
@ -18,7 +18,7 @@ async def test_aask_code():
@pytest.mark.asyncio
async def test_aask_code_str():
llm = OpenAIGPTAPI()
llm = OpenAILLM()
msg = "Write a python hello world code."
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
@ -28,7 +28,7 @@ async def test_aask_code_str():
@pytest.mark.asyncio
async def test_aask_code_Message():
llm = OpenAIGPTAPI()
llm = OpenAILLM()
msg = UserMessage("Write a python hello world code.")
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
@ -36,52 +36,6 @@ async def test_aask_code_Message():
assert len(rsp["code"]) > 0
def test_ask_code():
llm = OpenAIGPTAPI()
msg = [{"role": "user", "content": "Write a python hello world code."}]
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_str():
llm = OpenAIGPTAPI()
msg = "Write a python hello world code."
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_Message():
llm = OpenAIGPTAPI()
msg = UserMessage("Write a python hello world code.")
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_list_Message():
llm = OpenAIGPTAPI()
msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")]
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
def test_ask_code_list_str():
llm = OpenAIGPTAPI()
msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"]
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
print(rsp)
assert "language" in rsp
assert "code" in rsp
assert len(rsp["code"]) > 0
class TestOpenAI:
@pytest.fixture
def config(self):
@ -130,7 +84,7 @@ class TestOpenAI:
)
def test_make_client_kwargs_without_proxy(self, config):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config
kwargs, async_kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
@ -139,7 +93,7 @@ class TestOpenAI:
assert "http_client" not in async_kwargs
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config_azure
kwargs, async_kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
@ -148,14 +102,14 @@ class TestOpenAI:
assert "http_client" not in async_kwargs
def test_make_client_kwargs_with_proxy(self, config_proxy):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
instance = OpenAIGPTAPI()
instance = OpenAILLM()
instance.config = config_azure_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs

View file

@ -4,7 +4,7 @@
import pytest
from metagpt.provider.spark_api import SparkGPTAPI
from metagpt.provider.spark_api import SparkLLM
prompt_msg = "who are you"
resp_content = "I'm Spark"
@ -18,24 +18,13 @@ async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False,
return resp_content
def test_spark_completion(mocker):
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion)
spark_gpt = SparkGPTAPI()
resp = spark_gpt.completion([])
assert resp == resp_content
resp = spark_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion)
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion)
spark_gpt = SparkGPTAPI()
spark_gpt = SparkLLM()
resp = await spark_gpt.acompletion([], stream=False)
resp = await spark_gpt.acompletion([])
assert resp == resp_content
resp = await spark_gpt.aask(prompt_msg, stream=False)

View file

@ -28,18 +28,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
return resp_content
def test_zhipuai_completion(mocker):
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion)
zhipu_gpt = ZhiPuAIGPTAPI()
resp = zhipu_gpt.completion(messages)
assert resp["code"] == 200
assert resp["data"]["choices"][0]["content"] == resp_content
resp = zhipu_gpt.ask(prompt_msg)
assert resp == resp_content
@pytest.mark.asyncio
async def test_zhipuai_acompletion(mocker):
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion)

View file

@ -14,23 +14,6 @@ from metagpt.logs import logger
@pytest.mark.usefixtures("llm_api")
class TestGPT:
def test_llm_api_ask(self, llm_api):
answer = llm_api.ask("hello chatgpt")
logger.info(answer)
assert len(answer) > 0
def test_gptapi_ask_batch(self, llm_api):
answer = llm_api.ask_batch(["请扮演一个Google Python专家工程师如果理解回复明白", "写一个hello world"], timeout=60)
assert len(answer) > 0
def test_llm_api_ask_code(self, llm_api):
try:
answer = llm_api.ask_code(["请扮演一个Google Python专家工程师如果理解回复明白", "写一个hello world"])
logger.info(answer)
assert len(answer) > 0
except openai.BadRequestError:
assert CONFIG.OPENAI_API_TYPE == "azure"
@pytest.mark.asyncio
async def test_llm_api_aask(self, llm_api):
answer = await llm_api.aask("hello chatgpt", stream=False)

View file

@ -9,7 +9,7 @@
import pytest
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
from metagpt.provider.openai_api import OpenAILLM as LLM
@pytest.fixture()
@ -23,18 +23,11 @@ async def test_llm_aask(llm):
assert len(rsp) > 0
@pytest.mark.asyncio
async def test_llm_aask_batch(llm):
assert len(await llm.aask_batch(["hi", "write python hello world."])) > 0
@pytest.mark.asyncio
async def test_llm_acompletion(llm):
hello_msg = [{"role": "user", "content": "hello"}]
rsp = await llm.acompletion(hello_msg)
assert len(rsp.choices[0].message.content) > 0
assert len(await llm.acompletion_batch([hello_msg])) > 0
assert len(await llm.acompletion_batch_text([hello_msg])) > 0
if __name__ == "__main__":

View file

@ -1,21 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/7 17:23
@Author : alexanderwu
@File : test_custom_aio_session.py
"""
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAIGPTAPI
async def try_hello(api):
batch = [[{"role": "user", "content": "hello"}]]
results = await api.acompletion_batch_text(batch)
return results
async def aask_batch(api: OpenAIGPTAPI):
results = await api.aask_batch(["hi", "write python hello world."])
logger.info(results)
return results