Skip to content

Commit

Permalink
Move getAxis to tslib
Browse files Browse the repository at this point in the history
  • Loading branch information
keisuke-umezawa committed Nov 1, 2024
1 parent 12319fa commit 42380e2
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 118 deletions.
1 change: 1 addition & 0 deletions optuna_dashboard/ts/components/GraphContour.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
import blue from "@mui/material/colors/blue"
import {
GraphContainer,
getAxisInfo,
useGraphComponentState,
useMergedUnionSearchSpace,
} from "@optuna/react"
Expand Down
2 changes: 1 addition & 1 deletion optuna_dashboard/ts/components/GraphParetoFront.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ import {
Typography,
useTheme,
} from "@mui/material"
import { makeHovertext } from "@optuna/react"
import * as Optuna from "@optuna/types"
import * as plotly from "plotly.js-dist-min"
import React, { FC, useEffect, useState } from "react"
import { useNavigate } from "react-router-dom"
import { StudyDetail, Trial } from "ts/types/optuna"
import { PlotType } from "../apiClient"
import { useConstants } from "../constantsProvider"
import { makeHovertext } from "../graphUtil"
import { usePlot } from "../hooks/usePlot"
import { usePlotlyColorTheme } from "../state"
import { useBackendRender } from "../state"
Expand Down
3 changes: 2 additions & 1 deletion optuna_dashboard/ts/components/GraphRank.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ import {
} from "@mui/material"
import {
GraphContainer,
getAxisInfo,
useGraphComponentState,
useMergedUnionSearchSpace,
} from "@optuna/react"
import { makeHovertext } from "@optuna/react"
import * as plotly from "plotly.js-dist-min"
import React, { FC, useEffect, useState } from "react"
import { SearchSpaceItem, StudyDetail, Trial } from "ts/types/optuna"
import { PlotType } from "../apiClient"
import { getAxisInfo, makeHovertext } from "../graphUtil"
import { usePlot } from "../hooks/usePlot"
import { useBackendRender, usePlotlyColorTheme } from "../state"

Expand Down
117 changes: 1 addition & 116 deletions optuna_dashboard/ts/graphUtil.ts
Original file line number Diff line number Diff line change
@@ -1,120 +1,5 @@
import * as Optuna from "@optuna/types"
import { SearchSpaceItem, StudyDetail, Trial } from "./types/optuna"

const PADDING_RATIO = 0.05

