From d8edc7c8150d0ae87d7d6177e7f11881d4c4af3a Mon Sep 17 00:00:00 2001 From: Ihor Stepanov Date: Tue, 24 Sep 2024 20:29:43 +0300 Subject: [PATCH 1/4] add text to the inputs and optimize tensors preparation --- src/decoder.ts | 11 +++++++++-- src/model.ts | 39 +++++++++++---------------------------- src/processor.ts | 8 ++++---- 3 files changed, 24 insertions(+), 34 deletions(-) diff --git a/src/decoder.ts b/src/decoder.ts index a36d83e..5773b7b 100644 --- a/src/decoder.ts +++ b/src/decoder.ts @@ -75,6 +75,8 @@ export class SpanDecoder extends BaseDecoder { inputLength: number, maxWidth: number, numEntities: number, + texts: string[][], + batchIds: number[], batchWordsStartIdx: number[][], batchWordsEndIdx: number[][], idToClass: Record, @@ -106,9 +108,14 @@ export class SpanDecoder extends BaseDecoder { startToken < batchWordsStartIdx[batch].length && endToken < batchWordsEndIdx[batch].length ) { + let globalBatch = batchIds[batch]; + let startIdx = batchWordsStartIdx[batch][startToken]; + let endIdx = batchWordsEndIdx[batch][endToken]; + let spanText = texts[globalBatch].slice(startIdx, endIdx); spans[batch].push([ - batchWordsStartIdx[batch][startToken], - batchWordsEndIdx[batch][endToken], + spanText, + startIdx, + endIdx, idToClass[entity + 1], prob ]); diff --git a/src/model.ts b/src/model.ts index aa9b43a..33c287a 100644 --- a/src/model.ts +++ b/src/model.ts @@ -18,42 +18,20 @@ export class Model { const num_tokens = batch.inputsIds[0].length; const num_spans = batch.spanIdxs[0].length; - const convertToBool = (arr: any[]): any[] => { - return arr.map((subArr) => subArr.map((val: any) => !!val)); - }; - - const blankConvert = (arr: any[]): any[] => { - return arr.map((val) => (isNaN(Number(val)) ? 0 : Math.floor(Number(val)))); - }; - - const convertToInt = (arr: any[]): any[] => { - return arr.map((subArr) => - subArr.map((val: any) => (isNaN(Number(val)) ? 0 : Math.floor(Number(val)))) - ); - }; - - const convertSpanIdx = (arr: any[]): any[] => { - return arr.flatMap((subArr) => - subArr.flatMap((pair: any) => pair.map((val: any) => (isNaN(Number(val)) ? 0 : Math.floor(Number(val))))) - ); - }; - const createTensor = ( data: any[], shape: number[], - conversionFunc: (arr: any[]) => any[] = convertToInt, tensorType: any = "int64" ): ort.Tensor => { - const convertedData = conversionFunc(data); - return new this.onnxWrapper.ort.Tensor(tensorType, convertedData.flat(Infinity), shape); + return new this.onnxWrapper.ort.Tensor(tensorType, data.flat(Infinity), shape); }; let input_ids = createTensor(batch.inputsIds, [batch_size, num_tokens]); - let attention_mask = createTensor(batch.attentionMasks, [batch_size, num_tokens], convertToBool); // NOTE: why convert to bool but type is not bool? + let attention_mask = createTensor(batch.attentionMasks, [batch_size, num_tokens]); // NOTE: why convert to bool but type is not bool? let words_mask = createTensor(batch.wordsMasks, [batch_size, num_tokens]); - let text_lengths = createTensor(batch.textLengths, [batch_size, 1], blankConvert); - let span_idx = createTensor(batch.spanIdxs, [batch_size, num_spans, 2], convertSpanIdx); - let span_mask = createTensor(batch.spanMasks, [batch_size, num_spans], convertToBool, "bool"); + let text_lengths = createTensor(batch.textLengths, [batch_size, 1]); + let span_idx = createTensor(batch.spanIdxs, [batch_size, num_spans, 2]); + let span_mask = createTensor(batch.spanMasks, [batch_size, num_spans], "bool"); const feeds = { input_ids: input_ids, @@ -84,12 +62,14 @@ export class Model { const inputLength = Math.max(...batch.textLengths); const maxWidth = this.config.max_width; const numEntities = entities.length; - + const batchIds = Array.from({ length: batchSize }, (_, i) => i); const decodedSpans = this.decoder.decode( batchSize, inputLength, maxWidth, numEntities, + texts, + batchIds, batch.batchWordsStartIdx, batch.batchWordsEndIdx, batch.idToClass, @@ -147,6 +127,7 @@ export class Model { let currBatchTokens = batchTokens.slice(start, end); let currBatchWordsStartIdx = batchWordsStartIdx.slice(start, end); let currBatchWordsEndIdx = batchWordsEndIdx.slice(start, end); + let currBatchIds = batchIds.slice(start, end); let [inputTokens, textLengths, promptLengths] = this.processor.prepareTextInputs(currBatchTokens, entities); @@ -190,6 +171,8 @@ export class Model { inputLength, maxWidth, numEntities, + texts, + currBatchIds, batch.batchWordsStartIdx, batch.batchWordsEndIdx, batch.idToClass, diff --git a/src/processor.ts b/src/processor.ts index 0a54092..b5abdd8 100644 --- a/src/processor.ts +++ b/src/processor.ts @@ -149,20 +149,20 @@ export class SpanProcessor extends Processor { super(config, tokenizer, wordsSplitter); } - prepareSpans(batchTokens: string[][], maxWidth: number = 12): { spanIdxs: number[][][]; spanMasks: number[][] } { + prepareSpans(batchTokens: string[][], maxWidth: number = 12): { spanIdxs: number[][][]; spanMasks: boolean[][] } { let spanIdxs: number[][][] = []; - let spanMasks: number[][] = []; + let spanMasks: boolean[][] = []; batchTokens.forEach((tokens) => { let textLength = tokens.length; let spanIdx: number[][] = []; - let spanMask: number[] = []; + let spanMask: boolean[] = []; for (let i = 0; i < textLength; i++) { for (let j = 0; j < maxWidth; j++) { let endIdx = Math.min(i + j, textLength - 1); spanIdx.push([i, endIdx]); - spanMask.push(endIdx < textLength ? 1 : 0); + spanMask.push(endIdx < textLength ? true : false); } } From 33542b4e013d2b5e712f6fd6bb57d46a0476a545 Mon Sep 17 00:00:00 2001 From: Ihor Stepanov Date: Wed, 25 Sep 2024 09:36:08 +0300 Subject: [PATCH 2/4] fix chunking inference --- src/model.ts | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/model.ts b/src/model.ts index 33c287a..8f03052 100644 --- a/src/model.ts +++ b/src/model.ts @@ -88,7 +88,7 @@ export class Model { flatNer: boolean = false, threshold: number = 0.5, multiLabel: boolean = false, - batch_size: number = 8, + batch_size: number = 4, max_words: number = 512, ): Promise { @@ -119,7 +119,7 @@ export class Model { for (let id = 0; id Date: Wed, 25 Sep 2024 09:37:41 +0200 Subject: [PATCH 3/4] updated interfaces --- src/Gliner.ts | 47 ++++++++++++-- src/model.ts | 168 +++++++++++++++++++++++++------------------------- 2 files changed, 126 insertions(+), 89 deletions(-) diff --git a/src/Gliner.ts b/src/Gliner.ts index fef13a6..d5a4527 100644 --- a/src/Gliner.ts +++ b/src/Gliner.ts @@ -16,6 +16,25 @@ export interface InitConfig { maxWidth?: number; } +export interface IInference { + texts: string[]; + entities: string[]; + flatNer?: boolean; + threshold?: number; +} + +export type RawInferenceResult = [string, number, number, string, number][][] + +export interface IEntityResult { + spanText: string; + start: number; + end: number; + label: string; + score: number; +} +export type InferenceResultSingle = IEntityResult[] +export type InferenceResultMultiple = InferenceResultSingle[] + export class Gliner { private model: Model | null = null; @@ -42,19 +61,37 @@ export class Gliner { await this.model.initialize(); } - async inference(texts: string[], entities: string[], threshold: number = 0.5, flatNer: boolean = false): Promise { + async inference({ texts, entities, flatNer = false, threshold = 0.5 }: IInference): Promise { if (!this.model) { throw new Error("Model is not initialized. Call initialize() first."); } - return await this.model.inference(texts, entities, flatNer, threshold); + const result = await this.model.inference(texts, entities, flatNer, threshold); + return this.mapRawResultToResponse(result); } - async inference_with_chunking(texts: string[], entities: string[], threshold: number = 0.5, flatNer: boolean = false): Promise { + async inference_with_chunking({ texts, entities, flatNer = false, threshold = 0.5 }: IInference): Promise { if (!this.model) { throw new Error("Model is not initialized. Call initialize() first."); } - return await this.model.inference_with_chunking(texts, entities, flatNer, threshold); + const result = await this.model.inference_with_chunking(texts, entities, flatNer, threshold); + return this.mapRawResultToResponse(result); } -} + + mapRawResultToResponse(rawResult: RawInferenceResult): InferenceResultMultiple { + const response: InferenceResultMultiple = []; + for (const individualResult of rawResult) { + const entityResult: IEntityResult[] = individualResult.map(([spanText, start, end, label, score]) => ({ + spanText, + start, + end, + label, + score + })); + response.push(entityResult); + } + + return response; + } +} \ No newline at end of file diff --git a/src/model.ts b/src/model.ts index 8f03052..73b540c 100644 --- a/src/model.ts +++ b/src/model.ts @@ -1,5 +1,6 @@ import ort from "onnxruntime-web"; import { ONNXWrapper } from "./ONNXWrapper"; +import { RawInferenceResult } from "./Gliner"; export class Model { constructor( @@ -51,7 +52,7 @@ export class Model { flatNer: boolean = false, threshold: number = 0.5, multiLabel: boolean = false - ): Promise { + ): Promise { let batch = this.processor.prepareBatch(texts, entities); let feeds = this.prepareInputs(batch); const results = await this.onnxWrapper.run(feeds); @@ -90,7 +91,7 @@ export class Model { multiLabel: boolean = false, batch_size: number = 4, max_words: number = 512, - ): Promise { + ): Promise { const { classToId, idToClass } = this.processor.createMappings(entities); @@ -101,94 +102,93 @@ export class Model { texts.forEach((text, id) => { let [tokens, wordsStartIdx, wordsEndIdx] = this.processor.tokenizeText(text); let num_sub_batches: number = Math.ceil(tokens.length / max_words); - + for (let i = 0; i < num_sub_batches; i++) { - let start = i * max_words; - let end = Math.min((i + 1) * max_words, tokens.length); - - batchIds.push(id); - batchTokens.push(tokens.slice(start, end)); - batchWordsStartIdx.push(wordsStartIdx.slice(start, end)); - batchWordsEndIdx.push(wordsEndIdx.slice(start, end)); - } + let start = i * max_words; + let end = Math.min((i + 1) * max_words, tokens.length); + batchIds.push(id); + batchTokens.push(tokens.slice(start, end)); + batchWordsStartIdx.push(wordsStartIdx.slice(start, end)); + batchWordsEndIdx.push(wordsEndIdx.slice(start, end)); + } }); - let num_batches: number = Math.ceil(batchIds.length/batch_size); - - let finalDecodedSpans:number[][][] = []; - for (let id = 0; id Date: Wed, 25 Sep 2024 09:38:33 +0200 Subject: [PATCH 4/4] add changeset --- .changeset/ten-apples-care.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/ten-apples-care.md diff --git a/.changeset/ten-apples-care.md b/.changeset/ten-apples-care.md new file mode 100644 index 0000000..8c86487 --- /dev/null +++ b/.changeset/ten-apples-care.md @@ -0,0 +1,5 @@ +--- +"gliner": patch +--- + +chunking, speed optimizations, better interfaces, text span