Skip to content

Commit

Permalink
Fixes checks (#232)
Browse files Browse the repository at this point in the history
* some stray linting problems
* incorporates @undfined's GH action fixes
* mypy regressions 
  -> pinned version as not doing so creates a lot of churn 
  -> missing explicit type package deps 
  -> some method signature overrides needed to be ignored
  • Loading branch information
cmwilhelm authored Feb 13, 2025
1 parent a824220 commit d42fa1d
Show file tree
Hide file tree
Showing 8 changed files with 31 additions and 28 deletions.
12 changes: 8 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,9 @@ jobs:
sudo apt-get update
sudo apt-get install --yes --upgrade build-essential cmake protobuf-compiler libssl-dev glibc-source musl-tools
- name: Upload wheels
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4.4
with:
overwrite: true
name: wheels
path: dist

Expand All @@ -227,8 +228,9 @@ jobs:
target: ${{ matrix.target }}
args: --release --out dist --find-interpreter
- name: Upload wheels
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4.4
with:
overwrite: true
name: wheels
path: dist

Expand All @@ -251,8 +253,9 @@ jobs:
target: ${{ matrix.target }}
args: --release --out dist --find-interpreter
- name: Upload wheels
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4.4
with:
overwrite: true
name: wheels
path: dist

Expand All @@ -268,8 +271,9 @@ jobs:
command: sdist
args: --out dist
- name: Upload sdist
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4.4
with:
overwrite: true
name: wheels
path: dist

Expand Down
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,17 @@ dolma = "dolma.cli.__main__:main"

[project.optional-dependencies]
dev = [
"Flake8-pyproject>=1.1.0",
"black>=22.6.0",
"flake8>=5.0",
"flake8-pyi>=22.8.1",
"Flake8-pyproject>=1.1.0",
"flake8>=5.0",
"ipdb>=0.13.0",
"ipython>=8.4.0",
"isort>=5.10.1",
"mypy>=0.971",
"mypy==0.971",
"pytest>=5.2",
"types-PyYAML",
"types-dateparser",
]
# extension to process code
code = ["detect-secrets==1.4.0", "beautifulsoup4>=4", "pygments", "regex"]
Expand Down
2 changes: 1 addition & 1 deletion python/dolma/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def make_parser(parser: A, config: Type[DataClass], prefix: Optional[str] = None
# here's where we check if T is a dataclass
if is_dataclass(typ_):
# recursively add subparsers
make_parser(parser, typ_, prefix=field_name) # type: ignore
make_parser(parser, typ_, prefix=field_name)
continue

if typ_ is bool:
Expand Down
17 changes: 8 additions & 9 deletions python/dolma/cli/mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,14 @@ def run(cls, parsed_config: MixerConfig):
)

# TODO: note that we are not using the syntax here yet; adding it later
stream_config_dict.setdefault("span_replacement", []).append(
{
"span": str(span_replacement.span),
"replacement": str(span_replacement.replacement),
"syntax": span_replacement.syntax,
**min_score_config,
**max_score_config,
}
)
span_replacement_dict: Dict[str, Any] = {
"span": str(span_replacement.span),
"replacement": str(span_replacement.replacement),
"syntax": span_replacement.syntax,
**min_score_config,
**max_score_config,
}
stream_config_dict.setdefault("span_replacement", []).append(span_replacement_dict)

if "span_replacement" not in stream_config_dict and "filter" not in stream_config_dict:
raise DolmaConfigError("Either `filter` or `span_replacement` must be specified")
Expand Down
4 changes: 2 additions & 2 deletions python/dolma/core/ft_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, model_path: str, model_mode: str) -> None:
self.mode = model_mode

@classmethod
def train(
def train( # type: ignore[override]
cls,
train_file: str,
save_path: str,
Expand Down Expand Up @@ -120,7 +120,7 @@ def train(
return classifier

@classmethod
def test(
def test( # type: ignore[override]
cls,
test_file: str,
model_path: Optional[str] = None,
Expand Down
6 changes: 3 additions & 3 deletions python/dolma/taggers/punctuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def __init__(self) -> None:
r"[[:punct:]]|"
r"\s|"
r"["
"\U0001F300-\U0001F64F"
"\U0001F680-\U0001F6FF"
"\u2600-\u26FF\u2700-\u27BF"
"\U0001f300-\U0001f64f"
"\U0001f680-\U0001f6ff"
"\u2600-\u26ff\u2700-\u27bf"
r"]+"
r")+$",
regex.UNICODE,
Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,14 @@ def test_paragraph_with_doc_score(self):


class BaseMultilingualTaggerTest(BaseEnglishTaggerTest):
def test_document(self):
def test_document(self) -> None:
for doc in self.single_paragraph_docs:
result = self.doc_tagger.predict(doc)
best_lang = max(result.spans, key=lambda s: s.score)
self.assertEqual(best_lang.type, doc.id)
self.assertGreater(best_lang.score, 0.7)

def test_paragraph(self):
def test_paragraph(self) -> None:
for doc in self.multi_paragraph_docs:
result = self.par_tagger.predict(doc)
languages = doc.id.split("_")
Expand All @@ -152,7 +152,7 @@ def test_paragraph(self):
self.assertGreater(best_lang.score, 0.7)
self.assertEqual(best_lang.type, lang)

def test_paragraph_with_doc_score(self):
def test_paragraph_with_doc_score(self) -> None:
return


Expand Down
4 changes: 1 addition & 3 deletions tests/python/test_parallel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# mypy: disable-error-code="unused-ignore"

import os
from pathlib import Path
from tempfile import TemporaryDirectory
Expand All @@ -15,7 +13,7 @@

class MockProcessor(BaseParallelProcessor):
@classmethod
def increment_progressbar(cls, queue, /, cnt: int = 0): # type: ignore[override]
def increment_progressbar(cls, queue, /, cnt: int = 0):
return super().increment_progressbar(queue, cnt=cnt)

@classmethod
Expand Down

0 comments on commit d42fa1d

Please sign in to comment.