mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-02 12:22:39 +02:00
Update Code Review
This commit is contained in:
parent
686b1cd130
commit
6d4c72cdf0
11 changed files with 549 additions and 662 deletions
|
|
@ -474,53 +474,26 @@ class ActionNode:
|
|||
"""
|
||||
model_class = self.create_class()
|
||||
fields = model_class.model_fields
|
||||
|
||||
|
||||
# Assuming there's only one field in the model
|
||||
if len(fields) == 1:
|
||||
return next(iter(fields))
|
||||
|
||||
|
||||
# If there are multiple fields, we might want to use self.key to find the right one
|
||||
return self.key
|
||||
|
||||
async def code_fill(
|
||||
self,
|
||||
context,
|
||||
function_name=None,
|
||||
timeout=USE_CONFIG_TIMEOUT
|
||||
):
|
||||
|
||||
async def code_fill(self, context, function_name=None, timeout=USE_CONFIG_TIMEOUT):
|
||||
"""
|
||||
fill CodeBlock Node
|
||||
"""
|
||||
|
||||
def extract_code_from_response(response):
|
||||
"""
|
||||
Extracts code wrapped in triple backticks from the response,
|
||||
removing any language specifier.
|
||||
|
||||
:param response: The full response from the LLM
|
||||
:return: The extracted code, or None if no code is found
|
||||
"""
|
||||
code_pattern = r"```(?:\w+\n)?([\s\S]*?)```"
|
||||
matches = re.findall(code_pattern, response)
|
||||
|
||||
if matches:
|
||||
# The first group in the regex contains the code without the language specifier
|
||||
code = matches[0].strip()
|
||||
return code
|
||||
return None
|
||||
|
||||
import re
|
||||
field_name = self.get_field_name()
|
||||
prompt = context
|
||||
# print("generate prompt", "\n", prompt)
|
||||
content = await self.llm.aask(prompt, timeout=timeout)
|
||||
# print("generate content", "\n", content)
|
||||
extracted_code = sanitize(code=content, entrypoint=function_name)
|
||||
# extracted_code = extract_code_from_response(content)
|
||||
result = {field_name: extracted_code}
|
||||
# print("final_result", "\n", result)
|
||||
return result
|
||||
|
||||
|
||||
async def messages_fill(
|
||||
self,
|
||||
):
|
||||
|
|
@ -540,7 +513,7 @@ class ActionNode:
|
|||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=USE_CONFIG_TIMEOUT,
|
||||
exclude=[],
|
||||
function_name: str = None
|
||||
function_name: str = None,
|
||||
):
|
||||
"""Fill the node(s) with mode.
|
||||
|
||||
|
|
|
|||
|
|
@ -4,28 +4,35 @@
|
|||
@Time : 2024/7/24 16:37
|
||||
@Author : didi
|
||||
@File : code_node.py
|
||||
@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py
|
||||
"""
|
||||
import os
|
||||
import ast
|
||||
import pathlib
|
||||
import traceback
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, Generator, List, Optional, Set, Tuple
|
||||
|
||||
import tree_sitter_python
|
||||
from tqdm import tqdm
|
||||
from tree_sitter import Language, Node, Parser
|
||||
|
||||
CLASS_TYPE = "class_definition"
|
||||
FUNCTION_TYPE = "function_definition"
|
||||
IMPORT_TYPE = ["import_statement", "import_from_statement"]
|
||||
IDENTIFIER_TYPE = "identifier"
|
||||
ATTRIBUTE_TYPE = "attribute"
|
||||
RETURN_TYPE = "return_statement"
|
||||
EXPRESSION_TYPE = "expression_statement"
|
||||
ASSIGNMENT_TYPE = "assignment"
|
||||
|
||||
class NodeType(Enum):
|
||||
CLASS = "class_definition"
|
||||
FUNCTION = "function_definition"
|
||||
IMPORT = ["import_statement", "import_from_statement"]
|
||||
IDENTIFIER = "identifier"
|
||||
ATTRIBUTE = "attribute"
|
||||
RETURN = "return_statement"
|
||||
EXPRESSION = "expression_statement"
|
||||
ASSIGNMENT = "assignment"
|
||||
|
||||
|
||||
def traverse_tree(node: Node) -> Generator[Node, None, None]:
|
||||
"""
|
||||
Traverse the tree structure starting from the given node.
|
||||
|
||||
:param node: The root node to start the traversal from.
|
||||
:return: A generator object that yields nodes in the tree.
|
||||
"""
|
||||
cursor = node.walk()
|
||||
depth = 0
|
||||
|
||||
|
|
@ -43,6 +50,7 @@ def traverse_tree(node: Node) -> Generator[Node, None, None]:
|
|||
else:
|
||||
depth -= 1
|
||||
|
||||
|
||||
def syntax_check(code, verbose=False):
|
||||
try:
|
||||
ast.parse(code)
|
||||
|
|
@ -52,6 +60,7 @@ def syntax_check(code, verbose=False):
|
|||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def code_extract(text: str) -> str:
|
||||
lines = text.split("\n")
|
||||
longest_line_pair = (0, 0)
|
||||
|
|
@ -68,22 +77,25 @@ def code_extract(text: str) -> str:
|
|||
|
||||
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
|
||||
|
||||
|
||||
def get_definition_name(node: Node) -> str:
|
||||
for child in node.children:
|
||||
if child.type == IDENTIFIER_TYPE:
|
||||
if child.type == NodeType.IDENTIFIER.value:
|
||||
return child.text.decode("utf8")
|
||||
|
||||
|
||||
|
||||
def has_return_statement(node: Node) -> bool:
|
||||
traverse_nodes = traverse_tree(node)
|
||||
for node in traverse_nodes:
|
||||
if node.type == RETURN_TYPE:
|
||||
if node.type == NodeType.RETURN.value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
|
||||
def dfs_get_deps(node: Node, deps: Set[str]) -> None:
|
||||
for child in node.children:
|
||||
if child.type == IDENTIFIER_TYPE:
|
||||
if child.type == NodeType.IDENTIFIER.value:
|
||||
deps.add(child.text.decode("utf8"))
|
||||
else:
|
||||
dfs_get_deps(child, deps)
|
||||
|
|
@ -104,12 +116,23 @@ def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[
|
|||
if current not in call_graph:
|
||||
continue
|
||||
for neighbour in call_graph[current]:
|
||||
if not (neighbour in visited):
|
||||
if neighbour not in visited:
|
||||
visited.add(neighbour)
|
||||
queue.append(neighbour)
|
||||
return visited
|
||||
|
||||
|
||||
def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
|
||||
"""
|
||||
Sanitize and extract relevant parts of the given Python code.
|
||||
This function parses the input code, extracts import statements, class and function definitions,
|
||||
and variable assignments. If an entrypoint is provided, it only includes definitions that are
|
||||
reachable from the entrypoint in the call graph.
|
||||
|
||||
:param code: The input Python code as a string.
|
||||
:param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis.
|
||||
:return: A sanitized version of the input code, containing only relevant parts.
|
||||
"""
|
||||
code = code_extract(code)
|
||||
code_bytes = bytes(code, "utf8")
|
||||
parser = Parser(Language(tree_sitter_python.language()))
|
||||
|
|
@ -123,30 +146,24 @@ def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
|
|||
definition_nodes = []
|
||||
|
||||
for child in root_node.children:
|
||||
if child.type in IMPORT_TYPE:
|
||||
if child.type in NodeType.IMPORT.value:
|
||||
import_nodes.append(child)
|
||||
elif child.type == CLASS_TYPE:
|
||||
elif child.type == NodeType.CLASS.value:
|
||||
name = get_definition_name(child)
|
||||
if not (
|
||||
name in class_names or name in variable_names or name in function_names
|
||||
):
|
||||
if not (name in class_names or name in variable_names or name in function_names):
|
||||
definition_nodes.append((name, child))
|
||||
class_names.add(name)
|
||||
elif child.type == FUNCTION_TYPE:
|
||||
elif child.type == NodeType.FUNCTION.value:
|
||||
name = get_definition_name(child)
|
||||
if not (
|
||||
name in function_names or name in variable_names or name in class_names
|
||||
) and has_return_statement(child):
|
||||
if not (name in function_names or name in variable_names or name in class_names) and has_return_statement(
|
||||
child
|
||||
):
|
||||
definition_nodes.append((name, child))
|
||||
function_names.add(get_definition_name(child))
|
||||
elif (
|
||||
child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE
|
||||
):
|
||||
elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value:
|
||||
subchild = child.children[0]
|
||||
name = get_definition_name(subchild)
|
||||
if not (
|
||||
name in variable_names or name in function_names or name in class_names
|
||||
):
|
||||
if not (name in variable_names or name in function_names or name in class_names):
|
||||
definition_nodes.append((name, subchild))
|
||||
variable_names.add(name)
|
||||
|
||||
|
|
@ -161,7 +178,7 @@ def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
|
|||
|
||||
for pair in definition_nodes:
|
||||
name, node = pair
|
||||
if entrypoint and not (name in reacheable):
|
||||
if entrypoint and name not in reacheable:
|
||||
continue
|
||||
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
|
||||
return sanitized_output[:-1].decode("utf8")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue