Skip to content

Commit

Permalink
CNS-930: add cache to epoch CU
Browse files Browse the repository at this point in the history
  • Loading branch information
oren-lava committed Mar 27, 2024
1 parent 1b4d203 commit 82daa93
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 11 deletions.
51 changes: 51 additions & 0 deletions x/pairing/keeper/epoch_cu.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"strconv"
"strings"

"github.com/cosmos/cosmos-sdk/store/cachekv"
"github.com/cosmos/cosmos-sdk/store/prefix"
sdk "github.com/cosmos/cosmos-sdk/types"
"github.com/lavanet/lava/utils"
Expand Down Expand Up @@ -209,3 +210,53 @@ func (k Keeper) GetAllProviderConsumerEpochCuStore(ctx sdk.Context) []types.Prov

return info
}

/* ########## EpochCuCache ############ */
type EpochCuCache struct {
Keeper
ProviderEpochCuCache *cachekv.Store
ProviderConsumerEpochCuCache *cachekv.Store
}

func (k Keeper) NewEpochCuCacheHandler(ctx sdk.Context) EpochCuCache {
return EpochCuCache{
Keeper: k,
ProviderEpochCuCache: cachekv.NewStore(prefix.NewStore(ctx.KVStore(k.storeKey), types.ProviderEpochCuKeyPrefix())),
ProviderConsumerEpochCuCache: cachekv.NewStore(prefix.NewStore(ctx.KVStore(k.storeKey), types.ProviderConsumerEpochCuKeyPrefix())),
}
}

func (k EpochCuCache) SetProviderEpochCuCached(ctx sdk.Context, epoch uint64, provider string, chainID string, providerEpochCu types.ProviderEpochCu) {
b := k.cdc.MustMarshal(&providerEpochCu)
k.ProviderEpochCuCache.Set(types.ProviderEpochCuKey(epoch, provider, chainID), b)
}

func (k EpochCuCache) GetProviderEpochCuCached(ctx sdk.Context, epoch uint64, provider string, chainID string) (val types.ProviderEpochCu, found bool) {
b := k.ProviderEpochCuCache.Get(types.ProviderEpochCuKey(epoch, provider, chainID))
if b == nil {
return val, false
}

k.cdc.MustUnmarshal(b, &val)
return val, true
}

func (k EpochCuCache) SetProviderConsumerEpochCuCached(ctx sdk.Context, epoch uint64, provider string, project string, chainID string, providerConsumerEpochCu types.ProviderConsumerEpochCu) {
b := k.cdc.MustMarshal(&providerConsumerEpochCu)
k.ProviderConsumerEpochCuCache.Set(types.ProviderConsumerEpochCuKey(epoch, provider, project, chainID), b)
}

func (k EpochCuCache) GetProviderConsumerEpochCuCached(ctx sdk.Context, epoch uint64, provider string, project string, chainID string) (val types.ProviderConsumerEpochCu, found bool) {
b := k.ProviderConsumerEpochCuCache.Get(types.ProviderConsumerEpochCuKey(epoch, provider, project, chainID))
if b == nil {
return val, false
}

k.cdc.MustUnmarshal(b, &val)
return val, true
}

func (k EpochCuCache) Flush() {
k.ProviderEpochCuCache.Write()
k.ProviderConsumerEpochCuCache.Write()
}
12 changes: 6 additions & 6 deletions x/pairing/keeper/epoch_payment.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,27 @@ import (
)

