Skip to content

Commit

Permalink
Fix: ttl not always set correctly on Redis keys
Browse files Browse the repository at this point in the history
  • Loading branch information
ivard committed Oct 19, 2023
1 parent 1bc83e0 commit 31bdffd
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 20 deletions.
43 changes: 30 additions & 13 deletions server/irmaserver/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,31 +246,39 @@ func (s *memorySessionStore) deleteExpired() {
func (s *redisSessionStore) add(ctx context.Context, session *sessionData) error {
sessionJSON, err := json.Marshal(session)
if err != nil {
return err
return &RedisError{err}
}

ttl := session.ttl(s.conf)
if ttl <= 0 {
return &RedisError{errors.New("session ttl is in the past")}
}
if err := s.client.Watch(ctx, func(tx *redis.Tx) error {
if err := tx.Set(
ctx,
s.client.KeyPrefix+requestorTokenLookupPrefix+string(session.RequestorToken),
string(session.ClientToken),
ttl,
).Err(); err != nil {
return &RedisError{err}
return err
}
if err := tx.Set(ctx, s.client.KeyPrefix+clientTokenLookupPrefix+string(session.ClientToken), sessionJSON, ttl).Err(); err != nil {
return &RedisError{err}
if err := tx.Set(
ctx,
s.client.KeyPrefix+clientTokenLookupPrefix+string(session.ClientToken),
sessionJSON,
ttl,
).Err(); err != nil {
return err
}

if s.client.FailoverMode {
if err := s.client.Wait(ctx, 1, time.Second).Err(); err != nil {
return &RedisError{err}
return err
}
}
return nil
}); err != nil {
return err
return &RedisError{err}
}

s.conf.Logger.WithFields(logrus.Fields{"session": session.RequestorToken}).Debug("Session added in Redis datastore")
Expand All @@ -295,12 +303,12 @@ func (s *redisSessionStore) transaction(ctx context.Context, t irma.RequestorTok
}

func (s *redisSessionStore) clientTransaction(ctx context.Context, t irma.ClientToken, handler func(session *sessionData) (bool, error)) error {
if err := s.client.Watch(ctx, func(tx *redis.Tx) error {
err := s.client.Watch(ctx, func(tx *redis.Tx) error {
getResult := tx.Get(ctx, s.client.KeyPrefix+clientTokenLookupPrefix+string(t))
if getResult.Err() == redis.Nil {
return &UnknownSessionError{"", t}
} else if getResult.Err() != nil {
return &RedisError{getResult.Err()}
return getResult.Err()
}

session := &sessionData{}
Expand Down Expand Up @@ -329,19 +337,28 @@ func (s *redisSessionStore) clientTransaction(ctx context.Context, t irma.Client
return err
}

err = tx.Set(ctx, s.client.KeyPrefix+clientTokenLookupPrefix+string(t), sessionJSON, 0).Err()
if err != nil {
return &RedisError{err}
ttl := session.ttl(s.conf)
if ttl <= 0 {
return errors.New("session ttl is in the past")
}

if err := tx.Set(ctx, s.client.KeyPrefix+clientTokenLookupPrefix+string(t), sessionJSON, ttl).Err(); err != nil {
return err
}
if err := tx.Expire(ctx, s.client.KeyPrefix+requestorTokenLookupPrefix+string(session.RequestorToken), ttl).Err(); err != nil {
return err
}
if s.client.FailoverMode {
if err := tx.Wait(ctx, 1, time.Second).Err(); err != nil {
return &RedisError{err}
return err
}
}
return nil
}); err != nil {
})
if _, ok := err.(*UnknownSessionError); ok {
return err
} else if err != nil {
return &RedisError{err}
}
return nil
}
Expand Down
2 changes: 1 addition & 1 deletion server/keyshare/myirmaserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func New(conf *Configuration) (*Server, error) {
if err != nil {
return nil, err
}
store = &redisSessionStore{client: cl}
store = &redisSessionStore{client: cl, logger: conf.Logger}
default:
return nil, errors.New("unsupported session store type")
}
Expand Down
13 changes: 10 additions & 3 deletions server/keyshare/myirmaserver/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,16 @@ func (s *redisSessionStore) add(ctx context.Context, ses session) error {
return err
}

ttl := time.Until(ses.Expiry).Seconds()
ttl := time.Until(ses.Expiry)
if ttl <= 0 {
return errors.New("session expiry time is in the past")
}
if err := s.client.Watch(ctx, func(tx *redis.Tx) error {
if err := tx.Set(
ctx,
s.client.KeyPrefix+sessionLookupPrefix+ses.Token,
string(bytes),
time.Duration(ttl)*time.Second,
ttl,
).Err(); err != nil {
return err
}
Expand Down Expand Up @@ -138,7 +141,11 @@ func (s *redisSessionStore) update(ctx context.Context, token string, handler fu
return err
}

if err := tx.Set(ctx, key, string(updatedBytes), time.Until(session.Expiry)).Err(); err != nil {
ttl := time.Until(session.Expiry)
if ttl <= 0 {
return errors.New("session expiry time is in the past")
}
if err := tx.Set(ctx, key, string(updatedBytes), ttl).Err(); err != nil {
return err
}

Expand Down
22 changes: 19 additions & 3 deletions server/keyshare/myirmaserver/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,27 @@ import (
"testing"
"time"

// "github.com/alicebob/miniredis/v2"
"github.com/alicebob/miniredis/v2"
"github.com/go-redis/redis/v8"
irma "github.com/privacybydesign/irmago"
"github.com/privacybydesign/irmago/server"
"github.com/stretchr/testify/assert"
)

func TestSessions(t *testing.T) {
store := newMemorySessionStore()
func TestMemorySessionStore(t *testing.T) {
testSessions(t, newMemorySessionStore(), time.Sleep)
}

func TestRedisSessionStore(t *testing.T) {
mr := miniredis.NewMiniRedis()
mr.Start()
defer mr.Close()
client := redis.NewClient(&redis.Options{Addr: mr.Host() + ":" + mr.Port()})
testSessions(t, &redisSessionStore{client: &server.RedisClient{Client: client}, logger: server.Logger}, mr.FastForward)
}

func testSessions(t *testing.T, store sessionStore, sleepFn func(time.Duration)) {
s := session{
Token: "token",
Expiry: time.Now().Add(1 * time.Second),
Expand All @@ -21,6 +35,8 @@ func TestSessions(t *testing.T) {

session2, err := getSession(store, s.Token)
assert.NoError(t, err)
assert.Equal(t, s.Expiry.Unix(), session2.Expiry.Unix())
s.Expiry = session2.Expiry // Time is not exactly equal, so set it to the same value
assert.Equal(t, s, session2)

emailSessionToken := irma.RequestorToken("emailtoken")
Expand All @@ -43,7 +59,7 @@ func TestSessions(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, session4.Token, s.Token)

time.Sleep(2 * time.Second)
sleepFn(2 * time.Second)

store.flush()

Expand Down

0 comments on commit 31bdffd

Please sign in to comment.