Skip to content

Commit

Permalink
Merge branch 'refs/heads/main-stable'
Browse files Browse the repository at this point in the history
  • Loading branch information
enricoros committed Jun 7, 2024
2 parents 5066336 + 11c41e7 commit 3b15ad5
Show file tree
Hide file tree
Showing 22 changed files with 155 additions and 91 deletions.
2 changes: 1 addition & 1 deletion src/apps/chat/AppChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ export function AppChat() {
const conversation = getConversation(conversationId);
if (!conversation)
return;
const imaginedPrompt = await imaginePromptFromText(messageText) || 'An error sign.';
const imaginedPrompt = await imaginePromptFromText(messageText, conversationId) || 'An error sign.';
await handleExecuteAndOutcome('generate-image', conversationId, [
...conversation.messages,
createDMessage('user', imaginedPrompt),
Expand Down
4 changes: 2 additions & 2 deletions src/apps/chat/editors/chat-stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { DLLMId } from '~/modules/llms/store-llms';
import type { StreamingClientUpdate } from '~/modules/llms/vendors/unifiedStreamingClient';
import { autoSuggestions } from '~/modules/aifn/autosuggestions/autoSuggestions';
import { conversationAutoTitle } from '~/modules/aifn/autotitle/autoTitle';
import { llmStreamingChatGenerate, VChatContextRef, VChatContextName, VChatMessageIn } from '~/modules/llms/llm.client';
import { llmStreamingChatGenerate, VChatContextRef, VChatMessageIn, VChatStreamContextName } from '~/modules/llms/llm.client';
import { speakText } from '~/modules/elevenlabs/elevenlabs.client';

import type { DMessage } from '~/common/state/store-chats';
Expand Down Expand Up @@ -63,7 +63,7 @@ type StreamMessageStatus = { outcome: StreamMessageOutcome, errorMessage?: strin
export async function streamAssistantMessage(
llmId: DLLMId,
messagesHistory: VChatMessageIn[],
contextName: VChatContextName,
contextName: VChatStreamContextName,
contextRef: VChatContextRef,
throttleUnits: number, // 0: disable, 1: default throttle (12Hz), 2+ reduce the message frequency with the square root
autoSpeak: ChatAutoSpeakType,
Expand Down
19 changes: 12 additions & 7 deletions src/modules/aifn/autosuggestions/autoSuggestions.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { llmChatGenerateOrThrow, VChatFunctionIn } from '~/modules/llms/llm.client';
import { llmChatGenerateOrThrow, VChatFunctionIn, VChatMessageIn } from '~/modules/llms/llm.client';
import { useModelsStore } from '~/modules/llms/store-llms';

import { useChatStore } from '~/common/state/store-chats';
Expand Down Expand Up @@ -83,13 +83,18 @@ export function autoSuggestions(conversationId: string, assistantMessageId: stri

// Follow-up: Auto-Diagrams
if (suggestDiagrams) {
llmChatGenerateOrThrow(funcLLMId, [
{ role: 'system', content: systemMessage.text },
{ role: 'user', content: userMessage.text },
{ role: 'assistant', content: assistantMessageText },
], [suggestPlantUMLFn], 'draw_plantuml_diagram',
const instructions: VChatMessageIn[] = [
{ role: 'system', content: systemMessage.text },
{ role: 'user', content: userMessage.text },
{ role: 'assistant', content: assistantMessageText },
];
llmChatGenerateOrThrow(
funcLLMId,
instructions,
'chat-followup-diagram', conversationId,
[suggestPlantUMLFn], 'draw_plantuml_diagram',
).then(chatResponse => {

// cheap way to check if the function was supported
if (!('function_arguments' in chatResponse))
return;

Expand Down
30 changes: 16 additions & 14 deletions src/modules/aifn/autotitle/autoTitle.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { getFastLLMId } from '~/modules/llms/store-llms';
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';
import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client';

import { useChatStore } from '~/common/state/store-chats';

Expand Down Expand Up @@ -34,21 +34,23 @@ export async function conversationAutoTitle(conversationId: string, forceReplace

try {
// LLM chat-generate call
const instructions: VChatMessageIn[] = [
{ role: 'system', content: `You are an AI conversation titles assistant who specializes in creating expressive yet few-words chat titles.` },
{
role: 'user', content:
'Analyze the given short conversation (every line is truncated) and extract a concise chat title that ' +
'summarizes the conversation in as little as a couple of words.\n' +
'Only respond with the lowercase short title and nothing else.\n' +
'\n' +
'```\n' +
historyLines.join('\n') +
'```\n',
},
];
const chatResponse = await llmChatGenerateOrThrow(
fastLLMId,
[
{ role: 'system', content: `You are an AI conversation titles assistant who specializes in creating expressive yet few-words chat titles.` },
{
role: 'user', content:
'Analyze the given short conversation (every line is truncated) and extract a concise chat title that ' +
'summarizes the conversation in as little as a couple of words.\n' +
'Only respond with the lowercase short title and nothing else.\n' +
'\n' +
'```\n' +
historyLines.join('\n') +
'```\n',
},
],
instructions,
'chat-ai-title', conversationId,
null, null,
);

Expand Down
9 changes: 5 additions & 4 deletions src/modules/aifn/imagine/imaginePromptFromText.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { getFastLLMId } from '~/modules/llms/store-llms';
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';
import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client';


const simpleImagineSystemPrompt =
Expand All @@ -10,14 +10,15 @@ Provide output as a lowercase prompt and nothing else.`;
/**
* Creates a caption for a drawing or photo given some description - used to elevate the quality of the imaging
*/
export async function imaginePromptFromText(messageText: string): Promise<string | null> {
export async function imaginePromptFromText(messageText: string, contextRef: string): Promise<string | null> {
const fastLLMId = getFastLLMId();
if (!fastLLMId) return null;
try {
const chatResponse = await llmChatGenerateOrThrow(fastLLMId, [
const instructions: VChatMessageIn[] = [
{ role: 'system', content: simpleImagineSystemPrompt },
{ role: 'user', content: 'Write a prompt, based on the following input.\n\n```\n' + messageText.slice(0, 1000) + '\n```\n' },
], null, null);
];
const chatResponse = await llmChatGenerateOrThrow(fastLLMId, instructions, 'draw-expand-prompt', contextRef, null, null);
return chatResponse.content?.trim() ?? null;
} catch (error: any) {
console.error('imaginePromptFromText: fetch request error:', error);
Expand Down
2 changes: 1 addition & 1 deletion src/modules/aifn/react/react.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ export class Agent {
S.messages.push({ role: 'user', content: prompt });
let content: string;
try {
content = (await llmChatGenerateOrThrow(llmId, S.messages, null, null, 500)).content;
content = (await llmChatGenerateOrThrow(llmId, S.messages, 'chat-react-turn', null, null, null, 500)).content;
} catch (error: any) {
content = `Error in llmChatGenerateOrThrow: ${error}`;
}
Expand Down
7 changes: 4 additions & 3 deletions src/modules/aifn/summarize/summerize.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { DLLMId, findLLMOrThrow } from '~/modules/llms/store-llms';
import { llmChatGenerateOrThrow } from '~/modules/llms/llm.client';
import { llmChatGenerateOrThrow, VChatMessageIn } from '~/modules/llms/llm.client';


// prompt to be tried when doing recursive summerization.
Expand Down Expand Up @@ -80,10 +80,11 @@ async function cleanUpContent(chunk: string, llmId: DLLMId, _ignored_was_targetW
const autoResponseTokensSize = contextTokens ? Math.floor(contextTokens * outputTokenShare) : null;

try {
const chatResponse = await llmChatGenerateOrThrow(llmId, [
const instructions: VChatMessageIn[] = [
{ role: 'system', content: cleanupPrompt },
{ role: 'user', content: chunk },
], null, null, autoResponseTokensSize ?? undefined);
];
const chatResponse = await llmChatGenerateOrThrow(llmId, instructions, 'chat-ai-summarize', null, null, null, autoResponseTokensSize ?? undefined);
return chatResponse?.content ?? '';
} catch (error: any) {
return '';
Expand Down
4 changes: 2 additions & 2 deletions src/modules/aifn/useLLMChain.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import * as React from 'react';

import { DLLMId, findLLMOrThrow } from '~/modules/llms/store-llms';
import { llmStreamingChatGenerate, VChatContextName, VChatContextRef, VChatMessageIn } from '~/modules/llms/llm.client';
import { llmStreamingChatGenerate, VChatContextRef, VChatMessageIn, VChatStreamContextName } from '~/modules/llms/llm.client';


// set to true to log to the console
Expand All @@ -20,7 +20,7 @@ export interface LLMChainStep {
/**
* React hook to manage a chain of LLM transformations.
*/
export function useLLMChain(steps: LLMChainStep[], llmId: DLLMId | undefined, chainInput: string | undefined, onSuccess: (output: string, input: string) => void, contextName: VChatContextName, contextRef: VChatContextRef) {
export function useLLMChain(steps: LLMChainStep[], llmId: DLLMId | undefined, chainInput: string | undefined, onSuccess: (output: string, input: string) => void, contextName: VChatStreamContextName, contextRef: VChatContextRef) {

// state
const [chain, setChain] = React.useState<ChainState | null>(null);
Expand Down
4 changes: 2 additions & 2 deletions src/modules/aifn/useStreamChatText.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import * as React from 'react';

import type { DLLMId } from '~/modules/llms/store-llms';
import { llmStreamingChatGenerate, VChatContextName, VChatContextRef, VChatMessageIn } from '~/modules/llms/llm.client';
import { llmStreamingChatGenerate, VChatContextRef, VChatMessageIn, VChatStreamContextName } from '~/modules/llms/llm.client';


export function useStreamChatText() {
Expand All @@ -13,7 +13,7 @@ export function useStreamChatText() {
const abortControllerRef = React.useRef<AbortController | null>(null);


const startStreaming = React.useCallback(async (llmId: DLLMId, prompt: VChatMessageIn[], contextName: VChatContextName, contextRef: VChatContextRef) => {
const startStreaming = React.useCallback(async (llmId: DLLMId, prompt: VChatMessageIn[], contextName: VChatStreamContextName, contextRef: VChatContextRef) => {
setStreamError(null);
setPartialText(null);
setText(null);
Expand Down
21 changes: 9 additions & 12 deletions src/modules/llms/llm.client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { sendGAEvent } from '@next/third-parties/google';

import { hasGoogleAnalytics } from '~/common/components/GoogleAnalytics';

import type { ModelDescriptionSchema } from './server/llm.server.types';
import type { GenerateContextNameSchema, ModelDescriptionSchema, StreamingContextNameSchema } from './server/llm.server.types';
import type { OpenAIWire } from './server/openai/openai.wiretypes';
import type { StreamingClientUpdate } from './vendors/unifiedStreamingClient';
import { DLLM, DLLMId, DModelSource, DModelSourceId, LLM_IF_OAI_Chat, LLM_IF_OAI_Fn, useModelsStore } from './store-llms';
Expand All @@ -21,14 +21,8 @@ export interface VChatMessageIn {

export type VChatFunctionIn = OpenAIWire.ChatCompletion.RequestFunctionDef;

export type VChatContextName =
| 'conversation'
| 'ai-diagram'
| 'ai-flattener'
| 'beam-scatter'
| 'beam-gather'
| 'call'
| 'persona-extract';
export type VChatStreamContextName = StreamingContextNameSchema;
export type VChatGenerateContextName = GenerateContextNameSchema;
export type VChatContextRef = string;

export interface VChatMessageOut {
Expand Down Expand Up @@ -122,7 +116,10 @@ function modelDescriptionToDLLMOpenAIOptions<TSourceSetup, TLLMOptions>(model: M
export async function llmChatGenerateOrThrow<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown>(
llmId: DLLMId,
messages: VChatMessageIn[],
functions: VChatFunctionIn[] | null, forceFunctionName: string | null,
contextName: VChatGenerateContextName,
contextRef: VChatContextRef | null,
functions: VChatFunctionIn[] | null,
forceFunctionName: string | null,
maxTokens?: number,
): Promise<VChatMessageOut | VChatMessageOrFunctionCallOut> {

Expand All @@ -146,14 +143,14 @@ export async function llmChatGenerateOrThrow<TSourceSetup = unknown, TAccess = u
await new Promise(resolve => setTimeout(resolve, delay));

// execute via the vendor
return await vendor.rpcChatGenerateOrThrow(access, options, messages, functions, forceFunctionName, maxTokens);
return await vendor.rpcChatGenerateOrThrow(access, options, messages, contextName, contextRef, functions, forceFunctionName, maxTokens);
}


export async function llmStreamingChatGenerate<TSourceSetup = unknown, TAccess = unknown, TLLMOptions = unknown>(
llmId: DLLMId,
messages: VChatMessageIn[],
contextName: VChatContextName,
contextName: VChatStreamContextName,
contextRef: VChatContextRef,
functions: VChatFunctionIn[] | null,
forceFunctionName: string | null,
Expand Down
8 changes: 6 additions & 2 deletions src/modules/llms/server/anthropic/anthropic.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { fetchJsonOrTRPCError } from '~/server/api/trpc.router.fetchers';
import { fixupHost } from '~/common/util/urlUtils';

import { OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router';
import { llmsChatGenerateOutputSchema, llmsListModelsOutputSchema } from '../llm.server.types';
import { llmsChatGenerateOutputSchema, llmsGenerateContextSchema, llmsListModelsOutputSchema } from '../llm.server.types';

import { AnthropicWireMessagesRequest, anthropicWireMessagesRequestSchema, AnthropicWireMessagesResponse, anthropicWireMessagesResponseSchema } from './anthropic.wiretypes';
import { hardcodedAnthropicModels } from './anthropic.models';
Expand Down Expand Up @@ -158,7 +158,11 @@ const listModelsInputSchema = z.object({

const chatGenerateInputSchema = z.object({
access: anthropicAccessSchema,
model: openAIModelSchema, history: openAIHistorySchema,
model: openAIModelSchema,
history: openAIHistorySchema,
// functions: openAIFunctionsSchema.optional(),
// forceFunctionName: z.string().optional(),
context: llmsGenerateContextSchema.optional(),
});


Expand Down
9 changes: 6 additions & 3 deletions src/modules/llms/server/gemini/gemini.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { createTRPCRouter, publicProcedure } from '~/server/api/trpc.server';
import { fetchJsonOrTRPCError } from '~/server/api/trpc.router.fetchers';

import { fixupHost } from '~/common/util/urlUtils';
import { llmsChatGenerateOutputSchema, llmsListModelsOutputSchema } from '../llm.server.types';
import { llmsChatGenerateOutputSchema, llmsGenerateContextSchema, llmsListModelsOutputSchema } from '../llm.server.types';

import { OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router';

Expand Down Expand Up @@ -120,8 +120,11 @@ const accessOnlySchema = z.object({

const chatGenerateInputSchema = z.object({
access: geminiAccessSchema,
model: openAIModelSchema, history: openAIHistorySchema,
// functions: openAIFunctionsSchema.optional(), forceFunctionName: z.string().optional(),
model: openAIModelSchema,
history: openAIHistorySchema,
// functions: openAIFunctionsSchema.optional(),
// forceFunctionName: z.string().optional(),
context: llmsGenerateContextSchema.optional(),
});


Expand Down
12 changes: 6 additions & 6 deletions src/modules/llms/server/llm.server.streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ import type { OpenAIWire } from './openai/openai.wiretypes';
import { openAIAccess, openAIAccessSchema, openAIChatCompletionPayload, openAIHistorySchema, openAIModelSchema } from './openai/openai.router';


import { llmsStreamingContextSchema } from './llm.server.types';


// configuration
const USER_SYMBOL_MAX_TOKENS = '🧱';
const USER_SYMBOL_PROMPT_BLOCKED = '🚫';
Expand All @@ -46,17 +49,14 @@ type MuxingFormat = 'sse' | 'json-nl';
*/
type AIStreamParser = (data: string, eventType?: string) => { text: string, close: boolean };

const streamingContextSchema = z.object({
method: z.literal('chat-stream'),
name: z.enum(['conversation', 'ai-diagram', 'ai-flattener', 'call', 'beam-scatter', 'beam-gather', 'persona-extract']),
ref: z.string(),
});

const chatStreamingInputSchema = z.object({
access: z.union([anthropicAccessSchema, geminiAccessSchema, ollamaAccessSchema, openAIAccessSchema]),
model: openAIModelSchema,
history: openAIHistorySchema,
context: streamingContextSchema,
// NOTE: made it optional for now as we have some old requests without it
// 2024-07-07: remove .optional()
context: llmsStreamingContextSchema.optional(),
});
export type ChatStreamingInputSchema = z.infer<typeof chatStreamingInputSchema>;

Expand Down
19 changes: 19 additions & 0 deletions src/modules/llms/server/llm.server.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,25 @@ export const llmsListModelsOutputSchema = z.object({
});


// Chat Generation Input (some parts of)

const generateContextNameSchema = z.enum(['chat-ai-title', 'chat-ai-summarize', 'chat-followup-diagram', 'chat-react-turn', 'draw-expand-prompt']);
export type GenerateContextNameSchema = z.infer<typeof generateContextNameSchema>;
export const llmsGenerateContextSchema = z.object({
method: z.literal('chat-generate'),
name: generateContextNameSchema,
ref: z.string(),
});

const streamingContextNameSchema = z.enum(['conversation', 'ai-diagram', 'ai-flattener', 'call', 'beam-scatter', 'beam-gather', 'persona-extract']);
export type StreamingContextNameSchema = z.infer<typeof streamingContextNameSchema>;
export const llmsStreamingContextSchema = z.object({
method: z.literal('chat-stream'),
name: streamingContextNameSchema,
ref: z.string(),
});


// (non-streaming) Chat Generation Output

export const llmsChatGenerateOutputSchema = z.object({
Expand Down
9 changes: 6 additions & 3 deletions src/modules/llms/server/ollama/ollama.router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { capitalizeFirstLetter } from '~/common/util/textUtils';
import { fixupHost } from '~/common/util/urlUtils';

import { OpenAIHistorySchema, openAIHistorySchema, OpenAIModelSchema, openAIModelSchema } from '../openai/openai.router';
import { llmsChatGenerateOutputSchema, llmsListModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types';
import { llmsChatGenerateOutputSchema, llmsGenerateContextSchema, llmsListModelsOutputSchema, ModelDescriptionSchema } from '../llm.server.types';

import { OLLAMA_BASE_MODELS, OLLAMA_PREV_UPDATE } from './ollama.models';
import { WireOllamaChatCompletionInput, wireOllamaChunkedOutputSchema, wireOllamaListModelsSchema, wireOllamaModelInfoSchema } from './ollama.wiretypes';
Expand Down Expand Up @@ -117,8 +117,11 @@ const adminPullModelSchema = z.object({

const chatGenerateInputSchema = z.object({
access: ollamaAccessSchema,
model: openAIModelSchema, history: openAIHistorySchema,
// functions: openAIFunctionsSchema.optional(), forceFunctionName: z.string().optional(),
model: openAIModelSchema,
history: openAIHistorySchema,
// functions: openAIFunctionsSchema.optional(),
// forceFunctionName: z.string().optional(),
context: llmsGenerateContextSchema.optional(),
});

const listPullableOutputSchema = z.object({
Expand Down
Loading

0 comments on commit 3b15ad5

Please sign in to comment.