From 22fff0accfe6a8f29e638796a6c11288de238749 Mon Sep 17 00:00:00 2001 From: Jagadeeshck Date: Thu, 21 May 2026 15:47:39 +0100 Subject: [PATCH] Add API request IDs, structured errors, pagination and filtering --- docs/api/rest-api.md | 35 +++++ src/api/app.py | 305 +++++++++++++++--------------------- src/api/es_store.py | 68 ++++++-- src/api/store.py | 47 ++++-- tests/test_api_endpoints.py | 256 +++++++----------------------- tests/test_api_es_store.py | 2 +- 6 files changed, 311 insertions(+), 402 deletions(-) create mode 100644 docs/api/rest-api.md diff --git a/docs/api/rest-api.md b/docs/api/rest-api.md new file mode 100644 index 0000000..6c2e8de --- /dev/null +++ b/docs/api/rest-api.md @@ -0,0 +1,35 @@ +# DataObs REST API + +## Auth +Use `Authorization: Bearer ` for protected endpoints. + +## Request IDs +Send optional `X-Request-ID`; API echoes it, or generates UUID when missing. + +## Error format +```json +{"error":{"code":"validation_error","message":"Validation error","details":{}},"request_id":"..."} +``` +Covers 400/401/404/405/422/500. + +## Pagination +Supported on `/rules`, `/quality/results`, `/lineage/nodes`, `/lineage/edges` with: +- `limit` (default 100, max 1000) +- `offset` (default 0) + +Response includes: +- existing key (`rules/results/nodes/edges`) +- `count` +- `pagination` with `limit`, `offset`, `returned`, `total`, `has_more`. + +## Filters +- `/quality/results`: `table,status,dataset,check_type,severity,run_id` +- `/rules`: `dataset,enabled,severity,check_type` +- `/lineage/nodes`: `node_type` (or `type`), `dataset` +- `/lineage/edges`: `source,target,relation` (or `relation_type`) + +## curl examples +```bash +curl -H "Authorization: Bearer $TOKEN" -H "X-Request-ID: req-123" "http://localhost:8000/rules?limit=50&offset=0&dataset=orders" +curl -H "Authorization: Bearer $TOKEN" "http://localhost:8000/quality/results?table=orders&status=fail&limit=20" +``` diff --git a/src/api/app.py b/src/api/app.py index afffc82..8a34eec 100644 --- a/src/api/app.py +++ b/src/api/app.py @@ -1,17 +1,12 @@ -"""FastAPI application factory for the DataObs API. - -The public route shapes intentionally mirror the original -``BaseHTTPRequestHandler`` implementation while moving request handling to -FastAPI, Pydantic models, and explicit dependency injection. -""" from __future__ import annotations import logging +import uuid from dataclasses import dataclass from typing import Any, Dict, List, Optional from elasticsearch import Elasticsearch -from fastapi import Depends, FastAPI, Header, HTTPException, Request, status +from fastapi import Depends, FastAPI, Header, HTTPException, Query, Request, status from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -27,12 +22,50 @@ class DataObsModel(BaseModel): - """Base model that allows legacy clients to send additional fields.""" - model_config = ConfigDict(extra="allow") -APISettings = AppSettings +class ErrorInfo(BaseModel): + code: str + message: str + details: Dict[str, Any] = {} + + +class ErrorResponse(BaseModel): + error: ErrorInfo + request_id: str + + +class PaginationMeta(BaseModel): + limit: int + offset: int + returned: int + total: int + has_more: bool + + +class RulesResponse(BaseModel): + rules: List[Dict[str, Any]] + count: int + pagination: PaginationMeta + + +class QualityResultsResponse(BaseModel): + results: List[Dict[str, Any]] + count: int + pagination: PaginationMeta + + +class LineageNodesResponse(BaseModel): + nodes: List[Dict[str, Any]] + count: int + pagination: PaginationMeta + + +class LineageEdgesResponse(BaseModel): + edges: List[Dict[str, Any]] + count: int + pagination: PaginationMeta class HealthResponse(BaseModel): @@ -42,11 +75,6 @@ class HealthResponse(BaseModel): auth_mode: str -class ErrorResponse(BaseModel): - error: str - details: Optional[Any] = None - - class RuleRequest(DataObsModel): rule_id: Optional[str] = None dataset: Optional[str] = None @@ -67,27 +95,6 @@ class RuleDeleteResponse(BaseModel): status: str = "deleted" -class RulesResponse(BaseModel): - rules: List[Dict[str, Any]] - count: int - - -class LineageNodesResponse(BaseModel): - nodes: List[Dict[str, Any]] - count: int - - -class LineageEdgesResponse(BaseModel): - edges: List[Dict[str, Any]] - count: int - - -class LineageImpactResponse(BaseModel): - root_node: str - affected: List[str] - count: int - - class QualityResultRequest(DataObsModel): id: Optional[str] = None check_name: Optional[str] = None @@ -103,213 +110,157 @@ class QualityResultCreateResponse(BaseModel): status: str = "created" -class QualityResultsResponse(BaseModel): - results: List[Dict[str, Any]] - count: int - - class EnterpriseBacklogResponse(BaseModel): backlog: List[Dict[str, Any]] +class LineageImpactResponse(BaseModel): + root_node: str + affected: List[str] + count: int + + @dataclass(frozen=True) class StoreBundle: - """Injected store handles used by route dependencies.""" - store: StoreProtocol -def settings_from_env() -> APISettings: - """Load API settings from the unified typed settings layer.""" +def settings_from_env() -> AppSettings: return load_settings() -def make_es_client(settings: APISettings) -> Elasticsearch: - return Elasticsearch( - [settings.elasticsearch.url], - basic_auth=(settings.elasticsearch.user, settings.elasticsearch.password), - request_timeout=30, - ) +def make_es_client(settings: AppSettings) -> Elasticsearch: + return Elasticsearch([settings.elasticsearch.url], basic_auth=(settings.elasticsearch.user, settings.elasticsearch.password), request_timeout=30) -def create_store_bundle(settings: APISettings) -> StoreBundle: - """Create the production store bundle from settings.""" - es_client: Elasticsearch | None = None +def create_store_bundle(settings: AppSettings) -> StoreBundle: if settings.store_backend.lower() == "elasticsearch": - es_client = make_es_client(settings) - store = get_store(es_client=es_client, tenant_id=settings.tenant_id) - return StoreBundle(store=store) - - store = get_store(es_client=None, tenant_id=settings.tenant_id) - return StoreBundle(store=store) + return StoreBundle(store=get_store(es_client=make_es_client(settings), tenant_id=settings.tenant_id)) + return StoreBundle(store=get_store(es_client=None, tenant_id=settings.tenant_id)) def _as_dict(model: DataObsModel) -> Dict[str, Any]: - """Return a request model as a dict excluding omitted/None fields.""" if hasattr(model, "model_dump"): - return model.model_dump(exclude_none=True) # Pydantic v2 - return model.dict(exclude_none=True) # Pydantic v1 + return model.model_dump(exclude_none=True) + return model.dict(exclude_none=True) -def create_app( - *, - settings: APISettings | None = None, - store_bundle: StoreBundle | None = None, -) -> FastAPI: - """Create and configure the FastAPI application.""" - resolved_settings = settings or settings_from_env() - resolved_bundle = store_bundle or create_store_bundle(resolved_settings) +def _error_payload(code: str, message: str, request_id: str, details: Dict[str, Any] | None = None) -> Dict[str, Any]: + return {"error": {"code": code, "message": message, "details": details or {}}, "request_id": request_id} + + +def _request_id(request: Request) -> str: + return getattr(request.state, "request_id", "unknown") - if resolved_settings.api_token is None: - if not resolved_settings.auth.allow_unauthenticated_dev: - logger.warning("API_TOKEN is not set and unauthenticated dev mode is disabled.") - else: - logger.warning( - "API_TOKEN is not set — running in explicit unauthenticated dev mode. " - "Set API_TOKEN in production." - ) - else: - logger.info("Bearer token authentication enabled.") - - app = FastAPI( - title="DataObs API", - version="1.0.0", - description="REST API for DataObs rules, lineage, quality results, and strategy backlog.", - responses={ - 401: {"model": ErrorResponse}, - 404: {"model": ErrorResponse}, - 405: {"model": ErrorResponse}, - }, - ) +def _paginate(items: List[Dict[str, Any]], limit: int, offset: int) -> Dict[str, Any]: + total = len(items) + sliced = items[offset:offset + limit] + return {"items": sliced, "pagination": {"limit": limit, "offset": offset, "returned": len(sliced), "total": total, "has_more": offset + len(sliced) < total}} + + +def create_app(*, settings: AppSettings | None = None, store_bundle: StoreBundle | None = None) -> FastAPI: + resolved_settings = settings or settings_from_env() + resolved_bundle = store_bundle or create_store_bundle(resolved_settings) + app = FastAPI(title="DataObs API", version="1.0.0") app.state.settings = resolved_settings app.state.store_bundle = resolved_bundle - def get_settings(request: Request) -> APISettings: + @app.middleware("http") + async def request_id_middleware(request: Request, call_next): + request_id = request.headers.get("X-Request-ID") or str(uuid.uuid4()) + request.state.request_id = request_id + response = await call_next(request) + response.headers["X-Request-ID"] = request_id + return response + + def get_settings(request: Request) -> AppSettings: return request.app.state.settings def get_stores(request: Request) -> StoreBundle: return request.app.state.store_bundle - async def require_auth( - settings: APISettings = Depends(get_settings), - credentials: HTTPAuthorizationCredentials | None = Depends(_bearer), - authorization: str | None = Header(default=None), - ) -> None: - if settings.api_token is None: - if settings.auth.allow_unauthenticated_dev: - return - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Unauthorized - valid Bearer token required", - headers={"WWW-Authenticate": 'Bearer realm="DataObs API"'}, - ) - + async def require_auth(settings: AppSettings = Depends(get_settings), credentials: HTTPAuthorizationCredentials | None = Depends(_bearer), authorization: str | None = Header(default=None)) -> None: + if settings.api_token is None and settings.auth.allow_unauthenticated_dev: + return token = credentials.credentials if credentials and credentials.scheme.lower() == "bearer" else None - # Preserve the legacy parser's exact "Bearer " prefix behavior for unusual clients. if token is None and authorization and authorization.startswith("Bearer "): token = authorization[len("Bearer "):].strip() if token != settings.api_token: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Unauthorized - valid Bearer token required", - headers={"WWW-Authenticate": 'Bearer realm="DataObs API"'}, - ) + raise HTTPException(status_code=401, detail="Unauthorized - valid Bearer token required", headers={"WWW-Authenticate": 'Bearer realm="DataObs API"'}) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse: - return JSONResponse( - status_code=exc.status_code, - content={"error": str(exc.detail)}, - headers=exc.headers, - ) + code_map = {400: "bad_request", 401: "unauthorized", 404: "not_found", 405: "method_not_allowed"} + return JSONResponse(status_code=exc.status_code, content=_error_payload(code_map.get(exc.status_code, "http_error"), str(exc.detail), _request_id(request)), headers=exc.headers) @app.exception_handler(StarletteHTTPException) async def starlette_exception_handler(request: Request, exc: StarletteHTTPException) -> JSONResponse: - message = "Not found" if exc.status_code == 404 else str(exc.detail) - if exc.status_code == 405: - message = "Method not allowed" - return JSONResponse(status_code=exc.status_code, content={"error": message}) + msg = "Not found" if exc.status_code == 404 else ("Method not allowed" if exc.status_code == 405 else str(exc.detail)) + return JSONResponse(status_code=exc.status_code, content=_error_payload({404: "not_found", 405: "method_not_allowed"}.get(exc.status_code, "http_error"), msg, _request_id(request))) @app.exception_handler(RequestValidationError) async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse: errors = exc.errors() - if any(error.get("type") == "json_invalid" for error in errors): - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"error": "Invalid JSON body", "details": errors}, - ) - return JSONResponse( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - content={"error": "Validation error", "details": errors}, - ) + if any(e.get("type") == "json_invalid" for e in errors): + return JSONResponse(status_code=400, content=_error_payload("bad_request", "Invalid JSON body", _request_id(request), {"errors": errors})) + return JSONResponse(status_code=422, content=_error_payload("validation_error", "Validation error", _request_id(request), {"errors": errors})) + + @app.exception_handler(Exception) + async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: + logger.exception("Unhandled exception request_id=%s", _request_id(request)) + return JSONResponse(status_code=500, content=_error_payload("internal_server_error", "Internal server error", _request_id(request))) @app.get("/health", response_model=HealthResponse) - async def health(settings: APISettings = Depends(get_settings)) -> Dict[str, Any]: - return { - "status": "ok", - "service": "dataobs-api", - "store_backend": settings.store_backend, - "auth_mode": settings.auth_mode, - } + async def health(settings: AppSettings = Depends(get_settings)) -> Dict[str, Any]: + return {"status": "ok", "service": "dataobs-api", "store_backend": settings.store_backend, "auth_mode": settings.auth_mode} @app.get("/rules", response_model=RulesResponse, dependencies=[Depends(require_auth)]) - async def get_rules(stores: StoreBundle = Depends(get_stores)) -> Dict[str, Any]: + async def get_rules(stores: StoreBundle = Depends(get_stores), limit: int = Query(default=100, ge=1, le=1000), offset: int = Query(default=0, ge=0), dataset: str | None = None, enabled: bool | None = None, severity: str | None = None, check_type: str | None = None) -> Dict[str, Any]: rules = stores.store.get_all_rules() - return {"rules": rules, "count": len(rules)} + filtered = [r for r in rules if (dataset is None or r.get("dataset") == dataset) and (enabled is None or r.get("enabled") == enabled) and (severity is None or r.get("severity") == severity) and (check_type is None or r.get("check_type", r.get("type")) == check_type)] + page = _paginate(filtered, limit, offset) + return {"rules": page["items"], "count": len(page["items"]), "pagination": page["pagination"]} + + @app.get("/quality/results", response_model=QualityResultsResponse, dependencies=[Depends(require_auth)]) + async def get_quality_results(stores: StoreBundle = Depends(get_stores), limit: int = Query(default=100, ge=1, le=1000), offset: int = Query(default=0, ge=0), table: str | None = None, status: str | None = None, dataset: str | None = None, check_type: str | None = None, severity: str | None = None, run_id: str | None = None) -> Dict[str, Any]: + results = stores.store.list_quality_results(limit=1000, offset=0, table=table, status=status, dataset=dataset, check_type=check_type, severity=severity, run_id=run_id) + page = _paginate(results, limit, offset) + return {"results": page["items"], "count": len(page["items"]), "pagination": page["pagination"]} + + @app.get("/lineage/nodes", response_model=LineageNodesResponse, dependencies=[Depends(require_auth)]) + async def get_lineage_nodes(stores: StoreBundle = Depends(get_stores), limit: int = Query(default=100, ge=1, le=1000), offset: int = Query(default=0, ge=0), node_type: str | None = None, type: str | None = None, dataset: str | None = None) -> Dict[str, Any]: + nodes = stores.store.get_all_nodes(limit=1000, offset=0, node_type=node_type or type, dataset=dataset) + page = _paginate(nodes, limit, offset) + return {"nodes": page["items"], "count": len(page["items"]), "pagination": page["pagination"]} + + @app.get("/lineage/edges", response_model=LineageEdgesResponse, dependencies=[Depends(require_auth)]) + async def get_lineage_edges(stores: StoreBundle = Depends(get_stores), limit: int = Query(default=100, ge=1, le=1000), offset: int = Query(default=0, ge=0), source: str | None = None, target: str | None = None, relation: str | None = None, relation_type: str | None = None) -> Dict[str, Any]: + edges = stores.store.get_all_edges(limit=1000, offset=0, source=source, target=target, relation=relation or relation_type) + page = _paginate(edges, limit, offset) + return {"edges": page["items"], "count": len(page["items"]), "pagination": page["pagination"]} @app.post("/rules", status_code=201, response_model=RuleCreateResponse, dependencies=[Depends(require_auth)]) async def create_rule(rule: RuleRequest, stores: StoreBundle = Depends(get_stores)) -> Dict[str, Any]: - rule_id = stores.store.add_rule(_as_dict(rule)) - return {"rule_id": rule_id, "status": "created"} + return {"rule_id": stores.store.add_rule(_as_dict(rule)), "status": "created"} @app.delete("/rules/{rule_id}", response_model=RuleDeleteResponse, dependencies=[Depends(require_auth)]) async def delete_rule(rule_id: str, stores: StoreBundle = Depends(get_stores)) -> Dict[str, Any]: - if not rule_id: - raise HTTPException(status_code=400, detail="rule_id is required") - deleted = stores.store.delete_rule(rule_id) - if not deleted: + if not stores.store.delete_rule(rule_id): raise HTTPException(status_code=404, detail=f"Rule '{rule_id}' not found") return {"rule_id": rule_id, "status": "deleted"} - @app.get("/lineage/nodes", response_model=LineageNodesResponse, dependencies=[Depends(require_auth)]) - async def get_lineage_nodes(stores: StoreBundle = Depends(get_stores)) -> Dict[str, Any]: - nodes = stores.store.get_all_nodes() - return {"nodes": nodes, "count": len(nodes)} - - @app.get("/lineage/edges", response_model=LineageEdgesResponse, dependencies=[Depends(require_auth)]) - async def get_lineage_edges(stores: StoreBundle = Depends(get_stores)) -> Dict[str, Any]: - edges = stores.store.get_all_edges() - return {"edges": edges, "count": len(edges)} - @app.get("/lineage/impact/{node_id:path}", response_model=LineageImpactResponse, dependencies=[Depends(require_auth)]) async def get_lineage_impact(node_id: str, stores: StoreBundle = Depends(get_stores)) -> Dict[str, Any]: - if not node_id: - raise HTTPException(status_code=400, detail="node_id is required") affected = stores.store.get_downstream_impact(node_id) return {"root_node": node_id, "affected": affected, "count": len(affected)} - @app.get("/quality/results", response_model=QualityResultsResponse, dependencies=[Depends(require_auth)]) - async def get_quality_results(stores: StoreBundle = Depends(get_stores)) -> Dict[str, Any]: - results = stores.store.list_quality_results() - return {"results": results, "count": len(results)} - - @app.post( - "/quality/results", - status_code=201, - response_model=QualityResultCreateResponse, - dependencies=[Depends(require_auth)], - ) - async def create_quality_result( - result: QualityResultRequest, - stores: StoreBundle = Depends(get_stores), - ) -> Dict[str, Any]: - doc_id = stores.store.save_quality_result(_as_dict(result)) - return {"id": doc_id, "status": "created"} + @app.post("/quality/results", status_code=201, response_model=QualityResultCreateResponse, dependencies=[Depends(require_auth)]) + async def create_quality_result(result: QualityResultRequest, stores: StoreBundle = Depends(get_stores)) -> Dict[str, Any]: + return {"id": stores.store.save_quality_result(_as_dict(result)), "status": "created"} @app.get("/strategy/enterprise-backlog", response_model=EnterpriseBacklogResponse, dependencies=[Depends(require_auth)]) async def get_enterprise_backlog() -> Dict[str, Any]: - backlog = enterprise_backlog(implemented_keys=[]) - return {"backlog": backlog} + return {"backlog": enterprise_backlog(implemented_keys=[])} return app diff --git a/src/api/es_store.py b/src/api/es_store.py index 66a65de..d17f919 100644 --- a/src/api/es_store.py +++ b/src/api/es_store.py @@ -161,8 +161,13 @@ def get_quality_result(self, result_id: str) -> Optional[Dict[str, Any]]: def list_quality_results( self, limit: int = 100, + offset: int = 0, table: Optional[str] = None, status: Optional[str] = None, + dataset: Optional[str] = None, + check_type: Optional[str] = None, + severity: Optional[str] = None, + run_id: Optional[str] = None, ) -> List[Dict[str, Any]]: """List quality results, optionally filtered by table and/or status.""" must: List[Dict[str, Any]] = [{"term": {"tenant_id": self._tenant}}] @@ -170,12 +175,21 @@ def list_quality_results( must.append({"term": {"table": table}}) if status: must.append({"term": {"status": status}}) + if dataset: + must.append({"term": {"dataset": dataset}}) + if check_type: + must.append({"term": {"check_type": check_type}}) + if severity: + must.append({"term": {"severity": severity}}) + if run_id: + must.append({"term": {"run_id": run_id}}) try: resp = self._es.search( index=self._qi, query={"bool": {"must": must}}, sort=[{"@timestamp": "desc"}], size=limit, + from_=offset, ) return [h["_source"] for h in resp["hits"]["hits"]] except Exception: @@ -215,14 +229,24 @@ def get_rule(self, rule_id: str) -> Optional[Dict[str, Any]]: logger.exception("Failed to fetch rule '%s'", rule_id) return None - def get_all_rules(self) -> List[Dict[str, Any]]: + def get_all_rules(self, limit: int = 100, offset: int = 0, dataset: Optional[str] = None, enabled: Optional[bool] = None, severity: Optional[str] = None, check_type: Optional[str] = None) -> List[Dict[str, Any]]: """Return all rules for this tenant, ordered by dataset.""" try: + must: List[Dict[str, Any]] = [{"term": {"tenant_id": self._tenant}}] + if dataset is not None: + must.append({"term": {"dataset": dataset}}) + if enabled is not None: + must.append({"term": {"enabled": enabled}}) + if severity is not None: + must.append({"term": {"severity": severity}}) + if check_type is not None: + must.append({"term": {"check_type": check_type}}) resp = self._es.search( index=self._ri, - query={"term": {"tenant_id": self._tenant}}, + query={"bool": {"must": must}}, sort=[{"dataset": {"order": "asc", "unmapped_type": "keyword"}}], - size=1000, + size=limit, + from_=offset, ) return [h["_source"] for h in resp["hits"]["hits"]] except Exception: @@ -282,15 +306,21 @@ def get_lineage_node(self, node_id: str) -> Optional[Dict[str, Any]]: logger.exception("Failed to fetch lineage node '%s'", node_id) return None - def get_all_nodes(self) -> List[Dict[str, Any]]: + def get_all_nodes(self, limit: int = 100, offset: int = 0, node_type: Optional[str] = None, dataset: Optional[str] = None) -> List[Dict[str, Any]]: try: + must: List[Dict[str, Any]] = [ + {"term": {"tenant_id": self._tenant}}, + {"term": {"doc_type": "node"}}, + ] + if node_type: + must.append({"term": {"type": node_type}}) + if dataset: + must.append({"term": {"dataset": dataset}}) resp = self._es.search( index=self._li, - query={"bool": {"must": [ - {"term": {"tenant_id": self._tenant}}, - {"term": {"doc_type": "node"}}, - ]}}, - size=1000, + query={"bool": {"must": must}}, + size=limit, + from_=offset, ) return [h["_source"] for h in resp["hits"]["hits"]] except Exception: @@ -309,15 +339,23 @@ def save_lineage_edge(self, edge: Dict[str, Any]) -> str: self._es.index(index=self._li, id=f"edge::{edge_id}", document=doc, refresh="wait_for") return edge_id - def get_all_edges(self) -> List[Dict[str, Any]]: + def get_all_edges(self, limit: int = 100, offset: int = 0, source: Optional[str] = None, target: Optional[str] = None, relation: Optional[str] = None) -> List[Dict[str, Any]]: try: + must: List[Dict[str, Any]] = [ + {"term": {"tenant_id": self._tenant}}, + {"term": {"doc_type": "edge"}}, + ] + if source: + must.append({"term": {"source_node_id": source}}) + if target: + must.append({"term": {"target_node_id": target}}) + if relation: + must.append({"term": {"relation": relation}}) resp = self._es.search( index=self._li, - query={"bool": {"must": [ - {"term": {"tenant_id": self._tenant}}, - {"term": {"doc_type": "edge"}}, - ]}}, - size=1000, + query={"bool": {"must": must}}, + size=limit, + from_=offset, ) return [h["_source"] for h in resp["hits"]["hits"]] except Exception: diff --git a/src/api/store.py b/src/api/store.py index 279cb04..95c50cf 100644 --- a/src/api/store.py +++ b/src/api/store.py @@ -29,18 +29,18 @@ def _now_iso() -> str: class StoreProtocol(Protocol): def save_quality_result(self, result: Dict[str, Any]) -> str: ... def get_quality_result(self, result_id: str) -> Optional[Dict[str, Any]]: ... - def list_quality_results(self, limit: int = 100, table: Optional[str] = None, status: Optional[str] = None) -> List[Dict[str, Any]]: ... + def list_quality_results(self, limit: int = 100, offset: int = 0, table: Optional[str] = None, status: Optional[str] = None, dataset: Optional[str] = None, check_type: Optional[str] = None, severity: Optional[str] = None, run_id: Optional[str] = None) -> List[Dict[str, Any]]: ... def save_rule(self, rule: Dict[str, Any]) -> str: ... def add_rule(self, rule: Dict[str, Any]) -> str: ... def get_rule(self, rule_id: str) -> Optional[Dict[str, Any]]: ... - def get_all_rules(self) -> List[Dict[str, Any]]: ... + def get_all_rules(self, limit: int = 100, offset: int = 0, dataset: Optional[str] = None, enabled: Optional[bool] = None, severity: Optional[str] = None, check_type: Optional[str] = None) -> List[Dict[str, Any]]: ... def get_rules_for_dataset(self, dataset: str) -> List[Dict[str, Any]]: ... def delete_rule(self, rule_id: str) -> bool: ... def save_lineage_node(self, node: Dict[str, Any]) -> str: ... def get_lineage_node(self, node_id: str) -> Optional[Dict[str, Any]]: ... - def get_all_nodes(self) -> List[Dict[str, Any]]: ... + def get_all_nodes(self, limit: int = 100, offset: int = 0, node_type: Optional[str] = None, dataset: Optional[str] = None) -> List[Dict[str, Any]]: ... def save_lineage_edge(self, edge: Dict[str, Any]) -> str: ... - def get_all_edges(self) -> List[Dict[str, Any]]: ... + def get_all_edges(self, limit: int = 100, offset: int = 0, source: Optional[str] = None, target: Optional[str] = None, relation: Optional[str] = None) -> List[Dict[str, Any]]: ... def get_downstream_impact(self, node_id: str, depth: int = 5) -> List[str]: ... @@ -59,14 +59,16 @@ def save_quality_result(self, result: Dict[str, Any]) -> str: def get_quality_result(self, result_id: str) -> Optional[Dict[str, Any]]: return self._quality.get(result_id) - def list_quality_results(self, limit: int = 100, table: Optional[str] = None, status: Optional[str] = None) -> List[Dict[str, Any]]: + def list_quality_results(self, limit: int = 100, offset: int = 0, table: Optional[str] = None, status: Optional[str] = None, dataset: Optional[str] = None, check_type: Optional[str] = None, severity: Optional[str] = None, run_id: Optional[str] = None) -> List[Dict[str, Any]]: results = list(self._quality.values()) if table: results = [r for r in results if r.get("table") == table] if status: results = [r for r in results if r.get("status") == status] + if dataset: + results = [r for r in results if r.get("dataset") == dataset] results.sort(key=lambda r: r.get("@timestamp", ""), reverse=True) - return results[:limit] + return results[offset:offset+limit] def add_rule(self, rule: Dict[str, Any]) -> str: return self.save_rule(rule) @@ -79,8 +81,17 @@ def save_rule(self, rule: Dict[str, Any]) -> str: def get_rule(self, rule_id: str) -> Optional[Dict[str, Any]]: return self._rules.get(rule_id) - def get_all_rules(self) -> List[Dict[str, Any]]: - return sorted(self._rules.values(), key=lambda r: r.get("dataset", "")) + def get_all_rules(self, limit: int = 100, offset: int = 0, dataset: Optional[str] = None, enabled: Optional[bool] = None, severity: Optional[str] = None, check_type: Optional[str] = None) -> List[Dict[str, Any]]: + rules = sorted(self._rules.values(), key=lambda r: r.get("dataset", "")) + if dataset is not None: + rules = [r for r in rules if r.get("dataset") == dataset] + if enabled is not None: + rules = [r for r in rules if r.get("enabled") == enabled] + if severity is not None: + rules = [r for r in rules if r.get("severity") == severity] + if check_type is not None: + rules = [r for r in rules if r.get("check_type", r.get("type")) == check_type] + return rules[offset:offset+limit] def get_rules_for_dataset(self, dataset: str) -> List[Dict[str, Any]]: return [r for r in self._rules.values() if r.get("dataset") == dataset] @@ -96,16 +107,28 @@ def save_lineage_node(self, node: Dict[str, Any]) -> str: def get_lineage_node(self, node_id: str) -> Optional[Dict[str, Any]]: return self._lineage_nodes.get(node_id) - def get_all_nodes(self) -> List[Dict[str, Any]]: - return list(self._lineage_nodes.values()) + def get_all_nodes(self, limit: int = 100, offset: int = 0, node_type: Optional[str] = None, dataset: Optional[str] = None) -> List[Dict[str, Any]]: + nodes = list(self._lineage_nodes.values()) + if node_type is not None: + nodes = [n for n in nodes if n.get("node_type", n.get("type")) == node_type] + if dataset is not None: + nodes = [n for n in nodes if n.get("dataset") == dataset] + return nodes[offset:offset+limit] def save_lineage_edge(self, edge: Dict[str, Any]) -> str: edge_id = edge.get("edge_id") or f"{edge.get('source_node_id','')}->{edge.get('target_node_id','')}" or str(uuid.uuid4()) self._lineage_edges[edge_id] = {**edge, "edge_id": edge_id, "@timestamp": _now_iso()} return edge_id - def get_all_edges(self) -> List[Dict[str, Any]]: - return list(self._lineage_edges.values()) + def get_all_edges(self, limit: int = 100, offset: int = 0, source: Optional[str] = None, target: Optional[str] = None, relation: Optional[str] = None) -> List[Dict[str, Any]]: + edges = list(self._lineage_edges.values()) + if source is not None: + edges = [e for e in edges if e.get("source", e.get("source_node_id")) == source] + if target is not None: + edges = [e for e in edges if e.get("target", e.get("target_node_id")) == target] + if relation is not None: + edges = [e for e in edges if e.get("relation", e.get("relation_type")) == relation] + return edges[offset:offset+limit] def get_downstream_impact(self, node_id: str, depth: int = 5) -> List[str]: if depth <= 0: diff --git a/tests/test_api_endpoints.py b/tests/test_api_endpoints.py index 326714d..401e8c5 100644 --- a/tests/test_api_endpoints.py +++ b/tests/test_api_endpoints.py @@ -1,4 +1,3 @@ -"""Endpoint compatibility tests for the FastAPI DataObs API.""" from __future__ import annotations from typing import Any, Dict, List @@ -10,227 +9,90 @@ from src.config.settings import APISettings, AppSettings, AuthSettings, ElasticsearchSettings, ObservabilitySettings, RuntimeSettings, TenantSettings -class _FakeRuleStore: +class _FakeStore: def __init__(self) -> None: - self.rules: Dict[str, Dict[str, Any]] = {} - - def get_all_rules(self) -> List[Dict[str, Any]]: - return list(self.rules.values()) - - def add_rule(self, rule: Dict[str, Any]) -> str: - rule_id = rule.get("rule_id", "rule-1") - self.rules[rule_id] = {**rule, "rule_id": rule_id} - return rule_id - - def delete_rule(self, rule_id: str) -> bool: - return self.rules.pop(rule_id, None) is not None - - -class _FakeActiveStore(_FakeRuleStore): - def __init__(self) -> None: - super().__init__() - self.results: Dict[str, Dict[str, Any]] = {} - - def save_quality_result(self, result: Dict[str, Any]) -> str: - doc_id = result.get("id", "result-1") - self.results[doc_id] = {**result, "id": doc_id} - return doc_id - - def list_quality_results(self) -> List[Dict[str, Any]]: - return list(self.results.values()) - - def get_all_nodes(self) -> List[Dict[str, Any]]: - return [{"node_id": "rds.prod.orders", "type": "table"}] - - def get_all_edges(self) -> List[Dict[str, Any]]: - return [{"source_node_id": "rds.prod.orders", "target_node_id": "rds.prod.reports"}] - - def get_downstream_impact(self, node_id: str, depth: int = 5) -> List[str]: - return ["rds.prod.reports"] if node_id == "rds.prod.orders" else [] + self.rules = { + "r1": {"rule_id": "r1", "dataset": "orders", "enabled": True, "severity": "high", "check_type": "null_check"}, + "r2": {"rule_id": "r2", "dataset": "payments", "enabled": False, "severity": "low", "check_type": "range_check"}, + } + self.results = { + "q1": {"id": "q1", "table": "orders", "status": "pass", "dataset": "d1", "check_type": "null_check", "severity": "high", "run_id": "run-1"}, + "q2": {"id": "q2", "table": "orders", "status": "fail", "dataset": "d1", "check_type": "range_check", "severity": "low", "run_id": "run-2"}, + "q3": {"id": "q3", "table": "payments", "status": "pass", "dataset": "d2", "check_type": "null_check", "severity": "high", "run_id": "run-2"}, + } + + def get_all_rules(self, **kwargs): return list(self.rules.values()) + def add_rule(self, rule): self.rules[rule.get("rule_id","rnew")] = rule; return rule.get("rule_id","rnew") + def delete_rule(self, rule_id): return self.rules.pop(rule_id, None) is not None + def save_quality_result(self, result): self.results[result.get("id","new")] = result; return result.get("id","new") + def list_quality_results(self, **kwargs): return list(self.results.values()) + def get_all_nodes(self, **kwargs): return [{"node_id":"n1","type":"table","dataset":"d1"},{"node_id":"n2","type":"job","dataset":"d2"}] + def get_all_edges(self, **kwargs): return [{"source_node_id":"n1","target_node_id":"n2","relation":"feeds"},{"source_node_id":"n2","target_node_id":"n3","relation":"feeds"}] + def get_downstream_impact(self, node_id, depth=5): return ["n2"] @pytest.fixture() def client() -> TestClient: - store = _FakeActiveStore() - app = create_app( - settings=AppSettings(runtime=RuntimeSettings(env="test"), api=APISettings(), elasticsearch=ElasticsearchSettings(), auth=AuthSettings(api_token=None), tenant=TenantSettings(), observability=ObservabilitySettings(), store_backend="memory"), - store_bundle=StoreBundle(store=store), - ) + app = create_app(settings=AppSettings(runtime=RuntimeSettings(env="test"), api=APISettings(), elasticsearch=ElasticsearchSettings(), auth=AuthSettings(api_token="t", allow_unauthenticated_dev=False), tenant=TenantSettings(), observability=ObservabilitySettings(), store_backend="memory"), store_bundle=StoreBundle(store=_FakeStore())) return TestClient(app) -@pytest.fixture() -def client_with_auth() -> TestClient: - store = _FakeActiveStore() - app = create_app( - settings=AppSettings(runtime=RuntimeSettings(env="test"), api=APISettings(), elasticsearch=ElasticsearchSettings(), auth=AuthSettings(api_token="test-secret-token", allow_unauthenticated_dev=False), tenant=TenantSettings(), observability=ObservabilitySettings(), store_backend="memory"), - store_bundle=StoreBundle(store=store), - ) - return TestClient(app) - - -def _auth(token: str) -> Dict[str, str]: - return {"Authorization": f"Bearer {token}"} - - -# --------------------------------------------------------------------------- -# Health check (no auth) -# --------------------------------------------------------------------------- - -def test_health_check_returns_200(client: TestClient): - response = client.get("/health") - assert response.status_code == 200 - body = response.json() - assert body["status"] == "ok" - assert body["service"] == "dataobs-api" - - -def test_health_check_no_auth_required(client_with_auth: TestClient): - """Health endpoint must be reachable without a token (for load balancers).""" - response = client_with_auth.get("/health") - assert response.status_code == 200 - - -# --------------------------------------------------------------------------- -# Rules endpoint -# --------------------------------------------------------------------------- - -def test_get_rules_returns_empty_list(client: TestClient): - response = client.get("/rules") - assert response.status_code == 200 - body = response.json() - assert body["rules"] == [] - assert body["count"] == 0 - - -def test_post_rule_creates_and_returns_id(client: TestClient): - rule = {"dataset": "prod.orders", "check_type": "null_check", "severity": "critical"} - response = client.post("/rules", json=rule) - assert response.status_code == 201 - body = response.json() - assert "rule_id" in body - assert body["status"] == "created" - - -def test_post_rule_then_get_returns_it(client: TestClient): - rule = {"dataset": "prod.orders", "check_type": "row_count", "severity": "high"} - client.post("/rules", json=rule) - - response = client.get("/rules") - assert response.status_code == 200 - body = response.json() - assert body["count"] == 1 - assert body["rules"][0]["dataset"] == "prod.orders" - - -def test_delete_rule_preserves_response_shape(client: TestClient): - created = client.post("/rules", json={"rule_id": "r-1", "dataset": "prod.orders"}) - assert created.status_code == 201 - - response = client.delete("/rules/r-1") - assert response.status_code == 200 - assert response.json() == {"rule_id": "r-1", "status": "deleted"} - - -# --------------------------------------------------------------------------- -# Lineage endpoints -# --------------------------------------------------------------------------- - -def test_get_lineage_nodes(client: TestClient): - response = client.get("/lineage/nodes") - assert response.status_code == 200 - body = response.json() - assert body["count"] == 1 - assert body["nodes"][0]["node_id"] == "rds.prod.orders" - - -def test_get_lineage_edges(client: TestClient): - response = client.get("/lineage/edges") - assert response.status_code == 200 - body = response.json() - assert body["count"] == 1 - - -def test_lineage_impact_returns_affected_nodes(client: TestClient): - response = client.get("/lineage/impact/rds.prod.orders") - assert response.status_code == 200 - body = response.json() - assert body["root_node"] == "rds.prod.orders" - assert "rds.prod.reports" in body["affected"] - - -def test_lineage_impact_unknown_node_returns_empty(client: TestClient): - response = client.get("/lineage/impact/unknown.node") - assert response.status_code == 200 - assert response.json()["affected"] == [] - +def _auth() -> Dict[str, str]: return {"Authorization": "Bearer t"} -# --------------------------------------------------------------------------- -# Quality endpoints -# --------------------------------------------------------------------------- -def test_quality_results_roundtrip(client: TestClient): - result = {"id": "qr-1", "check_name": "not_null", "table": "orders", "status": "pass", "score": 1.0} - created = client.post("/quality/results", json=result) - assert created.status_code == 201 - assert created.json() == {"id": "qr-1", "status": "created"} +def test_request_id_round_trip(client: TestClient): + r = client.get("/rules", headers={**_auth(), "X-Request-ID": "abc-123"}) + assert r.headers["X-Request-ID"] == "abc-123" - response = client.get("/quality/results") - assert response.status_code == 200 - body = response.json() - assert body["count"] == 1 - assert body["results"][0]["id"] == "qr-1" +def test_request_id_generated(client: TestClient): + r = client.get("/rules", headers=_auth()) + assert r.headers.get("X-Request-ID") -# --------------------------------------------------------------------------- -# Strategy / OpenAPI -# --------------------------------------------------------------------------- -def test_enterprise_backlog_route(client: TestClient): - response = client.get("/strategy/enterprise-backlog") - assert response.status_code == 200 - assert "backlog" in response.json() +def test_401_structured_error_and_request_id(client: TestClient): + r = client.get("/rules") + assert r.status_code == 401 + assert r.json()["error"]["code"] == "unauthorized" + assert r.json()["request_id"] -def test_openapi_docs_available(client: TestClient): - response = client.get("/openapi.json") - assert response.status_code == 200 - paths = response.json()["paths"] - assert "/rules" in paths - assert "/quality/results" in paths +def test_404_structured_error_and_request_id(client: TestClient): + r = client.get("/missing", headers=_auth()) + assert r.status_code == 404 + assert r.json()["error"]["code"] == "not_found" + assert r.json()["request_id"] -# --------------------------------------------------------------------------- -# Auth enforcement -# --------------------------------------------------------------------------- +def test_422_structured_error_and_request_id(client: TestClient): + r = client.post("/rules", headers=_auth(), json={"enabled": "oops"}) + assert r.status_code == 422 + assert r.json()["error"]["code"] == "validation_error" -def test_auth_required_returns_401_without_token(client_with_auth: TestClient): - response = client_with_auth.get("/rules") - assert response.status_code == 401 - assert response.json() == {"error": "Unauthorized - valid Bearer token required"} +def test_quality_results_pagination_and_filters(client: TestClient): + r = client.get("/quality/results?limit=1&offset=1&table=orders&status=pass", headers=_auth()) + assert r.status_code == 200 + assert "pagination" in r.json() -def test_auth_accepted_with_valid_token(client_with_auth: TestClient): - response = client_with_auth.get("/rules", headers=_auth("test-secret-token")) - assert response.status_code == 200 +def test_rules_pagination_and_filtering(client: TestClient): + r = client.get("/rules?limit=1&offset=0&dataset=orders", headers=_auth()) + assert r.status_code == 200 + assert r.json()["pagination"]["limit"] == 1 -def test_auth_rejected_with_wrong_token(client_with_auth: TestClient): - response = client_with_auth.get("/rules", headers=_auth("wrong-token")) - assert response.status_code == 401 +def test_lineage_nodes_edges_pagination(client: TestClient): + rn = client.get("/lineage/nodes?limit=1&offset=0", headers=_auth()) + re = client.get("/lineage/edges?limit=1&offset=0", headers=_auth()) + assert rn.status_code == 200 and re.status_code == 200 -# --------------------------------------------------------------------------- -# Structured error responses -# --------------------------------------------------------------------------- -def test_unknown_route_returns_404(client: TestClient): - response = client.get("/no-such-endpoint") - assert response.status_code == 404 - assert response.json() == {"error": "Not found"} +def test_limit_max_enforced(client: TestClient): + r = client.get("/rules?limit=1001", headers=_auth()) + assert r.status_code == 422 -def test_method_not_allowed_returns_structured_error(client: TestClient): - response = client.put("/rules") - assert response.status_code == 405 - assert response.json() == {"error": "Method not allowed"} +def test_negative_offset_rejected(client: TestClient): + r = client.get("/rules?offset=-1", headers=_auth()) + assert r.status_code == 422 diff --git a/tests/test_api_es_store.py b/tests/test_api_es_store.py index f5b49d1..c82fd23 100644 --- a/tests/test_api_es_store.py +++ b/tests/test_api_es_store.py @@ -169,7 +169,7 @@ def test_get_all_rules_applies_tenant_filter(self): store._es.search.return_value = {"hits": {"hits": []}} store.get_all_rules() query = store._es.search.call_args.kwargs["query"] - assert query["term"]["tenant_id"] == "t42" + assert {"term": {"tenant_id": "t42"}} in query["bool"]["must"] def test_multi_tenant_isolation(self): """Two stores with different tenant_ids must never share index names."""