diff --git a/packages/aws-amplify/package.json b/packages/aws-amplify/package.json index 6f7e403e420..3d1f5cef58a 100644 --- a/packages/aws-amplify/package.json +++ b/packages/aws-amplify/package.json @@ -293,31 +293,31 @@ "name": "[Analytics] record (Pinpoint)", "path": "./dist/esm/analytics/index.mjs", "import": "{ record }", - "limit": "17.23 kB" + "limit": "17.28 kB" }, { "name": "[Analytics] record (Kinesis)", "path": "./dist/esm/analytics/kinesis/index.mjs", "import": "{ record }", - "limit": "48.65 kB" + "limit": "48.69 kB" }, { "name": "[Analytics] record (Kinesis Firehose)", "path": "./dist/esm/analytics/kinesis-firehose/index.mjs", "import": "{ record }", - "limit": "45.81 kB" + "limit": "45.85 kB" }, { "name": "[Analytics] record (Personalize)", "path": "./dist/esm/analytics/personalize/index.mjs", "import": "{ record }", - "limit": "49.63 kB" + "limit": "49.67 kB" }, { "name": "[Analytics] identifyUser (Pinpoint)", "path": "./dist/esm/analytics/index.mjs", "import": "{ identifyUser }", - "limit": "15.73 kB" + "limit": "15.79 kB" }, { "name": "[Analytics] enable", @@ -335,7 +335,7 @@ "name": "[API] generateClient (AppSync)", "path": "./dist/esm/api/index.mjs", "import": "{ generateClient }", - "limit": "40.19 kB" + "limit": "40.23 kB" }, { "name": "[API] REST API handlers", @@ -353,61 +353,61 @@ "name": "[Auth] resetPassword (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ resetPassword }", - "limit": "12.58 kB" + "limit": "12.62 kB" }, { "name": "[Auth] confirmResetPassword (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ confirmResetPassword }", - "limit": "12.52 kB" + "limit": "12.56 kB" }, { "name": "[Auth] signIn (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ signIn }", - "limit": "30.00 kB" + "limit": "28.78 kB" }, { "name": "[Auth] resendSignUpCode (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ resendSignUpCode }", - "limit": "12.53 kB" + "limit": "12.57 kB" }, { "name": "[Auth] confirmSignUp (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ confirmSignUp }", - "limit": "31.00 kB" + "limit": "29.40 kB" }, { "name": "[Auth] confirmSignIn (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ confirmSignIn }", - "limit": "28.42 kB" + "limit": "28.46 kB" }, { "name": "[Auth] updateMFAPreference (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ updateMFAPreference }", - "limit": "11.87 kB" + "limit": "11.92 kB" }, { "name": "[Auth] fetchMFAPreference (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ fetchMFAPreference }", - "limit": "11.91 kB" + "limit": "11.94 kB" }, { "name": "[Auth] verifyTOTPSetup (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ verifyTOTPSetup }", - "limit": "12.75 kB" + "limit": "12.78 kB" }, { "name": "[Auth] updatePassword (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ updatePassword }", - "limit": "12.76 kB" + "limit": "12.80 kB" }, { "name": "[Auth] setUpTOTP (Cognito)", @@ -419,85 +419,85 @@ "name": "[Auth] updateUserAttributes (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ updateUserAttributes }", - "limit": "11.99 kB" + "limit": "12.03 kB" }, { "name": "[Auth] getCurrentUser (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ getCurrentUser }", - "limit": "7.85 kB" + "limit": "7.86 kB" }, { "name": "[Auth] confirmUserAttribute (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ confirmUserAttribute }", - "limit": "12.75 kB" + "limit": "12.79 kB" }, { "name": "[Auth] signInWithRedirect (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ signInWithRedirect }", - "limit": "21.19 kB" + "limit": "21.21 kB" }, { "name": "[Auth] fetchUserAttributes (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ fetchUserAttributes }", - "limit": "11.82 kB" + "limit": "11.86 kB" }, { "name": "[Auth] Basic Auth Flow (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ signIn, signOut, fetchAuthSession, confirmSignIn }", - "limit": "30.20 kB" + "limit": "30.23 kB" }, { "name": "[Auth] OAuth Auth Flow (Cognito)", "path": "./dist/esm/auth/index.mjs", "import": "{ signInWithRedirect, signOut, fetchAuthSession }", - "limit": "21.62 kB" + "limit": "21.64 kB" }, { "name": "[Storage] copy (S3)", "path": "./dist/esm/storage/index.mjs", "import": "{ copy }", - "limit": "15.24 kB" + "limit": "15.42 kB" }, { "name": "[Storage] downloadData (S3)", "path": "./dist/esm/storage/index.mjs", "import": "{ downloadData }", - "limit": "15.76 kB" + "limit": "15.93 kB" }, { "name": "[Storage] getProperties (S3)", "path": "./dist/esm/storage/index.mjs", "import": "{ getProperties }", - "limit": "15.03 kB" + "limit": "15.20 kB" }, { "name": "[Storage] getUrl (S3)", "path": "./dist/esm/storage/index.mjs", "import": "{ getUrl }", - "limit": "16.09 kB" + "limit": "16.26 kB" }, { "name": "[Storage] list (S3)", "path": "./dist/esm/storage/index.mjs", "import": "{ list }", - "limit": "15.65 kB" + "limit": "15.82 kB" }, { "name": "[Storage] remove (S3)", "path": "./dist/esm/storage/index.mjs", "import": "{ remove }", - "limit": "14.88 kB" + "limit": "15.05 kB" }, { "name": "[Storage] uploadData (S3)", "path": "./dist/esm/storage/index.mjs", "import": "{ uploadData }", - "limit": "20.30 kB" + "limit": "20.48 kB" } ] } diff --git a/packages/core/__tests__/clients/middleware/retry/defaultRetryDecider.test.ts b/packages/core/__tests__/clients/middleware/retry/defaultRetryDecider.test.ts new file mode 100644 index 00000000000..b33b4e27f4d --- /dev/null +++ b/packages/core/__tests__/clients/middleware/retry/defaultRetryDecider.test.ts @@ -0,0 +1,136 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +import { HttpResponse } from '../../../../src/clients'; +import { getRetryDecider } from '../../../../src/clients/middleware/retry'; +import { isClockSkewError } from '../../../../src/clients/middleware/retry/isClockSkewError'; + +jest.mock('../../../../src/clients/middleware/retry/isClockSkewError'); + +const mockIsClockSkewError = jest.mocked(isClockSkewError); + +describe('getRetryDecider', () => { + const mockErrorParser = jest.fn(); + const mockHttpResponse: HttpResponse = { + statusCode: 200, + headers: {}, + body: 'body' as any, + }; + + beforeEach(() => { + jest.resetAllMocks(); + }); + + it('should handle network errors', async () => { + expect.assertions(2); + const retryDecider = getRetryDecider(mockErrorParser); + const connectionError = Object.assign(new Error(), { + name: 'Network error', + }); + const { retryable, isCredentialsExpiredError } = await retryDecider( + mockHttpResponse, + connectionError, + ); + expect(retryable).toBe(true); + expect(isCredentialsExpiredError).toBeFalsy(); + }); + + describe('handling throttling errors', () => { + it.each([ + 'BandwidthLimitExceeded', + 'EC2ThrottledException', + 'LimitExceededException', + 'PriorRequestNotComplete', + 'ProvisionedThroughputExceededException', + 'RequestLimitExceeded', + 'RequestThrottled', + 'RequestThrottledException', + 'SlowDown', + 'ThrottledException', + 'Throttling', + 'ThrottlingException', + 'TooManyRequestsException', + ])('should return retryable at %s error', async errorCode => { + expect.assertions(2); + mockErrorParser.mockResolvedValueOnce({ + code: errorCode, + }); + const retryDecider = getRetryDecider(mockErrorParser); + const { retryable, isCredentialsExpiredError } = await retryDecider( + mockHttpResponse, + undefined, + ); + expect(retryable).toBe(true); + expect(isCredentialsExpiredError).toBeFalsy(); + }); + + it('should set retryable for 402 error', async () => { + expect.assertions(2); + const retryDecider = getRetryDecider(mockErrorParser); + const { + retryable, + isCredentialsExpiredError: isInvalidCredentialsError, + } = await retryDecider( + { + ...mockHttpResponse, + statusCode: 429, + }, + undefined, + ); + expect(retryable).toBe(true); + expect(isInvalidCredentialsError).toBeFalsy(); + }); + }); + + describe('handling clockskew error', () => { + it.each([{ code: 'ClockSkew' }, { name: 'ClockSkew' }])( + 'should handle clockskew error %o', + async parsedError => { + expect.assertions(3); + mockErrorParser.mockResolvedValue(parsedError); + mockIsClockSkewError.mockReturnValue(true); + const retryDecider = getRetryDecider(mockErrorParser); + const { retryable, isCredentialsExpiredError } = await retryDecider( + mockHttpResponse, + undefined, + ); + expect(retryable).toBe(true); + expect(isCredentialsExpiredError).toBeFalsy(); + expect(mockIsClockSkewError).toHaveBeenCalledWith( + Object.values(parsedError)[0], + ); + }, + ); + }); + + it.each([500, 502, 503, 504])( + 'should handle server-side status code %s', + async statusCode => { + const retryDecider = getRetryDecider(mockErrorParser); + const { retryable, isCredentialsExpiredError } = await retryDecider( + { + ...mockHttpResponse, + statusCode, + }, + undefined, + ); + expect(retryable).toBe(true); + expect(isCredentialsExpiredError).toBeFalsy(); + }, + ); + + it.each(['TimeoutError', 'RequestTimeout', 'RequestTimeoutException'])( + 'should handle server-side timeout error code %s', + async errorCode => { + expect.assertions(2); + mockErrorParser.mockResolvedValue({ code: errorCode }); + const retryDecider = getRetryDecider(mockErrorParser); + const { retryable, isCredentialsExpiredError } = await retryDecider( + mockHttpResponse, + undefined, + ); + expect(retryable).toBe(true); + expect(isCredentialsExpiredError).toBeFalsy(); + }, + ); +}); diff --git a/packages/core/__tests__/clients/middleware/retry/middleware.test.ts b/packages/core/__tests__/clients/middleware/retry/middleware.test.ts index 1391f010d23..05f1b0f8de9 100644 --- a/packages/core/__tests__/clients/middleware/retry/middleware.test.ts +++ b/packages/core/__tests__/clients/middleware/retry/middleware.test.ts @@ -11,13 +11,13 @@ import { jest.spyOn(global, 'setTimeout'); jest.spyOn(global, 'clearTimeout'); -describe(`${retryMiddlewareFactory.name} middleware`, () => { +describe(`retry middleware`, () => { beforeEach(() => { jest.clearAllMocks(); }); const defaultRetryOptions = { - retryDecider: async () => true, + retryDecider: async () => ({ retryable: true }), computeDelay: () => 1, }; const defaultRequest = { url: new URL('https://a.b') }; @@ -72,7 +72,7 @@ describe(`${retryMiddlewareFactory.name} middleware`, () => { const retryableHandler = getRetryableHandler(nextHandler); const retryDecider = jest .fn() - .mockImplementation(response => response.body !== 'foo'); // retry if response is not foo + .mockImplementation(response => ({ retryable: response.body !== 'foo' })); // retry if response is not foo const resp = await retryableHandler(defaultRequest, { ...defaultRetryOptions, retryDecider, @@ -88,11 +88,9 @@ describe(`${retryMiddlewareFactory.name} middleware`, () => { .fn() .mockRejectedValue(new Error('UnretryableError')); const retryableHandler = getRetryableHandler(nextHandler); - const retryDecider = jest - .fn() - .mockImplementation( - (resp, error) => error.message !== 'UnretryableError', - ); + const retryDecider = jest.fn().mockImplementation((resp, error) => ({ + retryable: error.message !== 'UnretryableError', + })); try { await retryableHandler(defaultRequest, { ...defaultRetryOptions, @@ -103,11 +101,46 @@ describe(`${retryMiddlewareFactory.name} middleware`, () => { expect(e.message).toBe('UnretryableError'); expect(nextHandler).toHaveBeenCalledTimes(1); expect(retryDecider).toHaveBeenCalledTimes(1); - expect(retryDecider).toHaveBeenCalledWith(undefined, expect.any(Error)); + expect(retryDecider).toHaveBeenCalledWith( + undefined, + expect.any(Error), + expect.anything(), + ); } expect.assertions(4); }); + test('should set isCredentialsExpired in middleware context if retry decider returns the flag', async () => { + expect.assertions(4); + const coreHandler = jest + .fn() + .mockRejectedValueOnce(new Error('InvalidSignature')) + .mockResolvedValueOnce(defaultResponse); + + const nextMiddleware = jest.fn( + (next: MiddlewareHandler) => (request: any) => next(request), + ); + const retryableHandler = composeTransferHandler<[RetryOptions, any]>( + coreHandler, + [retryMiddlewareFactory, () => nextMiddleware], + ); + const retryDecider = jest.fn().mockImplementation((resp, error) => ({ + retryable: error?.message === 'InvalidSignature', + isCredentialsExpiredError: error?.message === 'InvalidSignature', + })); + const response = await retryableHandler(defaultRequest, { + ...defaultRetryOptions, + retryDecider, + }); + expect(response).toEqual(expect.objectContaining(defaultResponse)); + expect(coreHandler).toHaveBeenCalledTimes(2); + expect(retryDecider).toHaveBeenCalledTimes(2); + expect(nextMiddleware).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ isCredentialsExpired: true }), + ); + }); + test('should call computeDelay for intervals', async () => { const nextHandler = jest.fn().mockResolvedValue(defaultResponse); const retryableHandler = getRetryableHandler(nextHandler); @@ -152,7 +185,7 @@ describe(`${retryMiddlewareFactory.name} middleware`, () => { const nextHandler = jest.fn().mockResolvedValue(defaultResponse); const retryableHandler = getRetryableHandler(nextHandler); const controller = new AbortController(); - const retryDecider = async () => true; + const retryDecider = async () => ({ retryable: true }); const computeDelay = jest.fn().mockImplementation(attempt => { if (attempt === 1) { setTimeout(() => { @@ -204,9 +237,10 @@ describe(`${retryMiddlewareFactory.name} middleware`, () => { const retryDecider = jest .fn() .mockImplementation((response, error: Error) => { - if (error && error.message.endsWith('RetryableError')) return true; + if (error && error.message.endsWith('RetryableError')) + return { retryable: true }; - return false; + return { retryable: false }; }); const computeDelay = jest.fn().mockReturnValue(0); const response = await doubleRetryableHandler(defaultRequest, { diff --git a/packages/core/__tests__/clients/middleware/signing/middleware.test.ts b/packages/core/__tests__/clients/middleware/signing/middleware.test.ts index a3183ebcdb5..874d82e2282 100644 --- a/packages/core/__tests__/clients/middleware/signing/middleware.test.ts +++ b/packages/core/__tests__/clients/middleware/signing/middleware.test.ts @@ -11,6 +11,7 @@ import { getUpdatedSystemClockOffset } from '../../../../src/clients/middleware/ import { HttpRequest, HttpResponse, + Middleware, MiddlewareHandler, } from '../../../../src/clients/types'; @@ -113,6 +114,30 @@ describe('Signing middleware', () => { expect(credentialsProvider).toHaveBeenCalledTimes(1); }); + test('should forceRefresh credentials provider if middleware context isCredentialsInvalid flag is set', async () => { + expect.assertions(2); + const credentialsProvider = jest.fn().mockResolvedValue(credentials); + const nextHandler = jest.fn().mockResolvedValue(defaultResponse); + const setInvalidCredsMiddleware: Middleware = + () => (next, context) => request => { + context.isCredentialsExpired = true; + + return next(request); + }; + const signableHandler = composeTransferHandler< + [any, SigningOptions], + HttpRequest, + HttpResponse + >(nextHandler, [setInvalidCredsMiddleware, signingMiddlewareFactory]); + const config = { + ...defaultSigningOptions, + credentials: credentialsProvider, + }; + await signableHandler(defaultRequest, config); + expect(credentialsProvider).toHaveBeenCalledTimes(1); + expect(credentialsProvider).toHaveBeenCalledWith({ forceRefresh: true }); + }); + test.each([ ['response with Date header', 'Date'], ['response with date header', 'date'], @@ -128,6 +153,7 @@ describe('Signing middleware', () => { const middlewareFunction = signingMiddlewareFactory(defaultSigningOptions)( nextHandler, + {}, ); await middlewareFunction(defaultRequest); diff --git a/packages/core/src/clients/index.ts b/packages/core/src/clients/index.ts index a06067604bc..31abf267c77 100644 --- a/packages/core/src/clients/index.ts +++ b/packages/core/src/clients/index.ts @@ -15,9 +15,14 @@ export { } from './middleware/signing/signer/signatureV4'; export { EMPTY_HASH as EMPTY_SHA256_HASH } from './middleware/signing/signer/signatureV4/constants'; export { extendedEncodeURIComponent } from './middleware/signing/utils/extendedEncodeURIComponent'; -export { signingMiddlewareFactory, SigningOptions } from './middleware/signing'; +export { + signingMiddlewareFactory, + SigningOptions, + CredentialsProviderOptions, +} from './middleware/signing'; export { getRetryDecider, + RetryDeciderOutput, jitteredBackoff, retryMiddlewareFactory, RetryOptions, diff --git a/packages/core/src/clients/middleware/retry/defaultRetryDecider.ts b/packages/core/src/clients/middleware/retry/defaultRetryDecider.ts index 874cc74314e..edec193ebf1 100644 --- a/packages/core/src/clients/middleware/retry/defaultRetryDecider.ts +++ b/packages/core/src/clients/middleware/retry/defaultRetryDecider.ts @@ -4,6 +4,7 @@ import { ErrorParser, HttpResponse } from '../../types'; import { isClockSkewError } from './isClockSkewError'; +import { RetryDeciderOutput } from './types'; /** * Get retry decider function @@ -11,7 +12,10 @@ import { isClockSkewError } from './isClockSkewError'; */ export const getRetryDecider = (errorParser: ErrorParser) => - async (response?: HttpResponse, error?: unknown): Promise => { + async ( + response?: HttpResponse, + error?: unknown, + ): Promise => { const parsedError = (error as Error & { code: string }) ?? (await errorParser(response)) ?? @@ -19,12 +23,15 @@ export const getRetryDecider = const errorCode = parsedError?.code || parsedError?.name; const statusCode = response?.statusCode; - return ( + const isRetryable = isConnectionError(error) || isThrottlingError(statusCode, errorCode) || isClockSkewError(errorCode) || - isServerSideError(statusCode, errorCode) - ); + isServerSideError(statusCode, errorCode); + + return { + retryable: isRetryable, + }; }; // reference: https://github.com/aws/aws-sdk-js-v3/blob/ab0e7be36e7e7f8a0c04834357aaad643c7912c3/packages/service-error-classification/src/constants.ts#L22-L37 diff --git a/packages/core/src/clients/middleware/retry/index.ts b/packages/core/src/clients/middleware/retry/index.ts index 4c82c603508..fdf34552fa7 100644 --- a/packages/core/src/clients/middleware/retry/index.ts +++ b/packages/core/src/clients/middleware/retry/index.ts @@ -4,3 +4,4 @@ export { RetryOptions, retryMiddlewareFactory } from './middleware'; export { jitteredBackoff } from './jitteredBackoff'; export { getRetryDecider } from './defaultRetryDecider'; +export { RetryDeciderOutput } from './types'; diff --git a/packages/core/src/clients/middleware/retry/middleware.ts b/packages/core/src/clients/middleware/retry/middleware.ts index 1c8d88bc4fd..8d9a9c2cd9b 100644 --- a/packages/core/src/clients/middleware/retry/middleware.ts +++ b/packages/core/src/clients/middleware/retry/middleware.ts @@ -8,6 +8,8 @@ import { Response, } from '../../types/core'; +import { RetryDeciderOutput } from './types'; + const DEFAULT_RETRY_ATTEMPTS = 3; /** @@ -19,9 +21,14 @@ export interface RetryOptions { * * @param response Optional response of the request. * @param error Optional error thrown from previous attempts. + * @param middlewareContext Optional context object to store data between retries. * @returns True if the request should be retried. */ - retryDecider(response?: TResponse, error?: unknown): Promise; + retryDecider( + response?: TResponse, + error?: unknown, + middlewareContext?: MiddlewareContext, + ): Promise; /** * Function to compute the delay in milliseconds before the next retry based * on the number of attempts. @@ -87,7 +94,14 @@ export const retryMiddlewareFactory = ({ ? context.attemptsCount ?? 0 : attemptsCount + 1; context.attemptsCount = attemptsCount; - if (await retryDecider(response, error)) { + const { isCredentialsExpiredError, retryable } = await retryDecider( + response, + error, + context, + ); + if (retryable) { + // Setting isCredentialsInvalid flag to notify signing middleware to forceRefresh credentials provider. + context.isCredentialsExpired = !!isCredentialsExpiredError; if (!abortSignal?.aborted && attemptsCount < maxAttempts) { // prevent sleep for last attempt or cancelled request; const delay = computeDelay(attemptsCount); diff --git a/packages/core/src/clients/middleware/retry/types.ts b/packages/core/src/clients/middleware/retry/types.ts new file mode 100644 index 00000000000..a229216edee --- /dev/null +++ b/packages/core/src/clients/middleware/retry/types.ts @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +export interface RetryDeciderOutput { + retryable: boolean; + isCredentialsExpiredError?: boolean; +} diff --git a/packages/core/src/clients/middleware/signing/index.ts b/packages/core/src/clients/middleware/signing/index.ts index a1458bca3e4..1ce90db4b7e 100644 --- a/packages/core/src/clients/middleware/signing/index.ts +++ b/packages/core/src/clients/middleware/signing/index.ts @@ -1,4 +1,8 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -export { signingMiddlewareFactory, SigningOptions } from './middleware'; +export { + signingMiddlewareFactory, + SigningOptions, + CredentialsProviderOptions, +} from './middleware'; diff --git a/packages/core/src/clients/middleware/signing/middleware.ts b/packages/core/src/clients/middleware/signing/middleware.ts index a7bed1e6b7f..1b36519729e 100644 --- a/packages/core/src/clients/middleware/signing/middleware.ts +++ b/packages/core/src/clients/middleware/signing/middleware.ts @@ -7,16 +7,27 @@ import { HttpResponse, MiddlewareHandler, } from '../../types'; +import { MiddlewareContext } from '../../types/core'; import { signRequest } from './signer/signatureV4'; import { getSkewCorrectedDate } from './utils/getSkewCorrectedDate'; import { getUpdatedSystemClockOffset } from './utils/getUpdatedSystemClockOffset'; +/** + * Options type for the async callback function returning aws credentials. This + * function is used by SigV4 signer to resolve the aws credentials + */ +export interface CredentialsProviderOptions { + forceRefresh?: boolean; +} + /** * Configuration of the signing middleware */ export interface SigningOptions { - credentials: Credentials | (() => Promise); + credentials: + | Credentials + | ((options?: CredentialsProviderOptions) => Promise); region: string; service: string; @@ -41,12 +52,19 @@ export const signingMiddlewareFactory = ({ }: SigningOptions) => { let currentSystemClockOffset: number; - return (next: MiddlewareHandler) => + return ( + next: MiddlewareHandler, + context: MiddlewareContext, + ) => async function signingMiddleware(request: HttpRequest) { currentSystemClockOffset = currentSystemClockOffset ?? 0; const signRequestOptions = { credentials: - typeof credentials === 'function' ? await credentials() : credentials, + typeof credentials === 'function' + ? await credentials({ + forceRefresh: !!context?.isCredentialsExpired, + }) + : credentials, signingDate: getSkewCorrectedDate(currentSystemClockOffset), signingRegion: region, signingService: service, diff --git a/packages/core/src/clients/types/core.ts b/packages/core/src/clients/types/core.ts index 1fa122250b6..a6348655899 100644 --- a/packages/core/src/clients/types/core.ts +++ b/packages/core/src/clients/types/core.ts @@ -30,6 +30,11 @@ export type MiddlewareHandler = ( * The context object to store states across the middleware chain. */ export interface MiddlewareContext { + /** + * Whether an error indicating expired credentials has been returned from server-side. + * This is set by the retry middleware. + */ + isCredentialsExpired?: boolean; /** * The number of times the request has been attempted. This is set by retry middleware */ diff --git a/packages/core/src/clients/types/index.ts b/packages/core/src/clients/types/index.ts index e2b8953a4d2..0ee905fb162 100644 --- a/packages/core/src/clients/types/index.ts +++ b/packages/core/src/clients/types/index.ts @@ -4,6 +4,7 @@ export { Middleware, MiddlewareHandler, + MiddlewareContext, Request, Response, TransferHandler, diff --git a/packages/storage/__tests__/providers/s3/apis/copy.test.ts b/packages/storage/__tests__/providers/s3/apis/copy.test.ts index 56f46927b30..56104e84d17 100644 --- a/packages/storage/__tests__/providers/s3/apis/copy.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/copy.test.ts @@ -384,7 +384,6 @@ describe('copy API', () => { }, }); } catch (error: any) { - console.log(error); expect(error).toBeInstanceOf(StorageError); expect(error.name).toBe( StorageValidationErrorCode.InvalidCopyOperationStorageBucket, @@ -403,7 +402,6 @@ describe('copy API', () => { }, }); } catch (error: any) { - console.log(error); expect(error).toBeInstanceOf(StorageError); expect(error.name).toBe( StorageValidationErrorCode.InvalidCopyOperationStorageBucket, diff --git a/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts b/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts index 527e66dce18..662640e3340 100644 --- a/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts +++ b/packages/storage/__tests__/providers/s3/apis/utils/resolveS3ConfigAndInput.test.ts @@ -241,8 +241,10 @@ describe('resolveS3ConfigAndInput', () => { }); if (typeof s3Config.credentials === 'function') { - const result = await s3Config.credentials(); - expect(mockLocationCredentialsProvider).toHaveBeenCalled(); + const result = await s3Config.credentials({ forceRefresh: true }); + expect(mockLocationCredentialsProvider).toHaveBeenCalledWith({ + forceRefresh: true, + }); expect(result).toEqual(credentials); } else { throw new Error('Expect credentials to be a function'); diff --git a/packages/storage/__tests__/providers/s3/utils/client/S3/cases/completeMultipartUpload.ts b/packages/storage/__tests__/providers/s3/utils/client/S3/cases/completeMultipartUpload.ts index d94e6b94d34..6b20ab56254 100644 --- a/packages/storage/__tests__/providers/s3/utils/client/S3/cases/completeMultipartUpload.ts +++ b/packages/storage/__tests__/providers/s3/utils/client/S3/cases/completeMultipartUpload.ts @@ -109,7 +109,12 @@ const completeMultipartUploadErrorWith200CodeCase: ApiFunctionalTestCase< 'error case', 'completeMultipartUpload with 200 status', completeMultipartUpload, - { ...defaultConfig, retryDecider: async () => false }, // disable retry + { + ...defaultConfig, + retryDecider: async () => ({ + retryable: false, + }), + }, // disable retry completeMultipartUploadHappyCase[4], completeMultipartUploadHappyCase[5], { diff --git a/packages/storage/__tests__/providers/s3/utils/client/S3/utils/retryDecider.test.ts b/packages/storage/__tests__/providers/s3/utils/client/S3/utils/retryDecider.test.ts new file mode 100644 index 00000000000..5e1801c07db --- /dev/null +++ b/packages/storage/__tests__/providers/s3/utils/client/S3/utils/retryDecider.test.ts @@ -0,0 +1,103 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +import { + HttpResponse, + getRetryDecider as getDefaultRetryDecider, +} from '@aws-amplify/core/internals/aws-client-utils'; + +import { retryDecider } from '../../../../../../../src/providers/s3/utils/client/utils'; +import { parseXmlError } from '../../../../../../../src/providers/s3/utils/client/utils/parsePayload'; + +jest.mock( + '../../../../../../../src/providers/s3/utils/client/utils/parsePayload', +); +jest.mock('@aws-amplify/core/internals/aws-client-utils'); + +const mockErrorParser = jest.mocked(parseXmlError); + +describe('retryDecider', () => { + const mockHttpResponse: HttpResponse = { + statusCode: 200, + headers: {}, + body: 'body' as any, + }; + + beforeEach(() => { + jest.mocked(getDefaultRetryDecider).mockReturnValue(async () => { + return { retryable: false }; + }); + }); + + afterEach(() => { + jest.resetAllMocks(); + }); + + it('should invoke the default retry decider', async () => { + expect.assertions(3); + const { retryable, isCredentialsExpiredError } = await retryDecider( + mockHttpResponse, + undefined, + {}, + ); + expect(getDefaultRetryDecider).toHaveBeenCalledWith(mockErrorParser); + expect(retryable).toBe(false); + expect(isCredentialsExpiredError).toBeFalsy(); + }); + + describe('handling expired token errors', () => { + const mockErrorMessage = 'Token expired'; + it.each(['RequestExpired', 'ExpiredTokenException', 'ExpiredToken'])( + 'should retry if expired credentials error name %s', + async errorName => { + expect.assertions(2); + const parsedError = { + name: errorName, + message: mockErrorMessage, + $metadata: {}, + }; + mockErrorParser.mockResolvedValue(parsedError); + const { retryable, isCredentialsExpiredError } = await retryDecider( + { ...mockHttpResponse, statusCode: 400 }, + undefined, + {}, + ); + expect(retryable).toBe(true); + expect(isCredentialsExpiredError).toBe(true); + }, + ); + + it('should retry if error message indicates invalid credentials', async () => { + expect.assertions(2); + const parsedError = { + name: 'InvalidSignature', + message: 'Auth token in request is expired.', + $metadata: {}, + }; + mockErrorParser.mockResolvedValue(parsedError); + const { retryable, isCredentialsExpiredError } = await retryDecider( + { ...mockHttpResponse, statusCode: 400 }, + undefined, + {}, + ); + expect(retryable).toBe(true); + expect(isCredentialsExpiredError).toBe(true); + }); + + it('should not retry if invalid credentials error has been retried previously', async () => { + expect.assertions(2); + const parsedError = { + name: 'RequestExpired', + message: mockErrorMessage, + $metadata: {}, + }; + mockErrorParser.mockResolvedValue(parsedError); + const { retryable, isCredentialsExpiredError } = await retryDecider( + { ...mockHttpResponse, statusCode: 400 }, + undefined, + { isCredentialsExpired: true }, + ); + expect(retryable).toBe(false); + expect(isCredentialsExpiredError).toBe(true); + }); + }); +}); diff --git a/packages/storage/__tests__/storageBrowser/apis/getDataAccess.test.ts b/packages/storage/__tests__/storageBrowser/apis/getDataAccess.test.ts index 0753e0ae334..91a3fd12556 100644 --- a/packages/storage/__tests__/storageBrowser/apis/getDataAccess.test.ts +++ b/packages/storage/__tests__/storageBrowser/apis/getDataAccess.test.ts @@ -1,6 +1,8 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 +import { CredentialsProviderOptions } from '@aws-amplify/core/internals/aws-client-utils'; + import { getDataAccess } from '../../../src/storageBrowser/apis/getDataAccess'; import { getDataAccess as getDataAccessClient } from '../../../src/providers/s3/utils/client/s3control'; import { GetDataAccessInput } from '../../../src/storageBrowser/apis/types'; @@ -29,7 +31,7 @@ const MOCK_ACCESS_CREDENTIALS = { SessionToken: MOCK_SESSION_TOKEN, Expiration: MOCK_EXPIRATION_DATE, }; -const MOCK_CREDENTIAL_PROVIDER = async () => MOCK_CREDENTIALS; +const MOCK_CREDENTIAL_PROVIDER = jest.fn().mockResolvedValue(MOCK_CREDENTIALS); const sharedGetDataAccessParams: GetDataAccessInput = { accountId: MOCK_ACCOUNT_ID, @@ -41,7 +43,7 @@ const sharedGetDataAccessParams: GetDataAccessInput = { }; describe('getDataAccess', () => { - const getDataAccessClientMock = getDataAccessClient as jest.Mock; + const getDataAccessClientMock = jest.mocked(getDataAccessClient); beforeEach(() => { jest.clearAllMocks(); @@ -49,15 +51,17 @@ describe('getDataAccess', () => { getDataAccessClientMock.mockResolvedValue({ Credentials: MOCK_ACCESS_CREDENTIALS, MatchedGrantTarget: MOCK_SCOPE, + $metadata: {}, }); }); it('should invoke the getDataAccess client correctly', async () => { + expect.assertions(6); const result = await getDataAccess(sharedGetDataAccessParams); expect(getDataAccessClientMock).toHaveBeenCalledWith( expect.objectContaining({ - credentials: MOCK_CREDENTIALS.credentials, + credentials: expect.any(Function), region: MOCK_REGION, userAgentValue: expect.stringContaining('storage/8'), }), @@ -69,6 +73,15 @@ describe('getDataAccess', () => { DurationSeconds: 900, }), ); + const inputCredentialsProvider = getDataAccessClientMock.mock.calls[0][0] + .credentials as (input: CredentialsProviderOptions) => any; + expect(inputCredentialsProvider).toBeInstanceOf(Function); + await expect( + inputCredentialsProvider({ forceRefresh: true }), + ).resolves.toEqual(MOCK_CREDENTIALS.credentials); + expect(MOCK_CREDENTIAL_PROVIDER).toHaveBeenCalledWith({ + forceRefresh: true, + }); expect(result.credentials).toEqual(MOCK_CREDENTIALS.credentials); expect(result.scope).toEqual(MOCK_SCOPE); @@ -80,6 +93,7 @@ describe('getDataAccess', () => { getDataAccessClientMock.mockResolvedValue({ Credentials: undefined, MatchedGrantTarget: MOCK_SCOPE, + $metadata: {}, }); expect(getDataAccess(sharedGetDataAccessParams)).rejects.toThrow( @@ -93,6 +107,7 @@ describe('getDataAccess', () => { getDataAccessClientMock.mockResolvedValue({ Credentials: MOCK_ACCESS_CREDENTIALS, MatchedGrantTarget: MOCK_OBJECT_SCOPE, + $metadata: {}, }); const result = await getDataAccess({ diff --git a/packages/storage/__tests__/storageBrowser/apis/listCallerAccessGrants.test.ts b/packages/storage/__tests__/storageBrowser/apis/listCallerAccessGrants.test.ts index 3e0051f7461..bff4b4e07bd 100644 --- a/packages/storage/__tests__/storageBrowser/apis/listCallerAccessGrants.test.ts +++ b/packages/storage/__tests__/storageBrowser/apis/listCallerAccessGrants.test.ts @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +import { CredentialsProviderOptions } from '@aws-amplify/core/internals/aws-client-utils'; + import { listCallerAccessGrants } from '../../../src/storageBrowser/apis/listCallerAccessGrants'; import { listCallerAccessGrants as listCallerAccessGrantsClient } from '../../../src/providers/s3/utils/client/s3control'; @@ -8,7 +10,15 @@ jest.mock('../../../src/providers/s3/utils/client/s3control'); const mockAccountId = '1234567890'; const mockRegion = 'us-foo-2'; -const mockCredentialsProvider = jest.fn(); +const mockCredentials = { + accessKeyId: 'key', + secretAccessKey: 'secret', + sessionToken: 'session', + expiration: new Date(), +}; +const mockCredentialsProvider = jest + .fn() + .mockResolvedValue({ credentials: mockCredentials }); const mockNextToken = '123'; const mockPageSize = 123; @@ -18,7 +28,7 @@ describe('listCallerAccessGrants', () => { }); it('should invoke the listCallerAccessGrants client with expected parameters', async () => { - expect.assertions(1); + expect.assertions(4); jest.mocked(listCallerAccessGrantsClient).mockResolvedValue({ NextToken: undefined, CallerAccessGrantsList: [], @@ -42,6 +52,17 @@ describe('listCallerAccessGrants', () => { MaxResults: mockPageSize, }), ); + const inputCredentialsProvider = jest.mocked(listCallerAccessGrantsClient) + .mock.calls[0][0].credentials as ( + input: CredentialsProviderOptions, + ) => any; + expect(inputCredentialsProvider).toBeInstanceOf(Function); + await expect( + inputCredentialsProvider({ forceRefresh: true }), + ).resolves.toEqual(mockCredentials); + expect(mockCredentialsProvider).toHaveBeenCalledWith({ + forceRefresh: true, + }); }); it('should set a default page size', async () => { diff --git a/packages/storage/src/providers/s3/types/options.ts b/packages/storage/src/providers/s3/types/options.ts index 170cafd7313..0852c761891 100644 --- a/packages/storage/src/providers/s3/types/options.ts +++ b/packages/storage/src/providers/s3/types/options.ts @@ -3,7 +3,10 @@ import { StorageAccessLevel } from '@aws-amplify/core'; import { AWSCredentials } from '@aws-amplify/core/internals/utils'; -import { SigningOptions } from '@aws-amplify/core/internals/aws-client-utils'; +import { + CredentialsProviderOptions, + SigningOptions, +} from '@aws-amplify/core/internals/aws-client-utils'; import { TransferProgressEvent } from '../../../types'; import { @@ -15,9 +18,9 @@ import { /** * @internal */ -export type LocationCredentialsProvider = (options?: { - forceRefresh?: boolean; -}) => Promise<{ credentials: AWSCredentials }>; +export type LocationCredentialsProvider = ( + input?: CredentialsProviderOptions, +) => Promise<{ credentials: AWSCredentials }>; export interface BucketInfo { bucketName: string; diff --git a/packages/storage/src/providers/s3/utils/client/s3control/base.ts b/packages/storage/src/providers/s3/utils/client/s3control/base.ts index 380488bf4ac..a40f9f6a5dd 100644 --- a/packages/storage/src/providers/s3/utils/client/s3control/base.ts +++ b/packages/storage/src/providers/s3/utils/client/s3control/base.ts @@ -8,11 +8,10 @@ import { import { EndpointResolverOptions, getDnsSuffix, - getRetryDecider, jitteredBackoff, } from '@aws-amplify/core/internals/aws-client-utils'; -import { parseXmlError } from '../utils'; +import { retryDecider } from '../utils'; /** * The service name used to sign requests if the API requires authentication. @@ -64,7 +63,7 @@ const endpointResolver = ( export const defaultConfig = { service: SERVICE_NAME, endpointResolver, - retryDecider: getRetryDecider(parseXmlError), + retryDecider, computeDelay: jitteredBackoff, userAgentValue: getAmplifyUserAgent(), uriEscapePath: false, // Required by S3. See https://github.com/aws/aws-sdk-js-v3/blob/9ba012dfa3a3429aa2db0f90b3b0b3a7a31f9bc3/packages/signature-v4/src/SignatureV4.ts#L76-L83 diff --git a/packages/storage/src/providers/s3/utils/client/s3data/base.ts b/packages/storage/src/providers/s3/utils/client/s3data/base.ts index a31d6d5a2f1..d51c3a18a11 100644 --- a/packages/storage/src/providers/s3/utils/client/s3data/base.ts +++ b/packages/storage/src/providers/s3/utils/client/s3data/base.ts @@ -8,11 +8,10 @@ import { import { EndpointResolverOptions, getDnsSuffix, - getRetryDecider, jitteredBackoff, } from '@aws-amplify/core/internals/aws-client-utils'; -import { parseXmlError } from '../utils'; +import { retryDecider } from '../utils'; const DOMAIN_PATTERN = /^[a-z0-9][a-z0-9.-]{1,61}[a-z0-9]$/; const IP_ADDRESS_PATTERN = /(\d+\.){3}\d+/; @@ -106,7 +105,7 @@ export const isDnsCompatibleBucketName = (bucketName: string): boolean => export const defaultConfig = { service: SERVICE_NAME, endpointResolver, - retryDecider: getRetryDecider(parseXmlError), + retryDecider, computeDelay: jitteredBackoff, userAgentValue: getAmplifyUserAgent(), useAccelerateEndpoint: false, diff --git a/packages/storage/src/providers/s3/utils/client/s3data/completeMultipartUpload.ts b/packages/storage/src/providers/s3/utils/client/s3data/completeMultipartUpload.ts index 59a8e029afc..1e399e824e7 100644 --- a/packages/storage/src/providers/s3/utils/client/s3data/completeMultipartUpload.ts +++ b/packages/storage/src/providers/s3/utils/client/s3data/completeMultipartUpload.ts @@ -5,6 +5,8 @@ import { Endpoint, HttpRequest, HttpResponse, + MiddlewareContext, + RetryDeciderOutput, parseMetadata, } from '@aws-amplify/core/internals/aws-client-utils'; import { @@ -18,6 +20,7 @@ import { map, parseXmlBody, parseXmlError, + retryDecider, s3TransferHandler, serializePathnameObjectKey, validateS3RequiredParameter, @@ -136,25 +139,24 @@ const completeMultipartUploadDeserializer = async ( const retryWhenErrorWith200StatusCode = async ( response?: HttpResponse, error?: unknown, -): Promise => { + middlewareContext?: MiddlewareContext, +): Promise => { if (!response) { - return false; + return { retryable: false }; } if (response.statusCode === 200) { if (!response.body) { - return true; + return { retryable: true }; } const parsed = await parseXmlBody(response); if (parsed.Code !== undefined && parsed.Message !== undefined) { - return true; + return { retryable: true }; } - return false; + return { retryable: false }; } - const defaultRetryDecider = defaultConfig.retryDecider; - - return defaultRetryDecider(response, error); + return retryDecider(response, error, middlewareContext); }; export const completeMultipartUpload = composeServiceApi( diff --git a/packages/storage/src/providers/s3/utils/client/utils/index.ts b/packages/storage/src/providers/s3/utils/client/utils/index.ts index abfe9328d45..423987699f8 100644 --- a/packages/storage/src/providers/s3/utils/client/utils/index.ts +++ b/packages/storage/src/providers/s3/utils/client/utils/index.ts @@ -25,3 +25,4 @@ export { serializePathnameObjectKey, validateS3RequiredParameter, } from './serializeHelpers'; +export { retryDecider } from './retryDecider'; diff --git a/packages/storage/src/providers/s3/utils/client/utils/retryDecider.ts b/packages/storage/src/providers/s3/utils/client/utils/retryDecider.ts new file mode 100644 index 00000000000..3e1e0fcc3da --- /dev/null +++ b/packages/storage/src/providers/s3/utils/client/utils/retryDecider.ts @@ -0,0 +1,81 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +import { + HttpResponse, + MiddlewareContext, + RetryDeciderOutput, + getRetryDecider, +} from '@aws-amplify/core/internals/aws-client-utils'; + +import { LocationCredentialsProvider } from '../../../types/options'; + +import { parseXmlError } from './parsePayload'; + +/** + * Function to decide if the S3 request should be retried. For S3 APIs, we support forceRefresh option + * for {@link LocationCredentialsProvider | LocationCredentialsProvider } option. It's set when S3 returns + * credentials expired error. In the retry decider, we detect this response and set flag to signify a retry + * attempt. The retry attempt would invoke the LocationCredentialsProvider with forceRefresh option set. + * + * @param response Optional response of the request. + * @param error Optional error thrown from previous attempts. + * @param middlewareContext Optional context object to store data between retries. + * @returns True if the request should be retried. + */ +export const retryDecider = async ( + response?: HttpResponse, + error?: unknown, + middlewareContext?: MiddlewareContext, +): Promise => { + const defaultRetryDecider = getRetryDecider(parseXmlError); + const defaultRetryDecision = await defaultRetryDecider(response, error); + if (!response || response.statusCode < 300) { + return { retryable: false }; + } + const parsedError = await parseXmlError(response); + const errorCode = parsedError?.name; + const errorMessage = parsedError?.message; + const isCredentialsExpired = isCredentialsExpiredError( + errorCode, + errorMessage, + ); + + return { + retryable: + defaultRetryDecision.retryable || + // If we know the previous retry attempt sets isCredentialsExpired in the + // middleware context, we don't want to retry anymore. + !!(isCredentialsExpired && !middlewareContext?.isCredentialsExpired), + isCredentialsExpiredError: isCredentialsExpired, + }; +}; + +// Ref: https://github.com/aws/aws-sdk-js/blob/54829e341181b41573c419bd870dd0e0f8f10632/lib/event_listeners.js#L522-L541 +const INVALID_TOKEN_ERROR_CODES = [ + 'RequestExpired', + 'ExpiredTokenException', + 'ExpiredToken', +]; + +/** + * Given an error code, returns true if it is related to invalid credentials. + * + * @param errorCode String representation of some error. + * @returns True if given error indicates the credentials used to authorize request + * are invalid. + */ +const isCredentialsExpiredError = ( + errorCode?: string, + errorMessage?: string, +) => { + const isExpiredTokenError = + !!errorCode && INVALID_TOKEN_ERROR_CODES.includes(errorCode); + // Ref: https://github.com/aws/aws-sdk-js/blob/54829e341181b41573c419bd870dd0e0f8f10632/lib/event_listeners.js#L536-L539 + const isExpiredSignatureError = + !!errorCode && + !!errorMessage && + errorCode.includes('Signature') && + errorMessage.includes('expired'); + + return isExpiredTokenError || isExpiredSignatureError; +}; diff --git a/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts b/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts index ef3b4621c33..c07888871f6 100644 --- a/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts +++ b/packages/storage/src/providers/s3/utils/resolveS3ConfigAndInput.ts @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 import { AmplifyClassV6, StorageAccessLevel } from '@aws-amplify/core'; +import { CredentialsProviderOptions } from '@aws-amplify/core/internals/aws-client-utils'; import { assertValidationError } from '../../../errors/utils/assertValidationError'; import { StorageValidationErrorCode } from '../../../errors/types/validation'; @@ -76,14 +77,20 @@ export const resolveS3ConfigAndInput = async ( * used because the long-running tasks like multipart upload may span over the * credentials expiry. Auth.fetchAuthSession() automatically refreshes the * credentials if they are expired. + * + * The optional forceRefresh option is set when the S3 service returns expired + * tokens error in the previous API call attempt. */ - const credentialsProvider = async () => { + const credentialsProvider = async (options?: CredentialsProviderOptions) => { if (isLocationCredentialsProvider(apiOptions)) { assertStorageInput(apiInput); } + // TODO: forceRefresh option of fetchAuthSession would refresh both tokens and + // AWS credentials. So we do not support forceRefreshing from the Auth until + // we support refreshing only the credentials. const { credentials } = isLocationCredentialsProvider(apiOptions) - ? await apiOptions.locationCredentialsProvider() + ? await apiOptions.locationCredentialsProvider(options) : await amplify.Auth.fetchAuthSession(); assertValidationError( !!credentials, diff --git a/packages/storage/src/storageBrowser/apis/getDataAccess.ts b/packages/storage/src/storageBrowser/apis/getDataAccess.ts index 5e5bec23540..440a83e08cc 100644 --- a/packages/storage/src/storageBrowser/apis/getDataAccess.ts +++ b/packages/storage/src/storageBrowser/apis/getDataAccess.ts @@ -5,6 +5,7 @@ import { AmplifyErrorCode, StorageAction, } from '@aws-amplify/core/internals/utils'; +import { CredentialsProviderOptions } from '@aws-amplify/core/internals/aws-client-utils'; import { getStorageUserAgentValue } from '../../providers/s3/utils/userAgent'; import { getDataAccess as getDataAccessClient } from '../../providers/s3/utils/client/s3control'; @@ -18,11 +19,17 @@ export const getDataAccess = async ( input: GetDataAccessInput, ): Promise => { const targetType = input.scope.endsWith('*') ? undefined : 'Object'; - const { credentials } = await input.credentialsProvider(); + const clientCredentialsProvider = async ( + options?: CredentialsProviderOptions, + ) => { + const { credentials } = await input.credentialsProvider(options); + + return credentials; + }; const result = await getDataAccessClient( { - credentials, + credentials: clientCredentialsProvider, region: input.region, userAgentValue: getStorageUserAgentValue(StorageAction.GetDataAccess), }, diff --git a/packages/storage/src/storageBrowser/apis/listCallerAccessGrants.ts b/packages/storage/src/storageBrowser/apis/listCallerAccessGrants.ts index 957e6eb1fcb..12836d59880 100644 --- a/packages/storage/src/storageBrowser/apis/listCallerAccessGrants.ts +++ b/packages/storage/src/storageBrowser/apis/listCallerAccessGrants.ts @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 import { StorageAction } from '@aws-amplify/core/internals/utils'; +import { CredentialsProviderOptions } from '@aws-amplify/core/internals/aws-client-utils'; import { logger } from '../../utils'; import { listCallerAccessGrants as listCallerAccessGrantsClient } from '../../providers/s3/utils/client/s3control'; @@ -26,8 +27,10 @@ export const listCallerAccessGrants = async ( logger.debug(`defaulting pageSize to ${MAX_PAGE_SIZE}.`); } - const clientCredentialsProvider = async () => { - const { credentials } = await credentialsProvider(); + const clientCredentialsProvider = async ( + options?: CredentialsProviderOptions, + ) => { + const { credentials } = await credentialsProvider(options); return credentials; }; diff --git a/packages/storage/src/storageBrowser/types.ts b/packages/storage/src/storageBrowser/types.ts index c770b7472a3..a09492bfb65 100644 --- a/packages/storage/src/storageBrowser/types.ts +++ b/packages/storage/src/storageBrowser/types.ts @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 import { AWSCredentials } from '@aws-amplify/core/internals/utils'; +import { CredentialsProviderOptions } from '@aws-amplify/core/internals/aws-client-utils'; import { LocationCredentialsProvider } from '../providers/s3/types/options'; @@ -13,9 +14,9 @@ export type Permission = 'READ' | 'READWRITE' | 'WRITE'; /** * @internal */ -export type CredentialsProvider = (options?: { - forceRefresh?: boolean; -}) => Promise<{ credentials: AWSCredentials }>; +export type CredentialsProvider = ( + options?: CredentialsProviderOptions, +) => Promise<{ credentials: AWSCredentials }>; /** * @internal