This commit is contained in:
didi 2024-07-25 10:47:17 +08:00
parent ca1c8f8c5c
commit 772d2aea56
9 changed files with 583 additions and 65 deletions

View file

@ -17,6 +17,7 @@ from pydantic import BaseModel, Field, create_model, model_validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action_outcls_registry import register_action_outcls
from metagpt.actions.code_sanitize import sanitize
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.llm import BaseLLM
from metagpt.logs import logger
@ -484,6 +485,7 @@ class ActionNode:
async def code_fill(
self,
context,
function_name=None,
timeout=USE_CONFIG_TIMEOUT
):
"""
@ -510,10 +512,10 @@ class ActionNode:
import re
field_name = self.get_field_name()
prompt = context
# prompt += "\nPlease wrap the generated code within triple backticks, like this: ```<code>```"
content = await self.llm.aask(prompt, timeout=timeout)
extracted_code = extract_code_from_response(content)
# TODO 在前置逻辑中完成entrypoint的提取就可以
extracted_code = sanitize(code=content, entrypoint=function_name)
# extracted_code = extract_code_from_response(content)
result = {field_name: extracted_code}
return result
@ -536,6 +538,7 @@ class ActionNode:
images: Optional[Union[str, list[str]]] = None,
timeout=USE_CONFIG_TIMEOUT,
exclude=[],
function_name: str = None
):
"""Fill the node(s) with mode.
@ -563,7 +566,7 @@ class ActionNode:
schema = self.schema
if mode == self.MODE_CODE_FILL:
result = await self.code_fill(context, timeout)
result = await self.code_fill(context, function_name, timeout)
self.instruct_content = self.create_class()(**result)
return self

View file

@ -0,0 +1,167 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/7/24 16:37
@Author : didi
@File : code_node.py
"""
import os
import ast
import pathlib
import traceback
from typing import Dict, Generator, List, Optional, Set, Tuple
import tree_sitter_python
from tqdm import tqdm
from tree_sitter import Language, Node, Parser
CLASS_TYPE = "class_definition"
FUNCTION_TYPE = "function_definition"
IMPORT_TYPE = ["import_statement", "import_from_statement"]
IDENTIFIER_TYPE = "identifier"
ATTRIBUTE_TYPE = "attribute"
RETURN_TYPE = "return_statement"
EXPRESSION_TYPE = "expression_statement"
ASSIGNMENT_TYPE = "assignment"
def traverse_tree(node: Node) -> Generator[Node, None, None]:
cursor = node.walk()
depth = 0
visited_children = False
while True:
if not visited_children:
yield cursor.node
if not cursor.goto_first_child():
depth += 1
visited_children = True
elif cursor.goto_next_sibling():
visited_children = False
elif not cursor.goto_parent() or depth == 0:
break
else:
depth -= 1
def syntax_check(code, verbose=False):
try:
ast.parse(code)
return True
except (SyntaxError, MemoryError):
if verbose:
traceback.print_exc()
return False
def code_extract(text: str) -> str:
lines = text.split("\n")
longest_line_pair = (0, 0)
longest_so_far = 0
for i in range(len(lines)):
for j in range(i + 1, len(lines)):
current_lines = "\n".join(lines[i : j + 1])
if syntax_check(current_lines):
current_length = sum(1 for line in lines[i : j + 1] if line.strip())
if current_length > longest_so_far:
longest_so_far = current_length
longest_line_pair = (i, j)
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
def get_definition_name(node: Node) -> str:
for child in node.children:
if child.type == IDENTIFIER_TYPE:
return child.text.decode("utf8")
def has_return_statement(node: Node) -> bool:
traverse_nodes = traverse_tree(node)
for node in traverse_nodes:
if node.type == RETURN_TYPE:
return True
return False
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
def dfs_get_deps(node: Node, deps: Set[str]) -> None:
for child in node.children:
if child.type == IDENTIFIER_TYPE:
deps.add(child.text.decode("utf8"))
else:
dfs_get_deps(child, deps)
name2deps = {}
for name, node in nodes:
deps = set()
dfs_get_deps(node, deps)
name2deps[name] = deps
return name2deps
def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
queue = [entrypoint]
visited = {entrypoint}
while queue:
current = queue.pop(0)
if current not in call_graph:
continue
for neighbour in call_graph[current]:
if not (neighbour in visited):
visited.add(neighbour)
queue.append(neighbour)
return visited
def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
code = code_extract(code)
code_bytes = bytes(code, "utf8")
parser = Parser(Language(tree_sitter_python.language()))
tree = parser.parse(code_bytes)
class_names = set()
function_names = set()
variable_names = set()
root_node = tree.root_node
import_nodes = []
definition_nodes = []
for child in root_node.children:
if child.type in IMPORT_TYPE:
import_nodes.append(child)
elif child.type == CLASS_TYPE:
name = get_definition_name(child)
if not (
name in class_names or name in variable_names or name in function_names
):
definition_nodes.append((name, child))
class_names.add(name)
elif child.type == FUNCTION_TYPE:
name = get_definition_name(child)
if not (
name in function_names or name in variable_names or name in class_names
) and has_return_statement(child):
definition_nodes.append((name, child))
function_names.add(get_definition_name(child))
elif (
child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE
):
subchild = child.children[0]
name = get_definition_name(subchild)
if not (
name in variable_names or name in function_names or name in class_names
):
definition_nodes.append((name, subchild))
variable_names.add(name)
if entrypoint:
name2deps = get_deps(definition_nodes)
reacheable = get_function_dependency(entrypoint, name2deps)
sanitized_output = b""
for node in import_nodes:
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
for pair in definition_nodes:
name, node = pair
if entrypoint and not (name in reacheable):
continue
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
return sanitized_output[:-1].decode("utf8")