diff --git a/optuna_dashboard/ts/components/DataGrid.tsx b/optuna_dashboard/ts/components/DataGrid.tsx index d81abc25a..adf64e774 100644 --- a/optuna_dashboard/ts/components/DataGrid.tsx +++ b/optuna_dashboard/ts/components/DataGrid.tsx @@ -10,12 +10,16 @@ import { TableSortLabel, Collapse, IconButton, - useTheme, + Menu, + MenuItem, } from "@mui/material" import { styled } from "@mui/system" import KeyboardArrowDownIcon from "@mui/icons-material/KeyboardArrowDown" import KeyboardArrowUpIcon from "@mui/icons-material/KeyboardArrowUp" -import { Clear } from "@mui/icons-material" +import CheckBoxOutlineBlankIcon from "@mui/icons-material/CheckBoxOutlineBlank" +import CheckBoxIcon from "@mui/icons-material/CheckBox" +import FilterListIcon from "@mui/icons-material/FilterList" +import ListItemIcon from "@mui/material/ListItemIcon" type Order = "asc" | "desc" @@ -29,14 +33,14 @@ interface DataGridColumn { label: string sortable?: boolean less?: (a: T, b: T, ascending: boolean) => number - filterable?: boolean + filterChoices?: string[] toCellValue?: (rowIndex: number) => string | React.ReactNode padding?: "normal" | "checkbox" | "none" } interface RowFilter { columnIdx: number - value: Value + values: Value[] } function DataGrid(props: { @@ -81,24 +85,13 @@ function DataGrid(props: { } // Filtering - const fieldAlreadyFiltered = (columnIdx: number): boolean => - filters.some((f) => f.columnIdx === columnIdx) - - const handleClickFilterCell = (columnIdx: number, value: Value) => { - if (fieldAlreadyFiltered(columnIdx)) { - return - } - const newFilters = [...filters, { columnIdx: columnIdx, value: value }] - setFilters(newFilters) - } - const filteredRows = rows.filter((row, rowIdx) => { if (defaultFilter !== undefined && defaultFilter(row)) { return false } return filters.length === 0 ? true - : filters.some((f) => { + : filters.every((f) => { if (columns.length <= f.columnIdx) { console.log( `columnIdx=${f.columnIdx} must be smaller than columns.length=${columns.length}` @@ -106,11 +99,11 @@ function DataGrid(props: { return true } const toCellValue = columns[f.columnIdx].toCellValue - if (toCellValue !== undefined) { - return toCellValue(rowIdx) === f.value - } - const field = columns[f.columnIdx].field - return row[field] === f.value + const cellValue = + toCellValue !== undefined + ? toCellValue(rowIdx) + : row[columns[f.columnIdx].field] + return f.values.some((v) => v === cellValue) }) }) @@ -137,21 +130,32 @@ function DataGrid(props: { {collapseBody ? : null} - {columns.map((column, columnIdx) => ( - - key={column.label} - column={column} - orderBy={orderBy === columnIdx ? order : null} - onOrderByChange={(direction: Order) => { - setOrder(direction) - setOrderBy(columnIdx) - }} - onFilterClear={() => { - setFilters(filters.filter((f) => f.columnIdx !== columnIdx)) - }} - filtered={fieldAlreadyFiltered(columnIdx)} - /> - ))} + {columns.map((column, columnIdx) => { + return ( + + key={columnIdx} + column={column} + order={orderBy === columnIdx ? order : null} + filter={ + filters.find((f) => f.columnIdx === columnIdx) || null + } + onOrderByChange={(direction: Order) => { + setOrder(direction) + setOrderBy(columnIdx) + }} + onFilterChange={(values: Value[]) => { + const newFilters = filters.filter( + (f) => f.columnIdx !== columnIdx + ) + newFilters.push({ + columnIdx: columnIdx, + values: values, + }) + setFilters(newFilters) + }} + /> + ) + })} @@ -163,7 +167,6 @@ function DataGrid(props: { keyField={keyField} collapseBody={collapseBody} key={`${row[keyField]}`} - handleClickFilterCell={handleClickFilterCell} /> ))} {emptyRows > 0 && ( @@ -187,70 +190,103 @@ function DataGrid(props: { ) } +const TableHeaderCellSpan = styled("span")({ + display: "inline-flex", +}) + +const HiddenSpan = styled("span")({ + border: 0, + clip: "rect(0 0 0 0)", + height: 1, + margin: -1, + overflow: "hidden", + padding: 0, + position: "absolute", + top: 20, + width: 1, +}) + function DataGridHeaderColumn(props: { column: DataGridColumn - orderBy: Order | null - onOrderByChange: (direction: Order) => void - filtered: boolean - onFilterClear: () => void + order: Order | null + onOrderByChange: (order: Order) => void + filter: RowFilter | null + onFilterChange: (values: Value[]) => void dense?: boolean }) { - const { column, orderBy, onOrderByChange, filtered, onFilterClear, dense } = + const { column, order, onOrderByChange, filter, onFilterChange, dense } = props + const [filterMenuAnchorEl, setFilterMenuAnchorEl] = + React.useState(null) + + const filterChoices = column.filterChoices - const HiddenSpan = styled("span")({ - border: 0, - clip: "rect(0 0 0 0)", - height: 1, - margin: -1, - overflow: "hidden", - padding: 0, - position: "absolute", - top: 20, - width: 1, - }) - const TableHeaderCellSpan = styled("span")({ - display: "inline-flex", - }) return ( {column.sortable ? ( { - if (orderBy === null) { - onOrderByChange("asc") - } else { - onOrderByChange(orderBy === "desc" ? "asc" : "desc") - } + onOrderByChange(order === "asc" ? "desc" : "asc") }} > {column.label} - {orderBy !== null ? ( + {order !== null ? ( - {orderBy === "desc" ? "sorted descending" : "sorted ascending"} + {order === "desc" ? "sorted descending" : "sorted ascending"} ) : null} ) : ( column.label )} - {column.filterable ? ( - { - onFilterClear() - }} - > - - + {filterChoices !== undefined ? ( + <> + { + setFilterMenuAnchorEl(e.currentTarget) + }} + > + + + { + setFilterMenuAnchorEl(null) + }} + > + {filterChoices.map((choice) => ( + { + const newTickedValues = + filter === null + ? filterChoices.filter((v) => v !== choice) // By default, every choice is ticked, so the chosen option will be unticked. + : filter.values.some((v) => v === choice) + ? filter.values.filter((v) => v !== choice) + : [...filter.values, choice] + onFilterChange(newTickedValues) + }} + > + + {!filter || filter.values.some((v) => v === choice) ? ( + + ) : ( + + )} + + {choice} + + ))} + + ) : null} @@ -263,24 +299,10 @@ function DataGridRow(props: { row: T keyField: keyof T collapseBody?: (rowIndex: number) => React.ReactNode - handleClickFilterCell: (columnIdx: number, value: Value) => void }) { - const { - columns, - rowIndex, - row, - keyField, - collapseBody, - handleClickFilterCell, - } = props + const { columns, rowIndex, row, keyField, collapseBody } = props const [open, setOpen] = React.useState(false) - const theme = useTheme() - const FilterableDiv = styled("div")({ - color: theme.palette.primary.main, - textDecoration: "underline", - cursor: "pointer", - }) return ( @@ -301,21 +323,7 @@ function DataGridRow(props: { : // TODO(c-bata): Avoid this implicit type conversion. (row[column.field] as number | string | null | undefined) - return column.filterable ? ( - { - const value = - column.toCellValue !== undefined - ? column.toCellValue(rowIndex) - : row[column.field] - handleClickFilterCell(columnIndex, value) - }} - > - {cellItem} - - ) : ( + return ( trials[i].state.toString(), }, @@ -97,7 +97,10 @@ export const TrialTable: FC<{ ) { studyDetail?.intersection_search_space.forEach((s) => { const sortable = s.distribution.type !== "CategoricalDistribution" - const filterable = s.distribution.type === "CategoricalDistribution" + const filterChoices = + s.distribution.type === "CategoricalDistribution" + ? s.distribution.choices.map((c) => c.value) + : undefined columns.push({ field: "params", label: `Param ${s.name}`, @@ -105,7 +108,7 @@ export const TrialTable: FC<{ trials[i].params.find((p) => p.name === s.name) ?.param_external_value || null, sortable: sortable, - filterable: filterable, + filterChoices: filterChoices, // eslint-disable-next-line @typescript-eslint/no-unused-vars less: (firstEl, secondEl, _): number => { const firstVal = firstEl.params.find( @@ -146,7 +149,6 @@ export const TrialTable: FC<{ trials[i].user_attrs.find((attr) => attr.key === attr_spec.key) ?.value || null, sortable: attr_spec.sortable, - filterable: !attr_spec.sortable, // eslint-disable-next-line @typescript-eslint/no-unused-vars less: (firstEl, secondEl, _): number => { const firstVal = firstEl.user_attrs.find( diff --git a/typescript_tests/DataGrid.test.tsx b/typescript_tests/DataGrid.test.tsx index 475e360ab..6d13deb74 100644 --- a/typescript_tests/DataGrid.test.tsx +++ b/typescript_tests/DataGrid.test.tsx @@ -1,7 +1,7 @@ import React from "react" global.URL.createObjectURL = jest.fn() -import { cleanup, render, fireEvent } from "@testing-library/react" +import { cleanup, render } from "@testing-library/react" import { DataGrid, DataGridColumn, @@ -9,6 +9,7 @@ import { afterEach(cleanup) +// TODO(c-bata): Add tests to check filterChoices option it("Filter rows of DataGrid", () => { interface DummyAttribute { id: number @@ -23,7 +24,7 @@ it("Filter rows of DataGrid", () => { { id: 5, key: "foo", value: 3 }, ] const columns: DataGridColumn[] = [ - { field: "key", label: "Key", filterable: true }, + { field: "key", label: "Key" }, { field: "value", label: "Value", @@ -39,46 +40,4 @@ it("Filter rows of DataGrid", () => { /> ) expect(queryAllByText("bar").length).toBe(2) - - // Filter rows by "foo" - fireEvent.click(queryAllByText("foo")[0]) - expect(queryAllByText("foo").length).toBe(3) - expect(queryAllByText("bar").length).toBe(0) -}) - -it("Filter rows after sorted", () => { - interface DummyAttribute { - id: number - key: string - value: number - } - const dummyAttributes = [ - { id: 1, key: "foo", value: 4000 }, - { id: 2, key: "bar", value: 1000 }, - { id: 3, key: "bar", value: 2000 }, - { id: 4, key: "foo", value: 3000 }, - { id: 5, key: "foo", value: 5000 }, - ] - const columns: DataGridColumn[] = [ - { field: "key", label: "Key", filterable: true }, - { - field: "value", - label: "Value", - sortable: true, - }, - ] - - const { getByText, queryAllByText } = render( - - columns={columns} - rows={dummyAttributes} - keyField={"id"} - /> - ) - // Sort and filter rows - fireEvent.click(getByText("Value")) - fireEvent.click(queryAllByText("bar")[0]) - - expect(queryAllByText("1000").length).toBe(1) - expect(queryAllByText("2000").length).toBe(1) })