export type AxisInfo = {
name: string
isLog: boolean
isCat: boolean
indices: (string | number)[]
values: (string | number | null)[]
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const unique = (array: any[]) => {
const knownElements = new Map()
array.forEach((elem) => knownElements.set(elem, true))
return Array.from(knownElements.keys())
}

export const getAxisInfo = (
trials: Trial[],
param: SearchSpaceItem
): AxisInfo => {
if (param.distribution.type === "CategoricalDistribution") {
return getAxisInfoForCategoricalParams(
trials,
param.name,
param.distribution
)
} else {
return getAxisInfoForNumericalParams(trials, param.name, param.distribution)
}
}

const getAxisInfoForCategoricalParams = (
trials: Trial[],
paramName: string,
distribution: Optuna.CategoricalDistribution
): AxisInfo => {
const values = trials.map(
(trial) =>
trial.params.find((p) => p.name === paramName)?.param_external_value ||
null
)

const indices = distribution.choices
.map((c) => c?.toString() ?? "null")
.sort((a, b) =>
a.toLowerCase() < b.toLowerCase()
? -1
: a.toLowerCase() > b.toLowerCase()
? 1
: 0
)
return {
name: paramName,
isLog: false,
isCat: true,
indices,
values,
}
}

const getAxisInfoForNumericalParams = (
trials: Trial[],
paramName: string,
distribution: Optuna.FloatDistribution | Optuna.IntDistribution
): AxisInfo => {
let min = 0
let max = 0
if (distribution.log) {
const padding =
(Math.log10(distribution.high) - Math.log10(distribution.low)) *
PADDING_RATIO
min = Math.pow(10, Math.log10(distribution.low) - padding)
max = Math.pow(10, Math.log10(distribution.high) + padding)
} else {
const padding = (distribution.high - distribution.low) * PADDING_RATIO
min = distribution.low - padding
max = distribution.high + padding
}

const values = trials.map(
(trial) =>
trial.params.find((p) => p.name === paramName)?.param_internal_value ||
null
)
const indices = unique(values)
.filter((v) => v !== null)
.sort((a, b) => a - b)
if (indices.length >= 2) {
indices.unshift(min)
indices.push(max)
}
return {
name: paramName,
isLog: distribution.log,
isCat: false,
indices,
values,
}
}

export const makeHovertext = (trial: Trial): string => {
return JSON.stringify(
{
number: trial.number,
values: trial.values,
params: trial.params
.map((p) => [p.name, p.param_external_value])
.reduce((obj, [key, value]) => ({ ...obj, [key]: value }), {}),
},
undefined,
" "
).replace(/\n/g, "<br>")
}
import { StudyDetail } from "./types/optuna"

export const studyDetailToStudy = (
studyDetail: StudyDetail | null
Expand Down
2 changes: 2 additions & 0 deletions tslib/react/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ export {
useObjectiveAndUserAttrTargetsFromStudies,
} from "./utils/trialFilter"
export { useMergedUnionSearchSpace } from "./utils/searchSpace"
export { makeHovertext, getAxisInfo } from "./utils/graphUtil"
export type { AxisInfo } from "./utils/graphUtil"
export type { GraphComponentState } from "./types"
101 changes: 101 additions & 0 deletions tslib/react/src/utils/graphUtil.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,22 @@
import * as Optuna from "@optuna/types"

const PADDING_RATIO = 0.05

export type AxisInfo = {
name: string
isLog: boolean
isCat: boolean
indices: (string | number)[]
values: (string | number | boolean | null)[]
}

// eslint-disable-next-line @typescript-eslint/no-explicit-any
const unique = (array: any[]) => {

Check failure on line 14 in tslib/react/src/utils/graphUtil.ts

View workflow job for this annotation

GitHub Actions / Lint checking on Ubuntu

Unexpected any. Specify a different type.
const knownElements = new Map()
array.forEach((elem) => knownElements.set(elem, true))

Check failure on line 16 in tslib/react/src/utils/graphUtil.ts

View workflow job for this annotation

GitHub Actions / Lint checking on Ubuntu

Prefer for...of instead of forEach.
return Array.from(knownElements.keys())
}

export const makeHovertext = (trial: Optuna.Trial): string => {
return JSON.stringify(
{
Expand All @@ -16,3 +33,87 @@ export const makeHovertext = (trial: Optuna.Trial): string => {
" "
).replace(/\n/g, "<br>")
}

export const getAxisInfo = (
trials: Optuna.Trial[],
param: Optuna.SearchSpaceItem
): AxisInfo => {
if (param.distribution.type === "CategoricalDistribution") {
return getAxisInfoForCategoricalParams(
trials,
param.name,
param.distribution
)
} else {
return getAxisInfoForNumericalParams(trials, param.name, param.distribution)
}

Check failure on line 49 in tslib/react/src/utils/graphUtil.ts

View workflow job for this annotation

GitHub Actions / Lint checking on Ubuntu

This else clause can be omitted because previous branches break early.
}

const getAxisInfoForCategoricalParams = (
trials: Optuna.Trial[],
paramName: string,
distribution: Optuna.CategoricalDistribution
): AxisInfo => {
const values = trials.map(
(trial) =>
trial.params.find((p) => p.name === paramName)?.param_external_value ||
null
)

const indices = distribution.choices
.map((c) => c?.toString() ?? "null")
.sort((a, b) =>
a.toLowerCase() < b.toLowerCase()
? -1
: a.toLowerCase() > b.toLowerCase()
? 1
: 0
)
return {
name: paramName,
isLog: false,
isCat: true,
indices,
values,
}
}

const getAxisInfoForNumericalParams = (
trials: Optuna.Trial[],
paramName: string,
distribution: Optuna.FloatDistribution | Optuna.IntDistribution
): AxisInfo => {
let min = 0
let max = 0
if (distribution.log) {
const padding =
(Math.log10(distribution.high) - Math.log10(distribution.low)) *
PADDING_RATIO
min = Math.pow(10, Math.log10(distribution.low) - padding)

Check failure on line 92 in tslib/react/src/utils/graphUtil.ts

View workflow job for this annotation

GitHub Actions / Lint checking on Ubuntu

Use the '**' operator instead of 'Math.pow'.
max = Math.pow(10, Math.log10(distribution.high) + padding)

Check failure on line 93 in tslib/react/src/utils/graphUtil.ts

View workflow job for this annotation

GitHub Actions / Lint checking on Ubuntu

Use the '**' operator instead of 'Math.pow'.
} else {
const padding = (distribution.high - distribution.low) * PADDING_RATIO
min = distribution.low - padding
max = distribution.high + padding
}

const values = trials.map(
(trial) =>
trial.params.find((p) => p.name === paramName)?.param_internal_value ||
null
)
const indices = unique(values)
.filter((v) => v !== null)
.sort((a, b) => a - b)
if (indices.length >= 2) {
indices.unshift(min)
indices.push(max)
}
return {
name: paramName,
isLog: distribution.log,
isCat: false,
indices,
values,
}
}

0 comments on commit 42380e2

Please sign in to comment.