Merge branch 'main' into incremental_development

# Conflicts:
#	metagpt/schema.py
This commit is contained in:
mannaandpoem 2024-01-19 19:53:17 +08:00
commit e1b783ca14
127 changed files with 2346 additions and 648 deletions

View file

@ -88,6 +88,8 @@ class InvoiceOCR(Action):
async def _ocr(invoice_file_path: Path):
ocr = PaddleOCR(use_angle_cls=True, lang="ch", page_num=1)
ocr_result = ocr.ocr(str(invoice_file_path), cls=True)
for result in ocr_result[0]:
result[1] = (result[1][0], round(result[1][1], 2)) # round long confidence scores to reduce token costs
return ocr_result
async def run(self, file_path: Path, *args, **kwargs) -> list:

View file

@ -35,7 +35,6 @@ class PrepareDocuments(Action):
if path.exists() and not CONFIG.inc:
shutil.rmtree(path)
CONFIG.project_path = path
CONFIG.project_name = path.name
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
async def run(self, with_messages, **kwargs):

View file

@ -9,60 +9,209 @@
import re
from pathlib import Path
import aiofiles
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import CLASS_VIEW_FILE_REPO, GRAPH_REPO_FILE_REPO
from metagpt.const import (
AGGREGATION,
COMPOSITION,
DATA_API_DESIGN_FILE_REPO,
GENERALIZATION,
GRAPH_REPO_FILE_REPO,
)
from metagpt.logs import logger
from metagpt.repo_parser import RepoParser
from metagpt.schema import ClassAttribute, ClassMethod, ClassView
from metagpt.utils.common import split_namespace
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword, GraphRepository
class RebuildClassView(Action):
def __init__(self, name="", context=None, llm=None):
super().__init__(name=name, context=context, llm=llm)
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
repo_parser = RepoParser(base_directory=self.context)
class_views = await repo_parser.rebuild_class_views(path=Path(self.context)) # use pylint
repo_parser = RepoParser(base_directory=Path(self.context))
# use pylint
class_views, relationship_views, package_root = await repo_parser.rebuild_class_views(path=Path(self.context))
await GraphRepository.update_graph_db_with_class_views(graph_db, class_views)
symbols = repo_parser.generate_symbols() # use ast
await GraphRepository.update_graph_db_with_class_relationship_views(graph_db, relationship_views)
# use ast
direction, diff_path = self._diff_path(path_root=Path(self.context).resolve(), package_root=package_root)
symbols = repo_parser.generate_symbols()
for file_info in symbols:
# Align to the same root directory in accordance with `class_views`.
file_info.file = self._align_root(file_info.file, direction, diff_path)
await GraphRepository.update_graph_db_with_file_info(graph_db, file_info)
await self._create_mermaid_class_view(graph_db=graph_db)
await self._save(graph_db=graph_db)
await self._create_mermaid_class_views(graph_db=graph_db)
await graph_db.save()
async def _create_mermaid_class_view(self, graph_db):
pass
# dataset = await graph_db.select(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_PAGE_INFO)
# if not dataset:
# logger.warning(f"No page info for {concat_namespace(filename, class_name)}")
# return
# code_block_info = CodeBlockInfo.parse_raw(dataset[0].object_)
# src_code = await read_file_block(filename=Path(self.context) / filename, lineno=code_block_info.lineno, end_lineno=code_block_info.end_lineno)
# code_type = ""
# dataset = await graph_db.select(subject=filename, predicate=GraphKeyword.IS)
# for spo in dataset:
# if spo.object_ in ["javascript", "python"]:
# code_type = spo.object_
# break
async def _create_mermaid_class_views(self, graph_db):
path = Path(CONFIG.git_repo.workdir) / DATA_API_DESIGN_FILE_REPO
path.mkdir(parents=True, exist_ok=True)
pathname = path / CONFIG.git_repo.workdir.name
async with aiofiles.open(str(pathname.with_suffix(".mmd")), mode="w", encoding="utf-8") as writer:
content = "classDiagram\n"
logger.debug(content)
await writer.write(content)
# class names
rows = await graph_db.select(predicate=GraphKeyword.IS, object_=GraphKeyword.CLASS)
class_distinct = set()
relationship_distinct = set()
for r in rows:
await RebuildClassView._create_mermaid_class(r.subject, graph_db, writer, class_distinct)
for r in rows:
await RebuildClassView._create_mermaid_relationship(r.subject, graph_db, writer, relationship_distinct)
# try:
# node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
# class_view = node.instruct_content.model_dump()["Class View"]
# except Exception as e:
# class_view = RepoParser.rebuild_class_view(src_code, code_type)
# await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
# logger.info(f"{concat_namespace(filename, class_name)} {GraphKeyword.HAS_CLASS_VIEW} {class_view}")
@staticmethod
async def _create_mermaid_class(ns_class_name, graph_db, file_writer, distinct):
fields = split_namespace(ns_class_name)
if len(fields) > 2:
# Ignore sub-class
return
async def _save(self, graph_db):
class_view_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CLASS_VIEW_FILE_REPO)
dataset = await graph_db.select(predicate=GraphKeyword.HAS_CLASS_VIEW)
all_class_view = []
for spo in dataset:
title = f"---\ntitle: {spo.subject}\n---\n"
filename = re.sub(r"[/:]", "_", spo.subject) + ".mmd"
await class_view_file_repo.save(filename=filename, content=title + spo.object_)
all_class_view.append(spo.object_)
await class_view_file_repo.save(filename="all.mmd", content="\n".join(all_class_view))
class_view = ClassView(name=fields[1])
rows = await graph_db.select(subject=ns_class_name)
for r in rows:
name = split_namespace(r.object_)[-1]
name, visibility, abstraction = RebuildClassView._parse_name(name=name, language="python")
if r.predicate == GraphKeyword.HAS_CLASS_PROPERTY:
var_type = await RebuildClassView._parse_variable_type(r.object_, graph_db)
attribute = ClassAttribute(
name=name, visibility=visibility, abstraction=bool(abstraction), value_type=var_type
)
class_view.attributes.append(attribute)
elif r.predicate == GraphKeyword.HAS_CLASS_FUNCTION:
method = ClassMethod(name=name, visibility=visibility, abstraction=bool(abstraction))
await RebuildClassView._parse_function_args(method, r.object_, graph_db)
class_view.methods.append(method)
# update graph db
await graph_db.insert(ns_class_name, GraphKeyword.HAS_CLASS_VIEW, class_view.model_dump_json())
content = class_view.get_mermaid(align=1)
logger.debug(content)
await file_writer.write(content)
distinct.add(ns_class_name)
@staticmethod
async def _create_mermaid_relationship(ns_class_name, graph_db, file_writer, distinct):
s_fields = split_namespace(ns_class_name)
if len(s_fields) > 2:
# Ignore sub-class
return
predicates = {GraphKeyword.IS + v + GraphKeyword.OF: v for v in [GENERALIZATION, COMPOSITION, AGGREGATION]}
mappings = {
GENERALIZATION: " <|-- ",
COMPOSITION: " *-- ",
AGGREGATION: " o-- ",
}
content = ""
for p, v in predicates.items():
rows = await graph_db.select(subject=ns_class_name, predicate=p)
for r in rows:
o_fields = split_namespace(r.object_)
if len(o_fields) > 2:
# Ignore sub-class
continue
relationship = mappings.get(v, " .. ")
link = f"{o_fields[1]}{relationship}{s_fields[1]}"
distinct.add(link)
content += f"\t{link}\n"
if content:
logger.debug(content)
await file_writer.write(content)
@staticmethod
def _parse_name(name: str, language="python"):
pattern = re.compile(r"<I>(.*?)<\/I>")
result = re.search(pattern, name)
abstraction = ""
if result:
name = result.group(1)
abstraction = "*"
if name.startswith("__"):
visibility = "-"
elif name.startswith("_"):
visibility = "#"
else:
visibility = "+"
return name, visibility, abstraction
@staticmethod
async def _parse_variable_type(ns_name, graph_db) -> str:
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_TYPE_DESC)
if not rows:
return ""
vals = rows[0].object_.replace("'", "").split(":")
if len(vals) == 1:
return ""
val = vals[-1].strip()
return "" if val == "NoneType" else val + " "
@staticmethod
async def _parse_function_args(method: ClassMethod, ns_name: str, graph_db: GraphRepository):
rows = await graph_db.select(subject=ns_name, predicate=GraphKeyword.HAS_ARGS_DESC)
if not rows:
return
info = rows[0].object_.replace("'", "")
fs_tag = "("
ix = info.find(fs_tag)
fe_tag = "):"
eix = info.rfind(fe_tag)
if eix < 0:
fe_tag = ")"
eix = info.rfind(fe_tag)
args_info = info[ix + len(fs_tag) : eix].strip()
method.return_type = info[eix + len(fe_tag) :].strip()
if method.return_type == "None":
method.return_type = ""
if "(" in method.return_type:
method.return_type = method.return_type.replace("(", "Tuple[").replace(")", "]")
# parse args
if not args_info:
return
splitter_ixs = []
cost = 0
for i in range(len(args_info)):
if args_info[i] == "[":
cost += 1
elif args_info[i] == "]":
cost -= 1
if args_info[i] == "," and cost == 0:
splitter_ixs.append(i)
splitter_ixs.append(len(args_info))
args = []
ix = 0
for eix in splitter_ixs:
args.append(args_info[ix:eix])
ix = eix + 1
for arg in args:
parts = arg.strip().split(":")
if len(parts) == 1:
method.args.append(ClassAttribute(name=parts[0].strip()))
continue
method.args.append(ClassAttribute(name=parts[0].strip(), value_type=parts[-1].strip()))
@staticmethod
def _diff_path(path_root: Path, package_root: Path) -> (str, str):
if len(str(path_root)) > len(str(package_root)):
return "+", str(path_root.relative_to(package_root))
if len(str(path_root)) < len(str(package_root)):
return "-", str(package_root.relative_to(path_root))
return "=", "."
@staticmethod
def _align_root(path: str, direction: str, diff_path: str):
if direction == "=":
return path
if direction == "+":
return diff_path + "/" + path
else:
return path[len(diff_path) + 1 :]

