From 6ed4ad08517c70918993d8712c10d7279af8b517 Mon Sep 17 00:00:00 2001 From: divyaac Date: Mon, 2 Dec 2024 11:44:03 -0800 Subject: [PATCH] Remove all references to current fragments, standbyfragments and partialMonthTracker (#29066) * Oss Changes Patch * Remove test from oss file --- builtin/logical/pki/acme_billing_test.go | 8 +- .../operator_usage_testonly_test.go | 2 +- sdk/helper/clientcountutil/clientcountutil.go | 27 +- .../clientcountutil/clientcountutil_test.go | 5 +- vault/activity_log.go | 468 ++++++++--------- vault/activity_log_test.go | 477 ++++++------------ vault/activity_log_testing_util.go | 84 ++- vault/activity_log_util_common.go | 19 +- vault/activity_log_util_common_test.go | 62 +-- .../acme_regeneration_test.go | 6 +- .../activity_testonly_oss_test.go | 2 +- .../activity_testonly_test.go | 24 +- .../logical_system_activity_write_testonly.go | 104 ++-- ...cal_system_activity_write_testonly_test.go | 167 ++++-- 14 files changed, 632 insertions(+), 823 deletions(-) diff --git a/builtin/logical/pki/acme_billing_test.go b/builtin/logical/pki/acme_billing_test.go index b1948d7be29c..f8db67e64478 100644 --- a/builtin/logical/pki/acme_billing_test.go +++ b/builtin/logical/pki/acme_billing_test.go @@ -104,15 +104,17 @@ func TestACMEBilling(t *testing.T) { expectedCount = validateClientCount(t, client, "ns2/pki", expectedCount+1, "unique identifier in a different namespace") // Check the current fragment - fragment := cluster.Cores[0].Core.ResetActivityLog()[0] - if fragment == nil { + localFragment, globalFragment := cluster.Cores[0].Core.ResetActivityLog() + if globalFragment == nil || localFragment == nil { t.Fatal("no fragment created") } - validateAcmeClientTypes(t, fragment, expectedCount) + validateAcmeClientTypes(t, localFragment[0], 0) + validateAcmeClientTypes(t, globalFragment[0], expectedCount) } func validateAcmeClientTypes(t *testing.T, fragment *activity.LogFragment, expectedCount int64) { t.Helper() + if int64(len(fragment.Clients)) != expectedCount { t.Fatalf("bad number of entities, expected %v: got %v, entities are: %v", expectedCount, len(fragment.Clients), fragment.Clients) } diff --git a/command/command_testonly/operator_usage_testonly_test.go b/command/command_testonly/operator_usage_testonly_test.go index 74d67291fd1f..4cdfc0536ac3 100644 --- a/command/command_testonly/operator_usage_testonly_test.go +++ b/command/command_testonly/operator_usage_testonly_test.go @@ -53,7 +53,7 @@ func TestOperatorUsageCommandRun(t *testing.T) { now := time.Now().UTC() - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientsSeen(6, clientcountutil.WithClientType("entity")). NewClientsSeen(4, clientcountutil.WithClientType("non-entity-token")). diff --git a/sdk/helper/clientcountutil/clientcountutil.go b/sdk/helper/clientcountutil/clientcountutil.go index d09c5be13d33..85b25dab4348 100644 --- a/sdk/helper/clientcountutil/clientcountutil.go +++ b/sdk/helper/clientcountutil/clientcountutil.go @@ -280,39 +280,30 @@ func (d *ActivityLogDataGenerator) ToProto() *generation.ActivityLogMockInput { } // Write writes the data to the API with the given write options. The method -// returns the new paths that have been written. Note that the API endpoint will +// returns the new local and global paths that have been written. Note that the API endpoint will // only be present when Vault has been compiled with the "testonly" flag. -func (d *ActivityLogDataGenerator) Write(ctx context.Context, writeOptions ...generation.WriteOptions) ([]string, []string, []string, error) { +func (d *ActivityLogDataGenerator) Write(ctx context.Context, writeOptions ...generation.WriteOptions) ([]string, []string, error) { d.data.Write = writeOptions err := VerifyInput(d.data) if err != nil { - return nil, nil, nil, err + return nil, nil, err } data, err := d.ToJSON() if err != nil { - return nil, nil, nil, err + return nil, nil, err } resp, err := d.client.Logical().WriteWithContext(ctx, "sys/internal/counters/activity/write", map[string]interface{}{"input": string(data)}) if err != nil { - return nil, nil, nil, err + return nil, nil, err } if resp.Data == nil { - return nil, nil, nil, fmt.Errorf("received no data") - } - paths := resp.Data["paths"] - castedPaths, ok := paths.([]interface{}) - if !ok { - return nil, nil, nil, fmt.Errorf("invalid paths data: %v", paths) - } - returnPaths := make([]string, 0, len(castedPaths)) - for _, path := range castedPaths { - returnPaths = append(returnPaths, path.(string)) + return nil, nil, fmt.Errorf("received no data") } localPaths := resp.Data["local_paths"] localCastedPaths, ok := localPaths.([]interface{}) if !ok { - return nil, nil, nil, fmt.Errorf("invalid local paths data: %v", localPaths) + return nil, nil, fmt.Errorf("invalid local paths data: %v", localPaths) } returnLocalPaths := make([]string, 0, len(localCastedPaths)) for _, path := range localCastedPaths { @@ -322,13 +313,13 @@ func (d *ActivityLogDataGenerator) Write(ctx context.Context, writeOptions ...ge globalPaths := resp.Data["global_paths"] globalCastedPaths, ok := globalPaths.([]interface{}) if !ok { - return nil, nil, nil, fmt.Errorf("invalid global paths data: %v", globalPaths) + return nil, nil, fmt.Errorf("invalid global paths data: %v", globalPaths) } returnGlobalPaths := make([]string, 0, len(globalCastedPaths)) for _, path := range globalCastedPaths { returnGlobalPaths = append(returnGlobalPaths, path.(string)) } - return returnPaths, returnLocalPaths, returnGlobalPaths, nil + return returnLocalPaths, returnGlobalPaths, nil } // VerifyInput checks that the input data is valid diff --git a/sdk/helper/clientcountutil/clientcountutil_test.go b/sdk/helper/clientcountutil/clientcountutil_test.go index 4ea987fed025..637407436503 100644 --- a/sdk/helper/clientcountutil/clientcountutil_test.go +++ b/sdk/helper/clientcountutil/clientcountutil_test.go @@ -116,7 +116,7 @@ func TestNewCurrentMonthData_AddClients(t *testing.T) { // sent to the server is correct. func TestWrite(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := io.WriteString(w, `{"data":{"paths":["path1","path2"],"global_paths":["path2","path3"], "local_paths":["path3","path4"]}}`) + _, err := io.WriteString(w, `{"data":{"global_paths":["path2","path3"], "local_paths":["path3","path4"]}}`) require.NoError(t, err) body, err := io.ReadAll(r.Body) require.NoError(t, err) @@ -131,7 +131,7 @@ func TestWrite(t *testing.T) { Address: ts.URL, }) require.NoError(t, err) - paths, localPaths, globalPaths, err := NewActivityLogData(client). + localPaths, globalPaths, err := NewActivityLogData(client). NewPreviousMonthData(3). NewClientSeen(). NewPreviousMonthData(2). @@ -140,7 +140,6 @@ func TestWrite(t *testing.T) { NewCurrentMonthData().Write(context.Background(), generation.WriteOptions_WRITE_ENTITIES) require.NoError(t, err) - require.Equal(t, []string{"path1", "path2"}, paths) require.Equal(t, []string{"path2", "path3"}, globalPaths) require.Equal(t, []string{"path3", "path4"}, localPaths) } diff --git a/vault/activity_log.go b/vault/activity_log.go index 3ad43d31b479..757165f3e1f1 100644 --- a/vault/activity_log.go +++ b/vault/activity_log.go @@ -51,7 +51,6 @@ const ( distinctClientsBasePath = "log/distinctclients/" // for testing purposes (public as needed) - ActivityLogPrefix = "sys/counters/activity/log/" ActivityGlobalLogPrefix = "sys/counters/activity/global/log/" ActivityLogLocalPrefix = "sys/counters/activity/local/log/" ActivityPrefix = "sys/counters/activity/" @@ -147,8 +146,7 @@ type ActivityLog struct { // Acquire "l" before fragmentLock, globalFragmentLock, and localFragmentLock if all must be held. l sync.RWMutex - // fragmentLock protects enable, partialMonthClientTracker, fragment, - // standbyFragmentsReceived. + // fragmentLock protects enable fragmentLock sync.RWMutex // localFragmentLock protects partialMonthLocalClientTracker, localFragment, @@ -180,9 +178,6 @@ type ActivityLog struct { // could be adapted to use a secondary in the future. nodeID string - // current log fragment (may be nil) - fragment *activity.LogFragment - // Channel to signal a new fragment has been created // so it's appropriate to start the timer. newFragmentCh chan struct{} @@ -210,9 +205,6 @@ type ActivityLog struct { // track metadata and contents of the most recent local log segment currentLocalSegment segmentInfo - // Fragments received from performance standbys - standbyFragmentsReceived []*activity.LogFragment - // Local fragments received from performance standbys standbyLocalFragmentsReceived []*activity.LogFragment @@ -238,9 +230,6 @@ type ActivityLog struct { // for testing: is config currently being invalidated. protected by l configInvalidationInProgress bool - // partialMonthClientTracker tracks active clients this month. Protected by fragmentLock. - partialMonthClientTracker map[string]*activity.EntityRecord - // partialMonthLocalClientTracker tracks active local clients this month. Protected by localFragmentLock. partialMonthLocalClientTracker map[string]*activity.EntityRecord @@ -370,7 +359,6 @@ func NewActivityLog(core *Core, logger log.Logger, view *BarrierView, metrics me newFragmentCh: make(chan struct{}, 1), sendCh: make(chan struct{}, 1), // buffered so it can be triggered by fragment size doneCh: make(chan struct{}, 1), - partialMonthClientTracker: make(map[string]*activity.EntityRecord), partialMonthLocalClientTracker: make(map[string]*activity.EntityRecord), newGlobalClientFragmentCh: make(chan struct{}, 1), globalPartialMonthClientTracker: make(map[string]*activity.EntityRecord), @@ -414,7 +402,6 @@ func NewActivityLog(core *Core, logger log.Logger, view *BarrierView, metrics me }, clientSequenceNumber: 0, }, - standbyFragmentsReceived: make([]*activity.LogFragment, 0), standbyLocalFragmentsReceived: make([]*activity.LogFragment, 0), standbyGlobalFragmentsReceived: make([]*activity.LogFragment, 0), secondaryGlobalClientFragments: make([]*activity.LogFragment, 0), @@ -462,14 +449,7 @@ func (a *ActivityLog) saveCurrentSegmentToStorageLocked(ctx context.Context, for defer a.metrics.MeasureSinceWithLabels([]string{"core", "activity", "segment_write"}, a.clock.Now(), []metricsutil.Label{}) - // Swap out the pending regular fragments - a.fragmentLock.Lock() - currentFragment := a.fragment - a.fragment = nil - standbys := a.standbyFragmentsReceived - a.standbyFragmentsReceived = make([]*activity.LogFragment, 0) - a.fragmentLock.Unlock() - + // Swap out the pending global fragments a.globalFragmentLock.Lock() secondaryGlobalClients := a.secondaryGlobalClientFragments a.secondaryGlobalClientFragments = make([]*activity.LogFragment, 0) @@ -505,14 +485,10 @@ func (a *ActivityLog) saveCurrentSegmentToStorageLocked(ctx context.Context, for // If segment start time is zero, do not update or write // (even if force is true). This can happen if activityLog is // disabled after a save as been triggered. - if a.currentSegment.startTimestamp == 0 { + if a.currentGlobalSegment.startTimestamp == 0 { return nil } - if ret := a.createCurrentSegmentFromFragments(ctx, append(standbys, currentFragment), &a.currentSegment, force, ""); ret != nil { - return ret - } - // If we are the primary, store global clients // Create fragments from global clients and store the segment if !a.core.IsPerfSecondary() { @@ -575,7 +551,7 @@ func (a *ActivityLog) createCurrentSegmentFromFragments(ctx context.Context, fra // month when the client upgrades to 1.9, we must retain this functionality. for ns, val := range f.NonEntityTokens { // We track these pre-1.9 values in the old location, which is - // a.currentSegment.tokenCount, as opposed to the counter that stores tokens + // currentSegment.tokenCount, as opposed to the counter that stores tokens // without entities that have client IDs, namely // a.partialMonthClientTracker.nonEntityCountByNamespaceID. This preserves backward // compatibility for the precomputedQueryWorkers and the segment storing @@ -736,7 +712,7 @@ func parseSegmentNumberFromPath(path string) (int, bool) { // sorted last to first func (a *ActivityLog) availableLogs(ctx context.Context, upTo time.Time) ([]time.Time, error) { paths := make([]string, 0) - for _, basePath := range []string{activityEntityBasePath, activityLocalPathPrefix + activityEntityBasePath, activityGlobalPathPrefix + activityEntityBasePath, activityTokenLocalBasePath} { + for _, basePath := range []string{activityLocalPathPrefix + activityEntityBasePath, activityGlobalPathPrefix + activityEntityBasePath, activityTokenLocalBasePath} { p, err := a.view.List(ctx, basePath) if err != nil { return nil, err @@ -785,21 +761,17 @@ func (a *ActivityLog) getMostRecentActivityLogSegment(ctx context.Context, now t } // getLastEntitySegmentNumber returns the (non-negative) last segment number for the :startTime:, if it exists -func (a *ActivityLog) getLastEntitySegmentNumber(ctx context.Context, startTime time.Time) (uint64, uint64, uint64, bool, error) { - segmentHighestNum, segmentPresent, err := a.getLastSegmentNumberByEntityPath(ctx, activityEntityBasePath+fmt.Sprint(startTime.Unix())+"/") - if err != nil { - return 0, 0, 0, false, err - } +func (a *ActivityLog) getLastEntitySegmentNumber(ctx context.Context, startTime time.Time) (uint64, uint64, bool, error) { globalHighestNum, globalSegmentPresent, err := a.getLastSegmentNumberByEntityPath(ctx, activityGlobalPathPrefix+activityEntityBasePath+fmt.Sprint(startTime.Unix())+"/") if err != nil { - return 0, 0, 0, false, err + return 0, 0, false, err } localHighestNum, localSegmentPresent, err := a.getLastSegmentNumberByEntityPath(ctx, activityLocalPathPrefix+activityEntityBasePath+fmt.Sprint(startTime.Unix())+"/") if err != nil { - return 0, 0, 0, false, err + return 0, 0, false, err } - return segmentHighestNum, uint64(localHighestNum), uint64(globalHighestNum), (segmentPresent || localSegmentPresent || globalSegmentPresent), nil + return uint64(localHighestNum), uint64(globalHighestNum), (localSegmentPresent || globalSegmentPresent), nil } func (a *ActivityLog) getLastSegmentNumberByEntityPath(ctx context.Context, entityPath string) (uint64, bool, error) { @@ -829,30 +801,33 @@ func (a *ActivityLog) getLastSegmentNumberByEntityPath(ctx context.Context, enti // WalkEntitySegments loads each of the entity segments for a particular start time func (a *ActivityLog) WalkEntitySegments(ctx context.Context, startTime time.Time, hll *hyperloglog.Sketch, walkFn func(*activity.EntityActivityLog, time.Time, *hyperloglog.Sketch) error) error { - basePath := activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" - pathList, err := a.view.List(ctx, basePath) - if err != nil { - return err - } + baseGlobalPath := activityGlobalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + baseLocalPath := activityLocalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" - for _, path := range pathList { - raw, err := a.view.Get(ctx, basePath+path) + for _, basePath := range []string{baseGlobalPath, baseLocalPath} { + pathList, err := a.view.List(ctx, basePath) if err != nil { return err } - if raw == nil { - a.logger.Warn("expected log segment not found", "startTime", startTime, "segment", path) - continue - } + for _, path := range pathList { + raw, err := a.view.Get(ctx, basePath+path) + if err != nil { + return err + } + if raw == nil { + a.logger.Warn("expected log segment not found", "startTime", startTime, "segment", path) + continue + } - out := &activity.EntityActivityLog{} - err = proto.Unmarshal(raw.Value, out) - if err != nil { - return fmt.Errorf("unable to parse segment %v%v: %w", basePath, path, err) - } - err = walkFn(out, startTime, hll) - if err != nil { - return fmt.Errorf("unable to walk entities: %w", err) + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(raw.Value, out) + if err != nil { + return fmt.Errorf("unable to parse segment %v%v: %w", basePath, path, err) + } + err = walkFn(out, startTime, hll) + if err != nil { + return fmt.Errorf("unable to walk entities: %w", err) + } } } return nil @@ -889,69 +864,53 @@ func (a *ActivityLog) WalkTokenSegments(ctx context.Context, } // loadPriorEntitySegment populates the in-memory tracker for entity IDs that have -// been active "this month" -func (a *ActivityLog) loadPriorEntitySegment(ctx context.Context, startTime time.Time, sequenceNum uint64) error { - path := activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) - data, err := a.view.Get(ctx, path) - if err != nil { - return err - } - if data == nil { - return nil - } - - out := &activity.EntityActivityLog{} - err = proto.Unmarshal(data.Value, out) - if err != nil { - return err - } - +// been active "this month". If the entity segment to load is global, globalPartialMonthClientTracker +// is updated else partialMonthLocalClientTracker gets updated. +func (a *ActivityLog) loadPriorEntitySegment(ctx context.Context, startTime time.Time, sequenceNum uint64, isLocal bool) error { a.l.RLock() defer a.l.RUnlock() + + // protecting a.enabled a.fragmentLock.Lock() - // Handle the (unlikely) case where the end of the month has been reached while background loading. - // Or the feature has been disabled. - if a.enabled && startTime.Unix() == a.currentSegment.startTimestamp { - for _, ent := range out.Clients { - a.partialMonthClientTracker[ent.ClientID] = ent - } - } - a.fragmentLock.Unlock() + defer a.fragmentLock.Unlock() // load all the active global clients - globalPath := activityGlobalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) - data, err = a.view.Get(ctx, globalPath) - if err != nil { - return err - } - if data == nil { - return nil - } - out = &activity.EntityActivityLog{} - err = proto.Unmarshal(data.Value, out) - if err != nil { - return err - } - a.globalFragmentLock.Lock() - // Handle the (unlikely) case where the end of the month has been reached while background loading. - // Or the feature has been disabled. - if a.enabled && startTime.Unix() == a.currentGlobalSegment.startTimestamp { - for _, ent := range out.Clients { - a.globalPartialMonthClientTracker[ent.ClientID] = ent + if !isLocal { + globalPath := activityGlobalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) + data, err := a.view.Get(ctx, globalPath) + if err != nil { + return err + } + if data == nil { + return nil + } + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(data.Value, out) + if err != nil { + return err } + a.globalFragmentLock.Lock() + // Handle the (unlikely) case where the end of the month has been reached while background loading. + // Or the feature has been disabled. + if a.enabled && startTime.Unix() == a.currentGlobalSegment.startTimestamp { + for _, ent := range out.Clients { + a.globalPartialMonthClientTracker[ent.ClientID] = ent + } + } + a.globalFragmentLock.Unlock() + return nil } - a.globalFragmentLock.Unlock() // load all the active local clients localPath := activityLocalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) - data, err = a.view.Get(ctx, localPath) + data, err := a.view.Get(ctx, localPath) if err != nil { return err } if data == nil { return nil } - out = &activity.EntityActivityLog{} + out := &activity.EntityActivityLog{} err = proto.Unmarshal(data.Value, out) if err != nil { return err @@ -970,75 +929,44 @@ func (a *ActivityLog) loadPriorEntitySegment(ctx context.Context, startTime time } // loadCurrentClientSegment loads the most recent segment (for "this month") -// into memory (to append new entries), and to the partialMonthClientTracker to +// into memory (to append new entries), and to the globalPartialMonthClientTracker and partialMonthLocalClientTracker to // avoid duplication call with fragmentLock, globalFragmentLock, localFragmentLock and l held. -func (a *ActivityLog) loadCurrentClientSegment(ctx context.Context, startTime time.Time, sequenceNum uint64, localSegmentSequenceNumber uint64, globalSegmentSequenceNumber uint64) error { - path := activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(sequenceNum, 10) - data, err := a.view.Get(ctx, path) - if err != nil { - return err - } - if data == nil { - return nil - } - - out := &activity.EntityActivityLog{} - err = proto.Unmarshal(data.Value, out) - if err != nil { - return err - } - - if !a.core.perfStandby { - a.currentSegment = segmentInfo{ - startTimestamp: startTime.Unix(), - currentClients: &activity.EntityActivityLog{ - Clients: out.Clients, - }, - tokenCount: a.currentSegment.tokenCount, - clientSequenceNumber: sequenceNum, - } - } else { - // populate this for edge case checking (if end of month passes while background loading on standby) - a.currentSegment.startTimestamp = startTime.Unix() - } - - for _, client := range out.Clients { - a.partialMonthClientTracker[client.ClientID] = client - } - +func (a *ActivityLog) loadCurrentClientSegment(ctx context.Context, startTime time.Time, localSegmentSequenceNumber uint64, globalSegmentSequenceNumber uint64) error { // load current global segment - path = activityGlobalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(globalSegmentSequenceNumber, 10) - data, err = a.view.Get(ctx, path) - if err != nil { - return err - } - if data == nil { - return nil - } + path := activityGlobalPathPrefix + activityEntityBasePath + fmt.Sprint(startTime.Unix()) + "/" + strconv.FormatUint(globalSegmentSequenceNumber, 10) - out = &activity.EntityActivityLog{} - err = proto.Unmarshal(data.Value, out) + // setting a.currentSegment timestamp to support upgrades + a.currentSegment.startTimestamp = startTime.Unix() + + data, err := a.view.Get(ctx, path) if err != nil { return err } + if data != nil { + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(data.Value, out) + if err != nil { + return err + } - if !a.core.perfStandby { - a.currentGlobalSegment = segmentInfo{ - startTimestamp: startTime.Unix(), - currentClients: &activity.EntityActivityLog{ - Clients: out.Clients, - }, - tokenCount: &activity.TokenCount{ - CountByNamespaceID: make(map[string]uint64), - }, - clientSequenceNumber: sequenceNum, + if !a.core.perfStandby { + a.currentGlobalSegment = segmentInfo{ + startTimestamp: startTime.Unix(), + currentClients: &activity.EntityActivityLog{ + Clients: out.Clients, + }, + tokenCount: &activity.TokenCount{ + CountByNamespaceID: make(map[string]uint64), + }, + clientSequenceNumber: globalSegmentSequenceNumber, + } + } else { + // populate this for edge case checking (if end of month passes while background loading on standby) + a.currentGlobalSegment.startTimestamp = startTime.Unix() + } + for _, client := range out.Clients { + a.globalPartialMonthClientTracker[client.ClientID] = client } - } else { - // populate this for edge case checking (if end of month passes while background loading on standby) - a.currentGlobalSegment.startTimestamp = startTime.Unix() - } - for _, client := range out.Clients { - a.globalPartialMonthClientTracker[client.ClientID] = client } // load current local segment @@ -1047,31 +975,30 @@ func (a *ActivityLog) loadCurrentClientSegment(ctx context.Context, startTime ti if err != nil { return err } - if data == nil { - return nil - } - - out = &activity.EntityActivityLog{} - err = proto.Unmarshal(data.Value, out) - if err != nil { - return err - } + if data != nil { + out := &activity.EntityActivityLog{} + err = proto.Unmarshal(data.Value, out) + if err != nil { + return err + } - if !a.core.perfStandby { - a.currentLocalSegment = segmentInfo{ - startTimestamp: startTime.Unix(), - currentClients: &activity.EntityActivityLog{ - Clients: out.Clients, - }, - tokenCount: a.currentLocalSegment.tokenCount, - clientSequenceNumber: sequenceNum, + if !a.core.perfStandby { + a.currentLocalSegment = segmentInfo{ + startTimestamp: startTime.Unix(), + currentClients: &activity.EntityActivityLog{ + Clients: out.Clients, + }, + tokenCount: a.currentLocalSegment.tokenCount, + clientSequenceNumber: localSegmentSequenceNumber, + } + } else { + // populate this for edge case checking (if end of month passes while background loading on standby) + a.currentLocalSegment.startTimestamp = startTime.Unix() } - } else { - // populate this for edge case checking (if end of month passes while background loading on standby) - a.currentLocalSegment.startTimestamp = startTime.Unix() - } - for _, client := range out.Clients { - a.partialMonthLocalClientTracker[client.ClientID] = client + for _, client := range out.Clients { + a.partialMonthLocalClientTracker[client.ClientID] = client + } + } return nil @@ -1128,14 +1055,15 @@ func (a *ActivityLog) loadTokenCount(ctx context.Context, startTime time.Time) e // We must load the tokenCount of the current segment into the activity log // so that TWEs counted before the introduction of a client ID for TWEs are // still reported in the partial client counts. - a.currentSegment.tokenCount = out a.currentLocalSegment.tokenCount = out return nil } -// entityBackgroundLoader loads entity activity log records for start_date `t` -func (a *ActivityLog) entityBackgroundLoader(ctx context.Context, wg *sync.WaitGroup, t time.Time, seqNums <-chan uint64) { +// entityBackgroundLoader loads entity activity log records for start_date `t`. +// If isLocal is true, it loads the local entity activity log records else it +// loads global entity activity log records. +func (a *ActivityLog) entityBackgroundLoader(ctx context.Context, wg *sync.WaitGroup, t time.Time, seqNums <-chan uint64, isLocal bool) { defer wg.Done() for seqNum := range seqNums { select { @@ -1145,7 +1073,7 @@ func (a *ActivityLog) entityBackgroundLoader(ctx context.Context, wg *sync.WaitG default: } - err := a.loadPriorEntitySegment(ctx, t, seqNum) + err := a.loadPriorEntitySegment(ctx, t, seqNum, isLocal) if err != nil { a.logger.Error("error loading entity activity log", "time", t, "sequence", seqNum, "err", err) } @@ -1169,7 +1097,7 @@ func (a *ActivityLog) newMonthCurrentLogLocked(currentTime time.Time) { } // Initialize a new current segment, based on the given time -// should be called with fragmentLock, globalFragmentLock, localFragmentLock and l held. +// should be called with globalFragmentLock, localFragmentLock and l held. func (a *ActivityLog) newSegmentAtGivenTime(t time.Time) { timestamp := t.Unix() @@ -1182,26 +1110,17 @@ func (a *ActivityLog) newSegmentAtGivenTime(t time.Time) { // should be called with l held. func (a *ActivityLog) setCurrentSegmentTimeLocked(t time.Time) { timestamp := t.Unix() - a.currentSegment.startTimestamp = timestamp a.currentGlobalSegment.startTimestamp = timestamp a.currentLocalSegment.startTimestamp = timestamp + // setting a.currentSegment timestamp to support upgrades + a.currentSegment.startTimestamp = timestamp } // Reset all the current segment state. -// Should be called with fragmentLock, globalFragmentLock, localFragmentLock and l held. +// Should be called with globalFragmentLock, localFragmentLock and l held. func (a *ActivityLog) resetCurrentLog() { + // setting a.currentSegment timestamp to support upgrades a.currentSegment.startTimestamp = 0 - a.currentSegment.currentClients = &activity.EntityActivityLog{ - Clients: make([]*activity.EntityRecord, 0), - } - - // We must still initialize the tokenCount to recieve tokenCounts from fragments - // during the month where customers upgrade to 1.9 - a.currentSegment.tokenCount = &activity.TokenCount{ - CountByNamespaceID: make(map[string]uint64), - } - - a.currentSegment.clientSequenceNumber = 0 // global segment a.currentGlobalSegment.startTimestamp = 0 @@ -1217,16 +1136,12 @@ func (a *ActivityLog) resetCurrentLog() { } a.currentLocalSegment.clientSequenceNumber = 0 - a.fragment = nil - a.partialMonthClientTracker = make(map[string]*activity.EntityRecord) - a.currentGlobalFragment = nil a.globalPartialMonthClientTracker = make(map[string]*activity.EntityRecord) a.localFragment = nil a.partialMonthLocalClientTracker = make(map[string]*activity.EntityRecord) - a.standbyFragmentsReceived = make([]*activity.LogFragment, 0) a.standbyLocalFragmentsReceived = make([]*activity.LogFragment, 0) a.standbyGlobalFragmentsReceived = make([]*activity.LogFragment, 0) a.secondaryGlobalClientFragments = make([]*activity.LogFragment, 0) @@ -1234,7 +1149,6 @@ func (a *ActivityLog) resetCurrentLog() { func (a *ActivityLog) deleteLogWorker(ctx context.Context, startTimestamp int64, whenDone chan struct{}) { entityPathsToDelete := make([]string, 0) - entityPathsToDelete = append(entityPathsToDelete, fmt.Sprintf("%v%v/", activityEntityBasePath, startTimestamp)) entityPathsToDelete = append(entityPathsToDelete, fmt.Sprintf("%s%v%v/", activityGlobalPathPrefix, activityEntityBasePath, startTimestamp)) entityPathsToDelete = append(entityPathsToDelete, fmt.Sprintf("%s%v%v/", activityLocalPathPrefix, activityEntityBasePath, startTimestamp)) entityPathsToDelete = append(entityPathsToDelete, fmt.Sprintf("%v%v/", activityTokenLocalBasePath, startTimestamp)) @@ -1350,7 +1264,7 @@ func (a *ActivityLog) refreshFromStoredLog(ctx context.Context, wg *sync.WaitGro } // load entity logs from storage into memory - lastSegment, localLastSegment, globalLastSegment, segmentsExist, err := a.getLastEntitySegmentNumber(ctx, mostRecent) + localLastSegment, globalLastSegment, segmentsExist, err := a.getLastEntitySegmentNumber(ctx, mostRecent) if err != nil { return err } @@ -1359,20 +1273,39 @@ func (a *ActivityLog) refreshFromStoredLog(ctx context.Context, wg *sync.WaitGro return nil } - err = a.loadCurrentClientSegment(ctx, mostRecent, lastSegment, localLastSegment, globalLastSegment) - if err != nil || lastSegment == 0 { + err = a.loadCurrentClientSegment(ctx, mostRecent, localLastSegment, globalLastSegment) + // if both localLastSegment and globalLastSegment are 0, it will return nil here + if err != nil || (localLastSegment == 0 && globalLastSegment == 0) { return err } - lastSegment-- - seqNums := make(chan uint64, lastSegment+1) - wg.Add(1) - go a.entityBackgroundLoader(ctx, wg, mostRecent, seqNums) + // if last local segment that got loaded using loadCurrentClientSegment is not 0, there are more local segments to load + if localLastSegment != 0 { + localLastSegment-- + + localSeqNums := make(chan uint64, localLastSegment+1) + wg.Add(1) + go a.entityBackgroundLoader(ctx, wg, mostRecent, localSeqNums, true) + + for n := int(localLastSegment); n >= 0; n-- { + localSeqNums <- uint64(n) + } + close(localSeqNums) + } + + // if last global segment that got loaded using loadCurrentClientSegment is not 0, there are more global segments to load + if globalLastSegment != 0 { + globalLastSegment-- - for n := int(lastSegment); n >= 0; n-- { - seqNums <- uint64(n) + globalSeqNums := make(chan uint64, globalLastSegment+1) + wg.Add(1) + go a.entityBackgroundLoader(ctx, wg, mostRecent, globalSeqNums, false) + + for n := int(globalLastSegment); n >= 0; n-- { + globalSeqNums <- uint64(n) + } + close(globalSeqNums) } - close(seqNums) return nil } @@ -1425,16 +1358,16 @@ func (a *ActivityLog) SetConfig(ctx context.Context, config activityConfig) { a.logger.Info("activity log enable changed", "original", originalEnabled, "current", a.enabled) } - if !a.enabled && a.currentSegment.startTimestamp != 0 && a.currentGlobalSegment.startTimestamp != 0 && a.currentLocalSegment.startTimestamp != 0 { + if !a.enabled && a.currentGlobalSegment.startTimestamp != 0 && a.currentLocalSegment.startTimestamp != 0 { a.logger.Trace("deleting current segment") a.deleteDone = make(chan struct{}) // this is called from a request under stateLock, so use activeContext - go a.deleteLogWorker(a.core.activeContext, a.currentSegment.startTimestamp, a.deleteDone) + go a.deleteLogWorker(a.core.activeContext, a.currentGlobalSegment.startTimestamp, a.deleteDone) a.resetCurrentLog() } forceSave := false - if a.enabled && a.currentSegment.startTimestamp == 0 && a.currentGlobalSegment.startTimestamp == 0 && a.currentLocalSegment.startTimestamp == 0 { + if a.enabled && a.currentGlobalSegment.startTimestamp == 0 && a.currentLocalSegment.startTimestamp == 0 { a.startNewCurrentLogLocked(a.clock.Now().UTC()) // Force a save so we can distinguish between // @@ -1453,7 +1386,6 @@ func (a *ActivityLog) SetConfig(ctx context.Context, config activityConfig) { if forceSave { // l is still held here - a.saveCurrentSegmentInternal(ctx, true, a.currentSegment, "") a.saveCurrentSegmentInternal(ctx, true, a.currentGlobalSegment, activityGlobalPathPrefix) a.saveCurrentSegmentInternal(ctx, true, a.currentLocalSegment, activityLocalPathPrefix) } @@ -1690,10 +1622,10 @@ func (a *ActivityLog) StartOfNextMonth() time.Time { a.l.RLock() defer a.l.RUnlock() var segmentStart time.Time - if a.currentSegment.startTimestamp == 0 { + if a.currentGlobalSegment.startTimestamp == 0 { segmentStart = a.clock.Now().UTC() } else { - segmentStart = time.Unix(a.currentSegment.startTimestamp, 0).UTC() + segmentStart = time.Unix(a.currentGlobalSegment.startTimestamp, 0).UTC() } // Basing this on the segment start will mean we trigger EOM rollover when // necessary because we were down. @@ -1868,12 +1800,6 @@ func (a *ActivityLog) perfStandbyFragmentWorker(ctx context.Context) { } sendFunc() - // clear active entity set - a.fragmentLock.Lock() - a.partialMonthClientTracker = make(map[string]*activity.EntityRecord) - - a.fragmentLock.Unlock() - // clear local active entity set a.localFragmentLock.Lock() a.partialMonthLocalClientTracker = make(map[string]*activity.EntityRecord) @@ -1990,7 +1916,7 @@ func (a *ActivityLog) HandleEndOfMonth(ctx context.Context, currentTime time.Tim a.logger.Trace("starting end of month processing", "rolloverTime", currentTime) - err := a.writeIntentLog(ctx, a.currentSegment.startTimestamp, currentTime) + err := a.writeIntentLog(ctx, a.currentGlobalSegment.startTimestamp, currentTime) if err != nil { return err } @@ -2049,42 +1975,38 @@ func (a *ActivityLog) writeIntentLog(ctx context.Context, prevSegmentTimestamp i return nil } -// ResetActivityLog is used to extract the current fragment(s) during +// ResetActivityLog is used to extract the current local and global fragment(s) during // integration testing, so that it can be checked in a race-free way. -func (c *Core) ResetActivityLog() []*activity.LogFragment { +func (c *Core) ResetActivityLog() ([]*activity.LogFragment, []*activity.LogFragment) { c.stateLock.RLock() a := c.activityLog c.stateLock.RUnlock() if a == nil { - return nil + return nil, nil } - allFragments := make([]*activity.LogFragment, 1) - a.fragmentLock.Lock() - - allFragments[0] = a.fragment - a.fragment = nil - allFragments = append(allFragments, a.standbyFragmentsReceived...) - a.standbyFragmentsReceived = make([]*activity.LogFragment, 0) - a.secondaryGlobalClientFragments = make([]*activity.LogFragment, 0) - a.partialMonthClientTracker = make(map[string]*activity.EntityRecord) - a.fragmentLock.Unlock() + localFragments := make([]*activity.LogFragment, 0) + globalFragments := make([]*activity.LogFragment, 0) // local fragments a.localFragmentLock.Lock() - allFragments = append(allFragments, a.localFragment) + localFragments = append(localFragments, a.localFragment) a.localFragment = nil - allFragments = append(allFragments, a.standbyLocalFragmentsReceived...) + localFragments = append(localFragments, a.standbyLocalFragmentsReceived...) a.standbyLocalFragmentsReceived = make([]*activity.LogFragment, 0) a.partialMonthLocalClientTracker = make(map[string]*activity.EntityRecord) a.localFragmentLock.Unlock() // global fragments a.globalFragmentLock.Lock() + globalFragments = append(globalFragments, a.currentGlobalFragment) + a.currentGlobalFragment = nil + globalFragments = append(globalFragments, a.standbyGlobalFragmentsReceived...) a.globalPartialMonthClientTracker = make(map[string]*activity.EntityRecord) a.standbyGlobalFragmentsReceived = make([]*activity.LogFragment, 0) + a.secondaryGlobalClientFragments = make([]*activity.LogFragment, 0) a.globalFragmentLock.Unlock() - return allFragments + return localFragments, globalFragments } func (a *ActivityLog) AddEntityToFragment(entityID string, namespaceID string, timestamp int64) { @@ -2121,7 +2043,7 @@ func (a *ActivityLog) AddActivityToFragment(clientID string, namespaceID string, a.fragmentLock.RLock() if a.enabled { - _, presentInRegularClientMap := a.partialMonthClientTracker[clientID] + _, presentInRegularClientMap := a.globalPartialMonthClientTracker[clientID] _, presentInLocalClientmap := a.partialMonthLocalClientTracker[clientID] if presentInRegularClientMap || presentInLocalClientmap { present = true @@ -2146,7 +2068,7 @@ func (a *ActivityLog) AddActivityToFragment(clientID string, namespaceID string, defer a.globalFragmentLock.Unlock() // Re-check entity ID after re-acquiring lock - _, presentInRegularClientMap := a.partialMonthClientTracker[clientID] + _, presentInRegularClientMap := a.globalPartialMonthClientTracker[clientID] _, presentInLocalClientmap := a.partialMonthLocalClientTracker[clientID] if presentInRegularClientMap || presentInLocalClientmap { present = true @@ -2174,10 +2096,6 @@ func (a *ActivityLog) AddActivityToFragment(clientID string, namespaceID string, clientRecord.NonEntity = true } - // add the clients to the regular fragment - a.fragment.Clients = append(a.fragment.Clients, clientRecord) - a.partialMonthClientTracker[clientRecord.ClientID] = clientRecord - if local, _ := a.isClientLocal(clientRecord); local { // If the client is local then add the client to the current local fragment a.localFragment.Clients = append(a.localFragment.Clients, clientRecord) @@ -2212,17 +2130,10 @@ func (a *ActivityLog) isClientLocal(client *activity.EntityRecord) (bool, error) return false, nil } -// Create the fragments (regular fragment, local fragment and global fragment) if it doesn't already exist. +// Create the fragments (local fragment and global fragment) if it doesn't already exist. // Must be called with the fragmentLock, localFragmentLock and globalFragmentLock held. func (a *ActivityLog) createCurrentFragment() { - if a.fragment == nil { - // create regular fragment - a.fragment = &activity.LogFragment{ - OriginatingNode: a.nodeID, - Clients: make([]*activity.EntityRecord, 0, 120), - NonEntityTokens: make(map[string]uint64), - } - + if a.currentGlobalFragment == nil { // create local fragment a.localFragment = &activity.LogFragment{ OriginatingNode: a.nodeID, @@ -2232,6 +2143,7 @@ func (a *ActivityLog) createCurrentFragment() { // create global fragment a.currentGlobalFragment = &activity.LogFragment{ + OriginatingNode: a.nodeID, OriginatingCluster: a.core.ClusterID(), Clients: make([]*activity.EntityRecord, 0), } @@ -2293,7 +2205,6 @@ func (a *ActivityLog) receivedFragment(fragment *activity.LogFragment) { } for _, e := range fragment.Clients { - a.partialMonthClientTracker[e.ClientID] = e if isLocalFragment { a.partialMonthLocalClientTracker[e.ClientID] = e } else { @@ -2301,8 +2212,6 @@ func (a *ActivityLog) receivedFragment(fragment *activity.LogFragment) { } } - a.standbyFragmentsReceived = append(a.standbyFragmentsReceived, fragment) - if isLocalFragment { a.standbyLocalFragmentsReceived = append(a.standbyLocalFragmentsReceived, fragment) } else { @@ -2966,7 +2875,7 @@ func (a *ActivityLog) segmentToPrecomputedQuery(ctx context.Context, segmentTime // Iterate through entities, adding them to the hyperloglog and the summary maps in opts for { - entity, err := reader.ReadEntity(ctx) + entity, err := reader.ReadGlobalEntity(ctx) if errors.Is(err, io.EOF) { break } @@ -2981,6 +2890,23 @@ func (a *ActivityLog) segmentToPrecomputedQuery(ctx context.Context, segmentTime } } + for { + entity, err := reader.ReadLocalEntity(ctx) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + a.logger.Warn("failed to read segment", "error", err) + return err + } + err = a.handleEntitySegment(entity, segmentTime, hyperloglog, opts) + if err != nil { + a.logger.Warn("failed to handle entity segment", "error", err) + return err + } + + } + // Store the hyperloglog err = a.StoreHyperlogLog(ctx, segmentTime, hyperloglog) if err != nil { @@ -3133,7 +3059,7 @@ func (a *ActivityLog) precomputedQueryWorker(ctx context.Context, intent *Activi // too old, and startTimestamp should only go forward (unless it is zero.) // If there's an intent log, finish it even if the feature is currently disabled. a.l.RLock() - currentMonth := a.currentSegment.startTimestamp + currentMonth := a.currentGlobalSegment.startTimestamp // Base retention period on the month we are generating (even in the past)--- a.clock.Now() // would work but this will be easier to control in tests. retentionWindow := timeutil.MonthsPreviousTo(a.retentionMonths, time.Unix(intent.NextMonth, 0).UTC()) @@ -3272,7 +3198,7 @@ func (a *ActivityLog) PartialMonthMetrics(ctx context.Context) ([]metricsutil.Ga // Empty list return []metricsutil.GaugeLabelValues{}, nil } - count := len(a.partialMonthClientTracker) + count := len(a.globalPartialMonthClientTracker) + len(a.partialMonthLocalClientTracker) return []metricsutil.GaugeLabelValues{ { @@ -3298,7 +3224,7 @@ func (a *ActivityLog) populateNamespaceAndMonthlyBreakdowns() (map[int64]*proces // Parse the monthly clients and prepare the breakdowns. byNamespace := make(map[string]*processByNamespace) byMonth := make(map[int64]*processMonth) - for _, e := range a.partialMonthClientTracker { + for _, e := range a.globalPartialMonthClientTracker { processClientRecord(e, byNamespace, byMonth, a.clock.Now()) } for _, e := range a.partialMonthLocalClientTracker { diff --git a/vault/activity_log_test.go b/vault/activity_log_test.go index 4742d11467b8..1f36a7856582 100644 --- a/vault/activity_log_test.go +++ b/vault/activity_log_test.go @@ -34,7 +34,7 @@ import ( "github.com/stretchr/testify/require" ) -// TestActivityLog_Creation calls AddEntityToFragment and verifies that it appears correctly in a.fragment. +// TestActivityLog_Creation calls AddEntityToFragment and verifies that it appears correctly in a.currentGlobalFragment. func TestActivityLog_Creation(t *testing.T) { storage := &logical.InmemStorage{} coreConfig := &CoreConfig{ @@ -56,11 +56,13 @@ func TestActivityLog_Creation(t *testing.T) { if a.logger == nil || a.view == nil { t.Fatal("activity log not initialized") } - if a.fragment != nil || a.currentGlobalFragment != nil { - t.Fatal("activity log already has fragment") + currentGlobalFragment := core.GetActiveGlobalFragment() + if currentGlobalFragment != nil { + t.Fatal("activity log already has global fragment") } - if a.localFragment != nil { + localFragment := core.GetActiveLocalFragment() + if localFragment != nil { t.Fatal("activity log already has a local fragment") } @@ -69,44 +71,29 @@ func TestActivityLog_Creation(t *testing.T) { ts := time.Now() a.AddEntityToFragment(entity_id, namespace_id, ts.Unix()) - if a.fragment == nil || a.currentGlobalFragment == nil { + currentGlobalFragment = core.GetActiveGlobalFragment() + localFragment = core.GetActiveLocalFragment() + + if currentGlobalFragment == nil { t.Fatal("no fragment created") } - if a.fragment.OriginatingNode != a.nodeID { - t.Errorf("mismatched node ID, %q vs %q", a.fragment.OriginatingNode, a.nodeID) + if a.currentGlobalFragment.OriginatingNode != a.nodeID { + t.Errorf("mismatched node ID, %q vs %q", currentGlobalFragment.OriginatingNode, a.nodeID) } - if a.currentGlobalFragment.OriginatingCluster != a.core.ClusterID() { - t.Errorf("mismatched cluster ID, %q vs %q", a.currentGlobalFragment.GetOriginatingCluster(), a.core.ClusterID()) + if currentGlobalFragment.OriginatingCluster != a.core.ClusterID() { + t.Errorf("mismatched cluster ID, %q vs %q", currentGlobalFragment.GetOriginatingCluster(), a.core.ClusterID()) } - if a.fragment.Clients == nil || a.currentGlobalFragment.Clients == nil { + if currentGlobalFragment.Clients == nil { t.Fatal("no fragment entity slice") } - if a.fragment.NonEntityTokens == nil { - t.Fatal("no fragment token map") - } - - if len(a.fragment.Clients) != 1 { - t.Fatalf("wrong number of entities %v", len(a.fragment.Clients)) - } - if len(a.currentGlobalFragment.Clients) != 1 { - t.Fatalf("wrong number of entities %v", len(a.currentGlobalFragment.Clients)) + if len(currentGlobalFragment.Clients) != 1 { + t.Fatalf("wrong number of entities %v", len(currentGlobalFragment.Clients)) } - er := a.fragment.Clients[0] - if er.ClientID != entity_id { - t.Errorf("mimatched entity ID, %q vs %q", er.ClientID, entity_id) - } - if er.NamespaceID != namespace_id { - t.Errorf("mimatched namespace ID, %q vs %q", er.NamespaceID, namespace_id) - } - if er.Timestamp != ts.Unix() { - t.Errorf("mimatched timestamp, %v vs %v", er.Timestamp, ts.Unix()) - } - - er = a.currentGlobalFragment.Clients[0] + er := currentGlobalFragment.Clients[0] if er.ClientID != entity_id { t.Errorf("mimatched entity ID, %q vs %q", er.ClientID, entity_id) } @@ -118,22 +105,14 @@ func TestActivityLog_Creation(t *testing.T) { } // Reset and test the other code path - a.fragment = nil a.AddTokenToFragment(namespace_id) + currentGlobalFragment = core.GetActiveGlobalFragment() + localFragment = core.GetActiveLocalFragment() - if a.fragment == nil { + if currentGlobalFragment == nil { t.Fatal("no fragment created") } - if a.fragment.NonEntityTokens == nil { - t.Fatal("no fragment token map") - } - - actual := a.fragment.NonEntityTokens[namespace_id] - if actual != 1 { - t.Errorf("mismatched number of tokens, %v vs %v", actual, 1) - } - // test local fragment localMe := &MountEntry{ Table: credentialTableType, @@ -149,24 +128,25 @@ func TestActivityLog_Creation(t *testing.T) { local_ts := time.Now() a.AddClientToFragment(local_entity_id, "root", local_ts.Unix(), false, "local_mount_accessor") + localFragment = core.GetActiveLocalFragment() - if a.localFragment.OriginatingNode != a.nodeID { - t.Errorf("mismatched node ID, %q vs %q", a.localFragment.OriginatingNode, a.nodeID) + if localFragment.OriginatingNode != a.nodeID { + t.Errorf("mismatched node ID, %q vs %q", localFragment.OriginatingNode, a.nodeID) } - if a.localFragment.Clients == nil { + if localFragment.Clients == nil { t.Fatal("no local fragment entity slice") } - if a.localFragment.NonEntityTokens == nil { + if localFragment.NonEntityTokens == nil { t.Fatal("no local fragment token map") } - if len(a.localFragment.Clients) != 1 { - t.Fatalf("wrong number of entities %v", len(a.localFragment.Clients)) + if len(localFragment.Clients) != 1 { + t.Fatalf("wrong number of entities %v", len(localFragment.Clients)) } - er = a.localFragment.Clients[0] + er = localFragment.Clients[0] if er.ClientID != local_entity_id { t.Errorf("mimatched entity ID, %q vs %q", er.ClientID, local_entity_id) } @@ -192,17 +172,13 @@ func TestActivityLog_Creation_WrappingTokens(t *testing.T) { if a.logger == nil || a.view == nil { t.Fatal("activity log not initialized") } - a.fragmentLock.Lock() - if a.fragment != nil || a.currentGlobalFragment != nil { + if core.GetActiveGlobalFragment() != nil { t.Fatal("activity log already has fragment") } - a.fragmentLock.Unlock() - a.localFragmentLock.Lock() - if a.localFragment != nil { + if core.GetActiveLocalFragment() != nil { t.Fatal("activity log already has local fragment") } - a.localFragmentLock.Unlock() const namespace_id = "ns123" @@ -220,11 +196,9 @@ func TestActivityLog_Creation_WrappingTokens(t *testing.T) { t.Fatal(err) } - a.fragmentLock.Lock() - if a.fragment != nil || a.currentGlobalFragment != nil { + if core.GetActiveGlobalFragment() != nil { t.Fatal("fragment created") } - a.fragmentLock.Unlock() teNew := &logical.TokenEntry{ Path: "test", @@ -240,11 +214,9 @@ func TestActivityLog_Creation_WrappingTokens(t *testing.T) { t.Fatal(err) } - a.fragmentLock.Lock() - if a.fragment != nil || a.currentGlobalFragment != nil { + if core.GetActiveGlobalFragment() != nil { t.Fatal("fragment created") } - a.fragmentLock.Unlock() } func checkExpectedEntitiesInMap(t *testing.T, a *ActivityLog, entityIDs []string) { @@ -280,36 +252,15 @@ func TestActivityLog_UniqueEntities(t *testing.T) { a.AddEntityToFragment(id2, "root", t3.Unix()) a.AddEntityToFragment(id1, "root", t3.Unix()) - if a.fragment == nil || a.currentGlobalFragment == nil { - t.Fatal("no current fragment") + currentGlobalFragment := core.GetActiveGlobalFragment() + if currentGlobalFragment == nil { + t.Fatal("no current global fragment") } - - if len(a.fragment.Clients) != 2 { - t.Fatalf("number of entities is %v", len(a.fragment.Clients)) - } - if len(a.currentGlobalFragment.Clients) != 2 { - t.Fatalf("number of entities is %v", len(a.currentGlobalFragment.Clients)) + if len(currentGlobalFragment.Clients) != 2 { + t.Fatalf("number of entities is %v", len(currentGlobalFragment.Clients)) } - for i, e := range a.fragment.Clients { - expectedID := id1 - expectedTime := t1.Unix() - expectedNS := "root" - if i == 1 { - expectedID = id2 - expectedTime = t2.Unix() - } - if e.ClientID != expectedID { - t.Errorf("%v: expected %q, got %q", i, expectedID, e.ClientID) - } - if e.NamespaceID != expectedNS { - t.Errorf("%v: expected %q, got %q", i, expectedNS, e.NamespaceID) - } - if e.Timestamp != expectedTime { - t.Errorf("%v: expected %v, got %v", i, expectedTime, e.Timestamp) - } - } - for i, e := range a.currentGlobalFragment.Clients { + for i, e := range currentGlobalFragment.Clients { expectedID := id1 expectedTime := t1.Unix() expectedNS := "root" @@ -410,11 +361,11 @@ func TestActivityLog_SaveTokensToStorage(t *testing.T) { if err != nil { t.Fatalf("got error writing tokens to storage: %v", err) } - if a.fragment != nil || a.currentGlobalFragment != nil { + if core.GetActiveGlobalFragment() != nil { t.Errorf("fragment was not reset after write to storage") } - if a.localFragment != nil { + if core.GetActiveLocalFragment() != nil { t.Errorf("local fragment was not reset after write to storage") } @@ -446,11 +397,12 @@ func TestActivityLog_SaveTokensToStorage(t *testing.T) { if err != nil { t.Fatalf("got error writing tokens to storage: %v", err) } - if a.fragment != nil || a.currentGlobalFragment != nil { + + if core.GetActiveGlobalFragment() != nil { t.Errorf("fragment was not reset after write to storage") } - if a.localFragment != nil { + if core.GetActiveLocalFragment() != nil { t.Errorf("local fragment was not reset after write to storage") } @@ -492,7 +444,8 @@ func TestActivityLog_SaveTokensToStorageDoesNotUpdateTokenCount(t *testing.T) { a.SetStartTimestamp(time.Now().Unix()) // set a nonzero segment tokenPath := fmt.Sprintf("%sdirecttokens/%d/0", ActivityLogLocalPrefix, a.GetStartTimestamp()) - clientPath := fmt.Sprintf("sys/counters/activity/log/entity/%d/0", a.GetStartTimestamp()) + clientPath := fmt.Sprintf("sys/counters/activity/global/log/entity/%d/0", a.GetStartTimestamp()) + localPath := fmt.Sprintf("sys/counters/activity/local/log/entity/%d/0", a.GetStartTimestamp()) // Create some entries without entityIDs tokenEntryOne := logical.TokenEntry{NamespaceID: namespace.RootNamespaceID, Policies: []string{"hi"}} entityEntry := logical.TokenEntry{EntityID: "foo", NamespaceID: namespace.RootNamespaceID, Policies: []string{"hi"}} @@ -506,6 +459,9 @@ func TestActivityLog_SaveTokensToStorageDoesNotUpdateTokenCount(t *testing.T) { } } + // verify that the client got added to a local fragment + require.Len(t, core.GetActiveLocalFragment().Clients, 1) + idEntity, isTWE := entityEntry.CreateClientID() for i := 0; i < 2; i++ { err := a.HandleTokenUsage(ctx, &entityEntry, idEntity, isTWE) @@ -513,35 +469,53 @@ func TestActivityLog_SaveTokensToStorageDoesNotUpdateTokenCount(t *testing.T) { t.Fatal(err) } } + + // verify that the client got added to the global fragment + require.Len(t, core.GetActiveGlobalFragment().Clients, 1) + err := a.saveCurrentSegmentToStorage(ctx, false) if err != nil { t.Fatalf("got error writing TWEs to storage: %v", err) } // Assert that new elements have been written to the fragment - if a.fragment != nil || a.currentGlobalFragment != nil { + if core.GetActiveGlobalFragment() != nil { t.Errorf("fragment was not reset after write to storage") } - if a.localFragment != nil { + if core.GetActiveLocalFragment() != nil { t.Errorf("local fragment was not reset after write to storage") } // Assert that no tokens have been written to the fragment readSegmentFromStorageNil(t, core, tokenPath) + allClients := make([]*activity.EntityRecord, 0) e := readSegmentFromStorage(t, core, clientPath) out := &activity.EntityActivityLog{} err = proto.Unmarshal(e.Value, out) if err != nil { t.Fatalf("could not unmarshal protobuf: %v", err) } - if len(out.Clients) != 2 { - t.Fatalf("added 3 distinct TWEs and 2 distinct entity tokens that should all result in the same ID, got: %d", len(out.Clients)) + if len(out.Clients) != 1 { + t.Fatalf("added 2 distinct entity tokens that should all result in the same ID, got: %d", len(out.Clients)) } + allClients = append(allClients, out.Clients...) + + e = readSegmentFromStorage(t, core, localPath) + out = &activity.EntityActivityLog{} + err = proto.Unmarshal(e.Value, out) + if err != nil { + t.Fatalf("could not unmarshal protobuf: %v", err) + } + if len(out.Clients) != 1 { + t.Fatalf("added 3 distinct TWEs that should all result in the same ID, got: %d", len(out.Clients)) + } + allClients = append(allClients, out.Clients...) + nonEntityTokenFlag := false entityTokenFlag := false - for _, client := range out.Clients { + for _, client := range allClients { if client.NonEntity == true { nonEntityTokenFlag = true if client.ClientID != idNonEntity { @@ -578,7 +552,6 @@ func TestActivityLog_SaveEntitiesToStorage(t *testing.T) { now.Add(1 * time.Second).Unix(), now.Add(2 * time.Second).Unix(), } - path := fmt.Sprintf("%sentity/%d/0", ActivityLogPrefix, a.GetStartTimestamp()) globalPath := fmt.Sprintf("%sentity/%d/0", ActivityGlobalLogPrefix, a.GetStartTimestamp()) a.AddEntityToFragment(ids[0], "root", times[0]) @@ -587,14 +560,14 @@ func TestActivityLog_SaveEntitiesToStorage(t *testing.T) { if err != nil { t.Fatalf("got error writing entities to storage: %v", err) } - if a.fragment != nil || a.currentGlobalFragment != nil { + if core.GetActiveGlobalFragment() != nil { t.Errorf("fragment was not reset after write to storage") } - if a.localFragment != nil { + if core.GetActiveLocalFragment() != nil { t.Errorf("local fragment was not reset after write to storage") } - protoSegment := readSegmentFromStorage(t, core, path) + protoSegment := readSegmentFromStorage(t, core, globalPath) out := &activity.EntityActivityLog{} err = proto.Unmarshal(protoSegment.Value, out) if err != nil { @@ -609,14 +582,6 @@ func TestActivityLog_SaveEntitiesToStorage(t *testing.T) { t.Fatalf("got error writing segments to storage: %v", err) } - protoSegment = readSegmentFromStorage(t, core, path) - out = &activity.EntityActivityLog{} - err = proto.Unmarshal(protoSegment.Value, out) - if err != nil { - t.Fatalf("could not unmarshal protobuf: %v", err) - } - expectedEntityIDs(t, out, ids) - protoSegment = readSegmentFromStorage(t, core, globalPath) out = &activity.EntityActivityLog{} err = proto.Unmarshal(protoSegment.Value, out) @@ -686,7 +651,7 @@ func TestActivityLog_SaveEntitiesToStorageCommon(t *testing.T) { if err != nil { t.Fatalf("got error writing entities to storage: %v", err) } - if a.fragment != nil { + if core.GetActiveGlobalFragment() != nil || core.GetActiveLocalFragment() != nil { t.Errorf("fragment was not reset after write to storage") } @@ -775,8 +740,8 @@ func TestModifyResponseMonthsNilAppend(t *testing.T) { } // TestActivityLog_ReceivedFragment calls receivedFragment with a fragment and verifies it gets added to -// standbyFragmentsReceived and standbyGlobalFragmentsReceived. Send the same fragment again and then verify that it doesn't change the entity map but does -// get added to standbyFragmentsReceived and standbyGlobalFragmentsReceived. +// standbyGlobalFragmentsReceived. Send the same fragment again and then verify that it doesn't change the entity map but does +// get added to standbyGlobalFragmentsReceived. func TestActivityLog_ReceivedFragment(t *testing.T) { core, _, _ := TestCoreUnsealed(t) a := core.activityLog @@ -806,7 +771,7 @@ func TestActivityLog_ReceivedFragment(t *testing.T) { NonEntityTokens: make(map[string]uint64), } - if len(a.standbyFragmentsReceived) != 0 { + if len(a.standbyGlobalFragmentsReceived) != 0 { t.Fatalf("fragment already received") } @@ -814,10 +779,6 @@ func TestActivityLog_ReceivedFragment(t *testing.T) { checkExpectedEntitiesInMap(t, a, ids) - if len(a.standbyFragmentsReceived) != 1 { - t.Fatalf("fragment count is %v, expected 1", len(a.standbyFragmentsReceived)) - } - if len(a.standbyGlobalFragmentsReceived) != 1 { t.Fatalf("fragment count is %v, expected 1", len(a.standbyGlobalFragmentsReceived)) } @@ -827,9 +788,6 @@ func TestActivityLog_ReceivedFragment(t *testing.T) { checkExpectedEntitiesInMap(t, a, ids) - if len(a.standbyFragmentsReceived) != 2 { - t.Fatalf("fragment count is %v, expected 2", len(a.standbyFragmentsReceived)) - } if len(a.standbyGlobalFragmentsReceived) != 2 { t.Fatalf("fragment count is %v, expected 2", len(a.standbyGlobalFragmentsReceived)) } @@ -856,12 +814,17 @@ func TestActivityLog_availableLogs(t *testing.T) { // set up a few files in storage core, _, _ := TestCoreUnsealed(t) a := core.activityLog - paths := [...]string{"entity/1111/1", "entity/992/3"} + globalPaths := [...]string{"entity/1111/1", "entity/992/3", "entity/991/1"} + localPaths := [...]string{"entity/1111/1", "entity/992/3", "entity/990/1"} tokenPaths := [...]string{"directtokens/1111/1", "directtokens/1000000/1", "directtokens/992/1"} - expectedTimes := [...]time.Time{time.Unix(1000000, 0), time.Unix(1111, 0), time.Unix(992, 0)} + expectedTimes := [...]time.Time{time.Unix(1000000, 0), time.Unix(1111, 0), time.Unix(992, 0), time.Unix(991, 0), time.Unix(990, 0)} - for _, path := range paths { - WriteToStorage(t, core, ActivityLogPrefix+path, []byte("test")) + for _, path := range globalPaths { + WriteToStorage(t, core, ActivityGlobalLogPrefix+path, []byte("test")) + } + + for _, path := range localPaths { + WriteToStorage(t, core, ActivityLogLocalPrefix+path, []byte("test")) } for _, path := range tokenPaths { @@ -950,7 +913,7 @@ func TestActivityLog_createRegenerationIntentLog(t *testing.T) { } for _, subPath := range paths { - fullPath := ActivityLogPrefix + subPath + fullPath := ActivityGlobalLogPrefix + subPath WriteToStorage(t, core, fullPath, []byte("test")) deletePaths = append(deletePaths, fullPath) } @@ -999,9 +962,9 @@ func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { a.SetStartTimestamp(time.Now().Unix()) // set a nonzero segment startTimestamp := a.GetStartTimestamp() - path0 := fmt.Sprintf("sys/counters/activity/log/entity/%d/0", startTimestamp) - path1 := fmt.Sprintf("sys/counters/activity/log/entity/%d/1", startTimestamp) - path2 := fmt.Sprintf("sys/counters/activity/log/entity/%d/2", startTimestamp) + path0 := fmt.Sprintf("sys/counters/activity/global/log/entity/%d/0", startTimestamp) + path1 := fmt.Sprintf("sys/counters/activity/global/log/entity/%d/1", startTimestamp) + path2 := fmt.Sprintf("sys/counters/activity/global/log/entity/%d/2", startTimestamp) tokenPath := fmt.Sprintf("sys/counters/activity/local/log/directtokens/%d/0", startTimestamp) genID := func(i int) string { @@ -1094,11 +1057,6 @@ func TestActivityLog_MultipleFragmentsAndSegments(t *testing.T) { t.Fatalf("got error writing entities to storage: %v", err) } - seqNum := a.GetEntitySequenceNumber() - if seqNum != 2 { - t.Fatalf("expected sequence number 2, got %v", seqNum) - } - protoSegment0 = readSegmentFromStorage(t, core, path0) err = proto.Unmarshal(protoSegment0.Value, &entityLog0) if err != nil { @@ -1305,12 +1263,8 @@ func TestActivityLog_parseSegmentNumberFromPath(t *testing.T) { func TestActivityLog_getLastEntitySegmentNumber(t *testing.T) { core, _, _ := TestCoreUnsealed(t) a := core.activityLog - paths := [...]string{"entity/992/0", "entity/1000/-1", "entity/1001/foo", "entity/1111/0", "entity/1111/1"} globalPaths := [...]string{"entity/992/0", "entity/1000/-1", "entity/1001/foo", "entity/1111/1"} localPaths := [...]string{"entity/992/0", "entity/1000/-1", "entity/1001/foo", "entity/1111/0", "entity/1111/1"} - for _, path := range paths { - WriteToStorage(t, core, ActivityLogPrefix+path, []byte("test")) - } for _, path := range globalPaths { WriteToStorage(t, core, ActivityGlobalLogPrefix+path, []byte("test")) } @@ -1320,42 +1274,36 @@ func TestActivityLog_getLastEntitySegmentNumber(t *testing.T) { testCases := []struct { input int64 - expectedVal uint64 expectedGlobalVal uint64 expectedLocalVal uint64 expectExists bool }{ { input: 992, - expectedVal: 0, expectedGlobalVal: 0, expectedLocalVal: 0, expectExists: true, }, { input: 1000, - expectedVal: 0, expectedGlobalVal: 0, expectedLocalVal: 0, expectExists: false, }, { input: 1001, - expectedVal: 0, expectedGlobalVal: 0, expectedLocalVal: 0, expectExists: false, }, { input: 1111, - expectedVal: 1, expectedGlobalVal: 1, expectedLocalVal: 1, expectExists: true, }, { input: 2222, - expectedVal: 0, expectedGlobalVal: 0, expectedLocalVal: 0, expectExists: false, @@ -1364,16 +1312,13 @@ func TestActivityLog_getLastEntitySegmentNumber(t *testing.T) { ctx := context.Background() for _, tc := range testCases { - result, localSegmentNumber, globalSegmentNumber, exists, err := a.getLastEntitySegmentNumber(ctx, time.Unix(tc.input, 0)) + localSegmentNumber, globalSegmentNumber, exists, err := a.getLastEntitySegmentNumber(ctx, time.Unix(tc.input, 0)) if err != nil { t.Fatalf("unexpected error for input %d: %v", tc.input, err) } if exists != tc.expectExists { t.Errorf("expected result exists: %t, got: %t for input: %d", tc.expectExists, exists, tc.input) } - if result != tc.expectedVal { - t.Errorf("expected: %d got: %d for input: %d", tc.expectedVal, result, tc.input) - } if globalSegmentNumber != tc.expectedGlobalVal { t.Errorf("expected: %d got: %d for input: %d", tc.expectedGlobalVal, globalSegmentNumber, tc.input) } @@ -1505,15 +1450,6 @@ func (a *ActivityLog) resetEntitiesInMemory(t *testing.T) { a.globalFragmentLock.Lock() defer a.globalFragmentLock.Unlock() - a.currentSegment = segmentInfo{ - startTimestamp: time.Time{}.Unix(), - currentClients: &activity.EntityActivityLog{ - Clients: make([]*activity.EntityRecord, 0), - }, - tokenCount: a.currentSegment.tokenCount, - clientSequenceNumber: 0, - } - a.currentGlobalSegment = segmentInfo{ startTimestamp: time.Time{}.Unix(), currentClients: &activity.EntityActivityLog{ @@ -1532,7 +1468,6 @@ func (a *ActivityLog) resetEntitiesInMemory(t *testing.T) { clientSequenceNumber: 0, } - a.partialMonthClientTracker = make(map[string]*activity.EntityRecord) a.partialMonthLocalClientTracker = make(map[string]*activity.EntityRecord) a.globalPartialMonthClientTracker = make(map[string]*activity.EntityRecord) } @@ -1549,7 +1484,6 @@ func TestActivityLog_loadCurrentClientSegment(t *testing.T) { CountByNamespaceID: tokenRecords, } a.l.Lock() - a.currentSegment.tokenCount = tokenCount a.currentLocalSegment.tokenCount = tokenCount a.l.Unlock() @@ -1610,7 +1544,6 @@ func TestActivityLog_loadCurrentClientSegment(t *testing.T) { if err != nil { t.Fatalf(err.Error()) } - WriteToStorage(t, core, ActivityLogPrefix+tc.path, data) WriteToStorage(t, core, ActivityGlobalLogPrefix+tc.path, data) WriteToStorage(t, core, ActivityLogLocalPrefix+tc.path, data) } @@ -1624,7 +1557,7 @@ func TestActivityLog_loadCurrentClientSegment(t *testing.T) { // loadCurrentClientSegment requires us to grab the fragment lock and the // activityLog lock, as per the comment in the loadCurrentClientSegment // function - err := a.loadCurrentClientSegment(ctx, time.Unix(tc.time, 0), tc.seqNum, tc.seqNum, tc.seqNum) + err := a.loadCurrentClientSegment(ctx, time.Unix(tc.time, 0), tc.seqNum, tc.seqNum) a.localFragmentLock.Unlock() a.globalFragmentLock.Unlock() a.fragmentLock.Unlock() @@ -1639,15 +1572,9 @@ func TestActivityLog_loadCurrentClientSegment(t *testing.T) { // verify accurate data in in-memory current segment require.Equal(t, tc.time, a.GetStartTimestamp()) - require.Equal(t, tc.seqNum, a.GetEntitySequenceNumber()) require.Equal(t, tc.seqNum, a.GetGlobalEntitySequenceNumber()) require.Equal(t, tc.seqNum, a.GetLocalEntitySequenceNumber()) - currentEntities := a.GetCurrentEntities() - if !entityRecordsEqual(t, currentEntities.Clients, tc.entities.Clients) { - t.Errorf("bad data loaded. expected: %v, got: %v for path %q", tc.entities.Clients, currentEntities, tc.path) - } - globalClients := core.GetActiveGlobalClientsList() if err := ActiveEntitiesEqual(globalClients, tc.entities.Clients); err != nil { t.Errorf("bad data loaded into active global entities. expected only set of EntityID from %v in %v for path %q: %v", tc.entities.Clients, globalClients, tc.path, err) @@ -1739,7 +1666,6 @@ func TestActivityLog_loadPriorEntitySegment(t *testing.T) { if err != nil { t.Fatalf(err.Error()) } - WriteToStorage(t, core, ActivityLogPrefix+tc.path, data) WriteToStorage(t, core, ActivityGlobalLogPrefix+tc.path, data) WriteToStorage(t, core, ActivityLogLocalPrefix+tc.path, data) } @@ -1748,20 +1674,22 @@ func TestActivityLog_loadPriorEntitySegment(t *testing.T) { for _, tc := range testCases { if tc.refresh { a.l.Lock() - a.fragmentLock.Lock() a.localFragmentLock.Lock() - a.partialMonthClientTracker = make(map[string]*activity.EntityRecord) a.partialMonthLocalClientTracker = make(map[string]*activity.EntityRecord) a.globalPartialMonthClientTracker = make(map[string]*activity.EntityRecord) - a.currentSegment.startTimestamp = tc.time a.currentGlobalSegment.startTimestamp = tc.time a.currentLocalSegment.startTimestamp = tc.time - a.fragmentLock.Unlock() a.localFragmentLock.Unlock() a.l.Unlock() } - err := a.loadPriorEntitySegment(ctx, time.Unix(tc.time, 0), tc.seqNum) + // load global segments + err := a.loadPriorEntitySegment(ctx, time.Unix(tc.time, 0), tc.seqNum, false) + if err != nil { + t.Fatalf("got error loading data for %q: %v", tc.path, err) + } + // load local segments + err = a.loadPriorEntitySegment(ctx, time.Unix(tc.time, 0), tc.seqNum, true) if err != nil { t.Fatalf("got error loading data for %q: %v", tc.path, err) } @@ -1937,14 +1865,12 @@ func setupActivityRecordsInStorage(t *testing.T, base time.Time, includeEntities } switch i { case 0: - WriteToStorage(t, core, ActivityLogPrefix+"entity/"+fmt.Sprint(monthsAgo.Unix())+"/0", entityData) WriteToStorage(t, core, ActivityGlobalLogPrefix+"entity/"+fmt.Sprint(monthsAgo.Unix())+"/0", entityData) case len(entityRecords) - 1: // local data WriteToStorage(t, core, ActivityLogLocalPrefix+"entity/"+fmt.Sprint(base.Unix())+"/"+strconv.Itoa(i-1), entityData) default: - WriteToStorage(t, core, ActivityLogPrefix+"entity/"+fmt.Sprint(base.Unix())+"/"+strconv.Itoa(i-1), entityData) WriteToStorage(t, core, ActivityGlobalLogPrefix+"entity/"+fmt.Sprint(base.Unix())+"/"+strconv.Itoa(i-1), entityData) } } @@ -1988,16 +1914,24 @@ func TestActivityLog_refreshFromStoredLog(t *testing.T) { } wg.Wait() + // active clients for the entire month expectedActive := &activity.EntityActivityLog{ Clients: expectedClientRecords[1:], } - expectedCurrent := &activity.EntityActivityLog{ - Clients: expectedClientRecords[len(expectedClientRecords)-2 : len(expectedClientRecords)-1], + expectedActiveGlobal := &activity.EntityActivityLog{ + Clients: expectedClientRecords[1 : len(expectedClientRecords)-1], } + + // local client is only added to the newest segment for the current month. This should also appear in the active clients for the entire month. expectedCurrentLocal := &activity.EntityActivityLog{ Clients: expectedClientRecords[len(expectedClientRecords)-1:], } + // global clients added to the newest local entity segment + expectedCurrent := &activity.EntityActivityLog{ + Clients: expectedClientRecords[len(expectedClientRecords)-2 : len(expectedClientRecords)-1], + } + currentEntities := a.GetCurrentGlobalEntities() if !entityRecordsEqual(t, currentEntities.Clients, expectedCurrent.Clients) { // we only expect the newest entity segment to be loaded (for the current month) @@ -2021,6 +1955,19 @@ func TestActivityLog_refreshFromStoredLog(t *testing.T) { // we expect activeClients to be loaded for the entire month t.Errorf("bad data loaded into active entities. expected only set of EntityID from %v in %v: %v", expectedActive.Clients, activeClients, err) } + + // verify active global clients list + activeGlobalClients := a.core.GetActiveGlobalClientsList() + if err := ActiveEntitiesEqual(activeGlobalClients, expectedActiveGlobal.Clients); err != nil { + // we expect activeClients to be loaded for the entire month + t.Errorf("bad data loaded into active global entities. expected only set of EntityID from %v in %v: %v", expectedActiveGlobal.Clients, activeGlobalClients, err) + } + // verify active local clients list + activeLocalClients := a.core.GetActiveLocalClientsList() + if err := ActiveEntitiesEqual(activeLocalClients, expectedCurrentLocal.Clients); err != nil { + // we expect activeClients to be loaded for the entire month + t.Errorf("bad data loaded into active local entities. expected only set of EntityID from %v in %v: %v", expectedCurrentLocal.Clients, activeLocalClients, err) + } } // TestActivityLog_refreshFromStoredLogWithBackgroundLoadingCancelled writes data from 3 months ago to this month. The @@ -2082,6 +2029,18 @@ func TestActivityLog_refreshFromStoredLogWithBackgroundLoadingCancelled(t *testi // we only expect activeClients to be loaded for the newest segment (for the current month) t.Error(err) } + + // verify if the right global clients are loaded for the newest segment (for the current month) + activeGlobalClients := a.core.GetActiveGlobalClientsList() + if err := ActiveEntitiesEqual(activeGlobalClients, expectedCurrent.Clients); err != nil { + t.Error(err) + } + + // the right local clients are loaded for the newest segment (for the current month) + activeLocalClients := a.core.GetActiveLocalClientsList() + if err := ActiveEntitiesEqual(activeLocalClients, currentLocalEntities.Clients); err != nil { + t.Error(err) + } } // TestActivityLog_refreshFromStoredLogContextCancelled writes data from 3 months ago to this month and calls @@ -2115,9 +2074,6 @@ func TestActivityLog_refreshFromStoredLogNoTokens(t *testing.T) { expectedActive := &activity.EntityActivityLog{ Clients: expectedClientRecords[1:], } - expectedCurrent := &activity.EntityActivityLog{ - Clients: expectedClientRecords[len(expectedClientRecords)-2 : len(expectedClientRecords)-1], - } expectedCurrentGlobal := &activity.EntityActivityLog{ Clients: expectedClientRecords[len(expectedClientRecords)-2 : len(expectedClientRecords)-1], } @@ -2125,12 +2081,6 @@ func TestActivityLog_refreshFromStoredLogNoTokens(t *testing.T) { Clients: expectedClientRecords[len(expectedClientRecords)-1:], } - currentEntities := a.GetCurrentEntities() - if !entityRecordsEqual(t, currentEntities.Clients, expectedCurrent.Clients) { - // we expect all segments for the current month to be loaded - t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, currentEntities) - } - currentGlobalEntities := a.GetCurrentGlobalEntities() if !entityRecordsEqual(t, currentGlobalEntities.Clients, expectedCurrentGlobal.Clients) { // we only expect the newest entity segment to be loaded (for the current month) @@ -2174,7 +2124,7 @@ func TestActivityLog_refreshFromStoredLogNoEntities(t *testing.T) { t.Errorf("bad activity token counts loaded. expected: %v got: %v", expectedTokenCounts, nsCount) } - currentEntities := a.GetCurrentEntities() + currentEntities := a.GetCurrentGlobalEntities() if len(currentEntities.Clients) > 0 { t.Errorf("expected no current entity segment to be loaded. got: %v", currentEntities) } @@ -2245,7 +2195,7 @@ func TestActivityLog_refreshFromStoredLogPreviousMonth(t *testing.T) { Clients: expectedClientRecords[len(expectedClientRecords)-2 : len(expectedClientRecords)-1], } - currentEntities := a.GetCurrentEntities() + currentEntities := a.GetCurrentGlobalEntities() if !entityRecordsEqual(t, currentEntities.Clients, expectedCurrent.Clients) { // we only expect the newest entity segment to be loaded (for the current month) t.Errorf("bad activity entity logs loaded. expected: %v got: %v", expectedCurrent, currentEntities) @@ -2343,16 +2293,7 @@ func TestActivityLog_DeleteWorker(t *testing.T) { "entity/1112/1", } for _, path := range paths { - WriteToStorage(t, core, ActivityLogPrefix+path, []byte("test")) - } - - localPaths := []string{ - "entity/1111/1", - "entity/1111/2", - "entity/1111/3", - "entity/1112/1", - } - for _, path := range localPaths { + WriteToStorage(t, core, ActivityGlobalLogPrefix+path, []byte("test")) WriteToStorage(t, core, ActivityLogLocalPrefix+path, []byte("test")) } @@ -2376,14 +2317,14 @@ func TestActivityLog_DeleteWorker(t *testing.T) { } // Check segments still present - readSegmentFromStorage(t, core, ActivityLogPrefix+"entity/1112/1") + readSegmentFromStorage(t, core, ActivityGlobalLogPrefix+"entity/1112/1") readSegmentFromStorage(t, core, ActivityLogLocalPrefix+"entity/1112/1") readSegmentFromStorage(t, core, ActivityLogLocalPrefix+"directtokens/1112/1") // Check other segments not present - expectMissingSegment(t, core, ActivityLogPrefix+"entity/1111/1") - expectMissingSegment(t, core, ActivityLogPrefix+"entity/1111/2") - expectMissingSegment(t, core, ActivityLogPrefix+"entity/1111/3") + expectMissingSegment(t, core, ActivityGlobalLogPrefix+"entity/1111/1") + expectMissingSegment(t, core, ActivityGlobalLogPrefix+"entity/1111/2") + expectMissingSegment(t, core, ActivityGlobalLogPrefix+"entity/1111/3") expectMissingSegment(t, core, ActivityLogLocalPrefix+"entity/1111/1") expectMissingSegment(t, core, ActivityLogLocalPrefix+"entity/1111/2") expectMissingSegment(t, core, ActivityLogLocalPrefix+"entity/1111/3") @@ -2473,7 +2414,7 @@ func TestActivityLog_EnableDisable(t *testing.T) { } // verify segment exists - path := fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, seg1) + path := fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, seg1) readSegmentFromStorage(t, core, path) // Add in-memory fragment @@ -2503,7 +2444,7 @@ func TestActivityLog_EnableDisable(t *testing.T) { } // Verify empty segments are present - path = fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, seg2) + path = fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, seg2) readSegmentFromStorage(t, core, path) path = fmt.Sprintf("%vdirecttokens/%v/0", ActivityLogLocalPrefix, seg2) @@ -2544,6 +2485,8 @@ func TestActivityLog_EndOfMonth(t *testing.T) { id2 := "22222222-2222-2222-2222-222222222222" id3 := "33333333-3333-3333-3333-333333333333" id4 := "44444444-4444-4444-4444-444444444444" + + // add global data a.AddEntityToFragment(id1, "root", time.Now().Unix()) // add local data @@ -2567,22 +2510,13 @@ func TestActivityLog_EndOfMonth(t *testing.T) { a.HandleEndOfMonth(ctx, month1) // Check segment is present, with 1 entity - path := fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, segment0) + path := fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, segment0) protoSegment := readSegmentFromStorage(t, core, path) out := &activity.EntityActivityLog{} err = proto.Unmarshal(protoSegment.Value, out) if err != nil { t.Fatal(err) } - expectedEntityIDs(t, out, []string{id1, id4}) - - path = fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, segment0) - protoSegment = readSegmentFromStorage(t, core, path) - out = &activity.EntityActivityLog{} - err = proto.Unmarshal(protoSegment.Value, out) - if err != nil { - t.Fatal(err) - } expectedEntityIDs(t, out, []string{id1}) path = fmt.Sprintf("%ventity/%v/0", ActivityLogLocalPrefix, segment0) @@ -2649,18 +2583,6 @@ func TestActivityLog_EndOfMonth(t *testing.T) { for i, tc := range testCases { t.Logf("checking segment %v timestamp %v", i, tc.SegmentTimestamp) - expectedAllEntities := make([]string, 0) - expectedAllEntities = append(expectedAllEntities, tc.ExpectedGlobalEntityIDs...) - expectedAllEntities = append(expectedAllEntities, tc.ExpectedLocalEntityIDs...) - path := fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, tc.SegmentTimestamp) - protoSegment := readSegmentFromStorage(t, core, path) - out := &activity.EntityActivityLog{} - err = proto.Unmarshal(protoSegment.Value, out) - if err != nil { - t.Fatalf("could not unmarshal protobuf: %v", err) - } - expectedEntityIDs(t, out, expectedAllEntities) - // Check for global entities at global storage path path = fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, tc.SegmentTimestamp) protoSegment = readSegmentFromStorage(t, core, path) @@ -2844,7 +2766,7 @@ func TestActivityLog_CalculatePrecomputedQueriesWithMixedTWEs(t *testing.T) { if err != nil { t.Fatal(err) } - path := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment) + path := fmt.Sprintf("%ventity/%v/%v", ActivityGlobalLogPrefix, segment.StartTime, segment.Segment) WriteToStorage(t, core, path, data) } expectedCounts := []struct { @@ -3125,10 +3047,10 @@ func TestActivityLog_SaveAfterDisable(t *testing.T) { t.Fatal(err) } - path := ActivityLogPrefix + "entity/0/0" + path := ActivityGlobalLogPrefix + "entity/0/0" expectMissingSegment(t, core, path) - path = fmt.Sprintf("%ventity/%v/0", ActivityLogPrefix, startTimestamp) + path = fmt.Sprintf("%ventity/%v/0", ActivityGlobalLogPrefix, startTimestamp) expectMissingSegment(t, core, path) } @@ -3229,7 +3151,7 @@ func TestActivityLog_Precompute(t *testing.T) { if err != nil { t.Fatal(err) } - path := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment) + path := fmt.Sprintf("%ventity/%v/%v", ActivityGlobalLogPrefix, segment.StartTime, segment.Segment) WriteToStorage(t, core, path, data) } @@ -3540,7 +3462,7 @@ func TestActivityLog_Precompute_SkipMonth(t *testing.T) { if err != nil { t.Fatal(err) } - path := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment) + path := fmt.Sprintf("%ventity/%v/%v", ActivityGlobalLogPrefix, segment.StartTime, segment.Segment) WriteToStorage(t, core, path, data) } @@ -3757,7 +3679,7 @@ func TestActivityLog_PrecomputeNonEntityTokensWithID(t *testing.T) { if err != nil { t.Fatal(err) } - path := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, segment.StartTime, segment.Segment) + path := fmt.Sprintf("%ventity/%v/%v", ActivityGlobalLogPrefix, segment.StartTime, segment.Segment) WriteToStorage(t, core, path, data) } @@ -4146,7 +4068,7 @@ func TestActivityLog_Deletion(t *testing.T) { for i, start := range times { // no entities in some months, just for fun for j := 0; j < (i+3)%5; j++ { - entityPath := fmt.Sprintf("%ventity/%v/%v", ActivityLogPrefix, start.Unix(), j) + entityPath := fmt.Sprintf("%ventity/%v/%v", ActivityGlobalLogPrefix, start.Unix(), j) paths[i] = append(paths[i], entityPath) WriteToStorage(t, core, entityPath, []byte("test")) } @@ -4522,7 +4444,7 @@ func TestActivityLog_partialMonthClientCountWithMultipleMountPaths(t *testing.T) if err != nil { t.Fatalf(err.Error()) } - storagePath := fmt.Sprintf("%sentity/%d/%d", ActivityLogPrefix, timeutil.StartOfMonth(now).Unix(), i) + storagePath := fmt.Sprintf("%sentity/%d/%d", ActivityGlobalLogPrefix, timeutil.StartOfMonth(now).Unix(), i) WriteToStorage(t, core, storagePath, entityData) } @@ -5162,7 +5084,6 @@ func TestAddActivityToFragment(t *testing.T) { a := core.activityLog a.SetEnable(true) - require.Nil(t, a.fragment) require.Nil(t, a.localFragment) require.Nil(t, a.currentGlobalFragment) @@ -5254,10 +5175,6 @@ func TestAddActivityToFragment(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { var mountAccessor string - a.fragmentLock.RLock() - numClientsBefore := len(a.fragment.Clients) - a.fragmentLock.RUnlock() - a.globalFragmentLock.RLock() globalClientsBefore := len(a.currentGlobalFragment.Clients) a.globalFragmentLock.RUnlock() @@ -5281,9 +5198,6 @@ func TestAddActivityToFragment(t *testing.T) { a.AddActivityToFragment(tc.id, ns, 0, tc.activityType, mount) } - a.fragmentLock.RLock() - defer a.fragmentLock.RUnlock() - numClientsAfter := len(a.fragment.Clients) a.globalFragmentLock.RLock() defer a.globalFragmentLock.RUnlock() globalClientsAfter := len(a.currentGlobalFragment.Clients) @@ -5312,24 +5226,6 @@ func TestAddActivityToFragment(t *testing.T) { } } - // for now local clients are added to both regular fragment and local fragment. - // this will be modified in ticket vault-31234 - if tc.isAdded { - require.Equal(t, numClientsBefore+1, numClientsAfter) - } else { - require.Equal(t, numClientsBefore, numClientsAfter) - } - - require.Contains(t, a.partialMonthClientTracker, tc.expectedID) - require.True(t, proto.Equal(&activity.EntityRecord{ - ClientID: tc.expectedID, - NamespaceID: ns, - Timestamp: 0, - NonEntity: tc.isNonEntity, - MountAccessor: mountAccessor, - ClientType: tc.activityType, - }, a.partialMonthClientTracker[tc.expectedID])) - if tc.isLocal { require.Contains(t, a.partialMonthLocalClientTracker, tc.expectedID) require.True(t, proto.Equal(&activity.EntityRecord{ @@ -5371,7 +5267,6 @@ func TestGetAllPartialMonthClients(t *testing.T) { a := core.activityLog a.SetEnable(true) - require.Nil(t, a.fragment) require.Nil(t, a.localFragment) require.Nil(t, a.currentGlobalFragment) @@ -5385,7 +5280,6 @@ func TestGetAllPartialMonthClients(t *testing.T) { a.AddActivityToFragment(clientID, ns, 0, entityActivityType, mount) require.NotNil(t, a.localFragment) - require.NotNil(t, a.fragment) require.NotNil(t, a.currentGlobalFragment) // create a local mount accessor @@ -5783,37 +5677,6 @@ func TestCreateSegment_StoreSegment(t *testing.T) { global: true, forceStore: true, }, - - { - testName: "[non-global] max segment size", - numClients: ActivitySegmentClientCapacity, - maxClientsPerFragment: ActivitySegmentClientCapacity, - global: false, - }, - { - testName: "[non-global] max segment size, multiple fragments", - numClients: ActivitySegmentClientCapacity, - maxClientsPerFragment: ActivitySegmentClientCapacity - 1, - global: false, - }, - { - testName: "[non-global] roll over", - numClients: ActivitySegmentClientCapacity + 2, - maxClientsPerFragment: ActivitySegmentClientCapacity, - global: false, - }, - { - testName: "[non-global] max segment size, rollover multiple fragments", - numClients: ActivitySegmentClientCapacity * 2, - maxClientsPerFragment: ActivitySegmentClientCapacity - 1, - global: false, - }, - { - testName: "[non-global] max client size, drop clients", - numClients: ActivitySegmentClientCapacity*2 + 1, - maxClientsPerFragment: ActivitySegmentClientCapacity, - global: false, - }, { testName: "[local] max client size, drop clients", numClients: ActivitySegmentClientCapacity*2 + 1, @@ -5910,10 +5773,7 @@ func TestCreateSegment_StoreSegment(t *testing.T) { segment := &a.currentGlobalSegment if !test.global { - segment = &a.currentSegment - if test.pathPrefix == activityLocalPathPrefix { - segment = &a.currentLocalSegment - } + segment = &a.currentLocalSegment } // Create segments and write to storage @@ -5932,24 +5792,13 @@ func TestCreateSegment_StoreSegment(t *testing.T) { clientTotal += len(entity.GetClients()) } } else { - if test.pathPrefix == activityLocalPathPrefix { - for { - entity, err := reader.ReadLocalEntity(ctx) - if errors.Is(err, io.EOF) { - break - } - require.NoError(t, err) - clientTotal += len(entity.GetClients()) - } - } else { - for { - entity, err := reader.ReadEntity(ctx) - if errors.Is(err, io.EOF) { - break - } - require.NoError(t, err) - clientTotal += len(entity.GetClients()) + for { + entity, err := reader.ReadLocalEntity(ctx) + if errors.Is(err, io.EOF) { + break } + require.NoError(t, err) + clientTotal += len(entity.GetClients()) } } diff --git a/vault/activity_log_testing_util.go b/vault/activity_log_testing_util.go index f9bb25ba142e..d0fd4b7b35ae 100644 --- a/vault/activity_log_testing_util.go +++ b/vault/activity_log_testing_util.go @@ -36,7 +36,7 @@ func (c *Core) InjectActivityLogDataThisMonth(t *testing.T) map[string]*activity Timestamp: c.activityLog.clock.Now().Unix(), NonEntity: i%2 == 0, } - c.activityLog.partialMonthClientTracker[er.ClientID] = er + c.activityLog.globalPartialMonthClientTracker[er.ClientID] = er } if constants.IsEnterprise { @@ -49,12 +49,12 @@ func (c *Core) InjectActivityLogDataThisMonth(t *testing.T) map[string]*activity Timestamp: c.activityLog.clock.Now().Unix(), NonEntity: i%2 == 0, } - c.activityLog.partialMonthClientTracker[er.ClientID] = er + c.activityLog.globalPartialMonthClientTracker[er.ClientID] = er } } } - return c.activityLog.partialMonthClientTracker + return c.activityLog.globalPartialMonthClientTracker } // GetActiveClients returns the in-memory globalPartialMonthClientTracker and partialMonthLocalClientTracker from an @@ -93,6 +93,7 @@ func (c *Core) GetActiveClientsList() []*activity.EntityRecord { return out } +// GetActiveLocalClientsList returns the active clients from globalPartialMonthClientTracker in activity log func (c *Core) GetActiveGlobalClientsList() []*activity.EntityRecord { out := []*activity.EntityRecord{} c.activityLog.globalFragmentLock.RLock() @@ -104,6 +105,7 @@ func (c *Core) GetActiveGlobalClientsList() []*activity.EntityRecord { return out } +// GetActiveLocalClientsList returns the active clients from partialMonthLocalClientTracker in activity log func (c *Core) GetActiveLocalClientsList() []*activity.EntityRecord { out := []*activity.EntityRecord{} c.activityLog.localFragmentLock.RLock() @@ -115,21 +117,14 @@ func (c *Core) GetActiveLocalClientsList() []*activity.EntityRecord { return out } -// GetCurrentEntities returns the current entity activity log -func (a *ActivityLog) GetCurrentEntities() *activity.EntityActivityLog { - a.l.RLock() - defer a.l.RUnlock() - return a.currentSegment.currentClients -} - -// GetCurrentGlobalEntities returns the current global entity activity log +// GetCurrentGlobalEntities returns the current clients from currentGlobalSegment in activity log func (a *ActivityLog) GetCurrentGlobalEntities() *activity.EntityActivityLog { a.l.RLock() defer a.l.RUnlock() return a.currentGlobalSegment.currentClients } -// GetCurrentLocalEntities returns the current local entity activity log +// GetCurrentLocalEntities returns the current clients from currentLocalSegment in activity log func (a *ActivityLog) GetCurrentLocalEntities() *activity.EntityActivityLog { a.l.RLock() defer a.l.RUnlock() @@ -169,8 +164,11 @@ func (a *ActivityLog) SetStandbyEnable(ctx context.Context, enabled bool) { // NOTE: AddTokenToFragment is deprecated and can no longer be used, except for // testing backward compatibility. Please use AddClientToFragment instead. func (a *ActivityLog) AddTokenToFragment(namespaceID string) { - a.fragmentLock.Lock() - defer a.fragmentLock.Unlock() + a.globalFragmentLock.Lock() + defer a.globalFragmentLock.Unlock() + + a.localFragmentLock.Lock() + defer a.localFragmentLock.Unlock() if !a.enabled { return @@ -178,7 +176,7 @@ func (a *ActivityLog) AddTokenToFragment(namespaceID string) { a.createCurrentFragment() - a.fragment.NonEntityTokens[namespaceID] += 1 + a.localFragment.NonEntityTokens[namespaceID] += 1 } func RandStringBytes(n int) string { @@ -199,20 +197,29 @@ func (a *ActivityLog) ExpectCurrentSegmentRefreshed(t *testing.T, expectedStart defer a.l.RUnlock() a.fragmentLock.RLock() defer a.fragmentLock.RUnlock() - if a.currentSegment.currentClients == nil { + if a.currentGlobalSegment.currentClients == nil { t.Fatalf("expected non-nil currentSegment.currentClients") } - if a.currentSegment.currentClients.Clients == nil { + if a.currentGlobalSegment.currentClients.Clients == nil { t.Errorf("expected non-nil currentSegment.currentClients.Entities") } - if a.currentSegment.tokenCount == nil { + if a.currentGlobalSegment.tokenCount == nil { t.Fatalf("expected non-nil currentSegment.tokenCount") } - if a.currentSegment.tokenCount.CountByNamespaceID == nil { + if a.currentGlobalSegment.tokenCount.CountByNamespaceID == nil { t.Errorf("expected non-nil currentSegment.tokenCount.CountByNamespaceID") } - if a.partialMonthClientTracker == nil { - t.Errorf("expected non-nil partialMonthClientTracker") + if a.currentLocalSegment.currentClients == nil { + t.Fatalf("expected non-nil currentSegment.currentClients") + } + if a.currentLocalSegment.currentClients.Clients == nil { + t.Errorf("expected non-nil currentSegment.currentClients.Entities") + } + if a.currentLocalSegment.tokenCount == nil { + t.Fatalf("expected non-nil currentSegment.tokenCount") + } + if a.currentLocalSegment.tokenCount.CountByNamespaceID == nil { + t.Errorf("expected non-nil currentSegment.tokenCount.CountByNamespaceID") } if a.partialMonthLocalClientTracker == nil { t.Errorf("expected non-nil partialMonthLocalClientTracker") @@ -220,14 +227,14 @@ func (a *ActivityLog) ExpectCurrentSegmentRefreshed(t *testing.T, expectedStart if a.globalPartialMonthClientTracker == nil { t.Errorf("expected non-nil globalPartialMonthClientTracker") } - if len(a.currentSegment.currentClients.Clients) > 0 { - t.Errorf("expected no current entity segment to be loaded. got: %v", a.currentSegment.currentClients) + if len(a.currentGlobalSegment.currentClients.Clients) > 0 { + t.Errorf("expected no current entity segment to be loaded. got: %v", a.currentGlobalSegment.currentClients) } - if len(a.currentSegment.tokenCount.CountByNamespaceID) > 0 { - t.Errorf("expected no token counts to be loaded. got: %v", a.currentSegment.tokenCount.CountByNamespaceID) + if len(a.currentLocalSegment.currentClients.Clients) > 0 { + t.Errorf("expected no current entity segment to be loaded. got: %v", a.currentLocalSegment.currentClients) } - if len(a.partialMonthClientTracker) > 0 { - t.Errorf("expected no active entity segment to be loaded. got: %v", a.partialMonthClientTracker) + if len(a.currentLocalSegment.tokenCount.CountByNamespaceID) > 0 { + t.Errorf("expected no token counts to be loaded. got: %v", a.currentLocalSegment.tokenCount.CountByNamespaceID) } if len(a.partialMonthLocalClientTracker) > 0 { t.Errorf("expected no active entity segment to be loaded. got: %v", a.partialMonthLocalClientTracker) @@ -237,17 +244,12 @@ func (a *ActivityLog) ExpectCurrentSegmentRefreshed(t *testing.T, expectedStart } if verifyTimeNotZero { - if a.currentSegment.startTimestamp == 0 { - t.Error("bad start timestamp. expected no reset but timestamp was reset") - } if a.currentGlobalSegment.startTimestamp == 0 { t.Error("bad start timestamp. expected no reset but timestamp was reset") } if a.currentLocalSegment.startTimestamp == 0 { t.Error("bad start timestamp. expected no reset but timestamp was reset") } - } else if a.currentSegment.startTimestamp != expectedStart { - t.Errorf("bad start timestamp. expected: %v got: %v", expectedStart, a.currentSegment.startTimestamp) } else if a.currentGlobalSegment.startTimestamp != expectedStart { t.Errorf("bad start timestamp. expected: %v got: %v", expectedStart, a.currentGlobalSegment.startTimestamp) } else if a.currentLocalSegment.startTimestamp != expectedStart { @@ -270,9 +272,7 @@ func ActiveEntitiesEqual(active []*activity.EntityRecord, test []*activity.Entit func (a *ActivityLog) GetStartTimestamp() int64 { a.l.RLock() defer a.l.RUnlock() - // TODO: We will substitute a.currentSegment with a.currentLocalSegment when we remove - // a.currentSegment from the code - if a.currentGlobalSegment.startTimestamp != a.currentSegment.startTimestamp { + if a.currentGlobalSegment.startTimestamp != a.currentLocalSegment.startTimestamp { return -1 } return a.currentGlobalSegment.startTimestamp @@ -282,7 +282,6 @@ func (a *ActivityLog) GetStartTimestamp() int64 { func (a *ActivityLog) SetStartTimestamp(timestamp int64) { a.l.Lock() defer a.l.Unlock() - a.currentSegment.startTimestamp = timestamp a.currentGlobalSegment.startTimestamp = timestamp a.currentLocalSegment.startTimestamp = timestamp } @@ -294,13 +293,6 @@ func (a *ActivityLog) GetStoredTokenCountByNamespaceID() map[string]uint64 { return a.currentLocalSegment.tokenCount.CountByNamespaceID } -// GetEntitySequenceNumber returns the current entity sequence number -func (a *ActivityLog) GetEntitySequenceNumber() uint64 { - a.l.RLock() - defer a.l.RUnlock() - return a.currentSegment.clientSequenceNumber -} - // GetGlobalEntitySequenceNumber returns the current entity sequence number func (a *ActivityLog) GetGlobalEntitySequenceNumber() uint64 { a.l.RLock() @@ -355,12 +347,6 @@ func (c *Core) GetActiveLocalFragment() *activity.LogFragment { return c.activityLog.localFragment } -func (c *Core) GetActiveFragment() *activity.LogFragment { - c.activityLog.fragmentLock.RLock() - defer c.activityLog.fragmentLock.RUnlock() - return c.activityLog.fragment -} - // StoreCurrentSegment is a test only method to create and store // segments from fragments. This allows createCurrentSegmentFromFragments to remain // private diff --git a/vault/activity_log_util_common.go b/vault/activity_log_util_common.go index c019d03a4739..f3cd616ed99a 100644 --- a/vault/activity_log_util_common.go +++ b/vault/activity_log_util_common.go @@ -425,7 +425,6 @@ type singleTypeSegmentReader struct { } type segmentReader struct { tokens *singleTypeSegmentReader - entities *singleTypeSegmentReader globalEntities *singleTypeSegmentReader localEntities *singleTypeSegmentReader } @@ -433,16 +432,11 @@ type segmentReader struct { // SegmentReader is an interface that provides methods to read tokens and entities in order type SegmentReader interface { ReadToken(ctx context.Context) (*activity.TokenCount, error) - ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error) ReadGlobalEntity(ctx context.Context) (*activity.EntityActivityLog, error) ReadLocalEntity(ctx context.Context) (*activity.EntityActivityLog, error) } func (a *ActivityLog) NewSegmentFileReader(ctx context.Context, startTime time.Time) (SegmentReader, error) { - entities, err := a.newSingleTypeSegmentReader(ctx, startTime, activityEntityBasePath) - if err != nil { - return nil, err - } globalEntities, err := a.newSingleTypeSegmentReader(ctx, startTime, activityGlobalPathPrefix+activityEntityBasePath) if err != nil { return nil, err @@ -455,7 +449,7 @@ func (a *ActivityLog) NewSegmentFileReader(ctx context.Context, startTime time.T if err != nil { return nil, err } - return &segmentReader{entities: entities, globalEntities: globalEntities, localEntities: localEntities, tokens: tokens}, nil + return &segmentReader{globalEntities: globalEntities, localEntities: localEntities, tokens: tokens}, nil } func (a *ActivityLog) newSingleTypeSegmentReader(ctx context.Context, startTime time.Time, prefix string) (*singleTypeSegmentReader, error) { @@ -510,17 +504,6 @@ func (e *segmentReader) ReadToken(ctx context.Context) (*activity.TokenCount, er return out, nil } -// ReadEntity reads an entity from the segment -// If there is none available, then the error will be io.EOF -func (e *segmentReader) ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error) { - out := &activity.EntityActivityLog{} - err := e.entities.nextValue(ctx, out) - if err != nil { - return nil, err - } - return out, nil -} - // ReadGlobalEntity reads a global entity from the global segment // If there is none available, then the error will be io.EOF func (e *segmentReader) ReadGlobalEntity(ctx context.Context) (*activity.EntityActivityLog, error) { diff --git a/vault/activity_log_util_common_test.go b/vault/activity_log_util_common_test.go index 7201cdc651f9..f84775da3fc2 100644 --- a/vault/activity_log_util_common_test.go +++ b/vault/activity_log_util_common_test.go @@ -1006,14 +1006,6 @@ func writeLocalEntitySegment(t *testing.T, core *Core, ts time.Time, index int, WriteToStorage(t, core, makeSegmentPath(t, activityLocalPathPrefix+activityEntityBasePath, ts, index), protoItem) } -// writeEntitySegment writes a single segment file with the given time and index for an entity -func writeEntitySegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.EntityActivityLog) { - t.Helper() - protoItem, err := proto.Marshal(item) - require.NoError(t, err) - WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, ts, index), protoItem) -} - // writeTokenSegment writes a single segment file with the given time and index for a token func writeTokenSegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.TokenCount) { t.Helper() @@ -1037,7 +1029,6 @@ func TestSegmentFileReader_BadData(t *testing.T) { // write bad data that won't be able to be unmarshaled at index 0 WriteToStorage(t, core, makeSegmentPath(t, activityTokenLocalBasePath, now, 0), []byte("fake data")) - WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, 0), []byte("fake data")) WriteToStorage(t, core, makeSegmentPath(t, activityGlobalPathPrefix+activityEntityBasePath, now, 0), []byte("fake data")) WriteToStorage(t, core, makeSegmentPath(t, activityLocalPathPrefix+activityEntityBasePath, now, 0), []byte("fake data")) @@ -1047,8 +1038,6 @@ func TestSegmentFileReader_BadData(t *testing.T) { ClientID: "id", }, }} - writeEntitySegment(t, core, now, 1, entity) - // write global data at index 1 writeGlobalEntitySegment(t, core, now, 1, entity) @@ -1063,25 +1052,19 @@ func TestSegmentFileReader_BadData(t *testing.T) { reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now) require.NoError(t, err) - // first the bad entity is read, which returns an error - _, err = reader.ReadEntity(context.Background()) - require.Error(t, err) - // then, the reader can read the good entity at index 1 - gotEntity, err := reader.ReadEntity(context.Background()) - require.True(t, proto.Equal(gotEntity, entity)) - require.Nil(t, err) - // first the bad global entity is read, which returns an error _, err = reader.ReadGlobalEntity(context.Background()) require.Error(t, err) + // then, the reader can read the good entity at index 1 - gotEntity, err = reader.ReadGlobalEntity(context.Background()) + gotEntity, err := reader.ReadGlobalEntity(context.Background()) require.True(t, proto.Equal(gotEntity, entity)) require.Nil(t, err) // first the bad local entity is read, which returns an error _, err = reader.ReadLocalEntity(context.Background()) require.Error(t, err) + // then, the reader can read the good entity at index 1 gotEntity, err = reader.ReadLocalEntity(context.Background()) require.True(t, proto.Equal(gotEntity, entity)) @@ -1090,6 +1073,7 @@ func TestSegmentFileReader_BadData(t *testing.T) { // the bad token causes an error _, err = reader.ReadToken(context.Background()) require.Error(t, err) + // but the good token is able to be read gotToken, err := reader.ReadToken(context.Background()) require.True(t, proto.Equal(gotToken, token)) @@ -1104,9 +1088,7 @@ func TestSegmentFileReader_MissingData(t *testing.T) { // write entities and tokens at indexes 0, 1, 2 for i := 0; i < 3; i++ { WriteToStorage(t, core, makeSegmentPath(t, activityTokenLocalBasePath, now, i), []byte("fake data")) - WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, i), []byte("fake data")) WriteToStorage(t, core, makeSegmentPath(t, activityGlobalPathPrefix+activityEntityBasePath, now, i), []byte("fake data")) - } // write entity at index 3 entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{ @@ -1114,7 +1096,6 @@ func TestSegmentFileReader_MissingData(t *testing.T) { ClientID: "id", }, }} - writeEntitySegment(t, core, now, 3, entity) // write global entity at index 3 writeGlobalEntitySegment(t, core, now, 3, entity) @@ -1133,25 +1114,18 @@ func TestSegmentFileReader_MissingData(t *testing.T) { // delete the indexes 0, 1, 2 for i := 0; i < 3; i++ { require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityTokenLocalBasePath, now, i))) - require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityEntityBasePath, now, i))) require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityGlobalPathPrefix+activityEntityBasePath, now, i))) require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityLocalPathPrefix+activityEntityBasePath, now, i))) } // we expect the reader to only return the data at index 3, and then be done - gotEntity, err := reader.ReadEntity(context.Background()) - require.NoError(t, err) - require.True(t, proto.Equal(gotEntity, entity)) - _, err = reader.ReadEntity(context.Background()) - require.Equal(t, err, io.EOF) - gotToken, err := reader.ReadToken(context.Background()) require.NoError(t, err) require.True(t, proto.Equal(gotToken, token)) _, err = reader.ReadToken(context.Background()) require.Equal(t, err, io.EOF) - gotEntity, err = reader.ReadGlobalEntity(context.Background()) + gotEntity, err := reader.ReadGlobalEntity(context.Background()) require.NoError(t, err) require.True(t, proto.Equal(gotEntity, entity)) _, err = reader.ReadGlobalEntity(context.Background()) @@ -1170,7 +1144,7 @@ func TestSegmentFileReader_NoData(t *testing.T) { now := time.Now() reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now) require.NoError(t, err) - entity, err := reader.ReadEntity(context.Background()) + entity, err := reader.ReadGlobalEntity(context.Background()) require.Nil(t, entity) require.Equal(t, err, io.EOF) token, err := reader.ReadToken(context.Background()) @@ -1196,7 +1170,8 @@ func TestSegmentFileReader(t *testing.T) { token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{ fmt.Sprintf("ns-%d", i): uint64(i), }} - writeEntitySegment(t, core, now, i, entity) + writeGlobalEntitySegment(t, core, now, i, entity) + writeLocalEntitySegment(t, core, now, i, entity) writeTokenSegment(t, core, now, i, token) entities = append(entities, entity) tokens = append(tokens, token) @@ -1205,13 +1180,20 @@ func TestSegmentFileReader(t *testing.T) { reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now) require.NoError(t, err) - gotEntities := make([]*activity.EntityActivityLog, 0, 3) + gotGlobalEntities := make([]*activity.EntityActivityLog, 0, 3) + gotLocalEntities := make([]*activity.EntityActivityLog, 0, 3) gotTokens := make([]*activity.TokenCount, 0, 3) - // read the entities from the reader - for entity, err := reader.ReadEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadEntity(context.Background()) { + // read the global entities from the reader + for entity, err := reader.ReadGlobalEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadGlobalEntity(context.Background()) { + require.NoError(t, err) + gotGlobalEntities = append(gotGlobalEntities, entity) + } + + // read the local entities from the reader + for entity, err := reader.ReadLocalEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadLocalEntity(context.Background()) { require.NoError(t, err) - gotEntities = append(gotEntities, entity) + gotLocalEntities = append(gotLocalEntities, entity) } // read the tokens from the reader @@ -1219,13 +1201,15 @@ func TestSegmentFileReader(t *testing.T) { require.NoError(t, err) gotTokens = append(gotTokens, token) } - require.Len(t, gotEntities, 3) + require.Len(t, gotGlobalEntities, 3) + require.Len(t, gotLocalEntities, 3) require.Len(t, gotTokens, 3) // verify that the entities and tokens we got from the reader are correct // we can't use require.Equals() here because there are protobuf differences in unexported fields for i := 0; i < 3; i++ { - require.True(t, proto.Equal(gotEntities[i], entities[i])) + require.True(t, proto.Equal(gotGlobalEntities[i], entities[i])) + require.True(t, proto.Equal(gotLocalEntities[i], entities[i])) require.True(t, proto.Equal(gotTokens[i], tokens[i])) } } diff --git a/vault/external_tests/activity_testonly/acme_regeneration_test.go b/vault/external_tests/activity_testonly/acme_regeneration_test.go index c663b174b84e..5d70dc0c21c5 100644 --- a/vault/external_tests/activity_testonly/acme_regeneration_test.go +++ b/vault/external_tests/activity_testonly/acme_regeneration_test.go @@ -54,7 +54,7 @@ func TestACMERegeneration_RegenerateWithCurrentMonth(t *testing.T) { }) require.NoError(t, err) now := time.Now().UTC() - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(3). // 3 months ago, 15 non-entity clients and 10 ACME clients NewClientsSeen(15, clientcountutil.WithClientType("non-entity-token")). @@ -116,7 +116,7 @@ func TestACMERegeneration_RegenerateMuchOlder(t *testing.T) { client := cluster.Cores[0].Client now := time.Now().UTC() - _, _, _, err := clientcountutil.NewActivityLogData(client). + _, _, err := clientcountutil.NewActivityLogData(client). NewPreviousMonthData(5). // 5 months ago, 15 non-entity clients and 10 ACME clients NewClientsSeen(15, clientcountutil.WithClientType("non-entity-token")). @@ -159,7 +159,7 @@ func TestACMERegeneration_RegeneratePreviousMonths(t *testing.T) { client := cluster.Cores[0].Client now := time.Now().UTC() - _, _, _, err := clientcountutil.NewActivityLogData(client). + _, _, err := clientcountutil.NewActivityLogData(client). NewPreviousMonthData(3). // 3 months ago, 15 non-entity clients and 10 ACME clients NewClientsSeen(15, clientcountutil.WithClientType("non-entity-token")). diff --git a/vault/external_tests/activity_testonly/activity_testonly_oss_test.go b/vault/external_tests/activity_testonly/activity_testonly_oss_test.go index 4b59142008b6..c5463bb801d2 100644 --- a/vault/external_tests/activity_testonly/activity_testonly_oss_test.go +++ b/vault/external_tests/activity_testonly/activity_testonly_oss_test.go @@ -29,7 +29,7 @@ func Test_ActivityLog_Disable(t *testing.T) { "enabled": "enable", }) require.NoError(t, err) - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientsSeen(5). NewCurrentMonthData(). diff --git a/vault/external_tests/activity_testonly/activity_testonly_test.go b/vault/external_tests/activity_testonly/activity_testonly_test.go index 3e3a1259b2e3..cd9dfb21574b 100644 --- a/vault/external_tests/activity_testonly/activity_testonly_test.go +++ b/vault/external_tests/activity_testonly/activity_testonly_test.go @@ -1,7 +1,7 @@ // Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 -//go:build testonly +////go:build testonly package activity_testonly @@ -86,7 +86,7 @@ func Test_ActivityLog_LoseLeadership(t *testing.T) { }) require.NoError(t, err) - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewCurrentMonthData(). NewClientsSeen(10). Write(context.Background(), generation.WriteOptions_WRITE_ENTITIES) @@ -121,7 +121,7 @@ func Test_ActivityLog_ClientsOverlapping(t *testing.T) { "enabled": "enable", }) require.NoError(t, err) - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientsSeen(7). NewCurrentMonthData(). @@ -169,7 +169,7 @@ func Test_ActivityLog_ClientsNewCurrentMonth(t *testing.T) { "enabled": "enable", }) require.NoError(t, err) - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientsSeen(5). NewCurrentMonthData(). @@ -203,7 +203,7 @@ func Test_ActivityLog_EmptyDataMonths(t *testing.T) { "enabled": "enable", }) require.NoError(t, err) - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewCurrentMonthData(). NewClientsSeen(10). Write(context.Background(), generation.WriteOptions_WRITE_PRECOMPUTED_QUERIES, generation.WriteOptions_WRITE_ENTITIES) @@ -243,7 +243,7 @@ func Test_ActivityLog_FutureEndDate(t *testing.T) { "enabled": "enable", }) require.NoError(t, err) - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientsSeen(10). NewCurrentMonthData(). @@ -316,7 +316,7 @@ func Test_ActivityLog_ClientTypeResponse(t *testing.T) { _, err := client.Logical().Write("sys/internal/counters/config", map[string]interface{}{ "enabled": "enable", }) - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewCurrentMonthData(). NewClientsSeen(10, clientcountutil.WithClientType(tc.clientType)). Write(context.Background(), generation.WriteOptions_WRITE_ENTITIES) @@ -369,7 +369,7 @@ func Test_ActivityLogCurrentMonth_Response(t *testing.T) { _, err := client.Logical().Write("sys/internal/counters/config", map[string]interface{}{ "enabled": "enable", }) - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewCurrentMonthData(). NewClientsSeen(10, clientcountutil.WithClientType(tc.clientType)). Write(context.Background(), generation.WriteOptions_WRITE_ENTITIES) @@ -420,7 +420,7 @@ func Test_ActivityLog_Deduplication(t *testing.T) { _, err := client.Logical().Write("sys/internal/counters/config", map[string]interface{}{ "enabled": "enable", }) - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewPreviousMonthData(3). NewClientsSeen(10, clientcountutil.WithClientType(tc.clientType)). NewPreviousMonthData(2). @@ -462,7 +462,7 @@ func Test_ActivityLog_MountDeduplication(t *testing.T) { require.NoError(t, err) now := time.Now().UTC() - _, localPaths, globalPaths, err := clientcountutil.NewActivityLogData(client). + localPaths, globalPaths, err := clientcountutil.NewActivityLogData(client). NewPreviousMonthData(1). NewClientSeen(clientcountutil.WithClientMount("sys")). NewClientSeen(clientcountutil.WithClientMount("secret")). @@ -673,7 +673,7 @@ func Test_ActivityLog_Export_Sudo(t *testing.T) { rootToken := client.Token() - _, _, _, err = clientcountutil.NewActivityLogData(client). + _, _, err = clientcountutil.NewActivityLogData(client). NewCurrentMonthData(). NewClientsSeen(10). Write(context.Background(), generation.WriteOptions_WRITE_ENTITIES) @@ -849,7 +849,7 @@ func TestHandleQuery_MultipleMounts(t *testing.T) { } // Write all the client count data - _, _, _, err = activityLogGenerator.Write(context.Background(), generation.WriteOptions_WRITE_PRECOMPUTED_QUERIES, generation.WriteOptions_WRITE_ENTITIES) + _, _, err = activityLogGenerator.Write(context.Background(), generation.WriteOptions_WRITE_PRECOMPUTED_QUERIES, generation.WriteOptions_WRITE_ENTITIES) require.NoError(t, err) endOfCurrentMonth := timeutil.EndOfMonth(time.Now().UTC()) diff --git a/vault/logical_system_activity_write_testonly.go b/vault/logical_system_activity_write_testonly.go index 3f6f4caa5663..51fe65e61e6d 100644 --- a/vault/logical_system_activity_write_testonly.go +++ b/vault/logical_system_activity_write_testonly.go @@ -85,14 +85,13 @@ func (b *SystemBackend) handleActivityWriteData(ctx context.Context, request *lo for _, opt := range input.Write { opts[opt] = struct{}{} } - paths, localPaths, globalPaths, err := generated.write(ctx, opts, b.Core.activityLog, now) + localPaths, globalPaths, err := generated.write(ctx, opts, b.Core.activityLog, now) if err != nil { b.logger.Debug("failed to write activity log data", "error", err.Error()) return logical.ErrorResponse("failed to write data"), err } return &logical.Response{ Data: map[string]interface{}{ - "paths": paths, "local_paths": localPaths, "global_paths": globalPaths, }, @@ -101,15 +100,10 @@ func (b *SystemBackend) handleActivityWriteData(ctx context.Context, request *lo // singleMonthActivityClients holds a single month's client IDs, in the order they were seen type singleMonthActivityClients struct { - // clients are indexed by ID - clients []*activity.EntityRecord // globalClients are indexed by ID globalClients []*activity.EntityRecord // localClients are indexed by ID localClients []*activity.EntityRecord - // predefinedSegments map from the segment number to the client's index in - // the clients slice - predefinedSegments map[int][]int // predefinedGlobalSegments map from the segment number to the client's index in // the clients slice predefinedGlobalSegments map[int][]int @@ -126,17 +120,13 @@ type multipleMonthsActivityClients struct { months []*singleMonthActivityClients } -func (s *singleMonthActivityClients) addEntityRecord(core *Core, record *activity.EntityRecord, segmentIndex *int) { - s.clients = append(s.clients, record) - local, _ := core.activityLog.isClientLocal(record) +func (s *singleMonthActivityClients) addEntityRecord(core *Core, record *activity.EntityRecord, segmentIndex *int, local bool) { if !local { s.globalClients = append(s.globalClients, record) } else { s.localClients = append(s.localClients, record) } if segmentIndex != nil { - index := len(s.clients) - 1 - s.predefinedSegments[*segmentIndex] = append(s.predefinedSegments[*segmentIndex], index) if !local { globalIndex := len(s.globalClients) - 1 s.predefinedGlobalSegments[*segmentIndex] = append(s.predefinedGlobalSegments[*segmentIndex], globalIndex) @@ -230,9 +220,15 @@ func (s *singleMonthActivityClients) addNewClients(c *generation.Client, mountAc if c.Count > 1 { count = int(c.Count) } - isNonEntity := c.ClientType != entityActivityType ts := timeutil.MonthsPreviousTo(int(monthsAgo), now) + // identify is client is local or global + isLocal, err := isClientLocal(core, c.ClientType, mountAccessor) + if err != nil { + return err + } + + isNonEntity := c.ClientType != entityActivityType for i := 0; i < count; i++ { record := &activity.EntityRecord{ ClientID: c.Id, @@ -250,7 +246,7 @@ func (s *singleMonthActivityClients) addNewClients(c *generation.Client, mountAc } } - s.addEntityRecord(core, record, segmentIndex) + s.addEntityRecord(core, record, segmentIndex, isLocal) } return nil } @@ -359,13 +355,25 @@ func (m *multipleMonthsActivityClients) addRepeatedClients(monthsAgo int32, c *g repeatedFromMonth = c.RepeatedFromMonth } repeatedFrom := m.months[repeatedFromMonth] + + // identify is client is local or global + isLocal, err := isClientLocal(core, c.ClientType, mountAccessor) + if err != nil { + return err + } + numClients := 1 if c.Count > 0 { numClients = int(c.Count) } - for _, client := range repeatedFrom.clients { + + repeatedClients := repeatedFrom.globalClients + if isLocal { + repeatedClients = repeatedFrom.localClients + } + for _, client := range repeatedClients { if c.ClientType == client.ClientType && mountAccessor == client.MountAccessor && c.Namespace == client.NamespaceID { - addingTo.addEntityRecord(core, client, segmentIndex) + addingTo.addEntityRecord(core, client, segmentIndex, isLocal) numClients-- if numClients == 0 { break @@ -378,6 +386,23 @@ func (m *multipleMonthsActivityClients) addRepeatedClients(monthsAgo int32, c *g return nil } +// isClientLocal checks whether the given client is on a local mount. +// In all other cases, we will assume it is a global client. +func isClientLocal(core *Core, clientType string, mountAccessor string) (bool, error) { + // Tokens are not replicated to performance secondary clusters + if clientType == nonEntityTokenActivityType { + return true, nil + } + mountEntry := core.router.MatchingMountByAccessor(mountAccessor) + // If the mount entry is nil, this means the mount has been deleted. We will assume it was replicated because we do not want to + // over count clients + if mountEntry != nil && mountEntry.Local { + return true, nil + } + + return false, nil +} + func (m *multipleMonthsActivityClients) addMissingCurrentMonth() { missing := m.months[0].generationParameters == nil && len(m.months) > 1 && @@ -395,8 +420,7 @@ func (m *multipleMonthsActivityClients) timestampForMonth(i int, now time.Time) return now } -func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[generation.WriteOptions]struct{}, activityLog *ActivityLog, now time.Time) ([]string, []string, []string, error) { - paths := []string{} +func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[generation.WriteOptions]struct{}, activityLog *ActivityLog, now time.Time) ([]string, []string, error) { globalPaths := []string{} localPaths := []string{} @@ -411,30 +435,10 @@ func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[gene continue } timestamp := m.timestampForMonth(i, now) - segments, err := month.populateSegments(month.predefinedSegments, month.clients) - if err != nil { - return nil, nil, nil, err - } - for segmentIndex, segment := range segments { - if segment == nil { - // skip the index - continue - } - entityPath, err := activityLog.saveSegmentEntitiesInternal(ctx, segmentInfo{ - startTimestamp: timestamp.Unix(), - currentClients: &activity.EntityActivityLog{Clients: segment}, - clientSequenceNumber: uint64(segmentIndex), - tokenCount: &activity.TokenCount{}, - }, true, "") - if err != nil { - return nil, nil, nil, err - } - paths = append(paths, entityPath) - } if len(month.globalClients) > 0 { globalSegments, err := month.populateSegments(month.predefinedGlobalSegments, month.globalClients) if err != nil { - return nil, nil, nil, err + return nil, nil, err } for segmentIndex, segment := range globalSegments { if segment == nil { @@ -448,7 +452,7 @@ func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[gene tokenCount: &activity.TokenCount{}, }, true, activityGlobalPathPrefix) if err != nil { - return nil, nil, nil, err + return nil, nil, err } globalPaths = append(globalPaths, entityPath) } @@ -456,7 +460,7 @@ func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[gene if len(month.localClients) > 0 { localSegments, err := month.populateSegments(month.predefinedLocalSegments, month.localClients) if err != nil { - return nil, nil, nil, err + return nil, nil, err } for segmentIndex, segment := range localSegments { if segment == nil { @@ -470,7 +474,7 @@ func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[gene tokenCount: &activity.TokenCount{}, }, true, activityLocalPathPrefix) if err != nil { - return nil, nil, nil, err + return nil, nil, err } localPaths = append(localPaths, entityPath) } @@ -495,16 +499,16 @@ func (m *multipleMonthsActivityClients) write(ctx context.Context, opts map[gene if writeIntentLog { err := activityLog.writeIntentLog(ctx, m.latestTimestamp(now, false).Unix(), m.latestTimestamp(now, true).UTC()) if err != nil { - return nil, nil, nil, err + return nil, nil, err } } wg := sync.WaitGroup{} err := activityLog.refreshFromStoredLog(ctx, &wg, now) if err != nil { - return nil, nil, nil, err + return nil, nil, err } wg.Wait() - return paths, localPaths, globalPaths, nil + return localPaths, globalPaths, nil } func (m *multipleMonthsActivityClients) latestTimestamp(now time.Time, includeCurrentMonth bool) time.Time { @@ -532,7 +536,6 @@ func newMultipleMonthsActivityClients(numberOfMonths int) *multipleMonthsActivit } for i := 0; i < numberOfMonths; i++ { m.months[i] = &singleMonthActivityClients{ - predefinedSegments: make(map[int][]int), predefinedGlobalSegments: make(map[int][]int), predefinedLocalSegments: make(map[int][]int), } @@ -583,12 +586,3 @@ func (p *sliceSegmentReader) ReadLocalEntity(ctx context.Context) (*activity.Ent func (p *sliceSegmentReader) ReadToken(ctx context.Context) (*activity.TokenCount, error) { return nil, io.EOF } - -func (p *sliceSegmentReader) ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error) { - if p.i == len(p.records) { - return nil, io.EOF - } - record := p.records[p.i] - p.i++ - return &activity.EntityActivityLog{Clients: record}, nil -} diff --git a/vault/logical_system_activity_write_testonly_test.go b/vault/logical_system_activity_write_testonly_test.go index 4df992172d2b..420e2079d0ef 100644 --- a/vault/logical_system_activity_write_testonly_test.go +++ b/vault/logical_system_activity_write_testonly_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/hashicorp/vault/builtin/credential/userpass" "github.com/hashicorp/vault/helper/namespace" "github.com/hashicorp/vault/helper/timeutil" "github.com/hashicorp/vault/sdk/helper/clientcountutil/generation" @@ -26,11 +27,12 @@ import ( // correctly validated func TestSystemBackend_handleActivityWriteData(t *testing.T) { testCases := []struct { - name string - operation logical.Operation - input map[string]interface{} - wantError error - wantPaths int + name string + operation logical.Operation + input map[string]interface{} + hasLocalClients bool + wantError error + wantPaths int }{ { name: "read fails", @@ -84,6 +86,13 @@ func TestSystemBackend_handleActivityWriteData(t *testing.T) { input: map[string]interface{}{"input": `{"write":["WRITE_ENTITIES"],"data":[{"current_month":true,"num_segments":3,"all":{"clients":[{"count":5}]}}]}`}, wantPaths: 3, }, + { + name: "entities with multiple segments", + operation: logical.UpdateOperation, + input: map[string]interface{}{"input": `{"write":["WRITE_ENTITIES"],"data":[{"current_month":true,"num_segments":3,"all":{"clients":[{"count":5, "mount":"cubbyhole/"}]}}]}`}, + hasLocalClients: true, + wantPaths: 3, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -95,8 +104,16 @@ func TestSystemBackend_handleActivityWriteData(t *testing.T) { require.Equal(t, tc.wantError, err, resp.Error()) } else { require.NoError(t, err) - paths := resp.Data["paths"].([]string) - require.Len(t, paths, tc.wantPaths) + globalPaths := resp.Data["global_paths"].([]string) + localPaths := resp.Data["local_paths"].([]string) + if tc.hasLocalClients { + require.Len(t, globalPaths, 0) + require.Len(t, localPaths, tc.wantPaths) + } else { + require.Len(t, globalPaths, tc.wantPaths) + require.Len(t, localPaths, 0) + } + } }) } @@ -116,6 +133,7 @@ func Test_singleMonthActivityClients_addNewClients(t *testing.T) { wantNamespace string wantMount string wantID string + isLocal bool segmentIndex *int }{ { @@ -153,6 +171,13 @@ func Test_singleMonthActivityClients_addNewClients(t *testing.T) { ClientType: "non-entity", }, }, + { + name: "non entity token client", + clients: &generation.Client{ + ClientType: nonEntityTokenActivityType, + }, + isLocal: true, + }, { name: "acme client", clients: &generation.Client{ @@ -169,8 +194,8 @@ func Test_singleMonthActivityClients_addNewClients(t *testing.T) { t.Run(tt.name, func(t *testing.T) { core, _, _ := TestCoreUnsealed(t) m := &singleMonthActivityClients{ - predefinedSegments: make(map[int][]int), predefinedGlobalSegments: make(map[int][]int), + predefinedLocalSegments: make(map[int][]int), } err := m.addNewClients(tt.clients, tt.mount, tt.segmentIndex, 0, time.Now().UTC(), core) require.NoError(t, err) @@ -178,8 +203,16 @@ func Test_singleMonthActivityClients_addNewClients(t *testing.T) { if numNew == 0 { numNew = 1 } - require.Len(t, m.clients, int(numNew)) - for i, rec := range m.clients { + + var clients []*activity.EntityRecord + if tt.isLocal { + require.Len(t, m.localClients, int(numNew)) + clients = m.localClients + } else { + require.Len(t, m.globalClients, int(numNew)) + clients = m.globalClients + } + for i, rec := range clients { require.NotNil(t, rec) require.Equal(t, tt.wantNamespace, rec.NamespaceID) require.Equal(t, tt.wantMount, rec.MountAccessor) @@ -189,8 +222,11 @@ func Test_singleMonthActivityClients_addNewClients(t *testing.T) { } else { require.NotEqual(t, "", rec.ClientID) } - if tt.segmentIndex != nil { - require.Contains(t, m.predefinedSegments[*tt.segmentIndex], i) + if tt.segmentIndex != nil && tt.isLocal { + require.Contains(t, m.predefinedLocalSegments[*tt.segmentIndex], i) + } + if tt.segmentIndex != nil && !tt.isLocal { + require.Contains(t, m.predefinedGlobalSegments[*tt.segmentIndex], i) } } }) @@ -206,6 +242,7 @@ func Test_multipleMonthsActivityClients_processMonth(t *testing.T) { name string clients *generation.Data wantError bool + isLocal bool numMonths int }{ { @@ -218,6 +255,16 @@ func Test_multipleMonthsActivityClients_processMonth(t *testing.T) { }, numMonths: 1, }, + { + name: "specified namespace and local mount exist", + clients: &generation.Data{ + Clients: &generation.Data_All{All: &generation.Clients{Clients: []*generation.Client{{ + Mount: "cubbyhole/", + }}}}, + }, + numMonths: 1, + isLocal: true, + }, { name: "mount missing slash", clients: &generation.Data{ @@ -282,13 +329,24 @@ func Test_multipleMonthsActivityClients_processMonth(t *testing.T) { require.Error(t, err) } else { require.NoError(t, err) - require.Len(t, m.months[tt.clients.GetMonthsAgo()].clients, len(tt.clients.GetAll().Clients)) - for _, month := range m.months { - for _, c := range month.clients { - require.NotEmpty(t, c.NamespaceID) - require.NotEmpty(t, c.MountAccessor) + if tt.isLocal { + require.Len(t, m.months[tt.clients.GetMonthsAgo()].localClients, len(tt.clients.GetAll().Clients)) + for _, month := range m.months { + for _, c := range month.localClients { + require.NotEmpty(t, c.NamespaceID) + require.NotEmpty(t, c.MountAccessor) + } + } + } else { + require.Len(t, m.months[tt.clients.GetMonthsAgo()].globalClients, len(tt.clients.GetAll().Clients)) + for _, month := range m.months { + for _, c := range month.globalClients { + require.NotEmpty(t, c.NamespaceID) + require.NotEmpty(t, c.MountAccessor) + } } } + } }) } @@ -323,58 +381,95 @@ func Test_multipleMonthsActivityClients_processMonth_segmented(t *testing.T) { m := newMultipleMonthsActivityClients(1) core, _, _ := TestCoreUnsealed(t) require.NoError(t, m.processMonth(context.Background(), core, data, time.Now().UTC())) - require.Len(t, m.months[0].predefinedSegments, 3) - require.Len(t, m.months[0].clients, 3) + require.Len(t, m.months[0].predefinedGlobalSegments, 3) + require.Len(t, m.months[0].globalClients, 3) // segment indexes are correct - require.Contains(t, m.months[0].predefinedSegments, 0) - require.Contains(t, m.months[0].predefinedSegments, 1) - require.Contains(t, m.months[0].predefinedSegments, 7) + require.Contains(t, m.months[0].predefinedGlobalSegments, 0) + require.Contains(t, m.months[0].predefinedGlobalSegments, 1) + require.Contains(t, m.months[0].predefinedGlobalSegments, 7) // the data in each segment is correct - require.Contains(t, m.months[0].predefinedSegments[0], 0) - require.Contains(t, m.months[0].predefinedSegments[1], 1) - require.Contains(t, m.months[0].predefinedSegments[7], 2) + require.Contains(t, m.months[0].predefinedGlobalSegments[0], 0) + require.Contains(t, m.months[0].predefinedGlobalSegments[1], 1) + require.Contains(t, m.months[0].predefinedGlobalSegments[7], 2) } // Test_multipleMonthsActivityClients_addRepeatedClients adds repeated clients // from 1 month ago and 2 months ago, and verifies that the correct clients are // added based on namespace, mount, and non-entity attributes func Test_multipleMonthsActivityClients_addRepeatedClients(t *testing.T) { - core, _, _ := TestCoreUnsealed(t) + storage := &logical.InmemStorage{} + coreConfig := &CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "userpass": userpass.Factory, + }, + Physical: storage.Underlying(), + } + + cluster := NewTestCluster(t, coreConfig, nil) + core := cluster.Cores[0].Core now := time.Now().UTC() m := newMultipleMonthsActivityClients(3) defaultMount := "default" + // add global clients require.NoError(t, m.addClientToMonth(2, &generation.Client{Count: 2}, "identity", nil, now, core)) require.NoError(t, m.addClientToMonth(2, &generation.Client{Count: 2, Namespace: "other_ns"}, defaultMount, nil, now, core)) require.NoError(t, m.addClientToMonth(1, &generation.Client{Count: 2}, defaultMount, nil, now, core)) require.NoError(t, m.addClientToMonth(1, &generation.Client{Count: 2, ClientType: "non-entity"}, defaultMount, nil, now, core)) - month2Clients := m.months[2].clients - month1Clients := m.months[1].clients + // create a local mount + localMount := "localMountAccessor" + localMe := &MountEntry{ + Table: credentialTableType, + Path: "userpass-local/", + Type: "userpass", + Local: true, + Accessor: localMount, + } + err := core.enableCredential(namespace.RootContext(nil), localMe) + require.NoError(t, err) + + // add a local client + require.NoError(t, m.addClientToMonth(2, &generation.Client{Count: 2}, localMount, nil, now, core)) + require.NoError(t, m.addClientToMonth(1, &generation.Client{Count: 2}, localMount, nil, now, core)) + + month2GlobalClients := m.months[2].globalClients + month1GlobalClients := m.months[1].globalClients + + month2LocalClients := m.months[2].localClients + month1LocalClients := m.months[1].localClients thisMonth := m.months[0] // this will match the first client in month 1 require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, Repeated: true}, defaultMount, nil, core)) - require.Contains(t, month1Clients, thisMonth.clients[0]) + require.Contains(t, month1GlobalClients, thisMonth.globalClients[0]) + + // this will match the first local client in month 1 + require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, Repeated: true}, localMount, nil, core)) + require.Contains(t, month1LocalClients, thisMonth.localClients[0]) // this will match the 3rd client in month 1 require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, Repeated: true, ClientType: "non-entity"}, defaultMount, nil, core)) - require.Equal(t, month1Clients[2], thisMonth.clients[1]) + require.Equal(t, month1GlobalClients[2], thisMonth.globalClients[1]) // this will match the first two clients in month 1 require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 2, Repeated: true}, defaultMount, nil, core)) - require.Equal(t, month1Clients[0:2], thisMonth.clients[2:4]) + require.Equal(t, month1GlobalClients[0:2], thisMonth.globalClients[2:4]) // this will match the first client in month 2 require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, RepeatedFromMonth: 2}, "identity", nil, core)) - require.Equal(t, month2Clients[0], thisMonth.clients[4]) + require.Equal(t, month2GlobalClients[0], thisMonth.globalClients[4]) + + // this will match the first local client in month 2 + require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, RepeatedFromMonth: 2}, localMount, nil, core)) + require.Equal(t, month2LocalClients[0], thisMonth.localClients[1]) // this will match the 3rd client in month 2 require.NoError(t, m.addRepeatedClients(0, &generation.Client{Count: 1, RepeatedFromMonth: 2, Namespace: "other_ns"}, defaultMount, nil, core)) - require.Equal(t, month2Clients[2], thisMonth.clients[5]) + require.Equal(t, month2GlobalClients[2], thisMonth.globalClients[5]) require.Error(t, m.addRepeatedClients(0, &generation.Client{Count: 1, RepeatedFromMonth: 2, Namespace: "other_ns"}, "other_mount", nil, core)) } @@ -458,8 +553,8 @@ func Test_singleMonthActivityClients_populateSegments(t *testing.T) { } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - s := singleMonthActivityClients{predefinedSegments: tc.segments, clients: clients, generationParameters: &generation.Data{EmptySegmentIndexes: tc.emptyIndexes, SkipSegmentIndexes: tc.skipIndexes, NumSegments: int32(tc.numSegments)}} - gotSegments, err := s.populateSegments(s.predefinedSegments, s.clients) + s := singleMonthActivityClients{predefinedGlobalSegments: tc.segments, globalClients: clients, generationParameters: &generation.Data{EmptySegmentIndexes: tc.emptyIndexes, SkipSegmentIndexes: tc.skipIndexes, NumSegments: int32(tc.numSegments)}} + gotSegments, err := s.populateSegments(s.predefinedGlobalSegments, s.globalClients) require.NoError(t, err) require.Equal(t, tc.wantSegments, gotSegments) }) @@ -529,7 +624,7 @@ func Test_handleActivityWriteData(t *testing.T) { req.Data = map[string]interface{}{"input": string(marshaled)} resp, err := core.systemBackend.HandleRequest(namespace.RootContext(nil), req) require.NoError(t, err) - paths := resp.Data["paths"].([]string) + paths := resp.Data["global_paths"].([]string) require.Len(t, paths, 9) times, err := core.activityLog.availableLogs(context.Background(), time.Now())