Skip to content

Commit

Permalink
Merge pull request #389 from gro-intelligence/GAIA-29731
Browse files Browse the repository at this point in the history
GAIA-29731 Add client function for v2 area weighting service
  • Loading branch information
cn1036 authored Jan 5, 2024
2 parents a53cfc5 + 8e448ab commit fcca87d
Show file tree
Hide file tree
Showing 4 changed files with 339 additions and 1 deletion.
58 changes: 58 additions & 0 deletions groclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import time
import pandas as pd

from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -1990,3 +1991,60 @@ def reverse_geocode_points(self, points: list):
```
"""
return lib.reverse_geocode_points(self.access_token, self.api_host, points)

def get_area_weighted_series_df(
self,
series: Dict[str, int],
region_id: int,
weights: Optional[List[Dict[str, int]]] = None,
weight_names: Optional[List[str]] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
method: Optional[str] = "sum",
) -> pd.DataFrame:
"""Compute weighted average on selected series with the given weights.
Returns a dataframe that contains weighted values and metadata.
Parameters
----------
series: dict
A dictionary that maps required entity types to its value.
e.g. {"item_id": 321, "metric_id": 70029, "frequency_id": 3, "source_id": 3}
region_id: integer
The region for which the weighted series will be computed
Supported region levels are (1, 2, 3, 4, 5, 8)
weights: list of dict
A list of dictionaries with each representing a weight object. Mutually exclusive with "weight_names".
e.g. [{"item_id": 274, "metric_id": 2120001, "frequency_id": 15, "source_id": 88}, ...]
weight_names: list of strs
List of weight names that will be used to weight the provided series. Mutually exclusive with "weights".
e.g. ['Barley', 'Corn']
For getting the full list of valid weight names, please call :meth:`~.get_area_weighting_weight_names`
start_date: str, optional
A timestamp of the format 'YYYY-MM-DD', e.g. '2023-01-01'
end_date: str, optional
A timestamp of the format 'YYYY-MM-DD', e.g. '2023-01-01'
method: str, optional, default="sum"
Multi-crop weights can be calculated with either 'sum' or 'normalize' method.
Returns
-------
DataFrame
Example::
start_date value end_date available_date region_id item_id metric_id frequency_id unit_id source_id weights
0 2016-04-26 0.502835 2016-04-26 2016-04-28 1215 321 70029 1 189 112 [{"weight_name": "Corn", "item_id": 274, "metric_id": 2120001, ...}, ...]
1 2016-04-27 0.509729 2016-04-27 2016-04-29 1215 321 70029 1 189 112 [{"weight_name": "Corn", "item_id": 274, "metric_id": 2120001, ...}, ...]
"""
return lib.get_area_weighted_series_df(
self.access_token,
self.api_host,
series,
region_id,
weights,
weight_names,
start_date,
end_date,
method,
)
56 changes: 56 additions & 0 deletions groclient/client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from mock import patch, MagicMock
from datetime import date
import os
import pandas as pd
from pandas.testing import assert_frame_equal
from unittest import TestCase

from groclient import GroClient
Expand Down Expand Up @@ -308,6 +310,29 @@ def mock_reverse_geocode_points(access_token, api_host, points):
]


def mock_get_area_weighted_series_df(
access_token,
api_host,
series,
region_id,
weights,
weight_names,
start_date,
end_date,
method,
):
return pd.DataFrame(
data=[
[0.42, "2023-10-07", "2023-10-07", "2023-10-08", 321, 70029, 3, 189, 3, '[{"item_id": 95, "metric_id": 2120001, "frequency_id": 15, "unit_id": 42, "source_id": 88, "weight_name": "Wheat"}]'],
[0.12, "2023-10-08", "2023-10-08", "2023-10-09", 321, 70029, 3, 189, 3, '[{"item_id": 95, "metric_id": 2120001, "frequency_id": 15, "unit_id": 42, "source_id": 88, "weight_name": "Wheat"}]'],
],
columns=[
'value', 'start_date', 'end_date', 'available_date', 'item_id',
'metric_id', 'frequency_id', 'unit_id', 'source_id', 'weights_metadata'
]
)


@patch("groclient.lib.get_available", MagicMock(side_effect=mock_get_available))
@patch("groclient.lib.list_available", MagicMock(side_effect=mock_list_available))
@patch("groclient.lib.lookup", MagicMock(side_effect=mock_lookup))
Expand Down Expand Up @@ -345,6 +370,10 @@ def mock_reverse_geocode_points(access_token, api_host, points):
"groclient.lib.reverse_geocode_points",
MagicMock(side_effect=mock_reverse_geocode_points),
)
@patch(
"groclient.lib.get_area_weighted_series_df",
MagicMock(side_effect=mock_get_area_weighted_series_df),
)
class GroClientTests(TestCase):
def setUp(self):
self.client = GroClient(MOCK_HOST, MOCK_TOKEN)
Expand Down Expand Up @@ -688,6 +717,33 @@ def test_reverse_geocode_points(self):
],
)

def test_get_area_weighted_series_df(self):
selection = {
"series": {
"metric_id": 70029,
"item_id": 321,
"frequency_id": 3,
"source_id": 3
},
"region_id": 1215,
"weight_names": ["Wheat"],
"start_date": "2023-10-01"
}
expected = pd.DataFrame(
data=[
[0.42, "2023-10-07", "2023-10-07", "2023-10-08", 321, 70029, 3, 189, 3, '[{"item_id": 95, "metric_id": 2120001, "frequency_id": 15, "unit_id": 42, "source_id": 88, "weight_name": "Wheat"}]'],
[0.12, "2023-10-08", "2023-10-08", "2023-10-09", 321, 70029, 3, 189, 3, '[{"item_id": 95, "metric_id": 2120001, "frequency_id": 15, "unit_id": 42, "source_id": 88, "weight_name": "Wheat"}]'],
],
columns=[
'value', 'start_date', 'end_date', 'available_date', 'item_id',
'metric_id', 'frequency_id', 'unit_id', 'source_id', 'weights_metadata'
]
)
assert_frame_equal(
self.client.get_area_weighted_series_df(**selection),
expected
)


class GroClientConstructorTests(TestCase):
PROD_API_HOST = "api.gro-intelligence.com"
Expand Down
122 changes: 121 additions & 1 deletion groclient/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
import time
import platform
import warnings
import pandas as pd

from pkg_resources import get_distribution, DistributionNotFound
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Union, Any

try:
# functools are native in Python 3.2.3+
Expand Down Expand Up @@ -912,6 +913,125 @@ def reverse_geocode_points(access_token: str, api_host: str, points: list):
return r.json()["data"]


def validate_series_object(series_object):
required_entities = {"item_id", "metric_id", "frequency_id", "source_id"}

for entity in required_entities:
if entity not in series_object or not series_object[entity]:
raise ValueError(f"{entity} is required and supposed to be a positive integer.")

invalid_atts = set(series_object.keys()).difference(required_entities)
if len(invalid_atts):
raise ValueError(f"Unsupported fields: {invalid_atts}.")


def generate_payload_for_v2_area_weighting(
series: Dict[str, int],
region_id: int,
weights: Optional[List[Dict[str, int]]] = None,
weight_names: Optional[List[str]] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
method: Optional[str] = "sum",
):
payload = {
"region_id": region_id,
"method": method,
}

# validate series and weights selection
try:
validate_series_object(series)
except ValueError as error:
raise ValueError(f"Failed to parse series selection: {error}")
payload["series"] = series

if weights and len(weights):
if weight_names:
raise ValueError(f"weights and weight_names are mutually exclusive. Please specify only one.")
try:
for weight in weights:
validate_series_object(weight)
except ValueError as error:
raise ValueError(f"Failed to parse weight selections: {error}")
payload["weights"] = weights
else:
if not weight_names or not len(weight_names):
raise ValueError(f"Please specify either weights or weight_names in params.")
payload["weight_names"] = weight_names

# add optional attrs
if start_date:
payload["start_date"] = start_date
if end_date:
payload["end_date"] = end_date

return json.dumps(payload)


def format_v2_area_weighting_response(response_content: Dict[str, Any]) -> pd.DataFrame:
try:
data_points = response_content["data_points"]
weighted_series_df = pd.DataFrame(data_points)

if not len(weighted_series_df):
return weighted_series_df

# convert unix timestamps and rename date cols
datetime_col_mappings = {
"start_date": "timestamp", # add start_date col which is equivalent to end_date
"end_date": "timestamp",
"available_date": "available_timestamp",
}
for new_col, col in datetime_col_mappings.items():
weighted_series_df[new_col] = pd.to_datetime(weighted_series_df[col], unit='s').dt.strftime('%Y-%m-%d')
weighted_series_df = weighted_series_df.drop(columns=datetime_col_mappings.values())

# append selected fields of series metadata
for key in ['item_id', 'metric_id', 'frequency_id', 'unit_id', 'source_id']:
weighted_series_df[key] = response_content["series_description"][key]

# append weights metadata as a single json
weighted_series_df["weights_metadata"] = json.dumps(response_content["weights_description"])

return weighted_series_df
except KeyError as key:
raise Exception(f"Bad Implementation Error: missing {key} in API response.")


def get_area_weighted_series_df(
access_token: str,
api_host: str,
series: Dict[str, int],
region_id: int,
weights: Optional[List[Dict[str, int]]] = None,
weight_names: Optional[List[str]] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
method: Optional[str] = "sum",
) -> pd.DataFrame:
payload = generate_payload_for_v2_area_weighting(
series,
region_id,
weights,
weight_names,
start_date,
end_date,
method
)
response = requests.post(
f"https://{api_host}/v2/area-weighting",
data=payload,
headers={"Authorization": "Bearer " + access_token},
)

if response.status_code != 200:
raise Exception(response.text)

weighted_series_df = format_v2_area_weighting_response(response.json())
return weighted_series_df


if __name__ == "__main__":
# To run doctests:
# $ python lib.py -v
Expand Down
Loading

0 comments on commit fcca87d

Please sign in to comment.