Merge branch 'geekan:dev' into dev

This commit is contained in:
better629 2024-01-15 20:18:52 +08:00 committed by GitHub
commit 71067c894b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 242 additions and 123 deletions

Binary file not shown.

View file

@ -34,8 +34,10 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
node: ActionNode = Field(default=None, exclude=True)
@property
def project_repo(self):
return ProjectRepo(self.context.git_repo)
def repo(self) -> ProjectRepo:
if not self.context.repo:
self.context.repo = ProjectRepo(self.context.git_repo)
return self.context.repo
@property
def prompt_schema(self):

View file

@ -49,7 +49,7 @@ class DebugError(Action):
i_context: RunCodeContext = Field(default_factory=RunCodeContext)
async def run(self, *args, **kwargs) -> str:
output_doc = await self.project_repo.test_outputs.get(filename=self.i_context.output_filename)
output_doc = await self.repo.test_outputs.get(filename=self.i_context.output_filename)
if not output_doc:
return ""
output_detail = RunCodeResult.loads(output_doc.content)
@ -59,12 +59,12 @@ class DebugError(Action):
return ""
logger.info(f"Debug and rewrite {self.i_context.test_filename}")
code_doc = await self.project_repo.with_src_path(self.context.src_workspace).srcs.get(
code_doc = await self.repo.with_src_path(self.context.src_workspace).srcs.get(
filename=self.i_context.code_filename
)
if not code_doc:
return ""
test_doc = await self.project_repo.tests.get(filename=self.i_context.test_filename)
test_doc = await self.repo.tests.get(filename=self.i_context.test_filename)
if not test_doc:
return ""
prompt = PROMPT_TEMPLATE.format(code=code_doc.content, test_code=test_doc.content, logs=output_detail.stderr)

View file

@ -40,10 +40,10 @@ class WriteDesign(Action):
async def run(self, with_messages: Message, schema: str = None):
# Use `git status` to identify which PRD documents have been modified in the `docs/prds` directory.
changed_prds = self.project_repo.docs.prd.changed_files
changed_prds = self.repo.docs.prd.changed_files
# Use `git status` to identify which design documents in the `docs/system_designs` directory have undergone
# changes.
changed_system_designs = self.project_repo.docs.system_design.changed_files
changed_system_designs = self.repo.docs.system_design.changed_files
# For those PRDs and design documents that have undergone changes, regenerate the design content.
changed_files = Documents()
@ -73,21 +73,21 @@ class WriteDesign(Action):
return system_design_doc
async def _update_system_design(self, filename) -> Document:
prd = await self.project_repo.docs.prd.get(filename)
old_system_design_doc = await self.project_repo.docs.system_design.get(filename)
prd = await self.repo.docs.prd.get(filename)
old_system_design_doc = await self.repo.docs.system_design.get(filename)
if not old_system_design_doc:
system_design = await self._new_system_design(context=prd.content)
doc = await self.project_repo.docs.system_design.save(
doc = await self.repo.docs.system_design.save(
filename=filename,
content=system_design.instruct_content.model_dump_json(),
dependencies={prd.root_relative_path},
)
else:
doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc)
await self.project_repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path})
await self.repo.docs.system_design.save_doc(doc=doc, dependencies={prd.root_relative_path})
await self._save_data_api_design(doc)
await self._save_seq_flow(doc)
await self.project_repo.resources.system_design.save_pdf(doc=doc)
await self.repo.resources.system_design.save_pdf(doc=doc)
return doc
async def _save_data_api_design(self, design_doc):
@ -95,7 +95,7 @@ class WriteDesign(Action):
data_api_design = m.get("Data structures and interfaces")
if not data_api_design:
return
pathname = self.project_repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
pathname = self.repo.workdir / DATA_API_DESIGN_FILE_REPO / Path(design_doc.filename).with_suffix("")
await self._save_mermaid_file(data_api_design, pathname)
logger.info(f"Save class view to {str(pathname)}")
@ -104,7 +104,7 @@ class WriteDesign(Action):
seq_flow = m.get("Program call flow")
if not seq_flow:
return
pathname = self.project_repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
pathname = self.repo.workdir / Path(SEQ_FLOW_FILE_REPO) / Path(design_doc.filename).with_suffix("")
await self._save_mermaid_file(seq_flow, pathname)
logger.info(f"Saving sequence flow to {str(pathname)}")

