Skip to content

Commit

Permalink
Implement Refresh Token (#317)
Browse files Browse the repository at this point in the history
* Implement refresh token in backend

* Implement token refresh with axios interceptor

* Fix typo

* Apply lint

* Add Refresh token initialization

* Update Axios interceptor to handle token expiration and refresh logic

* Update environment variable comments for clarity

* Separate dtos related to refresh token

* Add `@ApiProperty` decorator

* Update `@ApiBody` and `@ApiResponse` type

* Separate JWT secrets for access and refresh tokens

* Add logout logic and remove duplicate call

* Update env variable comments for clarity

* Split JwtService for access and refresh tokens

* Fix interceptor logic for refresh token

* Add error config token
xet-a authored Sep 11, 2024
1 parent 70beff7 commit 5d5ed73
Showing 17 changed files with 278 additions and 63 deletions.
11 changes: 8 additions & 3 deletions backend/.env.development
Original file line number Diff line number Diff line change
@@ -14,9 +14,14 @@ GITHUB_CLIENT_SECRET=your_github_client_secret_here
# Example: http://localhost:3000/auth/login/github (For development mode)
GITHUB_CALLBACK_URL=http://localhost:3000/auth/login/github

# JWT_AUTH_SECRET: Secret key for JWT authentication.
# This key is used to sign and verify JWT tokens.
JWT_AUTH_SECRET=you_should_change_this_secret_key_in_production
# JWT_ACCESS_TOKEN_SECRET: Secret key for signing and verifying access tokens.
# JWT_ACCESS_TOKEN_EXPIRATION_TIME: Expiration time for access tokens in seconds.
JWT_ACCESS_TOKEN_SECRET=you_should_change_this_access_token_secret_key_in_production
JWT_ACCESS_TOKEN_EXPIRATION_TIME=86400
# JWT_REFRESH_TOKEN_SECRET: Secret key for signing and verifying refresh tokens.
# JWT_REFRESH_TOKEN_EXPIRATION_TIME: Expiration time for refresh tokens in seconds.
JWT_REFRESH_TOKEN_SECRET=you_should_change_this_refresh_token_secret_key_in_production
JWT_REFRESH_TOKEN_EXPIRATION_TIME=604800

# FRONTEND_BASE_URL: Base URL of the frontend application.
# This URL is used for redirecting after authentication, etc.
49 changes: 33 additions & 16 deletions backend/src/auth/auth.controller.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
import { Controller, Get, HttpRedirectResponse, Redirect, Req, UseGuards } from "@nestjs/common";
import {
Body,
Controller,
Get,
HttpRedirectResponse,
Post,
Redirect,
Req,
UseGuards,
} from "@nestjs/common";
import { ConfigService } from "@nestjs/config";
import { AuthGuard } from "@nestjs/passport";
import { ApiBody, ApiOperation, ApiResponse, ApiTags } from "@nestjs/swagger";
import { Public } from "src/utils/decorators/auth.decorator";
import { AuthService } from "./auth.service";
import { RefreshTokenRequestDto } from "./dto/refresh-token-request.dto";
import { RefreshTokenResponseDto } from "./dto/refresh-token-response.dto";
import { LoginRequest } from "./types/login-request.type";
import { JwtService } from "@nestjs/jwt";
import { LoginResponse } from "./types/login-response.type";
import { UsersService } from "src/users/users.service";
import { Public } from "src/utils/decorators/auth.decorator";
import { ApiOperation, ApiResponse, ApiTags } from "@nestjs/swagger";
import { ConfigService } from "@nestjs/config";

@ApiTags("Auth")
@Controller("auth")
export class AuthController {
constructor(
private configService: ConfigService,
private jwtService: JwtService,
private usersService: UsersService
private readonly authService: AuthService,
private configService: ConfigService
) {}

@Public()
@@ -28,16 +37,24 @@ export class AuthController {
})
@ApiResponse({ type: LoginResponse })
async login(@Req() req: LoginRequest): Promise<HttpRedirectResponse> {
const user = await this.usersService.findOrCreate(
req.user.socialProvider,
req.user.socialUid
);

const accessToken = this.jwtService.sign({ sub: user.id, nickname: user.nickname });
const { accessToken, refreshToken } = await this.authService.loginWithSocialProvider(req);

return {
url: `${this.configService.get("FRONTEND_BASE_URL")}/auth/callback?token=${accessToken}`,
url: `${this.configService.get("FRONTEND_BASE_URL")}/auth/callback?accessToken=${accessToken}&refreshToken=${refreshToken}`,
statusCode: 302,
};
}

@Public()
@Post("refresh")
@UseGuards(AuthGuard("refresh"))
@ApiOperation({
summary: "Refresh Access Token",
description: "Generates a new Access Token using the user's Refresh Token.",
})
@ApiBody({ type: RefreshTokenRequestDto })
@ApiResponse({ type: RefreshTokenResponseDto })
async refresh(@Body() body: RefreshTokenRequestDto): Promise<RefreshTokenResponseDto> {
return await this.authService.getNewAccessToken(body.refreshToken);
}
}
46 changes: 33 additions & 13 deletions backend/src/auth/auth.module.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,47 @@
import { Module } from "@nestjs/common";
import { AuthService } from "./auth.service";
import { ConfigService } from "@nestjs/config";
import { JwtService } from "@nestjs/jwt";
import { UsersModule } from "src/users/users.module";
import { JwtInject } from "src/utils/constants/jwt-inject";
import { AuthController } from "./auth.controller";
import { AuthService } from "./auth.service";
import { GithubStrategy } from "./github.strategy";
import { ConfigService } from "@nestjs/config";
import { JwtModule } from "@nestjs/jwt";
import { JwtRefreshStrategy } from "./jwt-refresh.strategy";
import { JwtStrategy } from "./jwt.strategy";

