diff --git a/src/Gliner.ts b/src/Gliner.ts index 2eaeccb..5a7c4ae 100644 --- a/src/Gliner.ts +++ b/src/Gliner.ts @@ -44,7 +44,11 @@ export class Gliner { env.allowLocalModels = config.transformersSettings?.allowLocalModels ?? false; env.useBrowserCache = config.transformersSettings?.useBrowserCache ?? false; - this.config = { ...config, maxWidth: config.maxWidth || 12, modelType: config.modelType || "span-level" }; + this.config = { + ...config, + maxWidth: config.maxWidth || 12, + modelType: config.modelType || "span-level", + }; } async initialize(): Promise { @@ -62,8 +66,7 @@ export class Gliner { const decoder = new SpanDecoder({ max_width: maxWidth }); this.model = new SpanModel({ max_width: maxWidth }, processor, decoder, onnxWrapper); - } - else { + } else { console.log("Initializing Token-level Model..."); const processor = new TokenProcessor({ max_width: maxWidth }, tokenizer, wordSplitter); @@ -79,7 +82,7 @@ export class Gliner { entities, flatNer = true, threshold = 0.5, - multiLabel = false + multiLabel = false, }: IInference): Promise { if (!this.model) { throw new Error("Model is not initialized. Call initialize() first."); diff --git a/src/decoder.ts b/src/decoder.ts index 6d826ee..828053d 100644 --- a/src/decoder.ts +++ b/src/decoder.ts @@ -57,7 +57,7 @@ abstract class BaseDecoder { const newList: Spans = []; // Sort spans by their score (last element) in descending order - const spanProb: Spans = spans.slice().sort((a, b) => b[b.length-1] - a[a.length - 1]); + const spanProb: Spans = spans.slice().sort((a, b) => b[b.length - 1] - a[a.length - 1]); for (let i = 0; i < spanProb.length; i++) { const b = spanProb[i]; @@ -77,7 +77,7 @@ abstract class BaseDecoder { // Sort newList by start position (second element) for correct ordering return newList.sort((a, b) => a[1] - b[1]); } -}; +} // SpanDecoder subclass export class SpanDecoder extends BaseDecoder { decode( @@ -95,7 +95,7 @@ export class SpanDecoder extends BaseDecoder { threshold: number = 0.5, multiLabel: boolean = false, ): RawInferenceResult { - const spans: RawInferenceResult= []; + const spans: RawInferenceResult = []; for (let batch = 0; batch < batchSize; batch++) { spans.push([]); @@ -152,7 +152,6 @@ export class TokenDecoder extends BaseDecoder { threshold: number = 0.5, multiLabel: boolean = false, ): RawInferenceResult { - const positionPadding = batchSize * inputLength * numEntities; const batchPadding = inputLength * numEntities; const tokenPadding = numEntities; @@ -203,10 +202,10 @@ export class TokenDecoder extends BaseDecoder { // Calculate the inside span scores const insideSpanScores = insideScore[batch] .slice(start, end + 1) - .map(tokenScores => tokenScores[clsSt]); + .map((tokenScores) => tokenScores[clsSt]); // Check if all scores within the span are above the threshold - if (insideSpanScores.some(score => score < threshold)) continue; + if (insideSpanScores.some((score) => score < threshold)) continue; // Calculate mean span score const spanScore = insideSpanScores.reduce((a, b) => a + b, 0) / insideSpanScores.length; @@ -233,4 +232,4 @@ export class TokenDecoder extends BaseDecoder { return allSelectedSpans; } -} \ No newline at end of file +} diff --git a/src/model.ts b/src/model.ts index 1b33ef5..72abf2a 100644 --- a/src/model.ts +++ b/src/model.ts @@ -13,7 +13,7 @@ export class Model { async initialize(): Promise { await this.onnxWrapper.init(); } -}; +} export class SpanModel extends Model { prepareInputs(batch: any): Record { @@ -91,18 +91,22 @@ export class SpanModel extends Model { batch_size: number = 4, max_words: number = 512, ): Promise { - const { idToClass }: { - idToClass: Record } = this.processor.createMappings(entities); + const { + idToClass, + }: { + idToClass: Record; + } = this.processor.createMappings(entities); let batchIds: number[] = []; let batchTokens: string[][] = []; let batchWordsStartIdx: number[][] = []; let batchWordsEndIdx: number[][] = []; texts.forEach((text, id) => { - let [tokens, wordsStartIdx, wordsEndIdx]: [string[], number[], number[]] = this.processor.tokenizeText(text); + let [tokens, wordsStartIdx, wordsEndIdx]: [string[], number[], number[]] = + this.processor.tokenizeText(text); let num_sub_batches: number = Math.ceil(tokens.length / max_words); for (let i = 0; i < num_sub_batches; i++) { - let start:number = i * max_words; + let start: number = i * max_words; let end: number = Math.min((i + 1) * max_words, tokens.length); batchIds.push(id); @@ -128,24 +132,18 @@ export class SpanModel extends Model { let currBatchWordsEndIdx: number[][] = batchWordsEndIdx.slice(start, end); let currBatchIds: number[] = batchIds.slice(start, end); - let [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = this.processor.prepareTextInputs( - currBatchTokens, - entities, - ); + let [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = + this.processor.prepareTextInputs(currBatchTokens, entities); - let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = this.processor.encodeInputs( - inputTokens, - promptLengths, - ); + let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = + this.processor.encodeInputs(inputTokens, promptLengths); inputsIds = this.processor.padArray(inputsIds); attentionMasks = this.processor.padArray(attentionMasks); wordsMasks = this.processor.padArray(wordsMasks); - let { spanIdxs, spanMasks }: { spanIdxs: number[][][]; spanMasks: boolean[][] } = this.processor.prepareSpans( - currBatchTokens, - this.config["max_width"], - ); + let { spanIdxs, spanMasks }: { spanIdxs: number[][][]; spanMasks: boolean[][] } = + this.processor.prepareSpans(currBatchTokens, this.config["max_width"]); spanIdxs = this.processor.padArray(spanIdxs, 3); spanMasks = this.processor.padArray(spanMasks); @@ -164,7 +162,7 @@ export class SpanModel extends Model { }; let feeds: Record = this.prepareInputs(batch); - const results: Record = await this.onnxWrapper.run(feeds); + const results: Record = await this.onnxWrapper.run(feeds); const modelOutput: number[] = results["logits"].data; // const modelOutput = results.logits.data as number[]; @@ -267,18 +265,22 @@ export class TokenModel extends Model { batch_size: number = 4, max_words: number = 512, ): Promise { - const { idToClass }: { - idToClass: Record } = this.processor.createMappings(entities); + const { + idToClass, + }: { + idToClass: Record; + } = this.processor.createMappings(entities); let batchIds: number[] = []; let batchTokens: string[][] = []; let batchWordsStartIdx: number[][] = []; let batchWordsEndIdx: number[][] = []; texts.forEach((text, id) => { - let [tokens, wordsStartIdx, wordsEndIdx]: [string[], number[], number[]] = this.processor.tokenizeText(text); + let [tokens, wordsStartIdx, wordsEndIdx]: [string[], number[], number[]] = + this.processor.tokenizeText(text); let num_sub_batches: number = Math.ceil(tokens.length / max_words); for (let i = 0; i < num_sub_batches; i++) { - let start:number = i * max_words; + let start: number = i * max_words; let end: number = Math.min((i + 1) * max_words, tokens.length); batchIds.push(id); @@ -304,15 +306,11 @@ export class TokenModel extends Model { let currBatchWordsEndIdx: number[][] = batchWordsEndIdx.slice(start, end); let currBatchIds: number[] = batchIds.slice(start, end); - let [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = this.processor.prepareTextInputs( - currBatchTokens, - entities, - ); + let [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = + this.processor.prepareTextInputs(currBatchTokens, entities); - let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = this.processor.encodeInputs( - inputTokens, - promptLengths, - ); + let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = + this.processor.encodeInputs(inputTokens, promptLengths); inputsIds = this.processor.padArray(inputsIds); attentionMasks = this.processor.padArray(attentionMasks); @@ -330,7 +328,7 @@ export class TokenModel extends Model { }; let feeds: Record = this.prepareInputs(batch); - const results: Record = await this.onnxWrapper.run(feeds); + const results: Record = await this.onnxWrapper.run(feeds); const modelOutput: number[] = results["logits"].data; // const modelOutput = results.logits.data as number[]; @@ -361,4 +359,4 @@ export class TokenModel extends Model { return finalDecodedSpans; } -} \ No newline at end of file +} diff --git a/src/processor.ts b/src/processor.ts index b49f4b9..659233f 100644 --- a/src/processor.ts +++ b/src/processor.ts @@ -181,17 +181,24 @@ export class SpanProcessor extends Processor { } prepareBatch(texts: string[], entities: string[]): Record { - const [batchTokens, batchWordsStartIdx, batchWordsEndIdx]: [string[][], number[][], number[][]] = this.batchTokenizeText(texts); - const { idToClass }: {idToClass: Record} = this.createMappings(entities); - const [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = this.prepareTextInputs(batchTokens, entities); - - let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = this.encodeInputs(inputTokens, promptLengths); + const [batchTokens, batchWordsStartIdx, batchWordsEndIdx]: [ + string[][], + number[][], + number[][], + ] = this.batchTokenizeText(texts); + const { idToClass }: { idToClass: Record } = this.createMappings(entities); + const [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = + this.prepareTextInputs(batchTokens, entities); + + let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = + this.encodeInputs(inputTokens, promptLengths); inputsIds = this.padArray(inputsIds); attentionMasks = this.padArray(attentionMasks); wordsMasks = this.padArray(wordsMasks); - let { spanIdxs, spanMasks }: { spanIdxs: number[][][]; spanMasks: boolean[][] } = this.prepareSpans(batchTokens, this.config["max_width"]); + let { spanIdxs, spanMasks }: { spanIdxs: number[][][]; spanMasks: boolean[][] } = + this.prepareSpans(batchTokens, this.config["max_width"]); spanIdxs = this.padArray(spanIdxs, 3); spanMasks = this.padArray(spanMasks); @@ -217,11 +224,17 @@ export class TokenProcessor extends Processor { } prepareBatch(texts: string[], entities: string[]): Record { - const [batchTokens, batchWordsStartIdx, batchWordsEndIdx]: [string[][], number[][], number[][]] = this.batchTokenizeText(texts); - const { idToClass }: {idToClass: Record} = this.createMappings(entities); - const [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = this.prepareTextInputs(batchTokens, entities); - - let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = this.encodeInputs(inputTokens, promptLengths); + const [batchTokens, batchWordsStartIdx, batchWordsEndIdx]: [ + string[][], + number[][], + number[][], + ] = this.batchTokenizeText(texts); + const { idToClass }: { idToClass: Record } = this.createMappings(entities); + const [inputTokens, textLengths, promptLengths]: [string[][], number[], number[]] = + this.prepareTextInputs(batchTokens, entities); + + let [inputsIds, attentionMasks, wordsMasks]: [number[][], number[][], number[][]] = + this.encodeInputs(inputTokens, promptLengths); inputsIds = this.padArray(inputsIds); attentionMasks = this.padArray(attentionMasks);