Skip to content

Commit

Permalink
Add ruff formating, support numpy 2 (#126)
Browse files Browse the repository at this point in the history
* [pre-commit.ci] pre-commit autoupdate

updates:
- [github.com/PyCQA/flake8: 7.0.0 → 7.1.0](PyCQA/flake8@7.0.0...7.1.0)

* maint: add ruff

* docs: add docs deps

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
steven-murray and pre-commit-ci[bot] authored Jun 27, 2024
1 parent 2a7e64e commit 75c55ec
Show file tree
Hide file tree
Showing 31 changed files with 314 additions and 426 deletions.
44 changes: 9 additions & 35 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,45 +16,19 @@ repos:
- id: mixed-line-ending
args: ['--fix=no']

- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
- flake8-comprehensions
- flake8-logging-format
- flake8-builtins
- flake8-eradicate
- pep8-naming
- flake8-pytest
- flake8-docstrings
- flake8-rst-docstrings
- flake8-rst
- flake8-copyright
# - flake8-ownership
- flake8-markdown
- flake8-bugbear
- flake8-comprehensions
- flake8-print


- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
hooks:
- id: black

- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort

- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: rst-backticks

- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.10
hooks:
- id: pyupgrade
args: [--py38-plus]
# Run the linter.
- id: ruff
args: [--fix]

# Run the formatter.
- id: ruff-format
14 changes: 6 additions & 8 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
"""Configuration for docs."""

import os
from datetime import datetime
from datetime import datetime, timezone

from py21cmsense import __version__

extensions = [
Expand Down Expand Up @@ -35,9 +35,9 @@
source_suffix = ".rst"
master_doc = "index"
project = "21cmSense"
year = str(datetime.now().year)
year = str(datetime.now(tz=timezone.utc).year)
author = "Jonathan Pober and Steven Murray"
copyright = "{0}, {1}".format(year, author)
copyright = f"{year}, {author}"
version = release = __version__
templates_path = ["templates"]

Expand Down Expand Up @@ -68,9 +68,7 @@
napoleon_use_rtype = False
napoleon_use_param = False

mathjax_path = (
"https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"
)
mathjax_path = "https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"

exclude_patterns = [
"_build",
Expand Down
58 changes: 57 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dynamic = ['version']

# Add here dependencies of your project (semicolon/line-separated), e.g.
dependencies = [
"numpy",
"numpy<2.0", # Restriction can be lifted once pyuvdata is good with numpy 2
"scipy",
"future",
"click",
Expand Down Expand Up @@ -64,6 +64,7 @@ dev = [
"21cmSense[docs,test]",
"pre-commit",
"commitizen",
"ruff"
]

[project.scripts]
Expand All @@ -72,6 +73,61 @@ sense = "py21cmsense.cli:main"

[tool.setuptools_scm]

[tool.ruff]
line-length = 100
target-version = "py39"

[tool.ruff.lint]
extend-select = [
"UP", # pyupgrade
"E", # pycodestyle
"W", # pycodestyle warning
"C90", # mccabe complexity
"I", # isort
"N", # pep8-naming
"D", # docstyle
"B", # bugbear
"A", # builtins
"C4", # comprehensions
"DTZ", # datetime
"FA", # future annotations
"PIE", # flake8-pie
"T", # print statements
"PT", # pytest-style
"Q", # quotes
"SIM", # simplify
# "PTH", # use Pathlib
"ERA", # kill commented code
"NPY", # numpy-specific rules
"PERF", # performance
"RUF", # ruff-specific rules
]
ignore = [
"DTZ007", # use %z in strptime
"A003", # class attribute shadows python builtin
"B008", # function call in argument defaults
"N802", # TODO: remove this (function name should be lower-case)
"B019", # no using lru_cache on methods
"G004", # logging uses f-string
"D107", # no docstring in __init__
]
[tool.ruff.lint.per-file-ignores]
"tests/*.py" = [
"D103", # ignore missing docstring in tests
"DTZ", # ignore datetime in tests
"T", # print statements
]
"docs/conf.py" = [
"A", # conf.py can shadow builtins
"ERA",
]

[tool.ruff.lint.pydocstyle]
convention = 'numpy'
property-decorators = ['property', 'functools.cached_property']

[tool.ruff.lint.mccabe]
max-complexity = 15

[tool.pytest.ini_options]
# Options for py.test:
Expand Down
11 changes: 11 additions & 0 deletions src/py21cmsense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
finally:
del version, PackageNotFoundError

__all__ = [
"data",
"theory",
"yaml",
"hera",
"BaselineRange",
"GaussianBeam",
"Observation",
"Observatory",
"PowerSpectrum",
]
from . import data, theory, yaml
from .antpos import hera
from .baseline_filters import BaselineRange
Expand Down
12 changes: 8 additions & 4 deletions src/py21cmsense/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from astropy.time import Time
from pyuvdata import utils as uvutils

from . import config


def between(xmin, xmax):
"""Return an attrs validation function that checks a number is within bounds."""
Expand All @@ -19,12 +17,18 @@ def validator(instance, att, val):


def positive(instance, att, x):
"""An attrs validator that checks a value is positive."""
"""Check that a value is positive.
This is an attrs validator.
"""
assert x > 0, "must be positive"


def nonnegative(instance, att, x):
"""An attrs validator that checks a value is non-negative."""
"""Check that a value is non-negative.
This is an attrs validator.
"""
assert x >= 0, "must be non-negative"


Expand Down
12 changes: 2 additions & 10 deletions src/py21cmsense/antpos.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import numpy as np
from astropy import units as un
from typing import Sequence

from . import units as tp
from . import yaml
Expand Down Expand Up @@ -52,10 +51,7 @@ def hera(

sep = separation.to_value("m")

if row_separation is None:
row_sep = sep * np.sqrt(3) / 2
else:
row_sep = row_separation.to_value("m")
row_sep = sep * np.sqrt(3) / 2 if row_separation is None else row_separation.to_value("m")

# construct the main hexagon
positions = []
Expand Down Expand Up @@ -97,11 +93,7 @@ def hera(
exterior_hex_num = outriggers + 2
for row in range(exterior_hex_num - 1, -exterior_hex_num, -1):
for col in range(2 * exterior_hex_num - abs(row) - 1):
x_pos = (
((2 - (2 * exterior_hex_num - abs(row))) / 2 + col)
* sep
* (hex_num - 1)
)
x_pos = ((2 - (2 * exterior_hex_num - abs(row))) / 2 + col) * sep * (hex_num - 1)
y_pos = row * (hex_num - 1) * row_sep
theta = np.arctan2(y_pos, x_pos)
if np.sqrt(x_pos**2 + y_pos**2) > sep * (hex_num + 1):
Expand Down
20 changes: 6 additions & 14 deletions src/py21cmsense/baseline_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
"""

import abc

import attr
import numpy as np
import warnings
from astropy import units as un
from pathlib import Path

from . import units as tp

Expand Down Expand Up @@ -44,29 +43,22 @@ def __call__(self, bl: tp.Length) -> bool:
bool
True if the baseline should be included.
"""
pass # pragma: no cover
# pragma: no cover


@attr.define
class BaselineRange(BaselineFilter):
"""Theory model from EOS2021 (https://arxiv.org/abs/2110.13919)."""

bl_min: tp.Length = attr.field(
default=0 * un.m, validator=tp.vld_physical_type("length")
)
bl_max: tp.Length = attr.field(
default=np.inf * un.m, validator=tp.vld_physical_type("length")
)
direction: str = attr.field(
default="mag", validator=attr.validators.in_(("ew", "ns", "mag"))
)
bl_min: tp.Length = attr.field(default=0 * un.m, validator=tp.vld_physical_type("length"))
bl_max: tp.Length = attr.field(default=np.inf * un.m, validator=tp.vld_physical_type("length"))
direction: str = attr.field(default="mag", validator=attr.validators.in_(("ew", "ns", "mag")))

@bl_max.validator
def _bl_max_vld(self, att, val):
if val <= self.bl_min:
raise ValueError(
"bl_max must be greater than bl_min, got "
f"bl_min={self.bl_min} and bl_max={val}"
"bl_max must be greater than bl_min, got " f"bl_min={self.bl_min} and bl_max={val}"
)

def __call__(self, bl: tp.Length) -> bool:
Expand Down
27 changes: 13 additions & 14 deletions src/py21cmsense/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from __future__ import annotations

from abc import ABCMeta, abstractmethod

import attr
from abc import ABCMeta, abstractmethod, abstractproperty
from astropy import constants as cnst
from astropy import units as un
from hickleable import hickleable
Expand Down Expand Up @@ -37,25 +38,25 @@ def at(self, frequency: tp.Frequency) -> PrimaryBeam:
"""Get a copy of the object at a new frequency."""
return attr.evolve(self, frequency=frequency)

@abstractproperty
@property
@abstractmethod
def area(self) -> un.Quantity[un.steradian]:
"""Beam area [units: sr]."""
pass

@abstractproperty
@property
@abstractmethod
def width(self) -> un.Quantity[un.radians]:
"""Beam width [units: rad]."""
pass

@abstractproperty
@property
@abstractmethod
def first_null(self) -> un.Quantity[un.radians]:
"""An approximation of the first null of the beam."""
pass

@abstractproperty
@property
@abstractmethod
def sq_area(self) -> un.Quantity[un.steradian]:
"""The area of the beam^2."""
pass

@property
def b_eff(self) -> un.Quantity[un.steradian]:
Expand All @@ -65,10 +66,10 @@ def b_eff(self) -> un.Quantity[un.steradian]:
"""
return self.area**2 / self.sq_area

@abstractproperty
@property
@abstractmethod
def uv_resolution(self) -> un.Quantity[1 / un.radians]:
"""The UV footprint of the beam."""
pass

@classmethod
def from_uvbeam(cls) -> PrimaryBeam:
Expand All @@ -91,9 +92,7 @@ class GaussianBeam(PrimaryBeam):
otherwise defined. This generates the beam size.
"""

dish_size: tp.Length = attr.ib(
validator=(tp.vld_physical_type("length"), ut.positive)
)
dish_size: tp.Length = attr.ib(validator=(tp.vld_physical_type("length"), ut.positive))

@property
def wavelength(self) -> un.Quantity[un.m]:
Expand Down
Loading

0 comments on commit 75c55ec

Please sign in to comment.