Skip to content

Commit

Permalink
fix: slow types
Browse files Browse the repository at this point in the history
  • Loading branch information
load1n9 committed Sep 30, 2024
1 parent 8e31438 commit 531fc70
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 30 deletions.
32 changes: 20 additions & 12 deletions packages/core/src/core/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import type { Tensor } from "./tensor/tensor.ts";
import type { NeuralNetwork } from "./api/network.ts";
import { SGDOptimizer } from "./api/optimizer.ts";
import { PostProcess, type PostProcessor } from "./api/postprocess.ts";
import type { DenseLayerConfig } from "./api/layer.ts";

/**
* Sequential Neural Network
Expand Down Expand Up @@ -47,21 +48,22 @@ export class Sequential implements NeuralNetwork {
*/
async predict(
data: Tensor<Rank>,
config?: { postProcess?: PostProcessor; layers?: [number, number] }
config?: { postProcess?: PostProcessor; layers?: [number, number] },
): Promise<Tensor<Rank>> {
if (!config)
if (!config) {
config = {
postProcess: PostProcess("none"),
};
}
if (config.layers) {
if (
config.layers[0] < 0 ||
config.layers[1] > this.config.layers.length
) {
throw new RangeError(
`Execution range should be within (0, ${
this.config.layers.length
}). Received (${(config.layers[0], config.layers[1])})`
`Execution range should be within (0, ${this.config.layers.length}). Received (${(config
.layers[0],
config.layers[1])})`,
);
}
const lastLayer = this.config.layers[config.layers[1] - 1];
Expand All @@ -77,9 +79,12 @@ export class Sequential implements NeuralNetwork {
data,
{
postProcess: config.postProcess || PostProcess("none"),
outputShape: lastLayer.config.size,
outputShape: (lastLayer as {
type: LayerType.Dense;
config: DenseLayerConfig;
}).config.size,
},
layerList
layerList,
);
} else if (lastLayer.type === LayerType.Activation) {
const penultimate = this.config.layers[config.layers[1] - 2];
Expand All @@ -91,26 +96,29 @@ export class Sequential implements NeuralNetwork {
data,
{
postProcess: config.postProcess || PostProcess("none"),
outputShape: penultimate.config.size,
outputShape: (penultimate as {
type: LayerType.Dense;
config: DenseLayerConfig;
}).config.size,
},
layerList
layerList,
);
} else {
throw new Error(
`The penultimate layer must be a dense layer, or a flatten layer if the last layer is an activation layer. Received ${penultimate.type}.`
`The penultimate layer must be a dense layer, or a flatten layer if the last layer is an activation layer. Received ${penultimate.type}.`,
);
}
} else {
throw new Error(
`The output layer must be a dense layer, activation layer, or a flatten layer. Received ${lastLayer.type}.`
`The output layer must be a dense layer, activation layer, or a flatten layer. Received ${lastLayer.type}.`,
);
}
}
return await this.backend.predict(
data,
config.postProcess
? (config as { postProcess: PostProcessor; layers?: [number, number] })
: { ...config, postProcess: PostProcess("none") }
: { ...config, postProcess: PostProcess("none") },
);
}

