Initial open-source release

This commit is contained in:
Andrey Avtomonov 2026-05-10 23:12:26 +02:00
commit 1a42152e6f
1199 changed files with 257054 additions and 0 deletions

104
python/klo-daemon/README.md Normal file
View file

@ -0,0 +1,104 @@
# klo-daemon
`klo-daemon` is the portable Python compute package for KLO.
It supports portable compute in two modes:
- One-shot commands, used by default by `@klo/context`.
- An explicit HTTP server for long-running local MCP sessions.
## One-shot semantic query
```bash
printf '%s\n' '{"sources":[],"query":{"measures":[],"dimensions":[]},"dialect":"postgres"}' \
| klo-daemon semantic-query
```
## One-shot source generation
Generate semantic-layer sources from schema scan data:
```bash
printf '%s\n' '{"tables":[{"name":"orders","db":"public","columns":[{"name":"id","type":"integer","primary_key":true}]}],"links":[],"dialect":"postgres"}' \
| klo-daemon semantic-generate-sources
```
## One-shot database introspection
Introspect a Postgres database schema:
```bash
printf '%s\n' '{"connection_id":"warehouse","driver":"postgres","url":"postgresql://readonly@example.test/warehouse","schemas":["public"]}' \
| klo-daemon database-introspect
```
## One-shot LookML parsing
Parse LookML projects into resolved, KSL-ready structures:
```bash
printf '%s\n' '{"files":[{"path":"views/orders.view.lkml","content":"view: orders { sql_table_name: public.orders ;; measure: order_count { type: count } }"}],"dialect":"postgres"}' \
| klo-daemon lookml-parse
```
## One-shot embeddings
Compute text embeddings locally:
```bash
printf '%s\n' '{"text":"hello"}' \
| klo-daemon embedding-compute
```
Compute text embeddings locally in bulk:
```bash
printf '%s\n' '{"texts":["hello","world"]}' \
| klo-daemon embedding-compute-bulk
```
## One-shot code execution
Execute Python code with the current in-process boundary:
```bash
printf '%s\n' '{"code":"result = 1 + 2"}' \
| klo-daemon code-execute
```
## HTTP compute server
Start the HTTP compute server with code execution disabled:
```bash
klo-daemon serve-http --host 127.0.0.1 --port 8765
```
Enable HTTP code execution explicitly:
```bash
klo-daemon serve-http --host 127.0.0.1 --port 8765 --enable-code-execution
```
Available HTTP endpoints:
- `GET /health`
- `POST /database/introspect`
- `POST /embeddings/compute`
- `POST /embeddings/compute-bulk`
- `POST /lookml/parse`
- `POST /semantic-layer/generate-sources`
- `POST /semantic-layer/query`
- `POST /semantic-layer/validate`
- `POST /code/execute` when `--enable-code-execution` is passed
The HTTP server exposes Postgres database introspection, LookML parsing, local
embedding compute, and semantic-layer compute for source generation, query
compilation, and validation.
Code execution is off by default. When enabled, it runs Python `exec` in the
daemon process with the same in-process boundary as the one-shot
`code-execute` command and does not provide OS-level sandboxing.
HTTP code execution uses the standalone KLO boundary. It does not forward
caller authorization headers to a host app and does not connect scratchpad or
visualization helpers to host application APIs.

View file

@ -0,0 +1,50 @@
[project]
name = "klo-daemon"
version = "0.1.0"
description = "Portable compute package for KLO semantic-layer operations"
readme = "README.md"
requires-python = ">=3.13"
license = "Apache-2.0"
dependencies = [
"fastapi>=0.115.0",
"klo-sl",
"lkml>=1.3.7",
"numpy>=2.2.6",
"orjson>=3.11.4",
"pandas>=2.2.3",
"psycopg[binary]>=3.2.0",
"pydantic>=2.9.0",
"requests>=2.32.0",
"sentence-transformers>=5.1.1",
"sqlglot>=26",
"torch>=2.2.0",
"uvicorn[standard]>=0.32.0",
]
[project.scripts]
klo-daemon = "klo_daemon.__main__:main"
[project.urls]
Homepage = "https://github.com/kaelio/ktx"
Repository = "https://github.com/kaelio/ktx"
Issues = "https://github.com/kaelio/ktx/issues"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/klo_daemon"]
[dependency-groups]
dev = [
"httpx>=0.28.1",
"pytest>=9.0.2",
]
[tool.uv.sources]
klo-sl = { workspace = true }
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["src"]

View file

@ -0,0 +1,6 @@
"""Portable compute package for KLO."""
PACKAGE_NAME = "klo-daemon"
VERSION = "0.1.0"
__all__ = ["PACKAGE_NAME", "VERSION"]

View file

@ -0,0 +1,172 @@
"""Command entry point for one-shot KLO daemon compute operations."""
from __future__ import annotations
import argparse
import json
import sys
from typing import Any
from pydantic import ValidationError
from klo_daemon.code_execution import ExecuteCodeRequest, execute_code_response
from klo_daemon.database_introspection import (
DatabaseIntrospectionRequest,
introspect_database_response,
)
from klo_daemon.embeddings import (
ComputeEmbeddingBulkRequest,
ComputeEmbeddingRequest,
compute_embedding_bulk_response,
compute_embedding_response,
)
from klo_daemon.lookml import ParseLookMLRequest, parse_lookml_project
from klo_daemon.semantic_layer import (
SemanticLayerQueryRequest,
ValidateSourcesRequest,
query_semantic_layer,
validate_semantic_layer,
)
from klo_daemon.source_generation import (
GenerateSourcesRequest,
generate_sources_response,
)
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="klo-daemon")
subcommands = parser.add_subparsers(dest="command", required=True)
subcommands.add_parser("semantic-query", help="Compile a semantic-layer query")
subcommands.add_parser("semantic-validate", help="Validate semantic-layer sources")
subcommands.add_parser(
"semantic-generate-sources",
help="Generate semantic-layer sources from schema scan data",
)
subcommands.add_parser(
"database-introspect",
help="Introspect a Postgres database schema",
)
subcommands.add_parser(
"lookml-parse",
help="Parse LookML files into KSL-ready structures",
)
subcommands.add_parser(
"embedding-compute",
help="Compute one local text embedding",
)
subcommands.add_parser(
"embedding-compute-bulk",
help="Compute local text embeddings in bulk",
)
subcommands.add_parser(
"code-execute",
help="Execute Python code with the current in-process boundary",
)
serve_http = subcommands.add_parser(
"serve-http",
help="Run the KLO daemon portable compute HTTP server",
)
serve_http.add_argument("--host", default="127.0.0.1")
serve_http.add_argument("--port", type=int, default=8765)
serve_http.add_argument(
"--log-level",
default="info",
choices=["critical", "error", "warning", "info", "debug", "trace"],
)
serve_http.add_argument(
"--enable-code-execution",
action="store_true",
help="Expose POST /code/execute on the HTTP server",
)
return parser
def _read_stdin_json() -> dict[str, Any]:
raw = sys.stdin.read()
parsed = json.loads(raw)
if not isinstance(parsed, dict):
raise ValueError("stdin JSON must be an object")
return parsed
def run_http_server(
*,
host: str,
port: int,
log_level: str,
enable_code_execution: bool,
) -> None:
import uvicorn
from klo_daemon.app import create_app
uvicorn.run(
create_app(enable_code_execution=enable_code_execution),
host=host,
port=port,
log_level=log_level,
)
def main(argv: list[str] | None = None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
if args.command == "serve-http":
run_http_server(
host=args.host,
port=args.port,
log_level=args.log_level,
enable_code_execution=args.enable_code_execution,
)
return 0
try:
payload = _read_stdin_json()
if args.command == "semantic-query":
response = query_semantic_layer(
SemanticLayerQueryRequest.model_validate(payload)
)
elif args.command == "semantic-validate":
response = validate_semantic_layer(
ValidateSourcesRequest.model_validate(payload)
)
elif args.command == "semantic-generate-sources":
response = generate_sources_response(
GenerateSourcesRequest.model_validate(payload)
)
elif args.command == "database-introspect":
response = introspect_database_response(
DatabaseIntrospectionRequest.model_validate(payload)
)
elif args.command == "lookml-parse":
response = parse_lookml_project(ParseLookMLRequest.model_validate(payload))
elif args.command == "embedding-compute":
response = compute_embedding_response(
ComputeEmbeddingRequest.model_validate(payload)
)
elif args.command == "embedding-compute-bulk":
response = compute_embedding_bulk_response(
ComputeEmbeddingBulkRequest.model_validate(payload)
)
elif args.command == "code-execute":
response = execute_code_response(
ExecuteCodeRequest.model_validate(payload),
nest_api_url=None,
auth_header=None,
)
else:
parser.error(f"Unknown command: {args.command}")
return 2
sys.stdout.write(response.model_dump_json() + "\n")
return 0
except (json.JSONDecodeError, ValidationError, ValueError) as error:
sys.stderr.write(f"{error}\n")
return 1
except Exception as error:
sys.stderr.write(f"{type(error).__name__}: {error}\n")
return 1
if __name__ == "__main__":
raise SystemExit(main())

View file

@ -0,0 +1,228 @@
"""FastAPI app factory for the KLO daemon semantic compute server."""
from __future__ import annotations
import logging
from collections.abc import Callable
from typing import Any
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from klo_daemon.code_execution import (
ExecuteCodeRequest,
ExecuteCodeResponse,
dumps_numpy_json,
execute_code_response,
)
from klo_daemon.database_introspection import (
DatabaseIntrospectionRequest,
DatabaseIntrospectionResponse,
introspect_database_response,
)
from klo_daemon.embeddings import (
ComputeEmbeddingBulkRequest,
ComputeEmbeddingBulkResponse,
ComputeEmbeddingRequest,
ComputeEmbeddingResponse,
EmbeddingProvider,
compute_embedding_bulk_response,
compute_embedding_response,
)
from klo_daemon.lookml import (
ParseLookMLRequest,
ParseLookMLResponse,
parse_lookml_project,
)
from klo_daemon.semantic_layer import (
SemanticLayerQueryRequest,
SemanticLayerQueryResponse,
ValidateSourcesRequest,
ValidateSourcesResponse,
query_semantic_layer,
validate_semantic_layer,
)
from klo_daemon.source_generation import (
GenerateSourcesRequest,
GenerateSourcesResponse,
generate_sources_response,
)
from klo_daemon.table_identifier import (
ParseTableIdentifierBatchRequest,
ParseTableIdentifierBatchResponse,
parse_table_identifier_response,
)
logger = logging.getLogger(__name__)
class NumpyORJSONResponse(Response):
media_type = "application/json"
def render(self, content: Any) -> bytes:
return dumps_numpy_json(content)
def create_app(
*,
embedding_provider: EmbeddingProvider | None = None,
database_introspector: Callable[
[DatabaseIntrospectionRequest], DatabaseIntrospectionResponse
]
| None = None,
enable_code_execution: bool = False,
) -> FastAPI:
app = FastAPI(
title="KLO Daemon",
description="Stateless portable compute server for KLO.",
version="0.1.0",
)
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "healthy"}
@app.post("/database/introspect", response_model=DatabaseIntrospectionResponse)
async def database_introspect(
request: DatabaseIntrospectionRequest,
) -> DatabaseIntrospectionResponse:
try:
introspector = database_introspector or introspect_database_response
return introspector(request)
except ValueError as error:
logger.warning("Database introspection rejected: %s", error)
raise HTTPException(status_code=400, detail=str(error)) from error
except Exception as error:
logger.exception("Database introspection failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Database introspection failed: {error}",
) from error
@app.post("/embeddings/compute", response_model=ComputeEmbeddingResponse)
async def embedding_compute(
request: ComputeEmbeddingRequest,
) -> ComputeEmbeddingResponse:
try:
return compute_embedding_response(
request,
provider=embedding_provider,
)
except ValueError as error:
logger.warning("Embedding compute rejected: %s", error)
raise HTTPException(status_code=400, detail=str(error)) from error
except Exception as error:
logger.exception("Embedding compute failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Embedding compute failed: {error}",
) from error
@app.post(
"/embeddings/compute-bulk",
response_model=ComputeEmbeddingBulkResponse,
)
async def embedding_compute_bulk(
request: ComputeEmbeddingBulkRequest,
) -> ComputeEmbeddingBulkResponse:
try:
return compute_embedding_bulk_response(
request,
provider=embedding_provider,
)
except ValueError as error:
logger.warning("Bulk embedding compute rejected: %s", error)
raise HTTPException(status_code=400, detail=str(error)) from error
except Exception as error:
logger.exception("Bulk embedding compute failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Bulk embedding compute failed: {error}",
) from error
if enable_code_execution:
@app.post(
"/code/execute",
response_model=ExecuteCodeResponse,
response_class=NumpyORJSONResponse,
)
async def code_execute(request: ExecuteCodeRequest) -> ExecuteCodeResponse:
try:
return execute_code_response(
request,
nest_api_url=None,
auth_header=None,
)
except Exception as error:
logger.exception("Code execution failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Code execution failed: {error}",
) from error
@app.post("/lookml/parse", response_model=ParseLookMLResponse)
async def lookml_parse(request: ParseLookMLRequest) -> ParseLookMLResponse:
try:
return parse_lookml_project(request)
except Exception as error:
logger.exception("LookML parsing failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"LookML parsing failed: {error}",
) from error
@app.post(
"/sql/parse-table-identifier",
response_model=ParseTableIdentifierBatchResponse,
)
async def sql_parse_table_identifier(
request: ParseTableIdentifierBatchRequest,
) -> ParseTableIdentifierBatchResponse:
try:
return parse_table_identifier_response(request)
except Exception as error:
logger.exception("Table identifier parsing failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Table identifier parsing failed: {error}",
) from error
@app.post(
"/semantic-layer/generate-sources", response_model=GenerateSourcesResponse
)
async def semantic_generate_sources(
request: GenerateSourcesRequest,
) -> GenerateSourcesResponse:
try:
return generate_sources_response(request)
except Exception as error:
logger.exception("Semantic source generation failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Semantic source generation failed: {error}",
) from error
@app.post("/semantic-layer/query", response_model=SemanticLayerQueryResponse)
async def semantic_query(
request: SemanticLayerQueryRequest,
) -> SemanticLayerQueryResponse:
try:
return query_semantic_layer(request)
except ValueError as error:
logger.warning("Semantic query rejected: %s", error)
raise HTTPException(status_code=400, detail=str(error)) from error
except Exception as error:
logger.exception("Semantic query failed: %s", error)
raise HTTPException(
status_code=500,
detail=f"Semantic layer query failed: {error}",
) from error
@app.post("/semantic-layer/validate", response_model=ValidateSourcesResponse)
async def semantic_validate(
request: ValidateSourcesRequest,
) -> ValidateSourcesResponse:
return validate_semantic_layer(request)
return app

View file

@ -0,0 +1,333 @@
"""Portable in-process code execution helpers for KLO daemon.
This module preserves the host application's current Python execution behavior.
It runs code with Python ``exec`` in the current process and does not provide
OS-level sandboxing.
"""
from __future__ import annotations
import json
import logging
import re
import sys
from collections.abc import Callable
from io import BytesIO, StringIO
from typing import Any
import numpy as np
import orjson
import pandas as pd
import requests
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
VALID_VISUALIZATION_TYPES = ["pie", "bar", "line", "area", "table", "boxplot"]
class ExecuteCodeRequest(BaseModel):
"""Request schema for executing Python code."""
code: str = Field(..., description="Python code to execute")
source_id: str | None = Field(
None,
description="Chat/dashboard ID for scratchpad file access",
)
message_id: str | None = Field(
None,
description="Message ID for visualization association",
)
class VisualizationSpec(BaseModel):
"""Specification for a visualization to be saved by the host application."""
type: str = Field(..., description="Type marker, always 'visualization'")
vis_type: str = Field(
...,
description="Visualization type: pie, bar, line, area, table",
)
config: dict[str, Any] = Field(
...,
description="Visualization configuration",
)
data: list[dict[str, Any]] = Field(
...,
description="Visualization data",
)
title: str | None = Field(None, description="Optional title")
class ExecuteCodeResponse(BaseModel):
"""Response schema for code execution."""
formatted_result: str = Field(
...,
description="Formatted execution result for display",
)
result: Any | None = Field(
None,
description="The value of the 'result' variable if set",
)
console_output: str | None = Field(
None,
description="Captured stdout from print statements",
)
error: str | None = Field(None, description="Error message if execution failed")
message: str | None = Field(
None,
description="Message if no clear result was returned",
)
visualizations: list[VisualizationSpec] | None = Field(
None,
description="List of visualizations detected in the result",
)
ScratchpadHelpers = tuple[
Callable[[pd.DataFrame, str | None], str],
Callable[[str], pd.DataFrame],
Callable[[str, dict[str, Any], list[dict[str, Any]]], str],
]
def dumps_numpy_json(content: Any) -> bytes:
"""Serialize JSON response content with numpy scalar and array support."""
return orjson.dumps(content, option=orjson.OPT_SERIALIZE_NUMPY)
def _strip_ansi_sequences(text: str) -> str:
ansi_escape = re.compile(
r"\x1b\[[0-9;]*[a-zA-Z]|\x1b\([0-9;]*[a-zA-Z]|\x1b\[[0-9;]*~"
)
return ansi_escape.sub("", text)
def create_scratchpad_helpers(
nest_api_url: str | None,
auth_header: str | None,
source_id: str | None,
message_id: str | None = None,
http_client: Any = requests,
) -> ScratchpadHelpers:
"""Create scratchpad and visualization helpers that call host app APIs."""
def save_df_to_scratchpad(df: pd.DataFrame, filename: str | None = None) -> str:
if not nest_api_url or not auth_header or not source_id:
raise ValueError(
"nest_api_url, Authorization header, and source_id are required "
"for scratchpad operations"
)
data_json = df.to_dict(orient="records")
url = f"{nest_api_url}/private_api/scratchpad/{source_id}/files"
response = http_client.post(
url,
data=dumps_numpy_json(
{"filename": filename, "data": data_json, "format": "json"}
),
headers={"Authorization": auth_header, "Content-Type": "application/json"},
timeout=30,
)
response.raise_for_status()
saved_filename = response.json()["filename"]
rows, _cols = df.shape
return f"{rows} rows saved to {saved_filename}"
def read_scratchpad_file(filename: str) -> pd.DataFrame:
if not nest_api_url or not auth_header or not source_id:
raise ValueError(
"nest_api_url, Authorization header, and source_id are required "
"for scratchpad operations"
)
url = f"{nest_api_url}/private_api/scratchpad/{source_id}/files/{filename}?format=raw"
response = http_client.get(
url,
headers={"Authorization": auth_header, "Accept": "text/csv"},
timeout=30,
)
response.raise_for_status()
content_type = response.headers.get("content-type", "")
if "text/csv" in content_type:
return pd.read_csv(BytesIO(response.content))
data = response.json()["data"]
return pd.DataFrame(data)
def save_visualization(
vis_type: str,
config: dict[str, Any],
data: list[dict[str, Any]],
) -> str:
if not nest_api_url or not auth_header or not source_id:
raise ValueError(
"nest_api_url, Authorization header, and source_id are required "
"for visualization operations"
)
if not message_id:
raise ValueError("message_id is required for visualization operations")
if vis_type not in VALID_VISUALIZATION_TYPES:
raise ValueError(
f"Invalid visualization type: {vis_type}. Must be one of {VALID_VISUALIZATION_TYPES}"
)
url = f"{nest_api_url}/private_api/visualizations/{source_id}"
payload = {
"visualizationType": vis_type,
"config": config,
"data": data,
"messageId": message_id,
}
response = http_client.post(
url,
data=dumps_numpy_json(payload),
headers={"Authorization": auth_header, "Content-Type": "application/json"},
timeout=30,
)
response.raise_for_status()
filename = response.json()["filename"]
print(f"Visualization saved: {filename}")
return f"![viz]({filename})"
return save_df_to_scratchpad, read_scratchpad_file, save_visualization
def detect_visualizations(result: Any) -> list[dict[str, Any]]:
"""Detect visualization specs in a code execution result value."""
visualizations = []
if isinstance(result, dict) and result.get("type") == "visualization":
visualizations.append(result)
elif isinstance(result, list):
for item in result:
if isinstance(item, dict) and item.get("type") == "visualization":
visualizations.append(item)
return visualizations
def execute_code(
code: str,
nest_api_url: str | None = None,
auth_header: str | None = None,
source_id: str | None = None,
message_id: str | None = None,
scratchpad_helpers: ScratchpadHelpers | None = None,
) -> dict[str, Any]:
"""Execute Python code with the current in-process execution boundary."""
logger.info("Starting code execution")
save_df, read_file, save_viz = scratchpad_helpers or create_scratchpad_helpers(
nest_api_url,
auth_header,
source_id,
message_id,
)
namespace = {
"pd": pd,
"np": np,
"json": json,
"requests": requests,
"save_df_to_scratchpad": save_df,
"read_scratchpad_file": read_file,
"save_visualization": save_viz,
}
stdout_capture = StringIO()
original_stdout = sys.stdout
sys.stdout = stdout_capture
console_output = ""
try:
logger.info("Executing code in current process namespace")
exec(code, namespace)
console_output = stdout_capture.getvalue()
if "result" in namespace:
logger.info("Code execution complete, 'result' variable found")
result_value = namespace["result"]
visualizations = detect_visualizations(result_value)
result = {"result": result_value}
if console_output:
result["console_output"] = console_output
if visualizations:
result["visualizations"] = visualizations
return result
logger.info("No result variable found")
result = {
"message": "Code executed successfully but no result variable was set"
}
if console_output:
result["console_output"] = console_output
return result
except Exception as error:
logger.exception("Error executing code: %s", error)
result = {"error": str(error)}
if console_output:
result["console_output"] = console_output
return result
finally:
sys.stdout = original_stdout
def format_execution_result(result: dict[str, Any]) -> str:
"""Format execution output for display in host chat responses."""
formatted_result = ""
if "console_output" in result:
formatted_result += "=== Console Output ===\n\n"
formatted_result += _strip_ansi_sequences(result["console_output"])
if "result" in result:
formatted_result += "\n\n=== Result ===\n\n"
formatted_result += str(result["result"])
elif "message" in result:
formatted_result += "\n\n=== Message ===\n\n"
formatted_result += result["message"]
elif "error" in result:
formatted_result += "\n\n=== Error ===\n\n"
formatted_result += result["error"]
return formatted_result
def execute_code_response(
request: ExecuteCodeRequest,
*,
nest_api_url: str | None,
auth_header: str | None,
) -> ExecuteCodeResponse:
"""Execute a validated request and return the public response model."""
result = execute_code(
code=request.code,
nest_api_url=nest_api_url,
auth_header=auth_header,
source_id=request.source_id,
message_id=request.message_id,
)
return ExecuteCodeResponse(
formatted_result=format_execution_result(result),
result=result.get("result"),
console_output=result.get("console_output"),
error=result.get("error"),
message=result.get("message"),
visualizations=result.get("visualizations"),
)

View file

