Skip to content

Commit

Permalink
adding label column feature
Browse files Browse the repository at this point in the history
  • Loading branch information
fcollman committed Jul 23, 2024
1 parent b27877b commit 29cc884
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
59 changes: 53 additions & 6 deletions materializationengine/blueprints/client/api2.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,27 @@ def __schema__(self):
)


query_seg_prop_parser = reqparse.RequestParser()
# add an argument for a string controlling the label format
query_seg_prop_parser.add_argument(
"label_format",
type=str,
default=None,
required=False,
location="args",
help="string controlling the label format, should be formatted like a python format string,\
i.e. {cell_type}_{id}, utilizing the columns available in the response",
)
# add an argument which is a list of column strings
query_seg_prop_parser.add_argument(
"label_columns",
type=str,
action="split",
default=None,
required=False,
location="args",
help="list of column names include in a label (will be overridden by label_format)",
)
metadata_parser = reqparse.RequestParser()
# add a boolean argument for whether to return all expired versions
metadata_parser.add_argument(
Expand Down Expand Up @@ -1118,10 +1139,10 @@ def preprocess_view_dataframe(df, view_name, db_name, column_names):
for dup in duplicates:
if dup in unique_vals[tag]:
df[tag] = df[tag].replace(dup, f"{tag}:{dup}")

return df, tags, bool_tags, numerical, root_id_col


@client_bp.expect(query_seg_prop_parser)
@client_bp.route(
"/datastack/<string:datastack_name>/version/<int:version>/table/<string:table_name>/info"
)
Expand Down Expand Up @@ -1160,6 +1181,10 @@ def get(
validate_table_args([table_name], target_datastack, target_version)
db_name = f"{datastack_name}__mat{version}"

args = query_seg_prop_parser.parse_args()
label_format = args.get("label_format", None)
label_columns = args.get("label_columns", None)

# if the database is a split database get a split model
# and if its not get a flat model

Expand Down Expand Up @@ -1211,20 +1236,24 @@ def get(
df, tags, bool_tags, numerical, root_id_col = preprocess_dataframe(
df, table_name, aligned_volume_name, column_names
)

if label_columns is None:
if label_format is None:
label_columns = "id"
seg_prop = nglui.segmentprops.SegmentProperties.from_dataframe(
df,
id_col=root_id_col,
tag_value_cols=tags,
tag_bool_cols=bool_tags,
number_cols=numerical,
label_col="id",
label_col=label_columns,
label_format_map=label_format,
)
dfjson = json.dumps(seg_prop.to_dict(), cls=current_app.json_encoder)
response = Response(dfjson, status=200, mimetype="application/json")
return after_request(response)


@client_bp.expect(query_seg_prop_parser)
@client_bp.route("/datastack/<string:datastack_name>/table/<string:table_name>/info")
class MatTableSegmentInfoLive(Resource):
method_decorators = [
Expand Down Expand Up @@ -1278,13 +1307,22 @@ def get(
vals = preprocess_dataframe(df, table_name, aligned_volume_name, column_names)
df, tags, bool_tags, numerical, root_id_col = vals

# parse the args
args = query_seg_prop_parser.parse_args()
label_format = args.get("label_format", None)
label_columns = args.get("label_columns", None)
if label_format is None:
if label_columns is None:
label_columns = "id"

seg_prop = nglui.segmentprops.SegmentProperties.from_dataframe(
df,
id_col=root_id_col,
tag_value_cols=tags,
tag_bool_cols=bool_tags,
number_cols=numerical,
label_col="id",
label_col=label_columns,
label_format_map=label_format,
)
dfjson = json.dumps(seg_prop.to_dict(), cls=current_app.json_encoder)
response = Response(dfjson, status=200, mimetype="application/json")
Expand Down Expand Up @@ -1764,7 +1802,7 @@ def conditional_view_cache(func):
@wraps(func)
def wrapper(*args, **kwargs):
# Generate a cache key
key = hashkey(*args, **kwargs)
key = request.url
if kwargs.get("version") == -1:
# Check if the result is in the live cache
if key in view_live_cache:
Expand Down Expand Up @@ -1795,6 +1833,7 @@ def wrapper(*args, **kwargs):
return wrapper


@client_bp.expect(query_seg_prop_parser)
@client_bp.route(
"/datastack/<string:datastack_name>/version/<int(signed=True):version>/view/<string:view_name>/info"
)
Expand Down Expand Up @@ -1856,13 +1895,21 @@ def get(
df, view_name, mat_db_name, column_names
)

args = query_seg_prop_parser.parse_args()
label_format = args.get("label_format", None)
label_columns = args.get("label_columns", None)
if label_format is None:
if label_columns is None:
label_columns = df.columns[0]

seg_prop = nglui.segmentprops.SegmentProperties.from_dataframe(
df,
id_col=root_id_col,
tag_value_cols=tags,
tag_bool_cols=bool_tags,
number_cols=numerical,
label_col=df.columns[0],
label_col=label_columns,
label_format_map=label_format,
)
# use the current_app encoder to encode the seg_prop.to_dict()
# to ensure that the json is serialized correctly
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ networkx==2.5
# cloud-volume
neuroglancer==2.39.2
# via nglui
nglui==3.3.4
nglui==3.3.5
# via -r requirements.in
numexpr==2.10.1
# via tables
Expand Down

0 comments on commit 29cc884

Please sign in to comment.