diff --git a/scp.go b/scp.go index 56f29b8..06ac7ee 100644 --- a/scp.go +++ b/scp.go @@ -4,7 +4,9 @@ import ( "fmt" "golang.org/x/crypto/ssh" "io" + "io/ioutil" "os" + "path/filepath" ) // Copy: Copy `from` to `target` @@ -15,7 +17,7 @@ func Copy(session *ssh.Session, from, target string) error { } if stat.IsDir() { - return CopyFolder(session, from, target) + return CopyFolder(session, from, stat.Mode(), target) } f, err := os.Open(from) @@ -38,7 +40,11 @@ func CopyFile(session *ssh.Session, file io.Reader, filename string, size int64, return err } - err = client.WriteFile(ConvertFileModeToPermString(mode), size, filename, file) + return copyFile(client, ConvertFileModeToPermString(mode), size, filename, file) +} + +func copyFile(client *RemoteClient, perm string, size int64, filename string, file io.Reader) error { + err := client.WriteFile(perm, size, filename, file) if err != nil { return err } @@ -46,6 +52,58 @@ func CopyFile(session *ssh.Session, file io.Reader, filename string, size int64, return nil } -func CopyFolder(session *ssh.Session, from, target string) error { +func CopyFolder(session *ssh.Session, from string, mode os.FileMode, target string) error { + client, err := NewClient(session) + if err != nil { + return err + } + + err = client.Start(target, true) + if err != nil { + return err + } + + return copyFolder(client, ConvertFileModeToPermString(mode), from) +} + +func copyFolder(client *RemoteClient, perm string, path string) error { + err := client.WriteDirectoryStart(perm, filepath.Base(path)) + if err != nil { + return err + } + + files, err := ioutil.ReadDir(path) + + if err != nil { + return err + } + + for _, file := range files { + if file.IsDir() { + err = copyFolder(client, ConvertFileModeToPermString(file.Mode()), filepath.Join(path, file.Name())) + + if err != nil { + return err + } + } else { + f, err := os.Open(filepath.Join(path, file.Name())) + if err != nil { + return err + } + + err = copyFile(client, ConvertFileModeToPermString(file.Mode()), file.Size(), file.Name(), f) + f.Close() + + if err != nil { + return err + } + } + } + + err = client.WriteDirectoryEnd() + if err != nil { + return err + } + return nil } diff --git a/scp_test.go b/scp_test.go index 4d91183..48066a8 100644 --- a/scp_test.go +++ b/scp_test.go @@ -53,3 +53,26 @@ func TestCopyFileToFolder(t *testing.T) { tt.AssertTrue(t, checkFileExist(t, remotePath)) tt.AssertEqual(t, "abcdefg", string(readFile(t, remotePath))) } + +func TestCopyFolderToFolder(t *testing.T) { + reset(t) + + session := getSshSession(t) + defer session.Close() + + err := Copy(session, "tests_data/t", "/test") + tt.AssertIsNil(t, err) + + tt.AssertTrue(t, checkFileExist(t, "/test/t/a")) + tt.AssertTrue(t, checkFileExist(t, "/test/t/a/b")) + tt.AssertTrue(t, checkFileExist(t, "/test/t/a/c")) + tt.AssertTrue(t, checkFileExist(t, "/test/t/b")) + tt.AssertTrue(t, checkFileExist(t, "/test/t/c")) + tt.AssertTrue(t, checkFileExist(t, "/test/t/c/d")) + tt.AssertTrue(t, checkFileExist(t, "/test/t/e")) + + tt.AssertEqual(t, "hahaha", string(readFile(t, "/test/t/a/b"))) + tt.AssertEqual(t, "xixixixi", string(readFile(t, "/test/t/a/c"))) + tt.AssertEqual(t, "", string(readFile(t, "/test/t/c/d"))) + tt.AssertEqual(t, "root", string(readFile(t, "/test/t/e"))) +} diff --git a/tests_data/t/a/b b/tests_data/t/a/b new file mode 100644 index 0000000..1240583 --- /dev/null +++ b/tests_data/t/a/b @@ -0,0 +1 @@ +hahaha \ No newline at end of file diff --git a/tests_data/t/a/c b/tests_data/t/a/c new file mode 100644 index 0000000..d019789 --- /dev/null +++ b/tests_data/t/a/c @@ -0,0 +1 @@ +xixixixi \ No newline at end of file diff --git a/tests_data/t/c/d b/tests_data/t/c/d new file mode 100644 index 0000000..e69de29 diff --git a/tests_data/t/e b/tests_data/t/e new file mode 100644 index 0000000..93ca142 --- /dev/null +++ b/tests_data/t/e @@ -0,0 +1 @@ +root \ No newline at end of file