@ -0,0 +1,284 @@
"""Portable database introspection helpers for KLO daemon."""
from __future__ import annotations
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
from pydantic import BaseModel, Field, field_validator
TABLES_SQL = """
select
t.table_catalog,
t.table_schema,
t.table_name,
obj_description(c.oid) as table_comment
from information_schema.tables t
join pg_catalog.pg_namespace n
on n.nspname = t.table_schema
join pg_catalog.pg_class c
on c.relnamespace = n.oid
and c.relname = t.table_name
where t.table_schema = any(%s)
and t.table_type = 'BASE TABLE'
order by t.table_schema, t.table_name
"""
COLUMNS_SQL = """
select
current_database() as table_catalog,
n.nspname as table_schema,
c.relname as table_name,
a.attname as column_name,
pg_catalog.format_type(a.atttypid, a.atttypmod) as formatted_type,
not a.attnotnull as is_nullable,
exists (
select 1
from pg_catalog.pg_index i
where i.indrelid = c.oid
and i.indisprimary
and a.attnum = any(i.indkey)
) as is_primary_key,
pg_catalog.col_description(c.oid, a.attnum) as column_comment
from pg_catalog.pg_attribute a
join pg_catalog.pg_class c
on c.oid = a.attrelid
join pg_catalog.pg_namespace n
on n.oid = c.relnamespace
where n.nspname = any(%s)
and c.relkind in ('r', 'p')
and a.attnum > 0
and not a.attisdropped
order by n.nspname, c.relname, a.attnum
"""
FOREIGN_KEYS_SQL = """
select
current_database() as table_catalog,
source_constraint.table_schema,
source_constraint.table_name,
source_key.column_name as from_column,
target_key.table_name as to_table,
target_key.column_name as to_column,
source_constraint.constraint_name
from information_schema.table_constraints source_constraint
join information_schema.key_column_usage source_key
on source_key.constraint_catalog = source_constraint.constraint_catalog
and source_key.constraint_schema = source_constraint.constraint_schema
and source_key.constraint_name = source_constraint.constraint_name
join information_schema.referential_constraints ref_constraint
on ref_constraint.constraint_catalog = source_constraint.constraint_catalog
and ref_constraint.constraint_schema = source_constraint.constraint_schema
and ref_constraint.constraint_name = source_constraint.constraint_name
join information_schema.key_column_usage target_key
on target_key.constraint_catalog = ref_constraint.unique_constraint_catalog
and target_key.constraint_schema = ref_constraint.unique_constraint_schema
and target_key.constraint_name = ref_constraint.unique_constraint_name
and target_key.ordinal_position = source_key.position_in_unique_constraint
where source_constraint.constraint_type = 'FOREIGN KEY'
and source_constraint.table_schema = any(%s)
order by source_constraint.table_schema, source_constraint.table_name, source_constraint.constraint_name, source_key.ordinal_position
"""
class LiveDatabaseColumn(BaseModel):
name: str
type: str
nullable: bool = True
primary_key: bool = False
comment: str | None = None
class LiveDatabaseForeignKey(BaseModel):
from_column: str
to_table: str
to_column: str
constraint_name: str | None = None
class LiveDatabaseTable(BaseModel):
catalog: str | None = None
db: str | None = None
name: str
comment: str | None = None
columns: list[LiveDatabaseColumn] = Field(default_factory=list)
foreign_keys: list[LiveDatabaseForeignKey] = Field(default_factory=list)
class DatabaseIntrospectionRequest(BaseModel):
connection_id: str
driver: str = "postgres"
url: str
schemas: list[str] = Field(default_factory=lambda: ["public"])
statement_timeout_ms: int = Field(default=30_000, ge=1)
connection_timeout_seconds: int = Field(default=5, ge=1)
@field_validator("schemas")
@classmethod
def _schemas_must_not_be_empty(cls, value: list[str]) -> list[str]:
if not value:
raise ValueError("database introspection requires at least one schema")
return value
class DatabaseIntrospectionResponse(BaseModel):
connection_id: str
extracted_at: str
metadata: dict[str, Any]
tables: list[LiveDatabaseTable]
@dataclass(frozen=True)
class DatabaseIntrospectionRows:
table_rows: Sequence[Mapping[str, Any]]
column_rows: Sequence[Mapping[str, Any]]
foreign_key_rows: Sequence[Mapping[str, Any]]
DatabaseRowsLoader = Callable[[DatabaseIntrospectionRequest], DatabaseIntrospectionRows]
NowProvider = Callable[[], str]
def _driver_name(driver: str) -> str:
return driver.strip().lower()
def _table_key(catalog: str | None, db: str | None, name: str) -> str:
return f"{catalog or ''}\u0000{db or ''}\u0000{name}"
def _optional_string(row: Mapping[str, Any], key: str) -> str | None:
value = row.get(key)
return value if isinstance(value, str) else None
def _required_string(row: Mapping[str, Any], key: str) -> str:
value = row.get(key)
if not isinstance(value, str) or not value:
raise ValueError(f"database introspection row is missing string field {key}")
return value
def _statement_timeout_config(statement_timeout_ms: int) -> tuple[str, tuple[str]]:
return (
"SELECT set_config('statement_timeout', %s, true)",
(f"{int(statement_timeout_ms)}ms",),
)
def _load_postgres_rows(
request: DatabaseIntrospectionRequest,
) -> DatabaseIntrospectionRows:
try:
import psycopg
from psycopg.rows import dict_row
except ImportError as error:
raise RuntimeError(
"psycopg is required for Postgres database introspection"
) from error
connection = psycopg.connect(
request.url,
connect_timeout=request.connection_timeout_seconds,
application_name="klo-daemon-database-introspection",
row_factory=dict_row,
)
try:
connection.execute("BEGIN READ ONLY")
try:
connection.execute(*_statement_timeout_config(request.statement_timeout_ms))
params = (request.schemas,)
table_rows = list(connection.execute(TABLES_SQL, params))
column_rows = list(connection.execute(COLUMNS_SQL, params))
foreign_key_rows = list(connection.execute(FOREIGN_KEYS_SQL, params))
connection.execute("COMMIT")
except Exception:
connection.execute("ROLLBACK")
raise
finally:
connection.close()
return DatabaseIntrospectionRows(
table_rows=table_rows,
column_rows=column_rows,
foreign_key_rows=foreign_key_rows,
)
def _map_rows_to_tables(rows: DatabaseIntrospectionRows) -> list[LiveDatabaseTable]:
tables: dict[str, LiveDatabaseTable] = {}
for row in rows.table_rows:
catalog = _optional_string(row, "table_catalog")
db = _required_string(row, "table_schema")
name = _required_string(row, "table_name")
key = _table_key(catalog, db, name)
tables[key] = LiveDatabaseTable(
catalog=catalog,
db=db,
name=name,
comment=_optional_string(row, "table_comment"),
)
for row in rows.column_rows:
catalog = _optional_string(row, "table_catalog")
db = _required_string(row, "table_schema")
table_name = _required_string(row, "table_name")
table = tables.get(_table_key(catalog, db, table_name))
if table is None:
continue
table.columns.append(
LiveDatabaseColumn(
name=_required_string(row, "column_name"),
type=_required_string(row, "formatted_type"),
nullable=bool(row.get("is_nullable")),
primary_key=bool(row.get("is_primary_key")),
comment=_optional_string(row, "column_comment"),
)
)
for row in rows.foreign_key_rows:
catalog = _optional_string(row, "table_catalog")
db = _required_string(row, "table_schema")
table_name = _required_string(row, "table_name")
table = tables.get(_table_key(catalog, db, table_name))
if table is None:
continue
table.foreign_keys.append(
LiveDatabaseForeignKey(
from_column=_required_string(row, "from_column"),
to_table=_required_string(row, "to_table"),
to_column=_required_string(row, "to_column"),
constraint_name=_optional_string(row, "constraint_name"),
)
)
return sorted(
tables.values(),
key=lambda table: _table_key(table.catalog, table.db, table.name),
)
def introspect_database_response(
request: DatabaseIntrospectionRequest,
*,
load_rows: DatabaseRowsLoader | None = None,
now: NowProvider | None = None,
) -> DatabaseIntrospectionResponse:
driver = _driver_name(request.driver)
if driver not in {"postgres", "postgresql"}:
raise ValueError('database introspection supports only driver "postgres"')
rows = (load_rows or _load_postgres_rows)(request)
timestamp = now() if now else datetime.now(timezone.utc).isoformat()
return DatabaseIntrospectionResponse(
connection_id=request.connection_id,
extracted_at=timestamp,
metadata={"driver": driver, "schemas": list(request.schemas)},
tables=_map_rows_to_tables(rows),
)

View file

@ -0,0 +1,172 @@
"""Portable embedding compute helpers for KLO daemon."""
from __future__ import annotations
import logging
import threading
from typing import TYPE_CHECKING, Protocol
from pydantic import BaseModel, Field
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
DEFAULT_SENTENCE_TRANSFORMER_MODEL = "all-MiniLM-L6-v2"
DEFAULT_EMBEDDING_DIMENSIONS = 384
DEFAULT_MAX_BATCH_SIZE = 100
class EmbeddingProvider(Protocol):
"""Provider interface for local embedding compute."""
@property
def name(self) -> str: ...
@property
def dimensions(self) -> int: ...
@property
def max_batch_size(self) -> int: ...
def encode(self, texts: list[str]) -> list[list[float]]: ...
class ComputeEmbeddingRequest(BaseModel):
"""Request schema for computing a single embedding."""
text: str = Field(..., description="Text to compute embedding for", min_length=1)
class ComputeEmbeddingResponse(BaseModel):
"""Response schema for single embedding computation."""
embedding: list[float] = Field(..., description="384-dimensional embedding vector")
class ComputeEmbeddingBulkRequest(BaseModel):
"""Request schema for computing multiple embeddings."""
texts: list[str] = Field(
...,
description="List of texts to compute embeddings for",
min_length=1,
max_length=DEFAULT_MAX_BATCH_SIZE,
)
class ComputeEmbeddingBulkResponse(BaseModel):
"""Response schema for bulk embedding computation."""
embeddings: list[list[float]] = Field(
...,
description="List of 384-dimensional embedding vectors",
)
class SentenceTransformersEmbeddingProvider:
"""Lazy sentence-transformers provider for local embeddings."""
def __init__(
self,
model_name: str = DEFAULT_SENTENCE_TRANSFORMER_MODEL,
model: SentenceTransformer | None = None,
) -> None:
self.model_name = model_name
self._model = model
self._model_lock = threading.Lock()
@property
def name(self) -> str:
return "sentence-transformers"
@property
def dimensions(self) -> int:
return DEFAULT_EMBEDDING_DIMENSIONS
@property
def max_batch_size(self) -> int:
return DEFAULT_MAX_BATCH_SIZE
def _get_model(self) -> SentenceTransformer:
if self._model is not None:
return self._model
with self._model_lock:
if self._model is None:
from sentence_transformers import SentenceTransformer
logger.info("Loading SentenceTransformer model: %s", self.model_name)
self._model = SentenceTransformer(self.model_name)
logger.info("SentenceTransformer model loaded successfully")
return self._model
def encode(self, texts: list[str]) -> list[list[float]]:
model = self._get_model()
if len(texts) == 1:
raw_single = model.encode(texts[0]).tolist()
return [[float(value) for value in raw_single]]
raw_bulk = model.encode(texts).tolist()
return [[float(value) for value in embedding] for embedding in raw_bulk]
_default_provider: SentenceTransformersEmbeddingProvider | None = None
_default_provider_lock = threading.Lock()
def get_default_embedding_provider() -> SentenceTransformersEmbeddingProvider:
"""Return the process-wide default embedding provider."""
global _default_provider
if _default_provider is not None:
return _default_provider
with _default_provider_lock:
if _default_provider is None:
_default_provider = SentenceTransformersEmbeddingProvider()
return _default_provider
def _validate_texts(texts: list[str], max_batch_size: int) -> None:
if not texts:
raise ValueError("Texts array must not be empty")
if len(texts) > max_batch_size:
raise ValueError(f"Maximum {max_batch_size} texts allowed per batch")
empty_indices = [
index for index, text in enumerate(texts) if not text or not text.strip()
]
if empty_indices:
joined_indices = ", ".join(str(index) for index in empty_indices)
raise ValueError(f"Empty texts found at indices: {joined_indices}")
def compute_embedding_response(
request: ComputeEmbeddingRequest,
provider: EmbeddingProvider | None = None,
) -> ComputeEmbeddingResponse:
"""Compute one embedding from a request model."""
selected_provider = provider or get_default_embedding_provider()
_validate_texts([request.text], selected_provider.max_batch_size)
return ComputeEmbeddingResponse(
embedding=selected_provider.encode([request.text])[0]
)
def compute_embedding_bulk_response(
request: ComputeEmbeddingBulkRequest,
provider: EmbeddingProvider | None = None,
) -> ComputeEmbeddingBulkResponse:
"""Compute multiple embeddings from a request model."""
selected_provider = provider or get_default_embedding_provider()
_validate_texts(request.texts, selected_provider.max_batch_size)
return ComputeEmbeddingBulkResponse(
embeddings=selected_provider.encode(request.texts)
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,136 @@
"""Semantic-layer compute helpers for the KLO daemon package."""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, Field
from semantic_layer.duplicate_check import validate_measure_duplicates
from semantic_layer.engine import SemanticEngine
from semantic_layer.models import QueryResult, SourceDefinition
class SemanticLayerQueryRequest(BaseModel):
sources: list[dict[str, Any]]
query: dict[str, Any]
dialect: str = "postgres"
class SemanticLayerQueryResponse(BaseModel):
sql: str
dialect: str
columns: list[dict[str, Any]]
plan: dict[str, Any]
class ValidateSourcesRequest(BaseModel):
sources: list[dict[str, Any]]
dialect: str = "postgres"
recently_touched: list[str] | None = None
class ValidateSourcesResponse(BaseModel):
valid: bool
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
per_source_warnings: dict[str, list[str]] = Field(default_factory=dict)
def _load_sources(raw_sources: list[dict[str, Any]]) -> dict[str, SourceDefinition]:
sources: dict[str, SourceDefinition] = {}
for raw_source in raw_sources:
source = SourceDefinition(**raw_source)
if source.name in sources:
raise ValueError(f"Duplicate source name '{source.name}'")
sources[source.name] = source
return sources
def _validate_duplicate_measure_names(source: SourceDefinition) -> list[str]:
errors: list[str] = []
seen: set[str] = set()
for measure in source.measures:
if measure.name in seen:
errors.append(
f"Duplicate measure '{measure.name}' on source '{source.name}'"
)
continue
seen.add(measure.name)
return errors
def _response_columns(result: QueryResult) -> list[dict[str, Any]]:
measure_names = {
measure.name: measure.qualified_ref
for measure in result.resolved_plan.measures
if measure.qualified_ref
}
columns: list[dict[str, Any]] = []
for column in result.columns:
dumped = column.model_dump(mode="json")
if column.provenance.value == "dimension" and column.expr:
dumped["name"] = column.expr
elif column.name in measure_names:
dumped["name"] = measure_names[column.name]
columns.append(dumped)
return columns
def query_semantic_layer(
request: SemanticLayerQueryRequest,
) -> SemanticLayerQueryResponse:
sources = _load_sources(request.sources)
engine = SemanticEngine.from_sources(sources, dialect=request.dialect)
result = engine.query(request.query)
return SemanticLayerQueryResponse(
sql=result.sql,
dialect=result.dialect,
columns=_response_columns(result),
plan=result.resolved_plan.model_dump(mode="json"),
)
def validate_semantic_layer(request: ValidateSourcesRequest) -> ValidateSourcesResponse:
errors: list[str] = []
warnings: list[str] = []
per_source_warnings: dict[str, list[str]] = {}
sources: dict[str, SourceDefinition] = {}
seen_names: set[str] = set()
for raw_source in request.sources:
raw_name = raw_source.get("name") if isinstance(raw_source, dict) else None
try:
source = SourceDefinition(**raw_source)
except Exception as error:
label = raw_name or "<unknown>"
errors.append(f"Source '{label}' failed to parse: {error}")
continue
if source.name in seen_names:
errors.append(f"Duplicate source name '{source.name}'")
continue
seen_names.add(source.name)
sources[source.name] = source
errors.extend(_validate_duplicate_measure_names(source))
if sources:
try:
engine = SemanticEngine.from_sources(sources, dialect=request.dialect)
report = engine.validate(
recently_touched=set(request.recently_touched)
if request.recently_touched
else None
)
errors.extend(report.errors)
warnings.extend(report.warnings)
per_source_warnings.update(report.per_source_warnings)
errors.extend(validate_measure_duplicates(sources, dialect=request.dialect))
except Exception as error:
errors.append(f"Validation failed: {error}")
return ValidateSourcesResponse(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
per_source_warnings=per_source_warnings,
)

View file

@ -0,0 +1,254 @@
"""Generate klo-sl YAML source definitions from database schema scan data."""
from __future__ import annotations
import logging
import re
from typing import Any
from pydantic import BaseModel
from semantic_layer.models import (
ColumnRole,
JoinDeclaration,
MeasureDefinition,
SourceColumn,
SourceDefinition,
)
logger = logging.getLogger(__name__)
_NUMBER_PATTERN = re.compile(
r"int|integer|bigint|smallint|tinyint|numeric|decimal|float|double|real|number|money",
re.IGNORECASE,
)
_TIME_PATTERN = re.compile(
r"timestamp|datetime|date|time(?!stamp)",
re.IGNORECASE,
)
_BOOLEAN_PATTERN = re.compile(r"bool|boolean|bit", re.IGNORECASE)
_ID_PATTERN = re.compile(
r"^id$|_id$|^uuid$|_uuid$|_key$|_pk$|identifier$",
re.IGNORECASE,
)
_RELATIONSHIP_MAP = {
"MANY_TO_ONE": "many_to_one",
"ONE_TO_MANY": "one_to_many",
"ONE_TO_ONE": "one_to_one",
"many_to_one": "many_to_one",
"one_to_many": "one_to_many",
"one_to_one": "one_to_one",
}
_RELATIONSHIP_INVERSE = {
"many_to_one": "one_to_many",
"one_to_many": "many_to_one",
"one_to_one": "one_to_one",
}
class ColumnInput(BaseModel):
name: str
type: str
primary_key: bool = False
nullable: bool = True
comment: str | None = None
class TableInput(BaseModel):
name: str
catalog: str | None = None
db: str | None = None
comment: str | None = None
columns: list[ColumnInput]
class LinkInput(BaseModel):
from_table: str
from_column: str
to_table: str
to_column: str
relationship_type: str
class GenerateSourcesRequest(BaseModel):
tables: list[TableInput]
links: list[LinkInput]
dialect: str = "postgres"
class GenerateSourcesResponse(BaseModel):
sources: list[dict[str, Any]]
source_count: int
def _map_column_type(db_type: str) -> str:
if _BOOLEAN_PATTERN.search(db_type):
return "boolean"
if _TIME_PATTERN.search(db_type):
return "time"
if _NUMBER_PATTERN.search(db_type):
return "number"
return "string"
def _build_table_ref(table: TableInput) -> str:
parts = []
if table.catalog:
parts.append(table.catalog)
if table.db:
parts.append(table.db)
parts.append(table.name)
return ".".join(parts)
def _generate_measures(
table_name: str,
columns: list[ColumnInput],
pk_columns: list[str],
) -> list[MeasureDefinition]:
measures: list[MeasureDefinition] = []
if pk_columns:
pk = pk_columns[0]
measures.append(
MeasureDefinition(
name="record_count",
expr=f"count({pk})",
description=f"Count of {table_name} records",
)
)
for col in columns:
if _map_column_type(col.type) != "number":
continue
if _ID_PATTERN.search(col.name):
continue
measures.append(
MeasureDefinition(
name=f"total_{col.name}",
expr=f"sum({col.name})",
description=f"Sum of {col.name}"
+ (f" \u2014 {col.comment}" if col.comment else ""),
)
)
measures.append(
MeasureDefinition(
name=f"avg_{col.name}",
expr=f"avg({col.name})",
description=f"Average of {col.name}"
+ (f" \u2014 {col.comment}" if col.comment else ""),
)
)
return measures
def generate_sources(request: GenerateSourcesRequest) -> list[dict[str, Any]]:
links_by_from: dict[str, list[LinkInput]] = {}
links_by_to: dict[str, list[LinkInput]] = {}
for link in request.links:
links_by_from.setdefault(link.from_table, []).append(link)
links_by_to.setdefault(link.to_table, []).append(link)
table_names = {table.name for table in request.tables}
sources: list[dict[str, Any]] = []
for table in request.tables:
pk_columns = [column.name for column in table.columns if column.primary_key]
grain = (
pk_columns
if pk_columns
else [table.columns[0].name]
if table.columns
else ["id"]
)
sl_columns: list[SourceColumn] = []
for column in table.columns:
sl_type = _map_column_type(column.type)
role = ColumnRole.TIME if sl_type == "time" else ColumnRole.DEFAULT
sl_columns.append(
SourceColumn(
name=column.name,
type=sl_type,
role=role,
description=column.comment,
)
)
joins: list[JoinDeclaration] = []
for link in links_by_from.get(table.name, []):
if link.to_table not in table_names:
logger.warning(
"Skipping link from %s.%s to %s.%s: target table not in scan",
link.from_table,
link.from_column,
link.to_table,
link.to_column,
)
continue
relationship = _RELATIONSHIP_MAP.get(link.relationship_type, "many_to_one")
joins.append(
JoinDeclaration(
to=link.to_table,
on=f"{link.from_column} = {link.to_table}.{link.to_column}",
relationship=relationship,
)
)
for link in links_by_to.get(table.name, []):
if link.from_table not in table_names:
logger.warning(
"Skipping reverse link from %s.%s to %s.%s: source table not in scan",
link.from_table,
link.from_column,
link.to_table,
link.to_column,
)
continue
forward_relationship = _RELATIONSHIP_MAP.get(
link.relationship_type, "many_to_one"
)
reverse_relationship = _RELATIONSHIP_INVERSE.get(
forward_relationship, "one_to_many"
)
joins.append(
JoinDeclaration(
to=link.from_table,
on=f"{link.to_column} = {link.from_table}.{link.from_column}",
relationship=reverse_relationship,
)
)
to_counts: dict[str, int] = {}
for join in joins:
to_counts[join.to] = to_counts.get(join.to, 0) + 1
if any(count > 1 for count in to_counts.values()):
for join in joins:
if to_counts[join.to] > 1:
fk_col = join.on.split(" = ")[0].strip().lower()
join.alias = f"{join.to}_{fk_col}"
source = SourceDefinition(
name=table.name,
description=table.comment,
table=_build_table_ref(table),
grain=grain,
columns=sl_columns,
joins=joins,
measures=_generate_measures(table.name, table.columns, pk_columns),
)
sources.append(source.model_dump(exclude_none=True))
logger.info("Generated %d klo-sl source definitions", len(sources))
return sources
def generate_sources_response(
request: GenerateSourcesRequest,
) -> GenerateSourcesResponse:
sources = generate_sources(request)
return GenerateSourcesResponse(sources=sources, source_count=len(sources))

View file

@ -0,0 +1,66 @@
from __future__ import annotations
from dataclasses import asdict
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
from semantic_layer.table_identifier_parser import (
ParseTableIdentifierItem as SharedParseTableIdentifierItem,
parse_table_identifier_batch,
)
ParseTableIdentifierReason = Literal[
"looker_template_unresolved",
"derived_table_not_supported",
"no_physical_table",
"multiple_table_references",
"unsupported_dialect",
"parse_error",
]
class ParseTableIdentifierItem(BaseModel):
key: str
sql_table_name: str
dialect: str
class ParseTableIdentifierBatchRequest(BaseModel):
items: list[ParseTableIdentifierItem]
class ParsedIdentifier(BaseModel):
model_config = ConfigDict(populate_by_name=True)
ok: bool
catalog: str | None = None
schema_: str | None = Field(default=None, alias="schema")
name: str | None = None
canonical_table: str | None = None
reason: ParseTableIdentifierReason | None = None
detail: str | None = None
class ParseTableIdentifierBatchResponse(BaseModel):
results: dict[str, ParsedIdentifier]
def parse_table_identifier_response(
request: ParseTableIdentifierBatchRequest,
) -> ParseTableIdentifierBatchResponse:
shared_results = parse_table_identifier_batch(
[
SharedParseTableIdentifierItem(
key=item.key,
sql_table_name=item.sql_table_name,
dialect=item.dialect,
)
for item in request.items
]
)
return ParseTableIdentifierBatchResponse(
results={
key: ParsedIdentifier.model_validate(asdict(value))
for key, value in shared_results.items()
}
)

View file

@ -0,0 +1,442 @@
from __future__ import annotations
from fastapi.testclient import TestClient
from klo_daemon.app import create_app
from klo_daemon.database_introspection import (
DatabaseIntrospectionResponse,
LiveDatabaseColumn,
LiveDatabaseTable,
)
ORDERS_SOURCE = {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "status", "type": "string"},
{"name": "amount", "type": "number"},
],
"joins": [],
"measures": [{"name": "order_count", "expr": "count(*)"}],
}
LOOKML_ORDER_VIEW = """
view: orders {
sql_table_name: public.orders ;;
dimension: id {
primary_key: yes
type: number
sql: ${TABLE}.id ;;
}
dimension: status {
type: string
sql: ${TABLE}.status ;;
}
measure: order_count {
type: count
}
}
"""
class FakeEmbeddingProvider:
name = "fake"
dimensions = 3
max_batch_size = 2
def __init__(self) -> None:
self.calls: list[list[str]] = []
def encode(self, texts: list[str]) -> list[list[float]]:
self.calls.append(list(texts))
return [
[float(len(text)), float(index), 1.0] for index, text in enumerate(texts)
]
def test_health_endpoint_returns_healthy() -> None:
client = TestClient(create_app())
response = client.get("/health")
assert response.status_code == 200
assert response.json() == {"status": "healthy"}
def test_database_introspect_endpoint_returns_snapshot() -> None:
calls = []
def fake_introspector(request):
calls.append(request)
return DatabaseIntrospectionResponse(
connection_id=request.connection_id,
extracted_at="2026-04-28T10:00:00+00:00",
metadata={"driver": request.driver, "schemas": request.schemas},
tables=[
LiveDatabaseTable(
catalog="warehouse",
db="public",
name="orders",
columns=[
LiveDatabaseColumn(
name="id",
type="integer",
nullable=False,
primary_key=True,
)
],
)
],
)
client = TestClient(create_app(database_introspector=fake_introspector))
response = client.post(
"/database/introspect",
json={
"connection_id": "warehouse",
"driver": "postgres",
"url": "postgresql://readonly@example.test/warehouse",
"schemas": ["public"],
},
)
assert response.status_code == 200
assert response.json()["connection_id"] == "warehouse"
assert response.json()["tables"][0]["name"] == "orders"
assert calls[0].connection_id == "warehouse"
def test_database_introspect_endpoint_maps_value_error_to_400() -> None:
def fake_introspector(request):
raise ValueError('database introspection supports only driver "postgres"')
client = TestClient(create_app(database_introspector=fake_introspector))
response = client.post(
"/database/introspect",
json={
"connection_id": "warehouse",
"driver": "snowflake",
"url": "snowflake://example",
},
)
assert response.status_code == 400
assert response.json() == {
"detail": 'database introspection supports only driver "postgres"'
}
def test_embedding_compute_endpoint_returns_embedding() -> None:
provider = FakeEmbeddingProvider()
client = TestClient(create_app(embedding_provider=provider))
response = client.post("/embeddings/compute", json={"text": "hello"})
assert response.status_code == 200
assert response.json() == {"embedding": [5.0, 0.0, 1.0]}
assert provider.calls == [["hello"]]
def test_embedding_compute_bulk_endpoint_returns_embeddings() -> None:
provider = FakeEmbeddingProvider()
client = TestClient(create_app(embedding_provider=provider))
response = client.post(
"/embeddings/compute-bulk",
json={"texts": ["one", "three"]},
)
assert response.status_code == 200
assert response.json() == {"embeddings": [[3.0, 0.0, 1.0], [5.0, 1.0, 1.0]]}
assert provider.calls == [["one", "three"]]
def test_embedding_compute_bulk_endpoint_maps_value_error_to_400() -> None:
provider = FakeEmbeddingProvider()
client = TestClient(create_app(embedding_provider=provider))
response = client.post(
"/embeddings/compute-bulk",
json={"texts": ["one", "two", "three"]},
)
assert response.status_code == 400
assert response.json() == {"detail": "Maximum 2 texts allowed per batch"}
assert provider.calls == []
def test_code_execute_endpoint_is_not_registered_by_default() -> None:
client = TestClient(create_app())
response = client.post("/code/execute", json={"code": "result = 7"})
assert response.status_code == 404
def test_code_execute_endpoint_returns_result_when_enabled() -> None:
client = TestClient(create_app(enable_code_execution=True))
response = client.post(
"/code/execute",
json={"code": 'print("ran")\nresult = {"value": 7}'},
)
assert response.status_code == 200
body = response.json()
assert body["result"] == {"value": 7}
assert body["console_output"] == "ran\n"
assert body["error"] is None
assert body["message"] is None
assert body["visualizations"] is None
assert "=== Console Output ===" in body["formatted_result"]
assert "=== Result ===" in body["formatted_result"]
def test_code_execute_endpoint_serializes_numpy_result_when_enabled() -> None:
client = TestClient(create_app(enable_code_execution=True))
response = client.post(
"/code/execute",
json={"code": "import numpy as np\nresult = {'value': np.float64(1.25)}"},
)
assert response.status_code == 200
body = response.json()
assert body["result"] == {"value": 1.25}
assert body["error"] is None
def test_code_execute_endpoint_uses_host_free_boundary_when_enabled() -> None:
client = TestClient(create_app(enable_code_execution=True))
response = client.post(
"/code/execute",
json={
"source_id": "chat_123",
"message_id": "message_456",
"code": (
"import pandas as pd\n"
"result = save_df_to_scratchpad(pd.DataFrame({'value': [1]}), 'out.json')"
),
},
headers={"Authorization": "Bearer should-not-forward"},
)
assert response.status_code == 200
body = response.json()
assert body["result"] is None
assert (
body["error"]
== "nest_api_url, Authorization header, and source_id are required for scratchpad operations"
)
assert "=== Error ===" in body["formatted_result"]
def test_sql_parse_table_identifier_endpoint() -> None:
client = TestClient(create_app())
response = client.post(
"/sql/parse-table-identifier",
json={
"items": [
{
"key": "orders",
"sql_table_name": "public.orders",
"dialect": "postgres",
},
{
"key": "template",
"sql_table_name": "${orders.SQL_TABLE_NAME}",
"dialect": "postgres",
},
]
},
)
assert response.status_code == 200
body = response.json()
assert body["results"]["orders"]["ok"] is True
assert body["results"]["orders"]["schema"] == "public"
assert body["results"]["orders"]["name"] == "orders"
assert body["results"]["template"]["ok"] is False
assert body["results"]["template"]["reason"] == "looker_template_unresolved"
def test_semantic_query_endpoint_returns_sql() -> None:
client = TestClient(create_app())
response = client.post(
"/semantic-layer/query",
json={
"sources": [ORDERS_SOURCE],
"dialect": "postgres",
"query": {
"measures": ["orders.order_count"],
"dimensions": ["orders.status"],
},
},
)
assert response.status_code == 200
body = response.json()
assert body["dialect"] == "postgres"
assert "public.orders" in body["sql"]
assert body["columns"][0]["name"] == "orders.status"
def test_semantic_query_endpoint_maps_value_error_to_400() -> None:
client = TestClient(create_app())
response = client.post(
"/semantic-layer/query",
json={
"sources": [ORDERS_SOURCE],
"dialect": "postgres",
"query": {
"measures": ["missing.order_count"],
"dimensions": [],
},
},
)
assert response.status_code == 400
assert "missing.order_count" in response.json()["detail"]
def test_semantic_validate_endpoint_returns_structured_validation() -> None:
client = TestClient(create_app())
invalid_source = {
**ORDERS_SOURCE,
"measures": [
{"name": "revenue", "expr": "sum(amount)"},
{"name": "revenue", "expr": "sum(amount)"},
],
}
response = client.post(
"/semantic-layer/validate",
json={"sources": [invalid_source], "dialect": "postgres"},
)
assert response.status_code == 200
body = response.json()
assert body["valid"] is False
assert any("Duplicate measure" in error for error in body["errors"])
assert body["warnings"] == []
assert body["per_source_warnings"] == {}
def test_semantic_generate_sources_endpoint_returns_sources() -> None:
client = TestClient(create_app())
response = client.post(
"/semantic-layer/generate-sources",
json={
"tables": [
{
"name": "orders",
"db": "public",
"comment": "Orders table",
"columns": [
{
"name": "id",
"type": "integer",
"primary_key": True,
"nullable": False,
"comment": "Order ID",
},
{"name": "customer_id", "type": "integer"},
{
"name": "amount",
"type": "decimal",
"comment": "Order amount",
},
],
},
{
"name": "customers",
"db": "public",
"columns": [
{"name": "id", "type": "integer", "primary_key": True},
{"name": "email", "type": "varchar"},
],
},
],
"links": [
{
"from_table": "orders",
"from_column": "customer_id",
"to_table": "customers",
"to_column": "id",
"relationship_type": "MANY_TO_ONE",
}
],
"dialect": "postgres",
},
)
assert response.status_code == 200
body = response.json()
assert body["source_count"] == 2
sources = {source["name"]: source for source in body["sources"]}
assert sources["orders"]["table"] == "public.orders"
assert sources["orders"]["description"] == "Orders table"
assert sources["orders"]["grain"] == ["id"]
assert sources["orders"]["joins"] == [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
]
assert [measure["name"] for measure in sources["orders"]["measures"]] == [
"record_count",
"total_amount",
"avg_amount",
]
def test_lookml_parse_endpoint_returns_resolved_views() -> None:
client = TestClient(create_app())
response = client.post(
"/lookml/parse",
json={
"files": [
{
"path": "views/orders.view.lkml",
"content": LOOKML_ORDER_VIEW,
}
],
"dialect": "postgres",
},
)
assert response.status_code == 200
body = response.json()
assert body["joins"] == []
assert body["skipped_views"] == []
assert body["warnings"] == []
assert len(body["views"]) == 1
view = body["views"][0]
assert view["name"] == "orders"
assert view["source_type"] == "table"
assert view["table_ref"] == "public.orders"
assert view["grain"] == ["id"]
assert [column["name"] for column in view["columns"]] == ["id", "status"]
assert view["measures"] == [
{
"name": "order_count",
"expr": "count(*)",
"filter": None,
"description": None,
}
]

