From a1035f082b7d61d1df37487d495d62e5388d1e0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bal=C3=A1zs=20Grill?= Date: Thu, 18 Jul 2024 17:41:48 +0200 Subject: [PATCH] tested and implemented conflict scenarios --- filesystem.go | 150 ++++++++++++++++++++++++++++++++---- filesystem_test.go | 184 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 318 insertions(+), 16 deletions(-) diff --git a/filesystem.go b/filesystem.go index a494e8e..189f4ae 100644 --- a/filesystem.go +++ b/filesystem.go @@ -17,6 +17,8 @@ import ( "github.com/spf13/afero" ) import ( + "bytes" + "crypto/md5" "fmt" "path/filepath" "strings" @@ -94,12 +96,40 @@ func (instance *VirtualizationInstance) start(rootPath string, filesystem afero. return instance.syncRemoteToLocal() } +func (instance *VirtualizationInstance) path_localToRemote(path string) string { + p := strings.TrimPrefix(path, instance.rootPath) + p = strings.ReplaceAll(p, "\\", "/") + p = strings.TrimPrefix(p, "/") + return p +} + +func (instance *VirtualizationInstance) path_remoteToLocal(path string) string { + p := strings.TrimPrefix(path, "/") + p = strings.ReplaceAll(p, "/", "\\") + return filepath.Join(instance.rootPath, "\\", p) +} + +func (instance *VirtualizationInstance) path_getNameRemote(path string) string { + p := strings.TrimPrefix(path, "/") + return filepath.Base(p) +} + +func (instance *VirtualizationInstance) path_getNameLocal(path string) string { + return filepath.Base(strings.ReplaceAll(path, "\\", "/")) +} + +func (instance *VirtualizationInstance) path_hashFile(remotepath string) string { + fname := filepath.Base(remotepath) + dir := filepath.Dir(remotepath) + return dir + "/.md5_" + fname +} + func (instance *VirtualizationInstance) PerformSynchronization() error { - err := instance.syncRemoteToLocal() + err := instance.syncLocalToRemote() if err != nil { return err } - return instance.syncLocalToRemote() + return instance.syncRemoteToLocal() } func (instance *VirtualizationInstance) syncRemoteToLocal() error { @@ -113,7 +143,11 @@ func (instance *VirtualizationInstance) syncRemoteToLocal() error { if remoteinfo.IsDir() { return nil } - localpath := instance.rootPath + "\\" + path + filename := instance.path_getNameRemote(path) + if strings.HasPrefix(filename, ".") { + return nil + } + localpath := instance.path_remoteToLocal(path) var localstate projfs.PRJ_FILE_STATE hr := projfs.PrjGetOnDiskFileState(localpath, &localstate) if hr != 0 { @@ -139,6 +173,29 @@ func (instance *VirtualizationInstance) syncRemoteToLocal() error { }) } +func (instance *VirtualizationInstance) localHash(remotepath string) ([]byte, error) { + // only calculate hash if file is not a placeholder + var localstate projfs.PRJ_FILE_STATE + hr := projfs.PrjGetOnDiskFileState(instance.path_remoteToLocal(remotepath), &localstate) + if hr != 0 { + return nil, projfs.ErrorByCode(hr) + } + if (localstate | (projfs.PRJ_FILE_STATE_FULL & projfs.PRJ_FILE_STATE_HYDRATED_PLACEHOLDER)) == 0 { + return nil, nil + } + hash := md5.New() + f, err := os.Open(instance.path_remoteToLocal(remotepath)) + if err != nil { + return nil, err + } + defer f.Close() + _, err = io.Copy(hash, f) + if err != nil { + return nil, err + } + return hash.Sum(nil), nil +} + func (instance *VirtualizationInstance) syncLocalToRemote() error { return filepath.Walk(instance.rootPath, func(localpath string, localinfo fs.FileInfo, err error) error { if os.IsNotExist(err) { @@ -148,8 +205,7 @@ func (instance *VirtualizationInstance) syncLocalToRemote() error { return err } - path := strings.TrimPrefix(localpath, instance.rootPath) - path = strings.TrimPrefix(path, "\\") + path := instance.path_localToRemote(localpath) if localinfo.IsDir() { return instance.fs.MkdirAll(path, 0777) } @@ -167,7 +223,33 @@ func (instance *VirtualizationInstance) syncLocalToRemote() error { // check if local is newer remoteinfo, err := instance.fs.Stat(path) if os.IsNotExist(err) { - // new local file, remote does not exist + // chek if hash file exists on remote + hashpath := instance.path_hashFile(path) + exists, err := afero.Exists(instance.fs, hashpath) + if err != nil { + return err + } + if exists { + // on remote file existed before, upload only if hash is different + hash, err := afero.ReadFile(instance.fs, hashpath) + if err != nil { + return err + } + localhash, err := instance.localHash(path) + if err != nil { + return err + } + if localhash == nil { + // local file does not exist, no need to upload + // TODO is this a tombstone? + return err + } + if bytes.Equal(hash, localhash) { + // hash is the same this file has been removed remotely, delete local file + return os.Remove(localpath) + } + } + // new local file, remote does not exist, or hash is different log.Printf("Uploading file '%s'", path) return instance.streamLocalToRemote(path) } @@ -263,7 +345,7 @@ func returncode(err error) uintptr { func (instance *VirtualizationInstance) Notify(callbackData *projfs.PRJ_CALLBACK_DATA, IsDirectory bool, notification projfs.PRJ_NOTIFICATION, destinationFileName uintptr, operationParameters *projfs.PRJ_NOTIFICATION_PARAMETERS) uintptr { // operation is done on file system - filename := callbackData.GetFilePathName() + filename := instance.path_localToRemote(callbackData.GetFilePathName()) log.Printf("Notify: %t %d %d '%s', %d", IsDirectory, callbackData.CommandId, notification, filename, *operationParameters) switch notification { @@ -282,21 +364,52 @@ func (instance *VirtualizationInstance) Notify(callbackData *projfs.PRJ_CALLBACK return returncode(instance.streamLocalToRemote(filename)) } case projfs.PRJ_NOTIFICATION_FILE_HANDLE_CLOSED_FILE_DELETED: + // TODO establish protocol for deletion + // option 1: upon deletion, leave a placeholder (remove when recreated) + // option 2: upon creation, create an indicator, which remains upon deletion (can stay around) return returncode(instance.fs.Remove(filename)) } return 0 } func (instance *VirtualizationInstance) streamLocalToRemote(filename string) error { - data, err := os.ReadFile(instance.rootPath + "\\" + filename) + file, err := os.Open(instance.path_remoteToLocal(filename)) + if err != nil { + return err + } + defer file.Close() + data := make([]byte, 1024*1024) + targetfile, err := instance.fs.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0x666) if err != nil { return err } - return afero.WriteFile(instance.fs, filename, data, 0666) + defer targetfile.Close() + + hash := md5.New() + for { + n, err := file.Read(data) + if err != nil { + if err == io.EOF { + break + } + return err + } + _, err = hash.Write(data[:n]) + if err != nil { + return err + } + _, err = targetfile.Write(data[:n]) + if err != nil { + return err + } + } + + return afero.WriteFile(instance.fs, instance.path_hashFile(filename), hash.Sum(nil), 0666) } func (instance *VirtualizationInstance) QueryFileName(callbackData *projfs.PRJ_CALLBACK_DATA) uintptr { - log.Printf("QueryFileName: '%s'", callbackData.GetFilePathName()) + filename := instance.path_localToRemote(callbackData.GetFilePathName()) + log.Printf("QueryFileName: '%s'", filename) return 0 } @@ -322,7 +435,7 @@ func (instance *VirtualizationInstance) EndDirectoryEnumeration(callbackData *pr } func (instance *VirtualizationInstance) GetDirectoryEnumeration(callbackData *projfs.PRJ_CALLBACK_DATA, enumerationId *syscall.GUID, searchExpression uintptr, dirEntryBufferHandle projfs.PRJ_DIR_ENTRY_BUFFER_HANDLE) uintptr { - filepath := callbackData.GetFilePathName() + filenamepath := instance.path_localToRemote(callbackData.GetFilePathName()) first := instance.enumerations[*enumerationId].countget == 0 restart := callbackData.Flags&projfs.PRJ_CB_DATA_FLAG_ENUM_RESTART_SCAN != 0 @@ -330,7 +443,7 @@ func (instance *VirtualizationInstance) GetDirectoryEnumeration(callbackData *pr if !ok { return uintptr(syscall.EINVAL) } - log.Printf("GetDirectoryEnumeration (%t, %t, %d) %s", first, restart, session.sentcount, filepath) + log.Printf("GetDirectoryEnumeration (%t, %t, %d) %s", first, restart, session.sentcount, filenamepath) if restart || first { session.sentcount = 0 @@ -344,13 +457,18 @@ func (instance *VirtualizationInstance) GetDirectoryEnumeration(callbackData *pr } instance.enumerations[*enumerationId].countget++ - files, err := afero.ReadDir(instance.fs, filepath) + files, err := afero.ReadDir(instance.fs, filenamepath) if err != nil { - log.Printf("Error reading directory %s: %s", filepath, err) + log.Printf("Error reading directory %s: %s", filenamepath, err) return uintptr(syscall.EIO) } for _, file := range files[session.sentcount:] { + fname := filepath.Base(file.Name()) + if strings.HasPrefix(fname, ".") { + continue + } + if session.searchstr != 0 { match := projfs.PrjFileNameMatch(file.Name(), session.searchstr) if !match { @@ -397,7 +515,7 @@ func FillInPlaceholderInfo(data *projfs.PRJ_PLACEHOLDER_INFO, fileinfo fs.FileIn func (instance *VirtualizationInstance) GetPlaceholderInfo(callbackData *projfs.PRJ_CALLBACK_DATA) uintptr { var data projfs.PRJ_PLACEHOLDER_INFO - filename := callbackData.GetFilePathName() + filename := instance.path_localToRemote(callbackData.GetFilePathName()) log.Printf("GetPlaceholderInfo %s", filename) stat, err := instance.fs.Stat(filename) if os.IsNotExist(err) { @@ -412,7 +530,7 @@ func (instance *VirtualizationInstance) GetPlaceholderInfo(callbackData *projfs. } func (instance *VirtualizationInstance) GetFileData(callbackData *projfs.PRJ_CALLBACK_DATA, byteOffset uint64, length uint32) uintptr { - filename := callbackData.GetFilePathName() + filename := instance.path_localToRemote(callbackData.GetFilePathName()) log.Printf("GetFileData %s", filename) file, err := instance.fs.Open(filename) if err != nil { diff --git a/filesystem_test.go b/filesystem_test.go index 1e01a52..32f7986 100644 --- a/filesystem_test.go +++ b/filesystem_test.go @@ -415,3 +415,187 @@ func TestRemoveFolder(t *testing.T) { } } + +func TestDeletedOnBackendWhileOffline(t *testing.T) { + instance := newTestInstance(t) + instance.start() + + data := []byte("something") + filename := "test.txt" + err := instance.osWriteFile(filename, string(data)) + + if err != nil { + t.Fatal(err) + } + + instance.stop() + time.Sleep(time.Second) + + err = instance.fs.Remove(filename) + if err != nil { + t.Fatal(err) + } + _, err = instance.fs.Stat(filename) + if !os.IsNotExist(err) { + t.Error("remote file exists") + } + + instance.start() + err = instance.closer.PerformSynchronization() + if err != nil { + t.Fatal(err) + } + + _, err = os.Stat(instance.location + "\\" + filename) + if !os.IsNotExist(err) { + t.Error("local file exists") + } + _, err = instance.fs.Stat(filename) + if !os.IsNotExist(err) { + t.Error("remote file exists") + } + + instance.stop() +} + +func TestDeletedLocallyWhileOffline(t *testing.T) { + instance := newTestInstance(t) + instance.start() + + data := []byte("something") + filename := "test.txt" + err := instance.osWriteFile(filename, string(data)) + + if err != nil { + t.Fatal(err) + } + + instance.stop() + time.Sleep(time.Second) + + err = instance.osRemoveFile(filename) + if err != nil { + t.Fatal(err) + } + + instance.start() + err = instance.closer.PerformSynchronization() + if err != nil { + t.Fatal(err) + } + + _, err = os.Stat(instance.location + "\\" + filename) + if os.IsNotExist(err) { + t.Error("File should be restored locally") + } + _, err = instance.fs.Stat(filename) + if os.IsNotExist(err) { + t.Error("remote file should not be removed") + } + + instance.stop() +} + +func TestConflictWhileOfflineLocalNewer(t *testing.T) { + instance := newTestInstance(t) + instance.start() + + data := []byte("something") + filename := "test.txt" + err := instance.osWriteFile(filename, string(data)) + + if err != nil { + t.Fatal(err) + } + + instance.stop() + time.Sleep(time.Second) + + data2 := []byte("something2") + data3 := []byte("something3") + + err = afero.WriteFile(instance.fs, filename, data3, 0x777) + if err != nil { + t.Fatal(err) + } + time.Sleep(time.Second) + err = instance.osWriteFile(filename, string(data2)) + if err != nil { + t.Fatal(err) + } + + instance.start() + err = instance.closer.PerformSynchronization() + if err != nil { + t.Fatal(err) + } + + data4, err := afero.ReadFile(instance.fs, filename) + if err != nil { + t.Fatal(err) + } + if string(data2) != strings.TrimSpace(string(data4)) { + t.Errorf("expected '%s', got '%s'", string(data2), string(data4)) + } + data5, err := os.ReadFile(instance.location + "\\" + filename) + if err != nil { + t.Fatal(err) + } + if string(data2) != strings.TrimSpace(string(data5)) { + t.Errorf("expected '%s', got '%s'", string(data2), string(data5)) + } + + instance.stop() +} + +func TestConflictWhileOfflineRemoteNewer(t *testing.T) { + instance := newTestInstance(t) + instance.start() + + data := []byte("something") + filename := "test.txt" + err := instance.osWriteFile(filename, string(data)) + + if err != nil { + t.Fatal(err) + } + + instance.stop() + time.Sleep(time.Second) + + data2 := []byte("something2") + data3 := []byte("something3") + + err = instance.osWriteFile(filename, string(data2)) + if err != nil { + t.Fatal(err) + } + time.Sleep(time.Second) + err = afero.WriteFile(instance.fs, filename, data3, 0x777) + if err != nil { + t.Fatal(err) + } + + instance.start() + err = instance.closer.PerformSynchronization() + if err != nil { + t.Fatal(err) + } + + data4, err := afero.ReadFile(instance.fs, filename) + if err != nil { + t.Fatal(err) + } + if string(data3) != strings.TrimSpace(string(data4)) { + t.Errorf("expected '%s', got '%s'", string(data3), string(data4)) + } + data5, err := os.ReadFile(instance.location + "\\" + filename) + if err != nil { + t.Fatal(err) + } + if string(data3) != strings.TrimSpace(string(data5)) { + t.Errorf("expected '%s', got '%s'", string(data3), string(data5)) + } + + instance.stop() +}