diff --git a/optuna_dashboard/ts/components/GraphContour.tsx b/optuna_dashboard/ts/components/GraphContour.tsx index a5416f5b8..b7895ddfa 100644 --- a/optuna_dashboard/ts/components/GraphContour.tsx +++ b/optuna_dashboard/ts/components/GraphContour.tsx @@ -14,25 +14,8 @@ import { import blue from "@mui/material/colors/blue" import { plotlyDarkTemplate } from "./PlotlyDarkMode" import { useMergedUnionSearchSpace } from "../searchSpace" +import { getAxisInfo } from "../graphUtil" -// eslint-disable-next-line @typescript-eslint/no-explicit-any -const unique = (array: any[]) => { - const knownElements = new Map() - array.forEach((elem) => knownElements.set(elem, true)) - return Array.from(knownElements.keys()) -} - -type AxisInfo = { - name: string - min: number - max: number - isLog: boolean - isCat: boolean - indices: (string | number)[] - values: (string | number | null)[] -} - -const PADDING_RATIO = 0.05 const plotDomId = "graph-contour" export const Contour: FC<{ @@ -284,93 +267,3 @@ const plotContour = ( ] plotly.react(plotDomId, plotData, layout) } - -const getAxisInfoForNumericalParams = ( - trials: Trial[], - paramName: string, - distribution: FloatDistribution | IntDistribution -): AxisInfo => { - let min = 0 - let max = 0 - if (distribution.log) { - const padding = - (Math.log10(distribution.high) - Math.log10(distribution.low)) * - PADDING_RATIO - min = Math.pow(10, Math.log10(distribution.low) - padding) - max = Math.pow(10, Math.log10(distribution.high) + padding) - } else { - const padding = (distribution.high - distribution.low) * PADDING_RATIO - min = distribution.low - padding - max = distribution.high + padding - } - - const values = trials.map( - (trial) => - trial.params.find((p) => p.name === paramName)?.param_internal_value || - null - ) - const indices = unique(values) - .filter((v) => v !== null) - .sort((a, b) => a - b) - if (indices.length >= 2) { - indices.unshift(min) - indices.push(max) - } - return { - name: paramName, - min, - max, - isLog: distribution.log, - isCat: false, - indices, - values, - } -} - -const getAxisInfoForCategoricalParams = ( - trials: Trial[], - paramName: string, - distribution: CategoricalDistribution -): AxisInfo => { - const values = trials.map( - (trial) => - trial.params.find((p) => p.name === paramName)?.param_external_value || - null - ) - const isDynamic = values.some((v) => v === null) - const span = distribution.choices.length - (isDynamic ? 2 : 1) - const padding = span * PADDING_RATIO - const min = -padding - const max = span + padding - - const indices = distribution.choices - .map((c) => c.value) - .sort((a, b) => - a.toLowerCase() < b.toLowerCase() - ? -1 - : a.toLowerCase() > b.toLowerCase() - ? 1 - : 0 - ) - return { - name: paramName, - min, - max, - isLog: false, - isCat: true, - indices, - values, - } -} - -const getAxisInfo = (trials: Trial[], param: SearchSpaceItem): AxisInfo => { - if (param.distribution.type === "CategoricalDistribution") { - return getAxisInfoForCategoricalParams( - trials, - param.name, - param.distribution - ) - } else { - return getAxisInfoForNumericalParams(trials, param.name, param.distribution) - } -} diff --git a/optuna_dashboard/ts/components/GraphRank.tsx b/optuna_dashboard/ts/components/GraphRank.tsx index 6781793fe..0ff215d1e 100644 --- a/optuna_dashboard/ts/components/GraphRank.tsx +++ b/optuna_dashboard/ts/components/GraphRank.tsx @@ -12,25 +12,18 @@ import { Box, } from "@mui/material" import { plotlyDarkTemplate } from "./PlotlyDarkMode" -import { makeHovertext } from "../graphUtil" +import { getAxisInfo, makeHovertext } from "../graphUtil" import { useMergedUnionSearchSpace } from "../searchSpace" -const PADDING_RATIO = 0.05 const plotDomId = "graph-rank" -interface AxisInfo { - name: string - range: [number, number] - isLog: boolean - isCat: boolean -} - interface RankPlotInfo { - xaxis: AxisInfo - yaxis: AxisInfo + xtitle: string + ytitle: string + xtype: plotly.AxisType + ytype: plotly.AxisType xvalues: (string | number)[] yvalues: (string | number)[] - zvalues: number[] colors: number[] is_feasible: boolean[] hovertext: string[] @@ -154,38 +147,74 @@ const getRankPlotInfo = ( const xAxis = getAxisInfo(filteredTrials, xParam) const yAxis = getAxisInfo(filteredTrials, yParam) - const xValues: (string | number)[] = [] - const yValues: (string | number)[] = [] + let xValues: (string | number)[] = [] + let yValues: (string | number)[] = [] const zValues: number[] = [] const isFeasible: boolean[] = [] const hovertext: string[] = [] - filteredTrials.forEach((trial) => { - const xValue = - trial.params.find((p) => p.name === xAxis.name)?.param_external_value || - null - const yValue = - trial.params.find((p) => p.name === yAxis.name)?.param_external_value || - null - if (trial.values === undefined || xValue === null || yValue === null) { - return + filteredTrials.forEach((trial, i) => { + const xValue = xAxis.values[i] + const yValue = yAxis.values[i] + if (xValue && yValue && trial.values) { + xValues.push(xValue) + yValues.push(yValue) + const zValue = Number(trial.values[objectiveId]) + zValues.push(zValue) + const feasibility = trial.constraints.every((c) => c <= 0) + isFeasible.push(feasibility) + hovertext.push(makeHovertext(trial)) } - const zValue = Number(trial.values[objectiveId]) - const feasibility = trial.constraints.every((c) => c <= 0) - xValues.push(xValue) - yValues.push(yValue) - zValues.push(zValue) - isFeasible.push(feasibility) - hovertext.push(makeHovertext(trial)) }) const colors = getColors(zValues) + if (xAxis.isCat && !yAxis.isCat) { + const indices: number[] = Array.from(Array(xValues.length).keys()).sort( + (a, b) => + xValues[a] + .toString() + .toLowerCase() + .localeCompare(xValues[b].toString().toLowerCase()) + ) + xValues = indices.map((i) => xValues[i]) + yValues = indices.map((i) => yValues[i]) + } else if (!xAxis.isCat && yAxis.isCat) { + const indices: number[] = Array.from(Array(yValues.length).keys()).sort( + (a, b) => + yValues[a] + .toString() + .toLowerCase() + .localeCompare(yValues[b].toString().toLowerCase()) + ) + xValues = indices.map((i) => xValues[i]) + yValues = indices.map((i) => yValues[i]) + } else if (xAxis.isCat && yAxis.isCat) { + const indices: number[] = Array.from(Array(xValues.length).keys()).sort( + (a, b) => { + const xComp = xValues[a] + .toString() + .toLowerCase() + .localeCompare(xValues[b].toString().toLowerCase()) + if (xComp !== 0) { + return xComp + } + return yValues[a] + .toString() + .toLowerCase() + .localeCompare(yValues[b].toString().toLowerCase()) + } + ) + xValues = indices.map((i) => xValues[i]) + yValues = indices.map((i) => yValues[i]) + } + return { - xaxis: xAxis, - yaxis: yAxis, + xtitle: xAxis.name, + ytitle: yAxis.name, + xtype: xAxis.isCat ? "category" : xAxis.isLog ? "log" : "linear", + ytype: yAxis.isCat ? "category" : yAxis.isLog ? "log" : "linear", xvalues: xValues, yvalues: yValues, - zvalues: zValues, colors, is_feasible: isFeasible, hovertext, @@ -196,72 +225,6 @@ const filterFunc = (trial: Trial): boolean => { return trial.state === "Complete" && trial.values !== undefined } -const getAxisInfo = (trials: Trial[], param: SearchSpaceItem): AxisInfo => { - if (param.distribution.type === "CategoricalDistribution") { - return getAxisInfoForCategorical(trials, param.name, param.distribution) - } else { - return getAxisInfoForNumerical(trials, param.name, param.distribution) - } -} - -const getAxisInfoForCategorical = ( - trials: Trial[], - param: string, - distribution: CategoricalDistribution -): AxisInfo => { - const values = trials.map( - (trial) => - trial.params.find((p) => p.name === param)?.param_internal_value || null - ) - const isDynamic = values.some((v) => v === null) - const span = distribution.choices.length - (isDynamic ? 2 : 1) - const padding = span * PADDING_RATIO - const min = -padding - const max = span + padding - - return { - name: param, - range: [min, max], - isLog: false, - isCat: true, - } -} - -const getAxisInfoForNumerical = ( - trials: Trial[], - param: string, - distribution: FloatDistribution | IntDistribution -): AxisInfo => { - const values = trials.map( - (trial) => - trial.params.find((p) => p.name === param)?.param_internal_value || null - ) - const nonNullValues: number[] = [] - values.forEach((value) => { - if (value !== null) { - nonNullValues.push(value) - } - }) - let min = Math.min(...nonNullValues) - let max = Math.max(...nonNullValues) - if (distribution.log) { - const padding = (Math.log10(max) - Math.log10(min)) * PADDING_RATIO - min = Math.pow(10, Math.log10(min) - padding) - max = Math.pow(10, Math.log10(max) + padding) - } else { - const padding = (max - min) * PADDING_RATIO - min = min - padding - max = max + padding - } - - return { - name: param, - range: [min, max], - isLog: distribution.log, - isCat: false, - } -} - const getColors = (values: number[]): number[] => { const rawRanks = getOrderWithSameOrderAveraging(values) let colorIdxs: number[] = [] @@ -299,16 +262,14 @@ const plotRank = (rankPlotInfo: RankPlotInfo | null, mode: string) => { return } - const xAxis = rankPlotInfo.xaxis - const yAxis = rankPlotInfo.yaxis const layout: Partial = { xaxis: { - title: xAxis.name, - type: xAxis.isCat ? "category" : xAxis.isLog ? "log" : "linear", + title: rankPlotInfo.xtitle, + type: rankPlotInfo.xtype, }, yaxis: { - title: yAxis.name, - type: yAxis.isCat ? "category" : yAxis.isLog ? "log" : "linear", + title: rankPlotInfo.ytitle, + type: rankPlotInfo.ytype, }, margin: { l: 50, @@ -320,49 +281,8 @@ const plotRank = (rankPlotInfo: RankPlotInfo | null, mode: string) => { template: mode === "dark" ? plotlyDarkTemplate : {}, } - let xValues = rankPlotInfo.xvalues - let yValues = rankPlotInfo.yvalues - if (xAxis.isCat && !yAxis.isCat) { - const xIndices: number[] = Array.from(Array(xValues.length).keys()).sort( - (a, b) => - xValues[a] - .toString() - .toLowerCase() - .localeCompare(xValues[b].toString().toLowerCase()) - ) - xValues = xIndices.map((i) => xValues[i]) - yValues = xIndices.map((i) => yValues[i]) - } - if (!xAxis.isCat && yAxis.isCat) { - const yIndices: number[] = Array.from(Array(yValues.length).keys()).sort( - (a, b) => - yValues[a] - .toString() - .toLowerCase() - .localeCompare(yValues[b].toString().toLowerCase()) - ) - xValues = yIndices.map((i) => xValues[i]) - yValues = yIndices.map((i) => yValues[i]) - } - if (xAxis.isCat && yAxis.isCat) { - const indices: number[] = Array.from(Array(xValues.length).keys()).sort( - (a, b) => { - const xComp = xValues[a] - .toString() - .toLowerCase() - .localeCompare(xValues[b].toString().toLowerCase()) - if (xComp !== 0) { - return xComp - } - return yValues[a] - .toString() - .toLowerCase() - .localeCompare(yValues[b].toString().toLowerCase()) - } - ) - xValues = indices.map((i) => xValues[i]) - yValues = indices.map((i) => yValues[i]) - } + const xValues = rankPlotInfo.xvalues + const yValues = rankPlotInfo.yvalues const plotData: Partial[] = [ { diff --git a/optuna_dashboard/ts/graphUtil.ts b/optuna_dashboard/ts/graphUtil.ts index 1b1c69cbd..35a9002b9 100644 --- a/optuna_dashboard/ts/graphUtil.ts +++ b/optuna_dashboard/ts/graphUtil.ts @@ -1,3 +1,104 @@ +const PADDING_RATIO = 0.05 + +export type AxisInfo = { + name: string + isLog: boolean + isCat: boolean + indices: (string | number)[] + values: (string | number | null)[] +} + +// eslint-disable-next-line @typescript-eslint/no-explicit-any +const unique = (array: any[]) => { + const knownElements = new Map() + array.forEach((elem) => knownElements.set(elem, true)) + return Array.from(knownElements.keys()) +} + +export const getAxisInfo = ( + trials: Trial[], + param: SearchSpaceItem +): AxisInfo => { + if (param.distribution.type === "CategoricalDistribution") { + return getAxisInfoForCategoricalParams( + trials, + param.name, + param.distribution + ) + } else { + return getAxisInfoForNumericalParams(trials, param.name, param.distribution) + } +} + +const getAxisInfoForCategoricalParams = ( + trials: Trial[], + paramName: string, + distribution: CategoricalDistribution +): AxisInfo => { + const values = trials.map( + (trial) => + trial.params.find((p) => p.name === paramName)?.param_external_value || + null + ) + + const indices = distribution.choices + .map((c) => c.value) + .sort((a, b) => + a.toLowerCase() < b.toLowerCase() + ? -1 + : a.toLowerCase() > b.toLowerCase() + ? 1 + : 0 + ) + return { + name: paramName, + isLog: false, + isCat: true, + indices, + values, + } +} + +const getAxisInfoForNumericalParams = ( + trials: Trial[], + paramName: string, + distribution: FloatDistribution | IntDistribution +): AxisInfo => { + let min = 0 + let max = 0 + if (distribution.log) { + const padding = + (Math.log10(distribution.high) - Math.log10(distribution.low)) * + PADDING_RATIO + min = Math.pow(10, Math.log10(distribution.low) - padding) + max = Math.pow(10, Math.log10(distribution.high) + padding) + } else { + const padding = (distribution.high - distribution.low) * PADDING_RATIO + min = distribution.low - padding + max = distribution.high + padding + } + + const values = trials.map( + (trial) => + trial.params.find((p) => p.name === paramName)?.param_internal_value || + null + ) + const indices = unique(values) + .filter((v) => v !== null) + .sort((a, b) => a - b) + if (indices.length >= 2) { + indices.unshift(min) + indices.push(max) + } + return { + name: paramName, + isLog: distribution.log, + isCat: false, + indices, + values, + } +} + export const makeHovertext = (trial: Trial): string => { return JSON.stringify( {