diff --git a/app/scripts/metamask-controller.test.js b/app/scripts/metamask-controller.test.js index 0cd4fba34589..880df69aa00f 100644 --- a/app/scripts/metamask-controller.test.js +++ b/app/scripts/metamask-controller.test.js @@ -45,6 +45,7 @@ import { } from './lib/accounts/BalancesController'; import { BalancesTracker as MultichainBalancesTracker } from './lib/accounts/BalancesTracker'; import { deferredPromise } from './lib/util'; +import { METAMASK_COOKIE_HANDLER } from './constants/stream'; import MetaMaskController, { ONE_KEY_VIA_TREZOR_MINOR_VERSION, } from './metamask-controller'; @@ -1273,6 +1274,129 @@ describe('MetaMaskController', () => { expect(mockKeyring.destroy).toHaveBeenCalledTimes(1); }); }); + describe('#setupPhishingCommunication', () => { + beforeEach(() => { + jest.spyOn(metamaskController, 'safelistPhishingDomain'); + jest.spyOn(metamaskController, 'backToSafetyPhishingWarning'); + metamaskController.preferencesController.setUsePhishDetect(true); + }); + afterEach(() => { + jest.clearAllMocks(); + }); + it('creates a phishing stream with safelistPhishingDomain and backToSafetyPhishingWarning handler', async () => { + const safelistPhishingDomainRequest = { + name: 'metamask-phishing-safelist', + data: { + id: 1, + method: 'safelistPhishingDomain', + params: ['mockHostname'], + }, + }; + const backToSafetyPhishingWarningRequest = { + name: 'metamask-phishing-safelist', + data: { id: 2, method: 'backToSafetyPhishingWarning', params: [] }, + }; + + const { promise, resolve } = deferredPromise(); + const { promise: promiseStream, resolve: resolveStream } = + deferredPromise(); + const streamTest = createThroughStream((chunk, _, cb) => { + if (chunk.name !== 'metamask-phishing-safelist') { + cb(); + return; + } + resolve(); + cb(null, chunk); + }); + + metamaskController.setupPhishingCommunication({ + connectionStream: streamTest, + }); + + streamTest.write(safelistPhishingDomainRequest, null, () => { + expect( + metamaskController.safelistPhishingDomain, + ).toHaveBeenCalledWith('mockHostname'); + }); + streamTest.write(backToSafetyPhishingWarningRequest, null, () => { + expect( + metamaskController.backToSafetyPhishingWarning, + ).toHaveBeenCalled(); + resolveStream(); + }); + + await promise; + streamTest.end(); + await promiseStream; + }); + }); + + describe('#setUpCookieHandlerCommunication', () => { + let localMetaMaskController; + beforeEach(() => { + localMetaMaskController = new MetaMaskController({ + showUserConfirmation: noop, + encryptor: mockEncryptor, + initState: { + ...cloneDeep(firstTimeState), + MetaMetricsController: { + metaMetricsId: 'MOCK_METRICS_ID', + participateInMetaMetrics: true, + dataCollectionForMarketing: true, + }, + }, + initLangCode: 'en_US', + platform: { + showTransactionNotification: () => undefined, + getVersion: () => 'foo', + }, + browser: browserPolyfillMock, + infuraProjectId: 'foo', + isFirstMetaMaskControllerSetup: true, + }); + jest.spyOn(localMetaMaskController, 'getCookieFromMarketingPage'); + }); + afterEach(() => { + jest.clearAllMocks(); + }); + it('creates a cookie handler communication stream with getCookieFromMarketingPage handler', async () => { + const attributionRequest = { + name: METAMASK_COOKIE_HANDLER, + data: { + id: 1, + method: 'getCookieFromMarketingPage', + params: [{ ga_client_id: 'XYZ.ABC' }], + }, + }; + + const { promise, resolve } = deferredPromise(); + const { promise: promiseStream, resolve: resolveStream } = + deferredPromise(); + const streamTest = createThroughStream((chunk, _, cb) => { + if (chunk.name !== METAMASK_COOKIE_HANDLER) { + cb(); + return; + } + resolve(); + cb(null, chunk); + }); + + localMetaMaskController.setUpCookieHandlerCommunication({ + connectionStream: streamTest, + }); + + streamTest.write(attributionRequest, null, () => { + expect( + localMetaMaskController.getCookieFromMarketingPage, + ).toHaveBeenCalledWith({ ga_client_id: 'XYZ.ABC' }); + resolveStream(); + }); + + await promise; + streamTest.end(); + await promiseStream; + }); + }); describe('#setupUntrustedCommunicationEip1193', () => { const mockTxParams = { from: TEST_ADDRESS };