Skip to content

Commit

Permalink
feat: Add @encrypted enhancer (#1922)
Browse files Browse the repository at this point in the history
  • Loading branch information
genu authored Dec 30, 2024
1 parent a0e2b53 commit 1b7448f
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 3 deletions.
1 change: 1 addition & 0 deletions packages/runtime/src/enhancements/edge/encrypted.ts
15 changes: 13 additions & 2 deletions packages/runtime/src/enhancements/node/create-enhancement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ import { withJsonProcessor } from './json-processor';
import { Logger } from './logger';
import { withOmit } from './omit';
import { withPassword } from './password';
import { withEncrypted } from './encrypted';
import { policyProcessIncludeRelationPayload, withPolicy } from './policy';
import type { PolicyDef } from './types';

/**
* All enhancement kinds
*/
const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate'];
const ALL_ENHANCEMENTS: EnhancementKind[] = ['password', 'omit', 'policy', 'validation', 'delegate', 'encrypted'];

/**
* Options for {@link createEnhancement}
Expand Down Expand Up @@ -100,6 +101,7 @@ export function createEnhancement<DbClient extends object>(
}

const hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password'));
const hasEncrypted = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@encrypted'));
const hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit'));
const hasDefaultAuth = allFields.some((field) => field.defaultValueProvider);
const hasTypeDefField = allFields.some((field) => field.isTypeDef);
Expand All @@ -120,13 +122,22 @@ export function createEnhancement<DbClient extends object>(
}
}

// password enhancement must be applied prior to policy because it changes then length of the field
// password and encrypted enhancement must be applied prior to policy because it changes then length of the field
// and can break validation rules like `@length`
if (hasPassword && kinds.includes('password')) {
// @password proxy
result = withPassword(result, options);
}

if (hasEncrypted && kinds.includes('encrypted')) {
if (!options.encryption) {
throw new Error('Encryption options are required for @encrypted enhancement');
}

// @encrypted proxy
result = withEncrypted(result, options);
}

// 'policy' and 'validation' enhancements are both enabled by `withPolicy`
if (kinds.includes('policy') || kinds.includes('validation')) {
result = withPolicy(result, options, context);
Expand Down
175 changes: 175 additions & 0 deletions packages/runtime/src/enhancements/node/encrypted.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
/* eslint-disable @typescript-eslint/no-unused-vars */

import {
FieldInfo,
NestedWriteVisitor,
enumerate,
getModelFields,
resolveField,
type PrismaWriteActionType,
} from '../../cross';
import { DbClientContract, CustomEncryption, SimpleEncryption } from '../../types';
import { InternalEnhancementOptions } from './create-enhancement';
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';
import { QueryUtils } from './query-utils';

/**
* Gets an enhanced Prisma client that supports `@encrypted` attribute.
*
* @private
*/
export function withEncrypted<DbClient extends object = any>(
prisma: DbClient,
options: InternalEnhancementOptions
): DbClient {
return makeProxy(
prisma,
options.modelMeta,
(_prisma, model) => new EncryptedHandler(_prisma as DbClientContract, model, options),
'encrypted'
);
}

class EncryptedHandler extends DefaultPrismaProxyHandler {
private queryUtils: QueryUtils;
private encoder = new TextEncoder();
private decoder = new TextDecoder();

constructor(prisma: DbClientContract, model: string, options: InternalEnhancementOptions) {
super(prisma, model, options);

this.queryUtils = new QueryUtils(prisma, options);

if (!options.encryption) throw new Error('Encryption options must be provided');

if (this.isCustomEncryption(options.encryption!)) {
if (!options.encryption.encrypt || !options.encryption.decrypt)
throw new Error('Custom encryption must provide encrypt and decrypt functions');
} else {
if (!options.encryption.encryptionKey) throw new Error('Encryption key must be provided');
if (options.encryption.encryptionKey.length !== 32) throw new Error('Encryption key must be 32 bytes');
}
}

private async getKey(secret: Uint8Array): Promise<CryptoKey> {
return crypto.subtle.importKey('raw', secret, 'AES-GCM', false, ['encrypt', 'decrypt']);
}

private isCustomEncryption(encryption: CustomEncryption | SimpleEncryption): encryption is CustomEncryption {
return 'encrypt' in encryption && 'decrypt' in encryption;
}

private async encrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.encrypt(this.model, field, data);
}

const key = await this.getKey(this.options.encryption!.encryptionKey);
const iv = crypto.getRandomValues(new Uint8Array(12));

const encrypted = await crypto.subtle.encrypt(
{
name: 'AES-GCM',
iv,
},
key,
this.encoder.encode(data)
);

// Combine IV and encrypted data into a single array of bytes
const bytes = [...iv, ...new Uint8Array(encrypted)];

// Convert bytes to base64 string
return btoa(String.fromCharCode(...bytes));
}

private async decrypt(field: FieldInfo, data: string): Promise<string> {
if (this.isCustomEncryption(this.options.encryption!)) {
return this.options.encryption.decrypt(this.model, field, data);
}

const key = await this.getKey(this.options.encryption!.encryptionKey);

// Convert base64 back to bytes
const bytes = Uint8Array.from(atob(data), (c) => c.charCodeAt(0));

// First 12 bytes are IV, rest is encrypted data
const decrypted = await crypto.subtle.decrypt(
{
name: 'AES-GCM',
iv: bytes.slice(0, 12),
},
key,
bytes.slice(12)
);

return this.decoder.decode(decrypted);
}

// base override
protected async preprocessArgs(action: PrismaProxyActions, args: any) {
const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert'];
if (args && args.data && actionsOfInterest.includes(action)) {
await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args);
}
return args;
}