View file

@ -0,0 +1,426 @@
from __future__ import annotations
import io
import json
import os
import subprocess
import sys
from pathlib import Path
from typing import Any
ORDERS_SOURCE = {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "status", "type": "string"},
{"name": "amount", "type": "number"},
],
"joins": [],
"measures": [{"name": "order_count", "expr": "count(*)"}],
}
def run_daemon_command(
command: str, payload: dict[str, object]
) -> subprocess.CompletedProcess[str]:
env = os.environ.copy()
src_path = str(Path(__file__).resolve().parents[1] / "src")
env["PYTHONPATH"] = src_path + os.pathsep + env.get("PYTHONPATH", "")
return subprocess.run(
[sys.executable, "-m", "klo_daemon", command],
input=json.dumps(payload),
text=True,
capture_output=True,
check=False,
env=env,
)
def test_semantic_query_command_reads_stdin_and_writes_json() -> None:
result = run_daemon_command(
"semantic-query",
{
"sources": [ORDERS_SOURCE],
"dialect": "postgres",
"query": {
"measures": ["orders.order_count"],
"dimensions": ["orders.status"],
},
},
)
assert result.returncode == 0, result.stderr
parsed = json.loads(result.stdout)
assert "public.orders" in parsed["sql"]
assert parsed["columns"][0]["name"] == "orders.status"
def test_semantic_validate_command_reads_stdin_and_writes_json() -> None:
result = run_daemon_command(
"semantic-validate",
{"sources": [ORDERS_SOURCE], "dialect": "postgres"},
)
assert result.returncode == 0, result.stderr
parsed = json.loads(result.stdout)
assert parsed == {
"valid": True,
"errors": [],
"warnings": [],
"per_source_warnings": {},
}
def test_command_returns_nonzero_for_invalid_json() -> None:
env = os.environ.copy()
src_path = str(Path(__file__).resolve().parents[1] / "src")
env["PYTHONPATH"] = src_path + os.pathsep + env.get("PYTHONPATH", "")
result = subprocess.run(
[sys.executable, "-m", "klo_daemon", "semantic-query"],
input="{",
text=True,
capture_output=True,
check=False,
env=env,
)
assert result.returncode == 1
assert "Expecting property name enclosed in double quotes" in result.stderr
def test_serve_http_command_starts_uvicorn_without_reading_stdin(
monkeypatch,
) -> None:
from klo_daemon import __main__ as daemon_main
calls: list[dict[str, object]] = []
class FailingStdin:
def read(self) -> str:
raise AssertionError("serve-http must not read stdin JSON")
def fake_run_http_server(
*,
host: str,
port: int,
log_level: str,
enable_code_execution: bool,
) -> None:
calls.append(
{
"host": host,
"port": port,
"log_level": log_level,
"enable_code_execution": enable_code_execution,
}
)
monkeypatch.setattr(sys, "stdin", FailingStdin())
monkeypatch.setattr(daemon_main, "run_http_server", fake_run_http_server)
assert (
daemon_main.main(
[
"serve-http",
"--host",
"127.0.0.1",
"--port",
"9191",
"--log-level",
"warning",
]
)
== 0
)
assert calls == [
{
"host": "127.0.0.1",
"port": 9191,
"log_level": "warning",
"enable_code_execution": False,
}
]
def test_serve_http_command_defaults_to_loopback(monkeypatch) -> None:
from klo_daemon import __main__ as daemon_main
calls: list[dict[str, object]] = []
def fake_run_http_server(
*,
host: str,
port: int,
log_level: str,
enable_code_execution: bool,
) -> None:
calls.append(
{
"host": host,
"port": port,
"log_level": log_level,
"enable_code_execution": enable_code_execution,
}
)
monkeypatch.setattr(daemon_main, "run_http_server", fake_run_http_server)
assert daemon_main.main(["serve-http"]) == 0
assert calls == [
{
"host": "127.0.0.1",
"port": 8765,
"log_level": "info",
"enable_code_execution": False,
}
]
def test_serve_http_command_can_enable_code_execution(monkeypatch) -> None:
from klo_daemon import __main__ as daemon_main
calls: list[dict[str, object]] = []
def fake_run_http_server(
*,
host: str,
port: int,
log_level: str,
enable_code_execution: bool,
) -> None:
calls.append(
{
"host": host,
"port": port,
"log_level": log_level,
"enable_code_execution": enable_code_execution,
}
)
monkeypatch.setattr(daemon_main, "run_http_server", fake_run_http_server)
assert daemon_main.main(["serve-http", "--enable-code-execution"]) == 0
assert calls == [
{
"host": "127.0.0.1",
"port": 8765,
"log_level": "info",
"enable_code_execution": True,
}
]
def test_lookml_parse_command_reads_stdin_and_writes_json() -> None:
result = run_daemon_command(
"lookml-parse",
{
"files": [
{
"path": "views/orders.view.lkml",
"content": """
view: orders {
sql_table_name: public.orders ;;
dimension: id {
primary_key: yes
type: number
sql: ${TABLE}.id ;;
}
measure: order_count {
type: count
}
}
""",
}
],
"dialect": "postgres",
},
)
assert result.returncode == 0, result.stderr
parsed = json.loads(result.stdout)
assert parsed["views"][0]["name"] == "orders"
assert parsed["views"][0]["table_ref"] == "public.orders"
assert parsed["views"][0]["measures"][0]["expr"] == "count(*)"
assert parsed["joins"] == []
assert parsed["skipped_views"] == []
assert parsed["warnings"] == []
def test_semantic_generate_sources_command_reads_stdin_and_writes_json() -> None:
result = run_daemon_command(
"semantic-generate-sources",
{
"tables": [
{
"name": "orders",
"db": "public",
"columns": [
{"name": "id", "type": "integer", "primary_key": True},
{"name": "amount", "type": "decimal"},
],
}
],
"links": [],
"dialect": "postgres",
},
)
assert result.returncode == 0, result.stderr
parsed = json.loads(result.stdout)
assert parsed["source_count"] == 1
assert parsed["sources"][0]["name"] == "orders"
assert parsed["sources"][0]["table"] == "public.orders"
assert parsed["sources"][0]["measures"] == [
{
"name": "record_count",
"expr": "count(id)",
"segments": [],
"description": "Count of orders records",
},
{
"name": "total_amount",
"expr": "sum(amount)",
"segments": [],
"description": "Sum of amount",
},
{
"name": "avg_amount",
"expr": "avg(amount)",
"segments": [],
"description": "Average of amount",
},
]
def test_database_introspect_command_reads_stdin_and_writes_json(
monkeypatch, capsys
) -> None:
from klo_daemon import __main__ as daemon_main
from klo_daemon.database_introspection import (
DatabaseIntrospectionResponse,
LiveDatabaseColumn,
LiveDatabaseTable,
)
def fake_introspect(request):
assert request.connection_id == "warehouse"
assert request.driver == "postgres"
assert request.schemas == ["public"]
return DatabaseIntrospectionResponse(
connection_id="warehouse",
extracted_at="2026-04-28T10:00:00+00:00",
metadata={"driver": "postgres", "schemas": ["public"]},
tables=[
LiveDatabaseTable(
catalog="warehouse",
db="public",
name="orders",
columns=[
LiveDatabaseColumn(
name="id",
type="integer",
nullable=False,
primary_key=True,
)
],
)
],
)
monkeypatch.setattr(daemon_main, "introspect_database_response", fake_introspect)
monkeypatch.setattr(
sys,
"stdin",
io.StringIO(
'{"connection_id":"warehouse","driver":"postgres","url":"postgresql://readonly@example.test/warehouse","schemas":["public"]}'
),
)
assert daemon_main.main(["database-introspect"]) == 0
captured = capsys.readouterr()
parsed = json.loads(captured.out)
assert parsed["connection_id"] == "warehouse"
assert parsed["metadata"] == {"driver": "postgres", "schemas": ["public"]}
assert parsed["tables"][0]["name"] == "orders"
assert captured.err == ""
def test_embedding_compute_command_reads_stdin_and_writes_json(
monkeypatch, capsys
) -> None:
from klo_daemon import __main__ as daemon_main
from klo_daemon.embeddings import ComputeEmbeddingResponse
def fake_compute(request):
assert request.text == "hello"
return ComputeEmbeddingResponse(embedding=[1.0, 2.0, 3.0])
monkeypatch.setattr(daemon_main, "compute_embedding_response", fake_compute)
monkeypatch.setattr(sys, "stdin", io.StringIO('{"text": "hello"}'))
assert daemon_main.main(["embedding-compute"]) == 0
captured = capsys.readouterr()
assert json.loads(captured.out) == {"embedding": [1.0, 2.0, 3.0]}
assert captured.err == ""
def test_embedding_compute_bulk_command_reads_stdin_and_writes_json(
monkeypatch, capsys
) -> None:
from klo_daemon import __main__ as daemon_main
from klo_daemon.embeddings import ComputeEmbeddingBulkResponse
def fake_compute(request):
assert request.texts == ["hello", "world"]
return ComputeEmbeddingBulkResponse(embeddings=[[1.0, 2.0], [3.0, 4.0]])
monkeypatch.setattr(daemon_main, "compute_embedding_bulk_response", fake_compute)
monkeypatch.setattr(sys, "stdin", io.StringIO('{"texts": ["hello", "world"]}'))
assert daemon_main.main(["embedding-compute-bulk"]) == 0
captured = capsys.readouterr()
assert json.loads(captured.out) == {"embeddings": [[1.0, 2.0], [3.0, 4.0]]}
assert captured.err == ""
def test_code_execute_command_reads_stdin_and_writes_json(monkeypatch, capsys) -> None:
from klo_daemon import __main__ as daemon_main
from klo_daemon.code_execution import ExecuteCodeResponse
calls: list[dict[str, Any]] = []
def fake_execute(request, *, nest_api_url, auth_header):
calls.append(
{
"request": request,
"nest_api_url": nest_api_url,
"auth_header": auth_header,
}
)
return ExecuteCodeResponse(
formatted_result="\n\n=== Result ===\n\n7",
result=7,
)
monkeypatch.setattr(daemon_main, "execute_code_response", fake_execute)
monkeypatch.setattr(sys, "stdin", io.StringIO('{"code": "result = 7"}'))
assert daemon_main.main(["code-execute"]) == 0
captured = capsys.readouterr()
assert json.loads(captured.out) == {
"formatted_result": "\n\n=== Result ===\n\n7",
"result": 7,
"console_output": None,
"error": None,
"message": None,
"visualizations": None,
}
assert captured.err == ""
assert calls[0]["request"].code == "result = 7"
assert calls[0]["nest_api_url"] is None
assert calls[0]["auth_header"] is None

View file

@ -0,0 +1,210 @@
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any
import numpy as np
import orjson
import pandas as pd
import pytest
from klo_daemon.code_execution import (
ExecuteCodeRequest,
create_scratchpad_helpers,
detect_visualizations,
dumps_numpy_json,
execute_code_response,
)
@dataclass
class FakeResponse:
json_payload: dict[str, Any] | None = None
content: bytes = b""
headers: dict[str, str] | None = None
def raise_for_status(self) -> None:
return None
def json(self) -> dict[str, Any]:
return self.json_payload or {}
class FakeHttpClient:
def __init__(self) -> None:
self.posts: list[dict[str, Any]] = []
self.gets: list[dict[str, Any]] = []
def post(
self,
url: str,
data: bytes,
headers: dict[str, str],
timeout: int,
) -> FakeResponse:
self.posts.append(
{
"url": url,
"data": orjson.loads(data),
"headers": headers,
"timeout": timeout,
}
)
return FakeResponse(json_payload={"filename": "saved.json"})
def get(
self,
url: str,
headers: dict[str, str],
timeout: int,
) -> FakeResponse:
self.gets.append({"url": url, "headers": headers, "timeout": timeout})
return FakeResponse(
content=b"value,name\n1.25,alpha\n",
headers={"content-type": "text/csv; charset=utf-8"},
)
def test_execute_code_response_captures_console_result_and_strips_ansi() -> None:
response = execute_code_response(
ExecuteCodeRequest(
code='print("\\x1b[31mhello\\x1b[0m")\nresult = {"value": 3}',
),
nest_api_url=None,
auth_header=None,
)
assert response.result == {"value": 3}
assert response.console_output == "\x1b[31mhello\x1b[0m\n"
assert "=== Console Output ===" in response.formatted_result
assert "hello" in response.formatted_result
assert "\x1b" not in response.formatted_result
assert "=== Result ===" in response.formatted_result
def test_execute_code_response_returns_message_when_result_is_absent() -> None:
response = execute_code_response(
ExecuteCodeRequest(code='print("ran")'),
nest_api_url=None,
auth_header=None,
)
assert response.result is None
assert (
response.message == "Code executed successfully but no result variable was set"
)
assert response.console_output == "ran\n"
assert "=== Message ===" in response.formatted_result
def test_execute_code_response_detects_visualization_records() -> None:
response = execute_code_response(
ExecuteCodeRequest(
code="result = "
+ json.dumps(
{
"type": "visualization",
"vis_type": "bar",
"config": {"title": "Revenue"},
"data": [{"month": "Jan", "revenue": 10}],
"title": "Revenue",
}
),
),
nest_api_url=None,
auth_header=None,
)
assert response.visualizations is not None
assert len(response.visualizations) == 1
assert response.visualizations[0].vis_type == "bar"
assert response.visualizations[0].title == "Revenue"
def test_detect_visualizations_filters_mixed_lists() -> None:
visualizations = detect_visualizations(
[
{"type": "note", "text": "skip"},
{
"type": "visualization",
"vis_type": "table",
"config": {"title": "Rows"},
"data": [{"row": 1}],
},
]
)
assert visualizations == [
{
"type": "visualization",
"vis_type": "table",
"config": {"title": "Rows"},
"data": [{"row": 1}],
}
]
def test_scratchpad_and_visualization_helpers_serialize_numpy_scalars() -> None:
client = FakeHttpClient()
save_df, read_file, save_viz = create_scratchpad_helpers(
nest_api_url="http://nest",
auth_header="Bearer token",
source_id="source_123",
message_id="message_456",
http_client=client,
)
df = pd.DataFrame({"value": [np.float64(1.25)]})
assert save_df(df, filename="df.json") == "1 rows saved to saved.json"
read_df = read_file("input.csv")
assert read_df.to_dict(orient="records") == [{"value": 1.25, "name": "alpha"}]
viz_ref = save_viz(
vis_type="bar",
config={"title": "Test", "x": "a", "y": np.float64(2.5)},
data=[{"a": "row1", "b": np.float64(3.75)}],
)
assert viz_ref == "![viz](saved.json)"
assert (
client.posts[0]["url"] == "http://nest/private_api/scratchpad/source_123/files"
)
assert client.posts[0]["data"]["data"][0]["value"] == 1.25
assert (
client.gets[0]["url"]
== "http://nest/private_api/scratchpad/source_123/files/input.csv?format=raw"
)
assert client.posts[1]["url"] == "http://nest/private_api/visualizations/source_123"
assert client.posts[1]["data"]["config"]["y"] == 2.5
assert client.posts[1]["data"]["data"][0]["b"] == 3.75
def test_scratchpad_helpers_require_app_context_only_when_called() -> None:
save_df, read_file, save_viz = create_scratchpad_helpers(
nest_api_url=None,
auth_header=None,
source_id=None,
message_id=None,
)
with pytest.raises(ValueError, match="required for scratchpad operations"):
save_df(pd.DataFrame({"value": [1]}), filename="df.json")
with pytest.raises(ValueError, match="required for scratchpad operations"):
read_file("df.csv")
with pytest.raises(ValueError, match="required for visualization operations"):
save_viz("bar", {"title": "Chart"}, [{"value": 1}])
def test_dumps_numpy_json_serializes_numpy_values() -> None:
rendered = dumps_numpy_json(
{
"scalar": np.float64(1.5),
"array": np.array([1, 2, 3]),
}
)
assert orjson.loads(rendered) == {"scalar": 1.5, "array": [1, 2, 3]}

View file

@ -0,0 +1,153 @@
from __future__ import annotations
import pytest
from klo_daemon.database_introspection import (
DatabaseIntrospectionRequest,
DatabaseIntrospectionRows,
_statement_timeout_config,
introspect_database_response,
)
def test_introspect_database_response_maps_postgres_catalog_rows() -> None:
def fake_load_rows(
request: DatabaseIntrospectionRequest,
) -> DatabaseIntrospectionRows:
assert request.connection_id == "warehouse"
assert request.driver == "postgres"
assert request.schemas == ["public"]
return DatabaseIntrospectionRows(
table_rows=[
{
"table_catalog": "warehouse",
"table_schema": "public",
"table_name": "customers",
"table_comment": None,
},
{
"table_catalog": "warehouse",
"table_schema": "public",
"table_name": "orders",
"table_comment": "Orders table",
},
],
column_rows=[
{
"table_catalog": "warehouse",
"table_schema": "public",
"table_name": "orders",
"column_name": "id",
"formatted_type": "integer",
"is_nullable": False,
"is_primary_key": True,
"column_comment": "Order ID",
},
{
"table_catalog": "warehouse",
"table_schema": "public",
"table_name": "orders",
"column_name": "customer_id",
"formatted_type": "integer",
"is_nullable": False,
"is_primary_key": False,
"column_comment": None,
},
{
"table_catalog": "warehouse",
"table_schema": "public",
"table_name": "customers",
"column_name": "id",
"formatted_type": "integer",
"is_nullable": False,
"is_primary_key": True,
"column_comment": None,
},
],
foreign_key_rows=[
{
"table_catalog": "warehouse",
"table_schema": "public",
"table_name": "orders",
"from_column": "customer_id",
"to_table": "customers",
"to_column": "id",
"constraint_name": "orders_customer_id_fkey",
}
],
)
response = introspect_database_response(
DatabaseIntrospectionRequest(
connection_id="warehouse",
driver="postgres",
url="postgresql://readonly@example.test/warehouse",
schemas=["public"],
),
load_rows=fake_load_rows,
now=lambda: "2026-04-28T10:00:00+00:00",
)
assert response.connection_id == "warehouse"
assert response.extracted_at == "2026-04-28T10:00:00+00:00"
assert response.metadata == {"driver": "postgres", "schemas": ["public"]}
assert [table.name for table in response.tables] == ["customers", "orders"]
orders = response.tables[1]
assert orders.model_dump(exclude_none=True) == {
"catalog": "warehouse",
"db": "public",
"name": "orders",
"comment": "Orders table",
"columns": [
{
"name": "id",
"type": "integer",
"nullable": False,
"primary_key": True,
"comment": "Order ID",
},
{
"name": "customer_id",
"type": "integer",
"nullable": False,
"primary_key": False,
},
],
"foreign_keys": [
{
"from_column": "customer_id",
"to_table": "customers",
"to_column": "id",
"constraint_name": "orders_customer_id_fkey",
}
],
}
def test_introspect_database_response_rejects_non_postgres_driver() -> None:
with pytest.raises(ValueError, match='supports only driver "postgres"'):
introspect_database_response(
DatabaseIntrospectionRequest(
connection_id="warehouse",
driver="snowflake",
url="snowflake://example",
),
load_rows=lambda request: DatabaseIntrospectionRows([], [], []),
)
def test_database_introspection_request_rejects_empty_schema_list() -> None:
with pytest.raises(ValueError, match="at least one schema"):
DatabaseIntrospectionRequest(
connection_id="warehouse",
driver="postgres",
url="postgresql://readonly@example.test/warehouse",
schemas=[],
)
def test_statement_timeout_config_uses_parameterized_set_config() -> None:
assert _statement_timeout_config(30_000) == (
"SELECT set_config('statement_timeout', %s, true)",
("30000ms",),
)

View file

@ -0,0 +1,107 @@
from __future__ import annotations
import pytest
from klo_daemon.embeddings import (
ComputeEmbeddingBulkRequest,
ComputeEmbeddingRequest,
SentenceTransformersEmbeddingProvider,
compute_embedding_bulk_response,
compute_embedding_response,
)
class FakeEmbeddingProvider:
name = "fake"
dimensions = 3
max_batch_size = 2
def __init__(self) -> None:
self.calls: list[list[str]] = []
def encode(self, texts: list[str]) -> list[list[float]]:
self.calls.append(list(texts))
return [
[float(len(text)), float(index), 1.0] for index, text in enumerate(texts)
]
class ArrayLike:
def __init__(self, value: list[float] | list[list[float]]) -> None:
self.value = value
def tolist(self) -> list[float] | list[list[float]]:
return self.value
class FakeSentenceTransformerModel:
def __init__(self) -> None:
self.calls: list[str | list[str]] = []
def encode(self, value: str | list[str]) -> ArrayLike:
self.calls.append(value)
if isinstance(value, str):
return ArrayLike([0.1, 0.2, 0.3])
return ArrayLike(
[[float(index), float(len(text)), 0.5] for index, text in enumerate(value)]
)
def test_compute_embedding_response_uses_injected_provider() -> None:
provider = FakeEmbeddingProvider()
response = compute_embedding_response(
ComputeEmbeddingRequest(text="hello"),
provider=provider,
)
assert response.embedding == [5.0, 0.0, 1.0]
assert provider.calls == [["hello"]]
def test_compute_embedding_bulk_response_uses_injected_provider() -> None:
provider = FakeEmbeddingProvider()
response = compute_embedding_bulk_response(
ComputeEmbeddingBulkRequest(texts=["one", "three"]),
provider=provider,
)
assert response.embeddings == [[3.0, 0.0, 1.0], [5.0, 1.0, 1.0]]
assert provider.calls == [["one", "three"]]
def test_compute_embedding_bulk_rejects_empty_texts() -> None:
provider = FakeEmbeddingProvider()
with pytest.raises(ValueError, match="Empty texts found at indices: 1"):
compute_embedding_bulk_response(
ComputeEmbeddingBulkRequest(texts=["valid", " "]),
provider=provider,
)
assert provider.calls == []
def test_compute_embedding_bulk_respects_provider_batch_size() -> None:
provider = FakeEmbeddingProvider()
with pytest.raises(ValueError, match="Maximum 2 texts allowed per batch"):
compute_embedding_bulk_response(
ComputeEmbeddingBulkRequest(texts=["one", "two", "three"]),
provider=provider,
)
assert provider.calls == []
def test_sentence_transformers_provider_normalizes_single_and_bulk_outputs() -> None:
model = FakeSentenceTransformerModel()
provider = SentenceTransformersEmbeddingProvider(model=model)
assert provider.encode(["hello"]) == [[0.1, 0.2, 0.3]]
assert provider.encode(["one", "three"]) == [
[0.0, 3.0, 0.5],
[1.0, 5.0, 0.5],
]
assert model.calls == ["hello", ["one", "three"]]

