diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 838abff0..7e423e54 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -202,8 +202,9 @@ jobs: sudo apt-get update sudo apt-get install --yes --upgrade build-essential cmake protobuf-compiler libssl-dev glibc-source musl-tools - name: Upload wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4 with: + overwrite: true name: wheels path: dist @@ -227,8 +228,9 @@ jobs: target: ${{ matrix.target }} args: --release --out dist --find-interpreter - name: Upload wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4 with: + overwrite: true name: wheels path: dist @@ -251,8 +253,9 @@ jobs: target: ${{ matrix.target }} args: --release --out dist --find-interpreter - name: Upload wheels - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4 with: + overwrite: true name: wheels path: dist @@ -268,8 +271,9 @@ jobs: command: sdist args: --out dist - name: Upload sdist - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4.4 with: + overwrite: true name: wheels path: dist diff --git a/pyproject.toml b/pyproject.toml index a4957551..400bba6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,15 +103,17 @@ dolma = "dolma.cli.__main__:main" [project.optional-dependencies] dev = [ + "Flake8-pyproject>=1.1.0", "black>=22.6.0", - "flake8>=5.0", "flake8-pyi>=22.8.1", - "Flake8-pyproject>=1.1.0", + "flake8>=5.0", "ipdb>=0.13.0", "ipython>=8.4.0", "isort>=5.10.1", - "mypy>=0.971", + "mypy==0.971", "pytest>=5.2", + "types-PyYAML", + "types-dateparser", ] # extension to process code code = ["detect-secrets==1.4.0", "beautifulsoup4>=4", "pygments", "regex"] diff --git a/python/dolma/cli/__init__.py b/python/dolma/cli/__init__.py index 2ef221d7..de6349cd 100644 --- a/python/dolma/cli/__init__.py +++ b/python/dolma/cli/__init__.py @@ -93,7 +93,7 @@ def make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None # here's where we check if T is a dataclass if is_dataclass(typ_): # recursively add subparsers - make_parser(parser, typ_, prefix=field_name) # type: ignore + make_parser(parser, typ_, prefix=field_name) continue if typ_ is bool: diff --git a/python/dolma/cli/mixer.py b/python/dolma/cli/mixer.py index 943d7f74..136c5344 100644 --- a/python/dolma/cli/mixer.py +++ b/python/dolma/cli/mixer.py @@ -128,15 +128,14 @@ def run(cls, parsed_config: MixerConfig): ) # TODO: note that we are not using the syntax here yet; adding it later - stream_config_dict.setdefault("span_replacement", []).append( - { - "span": str(span_replacement.span), - "replacement": str(span_replacement.replacement), - "syntax": span_replacement.syntax, - **min_score_config, - **max_score_config, - } - ) + span_replacement_dict: Dict[str, Any] = { + "span": str(span_replacement.span), + "replacement": str(span_replacement.replacement), + "syntax": span_replacement.syntax, + **min_score_config, + **max_score_config, + } + stream_config_dict.setdefault("span_replacement", []).append(span_replacement_dict) if "span_replacement" not in stream_config_dict and "filter" not in stream_config_dict: raise DolmaConfigError("Either `filter` or `span_replacement` must be specified") diff --git a/python/dolma/core/ft_tagger.py b/python/dolma/core/ft_tagger.py index 53bcc031..728851dd 100644 --- a/python/dolma/core/ft_tagger.py +++ b/python/dolma/core/ft_tagger.py @@ -37,7 +37,7 @@ def __init__(self, model_path: str, model_mode: str) -> None: self.mode = model_mode @classmethod - def train( + def train( # type: ignore[override] cls, train_file: str, save_path: str, @@ -120,7 +120,7 @@ def train( return classifier @classmethod - def test( + def test( # type: ignore[override] cls, test_file: str, model_path: Optional[str] = None, diff --git a/python/dolma/taggers/punctuation.py b/python/dolma/taggers/punctuation.py index d5526d2a..3eafae36 100644 --- a/python/dolma/taggers/punctuation.py +++ b/python/dolma/taggers/punctuation.py @@ -15,9 +15,9 @@ def __init__(self) -> None: r"[[:punct:]]|" r"\s|" r"[" - "\U0001F300-\U0001F64F" - "\U0001F680-\U0001F6FF" - "\u2600-\u26FF\u2700-\u27BF" + "\U0001f300-\U0001f64f" + "\U0001f680-\U0001f6ff" + "\u2600-\u26ff\u2700-\u27bf" r"]+" r")+$", regex.UNICODE, diff --git a/tests/python/test_language.py b/tests/python/test_language.py index b4331868..514612f6 100644 --- a/tests/python/test_language.py +++ b/tests/python/test_language.py @@ -129,14 +129,14 @@ def test_paragraph_with_doc_score(self): class BaseMultilingualTaggerTest(BaseEnglishTaggerTest): - def test_document(self): + def test_document(self) -> None: for doc in self.single_paragraph_docs: result = self.doc_tagger.predict(doc) best_lang = max(result.spans, key=lambda s: s.score) self.assertEqual(best_lang.type, doc.id) self.assertGreater(best_lang.score, 0.7) - def test_paragraph(self): + def test_paragraph(self) -> None: for doc in self.multi_paragraph_docs: result = self.par_tagger.predict(doc) languages = doc.id.split("_") @@ -152,7 +152,7 @@ def test_paragraph(self): self.assertGreater(best_lang.score, 0.7) self.assertEqual(best_lang.type, lang) - def test_paragraph_with_doc_score(self): + def test_paragraph_with_doc_score(self) -> None: return diff --git a/tests/python/test_parallel.py b/tests/python/test_parallel.py index 1287247a..9f644ce3 100644 --- a/tests/python/test_parallel.py +++ b/tests/python/test_parallel.py @@ -1,5 +1,3 @@ -# mypy: disable-error-code="unused-ignore" - import os from pathlib import Path from tempfile import TemporaryDirectory @@ -15,7 +13,7 @@ class MockProcessor(BaseParallelProcessor): @classmethod - def increment_progressbar(cls, queue, /, cnt: int = 0): # type: ignore[override] + def increment_progressbar(cls, queue, /, cnt: int = 0): return super().increment_progressbar(queue, cnt=cnt) @classmethod