rename klo to ktx

This commit is contained in:
Andrey Avtomonov 2026-05-10 23:51:24 +02:00
parent 1a42152e6f
commit 3ce510b55b
704 changed files with 10205 additions and 10255 deletions

View file

@ -0,0 +1,6 @@
"""Portable compute package for KTX."""
PACKAGE_NAME = "ktx-daemon"
VERSION = "0.1.0"
__all__ = ["PACKAGE_NAME", "VERSION"]

View file

@ -0,0 +1,172 @@
"""Command entry point for one-shot KTX daemon compute operations."""
from __future__ import annotations
import argparse
import json
import sys
from typing import Any
from pydantic import ValidationError
from ktx_daemon.code_execution import ExecuteCodeRequest, execute_code_response
from ktx_daemon.database_introspection import (
DatabaseIntrospectionRequest,
introspect_database_response,
)
from ktx_daemon.embeddings import (
ComputeEmbeddingBulkRequest,
ComputeEmbeddingRequest,
compute_embedding_bulk_response,
compute_embedding_response,
)
from ktx_daemon.lookml import ParseLookMLRequest, parse_lookml_project
from ktx_daemon.semantic_layer import (
SemanticLayerQueryRequest,
ValidateSourcesRequest,
query_semantic_layer,
validate_semantic_layer,
)
from ktx_daemon.source_generation import (
GenerateSourcesRequest,
generate_sources_response,
)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="ktx-daemon")
subcommands = parser.add_subparsers(dest="command", required=True)
subcommands.add_parser("semantic-query", help="Compile a semantic-layer query")
subcommands.add_parser("semantic-validate", help="Validate semantic-layer sources")
subcommands.add_parser(
"semantic-generate-sources",
help="Generate semantic-layer sources from schema scan data",
)
subcommands.add_parser(
"database-introspect",
help="Introspect a Postgres database schema",
)
subcommands.add_parser(
"lookml-parse",
help="Parse LookML files into KSL-ready structures",
)
subcommands.add_parser(
"embedding-compute",
help="Compute one local text embedding",
)
subcommands.add_parser(
"embedding-compute-bulk",
help="Compute local text embeddings in bulk",
)
subcommands.add_parser(
"code-execute",
help="Execute Python code with the current in-process boundary",
)
serve_http = subcommands.add_parser(
"serve-http",
help="Run the KTX daemon portable compute HTTP server",
)
serve_http.add_argument("--host", default="127.0.0.1")
serve_http.add_argument("--port", type=int, default=8765)
serve_http.add_argument(
"--log-level",
default="info",
choices=["critical", "error", "warning", "info", "debug", "trace"],
)
serve_http.add_argument(
"--enable-code-execution",
action="store_true",
help="Expose POST /code/execute on the HTTP server",
)
return parser
def _read_stdin_json() -> dict[str, Any]:
raw = sys.stdin.read()
parsed = json.loads(raw)
if not isinstance(parsed, dict):
raise ValueError("stdin JSON must be an object")
return parsed
def run_http_server(
*,
host: str,
port: int,
log_level: str,
enable_code_execution: bool,
) -> None:
import uvicorn
from ktx_daemon.app import create_app
uvicorn.run(
create_app(enable_code_execution=enable_code_execution),
host=host,
port=port,
log_level=log_level,
)
def main(argv: list[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
if args.command == "serve-http":
run_http_server(
host=args.host,
port=args.port,
log_level=args.log_level,
enable_code_execution=args.enable_code_execution,
)
return 0
try:
payload = _read_stdin_json()
if args.command == "semantic-query":
response = query_semantic_layer(
SemanticLayerQueryRequest.model_validate(payload)
)
elif args.command == "semantic-validate":
response = validate_semantic_layer(
ValidateSourcesRequest.model_validate(payload)
)
elif args.command == "semantic-generate-sources":
response = generate_sources_response(
GenerateSourcesRequest.model_validate(payload)
)
elif args.command == "database-introspect":
response = introspect_database_response(
DatabaseIntrospectionRequest.model_validate(payload)
)
elif args.command == "lookml-parse":
response = parse_lookml_project(ParseLookMLRequest.model_validate(payload))
elif args.command == "embedding-compute":
response = compute_embedding_response(
ComputeEmbeddingRequest.model_validate(payload)
)
elif args.command == "embedding-compute-bulk":
response = compute_embedding_bulk_response(
ComputeEmbeddingBulkRequest.model_validate(payload)
)
elif args.command == "code-execute":
response = execute_code_response(
ExecuteCodeRequest.model_validate(payload),
nest_api_url=None,
auth_header=None,
)
else:
parser.error(f"Unknown command: {args.command}")
return 2
sys.stdout.write(response.model_dump_json() + "\n")
return 0
except (json.JSONDecodeError, ValidationError, ValueError) as error:
sys.stderr.write(f"{error}\n")
return 1
except Exception as error:
sys.stderr.write(f"{type(error).__name__}: {error}\n")
return 1
if __name__ == "__main__":
raise SystemExit(main())

View file

@ -0,0 +1,228 @@
"""FastAPI app factory for the KTX daemon semantic compute server."""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import Any
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from ktx_daemon.code_execution import (
ExecuteCodeRequest,
ExecuteCodeResponse,
dumps_numpy_json,
execute_code_response,
)
from ktx_daemon.database_introspection import (
DatabaseIntrospectionRequest,
DatabaseIntrospectionResponse,
introspect_database_response,
)
from ktx_daemon.embeddings import (
ComputeEmbeddingBulkRequest,
ComputeEmbeddingBulkResponse,
ComputeEmbeddingRequest,
ComputeEmbeddingResponse,
EmbeddingProvider,
compute_embedding_bulk_response,
compute_embedding_response,
)
from ktx_daemon.lookml import (
ParseLookMLRequest,
ParseLookMLResponse,
parse_lookml_project,
)
from ktx_daemon.semantic_layer import (
SemanticLayerQueryRequest,
SemanticLayerQueryResponse,
ValidateSourcesRequest,
ValidateSourcesResponse,
query_semantic_layer,
validate_semantic_layer,
)
from ktx_daemon.source_generation import (
GenerateSourcesRequest,
GenerateSourcesResponse,
generate_sources_response,
)
from ktx_daemon.table_identifier import (
ParseTableIdentifierBatchRequest,
ParseTableIdentifierBatchResponse,
parse_table_identifier_response,
)
logger = logging.getLogger(__name__)
class NumpyORJSONResponse(Response):
media_type = "application/json"
def render(self, content: Any) -> bytes:
return dumps_numpy_json(content)
def create_app(
*,
embedding_provider: EmbeddingProvider | None = None,
database_introspector: Callable[
[DatabaseIntrospectionRequest], DatabaseIntrospectionResponse
]
| None = None,
enable_code_execution: bool = False,
) -> FastAPI:
app = FastAPI(
title="KTX Daemon",
description="Stateless portable compute server for KTX.",
version="0.1.0",
)
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "healthy"}
@app.post("/database/introspect", response_model=DatabaseIntrospectionResponse)
async def database_introspect(
request: DatabaseIntrospectionRequest,
) -> DatabaseIntrospectionResponse:
try:
introspector = database_introspector or introspect_database_response
return introspector(request)
except ValueError as error:
logger.warning("Database introspection rejected: %s", error)
raise HTTPException(status_code=400, detail=str(error)) from error
except Exception as error:
logger.exception("Database introspection failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Database introspection failed: {error}",
) from error
@app.post("/embeddings/compute", response_model=ComputeEmbeddingResponse)
async def embedding_compute(
request: ComputeEmbeddingRequest,
) -> ComputeEmbeddingResponse:
try:
return compute_embedding_response(
request,
provider=embedding_provider,
)
except ValueError as error:
logger.warning("Embedding compute rejected: %s", error)
raise HTTPException(status_code=400, detail=str(error)) from error
except Exception as error:
logger.exception("Embedding compute failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Embedding compute failed: {error}",
) from error
@app.post(
"/embeddings/compute-bulk",
response_model=ComputeEmbeddingBulkResponse,
)
async def embedding_compute_bulk(
request: ComputeEmbeddingBulkRequest,
) -> ComputeEmbeddingBulkResponse:
try:
return compute_embedding_bulk_response(
request,
provider=embedding_provider,
)
except ValueError as error:
logger.warning("Bulk embedding compute rejected: %s", error)
raise HTTPException(status_code=400, detail=str(error)) from error
except Exception as error:
logger.exception("Bulk embedding compute failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Bulk embedding compute failed: {error}",
) from error
if enable_code_execution:
@app.post(
"/code/execute",
response_model=ExecuteCodeResponse,
response_class=NumpyORJSONResponse,
)
async def code_execute(request: ExecuteCodeRequest) -> ExecuteCodeResponse:
try:
return execute_code_response(
request,
nest_api_url=None,
auth_header=None,
)
except Exception as error:
logger.exception("Code execution failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Code execution failed: {error}",
) from error
@app.post("/lookml/parse", response_model=ParseLookMLResponse)
async def lookml_parse(request: ParseLookMLRequest) -> ParseLookMLResponse:
try:
return parse_lookml_project(request)
except Exception as error:
logger.exception("LookML parsing failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"LookML parsing failed: {error}",
) from error
@app.post(
"/sql/parse-table-identifier",
response_model=ParseTableIdentifierBatchResponse,
)
async def sql_parse_table_identifier(
request: ParseTableIdentifierBatchRequest,
) -> ParseTableIdentifierBatchResponse:
try:
return parse_table_identifier_response(request)
except Exception as error:
logger.exception("Table identifier parsing failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Table identifier parsing failed: {error}",
) from error
@app.post(
"/semantic-layer/generate-sources", response_model=GenerateSourcesResponse
)
async def semantic_generate_sources(
request: GenerateSourcesRequest,
) -> GenerateSourcesResponse:
try:
return generate_sources_response(request)
except Exception as error:
logger.exception("Semantic source generation failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Semantic source generation failed: {error}",
) from error
@app.post("/semantic-layer/query", response_model=SemanticLayerQueryResponse)
async def semantic_query(
request: SemanticLayerQueryRequest,
) -> SemanticLayerQueryResponse:
try:
return query_semantic_layer(request)
except ValueError as error:
logger.warning("Semantic query rejected: %s", error)
raise HTTPException(status_code=400, detail=str(error)) from error
except Exception as error:
logger.exception("Semantic query failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Semantic layer query failed: {error}",
) from error
@app.post("/semantic-layer/validate", response_model=ValidateSourcesResponse)
async def semantic_validate(
request: ValidateSourcesRequest,
) -> ValidateSourcesResponse:
return validate_semantic_layer(request)
return app

View file

@ -0,0 +1,333 @@
"""Portable in-process code execution helpers for KTX daemon.
This module preserves the host application's current Python execution behavior.
It runs code with Python ``exec`` in the current process and does not provide
OS-level sandboxing.
"""
from __future__ import annotations
import json
import logging
import re
import sys
from collections.abc import Callable
from io import BytesIO, StringIO
from typing import Any
import numpy as np
import orjson
import pandas as pd
import requests
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
VALID_VISUALIZATION_TYPES = ["pie", "bar", "line", "area", "table", "boxplot"]
class ExecuteCodeRequest(BaseModel):
"""Request schema for executing Python code."""
code: str = Field(..., description="Python code to execute")
source_id: str | None = Field(
None,
description="Chat/dashboard ID for scratchpad file access",
)
message_id: str | None = Field(
None,
description="Message ID for visualization association",
)
class VisualizationSpec(BaseModel):
"""Specification for a visualization to be saved by the host application."""
type: str = Field(..., description="Type marker, always 'visualization'")
vis_type: str = Field(
...,
description="Visualization type: pie, bar, line, area, table",
)
config: dict[str, Any] = Field(
...,
description="Visualization configuration",
)
data: list[dict[str, Any]] = Field(
...,
description="Visualization data",
)
title: str | None = Field(None, description="Optional title")
class ExecuteCodeResponse(BaseModel):
"""Response schema for code execution."""
formatted_result: str = Field(
...,
description="Formatted execution result for display",
)
result: Any | None = Field(
None,
description="The value of the 'result' variable if set",
)
console_output: str | None = Field(
None,
description="Captured stdout from print statements",
)
error: str | None = Field(None, description="Error message if execution failed")
message: str | None = Field(
None,
description="Message if no clear result was returned",
)
visualizations: list[VisualizationSpec] | None = Field(
None,
description="List of visualizations detected in the result",
)
ScratchpadHelpers = tuple[
Callable[[pd.DataFrame, str | None], str],
Callable[[str], pd.DataFrame],
Callable[[str, dict[str, Any], list[dict[str, Any]]], str],
]
def dumps_numpy_json(content: Any) -> bytes:
"""Serialize JSON response content with numpy scalar and array support."""
return orjson.dumps(content, option=orjson.OPT_SERIALIZE_NUMPY)
def _strip_ansi_sequences(text: str) -> str:
ansi_escape = re.compile(
r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\([0-9;]*[a-zA-Z]|\x1b\[[0-9;]*~"
)
return ansi_escape.sub("", text)
def create_scratchpad_helpers(
nest_api_url: str | None,
auth_header: str | None,
source_id: str | None,
message_id: str | None = None,
http_client: Any = requests,
) -> ScratchpadHelpers:
"""Create scratchpad and visualization helpers that call host app APIs."""
def save_df_to_scratchpad(df: pd.DataFrame, filename: str | None = None) -> str:
if not nest_api_url or not auth_header or not source_id:
raise ValueError(
"nest_api_url, Authorization header, and source_id are required "
"for scratchpad operations"
)
data_json = df.to_dict(orient="records")
url = f"{nest_api_url}/private_api/scratchpad/{source_id}/files"
response = http_client.post(
url,
data=dumps_numpy_json(
{"filename": filename, "data": data_json, "format": "json"}
),
headers={"Authorization": auth_header, "Content-Type": "application/json"},
timeout=30,
)
response.raise_for_status()
saved_filename = response.json()["filename"]
rows, _cols = df.shape
return f"{rows} rows saved to {saved_filename}"
def read_scratchpad_file(filename: str) -> pd.DataFrame:
if not nest_api_url or not auth_header or not source_id:
raise ValueError(
"nest_api_url, Authorization header, and source_id are required "
"for scratchpad operations"
)
url = f"{nest_api_url}/private_api/scratchpad/{source_id}/files/{filename}?format=raw"
response = http_client.get(
url,
headers={"Authorization": auth_header, "Accept": "text/csv"},
timeout=30,
)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "text/csv" in content_type:
return pd.read_csv(BytesIO(response.content))
data = response.json()["data"]
return pd.DataFrame(data)
def save_visualization(
vis_type: str,
config: dict[str, Any],
data: list[dict[str, Any]],
) -> str:
if not nest_api_url or not auth_header or not source_id:
raise ValueError(
"nest_api_url, Authorization header, and source_id are required "
"for visualization operations"
)
if not message_id:
raise ValueError("message_id is required for visualization operations")
if vis_type not in VALID_VISUALIZATION_TYPES:
raise ValueError(
f"Invalid visualization type: {vis_type}. Must be one of {VALID_VISUALIZATION_TYPES}"
)
url = f"{nest_api_url}/private_api/visualizations/{source_id}"
payload = {
"visualizationType": vis_type,
"config": config,
"data": data,
"messageId": message_id,
}
response = http_client.post(
url,
data=dumps_numpy_json(payload),
headers={"Authorization": auth_header, "Content-Type": "application/json"},
timeout=30,
)
response.raise_for_status()
filename = response.json()["filename"]
print(f"Visualization saved: {filename}")
return f"![viz]({filename})"
return save_df_to_scratchpad, read_scratchpad_file, save_visualization
def detect_visualizations(result: Any) -> list[dict[str, Any]]:
"""Detect visualization specs in a code execution result value."""
visualizations = []
if isinstance(result, dict) and result.get("type") == "visualization":
visualizations.append(result)
elif isinstance(result, list):
for item in result:
if isinstance(item, dict) and item.get("type") == "visualization":
visualizations.append(item)
return visualizations
def execute_code(
code: str,
nest_api_url: str | None = None,
auth_header: str | None = None,
source_id: str | None = None,
message_id: str | None = None,
scratchpad_helpers: ScratchpadHelpers | None = None,
) -> dict[str, Any]:
"""Execute Python code with the current in-process execution boundary."""
logger.info("Starting code execution")
save_df, read_file, save_viz = scratchpad_helpers or create_scratchpad_helpers(
nest_api_url,
auth_header,
source_id,
message_id,
)
namespace = {
"pd": pd,
"np": np,
"json": json,
"requests": requests,
"save_df_to_scratchpad": save_df,
"read_scratchpad_file": read_file,
"save_visualization": save_viz,
}
stdout_capture = StringIO()
original_stdout = sys.stdout
sys.stdout = stdout_capture
console_output = ""
try:
logger.info("Executing code in current process namespace")
exec(code, namespace)
console_output = stdout_capture.getvalue()
if "result" in namespace:
logger.info("Code execution complete, 'result' variable found")
result_value = namespace["result"]
visualizations = detect_visualizations(result_value)
result = {"result": result_value}
if console_output:
result["console_output"] = console_output
if visualizations:
result["visualizations"] = visualizations
return result
logger.info("No result variable found")
result = {
"message": "Code executed successfully but no result variable was set"
}
if console_output:
result["console_output"] = console_output
return result
except Exception as error:
logger.exception("Error executing code: %s", error)
result = {"error": str(error)}
if console_output:
result["console_output"] = console_output
return result
finally:
sys.stdout = original_stdout
def format_execution_result(result: dict[str, Any]) -> str:
"""Format execution output for display in host chat responses."""
formatted_result = ""
if "console_output" in result:
formatted_result += "=== Console Output ===\n\n"
formatted_result += _strip_ansi_sequences(result["console_output"])
if "result" in result:
formatted_result += "\n\n=== Result ===\n\n"
formatted_result += str(result["result"])
elif "message" in result:
formatted_result += "\n\n=== Message ===\n\n"
formatted_result += result["message"]
elif "error" in result:
formatted_result += "\n\n=== Error ===\n\n"
formatted_result += result["error"]
return formatted_result
def execute_code_response(
request: ExecuteCodeRequest,
*,
nest_api_url: str | None,
auth_header: str | None,
) -> ExecuteCodeResponse:
"""Execute a validated request and return the public response model."""
result = execute_code(
code=request.code,
nest_api_url=nest_api_url,
auth_header=auth_header,
source_id=request.source_id,
message_id=request.message_id,
)
return ExecuteCodeResponse(
formatted_result=format_execution_result(result),
result=result.get("result"),
console_output=result.get("console_output"),
error=result.get("error"),
message=result.get("message"),
visualizations=result.get("visualizations"),
)

View file

@ -0,0 +1,284 @@
"""Portable database introspection helpers for KTX daemon."""
from __future__ import annotations
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
from pydantic import BaseModel, Field, field_validator
TABLES_SQL = """
select
t.table_catalog,
t.table_schema,
t.table_name,
obj_description(c.oid) as table_comment
from information_schema.tables t
join pg_catalog.pg_namespace n
on n.nspname = t.table_schema
join pg_catalog.pg_class c
on c.relnamespace = n.oid
and c.relname = t.table_name
where t.table_schema = any(%s)
and t.table_type = 'BASE TABLE'
order by t.table_schema, t.table_name
"""
COLUMNS_SQL = """
select
current_database() as table_catalog,
n.nspname as table_schema,
c.relname as table_name,
a.attname as column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) as formatted_type,
not a.attnotnull as is_nullable,
exists (
select 1
from pg_catalog.pg_index i
where i.indrelid = c.oid
and i.indisprimary
and a.attnum = any(i.indkey)
) as is_primary_key,
pg_catalog.col_description(c.oid, a.attnum) as column_comment
from pg_catalog.pg_attribute a
join pg_catalog.pg_class c
on c.oid = a.attrelid
join pg_catalog.pg_namespace n
on n.oid = c.relnamespace
where n.nspname = any(%s)
and c.relkind in ('r', 'p')
and a.attnum > 0
and not a.attisdropped
order by n.nspname, c.relname, a.attnum
"""
FOREIGN_KEYS_SQL = """
select
current_database() as table_catalog,
source_constraint.table_schema,
source_constraint.table_name,
source_key.column_name as from_column,
target_key.table_name as to_table,
target_key.column_name as to_column,
source_constraint.constraint_name
from information_schema.table_constraints source_constraint
join information_schema.key_column_usage source_key
on source_key.constraint_catalog = source_constraint.constraint_catalog
and source_key.constraint_schema = source_constraint.constraint_schema
and source_key.constraint_name = source_constraint.constraint_name
join information_schema.referential_constraints ref_constraint
on ref_constraint.constraint_catalog = source_constraint.constraint_catalog
and ref_constraint.constraint_schema = source_constraint.constraint_schema
and ref_constraint.constraint_name = source_constraint.constraint_name
join information_schema.key_column_usage target_key
on target_key.constraint_catalog = ref_constraint.unique_constraint_catalog
and target_key.constraint_schema = ref_constraint.unique_constraint_schema
and target_key.constraint_name = ref_constraint.unique_constraint_name
and target_key.ordinal_position = source_key.position_in_unique_constraint
where source_constraint.constraint_type = 'FOREIGN KEY'
and source_constraint.table_schema = any(%s)
order by source_constraint.table_schema, source_constraint.table_name, source_constraint.constraint_name, source_key.ordinal_position
"""
class LiveDatabaseColumn(BaseModel):
name: str
type: str
nullable: bool = True
primary_key: bool = False
comment: str | None = None
class LiveDatabaseForeignKey(BaseModel):
from_column: str
to_table: str
to_column: str
constraint_name: str | None = None
class LiveDatabaseTable(BaseModel):
catalog: str | None = None
db: str | None = None
name: str
comment: str | None = None
columns: list[LiveDatabaseColumn] = Field(default_factory=list)
foreign_keys: list[LiveDatabaseForeignKey] = Field(default_factory=list)
class DatabaseIntrospectionRequest(BaseModel):
connection_id: str
driver: str = "postgres"
url: str
schemas: list[str] = Field(default_factory=lambda: ["public"])
statement_timeout_ms: int = Field(default=30_000, ge=1)
connection_timeout_seconds: int = Field(default=5, ge=1)
@field_validator("schemas")
@classmethod
def _schemas_must_not_be_empty(cls, value: list[str]) -> list[str]:
if not value:
raise ValueError("database introspection requires at least one schema")
return value
class DatabaseIntrospectionResponse(BaseModel):
connection_id: str
extracted_at: str
metadata: dict[str, Any]
tables: list[LiveDatabaseTable]
@dataclass(frozen=True)
class DatabaseIntrospectionRows:
table_rows: Sequence[Mapping[str, Any]]
column_rows: Sequence[Mapping[str, Any]]
foreign_key_rows: Sequence[Mapping[str, Any]]
DatabaseRowsLoader = Callable[[DatabaseIntrospectionRequest], DatabaseIntrospectionRows]
NowProvider = Callable[[], str]
def _driver_name(driver: str) -> str:
return driver.strip().lower()
def _table_key(catalog: str | None, db: str | None, name: str) -> str:
return f"{catalog or ''}\u0000{db or ''}\u0000{name}"
def _optional_string(row: Mapping[str, Any], key: str) -> str | None:
value = row.get(key)
return value if isinstance(value, str) else None
def _required_string(row: Mapping[str, Any], key: str) -> str:
value = row.get(key)
if not isinstance(value, str) or not value:
raise ValueError(f"database introspection row is missing string field {key}")
return value
def _statement_timeout_config(statement_timeout_ms: int) -> tuple[str, tuple[str]]:
return (
"SELECT set_config('statement_timeout', %s, true)",
(f"{int(statement_timeout_ms)}ms",),
)
def _load_postgres_rows(
request: DatabaseIntrospectionRequest,
) -> DatabaseIntrospectionRows:
try:
import psycopg
from psycopg.rows import dict_row
except ImportError as error:
raise RuntimeError(
"psycopg is required for Postgres database introspection"
) from error
connection = psycopg.connect(
request.url,
connect_timeout=request.connection_timeout_seconds,
application_name="ktx-daemon-database-introspection",
row_factory=dict_row,
)
try:
connection.execute("BEGIN READ ONLY")
try:
connection.execute(*_statement_timeout_config(request.statement_timeout_ms))
params = (request.schemas,)
table_rows = list(connection.execute(TABLES_SQL, params))
column_rows = list(connection.execute(COLUMNS_SQL, params))
foreign_key_rows = list(connection.execute(FOREIGN_KEYS_SQL, params))
connection.execute("COMMIT")
except Exception:
connection.execute("ROLLBACK")
raise
finally:
connection.close()
return DatabaseIntrospectionRows(
table_rows=table_rows,
column_rows=column_rows,
foreign_key_rows=foreign_key_rows,
)
def _map_rows_to_tables(rows: DatabaseIntrospectionRows) -> list[LiveDatabaseTable]:
tables: dict[str, LiveDatabaseTable] = {}
for row in rows.table_rows:
catalog = _optional_string(row, "table_catalog")
db = _required_string(row, "table_schema")
name = _required_string(row, "table_name")
key = _table_key(catalog, db, name)
tables[key] = LiveDatabaseTable(
catalog=catalog,
db=db,
name=name,
comment=_optional_string(row, "table_comment"),
)
for row in rows.column_rows:
catalog = _optional_string(row, "table_catalog")
db = _required_string(row, "table_schema")
table_name = _required_string(row, "table_name")
table = tables.get(_table_key(catalog, db, table_name))
if table is None:
continue
table.columns.append(
LiveDatabaseColumn(
name=_required_string(row, "column_name"),
type=_required_string(row, "formatted_type"),
nullable=bool(row.get("is_nullable")),
primary_key=bool(row.get("is_primary_key")),
comment=_optional_string(row, "column_comment"),
)
)
for row in rows.foreign_key_rows:
catalog = _optional_string(row, "table_catalog")
db = _required_string(row, "table_schema")
table_name = _required_string(row, "table_name")
table = tables.get(_table_key(catalog, db, table_name))
if table is None:
continue
table.foreign_keys.append(
LiveDatabaseForeignKey(
from_column=_required_string(row, "from_column"),
to_table=_required_string(row, "to_table"),
to_column=_required_string(row, "to_column"),
constraint_name=_optional_string(row, "constraint_name"),
)
)
return sorted(
tables.values(),
key=lambda table: _table_key(table.catalog, table.db, table.name),
)
def introspect_database_response(
request: DatabaseIntrospectionRequest,
*,
load_rows: DatabaseRowsLoader | None = None,
now: NowProvider | None = None,
) -> DatabaseIntrospectionResponse:
driver = _driver_name(request.driver)
if driver not in {"postgres", "postgresql"}:
raise ValueError('database introspection supports only driver "postgres"')
rows = (load_rows or _load_postgres_rows)(request)
timestamp = now() if now else datetime.now(timezone.utc).isoformat()
return DatabaseIntrospectionResponse(
connection_id=request.connection_id,
extracted_at=timestamp,
metadata={"driver": driver, "schemas": list(request.schemas)},
tables=_map_rows_to_tables(rows),
)

View file

@ -0,0 +1,172 @@
"""Portable embedding compute helpers for KTX daemon."""
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING, Protocol
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
DEFAULT_SENTENCE_TRANSFORMER_MODEL = "all-MiniLM-L6-v2"
DEFAULT_EMBEDDING_DIMENSIONS = 384
DEFAULT_MAX_BATCH_SIZE = 100
class EmbeddingProvider(Protocol):
"""Provider interface for local embedding compute."""
@property
def name(self) -> str: ...
@property
def dimensions(self) -> int: ...
@property
def max_batch_size(self) -> int: ...
def encode(self, texts: list[str]) -> list[list[float]]: ...
class ComputeEmbeddingRequest(BaseModel):
"""Request schema for computing a single embedding."""
text: str = Field(..., description="Text to compute embedding for", min_length=1)
class ComputeEmbeddingResponse(BaseModel):
"""Response schema for single embedding computation."""
embedding: list[float] = Field(..., description="384-dimensional embedding vector")
class ComputeEmbeddingBulkRequest(BaseModel):
"""Request schema for computing multiple embeddings."""
texts: list[str] = Field(
...,
description="List of texts to compute embeddings for",
min_length=1,
max_length=DEFAULT_MAX_BATCH_SIZE,
)
class ComputeEmbeddingBulkResponse(BaseModel):
"""Response schema for bulk embedding computation."""
embeddings: list[list[float]] = Field(
...,
description="List of 384-dimensional embedding vectors",
)
class SentenceTransformersEmbeddingProvider:
"""Lazy sentence-transformers provider for local embeddings."""
def __init__(
self,
model_name: str = DEFAULT_SENTENCE_TRANSFORMER_MODEL,
model: SentenceTransformer | None = None,
) -> None:
self.model_name = model_name
self._model = model
self._model_lock = threading.Lock()
@property
def name(self) -> str:
return "sentence-transformers"
@property
def dimensions(self) -> int:
return DEFAULT_EMBEDDING_DIMENSIONS
@property
def max_batch_size(self) -> int:
return DEFAULT_MAX_BATCH_SIZE
def _get_model(self) -> SentenceTransformer:
if self._model is not None:
return self._model
with self._model_lock:
if self._model is None:
from sentence_transformers import SentenceTransformer
logger.info("Loading SentenceTransformer model: %s", self.model_name)
self._model = SentenceTransformer(self.model_name)
logger.info("SentenceTransformer model loaded successfully")
return self._model
def encode(self, texts: list[str]) -> list[list[float]]:
model = self._get_model()
if len(texts) == 1:
raw_single = model.encode(texts[0]).tolist()
return [[float(value) for value in raw_single]]
raw_bulk = model.encode(texts).tolist()
return [[float(value) for value in embedding] for embedding in raw_bulk]
_default_provider: SentenceTransformersEmbeddingProvider | None = None
_default_provider_lock = threading.Lock()
def get_default_embedding_provider() -> SentenceTransformersEmbeddingProvider:
"""Return the process-wide default embedding provider."""
global _default_provider
if _default_provider is not None:
return _default_provider
with _default_provider_lock:
if _default_provider is None:
_default_provider = SentenceTransformersEmbeddingProvider()
return _default_provider
def _validate_texts(texts: list[str], max_batch_size: int) -> None:
if not texts:
raise ValueError("Texts array must not be empty")
if len(texts) > max_batch_size:
raise ValueError(f"Maximum {max_batch_size} texts allowed per batch")
empty_indices = [
index for index, text in enumerate(texts) if not text or not text.strip()
]
if empty_indices:
joined_indices = ", ".join(str(index) for index in empty_indices)
raise ValueError(f"Empty texts found at indices: {joined_indices}")
def compute_embedding_response(
request: ComputeEmbeddingRequest,
provider: EmbeddingProvider | None = None,
) -> ComputeEmbeddingResponse:
"""Compute one embedding from a request model."""
selected_provider = provider or get_default_embedding_provider()
_validate_texts([request.text], selected_provider.max_batch_size)
return ComputeEmbeddingResponse(
embedding=selected_provider.encode([request.text])[0]
)
def compute_embedding_bulk_response(
request: ComputeEmbeddingBulkRequest,
provider: EmbeddingProvider | None = None,
) -> ComputeEmbeddingBulkResponse:
"""Compute multiple embeddings from a request model."""
selected_provider = provider or get_default_embedding_provider()
_validate_texts(request.texts, selected_provider.max_batch_size)
return ComputeEmbeddingBulkResponse(
embeddings=selected_provider.encode(request.texts)
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,136 @@
"""Semantic-layer compute helpers for the KTX daemon package."""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, Field
from semantic_layer.duplicate_check import validate_measure_duplicates
from semantic_layer.engine import SemanticEngine
from semantic_layer.models import QueryResult, SourceDefinition
class SemanticLayerQueryRequest(BaseModel):
sources: list[dict[str, Any]]
query: dict[str, Any]
dialect: str = "postgres"
class SemanticLayerQueryResponse(BaseModel):
sql: str
dialect: str
columns: list[dict[str, Any]]
plan: dict[str, Any]
class ValidateSourcesRequest(BaseModel):
sources: list[dict[str, Any]]
dialect: str = "postgres"
recently_touched: list[str] | None = None
class ValidateSourcesResponse(BaseModel):
valid: bool
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
per_source_warnings: dict[str, list[str]] = Field(default_factory=dict)
def _load_sources(raw_sources: list[dict[str, Any]]) -> dict[str, SourceDefinition]:
sources: dict[str, SourceDefinition] = {}
for raw_source in raw_sources:
source = SourceDefinition(**raw_source)
if source.name in sources:
raise ValueError(f"Duplicate source name '{source.name}'")
sources[source.name] = source
return sources
def _validate_duplicate_measure_names(source: SourceDefinition) -> list[str]:
errors: list[str] = []
seen: set[str] = set()
for measure in source.measures:
if measure.name in seen:
errors.append(
f"Duplicate measure '{measure.name}' on source '{source.name}'"
)
continue
seen.add(measure.name)
return errors
def _response_columns(result: QueryResult) -> list[dict[str, Any]]:
measure_names = {
measure.name: measure.qualified_ref
for measure in result.resolved_plan.measures
if measure.qualified_ref
}
columns: list[dict[str, Any]] = []
for column in result.columns:
dumped = column.model_dump(mode="json")
if column.provenance.value == "dimension" and column.expr:
dumped["name"] = column.expr
elif column.name in measure_names:
dumped["name"] = measure_names[column.name]
columns.append(dumped)
return columns
def query_semantic_layer(
request: SemanticLayerQueryRequest,
) -> SemanticLayerQueryResponse:
sources = _load_sources(request.sources)
engine = SemanticEngine.from_sources(sources, dialect=request.dialect)
result = engine.query(request.query)
return SemanticLayerQueryResponse(
sql=result.sql,
dialect=result.dialect,
columns=_response_columns(result),
plan=result.resolved_plan.model_dump(mode="json"),
)
def validate_semantic_layer(request: ValidateSourcesRequest) -> ValidateSourcesResponse:
errors: list[str] = []
warnings: list[str] = []
per_source_warnings: dict[str, list[str]] = {}
sources: dict[str, SourceDefinition] = {}
seen_names: set[str] = set()
for raw_source in request.sources:
raw_name = raw_source.get("name") if isinstance(raw_source, dict) else None
try:
source = SourceDefinition(**raw_source)
except Exception as error:
label = raw_name or "<unknown>"
errors.append(f"Source '{label}' failed to parse: {error}")
continue
if source.name in seen_names:
errors.append(f"Duplicate source name '{source.name}'")
continue
seen_names.add(source.name)
sources[source.name] = source
errors.extend(_validate_duplicate_measure_names(source))
if sources:
try:
engine = SemanticEngine.from_sources(sources, dialect=request.dialect)
report = engine.validate(
recently_touched=set(request.recently_touched)
if request.recently_touched
else None
)
errors.extend(report.errors)
warnings.extend(report.warnings)
per_source_warnings.update(report.per_source_warnings)
errors.extend(validate_measure_duplicates(sources, dialect=request.dialect))
except Exception as error:
errors.append(f"Validation failed: {error}")
return ValidateSourcesResponse(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
per_source_warnings=per_source_warnings,
)

View file

@ -0,0 +1,254 @@
"""Generate ktx-sl YAML source definitions from database schema scan data."""
from __future__ import annotations
import logging
import re
from typing import Any
from pydantic import BaseModel
from semantic_layer.models import (
ColumnRole,
JoinDeclaration,
MeasureDefinition,
SourceColumn,
SourceDefinition,
)
logger = logging.getLogger(__name__)
_NUMBER_PATTERN = re.compile(
r"int|integer|bigint|smallint|tinyint|numeric|decimal|float|double|real|number|money",
re.IGNORECASE,
)
_TIME_PATTERN = re.compile(
r"timestamp|datetime|date|time(?!stamp)",
re.IGNORECASE,
)
_BOOLEAN_PATTERN = re.compile(r"bool|boolean|bit", re.IGNORECASE)
_ID_PATTERN = re.compile(
r"^id$|_id$|^uuid$|_uuid$|_key$|_pk$|identifier$",
re.IGNORECASE,
)
_RELATIONSHIP_MAP = {
"MANY_TO_ONE": "many_to_one",
"ONE_TO_MANY": "one_to_many",
"ONE_TO_ONE": "one_to_one",
"many_to_one": "many_to_one",
"one_to_many": "one_to_many",
"one_to_one": "one_to_one",
}
_RELATIONSHIP_INVERSE = {
"many_to_one": "one_to_many",
"one_to_many": "many_to_one",
"one_to_one": "one_to_one",
}
class ColumnInput(BaseModel):
name: str
type: str
primary_key: bool = False
nullable: bool = True
comment: str | None = None
class TableInput(BaseModel):
name: str
catalog: str | None = None
db: str | None = None
comment: str | None = None
columns: list[ColumnInput]
class LinkInput(BaseModel):
from_table: str
from_column: str
to_table: str
to_column: str
relationship_type: str
class GenerateSourcesRequest(BaseModel):
tables: list[TableInput]
links: list[LinkInput]
dialect: str = "postgres"
class GenerateSourcesResponse(BaseModel):
sources: list[dict[str, Any]]
source_count: int
def _map_column_type(db_type: str) -> str:
if _BOOLEAN_PATTERN.search(db_type):
return "boolean"
if _TIME_PATTERN.search(db_type):
return "time"
if _NUMBER_PATTERN.search(db_type):
return "number"
return "string"
def _build_table_ref(table: TableInput) -> str:
parts = []
if table.catalog:
parts.append(table.catalog)
if table.db:
parts.append(table.db)
parts.append(table.name)
return ".".join(parts)
def _generate_measures(
table_name: str,
columns: list[ColumnInput],
pk_columns: list[str],
) -> list[MeasureDefinition]:
measures: list[MeasureDefinition] = []
if pk_columns:
pk = pk_columns[0]
measures.append(
MeasureDefinition(
name="record_count",
expr=f"count({pk})",
description=f"Count of {table_name} records",
)
)
for col in columns:
if _map_column_type(col.type) != "number":
continue
if _ID_PATTERN.search(col.name):
continue
measures.append(
MeasureDefinition(
name=f"total_{col.name}",
expr=f"sum({col.name})",
description=f"Sum of {col.name}"
+ (f" \u2014 {col.comment}" if col.comment else ""),
)
)
measures.append(
MeasureDefinition(
name=f"avg_{col.name}",
expr=f"avg({col.name})",
description=f"Average of {col.name}"
+ (f" \u2014 {col.comment}" if col.comment else ""),
)
)
return measures
def generate_sources(request: GenerateSourcesRequest) -> list[dict[str, Any]]:
links_by_from: dict[str, list[LinkInput]] = {}
links_by_to: dict[str, list[LinkInput]] = {}
for link in request.links:
links_by_from.setdefault(link.from_table, []).append(link)
links_by_to.setdefault(link.to_table, []).append(link)
table_names = {table.name for table in request.tables}
sources: list[dict[str, Any]] = []
for table in request.tables:
pk_columns = [column.name for column in table.columns if column.primary_key]
grain = (
pk_columns
if pk_columns
else [table.columns[0].name]
if table.columns
else ["id"]
)
sl_columns: list[SourceColumn] = []
for column in table.columns:
sl_type = _map_column_type(column.type)
role = ColumnRole.TIME if sl_type == "time" else ColumnRole.DEFAULT
sl_columns.append(
SourceColumn(
name=column.name,
type=sl_type,
role=role,
description=column.comment,
)
)
joins: list[JoinDeclaration] = []
for link in links_by_from.get(table.name, []):
if link.to_table not in table_names:
logger.warning(
"Skipping link from %s.%s to %s.%s: target table not in scan",
link.from_table,
link.from_column,
link.to_table,
link.to_column,
)
continue
relationship = _RELATIONSHIP_MAP.get(link.relationship_type, "many_to_one")
joins.append(
JoinDeclaration(
to=link.to_table,
on=f"{link.from_column} = {link.to_table}.{link.to_column}",
relationship=relationship,
)
)
for link in links_by_to.get(table.name, []):
if link.from_table not in table_names:
logger.warning(
"Skipping reverse link from %s.%s to %s.%s: source table not in scan",
link.from_table,
link.from_column,
link.to_table,
link.to_column,
)
continue
forward_relationship = _RELATIONSHIP_MAP.get(
link.relationship_type, "many_to_one"
)
reverse_relationship = _RELATIONSHIP_INVERSE.get(
forward_relationship, "one_to_many"
)
joins.append(
JoinDeclaration(
to=link.from_table,
on=f"{link.to_column} = {link.from_table}.{link.from_column}",
relationship=reverse_relationship,
)
)
to_counts: dict[str, int] = {}
for join in joins:
to_counts[join.to] = to_counts.get(join.to, 0) + 1
if any(count > 1 for count in to_counts.values()):
for join in joins:
if to_counts[join.to] > 1:
fk_col = join.on.split(" = ")[0].strip().lower()
join.alias = f"{join.to}_{fk_col}"
source = SourceDefinition(
name=table.name,
description=table.comment,
table=_build_table_ref(table),
grain=grain,
columns=sl_columns,
joins=joins,
measures=_generate_measures(table.name, table.columns, pk_columns),
)
sources.append(source.model_dump(exclude_none=True))
logger.info("Generated %d ktx-sl source definitions", len(sources))
return sources
def generate_sources_response(
request: GenerateSourcesRequest,
) -> GenerateSourcesResponse:
sources = generate_sources(request)
return GenerateSourcesResponse(sources=sources, source_count=len(sources))

View file

@ -0,0 +1,66 @@
from __future__ import annotations
from dataclasses import asdict
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
from semantic_layer.table_identifier_parser import (
ParseTableIdentifierItem as SharedParseTableIdentifierItem,
parse_table_identifier_batch,
)
ParseTableIdentifierReason = Literal[
"looker_template_unresolved",
"derived_table_not_supported",
"no_physical_table",
"multiple_table_references",
"unsupported_dialect",
"parse_error",
]
class ParseTableIdentifierItem(BaseModel):
key: str
sql_table_name: str
dialect: str
class ParseTableIdentifierBatchRequest(BaseModel):
items: list[ParseTableIdentifierItem]
class ParsedIdentifier(BaseModel):
model_config = ConfigDict(populate_by_name=True)
ok: bool
catalog: str | None = None
schema_: str | None = Field(default=None, alias="schema")
name: str | None = None
canonical_table: str | None = None
reason: ParseTableIdentifierReason | None = None
detail: str | None = None
class ParseTableIdentifierBatchResponse(BaseModel):
results: dict[str, ParsedIdentifier]
def parse_table_identifier_response(
request: ParseTableIdentifierBatchRequest,
) -> ParseTableIdentifierBatchResponse:
shared_results = parse_table_identifier_batch(
[
SharedParseTableIdentifierItem(
key=item.key,
sql_table_name=item.sql_table_name,
dialect=item.dialect,
)
for item in request.items
]
)
return ParseTableIdentifierBatchResponse(
results={
key: ParsedIdentifier.model_validate(asdict(value))
for key, value in shared_results.items()
}
)