Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
denisenkom committed Nov 27, 2023
1 parent f42c165 commit 123e468
Show file tree
Hide file tree
Showing 21 changed files with 243 additions and 207 deletions.
2 changes: 1 addition & 1 deletion profiling/profile_reader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import struct
import cProfile
import pstats
import pytds.tds
import pytds.tds_socket


BUFSIZE = 4096
Expand Down
2 changes: 1 addition & 1 deletion src/pytds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
recordtype_row_strategy, # noqa: F401 # export for backward compatibility
RowStrategy,
)
from .tds import _TdsSocket
from .tds_socket import _TdsSocket
from . import instance_browser_client
from . import tds_base
from . import utils
Expand Down
10 changes: 5 additions & 5 deletions src/pytds/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
import weakref
from . import tds_base
from . import tds
from .tds_socket import _TdsSocket
from . import row_strategies
from .tds_base import logger
from . import connection_pool
Expand Down Expand Up @@ -73,10 +73,10 @@ def __init__(
self,
pooling: bool,
key: connection_pool.PoolKeyType,
tds_socket: tds._TdsSocket,
tds_socket: _TdsSocket,
) -> None:
# _tds_socket is set to None when connection is closed
self._tds_socket: tds._TdsSocket | None = tds_socket
self._tds_socket: _TdsSocket | None = tds_socket
self._key = key
self._pooling = pooling
# references to all cursors opened from connection
Expand Down Expand Up @@ -236,7 +236,7 @@ def __init__(
self,
pooling: bool,
key: connection_pool.PoolKeyType,
tds_socket: tds._TdsSocket,
tds_socket: _TdsSocket,
):
super().__init__(pooling=pooling, key=key, tds_socket=tds_socket)

Expand Down Expand Up @@ -276,7 +276,7 @@ def __init__(
self,
pooling: bool,
key: connection_pool.PoolKeyType,
tds_socket: tds._TdsSocket,
tds_socket: _TdsSocket,
):
super().__init__(pooling=pooling, key=key, tds_socket=tds_socket)
self._active_cursor: NonMarsCursor | None = None
Expand Down
2 changes: 1 addition & 1 deletion src/pytds/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Optional, Union, Tuple

from pytds.tds_base import AuthProtocol
from pytds.tds import _TdsSocket, _TdsSession
from pytds.tds_socket import _TdsSocket, _TdsSession

