From 00ad6afd5f5f11b26d12cbd05eac442e240d0cd9 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 27 May 2026 19:35:27 +0530 Subject: [PATCH 01/10] Add credential= parameter for custom Azure Identity credential support Add a new 'credential' parameter to connect() that accepts any object following the Azure TokenCredential protocol (.get_token() method). This allows users to authenticate with any azure-identity credential class without being limited to the driver's hardcoded credential map. Changes: - auth.py: Add _get_token_from_credential() shared helper, acquire_token_from_credential(), acquire_raw_token_from_credential() - db_connection.py: Add credential=None parameter to connect() - connection.py: Validate credential, acquire token, store for bulk copy token refresh. Mutually exclusive with Authentication= - cursor.py: Check _custom_credential before _auth_type in bulk copy - constants.py: Unify _KEY_* constants with _ALLOWED_CONNECTION_STRING_PARAMS to use single source of truth (_CONNECTION_STRING_*_KEY pattern) - test_008_auth.py: Add 12 new tests for custom credential flow --- mssql_python/auth.py | 65 +++++++++++++ mssql_python/connection.py | 26 +++++- mssql_python/constants.py | 30 +++--- mssql_python/cursor.py | 19 +++- mssql_python/db_connection.py | 12 +++ tests/test_008_auth.py | 169 ++++++++++++++++++++++++++++++++++ 6 files changed, 305 insertions(+), 16 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index dd716c2c0..847574f12 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -238,3 +238,68 @@ def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]: """ auth_value = parsed_params.get(_KEY_AUTHENTICATION, "").strip().lower() return _AUTH_TYPE_MAP.get(auth_value) + + +def _get_token_from_credential(credential: object) -> str: + """Internal: call credential.get_token() and return the raw JWT string. + + Centralises the token-acquisition + error-wrapping logic that both + :func:`acquire_token_from_credential` and + :func:`acquire_raw_token_from_credential` need. + + Raises: + RuntimeError: If token acquisition fails. + """ + try: + raw_token = credential.get_token("https://database.windows.net/.default").token + logger.info( + "_get_token_from_credential: Token acquired from %s - length=%d chars", + type(credential).__name__, + len(raw_token), + ) + return raw_token + except Exception as e: + logger.error( + "_get_token_from_credential: Failed - credential=%s, error=%s", + type(credential).__name__, + str(e), + ) + raise RuntimeError( + f"Failed to acquire token from credential " f"({type(credential).__name__}): {e}" + ) from e + + +def acquire_token_from_credential(credential: object) -> bytes: + """Acquire an ODBC token struct from a user-supplied credential object. + + The credential must follow the Azure ``TokenCredential`` protocol — i.e. + have a ``.get_token(scope)`` method returning an object with a ``.token`` + attribute (a raw JWT string). + + Args: + credential: Any object with a ``.get_token(scope)`` method. + + Returns: + bytes: ODBC-compatible token struct for ``SQL_COPT_SS_ACCESS_TOKEN``. + + Raises: + RuntimeError: If token acquisition fails. + """ + return AADAuth.get_token_struct(_get_token_from_credential(credential)) + + +def acquire_raw_token_from_credential(credential: object) -> str: + """Acquire a raw JWT string from a user-supplied credential object. + + Used by bulk copy, which needs the raw JWT rather than the ODBC struct. + + Args: + credential: Any object with a ``.get_token(scope)`` method. + + Returns: + str: Raw JWT token string. + + Raises: + RuntimeError: If token acquisition fails. + """ + return _get_token_from_credential(credential) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 0d9b4692e..3e81f3bac 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -246,6 +246,7 @@ def __init__( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + credential: Optional[object] = None, **kwargs: Any, ) -> None: """ @@ -334,10 +335,33 @@ def __init__( # fresh token; re-parsing self.connection_str at that point would miss # them because UID is already gone. self._credential_kwargs: Optional[Dict[str, str]] = None + # User-supplied credential object for custom Entra ID authentication. + # Stored so bulk copy can call .get_token() for a fresh JWT later. + self._custom_credential = None + + # Custom credential= parameter — takes priority, mutually exclusive + # with Authentication= in the connection string. + if credential is not None: + if _KEY_AUTHENTICATION in parsed_params: + raise ValueError( + "Cannot specify both 'credential' parameter and " + "'Authentication' in the connection string. " + "Use one or the other." + ) + if not callable(getattr(credential, "get_token", None)): + raise TypeError( + f"credential must have a .get_token() method. " + f"Got {type(credential).__name__}." + ) + from mssql_python.auth import acquire_token_from_credential + + token = acquire_token_from_credential(credential) + self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token + self._custom_credential = credential # Handle Entra ID authentication if specified. # The parsed dict is used directly — no re-parsing of the connection string. - if _KEY_AUTHENTICATION in parsed_params: + elif _KEY_AUTHENTICATION in parsed_params: auth_type = process_auth_parameters(parsed_params) if auth_type: diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 3bfd39483..6283d6b90 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -467,6 +467,16 @@ def get_attribute_set_timing(attribute): _CONNECTION_STRING_DRIVER_KEY = "Driver" _CONNECTION_STRING_APP_KEY = "APP" +_CONNECTION_STRING_AUTH_KEY = "Authentication" +_CONNECTION_STRING_UID_KEY = "UID" +_CONNECTION_STRING_PWD_KEY = "PWD" +_CONNECTION_STRING_TRUSTED_CONNECTION_KEY = "Trusted_Connection" + +# Aliases used by auth.py / connection.py — kept for readability. +_KEY_AUTHENTICATION = _CONNECTION_STRING_AUTH_KEY +_KEY_UID = _CONNECTION_STRING_UID_KEY +_KEY_PWD = _CONNECTION_STRING_PWD_KEY +_KEY_TRUSTED_CONNECTION = _CONNECTION_STRING_TRUSTED_CONNECTION_KEY # Reserved connection string parameters that are controlled by the driver # and cannot be set by users @@ -486,16 +496,16 @@ def get_attribute_set_timing(attribute): "address": "Server", "addr": "Server", # Authentication - "uid": "UID", - "pwd": "PWD", - "authentication": "Authentication", - "trusted_connection": "Trusted_Connection", + "uid": _CONNECTION_STRING_UID_KEY, + "pwd": _CONNECTION_STRING_PWD_KEY, + "authentication": _CONNECTION_STRING_AUTH_KEY, + "trusted_connection": _CONNECTION_STRING_TRUSTED_CONNECTION_KEY, # Database "database": "Database", # Driver (always controlled by mssql-python) - "driver": "Driver", + "driver": _CONNECTION_STRING_DRIVER_KEY, # Application name (always controlled by mssql-python) - "app": "APP", + "app": _CONNECTION_STRING_APP_KEY, # Encryption and Security "encrypt": "Encrypt", "trustservercertificate": "TrustServerCertificate", @@ -519,14 +529,6 @@ def get_attribute_set_timing(attribute): "packetsize": "PacketSize", } -# Canonical normalized key names produced by _ConnectionStringParser._normalize_params. -# Consumer code should reference these instead of hard-coding raw strings so that -# a rename in _ALLOWED_CONNECTION_STRING_PARAMS is caught at import time. -_KEY_AUTHENTICATION = "Authentication" -_KEY_UID = "UID" -_KEY_PWD = "PWD" -_KEY_TRUSTED_CONNECTION = "Trusted_Connection" - def get_info_constants() -> Dict[str, int]: """ diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 49eb1b92d..966941d8d 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2956,7 +2956,24 @@ def bulkcopy( pycore_context = connstr_to_pycore_params(params) # Token acquisition — only thing cursor must handle (needs azure-identity SDK) - if self.connection._auth_type: + if self.connection._custom_credential is not None: + # User-supplied credential — use it directly for a fresh token. + from mssql_python.auth import acquire_raw_token_from_credential + + try: + raw_token = acquire_raw_token_from_credential(self.connection._custom_credential) + except RuntimeError as e: + raise RuntimeError( + f"Bulk copy failed: unable to acquire token " f"from custom credential: {e}" + ) from e + pycore_context["access_token"] = raw_token + for key in ("authentication", "user_name", "password"): + pycore_context.pop(key, None) + logger.debug( + "Bulk copy: acquired fresh token from custom credential (%s)", + type(self.connection._custom_credential).__name__, + ) + elif self.connection._auth_type: # Fresh token acquisition for mssql-py-core connection. credential # kwargs (e.g. user-assigned MSI client_id) were captured by # Connection.__init__ before remove_sensitive_params stripped UID diff --git a/mssql_python/db_connection.py b/mssql_python/db_connection.py index fe10b819b..1688a56ed 100644 --- a/mssql_python/db_connection.py +++ b/mssql_python/db_connection.py @@ -15,6 +15,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + credential: Optional[object] = None, **kwargs: Any, ) -> Connection: """ @@ -35,6 +36,16 @@ def connect( This per-connection override is useful for migration from pyodbc: connections that need string UUIDs can pass native_uuid=False, while the default (True) returns native uuid.UUID objects. + credential (object, optional): An Azure Identity credential object (or any object with a + ``.get_token(scope)`` method) used for Entra ID authentication. When provided, the + driver calls ``credential.get_token()`` to acquire a token instead of using the + built-in credential map. Cannot be combined with ``Authentication=`` in the + connection string. + + For environment-portable code, prefer ``Authentication=ActiveDirectoryDefault`` in + the connection string — ``DefaultAzureCredential`` automatically picks the right + credential per environment. Use ``credential=`` only when you need explicit control + (e.g., excluding specific providers or using a credential not in the built-in map). Keyword Args: **kwargs: Additional key/value pairs for the connection string. Below attributes are not implemented in the internal driver: @@ -58,6 +69,7 @@ def connect( attrs_before=attrs_before, timeout=timeout, native_uuid=native_uuid, + credential=credential, **kwargs, ) return conn diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index b127133a5..fe6937723 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -17,6 +17,8 @@ extract_auth_type, _credential_cache, _credential_cache_lock, + acquire_token_from_credential, + acquire_raw_token_from_credential, ) from mssql_python.constants import AuthType, ConstantsDDBC import secrets @@ -522,6 +524,7 @@ def test_bulkcopy_path_preserves_user_assigned_msi_client_id(self): mock_conn.connection_str = "Server=tcp:test.database.windows.net;Database=testdb;" mock_conn._auth_type = "msi" mock_conn._credential_kwargs = {"client_id": client_id} + mock_conn._custom_credential = None mock_conn._is_connected = True cursor = Cursor.__new__(Cursor) @@ -970,3 +973,169 @@ def test_token_output_correct_on_cache_miss_and_hit(self): # Same credential instance for both assert "default" in _credential_cache + + +# ── Custom credential= parameter tests ── + + +class TestAcquireTokenFromCredential: + """Tests for the acquire_token_from_credential helper.""" + + def test_happy_path(self): + """acquire_token_from_credential returns a valid token struct.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + result = acquire_token_from_credential(mock_cred) + assert isinstance(result, bytes) + assert len(result) > 4 + mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") + + def test_credential_raises_exception(self): + """acquire_token_from_credential wraps credential errors in RuntimeError.""" + mock_cred = MagicMock() + mock_cred.get_token.side_effect = Exception("auth failed") + with pytest.raises(RuntimeError, match="Failed to acquire token from credential"): + acquire_token_from_credential(mock_cred) + + +class TestAcquireRawTokenFromCredential: + """Tests for the acquire_raw_token_from_credential helper.""" + + def test_happy_path(self): + """acquire_raw_token_from_credential returns the raw JWT string.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + result = acquire_raw_token_from_credential(mock_cred) + assert result == SAMPLE_TOKEN + mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") + + def test_credential_raises_exception(self): + """acquire_raw_token_from_credential wraps credential errors in RuntimeError.""" + mock_cred = MagicMock() + mock_cred.get_token.side_effect = Exception("auth failed") + with pytest.raises(RuntimeError, match="Failed to acquire token from credential"): + acquire_raw_token_from_credential(mock_cred) + + +class TestCustomCredentialConnect: + """Tests for the credential= parameter on connect().""" + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_happy_path(self, mock_ddbc_conn): + """credential= acquires token and sets attrs_before.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + conn = connect("Server=test;Database=testdb", credential=mock_cred) + assert conn._custom_credential is mock_cred + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + # Existing auth_type should be None (no Authentication= in conn str) + assert conn._auth_type is None + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_plus_authentication_raises_valueerror(self, mock_ddbc_conn): + """credential= + Authentication= raises ValueError.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with pytest.raises(ValueError, match="Cannot specify both"): + connect( + "Server=test;Database=testdb;Authentication=ActiveDirectoryDefault", + credential=mock_cred, + ) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_plus_authentication_via_kwargs_raises_valueerror(self, mock_ddbc_conn): + """credential= + Authentication via kwargs raises ValueError.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with pytest.raises(ValueError, match="Cannot specify both"): + connect( + "Server=test;Database=testdb", + credential=mock_cred, + Authentication="ActiveDirectoryDefault", + ) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_without_get_token_raises_typeerror(self, mock_ddbc_conn): + """Passing an object without .get_token() raises TypeError.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + with pytest.raises(TypeError, match="credential must have a .get_token"): + connect("Server=test;Database=testdb", credential="not_a_credential") + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_none_uses_existing_flow(self, mock_ddbc_conn): + """credential=None (default) uses existing auth flow, no change.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryDefault") + assert conn._custom_credential is None + assert conn._auth_type == "default" + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_with_non_auth_attrs_before(self, mock_ddbc_conn): + """credential= works alongside non-auth attrs_before.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + login_timeout_attr = 113 # SQL_ATTR_LOGIN_TIMEOUT + conn = connect( + "Server=test;Database=testdb", + credential=mock_cred, + attrs_before={login_timeout_attr: 30}, + ) + assert conn._attrs_before[login_timeout_attr] == 30 + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_get_token_failure_raises_runtime_error(self, mock_ddbc_conn): + """If credential.get_token() fails, connect() raises RuntimeError.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.side_effect = Exception("token acquisition failed") + from mssql_python import connect + + with pytest.raises(RuntimeError, match="Failed to acquire token from credential"): + connect("Server=test;Database=testdb", credential=mock_cred) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_credential_with_non_callable_get_token_raises_typeerror(self, mock_ddbc_conn): + """Object with .get_token as a non-callable attribute raises TypeError.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class BadCredential: + get_token = "not_a_method" + + with pytest.raises(TypeError, match="credential must have a .get_token"): + connect("Server=test;Database=testdb", credential=BadCredential()) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_multiple_connections_share_same_credential(self, mock_ddbc_conn): + """Two connections can share the same credential object safely.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + conn1 = connect("Server=test1;Database=db1", credential=mock_cred) + conn2 = connect("Server=test2;Database=db2", credential=mock_cred) + assert conn1._custom_credential is conn2._custom_credential + assert mock_cred.get_token.call_count == 2 + conn1.close() + conn2.close() From 36ccf52001c083e0675edb96814c5fb2e7af6b50 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 27 May 2026 19:51:39 +0530 Subject: [PATCH 02/10] Sanitize conn string in credential= path, update .pyi stubs - Strip UID/PWD/Trusted_Connection from connection_str when credential= is used (same as Authentication= path) to avoid leaking unused secrets - Add credential= parameter to Connection.__init__ and connect() in mssql_python.pyi type stubs --- mssql_python/connection.py | 4 ++++ mssql_python/mssql_python.pyi | 2 ++ 2 files changed, 6 insertions(+) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 3e81f3bac..8ecba7df6 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -358,6 +358,10 @@ def __init__( token = acquire_token_from_credential(credential) self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token self._custom_credential = credential + # Strip sensitive params (UID/PWD/Trusted_Connection) since + # access-token auth is used — same as the Authentication= path. + sanitized = remove_sensitive_params(parsed_params) + self.connection_str = _ConnectionStringBuilder(sanitized).build() # Handle Entra ID authentication if specified. # The parsed dict is used directly — no re-parsing of the connection string. diff --git a/mssql_python/mssql_python.pyi b/mssql_python/mssql_python.pyi index 9b08913d6..ef2655bab 100644 --- a/mssql_python/mssql_python.pyi +++ b/mssql_python/mssql_python.pyi @@ -246,6 +246,7 @@ class Connection: attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + credential: Optional[object] = None, **kwargs: Any, ) -> None: ... @@ -289,6 +290,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, + credential: Optional[object] = None, **kwargs: Any, ) -> Connection: ... From 4a5687a950ed7c37d17cbd90aa57691f9768765f Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Wed, 27 May 2026 20:02:35 +0530 Subject: [PATCH 03/10] Fix bulkcopy auth test mock to set _custom_credential = None The _make_cursor helper uses MagicMock for the connection, which auto-creates truthy attributes. Without explicitly setting _custom_credential = None, the bulk copy code takes the custom credential path instead of the expected _auth_type path. --- tests/test_020_bulkcopy_auth_cleanup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_020_bulkcopy_auth_cleanup.py b/tests/test_020_bulkcopy_auth_cleanup.py index 164438344..404faca91 100644 --- a/tests/test_020_bulkcopy_auth_cleanup.py +++ b/tests/test_020_bulkcopy_auth_cleanup.py @@ -22,6 +22,7 @@ def _make_cursor(connection_str, auth_type): mock_conn = MagicMock() mock_conn.connection_str = connection_str mock_conn._auth_type = auth_type + mock_conn._custom_credential = None mock_conn._is_connected = True cursor = Cursor.__new__(Cursor) From 2304871b74477c9d1c295f08d1863b073f38293f Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 29 May 2026 12:40:16 +0530 Subject: [PATCH 04/10] Rename credential= to token_provider= per team feedback Rename the public API parameter from 'credential' to 'token_provider' to reduce ambiguity in our multi-auth-path context. 'credential' could be confused with SQL auth username/password; 'token_provider' clearly signals token-based Entra ID auth. - Rename parameter: credential -> token_provider (connect, Connection) - Rename internal attr: _custom_credential -> _token_provider - Update error messages, docstrings, comments, .pyi stubs - Improve docstring with usage example and explicit guidance - All 97 tests pass --- mssql_python/connection.py | 22 ++++++------- mssql_python/cursor.py | 6 ++-- mssql_python/db_connection.py | 35 +++++++++++++------- mssql_python/mssql_python.pyi | 4 +-- tests/test_008_auth.py | 44 ++++++++++++------------- tests/test_020_bulkcopy_auth_cleanup.py | 2 +- 6 files changed, 63 insertions(+), 50 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 8ecba7df6..be6fc9a57 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -246,7 +246,7 @@ def __init__( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - credential: Optional[object] = None, + token_provider: Optional[object] = None, **kwargs: Any, ) -> None: """ @@ -335,29 +335,29 @@ def __init__( # fresh token; re-parsing self.connection_str at that point would miss # them because UID is already gone. self._credential_kwargs: Optional[Dict[str, str]] = None - # User-supplied credential object for custom Entra ID authentication. + # User-supplied token provider for custom Entra ID authentication. # Stored so bulk copy can call .get_token() for a fresh JWT later. - self._custom_credential = None + self._token_provider = None - # Custom credential= parameter — takes priority, mutually exclusive + # Custom token_provider= parameter — takes priority, mutually exclusive # with Authentication= in the connection string. - if credential is not None: + if token_provider is not None: if _KEY_AUTHENTICATION in parsed_params: raise ValueError( - "Cannot specify both 'credential' parameter and " + "Cannot specify both 'token_provider' parameter and " "'Authentication' in the connection string. " "Use one or the other." ) - if not callable(getattr(credential, "get_token", None)): + if not callable(getattr(token_provider, "get_token", None)): raise TypeError( - f"credential must have a .get_token() method. " - f"Got {type(credential).__name__}." + f"token_provider must have a .get_token() method. " + f"Got {type(token_provider).__name__}." ) from mssql_python.auth import acquire_token_from_credential - token = acquire_token_from_credential(credential) + token = acquire_token_from_credential(token_provider) self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token - self._custom_credential = credential + self._token_provider = token_provider # Strip sensitive params (UID/PWD/Trusted_Connection) since # access-token auth is used — same as the Authentication= path. sanitized = remove_sensitive_params(parsed_params) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index c0b3e02e6..ea47d851e 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2961,12 +2961,12 @@ def bulkcopy( pycore_context = connstr_to_pycore_params(params) # Token acquisition — only thing cursor must handle (needs azure-identity SDK) - if self.connection._custom_credential is not None: + if self.connection._token_provider is not None: # User-supplied credential — use it directly for a fresh token. from mssql_python.auth import acquire_raw_token_from_credential try: - raw_token = acquire_raw_token_from_credential(self.connection._custom_credential) + raw_token = acquire_raw_token_from_credential(self.connection._token_provider) except RuntimeError as e: raise RuntimeError( f"Bulk copy failed: unable to acquire token " f"from custom credential: {e}" @@ -2976,7 +2976,7 @@ def bulkcopy( pycore_context.pop(key, None) logger.debug( "Bulk copy: acquired fresh token from custom credential (%s)", - type(self.connection._custom_credential).__name__, + type(self.connection._token_provider).__name__, ) elif self.connection._auth_type: # Fresh token acquisition for mssql-py-core connection. credential diff --git a/mssql_python/db_connection.py b/mssql_python/db_connection.py index 1688a56ed..894440009 100644 --- a/mssql_python/db_connection.py +++ b/mssql_python/db_connection.py @@ -15,7 +15,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - credential: Optional[object] = None, + token_provider: Optional[object] = None, **kwargs: Any, ) -> Connection: """ @@ -36,16 +36,29 @@ def connect( This per-connection override is useful for migration from pyodbc: connections that need string UUIDs can pass native_uuid=False, while the default (True) returns native uuid.UUID objects. - credential (object, optional): An Azure Identity credential object (or any object with a - ``.get_token(scope)`` method) used for Entra ID authentication. When provided, the - driver calls ``credential.get_token()`` to acquire a token instead of using the - built-in credential map. Cannot be combined with ``Authentication=`` in the - connection string. + token_provider (object, optional): A token provider for Microsoft Entra ID + authentication. This must be any object with a ``.get_token(scope)`` method that + returns an object with a ``.token`` attribute containing a raw JWT string — for + example, any ``azure-identity`` credential class such as + ``DefaultAzureCredential``, ``AzureCliCredential``, ``ManagedIdentityCredential``, + ``CertificateCredential``, etc. - For environment-portable code, prefer ``Authentication=ActiveDirectoryDefault`` in - the connection string — ``DefaultAzureCredential`` automatically picks the right - credential per environment. Use ``credential=`` only when you need explicit control - (e.g., excluding specific providers or using a credential not in the built-in map). + When provided, the driver calls ``token_provider.get_token()`` to acquire an + access token for SQL Server, bypassing the built-in credential map. + Cannot be combined with ``Authentication=`` in the connection string. + + For environment-portable code, prefer ``Authentication=ActiveDirectoryDefault`` + in the connection string — ``DefaultAzureCredential`` automatically picks the + right credential per environment (CLI on dev, Managed Identity in prod). + Use ``token_provider=`` only when you need explicit control over token + acquisition (e.g., excluding specific providers, using a credential not in + the built-in map, or passing custom options to the credential constructor). + + Example:: + + from azure.identity import AzureCliCredential + conn = mssql_python.connect("Server=s;Database=d", + token_provider=AzureCliCredential()) Keyword Args: **kwargs: Additional key/value pairs for the connection string. Below attributes are not implemented in the internal driver: @@ -69,7 +82,7 @@ def connect( attrs_before=attrs_before, timeout=timeout, native_uuid=native_uuid, - credential=credential, + token_provider=token_provider, **kwargs, ) return conn diff --git a/mssql_python/mssql_python.pyi b/mssql_python/mssql_python.pyi index e419b19bb..05aeec499 100644 --- a/mssql_python/mssql_python.pyi +++ b/mssql_python/mssql_python.pyi @@ -248,7 +248,7 @@ class Connection: attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - credential: Optional[object] = None, + token_provider: Optional[object] = None, **kwargs: Any, ) -> None: ... @@ -292,7 +292,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - credential: Optional[object] = None, + token_provider: Optional[object] = None, **kwargs: Any, ) -> Connection: ... diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index fe6937723..4dd13569b 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -524,7 +524,7 @@ def test_bulkcopy_path_preserves_user_assigned_msi_client_id(self): mock_conn.connection_str = "Server=tcp:test.database.windows.net;Database=testdb;" mock_conn._auth_type = "msi" mock_conn._credential_kwargs = {"client_id": client_id} - mock_conn._custom_credential = None + mock_conn._token_provider = None mock_conn._is_connected = True cursor = Cursor.__new__(Cursor) @@ -975,7 +975,7 @@ def test_token_output_correct_on_cache_miss_and_hit(self): assert "default" in _credential_cache -# ── Custom credential= parameter tests ── +# ── Custom token_provider= parameter tests ── class TestAcquireTokenFromCredential: @@ -1018,18 +1018,18 @@ def test_credential_raises_exception(self): class TestCustomCredentialConnect: - """Tests for the credential= parameter on connect().""" + """Tests for the token_provider= parameter on connect().""" @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_happy_path(self, mock_ddbc_conn): - """credential= acquires token and sets attrs_before.""" + """token_provider= acquires token and sets attrs_before.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) from mssql_python import connect - conn = connect("Server=test;Database=testdb", credential=mock_cred) - assert conn._custom_credential is mock_cred + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + assert conn._token_provider is mock_cred assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before # Existing auth_type should be None (no Authentication= in conn str) assert conn._auth_type is None @@ -1037,7 +1037,7 @@ def test_credential_happy_path(self, mock_ddbc_conn): @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_plus_authentication_raises_valueerror(self, mock_ddbc_conn): - """credential= + Authentication= raises ValueError.""" + """token_provider= + Authentication= raises ValueError.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) @@ -1046,12 +1046,12 @@ def test_credential_plus_authentication_raises_valueerror(self, mock_ddbc_conn): with pytest.raises(ValueError, match="Cannot specify both"): connect( "Server=test;Database=testdb;Authentication=ActiveDirectoryDefault", - credential=mock_cred, + token_provider=mock_cred, ) @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_plus_authentication_via_kwargs_raises_valueerror(self, mock_ddbc_conn): - """credential= + Authentication via kwargs raises ValueError.""" + """token_provider= + Authentication via kwargs raises ValueError.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) @@ -1060,7 +1060,7 @@ def test_credential_plus_authentication_via_kwargs_raises_valueerror(self, mock_ with pytest.raises(ValueError, match="Cannot specify both"): connect( "Server=test;Database=testdb", - credential=mock_cred, + token_provider=mock_cred, Authentication="ActiveDirectoryDefault", ) @@ -1070,23 +1070,23 @@ def test_credential_without_get_token_raises_typeerror(self, mock_ddbc_conn): mock_ddbc_conn.return_value = MagicMock() from mssql_python import connect - with pytest.raises(TypeError, match="credential must have a .get_token"): - connect("Server=test;Database=testdb", credential="not_a_credential") + with pytest.raises(TypeError, match="token_provider must have a .get_token"): + connect("Server=test;Database=testdb", token_provider="not_a_credential") @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_none_uses_existing_flow(self, mock_ddbc_conn): - """credential=None (default) uses existing auth flow, no change.""" + """token_provider=None (default) uses existing auth flow, no change.""" mock_ddbc_conn.return_value = MagicMock() from mssql_python import connect conn = connect("Server=test;Database=testdb;Authentication=ActiveDirectoryDefault") - assert conn._custom_credential is None + assert conn._token_provider is None assert conn._auth_type == "default" conn.close() @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_with_non_auth_attrs_before(self, mock_ddbc_conn): - """credential= works alongside non-auth attrs_before.""" + """token_provider= works alongside non-auth attrs_before.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) @@ -1095,7 +1095,7 @@ def test_credential_with_non_auth_attrs_before(self, mock_ddbc_conn): login_timeout_attr = 113 # SQL_ATTR_LOGIN_TIMEOUT conn = connect( "Server=test;Database=testdb", - credential=mock_cred, + token_provider=mock_cred, attrs_before={login_timeout_attr: 30}, ) assert conn._attrs_before[login_timeout_attr] == 30 @@ -1111,7 +1111,7 @@ def test_credential_get_token_failure_raises_runtime_error(self, mock_ddbc_conn) from mssql_python import connect with pytest.raises(RuntimeError, match="Failed to acquire token from credential"): - connect("Server=test;Database=testdb", credential=mock_cred) + connect("Server=test;Database=testdb", token_provider=mock_cred) @patch("mssql_python.connection.ddbc_bindings.Connection") def test_credential_with_non_callable_get_token_raises_typeerror(self, mock_ddbc_conn): @@ -1122,8 +1122,8 @@ def test_credential_with_non_callable_get_token_raises_typeerror(self, mock_ddbc class BadCredential: get_token = "not_a_method" - with pytest.raises(TypeError, match="credential must have a .get_token"): - connect("Server=test;Database=testdb", credential=BadCredential()) + with pytest.raises(TypeError, match="token_provider must have a .get_token"): + connect("Server=test;Database=testdb", token_provider=BadCredential()) @patch("mssql_python.connection.ddbc_bindings.Connection") def test_multiple_connections_share_same_credential(self, mock_ddbc_conn): @@ -1133,9 +1133,9 @@ def test_multiple_connections_share_same_credential(self, mock_ddbc_conn): mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) from mssql_python import connect - conn1 = connect("Server=test1;Database=db1", credential=mock_cred) - conn2 = connect("Server=test2;Database=db2", credential=mock_cred) - assert conn1._custom_credential is conn2._custom_credential + conn1 = connect("Server=test1;Database=db1", token_provider=mock_cred) + conn2 = connect("Server=test2;Database=db2", token_provider=mock_cred) + assert conn1._token_provider is conn2._token_provider assert mock_cred.get_token.call_count == 2 conn1.close() conn2.close() diff --git a/tests/test_020_bulkcopy_auth_cleanup.py b/tests/test_020_bulkcopy_auth_cleanup.py index 404faca91..4863c5e0e 100644 --- a/tests/test_020_bulkcopy_auth_cleanup.py +++ b/tests/test_020_bulkcopy_auth_cleanup.py @@ -22,7 +22,7 @@ def _make_cursor(connection_str, auth_type): mock_conn = MagicMock() mock_conn.connection_str = connection_str mock_conn._auth_type = auth_type - mock_conn._custom_credential = None + mock_conn._token_provider = None mock_conn._is_connected = True cursor = Cursor.__new__(Cursor) From 84497d7495a38cf2197309bca276f6119adab09f Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Thu, 25 Jun 2026 14:39:56 +0530 Subject: [PATCH 05/10] Add expiry capture and DB-API exceptions for token_provider; cover bulk-copy token branch C2: capture token expires_on from custom credential and store on connection. C3: raise DB-API InterfaceError/OperationalError instead of ValueError/TypeError for token_provider misuse and acquisition failures. Add unit tests covering the cursor bulk-copy token_provider branch (success, get_token failure, invalid token). --- mssql_python/auth.py | 99 +++++++++++++---- mssql_python/connection.py | 56 ++++++++-- mssql_python/cursor.py | 11 +- tests/test_008_auth.py | 135 +++++++++++++++++------- tests/test_020_bulkcopy_auth_cleanup.py | 101 ++++++++++++++++++ 5 files changed, 330 insertions(+), 72 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 365078339..0b15d198a 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -8,18 +8,19 @@ import platform import struct import threading +import time from typing import Tuple, Dict, Optional from mssql_python.logging import logger from mssql_python.constants import ( AuthType, - ConstantsDDBC, _AuthInternal, _KEY_AUTHENTICATION, _KEY_UID, _KEY_PWD, _KEY_TRUSTED_CONNECTION, ) +from mssql_python.exceptions import InterfaceError, OperationalError # Module-level credential instance cache. # Reusing credential objects allows the Azure Identity SDK's built-in @@ -34,6 +35,12 @@ # Canonical keys to strip when handing an Entra-token connection to ODBC. _SENSITIVE_KEYS = frozenset({_KEY_UID, _KEY_PWD, _KEY_TRUSTED_CONNECTION, _KEY_AUTHENTICATION}) +# Azure SQL Database OAuth scope for the Azure **commercial** cloud. Shared by +# the built-in AADAuth path and the custom token_provider path. Sovereign +# clouds (Azure US Gov, Azure China, Azure Germany) are out of scope — a token +# for a different audience is rejected by SQL Server at login. +_DATABASE_SCOPE = "https://database.windows.net/.default" + # Map Authentication connection-string values to internal short names. _AUTH_TYPE_MAP: Dict[str, str] = { AuthType.INTERACTIVE.value: _AuthInternal.INTERACTIVE, @@ -147,7 +154,7 @@ def _acquire_token( auth_type, ) credential = _credential_cache[cache_key] - raw_token = credential.get_token("https://database.windows.net/.default").token + raw_token = credential.get_token(_DATABASE_SCOPE).token logger.info( "get_token: Azure AD token acquired successfully - token_length=%d chars", len(raw_token), @@ -439,55 +446,100 @@ def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]: return _AUTH_TYPE_MAP.get(auth_value) -def _get_token_from_credential(credential: object) -> str: - """Internal: call credential.get_token() and return the raw JWT string. +def _get_token_from_credential(credential: object) -> Tuple[str, Optional[int]]: + """Internal: call credential.get_token() and return ``(raw_jwt, expires_on)``. Centralises the token-acquisition + error-wrapping logic that both :func:`acquire_token_from_credential` and :func:`acquire_raw_token_from_credential` need. + ``expires_on`` is the POSIX timestamp (seconds) at which the token + expires, taken from the credential's ``AccessToken`` result when present + (it is ``None`` if the provider does not supply one). It is captured so + callers can log it and reason about token lifetime; the access token + itself is a *pre-connect* ODBC attribute and cannot be refreshed on a + live connection (see the module docs on token lifecycle). + + Note: + The scope is hard-coded to the Azure **commercial** cloud + (``https://database.windows.net/.default``). Sovereign clouds + (Azure US Government, Azure China, Azure Germany) are **out of + scope** for the ``token_provider`` path. + Raises: - RuntimeError: If token acquisition fails. + InterfaceError: If the provider returns no valid ``.token`` string. + OperationalError: If the underlying ``get_token()`` call fails. """ + start_time = time.perf_counter() try: - raw_token = credential.get_token("https://database.windows.net/.default").token - logger.info( - "_get_token_from_credential: Token acquired from %s - length=%d chars", - type(credential).__name__, - len(raw_token), - ) - return raw_token + token_result = credential.get_token(_DATABASE_SCOPE) except Exception as e: logger.error( - "_get_token_from_credential: Failed - credential=%s, error=%s", + "_get_token_from_credential: get_token() failed - credential=%s, error=%s", type(credential).__name__, str(e), ) - raise RuntimeError( - f"Failed to acquire token from credential " f"({type(credential).__name__}): {e}" + # Preserve the original credential exception (e.g. azure-identity + # ClientAuthenticationError) as __cause__ for programmatic handling. + raise OperationalError( + driver_error=(f"Failed to acquire token from credential ({type(credential).__name__})"), + ddbc_error=str(e), ) from e + raw_token = getattr(token_result, "token", None) + if not isinstance(raw_token, str) or not raw_token: + raise InterfaceError( + driver_error=( + "token_provider.get_token() must return an object with a non-empty " + "string '.token' attribute." + ), + ddbc_error=f"got .token of type {type(raw_token).__name__}", + ) + + expires_on = getattr(token_result, "expires_on", None) + elapsed_ms = (time.perf_counter() - start_time) * 1000 + logger.info( + "_get_token_from_credential: Token acquired from %s - length=%d chars, " + "expires_on=%s, duration_ms=%.2f", + type(credential).__name__, + len(raw_token), + expires_on, + elapsed_ms, + ) + return raw_token, expires_on -def acquire_token_from_credential(credential: object) -> bytes: + +def acquire_token_from_credential(credential: object) -> Tuple[bytes, Optional[int]]: """Acquire an ODBC token struct from a user-supplied credential object. The credential must follow the Azure ``TokenCredential`` protocol — i.e. have a ``.get_token(scope)`` method returning an object with a ``.token`` attribute (a raw JWT string). + .. note:: + The scope is fixed to the Azure **commercial** cloud + (``https://database.windows.net/.default``). Sovereign clouds (Azure + US Government, Azure China, Azure Germany) are **out of scope** — for + those, supply a pre-acquired token via + ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`` instead. + Args: credential: Any object with a ``.get_token(scope)`` method. Returns: - bytes: ODBC-compatible token struct for ``SQL_COPT_SS_ACCESS_TOKEN``. + Tuple[bytes, Optional[int]]: The ODBC token struct for + ``SQL_COPT_SS_ACCESS_TOKEN`` and the token's ``expires_on`` POSIX + timestamp (``None`` if the provider does not supply one). Raises: - RuntimeError: If token acquisition fails. + InterfaceError: If the provider returns no valid ``.token`` string. + OperationalError: If the underlying ``get_token()`` call fails. """ - return AADAuth.get_token_struct(_get_token_from_credential(credential)) + raw_token, expires_on = _get_token_from_credential(credential) + return AADAuth.get_token_struct(raw_token), expires_on -def acquire_raw_token_from_credential(credential: object) -> str: +def acquire_raw_token_from_credential(credential: object) -> Tuple[str, Optional[int]]: """Acquire a raw JWT string from a user-supplied credential object. Used by bulk copy, which needs the raw JWT rather than the ODBC struct. @@ -496,9 +548,12 @@ def acquire_raw_token_from_credential(credential: object) -> str: credential: Any object with a ``.get_token(scope)`` method. Returns: - str: Raw JWT token string. + Tuple[str, Optional[int]]: The raw JWT token string and the token's + ``expires_on`` POSIX timestamp (``None`` if the provider does not + supply one). Raises: - RuntimeError: If token acquisition fails. + InterfaceError: If the provider returns no valid ``.token`` string. + OperationalError: If the underlying ``get_token()`` call fails. """ return _get_token_from_credential(credential) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index ba4b5be03..ba40eca9f 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -139,10 +139,10 @@ def _validate_utf16_wchar_compatibility( # Generate context-appropriate error messages if "ctype" in context: - driver_error = f"SQL_WCHAR ctype only supports UTF-16 encodings" + driver_error = "SQL_WCHAR ctype only supports UTF-16 encodings" ddbc_context = "SQL_WCHAR ctype" else: - driver_error = f"SQL_WCHAR only supports UTF-16 encodings" + driver_error = "SQL_WCHAR only supports UTF-16 encodings" ddbc_context = "SQL_WCHAR" raise ProgrammingError( @@ -271,6 +271,21 @@ def __init__( native_uuid (bool, optional): Controls whether UNIQUEIDENTIFIER columns return uuid.UUID objects (True) or str (False) for cursors created from this connection. None (default) defers to the module-level ``mssql_python.native_uuid`` setting (True). + token_provider (object, optional): Advanced token provider for Microsoft Entra ID + authentication. Must expose a callable ``.get_token(scope)`` method that returns + an object with a ``.token`` attribute. + + This parameter is mutually exclusive with ``Authentication=`` in the connection + string and raises ``ValueError`` at connect time when both are provided. + + .. note:: + The token scope is fixed to the Azure **commercial** cloud + (``https://database.windows.net/.default``). Sovereign clouds + (e.g. Azure US Government, Azure China, Azure Germany) are + **out of scope** for this parameter — a token acquired for a + different audience will be rejected by SQL Server at login. + For sovereign clouds, acquire the token yourself and pass it + via ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`` instead. **kwargs: Additional key/value pairs for the connection string. Returns: @@ -343,26 +358,47 @@ def __init__( # User-supplied token provider for custom Entra ID authentication. # Stored so bulk copy can call .get_token() for a fresh JWT later. self._token_provider = None + # POSIX timestamp (seconds) at which the current access token expires, + # captured from the credential's AccessToken result. None when unknown. + # The token is a pre-connect ODBC attribute and cannot be refreshed on + # a live connection — this is exposed for diagnostics/logging only. + self._token_expires_on: Optional[int] = None # Custom token_provider= parameter — takes priority, mutually exclusive # with Authentication= in the connection string. if token_provider is not None: if _KEY_AUTHENTICATION in parsed_params: - raise ValueError( - "Cannot specify both 'token_provider' parameter and " - "'Authentication' in the connection string. " - "Use one or the other." + raise InterfaceError( + driver_error=( + "Cannot specify both 'token_provider' parameter and " + "'Authentication' in the connection string. " + "Use one or the other." + ), + ddbc_error="", + ) + if ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in self._attrs_before: + raise InterfaceError( + driver_error=( + "Cannot specify both 'token_provider' parameter and " + "attrs_before[SQL_COPT_SS_ACCESS_TOKEN]. " + "Use one token source." + ), + ddbc_error="", ) if not callable(getattr(token_provider, "get_token", None)): - raise TypeError( - f"token_provider must have a .get_token() method. " - f"Got {type(token_provider).__name__}." + raise InterfaceError( + driver_error=( + f"token_provider must have a .get_token() method. " + f"Got {type(token_provider).__name__}." + ), + ddbc_error="", ) from mssql_python.auth import acquire_token_from_credential - token = acquire_token_from_credential(token_provider) + token, token_expires_on = acquire_token_from_credential(token_provider) self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token self._token_provider = token_provider + self._token_expires_on = token_expires_on # Strip sensitive params (UID/PWD/Trusted_Connection) since # access-token auth is used — same as the Authentication= path. sanitized = remove_sensitive_params(parsed_params) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 85239ccec..b9e9282c1 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2947,10 +2947,13 @@ def bulkcopy( from mssql_python.auth import acquire_raw_token_from_credential try: - raw_token = acquire_raw_token_from_credential(self.connection._token_provider) - except RuntimeError as e: - raise RuntimeError( - f"Bulk copy failed: unable to acquire token " f"from custom credential: {e}" + raw_token, _ = acquire_raw_token_from_credential(self.connection._token_provider) + except (OperationalError, InterfaceError) as e: + raise OperationalError( + driver_error=( + "Bulk copy failed: unable to acquire token from custom credential" + ), + ddbc_error=str(e), ) from e pycore_context["access_token"] = raw_token for key in ("authentication", "user_name", "password"): diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index 8413d4148..274ce4efd 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -8,6 +8,7 @@ import platform import sys import threading +from concurrent.futures import ThreadPoolExecutor from unittest.mock import patch, MagicMock from mssql_python.auth import ( AADAuth, @@ -18,11 +19,11 @@ get_auth_token, extract_auth_type, _credential_cache, - _credential_cache_lock, acquire_token_from_credential, acquire_raw_token_from_credential, ) from mssql_python.constants import AuthType, ConstantsDDBC +from mssql_python.exceptions import InterfaceError, OperationalError import secrets SAMPLE_TOKEN = secrets.token_hex(44) @@ -1042,19 +1043,34 @@ class TestAcquireTokenFromCredential: """Tests for the acquire_token_from_credential helper.""" def test_happy_path(self): - """acquire_token_from_credential returns a valid token struct.""" + """acquire_token_from_credential returns a token struct and expiry.""" mock_cred = MagicMock() - mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) - result = acquire_token_from_credential(mock_cred) - assert isinstance(result, bytes) - assert len(result) > 4 + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + token_struct, expires_on = acquire_token_from_credential(mock_cred) + assert isinstance(token_struct, bytes) + assert len(token_struct) > 4 + assert expires_on == 1893456000 mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") def test_credential_raises_exception(self): - """acquire_token_from_credential wraps credential errors in RuntimeError.""" + """acquire_token_from_credential wraps credential errors in OperationalError.""" mock_cred = MagicMock() mock_cred.get_token.side_effect = Exception("auth failed") - with pytest.raises(RuntimeError, match="Failed to acquire token from credential"): + with pytest.raises(OperationalError, match="Failed to acquire token from credential"): + acquire_token_from_credential(mock_cred) + + def test_missing_token_attribute_raises_interface_error(self): + """Token provider must return an object exposing a non-empty string .token.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = object() + with pytest.raises(InterfaceError, match="non-empty"): + acquire_token_from_credential(mock_cred) + + def test_non_string_token_raises_interface_error(self): + """Token provider must return a .token value of type str.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=123) + with pytest.raises(InterfaceError, match="non-empty"): acquire_token_from_credential(mock_cred) @@ -1062,79 +1078,109 @@ class TestAcquireRawTokenFromCredential: """Tests for the acquire_raw_token_from_credential helper.""" def test_happy_path(self): - """acquire_raw_token_from_credential returns the raw JWT string.""" + """acquire_raw_token_from_credential returns the raw JWT string and expiry.""" mock_cred = MagicMock() - mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) - result = acquire_raw_token_from_credential(mock_cred) - assert result == SAMPLE_TOKEN + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + raw_token, expires_on = acquire_raw_token_from_credential(mock_cred) + assert raw_token == SAMPLE_TOKEN + assert expires_on == 1893456000 mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") def test_credential_raises_exception(self): - """acquire_raw_token_from_credential wraps credential errors in RuntimeError.""" + """acquire_raw_token_from_credential wraps credential errors in OperationalError.""" mock_cred = MagicMock() mock_cred.get_token.side_effect = Exception("auth failed") - with pytest.raises(RuntimeError, match="Failed to acquire token from credential"): + with pytest.raises(OperationalError, match="Failed to acquire token from credential"): + acquire_raw_token_from_credential(mock_cred) + + def test_empty_string_token_raises_interface_error(self): + """Empty token values are rejected as invalid provider output.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token="") + with pytest.raises(InterfaceError, match="non-empty"): acquire_raw_token_from_credential(mock_cred) -class TestCustomCredentialConnect: +class TestCustomTokenProviderConnect: """Tests for the token_provider= parameter on connect().""" @patch("mssql_python.connection.ddbc_bindings.Connection") - def test_credential_happy_path(self, mock_ddbc_conn): + def test_token_provider_happy_path(self, mock_ddbc_conn): """token_provider= acquires token and sets attrs_before.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() - mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) from mssql_python import connect conn = connect("Server=test;Database=testdb", token_provider=mock_cred) assert conn._token_provider is mock_cred + assert conn._token_expires_on == 1893456000 assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before # Existing auth_type should be None (no Authentication= in conn str) assert conn._auth_type is None conn.close() @patch("mssql_python.connection.ddbc_bindings.Connection") - def test_credential_plus_authentication_raises_valueerror(self, mock_ddbc_conn): - """token_provider= + Authentication= raises ValueError.""" + def test_token_provider_plus_authentication_raises_valueerror(self, mock_ddbc_conn): + """token_provider= + Authentication= raises InterfaceError.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) from mssql_python import connect - with pytest.raises(ValueError, match="Cannot specify both"): + with pytest.raises(InterfaceError, match="Cannot specify both"): connect( "Server=test;Database=testdb;Authentication=ActiveDirectoryDefault", token_provider=mock_cred, ) + mock_cred.get_token.assert_not_called() + mock_ddbc_conn.assert_not_called() @patch("mssql_python.connection.ddbc_bindings.Connection") - def test_credential_plus_authentication_via_kwargs_raises_valueerror(self, mock_ddbc_conn): - """token_provider= + Authentication via kwargs raises ValueError.""" + def test_token_provider_plus_authentication_via_kwargs_raises_valueerror(self, mock_ddbc_conn): + """token_provider= + Authentication via kwargs raises InterfaceError.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) from mssql_python import connect - with pytest.raises(ValueError, match="Cannot specify both"): + with pytest.raises(InterfaceError, match="Cannot specify both"): connect( "Server=test;Database=testdb", token_provider=mock_cred, Authentication="ActiveDirectoryDefault", ) + mock_cred.get_token.assert_not_called() + mock_ddbc_conn.assert_not_called() @patch("mssql_python.connection.ddbc_bindings.Connection") - def test_credential_without_get_token_raises_typeerror(self, mock_ddbc_conn): - """Passing an object without .get_token() raises TypeError.""" + def test_token_provider_plus_attrs_before_access_token_raises_valueerror(self, mock_ddbc_conn): + """token_provider= + manual attrs_before token is ambiguous and rejected.""" mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) from mssql_python import connect - with pytest.raises(TypeError, match="token_provider must have a .get_token"): + with pytest.raises(InterfaceError, match="SQL_COPT_SS_ACCESS_TOKEN"): + connect( + "Server=test;Database=testdb", + token_provider=mock_cred, + attrs_before={ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: b"existing_token"}, + ) + mock_cred.get_token.assert_not_called() + mock_ddbc_conn.assert_not_called() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_without_get_token_raises_typeerror(self, mock_ddbc_conn): + """Passing an object without .get_token() raises InterfaceError.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + with pytest.raises(InterfaceError, match="token_provider must have a .get_token"): connect("Server=test;Database=testdb", token_provider="not_a_credential") @patch("mssql_python.connection.ddbc_bindings.Connection") - def test_credential_none_uses_existing_flow(self, mock_ddbc_conn): + def test_token_provider_none_uses_existing_flow(self, mock_ddbc_conn): """token_provider=None (default) uses existing auth flow, no change.""" mock_ddbc_conn.return_value = MagicMock() from mssql_python import connect @@ -1145,7 +1191,7 @@ def test_credential_none_uses_existing_flow(self, mock_ddbc_conn): conn.close() @patch("mssql_python.connection.ddbc_bindings.Connection") - def test_credential_with_non_auth_attrs_before(self, mock_ddbc_conn): + def test_token_provider_with_non_auth_attrs_before(self, mock_ddbc_conn): """token_provider= works alongside non-auth attrs_before.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() @@ -1163,31 +1209,31 @@ def test_credential_with_non_auth_attrs_before(self, mock_ddbc_conn): conn.close() @patch("mssql_python.connection.ddbc_bindings.Connection") - def test_credential_get_token_failure_raises_runtime_error(self, mock_ddbc_conn): - """If credential.get_token() fails, connect() raises RuntimeError.""" + def test_token_provider_get_token_failure_raises_runtime_error(self, mock_ddbc_conn): + """If token_provider.get_token() fails, connect() raises OperationalError.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.side_effect = Exception("token acquisition failed") from mssql_python import connect - with pytest.raises(RuntimeError, match="Failed to acquire token from credential"): + with pytest.raises(OperationalError, match="Failed to acquire token from credential"): connect("Server=test;Database=testdb", token_provider=mock_cred) @patch("mssql_python.connection.ddbc_bindings.Connection") - def test_credential_with_non_callable_get_token_raises_typeerror(self, mock_ddbc_conn): - """Object with .get_token as a non-callable attribute raises TypeError.""" + def test_token_provider_with_non_callable_get_token_raises_typeerror(self, mock_ddbc_conn): + """Object with .get_token as a non-callable attribute raises InterfaceError.""" mock_ddbc_conn.return_value = MagicMock() from mssql_python import connect class BadCredential: get_token = "not_a_method" - with pytest.raises(TypeError, match="token_provider must have a .get_token"): + with pytest.raises(InterfaceError, match="token_provider must have a .get_token"): connect("Server=test;Database=testdb", token_provider=BadCredential()) @patch("mssql_python.connection.ddbc_bindings.Connection") - def test_multiple_connections_share_same_credential(self, mock_ddbc_conn): - """Two connections can share the same credential object safely.""" + def test_multiple_connections_share_same_token_provider(self, mock_ddbc_conn): + """Two connections can share the same token provider object safely.""" mock_ddbc_conn.return_value = MagicMock() mock_cred = MagicMock() mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) @@ -1200,6 +1246,23 @@ def test_multiple_connections_share_same_credential(self, mock_ddbc_conn): conn1.close() conn2.close() + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_concurrent_connections_with_same_token_provider(self, mock_ddbc_conn): + """Concurrent connect() calls with one token provider should succeed.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + def _open_and_close(i): + conn = connect(f"Server=test{i};Database=testdb", token_provider=mock_cred) + conn.close() + + with ThreadPoolExecutor(max_workers=8) as executor: + list(executor.map(_open_and_close, range(20))) + + assert mock_cred.get_token.call_count == 20 + class TestParseTenantId: def test_guid_tenant(self): diff --git a/tests/test_020_bulkcopy_auth_cleanup.py b/tests/test_020_bulkcopy_auth_cleanup.py index 4863c5e0e..7543066a7 100644 --- a/tests/test_020_bulkcopy_auth_cleanup.py +++ b/tests/test_020_bulkcopy_auth_cleanup.py @@ -12,6 +12,10 @@ import secrets from unittest.mock import MagicMock, patch +import pytest + +from mssql_python.exceptions import OperationalError + SAMPLE_TOKEN = secrets.token_hex(44) @@ -109,3 +113,100 @@ def capture_context(ctx, **kwargs): assert "access_token" not in captured_context assert captured_context.get("user_name") == "sa" assert captured_context.get("password") == "mypwd" + + +class TestBulkcopyTokenProvider: + """Verify cursor.bulkcopy acquires a token from a custom token_provider.""" + + @patch("mssql_python.cursor.logger") + def test_token_provider_replaces_auth_fields(self, mock_logger): + """token_provider present ⇒ fresh token injected, stale auth keys removed.""" + mock_logger.is_debug_enabled = False + + # Custom credential whose get_token returns an AccessToken-like object. + credential = MagicMock() + credential.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" + "Authentication=ActiveDirectoryDefault;UID=user@test.com;PWD=secret", + "activedirectorydefault", + ) + # token_provider takes precedence over _auth_type. + cursor._connection._token_provider = credential + + captured_context = {} + + mock_pycore_cursor = MagicMock() + mock_pycore_cursor.bulkcopy.return_value = { + "rows_copied": 1, + "batch_count": 1, + "elapsed_time": 0.1, + } + mock_pycore_conn = MagicMock() + mock_pycore_conn.cursor.return_value = mock_pycore_cursor + + def capture_context(ctx, **kwargs): + captured_context.update(ctx) + return mock_pycore_conn + + mock_pycore_module = MagicMock() + mock_pycore_module.PyCoreConnection = capture_context + + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) + + # The credential was consulted for a fresh token. + credential.get_token.assert_called_once() + assert captured_context.get("access_token") == SAMPLE_TOKEN + assert "authentication" not in captured_context + assert "user_name" not in captured_context + assert "password" not in captured_context + + @patch("mssql_python.cursor.logger") + def test_token_provider_get_token_failure_rewrapped(self, mock_logger): + """credential.get_token raising ⇒ bulkcopy raises OperationalError.""" + mock_logger.is_debug_enabled = False + + credential = MagicMock() + credential.get_token.side_effect = RuntimeError("network down") + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" + "Authentication=ActiveDirectoryDefault", + "activedirectorydefault", + ) + cursor._connection._token_provider = credential + + mock_pycore_module = MagicMock() + + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): + with pytest.raises(OperationalError) as exc_info: + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) + + assert "unable to acquire token from custom credential" in str(exc_info.value) + + @patch("mssql_python.cursor.logger") + def test_token_provider_invalid_token_rewrapped(self, mock_logger): + """credential returning a non-string token ⇒ bulkcopy raises OperationalError.""" + mock_logger.is_debug_enabled = False + + # .token is not a non-empty string ⇒ _get_token_from_credential raises InterfaceError, + # which cursor.bulkcopy catches and re-wraps as OperationalError. + credential = MagicMock() + credential.get_token.return_value = MagicMock(token="", expires_on=None) + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" + "Authentication=ActiveDirectoryDefault", + "activedirectorydefault", + ) + cursor._connection._token_provider = credential + + mock_pycore_module = MagicMock() + + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): + with pytest.raises(OperationalError) as exc_info: + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) + + assert "unable to acquire token from custom credential" in str(exc_info.value) From e41058a79987e76b40d5ed9f2162542959c695f6 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 26 Jun 2026 11:28:35 +0530 Subject: [PATCH 06/10] Add token_provider validation, dropped-credential warning, and Protocol typing; fix docstring error type - auth.py: type credential params as TokenProvider Protocol; hard-code commercial-cloud scope - connection.py: warn on ignored UID/PWD/Trusted_Connection when token_provider set; validate get_token arity; document token lifecycle limitations - db_connection.py: note sovereign clouds out of scope - test_008_auth.py: cover arity validation and dropped-credential warning --- mssql_python/auth.py | 32 ++++++++++--- mssql_python/connection.py | 84 +++++++++++++++++++++++++++++------ mssql_python/db_connection.py | 12 ++++- tests/test_008_auth.py | 81 +++++++++++++++++++++++++++++++++ 4 files changed, 188 insertions(+), 21 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 0b15d198a..772c53c35 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -9,7 +9,7 @@ import struct import threading import time -from typing import Tuple, Dict, Optional +from typing import Tuple, Dict, Optional, Any, Protocol, runtime_checkable from mssql_python.logging import logger from mssql_python.constants import ( @@ -41,6 +41,22 @@ # for a different audience is rejected by SQL Server at login. _DATABASE_SCOPE = "https://database.windows.net/.default" + +@runtime_checkable +class TokenProvider(Protocol): + """Structural type accepted by the ``token_provider=`` connect parameter. + + Any object exposing a ``get_token(scope, ...)`` method qualifies — notably + every ``azure-identity`` credential class (``DefaultAzureCredential``, + ``AzureCliCredential``, ``ManagedIdentityCredential``, etc.). The returned + object must expose a non-empty ``.token`` (str); an optional + ``.expires_on`` (int POSIX timestamp) is captured for diagnostics. + """ + + def get_token(self, *scopes: str, **kwargs: Any) -> Any: # pragma: no cover - protocol + ... + + # Map Authentication connection-string values to internal short names. _AUTH_TYPE_MAP: Dict[str, str] = { AuthType.INTERACTIVE.value: _AuthInternal.INTERACTIVE, @@ -446,7 +462,7 @@ def extract_auth_type(parsed_params: Dict[str, str]) -> Optional[str]: return _AUTH_TYPE_MAP.get(auth_value) -def _get_token_from_credential(credential: object) -> Tuple[str, Optional[int]]: +def _get_token_from_credential(credential: "TokenProvider") -> Tuple[str, Optional[int]]: """Internal: call credential.get_token() and return ``(raw_jwt, expires_on)``. Centralises the token-acquisition + error-wrapping logic that both @@ -464,7 +480,8 @@ def _get_token_from_credential(credential: object) -> Tuple[str, Optional[int]]: The scope is hard-coded to the Azure **commercial** cloud (``https://database.windows.net/.default``). Sovereign clouds (Azure US Government, Azure China, Azure Germany) are **out of - scope** for the ``token_provider`` path. + scope** for the ``token_provider`` path — for those, supply a + pre-acquired token via ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]``. Raises: InterfaceError: If the provider returns no valid ``.token`` string. @@ -509,7 +526,7 @@ def _get_token_from_credential(credential: object) -> Tuple[str, Optional[int]]: return raw_token, expires_on -def acquire_token_from_credential(credential: object) -> Tuple[bytes, Optional[int]]: +def acquire_token_from_credential(credential: "TokenProvider") -> Tuple[bytes, Optional[int]]: """Acquire an ODBC token struct from a user-supplied credential object. The credential must follow the Azure ``TokenCredential`` protocol — i.e. @@ -539,11 +556,16 @@ def acquire_token_from_credential(credential: object) -> Tuple[bytes, Optional[i return AADAuth.get_token_struct(raw_token), expires_on -def acquire_raw_token_from_credential(credential: object) -> Tuple[str, Optional[int]]: +def acquire_raw_token_from_credential(credential: "TokenProvider") -> Tuple[str, Optional[int]]: """Acquire a raw JWT string from a user-supplied credential object. Used by bulk copy, which needs the raw JWT rather than the ODBC struct. + .. note:: + The scope is fixed to the Azure **commercial** cloud. Sovereign + clouds are **out of scope** — see + :func:`acquire_token_from_credential`. + Args: credential: Any object with a ``.get_token(scope)`` method. diff --git a/mssql_python/connection.py b/mssql_python/connection.py index ba40eca9f..6c1ea0350 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -14,6 +14,8 @@ import weakref import re import codecs +import inspect +import warnings from typing import Any, Dict, Optional, Union, List, Tuple, Callable, TYPE_CHECKING import threading @@ -53,11 +55,13 @@ _RESERVED_PARAMETERS, _KEY_AUTHENTICATION, _KEY_UID, - _AuthInternal, + _KEY_PWD, + _KEY_TRUSTED_CONNECTION, ) if TYPE_CHECKING: from mssql_python.row import Row + from mssql_python.auth import TokenProvider # Add SQL_WMETADATA constant for metadata decoding configuration SQL_WMETADATA: int = -99 # Special flag for column name decoding @@ -139,10 +143,10 @@ def _validate_utf16_wchar_compatibility( # Generate context-appropriate error messages if "ctype" in context: - driver_error = "SQL_WCHAR ctype only supports UTF-16 encodings" + driver_error = f"SQL_WCHAR ctype only supports UTF-16 encodings" ddbc_context = "SQL_WCHAR ctype" else: - driver_error = "SQL_WCHAR only supports UTF-16 encodings" + driver_error = f"SQL_WCHAR only supports UTF-16 encodings" ddbc_context = "SQL_WCHAR" raise ProgrammingError( @@ -251,7 +255,7 @@ def __init__( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - token_provider: Optional[object] = None, + token_provider: Optional["TokenProvider"] = None, **kwargs: Any, ) -> None: """ @@ -276,22 +280,39 @@ def __init__( an object with a ``.token`` attribute. This parameter is mutually exclusive with ``Authentication=`` in the connection - string and raises ``ValueError`` at connect time when both are provided. + string and with ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]``; supplying more than + one token source raises ``InterfaceError`` at connect time. + + If ``UID``/``PWD``/``Trusted_Connection`` are also present in the connection + string they are ignored (access-token auth wins) and a warning is emitted. .. note:: The token scope is fixed to the Azure **commercial** cloud - (``https://database.windows.net/.default``). Sovereign clouds - (e.g. Azure US Government, Azure China, Azure Germany) are - **out of scope** for this parameter — a token acquired for a - different audience will be rejected by SQL Server at login. - For sovereign clouds, acquire the token yourself and pass it - via ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`` instead. + (``https://database.windows.net/.default``). Sovereign clouds (Azure US + Government, Azure China, Azure Germany) are **out of scope** for this + parameter — a token acquired for a different audience is rejected by SQL + Server at login. For sovereign clouds, acquire the token yourself and pass + it via ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`` instead. + + .. note:: + Token lifecycle limitations: the access token is a *pre-connect* ODBC + attribute, so it cannot be refreshed on a live connection. Tokens are + **not** re-acquired automatically when a pooled/native connection is reused + after expiry, and Continuous Access Evaluation (CAE) claims challenges are + not handled. These require native driver support and are tracked as + follow-up work. Interactive credentials (e.g. + ``InteractiveBrowserCredential``) block ``connect()`` until the user + completes sign-in; prefer non-interactive credentials in server contexts. **kwargs: Additional key/value pairs for the connection string. Returns: None Raises: + InterfaceError: If ``token_provider`` is misused (combined with another token + source, or lacking a valid ``.get_token`` method), or the credential returns + no valid token. + OperationalError: If acquiring a token from ``token_provider`` fails. ValueError: If the connection string is invalid or connection fails. This method sets up the initial state for the connection object, @@ -393,8 +414,43 @@ def __init__( ), ddbc_error="", ) - from mssql_python.auth import acquire_token_from_credential - + # Validate that get_token can accept the scope positional argument. + # Inspecting the signature catches obvious arity bugs (e.g. a + # zero-arg get_token) up-front with a clear message instead of an + # opaque TypeError surfacing mid-acquisition. C-implemented or + # otherwise un-inspectable callables are skipped and validated at + # call time. + from mssql_python.auth import acquire_token_from_credential, _DATABASE_SCOPE + + get_token = getattr(token_provider, "get_token") + try: + signature = inspect.signature(get_token) + except (ValueError, TypeError): + signature = None + if signature is not None: + try: + signature.bind(_DATABASE_SCOPE) + except TypeError as exc: + raise InterfaceError( + driver_error=( + "token_provider.get_token() must accept a scope " + "positional argument, e.g. get_token(scope)." + ), + ddbc_error=str(exc), + ) from exc + # access-token auth ignores UID/PWD/Trusted_Connection — warn so the + # user is not surprised that those credentials are silently dropped. + dropped = [ + key for key in (_KEY_UID, _KEY_PWD, _KEY_TRUSTED_CONNECTION) if key in parsed_params + ] + if dropped: + warnings.warn( + "token_provider is set, so the following connection-string " + f"credential(s) are ignored: {', '.join(sorted(dropped))}. " + "Remove them to silence this warning.", + UserWarning, + stacklevel=2, + ) token, token_expires_on = acquire_token_from_credential(token_provider) self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token self._token_provider = token_provider @@ -413,7 +469,7 @@ def __init__( # Capture credential kwargs (e.g. user-assigned MSI client_id) # from the parsed dict *before* remove_sensitive_params strips UID. credential_kwargs: Optional[Dict[str, str]] = None - if auth_type == _AuthInternal.MSI: + if auth_type == "msi": uid = (parsed_params.get(_KEY_UID) or "").strip() if uid: credential_kwargs = {"client_id": uid} diff --git a/mssql_python/db_connection.py b/mssql_python/db_connection.py index 894440009..4992ffd34 100644 --- a/mssql_python/db_connection.py +++ b/mssql_python/db_connection.py @@ -4,10 +4,13 @@ This module provides a way to create a new connection object to interact with the database. """ -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Union, TYPE_CHECKING from mssql_python.connection import Connection +if TYPE_CHECKING: + from mssql_python.auth import TokenProvider + def connect( connection_str: str = "", @@ -15,7 +18,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - token_provider: Optional[object] = None, + token_provider: Optional["TokenProvider"] = None, **kwargs: Any, ) -> Connection: """ @@ -59,6 +62,11 @@ def connect( from azure.identity import AzureCliCredential conn = mssql_python.connect("Server=s;Database=d", token_provider=AzureCliCredential()) + + Note: the token scope is fixed to the Azure **commercial** cloud + (``https://database.windows.net/.default``). Sovereign clouds (Azure US + Government, Azure China, Azure Germany) are **out of scope** — acquire the token + yourself and pass it via ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`` instead. Keyword Args: **kwargs: Additional key/value pairs for the connection string. Below attributes are not implemented in the internal driver: diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index 274ce4efd..19b178527 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -8,6 +8,7 @@ import platform import sys import threading +import warnings from concurrent.futures import ThreadPoolExecutor from unittest.mock import patch, MagicMock from mssql_python.auth import ( @@ -1073,6 +1074,13 @@ def test_non_string_token_raises_interface_error(self): with pytest.raises(InterfaceError, match="non-empty"): acquire_token_from_credential(mock_cred) + def test_scope_is_commercial_cloud(self): + """The scope is hard-coded to the Azure commercial-cloud audience.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + acquire_token_from_credential(mock_cred) + mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") + class TestAcquireRawTokenFromCredential: """Tests for the acquire_raw_token_from_credential helper.""" @@ -1264,6 +1272,79 @@ def _open_and_close(i): assert mock_cred.get_token.call_count == 20 +class TestTokenProviderValidation: + """Tests for token_provider get_token arity validation and the dropped-credential warning.""" + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_scope_is_commercial_cloud(self, mock_ddbc_conn): + """connect() requests the fixed commercial-cloud scope from the credential.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_get_token_wrong_arity_raises_interface_error(self, mock_ddbc_conn): + """A get_token() that cannot accept a scope argument is rejected up-front.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class ZeroArgCredential: + def get_token(self): # missing scope parameter + return MagicMock(token=SAMPLE_TOKEN) + + with pytest.raises(InterfaceError, match="must accept a scope"): + connect("Server=test;Database=testdb", token_provider=ZeroArgCredential()) + mock_ddbc_conn.assert_not_called() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_get_token_with_scope_param_accepted(self, mock_ddbc_conn): + """A well-formed get_token(scope) passes arity validation.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class GoodCredential: + def get_token(self, scope): + return MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + + conn = connect("Server=test;Database=testdb", token_provider=GoodCredential()) + assert conn._token_expires_on == 1893456000 + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_dropped_uid_pwd_emits_warning(self, mock_ddbc_conn): + """UID/PWD in the connection string trigger a warning when token_provider is set.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with pytest.warns(UserWarning, match="credential\\(s\\) are ignored"): + conn = connect( + "Server=test;Database=testdb;UID=user@test.com;PWD=secret", + token_provider=mock_cred, + ) + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_no_warning_without_dropped_credentials(self, mock_ddbc_conn): + """No 'ignored credentials' warning when the connection string has no UID/PWD.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + assert not any("are ignored" in str(w.message) for w in caught) + conn.close() + + class TestParseTenantId: def test_guid_tenant(self): url = "https://login.microsoftonline.com/aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee/" From 2b796ae86cfd887844a0b40da9f3fe19d108fc65 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 26 Jun 2026 22:10:04 +0530 Subject: [PATCH 07/10] Disable pooling for access-token connections to prevent cross-principal collision The native connection pool keys on the sanitized connection string only, and the access token lives in attrs_before (applied once on a new physical connection, never re-applied on reuse). Two different principals sharing the same Server/Database collapsed into one pool bucket, so one caller could be handed another's authenticated connection (silent identity confusion). Fix: Connection.__init__ disables pooling whenever SQL_COPT_SS_ACCESS_TOKEN is present in attrs_before. One condition covers all access-token paths: raw attrs_before token, built-in Authentication=ActiveDirectory* (token-injecting), and token_provider=. Driver-native paths (e.g. ServicePrincipal) keep creds in the connection string and remain poolable. Adds regression tests in TestTokenProviderPooling. --- .coveragerc | 3 + CHANGELOG.md | 9 + mssql_python/auth.py | 48 ++ mssql_python/connection.py | 223 +++++--- tests/test_008_auth.py | 678 +++++++++++++++++++++++- tests/test_020_bulkcopy_auth_cleanup.py | 72 ++- 6 files changed, 948 insertions(+), 85 deletions(-) diff --git a/.coveragerc b/.coveragerc index 1182c6524..9b922851c 100644 --- a/.coveragerc +++ b/.coveragerc @@ -20,6 +20,9 @@ exclude_lines = # Don't complain if non-runnable code isn't run if __name__ == .__main__.: + + # Type-checking-only imports never execute at runtime + if TYPE_CHECKING: # Exclude all logging statements (zero overhead when disabled by design) logger\.debug diff --git a/CHANGELOG.md b/CHANGELOG.md index 4fa44f85b..7bdab9919 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Added - New feature: Support for macOS and Linux. - Documentation: Added API documentation in the Wiki. +- New `token_provider=` parameter on `connect()` / `Connection` for Microsoft + Entra ID authentication with a custom credential object. Accepts any object + exposing a `.get_token(scope)` method (e.g. any `azure-identity` credential + such as `DefaultAzureCredential`, `AzureCliCredential`, + `ManagedIdentityCredential`). Mutually exclusive with `Authentication=` in + the connection string and with `attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`. + Bulk copy re-acquires a fresh token from the provider on each operation. The + token scope is fixed to the Azure commercial cloud; sovereign clouds are out + of scope (supply a pre-acquired token via `attrs_before` instead). - Bulk copy now supports `Authentication=ActiveDirectoryServicePrincipal` via an `entra_id_token_factory` callback registered on the mssql-py-core connection. The callback is invoked by mssql-tds mid-handshake (FedAuth diff --git a/mssql_python/auth.py b/mssql_python/auth.py index 772c53c35..fbe1752bc 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -5,10 +5,12 @@ """ import hashlib +import inspect import platform import struct import threading import time +import warnings from typing import Tuple, Dict, Optional, Any, Protocol, runtime_checkable from mssql_python.logging import logger @@ -490,6 +492,21 @@ def _get_token_from_credential(credential: "TokenProvider") -> Tuple[str, Option start_time = time.perf_counter() try: token_result = credential.get_token(_DATABASE_SCOPE) + except TypeError as e: + # get_token() is called with exactly one positional scope argument, so + # a TypeError here almost always means its signature can't accept a + # scope (e.g. a zero-arg or keyword-only get_token). Surface that as a + # clear, actionable InterfaceError instead of an opaque failure. This + # is the call-time source of truth for arity — the connect() path only + # *warns* on a suspicious signature so it never blocks a credential + # whose signature is merely hard to introspect (partial/decorated). + raise InterfaceError( + driver_error=( + "token_provider.get_token() must accept a scope positional " + "argument, e.g. get_token(scope)." + ), + ddbc_error=str(e), + ) from e except Exception as e: logger.error( "_get_token_from_credential: get_token() failed - credential=%s, error=%s", @@ -503,6 +520,21 @@ def _get_token_from_credential(credential: "TokenProvider") -> Tuple[str, Option ddbc_error=str(e), ) from e + # azure.identity.aio (async) credentials return a coroutine from a + # synchronous get_token() call. Detect it and fail with an async-specific + # message rather than tripping over a missing .token attribute — and close + # the coroutine so it doesn't emit a "coroutine was never awaited" warning. + if inspect.iscoroutine(token_result): + token_result.close() + raise InterfaceError( + driver_error=( + "token_provider.get_token() returned a coroutine, which indicates " + "an async credential (e.g. from azure.identity.aio). Use a " + "synchronous credential instead." + ), + ddbc_error=f"got coroutine from {type(credential).__name__}.get_token()", + ) + raw_token = getattr(token_result, "token", None) if not isinstance(raw_token, str) or not raw_token: raise InterfaceError( @@ -514,6 +546,22 @@ def _get_token_from_credential(credential: "TokenProvider") -> Tuple[str, Option ) expires_on = getattr(token_result, "expires_on", None) + # Warn (don't fail) if the credential handed back an already-expired token: + # the server enforces expiry and will reject the login, so surfacing it here + # points at the real cause instead of an opaque later failure. Only numeric + # POSIX timestamps are checked; bools are excluded to avoid false positives. + if ( + isinstance(expires_on, (int, float)) + and not isinstance(expires_on, bool) + and expires_on < time.time() + ): + warnings.warn( + f"token_provider returned a token that is already expired " + f"(expires_on={expires_on} is in the past). The server will likely " + f"reject the connection.", + UserWarning, + stacklevel=2, + ) elapsed_ms = (time.perf_counter() - start_time) * 1000 logger.info( "_get_token_from_credential: Token acquired from %s - length=%d chars, " diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 6c1ea0350..f8def99cb 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -57,6 +57,7 @@ _KEY_UID, _KEY_PWD, _KEY_TRUSTED_CONNECTION, + _AuthInternal, ) if TYPE_CHECKING: @@ -294,15 +295,24 @@ def __init__( Server at login. For sovereign clouds, acquire the token yourself and pass it via ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]`` instead. + .. note:: + Connection pooling is automatically disabled for any access-token + connection (``token_provider=``, built-in ``Authentication=ActiveDirectory*``, + or a raw ``attrs_before[SQL_COPT_SS_ACCESS_TOKEN]``). The native pool keys + on the sanitized connection string only, so different principals sharing the + same server/database would otherwise collide in one pool bucket and could be + handed each other's authenticated connection. Disabling pooling keeps each + principal isolated. + .. note:: Token lifecycle limitations: the access token is a *pre-connect* ODBC - attribute, so it cannot be refreshed on a live connection. Tokens are - **not** re-acquired automatically when a pooled/native connection is reused - after expiry, and Continuous Access Evaluation (CAE) claims challenges are - not handled. These require native driver support and are tracked as - follow-up work. Interactive credentials (e.g. - ``InteractiveBrowserCredential``) block ``connect()`` until the user - completes sign-in; prefer non-interactive credentials in server contexts. + attribute, so it cannot be refreshed on a live connection. Long-lived + connections must be recycled by the application once the token nears expiry, + and Continuous Access Evaluation (CAE) claims challenges are not handled. + These require native driver support and are tracked as follow-up work. + Interactive credentials (e.g. ``InteractiveBrowserCredential``) block + ``connect()`` until the user completes sign-in; prefer non-interactive + credentials in server contexts. **kwargs: Additional key/value pairs for the connection string. Returns: @@ -337,7 +347,11 @@ def __init__( self.connection_str, parsed_params = self._construct_connection_string( connection_str, **kwargs ) - self._attrs_before = attrs_before or {} + # Shallow-copy so we never mutate the caller's dict (e.g. when the + # token_provider path injects SQL_COPT_SS_ACCESS_TOKEN). Mutating the + # caller's object would leak the access token into user state and break + # re-using the same attrs_before dict across multiple connections. + self._attrs_before = dict(attrs_before) if attrs_before else {} # Initialize encoding settings with defaults for Python 3 # Python 3 only has str (which is Unicode), so we use utf-16le by default @@ -378,7 +392,7 @@ def __init__( self._credential_kwargs: Optional[Dict[str, str]] = None # User-supplied token provider for custom Entra ID authentication. # Stored so bulk copy can call .get_token() for a fresh JWT later. - self._token_provider = None + self._token_provider: Optional["TokenProvider"] = None # POSIX timestamp (seconds) at which the current access token expires, # captured from the credential's AccessToken result. None when unknown. # The token is a pre-connect ODBC attribute and cannot be refreshed on @@ -388,77 +402,7 @@ def __init__( # Custom token_provider= parameter — takes priority, mutually exclusive # with Authentication= in the connection string. if token_provider is not None: - if _KEY_AUTHENTICATION in parsed_params: - raise InterfaceError( - driver_error=( - "Cannot specify both 'token_provider' parameter and " - "'Authentication' in the connection string. " - "Use one or the other." - ), - ddbc_error="", - ) - if ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in self._attrs_before: - raise InterfaceError( - driver_error=( - "Cannot specify both 'token_provider' parameter and " - "attrs_before[SQL_COPT_SS_ACCESS_TOKEN]. " - "Use one token source." - ), - ddbc_error="", - ) - if not callable(getattr(token_provider, "get_token", None)): - raise InterfaceError( - driver_error=( - f"token_provider must have a .get_token() method. " - f"Got {type(token_provider).__name__}." - ), - ddbc_error="", - ) - # Validate that get_token can accept the scope positional argument. - # Inspecting the signature catches obvious arity bugs (e.g. a - # zero-arg get_token) up-front with a clear message instead of an - # opaque TypeError surfacing mid-acquisition. C-implemented or - # otherwise un-inspectable callables are skipped and validated at - # call time. - from mssql_python.auth import acquire_token_from_credential, _DATABASE_SCOPE - - get_token = getattr(token_provider, "get_token") - try: - signature = inspect.signature(get_token) - except (ValueError, TypeError): - signature = None - if signature is not None: - try: - signature.bind(_DATABASE_SCOPE) - except TypeError as exc: - raise InterfaceError( - driver_error=( - "token_provider.get_token() must accept a scope " - "positional argument, e.g. get_token(scope)." - ), - ddbc_error=str(exc), - ) from exc - # access-token auth ignores UID/PWD/Trusted_Connection — warn so the - # user is not surprised that those credentials are silently dropped. - dropped = [ - key for key in (_KEY_UID, _KEY_PWD, _KEY_TRUSTED_CONNECTION) if key in parsed_params - ] - if dropped: - warnings.warn( - "token_provider is set, so the following connection-string " - f"credential(s) are ignored: {', '.join(sorted(dropped))}. " - "Remove them to silence this warning.", - UserWarning, - stacklevel=2, - ) - token, token_expires_on = acquire_token_from_credential(token_provider) - self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token - self._token_provider = token_provider - self._token_expires_on = token_expires_on - # Strip sensitive params (UID/PWD/Trusted_Connection) since - # access-token auth is used — same as the Authentication= path. - sanitized = remove_sensitive_params(parsed_params) - self.connection_str = _ConnectionStringBuilder(sanitized).build() + self._configure_token_provider(token_provider, parsed_params) # Handle Entra ID authentication if specified. # The parsed dict is used directly — no re-parsing of the connection string. @@ -469,7 +413,7 @@ def __init__( # Capture credential kwargs (e.g. user-assigned MSI client_id) # from the parsed dict *before* remove_sensitive_params strips UID. credential_kwargs: Optional[Dict[str, str]] = None - if auth_type == "msi": + if auth_type == _AuthInternal.MSI: uid = (parsed_params.get(_KEY_UID) or "").strip() if uid: credential_kwargs = {"client_id": uid} @@ -521,6 +465,27 @@ def __init__( if not PoolingManager.is_initialized(): PoolingManager.enable() self._pooling = PoolingManager.is_enabled() + + # Access-token connections must NOT be pooled. The native pool is keyed + # on the (sanitized) connection string only, and the access token lives + # in attrs_before — which is applied solely when a *new* physical + # connection is created and is never re-applied when a pooled connection + # is reused. With pooling on, two different principals that share the + # same Server/Database collapse into the same pool bucket, so one caller + # can be handed another caller's already-authenticated connection + # (silent identity confusion / privilege escalation). This affects every + # access-token path: a raw SQL_COPT_SS_ACCESS_TOKEN supplied directly in + # attrs_before, built-in Authentication=ActiveDirectory* auth, and the + # token_provider= credential — they all funnel the token through + # attrs_before. Disabling pooling for these connections keeps each + # principal isolated. The same-principal reuse case loses pooling, which + # is an acceptable, correct default. Refreshing the token on a live + # connection (so pooling could be re-enabled safely) needs native driver + # support and is tracked as follow-up work. + # See docs/DESIGN_TOKEN_PROVIDER_SUPPORT.md. + if ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in self._attrs_before: + self._pooling = False + try: self._conn = ddbc_bindings.Connection( self.connection_str, self._pooling, self._attrs_before @@ -545,6 +510,102 @@ def __init__( f"Unexpected error during connection registration: {type(e).__name__}: {e}" ) + def _configure_token_provider( + self, token_provider: "TokenProvider", parsed_params: Dict[str, str] + ) -> None: + """Validate a custom ``token_provider`` and apply its access token. + + Acquires a token from ``token_provider.get_token()`` and injects it as + the ``SQL_COPT_SS_ACCESS_TOKEN`` pre-connect attribute, then strips any + sensitive params from the connection string. Mutually exclusive with + ``Authentication=`` and a manual ``attrs_before`` access token. + + Raises: + InterfaceError: If ``token_provider`` is combined with another token + source, or lacks a ``get_token(scope)`` method. + OperationalError: If acquiring a token from ``token_provider`` fails. + """ + if _KEY_AUTHENTICATION in parsed_params: + raise InterfaceError( + driver_error=( + "Cannot specify both 'token_provider' parameter and " + "'Authentication' in the connection string. " + "Use one or the other." + ), + ddbc_error="", + ) + if ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in self._attrs_before: + raise InterfaceError( + driver_error=( + "Cannot specify both 'token_provider' parameter and " + "attrs_before[SQL_COPT_SS_ACCESS_TOKEN]. " + "Use one token source." + ), + ddbc_error="", + ) + get_token = getattr(token_provider, "get_token", None) + if not callable(get_token): + raise InterfaceError( + driver_error=( + f"token_provider must have a .get_token() method. " + f"Got {type(token_provider).__name__}." + ), + ddbc_error="", + ) + # Validate that get_token can accept the scope positional argument. + # Inspecting the signature catches obvious arity bugs (e.g. a + # zero-arg get_token) up-front. But signature introspection is + # unreliable for wrapped/partial/callable-object credentials, so a + # suspicious signature only *warns* here — it never blocks the + # credential. The authoritative arity check happens at call time in + # _get_token_from_credential, which raises a clear InterfaceError if + # the scope genuinely cannot be passed. C-implemented or otherwise + # un-inspectable callables are skipped entirely. + from mssql_python.auth import acquire_token_from_credential, _DATABASE_SCOPE + + try: + signature = inspect.signature(get_token) + except (ValueError, TypeError): + signature = None + if signature is not None: + try: + signature.bind(_DATABASE_SCOPE) + except TypeError: + warnings.warn( + "token_provider.get_token() does not appear to accept a " + "positional scope argument (expected get_token(scope)). " + "If token acquisition fails, check the credential's " + "get_token() signature.", + UserWarning, + # 3 frames out: warnings.warn -> _configure_token_provider + # -> __init__ -> caller. Keeps the warning on user code. + stacklevel=3, + ) + # access-token auth ignores UID/PWD/Trusted_Connection — warn so the + # user is not surprised that those credentials are silently dropped. + dropped = [ + key for key in (_KEY_UID, _KEY_PWD, _KEY_TRUSTED_CONNECTION) if key in parsed_params + ] + if dropped: + warnings.warn( + "token_provider is set, so the following connection-string " + f"credential(s) are ignored: {', '.join(sorted(dropped))}. " + "Remove them to silence this warning.", + UserWarning, + # 3 frames out: warnings.warn -> _configure_token_provider -> + # __init__ -> caller (connect()/Connection()). Keeps the warning + # pointed at user code, not this internal helper. + stacklevel=3, + ) + token, token_expires_on = acquire_token_from_credential(token_provider) + self._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] = token + self._token_provider = token_provider + self._token_expires_on = token_expires_on + # Strip sensitive params (UID/PWD/Trusted_Connection) since + # access-token auth is used — same as the Authentication= path. + sanitized = remove_sensitive_params(parsed_params) + self.connection_str = _ConnectionStringBuilder(sanitized).build() + def _construct_connection_string( self, connection_str: str = "", **kwargs: Any ) -> Tuple[str, Dict[str, str]]: diff --git a/tests/test_008_auth.py b/tests/test_008_auth.py index 19b178527..7edf7a646 100644 --- a/tests/test_008_auth.py +++ b/tests/test_008_auth.py @@ -5,6 +5,8 @@ """ import pytest +import collections +import inspect import platform import sys import threading @@ -22,6 +24,8 @@ _credential_cache, acquire_token_from_credential, acquire_raw_token_from_credential, + TokenProvider, + _DATABASE_SCOPE, ) from mssql_python.constants import AuthType, ConstantsDDBC from mssql_python.exceptions import InterfaceError, OperationalError @@ -1081,6 +1085,66 @@ def test_scope_is_commercial_cloud(self): acquire_token_from_credential(mock_cred) mock_cred.get_token.assert_called_once_with("https://database.windows.net/.default") + def test_missing_expires_on_returns_none(self): + """A token object without .expires_on yields expires_on=None (not an error).""" + + class MinimalToken: + token = SAMPLE_TOKEN # no expires_on attribute + + mock_cred = MagicMock() + mock_cred.get_token.return_value = MinimalToken() + token_struct, expires_on = acquire_token_from_credential(mock_cred) + assert isinstance(token_struct, bytes) + assert expires_on is None + + def test_bytes_token_raises_interface_error(self): + """A bytes .token (not str) is rejected just like other non-str values.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=b"not_a_str_token") + with pytest.raises(InterfaceError, match="non-empty"): + acquire_token_from_credential(mock_cred) + + def test_whitespace_only_token_is_accepted(self): + """Documents current behavior: a non-empty whitespace token passes the + client-side check (validity is enforced server-side at login).""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=" ", expires_on=None) + token_struct, _ = acquire_token_from_credential(mock_cred) + assert isinstance(token_struct, bytes) + + def test_credential_exception_preserved_as_cause(self): + """The original credential error is chained as __cause__ for callers + that want to catch the underlying azure-identity exception.""" + + class ClientAuthenticationError(Exception): + """Stand-in for azure.core.exceptions.ClientAuthenticationError.""" + + original = ClientAuthenticationError("AADSTS700016") + mock_cred = MagicMock() + mock_cred.get_token.side_effect = original + with pytest.raises(OperationalError) as exc_info: + acquire_token_from_credential(mock_cred) + assert exc_info.value.__cause__ is original + + def test_get_token_returns_none_raises_interface_error(self): + """A credential whose get_token returns None is rejected clearly.""" + mock_cred = MagicMock() + mock_cred.get_token.return_value = None + with pytest.raises(InterfaceError, match="non-empty"): + acquire_token_from_credential(mock_cred) + + def test_realistic_length_jwt_round_trips(self): + """A realistic ~1.5 KB JWT is encoded into the ODBC token struct without + truncation (length prefix + UTF-16-LE body).""" + big_jwt = "e" + "A" * 1500 + ".sig" + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=big_jwt, expires_on=None) + token_struct, _ = acquire_token_from_credential(mock_cred) + # struct = 4-byte little-endian length prefix + UTF-16-LE token bytes. + expected_body = big_jwt.encode("utf-16-le") + assert token_struct[:4] == len(expected_body).to_bytes(4, "little") + assert token_struct[4:] == expected_body + class TestAcquireRawTokenFromCredential: """Tests for the acquire_raw_token_from_credential helper.""" @@ -1297,8 +1361,10 @@ class ZeroArgCredential: def get_token(self): # missing scope parameter return MagicMock(token=SAMPLE_TOKEN) - with pytest.raises(InterfaceError, match="must accept a scope"): - connect("Server=test;Database=testdb", token_provider=ZeroArgCredential()) + # Up-front signature check warns; call-time validation raises the error. + with pytest.warns(UserWarning, match="does not appear to accept"): + with pytest.raises(InterfaceError, match="must accept a scope"): + connect("Server=test;Database=testdb", token_provider=ZeroArgCredential()) mock_ddbc_conn.assert_not_called() @patch("mssql_python.connection.ddbc_bindings.Connection") @@ -1315,6 +1381,20 @@ def get_token(self, scope): assert conn._token_expires_on == 1893456000 conn.close() + @patch("mssql_python.connection.inspect.signature", side_effect=ValueError("no signature")) + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_uninspectable_get_token_skips_validation(self, mock_ddbc_conn, _mock_sig): + """A get_token whose signature can't be introspected skips arity validation.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + from mssql_python import connect + + # inspect.signature raises -> validation is skipped and connect still succeeds. + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + assert conn._token_expires_on == 1893456000 + conn.close() + @patch("mssql_python.connection.ddbc_bindings.Connection") def test_dropped_uid_pwd_emits_warning(self, mock_ddbc_conn): """UID/PWD in the connection string trigger a warning when token_provider is set.""" @@ -1344,6 +1424,600 @@ def test_no_warning_without_dropped_credentials(self, mock_ddbc_conn): assert not any("are ignored" in str(w.message) for w in caught) conn.close() + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_real_azure_style_signature_accepted(self, mock_ddbc_conn): + """get_token(self, *scopes, **kwargs) — the real azure-identity shape — + passes arity validation.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class AzureStyleCredential: + def get_token(self, *scopes, **kwargs): + return MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + + conn = connect("Server=test;Database=testdb", token_provider=AzureStyleCredential()) + assert conn._token_expires_on == 1893456000 + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_connection_string_sanitized_of_uid_pwd(self, mock_ddbc_conn): + """UID/PWD are stripped from connection_str when token_provider is used.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + conn = connect( + "Server=test;Database=testdb;UID=user@test.com;PWD=secret", + token_provider=mock_cred, + ) + assert "UID=" not in conn.connection_str + assert "PWD=" not in conn.connection_str + assert "secret" not in conn.connection_str + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_missing_expires_on_sets_none(self, mock_ddbc_conn): + """A credential whose token lacks .expires_on leaves _token_expires_on None.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class MinimalToken: + token = SAMPLE_TOKEN # no expires_on + + class MinimalCredential: + def get_token(self, scope): + return MinimalToken() + + conn = connect("Server=test;Database=testdb", token_provider=MinimalCredential()) + assert conn._token_expires_on is None + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_dropped_trusted_connection_emits_warning(self, mock_ddbc_conn): + """Trusted_Connection alone also triggers the dropped-credential warning.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with pytest.warns(UserWarning, match="credential\\(s\\) are ignored"): + conn = connect( + "Server=test;Database=testdb;Trusted_Connection=yes", + token_provider=mock_cred, + ) + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_async_credential_coroutine_rejected(self, mock_ddbc_conn): + """An async credential returns a coroutine from a synchronous get_token() + call and is rejected with a clear, async-specific InterfaceError (no + un-awaited-coroutine warning, since the coroutine is closed).""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class AsyncCredential: + async def get_token(self, scope): # azure.identity.aio shape + return MagicMock(token=SAMPLE_TOKEN) + + cred = AsyncCredential() + with pytest.raises(InterfaceError, match="async credential"): + connect("Server=test;Database=testdb", token_provider=cred) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_suspicious_signature_warns_but_does_not_block(self, mock_ddbc_conn): + """If signature introspection wrongly reports that get_token can't take a + scope, connect() must only WARN and still succeed when the real call + works (guards against false rejections of partial/decorated credentials).""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class WorkingCredential: + def get_token(self, scope): # genuinely accepts a scope + return MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + + cred = WorkingCredential() + # Force the up-front bind() check to misfire (report a zero-arg + # signature) even though the actual call accepts the scope fine. + zero_arg_sig = inspect.signature(lambda: None) + with patch("mssql_python.connection.inspect.signature", return_value=zero_arg_sig): + with pytest.warns(UserWarning, match="does not appear to accept"): + conn = connect("Server=test;Database=testdb", token_provider=cred) + # Not blocked: the connection succeeded and captured the token. + assert conn._token_expires_on == 1893456000 + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_keyword_only_scope_rejected(self, mock_ddbc_conn): + """get_token(self, *, scope) can't take scope positionally and is rejected.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + + class KeywordOnlyCredential: + def get_token(self, *, scope): + return MagicMock(token=SAMPLE_TOKEN) + + # Up-front signature check warns; call-time validation raises the error. + with pytest.warns(UserWarning, match="does not appear to accept"): + with pytest.raises(InterfaceError, match="must accept a scope"): + connect("Server=test;Database=testdb", token_provider=KeywordOnlyCredential()) + mock_ddbc_conn.assert_not_called() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_caller_attrs_before_dict_not_mutated(self, mock_ddbc_conn): + """connect() must not inject the access token into the caller's own + attrs_before dict (it would leak the secret and break dict reuse).""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + login_timeout_attr = 113 # SQL_ATTR_LOGIN_TIMEOUT + caller_opts = {login_timeout_attr: 30} + conn = connect( + "Server=test;Database=testdb", + token_provider=mock_cred, + attrs_before=caller_opts, + ) + # The caller's dict is untouched: no access token leaked in. + assert caller_opts == {login_timeout_attr: 30} + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value not in caller_opts + # The connection's own copy did receive the token. + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_reusing_attrs_before_across_connections_succeeds(self, mock_ddbc_conn): + """The same attrs_before dict can be reused for a second connection with + a different provider — proves the dict isn't polluted by the first.""" + mock_ddbc_conn.return_value = MagicMock() + cred_a = MagicMock() + cred_a.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + cred_b = MagicMock() + cred_b.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + shared_opts = {113: 30} # SQL_ATTR_LOGIN_TIMEOUT + c1 = connect("Server=s;Database=d", token_provider=cred_a, attrs_before=shared_opts) + # Without the copy fix this raises "Cannot specify both ... access token". + c2 = connect("Server=s;Database=d", token_provider=cred_b, attrs_before=shared_opts) + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in c1._attrs_before + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in c2._attrs_before + assert c1._attrs_before is not c2._attrs_before + c1.close() + c2.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_expired_expires_on_warns_but_is_accepted(self, mock_ddbc_conn): + """An already-expired expires_on is still accepted (the server enforces + expiry), but a warning is emitted so the likely cause surfaces early.""" + mock_ddbc_conn.return_value = MagicMock() + past = 1 # 1970 — long expired + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=past) + from mssql_python import connect + + with pytest.warns(UserWarning, match="already expired"): + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + assert conn._token_expires_on == past + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_value_not_in_exception_message(self, mock_ddbc_conn): + """A provider failure must not leak the acquired token in the error.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.side_effect = Exception("auth failed") + from mssql_python import connect + + with pytest.raises(OperationalError) as exc_info: + connect("Server=test;Database=testdb", token_provider=mock_cred) + assert SAMPLE_TOKEN not in str(exc_info.value) + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_value_not_in_logs(self, mock_ddbc_conn, caplog): + """The raw JWT must never be written to logs (only its length).""" + import logging + + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + with caplog.at_level(logging.DEBUG): + conn = connect("Server=test;Database=testdb", token_provider=mock_cred) + assert SAMPLE_TOKEN not in caplog.text + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_empty_connection_string_with_token_provider(self, mock_ddbc_conn): + """An empty connection string with token_provider should not crash the + validation path; the token is still acquired and attached.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + conn = connect("", token_provider=mock_cred) + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + +class TestTokenProviderProtocol: + """Tests for the runtime_checkable TokenProvider Protocol.""" + + def test_object_with_get_token_is_instance(self): + """An object exposing get_token satisfies the Protocol at runtime.""" + + class Cred: + def get_token(self, *scopes, **kwargs): + return MagicMock(token=SAMPLE_TOKEN) + + assert isinstance(Cred(), TokenProvider) + + def test_object_without_get_token_is_not_instance(self): + """An object missing get_token does not satisfy the Protocol.""" + + class NotCred: + def something_else(self): + return None + + assert not isinstance(NotCred(), TokenProvider) + + def test_database_scope_is_commercial_cloud_constant(self): + """The shared scope constant points at the Azure commercial-cloud audience.""" + assert _DATABASE_SCOPE == "https://database.windows.net/.default" + + +class TestTokenProviderPooling: + """Pins pooling behavior for access-token connections. + + The native pool keys on the (sanitized) connection string only, and the + access token lives in attrs_before — applied just once when a *new* physical + connection is created and never re-applied on reuse. So two different + principals that share the same Server/Database would collide in the same + pool bucket and one could be handed another's authenticated connection. + To prevent that silent identity confusion, Connection.__init__ disables + pooling whenever an access token is present in attrs_before. These tests pin + that contract for every access-token path (raw SQL_COPT_SS_ACCESS_TOKEN, + built-in Authentication=ActiveDirectory*, and token_provider=). + """ + + @staticmethod + def _pooling_arg(mock_ddbc_conn): + """Return the `pooling` positional arg passed to ddbc_bindings.Connection.""" + # ddbc_bindings.Connection(connection_str, pooling, attrs_before) + return mock_ddbc_conn.call_args.args[1] + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_provider_disables_pooling(self, mock_ddbc_conn): + """token_provider= connections must not be pooled (cross-principal + collision guard).""" + mock_ddbc_conn.return_value = MagicMock() + cred = MagicMock() + cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + conn = connect("Server=s;Database=d", token_provider=cred) + assert self._pooling_arg(mock_ddbc_conn) is False + assert conn._pooling is False + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_raw_access_token_in_attrs_before_disables_pooling(self, mock_ddbc_conn): + """A raw SQL_COPT_SS_ACCESS_TOKEN supplied directly in attrs_before (the + pyodbc-style path) must also disable pooling — this path was uncovered + before the fix.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + from mssql_python.constants import ConstantsDDBC + + token_struct = b"\x04\x00\x00\x00test" + conn = connect( + "Server=s;Database=d", + attrs_before={ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value: token_struct}, + ) + assert self._pooling_arg(mock_ddbc_conn) is False + assert conn._pooling is False + conn.close() + + @patch("mssql_python.connection.get_auth_token") + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_builtin_entra_auth_disables_pooling(self, mock_ddbc_conn, mock_get_token): + """Built-in Authentication=ActiveDirectory* auth that injects a token into + attrs_before (e.g. ActiveDirectoryDefault) must also disable pooling — + this path was uncovered before the fix. (Driver-native paths such as + ServicePrincipal keep credentials in the connection string and remain + poolable; see test_builtin_driver_native_auth_keeps_pooling.)""" + mock_ddbc_conn.return_value = MagicMock() + mock_get_token.return_value = b"\x04\x00\x00\x00test" + from mssql_python import connect + + conn = connect("Server=s;Database=d;Authentication=ActiveDirectoryDefault") + assert self._pooling_arg(mock_ddbc_conn) is False + assert conn._pooling is False + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_builtin_driver_native_auth_keeps_pooling(self, mock_ddbc_conn): + """Driver-native Entra auth (ServicePrincipal) keeps UID/PWD in the + connection string, so the pool key already distinguishes principals and + pooling stays enabled — no token is injected into attrs_before.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + from mssql_python.constants import ConstantsDDBC + from mssql_python.pooling import PoolingManager + + PoolingManager._reset_for_testing() + conn = connect( + "Server=s;Database=d;Authentication=ActiveDirectoryServicePrincipal;" + "UID=app-id;PWD=app-secret" + ) + # No access token was injected, so pooling is left enabled. + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value not in conn._attrs_before + assert self._pooling_arg(mock_ddbc_conn) is True + assert conn._pooling is True + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_non_token_connection_keeps_pooling_enabled(self, mock_ddbc_conn): + """A plain connection (no access token) is still eligible for pooling — + the fix must not regress normal SQL/Windows-auth pooling.""" + mock_ddbc_conn.return_value = MagicMock() + from mssql_python import connect + from mssql_python.pooling import PoolingManager + + PoolingManager._reset_for_testing() + conn = connect("Server=s;Database=d;UID=sa;PWD=secret") + assert self._pooling_arg(mock_ddbc_conn) is True + assert conn._pooling is True + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_different_providers_yield_identical_connection_string(self, mock_ddbc_conn): + """Two different providers -> same sanitized connection string. This is + exactly why pooling must be disabled: the pool key (the connection + string) can't tell the principals apart.""" + mock_ddbc_conn.return_value = MagicMock() + cred_a = MagicMock() + cred_a.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + cred_b = MagicMock() + cred_b.get_token.return_value = MagicMock(token=SAMPLE_TOKEN) + from mssql_python import connect + + c1 = connect("Server=s;Database=d", token_provider=cred_a) + c2 = connect("Server=s;Database=d", token_provider=cred_b) + assert c1.connection_str == c2.connection_str + # ...but neither is pooled, so the collision can never occur. + assert c1._pooling is False + assert c2._pooling is False + c1.close() + c2.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_token_not_refreshed_after_connect(self, mock_ddbc_conn): + """The access token is a pre-connect attribute: it is acquired exactly + once at connect() and not re-acquired for the life of the connection.""" + mock_ddbc_conn.return_value = MagicMock() + mock_cred = MagicMock() + mock_cred.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1) + from mssql_python import connect + + # expires_on=1 is in the past, so the expired-token warning fires; the + # point of this test is that the token is acquired exactly once. + with pytest.warns(UserWarning, match="already expired"): + conn = connect("Server=s;Database=d", token_provider=mock_cred) + # Even though expires_on is in the past, nothing re-acquires the token. + assert mock_cred.get_token.call_count == 1 + conn.close() + assert mock_cred.get_token.call_count == 1 + + +# --- Faithful azure-identity stand-ins ------------------------------------- +# These mirror the real azure.core.credentials API so the token_provider path +# is exercised exactly as it would be with a live `azure-identity` install, +# without taking a dependency on the package or making network calls. + +# azure.core.credentials.AccessToken is a NamedTuple(token: str, expires_on: int). +_AccessToken = collections.namedtuple("AccessToken", ["token", "expires_on"]) + + +class _FakeDefaultAzureCredential: + """Mirrors azure.identity.DefaultAzureCredential. + + Real signature: + get_token(self, *scopes, claims=None, tenant_id=None, + enable_cae=False, **kwargs) -> AccessToken + The SDK caches internally and hands back the same AccessToken until it is + near expiry, so repeated calls are cheap and return a stable token. + """ + + def __init__(self, token=SAMPLE_TOKEN, expires_on=1893456000): + self._cached = _AccessToken(token, expires_on) + self.calls = [] + + def get_token(self, *scopes, claims=None, tenant_id=None, enable_cae=False, **kwargs): + self.calls.append(scopes) + return self._cached + + +class _FakeClientSecretCredential: + """Mirrors azure.identity.ClientSecretCredential (service principal).""" + + def __init__(self, tenant_id, client_id, client_secret, token=SAMPLE_TOKEN): + self.tenant_id = tenant_id + self.client_id = client_id + self._secret = client_secret + self._token = token + self.calls = 0 + + def get_token(self, *scopes, **kwargs): + self.calls += 1 + return _AccessToken(self._token, 1893456000) + + +class _FakeManagedIdentityCredential: + """Mirrors azure.identity.ManagedIdentityCredential (App Service / VM).""" + + def __init__(self, client_id=None, token=SAMPLE_TOKEN): + self.client_id = client_id + self._token = token + + def get_token(self, *scopes, **kwargs): + return _AccessToken(self._token, 1893456000) + + +class _FakeInteractiveBrowserCredential: + """Mirrors azure.identity.InteractiveBrowserCredential. + + The first call performs an interactive sign-in (slow / may block); after + that the token is cached. We model that the first get_token is the one that + "logs in" and subsequent calls return the cached value. + """ + + def __init__(self, token=SAMPLE_TOKEN): + self._token = token + self.login_count = 0 + + def get_token(self, *scopes, claims=None, tenant_id=None, enable_cae=False, **kwargs): + if self.login_count == 0: + self.login_count += 1 # "interactive sign-in" happens here + return _AccessToken(self._token, 1893456000) + + +class TestTokenProviderRealWorld: + """End-to-end checks against faithful azure-identity credential stand-ins. + + Validates that the token_provider= fix behaves correctly with the real + Azure SDK API shapes (AccessToken namedtuple, *scopes/**kwargs signatures) + and the real usage patterns library consumers actually write. + """ + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_default_azure_credential_end_to_end(self, mock_ddbc_conn): + """The canonical `connect(conn_str, token_provider=DefaultAzureCredential())`.""" + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeDefaultAzureCredential() + from mssql_python import connect + + conn = connect("Server=myserver.database.windows.net;Database=mydb", token_provider=cred) + # Token acquired with the commercial-cloud database scope, once. + assert cred.calls == [(_DATABASE_SCOPE,)] + assert conn._token_provider is cred + assert conn._token_expires_on == 1893456000 + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_client_secret_credential_service_principal(self, mock_ddbc_conn): + """Service-principal pattern: ClientSecretCredential(tenant, id, secret).""" + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeClientSecretCredential( + tenant_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + client_id="11111111-2222-3333-4444-555555555555", + client_secret="super-secret", + ) + from mssql_python import connect + + conn = connect("Server=s.database.windows.net;Database=d", token_provider=cred) + assert cred.calls == 1 + assert conn._token_provider is cred + # The client secret must never end up in the (sanitized) connection string. + assert "super-secret" not in conn.connection_str + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_managed_identity_credential_app_service(self, mock_ddbc_conn): + """App Service / VM pattern: ManagedIdentityCredential(client_id=...).""" + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeManagedIdentityCredential(client_id="user-assigned-mi-client-id") + from mssql_python import connect + + conn = connect("Server=s.database.windows.net;Database=d", token_provider=cred) + assert conn._token_provider is cred + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in conn._attrs_before + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_interactive_browser_credential_signs_in_once(self, mock_ddbc_conn): + """Interactive credential: first connect triggers the single sign-in.""" + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeInteractiveBrowserCredential() + from mssql_python import connect + + conn = connect("Server=s.database.windows.net;Database=d", token_provider=cred) + assert cred.login_count == 1 + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_access_token_namedtuple_round_trips(self, mock_ddbc_conn): + """A real AccessToken namedtuple flows through .token / .expires_on access.""" + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeDefaultAzureCredential(expires_on=1999999999) + from mssql_python import connect + + conn = connect("Server=s;Database=d", token_provider=cred) + assert conn._token_expires_on == 1999999999 + # The injected attribute is the UTF-16-LE struct, not the raw JWT. + token_struct = conn._attrs_before[ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value] + body = SAMPLE_TOKEN.encode("UTF-16-LE") + assert token_struct[:4] == len(body).to_bytes(4, "little") + assert token_struct[4:] == body + conn.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_one_credential_reused_across_a_connection_pool(self, mock_ddbc_conn): + """The real pattern: build the credential once, reuse for every connect. + + Each connect() acquires a fresh token from the (internally-cached) + credential, and connections never share an attrs_before dict. + """ + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeDefaultAzureCredential() + from mssql_python import connect + + conns = [connect(f"Server=s{i};Database=d", token_provider=cred) for i in range(5)] + assert len(cred.calls) == 5 + # No two connections alias the same attrs_before dict (regression guard + # for the caller-dict-mutation bug). + ids = {id(c._attrs_before) for c in conns} + assert len(ids) == 5 + for c in conns: + c.close() + + @patch("mssql_python.connection.ddbc_bindings.Connection") + def test_shared_app_config_dict_reused_for_every_connection(self, mock_ddbc_conn): + """Real-world bug-fix scenario: an app holds ONE options dict (e.g. a + login timeout) and passes it to every connect() alongside a credential. + + Before the fix the first connect() injected the access token into this + shared dict, so the second connect() raised "Cannot specify both ...". + """ + mock_ddbc_conn.return_value = MagicMock() + cred = _FakeDefaultAzureCredential() + from mssql_python import connect + + SQL_ATTR_LOGIN_TIMEOUT = 113 + app_attrs = {SQL_ATTR_LOGIN_TIMEOUT: 30} # built once, reused everywhere + + c1 = connect("Server=s1;Database=d", token_provider=cred, attrs_before=app_attrs) + c2 = connect("Server=s2;Database=d", token_provider=cred, attrs_before=app_attrs) + + # The shared dict is untouched: only the login timeout, no access token. + assert app_attrs == {SQL_ATTR_LOGIN_TIMEOUT: 30} + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value not in app_attrs + # Both connections got their own token + the app's login timeout. + for c in (c1, c2): + assert c._attrs_before[SQL_ATTR_LOGIN_TIMEOUT] == 30 + assert ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in c._attrs_before + c1.close() + c2.close() + class TestParseTenantId: def test_guid_tenant(self): diff --git a/tests/test_020_bulkcopy_auth_cleanup.py b/tests/test_020_bulkcopy_auth_cleanup.py index 7543066a7..1f471f1ce 100644 --- a/tests/test_020_bulkcopy_auth_cleanup.py +++ b/tests/test_020_bulkcopy_auth_cleanup.py @@ -187,8 +187,76 @@ def test_token_provider_get_token_failure_rewrapped(self, mock_logger): assert "unable to acquire token from custom credential" in str(exc_info.value) @patch("mssql_python.cursor.logger") - def test_token_provider_invalid_token_rewrapped(self, mock_logger): - """credential returning a non-string token ⇒ bulkcopy raises OperationalError.""" + def test_each_bulkcopy_reacquires_fresh_token(self, mock_logger): + """Every bulkcopy() call asks the provider for a fresh token (no reuse + of a possibly-stale token across operations).""" + mock_logger.is_debug_enabled = False + + credential = MagicMock() + credential.get_token.return_value = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" + "Authentication=ActiveDirectoryDefault", + "activedirectorydefault", + ) + cursor._connection._token_provider = credential + + mock_pycore_cursor = MagicMock() + mock_pycore_cursor.bulkcopy.return_value = { + "rows_copied": 1, + "batch_count": 1, + "elapsed_time": 0.1, + } + mock_pycore_conn = MagicMock() + mock_pycore_conn.cursor.return_value = mock_pycore_cursor + mock_pycore_module = MagicMock() + mock_pycore_module.PyCoreConnection = lambda ctx, **kwargs: mock_pycore_conn + + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): + cursor.bulkcopy("dbo.t", [(1, "a")], timeout=10) + cursor.bulkcopy("dbo.t", [(2, "b")], timeout=10) + cursor.bulkcopy("dbo.t", [(3, "c")], timeout=10) + + assert credential.get_token.call_count == 3 + + @patch("mssql_python.cursor.logger") + def test_transient_failure_then_recovery(self, mock_logger): + """A transient provider failure on one bulkcopy raises OperationalError + but leaves the cursor usable for a subsequent successful call.""" + mock_logger.is_debug_enabled = False + + credential = MagicMock() + good = MagicMock(token=SAMPLE_TOKEN, expires_on=1893456000) + # First call fails, second call succeeds. + credential.get_token.side_effect = [RuntimeError("network blip"), good] + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" + "Authentication=ActiveDirectoryDefault", + "activedirectorydefault", + ) + cursor._connection._token_provider = credential + + mock_pycore_cursor = MagicMock() + mock_pycore_cursor.bulkcopy.return_value = { + "rows_copied": 1, + "batch_count": 1, + "elapsed_time": 0.1, + } + mock_pycore_conn = MagicMock() + mock_pycore_conn.cursor.return_value = mock_pycore_cursor + mock_pycore_module = MagicMock() + mock_pycore_module.PyCoreConnection = lambda ctx, **kwargs: mock_pycore_conn + + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): + with pytest.raises(OperationalError): + cursor.bulkcopy("dbo.t", [(1, "a")], timeout=10) + # Cursor still works after the transient failure. + cursor.bulkcopy("dbo.t", [(2, "b")], timeout=10) + + assert credential.get_token.call_count == 2 + mock_logger.is_debug_enabled = False # .token is not a non-empty string ⇒ _get_token_from_credential raises InterfaceError, From 9814bdbe61e3b374fb1fa212646f241851163f0e Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 26 Jun 2026 22:19:10 +0530 Subject: [PATCH 08/10] Fix f-string lint nits and type token_provider in .pyi stubs - Remove unnecessary f-prefix from two non-interpolated SQL_WCHAR error strings in connection.py (flagged by flake8-no-fstring-u style linters). - Type token_provider as Optional[TokenProvider] (was Optional[object]) in the .pyi stubs for Connection.__init__ and connect(), matching the runtime. --- mssql_python/connection.py | 4 ++-- mssql_python/mssql_python.pyi | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mssql_python/connection.py b/mssql_python/connection.py index f8def99cb..f679569ed 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -144,10 +144,10 @@ def _validate_utf16_wchar_compatibility( # Generate context-appropriate error messages if "ctype" in context: - driver_error = f"SQL_WCHAR ctype only supports UTF-16 encodings" + driver_error = "SQL_WCHAR ctype only supports UTF-16 encodings" ddbc_context = "SQL_WCHAR ctype" else: - driver_error = f"SQL_WCHAR only supports UTF-16 encodings" + driver_error = "SQL_WCHAR only supports UTF-16 encodings" ddbc_context = "SQL_WCHAR" raise ProgrammingError( diff --git a/mssql_python/mssql_python.pyi b/mssql_python/mssql_python.pyi index 05aeec499..9f967cfcc 100644 --- a/mssql_python/mssql_python.pyi +++ b/mssql_python/mssql_python.pyi @@ -9,6 +9,8 @@ import datetime import logging import pyarrow +from mssql_python.auth import TokenProvider + # GLOBALS - DB-API 2.0 Required Module Globals # https://www.python.org/dev/peps/pep-0249/#module-interface apilevel: str # "2.0" @@ -248,7 +250,7 @@ class Connection: attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - token_provider: Optional[object] = None, + token_provider: Optional[TokenProvider] = None, **kwargs: Any, ) -> None: ... @@ -292,7 +294,7 @@ def connect( attrs_before: Optional[Dict[int, Union[int, str, bytes]]] = None, timeout: int = 0, native_uuid: Optional[bool] = None, - token_provider: Optional[object] = None, + token_provider: Optional[TokenProvider] = None, **kwargs: Any, ) -> Connection: ... From d723153df4c2f49c6570106ac6b14b44a003b870 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 26 Jun 2026 22:31:49 +0530 Subject: [PATCH 09/10] Remove dangling doc reference and fix expired-token warning attribution - Drop the 'See docs/DESIGN_TOKEN_PROVIDER_SUPPORT.md' comment in connection.py (that file is not part of the PR). - The expired-token warning in _get_token_from_credential is reached via two call chains at different depths (connect vs bulk-copy), so a fixed stacklevel cannot point at user code for both. Compute the stacklevel dynamically via _stacklevel_to_caller(), which walks out of the package to the first external frame. Works across all supported Python versions. --- mssql_python/auth.py | 35 ++++++++++++++++++++++++++++++++++- mssql_python/connection.py | 1 - 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index fbe1752bc..d88ad5820 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -6,6 +6,7 @@ import hashlib import inspect +import os import platform import struct import threading @@ -43,6 +44,38 @@ # for a different audience is rejected by SQL Server at login. _DATABASE_SCOPE = "https://database.windows.net/.default" +# Absolute, case-normalized directory of this package, used to attribute +# warnings to the first caller *outside* the package (see _stacklevel_to_caller). +_PACKAGE_DIR = os.path.normcase(os.path.dirname(os.path.abspath(__file__))) + + +def _stacklevel_to_caller() -> int: + """Return a ``warnings.warn`` stacklevel pointing at the first frame outside + this package. + + The token-expiry warning is reached through two different internal call + chains — the connect path (via ``acquire_token_from_credential``) and the + bulk-copy path (via ``acquire_raw_token_from_credential``) — which sit at + different depths. A fixed ``stacklevel`` cannot point at user code for both, + so we walk outward until we leave the package. Falls back to 2 if no + external frame is found. + """ + frame = inspect.currentframe() + try: + # Start at the caller of this helper, i.e. the frame that issues warn() + # (stacklevel == 1 for that warn() call). + frame = frame.f_back if frame else None + level = 1 + while frame is not None: + if not os.path.normcase(frame.f_code.co_filename).startswith(_PACKAGE_DIR): + return level + frame = frame.f_back + level += 1 + return 2 + finally: + # Break the reference cycle created by holding a frame object. + del frame + @runtime_checkable class TokenProvider(Protocol): @@ -560,7 +593,7 @@ def _get_token_from_credential(credential: "TokenProvider") -> Tuple[str, Option f"(expires_on={expires_on} is in the past). The server will likely " f"reject the connection.", UserWarning, - stacklevel=2, + stacklevel=_stacklevel_to_caller(), ) elapsed_ms = (time.perf_counter() - start_time) * 1000 logger.info( diff --git a/mssql_python/connection.py b/mssql_python/connection.py index f679569ed..d438423aa 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -482,7 +482,6 @@ def __init__( # is an acceptable, correct default. Refreshing the token on a live # connection (so pooling could be re-enabled safely) needs native driver # support and is tracked as follow-up work. - # See docs/DESIGN_TOKEN_PROVIDER_SUPPORT.md. if ConstantsDDBC.SQL_COPT_SS_ACCESS_TOKEN.value in self._attrs_before: self._pooling = False From 45592417142e90151d7cacb7491ea2f82ed77084 Mon Sep 17 00:00:00 2001 From: Jahnvi Thakkar Date: Fri, 26 Jun 2026 22:38:23 +0530 Subject: [PATCH 10/10] Revert dynamic stacklevel helper; use fixed stacklevel=2 for expired-token warning The _stacklevel_to_caller() helper walked frames on every expired-token warning and relied on __file__ (fragile under zipimport/frozen deploys). The expired case is rare and the message is self-explanatory, so a fixed stacklevel=2 is sufficient. Removes the helper, _PACKAGE_DIR, and the now-unused os import. --- mssql_python/auth.py | 35 +---------------------------------- 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/mssql_python/auth.py b/mssql_python/auth.py index d88ad5820..fbe1752bc 100644 --- a/mssql_python/auth.py +++ b/mssql_python/auth.py @@ -6,7 +6,6 @@ import hashlib import inspect -import os import platform import struct import threading @@ -44,38 +43,6 @@ # for a different audience is rejected by SQL Server at login. _DATABASE_SCOPE = "https://database.windows.net/.default" -# Absolute, case-normalized directory of this package, used to attribute -# warnings to the first caller *outside* the package (see _stacklevel_to_caller). -_PACKAGE_DIR = os.path.normcase(os.path.dirname(os.path.abspath(__file__))) - - -def _stacklevel_to_caller() -> int: - """Return a ``warnings.warn`` stacklevel pointing at the first frame outside - this package. - - The token-expiry warning is reached through two different internal call - chains — the connect path (via ``acquire_token_from_credential``) and the - bulk-copy path (via ``acquire_raw_token_from_credential``) — which sit at - different depths. A fixed ``stacklevel`` cannot point at user code for both, - so we walk outward until we leave the package. Falls back to 2 if no - external frame is found. - """ - frame = inspect.currentframe() - try: - # Start at the caller of this helper, i.e. the frame that issues warn() - # (stacklevel == 1 for that warn() call). - frame = frame.f_back if frame else None - level = 1 - while frame is not None: - if not os.path.normcase(frame.f_code.co_filename).startswith(_PACKAGE_DIR): - return level - frame = frame.f_back - level += 1 - return 2 - finally: - # Break the reference cycle created by holding a frame object. - del frame - @runtime_checkable class TokenProvider(Protocol): @@ -593,7 +560,7 @@ def _get_token_from_credential(credential: "TokenProvider") -> Tuple[str, Option f"(expires_on={expires_on} is in the past). The server will likely " f"reject the connection.", UserWarning, - stacklevel=_stacklevel_to_caller(), + stacklevel=2, ) elapsed_ms = (time.perf_counter() - start_time) * 1000 logger.info(