Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(server): new login flow #8387

Open
wants to merge 12 commits into
base: canary
Choose a base branch
from
14 changes: 14 additions & 0 deletions .github/helm/affine/templates/ingress.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ spec:
name: affine-graphql
port:
number: {{ .Values.graphql.service.port }}
- path: /oauth/login
pathType: Prefix
backend:
service:
name: affine-graphql
port:
number: {{ .Values.graphql.service.port }}
- path: /desktop-signin
pathType: Prefix
backend:
service:
name: affine-graphql
port:
number: {{ .Values.graphql.service.port }}
- path: /workspace
pathType: Prefix
backend:
Expand Down
30 changes: 18 additions & 12 deletions packages/backend/server/src/base/helpers/url.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,31 @@ export class URLHelper {
return this.url(path, query).toString();
}

safeRedirect(res: Response, to: string) {
safeLink(to?: string) {
try {
const finalTo = new URL(decodeURIComponent(to), this.baseUrl);

for (const host of this.redirectAllowHosts) {
const hostURL = new URL(host);
if (
hostURL.origin === finalTo.origin &&
finalTo.pathname.startsWith(hostURL.pathname)
) {
return res.redirect(finalTo.toString().replace(/\/$/, ''));
if (to) {
const finalTo = new URL(decodeURIComponent(to), this.baseUrl);

for (const host of this.redirectAllowHosts) {
const hostURL = new URL(host);
if (
hostURL.origin === finalTo.origin &&
finalTo.pathname.startsWith(hostURL.pathname)
) {
return finalTo.toString().replace(/\/$/, '');
}
}
}
} catch {
// just ignore invalid url
}

// redirect to home if the url is invalid
return res.redirect(this.home);
// return home url if the url is invalid
return this.home;
}

safeRedirect(res: Response, to: string) {
return res.redirect(this.safeLink(to));
}

verify(url: string | URL) {
Expand Down
261 changes: 224 additions & 37 deletions packages/backend/server/src/plugins/oauth/controller.ts
Original file line number Diff line number Diff line change
@@ -1,29 +1,128 @@
import {
Body,
Controller,
Get,
HttpCode,
HttpStatus,
Logger,
Post,
Req,
Res,
} from '@nestjs/common';
import { ConnectedAccount, PrismaClient } from '@prisma/client';
import type { Request, Response } from 'express';
import { z } from 'zod';

import {
Config,
InvalidOauthCallbackState,
MissingOauthQueryParameter,
OauthAccountAlreadyConnected,
OauthStateExpired,
UnknownOauthProvider,
URLHelper,
} from '../../base';
import { AuthService, Public } from '../../core/auth';
import { AuthService, Public, Session } from '../../core/auth';
import { UserService } from '../../core/user';
import { OAuthProviderName } from './config';
import { OAuthAccount, Tokens } from './providers/def';
import { OAuthProviderFactory } from './register';
import { OAuthService } from './service';

const LoginParams = z.object({
provider: z.nativeEnum(OAuthProviderName),
redirectUri: z.string().optional(),
});

// handle legacy clients oauth login
@Controller('/')
export class OAuthLegacyController {
private readonly logger = new Logger(OAuthLegacyController.name);
private readonly clientSchema: z.ZodEnum<any>;

constructor(
config: Config,
private readonly auth: AuthService,
private readonly oauth: OAuthService,
private readonly providerFactory: OAuthProviderFactory,
private readonly url: URLHelper
) {
this.clientSchema = z.enum([
'web',
'affine',
'affine-canary',
'affine-beta',
...(config.node.dev ? ['affine-dev'] : []),
]);
}

@Public()
@Get('/oauth/login')
@Get('/desktop-signin')
@HttpCode(HttpStatus.OK)
async legacyLogin(
@Res() res: Response,
@Session() session: Session | undefined,
@Body('provider') provider?: string,
@Body('redirect_uri') redirectUri?: string,
@Body('client') client?: string
) {
// sign out first, web only
if ((!!client || client === 'web') && session) {
await this.auth.signOut(session.sessionId);
await this.auth.refreshCookies(res, session.sessionId);
}

const params = LoginParams.extend({ client: this.clientSchema }).safeParse({
provider: provider?.toLowerCase(),
redirectUri: this.url.safeLink(redirectUri),
client,
});
if (params.error) {
return res.redirect(
this.url.link('/sign-in', {
error: `Invalid oauth parameters`,
})
);
} else {
const { provider: providerName, redirectUri, client } = params.data;
const provider = this.providerFactory.get(providerName);
if (!provider) {
throw new UnknownOauthProvider({ name: providerName });
}

try {
const token = await this.oauth.saveOAuthState({
provider: providerName,
redirectUri,
clientId: client,
});
// legacy client state assemble
const oAuthUrl = new URL(provider.getAuthUrl(token));
oAuthUrl.searchParams.set(
'state',
JSON.stringify({
state: oAuthUrl.searchParams.get('state'),
client,
provider,
})
);
return res.redirect(oAuthUrl.toString());
} catch (e: any) {
this.logger.error(
`Failed to preflight oauth login for provider ${providerName}`,
e
);
return res.redirect(
this.url.link('/sign-in', {
error: `Invalid oauth provider parameters`,
})
);
}
}
}
}

@Controller('/api/oauth')
export class OAuthController {
constructor(
Expand All @@ -39,7 +138,9 @@ export class OAuthController {
@HttpCode(HttpStatus.OK)
async preflight(
@Body('provider') unknownProviderName?: string,
@Body('redirect_uri') redirectUri?: string
@Body('redirect_uri') redirectUri?: string,
@Body('client') clientId?: string,
@Body('state') clientState?: string
) {
if (!unknownProviderName) {
throw new MissingOauthQueryParameter({ name: 'provider' });
Expand All @@ -53,66 +154,152 @@ export class OAuthController {
throw new UnknownOauthProvider({ name: unknownProviderName });
}

const state = await this.oauth.saveOAuthState({
const oAuthToken = await this.oauth.saveOAuthState({
provider: providerName,
redirectUri,
// new client will generate the state from the client side
clientId,
state: clientState,
});

return {
url: provider.getAuthUrl(state),
url: provider.getAuthUrl(oAuthToken),
};
}

@Public()
@Post('/callback')
@Post('/exchangeToken')
@HttpCode(HttpStatus.OK)
async callback(
@Req() req: Request,
@Res() res: Response,
@Body('code') code?: string,
@Body('state') stateStr?: string
async exchangeToken(
@Body('code') code: string,
@Body('state') oAuthToken: string
) {
if (!code) {
throw new MissingOauthQueryParameter({ name: 'code' });
}

if (!stateStr) {
if (!oAuthToken) {
throw new MissingOauthQueryParameter({ name: 'state' });
}

if (typeof stateStr !== 'string' || !this.oauth.isValidState(stateStr)) {
const oAuthState = await this.oauth.getOAuthState(oAuthToken);

if (!oAuthState || !oAuthState?.state) {
throw new InvalidOauthCallbackState();
}

const state = await this.oauth.getOAuthState(stateStr);

if (!state) {
throw new OauthStateExpired();
// for new client, need exchange cookie by client state
// we only cache the code and access token in server side
const provider = this.providerFactory.get(oAuthState.provider);
if (!provider) {
throw new UnknownOauthProvider({
name: oAuthState.provider ?? 'unknown',
});
}
const token = await this.oauth.saveOAuthState({ ...oAuthState, code });

if (!state.provider) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}
return {
token,
provider: oAuthState.provider,
client: oAuthState.clientId,
};
}

@Public()
@Post('/callback')
@HttpCode(HttpStatus.OK)
async callback(
@Req() req: Request,
@Res() res: Response,
/** @deprecated */ @Body('code') code?: string,
@Body('state') oAuthToken?: string,
// new client will send token to exchange cookie
@Body('secret') inAppState?: string
) {
if (inAppState && oAuthToken) {
// new method, need exchange cookie by client state
// we only cache the code and access token in server side
const authState = await this.oauth.getOAuthState(oAuthToken);
if (!authState || authState.state !== inAppState || !authState.code) {
throw new OauthStateExpired();
}

const provider = this.providerFactory.get(state.provider);
if (!authState.provider) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}

if (!provider) {
throw new UnknownOauthProvider({ name: state.provider ?? 'unknown' });
}
const provider = this.providerFactory.get(authState.provider);

const tokens = await provider.getToken(code);
const externAccount = await provider.getUser(tokens.accessToken);
const user = await this.loginFromOauth(
state.provider,
externAccount,
tokens
);

await this.auth.setCookies(req, res, user.id);
res.send({
id: user.id,
redirectUri: state.redirectUri,
});
if (!provider) {
throw new UnknownOauthProvider({
name: authState.provider ?? 'unknown',
});
}

// NOTE: in web client, we don't need to exchange token
// and provide the auth code directly
const tokens = await provider.getToken(authState.code);
const externAccount = await provider.getUser(tokens.accessToken);
const user = await this.loginFromOauth(
authState.provider,
externAccount,
tokens
);

await this.auth.setCookies(req, res, user.id);
res.send({
id: user.id,
/* @deprecated */
redirectUri: authState.redirectUri,
});
} else {
if (!code) {
throw new MissingOauthQueryParameter({ name: 'code' });
}

if (!oAuthToken) {
throw new MissingOauthQueryParameter({ name: 'state' });
}

if (
typeof oAuthToken !== 'string' ||
!this.oauth.isValidState(oAuthToken)
) {
throw new InvalidOauthCallbackState();
}

const authState = await this.oauth.getOAuthState(oAuthToken);

if (!authState) {
throw new OauthStateExpired();
}

if (!authState.provider) {
throw new MissingOauthQueryParameter({ name: 'provider' });
}

const provider = this.providerFactory.get(authState.provider);

if (!provider) {
throw new UnknownOauthProvider({
name: authState.provider ?? 'unknown',
});
}

const tokens = await provider.getToken(code);
const externAccount = await provider.getUser(tokens.accessToken);
const user = await this.loginFromOauth(
authState.provider,
externAccount,
tokens
);

await this.auth.setCookies(req, res, user.id);
res.send({
id: user.id,
/* @deprecated */
redirectUri: authState.redirectUri,
});
}
}

private async loginFromOauth(
Expand Down
Loading
Loading