Skip to content

Commit

Permalink
Toggle JSON mode, Fixes #515
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed May 7, 2024
1 parent 1f10905 commit 2e45325
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/modules/llms/server/llm.server.streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ function _prepareRequestData(access: ChatStreamingInputSchema['access'], model:
case 'ollama':
return {
...ollamaAccess(access, OLLAMA_PATH_CHAT),
body: ollamaChatCompletionPayload(model, history, true),
body: ollamaChatCompletionPayload(model, history, access.ollamaJson, true),
vendorMuxingFormat: 'json-nl',
vendorStreamParser: createStreamParserOllama(),
};
Expand Down
6 changes: 4 additions & 2 deletions src/modules/llms/server/ollama/ollama.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ export function ollamaAccess(access: OllamaAccessSchema, apiPath: string): { hea
}


export const ollamaChatCompletionPayload = (model: OpenAIModelSchema, history: OpenAIHistorySchema, stream: boolean): WireOllamaChatCompletionInput => ({
export const ollamaChatCompletionPayload = (model: OpenAIModelSchema, history: OpenAIHistorySchema, jsonOutput: boolean, stream: boolean): WireOllamaChatCompletionInput => ({
model: model.id,
messages: history,
options: {
...(model.temperature !== undefined && { temperature: model.temperature }),
},
...(jsonOutput && { format: 'json' }),
// n: ...
// functions: ...
// function_call: ...
Expand Down Expand Up @@ -101,6 +102,7 @@ async function ollamaPOST<TOut extends object, TPostBody extends object>(access:
export const ollamaAccessSchema = z.object({
dialect: z.enum(['ollama']),
ollamaHost: z.string().trim(),
ollamaJson: z.boolean(),
});
export type OllamaAccessSchema = z.infer<typeof ollamaAccessSchema>;

Expand Down Expand Up @@ -250,7 +252,7 @@ export const llmOllamaRouter = createTRPCRouter({
.output(llmsChatGenerateOutputSchema)
.mutation(async ({ input: { access, history, model } }) => {

const wireGeneration = await ollamaPOST(access, ollamaChatCompletionPayload(model, history, false), OLLAMA_PATH_CHAT);
const wireGeneration = await ollamaPOST(access, ollamaChatCompletionPayload(model, history, access.ollamaJson, false), OLLAMA_PATH_CHAT);
const generation = wireOllamaChunkedOutputSchema.parse(wireGeneration);

if ('error' in generation)
Expand Down
8 changes: 6 additions & 2 deletions src/modules/llms/server/ollama/ollama.wiretypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ const wireOllamaChatCompletionInputSchema = z.object({
messages: z.array(z.object({
role: z.enum(['assistant', 'system', 'user']),
content: z.string(),
images: z.array(z.string()).optional(), // base64 encoded images
})),

// optional
format: z.enum(['json']).optional(),
options: z.object({
// https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md
// https://github.com/ollama/ollama/blob/main/docs/modelfile.md
// Maximum number of tokens to predict when generating text.
num_predict: z.number().int().optional(),
// Sets the random number seed to use for generation
Expand All @@ -63,8 +64,11 @@ const wireOllamaChatCompletionInputSchema = z.object({
// Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text. (Default 0.9)
top_p: z.number().positive().optional(),
}).optional(),
template: z.string().optional(), // overrides what is defined in the Modelfile
stream: z.boolean().optional(), // default: true
keep_alive: z.string().optional(), // e.g. '5m'

// Note: not used anymore as of 2024-05-07?
// template: z.string().optional(), // overrides what is defined in the Modelfile

// Future Improvements?
// n: z.number().int().optional(), // number of completions to generate
Expand Down
15 changes: 13 additions & 2 deletions src/modules/llms/vendors/ollama/OllamaSourceSetup.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import * as React from 'react';

import { Button } from '@mui/joy';

import { FormSwitchControl } from '~/common/components/forms/FormSwitchControl';
import { FormTextField } from '~/common/components/forms/FormTextField';
import { InlineError } from '~/common/components/InlineError';
import { Link } from '~/common/components/Link';
Expand All @@ -26,7 +27,7 @@ export function OllamaSourceSetup(props: { sourceId: DModelSourceId }) {
useSourceSetup(props.sourceId, ModelVendorOllama);

// derived state
const { ollamaHost } = access;
const { ollamaHost, ollamaJson } = access;

const hostValid = !!asValidURL(ollamaHost);
const hostError = !!ollamaHost && !hostValid;
Expand All @@ -41,13 +42,23 @@ export function OllamaSourceSetup(props: { sourceId: DModelSourceId }) {
<FormTextField
autoCompleteId='ollama-host'
title='Ollama Host'
description={<Link level='body-sm' href='https://github.com/enricoros/big-agi/blob/main/docs/config-local-ollama.md' target='_blank'>information</Link>}
description={<Link level='body-sm' href='https://github.com/enricoros/big-agi/blob/main/docs/config-local-ollama.md' target='_blank'>Information</Link>}
placeholder='http://127.0.0.1:11434'
isError={hostError}
value={ollamaHost || ''}
onChange={text => updateSetup({ ollamaHost: text })}
/>

<FormSwitchControl
title='JSON Output' on='Enabled' fullWidth
description={<Link level='body-sm' href='https://github.com/ollama/ollama/blob/main/docs/api.md#generate-a-chat-completion' target='_blank'>Information</Link>}
checked={ollamaJson}
onChange={on => {
updateSetup({ ollamaJson: on });
refetch();
}}
/>

<SetupFormRefetchButton
refetch={refetch} disabled={!shallFetchSucceed || isFetching} loading={isFetching} error={isError}
leftButton={
Expand Down
4 changes: 3 additions & 1 deletion src/modules/llms/vendors/ollama/ollama.vendor.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { OllamaIcon } from '~/common/components/icons/vendors/OllamaIcon';
import { apiAsync, apiQuery } from '~/common/util/trpc.client';
import { apiAsync } from '~/common/util/trpc.client';

import type { IModelVendor } from '../IModelVendor';
import type { OllamaAccessSchema } from '../../server/ollama/ollama.router';
Expand All @@ -14,6 +14,7 @@ import { OllamaSourceSetup } from './OllamaSourceSetup';

export interface SourceSetupOllama {
ollamaHost: string;
ollamaJson: boolean;
}


Expand All @@ -34,6 +35,7 @@ export const ModelVendorOllama: IModelVendor<SourceSetupOllama, OllamaAccessSche
getTransportAccess: (partialSetup): OllamaAccessSchema => ({
dialect: 'ollama',
ollamaHost: partialSetup?.ollamaHost || '',
ollamaJson: partialSetup?.ollamaJson || false,
}),

// List Models
Expand Down

0 comments on commit 2e45325

Please sign in to comment.