Skip to content

Commit

Permalink
refactor(infra): refactor copilot client (#8813)
Browse files Browse the repository at this point in the history
  • Loading branch information
EYHN committed Nov 27, 2024
1 parent 6b4a1aa commit 6e25243
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 85 deletions.
4 changes: 0 additions & 4 deletions packages/frontend/core/src/blocksuite/presets/ai/provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ export class AIProvider {
...options: Parameters<BlockSuitePresets.AIActions[T]>
) => ReturnType<BlockSuitePresets.AIActions[T]>
): void {
if (this.actions[id]) {
console.warn(`AI action ${id} is already provided`);
}

// @ts-expect-error TODO: maybe fix this
this.actions[id] = (
...args: Parameters<BlockSuitePresets.AIActions[T]>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ import {
getCopilotHistoriesQuery,
getCopilotHistoryIdsQuery,
getCopilotSessionsQuery,
gqlFetcherFactory,
GraphQLError,
type GraphQLQuery,
type QueryOptions,
type QueryResponse,
type RequestOptions,
UserFriendlyError,
} from '@affine/graphql';
Expand All @@ -21,26 +21,6 @@ import {
} from '@blocksuite/affine/blocks';
import { getCurrentStore } from '@toeverything/infra';

/**
* @deprecated will be removed soon
*/
export function getBaseUrl(): string {
if (BUILD_CONFIG.isElectron || BUILD_CONFIG.isIOS || BUILD_CONFIG.isAndroid) {
return BUILD_CONFIG.serverUrlPrefix;
}
if (typeof window === 'undefined') {
// is nodejs
return '';
}
const { protocol, hostname, port } = window.location;
return `${protocol}//${hostname}${port ? `:${port}` : ''}`;
}

/**
* @deprecated will be removed soon
*/
const defaultFetcher = gqlFetcherFactory(getBaseUrl() + '/graphql');

type OptionsField<T extends GraphQLQuery> =
RequestOptions<T>['variables'] extends { options: infer U } ? U : never;

Expand Down Expand Up @@ -76,23 +56,22 @@ export function handleError(src: any) {
return err;
}

const fetcher = async <Query extends GraphQLQuery>(
options: QueryOptions<Query>
) => {
try {
return await defaultFetcher<Query>(options);
} catch (err) {
throw handleError(err);
}
};

export class CopilotClient {
readonly backendUrl = getBaseUrl();
constructor(
readonly gql: <Query extends GraphQLQuery>(
options: QueryOptions<Query>
) => Promise<QueryResponse<Query>>,
readonly fetcher: (input: string, init?: RequestInit) => Promise<Response>,
readonly eventSource: (
url: string,
eventSourceInitDict?: EventSourceInit
) => EventSource
) {}

async createSession(
options: OptionsField<typeof createCopilotSessionMutation>
) {
const res = await fetcher({
const res = await this.gql({
query: createCopilotSessionMutation,
variables: {
options,
Expand All @@ -102,7 +81,7 @@ export class CopilotClient {
}

async forkSession(options: OptionsField<typeof forkCopilotSessionMutation>) {
const res = await fetcher({
const res = await this.gql({
query: forkCopilotSessionMutation,
variables: {
options,
Expand All @@ -114,7 +93,7 @@ export class CopilotClient {
async createMessage(
options: OptionsField<typeof createCopilotMessageMutation>
) {
const res = await fetcher({
const res = await this.gql({
query: createCopilotMessageMutation,
variables: {
options,
Expand All @@ -124,7 +103,7 @@ export class CopilotClient {
}

async getSessions(workspaceId: string) {
const res = await fetcher({
const res = await this.gql({
query: getCopilotSessionsQuery,
variables: {
workspaceId,
Expand All @@ -140,7 +119,7 @@ export class CopilotClient {
typeof getCopilotHistoriesQuery
>['variables']['options']
) {
const res = await fetcher({
const res = await this.gql({
query: getCopilotHistoriesQuery,
variables: {
workspaceId,
Expand All @@ -159,7 +138,7 @@ export class CopilotClient {
typeof getCopilotHistoriesQuery
>['variables']['options']
) {
const res = await fetcher({
const res = await this.gql({
query: getCopilotHistoryIdsQuery,
variables: {
workspaceId,
Expand All @@ -176,7 +155,7 @@ export class CopilotClient {
docId: string;
sessionIds: string[];
}) {
const res = await fetcher({
const res = await this.gql({
query: cleanupCopilotSessionMutation,
variables: {
input,
Expand All @@ -194,11 +173,11 @@ export class CopilotClient {
messageId?: string;
signal?: AbortSignal;
}) {
const url = new URL(`${this.backendUrl}/api/copilot/chat/${sessionId}`);
let url = `/api/copilot/chat/${sessionId}`;
if (messageId) {
url.searchParams.set('messageId', messageId);
url += `?messageId=${encodeURIComponent(messageId)}`;
}
const response = await fetch(url.toString(), { signal });
const response = await this.fetcher(url.toString(), { signal });
return response.text();
}

Expand All @@ -213,11 +192,11 @@ export class CopilotClient {
},
endpoint = 'stream'
) {
const url = new URL(
`${this.backendUrl}/api/copilot/chat/${sessionId}/${endpoint}`
);
if (messageId) url.searchParams.set('messageId', messageId);
return new EventSource(url.toString());
let url = `/api/copilot/chat/${sessionId}/${endpoint}`;
if (messageId) {
url += `?messageId=${encodeURIComponent(messageId)}`;
}
return this.eventSource(url);
}

// Text or image to images
Expand All @@ -227,15 +206,18 @@ export class CopilotClient {
seed?: string,
endpoint = 'images'
) {
const url = new URL(
`${this.backendUrl}/api/copilot/chat/${sessionId}/${endpoint}`
);
if (messageId) {
url.searchParams.set('messageId', messageId);
}
if (seed) {
url.searchParams.set('seed', seed);
let url = `/api/copilot/chat/${sessionId}/${endpoint}`;

if (messageId || seed) {
url += '?';
url += new URLSearchParams(
Object.fromEntries(
Object.entries({ messageId, seed }).filter(
([_, v]) => v !== undefined
)
) as Record<string, string>
).toString();
}
return new EventSource(url);
return this.eventSource(url);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@ import type { ForkChatSessionInput } from '@affine/graphql';
import { assertExists } from '@blocksuite/affine/global/utils';
import { partition } from 'lodash-es';

import { CopilotClient } from './copilot-client';
import type { CopilotClient } from './copilot-client';
import { delay, toTextStream } from './event-source';
import type { PromptKey } from './prompt';

const TIMEOUT = 50000;

const client = new CopilotClient();

export type TextToTextOptions = {
client: CopilotClient;
docId: string;
workspaceId: string;
promptName?: PromptKey;
Expand All @@ -33,9 +32,11 @@ export type ToImageOptions = TextToTextOptions & {
};

export function createChatSession({
client,
workspaceId,
docId,
}: {
client: CopilotClient;
workspaceId: string;
docId: string;
}) {
Expand All @@ -46,7 +47,10 @@ export function createChatSession({
});
}

export function forkCopilotSession(forkChatSessionInput: ForkChatSessionInput) {
export function forkCopilotSession(
client: CopilotClient,
forkChatSessionInput: ForkChatSessionInput
) {
return client.forkSession(forkChatSessionInput);
}

Expand Down Expand Up @@ -83,6 +87,7 @@ async function resizeImage(blob: Blob | File): Promise<Blob | null> {
}

async function createSessionMessage({
client,
docId,
workspaceId,
promptName,
Expand Down Expand Up @@ -140,6 +145,7 @@ async function createSessionMessage({
}

export function textToText({
client,
docId,
workspaceId,
promptName,
Expand Down Expand Up @@ -169,6 +175,7 @@ export function textToText({
_messageId = undefined;
} else {
const message = await createSessionMessage({
client,
docId,
workspaceId,
promptName,
Expand Down Expand Up @@ -242,6 +249,7 @@ export function textToText({
_messageId = undefined;
} else {
const message = await createSessionMessage({
client,
docId,
workspaceId,
promptName,
Expand All @@ -268,10 +276,6 @@ export function textToText({
}
}

export const listHistories = client.getHistories;

export const listHistoryIds = client.getHistoryIds;

// Only one image is currently being processed
export function toImage({
docId,
Expand All @@ -286,6 +290,7 @@ export function toImage({
timeout = TIMEOUT,
retry = false,
workflow = false,
client,
}: ToImageOptions) {
let _sessionId: string;
let _messageId: string | undefined;
Expand All @@ -305,6 +310,7 @@ export function toImage({
content,
attachments,
params,
client,
});
_sessionId = sessionId;
_messageId = messageId;
Expand Down Expand Up @@ -334,10 +340,12 @@ export function cleanupSessions({
workspaceId,
docId,
sessionIds,
client,
}: {
workspaceId: string;
docId: string;
sessionIds: string[];
client: CopilotClient;
}) {
return client.cleanupSessions({ workspaceId, docId, sessionIds });
}
Loading

0 comments on commit 6e25243

Please sign in to comment.