Skip to content

Commit

Permalink
Fix function errors in bin_chi2 function.
Browse files Browse the repository at this point in the history
  • Loading branch information
l0o0 committed Nov 26, 2018
1 parent 0714215 commit dfeaaeb
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions rosaceae/bins.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ def chi2(a, bad_rate):
chi = (a[0] - b[0])**2 / b[0] + (a[1] - b[1])**2 / b[1]
return chi


def recursion(groups, counts, bins, numeric=False, verbose=False):
def recursion(groups, counts, bins, bad_rate, numeric=False, verbose=False):
max_chi = 0
if not numeric:
for _i, i in combinations(range(len(groups)), 2):
Expand Down Expand Up @@ -220,16 +219,17 @@ def recursion(groups, counts, bins, numeric=False, verbose=False):
elif chi_after > max_chi:
max_com_idx = (i, i+1)
max_chi = chi_after

merged = (groups[max_com_idx[0]][0], groups[max_com_idx[1]][1]) # create a new boundary
if verbose:
print(groups[max_com_idx[0]], groups[max_com_idx[1]], '-->' ,merged)
print('*',groups[max_com_idx[0]], groups[max_com_idx[1]], '-->' ,merged)
groups = groups[:max_com_idx[0]] + [merged] + groups[max_com_idx[1]+1:]
merged_counts = counts[max_com_idx[0]] + counts[max_com_idx[1]]
counts = counts[:max_com_idx[0]] + [merged_counts] + counts[max_com_idx[1]+1:]
if len(groups) <= bins:
return groups
else:
return recursion(groups, counts, bins, numeric=numeric, verbose=verbose)
return recursion(groups, counts, bins, bad_rate, numeric=numeric, verbose=verbose)


def bin_chi2(xarray, y, bins, min_sample=0.01, na_omit=True, verbose=False):
Expand All @@ -242,8 +242,9 @@ def bin_chi2(xarray, y, bins, min_sample=0.01, na_omit=True, verbose=False):
y = y[~pd.isna(xarray)]
elif not na_omit:
out['Miss'] = xarray.index[pd.isna(xarray)]
total_bad = xarray.sum()
total_bad = y.sum()
bad_rate = total_bad / len(y)
print('bad_rate', total_bad, len(y))
# numeric or categorious
if xarray.dtype == 'object':
if verbose:
Expand Down Expand Up @@ -276,18 +277,10 @@ def bin_chi2(xarray, y, bins, min_sample=0.01, na_omit=True, verbose=False):
for _i, b in enumerate(borders[1:]):
start = borders[i_start]
end = b
<<<<<<< HEAD
if sum((xarray >= start) & (xarray < end)) < len(y) * min_sample:
continue
else:
tmp = y[(xarray >= start) & (xarray < end)]
=======
if sum((xarray >= start) & (y < end)) < len(y) * min_sample:
continue
else:
print(_i, start, b)
tmp = y[(xarray >= start) & (y < end)]
>>>>>>> 226c7901c6fe634e021925973c741631ae9705ac
groups.append((start, end))
counts.append(np.array((sum(tmp), len(tmp)-sum(tmp))))
i_start = _i + 1
Expand All @@ -301,7 +294,16 @@ def bin_chi2(xarray, y, bins, min_sample=0.01, na_omit=True, verbose=False):

if verbose:
print('Init groups:', groups)
groups = recursion(groups, counts, bins, numeric=True, verbose=verbose)
groups = recursion(groups, counts, bins, bad_rate=bad_rate, numeric=True, verbose=verbose)

# reformat numeric groups
for _i, g in enumerate(groups):
if np.isinf(g[0]):
g[_i] = '(-inf:%s]' % g[1]
elif np.isinf(g[1]):
g[_i] = '(%s:inf)' % g[0]
else:
g[_i] = '(%s:%s]' % (g[0], g[1])

return groups

Expand Down Expand Up @@ -336,10 +338,10 @@ def bin_custom(xarray, groups, na_omit=True, verbose=False):
# handle missing data.
if not na_omit:
if verbose:
print('Keep NA data')
print('Keep NA data')
tmp = np.where(pd.isna(xarray))[0]
if verbose:
print('Missing data: %s' % len(tmp))
print('Missing data: %s' % len(tmp))
if len(tmp) > 0:
out['Miss'] = np.where(pd.isna(xarray))[0]

Expand Down

0 comments on commit dfeaaeb

Please sign in to comment.