Expand Down
2 changes: 1 addition & 1 deletion packages/utilities/src/utils/misc/argmax.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export function argmax(mat: ArrayLike<number | bigint>) {
export function argmax(mat: ArrayLike<number | bigint>): number {
let max = mat[0];
let index = 0;
for (let i = 0; i < mat.length; i++) {
Expand Down
31 changes: 14 additions & 17 deletions packages/utilities/src/utils/misc/matrix.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ export type MatrixLike<DT extends DataType> = {
* This is a collection of row vectors.
* A special case of Tensor for 2D data.
*/
export class Matrix<DT extends DataType>
extends Tensor<DT, 2>
implements Sliceable, MatrixLike<DT>
{
export class Matrix<DT extends DataType> extends Tensor<DT, 2>
implements Sliceable, MatrixLike<DT> {
/**
* Create a matrix from a typed array
* @param data Data to move into the matrix.
Expand All @@ -42,15 +40,15 @@ export class Matrix<DT extends DataType>
constructor(dType: DT, shape: Shape<2>);
constructor(
data: NDArray<DT>[2] | DType<DT> | DT | TensorLike<DT, 2>,
shape?: Shape<2> | DT
shape?: Shape<2> | DT,
) {
// @ts-ignore This call will work
super(data, shape);
}
get head() {
get head(): Matrix<DT> {
return this.slice(0, Math.min(this.nRows, 10));
}
get tail() {
get tail(): Matrix<DT> {
return this.slice(Math.max(this.nRows - 10, 0), this.nRows);
}
/** Convert the Matrix into a HTML table */
Expand Down Expand Up @@ -87,7 +85,7 @@ export class Matrix<DT extends DataType>
/** Get the transpose of the matrix. This method clones the matrix. */
get T(): Matrix<DT> {
const resArr = new (this.data.constructor as DTypeConstructor<DT>)(
this.nRows * this.nCols
this.nRows * this.nCols,
) as DType<DT>;
let i = 0;
for (const col of this.cols()) {
Expand All @@ -114,7 +112,7 @@ export class Matrix<DT extends DataType>
col(n: number): DType<DT> {
let i = 0;
const col = new (this.data.constructor as DTypeConstructor<DT>)(
this.nRows
this.nRows,
) as DType<DT>;
let offset = 0;
while (i < this.nRows) {
Expand All @@ -139,7 +137,7 @@ export class Matrix<DT extends DataType>
/** Get a column array of all column sums in the matrix */
colSum(): DType<DT> {
const sum = new (this.data.constructor as DTypeConstructor<DT>)(
this.nRows
this.nRows,
) as DType<DT>;
let i = 0;
while (i < this.nCols) {
Expand Down Expand Up @@ -169,8 +167,7 @@ export class Matrix<DT extends DataType>
while (j < this.nCols) {
let i = 0;
while (i < this.nRows) {
const adder =
(this.item(i, j) as DTypeValue<DT>) *
const adder = (this.item(i, j) as DTypeValue<DT>) *
(rhs.item(i, j) as DTypeValue<DT>);
// @ts-ignore I'll fix this later
res += adder as DTypeValue<DT>;
Expand All @@ -182,7 +179,7 @@ export class Matrix<DT extends DataType>
}
/** Filter the matrix by rows */
override filter(
fn: (value: DType<DT>, row: number, _: DType<DT>[]) => boolean
fn: (value: DType<DT>, row: number, _: DType<DT>[]) => boolean,
): Matrix<DT> {
const satisfying: number[] = [];
let i = 0;
Expand Down Expand Up @@ -224,7 +221,7 @@ export class Matrix<DT extends DataType>
/** Compute the sum of all rows */
rowSum(): DType<DT> {
const sum = new (this.data.constructor as DTypeConstructor<DT>)(
this.nCols
this.nCols,
) as DType<DT>;
let i = 0;
let offset = 0;
Expand Down Expand Up @@ -271,9 +268,9 @@ export class Matrix<DT extends DataType>
return new Matrix<DT>(
this.data.slice(
start ? start * this.nCols : 0,
end ? end * this.nCols : undefined
end ? end * this.nCols : undefined,
) as DType<DT>,
[end ? end - start : this.nRows - start, this.nCols]
[end ? end - start : this.nRows - start, this.nCols],
);
}
/** Iterate through rows */
Expand All @@ -290,7 +287,7 @@ export class Matrix<DT extends DataType>
while (i < this.nCols) {
let j = 0;
const col = new (this.data.constructor as DTypeConstructor<DT>)(
this.nRows
this.nRows,
) as DType<DT>;
while (j < this.nRows) {
col[j] = this.data[j * this.nCols + i];
Expand Down

0 comments on commit 531fc70

Please sign in to comment.