mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-10 08:05:14 +02:00
rename klo to ktx
This commit is contained in:
parent
1a42152e6f
commit
3ce510b55b
704 changed files with 10205 additions and 10255 deletions
6
python/ktx-daemon/src/ktx_daemon/__init__.py
Normal file
6
python/ktx-daemon/src/ktx_daemon/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
"""Portable compute package for KTX."""
|
||||
|
||||
PACKAGE_NAME = "ktx-daemon"
|
||||
VERSION = "0.1.0"
|
||||
|
||||
__all__ = ["PACKAGE_NAME", "VERSION"]
|
||||
172
python/ktx-daemon/src/ktx_daemon/__main__.py
Normal file
172
python/ktx-daemon/src/ktx_daemon/__main__.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
"""Command entry point for one-shot KTX daemon compute operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from typing import Any
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ktx_daemon.code_execution import ExecuteCodeRequest, execute_code_response
|
||||
from ktx_daemon.database_introspection import (
|
||||
DatabaseIntrospectionRequest,
|
||||
introspect_database_response,
|
||||
)
|
||||
from ktx_daemon.embeddings import (
|
||||
ComputeEmbeddingBulkRequest,
|
||||
ComputeEmbeddingRequest,
|
||||
compute_embedding_bulk_response,
|
||||
compute_embedding_response,
|
||||
)
|
||||
from ktx_daemon.lookml import ParseLookMLRequest, parse_lookml_project
|
||||
from ktx_daemon.semantic_layer import (
|
||||
SemanticLayerQueryRequest,
|
||||
ValidateSourcesRequest,
|
||||
query_semantic_layer,
|
||||
validate_semantic_layer,
|
||||
)
|
||||
from ktx_daemon.source_generation import (
|
||||
GenerateSourcesRequest,
|
||||
generate_sources_response,
|
||||
)
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="ktx-daemon")
|
||||
subcommands = parser.add_subparsers(dest="command", required=True)
|
||||
subcommands.add_parser("semantic-query", help="Compile a semantic-layer query")
|
||||
subcommands.add_parser("semantic-validate", help="Validate semantic-layer sources")
|
||||
subcommands.add_parser(
|
||||
"semantic-generate-sources",
|
||||
help="Generate semantic-layer sources from schema scan data",
|
||||
)
|
||||
subcommands.add_parser(
|
||||
"database-introspect",
|
||||
help="Introspect a Postgres database schema",
|
||||
)
|
||||
subcommands.add_parser(
|
||||
"lookml-parse",
|
||||
help="Parse LookML files into KSL-ready structures",
|
||||
)
|
||||
subcommands.add_parser(
|
||||
"embedding-compute",
|
||||
help="Compute one local text embedding",
|
||||
)
|
||||
subcommands.add_parser(
|
||||
"embedding-compute-bulk",
|
||||
help="Compute local text embeddings in bulk",
|
||||
)
|
||||
subcommands.add_parser(
|
||||
"code-execute",
|
||||
help="Execute Python code with the current in-process boundary",
|
||||
)
|
||||
serve_http = subcommands.add_parser(
|
||||
"serve-http",
|
||||
help="Run the KTX daemon portable compute HTTP server",
|
||||
)
|
||||
serve_http.add_argument("--host", default="127.0.0.1")
|
||||
serve_http.add_argument("--port", type=int, default=8765)
|
||||
serve_http.add_argument(
|
||||
"--log-level",
|
||||
default="info",
|
||||
choices=["critical", "error", "warning", "info", "debug", "trace"],
|
||||
)
|
||||
serve_http.add_argument(
|
||||
"--enable-code-execution",
|
||||
action="store_true",
|
||||
help="Expose POST /code/execute on the HTTP server",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def _read_stdin_json() -> dict[str, Any]:
|
||||
raw = sys.stdin.read()
|
||||
parsed = json.loads(raw)
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("stdin JSON must be an object")
|
||||
return parsed
|
||||
|
||||
|
||||
def run_http_server(
|
||||
*,
|
||||
host: str,
|
||||
port: int,
|
||||
log_level: str,
|
||||
enable_code_execution: bool,
|
||||
) -> None:
|
||||
import uvicorn
|
||||
|
||||
from ktx_daemon.app import create_app
|
||||
|
||||
uvicorn.run(
|
||||
create_app(enable_code_execution=enable_code_execution),
|
||||
host=host,
|
||||
port=port,
|
||||
log_level=log_level,
|
||||
)
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
if args.command == "serve-http":
|
||||
run_http_server(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.log_level,
|
||||
enable_code_execution=args.enable_code_execution,
|
||||
)
|
||||
return 0
|
||||
|
||||
try:
|
||||
payload = _read_stdin_json()
|
||||
if args.command == "semantic-query":
|
||||
response = query_semantic_layer(
|
||||
SemanticLayerQueryRequest.model_validate(payload)
|
||||
)
|
||||
elif args.command == "semantic-validate":
|
||||
response = validate_semantic_layer(
|
||||
ValidateSourcesRequest.model_validate(payload)
|
||||
)
|
||||
elif args.command == "semantic-generate-sources":
|
||||
response = generate_sources_response(
|
||||
GenerateSourcesRequest.model_validate(payload)
|
||||
)
|
||||
elif args.command == "database-introspect":
|
||||
response = introspect_database_response(
|
||||
DatabaseIntrospectionRequest.model_validate(payload)
|
||||
)
|
||||
elif args.command == "lookml-parse":
|
||||
response = parse_lookml_project(ParseLookMLRequest.model_validate(payload))
|
||||
elif args.command == "embedding-compute":
|
||||
response = compute_embedding_response(
|
||||
ComputeEmbeddingRequest.model_validate(payload)
|
||||
)
|
||||
elif args.command == "embedding-compute-bulk":
|
||||
response = compute_embedding_bulk_response(
|
||||
ComputeEmbeddingBulkRequest.model_validate(payload)
|
||||
)
|
||||
elif args.command == "code-execute":
|
||||
response = execute_code_response(
|
||||
ExecuteCodeRequest.model_validate(payload),
|
||||
nest_api_url=None,
|
||||
auth_header=None,
|
||||
)
|
||||
else:
|
||||
parser.error(f"Unknown command: {args.command}")
|
||||
return 2
|
||||
sys.stdout.write(response.model_dump_json() + "\n")
|
||||
return 0
|
||||
except (json.JSONDecodeError, ValidationError, ValueError) as error:
|
||||
sys.stderr.write(f"{error}\n")
|
||||
return 1
|
||||
except Exception as error:
|
||||
sys.stderr.write(f"{type(error).__name__}: {error}\n")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
228
python/ktx-daemon/src/ktx_daemon/app.py
Normal file
228
python/ktx-daemon/src/ktx_daemon/app.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""FastAPI app factory for the KTX daemon semantic compute server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import Response
|
||||
|
||||
from ktx_daemon.code_execution import (
|
||||
ExecuteCodeRequest,
|
||||
ExecuteCodeResponse,
|
||||
dumps_numpy_json,
|
||||
execute_code_response,
|
||||
)
|
||||
from ktx_daemon.database_introspection import (
|
||||
DatabaseIntrospectionRequest,
|
||||
DatabaseIntrospectionResponse,
|
||||
introspect_database_response,
|
||||
)
|
||||
from ktx_daemon.embeddings import (
|
||||
ComputeEmbeddingBulkRequest,
|
||||
ComputeEmbeddingBulkResponse,
|
||||
ComputeEmbeddingRequest,
|
||||
ComputeEmbeddingResponse,
|
||||
EmbeddingProvider,
|
||||
compute_embedding_bulk_response,
|
||||
compute_embedding_response,
|
||||
)
|
||||
from ktx_daemon.lookml import (
|
||||
ParseLookMLRequest,
|
||||
ParseLookMLResponse,
|
||||
parse_lookml_project,
|
||||
)
|
||||
from ktx_daemon.semantic_layer import (
|
||||
SemanticLayerQueryRequest,
|
||||
SemanticLayerQueryResponse,
|
||||
ValidateSourcesRequest,
|
||||
ValidateSourcesResponse,
|
||||
query_semantic_layer,
|
||||
validate_semantic_layer,
|
||||
)
|
||||
from ktx_daemon.source_generation import (
|
||||
GenerateSourcesRequest,
|
||||
GenerateSourcesResponse,
|
||||
generate_sources_response,
|
||||
)
|
||||
from ktx_daemon.table_identifier import (
|
||||
ParseTableIdentifierBatchRequest,
|
||||
ParseTableIdentifierBatchResponse,
|
||||
parse_table_identifier_response,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NumpyORJSONResponse(Response):
|
||||
media_type = "application/json"
|
||||
|
||||
def render(self, content: Any) -> bytes:
|
||||
return dumps_numpy_json(content)
|
||||
|
||||
|
||||
def create_app(
|
||||
*,
|
||||
embedding_provider: EmbeddingProvider | None = None,
|
||||
database_introspector: Callable[
|
||||
[DatabaseIntrospectionRequest], DatabaseIntrospectionResponse
|
||||
]
|
||||
| None = None,
|
||||
enable_code_execution: bool = False,
|
||||
) -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="KTX Daemon",
|
||||
description="Stateless portable compute server for KTX.",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, str]:
|
||||
return {"status": "healthy"}
|
||||
|
||||
@app.post("/database/introspect", response_model=DatabaseIntrospectionResponse)
|
||||
async def database_introspect(
|
||||
request: DatabaseIntrospectionRequest,
|
||||
) -> DatabaseIntrospectionResponse:
|
||||
try:
|
||||
introspector = database_introspector or introspect_database_response
|
||||
return introspector(request)
|
||||
except ValueError as error:
|
||||
logger.warning("Database introspection rejected: %s", error)
|
||||
raise HTTPException(status_code=400, detail=str(error)) from error
|
||||
except Exception as error:
|
||||
logger.exception("Database introspection failed: %s", error)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Database introspection failed: {error}",
|
||||
) from error
|
||||
|
||||
@app.post("/embeddings/compute", response_model=ComputeEmbeddingResponse)
|
||||
async def embedding_compute(
|
||||
request: ComputeEmbeddingRequest,
|
||||
) -> ComputeEmbeddingResponse:
|
||||
try:
|
||||
return compute_embedding_response(
|
||||
request,
|
||||
provider=embedding_provider,
|
||||
)
|
||||
except ValueError as error:
|
||||
logger.warning("Embedding compute rejected: %s", error)
|
||||
raise HTTPException(status_code=400, detail=str(error)) from error
|
||||
except Exception as error:
|
||||
logger.exception("Embedding compute failed: %s", error)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Embedding compute failed: {error}",
|
||||
) from error
|
||||
|
||||
@app.post(
|
||||
"/embeddings/compute-bulk",
|
||||
response_model=ComputeEmbeddingBulkResponse,
|
||||
)
|
||||
async def embedding_compute_bulk(
|
||||
request: ComputeEmbeddingBulkRequest,
|
||||
) -> ComputeEmbeddingBulkResponse:
|
||||
try:
|
||||
return compute_embedding_bulk_response(
|
||||
request,
|
||||
provider=embedding_provider,
|
||||
)
|
||||
except ValueError as error:
|
||||
logger.warning("Bulk embedding compute rejected: %s", error)
|
||||
raise HTTPException(status_code=400, detail=str(error)) from error
|
||||
except Exception as error:
|
||||
logger.exception("Bulk embedding compute failed: %s", error)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Bulk embedding compute failed: {error}",
|
||||
) from error
|
||||
|
||||
if enable_code_execution:
|
||||
|
||||
@app.post(
|
||||
"/code/execute",
|
||||
response_model=ExecuteCodeResponse,
|
||||
response_class=NumpyORJSONResponse,
|
||||
)
|
||||
async def code_execute(request: ExecuteCodeRequest) -> ExecuteCodeResponse:
|
||||
try:
|
||||
return execute_code_response(
|
||||
request,
|
||||
nest_api_url=None,
|
||||
auth_header=None,
|
||||
)
|
||||
except Exception as error:
|
||||
logger.exception("Code execution failed: %s", error)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Code execution failed: {error}",
|
||||
) from error
|
||||
|
||||
@app.post("/lookml/parse", response_model=ParseLookMLResponse)
|
||||
async def lookml_parse(request: ParseLookMLRequest) -> ParseLookMLResponse:
|
||||
try:
|
||||
return parse_lookml_project(request)
|
||||
except Exception as error:
|
||||
logger.exception("LookML parsing failed: %s", error)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"LookML parsing failed: {error}",
|
||||
) from error
|
||||
|
||||
@app.post(
|
||||
"/sql/parse-table-identifier",
|
||||
response_model=ParseTableIdentifierBatchResponse,
|
||||
)
|
||||
async def sql_parse_table_identifier(
|
||||
request: ParseTableIdentifierBatchRequest,
|
||||
) -> ParseTableIdentifierBatchResponse:
|
||||
try:
|
||||
return parse_table_identifier_response(request)
|
||||
except Exception as error:
|
||||
logger.exception("Table identifier parsing failed: %s", error)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Table identifier parsing failed: {error}",
|
||||
) from error
|
||||
|
||||
@app.post(
|
||||
"/semantic-layer/generate-sources", response_model=GenerateSourcesResponse
|
||||
)
|
||||
async def semantic_generate_sources(
|
||||
request: GenerateSourcesRequest,
|
||||
) -> GenerateSourcesResponse:
|
||||
try:
|
||||
return generate_sources_response(request)
|
||||
except Exception as error:
|
||||
logger.exception("Semantic source generation failed: %s", error)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Semantic source generation failed: {error}",
|
||||
) from error
|
||||
|
||||
@app.post("/semantic-layer/query", response_model=SemanticLayerQueryResponse)
|
||||
async def semantic_query(
|
||||
request: SemanticLayerQueryRequest,
|
||||
) -> SemanticLayerQueryResponse:
|
||||
try:
|
||||
return query_semantic_layer(request)
|
||||
except ValueError as error:
|
||||
logger.warning("Semantic query rejected: %s", error)
|
||||
raise HTTPException(status_code=400, detail=str(error)) from error
|
||||
except Exception as error:
|
||||
logger.exception("Semantic query failed: %s", error)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Semantic layer query failed: {error}",
|
||||
) from error
|
||||
|
||||
@app.post("/semantic-layer/validate", response_model=ValidateSourcesResponse)
|
||||
async def semantic_validate(
|
||||
request: ValidateSourcesRequest,
|
||||
) -> ValidateSourcesResponse:
|
||||
return validate_semantic_layer(request)
|
||||
|
||||
return app
|
||||
333
python/ktx-daemon/src/ktx_daemon/code_execution.py
Normal file
333
python/ktx-daemon/src/ktx_daemon/code_execution.py
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
"""Portable in-process code execution helpers for KTX daemon.
|
||||
|
||||
This module preserves the host application's current Python execution behavior.
|
||||
It runs code with Python ``exec`` in the current process and does not provide
|
||||
OS-level sandboxing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO, StringIO
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
import pandas as pd
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VALID_VISUALIZATION_TYPES = ["pie", "bar", "line", "area", "table", "boxplot"]
|
||||
|
||||
|
||||
class ExecuteCodeRequest(BaseModel):
|
||||
"""Request schema for executing Python code."""
|
||||
|
||||
code: str = Field(..., description="Python code to execute")
|
||||
source_id: str | None = Field(
|
||||
None,
|
||||
description="Chat/dashboard ID for scratchpad file access",
|
||||
)
|
||||
message_id: str | None = Field(
|
||||
None,
|
||||
description="Message ID for visualization association",
|
||||
)
|
||||
|
||||
|
||||
class VisualizationSpec(BaseModel):
|
||||
"""Specification for a visualization to be saved by the host application."""
|
||||
|
||||
type: str = Field(..., description="Type marker, always 'visualization'")
|
||||
vis_type: str = Field(
|
||||
...,
|
||||
description="Visualization type: pie, bar, line, area, table",
|
||||
)
|
||||
config: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Visualization configuration",
|
||||
)
|
||||
data: list[dict[str, Any]] = Field(
|
||||
...,
|
||||
description="Visualization data",
|
||||
)
|
||||
title: str | None = Field(None, description="Optional title")
|
||||
|
||||
|
||||
class ExecuteCodeResponse(BaseModel):
|
||||
"""Response schema for code execution."""
|
||||
|
||||
formatted_result: str = Field(
|
||||
...,
|
||||
description="Formatted execution result for display",
|
||||
)
|
||||
result: Any | None = Field(
|
||||
None,
|
||||
description="The value of the 'result' variable if set",
|
||||
)
|
||||
console_output: str | None = Field(
|
||||
None,
|
||||
description="Captured stdout from print statements",
|
||||
)
|
||||
error: str | None = Field(None, description="Error message if execution failed")
|
||||
message: str | None = Field(
|
||||
None,
|
||||
description="Message if no clear result was returned",
|
||||
)
|
||||
visualizations: list[VisualizationSpec] | None = Field(
|
||||
None,
|
||||
description="List of visualizations detected in the result",
|
||||
)
|
||||
|
||||
|
||||
ScratchpadHelpers = tuple[
|
||||
Callable[[pd.DataFrame, str | None], str],
|
||||
Callable[[str], pd.DataFrame],
|
||||
Callable[[str, dict[str, Any], list[dict[str, Any]]], str],
|
||||
]
|
||||
|
||||
|
||||
def dumps_numpy_json(content: Any) -> bytes:
|
||||
"""Serialize JSON response content with numpy scalar and array support."""
|
||||
|
||||
return orjson.dumps(content, option=orjson.OPT_SERIALIZE_NUMPY)
|
||||
|
||||
|
||||
def _strip_ansi_sequences(text: str) -> str:
|
||||
ansi_escape = re.compile(
|
||||
r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\([0-9;]*[a-zA-Z]|\x1b\[[0-9;]*~"
|
||||
)
|
||||
return ansi_escape.sub("", text)
|
||||
|
||||
|
||||
def create_scratchpad_helpers(
|
||||
nest_api_url: str | None,
|
||||
auth_header: str | None,
|
||||
source_id: str | None,
|
||||
message_id: str | None = None,
|
||||
http_client: Any = requests,
|
||||
) -> ScratchpadHelpers:
|
||||
"""Create scratchpad and visualization helpers that call host app APIs."""
|
||||
|
||||
def save_df_to_scratchpad(df: pd.DataFrame, filename: str | None = None) -> str:
|
||||
if not nest_api_url or not auth_header or not source_id:
|
||||
raise ValueError(
|
||||
"nest_api_url, Authorization header, and source_id are required "
|
||||
"for scratchpad operations"
|
||||
)
|
||||
|
||||
data_json = df.to_dict(orient="records")
|
||||
url = f"{nest_api_url}/private_api/scratchpad/{source_id}/files"
|
||||
response = http_client.post(
|
||||
url,
|
||||
data=dumps_numpy_json(
|
||||
{"filename": filename, "data": data_json, "format": "json"}
|
||||
),
|
||||
headers={"Authorization": auth_header, "Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
saved_filename = response.json()["filename"]
|
||||
rows, _cols = df.shape
|
||||
return f"{rows} rows saved to {saved_filename}"
|
||||
|
||||
def read_scratchpad_file(filename: str) -> pd.DataFrame:
|
||||
if not nest_api_url or not auth_header or not source_id:
|
||||
raise ValueError(
|
||||
"nest_api_url, Authorization header, and source_id are required "
|
||||
"for scratchpad operations"
|
||||
)
|
||||
|
||||
url = f"{nest_api_url}/private_api/scratchpad/{source_id}/files/{filename}?format=raw"
|
||||
response = http_client.get(
|
||||
url,
|
||||
headers={"Authorization": auth_header, "Accept": "text/csv"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
content_type = response.headers.get("content-type", "")
|
||||
if "text/csv" in content_type:
|
||||
return pd.read_csv(BytesIO(response.content))
|
||||
|
||||
data = response.json()["data"]
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def save_visualization(
|
||||
vis_type: str,
|
||||
config: dict[str, Any],
|
||||
data: list[dict[str, Any]],
|
||||
) -> str:
|
||||
if not nest_api_url or not auth_header or not source_id:
|
||||
raise ValueError(
|
||||
"nest_api_url, Authorization header, and source_id are required "
|
||||
"for visualization operations"
|
||||
)
|
||||
|
||||
if not message_id:
|
||||
raise ValueError("message_id is required for visualization operations")
|
||||
|
||||
if vis_type not in VALID_VISUALIZATION_TYPES:
|
||||
raise ValueError(
|
||||
f"Invalid visualization type: {vis_type}. Must be one of {VALID_VISUALIZATION_TYPES}"
|
||||
)
|
||||
|
||||
url = f"{nest_api_url}/private_api/visualizations/{source_id}"
|
||||
payload = {
|
||||
"visualizationType": vis_type,
|
||||
"config": config,
|
||||
"data": data,
|
||||
"messageId": message_id,
|
||||
}
|
||||
|
||||
response = http_client.post(
|
||||
url,
|
||||
data=dumps_numpy_json(payload),
|
||||
headers={"Authorization": auth_header, "Content-Type": "application/json"},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
filename = response.json()["filename"]
|
||||
print(f"Visualization saved: {filename}")
|
||||
return f""
|
||||
|
||||
return save_df_to_scratchpad, read_scratchpad_file, save_visualization
|
||||
|
||||
|
||||
def detect_visualizations(result: Any) -> list[dict[str, Any]]:
|
||||
"""Detect visualization specs in a code execution result value."""
|
||||
|
||||
visualizations = []
|
||||
|
||||
if isinstance(result, dict) and result.get("type") == "visualization":
|
||||
visualizations.append(result)
|
||||
elif isinstance(result, list):
|
||||
for item in result:
|
||||
if isinstance(item, dict) and item.get("type") == "visualization":
|
||||
visualizations.append(item)
|
||||
|
||||
return visualizations
|
||||
|
||||
|
||||
def execute_code(
|
||||
code: str,
|
||||
nest_api_url: str | None = None,
|
||||
auth_header: str | None = None,
|
||||
source_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
scratchpad_helpers: ScratchpadHelpers | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Execute Python code with the current in-process execution boundary."""
|
||||
|
||||
logger.info("Starting code execution")
|
||||
save_df, read_file, save_viz = scratchpad_helpers or create_scratchpad_helpers(
|
||||
nest_api_url,
|
||||
auth_header,
|
||||
source_id,
|
||||
message_id,
|
||||
)
|
||||
|
||||
namespace = {
|
||||
"pd": pd,
|
||||
"np": np,
|
||||
"json": json,
|
||||
"requests": requests,
|
||||
"save_df_to_scratchpad": save_df,
|
||||
"read_scratchpad_file": read_file,
|
||||
"save_visualization": save_viz,
|
||||
}
|
||||
|
||||
stdout_capture = StringIO()
|
||||
original_stdout = sys.stdout
|
||||
sys.stdout = stdout_capture
|
||||
console_output = ""
|
||||
|
||||
try:
|
||||
logger.info("Executing code in current process namespace")
|
||||
exec(code, namespace)
|
||||
|
||||
console_output = stdout_capture.getvalue()
|
||||
if "result" in namespace:
|
||||
logger.info("Code execution complete, 'result' variable found")
|
||||
result_value = namespace["result"]
|
||||
visualizations = detect_visualizations(result_value)
|
||||
|
||||
result = {"result": result_value}
|
||||
if console_output:
|
||||
result["console_output"] = console_output
|
||||
if visualizations:
|
||||
result["visualizations"] = visualizations
|
||||
|
||||
return result
|
||||
|
||||
logger.info("No result variable found")
|
||||
result = {
|
||||
"message": "Code executed successfully but no result variable was set"
|
||||
}
|
||||
if console_output:
|
||||
result["console_output"] = console_output
|
||||
return result
|
||||
|
||||
except Exception as error:
|
||||
logger.exception("Error executing code: %s", error)
|
||||
result = {"error": str(error)}
|
||||
if console_output:
|
||||
result["console_output"] = console_output
|
||||
return result
|
||||
|
||||
finally:
|
||||
sys.stdout = original_stdout
|
||||
|
||||
|
||||
def format_execution_result(result: dict[str, Any]) -> str:
|
||||
"""Format execution output for display in host chat responses."""
|
||||
|
||||
formatted_result = ""
|
||||
if "console_output" in result:
|
||||
formatted_result += "=== Console Output ===\n\n"
|
||||
formatted_result += _strip_ansi_sequences(result["console_output"])
|
||||
|
||||
if "result" in result:
|
||||
formatted_result += "\n\n=== Result ===\n\n"
|
||||
formatted_result += str(result["result"])
|
||||
elif "message" in result:
|
||||
formatted_result += "\n\n=== Message ===\n\n"
|
||||
formatted_result += result["message"]
|
||||
elif "error" in result:
|
||||
formatted_result += "\n\n=== Error ===\n\n"
|
||||
formatted_result += result["error"]
|
||||
|
||||
return formatted_result
|
||||
|
||||
|
||||
def execute_code_response(
|
||||
request: ExecuteCodeRequest,
|
||||
*,
|
||||
nest_api_url: str | None,
|
||||
auth_header: str | None,
|
||||
) -> ExecuteCodeResponse:
|
||||
"""Execute a validated request and return the public response model."""
|
||||
|
||||
result = execute_code(
|
||||
code=request.code,
|
||||
nest_api_url=nest_api_url,
|
||||
auth_header=auth_header,
|
||||
source_id=request.source_id,
|
||||
message_id=request.message_id,
|
||||
)
|
||||
|
||||
return ExecuteCodeResponse(
|
||||
formatted_result=format_execution_result(result),
|
||||
result=result.get("result"),
|
||||
console_output=result.get("console_output"),
|
||||
error=result.get("error"),
|
||||
message=result.get("message"),
|
||||
visualizations=result.get("visualizations"),
|
||||
)
|
||||
284
python/ktx-daemon/src/ktx_daemon/database_introspection.py
Normal file
284
python/ktx-daemon/src/ktx_daemon/database_introspection.py
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
"""Portable database introspection helpers for KTX daemon."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
TABLES_SQL = """
|
||||
select
|
||||
t.table_catalog,
|
||||
t.table_schema,
|
||||
t.table_name,
|
||||
obj_description(c.oid) as table_comment
|
||||
from information_schema.tables t
|
||||
join pg_catalog.pg_namespace n
|
||||
on n.nspname = t.table_schema
|
||||
join pg_catalog.pg_class c
|
||||
on c.relnamespace = n.oid
|
||||
and c.relname = t.table_name
|
||||
where t.table_schema = any(%s)
|
||||
and t.table_type = 'BASE TABLE'
|
||||
order by t.table_schema, t.table_name
|
||||
"""
|
||||
|
||||
COLUMNS_SQL = """
|
||||
select
|
||||
current_database() as table_catalog,
|
||||
n.nspname as table_schema,
|
||||
c.relname as table_name,
|
||||
a.attname as column_name,
|
||||
pg_catalog.format_type(a.atttypid, a.atttypmod) as formatted_type,
|
||||
not a.attnotnull as is_nullable,
|
||||
exists (
|
||||
select 1
|
||||
from pg_catalog.pg_index i
|
||||
where i.indrelid = c.oid
|
||||
and i.indisprimary
|
||||
and a.attnum = any(i.indkey)
|
||||
) as is_primary_key,
|
||||
pg_catalog.col_description(c.oid, a.attnum) as column_comment
|
||||
from pg_catalog.pg_attribute a
|
||||
join pg_catalog.pg_class c
|
||||
on c.oid = a.attrelid
|
||||
join pg_catalog.pg_namespace n
|
||||
on n.oid = c.relnamespace
|
||||
where n.nspname = any(%s)
|
||||
and c.relkind in ('r', 'p')
|
||||
and a.attnum > 0
|
||||
and not a.attisdropped
|
||||
order by n.nspname, c.relname, a.attnum
|
||||
"""
|
||||
|
||||
FOREIGN_KEYS_SQL = """
|
||||
select
|
||||
current_database() as table_catalog,
|
||||
source_constraint.table_schema,
|
||||
source_constraint.table_name,
|
||||
source_key.column_name as from_column,
|
||||
target_key.table_name as to_table,
|
||||
target_key.column_name as to_column,
|
||||
source_constraint.constraint_name
|
||||
from information_schema.table_constraints source_constraint
|
||||
join information_schema.key_column_usage source_key
|
||||
on source_key.constraint_catalog = source_constraint.constraint_catalog
|
||||
and source_key.constraint_schema = source_constraint.constraint_schema
|
||||
and source_key.constraint_name = source_constraint.constraint_name
|
||||
join information_schema.referential_constraints ref_constraint
|
||||
on ref_constraint.constraint_catalog = source_constraint.constraint_catalog
|
||||
and ref_constraint.constraint_schema = source_constraint.constraint_schema
|
||||
and ref_constraint.constraint_name = source_constraint.constraint_name
|
||||
join information_schema.key_column_usage target_key
|
||||
on target_key.constraint_catalog = ref_constraint.unique_constraint_catalog
|
||||
and target_key.constraint_schema = ref_constraint.unique_constraint_schema
|
||||
and target_key.constraint_name = ref_constraint.unique_constraint_name
|
||||
and target_key.ordinal_position = source_key.position_in_unique_constraint
|
||||
where source_constraint.constraint_type = 'FOREIGN KEY'
|
||||
and source_constraint.table_schema = any(%s)
|
||||
order by source_constraint.table_schema, source_constraint.table_name, source_constraint.constraint_name, source_key.ordinal_position
|
||||
"""
|
||||
|
||||
|
||||
class LiveDatabaseColumn(BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
nullable: bool = True
|
||||
primary_key: bool = False
|
||||
comment: str | None = None
|
||||
|
||||
|
||||
class LiveDatabaseForeignKey(BaseModel):
|
||||
from_column: str
|
||||
to_table: str
|
||||
to_column: str
|
||||
constraint_name: str | None = None
|
||||
|
||||
|
||||
class LiveDatabaseTable(BaseModel):
|
||||
catalog: str | None = None
|
||||
db: str | None = None
|
||||
name: str
|
||||
comment: str | None = None
|
||||
columns: list[LiveDatabaseColumn] = Field(default_factory=list)
|
||||
foreign_keys: list[LiveDatabaseForeignKey] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DatabaseIntrospectionRequest(BaseModel):
|
||||
connection_id: str
|
||||
driver: str = "postgres"
|
||||
url: str
|
||||
schemas: list[str] = Field(default_factory=lambda: ["public"])
|
||||
statement_timeout_ms: int = Field(default=30_000, ge=1)
|
||||
connection_timeout_seconds: int = Field(default=5, ge=1)
|
||||
|
||||
@field_validator("schemas")
|
||||
@classmethod
|
||||
def _schemas_must_not_be_empty(cls, value: list[str]) -> list[str]:
|
||||
if not value:
|
||||
raise ValueError("database introspection requires at least one schema")
|
||||
return value
|
||||
|
||||
|
||||
class DatabaseIntrospectionResponse(BaseModel):
|
||||
connection_id: str
|
||||
extracted_at: str
|
||||
metadata: dict[str, Any]
|
||||
tables: list[LiveDatabaseTable]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatabaseIntrospectionRows:
|
||||
table_rows: Sequence[Mapping[str, Any]]
|
||||
column_rows: Sequence[Mapping[str, Any]]
|
||||
foreign_key_rows: Sequence[Mapping[str, Any]]
|
||||
|
||||
|
||||
DatabaseRowsLoader = Callable[[DatabaseIntrospectionRequest], DatabaseIntrospectionRows]
|
||||
NowProvider = Callable[[], str]
|
||||
|
||||
|
||||
def _driver_name(driver: str) -> str:
|
||||
return driver.strip().lower()
|
||||
|
||||
|
||||
def _table_key(catalog: str | None, db: str | None, name: str) -> str:
|
||||
return f"{catalog or ''}\u0000{db or ''}\u0000{name}"
|
||||
|
||||
|
||||
def _optional_string(row: Mapping[str, Any], key: str) -> str | None:
|
||||
value = row.get(key)
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
|
||||
def _required_string(row: Mapping[str, Any], key: str) -> str:
|
||||
value = row.get(key)
|
||||
if not isinstance(value, str) or not value:
|
||||
raise ValueError(f"database introspection row is missing string field {key}")
|
||||
return value
|
||||
|
||||
|
||||
def _statement_timeout_config(statement_timeout_ms: int) -> tuple[str, tuple[str]]:
|
||||
return (
|
||||
"SELECT set_config('statement_timeout', %s, true)",
|
||||
(f"{int(statement_timeout_ms)}ms",),
|
||||
)
|
||||
|
||||
|
||||
def _load_postgres_rows(
|
||||
request: DatabaseIntrospectionRequest,
|
||||
) -> DatabaseIntrospectionRows:
|
||||
try:
|
||||
import psycopg
|
||||
from psycopg.rows import dict_row
|
||||
except ImportError as error:
|
||||
raise RuntimeError(
|
||||
"psycopg is required for Postgres database introspection"
|
||||
) from error
|
||||
|
||||
connection = psycopg.connect(
|
||||
request.url,
|
||||
connect_timeout=request.connection_timeout_seconds,
|
||||
application_name="ktx-daemon-database-introspection",
|
||||
row_factory=dict_row,
|
||||
)
|
||||
try:
|
||||
connection.execute("BEGIN READ ONLY")
|
||||
try:
|
||||
connection.execute(*_statement_timeout_config(request.statement_timeout_ms))
|
||||
params = (request.schemas,)
|
||||
table_rows = list(connection.execute(TABLES_SQL, params))
|
||||
column_rows = list(connection.execute(COLUMNS_SQL, params))
|
||||
foreign_key_rows = list(connection.execute(FOREIGN_KEYS_SQL, params))
|
||||
connection.execute("COMMIT")
|
||||
except Exception:
|
||||
connection.execute("ROLLBACK")
|
||||
raise
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
return DatabaseIntrospectionRows(
|
||||
table_rows=table_rows,
|
||||
column_rows=column_rows,
|
||||
foreign_key_rows=foreign_key_rows,
|
||||
)
|
||||
|
||||
|
||||
def _map_rows_to_tables(rows: DatabaseIntrospectionRows) -> list[LiveDatabaseTable]:
|
||||
tables: dict[str, LiveDatabaseTable] = {}
|
||||
|
||||
for row in rows.table_rows:
|
||||
catalog = _optional_string(row, "table_catalog")
|
||||
db = _required_string(row, "table_schema")
|
||||
name = _required_string(row, "table_name")
|
||||
key = _table_key(catalog, db, name)
|
||||
tables[key] = LiveDatabaseTable(
|
||||
catalog=catalog,
|
||||
db=db,
|
||||
name=name,
|
||||
comment=_optional_string(row, "table_comment"),
|
||||
)
|
||||
|
||||
for row in rows.column_rows:
|
||||
catalog = _optional_string(row, "table_catalog")
|
||||
db = _required_string(row, "table_schema")
|
||||
table_name = _required_string(row, "table_name")
|
||||
table = tables.get(_table_key(catalog, db, table_name))
|
||||
if table is None:
|
||||
continue
|
||||
|
||||
table.columns.append(
|
||||
LiveDatabaseColumn(
|
||||
name=_required_string(row, "column_name"),
|
||||
type=_required_string(row, "formatted_type"),
|
||||
nullable=bool(row.get("is_nullable")),
|
||||
primary_key=bool(row.get("is_primary_key")),
|
||||
comment=_optional_string(row, "column_comment"),
|
||||
)
|
||||
)
|
||||
|
||||
for row in rows.foreign_key_rows:
|
||||
catalog = _optional_string(row, "table_catalog")
|
||||
db = _required_string(row, "table_schema")
|
||||
table_name = _required_string(row, "table_name")
|
||||
table = tables.get(_table_key(catalog, db, table_name))
|
||||
if table is None:
|
||||
continue
|
||||
|
||||
table.foreign_keys.append(
|
||||
LiveDatabaseForeignKey(
|
||||
from_column=_required_string(row, "from_column"),
|
||||
to_table=_required_string(row, "to_table"),
|
||||
to_column=_required_string(row, "to_column"),
|
||||
constraint_name=_optional_string(row, "constraint_name"),
|
||||
)
|
||||
)
|
||||
|
||||
return sorted(
|
||||
tables.values(),
|
||||
key=lambda table: _table_key(table.catalog, table.db, table.name),
|
||||
)
|
||||
|
||||
|
||||
def introspect_database_response(
|
||||
request: DatabaseIntrospectionRequest,
|
||||
*,
|
||||
load_rows: DatabaseRowsLoader | None = None,
|
||||
now: NowProvider | None = None,
|
||||
) -> DatabaseIntrospectionResponse:
|
||||
driver = _driver_name(request.driver)
|
||||
if driver not in {"postgres", "postgresql"}:
|
||||
raise ValueError('database introspection supports only driver "postgres"')
|
||||
|
||||
rows = (load_rows or _load_postgres_rows)(request)
|
||||
timestamp = now() if now else datetime.now(timezone.utc).isoformat()
|
||||
return DatabaseIntrospectionResponse(
|
||||
connection_id=request.connection_id,
|
||||
extracted_at=timestamp,
|
||||
metadata={"driver": driver, "schemas": list(request.schemas)},
|
||||
tables=_map_rows_to_tables(rows),
|
||||
)
|
||||
172
python/ktx-daemon/src/ktx_daemon/embeddings.py
Normal file
172
python/ktx-daemon/src/ktx_daemon/embeddings.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
"""Portable embedding compute helpers for KTX daemon."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SENTENCE_TRANSFORMER_MODEL = "all-MiniLM-L6-v2"
|
||||
DEFAULT_EMBEDDING_DIMENSIONS = 384
|
||||
DEFAULT_MAX_BATCH_SIZE = 100
|
||||
|
||||
|
||||
class EmbeddingProvider(Protocol):
|
||||
"""Provider interface for local embedding compute."""
|
||||
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int: ...
|
||||
|
||||
@property
|
||||
def max_batch_size(self) -> int: ...
|
||||
|
||||
def encode(self, texts: list[str]) -> list[list[float]]: ...
|
||||
|
||||
|
||||
class ComputeEmbeddingRequest(BaseModel):
|
||||
"""Request schema for computing a single embedding."""
|
||||
|
||||
text: str = Field(..., description="Text to compute embedding for", min_length=1)
|
||||
|
||||
|
||||
class ComputeEmbeddingResponse(BaseModel):
|
||||
"""Response schema for single embedding computation."""
|
||||
|
||||
embedding: list[float] = Field(..., description="384-dimensional embedding vector")
|
||||
|
||||
|
||||
class ComputeEmbeddingBulkRequest(BaseModel):
|
||||
"""Request schema for computing multiple embeddings."""
|
||||
|
||||
texts: list[str] = Field(
|
||||
...,
|
||||
description="List of texts to compute embeddings for",
|
||||
min_length=1,
|
||||
max_length=DEFAULT_MAX_BATCH_SIZE,
|
||||
)
|
||||
|
||||
|
||||
class ComputeEmbeddingBulkResponse(BaseModel):
|
||||
"""Response schema for bulk embedding computation."""
|
||||
|
||||
embeddings: list[list[float]] = Field(
|
||||
...,
|
||||
description="List of 384-dimensional embedding vectors",
|
||||
)
|
||||
|
||||
|
||||
class SentenceTransformersEmbeddingProvider:
|
||||
"""Lazy sentence-transformers provider for local embeddings."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DEFAULT_SENTENCE_TRANSFORMER_MODEL,
|
||||
model: SentenceTransformer | None = None,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self._model = model
|
||||
self._model_lock = threading.Lock()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "sentence-transformers"
|
||||
|
||||
@property
|
||||
def dimensions(self) -> int:
|
||||
return DEFAULT_EMBEDDING_DIMENSIONS
|
||||
|
||||
@property
|
||||
def max_batch_size(self) -> int:
|
||||
return DEFAULT_MAX_BATCH_SIZE
|
||||
|
||||
def _get_model(self) -> SentenceTransformer:
|
||||
if self._model is not None:
|
||||
return self._model
|
||||
|
||||
with self._model_lock:
|
||||
if self._model is None:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger.info("Loading SentenceTransformer model: %s", self.model_name)
|
||||
self._model = SentenceTransformer(self.model_name)
|
||||
logger.info("SentenceTransformer model loaded successfully")
|
||||
|
||||
return self._model
|
||||
|
||||
def encode(self, texts: list[str]) -> list[list[float]]:
|
||||
model = self._get_model()
|
||||
if len(texts) == 1:
|
||||
raw_single = model.encode(texts[0]).tolist()
|
||||
return [[float(value) for value in raw_single]]
|
||||
|
||||
raw_bulk = model.encode(texts).tolist()
|
||||
return [[float(value) for value in embedding] for embedding in raw_bulk]
|
||||
|
||||
|
||||
_default_provider: SentenceTransformersEmbeddingProvider | None = None
|
||||
_default_provider_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_default_embedding_provider() -> SentenceTransformersEmbeddingProvider:
|
||||
"""Return the process-wide default embedding provider."""
|
||||
|
||||
global _default_provider
|
||||
|
||||
if _default_provider is not None:
|
||||
return _default_provider
|
||||
|
||||
with _default_provider_lock:
|
||||
if _default_provider is None:
|
||||
_default_provider = SentenceTransformersEmbeddingProvider()
|
||||
|
||||
return _default_provider
|
||||
|
||||
|
||||
def _validate_texts(texts: list[str], max_batch_size: int) -> None:
|
||||
if not texts:
|
||||
raise ValueError("Texts array must not be empty")
|
||||
if len(texts) > max_batch_size:
|
||||
raise ValueError(f"Maximum {max_batch_size} texts allowed per batch")
|
||||
|
||||
empty_indices = [
|
||||
index for index, text in enumerate(texts) if not text or not text.strip()
|
||||
]
|
||||
if empty_indices:
|
||||
joined_indices = ", ".join(str(index) for index in empty_indices)
|
||||
raise ValueError(f"Empty texts found at indices: {joined_indices}")
|
||||
|
||||
|
||||
def compute_embedding_response(
|
||||
request: ComputeEmbeddingRequest,
|
||||
provider: EmbeddingProvider | None = None,
|
||||
) -> ComputeEmbeddingResponse:
|
||||
"""Compute one embedding from a request model."""
|
||||
|
||||
selected_provider = provider or get_default_embedding_provider()
|
||||
_validate_texts([request.text], selected_provider.max_batch_size)
|
||||
return ComputeEmbeddingResponse(
|
||||
embedding=selected_provider.encode([request.text])[0]
|
||||
)
|
||||
|
||||
|
||||
def compute_embedding_bulk_response(
|
||||
request: ComputeEmbeddingBulkRequest,
|
||||
provider: EmbeddingProvider | None = None,
|
||||
) -> ComputeEmbeddingBulkResponse:
|
||||
"""Compute multiple embeddings from a request model."""
|
||||
|
||||
selected_provider = provider or get_default_embedding_provider()
|
||||
_validate_texts(request.texts, selected_provider.max_batch_size)
|
||||
return ComputeEmbeddingBulkResponse(
|
||||
embeddings=selected_provider.encode(request.texts)
|
||||
)
|
||||
1056
python/ktx-daemon/src/ktx_daemon/lookml.py
Normal file
1056
python/ktx-daemon/src/ktx_daemon/lookml.py
Normal file
File diff suppressed because it is too large
Load diff
136
python/ktx-daemon/src/ktx_daemon/semantic_layer.py
Normal file
136
python/ktx-daemon/src/ktx_daemon/semantic_layer.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""Semantic-layer compute helpers for the KTX daemon package."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from semantic_layer.duplicate_check import validate_measure_duplicates
|
||||
from semantic_layer.engine import SemanticEngine
|
||||
from semantic_layer.models import QueryResult, SourceDefinition
|
||||
|
||||
|
||||
class SemanticLayerQueryRequest(BaseModel):
|
||||
sources: list[dict[str, Any]]
|
||||
query: dict[str, Any]
|
||||
dialect: str = "postgres"
|
||||
|
||||
|
||||
class SemanticLayerQueryResponse(BaseModel):
|
||||
sql: str
|
||||
dialect: str
|
||||
columns: list[dict[str, Any]]
|
||||
plan: dict[str, Any]
|
||||
|
||||
|
||||
class ValidateSourcesRequest(BaseModel):
|
||||
sources: list[dict[str, Any]]
|
||||
dialect: str = "postgres"
|
||||
recently_touched: list[str] | None = None
|
||||
|
||||
|
||||
class ValidateSourcesResponse(BaseModel):
|
||||
valid: bool
|
||||
errors: list[str] = Field(default_factory=list)
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
per_source_warnings: dict[str, list[str]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
def _load_sources(raw_sources: list[dict[str, Any]]) -> dict[str, SourceDefinition]:
|
||||
sources: dict[str, SourceDefinition] = {}
|
||||
for raw_source in raw_sources:
|
||||
source = SourceDefinition(**raw_source)
|
||||
if source.name in sources:
|
||||
raise ValueError(f"Duplicate source name '{source.name}'")
|
||||
sources[source.name] = source
|
||||
return sources
|
||||
|
||||
|
||||
def _validate_duplicate_measure_names(source: SourceDefinition) -> list[str]:
|
||||
errors: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for measure in source.measures:
|
||||
if measure.name in seen:
|
||||
errors.append(
|
||||
f"Duplicate measure '{measure.name}' on source '{source.name}'"
|
||||
)
|
||||
continue
|
||||
seen.add(measure.name)
|
||||
return errors
|
||||
|
||||
|
||||
def _response_columns(result: QueryResult) -> list[dict[str, Any]]:
|
||||
measure_names = {
|
||||
measure.name: measure.qualified_ref
|
||||
for measure in result.resolved_plan.measures
|
||||
if measure.qualified_ref
|
||||
}
|
||||
columns: list[dict[str, Any]] = []
|
||||
for column in result.columns:
|
||||
dumped = column.model_dump(mode="json")
|
||||
if column.provenance.value == "dimension" and column.expr:
|
||||
dumped["name"] = column.expr
|
||||
elif column.name in measure_names:
|
||||
dumped["name"] = measure_names[column.name]
|
||||
columns.append(dumped)
|
||||
return columns
|
||||
|
||||
|
||||
def query_semantic_layer(
|
||||
request: SemanticLayerQueryRequest,
|
||||
) -> SemanticLayerQueryResponse:
|
||||
sources = _load_sources(request.sources)
|
||||
engine = SemanticEngine.from_sources(sources, dialect=request.dialect)
|
||||
result = engine.query(request.query)
|
||||
return SemanticLayerQueryResponse(
|
||||
sql=result.sql,
|
||||
dialect=result.dialect,
|
||||
columns=_response_columns(result),
|
||||
plan=result.resolved_plan.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
|
||||
def validate_semantic_layer(request: ValidateSourcesRequest) -> ValidateSourcesResponse:
|
||||
errors: list[str] = []
|
||||
warnings: list[str] = []
|
||||
per_source_warnings: dict[str, list[str]] = {}
|
||||
sources: dict[str, SourceDefinition] = {}
|
||||
seen_names: set[str] = set()
|
||||
|
||||
for raw_source in request.sources:
|
||||
raw_name = raw_source.get("name") if isinstance(raw_source, dict) else None
|
||||
try:
|
||||
source = SourceDefinition(**raw_source)
|
||||
except Exception as error:
|
||||
label = raw_name or "<unknown>"
|
||||
errors.append(f"Source '{label}' failed to parse: {error}")
|
||||
continue
|
||||
|
||||
if source.name in seen_names:
|
||||
errors.append(f"Duplicate source name '{source.name}'")
|
||||
continue
|
||||
seen_names.add(source.name)
|
||||
sources[source.name] = source
|
||||
errors.extend(_validate_duplicate_measure_names(source))
|
||||
|
||||
if sources:
|
||||
try:
|
||||
engine = SemanticEngine.from_sources(sources, dialect=request.dialect)
|
||||
report = engine.validate(
|
||||
recently_touched=set(request.recently_touched)
|
||||
if request.recently_touched
|
||||
else None
|
||||
)
|
||||
errors.extend(report.errors)
|
||||
warnings.extend(report.warnings)
|
||||
per_source_warnings.update(report.per_source_warnings)
|
||||
errors.extend(validate_measure_duplicates(sources, dialect=request.dialect))
|
||||
except Exception as error:
|
||||
errors.append(f"Validation failed: {error}")
|
||||
|
||||
return ValidateSourcesResponse(
|
||||
valid=len(errors) == 0,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
per_source_warnings=per_source_warnings,
|
||||
)
|
||||
254
python/ktx-daemon/src/ktx_daemon/source_generation.py
Normal file
254
python/ktx-daemon/src/ktx_daemon/source_generation.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""Generate ktx-sl YAML source definitions from database schema scan data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from semantic_layer.models import (
|
||||
ColumnRole,
|
||||
JoinDeclaration,
|
||||
MeasureDefinition,
|
||||
SourceColumn,
|
||||
SourceDefinition,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NUMBER_PATTERN = re.compile(
|
||||
r"int|integer|bigint|smallint|tinyint|numeric|decimal|float|double|real|number|money",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_TIME_PATTERN = re.compile(
|
||||
r"timestamp|datetime|date|time(?!stamp)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
_BOOLEAN_PATTERN = re.compile(r"bool|boolean|bit", re.IGNORECASE)
|
||||
_ID_PATTERN = re.compile(
|
||||
r"^id$|_id$|^uuid$|_uuid$|_key$|_pk$|identifier$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_RELATIONSHIP_MAP = {
|
||||
"MANY_TO_ONE": "many_to_one",
|
||||
"ONE_TO_MANY": "one_to_many",
|
||||
"ONE_TO_ONE": "one_to_one",
|
||||
"many_to_one": "many_to_one",
|
||||
"one_to_many": "one_to_many",
|
||||
"one_to_one": "one_to_one",
|
||||
}
|
||||
|
||||
_RELATIONSHIP_INVERSE = {
|
||||
"many_to_one": "one_to_many",
|
||||
"one_to_many": "many_to_one",
|
||||
"one_to_one": "one_to_one",
|
||||
}
|
||||
|
||||
|
||||
class ColumnInput(BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
primary_key: bool = False
|
||||
nullable: bool = True
|
||||
comment: str | None = None
|
||||
|
||||
|
||||
class TableInput(BaseModel):
|
||||
name: str
|
||||
catalog: str | None = None
|
||||
db: str | None = None
|
||||
comment: str | None = None
|
||||
columns: list[ColumnInput]
|
||||
|
||||
|
||||
class LinkInput(BaseModel):
|
||||
from_table: str
|
||||
from_column: str
|
||||
to_table: str
|
||||
to_column: str
|
||||
relationship_type: str
|
||||
|
||||
|
||||
class GenerateSourcesRequest(BaseModel):
|
||||
tables: list[TableInput]
|
||||
links: list[LinkInput]
|
||||
dialect: str = "postgres"
|
||||
|
||||
|
||||
class GenerateSourcesResponse(BaseModel):
|
||||
sources: list[dict[str, Any]]
|
||||
source_count: int
|
||||
|
||||
|
||||
def _map_column_type(db_type: str) -> str:
|
||||
if _BOOLEAN_PATTERN.search(db_type):
|
||||
return "boolean"
|
||||
if _TIME_PATTERN.search(db_type):
|
||||
return "time"
|
||||
if _NUMBER_PATTERN.search(db_type):
|
||||
return "number"
|
||||
return "string"
|
||||
|
||||
|
||||
def _build_table_ref(table: TableInput) -> str:
|
||||
parts = []
|
||||
if table.catalog:
|
||||
parts.append(table.catalog)
|
||||
if table.db:
|
||||
parts.append(table.db)
|
||||
parts.append(table.name)
|
||||
return ".".join(parts)
|
||||
|
||||
|
||||
def _generate_measures(
|
||||
table_name: str,
|
||||
columns: list[ColumnInput],
|
||||
pk_columns: list[str],
|
||||
) -> list[MeasureDefinition]:
|
||||
measures: list[MeasureDefinition] = []
|
||||
|
||||
if pk_columns:
|
||||
pk = pk_columns[0]
|
||||
measures.append(
|
||||
MeasureDefinition(
|
||||
name="record_count",
|
||||
expr=f"count({pk})",
|
||||
description=f"Count of {table_name} records",
|
||||
)
|
||||
)
|
||||
|
||||
for col in columns:
|
||||
if _map_column_type(col.type) != "number":
|
||||
continue
|
||||
if _ID_PATTERN.search(col.name):
|
||||
continue
|
||||
measures.append(
|
||||
MeasureDefinition(
|
||||
name=f"total_{col.name}",
|
||||
expr=f"sum({col.name})",
|
||||
description=f"Sum of {col.name}"
|
||||
+ (f" \u2014 {col.comment}" if col.comment else ""),
|
||||
)
|
||||
)
|
||||
measures.append(
|
||||
MeasureDefinition(
|
||||
name=f"avg_{col.name}",
|
||||
expr=f"avg({col.name})",
|
||||
description=f"Average of {col.name}"
|
||||
+ (f" \u2014 {col.comment}" if col.comment else ""),
|
||||
)
|
||||
)
|
||||
|
||||
return measures
|
||||
|
||||
|
||||
def generate_sources(request: GenerateSourcesRequest) -> list[dict[str, Any]]:
|
||||
links_by_from: dict[str, list[LinkInput]] = {}
|
||||
links_by_to: dict[str, list[LinkInput]] = {}
|
||||
for link in request.links:
|
||||
links_by_from.setdefault(link.from_table, []).append(link)
|
||||
links_by_to.setdefault(link.to_table, []).append(link)
|
||||
|
||||
table_names = {table.name for table in request.tables}
|
||||
sources: list[dict[str, Any]] = []
|
||||
|
||||
for table in request.tables:
|
||||
pk_columns = [column.name for column in table.columns if column.primary_key]
|
||||
grain = (
|
||||
pk_columns
|
||||
if pk_columns
|
||||
else [table.columns[0].name]
|
||||
if table.columns
|
||||
else ["id"]
|
||||
)
|
||||
|
||||
sl_columns: list[SourceColumn] = []
|
||||
for column in table.columns:
|
||||
sl_type = _map_column_type(column.type)
|
||||
role = ColumnRole.TIME if sl_type == "time" else ColumnRole.DEFAULT
|
||||
sl_columns.append(
|
||||
SourceColumn(
|
||||
name=column.name,
|
||||
type=sl_type,
|
||||
role=role,
|
||||
description=column.comment,
|
||||
)
|
||||
)
|
||||
|
||||
joins: list[JoinDeclaration] = []
|
||||
for link in links_by_from.get(table.name, []):
|
||||
if link.to_table not in table_names:
|
||||
logger.warning(
|
||||
"Skipping link from %s.%s to %s.%s: target table not in scan",
|
||||
link.from_table,
|
||||
link.from_column,
|
||||
link.to_table,
|
||||
link.to_column,
|
||||
)
|
||||
continue
|
||||
|
||||
relationship = _RELATIONSHIP_MAP.get(link.relationship_type, "many_to_one")
|
||||
joins.append(
|
||||
JoinDeclaration(
|
||||
to=link.to_table,
|
||||
on=f"{link.from_column} = {link.to_table}.{link.to_column}",
|
||||
relationship=relationship,
|
||||
)
|
||||
)
|
||||
|
||||
for link in links_by_to.get(table.name, []):
|
||||
if link.from_table not in table_names:
|
||||
logger.warning(
|
||||
"Skipping reverse link from %s.%s to %s.%s: source table not in scan",
|
||||
link.from_table,
|
||||
link.from_column,
|
||||
link.to_table,
|
||||
link.to_column,
|
||||
)
|
||||
continue
|
||||
|
||||
forward_relationship = _RELATIONSHIP_MAP.get(
|
||||
link.relationship_type, "many_to_one"
|
||||
)
|
||||
reverse_relationship = _RELATIONSHIP_INVERSE.get(
|
||||
forward_relationship, "one_to_many"
|
||||
)
|
||||
joins.append(
|
||||
JoinDeclaration(
|
||||
to=link.from_table,
|
||||
on=f"{link.to_column} = {link.from_table}.{link.from_column}",
|
||||
relationship=reverse_relationship,
|
||||
)
|
||||
)
|
||||
|
||||
to_counts: dict[str, int] = {}
|
||||
for join in joins:
|
||||
to_counts[join.to] = to_counts.get(join.to, 0) + 1
|
||||
if any(count > 1 for count in to_counts.values()):
|
||||
for join in joins:
|
||||
if to_counts[join.to] > 1:
|
||||
fk_col = join.on.split(" = ")[0].strip().lower()
|
||||
join.alias = f"{join.to}_{fk_col}"
|
||||
|
||||
source = SourceDefinition(
|
||||
name=table.name,
|
||||
description=table.comment,
|
||||
table=_build_table_ref(table),
|
||||
grain=grain,
|
||||
columns=sl_columns,
|
||||
joins=joins,
|
||||
measures=_generate_measures(table.name, table.columns, pk_columns),
|
||||
)
|
||||
sources.append(source.model_dump(exclude_none=True))
|
||||
|
||||
logger.info("Generated %d ktx-sl source definitions", len(sources))
|
||||
return sources
|
||||
|
||||
|
||||
def generate_sources_response(
|
||||
request: GenerateSourcesRequest,
|
||||
) -> GenerateSourcesResponse:
|
||||
sources = generate_sources(request)
|
||||
return GenerateSourcesResponse(sources=sources, source_count=len(sources))
|
||||
66
python/ktx-daemon/src/ktx_daemon/table_identifier.py
Normal file
66
python/ktx-daemon/src/ktx_daemon/table_identifier.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from semantic_layer.table_identifier_parser import (
|
||||
ParseTableIdentifierItem as SharedParseTableIdentifierItem,
|
||||
parse_table_identifier_batch,
|
||||
)
|
||||
|
||||
ParseTableIdentifierReason = Literal[
|
||||
"looker_template_unresolved",
|
||||
"derived_table_not_supported",
|
||||
"no_physical_table",
|
||||
"multiple_table_references",
|
||||
"unsupported_dialect",
|
||||
"parse_error",
|
||||
]
|
||||
|
||||
|
||||
class ParseTableIdentifierItem(BaseModel):
|
||||
key: str
|
||||
sql_table_name: str
|
||||
dialect: str
|
||||
|
||||
|
||||
class ParseTableIdentifierBatchRequest(BaseModel):
|
||||
items: list[ParseTableIdentifierItem]
|
||||
|
||||
|
||||
class ParsedIdentifier(BaseModel):
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
ok: bool
|
||||
catalog: str | None = None
|
||||
schema_: str | None = Field(default=None, alias="schema")
|
||||
name: str | None = None
|
||||
canonical_table: str | None = None
|
||||
reason: ParseTableIdentifierReason | None = None
|
||||
detail: str | None = None
|
||||
|
||||
|
||||
class ParseTableIdentifierBatchResponse(BaseModel):
|
||||
results: dict[str, ParsedIdentifier]
|
||||
|
||||
|
||||
def parse_table_identifier_response(
|
||||
request: ParseTableIdentifierBatchRequest,
|
||||
) -> ParseTableIdentifierBatchResponse:
|
||||
shared_results = parse_table_identifier_batch(
|
||||
[
|
||||
SharedParseTableIdentifierItem(
|
||||
key=item.key,
|
||||
sql_table_name=item.sql_table_name,
|
||||
dialect=item.dialect,
|
||||
)
|
||||
for item in request.items
|
||||
]
|
||||
)
|
||||
return ParseTableIdentifierBatchResponse(
|
||||
results={
|
||||
key: ParsedIdentifier.model_validate(asdict(value))
|
||||
for key, value in shared_results.items()
|
||||
}
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue