Skip to content

Commit

Permalink
Merge pull request #28 from NCAR/self_aware
Browse files Browse the repository at this point in the history
obs_sequence object aware whether it has assimilation results or not
  • Loading branch information
hkershaw-brown authored Dec 19, 2024
2 parents ea68901 + f8542cc commit 9d2f70b
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 62 deletions.
153 changes: 92 additions & 61 deletions src/pydartdiags/obs_sequence/obs_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,23 @@
import yaml
import struct

def requires_assimilation_info(func):
def wrapper(self, *args, **kwargs):
if self.has_assimilation_info:
return func(self, *args, **kwargs)
else:
raise ValueError("Assimilation information is required to call this function.")
return wrapper

def requires_posterior_info(func):
def wrapper(self, *args, **kwargs):
if self.has_posterior_info:
return func(self, *args, **kwargs)
else:
raise ValueError("Posterior information is required to call this function.")
return wrapper


class obs_sequence:
"""Create an obs_sequence object from an ascii observation sequence file.
Expand Down Expand Up @@ -59,6 +76,8 @@ class obs_sequence:

def __init__(self, file, synonyms=None):
self.loc_mod = 'None'
self.has_assimilation_info = False
self.has_posterior = False
self.file = file
self.synonyms_for_obs = ['NCEP BUFR observation',
'AIRS observation',
Expand All @@ -72,6 +91,17 @@ def __init__(self, file, synonyms=None):
else:
self.synonyms_for_obs.append(synonyms)

if file is None:
# Early exit for testing purposes
self.df = pd.DataFrame()
self.types = {}
self.reverse_types = {}
self.copie_names = []
self.n_copies = 0
self.seq = []
self.all_obs = []
return

module_dir = os.path.dirname(__file__)
self.default_composite_types = os.path.join(module_dir,"composite_types.yaml")

Expand Down Expand Up @@ -103,11 +133,16 @@ def __init__(self, file, synonyms=None):
self.synonyms_for_obs = [synonym.replace(' ', '_') for synonym in self.synonyms_for_obs]
rename_dict = {old: 'observation' for old in self.synonyms_for_obs if old in self.df.columns}
self.df = self.df.rename(columns=rename_dict)

# calculate bias and sq_err is the obs_seq is an obs_seq.final
if 'prior_ensemble_mean'.casefold() in map(str.casefold, self.columns):
self.has_assimilation_info = True
self.df['bias'] = (self.df['prior_ensemble_mean'] - self.df['observation'])
self.df['sq_err'] = self.df['bias']**2 # squared error

if 'posterior_ensemble_mean'.casefold() in map(str.casefold, self.columns):
self.has_posterior_info = True
self.df['posterior_bias'] = (self.df['posterior_ensemble_mean'] - self.df['observation'])
self.df['posterior_sq_err'] = self.df['posterior_bias']**2

def create_all_obs(self):
""" steps through the generator to create a
Expand Down Expand Up @@ -184,7 +219,6 @@ def split_metadata(metadata):
return metadata[:i], metadata[i:]
return metadata, []


def list_to_obs(self, data):
obs = []
obs.append('OBS ' + str(data[0])) # obs_num lots of space
Expand Down Expand Up @@ -312,6 +346,62 @@ def column_headers(self):
heading.append('obs_err_var')
return heading

@requires_assimilation_info
def select_by_dart_qc(self, dart_qc):
"""
Selects rows from a DataFrame based on the DART quality control flag.
Parameters:
df (DataFrame): A pandas DataFrame.
dart_qc (int): The DART quality control flag to select.
Returns:
DataFrame: A DataFrame containing only the rows with the specified DART quality control flag.
Raises:
ValueError: If the DART quality control flag is not present in the DataFrame.
"""
if dart_qc not in self.df['DART_quality_control'].unique():
raise ValueError(f"DART quality control flag '{dart_qc}' not found in DataFrame.")
else:
return self.df[self.df['DART_quality_control'] == dart_qc]

@requires_assimilation_info
def select_failed_qcs(self):
"""
Select rows from the DataFrame where the DART quality control flag is greater than 0.
Returns:
pandas.DataFrame: A DataFrame containing only the rows with a DART quality control flag greater than 0.
"""
return self.df[self.df['DART_quality_control'] > 0]

@requires_assimilation_info
def possible_vs_used(self):
"""
Calculates the count of possible vs. used observations by type.
This function takes a DataFrame containing observation data, including a 'type' column for the observation
type and an 'observation' column. The number of used observations ('used'), is the total number
minus the observations that failed quality control checks (as determined by the `select_failed_qcs` function).
The result is a DataFrame with each observation type, the count of possible observations, and the count of
used observations.
Returns:
pd.DataFrame: A DataFrame with three columns: 'type', 'possible', and 'used'. 'type' is the observation type,
'possible' is the count of all observations of that type, and 'used' is the count of observations of that type
that passed quality control checks.
"""
possible = self.df.groupby('type')['observation'].count()
possible.rename('possible', inplace=True)

failed_qcs = self.select_failed_qcs().groupby('type')['observation'].count()
used = possible - failed_qcs.reindex(possible.index, fill_value=0)
used.rename('used', inplace=True)

return pd.concat([possible, used], axis=1).reset_index()


@staticmethod
def is_binary(file):
"""Check if a file is binary file."""
Expand Down Expand Up @@ -692,65 +782,6 @@ def convert_dart_time(seconds, days):
"""
time = dt.datetime(1601,1,1) + dt.timedelta(days=days, seconds=seconds)
return time

def select_by_dart_qc(df, dart_qc):
"""
Selects rows from a DataFrame based on the DART quality control flag.
Parameters:
df (DataFrame): A pandas DataFrame.
dart_qc (int): The DART quality control flag to select.
Returns:
DataFrame: A DataFrame containing only the rows with the specified DART quality control flag.
Raises:
ValueError: If the DART quality control flag is not present in the DataFrame.
"""
if dart_qc not in df['DART_quality_control'].unique():
raise ValueError(f"DART quality control flag '{dart_qc}' not found in DataFrame.")
else:
return df[df['DART_quality_control'] == dart_qc]

def select_failed_qcs(df):
"""
Selects rows from a DataFrame where the DART quality control flag is greater than 0.
Parameters:
df (DataFrame): A pandas DataFrame.
Returns:
DataFrame: A DataFrame containing only the rows with a DART quality control flag greater than 0.
"""
return df[df['DART_quality_control'] > 0]

def possible_vs_used(df):
"""
Calculates the count of possible vs. used observations by type.
This function takes a DataFrame containing observation data, including a 'type' column for the observation
type and an 'observation' column. The number of used observations ('used'), is the total number
minus the observations that failed quality control checks (as determined by the `select_failed_qcs` function).
The result is a DataFrame with each observation type, the count of possible observations, and the count of
used observations.
Parameters:
df (pd.DataFrame): A DataFrame with at least two columns: 'type' for the observation type and 'observation'
for the observation data. It may also contain other columns required by the `select_failed_qcs` function
to determine failed quality control checks.
Returns:
pd.DataFrame: A DataFrame with three columns: 'type', 'possible', and 'used'. 'type' is the observation type,
'possible' is the count of all observations of that type, and 'used' is the count of observations of that type
that passed quality control checks.
"""
possible = df.groupby('type')['observation'].count()
possible.rename('possible', inplace=True)
used = df.groupby('type')['observation'].count() - select_failed_qcs(df).groupby('type')['observation'].count()
used.rename('used', inplace=True)
return pd.concat([possible, used], axis=1).reset_index()


def construct_composit(df_comp, composite, components):
"""
Expand Down
63 changes: 62 additions & 1 deletion tests/test_obs_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import tempfile
import filecmp
import datetime as dt
import pandas as pd
from pydartdiags.obs_sequence import obs_sequence as obsq

class TestConvertDartTime:
Expand Down Expand Up @@ -46,7 +47,7 @@ def test_read1d(self, obs_seq_file_path):
obj = obsq.obs_sequence(obs_seq_file_path)
assert obj.loc_mod == 'loc1d'
assert len(obj.df) == 40 # 40 obs in the file
assert obj.df.columns.str.contains('posterior').sum() == 22
assert obj.df.columns.str.contains('posterior').sum() == 22 + 2 # members + sq_err + bias
assert obj.df.columns.str.contains('prior').sum() == 22


Expand Down Expand Up @@ -160,5 +161,65 @@ def test_write_ascii(self, ascii_obs_seq_file_path, temp_dir):

# Clean up is handled by the temporary directory context manager

class TestObsDataframe:
@pytest.fixture
def obs_seq(self):
# Create a sample DataFrame to simulate the observation sequence
data = {
'DART_quality_control': [0, 1, 2, 0, 3, 0],
'type': ['type1', 'type2', 'type1', 'type3', 'type2', 'type1'],
'observation': [1.0, 2.0, 3.0, 4.0, 5.0, 5.2]
}
df = pd.DataFrame(data)

# Create an instance of ObsSequence with the sample DataFrame
obs_seq = obsq.obs_sequence(file=None)
obs_seq.df = df
obs_seq.has_assimilation_info = True # Set to True for testing purposes
return obs_seq

def test_select_by_dart_qc(self, obs_seq):
dart_qc_value = 2
result = obs_seq.select_by_dart_qc(dart_qc_value).reset_index(drop=True)

# Expected DataFrame
expected_data = {
'DART_quality_control': [2],
'type': ['type1'],
'observation': [3.0]
}
expected_df = pd.DataFrame(expected_data)

# Assert that the result matches the expected DataFrame, ignoring the index
pd.testing.assert_frame_equal(result, expected_df)

def test_select_failed_qcs(self, obs_seq):
result = obs_seq.select_failed_qcs().reset_index(drop=True)

# Expected DataFrame
expected_data = {
'DART_quality_control': [1, 2, 3],
'type': ['type2', 'type1', 'type2'],
'observation': [2.0, 3.0, 5.0]
}
expected_df = pd.DataFrame(expected_data)

# Assert that the result matches the expected DataFrame, ignoring the index
pd.testing.assert_frame_equal(result, expected_df)

def test_possible_vs_used(self, obs_seq):
result = obs_seq.possible_vs_used()

# Expected DataFrame
expected_data = {
'type': ['type1', 'type2', 'type3'],
'possible': [3, 2, 1],
'used': [2, 0, 1]
}
expected_df = pd.DataFrame(expected_data)

# Assert that the result matches the expected DataFrame, ignoring the index
pd.testing.assert_frame_equal(result, expected_df)

if __name__ == '__main__':
pytest.main()

0 comments on commit 9d2f70b

Please sign in to comment.