// base override
protected async processResultEntity<T>(method: PrismaProxyActions, data: T): Promise<T> {
if (!data || typeof data !== 'object') {
return data;
}

for (const value of enumerate(data)) {
await this.doPostProcess(value, this.model);
}

return data;
}

private async doPostProcess(entityData: any, model: string) {
const realModel = this.queryUtils.getDelegateConcreteModel(model, entityData);

for (const field of getModelFields(entityData)) {
const fieldInfo = await resolveField(this.options.modelMeta, realModel, field);

if (!fieldInfo) {
continue;
}

const shouldDecrypt = fieldInfo.attributes?.find((attr) => attr.name === '@encrypted');
if (shouldDecrypt) {
// Don't decrypt null, undefined or empty string values
if (!entityData[field]) continue;

try {
entityData[field] = await this.decrypt(fieldInfo, entityData[field]);
} catch (error) {
console.warn('Decryption failed, keeping original value:', error);
}
}
}
}

private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) {
const visitor = new NestedWriteVisitor(this.options.modelMeta, {
field: async (field, _action, data, context) => {
// Don't encrypt null, undefined or empty string values
if (!data) return;

const encAttr = field.attributes?.find((attr) => attr.name === '@encrypted');
if (encAttr && field.type === 'String') {
try {
context.parent[field.name] = await this.encrypt(field, data);
} catch (error) {
throw new Error(`Encryption failed for field ${field.name}: ${error}`);
}
}
},
});

await visitor.visit(model, action, args);
}
}
15 changes: 14 additions & 1 deletion packages/runtime/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import type { z } from 'zod';
import { FieldInfo } from './cross';

export type PrismaPromise<T> = Promise<T> & Record<string, (args?: any) => PrismaPromise<any>>;

Expand Down Expand Up @@ -133,6 +134,11 @@ export type EnhancementOptions = {
* The `isolationLevel` option passed to `prisma.$transaction()` call for transactions initiated by ZenStack.
*/
transactionIsolationLevel?: TransactionIsolationLevel;

/**
* The encryption options for using the `encrypted` enhancement.
*/
encryption?: SimpleEncryption | CustomEncryption;
};

/**
Expand All @@ -145,7 +151,7 @@ export type EnhancementContext<User extends AuthUser = AuthUser> = {
/**
* Kinds of enhancements to `PrismaClient`
*/
export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate';
export type EnhancementKind = 'password' | 'omit' | 'policy' | 'validation' | 'delegate' | 'encrypted';

/**
* Function for transforming errors.
Expand All @@ -166,3 +172,10 @@ export type ZodSchemas = {
*/
input?: Record<string, Record<string, z.ZodSchema>>;
};

export type CustomEncryption = {
encrypt: (model: string, field: FieldInfo, plain: string) => Promise<string>;
decrypt: (model: string, field: FieldInfo, cipher: string) => Promise<string>;
};

export type SimpleEncryption = { encryptionKey: Uint8Array };
8 changes: 8 additions & 0 deletions packages/schema/src/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,14 @@ attribute @@auth() @@@supportTypeDef
*/
attribute @password(saltLength: Int?, salt: String?) @@@targetField([StringField])


/**
* Indicates that the field is encrypted when storing in the DB and should be decrypted when read
*
* ZenStack uses the Web Crypto API to encrypt and decrypt the field.
*/
attribute @encrypted() @@@targetField([StringField])

/**
* Indicates that the field should be omitted when read from the generated services.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import { FieldInfo } from '@zenstackhq/runtime';
import { loadSchema } from '@zenstackhq/testtools';
import path from 'path';

describe('Encrypted test', () => {
let origDir: string;

beforeAll(async () => {
origDir = path.resolve('.');
});

afterEach(async () => {
process.chdir(origDir);
});

it('Simple encryption test', async () => {
const { enhance } = await loadSchema(`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()
@@allow('all', true)
}`);

const sudoDb = enhance(undefined, { kinds: [] });
const encryptionKey = new Uint8Array(Buffer.from('AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=', 'base64'));

const db = enhance(undefined, {
kinds: ['encrypted'],
encryption: { encryptionKey },
});

const create = await db.user.create({
data: {
id: '1',
encrypted_value: 'abc123',
},
});

const read = await db.user.findUnique({
where: {
id: '1',
},
});

const sudoRead = await sudoDb.user.findUnique({
where: {
id: '1',
},
});

expect(create.encrypted_value).toBe('abc123');
expect(read.encrypted_value).toBe('abc123');
expect(sudoRead.encrypted_value).not.toBe('abc123');
});

it('Custom encryption test', async () => {
const { enhance } = await loadSchema(`
model User {
id String @id @default(cuid())
encrypted_value String @encrypted()
@@allow('all', true)
}`);

const sudoDb = enhance(undefined, { kinds: [] });
const db = enhance(undefined, {
kinds: ['encrypted'],
encryption: {
encrypt: async (model: string, field: FieldInfo, data: string) => {
// Add _enc to the end of the input
return data + '_enc';
},
decrypt: async (model: string, field: FieldInfo, cipher: string) => {
// Remove _enc from the end of the input explicitly
if (cipher.endsWith('_enc')) {
return cipher.slice(0, -4); // Remove last 4 characters (_enc)
}

return cipher;
},
},
});

const create = await db.user.create({
data: {
id: '1',
encrypted_value: 'abc123',
},
});

const read = await db.user.findUnique({
where: {
id: '1',
},
});

const sudoRead = await sudoDb.user.findUnique({
where: {
id: '1',
},
});

expect(create.encrypted_value).toBe('abc123');
expect(read.encrypted_value).toBe('abc123');
expect(sudoRead.encrypted_value).toBe('abc123_enc');
});
});

0 comments on commit 1b7448f

Please sign in to comment.