Skip to content

Commit

Permalink
Add GenAI Preprocessor logic (#2314)
Browse files Browse the repository at this point in the history
* add sample raw logs to e2e test resources

* working on genai preprocessor

* adding logic for genAI preprocessor

* adding UT tests

* select cols in df  to v1 schema

* UT

* fix UT

* fix data

* refactor shared mdc code in genai preprocessor

* syntax fixes

* let all extra columns through

* syntax

* syntax

* doc syntax

* adding e2e test

* adding dataref example

* remove unused import

* add file datalake pacakge to genai preprocessor

* no need to bump version. Hasn't released v1 yet

* add dataref testcase but comment out for now

* fix UT

* merge trace aggregator and genai preprocessor PRs into one (#2323)

* temp testing files

* stubbing out trace logic

* update type hinting

* add sample raw logs to e2e test resources

* remove attributes for temp testing

* temp files

* working on span tree json_string repr

* fix syntaxes

* finished logic for trace aggregator

* remove temp files

* add type hints

* syntax fix and working on UTs

* remove unused property

* make span_row internal private

* fix syntax

* syntax fixes

* add SpanTreeNode tests

* add tests for spantree and spantreenode

* syntax

* Revert "add sample raw logs to e2e test resources"

This reverts commit b331897.

* syntax fixes

* syntax fix

* fix UT

* Update test_trace_aggregator.py

* Update test_trace_aggregator.py

* change to timestamptype

* schema in alphabetical order

* columns will be in datetime. Handle correct json serialization

* expect datetime object in spantree tests not string

* moved span utils out of subfolder

* fix syntax

* add debug and temp file

* assign df to return

* remove debug info

* add from json str test

* use property in function instead of private var

* syntax changes

* syntax

* add more validation in span tree util

* move datetime serialization logic to appropriate function

* syntax

* fix test case for to_dict()

* add trace aggregator in preprocessor

* user_id and session_id may not be available for v0

* comment out user id and session id in schema

* unused import for now

* syntax

* fix UT with less schema

* slim down debugging show()

* refactor trace aggregator to use more pyspark

* fix import to be same as others

* fix UT and import stmts

* modify UT

* modify UT

* Update trace_aggregator.py

* fix UT

* fix syntax

* syntax

* refactor tree construction to include all fields from span logs

* syntax

* syntax fix

* format output of trace aggregator to fit  agg trace schema

* remove unused variable

* unused imports

* unused definition

* set auto for parallel pytests

* Revert "set auto for parallel pytests"

This reverts commit d12324a.

* revert main merge

* Update model-monitoring-ci.yml

* revert auto change

* pass schema to map function, not call it on workers

* addressing some comments

* fix UT

* add new flag for calculate trace logs in preprocessor

* return None column if can't find in raw logs data

* add UT for None columns if not in raw logs

* Revert "add new flag for calculate trace logs in preprocessor"

This reverts commit fa89029.

* syntax

* syntax

* Remove problematic test for v1

* add trace aggregator UT

* setup pyspark path for each testcase

* remove broken UT for v1

* Revert "remove broken UT for v1"

This reverts commit 836396b.

* Reapply "remove broken UT for v1"

This reverts commit 7753121.

* Revert "setup pyspark path for each testcase"

This reverts commit 22daf2a.

* Revert "add trace aggregator UT"

This reverts commit 14a4d63.
  • Loading branch information
alanpo1 authored Feb 20, 2024
1 parent c65d1db commit 9746f3a
Show file tree
Hide file tree
Showing 18 changed files with 1,051 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ conf:
dependencies:
- python=3.8
- pip:
- azure-storage-file-datalake~=12.8.0
- mltable~=1.3.0
name: momo-base-spark
args: >-
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Constants file for genai preprocessor component."""


from pyspark.sql.types import TimestampType, StructType, StructField, StringType


def _get_preprocessed_span_logs_df_schema() -> StructType:
"""Get processed span logs Dataframe schema."""
# TODO: The user_id and session_id may not be available in v1.
schema = StructType([
StructField('attributes', StringType(), False),
StructField('end_time', TimestampType(), False),
StructField('events', StringType(), False),
StructField('framework', StringType(), False),
StructField('input', StringType(), False),
StructField('links', StringType(), False),
StructField('name', StringType(), False),
StructField('output', StringType(), False),
StructField('parent_id', StringType(), True),
# StructField('session_id', StringType(), True),
StructField('span_id', StringType(), False),
StructField('span_type', StringType(), False),
StructField('start_time', TimestampType(), False),
StructField('status', StringType(), False),
StructField('trace_id', StringType(), False),
# StructField('user_id', StringType(), True),
])
return schema


def _get_aggregated_trace_log_spark_df_schema() -> StructType:
"""Get Aggregated Trace Log DataFrame Schema."""
# TODO: The user_id and session_id may not be available in v0 of trace aggregator.
schema = StructType(
[
StructField("trace_id", StringType(), False),
StructField("user_id", StringType(), True),
StructField("session_id", StringType(), True),
StructField("start_time", TimestampType(), False),
StructField("end_time", TimestampType(), False),
StructField("input", StringType(), False),
StructField("output", StringType(), False),
StructField("root_span", StringType(), True),
]
)
return schema
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,141 @@

"""Entry script for GenAI MDC Preprocessor."""

# place holder file, will add content later
import argparse

from pyspark.sql import DataFrame
from pyspark.sql.types import TimestampType
from pyspark.sql.utils import AnalysisException
from pyspark.sql.functions import lit
from shared_utilities.io_utils import save_spark_df_as_mltable
from model_data_collector_preprocessor.store_url import StoreUrl
# TODO: move shared utils to a util py file
from model_data_collector_preprocessor.spark_run import (
_mdc_uri_folder_to_preprocessed_spark_df,
_convert_complex_columns_to_json_string,
)
from model_data_collector_preprocessor.trace_aggregator import (
process_spans_into_aggregated_traces,
)


def _get_important_field_mapping() -> dict:
"""Map the span log schema names to the expected raw log schema column/field name."""
map = {
"trace_id": "context.trace_id",
"span_id": "context.span_id",
"span_type": "attributes.span_type",
"status": "status.status_code",
"framework": "attributes.framework",
"input": "attributes.inputs",
"output": "attributes.output",
}
return map


def _promote_fields_from_attributes(df: DataFrame) -> DataFrame:
"""Retrieve common/important fields(span_type, input, etc.) from attributes and make them first level columns.
Args:
df: The raw span logs data in a dataframe.
"""
def try_get_df_column(df: DataFrame, name: str):
"""Get column if it exists in DF. Return none if column does not exist."""
try:
return df[name]
except AnalysisException:
return None

df = df.withColumns(
{
key: (df_col if (df_col := try_get_df_column(df, col_name)) is not None else lit(df_col))
for key, col_name in _get_important_field_mapping().items()
}
)
return df


def _preprocess_raw_logs_to_span_logs_spark_df(df: DataFrame) -> DataFrame:
"""Apply Gen AI preprocessing steps to raw logs dataframe.
Args:
df: The raw span logs data in a dataframe.
"""
# TODO: handle if original start/endtime not valid time string
df = df.withColumns(
{'end_time': df.end_time.cast(TimestampType()), 'start_time': df.start_time.cast(TimestampType())}
)

df = _promote_fields_from_attributes(df)

# Cast all remaining fields in attributes to json string, for easier schema unifying
df = _convert_complex_columns_to_json_string(df)

print("df processed from raw Gen AI logs:")
df.show(truncate=False)
df.printSchema()

return df


def _genai_uri_folder_to_preprocessed_spark_df(
data_window_start: str, data_window_end: str, store_url: StoreUrl, add_tags_func=None
) -> DataFrame:
"""Read raw gen AI logs data, preprocess, and return in a Spark DataFrame."""
df = _mdc_uri_folder_to_preprocessed_spark_df(data_window_start, data_window_end, store_url, False, add_tags_func)

df = _preprocess_raw_logs_to_span_logs_spark_df(df)

return df


def genai_preprocessor(
data_window_start: str,
data_window_end: str,
input_data: str,
preprocessed_span_data: str,
aggregated_trace_data: str):
"""Extract data based on window size provided and preprocess it into MLTable.
Args:
data_window_start: The start date of the data window.
data_window_end: The end date of the data window.
input_data: The data asset on which the date filter is applied.
preprocessed_span_data: The mltable path pointing to location where the
outputted span logs mltable will be written to.
aggregated_trace_data: The mltable path pointing to location where the
outputted aggregated trace logs mltable will be written to.
"""
store_url = StoreUrl(input_data)

transformed_df = _genai_uri_folder_to_preprocessed_spark_df(data_window_start, data_window_end, store_url)

trace_logs_df = process_spans_into_aggregated_traces(transformed_df)

save_spark_df_as_mltable(transformed_df, preprocessed_span_data)

save_spark_df_as_mltable(trace_logs_df, aggregated_trace_data)


def run():
"""Compute data window and preprocess data from MDC."""
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--data_window_start", type=str)
parser.add_argument("--data_window_end", type=str)
parser.add_argument("--input_data", type=str)
parser.add_argument("--preprocessed_span_data", type=str)
parser.add_argument("--aggregated_trace_data", type=str)
args = parser.parse_args()

genai_preprocessor(
args.data_window_start,
args.data_window_end,
args.input_data,
args.preprocessed_span_data,
args.aggregated_trace_data,
)


if __name__ == "__main__":
run()
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Span Tree internal utilities and classes for building a Span Tree from preprocessed span logs."""


import bisect
from datetime import datetime
import json

from typing import Iterator, List
from pyspark.sql import Row


class SpanTreeNode:
"""Spantree node class."""

def __init__(self, span_row: Row) -> None:
"""Represent a singular node in a span tree."""
self._span_row = span_row
self._children = []

@property
def span_id(self) -> str:
"""Get the span id."""
return self._span_row.span_id

@property
def parent_id(self) -> str:
"""Get the span's parent id."""
return self._span_row.parent_id

@property
def children(self) -> List["SpanTreeNode"]:
"""Get the span's children as list."""
return self._children

@children.setter
def children(self, value: list) -> None:
"""Set the span's children."""
self._children = value

@property
def span_row(self) -> Row:
"""Get the span's internal row."""
return self._span_row

@classmethod
def create_node_from_json_str(cls, json_str: str) -> "SpanTreeNode":
"""Parse json string to get a single SpanTree node."""
node_dict: dict = json.loads(json_str)

node_dict['start_time'] = datetime.fromisoformat(node_dict['start_time'])
node_dict['end_time'] = datetime.fromisoformat(node_dict['end_time'])
children = node_dict.pop('children', [])
new_row = Row(**node_dict)

obj = cls.__new__(cls)
super(SpanTreeNode, obj).__init__()
obj._span_row = new_row
obj._children = children
return obj

def insert_child(self, span: "SpanTreeNode") -> None:
"""Insert a child span in ascending time order due to __lt__()."""
bisect.insort(self._children, span)

def show(self, indent: int = 0) -> None:
"""Print the current span in a formatted syntax to stdout."""
print(f"{' '*indent}[{self.span_row.span_id}({self.span_row.start_time}, {self.span_row.end_time})]")
for c in self.children:
c.show(indent + 4)

def to_dict(self) -> dict:
"""Get dictionary representation of SpanTreeNode."""
span_row_dict = self.span_row.asDict()
span_row_dict['children'] = self.children
return span_row_dict

def __iter__(self) -> Iterator["SpanTreeNode"]:
"""Iterate over current span and child spans."""
for child_span in self._children:
for span in child_span:
# print(f"iter func: child_span: {child_span}, span: {span}.")
yield span
yield self

def __lt__(self, other: "SpanTreeNode") -> bool:
"""Compare by end_time in bisect.insort() for python3.8."""
return self.span_row.end_time < other.span_row.end_time


class SpanTree:
"""Spantree class."""

def __init__(self, spans: List[SpanTreeNode]) -> None:
"""Spantree constructor to build up tree from span list."""
self.root_span = self._construct_span_tree(spans)
self._load_json_tree_lazy = False

@classmethod
def create_tree_from_json_string(cls, json_string: str, load_tree_lazy: bool = False) -> "SpanTree":
"""Create SpanTree object from "root_span" json string."""
obj = cls.__new__(cls)
super(SpanTree, obj).__init__()
obj._load_json_tree_lazy = load_tree_lazy
# Default behavior is to load the whole tree from top level json string.
# TODO: Stretch goal will be to load the tree lazily one level at a time to save computation/space.
if not load_tree_lazy:
obj.root_span = obj._from_json_str_repr(json_string)
return obj

def show(self) -> None:
"""Print to stdout a formatted representation of the Span Tree."""
if self.root_span is None:
return
self.root_span.show()

def to_json_str(self) -> str:
"""Get tree structure as json string."""
if self.root_span is None:
return None # type: ignore
return self._to_json_str_repr(self.root_span)

def _construct_span_tree(self, spans: List[SpanTreeNode]) -> SpanTreeNode:
"""Build the span tree in ascending time order from list of all spans."""
# construct a dict with span_id as key and span as value
span_map = {span.span_id: span for span in spans}
for span in span_map.values():
parent_id = span.parent_id
if parent_id is None:
root_span = span
else:
parent_span = span_map.get(parent_id)
if parent_span is not None:
parent_span.insert_child(span)
return root_span

def __iter__(self) -> Iterator[SpanTreeNode]:
"""Iterate over the span tree in order."""
if self.root_span is None:
return
for span in self.root_span.__iter__():
yield span

def _from_json_str_repr(self, json_string: str) -> SpanTreeNode:
"""Convert json string to SpanTree Node fully."""
output_node = SpanTreeNode.create_node_from_json_str(json_string)
child_subtree_nodes = []
for child in output_node.children:
new_node = self._from_json_str_repr(child) # type: ignore
child_subtree_nodes.append(new_node)
output_node.children = child_subtree_nodes
return output_node

def _to_json_str_repr(self, curr_span: SpanTreeNode) -> str:
"""Recursively get tree structure JSON string."""
output = curr_span.to_dict()
start_time: datetime = output['start_time'] # type: ignore
end_time: datetime = output['end_time'] # type: ignore
output['start_time'] = start_time.isoformat()
output['end_time'] = end_time.isoformat()

child_subtree_strs = []
for child in output['children']:
subtree_str = self._to_json_str_repr(child)
child_subtree_strs.append(subtree_str)
output['children'] = child_subtree_strs
return json.dumps(output)
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from shared_utilities.constants import (
MDC_CORRELATION_ID_COLUMN, MDC_DATA_COLUMN, MDC_DATAREF_COLUMN, AML_MOMO_ERROR_TAG
)
from mdc_preprocessor_helper import (

from model_data_collector_preprocessor.mdc_preprocessor_helper import (
get_file_list, set_data_access_config, serialize_credential, copy_appendblob_to_blockblob
)
from model_data_collector_preprocessor.store_url import StoreUrl
Expand Down
Loading

0 comments on commit 9746f3a

Please sign in to comment.