View file

@ -1,33 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19
@Author : mashenquan
@File : rebuild_class_view_an.py
@Desc : Defines `ActionNode` objects used by rebuild_class_view.py
"""
from metagpt.actions.action_node import ActionNode
CLASS_SOURCE_CODE_BLOCK = ActionNode(
key="Class View",
expected_type=str,
instruction='Generate the mermaid class diagram corresponding to source code in "context."',
example="""
classDiagram
class A {
-int x
+int y
-int speed
-int direction
+__init__(x: int, y: int, speed: int, direction: int)
+change_direction(new_direction: int) None
+move() None
}
""",
)
REBUILD_CLASS_VIEW_NODES = [
CLASS_SOURCE_CODE_BLOCK,
]
REBUILD_CLASS_VIEW_NODE = ActionNode.from_children("RebuildClassView", REBUILD_CLASS_VIEW_NODES)

View file

@ -0,0 +1,60 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4
@Author : mashenquan
@File : rebuild_sequence_view.py
@Desc : Rebuild sequence view info
"""
from __future__ import annotations
from pathlib import Path
from typing import List
from metagpt.actions import Action
from metagpt.config import CONFIG
from metagpt.const import GRAPH_REPO_FILE_REPO
from metagpt.logs import logger
from metagpt.utils.common import aread, list_files
from metagpt.utils.di_graph_repository import DiGraphRepository
from metagpt.utils.graph_repository import GraphKeyword
class RebuildSequenceView(Action):
async def run(self, with_messages=None, format=CONFIG.prompt_schema):
graph_repo_pathname = CONFIG.git_repo.workdir / GRAPH_REPO_FILE_REPO / CONFIG.git_repo.workdir.name
graph_db = await DiGraphRepository.load_from(str(graph_repo_pathname.with_suffix(".json")))
entries = await RebuildSequenceView._search_main_entry(graph_db)
for entry in entries:
await self._rebuild_sequence_view(entry, graph_db)
await graph_db.save()
@staticmethod
async def _search_main_entry(graph_db) -> List:
rows = await graph_db.select(predicate=GraphKeyword.HAS_PAGE_INFO)
tag = "__name__:__main__"
entries = []
for r in rows:
if tag in r.subject or tag in r.object_:
entries.append(r)
return entries
async def _rebuild_sequence_view(self, entry, graph_db):
filename = entry.subject.split(":", 1)[0]
src_filename = RebuildSequenceView._get_full_filename(root=self.context, pathname=filename)
content = await aread(filename=src_filename, encoding="utf-8")
content = f"```python\n{content}\n```\n\n---\nTranslate the code above into Mermaid Sequence Diagram."
data = await self.llm.aask(
msg=content, system_msgs=["You are a python code to Mermaid Sequence Diagram translator in function detail"]
)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_SEQUENCE_VIEW, object_=data)
logger.info(data)
@staticmethod
def _get_full_filename(root: str | Path, pathname: str | Path) -> Path | None:
files = list_files(root=root)
postfix = "/" + str(pathname)
for i in files:
if str(i).endswith(postfix):
return i
return None

View file

