From 3036026e0416c8fdb37e9e1c7357ccb5ff932a44 Mon Sep 17 00:00:00 2001 From: Florian Ritterhoff Date: Wed, 25 Oct 2023 09:54:16 +0200 Subject: [PATCH] fix: try better reloading of database --- offline.go | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/offline.go b/offline.go index 21c980a..2e1f7f7 100644 --- a/offline.go +++ b/offline.go @@ -58,9 +58,13 @@ var ( type ( // An OfflineDatabase is a client for querying Pwned Passwords locally. OfflineDatabase struct { + file *DatabaseFile + logger zap.Logger + cron *cron.Cron + } + + DatabaseFile struct { database readCloserAt - logger zap.Logger - cron *cron.Cron } // readCloserAt is an io.ReaderAt that can be Closed and whose @@ -135,9 +139,9 @@ func NewOfflineDatabase(dbFile string, updatedDbFile string) (*OfflineDatabase, logger.Info("Database opened") c := cron.New() odb := &OfflineDatabase{ - database: db, - logger: logger, - cron: c, + file: &DatabaseFile{db}, + logger: logger, + cron: c, } c.AddFunc("@hourly", func() { @@ -156,7 +160,7 @@ func NewOfflineDatabase(dbFile string, updatedDbFile string) (*OfflineDatabase, if err != nil { log.Panic(err) } - odb.database = db + odb.file = &DatabaseFile{db} logger.Info("Database updated") } } @@ -169,7 +173,7 @@ func NewOfflineDatabase(dbFile string, updatedDbFile string) (*OfflineDatabase, // Close frees resources associated with the database. func (od *OfflineDatabase) Close() error { od.cron.Stop() - return od.database.Close() + return od.file.database.Close() } // Pwned checks how frequently the given hash is included in the Pwned Passwords @@ -211,7 +215,7 @@ func (od *OfflineDatabase) Pwned(hash [20]byte) (frequency int, err error) { test = lo + ((hi - lo) / 2) // lookup - if _, err := od.database.ReadAt(rbuf[:], DataSegmentOffset+start+int64(test*19)); err != nil { + if _, err := od.file.database.ReadAt(rbuf[:], DataSegmentOffset+start+int64(test*19)); err != nil { return 0, err } @@ -270,7 +274,7 @@ func (od *OfflineDatabase) Scan(startPrefix, endPrefix [3]byte, hash []byte, cb } // read from the data file - if _, err := od.database.ReadAt(buffer[0:length], DataSegmentOffset+start); err != nil { + if _, err := od.file.database.ReadAt(buffer[0:length], DataSegmentOffset+start); err != nil { return err } @@ -325,18 +329,18 @@ func (od *OfflineDatabase) lookup(start [3]byte) (location, length int64, err er case [3]byte{0xFF, 0xFF, 0xFF}: // read the required index - if _, err := od.database.ReadAt(buffer[0:8], int64(prefixIndex)*8); err != nil { + if _, err := od.file.database.ReadAt(buffer[0:8], int64(prefixIndex)*8); err != nil { return 0, 0, err } // look up locations and calculate length loc = int64(binary.BigEndian.Uint64(buffer[0:8])) - dataLen = int64(od.database.Len()-IndexSegmentSize) - loc + dataLen = int64(od.file.database.Len()-IndexSegmentSize) - loc default: // read the required index, and the next one (to calculate length) - if _, err := od.database.ReadAt(buffer[0:16], int64(prefixIndex)*8); err != nil { + if _, err := od.file.database.ReadAt(buffer[0:16], int64(prefixIndex)*8); err != nil { return 0, 0, err } @@ -378,8 +382,8 @@ func (od *OfflineDatabase) ServeHTTP(w http.ResponseWriter, r *http.Request) { hash = sha1.Sum([]byte(pw)) } - frequency, err := od.Pwned(hash) od.logger.Sugar().Infof("checking password: %v", hash) + frequency, err := od.Pwned(hash) if err != nil { od.logger.Sugar().Warnf("error checking password: %v", err) w.WriteHeader(http.StatusInternalServerError)