View file

@ -0,0 +1,134 @@
from __future__ import annotations
from klo_daemon.lookml import (
LookMLFileInput,
ParseLookMLRequest,
parse_lookml_project,
)
ORDER_VIEW = """
view: orders {
sql_table_name: public.orders ;;
dimension: id {
primary_key: yes
type: number
sql: ${TABLE}.id ;;
}
dimension: user_id {
type: number
sql: ${TABLE}.user_id ;;
}
dimension: status {
type: string
sql: ${TABLE}.status ;;
}
measure: order_count {
type: count
}
measure: revenue {
type: sum
sql: ${TABLE}.amount ;;
}
}
"""
USER_VIEW = """
view: users {
sql_table_name: public.users ;;
dimension: id {
primary_key: yes
type: number
sql: ${TABLE}.id ;;
}
}
"""
ORDER_MODEL = """
explore: orders {
join: users {
relationship: many_to_one
sql_on: ${orders.user_id} = ${users.id} ;;
}
}
"""
DERIVED_VIEW = """
view: order_rollup {
derived_table: {
sql:
SELECT status, SUM(amount) AS total_amount
FROM public.orders
GROUP BY status ;;
}
dimension: status {
type: string
sql: ${TABLE}.status ;;
}
}
"""
def test_parse_lookml_project_returns_views_and_joins() -> None:
response = parse_lookml_project(
ParseLookMLRequest(
files=[
LookMLFileInput(path="views/orders.view.lkml", content=ORDER_VIEW),
LookMLFileInput(path="views/users.view.lkml", content=USER_VIEW),
LookMLFileInput(
path="models/ecommerce.model.lkml", content=ORDER_MODEL
),
],
dialect="postgres",
)
)
views = {view.name: view for view in response.views}
assert sorted(views) == ["orders", "users"]
assert views["orders"].source_type == "table"
assert views["orders"].table_ref == "public.orders"
assert views["orders"].grain == ["id"]
assert [measure.name for measure in views["orders"].measures] == [
"order_count",
"revenue",
]
assert views["orders"].measures[0].expr == "count(*)"
assert views["orders"].measures[1].expr == "sum(amount)"
assert response.joins[0].source_view == "orders"
assert response.joins[0].to == "users"
assert response.joins[0].relationship == "many_to_one"
assert response.joins[0].on == "orders.user_id = users.id"
assert response.skipped_views == []
assert response.warnings == []
def test_parse_lookml_project_extracts_derived_table_columns() -> None:
response = parse_lookml_project(
ParseLookMLRequest(
files=[
LookMLFileInput(
path="views/order_rollup.view.lkml", content=DERIVED_VIEW
)
],
dialect="postgres",
)
)
assert len(response.views) == 1
view = response.views[0]
assert view.name == "order_rollup"
assert view.source_type == "sql"
assert "SELECT status, SUM(amount) AS total_amount" in (view.sql or "")
assert [column.name for column in view.columns] == ["status", "total_amount"]
assert response.skipped_views == []
assert response.warnings == []

View file

@ -0,0 +1,6 @@
from klo_daemon import PACKAGE_NAME, VERSION
def test_package_metadata() -> None:
assert PACKAGE_NAME == "klo-daemon"
assert VERSION == "0.1.0"

View file

@ -0,0 +1,64 @@
from __future__ import annotations
from klo_daemon.semantic_layer import (
SemanticLayerQueryRequest,
ValidateSourcesRequest,
query_semantic_layer,
validate_semantic_layer,
)
ORDERS_SOURCE = {
"name": "orders",
"table": "public.orders",
"grain": ["id"],
"columns": [
{"name": "id", "type": "number"},
{"name": "status", "type": "string"},
{"name": "amount", "type": "number"},
],
"joins": [],
"measures": [
{"name": "order_count", "expr": "count(*)"},
{"name": "revenue", "expr": "sum(amount)"},
],
}
def test_query_semantic_layer_generates_sql_and_plan() -> None:
response = query_semantic_layer(
SemanticLayerQueryRequest(
sources=[ORDERS_SOURCE],
dialect="postgres",
query={
"measures": ["orders.order_count"],
"dimensions": ["orders.status"],
"limit": 25,
},
)
)
assert response.dialect == "postgres"
assert "public.orders" in response.sql
assert "orders.status" in response.sql
assert response.columns[0]["name"] == "orders.status"
assert response.columns[1]["name"] == "orders.order_count"
assert response.plan["sources_used"] == ["orders"]
def test_validate_semantic_layer_reports_duplicate_measure_names() -> None:
invalid_source = {
**ORDERS_SOURCE,
"measures": [
{"name": "revenue", "expr": "sum(amount)"},
{"name": "revenue", "expr": "sum(amount)"},
],
}
response = validate_semantic_layer(
ValidateSourcesRequest(sources=[invalid_source], dialect="postgres")
)
assert response.valid is False
assert any("Duplicate measure" in error for error in response.errors)
assert response.warnings == []

View file

@ -0,0 +1,161 @@
from __future__ import annotations
from klo_daemon.source_generation import (
ColumnInput,
GenerateSourcesRequest,
LinkInput,
TableInput,
generate_sources,
generate_sources_response,
)
def test_generate_sources_maps_tables_columns_measures_and_joins() -> None:
response = generate_sources_response(
GenerateSourcesRequest(
tables=[
TableInput(
name="orders",
db="public",
comment="Orders table",
columns=[
ColumnInput(
name="id",
type="integer",
primary_key=True,
nullable=False,
comment="Order ID",
),
ColumnInput(name="customer_id", type="integer"),
ColumnInput(
name="amount", type="decimal", comment="Order amount"
),
ColumnInput(name="created_at", type="timestamp"),
ColumnInput(name="status", type="varchar"),
],
),
TableInput(
name="customers",
db="public",
columns=[
ColumnInput(name="id", type="integer", primary_key=True),
ColumnInput(name="email", type="varchar"),
],
),
],
links=[
LinkInput(
from_table="orders",
from_column="customer_id",
to_table="customers",
to_column="id",
relationship_type="MANY_TO_ONE",
)
],
)
)
assert response.source_count == 2
sources = {source["name"]: source for source in response.sources}
assert sources["orders"]["description"] == "Orders table"
assert sources["orders"]["table"] == "public.orders"
assert sources["orders"]["grain"] == ["id"]
assert sources["orders"]["columns"] == [
{
"name": "id",
"type": "number",
"visibility": "public",
"role": "default",
"description": "Order ID",
},
{
"name": "customer_id",
"type": "number",
"visibility": "public",
"role": "default",
},
{
"name": "amount",
"type": "number",
"visibility": "public",
"role": "default",
"description": "Order amount",
},
{"name": "created_at", "type": "time", "visibility": "public", "role": "time"},
{"name": "status", "type": "string", "visibility": "public", "role": "default"},
]
assert sources["orders"]["joins"] == [
{
"to": "customers",
"on": "customer_id = customers.id",
"relationship": "many_to_one",
}
]
assert [measure["name"] for measure in sources["orders"]["measures"]] == [
"record_count",
"total_amount",
"avg_amount",
]
assert sources["orders"]["measures"][0]["expr"] == "count(id)"
assert sources["orders"]["measures"][1]["expr"] == "sum(amount)"
assert sources["orders"]["measures"][2]["expr"] == "avg(amount)"
assert sources["customers"]["joins"] == [
{
"to": "orders",
"on": "id = orders.customer_id",
"relationship": "one_to_many",
}
]
def test_generate_sources_aliases_multiple_joins_to_same_table() -> None:
sources = generate_sources(
GenerateSourcesRequest(
tables=[
TableInput(
name="orders",
columns=[
ColumnInput(name="id", type="integer", primary_key=True),
ColumnInput(name="buyer_id", type="integer"),
ColumnInput(name="seller_id", type="integer"),
],
),
TableInput(
name="users",
columns=[ColumnInput(name="id", type="integer", primary_key=True)],
),
],
links=[
LinkInput(
from_table="orders",
from_column="buyer_id",
to_table="users",
to_column="id",
relationship_type="many_to_one",
),
LinkInput(
from_table="orders",
from_column="seller_id",
to_table="users",
to_column="id",
relationship_type="many_to_one",
),
],
)
)
orders = next(source for source in sources if source["name"] == "orders")
assert orders["joins"] == [
{
"to": "users",
"on": "buyer_id = users.id",
"relationship": "many_to_one",
"alias": "users_buyer_id",
},
{
"to": "users",
"on": "seller_id = users.id",
"relationship": "many_to_one",
"alias": "users_seller_id",
},
]

161
python/klo-sl/AGENTS.md Normal file
View file

@ -0,0 +1,161 @@
# Semantic Layer Engine
Python semantic layer that generates SQL from structured JSON queries. No `from` clause — sources are inferred from fully-qualified field names (`source.column`).
## Quick Start
```bash
uv run pytest -q # run all tests
uv run python -m semantic_layer.cli --help
```
## Testing Corner Cases via CLI
Use `--model` to pass a self-contained YAML model (list of source definitions) instead of a directory. This lets you test any join topology or edge case without creating files.
### 1. Create an inline model file
```yaml
# /tmp/model.yaml — a YAML list of source definitions
- name: orders
table: public.orders
grain: [id]
columns:
- {name: id, type: number}
- {name: amount, type: number}
- {name: status, type: string}
joins:
- to: customers
"on": "customer_id = customers.id"
relationship: many_to_one
measures:
- {name: revenue, expr: "sum(amount)", filter: "status != 'refunded'"}
- name: customers
table: public.customers
grain: [id]
columns:
- {name: id, type: number}
- {name: segment, type: string}
```
### 2. Run queries against it
```bash
# Basic query
uv run python -m semantic_layer.cli --model /tmp/model.yaml \
-q '{"measures":["sum(orders.amount)"],"dimensions":["customers.segment"]}'
# Pre-defined measure + filter
uv run python -m semantic_layer.cli --model /tmp/model.yaml \
-q '{"measures":["orders.revenue"],"dimensions":["orders.status"],"filters":["orders.status != '"'"'cancelled'"'"'"]}'
# Show resolved plan alongside SQL
uv run python -m semantic_layer.cli --model /tmp/model.yaml \
-q '{"measures":["orders.revenue"],"dimensions":["customers.segment"]}' --plan
# Validate without generating SQL
uv run python -m semantic_layer.cli --model /tmp/model.yaml \
-q '{"measures":["orders.revenue"],"dimensions":["customers.segment"]}' --suggest
```
### 3. Test fan-out / chasm traps
Add multiple measure sources that fan out from a shared dimension hub:
```yaml
# Two independent fact tables joining to the same dimension
- name: hub
table: public.hub
grain: [id]
columns: [{name: id, type: number}, {name: segment, type: string}]
- name: fact_a
table: public.fact_a
grain: [id]
columns: [{name: id, type: number}, {name: hub_id, type: number}, {name: val, type: number}]
joins: [{to: hub, "on": "hub_id = hub.id", relationship: many_to_one}]
- name: fact_b
table: public.fact_b
grain: [id]
columns: [{name: id, type: number}, {name: hub_id, type: number}, {name: val, type: number}]
joins: [{to: hub, "on": "hub_id = hub.id", relationship: many_to_one}]
```
```bash
# This triggers aggregate locality (separate CTEs per fact table, FULL JOIN)
uv run python -m semantic_layer.cli --model /tmp/chasm.yaml \
-q '{"measures":["sum(fact_a.val)","sum(fact_b.val)"],"dimensions":["hub.segment"]}'
```
### 4. Test derived measures
```bash
uv run python -m semantic_layer.cli --model /tmp/model.yaml \
-q '{"measures":[{"expr":"sum(orders.amount)","name":"total"},{"expr":"count(orders.id)","name":"cnt"},{"expr":"total / cnt","name":"avg_order"}],"dimensions":["customers.segment"]}'
```
### 5. Test dialects
```bash
uv run python -m semantic_layer.cli --model /tmp/model.yaml \
-q '{"measures":["sum(orders.amount)"],"dimensions":["customers.segment"]}' --dialect bigquery
```
### 6. Useful flags
| Flag | Purpose |
|------|---------|
| `--model FILE` | Single YAML file with all sources (alternative to `--sources DIR`) |
| `--plan` | Show resolved plan + SQL |
| `--plan-only` | Show plan without SQL |
| `--suggest` | Validate query, show suggestions on failure |
| `--list-sources` | Print all sources, columns, measures, joins |
| `--dialect X` | postgres (default), bigquery, snowflake, duckdb, mysql |
| `--compact` | SQL without header comment |
| `-q JSON` | Pass query as JSON string |
| `--json` | Read JSON query from stdin |
## Coding Guidelines
### Expression handling — always use sqlglot AST, never regex on SQL
- **Parse expressions** with `sqlglot.parse_one(f"SELECT {expr}")` and walk/transform the AST. Never use `str.replace()`, `re.sub()`, or string splitting on SQL fragments — these corrupt string literals, aliases, and nested expressions.
- **Quote reserved words first**: always call `quote_reserved_identifiers(expr)` before passing to `sqlglot.parse_one()`. Column/source names like `group`, `key`, `order` will fail to parse otherwise.
- **Use the parse cache** in `parser.py` (`ExpressionParser._parse_as_select()`) for read-only AST walks. Direct `sqlglot.parse_one()` calls are fine when you need to `.transform()` the tree.
- **Regex is fine for non-SQL tasks**: sanitizing alias names, masking string literals before parse, etc. The rule is: don't use regex to interpret SQL structure.
### Error handling
- Never use bare `except Exception: pass`. At minimum add `logger.debug(...)` so failures are observable. Prefer catching `sqlglot.errors.ParseError` specifically.
- Regex fallback paths in generator.py exist for edge cases where sqlglot can't parse user-provided SQL sources. These are acceptable as last-resort fallbacks with logging, not as primary code paths.
### SQL generation strategy
- **Write postgres, transpile on output.** All SQL is generated as postgres dialect. `_transpile()` converts to the target dialect at the very end. Never add dialect-specific SQL generation logic.
- **f-strings for SQL skeleton** (`SELECT/FROM/JOIN/GROUP BY`) are fine and readable. Use sqlglot AST only for expression-level transformations (substitution, function translation, filter rewriting).
- **Don't build SQL via sqlglot node construction** (`exp.Select().from_(...)`). It's harder to read and debug than f-strings for structural SQL.
### Testing
- Run `uv run pytest -q` after every change. All tests must pass.
- Test CLI queries with `--model /tmp/model.yaml` for quick iteration on edge cases (see examples above).
- When adding expression handling logic, test with reserved-word identifiers (`group.key`, `order.select`) and string literals containing dots (`status = 'group.value'`).
## Project Structure
```
semantic_layer/
models.py # Pydantic data models (sources, queries, plans, results)
loader.py # YAML source file loader
graph.py # Bidirectional join graph with Dijkstra + Steiner tree
parser.py # Expression parser (source refs, aggregate detection)
planner.py # 12-step query planning pipeline
generator.py # SQL generation (simple path + aggregate locality)
engine.py # Orchestrator tying loader/graph/planner/generator
cli.py # CLI entry point
sources/
ecommerce/ # Test fixtures (6 YAML source definitions)
tests/ # 353 tests
```

1
python/klo-sl/CLAUDE.md Symbolic link
View file

@ -0,0 +1 @@
AGENTS.md

0
python/klo-sl/README.md Normal file
View file

View file

@ -0,0 +1,222 @@
# Complex CTE Runtime Join Demo
#
# Demonstrates:
# 1. Two SQL sources with internal CTEs (customer_lifetime_value, churn_risk)
# 2. Both join to `customers` at the source level (many_to_one)
# 3. `customers` joins to `regions` (many_to_one)
# 4. A query requesting measures from BOTH SQL sources + dimensions from `regions`
# triggers chasm trap detection and aggregate locality
#
# Join graph:
# customer_lifetime_value --m2o--> customers --m2o--> regions
# churn_risk --m2o--> customers --m2o--> regions
# --- Table sources ---
- name: regions
table: public.regions
grain: [id]
columns:
- name: id
type: number
- name: name
type: string
- name: continent
type: string
- name: customers
table: public.customers
grain: [id]
columns:
- name: id
type: number
- name: name
type: string
- name: segment
type: string
- name: region_id
type: number
- name: signed_at
type: time
role: time
- name: arr
type: number
joins:
- to: regions
"on": region_id = regions.id
relationship: many_to_one
- name: orders
table: public.orders
grain: [id]
columns:
- name: id
type: number
- name: customer_id
type: number
- name: amount
type: number
- name: created_at
type: time
role: time
joins:
- to: customers
"on": customer_id = customers.id
relationship: many_to_one
- name: order_items
table: public.order_items
grain: [id]
columns:
- name: id
type: number
- name: order_id
type: number
- name: quantity
type: number
- name: unit_price
type: number
joins:
- to: orders
"on": order_id = orders.id
relationship: many_to_one
# --- SQL source: Customer Lifetime Value (uses internal CTEs) ---
- name: customer_lifetime_value
description: |
Customer lifetime value estimate using monthly revenue cohort analysis.
Internal CTEs aggregate orders+order_items by month, then compute
active_months and avg_mrr per customer before estimating LTV.
sql: |
WITH monthly_revenue AS (
SELECT
o.customer_id,
DATE_TRUNC('month', o.created_at) AS month,
SUM(oi.quantity * oi.unit_price) AS mrr
FROM orders o
JOIN order_items oi ON o.id = oi.order_id
GROUP BY o.customer_id, DATE_TRUNC('month', o.created_at)
),
cohort_stats AS (
SELECT
customer_id,
MIN(month) AS first_month,
COUNT(DISTINCT month) AS active_months,
AVG(mrr) AS avg_mrr
FROM monthly_revenue
GROUP BY customer_id
)
SELECT
cs.customer_id,
cs.first_month,
cs.active_months,
cs.avg_mrr,
cs.avg_mrr * cs.active_months * 1.2 AS ltv_estimate
FROM cohort_stats cs
grain: [customer_id]
columns:
- name: customer_id
type: number
- name: first_month
type: time
- name: active_months
type: number
- name: avg_mrr
type: number
- name: ltv_estimate
type: number
joins:
- to: customers
"on": customer_id = customers.id
relationship: many_to_one
measures:
- name: avg_ltv
expr: avg(ltv_estimate)
description: "Average customer lifetime value"
- name: total_ltv
expr: sum(ltv_estimate)
description: "Total lifetime value across customers"
- name: avg_active_months
expr: avg(active_months)
description: "Average number of active months per customer"
# --- SQL source: Churn Risk (uses internal CTEs) ---
- name: churn_risk
description: |
Customer churn risk score combining recency, frequency, and support burden.
Internal CTEs compute rfm_scores from orders and ticket_counts from a
support table before producing a weighted composite score.
sql: |
WITH rfm_scores AS (
SELECT
customer_id,
EXTRACT(DAY FROM NOW() - MAX(created_at)) AS days_since_last_order,
COUNT(*) AS order_frequency,
AVG(amount) AS avg_order_value
FROM orders
GROUP BY customer_id
),
ticket_counts AS (
SELECT
customer_id,
COUNT(*) AS open_tickets,
AVG(EXTRACT(DAY FROM resolved_at - created_at)) AS avg_resolution_days
FROM support_tickets
WHERE status = 'open'
GROUP BY customer_id
)
SELECT
r.customer_id,
r.days_since_last_order,
r.order_frequency,
COALESCE(t.open_tickets, 0) AS open_tickets,
CASE
WHEN r.days_since_last_order > 180 THEN 0.9
WHEN r.days_since_last_order > 90 THEN 0.6
ELSE 0.2
END * 0.4
+ CASE
WHEN r.order_frequency < 2 THEN 0.8
WHEN r.order_frequency < 5 THEN 0.4
ELSE 0.1
END * 0.3
+ CASE
WHEN COALESCE(t.open_tickets, 0) > 3 THEN 0.9
WHEN COALESCE(t.open_tickets, 0) > 1 THEN 0.5
ELSE 0.1
END * 0.3 AS score,
CASE
WHEN r.avg_order_value < 100 THEN 'SMB'
WHEN r.avg_order_value < 1000 THEN 'Mid-Market'
ELSE 'Enterprise'
END AS customer_type
FROM rfm_scores r
LEFT JOIN ticket_counts t ON r.customer_id = t.customer_id
grain: [customer_id]
columns:
- name: customer_id
type: number
- name: days_since_last_order
type: number
- name: order_frequency
type: number
- name: open_tickets
type: number
- name: score
type: number
- name: customer_type
type: string
joins:
- to: customers
"on": customer_id = customers.id
relationship: many_to_one
measures:
- name: avg_risk
expr: avg(score)
description: "Average churn risk score"
- name: high_risk_count
expr: count(customer_id)
filter: "score > 0.7"
description: "Number of high-risk customers"

View file

@ -0,0 +1,60 @@
#!/usr/bin/env bash
# Complex CTE Runtime Join Demo
#
# Shows how two SQL sources with internal CTEs (customer_lifetime_value, churn_risk)
# are joined at runtime through the join graph to a dimension table (regions),
# triggering chasm trap detection and aggregate locality.
set -euo pipefail
cd "$(dirname "$0")/.."
MODEL="demos/complex_cte_join.yaml"
echo "============================================"
echo " Demo 1: Chasm Trap — Two CTE metrics + regions dimension"
echo "============================================"
echo ""
echo "Query: Average LTV and average churn risk by region,"
echo " for customers with churn score > 0.7"
echo ""
echo '{
"measures": ["customer_lifetime_value.avg_ltv", "churn_risk.avg_risk"],
"dimensions": ["regions.name"],
"filters": ["churn_risk.score > 0.7"]
}' | uv run python -m semantic_layer.cli --model "$MODEL" --json --plan
echo ""
echo "============================================"
echo " Demo 2: Single CTE metric enriched with regions"
echo "============================================"
echo ""
echo "Query: LTV breakdown by region and customer segment,"
echo " only customers with 6+ active months"
echo ""
echo '{
"measures": [
"customer_lifetime_value.avg_ltv",
"customer_lifetime_value.avg_active_months",
{"expr": "count(customer_lifetime_value.customer_id)", "name": "customer_count"}
],
"dimensions": ["regions.name", "customers.segment"],
"filters": ["customer_lifetime_value.active_months >= 6"]
}' | uv run python -m semantic_layer.cli --model "$MODEL" --json --plan
echo ""
echo "============================================"
echo " Demo 3: Runtime aggregation on CTE columns + cross-source join"
echo "============================================"
echo ""
echo "Query: P90 churn score and max LTV by region continent"
echo ""
echo '{
"measures": [
{"expr": "percentile(churn_risk.score, 0.9)", "name": "p90_churn"},
{"expr": "max(customer_lifetime_value.ltv_estimate)", "name": "max_ltv"}
],
"dimensions": ["regions.continent"]
}' | uv run python -m semantic_layer.cli --model "$MODEL" --json --plan

View file