@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4
@Author : mashenquan
@File : rebuild_sequence_view_an.py
"""
from metagpt.actions.action_node import ActionNode
from metagpt.utils.mermaid import MMC2
CODE_2_MERMAID_SEQUENCE_DIAGRAM = ActionNode(
key="Program call flow",
expected_type=str,
instruction='Translate the "context" content into "format example" format.',
example=MMC2,
)

View file

@ -158,7 +158,7 @@ class WriteCode(Action):
if not coding_context.code_doc:
# avoid root_path pydantic ValidationError if use WriteCode alone
root_path = CONFIG.src_workspace if CONFIG.src_workspace else ""
coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path)
coding_context.code_doc = Document(filename=coding_context.filename, root_path=str(root_path))
coding_context.code_doc.content = code
return coding_context

View file

@ -188,9 +188,11 @@ class WriteCodeReview(Action):
cr_prompt = EXAMPLE_AND_INSTRUCTION.format(
format_example=format_example,
)
len1 = len(iterative_code) if iterative_code else 0
len2 = len(self.context.code_doc.content) if self.context.code_doc.content else 0
logger.info(
f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, "
f"{len(self.context.code_doc.content)=}"
f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | len(iterative_code)={len1}, "
f"len(self.context.code_doc.content)={len2}"
)
result, rewrited_code = await self.write_code_review_and_rewrite(
context_prompt, cr_prompt, self.context.code_doc.filename

View file

@ -14,6 +14,7 @@
from __future__ import annotations
import json
import uuid
from pathlib import Path
from typing import Optional
@ -120,7 +121,7 @@ class WritePRD(Action):
# if sas.result:
# logger.info(sas.result)
# logger.info(rsp)
project_name = CONFIG.project_name if CONFIG.project_name else ""
project_name = CONFIG.project_name or ""
context = CONTEXT_TEMPLATE.format(requirements=requirements, project_name=project_name)
exclude = [PROJECT_NAME.key] if project_name else []
node = await WRITE_PRD_NODE.fill(context=context, llm=self.llm, exclude=exclude) # schema=schema
@ -190,6 +191,8 @@ class WritePRD(Action):
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
if ws_name:
CONFIG.project_name = ws_name
if not CONFIG.project_name: # The LLM failed to provide a project name, and the user didn't provide one either.
CONFIG.project_name = "app" + uuid.uuid4().hex[:16]
CONFIG.git_repo.rename_root(CONFIG.project_name)
async def _is_bugfix(self, context) -> bool:

View file

@ -11,7 +11,6 @@
from typing import Optional
from metagpt.actions.action import Action
from metagpt.config import CONFIG
from metagpt.const import TEST_CODES_FILE_REPO
from metagpt.logs import logger
from metagpt.schema import Document, TestingContext
@ -60,11 +59,12 @@ class WriteTest(Action):
self.context.test_doc = Document(
filename="test_" + self.context.code_doc.filename, root_path=TEST_CODES_FILE_REPO
)
fake_root = "/data"
prompt = PROMPT_TEMPLATE.format(
code_to_test=self.context.code_doc.content,
test_file_name=self.context.test_doc.filename,
source_file_path=self.context.code_doc.root_relative_path,
workspace=CONFIG.git_repo.workdir,
source_file_path=fake_root + "/" + self.context.code_doc.root_relative_path,
workspace=fake_root,
)
self.context.test_doc.content = await self.write_code(prompt)
return self.context

View file

@ -50,6 +50,9 @@ class LLMProviderEnum(Enum):
AZURE_OPENAI = "azure_openai"
OLLAMA = "ollama"
def __missing__(self, key):
return self.OPENAI
class Config(metaclass=Singleton):
"""
@ -108,6 +111,11 @@ class Config(metaclass=Singleton):
if v:
provider = k
break
if provider is None:
if self.DEFAULT_PROVIDER:
provider = LLMProviderEnum(self.DEFAULT_PROVIDER)
else:
raise NotConfiguredException("You should config a LLM configuration first")
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
@ -117,7 +125,6 @@ class Config(metaclass=Singleton):
if provider:
logger.info(f"API: {provider}")
return provider
raise NotConfiguredException("You should config a LLM configuration first")
def get_model_name(self, provider=None) -> str:
provider = provider or self.get_default_llm_provider_enum()
@ -137,6 +144,7 @@ class Config(metaclass=Singleton):
self.openai_api_key = self._get("OPENAI_API_KEY")
self.anthropic_api_key = self._get("ANTHROPIC_API_KEY")
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
self.zhipuai_api_model = self._get("ZHIPUAI_API_MODEL")
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")

View file

@ -129,3 +129,8 @@ LLM_API_TIMEOUT = 300
# Message id
IGNORED_MESSAGE_ID = "0"
# Class Relationship
GENERALIZATION = "Generalize"
COMPOSITION = "Composite"
AGGREGATION = "Aggregate"

View file

@ -67,7 +67,7 @@ class SkillsDeclaration(BaseModel):
@staticmethod
async def load(skill_yaml_file_name: Path = None) -> "SkillsDeclaration":
if not skill_yaml_file_name:
skill_yaml_file_name = Path(__file__).parent.parent.parent / ".well-known/skills.yaml"
skill_yaml_file_name = Path(__file__).parent.parent.parent / "docs/.well-known/skills.yaml"
async with aiofiles.open(str(skill_yaml_file_name), mode="r") as reader:
data = await reader.read(-1)
skill_data = yaml.safe_load(data)

View file

