Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
00ad6af
Add credential= parameter for custom Azure Identity credential support
jahnvi480 May 27, 2026
36ccf52
Sanitize conn string in credential= path, update .pyi stubs
jahnvi480 May 27, 2026
4a5687a
Fix bulkcopy auth test mock to set _custom_credential = None
jahnvi480 May 27, 2026
819549b
Merge branch 'main' into jahnvi/custom-credential-support
jahnvi480 May 27, 2026
5370672
Merge branch 'main' into jahnvi/custom-credential-support
jahnvi480 May 29, 2026
2304871
Rename credential= to token_provider= per team feedback
jahnvi480 May 29, 2026
b1c8a59
Merge branch 'main' into jahnvi/custom-credential-support
jahnvi480 May 29, 2026
09b7198
Merge branch 'main' into jahnvi/custom-credential-support
jahnvi480 Jun 2, 2026
e35270e
Merge branch 'main' into jahnvi/custom-credential-support
jahnvi480 Jun 23, 2026
8e4e04a
Merge branch 'main' into jahnvi/custom-credential-support
jahnvi480 Jun 24, 2026
42f976b
Merge remote-tracking branch 'origin/main' into jahnvi/custom-credent…
jahnvi480 Jun 25, 2026
84497d7
Add expiry capture and DB-API exceptions for token_provider; cover bu…
jahnvi480 Jun 25, 2026
b49daa8
Merge branch 'main' into jahnvi/custom-credential-support
jahnvi480 Jun 26, 2026
e41058a
Add token_provider validation, dropped-credential warning, and Protoc…
jahnvi480 Jun 26, 2026
bc59232
Merge branch 'main' into jahnvi/custom-credential-support
jahnvi480 Jun 26, 2026
2b796ae
Disable pooling for access-token connections to prevent cross-princip…
jahnvi480 Jun 26, 2026
9814bdb
Fix f-string lint nits and type token_provider in .pyi stubs
jahnvi480 Jun 26, 2026
d723153
Remove dangling doc reference and fix expired-token warning attribution
jahnvi480 Jun 26, 2026
4559241
Revert dynamic stacklevel helper; use fixed stacklevel=2 for expired-…
jahnvi480 Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ exclude_lines =

# Don't complain if non-runnable code isn't run
if __name__ == .__main__.:

# Type-checking-only imports never execute at runtime
if TYPE_CHECKING:

