Skip to content

Commit

Permalink
Merge pull request #11 from Ingvarstep/optim
Browse files Browse the repository at this point in the history
Optim
  • Loading branch information
Ingvarstep authored Sep 25, 2024
2 parents e4123c9 + 0c5d792 commit dc7150e
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 120 deletions.
5 changes: 5 additions & 0 deletions .changeset/ten-apples-care.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gliner": patch
---

chunking, speed optimizations, better interfaces, text span
47 changes: 42 additions & 5 deletions src/Gliner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,25 @@ export interface InitConfig {
maxWidth?: number;
}

export interface IInference {
texts: string[];
entities: string[];
flatNer?: boolean;
threshold?: number;
}

export type RawInferenceResult = [string, number, number, string, number][][]

export interface IEntityResult {
spanText: string;
start: number;
end: number;
label: string;
score: number;
}
export type InferenceResultSingle = IEntityResult[]
export type InferenceResultMultiple = InferenceResultSingle[]

export class Gliner {
private model: Model | null = null;

Expand All @@ -42,19 +61,37 @@ export class Gliner {
await this.model.initialize();
}

async inference(texts: string[], entities: string[], threshold: number = 0.5, flatNer: boolean = false): Promise<any[]> {
async inference({ texts, entities, flatNer = false, threshold = 0.5 }: IInference): Promise<InferenceResultMultiple> {
if (!this.model) {
throw new Error("Model is not initialized. Call initialize() first.");
}

return await this.model.inference(texts, entities, flatNer, threshold);
const result = await this.model.inference(texts, entities, flatNer, threshold);
return this.mapRawResultToResponse(result);
}

async inference_with_chunking(texts: string[], entities: string[], threshold: number = 0.5, flatNer: boolean = false): Promise<any[]> {
async inference_with_chunking({ texts, entities, flatNer = false, threshold = 0.5 }: IInference): Promise<InferenceResultMultiple> {
if (!this.model) {
throw new Error("Model is not initialized. Call initialize() first.");
}

return await this.model.inference_with_chunking(texts, entities, flatNer, threshold);
const result = await this.model.inference_with_chunking(texts, entities, flatNer, threshold);
return this.mapRawResultToResponse(result);
}
}

mapRawResultToResponse(rawResult: RawInferenceResult): InferenceResultMultiple {
const response: InferenceResultMultiple = [];
for (const individualResult of rawResult) {
const entityResult: IEntityResult[] = individualResult.map(([spanText, start, end, label, score]) => ({
spanText,
start,
end,
label,
score
}));
response.push(entityResult);
}

return response;
}
}
11 changes: 9 additions & 2 deletions src/decoder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ export class SpanDecoder extends BaseDecoder {
inputLength: number,
maxWidth: number,
numEntities: number,
texts: string[][],
batchIds: number[],
batchWordsStartIdx: number[][],
batchWordsEndIdx: number[][],
idToClass: Record<number, string>,
Expand Down Expand Up @@ -106,9 +108,14 @@ export class SpanDecoder extends BaseDecoder {
startToken < batchWordsStartIdx[batch].length &&
endToken < batchWordsEndIdx[batch].length
) {
let globalBatch = batchIds[batch];
let startIdx = batchWordsStartIdx[batch][startToken];
let endIdx = batchWordsEndIdx[batch][endToken];
let spanText = texts[globalBatch].slice(startIdx, endIdx);
spans[batch].push([
batchWordsStartIdx[batch][startToken],
batchWordsEndIdx[batch][endToken],
spanText,
startIdx,
endIdx,
idToClass[entity + 1],
prob
]);
Expand Down
201 changes: 92 additions & 109 deletions src/model.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ort from "onnxruntime-web";
import { ONNXWrapper } from "./ONNXWrapper";
import { RawInferenceResult } from "./Gliner";

export class Model {
constructor(
Expand All @@ -18,42 +19,20 @@ export class Model {
const num_tokens = batch.inputsIds[0].length;
const num_spans = batch.spanIdxs[0].length;

const convertToBool = (arr: any[]): any[] => {
return arr.map((subArr) => subArr.map((val: any) => !!val));
};

const blankConvert = (arr: any[]): any[] => {
return arr.map((val) => (isNaN(Number(val)) ? 0 : Math.floor(Number(val))));
};

const convertToInt = (arr: any[]): any[] => {
return arr.map((subArr) =>
subArr.map((val: any) => (isNaN(Number(val)) ? 0 : Math.floor(Number(val))))
);
};

const convertSpanIdx = (arr: any[]): any[] => {
return arr.flatMap((subArr) =>
subArr.flatMap((pair: any) => pair.map((val: any) => (isNaN(Number(val)) ? 0 : Math.floor(Number(val)))))
);
};

const createTensor = (
data: any[],
shape: number[],
conversionFunc: (arr: any[]) => any[] = convertToInt,
tensorType: any = "int64"
): ort.Tensor => {
const convertedData = conversionFunc(data);
return new this.onnxWrapper.ort.Tensor(tensorType, convertedData.flat(Infinity), shape);
return new this.onnxWrapper.ort.Tensor(tensorType, data.flat(Infinity), shape);
};

let input_ids = createTensor(batch.inputsIds, [batch_size, num_tokens]);
let attention_mask = createTensor(batch.attentionMasks, [batch_size, num_tokens], convertToBool); // NOTE: why convert to bool but type is not bool?
let attention_mask = createTensor(batch.attentionMasks, [batch_size, num_tokens]); // NOTE: why convert to bool but type is not bool?
let words_mask = createTensor(batch.wordsMasks, [batch_size, num_tokens]);
let text_lengths = createTensor(batch.textLengths, [batch_size, 1], blankConvert);
let span_idx = createTensor(batch.spanIdxs, [batch_size, num_spans, 2], convertSpanIdx);
let span_mask = createTensor(batch.spanMasks, [batch_size, num_spans], convertToBool, "bool");
let text_lengths = createTensor(batch.textLengths, [batch_size, 1]);
let span_idx = createTensor(batch.spanIdxs, [batch_size, num_spans, 2]);
let span_mask = createTensor(batch.spanMasks, [batch_size, num_spans], "bool");

const feeds = {
input_ids: input_ids,
Expand All @@ -73,7 +52,7 @@ export class Model {
flatNer: boolean = false,
threshold: number = 0.5,
multiLabel: boolean = false
): Promise<number[][][]> {
): Promise<RawInferenceResult> {
let batch = this.processor.prepareBatch(texts, entities);
let feeds = this.prepareInputs(batch);
const results = await this.onnxWrapper.run(feeds);
Expand All @@ -84,12 +63,14 @@ export class Model {
const inputLength = Math.max(...batch.textLengths);
const maxWidth = this.config.max_width;
const numEntities = entities.length;

const batchIds = Array.from({ length: batchSize }, (_, i) => i);
const decodedSpans = this.decoder.decode(
batchSize,
inputLength,
maxWidth,
numEntities,
texts,
batchIds,
batch.batchWordsStartIdx,
batch.batchWordsEndIdx,
batch.idToClass,
Expand All @@ -108,9 +89,9 @@ export class Model {
flatNer: boolean = false,
threshold: number = 0.5,
multiLabel: boolean = false,
batch_size: number = 8,
batch_size: number = 4,
max_words: number = 512,
): Promise<number[][][]> {
): Promise<RawInferenceResult> {

const { classToId, idToClass } = this.processor.createMappings(entities);

Expand All @@ -121,91 +102,93 @@ export class Model {
texts.forEach((text, id) => {
let [tokens, wordsStartIdx, wordsEndIdx] = this.processor.tokenizeText(text);
let num_sub_batches: number = Math.ceil(tokens.length / max_words);

for (let i = 0; i < num_sub_batches; i++) {
let start = i * max_words;
let end = Math.min((i + 1) * max_words, tokens.length);

batchIds.push(id);
batchTokens.push(tokens.slice(start, end));
batchWordsStartIdx.push(wordsStartIdx.slice(start, end));
batchWordsEndIdx.push(wordsEndIdx.slice(start, end));
}
let start = i * max_words;
let end = Math.min((i + 1) * max_words, tokens.length);

batchIds.push(id);
batchTokens.push(tokens.slice(start, end));
batchWordsStartIdx.push(wordsStartIdx.slice(start, end));
batchWordsEndIdx.push(wordsEndIdx.slice(start, end));
}
});
let num_batches: number = Math.ceil(batchIds.length/batch_size);

let finalDecodedSpans:number[][][] = [];
for (let id = 0; id<texts.length; id++) {

let num_batches: number = Math.ceil(batchIds.length / batch_size);

let finalDecodedSpans: RawInferenceResult = [];
for (let id = 0; id < texts.length; id++) {
finalDecodedSpans.push([]);
}

for (let batch_id = 0; batch_id<num_batches; batch_id++) {
let start: number = batch_id * batch_size;
let end: number = Math.min((batch_id + 1) * batch_size, batchIds.length);

let currBatchTokens = batchTokens.slice(start, end);
let currBatchWordsStartIdx = batchWordsStartIdx.slice(start, end);
let currBatchWordsEndIdx = batchWordsEndIdx.slice(start, end);

let [inputTokens, textLengths, promptLengths] = this.processor.prepareTextInputs(currBatchTokens, entities);

let [inputsIds, attentionMasks, wordsMasks] = this.processor.encodeInputs(inputTokens, promptLengths);

inputsIds = this.processor.padArray(inputsIds);
attentionMasks = this.processor.padArray(attentionMasks);
wordsMasks = this.processor.padArray(wordsMasks);

let { spanIdxs, spanMasks } = this.processor.prepareSpans(batchTokens, this.config["max_width"]);

spanIdxs = this.processor.padArray(spanIdxs, 3);
spanMasks = this.processor.padArray(spanMasks);


let batch = {
inputsIds: inputsIds,
attentionMasks: attentionMasks,
wordsMasks: wordsMasks,
textLengths: textLengths,
spanIdxs: spanIdxs,
spanMasks: spanMasks,
idToClass: idToClass,
batchTokens: batchTokens,
batchWordsStartIdx: currBatchWordsStartIdx,
batchWordsEndIdx: currBatchWordsEndIdx,
};

let feeds = this.prepareInputs(batch);
const results = await this.onnxWrapper.run(feeds);
const modelOutput = results.logits.data;
// const modelOutput = results.logits.data as number[];

const batchSize = batch.batchTokens.length;
const inputLength = Math.max(...batch.textLengths);
const maxWidth = this.config.max_width;
const numEntities = entities.length;

const decodedSpans = this.decoder.decode(
batchSize,
inputLength,
maxWidth,
numEntities,
batch.batchWordsStartIdx,
batch.batchWordsEndIdx,
batch.idToClass,
modelOutput,
flatNer,
threshold,
multiLabel
);

for (let i = 0; i < decodedSpans.length; i++) {
const originalTextId = batchIds[start + i];
finalDecodedSpans[originalTextId].push(...decodedSpans[i]);
}
}
for (let batch_id = 0; batch_id < num_batches; batch_id++) {
let start: number = batch_id * batch_size;
let end: number = Math.min((batch_id + 1) * batch_size, batchIds.length);

let currBatchTokens = batchTokens.slice(start, end);
let currBatchWordsStartIdx = batchWordsStartIdx.slice(start, end);
let currBatchWordsEndIdx = batchWordsEndIdx.slice(start, end);
let currBatchIds = batchIds.slice(start, end);

let [inputTokens, textLengths, promptLengths] = this.processor.prepareTextInputs(currBatchTokens, entities);

let [inputsIds, attentionMasks, wordsMasks] = this.processor.encodeInputs(inputTokens, promptLengths);

inputsIds = this.processor.padArray(inputsIds);
attentionMasks = this.processor.padArray(attentionMasks);
wordsMasks = this.processor.padArray(wordsMasks);

let { spanIdxs, spanMasks } = this.processor.prepareSpans(currBatchTokens, this.config["max_width"]);

spanIdxs = this.processor.padArray(spanIdxs, 3);
spanMasks = this.processor.padArray(spanMasks);

let batch = {
inputsIds: inputsIds,
attentionMasks: attentionMasks,
wordsMasks: wordsMasks,
textLengths: textLengths,
spanIdxs: spanIdxs,
spanMasks: spanMasks,
idToClass: idToClass,
batchTokens: currBatchTokens,
batchWordsStartIdx: currBatchWordsStartIdx,
batchWordsEndIdx: currBatchWordsEndIdx,
};

let feeds = this.prepareInputs(batch);
const results = await this.onnxWrapper.run(feeds);
const modelOutput = results.logits.data;
// const modelOutput = results.logits.data as number[];

const batchSize = batch.batchTokens.length;
const inputLength = Math.max(...batch.textLengths);
const maxWidth = this.config.max_width;
const numEntities = entities.length;

const decodedSpans = this.decoder.decode(
batchSize,
inputLength,
maxWidth,
numEntities,
texts,
currBatchIds,
batch.batchWordsStartIdx,
batch.batchWordsEndIdx,
batch.idToClass,
modelOutput,
flatNer,
threshold,
multiLabel
);

return finalDecodedSpans;
for (let i = 0; i < currBatchIds.length; i++) {
const originalTextId = currBatchIds[i];
finalDecodedSpans[originalTextId].push(...decodedSpans[i]);
}
}

return finalDecodedSpans;
}

}
8 changes: 4 additions & 4 deletions src/processor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -149,20 +149,20 @@ export class SpanProcessor extends Processor {
super(config, tokenizer, wordsSplitter);
}

prepareSpans(batchTokens: string[][], maxWidth: number = 12): { spanIdxs: number[][][]; spanMasks: number[][] } {
prepareSpans(batchTokens: string[][], maxWidth: number = 12): { spanIdxs: number[][][]; spanMasks: boolean[][] } {
let spanIdxs: number[][][] = [];
let spanMasks: number[][] = [];
let spanMasks: boolean[][] = [];

batchTokens.forEach((tokens) => {
let textLength = tokens.length;
let spanIdx: number[][] = [];
let spanMask: number[] = [];
let spanMask: boolean[] = [];

for (let i = 0; i < textLength; i++) {
for (let j = 0; j < maxWidth; j++) {
let endIdx = Math.min(i + j, textLength - 1);
spanIdx.push([i, endIdx]);
spanMask.push(endIdx < textLength ? 1 : 0);
spanMask.push(endIdx < textLength ? true : false);
}
}

Expand Down

0 comments on commit dc7150e

Please sign in to comment.