View file

@ -15,6 +15,7 @@ from metagpt.actions import Action, ActionOutput
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.utils.file_repository import FileRepository
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
class PrepareDocuments(Action):
@ -38,13 +39,14 @@ class PrepareDocuments(Action):
shutil.rmtree(path)
self.config.project_path = path
self.context.git_repo = GitRepository(local_path=path, auto_init=True)
self.context.repo = ProjectRepo(self.context.git_repo)
async def run(self, with_messages, **kwargs):
"""Create and initialize the workspace folder, initialize the Git environment."""
self._init_repo()
# Write the newly added requirements from the main parameter idea to `docs/requirement.txt`.
doc = await self.project_repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
doc = await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content=with_messages[0].content)
# Send a Message notification to the WritePRD action, instructing it to process requirements using
# `docs/requirement.txt` and `docs/prds/`.
return ActionOutput(content=doc.content, instruct_content=doc)

View file

@ -13,8 +13,8 @@
import json
from typing import Optional
from metagpt.actions import ActionOutput
from metagpt.actions.action import Action
from metagpt.actions.action_output import ActionOutput
from metagpt.actions.project_management_an import PM_NODE
from metagpt.const import PACKAGE_REQUIREMENTS_FILENAME
from metagpt.logs import logger
@ -34,8 +34,8 @@ class WriteTasks(Action):
i_context: Optional[str] = None
async def run(self, with_messages):
changed_system_designs = self.project_repo.docs.system_design.changed_files
changed_tasks = self.project_repo.docs.task.changed_files
changed_system_designs = self.repo.docs.system_design.changed_files
changed_tasks = self.repo.docs.task.changed_files
change_files = Documents()
# Rewrite the system designs that have undergone changes based on the git head diff under
# `docs/system_designs/`.
@ -57,16 +57,14 @@ class WriteTasks(Action):
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _update_tasks(self, filename):
system_design_doc = await self.project_repo.docs.system_design.get(filename)
task_doc = await self.project_repo.docs.task.get(filename)
system_design_doc = await self.repo.docs.system_design.get(filename)
task_doc = await self.repo.docs.task.get(filename)
if task_doc:
task_doc = await self._merge(system_design_doc=system_design_doc, task_doc=task_doc)
await self.project_repo.docs.task.save_doc(
doc=task_doc, dependencies={system_design_doc.root_relative_path}
)
await self.repo.docs.task.save_doc(doc=task_doc, dependencies={system_design_doc.root_relative_path})
else:
rsp = await self._run_new_tasks(context=system_design_doc.content)
task_doc = await self.project_repo.docs.task.save(
task_doc = await self.repo.docs.task.save(
filename=filename,
content=rsp.instruct_content.model_dump_json(),
dependencies={system_design_doc.root_relative_path},
@ -87,7 +85,7 @@ class WriteTasks(Action):
async def _update_requirements(self, doc):
m = json.loads(doc.content)
packages = set(m.get("Required Python third-party packages", set()))
requirement_doc = await self.project_repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
requirement_doc = await self.repo.get(filename=PACKAGE_REQUIREMENTS_FILENAME)
if not requirement_doc:
requirement_doc = Document(filename=PACKAGE_REQUIREMENTS_FILENAME, root_path=".", content="")
lines = requirement_doc.content.splitlines()
@ -95,4 +93,4 @@ class WriteTasks(Action):
if pkg == "":
continue
packages.add(pkg)
await self.project_repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))
await self.repo.save(filename=PACKAGE_REQUIREMENTS_FILENAME, content="\n".join(packages))

View file

