Skip to content

Commit

Permalink
🐛 (core): Handle device reconnection in sendApdu
Browse files Browse the repository at this point in the history
  • Loading branch information
ofreyssinet-ledger committed Jul 9, 2024
1 parent 136d9c4 commit 2e64ba0
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 22 deletions.
6 changes: 6 additions & 0 deletions packages/core/src/api/command/Command.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ import { ApduResponse } from "@api/device-session/ApduResponse";
* @template Args - The type of the arguments passed to the command (optional).
*/
export interface Command<Response, Args = void> {
/**
* Indicates whether the command triggers a disconnection from the device when
* it succeeds.
*/
readonly triggersDisconnection?: boolean;

/**
* Gets the APDU (Application Protocol Data Unit) for the command.
*
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/api/command/os/CloseAppCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import { ApduResponse } from "@api/device-session/ApduResponse";
export class CloseAppCommand implements Command<void> {
args = undefined;

readonly triggersDisconnection = true;

getApdu(): Apdu {
const closeAppApduArgs: ApduBuilderArgs = {
cla: 0xb0,
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/api/command/os/OpenAppCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ export type OpenAppArgs = {
export class OpenAppCommand implements Command<void, OpenAppArgs> {
args: OpenAppArgs;

readonly triggersDisconnection = true;

constructor(args: OpenAppArgs) {
this.args = args;
}
Expand Down
20 changes: 16 additions & 4 deletions packages/core/src/internal/device-session/model/DeviceSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ export class DeviceSession {
refreshInterval: 1000,
deviceStatus: DeviceStatus.CONNECTED,
sendApduFn: (rawApdu: Uint8Array) =>
this.sendApdu(rawApdu, { isPolling: true }),
this.sendApdu(rawApdu, {
isPolling: true,
triggersDisconnection: false,
}),
updateStateFn: (state: DeviceSessionState) =>
this.setDeviceSessionState(state),
},
Expand Down Expand Up @@ -87,11 +90,17 @@ export class DeviceSession {

async sendApdu(
rawApdu: Uint8Array,
options: { isPolling: boolean } = { isPolling: false },
options: { isPolling: boolean; triggersDisconnection: boolean } = {
isPolling: false,
triggersDisconnection: false,
},
) {
if (!options.isPolling) this.updateDeviceStatus(DeviceStatus.BUSY);

const errorOrResponse = await this._connectedDevice.sendApdu(rawApdu);
const errorOrResponse = await this._connectedDevice.sendApdu(
rawApdu,
options.triggersDisconnection,
);

return errorOrResponse.ifRight((response) => {
if (CommandUtils.isLockedDeviceResponse(response)) {
Expand All @@ -106,7 +115,10 @@ export class DeviceSession {
command: Command<Response, Args>,
): Promise<Response> {
const apdu = command.getApdu();
const response = await this.sendApdu(apdu.getRawApdu());
const response = await this.sendApdu(apdu.getRawApdu(), {
isPolling: false,
triggersDisconnection: command.triggersDisconnection ?? false,
});

return response.caseOf({
Left: (err) => {
Expand Down
7 changes: 7 additions & 0 deletions packages/core/src/internal/usb/model/Errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,10 @@ export class DisconnectError extends GeneralSdkError {
super(err);
}
}

export class ReconnectionFailedError extends GeneralSdkError {
override readonly _tag = "ReconnectionFailedError";
constructor(readonly err?: unknown) {
super(err);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { SdkError } from "@api/Error";

export type SendApduFnType = (
apdu: Uint8Array,
triggersDisconnection?: boolean,
) => Promise<Either<SdkError, ApduResponse>>;

export interface DeviceConnection {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import { Left, Right } from "purify-ts";

import { ApduReceiverService } from "@internal/device-session/service/ApduReceiverService";
import { ApduSenderService } from "@internal/device-session/service/ApduSenderService";
import { defaultApduReceiverServiceStubBuilder } from "@internal/device-session/service/DefaultApduReceiverService.stub";
import { defaultApduSenderServiceStubBuilder } from "@internal/device-session/service/DefaultApduSenderService.stub";
import { DefaultLoggerPublisherService } from "@internal/logger-publisher/service/DefaultLoggerPublisherService";
import { ReconnectionFailedError } from "@internal/usb/model/Errors";
import { hidDeviceStubBuilder } from "@internal/usb/model/HIDDevice.stub";
import { UsbHidDeviceConnection } from "@internal/usb/transport/UsbHidDeviceConnection";

jest.useFakeTimers();

const RESPONSE_LOCKED_DEVICE = new Uint8Array([
0xaa, 0xaa, 0x05, 0x00, 0x00, 0x00, 0x02, 0x55, 0x15, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
Expand All @@ -14,6 +19,20 @@ const RESPONSE_LOCKED_DEVICE = new Uint8Array([
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
]);

const RESPONSE_SUCCESS = new Uint8Array([
0xaa, 0xaa, 0x05, 0x00, 0x00, 0x00, 0x02, 0x90, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
]);

/**
* Flushes all pending promises
*/
const flushPromises = () =>
new Promise(jest.requireActual("timers").setImmediate);

describe("UsbHidDeviceConnection", () => {
let device: HIDDevice;
let apduSender: ApduSenderService;
Expand Down Expand Up @@ -50,7 +69,74 @@ describe("UsbHidDeviceConnection", () => {
expect(device.sendReport).toHaveBeenCalled();
});

it("should receive APDU through hid report", () => {
it("should receive APDU through hid report", async () => {
// given
device.sendReport = jest.fn(() =>
Promise.resolve(
device.oninputreport!({
type: "inputreport",
data: new DataView(Uint8Array.from(RESPONSE_SUCCESS).buffer),
} as HIDInputReportEvent),
),
);
const connection = new UsbHidDeviceConnection(
{ device, apduSender, apduReceiver },
logger,
);
// when
const response = await connection.sendApdu(Uint8Array.from([]));
// then
expect(response).toEqual(
Right({
statusCode: new Uint8Array([0x90, 0x00]),
data: new Uint8Array([]),
}),
);
});

test("sendApdu(whatever, true) should wait for reconnection before resolving if the response is a success", async () => {
// given
device.sendReport = jest.fn(() =>
Promise.resolve(
device.oninputreport!({
type: "inputreport",
data: new DataView(Uint8Array.from(RESPONSE_SUCCESS).buffer),
} as HIDInputReportEvent),
),
);
const connection = new UsbHidDeviceConnection(
{ device, apduSender, apduReceiver },
logger,
);

let hasResolved = false;
const responsePromise = connection
.sendApdu(Uint8Array.from([]), true)
.then((response) => {
hasResolved = true;
return response;
});

// before reconnecting
await flushPromises();
expect(hasResolved).toBe(false);

// when reconnecting
connection.device = device;
await flushPromises();
expect(hasResolved).toBe(true);

const response = await responsePromise;

expect(response).toEqual(
Right({
statusCode: new Uint8Array([0x90, 0x00]),
data: new Uint8Array([]),
}),
);
});

test("sendApdu(whatever, true) should not wait for reconnection if the response is not a success", async () => {
// given
device.sendReport = jest.fn(() =>
Promise.resolve(
Expand All @@ -64,9 +150,42 @@ describe("UsbHidDeviceConnection", () => {
{ device, apduSender, apduReceiver },
logger,
);

// when
const response = connection.sendApdu(Uint8Array.from([]));
const response = await connection.sendApdu(Uint8Array.from([]), true);

// then
expect(response).toEqual(
Right({
statusCode: new Uint8Array([0x55, 0x15]),
data: new Uint8Array([]),
}),
);
});

test("sendApdu(whatever, true) should return an error if the device gets disconnected while waiting for reconnection", async () => {
// given
device.sendReport = jest.fn(() =>
Promise.resolve(
device.oninputreport!({
type: "inputreport",
data: new DataView(Uint8Array.from(RESPONSE_SUCCESS).buffer),
} as HIDInputReportEvent),
),
);
const connection = new UsbHidDeviceConnection(
{ device, apduSender, apduReceiver },
logger,
);

const responsePromise = connection.sendApdu(Uint8Array.from([]), true);

// when disconnecting
connection.disconnect();
await flushPromises();

// then
expect(response).resolves.toBe(RESPONSE_LOCKED_DEVICE);
const response = await responsePromise;
expect(response).toEqual(Left(new ReconnectionFailedError()));
});
});
73 changes: 61 additions & 12 deletions packages/core/src/internal/usb/transport/UsbHidDeviceConnection.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import { inject } from "inversify";
import { Either, Left, Right } from "purify-ts";
import { Either, Left, Maybe, Right } from "purify-ts";
import { Subject } from "rxjs";

import { ApduResponse } from "@api/device-session/ApduResponse";
import { SdkError } from "@api/Error";
import { CommandUtils } from "@api/index";
import { ApduReceiverService } from "@internal/device-session/service/ApduReceiverService";
import { ApduSenderService } from "@internal/device-session/service/ApduSenderService";
import { loggerTypes } from "@internal/logger-publisher/di/loggerTypes";
import type { LoggerPublisherService } from "@internal/logger-publisher/service/LoggerPublisherService";
import { ReconnectionFailedError } from "@internal/usb/model/Errors";

import { DeviceConnection } from "./DeviceConnection";

Expand All @@ -23,6 +25,10 @@ export class UsbHidDeviceConnection implements DeviceConnection {
private readonly _apduReceiver: ApduReceiverService;
private _sendApduSubject: Subject<ApduResponse>;
private readonly _logger: LoggerPublisherService;
private _settleReconnectionPromise: Maybe<{
resolve(): void;
reject(err: SdkError): void;
}> = Maybe.zero();

constructor(
{ device, apduSender, apduReceiver }: UsbHidDeviceConnectionConstructorArgs,
Expand All @@ -44,15 +50,44 @@ export class UsbHidDeviceConnection implements DeviceConnection {
public set device(device: HIDDevice) {
this._device = device;
this._device.oninputreport = (event) => this.receiveHidInputReport(event);

this._settleReconnectionPromise.ifJust(() => {
this.reconnected();
});
}

async sendApdu(apdu: Uint8Array): Promise<Either<SdkError, ApduResponse>> {
async sendApdu(
apdu: Uint8Array,
triggersDisconnection?: boolean,
): Promise<Either<SdkError, ApduResponse>> {
this._sendApduSubject = new Subject();

this._logger.debug("Sending APDU", {
data: { apdu },
tag: "apdu-sender",
});

const resultPromise = new Promise<Either<SdkError, ApduResponse>>(
(resolve) => {
this._sendApduSubject.subscribe({
next: async (r) => {
if (triggersDisconnection && CommandUtils.isSuccessResponse(r)) {
const reconnectionRes = await this.setupWaitForReconnection();
reconnectionRes.caseOf({
Left: (err) => resolve(Left(err)),
Right: () => resolve(Right(r)),
});
} else {
resolve(Right(r));
}
},
error: (err) => {
resolve(Left(err));
},
});
},
);

const frames = this._apduSender.getFrames(apdu);
for (const frame of frames) {
this._logger.debug("Sending Frame", {
Expand All @@ -65,16 +100,7 @@ export class UsbHidDeviceConnection implements DeviceConnection {
}
}

return new Promise((resolve) => {
this._sendApduSubject.subscribe({
next: (r) => {
resolve(Right(r));
},
error: (err) => {
resolve(Left(err));
},
});
});
return resultPromise;
}

private receiveHidInputReport(event: HIDInputReportEvent) {
Expand All @@ -99,4 +125,27 @@ export class UsbHidDeviceConnection implements DeviceConnection {
},
});
}

private setupWaitForReconnection(): Promise<Either<SdkError, void>> {
return new Promise<Either<SdkError, void>>((resolve) => {
this._settleReconnectionPromise = Maybe.of({
resolve: () => resolve(Right(undefined)),
reject: (error: SdkError) => resolve(Left(error)),
});
});
}

private reconnected() {
this._settleReconnectionPromise.ifJust((promise) => {
promise.resolve();
this._settleReconnectionPromise = Maybe.zero();
});
}

public disconnect() {
this._settleReconnectionPromise.ifJust((promise) => {
promise.reject(new ReconnectionFailedError());
this._settleReconnectionPromise = Maybe.zero();
});
}
}
Loading

0 comments on commit 2e64ba0

Please sign in to comment.