@ -43,7 +43,9 @@ class BaseLLM(ABC):
if system_msgs:
message = self._system_msgs(system_msgs)
else:
message = [self._default_system_msg()] if self.use_system_prompt else []
message = [self._default_system_msg()]
if not self.use_system_prompt:
message = []
if format_msgs:
message.extend(format_msgs)
message.append(self._user_msg(msg))
@ -87,6 +89,10 @@ class BaseLLM(ABC):
"""Required to provide the first text of choice"""
return rsp.get("choices")[0]["message"]["content"]
def get_choice_delta_text(self, rsp: dict) -> str:
"""Required to provide the first text of stream choice"""
return rsp.get("choices")[0]["delta"]["content"]
def get_choice_function(self, rsp: dict) -> dict:
"""Required to provide the first function of choice
:param dict rsp: OpenAI chat.comletion respond JSON, Note "message" must include "tool_calls",

View file

@ -79,10 +79,8 @@ class GeneralAPIRequestor(APIRequestor):
async def _interpret_async_response(
self, result: aiohttp.ClientResponse, stream: bool
) -> Tuple[Union[bytes, AsyncGenerator[bytes, None]], bool]:
if stream and (
"text/event-stream" in result.headers.get("Content-Type", "")
or "application/x-ndjson" in result.headers.get("Content-Type", "")
):
content_type = result.headers.get("Content-Type", "")
if stream and ("text/event-stream" in content_type or "application/x-ndjson" in content_type):
# the `Content-Type` of ollama stream resp is "application/x-ndjson"
return (
self._interpret_response_line(line, result.status, result.headers, stream=True)

View file

@ -120,6 +120,7 @@ class GeminiLLM(BaseLLM):
content = chunk.text
log_llm_stream(content)
collected_content.append(content)
log_llm_stream("\n")
full_content = "".join(collected_content)
usage = await self.aget_usage(messages, full_content)

View file

@ -119,6 +119,7 @@ class OllamaLLM(BaseLLM):
else:
# stream finished
usage = self.get_usage(chunk)
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)

View file

@ -134,6 +134,7 @@ class OpenAILLM(BaseLLM):
async for i in resp:
log_llm_stream(i)
collected_messages.append(i)
log_llm_stream("\n")
full_reply_content = "".join(collected_messages)
usage = self._calc_usage(messages, full_reply_content)

View file

@ -1,75 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : async_sse_client to make keep the use of Event to access response
# refs to `https://github.com/zhipuai/zhipuai-sdk-python/blob/main/zhipuai/utils/sse_client.py`
# refs to `zhipuai/core/_sse_client.py`
from zhipuai.utils.sse_client import _FIELD_SEPARATOR, Event, SSEClient
import json
from typing import Any, Iterator
class AsyncSSEClient(SSEClient):
async def _aread(self):
data = b""
class AsyncSSEClient(object):
def __init__(self, event_source: Iterator[Any]):
self._event_source = event_source
async def stream(self) -> dict:
if isinstance(self._event_source, bytes):
raise RuntimeError(
f"Request failed, msg: {self._event_source.decode('utf-8')}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
)
async for chunk in self._event_source:
for line in chunk.splitlines(True):
data += line
if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")):
yield data
data = b""
if data:
yield data
line = chunk.decode("utf-8")
if line.startswith(":") or not line:
return
async def async_events(self):
async for chunk in self._aread():
event = Event()
# Split before decoding so splitlines() only uses \r and \n
for line in chunk.splitlines():
# Decode the line.
line = line.decode(self._char_enc)
# Lines starting with a separator are comments and are to be
# ignored.
if not line.strip() or line.startswith(_FIELD_SEPARATOR):
continue
data = line.split(_FIELD_SEPARATOR, 1)
field = data[0]
# Ignore unknown fields.
if field not in event.__dict__:
self._logger.debug("Saw invalid field %s while parsing " "Server Side Event", field)
continue
if len(data) > 1:
# From the spec:
# "If value starts with a single U+0020 SPACE character,
# remove it from value."
if data[1].startswith(" "):
value = data[1][1:]
else:
value = data[1]
else:
# If no value is present after the separator,
# assume an empty value.
value = ""
# The data field may come over multiple lines and their values
# are concatenated with each other.
if field == "data":
event.__dict__[field] += value + "\n"
else:
event.__dict__[field] = value
# Events with no data are not dispatched.
if not event.data:
continue
# If the data field ends with a newline, remove it.
if event.data.endswith("\n"):
event.data = event.data[0:-1]
# Empty event names default to 'message'
event.event = event.event or "message"
# Dispatch the event
self._logger.debug("Dispatching %s...", event)
yield event
field, _p, value = line.partition(":")
if value.startswith(" "):
value = value[1:]
if field == "data":
if value.startswith("[DONE]"):
break
data = json.loads(value)
yield data

View file

@ -4,46 +4,27 @@
import json
import zhipuai
from zhipuai.model_api.api import InvokeType, ModelAPI
from zhipuai.utils.http_client import headers as zhipuai_default_headers
from zhipuai import ZhipuAI
from zhipuai.core._http_client import ZHIPUAI_DEFAULT_TIMEOUT
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
class ZhiPuModelAPI(ModelAPI):
@classmethod
def get_header(cls) -> dict:
token = cls._generate_token()
zhipuai_default_headers.update({"Authorization": token})
return zhipuai_default_headers
@classmethod
def get_sse_header(cls) -> dict:
token = cls._generate_token()
headers = {"Authorization": token}
return headers
@classmethod
def split_zhipu_api_url(cls, invoke_type: InvokeType, kwargs):
class ZhiPuModelAPI(ZhipuAI):
def split_zhipu_api_url(self):
# use this method to prevent zhipu api upgrading to different version.
# and follow the GeneralAPIRequestor implemented based on openai sdk
zhipu_api_url = cls._build_api_url(kwargs, invoke_type)
"""
example:
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
"""
zhipu_api_url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
arr = zhipu_api_url.split("/api/")
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
# ("https://open.bigmodel.cn/api" , "/paas/v4/chat/completions")
return f"{arr[0]}/api", f"/{arr[1]}"
@classmethod
async def arequest(cls, invoke_type: InvokeType, stream: bool, method: str, headers: dict, kwargs):
async def arequest(self, stream: bool, method: str, headers: dict, kwargs):
# TODO to make the async request to be more generic for models in http mode.
assert method in ["post", "get"]
base_url, url = cls.split_zhipu_api_url(invoke_type, kwargs)
base_url, url = self.split_zhipu_api_url()
requester = GeneralAPIRequestor(base_url=base_url)
result, _, api_key = await requester.arequest(
method=method,
@ -51,25 +32,23 @@ class ZhiPuModelAPI(ModelAPI):
headers=headers,
stream=stream,
params=kwargs,
request_timeout=zhipuai.api_timeout_seconds,
request_timeout=ZHIPUAI_DEFAULT_TIMEOUT.read,
)
return result
@classmethod
async def ainvoke(cls, **kwargs) -> dict:
async def acreate(self, **kwargs) -> dict:
"""async invoke different from raw method `async_invoke` which get the final result by task_id"""
headers = cls.get_header()
resp = await cls.arequest(
invoke_type=InvokeType.SYNC, stream=False, method="post", headers=headers, kwargs=kwargs
)
headers = self._default_headers
resp = await self.arequest(stream=False, method="post", headers=headers, kwargs=kwargs)
resp = resp.decode("utf-8")
resp = json.loads(resp)
if "error" in resp:
raise RuntimeError(
f"Request failed, msg: {resp}, please ref to `https://open.bigmodel.cn/dev/api#error-code-v3`"
)
return resp
@classmethod
async def asse_invoke(cls, **kwargs) -> AsyncSSEClient:
async def acreate_stream(self, **kwargs) -> AsyncSSEClient:
"""async sse_invoke"""
headers = cls.get_sse_header()
return AsyncSSEClient(
await cls.arequest(invoke_type=InvokeType.SSE, stream=True, method="post", headers=headers, kwargs=kwargs)
)
headers = self._default_headers
return AsyncSSEClient(await self.arequest(stream=True, method="post", headers=headers, kwargs=kwargs))

View file

@ -2,11 +2,9 @@
# -*- coding: utf-8 -*-
# @Desc : zhipuai LLM from https://open.bigmodel.cn/dev/api#sdk
import json
from enum import Enum
import openai
import zhipuai
from requests import ConnectionError
from tenacity import (
after_log,
@ -15,6 +13,7 @@ from tenacity import (
stop_after_attempt,
wait_random_exponential,
)
from zhipuai.types.chat.chat_completion import Completion
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import log_llm_stream, logger
@ -35,26 +34,25 @@ class ZhiPuEvent(Enum):
class ZhiPuAILLM(BaseLLM):
"""
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
From now, there is only one model named `chatglm_turbo`
From now, support glm-3-turboglm-4, and also system_prompt.
"""
def __init__(self):
self.__init_zhipuai(CONFIG)
self.llm = ZhiPuModelAPI
self.model = "chatglm_turbo" # so far only one model, just use it
self.use_system_prompt: bool = False # zhipuai has no system prompt when use api
self.llm = ZhiPuModelAPI(api_key=self.api_key)
def __init_zhipuai(self, config: CONFIG):
assert config.zhipuai_api_key
zhipuai.api_key = config.zhipuai_api_key
self.api_key = config.zhipuai_api_key
self.model = config.zhipuai_api_model # so far, it support glm-3-turbo、glm-4
# due to use openai sdk, set the api_key but it will't be used.
# openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used.
if config.openai_proxy:
# FIXME: openai v1.x sdk has no proxy support
openai.proxy = config.openai_proxy
def _const_kwargs(self, messages: list[dict]) -> dict:
kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3}
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
kwargs = {"model": self.model, "messages": messages, "stream": stream, "temperature": 0.3}
return kwargs
def _update_costs(self, usage: dict):
@ -67,21 +65,15 @@ class ZhiPuAILLM(BaseLLM):
except Exception as e:
logger.error(f"zhipuai updats costs failed! exp: {e}")
def get_choice_text(self, resp: dict) -> str:
"""get the first text of choice from llm response"""
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
assert assist_msg["role"] == "assistant"
return assist_msg.get("content")
def completion(self, messages: list[dict], timeout=3) -> dict:
resp = self.llm.invoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
resp: Completion = self.llm.chat.completions.create(**self._const_kwargs(messages))
usage = resp.usage.model_dump()
self._update_costs(usage)
return resp
return resp.model_dump()
async def _achat_completion(self, messages: list[dict], timeout=3) -> dict:
resp = await self.llm.ainvoke(**self._const_kwargs(messages))
usage = resp.get("data").get("usage")
resp = await self.llm.acreate(**self._const_kwargs(messages))
usage = resp.get("usage", {})
self._update_costs(usage)
return resp
@ -89,35 +81,19 @@ class ZhiPuAILLM(BaseLLM):
return await self._achat_completion(messages, timeout=timeout)
async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str:
response = await self.llm.asse_invoke(**self._const_kwargs(messages))
response = await self.llm.acreate_stream(**self._const_kwargs(messages, stream=True))
collected_content = []
usage = {}
async for event in response.async_events():
if event.event == ZhiPuEvent.ADD.value:
content = event.data
async for chunk in response.stream():
finish_reason = chunk.get("choices")[0].get("finish_reason")
if finish_reason == "stop":
usage = chunk.get("usage", {})
else:
content = self.get_choice_delta_text(chunk)
collected_content.append(content)
log_llm_stream(content)
elif event.event == ZhiPuEvent.ERROR.value or event.event == ZhiPuEvent.INTERRUPTED.value:
content = event.data
logger.error(f"event error: {content}", end="")
elif event.event == ZhiPuEvent.FINISH.value:
"""
event.meta
{
"task_status":"SUCCESS",
"usage":{
"completion_tokens":351,
"prompt_tokens":595,
"total_tokens":946
},
"task_id":"xx",
"request_id":"xxx"
}
"""
meta = json.loads(event.meta)
usage = meta.get("usage")
else:
print(f"zhipuapi else event: {event.data}", end="")
log_llm_stream("\n")
self._update_costs(usage)
full_content = "".join(collected_content)

View file

@ -12,14 +12,14 @@ import json
import re
import subprocess
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional
import aiofiles
import pandas as pd
from pydantic import BaseModel, Field
from metagpt.const import AGGREGATION, COMPOSITION, GENERALIZATION
from metagpt.logs import logger
from metagpt.utils.common import any_to_str
from metagpt.utils.common import any_to_str, aread
from metagpt.utils.exceptions import handle_exception
@ -46,6 +46,13 @@ class ClassInfo(BaseModel):
methods: Dict[str, str] = Field(default_factory=dict)
class ClassRelationship(BaseModel):
src: str = ""
dest: str = ""
relationship: str = ""
label: Optional[str] = None
class RepoParser(BaseModel):
base_directory: Path = Field(default=None)
@ -60,7 +67,8 @@ class RepoParser(BaseModel):
file_info = RepoFileInfo(file=str(file_path.relative_to(self.base_directory)))
for node in tree:
info = RepoParser.node_to_str(node)
file_info.page_info.append(info)
if info:
file_info.page_info.append(info)
if isinstance(node, ast.ClassDef):
class_methods = [m.name for m in node.body if is_func(m)]
file_info.classes.append({"name": node.name, "methods": class_methods})
@ -110,7 +118,9 @@ class RepoParser(BaseModel):
return output_path
@staticmethod
def node_to_str(node) -> (int, int, str, str | Tuple):
def node_to_str(node) -> CodeBlockInfo | None:
if isinstance(node, ast.Try):
return None
if any_to_str(node) == any_to_str(ast.Expr):
return CodeBlockInfo(
lineno=node.lineno,
@ -129,6 +139,7 @@ class RepoParser(BaseModel):
},
any_to_str(ast.If): RepoParser._parse_if,
any_to_str(ast.AsyncFunctionDef): lambda x: x.name,
any_to_str(ast.AnnAssign): lambda x: RepoParser._parse_variable(x.target),
}
func = mappings.get(any_to_str(node))
if func:
@ -143,7 +154,8 @@ class RepoParser(BaseModel):
else:
raise NotImplementedError(f"Not implement:{val}")
return code_block
raise NotImplementedError(f"Not implement code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
logger.warning(f"Unsupported code block:{node.lineno}, {node.end_lineno}, {any_to_str(node)}")
return None
@staticmethod
def _parse_expr(node) -> List:
@ -164,22 +176,51 @@ class RepoParser(BaseModel):
@staticmethod
def _parse_if(n):
tokens = [RepoParser._parse_variable(n.test.left)]
for item in n.test.comparators:
tokens.append(RepoParser._parse_variable(item))
tokens = []
try:
if isinstance(n.test, ast.BoolOp):
tokens = []
for v in n.test.values:
tokens.extend(RepoParser._parse_if_compare(v))
return tokens
if isinstance(n.test, ast.Compare):
v = RepoParser._parse_variable(n.test.left)
if v:
tokens.append(v)
for item in n.test.comparators:
v = RepoParser._parse_variable(item)
if v:
tokens.append(v)
return tokens
except Exception as e:
logger.warning(f"Unsupported if: {n}, err:{e}")
return tokens
@staticmethod
def _parse_if_compare(n):
if hasattr(n, "left"):
return RepoParser._parse_variable(n.left)
else:
return []
@staticmethod
def _parse_variable(node):
funcs = {
any_to_str(ast.Constant): lambda x: x.value,
any_to_str(ast.Name): lambda x: x.id,
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}",
}
func = funcs.get(any_to_str(node))
if not func:
raise NotImplementedError(f"Not implement:{node}")
return func(node)
try:
funcs = {
any_to_str(ast.Constant): lambda x: x.value,
any_to_str(ast.Name): lambda x: x.id,
any_to_str(ast.Attribute): lambda x: f"{x.value.id}.{x.attr}"
if hasattr(x.value, "id")
else f"{x.attr}",
any_to_str(ast.Call): lambda x: RepoParser._parse_variable(x.func),
any_to_str(ast.Tuple): lambda x: "",
}
func = funcs.get(any_to_str(node))
if not func:
raise NotImplementedError(f"Not implement:{node}")
return func(node)
except Exception as e:
logger.warning(f"Unsupported variable:{node}, err:{e}")
@staticmethod
def _parse_assign(node):
@ -197,18 +238,21 @@ class RepoParser(BaseModel):
raise ValueError(f"{result}")
class_view_pathname = path / "classes.dot"
class_views = await self._parse_classes(class_view_pathname)
relationship_views = await self._parse_class_relationships(class_view_pathname)
packages_pathname = path / "packages.dot"
class_views = RepoParser._repair_namespaces(class_views=class_views, path=path)
class_views, relationship_views, package_root = RepoParser._repair_namespaces(
class_views=class_views, relationship_views=relationship_views, path=path
)
class_view_pathname.unlink(missing_ok=True)
packages_pathname.unlink(missing_ok=True)
return class_views
return class_views, relationship_views, package_root
async def _parse_classes(self, class_view_pathname):
class_views = []
if not class_view_pathname.exists():
return class_views
async with aiofiles.open(str(class_view_pathname), mode="r") as reader:
lines = await reader.readlines()
data = await aread(filename=class_view_pathname, encoding="utf-8")
lines = data.split("\n")
for line in lines:
package_name, info = RepoParser._split_class_line(line)
if not package_name:
@ -229,6 +273,19 @@ class RepoParser(BaseModel):
class_views.append(class_info)
return class_views
async def _parse_class_relationships(self, class_view_pathname) -> List[ClassRelationship]:
relationship_views = []
if not class_view_pathname.exists():
return relationship_views
data = await aread(filename=class_view_pathname, encoding="utf-8")
lines = data.split("\n")
for line in lines:
relationship = RepoParser._split_relationship_line(line)
if not relationship:
continue
relationship_views.append(relationship)
return relationship_views
@staticmethod
def _split_class_line(line):
part_splitor = '" ['
@ -247,6 +304,40 @@ class RepoParser(BaseModel):
info = re.sub(r"<br[^>]*>", "\n", info)
return class_name, info
@staticmethod
def _split_relationship_line(line):
splitters = [" -> ", " [", "];"]
idxs = []
for tag in splitters:
if tag not in line:
return None
idxs.append(line.find(tag))
ret = ClassRelationship()
ret.src = line[0 : idxs[0]].strip('"')
ret.dest = line[idxs[0] + len(splitters[0]) : idxs[1]].strip('"')
properties = line[idxs[1] + len(splitters[1]) : idxs[2]].strip(" ")
mappings = {
'arrowhead="empty"': GENERALIZATION,
'arrowhead="diamond"': COMPOSITION,
'arrowhead="odiamond"': AGGREGATION,
}
for k, v in mappings.items():
if k in properties:
ret.relationship = v
if v != GENERALIZATION:
ret.label = RepoParser._get_label(properties)
break
return ret
@staticmethod
def _get_label(line):
tag = 'label="'
if tag not in line:
return ""
ix = line.find(tag)
eix = line.find('"', ix + len(tag))
return line[ix + len(tag) : eix]
@staticmethod
def _create_path_mapping(path: str | Path) -> Dict[str, str]:
mappings = {
@ -271,9 +362,11 @@ class RepoParser(BaseModel):
return mappings
@staticmethod
def _repair_namespaces(class_views: List[ClassInfo], path: str | Path) -> List[ClassInfo]:
def _repair_namespaces(
class_views: List[ClassInfo], relationship_views: List[ClassRelationship], path: str | Path
) -> (List[ClassInfo], List[ClassRelationship], str):
if not class_views:
return []
return [], [], ""
c = class_views[0]
full_key = str(path).lstrip("/").replace("/", ".")
root_namespace = RepoParser._find_root(full_key, c.package)
@ -290,7 +383,12 @@ class RepoParser(BaseModel):
for c in class_views:
c.package = RepoParser._repair_ns(c.package, new_mappings)
return class_views
for i in range(len(relationship_views)):
v = relationship_views[i]
v.src = RepoParser._repair_ns(v.src, new_mappings)
v.dest = RepoParser._repair_ns(v.dest, new_mappings)
relationship_views[i] = v
return class_views, relationship_views, root_path
@staticmethod
def _repair_ns(package, mappings):

View file

@ -36,7 +36,8 @@ class QaEngineer(Role):
profile: str = "QaEngineer"
goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs"
constraints: str = (
"The test code you write should conform to code standard like PEP8, be modular, " "easy to read and maintain"
"The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain."
"Use same language as user requirement"
)
test_round_allowed: int = 5
test_round: int = 0
@ -62,6 +63,8 @@ class QaEngineer(Role):
if not filename or "test" in filename:
continue
code_doc = await src_file_repo.get(filename)
if not code_doc:
continue
test_doc = await tests_file_repo.get("test_" + code_doc.filename)
if not test_doc:
test_doc = Document(

View file

@ -166,6 +166,9 @@ class Role(SerializationMixin, is_polymorphic_base=True):
Role.model_rebuild()
super().__init__(**data)
if self.is_human:
self.llm = HumanProvider()
self.llm.system_prompt = self._get_prefix()
self._watch(data.get("watch") or [UserRequirement])
@ -418,7 +421,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
Use llm to select actions in _think dynamically
"""
actions_taken = 0
rsp = Message(content="No actions taken yet") # will be overwritten after Role _act
rsp = Message(content="No actions taken yet", cause_by=Action) # will be overwritten after Role _act
while actions_taken < self.rc.max_react_loop:
# think
await self._think()

