diff --git a/src/commands/upgrade_config_files.ts b/src/commands/upgrade_config_files.ts index 69822af..056b0a4 100644 --- a/src/commands/upgrade_config_files.ts +++ b/src/commands/upgrade_config_files.ts @@ -7,6 +7,7 @@ import { sessionConfig } from '../patchers/config_files/session_config.js' import { bodyparserConfig } from '../patchers/config_files/bodyparser_config.js' import { databaseConfig } from '../patchers/config_files/database_config.js' import { hashConfig } from '../patchers/config_files/hash_config.js' +import { shieldConfig } from '../patchers/config_files/shield_config.js' export class UpgradeConfigFiles extends BaseCommand { static commandName = `upgrade-config-files` @@ -24,6 +25,7 @@ export class UpgradeConfigFiles extends BaseCommand { staticConfig(), databaseConfig(), hashConfig(), + shieldConfig(), ], projectPath: this.projectPath, }).run() diff --git a/src/patchers/config_files/shield_config.ts b/src/patchers/config_files/shield_config.ts new file mode 100644 index 0000000..2c07a57 --- /dev/null +++ b/src/patchers/config_files/shield_config.ts @@ -0,0 +1,55 @@ +import { SyntaxKind } from 'ts-morph' +import { PatcherFactory } from '../../types/index.js' +import { ConfigUpdaterPatcher } from '../config_updater_patcher.js' + +export function shieldConfig(): PatcherFactory { + return (runner) => new ShieldConfig(runner) +} + +/** + * Rewrite the config/shield.ts file to use the new API + */ +export class ShieldConfig extends ConfigUpdaterPatcher { + static patcherName = 'shield-config' + + async invoke() { + super.invoke() + + const file = this.getConfigFile('config/shield.ts') + if (!file) return + + /** + * Take each configuration section exported from the old file + * and create a new object literal that will include all + */ + const exportedDeclarations = file.getExportedDeclarations() + const sections = ['csp', 'csrf', 'xframe', 'hsts', 'contentTypeSniffing', 'dnsPrefetchControl'] + + let configObjectLiteral = '{\n' + + sections.forEach((sectionName) => { + const symbols = exportedDeclarations.get(sectionName) + if (!symbols) return + + const symbol = symbols[0].getChildrenOfKind(SyntaxKind.ObjectLiteralExpression) + configObjectLiteral += sectionName + ': ' + symbol[0].getText() + ',\n' + }) + + configObjectLiteral += '}' + + /** + * Write the new file + */ + const newContent = ` + import { defineConfig } from '@adonisjs/shield' + + export default defineConfig(${configObjectLiteral}) + ` + file.replaceWithText(newContent) + await this.formatFile(file).save() + + this.logger.info('Updated config/shield.ts file') + + this.exit() + } +} diff --git a/tests/shield_config.spec.ts b/tests/shield_config.spec.ts new file mode 100644 index 0000000..265518c --- /dev/null +++ b/tests/shield_config.spec.ts @@ -0,0 +1,88 @@ +import dedent from 'dedent' +import { test } from '@japa/runner' + +import { createRunner } from '../test_helpers/index.js' +import { shieldConfig } from '../src/patchers/config_files/shield_config.js' + +test.group('Shield config', () => { + test('Update shield config', async ({ assert, fs }) => { + await fs.setupProject({}) + + await fs.create( + 'config/shield.ts', + dedent` + import Env from '@ioc:Adonis/Core/Env' + import { ShieldConfig } from '@ioc:Adonis/Addons/Shield' + + export const csp: ShieldConfig['csp'] = { + enabled: false, + directives: {}, + reportOnly: false, + } + + export const csrf: ShieldConfig['csrf'] = { + enabled: Env.get('NODE_ENV') !== 'test', + exceptRoutes: [], + enableXsrfCookie: true, + methods: ['POST', 'PUT', 'PATCH', 'DELETE'], + } + + export const dnsPrefetch: ShieldConfig['dnsPrefetch'] = { + enabled: true, + allow: true, + } + + export const xFrame: ShieldConfig['xFrame'] = { + enabled: true, + action: 'DENY', + } + + export const hsts: ShieldConfig['hsts'] = { + enabled: true, + maxAge: '180 days', + includeSubDomains: true, + preload: false, + } + + export const contentTypeSniffing: ShieldConfig['contentTypeSniffing'] = { + enabled: true, + } + ` + ) + + await createRunner({ + projectPath: fs.basePath, + patchers: [shieldConfig()], + }).run() + + const content = await fs.contents('config/shield.ts') + assert.snapshot(content).matchInline(` + " + import { defineConfig } from '@adonisjs/shield' + + export default defineConfig({ + csp: { + enabled: false, + directives: {}, + reportOnly: false, + }, + csrf: { + enabled: Env.get('NODE_ENV') !== 'test', + exceptRoutes: [], + enableXsrfCookie: true, + methods: ['POST', 'PUT', 'PATCH', 'DELETE'], + }, + hsts: { + enabled: true, + maxAge: '180 days', + includeSubDomains: true, + preload: false, + }, + contentTypeSniffing: { + enabled: true, + }, + }) + " + `) + }) +})