@Module({
imports: [
UsersModule,
JwtModule.registerAsync({
imports: [UsersModule],
providers: [
AuthService,
GithubStrategy,
JwtStrategy,
JwtRefreshStrategy,
{
provide: JwtInject.ACCESS,
useFactory: async (configService: ConfigService) => {
return new JwtService({
secret: configService.get<string>("JWT_ACCESS_TOKEN_SECRET"),
signOptions: {
expiresIn: `${configService.get("JWT_ACCESS_TOKEN_EXPIRATION_TIME")}s`,
},
});
},
inject: [ConfigService],
},
{
provide: JwtInject.REFRESH,
useFactory: async (configService: ConfigService) => {
return {
global: true,
signOptions: { expiresIn: "24h" },
secret: configService.get<string>("JWT_AUTH_SECRET"),
};
return new JwtService({
secret: configService.get<string>("JWT_REFRESH_TOKEN_SECRET"),
signOptions: {
expiresIn: `${configService.get("JWT_REFRESH_TOKEN_EXPIRATION_TIME")}s`,
},
});
},
inject: [ConfigService],
}),
},
],
providers: [AuthService, GithubStrategy, JwtStrategy],
exports: [JwtInject.ACCESS, JwtInject.REFRESH],
controllers: [AuthController],
})
export class AuthModule {}
48 changes: 47 additions & 1 deletion backend/src/auth/auth.service.spec.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,64 @@
import { ConfigModule } from "@nestjs/config";
import { JwtService } from "@nestjs/jwt";
import { Test, TestingModule } from "@nestjs/testing";
import { UsersService } from "../users/users.service";
import { AuthService } from "./auth.service";

describe("AuthService", () => {
let service: AuthService;
let jwtService: JwtService;

beforeEach(async () => {
const module: TestingModule = await Test.createTestingModule({
providers: [AuthService],
imports: [ConfigModule.forRoot()],
providers: [
AuthService,
{
provide: UsersService,
useValue: {
findOrCreate: jest
.fn()
.mockResolvedValue({ id: "123", nickname: "testuser" }),
},
},
{
provide: JwtService,
useValue: {
sign: jest.fn().mockReturnValue("signedToken"),
verify: jest.fn().mockReturnValue({ sub: "123", nickname: "testuser" }),
},
},
],
}).compile();

service = module.get<AuthService>(AuthService);
jwtService = module.get<JwtService>(JwtService);
});

it("should be defined", () => {
expect(service).toBeDefined();
});

describe("getNewAccessToken", () => {
it("should generate a new access token using refresh token", async () => {
const newToken = await service.getNewAccessToken("refreshToken");

expect(newToken).toBe("signedToken");
expect(jwtService.verify).toHaveBeenCalledWith("refreshToken");
expect(jwtService.sign).toHaveBeenCalledWith(
{ sub: "123", nickname: "testuser" },
expect.any(Object)
);
});

it("should throw an error if refresh token is invalid", async () => {
jwtService.verify = jest.fn().mockImplementation(() => {
throw new Error("Invalid token");
});

await expect(service.getNewAccessToken("invalidToken")).rejects.toThrow(
"Invalid token"
);
});
});
});
36 changes: 34 additions & 2 deletions backend/src/auth/auth.service.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,39 @@
import { Injectable } from "@nestjs/common";
import { Inject, Injectable } from "@nestjs/common";
import { JwtService } from "@nestjs/jwt";
import { UsersService } from "src/users/users.service";
import { JwtInject } from "src/utils/constants/jwt-inject";
import { RefreshTokenResponseDto } from "./dto/refresh-token-response.dto";
import { LoginRequest } from "./types/login-request.type";
import { LoginResponse } from "./types/login-response.type";

@Injectable()
export class AuthService {
constructor(private usersService: UsersService) {}
constructor(
private readonly usersService: UsersService,
@Inject(JwtInject.ACCESS) private readonly jwtAccessService: JwtService,
@Inject(JwtInject.REFRESH) private readonly jwtRefreshService: JwtService
) {}

async loginWithSocialProvider(req: LoginRequest): Promise<LoginResponse> {
const user = await this.usersService.findOrCreate(
req.user.socialProvider,
req.user.socialUid
);

const accessToken = this.jwtAccessService.sign({ sub: user.id, nickname: user.nickname });
const refreshToken = this.jwtRefreshService.sign({ sub: user.id });

return { accessToken, refreshToken };
}

async getNewAccessToken(refreshToken: string): Promise<RefreshTokenResponseDto> {
const payload = this.jwtRefreshService.verify(refreshToken);

const newAccessToken = this.jwtAccessService.sign({
sub: payload.sub,
nickname: payload.nickname,
});

return { newAccessToken };
}
}
6 changes: 6 additions & 0 deletions backend/src/auth/dto/refresh-token-request.dto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import { ApiProperty } from "@nestjs/swagger";

export class RefreshTokenRequestDto {
@ApiProperty({ type: String, description: "The refresh token to request a new access token" })
refreshToken: string;
}
6 changes: 6 additions & 0 deletions backend/src/auth/dto/refresh-token-response.dto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import { ApiProperty } from "@nestjs/swagger";

export class RefreshTokenResponseDto {
@ApiProperty({ type: String, description: "The new access token" })
newAccessToken: string;
}
26 changes: 26 additions & 0 deletions backend/src/auth/jwt-refresh.strategy.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { Injectable } from "@nestjs/common";
import { ConfigService } from "@nestjs/config";
import { PassportStrategy } from "@nestjs/passport";
import { Strategy as PassportJwtStrategy } from "passport-jwt";
import { JwtPayload } from "src/utils/types/jwt.type";
import { AuthorizedUser } from "src/utils/types/req.type";

@Injectable()
export class JwtRefreshStrategy extends PassportStrategy(PassportJwtStrategy, "refresh") {
constructor(configService: ConfigService) {
super({
jwtFromRequest: (req) => {
if (req && req.body.refreshToken) {
return req.body.refreshToken;
}
return null;
},
ignoreExpiration: false,
secretOrKey: configService.get<string>("JWT_REFRESH_TOKEN_SECRET"),
});
}

async validate(payload: JwtPayload): Promise<AuthorizedUser> {
return { id: payload.sub, nickname: payload.nickname };
}
}
8 changes: 4 additions & 4 deletions backend/src/auth/jwt.strategy.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import { ExtractJwt, Strategy as PassportJwtStrategy } from "passport-jwt";
import { ConfigService } from "@nestjs/config";
import { Injectable } from "@nestjs/common";
import { ConfigService } from "@nestjs/config";
import { PassportStrategy } from "@nestjs/passport";
import { ExtractJwt, Strategy as PassportJwtStrategy } from "passport-jwt";
import { JwtPayload } from "src/utils/types/jwt.type";
import { AuthorizedUser } from "src/utils/types/req.type";

@Injectable()
export class JwtStrategy extends PassportStrategy(PassportJwtStrategy) {
export class JwtStrategy extends PassportStrategy(PassportJwtStrategy, "jwt") {
constructor(configService: ConfigService) {
super({
jwtFromRequest: ExtractJwt.fromAuthHeaderAsBearerToken(),
ignoreExpiration: false,
secretOrKey: configService.get<string>("JWT_AUTH_SECRET"),
secretOrKey: configService.get<string>("JWT_ACCESS_TOKEN_SECRET"),
});
}

3 changes: 3 additions & 0 deletions backend/src/auth/types/login-response.type.ts
Original file line number Diff line number Diff line change
@@ -3,4 +3,7 @@ import { ApiProperty } from "@nestjs/swagger";
export class LoginResponse {
@ApiProperty({ type: String, description: "Access token for CodePair" })
accessToken: string;

@ApiProperty({ type: String, description: "Refresh token to get a new access token" })
refreshToken: string;
}
4 changes: 2 additions & 2 deletions backend/src/users/users.module.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { Module } from "@nestjs/common";
import { UsersService } from "./users.service";
import { CheckService } from "src/check/check.service";
import { PrismaService } from "src/db/prisma.service";
import { UsersController } from "./users.controller";
import { CheckService } from "src/check/check.service";
import { UsersService } from "./users.service";

@Module({
providers: [UsersService, PrismaService, CheckService],
4 changes: 4 additions & 0 deletions backend/src/utils/constants/jwt-inject.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export const JwtInject = {
ACCESS: "JWT_ACCESS_SERVICE",
REFRESH: "JWT_REFRESH_SERVICE",
};
57 changes: 47 additions & 10 deletions frontend/src/App.tsx
Original file line number Diff line number Diff line change
@@ -2,9 +2,12 @@ import "@fontsource/roboto/300.css";
import "@fontsource/roboto/400.css";
import "@fontsource/roboto/500.css";
import "@fontsource/roboto/700.css";
import "./App.css";
import { Box, CssBaseline, ThemeProvider, createTheme, useMediaQuery } from "@mui/material";
import { useSelector } from "react-redux";
import * as Sentry from "@sentry/react";
import { QueryCache, QueryClient, QueryClientProvider } from "@tanstack/react-query";
import axios from "axios";
import { useEffect, useMemo } from "react";
import { useDispatch, useSelector } from "react-redux";
import {
RouterProvider,
createBrowserRouter,
@@ -13,15 +16,15 @@ import {
useLocation,
useNavigationType,
} from "react-router-dom";
import { useEffect, useMemo } from "react";
import { selectConfig } from "./store/configSlice";
import axios from "axios";
import { routes } from "./routes";
import { QueryCache, QueryClient, QueryClientProvider } from "@tanstack/react-query";
import AuthProvider from "./providers/AuthProvider";
import { useErrorHandler } from "./hooks/useErrorHandler";
import * as Sentry from "@sentry/react";
import "./App.css";
import { useGetSettingsQuery } from "./hooks/api/settings";
import { useErrorHandler } from "./hooks/useErrorHandler";
import AuthProvider from "./providers/AuthProvider";
import { routes } from "./routes";
import { logout, setAccessToken } from "./store/authSlice";
import { selectConfig } from "./store/configSlice";
import { store } from "./store/store";
import { setUserData } from "./store/userSlice";
import { isAxios404Error, isAxios500Error } from "./utils/axios.default";

if (import.meta.env.PROD) {
@@ -58,6 +61,7 @@ function SettingLoader() {

function App() {
const config = useSelector(selectConfig);
const dispatch = useDispatch();
const prefersDarkMode = useMediaQuery("(prefers-color-scheme: dark)");
const theme = useMemo(() => {
const defaultMode = prefersDarkMode ? "dark" : "light";
@@ -104,6 +108,39 @@ function App() {
});
}, [handleError]);

useEffect(() => {
const handleRefreshTokenExpiration = () => {
dispatch(logout());
dispatch(setUserData(null));
};

const interceptor = axios.interceptors.response.use(
(response) => response,
async (error) => {
if (error.response?.status === 401 && !error.config._retry) {
if (error.config.url === "/auth/refresh") {
handleRefreshTokenExpiration();
return Promise.reject(error);
} else {
error.config._retry = true;
const refreshToken = store.getState().auth.refreshToken;
const response = await axios.post("/auth/refresh", { refreshToken });
const newAccessToken = response.data.newAccessToken;
dispatch(setAccessToken(newAccessToken));
axios.defaults.headers.common["Authorization"] = `Bearer ${newAccessToken}`;
error.config.headers["Authorization"] = `Bearer ${newAccessToken}`;
return axios(error.config);
}
}
return Promise.reject(error);
}
);

return () => {
axios.interceptors.response.eject(interceptor);
};
}, [dispatch]);

return (
<QueryClientProvider client={queryClient}>
<AuthProvider>
4 changes: 2 additions & 2 deletions frontend/src/components/popovers/ProfilePopover.tsx
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ import {
import { useDispatch } from "react-redux";
import { useNavigate } from "react-router-dom";
import { useCurrentTheme } from "../../hooks/useCurrentTheme";
import { setAccessToken } from "../../store/authSlice";
import { logout } from "../../store/authSlice";
import { setTheme, ThemeType } from "../../store/configSlice";
import { setUserData } from "../../store/userSlice";

@@ -23,7 +23,7 @@ function ProfilePopover(props: PopoverProps) {
const navigate = useNavigate();

const handleLogout = () => {
dispatch(setAccessToken(null));
dispatch(logout());
dispatch(setUserData(null));
};

8 changes: 4 additions & 4 deletions frontend/src/hooks/api/user.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import { useDispatch, useSelector } from "react-redux";
import { selectAuth, setAccessToken } from "../../store/authSlice";
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import axios from "axios";
import { GetUserResponse, UpdateUserRequest } from "./types/user";
import { useEffect } from "react";
import { useDispatch, useSelector } from "react-redux";
import { logout, selectAuth } from "../../store/authSlice";
import { User, setUserData } from "../../store/userSlice";
import { GetUserResponse, UpdateUserRequest } from "./types/user";

export const generateGetUserQueryKey = (accessToken: string) => {
return ["users", accessToken];
@@ -28,7 +28,7 @@ export const useGetUserQuery = () => {
if (query.isSuccess) {
dispatch(setUserData(query.data as User));
} else if (query.isError) {
dispatch(setAccessToken(null));
dispatch(logout());
dispatch(setUserData(null));
axios.defaults.headers.common["Authorization"] = "";
}
10 changes: 6 additions & 4 deletions frontend/src/pages/auth/callback/Index.tsx
Original file line number Diff line number Diff line change
@@ -2,22 +2,24 @@ import { Box } from "@mui/material";
import { useEffect } from "react";
import { useDispatch } from "react-redux";
import { useNavigate, useSearchParams } from "react-router-dom";
import { setAccessToken } from "../../../store/authSlice";
import { setAccessToken, setRefreshToken } from "../../../store/authSlice";

function CallbackIndex() {
const dispatch = useDispatch();
const navigate = useNavigate();
const [searchParams] = useSearchParams();

useEffect(() => {
const token = searchParams.get("token");
const accessToken = searchParams.get("accessToken");
const refreshToken = searchParams.get("refreshToken");

if (!token) {
if (!accessToken || !refreshToken) {
navigate("/");
return;
}

dispatch(setAccessToken(token));
dispatch(setAccessToken(accessToken));
dispatch(setRefreshToken(refreshToken));
}, [dispatch, navigate, searchParams]);

return <Box></Box>;
15 changes: 13 additions & 2 deletions frontend/src/store/authSlice.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import { createSlice } from "@reduxjs/toolkit";
import type { PayloadAction } from "@reduxjs/toolkit";
import { createSlice } from "@reduxjs/toolkit";
import axios from "axios";
import { RootState } from "./store";

export interface AuthState {
accessToken: string | null;
refreshToken: string | null;
}

const initialState: AuthState = {
accessToken: null,
refreshToken: null,
};

export const authSlice = createSlice({
@@ -17,10 +20,18 @@ export const authSlice = createSlice({
setAccessToken: (state, action: PayloadAction<string | null>) => {
state.accessToken = action.payload;
},
setRefreshToken(state, action: PayloadAction<string | null>) {
state.refreshToken = action.payload;
},
logout: (state) => {
state.accessToken = null;
state.refreshToken = null;
axios.defaults.headers.common["Authorization"] = "";
},
},
});

export const { setAccessToken } = authSlice.actions;
export const { setAccessToken, setRefreshToken, logout } = authSlice.actions;

export const selectAuth = (state: RootState) => state.auth;

0 comments on commit 5d5ed73

Please sign in to comment.