Skip to content

Commit

Permalink
Beam: auto-merge
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed May 13, 2024
1 parent 3a7aa75 commit 6f35f72
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 15 deletions.
48 changes: 36 additions & 12 deletions src/modules/beam/BeamView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { BeamRayGrid } from './scatter/BeamRayGrid';
import { BeamScatterInput } from './scatter/BeamScatterInput';
import { BeamScatterPane } from './scatter/BeamScatterPane';
import { BeamStoreApi, useBeamStore } from './store-beam.hooks';
import { useModuleBeamStore } from './store-module-beam';


export function BeamView(props: {
Expand All @@ -24,6 +25,7 @@ export function BeamView(props: {
}) {

// state
const [hasAutoMerged, setHasAutoMerged] = React.useState(false);
const [warnIsScattering, setWarnIsScattering] = React.useState(false);

// external state
Expand All @@ -36,7 +38,9 @@ export function BeamView(props: {
/* root */ inputHistory, inputIssues, inputReady,
/* scatter */ isScattering, raysReady,
/* gather (composite) */ canGather,
/* IDs */ rayIds, fusionIds,
} = useBeamStore(props.beamStore, useShallow(state => ({
// input
inputHistory: state.inputHistory,
inputIssues: state.inputIssues,
inputReady: state.inputReady,
Expand All @@ -45,9 +49,13 @@ export function BeamView(props: {
raysReady: state.raysReady,
// gather (composite)
canGather: state.raysReady >= 2 && state.currentFactoryId !== null && state.currentGatherLlmId !== null,
// IDs
rayIds: state.rays.map(ray => ray.rayId),
fusionIds: state.fusions.map(fusion => fusion.fusionId),
})));
const { gatherAutoStartAfterScatter } = useModuleBeamStore(useShallow(state => ({
gatherAutoStartAfterScatter: state.gatherAutoStartAfterScatter,
})));
const rayIds = useBeamStore(props.beamStore, useShallow(state => state.rays.map(ray => ray.rayId)));
const fusionIds = useBeamStore(props.beamStore, useShallow(state => state.fusions.map(fusion => fusion.fusionId)));

// derived state
const raysCount = rayIds.length;
Expand All @@ -59,6 +67,11 @@ export function BeamView(props: {

const handleRayIncreaseCount = React.useCallback(() => setRayCount(raysCount + 1), [setRayCount, raysCount]);

const handleScatterStart = React.useCallback(() => {
setHasAutoMerged(false);
startScatteringAll();
}, [startScatteringAll]);


const handleCreateFusion = React.useCallback(() => {
// if scatter is busy, ask for confirmation
Expand All @@ -70,20 +83,31 @@ export function BeamView(props: {
}, [isScattering, props.beamStore]);


const handleStopScatterConfirmation = React.useCallback(() => {
const handleStartMergeConfirmation = React.useCallback(() => {
setWarnIsScattering(false);
stopScatteringAll();
handleCreateFusion();
}, [handleCreateFusion, stopScatteringAll]);

const handleStopScatterDenial = React.useCallback(() => setWarnIsScattering(false), []);
const handleStartMergeDenial = React.useCallback(() => setWarnIsScattering(false), []);


// auto-merge
const shallAutoMerge = gatherAutoStartAfterScatter && canGather && !isScattering && !hasAutoMerged;
React.useEffect(() => {
if (shallAutoMerge) {
setHasAutoMerged(true);
handleStartMergeConfirmation();
}
}, [handleStartMergeConfirmation, shallAutoMerge]);

// (this is great ux) scatter freed up while we were asking the question, proceed
// (great ux) scatter finished while the "start merge" (warning) dialog is up: dismiss dialog and proceed
// here we assume that 'warnIsScattering' shows the intention of the user to proceed with a merge asap
const shallResumeMerge = warnIsScattering && !isScattering && !gatherAutoStartAfterScatter;
React.useEffect(() => {
if (warnIsScattering && !isScattering)
handleStopScatterConfirmation();
}, [handleStopScatterConfirmation, isScattering, warnIsScattering]);
if (shallResumeMerge)
handleStartMergeConfirmation();
}, [handleStartMergeConfirmation, shallResumeMerge]);


// runnning
Expand Down Expand Up @@ -138,7 +162,7 @@ export function BeamView(props: {
setRayCount={handleRaySetCount}
startEnabled={inputReady}
startBusy={isScattering}
onStart={startScatteringAll}
onStart={handleScatterStart}
onStop={stopScatteringAll}
onExplainerShow={explainerShow}
/>
Expand All @@ -163,7 +187,7 @@ export function BeamView(props: {
beamStore={props.beamStore}
canGather={canGather}
isMobile={props.isMobile}
onAddFusion={handleCreateFusion}
// onAddFusion={handleCreateFusion}
raysReady={raysReady}
/>

Expand All @@ -184,8 +208,8 @@ export function BeamView(props: {
{warnIsScattering && (
<ConfirmationModal
open
onClose={handleStopScatterDenial}
onPositive={handleStopScatterConfirmation}
onClose={handleStartMergeDenial}
onPositive={handleStartMergeConfirmation}
// lowStakes
noTitleBar
confirmationText='Some responses are still being generated. Do you want to stop and proceed with merging the available responses now?'
Expand Down
2 changes: 1 addition & 1 deletion src/modules/beam/gather/BeamFusionGrid.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ export function BeamFusionGrid(props: {
</Box> : (
<Typography level='body-sm' sx={{ opacity: 0.8 }}>
{/*You need two or more replies for a {currentFactory?.shortLabel?.toLocaleLowerCase() ?? ''} merge.*/}
Waiting for multiple Beams.
Waiting for multiple responses.
</Typography>
)}
</BeamCard>
Expand Down
2 changes: 1 addition & 1 deletion src/modules/beam/gather/BeamGatherPane.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export function BeamGatherPane(props: {
beamStore: BeamStoreApi,
canGather: boolean,
isMobile: boolean,
onAddFusion: () => void,
// onAddFusion: () => void,
raysReady: number,
}) {

Expand Down
4 changes: 4 additions & 0 deletions src/modules/beam/gather/beam.gather.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ interface GatherStateSlice {

// derived state (just acts as a cache to avoid re-calculating)
isGatheringAny: boolean;
// fusionsReady: number;

}

Expand All @@ -118,6 +119,7 @@ export const reInitGatherStateSlice = (prevFusions: BFusion[], gatherLlmId: DLLM
fusions: [],

isGatheringAny: false,
// fusionsReady: 0,
};
};

Expand Down Expand Up @@ -170,10 +172,12 @@ export const createGatherSlice: StateCreator<RootStoreSlice & ScatterStoreSlice

// 'or' the status of all fusions
const isGatheringAny = newFusions.some(fusionIsFusing);
// const fusionsReady = newFusions.filter(fusionIsUsableOutput).length;

_set({
fusions: newFusions,
isGatheringAny,
// fusionsReady,
});
},

Expand Down
10 changes: 9 additions & 1 deletion src/modules/beam/scatter/BeamScatterPaneDropdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import DriveFileRenameOutlineRoundedIcon from '@mui/icons-material/DriveFileRena
import MoreVertIcon from '@mui/icons-material/MoreVert';
import SchoolRoundedIcon from '@mui/icons-material/SchoolRounded';

import { DEV_MODE_SETTINGS } from '../../../apps/settings-modal/UxLabsSettings';

import type { DLLMId } from '~/modules/llms/store-llms';

import type { BeamStoreApi } from '../store-beam.hooks';
Expand Down Expand Up @@ -69,6 +71,7 @@ export function BeamScatterDropdown(props: {
cardScrolling, toggleCardScrolling,
scatterShowPrevMessages, toggleScatterShowPrevMessages,
scatterShowLettering, toggleScatterShowLettering,
gatherAutoStartAfterScatter, toggleGatherAutoStartAfterScatter,
gatherShowAllPrompts, toggleGatherShowAllPrompts,
} = useModuleBeamStore();

Expand Down Expand Up @@ -163,10 +166,15 @@ export function BeamScatterDropdown(props: {
Response Numbers
</MenuItem>

<ListItem onClick={() => handleClearLastConfig()}>
<ListItem onClick={DEV_MODE_SETTINGS ? () => handleClearLastConfig() : undefined}>
<Typography level='body-sm'>Advanced</Typography>
</ListItem>

<MenuItem onClick={toggleGatherAutoStartAfterScatter}>
<ListItemDecorator>{gatherAutoStartAfterScatter && <CheckRoundedIcon />}</ListItemDecorator>
Auto-Merge
</MenuItem>

<MenuItem onClick={toggleGatherShowAllPrompts}>
<ListItemDecorator>{gatherShowAllPrompts && <CheckRoundedIcon />}</ListItemDecorator>
Detailed Custom Merge
Expand Down
5 changes: 5 additions & 0 deletions src/modules/beam/store-module-beam.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ interface ModuleBeamState {
cardScrolling: boolean;
scatterShowLettering: boolean;
scatterShowPrevMessages: boolean;
gatherAutoStartAfterScatter: boolean;
gatherShowAllPrompts: boolean;
}

Expand All @@ -37,6 +38,7 @@ interface ModuleBeamStore extends ModuleBeamState {
toggleCardScrolling: () => void;
toggleScatterShowLettering: () => void;
toggleScatterShowPrevMessages: () => void;
toggleGatherAutoStartAfterScatter: () => void;
toggleGatherShowAllPrompts: () => void;
}

Expand All @@ -50,6 +52,7 @@ export const useModuleBeamStore = create<ModuleBeamStore>()(persist(
scatterShowLettering: false,
scatterShowPrevMessages: false,
gatherShowAllPrompts: false,
gatherAutoStartAfterScatter: false,


addPreset: (name, rayLlmIds, gatherLlmId, gatherFactoryId) => _set(state => ({
Expand Down Expand Up @@ -86,6 +89,8 @@ export const useModuleBeamStore = create<ModuleBeamStore>()(persist(

toggleScatterShowPrevMessages: () => _set(state => ({ scatterShowPrevMessages: !state.scatterShowPrevMessages })),

toggleGatherAutoStartAfterScatter: () => _set(state => ({ gatherAutoStartAfterScatter: !state.gatherAutoStartAfterScatter })),

toggleGatherShowAllPrompts: () => _set(state => ({ gatherShowAllPrompts: !state.gatherShowAllPrompts })),

}), {
Expand Down

0 comments on commit 6f35f72

Please sign in to comment.