Skip to content

Commit

Permalink
refactor(experimental): refactor getDiscriminatedUnionCodec by using …
Browse files Browse the repository at this point in the history
…new getUnionCodec primitive (#2399)

This PR refactors the implementation of `getDiscriminatedUnionCodec` (without changing its API or behaviour) so that it uses the new `getUnionCodec` helper under the hood.
  • Loading branch information
lorisleiva authored Apr 2, 2024
1 parent bef9604 commit e77a9b4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import {
import { getU8Codec, getU16Codec, getU32Codec, getU64Codec } from '@solana/codecs-numbers';
import { getStringCodec } from '@solana/codecs-strings';
import {
SOLANA_ERROR__CODECS__ENUM_DISCRIMINATOR_OUT_OF_RANGE,
SOLANA_ERROR__CODECS__INVALID_DISCRIMINATED_UNION_VARIANT,
SOLANA_ERROR__CODECS__UNION_VARIANT_OUT_OF_RANGE,
SolanaError,
} from '@solana/errors';

Expand Down Expand Up @@ -115,11 +115,7 @@ describe('getDiscriminatedUnionCodec', () => {
}),
);
expect(() => discriminatedUnion(getWebEvent()).read(new Uint8Array([4]), 0)).toThrow(
new SolanaError(SOLANA_ERROR__CODECS__ENUM_DISCRIMINATOR_OUT_OF_RANGE, {
discriminator: 4,
maxRange: 3,
minRange: 0,
}),
new SolanaError(SOLANA_ERROR__CODECS__UNION_VARIANT_OUT_OF_RANGE, { maxRange: 3, minRange: 0, variant: 4 }),
);
});

Expand Down
103 changes: 20 additions & 83 deletions packages/codecs-data-structures/src/discriminated-union.ts
Original file line number Diff line number Diff line change
@@ -1,24 +1,11 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import {
assertByteArrayIsNotEmptyForCodec,
Codec,
combineCodec,
createDecoder,
createEncoder,
Decoder,
Encoder,
getEncodedSize,
isFixedSize,
ReadonlyUint8Array,
} from '@solana/codecs-core';
import { Codec, combineCodec, Decoder, Encoder, mapDecoder, mapEncoder } from '@solana/codecs-core';
import { getU8Decoder, getU8Encoder, NumberCodec, NumberDecoder, NumberEncoder } from '@solana/codecs-numbers';
import {
SOLANA_ERROR__CODECS__ENUM_DISCRIMINATOR_OUT_OF_RANGE,
SOLANA_ERROR__CODECS__INVALID_DISCRIMINATED_UNION_VARIANT,
SolanaError,
} from '@solana/errors';
import { SOLANA_ERROR__CODECS__INVALID_DISCRIMINATED_UNION_VARIANT, SolanaError } from '@solana/errors';

import { DrainOuterGeneric, getMaxSize, maxCodecSizes, sumCodecSizes } from './utils';
import { getTupleDecoder, getTupleEncoder } from './tuple';
import { getUnionDecoder, getUnionEncoder } from './union';
import { DrainOuterGeneric } from './utils';

