Skip to content

Commit

Permalink
Merge pull request #680 from keisuke-umezawa/fix/refactor-rank-plot
Browse files Browse the repository at this point in the history
Use the same implementation of getAxisInfo for rank and contour plot
  • Loading branch information
HideakiImamura authored Dec 4, 2023
2 parents 9eeda40 + d2748eb commit 9f021f5
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 256 deletions.
109 changes: 1 addition & 108 deletions optuna_dashboard/ts/components/GraphContour.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,8 @@ import {
import blue from "@mui/material/colors/blue"
import { plotlyDarkTemplate } from "./PlotlyDarkMode"
import { useMergedUnionSearchSpace } from "../searchSpace"
import { getAxisInfo } from "../graphUtil"

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const unique = (array: any[]) => {
const knownElements = new Map()
array.forEach((elem) => knownElements.set(elem, true))
return Array.from(knownElements.keys())
}

type AxisInfo = {
name: string
min: number
max: number
isLog: boolean
isCat: boolean
indices: (string | number)[]
values: (string | number | null)[]
}

const PADDING_RATIO = 0.05
const plotDomId = "graph-contour"

export const Contour: FC<{
Expand Down Expand Up @@ -284,93 +267,3 @@ const plotContour = (
]
plotly.react(plotDomId, plotData, layout)
}

const getAxisInfoForNumericalParams = (
trials: Trial[],
paramName: string,
distribution: FloatDistribution | IntDistribution
): AxisInfo => {
let min = 0
let max = 0
if (distribution.log) {
const padding =
(Math.log10(distribution.high) - Math.log10(distribution.low)) *
PADDING_RATIO
min = Math.pow(10, Math.log10(distribution.low) - padding)
max = Math.pow(10, Math.log10(distribution.high) + padding)
} else {
const padding = (distribution.high - distribution.low) * PADDING_RATIO
min = distribution.low - padding
max = distribution.high + padding
}

const values = trials.map(
(trial) =>
trial.params.find((p) => p.name === paramName)?.param_internal_value ||
null
)
const indices = unique(values)
.filter((v) => v !== null)
.sort((a, b) => a - b)
if (indices.length >= 2) {
indices.unshift(min)
indices.push(max)
}
return {
name: paramName,
min,
max,
isLog: distribution.log,
isCat: false,
indices,
values,
}
}

const getAxisInfoForCategoricalParams = (
trials: Trial[],
paramName: string,
distribution: CategoricalDistribution
): AxisInfo => {
const values = trials.map(
(trial) =>
trial.params.find((p) => p.name === paramName)?.param_external_value ||
null
)
const isDynamic = values.some((v) => v === null)
const span = distribution.choices.length - (isDynamic ? 2 : 1)
const padding = span * PADDING_RATIO
const min = -padding
const max = span + padding

const indices = distribution.choices
.map((c) => c.value)
.sort((a, b) =>
a.toLowerCase() < b.toLowerCase()
? -1
: a.toLowerCase() > b.toLowerCase()
? 1
: 0
)
return {
name: paramName,
min,
max,
isLog: false,
isCat: true,
indices,
values,
}
}

const getAxisInfo = (trials: Trial[], param: SearchSpaceItem): AxisInfo => {
if (param.distribution.type === "CategoricalDistribution") {
return getAxisInfoForCategoricalParams(
trials,
param.name,
param.distribution
)
} else {
return getAxisInfoForNumericalParams(trials, param.name, param.distribution)
}
}
216 changes: 68 additions & 148 deletions optuna_dashboard/ts/components/GraphRank.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,18 @@ import {
Box,
} from "@mui/material"
import { plotlyDarkTemplate } from "./PlotlyDarkMode"
import { makeHovertext } from "../graphUtil"
import { getAxisInfo, makeHovertext } from "../graphUtil"
import { useMergedUnionSearchSpace } from "../searchSpace"

const PADDING_RATIO = 0.05
const plotDomId = "graph-rank"

interface AxisInfo {
name: string
range: [number, number]
isLog: boolean
isCat: boolean
}

interface RankPlotInfo {
xaxis: AxisInfo
yaxis: AxisInfo
xtitle: string
ytitle: string
xtype: plotly.AxisType
ytype: plotly.AxisType
xvalues: (string | number)[]
yvalues: (string | number)[]
zvalues: number[]
colors: number[]
is_feasible: boolean[]
hovertext: string[]
Expand Down Expand Up @@ -154,38 +147,74 @@ const getRankPlotInfo = (
const xAxis = getAxisInfo(filteredTrials, xParam)
const yAxis = getAxisInfo(filteredTrials, yParam)

const xValues: (string | number)[] = []
const yValues: (string | number)[] = []
let xValues: (string | number)[] = []
let yValues: (string | number)[] = []
const zValues: number[] = []
const isFeasible: boolean[] = []
const hovertext: string[] = []
filteredTrials.forEach((trial) => {
const xValue =
trial.params.find((p) => p.name === xAxis.name)?.param_external_value ||
null
const yValue =
trial.params.find((p) => p.name === yAxis.name)?.param_external_value ||
null
if (trial.values === undefined || xValue === null || yValue === null) {
return
filteredTrials.forEach((trial, i) => {
const xValue = xAxis.values[i]
const yValue = yAxis.values[i]
if (xValue && yValue && trial.values) {
xValues.push(xValue)
yValues.push(yValue)
const zValue = Number(trial.values[objectiveId])
zValues.push(zValue)
const feasibility = trial.constraints.every((c) => c <= 0)
isFeasible.push(feasibility)
hovertext.push(makeHovertext(trial))
}
const zValue = Number(trial.values[objectiveId])
const feasibility = trial.constraints.every((c) => c <= 0)
xValues.push(xValue)
yValues.push(yValue)
zValues.push(zValue)
isFeasible.push(feasibility)
hovertext.push(makeHovertext(trial))
})

const colors = getColors(zValues)

if (xAxis.isCat && !yAxis.isCat) {
const indices: number[] = Array.from(Array(xValues.length).keys()).sort(
(a, b) =>
xValues[a]
.toString()
.toLowerCase()
.localeCompare(xValues[b].toString().toLowerCase())
)
xValues = indices.map((i) => xValues[i])
yValues = indices.map((i) => yValues[i])
} else if (!xAxis.isCat && yAxis.isCat) {
const indices: number[] = Array.from(Array(yValues.length).keys()).sort(
(a, b) =>
yValues[a]
.toString()
.toLowerCase()
.localeCompare(yValues[b].toString().toLowerCase())
)
xValues = indices.map((i) => xValues[i])
yValues = indices.map((i) => yValues[i])
} else if (xAxis.isCat && yAxis.isCat) {
const indices: number[] = Array.from(Array(xValues.length).keys()).sort(
(a, b) => {
const xComp = xValues[a]
.toString()
.toLowerCase()
.localeCompare(xValues[b].toString().toLowerCase())
if (xComp !== 0) {
return xComp
}
return yValues[a]
.toString()
.toLowerCase()
.localeCompare(yValues[b].toString().toLowerCase())
}
)
xValues = indices.map((i) => xValues[i])
yValues = indices.map((i) => yValues[i])
}

return {
xaxis: xAxis,
yaxis: yAxis,
xtitle: xAxis.name,
ytitle: yAxis.name,
xtype: xAxis.isCat ? "category" : xAxis.isLog ? "log" : "linear",
ytype: yAxis.isCat ? "category" : yAxis.isLog ? "log" : "linear",
xvalues: xValues,
yvalues: yValues,
zvalues: zValues,
colors,
is_feasible: isFeasible,
hovertext,
Expand All @@ -196,72 +225,6 @@ const filterFunc = (trial: Trial): boolean => {
return trial.state === "Complete" && trial.values !== undefined
}

const getAxisInfo = (trials: Trial[], param: SearchSpaceItem): AxisInfo => {
if (param.distribution.type === "CategoricalDistribution") {
return getAxisInfoForCategorical(trials, param.name, param.distribution)
} else {
return getAxisInfoForNumerical(trials, param.name, param.distribution)
}
}

const getAxisInfoForCategorical = (
trials: Trial[],
param: string,
distribution: CategoricalDistribution
): AxisInfo => {
const values = trials.map(
(trial) =>
trial.params.find((p) => p.name === param)?.param_internal_value || null
)
const isDynamic = values.some((v) => v === null)
const span = distribution.choices.length - (isDynamic ? 2 : 1)
const padding = span * PADDING_RATIO
const min = -padding
const max = span + padding

return {
name: param,
range: [min, max],
isLog: false,
isCat: true,
}
}

const getAxisInfoForNumerical = (
trials: Trial[],
param: string,
distribution: FloatDistribution | IntDistribution
): AxisInfo => {
const values = trials.map(
(trial) =>
trial.params.find((p) => p.name === param)?.param_internal_value || null
)
const nonNullValues: number[] = []
values.forEach((value) => {
if (value !== null) {
nonNullValues.push(value)
}
})
let min = Math.min(...nonNullValues)
let max = Math.max(...nonNullValues)
if (distribution.log) {
const padding = (Math.log10(max) - Math.log10(min)) * PADDING_RATIO
min = Math.pow(10, Math.log10(min) - padding)
max = Math.pow(10, Math.log10(max) + padding)
} else {
const padding = (max - min) * PADDING_RATIO
min = min - padding
max = max + padding
}

return {
name: param,
range: [min, max],
isLog: distribution.log,
isCat: false,
}
}

const getColors = (values: number[]): number[] => {
const rawRanks = getOrderWithSameOrderAveraging(values)
let colorIdxs: number[] = []
Expand Down Expand Up @@ -299,16 +262,14 @@ const plotRank = (rankPlotInfo: RankPlotInfo | null, mode: string) => {
return
}

const xAxis = rankPlotInfo.xaxis
const yAxis = rankPlotInfo.yaxis
const layout: Partial<plotly.Layout> = {
xaxis: {
title: xAxis.name,
type: xAxis.isCat ? "category" : xAxis.isLog ? "log" : "linear",
title: rankPlotInfo.xtitle,
type: rankPlotInfo.xtype,
},
yaxis: {
title: yAxis.name,
type: yAxis.isCat ? "category" : yAxis.isLog ? "log" : "linear",
title: rankPlotInfo.ytitle,
type: rankPlotInfo.ytype,
},
margin: {
l: 50,
Expand All @@ -320,49 +281,8 @@ const plotRank = (rankPlotInfo: RankPlotInfo | null, mode: string) => {
template: mode === "dark" ? plotlyDarkTemplate : {},
}

let xValues = rankPlotInfo.xvalues
let yValues = rankPlotInfo.yvalues
if (xAxis.isCat && !yAxis.isCat) {
const xIndices: number[] = Array.from(Array(xValues.length).keys()).sort(
(a, b) =>
xValues[a]
.toString()
.toLowerCase()
.localeCompare(xValues[b].toString().toLowerCase())
)
xValues = xIndices.map((i) => xValues[i])
yValues = xIndices.map((i) => yValues[i])
}
if (!xAxis.isCat && yAxis.isCat) {
const yIndices: number[] = Array.from(Array(yValues.length).keys()).sort(
(a, b) =>
yValues[a]
.toString()
.toLowerCase()
.localeCompare(yValues[b].toString().toLowerCase())
)
xValues = yIndices.map((i) => xValues[i])
yValues = yIndices.map((i) => yValues[i])
}
if (xAxis.isCat && yAxis.isCat) {
const indices: number[] = Array.from(Array(xValues.length).keys()).sort(
(a, b) => {
const xComp = xValues[a]
.toString()
.toLowerCase()
.localeCompare(xValues[b].toString().toLowerCase())
if (xComp !== 0) {
return xComp
}
return yValues[a]
.toString()
.toLowerCase()
.localeCompare(yValues[b].toString().toLowerCase())
}
)
xValues = indices.map((i) => xValues[i])
yValues = indices.map((i) => yValues[i])
}
const xValues = rankPlotInfo.xvalues
const yValues = rankPlotInfo.yvalues

const plotData: Partial<plotly.PlotData>[] = [
{
Expand Down
Loading

0 comments on commit 9f021f5

Please sign in to comment.