diff --git a/.github/helm/affine/templates/ingress.yaml b/.github/helm/affine/templates/ingress.yaml index a4c4f89987cd3..a55d19925223a 100644 --- a/.github/helm/affine/templates/ingress.yaml +++ b/.github/helm/affine/templates/ingress.yaml @@ -60,6 +60,20 @@ spec: name: affine-graphql port: number: {{ .Values.graphql.service.port }} + - path: /oauth/login + pathType: Prefix + backend: + service: + name: affine-graphql + port: + number: {{ .Values.graphql.service.port }} + - path: /desktop-signin + pathType: Prefix + backend: + service: + name: affine-graphql + port: + number: {{ .Values.graphql.service.port }} - path: /workspace pathType: Prefix backend: diff --git a/packages/backend/server/src/base/helpers/url.ts b/packages/backend/server/src/base/helpers/url.ts index 48bf54fe18016..594a874c3fc8d 100644 --- a/packages/backend/server/src/base/helpers/url.ts +++ b/packages/backend/server/src/base/helpers/url.ts @@ -60,25 +60,31 @@ export class URLHelper { return this.url(path, query).toString(); } - safeRedirect(res: Response, to: string) { + safeLink(to?: string) { try { - const finalTo = new URL(decodeURIComponent(to), this.baseUrl); - - for (const host of this.redirectAllowHosts) { - const hostURL = new URL(host); - if ( - hostURL.origin === finalTo.origin && - finalTo.pathname.startsWith(hostURL.pathname) - ) { - return res.redirect(finalTo.toString().replace(/\/$/, '')); + if (to) { + const finalTo = new URL(decodeURIComponent(to), this.baseUrl); + + for (const host of this.redirectAllowHosts) { + const hostURL = new URL(host); + if ( + hostURL.origin === finalTo.origin && + finalTo.pathname.startsWith(hostURL.pathname) + ) { + return finalTo.toString().replace(/\/$/, ''); + } } } } catch { // just ignore invalid url } - // redirect to home if the url is invalid - return res.redirect(this.home); + // return home url if the url is invalid + return this.home; + } + + safeRedirect(res: Response, to: string) { + return res.redirect(this.safeLink(to)); } verify(url: string | URL) { diff --git a/packages/backend/server/src/plugins/oauth/controller.ts b/packages/backend/server/src/plugins/oauth/controller.ts index f38eaa2eb0992..5605ac88577e6 100644 --- a/packages/backend/server/src/plugins/oauth/controller.ts +++ b/packages/backend/server/src/plugins/oauth/controller.ts @@ -1,29 +1,128 @@ import { Body, Controller, + Get, HttpCode, HttpStatus, + Logger, Post, Req, Res, } from '@nestjs/common'; import { ConnectedAccount, PrismaClient } from '@prisma/client'; import type { Request, Response } from 'express'; +import { z } from 'zod'; import { + Config, InvalidOauthCallbackState, MissingOauthQueryParameter, OauthAccountAlreadyConnected, OauthStateExpired, UnknownOauthProvider, + URLHelper, } from '../../base'; -import { AuthService, Public } from '../../core/auth'; +import { AuthService, Public, Session } from '../../core/auth'; import { UserService } from '../../core/user'; import { OAuthProviderName } from './config'; import { OAuthAccount, Tokens } from './providers/def'; import { OAuthProviderFactory } from './register'; import { OAuthService } from './service'; +const LoginParams = z.object({ + provider: z.nativeEnum(OAuthProviderName), + redirectUri: z.string().optional(), +}); + +// handle legacy clients oauth login +@Controller('/') +export class OAuthLegacyController { + private readonly logger = new Logger(OAuthLegacyController.name); + private readonly clientSchema: z.ZodEnum; + + constructor( + config: Config, + private readonly auth: AuthService, + private readonly oauth: OAuthService, + private readonly providerFactory: OAuthProviderFactory, + private readonly url: URLHelper + ) { + this.clientSchema = z.enum([ + 'web', + 'affine', + 'affine-canary', + 'affine-beta', + ...(config.node.dev ? ['affine-dev'] : []), + ]); + } + + @Public() + @Get('/oauth/login') + @Get('/desktop-signin') + @HttpCode(HttpStatus.OK) + async legacyLogin( + @Res() res: Response, + @Session() session: Session | undefined, + @Body('provider') provider?: string, + @Body('redirect_uri') redirectUri?: string, + @Body('client') client?: string + ) { + // sign out first, web only + if ((!!client || client === 'web') && session) { + await this.auth.signOut(session.sessionId); + await this.auth.refreshCookies(res, session.sessionId); + } + + const params = LoginParams.extend({ client: this.clientSchema }).safeParse({ + provider: provider?.toLowerCase(), + redirectUri: this.url.safeLink(redirectUri), + client, + }); + if (params.error) { + return res.redirect( + this.url.link('/sign-in', { + error: `Invalid oauth parameters`, + }) + ); + } else { + const { provider: providerName, redirectUri, client } = params.data; + const provider = this.providerFactory.get(providerName); + if (!provider) { + throw new UnknownOauthProvider({ name: providerName }); + } + + try { + const token = await this.oauth.saveOAuthState({ + provider: providerName, + redirectUri, + clientId: client, + }); + // legacy client state assemble + const oAuthUrl = new URL(provider.getAuthUrl(token)); + oAuthUrl.searchParams.set( + 'state', + JSON.stringify({ + state: oAuthUrl.searchParams.get('state'), + client, + provider, + }) + ); + return res.redirect(oAuthUrl.toString()); + } catch (e: any) { + this.logger.error( + `Failed to preflight oauth login for provider ${providerName}`, + e + ); + return res.redirect( + this.url.link('/sign-in', { + error: `Invalid oauth provider parameters`, + }) + ); + } + } + } +} + @Controller('/api/oauth') export class OAuthController { constructor( @@ -39,7 +138,9 @@ export class OAuthController { @HttpCode(HttpStatus.OK) async preflight( @Body('provider') unknownProviderName?: string, - @Body('redirect_uri') redirectUri?: string + @Body('redirect_uri') redirectUri?: string, + @Body('client') clientId?: string, + @Body('state') clientState?: string ) { if (!unknownProviderName) { throw new MissingOauthQueryParameter({ name: 'provider' }); @@ -53,66 +154,152 @@ export class OAuthController { throw new UnknownOauthProvider({ name: unknownProviderName }); } - const state = await this.oauth.saveOAuthState({ + const oAuthToken = await this.oauth.saveOAuthState({ provider: providerName, redirectUri, + // new client will generate the state from the client side + clientId, + state: clientState, }); return { - url: provider.getAuthUrl(state), + url: provider.getAuthUrl(oAuthToken), }; } @Public() - @Post('/callback') + @Post('/exchangeToken') @HttpCode(HttpStatus.OK) - async callback( - @Req() req: Request, - @Res() res: Response, - @Body('code') code?: string, - @Body('state') stateStr?: string + async exchangeToken( + @Body('code') code: string, + @Body('state') oAuthToken: string ) { if (!code) { throw new MissingOauthQueryParameter({ name: 'code' }); } - - if (!stateStr) { + if (!oAuthToken) { throw new MissingOauthQueryParameter({ name: 'state' }); } - if (typeof stateStr !== 'string' || !this.oauth.isValidState(stateStr)) { + const oAuthState = await this.oauth.getOAuthState(oAuthToken); + + if (!oAuthState || !oAuthState?.state) { throw new InvalidOauthCallbackState(); } - const state = await this.oauth.getOAuthState(stateStr); - - if (!state) { - throw new OauthStateExpired(); + // for new client, need exchange cookie by client state + // we only cache the code and access token in server side + const provider = this.providerFactory.get(oAuthState.provider); + if (!provider) { + throw new UnknownOauthProvider({ + name: oAuthState.provider ?? 'unknown', + }); } + const token = await this.oauth.saveOAuthState({ ...oAuthState, code }); - if (!state.provider) { - throw new MissingOauthQueryParameter({ name: 'provider' }); - } + return { + token, + provider: oAuthState.provider, + client: oAuthState.clientId, + }; + } + + @Public() + @Post('/callback') + @HttpCode(HttpStatus.OK) + async callback( + @Req() req: Request, + @Res() res: Response, + /** @deprecated */ @Body('code') code?: string, + @Body('state') oAuthToken?: string, + // new client will send token to exchange cookie + @Body('secret') inAppState?: string + ) { + if (inAppState && oAuthToken) { + // new method, need exchange cookie by client state + // we only cache the code and access token in server side + const authState = await this.oauth.getOAuthState(oAuthToken); + if (!authState || authState.state !== inAppState || !authState.code) { + throw new OauthStateExpired(); + } - const provider = this.providerFactory.get(state.provider); + if (!authState.provider) { + throw new MissingOauthQueryParameter({ name: 'provider' }); + } - if (!provider) { - throw new UnknownOauthProvider({ name: state.provider ?? 'unknown' }); - } + const provider = this.providerFactory.get(authState.provider); - const tokens = await provider.getToken(code); - const externAccount = await provider.getUser(tokens.accessToken); - const user = await this.loginFromOauth( - state.provider, - externAccount, - tokens - ); - - await this.auth.setCookies(req, res, user.id); - res.send({ - id: user.id, - redirectUri: state.redirectUri, - }); + if (!provider) { + throw new UnknownOauthProvider({ + name: authState.provider ?? 'unknown', + }); + } + + // NOTE: in web client, we don't need to exchange token + // and provide the auth code directly + const tokens = await provider.getToken(authState.code); + const externAccount = await provider.getUser(tokens.accessToken); + const user = await this.loginFromOauth( + authState.provider, + externAccount, + tokens + ); + + await this.auth.setCookies(req, res, user.id); + res.send({ + id: user.id, + /* @deprecated */ + redirectUri: authState.redirectUri, + }); + } else { + if (!code) { + throw new MissingOauthQueryParameter({ name: 'code' }); + } + + if (!oAuthToken) { + throw new MissingOauthQueryParameter({ name: 'state' }); + } + + if ( + typeof oAuthToken !== 'string' || + !this.oauth.isValidState(oAuthToken) + ) { + throw new InvalidOauthCallbackState(); + } + + const authState = await this.oauth.getOAuthState(oAuthToken); + + if (!authState) { + throw new OauthStateExpired(); + } + + if (!authState.provider) { + throw new MissingOauthQueryParameter({ name: 'provider' }); + } + + const provider = this.providerFactory.get(authState.provider); + + if (!provider) { + throw new UnknownOauthProvider({ + name: authState.provider ?? 'unknown', + }); + } + + const tokens = await provider.getToken(code); + const externAccount = await provider.getUser(tokens.accessToken); + const user = await this.loginFromOauth( + authState.provider, + externAccount, + tokens + ); + + await this.auth.setCookies(req, res, user.id); + res.send({ + id: user.id, + /* @deprecated */ + redirectUri: authState.redirectUri, + }); + } } private async loginFromOauth( diff --git a/packages/backend/server/src/plugins/oauth/index.ts b/packages/backend/server/src/plugins/oauth/index.ts index abfc3aab2e436..175bda3f82efa 100644 --- a/packages/backend/server/src/plugins/oauth/index.ts +++ b/packages/backend/server/src/plugins/oauth/index.ts @@ -4,7 +4,7 @@ import { AuthModule } from '../../core/auth'; import { ServerFeature } from '../../core/config'; import { UserModule } from '../../core/user'; import { Plugin } from '../registry'; -import { OAuthController } from './controller'; +import { OAuthController, OAuthLegacyController } from './controller'; import { OAuthProviders } from './providers'; import { OAuthProviderFactory } from './register'; import { OAuthResolver } from './resolver'; @@ -19,7 +19,7 @@ import { OAuthService } from './service'; OAuthResolver, ...OAuthProviders, ], - controllers: [OAuthController], + controllers: [OAuthController, OAuthLegacyController], contributesTo: ServerFeature.OAuth, if: config => config.flavor.graphql && !!config.plugins.oauth, }) diff --git a/packages/backend/server/src/plugins/oauth/service.ts b/packages/backend/server/src/plugins/oauth/service.ts index da0ccee9ce503..d74f7ef2592e6 100644 --- a/packages/backend/server/src/plugins/oauth/service.ts +++ b/packages/backend/server/src/plugins/oauth/service.ts @@ -9,6 +9,13 @@ import { OAuthProviderFactory } from './register'; const OAUTH_STATE_KEY = 'OAUTH_STATE'; interface OAuthState { + // client id, currently it's the client schema + // if not provided, it's web platform + clientId?: string; + // client state + state?: string; + // provider authorize code + code?: string; redirectUri?: string; provider: OAuthProviderName; } diff --git a/packages/frontend/core/src/components/affine/auth/oauth.tsx b/packages/frontend/core/src/components/affine/auth/oauth.tsx index e109e7a6de813..9c8f020bbb59e 100644 --- a/packages/frontend/core/src/components/affine/auth/oauth.tsx +++ b/packages/frontend/core/src/components/affine/auth/oauth.tsx @@ -1,5 +1,5 @@ import { Button } from '@affine/component/ui/button'; -import { ServerService } from '@affine/core/modules/cloud'; +import { AuthService, ServerService } from '@affine/core/modules/cloud'; import { UrlService } from '@affine/core/modules/url'; import { OAuthProviderType } from '@affine/graphql'; import track from '@affine/track'; @@ -53,43 +53,47 @@ export function OAuth({ redirectUrl }: { redirectUrl?: string }) { )); } +type OAuthProviderProps = { + provider: OAuthProviderType; + redirectUrl?: string; + scheme?: string; + popupWindow: (url: string) => void; +}; + function OAuthProvider({ provider, redirectUrl, scheme, popupWindow, -}: { - provider: OAuthProviderType; - redirectUrl?: string; - scheme?: string; - popupWindow: (url: string) => void; -}) { - const serverService = useService(ServerService); +}: OAuthProviderProps) { + const auth = useService(AuthService); const { icon } = OAuthProviderMap[provider]; const onClick = useCallback(() => { - const params = new URLSearchParams(); - - params.set('provider', provider); - - if (redirectUrl) { - params.set('redirect_uri', redirectUrl); + async function preflight() { + if (ignore) return; + try { + return await auth.oauthPreflight(provider, scheme, false, redirectUrl); + } catch { + return null; + } } - if (scheme) { - params.set('client', scheme); - } - - // TODO: Android app scheme not implemented - // if (BUILD_CONFIG.isAndroid) {} - - const oauthUrl = - serverService.server.baseUrl + `/oauth/login?${params.toString()}`; - - track.$.$.auth.signIn({ method: 'oauth', provider }); - - popupWindow(oauthUrl); - }, [popupWindow, provider, redirectUrl, scheme, serverService]); + let ignore = false; + // eslint-disable-next-line @typescript-eslint/no-floating-promises + preflight().then(url => { + // cover popup limit in safari + setTimeout(() => { + if (url && !ignore) { + track.$.$.auth.signIn({ method: 'oauth', provider }); + popupWindow(url); + } + }); + }); + return () => { + ignore = true; + }; + }, [auth, popupWindow, provider, redirectUrl, scheme]); return (