diff --git a/x/marker/types/send_restrictions.go b/x/marker/types/send_restrictions.go index a7709a8ded..92e9e3fa09 100644 --- a/x/marker/types/send_restrictions.go +++ b/x/marker/types/send_restrictions.go @@ -2,7 +2,10 @@ package types import sdk "github.com/cosmos/cosmos-sdk/types" -var bypassKey = "bypass-marker-restriction" +var ( + bypassKey = "bypass-marker-restriction" + transferAgentKey = "marker-transfer-agent" +) // WithBypass returns a new context that will cause the marker bank send restriction to be skipped. func WithBypass(ctx sdk.Context) sdk.Context { @@ -23,3 +26,23 @@ func HasBypass(ctx sdk.Context) bool { bypass, isBool := bypassValue.(bool) return isBool && bypass } + +// WithTransferAgent returns a new context that contains the provided marker transfer agent. +func WithTransferAgent(ctx sdk.Context, transferAgent sdk.AccAddress) sdk.Context { + return ctx.WithValue(transferAgentKey, transferAgent) +} + +// WithoutTransferAgent returns a new context with a nil marker transfer agent. +func WithoutTransferAgent(ctx sdk.Context) sdk.Context { + return ctx.WithValue(transferAgentKey, sdk.AccAddress(nil)) +} + +// GetTransferAgent gets the marker transfer agent from the provided context. +func GetTransferAgent(ctx sdk.Context) sdk.AccAddress { + val := ctx.Value(transferAgentKey) + if val == nil { + return nil + } + rv, _ := val.(sdk.AccAddress) + return rv +} diff --git a/x/marker/types/send_restrictions_test.go b/x/marker/types/send_restrictions_test.go index f2be345f59..465d53c75a 100644 --- a/x/marker/types/send_restrictions_test.go +++ b/x/marker/types/send_restrictions_test.go @@ -9,7 +9,54 @@ import ( tmproto "github.com/tendermint/tendermint/proto/tendermint/types" ) -func TestSendRestrictionContextFuncs(t *testing.T) { +func TestKeysContainModuleName(t *testing.T) { + assert.Contains(t, bypassKey, ModuleName, "bypassKey") + assert.Contains(t, transferAgentKey, ModuleName, "transferAgentKey") +} + +func TestContextCombos(t *testing.T) { + newCtx := func() sdk.Context { + return sdk.NewContext(nil, tmproto.Header{}, false, nil) + } + + tests := []struct { + name string + ctx sdk.Context + expBypass bool + expTA sdk.AccAddress + }{ + { + name: "with transfer agent on with bypass", + ctx: WithTransferAgent(WithBypass(newCtx()), sdk.AccAddress("some_transfer_agent_")), + expBypass: true, + expTA: sdk.AccAddress("some_transfer_agent_"), + }, + { + name: "with bypass on with transfer agent", + ctx: WithBypass(WithTransferAgent(newCtx(), sdk.AccAddress("other_transfer_agent"))), + expBypass: true, + expTA: sdk.AccAddress("other_transfer_agent"), + }, + { + name: "without either on with transfer agent and bypass", + ctx: WithoutBypass(WithoutTransferAgent(WithBypass(WithTransferAgent(newCtx(), sdk.AccAddress("bad_transfer_agent__"))))), + expBypass: false, + expTA: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + actBypass := HasBypass(tc.ctx) + actTA := GetTransferAgent(tc.ctx) + + assert.Equal(t, tc.expBypass, actBypass, "HasBypass") + assert.Equal(t, tc.expTA, actTA, "GetTransferAgent") + }) + } +} + +func TestBypassFuncs(t *testing.T) { tests := []struct { name string ctx sdk.Context @@ -42,7 +89,7 @@ func TestSendRestrictionContextFuncs(t *testing.T) { }, { name: "context without bypass on one that originally had it", - ctx: WithoutBypass(WithoutBypass(sdk.NewContext(nil, tmproto.Header{}, false, nil))), + ctx: WithoutBypass(WithBypass(sdk.NewContext(nil, tmproto.Header{}, false, nil))), exp: false, }, { @@ -60,7 +107,7 @@ func TestSendRestrictionContextFuncs(t *testing.T) { } } -func TestContextFuncsDoNotModifyProvided(t *testing.T) { +func TestBypassFuncsDoNotModifyProvided(t *testing.T) { origCtx := sdk.NewContext(nil, tmproto.Header{}, false, nil) assert.False(t, HasBypass(origCtx), "HasBypass(origCtx)") afterWith := WithBypass(origCtx) @@ -72,6 +119,72 @@ func TestContextFuncsDoNotModifyProvided(t *testing.T) { assert.False(t, HasBypass(origCtx), "HasBypass(origCtx) after giving afterWith to WithoutBypass") } -func TestKeyContainsModuleName(t *testing.T) { - assert.Contains(t, bypassKey, ModuleName, "bypassKey") +func TestTransferAgentFuncs(t *testing.T) { + newCtx := func() sdk.Context { + return sdk.NewContext(nil, tmproto.Header{}, false, nil) + } + + tests := []struct { + name string + ctx sdk.Context + exp sdk.AccAddress + }{ + { + name: "brand new mostly empty context", + ctx: newCtx(), + exp: nil, + }, + { + name: "context with transfer agent", + ctx: WithTransferAgent(newCtx(), sdk.AccAddress("transfer_agent______")), + exp: sdk.AccAddress("transfer_agent______"), + }, + { + name: "context without transfer agent", + ctx: WithoutTransferAgent(newCtx()), + exp: nil, + }, + { + name: "context with transfer agent twice", + ctx: WithTransferAgent(WithTransferAgent(newCtx(), sdk.AccAddress("first_transfer_agent")), sdk.AccAddress("agent_2_of_transfer_")), + exp: sdk.AccAddress("agent_2_of_transfer_"), + }, + { + name: "context without transfer agent twice", + ctx: WithoutTransferAgent(WithoutTransferAgent(newCtx())), + exp: nil, + }, + { + name: "context with transfer agent on one that originally was without it", + ctx: WithTransferAgent(WithoutTransferAgent(newCtx()), sdk.AccAddress("agent_of_transfer___")), + exp: sdk.AccAddress("agent_of_transfer___"), + }, + { + name: "context without transfer agent on one that originally had it", + ctx: WithoutTransferAgent(WithTransferAgent(newCtx(), sdk.AccAddress("the_transfer_agent__"))), + exp: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + actual := GetTransferAgent(tc.ctx) + assert.Equal(t, tc.exp, actual, "GetTransferAgent") + }) + } +} + +func TestTransferAgentFuncsDoNotModifyProvided(t *testing.T) { + origCtx := sdk.NewContext(nil, tmproto.Header{}, false, nil) + assert.Nil(t, GetTransferAgent(origCtx), "GetTransferAgent(origCtx)") + + ta := sdk.AccAddress("great_transfer_agent") + afterWith := WithTransferAgent(origCtx, ta) + assert.Equal(t, ta, GetTransferAgent(afterWith), "GetTransferAgent(afterWith)") + assert.Nil(t, GetTransferAgent(origCtx), "GetTransferAgent(origCtx) after giving it to WithTransferAgent") + + afterWithout := WithoutTransferAgent(afterWith) + assert.Nil(t, GetTransferAgent(afterWithout), "GetTransferAgent(afterWithout)") + assert.Equal(t, ta, GetTransferAgent(afterWith), "GetTransferAgent(afterWith) after giving it to WithoutTransferAgent") + assert.Nil(t, GetTransferAgent(origCtx), "GetTransferAgent(origCtx) after giving afterWith to WithoutTransferAgent") }