View file

@ -459,3 +459,63 @@ class CodePlanAndChangeContext(BaseContext):
prd_docs: List[Document]
design_docs: List[Document]
tasks_docs: List[Document]
# mermaid class view
class ClassMeta(BaseModel):
name: str = ""
abstraction: bool = False
static: bool = False
visibility: str = ""
class ClassAttribute(ClassMeta):
value_type: str = ""
default_value: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
if self.value_type:
content += self.value_type + " "
content += self.name
if self.default_value:
content += "="
if self.value_type not in ["str", "string", "String"]:
content += self.default_value
else:
content += '"' + self.default_value.replace('"', "") + '"'
if self.abstraction:
content += "*"
if self.static:
content += "$"
return content
class ClassMethod(ClassMeta):
args: List[ClassAttribute] = Field(default_factory=list)
return_type: str = ""
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + self.visibility
content += self.name + "(" + ",".join([v.get_mermaid(align=0) for v in self.args]) + ")"
if self.return_type:
content += ":" + self.return_type
if self.abstraction:
content += "*"
if self.static:
content += "$"
return content
class ClassView(ClassMeta):
attributes: List[ClassAttribute] = Field(default_factory=list)
methods: List[ClassMethod] = Field(default_factory=list)
def get_mermaid(self, align=1) -> str:
content = "".join(["\t" for i in range(align)]) + "class " + self.name + "{\n"
for v in self.attributes:
content += v.get_mermaid(align=align + 1) + "\n"
for v in self.methods:
content += v.get_mermaid(align=align + 1) + "\n"
content += "".join(["\t" for i in range(align)]) + "}\n"
return content

