Skip to content

Commit

Permalink
[WebNN] Fixes MLTensor caching across different contexts (microsoft#2…
Browse files Browse the repository at this point in the history
…3100)

We weren't checking that MLTensors were from the same context before
reusing them.

Found while debugging microsoft/webnn-developer-preview#69
  • Loading branch information
egalli authored Dec 17, 2024
1 parent 5afab78 commit 54edb43
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions js/web/lib/wasm/jsep/webnn/tensor-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,9 @@ class TensorWrapper {
return this.mlContext.readTensor(this.mlTensor);
}

public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean {
public canReuseTensor(context: MLContext, dataType: MLOperandDataType, shape: readonly number[]): boolean {
return (
this.mlContext === context &&
this.dataType === dataType &&
this.tensorShape.length === shape.length &&
this.tensorShape.every((v, i) => v === shape[i])
Expand Down Expand Up @@ -176,12 +177,13 @@ class TensorIdTracker {
}

public async ensureTensor(
context: MLContext,
dataType: MLOperandDataType,
shape: readonly number[],
copyOld: boolean,
): Promise<MLTensor> {
if (this.wrapper) {
if (this.wrapper.sameTypeAndShape(dataType, shape)) {
if (this.wrapper.canReuseTensor(context, dataType, shape)) {
return this.wrapper.tensor;
} else {
if (copyOld) {
Expand Down Expand Up @@ -288,7 +290,7 @@ class TensorManagerImpl implements TensorManager {
if (!tensor) {
throw new Error('Tensor not found.');
}
return tensor.ensureTensor(dataType, shape, copyOld);
return tensor.ensureTensor(this.backend.currentContext, dataType, shape, copyOld);
}

public upload(tensorId: TensorId, data: Uint8Array): void {
Expand Down Expand Up @@ -354,15 +356,15 @@ class TensorManagerImpl implements TensorManager {
readable: boolean,
): Promise<TensorWrapper> {
const sessionId = this.backend.currentSessionId;
const context = this.backend.currentContext;
for (const [index, tensor] of this.freeTensors.entries()) {
if (tensor.sameTypeAndShape(dataType, shape)) {
if (tensor.canReuseTensor(context, dataType, shape)) {
LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`);
const wrapper = this.freeTensors.splice(index, 1)[0];
wrapper.sessionId = sessionId;
return wrapper;
}
}
const context = this.backend.currentContext;
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
const tensor = await context.createTensor({
dataType,
Expand Down

0 comments on commit 54edb43

Please sign in to comment.