@ -98,10 +98,10 @@ class SummarizeCode(Action):
async def run(self):
design_pathname = Path(self.i_context.design_filename)
design_doc = await self.project_repo.docs.system_design.get(filename=design_pathname.name)
design_doc = await self.repo.docs.system_design.get(filename=design_pathname.name)
task_pathname = Path(self.i_context.task_filename)
task_doc = await self.project_repo.docs.task.get(filename=task_pathname.name)
src_file_repo = self.project_repo.with_src_path(self.context.src_workspace).srcs
task_doc = await self.repo.docs.task.get(filename=task_pathname.name)
src_file_repo = self.repo.with_src_path(self.context.src_workspace).srcs
code_blocks = []
for filename in self.i_context.codes_filenames:
code_doc = await src_file_repo.get(filename)

View file

@ -88,12 +88,12 @@ class WriteCode(Action):
return code
async def run(self, *args, **kwargs) -> CodingContext:
bug_feedback = await self.project_repo.docs.get(filename=BUGFIX_FILENAME)
bug_feedback = await self.repo.docs.get(filename=BUGFIX_FILENAME)
coding_context = CodingContext.loads(self.i_context.content)
test_doc = await self.project_repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
test_doc = await self.repo.test_outputs.get(filename="test_" + coding_context.filename + ".json")
summary_doc = None
if coding_context.design_doc and coding_context.design_doc.filename:
summary_doc = await self.project_repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
summary_doc = await self.repo.docs.code_summary.get(filename=coding_context.design_doc.filename)
logs = ""
if test_doc:
test_detail = RunCodeResult.loads(test_doc.content)
@ -105,7 +105,7 @@ class WriteCode(Action):
code_context = await self.get_codes(
coding_context.task_doc,
exclude=self.i_context.filename,
project_repo=self.project_repo.with_src_path(self.context.src_workspace),
project_repo=self.repo.with_src_path(self.context.src_workspace),
)
prompt = PROMPT_TEMPLATE.format(

View file

@ -143,7 +143,7 @@ class WriteCodeReview(Action):
code_context = await WriteCode.get_codes(
self.i_context.task_doc,
exclude=self.i_context.filename,
project_repo=self.project_repo.with_src_path(self.context.src_workspace),
project_repo=self.repo.with_src_path(self.context.src_workspace),
)
context = "\n".join(
[

View file

@ -15,7 +15,6 @@ from __future__ import annotations
import json
from pathlib import Path
from typing import Optional
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
@ -58,96 +57,106 @@ NEW_REQ_TEMPLATE = """
class WritePRD(Action):
name: str = "WritePRD"
content: Optional[str] = None
"""WritePRD deal with the following situations:
1. Bugfix: If the requirement is a bugfix, the bugfix document will be generated.
2. New requirement: If the requirement is a new requirement, the PRD document will be generated.
3. Requirement update: If the requirement is an update, the PRD document will be updated.
"""
async def run(self, with_messages, *args, **kwargs) -> ActionOutput | Message:
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
# related to the PRD. If they are related, rewrite the PRD.
requirement_doc = await self.project_repo.docs.get(filename=REQUIREMENT_FILENAME)
if requirement_doc and await self._is_bugfix(requirement_doc.content):
await self.project_repo.docs.save(filename=BUGFIX_FILENAME, content=requirement_doc.content)
await self.project_repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
return Message(
content=bug_fix.model_dump_json(),
instruct_content=bug_fix,
role="",
cause_by=FixBug,
sent_from=self,
send_to="Alex", # the name of Engineer
)
"""Run the action."""
req: Document = await self.repo.requirement
docs: list[Document] = await self.repo.docs.prd.get_all()
if not req:
raise FileNotFoundError("No requirement document found.")
if await self._is_bugfix(req.content):
logger.info(f"Bugfix detected: {req.content}")
return await self._handle_bugfix(req)
# remove bugfix file from last round in case of conflict
await self.repo.docs.delete(filename=BUGFIX_FILENAME)
# if requirement is related to other documents, update them, otherwise create a new one
if related_docs := await self.get_related_docs(req, docs):
logger.info(f"Requirement update detected: {req.content}")
return await self._handle_requirement_update(req, related_docs)
else:
await self.project_repo.docs.delete(filename=BUGFIX_FILENAME)
logger.info(f"New requirement detected: {req.content}")
return await self._handle_new_requirement(req)
prd_docs = await self.project_repo.docs.prd.get_all()
change_files = Documents()
for prd_doc in prd_docs:
prd_doc = await self._update_prd(requirement_doc=requirement_doc, prd_doc=prd_doc, *args, **kwargs)
if not prd_doc:
continue
change_files.docs[prd_doc.filename] = prd_doc
logger.info(f"rewrite prd: {prd_doc.filename}")
# If there is no existing PRD, generate one using 'docs/requirement.txt'.
if not change_files.docs:
prd_doc = await self._update_prd(requirement_doc=requirement_doc, *args, **kwargs)
if prd_doc:
change_files.docs[prd_doc.filename] = prd_doc
logger.debug(f"new prd: {prd_doc.filename}")
# Once all files under 'docs/prds/' have been compared with the newly added requirements, trigger the
# 'publish' message to transition the workflow to the next stage. This design allows room for global
# optimization in subsequent steps.
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _handle_bugfix(self, req: Document) -> Message:
# ... bugfix logic ...
await self.repo.docs.save(filename=BUGFIX_FILENAME, content=req.content)
await self.repo.docs.save(filename=REQUIREMENT_FILENAME, content="")
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
return Message(
content=bug_fix.model_dump_json(),
instruct_content=bug_fix,
role="",
cause_by=FixBug,
sent_from=self,
send_to="Alex", # the name of Engineer
)
async def _run_new_requirement(self, requirements) -> ActionOutput:
async def _handle_new_requirement(self, req: Document) -> ActionOutput:
"""handle new requirement"""
project_name = self.project_name
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
context = CONTEXT_TEMPLATE.format(requirements=req, project_name=project_name)
exclude = [PROJECT_NAME.key] if project_name else []
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
await self._rename_workspace(node)
return node
new_prd_doc = await self.repo.docs.prd.save(
filename=FileRepository.new_filename() + ".json", content=node.instruct_content.model_dump_json()
)
await self._save_competitive_analysis(new_prd_doc)
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
return Documents.from_iterable(documents=[new_prd_doc]).to_action_output()
async def _is_relative(self, new_requirement_doc, old_prd_doc) -> bool:
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd_doc.content, requirements=new_requirement_doc.content)
async def _handle_requirement_update(self, req: Document, related_docs: list[Document]) -> ActionOutput:
# ... requirement update logic ...
for doc in related_docs:
await self._update_prd(req, doc)
return Documents.from_iterable(documents=related_docs).to_action_output()
async def _is_bugfix(self, context: str) -> bool:
if not self.repo.code_files_exists():
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
return node.get("issue_type") == "BUG"
async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]:
"""get the related documents"""
# refine: use gather to speed up
return [i for i in docs if await self._is_related(req, i)]
async def _is_related(self, req: Document, old_prd: Document) -> bool:
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content)
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
return node.get("is_relative") == "YES"
async def _merge(self, new_requirement_doc, prd_doc) -> Document:
async def _merge(self, req: Document, related_doc: Document) -> Document:
if not self.project_name:
self.project_name = Path(self.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content)
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
prd_doc.content = node.instruct_content.model_dump_json()
related_doc.content = node.instruct_content.model_dump_json()
await self._rename_workspace(node)
return prd_doc
return related_doc
async def _update_prd(self, requirement_doc, prd_doc=None, *args, **kwargs) -> Document | None:
if not prd_doc:
prd = await self._run_new_requirement(
requirements=[requirement_doc.content if requirement_doc else ""], *args, **kwargs
)
new_prd_doc = await self.project_repo.docs.prd.save(
filename=FileRepository.new_filename() + ".json", content=prd.instruct_content.model_dump_json()
)
elif await self._is_relative(requirement_doc, prd_doc):
new_prd_doc = await self._merge(requirement_doc, prd_doc)
self.project_repo.docs.prd.save_doc(doc=new_prd_doc)
else:
return None
async def _update_prd(self, req: Document, prd_doc: Document) -> Document:
new_prd_doc: Document = await self._merge(req, prd_doc)
self.repo.docs.prd.save_doc(doc=new_prd_doc)
await self._save_competitive_analysis(new_prd_doc)
await self.project_repo.resources.prd.save_pdf(doc=new_prd_doc)
await self.repo.resources.prd.save_pdf(doc=new_prd_doc)
return new_prd_doc
async def _save_competitive_analysis(self, prd_doc):
async def _save_competitive_analysis(self, prd_doc: Document):
m = json.loads(prd_doc.content)
quadrant_chart = m.get("Competitive Quadrant Chart")
if not quadrant_chart:
return
pathname = (
self.project_repo.workdir / Path(COMPETITIVE_ANALYSIS_FILE_REPO) / Path(prd_doc.filename).with_suffix("")
)
if not pathname.parent.exists():
pathname.parent.mkdir(parents=True, exist_ok=True)
pathname = self.repo.workdir / COMPETITIVE_ANALYSIS_FILE_REPO / Path(prd_doc.filename).stem
pathname.parent.mkdir(parents=True, exist_ok=True)
await mermaid_to_file(self.config.mermaid_engine, quadrant_chart, pathname)
async def _rename_workspace(self, prd):
@ -158,15 +167,4 @@ class WritePRD(Action):
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
if ws_name:
self.project_name = ws_name
self.project_repo.git_repo.rename_root(self.project_name)
async def _is_bugfix(self, context) -> bool:
git_workdir = self.project_repo.git_repo.workdir
src_workdir = git_workdir / git_workdir.name
if not src_workdir.exists():
return False
code_files = self.project_repo.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
if not code_files:
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
return node.get("issue_type") == "BUG"
self.repo.git_repo.rename_root(self.project_name)

View file

@ -17,6 +17,7 @@ from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import create_llm_instance
from metagpt.utils.cost_manager import CostManager
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
class AttrDict(BaseModel):
@ -58,6 +59,8 @@ class Context(BaseModel):
kwargs: AttrDict = AttrDict()
config: Config = Config.default()
repo: Optional[ProjectRepo] = None
git_repo: Optional[GitRepository] = None
src_workspace: Optional[Path] = None
cost_manager: CostManager = CostManager()
@ -67,8 +70,8 @@ class Context(BaseModel):
def new_environ(self):
"""Return a new os.environ object"""
env = os.environ.copy()
i = self.options
env.update({k: v for k, v in i.items() if isinstance(v, str)})
# i = self.options
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
return env
# def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:

View file

@ -235,3 +235,7 @@ class OpenAILLM(BaseLLM):
async def amoderation(self, content: Union[str, list[str]]):
"""Moderate content."""
return await self.aclient.moderations.create(input=content)
async def atext_to_speech(self, **kwargs):
"""text to speech"""
return await self.aclient.audio.speech.create(**kwargs)

View file

@ -10,8 +10,9 @@
from pydantic import Field
from metagpt.actions import ActionOutput, SearchAndSummarize
from metagpt.actions import SearchAndSummarize
from metagpt.actions.action_node import ActionNode
from metagpt.actions.action_output import ActionOutput
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message

View file

@ -23,7 +23,7 @@ from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from typing import Any, Dict, Iterable, List, Optional, Type, TypeVar, Union
from pydantic import (
BaseModel,
@ -162,6 +162,26 @@ class Documents(BaseModel):
docs: Dict[str, Document] = Field(default_factory=dict)
@classmethod
def from_iterable(cls, documents: Iterable[Document]) -> Documents:
"""Create a Documents instance from a list of Document instances.
:param documents: A list of Document instances.
:return: A Documents instance.
"""
docs = {doc.filename: doc for doc in documents}
return Documents(docs=docs)
def to_action_output(self) -> "ActionOutput":
"""Convert to action output string.
:return: A string representing action output.
"""
from metagpt.actions.action_output import ActionOutput
return ActionOutput(content=self.model_dump_json(), instruct_content=self)
class Message(BaseModel):
"""list[<role>: <content>]"""

View file

@ -21,6 +21,7 @@ from metagpt.const import (
GRAPH_REPO_FILE_REPO,
PRD_PDF_FILE_REPO,
PRDS_FILE_REPO,
REQUIREMENT_FILENAME,
RESOURCES_FILE_REPO,
SD_OUTPUT_FILE_REPO,
SEQ_FLOW_FILE_REPO,
@ -93,6 +94,10 @@ class ProjectRepo(FileRepository):
self.test_outputs = self._git_repo.new_file_repository(relative_path=TEST_OUTPUTS_FILE_REPO)
self._srcs_path = None
@property
async def requirement(self):
return await self.docs.get(filename=REQUIREMENT_FILENAME)
@property
def git_repo(self) -> GitRepository:
return self._git_repo
@ -107,6 +112,15 @@ class ProjectRepo(FileRepository):
raise ValueError("Call with_srcs first.")
return self._git_repo.new_file_repository(self._srcs_path)
def code_files_exists(self) -> bool:
git_workdir = self.git_repo.workdir
src_workdir = git_workdir / git_workdir.name
if not src_workdir.exists():
return False
code_files = self.with_src_path(path=git_workdir / git_workdir.name).srcs.all_files
if not code_files:
return False
def with_src_path(self, path: str | Path) -> ProjectRepo:
try:
self._srcs_path = Path(path).relative_to(self.workdir)

File diff suppressed because one or more lines are too long

View file

@ -11,7 +11,7 @@ from pathlib import Path
import pytest
from metagpt.config2 import config
from metagpt.config2 import Config
from metagpt.learn.text_to_embedding import text_to_embedding
from metagpt.utils.common import aread
@ -19,13 +19,14 @@ from metagpt.utils.common import aread
@pytest.mark.asyncio
async def test_text_to_embedding(mocker):
# mock
config = Config.default()
mock_post = mocker.patch("aiohttp.ClientSession.post")
mock_response = mocker.AsyncMock()
mock_response.status = 200
data = await aread(Path(__file__).parent / "../../data/openai/embedding.json")
mock_response.json.return_value = json.loads(data)
mock_post.return_value.__aenter__.return_value = mock_response
type(config.get_openai_llm()).proxy = mocker.PropertyMock(return_value="http://mock.proxy")
config.get_openai_llm().proxy = mocker.PropertyMock(return_value="http://mock.proxy")
# Prerequisites
assert config.get_openai_llm().api_key

View file

@ -42,11 +42,23 @@ async def test_aask_code_message():
assert len(rsp["code"]) > 0
@pytest.mark.asyncio
async def test_text_to_speech():
llm = LLM()
resp = await llm.atext_to_speech(
model="tts-1",
voice="alloy",
input="人生说起来长,但知道一个岁月回头看,许多事件仅是仓促的。一段一段拼凑一起,合成了人生。苦难当头时,当下不免觉得是折磨;回头看,也不够是一段短短的人生旅程。",
)
assert 200 == resp.response.status_code
class TestOpenAI:
def test_make_client_kwargs_without_proxy(self):
instance = OpenAILLM(mock_llm_config)
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "mock_api_key", "base_url": "mock_base_url"}
assert kwargs["api_key"] == "mock_api_key"
assert kwargs["base_url"] == "mock_base_url"
assert "http_client" not in kwargs
def test_make_client_kwargs_with_proxy(self):

View file

@ -48,7 +48,6 @@ def test_context_1():
assert ctx.git_repo is None
assert ctx.src_workspace is None
assert ctx.cost_manager is not None
assert ctx.options is not None
def test_context_2():

View file

@ -95,7 +95,7 @@ def test_config_mixin_4_multi_inheritance_override_config():
print(obj.__dict__.keys())
assert "private_config" in obj.__dict__.keys()
assert obj.llm.model == "mock_zhipu_model"
assert obj.config.llm.model == "mock_zhipu_model"
@pytest.mark.asyncio

View file

@ -10,7 +10,7 @@ from pathlib import Path
import pytest
from metagpt.config2 import config
from metagpt.config2 import Config
from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding
from metagpt.utils.common import aread
@ -18,6 +18,7 @@ from metagpt.utils.common import aread
@pytest.mark.asyncio
async def test_embedding(mocker):
# mock
config = Config.default()
mock_post = mocker.patch("aiohttp.ClientSession.post")
mock_response = mocker.AsyncMock()
mock_response.status = 200