View file

@ -83,11 +83,12 @@ class Team(BaseModel):
"""Invest company. raise NoMoneyException when exceed max_budget."""
self.investment = investment
CONFIG.max_budget = investment
CONFIG.cost_manager.max_budget = investment
logger.info(f"Investment: ${investment}.")
@staticmethod
def _check_balance():
if CONFIG.cost_manager.total_cost > CONFIG.cost_manager.max_budget:
if CONFIG.cost_manager.total_cost >= CONFIG.cost_manager.max_budget:
raise NoMoneyException(
CONFIG.cost_manager.total_cost, f"Insufficient funds: {CONFIG.cost_manager.max_budget}"
)

View file

@ -5,6 +5,12 @@
@Author : mashenquan
@File : metagpt_oas3_api_svc.py
@Desc : MetaGPT OpenAPI Specification 3.0 REST API service
curl -X 'POST' \
'http://localhost:8080/openapi/greeting/dave' \
-H 'accept: text/plain' \
-H 'Content-Type: application/json' \
-d '{}'
"""
from pathlib import Path
@ -15,7 +21,7 @@ import connexion
def oas_http_svc():
"""Start the OAS 3.0 OpenAPI HTTP service"""
print("http://localhost:8080/oas3/ui/")
specification_dir = Path(__file__).parent.parent.parent / ".well-known"
specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known"
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
app.add_api("metagpt_oas3_api.yaml")
app.add_api("openapi.yaml")

View file

@ -23,7 +23,7 @@ async def post_greeting(name: str) -> str:
if __name__ == "__main__":
specification_dir = Path(__file__).parent.parent.parent / ".well-known"
specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known"
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
app.add_api("openapi.yaml", arguments={"title": "Hello World Example"})
app.run(port=8082)

View file

@ -407,6 +407,10 @@ def concat_namespace(*args) -> str:
return ":".join(str(value) for value in args)
def split_namespace(ns_class_name: str) -> List[str]:
return ns_class_name.split(":")
def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
"""
Generates a logging function to be used after a call is retried.
@ -546,3 +550,20 @@ async def read_file_block(filename: str | Path, lineno: int, end_lineno: int):
break
lines.append(line)
return "".join(lines)
def list_files(root: str | Path) -> List[Path]:
files = []
try:
directory_path = Path(root)
if not directory_path.exists():
return []
for file_path in directory_path.iterdir():
if file_path.is_file():
files.append(file_path)
else:
subfolder_files = list_files(root=file_path)
files.extend(subfolder_files)
except Exception as e:
logger.error(f"Error: {e}")
return files

