diff --git a/filesystem.go b/filesystem.go index d67a651..a494e8e 100644 --- a/filesystem.go +++ b/filesystem.go @@ -16,6 +16,11 @@ import ( "github.com/google/uuid" "github.com/spf13/afero" ) +import ( + "fmt" + "path/filepath" + "strings" +) type VirtualizationInstance struct { rootPath string @@ -46,7 +51,7 @@ func (instance *VirtualizationInstance) Close() error { return nil } -func StartProjecting(rootPath string, filesystem afero.Fs) (io.Closer, error) { +func StartProjecting(rootPath string, filesystem afero.Fs) (Virtualization, error) { instance := &VirtualizationInstance{ enumerations: make(map[syscall.GUID]*enumerationSession), } @@ -86,18 +91,26 @@ func (instance *VirtualizationInstance) start(rootPath string, filesystem afero. log.Printf("Error starting virtualization: %s", err) return err } - return instance.syncRemoteToLcal() + return instance.syncRemoteToLocal() +} + +func (instance *VirtualizationInstance) PerformSynchronization() error { + err := instance.syncRemoteToLocal() + if err != nil { + return err + } + return instance.syncLocalToRemote() } -func (instance *VirtualizationInstance) syncRemoteToLcal() error { - return afero.Walk(instance.fs, "", func(path string, info fs.FileInfo, err error) error { +func (instance *VirtualizationInstance) syncRemoteToLocal() error { + return afero.Walk(instance.fs, "", func(path string, remoteinfo fs.FileInfo, err error) error { if os.IsNotExist(err) { return nil } if err != nil { return err } - if info.IsDir() { + if remoteinfo.IsDir() { return nil } localpath := instance.rootPath + "\\" + path @@ -107,13 +120,18 @@ func (instance *VirtualizationInstance) syncRemoteToLcal() error { return projfs.ErrorByCode(hr) } - if localstate == projfs.PRJ_FILE_STATE_FULL { + if (localstate | (projfs.PRJ_FILE_STATE_FULL & projfs.PRJ_FILE_STATE_HYDRATED_PLACEHOLDER)) != 0 { // check if remote is newer localinfo, _ := os.Stat(localpath) - if localinfo.ModTime().UTC().Unix() > info.ModTime().UTC().Unix() { + if localinfo.ModTime().UTC().Unix() < remoteinfo.ModTime().UTC().Unix() { + log.Printf("Updating local file '%s'", path) var placeholderInfo projfs.PRJ_PLACEHOLDER_INFO - placeholderInfo.FileBasicInfo = toBasicInfo(info) - instance.UpdateFileIfNeeded(path, &placeholderInfo, uint32(info.Size()), projfs.PRJ_UPDATE_ALLOW_DIRTY_METADATA, nil) + FillInPlaceholderInfo(&placeholderInfo, remoteinfo) + //err = projfs.ErrorByCode(projfs.PrjWritePlaceholderInfo(instance._instanceHandle, path, &placeholderInfo, uint32(unsafe.Sizeof(placeholderInfo)))) + err = instance.UpdateFileIfNeeded(path, &placeholderInfo, uint32(unsafe.Sizeof(placeholderInfo)), projfs.PRJ_UPDATE_ALLOW_DIRTY_METADATA|projfs.PRJ_UPDATE_ALLOW_DIRTY_DATA) + if err != nil { + return err + } } } @@ -121,6 +139,58 @@ func (instance *VirtualizationInstance) syncRemoteToLcal() error { }) } +func (instance *VirtualizationInstance) syncLocalToRemote() error { + return filepath.Walk(instance.rootPath, func(localpath string, localinfo fs.FileInfo, err error) error { + if os.IsNotExist(err) { + return nil + } + if err != nil { + return err + } + + path := strings.TrimPrefix(localpath, instance.rootPath) + path = strings.TrimPrefix(path, "\\") + if localinfo.IsDir() { + return instance.fs.MkdirAll(path, 0777) + } + if strings.HasPrefix(path, ".") { + return nil + } + + var localstate projfs.PRJ_FILE_STATE + hr := projfs.PrjGetOnDiskFileState(localpath, &localstate) + if hr != 0 { + return projfs.ErrorByCode(hr) + } + + if (localstate | (projfs.PRJ_FILE_STATE_FULL & projfs.PRJ_FILE_STATE_HYDRATED_PLACEHOLDER)) != 0 { + // check if local is newer + remoteinfo, err := instance.fs.Stat(path) + if os.IsNotExist(err) { + // new local file, remote does not exist + log.Printf("Uploading file '%s'", path) + return instance.streamLocalToRemote(path) + } + if err != nil { + return err + } + // info from walk return modification time of 2185 TODO: why? + localinfo, err := os.Stat(localpath) + if err != nil { + return err + } + localmodtime := localinfo.ModTime() + localtime := localmodtime.Unix() + remotetime := remoteinfo.ModTime().Unix() + if localtime > remotetime { + log.Printf("Updating remote file '%s'", path) + return instance.streamLocalToRemote(path) + } + } + return nil + }) +} + func (instance *VirtualizationInstance) getVirtualizationInfoFileName() string { return instance.rootPath + "\\.virtualization" } @@ -174,8 +244,13 @@ func (instance *VirtualizationInstance) get_callbacks() *projfs.PRJ_CALLBACKS { } } -func (instance *VirtualizationInstance) UpdateFileIfNeeded(relativePath string, placeholderInfo *projfs.PRJ_PLACEHOLDER_INFO, length uint32, updateFlags projfs.PRJ_UPDATE_TYPES, failureReason *projfs.PRJ_UPDATE_FAILURE_CAUSES) error { - return projfs.ErrorByCode(projfs.PrjUpdateFileIfNeeded(instance._instanceHandle, relativePath, placeholderInfo, length, updateFlags, failureReason)) +func (instance *VirtualizationInstance) UpdateFileIfNeeded(relativePath string, placeholderInfo *projfs.PRJ_PLACEHOLDER_INFO, length uint32, updateFlags projfs.PRJ_UPDATE_TYPES) error { + var failureReason projfs.PRJ_UPDATE_FAILURE_CAUSES + err := projfs.ErrorByCode(projfs.PrjUpdateFileIfNeeded(instance._instanceHandle, relativePath, placeholderInfo, length, updateFlags, &failureReason)) + if err != nil { + err = fmt.Errorf("UpdateFileIfNeeded failed: %w (reason: %d)", err, failureReason) + } + return err } func returncode(err error) uintptr { @@ -204,11 +279,7 @@ func (instance *VirtualizationInstance) Notify(callbackData *projfs.PRJ_CALLBACK } case projfs.PRJ_NOTIFICATION_FILE_HANDLE_CLOSED_FILE_MODIFIED, projfs.PRJ_NOTIFICATION_FILE_OVERWRITTEN: if !IsDirectory { - data, err := os.ReadFile(instance.rootPath + "\\" + filename) - if err != nil { - return returncode(err) - } - return returncode(afero.WriteFile(instance.fs, filename, data, 0666)) + return returncode(instance.streamLocalToRemote(filename)) } case projfs.PRJ_NOTIFICATION_FILE_HANDLE_CLOSED_FILE_DELETED: return returncode(instance.fs.Remove(filename)) @@ -216,6 +287,14 @@ func (instance *VirtualizationInstance) Notify(callbackData *projfs.PRJ_CALLBACK return 0 } +func (instance *VirtualizationInstance) streamLocalToRemote(filename string) error { + data, err := os.ReadFile(instance.rootPath + "\\" + filename) + if err != nil { + return err + } + return afero.WriteFile(instance.fs, filename, data, 0666) +} + func (instance *VirtualizationInstance) QueryFileName(callbackData *projfs.PRJ_CALLBACK_DATA) uintptr { log.Printf("QueryFileName: '%s'", callbackData.GetFilePathName()) return 0 @@ -287,17 +366,35 @@ func (instance *VirtualizationInstance) GetDirectoryEnumeration(callbackData *pr } func toBasicInfo(file fs.FileInfo) projfs.PRJ_FILE_BASIC_INFO { + ftime := syscall.NsecToFiletime(file.ModTime().UnixNano()) return projfs.PRJ_FILE_BASIC_INFO{ IsDirectory: file.IsDir(), FileSize: file.Size(), - CreationTime: file.ModTime().Unix(), - LastAccessTime: file.ModTime().Unix(), - LastWriteTime: file.ModTime().Unix(), - ChangeTime: file.ModTime().Unix(), + CreationTime: ftime, + LastAccessTime: ftime, + LastWriteTime: ftime, + ChangeTime: ftime, FileAttributes: 0, } } +func getVersionInfo(basicInfo *projfs.PRJ_FILE_BASIC_INFO) projfs.PRJ_PLACEHOLDER_VERSION_INFO { + result := projfs.PRJ_PLACEHOLDER_VERSION_INFO{ + ProviderID: [projfs.PRJ_PLACEHOLDER_ID_LENGTH]byte{0, 0x1}, + ContentID: [projfs.PRJ_PLACEHOLDER_ID_LENGTH]byte{0}, + } + + version := uint64(basicInfo.LastWriteTime.Nanoseconds()) + binary.LittleEndian.PutUint64(result.ContentID[:], version) + log.Printf("Version: %d %v", version, result.ContentID) + return result +} + +func FillInPlaceholderInfo(data *projfs.PRJ_PLACEHOLDER_INFO, fileinfo fs.FileInfo) { + data.FileBasicInfo = toBasicInfo(fileinfo) + data.VersionInfo = getVersionInfo(&data.FileBasicInfo) +} + func (instance *VirtualizationInstance) GetPlaceholderInfo(callbackData *projfs.PRJ_CALLBACK_DATA) uintptr { var data projfs.PRJ_PLACEHOLDER_INFO filename := callbackData.GetFilePathName() @@ -310,7 +407,7 @@ func (instance *VirtualizationInstance) GetPlaceholderInfo(callbackData *projfs. log.Printf("Error getting placeholder info for %s: %s", filename, err) return uintptr(syscall.EIO) } - data.FileBasicInfo = toBasicInfo(stat) + FillInPlaceholderInfo(&data, stat) return projfs.PrjWritePlaceholderInfo(instance._instanceHandle, callbackData.GetFilePathName(), &data, uint32(unsafe.Sizeof(data))) } diff --git a/filesystem_test.go b/filesystem_test.go index 6da0b3e..1e01a52 100644 --- a/filesystem_test.go +++ b/filesystem_test.go @@ -2,13 +2,13 @@ package projfero_test import ( "bytes" - "io" "log" "os" "os/exec" "reflect" "strings" "testing" + "time" "github.com/balazsgrill/projfero" "github.com/spf13/afero" @@ -18,7 +18,7 @@ type testInstance struct { t *testing.T location string fs afero.Fs - closer io.Closer + closer projfero.Virtualization closechan chan bool } @@ -295,12 +295,19 @@ func TestChangedOnBackend(t *testing.T) { t.Errorf("expected %v, got %v", data, data2) } + // sleep for a bit to ensure that the file timestamp is different + time.Sleep(time.Second) data = []byte("somethingelse") err = afero.WriteFile(instance.fs, filename, data, 0x777) if err != nil { t.Fatal(err) } + err = instance.closer.PerformSynchronization() + if err != nil { + t.Fatal(err) + } + data2, err = os.ReadFile(instance.location + "\\" + filename) if err != nil { t.Fatal(err) @@ -350,6 +357,7 @@ func TestUpdatedLocallyWhileOffline(t *testing.T) { } instance.stop() + time.Sleep(time.Second) data = []byte("somethingelse") err = instance.osWriteFile(filename, string(data)) @@ -358,13 +366,17 @@ func TestUpdatedLocallyWhileOffline(t *testing.T) { } instance.start() + err = instance.closer.PerformSynchronization() + if err != nil { + t.Fatal(err) + } data2, err := afero.ReadFile(instance.fs, filename) if err != nil { t.Fatal(err) } - if !bytes.Equal(data, data2) { - t.Errorf("expected %s, got %s", string(data), string(data2)) + if string(data) != strings.TrimSpace(string(data2)) { + t.Errorf("expected '%s', got '%s'", string(data), string(data2)) } instance.stop() } diff --git a/go.mod b/go.mod index c5fd4cf..a6f47e2 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/balazsgrill/projfero go 1.21.0 require ( - github.com/balazsgrill/projfs v0.0.1 + github.com/balazsgrill/projfs v0.0.2 github.com/google/uuid v1.6.0 github.com/spf13/afero v1.11.0 ) diff --git a/go.sum b/go.sum index eb10ac3..6572fb7 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/balazsgrill/projfs v0.0.1 h1:ot7RKyspqAfLdr3QZjLZzhS6m0lcxp/VPrRELhZNh0U= -github.com/balazsgrill/projfs v0.0.1/go.mod h1:zxk3JTKjlt3AKYJ98gJQEpJ6ZQ+qtkyuqsgfMLb5mW4= +github.com/balazsgrill/projfs v0.0.2 h1:HyhDgJT1sgIdHd1j1OHr/YYeDauQINzd+irerpaBlUM= +github.com/balazsgrill/projfs v0.0.2/go.mod h1:zxk3JTKjlt3AKYJ98gJQEpJ6ZQ+qtkyuqsgfMLb5mW4= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=