diff --git a/.gitignore b/.gitignore index 7d66e42..0d02728 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ cov.out # massive data files /pwned-passwords-* *.bin +pwned-passwords.lock diff --git a/cmd/pwnd/main.go b/cmd/pwnd/main.go index deec02c..8ac04a5 100644 --- a/cmd/pwnd/main.go +++ b/cmd/pwnd/main.go @@ -10,12 +10,16 @@ import ( func main() { // parse flags - var dbFile string + var ( + dbFile string + updatedDbFile string + ) flag.StringVar(&dbFile, "database", pwnedpass.DatabaseFilename, "path to the database file") + flag.StringVar(&updatedDbFile, "updated-database", pwnedpass.UpdatedDatabaseFilename, "path to the database file") flag.Parse() // open the offline database - od, err := pwnedpass.NewOfflineDatabase(dbFile) + od, err := pwnedpass.NewOfflineDatabase(dbFile, updatedDbFile) if err != nil { panic(err) } diff --git a/cmd/pwngen/main.go b/cmd/pwngen/main.go index fa27ab1..7b21e9c 100644 --- a/cmd/pwngen/main.go +++ b/cmd/pwngen/main.go @@ -25,7 +25,8 @@ const IndexSegmentSize = 256 << 16 << 3 // exactly 256^3 MB // DatabaseFilename indicates the default location of the database file // to be created. -var DatabaseFilename = "pwned-passwords.bin" +var DatabaseFilename = "updated-pwned-passwords.bin" +var LockFileName = "pwned-passwords.lock" type loadWorker struct { sugar *zap.SugaredLogger @@ -40,6 +41,7 @@ type response struct { // The work that needs to be performed // The input type should implement the WorkFunction interface func (w loadWorker) Run(ctx context.Context) interface{} { + os.Create(LockFileName) var body []byte for i := 0; i < 3; i++ { httpClient := &http.Client{} @@ -192,5 +194,6 @@ func main() { } sugar.Infof("OK") + os.Remove(LockFileName) } diff --git a/go.mod b/go.mod index 715e7f7..93c9285 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,8 @@ require ( golang.org/x/exp v0.0.0-20230905200255-921286631fa9 ) +require github.com/robfig/cron/v3 v3.0.0 // indirect + require ( github.com/tejzpr/ordered-concurrently/v3 v3.0.1 go.uber.org/multierr v1.10.0 // indirect diff --git a/go.sum b/go.sum index 47c2a15..73e9196 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/robfig/cron/v3 v3.0.0 h1:kQ6Cb7aHOHTSzNVNEhmp8EcWKLb4CbiMW9h9VyIhO4E= +github.com/robfig/cron/v3 v3.0.0/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/tejzpr/ordered-concurrently/v3 v3.0.1 h1:TLHtzlQEDshbmGveS8S+hxLw4s5u67aoJw5LLf+X2xY= github.com/tejzpr/ordered-concurrently/v3 v3.0.1/go.mod h1:mu/neZ6AGXm5jdPc7PEgViYK3rkYNPvVCEm15Cx/iRI= go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= diff --git a/offline.go b/offline.go index 0efb3b7..1aa7a70 100644 --- a/offline.go +++ b/offline.go @@ -8,12 +8,15 @@ import ( "errors" "fmt" "io" + "log" "net/http" "os" "strconv" "strings" "sync" + "time" + "github.com/robfig/cron/v3" "go.uber.org/zap" "go.uber.org/zap/zapcore" "golang.org/x/exp/mmap" @@ -21,7 +24,9 @@ import ( const ( // DatabaseFilename is the default path to the database. - DatabaseFilename = "pwned-passwords.bin" + DatabaseFilename = "pwned-passwords.bin" + UpdatedDatabaseFilename = "updated-pwned-passwords.bin" + LockFileName = "pwned-passwords.lock" // IndexSegmentSize is the exact size of the index segment in bytes. IndexSegmentSize = 256 << 16 << 3 // exactly 256^3 MB @@ -55,6 +60,7 @@ type ( OfflineDatabase struct { database readCloserAt logger zap.Logger + cron *cron.Cron } // readCloserAt is an io.ReaderAt that can be Closed and whose @@ -71,13 +77,8 @@ type ( // NewOfflineDatabase opens a new OfflineDatabase using the data in the given // database file. -func NewOfflineDatabase(dbFile string) (*OfflineDatabase, error) { - - db, err := mmap.Open(dbFile) - if err != nil { - return nil, fmt.Errorf("error opening index: %s", err) - } - +func NewOfflineDatabase(dbFile string, updatedDbFile string) (*OfflineDatabase, error) { + lockExists := false encoderCfg := zap.NewProductionEncoderConfig() encoderCfg.TimeKey = "timestamp" encoderCfg.EncodeTime = zapcore.ISO8601TimeEncoder @@ -99,17 +100,68 @@ func NewOfflineDatabase(dbFile string) (*OfflineDatabase, error) { "pid": os.Getpid(), }, } + logger := *zap.Must(config.Build()) + for { + if _, err := os.Stat(LockFileName); err == nil { + lockExists = true + } + if _, err := os.Stat(dbFile); err == nil { + break + } else { + // Check if error indicates a missing file? + if os.IsNotExist(err) && lockExists { + logger.Warn("Lock file exists, but database file does not. Waiting for lock to be released.") + time.Sleep(1 * time.Minute) + } + } + } + + if _, err := os.Stat(updatedDbFile); err == nil { + if !lockExists { + if err := os.Rename(updatedDbFile, dbFile); err != nil { + return nil, fmt.Errorf("error moving updated database: %s", err) + } + } + } + + db, err := mmap.Open(dbFile) + if err != nil { + return nil, fmt.Errorf("error opening index: %s", err) + } + c := cron.New() odb := &OfflineDatabase{ database: db, - logger: *zap.Must(config.Build()), + logger: logger, + cron: c, } + c.AddFunc("@hourly", func() { + if _, err := os.Stat(updatedDbFile); err == nil { + lockExists := false + if _, err := os.Stat(LockFileName); err == nil { + lockExists = true + } + if !lockExists { + db.Close() + if err := os.Rename(updatedDbFile, dbFile); err != nil { + log.Panic(err) + } + db, err := mmap.Open(dbFile) + if err != nil { + log.Panic(err) + } + odb.database = db + } + } + }) + c.Start() return odb, nil } // Close frees resources associated with the database. func (od *OfflineDatabase) Close() error { + od.cron.Stop() return od.database.Close() } diff --git a/offline_test.go b/offline_test.go index 6cb47dd..5c4f75d 100644 --- a/offline_test.go +++ b/offline_test.go @@ -25,7 +25,7 @@ func TestPwned(t *testing.T) { }, } - od, err := NewOfflineDatabase(DatabaseFilename) + od, err := NewOfflineDatabase(DatabaseFilename, UpdatedDatabaseFilename) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -44,7 +44,7 @@ func TestPwned(t *testing.T) { func BenchmarkPwned(b *testing.B) { - od, err := NewOfflineDatabase(DatabaseFilename) + od, err := NewOfflineDatabase(DatabaseFilename, UpdatedDatabaseFilename) if err != nil { b.Fatalf("unexpected error: %s", err) } @@ -93,7 +93,7 @@ func TestScan(t *testing.T) { }, } - od, err := NewOfflineDatabase(DatabaseFilename) + od, err := NewOfflineDatabase(DatabaseFilename, UpdatedDatabaseFilename) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -127,7 +127,7 @@ func BenchmarkScan(b *testing.B) { EndPrefix = [3]byte{0x05, 0x31, 0x91} ) - od, err := NewOfflineDatabase(DatabaseFilename) + od, err := NewOfflineDatabase(DatabaseFilename, UpdatedDatabaseFilename) if err != nil { b.Fatalf("unexpected error: %s", err) } @@ -179,7 +179,7 @@ func TestLookup(t *testing.T) { }, } - od, err := NewOfflineDatabase(DatabaseFilename) + od, err := NewOfflineDatabase(DatabaseFilename, UpdatedDatabaseFilename) if err != nil { t.Fatalf("unexpected error: %s", err) } @@ -206,7 +206,7 @@ func TestLookup(t *testing.T) { func BenchmarkHTTPPassword(b *testing.B) { // open the offline database - od, err := NewOfflineDatabase(DatabaseFilename) + od, err := NewOfflineDatabase(DatabaseFilename, UpdatedDatabaseFilename) if err != nil { panic(err) } @@ -242,7 +242,7 @@ func BenchmarkHTTPPassword(b *testing.B) { func BenchmarkHTTPRange(b *testing.B) { // open the offline database - od, err := NewOfflineDatabase(DatabaseFilename) + od, err := NewOfflineDatabase(DatabaseFilename, UpdatedDatabaseFilename) if err != nil { panic(err) }