View file

@ -12,9 +12,9 @@ import json
from pathlib import Path
from typing import List
import aiofiles
import networkx
from metagpt.utils.common import aread, awrite
from metagpt.utils.graph_repository import SPO, GraphRepository
@ -55,12 +55,10 @@ class DiGraphRepository(GraphRepository):
if not path.exists():
path.mkdir(parents=True, exist_ok=True)
pathname = Path(path) / self.name
async with aiofiles.open(str(pathname.with_suffix(".json")), mode="w", encoding="utf-8") as writer:
await writer.write(data)
await awrite(filename=pathname.with_suffix(".json"), data=data, encoding="utf-8")
async def load(self, pathname: str | Path):
async with aiofiles.open(str(pathname), mode="r", encoding="utf-8") as reader:
data = await reader.read(-1)
data = await aread(filename=pathname, encoding="utf-8")
m = json.loads(data)
self._repo = networkx.node_link_graph(m)

View file

@ -55,6 +55,7 @@ class FileRepository:
"""
pathname = self.workdir / filename
pathname.parent.mkdir(parents=True, exist_ok=True)
content = content if content else "" # avoid `argument must be str, not None` to make it continue
async with aiofiles.open(str(pathname), mode="w") as writer:
await writer.write(content)
logger.info(f"save to: {str(pathname)}")
@ -138,6 +139,8 @@ class FileRepository:
files = self._git_repo.changed_files
relative_files = {}
for p, ct in files.items():
if ct.value == "D": # deleted
continue
try:
rf = Path(p).relative_to(self._relative_path)
except ValueError:

View file

@ -13,19 +13,25 @@ from typing import List
from pydantic import BaseModel
from metagpt.repo_parser import ClassInfo, RepoFileInfo
from metagpt.logs import logger
from metagpt.repo_parser import ClassInfo, ClassRelationship, RepoFileInfo
from metagpt.utils.common import concat_namespace
class GraphKeyword:
IS = "is"
OF = "Of"
ON = "On"
CLASS = "class"
FUNCTION = "function"
HAS_FUNCTION = "has_function"
SOURCE_CODE = "source_code"
NULL = "<null>"
GLOBAL_VARIABLE = "global_variable"
CLASS_FUNCTION = "class_function"
CLASS_PROPERTY = "class_property"
HAS_CLASS_FUNCTION = "has_class_function"
HAS_CLASS_PROPERTY = "has_class_property"
HAS_CLASS = "has_class"
HAS_PAGE_INFO = "has_page_info"
HAS_CLASS_VIEW = "has_class_view"
@ -73,11 +79,13 @@ class GraphRepository(ABC):
await graph_db.insert(subject=file_info.file, predicate=GraphKeyword.IS, object_=file_type)
for c in file_info.classes:
class_name = c.get("name", "")
# file -> class
await graph_db.insert(
subject=file_info.file,
predicate=GraphKeyword.HAS_CLASS,
object_=concat_namespace(file_info.file, class_name),
)
# class detail
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.IS,
@ -85,12 +93,22 @@ class GraphRepository(ABC):
)
methods = c.get("methods", [])
for fn in methods:
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name),
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
object_=concat_namespace(file_info.file, class_name, fn),
)
await graph_db.insert(
subject=concat_namespace(file_info.file, class_name, fn),
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS_FUNCTION,
)
for f in file_info.functions:
# file -> function
await graph_db.insert(
subject=file_info.file, predicate=GraphKeyword.HAS_FUNCTION, object_=concat_namespace(file_info.file, f)
)
# function detail
await graph_db.insert(
subject=concat_namespace(file_info.file, f), predicate=GraphKeyword.IS, object_=GraphKeyword.FUNCTION
)
@ -105,30 +123,37 @@ class GraphRepository(ABC):
await graph_db.insert(
subject=concat_namespace(file_info.file, *code_block.tokens),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
object_=code_block.model_dump_json(),
)
for k, v in code_block.properties.items():
await graph_db.insert(
subject=concat_namespace(file_info.file, k, v),
predicate=GraphKeyword.HAS_PAGE_INFO,
object_=code_block.json(ensure_ascii=False),
object_=code_block.model_dump_json(),
)
@staticmethod
async def update_graph_db_with_class_views(graph_db: "GraphRepository", class_views: List[ClassInfo]):
for c in class_views:
filename, class_name = c.package.split(":", 1)
filename, _ = c.package.split(":", 1)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=GraphKeyword.SOURCE_CODE)
file_types = {".py": "python", ".js": "javascript"}
file_type = file_types.get(Path(filename).suffix, GraphKeyword.NULL)
await graph_db.insert(subject=filename, predicate=GraphKeyword.IS, object_=file_type)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=class_name)
await graph_db.insert(subject=filename, predicate=GraphKeyword.HAS_CLASS, object_=c.package)
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.IS,
object_=GraphKeyword.CLASS,
)
for vn, vt in c.attributes.items():
# class -> property
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_PROPERTY,
object_=concat_namespace(c.package, vn),
)
# property detail
await graph_db.insert(
subject=concat_namespace(c.package, vn),
predicate=GraphKeyword.IS,
@ -138,6 +163,15 @@ class GraphRepository(ABC):
subject=concat_namespace(c.package, vn), predicate=GraphKeyword.HAS_TYPE_DESC, object_=vt
)
for fn, desc in c.methods.items():
if "</I>" in desc and "<I>" not in desc:
logger.error(desc)
# class -> function
await graph_db.insert(
subject=c.package,
predicate=GraphKeyword.HAS_CLASS_FUNCTION,
object_=concat_namespace(c.package, fn),
)
# function detail
await graph_db.insert(
subject=concat_namespace(c.package, fn),
predicate=GraphKeyword.IS,
@ -148,3 +182,19 @@ class GraphRepository(ABC):
predicate=GraphKeyword.HAS_ARGS_DESC,
object_=desc,
)
@staticmethod
async def update_graph_db_with_class_relationship_views(
graph_db: "GraphRepository", relationship_views: List[ClassRelationship]
):
for r in relationship_views:
await graph_db.insert(
subject=r.src, predicate=GraphKeyword.IS + r.relationship + GraphKeyword.OF, object_=r.dest
)
if not r.label:
continue
await graph_db.insert(
subject=r.src,
predicate=GraphKeyword.IS + r.relationship + GraphKeyword.ON,
object_=concat_namespace(r.dest, r.label),
)

View file

@ -120,6 +120,15 @@ def repair_json_format(output: str) -> str:
elif output.startswith("{") and output.endswith("]"):
output = output[:-1] + "}"
# remove `#` in output json str, usually appeared in `glm-4`
arr = output.split("\n")
new_arr = []
for line in arr:
idx = line.find("#")
if idx >= 0:
line = line[:idx]
new_arr.append(line)
output = "\n".join(new_arr)
return output
@ -168,15 +177,17 @@ def repair_invalid_json(output: str, error: str) -> str:
example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765)
example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266)
"""
pattern = r"line ([0-9]+)"
pattern = r"line ([0-9]+) column ([0-9]+)"
matches = re.findall(pattern, error, re.DOTALL)
if len(matches) > 0:
line_no = int(matches[0]) - 1
line_no = int(matches[0][0]) - 1
col_no = int(matches[0][1]) - 1
# due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'`
output = output.replace('"""', '"').replace("'''", '"')
arr = output.split("\n")
rline = arr[line_no] # raw line
line = arr[line_no].strip()
# different general problems
if line.endswith("],"):
@ -187,9 +198,12 @@ def repair_invalid_json(output: str, error: str) -> str:
new_line = line.replace("}", "")
elif line.endswith("},") and output.endswith("},"):
new_line = line[:-1]
elif '",' not in line and "," not in line:
elif (rline[col_no] in ["'", '"']) and (line.startswith('"') or line.startswith("'")) and "," not in line:
# problem, `"""` or `'''` without `,`
new_line = f",{line}"
elif '",' not in line and "," not in line and '"' not in line:
new_line = f'{line}",'
elif "," not in line:
elif not line.endswith(","):
# problem, miss char `,` at the end.
new_line = f"{line},"
elif "," in line and len(line) == 1:

View file

@ -27,7 +27,8 @@ TOKEN_COSTS = {
"gpt-4-0613": {"prompt": 0.06, "completion": 0.12},
"gpt-4-1106-preview": {"prompt": 0.01, "completion": 0.03},
"text-embedding-ada-002": {"prompt": 0.0004, "completion": 0.0},
"chatglm_turbo": {"prompt": 0.0, "completion": 0.00069}, # 32k version, prompt + completion tokens=0.005¥/k-tokens
"glm-3-turbo": {"prompt": 0.0, "completion": 0.0007}, # 128k version, prompt + completion tokens=0.005¥/k-tokens
"glm-4": {"prompt": 0.0, "completion": 0.014}, # 128k version, prompt + completion tokens=0.1¥/k-tokens
"gemini-pro": {"prompt": 0.00025, "completion": 0.0005},
}