From 2dc4eb5c2662b8d7f76fb03d3065007df925d4e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bal=C3=A1zs=20Grill?= Date: Sat, 13 Jul 2024 05:40:32 +0200 Subject: [PATCH] first commit --- example/main.go | 25 ++++ filesystem.go | 291 +++++++++++++++++++++++++++++++++++++++++++++ filesystem_test.go | 213 +++++++++++++++++++++++++++++++++ go.mod | 13 ++ go.sum | 6 + 5 files changed, 548 insertions(+) create mode 100644 example/main.go create mode 100644 filesystem.go create mode 100644 filesystem_test.go create mode 100644 go.mod create mode 100644 go.sum diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..5af5e1b --- /dev/null +++ b/example/main.go @@ -0,0 +1,25 @@ +package main + +import ( + "log" + "os" + "os/signal" + "syscall" + + "github.com/balazsgrill/projfero" + "github.com/spf13/afero" +) + +func main() { + fs := afero.NewBasePathFs(afero.NewOsFs(), "C:\\work\\vfsbase") + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + closer, err := projfero.StartProjecting("C:\\work\\vfs", fs) + if err != nil { + log.Panic(err) + } + + <-c + closer.Close() + os.Exit(1) +} diff --git a/filesystem.go b/filesystem.go new file mode 100644 index 0000000..5c859c5 --- /dev/null +++ b/filesystem.go @@ -0,0 +1,291 @@ +package projfero + +import ( + "encoding/binary" + "errors" + "io" + "io/fs" + "log" + "os" + "syscall" + "unsafe" + + "C" + + "github.com/balazsgrill/projfs" + "github.com/google/uuid" + "github.com/spf13/afero" +) + +type VirtualizationInstance struct { + rootPath string + fs afero.Fs + _instanceHandle projfs.PRJ_NAMESPACE_VIRTUALIZATION_CONTEXT + enumerations map[syscall.GUID]*enumerationSession +} + +type enumerationSession struct { + searchstr uintptr + countget int + sentcount int + wildcard bool +} + +func (instance *VirtualizationInstance) Close() error { + if instance._instanceHandle == 0 { + return errors.New("not started") + } + projfs.PrjStopVirtualizing(instance._instanceHandle) + instance._instanceHandle = 0 + log.Println("Stopped virtualization") + return nil +} + +func StartProjecting(rootPath string, filesystem afero.Fs) (io.Closer, error) { + instance := &VirtualizationInstance{ + enumerations: make(map[syscall.GUID]*enumerationSession), + } + return instance, instance.start(rootPath, filesystem) +} + +func (instance *VirtualizationInstance) start(rootPath string, filesystem afero.Fs) error { + if instance._instanceHandle != 0 { + return errors.New("already started") + } + instance.rootPath = rootPath + instance.fs = filesystem + + id, err := instance.ensureVirtualizationFolderExists() + if err != nil { + return err + } + + hr := projfs.PrjMarkDirectoryAsPlaceholder(rootPath, "", nil, id) + if hr != 0 { + log.Printf("Error marking directory as placeholder: %s", projfs.ErrorByCode(hr)) + return projfs.ErrorByCode(hr) + } + log.Printf("Starting virtualization of '%s' (%v)", rootPath, *id) + options := &projfs.PRJ_STARTVIRTUALIZING_OPTIONS{ + NotificationMappings: &projfs.PRJ_NOTIFICATION_MAPPING{ + NotificationBitMask: projfs.PRJ_NOTIFY_NEW_FILE_CREATED | projfs.PRJ_NOTIFY_FILE_OVERWRITTEN | projfs.PRJ_NOTIFY_FILE_HANDLE_CLOSED_FILE_DELETED | projfs.PRJ_NOTIFY_FILE_HANDLE_CLOSED_FILE_MODIFIED, + NotificationRoot: projfs.GetPointer(""), + }, + NotificationMappingsCount: 1, + PoolThreadCount: 4, + ConcurrentThreadCount: 4, + } + hr = projfs.PrjStartVirtualizing(rootPath, instance.get_callbacks(), instance, options, &instance._instanceHandle) + return projfs.ErrorByCode(hr) +} + +func (instance *VirtualizationInstance) getVirtualizationInfoFileName() string { + return instance.rootPath + "\\.virtualization" +} + +func bytesToGuid(b []byte) *syscall.GUID { + return &syscall.GUID{ + Data1: binary.LittleEndian.Uint32(b[0:4]), + Data2: binary.LittleEndian.Uint16(b[4:6]), + Data3: binary.LittleEndian.Uint16(b[6:8]), + Data4: ([8]byte)(b[8:16]), + } +} + +func (instance *VirtualizationInstance) ensureVirtualizationFolderExists() (*syscall.GUID, error) { + err := os.MkdirAll(instance.rootPath, 0777) + if err != nil { + return nil, err + } + + if _, err := os.Stat(instance.getVirtualizationInfoFileName()); errors.Is(err, os.ErrNotExist) { + uuid, _ := uuid.NewRandom() + id := bytesToGuid(uuid[:]) + err = os.WriteFile(instance.getVirtualizationInfoFileName(), uuid[:], 0666) + if err != nil { + return nil, err + } + return id, nil + } + + bytes, err := os.ReadFile(instance.getVirtualizationInfoFileName()) + if err != nil { + return nil, err + } + if len(bytes) != 16 { + return nil, errors.New("invalid virtualization info file") + } + + return bytesToGuid(bytes), nil +} + +func (instance *VirtualizationInstance) get_callbacks() *projfs.PRJ_CALLBACKS { + return &projfs.PRJ_CALLBACKS{ + NotificationCallback: instance.Notify, + QueryFileNameCallback: instance.QueryFileName, + CancelCommandCallback: instance.CancelCommand, + StartDirectoryEnumerationCallback: instance.StartDirectoryEnumeration, + GetDirectoryEnumerationCallback: instance.GetDirectoryEnumeration, + EndDirectoryEnumerationCallback: instance.EndDirectoryEnumeration, + GetPlaceholderInfoCallback: instance.GetPlaceholderInfo, + GetFileDataCallback: instance.GetFileData, + } +} + +func (instance *VirtualizationInstance) UpdateFileIfNeeded(relativePath string, placeholderInfo *projfs.PRJ_PLACEHOLDER_INFO, length uint32, updateFlags projfs.PRJ_UPDATE_TYPES, failureReason *projfs.PRJ_UPDATE_FAILURE_CAUSES) error { + return projfs.ErrorByCode(projfs.PrjUpdateFileIfNeeded(instance._instanceHandle, relativePath, placeholderInfo, length, updateFlags, failureReason)) +} + +func returncode(err error) uintptr { + if err != nil { + log.Println(err) + return 1 + } + return 0 +} + +func (instance *VirtualizationInstance) Notify(callbackData *projfs.PRJ_CALLBACK_DATA, IsDirectory bool, notification projfs.PRJ_NOTIFICATION, destinationFileName uintptr, operationParameters *projfs.PRJ_NOTIFICATION_PARAMETERS) uintptr { + // operation is done on file system + filename := callbackData.GetFilePathName() + log.Printf("Notify: %t %d %d '%s', %d", IsDirectory, callbackData.CommandId, notification, filename, *operationParameters) + switch notification { + + case projfs.PRJ_NOTIFICATION_NEW_FILE_CREATED: + if IsDirectory { + return returncode(instance.fs.Mkdir(filename, 0777)) + } else { + _, err := instance.fs.Create(filename) + if err != nil { + return returncode(err) + } + return returncode(err) + } + case projfs.PRJ_NOTIFICATION_FILE_HANDLE_CLOSED_FILE_MODIFIED, projfs.PRJ_NOTIFICATION_FILE_OVERWRITTEN: + if !IsDirectory { + data, err := os.ReadFile(instance.rootPath + "\\" + filename) + if err != nil { + return returncode(err) + } + return returncode(afero.WriteFile(instance.fs, filename, data, 0666)) + } + case projfs.PRJ_NOTIFICATION_FILE_HANDLE_CLOSED_FILE_DELETED: + return returncode(instance.fs.Remove(filename)) + } + return 0 +} + +func (instance *VirtualizationInstance) QueryFileName(callbackData *projfs.PRJ_CALLBACK_DATA) uintptr { + log.Printf("QueryFileName: '%s'", callbackData.GetFilePathName()) + return 0 +} + +func (instance *VirtualizationInstance) CancelCommand(callbackData *projfs.PRJ_CALLBACK_DATA) uintptr { + return 0 +} + +func (instance *VirtualizationInstance) StartDirectoryEnumeration(callbackData *projfs.PRJ_CALLBACK_DATA, enumerationId *syscall.GUID) uintptr { + log.Printf("StartDirectoryEnumeration: '%v'", *enumerationId) + instance.enumerations[*enumerationId] = &enumerationSession{ + searchstr: 0, + countget: 0, + sentcount: 0, + wildcard: false, + } + return 0 +} + +func (instance *VirtualizationInstance) EndDirectoryEnumeration(callbackData *projfs.PRJ_CALLBACK_DATA, enumerationId *syscall.GUID) uintptr { + log.Printf("EndDirectoryEnumeration: '%v'", *enumerationId) + instance.enumerations[*enumerationId] = nil + return 0 +} + +func (instance *VirtualizationInstance) GetDirectoryEnumeration(callbackData *projfs.PRJ_CALLBACK_DATA, enumerationId *syscall.GUID, searchExpression uintptr, dirEntryBufferHandle projfs.PRJ_DIR_ENTRY_BUFFER_HANDLE) uintptr { + filepath := callbackData.GetFilePathName() + first := instance.enumerations[*enumerationId].countget == 0 + restart := callbackData.Flags&projfs.PRJ_CB_DATA_FLAG_ENUM_RESTART_SCAN != 0 + + session, ok := instance.enumerations[*enumerationId] + if !ok { + return uintptr(syscall.EINVAL) + } + log.Printf("GetDirectoryEnumeration (%t, %t, %d) %s", first, restart, session.sentcount, filepath) + + if restart || first { + session.sentcount = 0 + if searchExpression != 0 { + session.searchstr = searchExpression + session.wildcard = projfs.PrjDoesNameContainWildCards(searchExpression) + } else { + session.searchstr = 0 + session.wildcard = false + } + } + instance.enumerations[*enumerationId].countget++ + + files, err := afero.ReadDir(instance.fs, filepath) + if err != nil { + log.Printf("Error reading directory %s: %s", filepath, err) + return uintptr(syscall.EIO) + } + + for _, file := range files[session.sentcount:] { + if session.searchstr != 0 { + match := projfs.PrjFileNameMatch(file.Name(), session.searchstr) + if !match { + continue + } + } + dirEntry := toBasicInfo(file) + projfs.PrjFillDirEntryBuffer(file.Name(), &dirEntry, dirEntryBufferHandle) + session.sentcount += 1 + } + log.Printf("Sent %d entries", session.sentcount) + return 0 +} + +func toBasicInfo(file fs.FileInfo) projfs.PRJ_FILE_BASIC_INFO { + return projfs.PRJ_FILE_BASIC_INFO{ + IsDirectory: file.IsDir(), + FileSize: file.Size(), + CreationTime: file.ModTime().Unix(), + LastAccessTime: file.ModTime().Unix(), + LastWriteTime: file.ModTime().Unix(), + ChangeTime: file.ModTime().Unix(), + FileAttributes: 0, + } +} + +func (instance *VirtualizationInstance) GetPlaceholderInfo(callbackData *projfs.PRJ_CALLBACK_DATA) uintptr { + var data projfs.PRJ_PLACEHOLDER_INFO + filename := callbackData.GetFilePathName() + log.Printf("GetPlaceholderInfo %s", filename) + stat, err := instance.fs.Stat(filename) + if os.IsNotExist(err) { + return uintptr(0x80070002) + } + if err != nil { + log.Printf("Error getting placeholder info for %s: %s", filename, err) + return uintptr(syscall.EIO) + } + data.FileBasicInfo = toBasicInfo(stat) + return projfs.PrjWritePlaceholderInfo(instance._instanceHandle, callbackData.GetFilePathName(), &data, uint32(unsafe.Sizeof(data))) +} + +func (instance *VirtualizationInstance) GetFileData(callbackData *projfs.PRJ_CALLBACK_DATA, byteOffset uint64, length uint32) uintptr { + filename := callbackData.GetFilePathName() + log.Printf("GetFileData %s", filename) + file, err := instance.fs.Open(filename) + if err != nil { + log.Printf("Error opening file %s: %s", filename, err) + return uintptr(syscall.EIO) + } + defer file.Close() + buffer := make([]byte, length) + _, err = file.ReadAt(buffer, int64(byteOffset)) + if err != nil { + log.Printf("Error reading file %s: %s", filename, err) + return uintptr(syscall.EIO) + } + return projfs.PrjWriteFileData(instance._instanceHandle, &callbackData.DataStreamId, &buffer[0], byteOffset, length) +} diff --git a/filesystem_test.go b/filesystem_test.go new file mode 100644 index 0000000..717f512 --- /dev/null +++ b/filesystem_test.go @@ -0,0 +1,213 @@ +package projfero_test + +import ( + "bytes" + "io" + "log" + "os" + "os/exec" + "reflect" + "strings" + "testing" + + "github.com/balazsgrill/projfero" + "github.com/spf13/afero" +) + +type testInstance struct { + t *testing.T + location string + fs afero.Fs + closer io.Closer + closechan chan bool +} + +func newTestInstance(t *testing.T) *testInstance { + location := t.TempDir() + os.RemoveAll(location) + os.MkdirAll(location, 0x777) + return &testInstance{ + t: t, + location: location, + fs: afero.NewMemMapFs(), + closechan: make(chan bool), + } +} + +func (i *testInstance) start() { + started := make(chan bool) + var err error + go func() { + i.closer, err = projfero.StartProjecting(i.location, i.fs) + started <- true + <-i.closechan + i.closer.Close() + }() + <-started + if err != nil { + log.Fatal(err) + } +} + +func (i *testInstance) osWriteFile(filename string, content string) error { + return exec.Command("cmd", "/c", "echo", content, ">", i.location+"\\"+filename).Run() +} + +func (i *testInstance) osRemoveFile(filename string) error { + return exec.Command("cmd", "/c", "del", i.location+"\\"+filename).Run() +} + +func (i *testInstance) osCreateDir(filename string) error { + return exec.Command("cmd", "/c", "mkdir", i.location+"\\"+filename).Run() +} + +func (i *testInstance) osRemoveDir(filename string) error { + return exec.Command("cmd", "/c", "rmdir", i.location+"\\"+filename).Run() +} + +func (i *testInstance) stop() { + i.closechan <- true +} + +func TestExistingFileOnBackend(t *testing.T) { + instance := newTestInstance(t) + + data := []byte("something") + filename := "test.txt" + err := afero.WriteFile(instance.fs, filename, data, 0x777) + if err != nil { + t.Fatal(err) + } + + instance.start() + defer instance.stop() + + data2, err := os.ReadFile(instance.location + "\\" + filename) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(data, data2) { + t.Errorf("expected %v, got %v", data, data2) + } +} + +func TestFileCreation(t *testing.T) { + instance := newTestInstance(t) + instance.start() + defer instance.stop() + + filename := "test.txt" + data := "something" + err := instance.osWriteFile(filename, data) + if err != nil { + t.Fatal(err) + } + + data2, err := afero.ReadFile(instance.fs, filename) + if err != nil { + t.Fatal(err) + } + + if data != strings.TrimSpace(string(data2)) { + t.Errorf("expected '%s', got '%s'", data, string(data2)) + } +} + +func TestUpdateExistingFileOnBackend(t *testing.T) { + instance := newTestInstance(t) + + data := "something" + filename := "test.txt" + err := afero.WriteFile(instance.fs, filename, []byte(data), 0x777) + if err != nil { + t.Fatal(err) + } + + instance.start() + defer instance.stop() + + data = "somethingelse" + err = instance.osWriteFile(filename, data) + if err != nil { + t.Fatal(err) + } + + data2, err := afero.ReadFile(instance.fs, filename) + if err != nil { + t.Fatal(err) + } + + if data != strings.TrimSpace(string(data2)) { + t.Errorf("expected %s, got %s", data, string(data2)) + } +} + +func TestDeleteExistingFileOnBackend(t *testing.T) { + instance := newTestInstance(t) + data := "something" + filename := "test.txt" + err := afero.WriteFile(instance.fs, filename, []byte(data), 0x777) + if err != nil { + t.Fatal(err) + } + + instance.start() + defer instance.stop() + + err = instance.osRemoveFile(filename) + if err != nil { + t.Fatal(err) + } + + _, err = instance.fs.Stat(filename) + if err != nil { + if os.IsNotExist(err) { + //ok + return + } + t.Fatal(err) + } else { + t.Error("File exists") + } +} + +func TestListFiles(t *testing.T) { + instance := newTestInstance(t) + instance.start() + defer instance.stop() + + data := "something" + filename := "test.txt" + err := afero.WriteFile(instance.fs, filename, []byte(data), 0x777) + if err != nil { + t.Fatal(err) + } + + filename2 := "test2.txt" + err = instance.osWriteFile(filename2, data) + if err != nil { + t.Fatal(err) + } + + expected := make(map[string]bool) + expected[filename] = true + expected[filename2] = true + + entries, err := os.ReadDir(instance.location) + if err != nil { + t.Fatal(err) + } + + actual := make(map[string]bool) + for _, entry := range entries { + if strings.HasPrefix(entry.Name(), ".") { + continue + } + actual[entry.Name()] = true + } + + if !reflect.DeepEqual(expected, actual) { + t.Errorf("expected %v, got %v", expected, actual) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..79c8db6 --- /dev/null +++ b/go.mod @@ -0,0 +1,13 @@ +module github.com/balazsgrill/projfero + +go 1.21.0 + +require ( + github.com/balazsgrill/projfs v0.0.0 + github.com/google/uuid v1.6.0 + github.com/spf13/afero v1.11.0 +) + +require golang.org/x/text v0.14.0 // indirect + +replace github.com/balazsgrill/projfs => ../projfs diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..97ff556 --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=