Skip to content

Commit

Permalink
fix linting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
Ingvarstep committed Oct 8, 2024
1 parent 748347b commit 9001e24
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 54 deletions.
11 changes: 7 additions & 4 deletions src/Gliner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ export class Gliner {
env.allowLocalModels = config.transformersSettings?.allowLocalModels ?? false;
env.useBrowserCache = config.transformersSettings?.useBrowserCache ?? false;

this.config = { ...config, maxWidth: config.maxWidth || 12, modelType: config.modelType || "span-level" };
this.config = {
...config,
maxWidth: config.maxWidth || 12,
modelType: config.modelType || "span-level",
};
}

async initialize(): Promise<void> {
Expand All @@ -62,8 +66,7 @@ export class Gliner {
const decoder = new SpanDecoder({ max_width: maxWidth });

this.model = new SpanModel({ max_width: maxWidth }, processor, decoder, onnxWrapper);
}
else {
} else {
console.log("Initializing Token-level Model...");

const processor = new TokenProcessor({ max_width: maxWidth }, tokenizer, wordSplitter);
Expand All @@ -79,7 +82,7 @@ export class Gliner {
entities,
flatNer = true,
threshold = 0.5,
multiLabel = false
multiLabel = false,
}: IInference): Promise<InferenceResultMultiple> {
if (!this.model) {
throw new Error("Model is not initialized. Call initialize() first.");
Expand Down
13 changes: 6 additions & 7 deletions src/decoder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ abstract class BaseDecoder {

const newList: Spans = [];
// Sort spans by their score (last element) in descending order
const spanProb: Spans = spans.slice().sort((a, b) => b[b.length-1] - a[a.length - 1]);
const spanProb: Spans = spans.slice().sort((a, b) => b[b.length - 1] - a[a.length - 1]);

for (let i = 0; i < spanProb.length; i++) {
const b = spanProb[i];
Expand All @@ -77,7 +77,7 @@ abstract class BaseDecoder {
// Sort newList by start position (second element) for correct ordering
return newList.sort((a, b) => a[1] - b[1]);
}
};
}
// SpanDecoder subclass
export class SpanDecoder extends BaseDecoder {
decode(
Expand All @@ -95,7 +95,7 @@ export class SpanDecoder extends BaseDecoder {
threshold: number = 0.5,
multiLabel: boolean = false,
): RawInferenceResult {
const spans: RawInferenceResult= [];
const spans: RawInferenceResult = [];

for (let batch = 0; batch < batchSize; batch++) {
spans.push([]);
Expand Down Expand Up @@ -152,7 +152,6 @@ export class TokenDecoder extends BaseDecoder {
threshold: number = 0.5,
multiLabel: boolean = false,
): RawInferenceResult {

const positionPadding = batchSize * inputLength * numEntities;
const batchPadding = inputLength * numEntities;
const tokenPadding = numEntities;
Expand Down Expand Up @@ -203,10 +202,10 @@ export class TokenDecoder extends BaseDecoder {
// Calculate the inside span scores
const insideSpanScores = insideScore[batch]
.slice(start, end + 1)
.map(tokenScores => tokenScores[clsSt]);
.map((tokenScores) => tokenScores[clsSt]);

// Check if all scores within the span are above the threshold
if (insideSpanScores.some(score => score < threshold)) continue;
if (insideSpanScores.some((score) => score < threshold)) continue;

// Calculate mean span score
const spanScore = insideSpanScores.reduce((a, b) => a + b, 0) / insideSpanScores.length;
Expand All @@ -233,4 +232,4 @@ export class TokenDecoder extends BaseDecoder {

return allSelectedSpans;
}
}
}
62 changes: 30 additions & 32 deletions src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export class Model {
async initialize(): Promise<void> {
await this.onnxWrapper.init();
}
};
}

export class SpanModel extends Model {
prepareInputs(batch: any): Record<string, ort.Tensor> {
Expand Down Expand Up @@ -91,18 +91,22 @@ export class SpanModel extends Model {
batch_size: number = 4,
max_words: number = 512,
): Promise<RawInferenceResult> {
const { idToClass }: {
idToClass: Record<number, string> } = this.processor.createMappings(entities);
const {
idToClass,
}: {
idToClass: Record<number, string>;
} = this.processor.createMappings(entities);
let batchIds: number[] = [];
let batchTokens: string[][] = [];
let batchWordsStartIdx: number[][] = [];
let batchWordsEndIdx: number[][] = [];
texts.forEach((text, id) => {
let [tokens, wordsStartIdx, wordsEndIdx]: [string[], number[], number[]] = this.processor.tokenizeText(text);
let [tokens, wordsStartIdx, wordsEndIdx]: [string[], number[], number[]] =
this.processor.tokenizeText(text);
let num_sub_batches: number = Math.ceil(tokens.length / max_words);

for (let i = 0; i < num_sub_batches; i++) {
let start:number = i * max_words;
let start: number = i * max_words;
let end: number = Math.min((i + 1) * max_words, tokens.length);

batchIds.push(id);
Expand All @@ -128,24 +132,18 @@ export class SpanModel extends Model {
let currBatchWordsEndIdx: number[][] = batchWordsEndIdx.slice(start, end);
let currBatchIds: number[] = batchIds.slice(start, end);

let [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = this.processor.prepareTextInputs(
currBatchTokens,
entities,
);
let [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] =
this.processor.prepareTextInputs(currBatchTokens, entities);

let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = this.processor.encodeInputs(
inputTokens,
promptLengths,
);
let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] =
this.processor.encodeInputs(inputTokens, promptLengths);

inputsIds = this.processor.padArray(inputsIds);
attentionMasks = this.processor.padArray(attentionMasks);
wordsMasks = this.processor.padArray(wordsMasks);

let { spanIdxs, spanMasks }: { spanIdxs: number[][][]; spanMasks: boolean[][] } = this.processor.prepareSpans(
currBatchTokens,
this.config["max_width"],
);
let { spanIdxs, spanMasks }: { spanIdxs: number[][][]; spanMasks: boolean[][] } =
this.processor.prepareSpans(currBatchTokens, this.config["max_width"]);

spanIdxs = this.processor.padArray(spanIdxs, 3);
spanMasks = this.processor.padArray(spanMasks);
Expand All @@ -164,7 +162,7 @@ export class SpanModel extends Model {
};

let feeds: Record<string, ort.Tensor> = this.prepareInputs(batch);
const results: Record<string, ort.Tensor> = await this.onnxWrapper.run(feeds);
const results: Record<string, ort.Tensor> = await this.onnxWrapper.run(feeds);
const modelOutput: number[] = results["logits"].data;
// const modelOutput = results.logits.data as number[];

Expand Down Expand Up @@ -267,18 +265,22 @@ export class TokenModel extends Model {
batch_size: number = 4,
max_words: number = 512,
): Promise<RawInferenceResult> {
const { idToClass }: {
idToClass: Record<number, string> } = this.processor.createMappings(entities);
const {
idToClass,
}: {
idToClass: Record<number, string>;
} = this.processor.createMappings(entities);
let batchIds: number[] = [];
let batchTokens: string[][] = [];
let batchWordsStartIdx: number[][] = [];
let batchWordsEndIdx: number[][] = [];
texts.forEach((text, id) => {
let [tokens, wordsStartIdx, wordsEndIdx]: [string[], number[], number[]] = this.processor.tokenizeText(text);
let [tokens, wordsStartIdx, wordsEndIdx]: [string[], number[], number[]] =
this.processor.tokenizeText(text);
let num_sub_batches: number = Math.ceil(tokens.length / max_words);

for (let i = 0; i < num_sub_batches; i++) {
let start:number = i * max_words;
let start: number = i * max_words;
let end: number = Math.min((i + 1) * max_words, tokens.length);

batchIds.push(id);
Expand All @@ -304,15 +306,11 @@ export class TokenModel extends Model {
let currBatchWordsEndIdx: number[][] = batchWordsEndIdx.slice(start, end);
let currBatchIds: number[] = batchIds.slice(start, end);

let [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = this.processor.prepareTextInputs(
currBatchTokens,
entities,
);
let [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] =
this.processor.prepareTextInputs(currBatchTokens, entities);

let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = this.processor.encodeInputs(
inputTokens,
promptLengths,
);
let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] =
this.processor.encodeInputs(inputTokens, promptLengths);

inputsIds = this.processor.padArray(inputsIds);
attentionMasks = this.processor.padArray(attentionMasks);
Expand All @@ -330,7 +328,7 @@ export class TokenModel extends Model {
};

let feeds: Record<string, ort.Tensor> = this.prepareInputs(batch);
const results: Record<string, ort.Tensor> = await this.onnxWrapper.run(feeds);
const results: Record<string, ort.Tensor> = await this.onnxWrapper.run(feeds);
const modelOutput: number[] = results["logits"].data;
// const modelOutput = results.logits.data as number[];

Expand Down Expand Up @@ -361,4 +359,4 @@ export class TokenModel extends Model {

return finalDecodedSpans;
}
}
}
35 changes: 24 additions & 11 deletions src/processor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,24 @@ export class SpanProcessor extends Processor {
}

prepareBatch(texts: string[], entities: string[]): Record<string, any> {
const [batchTokens, batchWordsStartIdx, batchWordsEndIdx]: [string[][], number[][], number[][]] = this.batchTokenizeText(texts);
const { idToClass }: {idToClass: Record<number, string>} = this.createMappings(entities);
const [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = this.prepareTextInputs(batchTokens, entities);

let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = this.encodeInputs(inputTokens, promptLengths);
const [batchTokens, batchWordsStartIdx, batchWordsEndIdx]: [
string[][],
number[][],
number[][],
] = this.batchTokenizeText(texts);
const { idToClass }: { idToClass: Record<number, string> } = this.createMappings(entities);
const [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] =
this.prepareTextInputs(batchTokens, entities);

let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] =
this.encodeInputs(inputTokens, promptLengths);

inputsIds = this.padArray(inputsIds);
attentionMasks = this.padArray(attentionMasks);
wordsMasks = this.padArray(wordsMasks);

let { spanIdxs, spanMasks }: { spanIdxs: number[][][]; spanMasks: boolean[][] } = this.prepareSpans(batchTokens, this.config["max_width"]);
let { spanIdxs, spanMasks }: { spanIdxs: number[][][]; spanMasks: boolean[][] } =
this.prepareSpans(batchTokens, this.config["max_width"]);

spanIdxs = this.padArray(spanIdxs, 3);
spanMasks = this.padArray(spanMasks);
Expand All @@ -217,11 +224,17 @@ export class TokenProcessor extends Processor {
}

prepareBatch(texts: string[], entities: string[]): Record<string, any> {
const [batchTokens, batchWordsStartIdx, batchWordsEndIdx]: [string[][], number[][], number[][]] = this.batchTokenizeText(texts);
const { idToClass }: {idToClass: Record<number, string>} = this.createMappings(entities);
const [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = this.prepareTextInputs(batchTokens, entities);

let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = this.encodeInputs(inputTokens, promptLengths);
const [batchTokens, batchWordsStartIdx, batchWordsEndIdx]: [
string[][],
number[][],
number[][],
] = this.batchTokenizeText(texts);
const { idToClass }: { idToClass: Record<number, string> } = this.createMappings(entities);
const [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] =
this.prepareTextInputs(batchTokens, entities);

let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] =
this.encodeInputs(inputTokens, promptLengths);

inputsIds = this.padArray(inputsIds);
attentionMasks = this.padArray(attentionMasks);
Expand Down

0 comments on commit 9001e24

Please sign in to comment.