|
15 | 15 | # specific language governing permissions and limitations
|
16 | 16 | # under the License.
|
17 | 17 | # pylint: disable=protected-access,redefined-outer-name
|
| 18 | +import base64 |
18 | 19 | import copy
|
| 20 | +import struct |
| 21 | +import threading |
19 | 22 | import uuid
|
| 23 | +from collections.abc import Generator |
20 | 24 | from copy import deepcopy
|
| 25 | +from typing import Optional |
21 | 26 | from unittest.mock import MagicMock, call, patch
|
22 | 27 |
|
23 | 28 | import pytest
|
| 29 | +import thrift.transport.TSocket |
24 | 30 | from hive_metastore.ttypes import (
|
25 | 31 | AlreadyExistsException,
|
26 | 32 | FieldSchema,
|
|
38 | 44 |
|
39 | 45 | from pyiceberg.catalog import PropertiesUpdateSummary
|
40 | 46 | from pyiceberg.catalog.hive import (
|
| 47 | + HIVE_KERBEROS_AUTH, |
41 | 48 | LOCK_CHECK_MAX_WAIT_TIME,
|
42 | 49 | LOCK_CHECK_MIN_WAIT_TIME,
|
43 | 50 | LOCK_CHECK_RETRIES,
|
44 | 51 | HiveCatalog,
|
45 | 52 | _construct_hive_storage_descriptor,
|
| 53 | + _HiveClient, |
46 | 54 | )
|
47 | 55 | from pyiceberg.exceptions import (
|
48 | 56 | NamespaceAlreadyExistsError,
|
@@ -183,6 +191,59 @@ def hive_database(tmp_path_factory: pytest.TempPathFactory) -> HiveDatabase:
|
183 | 191 | )
|
184 | 192 |
|
185 | 193 |
|
| 194 | +class SaslServer(threading.Thread): |
| 195 | + def __init__(self, socket: thrift.transport.TSocket.TServerSocket, response: bytes) -> None: |
| 196 | + super().__init__() |
| 197 | + self.daemon = True |
| 198 | + self._socket = socket |
| 199 | + self._response = response |
| 200 | + self._port = None |
| 201 | + self._port_bound = threading.Event() |
| 202 | + |
| 203 | + def run(self) -> None: |
| 204 | + self._socket.listen() |
| 205 | + |
| 206 | + try: |
| 207 | + address = self._socket.handle.getsockname() |
| 208 | + # AF_INET addresses are 2-tuples (host, port) and AF_INET6 are |
| 209 | + # 4-tuples (host, port, ...), i.e. port is always at index 1. |
| 210 | + _host, self._port, *_ = address |
| 211 | + finally: |
| 212 | + self._port_bound.set() |
| 213 | + |
| 214 | + # Accept connections and respond to each connection with the same message. |
| 215 | + # The responsibility for closing the connection is on the client |
| 216 | + while True: |
| 217 | + try: |
| 218 | + client = self._socket.accept() |
| 219 | + if client: |
| 220 | + client.write(self._response) |
| 221 | + client.flush() |
| 222 | + except Exception: |
| 223 | + pass |
| 224 | + |
| 225 | + @property |
| 226 | + def port(self) -> Optional[int]: |
| 227 | + self._port_bound.wait() |
| 228 | + return self._port |
| 229 | + |
| 230 | + def close(self) -> None: |
| 231 | + self._socket.close() |
| 232 | + |
| 233 | + |
| 234 | +@pytest.fixture(scope="session") |
| 235 | +def kerberized_hive_metastore_fake_url() -> Generator[str, None, None]: |
| 236 | + server = SaslServer( |
| 237 | + # Port 0 means pick any available port. |
| 238 | + socket=thrift.transport.TSocket.TServerSocket(port=0), |
| 239 | + # Always return a message with status 5 (COMPLETE). |
| 240 | + response=struct.pack(">BI", 5, 0), |
| 241 | + ) |
| 242 | + server.start() |
| 243 | + yield f"thrift://localhost:{server.port}" |
| 244 | + server.close() |
| 245 | + |
| 246 | + |
186 | 247 | def test_no_uri_supplied() -> None:
|
187 | 248 | with pytest.raises(KeyError):
|
188 | 249 | HiveCatalog("production")
|
@@ -1239,3 +1300,45 @@ def test_create_hive_client_failure() -> None:
|
1239 | 1300 | with pytest.raises(Exception, match="Connection failed"):
|
1240 | 1301 | HiveCatalog._create_hive_client(properties)
|
1241 | 1302 | assert mock_hive_client.call_count == 2
|
| 1303 | + |
| 1304 | + |
| 1305 | +def test_create_hive_client_with_kerberos( |
| 1306 | + kerberized_hive_metastore_fake_url: str, |
| 1307 | +) -> None: |
| 1308 | + properties = { |
| 1309 | + "uri": kerberized_hive_metastore_fake_url, |
| 1310 | + "ugi": "user", |
| 1311 | + HIVE_KERBEROS_AUTH: "true", |
| 1312 | + } |
| 1313 | + client = HiveCatalog._create_hive_client(properties) |
| 1314 | + assert client is not None |
| 1315 | + |
| 1316 | + |
| 1317 | +def test_create_hive_client_with_kerberos_using_context_manager( |
| 1318 | + kerberized_hive_metastore_fake_url: str, |
| 1319 | +) -> None: |
| 1320 | + client = _HiveClient( |
| 1321 | + uri=kerberized_hive_metastore_fake_url, |
| 1322 | + kerberos_auth=True, |
| 1323 | + ) |
| 1324 | + with ( |
| 1325 | + patch( |
| 1326 | + "puresasl.mechanisms.kerberos.authGSSClientStep", |
| 1327 | + return_value=None, |
| 1328 | + ), |
| 1329 | + patch( |
| 1330 | + "puresasl.mechanisms.kerberos.authGSSClientResponse", |
| 1331 | + return_value=base64.b64encode(b"Some Response"), |
| 1332 | + ), |
| 1333 | + patch( |
| 1334 | + "puresasl.mechanisms.GSSAPIMechanism.complete", |
| 1335 | + return_value=True, |
| 1336 | + ), |
| 1337 | + ): |
| 1338 | + with client as open_client: |
| 1339 | + assert open_client._iprot.trans.isOpen() |
| 1340 | + |
| 1341 | + # Use the context manager a second time to see if |
| 1342 | + # closing and re-opening work as expected. |
| 1343 | + with client as open_client: |
| 1344 | + assert open_client._iprot.trans.isOpen() |
0 commit comments