Skip to content

Commit

Permalink
_get_common_dtype must check PintTypes compat
Browse files Browse the repository at this point in the history
PR number 137 adds _get_common_dtype to PintType so that `PintType` operations can be performed on a mix of `PintType` and numeric values (with the later being promoted to the `PintType` for the purposes of the operation).  However, when there are multiple `PintType` elements present, it is important that all elements are in fact compatible, lest the operation attempt to combine two `PintType` elements that are not unit-compatible.
  • Loading branch information
MichaelTiemannOSC committed Jul 19, 2024
1 parent f205835 commit 32c9d64
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pint_pandas/pint_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def _get_common_dtype(self, dtypes):
In order to be able to be able to perform operation on ``PintType``
with scalars, mix of ``PintType`` and numeric values are allowed.
But all ``PintType`` elements must be compatible.
Parameters
Expand All @@ -216,7 +217,13 @@ def _get_common_dtype(self, dtypes):
if all(
isinstance(x, PintType) or pd.api.types.is_numeric_dtype(x) for x in dtypes
):
return self
PintType_list = [x for x in dtypes if isinstance(x, PintType)]
if len(PintType_list) < 2:
return self
if all (PintType_list[0].units.is_compatible_with(x.units) for x in PintType_list[1:]):
return self
else:
return None
else:
return None

Expand Down
11 changes: 11 additions & 0 deletions pint_pandas/testsuite/test_issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,3 +277,14 @@ def test_eval(self):
)
tm.assert_series_equal(df.eval("a / b"), df["a"] / df["b"])
tm.assert_series_equal(df.eval("a / c"), df["a"] / df["c"])

def test_mixed_df(self):
df = pd.DataFrame(
{
"a": pd.Series([1.0, 2.0, 3.0], dtype="pint[meter]"),
"b": pd.Series([4.0, 5.0, 6.0], dtype="pint[second]"),
"c": [1.0, 2.0, 3.0],
}
)

assert df["a"][0] == df.iloc[0][0]

0 comments on commit 32c9d64

Please sign in to comment.