@ -0,0 +1,59 @@
[project]
name = "klo-sl"
version = "0.1.0"
description = "Agent-first semantic layer engine with aggregate locality"
readme = "README.md"
requires-python = ">=3.13"
license = "Apache-2.0"
dependencies = [
"sqlglot>=26",
"pydantic>=2",
"pyyaml>=6",
]
[project.urls]
Homepage = "https://github.com/kaelio/ktx"
Repository = "https://github.com/kaelio/ktx"
Issues = "https://github.com/kaelio/ktx/issues"
[project.optional-dependencies]
dev = [
"pytest>=8",
"pytest-cov",
"ruff",
"pre-commit",
]
tpch = [
"duckdb>=1.0",
]
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["tests"]
addopts = "--cov=semantic_layer --cov-report=term-missing --cov-report=html"
[tool.coverage.run]
source = ["semantic_layer"]
branch = true
[tool.coverage.report]
show_missing = true
skip_empty = true
exclude_lines = [
"pragma: no cover",
"if __name__ == .__main__.",
"if TYPE_CHECKING:",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["semantic_layer"]
[dependency-groups]
dev = [
"pytest>=9.0.2",
"pytest-cov>=7.1.0",
]

View file

@ -0,0 +1,219 @@
#!/usr/bin/env python3
"""Generate semantic layer YAML sources from demo DB metadata.
Usage:
kubectl port-forward -n klo-demo deployment/klo-demo-db 5433:5432 &
KLO_DEMO_DB_PASSWORD=local-demo-password python scripts/gen_b2b_saas_model.py
"""
import os
import psycopg2
import yaml
CONNECTION_ID = "256bc76b-cc47-4d5d-a9fc-5bcfb0364d44"
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "..", "sources", "b2b_saas")
DB_PARAMS = {
"host": os.environ.get("KLO_DEMO_DB_HOST", "127.0.0.1"),
"port": int(os.environ.get("KLO_DEMO_DB_PORT", "5433")),
"user": os.environ.get("KLO_DEMO_DB_USER", "klo-demo-user"),
"password": os.environ.get("KLO_DEMO_DB_PASSWORD", ""),
"dbname": os.environ.get("KLO_DEMO_DB_NAME", "klo-demo-db"),
}
# Map DB types to semantic layer types
TYPE_MAP = {
"INTEGER": "number",
"FLOAT": "number",
"NUMERIC": "number",
"DECIMAL": "number",
"BIGINT": "number",
"SMALLINT": "number",
"DOUBLE": "number",
"REAL": "number",
"VARCHAR": "string",
"TEXT": "string",
"CHAR": "string",
"DATE": "time",
"TIMESTAMP": "time",
"TIMESTAMPTZ": "time",
"DATETIME": "time",
"TIME": "time",
"BOOLEAN": "boolean",
"BOOL": "boolean",
}
# Columns whose names suggest a time role
TIME_PATTERNS = {"_at", "_date", "date", "timestamp", "created", "updated"}
def is_time_column(name: str, db_type: str) -> bool:
sl_type = TYPE_MAP.get(db_type.upper(), "string")
if sl_type == "time":
return True
# VARCHAR columns with date-like names (e.g. created_at stored as VARCHAR)
lower = name.lower()
return any(p in lower for p in TIME_PATTERNS) and sl_type == "string"
def map_type(db_type: str, col_name: str) -> str:
upper = db_type.upper()
if upper in TYPE_MAP:
base = TYPE_MAP[upper]
# Override string→time for date-like column names
if base == "string" and is_time_column(col_name, db_type):
return "time"
return base
return "string"
def main():
conn = psycopg2.connect(**DB_PARAMS)
cur = conn.cursor()
# 1. Fetch tables
cur.execute(
"SELECT id, name FROM source_tables WHERE connection_id = %s ORDER BY name",
(CONNECTION_ID,),
)
tables = {row[0]: row[1] for row in cur.fetchall()}
table_ids = tuple(tables.keys())
# 2. Fetch columns
cur.execute(
"""
SELECT id, name, type, nullable, primary_key, table_id
FROM source_columns
WHERE table_id = ANY(%s::uuid[])
ORDER BY table_id, primary_key DESC, name
""",
(list(table_ids),),
)
columns_by_table: dict[str, list] = {}
col_id_to_info: dict[str, dict] = {}
for row in cur.fetchall():
col_id, col_name, col_type, nullable, is_pk, table_id = row
info = {
"id": col_id,
"name": col_name,
"type": col_type,
"nullable": nullable,
"primary_key": is_pk,
"table_id": table_id,
}
col_id_to_info[col_id] = info
columns_by_table.setdefault(table_id, []).append(info)
# 3. Fetch links (joins)
cur.execute(
"""
SELECT from_table_id, from_column_id, to_table_id, to_column_id, relationship_type
FROM column_links
WHERE from_table_id = ANY(%s::uuid[]) OR to_table_id = ANY(%s::uuid[])
""",
(list(table_ids), list(table_ids)),
)
# Group links by from_table
joins_by_table: dict[str, list] = {}
for row in cur.fetchall():
from_table_id, from_col_id, to_table_id, to_col_id, rel_type = row
# Only include joins where both sides are in our connection
if from_table_id not in tables or to_table_id not in tables:
continue
joins_by_table.setdefault(from_table_id, []).append(
{
"from_col_id": from_col_id,
"to_table_id": to_table_id,
"to_col_id": to_col_id,
"relationship_type": rel_type,
}
)
conn.close()
# 4. Generate YAML files
os.makedirs(OUTPUT_DIR, exist_ok=True)
for table_id, table_name in sorted(tables.items(), key=lambda x: x[1]):
cols = columns_by_table.get(table_id, [])
joins = joins_by_table.get(table_id, [])
# Find primary key columns
pk_cols = [c for c in cols if c["primary_key"]]
if pk_cols:
grain = [c["name"] for c in pk_cols]
else:
# Fallback: use row_id if present, else first column
row_id_col = next((c for c in cols if c["name"] == "row_id"), None)
if row_id_col:
grain = ["row_id"]
elif cols:
grain = [cols[0]["name"]]
else:
grain = [table_name + "_id"]
# Build column definitions
yaml_columns = []
for c in cols:
sl_type = map_type(c["type"], c["name"])
col_def: dict = {"name": c["name"], "type": sl_type}
if is_time_column(c["name"], c["type"]):
col_def["role"] = "time"
yaml_columns.append(col_def)
# Build join definitions
yaml_joins = []
# Track target sources to handle aliases for multiple joins to same target
target_counts: dict[str, int] = {}
for j in joins:
to_name = tables.get(j["to_table_id"])
if not to_name:
continue
target_counts[to_name] = target_counts.get(to_name, 0) + 1
target_seen: dict[str, int] = {}
for j in joins:
to_name = tables.get(j["to_table_id"])
from_col = col_id_to_info.get(j["from_col_id"], {}).get("name")
to_col = col_id_to_info.get(j["to_col_id"], {}).get("name")
if not (to_name and from_col and to_col):
continue
rel = j["relationship_type"].lower()
join_def: dict = {
"to": to_name,
"on": f"{from_col} = {to_name}.{to_col}",
"relationship": rel,
}
# Add alias if multiple joins to same target
target_seen[to_name] = target_seen.get(to_name, 0) + 1
if target_counts.get(to_name, 0) > 1:
join_def["alias"] = f"{to_name}_{target_seen[to_name]}"
yaml_joins.append(join_def)
# Build source definition
source: dict = {
"name": table_name,
"table": table_name,
}
if grain:
source["grain"] = grain
source["columns"] = yaml_columns
if yaml_joins:
source["joins"] = yaml_joins
# Write YAML
filepath = os.path.join(OUTPUT_DIR, f"{table_name}.yaml")
with open(filepath, "w") as f:
yaml.dump(
source, f, default_flow_style=False, sort_keys=False, allow_unicode=True
)
print(f"Generated {len(tables)} source files in {OUTPUT_DIR}")
if __name__ == "__main__":
main()

View file

@ -0,0 +1,110 @@
#!/usr/bin/env python3
"""Run a semantic layer query against the b2b_saas SQLite database.
Usage:
uv run python scripts/slquery.py '{"measures":["count(opportunities.opportunity_id)"],"dimensions":["accounts.segment"]}'
uv run python scripts/slquery.py '{"measures":["churn_risk.avg_risk_score"],"dimensions":["accounts.industry"]}'
echo '{"measures":["sum(contracts.arr)"],"dimensions":["accounts.segment"]}' | uv run python scripts/slquery.py --stdin
"""
from __future__ import annotations
import argparse
import json
import os
import re
import sqlite3
import sys
from pathlib import Path
from semantic_layer.engine import SemanticEngine
SOURCES_DIR = Path(__file__).resolve().parent.parent / "sources" / "b2b_saas"
DB_PATH = Path(
os.environ.get("KLO_B2B_SQLITE_DB", "sample-data-generator/b2b_data.db")
).expanduser()
# sqlglot's sqlite dialect handles most transpilation, but has a few gaps.
# These fixups patch what sqlglot misses.
_SQLITE_FIXUPS = [
# GROUP_CONCAT(DISTINCT x, sep) → GROUP_CONCAT(DISTINCT x) — sqlite
# only allows 1 arg with DISTINCT
(r"GROUP_CONCAT\(DISTINCT (\w+),\s*'[^']*'\)", r"GROUP_CONCAT(DISTINCT \1)"),
# CURRENT_DATE - col → integer days via julianday
(
r"CURRENT_DATE - DATE\((\w+)\)",
r"CAST(julianday('now') - julianday(\1) AS INTEGER)",
),
(r"CURRENT_DATE - (\w+)", r"CAST(julianday('now') - julianday(\1) AS INTEGER)"),
# col - CURRENT_DATE → integer days via julianday
(r"(\w+) - CURRENT_DATE", r"CAST(julianday(\1) - julianday('now') AS INTEGER)"),
# CURRENT_DATE > col → julianday comparison
(r"CURRENT_DATE > (\w+)", r"julianday('now') > julianday(\1)"),
# NULLS LAST — not supported in sqlite
(r"\s+NULLS LAST", ""),
]
def fixup_sqlite(sql: str) -> str:
for pattern, repl in _SQLITE_FIXUPS:
sql = re.sub(pattern, repl, sql)
return sql
def main() -> None:
p = argparse.ArgumentParser(description="Run SL query against b2b_saas SQLite DB")
p.add_argument("query", nargs="?", help="JSON query string")
p.add_argument("--stdin", action="store_true", help="Read JSON from stdin")
p.add_argument(
"--sql-only", action="store_true", help="Print SQL without executing"
)
p.add_argument("--db", default=str(DB_PATH), help="Path to SQLite database")
p.add_argument(
"--sources", default=str(SOURCES_DIR), help="Path to sources directory"
)
args = p.parse_args()
if args.stdin:
query_dict = json.loads(sys.stdin.read())
elif args.query:
query_dict = json.loads(args.query)
else:
p.error("Provide a JSON query string or use --stdin")
# Use sqlite dialect — sqlglot handles STRING_AGG→GROUP_CONCAT,
# DECIMAL→REAL, ::DATE→DATE(), etc.
engine = SemanticEngine(args.sources, dialect="sqlite")
result = engine.query(query_dict)
sql = fixup_sqlite(result.sql)
if args.sql_only:
print(sql)
return
conn = sqlite3.connect(args.db)
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(sql).fetchall()
except sqlite3.OperationalError as e:
print(f"SQL error: {e}", file=sys.stderr)
print(f"\nGenerated SQL:\n{sql}", file=sys.stderr)
sys.exit(1)
finally:
conn.close()
if not rows:
print("(no rows)")
return
cols = rows[0].keys()
widths = [max(len(str(c)), max(len(str(r[c])) for r in rows)) for c in cols]
header = " ".join(str(c).ljust(w) for c, w in zip(cols, widths))
sep = " ".join("-" * w for w in widths)
print(header)
print(sep)
for r in rows:
print(" ".join(str(r[c]).ljust(w) for c, w in zip(cols, widths)))
if __name__ == "__main__":
main()

View file

@ -0,0 +1,166 @@
#!/usr/bin/env python3
"""Run TPC-H queries end-to-end: generate data + semantic layer SQL + execute.
Usage:
uv run python scripts/tpch_runner.py
"""
from __future__ import annotations
import json
import duckdb
import sqlglot
from semantic_layer.engine import SemanticEngine
TPCH_TABLES = [
"region",
"nation",
"supplier",
"customer",
"part",
"partsupp",
"orders",
"lineitem",
]
def setup_tpch(sf: float = 0.01) -> duckdb.DuckDBPyConnection:
"""Create in-memory DuckDB with TPC-H data at the given scale factor."""
conn = duckdb.connect()
conn.execute("INSTALL tpch; LOAD tpch")
conn.execute(f"CALL dbgen(sf={sf})")
# YAML files use public.<table> — create views to match
conn.execute("CREATE SCHEMA IF NOT EXISTS public")
for t in TPCH_TABLES:
conn.execute(f"CREATE VIEW public.{t} AS SELECT * FROM main.{t}")
return conn
def run_query(
conn: duckdb.DuckDBPyConnection,
engine: SemanticEngine,
title: str,
query_dict: dict,
) -> None:
"""Generate SQL via semantic layer, execute it, and print results."""
print(f"\n{'=' * 60}")
print(f" {title}")
print(f"{'=' * 60}")
print("\n>> Request:")
print(json.dumps(query_dict, indent=2))
result = engine.query(query_dict)
formatted_sql = sqlglot.transpile(
result.sql, read=result.dialect, write=result.dialect, pretty=True
)[0]
print(f"\n-- dialect: {result.dialect}")
print(formatted_sql)
cursor = conn.execute(result.sql)
col_names = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
# Simple table formatting
widths = [
max(len(str(c)), *(len(str(r[i])) for r in rows))
for i, c in enumerate(col_names)
]
header = " ".join(str(c).ljust(w) for c, w in zip(col_names, widths))
print(f"\n{header}")
print(" ".join("-" * w for w in widths))
for row in rows:
print(" ".join(str(v).ljust(w) for v, w in zip(row, widths)))
print(f"\n({len(rows)} rows)")
def main() -> None:
conn = setup_tpch()
engine = SemanticEngine("sources/tpch", dialect="duckdb")
# Q1: Pricing summary by return flag / line status
run_query(
conn,
engine,
"Q1: Pricing Summary",
{
"measures": [
"lineitem.revenue",
"lineitem.total_quantity",
"lineitem.avg_discount",
"lineitem.line_count",
],
"dimensions": ["lineitem.l_returnflag", "lineitem.l_linestatus"],
},
)
# Q5-style: Revenue by nation (4-hop join) with ASIA filter
run_query(
conn,
engine,
"Q5: Revenue by Nation (ASIA)",
{
"measures": ["lineitem.revenue"],
"dimensions": ["nation.n_name"],
"filters": ["region.r_name = 'ASIA'"],
},
)
# Q3-style: Revenue by order month for BUILDING segment
run_query(
conn,
engine,
"Q3: Revenue by Month (BUILDING)",
{
"measures": ["lineitem.revenue"],
"dimensions": [{"field": "orders.o_orderdate", "granularity": "month"}],
"filters": ["customer.c_mktsegment = 'BUILDING'"],
"limit": 12,
},
)
# Q10-style: Returned revenue by customer (filtered measure)
run_query(
conn,
engine,
"Q10: Returned Revenue by Customer",
{
"measures": ["lineitem.returned_revenue"],
"dimensions": ["customer.c_name"],
"order_by": [{"field": "lineitem.returned_revenue", "direction": "desc"}],
"limit": 10,
},
)
# Multi-measure: revenue + charge + counts
run_query(
conn,
engine,
"Multi-measure: Revenue, Charge, Counts",
{
"measures": [
"lineitem.revenue",
"lineitem.charge",
"orders.order_count",
],
"dimensions": ["customer.c_mktsegment"],
},
)
# Supply cost by nation (through partsupp bridge)
run_query(
conn,
engine,
"Supply Cost by Nation",
{
"measures": ["partsupp.total_supply_cost"],
"dimensions": ["nation.n_name"],
"limit": 10,
},
)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,4 @@
from semantic_layer.engine import SemanticEngine
from semantic_layer.models import QueryResult, SemanticQuery
__all__ = ["SemanticEngine", "SemanticQuery", "QueryResult"]

View file

@ -0,0 +1,3 @@
from semantic_layer.cli import main
main()

View file

@ -0,0 +1,268 @@
"""CLI for the semantic layer engine.
Usage:
# Simple query
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["orders.status"]}'
# Pre-defined measure with filter
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["orders.revenue"], "dimensions": ["orders.status"]}'
# Cross-source with time granularity
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["regions.name", {"field": "orders.created_at", "granularity": "month"}], "filters": ["regions.name = '"'"'LATAM'"'"'"]}'
# Multiple dialects
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["orders.status"]}' \
--dialect bigquery
# Plan only (no SQL generation)
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["orders.status"]}' \
--plan-only
# JSON input from stdin
echo '{"measures":["sum(orders.amount)"],"dimensions":["orders.status"]}' | \
uv run python -m semantic_layer.cli --sources sources/ecommerce --json
# Custom ORDER BY
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["orders.status"], "order_by": [{"field": "sum(orders.amount)", "direction": "desc"}]}'
# Validate query (suggest fixes on failure)
uv run python -m semantic_layer.cli \
--sources sources/ecommerce \
-q '{"measures": ["sum(orders.amount)"], "dimensions": ["orders.status"]}' \
--suggest
"""
from __future__ import annotations
import argparse
import json
import sys
import yaml
from semantic_layer.engine import SemanticEngine
from semantic_layer.models import SourceDefinition
def build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
prog="semantic-layer",
description="Query the semantic layer engine and generate SQL",
)
p.add_argument(
"--sources",
"-s",
help="Path to the sources directory (e.g. sources/ecommerce)",
)
p.add_argument(
"--model",
help="Path to a single YAML file containing all source definitions as a list",
)
p.add_argument(
"--dialect",
"-d",
default="postgres",
help="SQL dialect (postgres, bigquery, snowflake, etc.)",
)
# Query input
p.add_argument(
"--query",
"-q",
help='Raw JSON query string (e.g. \'{"measures": ["orders.revenue"], "dimensions": ["orders.status"]}\')',
)
# Output modes
p.add_argument(
"--json",
action="store_true",
dest="json_input",
help="Read query as JSON from stdin",
)
p.add_argument(
"--plan-only",
action="store_true",
help="Show the resolved plan instead of SQL",
)
p.add_argument(
"--plan",
action="store_true",
help="Show the resolved plan alongside SQL",
)
p.add_argument(
"--compact",
action="store_true",
help="Output SQL without formatting",
)
# Info commands
p.add_argument(
"--list-sources",
action="store_true",
help="List all available sources and exit",
)
p.add_argument(
"--suggest",
action="store_true",
help="Validate the query and suggest fixes if it fails",
)
return p
def list_sources(engine: SemanticEngine) -> None:
for name, src in sorted(engine.sources.items()):
print(f"\n{'' * 40}")
print(f" {name}")
src_type = "sql" if src.is_sql_source else "table"
print(f" type: {src_type}", end="")
if src.table:
print(f" table: {src.table}", end="")
print(f" grain: {src.grain}")
if src.description:
print(f" {src.description.strip()}")
if src.columns:
print(" columns:")
for col in src.columns:
role_tag = f" [{col.role.value}]" if col.role.value != "default" else ""
print(f" {col.name}: {col.type}{role_tag}")
if src.measures:
print(" measures:")
for m in src.measures:
filt = f" (filter: {m.filter})" if m.filter else ""
print(f" {m.name}: {m.expr}{filt}")
if src.joins:
print(" joins:")
for j in src.joins:
print(f"{j.to} ({j.relationship}) on {j.on}")
def print_plan(plan) -> None:
print("\n── Resolved Plan ──")
print(f" Sources: {', '.join(plan.sources_used)}")
print(f" Anchor: {plan.anchor_source}")
if plan.join_paths:
print(" Joins:")
for jp in plan.join_paths:
print(f" {jp}")
print(f" Fan-out: {plan.fan_out_description}")
if plan.aggregate_locality:
print(" Locality:")
for al in plan.aggregate_locality:
print(f" {al}")
if plan.where_filters:
print(f" WHERE: {' AND '.join(plan.where_filters)}")
if plan.having_filters:
print(f" HAVING: {' AND '.join(plan.having_filters)}")
print(" Columns:")
for col in plan.columns:
prov = col.provenance.value
gran = f" ({col.granularity})" if col.granularity else ""
print(f" {col.name} [{prov}]{gran}")
def _load_model_file(path: str) -> dict[str, SourceDefinition]:
"""Load a YAML file containing a list of source definitions."""
with open(path) as f:
data = yaml.safe_load(f)
if not isinstance(data, list):
raise ValueError("Model file must contain a YAML list of source definitions")
sources: dict[str, SourceDefinition] = {}
for item in data:
src = SourceDefinition(**item)
if src.name in sources:
raise ValueError(f"Duplicate source name: '{src.name}'")
sources[src.name] = src
return sources
def main(argv: list[str] | None = None) -> None:
parser = build_parser()
args = parser.parse_args(argv)
if args.model:
sources = _load_model_file(args.model)
engine = SemanticEngine.from_sources(sources, dialect=args.dialect)
elif args.sources:
engine = SemanticEngine(args.sources, dialect=args.dialect)
else:
parser.error("Provide --sources or --model")
# List sources mode
if args.list_sources:
list_sources(engine)
return
# Build query
if args.query:
query_dict = json.loads(args.query)
elif args.json_input:
raw = sys.stdin.read()
query_dict = json.loads(raw)
else:
parser.error("Provide --query or --json")
return
# Suggest mode
if args.suggest:
result = engine.suggest(query_dict)
if result["success"]:
print("Query is valid.")
print_plan(result["plan"])
else:
print(f"Query failed: {result['error']}")
if result.get("graph_errors"):
for err in result["graph_errors"]:
print(f" Graph error: {err}")
for s in result.get("suggestions", []):
if isinstance(s, dict):
print(f" Suggestion: {s.get('description', '')}")
for src in s.get("required_sources", []):
print(f" - Define source: {src}")
for j in s.get("required_joins", []):
print(
f" - Add join: {j['source']}.{j['on']} ({j['relationship']})"
)
for note in s.get("notes", []):
print(f" Note: {note}")
else:
print(f" Suggestion: {s}")
return
# Plan-only mode
if args.plan_only:
plan = engine.plan_only(query_dict)
print_plan(plan)
return
# Full query
result = engine.query(query_dict)
if args.plan:
print_plan(result.resolved_plan)
print()
if args.compact:
print(result.sql)
else:
print(f"-- dialect: {result.dialect}")
print(result.sql)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,99 @@
"""Detect semantically-redundant measure definitions on the same source."""
from __future__ import annotations
import sqlglot
from sqlglot import exp
from semantic_layer.models import SourceDefinition
from semantic_layer.parser import quote_reserved_identifiers
# DIALECT CONVENTION:
# Measure `expr` values are compared structurally. They must be parsed with
# the connection's native dialect (per sl_capture); parsing as postgres
# would drop dialect-specific tokens and miss duplicates across BigQuery
# `SAFE_DIVIDE` / Snowflake `DIV0` etc.
def validate_measure_duplicates(
sources: dict[str, SourceDefinition],
*,
dialect: str = "postgres",
) -> list[str]:
"""
Flag pairs of measures on the same source whose `expr` is structurally
equivalent. Intended to prevent capture-time churn like:
- name: active_subscription_count
expr: count(*)
filter: is_active = true
- name: new_subscription_count
expr: count(*) # same base aggregation — should be query-time filter
Returns a list of human-readable error strings (empty list = no duplicates).
Compares every pair of measures within a single source; does not compare
across sources (measures on different sources are never redundant).
"""
errors: list[str] = []
for source_name, source in sources.items():
if len(source.measures) < 2:
continue
parsed: list[tuple[str, exp.Expression | None, str | None, frozenset[str]]] = []
for m in source.measures:
try:
quoted = quote_reserved_identifiers(m.expr)
tree = sqlglot.parse_one(f"SELECT {quoted}", read=dialect)
expr_node = tree.expressions[0] if tree.expressions else None
except Exception:
# Unparseable expressions are left for the caller's normal
# validation to surface; don't block on parse failure here.
expr_node = None
parsed.append((m.name, expr_node, m.filter, frozenset(m.segments)))
for i, (name_a, expr_a, filter_a, segments_a) in enumerate(parsed):
if expr_a is None:
continue
for name_b, expr_b, filter_b, segments_b in parsed[i + 1 :]:
if expr_b is None:
continue
if not _expressions_equivalent(expr_a, expr_b):
continue
# Segments are named, reusable filter predicates; two measures
# sharing an expr but applying different segments are by design
# distinct and must not be flagged.
if segments_a != segments_b:
continue
fa = (filter_a or "").strip()
fb = (filter_b or "").strip()
if fa == fb:
errors.append(
f"{source_name}: measures '{name_a}' and '{name_b}' have the same "
f"expression and filter — remove one or differentiate them."
)
else:
errors.append(
f"{source_name}: measure '{name_b}' has the same expression as "
f"'{name_a}' — differs only by `filter`. Use query-time filtering "
f"on '{name_a}' (via semantic_query filters), or, if the filter "
f"encodes a named business segment, add a segments[] entry on this "
f"source and reference it instead."
)
return errors
def _expressions_equivalent(a: exp.Expression, b: exp.Expression) -> bool:
"""
Structural equality on sqlglot ASTs.
Normalizes via sqlglot's .sql() canonical form (handles whitespace, case,
aliasing). Does NOT reorder operands `safe_divide(a, b)` is NOT equal to
`safe_divide(b, a)`, nor is `a - b` equal to `b - a`. This is deliberate:
the check's purpose is catching accidental redundancy, not proving
mathematical equivalence.
"""
if type(a) is not type(b):
return False
return a.sql(dialect="postgres") == b.sql(dialect="postgres")

View file

