Skip to content

Commit

Permalink
Merge pull request optuna#608 from moririn2528/history-undo
Browse files Browse the repository at this point in the history
remove/restore History
  • Loading branch information
c-bata authored Sep 22, 2023
2 parents cbcb920 + 02bfb79 commit ae11fda
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 62 deletions.
8 changes: 3 additions & 5 deletions optuna_dashboard/_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,8 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
) = get_cached_extra_study_property(study_id, trials)

plotly_graph_objects = get_plotly_graph_objects(system_attrs)
trials_id2number = {trial._trial_id: trial.number for trial in trials}
skipped_trials = [
trials_id2number[trial_id] for trial_id in get_skipped_trial_ids(system_attrs)
]
skipped_trial_ids = get_skipped_trial_ids(system_attrs)
skipped_trial_numbers = [t.number for t in trials if t._trial_id in skipped_trial_ids]
return serialize_study_detail(
summary,
best_trials,
Expand All @@ -235,7 +233,7 @@ def get_study_detail(study_id: int) -> dict[str, Any]:
union_user_attrs,
has_intermediate_values,
plotly_graph_objects,
skipped_trials,
skipped_trial_numbers,
)

@app.get("/api/studies/<study_id:int>/param_importances")
Expand Down
4 changes: 2 additions & 2 deletions optuna_dashboard/_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def serialize_study_detail(
union_user_attrs: list[tuple[str, bool]],
has_intermediate_values: bool,
plotly_graph_objects: dict[str, str],
skipped_trials: list[int],
skipped_trial_numbers: list[int],
) -> dict[str, Any]:
serialized: dict[str, Any] = {
"name": summary.study_name,
Expand Down Expand Up @@ -174,7 +174,7 @@ def serialize_study_detail(
if serialized["is_preferential"]:
serialized["preference_history"] = serialize_preference_history(system_attrs)
serialized["preferences"] = get_preferences(system_attrs)
serialized["skipped_trials"] = skipped_trials
serialized["skipped_trial_numbers"] = skipped_trial_numbers
serialized["plotly_graph_objects"] = [
{"id": id_, "graph_object": graph_object}
for id_, graph_object in plotly_graph_objects.items()
Expand Down
55 changes: 52 additions & 3 deletions optuna_dashboard/ts/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import {
deleteArtifactAPI,
reportPreferenceAPI,
skipPreferentialTrialAPI,
removePreferentialHistoryAPI,
restorePreferentialHistoryAPI,
reportFeedbackComponentAPI,
} from "./apiClient"
import {
Expand Down Expand Up @@ -587,11 +589,11 @@ export const actionCreator = () => {
}

const updatePreference = (
study_id: number,
studyId: number,
candidates: number[],
clicked: number
) => {
reportPreferenceAPI(study_id, candidates, clicked).catch((err) => {
reportPreferenceAPI(studyId, candidates, clicked).catch((err) => {
const reason = err.response?.data.reason
enqueueSnackbar(`Failed to report preference. Reason: ${reason}`, {
variant: "error",
Expand All @@ -609,7 +611,6 @@ export const actionCreator = () => {
console.log(err)
})
}

const updateFeedbackComponent = (
studyId: number,
compoennt_type: FeedbackComponentType
Expand All @@ -632,6 +633,52 @@ export const actionCreator = () => {
})
}

const removePreferentialHistory = (studyId: number, historyId: string) => {
removePreferentialHistoryAPI(studyId, historyId)
.then(() => {
const newStudy = Object.assign({}, studyDetails[studyId])
newStudy.preference_history = newStudy.preference_history?.map((h) =>
h.id === historyId ? { ...h, is_removed: true } : h
)
const removed = newStudy.preference_history
?.filter((h) => h.id === historyId)
.pop()?.preferences
newStudy.preferences = newStudy.preferences?.filter(
(p) => !removed?.some((r) => r[0] === p[0] && r[1] === p[1])
)
setStudyDetailState(studyId, newStudy)
})
.catch((err) => {
const reason = err.response?.data.reason

enqueueSnackbar(`Failed to switch history. Reason: ${reason}`, {
variant: "error",
})
console.log(err)
})
}
const restorePreferentialHistory = (studyId: number, historyId: string) => {
restorePreferentialHistoryAPI(studyId, historyId)
.then(() => {
const newStudy = Object.assign({}, studyDetails[studyId])
newStudy.preference_history = newStudy.preference_history?.map((h) =>
h.id === historyId ? { ...h, is_removed: false } : h
)
const restored = newStudy.preference_history
?.filter((h) => h.id === historyId)
.pop()?.preferences
newStudy.preferences = newStudy.preferences?.concat(restored ?? [])
setStudyDetailState(studyId, newStudy)
})
.catch((err) => {
const reason = err.response?.data.reason
enqueueSnackbar(`Failed to switch history. Reason: ${reason}`, {
variant: "error",
})
console.log(err)
})
}

return {
updateAPIMeta,
updateStudyDetail,
Expand All @@ -653,6 +700,8 @@ export const actionCreator = () => {
saveTrialUserAttrs,
updatePreference,
skipPreferentialTrial,
removePreferentialHistory,
restorePreferentialHistory,
updateFeedbackComponent,
}
}
Expand Down
33 changes: 26 additions & 7 deletions optuna_dashboard/ts/apiClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,9 @@ const convertTrialResponse = (res: TrialResponse): Trial => {
}
}

interface PreferenceHistoryResponce {
interface PreferenceHistoryResponse {
history: {
id: string
preference_id: string
candidates: number[]
clicked: number
mode: PreferenceFeedbackMode
Expand All @@ -69,11 +68,10 @@ interface PreferenceHistoryResponce {
}

const convertPreferenceHistory = (
res: PreferenceHistoryResponce
res: PreferenceHistoryResponse
): PreferenceHistory => {
return {
id: res.history.id,
preference_id: res.history.preference_id,
candidates: res.history.candidates,
clicked: res.history.clicked,
feedback_mode: res.history.mode,
Expand All @@ -99,10 +97,10 @@ interface StudyDetailResponse {
objective_names?: string[]
form_widgets?: FormWidgets
preferences?: [number, number][]
preference_history?: PreferenceHistoryResponce[]
preference_history?: PreferenceHistoryResponse[]
plotly_graph_objects: PlotlyGraphObject[]
feedback_component_type: FeedbackComponentType
skipped_trials?: number[]
skipped_trial_numbers?: number[]
}

export const getStudyDetailAPI = (
Expand Down Expand Up @@ -144,7 +142,7 @@ export const getStudyDetailAPI = (
convertPreferenceHistory
),
plotly_graph_objects: res.data.plotly_graph_objects,
skipped_trials: res.data.skipped_trials ?? [],
skipped_trial_numbers: res.data.skipped_trial_numbers ?? [],
}
})
}
Expand Down Expand Up @@ -379,6 +377,27 @@ export const skipPreferentialTrialAPI = (
})
}

export const removePreferentialHistoryAPI = (
studyId: number,
historyUuid: string
): Promise<void> => {
return axiosInstance
.delete<void>(`/api/studies/${studyId}/preference/${historyUuid}`)
.then(() => {
return
})
}
export const restorePreferentialHistoryAPI = (
studyId: number,
historyUuid: string
): Promise<void> => {
return axiosInstance
.post<void>(`/api/studies/${studyId}/preference/${historyUuid}`)
.then(() => {
return
})
}

export const reportFeedbackComponentAPI = (
studyId: number,
component_type: FeedbackComponentType
Expand Down
64 changes: 55 additions & 9 deletions optuna_dashboard/ts/components/PreferenceHistory.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ import {
import ClearIcon from "@mui/icons-material/Clear"
import IconButton from "@mui/material/IconButton"
import OpenInFullIcon from "@mui/icons-material/OpenInFull"
import RestoreFromTrashIcon from "@mui/icons-material/RestoreFromTrash"
import DeleteIcon from "@mui/icons-material/Delete"
import Modal from "@mui/material/Modal"
import { red } from "@mui/material/colors"

import { TrialListDetail } from "./TrialList"
import { getArtifactUrlPath } from "./PreferentialTrials"
import { formatDate } from "../dateUtil"
import { actionCreator } from "../action"
import { useStudyDetailValue } from "../state"
import { PreferentialOutputComponent } from "./PreferentialOutputComponent"

Expand Down Expand Up @@ -157,32 +160,74 @@ const CandidateTrial: FC<{
)
}

const ChoiceTrials: FC<{ choice: PreferenceHistory; trials: Trial[] }> = ({
choice,
trials,
}) => {
const ChoiceTrials: FC<{
choice: PreferenceHistory
trials: Trial[]
studyId: number
}> = ({ choice, trials, studyId }) => {
const [isRemoved, setRemoved] = useState(choice.is_removed)
const theme = useTheme()
const worst_trials = new Set([choice.clicked])
const action = actionCreator()

return (
<Box
sx={{
marginBottom: theme.spacing(4),
position: "relative",
}}
>
<Typography
variant="h6"
<Box
sx={{
fontWeight: theme.typography.fontWeightLight,
display: "flex",
flexDirection: "row",
flexWrap: "wrap",
}}
>
{formatDate(choice.timestamp)}
</Typography>
<Typography
variant="h6"
sx={{
fontWeight: theme.typography.fontWeightLight,
margin: "auto 0",
}}
>
{formatDate(choice.timestamp)}
</Typography>
{choice.is_removed ? (
<IconButton
disabled={!isRemoved}
onClick={() => {
setRemoved(false)
action.restorePreferentialHistory(studyId, choice.id)
}}
sx={{
margin: `auto ${theme.spacing(2)}`,
}}
>
<RestoreFromTrashIcon />
</IconButton>
) : (
<IconButton
disabled={isRemoved}
onClick={() => {
setRemoved(true)
action.removePreferentialHistory(studyId, choice.id)
}}
sx={{
margin: `auto ${theme.spacing(2)}`,
}}
>
<DeleteIcon />
</IconButton>
)}
</Box>
<Box
sx={{
display: "flex",
flexDirection: "row",
flexWrap: "wrap",
filter: choice.is_removed ? "brightness(0.4)" : undefined,
backgroundColor: theme.palette.background.paper,
}}
>
{choice.candidates.map((trial_num, index) => (
Expand Down Expand Up @@ -234,6 +279,7 @@ export const PreferenceHistory: FC<{ studyDetail: StudyDetail | null }> = ({
key={choice.id}
choice={choice}
trials={studyDetail.trials}
studyId={studyDetail.id}
/>
))}
</Box>
Expand Down
7 changes: 4 additions & 3 deletions optuna_dashboard/ts/components/PreferentialGraph.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ export const PreferentialGraph: FC<{
if (!studyDetail.is_preferential || studyDetail.preferences === undefined)
return
const preferences = reductionPreference(studyDetail.preferences)
const trialNodes = Array.from(new Set(preferences.flat()))
const graph: ElkNode = {
id: "root",
layoutOptions: {
Expand All @@ -211,8 +212,8 @@ export const PreferentialGraph: FC<{
"elk.layered.spacing.nodeNodeBetweenLayers": nodeMargin.toString(),
"elk.spacing.nodeNode": nodeMargin.toString(),
},
children: studyDetail.trials.map((trial) => ({
id: `${trial.number}`,
children: trialNodes.map((trial) => ({
id: `${trial}`,
targetPosition: "top",
sourcePosition: "bottom",
width: nodeWidth,
Expand All @@ -229,7 +230,7 @@ export const PreferentialGraph: FC<{
.then((layoutedGraph) => {
setNodes(
layoutedGraph.children?.map((node, index) => {
const trial = studyDetail.trials[index]
const trial = studyDetail.trials[trialNodes[index]]
return {
id: `${trial.number}`,
type: "note",
Expand Down
Loading

0 comments on commit ae11fda

Please sign in to comment.