mirror of
https://github.com/L-yang-yang/cugenopt.git
synced 2026-04-25 12:16:21 +02:00
Initial commit: cuGenOpt GPU optimization solver
This commit is contained in:
commit
fc5a0ff4af
117 changed files with 25545 additions and 0 deletions
261
python/cugenopt/validation.py
Normal file
261
python/cugenopt/validation.py
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
"""
|
||||
Input validation and friendly error translation for cuGenOpt.
|
||||
|
||||
Two responsibilities:
|
||||
1. Validate numpy arrays before JIT compilation (dtype, shape, NaN/Inf, contiguity)
|
||||
2. Translate nvcc compilation errors into actionable Python messages
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class CuGenOptValidationError(ValueError):
|
||||
"""Raised when input data fails validation."""
|
||||
pass
|
||||
|
||||
|
||||
class CuGenOptCompileError(RuntimeError):
|
||||
"""Raised when nvcc compilation fails, with a friendly summary."""
|
||||
|
||||
def __init__(self, raw_stderr: str, source_path: str):
|
||||
self.raw_stderr = raw_stderr
|
||||
self.source_path = source_path
|
||||
self.friendly = _translate_nvcc_error(raw_stderr)
|
||||
super().__init__(
|
||||
f"{self.friendly}\n\n"
|
||||
f"[raw nvcc output]\n{_truncate(raw_stderr, 1200)}\n\n"
|
||||
f"Source saved at: {source_path}"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Array validation
|
||||
# ============================================================
|
||||
|
||||
def validate_array(
|
||||
arr: np.ndarray,
|
||||
name: str,
|
||||
*,
|
||||
expected_dtype: Optional[np.dtype] = None,
|
||||
expected_ndim: Optional[int] = None,
|
||||
expected_shape: Optional[tuple] = None,
|
||||
min_size: int = 1,
|
||||
allow_nan: bool = False,
|
||||
allow_inf: bool = False,
|
||||
) -> np.ndarray:
|
||||
"""Validate a single numpy array and return a contiguous copy if needed.
|
||||
|
||||
Raises CuGenOptValidationError with a clear message on failure.
|
||||
"""
|
||||
if not isinstance(arr, np.ndarray):
|
||||
raise CuGenOptValidationError(
|
||||
f"'{name}' must be a numpy array, got {type(arr).__name__}"
|
||||
)
|
||||
|
||||
if expected_ndim is not None and arr.ndim != expected_ndim:
|
||||
raise CuGenOptValidationError(
|
||||
f"'{name}' must be {expected_ndim}D, got {arr.ndim}D with shape {arr.shape}"
|
||||
)
|
||||
|
||||
if expected_shape is not None:
|
||||
for i, (actual, expect) in enumerate(zip(arr.shape, expected_shape)):
|
||||
if expect is not None and actual != expect:
|
||||
raise CuGenOptValidationError(
|
||||
f"'{name}' shape mismatch at axis {i}: "
|
||||
f"expected {expected_shape}, got {arr.shape}"
|
||||
)
|
||||
|
||||
if arr.size < min_size:
|
||||
raise CuGenOptValidationError(
|
||||
f"'{name}' is too small: size={arr.size}, minimum={min_size}"
|
||||
)
|
||||
|
||||
if expected_dtype is not None:
|
||||
arr = np.ascontiguousarray(arr, dtype=expected_dtype)
|
||||
|
||||
if not allow_nan and np.issubdtype(arr.dtype, np.floating) and np.isnan(arr).any():
|
||||
nan_count = int(np.isnan(arr).sum())
|
||||
raise CuGenOptValidationError(
|
||||
f"'{name}' contains {nan_count} NaN value(s). "
|
||||
f"Clean your data or set allow_nan=True."
|
||||
)
|
||||
|
||||
if not allow_inf and np.issubdtype(arr.dtype, np.floating) and np.isinf(arr).any():
|
||||
inf_count = int(np.isinf(arr).sum())
|
||||
raise CuGenOptValidationError(
|
||||
f"'{name}' contains {inf_count} Inf value(s). "
|
||||
f"Clean your data or set allow_inf=True."
|
||||
)
|
||||
|
||||
return np.ascontiguousarray(arr)
|
||||
|
||||
|
||||
def validate_square_matrix(arr: np.ndarray, name: str, dtype=np.float32) -> np.ndarray:
|
||||
"""Validate a square 2D matrix."""
|
||||
arr = validate_array(arr, name, expected_ndim=2, expected_dtype=dtype)
|
||||
if arr.shape[0] != arr.shape[1]:
|
||||
raise CuGenOptValidationError(
|
||||
f"'{name}' must be square, got shape {arr.shape}"
|
||||
)
|
||||
return arr
|
||||
|
||||
|
||||
def validate_1d(arr: np.ndarray, name: str, *, length: Optional[int] = None,
|
||||
dtype=np.float32) -> np.ndarray:
|
||||
"""Validate a 1D array with optional length check."""
|
||||
arr = validate_array(arr, name, expected_ndim=1, expected_dtype=dtype)
|
||||
if length is not None and arr.shape[0] != length:
|
||||
raise CuGenOptValidationError(
|
||||
f"'{name}' length mismatch: expected {length}, got {arr.shape[0]}"
|
||||
)
|
||||
return arr
|
||||
|
||||
|
||||
def validate_data_dict(data: Dict[str, np.ndarray], dtype_tag: str) -> Dict[str, np.ndarray]:
|
||||
"""Validate a dict of name -> array for compile_and_solve data/int_data."""
|
||||
target_dtype = np.float32 if dtype_tag == "float" else np.int32
|
||||
validated = {}
|
||||
for name, arr in data.items():
|
||||
if not isinstance(arr, np.ndarray):
|
||||
raise CuGenOptValidationError(
|
||||
f"data['{name}'] must be a numpy array, got {type(arr).__name__}"
|
||||
)
|
||||
arr = validate_array(arr, f"data['{name}']", expected_dtype=target_dtype)
|
||||
validated[name] = arr
|
||||
return validated
|
||||
|
||||
|
||||
def validate_encoding(encoding: str) -> str:
|
||||
"""Validate encoding string."""
|
||||
valid = {"permutation", "binary", "integer"}
|
||||
enc = encoding.lower().strip()
|
||||
if enc not in valid:
|
||||
raise CuGenOptValidationError(
|
||||
f"Unknown encoding '{encoding}'. Must be one of: {', '.join(sorted(valid))}"
|
||||
)
|
||||
return enc
|
||||
|
||||
|
||||
def validate_positive_int(value, name: str, *, allow_zero: bool = False) -> int:
|
||||
"""Validate that value is a positive integer."""
|
||||
try:
|
||||
v = int(value)
|
||||
except (TypeError, ValueError):
|
||||
raise CuGenOptValidationError(
|
||||
f"'{name}' must be an integer, got {type(value).__name__}: {value!r}"
|
||||
)
|
||||
if allow_zero and v < 0:
|
||||
raise CuGenOptValidationError(f"'{name}' must be >= 0, got {v}")
|
||||
if not allow_zero and v < 1:
|
||||
raise CuGenOptValidationError(f"'{name}' must be >= 1, got {v}")
|
||||
return v
|
||||
|
||||
|
||||
def validate_cuda_snippet(code: str, name: str) -> str:
|
||||
"""Basic sanity check on a CUDA code snippet."""
|
||||
code = code.strip()
|
||||
if not code:
|
||||
raise CuGenOptValidationError(f"'{name}' CUDA code snippet is empty")
|
||||
|
||||
dangerous = ["system(", "popen(", "exec(", "fork(", "unlink("]
|
||||
for d in dangerous:
|
||||
if d in code:
|
||||
raise CuGenOptValidationError(
|
||||
f"'{name}' contains potentially dangerous call: '{d}'"
|
||||
)
|
||||
return code
|
||||
|
||||
|
||||
# ============================================================
|
||||
# nvcc error translation
|
||||
# ============================================================
|
||||
|
||||
_NVCC_PATTERNS = [
|
||||
(
|
||||
re.compile(r"error:\s*identifier\s+\"(\w+)\"\s+is\s+undefined", re.I),
|
||||
lambda m: f"Undefined identifier '{m.group(1)}'. "
|
||||
f"Check that all data field names in compute_obj/compute_penalty "
|
||||
f"match the keys in your data dict."
|
||||
),
|
||||
(
|
||||
re.compile(r"error:\s*expected\s+a\s+\"([^\"]+)\"", re.I),
|
||||
lambda m: f"Syntax error: expected '{m.group(1)}'. "
|
||||
f"Check for missing semicolons, braces, or parentheses."
|
||||
),
|
||||
(
|
||||
re.compile(r"error:\s*no\s+suitable\s+conversion\s+function\s+from\s+\"([^\"]+)\"\s+to\s+\"([^\"]+)\"", re.I),
|
||||
lambda m: f"Type mismatch: cannot convert '{m.group(1)}' to '{m.group(2)}'. "
|
||||
f"Ensure you're using the correct types (float/int)."
|
||||
),
|
||||
(
|
||||
re.compile(r"error:\s*too\s+(?:few|many)\s+arguments", re.I),
|
||||
lambda m: f"Wrong number of arguments in a function call. "
|
||||
f"Check the function signature."
|
||||
),
|
||||
(
|
||||
re.compile(r"error:\s*class\s+\"(\w+)\"\s+has\s+no\s+member\s+\"(\w+)\"", re.I),
|
||||
lambda m: f"'{m.group(1)}' has no member '{m.group(2)}'. "
|
||||
f"Available solution members: data[row][col], dim2_sizes[row]."
|
||||
),
|
||||
(
|
||||
re.compile(r"error:\s*expression\s+must\s+have\s+a\s+constant\s+value", re.I),
|
||||
lambda m: f"Non-constant expression where a constant is required. "
|
||||
f"CUDA device code cannot use dynamic allocation; "
|
||||
f"use fixed-size arrays."
|
||||
),
|
||||
(
|
||||
re.compile(r"ptxas\s+error\s*:\s*Entry\s+function.*uses\s+too\s+much\s+shared\s+data", re.I),
|
||||
lambda m: f"Shared memory overflow. Your problem data is too large for GPU "
|
||||
f"shared memory. Try reducing problem size or data arrays."
|
||||
),
|
||||
(
|
||||
re.compile(r"nvcc\s+fatal\s*:\s*Unsupported\s+gpu\s+architecture\s+'compute_(\d+)'", re.I),
|
||||
lambda m: f"GPU architecture sm_{m.group(1)} is not supported by your nvcc. "
|
||||
f"Try specifying cuda_arch='sm_75' or update your CUDA toolkit."
|
||||
),
|
||||
(
|
||||
re.compile(r"error:\s*return\s+value\s+type\s+does\s+not\s+match", re.I),
|
||||
lambda m: f"Return type mismatch. compute_obj must return float. "
|
||||
f"Make sure all code paths return a float value."
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _translate_nvcc_error(stderr: str) -> str:
|
||||
"""Extract the most relevant error from nvcc output and provide a friendly message."""
|
||||
messages = []
|
||||
for pattern, formatter in _NVCC_PATTERNS:
|
||||
match = pattern.search(stderr)
|
||||
if match:
|
||||
messages.append(formatter(match))
|
||||
|
||||
if messages:
|
||||
header = "nvcc compilation failed. Likely cause(s):\n"
|
||||
return header + "\n".join(f" - {m}" for m in messages)
|
||||
|
||||
error_lines = [
|
||||
line.strip() for line in stderr.split("\n")
|
||||
if "error" in line.lower() and not line.strip().startswith("#")
|
||||
]
|
||||
if error_lines:
|
||||
summary = error_lines[0]
|
||||
return (
|
||||
f"nvcc compilation failed:\n {summary}\n\n"
|
||||
f"Tip: Check your CUDA code snippets for syntax errors. "
|
||||
f"Common issues: missing semicolons, undefined variables, "
|
||||
f"wrong data field names."
|
||||
)
|
||||
|
||||
return (
|
||||
"nvcc compilation failed with an unknown error.\n"
|
||||
"Check the raw output below for details."
|
||||
)
|
||||
|
||||
|
||||
def _truncate(text: str, max_len: int) -> str:
|
||||
if len(text) <= max_len:
|
||||
return text
|
||||
return text[:max_len] + f"\n... ({len(text) - max_len} chars truncated)"
|
||||
Loading…
Add table
Add a link
Reference in a new issue