@ -0,0 +1,360 @@
from __future__ import annotations
from semantic_layer.generator import SqlGenerator
from semantic_layer.graph import JoinGraph
from semantic_layer.loader import SourceLoader
from semantic_layer.models import (
QueryResult,
ResolvedPlan,
SemanticQuery,
SourceDefinition,
ValidationReport,
)
from semantic_layer.planner import QueryPlanner
from semantic_layer.sql_table_extractor import (
extract_table_refs,
ref_matches_source_table,
)
class SemanticEngine:
def __init__(self, sources_dir: str, dialect: str = "postgres"):
self.loader = SourceLoader(sources_dir)
self.sources = self.loader.load_all()
self._init_engine(dialect)
@classmethod
def from_sources(
cls, sources: dict[str, SourceDefinition], dialect: str = "postgres"
) -> SemanticEngine:
"""Create engine from pre-loaded source definitions."""
obj = object.__new__(cls)
obj.loader = None
obj.sources = sources
obj._init_engine(dialect)
return obj
def _init_engine(self, dialect: str) -> None:
# Validate the dialect up-front with the user-facing "Unknown SQL
# dialect" error, before JoinGraph.build() hits sqlglot's parser.
SqlGenerator(dialect)
self.graph = JoinGraph(self.sources, dialect=dialect)
self.graph.build()
self.planner = QueryPlanner(self.sources, self.graph, dialect=dialect)
self.generator = SqlGenerator(dialect, alias_map=self.graph.alias_map)
def query(self, query: dict | SemanticQuery) -> QueryResult:
if isinstance(query, dict):
query = SemanticQuery(**query)
orphan_errors = self._collect_orphan_join_target_errors()
if orphan_errors:
raise ValueError("Cannot query semantic layer: " + "; ".join(orphan_errors))
plan = self.planner.plan(query)
sql = self.generator.generate(plan, self.sources)
return QueryResult(
resolved_plan=plan,
sql=sql,
dialect=self.generator.dialect,
columns=plan.columns,
)
def validate(self, recently_touched: set[str] | None = None) -> ValidationReport:
report = ValidationReport()
self._check_orphan_join_targets(report)
self._check_invalid_grain(report)
self._check_sql_join_coverage(report, recently_touched=recently_touched)
self._check_disconnected_components(report, recently_touched=recently_touched)
return report
def _collect_orphan_join_target_errors(self) -> list[str]:
known = set(self.sources.keys())
errors: list[str] = []
for source in self.sources.values():
for join in source.joins:
if join.to not in known:
errors.append(
f"Source '{source.name}' joins to '{join.to}', "
f"but '{join.to}' is not defined"
)
return errors
def _check_orphan_join_targets(self, report: ValidationReport) -> None:
report.errors.extend(self._collect_orphan_join_target_errors())
def _check_invalid_grain(self, report: ValidationReport) -> None:
for source in self.sources.values():
column_names = {c.name for c in source.columns}
for grain_col in source.grain:
if grain_col not in column_names:
report.errors.append(
f"Source '{source.name}' has grain column '{grain_col}' "
f"that is not in its columns list"
)
def _check_sql_join_coverage(
self,
report: ValidationReport,
recently_touched: set[str] | None = None,
) -> None:
"""Block writes whose SQL references a known source's base table
without declaring a join to that source.
Scoped to `recently_touched` so existing fragmentation isn't flagged
on every write. Only sources with `sql:` are checked. CTE
self-references are filtered by the extractor.
"""
if not recently_touched:
return
table_index: list[tuple[SourceDefinition, str]] = [
(src, src.table) for src in self.sources.values() if src.table is not None
]
if not table_index:
return
dialect = getattr(self.generator, "dialect", "postgres")
for source_name in sorted(recently_touched):
source = self.sources.get(source_name)
if source is None or not source.is_sql_source or not source.sql:
continue
declared = {j.to.lower() for j in source.joins}
refs = extract_table_refs(source.sql, dialect=dialect)
missing: list[str] = []
for ref in refs:
hit_name: str | None = None
for candidate, table_value in table_index:
if candidate.name == source.name:
continue
if ref_matches_source_table(ref, table_value):
hit_name = candidate.name
break
if hit_name is None:
continue
if hit_name.lower() in declared:
continue
if hit_name not in missing:
missing.append(hit_name)
if not missing:
continue
ref_list = ", ".join(missing)
example = missing[0]
grain_col = (
self.sources[example].grain[0] if self.sources[example].grain else "id"
)
msg = (
f"Source '{source.name}' SQL joins manifest table(s) [{ref_list}] "
f"that are not declared in joins[]. Add a join entry for each, "
f"e.g. {{to: {example}, on: '{source.name}.<your_fk> = "
f"{example}.{grain_col}', relationship: many_to_one}}. If a "
f"reference is intentionally absent, document it with a "
f"`unmapped-table-*` wiki note and remove the SQL reference."
)
report.errors.append(msg)
def _check_disconnected_components(
self,
report: ValidationReport,
recently_touched: set[str] | None = None,
) -> None:
components = self.graph.find_components()
if len(components) <= 1:
return
sorted_components = sorted(
components, key=lambda c: (-len(c), sorted(c)[0] if c else "")
)
lines = [
f"Model has {len(components)} disconnected components. "
f"Queries that span components will fail with 'No join path' errors:"
]
for i, component in enumerate(sorted_components, start=1):
names = sorted(component)
if len(names) > 3:
sample = ", ".join(names[:2])
lines.append(
f" - Component {i} ({len(names)} sources): {sample}, ... (+{len(names) - 2} more)"
)
else:
lines.append(
f" - Component {i} ({len(names)} sources): {', '.join(names)}"
)
report.warnings.append("\n".join(lines))
if recently_touched:
singleton_components = {next(iter(c)) for c in components if len(c) == 1}
for source_name in sorted(recently_touched & singleton_components):
report.per_source_warnings.setdefault(source_name, []).append(
f"Source '{source_name}' is now a singleton component (no joins to any "
f"other source). Queries that combine '{source_name}' with anything else "
f"will fail with 'No join path' errors. Run sl_discover for each table "
f"named in this source's SQL and add joins via sl_edit_source."
)
def plan_only(self, query: dict | SemanticQuery) -> ResolvedPlan:
if isinstance(query, dict):
query = SemanticQuery(**query)
return self.planner.plan(query)
def suggest(self, query: dict | SemanticQuery) -> dict:
"""Try to plan. If it fails, suggest config extensions with structured info."""
if isinstance(query, dict):
query = SemanticQuery(**query)
try:
plan = self.planner.plan(query)
# Also validate that SQL generation succeeds
try:
self.generator.generate(plan, self.sources)
except Exception as gen_err:
return {
"success": False,
"error": f"SQL generation failed: {gen_err}",
"plan": plan,
"referenced_sources": sorted(set(plan.sources_used)),
"missing_sources": [],
"graph_errors": [],
"suggestions": [
{
"description": f"SQL generation error: {gen_err}",
"required_sources": [],
"required_joins": [],
"notes": [
"The query plan was valid but the SQL generator encountered an error.",
"This may indicate a limitation in the aggregate locality system.",
],
}
],
}
return {
"success": True,
"plan": plan,
"suggestions": [],
}
except Exception as e:
from semantic_layer.parser import ExpressionParser
parser = ExpressionParser()
# Collect all source references from the query
referenced_sources: set[str] = set()
all_exprs: list[str] = []
for m in query.measures:
if isinstance(m, str):
all_exprs.append(m)
elif isinstance(m, dict):
all_exprs.append(m.get("expr", ""))
for d in query.dimensions:
if isinstance(d, str):
all_exprs.append(d)
elif isinstance(d, dict):
all_exprs.append(d.get("field", ""))
all_exprs.extend(query.filters)
for expr in all_exprs:
referenced_sources.update(parser.extract_source_refs(expr))
# Identify missing sources
known_sources = set(self.sources.keys())
missing_sources = sorted(referenced_sources - known_sources)
graph_errors = _format_component_errors(self.graph.find_components())
suggestions = []
if missing_sources:
# Suggest source definitions for missing sources
required_joins = []
for ms in missing_sources:
# Infer potential join targets from column naming (e.g. orders → orders.id)
for known_name, known_src in self.sources.items():
candidate_fk = f"{known_name}_id"
# Check if the missing source might join to this known source
if any(c.name == candidate_fk for c in known_src.columns):
required_joins.append(
{
"source": known_name,
"to": ms,
"on": f"{candidate_fk} = {ms}.id",
"relationship": "many_to_one",
}
)
suggestions.append(
{
"description": f"Define missing source(s): {', '.join(missing_sources)}",
"required_sources": missing_sources,
"required_joins": required_joins,
"notes": [
f"Create YAML definition(s) for: {', '.join(missing_sources)}",
"Each source needs at minimum: name, table (or sql), grain, and columns",
],
}
)
if not missing_sources and len(referenced_sources) > 1:
# Identify which specific pairs are disconnected
present_sources = sorted(referenced_sources & known_sources)
disconnected_pairs = []
for i, src_a in enumerate(present_sources):
for src_b in present_sources[i + 1 :]:
path = self.graph.find_path(src_a, src_b)
if path is None:
disconnected_pairs.append((src_a, src_b))
required_joins = []
for src_a, src_b in disconnected_pairs:
required_joins.append(
{
"source": src_a,
"to": src_b,
"on": f"{src_b}_id = {src_b}.id",
"relationship": "many_to_one",
}
)
suggestions.append(
{
"description": f"Add join path(s) connecting: {', '.join(present_sources)}",
"required_sources": [],
"required_joins": required_joins,
"notes": [
f"Disconnected pairs: {[f'{a}{b}' for a, b in disconnected_pairs]}"
if disconnected_pairs
else "Sources are connected but query failed for another reason",
]
if disconnected_pairs
else [
"All sources are connected; check the error message for details",
],
}
)
return {
"success": False,
"error": str(e),
"referenced_sources": sorted(referenced_sources),
"missing_sources": missing_sources,
"graph_errors": graph_errors,
"suggestions": suggestions,
}
def _format_component_errors(components: list[set[str]]) -> list[str]:
"""Render multi-component topology as graph_error strings for `suggest()` / CLI."""
if len(components) <= 1:
return []
sorted_components = sorted(
components, key=lambda c: (-len(c), sorted(c)[0] if c else "")
)
lines = []
for i, component in enumerate(sorted_components, start=1):
names = sorted(component)
if len(names) > 3:
sample = ", ".join(names[:2])
lines.append(
f"Component {i} ({len(names)} sources): {sample}, ... (+{len(names) - 2} more)"
)
else:
lines.append(f"Component {i} ({len(names)} sources): {', '.join(names)}")
return [f"Disconnected components: {len(components)}"] + lines

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,285 @@
from __future__ import annotations
import heapq
import logging
from dataclasses import dataclass, field
from semantic_layer.models import SourceDefinition
# DIALECT CONVENTION:
# YAML-authored join `on:` clauses may contain dialect-specific casts
# (e.g. BigQuery `SAFE_CAST(x AS INT64)`). `_parse_on` parses them with
# `read=self.dialect` so the AST reflects the author's intent.
logger = logging.getLogger(__name__)
RELATIONSHIP_INVERSE = {
"many_to_one": "one_to_many",
"one_to_many": "many_to_one",
"one_to_one": "one_to_one",
}
@dataclass
class JoinEdge:
from_source: str
to_source: str
from_column: str
to_column: str
relationship: str
alias: str | None = None
@dataclass
class JoinPath:
edges: list[JoinEdge]
has_one_to_many: bool = False
is_ambiguous: bool = False
@property
def source_names(self) -> list[str]:
if not self.edges:
return []
names = [self.edges[0].from_source]
for e in self.edges:
names.append(e.to_source)
return names
@dataclass
class JoinTree:
edges: list[JoinEdge] = field(default_factory=list)
sources: set[str] = field(default_factory=set)
has_one_to_many: bool = False
class JoinGraph:
def __init__(
self,
sources: dict[str, SourceDefinition],
*,
dialect: str = "postgres",
):
self.sources = sources
self.dialect = dialect
self.adjacency: dict[str, list[JoinEdge]] = {}
def build(self) -> None:
# alias_name → actual source name
self.alias_map: dict[str, str] = {}
for name in self.sources:
self.adjacency.setdefault(name, [])
for source in self.sources.values():
for join in source.joins:
from_col, to_col = self._parse_on(join.on, join.to)
target_name = join.alias if join.alias else join.to
if join.alias:
self.alias_map[join.alias] = join.to
# Forward edge: source → alias (or target)
fwd = JoinEdge(
from_source=source.name,
to_source=target_name,
from_column=from_col,
to_column=to_col,
relationship=join.relationship,
alias=join.alias,
)
self.adjacency.setdefault(target_name, [])
self.adjacency[source.name].append(fwd)
# Reverse edge: alias (or target) → source
rev = JoinEdge(
from_source=target_name,
to_source=source.name,
from_column=to_col,
to_column=from_col,
relationship=RELATIONSHIP_INVERSE[join.relationship],
alias=join.alias,
)
self.adjacency[target_name].append(rev)
def find_path(self, from_source: str, to_source: str) -> JoinPath | None:
"""Dijkstra shortest path between two sources.
Also detects ambiguity: if multiple equal-cost paths exist to the
destination, the returned ``JoinPath`` has ``is_ambiguous=True``.
"""
if from_source == to_source:
return JoinPath(edges=[], has_one_to_many=False)
if from_source not in self.adjacency or to_source not in self.adjacency:
return None
# (cost, counter, current_node, path_edges)
counter = 0
heap: list[tuple[int, int, str, list[JoinEdge]]] = [
(0, counter, from_source, [])
]
visited: set[str] = set()
first_path: JoinPath | None = None
first_cost: int | None = None
while heap:
cost, _, current, path = heapq.heappop(heap)
# All equal-cost alternatives exhausted — stop.
if first_cost is not None and cost > first_cost:
break
if current == to_source:
has_o2m = any(e.relationship == "one_to_many" for e in path)
if first_path is None:
first_path = JoinPath(edges=path, has_one_to_many=has_o2m)
first_cost = cost
continue # don't visit dest — keep looking for alternatives
else:
first_path.is_ambiguous = True
return first_path
if current in visited:
continue
visited.add(current)
for edge in self.adjacency.get(current, []):
if edge.to_source not in visited:
counter += 1
# Prefer safe (many_to_one / one_to_one) paths over one_to_many
edge_cost = (
1 if edge.relationship in ("many_to_one", "one_to_one") else 10
)
heapq.heappush(
heap, (cost + edge_cost, counter, edge.to_source, path + [edge])
)
return first_path
def resolve_join_tree(
self, source_names: set[str], root: str | None = None
) -> JoinTree:
"""
Steiner tree approximation: pick root source,
find shortest path to each other source, merge paths.
"""
if len(source_names) <= 1:
return JoinTree(sources=source_names)
if root is not None and root in source_names:
names = [root] + sorted(source_names - {root})
else:
names = sorted(source_names)
root = names[0]
tree = JoinTree(sources={root})
for target in names[1:]:
if target in tree.sources:
continue
path = self.find_path(root, target)
if path is not None and path.is_ambiguous:
logger.warning(
"Ambiguous join path from '%s' to '%s': multiple equal-cost "
"paths exist. The engine picked one arbitrarily. Use join "
"aliases to disambiguate.",
root,
target,
)
if path is None:
raise ValueError(
f"No join path from '{root}' to '{target}'. "
f"These sources are not connected in the join graph."
)
for edge in path.edges:
if not any(
e.from_source == edge.from_source and e.to_source == edge.to_source
for e in tree.edges
):
tree.edges.append(edge)
if edge.relationship == "one_to_many":
tree.has_one_to_many = True
tree.sources.add(edge.from_source)
tree.sources.add(edge.to_source)
return tree
def find_components(self) -> list[set[str]]:
"""Partition the graph into connected components.
Returns one set per component. For an empty graph, returns []. For a
fully connected graph, returns a single-element list. Used both for
validation (multi-component warning) and for suggest().
Aliases and their base source are treated as belonging to the same
component, since alias-scoped queries resolve back to the base table.
"""
# Bidirectional alias↔base adjacency so BFS treats them as one node
alias_neighbors: dict[str, list[str]] = {}
for alias, base in self.alias_map.items():
alias_neighbors.setdefault(alias, []).append(base)
alias_neighbors.setdefault(base, []).append(alias)
components: list[set[str]] = []
unvisited = set(self.adjacency)
while unvisited:
start = next(iter(unvisited))
component: set[str] = set()
queue = [start]
while queue:
node = queue.pop()
if node in component:
continue
component.add(node)
for edge in self.adjacency.get(node, []):
if edge.to_source not in component:
queue.append(edge.to_source)
for neighbor in alias_neighbors.get(node, []):
if neighbor not in component:
queue.append(neighbor)
components.append(component)
unvisited -= component
return components
def _parse_on(self, on_clause: str, target_source: str) -> tuple[str, str]:
"""
Parse join conditions into (from_columns, to_columns) using sqlglot AST.
Single key: "customer_id = customers.id" ("customer_id", "id")
Composite: "a = t.x AND b = t.y" ("a,b", "x,y")
Composite keys are stored as comma-separated strings.
"""
import sqlglot
from sqlglot import exp as _exp
from semantic_layer.parser import quote_reserved_identifiers
quoted = quote_reserved_identifiers(on_clause)
tree = sqlglot.parse_one(
f"SELECT 1 FROM _a JOIN _b ON {quoted}", read=self.dialect
)
from_cols: list[str] = []
to_cols: list[str] = []
for eq_node in tree.find_all(_exp.EQ):
left = eq_node.left
right = eq_node.right
# Reject nested equality (e.g., "a = b = c")
if isinstance(left, _exp.EQ) or isinstance(right, _exp.EQ):
raise ValueError(f"Invalid join condition: '{on_clause}'")
# Extract column name, stripping any source qualifier
def _col_name(node: _exp.Expression) -> str:
if isinstance(node, _exp.Column):
return node.name
return node.sql(dialect="postgres")
from_cols.append(_col_name(left))
to_cols.append(_col_name(right))
if not from_cols:
raise ValueError(f"Invalid join condition: '{on_clause}'")
return ",".join(from_cols), ",".join(to_cols)

View file

@ -0,0 +1,210 @@
from __future__ import annotations
import logging
import re
from copy import deepcopy
from pathlib import Path
import yaml
from semantic_layer.manifest import (
Manifest,
_description_sources,
_resolve_description,
project_manifest_entry,
validate_overlay,
)
from semantic_layer.models import (
JoinDeclaration,
MeasureDefinition,
Segment,
SourceColumn,
SourceDefinition,
)
logger = logging.getLogger(__name__)
_SCHEMA_DIR = "_schema"
def _normalize_ws(s: str) -> str:
"""Collapse whitespace for join deduplication."""
return re.sub(r"\s+", " ", s.strip())
class SourceLoader:
def __init__(self, sources_dir: str | Path):
self.sources_dir = Path(sources_dir)
def load_all(self) -> dict[str, SourceDefinition]:
"""Load all sources using two-tier architecture.
1. Load _schema/*.yaml manifest shards project to SourceDefinitions
2. Load *.yaml files outside _schema/
- Has `sql` or `table` standalone source (load directly)
- Otherwise overlay (compose with matching manifest entry)
3. Validate cross-references
"""
sources: dict[str, SourceDefinition] = {}
description_sources: dict[str, dict[str, str] | None] = {}
# 1. Load manifest shards
schema_dir = self.sources_dir / _SCHEMA_DIR
if schema_dir.is_dir():
for path in sorted(schema_dir.glob("*.yaml")):
manifest = self._load_manifest_shard(path)
for name, entry in manifest.tables.items():
if name in sources:
raise ValueError(
f"Duplicate source name '{name}' in manifest shard {path}"
)
sources[name] = project_manifest_entry(name, entry)
description_sources[name] = _description_sources(
entry.descriptions, entry.description, entry.db_description
)
# 2. Load files outside _schema/
for path in sorted(self.sources_dir.rglob("*.yaml")):
# Skip manifest shards
if _is_in_schema_dir(path, self.sources_dir):
continue
with open(path) as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
continue
name = data.get("name")
if not name:
continue
if data.get("sql") or data.get("table"):
# Standalone source — load directly
if name in sources:
raise ValueError(
f"Duplicate source name '{name}': standalone file {path} "
f"conflicts with manifest entry"
)
sources[name] = SourceDefinition(**data)
else:
# Overlay — validate and compose with matching manifest entry
errors = validate_overlay(data)
if errors:
raise ValueError(
f"Invalid overlay '{name}' in {path}: {'; '.join(errors)}"
)
base = sources.get(name)
if base:
(
sources[name],
description_sources[name],
) = self._compose(base, data, description_sources.get(name))
else:
logger.warning(
"Orphan overlay '%s' in %s: no matching manifest entry, skipping",
name,
path,
)
self._validate_cross_references(sources)
return sources
def load_file(self, path: str | Path) -> SourceDefinition:
"""Load and validate a single standalone YAML source definition."""
path = Path(path)
with open(path) as f:
data = yaml.safe_load(f)
source = SourceDefinition(**data)
if not source.table and not source.sql:
raise ValueError(
f"Standalone source '{source.name}' in {path} must have 'table' or 'sql'"
)
return source
def _load_manifest_shard(self, path: Path) -> Manifest:
"""Load a single manifest shard file."""
with open(path) as f:
data = yaml.safe_load(f)
return Manifest(**data)
def _compose(
self,
base: SourceDefinition,
overlay: dict,
base_description_sources: dict[str, str] | None = None,
) -> tuple[SourceDefinition, dict[str, str] | None]:
"""Compose a manifest-projected SourceDefinition with an overlay."""
source = deepcopy(base)
description_sources = dict(base_description_sources or {})
# Overlay description semantics match the server: `description` writes the
# `user` source key, and `descriptions` merges keyed sources before a single
# visible description is resolved from the full map.
if overlay.get("description"):
description_sources["user"] = overlay["description"]
if overlay.get("descriptions"):
description_sources.update(
{
source_name: text
for source_name, text in overlay["descriptions"].items()
if text
}
)
if overlay.get("description") or overlay.get("descriptions"):
source.description = _resolve_description(
description_sources or None,
)
# Filter columns
excluded = set(overlay.get("exclude_columns", []))
source.columns = [c for c in source.columns if c.name not in excluded]
# Append computed columns (overlay columns with expr)
for col in overlay.get("columns", []):
source.columns.append(SourceColumn(**col))
# Set measures
source.measures = [MeasureDefinition(**m) for m in overlay.get("measures", [])]
# Set segments
source.segments = [Segment(**s) for s in overlay.get("segments", [])]
# Override grain
if overlay.get("grain"):
source.grain = overlay["grain"]
# Union + dedupe joins, apply suppressions
disabled = {_normalize_ws(j) for j in overlay.get("disable_joins", [])}
manifest_joins = [
j for j in source.joins if _normalize_ws(j.on) not in disabled
]
overlay_joins = [JoinDeclaration(**j) for j in overlay.get("joins", [])]
existing_keys = {f"{j.to}::{_normalize_ws(j.on)}" for j in manifest_joins}
new_joins = [
j
for j in overlay_joins
if f"{j.to}::{_normalize_ws(j.on)}" not in existing_keys
]
source.joins = manifest_joins + new_joins
return source, (description_sources or None)
def _validate_cross_references(self, sources: dict[str, SourceDefinition]) -> None:
"""Validate that all join targets reference existing sources."""
for source in sources.values():
for join in source.joins:
if join.to not in sources:
raise ValueError(
f"Source '{source.name}' joins to '{join.to}', "
f"but '{join.to}' is not defined"
)
def _is_in_schema_dir(path: Path, sources_dir: Path) -> bool:
"""Check if a path is inside the _schema/ directory."""
try:
path.relative_to(sources_dir / _SCHEMA_DIR)
return True
except ValueError:
return False

View file