# Exclude all logging statements (zero overhead when disabled by design)
logger\.debug
Expand Down
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Added
- New feature: Support for macOS and Linux.
- Documentation: Added API documentation in the Wiki.
- New `token_provider=` parameter on `connect()` / `Connection` for Microsoft
Entra ID authentication with a custom credential object. Accepts any object
exposing a `.get_token(scope)` method (e.g. any `azure-identity` credential
such as `DefaultAzureCredential`, `AzureCliCredential`,
`ManagedIdentityCredential`). Mutually exclusive with `Authentication=` in
the connection string and with `attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`.
Bulk copy re-acquires a fresh token from the provider on each operation. The
token scope is fixed to the Azure commercial cloud; sovereign clouds are out
of scope (supply a pre-acquired token via `attrs_before` instead).
- Bulk copy now supports `Authentication=ActiveDirectoryServicePrincipal`
via an `entra_id_token_factory` callback registered on the mssql-py-core
connection. The callback is invoked by mssql-tds mid-handshake (FedAuth
Expand Down
196 changes: 193 additions & 3 deletions mssql_python/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,24 @@
"""

import hashlib
import inspect
import platform
import struct
import threading
from typing import Tuple, Dict, Optional
import time
import warnings
from typing import Tuple, Dict, Optional, Any, Protocol, runtime_checkable

from mssql_python.logging import logger
from mssql_python.constants import (
AuthType,
ConstantsDDBC,
_AuthInternal,
_KEY_AUTHENTICATION,
_KEY_UID,
_KEY_PWD,
_KEY_TRUSTED_CONNECTION,
)
from mssql_python.exceptions import InterfaceError, OperationalError

# Module-level credential instance cache.
# Reusing credential objects allows the Azure Identity SDK's built-in
Expand All @@ -34,6 +37,28 @@
# Canonical keys to strip when handing an Entra-token connection to ODBC.
_SENSITIVE_KEYS = frozenset({_KEY_UID, _KEY_PWD, _KEY_TRUSTED_CONNECTION, _KEY_AUTHENTICATION})

# Azure SQL Database OAuth scope for the Azure **commercial** cloud. Shared by
# the built-in AADAuth path and the custom token_provider path. Sovereign
# clouds (Azure US Gov, Azure China, Azure Germany) are out of scope — a token
# for a different audience is rejected by SQL Server at login.
_DATABASE_SCOPE = "https://database.windows.net/.default"


@runtime_checkable
class TokenProvider(Protocol):
"""Structural type accepted by the ``token_provider=`` connect parameter.

Any object exposing a ``get_token(scope, ...)`` method qualifies — notably
every ``azure-identity`` credential class (``DefaultAzureCredential``,
``AzureCliCredential``, ``ManagedIdentityCredential``, etc.). The returned
object must expose a non-empty ``.token`` (str); an optional
``.expires_on`` (int POSIX timestamp) is captured for diagnostics.
"""

def get_token(self, *scopes: str, **kwargs: Any) -> Any: # pragma: no cover - protocol
...


# Map Authentication connection-string values to internal short names.
_AUTH_TYPE_MAP: Dict[str, str] = {
AuthType.INTERACTIVE.value: _AuthInternal.INTERACTIVE,
Expand Down Expand Up @@ -147,7 +172,7 @@ def _acquire_token(
auth_type,
)
credential = _credential_cache[cache_key]
raw_token = credential.get_token("https://database.windows.net/.default").token
raw_token = credential.get_token(_DATABASE_SCOPE).token
logger.info(
"get_token: Azure AD token acquired successfully - token_length=%d chars",
len(raw_token),
Expand Down Expand Up @@ -437,3 +462,168 @@ def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]:
"""
auth_value = parsed_params.get(_KEY_AUTHENTICATION, "").strip().lower()
return _AUTH_TYPE_MAP.get(auth_value)


def _get_token_from_credential(credential: "TokenProvider") -> Tuple[str, Optional[int]]:
"""Internal: call credential.get_token() and return ``(raw_jwt, expires_on)``.

Centralises the token-acquisition + error-wrapping logic that both
:func:`acquire_token_from_credential` and
:func:`acquire_raw_token_from_credential` need.

``expires_on`` is the POSIX timestamp (seconds) at which the token
expires, taken from the credential's ``AccessToken`` result when present
(it is ``None`` if the provider does not supply one). It is captured so
callers can log it and reason about token lifetime; the access token
itself is a *pre-connect* ODBC attribute and cannot be refreshed on a
live connection (see the module docs on token lifecycle).

Note:
The scope is hard-coded to the Azure **commercial** cloud
(``https://database.windows.net/.default``). Sovereign clouds
(Azure US Government, Azure China, Azure Germany) are **out of
scope** for the ``token_provider`` path — for those, supply a
pre-acquired token via ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]``.

Raises:
InterfaceError: If the provider returns no valid ``.token`` string.
OperationalError: If the underlying ``get_token()`` call fails.
"""
start_time = time.perf_counter()
try:
token_result = credential.get_token(_DATABASE_SCOPE)
except TypeError as e:
# get_token() is called with exactly one positional scope argument, so
# a TypeError here almost always means its signature can't accept a
# scope (e.g. a zero-arg or keyword-only get_token). Surface that as a
# clear, actionable InterfaceError instead of an opaque failure. This
# is the call-time source of truth for arity — the connect() path only
# *warns* on a suspicious signature so it never blocks a credential
# whose signature is merely hard to introspect (partial/decorated).
raise InterfaceError(
driver_error=(
"token_provider.get_token() must accept a scope positional "
"argument, e.g. get_token(scope)."
),
ddbc_error=str(e),
) from e
except Exception as e:
logger.error(
"_get_token_from_credential: get_token() failed - credential=%s, error=%s",
type(credential).__name__,
str(e),
)
# Preserve the original credential exception (e.g. azure-identity
# ClientAuthenticationError) as __cause__ for programmatic handling.
raise OperationalError(
driver_error=(f"Failed to acquire token from credential ({type(credential).__name__})"),
ddbc_error=str(e),
) from e

# azure.identity.aio (async) credentials return a coroutine from a
# synchronous get_token() call. Detect it and fail with an async-specific
# message rather than tripping over a missing .token attribute — and close
# the coroutine so it doesn't emit a "coroutine was never awaited" warning.
if inspect.iscoroutine(token_result):
token_result.close()
raise InterfaceError(
driver_error=(
"token_provider.get_token() returned a coroutine, which indicates "
"an async credential (e.g. from azure.identity.aio). Use a "
"synchronous credential instead."
),
ddbc_error=f"got coroutine from {type(credential).__name__}.get_token()",
)

raw_token = getattr(token_result, "token", None)
if not isinstance(raw_token, str) or not raw_token:
raise InterfaceError(
driver_error=(
"token_provider.get_token() must return an object with a non-empty "
"string '.token' attribute."
),
ddbc_error=f"got .token of type {type(raw_token).__name__}",
)

expires_on = getattr(token_result, "expires_on", None)
# Warn (don't fail) if the credential handed back an already-expired token:
# the server enforces expiry and will reject the login, so surfacing it here
# points at the real cause instead of an opaque later failure. Only numeric
# POSIX timestamps are checked; bools are excluded to avoid false positives.
if (
isinstance(expires_on, (int, float))
and not isinstance(expires_on, bool)
and expires_on < time.time()
):
warnings.warn(
f"token_provider returned a token that is already expired "
f"(expires_on={expires_on} is in the past). The server will likely "
f"reject the connection.",
UserWarning,
stacklevel=2,
)
elapsed_ms = (time.perf_counter() - start_time) * 1000
logger.info(
"_get_token_from_credential: Token acquired from %s - length=%d chars, "
"expires_on=%s, duration_ms=%.2f",
type(credential).__name__,
len(raw_token),
expires_on,
elapsed_ms,
)
return raw_token, expires_on


def acquire_token_from_credential(credential: "TokenProvider") -> Tuple[bytes, Optional[int]]:
"""Acquire an ODBC token struct from a user-supplied credential object.

The credential must follow the Azure ``TokenCredential`` protocol — i.e.
have a ``.get_token(scope)`` method returning an object with a ``.token``
attribute (a raw JWT string).

.. note::
The scope is fixed to the Azure **commercial** cloud
(``https://database.windows.net/.default``). Sovereign clouds (Azure
US Government, Azure China, Azure Germany) are **out of scope** — for
those, supply a pre-acquired token via
``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`` instead.

Args:
credential: Any object with a ``.get_token(scope)`` method.

Returns:
Tuple[bytes, Optional[int]]: The ODBC token struct for
``SQL_COPT_SS_ACCESS_TOKEN`` and the token's ``expires_on`` POSIX
timestamp (``None`` if the provider does not supply one).

Raises:
InterfaceError: If the provider returns no valid ``.token`` string.
OperationalError: If the underlying ``get_token()`` call fails.
"""
raw_token, expires_on = _get_token_from_credential(credential)
return AADAuth.get_token_struct(raw_token), expires_on


def acquire_raw_token_from_credential(credential: "TokenProvider") -> Tuple[str, Optional[int]]:
"""Acquire a raw JWT string from a user-supplied credential object.

Used by bulk copy, which needs the raw JWT rather than the ODBC struct.

.. note::
The scope is fixed to the Azure **commercial** cloud. Sovereign
clouds are **out of scope** — see
:func:`acquire_token_from_credential`.

Args:
credential: Any object with a ``.get_token(scope)`` method.

Returns:
Tuple[str, Optional[int]]: The raw JWT token string and the token's
``expires_on`` POSIX timestamp (``None`` if the provider does not
supply one).

Raises:
InterfaceError: If the provider returns no valid ``.token`` string.
OperationalError: If the underlying ``get_token()`` call fails.
"""
return _get_token_from_credential(credential)
Loading
Loading