Update Code Review

This commit is contained in:
didi 2024-08-01 14:20:40 +08:00
parent 686b1cd130
commit 6d4c72cdf0
11 changed files with 549 additions and 662 deletions

View file

@ -474,53 +474,26 @@ class ActionNode:
"""
model_class = self.create_class()
fields = model_class.model_fields
# Assuming there's only one field in the model
if len(fields) == 1:
return next(iter(fields))
# If there are multiple fields, we might want to use self.key to find the right one
return self.key
async def code_fill(
self,
context,
function_name=None,
timeout=USE_CONFIG_TIMEOUT
):
async def code_fill(self, context, function_name=None, timeout=USE_CONFIG_TIMEOUT):
"""
fill CodeBlock Node
"""
def extract_code_from_response(response):
"""
Extracts code wrapped in triple backticks from the response,
removing any language specifier.
:param response: The full response from the LLM
:return: The extracted code, or None if no code is found
"""
code_pattern = r"```(?:\w+\n)?([\s\S]*?)```"
matches = re.findall(code_pattern, response)
if matches:
# The first group in the regex contains the code without the language specifier
code = matches[0].strip()
return code
return None
import re
field_name = self.get_field_name()
prompt = context
# print("generate prompt", "\n", prompt)
content = await self.llm.aask(prompt, timeout=timeout)
# print("generate content", "\n", content)
extracted_code = sanitize(code=content, entrypoint=function_name)
# extracted_code = extract_code_from_response(content)
result = {field_name: extracted_code}
# print("final_result", "\n", result)
return result
async def messages_fill(
self,
):
@ -540,7 +513,7 @@ class ActionNode:
images: Optional[Union[str, list[str]]] = None,
timeout=USE_CONFIG_TIMEOUT,
exclude=[],
function_name: str = None
function_name: str = None,
):
"""Fill the node(s) with mode.

View file

@ -4,28 +4,35 @@
@Time : 2024/7/24 16:37
@Author : didi
@File : code_node.py
@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py
"""
import os
import ast
import pathlib
import traceback
from enum import Enum
from typing import Dict, Generator, List, Optional, Set, Tuple
import tree_sitter_python
from tqdm import tqdm
from tree_sitter import Language, Node, Parser
CLASS_TYPE = "class_definition"
FUNCTION_TYPE = "function_definition"
IMPORT_TYPE = ["import_statement", "import_from_statement"]
IDENTIFIER_TYPE = "identifier"
ATTRIBUTE_TYPE = "attribute"
RETURN_TYPE = "return_statement"
EXPRESSION_TYPE = "expression_statement"
ASSIGNMENT_TYPE = "assignment"
class NodeType(Enum):
CLASS = "class_definition"
FUNCTION = "function_definition"
IMPORT = ["import_statement", "import_from_statement"]
IDENTIFIER = "identifier"
ATTRIBUTE = "attribute"
RETURN = "return_statement"
EXPRESSION = "expression_statement"
ASSIGNMENT = "assignment"
def traverse_tree(node: Node) -> Generator[Node, None, None]:
"""
Traverse the tree structure starting from the given node.
:param node: The root node to start the traversal from.
:return: A generator object that yields nodes in the tree.
"""
cursor = node.walk()
depth = 0
@ -43,6 +50,7 @@ def traverse_tree(node: Node) -> Generator[Node, None, None]:
else:
depth -= 1
def syntax_check(code, verbose=False):
try:
ast.parse(code)
@ -52,6 +60,7 @@ def syntax_check(code, verbose=False):
traceback.print_exc()
return False
def code_extract(text: str) -> str:
lines = text.split("\n")
longest_line_pair = (0, 0)
@ -68,22 +77,25 @@ def code_extract(text: str) -> str:
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
def get_definition_name(node: Node) -> str:
for child in node.children:
if child.type == IDENTIFIER_TYPE:
if child.type == NodeType.IDENTIFIER.value:
return child.text.decode("utf8")
def has_return_statement(node: Node) -> bool:
traverse_nodes = traverse_tree(node)
for node in traverse_nodes:
if node.type == RETURN_TYPE:
if node.type == NodeType.RETURN.value:
return True
return False
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
def dfs_get_deps(node: Node, deps: Set[str]) -> None:
for child in node.children:
if child.type == IDENTIFIER_TYPE:
if child.type == NodeType.IDENTIFIER.value:
deps.add(child.text.decode("utf8"))
else:
dfs_get_deps(child, deps)
@ -104,12 +116,23 @@ def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[
if current not in call_graph:
continue
for neighbour in call_graph[current]:
if not (neighbour in visited):
if neighbour not in visited:
visited.add(neighbour)
queue.append(neighbour)
return visited
def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
"""
Sanitize and extract relevant parts of the given Python code.
This function parses the input code, extracts import statements, class and function definitions,
and variable assignments. If an entrypoint is provided, it only includes definitions that are
reachable from the entrypoint in the call graph.
:param code: The input Python code as a string.
:param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis.
:return: A sanitized version of the input code, containing only relevant parts.
"""
code = code_extract(code)
code_bytes = bytes(code, "utf8")
parser = Parser(Language(tree_sitter_python.language()))
@ -123,30 +146,24 @@ def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
definition_nodes = []
for child in root_node.children:
if child.type in IMPORT_TYPE:
if child.type in NodeType.IMPORT.value:
import_nodes.append(child)
elif child.type == CLASS_TYPE:
elif child.type == NodeType.CLASS.value:
name = get_definition_name(child)
if not (
name in class_names or name in variable_names or name in function_names
):
if not (name in class_names or name in variable_names or name in function_names):
definition_nodes.append((name, child))
class_names.add(name)
elif child.type == FUNCTION_TYPE:
elif child.type == NodeType.FUNCTION.value:
name = get_definition_name(child)
if not (
name in function_names or name in variable_names or name in class_names
) and has_return_statement(child):
if not (name in function_names or name in variable_names or name in class_names) and has_return_statement(
child
):
definition_nodes.append((name, child))
function_names.add(get_definition_name(child))
elif (
child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE
):
elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value:
subchild = child.children[0]
name = get_definition_name(subchild)
if not (
name in variable_names or name in function_names or name in class_names
):
if not (name in variable_names or name in function_names or name in class_names):
definition_nodes.append((name, subchild))
variable_names.add(name)
@ -161,7 +178,7 @@ def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
for pair in definition_nodes:
name, node = pair
if entrypoint and not (name in reacheable):
if entrypoint and name not in reacheable:
continue
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
return sanitized_output[:-1].decode("utf8")