diff --git a/rules-engine/src/rules_engine/parser.py b/rules-engine/src/rules_engine/parser.py index 54aff093..ca966e65 100644 --- a/rules-engine/src/rules_engine/parser.py +++ b/rules-engine/src/rules_engine/parser.py @@ -4,6 +4,7 @@ """ import csv import io +import re from datetime import datetime, timedelta from enum import StrEnum @@ -15,6 +16,19 @@ class NaturalGasCompany(StrEnum): NATIONAL_GRID = "national_grid" +class _NaturalGasCompanyBillRegex: + """ + The regex for which to search a natural gas bill to determine its company. + """ + + EVERSOURCE = re.compile( + r"Read Date,Usage,Number of Days,Usage per day,Charge,Average Temperature" + ) + NATIONAL_GRID = re.compile( + r"Name,.*,,,,,\nAddress,.*,,,,,\nAccount Number,.*,,,,,\nService,.*,,,,,\n" + ) + + class _GasBillRowEversource: """ Holds data for one row of an Eversource gas bill CSV. @@ -57,20 +71,43 @@ def __init__(self, row): self.usage = row["USAGE"] -def parse_gas_bill(data: str, company: NaturalGasCompany) -> NaturalGasBillingInput: +def _detect_gas_company(data: str) -> NaturalGasCompany: + """ + Return which natural gas company issued this bill. + """ + if _NaturalGasCompanyBillRegex.NATIONAL_GRID.search(data): + return NaturalGasCompany.NATIONAL_GRID + elif _NaturalGasCompanyBillRegex.EVERSOURCE.search(data): + return NaturalGasCompany.EVERSOURCE + else: + raise ValueError( + """Could not detect which company this bill was from:\n + Regular expressions matched not.""" + ) + + +def parse_gas_bill( + data: str, company: NaturalGasCompany | None = None +) -> NaturalGasBillingInput: """ Parse a natural gas bill from a given natural gas company. + + Tries to automatically detect the company that sent the bill. + Otherwise, requires the company be passed as an argument. """ + if company == None: + company = _detect_gas_company(data) + match company: case NaturalGasCompany.EVERSOURCE: - return parse_gas_bill_eversource(data) + return _parse_gas_bill_eversource(data) case NaturalGasCompany.NATIONAL_GRID: - return parse_gas_bill_national_grid(data) + return _parse_gas_bill_national_grid(data) case _: raise ValueError("Wrong CSV format selected: select another format.") -def parse_gas_bill_eversource(data: str) -> NaturalGasBillingInput: +def _parse_gas_bill_eversource(data: str) -> NaturalGasBillingInput: """ Return a list of gas bill data parsed from an Eversource CSV received as a string. @@ -103,7 +140,7 @@ def parse_gas_bill_eversource(data: str) -> NaturalGasBillingInput: return NaturalGasBillingInput(records=records) -def parse_gas_bill_national_grid(data: str) -> NaturalGasBillingInput: +def _parse_gas_bill_national_grid(data: str) -> NaturalGasBillingInput: """ Return a list of gas bill data parsed from an National Grid CSV received as a string. diff --git a/rules-engine/tests/test_rules_engine/test_engine.py b/rules-engine/tests/test_rules_engine/test_engine.py index 19471d3b..0b4c7421 100644 --- a/rules-engine/tests/test_rules_engine/test_engine.py +++ b/rules-engine/tests/test_rules_engine/test_engine.py @@ -3,7 +3,6 @@ import pytest from pytest import approx - from rules_engine import engine from rules_engine.pydantic_models import ( AnalysisType, diff --git a/rules-engine/tests/test_rules_engine/test_examples.py b/rules-engine/tests/test_rules_engine/test_examples.py index 28b51d65..9a5497c9 100644 --- a/rules-engine/tests/test_rules_engine/test_examples.py +++ b/rules-engine/tests/test_rules_engine/test_examples.py @@ -8,8 +8,6 @@ import pytest from pydantic import BaseModel from pytest import approx -from typing_extensions import Annotated - from rules_engine import engine from rules_engine.pydantic_models import ( NaturalGasBillingInput, @@ -18,6 +16,7 @@ SummaryOutput, TemperatureInput, ) +from typing_extensions import Annotated # Test inputs are provided as separate directory within the "cases/examples" directory # Each subdirectory contains a JSON file (named summary.json) which specifies the inputs for the test runner diff --git a/rules-engine/tests/test_rules_engine/test_parser.py b/rules-engine/tests/test_rules_engine/test_parser.py index c58fc865..115ecf74 100644 --- a/rules-engine/tests/test_rules_engine/test_parser.py +++ b/rules-engine/tests/test_rules_engine/test_parser.py @@ -1,6 +1,7 @@ import pathlib from datetime import date +import pytest from rules_engine import parser from rules_engine.pydantic_models import NaturalGasBillingRecordInput @@ -10,13 +11,13 @@ # of refactoring elsewhere in the codebase. -def _read_gas_bill_eversource(): +def _read_gas_bill_eversource() -> str: """Read a test natural gas bill from a test Eversource CSV""" with open(ROOT_DIR / "feldman" / "natural-gas-eversource.csv") as f: return f.read() -def _read_gas_bill_national_grid(): +def _read_gas_bill_national_grid() -> str: """Read a test natural gas bill from a test National Grid CSV""" with open(ROOT_DIR / "quateman" / "natural-gas-national-grid.csv") as f: return f.read() @@ -76,11 +77,32 @@ def test_parse_gas_bill(): def test_parse_gas_bill_eversource(): """Tests parsing a natural gas bill from Eversource.""" - _validate_eversource(parser.parse_gas_bill_eversource(_read_gas_bill_eversource())) + _validate_eversource(parser._parse_gas_bill_eversource(_read_gas_bill_eversource())) def test_parse_gas_bill_national_grid(): """Tests parsing a natural gas bill from National Grid.""" _validate_national_grid( - parser.parse_gas_bill_national_grid(_read_gas_bill_national_grid()) + parser._parse_gas_bill_national_grid(_read_gas_bill_national_grid()) ) + + +def test_detect_natural_gas_company(): + """Tests if the natural gas company is correctly detected from the parsed csv.""" + read_eversource = _read_gas_bill_eversource() + read_nationalgrid = _read_gas_bill_national_grid() + assert ( + parser._detect_gas_company(read_eversource) + == parser.NaturalGasCompany.EVERSOURCE + ) + assert ( + parser._detect_gas_company(read_nationalgrid) + == parser.NaturalGasCompany.NATIONAL_GRID + ) + + +def test_detect_natural_gas_company_with_error(): + """Tests if an error is raised if the natural gas company is incorrect in the csv.""" + read_csv_string = r"Some bogus string input" + with pytest.raises(ValueError): + parser._detect_gas_company(read_csv_string)