Skip to content

Commit

Permalink
Fix type compatibility issues
Browse files Browse the repository at this point in the history
Fixes issues where the types were not compatible with the corresponding
types from the Prisma client. Also adds additional options for bytes
types. BufferObject matches the default JSON.stringify output for a
Buffer object.
  • Loading branch information
mogzol committed Nov 18, 2023
1 parent 45ecd46 commit 30c3085
Show file tree
Hide file tree
Showing 30 changed files with 507 additions and 240 deletions.
69 changes: 43 additions & 26 deletions generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ import { dirname } from "node:path";

interface Config {
enumType: "stringUnion" | "enum";
enumSuffix: string;
enumPrefix: string;
modelSuffix: string;
enumSuffix: string;
modelPrefix: string;
modelSuffix: string;
typePrefix: string;
typeSuffix: string;
dateType: "string" | "Date";
bigIntType: "string" | "bigint";
decimalType: "string" | "Decimal";
bytesType: "Buffer" | "string";
prettier: "true" | "false";
bytesType: "string" | "Buffer" | "BufferObject" | "number[]";
optionalRelations: boolean;
prettier: boolean;
}

// Map of Prisma scalar types to Typescript type getters
Expand All @@ -31,9 +34,10 @@ const SCALAR_TYPE_GETTERS: Record<string, (config: Config) => string> = {
// Since we want the output to have zero dependencies, define custom types which are compatible
// with the actual Prisma types. If users need the real Prisma types, they can cast to them.
const CUSTOM_TYPES: Record<string, string> = {
BufferObject: 'type BufferObject = { type: "Buffer"; data: number[] };',
Decimal: "type Decimal = { valueOf(): string };",
JsonValue:
"type JsonValue = string | number | boolean | { [key in string]: JsonValue } | Array<JsonValue> | null;",
Decimal: "interface Decimal {\n valueOf(): string;\n}",
"type JsonValue = string | number | boolean | { [key in string]?: JsonValue } | Array<JsonValue> | null;",
};

function validateConfig(config: Config) {
Expand All @@ -50,12 +54,9 @@ function validateConfig(config: Config) {
if (!["string", "Decimal"].includes(config.decimalType)) {
errors.push(`Invalid decimalType: ${config.decimalType}`);
}
if (!["Buffer", "string"].includes(config.bytesType)) {
if (!["string", "Buffer", "BufferObject", "number[]"].includes(config.bytesType)) {
errors.push(`Invalid bytesType: ${config.bytesType}`);
}
if (!["true", "false"].includes(config.prettier)) {
errors.push(`Invalid prettier: ${config.prettier}`);
}
if (errors.length > 0) {
throw new Error(errors.join("\n"));
}
Expand Down Expand Up @@ -87,12 +88,15 @@ function getModelTs(
modelData: DMMF.Model,
modelNameMap: Map<string, string>,
enumNameMap: Map<string, string>,
typeNameMap: Map<string, string>,
usedCustomTypes: Set<keyof typeof CUSTOM_TYPES>,
): string {
const fields = modelData.fields
.map(({ name, kind, type, isRequired, isList }) => {
const getDefinition = (resolvedType: string) =>
` ${name}: ${resolvedType}${isList ? "[]" : ""}${!isRequired ? " | null" : ""};`;
const getDefinition = (resolvedType: string, optional = false) =>
` ${name}${optional ? "?" : ""}: ${resolvedType}${isList ? "[]" : ""}${
!isRequired ? " | null" : ""
};`;

switch (kind) {
case "scalar": {
Expand All @@ -108,10 +112,14 @@ function getModelTs(
}
case "object": {
const modelName = modelNameMap.get(type);
if (!modelName) {
const typeName = typeNameMap.get(type);
if (!modelName && !typeName) {
throw new Error(`Unknown model name: ${type}`);
}
return getDefinition(modelName);
return getDefinition(
(modelName ?? typeName) as string,
config.optionalRelations && !typeName, // Type relations are never optional
);
}
case "enum": {
const enumName = enumNameMap.get(type);
Expand All @@ -128,7 +136,8 @@ function getModelTs(
})
.join("\n");

return `export interface ${modelNameMap.get(modelData.name)} {\n${fields}\n}`;
const name = modelNameMap.get(modelData.name) ?? typeNameMap.get(modelData.name);
return `export interface ${name} {\n${fields}\n}`;
}

generatorHandler({
Expand All @@ -139,27 +148,31 @@ generatorHandler({
};
},
async onGenerate(options) {
const baseConfig = options.generator.config;
const config: Config = {
enumType: "enum",
enumSuffix: "",
enumType: "stringUnion",
enumPrefix: "",
modelSuffix: "",
enumSuffix: "",
modelPrefix: "",
modelSuffix: "",
typePrefix: "",
typeSuffix: "",
dateType: "Date",
bigIntType: "bigint",
decimalType: "Decimal",
bytesType: "Buffer",
prettier: "false",
...options.generator.config,
...baseConfig,
// Booleans go here since in the base config they are strings
optionalRelations: baseConfig.optionalRelations !== "false", // Default true
prettier: baseConfig.prettier === "true", // Default false
};

validateConfig(config);

const datamodel = options.dmmf.datamodel;
const models = datamodel.models;
const enums = datamodel.enums;

// For the purposes of this generator, models and types are equivalent
const models = [...datamodel.models, ...datamodel.types];
const types = datamodel.types;

const usedCustomTypes = new Set<keyof typeof CUSTOM_TYPES>();

Expand All @@ -169,16 +182,20 @@ generatorHandler({
const modelNameMap = new Map<string, string>(
models.map((m) => [m.name, `${config.modelPrefix}${m.name}${config.modelSuffix}`]),
);
const typeNameMap = new Map<string, string>(
types.map((t) => [t.name, `${config.typePrefix}${t.name}${config.typeSuffix}`]),
);

const enumsTs = enums.map((e) => getEnumTs(config, e, enumNameMap));
const modelsTs = models.map((m) =>
getModelTs(config, m, modelNameMap, enumNameMap, usedCustomTypes),
// Types and Models are essentially the same thing, so we can run both through getModelTs
const modelsTs = [...models, ...types].map((m) =>
getModelTs(config, m, modelNameMap, enumNameMap, typeNameMap, usedCustomTypes),
);
const customTypesTs = Array.from(usedCustomTypes).map((t) => CUSTOM_TYPES[t]);

let ts = [...enumsTs, ...modelsTs, ...customTypesTs].join("\n\n") + "\n";

if (config.prettier === "true") {
if (config.prettier) {
// Prettier is imported inside this if so that it's not a required dependency
let prettier: typeof import("prettier");
try {
Expand Down
15 changes: 13 additions & 2 deletions prisma/example.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

datasource db {
provider = "postgresql"
url = ""
url = "postgresql://postgres:password@localhost:5432/testing?schema=public"
}

generator client {
Expand All @@ -16,6 +16,17 @@ generator typescriptInterfaces {
prettier = true
}

generator typescriptInterfacesJson {
provider = "node --loader ts-node/esm generator.ts"
output = "exampleJson.ts"
modelSuffix = "Json"
dateType = "string"
bigIntType = "string"
decimalType = "string"
bytesType = "BufferObject"
prettier = true
}

enum Gender {
Male
Female
Expand All @@ -30,8 +41,8 @@ model Person {
gender Gender
addressId Int
address Address @relation(fields: [addressId], references: [id])
friends Person[] @relation("Friends")
friendsOf Person[] @relation("Friends")
friends Person[] @relation("Friends")
data Data?
}

Expand Down
72 changes: 52 additions & 20 deletions test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
* "expected-error.txt" file in the relevant test folder.
*
* You can run specific tests by passing the one(s) you want to run as arguments to this script:
* npm run test custom-output no-options prettier
* npm run test -- custom-output no-options prettier
*
* If you want to run all tests even if some fail, pass the --continue or -c flag:
* npm run test -- -c
*/

import { exec } from "node:child_process";
Expand All @@ -15,7 +18,7 @@ import fs from "node:fs/promises";
import path from "node:path";

const TEMP_TEST_DIRNAME = "__TEST_TMP__";
const BASE_REPLACE_STRING = "// TEST_INSERT_BASE_HERE";
const BASE_REPLACE_REGEX = /^\/\/ ?#INSERT base\.([a-z]+)\.prisma$/gm;
const RED = "\x1b[1m\x1b[41m\x1b[97m";
const GREEN = "\x1b[1m\x1b[42m\x1b[97m";
const RESET = "\x1b[0m";
Expand All @@ -29,29 +32,50 @@ const trimMultiLine = (s: string) =>
.map((l) => l.trim())
.join("\n");

const testFilters = process.argv.slice(2);
let testFilters = process.argv.slice(2);

// Continue on errors if --continue or -c is passed
let continueOnError = false;
let hasErrors = false;
if (testFilters.some((f) => f === "--continue" || f === "-c")) {
continueOnError = true;
testFilters = testFilters.filter((f) => f !== "--continue" && f !== "-c");
}

const tests = (await fs.readdir("tests", { withFileTypes: true }))
.filter(
(dirent) => dirent.isDirectory() && (!testFilters.length || testFilters.includes(dirent.name)),
)
.map((t) => path.join(t.path, t.name));
const testsEntries = await fs.readdir("tests", { withFileTypes: true });
const tests = testsEntries
.filter((d) => d.isDirectory() && (!testFilters.length || testFilters.includes(d.name)))
.map((d) => path.join(d.path, d.name));

// Common schemas used by multiple tests
const baseSchemas = new Map(
await Promise.all(
testsEntries
.filter((f) => f.isFile() && /^base\.[a-z]+\.prisma$/.test(f.name))
.map<Promise<[string, string]>>((f) =>
readFile(path.join(f.path, f.name)).then((c) => [f.name, c]),
),
),
);

// Get the length of the longest test name, so we can pad the output
const longestName = Math.max(...tests.map((t) => t.length));

// Common schema text used by many of the tests
const baseSchema = await readFile(path.join("tests", "base.prisma"));

console.log("Running tests...");

try {
for (const test of tests) {
for (const test of tests) {
try {
process.stdout.write(` ${test}${" ".repeat(longestName - test.length + 2)}`);

const schema = (await readFile(path.join(test, "schema.prisma"))).replace(
BASE_REPLACE_STRING,
baseSchema,
const schema = (await readFile(path.join(test, "schema.prisma"))).replaceAll(
BASE_REPLACE_REGEX,
(_, baseName) => {
const baseSchema = baseSchemas.get(`base.${baseName}.prisma`);
if (!baseSchema) {
throw new Error(`Unknown base schema: ${baseName}`);
}
return baseSchema;
},
);

let expectedError: string | null; // Text of expected stderr after a non-zero exit code
Expand Down Expand Up @@ -119,11 +143,19 @@ try {
process.stdout.write(GREEN + " PASS " + RESET + "\n");

await rimraf(testDir);
} catch (e) {
process.stdout.write(RED + " FAIL " + RESET + "\n\n");
console.error((e as Error).message, "\n");
hasErrors = true;
if (!continueOnError) {
process.exit(1);
}
}
}

console.log("\n\nAll tests passed!");
} catch (e) {
process.stdout.write(RED + " FAIL " + RESET + "\n\n");
console.error((e as Error).message);
if (hasErrors) {
console.error("\nSome tests failed!");
process.exit(1);
} else {
console.log("\nAll tests passed!");
}
45 changes: 45 additions & 0 deletions tests/base.mongo.prisma
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// We need to use mongo to test composite types
datasource db {
provider = "mongodb"
url = ""
}

enum Gender {
Male
Female
Other
}

enum PhotoType {
Selfie
Profile
Tagged
}

model Person {
id Int @id @map("_id")
name String
gender Gender
addressId Int
address Address @relation(fields: [addressId], references: [id])
photos Photo[]
tags Tag?
}

model Address {
id Int @id @map("_id")
addresText String
people Person[]
}

type Photo {
height Int
Width Int
url String
type PhotoType
}

type Tag {
id Int
name String
}
10 changes: 10 additions & 0 deletions tests/base.prisma → tests/base.postgres.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ enum Gender {
Other
}

enum DataTest {
Apple
Banana
Orange
Pear
}

model Person {
id Int @id @default(autoincrement())
name String
Expand Down Expand Up @@ -43,6 +50,7 @@ model Data {
dateField DateTime
jsonField Json
bytesField Bytes
enumField DataTest
optionalStringField String?
optionalBooleanField Boolean?
Expand All @@ -53,6 +61,7 @@ model Data {
optionalDateField DateTime?
optionalJsonField Json?
optionalBytesField Bytes?
optionalEnumField DataTest?
stringArrayField String[]
booleanArrayField Boolean[]
Expand All @@ -63,6 +72,7 @@ model Data {
dateArrayField DateTime[]
jsonArrayField Json[]
bytesArrayField Bytes[]
enumArrayField DataTest[]
personId Int @unique
person Person @relation(fields: [personId], references: [id])
Expand Down
Loading

0 comments on commit 30c3085

Please sign in to comment.