/**
* Defines a discriminated union using discriminated union types.
Expand Down Expand Up @@ -132,28 +119,12 @@ export function getDiscriminatedUnionEncoder<
type TFrom = GetEncoderTypeFromVariants<TVariants, TDiscriminatorProperty>;
const discriminatorProperty = (config.discriminator ?? '__kind') as TDiscriminatorProperty;
const prefix = config.size ?? getU8Encoder();
const fixedSize = getDiscriminatedUnionFixedSize(variants, prefix);
return createEncoder({
...(fixedSize !== null
? { fixedSize }
: {
getSizeFromValue: (variant: TFrom) => {
const discriminator = getVariantDiscriminator(variants, variant[discriminatorProperty]);
const variantEncoder = variants[discriminator][1];
return (
getEncodedSize(discriminator, prefix) +
getEncodedSize(variant as TFrom & void, variantEncoder)
);
},
maxSize: getDiscriminatedUnionMaxSize(variants, prefix),
}),
write: (variant: TFrom, bytes, offset) => {
const discriminator = getVariantDiscriminator(variants, variant[discriminatorProperty]);
offset = prefix.write(discriminator, bytes, offset);
const variantEncoder = variants[discriminator][1];
return variantEncoder.write(variant as TFrom & void, bytes, offset);
},
});
return getUnionEncoder(
variants.map(([, variant], index) =>
mapEncoder(getTupleEncoder([prefix, variant]), (value: TFrom): [number, TFrom] => [index, value]),
),
value => getVariantDiscriminator(variants, value[discriminatorProperty]),
);
}

/**
Expand All @@ -169,29 +140,17 @@ export function getDiscriminatedUnionDecoder<
variants: TVariants,
config: DiscriminatedUnionCodecConfig<TDiscriminatorProperty, NumberDecoder> = {},
): Decoder<GetDecoderTypeFromVariants<TVariants, TDiscriminatorProperty>> {
type TTo = GetDecoderTypeFromVariants<TVariants, TDiscriminatorProperty>;
const discriminatorProperty = config.discriminator ?? '__kind';
const prefix = config.size ?? getU8Decoder();
const fixedSize = getDiscriminatedUnionFixedSize(variants, prefix);
return createDecoder({
...(fixedSize !== null ? { fixedSize } : { maxSize: getDiscriminatedUnionMaxSize(variants, prefix) }),
read: (bytes: ReadonlyUint8Array | Uint8Array, offset) => {
assertByteArrayIsNotEmptyForCodec('discriminatedUnion', bytes, offset);
const [discriminator, dOffset] = prefix.read(bytes, offset);
offset = dOffset;
const variantField = variants[Number(discriminator)] ?? null;
if (!variantField) {
throw new SolanaError(SOLANA_ERROR__CODECS__ENUM_DISCRIMINATOR_OUT_OF_RANGE, {
discriminator,
maxRange: variants.length - 1,
minRange: 0,
});
}
const [variant, vOffset] = variantField[1].read(bytes, offset);
offset = vOffset;
return [{ [discriminatorProperty]: variantField[0], ...(variant ?? {}) } as TTo, offset];
},
});
return getUnionDecoder(
variants.map(([discriminator, variant]) =>
mapDecoder(getTupleDecoder([prefix, variant]), ([, value]) => ({
[discriminatorProperty]: discriminator,
...value,
})),
),
(bytes, offset) => Number(prefix.read(bytes, offset)[0]),
);
}

/**
Expand Down Expand Up @@ -220,28 +179,6 @@ export function getDiscriminatedUnionCodec<
);
}

function getDiscriminatedUnionFixedSize<const TVariants extends Variants<Decoder<any> | Encoder<any>>>(
variants: TVariants,
prefix: object | { fixedSize: number },
): number | null {
if (variants.length === 0) return isFixedSize(prefix) ? prefix.fixedSize : null;
if (!isFixedSize(variants[0][1])) return null;
const variantSize = variants[0][1].fixedSize;
const sameSizedVariants = variants.every(
variant => isFixedSize(variant[1]) && variant[1].fixedSize === variantSize,
);
if (!sameSizedVariants) return null;
return isFixedSize(prefix) ? prefix.fixedSize + variantSize : null;
}

function getDiscriminatedUnionMaxSize<const TVariants extends Variants<Decoder<any> | Encoder<any>>>(
variants: TVariants,
prefix: object | { fixedSize: number },
) {
const maxVariantSize = maxCodecSizes(variants.map(([, codec]) => getMaxSize(codec)));
return sumCodecSizes([getMaxSize(prefix), maxVariantSize]) ?? undefined;
}

function getVariantDiscriminator<const TVariants extends Variants<Decoder<any> | Encoder<any>>>(
variants: TVariants,
discriminatorValue: DiscriminatorValue,
Expand Down

0 comments on commit e77a9b4

Please sign in to comment.