PoolKeyType = Tuple[
Optional[str],
Expand Down
2 changes: 1 addition & 1 deletion src/pytds/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pytds.connection import Connection, MarsConnection, NonMarsConnection
from pytds.tds_types import NVarCharType, TzInfoFactoryType

from pytds.tds import _TdsSession
from pytds.tds_socket import _TdsSession

from pytds import tds_base
from .tds_base import logger
Expand Down
27 changes: 13 additions & 14 deletions src/pytds/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
"""
from __future__ import annotations

import base64
import ctypes
import logging
import socket

Expand Down Expand Up @@ -42,7 +44,7 @@ def __init__(
server_name: str = "",
port: int | None = None,
spn: str | None = None,
):
) -> None:
from . import sspi

# parse username/password informations
Expand Down Expand Up @@ -75,7 +77,6 @@ def __init__(

def create_packet(self) -> bytes:
from . import sspi
import ctypes

buf = ctypes.create_string_buffer(4096)
ctx, status, bufs = self._cred.create_context(
Expand All @@ -91,7 +92,6 @@ def create_packet(self) -> bytes:

def handle_next(self, packet: bytes) -> bytes | None:
from . import sspi
import ctypes

if self._ctx:
buf = ctypes.create_string_buffer(4096)
Expand Down Expand Up @@ -128,7 +128,7 @@ class NtlmAuth(AuthProtocol):
:type ntlm_compatibility: int
"""

def __init__(self, user_name: str, password: str, ntlm_compatibility: int = 3):
def __init__(self, user_name: str, password: str, ntlm_compatibility: int = 3) -> None:
self._user_name = user_name
if "\\" in user_name:
domain, self._user = user_name.split("\\", 1)
Expand Down Expand Up @@ -170,7 +170,7 @@ class SpnegoAuth(AuthProtocol):
Takes same parameters as spnego.client function.
"""

def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
try:
import spnego
except ImportError:
Expand All @@ -180,7 +180,10 @@ def __init__(self, *args, **kwargs):
self._context = spnego.client(*args, **kwargs)

def create_packet(self) -> bytes:
return self._context.step()
result = self._context.step()
if not result:
raise RuntimeError("spnego did not create initial packet")
return result

def handle_next(self, packet: bytes) -> bytes | None:
return self._context.step(packet)
Expand All @@ -190,36 +193,32 @@ def close(self) -> None:


class KerberosAuth(AuthProtocol):
def __init__(self, server_principal):
def __init__(self, server_principal: str) -> None:
try:
import kerberos # type: ignore # fix later
except ImportError:
import winkerberos as kerberos # type: ignore # fix later
self._kerberos = kerberos
res, context = kerberos.authGSSClientInit(server_principal)
if res < 0:
raise RuntimeError("authGSSClientInit failed with code {}".format(res))
raise RuntimeError(f"authGSSClientInit failed with code {res}")
logger.info("Initialized GSS context")
self._context = context

def create_packet(self) -> bytes:
import base64

res = self._kerberos.authGSSClientStep(self._context, "")
if res < 0:
raise RuntimeError("authGSSClientStep failed with code {}".format(res))
raise RuntimeError(f"authGSSClientStep failed with code {res}")
data = self._kerberos.authGSSClientResponse(self._context)
logger.info("created first client GSS packet %s", data)
return base64.b64decode(data)

def handle_next(self, packet: bytes) -> bytes | None:
import base64

res = self._kerberos.authGSSClientStep(
self._context, base64.b64encode(packet).decode("ascii")
)
if res < 0:
raise RuntimeError("authGSSClientStep failed with code {}".format(res))
raise RuntimeError(f"authGSSClientStep failed with code {res}")
if res == self._kerberos.AUTH_GSS_COMPLETE:
logger.info("GSS authentication completed")
return b""
Expand Down
4 changes: 4 additions & 0 deletions src/pytds/row_strategies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
"""
This module implements various row strategies.
E.g. row strategy that generated dictionaries or named tuples for rows.
"""
import collections
import keyword
import re
Expand Down
6 changes: 4 additions & 2 deletions src/pytds/smp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# This file implements Session Multiplex Protocol used by MARS connections
# Protocol documentation https://msdn.microsoft.com/en-us/library/cc219643.aspx
"""
This file implements Session Multiplex Protocol used by MARS connections
Protocol documentation https://msdn.microsoft.com/en-us/library/cc219643.aspx
"""
from __future__ import annotations

import struct
Expand Down
3 changes: 3 additions & 0 deletions src/pytds/sspi.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
This module implements wrapper for Windows SSPI API
"""
import logging

from ctypes import ( # type: ignore # needs fixing
Expand Down
170 changes: 4 additions & 166 deletions src/pytds/tds.py
Original file line number Diff line number Diff line change
@@ -1,170 +1,8 @@
from __future__ import annotations

import logging
import datetime

from . import tds_base
from . import tds_types
from . import tls
from .tds_base import PreLoginEnc, _TdsEnv, _TdsLogin, Route
from .row_strategies import list_row_strategy
from .smp import SmpManager

"""
This module provides backward compatibility
"""
# _token_map is needed by sqlalchemy_pytds connector
from .tds_session import (
_token_map, # noqa: F401 # _token_map is needed by sqlalchemy_pytds connector
_TdsSession,
)

logger = logging.getLogger(__name__)


class _TdsSocket:
"""
This class represents root TDS connection
if MARS is used it can have multiple sessions represented by _TdsSession class
if MARS is not used it would have single _TdsSession instance
"""

def __init__(
self,
sock: tds_base.TransportProtocol,
login: _TdsLogin,
tzinfo_factory: tds_types.TzInfoFactoryType | None = None,
row_strategy=list_row_strategy,
use_tz: datetime.tzinfo | None = None,
autocommit=False,
isolation_level=0,
):
self._is_connected = False
self.env = _TdsEnv()
self.env.isolation_level = isolation_level
self.collation = None
self.tds72_transaction = 0
self._mars_enabled = False
self.sock = sock
self.bufsize = login.blocksize
self.use_tz = use_tz
self.tds_version = login.tds_version
self.type_factory = tds_types.SerializerFactory(self.tds_version)
self._tzinfo_factory = tzinfo_factory
self._smp_manager: SmpManager | None = None
self._main_session = _TdsSession(
tds=self,
transport=sock,
tzinfo_factory=tzinfo_factory,
row_strategy=row_strategy,
env=self.env,
# initially we use fixed bufsize
# it may be updated later if server specifies different block size
bufsize=4096,
)
self._login = login
self.route: Route | None = None
self._row_strategy = row_strategy
self.env.autocommit = autocommit
self.query_timeout = login.query_timeout
self.type_inferrer = tds_types.TdsTypeInferrer(
type_factory=self.type_factory,
collation=self.collation,
bytes_to_unicode=self._login.bytes_to_unicode,
allow_tz=not self.use_tz,
)
self.server_library_version = (0, 0)
self.product_name = ""
self.product_version = 0

def __repr__(self) -> str:
fmt = "<_TdsSocket tran={} mars={} tds_version={} use_tz={}>"
return fmt.format(
self.tds72_transaction, self._mars_enabled, self.tds_version, self.use_tz
)

def login(self) -> Route | None:
self._login.server_enc_flag = PreLoginEnc.ENCRYPT_NOT_SUP
if tds_base.IS_TDS71_PLUS(self._main_session):
self._main_session.send_prelogin(self._login)
self._main_session.process_prelogin(self._login)
self._main_session.tds7_send_login(self._login)
if self._login.server_enc_flag == PreLoginEnc.ENCRYPT_OFF:
tls.revert_to_clear(self._main_session)
self._main_session.begin_response()
if not self._main_session.process_login_tokens():
self._main_session.raise_db_exception()
if self.route is not None:
return self.route

# update block size if server returned different one
if (
self._main_session._writer.bufsize
!= self._main_session._reader.get_block_size()
):
self._main_session._reader.set_block_size(
self._main_session._writer.bufsize
)

self.type_factory = tds_types.SerializerFactory(self.tds_version)
self.type_inferrer = tds_types.TdsTypeInferrer(
type_factory=self.type_factory,
collation=self.collation,
bytes_to_unicode=self._login.bytes_to_unicode,
allow_tz=not self.use_tz,
)
if self._mars_enabled:
self._smp_manager = SmpManager(self.sock)
self._main_session = _TdsSession(
tds=self,
bufsize=self.bufsize,
transport=self._smp_manager.create_session(),
tzinfo_factory=self._tzinfo_factory,
row_strategy=self._row_strategy,
env=self.env,
)
self._is_connected = True
q = []
if self._login.database and self.env.database != self._login.database:
q.append("use " + tds_base.tds_quote_id(self._login.database))
if q:
self._main_session.submit_plain_query("".join(q))
self._main_session.process_simple_request()
return None

@property
def mars_enabled(self) -> bool:
return self._mars_enabled

@property
def main_session(self) -> _TdsSession:
return self._main_session

def create_session(self) -> _TdsSession:
if not self._smp_manager:
raise RuntimeError(
"Calling create_session on a non-MARS connection does not work"
)
return _TdsSession(
tds=self,
transport=self._smp_manager.create_session(),
tzinfo_factory=self._tzinfo_factory,
row_strategy=self._row_strategy,
bufsize=self.bufsize,
env=self.env,
)

def is_connected(self) -> bool:
return self._is_connected

def close(self) -> None:
self._is_connected = False
if self.sock is not None:
self.sock.close()
if self._smp_manager:
self._smp_manager.transport_closed()
self._main_session.state = tds_base.TDS_DEAD
if self._main_session.authentication:
self._main_session.authentication.close()
self._main_session.authentication = None

def close_all_mars_sessions(self) -> None:
if self._smp_manager:
self._smp_manager.close_all_sessions(keep=self.main_session._transport)
from . import tds_base # noqa: F401 # this is needed by sqlalchemy_pytds connector
3 changes: 3 additions & 0 deletions src/pytds/tds_reader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
"""
This module implements TdsReader class
"""
from __future__ import annotations

import struct
Expand Down
Loading

0 comments on commit 123e468

Please sign in to comment.