cuGenOpt/python/cugenopt/validation.py

262 lines
9.5 KiB
Python
Raw Normal View History

"""
Input validation and friendly error translation for cuGenOpt.
Two responsibilities:
1. Validate numpy arrays before JIT compilation (dtype, shape, NaN/Inf, contiguity)
2. Translate nvcc compilation errors into actionable Python messages
"""
import re
from typing import Dict, Optional, Sequence
import numpy as np
class CuGenOptValidationError(ValueError):
"""Raised when input data fails validation."""
pass
class CuGenOptCompileError(RuntimeError):
"""Raised when nvcc compilation fails, with a friendly summary."""
def __init__(self, raw_stderr: str, source_path: str):
self.raw_stderr = raw_stderr
self.source_path = source_path
self.friendly = _translate_nvcc_error(raw_stderr)
super().__init__(
f"{self.friendly}\n\n"
f"[raw nvcc output]\n{_truncate(raw_stderr, 1200)}\n\n"
f"Source saved at: {source_path}"
)
# ============================================================
# Array validation
# ============================================================
def validate_array(
arr: np.ndarray,
name: str,
*,
expected_dtype: Optional[np.dtype] = None,
expected_ndim: Optional[int] = None,
expected_shape: Optional[tuple] = None,
min_size: int = 1,
allow_nan: bool = False,
allow_inf: bool = False,
) -> np.ndarray:
"""Validate a single numpy array and return a contiguous copy if needed.
Raises CuGenOptValidationError with a clear message on failure.
"""
if not isinstance(arr, np.ndarray):
raise CuGenOptValidationError(
f"'{name}' must be a numpy array, got {type(arr).__name__}"
)
if expected_ndim is not None and arr.ndim != expected_ndim:
raise CuGenOptValidationError(
f"'{name}' must be {expected_ndim}D, got {arr.ndim}D with shape {arr.shape}"
)
if expected_shape is not None:
for i, (actual, expect) in enumerate(zip(arr.shape, expected_shape)):
if expect is not None and actual != expect:
raise CuGenOptValidationError(
f"'{name}' shape mismatch at axis {i}: "
f"expected {expected_shape}, got {arr.shape}"
)
if arr.size < min_size:
raise CuGenOptValidationError(
f"'{name}' is too small: size={arr.size}, minimum={min_size}"
)
if expected_dtype is not None:
arr = np.ascontiguousarray(arr, dtype=expected_dtype)
if not allow_nan and np.issubdtype(arr.dtype, np.floating) and np.isnan(arr).any():
nan_count = int(np.isnan(arr).sum())
raise CuGenOptValidationError(
f"'{name}' contains {nan_count} NaN value(s). "
f"Clean your data or set allow_nan=True."
)
if not allow_inf and np.issubdtype(arr.dtype, np.floating) and np.isinf(arr).any():
inf_count = int(np.isinf(arr).sum())
raise CuGenOptValidationError(
f"'{name}' contains {inf_count} Inf value(s). "
f"Clean your data or set allow_inf=True."
)
return np.ascontiguousarray(arr)
def validate_square_matrix(arr: np.ndarray, name: str, dtype=np.float32) -> np.ndarray:
"""Validate a square 2D matrix."""
arr = validate_array(arr, name, expected_ndim=2, expected_dtype=dtype)
if arr.shape[0] != arr.shape[1]:
raise CuGenOptValidationError(
f"'{name}' must be square, got shape {arr.shape}"
)
return arr
def validate_1d(arr: np.ndarray, name: str, *, length: Optional[int] = None,
dtype=np.float32) -> np.ndarray:
"""Validate a 1D array with optional length check."""
arr = validate_array(arr, name, expected_ndim=1, expected_dtype=dtype)
if length is not None and arr.shape[0] != length:
raise CuGenOptValidationError(
f"'{name}' length mismatch: expected {length}, got {arr.shape[0]}"
)
return arr
def validate_data_dict(data: Dict[str, np.ndarray], dtype_tag: str) -> Dict[str, np.ndarray]:
"""Validate a dict of name -> array for compile_and_solve data/int_data."""
target_dtype = np.float32 if dtype_tag == "float" else np.int32
validated = {}
for name, arr in data.items():
if not isinstance(arr, np.ndarray):
raise CuGenOptValidationError(
f"data['{name}'] must be a numpy array, got {type(arr).__name__}"
)
arr = validate_array(arr, f"data['{name}']", expected_dtype=target_dtype)
validated[name] = arr
return validated
def validate_encoding(encoding: str) -> str:
"""Validate encoding string."""
valid = {"permutation", "binary", "integer"}
enc = encoding.lower().strip()
if enc not in valid:
raise CuGenOptValidationError(
f"Unknown encoding '{encoding}'. Must be one of: {', '.join(sorted(valid))}"
)
return enc
def validate_positive_int(value, name: str, *, allow_zero: bool = False) -> int:
"""Validate that value is a positive integer."""
try:
v = int(value)
except (TypeError, ValueError):
raise CuGenOptValidationError(
f"'{name}' must be an integer, got {type(value).__name__}: {value!r}"
)
if allow_zero and v < 0:
raise CuGenOptValidationError(f"'{name}' must be >= 0, got {v}")
if not allow_zero and v < 1:
raise CuGenOptValidationError(f"'{name}' must be >= 1, got {v}")
return v
def validate_cuda_snippet(code: str, name: str) -> str:
"""Basic sanity check on a CUDA code snippet."""
code = code.strip()
if not code:
raise CuGenOptValidationError(f"'{name}' CUDA code snippet is empty")
dangerous = ["system(", "popen(", "exec(", "fork(", "unlink("]
for d in dangerous:
if d in code:
raise CuGenOptValidationError(
f"'{name}' contains potentially dangerous call: '{d}'"
)
return code
# ============================================================
# nvcc error translation
# ============================================================
_NVCC_PATTERNS = [
(
re.compile(r"error:\s*identifier\s+\"(\w+)\"\s+is\s+undefined", re.I),
lambda m: f"Undefined identifier '{m.group(1)}'. "
f"Check that all data field names in compute_obj/compute_penalty "
f"match the keys in your data dict."
),
(
re.compile(r"error:\s*expected\s+a\s+\"([^\"]+)\"", re.I),
lambda m: f"Syntax error: expected '{m.group(1)}'. "
f"Check for missing semicolons, braces, or parentheses."
),
(
re.compile(r"error:\s*no\s+suitable\s+conversion\s+function\s+from\s+\"([^\"]+)\"\s+to\s+\"([^\"]+)\"", re.I),
lambda m: f"Type mismatch: cannot convert '{m.group(1)}' to '{m.group(2)}'. "
f"Ensure you're using the correct types (float/int)."
),
(
re.compile(r"error:\s*too\s+(?:few|many)\s+arguments", re.I),
lambda m: f"Wrong number of arguments in a function call. "
f"Check the function signature."
),
(
re.compile(r"error:\s*class\s+\"(\w+)\"\s+has\s+no\s+member\s+\"(\w+)\"", re.I),
lambda m: f"'{m.group(1)}' has no member '{m.group(2)}'. "
f"Available solution members: data[row][col], dim2_sizes[row]."
),
(
re.compile(r"error:\s*expression\s+must\s+have\s+a\s+constant\s+value", re.I),
lambda m: f"Non-constant expression where a constant is required. "
f"CUDA device code cannot use dynamic allocation; "
f"use fixed-size arrays."
),
(
re.compile(r"ptxas\s+error\s*:\s*Entry\s+function.*uses\s+too\s+much\s+shared\s+data", re.I),
lambda m: f"Shared memory overflow. Your problem data is too large for GPU "
f"shared memory. Try reducing problem size or data arrays."
),
(
re.compile(r"nvcc\s+fatal\s*:\s*Unsupported\s+gpu\s+architecture\s+'compute_(\d+)'", re.I),
lambda m: f"GPU architecture sm_{m.group(1)} is not supported by your nvcc. "
f"Try specifying cuda_arch='sm_75' or update your CUDA toolkit."
),
(
re.compile(r"error:\s*return\s+value\s+type\s+does\s+not\s+match", re.I),
lambda m: f"Return type mismatch. compute_obj must return float. "
f"Make sure all code paths return a float value."
),
]
def _translate_nvcc_error(stderr: str) -> str:
"""Extract the most relevant error from nvcc output and provide a friendly message."""
messages = []
for pattern, formatter in _NVCC_PATTERNS:
match = pattern.search(stderr)
if match:
messages.append(formatter(match))
if messages:
header = "nvcc compilation failed. Likely cause(s):\n"
return header + "\n".join(f" - {m}" for m in messages)
error_lines = [
line.strip() for line in stderr.split("\n")
if "error" in line.lower() and not line.strip().startswith("#")
]
if error_lines:
summary = error_lines[0]
return (
f"nvcc compilation failed:\n {summary}\n\n"
f"Tip: Check your CUDA code snippets for syntax errors. "
f"Common issues: missing semicolons, undefined variables, "
f"wrong data field names."
)
return (
"nvcc compilation failed with an unknown error.\n"
"Check the raw output below for details."
)
def _truncate(text: str, max_len: int) -> str:
if len(text) <= max_len:
return text
return text[:max_len] + f"\n... ({len(text) - max_len} chars truncated)"