diff --git a/cache_test.go b/cache_test.go index 33f0517..e1eeb07 100644 --- a/cache_test.go +++ b/cache_test.go @@ -226,7 +226,7 @@ const prefixKey = "#prefix#" func TestPrefixKey(t *testing.T) { memoryStore := persist.NewMemoryStore(1 * time.Minute) - cacheURIMiddleware := CacheByRequestPath( + cachePathMiddleware := CacheByRequestPath( memoryStore, 3*time.Second, WithPrefixKey(prefixKey), @@ -234,11 +234,47 @@ func TestPrefixKey(t *testing.T) { requestPath := "/cache" - w1 := mockHttpRequest(cacheURIMiddleware, requestPath, true) + w1 := mockHttpRequest(cachePathMiddleware, requestPath, true) err := memoryStore.Delete(prefixKey + requestPath) require.NoError(t, err) - w2 := mockHttpRequest(cacheURIMiddleware, requestPath, true) + w2 := mockHttpRequest(cachePathMiddleware, requestPath, true) assert.NotEqual(t, w1.Body, w2.Body) } + +func TestWithDiscardHeaders(t *testing.T) { + const headerKey = "RandKey" + + memoryStore := persist.NewMemoryStore(1 * time.Minute) + cachePathMiddleware := CacheByRequestPath( + memoryStore, + 3*time.Second, + WithDiscardHeaders([]string{ + headerKey, + }), + ) + + _, engine := gin.CreateTestContext(httptest.NewRecorder()) + + engine.GET("/cache", cachePathMiddleware, func(c *gin.Context) { + c.Header(headerKey, fmt.Sprintf("rand:%d", rand.Int())) + c.String(http.StatusOK, "value") + }) + + testRequest := httptest.NewRequest(http.MethodGet, "/cache", nil) + + { + testWriter := httptest.NewRecorder() + engine.ServeHTTP(testWriter, testRequest) + headers1 := testWriter.Header() + assert.NotEqual(t, headers1.Get(headerKey), "") + } + + { + testWriter := httptest.NewRecorder() + engine.ServeHTTP(testWriter, testRequest) + headers2 := testWriter.Header() + assert.Equal(t, headers2.Get(headerKey), "") + } +} diff --git a/option.go b/option.go index 96671a7..47df4b1 100644 --- a/option.go +++ b/option.go @@ -161,7 +161,7 @@ func WithoutHeader() Option { } } -func DiscardHeaders(headers []string) Option { +func WithDiscardHeaders(headers []string) Option { return func(c *Config) { c.discardHeaders = headers }