mirror of
https://github.com/Kaelio/ktx.git
synced 2026-06-07 07:55:13 +02:00
483 lines
14 KiB
Python
483 lines
14 KiB
Python
from __future__ import annotations
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
from ktx_daemon.app import create_app
|
|
from ktx_daemon.database_introspection import (
|
|
DatabaseIntrospectionResponse,
|
|
LiveDatabaseColumn,
|
|
LiveDatabaseTable,
|
|
)
|
|
|
|
|
|
ORDERS_SOURCE = {
|
|
"name": "orders",
|
|
"table": "public.orders",
|
|
"grain": ["id"],
|
|
"columns": [
|
|
{"name": "id", "type": "number"},
|
|
{"name": "status", "type": "string"},
|
|
{"name": "amount", "type": "number"},
|
|
],
|
|
"joins": [],
|
|
"measures": [{"name": "order_count", "expr": "count(*)"}],
|
|
}
|
|
|
|
LOOKML_ORDER_VIEW = """
|
|
view: orders {
|
|
sql_table_name: public.orders ;;
|
|
|
|
dimension: id {
|
|
primary_key: yes
|
|
type: number
|
|
sql: ${TABLE}.id ;;
|
|
}
|
|
|
|
dimension: status {
|
|
type: string
|
|
sql: ${TABLE}.status ;;
|
|
}
|
|
|
|
measure: order_count {
|
|
type: count
|
|
}
|
|
}
|
|
"""
|
|
|
|
|
|
class FakeEmbeddingProvider:
|
|
name = "fake"
|
|
dimensions = 3
|
|
max_batch_size = 2
|
|
|
|
def __init__(self) -> None:
|
|
self.calls: list[list[str]] = []
|
|
|
|
def encode(self, texts: list[str]) -> list[list[float]]:
|
|
self.calls.append(list(texts))
|
|
return [
|
|
[float(len(text)), float(index), 1.0] for index, text in enumerate(texts)
|
|
]
|
|
|
|
|
|
def test_health_endpoint_returns_healthy() -> None:
|
|
client = TestClient(create_app())
|
|
|
|
response = client.get("/health")
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == {"status": "healthy"}
|
|
|
|
|
|
def test_health_endpoint_returns_managed_runtime_version(monkeypatch) -> None:
|
|
monkeypatch.setenv("KTX_DAEMON_VERSION", "0.2.0")
|
|
client = TestClient(create_app())
|
|
|
|
response = client.get("/health")
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == {"status": "healthy", "version": "0.2.0"}
|
|
|
|
|
|
def test_database_introspect_endpoint_returns_snapshot() -> None:
|
|
calls = []
|
|
|
|
def fake_introspector(request):
|
|
calls.append(request)
|
|
return DatabaseIntrospectionResponse(
|
|
connection_id=request.connection_id,
|
|
extracted_at="2026-04-28T10:00:00+00:00",
|
|
metadata={"driver": request.driver, "schemas": request.schemas},
|
|
tables=[
|
|
LiveDatabaseTable(
|
|
catalog="warehouse",
|
|
db="public",
|
|
name="orders",
|
|
columns=[
|
|
LiveDatabaseColumn(
|
|
name="id",
|
|
type="integer",
|
|
nullable=False,
|
|
primary_key=True,
|
|
)
|
|
],
|
|
)
|
|
],
|
|
)
|
|
|
|
client = TestClient(create_app(database_introspector=fake_introspector))
|
|
|
|
response = client.post(
|
|
"/database/introspect",
|
|
json={
|
|
"connection_id": "warehouse",
|
|
"driver": "postgres",
|
|
"url": "postgresql://readonly@example.test/warehouse",
|
|
"schemas": ["public"],
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.json()["connection_id"] == "warehouse"
|
|
assert response.json()["tables"][0]["name"] == "orders"
|
|
assert calls[0].connection_id == "warehouse"
|
|
|
|
|
|
def test_database_introspect_endpoint_maps_value_error_to_400() -> None:
|
|
def fake_introspector(request):
|
|
raise ValueError('database introspection supports only driver "postgres"')
|
|
|
|
client = TestClient(create_app(database_introspector=fake_introspector))
|
|
|
|
response = client.post(
|
|
"/database/introspect",
|
|
json={
|
|
"connection_id": "warehouse",
|
|
"driver": "snowflake",
|
|
"url": "snowflake://example",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert response.json() == {
|
|
"detail": 'database introspection supports only driver "postgres"'
|
|
}
|
|
|
|
|
|
def test_embedding_compute_endpoint_returns_embedding() -> None:
|
|
provider = FakeEmbeddingProvider()
|
|
client = TestClient(create_app(embedding_provider=provider))
|
|
|
|
response = client.post("/embeddings/compute", json={"text": "hello"})
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == {"embedding": [5.0, 0.0, 1.0]}
|
|
assert provider.calls == [["hello"]]
|
|
|
|
|
|
def test_embedding_compute_bulk_endpoint_returns_embeddings() -> None:
|
|
provider = FakeEmbeddingProvider()
|
|
client = TestClient(create_app(embedding_provider=provider))
|
|
|
|
response = client.post(
|
|
"/embeddings/compute-bulk",
|
|
json={"texts": ["one", "three"]},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
assert response.json() == {"embeddings": [[3.0, 0.0, 1.0], [5.0, 1.0, 1.0]]}
|
|
assert provider.calls == [["one", "three"]]
|
|
|
|
|
|
def test_embedding_compute_bulk_endpoint_maps_value_error_to_400() -> None:
|
|
provider = FakeEmbeddingProvider()
|
|
client = TestClient(create_app(embedding_provider=provider))
|
|
|
|
response = client.post(
|
|
"/embeddings/compute-bulk",
|
|
json={"texts": ["one", "two", "three"]},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert response.json() == {"detail": "Maximum 2 texts allowed per batch"}
|
|
assert provider.calls == []
|
|
|
|
|
|
def test_code_execute_endpoint_is_not_registered_by_default() -> None:
|
|
client = TestClient(create_app())
|
|
|
|
response = client.post("/code/execute", json={"code": "result = 7"})
|
|
|
|
assert response.status_code == 404
|
|
|
|
|
|
def test_code_execute_endpoint_returns_result_when_enabled() -> None:
|
|
client = TestClient(create_app(enable_code_execution=True))
|
|
|
|
response = client.post(
|
|
"/code/execute",
|
|
json={"code": 'print("ran")\nresult = {"value": 7}'},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["result"] == {"value": 7}
|
|
assert body["console_output"] == "ran\n"
|
|
assert body["error"] is None
|
|
assert body["message"] is None
|
|
assert body["visualizations"] is None
|
|
assert "=== Console Output ===" in body["formatted_result"]
|
|
assert "=== Result ===" in body["formatted_result"]
|
|
|
|
|
|
def test_code_execute_endpoint_serializes_numpy_result_when_enabled() -> None:
|
|
client = TestClient(create_app(enable_code_execution=True))
|
|
|
|
response = client.post(
|
|
"/code/execute",
|
|
json={"code": "import numpy as np\nresult = {'value': np.float64(1.25)}"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["result"] == {"value": 1.25}
|
|
assert body["error"] is None
|
|
|
|
|
|
def test_code_execute_endpoint_uses_host_free_boundary_when_enabled() -> None:
|
|
client = TestClient(create_app(enable_code_execution=True))
|
|
|
|
response = client.post(
|
|
"/code/execute",
|
|
json={
|
|
"source_id": "chat_123",
|
|
"message_id": "message_456",
|
|
"code": (
|
|
"import pandas as pd\n"
|
|
"result = save_df_to_scratchpad(pd.DataFrame({'value': [1]}), 'out.json')"
|
|
),
|
|
},
|
|
headers={"Authorization": "Bearer should-not-forward"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["result"] is None
|
|
assert (
|
|
body["error"]
|
|
== "nest_api_url, Authorization header, and source_id are required for scratchpad operations"
|
|
)
|
|
assert "=== Error ===" in body["formatted_result"]
|
|
|
|
|
|
def test_sql_parse_table_identifier_endpoint() -> None:
|
|
client = TestClient(create_app())
|
|
|
|
response = client.post(
|
|
"/sql/parse-table-identifier",
|
|
json={
|
|
"items": [
|
|
{
|
|
"key": "orders",
|
|
"sql_table_name": "public.orders",
|
|
"dialect": "postgres",
|
|
},
|
|
{
|
|
"key": "template",
|
|
"sql_table_name": "${orders.SQL_TABLE_NAME}",
|
|
"dialect": "postgres",
|
|
},
|
|
]
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["results"]["orders"]["ok"] is True
|
|
assert body["results"]["orders"]["schema"] == "public"
|
|
assert body["results"]["orders"]["name"] == "orders"
|
|
assert body["results"]["template"]["ok"] is False
|
|
assert body["results"]["template"]["reason"] == "looker_template_unresolved"
|
|
|
|
|
|
def test_sql_analyze_batch_endpoint_returns_per_item_results() -> None:
|
|
client = TestClient(create_app())
|
|
|
|
response = client.post(
|
|
"/sql/analyze-batch",
|
|
json={
|
|
"dialect": "postgres",
|
|
"max_workers": 1,
|
|
"items": [
|
|
{
|
|
"id": "orders",
|
|
"sql": "select status from public.orders where created_at is not null",
|
|
},
|
|
{"id": "broken", "sql": "select * from where"},
|
|
],
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["results"]["orders"]["tables_touched"] == ["public.orders"]
|
|
assert body["results"]["orders"]["columns_by_clause"] == {
|
|
"select": ["status"],
|
|
"where": ["created_at"],
|
|
}
|
|
assert body["results"]["orders"]["error"] is None
|
|
assert body["results"]["broken"]["tables_touched"] == []
|
|
assert body["results"]["broken"]["columns_by_clause"] == {}
|
|
assert body["results"]["broken"]["error"] is not None
|
|
|
|
|
|
def test_semantic_query_endpoint_returns_sql() -> None:
|
|
client = TestClient(create_app())
|
|
|
|
response = client.post(
|
|
"/semantic-layer/query",
|
|
json={
|
|
"sources": [ORDERS_SOURCE],
|
|
"dialect": "postgres",
|
|
"query": {
|
|
"measures": ["orders.order_count"],
|
|
"dimensions": ["orders.status"],
|
|
},
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["dialect"] == "postgres"
|
|
assert "public.orders" in body["sql"]
|
|
assert body["columns"][0]["name"] == "orders.status"
|
|
|
|
|
|
def test_semantic_query_endpoint_maps_value_error_to_400() -> None:
|
|
client = TestClient(create_app())
|
|
|
|
response = client.post(
|
|
"/semantic-layer/query",
|
|
json={
|
|
"sources": [ORDERS_SOURCE],
|
|
"dialect": "postgres",
|
|
"query": {
|
|
"measures": ["missing.order_count"],
|
|
"dimensions": [],
|
|
},
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "missing.order_count" in response.json()["detail"]
|
|
|
|
|
|
def test_semantic_validate_endpoint_returns_structured_validation() -> None:
|
|
client = TestClient(create_app())
|
|
invalid_source = {
|
|
**ORDERS_SOURCE,
|
|
"measures": [
|
|
{"name": "revenue", "expr": "sum(amount)"},
|
|
{"name": "revenue", "expr": "sum(amount)"},
|
|
],
|
|
}
|
|
|
|
response = client.post(
|
|
"/semantic-layer/validate",
|
|
json={"sources": [invalid_source], "dialect": "postgres"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["valid"] is False
|
|
assert any("Duplicate measure" in error for error in body["errors"])
|
|
assert body["warnings"] == []
|
|
assert body["per_source_warnings"] == {}
|
|
|
|
|
|
def test_semantic_generate_sources_endpoint_returns_sources() -> None:
|
|
client = TestClient(create_app())
|
|
|
|
response = client.post(
|
|
"/semantic-layer/generate-sources",
|
|
json={
|
|
"tables": [
|
|
{
|
|
"name": "orders",
|
|
"db": "public",
|
|
"comment": "Orders table",
|
|
"columns": [
|
|
{
|
|
"name": "id",
|
|
"type": "integer",
|
|
"primary_key": True,
|
|
"nullable": False,
|
|
"comment": "Order ID",
|
|
},
|
|
{"name": "customer_id", "type": "integer"},
|
|
{
|
|
"name": "amount",
|
|
"type": "decimal",
|
|
"comment": "Order amount",
|
|
},
|
|
],
|
|
},
|
|
{
|
|
"name": "customers",
|
|
"db": "public",
|
|
"columns": [
|
|
{"name": "id", "type": "integer", "primary_key": True},
|
|
{"name": "email", "type": "varchar"},
|
|
],
|
|
},
|
|
],
|
|
"links": [
|
|
{
|
|
"from_table": "orders",
|
|
"from_column": "customer_id",
|
|
"to_table": "customers",
|
|
"to_column": "id",
|
|
"relationship_type": "MANY_TO_ONE",
|
|
}
|
|
],
|
|
"dialect": "postgres",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["source_count"] == 2
|
|
sources = {source["name"]: source for source in body["sources"]}
|
|
assert sources["orders"]["table"] == "public.orders"
|
|
assert sources["orders"]["description"] == "Orders table"
|
|
assert sources["orders"]["grain"] == ["id"]
|
|
assert sources["orders"]["joins"] == [
|
|
{
|
|
"to": "customers",
|
|
"on": "customer_id = customers.id",
|
|
"relationship": "many_to_one",
|
|
}
|
|
]
|
|
assert [measure["name"] for measure in sources["orders"]["measures"]] == [
|
|
"record_count",
|
|
"total_amount",
|
|
"avg_amount",
|
|
]
|
|
|
|
|
|
def test_lookml_parse_endpoint_returns_resolved_views() -> None:
|
|
client = TestClient(create_app())
|
|
|
|
response = client.post(
|
|
"/lookml/parse",
|
|
json={
|
|
"files": [
|
|
{
|
|
"path": "views/orders.view.lkml",
|
|
"content": LOOKML_ORDER_VIEW,
|
|
}
|
|
],
|
|
"dialect": "postgres",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
body = response.json()
|
|
assert body["joins"] == []
|
|
assert body["skipped_views"] == []
|
|
assert body["warnings"] == []
|
|
assert len(body["views"]) == 1
|
|
view = body["views"][0]
|
|
assert view["name"] == "orders"
|
|
assert view["source_type"] == "table"
|
|
assert view["table_ref"] == "public.orders"
|
|
assert view["grain"] == ["id"]
|
|
assert [column["name"] for column in view["columns"]] == ["id", "status"]
|
|
assert view["measures"] == [
|
|
{
|
|
"name": "order_count",
|
|
"expr": "count(*)",
|
|
"filter": None,
|
|
"description": None,
|
|
}
|
|
]
|