From 531fc70ea528213b28b814e69650ad7ad2fb529d Mon Sep 17 00:00:00 2001 From: Dean Srebnik <49134864+load1n9@users.noreply.github.com> Date: Sun, 29 Sep 2024 22:41:43 -0400 Subject: [PATCH] fix: slow types --- packages/core/src/core/mod.ts | 32 +++++++++++++-------- packages/utilities/src/utils/misc/argmax.ts | 2 +- packages/utilities/src/utils/misc/matrix.ts | 31 +++++++++----------- 3 files changed, 35 insertions(+), 30 deletions(-) diff --git a/packages/core/src/core/mod.ts b/packages/core/src/core/mod.ts index b34a55e..924e06e 100644 --- a/packages/core/src/core/mod.ts +++ b/packages/core/src/core/mod.ts @@ -12,6 +12,7 @@ import type { Tensor } from "./tensor/tensor.ts"; import type { NeuralNetwork } from "./api/network.ts"; import { SGDOptimizer } from "./api/optimizer.ts"; import { PostProcess, type PostProcessor } from "./api/postprocess.ts"; +import type { DenseLayerConfig } from "./api/layer.ts"; /** * Sequential Neural Network @@ -47,21 +48,22 @@ export class Sequential implements NeuralNetwork { */ async predict( data: Tensor, - config?: { postProcess?: PostProcessor; layers?: [number, number] } + config?: { postProcess?: PostProcessor; layers?: [number, number] }, ): Promise> { - if (!config) + if (!config) { config = { postProcess: PostProcess("none"), }; + } if (config.layers) { if ( config.layers[0] < 0 || config.layers[1] > this.config.layers.length ) { throw new RangeError( - `Execution range should be within (0, ${ - this.config.layers.length - }). Received (${(config.layers[0], config.layers[1])})` + `Execution range should be within (0, ${this.config.layers.length}). Received (${(config + .layers[0], + config.layers[1])})`, ); } const lastLayer = this.config.layers[config.layers[1] - 1]; @@ -77,9 +79,12 @@ export class Sequential implements NeuralNetwork { data, { postProcess: config.postProcess || PostProcess("none"), - outputShape: lastLayer.config.size, + outputShape: (lastLayer as { + type: LayerType.Dense; + config: DenseLayerConfig; + }).config.size, }, - layerList + layerList, ); } else if (lastLayer.type === LayerType.Activation) { const penultimate = this.config.layers[config.layers[1] - 2]; @@ -91,18 +96,21 @@ export class Sequential implements NeuralNetwork { data, { postProcess: config.postProcess || PostProcess("none"), - outputShape: penultimate.config.size, + outputShape: (penultimate as { + type: LayerType.Dense; + config: DenseLayerConfig; + }).config.size, }, - layerList + layerList, ); } else { throw new Error( - `The penultimate layer must be a dense layer, or a flatten layer if the last layer is an activation layer. Received ${penultimate.type}.` + `The penultimate layer must be a dense layer, or a flatten layer if the last layer is an activation layer. Received ${penultimate.type}.`, ); } } else { throw new Error( - `The output layer must be a dense layer, activation layer, or a flatten layer. Received ${lastLayer.type}.` + `The output layer must be a dense layer, activation layer, or a flatten layer. Received ${lastLayer.type}.`, ); } } @@ -110,7 +118,7 @@ export class Sequential implements NeuralNetwork { data, config.postProcess ? (config as { postProcess: PostProcessor; layers?: [number, number] }) - : { ...config, postProcess: PostProcess("none") } + : { ...config, postProcess: PostProcess("none") }, ); } diff --git a/packages/utilities/src/utils/misc/argmax.ts b/packages/utilities/src/utils/misc/argmax.ts index be8bfb1..0fa2602 100644 --- a/packages/utilities/src/utils/misc/argmax.ts +++ b/packages/utilities/src/utils/misc/argmax.ts @@ -1,4 +1,4 @@ -export function argmax(mat: ArrayLike) { +export function argmax(mat: ArrayLike): number { let max = mat[0]; let index = 0; for (let i = 0; i < mat.length; i++) { diff --git a/packages/utilities/src/utils/misc/matrix.ts b/packages/utilities/src/utils/misc/matrix.ts index 2b8dc7e..66ec500 100644 --- a/packages/utilities/src/utils/misc/matrix.ts +++ b/packages/utilities/src/utils/misc/matrix.ts @@ -27,10 +27,8 @@ export type MatrixLike
= { * This is a collection of row vectors. * A special case of Tensor for 2D data. */ -export class Matrix
- extends Tensor - implements Sliceable, MatrixLike
-{ +export class Matrix
extends Tensor + implements Sliceable, MatrixLike
{ /** * Create a matrix from a typed array * @param data Data to move into the matrix. @@ -42,15 +40,15 @@ export class Matrix
constructor(dType: DT, shape: Shape<2>); constructor( data: NDArray
[2] | DType
| DT | TensorLike, - shape?: Shape<2> | DT + shape?: Shape<2> | DT, ) { // @ts-ignore This call will work super(data, shape); } - get head() { + get head(): Matrix
{ return this.slice(0, Math.min(this.nRows, 10)); } - get tail() { + get tail(): Matrix
{ return this.slice(Math.max(this.nRows - 10, 0), this.nRows); } /** Convert the Matrix into a HTML table */ @@ -87,7 +85,7 @@ export class Matrix
/** Get the transpose of the matrix. This method clones the matrix. */ get T(): Matrix
{ const resArr = new (this.data.constructor as DTypeConstructor
)( - this.nRows * this.nCols + this.nRows * this.nCols, ) as DType
; let i = 0; for (const col of this.cols()) { @@ -114,7 +112,7 @@ export class Matrix
col(n: number): DType
{ let i = 0; const col = new (this.data.constructor as DTypeConstructor
)( - this.nRows + this.nRows, ) as DType
; let offset = 0; while (i < this.nRows) { @@ -139,7 +137,7 @@ export class Matrix
/** Get a column array of all column sums in the matrix */ colSum(): DType
{ const sum = new (this.data.constructor as DTypeConstructor
)( - this.nRows + this.nRows, ) as DType
; let i = 0; while (i < this.nCols) { @@ -169,8 +167,7 @@ export class Matrix
while (j < this.nCols) { let i = 0; while (i < this.nRows) { - const adder = - (this.item(i, j) as DTypeValue
) * + const adder = (this.item(i, j) as DTypeValue
) * (rhs.item(i, j) as DTypeValue
); // @ts-ignore I'll fix this later res += adder as DTypeValue
; @@ -182,7 +179,7 @@ export class Matrix
} /** Filter the matrix by rows */ override filter( - fn: (value: DType
, row: number, _: DType
[]) => boolean + fn: (value: DType
, row: number, _: DType
[]) => boolean, ): Matrix
{ const satisfying: number[] = []; let i = 0; @@ -224,7 +221,7 @@ export class Matrix
/** Compute the sum of all rows */ rowSum(): DType
{ const sum = new (this.data.constructor as DTypeConstructor
)( - this.nCols + this.nCols, ) as DType
; let i = 0; let offset = 0; @@ -271,9 +268,9 @@ export class Matrix
return new Matrix
( this.data.slice( start ? start * this.nCols : 0, - end ? end * this.nCols : undefined + end ? end * this.nCols : undefined, ) as DType
, - [end ? end - start : this.nRows - start, this.nCols] + [end ? end - start : this.nRows - start, this.nCols], ); } /** Iterate through rows */ @@ -290,7 +287,7 @@ export class Matrix
while (i < this.nCols) { let j = 0; const col = new (this.data.constructor as DTypeConstructor
)( - this.nRows + this.nRows, ) as DType
; while (j < this.nRows) { col[j] = this.data[j * this.nCols + i];