// AddEpochPayment adds a new epoch payment and returns the updated CU used between provider and project
func (k Keeper) AddEpochPayment(ctx sdk.Context, chainID string, epoch uint64, project string, provider string, cu uint64, sessionID uint64) uint64 {
// register new epoch session (not checking double spend because it's alreday checked before calling this function)
func (k EpochCuCache) AddEpochPayment(ctx sdk.Context, chainID string, epoch uint64, project string, provider string, cu uint64, sessionID uint64) uint64 {
// register new epoch session (not checking double spend because it's already checked before calling this function)
k.SetUniqueEpochSession(ctx, epoch, provider, project, chainID, sessionID)

// update provider serviced CU
pec, found := k.GetProviderEpochCu(ctx, epoch, provider, chainID)
pec, found := k.GetProviderEpochCuCached(ctx, epoch, provider, chainID)
if !found {
pec = types.ProviderEpochCu{ServicedCu: cu}
} else {
pec.ServicedCu += cu
}
k.SetProviderEpochCu(ctx, epoch, provider, chainID, pec)
k.SetProviderEpochCuCached(ctx, epoch, provider, chainID, pec)

// update provider CU for the specific project
pcec, found := k.GetProviderConsumerEpochCu(ctx, epoch, provider, project, chainID)
pcec, found := k.GetProviderConsumerEpochCuCached(ctx, epoch, provider, project, chainID)
if !found {
pcec = types.ProviderConsumerEpochCu{Cu: cu}
} else {
pcec.Cu += cu
}
k.SetProviderConsumerEpochCu(ctx, epoch, provider, project, chainID, pcec)
k.SetProviderConsumerEpochCuCached(ctx, epoch, provider, project, chainID, pcec)
return pcec.Cu
}

Expand Down
13 changes: 8 additions & 5 deletions x/pairing/keeper/msg_server_relay_payment.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ func (k msgServer) RelayPayment(goCtx context.Context, msg *types.MsgRelayPaymen

ctx := sdk.UnwrapSDKContext(goCtx)
logger := k.Logger(ctx)
epochCuCache := k.NewEpochCuCacheHandler(ctx)
lavaChainID := ctx.BlockHeader().ChainID
creator, err := sdk.AccAddressFromBech32(msg.Creator)
if err != nil {
Expand Down Expand Up @@ -165,7 +166,7 @@ func (k msgServer) RelayPayment(goCtx context.Context, msg *types.MsgRelayPaymen
// if they failed (one relay should affect all of them). From here on, every check will
// fail the TX ***

totalCUInEpochForUserProvider := k.Keeper.AddEpochPayment(ctx, relay.SpecId, epochStart, project.Index, relay.Provider, relay.CuSum, relay.SessionId)
totalCUInEpochForUserProvider := epochCuCache.AddEpochPayment(ctx, relay.SpecId, epochStart, project.Index, relay.Provider, relay.CuSum, relay.SessionId)
if badgeFound {
k.handleBadgeCu(ctx, badgeData, relay.Provider, relay.CuSum, newBadgeTimerExpiry)
}
Expand Down Expand Up @@ -281,7 +282,7 @@ func (k msgServer) RelayPayment(goCtx context.Context, msg *types.MsgRelayPaymen
}

// update provider payment storage with complainer's CU
err = k.updateProvidersComplainerCU(ctx, relay.UnresponsiveProviders, epochStart, relay.SpecId, cuAfterQos, providers, project.Index)
err = epochCuCache.updateProvidersComplainerCU(ctx, relay.UnresponsiveProviders, epochStart, relay.SpecId, cuAfterQos, providers, project.Index)
if err != nil {
var reportedProviders []string
for _, p := range relay.UnresponsiveProviders {
Expand Down Expand Up @@ -324,6 +325,8 @@ func (k msgServer) RelayPayment(goCtx context.Context, msg *types.MsgRelayPaymen
}
utils.LogLavaEvent(ctx, logger, types.LatestBlocksReportEventName, latestBlockReports, "New LatestBlocks Report for provider")

epochCuCache.Flush()

return &types.MsgRelayPaymentResponse{RejectedRelays: rejected_relays}, nil
}

Expand All @@ -338,7 +341,7 @@ func (k msgServer) setStakeEntryBlockReport(ctx sdk.Context, providerAddr sdk.Ac
}
}

func (k msgServer) updateProvidersComplainerCU(ctx sdk.Context, unresponsiveProviders []*types.ReportedProvider, epoch uint64, chainID string, cu uint64, pairedProviders []epochstoragetypes.StakeEntry, project string) error {
func (k EpochCuCache) updateProvidersComplainerCU(ctx sdk.Context, unresponsiveProviders []*types.ReportedProvider, epoch uint64, chainID string, cu uint64, pairedProviders []epochstoragetypes.StakeEntry, project string) error {
// check that unresponsiveData exists and that the paired providers list is larger than 1
if len(unresponsiveProviders) == 0 || len(pairedProviders) <= 1 {
return nil
Expand All @@ -364,13 +367,13 @@ func (k msgServer) updateProvidersComplainerCU(ctx sdk.Context, unresponsiveProv
continue
}

pec, found := k.GetProviderEpochCu(ctx, epoch, unresponsiveProvider.Address, chainID)
pec, found := k.GetProviderEpochCuCached(ctx, epoch, unresponsiveProvider.Address, chainID)
if !found {
pec = types.ProviderEpochCu{ComplainersCu: complainerCuToAdd}
} else {
pec.ComplainersCu += complainerCuToAdd
}
k.SetProviderEpochCu(ctx, epoch, unresponsiveProvider.Address, chainID, pec)
k.SetProviderEpochCuCached(ctx, epoch, unresponsiveProvider.Address, chainID, pec)

timestamp := time.Unix(unresponsiveProvider.TimestampS, 0)
details := map[string]string{
Expand Down
50 changes: 50 additions & 0 deletions x/pairing/keeper/msg_server_relay_payment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -969,3 +969,53 @@ func TestIntOverflow(t *testing.T) {
ts.AdvanceEpoch()
}
}

func TestPairingCaching(t *testing.T) {
ts := newTester(t)
ts.setupForPayments(3, 3, 0) // 3 provider, 3 client, default providers-to-pair

ts.AdvanceEpoch()

relayNum := uint64(0)
totalCU := uint64(0)
// trigger relay payment with cache
for i := 0; i < 3; i++ {
relays := []*types.RelaySession{}
_, provider1Addr := ts.GetAccount(common.PROVIDER, i)
for i := 0; i < 3; i++ {
consumerAcct, _ := ts.GetAccount(common.CONSUMER, i)
totalCU = 0
for i := 0; i < 50; i++ {
totalCU += uint64(i)
relaySession := ts.newRelaySession(provider1Addr, relayNum, uint64(i), ts.BlockHeight(), 0)
sig, err := sigs.Sign(consumerAcct.SK, *relaySession)
relaySession.Sig = sig
require.NoError(t, err)
relays = append(relays, relaySession)
relayNum++
}
}
_, err := ts.TxPairingRelayPayment(provider1Addr, relays...)
require.NoError(t, err)
}

pecs := ts.Keepers.Pairing.GetAllProviderEpochCuStore(ts.Ctx)
require.Len(t, pecs, 3)

UniquePayments := ts.Keepers.Pairing.GetAllUniqueEpochSessionStore(ts.Ctx)
require.Len(t, UniquePayments, 3*3*50)

storages := ts.Keepers.Pairing.GetAllProviderConsumerEpochCuStore(ts.Ctx)
require.Len(t, storages, 3*3)

for i := 0; i < 3; i++ {
consumerAcct, _ := ts.GetAccount(common.CONSUMER, i)
project, err := ts.GetProjectForDeveloper(consumerAcct.Addr.String(), ts.BlockHeight())
require.NoError(t, err)
require.Equal(t, totalCU*3, project.UsedCu)

sub, err := ts.QuerySubscriptionCurrent(consumerAcct.Addr.String())
require.NoError(t, err)
require.Equal(t, totalCU*3, sub.Sub.MonthCuTotal-sub.Sub.MonthCuLeft)
}
}

0 comments on commit 82daa93

Please sign in to comment.