@ -0,0 +1,233 @@
"""Manifest models and projection for the two-tier schema architecture.
The manifest (`_schema/*.yaml`) stores physical table catalog data with DB-native
types, PK flags, and join provenance. This module handles:
- Manifest-specific data models (ManifestColumn, ManifestJoin, ManifestEntry)
- DB-native semantic type mapping
- Projection from ManifestEntry SourceDefinition
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel
from semantic_layer.models import (
ColumnRole,
DefaultTimeDimensionDbt,
FreshnessDbt,
JoinDeclaration,
SourceColumn,
SourceColumnTests,
SourceDefinition,
)
# ── Type mapping (DB-native → semantic) ─────────────────────────────
_TYPE_MAP: dict[str, str] = {
# number family
"integer": "number",
"bigint": "number",
"smallint": "number",
"numeric": "number",
"decimal": "number",
"float": "number",
"double": "number",
"real": "number",
"int": "number",
"int2": "number",
"int4": "number",
"int8": "number",
"float4": "number",
"float8": "number",
"double precision": "number",
"number": "number",
"tinyint": "number",
"mediumint": "number",
# time family
"timestamp": "time",
"timestamptz": "time",
"timestamp with time zone": "time",
"timestamp without time zone": "time",
"timestamp_ntz": "time",
"timestamp_ltz": "time",
"timestamp_tz": "time",
"datetime": "time",
"date": "time",
"time": "time",
"timetz": "time",
# boolean family
"boolean": "boolean",
"bool": "boolean",
# fallback → 'string'
}
def map_column_type(db_type: str) -> str:
"""Map a DB-native column type to a semantic type (string/number/time/boolean)."""
normalized = db_type.lower().split("(")[0].strip()
return _TYPE_MAP.get(normalized, "string")
# ── Manifest data models ────────────────────────────────────────────
_DEFAULT_PRIORITY = ["user", "ai", "dbt", "db"]
def _description_sources(
descriptions: dict[str, str] | None,
description: str | None = None,
db_description: str | None = None,
) -> dict[str, str] | None:
"""Normalize multi-source descriptions to a keyed map."""
if descriptions:
result = {source: text for source, text in descriptions.items() if text}
if result:
return result
result: dict[str, str] = {}
if description:
result["ai"] = description
if db_description:
result["db"] = db_description
return result or None
def _resolve_description(
descriptions: dict[str, str] | None,
description: str | None = None,
db_description: str | None = None,
) -> str | None:
"""Resolve a single description from a multi-source map or legacy flat fields."""
if descriptions:
for source in _DEFAULT_PRIORITY:
if text := descriptions.get(source):
return text
# Fallback: first available
for text in descriptions.values():
if text:
return text
# Legacy flat fields
if description:
return description
if db_description:
return db_description
return None
class ManifestColumn(BaseModel):
name: str
type: str # DB-native type (e.g., "integer", "varchar", "timestamp")
pk: bool = False
nullable: bool = True
descriptions: dict[str, str] | None = None
# Legacy flat fields (backwards-compatible YAML parsing)
description: str | None = None
db_description: str | None = None
constraints: dict | None = None
enum_values: dict[str, list[str]] | None = None
tests: SourceColumnTests | None = None
@property
def resolved_description(self) -> str | None:
return _resolve_description(
self.descriptions, self.description, self.db_description
)
class ManifestJoin(BaseModel):
to: str
on: str
relationship: Literal["many_to_one", "one_to_many", "one_to_one"]
source: Literal["formal", "inferred", "manual"] = "formal"
class ManifestEntry(BaseModel):
table: str
descriptions: dict[str, str] | None = None
# Legacy flat fields (backwards-compatible YAML parsing)
description: str | None = None
db_description: str | None = None
columns: list[ManifestColumn]
joins: list[ManifestJoin] = []
default_time_dimension: DefaultTimeDimensionDbt | None = None
tags: dict[str, list[str]] | None = None
freshness: dict[str, FreshnessDbt] | None = None
@property
def resolved_description(self) -> str | None:
return _resolve_description(
self.descriptions, self.description, self.db_description
)
class Manifest(BaseModel):
"""A single manifest shard file (`_schema/{schema}.yaml`)."""
tables: dict[str, ManifestEntry]
# ── Projection ──────────────────────────────────────────────────────
def validate_overlay(data: dict) -> list[str]:
"""Validate that overlay data doesn't contain structural fields.
Returns a list of error messages (empty if valid).
"""
errors: list[str] = []
if "table" in data:
errors.append("Overlay must not contain 'table' (owned by manifest)")
if "sql" in data:
errors.append(
"Overlay must not contain 'sql' (that makes it a standalone source)"
)
for col in data.get("columns", []):
if "type" in col and "expr" not in col:
errors.append(
f"Overlay column '{col.get('name', '?')}' specifies 'type' without 'expr' "
f"(structural types are inherited from manifest — only computed columns may specify a type)"
)
return errors
def project_manifest_entry(name: str, entry: ManifestEntry) -> SourceDefinition:
"""Convert a raw manifest entry into a valid SourceDefinition.
- Maps DB-native column types to semantic types
- Auto-derives grain from PK columns (or all columns if no PKs)
- Strips join provenance (source field)
"""
columns = [
SourceColumn(
name=c.name,
type=map_column_type(c.type),
role=ColumnRole.TIME
if map_column_type(c.type) == "time"
else ColumnRole.DEFAULT,
description=c.resolved_description,
constraints=c.constraints,
enum_values=c.enum_values,
tests=c.tests,
)
for c in entry.columns
]
pk_columns = [c.name for c in entry.columns if c.pk]
grain = pk_columns if pk_columns else [c.name for c in entry.columns]
return SourceDefinition(
name=name,
table=entry.table,
description=entry.resolved_description,
grain=grain,
columns=columns,
joins=[
JoinDeclaration(to=j.to, on=j.on, relationship=j.relationship)
for j in entry.joins
],
default_time_dimension=entry.default_time_dimension,
tags=entry.tags,
freshness=entry.freshness,
)

View file

@ -0,0 +1,235 @@
from __future__ import annotations
from enum import Enum
from typing import Any, Literal
from pydantic import BaseModel, Field, model_validator
# ── Source Definition Models ──────────────────────────────────────────
class ColumnVisibility(str, Enum):
PUBLIC = "public"
INTERNAL = "internal"
HIDDEN = "hidden"
class ColumnRole(str, Enum):
TIME = "time"
DEFAULT = "default"
class ColumnDbtConstraints(BaseModel):
not_null: bool | None = None
unique: bool | None = None
class DbtDataTestRef(BaseModel):
name: str
package: str
kwargs: dict[str, Any] | None = None
class SourceColumnTests(BaseModel):
dbt: list[DbtDataTestRef] | None = None
dbt_by_package: dict[str, list[str]] | None = None
class FreshnessDbt(BaseModel):
raw: Any | None = None
loaded_at_field: str | None = None
class SourceColumn(BaseModel):
name: str
type: Literal["string", "number", "time", "boolean"]
visibility: ColumnVisibility = ColumnVisibility.PUBLIC
role: ColumnRole = ColumnRole.DEFAULT
description: str | None = None
expr: str | None = None
natural_granularity: str | None = None
constraints: dict[str, ColumnDbtConstraints] | None = None
enum_values: dict[str, list[str]] | None = None
tests: SourceColumnTests | None = None
class JoinDeclaration(BaseModel):
to: str
on: str # e.g. "customer_id = customers.id"
relationship: Literal["many_to_one", "one_to_many", "one_to_one"]
alias: str | None = None
class MeasureDefinition(BaseModel):
name: str
expr: str # e.g. "sum(amount)"
filter: str | None = None # e.g. "status != 'refunded'"
segments: list[str] = [] # bare segment names defined on the measure's own source
description: str | None = None
class Segment(BaseModel):
"""A named, reusable boolean predicate scoped to a single source."""
name: str
expr: str # e.g. "is_paid = true and is_refunded = '0'"
description: str | None = None
class DefaultTimeDimensionDbt(BaseModel):
dbt: str | None = None
class SourceDefinition(BaseModel):
name: str
description: str | None = None
table: str | None = None
sql: str | None = None
grain: list[str]
columns: list[SourceColumn]
joins: list[JoinDeclaration] = []
measures: list[MeasureDefinition] = []
segments: list[Segment] = []
default_time_dimension: DefaultTimeDimensionDbt | None = None
tags: dict[str, list[str]] | None = None
freshness: dict[str, FreshnessDbt] | None = None
@model_validator(mode="after")
def validate_source(self) -> SourceDefinition:
if self.table and self.sql:
raise ValueError("'table' and 'sql' are mutually exclusive")
if not self.grain:
raise ValueError("grain must be non-empty")
return self
@property
def is_sql_source(self) -> bool:
return self.sql is not None
@property
def is_table_source(self) -> bool:
return self.table is not None
# ── Query Models ──────────────────────────────────────────────────────
class QueryMeasure(BaseModel):
"""Either a pre-defined name ('orders.revenue') or runtime expr."""
ref: str | None = None
expr: str | None = None
name: str | None = None
class QueryDimension(BaseModel):
"""Either a column ref or a time granularity."""
field: str
granularity: str | None = None
class SemanticQuery(BaseModel):
measures: list[str | dict[str, Any]]
dimensions: list[str | dict[str, Any]] = []
filters: list[str] = []
# dotted "source.segment" names; AND-ed into matching measures
segments: list[str] = []
order_by: list[str | dict[str, Any]] = []
limit: int = 1000
include_empty: bool = True
@model_validator(mode="after")
def _validate_limit(self) -> SemanticQuery:
if self.limit is not None and self.limit < 0:
raise ValueError(f"limit must be non-negative, got {self.limit}")
return self
# ── Plan & Result Models ──────────────────────────────────────────────
class Provenance(str, Enum):
VERIFIED = "verified"
COMPOSED = "composed"
DIMENSION = "dimension"
class ResolvedColumn(BaseModel):
name: str
provenance: Provenance
expr: str | None = None
description: str | None = None
granularity: str | None = None
class ResolvedMeasure(BaseModel):
name: str
expr: str # the aggregate expression, e.g. "sum(amount)"
source_name: str
original_name: str | None = None
qualified_ref: str | None = None
filter: str | None = None
provenance: Provenance = Provenance.COMPOSED
is_derived: bool = False
depends_on: list[str] = [] # names of other measures this depends on
description: str | None = None
class MeasureGroup(BaseModel):
"""A group of measures from the same source, for aggregate locality."""
source_name: str
measures: list[ResolvedMeasure]
join_path_to_dims: list[str] = []
class ResolvedJoin(BaseModel):
from_source: str
to_source: str
from_column: str
to_column: str
relationship: str
class OrderByClause(BaseModel):
field: str
direction: str = "asc"
class ResolvedPlan(BaseModel):
sources_used: list[str]
join_paths: list[str] # human-readable descriptions
joins: list[ResolvedJoin] = [] # structured join info for generator
anchor_source: str | None = None # the primary FROM source
anchor_grain: list[str]
fan_out_description: str
has_fan_out: bool = False
measure_groups: list[MeasureGroup] = []
aggregate_locality: list[str] # human-readable CTE descriptions
where_filters: list[str]
having_filters: list[str]
columns: list[ResolvedColumn]
measures: list[ResolvedMeasure] = []
dimensions: list[QueryDimension] = []
order_by: list[OrderByClause] = []
limit: int | None = None
include_empty: bool = True
class QueryResult(BaseModel):
resolved_plan: ResolvedPlan
sql: str
dialect: str
columns: list[ResolvedColumn]
class ValidationReport(BaseModel):
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
per_source_warnings: dict[str, list[str]] = Field(default_factory=dict)
@property
def valid(self) -> bool:
return len(self.errors) == 0

View file

@ -0,0 +1,303 @@
from __future__ import annotations
import functools
import re
from dataclasses import dataclass, field
import sqlglot
from sqlglot import exp
# DIALECT CONVENTION:
# `ExpressionParser` wraps read-only AST walks over user-authored
# expressions. Callers must construct it with the connection's native
# dialect (per sl_capture). The parse cache is keyed on (sql, dialect)
# so engines with different dialects do not share AST collisions.
AGGREGATE_FUNCTIONS = frozenset(
{
"sum",
"avg",
"count",
"count_distinct",
"min",
"max",
"median",
"percentile",
}
)
# Maps sqlglot AggFunc subclasses to our canonical names
_AGG_NODE_MAP: dict[type, str] = {
exp.Sum: "sum",
exp.Avg: "avg",
exp.Count: "count",
exp.Min: "min",
exp.Max: "max",
}
# Custom aggregates that sqlglot parses as Anonymous (not standard SQL)
_CUSTOM_AGG_NAMES = frozenset({"count_distinct", "percentile", "median"})
# SQL reserved words that cause parse failures when used as identifiers
_SQL_RESERVED = frozenset(
{
"select",
"from",
"where",
"group",
"order",
"by",
"having",
"limit",
"join",
"on",
"as",
"and",
"or",
"not",
"in",
"is",
"null",
"true",
"false",
"between",
"like",
"case",
"when",
"then",
"else",
"end",
"insert",
"update",
"delete",
"create",
"drop",
"alter",
"table",
"index",
"view",
"union",
"all",
"distinct",
"into",
"values",
"set",
"with",
"exists",
"any",
"some",
"offset",
"fetch",
"for",
"grant",
"revoke",
"primary",
"key",
"foreign",
"references",
"check",
"constraint",
"default",
"column",
"cross",
"full",
"inner",
"left",
"right",
"outer",
"natural",
"using",
"except",
"intersect",
# Snowflake / cross-dialect reserved words
"glob",
"ilike",
"lateral",
"match_recognize",
"notnull",
"out",
"qualify",
"regexp",
"returning",
"rlike",
"rollback",
"sample",
"tablesample",
"top",
"uncache",
"xor",
}
)
# Regex pattern for source.column references (word.word)
_DOTTED_IDENT_RE = re.compile(r"\b(\w+)\.(\w+)\b")
# Matches single-quoted SQL string literals (including escaped quotes '')
_STRING_LITERAL_RE = re.compile(r"'(?:[^']|'')*'")
@dataclass
class ParsedExpression:
original: str
source_refs: set[str] = field(default_factory=set)
column_refs: set[str] = field(default_factory=set) # "source.column" format
is_aggregate: bool = False
aggregate_function: str | None = None
has_window_function: bool = False
depends_on_measures: set[str] = field(default_factory=set)
def _strip_quotes(name: str) -> str:
"""Strip surrounding double quotes from an identifier."""
if name.startswith('"') and name.endswith('"'):
return name[1:-1]
return name
def quote_reserved_identifiers(expr: str) -> str:
"""Quote source.column references where either part is a SQL reserved word.
String literals are masked before processing to prevent matching
dotted identifiers inside quoted strings like 'group.value'.
"""
# Mask string literals to avoid matching inside them
literals: list[str] = []
def _mask_literal(m: re.Match) -> str:
literals.append(m.group(0))
return f"__SL_LIT_{len(literals) - 1}__"
masked = _STRING_LITERAL_RE.sub(_mask_literal, expr)
def _quote_match(m: re.Match) -> str:
source, col = m.group(1), m.group(2)
start = m.start()
if start > 0 and masked[start - 1] == '"':
return m.group(0)
needs_quote = False
source_q = source
col_q = col
if source.lower() in _SQL_RESERVED:
source_q = f'"{source}"'
needs_quote = True
if col.lower() in _SQL_RESERVED:
col_q = f'"{col}"'
needs_quote = True
if needs_quote:
return f"{source_q}.{col_q}"
return m.group(0)
result = _DOTTED_IDENT_RE.sub(_quote_match, masked)
# Restore string literals
for i, lit in enumerate(literals):
result = result.replace(f"__SL_LIT_{i}__", lit)
return result
@functools.lru_cache(maxsize=256)
def _cached_parse_select(sql: str, dialect: str) -> exp.Expression:
"""Cache parsed SELECT wrapper trees keyed by (sql, dialect).
Each (sql, dialect) pair gets its own entry, so engines using different
dialects don't share AST cache collisions.
"""
return sqlglot.parse_one(sql, read=dialect)
class ExpressionParser:
"""Parses user-authored SQL expressions for AST walks.
Must be constructed with the connection's native dialect. User-authored
`expr:`, `filter:`, and segment predicates from YAML are written in that
dialect (per the sl_capture skill contract) and parsing them as postgres
silently drops dialect-specific tokens (e.g. BigQuery `INTERVAL 30 DAY`).
"""
def __init__(self, dialect: str = "postgres") -> None:
self.dialect = dialect
def _quote_reserved_identifiers(self, expr: str) -> str:
return quote_reserved_identifiers(expr)
def _parse_as_select(self, quoted_expr: str) -> exp.Expression:
"""Parse expression wrapped in SELECT, using cache for repeated expressions."""
return _cached_parse_select(f"SELECT {quoted_expr}", self.dialect)
def parse(
self,
expr: str,
known_measure_names: set[str] | None = None,
) -> ParsedExpression:
known_measure_names = known_measure_names or set()
result = ParsedExpression(original=expr)
if not expr or not expr.strip():
return result
quoted_expr = self._quote_reserved_identifiers(expr)
tree = self._parse_as_select(quoted_expr)
# Extract source.column references
for col in tree.find_all(exp.Column):
if col.table:
source_name = _strip_quotes(col.table)
col_name = _strip_quotes(col.name)
result.source_refs.add(source_name)
result.column_refs.add(f"{source_name}.{col_name}")
# Detect aggregate functions (built-in AggFunc subclasses).
# Aggregates nested inside scalar/correlated subqueries do NOT make the
# outer expression aggregate — e.g. `col = (SELECT MAX(col) FROM t)` is a
# plain column predicate, not a HAVING candidate.
def _inside_subquery(node: exp.Expression) -> bool:
parent = node.parent
while parent is not None:
if isinstance(parent, exp.Subquery):
return True
parent = parent.parent
return False
agg_names: list[str] = []
for node in tree.find_all(exp.AggFunc):
if _inside_subquery(node):
continue
name = _AGG_NODE_MAP.get(type(node))
if name:
agg_names.append(name)
else:
agg_names.append(node.key.lower())
# Detect custom aggregates parsed as Anonymous (count_distinct, percentile, median)
for node in tree.find_all(exp.Anonymous):
if _inside_subquery(node):
continue
if node.name.lower() in _CUSTOM_AGG_NAMES:
agg_names.append(node.name.lower())
if agg_names:
result.is_aggregate = True
result.aggregate_function = agg_names[0]
# Detect window functions (OVER clause)
if tree.find(exp.Window):
result.has_window_function = True
# Detect dependencies on named measures (bare identifiers without table qualifier)
if known_measure_names:
for col in tree.find_all(exp.Column):
if not col.table and col.name in known_measure_names:
result.depends_on_measures.add(col.name)
return result
def extract_source_refs(self, expr: str) -> set[str]:
"""Quick extraction of source names from an expression."""
if not expr or not expr.strip():
return set()
quoted_expr = self._quote_reserved_identifiers(expr)
tree = self._parse_as_select(quoted_expr)
return {
_strip_quotes(col.table) for col in tree.find_all(exp.Column) if col.table
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,72 @@
from __future__ import annotations
import logging
import sqlglot
from sqlglot import exp
logger = logging.getLogger(__name__)
def extract_table_refs(sql: str, dialect: str = "postgres") -> list[tuple[str, ...]]:
"""Return a deduped list of warehouse-table refs found in `sql` as
tuples of normalized (lowercase, unquoted) name parts.
Skips CTE self-references. Returns refs in the order they first appear
so callers can present consistent error messages. Each tuple is the
fully-qualified name as written in the SQL: `("staging", "shipments")`,
`("analytics", "marts", "listings")`, or `("listings",)`.
On parse failure returns []; coverage check is best-effort and must
not break source writes when the SQL has unusual syntax.
"""
try:
tree = sqlglot.parse_one(sql, dialect=dialect)
except Exception as e:
logger.debug("sql_table_extractor: parse failed (%s); skipping coverage", e)
return []
cte_names = {cte.alias_or_name.lower() for cte in tree.find_all(exp.CTE)}
seen: set[tuple[str, ...]] = set()
out: list[tuple[str, ...]] = []
for t in tree.find_all(exp.Table):
name = (t.name or "").lower()
if not name or name in cte_names:
continue
parts: list[str] = []
catalog = t.args.get("catalog")
db = t.args.get("db")
if catalog and getattr(catalog, "name", None):
parts.append(catalog.name.lower())
if db and getattr(db, "name", None):
parts.append(db.name.lower())
parts.append(name)
ref = tuple(parts)
if ref not in seen:
seen.add(ref)
out.append(ref)
return out
def normalize_table(value: str) -> tuple[str, ...]:
"""Split a `table:` field value into normalized, lowercased parts."""
return tuple(p.strip('"').strip("`").lower() for p in value.split(".") if p)
def ref_matches_source_table(ref: tuple[str, ...], source_table: str) -> bool:
"""True iff `ref` is a suffix of `source_table` (or vice versa for the
1-part bare-name case).
Examples:
ref=(marts, listings) table=ANALYTICS.MARTS.LISTINGS True
ref=(analytics, marts, x) table=ANALYTICS.MARTS.X True
ref=(listings,) table=ANALYTICS.MARTS.LISTINGS True (bare matches last)
ref=(staging, shipments) table=ANALYTICS.MARTS.SHIPMENTS False (db differs)
"""
src = normalize_table(source_table)
if not src or not ref:
return False
if len(ref) > len(src):
return False
return src[-len(ref) :] == ref

View file

@ -0,0 +1,111 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Literal
import sqlglot
from sqlglot import exp
logger = logging.getLogger(__name__)
SUPPORTED_TABLE_IDENTIFIER_DIALECTS = {
"bigquery",
"snowflake",
"postgres",
"redshift",
"mysql",
"sqlite",
"tsql",
"clickhouse",
}
ParseTableIdentifierReason = Literal[
"looker_template_unresolved",
"derived_table_not_supported",
"no_physical_table",
"multiple_table_references",
"unsupported_dialect",
"parse_error",
]
@dataclass(frozen=True)
class ParseTableIdentifierItem:
key: str
sql_table_name: str
dialect: str
@dataclass(frozen=True)
class ParsedIdentifier:
ok: bool
catalog: str | None = None
schema_: str | None = None
name: str | None = None
canonical_table: str | None = None
reason: ParseTableIdentifierReason | None = None
detail: str | None = None
def parse_table_identifier_batch(
items: list[ParseTableIdentifierItem],
) -> dict[str, ParsedIdentifier]:
return {
item.key: parse_table_identifier_one(item.sql_table_name, item.dialect)
for item in items
}
def parse_table_identifier_one(sql_table_name: str, dialect: str) -> ParsedIdentifier:
normalized_dialect = dialect.lower()
if normalized_dialect not in SUPPORTED_TABLE_IDENTIFIER_DIALECTS:
return ParsedIdentifier(
ok=False,
reason="unsupported_dialect",
detail=f"Unsupported sqlglot dialect for table identifier parsing: {dialect}",
)
if "${" in sql_table_name or "@{" in sql_table_name:
return ParsedIdentifier(ok=False, reason="looker_template_unresolved")
try:
parsed = sqlglot.parse_one(
f"SELECT * FROM {sql_table_name}",
read=normalized_dialect,
)
from_clause = parsed.args.get("from_")
if from_clause is None or from_clause.this is None:
return ParsedIdentifier(ok=False, reason="no_physical_table")
from_expr = from_clause.this
if isinstance(from_expr, (exp.Subquery, exp.Values, exp.Lateral)):
return ParsedIdentifier(ok=False, reason="derived_table_not_supported")
if not isinstance(from_expr, exp.Table):
return ParsedIdentifier(ok=False, reason="derived_table_not_supported")
tables = list(parsed.find_all(exp.Table))
if not tables:
return ParsedIdentifier(ok=False, reason="no_physical_table")
if len(tables) > 1:
return ParsedIdentifier(ok=False, reason="multiple_table_references")
table = tables[0]
canonical_table = exp.Table(
this=exp.to_identifier(table.name),
db=exp.to_identifier(table.db) if table.db else None,
catalog=exp.to_identifier(table.catalog) if table.catalog else None,
).sql(dialect=normalized_dialect)
return ParsedIdentifier(
ok=True,
catalog=table.catalog or None,
schema_=table.db or None,
name=table.name,
canonical_table=canonical_table,
)
except sqlglot.errors.ParseError as exc:
return ParsedIdentifier(ok=False, reason="parse_error", detail=str(exc))
except Exception as exc:
logger.exception("Unexpected failure while parsing Looker sql_table_name")
return ParsedIdentifier(ok=False, reason="parse_error", detail=str(exc))

View file

@ -0,0 +1,15 @@
name: abm_engagements
table: abm_engagements
grain:
- row_id
columns:
- name: account_id
type: number
- name: engagement_month
type: string
- name: row_id
type: number
joins:
- to: accounts
'on': account_id = accounts.account_id
relationship: many_to_one

View file

@ -0,0 +1,18 @@
name: account_intent_signals
table: account_intent_signals
grain:
- signal_id
columns:
- name: signal_id
type: number
- name: account_id
type: number
- name: signal_date
type: time
role: time
- name: topic
type: string
joins:
- to: accounts
'on': account_id = accounts.account_id
relationship: many_to_one

View file

@ -0,0 +1,23 @@
name: accounts
table: accounts
grain:
- account_id
columns:
- name: account_id
type: number
- name: account_name
type: string
- name: csm_rep_id
type: number
- name: industry
type: string
- name: is_customer
type: string
- name: region
type: string
- name: segment
type: string
joins:
- to: sales_reps
'on': csm_rep_id = sales_reps.rep_id
relationship: many_to_one

View file

@ -0,0 +1,36 @@
name: activities
table: activities
grain:
- activity_id
columns:
- name: activity_id
type: number
- name: account_id
type: number
- name: activity_date
type: time
role: time
- name: activity_type
type: string
- name: channel
type: string
- name: direction
type: string
- name: duration_minutes
type: number
- name: opportunity_id
type: number
- name: rep_id
type: number
- name: subject
type: string
joins:
- to: accounts
'on': account_id = accounts.account_id
relationship: many_to_one
- to: opportunities
'on': opportunity_id = opportunities.opportunity_id
relationship: many_to_one
- to: sales_reps
'on': rep_id = sales_reps.rep_id
relationship: many_to_one

View file

@ -0,0 +1,13 @@
name: ad_accounts
table: ad_accounts
grain:
- ad_account_id
columns:
- name: ad_account_id
type: number
- name: account_name
type: string
- name: currency
type: string
- name: platform
type: string

View file

@ -0,0 +1,24 @@
name: ad_ad_stats
table: ad_ad_stats
grain:
- row_id
columns:
- name: ad_id
type: number
- name: clicks
type: number
- name: conversions
type: number
- name: impressions
type: number
- name: row_id
type: number
- name: spend
type: number
- name: stat_date
type: time
role: time
joins:
- to: ads
'on': ad_id = ads.ad_id
relationship: many_to_one

View file

@ -0,0 +1,28 @@
name: ad_campaigns
table: ad_campaigns
grain:
- ad_campaign_id
columns:
- name: ad_campaign_id
type: number
- name: ad_account_id
type: number
- name: campaign_name
type: string
- name: channel
type: string
- name: end_date
type: time
role: time
- name: objective
type: string
- name: start_date
type: time
role: time
joins:
- to: ad_accounts
'on': ad_account_id = ad_accounts.ad_account_id
relationship: many_to_one
- to: accounts
'on': ad_account_id = accounts.account_id
relationship: many_to_one

View file

@ -0,0 +1,24 @@
name: ad_creative_stats
table: ad_creative_stats
grain:
- row_id
columns:
- name: clicks
type: number
- name: conversions
type: number
- name: creative_id
type: number
- name: impressions
type: number
- name: row_id
type: number
- name: spend
type: number
- name: stat_date
type: time
role: time
joins:
- to: ad_creatives
'on': creative_id = ad_creatives.creative_id
relationship: many_to_one

View file

@ -0,0 +1,20 @@
name: ad_creatives
table: ad_creatives
grain:
- creative_id
columns:
- name: creative_id
type: number
- name: ad_campaign_id
type: number
- name: created_at
type: time
role: time
- name: format
type: string
- name: name
type: string
joins:
- to: ad_campaigns
'on': ad_campaign_id = ad_campaigns.ad_campaign_id
relationship: many_to_one

View file

@ -0,0 +1,17 @@
name: ad_groups
table: ad_groups
grain:
- ad_group_id
columns:
- name: ad_group_id
type: number
- name: ad_campaign_id
type: number
- name: name
type: string
- name: status
type: string
joins:
- to: ad_campaigns
'on': ad_campaign_id = ad_campaigns.ad_campaign_id
relationship: many_to_one

View file

@ -0,0 +1,24 @@
name: ad_stats
table: ad_stats
grain:
- stat_id
columns:
- name: stat_id
type: number
- name: ad_campaign_id
type: number
- name: clicks
type: number
- name: conversions
type: number
- name: impressions
type: number
- name: spend
type: number
- name: stat_date
type: time
role: time
joins:
- to: ad_campaigns
'on': ad_campaign_id = ad_campaigns.ad_campaign_id
relationship: many_to_one

View file

@ -0,0 +1,20 @@
name: ads
table: ads
grain:
- ad_id
columns:
- name: ad_id
type: number
- name: ad_group_id
type: number
- name: created_at
type: time
role: time
- name: name
type: string
- name: status
type: string
joins:
- to: ad_groups
'on': ad_group_id = ad_groups.ad_group_id
relationship: many_to_one

View file

@ -0,0 +1,23 @@
name: ap_bills
table: ap_bills
grain:
- bill_id
columns:
- name: bill_id
type: number
- name: amount
type: number
- name: bill_date
type: time
role: time
- name: due_date
type: time
role: time
- name: status
type: string
- name: vendor_id
type: number
joins:
- to: vendors
'on': vendor_id = vendors.vendor_id
relationship: many_to_one

View file

@ -0,0 +1,26 @@
name: approvals
table: approvals
grain:
- approval_id
columns:
- name: approval_id
type: number
- name: approved_at
type: time
role: time
- name: approver_rep_id
type: number
- name: quote_id
type: number
- name: requested_at
type: time
role: time
- name: status
type: string
joins:
- to: sales_reps
'on': approver_rep_id = sales_reps.rep_id
relationship: many_to_one
- to: quotes
'on': quote_id = quotes.quote_id
relationship: many_to_one

View file

@ -0,0 +1,22 @@
name: attribution_credits
table: attribution_credits
grain:
- credit_id
columns:
- name: credit_id
type: number
- name: credit
type: string
- name: model
type: string
- name: opportunity_id
type: number
- name: touchpoint_id
type: number
joins:
- to: touchpoints
'on': touchpoint_id = touchpoints.touchpoint_id
relationship: many_to_one
- to: opportunities
'on': opportunity_id = opportunities.opportunity_id
relationship: many_to_one

View file

@ -0,0 +1,15 @@
name: budgets
table: budgets
grain:
- budget_id
columns:
- name: budget_id
type: number
- name: department
type: string
- name: period_end
type: string
- name: period_start
type: string
- name: planned_amount
type: number

View file

@ -0,0 +1,28 @@
name: calls
table: calls
grain:
- call_id
columns:
- name: call_id
type: number
- name: call_date
type: time
role: time
- name: duration_minutes
type: number
- name: opportunity_id
type: number
- name: rep_id
type: number
- name: sentiment
type: time
role: time
- name: transcript_url
type: string
joins:
- to: opportunities
'on': opportunity_id = opportunities.opportunity_id
relationship: many_to_one
- to: sales_reps
'on': rep_id = sales_reps.rep_id
relationship: many_to_one

View file

@ -0,0 +1,23 @@
name: campaign_members
table: campaign_members
grain:
- campaign_member_id
columns:
- name: campaign_member_id
type: number
- name: campaign_id
type: number
- name: lead_id
type: number
- name: responded_at
type: time
role: time
- name: status
type: string
joins:
- to: campaigns
'on': campaign_id = campaigns.campaign_id
relationship: many_to_one
- to: leads
'on': lead_id = leads.lead_id
relationship: many_to_one

View file

@ -0,0 +1,19 @@
name: campaigns
table: campaigns
grain:
- campaign_id
columns:
- name: campaign_id
type: number
- name: budget
type: string
- name: campaign_name
type: string
- name: end_date
type: time
role: time
- name: start_date
type: time
role: time
- name: type
type: string

View file

@ -0,0 +1,22 @@
name: card_transactions
table: card_transactions
grain:
- amount
columns:
- name: amount
type: number
- name: card_txn_id
type: number
- name: department
type: string
- name: employee_email
type: string
- name: txn_date
type: time
role: time
- name: vendor_id
type: number
joins:
- to: vendors
'on': vendor_id = vendors.vendor_id
relationship: many_to_one

View file

@ -0,0 +1,12 @@
name: cash_balances
table: cash_balances
grain:
- balance
columns:
- name: balance
type: string
- name: balance_date
type: time
role: time
- name: bank_account
type: string

View file

@ -0,0 +1,24 @@
name: charges
table: charges
grain:
- charge_id
columns:
- name: charge_id
type: number
- name: amount
type: number
- name: created_at
type: time
role: time
- name: currency
type: string
- name: payment_intent_id
type: number
- name: payment_method
type: string
- name: status
type: string
joins:
- to: payment_intents
'on': payment_intent_id = payment_intents.payment_intent_id
relationship: many_to_one

View file

@ -0,0 +1,290 @@
name: churn_risk
description: |
Per-account churn risk scoring for B2B SaaS customers. Combines signals from
subscriptions (cancellation history), support tickets (severity, SLA breaches),
product usage (adoption decline), contracts (renewal proximity), CSM activities
(engagement recency), and invoices (payment issues) into a weighted composite
risk_score (0-1) and risk_tier (High/Medium/Low). One row per customer account.
sql: |
WITH sub_signals AS (
SELECT
account_id,
MAX(CASE WHEN canceled_at IS NOT NULL THEN 1 ELSE 0 END) AS has_canceled,
COUNT(CASE WHEN canceled_at IS NOT NULL THEN 1 END) AS canceled_count,
STRING_AGG(DISTINCT churn_reason, ', ') AS churn_reasons
FROM subscriptions
GROUP BY account_id
),
ticket_signals AS (
SELECT
account_id,
COUNT(*) AS total_tickets,
COUNT(CASE WHEN status = 'Open' THEN 1 END) AS open_tickets,
COUNT(CASE WHEN severity = 'High' THEN 1 END) AS high_severity_tickets,
COUNT(CASE WHEN sla_breached = '1' OR sla_breached = 'true' THEN 1 END) AS sla_breaches
FROM support_tickets
GROUP BY account_id
),
usage_signals AS (
SELECT
account_id,
AVG(CASE WHEN CURRENT_DATE - usage_date <= 90
THEN CAST(active_users AS NUMERIC) END) AS recent_active_users,
AVG(CASE WHEN CURRENT_DATE - usage_date > 90
AND CURRENT_DATE - usage_date <= 180
THEN CAST(active_users AS NUMERIC) END) AS prior_active_users,
AVG(CASE WHEN CURRENT_DATE - usage_date <= 90
THEN CAST(events_count AS NUMERIC) END) AS recent_events,
AVG(CASE WHEN CURRENT_DATE - usage_date > 90
AND CURRENT_DATE - usage_date <= 180
THEN CAST(events_count AS NUMERIC) END) AS prior_events
FROM product_usage
GROUP BY account_id
),
contract_signals AS (
SELECT
account_id,
MAX(arr) AS current_arr,
MIN(CASE WHEN status = 'Active'
THEN end_date - CURRENT_DATE END) AS days_to_renewal,
COUNT(CASE WHEN status = 'Active' THEN 1 END) AS active_contracts
FROM contracts
GROUP BY account_id
),
activity_signals AS (
SELECT
account_id,
COUNT(CASE WHEN CURRENT_DATE - activity_date::date <= 90
THEN 1 END) AS recent_activities,
MIN(CURRENT_DATE - activity_date::date) AS days_since_last_activity
FROM activities
GROUP BY account_id
),
invoice_signals AS (
SELECT
account_id,
COUNT(CASE WHEN status = 'Partial' THEN 1 END) AS partial_invoices,
COUNT(CASE WHEN CURRENT_DATE > due_date
AND status != 'Paid' THEN 1 END) AS overdue_invoices
FROM invoices
GROUP BY account_id
),
scored AS (
SELECT
a.account_id,
COALESCE(s.has_canceled, 0) AS has_canceled,
COALESCE(s.canceled_count, 0) AS canceled_count,
s.churn_reasons,
COALESCE(t.open_tickets, 0) AS open_tickets,
COALESCE(t.high_severity_tickets, 0) AS high_severity_tickets,
COALESCE(t.sla_breaches, 0) AS sla_breaches,
COALESCE(u.recent_active_users, 0) AS recent_active_users,
COALESCE(u.prior_active_users, 0) AS prior_active_users,
COALESCE(u.recent_events, 0) AS recent_events,
COALESCE(c.current_arr, 0) AS current_arr,
COALESCE(c.days_to_renewal, 999) AS days_to_renewal,
COALESCE(c.active_contracts, 0) AS active_contracts,
COALESCE(act.recent_activities, 0) AS recent_activities,
COALESCE(act.days_since_last_activity, 999) AS days_since_last_activity,
COALESCE(inv.partial_invoices, 0) AS partial_invoices,
COALESCE(inv.overdue_invoices, 0) AS overdue_invoices,
CASE WHEN COALESCE(s.has_canceled, 0) = 1 THEN 1.0
WHEN COALESCE(s.canceled_count, 0) > 0 THEN 0.7
ELSE 0.1 END AS subscription_risk,
CASE WHEN COALESCE(t.high_severity_tickets, 0) >= 3 THEN 0.9
WHEN COALESCE(t.sla_breaches, 0) >= 2 THEN 0.8
WHEN COALESCE(t.open_tickets, 0) >= 3 THEN 0.7
WHEN COALESCE(t.open_tickets, 0) >= 1 THEN 0.4
ELSE 0.1 END AS support_risk,
CASE WHEN COALESCE(u.recent_active_users, 0) = 0 THEN 0.9
WHEN COALESCE(u.prior_active_users, 0) > 0
AND COALESCE(u.recent_active_users, 0) < COALESCE(u.prior_active_users, 0) * 0.5
THEN 0.8
WHEN COALESCE(u.prior_active_users, 0) > 0
AND COALESCE(u.recent_active_users, 0) < COALESCE(u.prior_active_users, 0) * 0.8
THEN 0.5
ELSE 0.1 END AS usage_risk,
CASE WHEN COALESCE(c.days_to_renewal, 999) <= 30 THEN 0.9
WHEN COALESCE(c.days_to_renewal, 999) <= 60 THEN 0.7
WHEN COALESCE(c.days_to_renewal, 999) <= 90 THEN 0.5
WHEN COALESCE(c.active_contracts, 0) = 0 THEN 0.8
ELSE 0.1 END AS contract_risk,
CASE WHEN COALESCE(act.days_since_last_activity, 999) > 90 THEN 0.9
WHEN COALESCE(act.days_since_last_activity, 999) > 60 THEN 0.7
WHEN COALESCE(act.recent_activities, 0) <= 2 THEN 0.6
WHEN COALESCE(act.days_since_last_activity, 999) > 30 THEN 0.4
ELSE 0.1 END AS engagement_risk,
CASE WHEN COALESCE(inv.overdue_invoices, 0) >= 2 THEN 0.9
WHEN COALESCE(inv.overdue_invoices, 0) >= 1 THEN 0.7
WHEN COALESCE(inv.partial_invoices, 0) >= 2 THEN 0.6
WHEN COALESCE(inv.partial_invoices, 0) >= 1 THEN 0.3
ELSE 0.1 END AS payment_risk
FROM accounts a
LEFT JOIN sub_signals s ON a.account_id = s.account_id
LEFT JOIN ticket_signals t ON a.account_id = t.account_id
LEFT JOIN usage_signals u ON a.account_id = u.account_id
LEFT JOIN contract_signals c ON a.account_id = c.account_id
LEFT JOIN activity_signals act ON a.account_id = act.account_id
LEFT JOIN invoice_signals inv ON a.account_id = inv.account_id
WHERE a.is_customer = '1'
)
SELECT
account_id,
has_canceled,
canceled_count,
churn_reasons,
open_tickets,
high_severity_tickets,
sla_breaches,
recent_active_users,
prior_active_users,
recent_events,
current_arr,
days_to_renewal,
active_contracts,
recent_activities,
days_since_last_activity,
partial_invoices,
overdue_invoices,
subscription_risk,
support_risk,
usage_risk,
contract_risk,
engagement_risk,
payment_risk,
ROUND(
subscription_risk * 0.20
+ support_risk * 0.20
+ usage_risk * 0.20
+ contract_risk * 0.15
+ engagement_risk * 0.15
+ payment_risk * 0.10,
3
) AS risk_score,
CASE
WHEN (subscription_risk * 0.20
+ support_risk * 0.20
+ usage_risk * 0.20
+ contract_risk * 0.15
+ engagement_risk * 0.15
+ payment_risk * 0.10) >= 0.7 THEN 'High'
WHEN (subscription_risk * 0.20
+ support_risk * 0.20
+ usage_risk * 0.20
+ contract_risk * 0.15
+ engagement_risk * 0.15
+ payment_risk * 0.10) >= 0.4 THEN 'Medium'
ELSE 'Low'
END AS risk_tier
FROM scored
grain:
- account_id
columns:
- name: account_id
type: number
- name: has_canceled
type: number
description: "1 if the account has any canceled subscription"
- name: canceled_count
type: number
description: "Number of canceled subscriptions"
- name: churn_reasons
type: string
description: "Comma-separated distinct churn reasons from subscriptions"
- name: open_tickets
type: number
description: "Count of currently open support tickets"
- name: high_severity_tickets
type: number
description: "Count of high-severity support tickets"
- name: sla_breaches
type: number
description: "Count of support tickets with SLA breaches"
- name: recent_active_users
type: number
description: "Average active users in the last 90 days"
- name: prior_active_users
type: number
description: "Average active users 90-180 days ago (for trend comparison)"
- name: recent_events
type: number
description: "Average event count in the last 90 days"
- name: current_arr
type: number
description: "Highest ARR from active contracts"
- name: days_to_renewal
type: number
description: "Days until the nearest active contract expires"
- name: active_contracts
type: number
description: "Count of active contracts"
- name: recent_activities
type: number
description: "CSM activities (calls, meetings, emails, tasks) in the last 90 days"
- name: days_since_last_activity
type: number
description: "Days since the most recent CSM activity"
- name: partial_invoices
type: number
description: "Count of invoices with Partial payment status"
- name: overdue_invoices
type: number
description: "Count of overdue unpaid invoices"
- name: subscription_risk
type: number
description: "Subscription cancellation risk sub-score (0.0-1.0)"
- name: support_risk
type: number
description: "Support burden risk sub-score (0.0-1.0)"
- name: usage_risk
type: number
description: "Product usage decline risk sub-score (0.0-1.0)"
- name: contract_risk
type: number
description: "Contract renewal proximity risk sub-score (0.0-1.0)"
- name: engagement_risk
type: number
description: "CSM engagement gap risk sub-score (0.0-1.0)"
- name: payment_risk
type: number
description: "Payment issues risk sub-score (0.0-1.0)"
- name: risk_score
type: number
description: "Weighted composite churn risk score (0.0-1.0); higher = riskier"
- name: risk_tier
type: string
description: "Churn risk tier: High (>=0.7), Medium (>=0.4), Low (<0.4)"
joins:
- to: accounts
"on": account_id = accounts.account_id
relationship: one_to_one
measures:
- name: avg_risk_score
expr: avg(risk_score)
description: "Average churn risk score across accounts"
- name: high_risk_accounts
expr: count(account_id)
filter: "risk_tier = 'High'"
description: "Number of accounts in the High risk tier"
- name: medium_risk_accounts
expr: count(account_id)
filter: "risk_tier = 'Medium'"
description: "Number of accounts in the Medium risk tier"
- name: low_risk_accounts
expr: count(account_id)
filter: "risk_tier = 'Low'"
description: "Number of accounts in the Low risk tier"
- name: total_arr_at_risk
expr: sum(current_arr)
filter: "risk_tier = 'High'"
description: "Total ARR from accounts in the High risk tier"
- name: avg_support_risk
expr: avg(support_risk)
description: "Average support burden risk sub-score"
- name: avg_usage_risk
expr: avg(usage_risk)
description: "Average usage decline risk sub-score"
- name: accounts_expiring_90d
expr: count(account_id)
filter: "days_to_renewal <= 90"
description: "Accounts with contracts expiring within 90 days"

View file

@ -0,0 +1,23 @@
name: contacts
table: contacts
grain:
- contact_id
columns:
- name: contact_id
type: number
- name: account_id
type: number
- name: email
type: string
- name: first_name
type: string
- name: last_name
type: string
- name: phone
type: string
- name: title
type: string
joins:
- to: accounts
'on': account_id = accounts.account_id
relationship: many_to_one

View file

@ -0,0 +1,16 @@
name: content_assets
table: content_assets
grain:
- asset_id
columns:
- name: asset_id
type: number
- name: content_type
type: string
- name: publish_date
type: time
role: time
- name: title
type: string
- name: url
type: string

View file

@ -0,0 +1,33 @@
name: content_touches
table: content_touches
grain:
- touch_id
columns:
- name: touch_id
type: number
- name: account_id
type: number
- name: action
type: string
- name: asset_id
type: number
- name: lead_id
type: number
- name: opportunity_id
type: number
- name: touched_at
type: time
role: time
joins:
- to: leads
'on': lead_id = leads.lead_id
relationship: many_to_one
- to: opportunities
'on': opportunity_id = opportunities.opportunity_id
relationship: many_to_one
- to: content_assets
'on': asset_id = content_assets.asset_id
relationship: many_to_one
- to: accounts
'on': account_id = accounts.account_id
relationship: many_to_one

View file

@ -0,0 +1,30 @@
name: contracts
table: contracts
grain:
- contract_id
columns:
- name: contract_id
type: number
- name: account_id
type: number
- name: arr
type: number
- name: contract_number
type: string
- name: end_date
type: time
role: time
- name: opportunity_id
type: number
- name: start_date
type: time
role: time
- name: status
type: string
joins:
- to: accounts
'on': account_id = accounts.account_id
relationship: many_to_one
- to: opportunities
'on': opportunity_id = opportunities.opportunity_id
relationship: many_to_one

View file

@ -0,0 +1,23 @@
name: crm_notes
table: crm_notes
grain:
- note_id
columns:
- name: note_id
type: number
- name: created_at
type: time
role: time
- name: note_text
type: string
- name: opportunity_id
type: number
- name: rep_id
type: number
joins:
- to: opportunities
'on': opportunity_id = opportunities.opportunity_id
relationship: many_to_one
- to: sales_reps
'on': rep_id = sales_reps.rep_id
relationship: many_to_one

View file

@ -0,0 +1,9 @@
name: currencies
table: currencies
grain:
- currency_code
columns:
- name: currency_code
type: string
- name: currency_name
type: string

View file

@ -0,0 +1,9 @@
name: departments_hr
table: departments_hr
grain:
- dept_id
columns:
- name: dept_id
type: number
- name: dept_name
type: string

View file

@ -0,0 +1,23 @@
name: disputes
table: disputes
grain:
- dispute_id
columns:
- name: dispute_id
type: number
- name: charge_id
type: number
- name: created_at
type: time
role: time
- name: reason
type: string
- name: resolved_at
type: time
role: time
- name: status
type: string
joins:
- to: charges
'on': charge_id = charges.charge_id
relationship: many_to_one

View file

@ -0,0 +1,18 @@
name: email_events
table: email_events
grain:
- event_id
columns:
- name: event_id
type: number
- name: event_at
type: time
role: time
- name: event_type
type: string
- name: send_id
type: number
joins:
- to: email_sends
'on': send_id = email_sends.send_id
relationship: many_to_one

View file

@ -0,0 +1,33 @@
name: email_sends
table: email_sends
grain:
- send_id
columns:
- name: send_id
type: number
- name: campaign_id
type: number
- name: email_id
type: number
- name: lead_id
type: number
- name: rep_id
type: number
- name: sent_at
type: time
role: time
- name: sequence_id
type: number
joins:
- to: campaigns
'on': campaign_id = campaigns.campaign_id
relationship: many_to_one
- to: leads
'on': lead_id = leads.lead_id
relationship: many_to_one
- to: sequences
'on': sequence_id = sequences.sequence_id
relationship: many_to_one
- to: sales_reps
'on': rep_id = sales_reps.rep_id
relationship: many_to_one

View file

@ -0,0 +1,33 @@
name: employees
table: employees
grain:
- employee_id
columns:
- name: employee_id
type: number
- name: base_salary
type: number
- name: benefits_cost
type: number
- name: dept_id
type: number
- name: email
type: string
- name: first_name
type: string
- name: hire_date
type: time
role: time
- name: last_name
type: string
- name: region
type: string
- name: role
type: string
- name: termination_date
type: time
role: time
joins:
- to: departments_hr
'on': dept_id = departments_hr.dept_id
relationship: many_to_one

View file

@ -0,0 +1,21 @@
name: etl_runs
table: etl_runs
grain:
- run_id
columns:
- name: run_id
type: number
- name: destination
type: string
- name: ended_at
type: time
role: time
- name: rows_processed
type: number
- name: source
type: string
- name: started_at
type: time
role: time
- name: status
type: string

View file

@ -0,0 +1,17 @@
name: fiscal_calendar
table: fiscal_calendar
grain:
- calendar_date
columns:
- name: calendar_date
type: time
- name: fiscal_month
type: string
- name: fiscal_quarter
type: string
- name: fiscal_year
type: string
- name: is_month_start
type: string
- name: is_quarter_start
type: string

View file

@ -0,0 +1,23 @@
name: forecast_snapshots
table: forecast_snapshots
grain:
- snapshot_id
columns:
- name: snapshot_id
type: number
- name: category
type: string
- name: rep_id
type: number
- name: snapshot_date
type: time
role: time
- name: team_id
type: number
joins:
- to: sales_teams
'on': team_id = sales_teams.team_id
relationship: many_to_one
- to: sales_reps
'on': rep_id = sales_reps.rep_id
relationship: many_to_one

View file

@ -0,0 +1,14 @@
name: fx_rates
table: fx_rates
grain:
- from_currency
columns:
- name: from_currency
type: string
- name: rate
type: string
- name: rate_date
type: time
role: time
- name: to_currency
type: string

View file

@ -0,0 +1,23 @@
name: ga4_event_params
table: ga4_event_params
grain:
- param_id
columns:
- name: param_id
type: number
- name: ga4_event_id
type: number
- name: key
type: string
- name: value
type: string
joins:
- to: ga4_events
'on': ga4_event_id = ga4_events.ga4_event_id
relationship: many_to_one
- to: email_events
'on': ga4_event_id = email_events.event_id
relationship: many_to_one
- to: web_events
'on': ga4_event_id = web_events.event_id
relationship: many_to_one

View file

@ -0,0 +1,25 @@
name: ga4_events
table: ga4_events
grain:
- ga4_event_id
columns:
- name: ga4_event_id
type: number
- name: account_id
type: number
- name: event_name
type: string
- name: event_time
type: time
role: time
- name: session_id
type: number
- name: user_id
type: number
joins:
- to: web_sessions
'on': session_id = web_sessions.session_id
relationship: many_to_one
- to: accounts
'on': account_id = accounts.account_id
relationship: many_to_one

View file

@ -0,0 +1,13 @@
name: gl_accounts
table: gl_accounts
grain:
- gl_account_id
columns:
- name: gl_account_id
type: number
- name: account_code
type: string
- name: name
type: string
- name: type
type: string

View file

@ -0,0 +1,22 @@
name: identities
table: identities
grain:
- identity_id
columns:
- name: identity_id
type: number
- name: account_id
type: number
- name: created_at
type: time
role: time
- name: device_id
type: number
- name: email
type: string
- name: user_id
type: number
joins:
- to: accounts
'on': account_id = accounts.account_id
relationship: many_to_one

View file

@ -0,0 +1,25 @@
name: identity_links
table: identity_links
grain:
- link_id
columns:
- name: link_id
type: number
- name: child_identity_id
type: number
- name: linked_at
type: time
role: time
- name: link_source
type: string
- name: parent_identity_id
type: number
joins:
- to: identities
'on': child_identity_id = identities.identity_id
relationship: many_to_one
alias: identities_1
- to: identities
'on': parent_identity_id = identities.identity_id
relationship: many_to_one
alias: identities_2

View file

@ -0,0 +1,24 @@
name: invoice_lines
table: invoice_lines
grain:
- invoice_line_id
columns:
- name: invoice_line_id
type: number
- name: amount
type: number
- name: invoice_id
type: number
- name: product_id
type: number
- name: quantity
type: string
- name: unit_price
type: number
joins:
- to: products
'on': product_id = products.product_id
relationship: many_to_one
- to: invoices
'on': invoice_id = invoices.invoice_id
relationship: many_to_one

View file

@ -0,0 +1,28 @@
name: invoices
table: invoices
grain:
- invoice_id
columns:
- name: invoice_id
type: number
- name: account_id
type: number
- name: contract_id
type: number
- name: currency
type: string
- name: due_date
type: time
role: time
- name: invoice_date
type: time
role: time
- name: status
type: string
joins:
- to: contracts
'on': contract_id = contracts.contract_id
relationship: many_to_one
- to: accounts
'on': account_id = accounts.account_id
relationship: many_to_one

View file

@ -0,0 +1,12 @@
name: journal_entries
table: journal_entries
grain:
- journal_entry_id
columns:
- name: journal_entry_id
type: number
- name: entry_date
type: time
role: time
- name: memo
type: string

View file

@ -0,0 +1,25 @@
name: journal_lines
table: journal_lines
grain:
- journal_line_id
columns:
- name: journal_line_id
type: number
- name: amount
type: number
- name: dr_cr
type: string
- name: gl_account_id
type: number
- name: journal_entry_id
type: number
joins:
- to: gl_accounts
'on': gl_account_id = gl_accounts.gl_account_id
relationship: many_to_one
- to: accounts
'on': gl_account_id = accounts.account_id
relationship: many_to_one
- to: journal_entries
'on': journal_entry_id = journal_entries.journal_entry_id
relationship: many_to_one

View file

@ -0,0 +1,20 @@
name: keyword_rankings
table: keyword_rankings
grain:
- row_id
columns:
- name: domain
type: string
- name: is_competitor
type: string
- name: keyword
type: string
- name: rank
type: string
- name: row_id
type: number
- name: search_volume
type: string
- name: stat_date
type: time
role: time

View file

@ -0,0 +1,18 @@
name: lead_status_history
table: lead_status_history
grain:
- row_id
columns:
- name: changed_at
type: time
role: time
- name: lead_id
type: number
- name: row_id
type: number
- name: status
type: string
joins:
- to: leads
'on': lead_id = leads.lead_id
relationship: many_to_one

View file

@ -0,0 +1,43 @@
name: leads
table: leads
grain:
- lead_id
columns:
- name: lead_id
type: number
- name: account_id
type: number
- name: converted_at
type: time
role: time
- name: converted_opportunity_id
type: number
- name: created_at
type: time
role: time
- name: first_touch_at
type: time
role: time
- name: last_touch_at
type: time
role: time
- name: owner_rep_id
type: number
- name: source
type: string
- name: utm_campaign
type: string
- name: utm_medium
type: string
- name: utm_source
type: string
joins:
- to: accounts
'on': account_id = accounts.account_id
relationship: many_to_one
- to: sales_reps
'on': owner_rep_id = sales_reps.rep_id
relationship: many_to_one
- to: opportunities
'on': converted_opportunity_id = opportunities.opportunity_id
relationship: many_to_one

View file

@ -0,0 +1,22 @@
name: meeting_bookings
table: meeting_bookings
grain:
- meeting_date
columns:
- name: meeting_date
type: time
- name: meeting_id
type: number
- name: opportunity_id
type: number
- name: rep_id
type: number
- name: source
type: string
joins:
- to: sales_reps
'on': rep_id = sales_reps.rep_id
relationship: many_to_one
- to: opportunities
'on': opportunity_id = opportunities.opportunity_id
relationship: many_to_one

View file

@ -0,0 +1,22 @@
name: open_roles
table: open_roles
grain:
- budgeted_salary
columns:
- name: budgeted_salary
type: number
- name: dept_id
type: number
- name: opened_date
type: time
role: time
- name: req_id
type: number
- name: status
type: string
- name: title
type: string
joins:
- to: departments_hr
'on': dept_id = departments_hr.dept_id
relationship: many_to_one

View file

@ -0,0 +1,40 @@
name: opportunities
table: opportunities
grain:
- opportunity_id
columns:
- name: opportunity_id
type: number
- name: account_id
type: number
- name: close_date
type: time
role: time
- name: created_date
type: time
role: time
- name: currency
type: string
- name: lead_source
type: string
- name: owner_rep_id
type: number
- name: parent_opportunity_id
type: number
- name: primary_competitor
type: string
- name: region
type: string
- name: risk_reason
type: string
- name: stage
type: string
- name: type
type: string
joins:
- to: accounts
'on': account_id = accounts.account_id
relationship: many_to_one
- to: sales_reps
'on': owner_rep_id = sales_reps.rep_id
relationship: many_to_one

View file

@ -0,0 +1,20 @@
name: opportunity_contact_roles
table: opportunity_contact_roles
grain:
- contact_id
columns:
- name: contact_id
type: number
- name: ocr_id
type: number
- name: opportunity_id
type: number
- name: role
type: string
joins:
- to: opportunities
'on': opportunity_id = opportunities.opportunity_id
relationship: many_to_one
- to: contacts
'on': contact_id = contacts.contact_id
relationship: many_to_one

View file

@ -0,0 +1,24 @@
name: opportunity_line_items
table: opportunity_line_items
grain:
- discount_pct
columns:
- name: discount_pct
type: string
- name: line_item_id
type: number
- name: opportunity_id
type: number
- name: product_id
type: number
- name: quantity
type: string
- name: unit_price
type: number
joins:
- to: products
'on': product_id = products.product_id
relationship: many_to_one
- to: opportunities
'on': opportunity_id = opportunities.opportunity_id
relationship: many_to_one

View file

@ -0,0 +1,21 @@
name: opportunity_stage_history
table: opportunity_stage_history
grain:
- history_id
columns:
- name: history_id
type: number
- name: entered_at
type: time
role: time
- name: exited_at
type: time
role: time
- name: opportunity_id
type: number
- name: stage
type: string
joins:
- to: opportunities
'on': opportunity_id = opportunities.opportunity_id
relationship: many_to_one

Some files were not shown because too many files have changed in this diff Show more