Skip to content

Commit

Permalink
Merge branch 'main' into demo/chart-gallery
Browse files Browse the repository at this point in the history
  • Loading branch information
huong-li-nguyen committed Jul 8, 2024
2 parents cd1ff04 + 6ba4119 commit 3cfa78a
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 39 deletions.
48 changes: 48 additions & 0 deletions vizro-ai/changelog.d/20240626_224646_anna_xiong_azure_openai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Fixed
- A bullet item for the Fixed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
58 changes: 24 additions & 34 deletions vizro-ai/src/vizro_ai/chains/_llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,24 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_openai import ChatOpenAI

# TODO constant of model inventory, can be converted to yaml and link to docs
PREDEFINED_MODELS: Dict[str, Dict[str, Union[int, BaseChatModel]]] = {
"gpt-3.5-turbo-0613": {
"max_tokens": 4096,
"wrapper": ChatOpenAI,
},
"gpt-4-0613": {
"max_tokens": 8192,
"wrapper": ChatOpenAI,
},
"gpt-3.5-turbo-1106": {
"max_tokens": 16385,
"wrapper": ChatOpenAI,
},
"gpt-4-1106-preview": {
"max_tokens": 128000,
"wrapper": ChatOpenAI,
},
"gpt-3.5-turbo-0125": {
"max_tokens": 16385,
"wrapper": ChatOpenAI,
},
"gpt-3.5-turbo": {
"max_tokens": 16385,
"wrapper": ChatOpenAI,
},
"gpt-4-turbo": {
"max_tokens": 128000,
"wrapper": ChatOpenAI,
},
SUPPORTED_MODELS = {
"OpenAI": [
"gpt-4-0613",
"gpt-3.5-turbo-1106",
"gpt-4-1106-preview",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo",
"gpt-4-turbo",
"gpt-4o",
]
}

DEFAULT_WRAPPER_MAP: Dict[str, BaseChatModel] = {"OpenAI": ChatOpenAI}
DEFAULT_MODEL = "gpt-3.5-turbo"
DEFAULT_TEMPERATURE = 0

model_to_vendor = {model: key for key, models in SUPPORTED_MODELS.items() for model in models}


def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatModel:
"""Fetches and initializes an instance of the LLM.
Expand All @@ -54,14 +37,21 @@ def _get_llm_model(model: Optional[Union[ChatOpenAI, str]] = None) -> BaseChatMo
"""
if not model:
return ChatOpenAI(model_name=DEFAULT_MODEL, temperature=DEFAULT_TEMPERATURE)
if isinstance(model, ChatOpenAI):

if isinstance(model, BaseChatModel):
return model
if isinstance(model, str) and model in PREDEFINED_MODELS:
return PREDEFINED_MODELS.get(model)["wrapper"](model_name=model, temperature=DEFAULT_TEMPERATURE)

if isinstance(model, str):
if any(model in model_list for model_list in SUPPORTED_MODELS.values()):
vendor = model_to_vendor[model]
return DEFAULT_WRAPPER_MAP.get(vendor)(model_name=model, temperature=DEFAULT_TEMPERATURE)

raise ValueError(
f"Model {model} not found! List of available model can be found at https://vizro.readthedocs.io/projects/vizro-ai/en/latest/pages/explanation/faq/#which-llms-are-supported-by-vizro-ai"
)


if __name__ == "__main__":
llm_chat_openai = _get_llm_model()
llm_chat_openai = _get_llm_model(model="gpt-3.5-turbo")
print(repr(llm_chat_openai)) # noqa: T201
print(llm_chat_openai.model_name) # noqa: T201
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
<!--
A new scriv changelog fragment.
Uncomment the section that is right (remove the HTML comment wrapper).
-->

<!--
### Highlights ✨
- A bullet item for the Highlights ✨ category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Removed
- A bullet item for the Removed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Added
- A bullet item for the Added category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Changed
- A bullet item for the Changed category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
<!--
### Deprecated
- A bullet item for the Deprecated category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->

### Fixed

- Ensure that categorical selectors always return a list of values. ([#562](https://github.com/mckinsey/vizro/pull/562))

<!--
### Security
- A bullet item for the Security category with a link to the relevant PR at the end of your entry, e.g. Enable feature XXX ([#1](https://github.com/mckinsey/vizro/pull/1))
-->
15 changes: 11 additions & 4 deletions vizro-core/src/vizro/actions/_actions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,19 @@ def _get_parametrized_config(target: ModelID, ctd_parameters: List[CallbackTrigg
config["data_frame"] = {}

for ctd in ctd_parameters:
selector_value = ctd[
"value"
] # TODO: needs to be refactored so that it is independent of implementation details
# TODO: needs to be refactored so that it is independent of implementation details
selector_value = ctd["value"]

if hasattr(selector_value, "__iter__") and ALL_OPTION in selector_value: # type: ignore[operator]
selector: SelectorType = model_manager[ctd["id"]]
selector_value = selector.options

# Even if options are provided as List[Dict], the Dash component only returns a List of values.
# So we need to ensure that we always return a List only as well to provide consistent types.
if all(isinstance(option, dict) for option in selector.options):
selector_value = [option["value"] for option in selector.options]
else:
selector_value = selector.options

selector_value = _validate_selector_value_none(selector_value)
selector_actions = _get_component_actions(model_manager[ctd["id"]])

Expand Down
2 changes: 1 addition & 1 deletion vizro-core/src/vizro/models/_action/_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _action_callback_function(
) -> Any:
logger.debug("===== Running action with id %s, function %s =====", self.id, self.function._function.__name__)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Action inputs:\n%s", pformat(inputs, depth=2, width=200))
logger.debug("Action inputs:\n%s", pformat(inputs, depth=3, width=200))
logger.debug("Action outputs:\n%s", pformat(outputs, width=200))

if isinstance(inputs, Mapping):
Expand Down
28 changes: 28 additions & 0 deletions vizro-core/tests/unit/vizro/actions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ def gapminder_2007(gapminder):
return gapminder.query("year == 2007")


@pytest.fixture
def iris():
return px.data.iris()


@pytest.fixture
def gapminder_dynamic_first_n_last_n_function(gapminder):
return lambda first_n=None, last_n=None: (
Expand Down Expand Up @@ -44,6 +49,16 @@ def scatter_chart(gapminder_2007, scatter_params):
return px.scatter(gapminder_2007, **scatter_params).update_layout(margin_t=24)


@pytest.fixture
def scatter_matrix_params():
return {"dimensions": ["sepal_width", "sepal_length", "petal_width", "petal_length"]}


@pytest.fixture
def scatter_matrix_chart(iris, scatter_matrix_params):
return px.scatter_matrix(iris, **scatter_matrix_params).update_layout(margin_t=24)


@pytest.fixture
def scatter_chart_dynamic_data_frame(scatter_params):
return px.scatter("gapminder_dynamic_first_n_last_n", **scatter_params).update_layout(margin_t=24)
Expand Down Expand Up @@ -110,3 +125,16 @@ def managers_one_page_two_graphs_one_table_one_aggrid_one_button(
],
)
Vizro._pre_build()


@pytest.fixture
def managers_one_page_one_graph_with_dict_param_input(scatter_matrix_chart):
"""Instantiates a model_manager and data_manager with a page and a graph that requires a list input."""
vm.Page(
id="test_page",
title="My first dashboard",
components=[
vm.Graph(id="scatter_matrix_chart", figure=scatter_matrix_chart),
],
)
Vizro._pre_build()
72 changes: 72 additions & 0 deletions vizro-core/tests/unit/vizro/actions/test_parameter_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ def target_scatter_parameter_y(request, gapminder_2007, scatter_params):
return px.scatter(gapminder_2007, **scatter_params).update_layout(margin_t=24)


@pytest.fixture
def target_scatter_matrix_parameter_dimensions(request, iris, scatter_matrix_params):
dimensions = request.param
scatter_matrix_params["dimensions"] = dimensions
return px.scatter_matrix(iris, **scatter_matrix_params).update_layout(margin_t=24)


@pytest.fixture
def target_scatter_parameter_hover_data(request, gapminder_2007, scatter_params):
hover_data = request.param
Expand Down Expand Up @@ -95,6 +102,38 @@ def ctx_parameter_y(request):
return context_value


@pytest.fixture
def ctx_parameter_dimensions(request):
"""Mock dash.ctx that represents `dimensions` Parameter value selection."""
y = request.param
mock_ctx = {
"args_grouping": {
"external": {
"filter_interaction": [],
"filters": [],
"parameters": [
CallbackTriggerDict(
id="dimensions_parameter",
property="value",
value=y,
str_id="dimensions_parameter",
triggered=False,
)
],
"theme_selector": CallbackTriggerDict(
id="theme_selector",
property="checked",
value=False,
str_id="theme_selector",
triggered=False,
),
}
}
}
context_value.set(AttributeDict(**mock_ctx))
return context_value


@pytest.fixture
def ctx_parameter_hover_data(request):
"""Mock dash.ctx that represents hover_data Parameter value selection."""
Expand Down Expand Up @@ -497,3 +536,36 @@ def test_data_frame_parameters_multiple_targets(
}

assert result == expected

@pytest.mark.usefixtures("managers_one_page_one_graph_with_dict_param_input")
@pytest.mark.parametrize(
"ctx_parameter_dimensions, target_scatter_matrix_parameter_dimensions",
[("ALL", ["sepal_length", "sepal_width", "petal_length", "petal_width"]), (["sepal_width"], ["sepal_width"])],
indirect=True,
)
def test_one_parameter_with_dict_input_as_options(
self, ctx_parameter_dimensions, target_scatter_matrix_parameter_dimensions
):
# If the options are provided as a list of dictionaries, the value should be correctly passed to the
# target as a list. So when "ALL" is selected, a list of all possible values should be returned.
dimensions_parameter = vm.Parameter(
id="test_parameter_dimensions",
targets=["scatter_matrix_chart.dimensions"],
selector=vm.RadioItems(
id="dimensions_parameter",
options=[
{"label": "sepal_length", "value": "sepal_length"},
{"label": "sepal_width", "value": "sepal_width"},
{"label": "petal_length", "value": "petal_length"},
{"label": "petal_width", "value": "petal_width"},
],
),
)
model_manager["test_page"].controls = [dimensions_parameter]
dimensions_parameter.pre_build()

# Run action by picking the above added action function and executing it with ()
result = model_manager[f"{PARAMETER_ACTION_PREFIX}_test_parameter_dimensions"].function()
expected = {"scatter_matrix_chart": target_scatter_matrix_parameter_dimensions}

assert result == expected

0 comments on commit 3cfa78a

Please sign in to comment.