mirror of
https://github.com/katanemo/plano.git
synced 2026-05-15 11:02:39 +02:00
lint + formating with black (#158)
* lint + formating with black * add black as pre commit
This commit is contained in:
parent
498e7f9724
commit
5c4a6bc8ff
22 changed files with 581 additions and 295 deletions
|
|
@ -18,14 +18,20 @@ def detect_framework(tree: Any) -> str:
|
|||
return "fastapi"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def get_route_decorators(node: Any, framework: str) -> list:
|
||||
"""Extract route decorators based on the framework."""
|
||||
decorators = []
|
||||
for decorator in node.decorator_list:
|
||||
if isinstance(decorator, ast.Call) and isinstance(decorator.func, ast.Attribute):
|
||||
if isinstance(decorator, ast.Call) and isinstance(
|
||||
decorator.func, ast.Attribute
|
||||
):
|
||||
if framework == "flask" and decorator.func.attr in FLASK_ROUTE_DECORATORS:
|
||||
decorators.append(decorator.func.attr)
|
||||
elif framework == "fastapi" and decorator.func.attr in FASTAPI_ROUTE_DECORATORS:
|
||||
elif (
|
||||
framework == "fastapi"
|
||||
and decorator.func.attr in FASTAPI_ROUTE_DECORATORS
|
||||
):
|
||||
decorators.append(decorator.func.attr)
|
||||
return decorators
|
||||
|
||||
|
|
@ -36,6 +42,7 @@ def get_route_path(node: Any, framework: str) -> str:
|
|||
if isinstance(decorator, ast.Call) and decorator.args:
|
||||
return decorator.args[0].s # Assuming it's a string literal
|
||||
|
||||
|
||||
def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool:
|
||||
"""Check if a given type annotation is a Pydantic model."""
|
||||
# We walk through the AST to find class definitions and check if they inherit from Pydantic's BaseModel
|
||||
|
|
@ -47,6 +54,7 @@ def is_pydantic_model(annotation: ast.expr, tree: ast.AST) -> bool:
|
|||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
|
||||
"""Extract fields from a Pydantic model, handling list, tuple, set, dict types, and direct default values."""
|
||||
fields = []
|
||||
|
|
@ -62,15 +70,26 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
|
|||
required = True # Assume the field is required initially
|
||||
|
||||
# Check if the field uses Field() with required status and description
|
||||
if stmt.value and isinstance(stmt.value, ast.Call) and isinstance(stmt.value.func, ast.Name) and stmt.value.func.id == 'Field':
|
||||
if (
|
||||
stmt.value
|
||||
and isinstance(stmt.value, ast.Call)
|
||||
and isinstance(stmt.value.func, ast.Name)
|
||||
and stmt.value.func.id == "Field"
|
||||
):
|
||||
# Extract the description argument inside the Field call
|
||||
for keyword in stmt.value.keywords:
|
||||
if keyword.arg == 'description' and isinstance(keyword.value, ast.Str):
|
||||
if keyword.arg == "description" and isinstance(
|
||||
keyword.value, ast.Str
|
||||
):
|
||||
description = keyword.value.s
|
||||
if keyword.arg == 'default':
|
||||
if keyword.arg == "default":
|
||||
default_value = keyword.value
|
||||
# If Ellipsis (...) is used, it means the field is required
|
||||
if stmt.value.args and isinstance(stmt.value.args[0], ast.Constant) and stmt.value.args[0].value is Ellipsis:
|
||||
if (
|
||||
stmt.value.args
|
||||
and isinstance(stmt.value.args[0], ast.Constant)
|
||||
and stmt.value.args[0].value is Ellipsis
|
||||
):
|
||||
required = True
|
||||
else:
|
||||
required = False
|
||||
|
|
@ -80,19 +99,28 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
|
|||
if isinstance(stmt.value, ast.Constant):
|
||||
# Set the default value from the assignment (e.g., name: str = "John Doe")
|
||||
default_value = stmt.value.value
|
||||
required = False # Not required since it has a default value
|
||||
required = (
|
||||
False # Not required since it has a default value
|
||||
)
|
||||
|
||||
# Always extract the field type, even if there's a default value
|
||||
if isinstance(stmt.annotation, ast.Subscript):
|
||||
# Get the base type (list, tuple, set, dict)
|
||||
base_type = stmt.annotation.value.id if isinstance(stmt.annotation.value, ast.Name) else "Unknown"
|
||||
base_type = (
|
||||
stmt.annotation.value.id
|
||||
if isinstance(stmt.annotation.value, ast.Name)
|
||||
else "Unknown"
|
||||
)
|
||||
|
||||
# Handle only list, tuple, set, dict and ignore the inner types
|
||||
if base_type.lower() in ['list', 'tuple', 'set', 'dict']:
|
||||
if base_type.lower() in ["list", "tuple", "set", "dict"]:
|
||||
field_type = base_type.lower()
|
||||
|
||||
# Handle the ellipsis '...' for required fields if no Field() call
|
||||
elif isinstance(stmt.value, ast.Constant) and stmt.value.value is Ellipsis:
|
||||
elif (
|
||||
isinstance(stmt.value, ast.Constant)
|
||||
and stmt.value.value is Ellipsis
|
||||
):
|
||||
required = True
|
||||
|
||||
# Handle simple types like str, int, etc.
|
||||
|
|
@ -100,16 +128,17 @@ def get_pydantic_model_fields(model_name: str, tree: ast.AST) -> list:
|
|||
field_type = stmt.annotation.id
|
||||
|
||||
field_info = {
|
||||
"name": stmt.target.id,
|
||||
"type": field_type, # Always set the field type
|
||||
"description": description,
|
||||
"default": default_value, # Handle direct default values
|
||||
"required": required
|
||||
"name": stmt.target.id,
|
||||
"type": field_type, # Always set the field type
|
||||
"description": description,
|
||||
"default": default_value, # Handle direct default values
|
||||
"required": required,
|
||||
}
|
||||
fields.append(field_info)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
|
||||
"""Extract the parameters and their types from the function definition."""
|
||||
parameters = []
|
||||
|
|
@ -119,40 +148,68 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
|
|||
arg_descriptions = extract_arg_descriptions_from_docstring(docstring)
|
||||
|
||||
# Extract default values
|
||||
defaults = [None] * (len(node.args.args) - len(node.args.defaults)) + node.args.defaults # Align defaults with args
|
||||
defaults = [None] * (
|
||||
len(node.args.args) - len(node.args.defaults)
|
||||
) + node.args.defaults # Align defaults with args
|
||||
for arg, default in zip(node.args.args, defaults):
|
||||
if arg.arg != "self": # Skip 'self' or 'cls' in class methods
|
||||
param_info = {"name": arg.arg, "description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]")}
|
||||
param_info = {
|
||||
"name": arg.arg,
|
||||
"description": arg_descriptions.get(arg.arg, "[ADD DESCRIPTION]"),
|
||||
}
|
||||
|
||||
# Handle Pydantic model types
|
||||
if hasattr(arg, 'annotation') and is_pydantic_model(arg.annotation, tree):
|
||||
if hasattr(arg, "annotation") and is_pydantic_model(arg.annotation, tree):
|
||||
# Extract and flatten Pydantic model fields
|
||||
pydantic_fields = get_pydantic_model_fields(arg.annotation.id, tree)
|
||||
parameters.extend(pydantic_fields) # Flatten the model fields into the parameters list
|
||||
parameters.extend(
|
||||
pydantic_fields
|
||||
) # Flatten the model fields into the parameters list
|
||||
continue # Skip adding the current param_info for the model since we expand the fields
|
||||
|
||||
# Handle standard Python types (int, float, str, etc.)
|
||||
elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Name):
|
||||
if arg.annotation.id in ['int', 'float', 'bool', 'str', 'list', 'tuple', 'set', 'dict']:
|
||||
elif hasattr(arg, "annotation") and isinstance(arg.annotation, ast.Name):
|
||||
if arg.annotation.id in [
|
||||
"int",
|
||||
"float",
|
||||
"bool",
|
||||
"str",
|
||||
"list",
|
||||
"tuple",
|
||||
"set",
|
||||
"dict",
|
||||
]:
|
||||
param_info["type"] = arg.annotation.id
|
||||
else:
|
||||
param_info["type"] = "[UNKNOWN - PLEASE FIX]"
|
||||
|
||||
# Handle generic subscript types (e.g., Optional, List[Type], etc.)
|
||||
elif hasattr(arg, 'annotation') and isinstance(arg.annotation, ast.Subscript):
|
||||
if isinstance(arg.annotation.value, ast.Name) and arg.annotation.value.id in ['list', 'tuple', 'set', 'dict']:
|
||||
param_info["type"] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
|
||||
elif hasattr(arg, "annotation") and isinstance(
|
||||
arg.annotation, ast.Subscript
|
||||
):
|
||||
if isinstance(
|
||||
arg.annotation.value, ast.Name
|
||||
) and arg.annotation.value.id in ["list", "tuple", "set", "dict"]:
|
||||
param_info[
|
||||
"type"
|
||||
] = f"{arg.annotation.value.id}" # e.g., "List", "Tuple", etc.
|
||||
else:
|
||||
param_info["type"] = "[UNKNOWN - PLEASE FIX]"
|
||||
|
||||
# Default for unknown types
|
||||
else:
|
||||
param_info["type"] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type
|
||||
param_info[
|
||||
"type"
|
||||
] = "[UNKNOWN - PLEASE FIX]" # If unable to detect type
|
||||
|
||||
# Handle default values
|
||||
if default is not None:
|
||||
if isinstance(default, ast.Constant) or isinstance(default, ast.NameConstant):
|
||||
param_info["default"] = default.value # Use the default value directly
|
||||
if isinstance(default, ast.Constant) or isinstance(
|
||||
default, ast.NameConstant
|
||||
):
|
||||
param_info[
|
||||
"default"
|
||||
] = default.value # Use the default value directly
|
||||
else:
|
||||
param_info["default"] = "[UNKNOWN DEFAULT]" # Unknown default type
|
||||
param_info["required"] = False # Optional since it has a default value
|
||||
|
|
@ -164,6 +221,7 @@ def get_function_parameters(node: ast.FunctionDef, tree: ast.AST) -> list:
|
|||
|
||||
return parameters
|
||||
|
||||
|
||||
def get_function_docstring(node: Any) -> str:
|
||||
"""Extract the function's docstring description if present."""
|
||||
# Check if the first node is a docstring
|
||||
|
|
@ -178,6 +236,7 @@ def get_function_docstring(node: Any) -> str:
|
|||
|
||||
return "No description provided."
|
||||
|
||||
|
||||
def extract_arg_descriptions_from_docstring(docstring: str) -> dict:
|
||||
"""Extract descriptions for function parameters from the 'Args' section of the docstring."""
|
||||
descriptions = {}
|
||||
|
|
@ -195,14 +254,14 @@ def extract_arg_descriptions_from_docstring(docstring: str) -> dict:
|
|||
continue # Proceed to the next line after 'Args:'
|
||||
|
||||
# End of 'Args' section if no indentation and no colon
|
||||
if in_args_section and not line.startswith(" ") and ':' not in line:
|
||||
if in_args_section and not line.startswith(" ") and ":" not in line:
|
||||
break # Stop processing if we reach a new section
|
||||
|
||||
# Process lines in the 'Args' section
|
||||
if in_args_section:
|
||||
if ':' in line:
|
||||
if ":" in line:
|
||||
# Extract parameter name and description
|
||||
param_name, description = line.split(':', 1)
|
||||
param_name, description = line.split(":", 1)
|
||||
descriptions[param_name.strip()] = description.strip()
|
||||
current_param = param_name.strip()
|
||||
elif current_param and line.startswith(" "):
|
||||
|
|
@ -230,43 +289,50 @@ def generate_prompt_targets(input_file_path: str) -> None:
|
|||
route_decorators = get_route_decorators(node, framework)
|
||||
if route_decorators:
|
||||
route_path = get_route_path(node, framework)
|
||||
function_params = get_function_parameters(node, tree) # Get parameters for the route
|
||||
function_params = get_function_parameters(
|
||||
node, tree
|
||||
) # Get parameters for the route
|
||||
function_docstring = get_function_docstring(node) # Extract docstring
|
||||
routes.append({
|
||||
'name': node.name,
|
||||
'path': route_path,
|
||||
'methods': route_decorators,
|
||||
'parameters': function_params, # Add parameters to the route
|
||||
'description': function_docstring # Add the docstring as the description
|
||||
})
|
||||
routes.append(
|
||||
{
|
||||
"name": node.name,
|
||||
"path": route_path,
|
||||
"methods": route_decorators,
|
||||
"parameters": function_params, # Add parameters to the route
|
||||
"description": function_docstring, # Add the docstring as the description
|
||||
}
|
||||
)
|
||||
|
||||
# Generate YAML structure
|
||||
output_structure = {
|
||||
"prompt_targets": []
|
||||
}
|
||||
output_structure = {"prompt_targets": []}
|
||||
|
||||
for route in routes:
|
||||
target = {
|
||||
"name": route['name'],
|
||||
"name": route["name"],
|
||||
"endpoint": [
|
||||
{
|
||||
"name": "app_server",
|
||||
"path": route['path'],
|
||||
"path": route["path"],
|
||||
}
|
||||
],
|
||||
"description": route['description'], # Use extracted docstring
|
||||
"description": route["description"], # Use extracted docstring
|
||||
"parameters": [
|
||||
{
|
||||
"name": param['name'],
|
||||
"type": param['type'],
|
||||
"name": param["name"],
|
||||
"type": param["type"],
|
||||
"description": f"{param['description']}",
|
||||
**({"default": param['default']} if "default" in param and param['default'] is not None else {}), # Only add default if it's set
|
||||
"required": param['required']
|
||||
} for param in route['parameters']
|
||||
]
|
||||
**(
|
||||
{"default": param["default"]}
|
||||
if "default" in param and param["default"] is not None
|
||||
else {}
|
||||
), # Only add default if it's set
|
||||
"required": param["required"],
|
||||
}
|
||||
for param in route["parameters"]
|
||||
],
|
||||
}
|
||||
|
||||
if route['name'] == "default":
|
||||
if route["name"] == "default":
|
||||
# Special case for `information_extraction` based on your YAML format
|
||||
target["type"] = "default"
|
||||
target["auto-llm-dispatch-on-response"] = True
|
||||
|
|
@ -274,9 +340,12 @@ def generate_prompt_targets(input_file_path: str) -> None:
|
|||
output_structure["prompt_targets"].append(target)
|
||||
|
||||
# Output as YAML
|
||||
print(yaml.dump(output_structure, sort_keys=False,default_flow_style=False, indent=3))
|
||||
print(
|
||||
yaml.dump(output_structure, sort_keys=False, default_flow_style=False, indent=3)
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) != 2:
|
||||
print("Usage: python targets.py <input_file>")
|
||||
sys.exit(1)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue