mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
Merge branch 'dev' of https://github.com/geekan/MetaGPT into geekan/dev
This commit is contained in:
commit
dce4696f17
57 changed files with 281 additions and 706 deletions
|
|
@ -14,7 +14,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import (
|
||||
CodeSummarizeContext,
|
||||
CodingContext,
|
||||
|
|
@ -27,7 +27,7 @@ action_subclass_registry = {}
|
|||
|
||||
class Action(BaseModel):
|
||||
name: str = ""
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True)
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
|
||||
context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
|
||||
prefix = "" # aask*时会加上prefix,作为system_message
|
||||
desc = "" # for skill manager
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from pydantic import BaseModel, create_model, root_validator, validator
|
|||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import BaseGPTAPI
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
|
||||
from metagpt.utils.common import OutputParser, general_after_log
|
||||
|
|
@ -60,7 +60,7 @@ class ActionNode:
|
|||
|
||||
# Action Context
|
||||
context: str # all the context, including all necessary info
|
||||
llm: BaseGPTAPI # LLM with aask interface
|
||||
llm: BaseLLM # LLM with aask interface
|
||||
children: dict[str, "ActionNode"]
|
||||
|
||||
# Action Input
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from pydantic import Field
|
|||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.highlight import highlight
|
||||
|
|
@ -33,7 +33,7 @@ def run(*args) -> pd.DataFrame:
|
|||
class CloneFunction(WriteCode):
|
||||
name: str = "CloneFunction"
|
||||
context: list[Message] = []
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
def _save(self, code_path, code):
|
||||
if isinstance(code_path, str):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from pydantic import Field
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.llm import LLM, BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -52,7 +52,7 @@ Now you should start rewriting the code:
|
|||
class DebugError(Action):
|
||||
name: str = "DebugError"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, *args, **kwargs) -> str:
|
||||
output_doc = await FileRepository.get_file(
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from metagpt.const import (
|
|||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document, Documents, Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.mermaid import mermaid_to_file
|
||||
|
|
@ -44,7 +44,7 @@ NEW_REQ_TEMPLATE = """
|
|||
class WriteDesign(Action):
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = (
|
||||
"Based on the PRD, think about the system design, and design the corresponding APIs, "
|
||||
"data structures, library tables, processes, and paths. Please provide your design, feedback "
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class DesignReview(Action):
|
||||
name: str = "DesignReview"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, prd, api_design):
|
||||
prompt = (
|
||||
|
|
|
|||
|
|
@ -10,14 +10,14 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class ExecuteTask(Action):
|
||||
name: str = "ExecuteTask"
|
||||
context: list[Message] = []
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from metagpt.prompts.invoice_ocr import (
|
|||
EXTRACT_OCR_MAIN_INFO_PROMPT,
|
||||
REPLY_OCR_QUESTION_PROMPT,
|
||||
)
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.file import File
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ class InvoiceOCR(Action):
|
|||
|
||||
name: str = "InvoiceOCR"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@staticmethod
|
||||
async def _check_file_type(file_path: Path) -> str:
|
||||
|
|
@ -132,7 +132,7 @@ class GenerateTable(Action):
|
|||
|
||||
name: str = "GenerateTable"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
language: str = "ch"
|
||||
|
||||
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
|
||||
|
|
@ -177,7 +177,7 @@ class ReplyQuestion(Action):
|
|||
|
||||
name: str = "ReplyQuestion"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
language: str = "ch"
|
||||
|
||||
async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from metagpt.actions import Action, ActionOutput
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
|
@ -28,7 +28,7 @@ class PrepareDocuments(Action):
|
|||
|
||||
name: str = "PrepareDocuments"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
def _init_repo(self):
|
||||
"""Initialize the Git environment."""
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from metagpt.const import (
|
|||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document, Documents
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ NEW_REQ_TEMPLATE = """
|
|||
class WriteTasks(Action):
|
||||
name: str = "CreateTasks"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema):
|
||||
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from metagpt.actions import Action
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
|
||||
from metagpt.utils.common import OutputParser
|
||||
|
|
@ -82,7 +82,7 @@ class CollectLinks(Action):
|
|||
|
||||
name: str = "CollectLinks"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = "Collect links from a search engine."
|
||||
search_engine: SearchEngine = Field(default_factory=SearchEngine)
|
||||
rank_func: Union[Callable[[list[str]], None], None] = None
|
||||
|
|
@ -177,7 +177,7 @@ class WebBrowseAndSummarize(Action):
|
|||
|
||||
name: str = "WebBrowseAndSummarize"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = "Explore the web and provide summaries of articles and webpages."
|
||||
browse_func: Union[Callable[[list[str]], None], None] = None
|
||||
web_browser_engine: WebBrowserEngine = Field(
|
||||
|
|
@ -248,7 +248,7 @@ class ConductResearch(Action):
|
|||
|
||||
name: str = "ConductResearch"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.llm import LLM, BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
|
@ -79,7 +79,7 @@ standard errors:
|
|||
class RunCode(Action):
|
||||
name: str = "RunCode"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from metagpt.actions import Action
|
|||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
|
@ -109,7 +109,7 @@ You are a member of a professional butler team and will provide helpful suggesti
|
|||
class SearchAndSummarize(Action):
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
config: None = Field(default_factory=Config)
|
||||
engine: Optional[SearchEngineType] = CONFIG.search_engine
|
||||
search_func: Optional[Any] = None
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.llm import LLM, BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -95,7 +95,7 @@ flowchart TB
|
|||
class SummarizeCode(Action):
|
||||
name: str = "SummarizeCode"
|
||||
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60))
|
||||
async def summarize_code(self, prompt):
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from metagpt.const import (
|
|||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import CodingContext, Document, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -90,7 +90,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
class WriteCode(Action):
|
||||
name: str = "WriteCode"
|
||||
context: Document = Field(default_factory=Document)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code(self, prompt) -> str:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from metagpt.actions.action import Action
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import CodingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
|
@ -123,7 +123,7 @@ REWRITE_CODE_TEMPLATE = """
|
|||
class WriteCodeReview(Action):
|
||||
name: str = "WriteCodeReview"
|
||||
context: CodingContext = Field(default_factory=CodingContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename):
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.pycst import merge_docstring
|
||||
|
||||
|
|
@ -163,7 +163,7 @@ class WriteDocstring(Action):
|
|||
|
||||
desc: str = "Write docstring for code."
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ from metagpt.const import (
|
|||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import BugFixContext, Document, Documents, Message
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -67,7 +67,7 @@ NEW_REQ_TEMPLATE = """
|
|||
class WritePRD(Action):
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
|
||||
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class WritePRDReview(Action):
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
prd: Optional[str] = None
|
||||
desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"
|
||||
prd_review_prompt_template: str = """
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from pydantic import Field
|
|||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
REVIEW = ActionNode(
|
||||
key="Review",
|
||||
|
|
@ -38,7 +38,7 @@ class WriteReview(Action):
|
|||
"""Write a review for the given context."""
|
||||
|
||||
name: str = "WriteReview"
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, context):
|
||||
return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
|
|
|
|||
|
|
@ -13,14 +13,14 @@ from metagpt.actions import Action
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class WriteTeachingPlanPart(Action):
|
||||
"""Write Teaching Plan Part"""
|
||||
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
topic: str = ""
|
||||
language: str = "Chinese"
|
||||
rsp: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from metagpt.config import CONFIG
|
|||
from metagpt.const import TEST_CODES_FILE_REPO
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document, TestingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
|
@ -45,7 +45,7 @@ you should correctly import the necessary classes based on these file locations!
|
|||
class WriteTest(Action):
|
||||
name: str = "WriteTest"
|
||||
context: Optional[TestingContext] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def write_code(self, prompt):
|
||||
code_rsp = await self._aask(prompt)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from pydantic import Field
|
|||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import OutputParser
|
||||
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ class WriteDirectory(Action):
|
|||
"""
|
||||
|
||||
name: str = "WriteDirectory"
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
language: str = "Chinese"
|
||||
|
||||
async def run(self, topic: str, *args, **kwargs) -> Dict:
|
||||
|
|
@ -54,7 +54,7 @@ class WriteContent(Action):
|
|||
"""
|
||||
|
||||
name: str = "WriteContent"
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
directory: dict = dict()
|
||||
language: str = "Chinese"
|
||||
|
||||
|
|
|
|||
|
|
@ -9,14 +9,14 @@
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.llm_provider_registry import LLM_REGISTRY
|
||||
|
||||
_ = HumanProvider() # Avoid pre-commit error
|
||||
|
||||
|
||||
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseGPTAPI:
|
||||
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseLLM:
|
||||
"""get the default llm provider"""
|
||||
if provider is None:
|
||||
provider = CONFIG.get_default_llm_provider_enum()
|
||||
|
|
|
|||
|
|
@ -10,14 +10,15 @@
|
|||
"""
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_LANGUAGE
|
||||
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider import MetaGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, SimpleMessage
|
||||
from metagpt.utils.redis import Redis
|
||||
|
||||
|
|
@ -30,6 +31,7 @@ class BrainMemory(BaseModel):
|
|||
is_dirty: bool = False
|
||||
last_talk: str = None
|
||||
cacheable: bool = True
|
||||
llm: Optional[BaseLLM] = None
|
||||
|
||||
def add_talk(self, msg: Message):
|
||||
"""
|
||||
|
|
@ -120,6 +122,7 @@ class BrainMemory(BaseModel):
|
|||
if isinstance(llm, MetaGPTAPI):
|
||||
return await self._metagpt_summarize(max_words=max_words)
|
||||
|
||||
self.llm = llm
|
||||
return await self._openai_summarize(llm=llm, max_words=max_words, keep_language=keep_language, limit=limit)
|
||||
|
||||
async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1):
|
||||
|
|
@ -131,7 +134,7 @@ class BrainMemory(BaseModel):
|
|||
text_length = len(text)
|
||||
if limit > 0 and text_length < limit:
|
||||
return text
|
||||
summary = await llm.summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
|
||||
summary = await self._summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit)
|
||||
if summary:
|
||||
await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS)
|
||||
return summary
|
||||
|
|
@ -251,3 +254,74 @@ class BrainMemory(BaseModel):
|
|||
texts.append(t)
|
||||
|
||||
return "\n".join(texts)
|
||||
|
||||
async def _summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str:
|
||||
max_token_count = DEFAULT_MAX_TOKENS
|
||||
max_count = 100
|
||||
text_length = len(text)
|
||||
if limit > 0 and text_length < limit:
|
||||
return text
|
||||
summary = ""
|
||||
while max_count > 0:
|
||||
if text_length < max_token_count:
|
||||
summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
|
||||
break
|
||||
|
||||
padding_size = 20 if max_token_count > 20 else 0
|
||||
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
|
||||
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
|
||||
summaries = []
|
||||
for ws in text_windows:
|
||||
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
|
||||
summaries.append(response)
|
||||
if len(summaries) == 1:
|
||||
summary = summaries[0]
|
||||
break
|
||||
|
||||
# Merged and retry
|
||||
text = "\n".join(summaries)
|
||||
text_length = len(text)
|
||||
|
||||
max_count -= 1 # safeguard
|
||||
return summary
|
||||
|
||||
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
|
||||
"""Generate text summary"""
|
||||
if len(text) < max_words:
|
||||
return text
|
||||
if keep_language:
|
||||
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
|
||||
else:
|
||||
command = f"Translate the above content into a summary of less than {max_words} words."
|
||||
msg = text + "\n\n" + command
|
||||
logger.debug(f"summary ask:{msg}")
|
||||
response = await self.llm.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"summary rsp: {response}")
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def split_texts(text: str, window_size) -> List[str]:
|
||||
"""Splitting long text into sliding windows text"""
|
||||
if window_size <= 0:
|
||||
window_size = DEFAULT_TOKEN_SIZE
|
||||
total_len = len(text)
|
||||
if total_len <= window_size:
|
||||
return [text]
|
||||
|
||||
padding_size = 20 if window_size > 20 else 0
|
||||
windows = []
|
||||
idx = 0
|
||||
data_len = window_size - padding_size
|
||||
while idx < total_len:
|
||||
if window_size + idx > total_len: # 不足一个滑窗
|
||||
windows.append(text[idx:])
|
||||
break
|
||||
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
|
||||
# window_size=3, padding_size=1:
|
||||
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
|
||||
# idx=2, | idx=5 | idx=8 | ...
|
||||
w = text[idx : idx + window_size]
|
||||
windows.append(w)
|
||||
idx += data_len
|
||||
|
||||
return windows
|
||||
|
|
|
|||
|
|
@ -6,22 +6,22 @@
|
|||
@File : __init__.py
|
||||
"""
|
||||
|
||||
from metagpt.provider.fireworks_api import FireWorksGPTAPI
|
||||
from metagpt.provider.fireworks_api import FireworksLLM
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
from metagpt.provider.ollama_api import OllamaGPTAPI
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAIGPTAPI
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.metagpt_api import MetaGPTAPI
|
||||
|
||||
__all__ = [
|
||||
"FireWorksGPTAPI",
|
||||
"FireworksLLM",
|
||||
"GeminiGPTAPI",
|
||||
"OpenLLMGPTAPI",
|
||||
"OpenAIGPTAPI",
|
||||
"OpenAILLM",
|
||||
"ZhiPuAIGPTAPI",
|
||||
"AzureOpenAIGPTAPI",
|
||||
"AzureOpenAILLM",
|
||||
"MetaGPTAPI",
|
||||
"OllamaGPTAPI",
|
||||
"OllamaLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,60 +10,36 @@
|
|||
"""
|
||||
|
||||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
|
||||
from openai import AsyncAzureOpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.AZURE_OPENAI)
|
||||
class AzureOpenAIGPTAPI(OpenAIGPTAPI):
|
||||
class AzureOpenAILLM(OpenAILLM):
|
||||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self._init_openai()
|
||||
self.auto_max_tokens = False
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def _make_client(self):
|
||||
kwargs, async_kwargs = self._make_client_kwargs()
|
||||
def _init_client(self):
|
||||
kwargs = self._make_client_kwargs()
|
||||
# https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix
|
||||
self.client = AzureOpenAI(**kwargs)
|
||||
self.async_client = AsyncAzureOpenAI(**async_kwargs)
|
||||
self.async_client = AsyncAzureOpenAI(**kwargs)
|
||||
self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = dict(
|
||||
api_key=self.config.OPENAI_API_KEY,
|
||||
api_version=self.config.OPENAI_API_VERSION,
|
||||
azure_endpoint=self.config.OPENAI_BASE_URL,
|
||||
)
|
||||
async_kwargs = kwargs.copy()
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
proxy_params = self._get_proxy_params()
|
||||
if proxy_params:
|
||||
kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params)
|
||||
async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
|
||||
return kwargs, async_kwargs
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3,
|
||||
"model": self.model,
|
||||
}
|
||||
if configs:
|
||||
kwargs.update(configs)
|
||||
kwargs["timeout"] = max(CONFIG.timeout, timeout)
|
||||
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
|
||||
return kwargs
|
||||
|
|
|
|||
|
|
@ -1,30 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/5 23:00
|
||||
@Author : alexanderwu
|
||||
@File : base_chatbot.py
|
||||
@Modified By: mashenquan, 2023/11/21. Add `timeout`.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseChatbot(ABC):
|
||||
"""Abstract GPT class"""
|
||||
|
||||
mode: str = "API"
|
||||
use_system_prompt: bool = True
|
||||
|
||||
@abstractmethod
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
"""Ask GPT a question and get an answer"""
|
||||
|
||||
@abstractmethod
|
||||
def ask_batch(self, msgs: list, timeout=3) -> str:
|
||||
"""Ask GPT multiple questions and get a series of answers"""
|
||||
|
||||
@abstractmethod
|
||||
def ask_code(self, msgs: list, timeout=3) -> str:
|
||||
"""Ask GPT multiple questions and get a piece of code"""
|
||||
|
|
@ -3,19 +3,18 @@
|
|||
"""
|
||||
@Time : 2023/5/5 23:04
|
||||
@Author : alexanderwu
|
||||
@File : base_gpt_api.py
|
||||
@File : base_llm.py
|
||||
@Desc : mashenquan, 2023/8/22. + try catch
|
||||
"""
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.provider.base_chatbot import BaseChatbot
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""LLM API abstract class, requiring all inheritors to provide a series of standard capabilities"""
|
||||
|
||||
class BaseGPTAPI(BaseChatbot):
|
||||
"""GPT API abstract class, requiring all inheritors to provide a series of standard capabilities"""
|
||||
|
||||
use_system_prompt: bool = True
|
||||
system_prompt = "You are a helpful assistant."
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
|
|
@ -33,17 +32,11 @@ class BaseGPTAPI(BaseChatbot):
|
|||
def _default_system_msg(self):
|
||||
return self._system_msg(self.system_prompt)
|
||||
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
|
||||
rsp = self.completion(message, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
generator: bool = False,
|
||||
timeout=3,
|
||||
stream=True,
|
||||
) -> str:
|
||||
|
|
@ -54,23 +47,12 @@ class BaseGPTAPI(BaseChatbot):
|
|||
if format_msgs:
|
||||
message.extend(format_msgs)
|
||||
message.append(self._user_msg(msg))
|
||||
rsp = await self.acompletion_text(message, stream=stream, generator=generator, timeout=timeout)
|
||||
# logger.debug(rsp)
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
return rsp
|
||||
|
||||
def _extract_assistant_rsp(self, context):
|
||||
return "\n".join([i["content"] for i in context if i["role"] == "assistant"])
|
||||
|
||||
def ask_batch(self, msgs: list, timeout=3) -> str:
|
||||
context = []
|
||||
for msg in msgs:
|
||||
umsg = self._user_msg(msg)
|
||||
context.append(umsg)
|
||||
rsp = self.completion(context, timeout=timeout)
|
||||
rsp_text = self.get_choice_text(rsp)
|
||||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
async def aask_batch(self, msgs: list, timeout=3) -> str:
|
||||
"""Sequential questioning"""
|
||||
context = []
|
||||
|
|
@ -81,26 +63,11 @@ class BaseGPTAPI(BaseChatbot):
|
|||
context.append(self._assistant_msg(rsp_text))
|
||||
return self._extract_assistant_rsp(context)
|
||||
|
||||
def ask_code(self, msgs: list[str], timeout=3) -> str:
|
||||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
rsp_text = self.ask_batch(msgs, timeout=timeout)
|
||||
return rsp_text
|
||||
|
||||
async def aask_code(self, msgs: list[str], timeout=3) -> str:
|
||||
"""FIXME: No code segment filtering has been done here, and all results are actually displayed"""
|
||||
rsp_text = await self.aask_batch(msgs, timeout=timeout)
|
||||
return rsp_text
|
||||
|
||||
@abstractmethod
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
"""All GPTAPIs are required to provide the standard OpenAI completion interface
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "hello, show me python hello world code"},
|
||||
# {"role": "assistant", "content": ...}, # If there is an answer in the history, also include it
|
||||
]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""Asynchronous version of completion
|
||||
|
|
@ -113,7 +80,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""Asynchronous version of completion. Return str. Support stream-print"""
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
|
|
@ -159,16 +126,3 @@ class BaseGPTAPI(BaseChatbot):
|
|||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
return json.loads(self.get_choice_function(rsp)["arguments"])
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
return "\n".join([f"{i.role}: {i.content}" for i in messages])
|
||||
|
||||
def messages_to_dict(self, messages):
|
||||
"""objects to [{"role": "user", "content": msg}] etc."""
|
||||
return [i.to_dict() for i in messages]
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
pass
|
||||
|
|
@ -18,7 +18,7 @@ from tenacity import (
|
|||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter, log_and_reraise
|
||||
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
||||
MODEL_GRADE_TOKEN_COSTS = {
|
||||
|
|
@ -72,18 +72,17 @@ class FireworksCostManager(CostManager):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.FIREWORKS)
|
||||
class FireWorksGPTAPI(OpenAIGPTAPI):
|
||||
class FireworksLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self.__init_fireworks()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = FireworksCostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_fireworks(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._make_client()
|
||||
self._init_client()
|
||||
self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
|
|
@ -103,7 +102,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
|
|||
return self._cost_manager.get_costs()
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages), stream=True
|
||||
)
|
||||
|
||||
|
|
@ -133,9 +132,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI):
|
|||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -47,8 +47,7 @@ MAX_CONNECTION_RETRIES = 2
|
|||
# Has one attribute per thread, 'session'.
|
||||
_thread_context = threading.local()
|
||||
|
||||
LLM_LOG = os.environ.get("LLM_LOG")
|
||||
LLM_LOG = "debug"
|
||||
LLM_LOG = os.environ.get("LLM_LOG", "debug")
|
||||
|
||||
|
||||
class ApiType(Enum):
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from tenacity import (
|
|||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ class GeminiGenerativeModel(GenerativeModel):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.GEMINI)
|
||||
class GeminiGPTAPI(BaseGPTAPI):
|
||||
class GeminiGPTAPI(BaseLLM):
|
||||
"""
|
||||
Refs to `https://ai.google.dev/tutorials/python_quickstart`
|
||||
"""
|
||||
|
|
@ -136,9 +136,7 @@ class GeminiGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ Author: garylin2099
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class HumanProvider(BaseGPTAPI):
|
||||
class HumanProvider(BaseLLM):
|
||||
"""Humans provide themselves as a 'model', which actually takes in human input as its response.
|
||||
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
|
||||
"""
|
||||
|
|
@ -31,10 +31,6 @@ class HumanProvider(BaseGPTAPI):
|
|||
) -> str:
|
||||
return self.ask(msg, timeout=timeout)
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return []
|
||||
|
|
@ -42,7 +38,3 @@ class HumanProvider(BaseGPTAPI):
|
|||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""dummy implementation of abstract method in base"""
|
||||
return ""
|
||||
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@
|
|||
@Desc : MetaGPT LLM provider.
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.provider import OpenAIGPTAPI
|
||||
from metagpt.provider import OpenAILLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.METAGPT)
|
||||
class MetaGPTAPI(OpenAIGPTAPI):
|
||||
class MetaGPTAPI(OpenAILLM):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from tenacity import (
|
|||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
|
|
@ -39,7 +39,7 @@ class OllamaCostManager(CostManager):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.OLLAMA)
|
||||
class OllamaGPTAPI(BaseGPTAPI):
|
||||
class OllamaLLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
|
||||
"""
|
||||
|
|
@ -54,12 +54,8 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
|
||||
def __init_ollama(self, config: CONFIG):
|
||||
assert config.ollama_api_base
|
||||
|
||||
self.model = config.ollama_api_model
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
|
@ -87,18 +83,6 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
chunk = chunk.decode(encoding)
|
||||
return json.loads(chunk)
|
||||
|
||||
def completion(self, messages: list[dict]) -> dict:
|
||||
resp, _, _ = self.client.request(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
params=self._const_kwargs(messages),
|
||||
request_timeout=LLM_API_TIMEOUT,
|
||||
)
|
||||
resp = self._decode_and_load(resp)
|
||||
usage = self.get_usage(resp)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
|
|
@ -111,7 +95,7 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
|
|
@ -147,9 +131,7 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from openai.types import CompletionUsage
|
|||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
|
@ -35,18 +35,17 @@ class OpenLLMCostManager(CostManager):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPEN_LLM)
|
||||
class OpenLLMGPTAPI(OpenAIGPTAPI):
|
||||
class OpenLLMGPTAPI(OpenAILLM):
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self.__init_openllm()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = OpenLLMCostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openllm(self):
|
||||
self.is_azure = False
|
||||
self.rpm = int(self.config.get("RPM", 10))
|
||||
self._make_client()
|
||||
self._init_client()
|
||||
self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
|
|
|
|||
|
|
@ -3,20 +3,17 @@
|
|||
@Time : 2023/5/5 23:08
|
||||
@Author : alexanderwu
|
||||
@File : openai.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation;
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for isolation;
|
||||
Change cost control from global to company level.
|
||||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Union
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
import openai
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
|
||||
from openai._base_client import AsyncHttpxClientWrapper
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from tenacity import (
|
||||
|
|
@ -28,9 +25,8 @@ from tenacity import (
|
|||
)
|
||||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -43,31 +39,6 @@ from metagpt.utils.token_counter import (
|
|||
)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate control class, each call goes through wait_if_needed, sleep if rate control is needed"""
|
||||
|
||||
def __init__(self, rpm):
|
||||
self.last_call_time = 0
|
||||
# Here 1.1 is used because even if the calls are made strictly according to time,
|
||||
# they will still be QOS'd; consider switching to simple error retry later
|
||||
self.interval = 1.1 * 60 / rpm
|
||||
self.rpm = rpm
|
||||
|
||||
def split_batches(self, batch):
|
||||
return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)]
|
||||
|
||||
async def wait_if_needed(self, num_requests):
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - self.last_call_time
|
||||
|
||||
if elapsed_time < self.interval * num_requests:
|
||||
remaining_time = self.interval * num_requests - elapsed_time
|
||||
logger.info(f"sleep {remaining_time}")
|
||||
await asyncio.sleep(remaining_time)
|
||||
|
||||
self.last_call_time = time.time()
|
||||
|
||||
|
||||
def log_and_reraise(retry_state):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
|
|
@ -80,39 +51,31 @@ See FAQ 5.8
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPENAI)
|
||||
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
||||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self._init_openai()
|
||||
self._init_client()
|
||||
self.auto_max_tokens = False
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def _init_openai(self):
|
||||
self.rpm = int(self.config.RPM or 10)
|
||||
self._make_client()
|
||||
|
||||
def _make_client(self):
|
||||
kwargs, async_kwargs = self._make_client_kwargs()
|
||||
# https://github.com/openai/openai-python#async-usage
|
||||
self.client = OpenAI(**kwargs)
|
||||
self.async_client = AsyncOpenAI(**async_kwargs)
|
||||
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _make_client_kwargs(self) -> (dict, dict):
|
||||
kwargs = dict(api_key=self.config.OPENAI_API_KEY, base_url=self.config.OPENAI_BASE_URL)
|
||||
async_kwargs = kwargs.copy()
|
||||
def _init_client(self):
|
||||
"""https://github.com/openai/openai-python#async-usage"""
|
||||
kwargs = self._make_client_kwargs()
|
||||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL}
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
proxy_params = self._get_proxy_params()
|
||||
if proxy_params:
|
||||
kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params)
|
||||
async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
if proxy_params := self._get_proxy_params():
|
||||
kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params)
|
||||
|
||||
return kwargs, async_kwargs
|
||||
return kwargs
|
||||
|
||||
def _get_proxy_params(self) -> dict:
|
||||
params = {}
|
||||
|
|
@ -123,8 +86,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
return params
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create(
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]:
|
||||
response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create(
|
||||
**self._cons_kwargs(messages, timeout=timeout), stream=True
|
||||
)
|
||||
|
||||
|
|
@ -132,35 +95,26 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message
|
||||
yield chunk_message
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict:
|
||||
def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict:
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"max_tokens": self._get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3,
|
||||
"model": self.model,
|
||||
"timeout": max(CONFIG.timeout, timeout),
|
||||
}
|
||||
if configs:
|
||||
kwargs.update(configs)
|
||||
kwargs["timeout"] = max(CONFIG.timeout, timeout)
|
||||
|
||||
if extra_kwargs:
|
||||
kwargs.update(extra_kwargs)
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
kwargs = self._cons_kwargs(messages, timeout=timeout)
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def _chat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages, timeout=timeout))
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
return self._chat_completion(messages, timeout=timeout)
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion:
|
||||
return await self._achat_completion(messages, timeout=timeout)
|
||||
|
||||
|
|
@ -171,12 +125,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
resp = self._achat_completion_stream(messages, timeout=timeout)
|
||||
if generator:
|
||||
return resp
|
||||
|
||||
collected_messages = []
|
||||
async for i in resp:
|
||||
|
|
@ -192,9 +144,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
return self.get_choice_text(rsp)
|
||||
|
||||
def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict:
|
||||
"""
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
"""
|
||||
"""Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create"""
|
||||
if "tools" not in kwargs:
|
||||
configs = {
|
||||
"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}],
|
||||
|
|
@ -204,14 +154,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs)
|
||||
|
||||
def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> ChatCompletion:
|
||||
rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs))
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion:
|
||||
kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs)
|
||||
rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs)
|
||||
rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs)
|
||||
self._update_costs(rsp.usage)
|
||||
return rsp
|
||||
|
||||
|
|
@ -231,56 +176,28 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
)
|
||||
return messages
|
||||
|
||||
def ask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
Examples:
|
||||
|
||||
>>> llm = OpenAIGPTAPI()
|
||||
>>> llm.ask_code("Write a python hello world code.")
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
|
||||
>>> llm.ask_code(msg)
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
rsp = self._chat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict:
|
||||
"""Use function of tools to ask a code.
|
||||
|
||||
Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create
|
||||
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
Examples:
|
||||
|
||||
>>> llm = OpenAIGPTAPI()
|
||||
>>> rsp = await llm.ask_code("Write a python hello world code.")
|
||||
>>> rsp
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
>>> llm = OpenAILLM()
|
||||
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
|
||||
>>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
>>> rsp = await llm.aask_code(msg)
|
||||
# -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
messages = self._process_message(messages)
|
||||
try:
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
except openai.BadRequestError as e:
|
||||
logger.error(f"API TYPE:{CONFIG.OPENAI_API_TYPE}, err:{e}")
|
||||
raise e
|
||||
rsp = await self._achat_completion_function(messages, **kwargs)
|
||||
return self.get_choice_function_arguments(rsp)
|
||||
|
||||
@handle_exception
|
||||
def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict:
|
||||
"""Required to provide the first function arguments of choice.
|
||||
|
||||
:return dict: return the first function arguments of choice, for example,
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
try:
|
||||
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments)
|
||||
|
||||
def get_choice_text(self, rsp: ChatCompletion) -> str:
|
||||
"""Required to provide the first text of choice"""
|
||||
|
|
@ -295,134 +212,24 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
usage.prompt_tokens = count_message_tokens(messages, self.model)
|
||||
usage.completion_tokens = count_string_tokens(rsp, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"usage calculation failed!: {e}")
|
||||
logger.error(f"usage calculation failed: {e}")
|
||||
|
||||
return usage
|
||||
|
||||
async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[ChatCompletion]:
|
||||
"""Return full JSON"""
|
||||
split_batches = self.split_batches(batch)
|
||||
all_results = []
|
||||
|
||||
for small_batch in split_batches:
|
||||
logger.info(small_batch)
|
||||
await self.wait_if_needed(len(small_batch))
|
||||
|
||||
future = [self.acompletion(prompt, timeout=timeout) for prompt in small_batch]
|
||||
results = await asyncio.gather(*future)
|
||||
logger.info(results)
|
||||
all_results.extend(results)
|
||||
|
||||
return all_results
|
||||
|
||||
async def acompletion_batch_text(self, batch: list[list[dict]], timeout=3) -> list[str]:
|
||||
"""Only return plain text"""
|
||||
raw_results = await self.acompletion_batch(batch, timeout=timeout)
|
||||
results = []
|
||||
for idx, raw_result in enumerate(raw_results, start=1):
|
||||
result = self.get_choice_text(raw_result)
|
||||
results.append(result)
|
||||
logger.info(f"Result of task {idx}: {result}")
|
||||
return results
|
||||
|
||||
@handle_exception
|
||||
def _update_costs(self, usage: CompletionUsage):
|
||||
if CONFIG.calc_usage and usage:
|
||||
try:
|
||||
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
except Exception as e:
|
||||
logger.error(f"updating costs failed!, exp: {e}")
|
||||
CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model)
|
||||
|
||||
def get_costs(self) -> Costs:
|
||||
return CONFIG.cost_manager.get_costs()
|
||||
|
||||
def get_max_tokens(self, messages: list[dict]):
|
||||
def _get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
|
||||
def moderation(self, content: Union[str, list[str]]):
|
||||
return self.client.moderations.create(input=content)
|
||||
|
||||
@handle_exception
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
return await self.async_client.moderations.create(input=content)
|
||||
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
if self.client:
|
||||
self.client.close()
|
||||
self.client = None
|
||||
if self.async_client:
|
||||
await self.async_client.close()
|
||||
self.async_client = None
|
||||
|
||||
async def summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str:
|
||||
max_token_count = DEFAULT_MAX_TOKENS
|
||||
max_count = 100
|
||||
text_length = len(text)
|
||||
if limit > 0 and text_length < limit:
|
||||
return text
|
||||
summary = ""
|
||||
while max_count > 0:
|
||||
if text_length < max_token_count:
|
||||
summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language)
|
||||
break
|
||||
|
||||
padding_size = 20 if max_token_count > 20 else 0
|
||||
text_windows = self.split_texts(text, window_size=max_token_count - padding_size)
|
||||
part_max_words = min(int(max_words / len(text_windows)) + 1, 100)
|
||||
summaries = []
|
||||
for ws in text_windows:
|
||||
response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language)
|
||||
summaries.append(response)
|
||||
if len(summaries) == 1:
|
||||
summary = summaries[0]
|
||||
break
|
||||
|
||||
# Merged and retry
|
||||
text = "\n".join(summaries)
|
||||
text_length = len(text)
|
||||
|
||||
max_count -= 1 # safeguard
|
||||
return summary
|
||||
|
||||
async def _get_summary(self, text: str, max_words=20, keep_language: bool = False):
|
||||
"""Generate text summary"""
|
||||
if len(text) < max_words:
|
||||
return text
|
||||
if keep_language:
|
||||
command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly."
|
||||
else:
|
||||
command = f"Translate the above content into a summary of less than {max_words} words."
|
||||
msg = text + "\n\n" + command
|
||||
logger.debug(f"summary ask:{msg}")
|
||||
response = await self.aask(msg=msg, system_msgs=[])
|
||||
logger.debug(f"summary rsp: {response}")
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def split_texts(text: str, window_size) -> List[str]:
|
||||
"""Splitting long text into sliding windows text"""
|
||||
if window_size <= 0:
|
||||
window_size = DEFAULT_TOKEN_SIZE
|
||||
total_len = len(text)
|
||||
if total_len <= window_size:
|
||||
return [text]
|
||||
|
||||
padding_size = 20 if window_size > 20 else 0
|
||||
windows = []
|
||||
idx = 0
|
||||
data_len = window_size - padding_size
|
||||
while idx < total_len:
|
||||
if window_size + idx > total_len: # 不足一个滑窗
|
||||
windows.append(text[idx:])
|
||||
break
|
||||
# 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....]
|
||||
# window_size=3, padding_size=1:
|
||||
# [1, 2, 3], [3, 4, 5], [5, 6, 7], ....
|
||||
# idx=2, | idx=5 | idx=8 | ...
|
||||
w = text[idx : idx + window_size]
|
||||
windows.append(w)
|
||||
idx += data_len
|
||||
|
||||
return windows
|
||||
"""Moderate content."""
|
||||
return await self.aclient.moderations.create(input=content)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/21 11:15
|
||||
@Author : Leo Xiao
|
||||
@File : anthropic_api.py
|
||||
@File : spark_api.py
|
||||
"""
|
||||
import _thread as thread
|
||||
import base64
|
||||
|
|
@ -13,7 +11,6 @@ import hmac
|
|||
import json
|
||||
import ssl
|
||||
from time import mktime
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
|
|
@ -21,52 +18,29 @@ import websocket # 使用websocket_client
|
|||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.SPARK)
|
||||
class SparkGPTAPI(BaseGPTAPI):
|
||||
class SparkLLM(BaseLLM):
|
||||
def __init__(self):
|
||||
logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
rsp = self.completion(message)
|
||||
return rsp
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str:
|
||||
if system_msgs:
|
||||
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
else:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
rsp = await self.acompletion(message)
|
||||
logger.debug(message)
|
||||
return rsp
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
|
||||
async def acompletion_text(
|
||||
self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3
|
||||
) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str:
|
||||
# 不支持
|
||||
logger.error("该功能禁用。")
|
||||
w = GetMessageFromWeb(messages)
|
||||
return w.run()
|
||||
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages)
|
||||
return w.run()
|
||||
|
||||
def completion(self, messages: list[dict]):
|
||||
w = GetMessageFromWeb(messages)
|
||||
return w.run()
|
||||
|
||||
|
||||
class GetMessageFromWeb:
|
||||
class WsParam:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from tenacity import (
|
|||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
|
|
@ -31,7 +31,7 @@ class ZhiPuEvent(Enum):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.ZHIPUAI)
|
||||
class ZhiPuAIGPTAPI(BaseGPTAPI):
|
||||
class ZhiPuAIGPTAPI(BaseLLM):
|
||||
"""
|
||||
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
|
||||
From now, there is only one model named `chatglm_turbo`
|
||||
|
|
@ -131,7 +131,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI):
|
|||
retry=retry_if_exception_type(ConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
"""response in async with stream or non-stream mode"""
|
||||
if stream:
|
||||
return await self._achat_completion_stream(messages)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from metagpt.const import SERDESER_PATH
|
|||
from metagpt.llm import LLM, HumanProvider
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, MessageQueue
|
||||
from metagpt.utils.common import (
|
||||
any_to_name,
|
||||
|
|
@ -141,7 +141,7 @@ class Role(BaseModel):
|
|||
desc: str = ""
|
||||
is_human: bool = False
|
||||
|
||||
_llm: BaseGPTAPI = Field(default_factory=LLM) # Each role has its own LLM, use different system message
|
||||
_llm: BaseLLM = Field(default_factory=LLM) # Each role has its own LLM, use different system message
|
||||
_role_id: str = ""
|
||||
_states: list[str] = []
|
||||
_actions: list[Action] = []
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from metagpt.actions import UserRequirement
|
|||
from metagpt.actions.execute_task import ExecuteTask
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.make_sk_kernel import make_sk_kernel
|
||||
|
|
@ -44,7 +44,7 @@ class SkAgent(Role):
|
|||
plan: Any = None
|
||||
planner_cls: Any = None
|
||||
planner: Any = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
kernel: Kernel = Field(default_factory=Kernel)
|
||||
import_semantic_skill_from_directory: Type[Kernel.import_semantic_skill_from_directory] = None
|
||||
import_skill: Type[Kernel.import_skill] = None
|
||||
|
|
|
|||
|
|
@ -21,10 +21,6 @@ class OpenAIText2Image:
|
|||
"""
|
||||
self._llm = LLM()
|
||||
|
||||
def __del__(self):
|
||||
if self._llm:
|
||||
self._llm.close()
|
||||
|
||||
async def text_2_image(self, text, size_type="1024x1024"):
|
||||
"""Text to image
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
|
||||
from metagpt.provider.openai_api import OpenAILLM as GPTAPI
|
||||
|
||||
ICL_SAMPLE = """Interface definition:
|
||||
```text
|
||||
|
|
@ -278,11 +278,11 @@ class UTGenerator:
|
|||
question += self.build_api_doc(node, path, method)
|
||||
self.ask_gpt_and_save(question, tag, summary)
|
||||
|
||||
def gpt_msgs_to_code(self, messages: list) -> str:
|
||||
async def gpt_msgs_to_code(self, messages: list) -> str:
|
||||
"""Choose based on different calling methods"""
|
||||
result = ""
|
||||
if self.chatgpt_method == "API":
|
||||
result = GPTAPI().ask_code(msgs=messages)
|
||||
result = await GPTAPI().aask_code(msgs=messages)
|
||||
|
||||
return result
|
||||
|
||||
|
|
|
|||
|
|
@ -22,9 +22,9 @@ async def test_design_api():
|
|||
for prd in inputs:
|
||||
await FileRepository.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
|
||||
|
||||
design_api = WriteDesign("design_api")
|
||||
design_api = WriteDesign()
|
||||
|
||||
result = await design_api.run([Message(content=prd, instruct_content=None)])
|
||||
result = await design_api.run(Message(content=prd, instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
|
|
|||
|
|
@ -6,10 +6,26 @@
|
|||
@File : test_project_management.py
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
class TestCreateProjectPlan:
|
||||
pass
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock_json import DESIGN, PRD
|
||||
|
||||
|
||||
class TestAssignTasks:
|
||||
pass
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api():
|
||||
await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
|
||||
await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
|
||||
logger.info(CONFIG.git_repo)
|
||||
|
||||
action = WriteTasks()
|
||||
|
||||
result = await action.run(Message(content="", instruct_content=None))
|
||||
logger.info(result)
|
||||
|
||||
assert result
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
|
||||
from metagpt.provider.openai_api import OpenAILLM as LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
default_chat_resp = {
|
||||
|
|
@ -27,14 +27,14 @@ prompt_msg = "who are you"
|
|||
resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
class MockBaseGPTAPI(BaseGPTAPI):
|
||||
class MockBaseGPTAPI(BaseLLM):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str:
|
||||
async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str:
|
||||
return resp_content
|
||||
|
||||
async def close(self):
|
||||
|
|
@ -47,11 +47,6 @@ def test_base_gpt_api():
|
|||
assert "user" in str(message)
|
||||
|
||||
base_gpt_api = MockBaseGPTAPI()
|
||||
msg_prompt = base_gpt_api.messages_to_prompt([message])
|
||||
assert msg_prompt == "user: hello"
|
||||
|
||||
msg_dict = base_gpt_api.messages_to_dict([message])
|
||||
assert msg_dict == [{"role": "user", "content": "hello"}]
|
||||
|
||||
openai_funccall_resp = {
|
||||
"choices": [
|
||||
|
|
@ -87,14 +82,14 @@ def test_base_gpt_api():
|
|||
choice_text = base_gpt_api.get_choice_text(openai_funccall_resp)
|
||||
assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"]
|
||||
|
||||
resp = base_gpt_api.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
# resp = base_gpt_api.ask(prompt_msg)
|
||||
# assert resp == resp_content
|
||||
|
||||
resp = base_gpt_api.ask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
# resp = base_gpt_api.ask_batch([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
||||
resp = base_gpt_api.ask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
# resp = base_gpt_api.ask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from openai.types.completion_usage import CompletionUsage
|
|||
from metagpt.provider.fireworks_api import (
|
||||
MODEL_GRADE_TOKEN_COSTS,
|
||||
FireworksCostManager,
|
||||
FireWorksGPTAPI,
|
||||
FireworksLLM,
|
||||
)
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
|
|
@ -55,17 +55,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
|||
return default_resp.choices[0].message.content
|
||||
|
||||
|
||||
def test_fireworks_completion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion)
|
||||
fireworks_gpt = FireWorksGPTAPI()
|
||||
|
||||
resp = fireworks_gpt.completion(messages)
|
||||
assert resp.choices[0].message.content == resp_content
|
||||
|
||||
resp = fireworks_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion)
|
||||
|
|
@ -73,7 +62,7 @@ async def test_fireworks_acompletion(mocker):
|
|||
mocker.patch(
|
||||
"metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
fireworks_gpt = FireWorksGPTAPI()
|
||||
fireworks_gpt = FireworksLLM()
|
||||
|
||||
resp = await fireworks_gpt.acompletion(messages, stream=False)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
|
|
|||
|
|
@ -35,16 +35,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion)
|
||||
gemini_gpt = GeminiGPTAPI()
|
||||
resp = gemini_gpt.completion(messages)
|
||||
assert resp.text == resp_content
|
||||
|
||||
resp = gemini_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion)
|
||||
|
|
|
|||
|
|
@ -17,15 +17,6 @@ async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
|
|||
return mock_llm_ask(msg)
|
||||
|
||||
|
||||
def test_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
assert resp_content == human_provider.ask(None)
|
||||
|
||||
assert not human_provider.completion(messages=[])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.ollama_api import OllamaGPTAPI
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
|
@ -28,22 +28,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
resp = ollama_gpt.completion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
resp = ollama_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
ollama_gpt = OllamaLLM()
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ from unittest.mock import Mock
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import UserMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -18,7 +18,7 @@ async def test_aask_code():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -28,7 +28,7 @@ async def test_aask_code_str():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -36,52 +36,6 @@ async def test_aask_code_Message():
|
|||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_list_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_list_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
|
||||
print(rsp)
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
class TestOpenAI:
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
|
|
@ -130,7 +84,7 @@ class TestOpenAI:
|
|||
)
|
||||
|
||||
def test_make_client_kwargs_without_proxy(self, config):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
|
|
@ -139,7 +93,7 @@ class TestOpenAI:
|
|||
assert "http_client" not in async_kwargs
|
||||
|
||||
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
|
|
@ -148,14 +102,14 @@ class TestOpenAI:
|
|||
assert "http_client" not in async_kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy(self, config_proxy):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_proxy
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
assert "http_client" in async_kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure_proxy
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.spark_api import SparkGPTAPI
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
|
|
@ -18,24 +18,13 @@ async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False,
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_spark_completion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
|
||||
resp = spark_gpt.completion([])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = spark_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
spark_gpt = SparkLLM()
|
||||
|
||||
resp = await spark_gpt.acompletion([], stream=False)
|
||||
resp = await spark_gpt.acompletion([])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg, stream=False)
|
||||
|
|
|
|||
|
|
@ -28,18 +28,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_zhipuai_completion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion)
|
||||
zhipu_gpt = ZhiPuAIGPTAPI()
|
||||
|
||||
resp = zhipu_gpt.completion(messages)
|
||||
assert resp["code"] == 200
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
|
||||
resp = zhipu_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion)
|
||||
|
|
|
|||
|
|
@ -14,23 +14,6 @@ from metagpt.logs import logger
|
|||
|
||||
@pytest.mark.usefixtures("llm_api")
|
||||
class TestGPT:
|
||||
def test_llm_api_ask(self, llm_api):
|
||||
answer = llm_api.ask("hello chatgpt")
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
|
||||
def test_gptapi_ask_batch(self, llm_api):
|
||||
answer = llm_api.ask_batch(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"], timeout=60)
|
||||
assert len(answer) > 0
|
||||
|
||||
def test_llm_api_ask_code(self, llm_api):
|
||||
try:
|
||||
answer = llm_api.ask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"])
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
except openai.BadRequestError:
|
||||
assert CONFIG.OPENAI_API_TYPE == "azure"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_aask(self, llm_api):
|
||||
answer = await llm_api.aask("hello chatgpt", stream=False)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
|
||||
from metagpt.provider.openai_api import OpenAILLM as LLM
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
@ -23,18 +23,11 @@ async def test_llm_aask(llm):
|
|||
assert len(rsp) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_aask_batch(llm):
|
||||
assert len(await llm.aask_batch(["hi", "write python hello world."])) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_acompletion(llm):
|
||||
hello_msg = [{"role": "user", "content": "hello"}]
|
||||
rsp = await llm.acompletion(hello_msg)
|
||||
assert len(rsp.choices[0].message.content) > 0
|
||||
assert len(await llm.acompletion_batch([hello_msg])) > 0
|
||||
assert len(await llm.acompletion_batch_text([hello_msg])) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/7 17:23
|
||||
@Author : alexanderwu
|
||||
@File : test_custom_aio_session.py
|
||||
"""
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
|
||||
|
||||
async def try_hello(api):
|
||||
batch = [[{"role": "user", "content": "hello"}]]
|
||||
results = await api.acompletion_batch_text(batch)
|
||||
return results
|
||||
|
||||
|
||||
async def aask_batch(api: OpenAIGPTAPI):
|
||||
results = await api.aask_batch(["hi", "write python hello world."])
|
||||
logger.info(results)
|
||||
return results
|
||||
Loading…
Add table
Add a link
Reference in a new issue