diff --git a/.changeset/ten-apples-care.md b/.changeset/ten-apples-care.md new file mode 100644 index 0000000..8c86487 --- /dev/null +++ b/.changeset/ten-apples-care.md @@ -0,0 +1,5 @@ +--- +"gliner": patch +--- + +chunking, speed optimizations, better interfaces, text span diff --git a/src/Gliner.ts b/src/Gliner.ts index fef13a6..d5a4527 100644 --- a/src/Gliner.ts +++ b/src/Gliner.ts @@ -16,6 +16,25 @@ export interface InitConfig { maxWidth?: number; } +export interface IInference { + texts: string[]; + entities: string[]; + flatNer?: boolean; + threshold?: number; +} + +export type RawInferenceResult = [string, number, number, string, number][][] + +export interface IEntityResult { + spanText: string; + start: number; + end: number; + label: string; + score: number; +} +export type InferenceResultSingle = IEntityResult[] +export type InferenceResultMultiple = InferenceResultSingle[] + export class Gliner { private model: Model | null = null; @@ -42,19 +61,37 @@ export class Gliner { await this.model.initialize(); } - async inference(texts: string[], entities: string[], threshold: number = 0.5, flatNer: boolean = false): Promise { + async inference({ texts, entities, flatNer = false, threshold = 0.5 }: IInference): Promise { if (!this.model) { throw new Error("Model is not initialized. Call initialize() first."); } - return await this.model.inference(texts, entities, flatNer, threshold); + const result = await this.model.inference(texts, entities, flatNer, threshold); + return this.mapRawResultToResponse(result); } - async inference_with_chunking(texts: string[], entities: string[], threshold: number = 0.5, flatNer: boolean = false): Promise { + async inference_with_chunking({ texts, entities, flatNer = false, threshold = 0.5 }: IInference): Promise { if (!this.model) { throw new Error("Model is not initialized. Call initialize() first."); } - return await this.model.inference_with_chunking(texts, entities, flatNer, threshold); + const result = await this.model.inference_with_chunking(texts, entities, flatNer, threshold); + return this.mapRawResultToResponse(result); } -} + + mapRawResultToResponse(rawResult: RawInferenceResult): InferenceResultMultiple { + const response: InferenceResultMultiple = []; + for (const individualResult of rawResult) { + const entityResult: IEntityResult[] = individualResult.map(([spanText, start, end, label, score]) => ({ + spanText, + start, + end, + label, + score + })); + response.push(entityResult); + } + + return response; + } +} \ No newline at end of file diff --git a/src/decoder.ts b/src/decoder.ts index a36d83e..5773b7b 100644 --- a/src/decoder.ts +++ b/src/decoder.ts @@ -75,6 +75,8 @@ export class SpanDecoder extends BaseDecoder { inputLength: number, maxWidth: number, numEntities: number, + texts: string[][], + batchIds: number[], batchWordsStartIdx: number[][], batchWordsEndIdx: number[][], idToClass: Record, @@ -106,9 +108,14 @@ export class SpanDecoder extends BaseDecoder { startToken < batchWordsStartIdx[batch].length && endToken < batchWordsEndIdx[batch].length ) { + let globalBatch = batchIds[batch]; + let startIdx = batchWordsStartIdx[batch][startToken]; + let endIdx = batchWordsEndIdx[batch][endToken]; + let spanText = texts[globalBatch].slice(startIdx, endIdx); spans[batch].push([ - batchWordsStartIdx[batch][startToken], - batchWordsEndIdx[batch][endToken], + spanText, + startIdx, + endIdx, idToClass[entity + 1], prob ]); diff --git a/src/model.ts b/src/model.ts index aa9b43a..73b540c 100644 --- a/src/model.ts +++ b/src/model.ts @@ -1,5 +1,6 @@ import ort from "onnxruntime-web"; import { ONNXWrapper } from "./ONNXWrapper"; +import { RawInferenceResult } from "./Gliner"; export class Model { constructor( @@ -18,42 +19,20 @@ export class Model { const num_tokens = batch.inputsIds[0].length; const num_spans = batch.spanIdxs[0].length; - const convertToBool = (arr: any[]): any[] => { - return arr.map((subArr) => subArr.map((val: any) => !!val)); - }; - - const blankConvert = (arr: any[]): any[] => { - return arr.map((val) => (isNaN(Number(val)) ? 0 : Math.floor(Number(val)))); - }; - - const convertToInt = (arr: any[]): any[] => { - return arr.map((subArr) => - subArr.map((val: any) => (isNaN(Number(val)) ? 0 : Math.floor(Number(val)))) - ); - }; - - const convertSpanIdx = (arr: any[]): any[] => { - return arr.flatMap((subArr) => - subArr.flatMap((pair: any) => pair.map((val: any) => (isNaN(Number(val)) ? 0 : Math.floor(Number(val))))) - ); - }; - const createTensor = ( data: any[], shape: number[], - conversionFunc: (arr: any[]) => any[] = convertToInt, tensorType: any = "int64" ): ort.Tensor => { - const convertedData = conversionFunc(data); - return new this.onnxWrapper.ort.Tensor(tensorType, convertedData.flat(Infinity), shape); + return new this.onnxWrapper.ort.Tensor(tensorType, data.flat(Infinity), shape); }; let input_ids = createTensor(batch.inputsIds, [batch_size, num_tokens]); - let attention_mask = createTensor(batch.attentionMasks, [batch_size, num_tokens], convertToBool); // NOTE: why convert to bool but type is not bool? + let attention_mask = createTensor(batch.attentionMasks, [batch_size, num_tokens]); // NOTE: why convert to bool but type is not bool? let words_mask = createTensor(batch.wordsMasks, [batch_size, num_tokens]); - let text_lengths = createTensor(batch.textLengths, [batch_size, 1], blankConvert); - let span_idx = createTensor(batch.spanIdxs, [batch_size, num_spans, 2], convertSpanIdx); - let span_mask = createTensor(batch.spanMasks, [batch_size, num_spans], convertToBool, "bool"); + let text_lengths = createTensor(batch.textLengths, [batch_size, 1]); + let span_idx = createTensor(batch.spanIdxs, [batch_size, num_spans, 2]); + let span_mask = createTensor(batch.spanMasks, [batch_size, num_spans], "bool"); const feeds = { input_ids: input_ids, @@ -73,7 +52,7 @@ export class Model { flatNer: boolean = false, threshold: number = 0.5, multiLabel: boolean = false - ): Promise { + ): Promise { let batch = this.processor.prepareBatch(texts, entities); let feeds = this.prepareInputs(batch); const results = await this.onnxWrapper.run(feeds); @@ -84,12 +63,14 @@ export class Model { const inputLength = Math.max(...batch.textLengths); const maxWidth = this.config.max_width; const numEntities = entities.length; - + const batchIds = Array.from({ length: batchSize }, (_, i) => i); const decodedSpans = this.decoder.decode( batchSize, inputLength, maxWidth, numEntities, + texts, + batchIds, batch.batchWordsStartIdx, batch.batchWordsEndIdx, batch.idToClass, @@ -108,9 +89,9 @@ export class Model { flatNer: boolean = false, threshold: number = 0.5, multiLabel: boolean = false, - batch_size: number = 8, + batch_size: number = 4, max_words: number = 512, - ): Promise { + ): Promise { const { classToId, idToClass } = this.processor.createMappings(entities); @@ -121,91 +102,93 @@ export class Model { texts.forEach((text, id) => { let [tokens, wordsStartIdx, wordsEndIdx] = this.processor.tokenizeText(text); let num_sub_batches: number = Math.ceil(tokens.length / max_words); - + for (let i = 0; i < num_sub_batches; i++) { - let start = i * max_words; - let end = Math.min((i + 1) * max_words, tokens.length); - - batchIds.push(id); - batchTokens.push(tokens.slice(start, end)); - batchWordsStartIdx.push(wordsStartIdx.slice(start, end)); - batchWordsEndIdx.push(wordsEndIdx.slice(start, end)); - } + let start = i * max_words; + let end = Math.min((i + 1) * max_words, tokens.length); + batchIds.push(id); + batchTokens.push(tokens.slice(start, end)); + batchWordsStartIdx.push(wordsStartIdx.slice(start, end)); + batchWordsEndIdx.push(wordsEndIdx.slice(start, end)); + } }); - let num_batches: number = Math.ceil(batchIds.length/batch_size); - - let finalDecodedSpans:number[][][] = []; - for (let id = 0; id { let textLength = tokens.length; let spanIdx: number[][] = []; - let spanMask: number[] = []; + let spanMask: boolean[] = []; for (let i = 0; i < textLength; i++) { for (let j = 0; j < maxWidth; j++) { let endIdx = Math.min(i + j, textLength - 1); spanIdx.push([i, endIdx]); - spanMask.push(endIdx < textLength ? 1 : 0); + spanMask.push(endIdx < textLength ? true : false); } }