diff --git a/conftest.py b/conftest.py index 0b559d6..6db66b3 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,3 @@ def pytest_configure(config): - plugin = config.pluginmanager.getplugin('mypy') - plugin.mypy_argv.append('--check-untyped-defs') + plugin = config.pluginmanager.getplugin("mypy") + plugin.mypy_argv.append("--check-untyped-defs") diff --git a/docs/conf.py b/docs/conf.py index 8355182..c892cda 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,194 +22,203 @@ # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -#needs_sphinx = '1.0' +# needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.todo', 'sphinx.ext.coverage', 'sphinx.ext.viewcode'] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.viewcode", +] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'python-tds' -copyright = u'2013, Mikhail Denisenko' +project = "python-tds" +copyright = "2013, Mikhail Denisenko" # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. -version = '1.6' +version = "1.6" # The full version, including alpha/beta/rc tags. -release = '1.6' +release = "1.6" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns = ['_build'] +exclude_patterns = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +# default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = "default" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. -#html_theme_options = {} +# html_theme_options = {} # Add any paths that contain custom themes here, relative to this directory. -#html_theme_path = [] +# html_theme_path = [] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. -#html_show_sphinx = True +# html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'python-tdsdoc' +htmlhelp_basename = "python-tdsdoc" # -- Options for LaTeX output -------------------------------------------------- latex_elements = { -# The paper size ('letterpaper' or 'a4paper'). -#'papersize': 'letterpaper', - -# The font size ('10pt', '11pt' or '12pt'). -#'pointsize': '10pt', - -# Additional stuff for the LaTeX preamble. -#'preamble': '', + # The paper size ('letterpaper' or 'a4paper'). + #'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + #'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'python-tds.tex', u'python-tds Documentation', - u'Mikhail Denisenko', 'manual'), + ( + "index", + "python-tds.tex", + "python-tds Documentation", + "Mikhail Denisenko", + "manual", + ), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # If true, show page references after internal links. -#latex_show_pagerefs = False +# latex_show_pagerefs = False # If true, show URL addresses after external links. -#latex_show_urls = False +# latex_show_urls = False # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_domain_indices = True +# latex_domain_indices = True # -- Options for manual page output -------------------------------------------- @@ -217,12 +226,11 @@ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - ('index', 'python-tds', u'python-tds Documentation', - [u'Mikhail Denisenko'], 1) + ("index", "python-tds", "python-tds Documentation", ["Mikhail Denisenko"], 1) ] # If true, show URL addresses after external links. -#man_show_urls = False +# man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ @@ -231,19 +239,25 @@ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'python-tds', u'python-tds Documentation', - u'Mikhail Denisenko', 'python-tds', 'One line description of project.', - 'Miscellaneous'), + ( + "index", + "python-tds", + "python-tds Documentation", + "Mikhail Denisenko", + "python-tds", + "One line description of project.", + "Miscellaneous", + ), ] # Documents to append as an appendix to all manuals. -#texinfo_appendices = [] +# texinfo_appendices = [] # If false, no module index is generated. -#texinfo_domain_indices = True +# texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. -#texinfo_show_urls = 'footnote' +# texinfo_show_urls = 'footnote' # If true, do not generate a @detailmenu in the "Top" node's menu. -#texinfo_no_detailmenu = False +# texinfo_no_detailmenu = False diff --git a/profiling/profile_reader.py b/profiling/profile_reader.py index 4244dce..ed98da8 100644 --- a/profiling/profile_reader.py +++ b/profiling/profile_reader.py @@ -5,13 +5,13 @@ BUFSIZE = 4096 -HEADER = struct.Struct('>BBHHBx') +HEADER = struct.Struct(">BBHHBx") class Sock: def __init__(self): self._read_pos = 0 - self._buf = bytearray(b'\x00' * BUFSIZE) + self._buf = bytearray(b"\x00" * BUFSIZE) HEADER.pack_into(self._buf, 0, 0, 0, BUFSIZE, 0, 0) def sendall(self, data, flags=0): @@ -24,14 +24,14 @@ def recv_into(self, buffer, size=0): HEADER.pack_into(self._buf, 0, 0, 0, BUFSIZE, 0, 0) self._read_pos = 0 to_read = min(size, BUFSIZE - self._read_pos) - buffer[:to_read] = self._buf[self._read_pos:self._read_pos+to_read] + buffer[:to_read] = self._buf[self._read_pos : self._read_pos + to_read] return to_read def recv(self, size): if self._read_pos >= len(self._buf): HEADER.pack_into(self._buf, 0, 0, 0, BUFSIZE, 0, 0) self._read_pos = 0 - res = self._buf[self._read_pos:self._read_pos + size] + res = self._buf[self._read_pos : self._read_pos + size] self._read_pos += len(res) return res @@ -52,6 +52,6 @@ def __init__(self): for _ in range(50000): rdr.recv(BUFSIZE) pr.disable() -sortby = 'tottime' +sortby = "tottime" ps = pstats.Stats(pr).sort_stats(sortby) ps.print_stats() diff --git a/profiling/profile_smp.py b/profiling/profile_smp.py index aab9c0b..366fb81 100644 --- a/profiling/profile_smp.py +++ b/profiling/profile_smp.py @@ -8,14 +8,14 @@ transport = None bufsize = 512 -smp_header = struct.Struct(' ") @@ -23,15 +25,15 @@ def main(): for _, msg in cursor.messages: print(msg.text) if cursor.description: - print('\t'.join(col[0] for col in cursor.description)) - print('-' * 80) + print("\t".join(col[0] for col in cursor.description)) + print("-" * 80) count = 0 for row in cursor: - print('\t'.join(str(col) for col in row)) + print("\t".join(str(col) for col in row)) count += 1 - print('-' * 80) + print("-" * 80) print("Returned {} rows".format(count)) print() -main() \ No newline at end of file +main() diff --git a/src/pytds/__init__.py b/src/pytds/__init__.py index bda2e78..01b6492 100644 --- a/src/pytds/__init__.py +++ b/src/pytds/__init__.py @@ -20,35 +20,51 @@ from .connection_pool import connection_pool, PoolKeyType from .login import KerberosAuth, SspiAuth, AuthProtocol from .row_strategies import * -from .tds import ( - _TdsSocket, tds7_get_instances, - _TdsLogin -) +from .tds import _TdsSocket, tds7_get_instances, _TdsLogin from . import tds_base from .tds_base import ( - Error, LoginError, DatabaseError, ProgrammingError, - IntegrityError, DataError, InternalError, - InterfaceError, TimeoutError, OperationalError, - NotSupportedError, Warning, ClosedConnectionError, + Error, + LoginError, + DatabaseError, + ProgrammingError, + IntegrityError, + DataError, + InternalError, + InterfaceError, + TimeoutError, + OperationalError, + NotSupportedError, + Warning, + ClosedConnectionError, Column, - PreLoginEnc, _create_exception_by_message) + PreLoginEnc, + _create_exception_by_message, +) from .tds_session import _TdsSession -from .tds_types import ( - TableValuedParam, Binary -) +from .tds_types import TableValuedParam, Binary from .tds_base import ( - ROWID, DECIMAL, STRING, BINARY, NUMBER, DATETIME, INTEGER, REAL, XML, output, default + ROWID, + DECIMAL, + STRING, + BINARY, + NUMBER, + DATETIME, + INTEGER, + REAL, + XML, + output, + default, ) from . import tls -import pkg_resources # type: ignore # fix later +import pkg_resources # type: ignore # fix later -__author__ = 'Mikhail Denisenko ' +__author__ = "Mikhail Denisenko " try: - __version__ = pkg_resources.get_distribution('python-tds').version + __version__ = pkg_resources.get_distribution("python-tds").version except: __version__ = "DEV" @@ -56,24 +72,28 @@ def _ver_to_int(ver): - res = ver.split('.') + res = ver.split(".") if len(res) < 2: - logger.warning('Invalid version {}, it should have 2 parts at least separated by "."'.format(ver)) + logger.warning( + 'Invalid version {}, it should have 2 parts at least separated by "."'.format( + ver + ) + ) return 0 - maj, minor, _ = ver.split('.') + maj, minor, _ = ver.split(".") return (int(maj) << 24) + (int(minor) << 16) intversion = _ver_to_int(__version__) #: Compliant with DB SIG 2.0 -apilevel = '2.0' +apilevel = "2.0" #: Module may be shared, but not connections threadsafety = 1 #: This module uses extended python format codes -paramstyle = 'pyformat' +paramstyle = "pyformat" class Cursor(Protocol, Iterable): @@ -81,6 +101,7 @@ class Cursor(Protocol, Iterable): This class defines an interface for cursor classes. It is implemented by MARS and non-MARS cursor classes. """ + def __enter__(self) -> Cursor: ... @@ -90,7 +111,11 @@ def __exit__(self, *args) -> None: def get_proc_outputs(self) -> list[Any]: ... - def callproc(self, procname: tds_base.InternalProc | str, parameters: dict[str, Any] | tuple[Any, ...] = ()) -> list[Any]: + def callproc( + self, + procname: tds_base.InternalProc | str, + parameters: dict[str, Any] | tuple[Any, ...] = (), + ) -> list[Any]: ... @property @@ -114,13 +139,25 @@ def cancel(self) -> None: def close(self) -> None: ... - def execute(self, operation: str, params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = ()) -> Cursor: + def execute( + self, + operation: str, + params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = (), + ) -> Cursor: ... - def executemany(self, operation: str, params_seq: Iterable[list[Any] | tuple[Any, ...] | dict[str, Any]]) -> None: + def executemany( + self, + operation: str, + params_seq: Iterable[list[Any] | tuple[Any, ...] | dict[str, Any]], + ) -> None: ... - def execute_scalar(self, query_string: str, params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = None) -> Any: + def execute_scalar( + self, + query_string: str, + params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = None, + ) -> Any: ... def nextset(self) -> bool | None: @@ -138,7 +175,9 @@ def set_stream(self, column_idx: int, stream) -> None: ... @property - def messages(self) -> list[Tuple[Type, IntegrityError | ProgrammingError | OperationalError]] | None: + def messages( + self + ) -> list[Tuple[Type, IntegrityError | ProgrammingError | OperationalError]] | None: ... @property @@ -163,21 +202,21 @@ def setoutputsize(size=None, column=0) -> None: ... def copy_to( - self, - file: Iterable[str] | None = None, - table_or_view: str | None = None, - sep: str = '\t', - columns: Iterable[Column | str] | None = None, - check_constraints: bool = False, - fire_triggers: bool = False, - keep_nulls: bool = False, - kb_per_batch: int | None = None, - rows_per_batch: int | None = None, - order: str | None = None, - tablock: bool = False, - schema: str | None = None, - null_string: str | None = None, - data: Iterable[Tuple[Any, ...]] | None = None + self, + file: Iterable[str] | None = None, + table_or_view: str | None = None, + sep: str = "\t", + columns: Iterable[Column | str] | None = None, + check_constraints: bool = False, + fire_triggers: bool = False, + keep_nulls: bool = False, + kb_per_batch: int | None = None, + rows_per_batch: int | None = None, + order: str | None = None, + tablock: bool = False, + schema: str | None = None, + null_string: str | None = None, + data: Iterable[Tuple[Any, ...]] | None = None, ): ... @@ -187,6 +226,7 @@ class Connection(Protocol): This class defines interface for connection object according to DBAPI specification. This interface is implemented by MARS and non-MARS connection classes. """ + @property def autocommit(self) -> bool: ... @@ -235,10 +275,10 @@ class BaseConnection(Connection): _connection_closed_exception = InterfaceError("Connection closed") def __init__( - self, - pooling: bool, - key: PoolKeyType, - tds_socket: _TdsSocket, + self, + pooling: bool, + key: PoolKeyType, + tds_socket: _TdsSocket, ) -> None: # _tds_socket is set to None when connection is closed self._tds_socket: _TdsSocket | None = tds_socket @@ -262,7 +302,7 @@ def as_dict(self) -> bool: def as_dict(self, value: bool) -> None: warnings.warn( "setting as_dict property on the active connection, instead create connection with needed row_strategy", - DeprecationWarning + DeprecationWarning, ) if not self._tds_socket: raise self._connection_closed_exception @@ -281,8 +321,7 @@ def autocommit_state(self) -> bool: return self._tds_socket.main_session.autocommit def set_autocommit(self, value: bool) -> None: - """ An alias for `autocommit`, provided for compatibility with ADO dbapi - """ + """An alias for `autocommit`, provided for compatibility with ADO dbapi""" if not self._tds_socket: raise self._connection_closed_exception self._tds_socket.main_session.autocommit = value @@ -365,7 +404,7 @@ def rollback(self) -> None: self._tds_socket.main_session.rollback(cont=True) def close(self) -> None: - """ Close connection to an MS SQL Server. + """Close connection to an MS SQL Server. This function tries to close the connection and free all memory used. It can be called more than once in a row. No exception is raised in @@ -374,7 +413,9 @@ def close(self) -> None: if self._tds_socket: logger.debug("Closing connection") if self._pooling: - connection_pool.add(self._key, (self._tds_socket, self._tds_socket.main_session)) + connection_pool.add( + self._key, (self._tds_socket, self._tds_socket.main_session) + ) else: self._tds_socket.close() logger.debug("Closing all cursors which were opened by connection") @@ -388,8 +429,8 @@ class MarsConnection(BaseConnection): MARS connection class, this object is created by calling :func:`connect` with use_mars parameter set to False. """ - def __init__(self, pooling: bool, key: PoolKeyType, - tds_socket: _TdsSocket): + + def __init__(self, pooling: bool, key: PoolKeyType, tds_socket: _TdsSocket): super().__init__(pooling=pooling, key=key, tds_socket=tds_socket) @property @@ -421,9 +462,8 @@ class NonMarsConnection(BaseConnection): Non-MARS connection class, this object should be created by calling :func:`connect` with use_mars parameter set to False. """ - def __init__(self, pooling: bool, key: PoolKeyType, - tds_socket: _TdsSocket): + def __init__(self, pooling: bool, key: PoolKeyType, tds_socket: _TdsSocket): super().__init__(pooling=pooling, key=key, tds_socket=tds_socket) self._active_cursor: NonMarsCursor | None = None @@ -458,6 +498,7 @@ class BaseCursor(Cursor, Iterator): There are two actual cursor classes: one for MARS connections and one for non-MARS connections. """ + _cursor_closed_exception = InterfaceError("Cursor is closed") def __init__(self, connection: Connection, session: _TdsSession): @@ -472,7 +513,7 @@ def __init__(self, connection: Connection, session: _TdsSession): def connection(self) -> Connection | None: warnings.warn( "connection property is deprecated on the cursor object and will be removed in future releases", - DeprecationWarning + DeprecationWarning, ) return self._connection @@ -499,7 +540,11 @@ def get_proc_outputs(self) -> list[Any]: raise self._cursor_closed_exception return self._session.get_proc_outputs() - def callproc(self, procname: tds_base.InternalProc | str, parameters: dict[str, Any] | tuple[Any, ...] = ()) -> list[Any]: + def callproc( + self, + procname: tds_base.InternalProc | str, + parameters: dict[str, Any] | tuple[Any, ...] = (), + ) -> list[Any]: """ Call a stored procedure with the given name. @@ -518,13 +563,12 @@ def callproc(self, procname: tds_base.InternalProc | str, parameters: dict[str, @property def return_value(self) -> int | None: - """ Alias to :func:`get_proc_return_status` - """ + """Alias to :func:`get_proc_return_status`""" return self.get_proc_return_status() @property def spid(self) -> int: - """ MSSQL Server's session ID (SPID) + """MSSQL Server's session ID (SPID) It can be used to correlate connections between client and server logs. """ @@ -545,7 +589,7 @@ def _set_tzinfo_factory(self, tzinfo_factory: TzInfoFactoryType | None) -> None: tzinfo_factory = property(_get_tzinfo_factory, _set_tzinfo_factory) def get_proc_return_status(self) -> int | None: - """ Last executed stored procedure's return value + """Last executed stored procedure's return value Returns integer value returned by `RETURN` statement from last executed stored procedure. If no value was not returned or no stored procedure was executed return `None`. @@ -555,8 +599,7 @@ def get_proc_return_status(self) -> int | None: return self._session.get_proc_return_status() def cancel(self) -> None: - """ Cancel currently executing statement or stored procedure call - """ + """Cancel currently executing statement or stored procedure call""" if self._session is None: return self._session.cancel_if_pending() @@ -569,10 +612,14 @@ def close(self) -> None: self._session = None self._connection = None - T = TypeVar('T') + T = TypeVar("T") - def execute(self, operation: str, params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = ()) -> BaseCursor: - """ Execute an SQL query + def execute( + self, + operation: str, + params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = (), + ) -> BaseCursor: + """Execute an SQL query Optionally query can be executed with parameters. To make parametrized query use `%s` in the query to denote a parameter @@ -602,7 +649,11 @@ def execute(self, operation: str, params: list[Any] | tuple[Any, ...] | dict[str # for compatibility with pyodbc return self - def executemany(self, operation: str, params_seq: Iterable[list[Any] | tuple[Any, ...] | dict[str, Any]]) -> None: + def executemany( + self, + operation: str, + params_seq: Iterable[list[Any] | tuple[Any, ...] | dict[str, Any]], + ) -> None: """ Execute same SQL query multiple times for each parameter set in the `params_seq` list. """ @@ -610,7 +661,11 @@ def executemany(self, operation: str, params_seq: Iterable[list[Any] | tuple[Any raise self._cursor_closed_exception self._session.executemany(operation=operation, params_seq=params_seq) - def execute_scalar(self, query_string: str, params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = None) -> Any: + def execute_scalar( + self, + query_string: str, + params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = None, + ) -> Any: """ This method executes SQL query then returns first column of first row or the result. @@ -632,7 +687,7 @@ def execute_scalar(self, query_string: str, params: list[Any] | tuple[Any, ...] return self._session.execute_scalar(query_string, params) def nextset(self) -> bool | None: - """ Move to next recordset in batch statement, all rows of current recordset are + """Move to next recordset in batch statement, all rows of current recordset are discarded if present. :returns: true if successful or ``None`` when there are no more recordsets @@ -643,7 +698,7 @@ def nextset(self) -> bool | None: @property def rowcount(self) -> int: - """ Number of rows affected by previous statement + """Number of rows affected by previous statement :returns: -1 if this information was not supplied by the server """ @@ -653,8 +708,7 @@ def rowcount(self) -> int: @property def description(self): - """ Cursor description, see http://legacy.python.org/dev/peps/pep-0249/#description - """ + """Cursor description, see http://legacy.python.org/dev/peps/pep-0249/#description""" if self._session is None: return None res = self._session.res_info @@ -696,15 +750,18 @@ def set_stream(self, column_idx: int, stream) -> None: raise self._cursor_closed_exception res_info = self._session.res_info if not res_info: - raise ValueError('No result set is active') + raise ValueError("No result set is active") if len(res_info.columns) <= column_idx or column_idx < 0: - raise ValueError('Invalid value for column_idx') - res_info.columns[column_idx].serializer.set_chunk_handler(pytds.tds_types._StreamChunkedHandler(stream)) + raise ValueError("Invalid value for column_idx") + res_info.columns[column_idx].serializer.set_chunk_handler( + pytds.tds_types._StreamChunkedHandler(stream) + ) @property - def messages(self) -> list[Tuple[Type, IntegrityError | ProgrammingError | OperationalError]] | None: - """ Messages generated by server, see http://legacy.python.org/dev/peps/pep-0249/#cursor-messages - """ + def messages( + self + ) -> list[Tuple[Type, IntegrityError | ProgrammingError | OperationalError]] | None: + """Messages generated by server, see http://legacy.python.org/dev/peps/pep-0249/#cursor-messages""" if self._session: result = [] for msg in self._session.messages: @@ -716,8 +773,7 @@ def messages(self) -> list[Tuple[Type, IntegrityError | ProgrammingError | Opera @property def native_description(self): - """ todo document - """ + """todo document""" if self._session is None: return None res = self._session.res_info @@ -727,7 +783,7 @@ def native_description(self): return None def fetchone(self) -> Any: - """ Fetch next row. + """Fetch next row. Returns row using currently configured factory, or ``None`` if there are no more rows """ @@ -736,7 +792,7 @@ def fetchone(self) -> Any: return self._session.fetchone() def fetchmany(self, size=None) -> list[Any]: - """ Fetch next N rows + """Fetch next N rows :param size: Maximum number of rows to return, default value is cursor.arraysize :returns: List of rows @@ -755,7 +811,7 @@ def fetchmany(self, size=None) -> list[Any]: return rows def fetchall(self) -> list[Any]: - """ Fetch all remaining rows + """Fetch all remaining rows Do not use this if you expect large number of rows returned by the server, since this method will load all rows into memory. It is more efficient @@ -786,23 +842,23 @@ def setoutputsize(size=None, column=0) -> None: pass def copy_to( - self, - file: Iterable[str] | None = None, - table_or_view: str | None = None, - sep: str = '\t', - columns: Iterable[Column | str] | None = None, - check_constraints: bool = False, - fire_triggers: bool = False, - keep_nulls: bool = False, - kb_per_batch: int | None = None, - rows_per_batch: int | None = None, - order: str | None = None, - tablock: bool = False, - schema: str | None = None, - null_string: str | None = None, - data: Iterable[collections.abc.Sequence[Any]] | None = None + self, + file: Iterable[str] | None = None, + table_or_view: str | None = None, + sep: str = "\t", + columns: Iterable[Column | str] | None = None, + check_constraints: bool = False, + fire_triggers: bool = False, + keep_nulls: bool = False, + kb_per_batch: int | None = None, + rows_per_batch: int | None = None, + order: str | None = None, + tablock: bool = False, + schema: str | None = None, + null_string: str | None = None, + data: Iterable[collections.abc.Sequence[Any]] | None = None, ): - """ *Experimental*. Efficiently load data to database from file using ``BULK INSERT`` operation + """*Experimental*. Efficiently load data to database from file using ``BULK INSERT`` operation :param file: Source file-like object, should be in csv format. Specify either this or data, not both. @@ -851,7 +907,7 @@ def copy_to( """ if self._session is None: raise self._cursor_closed_exception - #conn = self._conn() + # conn = self._conn() rows: typing.Iterable[collections.abc.Sequence[typing.Any]] if data is None: if file is None: @@ -859,6 +915,7 @@ def copy_to( reader = csv.reader(file, delimiter=sep) if null_string is not None: + def _convert_null_strings(csv_reader): for row in csv_reader: yield [r if r != null_string else None for r in row] @@ -871,39 +928,53 @@ def _convert_null_strings(csv_reader): obj_name = tds_base.tds_quote_id(table_or_view) if schema: - obj_name = f'{tds_base.tds_quote_id(schema)}.{obj_name}' + obj_name = f"{tds_base.tds_quote_id(schema)}.{obj_name}" if columns: metadata = [] for column in columns: if isinstance(column, Column): metadata.append(column) else: - metadata.append(Column(name=column, type=NVarCharType(size=4000), flags=Column.fNullable)) + metadata.append( + Column( + name=column, + type=NVarCharType(size=4000), + flags=Column.fNullable, + ) + ) else: - self.execute(f'select top 1 * from {obj_name} where 1<>1') - metadata = [Column(name=col[0], type=NVarCharType(size=4000), flags=Column.fNullable if col[6] else 0) - for col in self.description] - col_defs = ','.join(f'{tds_base.tds_quote_id(col.column_name)} {col.type.get_declaration()}' - for col in metadata) + self.execute(f"select top 1 * from {obj_name} where 1<>1") + metadata = [ + Column( + name=col[0], + type=NVarCharType(size=4000), + flags=Column.fNullable if col[6] else 0, + ) + for col in self.description + ] + col_defs = ",".join( + f"{tds_base.tds_quote_id(col.column_name)} {col.type.get_declaration()}" + for col in metadata + ) with_opts = [] if check_constraints: - with_opts.append('CHECK_CONSTRAINTS') + with_opts.append("CHECK_CONSTRAINTS") if fire_triggers: - with_opts.append('FIRE_TRIGGERS') + with_opts.append("FIRE_TRIGGERS") if keep_nulls: - with_opts.append('KEEP_NULLS') + with_opts.append("KEEP_NULLS") if kb_per_batch: - with_opts.append('KILOBYTES_PER_BATCH = {0}'.format(kb_per_batch)) + with_opts.append("KILOBYTES_PER_BATCH = {0}".format(kb_per_batch)) if rows_per_batch: - with_opts.append('ROWS_PER_BATCH = {0}'.format(rows_per_batch)) + with_opts.append("ROWS_PER_BATCH = {0}".format(rows_per_batch)) if order: - with_opts.append('ORDER({0})'.format(','.join(order))) + with_opts.append("ORDER({0})".format(",".join(order))) if tablock: - with_opts.append('TABLOCK') - with_part = '' + with_opts.append("TABLOCK") + with_part = "" if with_opts: - with_part = 'WITH ({0})'.format(','.join(with_opts)) - operation = 'INSERT BULK {0}({1}) {2}'.format(obj_name, col_defs, with_part) + with_part = "WITH ({0})".format(",".join(with_opts)) + operation = "INSERT BULK {0}({1}) {2}".format(obj_name, col_defs, with_part) self.execute(operation) self._session.submit_bulk(metadata, rows) self._session.process_simple_request() @@ -916,6 +987,7 @@ class NonMarsCursor(BaseCursor): Non-MARS connections allow only one cursor to be active at a given time. """ + def __init__(self, connection: NonMarsConnection, session: _TdsSession): super().__init__(connection=connection, session=session) @@ -927,6 +999,7 @@ class _MarsCursor(BaseCursor): MARS connections allow multiple cursors to be active at the same time. """ + def __init__(self, connection: MarsConnection, session: _TdsSession): super().__init__( connection=connection, @@ -936,7 +1009,7 @@ def __init__(self, connection: MarsConnection, session: _TdsSession): @property def spid(self) -> int: # not thread safe for connection - return self.execute_scalar('select @@SPID') + return self.execute_scalar("select @@SPID") def close(self) -> None: """ @@ -949,18 +1022,26 @@ def close(self) -> None: self._connection = None -def _resolve_instance_port(server: Any, port: int, instance: str, timeout: float = 5) -> int: +def _resolve_instance_port( + server: Any, port: int, instance: str, timeout: float = 5 +) -> int: if instance and not port: - logger.info('querying %s for list of instances', server) + logger.info("querying %s for list of instances", server) instances = tds7_get_instances(server, timeout=timeout) if not instances: - raise RuntimeError("Querying list of instances failed, returned value has invalid format") + raise RuntimeError( + "Querying list of instances failed, returned value has invalid format" + ) if instance not in instances: - raise LoginError("Instance {0} not found on server {1}".format(instance, server)) + raise LoginError( + "Instance {0} not found on server {1}".format(instance, server) + ) instdict = instances[instance] - if 'tcp' not in instdict: - raise LoginError("Instance {0} doen't have tcp connections enabled".format(instance)) - port = int(instdict['tcp']) + if "tcp" not in instdict: + raise LoginError( + "Instance {0} doen't have tcp connections enabled".format(instance) + ) + port = int(instdict["tcp"]) return port or 1433 @@ -979,11 +1060,16 @@ def _parse_server(server: str) -> Tuple[str, str]: # map to servers deques, used to store active/passive servers # between calls to connect function # deques are used because they can be rotated -_servers_deques: dict[Tuple[Tuple[Tuple[str, int | None, str], ...], str | None], deque[Tuple[Any, int | None, str]]] = {} +_servers_deques: dict[ + Tuple[Tuple[Tuple[str, int | None, str], ...], str | None], + deque[Tuple[Any, int | None, str]], +] = {} -def _get_servers_deque(servers: Tuple[Tuple[str, int | None, str], ...], database: str | None): - """ Returns deque of servers for given tuple of servers and +def _get_servers_deque( + servers: Tuple[Tuple[str, int | None, str], ...], database: str | None +): + """Returns deque of servers for given tuple of servers and database name. This deque have active server at the begining, if first server is not accessible at the moment the deque will be rotated, @@ -1006,48 +1092,48 @@ def _parse_connection_string(connstr: str) -> dict[str, str]: Returns normalized dictionary of connection string parameters """ res = {} - for item in connstr.split(';'): + for item in connstr.split(";"): item = item.strip() if not item: continue - key, value = item.split('=', 1) - key = key.strip().lower().replace(' ', '_') + key, value = item.split("=", 1) + key = key.strip().lower().replace(" ", "_") value = value.strip() res[key] = value return res def connect( - dsn: str | None = None, - database: str | None = None, - user: str | None = None, - password: str | None = None, - timeout: float | None = None, - login_timeout: float = 15, - as_dict: bool | None = None, - appname: str | None = None, - port: int | None = None, - tds_version: int = tds_base.TDS74, - autocommit: bool = False, - blocksize: int = 4096, - use_mars: bool = False, - auth: AuthProtocol | None = None, - readonly: bool = False, - load_balancer: tds_base.LoadBalancer | None = None, - use_tz: datetime.tzinfo | None = None, - bytes_to_unicode: bool = True, - row_strategy: RowStrategy | None = None, - failover_partner: str | None = None, - server: str | None = None, - cafile: str | None = None, - sock: socket.socket | None = None, - validate_host: bool = True, - enc_login_only: bool = False, - disable_connect_retry: bool = False, - pooling: bool = False, - use_sso: bool = False, - isolation_level: int = 0, - ): + dsn: str | None = None, + database: str | None = None, + user: str | None = None, + password: str | None = None, + timeout: float | None = None, + login_timeout: float = 15, + as_dict: bool | None = None, + appname: str | None = None, + port: int | None = None, + tds_version: int = tds_base.TDS74, + autocommit: bool = False, + blocksize: int = 4096, + use_mars: bool = False, + auth: AuthProtocol | None = None, + readonly: bool = False, + load_balancer: tds_base.LoadBalancer | None = None, + use_tz: datetime.tzinfo | None = None, + bytes_to_unicode: bool = True, + row_strategy: RowStrategy | None = None, + failover_partner: str | None = None, + server: str | None = None, + cafile: str | None = None, + sock: socket.socket | None = None, + validate_host: bool = True, + enc_login_only: bool = False, + disable_connect_retry: bool = False, + pooling: bool = False, + use_sso: bool = False, + isolation_level: int = 0, +): """ Opens connection to the database @@ -1107,33 +1193,35 @@ def connect( :returns: An instance of :class:`Connection` """ if use_sso and auth: - raise ValueError('use_sso cannot be used with auth parameter defined') + raise ValueError("use_sso cannot be used with auth parameter defined") login = _TdsLogin() login.client_host_name = socket.gethostname()[:128] login.library = "Python TDS Library" - login.user_name = user or '' - login.password = password or '' - login.app_name = appname or 'pytds' + login.user_name = user or "" + login.password = password or "" + login.app_name = appname or "pytds" login.port = port - login.language = '' # use database default - login.attach_db_file = '' + login.language = "" # use database default + login.attach_db_file = "" login.tds_version = tds_version if tds_version < tds_base.TDS70: - raise ValueError('This TDS version is not supported') - login.database = database or '' + raise ValueError("This TDS version is not supported") + login.database = database or "" login.bulk_copy = False login.client_lcid = lcid.LANGID_ENGLISH_US login.use_mars = use_mars login.pid = os.getpid() - login.change_password = '' + login.change_password = "" login.client_id = uuid.getnode() # client mac address login.cafile = cafile login.validate_host = validate_host login.enc_login_only = enc_login_only if cafile: if not tls.OPENSSL_AVAILABLE: - raise ValueError("You are trying to use encryption but pyOpenSSL does not work, you probably " - "need to install it first") + raise ValueError( + "You are trying to use encryption but pyOpenSSL does not work, you probably " + "need to install it first" + ) login.tls_ctx = tls.create_context(cafile) if login.enc_login_only: login.enc_flag = PreLoginEnc.ENCRYPT_OFF @@ -1166,16 +1254,20 @@ def connect( raise ValueError("Both server and dsn shouldn't be specified") if server: - warnings.warn("server parameter is deprecated, use dsn instead", DeprecationWarning) + warnings.warn( + "server parameter is deprecated, use dsn instead", DeprecationWarning + ) dsn = server if load_balancer and failover_partner: - raise ValueError("Both load_balancer and failover_partner shoudln't be specified") + raise ValueError( + "Both load_balancer and failover_partner shoudln't be specified" + ) servers: list[Tuple[str, int | None]] = [] if load_balancer: servers += ((srv, None) for srv in load_balancer.choose()) else: - servers += [(dsn or 'localhost', port)] + servers += [(dsn or "localhost", port)] if failover_partner: servers.append((failover_partner, port)) @@ -1189,6 +1281,7 @@ def connect( if use_sso: spn = "MSSQLSvc@{}:{}".format(parsed_servers[0][0], parsed_servers[0][1]) from . import login as pytds_login + try: login.auth = pytds_login.SspiAuth(spn=spn) except ImportError: @@ -1217,24 +1310,26 @@ def connect( ) from .tz import FixedOffsetTimezone + tzinfo_factory = None if use_tz is None else FixedOffsetTimezone - #conn = Connection( + # conn = Connection( # login_info=login, # pooling=pooling, # key=key, # use_tz=use_tz, # autocommit=autocommit, # tzinfo_factory=tzinfo_factory - #) + # ) - assert row_strategy is None or as_dict is None,\ - 'Both row_startegy and as_dict were specified, you should use either one or another' + assert ( + row_strategy is None or as_dict is None + ), "Both row_startegy and as_dict were specified, you should use either one or another" if as_dict: row_strategy = dict_row_strategy elif row_strategy is not None: row_strategy = row_strategy else: - row_strategy = tuple_row_strategy # default row strategy + row_strategy = tuple_row_strategy # default row strategy if disable_connect_retry: first_try_time = login.connect_timeout @@ -1292,11 +1387,11 @@ def ex_handler(ex: Exception) -> None: # 2) Login failed for user '' # in this case we want to retry if ex.msg_no in ( - 18456, # login failed - 18486, # account is locked - 18487, # password expired - 18488, # password should be changed - 18452, # login from untrusted domain + 18456, # login failed + 18486, # account is locked + 18487, # password expired + 18488, # password should be changed + 18452, # login from untrusted domain ): raise ex else: @@ -1311,30 +1406,26 @@ def ex_handler(ex: Exception) -> None: def _connect( - login: _TdsLogin, - host: str, - port: int, - instance: str, - timeout: float, - pooling: bool, - key: PoolKeyType, - autocommit: bool, - isolation_level: int, - tzinfo_factory: TzInfoFactoryType | None, - sock: socket.socket | None, - use_tz: datetime.tzinfo | None, - row_strategy: RowStrategy, + login: _TdsLogin, + host: str, + port: int, + instance: str, + timeout: float, + pooling: bool, + key: PoolKeyType, + autocommit: bool, + isolation_level: int, + tzinfo_factory: TzInfoFactoryType | None, + sock: socket.socket | None, + use_tz: datetime.tzinfo | None, + row_strategy: RowStrategy, ) -> BaseConnection: try: login.server_name = host login.instance_name = instance - port = _resolve_instance_port( - host, - port, - instance, - timeout=timeout) + port = _resolve_instance_port(host, port, instance, timeout=timeout) if not sock: - logger.info('Opening socket to %s:%d', host, port) + logger.info("Opening socket to %s:%d", host, port) sock = socket.create_connection((host, port), timeout) except Exception as e: raise LoginError(f"Cannot connect to server '{host}': {e}", e) @@ -1362,15 +1453,21 @@ def _connect( sock.close() ### Change SPN once route exists from . import login as pytds_login + if isinstance(login.auth, pytds_login.SspiAuth): route_spn = f"MSSQLSvc@{host}:{port}" - login.auth = pytds_login.SspiAuth(user_name=login.user_name, password=login.password, - server_name=host, port=port, spn=route_spn) + login.auth = pytds_login.SspiAuth( + user_name=login.user_name, + password=login.password, + server_name=host, + port=port, + spn=route_spn, + ) return _connect( login=login, - host=route['server'], - port=route['port'], + host=route["server"], + port=route["port"], instance=instance, timeout=timeout, pooling=pooling, @@ -1402,15 +1499,15 @@ def _connect( raise -T = TypeVar('T') +T = TypeVar("T") def exponential_backoff( - work: Callable[[float], T], - ex_handler: Callable[[Exception], None], - max_time_sec: float, - first_attempt_time_sec: float, - backoff_factor: float = 2, + work: Callable[[float], T], + ex_handler: Callable[[Exception], None], + max_time_sec: float, + first_attempt_time_sec: float, + backoff_factor: float = 2, ) -> T: try_time = first_attempt_time_sec last_error: Exception | None @@ -1438,17 +1535,35 @@ def DateFromTicks(ticks: float) -> datetime.date: return datetime.date.fromtimestamp(ticks) -def Time(hour: int, minute: int, second: int, microsecond: int = 0, tzinfo: datetime.tzinfo | None = None) -> datetime.time: +def Time( + hour: int, + minute: int, + second: int, + microsecond: int = 0, + tzinfo: datetime.tzinfo | None = None, +) -> datetime.time: return datetime.time(hour, minute, second, microsecond, tzinfo) def TimeFromTicks(ticks: float) -> datetime.time: import time + return Time(*time.localtime(ticks)[3:6]) -def Timestamp(year: int, month: int, day: int, hour: int, minute: int, second: int, microseconds: int = 0, tzinfo: datetime.tzinfo | None = None) -> datetime.datetime: - return datetime.datetime(year, month, day, hour, minute, second, microseconds, tzinfo) +def Timestamp( + year: int, + month: int, + day: int, + hour: int, + minute: int, + second: int, + microseconds: int = 0, + tzinfo: datetime.tzinfo | None = None, +) -> datetime.datetime: + return datetime.datetime( + year, month, day, hour, minute, second, microseconds, tzinfo + ) def TimestampFromTicks(ticks: float) -> datetime.datetime: diff --git a/src/pytds/collate.py b/src/pytds/collate.py index 32702bf..9b95add 100644 --- a/src/pytds/collate.py +++ b/src/pytds/collate.py @@ -8,7 +8,7 @@ TDS_CHARSET_UNICODE = 5 -ucs2_codec = codecs.lookup('utf_16_le') +ucs2_codec = codecs.lookup("utf_16_le") def sortid2charset(sort_id): @@ -18,175 +18,264 @@ def sortid2charset(sort_id): # and from " NLS Information for Microsoft Windows XP" # if sql_collate in ( - 30, # SQL_Latin1_General_CP437_BIN - 31, # SQL_Latin1_General_CP437_CS_AS - 32, # SQL_Latin1_General_CP437_CI_AS - 33, # SQL_Latin1_General_Pref_CP437_CI_AS - 34): # SQL_Latin1_General_CP437_CI_AI - return 'CP437' + 30, # SQL_Latin1_General_CP437_BIN + 31, # SQL_Latin1_General_CP437_CS_AS + 32, # SQL_Latin1_General_CP437_CI_AS + 33, # SQL_Latin1_General_Pref_CP437_CI_AS + 34, + ): # SQL_Latin1_General_CP437_CI_AI + return "CP437" elif sql_collate in ( - 40, # SQL_Latin1_General_CP850_BIN - 41, # SQL_Latin1_General_CP850_CS_AS - 42, # SQL_Latin1_General_CP850_CI_AS - 43, # SQL_Latin1_General_Pref_CP850_CI_AS - 44, # SQL_Latin1_General_CP850_CI_AI - 49, # SQL_1xCompat_CP850_CI_AS - 55, # SQL_AltDiction_CP850_CS_AS - 56, # SQL_AltDiction_Pref_CP850_CI_AS - 57, # SQL_AltDiction_CP850_CI_AI - 58, # SQL_Scandinavian_Pref_CP850_CI_AS - 59, # SQL_Scandinavian_CP850_CS_AS - 60, # SQL_Scandinavian_CP850_CI_AS - 61): # SQL_AltDiction_CP850_CI_AS - return 'CP850' + 40, # SQL_Latin1_General_CP850_BIN + 41, # SQL_Latin1_General_CP850_CS_AS + 42, # SQL_Latin1_General_CP850_CI_AS + 43, # SQL_Latin1_General_Pref_CP850_CI_AS + 44, # SQL_Latin1_General_CP850_CI_AI + 49, # SQL_1xCompat_CP850_CI_AS + 55, # SQL_AltDiction_CP850_CS_AS + 56, # SQL_AltDiction_Pref_CP850_CI_AS + 57, # SQL_AltDiction_CP850_CI_AI + 58, # SQL_Scandinavian_Pref_CP850_CI_AS + 59, # SQL_Scandinavian_CP850_CS_AS + 60, # SQL_Scandinavian_CP850_CI_AS + 61, + ): # SQL_AltDiction_CP850_CI_AS + return "CP850" elif sql_collate in ( - 80, # SQL_Latin1_General_1250_BIN - 81, # SQL_Latin1_General_CP1250_CS_AS - 82, # SQL_Latin1_General_Cp1250_CI_AS_KI_WI - 83, # SQL_Czech_Cp1250_CS_AS_KI_WI - 84, # SQL_Czech_Cp1250_CI_AS_KI_WI - 85, # SQL_Hungarian_Cp1250_CS_AS_KI_WI - 86, # SQL_Hungarian_Cp1250_CI_AS_KI_WI - 87, # SQL_Polish_Cp1250_CS_AS_KI_WI - 88, # SQL_Polish_Cp1250_CI_AS_KI_WI - 89, # SQL_Romanian_Cp1250_CS_AS_KI_WI - 90, # SQL_Romanian_Cp1250_CI_AS_KI_WI - 91, # SQL_Croatian_Cp1250_CS_AS_KI_WI - 92, # SQL_Croatian_Cp1250_CI_AS_KI_WI - 93, # SQL_Slovak_Cp1250_CS_AS_KI_WI - 94, # SQL_Slovak_Cp1250_CI_AS_KI_WI - 95, # SQL_Slovenian_Cp1250_CS_AS_KI_WI - 96, # SQL_Slovenian_Cp1250_CI_AS_KI_WI - ): - return 'CP1250' + 80, # SQL_Latin1_General_1250_BIN + 81, # SQL_Latin1_General_CP1250_CS_AS + 82, # SQL_Latin1_General_Cp1250_CI_AS_KI_WI + 83, # SQL_Czech_Cp1250_CS_AS_KI_WI + 84, # SQL_Czech_Cp1250_CI_AS_KI_WI + 85, # SQL_Hungarian_Cp1250_CS_AS_KI_WI + 86, # SQL_Hungarian_Cp1250_CI_AS_KI_WI + 87, # SQL_Polish_Cp1250_CS_AS_KI_WI + 88, # SQL_Polish_Cp1250_CI_AS_KI_WI + 89, # SQL_Romanian_Cp1250_CS_AS_KI_WI + 90, # SQL_Romanian_Cp1250_CI_AS_KI_WI + 91, # SQL_Croatian_Cp1250_CS_AS_KI_WI + 92, # SQL_Croatian_Cp1250_CI_AS_KI_WI + 93, # SQL_Slovak_Cp1250_CS_AS_KI_WI + 94, # SQL_Slovak_Cp1250_CI_AS_KI_WI + 95, # SQL_Slovenian_Cp1250_CS_AS_KI_WI + 96, # SQL_Slovenian_Cp1250_CI_AS_KI_WI + ): + return "CP1250" elif sql_collate in ( - 104, # SQL_Latin1_General_1251_BIN - 105, # SQL_Latin1_General_CP1251_CS_AS - 106, # SQL_Latin1_General_CP1251_CI_AS - 107, # SQL_Ukrainian_Cp1251_CS_AS_KI_WI - 108, # SQL_Ukrainian_Cp1251_CI_AS_KI_WI - ): - return 'CP1251' + 104, # SQL_Latin1_General_1251_BIN + 105, # SQL_Latin1_General_CP1251_CS_AS + 106, # SQL_Latin1_General_CP1251_CI_AS + 107, # SQL_Ukrainian_Cp1251_CS_AS_KI_WI + 108, # SQL_Ukrainian_Cp1251_CI_AS_KI_WI + ): + return "CP1251" elif sql_collate in ( - 51, # SQL_Latin1_General_Cp1_CS_AS_KI_WI - 52, # SQL_Latin1_General_Cp1_CI_AS_KI_WI - 53, # SQL_Latin1_General_Pref_Cp1_CI_AS_KI_WI - 54, # SQL_Latin1_General_Cp1_CI_AI_KI_WI - 183, # SQL_Danish_Pref_Cp1_CI_AS_KI_WI - 184, # SQL_SwedishPhone_Pref_Cp1_CI_AS_KI_WI - 185, # SQL_SwedishStd_Pref_Cp1_CI_AS_KI_WI - 186, # SQL_Icelandic_Pref_Cp1_CI_AS_KI_WI - ): - return 'CP1252' + 51, # SQL_Latin1_General_Cp1_CS_AS_KI_WI + 52, # SQL_Latin1_General_Cp1_CI_AS_KI_WI + 53, # SQL_Latin1_General_Pref_Cp1_CI_AS_KI_WI + 54, # SQL_Latin1_General_Cp1_CI_AI_KI_WI + 183, # SQL_Danish_Pref_Cp1_CI_AS_KI_WI + 184, # SQL_SwedishPhone_Pref_Cp1_CI_AS_KI_WI + 185, # SQL_SwedishStd_Pref_Cp1_CI_AS_KI_WI + 186, # SQL_Icelandic_Pref_Cp1_CI_AS_KI_WI + ): + return "CP1252" elif sql_collate in ( - 112, # SQL_Latin1_General_1253_BIN - 113, # SQL_Latin1_General_CP1253_CS_AS - 114, # SQL_Latin1_General_CP1253_CI_AS - 120, # SQL_MixDiction_CP1253_CS_AS - 121, # SQL_AltDiction_CP1253_CS_AS - 122, # SQL_AltDiction2_CP1253_CS_AS - 124, # SQL_Latin1_General_CP1253_CI_AI - ): - return 'CP1253' + 112, # SQL_Latin1_General_1253_BIN + 113, # SQL_Latin1_General_CP1253_CS_AS + 114, # SQL_Latin1_General_CP1253_CI_AS + 120, # SQL_MixDiction_CP1253_CS_AS + 121, # SQL_AltDiction_CP1253_CS_AS + 122, # SQL_AltDiction2_CP1253_CS_AS + 124, # SQL_Latin1_General_CP1253_CI_AI + ): + return "CP1253" elif sql_collate in ( - 128, # SQL_Latin1_General_1254_BIN - 129, # SQL_Latin1_General_Cp1254_CS_AS_KI_WI - 130, # SQL_Latin1_General_Cp1254_CI_AS_KI_WI - ): - return 'CP1254' + 128, # SQL_Latin1_General_1254_BIN + 129, # SQL_Latin1_General_Cp1254_CS_AS_KI_WI + 130, # SQL_Latin1_General_Cp1254_CI_AS_KI_WI + ): + return "CP1254" elif sql_collate in ( - 136, # SQL_Latin1_General_1255_BIN - 137, # SQL_Latin1_General_CP1255_CS_AS - 138, # SQL_Latin1_General_CP1255_CI_AS - ): - return 'CP1255' + 136, # SQL_Latin1_General_1255_BIN + 137, # SQL_Latin1_General_CP1255_CS_AS + 138, # SQL_Latin1_General_CP1255_CI_AS + ): + return "CP1255" elif sql_collate in ( - 144, # SQL_Latin1_General_1256_BIN - 145, # SQL_Latin1_General_CP1256_CS_AS - 146, # SQL_Latin1_General_CP1256_CI_AS - ): - return 'CP1256' + 144, # SQL_Latin1_General_1256_BIN + 145, # SQL_Latin1_General_CP1256_CS_AS + 146, # SQL_Latin1_General_CP1256_CI_AS + ): + return "CP1256" elif sql_collate in ( - 152, # SQL_Latin1_General_1257_BIN - 153, # SQL_Latin1_General_CP1257_CS_AS - 154, # SQL_Latin1_General_CP1257_CI_AS - 155, # SQL_Estonian_Cp1257_CS_AS_KI_WI - 156, # SQL_Estonian_Cp1257_CI_AS_KI_WI - 157, # SQL_Latvian_Cp1257_CS_AS_KI_WI - 158, # SQL_Latvian_Cp1257_CI_AS_KI_WI - 159, # SQL_Lithuanian_Cp1257_CS_AS_KI_WI - 160, # SQL_Lithuanian_Cp1257_CI_AS_KI_WI - ): - return 'CP1257' + 152, # SQL_Latin1_General_1257_BIN + 153, # SQL_Latin1_General_CP1257_CS_AS + 154, # SQL_Latin1_General_CP1257_CI_AS + 155, # SQL_Estonian_Cp1257_CS_AS_KI_WI + 156, # SQL_Estonian_Cp1257_CI_AS_KI_WI + 157, # SQL_Latvian_Cp1257_CS_AS_KI_WI + 158, # SQL_Latvian_Cp1257_CI_AS_KI_WI + 159, # SQL_Lithuanian_Cp1257_CS_AS_KI_WI + 160, # SQL_Lithuanian_Cp1257_CI_AS_KI_WI + ): + return "CP1257" else: - raise Exception("Invalid collation: 0x%X" % (sql_collate, )) + raise Exception("Invalid collation: 0x%X" % (sql_collate,)) def lcid2charset(lcid): - if lcid in (0x405, - 0x40e, # 0x1040e - 0x415, 0x418, 0x41a, 0x41b, 0x41c, 0x424, - # 0x81a, seem wrong in XP table TODO check - 0x104e): - return 'CP1250' - elif lcid in (0x402, 0x419, 0x422, 0x423, 0x42f, 0x43f, - 0x440, 0x444, 0x450, - 0x81a, # ?? - 0x82c, 0x843, 0xc1a): - return 'CP1251' - elif lcid in (0x1007, 0x1009, 0x100a, 0x100c, 0x1407, - 0x1409, 0x140a, 0x140c, 0x1809, 0x180a, - 0x180c, 0x1c09, 0x1c0a, 0x2009, 0x200a, - 0x2409, 0x240a, 0x2809, 0x280a, 0x2c09, - 0x2c0a, 0x3009, 0x300a, 0x3409, 0x340a, - 0x380a, 0x3c0a, 0x400a, 0x403, 0x406, - 0x407, # 0x10407 - 0x409, 0x40a, 0x40b, 0x40c, 0x40f, 0x410, - 0x413, 0x414, 0x416, 0x41d, 0x421, 0x42d, - 0x436, - 0x437, # 0x10437 - 0x438, - # 0x439, ??? Unicode only - 0x43e, 0x440a, 0x441, 0x456, 0x480a, - 0x4c0a, 0x500a, 0x807, 0x809, 0x80a, - 0x80c, 0x810, 0x813, 0x814, 0x816, - 0x81d, 0x83e, 0xc07, 0xc09, 0xc0a, 0xc0c): - return 'CP1252' + if lcid in ( + 0x405, + 0x40E, # 0x1040e + 0x415, + 0x418, + 0x41A, + 0x41B, + 0x41C, + 0x424, + # 0x81a, seem wrong in XP table TODO check + 0x104E, + ): + return "CP1250" + elif lcid in ( + 0x402, + 0x419, + 0x422, + 0x423, + 0x42F, + 0x43F, + 0x440, + 0x444, + 0x450, + 0x81A, # ?? + 0x82C, + 0x843, + 0xC1A, + ): + return "CP1251" + elif lcid in ( + 0x1007, + 0x1009, + 0x100A, + 0x100C, + 0x1407, + 0x1409, + 0x140A, + 0x140C, + 0x1809, + 0x180A, + 0x180C, + 0x1C09, + 0x1C0A, + 0x2009, + 0x200A, + 0x2409, + 0x240A, + 0x2809, + 0x280A, + 0x2C09, + 0x2C0A, + 0x3009, + 0x300A, + 0x3409, + 0x340A, + 0x380A, + 0x3C0A, + 0x400A, + 0x403, + 0x406, + 0x407, # 0x10407 + 0x409, + 0x40A, + 0x40B, + 0x40C, + 0x40F, + 0x410, + 0x413, + 0x414, + 0x416, + 0x41D, + 0x421, + 0x42D, + 0x436, + 0x437, # 0x10437 + 0x438, + # 0x439, ??? Unicode only + 0x43E, + 0x440A, + 0x441, + 0x456, + 0x480A, + 0x4C0A, + 0x500A, + 0x807, + 0x809, + 0x80A, + 0x80C, + 0x810, + 0x813, + 0x814, + 0x816, + 0x81D, + 0x83E, + 0xC07, + 0xC09, + 0xC0A, + 0xC0C, + ): + return "CP1252" elif lcid == 0x408: - return 'CP1253' - elif lcid in (0x41f, 0x42c, 0x443): - return 'CP1254' - elif lcid == 0x40d: - return 'CP1255' - elif lcid in (0x1001, 0x1401, 0x1801, 0x1c01, 0x2001, - 0x2401, 0x2801, 0x2c01, 0x3001, 0x3401, - 0x3801, 0x3c01, 0x4001, 0x401, 0x420, - 0x429, 0x801, 0xc01): - return 'CP1256' - elif lcid in (0x425, 0x426, 0x427, - 0x827): # ?? - return 'CP1257' - elif lcid == 0x42a: - return 'CP1258' - elif lcid == 0x41e: - return 'CP874' + return "CP1253" + elif lcid in (0x41F, 0x42C, 0x443): + return "CP1254" + elif lcid == 0x40D: + return "CP1255" + elif lcid in ( + 0x1001, + 0x1401, + 0x1801, + 0x1C01, + 0x2001, + 0x2401, + 0x2801, + 0x2C01, + 0x3001, + 0x3401, + 0x3801, + 0x3C01, + 0x4001, + 0x401, + 0x420, + 0x429, + 0x801, + 0xC01, + ): + return "CP1256" + elif lcid in (0x425, 0x426, 0x427, 0x827): # ?? + return "CP1257" + elif lcid == 0x42A: + return "CP1258" + elif lcid == 0x41E: + return "CP874" elif lcid == 0x411: # 0x10411 - return 'CP932' - elif lcid in (0x1004, - 0x804): # 0x20804 - return 'CP936' + return "CP932" + elif lcid in (0x1004, 0x804): # 0x20804 + return "CP936" elif lcid == 0x412: # 0x10412 - return 'CP949' - elif lcid in (0x1404, - 0x404, # 0x30404 - 0xc04): - return 'CP950' + return "CP949" + elif lcid in ( + 0x1404, + 0x404, # 0x30404 + 0xC04, + ): + return "CP950" else: - return 'CP1252' + return "CP1252" class Collation(object): - _coll_struct = struct.Struct('> 26 - return cls(lcid=lcid, - ignore_case=ignore_case, - ignore_accent=ignore_accent, - ignore_width=ignore_width, - ignore_kana=ignore_kana, - binary=binary, - binary2=binary2, - version=version, - sort_id=sort_id) + version = (lump & 0xF0000000) >> 26 + return cls( + lcid=lcid, + ignore_case=ignore_case, + ignore_accent=ignore_accent, + ignore_width=ignore_width, + ignore_kana=ignore_kana, + binary=binary, + binary2=binary2, + version=version, + sort_id=sort_id, + ) def pack(self): lump = 0 - lump |= self.lcid & 0xfffff - lump |= (self.version << 26) & 0xf0000000 + lump |= self.lcid & 0xFFFFF + lump |= (self.version << 26) & 0xF0000000 if self.ignore_case: lump |= self.f_ignore_case if self.ignore_accent: diff --git a/src/pytds/connection_pool.py b/src/pytds/connection_pool.py index 39b37ab..738eaf8 100644 --- a/src/pytds/connection_pool.py +++ b/src/pytds/connection_pool.py @@ -20,7 +20,8 @@ bool, Union[AuthProtocol, None], datetime.tzinfo, - bool] + bool, +] class ConnectionPool: diff --git a/src/pytds/lcid.py b/src/pytds/lcid.py index a0d5937..8757434 100644 --- a/src/pytds/lcid.py +++ b/src/pytds/lcid.py @@ -19,77 +19,211 @@ __docformat__ = "restructuredtext en" __all__ = [ - "LANGID_AFRIKAANS", "LANGID_ALBANIAN", "LANGID_AMHARIC", "LANGID_ARABIC", - "LANGID_ARABIC_ALGERIA", "LANGID_ARABIC_BAHRAIN", "LANGID_ARABIC_EGYPT", - "LANGID_ARABIC_IRAQ", "LANGID_ARABIC_JORDAN", "LANGID_ARABIC_KUWAIT", - "LANGID_ARABIC_LEBANON", "LANGID_ARABIC_LIBYA", "LANGID_ARABIC_MOROCCO", - "LANGID_ARABIC_OMAN", "LANGID_ARABIC_QATAR", "LANGID_ARABIC_SYRIA", - "LANGID_ARABIC_TUNISIA", "LANGID_ARABIC_UAE", "LANGID_ARABIC_YEMEN", - "LANGID_ARMENIAN", "LANGID_ASSAMESE", "LANGID_AZERI_CYRILLIC", - "LANGID_AZERI_LATIN", "LANGID_BASQUE", "LANGID_BELGIAN_DUTCH", - "LANGID_BELGIAN_FRENCH", "LANGID_BENGALI", "LANGID_BULGARIAN", - "LANGID_BURMESE", "LANGID_BYELORUSSIAN", "LANGID_CATALAN", - "LANGID_CHEROKEE", "LANGID_CHINESE_HONG_KONG_SAR", - "LANGID_CHINESE_MACAO_SAR", "LANGID_CHINESE_SINGAPORE", "LANGID_CROATIAN", - "LANGID_CZECH", "LANGID_DANISH", "LANGID_DIVEHI", "LANGID_DUTCH", - "LANGID_EDO", "LANGID_ENGLISH_AUS", "LANGID_ENGLISH_BELIZE", - "LANGID_ENGLISH_CANADIAN", "LANGID_ENGLISH_CARIBBEAN", - "LANGID_ENGLISH_INDONESIA", "LANGID_ENGLISH_IRELAND", - "LANGID_ENGLISH_JAMAICA", "LANGID_ENGLISH_NEW_ZEALAND", - "LANGID_ENGLISH_PHILIPPINES", "LANGID_ENGLISH_SOUTH_AFRICA", - "LANGID_ENGLISH_TRINIDAD_TOBAGO", "LANGID_ENGLISH_UK", "LANGID_ENGLISH_US", - "LANGID_ENGLISH_ZIMBABWE", "LANGID_ESTONIAN", "LANGID_FAEROESE", - "LANGID_FILIPINO", "LANGID_FINNISH", "LANGID_FRENCH", - "LANGID_FRENCH_CAMEROON", "LANGID_FRENCH_CANADIAN", - "LANGID_FRENCH_CONGO_D_R_C", "LANGID_FRENCH_COTED_IVOIRE", - "LANGID_FRENCH_HAITI", "LANGID_FRENCH_LUXEMBOURG", "LANGID_FRENCH_MALI", - "LANGID_FRENCH_MONACO", "LANGID_FRENCH_MOROCCO", "LANGID_FRENCH_REUNION", - "LANGID_FRENCH_SENEGAL", "LANGID_FRENCH_WEST_INDIES", - "LANGID_FRISIAN_NETHERLANDS", "LANGID_FULFULDE", "LANGID_GAELIC_IRELAND", - "LANGID_GAELIC_SCOTLAND", "LANGID_GALICIAN", "LANGID_GEORGIAN", - "LANGID_GERMAN", "LANGID_GERMAN_AUSTRIA", "LANGID_GERMAN_LIECHTENSTEIN", - "LANGID_GERMAN_LUXEMBOURG", "LANGID_GREEK", "LANGID_GUARANI", - "LANGID_GUJARATI", "LANGID_HAUSA", "LANGID_HAWAIIAN", "LANGID_HEBREW", - "LANGID_HINDI", "LANGID_HUNGARIAN", "LANGID_IBIBIO", "LANGID_ICELANDIC", - "LANGID_IGBO", "LANGID_INDONESIAN", "LANGID_INUKTITUT", "LANGID_ITALIAN", - "LANGID_JAPANESE", "LANGID_KANNADA", "LANGID_KANURI", "LANGID_KASHMIRI", - "LANGID_KAZAKH", "LANGID_KHMER", "LANGID_KIRGHIZ", "LANGID_KONKANI", - "LANGID_KOREAN", "LANGID_KYRGYZ", "LANGID_LANGUAGE_NONE", "LANGID_LAO", - "LANGID_LATIN", "LANGID_LATVIAN", "LANGID_LITHUANIAN", - "LANGID_MACEDONIAN_FYROM", "LANGID_MALAYALAM", "LANGID_MALAYSIAN", - "LANGID_MALAY_BRUNEI_DARUSSALAM", "LANGID_MALTESE", "LANGID_MANIPURI", - "LANGID_MARATHI", "LANGID_MEXICAN_SPANISH", "LANGID_MONGOLIAN", - "LANGID_NEPALI", "LANGID_NORWEGIAN_BOKMOL", "LANGID_NORWEGIAN_NYNORSK", - "LANGID_NO_PROOFING", "LANGID_ORIYA", "LANGID_OROMO", "LANGID_PASHTO", - "LANGID_PERSIAN", "LANGID_POLISH", "LANGID_PORTUGUESE", - "LANGID_PORTUGUESE_BRAZIL", "LANGID_PUNJABI", "LANGID_RHAETO_ROMANIC", - "LANGID_ROMANIAN", "LANGID_ROMANIAN_MOLDOVA", "LANGID_RUSSIAN", - "LANGID_RUSSIAN_MOLDOVA", "LANGID_SAMI_LAPPISH", "LANGID_SANSKRIT", - "LANGID_SERBIAN_CYRILLIC", "LANGID_SERBIAN_LATIN", "LANGID_SESOTHO", - "LANGID_SIMPLIFIED_CHINESE", "LANGID_SINDHI", "LANGID_SINDHI_PAKISTAN", - "LANGID_SINHALESE", "LANGID_SLOVAK", "LANGID_SLOVENIAN", "LANGID_SOMALI", - "LANGID_SORBIAN", "LANGID_SPANISH", "LANGID_SPANISH_ARGENTINA", - "LANGID_SPANISH_BOLIVIA", "LANGID_SPANISH_CHILE", - "LANGID_SPANISH_COLOMBIA", "LANGID_SPANISH_COSTA_RICA", - "LANGID_SPANISH_DOMINICAN_REPUBLIC", "LANGID_SPANISH_ECUADOR", - "LANGID_SPANISH_EL_SALVADOR", "LANGID_SPANISH_GUATEMALA", - "LANGID_SPANISH_HONDURAS", "LANGID_SPANISH_MODERN_SORT", - "LANGID_SPANISH_NICARAGUA", "LANGID_SPANISH_PANAMA", - "LANGID_SPANISH_PARAGUAY", "LANGID_SPANISH_PERU", - "LANGID_SPANISH_PUERTO_RICO", "LANGID_SPANISH_URUGUAY", - "LANGID_SPANISH_VENEZUELA", "LANGID_SUTU", "LANGID_SWAHILI", - "LANGID_SWEDISH", "LANGID_SWEDISH_FINLAND", "LANGID_SWISS_FRENCH", - "LANGID_SWISS_GERMAN", "LANGID_SWISS_ITALIAN", "LANGID_SYRIAC", - "LANGID_TAJIK", "LANGID_TAMAZIGHT", "LANGID_TAMAZIGHT_LATIN", - "LANGID_TAMIL", "LANGID_TATAR", "LANGID_TELUGU", "LANGID_THAI", - "LANGID_TIBETAN", "LANGID_TIGRIGNA_ERITREA", "LANGID_TIGRIGNA_ETHIOPIC", - "LANGID_TRADITIONAL_CHINESE", "LANGID_TSONGA", "LANGID_TSWANA", - "LANGID_TURKISH", "LANGID_TURKMEN", "LANGID_UKRAINIAN", "LANGID_URDU", - "LANGID_UZBEK_CYRILLIC", "LANGID_UZBEK_LATIN", "LANGID_VENDA", - "LANGID_VIETNAMESE", "LANGID_WELSH", "LANGID_XHOSA", "LANGID_YI", - "LANGID_YIDDISH", "LANGID_YORUBA", "LANGID_ZULU", - - "lang_id_names" + "LANGID_AFRIKAANS", + "LANGID_ALBANIAN", + "LANGID_AMHARIC", + "LANGID_ARABIC", + "LANGID_ARABIC_ALGERIA", + "LANGID_ARABIC_BAHRAIN", + "LANGID_ARABIC_EGYPT", + "LANGID_ARABIC_IRAQ", + "LANGID_ARABIC_JORDAN", + "LANGID_ARABIC_KUWAIT", + "LANGID_ARABIC_LEBANON", + "LANGID_ARABIC_LIBYA", + "LANGID_ARABIC_MOROCCO", + "LANGID_ARABIC_OMAN", + "LANGID_ARABIC_QATAR", + "LANGID_ARABIC_SYRIA", + "LANGID_ARABIC_TUNISIA", + "LANGID_ARABIC_UAE", + "LANGID_ARABIC_YEMEN", + "LANGID_ARMENIAN", + "LANGID_ASSAMESE", + "LANGID_AZERI_CYRILLIC", + "LANGID_AZERI_LATIN", + "LANGID_BASQUE", + "LANGID_BELGIAN_DUTCH", + "LANGID_BELGIAN_FRENCH", + "LANGID_BENGALI", + "LANGID_BULGARIAN", + "LANGID_BURMESE", + "LANGID_BYELORUSSIAN", + "LANGID_CATALAN", + "LANGID_CHEROKEE", + "LANGID_CHINESE_HONG_KONG_SAR", + "LANGID_CHINESE_MACAO_SAR", + "LANGID_CHINESE_SINGAPORE", + "LANGID_CROATIAN", + "LANGID_CZECH", + "LANGID_DANISH", + "LANGID_DIVEHI", + "LANGID_DUTCH", + "LANGID_EDO", + "LANGID_ENGLISH_AUS", + "LANGID_ENGLISH_BELIZE", + "LANGID_ENGLISH_CANADIAN", + "LANGID_ENGLISH_CARIBBEAN", + "LANGID_ENGLISH_INDONESIA", + "LANGID_ENGLISH_IRELAND", + "LANGID_ENGLISH_JAMAICA", + "LANGID_ENGLISH_NEW_ZEALAND", + "LANGID_ENGLISH_PHILIPPINES", + "LANGID_ENGLISH_SOUTH_AFRICA", + "LANGID_ENGLISH_TRINIDAD_TOBAGO", + "LANGID_ENGLISH_UK", + "LANGID_ENGLISH_US", + "LANGID_ENGLISH_ZIMBABWE", + "LANGID_ESTONIAN", + "LANGID_FAEROESE", + "LANGID_FILIPINO", + "LANGID_FINNISH", + "LANGID_FRENCH", + "LANGID_FRENCH_CAMEROON", + "LANGID_FRENCH_CANADIAN", + "LANGID_FRENCH_CONGO_D_R_C", + "LANGID_FRENCH_COTED_IVOIRE", + "LANGID_FRENCH_HAITI", + "LANGID_FRENCH_LUXEMBOURG", + "LANGID_FRENCH_MALI", + "LANGID_FRENCH_MONACO", + "LANGID_FRENCH_MOROCCO", + "LANGID_FRENCH_REUNION", + "LANGID_FRENCH_SENEGAL", + "LANGID_FRENCH_WEST_INDIES", + "LANGID_FRISIAN_NETHERLANDS", + "LANGID_FULFULDE", + "LANGID_GAELIC_IRELAND", + "LANGID_GAELIC_SCOTLAND", + "LANGID_GALICIAN", + "LANGID_GEORGIAN", + "LANGID_GERMAN", + "LANGID_GERMAN_AUSTRIA", + "LANGID_GERMAN_LIECHTENSTEIN", + "LANGID_GERMAN_LUXEMBOURG", + "LANGID_GREEK", + "LANGID_GUARANI", + "LANGID_GUJARATI", + "LANGID_HAUSA", + "LANGID_HAWAIIAN", + "LANGID_HEBREW", + "LANGID_HINDI", + "LANGID_HUNGARIAN", + "LANGID_IBIBIO", + "LANGID_ICELANDIC", + "LANGID_IGBO", + "LANGID_INDONESIAN", + "LANGID_INUKTITUT", + "LANGID_ITALIAN", + "LANGID_JAPANESE", + "LANGID_KANNADA", + "LANGID_KANURI", + "LANGID_KASHMIRI", + "LANGID_KAZAKH", + "LANGID_KHMER", + "LANGID_KIRGHIZ", + "LANGID_KONKANI", + "LANGID_KOREAN", + "LANGID_KYRGYZ", + "LANGID_LANGUAGE_NONE", + "LANGID_LAO", + "LANGID_LATIN", + "LANGID_LATVIAN", + "LANGID_LITHUANIAN", + "LANGID_MACEDONIAN_FYROM", + "LANGID_MALAYALAM", + "LANGID_MALAYSIAN", + "LANGID_MALAY_BRUNEI_DARUSSALAM", + "LANGID_MALTESE", + "LANGID_MANIPURI", + "LANGID_MARATHI", + "LANGID_MEXICAN_SPANISH", + "LANGID_MONGOLIAN", + "LANGID_NEPALI", + "LANGID_NORWEGIAN_BOKMOL", + "LANGID_NORWEGIAN_NYNORSK", + "LANGID_NO_PROOFING", + "LANGID_ORIYA", + "LANGID_OROMO", + "LANGID_PASHTO", + "LANGID_PERSIAN", + "LANGID_POLISH", + "LANGID_PORTUGUESE", + "LANGID_PORTUGUESE_BRAZIL", + "LANGID_PUNJABI", + "LANGID_RHAETO_ROMANIC", + "LANGID_ROMANIAN", + "LANGID_ROMANIAN_MOLDOVA", + "LANGID_RUSSIAN", + "LANGID_RUSSIAN_MOLDOVA", + "LANGID_SAMI_LAPPISH", + "LANGID_SANSKRIT", + "LANGID_SERBIAN_CYRILLIC", + "LANGID_SERBIAN_LATIN", + "LANGID_SESOTHO", + "LANGID_SIMPLIFIED_CHINESE", + "LANGID_SINDHI", + "LANGID_SINDHI_PAKISTAN", + "LANGID_SINHALESE", + "LANGID_SLOVAK", + "LANGID_SLOVENIAN", + "LANGID_SOMALI", + "LANGID_SORBIAN", + "LANGID_SPANISH", + "LANGID_SPANISH_ARGENTINA", + "LANGID_SPANISH_BOLIVIA", + "LANGID_SPANISH_CHILE", + "LANGID_SPANISH_COLOMBIA", + "LANGID_SPANISH_COSTA_RICA", + "LANGID_SPANISH_DOMINICAN_REPUBLIC", + "LANGID_SPANISH_ECUADOR", + "LANGID_SPANISH_EL_SALVADOR", + "LANGID_SPANISH_GUATEMALA", + "LANGID_SPANISH_HONDURAS", + "LANGID_SPANISH_MODERN_SORT", + "LANGID_SPANISH_NICARAGUA", + "LANGID_SPANISH_PANAMA", + "LANGID_SPANISH_PARAGUAY", + "LANGID_SPANISH_PERU", + "LANGID_SPANISH_PUERTO_RICO", + "LANGID_SPANISH_URUGUAY", + "LANGID_SPANISH_VENEZUELA", + "LANGID_SUTU", + "LANGID_SWAHILI", + "LANGID_SWEDISH", + "LANGID_SWEDISH_FINLAND", + "LANGID_SWISS_FRENCH", + "LANGID_SWISS_GERMAN", + "LANGID_SWISS_ITALIAN", + "LANGID_SYRIAC", + "LANGID_TAJIK", + "LANGID_TAMAZIGHT", + "LANGID_TAMAZIGHT_LATIN", + "LANGID_TAMIL", + "LANGID_TATAR", + "LANGID_TELUGU", + "LANGID_THAI", + "LANGID_TIBETAN", + "LANGID_TIGRIGNA_ERITREA", + "LANGID_TIGRIGNA_ETHIOPIC", + "LANGID_TRADITIONAL_CHINESE", + "LANGID_TSONGA", + "LANGID_TSWANA", + "LANGID_TURKISH", + "LANGID_TURKMEN", + "LANGID_UKRAINIAN", + "LANGID_URDU", + "LANGID_UZBEK_CYRILLIC", + "LANGID_UZBEK_LATIN", + "LANGID_VENDA", + "LANGID_VIETNAMESE", + "LANGID_WELSH", + "LANGID_XHOSA", + "LANGID_YI", + "LANGID_YIDDISH", + "LANGID_YORUBA", + "LANGID_ZULU", + "lang_id_names", ] LANGID_AFRIKAANS = 1078 @@ -501,5 +635,5 @@ LANGID_YI: "Yi", LANGID_YIDDISH: "Yiddish", LANGID_YORUBA: "Yoruba", - LANGID_ZULU: "Zulu" + LANGID_ZULU: "Zulu", } diff --git a/src/pytds/login.py b/src/pytds/login.py index 1429c80..3a1e361 100644 --- a/src/pytds/login.py +++ b/src/pytds/login.py @@ -17,7 +17,7 @@ class SspiAuth(AuthProtocol): - """ SSPI authentication + """SSPI authentication :platform: Windows @@ -34,18 +34,24 @@ class SspiAuth(AuthProtocol): :keyword spn: Service name :type spn: str """ - def __init__(self, user_name: str = '', password: str = '', server_name: str = '', port: int | None = None, spn: str | None = None): + + def __init__( + self, + user_name: str = "", + password: str = "", + server_name: str = "", + port: int | None = None, + spn: str | None = None, + ): from . import sspi + # parse username/password informations - if '\\' in user_name: - domain, user_name = user_name.split('\\') + if "\\" in user_name: + domain, user_name = user_name.split("\\") else: - domain = '' + domain = "" if domain and user_name: - self._identity = sspi.make_winnt_identity( - domain, - user_name, - password) + self._identity = sspi.make_winnt_identity(domain, user_name, password) else: self._identity = None # build SPN @@ -53,26 +59,31 @@ def __init__(self, user_name: str = '', password: str = '', server_name: str = ' self._sname = spn else: primary_host_name, _, _ = socket.gethostbyname_ex(server_name) - self._sname = 'MSSQLSvc/{0}:{1}'.format(primary_host_name, port) + self._sname = "MSSQLSvc/{0}:{1}".format(primary_host_name, port) # using Negotiate system will use proper protocol (either NTLM or Kerberos) self._cred = sspi.SspiCredentials( - package='Negotiate', - use=sspi.SECPKG_CRED_OUTBOUND, - identity=self._identity) - - self._flags = sspi.ISC_REQ_CONFIDENTIALITY | sspi.ISC_REQ_REPLAY_DETECT | sspi.ISC_REQ_CONNECTION + package="Negotiate", use=sspi.SECPKG_CRED_OUTBOUND, identity=self._identity + ) + + self._flags = ( + sspi.ISC_REQ_CONFIDENTIALITY + | sspi.ISC_REQ_REPLAY_DETECT + | sspi.ISC_REQ_CONNECTION + ) self._ctx = None def create_packet(self) -> bytes: from . import sspi import ctypes + buf = ctypes.create_string_buffer(4096) ctx, status, bufs = self._cred.create_context( flags=self._flags, - byte_ordering='network', + byte_ordering="network", target_name=self._sname, - output_buffers=[(sspi.SECBUFFER_TOKEN, buf)]) + output_buffers=[(sspi.SECBUFFER_TOKEN, buf)], + ) self._ctx = ctx if status == sspi.Status.SEC_I_COMPLETE_AND_CONTINUE: ctx.complete_auth_token(bufs) @@ -81,14 +92,16 @@ def create_packet(self) -> bytes: def handle_next(self, packet: bytes) -> bytes | None: from . import sspi import ctypes + if self._ctx: buf = ctypes.create_string_buffer(4096) status, buffers = self._ctx.next( flags=self._flags, - byte_ordering='network', + byte_ordering="network", target_name=self._sname, input_buffers=[(sspi.SECBUFFER_TOKEN, packet)], - output_buffers=[(sspi.SECBUFFER_TOKEN, buf)]) + output_buffers=[(sspi.SECBUFFER_TOKEN, buf)], + ) return buffers[0][1] else: return None @@ -117,22 +130,29 @@ class NtlmAuth(AuthProtocol): def __init__(self, user_name: str, password: str, ntlm_compatibility: int = 3): self._user_name = user_name - if '\\' in user_name: - domain, self._user = user_name.split('\\', 1) + if "\\" in user_name: + domain, self._user = user_name.split("\\", 1) self._domain = domain.upper() else: - self._domain = 'WORKSPACE' + self._domain = "WORKSPACE" self._user = user_name self._password = password self._workstation = socket.gethostname().upper() try: - from ntlm_auth.ntlm import NtlmContext # type: ignore # fix later + from ntlm_auth.ntlm import NtlmContext # type: ignore # fix later except ImportError: - raise ImportError("To use NTLM authentication you need to install ntlm-auth module") - - self._ntlm_context = NtlmContext(self._user, self._password, self._domain, self._workstation, - ntlm_compatibility=ntlm_compatibility) + raise ImportError( + "To use NTLM authentication you need to install ntlm-auth module" + ) + + self._ntlm_context = NtlmContext( + self._user, + self._password, + self._domain, + self._workstation, + ntlm_compatibility=ntlm_compatibility, + ) def create_packet(self) -> bytes: return self._ntlm_context.step() @@ -145,7 +165,7 @@ def close(self) -> None: class SpnegoAuth(AuthProtocol): - """ Authentication using Negotiate protocol, uses implementation provided pyspnego package + """Authentication using Negotiate protocol, uses implementation provided pyspnego package Takes same parameters as spnego.client function. """ @@ -154,7 +174,9 @@ def __init__(self, *args, **kwargs): try: import spnego except ImportError: - raise ImportError("To use spnego authentication you need to install pyspnego package") + raise ImportError( + "To use spnego authentication you need to install pyspnego package" + ) self._context = spnego.client(*args, **kwargs) def create_packet(self) -> bytes: @@ -170,36 +192,40 @@ def close(self) -> None: class KerberosAuth(AuthProtocol): def __init__(self, server_principal): try: - import kerberos # type: ignore # fix later + import kerberos # type: ignore # fix later except ImportError: - import winkerberos as kerberos # type: ignore # fix later + import winkerberos as kerberos # type: ignore # fix later self._kerberos = kerberos res, context = kerberos.authGSSClientInit(server_principal) if res < 0: - raise RuntimeError('authGSSClientInit failed with code {}'.format(res)) - logger.info('Initialized GSS context') + raise RuntimeError("authGSSClientInit failed with code {}".format(res)) + logger.info("Initialized GSS context") self._context = context def create_packet(self) -> bytes: import base64 - res = self._kerberos.authGSSClientStep(self._context, '') + + res = self._kerberos.authGSSClientStep(self._context, "") if res < 0: - raise RuntimeError('authGSSClientStep failed with code {}'.format(res)) + raise RuntimeError("authGSSClientStep failed with code {}".format(res)) data = self._kerberos.authGSSClientResponse(self._context) - logger.info('created first client GSS packet %s', data) + logger.info("created first client GSS packet %s", data) return base64.b64decode(data) def handle_next(self, packet: bytes) -> bytes | None: import base64 - res = self._kerberos.authGSSClientStep(self._context, base64.b64encode(packet).decode('ascii')) + + res = self._kerberos.authGSSClientStep( + self._context, base64.b64encode(packet).decode("ascii") + ) if res < 0: - raise RuntimeError('authGSSClientStep failed with code {}'.format(res)) + raise RuntimeError("authGSSClientStep failed with code {}".format(res)) if res == self._kerberos.AUTH_GSS_COMPLETE: - logger.info('GSS authentication completed') - return b'' + logger.info("GSS authentication completed") + return b"" else: data = self._kerberos.authGSSClientResponse(self._context) - logger.info('created client GSS packet %s', data) + logger.info("created client GSS packet %s", data) return base64.b64decode(data) def close(self) -> None: diff --git a/src/pytds/row_strategies.py b/src/pytds/row_strategies.py index e0b252f..44f59f2 100644 --- a/src/pytds/row_strategies.py +++ b/src/pytds/row_strategies.py @@ -12,21 +12,24 @@ RowStrategy = Callable[[Iterable[str]], RowGenerator] -def tuple_row_strategy(column_names: Iterable[str]) -> Callable[[Iterable[Any]], Tuple[Any, ...]]: - """ Tuple row strategy, rows returned as tuples, default - """ +def tuple_row_strategy( + column_names: Iterable[str] +) -> Callable[[Iterable[Any]], Tuple[Any, ...]]: + """Tuple row strategy, rows returned as tuples, default""" return tuple -def list_row_strategy(column_names: Iterable[str]) -> Callable[[Iterable[Any]], List[Any]]: - """ List row strategy, rows returned as lists - """ +def list_row_strategy( + column_names: Iterable[str] +) -> Callable[[Iterable[Any]], List[Any]]: + """List row strategy, rows returned as lists""" return list -def dict_row_strategy(column_names: Iterable[str]) -> Callable[[Iterable[Any]], Dict[str, Any]]: - """ Dict row strategy, rows returned as dictionaries - """ +def dict_row_strategy( + column_names: Iterable[str] +) -> Callable[[Iterable[Any]], Dict[str, Any]]: + """Dict row strategy, rows returned as dictionaries""" # replace empty column names with indices column_names = [(name or str(idx)) for idx, name in enumerate(column_names)] @@ -37,20 +40,28 @@ def row_factory(row: Iterable[Any]) -> Dict[str, Any]: def is_valid_identifier(name: str) -> bool: - """ Returns true if given name can be used as an identifier in Python, otherwise returns false. - """ - return bool(name and re.match("^[_A-Za-z][_a-zA-Z0-9]*$", name) and not keyword.iskeyword(name)) + """Returns true if given name can be used as an identifier in Python, otherwise returns false.""" + return bool( + name + and re.match("^[_A-Za-z][_a-zA-Z0-9]*$", name) + and not keyword.iskeyword(name) + ) -def namedtuple_row_strategy(column_names: Iterable[str]) -> Callable[[Iterable[Any]], NamedTuple]: - """ Namedtuple row strategy, rows returned as named tuples +def namedtuple_row_strategy( + column_names: Iterable[str] +) -> Callable[[Iterable[Any]], NamedTuple]: + """Namedtuple row strategy, rows returned as named tuples Column names that are not valid Python identifiers will be replaced with col_ """ # replace empty column names with placeholders - clean_column_names = [name if is_valid_identifier(name) else f'col{idx}_' for idx, name in enumerate(column_names)] - row_class = collections.namedtuple('Row', clean_column_names) # type: ignore # needs fixing + clean_column_names = [ + name if is_valid_identifier(name) else f"col{idx}_" + for idx, name in enumerate(column_names) + ] + row_class = collections.namedtuple("Row", clean_column_names) # type: ignore # needs fixing def row_factory(row: Iterable[Any]) -> NamedTuple: return row_class(*row) @@ -58,8 +69,10 @@ def row_factory(row: Iterable[Any]) -> NamedTuple: return row_factory -def recordtype_row_strategy(column_names: Iterable[str]) -> Callable[[Iterable[Any]], Any]: - """ Recordtype row strategy, rows returned as recordtypes +def recordtype_row_strategy( + column_names: Iterable[str] +) -> Callable[[Iterable[Any]], Any]: + """Recordtype row strategy, rows returned as recordtypes Column names that are not valid Python identifiers will be replaced with col_ @@ -69,11 +82,14 @@ def recordtype_row_strategy(column_names: Iterable[str]) -> Callable[[Iterable[A except ImportError: from recordtype import recordtype # type: ignore # needs fixing # optional dependency # replace empty column names with placeholders - column_names = [name if is_valid_identifier(name) else 'col%s_' % idx for idx, name in enumerate(column_names)] - recordtype_row_class = recordtype('Row', column_names) + column_names = [ + name if is_valid_identifier(name) else "col%s_" % idx + for idx, name in enumerate(column_names) + ] + recordtype_row_class = recordtype("Row", column_names) # custom extension class that supports indexing - class Row(recordtype_row_class): # type: ignore # needs fixing + class Row(recordtype_row_class): # type: ignore # needs fixing def __getitem__(self, index): if isinstance(index, slice): return tuple(getattr(self, x) for x in self.__slots__[index]) diff --git a/src/pytds/smp.py b/src/pytds/smp.py index 6567fa6..6bba78f 100644 --- a/src/pytds/smp.py +++ b/src/pytds/smp.py @@ -16,6 +16,7 @@ try: from bitarray import bitarray # type: ignore # fix typing later except ImportError: + class BitArray(list): def __init__(self, size: int): super(BitArray, self).__init__() @@ -31,7 +32,7 @@ def setall(self, val: bool) -> None: logger = logging.getLogger(__name__) -SMP_HEADER = struct.Struct(' int | None: return self._state @@ -67,7 +73,7 @@ def sendall(self, data: bytes, flags: int = 0) -> None: self._mgr.send_packet(self, data) def _recv_internal(self, size: int) -> Tuple[int, int]: - if not self._curr_buf[self._curr_buf_pos:]: + if not self._curr_buf[self._curr_buf_pos :]: self._curr_buf = self._mgr.recv_packet(self) self._curr_buf_pos = 0 if not self._curr_buf: @@ -77,17 +83,19 @@ def _recv_internal(self, size: int) -> Tuple[int, int]: self._curr_buf_pos += to_read return offset, to_read - def recv_into(self, buffer: bytearray | memoryview, size: int = 0, flags: int = 0) -> int: + def recv_into( + self, buffer: bytearray | memoryview, size: int = 0, flags: int = 0 + ) -> int: if size == 0: size = len(buffer) offset, to_read = self._recv_internal(size) - buffer[:to_read] = self._curr_buf[offset:offset + to_read] + buffer[:to_read] = self._curr_buf[offset : offset + to_read] return to_read def recv(self, size: int) -> bytes: offset, to_read = self._recv_internal(size) - return self._curr_buf[offset:offset + to_read] + return self._curr_buf[offset : offset + to_read] def is_connected(self) -> bool: return self._state == SessionState.SESSION_ESTABLISHED @@ -105,8 +113,8 @@ class PacketTypes: FIN = 0x4 DATA = 0x8 - #@staticmethod - #def type_to_str(t): + # @staticmethod + # def type_to_str(t): # if t == PacketTypes.SYN: # return 'SYN' # elif t == PacketTypes.ACK: @@ -126,25 +134,25 @@ class SessionState: @staticmethod def to_str(st: int) -> str: if st == SessionState.SESSION_ESTABLISHED: - return 'SESSION ESTABLISHED' + return "SESSION ESTABLISHED" elif st == SessionState.CLOSED: - return 'CLOSED' + return "CLOSED" elif st == SessionState.FIN_SENT: - return 'FIN SENT' + return "FIN SENT" elif st == SessionState.FIN_RECEIVED: - return 'FIN RECEIVED' + return "FIN RECEIVED" else: raise RuntimeError(f"invalid session state: {st}") class SmpManager: - def __init__(self, transport: TransportProtocol, max_sessions: int = 2 ** 16): + def __init__(self, transport: TransportProtocol, max_sessions: int = 2**16): self._transport = transport self._sessions: Dict[int, _SmpSession] = {} self._used_ids_ba = bitarray(max_sessions) self._used_ids_ba.setall(False) self._lock = threading.RLock() - self._hdr_buf = memoryview(bytearray(b'\x00' * SMP_HEADER.size)) + self._hdr_buf = memoryview(bytearray(b"\x00" * SMP_HEADER.size)) def __repr__(self): return "".format(self._sessions) @@ -153,7 +161,9 @@ def create_session(self) -> _SmpSession: try: session_id = self._used_ids_ba.index(False) except ValueError: - raise Error("Can't create more MARS sessions, close some sessions and try again") + raise Error( + "Can't create more MARS sessions, close some sessions and try again" + ) session = _SmpSession(self, session_id) with self._lock: self._sessions[session_id] = session @@ -165,7 +175,7 @@ def create_session(self) -> _SmpSession: SMP_HEADER.size, 0, session.high_water_for_recv, - ) + ) self._transport.sendall(hdr) session._state = SessionState.SESSION_ESTABLISHED return session @@ -187,7 +197,7 @@ def close_smp_session(self, session: _SmpSession) -> None: SMP_HEADER.size, session.seq_num_for_send, session.high_water_for_recv, - ) + ) session._state = SessionState.FIN_SENT try: self._transport.sendall(hdr) @@ -200,17 +210,23 @@ def close_smp_session(self, session: _SmpSession) -> None: def send_queued_packets(self, session: _SmpSession) -> None: with self._lock: - while session.send_queue and session.seq_num_for_send < session.high_water_for_send: + while ( + session.send_queue + and session.seq_num_for_send < session.high_water_for_send + ): data = session.send_queue.pop(0) self.send_packet(session, data) @staticmethod def _add_one_wrap(val: int) -> int: - return 0 if val == 2 ** 32 - 1 else val + 1 + return 0 if val == 2**32 - 1 else val + 1 def send_packet(self, session: _SmpSession, data: bytes) -> None: with self._lock: - if session._state == SessionState.CLOSED or session._state == SessionState.FIN_SENT: + if ( + session._state == SessionState.CLOSED + or session._state == SessionState.FIN_SENT + ): raise Error("Stream closed") if session.seq_num_for_send < session.high_water_for_send: l = SMP_HEADER.size + len(data) @@ -222,7 +238,7 @@ def send_packet(self, session: _SmpSession, data: bytes) -> None: l, seq_num, session.high_water_for_recv, - ) + ) session._last_high_water_for_recv = session.high_water_for_recv self._transport.sendall(hdr + data) session.seq_num_for_send = self._add_one_wrap(session.seq_num_for_send) @@ -233,12 +249,14 @@ def send_packet(self, session: _SmpSession, data: bytes) -> None: def recv_packet(self, session: _SmpSession) -> bytes: with self._lock: if session._state == SessionState.CLOSED: - return b'' + return b"" while not session.recv_queue: self._read_smp_message() if session._state in (SessionState.CLOSED, SessionState.FIN_RECEIVED): - return b'' - session.high_water_for_recv = self._add_one_wrap(session.high_water_for_recv) + return b"" + session.high_water_for_recv = self._add_one_wrap( + session.high_water_for_recv + ) if session.high_water_for_recv - session._last_high_water_for_recv >= 2: hdr = SMP_HEADER.pack( SMP_ID, @@ -247,7 +265,7 @@ def recv_packet(self, session: _SmpSession) -> bytes: SMP_HEADER.size, session.seq_num_for_send, session.high_water_for_recv, - ) + ) self._transport.sendall(hdr) session._last_high_water_for_recv = session.high_water_for_recv return session.recv_queue.pop(0) @@ -263,25 +281,25 @@ def _read_smp_message(self) -> None: read = self._transport.recv_into(self._hdr_buf[buf_pos:]) buf_pos += read if read == 0: - self._bad_stm('Unexpected EOF while reading SMP header') + self._bad_stm("Unexpected EOF while reading SMP header") smid, flags, sid, l, seq_num, wnd = SMP_HEADER.unpack(self._hdr_buf) if smid != SMP_ID: - self._bad_stm('Invalid SMP packet signature') + self._bad_stm("Invalid SMP packet signature") try: session = self._sessions[sid] except KeyError: - self._bad_stm('Invalid SMP packet session id') + self._bad_stm("Invalid SMP packet session id") if wnd < session.high_water_for_send: - self._bad_stm('Invalid WNDW in packet from server') + self._bad_stm("Invalid WNDW in packet from server") if seq_num > session.high_water_for_recv: - self._bad_stm('Invalid SEQNUM in packet from server') + self._bad_stm("Invalid SEQNUM in packet from server") if l < SMP_HEADER.size: - self._bad_stm('Invalid LENGTH in packet from server') + self._bad_stm("Invalid LENGTH in packet from server") session._last_recv_seq_num = seq_num if flags == PacketTypes.DATA: if session._state == SessionState.SESSION_ESTABLISHED: if seq_num != self._add_one_wrap(session._seq_num_for_recv): - self._bad_stm('Invalid SEQNUM in DATA packet from server') + self._bad_stm("Invalid SEQNUM in DATA packet from server") session._seq_num_for_recv = seq_num remains = l - SMP_HEADER.size while remains: @@ -295,16 +313,20 @@ def _read_smp_message(self) -> None: elif session._state == SessionState.FIN_SENT: skipall(self._transport, l - SMP_HEADER.size) else: - self._bad_stm('Unexpected DATA packet from server') + self._bad_stm("Unexpected DATA packet from server") elif flags == PacketTypes.ACK: if session._state in (SessionState.FIN_RECEIVED, SessionState.CLOSED): - self._bad_stm('Unexpected ACK packet from server') + self._bad_stm("Unexpected ACK packet from server") if seq_num != session._seq_num_for_recv: - self._bad_stm('Invalid SEQNUM in ACK packet from server') + self._bad_stm("Invalid SEQNUM in ACK packet from server") session.high_water_for_send = wnd self.send_queued_packets(session) elif flags == PacketTypes.FIN: - assert session._state in (SessionState.SESSION_ESTABLISHED, SessionState.FIN_SENT, SessionState.FIN_RECEIVED) + assert session._state in ( + SessionState.SESSION_ESTABLISHED, + SessionState.FIN_SENT, + SessionState.FIN_RECEIVED, + ) if session._state == SessionState.SESSION_ESTABLISHED: session._state = SessionState.FIN_RECEIVED elif session._state == SessionState.FIN_SENT: @@ -312,11 +334,11 @@ def _read_smp_message(self) -> None: del self._sessions[session.session_id] self._used_ids_ba[session.session_id] = False elif session._state == SessionState.FIN_RECEIVED: - self._bad_stm('Unexpected FIN packet from server') + self._bad_stm("Unexpected FIN packet from server") elif flags == PacketTypes.SYN: - self._bad_stm('Unexpected SYN packet from server') + self._bad_stm("Unexpected SYN packet from server") else: - self._bad_stm('Unexpected FLAGS in packet from server') + self._bad_stm("Unexpected FLAGS in packet from server") def close(self) -> None: self._transport.close() diff --git a/src/pytds/sspi.py b/src/pytds/sspi.py index 33ae5b5..91ea665 100644 --- a/src/pytds/sspi.py +++ b/src/pytds/sspi.py @@ -1,7 +1,18 @@ import logging -from ctypes import (c_ulong, c_ushort, c_void_p, c_ulonglong, POINTER, # type: ignore # needs fixing - Structure, c_wchar_p, WINFUNCTYPE, windll, byref, cast) # type: ignore # needs fixing +from ctypes import ( + c_ulong, + c_ushort, + c_void_p, + c_ulonglong, + POINTER, # type: ignore # needs fixing + Structure, + c_wchar_p, + WINFUNCTYPE, + windll, + byref, + cast, +) # type: ignore # needs fixing logger = logging.getLogger(__name__) @@ -31,28 +42,29 @@ class Status(object): @classmethod def getname(cls, value): for name in dir(cls): - if name.startswith('SEC_E_') and getattr(cls, name) == value: + if name.startswith("SEC_E_") and getattr(cls, name) == value: return name - return 'unknown value {0:x}'.format(0x100000000 + value) + return "unknown value {0:x}".format(0x100000000 + value) -#define SECBUFFER_EMPTY 0 // Undefined, replaced by provider -#define SECBUFFER_DATA 1 // Packet data + +# define SECBUFFER_EMPTY 0 // Undefined, replaced by provider +# define SECBUFFER_DATA 1 // Packet data SECBUFFER_TOKEN = 2 -#define SECBUFFER_PKG_PARAMS 3 // Package specific parameters -#define SECBUFFER_MISSING 4 // Missing Data indicator -#define SECBUFFER_EXTRA 5 // Extra data -#define SECBUFFER_STREAM_TRAILER 6 // Security Trailer -#define SECBUFFER_STREAM_HEADER 7 // Security Header -#define SECBUFFER_NEGOTIATION_INFO 8 // Hints from the negotiation pkg -#define SECBUFFER_PADDING 9 // non-data padding -#define SECBUFFER_STREAM 10 // whole encrypted message -#define SECBUFFER_MECHLIST 11 -#define SECBUFFER_MECHLIST_SIGNATURE 12 -#define SECBUFFER_TARGET 13 // obsolete -#define SECBUFFER_CHANNEL_BINDINGS 14 -#define SECBUFFER_CHANGE_PASS_RESPONSE 15 -#define SECBUFFER_TARGET_HOST 16 -#define SECBUFFER_ALERT 17 +# define SECBUFFER_PKG_PARAMS 3 // Package specific parameters +# define SECBUFFER_MISSING 4 // Missing Data indicator +# define SECBUFFER_EXTRA 5 // Extra data +# define SECBUFFER_STREAM_TRAILER 6 // Security Trailer +# define SECBUFFER_STREAM_HEADER 7 // Security Header +# define SECBUFFER_NEGOTIATION_INFO 8 // Hints from the negotiation pkg +# define SECBUFFER_PADDING 9 // non-data padding +# define SECBUFFER_STREAM 10 // whole encrypted message +# define SECBUFFER_MECHLIST 11 +# define SECBUFFER_MECHLIST_SIGNATURE 12 +# define SECBUFFER_TARGET 13 // obsolete +# define SECBUFFER_CHANNEL_BINDINGS 14 +# define SECBUFFER_CHANGE_PASS_RESPONSE 15 +# define SECBUFFER_TARGET_HOST 16 +# define SECBUFFER_ALERT 17 SECPKG_CRED_INBOUND = 0x00000001 SECPKG_CRED_OUTBOUND = 0x00000002 @@ -62,10 +74,10 @@ def getname(cls, value): SECBUFFER_VERSION = 0 -#define ISC_REQ_DELEGATE 0x00000001 -#define ISC_REQ_MUTUAL_AUTH 0x00000002 +# define ISC_REQ_DELEGATE 0x00000001 +# define ISC_REQ_MUTUAL_AUTH 0x00000002 ISC_REQ_REPLAY_DETECT = 4 -#define ISC_REQ_SEQUENCE_DETECT 0x00000008 +# define ISC_REQ_SEQUENCE_DETECT 0x00000008 ISC_REQ_CONFIDENTIALITY = 0x10 ISC_REQ_USE_SESSION_KEY = 0x00000020 ISC_REQ_PROMPT_FOR_CREDS = 0x00000040 @@ -74,22 +86,22 @@ def getname(cls, value): ISC_REQ_USE_DCE_STYLE = 0x00000200 ISC_REQ_DATAGRAM = 0x00000400 ISC_REQ_CONNECTION = 0x00000800 -#define ISC_REQ_CALL_LEVEL 0x00001000 -#define ISC_REQ_FRAGMENT_SUPPLIED 0x00002000 -#define ISC_REQ_EXTENDED_ERROR 0x00004000 -#define ISC_REQ_STREAM 0x00008000 -#define ISC_REQ_INTEGRITY 0x00010000 -#define ISC_REQ_IDENTIFY 0x00020000 -#define ISC_REQ_NULL_SESSION 0x00040000 -#define ISC_REQ_MANUAL_CRED_VALIDATION 0x00080000 -#define ISC_REQ_RESERVED1 0x00100000 -#define ISC_REQ_FRAGMENT_TO_FIT 0x00200000 -#// This exists only in Windows Vista and greater -#define ISC_REQ_FORWARD_CREDENTIALS 0x00400000 -#define ISC_REQ_NO_INTEGRITY 0x00800000 // honored only by SPNEGO -#define ISC_REQ_USE_HTTP_STYLE 0x01000000 -#define ISC_REQ_UNVERIFIED_TARGET_NAME 0x20000000 -#define ISC_REQ_CONFIDENTIALITY_ONLY 0x40000000 // honored by SPNEGO/Kerberos +# define ISC_REQ_CALL_LEVEL 0x00001000 +# define ISC_REQ_FRAGMENT_SUPPLIED 0x00002000 +# define ISC_REQ_EXTENDED_ERROR 0x00004000 +# define ISC_REQ_STREAM 0x00008000 +# define ISC_REQ_INTEGRITY 0x00010000 +# define ISC_REQ_IDENTIFY 0x00020000 +# define ISC_REQ_NULL_SESSION 0x00040000 +# define ISC_REQ_MANUAL_CRED_VALIDATION 0x00080000 +# define ISC_REQ_RESERVED1 0x00100000 +# define ISC_REQ_FRAGMENT_TO_FIT 0x00200000 +# // This exists only in Windows Vista and greater +# define ISC_REQ_FORWARD_CREDENTIALS 0x00400000 +# define ISC_REQ_NO_INTEGRITY 0x00800000 // honored only by SPNEGO +# define ISC_REQ_USE_HTTP_STYLE 0x01000000 +# define ISC_REQ_UNVERIFIED_TARGET_NAME 0x20000000 +# define ISC_REQ_CONFIDENTIALITY_ONLY 0x40000000 // honored by SPNEGO/Kerberos SECURITY_NETWORK_DREP = 0 SECURITY_NATIVE_DREP = 0x10 @@ -107,9 +119,11 @@ def getname(cls, value): class SecHandle(Structure): _fields_ = [ - ('lower', c_void_p), - ('upper', c_void_p), + ("lower", c_void_p), + ("upper", c_void_p), ] + + PSecHandle = POINTER(SecHandle) CredHandle = SecHandle PCredHandle = PSecHandle @@ -118,103 +132,110 @@ class SecHandle(Structure): class SecBuffer(Structure): _fields_ = [ - ('cbBuffer', ULONG), - ('BufferType', ULONG), - ('pvBuffer', PVOID), + ("cbBuffer", ULONG), + ("BufferType", ULONG), + ("pvBuffer", PVOID), ] + + PSecBuffer = POINTER(SecBuffer) class SecBufferDesc(Structure): _fields_ = [ - ('ulVersion', ULONG), - ('cBuffers', ULONG), - ('pBuffers', PSecBuffer), + ("ulVersion", ULONG), + ("cBuffers", ULONG), + ("pBuffers", PSecBuffer), ] + + PSecBufferDesc = POINTER(SecBufferDesc) class SEC_WINNT_AUTH_IDENTITY(Structure): _fields_ = [ - ('User', c_wchar_p), - ('UserLength', c_ulong), - ('Domain', c_wchar_p), - ('DomainLength', c_ulong), - ('Password', c_wchar_p), - ('PasswordLength', c_ulong), - ('Flags', c_ulong), - ] + ("User", c_wchar_p), + ("UserLength", c_ulong), + ("Domain", c_wchar_p), + ("DomainLength", c_ulong), + ("Password", c_wchar_p), + ("PasswordLength", c_ulong), + ("Flags", c_ulong), + ] class SecPkgInfo(Structure): _fields_ = [ - ('fCapabilities', ULONG), - ('wVersion', USHORT), - ('wRPCID', USHORT), - ('cbMaxToken', ULONG), - ('Name', c_wchar_p), - ('Comment', c_wchar_p), + ("fCapabilities", ULONG), + ("wVersion", USHORT), + ("wRPCID", USHORT), + ("cbMaxToken", ULONG), + ("Name", c_wchar_p), + ("Comment", c_wchar_p), ] + + PSecPkgInfo = POINTER(SecPkgInfo) class SecPkgCredentials_Names(Structure): - _fields_ = [('UserName', c_wchar_p)] + _fields_ = [("UserName", c_wchar_p)] def ret_val(value): if value < 0: - raise Exception('SSPI Error {0}'.format(Status.getname(value))) + raise Exception("SSPI Error {0}".format(Status.getname(value))) return value ENUMERATE_SECURITY_PACKAGES_FN = WINFUNCTYPE( ret_val, # type: ignore # needs fixing POINTER(c_ulong), - POINTER(POINTER(SecPkgInfo))) + POINTER(POINTER(SecPkgInfo)), +) ACQUIRE_CREDENTIALS_HANDLE_FN = WINFUNCTYPE( ret_val, # type: ignore # needs fixing - c_wchar_p, # principal - c_wchar_p, # package - ULONG, # fCredentialUse - PLUID, # pvLogonID - PVOID, # pAuthData - PVOID, # pGetKeyFn - PVOID, # pvGetKeyArgument + c_wchar_p, # principal + c_wchar_p, # package + ULONG, # fCredentialUse + PLUID, # pvLogonID + PVOID, # pAuthData + PVOID, # pGetKeyFn + PVOID, # pvGetKeyArgument PCredHandle, # phCredential - PTimeStamp # ptsExpiry - ) -FREE_CREDENTIALS_HANDLE_FN = WINFUNCTYPE(ret_val, POINTER(SecHandle)) # type: ignore # needs fixing + PTimeStamp, # ptsExpiry +) +FREE_CREDENTIALS_HANDLE_FN = WINFUNCTYPE(ret_val, POINTER(SecHandle)) # type: ignore # needs fixing INITIALIZE_SECURITY_CONTEXT_FN = WINFUNCTYPE( ret_val, # type: ignore # needs fixing PCredHandle, - PCtxtHandle, # phContext, - c_wchar_p, # pszTargetName, - ULONG, # fContextReq, - ULONG, # Reserved1, - ULONG, # TargetDataRep, + PCtxtHandle, # phContext, + c_wchar_p, # pszTargetName, + ULONG, # fContextReq, + ULONG, # Reserved1, + ULONG, # TargetDataRep, PSecBufferDesc, # pInput, - ULONG, # Reserved2, - PCtxtHandle, # phNewContext, + ULONG, # Reserved2, + PCtxtHandle, # phNewContext, PSecBufferDesc, # pOutput, - PULONG, # pfContextAttr, - PTimeStamp, # ptsExpiry - ) + PULONG, # pfContextAttr, + PTimeStamp, # ptsExpiry +) COMPLETE_AUTH_TOKEN_FN = WINFUNCTYPE( ret_val, # type: ignore # needs fixing - PCtxtHandle, # phContext + PCtxtHandle, # phContext PSecBufferDesc, # pToken - ) +) -FREE_CONTEXT_BUFFER_FN = WINFUNCTYPE(ret_val, PVOID) # type: ignore # needs fixing +FREE_CONTEXT_BUFFER_FN = WINFUNCTYPE(ret_val, PVOID) # type: ignore # needs fixing QUERY_CREDENTIAL_ATTRIBUTES_FN = WINFUNCTYPE( ret_val, # type: ignore # needs fixing - PCredHandle, # cred - ULONG, # attribute - PVOID, # out buffer - ) + PCredHandle, # cred + ULONG, # attribute + PVOID, # out buffer +) ACCEPT_SECURITY_CONTEXT_FN = PVOID DELETE_SECURITY_CONTEXT_FN = WINFUNCTYPE(ret_val, PCtxtHandle) # type: ignore # needs fixing APPLY_CONTROL_TOKEN_FN = PVOID @@ -227,7 +248,7 @@ def ret_val(value): ret_val, # type: ignore # needs fixing c_wchar_p, # package name POINTER(PSecPkgInfo), - ) +) EXPORT_SECURITY_CONTEXT_FN = PVOID IMPORT_SECURITY_CONTEXT_FN = PVOID ADD_CREDENTIALS_FN = PVOID @@ -239,47 +260,50 @@ def ret_val(value): class SECURITY_FUNCTION_TABLE(Structure): _fields_ = [ - ('dwVersion', c_ulong), - ('EnumerateSecurityPackages', ENUMERATE_SECURITY_PACKAGES_FN), - ('QueryCredentialsAttributes', QUERY_CREDENTIAL_ATTRIBUTES_FN), - ('AcquireCredentialsHandle', ACQUIRE_CREDENTIALS_HANDLE_FN), - ('FreeCredentialsHandle', FREE_CREDENTIALS_HANDLE_FN), - ('Reserved2', c_void_p), - ('InitializeSecurityContext', INITIALIZE_SECURITY_CONTEXT_FN), - ('AcceptSecurityContext', ACCEPT_SECURITY_CONTEXT_FN), - ('CompleteAuthToken', COMPLETE_AUTH_TOKEN_FN), - ('DeleteSecurityContext', DELETE_SECURITY_CONTEXT_FN), - ('ApplyControlToken', APPLY_CONTROL_TOKEN_FN), - ('QueryContextAttributes', QUERY_CONTEXT_ATTRIBUTES_FN), - ('ImpersonateSecurityContext', IMPERSONATE_SECURITY_CONTEXT_FN), - ('RevertSecurityContext', REVERT_SECURITY_CONTEXT_FN), - ('MakeSignature', MAKE_SIGNATURE_FN), - ('VerifySignature', VERIFY_SIGNATURE_FN), - ('FreeContextBuffer', FREE_CONTEXT_BUFFER_FN), - ('QuerySecurityPackageInfo', QUERY_SECURITY_PACKAGE_INFO_FN), - ('Reserved3', c_void_p), - ('Reserved4', c_void_p), - ('ExportSecurityContext', EXPORT_SECURITY_CONTEXT_FN), - ('ImportSecurityContext', IMPORT_SECURITY_CONTEXT_FN), - ('AddCredentials', ADD_CREDENTIALS_FN), - ('Reserved8', c_void_p), - ('QuerySecurityContextToken', QUERY_SECURITY_CONTEXT_TOKEN_FN), - ('EncryptMessage', ENCRYPT_MESSAGE_FN), - ('DecryptMessage', DECRYPT_MESSAGE_FN), - ('SetContextAttributes', SET_CONTEXT_ATTRIBUTES_FN), - ] + ("dwVersion", c_ulong), + ("EnumerateSecurityPackages", ENUMERATE_SECURITY_PACKAGES_FN), + ("QueryCredentialsAttributes", QUERY_CREDENTIAL_ATTRIBUTES_FN), + ("AcquireCredentialsHandle", ACQUIRE_CREDENTIALS_HANDLE_FN), + ("FreeCredentialsHandle", FREE_CREDENTIALS_HANDLE_FN), + ("Reserved2", c_void_p), + ("InitializeSecurityContext", INITIALIZE_SECURITY_CONTEXT_FN), + ("AcceptSecurityContext", ACCEPT_SECURITY_CONTEXT_FN), + ("CompleteAuthToken", COMPLETE_AUTH_TOKEN_FN), + ("DeleteSecurityContext", DELETE_SECURITY_CONTEXT_FN), + ("ApplyControlToken", APPLY_CONTROL_TOKEN_FN), + ("QueryContextAttributes", QUERY_CONTEXT_ATTRIBUTES_FN), + ("ImpersonateSecurityContext", IMPERSONATE_SECURITY_CONTEXT_FN), + ("RevertSecurityContext", REVERT_SECURITY_CONTEXT_FN), + ("MakeSignature", MAKE_SIGNATURE_FN), + ("VerifySignature", VERIFY_SIGNATURE_FN), + ("FreeContextBuffer", FREE_CONTEXT_BUFFER_FN), + ("QuerySecurityPackageInfo", QUERY_SECURITY_PACKAGE_INFO_FN), + ("Reserved3", c_void_p), + ("Reserved4", c_void_p), + ("ExportSecurityContext", EXPORT_SECURITY_CONTEXT_FN), + ("ImportSecurityContext", IMPORT_SECURITY_CONTEXT_FN), + ("AddCredentials", ADD_CREDENTIALS_FN), + ("Reserved8", c_void_p), + ("QuerySecurityContextToken", QUERY_SECURITY_CONTEXT_TOKEN_FN), + ("EncryptMessage", ENCRYPT_MESSAGE_FN), + ("DecryptMessage", DECRYPT_MESSAGE_FN), + ("SetContextAttributes", SET_CONTEXT_ATTRIBUTES_FN), + ] + _PInitSecurityInterface = WINFUNCTYPE(POINTER(SECURITY_FUNCTION_TABLE)) -InitSecurityInterface = _PInitSecurityInterface(('InitSecurityInterfaceW', windll.secur32)) +InitSecurityInterface = _PInitSecurityInterface( + ("InitSecurityInterfaceW", windll.secur32) +) sec_fn = InitSecurityInterface() if not sec_fn: - raise Exception('InitSecurityInterface failed') + raise Exception("InitSecurityInterface failed") sec_fn = sec_fn.contents class _SecContext(object): - def __init__(self, cred: 'SspiCredentials') -> None: + def __init__(self, cred: "SspiCredentials") -> None: self._cred = cred self._handle = SecHandle() self._ts = TimeStamp() @@ -294,34 +318,41 @@ def __del__(self) -> None: self.close() def complete_auth_token(self, bufs): - sec_fn.CompleteAuthToken( - byref(self._handle), - byref(_make_buffers_desc(bufs))) - - def next(self, - flags, - target_name=None, - byte_ordering='network', - input_buffers=None, - output_buffers=None): - input_buffers_desc = _make_buffers_desc(input_buffers) if input_buffers else None - output_buffers_desc = _make_buffers_desc(output_buffers) if output_buffers else None + sec_fn.CompleteAuthToken(byref(self._handle), byref(_make_buffers_desc(bufs))) + + def next( + self, + flags, + target_name=None, + byte_ordering="network", + input_buffers=None, + output_buffers=None, + ): + input_buffers_desc = ( + _make_buffers_desc(input_buffers) if input_buffers else None + ) + output_buffers_desc = ( + _make_buffers_desc(output_buffers) if output_buffers else None + ) status = sec_fn.InitializeSecurityContext( byref(self._cred._handle), byref(self._handle), target_name, flags, 0, - SECURITY_NETWORK_DREP if byte_ordering == 'network' else SECURITY_NATIVE_DREP, + SECURITY_NETWORK_DREP + if byte_ordering == "network" + else SECURITY_NATIVE_DREP, byref(input_buffers_desc) if input_buffers_desc else None, 0, byref(self._handle), byref(output_buffers_desc) if input_buffers_desc else None, byref(self._attrs), - byref(self._ts)) + byref(self._ts), + ) result_buffers = [] for i, (type, buf) in enumerate(output_buffers): - buf = buf[:output_buffers_desc.pBuffers[i].cbBuffer] + buf = buf[: output_buffers_desc.pBuffers[i].cbBuffer] result_buffers.append((type, buf)) return status, result_buffers @@ -332,9 +363,16 @@ def __init__(self, package, use, identity=None): self._ts = TimeStamp() logger.debug("Acquiring credentials handle") sec_fn.AcquireCredentialsHandle( - None, package, use, - None, byref(identity) if identity and identity.Domain else None, - None, None, byref(self._handle), byref(self._ts)) + None, + package, + use, + None, + byref(identity) if identity and identity.Domain else None, + None, + None, + byref(self._handle), + byref(self._ts), + ) def close(self): if self._handle and (self._handle.lower or self._handle.upper): @@ -349,9 +387,8 @@ def query_user_name(self): names = SecPkgCredentials_Names() try: sec_fn.QueryCredentialsAttributes( - byref(self._handle), - SECPKG_CRED_ATTR_NAMES, - byref(names)) + byref(self._handle), SECPKG_CRED_ATTR_NAMES, byref(names) + ) user_name = str(names.UserName) finally: p = c_wchar_p.from_buffer(names, SecPkgCredentials_Names.UserName.offset) @@ -359,17 +396,22 @@ def query_user_name(self): return user_name def create_context( - self, - flags: int, - target_name=None, - byte_ordering='network', - input_buffers=None, - output_buffers=None): + self, + flags: int, + target_name=None, + byte_ordering="network", + input_buffers=None, + output_buffers=None, + ): if self._handle is None: raise RuntimeError("Using closed SspiCredentials object") ctx = _SecContext(cred=self) - input_buffers_desc = _make_buffers_desc(input_buffers) if input_buffers else None - output_buffers_desc = _make_buffers_desc(output_buffers) if output_buffers else None + input_buffers_desc = ( + _make_buffers_desc(input_buffers) if input_buffers else None + ) + output_buffers_desc = ( + _make_buffers_desc(output_buffers) if output_buffers else None + ) logger.debug("Initializing security context") status = sec_fn.InitializeSecurityContext( byref(self._handle), @@ -377,16 +419,19 @@ def create_context( target_name, flags, 0, - SECURITY_NETWORK_DREP if byte_ordering == 'network' else SECURITY_NATIVE_DREP, + SECURITY_NETWORK_DREP + if byte_ordering == "network" + else SECURITY_NATIVE_DREP, byref(input_buffers_desc) if input_buffers_desc else None, 0, byref(ctx._handle), byref(output_buffers_desc) if output_buffers_desc else None, byref(ctx._attrs), - byref(ctx._ts)) + byref(ctx._ts), + ) result_buffers = [] for i, (type, buf) in enumerate(output_buffers): - buf = buf[:output_buffers_desc.pBuffers[i].cbBuffer] + buf = buf[: output_buffers_desc.pBuffers[i].cbBuffer] result_buffers.append((type, buf)) return ctx, status, result_buffers @@ -415,7 +460,8 @@ def make_winnt_identity(domain, user_name, password): identity.UserLength = len(user_name) return identity -#class SspiSecBuffer(object): + +# class SspiSecBuffer(object): # def __init__(self, type, buflen=4096): # self._buf = create_string_buffer(int(buflen)) # self._desc = SecBuffer() @@ -423,7 +469,7 @@ def make_winnt_identity(domain, user_name, password): # self._desc.BufferType = type # self._desc.pvBuffer = cast(self._buf, PVOID) # -#class SspiSecBuffers(object): +# class SspiSecBuffers(object): # def __init__(self): # self._desc = SecBufferDesc() # self._desc.ulVersion = SECBUFFER_VERSION @@ -444,12 +490,16 @@ def enum_security_packages(): infos = POINTER(SecPkgInfo)() status = sec_fn.EnumerateSecurityPackages(byref(num), byref(infos)) try: - return [{'caps': infos[i].fCapabilities, - 'version': infos[i].wVersion, - 'rpcid': infos[i].wRPCID, - 'max_token': infos[i].cbMaxToken, - 'name': infos[i].Name, - 'comment': infos[i].Comment, - } for i in range(num.value)] + return [ + { + "caps": infos[i].fCapabilities, + "version": infos[i].wVersion, + "rpcid": infos[i].wRPCID, + "max_token": infos[i].cbMaxToken, + "name": infos[i].Name, + "comment": infos[i].Comment, + } + for i in range(num.value) + ] finally: sec_fn.FreeContextBuffer(infos) diff --git a/src/pytds/tds.py b/src/pytds/tds.py index 924e9ec..9abe599 100644 --- a/src/pytds/tds.py +++ b/src/pytds/tds.py @@ -12,6 +12,7 @@ from .tds_base import PreLoginEnc, _TdsEnv, _TdsLogin, Route from .row_strategies import list_row_strategy from .smp import SmpManager + # _token_map is needed by sqlalchemy_pytds connector from .tds_session import _token_map, _TdsSession @@ -23,14 +24,14 @@ # if MARS is not used it would have single _TdsSession instance class _TdsSocket(object): def __init__( - self, - sock: tds_base.TransportProtocol, - login: _TdsLogin, - tzinfo_factory: tds_types.TzInfoFactoryType | None = None, - row_strategy=list_row_strategy, - use_tz: datetime.tzinfo | None = None, - autocommit=False, - isolation_level=0, + self, + sock: tds_base.TransportProtocol, + login: _TdsLogin, + tzinfo_factory: tds_types.TzInfoFactoryType | None = None, + row_strategy=list_row_strategy, + use_tz: datetime.tzinfo | None = None, + autocommit=False, + isolation_level=0, ): self._is_connected = False self.env = _TdsEnv() @@ -64,7 +65,7 @@ def __init__( type_factory=self.type_factory, collation=self.collation, bytes_to_unicode=self._login.bytes_to_unicode, - allow_tz=not self.use_tz + allow_tz=not self.use_tz, ) self.server_library_version = (0, 0) self.product_name = "" @@ -72,11 +73,13 @@ def __init__( def __repr__(self) -> str: fmt = "<_TdsSocket tran={} mars={} tds_version={} use_tz={}>" - return fmt.format(self.tds72_transaction, self._mars_enabled, - self.tds_version, self.use_tz) + return fmt.format( + self.tds72_transaction, self._mars_enabled, self.tds_version, self.use_tz + ) def login(self) -> Route | None: from . import tls + self._login.server_enc_flag = PreLoginEnc.ENCRYPT_NOT_SUP if tds_base.IS_TDS71_PLUS(self): self._main_session.send_prelogin(self._login) @@ -91,15 +94,20 @@ def login(self) -> Route | None: return self.route # update block size if server returned different one - if self._main_session._writer.bufsize != self._main_session._reader.get_block_size(): - self._main_session._reader.set_block_size(self._main_session._writer.bufsize) + if ( + self._main_session._writer.bufsize + != self._main_session._reader.get_block_size() + ): + self._main_session._reader.set_block_size( + self._main_session._writer.bufsize + ) self.type_factory = tds_types.SerializerFactory(self.tds_version) self.type_inferrer = tds_types.TdsTypeInferrer( type_factory=self.type_factory, collation=self.collation, bytes_to_unicode=self._login.bytes_to_unicode, - allow_tz=not self.use_tz + allow_tz=not self.use_tz, ) if self._mars_enabled: self._smp_manager = SmpManager(self.sock) @@ -114,9 +122,9 @@ def login(self) -> Route | None: self._is_connected = True q = [] if self._login.database and self.env.database != self._login.database: - q.append('use ' + tds_base.tds_quote_id(self._login.database)) + q.append("use " + tds_base.tds_quote_id(self._login.database)) if q: - self._main_session.submit_plain_query(''.join(q)) + self._main_session.submit_plain_query("".join(q)) self._main_session.process_simple_request() return None @@ -130,7 +138,9 @@ def main_session(self) -> _TdsSession: def create_session(self) -> _TdsSession: if not self._smp_manager: - raise RuntimeError("Calling create_session on a non-MARS connection does not work") + raise RuntimeError( + "Calling create_session on a non-MARS connection does not work" + ) return _TdsSession( tds=self, transport=self._smp_manager.create_session(), @@ -162,7 +172,7 @@ def close_all_mars_sessions(self) -> None: def _parse_instances(msg: bytes) -> dict[str, dict[str, str]] | None: name: str | None = None if len(msg) > 3 and tds_base.my_ord(msg[0]) == 5: - tokens = msg[3:].decode('ascii').split(';') + tokens = msg[3:].decode("ascii").split(";") results: dict[str, dict[str, str]] = {} instdict: dict[str, str] = {} got_name = False @@ -175,24 +185,27 @@ def _parse_instances(msg: bytes) -> dict[str, dict[str, str]] | None: if not name: if not instdict: break - results[instdict['InstanceName'].upper()] = instdict + results[instdict["InstanceName"].upper()] = instdict instdict = {} continue got_name = True return results return None + # # Get port of all instances # @return default port number or 0 if error # @remark experimental, cf. MC-SQLR.pdf. # -def tds7_get_instances(ip_addr: Any, timeout: float = 5) -> dict[str, dict[str, str]] | None: +def tds7_get_instances( + ip_addr: Any, timeout: float = 5 +) -> dict[str, dict[str, str]] | None: s = socket.socket(type=socket.SOCK_DGRAM) s.settimeout(timeout) try: # send the request - s.sendto(b'\x03', (ip_addr, 1434)) + s.sendto(b"\x03", (ip_addr, 1434)) msg = s.recv(16 * 1024 - 1) # got data, read and parse return _parse_instances(msg) diff --git a/src/pytds/tds_base.py b/src/pytds/tds_base.py index f14171e..b9e9d0f 100644 --- a/src/pytds/tds_base.py +++ b/src/pytds/tds_base.py @@ -195,7 +195,7 @@ class PacketType: TEXTTYPE = SYBTEXT = 35 # 0x23 SYBVARBINARY = 37 # 0x25 INTNTYPE = SYBINTN = 38 # 0x26 -SYBVARCHAR = 39 # 0x27 +SYBVARCHAR = 39 # 0x27 BINARYTYPE = SYBBINARY = 45 # 0x2D SYBCHAR = 47 # 0x2F INT1TYPE = SYBINT1 = 48 # 0x30 @@ -266,11 +266,11 @@ class PacketType: TDS_UT_TIMESTAMP = 80 # compute operator -SYBAOPCNT = 0x4b -SYBAOPCNTU = 0x4c -SYBAOPSUM = 0x4d -SYBAOPSUMU = 0x4e -SYBAOPAVG = 0x4f +SYBAOPCNT = 0x4B +SYBAOPCNTU = 0x4C +SYBAOPSUM = 0x4D +SYBAOPSUMU = 0x4E +SYBAOPAVG = 0x4F SYBAOPAVGU = 0x50 SYBAOPMIN = 0x51 SYBAOPMAX = 0x52 @@ -292,7 +292,7 @@ class PacketType: TDS_PENDING = 2 TDS_READING = 3 TDS_DEAD = 4 -state_names = ['IDLE', 'QUERYING', 'PENDING', 'READING', 'DEAD'] +state_names = ["IDLE", "QUERYING", "PENDING", "READING", "DEAD"] TDS_ENCRYPTION_OFF = 0 TDS_ENCRYPTION_REQUEST = 1 @@ -305,6 +305,7 @@ class PreLoginToken: Spec link: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/60f56408-0188-4cd5-8b90-25c6f2423868 """ + VERSION = 0 ENCRYPTION = 1 INSTOPT = 2 @@ -313,7 +314,7 @@ class PreLoginToken: TRACEID = 5 FEDAUTHREQUIRED = 6 NONCEOPT = 7 - TERMINATOR = 0xff + TERMINATOR = 0xFF class PreLoginEnc: @@ -322,19 +323,20 @@ class PreLoginEnc: Spec link: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/60f56408-0188-4cd5-8b90-25c6f2423868 """ + ENCRYPT_OFF = 0 # Encryption available but off ENCRYPT_ON = 1 # Encryption available and on ENCRYPT_NOT_SUP = 2 # Encryption not available ENCRYPT_REQ = 3 # Encryption required -PLP_MARKER = 0xffff -PLP_NULL = 0xffffffffffffffff -PLP_UNKNOWN = 0xfffffffffffffffe +PLP_MARKER = 0xFFFF +PLP_NULL = 0xFFFFFFFFFFFFFFFF +PLP_UNKNOWN = 0xFFFFFFFFFFFFFFFE TDS_NO_COUNT = -1 -TVP_NULL_TOKEN = 0xffff +TVP_NULL_TOKEN = 0xFFFF # TVP COLUMN FLAGS TVP_COLUMN_DEFAULT_FLAG = 0x200 @@ -354,7 +356,7 @@ def __ne__(self, other): def iterdecode(iterable, codec): - """ Uses an incremental decoder to decode each chunk of string in iterable. + """Uses an incremental decoder to decode each chunk of string in iterable. This function is a generator. :param iterable: Iterable object which yields raw data to be decoded. @@ -363,7 +365,7 @@ def iterdecode(iterable, codec): decoder = codec.incrementaldecoder() for chunk in iterable: yield decoder.decode(chunk) - yield decoder.decode(b'', True) + yield decoder.decode(b"", True) def force_unicode(s): @@ -372,7 +374,7 @@ def force_unicode(s): """ if isinstance(s, bytes): try: - return s.decode('utf8') + return s.decode("utf8") except UnicodeDecodeError as e: raise DatabaseError(e) elif isinstance(s, str): @@ -382,29 +384,29 @@ def force_unicode(s): def tds_quote_id(ident): - """ Quote an identifier according to MSSQL rules + """Quote an identifier according to MSSQL rules :param ident: identifier to quote :returns: Quoted identifier """ - return '[{0}]'.format(ident.replace(']', ']]')) + return "[{0}]".format(ident.replace("]", "]]")) # store a tuple of programming error codes prog_errors = ( - 102, # syntax error - 207, # invalid column name - 208, # invalid object name - 2812, # unknown procedure - 4104 # multi-part identifier could not be bound + 102, # syntax error + 207, # invalid column name + 208, # invalid object name + 2812, # unknown procedure + 4104, # multi-part identifier could not be bound ) # store a tuple of integrity error codes integrity_errors = ( - 515, # NULL insert - 547, # FK related - 2601, # violate unique index - 2627, # violate UNIQUE KEY constraint + 515, # NULL insert + 547, # FK related + 2601, # violate unique index + 2627, # violate UNIQUE KEY constraint ) @@ -415,14 +417,14 @@ def my_ord(val): return val def join_bytearrays(ba): - return b''.join(ba) + return b"".join(ba) else: exc_base_class = StandardError my_ord = ord def join_bytearrays(bas): - return b''.join(bytes(ba) for ba in bas) + return b"".join(bytes(ba) for ba in bas) # exception hierarchy @@ -434,6 +436,7 @@ class Error(exc_base_class): """ Base class for all error classes, except TimeoutError """ + pass @@ -444,6 +447,7 @@ class InterfaceError(Error): """ TODO add documentation """ + pass @@ -451,7 +455,8 @@ class DatabaseError(Error): """ This error is raised when MSSQL server returns an error which includes error number """ - def __init__(self, msg: str, exc: typing.Any | None=None): + + def __init__(self, msg: str, exc: typing.Any | None = None): super().__init__(msg, exc) self.msg_no = 0 self.text = msg @@ -465,28 +470,42 @@ def __init__(self, msg: str, exc: typing.Any | None=None): @property def message(self): if self.procname: - return 'SQL Server message %d, severity %d, state %d, ' \ - 'procedure %s, line %d:\n%s' % (self.number, - self.severity, self.state, self.procname, - self.line, self.text) + return ( + "SQL Server message %d, severity %d, state %d, " + "procedure %s, line %d:\n%s" + % ( + self.number, + self.severity, + self.state, + self.procname, + self.line, + self.text, + ) + ) else: - return 'SQL Server message %d, severity %d, state %d, ' \ - 'line %d:\n%s' % (self.number, self.severity, - self.state, self.line, self.text) + return "SQL Server message %d, severity %d, state %d, " "line %d:\n%s" % ( + self.number, + self.severity, + self.state, + self.line, + self.text, + ) class ClosedConnectionError(InterfaceError): """ This error is raised when MSSQL server closes connection. """ + def __init__(self): - super(ClosedConnectionError, self).__init__('Server closed connection') + super(ClosedConnectionError, self).__init__("Server closed connection") class DataError(Error): """ This error is raised when input parameter contains data which cannot be converted to acceptable data type. """ + pass @@ -494,6 +513,7 @@ class OperationalError(DatabaseError): """ TODO add documentation """ + pass @@ -501,6 +521,7 @@ class LoginError(OperationalError): """ This error is raised if provided login credentials are invalid """ + pass @@ -508,6 +529,7 @@ class IntegrityError(DatabaseError): """ TODO add documentation """ + pass @@ -515,6 +537,7 @@ class InternalError(DatabaseError): """ TODO add documentation """ + pass @@ -522,6 +545,7 @@ class ProgrammingError(DatabaseError): """ TODO add documentation """ + pass @@ -529,6 +553,7 @@ class NotSupportedError(DatabaseError): """ TODO add documentation """ + pass @@ -537,6 +562,7 @@ class DBAPITypeObject: """ TODO add documentation """ + def __init__(self, *values): self.values = set(values) @@ -553,15 +579,32 @@ def __cmp__(self, other): # standard dbapi type objects -STRING = DBAPITypeObject(SYBVARCHAR, SYBCHAR, SYBTEXT, - XSYBNVARCHAR, XSYBNCHAR, SYBNTEXT, - XSYBVARCHAR, XSYBCHAR, SYBMSXML) +STRING = DBAPITypeObject( + SYBVARCHAR, + SYBCHAR, + SYBTEXT, + XSYBNVARCHAR, + XSYBNCHAR, + SYBNTEXT, + XSYBVARCHAR, + XSYBCHAR, + SYBMSXML, +) BINARY = DBAPITypeObject(SYBIMAGE, SYBBINARY, SYBVARBINARY, XSYBVARBINARY, XSYBBINARY) -NUMBER = DBAPITypeObject(SYBBIT, SYBBITN, SYBINT1, SYBINT2, SYBINT4, SYBINT8, SYBINTN, - SYBREAL, SYBFLT8, SYBFLTN) +NUMBER = DBAPITypeObject( + SYBBIT, + SYBBITN, + SYBINT1, + SYBINT2, + SYBINT4, + SYBINT8, + SYBINTN, + SYBREAL, + SYBFLT8, + SYBFLTN, +) DATETIME = DBAPITypeObject(SYBDATETIME, SYBDATETIME4, SYBDATETIMN) -DECIMAL = DBAPITypeObject(SYBMONEY, SYBMONEY4, SYBMONEYN, SYBNUMERIC, - SYBDECIMAL) +DECIMAL = DBAPITypeObject(SYBMONEY, SYBMONEY4, SYBMONEYN, SYBNUMERIC, SYBDECIMAL) ROWID = DBAPITypeObject() # non-standard, but useful type objects @@ -574,6 +617,7 @@ class InternalProc(object): """ TODO add documentation """ + def __init__(self, proc_id, name): self.proc_id = proc_id self.name = name @@ -581,13 +625,14 @@ def __init__(self, proc_id, name): def __unicode__(self): return self.name -SP_EXECUTESQL = InternalProc(TDS_SP_EXECUTESQL, 'sp_executesql') -SP_PREPARE = InternalProc(TDS_SP_PREPARE, 'sp_prepare') -SP_EXECUTE = InternalProc(TDS_SP_EXECUTE, 'sp_execute') + +SP_EXECUTESQL = InternalProc(TDS_SP_EXECUTESQL, "sp_executesql") +SP_PREPARE = InternalProc(TDS_SP_PREPARE, "sp_prepare") +SP_EXECUTE = InternalProc(TDS_SP_EXECUTE, "sp_execute") def skipall(stm, size): - """ Skips exactly size bytes in stm + """Skips exactly size bytes in stm If EOF is reached before size bytes are skipped will raise :class:`ClosedConnectionError` @@ -611,7 +656,7 @@ def skipall(stm, size): def read_chunks(stm, size): - """ Reads exactly size bytes from stm and produces chunks + """Reads exactly size bytes from stm and produces chunks May call stm.read multiple times until required number of bytes is read. @@ -624,7 +669,7 @@ def read_chunks(stm, size): :param size: Number of bytes to read. """ if size == 0: - yield b'' + yield b"" return res = stm.recv(size) @@ -641,7 +686,7 @@ def read_chunks(stm, size): def readall(stm, size): - """ Reads exactly size bytes from stm + """Reads exactly size bytes from stm May call stm.read multiple times until required number of bytes read. @@ -677,7 +722,7 @@ def readall_fast(stm, size): def total_seconds(td): - """ Total number of seconds in timedelta object + """Total number of seconds in timedelta object Python 2.6 doesn't have total_seconds method, this function provides a backport @@ -694,6 +739,7 @@ class Param: :type name: str :param type: Type of the parameter, e.g. :class:`pytds.tds_types.IntType` """ + def __init__(self, name: str = "", type=None, value=None, flags: int = 0): self.name = name self.type = type @@ -719,13 +765,14 @@ class Column(CommonEqualityMixin): :param flags: Combination of flags for the column, multiple flags can be combined using binary or operator. Possible flags are described above. """ + fNullable = 1 fCaseSen = 2 fReadWrite = 8 fIdentity = 0x10 fComputed = 0x20 - def __init__(self, name='', type=None, flags=fNullable, value=None): + def __init__(self, name="", type=None, flags=fNullable, value=None): self.char_codec = None self.column_name = name self.column_usertype = 0 @@ -737,16 +784,18 @@ def __init__(self, name='', type=None, flags=fNullable, value=None): def __repr__(self): val = self.value if isinstance(val, bytes) and len(self.value) > 100: - val = self.value[:100] + b'... len is ' + str(len(val)).encode('ascii') + val = self.value[:100] + b"... len is " + str(len(val)).encode("ascii") if isinstance(val, str) and len(self.value) > 100: - val = self.value[:100] + '... len is ' + str(len(val)) - return ''.format( - repr(self.column_name), - repr(self.type), - repr(val), - repr(self.flags), - repr(self.column_usertype), - repr(self.char_codec), + val = self.value[:100] + "... len is " + str(len(val)) + return ( + "".format( + repr(self.column_name), + repr(self.type), + repr(val), + repr(self.flags), + repr(self.column_usertype), + repr(self.char_codec), + ) ) def choose_serializer(self, type_factory, collation): @@ -760,7 +809,8 @@ class TransportProtocol(Protocol): """ This protocol mimics socket protocol """ - #def is_connected(self) -> bool: + + # def is_connected(self) -> bool: # ... def close(self) -> None: @@ -778,9 +828,12 @@ def sendall(self, buf: bytes, flags: int = 0) -> None: def recv(self, size: int) -> bytes: ... - def recv_into(self, buf: bytearray | memoryview, size: int = 0, flags: int = 0) -> int: + def recv_into( + self, buf: bytearray | memoryview, size: int = 0, flags: int = 0 + ) -> int: ... + class LoadBalancer(Protocol): def choose(self) -> Iterable[str]: ... @@ -799,22 +852,21 @@ def close(self) -> None: # packet header # https://msdn.microsoft.com/en-us/library/dd340948.aspx -_header = struct.Struct('>BBHHBx') - -_byte = struct.Struct('B') -_smallint_le = struct.Struct('h') -_usmallint_le = struct.Struct('H') -_int_le = struct.Struct('l') -_uint_le = struct.Struct('L') -_int8_le = struct.Struct('q') -_uint8_le = struct.Struct('Q') - +_header = struct.Struct(">BBHHBx") + +_byte = struct.Struct("B") +_smallint_le = struct.Struct("h") +_usmallint_le = struct.Struct("H") +_int_le = struct.Struct("l") +_uint_le = struct.Struct("L") +_int8_le = struct.Struct("q") +_uint8_le = struct.Struct("Q") logging_enabled = False @@ -838,17 +890,21 @@ def value(self): return self._value def __init__(self, value: Any = None, param_type=None): - """ Creates procedure output parameter. + """Creates procedure output parameter. :param param_type: either sql type declaration or python type :param value: value to pass into procedure """ if param_type is None: if value is None or value is default: - raise ValueError('Output type cannot be autodetected') + raise ValueError("Output type cannot be autodetected") elif isinstance(param_type, type) and value is not None: if value is not default and not isinstance(value, param_type): - raise ValueError('value should match param_type, value is {}, param_type is \'{}\''.format(repr(value), param_type.__name__)) + raise ValueError( + "value should match param_type, value is {}, param_type is '{}'".format( + repr(value), param_type.__name__ + ) + ) self._type = param_type self._value = value @@ -861,14 +917,14 @@ class _Default: def tds7_crypt_pass(password: str) -> bytearray: - """ Mangle password according to tds rules + """Mangle password according to tds rules :param password: Password str :returns: Byte-string with encoded password """ encoded = bytearray(ucs2_codec.encode(password)[0]) for i, ch in enumerate(encoded): - encoded[i] = ((ch << 4) & 0xff | (ch >> 4)) ^ 0xA5 + encoded[i] = ((ch << 4) & 0xFF | (ch >> 4)) ^ 0xA5 return encoded @@ -920,12 +976,14 @@ def __init__(self): self.isolation_level = 0 -def _create_exception_by_message(msg: Message, custom_error_msg: str | None = None) -> ProgrammingError | IntegrityError | OperationalError: - msg_no = msg['msgno'] +def _create_exception_by_message( + msg: Message, custom_error_msg: str | None = None +) -> ProgrammingError | IntegrityError | OperationalError: + msg_no = msg["msgno"] if custom_error_msg is not None: error_msg = custom_error_msg else: - error_msg = msg['message'] + error_msg = msg["message"] ex: ProgrammingError | IntegrityError | OperationalError if msg_no in prog_errors: ex = ProgrammingError(error_msg) @@ -933,14 +991,14 @@ def _create_exception_by_message(msg: Message, custom_error_msg: str | None = No ex = IntegrityError(error_msg) else: ex = OperationalError(error_msg) - ex.msg_no = msg['msgno'] - ex.text = msg['message'] - ex.srvname = msg['server'] - ex.procname = msg['proc_name'] - ex.number = msg['msgno'] - ex.severity = msg['severity'] - ex.state = msg['state'] - ex.line = msg['line_number'] + ex.msg_no = msg["msgno"] + ex.text = msg["message"] + ex.srvname = msg["server"] + ex.procname = msg["proc_name"] + ex.number = msg["msgno"] + ex.severity = msg["severity"] + ex.state = msg["state"] + ex.line = msg["line_number"] return ex @@ -966,4 +1024,4 @@ class _Results(object): def __init__(self) -> None: self.columns: list[Column] = [] self.row_count = 0 - self.description: tuple[tuple[str, Any, None, int, int, int, int], ...] = () \ No newline at end of file + self.description: tuple[tuple[str, Any, None, int, int, int, int], ...] = () diff --git a/src/pytds/tds_reader.py b/src/pytds/tds_reader.py index 5bc8b15..270cf5c 100644 --- a/src/pytds/tds_reader.py +++ b/src/pytds/tds_reader.py @@ -6,8 +6,19 @@ from pytds import tds_base from pytds.collate import Collation, ucs2_codec -from pytds.tds_base import readall, readall_fast, _header, _int_le, _uint_be, _uint_le, _uint8_le, _int8_le, _byte, \ - _smallint_le, _usmallint_le +from pytds.tds_base import ( + readall, + readall_fast, + _header, + _int_le, + _uint_be, + _uint_le, + _uint8_le, + _int8_le, + _byte, + _smallint_le, + _usmallint_le, +) if typing.TYPE_CHECKING: @@ -19,20 +30,27 @@ class ResponseMetadata: This class represents response metadata extracted from first response packet. This includes response type and session ID """ + def __init__(self): self.type = 0 self.spid = 0 class _TdsReader: - """ TDS stream reader + """TDS stream reader Provides stream-like interface for TDS packeted stream. Also provides convinience methods to decode primitive data like different kinds of integers etc. """ - def __init__(self, tds_session: _TdsSession, transport: tds_base.TransportProtocol, bufsize: int = 4096): - self._buf = bytearray(b'\x00' * bufsize) + + def __init__( + self, + tds_session: _TdsSession, + transport: tds_base.TransportProtocol, + bufsize: int = 4096, + ): + self._buf = bytearray(b"\x00" * bufsize) self._bufview = memoryview(self._buf) self._pos = len(self._buf) # position in the buffer self._have = 0 # number of bytes read from packet @@ -51,7 +69,7 @@ def session(self): return self._session def set_block_size(self, size: int) -> None: - self._buf = bytearray(b'\x00' * size) + self._buf = bytearray(b"\x00" * size) self._bufview = memoryview(self._buf) def get_block_size(self) -> int: @@ -59,7 +77,7 @@ def get_block_size(self) -> int: @property def packet_type(self) -> int | None: - """ Type of current packet + """Type of current packet Possible values are TDS_QUERY, TDS_LOGIN, etc. """ @@ -78,7 +96,7 @@ def stream_finished(self) -> bool: return False def read_fast(self, size: int) -> Tuple[bytes, int]: - """ Faster version of read + """Faster version of read Instead of returning sliced buffer it returns reference to internal buffer and the offset to this buffer. @@ -89,7 +107,7 @@ def read_fast(self, size: int) -> Tuple[bytes, int]: # Current response stream finished if self._pos >= self._size: if self._status == 1: - return b'', 0 + return b"", 0 self._read_packet() offset = self._pos to_read = min(size, self._size - self._pos) @@ -100,15 +118,15 @@ def recv(self, size: int) -> bytes: if self._pos >= self._size: # Current response stream finished if self._status == 1: - return b'' + return b"" self._read_packet() offset = self._pos to_read = min(size, self._size - self._pos) self._pos += to_read - return self._buf[offset:offset+to_read] + return self._buf[offset : offset + to_read] def unpack(self, struc: struct.Struct) -> Tuple[Any, ...]: - """ Unpacks given structure from stream + """Unpacks given structure from stream :param struc: A struct.Struct instance :returns: Result of unpacking @@ -117,44 +135,44 @@ def unpack(self, struc: struct.Struct) -> Tuple[Any, ...]: return struc.unpack_from(buf, offset) def get_byte(self) -> int: - """ Reads one byte from stream """ + """Reads one byte from stream""" return self.unpack(_byte)[0] def get_smallint(self) -> int: - """ Reads 16bit signed integer from the stream """ + """Reads 16bit signed integer from the stream""" return self.unpack(_smallint_le)[0] def get_usmallint(self) -> int: - """ Reads 16bit unsigned integer from the stream """ + """Reads 16bit unsigned integer from the stream""" return self.unpack(_usmallint_le)[0] def get_int(self) -> int: - """ Reads 32bit signed integer from the stream """ + """Reads 32bit signed integer from the stream""" return self.unpack(_int_le)[0] def get_uint(self) -> int: - """ Reads 32bit unsigned integer from the stream """ + """Reads 32bit unsigned integer from the stream""" return self.unpack(_uint_le)[0] def get_uint_be(self) -> int: - """ Reads 32bit unsigned big-endian integer from the stream """ + """Reads 32bit unsigned big-endian integer from the stream""" return self.unpack(_uint_be)[0] def get_uint8(self) -> int: - """ Reads 64bit unsigned integer from the stream """ + """Reads 64bit unsigned integer from the stream""" return self.unpack(_uint8_le)[0] def get_int8(self) -> int: - """ Reads 64bit signed integer from the stream """ + """Reads 64bit signed integer from the stream""" return self.unpack(_int8_le)[0] def read_ucs2(self, num_chars: int) -> str: - """ Reads num_chars UCS2 string from the stream """ + """Reads num_chars UCS2 string from the stream""" buf = readall(self, num_chars * 2) return ucs2_codec.decode(buf)[0] def read_str(self, size: int, codec) -> str: - """ Reads byte string from the stream and decodes it + """Reads byte string from the stream and decodes it :param size: Size of string in bytes :param codec: Instance of codec to decode string @@ -163,7 +181,7 @@ def read_str(self, size: int, codec) -> str: return codec.decode(readall(self, size))[0] def get_collation(self) -> Collation: - """ Reads :class:`Collation` object from stream """ + """Reads :class:`Collation` object from stream""" buf = readall(self, Collation.wire_size) return Collation.unpack(buf) @@ -174,7 +192,9 @@ def begin_response(self) -> ResponseMetadata: read methods can be called to read contents of the response packet stream until it ends. """ if self._status != 1 or self._pos < self._size: - raise RuntimeError("begin_response was called before previous response was fully consumed") + raise RuntimeError( + "begin_response was called before previous response was fully consumed" + ) self._read_packet() res = ResponseMetadata() res.type = self._type @@ -182,18 +202,22 @@ def begin_response(self) -> ResponseMetadata: return res def _read_packet(self) -> None: - """ Reads next TDS packet from the underlying transport + """Reads next TDS packet from the underlying transport Can only be called when transport's read pointer is at the beginning of the packet. """ pos = 0 while pos < _header.size: - received = self._transport.recv_into(self._bufview[pos:], _header.size - pos) + received = self._transport.recv_into( + self._bufview[pos:], _header.size - pos + ) if received == 0: raise tds_base.ClosedConnectionError() pos += received self._pos = _header.size - self._type, self._status, self._size, self._spid, _ = _header.unpack_from(self._bufview, 0) + self._type, self._status, self._size, self._spid, _ = _header.unpack_from( + self._bufview, 0 + ) self._have = pos while pos < self._size: received = self._transport.recv_into(self._bufview[pos:], self._size - pos) @@ -203,10 +227,10 @@ def _read_packet(self) -> None: self._have += received def read_whole_packet(self) -> bytes: - """ Reads single packet and returns bytes payload of the packet + """Reads single packet and returns bytes payload of the packet Can only be called when transport's read pointer is at the beginning of the packet. """ - #self._read_packet() + # self._read_packet() return readall(self, self._size - _header.size) diff --git a/src/pytds/tds_session.py b/src/pytds/tds_session.py index 4f84b1f..ab435e5 100644 --- a/src/pytds/tds_session.py +++ b/src/pytds/tds_session.py @@ -11,8 +11,22 @@ from pytds import tds_base, tds_types from pytds.collate import lcid2charset, raw_collation -from pytds.tds_base import readall, skipall, PreLoginToken, PreLoginEnc, Message, logging_enabled, \ - _create_exception_by_message, output, default, _TdsLogin, tds7_crypt_pass, logger, _Results, _TdsEnv +from pytds.tds_base import ( + readall, + skipall, + PreLoginToken, + PreLoginEnc, + Message, + logging_enabled, + _create_exception_by_message, + output, + default, + _TdsLogin, + tds7_crypt_pass, + logger, + _Results, + _TdsEnv, +) from pytds.tds_reader import _TdsReader, ResponseMetadata from pytds.tds_writer import _TdsWriter from pytds.row_strategies import list_row_strategy, RowStrategy, RowGenerator @@ -22,7 +36,7 @@ class _TdsSession: - """ TDS session + """TDS session This class has the following responsibilities: * Track state of a single TDS session if MARS enabled there could be multiple TDS sessions @@ -30,14 +44,15 @@ class _TdsSession: * Provides API to send requests and receive responses * Does serialization of requests and deserialization of responses """ + def __init__( - self, - tds: _TdsSocket, - transport: tds_base.TransportProtocol, - tzinfo_factory: tds_types.TzInfoFactoryType | None, - env: _TdsEnv, - bufsize: int, - row_strategy: RowStrategy = list_row_strategy, + self, + tds: _TdsSocket, + transport: tds_base.TransportProtocol, + tzinfo_factory: tds_types.TzInfoFactoryType | None, + env: _TdsEnv, + bufsize: int, + row_strategy: RowStrategy = list_row_strategy, ): self.out_pos = 8 self.res_info: _Results | None = None @@ -48,8 +63,12 @@ def __init__( self.ret_status: int | None = None self.skipped_to_status = False self._transport = transport - self._reader = _TdsReader(transport=transport, bufsize=bufsize, tds_session=self) - self._writer = _TdsWriter(transport=transport, bufsize=bufsize, tds_session=self) + self._reader = _TdsReader( + transport=transport, bufsize=bufsize, tds_session=self + ) + self._writer = _TdsWriter( + transport=transport, bufsize=bufsize, tds_session=self + ) self.in_buf_max = 0 self.state = tds_base.TDS_IDLE self._tds = tds @@ -107,24 +126,35 @@ def row_strategy(self) -> Callable[[Iterable[str]], Callable[[Iterable[Any]], An return self._row_strategy @row_strategy.setter - def row_strategy(self, value: Callable[[Iterable[str]], Callable[[Iterable[Any]], Any]]) -> None: - warnings.warn("Changing row_strategy on live connection is now deprecated, you should set it when creating new connection", DeprecationWarning) + def row_strategy( + self, value: Callable[[Iterable[str]], Callable[[Iterable[Any]], Any]] + ) -> None: + warnings.warn( + "Changing row_strategy on live connection is now deprecated, you should set it when creating new connection", + DeprecationWarning, + ) self._row_strategy = value def log_response_message(self, msg): # logging is disabled by default if logging_enabled: - logger.info('[%d] %s', self._spid, msg) + logger.info("[%d] %s", self._spid, msg) def __repr__(self): fmt = "<_TdsSession state={} tds={} messages={} rows_affected={} use_tz={} spid={} in_cancel={}>" - res = fmt.format(repr(self.state), repr(self._tds), repr(self.messages), - repr(self.rows_affected), repr(self.use_tz), repr(self._spid), - self.in_cancel) + res = fmt.format( + repr(self.state), + repr(self._tds), + repr(self.messages), + repr(self.rows_affected), + repr(self.use_tz), + repr(self._spid), + self.in_cancel, + ) return res def raise_db_exception(self) -> None: - """ Raises exception from last server message + """Raises exception from last server message This function will skip messages: The statement has been terminated """ @@ -133,17 +163,17 @@ def raise_db_exception(self) -> None: msg = None while True: msg = self.messages[-1] - if msg['msgno'] == 3621: # the statement has been terminated + if msg["msgno"] == 3621: # the statement has been terminated self.messages = self.messages[:-1] else: break - error_msg = ' '.join(m['message'] for m in self.messages) + error_msg = " ".join(m["message"] for m in self.messages) ex = _create_exception_by_message(msg, error_msg) raise ex def get_type_info(self, curcol): - """ Reads TYPE_INFO structure (http://msdn.microsoft.com/en-us/library/dd358284.aspx) + """Reads TYPE_INFO structure (http://msdn.microsoft.com/en-us/library/dd358284.aspx) :param curcol: An instance of :class:`Column` that will receive read information """ @@ -160,12 +190,12 @@ def get_type_info(self, curcol): curcol.serializer = serializer_class.from_stream(r) def tds7_process_result(self): - """ Reads and processes COLMETADATA stream + """Reads and processes COLMETADATA stream This stream contains a list of returned columns. Stream format link: http://msdn.microsoft.com/en-us/library/dd357363.aspx """ - self.log_response_message('got COLMETADATA') + self.log_response_message("got COLMETADATA") r = self._reader # read number of columns and allocate the columns structure @@ -201,24 +231,27 @@ def tds7_process_result(self): scale = curcol.serializer.scale size = curcol.serializer.size header_tuple.append( - (curcol.column_name, - curcol.serializer.get_typeid(), - None, - size, - precision, - scale, - curcol.flags & tds_base.Column.fNullable)) + ( + curcol.column_name, + curcol.serializer.get_typeid(), + None, + size, + precision, + scale, + curcol.flags & tds_base.Column.fNullable, + ) + ) info.description = tuple(header_tuple) self._setup_row_factory() return info def process_param(self): - """ Reads and processes RETURNVALUE stream. + """Reads and processes RETURNVALUE stream. This stream is used to send OUTPUT parameters from RPC to client. Stream format url: http://msdn.microsoft.com/en-us/library/dd303881.aspx """ - self.log_response_message('got RETURNVALUE message') + self.log_response_message("got RETURNVALUE message") r = self._reader if tds_base.IS_TDS72_PLUS(self): ordinal = r.get_usmallint() @@ -242,7 +275,7 @@ def process_cancel(self): In case when no cancel request is pending this function does nothing. """ - self.log_response_message('got CANCEL message') + self.log_response_message("got CANCEL message") # silly cases, nothing to do if not self.in_cancel: return @@ -256,7 +289,7 @@ def process_cancel(self): self.begin_response() def process_msg(self, marker: int) -> None: - """ Reads and processes ERROR/INFO streams + """Reads and processes ERROR/INFO streams Stream formats: @@ -265,40 +298,42 @@ def process_msg(self, marker: int) -> None: :param marker: TDS_ERROR_TOKEN or TDS_INFO_TOKEN """ - self.log_response_message('got ERROR/INFO message') + self.log_response_message("got ERROR/INFO message") r = self._reader r.get_smallint() # size msg: Message = { - 'marker': marker, - 'msgno': r.get_int(), - 'state': r.get_byte(), - 'severity': r.get_byte(), - 'sql_state': None, - 'priv_msg_type': 0, - 'message': '', - 'server': '', - 'proc_name': '', - 'line_number': 0, + "marker": marker, + "msgno": r.get_int(), + "state": r.get_byte(), + "severity": r.get_byte(), + "sql_state": None, + "priv_msg_type": 0, + "message": "", + "server": "", + "proc_name": "", + "line_number": 0, } if marker == tds_base.TDS_INFO_TOKEN: - msg['priv_msg_type'] = 0 + msg["priv_msg_type"] = 0 elif marker == tds_base.TDS_ERROR_TOKEN: - msg['priv_msg_type'] = 1 + msg["priv_msg_type"] = 1 else: logger.error('tds_process_msg() called with unknown marker "%d"', marker) - msg['message'] = r.read_ucs2(r.get_smallint()) + msg["message"] = r.read_ucs2(r.get_smallint()) # server name - msg['server'] = r.read_ucs2(r.get_byte()) + msg["server"] = r.read_ucs2(r.get_byte()) # stored proc name if available - msg['proc_name'] = r.read_ucs2(r.get_byte()) - msg['line_number'] = r.get_int() if tds_base.IS_TDS72_PLUS(self) else r.get_smallint() + msg["proc_name"] = r.read_ucs2(r.get_byte()) + msg["line_number"] = ( + r.get_int() if tds_base.IS_TDS72_PLUS(self) else r.get_smallint() + ) # in case extended error data is sent, we just try to discard it # special case self.messages.append(msg) def process_row(self): - """ Reads and handles ROW stream. + """Reads and handles ROW stream. This stream contains list of values of one returned row. Stream format url: http://msdn.microsoft.com/en-us/library/dd357254.aspx @@ -311,7 +346,7 @@ def process_row(self): curcol.value = self.row[i] = curcol.serializer.read(r) def process_nbcrow(self): - """ Reads and handles NBCROW stream. + """Reads and handles NBCROW stream. This stream contains list of values of one returned row in a compressed way, introduced in TDS 7.3.B @@ -321,7 +356,7 @@ def process_nbcrow(self): r = self._reader info = self.res_info if not info: - self.bad_stream('got row without info') + self.bad_stream("got row without info") assert len(info.columns) > 0 info.row_count += 1 @@ -336,7 +371,7 @@ def process_nbcrow(self): self.row[i] = value def process_orderby(self): - """ Reads and processes ORDER stream + """Reads and processes ORDER stream Used to inform client by which column dataset is ordered. Stream format url: http://msdn.microsoft.com/en-us/library/dd303317.aspx @@ -345,7 +380,7 @@ def process_orderby(self): skipall(r, r.get_smallint()) def process_end(self, marker): - """ Reads and processes DONE/DONEINPROC/DONEPROC streams + """Reads and processes DONE/DONEINPROC/DONEPROC streams Stream format urls: @@ -356,9 +391,9 @@ def process_end(self, marker): :param marker: Can be TDS_DONE_TOKEN or TDS_DONEINPROC_TOKEN or TDS_DONEPROC_TOKEN """ code_to_str = { - tds_base.TDS_DONE_TOKEN: 'DONE', - tds_base.TDS_DONEINPROC_TOKEN: 'DONEINPROC', - tds_base.TDS_DONEPROC_TOKEN: 'DONEPROC', + tds_base.TDS_DONE_TOKEN: "DONE", + tds_base.TDS_DONEINPROC_TOKEN: "DONEINPROC", + tds_base.TDS_DONEPROC_TOKEN: "DONEPROC", } self.end_marker = marker self.more_rows = False @@ -371,8 +406,11 @@ def process_end(self, marker): if self.res_info: self.res_info.more_results = more_results rows_affected = r.get_int8() if tds_base.IS_TDS72_PLUS(self) else r.get_int() - self.log_response_message("got {} message, more_res={}, cancelled={}, rows_affected={}".format( - code_to_str[marker], more_results, was_cancelled, rows_affected)) + self.log_response_message( + "got {} message, more_res={}, cancelled={}, rows_affected={}".format( + code_to_str[marker], more_results, was_cancelled, rows_affected + ) + ) if was_cancelled or (not more_results and not self.in_cancel): self.in_cancel = False self.set_state(tds_base.TDS_IDLE) @@ -381,7 +419,11 @@ def process_end(self, marker): else: self.rows_affected = -1 self.done_flags = status - if self.done_flags & tds_base.TDS_DONE_ERROR and not was_cancelled and not self.in_cancel: + if ( + self.done_flags & tds_base.TDS_DONE_ERROR + and not was_cancelled + and not self.in_cancel + ): self.raise_db_exception() def _ensure_transaction(self) -> None: @@ -389,7 +431,7 @@ def _ensure_transaction(self) -> None: self.begin_tran() def process_env_chg(self): - """ Reads and processes ENVCHANGE stream. + """Reads and processes ENVCHANGE stream. Stream info url: http://msdn.microsoft.com/en-us/library/dd303449.aspx """ @@ -400,7 +442,7 @@ def process_env_chg(self): if type_id == tds_base.TDS_ENV_SQLCOLLATION: size = r.get_byte() self.conn.collation = r.get_collation() - logger.info('switched collation to %s', self.conn.collation) + logger.info("switched collation to %s", self.conn.collation) skipall(r, size - 5) # discard old one skipall(r, r.get_byte()) @@ -410,7 +452,10 @@ def process_env_chg(self): self.conn.tds72_transaction = r.get_uint8() # old val, should be 0 skipall(r, r.get_byte()) - elif type_id == tds_base.TDS_ENV_COMMITTRANS or type_id == tds_base.TDS_ENV_ROLLBACKTRANS: + elif ( + type_id == tds_base.TDS_ENV_COMMITTRANS + or type_id == tds_base.TDS_ENV_ROLLBACKTRANS + ): self.conn.tds72_transaction = 0 # new val, should be 0 skipall(r, r.get_byte()) @@ -428,28 +473,28 @@ def process_env_chg(self): self._writer.bufsize = new_block_size elif type_id == tds_base.TDS_ENV_DATABASE: newval = r.read_ucs2(r.get_byte()) - logger.info('switched to database %s', newval) + logger.info("switched to database %s", newval) r.read_ucs2(r.get_byte()) self.conn.env.database = newval elif type_id == tds_base.TDS_ENV_LANG: newval = r.read_ucs2(r.get_byte()) - logger.info('switched language to %s', newval) + logger.info("switched language to %s", newval) r.read_ucs2(r.get_byte()) self.conn.env.language = newval elif type_id == tds_base.TDS_ENV_CHARSET: newval = r.read_ucs2(r.get_byte()) - logger.info('switched charset to %s', newval) + logger.info("switched charset to %s", newval) r.read_ucs2(r.get_byte()) self.conn.env.charset = newval - remap = {'iso_1': 'iso8859-1'} + remap = {"iso_1": "iso8859-1"} self.conn.server_codec = codecs.lookup(remap.get(newval, newval)) elif type_id == tds_base.TDS_ENV_DB_MIRRORING_PARTNER: newval = r.read_ucs2(r.get_byte()) - logger.info('got mirroring partner %s', newval) + logger.info("got mirroring partner %s", newval) r.read_ucs2(r.get_byte()) elif type_id == tds_base.TDS_ENV_LCID: lcid = int(r.read_ucs2(r.get_byte())) - logger.info('switched lcid to %s', lcid) + logger.info("switched lcid to %s", lcid) self.conn.server_codec = codecs.lookup(lcid2charset(lcid)) r.read_ucs2(r.get_byte()) elif type_id == tds_base.TDS_ENV_UNICODE_DATA_SORT_COMP_FLAGS: @@ -462,10 +507,15 @@ def process_env_chg(self): protocol = r.get_byte() protocol_property = r.get_usmallint() alt_server = r.read_ucs2(r.get_usmallint()) - logger.info('got routing info proto=%d proto_prop=%d alt_srv=%s', protocol, protocol_property, alt_server) + logger.info( + "got routing info proto=%d proto_prop=%d alt_srv=%s", + protocol, + protocol_property, + alt_server, + ) self.conn.route = { - 'server': alt_server, - 'port': protocol_property, + "server": alt_server, + "port": protocol_property, } # OLDVALUE = 0x00, 0x00 r.get_usmallint() @@ -475,7 +525,7 @@ def process_env_chg(self): skipall(r, size - 1) def process_auth(self) -> None: - """ Reads and processes SSPI stream. + """Reads and processes SSPI stream. Stream info: http://msdn.microsoft.com/en-us/library/dd302844.aspx """ @@ -483,7 +533,7 @@ def process_auth(self) -> None: w = self._writer pdu_size = r.get_smallint() if not self.authentication: - raise tds_base.Error('Got unexpected token') + raise tds_base.Error("Got unexpected token") packet = self.authentication.handle_next(readall(r, pdu_size)) if packet: w.write(packet) @@ -493,10 +543,10 @@ def is_connected(self) -> bool: """ :return: True if transport is connected """ - return self._transport.is_connected() # type: ignore # needs fixing + return self._transport.is_connected() # type: ignore # needs fixing def bad_stream(self, msg) -> None: - """ Called when input stream contains unexpected data. + """Called when input stream contains unexpected data. Will close stream and raise :class:`InterfaceError` :param msg: Message for InterfaceError exception. @@ -507,21 +557,19 @@ def bad_stream(self, msg) -> None: @property def tds_version(self) -> int: - """ Returns integer encoded current TDS protocol version - """ + """Returns integer encoded current TDS protocol version""" return self._tds.tds_version @property def conn(self) -> _TdsSocket: - """ Reference to owning :class:`_TdsSocket` - """ + """Reference to owning :class:`_TdsSocket`""" return self._tds def close(self) -> None: self._transport.close() def set_state(self, state: int) -> int: - """ Switches state of the TDS session. + """Switches state of the TDS session. It also does state transitions checks. :param state: New state, one of TDS_PENDING/TDS_READING/TDS_IDLE/TDS_DEAD/TDS_QUERING @@ -533,29 +581,44 @@ def set_state(self, state: int) -> int: if prior_state in (tds_base.TDS_READING, tds_base.TDS_QUERYING): self.state = tds_base.TDS_PENDING else: - raise tds_base.InterfaceError('logic error: cannot chage query state from {0} to {1}'. - format(tds_base.state_names[prior_state], tds_base.state_names[state])) + raise tds_base.InterfaceError( + "logic error: cannot chage query state from {0} to {1}".format( + tds_base.state_names[prior_state], tds_base.state_names[state] + ) + ) elif state == tds_base.TDS_READING: # transition to READING are valid only from PENDING if self.state != tds_base.TDS_PENDING: - raise tds_base.InterfaceError('logic error: cannot change query state from {0} to {1}'. - format(tds_base.state_names[prior_state], tds_base.state_names[state])) + raise tds_base.InterfaceError( + "logic error: cannot change query state from {0} to {1}".format( + tds_base.state_names[prior_state], tds_base.state_names[state] + ) + ) else: self.state = state elif state == tds_base.TDS_IDLE: if prior_state == tds_base.TDS_DEAD: - raise tds_base.InterfaceError('logic error: cannot change query state from {0} to {1}'. - format(tds_base.state_names[prior_state], tds_base.state_names[state])) + raise tds_base.InterfaceError( + "logic error: cannot change query state from {0} to {1}".format( + tds_base.state_names[prior_state], tds_base.state_names[state] + ) + ) self.state = state elif state == tds_base.TDS_DEAD: self.state = state elif state == tds_base.TDS_QUERYING: if self.state == tds_base.TDS_DEAD: - raise tds_base.InterfaceError('logic error: cannot change query state from {0} to {1}'. - format(tds_base.state_names[prior_state], tds_base.state_names[state])) + raise tds_base.InterfaceError( + "logic error: cannot change query state from {0} to {1}".format( + tds_base.state_names[prior_state], tds_base.state_names[state] + ) + ) elif self.state != tds_base.TDS_IDLE: - raise tds_base.InterfaceError('logic error: cannot change query state from {0} to {1}'. - format(tds_base.state_names[prior_state], tds_base.state_names[state])) + raise tds_base.InterfaceError( + "logic error: cannot change query state from {0} to {1}".format( + tds_base.state_names[prior_state], tds_base.state_names[state] + ) + ) else: self.rows_affected = tds_base.TDS_NO_COUNT self.internal_sp_called = 0 @@ -566,7 +629,7 @@ def set_state(self, state: int) -> int: @contextlib.contextmanager def querying_context(self, packet_type: int) -> typing.Iterator[None]: - """ Context manager for querying. + """Context manager for querying. Sets state to TDS_QUERYING, and reverts it to TDS_IDLE if exception happens inside managed block, and to TDS_PENDING if managed block succeeds and flushes buffer. @@ -585,7 +648,7 @@ def querying_context(self, packet_type: int) -> typing.Iterator[None]: self._writer.flush() def make_param(self, name: str, value: Any) -> tds_base.Param: - """ Generates instance of :class:`Param` from value and name + """Generates instance of :class:`Param` from value and name Value can also be of a special types: @@ -604,7 +667,10 @@ def make_param(self, name: str, value: Any) -> tds_base.Param: return value if isinstance(value, tds_base.Column): - warnings.warn("Usage of Column class as parameter is deprecated, use Param class instead", DeprecationWarning) + warnings.warn( + "Usage of Column class as parameter is deprecated, use Param class instead", + DeprecationWarning, + ) return tds_base.Param( name=name, type=value.type, @@ -629,26 +695,29 @@ def make_param(self, name: str, value: Any) -> tds_base.Param: param_value = value if param_type is None: param_type = self.conn.type_inferrer.from_value(value) - param = tds_base.Param(name=name, type=param_type, flags=param_flags, value=param_value) + param = tds_base.Param( + name=name, type=param_type, flags=param_flags, value=param_value + ) return param - def _convert_params(self, parameters: dict[str, Any] | typing.Iterable[Any]) -> List[tds_base.Param]: - """ Converts a dict of list of parameters into a list of :class:`Column` instances. + def _convert_params( + self, parameters: dict[str, Any] | typing.Iterable[Any] + ) -> List[tds_base.Param]: + """Converts a dict of list of parameters into a list of :class:`Column` instances. :param parameters: Can be a list of parameter values, or a dict of parameter names to values. :return: A list of :class:`Column` instances. """ if isinstance(parameters, dict): - return [self.make_param(name, value) - for name, value in parameters.items()] + return [self.make_param(name, value) for name, value in parameters.items()] else: params = [] for parameter in parameters: - params.append(self.make_param('', parameter)) + params.append(self.make_param("", parameter)) return params def cancel_if_pending(self) -> None: - """ Cancels current pending request. + """Cancels current pending request. Does nothing if no request is pending, otherwise sends cancel request, and waits for response. @@ -659,8 +728,13 @@ def cancel_if_pending(self) -> None: self.put_cancel() self.process_cancel() - def submit_rpc(self, rpc_name: tds_base.InternalProc | str, params: List[tds_base.Param], flags: int = 0) -> None: - """ Sends an RPC request. + def submit_rpc( + self, + rpc_name: tds_base.InternalProc | str, + params: List[tds_base.Param], + flags: int = 0, + ) -> None: + """Sends an RPC request. This call will transition session into pending state. If some operation is currently pending on the session, it will be @@ -672,7 +746,7 @@ def submit_rpc(self, rpc_name: tds_base.InternalProc | str, params: List[tds_bas :param params: Stored proc parameters, should be a list of :class:`Column` instances. :param flags: See spec for possible flags. """ - logger.info('Sending RPC %s flags=%d', rpc_name, flags) + logger.info("Sending RPC %s flags=%d", rpc_name, flags) self.messages = [] self.output_params = {} self.cancel_if_pending() @@ -681,7 +755,9 @@ def submit_rpc(self, rpc_name: tds_base.InternalProc | str, params: List[tds_bas with self.querying_context(tds_base.PacketType.RPC): if tds_base.IS_TDS72_PLUS(self): self._start_query() - if tds_base.IS_TDS71_PLUS(self) and isinstance(rpc_name, tds_base.InternalProc): + if tds_base.IS_TDS71_PLUS(self) and isinstance( + rpc_name, tds_base.InternalProc + ): w.put_smallint(-1) w.put_smallint(rpc_name.proc_id) else: @@ -712,8 +788,7 @@ def submit_rpc(self, rpc_name: tds_base.InternalProc | str, params: List[tds_bas # TYPE_INFO structure: https://msdn.microsoft.com/en-us/library/dd358284.aspx serializer = self._tds.type_factory.serializer_by_type( - sql_type=param.type, - collation=self._tds.collation or raw_collation + sql_type=param.type, collation=self._tds.collation or raw_collation ) type_id = serializer.type w.put_byte(type_id) @@ -727,7 +802,11 @@ def _setup_row_factory(self) -> None: column_names = [col[0] for col in self.res_info.description] self._row_convertor = self._row_strategy(column_names) - def callproc(self, procname: tds_base.InternalProc | str, parameters: dict[str, Any] | typing.Iterable[Any]) -> list[Any]: + def callproc( + self, + procname: tds_base.InternalProc | str, + parameters: dict[str, Any] | typing.Iterable[Any], + ) -> list[Any]: self._ensure_transaction() results = list(parameters) conv_parameters = self._convert_params(parameters) @@ -752,7 +831,7 @@ def get_proc_outputs(self) -> list[Any]: return results def get_proc_return_status(self) -> int | None: - """ Last executed stored procedure's return value + """Last executed stored procedure's return value Returns integer value returned by `RETURN` statement from last executed stored procedure. If no value was not returned or no stored procedure was executed return `None`. @@ -761,7 +840,11 @@ def get_proc_return_status(self) -> int | None: self.find_return_status() return self.ret_status if self.has_status else None - def executemany(self, operation: str, params_seq: Iterable[list[Any] | tuple[Any, ...] | dict[str, Any]]) -> None: + def executemany( + self, + operation: str, + params_seq: Iterable[list[Any] | tuple[Any, ...] | dict[str, Any]], + ) -> None: """ Execute same SQL query multiple times for each parameter set in the `params_seq` list. """ @@ -773,7 +856,11 @@ def executemany(self, operation: str, params_seq: Iterable[list[Any] | tuple[Any if counts: self.rows_affected = sum(counts) - def execute(self, operation: str, params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = None) -> None: + def execute( + self, + operation: str, + params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = None, + ) -> None: self._ensure_transaction() if params: named_params = {} @@ -782,9 +869,9 @@ def execute(self, operation: str, params: list[Any] | tuple[Any, ...] | dict[str pid = 1 for val in params: if val is None: - names.append('NULL') + names.append("NULL") else: - name = '@P{0}'.format(pid) + name = "@P{0}".format(pid) names.append(name) named_params[name] = val pid += 1 @@ -797,21 +884,27 @@ def execute(self, operation: str, params: list[Any] | tuple[Any, ...] | dict[str rename = {} for name, value in params.items(): if value is None: - rename[name] = 'NULL' + rename[name] = "NULL" else: - mssql_name = '@{0}'.format(name.replace(' ', '_')) + mssql_name = "@{0}".format(name.replace(" ", "_")) rename[name] = mssql_name named_params[mssql_name] = value operation = operation % rename if named_params: list_named_params = self._convert_params(named_params) - param_definition = u','.join( - u'{0} {1}'.format(p.name, p.type.get_declaration()) - for p in list_named_params) + param_definition = ",".join( + "{0} {1}".format(p.name, p.type.get_declaration()) + for p in list_named_params + ) self.submit_rpc( tds_base.SP_EXECUTESQL, - [self.make_param('', operation), self.make_param('', param_definition)] + list_named_params, - 0) + [ + self.make_param("", operation), + self.make_param("", param_definition), + ] + + list_named_params, + 0, + ) else: self.submit_plain_query(operation) else: @@ -819,7 +912,11 @@ def execute(self, operation: str, params: list[Any] | tuple[Any, ...] | dict[str self.begin_response() self.find_result_or_done() - def execute_scalar(self, query_string: str, params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = None) -> Any: + def execute_scalar( + self, + query_string: str, + params: list[Any] | tuple[Any, ...] | dict[str, Any] | None = None, + ) -> Any: """ This method executes SQL query then returns first column of first row or the result. @@ -843,7 +940,7 @@ def execute_scalar(self, query_string: str, params: list[Any] | tuple[Any, ...] return row[0] def submit_plain_query(self, operation: str) -> None: - """ Sends a plain query to server. + """Sends a plain query to server. This call will transition session into pending state. If some operation is currently pending on the session, it will be @@ -863,8 +960,12 @@ def submit_plain_query(self, operation: str) -> None: self._start_query() w.write_ucs2(operation) - def submit_bulk(self, metadata: list[tds_base.Column], rows: Iterable[collections.abc.Sequence[Any]]) -> None: - """ Sends insert bulk command. + def submit_bulk( + self, + metadata: list[tds_base.Column], + rows: Iterable[collections.abc.Sequence[Any]], + ) -> None: + """Sends insert bulk command. Spec: http://msdn.microsoft.com/en-us/library/dd358082.aspx @@ -872,7 +973,7 @@ def submit_bulk(self, metadata: list[tds_base.Column], rows: Iterable[collection :param rows: A collection of rows, each row is a collection of values. :return: """ - logger.info('Sending INSERT BULK') + logger.info("Sending INSERT BULK") num_cols = len(metadata) w = self._writer serializers = [] @@ -911,19 +1012,19 @@ def submit_bulk(self, metadata: list[tds_base.Column], rows: Iterable[collection w.put_int(0) def put_cancel(self) -> None: - """ Sends a cancel request to the server. + """Sends a cancel request to the server. Switches connection to IN_CANCEL state. """ - logger.info('Sending CANCEL') + logger.info("Sending CANCEL") self._writer.begin_packet(tds_base.PacketType.CANCEL) self._writer.flush() self.in_cancel = True - _begin_tran_struct_72 = struct.Struct(' None: - logger.info('Sending BEGIN TRAN il=%x', self._env.isolation_level) + logger.info("Sending BEGIN TRAN il=%x", self._env.isolation_level) self.submit_begin_tran(isolation_level=self._env.isolation_level) self.process_simple_request() @@ -944,8 +1045,8 @@ def submit_begin_tran(self, isolation_level: int = 0) -> None: self.submit_plain_query("BEGIN TRANSACTION") self.conn.tds72_transaction = 1 - _commit_rollback_tran_struct72_hdr = struct.Struct(' None: """ @@ -956,12 +1057,12 @@ def rollback(self, cont: bool) -> None: if self._env.autocommit: return - #if not self._conn or not self._conn.is_connected(): + # if not self._conn or not self._conn.is_connected(): # return if not self._tds.tds72_transaction: return - logger.info('Sending ROLLBACK TRAN') + logger.info("Sending ROLLBACK TRAN") self.submit_rollback(cont, isolation_level=self._env.isolation_level) prev_timeout = self._tds.sock.gettimeout() self._tds.sock.settimeout(None) @@ -999,7 +1100,10 @@ def submit_rollback(self, cont: bool, isolation_level: int = 0) -> None: ) else: self.submit_plain_query( - "IF @@TRANCOUNT > 0 ROLLBACK BEGIN TRANSACTION" if cont else "IF @@TRANCOUNT > 0 ROLLBACK") + "IF @@TRANCOUNT > 0 ROLLBACK BEGIN TRANSACTION" + if cont + else "IF @@TRANCOUNT > 0 ROLLBACK" + ) self.conn.tds72_transaction = 1 if cont else 0 def commit(self, cont: bool) -> None: @@ -1007,7 +1111,7 @@ def commit(self, cont: bool) -> None: return if not self._tds.tds72_transaction: return - logger.info('Sending COMMIT TRAN') + logger.info("Sending COMMIT TRAN") self.submit_commit(cont, isolation_level=self._env.isolation_level) prev_timeout = self._tds.sock.gettimeout() self._tds.sock.settimeout(None) @@ -1040,59 +1144,82 @@ def submit_commit(self, cont: bool, isolation_level: int = 0) -> None: ) else: self.submit_plain_query( - "IF @@TRANCOUNT > 0 COMMIT BEGIN TRANSACTION" if cont else "IF @@TRANCOUNT > 0 COMMIT") + "IF @@TRANCOUNT > 0 COMMIT BEGIN TRANSACTION" + if cont + else "IF @@TRANCOUNT > 0 COMMIT" + ) self.conn.tds72_transaction = 1 if cont else 0 - _tds72_query_start = struct.Struct(' None: w = self._writer - w.pack(_TdsSession._tds72_query_start, - 0x16, # total length - 0x12, # length - 2, # type - self.conn.tds72_transaction, - 1, # request count - ) + w.pack( + _TdsSession._tds72_query_start, + 0x16, # total length + 0x12, # length + 2, # type + self.conn.tds72_transaction, + 1, # request count + ) def send_prelogin(self, login: _TdsLogin) -> None: from . import intversion + # https://msdn.microsoft.com/en-us/library/dd357559.aspx - instance_name = login.instance_name or 'MSSQLServer' - instance_name_encoded = instance_name.encode('ascii') + instance_name = login.instance_name or "MSSQLServer" + instance_name_encoded = instance_name.encode("ascii") if len(instance_name_encoded) > 65490: - raise ValueError('Instance name is too long') + raise ValueError("Instance name is too long") if tds_base.IS_TDS72_PLUS(self): start_pos = 26 buf = struct.pack( - b'>BHHBHHBHHBHHBHHB', + b">BHHBHHBHHBHHBHHB", # netlib version - PreLoginToken.VERSION, start_pos, 6, + PreLoginToken.VERSION, + start_pos, + 6, # encryption - PreLoginToken.ENCRYPTION, start_pos + 6, 1, + PreLoginToken.ENCRYPTION, + start_pos + 6, + 1, # instance - PreLoginToken.INSTOPT, start_pos + 6 + 1, len(instance_name_encoded) + 1, + PreLoginToken.INSTOPT, + start_pos + 6 + 1, + len(instance_name_encoded) + 1, # thread id - PreLoginToken.THREADID, start_pos + 6 + 1 + len(instance_name_encoded) + 1, 4, + PreLoginToken.THREADID, + start_pos + 6 + 1 + len(instance_name_encoded) + 1, + 4, # MARS enabled - PreLoginToken.MARS, start_pos + 6 + 1 + len(instance_name_encoded) + 1 + 4, 1, + PreLoginToken.MARS, + start_pos + 6 + 1 + len(instance_name_encoded) + 1 + 4, + 1, # end - PreLoginToken.TERMINATOR + PreLoginToken.TERMINATOR, ) else: start_pos = 21 buf = struct.pack( - b'>BHHBHHBHHBHHB', + b">BHHBHHBHHBHHB", # netlib version - PreLoginToken.VERSION, start_pos, 6, + PreLoginToken.VERSION, + start_pos, + 6, # encryption - PreLoginToken.ENCRYPTION, start_pos + 6, 1, + PreLoginToken.ENCRYPTION, + start_pos + 6, + 1, # instance - PreLoginToken.INSTOPT, start_pos + 6 + 1, len(instance_name_encoded) + 1, + PreLoginToken.INSTOPT, + start_pos + 6 + 1, + len(instance_name_encoded) + 1, # thread id - PreLoginToken.THREADID, start_pos + 6 + 1 + len(instance_name_encoded) + 1, 4, + PreLoginToken.THREADID, + start_pos + 6 + 1 + len(instance_name_encoded) + 1, + 4, # end - PreLoginToken.TERMINATOR + PreLoginToken.TERMINATOR, ) assert start_pos == len(buf) w = self._writer @@ -1106,20 +1233,22 @@ def send_prelogin(self, login: _TdsLogin) -> None: w.put_byte(0) # zero terminate instance_name w.put_int(0) # TODO: change this to thread id attribs: dict[str, str | int | bool] = { - 'lib_ver': f'{intversion:x}', - 'enc_flag': f'{login.enc_flag:x}', - 'inst_name': instance_name, + "lib_ver": f"{intversion:x}", + "enc_flag": f"{login.enc_flag:x}", + "inst_name": instance_name, } if tds_base.IS_TDS72_PLUS(self): # MARS (1 enabled) w.put_byte(1 if login.use_mars else 0) - attribs['mars'] = login.use_mars - logger.info('Sending PRELOGIN %s', ' '.join(f'{n}={v!r}' for n, v in attribs.items())) + attribs["mars"] = login.use_mars + logger.info( + "Sending PRELOGIN %s", " ".join(f"{n}={v!r}" for n, v in attribs.items()) + ) w.flush() def begin_response(self) -> ResponseMetadata: - """ Begins reading next response from server. + """Begins reading next response from server. If timeout happens during reading of first packet will send cancellation message. @@ -1136,48 +1265,58 @@ def process_prelogin(self, login: _TdsLogin) -> None: p = self._reader.read_whole_packet() size = len(p) if size <= 0 or resp_header.type != tds_base.PacketType.REPLY: - self.bad_stream('Invalid packet type: {0}, expected PRELOGIN(4)'.format(self._reader.packet_type)) + self.bad_stream( + "Invalid packet type: {0}, expected PRELOGIN(4)".format( + self._reader.packet_type + ) + ) self.parse_prelogin(octets=p, login=login) def parse_prelogin(self, octets: bytes, login: _TdsLogin) -> None: from . import tls + # https://msdn.microsoft.com/en-us/library/dd357559.aspx size = len(octets) p = octets # default 2, no certificate, no encryptption crypt_flag = 2 i = 0 - byte_struct = struct.Struct('B') - off_len_struct = struct.Struct('>HH') - prod_version_struct = struct.Struct('>LH') + byte_struct = struct.Struct("B") + off_len_struct = struct.Struct(">HH") + prod_version_struct = struct.Struct(">LH") while True: if i >= size: - self.bad_stream('Invalid size of PRELOGIN structure') - type_id, = byte_struct.unpack_from(p, i) + self.bad_stream("Invalid size of PRELOGIN structure") + (type_id,) = byte_struct.unpack_from(p, i) if type_id == PreLoginToken.TERMINATOR: break if i + 4 > size: - self.bad_stream('Invalid size of PRELOGIN structure') + self.bad_stream("Invalid size of PRELOGIN structure") off, l = off_len_struct.unpack_from(p, i + 1) if off > size or off + l > size: - self.bad_stream('Invalid offset in PRELOGIN structure') + self.bad_stream("Invalid offset in PRELOGIN structure") if type_id == PreLoginToken.VERSION: - self.conn.server_library_version = prod_version_struct.unpack_from(p, off) + self.conn.server_library_version = prod_version_struct.unpack_from( + p, off + ) elif type_id == PreLoginToken.ENCRYPTION and l >= 1: - crypt_flag, = byte_struct.unpack_from(p, off) + (crypt_flag,) = byte_struct.unpack_from(p, off) elif type_id == PreLoginToken.MARS: self.conn._mars_enabled = bool(byte_struct.unpack_from(p, off)[0]) elif type_id == PreLoginToken.INSTOPT: # ignore instance name mismatch pass i += 5 - logger.info("Got PRELOGIN response crypt=%x mars=%d", - crypt_flag, self.conn._mars_enabled) + logger.info( + "Got PRELOGIN response crypt=%x mars=%d", + crypt_flag, + self.conn._mars_enabled, + ) # if server do not has certificate do normal login login.server_enc_flag = crypt_flag if crypt_flag == PreLoginEnc.ENCRYPT_OFF: if login.enc_flag == PreLoginEnc.ENCRYPT_ON: - self.bad_stream('Server returned unexpected ENCRYPT_ON value') + self.bad_stream("Server returned unexpected ENCRYPT_ON value") else: # encrypt login packet only tls.establish_channel(self) @@ -1187,64 +1326,85 @@ def parse_prelogin(self, octets: bytes, login: _TdsLogin) -> None: elif crypt_flag == PreLoginEnc.ENCRYPT_REQ: if login.enc_flag == PreLoginEnc.ENCRYPT_NOT_SUP: # connection terminated by server and client - raise tds_base.Error('Client does not have encryption enabled but it is required by server, ' - 'enable encryption and try connecting again') + raise tds_base.Error( + "Client does not have encryption enabled but it is required by server, " + "enable encryption and try connecting again" + ) else: # encrypt entire connection tls.establish_channel(self) elif crypt_flag == PreLoginEnc.ENCRYPT_NOT_SUP: if login.enc_flag == PreLoginEnc.ENCRYPT_ON: # connection terminated by server and client - raise tds_base.Error('You requested encryption but it is not supported by server') + raise tds_base.Error( + "You requested encryption but it is not supported by server" + ) # do not encrypt anything else: - self.bad_stream('Unexpected value of enc_flag returned by server: {}'.format(crypt_flag)) + self.bad_stream( + "Unexpected value of enc_flag returned by server: {}".format(crypt_flag) + ) def tds7_send_login(self, login: _TdsLogin) -> None: # https://msdn.microsoft.com/en-us/library/dd304019.aspx option_flag2 = login.option_flag2 user_name = login.user_name if len(user_name) > 128: - raise ValueError('User name should be no longer that 128 characters') + raise ValueError("User name should be no longer that 128 characters") if len(login.password) > 128: - raise ValueError('Password should be not longer than 128 characters') + raise ValueError("Password should be not longer than 128 characters") if len(login.change_password) > 128: - raise ValueError('Password should be not longer than 128 characters') + raise ValueError("Password should be not longer than 128 characters") if len(login.client_host_name) > 128: - raise ValueError('Host name should be not longer than 128 characters') + raise ValueError("Host name should be not longer than 128 characters") if len(login.app_name) > 128: - raise ValueError('App name should be not longer than 128 characters') + raise ValueError("App name should be not longer than 128 characters") if len(login.server_name) > 128: - raise ValueError('Server name should be not longer than 128 characters') + raise ValueError("Server name should be not longer than 128 characters") if len(login.database) > 128: - raise ValueError('Database name should be not longer than 128 characters') + raise ValueError("Database name should be not longer than 128 characters") if len(login.language) > 128: - raise ValueError('Language should be not longer than 128 characters') + raise ValueError("Language should be not longer than 128 characters") if len(login.attach_db_file) > 260: - raise ValueError('File path should be not longer than 260 characters') + raise ValueError("File path should be not longer than 260 characters") w = self._writer w.begin_packet(tds_base.PacketType.LOGIN) self.authentication = None current_pos = 86 + 8 if tds_base.IS_TDS72_PLUS(self) else 86 client_host_name = login.client_host_name login.client_host_name = client_host_name - packet_size = current_pos + (len(client_host_name) + len(login.app_name) + len(login.server_name) + - len(login.library) + len(login.language) + len(login.database)) * 2 + packet_size = ( + current_pos + + ( + len(client_host_name) + + len(login.app_name) + + len(login.server_name) + + len(login.library) + + len(login.language) + + len(login.database) + ) + * 2 + ) if login.auth: self.authentication = login.auth auth_packet = login.auth.create_packet() packet_size += len(auth_packet) else: - auth_packet = b'' + auth_packet = b"" packet_size += (len(user_name) + len(login.password)) * 2 w.put_int(packet_size) w.put_uint(login.tds_version) w.put_int(login.blocksize) from . import intversion + w.put_uint(intversion) w.put_int(login.pid) w.put_uint(0) # connection id - option_flag1 = tds_base.TDS_SET_LANG_ON | tds_base.TDS_USE_DB_NOTIFY | tds_base.TDS_INIT_DB_FATAL + option_flag1 = ( + tds_base.TDS_SET_LANG_ON + | tds_base.TDS_USE_DB_NOTIFY + | tds_base.TDS_INIT_DB_FATAL + ) if not login.bulk_copy: option_flag1 |= tds_base.TDS_DUMPLOAD_OFF w.put_byte(option_flag1) @@ -1257,11 +1417,30 @@ def tds7_send_login(self, login: _TdsLogin) -> None: w.put_byte(type_flags) option_flag3 = tds_base.TDS_UNKNOWN_COLLATION_HANDLING w.put_byte(option_flag3 if tds_base.IS_TDS73_PLUS(self) else 0) - mins_fix = int((login.client_tz.utcoffset(datetime.datetime.now()) or datetime.timedelta()).total_seconds()) // 60 - logger.info('Sending LOGIN tds_ver=%x bufsz=%d pid=%d opt1=%x opt2=%x opt3=%x cli_tz=%d cli_lcid=%s ' - 'cli_host=%s lang=%s db=%s', - login.tds_version, w.bufsize, login.pid, option_flag1, option_flag2, option_flag3, mins_fix, - login.client_lcid, client_host_name, login.language, login.database) + mins_fix = ( + int( + ( + login.client_tz.utcoffset(datetime.datetime.now()) + or datetime.timedelta() + ).total_seconds() + ) + // 60 + ) + logger.info( + "Sending LOGIN tds_ver=%x bufsz=%d pid=%d opt1=%x opt2=%x opt3=%x cli_tz=%d cli_lcid=%s " + "cli_host=%s lang=%s db=%s", + login.tds_version, + w.bufsize, + login.pid, + option_flag1, + option_flag2, + option_flag3, + mins_fix, + login.client_lcid, + client_host_name, + login.language, + login.database, + ) w.put_int(mins_fix) w.put_int(login.client_lcid) w.put_smallint(current_pos) @@ -1302,7 +1481,7 @@ def tds7_send_login(self, login: _TdsLogin) -> None: w.put_smallint(len(login.database)) current_pos += len(login.database) * 2 # ClientID - client_id = struct.pack('>Q', login.client_id)[2:] + client_id = struct.pack(">Q", login.client_id)[2:] w.write(client_id) # authentication w.put_smallint(current_pos) @@ -1361,17 +1540,23 @@ def process_login_tokens(self) -> bool: size = r.get_smallint() r.get_byte() # interface version = r.get_uint_be() - self.conn.tds_version = self._SERVER_TO_CLIENT_MAPPING.get(version, version) + self.conn.tds_version = self._SERVER_TO_CLIENT_MAPPING.get( + version, version + ) if not tds_base.IS_TDS7_PLUS(self): - self.bad_stream('Only TDS 7.0 and higher are supported') + self.bad_stream("Only TDS 7.0 and higher are supported") # get server product name # ignore product name length, some servers seem to set it incorrectly r.get_byte() size -= 10 self.conn.product_name = r.read_ucs2(size // 2) product_version = r.get_uint_be() - logger.info('Got LOGINACK tds_ver=%x srv_name=%s srv_ver=%x', - self.conn.tds_version, self.conn.product_name, product_version) + logger.info( + "Got LOGINACK tds_ver=%x srv_name=%s srv_ver=%x", + self.conn.tds_version, + self.conn.product_name, + product_version, + ) # MSSQL 6.5 and 7.0 seem to return strange values for this # using TDS 4.2, something like 5F 06 32 FF for 6.50 self.conn.product_version = product_version @@ -1385,14 +1570,14 @@ def process_login_tokens(self) -> bool: return succeed def process_returnstatus(self) -> None: - self.log_response_message('got RETURNSTATUS message') + self.log_response_message("got RETURNSTATUS message") self.ret_status = self._reader.get_int() self.has_status = True def process_token(self, marker: int) -> Any: handler = _token_map.get(marker) if not handler: - self.bad_stream(f'Invalid TDS marker: {marker}({marker:x})') + self.bad_stream(f"Invalid TDS marker: {marker}({marker:x})") return return handler(self) @@ -1412,7 +1597,11 @@ def process_simple_request(self) -> None: self.begin_response() while True: marker = self.get_token_id() - if marker in (tds_base.TDS_DONE_TOKEN, tds_base.TDS_DONEPROC_TOKEN, tds_base.TDS_DONEINPROC_TOKEN): + if marker in ( + tds_base.TDS_DONE_TOKEN, + tds_base.TDS_DONEPROC_TOKEN, + tds_base.TDS_DONEINPROC_TOKEN, + ): self.process_end(marker) if not self.done_flags & tds_base.TDS_DONE_MORE_RESULTS: return @@ -1437,10 +1626,14 @@ def fetchone(self) -> Any | None: def _fetchone(self) -> list[Any] | None: if self.res_info is None: - raise tds_base.ProgrammingError("Previous statement didn't produce any results") + raise tds_base.ProgrammingError( + "Previous statement didn't produce any results" + ) if self.skipped_to_status: - raise tds_base.ProgrammingError("Unable to fetch any rows after accessing return_status") + raise tds_base.ProgrammingError( + "Unable to fetch any rows after accessing return_status" + ) if not self.next_row(): return None @@ -1455,7 +1648,11 @@ def next_row(self) -> bool: if marker in (tds_base.TDS_ROW_TOKEN, tds_base.TDS_NBC_ROW_TOKEN): self.process_token(marker) return True - elif marker in (tds_base.TDS_DONE_TOKEN, tds_base.TDS_DONEPROC_TOKEN, tds_base.TDS_DONEINPROC_TOKEN): + elif marker in ( + tds_base.TDS_DONE_TOKEN, + tds_base.TDS_DONEPROC_TOKEN, + tds_base.TDS_DONEINPROC_TOKEN, + ): self.process_end(marker) return False else: @@ -1468,7 +1665,11 @@ def find_result_or_done(self) -> bool: if marker == tds_base.TDS7_RESULT_TOKEN: self.process_token(marker) return True - elif marker in (tds_base.TDS_DONE_TOKEN, tds_base.TDS_DONEPROC_TOKEN, tds_base.TDS_DONEINPROC_TOKEN): + elif marker in ( + tds_base.TDS_DONE_TOKEN, + tds_base.TDS_DONEPROC_TOKEN, + tds_base.TDS_DONEINPROC_TOKEN, + ): self.process_end(marker) if self.done_flags & tds_base.TDS_DONE_MORE_RESULTS: if self.done_flags & tds_base.TDS_DONE_COUNT: @@ -1488,7 +1689,10 @@ def process_rpc(self) -> bool: return True elif marker in (tds_base.TDS_DONE_TOKEN, tds_base.TDS_DONEPROC_TOKEN): self.process_end(marker) - if self.done_flags & tds_base.TDS_DONE_MORE_RESULTS and not self.done_flags & tds_base.TDS_DONE_COUNT: + if ( + self.done_flags & tds_base.TDS_DONE_MORE_RESULTS + and not self.done_flags & tds_base.TDS_DONE_COUNT + ): # skip results that don't event have rowcount continue return False @@ -1513,11 +1717,17 @@ def find_return_status(self) -> None: tds_base.TDS_AUTH_TOKEN: _TdsSession.process_auth, tds_base.TDS_ENVCHANGE_TOKEN: _TdsSession.process_env_chg, tds_base.TDS_DONE_TOKEN: lambda self: self.process_end(tds_base.TDS_DONE_TOKEN), - tds_base.TDS_DONEPROC_TOKEN: lambda self: self.process_end(tds_base.TDS_DONEPROC_TOKEN), - tds_base.TDS_DONEINPROC_TOKEN: lambda self: self.process_end(tds_base.TDS_DONEINPROC_TOKEN), + tds_base.TDS_DONEPROC_TOKEN: lambda self: self.process_end( + tds_base.TDS_DONEPROC_TOKEN + ), + tds_base.TDS_DONEINPROC_TOKEN: lambda self: self.process_end( + tds_base.TDS_DONEINPROC_TOKEN + ), tds_base.TDS_ERROR_TOKEN: lambda self: self.process_msg(tds_base.TDS_ERROR_TOKEN), tds_base.TDS_INFO_TOKEN: lambda self: self.process_msg(tds_base.TDS_INFO_TOKEN), - tds_base.TDS_CAPABILITY_TOKEN: lambda self: self.process_msg(tds_base.TDS_CAPABILITY_TOKEN), + tds_base.TDS_CAPABILITY_TOKEN: lambda self: self.process_msg( + tds_base.TDS_CAPABILITY_TOKEN + ), tds_base.TDS_PARAM_TOKEN: lambda self: self.process_param(), tds_base.TDS7_RESULT_TOKEN: lambda self: self.tds7_process_result(), tds_base.TDS_ROW_TOKEN: lambda self: self.process_row(), diff --git a/src/pytds/tds_types.py b/src/pytds/tds_types.py index 3dc50d5..8a0dca3 100644 --- a/src/pytds/tds_types.py +++ b/src/pytds/tds_types.py @@ -16,8 +16,8 @@ from . import tz -_flt4_struct = struct.Struct('f') -_flt8_struct = struct.Struct('d') +_flt4_struct = struct.Struct("f") +_flt8_struct = struct.Struct("d") _utc = tz.utc @@ -32,18 +32,21 @@ def _applytz(dt, tzinfo): def _decode_num(buf): - """ Decodes little-endian integer from buffer + """Decodes little-endian integer from buffer Buffer can be of any size """ - return functools.reduce(lambda acc, val: acc * 256 + tds_base.my_ord(val), reversed(buf), 0) + return functools.reduce( + lambda acc, val: acc * 256 + tds_base.my_ord(val), reversed(buf), 0 + ) class PlpReader(object): - """ Partially length prefixed reader + """Partially length prefixed reader Spec: http://msdn.microsoft.com/en-us/library/dd340469.aspx """ + def __init__(self, r): """ :param r: An instance of :class:`_TdsReader` @@ -71,8 +74,7 @@ def size(self): return self._size def chunks(self): - """ Generates chunks from stream, each chunk is an instace of bytes. - """ + """Generates chunks from stream, each chunk is an instace of bytes.""" if self.is_null(): return total = 0 @@ -80,7 +82,10 @@ def chunks(self): chunk_len = self._rdr.get_uint() if chunk_len == 0: if not self.is_unknown_len() and total != self._size: - msg = "PLP actual length (%d) doesn't match reported length (%d)" % (total, self._size) + msg = ( + "PLP actual length (%d) doesn't match reported length (%d)" + % (total, self._size) + ) self._rdr.session.bad_stream(msg) return @@ -126,7 +131,7 @@ def __ne__(self, other): class SqlTypeMetaclass(tds_base.CommonEqualityMixin): def __repr__(self): - return ''.format(self.get_declaration()) + return "".format(self.get_declaration()) def get_declaration(self): raise NotImplementedError() @@ -134,7 +139,7 @@ def get_declaration(self): class ImageType(SqlTypeMetaclass): def get_declaration(self): - return 'IMAGE' + return "IMAGE" class BinaryType(SqlTypeMetaclass): @@ -146,7 +151,7 @@ def size(self): return self._size def get_declaration(self): - return 'BINARY({})'.format(self._size) + return "BINARY({})".format(self._size) class VarBinaryType(SqlTypeMetaclass): @@ -158,12 +163,12 @@ def size(self): return self._size def get_declaration(self): - return 'VARBINARY({})'.format(self._size) + return "VARBINARY({})".format(self._size) class VarBinaryMaxType(SqlTypeMetaclass): def get_declaration(self): - return 'VARBINARY(MAX)' + return "VARBINARY(MAX)" class CharType(SqlTypeMetaclass): @@ -175,7 +180,7 @@ def size(self): return self._size def get_declaration(self): - return 'CHAR({})'.format(self._size) + return "CHAR({})".format(self._size) class VarCharType(SqlTypeMetaclass): @@ -187,12 +192,12 @@ def size(self): return self._size def get_declaration(self): - return 'VARCHAR({})'.format(self._size) + return "VARCHAR({})".format(self._size) class VarCharMaxType(SqlTypeMetaclass): def get_declaration(self): - return 'VARCHAR(MAX)' + return "VARCHAR(MAX)" class NCharType(SqlTypeMetaclass): @@ -204,7 +209,7 @@ def size(self): return self._size def get_declaration(self): - return 'NCHAR({})'.format(self._size) + return "NCHAR({})".format(self._size) class NVarCharType(SqlTypeMetaclass): @@ -216,37 +221,37 @@ def size(self): return self._size def get_declaration(self): - return 'NVARCHAR({})'.format(self._size) + return "NVARCHAR({})".format(self._size) class NVarCharMaxType(SqlTypeMetaclass): def get_declaration(self): - return 'NVARCHAR(MAX)' + return "NVARCHAR(MAX)" class TextType(SqlTypeMetaclass): def get_declaration(self): - return 'TEXT' + return "TEXT" class NTextType(SqlTypeMetaclass): def get_declaration(self): - return 'NTEXT' + return "NTEXT" class XmlType(SqlTypeMetaclass): def get_declaration(self): - return 'XML' + return "XML" class SmallMoneyType(SqlTypeMetaclass): def get_declaration(self): - return 'SMALLMONEY' + return "SMALLMONEY" class MoneyType(SqlTypeMetaclass): def get_declaration(self): - return 'MONEY' + return "MONEY" class DecimalType(SqlTypeMetaclass): @@ -256,8 +261,8 @@ def __init__(self, precision=18, scale=0): @classmethod def from_value(cls, value): - if not (-10 ** 38 + 1 <= value <= 10 ** 38 - 1): - raise tds_base.DataError('Decimal value is out of range') + if not (-(10**38) + 1 <= value <= 10**38 - 1): + raise tds_base.DataError("Decimal value is out of range") with decimal.localcontext() as context: context.prec = 38 value = value.normalize() @@ -279,17 +284,17 @@ def scale(self): return self._scale def get_declaration(self): - return 'DECIMAL({}, {})'.format(self._precision, self._scale) + return "DECIMAL({}, {})".format(self._precision, self._scale) class UniqueIdentifierType(SqlTypeMetaclass): def get_declaration(self): - return 'UNIQUEIDENTIFIER' + return "UNIQUEIDENTIFIER" class VariantType(SqlTypeMetaclass): def get_declaration(self): - return 'SQL_VARIANT' + return "SQL_VARIANT" class SqlValueMetaclass(tds_base.CommonEqualityMixin): @@ -297,13 +302,14 @@ class SqlValueMetaclass(tds_base.CommonEqualityMixin): class BaseTypeSerializer(tds_base.CommonEqualityMixin): - """ Base type for TDS data types. + """Base type for TDS data types. All TDS types should derive from it. In addition actual types should provide the following: - type - class variable storing type identifier """ + type = 0 def __init__(self, precision=None, scale=None, size=None): @@ -324,12 +330,12 @@ def size(self): return self._size def get_typeid(self): - """ Returns type identifier of type. """ + """Returns type identifier of type.""" return self.type @classmethod def from_stream(cls, r): - """ Class method that reads and returns a type instance. + """Class method that reads and returns a type instance. :param r: An instance of :class:`_TdsReader` to read type from. @@ -338,7 +344,7 @@ def from_stream(cls, r): raise NotImplementedError def write_info(self, w): - """ Writes type info into w stream. + """Writes type info into w stream. :param w: An instance of :class:`_TdsWriter` to write into. @@ -348,7 +354,7 @@ def write_info(self, w): raise NotImplementedError def write(self, w, value): - """ Writes type's value into stream + """Writes type's value into stream :param w: An instance of :class:`_TdsWriter` to write into. :param value: A value to be stored, should be compatible with the type @@ -358,7 +364,7 @@ def write(self, w, value): raise NotImplementedError def read(self, r): - """ Reads value from the stream. + """Reads value from the stream. :param r: An instance of :class:`_TdsReader` to read value from. :return: A read value. @@ -372,7 +378,7 @@ def set_chunk_handler(self, chunk_handler): class BasePrimitiveTypeSerializer(BaseTypeSerializer): - """ Base type for primitive TDS data types. + """Base type for primitive TDS data types. Primitive type is a fixed size type with no type arguments. All primitive TDS types should derive from it. @@ -400,7 +406,7 @@ def write_info(self, w): class BaseTypeSerializerN(BaseTypeSerializer): - """ Base type for nullable TDS data types. + """Base type for nullable TDS data types. All nullable TDS types should derive from it. In addition actual types should provide the following: @@ -408,6 +414,7 @@ class BaseTypeSerializerN(BaseTypeSerializer): - type - class variable storing type identifier - subtypes - class variable storing dict {subtype_size: subtype_instance} """ + subtypes: dict[int, BaseTypeSerializer] = {} def __init__(self, size): @@ -422,7 +429,7 @@ def get_typeid(self): def from_stream(cls, r): size = r.get_byte() if size not in cls.subtypes: - raise tds_base.InterfaceError('Invalid %s size' % cls.type, size) + raise tds_base.InterfaceError("Invalid %s size" % cls.type, size) return cls(size) def write_info(self, w): @@ -433,7 +440,7 @@ def read(self, r): if size == 0: return None if size not in self.subtypes: - raise r.session.bad_stream('Invalid %s size' % self.type, size) + raise r.session.bad_stream("Invalid %s size" % self.type, size) return self.subtypes[size].read(r) def write(self, w, val): @@ -448,7 +455,7 @@ class BitType(SqlTypeMetaclass): type = tds_base.SYBBITN def get_declaration(self): - return 'BIT' + return "BIT" class TinyIntType(SqlTypeMetaclass): @@ -456,7 +463,7 @@ class TinyIntType(SqlTypeMetaclass): size = 1 def get_declaration(self): - return 'TINYINT' + return "TINYINT" class SmallIntType(SqlTypeMetaclass): @@ -464,7 +471,7 @@ class SmallIntType(SqlTypeMetaclass): size = 2 def get_declaration(self): - return 'SMALLINT' + return "SMALLINT" class IntType(SqlTypeMetaclass): @@ -472,11 +479,12 @@ class IntType(SqlTypeMetaclass): Integer type, corresponds to `INT `_ type in the MSSQL server. """ + type = tds_base.SYBINTN size = 4 def get_declaration(self): - return 'INT' + return "INT" class BigIntType(SqlTypeMetaclass): @@ -484,22 +492,22 @@ class BigIntType(SqlTypeMetaclass): size = 8 def get_declaration(self): - return 'BIGINT' + return "BIGINT" class RealType(SqlTypeMetaclass): def get_declaration(self): - return 'REAL' + return "REAL" class FloatType(SqlTypeMetaclass): def get_declaration(self): - return 'FLOAT' + return "FLOAT" class BitSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBBIT - declaration = 'BIT' + declaration = "BIT" def write(self, w, value): w.put_byte(1 if value else 0) @@ -507,6 +515,7 @@ def write(self, w, value): def read(self, r): return bool(r.get_byte()) + BitSerializer.instance = bit_serializer = BitSerializer() @@ -519,15 +528,15 @@ def __init__(self, typ): self._typ = typ def __repr__(self): - return 'BitNSerializer({})'.format(self._typ) + return "BitNSerializer({})".format(self._typ) -#BitNSerializer.instance = BitNSerializer(BitType()) +# BitNSerializer.instance = BitNSerializer(BitType()) class TinyIntSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBINT1 - declaration = 'TINYINT' + declaration = "TINYINT" def write(self, w, val): w.put_byte(val) @@ -535,12 +544,13 @@ def write(self, w, val): def read(self, r): return r.get_byte() + TinyIntSerializer.instance = tiny_int_serializer = TinyIntSerializer() class SmallIntSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBINT2 - declaration = 'SMALLINT' + declaration = "SMALLINT" def write(self, w, val): w.put_smallint(val) @@ -548,12 +558,13 @@ def write(self, w, val): def read(self, r): return r.get_smallint() + SmallIntSerializer.instance = small_int_serializer = SmallIntSerializer() class IntSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBINT4 - declaration = 'INT' + declaration = "INT" def write(self, w, val): w.put_int(val) @@ -561,12 +572,13 @@ def write(self, w, val): def read(self, r): return r.get_int() + IntSerializer.instance = int_serializer = IntSerializer() class BigIntSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBINT8 - declaration = 'BIGINT' + declaration = "BIGINT" def write(self, w, val): w.put_int8(val) @@ -574,6 +586,7 @@ def write(self, w, val): def read(self, r): return r.get_int8() + BigIntSerializer.instance = big_int_serializer = BigIntSerializer() @@ -602,16 +615,16 @@ def __init__(self, typ): def from_stream(cls, r): size = r.get_byte() if size not in cls.subtypes: - raise tds_base.InterfaceError('Invalid %s size' % cls.type, size) + raise tds_base.InterfaceError("Invalid %s size" % cls.type, size) return cls(cls.type_by_size[size]) def __repr__(self): - return 'IntN({})'.format(self.size) + return "IntN({})".format(self.size) class RealSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBREAL - declaration = 'REAL' + declaration = "REAL" def write(self, w, val): w.pack(_flt4_struct, val) @@ -619,12 +632,13 @@ def write(self, w, val): def read(self, r): return r.unpack(_flt4_struct)[0] + RealSerializer.instance = real_serializer = RealSerializer() class FloatSerializer(BasePrimitiveTypeSerializer): type = tds_base.SYBFLT8 - declaration = 'FLOAT' + declaration = "FLOAT" def write(self, w, val): w.pack(_flt8_struct, val) @@ -632,6 +646,7 @@ def write(self, w, val): def read(self, r): return r.unpack(_flt8_struct)[0] + FloatSerializer.instance = float_serializer = FloatSerializer() @@ -718,7 +733,7 @@ class VarChar72Serializer(VarChar71Serializer): def from_stream(cls, r): size = r.get_usmallint() collation = r.get_collation() - if size == 0xffff: + if size == 0xFFFF: return VarCharMaxSerializer(collation) return cls(size, collation) @@ -797,7 +812,7 @@ def write_info(self, w): def write(self, w, val): if val is None: - w.put_usmallint(0xffff) + w.put_usmallint(0xFFFF) else: if isinstance(val, bytes): val = tds_base.force_unicode(val) @@ -808,7 +823,7 @@ def write(self, w, val): def read(self, r): size = r.get_usmallint() - if size == 0xffff: + if size == 0xFFFF: return None return r.read_str(size, ucs2_codec) @@ -830,7 +845,7 @@ class NVarChar72Serializer(NVarChar71Serializer): def from_stream(cls, r): size = r.get_usmallint() collation = r.get_collation() - if size == 0xffff: + if size == 0xFFFF: return NVarCharMaxSerializer(collation=collation) return cls(size / 2, collation=collation) @@ -841,7 +856,7 @@ def __init__(self, collation=raw_collation): self._chunk_handler = _DefaultChunkedHandler(StringIO()) def __repr__(self): - return 'NVarCharMax(s={},c={})'.format(self.size, repr(self._collation)) + return "NVarCharMax(s={},c={})".format(self.size, repr(self._collation)) def get_typeid(self): return tds_base.SYBNTEXT @@ -889,14 +904,14 @@ def set_chunk_handler(self, chunk_handler): class XmlSerializer(NVarCharMaxSerializer): type = tds_base.SYBMSXML - declaration = 'XML' + declaration = "XML" def __init__(self, schema=None): super(XmlSerializer, self).__init__(0) self._schema = schema or {} def __repr__(self): - return 'XmlSerializer(schema={})'.format(repr(self._schema)) + return "XmlSerializer(schema={})".format(repr(self._schema)) def get_typeid(self): return self.type @@ -906,29 +921,29 @@ def from_stream(cls, r): has_schema = r.get_byte() schema = {} if has_schema: - schema['dbname'] = r.read_ucs2(r.get_byte()) - schema['owner'] = r.read_ucs2(r.get_byte()) - schema['collection'] = r.read_ucs2(r.get_smallint()) + schema["dbname"] = r.read_ucs2(r.get_byte()) + schema["owner"] = r.read_ucs2(r.get_byte()) + schema["collection"] = r.read_ucs2(r.get_smallint()) return cls(schema) def write_info(self, w): if self._schema: w.put_byte(1) - w.put_byte(len(self._schema['dbname'])) - w.write_ucs2(self._schema['dbname']) - w.put_byte(len(self._schema['owner'])) - w.write_ucs2(self._schema['owner']) - w.put_usmallint(len(self._schema['collection'])) - w.write_ucs2(self._schema['collection']) + w.put_byte(len(self._schema["dbname"])) + w.write_ucs2(self._schema["dbname"]) + w.put_byte(len(self._schema["owner"])) + w.write_ucs2(self._schema["owner"]) + w.put_usmallint(len(self._schema["collection"])) + w.write_ucs2(self._schema["collection"]) else: w.put_byte(0) class Text70Serializer(BaseTypeSerializer): type = tds_base.SYBTEXT - declaration = 'TEXT' + declaration = "TEXT" - def __init__(self, size=0, table_name='', collation=raw_collation, codec=None): + def __init__(self, size=0, table_name="", collation=raw_collation, codec=None): super(Text70Serializer, self).__init__(size=size) self._table_name = table_name self._collation = collation @@ -939,7 +954,9 @@ def __init__(self, size=0, table_name='', collation=raw_collation, codec=None): self._chunk_handler = None def __repr__(self): - return 'Text70(size={},table_name={},codec={})'.format(self.size, self._table_name, self._codec) + return "Text70(size={},table_name={},codec={})".format( + self.size, self._table_name, self._codec + ) @classmethod def from_stream(cls, r): @@ -987,7 +1004,7 @@ def set_chunk_handler(self, chunk_handler): class Text71Serializer(Text70Serializer): def __repr__(self): - return 'Text71(size={}, table_name={}, collation={})'.format( + return "Text71(size={}, table_name={}, collation={})".format( self.size, self._table_name, repr(self._collation) ) @@ -1005,7 +1022,9 @@ def write_info(self, w): class Text72Serializer(Text71Serializer): def __init__(self, size=0, table_name_parts=(), collation=raw_collation): - super(Text72Serializer, self).__init__(size=size, table_name='.'.join(table_name_parts), collation=collation) + super(Text72Serializer, self).__init__( + size=size, table_name=".".join(table_name_parts), collation=collation + ) self._table_name_parts = table_name_parts @classmethod @@ -1021,16 +1040,16 @@ def from_stream(cls, r): class NText70Serializer(BaseTypeSerializer): type = tds_base.SYBNTEXT - declaration = 'NTEXT' + declaration = "NTEXT" - def __init__(self, size=0, table_name='', collation=raw_collation): + def __init__(self, size=0, table_name="", collation=raw_collation): super(NText70Serializer, self).__init__(size=size) self._collation = collation self._table_name = table_name self._chunk_handler = _DefaultChunkedHandler(StringIO()) def __repr__(self): - return 'NText70(size={}, table_name={})'.format(self.size, self._table_name) + return "NText70(size={}, table_name={})".format(self.size, self._table_name) @classmethod def from_stream(cls, r): @@ -1065,9 +1084,9 @@ def set_chunk_handler(self, chunk_handler): class NText71Serializer(NText70Serializer): def __repr__(self): - return 'NText71(size={}, table_name={}, collation={})'.format(self.size, - self._table_name, - repr(self._collation)) + return "NText71(size={}, table_name={}, collation={})".format( + self.size, self._table_name, repr(self._collation) + ) @classmethod def from_stream(cls, r): @@ -1087,8 +1106,9 @@ def __init__(self, size=0, table_name_parts=(), collation=raw_collation): self._table_name_parts = table_name_parts def __repr__(self): - return 'NText72Serializer(s={},table_name={},coll={})'.format( - self.size, self._table_name_parts, self._collation) + return "NText72Serializer(s={},table_name={},coll={})".format( + self.size, self._table_name_parts, self._collation + ) @classmethod def from_stream(cls, r): @@ -1103,7 +1123,7 @@ def from_stream(cls, r): class Binary(bytes, SqlValueMetaclass): def __repr__(self): - return 'Binary({0})'.format(super(Binary, self).__repr__()) + return "Binary({0})".format(super(Binary, self).__repr__()) class VarBinarySerializer(BaseTypeSerializer): @@ -1113,7 +1133,7 @@ def __init__(self, size): super(VarBinarySerializer, self).__init__(size=size) def __repr__(self): - return 'VarBinary({})'.format(self.size) + return "VarBinary({})".format(self.size) @classmethod def from_stream(cls, r): @@ -1125,26 +1145,26 @@ def write_info(self, w): def write(self, w, val): if val is None: - w.put_usmallint(0xffff) + w.put_usmallint(0xFFFF) else: w.put_usmallint(len(val)) w.write(val) def read(self, r): size = r.get_usmallint() - if size == 0xffff: + if size == 0xFFFF: return None return tds_base.readall(r, size) class VarBinarySerializer72(VarBinarySerializer): def __repr__(self): - return 'VarBinary72({})'.format(self.size) + return "VarBinary72({})".format(self.size) @classmethod def from_stream(cls, r): size = r.get_usmallint() - if size == 0xffff: + if size == 0xFFFF: return VarBinarySerializerMax() return cls(size) @@ -1155,7 +1175,7 @@ def __init__(self): self._chunk_handler = _DefaultChunkedHandler(BytesIO()) def __repr__(self): - return 'VarBinaryMax()' + return "VarBinaryMax()" def write_info(self, w): w.put_usmallint(tds_base.PLP_MARKER) @@ -1185,8 +1205,9 @@ def set_chunk_handler(self, chunk_handler): class UDT72Serializer(BaseTypeSerializer): # Data type definition stream used for UDT_INFO in TYPE_INFO # https://msdn.microsoft.com/en-us/library/a57df60e-d0a6-4e7e-a2e5-ccacd277c673/ - def __init__(self, max_byte_size, db_name, schema_name, type_name, - assembly_qualified_name): + def __init__( + self, max_byte_size, db_name, schema_name, type_name, assembly_qualified_name + ): self.max_byte_size = max_byte_size self.db_name = db_name self.schema_name = schema_name @@ -1195,19 +1216,28 @@ def __init__(self, max_byte_size, db_name, schema_name, type_name, super(UDT72Serializer, self).__init__() def __repr__(self): - return ('UDT72Serializer(max_byte_size={}, db_name={}, ' - 'schema_name={}, type_name={}, ' - 'assembly_qualified_name={})'.format( - *map(repr, ( - self.max_byte_size, self.db_name, self.schema_name, - self.type_name, self.assembly_qualified_name))) + return ( + "UDT72Serializer(max_byte_size={}, db_name={}, " + "schema_name={}, type_name={}, " + "assembly_qualified_name={})".format( + *map( + repr, + ( + self.max_byte_size, + self.db_name, + self.schema_name, + self.type_name, + self.assembly_qualified_name, + ), + ) + ) ) @classmethod def from_stream(cls, r): # MAX_BYTE_SIZE max_byte_size = r.get_usmallint() - assert max_byte_size == 0xffff or 1 < max_byte_size < 8000 + assert max_byte_size == 0xFFFF or 1 < max_byte_size < 8000 # DB_NAME -- B_VARCHAR db_name = r.read_ucs2(r.get_byte()) # SCHEMA_NAME -- B_VARCHAR @@ -1218,14 +1248,15 @@ def from_stream(cls, r): # a US_VARCHAR (2 bytes length prefix) # containing ASSEMBLY_QUALIFIED_NAME assembly_qualified_name = r.read_ucs2(r.get_smallint()) - return cls(max_byte_size, db_name, schema_name, type_name, - assembly_qualified_name) + return cls( + max_byte_size, db_name, schema_name, type_name, assembly_qualified_name + ) def read(self, r): r = PlpReader(r) if r.is_null(): return None - return b''.join(r.chunks()) + return b"".join(r.chunks()) class UDT72SerializerMax(UDT72Serializer): @@ -1235,15 +1266,15 @@ def __init__(self, *args, **kwargs): class Image70Serializer(BaseTypeSerializer): type = tds_base.SYBIMAGE - declaration = 'IMAGE' + declaration = "IMAGE" - def __init__(self, size=0, table_name=''): + def __init__(self, size=0, table_name=""): super(Image70Serializer, self).__init__(size=size) self._table_name = table_name self._chunk_handler = _DefaultChunkedHandler(BytesIO()) def __repr__(self): - return 'Image70(tn={},s={})'.format(repr(self._table_name), self.size) + return "Image70(tn={},s={})".format(repr(self._table_name), self.size) @classmethod def from_stream(cls, r): @@ -1279,11 +1310,11 @@ def set_chunk_handler(self, chunk_handler): class Image72Serializer(Image70Serializer): def __init__(self, size=0, parts=()): - super(Image72Serializer, self).__init__(size=size, table_name='.'.join(parts)) + super(Image72Serializer, self).__init__(size=size, table_name=".".join(parts)) self._parts = parts def __repr__(self): - return 'Image72(p={},s={})'.format(self._parts, self.size) + return "Image72(p={},s={})".format(self._parts, self.size) @classmethod def from_stream(cls, r): @@ -1300,16 +1331,17 @@ def from_stream(cls, r): class SmallDateTimeType(SqlTypeMetaclass): def get_declaration(self): - return 'SMALLDATETIME' + return "SMALLDATETIME" class DateTimeType(SqlTypeMetaclass): def get_declaration(self): - return 'DATETIME' + return "DATETIME" class SmallDateTime(SqlValueMetaclass): """Corresponds to MSSQL smalldatetime""" + def __init__(self, days, minutes): """ @@ -1328,7 +1360,9 @@ def minutes(self): return self._minutes def to_pydatetime(self): - return _datetime_base_date + datetime.timedelta(days=self._days, minutes=self._minutes) + return _datetime_base_date + datetime.timedelta( + days=self._days, minutes=self._minutes + ) @classmethod def from_pydatetime(cls, dt): @@ -1354,14 +1388,16 @@ def from_stream(cls, r): class SmallDateTimeSerializer(BasePrimitiveTypeSerializer, BaseDateTimeSerializer): type = tds_base.SYBDATETIME4 - declaration = 'SMALLDATETIME' + declaration = "SMALLDATETIME" - _struct = struct.Struct(' 38: - raise tds_base.DataError('Precision of decimal value is out of range') + raise tds_base.DataError("Precision of decimal value is out of range") def __repr__(self): - return 'MsDecimal(scale={}, prec={})'.format(self.scale, self.precision) + return "MsDecimal(scale={}, prec={})".format(self.scale, self.precision) @classmethod def from_value(cls, value): @@ -1913,7 +2010,7 @@ def write(self, w, value): if not positive: val *= -1 size -= 1 - val *= 10 ** scale + val *= 10**scale for i in range(size): w.put_byte(int(val % 256)) val //= 256 @@ -1926,7 +2023,7 @@ def _decode(self, positive, buf): ctx.prec = 38 if not positive: val *= -1 - val /= 10 ** self._scale + val /= 10**self._scale return val def read_fixed(self, r, size): @@ -1943,7 +2040,7 @@ def read(self, r): class Money4Serializer(BasePrimitiveTypeSerializer): type = tds_base.SYBMONEY4 - declaration = 'SMALLMONEY' + declaration = "SMALLMONEY" def read(self, r): return decimal.Decimal(r.get_int()) / 10000 @@ -1952,26 +2049,28 @@ def write(self, w, val): val = int(val * 10000) w.put_int(val) + Money4Serializer.instance = money4_serializer = Money4Serializer() class Money8Serializer(BasePrimitiveTypeSerializer): type = tds_base.SYBMONEY - declaration = 'MONEY' + declaration = "MONEY" - _struct = struct.Struct(' 128: - raise ValueError("Schema part of TVP name should be no longer than 128 characters") + raise ValueError( + "Schema part of TVP name should be no longer than 128 characters" + ) if len(typ_name) > 128: - raise ValueError("Name part of TVP name should be no longer than 128 characters") + raise ValueError( + "Name part of TVP name should be no longer than 128 characters" + ) if columns is not None: if len(columns) > 1024: raise ValueError("TVP cannot have more than 1024 columns") if len(columns) < 1: raise ValueError("TVP must have at least one column") - self._typ_dbname = '' # dbname should always be empty string for TVP according to spec + self._typ_dbname = ( + "" # dbname should always be empty string for TVP according to spec + ) self._typ_schema = typ_schema self._typ_name = typ_name self._columns = columns def __repr__(self): - return 'TableType(s={},n={},cols={})'.format( + return "TableType(s={},n={},cols={})".format( self._typ_schema, self._typ_name, repr(self._columns) ) def get_declaration(self): assert not self._typ_dbname if self._typ_schema: - full_name = '{}.{}'.format(self._typ_schema, self._typ_name) + full_name = "{}.{}".format(self._typ_schema, self._typ_name) else: full_name = self._typ_name - return '{} READONLY'.format(full_name) + return "{} READONLY".format(full_name) @property def typ_schema(self): @@ -2168,14 +2275,17 @@ class TableValuedParam(SqlValueMetaclass): """ Used to represent a value of table-valued parameter """ + def __init__(self, type_name=None, columns=None, rows=None): # parsing type name - self._typ_schema = '' - self._typ_name = '' + self._typ_schema = "" + self._typ_name = "" if type_name: - parts = type_name.split('.') + parts = type_name.split(".") if len(parts) > 2: - raise ValueError('Type name should consist of at most 2 parts, e.g. dbo.MyType') + raise ValueError( + "Type name should consist of at most 2 parts, e.g. dbo.MyType" + ) self._typ_name = parts[-1] if len(parts) > 1: self._typ_schema = parts[0] @@ -2206,13 +2316,15 @@ def peek_row(self): try: rows = iter(self._rows) except TypeError: - raise tds_base.DataError('rows should be iterable') + raise tds_base.DataError("rows should be iterable") try: row = next(rows) except StopIteration: # no rows - raise tds_base.DataError("Cannot infer columns from rows for TVP because there are no rows") + raise tds_base.DataError( + "Cannot infer columns from rows for TVP because there are no rows" + ) else: # put row back self._rows = itertools.chain([row], rows) @@ -2229,12 +2341,12 @@ class TableSerializer(BaseTypeSerializer): type = tds_base.TVPTYPE def read(self, r): - """ According to spec TDS does not support output TVP values """ + """According to spec TDS does not support output TVP values""" raise NotImplementedError @classmethod def from_stream(cls, r): - """ According to spec TDS does not support output TVP values """ + """According to spec TDS does not support output TVP values""" raise NotImplementedError def __init__(self, table_type, columns_serializers): @@ -2247,7 +2359,7 @@ def table_type(self): return self._table_type def __repr__(self): - return 'TableSerializer(t={},c={})'.format( + return "TableSerializer(t={},c={})".format( repr(self._table_type), repr(self._columns_serializers) ) @@ -2293,7 +2405,7 @@ def write(self, w, val): w.put_byte(type_id) serializer.write_info(w) - w.write_b_varchar('') # ColName, must be empty in TVP according to spec + w.write_b_varchar("") # ColName, must be empty in TVP according to spec # here can optionally send TVP_ORDER_UNIQUE and TVP_COLUMN_ORDERING # https://msdn.microsoft.com/en-us/library/dd305261.aspx @@ -2352,33 +2464,39 @@ def write(self, w, val): } _type_map71 = _type_map.copy() -_type_map71.update({ - tds_base.XSYBCHAR: VarChar71Serializer, - tds_base.XSYBNCHAR: NVarChar71Serializer, - tds_base.XSYBVARCHAR: VarChar71Serializer, - tds_base.XSYBNVARCHAR: NVarChar71Serializer, - tds_base.SYBTEXT: Text71Serializer, - tds_base.SYBNTEXT: NText71Serializer, -}) +_type_map71.update( + { + tds_base.XSYBCHAR: VarChar71Serializer, + tds_base.XSYBNCHAR: NVarChar71Serializer, + tds_base.XSYBVARCHAR: VarChar71Serializer, + tds_base.XSYBNVARCHAR: NVarChar71Serializer, + tds_base.SYBTEXT: Text71Serializer, + tds_base.SYBNTEXT: NText71Serializer, + } +) _type_map72 = _type_map.copy() -_type_map72.update({ - tds_base.XSYBCHAR: VarChar72Serializer, - tds_base.XSYBNCHAR: NVarChar72Serializer, - tds_base.XSYBVARCHAR: VarChar72Serializer, - tds_base.XSYBNVARCHAR: NVarChar72Serializer, - tds_base.SYBTEXT: Text72Serializer, - tds_base.SYBNTEXT: NText72Serializer, - tds_base.XSYBBINARY: VarBinarySerializer72, - tds_base.XSYBVARBINARY: VarBinarySerializer72, - tds_base.SYBIMAGE: Image72Serializer, - tds_base.UDTTYPE: UDT72Serializer, -}) +_type_map72.update( + { + tds_base.XSYBCHAR: VarChar72Serializer, + tds_base.XSYBNCHAR: NVarChar72Serializer, + tds_base.XSYBVARCHAR: VarChar72Serializer, + tds_base.XSYBNVARCHAR: NVarChar72Serializer, + tds_base.SYBTEXT: Text72Serializer, + tds_base.SYBNTEXT: NText72Serializer, + tds_base.XSYBBINARY: VarBinarySerializer72, + tds_base.XSYBVARBINARY: VarBinarySerializer72, + tds_base.SYBIMAGE: Image72Serializer, + tds_base.UDTTYPE: UDT72Serializer, + } +) _type_map73 = _type_map72.copy() -_type_map73.update({ - tds_base.TVPTYPE: TableSerializer, -}) +_type_map73.update( + { + tds_base.TVPTYPE: TableSerializer, + } +) def sql_type_by_declaration(declaration): @@ -2389,6 +2507,7 @@ class SerializerFactory(object): """ Factory class for TDS data types """ + def __init__(self, tds_ver): self._tds_ver = tds_ver if self._tds_ver >= tds_base.TDS73: @@ -2403,7 +2522,7 @@ def __init__(self, tds_ver): def get_type_serializer(self, tds_type_id): type_class = self._type_map.get(tds_type_id) if not type_class: - raise tds_base.InterfaceError('Invalid type id {}'.format(tds_type_id)) + raise tds_base.InterfaceError("Invalid type id {}".format(tds_type_id)) return type_class def long_binary_type(self): @@ -2437,7 +2556,9 @@ def datetime_with_tz(self, precision): if self._tds_ver >= tds_base.TDS72: return DateTimeOffsetType(precision=precision) else: - raise tds_base.DataError('Given TDS version does not support DATETIMEOFFSET type') + raise tds_base.DataError( + "Given TDS version does not support DATETIMEOFFSET type" + ) def date(self): if self._tds_ver >= tds_base.TDS72: @@ -2449,11 +2570,13 @@ def time(self, precision): if self._tds_ver >= tds_base.TDS72: return TimeType(precision=precision) else: - raise tds_base.DataError('Given TDS version does not support TIME type') + raise tds_base.DataError("Given TDS version does not support TIME type") def serializer_by_declaration(self, declaration, connection): sql_type = sql_type_by_declaration(declaration) - return self.serializer_by_type(sql_type=sql_type, collation=connection.collation) + return self.serializer_by_type( + sql_type=sql_type, collation=connection.collation + ) def serializer_by_type(self, sql_type, collation=raw_collation): typ = sql_type @@ -2478,13 +2601,19 @@ def serializer_by_type(self, sql_type, collation=raw_collation): elif isinstance(typ, CharType): return self._type_map[tds_base.XSYBCHAR](size=typ.size, collation=collation) elif isinstance(typ, VarCharType): - return self._type_map[tds_base.XSYBVARCHAR](size=typ.size, collation=collation) + return self._type_map[tds_base.XSYBVARCHAR]( + size=typ.size, collation=collation + ) elif isinstance(typ, VarCharMaxType): return VarCharMaxSerializer(collation=collation) elif isinstance(typ, NCharType): - return self._type_map[tds_base.XSYBNCHAR](size=typ.size, collation=collation) + return self._type_map[tds_base.XSYBNCHAR]( + size=typ.size, collation=collation + ) elif isinstance(typ, NVarCharType): - return self._type_map[tds_base.XSYBNVARCHAR](size=typ.size, collation=collation) + return self._type_map[tds_base.XSYBNVARCHAR]( + size=typ.size, collation=collation + ) elif isinstance(typ, NVarCharMaxType): return NVarCharMaxSerializer(collation=collation) elif isinstance(typ, TextType): @@ -2502,7 +2631,9 @@ def serializer_by_type(self, sql_type, collation=raw_collation): elif isinstance(typ, ImageType): return self._type_map[tds_base.SYBIMAGE]() elif isinstance(typ, DecimalType): - return self._type_map[tds_base.SYBDECIMAL](scale=typ.scale, precision=typ.precision) + return self._type_map[tds_base.SYBDECIMAL]( + scale=typ.scale, precision=typ.precision + ) elif isinstance(typ, VariantType): return self._type_map[tds_base.SYBVARIANT](size=0) elif isinstance(typ, SmallDateTimeType): @@ -2522,64 +2653,98 @@ def serializer_by_type(self, sql_type, collation=raw_collation): elif isinstance(typ, TableType): columns_serializers = None if typ.columns is not None: - columns_serializers = [self.serializer_by_type(col.type) for col in typ.columns] - return TableSerializer(table_type=typ, columns_serializers=columns_serializers) + columns_serializers = [ + self.serializer_by_type(col.type) for col in typ.columns + ] + return TableSerializer( + table_type=typ, columns_serializers=columns_serializers + ) else: - raise ValueError('Cannot map type {} to serializer.'.format(typ)) + raise ValueError("Cannot map type {} to serializer.".format(typ)) class DeclarationsParser(object): def __init__(self): declaration_parsers = [ - ('bit', BitType), - ('tinyint', TinyIntType), - ('smallint', SmallIntType), - ('(?:int|integer)', IntType), - ('bigint', BigIntType), - ('real', RealType), - ('(?:float|double precision)', FloatType), - ('(?:char|character)', CharType), - (r'(?:char|character)\((\d+)\)', lambda size_str: CharType(size=int(size_str))), - (r'(?:varchar|char(?:|acter)\s+varying)', VarCharType), - (r'(?:varchar|char(?:|acter)\s+varying)\((\d+)\)', lambda size_str: VarCharType(size=int(size_str))), - (r'varchar\(max\)', VarCharMaxType), - (r'(?:nchar|national\s+(?:char|character))', NCharType), - (r'(?:nchar|national\s+(?:char|character))\((\d+)\)', lambda size_str: NCharType(size=int(size_str))), - (r'(?:nvarchar|national\s+(?:char|character)\s+varying)', NVarCharType), - (r'(?:nvarchar|national\s+(?:char|character)\s+varying)\((\d+)\)', - lambda size_str: NVarCharType(size=int(size_str))), - (r'nvarchar\(max\)', NVarCharMaxType), - ('xml', XmlType), - ('text', TextType), - (r'(?:ntext|national\s+text)', NTextType), - ('binary', BinaryType), - (r'binary\((\d+)\)', lambda size_str: BinaryType(size=int(size_str))), - ('(?:varbinary|binary varying)', VarBinaryType), - (r'(?:varbinary|binary varying)\((\d+)\)', lambda size_str: VarBinaryType(size=int(size_str))), - (r'varbinary\(max\)', VarBinaryMaxType), - ('image', ImageType), - ('smalldatetime', SmallDateTimeType), - ('datetime', DateTimeType), - ('date', DateType), - (r'time', TimeType), - (r'time\((\d+)\)', lambda precision_str: TimeType(precision=int(precision_str))), - ('datetime2', DateTime2Type), - (r'datetime2\((\d+)\)', lambda precision_str: DateTime2Type(precision=int(precision_str))), - ('datetimeoffset', DateTimeOffsetType), - (r'datetimeoffset\((\d+)\)', - lambda precision_str: DateTimeOffsetType(precision=int(precision_str))), - ('(?:decimal|dec|numeric)', DecimalType), - (r'(?:decimal|dec|numeric)\((\d+)\)', - lambda precision_str: DecimalType(precision=int(precision_str))), - (r'(?:decimal|dec|numeric)\((\d+), ?(\d+)\)', - lambda precision_str, scale_str: DecimalType(precision=int(precision_str), scale=int(scale_str))), - ('smallmoney', SmallMoneyType), - ('money', MoneyType), - ('uniqueidentifier', UniqueIdentifierType), - ('sql_variant', VariantType), + ("bit", BitType), + ("tinyint", TinyIntType), + ("smallint", SmallIntType), + ("(?:int|integer)", IntType), + ("bigint", BigIntType), + ("real", RealType), + ("(?:float|double precision)", FloatType), + ("(?:char|character)", CharType), + ( + r"(?:char|character)\((\d+)\)", + lambda size_str: CharType(size=int(size_str)), + ), + (r"(?:varchar|char(?:|acter)\s+varying)", VarCharType), + ( + r"(?:varchar|char(?:|acter)\s+varying)\((\d+)\)", + lambda size_str: VarCharType(size=int(size_str)), + ), + (r"varchar\(max\)", VarCharMaxType), + (r"(?:nchar|national\s+(?:char|character))", NCharType), + ( + r"(?:nchar|national\s+(?:char|character))\((\d+)\)", + lambda size_str: NCharType(size=int(size_str)), + ), + (r"(?:nvarchar|national\s+(?:char|character)\s+varying)", NVarCharType), + ( + r"(?:nvarchar|national\s+(?:char|character)\s+varying)\((\d+)\)", + lambda size_str: NVarCharType(size=int(size_str)), + ), + (r"nvarchar\(max\)", NVarCharMaxType), + ("xml", XmlType), + ("text", TextType), + (r"(?:ntext|national\s+text)", NTextType), + ("binary", BinaryType), + (r"binary\((\d+)\)", lambda size_str: BinaryType(size=int(size_str))), + ("(?:varbinary|binary varying)", VarBinaryType), + ( + r"(?:varbinary|binary varying)\((\d+)\)", + lambda size_str: VarBinaryType(size=int(size_str)), + ), + (r"varbinary\(max\)", VarBinaryMaxType), + ("image", ImageType), + ("smalldatetime", SmallDateTimeType), + ("datetime", DateTimeType), + ("date", DateType), + (r"time", TimeType), + ( + r"time\((\d+)\)", + lambda precision_str: TimeType(precision=int(precision_str)), + ), + ("datetime2", DateTime2Type), + ( + r"datetime2\((\d+)\)", + lambda precision_str: DateTime2Type(precision=int(precision_str)), + ), + ("datetimeoffset", DateTimeOffsetType), + ( + r"datetimeoffset\((\d+)\)", + lambda precision_str: DateTimeOffsetType(precision=int(precision_str)), + ), + ("(?:decimal|dec|numeric)", DecimalType), + ( + r"(?:decimal|dec|numeric)\((\d+)\)", + lambda precision_str: DecimalType(precision=int(precision_str)), + ), + ( + r"(?:decimal|dec|numeric)\((\d+), ?(\d+)\)", + lambda precision_str, scale_str: DecimalType( + precision=int(precision_str), scale=int(scale_str) + ), + ), + ("smallmoney", SmallMoneyType), + ("money", MoneyType), + ("uniqueidentifier", UniqueIdentifierType), + ("sql_variant", VariantType), + ] + self._compiled = [ + (re.compile(r"^" + regex + "$", re.IGNORECASE), constructor) + for regex, constructor in declaration_parsers ] - self._compiled = [(re.compile(r'^' + regex + '$', re.IGNORECASE), constructor) - for regex, constructor in declaration_parsers] def parse(self, declaration): """ @@ -2594,14 +2759,16 @@ def parse(self, declaration): m = regex.match(declaration) if m: return constructor(*m.groups()) - raise ValueError('Unable to parse type declaration', declaration) + raise ValueError("Unable to parse type declaration", declaration) _declarations_parser = DeclarationsParser() class TdsTypeInferrer(object): - def __init__(self, type_factory, collation=None, bytes_to_unicode=False, allow_tz=False): + def __init__( + self, type_factory, collation=None, bytes_to_unicode=False, allow_tz=False + ): """ Class used to do TDS type inference @@ -2616,7 +2783,7 @@ def __init__(self, type_factory, collation=None, bytes_to_unicode=False, allow_t self._allow_tz = allow_tz def from_value(self, value): - """ Function infers TDS type from Python value. + """Function infers TDS type from Python value. :param value: value from which to infer TDS type :return: An instance of subclass of :class:`BaseType` @@ -2628,7 +2795,7 @@ def from_value(self, value): return sql_type def from_class(self, cls): - """ Function infers TDS type from Python class. + """Function infers TDS type from Python class. :param cls: Class from which to infer type :return: An instance of subclass of :class:`BaseType` @@ -2645,11 +2812,11 @@ def _from_class_value(self, value, value_type): elif issubclass(value_type, int): if value is None: return IntType() - if -2 ** 31 <= value <= 2 ** 31 - 1: + if -(2**31) <= value <= 2**31 - 1: return IntType() - elif -2 ** 63 <= value <= 2 ** 63 - 1: + elif -(2**63) <= value <= 2**63 - 1: return BigIntType() - elif -10 ** 38 + 1 <= value <= 10 ** 38 - 1: + elif -(10**38) + 1 <= value <= 10**38 - 1: return DecimalType(precision=38) else: return VarCharMaxType() @@ -2699,14 +2866,22 @@ def _from_class_value(self, value, value_type): try: cell_iter = iter(row) except TypeError: - raise tds_base.DataError('Each row in table should be an iterable') + raise tds_base.DataError( + "Each row in table should be an iterable" + ) for cell in cell_iter: if isinstance(cell, TableValuedParam): - raise tds_base.DataError('TVP type cannot have nested TVP types') + raise tds_base.DataError( + "TVP type cannot have nested TVP types" + ) col_type = self.from_value(cell) col = tds_base.Column(type=col_type) columns.append(col) - return TableType(typ_schema=value.typ_schema, typ_name=value.typ_name, columns=columns) + return TableType( + typ_schema=value.typ_schema, typ_name=value.typ_name, columns=columns + ) else: - raise tds_base.DataError('Cannot infer TDS type from Python value: {!r}'.format(value)) + raise tds_base.DataError( + "Cannot infer TDS type from Python value: {!r}".format(value) + ) diff --git a/src/pytds/tds_writer.py b/src/pytds/tds_writer.py index 3c59172..be448eb 100644 --- a/src/pytds/tds_writer.py +++ b/src/pytds/tds_writer.py @@ -2,17 +2,30 @@ from pytds import tds_base from pytds.collate import Collation, ucs2_codec -from pytds.tds_base import _uint_be, _byte, _smallint_le, _usmallint_le, _usmallint_be, _int_le, _uint_le, _int8_le, \ - _uint8_le, _header +from pytds.tds_base import ( + _uint_be, + _byte, + _smallint_le, + _usmallint_le, + _usmallint_be, + _int_le, + _uint_le, + _int8_le, + _uint8_le, + _header, +) class _TdsWriter: - """ TDS stream writer + """TDS stream writer Handles splitting of incoming data into TDS packets according to TDS protocol. Provides convinience methods for writing primitive data types. """ - def __init__(self, transport: tds_base.TransportProtocol, bufsize: int, tds_session): + + def __init__( + self, transport: tds_base.TransportProtocol, bufsize: int, tds_session + ): self._transport = transport self._tds = tds_session self._pos = 0 @@ -26,7 +39,7 @@ def session(self): @property def bufsize(self) -> int: - """ Size of the buffer """ + """Size of the buffer""" return len(self._buf) @bufsize.setter @@ -35,12 +48,12 @@ def bufsize(self, bufsize: int) -> None: return if bufsize > len(self._buf): - self._buf.extend(b'\0' * (bufsize - len(self._buf))) + self._buf.extend(b"\0" * (bufsize - len(self._buf))) else: self._buf = self._buf[0:bufsize] def begin_packet(self, packet_type: int) -> None: - """ Starts new packet stream + """Starts new packet stream :param packet_type: Type of TDS stream, e.g. TDS_PRELOGIN, TDS_QUERY etc. """ @@ -48,51 +61,51 @@ def begin_packet(self, packet_type: int) -> None: self._pos = 8 def pack(self, struc: struct.Struct, *args) -> None: - """ Packs and writes structure into stream """ + """Packs and writes structure into stream""" self.write(struc.pack(*args)) def put_byte(self, value: int) -> None: - """ Writes single byte into stream """ + """Writes single byte into stream""" self.pack(_byte, value) def put_smallint(self, value: int) -> None: - """ Writes 16-bit signed integer into the stream """ + """Writes 16-bit signed integer into the stream""" self.pack(_smallint_le, value) def put_usmallint(self, value: int) -> None: - """ Writes 16-bit unsigned integer into the stream """ + """Writes 16-bit unsigned integer into the stream""" self.pack(_usmallint_le, value) def put_usmallint_be(self, value: int) -> None: - """ Writes 16-bit unsigned big-endian integer into the stream """ + """Writes 16-bit unsigned big-endian integer into the stream""" self.pack(_usmallint_be, value) def put_int(self, value: int) -> None: - """ Writes 32-bit signed integer into the stream """ + """Writes 32-bit signed integer into the stream""" self.pack(_int_le, value) def put_uint(self, value: int) -> None: - """ Writes 32-bit unsigned integer into the stream """ + """Writes 32-bit unsigned integer into the stream""" self.pack(_uint_le, value) def put_uint_be(self, value: int) -> None: - """ Writes 32-bit unsigned big-endian integer into the stream """ + """Writes 32-bit unsigned big-endian integer into the stream""" self.pack(_uint_be, value) def put_int8(self, value: int) -> None: - """ Writes 64-bit signed integer into the stream """ + """Writes 64-bit signed integer into the stream""" self.pack(_int8_le, value) def put_uint8(self, value: int) -> None: - """ Writes 64-bit unsigned integer into the stream """ + """Writes 64-bit unsigned integer into the stream""" self.pack(_uint8_le, value) def put_collation(self, collation: Collation) -> None: - """ Writes :class:`Collation` structure into the stream """ + """Writes :class:`Collation` structure into the stream""" self.write(collation.pack()) def write(self, data: bytes) -> None: - """ Writes given bytes buffer into the stream + """Writes given bytes buffer into the stream Function returns only when entire buffer is written """ @@ -103,7 +116,9 @@ def write(self, data: bytes) -> None: self._write_packet(final=False) else: to_write = min(left, len(data) - data_off) - self._buf[self._pos:self._pos + to_write] = data[data_off:data_off + to_write] + self._buf[self._pos : self._pos + to_write] = data[ + data_off : data_off + to_write + ] self._pos += to_write data_off += to_write @@ -112,30 +127,32 @@ def write_b_varchar(self, s: str) -> None: self.write_ucs2(s) def write_ucs2(self, s: str) -> None: - """ Write string encoding it in UCS2 into stream """ + """Write string encoding it in UCS2 into stream""" self.write_string(s, ucs2_codec) def write_string(self, s: str, codec) -> None: - """ Write string encoding it with codec into stream """ + """Write string encoding it with codec into stream""" for i in range(0, len(s), self.bufsize): - chunk = s[i:i + self.bufsize] + chunk = s[i : i + self.bufsize] buf, consumed = codec.encode(chunk) assert consumed == len(chunk) self.write(buf) def flush(self) -> None: - """ Closes current packet stream """ + """Closes current packet stream""" return self._write_packet(final=True) def _write_packet(self, final: bool) -> None: - """ Writes single TDS packet into underlying transport. + """Writes single TDS packet into underlying transport. Data for the packet is taken from internal buffer. :param final: True means this is the final packet in substream. """ status = 1 if final else 0 - _header.pack_into(self._buf, 0, self._type, status, self._pos, 0, self._packet_no) + _header.pack_into( + self._buf, 0, self._type, status, self._pos, 0, self._packet_no + ) self._packet_no = (self._packet_no + 1) % 256 - self._transport.sendall(self._buf[:self._pos]) + self._transport.sendall(self._buf[: self._pos]) self._pos = 8 diff --git a/src/pytds/tls.py b/src/pytds/tls.py index 813f9d1..dd97308 100644 --- a/src/pytds/tls.py +++ b/src/pytds/tls.py @@ -5,8 +5,8 @@ import typing try: - import OpenSSL.SSL # type: ignore # needs fixing - import cryptography.hazmat.backends.openssl.backend # type: ignore # needs fixing + import OpenSSL.SSL # type: ignore # needs fixing + import cryptography.hazmat.backends.openssl.backend # type: ignore # needs fixing except ImportError: OPENSSL_AVAILABLE = False else: @@ -25,7 +25,9 @@ class EncryptedSocket(tds_base.TransportProtocol): - def __init__(self, transport: tds_base.TransportProtocol, tls_conn: OpenSSL.SSL.Connection): + def __init__( + self, transport: tds_base.TransportProtocol, tls_conn: OpenSSL.SSL.Connection + ): super().__init__() self._transport = transport self._tls_conn = tls_conn @@ -45,19 +47,21 @@ def sendall(self, data: Any, flags: int = 0) -> None: buf = self._tls_conn.bio_read(BUFSIZE) self._transport.sendall(buf) - # def send(self, data): - # while True: - # try: - # return self._tls_conn.send(data) - # except OpenSSL.SSL.WantWriteError: - # buf = self._tls_conn.bio_read(BUFSIZE) - # self._transport.sendall(buf) - - def recv_into(self, buffer: bytearray | memoryview, size: int = 0, flags: int = 0) -> int: + # def send(self, data): + # while True: + # try: + # return self._tls_conn.send(data) + # except OpenSSL.SSL.WantWriteError: + # buf = self._tls_conn.bio_read(BUFSIZE) + # self._transport.sendall(buf) + + def recv_into( + self, buffer: bytearray | memoryview, size: int = 0, flags: int = 0 + ) -> int: if size == 0: size = len(buffer) res = self.recv(size) - buffer[0:len(res)] = res + buffer[0 : len(res)] = res return len(res) def recv(self, bufsize: int, flags: int = 0) -> bytes: @@ -76,7 +80,7 @@ def recv(self, bufsize: int, flags: int = 0) -> bytes: if buf: self._tls_conn.bio_write(buf) else: - return b'' + return b"" def close(self) -> None: self._tls_conn.shutdown() @@ -91,14 +95,18 @@ def verify_cb(conn, cert, err_num, err_depth, ret_code: int) -> bool: def is_san_matching(san: str, host_name: str) -> bool: - for item in san.split(','): - dnsentry = item.lstrip('DNS:').strip() + for item in san.split(","): + dnsentry = item.lstrip("DNS:").strip() # SANs are usually have form like: DNS:hostname if dnsentry == host_name: return True - if dnsentry[0:2] == "*.": # support for wildcards, but only at the first position + if ( + dnsentry[0:2] == "*." + ): # support for wildcards, but only at the first position afterstar_parts = dnsentry[2:] - afterstar_parts_sname = '.'.join(host_name.split('.')[1:]) # remove first part of dns name + afterstar_parts_sname = ".".join( + host_name.split(".")[1:] + ) # remove first part of dns name if afterstar_parts == afterstar_parts_sname: return True return False @@ -114,7 +122,7 @@ def validate_host(cert, name: bytes) -> bool: """ cn = None for t, v in cert.get_subject().get_components(): - if t == b'CN': + if t == b"CN": cn = v break @@ -122,10 +130,10 @@ def validate_host(cert, name: bytes) -> bool: return True # checking SAN - s_name = name.decode('ascii') + s_name = name.decode("ascii") for i in range(cert.get_extension_count()): ext = cert.get_extension(i) - if ext.get_short_name() == b'subjectAltName': + if ext.get_short_name() == b"subjectAltName": s = str(ext) if is_san_matching(s, s_name): return True @@ -139,9 +147,9 @@ def create_context(cafile: str) -> OpenSSL.SSL.Context: ctx.set_options(OpenSSL.SSL.OP_NO_SSLv2) ctx.set_options(OpenSSL.SSL.OP_NO_SSLv3) ctx.set_verify(OpenSSL.SSL.VERIFY_PEER, verify_cb) - #print("verify depth:", ctx.get_verify_depth()) - #print("verify mode:", ctx.get_verify_mode()) - #print("openssl version:", cryptography.hazmat.backends.openssl.backend.openssl_version_text()) + # print("verify depth:", ctx.get_verify_depth()) + # print("verify mode:", ctx.get_verify_mode()) + # print("openssl version:", cryptography.hazmat.backends.openssl.backend.openssl_version_text()) ctx.load_verify_locations(cafile=cafile) return ctx @@ -152,40 +160,51 @@ def establish_channel(tds_sock: _TdsSession) -> None: r = tds_sock._reader login = tds_sock.conn._login - bhost = login.server_name.encode('ascii') + bhost = login.server_name.encode("ascii") conn = OpenSSL.SSL.Connection(login.tls_ctx) conn.set_tlsext_host_name(bhost) # change connection to client mode conn.set_connect_state() - logger.info('doing TLS handshake') + logger.info("doing TLS handshake") while True: try: - logger.debug('calling do_handshake') + logger.debug("calling do_handshake") conn.do_handshake() except OpenSSL.SSL.WantReadError: - logger.debug('got WantReadError, getting data from the write end of the TLS connection buffer') + logger.debug( + "got WantReadError, getting data from the write end of the TLS connection buffer" + ) try: req = conn.bio_read(BUFSIZE) except OpenSSL.SSL.WantReadError: # PyOpenSSL - https://github.com/pyca/pyopenssl/issues/887 - logger.debug('got WantReadError again, waiting for response...') + logger.debug("got WantReadError again, waiting for response...") else: - logger.debug('sending %d bytes of the handshake data to the server', len(req)) + logger.debug( + "sending %d bytes of the handshake data to the server", len(req) + ) w.begin_packet(tds_base.PacketType.PRELOGIN) w.write(req) w.flush() - logger.debug('receiving response from the server') + logger.debug("receiving response from the server") resp_header = r.begin_response() resp = r.read_whole_packet() # TODO validate r.packet_type - logger.debug('adding %d bytes of the response into the TLS connection buffer', len(resp)) + logger.debug( + "adding %d bytes of the response into the TLS connection buffer", + len(resp), + ) conn.bio_write(resp) else: - logger.info('TLS handshake is complete') + logger.info("TLS handshake is complete") if login.validate_host: if not validate_host(cert=conn.get_peer_certificate(), name=bhost): - raise tds_base.Error("Certificate does not match host name '{}'".format(login.server_name)) + raise tds_base.Error( + "Certificate does not match host name '{}'".format( + login.server_name + ) + ) enc_sock = EncryptedSocket(transport=tds_sock.conn.sock, tls_conn=conn) tds_sock.conn.sock = enc_sock tds_sock._writer._transport = enc_sock diff --git a/src/pytds/tz.py b/src/pytds/tz.py index 39119a5..b151844 100644 --- a/src/pytds/tz.py +++ b/src/pytds/tz.py @@ -27,7 +27,7 @@ def dst(self, dt): return ZERO -utc = FixedOffsetTimezone(offset=0, name='UTC') +utc = FixedOffsetTimezone(offset=0, name="UTC") STDOFFSET = timedelta(seconds=-_time.timezone) @@ -40,7 +40,6 @@ def dst(self, dt): class LocalTimezone(tzinfo): - def utcoffset(self, dt): if self._isdst(dt): return DSTOFFSET @@ -57,11 +56,20 @@ def tzname(self, dt): return _time.tzname[self._isdst(dt)] def _isdst(self, dt): - tt = (dt.year, dt.month, dt.day, - dt.hour, dt.minute, dt.second, - dt.weekday(), 0, 0) + tt = ( + dt.year, + dt.month, + dt.day, + dt.hour, + dt.minute, + dt.second, + dt.weekday(), + 0, + 0, + ) stamp = _time.mktime(tt) tt = _time.localtime(stamp) return tt.tm_isdst > 0 + local = LocalTimezone() diff --git a/tests/all_test.py b/tests/all_test.py index 628d990..5ba1fab 100644 --- a/tests/all_test.py +++ b/tests/all_test.py @@ -11,10 +11,35 @@ import utils -from pytds.tds_types import TimeType, DateTime2Type, DateType, DateTimeOffsetType, BitType, TinyIntType, SmallIntType, \ - IntType, BigIntType, RealType, FloatType, NVarCharType, VarBinaryType, SmallDateTimeType, DateTimeType, DecimalType, \ - MoneyType, UniqueIdentifierType, VariantType, ImageType, VarBinaryMaxType, VarCharType, TextType, NTextType, \ - NVarCharMaxType, VarCharMaxType, XmlType +from pytds.tds_types import ( + TimeType, + DateTime2Type, + DateType, + DateTimeOffsetType, + BitType, + TinyIntType, + SmallIntType, + IntType, + BigIntType, + RealType, + FloatType, + NVarCharType, + VarBinaryType, + SmallDateTimeType, + DateTimeType, + DecimalType, + MoneyType, + UniqueIdentifierType, + VariantType, + ImageType, + VarBinaryMaxType, + VarCharType, + TextType, + NTextType, + NVarCharMaxType, + VarCharMaxType, + XmlType, +) try: import unittest2 as unittest @@ -30,19 +55,38 @@ import pytds.tz import pytds.login import pytds.smp + tzoffset = pytds.tz.FixedOffsetTimezone utc = pytds.tz.utc import pytds.extensions from pytds import ( - connect, ProgrammingError, TimeoutError, Time, - Error, IntegrityError, Timestamp, DataError, Date, Binary, - output, default, - STRING, BINARY, NUMBER, DATETIME, DECIMAL, INTEGER, REAL, XML) -from pytds.tds_types import (DateTimeSerializer, SmallMoneyType) + connect, + ProgrammingError, + TimeoutError, + Time, + Error, + IntegrityError, + Timestamp, + DataError, + Date, + Binary, + output, + default, + STRING, + BINARY, + NUMBER, + DATETIME, + DECIMAL, + INTEGER, + REAL, + XML, +) +from pytds.tds_types import DateTimeSerializer, SmallMoneyType from pytds.tds_base import ( Param, - IS_TDS73_PLUS, IS_TDS71_PLUS, - ) + IS_TDS73_PLUS, + IS_TDS71_PLUS, +) import dbapi20 import pytds import settings @@ -50,35 +94,37 @@ logger = logging.getLogger(__name__) -LIVE_TEST = getattr(settings, 'LIVE_TEST', True) +LIVE_TEST = getattr(settings, "LIVE_TEST", True) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_connection_timeout_with_mars(): kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = 'master' - kwargs['timeout'] = 1 - kwargs['use_mars'] = True + kwargs["database"] = "master" + kwargs["timeout"] = 1 + kwargs["use_mars"] = True with connect(*settings.CONNECT_ARGS, **kwargs) as conn: cur = conn.cursor() with pytest.raises(TimeoutError): cur.execute("waitfor delay '00:00:05'") - cur.execute('select 1') + cur.execute("select 1") @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_connection_no_mars_autocommit(): kwargs = settings.CONNECT_KWARGS.copy() - kwargs.update({ - 'use_mars': False, - 'timeout': 1, - 'pooling': True, - 'autocommit': True, - }) + kwargs.update( + { + "use_mars": False, + "timeout": 1, + "pooling": True, + "autocommit": True, + } + ) with connect(**kwargs) as conn: with conn.cursor() as cur: # test execute scalar with empty response - cur.execute_scalar('declare @tbl table(f int); select * from @tbl') + cur.execute_scalar("declare @tbl table(f int); select * from @tbl") cur.execute("print 'hello'") messages = cur.messages @@ -86,20 +132,20 @@ def test_connection_no_mars_autocommit(): assert len(messages[0]) == 2 # in following assert exception class does not have to be exactly as specified assert messages[0][0] == pytds.OperationalError - assert messages[0][1].text == 'hello' + assert messages[0][1].text == "hello" assert messages[0][1].line == 1 assert messages[0][1].severity == 0 assert messages[0][1].number == 0 assert messages[0][1].state == 1 - assert 'hello' in messages[0][1].message + assert "hello" in messages[0][1].message # test cursor usage after close, should raise exception cur = conn.cursor() - cur.execute_scalar('select 1') + cur.execute_scalar("select 1") cur.close() with pytest.raises(Error) as ex: - cur.execute('select 1') - assert 'Cursor is closed' in str(ex.value) + cur.execute("select 1") + assert "Cursor is closed" in str(ex.value) # calling get_proc_return_status on closed cursor works # this test does not have to pass assert cur.get_proc_return_status() is None @@ -120,11 +166,13 @@ def test_connection_no_mars_autocommit(): @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_connection_timeout_no_mars(): kwargs = settings.CONNECT_KWARGS.copy() - kwargs.update({ - 'use_mars': False, - 'timeout': 1, - 'pooling': True, - }) + kwargs.update( + { + "use_mars": False, + "timeout": 1, + "pooling": True, + } + ) with connect(**kwargs) as conn: with conn.cursor() as cur: with pytest.raises(TimeoutError): @@ -135,10 +183,10 @@ def test_connection_timeout_no_mars(): # test cancelling with conn.cursor() as cur: - cur.execute('select 1') + cur.execute("select 1") cur.cancel() assert cur.fetchall() == [] - cur.execute('select 2') + cur.execute("select 2") assert cur.fetchall() == [(2,)] # test rollback @@ -146,7 +194,7 @@ def test_connection_timeout_no_mars(): # test callproc on non-mars connection with conn.cursor() as cur: - cur.callproc('sp_reset_connection') + cur.callproc("sp_reset_connection") with conn.cursor() as cur: # test spid property on non-mars cursor @@ -158,17 +206,19 @@ def test_connection_timeout_no_mars(): # test non-mars cursor with connection pool enabled with connect(**kwargs) as conn: with conn.cursor() as cur: - cur.execute('select 1') + cur.execute("select 1") assert cur.fetchall() == [(1,)] @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_connection_no_mars_no_pooling(): kwargs = settings.CONNECT_KWARGS.copy() - kwargs.update({ - 'use_mars': False, - 'pooling': False, - }) + kwargs.update( + { + "use_mars": False, + "pooling": False, + } + ) with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute("select 1") @@ -178,28 +228,35 @@ def test_connection_no_mars_no_pooling(): @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_row_strategies(): kwargs = settings.CONNECT_KWARGS.copy() - kwargs.update({ - 'row_strategy': pytds.list_row_strategy, - }) + kwargs.update( + { + "row_strategy": pytds.list_row_strategy, + } + ) with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute("select 1") assert cur.fetchall() == [[1]] - kwargs.update({ - 'row_strategy': pytds.namedtuple_row_strategy, - }) + kwargs.update( + { + "row_strategy": pytds.namedtuple_row_strategy, + } + ) import collections + with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute("select 1 as f") - assert cur.fetchall() == [collections.namedtuple('Row', ['f'])(1)] - kwargs.update({ - 'row_strategy': pytds.recordtype_row_strategy, - }) + assert cur.fetchall() == [collections.namedtuple("Row", ["f"])(1)] + kwargs.update( + { + "row_strategy": pytds.recordtype_row_strategy, + } + ) with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute("select 1 as e, 2 as f") - row, = cur.fetchall() + (row,) = cur.fetchall() assert row.e == 1 assert row.f == 2 assert row[0] == 1 @@ -210,8 +267,8 @@ def test_row_strategies(): @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_get_instances(): - if not hasattr(settings, 'BROWSER_ADDRESS'): - return unittest.skip('BROWSER_ADDRESS setting is not defined') + if not hasattr(settings, "BROWSER_ADDRESS"): + return unittest.skip("BROWSER_ADDRESS setting is not defined") pytds.tds.tds7_get_instances(settings.BROWSER_ADDRESS) @@ -219,7 +276,7 @@ def test_get_instances(): class ConnectionTestCase(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = settings.DATABASE + kwargs["database"] = settings.DATABASE self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def tearDown(self): @@ -230,21 +287,20 @@ def tearDown(self): class NoMarsTestCase(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = 'master' - kwargs['use_mars'] = False + kwargs["database"] = "master" + kwargs["use_mars"] = False self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def tearDown(self): self.conn.close() - class TestCaseWithCursor(ConnectionTestCase): def setUp(self): super(TestCaseWithCursor, self).setUp() self.cursor = self.conn.cursor() - #def test_mars_sessions_recycle_ids(self): + # def test_mars_sessions_recycle_ids(self): # if not self.conn.mars_enabled: # self.skipTest('Only relevant to mars') # for _ in range(2 ** 16 + 1): @@ -255,51 +311,80 @@ def test_parameters_ll(self): _params_tests(self) - class TestVariant(ConnectionTestCase): def _t(self, result, sql): with self.conn.cursor() as cur: cur.execute("select cast({0} as sql_variant)".format(sql)) - val, = cur.fetchone() + (val,) = cur.fetchone() self.assertEqual(result, val) def test_new_datetime(self): if not IS_TDS73_PLUS(self.conn): - self.skipTest('Requires TDS7.3+') + self.skipTest("Requires TDS7.3+") import pytds.tz - self._t(datetime(2011, 2, 3, 10, 11, 12, 3000), "cast('2011-02-03T10:11:12.003000' as datetime2)") + + self._t( + datetime(2011, 2, 3, 10, 11, 12, 3000), + "cast('2011-02-03T10:11:12.003000' as datetime2)", + ) self._t(time(10, 11, 12, 3000), "cast('10:11:12.003000' as time)") self._t(date(2011, 2, 3), "cast('2011-02-03' as date)") - self._t(datetime(2011, 2, 3, 10, 11, 12, 3000, pytds.tz.FixedOffsetTimezone(3 * 60)), "cast('2011-02-03T10:11:12.003000+03:00' as datetimeoffset)") + self._t( + datetime( + 2011, 2, 3, 10, 11, 12, 3000, pytds.tz.FixedOffsetTimezone(3 * 60) + ), + "cast('2011-02-03T10:11:12.003000+03:00' as datetimeoffset)", + ) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class BadConnection(unittest.TestCase): def test_invalid_parameters(self): with self.assertRaises(Error): - with connect(server=settings.HOST + 'bad', database='master', user='baduser', password=settings.PASSWORD, login_timeout=1) as conn: + with connect( + server=settings.HOST + "bad", + database="master", + user="baduser", + password=settings.PASSWORD, + login_timeout=1, + ) as conn: with conn.cursor() as cur: - cur.execute('select 1') + cur.execute("select 1") with self.assertRaises(Error): - with connect(server=settings.HOST, database='doesnotexist', user=settings.USER, password=settings.PASSWORD) as conn: + with connect( + server=settings.HOST, + database="doesnotexist", + user=settings.USER, + password=settings.PASSWORD, + ) as conn: with conn.cursor() as cur: - cur.execute('select 1') + cur.execute("select 1") with self.assertRaises(Error): - with connect(server=settings.HOST, database='master', user='baduser', password=None) as conn: + with connect( + server=settings.HOST, database="master", user="baduser", password=None + ) as conn: with conn.cursor() as cur: - cur.execute('select 1') + cur.execute("select 1") def test_instance_and_port(self): host = settings.HOST - if '\\' in host: - host, _ = host.split('\\') - with self.assertRaisesRegex(ValueError, 'Both instance and port shouldn\'t be specified'): - with connect(server=host + '\\badinstancename', database='master', user=settings.USER, password=settings.PASSWORD, port=1212) as conn: + if "\\" in host: + host, _ = host.split("\\") + with self.assertRaisesRegex( + ValueError, "Both instance and port shouldn't be specified" + ): + with connect( + server=host + "\\badinstancename", + database="master", + user=settings.USER, + password=settings.PASSWORD, + port=1212, + ) as conn: with conn.cursor() as cur: - cur.execute('select 1') + cur.execute("select 1") -#class EncryptionTest(unittest.TestCase): +# class EncryptionTest(unittest.TestCase): # def runTest(self): # conn = connect(server=settings.HOST, database='master', user=settings.USER, password=settings.PASSWORD, encryption_level=TDS_ENCRYPTION_REQUIRE) # cur = conn.cursor() @@ -310,7 +395,7 @@ def test_instance_and_port(self): class SmallDateTimeTest(ConnectionTestCase): def _testval(self, val): with self.conn.cursor() as cur: - cur.execute('select cast(%s as smalldatetime)', (val,)) + cur.execute("select cast(%s as smalldatetime)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) def runTest(self): @@ -326,17 +411,24 @@ def runTest(self): @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class DateTimeTest(ConnectionTestCase): def _testencdec(self, val): - self.assertEqual(val, DateTimeSerializer.decode(*DateTimeSerializer._struct.unpack(DateTimeSerializer.encode(val)))) + self.assertEqual( + val, + DateTimeSerializer.decode( + *DateTimeSerializer._struct.unpack(DateTimeSerializer.encode(val)) + ), + ) def _testval(self, val): with self.conn.cursor() as cur: - cur.execute('select cast(%s as datetime)', (val,)) + cur.execute("select cast(%s as datetime)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) def runTest(self): with self.conn.cursor() as cur: cur.execute("select cast('9999-12-31T23:59:59.997' as datetime)") - self.assertEqual(cur.fetchall(), [(Timestamp(9999, 12, 31, 23, 59, 59, 997000),)]) + self.assertEqual( + cur.fetchall(), [(Timestamp(9999, 12, 31, 23, 59, 59, 997000),)] + ) self._testencdec(Timestamp(2010, 1, 2, 10, 11, 12)) self._testval(Timestamp(2010, 1, 2, 0, 0, 0)) self._testval(Timestamp(2010, 1, 2, 10, 11, 12)) @@ -349,34 +441,43 @@ def runTest(self): with self.assertRaises(Error): self._testval(Timestamp(1752, 1, 1, 0, 0, 0)) with self.conn.cursor() as cur: - cur.execute(''' + cur.execute( + """ if object_id('testtable') is not null drop table testtable - ''') - cur.execute('create table testtable (col datetime not null)') + """ + ) + cur.execute("create table testtable (col datetime not null)") dt = Timestamp(2010, 1, 2, 20, 21, 22, 123000) - cur.execute('insert into testtable values (%s)', (dt,)) - cur.execute('select col from testtable') + cur.execute("insert into testtable values (%s)", (dt,)) + cur.execute("select col from testtable") self.assertEqual(cur.fetchone(), (dt,)) class NewDateTimeTest(ConnectionTestCase): def test_datetimeoffset(self): if not IS_TDS73_PLUS(self.conn): - self.skipTest('Requires TDS7.3+') + self.skipTest("Requires TDS7.3+") def _testval(val): with self.conn.cursor() as cur: import pytds.tz + cur.tzinfo_factory = pytds.tz.FixedOffsetTimezone - cur.execute('select cast(%s as datetimeoffset)', (val,)) + cur.execute("select cast(%s as datetimeoffset)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) with self.conn.cursor() as cur: import pytds.tz + cur.tzinfo_factory = pytds.tz.FixedOffsetTimezone - cur.execute("select cast('2010-01-02T20:21:22.1234567+05:00' as datetimeoffset)") - self.assertEqual(datetime(2010, 1, 2, 20, 21, 22, 123456, tzoffset(5 * 60)), cur.fetchone()[0]) + cur.execute( + "select cast('2010-01-02T20:21:22.1234567+05:00' as datetimeoffset)" + ) + self.assertEqual( + datetime(2010, 1, 2, 20, 21, 22, 123456, tzoffset(5 * 60)), + cur.fetchone()[0], + ) _testval(Timestamp(2010, 1, 2, 0, 0, 0, 0, utc)) _testval(Timestamp(2010, 1, 2, 0, 0, 0, 0, tzoffset(5 * 60))) _testval(Timestamp(1, 1, 1, 0, 0, 0, 0, utc)) @@ -387,11 +488,11 @@ def _testval(val): def test_time(self): if not IS_TDS73_PLUS(self.conn): - self.skipTest('Requires TDS7.3+') + self.skipTest("Requires TDS7.3+") def testval(val): with self.conn.cursor() as cur: - cur.execute('select cast(%s as time)', (val,)) + cur.execute("select cast(%s as time)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) testval(Time(14, 16, 18, 123456)) @@ -405,11 +506,11 @@ def testval(val): def test_datetime2(self): if not IS_TDS73_PLUS(self.conn): - self.skipTest('Requires TDS7.3+') + self.skipTest("Requires TDS7.3+") def testval(val): with self.conn.cursor() as cur: - cur.execute('select cast(%s as datetime2)', (val,)) + cur.execute("select cast(%s as datetime2)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) testval(Timestamp(2010, 1, 2, 20, 21, 22, 345678)) @@ -419,11 +520,11 @@ def testval(val): def test_date(self): if not IS_TDS73_PLUS(self.conn): - self.skipTest('Requires TDS7.3+') + self.skipTest("Requires TDS7.3+") def testval(val): with self.conn.cursor() as cur: - cur.execute('select cast(%s as date)', (val,)) + cur.execute("select cast(%s as date)", (val,)) self.assertEqual(cur.fetchall(), [(val,)]) testval(Date(2010, 1, 2)) @@ -434,33 +535,52 @@ def testval(val): @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class Auth(unittest.TestCase): - @unittest.skipUnless(os.getenv('NTLM_USER') and os.getenv('NTLM_PASSWORD'), "requires NTLM_USER and NTLM_PASSWORD environment variables to be set") + @unittest.skipUnless( + os.getenv("NTLM_USER") and os.getenv("NTLM_PASSWORD"), + "requires NTLM_USER and NTLM_PASSWORD environment variables to be set", + ) def test_ntlm(self): - conn = connect(settings.HOST, auth=pytds.login.NtlmAuth(user_name=os.getenv('NTLM_USER'), password=os.getenv('NTLM_PASSWORD'))) + conn = connect( + settings.HOST, + auth=pytds.login.NtlmAuth( + user_name=os.getenv("NTLM_USER"), password=os.getenv("NTLM_PASSWORD") + ), + ) with conn.cursor() as cursor: - cursor.execute('select 1') + cursor.execute("select 1") cursor.fetchall() - @unittest.skipUnless(os.getenv('NTLM_USER') and os.getenv('NTLM_PASSWORD'), "requires NTLM_USER and NTLM_PASSWORD environment variables to be set") + @unittest.skipUnless( + os.getenv("NTLM_USER") and os.getenv("NTLM_PASSWORD"), + "requires NTLM_USER and NTLM_PASSWORD environment variables to be set", + ) def test_spnego(self): - conn = connect(settings.HOST, auth=pytds.login.SpnegoAuth(os.getenv('NTLM_USER'), os.getenv('NTLM_PASSWORD'))) + conn = connect( + settings.HOST, + auth=pytds.login.SpnegoAuth( + os.getenv("NTLM_USER"), os.getenv("NTLM_PASSWORD") + ), + ) with conn.cursor() as cursor: - cursor.execute('select 1') + cursor.execute("select 1") cursor.fetchall() @unittest.skipUnless(sys.platform.startswith("win"), "requires Windows") def test_sspi(self): from pytds.login import SspiAuth + with connect(settings.HOST, auth=SspiAuth()) as conn: with conn.cursor() as cursor: - cursor.execute('select 1') + cursor.execute("select 1") cursor.fetchall() - @unittest.skipIf(getattr(settings, 'SKIP_SQL_AUTH', False), 'SKIP_SQL_AUTH is set') + @unittest.skipIf(getattr(settings, "SKIP_SQL_AUTH", False), "SKIP_SQL_AUTH is set") def test_sqlauth(self): - with connect(settings.HOST, user=settings.USER, password=settings.PASSWORD) as conn: + with connect( + settings.HOST, user=settings.USER, password=settings.PASSWORD + ) as conn: with conn.cursor() as cursor: - cursor.execute('select 1') + cursor.execute("select 1") cursor.fetchall() @@ -480,27 +600,30 @@ def test_cancel(self): class TimezoneTests(unittest.TestCase): def check_val(self, conn, sql, input, output): with conn.cursor() as cur: - cur.execute('select ' + sql, (input,)) + cur.execute("select " + sql, (input,)) rows = cur.fetchall() self.assertEqual(rows[0][0], output) def runTest(self): kwargs = settings.CONNECT_KWARGS.copy() use_tz = utc - kwargs['use_tz'] = use_tz - kwargs['database'] = 'master' + kwargs["use_tz"] = use_tz + kwargs["database"] = "master" with connect(*settings.CONNECT_ARGS, **kwargs) as conn: # Naive time should be interpreted as use_tz - self.check_val(conn, '%s', - datetime(2011, 2, 3, 10, 11, 12, 3000), - datetime(2011, 2, 3, 10, 11, 12, 3000, utc)) + self.check_val( + conn, + "%s", + datetime(2011, 2, 3, 10, 11, 12, 3000), + datetime(2011, 2, 3, 10, 11, 12, 3000, utc), + ) # Aware time shoule be passed as-is dt = datetime(2011, 2, 3, 10, 11, 12, 3000, tzoffset(1)) - self.check_val(conn, '%s', dt, dt) + self.check_val(conn, "%s", dt, dt) # Aware time should be converted to use_tz if not using datetimeoffset type dt = datetime(2011, 2, 3, 10, 11, 12, 3000, tzoffset(1)) if IS_TDS73_PLUS(conn): - self.check_val(conn, 'cast(%s as datetime2)', dt, dt.astimezone(use_tz)) + self.check_val(conn, "cast(%s as datetime2)", dt, dt.astimezone(use_tz)) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") @@ -509,8 +632,8 @@ class DbapiTestSuite(dbapi20.DatabaseAPI20Test, ConnectionTestCase): connect_args = settings.CONNECT_ARGS connect_kw_args = settings.CONNECT_KWARGS -# def _connect(self): -# return connection + # def _connect(self): + # return connection def _try_run(self, *args): with self._connect() as con: @@ -535,7 +658,7 @@ def _callproc_setup(self, cur): select LOWER(@input) END """, - ) + ) # This should create a sproc with a return value. def _retval_setup(self, cur): @@ -549,19 +672,26 @@ def _retval_setup(self, cur): return @input+1 END """, - ) + ) def test_retval(self): with self._connect() as con: cur = con.cursor() self._retval_setup(cur) - values = cur.callproc('add_one', (1,)) - self.assertEqual(values[0], 1, 'input parameter should be left unchanged: %s' % (values[0],)) + values = cur.callproc("add_one", (1,)) + self.assertEqual( + values[0], + 1, + "input parameter should be left unchanged: %s" % (values[0],), + ) self.assertEqual(cur.description, None, "No resultset was expected.") - self.assertEqual(cur.return_value, 2, "Invalid return value: %s" % (cur.return_value,)) + self.assertEqual( + cur.return_value, 2, "Invalid return value: %s" % (cur.return_value,) + ) # This should create a sproc with a return value. + def _retval_select_setup(self, cur): self._try_run2( cur, @@ -575,24 +705,30 @@ def _retval_select_setup(self, cur): return @input+1 END """, - ) + ) def test_retval_select(self): with self._connect() as con: cur = con.cursor() self._retval_select_setup(cur) - values = cur.callproc('add_one_select', (1,)) - self.assertEqual(values[0], 1, 'input parameter should be left unchanged: %s' % (values[0],)) + values = cur.callproc("add_one_select", (1,)) + self.assertEqual( + values[0], + 1, + "input parameter should be left unchanged: %s" % (values[0],), + ) self.assertEqual(len(cur.description), 1, "Unexpected resultset.") - self.assertEqual(cur.description[0][0], 'a', "Unexpected resultset.") - self.assertEqual(cur.fetchall(), [('a',)], 'Unexpected resultset.') + self.assertEqual(cur.description[0][0], "a", "Unexpected resultset.") + self.assertEqual(cur.fetchall(), [("a",)], "Unexpected resultset.") - self.assertTrue(cur.nextset(), 'No second resultset found.') + self.assertTrue(cur.nextset(), "No second resultset found.") self.assertEqual(len(cur.description), 1, "Unexpected resultset.") - self.assertEqual(cur.description[0][0], 'b', "Unexpected resultset.") + self.assertEqual(cur.description[0][0], "b", "Unexpected resultset.") - self.assertEqual(cur.return_value, 2, "Invalid return value: %s" % (cur.return_value,)) + self.assertEqual( + cur.return_value, 2, "Invalid return value: %s" % (cur.return_value,) + ) with self.assertRaises(Error): cur.fetchall() @@ -608,34 +744,36 @@ def _outparam_setup(self, cur): SET @output = @input+1 END """, - ) + ) def test_outparam(self): with self._connect() as con: cur = con.cursor() self._outparam_setup(cur) - values = cur.callproc('add_one_out', (1, output(value=1))) - self.assertEqual(len(values), 2, 'expected 2 parameters') - self.assertEqual(values[0], 1, 'input parameter should be unchanged') - self.assertEqual(values[1], 2, 'output parameter should get new values') + values = cur.callproc("add_one_out", (1, output(value=1))) + self.assertEqual(len(values), 2, "expected 2 parameters") + self.assertEqual(values[0], 1, "input parameter should be unchanged") + self.assertEqual(values[1], 2, "output parameter should get new values") - values = cur.callproc('add_one_out', (None, output(value=1))) - self.assertEqual(len(values), 2, 'expected 2 parameters') - self.assertEqual(values[0], None, 'input parameter should be unchanged') - self.assertEqual(values[1], None, 'output parameter should get new values') + values = cur.callproc("add_one_out", (None, output(value=1))) + self.assertEqual(len(values), 2, "expected 2 parameters") + self.assertEqual(values[0], None, "input parameter should be unchanged") + self.assertEqual(values[1], None, "output parameter should get new values") def test_assigning_select(self): # test that assigning select does not interfere with result sets with self._connect() as con: cur = con.cursor() - cur.execute(""" + cur.execute( + """ declare @var1 int select @var1 = 1 select @var1 = 2 select 'value' -""") +""" + ) self.assertFalse(cur.description) self.assertTrue(cur.nextset()) @@ -643,10 +781,11 @@ def test_assigning_select(self): self.assertTrue(cur.nextset()) self.assertTrue(cur.description) - self.assertEqual([(u'value',)], cur.fetchall()) + self.assertEqual([("value",)], cur.fetchall()) self.assertFalse(cur.nextset()) - cur.execute(""" + cur.execute( + """ set nocount on declare @var1 int @@ -655,9 +794,10 @@ def test_assigning_select(self): select @var1 = 2 select 'value' -""") +""" + ) self.assertTrue(cur.description) - self.assertEqual([(u'value',)], cur.fetchall()) + self.assertEqual([("value",)], cur.fetchall()) self.assertFalse(cur.nextset()) # Don't need setoutputsize tests. @@ -675,8 +815,9 @@ def help_nextset_setUp(self, cur): select count(*) from %sbooze select name from %sbooze end -""" % (self.table_prefix, self.table_prefix), - ) +""" + % (self.table_prefix, self.table_prefix), + ) def help_nextset_tearDown(self, cur): cur.execute("drop procedure deleteme") @@ -686,10 +827,7 @@ def test_ExceptionsAsConnectionAttributes(self): def test_select_decimal_zero(self): with self._connect() as con: - expected = ( - Decimal('0.00'), - Decimal('0.0'), - Decimal('-0.00')) + expected = (Decimal("0.00"), Decimal("0.0"), Decimal("-0.00")) cur = con.cursor() cur.execute("SELECT %s as A, %s as B, %s as C", expected) @@ -700,7 +838,8 @@ def test_select_decimal_zero(self): def test_type_objects(self): with self._connect() as con: cur = con.cursor() - cur.execute(""" + cur.execute( + """ select cast(0 as varchar), cast(1 as binary), cast(2 as int), @@ -708,7 +847,8 @@ def test_type_objects(self): cast(4 as decimal), cast('2005' as datetime), cast('6' as xml) -""") +""" + ) self.assertTrue(cur.description) col_types = [col[1] for col in cur.description] self.assertEqual(col_types[0], STRING) @@ -727,11 +867,13 @@ def test_type_objects(self): class TestBug4(unittest.TestCase): def test_as_dict(self): kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = 'master' - with connect(*settings.CONNECT_ARGS, **kwargs, row_strategy=pytds.dict_row_strategy) as conn: + kwargs["database"] = "master" + with connect( + *settings.CONNECT_ARGS, **kwargs, row_strategy=pytds.dict_row_strategy + ) as conn: with conn.cursor() as cur: - cur.execute('select 1 as a, 2 as b') - self.assertDictEqual({'a': 1, 'b': 2}, cur.fetchone()) + cur.execute("select 1 as a, 2 as b") + self.assertDictEqual({"a": 1, "b": 2}, cur.fetchone()) def _params_tests(self): @@ -739,7 +881,7 @@ def test_val(typ, val): with self.conn.cursor() as cur: param = Param(type=typ, value=val) logger.info("Testing with %s", repr(param)) - cur.execute('select %s', [param]) + cur.execute("select %s", [param]) self.assertTupleEqual(cur.fetchone(), (val,)) self.assertIs(cur.fetchone(), None) @@ -747,9 +889,9 @@ def test_val(typ, val): test_val(BitType(), False) test_val(BitType(), None) test_val(TinyIntType(), 255) - test_val(SmallIntType(), 2 ** 15 - 1) - test_val(IntType(), 2 ** 31 - 1) - test_val(BigIntType(), 2 ** 63 - 1) + test_val(SmallIntType(), 2**15 - 1) + test_val(IntType(), 2**31 - 1) + test_val(BigIntType(), 2**63 - 1) test_val(IntType(), None) test_val(RealType(), 0.25) test_val(FloatType(), 0.25) @@ -769,64 +911,73 @@ def test_val(typ, val): test_val(DateTime2Type(precision=0), datetime(1, 1, 1, 0, 0, 0)) test_val(DateTime2Type(precision=6), datetime(9999, 12, 31, 23, 59, 59, 999999)) test_val(DateTime2Type(precision=0), None) - test_val(DateTimeOffsetType(precision=6), datetime(9999, 12, 31, 23, 59, 59, 999999, utc)) - test_val(DateTimeOffsetType(precision=6), datetime(9999, 12, 31, 23, 59, 59, 999999, tzoffset(14))) - test_val(DateTimeOffsetType(precision=0), datetime(1, 1, 1, 0, 0, 0, tzinfo=tzoffset(-14))) - #test_val(DateTimeOffsetType(precision=0), datetime(1, 1, 1, 0, 0, 0, tzinfo=tzoffset(14))) + test_val( + DateTimeOffsetType(precision=6), + datetime(9999, 12, 31, 23, 59, 59, 999999, utc), + ) + test_val( + DateTimeOffsetType(precision=6), + datetime(9999, 12, 31, 23, 59, 59, 999999, tzoffset(14)), + ) + test_val( + DateTimeOffsetType(precision=0), + datetime(1, 1, 1, 0, 0, 0, tzinfo=tzoffset(-14)), + ) + # test_val(DateTimeOffsetType(precision=0), datetime(1, 1, 1, 0, 0, 0, tzinfo=tzoffset(14))) test_val(DateTimeOffsetType(precision=6), None) - test_val(DecimalType(scale=6, precision=38), Decimal('123.456789')) + test_val(DecimalType(scale=6, precision=38), Decimal("123.456789")) test_val(DecimalType(scale=6, precision=38), None) - test_val(SmallMoneyType(), Decimal('-214748.3648')) - test_val(SmallMoneyType(), Decimal('214748.3647')) - test_val(MoneyType(), Decimal('922337203685477.5807')) - test_val(MoneyType(), Decimal('-922337203685477.5808')) + test_val(SmallMoneyType(), Decimal("-214748.3648")) + test_val(SmallMoneyType(), Decimal("214748.3647")) + test_val(MoneyType(), Decimal("922337203685477.5807")) + test_val(MoneyType(), Decimal("-922337203685477.5808")) test_val(MoneyType(), None) test_val(UniqueIdentifierType(), None) test_val(UniqueIdentifierType(), uuid.uuid4()) if pytds.tds_base.IS_TDS71_PLUS(self.conn._tds_socket): test_val(VariantType(), None) - #test_val(self.conn._conn.type_factory.SqlVariant(10), 100) - test_val(VarBinaryType(size=10), b'') - test_val(VarBinaryType(size=10), b'testtest12') + # test_val(self.conn._conn.type_factory.SqlVariant(10), 100) + test_val(VarBinaryType(size=10), b"") + test_val(VarBinaryType(size=10), b"testtest12") test_val(VarBinaryType(size=10), None) - test_val(VarBinaryType(size=8000), b'x' * 8000) + test_val(VarBinaryType(size=8000), b"x" * 8000) test_val(VarCharType(size=10), None) - test_val(VarCharType(size=10), '') - test_val(VarCharType(size=10), 'test') - test_val(VarCharType(size=8000), 'x' * 8000) - test_val(NVarCharType(size=10), u'') - test_val(NVarCharType(size=10), u'testtest12') + test_val(VarCharType(size=10), "") + test_val(VarCharType(size=10), "test") + test_val(VarCharType(size=8000), "x" * 8000) + test_val(NVarCharType(size=10), "") + test_val(NVarCharType(size=10), "testtest12") test_val(NVarCharType(size=10), None) - test_val(NVarCharType(size=4000), u'x' * 4000) + test_val(NVarCharType(size=4000), "x" * 4000) test_val(TextType(), None) - test_val(TextType(), '') - test_val(TextType(), 'hello') + test_val(TextType(), "") + test_val(TextType(), "hello") test_val(NTextType(), None) - test_val(NTextType(), '') - test_val(NTextType(), 'hello') + test_val(NTextType(), "") + test_val(NTextType(), "hello") test_val(ImageType(), None) - test_val(ImageType(), b'') - test_val(ImageType(), b'test') + test_val(ImageType(), b"") + test_val(ImageType(), b"test") if pytds.tds_base.IS_TDS72_PLUS(self.conn._tds_socket): test_val(VarBinaryMaxType(), None) - test_val(VarBinaryMaxType(), b'') - test_val(VarBinaryMaxType(), b'testtest12') - test_val(VarBinaryMaxType(), b'x' * (10 ** 6)) + test_val(VarBinaryMaxType(), b"") + test_val(VarBinaryMaxType(), b"testtest12") + test_val(VarBinaryMaxType(), b"x" * (10**6)) test_val(NVarCharMaxType(), None) - test_val(NVarCharMaxType(), 'test') - test_val(NVarCharMaxType(), 'x' * (10 ** 6)) + test_val(NVarCharMaxType(), "test") + test_val(NVarCharMaxType(), "x" * (10**6)) test_val(VarCharMaxType(), None) - test_val(VarCharMaxType(), 'test') - test_val(VarCharMaxType(), 'x' * (10 ** 6)) - test_val(XmlType(), '') + test_val(VarCharMaxType(), "test") + test_val(VarCharMaxType(), "x" * (10**6)) + test_val(XmlType(), "") @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class TestTds70(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = 'master' - kwargs['tds_version'] = pytds.tds_base.TDS70 + kwargs["database"] = "master" + kwargs["tds_version"] = pytds.tds_base.TDS70 self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def test_parsing(self): @@ -837,8 +988,8 @@ def test_parsing(self): class TestTds71(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = settings.DATABASE - kwargs['tds_version'] = pytds.tds_base.TDS71 + kwargs["database"] = settings.DATABASE + kwargs["tds_version"] = pytds.tds_base.TDS71 self.conn = connect(*settings.CONNECT_ARGS, **kwargs) utils.create_test_database(self.conn) self.conn.commit() @@ -846,19 +997,20 @@ def setUp(self): def test_parsing(self): _params_tests(self) - def test_bulk(self): f = StringIO("42\tfoo\n74\tbar\n") with self.conn.cursor() as cur: - cur.copy_to(f, 'bulk_insert_table', schema='myschema', columns=('num', 'data')) - cur.execute('select num, data from myschema.bulk_insert_table') - self.assertListEqual(cur.fetchall(), [(42, 'foo'), (74, 'bar')]) + cur.copy_to( + f, "bulk_insert_table", schema="myschema", columns=("num", "data") + ) + cur.execute("select num, data from myschema.bulk_insert_table") + self.assertListEqual(cur.fetchall(), [(42, "foo"), (74, "bar")]) def test_call_proc(self): with self.conn.cursor() as cur: val = 45 - values = cur.callproc('testproc', (val, default, output(value=1))) - #self.assertEqual(cur.fetchall(), [(val,)]) + values = cur.callproc("testproc", (val, default, output(value=1))) + # self.assertEqual(cur.fetchall(), [(val,)]) self.assertEqual(val + 2, values[2]) self.assertEqual(val + 2, cur.get_proc_return_status()) @@ -867,8 +1019,8 @@ def test_call_proc(self): class TestTds72(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = 'master' - kwargs['tds_version'] = pytds.tds_base.TDS72 + kwargs["database"] = "master" + kwargs["tds_version"] = pytds.tds_base.TDS72 self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def test_parsing(self): @@ -879,8 +1031,8 @@ def test_parsing(self): class TestTds73A(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = 'master' - kwargs['tds_version'] = pytds.tds_base.TDS73A + kwargs["database"] = "master" + kwargs["tds_version"] = pytds.tds_base.TDS73A self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def test_parsing(self): @@ -891,36 +1043,43 @@ def test_parsing(self): class TestTds73B(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = 'master' - kwargs['tds_version'] = pytds.tds_base.TDS73B + kwargs["database"] = "master" + kwargs["tds_version"] = pytds.tds_base.TDS73B self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def test_parsing(self): _params_tests(self) + @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") class TestRawBytes(unittest.TestCase): def setUp(self): kwargs = settings.CONNECT_KWARGS.copy() - kwargs['bytes_to_unicode'] = False - kwargs['database'] = 'master' + kwargs["bytes_to_unicode"] = False + kwargs["database"] = "master" self.conn = connect(*settings.CONNECT_ARGS, **kwargs) def test_fetch(self): cur = self.conn.cursor() - self.assertIsInstance(cur.execute_scalar("select cast('abc' as nvarchar(max))"), str) - self.assertIsInstance(cur.execute_scalar("select cast('abc' as varchar(max))"), bytes) + self.assertIsInstance( + cur.execute_scalar("select cast('abc' as nvarchar(max))"), str + ) + self.assertIsInstance( + cur.execute_scalar("select cast('abc' as varchar(max))"), bytes + ) self.assertIsInstance(cur.execute_scalar("select cast('abc' as text)"), bytes) - self.assertIsInstance(cur.execute_scalar("select %s", ['abc']), str) - self.assertIsInstance(cur.execute_scalar("select %s", [b'abc']), bytes) + self.assertIsInstance(cur.execute_scalar("select %s", ["abc"]), str) + self.assertIsInstance(cur.execute_scalar("select %s", [b"abc"]), bytes) - rawBytes = b'\x01\x02\x03' - self.assertEqual(rawBytes, cur.execute_scalar("select cast(0x010203 as varchar(max))")) + rawBytes = b"\x01\x02\x03" + self.assertEqual( + rawBytes, cur.execute_scalar("select cast(0x010203 as varchar(max))") + ) self.assertEqual(rawBytes, cur.execute_scalar("select %s", [rawBytes])) - utf8char = b'\xee\xb4\xba' + utf8char = b"\xee\xb4\xba" self.assertEqual(utf8char, cur.execute_scalar("select %s", [utf8char])) @@ -931,20 +1090,24 @@ def test_invalid_block_size(): and then it should upgrade to buffer size that was provided in login request. """ kwargs = settings.CONNECT_KWARGS.copy() - kwargs.update({ - 'blocksize': 4000, - }) + kwargs.update( + { + "blocksize": 4000, + } + ) with connect(**kwargs) as conn: with conn.cursor() as cur: - cur.execute_scalar("select '{}'".format('x' * 8000)) + cur.execute_scalar("select '{}'".format("x" * 8000)) @unittest.skipUnless(LIVE_TEST, "requires HOST variable to be set") def test_readonly_connection(): kwargs = settings.CONNECT_KWARGS.copy() - kwargs.update({ - 'readonly': True, - }) + kwargs.update( + { + "readonly": True, + } + ) with connect(**kwargs) as conn: with conn.cursor() as cur: cur.execute_scalar("select 1") diff --git a/tests/connected_test.py b/tests/connected_test.py index 488a31f..1325acd 100644 --- a/tests/connected_test.py +++ b/tests/connected_test.py @@ -22,30 +22,29 @@ from tests.utils import tran_count logger = logging.getLogger(__name__) -LIVE_TEST = getattr(settings, 'LIVE_TEST', True) +LIVE_TEST = getattr(settings, "LIVE_TEST", True) pytds.tds.logging_enabled = True def test_integrity_error(cursor): - cursor.execute('create table testtable_pk(pk int primary key)') - cursor.execute('insert into testtable_pk values (1)') + cursor.execute("create table testtable_pk(pk int primary key)") + cursor.execute("insert into testtable_pk values (1)") with pytest.raises(pytds.IntegrityError): - cursor.execute('insert into testtable_pk values (1)') - - + cursor.execute("insert into testtable_pk values (1)") def test_bulk_insert(cursor): cur = cursor f = StringIO("42\tfoo\n74\tbar\n") - cur.copy_to(f, 'bulk_insert_table', schema='myschema', columns=('num', 'data')) - cur.execute('select num, data from myschema.bulk_insert_table') - assert [(42, 'foo'), (74, 'bar')] == cur.fetchall() + cur.copy_to(f, "bulk_insert_table", schema="myschema", columns=("num", "data")) + cur.execute("select num, data from myschema.bulk_insert_table") + assert [(42, "foo"), (74, "bar")] == cur.fetchall() def test_bug2(cursor): cur = cursor - cur.execute(''' + cur.execute( + """ create procedure testproc_bug2 (@param int) as begin @@ -53,9 +52,10 @@ def test_bug2(cursor): select @param return @param + 1 end - ''') + """ + ) val = 45 - cur.execute('exec testproc_bug2 @param = 45') + cur.execute("exec testproc_bug2 @param = 45") assert cur.fetchall() == [(val,)] assert val + 1 == cur.get_proc_return_status() @@ -63,223 +63,299 @@ def test_bug2(cursor): def test_stored_proc(cursor): cur = cursor val = 45 - #params = {'@param': val, '@outparam': output(None), '@add': 1} - values = cur.callproc('testproc', (val, pytds.default, pytds.output(value=1))) - #self.assertEqual(cur.fetchall(), [(val,)]) + # params = {'@param': val, '@outparam': output(None), '@add': 1} + values = cur.callproc("testproc", (val, pytds.default, pytds.output(value=1))) + # self.assertEqual(cur.fetchall(), [(val,)]) assert val + 2 == values[2] assert val + 2 == cur.get_proc_return_status() # after calling stored proc which does not have RETURN statement get_proc_return_status() should return 0 # since in this case SQL server issues RETURN STATUS token with 0 value - cur.callproc('test_proc_no_return', (val,)) + cur.callproc("test_proc_no_return", (val,)) assert cur.fetchall() == [(val,)] assert cur.get_proc_return_status() == 0 - #TODO fix this part, currently it fails - #assert cur.execute_scalar("select 1") == 1 - #assert cur.get_proc_return_status() == 0 + # TODO fix this part, currently it fails + # assert cur.execute_scalar("select 1") == 1 + # assert cur.get_proc_return_status() == 0 def test_table_selects(db_connection): cur = db_connection.cursor() - cur.execute(u''' + cur.execute( + """ create table #testtable (id int, _text text, _xml xml, vcm varchar(max), vc varchar(10)) - ''') - cur.execute(u''' + """ + ) + cur.execute( + """ insert into #testtable (id, _text, _xml, vcm, vc) values (1, 'text', '', '', NULL) - ''') - cur.execute('select id from #testtable order by id') + """ + ) + cur.execute("select id from #testtable order by id") assert [(1,)] == cur.fetchall() cur = db_connection.cursor() - cur.execute('select _text from #testtable order by id') - assert [(u'text',)] == cur.fetchall() + cur.execute("select _text from #testtable order by id") + assert [("text",)] == cur.fetchall() cur = db_connection.cursor() - cur.execute('select _xml from #testtable order by id') - assert [('',)] == cur.fetchall() + cur.execute("select _xml from #testtable order by id") + assert [("",)] == cur.fetchall() cur = db_connection.cursor() - cur.execute('select id, _text, _xml, vcm, vc from #testtable order by id') - assert (1, 'text', '', '', None) == cur.fetchone() + cur.execute("select id, _text, _xml, vcm, vc from #testtable order by id") + assert (1, "text", "", "", None) == cur.fetchone() cur = db_connection.cursor() - cur.execute('select vc from #testtable order by id') + cur.execute("select vc from #testtable order by id") assert [(None,)] == cur.fetchall() cur = db_connection.cursor() - cur.execute('insert into #testtable (_xml) values (%s)', ('',)) + cur.execute("insert into #testtable (_xml) values (%s)", ("",)) cur = db_connection.cursor() - cur.execute(u'drop table #testtable') + cur.execute("drop table #testtable") def test_decimals(cursor): cur = cursor - assert Decimal(12) == cur.execute_scalar('select cast(12 as decimal) as fieldname') - assert Decimal(-12) == cur.execute_scalar('select cast(-12 as decimal) as fieldname') - assert Decimal('123456.12345') == cur.execute_scalar("select cast('123456.12345'as decimal(20,5)) as fieldname") - assert Decimal('-123456.12345') == cur.execute_scalar("select cast('-123456.12345'as decimal(20,5)) as fieldname") + assert Decimal(12) == cur.execute_scalar("select cast(12 as decimal) as fieldname") + assert Decimal(-12) == cur.execute_scalar( + "select cast(-12 as decimal) as fieldname" + ) + assert Decimal("123456.12345") == cur.execute_scalar( + "select cast('123456.12345'as decimal(20,5)) as fieldname" + ) + assert Decimal("-123456.12345") == cur.execute_scalar( + "select cast('-123456.12345'as decimal(20,5)) as fieldname" + ) def test_bulk_insert_with_special_chars_no_columns(cursor): cur = cursor - cur.execute('create table [test]] table](num int not null, data varchar(100))') + cur.execute("create table [test]] table](num int not null, data varchar(100))") f = StringIO("42\tfoo\n74\tbar\n") - cur.copy_to(f, 'test] table') - cur.execute('select num, data from [test]] table]') - assert cur.fetchall() == [(42, 'foo'), (74, 'bar')] + cur.copy_to(f, "test] table") + cur.execute("select num, data from [test]] table]") + assert cur.fetchall() == [(42, "foo"), (74, "bar")] def test_bulk_insert_with_special_chars(cursor): cur = cursor - cur.execute('create table [test]] table](num int, data varchar(100))') + cur.execute("create table [test]] table](num int, data varchar(100))") f = StringIO("42\tfoo\n74\tbar\n") - cur.copy_to(f, 'test] table', columns=('num', 'data')) - cur.execute('select num, data from [test]] table]') - assert cur.fetchall() == [(42, 'foo'), (74, 'bar')] + cur.copy_to(f, "test] table", columns=("num", "data")) + cur.execute("select num, data from [test]] table]") + assert cur.fetchall() == [(42, "foo"), (74, "bar")] def test_bulk_insert_with_keyword_column_name(cursor): cur = cursor - cur.execute('create table test_table(num int, [User] varchar(100))') + cur.execute("create table test_table(num int, [User] varchar(100))") f = StringIO("42\tfoo\n74\tbar\n") - cur.copy_to(f, 'test_table') - cur.execute('select num, [User] from test_table') - assert cur.fetchall() == [(42, 'foo'), (74, 'bar')] + cur.copy_to(f, "test_table") + cur.execute("select num, [User] from test_table") + assert cur.fetchall() == [(42, "foo"), (74, "bar")] def test_bulk_insert_with_direct_data(cursor): cur = cursor - cur.execute('create table test_table(num int, data nvarchar(max))') + cur.execute("create table test_table(num int, data nvarchar(max))") - data = [ - [42, 'foo'], - [57, ''], - [66, None], - [74, 'bar'] - ] + data = [[42, "foo"], [57, ""], [66, None], [74, "bar"]] column_types = [ - pytds.tds_base.Column('num', type=pytds.tds_types.IntType()), - pytds.tds_base.Column('data', type=pytds.tds_types.NVarCharMaxType()) + pytds.tds_base.Column("num", type=pytds.tds_types.IntType()), + pytds.tds_base.Column("data", type=pytds.tds_types.NVarCharMaxType()), ] - cur.copy_to(data=data, table_or_view='test_table', columns=column_types) - cur.execute('select num, data from test_table') - assert cur.fetchall() == [(42, 'foo'), (57, ''), (66, None), (74, 'bar')] + cur.copy_to(data=data, table_or_view="test_table", columns=column_types) + cur.execute("select num, data from test_table") + assert cur.fetchall() == [(42, "foo"), (57, ""), (66, None), (74, "bar")] def test_table_valued_type_autodetect(cursor): def rows_gen(): - yield 1, 'test1' - yield 2, 'test2' + yield 1, "test1" + yield 2, "test2" - tvp = pytds.TableValuedParam(type_name='dbo.CategoryTableType', rows=rows_gen()) - cursor.execute('SELECT * FROM %s', (tvp,)) - assert cursor.fetchall() == [(1, 'test1'), (2, 'test2')] + tvp = pytds.TableValuedParam(type_name="dbo.CategoryTableType", rows=rows_gen()) + cursor.execute("SELECT * FROM %s", (tvp,)) + assert cursor.fetchall() == [(1, "test1"), (2, "test2")] def test_table_valued_type_explicit(cursor): def rows_gen(): - yield 1, 'test1' - yield 2, 'test2' + yield 1, "test1" + yield 2, "test2" tvp = pytds.TableValuedParam( - type_name='dbo.CategoryTableType', + type_name="dbo.CategoryTableType", columns=( pytds.Column(type=pytds.tds_types.IntType()), - pytds.Column(type=pytds.tds_types.NVarCharType(size=30)) + pytds.Column(type=pytds.tds_types.NVarCharType(size=30)), ), - rows=rows_gen() + rows=rows_gen(), ) - cursor.execute('SELECT * FROM %s', (tvp,)) - assert cursor.fetchall() == [(1, 'test1'), (2, 'test2')] + cursor.execute("SELECT * FROM %s", (tvp,)) + assert cursor.fetchall() == [(1, "test1"), (2, "test2")] def test_minimal(cursor): - cursor.execute('select 1') + cursor.execute("select 1") assert [(1,)] == cursor.fetchall() def test_empty_query(cursor): - cursor.execute('') + cursor.execute("") assert cursor.description is None @pytest.mark.parametrize( - 'typ,values', + "typ,values", [ (pytds.tds_types.BitType(), [True, False]), - (pytds.tds_types.IntType(), [2 ** 31 - 1, None]), - (pytds.tds_types.IntType(), [-2 ** 31, None]), - (pytds.tds_types.SmallIntType(), [-2 ** 15, 2 ** 15 - 1]), + (pytds.tds_types.IntType(), [2**31 - 1, None]), + (pytds.tds_types.IntType(), [-(2**31), None]), + (pytds.tds_types.SmallIntType(), [-(2**15), 2**15 - 1]), (pytds.tds_types.TinyIntType(), [255, 0]), - (pytds.tds_types.BigIntType(), [2 ** 63 - 1, -2 ** 63]), - (pytds.tds_types.IntType(), [None, 2 ** 31 - 1]), - (pytds.tds_types.IntType(), [None, -2 ** 31]), + (pytds.tds_types.BigIntType(), [2**63 - 1, -(2**63)]), + (pytds.tds_types.IntType(), [None, 2**31 - 1]), + (pytds.tds_types.IntType(), [None, -(2**31)]), (pytds.tds_types.RealType(), [0.25, None]), (pytds.tds_types.FloatType(), [0.25, None]), - (pytds.tds_types.VarCharType(size=10), [u'', u'testtest12', None, u'foo']), - (pytds.tds_types.VarCharType(size=4000), [u'x' * 4000, u'foo']), - (pytds.tds_types.VarCharMaxType(), [u'x' * 10000, u'foo', u'', u'testtest', None, u'bar']), - (pytds.tds_types.NVarCharType(size=10), [u'', u'testtest12', None, u'foo']), - (pytds.tds_types.NVarCharType(size=4000), [u'x' * 4000, u'foo']), - (pytds.tds_types.NVarCharMaxType(), [u'x' * 10000, u'foo', u'', u'testtest', None, u'bar']), - (pytds.tds_types.VarBinaryType(size=10), [b'testtest12', b'', None]), - (pytds.tds_types.VarBinaryType(size=8000), [b'x' * 8000, b'']), - (pytds.tds_types.SmallDateTimeType(), [datetime.datetime(1900, 1, 1, 0, 0, 0), None, datetime.datetime(2079, 6, 6, 23, 59, 0)]), - (pytds.tds_types.DateTimeType(), [datetime.datetime(1753, 1, 1, 0, 0, 0), None, datetime.datetime(9999, 12, 31, 23, 59, 59, 990000)]), - (pytds.tds_types.DateType(), [datetime.date(1, 1, 1), None, datetime.date(9999, 12, 31)]), + (pytds.tds_types.VarCharType(size=10), ["", "testtest12", None, "foo"]), + (pytds.tds_types.VarCharType(size=4000), ["x" * 4000, "foo"]), + ( + pytds.tds_types.VarCharMaxType(), + ["x" * 10000, "foo", "", "testtest", None, "bar"], + ), + (pytds.tds_types.NVarCharType(size=10), ["", "testtest12", None, "foo"]), + (pytds.tds_types.NVarCharType(size=4000), ["x" * 4000, "foo"]), + ( + pytds.tds_types.NVarCharMaxType(), + ["x" * 10000, "foo", "", "testtest", None, "bar"], + ), + (pytds.tds_types.VarBinaryType(size=10), [b"testtest12", b"", None]), + (pytds.tds_types.VarBinaryType(size=8000), [b"x" * 8000, b""]), + ( + pytds.tds_types.SmallDateTimeType(), + [ + datetime.datetime(1900, 1, 1, 0, 0, 0), + None, + datetime.datetime(2079, 6, 6, 23, 59, 0), + ], + ), + ( + pytds.tds_types.DateTimeType(), + [ + datetime.datetime(1753, 1, 1, 0, 0, 0), + None, + datetime.datetime(9999, 12, 31, 23, 59, 59, 990000), + ], + ), + ( + pytds.tds_types.DateType(), + [datetime.date(1, 1, 1), None, datetime.date(9999, 12, 31)], + ), (pytds.tds_types.TimeType(precision=0), [datetime.time(0, 0, 0), None]), - (pytds.tds_types.TimeType(precision=6), [datetime.time(23, 59, 59, 999999), None]), + ( + pytds.tds_types.TimeType(precision=6), + [datetime.time(23, 59, 59, 999999), None], + ), (pytds.tds_types.TimeType(precision=0), [None]), - (pytds.tds_types.DateTime2Type(precision=0), [datetime.datetime(1, 1, 1, 0, 0, 0), None]), - (pytds.tds_types.DateTime2Type(precision=6), [datetime.datetime(9999, 12, 31, 23, 59, 59, 999999), None]), + ( + pytds.tds_types.DateTime2Type(precision=0), + [datetime.datetime(1, 1, 1, 0, 0, 0), None], + ), + ( + pytds.tds_types.DateTime2Type(precision=6), + [datetime.datetime(9999, 12, 31, 23, 59, 59, 999999), None], + ), (pytds.tds_types.DateTime2Type(precision=0), [None]), - (pytds.tds_types.DateTimeOffsetType(precision=6), [datetime.datetime(9999, 12, 31, 23, 59, 59, 999999, pytds.tz.utc), None]), - (pytds.tds_types.DateTimeOffsetType(precision=6), [datetime.datetime(9999, 12, 31, 23, 59, 59, 999999, pytds.tz.FixedOffsetTimezone(14)), None]), - (pytds.tds_types.DateTimeOffsetType(precision=0), [datetime.datetime(1, 1, 1, 0, 0, 0, tzinfo=pytds.tz.FixedOffsetTimezone(-14))]), - (pytds.tds_types.DateTimeOffsetType(precision=0), [datetime.datetime(1, 1, 1, 0, 14, 0, tzinfo=pytds.tz.FixedOffsetTimezone(14))]), + ( + pytds.tds_types.DateTimeOffsetType(precision=6), + [datetime.datetime(9999, 12, 31, 23, 59, 59, 999999, pytds.tz.utc), None], + ), + ( + pytds.tds_types.DateTimeOffsetType(precision=6), + [ + datetime.datetime( + 9999, 12, 31, 23, 59, 59, 999999, pytds.tz.FixedOffsetTimezone(14) + ), + None, + ], + ), + ( + pytds.tds_types.DateTimeOffsetType(precision=0), + [ + datetime.datetime( + 1, 1, 1, 0, 0, 0, tzinfo=pytds.tz.FixedOffsetTimezone(-14) + ) + ], + ), + ( + pytds.tds_types.DateTimeOffsetType(precision=0), + [ + datetime.datetime( + 1, 1, 1, 0, 14, 0, tzinfo=pytds.tz.FixedOffsetTimezone(14) + ) + ], + ), (pytds.tds_types.DateTimeOffsetType(precision=6), [None]), - (pytds.tds_types.DecimalType(scale=6, precision=38), [Decimal('123.456789'), None]), - (pytds.tds_types.SmallMoneyType(), [Decimal('214748.3647'), None, Decimal('-214748.3648')]), - (pytds.tds_types.MoneyType(), [Decimal('922337203685477.5807'), None, Decimal('-922337203685477.5808')]), - (pytds.tds_types.SmallMoneyType(), [Decimal('214748.3647')]), - (pytds.tds_types.MoneyType(), [Decimal('922337203685477.5807')]), + ( + pytds.tds_types.DecimalType(scale=6, precision=38), + [Decimal("123.456789"), None], + ), + ( + pytds.tds_types.SmallMoneyType(), + [Decimal("214748.3647"), None, Decimal("-214748.3648")], + ), + ( + pytds.tds_types.MoneyType(), + [Decimal("922337203685477.5807"), None, Decimal("-922337203685477.5808")], + ), + (pytds.tds_types.SmallMoneyType(), [Decimal("214748.3647")]), + (pytds.tds_types.MoneyType(), [Decimal("922337203685477.5807")]), (pytds.tds_types.MoneyType(), [None]), (pytds.tds_types.UniqueIdentifierType(), [None, uuid.uuid4()]), (pytds.tds_types.VariantType(), [None]), - #(pytds.tds_types.VariantType(), [100]), - #(pytds.tds_types.ImageType(), [None]), + # (pytds.tds_types.VariantType(), [100]), + # (pytds.tds_types.ImageType(), [None]), (pytds.tds_types.VarBinaryMaxType(), [None]), - #(pytds.tds_types.NTextType(), [None]), - #(pytds.tds_types.TextType(), [None]), - #(pytds.tds_types.ImageType(), [b'']), - #(self.conn._conn.type_factory.long_binary_type(), [b'testtest12']), - #(self.conn._conn.type_factory.long_string_type(), [None]), - #(self.conn._conn.type_factory.long_varchar_type(), [None]), - #(self.conn._conn.type_factory.long_string_type(), ['test']), - #(pytds.tds_types.ImageType(), [None]), - #(pytds.tds_types.ImageType(), [None]), - #(pytds.tds_types.ImageType(), [b'test']), -]) + # (pytds.tds_types.NTextType(), [None]), + # (pytds.tds_types.TextType(), [None]), + # (pytds.tds_types.ImageType(), [b'']), + # (self.conn._conn.type_factory.long_binary_type(), [b'testtest12']), + # (self.conn._conn.type_factory.long_string_type(), [None]), + # (self.conn._conn.type_factory.long_varchar_type(), [None]), + # (self.conn._conn.type_factory.long_string_type(), ['test']), + # (pytds.tds_types.ImageType(), [None]), + # (pytds.tds_types.ImageType(), [None]), + # (pytds.tds_types.ImageType(), [b'test']), + ], +) def test_bulk_insert_type(cursor, typ, values): cur = cursor - cur.execute('create table bulk_insert_table_ll(c1 {0})'.format(typ.get_declaration())) - cur._session.submit_plain_query('insert bulk bulk_insert_table_ll (c1 {0})'.format(typ.get_declaration())) + cur.execute( + "create table bulk_insert_table_ll(c1 {0})".format(typ.get_declaration()) + ) + cur._session.submit_plain_query( + "insert bulk bulk_insert_table_ll (c1 {0})".format(typ.get_declaration()) + ) cur._session.process_simple_request() - col1 = pytds.Column(name='c1', type=typ, flags=pytds.Column.fNullable) + col1 = pytds.Column(name="c1", type=typ, flags=pytds.Column.fNullable) metadata = [col1] cur._session.submit_bulk(metadata, [[value] for value in values]) cur._session.process_simple_request() - cur.execute('select c1 from bulk_insert_table_ll') + cur.execute("select c1 from bulk_insert_table_ll") assert cur.fetchall() == [(value,) for value in values] assert cur.fetchone() is None def test_streaming(cursor): - val = 'x' * 10000 + val = "x" * 10000 # test nvarchar(max) cursor.execute("select N'{}', 1".format(val)) with pytest.raises(ValueError): @@ -311,14 +387,14 @@ def test_streaming(cursor): cursor.set_stream(0, BytesIO()) row = cursor.fetchone() assert isinstance(row[0], BytesIO) - assert row[0].getvalue().decode('ascii') == val + assert row[0].getvalue().decode("ascii") == val # test image type cursor.execute("select cast('{}' as image), 1".format(val)) cursor.set_stream(0, BytesIO()) row = cursor.fetchone() assert isinstance(row[0], BytesIO) - assert row[0].getvalue().decode('ascii') == val + assert row[0].getvalue().decode("ascii") == val # test ntext type cursor.execute("select cast('{}' as ntext), 1".format(val)) @@ -335,7 +411,7 @@ def test_streaming(cursor): assert row[0].getvalue() == val # test xml type - xml_val = '{}'.format(val) + xml_val = "{}".format(val) cursor.execute("select cast('{}' as xml), 1".format(xml_val)) cursor.set_stream(0, StringIO()) row = cursor.fetchone() @@ -344,8 +420,8 @@ def test_streaming(cursor): def test_dictionary_params(cursor): - assert cursor.execute_scalar('select %(param)s', {'param': None}) == None - assert cursor.execute_scalar('select %(param)s', {'param': 1}) == 1 + assert cursor.execute_scalar("select %(param)s", {"param": None}) == None + assert cursor.execute_scalar("select %(param)s", {"param": 1}) == 1 def test_properties(separate_db_connection): @@ -362,7 +438,7 @@ def test_properties(separate_db_connection): def test_fetch_on_empty_dataset(cursor): - cursor.execute('declare @x int') + cursor.execute("declare @x int") with pytest.raises(pytds.ProgrammingError): cursor.fetchall() @@ -370,33 +446,33 @@ def test_fetch_on_empty_dataset(cursor): def test_bad_collation(cursor): # exception can be different with pytest.raises(UnicodeDecodeError): - cursor.execute_scalar('select cast(0x90 as varchar)') + cursor.execute_scalar("select cast(0x90 as varchar)") # check that connection is still usable - assert 1 == cursor.execute_scalar('select 1') + assert 1 == cursor.execute_scalar("select 1") def test_overlimit(cursor): def test_val(val): - cursor.execute('select %s', (val,)) + cursor.execute("select %s", (val,)) assert cursor.fetchone() == (val,) assert cursor.fetchone() is None ##cur.execute('select %s', '\x00'*(2**31)) with pytest.raises(pytds.DataError): - test_val(Decimal('1' + '0' * 38)) + test_val(Decimal("1" + "0" * 38)) with pytest.raises(pytds.DataError): - test_val(Decimal('-1' + '0' * 38)) + test_val(Decimal("-1" + "0" * 38)) with pytest.raises(pytds.DataError): - test_val(Decimal('1E38')) - val = -10 ** 38 - cursor.execute('select %s', (val,)) + test_val(Decimal("1E38")) + val = -(10**38) + cursor.execute("select %s", (val,)) assert cursor.fetchone() == (str(val),) assert cursor.fetchone() is None def test_description(cursor): - cursor.execute('select cast(12.65 as decimal(4,2)) as testname') - assert cursor.description[0][0] == 'testname' + cursor.execute("select cast(12.65 as decimal(4,2)) as testname") + assert cursor.description[0][0] == "testname" assert cursor.description[0][1] == pytds.DECIMAL assert cursor.description[0][4] == 4 assert cursor.description[0][5] == 2 @@ -404,27 +480,37 @@ def test_description(cursor): def test_bug4(separate_db_connection): with separate_db_connection.cursor() as cursor: - cursor.execute(''' + cursor.execute( + """ set transaction isolation level read committed select 1 - ''') + """ + ) assert cursor.fetchall() == [(1,)] def test_row_strategies(separate_db_connection): - conn = pytds.connect(*settings.CONNECT_ARGS, **settings.CONNECT_KWARGS, row_strategy=pytds.dict_row_strategy) + conn = pytds.connect( + *settings.CONNECT_ARGS, + **settings.CONNECT_KWARGS, + row_strategy=pytds.dict_row_strategy, + ) with conn.cursor() as cur: - cur.execute('select 1 as f') - assert cur.fetchall() == [{'f': 1}] - conn = pytds.connect(*settings.CONNECT_ARGS, **settings.CONNECT_KWARGS, row_strategy=pytds.tuple_row_strategy) + cur.execute("select 1 as f") + assert cur.fetchall() == [{"f": 1}] + conn = pytds.connect( + *settings.CONNECT_ARGS, + **settings.CONNECT_KWARGS, + row_strategy=pytds.tuple_row_strategy, + ) with conn.cursor() as cur: - cur.execute('select 1 as f') + cur.execute("select 1 as f") assert cur.fetchall() == [(1,)] def test_fetchone(cursor): cur = cursor - cur.execute('select 10; select 12') + cur.execute("select 10; select 12") assert (10,) == cur.fetchone() assert cur.nextset() assert (12,) == cur.fetchone() @@ -433,7 +519,7 @@ def test_fetchone(cursor): def test_fetchall(cursor): cur = cursor - cur.execute('select 10; select 12') + cur.execute("select 10; select 12") assert [(10,)] == cur.fetchall() assert cur.nextset() assert [(12,)] == cur.fetchall() @@ -442,61 +528,67 @@ def test_fetchall(cursor): def test_cursor_closing(db_connection): with db_connection.cursor() as cur: - cur.execute('select 10; select 12') + cur.execute("select 10; select 12") cur.fetchone() with db_connection.cursor() as cur2: - cur2.execute('select 20') + cur2.execute("select 20") cur2.fetchone() def test_multi_packet(cursor): cur = cursor - param = 'x' * (cursor._connection._tds_socket.main_session._writer.bufsize * 3) - cur.execute('select %s', (param,)) - assert [(param, )] == cur.fetchall() + param = "x" * (cursor._connection._tds_socket.main_session._writer.bufsize * 3) + cur.execute("select %s", (param,)) + assert [(param,)] == cur.fetchall() def test_big_request(cursor): cur = cursor - param = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(5000)) - params = (10, datetime.datetime(2012, 11, 19, 1, 21, 37, 3000), param, 'test') - cur.execute('select %s, %s, %s, %s', params) + param = "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(5000) + ) + params = (10, datetime.datetime(2012, 11, 19, 1, 21, 37, 3000), param, "test") + cur.execute("select %s, %s, %s, %s", params) assert [params] == cur.fetchall() def test_row_count(cursor): cur = cursor - cur.execute(''' + cur.execute( + """ create table testtable_row_cnt (field int) - ''') - cur.execute('insert into testtable_row_cnt (field) values (1)') + """ + ) + cur.execute("insert into testtable_row_cnt (field) values (1)") assert cur.rowcount == 1 - cur.execute('insert into testtable_row_cnt (field) values (2)') + cur.execute("insert into testtable_row_cnt (field) values (2)") assert cur.rowcount == 1 - cur.execute('select * from testtable_row_cnt') + cur.execute("select * from testtable_row_cnt") cur.fetchall() assert cur.rowcount == 2 def test_no_rows(cursor): cur = cursor - cur.execute(''' + cur.execute( + """ create table testtable_no_rows (field int) - ''') - cur.execute('select * from testtable_no_rows') + """ + ) + cur.execute("select * from testtable_no_rows") assert [] == cur.fetchall() def test_fixed_size_data(cursor): cur = cursor - cur.execute(''' + cur.execute( + """ create table testtable_fixed_size_dt (chr char(5), nchr nchar(5), bfld binary(5)) insert into testtable_fixed_size_dt values ('1', '2', cast('3' as binary(5))) - ''') - cur.execute('select * from testtable_fixed_size_dt') - assert cur.fetchall() == [('1 ', '2 ', b'3\x00\x00\x00\x00')] - - + """ + ) + cur.execute("select * from testtable_fixed_size_dt") + assert cur.fetchall() == [("1 ", "2 ", b"3\x00\x00\x00\x00")] def test_closing_cursor_in_context(db_connection): @@ -514,22 +606,23 @@ def test_outparam_and_result_set(cursor): Test stored procedure which has output parameters and also result set """ cur = cursor - logger.info('creating stored procedure') - cur.execute(''' + logger.info("creating stored procedure") + cur.execute( + """ CREATE PROCEDURE P_OutParam_ResultSet(@A INT OUTPUT) AS BEGIN SET @A = 3; SELECT 4 AS C; SELECT 5 AS C; END; - ''' - ) - logger.info('executing stored procedure') - cur.callproc('P_OutParam_ResultSet', [pytds.output(value=1)]) + """ + ) + logger.info("executing stored procedure") + cur.callproc("P_OutParam_ResultSet", [pytds.output(value=1)]) assert [(4,)] == cur.fetchall() assert [3] == cur.get_proc_outputs() - logger.info('execurint query after stored procedure') - cur.execute('select 5') + logger.info("execurint query after stored procedure") + cur.execute("select 5") assert [(5,)] == cur.fetchall() @@ -538,7 +631,8 @@ def test_outparam_null_default(cursor): pytds.output(None, None) cur = cursor - cur.execute(''' + cur.execute( + """ create procedure outparam_null_testproc (@inparam int, @outint int = 8 output, @outstr varchar(max) = 'defstr' output) as begin @@ -547,56 +641,92 @@ def test_outparam_null_default(cursor): set @outstr = isnull(@outstr, 'null') + cast(@inparam as varchar(max)) set @inparam = 8 end - ''') - values = cur.callproc('outparam_null_testproc', (1, pytds.output(value=4), pytds.output(value='str'))) - assert [1, 5, 'str1'] == values - values = cur.callproc('outparam_null_testproc', (1, pytds.output(value=None, param_type='int'), pytds.output(value=None, param_type='varchar(max)'))) - assert [1, -9, 'null1'] == values - values = cur.callproc('outparam_null_testproc', (1, pytds.output(value=pytds.default, param_type='int'), pytds.output(value=pytds.default, param_type='varchar(max)'))) - assert [1, 9, 'defstr1'] == values - values = cur.callproc('outparam_null_testproc', (1, pytds.output(value=pytds.default, param_type='bit'), pytds.output(value=pytds.default, param_type='varchar(5)'))) - assert [1, 1, 'defst'] == values - values = cur.callproc('outparam_null_testproc', (1, pytds.output(value=pytds.default, param_type=int), pytds.output(value=pytds.default, param_type=str))) - assert [1, 9, 'defstr1'] == values + """ + ) + values = cur.callproc( + "outparam_null_testproc", (1, pytds.output(value=4), pytds.output(value="str")) + ) + assert [1, 5, "str1"] == values + values = cur.callproc( + "outparam_null_testproc", + ( + 1, + pytds.output(value=None, param_type="int"), + pytds.output(value=None, param_type="varchar(max)"), + ), + ) + assert [1, -9, "null1"] == values + values = cur.callproc( + "outparam_null_testproc", + ( + 1, + pytds.output(value=pytds.default, param_type="int"), + pytds.output(value=pytds.default, param_type="varchar(max)"), + ), + ) + assert [1, 9, "defstr1"] == values + values = cur.callproc( + "outparam_null_testproc", + ( + 1, + pytds.output(value=pytds.default, param_type="bit"), + pytds.output(value=pytds.default, param_type="varchar(5)"), + ), + ) + assert [1, 1, "defst"] == values + values = cur.callproc( + "outparam_null_testproc", + ( + 1, + pytds.output(value=pytds.default, param_type=int), + pytds.output(value=pytds.default, param_type=str), + ), + ) + assert [1, 9, "defstr1"] == values def test_invalid_ntlm_creds(): if not LIVE_TEST: - pytest.skip('LIVE_TEST is not set') + pytest.skip("LIVE_TEST is not set") with pytest.raises(pytds.OperationalError): - pytds.connect(settings.HOST, auth=pytds.login.NtlmAuth(user_name='bad', password='bad')) + pytds.connect( + settings.HOST, auth=pytds.login.NtlmAuth(user_name="bad", password="bad") + ) def test_open_with_different_blocksize(): if not LIVE_TEST: - pytest.skip('LIVE_TEST is not set') + pytest.skip("LIVE_TEST is not set") kwargs = settings.CONNECT_KWARGS.copy() # test very small block size - kwargs['blocksize'] = 100 + kwargs["blocksize"] = 100 with pytds.connect(*settings.CONNECT_ARGS, **kwargs): pass # test very large block size - kwargs['blocksize'] = 1000000 + kwargs["blocksize"] = 1000000 with pytds.connect(*settings.CONNECT_ARGS, **kwargs): pass def test_nvarchar_multiple_rows(cursor): - cursor.execute(''' + cursor.execute( + """ set nocount on declare @tbl table (id int primary key, fld nvarchar(max)) insert into @tbl values(1, 'foo') insert into @tbl values(2, 'bar') select fld from @tbl order by id - ''' + """ ) - assert cursor.fetchall() == [('foo',), ('bar',)] + assert cursor.fetchall() == [("foo",), ("bar",)] def test_no_metadata_request(cursor): cursor._session.submit_rpc( rpc_name=pytds.tds_base.SP_PREPARE, - params=cursor._session._convert_params((pytds.output(param_type=int), '@p1 int', 'select @p1')), + params=cursor._session._convert_params( + (pytds.output(param_type=int), "@p1 int", "select @p1") + ), ) cursor._session.begin_response() cursor._session.process_rpc() @@ -617,7 +747,7 @@ def test_no_metadata_request(cursor): cursor._session.submit_rpc( rpc_name=pytds.tds_base.SP_EXECUTE, params=cursor._session._convert_params((handle, 2)), - flags=0x02 # no metadata + flags=0x02, # no metadata ) cursor._session.begin_response() cursor._session.process_rpc() @@ -629,10 +759,10 @@ def test_no_metadata_request(cursor): def test_with_sso(): if not LIVE_TEST: - pytest.skip('LIVE_TEST is not set') + pytest.skip("LIVE_TEST is not set") with pytds.connect(settings.HOST, use_sso=True) as conn: with conn.cursor() as cursor: - cursor.execute('select 1') + cursor.execute("select 1") cursor.fetchall() @@ -642,7 +772,7 @@ def test_param_as_column_backward_compat(cursor): New way to pass such parameters is to use Param object. """ param = Column(type=BitType(), value=True) - result = cursor.execute_scalar('select %s', [param]) + result = cursor.execute_scalar("select %s", [param]) assert result is True @@ -651,5 +781,5 @@ def test_param_with_spaces(cursor): For backward compatibility need to support passing parameters as Column objects New way to pass such parameters is to use Param object. """ - result = cursor.execute_scalar('select %(param name)s', {"param name": "abc"}) + result = cursor.execute_scalar("select %(param name)s", {"param name": "abc"}) assert result == "abc" diff --git a/tests/connection_closing_tests.py b/tests/connection_closing_tests.py index aa20cd1..10dee8c 100644 --- a/tests/connection_closing_tests.py +++ b/tests/connection_closing_tests.py @@ -15,7 +15,7 @@ def get_spid(conn): def kill(conn, spid): with conn.cursor() as cur: - cur.execute('kill {0}'.format(spid)) + cur.execute("kill {0}".format(spid)) def test_cursor_use_after_connection_closing(): @@ -37,7 +37,7 @@ def test_cursor_use_after_connection_closing(): def test_open_close(): for x in range(3): kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = 'master' + kwargs["database"] = "master" pytds.connect(**kwargs).close() @@ -46,13 +46,13 @@ def test_closing_after_closed_by_server(): You should be able to call close on connection closed by server """ kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = 'master' - kwargs['autocommit'] = True + kwargs["database"] = "master" + kwargs["autocommit"] = True with pytds.connect(**kwargs) as master_conn: - kwargs['autocommit'] = False + kwargs["autocommit"] = False with pytds.connect(**kwargs) as conn: with conn.cursor() as cur: - cur.execute('select 1') + cur.execute("select 1") conn.commit() kill(master_conn, get_spid(conn)) sleep(0.2) diff --git a/tests/connection_pool_tests.py b/tests/connection_pool_tests.py index 34aec09..83b7786 100644 --- a/tests/connection_pool_tests.py +++ b/tests/connection_pool_tests.py @@ -2,7 +2,7 @@ import settings import pytds -LIVE_TEST = getattr(settings, 'LIVE_TEST', True) +LIVE_TEST = getattr(settings, "LIVE_TEST", True) def test_broken_connection_in_pool(): @@ -23,7 +23,7 @@ def test_broken_connection_in_pool(): # kill this connection, need to use another connection to do that spid = sess.execute_scalar("select @@spid") with extra_conn.cursor() as cur: - cur.execute(f'kill {spid}') + cur.execute(f"kill {spid}") # create new connection, it should attempt to use connection from the pool # it should detect that connection is bad and create new one diff --git a/tests/dbapi20.py b/tests/dbapi20.py index cba21bf..ac0b923 100644 --- a/tests/dbapi20.py +++ b/tests/dbapi20.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -''' Python DB API 2.0 driver compliance unit test suite. +""" Python DB API 2.0 driver compliance unit test suite. This software is Public Domain and may be used without restrictions. @@ -9,11 +9,11 @@ this is turning out to be a thoroughly unwholesome unit test." -- Ian Bicking -''' +""" -__rcs_id__ = '$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $' -__version__ = '$Revision: 1.12 $'[11:-2] -__author__ = 'Stuart Bishop ' +__rcs_id__ = "$Id: dbapi20.py,v 1.11 2005/01/02 02:41:01 zenzen Exp $" +__version__ = "$Revision: 1.12 $"[11:-2] +__author__ = "Stuart Bishop " import unittest import time @@ -72,73 +72,74 @@ # - Fix bugs in test_setoutputsize_basic and test_setinputsizes # def str2bytes(sval): - if sys.version_info < (3,0) and isinstance(sval, str): + if sys.version_info < (3, 0) and isinstance(sval, str): sval = sval.decode("latin1") return sval.encode("latin1") + class DatabaseAPI20Test(unittest.TestCase): - ''' Test a database self.driver for DB API 2.0 compatibility. - This implementation tests Gadfly, but the TestCase - is structured so that other self.drivers can subclass this - test case to ensure compiliance with the DB-API. It is - expected that this TestCase may be expanded in the future - if ambiguities or edge conditions are discovered. + """Test a database self.driver for DB API 2.0 compatibility. + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compiliance with the DB-API. It is + expected that this TestCase may be expanded in the future + if ambiguities or edge conditions are discovered. - The 'Optional Extensions' are not yet being tested. + The 'Optional Extensions' are not yet being tested. - self.drivers should subclass this test, overriding setUp, tearDown, - self.driver, connect_args and connect_kw_args. Class specification - should be as follows: + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: - import dbapi20 - class mytest(dbapi20.DatabaseAPI20Test): - [...] + import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] - Don't 'import DatabaseAPI20Test from dbapi20', or you will - confuse the unit tester - just 'import dbapi20'. - ''' + Don't 'import DatabaseAPI20Test from dbapi20', or you will + confuse the unit tester - just 'import dbapi20'. + """ # The self.driver module. This should be the module where the 'connect' # method is to be found driver = None - connect_args = () # List of arguments to pass to connect - connect_kw_args = {} # Keyword arguments for connect - table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + connect_args = () # List of arguments to pass to connect + connect_kw_args = {} # Keyword arguments for connect + table_prefix = "dbapi20test_" # If you need to specify a prefix for tables + + ddl1 = "create table %sbooze (name varchar(20))" % table_prefix + ddl2 = "create table %sbarflys (name varchar(20))" % table_prefix + xddl1 = "drop table %sbooze" % table_prefix + xddl2 = "drop table %sbarflys" % table_prefix - ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix - ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix - xddl1 = 'drop table %sbooze' % table_prefix - xddl2 = 'drop table %sbarflys' % table_prefix + lowerfunc = "to_lower" # Name of stored procedure to convert string->lowercase - lowerfunc = 'to_lower' # Name of stored procedure to convert string->lowercase - # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. - def executeDDL1(self,cursor): + def executeDDL1(self, cursor): cursor.execute(self.ddl1) - def executeDDL2(self,cursor): + def executeDDL2(self, cursor): cursor.execute(self.ddl2) def setUp(self): - ''' self.drivers should override this method to perform required setup - if any is necessary, such as creating the database. - ''' + """self.drivers should override this method to perform required setup + if any is necessary, such as creating the database. + """ pass def tearDown(self): - ''' self.drivers should override this method to perform required cleanup - if any is necessary, such as deleting the test database. - The default drops the tables that may be created. - ''' + """self.drivers should override this method to perform required cleanup + if any is necessary, such as deleting the test database. + The default drops the tables that may be created. + """ con = self._connect() try: cur = con.cursor() - for ddl in (self.xddl1,self.xddl2): - try: + for ddl in (self.xddl1, self.xddl2): + try: cur.execute(ddl) con.commit() - except self.driver.Error: + except self.driver.Error: # Assume table didn't exist. Other tests will check if # execute is busted. pass @@ -147,9 +148,7 @@ def tearDown(self): def _connect(self): try: - return self.driver.connect( - *self.connect_args,**self.connect_kw_args - ) + return self.driver.connect(*self.connect_args, **self.connect_kw_args) except AttributeError: self.fail("No connect method found in self.driver module") @@ -162,7 +161,7 @@ def test_apilevel(self): # Must exist apilevel = self.driver.apilevel # Must equal 2.0 - self.assertEqual(apilevel,'2.0') + self.assertEqual(apilevel, "2.0") except AttributeError: self.fail("Driver doesn't define apilevel") @@ -171,7 +170,7 @@ def test_threadsafety(self): # Must exist threadsafety = self.driver.threadsafety # Must be a valid value - self.assertTrue(threadsafety in (0,1,2,3)) + self.assertTrue(threadsafety in (0, 1, 2, 3)) except AttributeError: self.fail("Driver doesn't define threadsafety") @@ -180,43 +179,29 @@ def test_paramstyle(self): # Must exist paramstyle = self.driver.paramstyle # Must be a valid value - self.assertTrue(paramstyle in ( - 'qmark','numeric','named','format','pyformat' - )) + self.assertTrue( + paramstyle in ("qmark", "numeric", "named", "format", "pyformat") + ) except AttributeError: self.fail("Driver doesn't define paramstyle") def test_Exceptions(self): # Make sure required exceptions exist, and are in the # defined heirarchy. - if sys.version[0] == '3': #under Python 3 StardardError no longer exists - self.assertTrue(issubclass(self.driver.Warning,Exception)) - self.assertTrue(issubclass(self.driver.Error,Exception)) + if sys.version[0] == "3": # under Python 3 StardardError no longer exists + self.assertTrue(issubclass(self.driver.Warning, Exception)) + self.assertTrue(issubclass(self.driver.Error, Exception)) else: - self.assertTrue(issubclass(self.driver.Warning,StandardError)) - self.assertTrue(issubclass(self.driver.Error,StandardError)) + self.assertTrue(issubclass(self.driver.Warning, StandardError)) + self.assertTrue(issubclass(self.driver.Error, StandardError)) - self.assertTrue( - issubclass(self.driver.InterfaceError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.DatabaseError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.OperationalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.IntegrityError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.InternalError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.ProgrammingError,self.driver.Error) - ) - self.assertTrue( - issubclass(self.driver.NotSupportedError,self.driver.Error) - ) + self.assertTrue(issubclass(self.driver.InterfaceError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.DatabaseError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.OperationalError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.IntegrityError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.InternalError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.ProgrammingError, self.driver.Error)) + self.assertTrue(issubclass(self.driver.NotSupportedError, self.driver.Error)) def test_ExceptionsAsConnectionAttributes(self): # OPTIONAL EXTENSION @@ -240,7 +225,6 @@ def test_ExceptionsAsConnectionAttributes(self): finally: con.close() - def test_commit(self): con = self._connect() try: @@ -254,14 +238,14 @@ def test_rollback(self): try: # If rollback is defined, it should either work or throw # the documented exception - if hasattr(con,'rollback'): + if hasattr(con, "rollback"): try: con.rollback() except self.driver.NotSupportedError: pass finally: con.close() - + def test_cursor(self): con = self._connect() try: @@ -277,14 +261,14 @@ def test_cursor_isolation(self): cur1 = con.cursor() cur2 = con.cursor() self.executeDDL1(cur1) - cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) + cur1.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) cur2.execute("select name from %sbooze" % self.table_prefix) booze = cur2.fetchall() - self.assertEqual(len(booze),1) - self.assertEqual(len(booze[0]),1) - self.assertEqual(booze[0][0],'Victoria Bitter') + self.assertEqual(len(booze), 1) + self.assertEqual(len(booze[0]), 1) + self.assertEqual(booze[0][0], "Victoria Bitter") finally: con.close() @@ -293,31 +277,41 @@ def test_description(self): try: cur = con.cursor() self.executeDDL1(cur) - self.assertEqual(cur.description,None, - 'cursor.description should be none after executing a ' - 'statement that can return no rows (such as DDL)' - ) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(len(cur.description),1, - 'cursor.description describes too many columns' - ) - self.assertEqual(len(cur.description[0]),7, - 'cursor.description[x] tuples must have 7 elements' - ) - self.assertEqual(cur.description[0][0].lower(),'name', - 'cursor.description[x][0] must return column name' - ) - self.assertEqual(cur.description[0][1],self.driver.STRING, - 'cursor.description[x][1] must return column type. Got %r' - % cur.description[0][1] - ) + self.assertEqual( + cur.description, + None, + "cursor.description should be none after executing a " + "statement that can return no rows (such as DDL)", + ) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + len(cur.description), 1, "cursor.description describes too many columns" + ) + self.assertEqual( + len(cur.description[0]), + 7, + "cursor.description[x] tuples must have 7 elements", + ) + self.assertEqual( + cur.description[0][0].lower(), + "name", + "cursor.description[x][0] must return column name", + ) + self.assertEqual( + cur.description[0][1], + self.driver.STRING, + "cursor.description[x][1] must return column type. Got %r" + % cur.description[0][1], + ) # Make sure self.description gets reset self.executeDDL2(cur) - self.assertEqual(cur.description,None, - 'cursor.description not being set to None when executing ' - 'no-result statements (eg. DDL)' - ) + self.assertEqual( + cur.description, + None, + "cursor.description not being set to None when executing " + "no-result statements (eg. DDL)", + ) finally: con.close() @@ -326,48 +320,50 @@ def test_rowcount(self): try: cur = con.cursor() self.executeDDL1(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount should be -1 after executing no-result ' - 'statements' - ) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number or rows inserted, or ' - 'set to -1 after executing an insert statement' - ) + self.assertEqual( + cur.rowcount, + -1, + "cursor.rowcount should be -1 after executing no-result " "statements", + ) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number or rows inserted, or " + "set to -1 after executing an insert statement", + ) cur.execute("select name from %sbooze" % self.table_prefix) - self.assertTrue(cur.rowcount in (-1,1), - 'cursor.rowcount should == number of rows returned, or ' - 'set to -1 after executing a select statement' - ) + self.assertTrue( + cur.rowcount in (-1, 1), + "cursor.rowcount should == number of rows returned, or " + "set to -1 after executing a select statement", + ) self.executeDDL2(cur) - self.assertEqual(cur.rowcount,-1, - 'cursor.rowcount not being reset to -1 after executing ' - 'no-result statements' - ) + self.assertEqual( + cur.rowcount, + -1, + "cursor.rowcount not being reset to -1 after executing " + "no-result statements", + ) finally: con.close() - lower_func = 'to_lower' + lower_func = "to_lower" + def test_callproc(self): con = self._connect() try: cur = con.cursor() self._callproc_setup(cur) - if self.lower_func and hasattr(cur,'callproc'): - r = cur.callproc(self.lower_func,('FOO',)) - self.assertEqual(len(r),1) - self.assertEqual(r[0],'FOO') + if self.lower_func and hasattr(cur, "callproc"): + r = cur.callproc(self.lower_func, ("FOO",)) + self.assertEqual(len(r), 1) + self.assertEqual(r[0], "FOO") r = cur.fetchall() - self.assertEqual(len(r),1,'callproc produced no result set') - self.assertEqual(len(r[0]),1, - 'callproc produced invalid result set' - ) - self.assertEqual(r[0][0],'foo', - 'callproc produced invalid results' - ) + self.assertEqual(len(r), 1, "callproc produced no result set") + self.assertEqual(len(r[0]), 1, "callproc produced invalid result set") + self.assertEqual(r[0][0], "foo", "callproc produced invalid results") finally: con.close() @@ -380,15 +376,16 @@ def test_close(self): # cursor.execute should raise an Error if called after connection # closed - self.assertRaises(self.driver.Error,self.executeDDL1,cur) + self.assertRaises(self.driver.Error, self.executeDDL1, cur) # connection.commit should raise an Error if called after connection' # closed.' - self.assertRaises(self.driver.Error,con.commit) + self.assertRaises(self.driver.Error, con.commit) # connection.close should raise an Error if called more than once -# # disabled, there is no such requirement in DBAPI PEP-0249 - #self.assertRaises(self.driver.Error,con.close) + + # # disabled, there is no such requirement in DBAPI PEP-0249 + # self.assertRaises(self.driver.Error,con.close) def test_execute(self): con = self._connect() @@ -398,105 +395,99 @@ def test_execute(self): finally: con.close() - def _paraminsert(self,cur): + def _paraminsert(self, cur): self.executeDDL1(cur) - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertTrue(cur.rowcount in (-1,1)) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertTrue(cur.rowcount in (-1, 1)) - if self.driver.paramstyle == 'qmark': + if self.driver.paramstyle == "qmark": cur.execute( - 'insert into %sbooze values (?)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'numeric': + "insert into %sbooze values (?)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "numeric": cur.execute( - 'insert into %sbooze values (:1)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'named': + "insert into %sbooze values (:1)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "named": cur.execute( - 'insert into %sbooze values (:beer)' % self.table_prefix, - {'beer':"Cooper's"} - ) - elif self.driver.paramstyle == 'format': + "insert into %sbooze values (:beer)" % self.table_prefix, + {"beer": "Cooper's"}, + ) + elif self.driver.paramstyle == "format": cur.execute( - 'insert into %sbooze values (%%s)' % self.table_prefix, - ("Cooper's",) - ) - elif self.driver.paramstyle == 'pyformat': + "insert into %sbooze values (%%s)" % self.table_prefix, ("Cooper's",) + ) + elif self.driver.paramstyle == "pyformat": cur.execute( - 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, - {'beer':"Cooper's"} - ) + "insert into %sbooze values (%%(beer)s)" % self.table_prefix, + {"beer": "Cooper's"}, + ) else: - self.fail('Invalid paramstyle') - self.assertTrue(cur.rowcount in (-1,1)) + self.fail("Invalid paramstyle") + self.assertTrue(cur.rowcount in (-1, 1)) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) res = cur.fetchall() - self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') - beers = [res[0][0],res[1][0]] + self.assertEqual(len(res), 2, "cursor.fetchall returned too few rows") + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Cooper's", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) - self.assertEqual(beers[1],"Victoria Bitter", - 'cursor.fetchall retrieved incorrect data, or data inserted ' - 'incorrectly' - ) + self.assertEqual( + beers[0], + "Cooper's", + "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly", + ) + self.assertEqual( + beers[1], + "Victoria Bitter", + "cursor.fetchall retrieved incorrect data, or data inserted " "incorrectly", + ) def test_executemany(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) - largs = [ ("Cooper's",) , ("Boag's",) ] - margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] - if self.driver.paramstyle == 'qmark': + largs = [("Cooper's",), ("Boag's",)] + margs = [{"beer": "Cooper's"}, {"beer": "Boag's"}] + if self.driver.paramstyle == "qmark": cur.executemany( - 'insert into %sbooze values (?)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'numeric': + "insert into %sbooze values (?)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "numeric": cur.executemany( - 'insert into %sbooze values (:1)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'named': + "insert into %sbooze values (:1)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "named": cur.executemany( - 'insert into %sbooze values (:beer)' % self.table_prefix, - margs - ) - elif self.driver.paramstyle == 'format': + "insert into %sbooze values (:beer)" % self.table_prefix, margs + ) + elif self.driver.paramstyle == "format": cur.executemany( - 'insert into %sbooze values (%%s)' % self.table_prefix, - largs - ) - elif self.driver.paramstyle == 'pyformat': + "insert into %sbooze values (%%s)" % self.table_prefix, largs + ) + elif self.driver.paramstyle == "pyformat": cur.executemany( - 'insert into %sbooze values (%%(beer)s)' % ( - self.table_prefix - ), - margs - ) - else: - self.fail('Unknown paramstyle') - self.assertTrue(cur.rowcount in (-1,2), - 'insert using cursor.executemany set cursor.rowcount to ' - 'incorrect value %r' % cur.rowcount + "insert into %sbooze values (%%(beer)s)" % (self.table_prefix), + margs, ) - cur.execute('select name from %sbooze' % self.table_prefix) + else: + self.fail("Unknown paramstyle") + self.assertTrue( + cur.rowcount in (-1, 2), + "insert using cursor.executemany set cursor.rowcount to " + "incorrect value %r" % cur.rowcount, + ) + cur.execute("select name from %sbooze" % self.table_prefix) res = cur.fetchall() - self.assertEqual(len(res),2, - 'cursor.fetchall retrieved incorrect number of rows' - ) - beers = [res[0][0],res[1][0]] + self.assertEqual( + len(res), 2, "cursor.fetchall retrieved incorrect number of rows" + ) + beers = [res[0][0], res[1][0]] beers.sort() - self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') - self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') + self.assertEqual(beers[0], "Boag's", "incorrect data retrieved") + self.assertEqual(beers[1], "Cooper's", "incorrect data retrieved") finally: con.close() @@ -507,59 +498,62 @@ def test_fetchone(self): # cursor.fetchone should raise an Error if called before # executing a select-type query - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows self.executeDDL1(cur) - self.assertRaises(self.driver.Error,cur.fetchone) + self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if a query retrieves ' - 'no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if a query retrieves " "no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) # cursor.fetchone should raise an Error if called after # executing a query that cannnot return rows - cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( - self.table_prefix - )) - self.assertRaises(self.driver.Error,cur.fetchone) + cur.execute( + "insert into %sbooze values ('Victoria Bitter')" % (self.table_prefix) + ) + self.assertRaises(self.driver.Error, cur.fetchone) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchone() - self.assertEqual(len(r),1, - 'cursor.fetchone should have retrieved a single row' - ) - self.assertEqual(r[0],'Victoria Bitter', - 'cursor.fetchone retrieved incorrect data' - ) - self.assertEqual(cur.fetchone(),None, - 'cursor.fetchone should return None if no more rows available' - ) - self.assertTrue(cur.rowcount in (-1,1)) + self.assertEqual( + len(r), 1, "cursor.fetchone should have retrieved a single row" + ) + self.assertEqual( + r[0], "Victoria Bitter", "cursor.fetchone retrieved incorrect data" + ) + self.assertEqual( + cur.fetchone(), + None, + "cursor.fetchone should return None if no more rows available", + ) + self.assertTrue(cur.rowcount in (-1, 1)) finally: con.close() samples = [ - 'Carlton Cold', - 'Carlton Draft', - 'Mountain Goat', - 'Redback', - 'Victoria Bitter', - 'XXXX' - ] + "Carlton Cold", + "Carlton Draft", + "Mountain Goat", + "Redback", + "Victoria Bitter", + "XXXX", + ] def _populate(self): - ''' Return a list of sql commands to setup the DB for the fetch - tests. - ''' + """Return a list of sql commands to setup the DB for the fetch + tests. + """ populate = [ - "insert into %sbooze values ('%s')" % (self.table_prefix,s) - for s in self.samples - ] + "insert into %sbooze values ('%s')" % (self.table_prefix, s) + for s in self.samples + ] return populate def test_fetchmany(self): @@ -568,78 +562,88 @@ def test_fetchmany(self): cur = con.cursor() # cursor.fetchmany should raise an Error if called without - #issuing a query - self.assertRaises(self.driver.Error,cur.fetchmany,4) + # issuing a query + self.assertRaises(self.driver.Error, cur.fetchmany, 4) self.executeDDL1(cur) for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchmany() - self.assertEqual(len(r),1, - 'cursor.fetchmany retrieved incorrect number of rows, ' - 'default of arraysize is one.' - ) - cur.arraysize=10 - r = cur.fetchmany(3) # Should get 3 rows - self.assertEqual(len(r),3, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should get 2 more - self.assertEqual(len(r),2, - 'cursor.fetchmany retrieved incorrect number of rows' - ) - r = cur.fetchmany(4) # Should be an empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence after ' - 'results are exhausted' + self.assertEqual( + len(r), + 1, + "cursor.fetchmany retrieved incorrect number of rows, " + "default of arraysize is one.", + ) + cur.arraysize = 10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual( + len(r), 3, "cursor.fetchmany retrieved incorrect number of rows" + ) + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual( + len(r), 2, "cursor.fetchmany retrieved incorrect number of rows" + ) + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual( + len(r), + 0, + "cursor.fetchmany should return an empty sequence after " + "results are exhausted", ) - self.assertTrue(cur.rowcount in (-1,6)) + self.assertTrue(cur.rowcount in (-1, 6)) # Same as above, using cursor.arraysize - cur.arraysize=4 - cur.execute('select name from %sbooze' % self.table_prefix) - r = cur.fetchmany() # Should get 4 rows - self.assertEqual(len(r),4, - 'cursor.arraysize not being honoured by fetchmany' - ) - r = cur.fetchmany() # Should get 2 more - self.assertEqual(len(r),2) - r = cur.fetchmany() # Should be an empty sequence - self.assertEqual(len(r),0) - self.assertTrue(cur.rowcount in (-1,6)) - - cur.arraysize=6 - cur.execute('select name from %sbooze' % self.table_prefix) - rows = cur.fetchmany() # Should get all rows - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows),6) - self.assertEqual(len(rows),6) + cur.arraysize = 4 + cur.execute("select name from %sbooze" % self.table_prefix) + r = cur.fetchmany() # Should get 4 rows + self.assertEqual( + len(r), 4, "cursor.arraysize not being honoured by fetchmany" + ) + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r), 2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r), 0) + self.assertTrue(cur.rowcount in (-1, 6)) + + cur.arraysize = 6 + cur.execute("select name from %sbooze" % self.table_prefix) + rows = cur.fetchmany() # Should get all rows + self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual(len(rows), 6) + self.assertEqual(len(rows), 6) rows = [r[0] for r in rows] rows.sort() - + # Make sure we get the right data back out - for i in range(0,6): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved by cursor.fetchmany' - ) - - rows = cur.fetchmany() # Should return an empty list - self.assertEqual(len(rows),0, - 'cursor.fetchmany should return an empty sequence if ' - 'called after the whole result set has been fetched' + for i in range(0, 6): + self.assertEqual( + rows[i], + self.samples[i], + "incorrect data retrieved by cursor.fetchmany", ) - self.assertTrue(cur.rowcount in (-1,6)) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual( + len(rows), + 0, + "cursor.fetchmany should return an empty sequence if " + "called after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, 6)) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) - r = cur.fetchmany() # Should get empty sequence - self.assertEqual(len(r),0, - 'cursor.fetchmany should return an empty sequence if ' - 'query retrieved no rows' - ) - self.assertTrue(cur.rowcount in (-1,0)) + cur.execute("select name from %sbarflys" % self.table_prefix) + r = cur.fetchmany() # Should get empty sequence + self.assertEqual( + len(r), + 0, + "cursor.fetchmany should return an empty sequence if " + "query retrieved no rows", + ) + self.assertTrue(cur.rowcount in (-1, 0)) finally: con.close() @@ -659,40 +663,45 @@ def test_fetchall(self): # cursor.fetchall should raise an Error if called # after executing a a statement that cannot return rows - self.assertRaises(self.driver.Error,cur.fetchall) + self.assertRaises(self.driver.Error, cur.fetchall) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,len(self.samples))) - self.assertEqual(len(rows),len(self.samples), - 'cursor.fetchall did not retrieve all rows' - ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) + self.assertEqual( + len(rows), + len(self.samples), + "cursor.fetchall did not retrieve all rows", + ) rows = [r[0] for r in rows] rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'cursor.fetchall retrieved incorrect rows' + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "cursor.fetchall retrieved incorrect rows" ) rows = cur.fetchall() self.assertEqual( - len(rows),0, - 'cursor.fetchall should return an empty list if called ' - 'after the whole result set has been fetched' - ) - self.assertTrue(cur.rowcount in (-1,len(self.samples))) + len(rows), + 0, + "cursor.fetchall should return an empty list if called " + "after the whole result set has been fetched", + ) + self.assertTrue(cur.rowcount in (-1, len(self.samples))) self.executeDDL2(cur) - cur.execute('select name from %sbarflys' % self.table_prefix) + cur.execute("select name from %sbarflys" % self.table_prefix) rows = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,0)) - self.assertEqual(len(rows),0, - 'cursor.fetchall should return an empty list if ' - 'a select query returns no rows' - ) - + self.assertTrue(cur.rowcount in (-1, 0)) + self.assertEqual( + len(rows), + 0, + "cursor.fetchall should return an empty list if " + "a select query returns no rows", + ) + finally: con.close() - + def test_mixedfetch(self): con = self._connect() try: @@ -701,81 +710,81 @@ def test_mixedfetch(self): for sql in self._populate(): cur.execute(sql) - cur.execute('select name from %sbooze' % self.table_prefix) - rows1 = cur.fetchone() + cur.execute("select name from %sbooze" % self.table_prefix) + rows1 = cur.fetchone() rows23 = cur.fetchmany(2) - rows4 = cur.fetchone() + rows4 = cur.fetchone() rows56 = cur.fetchall() - self.assertTrue(cur.rowcount in (-1,6)) - self.assertEqual(len(rows23),2, - 'fetchmany returned incorrect number of rows' - ) - self.assertEqual(len(rows56),2, - 'fetchall returned incorrect number of rows' - ) + self.assertTrue(cur.rowcount in (-1, 6)) + self.assertEqual( + len(rows23), 2, "fetchmany returned incorrect number of rows" + ) + self.assertEqual( + len(rows56), 2, "fetchall returned incorrect number of rows" + ) rows = [rows1[0]] - rows.extend([rows23[0][0],rows23[1][0]]) + rows.extend([rows23[0][0], rows23[1][0]]) rows.append(rows4[0]) - rows.extend([rows56[0][0],rows56[1][0]]) + rows.extend([rows56[0][0], rows56[1][0]]) rows.sort() - for i in range(0,len(self.samples)): - self.assertEqual(rows[i],self.samples[i], - 'incorrect data retrieved or inserted' - ) + for i in range(0, len(self.samples)): + self.assertEqual( + rows[i], self.samples[i], "incorrect data retrieved or inserted" + ) finally: con.close() - def help_nextset_setUp(self,cur): - ''' Should create a procedure called deleteme - that returns two result sets, first the - number of rows in booze then "name from booze" - ''' - raise NotImplementedError('Helper not implemented') - #sql=""" + def help_nextset_setUp(self, cur): + """Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + """ + raise NotImplementedError("Helper not implemented") + # sql=""" # create procedure deleteme as # begin # select count(*) from booze # select name from booze # end - #""" - #cur.execute(sql) + # """ + # cur.execute(sql) - def help_nextset_tearDown(self,cur): - 'If cleaning up is needed after nextSetTest' - raise NotImplementedError('Helper not implemented') - #cur.execute("drop procedure deleteme") + def help_nextset_tearDown(self, cur): + "If cleaning up is needed after nextSetTest" + raise NotImplementedError("Helper not implemented") + # cur.execute("drop procedure deleteme") def test_nextset(self): con = self._connect() try: cur = con.cursor() - if not hasattr(cur,'nextset'): + if not hasattr(cur, "nextset"): return try: self.executeDDL1(cur) - sql=self._populate() + sql = self._populate() for sql in self._populate(): cur.execute(sql) self.help_nextset_setUp(cur) - cur.callproc('deleteme') - numberofrows=cur.fetchone() + cur.callproc("deleteme") + numberofrows = cur.fetchone() self.assertEqual(numberofrows[0], len(self.samples)) assert cur.nextset() - names=cur.fetchall() + names = cur.fetchall() assert len(names) == len(self.samples) - s=cur.nextset() - assert s == None,'No more return sets, should return None' + s = cur.nextset() + assert s == None, "No more return sets, should return None" finally: self.help_nextset_tearDown(cur) finally: con.close() - #def test_nextset(self): + # def test_nextset(self): # raise NotImplementedError('Drivers need to override this test') def test_arraysize(self): @@ -783,9 +792,9 @@ def test_arraysize(self): con = self._connect() try: cur = con.cursor() - self.assertTrue(hasattr(cur,'arraysize'), - 'cursor.arraysize must be defined' - ) + self.assertTrue( + hasattr(cur, "arraysize"), "cursor.arraysize must be defined" + ) finally: con.close() @@ -793,8 +802,8 @@ def test_setinputsizes(self): con = self._connect() try: cur = con.cursor() - cur.setinputsizes( (25,) ) - self._paraminsert(cur) # Make sure cursor still works + cur.setinputsizes((25,)) + self._paraminsert(cur) # Make sure cursor still works finally: con.close() @@ -804,75 +813,70 @@ def test_setoutputsize_basic(self): try: cur = con.cursor() cur.setoutputsize(1000) - cur.setoutputsize(2000,0) - self._paraminsert(cur) # Make sure the cursor still works + cur.setoutputsize(2000, 0) + self._paraminsert(cur) # Make sure the cursor still works finally: con.close() def test_setoutputsize(self): # Real test for setoutputsize is driver dependant - raise NotImplementedError('Driver needed to override this test') + raise NotImplementedError("Driver needed to override this test") def test_None(self): con = self._connect() try: cur = con.cursor() self.executeDDL1(cur) - cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) - cur.execute('select name from %sbooze' % self.table_prefix) + cur.execute("insert into %sbooze values (NULL)" % self.table_prefix) + cur.execute("select name from %sbooze" % self.table_prefix) r = cur.fetchall() - self.assertEqual(len(r),1) - self.assertEqual(len(r[0]),1) - self.assertEqual(r[0][0],None,'NULL value not returned as None') + self.assertEqual(len(r), 1) + self.assertEqual(len(r[0]), 1) + self.assertEqual(r[0][0], None, "NULL value not returned as None") finally: con.close() def test_Date(self): - d1 = self.driver.Date(2002,12,25) - d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) + d1 = self.driver.Date(2002, 12, 25) + d2 = self.driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(d1),str(d2)) def test_Time(self): - t1 = self.driver.Time(13,45,30) - t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) + t1 = self.driver.Time(13, 45, 30) + t2 = self.driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Timestamp(self): - t1 = self.driver.Timestamp(2002,12,25,13,45,30) + t1 = self.driver.Timestamp(2002, 12, 25, 13, 45, 30) t2 = self.driver.TimestampFromTicks( - time.mktime((2002,12,25,13,45,30,0,0,0)) - ) + time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)) + ) # Can we assume this? API doesn't specify, but it seems implied # self.assertEqual(str(t1),str(t2)) def test_Binary(self): - b = self.driver.Binary(str2bytes('Something')) - b = self.driver.Binary(str2bytes('')) + b = self.driver.Binary(str2bytes("Something")) + b = self.driver.Binary(str2bytes("")) def test_STRING(self): - self.assertTrue(hasattr(self.driver,'STRING'), - 'module.STRING must be defined' - ) + self.assertTrue(hasattr(self.driver, "STRING"), "module.STRING must be defined") def test_BINARY(self): - self.assertTrue(hasattr(self.driver,'BINARY'), - 'module.BINARY must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "BINARY"), "module.BINARY must be defined." + ) def test_NUMBER(self): - self.assertTrue(hasattr(self.driver,'NUMBER'), - 'module.NUMBER must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "NUMBER"), "module.NUMBER must be defined." + ) def test_DATETIME(self): - self.assertTrue(hasattr(self.driver,'DATETIME'), - 'module.DATETIME must be defined.' - ) + self.assertTrue( + hasattr(self.driver, "DATETIME"), "module.DATETIME must be defined." + ) def test_ROWID(self): - self.assertTrue(hasattr(self.driver,'ROWID'), - 'module.ROWID must be defined.' - ) - + self.assertTrue(hasattr(self.driver, "ROWID"), "module.ROWID must be defined.") diff --git a/tests/fixtures.py b/tests/fixtures.py index 301ddb1..12bdb6c 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -11,16 +11,16 @@ logger = logging.getLogger(__name__) -LIVE_TEST = getattr(settings, 'LIVE_TEST', True) +LIVE_TEST = getattr(settings, "LIVE_TEST", True) pytds.tds.logging_enabled = True -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def db_connection(sqlalchemy_engine): if not LIVE_TEST: - pytest.skip('LIVE_TEST is not set') + pytest.skip("LIVE_TEST is not set") kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = settings.DATABASE + kwargs["database"] = settings.DATABASE conn = pytds.connect(*settings.CONNECT_ARGS, **kwargs) utils.create_test_database(connection=conn) conn.commit() @@ -37,23 +37,25 @@ def cursor(db_connection): @pytest.fixture def separate_db_connection(): if not LIVE_TEST: - pytest.skip('LIVE_TEST is not set') + pytest.skip("LIVE_TEST is not set") kwargs = settings.CONNECT_KWARGS.copy() - kwargs['database'] = settings.DATABASE + kwargs["database"] = settings.DATABASE conn = pytds.connect(*settings.CONNECT_ARGS, **kwargs) yield conn conn.close() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def collation_set(db_connection): with db_connection.cursor() as cursor: - cursor.execute("SELECT Name, Description, COLLATIONPROPERTY(Name, 'LCID') FROM ::fn_helpcollations()") + cursor.execute( + "SELECT Name, Description, COLLATIONPROPERTY(Name, 'LCID') FROM ::fn_helpcollations()" + ) collations_list = cursor.fetchall() return set(coll_name for coll_name, _, _ in collations_list) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def sqlalchemy_engine() -> sqlalchemy.engine.Engine: host = settings.HOST hostname, _, instance = host.partition("\\") diff --git a/tests/settings.py b/tests/settings.py index d37a1c7..60bd826 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -7,7 +7,7 @@ connection_json_path = os.path.join(os.path.dirname(__file__), ".connection.json") if os.path.exists(connection_json_path): - conf = json.load(open(connection_json_path, 'rb')) + conf = json.load(open(connection_json_path, "rb")) default_host = conf["host"] default_database = conf["database"] default_user = conf["sqluser"] @@ -20,35 +20,35 @@ default_password = "sa" default_use_mars = True -LIVE_TEST = 'HOST' in os.environ or default_host +LIVE_TEST = "HOST" in os.environ or default_host if LIVE_TEST: - HOST = os.environ.get('HOST', default_host) - DATABASE = os.environ.get('DATABASE', default_database) - USER = os.environ.get('SQLUSER', default_user) - PASSWORD = os.environ.get('SQLPASSWORD', default_password) - USE_MARS = bool(os.environ.get('USE_MARS', default_use_mars)) - SKIP_SQL_AUTH = bool(os.environ.get('SKIP_SQL_AUTH')) + HOST = os.environ.get("HOST", default_host) + DATABASE = os.environ.get("DATABASE", default_database) + USER = os.environ.get("SQLUSER", default_user) + PASSWORD = os.environ.get("SQLPASSWORD", default_password) + USE_MARS = bool(os.environ.get("USE_MARS", default_use_mars)) + SKIP_SQL_AUTH = bool(os.environ.get("SKIP_SQL_AUTH")) import pytds CONNECT_KWARGS = { - 'server': HOST, - 'database': DATABASE, - 'user': USER, - 'password': PASSWORD, - 'use_mars': USE_MARS, - 'bytes_to_unicode': True, - 'pooling': True, - 'timeout': 30, + "server": HOST, + "database": DATABASE, + "user": USER, + "password": PASSWORD, + "use_mars": USE_MARS, + "bytes_to_unicode": True, + "pooling": True, + "timeout": 30, } - if 'tds_version' in os.environ: - CONNECT_KWARGS['tds_version'] = getattr(pytds, os.environ['tds_version']) + if "tds_version" in os.environ: + CONNECT_KWARGS["tds_version"] = getattr(pytds, os.environ["tds_version"]) - if 'auth' in os.environ: + if "auth" in os.environ: import pytds.login - CONNECT_KWARGS['auth'] = getattr(pytds.login, os.environ['auth'])() + CONNECT_KWARGS["auth"] = getattr(pytds.login, os.environ["auth"])() - if 'bytes_to_unicode' in os.environ: - CONNECT_KWARGS['bytes_to_unicode'] = bool(os.environ.get('bytes_to_unicode')) + if "bytes_to_unicode" in os.environ: + CONNECT_KWARGS["bytes_to_unicode"] = bool(os.environ.get("bytes_to_unicode")) diff --git a/tests/simple_server.py b/tests/simple_server.py index 0ac8b1a..d99b026 100644 --- a/tests/simple_server.py +++ b/tests/simple_server.py @@ -10,9 +10,9 @@ import pytds.tds_writer import pytds.collate -_BYTE_STRUCT = struct.Struct('B') -_OFF_LEN_STRUCT = struct.Struct('>HH') -_PROD_VER_STRUCT = struct.Struct('>LH') +_BYTE_STRUCT = struct.Struct("B") +_OFF_LEN_STRUCT = struct.Struct(">HH") +_PROD_VER_STRUCT = struct.Struct(">LH") logger = logging.getLogger(__name__) @@ -29,15 +29,15 @@ def parse_prelogin(self, buf): while True: value = None if i >= size: - self.bad_stream('Invalid size of PRELOGIN structure') - type_id, = _BYTE_STRUCT.unpack_from(buf, i) + self.bad_stream("Invalid size of PRELOGIN structure") + (type_id,) = _BYTE_STRUCT.unpack_from(buf, i) if type_id == pytds.tds_base.PreLoginToken.TERMINATOR: break if i + 4 > size: - self.bad_stream('Invalid size of PRELOGIN structure') + self.bad_stream("Invalid size of PRELOGIN structure") off, l = _OFF_LEN_STRUCT.unpack_from(buf, i + 1) if off > size or off + l > size: - self.bad_stream('Invalid offset in PRELOGIN structure') + self.bad_stream("Invalid offset in PRELOGIN structure") if type_id == pytds.tds_base.PreLoginToken.VERSION: value = _PROD_VER_STRUCT.unpack_from(buf, off) elif type_id == pytds.tds_base.PreLoginToken.ENCRYPTION: @@ -45,7 +45,7 @@ def parse_prelogin(self, buf): elif type_id == pytds.tds_base.PreLoginToken.MARS: value = bool(_BYTE_STRUCT.unpack_from(buf, off)[0]) elif type_id == pytds.tds_base.PreLoginToken.INSTOPT: - value = buf[off:off+l].decode('ascii') + value = buf[off : off + l].decode("ascii") i += 5 result[type_id] = value return result @@ -65,9 +65,11 @@ def generate_prelogin(self, prelogin): elif type_id == pytds.tds_base.PreLoginToken.MARS: packed = [1 if value else 0] elif type_id == pytds.tds_base.PreLoginToken.INSTOPT: - packed = value.encode('ascii') + packed = value.encode("ascii") else: - raise Exception(f"not implemented prelogin option {type_id} in prelogin message generator") + raise Exception( + f"not implemented prelogin option {type_id} in prelogin message generator" + ) data_size = len(packed) @@ -83,7 +85,7 @@ def generate_prelogin(self, prelogin): return buf -class Sock(): +class Sock: # wraps request in class compatible with TdsSocket def __init__(self, req): self._req = req @@ -108,13 +110,16 @@ def handle(self): self._transport = Sock(self.request) r = pytds.tds_reader._TdsReader(tds_session=self, transport=self._transport) - w = pytds.tds_writer._TdsWriter(tds_session=self, bufsize=bufsize, transport=self._transport) + w = pytds.tds_writer._TdsWriter( + tds_session=self, bufsize=bufsize, transport=self._transport + ) resp_header = r.begin_response() buf = r.read_whole_packet() if resp_header.type != pytds.tds_base.PacketType.PRELOGIN: - msg = 'Invalid packet type: {0}, expected PRELOGIN({1})'.format(r.packet_type, - pytds.tds_base.PacketType.PRELOGIN) + msg = "Invalid packet type: {0}, expected PRELOGIN({1})".format( + r.packet_type, pytds.tds_base.PacketType.PRELOGIN + ) self.bad_stream(msg) prelogin = parser.parse_prelogin(buf) logger.info(f"received prelogin message from client {prelogin}") @@ -149,9 +154,11 @@ def handle(self): res_enc = pytds.PreLoginEnc.ENCRYPT_REQ # sending reply to client's prelogin packet - prelogin_resp = gen.generate_prelogin({ - pytds.tds_base.PreLoginToken.ENCRYPTION: res_enc, - }) + prelogin_resp = gen.generate_prelogin( + { + pytds.tds_base.PreLoginToken.ENCRYPTION: res_enc, + } + ) w.begin_packet(pytds.tds_base.PacketType.REPLY) w.write(prelogin_resp) w.flush() @@ -192,7 +199,9 @@ def handle(self): w.write(buf) w.flush() - wrapped_socket = pytds.tls.EncryptedSocket(transport=self.request, tls_conn=tlsconn) + wrapped_socket = pytds.tls.EncryptedSocket( + transport=self.request, tls_conn=tlsconn + ) r._transport = wrapped_socket w._transport = wrapped_socket @@ -200,7 +209,9 @@ def handle(self): r.begin_response() buf = r.read_whole_packet() except pytds.tds_base.ClosedConnectionError: - logger.info('client closed connection, probably did not like server certificate') + logger.info( + "client closed connection, probably did not like server certificate" + ) return logger.info(f"received login packet from client {buf}") @@ -209,7 +220,7 @@ def handle(self): r._transport = self._transport w._transport = self._transport - srv_name = 'Simple TDS Server' + srv_name = "Simple TDS Server" srv_ver = (1, 0, 0, 0) tds_version = self.server._tds_version @@ -233,7 +244,7 @@ def handle(self): w.put_byte(pytds.tds_base.TDS_DONE_TOKEN) w.put_usmallint(0) # status w.put_usmallint(0) # curcmd - w.put_uint8(0) # done row count + w.put_uint8(0) # done row count w.flush() @@ -244,7 +255,9 @@ def bad_stream(self, msg): class SimpleServer(socketserver.TCPServer): allow_reuse_address = True - def __init__(self, address, enc, cert=None, pkey=None, tds_version=pytds.tds_base.TDS74): + def __init__( + self, address, enc, cert=None, pkey=None, tds_version=pytds.tds_base.TDS74 + ): self._enc = enc super().__init__(address, RequestHandler) ctx = None @@ -265,11 +278,11 @@ def set_enc(self, enc): def run(address): - logger.info('Starting server...') + logger.info("Starting server...") with SimpleServer(address) as server: - logger.info('Press Ctrl+C to stop the server') + logger.info("Press Ctrl+C to stop the server") server.serve_forever() -if __name__ == '__main__': +if __name__ == "__main__": run() diff --git a/tests/smp_test.py b/tests/smp_test.py index 2a2b55e..890d334 100644 --- a/tests/smp_test.py +++ b/tests/smp_test.py @@ -6,7 +6,7 @@ from utils import MockSock -smp_hdr = struct.Struct('= len(self._packets): - return b'' + return b"" if self._packet_pos >= len(self._packets[self._curr_packet]): self._curr_packet += 1 self._packet_pos = 0 if self._curr_packet >= len(self._packets): - return b'' - res = self._packets[self._curr_packet][self._packet_pos:self._packet_pos+size] + return b"" + res = self._packets[self._curr_packet][ + self._packet_pos : self._packet_pos + size + ] self._packet_pos += len(res) return res @@ -63,7 +133,7 @@ def recv_into(self, buffer, size=0): if size == 0: size = len(buffer) res = self.recv(size) - buffer[0:len(res)] = res + buffer[0 : len(res)] = res return len(res) def send(self, buf, flags=0): @@ -77,12 +147,13 @@ def setsockopt(self, *args): pass def close(self): - self._stream = b'' + self._stream = b"" class TestMessages(unittest.TestCase): def _make_login(self): from pytds.tds_base import TDS74 + login = _TdsLogin() login.blocksize = 4096 login.use_tz = None @@ -92,67 +163,75 @@ def _make_login(self): login.enc_flag = PreLoginEnc.ENCRYPT_NOT_SUP login.use_mars = False login.option_flag2 = 0 - login.user_name = 'testname' - login.password = 'password' - login.app_name = 'appname' - login.server_name = 'servername' - login.library = 'library' - login.language = 'EN' - login.database = 'database' + login.user_name = "testname" + login.password = "password" + login.app_name = "appname" + login.server_name = "servername" + login.library = "library" + login.language = "EN" + login.database = "database" login.auth = None login.bulk_copy = False login.readonly = False login.client_lcid = 100 - login.attach_db_file = '' + login.attach_db_file = "" login.text_size = 0 - login.client_host_name = 'clienthost' + login.client_host_name = "clienthost" login.pid = 100 - login.change_password = '' + login.change_password = "" login.client_tz = tzoffset(5) - login.client_id = 0xabcd + login.client_id = 0xABCD login.bytes_to_unicode = True return login def test_login(self): - sock = _FakeSock([ - # prelogin response - b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', - # login resopnse - b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\x07\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01s\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - # response to USE query - b'\x04\x01\x00#\x00Z\x01\x00\xe3\x0b\x00\x08\x08\x01\x00\x00\x00Z\x00\x00\x00\x00\xfd\x00\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00', - ]) + sock = _FakeSock( + [ + # prelogin response + b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', + # login resopnse + b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\x07\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01s\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + # response to USE query + b"\x04\x01\x00#\x00Z\x01\x00\xe3\x0b\x00\x08\x08\x01\x00\x00\x00Z\x00\x00\x00\x00\xfd\x00\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00", + ] + ) _TdsSocket(sock=sock, login=self._make_login()).login() # test connection close on first message - sock = _FakeSock([ - b'\x04\x01\x00+\x00', - ]) + sock = _FakeSock( + [ + b"\x04\x01\x00+\x00", + ] + ) with self.assertRaises(pytds.Error): _TdsSocket(sock=sock, login=self._make_login()).login() # test connection close on second message - sock = _FakeSock([ - b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', - b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S", - ]) + sock = _FakeSock( + [ + b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', + b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S", + ] + ) with self.assertRaises(pytds.Error): _TdsSocket(sock=sock, login=self._make_login()).login() # test connection close on third message - #sock = _FakeSock([ + # sock = _FakeSock([ # b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', # b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\x07\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01s\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", # b'\x04\x01\x00#\x00Z\x01\x00\xe3\x0b\x00\x08\x08\x01\x00\x00\x00Z\x00\x00\x00\x00\xfd\x00\x00\xfd\x00\x00', - #]) - #with self.assertRaises(pytds.Error): + # ]) + # with self.assertRaises(pytds.Error): # _TdsSocket().login(self._make_login(), sock, None) def test_prelogin_parsing(self): # test good packet - sock = _FakeSock([ - b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', - ]) + sock = _FakeSock( + [ + b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', + ] + ) # test repr on some objects login = _TdsLogin() login.enc_flag = PreLoginEnc.ENCRYPT_NOT_SUP @@ -161,39 +240,47 @@ def test_prelogin_parsing(self): repr(tds) tds._main_session.process_prelogin(login) self.assertFalse(tds._mars_enabled) - self.assertTupleEqual(tds.server_library_version, (0xa001588, 0)) + self.assertTupleEqual(tds.server_library_version, (0xA001588, 0)) # test bad packet type - sock = _FakeSock([ - b'\x03\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', - ]) + sock = _FakeSock( + [ + b'\x03\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\xff\n\x00\x15\x88\x00\x00\x02\x00\x00', + ] + ) login = self._make_login() tds = _TdsSocket(sock=sock, login=login) with self.assertRaises(pytds.InterfaceError): tds._main_session.process_prelogin(login) # test bad offset 1 - sock = _FakeSock([ - b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\x00\n\x00\x15\x88\x00\x00\x02\x00\x00', - ]) + sock = _FakeSock( + [ + b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\x00\n\x00\x15\x88\x00\x00\x02\x00\x00', + ] + ) login = self._make_login() tds = _TdsSocket(sock=sock, login=login) with self.assertRaises(pytds.InterfaceError): tds._main_session.process_prelogin(login) # test bad offset 2 - sock = _FakeSock([ - b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00', - ]) + sock = _FakeSock( + [ + b'\x04\x01\x00+\x00\x00\x01\x00\x00\x00\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x01\x03\x00"\x00\x00\x04\x00"\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00', + ] + ) login = self._make_login() tds = _TdsSocket(sock=sock, login=login) with self.assertRaises(pytds.InterfaceError): tds._main_session.process_prelogin(login) # test bad size - with self.assertRaisesRegex(pytds.InterfaceError, 'Invalid size of PRELOGIN structure'): + with self.assertRaisesRegex( + pytds.InterfaceError, "Invalid size of PRELOGIN structure" + ): login = self._make_login() - tds._main_session.parse_prelogin(login=login, octets=b'\x01') + tds._main_session.parse_prelogin(login=login, octets=b"\x01") def make_tds(self): sock = _FakeSock([]) @@ -202,212 +289,251 @@ def make_tds(self): def test_prelogin_unexpected_encrypt_on(self): tds = self.make_tds() - with self.assertRaisesRegex(pytds.InterfaceError, 'Server returned unexpected ENCRYPT_ON value'): + with self.assertRaisesRegex( + pytds.InterfaceError, "Server returned unexpected ENCRYPT_ON value" + ): login = self._make_login() login.enc_flag = PreLoginEnc.ENCRYPT_ON - tds._main_session.parse_prelogin(login=login, octets=b'\x01\x00\x06\x00\x01\xff\x00') + tds._main_session.parse_prelogin( + login=login, octets=b"\x01\x00\x06\x00\x01\xff\x00" + ) def test_prelogin_unexpected_enc_flag(self): tds = self.make_tds() - with self.assertRaisesRegex(pytds.InterfaceError, 'Unexpected value of enc_flag returned by server: 5'): + with self.assertRaisesRegex( + pytds.InterfaceError, "Unexpected value of enc_flag returned by server: 5" + ): login = self._make_login() - tds._main_session.parse_prelogin(login=login, octets=b'\x01\x00\x06\x00\x01\xff\x05') + tds._main_session.parse_prelogin( + login=login, octets=b"\x01\x00\x06\x00\x01\xff\x05" + ) def test_prelogin_generation(self): - sock = _FakeSock('') + sock = _FakeSock("") login = _TdsLogin() - login.instance_name = 'MSSQLServer' + login.instance_name = "MSSQLServer" login.enc_flag = PreLoginEnc.ENCRYPT_NOT_SUP login.use_mars = False tds = _TdsSocket(sock=sock, login=login) tds._main_session.send_prelogin(login) - template = (b'\x12\x01\x00:\x00\x00\x00\x00\x00\x00' + - b'\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x0c\x03' + - b'\x00-\x00\x04\x04\x001\x00\x01\xff' + struct.pack('>l', pytds.intversion) + - b'\x00\x00\x02MSSQLServer\x00\x00\x00\x00\x00\x00') + template = ( + b"\x12\x01\x00:\x00\x00\x00\x00\x00\x00" + + b"\x1a\x00\x06\x01\x00 \x00\x01\x02\x00!\x00\x0c\x03" + + b"\x00-\x00\x04\x04\x001\x00\x01\xff" + + struct.pack(">l", pytds.intversion) + + b"\x00\x00\x02MSSQLServer\x00\x00\x00\x00\x00\x00" + ) self.assertEqual(sock._sent, template) - login.instance_name = 'x' * 65499 - sock._sent = b'' - with self.assertRaisesRegex(ValueError, 'Instance name is too long'): + login.instance_name = "x" * 65499 + sock._sent = b"" + with self.assertRaisesRegex(ValueError, "Instance name is too long"): tds._main_session.send_prelogin(login) - self.assertEqual(sock._sent, b'') + self.assertEqual(sock._sent, b"") - login.instance_name = u'тест' + login.instance_name = "тест" with self.assertRaises(UnicodeEncodeError): tds._main_session.send_prelogin(login) - self.assertEqual(sock._sent, b'') + self.assertEqual(sock._sent, b"") def test_login_parsing(self): - sock = _FakeSock([ - b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\x07\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01s\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - ]) + sock = _FakeSock( + [ + b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\x07\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01s\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + ] + ) tds = _TdsSocket(sock=sock, login=_TdsLogin()) tds._main_session.begin_response() tds._main_session.process_login_tokens() # test invalid tds version - sock = _FakeSock([ - b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\x07\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01\x65\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - ]) + sock = _FakeSock( + [ + b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\x07\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01\x65\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + ] + ) tds = _TdsSocket(sock=sock, login=_TdsLogin()) tds._main_session.begin_response() with self.assertRaises(pytds.InterfaceError): tds._main_session.process_login_tokens() # test for invalid env type - sock = _FakeSock([ - b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\xab\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01s\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", - ]) + sock = _FakeSock( + [ + b"\x04\x01\x01\xad\x00Z\x01\x00\xe3/\x00\x01\x10S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00\x06m\x00a\x00s\x00t\x00e\x00r\x00\xab~\x00E\x16\x00\x00\x02\x00/\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00d\x00a\x00t\x00a\x00b\x00a\x00s\x00e\x00 \x00c\x00o\x00n\x00t\x00e\x00x\x00t\x00 \x00t\x00o\x00 \x00'\x00S\x00u\x00b\x00m\x00i\x00s\x00s\x00i\x00o\x00n\x00P\x00o\x00r\x00t\x00a\x00l\x00'\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xe3\x08\x00\xab\x05\t\x04\x00\x01\x00\x00\xe3\x17\x00\x02\nu\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00\x00\xabn\x00G\x16\x00\x00\x01\x00'\x00C\x00h\x00a\x00n\x00g\x00e\x00d\x00 \x00l\x00a\x00n\x00g\x00u\x00a\x00g\x00e\x00 \x00s\x00e\x00t\x00t\x00i\x00n\x00g\x00 \x00t\x00o\x00 \x00u\x00s\x00_\x00e\x00n\x00g\x00l\x00i\x00s\x00h\x00.\x00\tM\x00S\x00S\x00Q\x00L\x00H\x00V\x003\x000\x00\x00\x01\x00\x00\x00\xad6\x00\x01s\x0b\x00\x03\x16M\x00i\x00c\x00r\x00o\x00s\x00o\x00f\x00t\x00 \x00S\x00Q\x00L\x00 \x00S\x00e\x00r\x00v\x00e\x00r\x00\x00\x00\x00\x00\n\x00\x15\x88\xe3\x13\x00\x04\x044\x000\x009\x006\x00\x044\x000\x009\x006\x00\xfd\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + ] + ) tds = _TdsSocket(sock=sock, login=_TdsLogin()) tds._main_session.begin_response() tds._main_session.process_login_tokens() def test_login_generation(self): - sock = _FakeSock(b'') + sock = _FakeSock(b"") login = _TdsLogin() login.option_flag2 = 0 - login.user_name = 'test' - login.password = 'testpwd' - login.app_name = 'appname' - login.server_name = 'servername' - login.library = 'library' - login.language = 'en' - login.database = 'database' + login.user_name = "test" + login.password = "testpwd" + login.app_name = "appname" + login.server_name = "servername" + login.library = "library" + login.language = "en" + login.database = "database" login.auth = None login.tds_version = TDS73 login.bulk_copy = True login.client_lcid = 0x204 - login.attach_db_file = 'filepath' + login.attach_db_file = "filepath" login.readonly = False - login.client_host_name = 'subdev1' + login.client_host_name = "subdev1" login.pid = 100 - login.change_password = '' + login.change_password = "" login.client_tz = tzoffset(-4 * 60) - login.client_id = 0x1234567890ab + login.client_id = 0x1234567890AB tds = _TdsSocket(sock=sock, login=login) tds._main_session.tds7_send_login(login) self.assertEqual( sock._sent, - b'\x10\x01\x00\xde\x00\x00\x00\x00' + # header - b'\xc6\x00\x00\x00' + # size - b'\x03\x00\ns' + # tds version - b'\x00\x10\x00\x00' + # buf size - struct.pack('= len(self._packets): - return b'' + return b"" if self._packet_pos >= len(self._packets[self._curr_packet]): self._curr_packet += 1 self._packet_pos = 0 if self._curr_packet >= len(self._packets): - return b'' - res = self._packets[self._curr_packet][self._packet_pos:self._packet_pos+size] + return b"" + res = self._packets[self._curr_packet][ + self._packet_pos : self._packet_pos + size + ] self._packet_pos += len(res) return res def recv_into(self, buffer, size=0): if not self.is_open(): - raise Exception('Connection closed') + raise Exception("Connection closed") if size == 0: size = len(buffer) res = self.recv(size) - buffer[0:len(res)] = res + buffer[0 : len(res)] = res return len(res) def send(self, buf, flags=0): if not self.is_open(): - raise Exception('Connection closed') + raise Exception("Connection closed") self._out_packets.append(buf) return len(buf) def sendall(self, buf, flags=0): if not self.is_open(): - raise Exception('Connection closed') + raise Exception("Connection closed") self._out_packets.append(buf) def setsockopt(self, *args): @@ -78,7 +81,7 @@ def consume_output(self): """ res = self._out_packets self._out_packets = [] - return b''.join(res) + return b"".join(res) def set_input(self, packets): """ @@ -103,25 +106,35 @@ def does_schema_exist(cursor: pytds.Cursor, name: str, database: str) -> bool: f""" select count(*) from {database}.information_schema.schemata where schema_name = cast(%s as nvarchar(max)) - """, (name,)) + """, + (name,), + ) return val > 0 -def does_stored_proc_exist(cursor: pytds.Cursor, name: str, database: str, schema: str = "dbo") -> bool: +def does_stored_proc_exist( + cursor: pytds.Cursor, name: str, database: str, schema: str = "dbo" +) -> bool: val = cursor.execute_scalar( f""" select count(*) from {database}.information_schema.routines where routine_schema = cast(%s as nvarchar(max)) and routine_name = cast(%s as nvarchar(max)) - """, (schema, name)) + """, + (schema, name), + ) return val > 0 -def does_table_exist(cursor: pytds.Cursor, name: str, database: str, schema: str = "dbo") -> bool: +def does_table_exist( + cursor: pytds.Cursor, name: str, database: str, schema: str = "dbo" +) -> bool: val = cursor.execute_scalar( f""" select count(*) from {database}.information_schema.tables where table_schema = cast(%s as nvarchar(max)) and table_name = cast(%s as nvarchar(max)) - """, (schema, name)) + """, + (schema, name), + ) return val > 0 @@ -133,14 +146,26 @@ def does_user_defined_type_exist(cursor: pytds.Cursor, name: str) -> bool: def create_test_database(connection: pytds.Connection): with connection.cursor() as cur: if not does_database_exist(cursor=cur, name=settings.DATABASE): - cur.execute(f'create database [{settings.DATABASE}]') + cur.execute(f"create database [{settings.DATABASE}]") cur.execute(f"use [{settings.DATABASE}]") - if not does_schema_exist(cursor=cur, name="myschema", database=settings.DATABASE): - cur.execute('create schema myschema') - if not does_table_exist(cursor=cur, name="bulk_insert_table", schema="myschema", database=settings.DATABASE): - cur.execute('create table myschema.bulk_insert_table(num int, data varchar(100))') - if not does_stored_proc_exist(cursor=cur, name="testproc", database=settings.DATABASE): - cur.execute(''' + if not does_schema_exist( + cursor=cur, name="myschema", database=settings.DATABASE + ): + cur.execute("create schema myschema") + if not does_table_exist( + cursor=cur, + name="bulk_insert_table", + schema="myschema", + database=settings.DATABASE, + ): + cur.execute( + "create table myschema.bulk_insert_table(num int, data varchar(100))" + ) + if not does_stored_proc_exist( + cursor=cur, name="testproc", database=settings.DATABASE + ): + cur.execute( + """ create procedure testproc (@param int, @add int = 2, @outparam int output) as begin @@ -149,19 +174,26 @@ def create_test_database(connection: pytds.Connection): set @outparam = @param + @add return @outparam end - ''') + """ + ) # Stored procedure which does not have RETURN statement - if not does_stored_proc_exist(cursor=cur, name="test_proc_no_return", database=settings.DATABASE): - cur.execute(''' + if not does_stored_proc_exist( + cursor=cur, name="test_proc_no_return", database=settings.DATABASE + ): + cur.execute( + """ create procedure test_proc_no_return(@param int) as begin select @param end - ''') + """ + ) if not does_user_defined_type_exist(cursor=cur, name="dbo.CategoryTableType"): - cur.execute('CREATE TYPE dbo.CategoryTableType AS TABLE ( CategoryID int, CategoryName nvarchar(50) )') + cur.execute( + "CREATE TYPE dbo.CategoryTableType AS TABLE ( CategoryID int, CategoryName nvarchar(50) )" + ) def tran_count(cursor: pytds.Cursor) -> int: - return cursor.execute_scalar('select @@trancount') + return cursor.execute_scalar("select @@trancount") diff --git a/tests/utils_35.py b/tests/utils_35.py index bce1696..c02fb0d 100644 --- a/tests/utils_35.py +++ b/tests/utils_35.py @@ -14,18 +14,22 @@ class TestCA: def __init__(self): self._key_cache = {} backend = cryptography.hazmat.backends.default_backend() - self._test_cache_dir = os.path.join(os.path.dirname(__file__), '..', '.test-cache') + self._test_cache_dir = os.path.join( + os.path.dirname(__file__), "..", ".test-cache" + ) os.makedirs(self._test_cache_dir, exist_ok=True) - root_cert_path = self.cert_path('root') - self._root_key = self.key('root') + root_cert_path = self.cert_path("root") + self._root_key = self.key("root") self._root_ca = generate_root_certificate(self._root_key) - pathlib.Path(root_cert_path).write_bytes(self._root_ca.public_bytes(serialization.Encoding.PEM)) + pathlib.Path(root_cert_path).write_bytes( + self._root_ca.public_bytes(serialization.Encoding.PEM) + ) def key_path(self, name): - return os.path.join(self._test_cache_dir, name + 'key.pem') + return os.path.join(self._test_cache_dir, name + "key.pem") def cert_path(self, name): - return os.path.join(self._test_cache_dir, name + 'cert.pem') + return os.path.join(self._test_cache_dir, name + "cert.pem") def key(self, name) -> rsa.RSAPrivateKey: if name not in self._key_cache: @@ -33,7 +37,9 @@ def key(self, name) -> rsa.RSAPrivateKey: key_path = self.key_path(name) if os.path.exists(key_path): bin = pathlib.Path(key_path).read_bytes() - key = serialization.load_pem_private_key(bin, password=None, backend=backend) + key = serialization.load_pem_private_key( + bin, password=None, backend=backend + ) else: key = generate_rsa_key() bin = key.private_bytes( @@ -47,44 +53,48 @@ def key(self, name) -> rsa.RSAPrivateKey: def sign(self, name: str, cb: x509.CertificateBuilder) -> x509.Certificate: backend = cryptography.hazmat.backends.default_backend() - cert = cb.issuer_name(self._root_ca.subject) \ - .sign(private_key=self._root_key, algorithm=hashes.SHA256(), backend=backend) + cert = cb.issuer_name(self._root_ca.subject).sign( + private_key=self._root_key, algorithm=hashes.SHA256(), backend=backend + ) cert_path = self.cert_path(name) - pathlib.Path(cert_path).write_bytes(cert.public_bytes(serialization.Encoding.PEM)) + pathlib.Path(cert_path).write_bytes( + cert.public_bytes(serialization.Encoding.PEM) + ) return cert def generate_rsa_key() -> rsa.RSAPrivateKeyWithSerialization: backend = cryptography.hazmat.backends.default_backend() return rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - backend=backend) + public_exponent=65537, key_size=2048, backend=backend + ) def generate_root_certificate(private_key: rsa.RSAPrivateKey) -> x509.Certificate: backend = cryptography.hazmat.backends.default_backend() - subject = x509.Name( - [x509.NameAttribute( - x509.oid.NameOID.COMMON_NAME, 'root' - )] - ) + subject = x509.Name([x509.NameAttribute(x509.oid.NameOID.COMMON_NAME, "root")]) builder = x509.CertificateBuilder() - return builder.subject_name(subject).issuer_name(subject) \ - .not_valid_before(datetime.datetime.utcnow()) \ - .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=1)) \ - .serial_number(1) \ - .public_key(private_key.public_key()) \ - .add_extension(x509.BasicConstraints(ca=True, path_length=1), critical=True) \ - .add_extension(x509.KeyUsage(digital_signature=False, - content_commitment=False, - key_encipherment=False, - data_encipherment=False, - key_agreement=False, - key_cert_sign=True, - crl_sign=True, - encipher_only=False, - decipher_only=False, - ), critical=True) \ + return ( + builder.subject_name(subject) + .issuer_name(subject) + .not_valid_before(datetime.datetime.utcnow()) + .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=1)) + .serial_number(1) + .public_key(private_key.public_key()) + .add_extension(x509.BasicConstraints(ca=True, path_length=1), critical=True) + .add_extension( + x509.KeyUsage( + digital_signature=False, + content_commitment=False, + key_encipherment=False, + data_encipherment=False, + key_agreement=False, + key_cert_sign=True, + crl_sign=True, + encipher_only=False, + decipher_only=False, + ), + critical=True, + ) .sign(private_key=private_key, algorithm=hashes.SHA256(), backend=backend) - + ) diff --git a/version.py b/version.py index c84b9d3..ac8a07b 100644 --- a/version.py +++ b/version.py @@ -31,18 +31,17 @@ # # include RELEASE-VERSION -__all__ = ("get_git_version") +__all__ = "get_git_version" from subprocess import Popen, PIPE def call_git_describe(abbrev=4): try: - p = Popen(['git', 'describe', '--abbrev=%d' % abbrev], - stdout=PIPE, stderr=PIPE) + p = Popen(["git", "describe", "--abbrev=%d" % abbrev], stdout=PIPE, stderr=PIPE) p.stderr.close() line = p.stdout.readlines()[0] - return line.strip().decode('utf8') + return line.strip().decode("utf8") except: return None @@ -54,7 +53,7 @@ def read_release_version(): try: version = f.readlines()[0] - return version.strip().decode('utf8') + return version.strip().decode("utf8") finally: f.close() @@ -87,7 +86,7 @@ def get_git_version(abbrev=4): # If we still don't have anything, that's an error. if version is None: - return 'unknown' + return "unknown" # If the current version is different from what's in the # RELEASE-VERSION file, update the file to be current.