From 34a292d9d81edace6dfe658f00f57ccbc65206c6 Mon Sep 17 00:00:00 2001 From: Antti Kaihola Date: Sun, 12 Sep 2021 11:49:08 +0300 Subject: [PATCH 01/14] Add merge sort tool --- pgtricks/mergesort.py | 60 +++++++++++++++++++++++ pgtricks/tests/test_mergesort.py | 83 ++++++++++++++++++++++++++++++++ pyproject.toml | 2 + 3 files changed, 145 insertions(+) create mode 100644 pgtricks/mergesort.py create mode 100644 pgtricks/tests/test_mergesort.py diff --git a/pgtricks/mergesort.py b/pgtricks/mergesort.py new file mode 100644 index 0000000..839ef9f --- /dev/null +++ b/pgtricks/mergesort.py @@ -0,0 +1,60 @@ +"""Merge sort implementation to handle large files by sorting them in partitions.""" + +from __future__ import annotations + +import sys +from heapq import merge +from tempfile import TemporaryFile +from typing import IO, Iterable, Iterator, cast + + +class MergeSort(Iterable[str]): + """Merge sort implementation to handle large files by sorting them in partitions.""" + + def __init__(self, directory: str = ".", max_memory: int = 190) -> None: + """Initialize the merge sort object.""" + self._directory = directory + self._max_memory = max_memory + self._partitions: list[IO[str]] = [] + self._iterating: Iterable[str] | None = None + self._buffer: list[str] = [] + self._memory_counter = 0 + self._flush() + + def append(self, line: str) -> None: + """Append a line to the set of lines to be sorted.""" + if self._iterating: + message = "Can't append lines after starting to sort" + raise ValueError(message) + self._memory_counter -= sys.getsizeof(self._buffer) + self._buffer.append(line) + self._memory_counter += sys.getsizeof(self._buffer) + self._memory_counter += sys.getsizeof(line) + if self._memory_counter >= self._max_memory: + self._flush() + + def _flush(self) -> None: + if self._buffer: + self._partitions.append(TemporaryFile(mode="w+", dir=self._directory)) + self._partitions[-1].writelines(sorted(self._buffer)) + self._buffer = [] + self._memory_counter = sys.getsizeof(self._buffer) + + def __next__(self) -> str: + """Return the next line in the sorted list of lines.""" + if not self._iterating: + if self._partitions: + # At least one partition has already been flushed to disk. + # Iterate the merge sort for all partitions. + self._flush() + for partition in self._partitions: + partition.seek(0) + self._iterating = merge(*self._partitions) + else: + # All lines fit in memory. Iterate the list of lines directly. + self._iterating = iter(sorted(self._buffer)) + return next(cast(Iterator[str], self._iterating)) + + def __iter__(self) -> Iterator[str]: + """Return the iterator object for the sorted list of lines.""" + return self diff --git a/pgtricks/tests/test_mergesort.py b/pgtricks/tests/test_mergesort.py new file mode 100644 index 0000000..23e5e88 --- /dev/null +++ b/pgtricks/tests/test_mergesort.py @@ -0,0 +1,83 @@ +"""Tests for the `pgtricks.mergesort` module.""" + +from types import GeneratorType +from typing import Iterable, cast + +import pytest + +from pgtricks.mergesort import MergeSort + + +def test_mergesort_append(tmpdir): + """Test appending lines to the merge sort object.""" + m = MergeSort(directory=tmpdir, max_memory=190) + m.append("1\n") + assert m._buffer == ["1\n"] + m.append("2\n") + assert m._buffer == [] + m.append("3\n") + assert m._buffer == ["3\n"] + assert len(m._partitions) == 1 + assert m._partitions[0].tell() == len("1\n2\n") + m._partitions[0].seek(0) + assert m._partitions[0].read() == "1\n2\n" + + +def test_mergesort_flush(tmpdir): + """Test flushing the buffer to disk.""" + m = MergeSort(directory=tmpdir, max_memory=190) + for value in [1, 2, 3]: + m.append(f"{value}\n") + m._flush() + assert len(m._partitions) == 2 + assert m._partitions[0].tell() == len("1\n2\n") + m._partitions[0].seek(0) + assert m._partitions[0].read() == "1\n2\n" + assert m._partitions[1].tell() == len("3\n") + m._partitions[1].seek(0) + assert m._partitions[1].read() == "3\n" + + +def test_mergesort_iterate_disk(tmpdir): + """Test iterating over the sorted lines on disk.""" + m = MergeSort(directory=tmpdir, max_memory=190) + for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: + m.append(f"{value}\n") + assert next(m) == "1\n" + assert isinstance(m._iterating, GeneratorType) + assert next(m) == "1\n" + assert next(m) == "2\n" + assert next(m) == "3\n" + assert next(m) == "3\n" + assert next(m) == "4\n" + assert next(m) == "4\n" + assert next(m) == "5\n" + assert next(m) == "5\n" + assert next(m) == "6\n" + assert next(m) == "8\n" + assert next(m) == "9\n" + with pytest.raises(StopIteration): + next(m) + + +def test_mergesort_iterate_memory(tmpdir): + """Test iterating over the sorted lines when all lines fit in memory.""" + m = MergeSort(directory=tmpdir, max_memory=1000000) + for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: + m.append(f"{value}\n") + assert next(m) == "1\n" + assert not isinstance(m._iterating, GeneratorType) + assert iter(cast(Iterable[str], m._iterating)) is m._iterating + assert next(m) == "1\n" + assert next(m) == "2\n" + assert next(m) == "3\n" + assert next(m) == "3\n" + assert next(m) == "4\n" + assert next(m) == "4\n" + assert next(m) == "5\n" + assert next(m) == "5\n" + assert next(m) == "6\n" + assert next(m) == "8\n" + assert next(m) == "9\n" + with pytest.raises(StopIteration): + next(m) diff --git a/pyproject.toml b/pyproject.toml index 086fffd..6f3739b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,9 @@ ignore = [ "ANN201", # Missing return type annotation for public function #"ANN204", # Missing return type annotation for special method `__init__` #"C408", # Unnecessary `dict` call (rewrite as a literal) + "PLR2004", # Magic value used in comparison "S101", # Use of `assert` detected + "SLF001", # Private member accessed ] [tool.ruff.lint.isort] From 128ecec90a479a188f5b74117f9c16360eae4202 Mon Sep 17 00:00:00 2001 From: Antti Kaihola Date: Sun, 12 Sep 2021 11:49:08 +0300 Subject: [PATCH 02/14] Allow custom key for merge sort --- mypy.ini | 3 +++ pgtricks/mergesort.py | 14 ++++++++++---- pgtricks/tests/test_mergesort.py | 9 +++++++++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/mypy.ini b/mypy.ini index ed30dab..bbb2f00 100644 --- a/mypy.ini +++ b/mypy.ini @@ -32,6 +32,9 @@ strict_equality = True disallow_any_decorated = False disallow_untyped_defs = False +[mypy-pgtricks.mergesort] +disallow_any_explicit = False + [mypy-pytest.*] ignore_missing_imports = True diff --git a/pgtricks/mergesort.py b/pgtricks/mergesort.py index 839ef9f..9280fd2 100644 --- a/pgtricks/mergesort.py +++ b/pgtricks/mergesort.py @@ -5,14 +5,20 @@ import sys from heapq import merge from tempfile import TemporaryFile -from typing import IO, Iterable, Iterator, cast +from typing import IO, Any, Callable, Iterable, Iterator, cast class MergeSort(Iterable[str]): """Merge sort implementation to handle large files by sorting them in partitions.""" - def __init__(self, directory: str = ".", max_memory: int = 190) -> None: + def __init__( + self, + key: Callable[[str], Any] = str, + directory: str = ".", + max_memory: int = 190, + ) -> None: """Initialize the merge sort object.""" + self._key = key self._directory = directory self._max_memory = max_memory self._partitions: list[IO[str]] = [] @@ -36,7 +42,7 @@ def append(self, line: str) -> None: def _flush(self) -> None: if self._buffer: self._partitions.append(TemporaryFile(mode="w+", dir=self._directory)) - self._partitions[-1].writelines(sorted(self._buffer)) + self._partitions[-1].writelines(sorted(self._buffer, key=self._key)) self._buffer = [] self._memory_counter = sys.getsizeof(self._buffer) @@ -49,7 +55,7 @@ def __next__(self) -> str: self._flush() for partition in self._partitions: partition.seek(0) - self._iterating = merge(*self._partitions) + self._iterating = merge(*self._partitions, key=self._key) else: # All lines fit in memory. Iterate the list of lines directly. self._iterating = iter(sorted(self._buffer)) diff --git a/pgtricks/tests/test_mergesort.py b/pgtricks/tests/test_mergesort.py index 23e5e88..92972de 100644 --- a/pgtricks/tests/test_mergesort.py +++ b/pgtricks/tests/test_mergesort.py @@ -81,3 +81,12 @@ def test_mergesort_iterate_memory(tmpdir): assert next(m) == "9\n" with pytest.raises(StopIteration): next(m) + + +def test_mergesort_key(tmpdir): + """Test sorting lines based on a key function.""" + m = MergeSort(directory=tmpdir, key=lambda line: -int(line[0])) + for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: + m.append(f"{value}\n") + result = "".join(value[0] for value in m) + assert result == "986554433211" From c0bf8f5cf59498a38f7bb61d3074c0afe6fb21f8 Mon Sep 17 00:00:00 2001 From: Antti Kaihola Date: Sun, 12 Sep 2021 11:49:08 +0300 Subject: [PATCH 03/14] Use merge sort for `pg_dump_splitsort` Uses up to 100MB of memory by default before splitting table data into files and merge sorting them. --- pgtricks/pg_dump_splitsort.py | 41 ++++++++---- pgtricks/tests/test_pg_dump_splitsort.py | 84 +++++++++++++++++++++++- 2 files changed, 110 insertions(+), 15 deletions(-) diff --git a/pgtricks/pg_dump_splitsort.py b/pgtricks/pg_dump_splitsort.py index 56908ea..93fe1d2 100755 --- a/pgtricks/pg_dump_splitsort.py +++ b/pgtricks/pg_dump_splitsort.py @@ -1,15 +1,20 @@ #!/usr/bin/env python +from __future__ import annotations + import functools import os import re import sys -from typing import IO, List, Match, Optional, Pattern, Tuple, Union, cast +from typing import IO, Match, Pattern, cast + +from pgtricks.mergesort import MergeSort COPY_RE = re.compile(r'COPY .*? \(.*?\) FROM stdin;\n$') -def try_float(s1: str, s2: str) -> Union[Tuple[str, str], Tuple[float, float]]: +def try_float(s1: str, s2: str) -> tuple[str, str] | tuple[float, float]: + """Convert two strings to floats. Return original ones on conversion error.""" if not s1 or not s2 or s1[0] not in '0123456789.-' or s2[0] not in '0123456789.-': # optimization return s1, s2 @@ -22,7 +27,7 @@ def try_float(s1: str, s2: str) -> Union[Tuple[str, str], Tuple[float, float]]: def linecomp(l1: str, l2: str) -> int: p1 = l1.split('\t', 1) p2 = l2.split('\t', 1) - v1, v2 = cast(Tuple[float, float], try_float(p1[0], p2[0])) + v1, v2 = cast(tuple[float, float], try_float(p1[0], p2[0])) result = (v1 > v2) - (v1 < v2) # modifying a line to see whether Darker works: if not result and len(p1) == len(p2) == 2: @@ -37,9 +42,10 @@ def linecomp(l1: str, l2: str) -> int: class Matcher(object): def __init__(self) -> None: - self._match: Optional[Match[str]] = None + self._match: Match[str] | None = None - def match(self, pattern: Pattern[str], data: str) -> Optional[Match[str]]: + def match(self, pattern: Pattern[str], data: str) -> Match[str] | None: + """Match the regular expression pattern against the data.""" self._match = pattern.match(data) return self._match @@ -49,12 +55,15 @@ def group(self, group1: str) -> str: return self._match.group(group1) -def split_sql_file(sql_filepath: str) -> None: - +def split_sql_file( # noqa: PLR0912, C901 many branches, too complex + sql_filepath: str, + max_memory: int = 10**8, +) -> None: + """Split a SQL file so that each COPY statement is in its own file.""" directory = os.path.dirname(sql_filepath) - output: Optional[IO[str]] = None - buf: List[str] = [] + output: IO[str] | None = None + buf: list[str] = [] def flush() -> None: cast(IO[str], output).writelines(buf) @@ -65,7 +74,7 @@ def new_output(filename: str) -> IO[str]: output.close() return open(os.path.join(directory, filename), 'w') - copy_lines: Optional[List[str]] = None + copy_lines: MergeSort | None = None counter = 0 output = new_output('0000_prologue.sql') matcher = Matcher() @@ -86,7 +95,10 @@ def new_output(filename: str) -> IO[str]: schema=matcher.group('schema'), table=matcher.group('table'))) elif COPY_RE.match(line): - copy_lines = [] + copy_lines = MergeSort( + key=functools.cmp_to_key(linecomp), + max_memory=max_memory, + ) elif SEQUENCE_SET_RE.match(line): pass elif 1 <= counter < 9999: @@ -95,9 +107,10 @@ def new_output(filename: str) -> IO[str]: buf.append(line) flush() else: - if line == '\\.\n': - copy_lines.sort(key=functools.cmp_to_key(linecomp)) - buf.extend(copy_lines) + if line == "\\.\n": + for copy_line in copy_lines: + buf.append(copy_line) + flush() buf.append(line) flush() copy_lines = None diff --git a/pgtricks/tests/test_pg_dump_splitsort.py b/pgtricks/tests/test_pg_dump_splitsort.py index 3305c03..080d8d6 100644 --- a/pgtricks/tests/test_pg_dump_splitsort.py +++ b/pgtricks/tests/test_pg_dump_splitsort.py @@ -1,8 +1,9 @@ from functools import cmp_to_key +from textwrap import dedent import pytest -from pgtricks.pg_dump_splitsort import linecomp, try_float +from pgtricks.pg_dump_splitsort import linecomp, split_sql_file, try_float @pytest.mark.parametrize( @@ -101,3 +102,84 @@ def test_linecomp_by_sorting(): [r'\N', r'\N', r'\N'], [r'\N', 'foo', '.42'], ] + + +PROLOGUE = dedent( + """ + + -- + -- Name: table1; Type: TABLE; Schema: public; Owner: + -- + + (information for table1 goes here) + """, +) + +TABLE1_COPY = dedent( + r""" + + -- Data for Name: table1; Type: TABLE DATA; Schema: public; + + COPY foo (id) FROM stdin; + 3 + 1 + 4 + 1 + 5 + 9 + 2 + 6 + 5 + 3 + 8 + 4 + \. + """, +) + +TABLE1_COPY_SORTED = dedent( + r""" + + -- Data for Name: table1; Type: TABLE DATA; Schema: public; + + COPY foo (id) FROM stdin; + 1 + 1 + 2 + 3 + 3 + 4 + 4 + 5 + 5 + 6 + 8 + 9 + \. + """, +) + +EPILOGUE = dedent( + """ + -- epilogue + """, +) + + +def test_split_sql_file(tmpdir): + """Test splitting a SQL file with COPY statements.""" + sql_file = tmpdir / "test.sql" + sql_file.write(PROLOGUE + TABLE1_COPY + EPILOGUE) + + split_sql_file(sql_file, max_memory=190) + + split_files = sorted(path.relto(tmpdir) for path in tmpdir.listdir()) + assert split_files == [ + "0000_prologue.sql", + "0001_public.table1.sql", + "9999_epilogue.sql", + "test.sql", + ] + assert (tmpdir / "0000_prologue.sql").read() == PROLOGUE + assert (tmpdir / "0001_public.table1.sql").read() == TABLE1_COPY_SORTED + assert (tmpdir / "9999_epilogue.sql").read() == EPILOGUE From 85eebab4b7b7b08cffc0c30c32736f36ccc0e024 Mon Sep 17 00:00:00 2001 From: Antti Kaihola Date: Sun, 12 Sep 2021 11:49:08 +0300 Subject: [PATCH 04/14] Avoid buffer flushes writing sorted COPY lines --- pgtricks/pg_dump_splitsort.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/pgtricks/pg_dump_splitsort.py b/pgtricks/pg_dump_splitsort.py index 93fe1d2..1aec575 100755 --- a/pgtricks/pg_dump_splitsort.py +++ b/pgtricks/pg_dump_splitsort.py @@ -3,6 +3,7 @@ from __future__ import annotations import functools +import io import os import re import sys @@ -62,13 +63,21 @@ def split_sql_file( # noqa: PLR0912, C901 many branches, too complex """Split a SQL file so that each COPY statement is in its own file.""" directory = os.path.dirname(sql_filepath) - output: IO[str] | None = None + # `output` needs to be instantiated before the inner functions are defined. + # Assign it a dummy string I/O object so type checking is happy. + # This will be replaced with the prologue SQL file object. + output: IO[str] = io.StringIO() buf: list[str] = [] def flush() -> None: - cast(IO[str], output).writelines(buf) + output.writelines(buf) buf[:] = [] + def writeline(line_: str) -> None: + if buf: + flush() + output.write(line_) + def new_output(filename: str) -> IO[str]: if output: output.close() @@ -84,8 +93,7 @@ def new_output(filename: str) -> IO[str]: if line in ('\n', '--\n'): buf.append(line) elif line.startswith('SET search_path = '): - flush() - buf.append(line) + writeline(line) else: if matcher.match(DATA_COMMENT_RE, line): counter += 1 @@ -104,15 +112,12 @@ def new_output(filename: str) -> IO[str]: elif 1 <= counter < 9999: counter = 9999 output = new_output('%04d_epilogue.sql' % counter) - buf.append(line) - flush() + writeline(line) else: if line == "\\.\n": for copy_line in copy_lines: - buf.append(copy_line) - flush() - buf.append(line) - flush() + writeline(copy_line) + writeline(line) copy_lines = None else: copy_lines.append(line) From 127dbbf89c3d83d5cd670902bc24fb92415947e8 Mon Sep 17 00:00:00 2001 From: Antti Kaihola Date: Sun, 12 Sep 2021 11:49:08 +0300 Subject: [PATCH 05/14] Write all COPY data in one go --- pgtricks/pg_dump_splitsort.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/pgtricks/pg_dump_splitsort.py b/pgtricks/pg_dump_splitsort.py index 1aec575..529fdae 100755 --- a/pgtricks/pg_dump_splitsort.py +++ b/pgtricks/pg_dump_splitsort.py @@ -7,7 +7,7 @@ import os import re import sys -from typing import IO, Match, Pattern, cast +from typing import IO, Iterable, Match, Pattern, cast from pgtricks.mergesort import MergeSort @@ -56,7 +56,7 @@ def group(self, group1: str) -> str: return self._match.group(group1) -def split_sql_file( # noqa: PLR0912, C901 many branches, too complex +def split_sql_file( # noqa: C901 too complex sql_filepath: str, max_memory: int = 10**8, ) -> None: @@ -73,10 +73,10 @@ def flush() -> None: output.writelines(buf) buf[:] = [] - def writeline(line_: str) -> None: + def writelines(lines: Iterable[str]) -> None: if buf: flush() - output.write(line_) + output.writelines(lines) def new_output(filename: str) -> IO[str]: if output: @@ -93,7 +93,7 @@ def new_output(filename: str) -> IO[str]: if line in ('\n', '--\n'): buf.append(line) elif line.startswith('SET search_path = '): - writeline(line) + writelines([line]) else: if matcher.match(DATA_COMMENT_RE, line): counter += 1 @@ -112,12 +112,11 @@ def new_output(filename: str) -> IO[str]: elif 1 <= counter < 9999: counter = 9999 output = new_output('%04d_epilogue.sql' % counter) - writeline(line) + writelines([line]) else: if line == "\\.\n": - for copy_line in copy_lines: - writeline(copy_line) - writeline(line) + writelines(copy_lines) + writelines(line) copy_lines = None else: copy_lines.append(line) From 459bb7e8d87723a155bf195b2f55e2f4569567bf Mon Sep 17 00:00:00 2001 From: Antti Kaihola Date: Sun, 12 Sep 2021 11:49:08 +0300 Subject: [PATCH 06/14] Better varname for merge sorter. Also sorted imports. --- pgtricks/pg_dump_splitsort.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pgtricks/pg_dump_splitsort.py b/pgtricks/pg_dump_splitsort.py index 529fdae..8b7a123 100755 --- a/pgtricks/pg_dump_splitsort.py +++ b/pgtricks/pg_dump_splitsort.py @@ -83,13 +83,13 @@ def new_output(filename: str) -> IO[str]: output.close() return open(os.path.join(directory, filename), 'w') - copy_lines: MergeSort | None = None + sorted_data_lines: MergeSort | None = None counter = 0 output = new_output('0000_prologue.sql') matcher = Matcher() for line in open(sql_filepath): - if copy_lines is None: + if sorted_data_lines is None: if line in ('\n', '--\n'): buf.append(line) elif line.startswith('SET search_path = '): @@ -103,7 +103,7 @@ def new_output(filename: str) -> IO[str]: schema=matcher.group('schema'), table=matcher.group('table'))) elif COPY_RE.match(line): - copy_lines = MergeSort( + sorted_data_lines = MergeSort( key=functools.cmp_to_key(linecomp), max_memory=max_memory, ) @@ -115,11 +115,11 @@ def new_output(filename: str) -> IO[str]: writelines([line]) else: if line == "\\.\n": - writelines(copy_lines) + writelines(sorted_data_lines) writelines(line) - copy_lines = None + sorted_data_lines = None else: - copy_lines.append(line) + sorted_data_lines.append(line) flush() From 952fa77fdcf1c096beaf3c1ca96eddfd9f2b750a Mon Sep 17 00:00:00 2001 From: Antti Kaihola <13725+akaihola@users.noreply.github.com> Date: Fri, 19 Apr 2024 23:40:44 +0300 Subject: [PATCH 07/14] Fix Python 3.8 compatibility --- pgtricks/pg_dump_splitsort.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pgtricks/pg_dump_splitsort.py b/pgtricks/pg_dump_splitsort.py index 8b7a123..3732a02 100755 --- a/pgtricks/pg_dump_splitsort.py +++ b/pgtricks/pg_dump_splitsort.py @@ -28,7 +28,8 @@ def try_float(s1: str, s2: str) -> tuple[str, str] | tuple[float, float]: def linecomp(l1: str, l2: str) -> int: p1 = l1.split('\t', 1) p2 = l2.split('\t', 1) - v1, v2 = cast(tuple[float, float], try_float(p1[0], p2[0])) + # TODO: unquote cast after support for Python 3.8 is dropped + v1, v2 = cast("tuple[float, float]", try_float(p1[0], p2[0])) result = (v1 > v2) - (v1 < v2) # modifying a line to see whether Darker works: if not result and len(p1) == len(p2) == 2: From a19b2bd2949d629bfff2fc2dbf8635f858ed5b22 Mon Sep 17 00:00:00 2001 From: Antti Kaihola <13725+akaihola@users.noreply.github.com> Date: Fri, 19 Apr 2024 23:53:28 +0300 Subject: [PATCH 08/14] Fix tests compatibility with Windows --- pgtricks/tests/test_mergesort.py | 81 +++++++++++++++++--------------- 1 file changed, 42 insertions(+), 39 deletions(-) diff --git a/pgtricks/tests/test_mergesort.py b/pgtricks/tests/test_mergesort.py index 92972de..6a81760 100644 --- a/pgtricks/tests/test_mergesort.py +++ b/pgtricks/tests/test_mergesort.py @@ -1,5 +1,6 @@ """Tests for the `pgtricks.mergesort` module.""" +import os from types import GeneratorType from typing import Iterable, cast @@ -7,55 +8,57 @@ from pgtricks.mergesort import MergeSort +LF = os.linesep + def test_mergesort_append(tmpdir): """Test appending lines to the merge sort object.""" m = MergeSort(directory=tmpdir, max_memory=190) - m.append("1\n") - assert m._buffer == ["1\n"] - m.append("2\n") + m.append(f"1{LF}") + assert m._buffer == [f"1{LF}"] + m.append(f"2{LF}") assert m._buffer == [] - m.append("3\n") - assert m._buffer == ["3\n"] + m.append(f"3{LF}") + assert m._buffer == [f"3{LF}"] assert len(m._partitions) == 1 - assert m._partitions[0].tell() == len("1\n2\n") + assert m._partitions[0].tell() == len(f"1{LF}2{LF}") m._partitions[0].seek(0) - assert m._partitions[0].read() == "1\n2\n" + assert m._partitions[0].read() == f"1{LF}2{LF}" def test_mergesort_flush(tmpdir): """Test flushing the buffer to disk.""" m = MergeSort(directory=tmpdir, max_memory=190) for value in [1, 2, 3]: - m.append(f"{value}\n") + m.append(f"{value}{LF}") m._flush() assert len(m._partitions) == 2 - assert m._partitions[0].tell() == len("1\n2\n") + assert m._partitions[0].tell() == len(f"1{LF}2{LF}") m._partitions[0].seek(0) - assert m._partitions[0].read() == "1\n2\n" - assert m._partitions[1].tell() == len("3\n") + assert m._partitions[0].read() == f"1{LF}2{LF}" + assert m._partitions[1].tell() == len(f"3{LF}") m._partitions[1].seek(0) - assert m._partitions[1].read() == "3\n" + assert m._partitions[1].read() == f"3{LF}" def test_mergesort_iterate_disk(tmpdir): """Test iterating over the sorted lines on disk.""" m = MergeSort(directory=tmpdir, max_memory=190) for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: - m.append(f"{value}\n") - assert next(m) == "1\n" + m.append(f"{value}{LF}") + assert next(m) == f"1{LF}" assert isinstance(m._iterating, GeneratorType) - assert next(m) == "1\n" - assert next(m) == "2\n" - assert next(m) == "3\n" - assert next(m) == "3\n" - assert next(m) == "4\n" - assert next(m) == "4\n" - assert next(m) == "5\n" - assert next(m) == "5\n" - assert next(m) == "6\n" - assert next(m) == "8\n" - assert next(m) == "9\n" + assert next(m) == f"1{LF}" + assert next(m) == f"2{LF}" + assert next(m) == f"3{LF}" + assert next(m) == f"3{LF}" + assert next(m) == f"4{LF}" + assert next(m) == f"4{LF}" + assert next(m) == f"5{LF}" + assert next(m) == f"5{LF}" + assert next(m) == f"6{LF}" + assert next(m) == f"8{LF}" + assert next(m) == f"9{LF}" with pytest.raises(StopIteration): next(m) @@ -64,21 +67,21 @@ def test_mergesort_iterate_memory(tmpdir): """Test iterating over the sorted lines when all lines fit in memory.""" m = MergeSort(directory=tmpdir, max_memory=1000000) for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: - m.append(f"{value}\n") - assert next(m) == "1\n" + m.append(f"{value}{LF}") + assert next(m) == f"1{LF}" assert not isinstance(m._iterating, GeneratorType) assert iter(cast(Iterable[str], m._iterating)) is m._iterating - assert next(m) == "1\n" - assert next(m) == "2\n" - assert next(m) == "3\n" - assert next(m) == "3\n" - assert next(m) == "4\n" - assert next(m) == "4\n" - assert next(m) == "5\n" - assert next(m) == "5\n" - assert next(m) == "6\n" - assert next(m) == "8\n" - assert next(m) == "9\n" + assert next(m) == f"1{LF}" + assert next(m) == f"2{LF}" + assert next(m) == f"3{LF}" + assert next(m) == f"3{LF}" + assert next(m) == f"4{LF}" + assert next(m) == f"4{LF}" + assert next(m) == f"5{LF}" + assert next(m) == f"5{LF}" + assert next(m) == f"6{LF}" + assert next(m) == f"8{LF}" + assert next(m) == f"9{LF}" with pytest.raises(StopIteration): next(m) @@ -87,6 +90,6 @@ def test_mergesort_key(tmpdir): """Test sorting lines based on a key function.""" m = MergeSort(directory=tmpdir, key=lambda line: -int(line[0])) for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: - m.append(f"{value}\n") + m.append(f"{value}{LF}") result = "".join(value[0] for value in m) assert result == "986554433211" From cceee257b27de1dd33472f9fa2511df5d1b30402 Mon Sep 17 00:00:00 2001 From: Antti Kaihola <13725+akaihola@users.noreply.github.com> Date: Fri, 19 Apr 2024 23:58:30 +0300 Subject: [PATCH 09/14] Test merge sort partition position an end of test In case of failure, this provides us the actual output before checking the position. --- pgtricks/tests/test_mergesort.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pgtricks/tests/test_mergesort.py b/pgtricks/tests/test_mergesort.py index 6a81760..4c3c8fd 100644 --- a/pgtricks/tests/test_mergesort.py +++ b/pgtricks/tests/test_mergesort.py @@ -21,9 +21,10 @@ def test_mergesort_append(tmpdir): m.append(f"3{LF}") assert m._buffer == [f"3{LF}"] assert len(m._partitions) == 1 - assert m._partitions[0].tell() == len(f"1{LF}2{LF}") + pos = m._partitions[0].tell() m._partitions[0].seek(0) assert m._partitions[0].read() == f"1{LF}2{LF}" + assert pos == len(f"1{LF}2{LF}") def test_mergesort_flush(tmpdir): @@ -36,9 +37,10 @@ def test_mergesort_flush(tmpdir): assert m._partitions[0].tell() == len(f"1{LF}2{LF}") m._partitions[0].seek(0) assert m._partitions[0].read() == f"1{LF}2{LF}" - assert m._partitions[1].tell() == len(f"3{LF}") + pos = m._partitions[1].tell() m._partitions[1].seek(0) assert m._partitions[1].read() == f"3{LF}" + assert pos == len(f"3{LF}") def test_mergesort_iterate_disk(tmpdir): From d746c2d53ee97e57567c3a4a0b87efa000d4dbf9 Mon Sep 17 00:00:00 2001 From: Antti Kaihola <13725+akaihola@users.noreply.github.com> Date: Sat, 20 Apr 2024 11:22:58 +0300 Subject: [PATCH 10/14] Stop newline conversion on Win. Test all newlines. --- pgtricks/mergesort.py | 18 ++++-- pgtricks/tests/test_mergesort.py | 104 ++++++++++++++++--------------- 2 files changed, 69 insertions(+), 53 deletions(-) diff --git a/pgtricks/mergesort.py b/pgtricks/mergesort.py index 9280fd2..9d235ca 100644 --- a/pgtricks/mergesort.py +++ b/pgtricks/mergesort.py @@ -21,7 +21,8 @@ def __init__( self._key = key self._directory = directory self._max_memory = max_memory - self._partitions: list[IO[str]] = [] + # Use binary mode to avoid newline conversion on Windows. + self._partitions: list[IO[bytes]] = [] self._iterating: Iterable[str] | None = None self._buffer: list[str] = [] self._memory_counter = 0 @@ -41,8 +42,11 @@ def append(self, line: str) -> None: def _flush(self) -> None: if self._buffer: - self._partitions.append(TemporaryFile(mode="w+", dir=self._directory)) - self._partitions[-1].writelines(sorted(self._buffer, key=self._key)) + # Use binary mode to avoid newline conversion on Windows. + self._partitions.append(TemporaryFile(mode="w+b", dir=self._directory)) + self._partitions[-1].writelines( + line.encode("UTF-8") for line in sorted(self._buffer, key=self._key) + ) self._buffer = [] self._memory_counter = sys.getsizeof(self._buffer) @@ -55,7 +59,13 @@ def __next__(self) -> str: self._flush() for partition in self._partitions: partition.seek(0) - self._iterating = merge(*self._partitions, key=self._key) + self._iterating = merge( + *[ + (line.decode("UTF-8") for line in partition) + for partition in self._partitions + ], + key=self._key, + ) else: # All lines fit in memory. Iterate the list of lines directly. self._iterating = iter(sorted(self._buffer)) diff --git a/pgtricks/tests/test_mergesort.py b/pgtricks/tests/test_mergesort.py index 4c3c8fd..ca1f27c 100644 --- a/pgtricks/tests/test_mergesort.py +++ b/pgtricks/tests/test_mergesort.py @@ -1,6 +1,5 @@ """Tests for the `pgtricks.mergesort` module.""" -import os from types import GeneratorType from typing import Iterable, cast @@ -8,90 +7,97 @@ from pgtricks.mergesort import MergeSort -LF = os.linesep +# This is the biggest amount of memory which can't hold two one-character lines on any +# platform. On Windows it's slightly smaller than on Unix. +JUST_BELOW_TWO_SHORT_LINES = 174 -def test_mergesort_append(tmpdir): +@pytest.mark.parametrize("lf", ["\n", "\r\n"]) +def test_mergesort_append(tmpdir, lf): """Test appending lines to the merge sort object.""" - m = MergeSort(directory=tmpdir, max_memory=190) - m.append(f"1{LF}") - assert m._buffer == [f"1{LF}"] - m.append(f"2{LF}") + m = MergeSort(directory=tmpdir, max_memory=JUST_BELOW_TWO_SHORT_LINES) + m.append(f"1{lf}") + assert m._buffer == [f"1{lf}"] + m.append(f"2{lf}") assert m._buffer == [] - m.append(f"3{LF}") - assert m._buffer == [f"3{LF}"] + m.append(f"3{lf}") + assert m._buffer == [f"3{lf}"] assert len(m._partitions) == 1 pos = m._partitions[0].tell() m._partitions[0].seek(0) - assert m._partitions[0].read() == f"1{LF}2{LF}" - assert pos == len(f"1{LF}2{LF}") + assert m._partitions[0].read() == f"1{lf}2{lf}".encode("UTF-8") + assert pos == len(f"1{lf}2{lf}") -def test_mergesort_flush(tmpdir): +@pytest.mark.parametrize("lf", ["\n", "\r\n"]) +def test_mergesort_flush(tmpdir, lf): """Test flushing the buffer to disk.""" - m = MergeSort(directory=tmpdir, max_memory=190) + m = MergeSort(directory=tmpdir, max_memory=JUST_BELOW_TWO_SHORT_LINES) for value in [1, 2, 3]: - m.append(f"{value}{LF}") + m.append(f"{value}{lf}") m._flush() assert len(m._partitions) == 2 - assert m._partitions[0].tell() == len(f"1{LF}2{LF}") + assert m._partitions[0].tell() == len(f"1{lf}2{lf}") m._partitions[0].seek(0) - assert m._partitions[0].read() == f"1{LF}2{LF}" + assert m._partitions[0].read() == f"1{lf}2{lf}".encode("UTF-8") pos = m._partitions[1].tell() m._partitions[1].seek(0) - assert m._partitions[1].read() == f"3{LF}" - assert pos == len(f"3{LF}") + assert m._partitions[1].read() == f"3{lf}".encode("UTF-8") + assert pos == len(f"3{lf}") -def test_mergesort_iterate_disk(tmpdir): +@pytest.mark.parametrize("lf", ["\n", "\r\n"]) +def test_mergesort_iterate_disk(tmpdir, lf): """Test iterating over the sorted lines on disk.""" - m = MergeSort(directory=tmpdir, max_memory=190) + m = MergeSort(directory=tmpdir, max_memory=JUST_BELOW_TWO_SHORT_LINES) for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: - m.append(f"{value}{LF}") - assert next(m) == f"1{LF}" + m.append(f"{value}{lf}") + assert next(m) == f"1{lf}" assert isinstance(m._iterating, GeneratorType) - assert next(m) == f"1{LF}" - assert next(m) == f"2{LF}" - assert next(m) == f"3{LF}" - assert next(m) == f"3{LF}" - assert next(m) == f"4{LF}" - assert next(m) == f"4{LF}" - assert next(m) == f"5{LF}" - assert next(m) == f"5{LF}" - assert next(m) == f"6{LF}" - assert next(m) == f"8{LF}" - assert next(m) == f"9{LF}" + assert next(m) == f"1{lf}" + assert next(m) == f"2{lf}" + assert next(m) == f"3{lf}" + assert next(m) == f"3{lf}" + assert next(m) == f"4{lf}" + assert next(m) == f"4{lf}" + assert next(m) == f"5{lf}" + assert next(m) == f"5{lf}" + assert next(m) == f"6{lf}" + assert next(m) == f"8{lf}" + assert next(m) == f"9{lf}" with pytest.raises(StopIteration): next(m) -def test_mergesort_iterate_memory(tmpdir): +@pytest.mark.parametrize("lf", ["\n", "\r\n"]) +def test_mergesort_iterate_memory(tmpdir, lf): """Test iterating over the sorted lines when all lines fit in memory.""" m = MergeSort(directory=tmpdir, max_memory=1000000) for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: - m.append(f"{value}{LF}") - assert next(m) == f"1{LF}" + m.append(f"{value}{lf}") + assert next(m) == f"1{lf}" assert not isinstance(m._iterating, GeneratorType) assert iter(cast(Iterable[str], m._iterating)) is m._iterating - assert next(m) == f"1{LF}" - assert next(m) == f"2{LF}" - assert next(m) == f"3{LF}" - assert next(m) == f"3{LF}" - assert next(m) == f"4{LF}" - assert next(m) == f"4{LF}" - assert next(m) == f"5{LF}" - assert next(m) == f"5{LF}" - assert next(m) == f"6{LF}" - assert next(m) == f"8{LF}" - assert next(m) == f"9{LF}" + assert next(m) == f"1{lf}" + assert next(m) == f"2{lf}" + assert next(m) == f"3{lf}" + assert next(m) == f"3{lf}" + assert next(m) == f"4{lf}" + assert next(m) == f"4{lf}" + assert next(m) == f"5{lf}" + assert next(m) == f"5{lf}" + assert next(m) == f"6{lf}" + assert next(m) == f"8{lf}" + assert next(m) == f"9{lf}" with pytest.raises(StopIteration): next(m) -def test_mergesort_key(tmpdir): +@pytest.mark.parametrize("lf", ["\n", "\r\n"]) +def test_mergesort_key(tmpdir, lf): """Test sorting lines based on a key function.""" m = MergeSort(directory=tmpdir, key=lambda line: -int(line[0])) for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: - m.append(f"{value}{LF}") + m.append(f"{value}{lf}") result = "".join(value[0] for value in m) assert result == "986554433211" From 8a599babd76e2487d751b3f46225028315a07eac Mon Sep 17 00:00:00 2001 From: Antti Kaihola <13725+akaihola@users.noreply.github.com> Date: Sat, 20 Apr 2024 11:27:31 +0300 Subject: [PATCH 11/14] Update the change log --- CHANGES.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGES.rst b/CHANGES.rst index 3fc8314..c3796a7 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -13,6 +13,9 @@ Removed Fixed ----- +- Very large tables are now sorted without crashing. This is done by merge sorting + in temporary files. + 1.0.0_ / 2021-09-11 ==================== From efd3980778034482f0f5f54f0008ab6e31b69a80 Mon Sep 17 00:00:00 2001 From: Xun Cai Date: Mon, 22 Apr 2024 14:54:38 +1000 Subject: [PATCH 12/14] use self._key to sort in MergeSort even if it fits the memory. add max_memory command line option. --- pgtricks/mergesort.py | 4 ++-- pgtricks/pg_dump_splitsort.py | 7 +++++-- pgtricks/tests/test_mergesort.py | 8 +++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pgtricks/mergesort.py b/pgtricks/mergesort.py index 9d235ca..f22f020 100644 --- a/pgtricks/mergesort.py +++ b/pgtricks/mergesort.py @@ -25,7 +25,7 @@ def __init__( self._partitions: list[IO[bytes]] = [] self._iterating: Iterable[str] | None = None self._buffer: list[str] = [] - self._memory_counter = 0 + self._memory_counter: int = sys.getsizeof(self._buffer) self._flush() def append(self, line: str) -> None: @@ -68,7 +68,7 @@ def __next__(self) -> str: ) else: # All lines fit in memory. Iterate the list of lines directly. - self._iterating = iter(sorted(self._buffer)) + self._iterating = iter(sorted(self._buffer, key=self._key)) return next(cast(Iterator[str], self._iterating)) def __iter__(self) -> Iterator[str]: diff --git a/pgtricks/pg_dump_splitsort.py b/pgtricks/pg_dump_splitsort.py index 3732a02..86e20a6 100755 --- a/pgtricks/pg_dump_splitsort.py +++ b/pgtricks/pg_dump_splitsort.py @@ -59,7 +59,7 @@ def group(self, group1: str) -> str: def split_sql_file( # noqa: C901 too complex sql_filepath: str, - max_memory: int = 10**8, + max_memory: int = 100 * 2 ** 20, ) -> None: """Split a SQL file so that each COPY statement is in its own file.""" directory = os.path.dirname(sql_filepath) @@ -125,7 +125,10 @@ def new_output(filename: str) -> IO[str]: def main() -> None: - split_sql_file(sys.argv[1]) + max_memory = 100 * 2 ** 20 + if len(sys.argv) > 2: + max_memory = int(sys.argv[2]) * 2 ** 20 + split_sql_file(sys.argv[1], max_memory) if __name__ == '__main__': diff --git a/pgtricks/tests/test_mergesort.py b/pgtricks/tests/test_mergesort.py index ca1f27c..5340fdb 100644 --- a/pgtricks/tests/test_mergesort.py +++ b/pgtricks/tests/test_mergesort.py @@ -1,11 +1,12 @@ """Tests for the `pgtricks.mergesort` module.""" - +import functools from types import GeneratorType from typing import Iterable, cast import pytest from pgtricks.mergesort import MergeSort +from pgtricks.pg_dump_splitsort import linecomp # This is the biggest amount of memory which can't hold two one-character lines on any # platform. On Windows it's slightly smaller than on Unix. @@ -72,8 +73,8 @@ def test_mergesort_iterate_disk(tmpdir, lf): @pytest.mark.parametrize("lf", ["\n", "\r\n"]) def test_mergesort_iterate_memory(tmpdir, lf): """Test iterating over the sorted lines when all lines fit in memory.""" - m = MergeSort(directory=tmpdir, max_memory=1000000) - for value in [3, 1, 4, 1, 5, 9, 2, 6, 5, 3, 8, 4]: + m = MergeSort(directory=tmpdir, max_memory=1000000, key=functools.cmp_to_key(linecomp)) + for value in [3, 1, 4, 1, 5, 9, 2, 10, 6, 5, 3, 8, 4]: m.append(f"{value}{lf}") assert next(m) == f"1{lf}" assert not isinstance(m._iterating, GeneratorType) @@ -89,6 +90,7 @@ def test_mergesort_iterate_memory(tmpdir, lf): assert next(m) == f"6{lf}" assert next(m) == f"8{lf}" assert next(m) == f"9{lf}" + assert next(m) == f"10{lf}" with pytest.raises(StopIteration): next(m) From f8aedf8aad977d4f6549879075f123410253f851 Mon Sep 17 00:00:00 2001 From: Xun Cai Date: Mon, 22 Apr 2024 15:02:32 +1000 Subject: [PATCH 13/14] code style --- pgtricks/pg_dump_splitsort.py | 8 ++++---- pgtricks/tests/test_mergesort.py | 13 +++++++++---- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pgtricks/pg_dump_splitsort.py b/pgtricks/pg_dump_splitsort.py index 86e20a6..024b2fa 100755 --- a/pgtricks/pg_dump_splitsort.py +++ b/pgtricks/pg_dump_splitsort.py @@ -59,7 +59,7 @@ def group(self, group1: str) -> str: def split_sql_file( # noqa: C901 too complex sql_filepath: str, - max_memory: int = 100 * 2 ** 20, + max_memory: int = 100 * 2**20, ) -> None: """Split a SQL file so that each COPY statement is in its own file.""" directory = os.path.dirname(sql_filepath) @@ -125,10 +125,10 @@ def new_output(filename: str) -> IO[str]: def main() -> None: - max_memory = 100 * 2 ** 20 + max_memory = 100 * 2**20 if len(sys.argv) > 2: - max_memory = int(sys.argv[2]) * 2 ** 20 - split_sql_file(sys.argv[1], max_memory) + max_memory = int(sys.argv[2]) * 2**20 + split_sql_file(sys.argv[1], max_memory) if __name__ == '__main__': diff --git a/pgtricks/tests/test_mergesort.py b/pgtricks/tests/test_mergesort.py index 5340fdb..6f7c0b6 100644 --- a/pgtricks/tests/test_mergesort.py +++ b/pgtricks/tests/test_mergesort.py @@ -1,4 +1,5 @@ """Tests for the `pgtricks.mergesort` module.""" + import functools from types import GeneratorType from typing import Iterable, cast @@ -26,7 +27,7 @@ def test_mergesort_append(tmpdir, lf): assert len(m._partitions) == 1 pos = m._partitions[0].tell() m._partitions[0].seek(0) - assert m._partitions[0].read() == f"1{lf}2{lf}".encode("UTF-8") + assert m._partitions[0].read() == f"1{lf}2{lf}".encode() assert pos == len(f"1{lf}2{lf}") @@ -40,10 +41,10 @@ def test_mergesort_flush(tmpdir, lf): assert len(m._partitions) == 2 assert m._partitions[0].tell() == len(f"1{lf}2{lf}") m._partitions[0].seek(0) - assert m._partitions[0].read() == f"1{lf}2{lf}".encode("UTF-8") + assert m._partitions[0].read() == f"1{lf}2{lf}".encode() pos = m._partitions[1].tell() m._partitions[1].seek(0) - assert m._partitions[1].read() == f"3{lf}".encode("UTF-8") + assert m._partitions[1].read() == f"3{lf}".encode() assert pos == len(f"3{lf}") @@ -73,7 +74,11 @@ def test_mergesort_iterate_disk(tmpdir, lf): @pytest.mark.parametrize("lf", ["\n", "\r\n"]) def test_mergesort_iterate_memory(tmpdir, lf): """Test iterating over the sorted lines when all lines fit in memory.""" - m = MergeSort(directory=tmpdir, max_memory=1000000, key=functools.cmp_to_key(linecomp)) + m = MergeSort( + directory=tmpdir, + max_memory=1000000, + key=functools.cmp_to_key(linecomp), + ) for value in [3, 1, 4, 1, 5, 9, 2, 10, 6, 5, 3, 8, 4]: m.append(f"{value}{lf}") assert next(m) == f"1{lf}" From 84d50b9c13fe5f60bb5b1c57f6219975239162b0 Mon Sep 17 00:00:00 2001 From: Antti Kaihola <13725+akaihola@users.noreply.github.com> Date: Mon, 22 Apr 2024 10:20:29 +0300 Subject: [PATCH 14/14] Support memory size units, use -m option --- pgtricks/pg_dump_splitsort.py | 37 ++++++++++++++++++++---- pgtricks/tests/test_pg_dump_splitsort.py | 29 ++++++++++++++++++- 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/pgtricks/pg_dump_splitsort.py b/pgtricks/pg_dump_splitsort.py index 024b2fa..aab1258 100755 --- a/pgtricks/pg_dump_splitsort.py +++ b/pgtricks/pg_dump_splitsort.py @@ -6,12 +6,14 @@ import io import os import re -import sys +from argparse import ArgumentParser from typing import IO, Iterable, Match, Pattern, cast from pgtricks.mergesort import MergeSort COPY_RE = re.compile(r'COPY .*? \(.*?\) FROM stdin;\n$') +KIBIBYTE, MEBIBYTE, GIBIBYTE = 2**10, 2**20, 2**30 +MEMORY_UNITS = {"": 1, "k": KIBIBYTE, "m": MEBIBYTE, "g": GIBIBYTE} def try_float(s1: str, s2: str) -> tuple[str, str] | tuple[float, float]: @@ -59,7 +61,7 @@ def group(self, group1: str) -> str: def split_sql_file( # noqa: C901 too complex sql_filepath: str, - max_memory: int = 100 * 2**20, + max_memory: int = 100 * MEBIBYTE, ) -> None: """Split a SQL file so that each COPY statement is in its own file.""" directory = os.path.dirname(sql_filepath) @@ -124,11 +126,34 @@ def new_output(filename: str) -> IO[str]: flush() +def memory_size(size: str) -> int: + """Parse a human-readable memory size. + + :param size: The memory size to parse, e.g. "100MB". + :return: The memory size in bytes. + :raise ValueError: If the memory size is invalid. + + """ + match = re.match(r"([\d._]+)\s*([kmg]?)b?", size.lower().strip()) + if not match: + message = f"Invalid memory size: {size}" + raise ValueError(message) + return int(float(match.group(1)) * MEMORY_UNITS[match.group(2)]) + + def main() -> None: - max_memory = 100 * 2**20 - if len(sys.argv) > 2: - max_memory = int(sys.argv[2]) * 2**20 - split_sql_file(sys.argv[1], max_memory) + parser = ArgumentParser(description="Split a SQL file into smaller files.") + parser.add_argument("sql_filepath", help="The SQL file to split.") + parser.add_argument( + "-m", + "--max-memory", + default=100 * MEBIBYTE, + type=memory_size, + help="Max memory to use, e.g. 50_000, 200000000, 100kb, 100MB (default), 2Gig.", + ) + args = parser.parse_args() + + split_sql_file(args.sql_filepath, args.max_memory) if __name__ == '__main__': diff --git a/pgtricks/tests/test_pg_dump_splitsort.py b/pgtricks/tests/test_pg_dump_splitsort.py index 080d8d6..74e6b56 100644 --- a/pgtricks/tests/test_pg_dump_splitsort.py +++ b/pgtricks/tests/test_pg_dump_splitsort.py @@ -3,7 +3,7 @@ import pytest -from pgtricks.pg_dump_splitsort import linecomp, split_sql_file, try_float +from pgtricks.pg_dump_splitsort import linecomp, memory_size, split_sql_file, try_float @pytest.mark.parametrize( @@ -183,3 +183,30 @@ def test_split_sql_file(tmpdir): assert (tmpdir / "0000_prologue.sql").read() == PROLOGUE assert (tmpdir / "0001_public.table1.sql").read() == TABLE1_COPY_SORTED assert (tmpdir / "9999_epilogue.sql").read() == EPILOGUE + + +@pytest.mark.parametrize( + ("size", "expect"), + [ + ("0", 0), + ("1", 1), + ("1k", 1024), + ("1m", 1024**2), + ("1g", 1024**3), + ("100_000K", 102400000), + ("1.5M", 1536 * 1024), + ("1.5G", 1536 * 1024**2), + ("1.5", 1), + ("1.5 kibibytes", 1536), + ("1.5 Megabytes", 1024 * 1536), + ("1.5 Gigs", 1024**2 * 1536), + ("1.5KB", 1536), + (".5MB", 512 * 1024), + ("20GB", 20 * 1024**3), + ], +) +def test_memory_size(size, expect): + """Test parsing human-readable memory sizes with `memory_size`.""" + result = memory_size(size) + + assert result == expect