diff --git a/knownhosts.go b/knownhosts.go index c460031..15d25fa 100644 --- a/knownhosts.go +++ b/knownhosts.go @@ -5,6 +5,7 @@ package knownhosts import ( "encoding/base64" "errors" + "fmt" "io" "net" "sort" @@ -140,11 +141,15 @@ func Line(addresses []string, key ssh.PublicKey) string { func WriteKnownHost(w io.Writer, hostname string, remote net.Addr, key ssh.PublicKey) error { // Always include hostname; only also include remote if it isn't a zero value // and doesn't normalize to the same string as hostname. - addresses := []string{hostname} - remoteStr := remote.String() - remoteStrNormalized := Normalize(remoteStr) - if remoteStrNormalized != "[0.0.0.0]:0" && remoteStrNormalized != Normalize(hostname) { - addresses = append(addresses, remoteStr) + hostnameNormalized := Normalize(hostname) + if strings.ContainsAny(hostnameNormalized, "\t ") { + return fmt.Errorf("knownhosts: hostname '%s' contains spaces", hostnameNormalized) + } + addresses := []string{hostnameNormalized} + remoteStrNormalized := Normalize(remote.String()) + if remoteStrNormalized != "[0.0.0.0]:0" && remoteStrNormalized != hostnameNormalized && + !strings.ContainsAny(remoteStrNormalized, "\t ") { + addresses = append(addresses, remoteStrNormalized) } line := Line(addresses, key) + "\n" _, err := w.Write([]byte(line)) diff --git a/knownhosts_test.go b/knownhosts_test.go index 536443c..f48d2b2 100644 --- a/knownhosts_test.go +++ b/knownhosts_test.go @@ -189,18 +189,31 @@ func TestWriteKnownHost(t *testing.T) { hostname string remoteAddr string want string + err string }{ {hostname: "::1", remoteAddr: "[::1]:22", want: "::1 " + edKeyStr + "\n"}, {hostname: "127.0.0.1", remoteAddr: "127.0.0.1:22", want: "127.0.0.1 " + edKeyStr + "\n"}, {hostname: "ipv4.test", remoteAddr: "192.168.0.1:23", want: "ipv4.test,[192.168.0.1]:23 " + edKeyStr + "\n"}, {hostname: "ipv6.test", remoteAddr: "[ff01::1234]:23", want: "ipv6.test,[ff01::1234]:23 " + edKeyStr + "\n"}, + {hostname: "normal.zone", remoteAddr: "[fe80::1%en0]:22", want: "normal.zone,fe80::1%en0 " + edKeyStr + "\n"}, + {hostname: "spaces.zone", remoteAddr: "[fe80::1%Ethernet 1]:22", want: "spaces.zone " + edKeyStr + "\n"}, + {hostname: "spaces.zone", remoteAddr: "[fe80::1%Ethernet\t2]:23", want: "spaces.zone " + edKeyStr + "\n"}, + {hostname: "[fe80::1%Ethernet 1]:22", err: "knownhosts: hostname 'fe80::1%Ethernet 1' contains spaces"}, + {hostname: "[fe80::1%Ethernet\t2]:23", err: "knownhosts: hostname '[fe80::1%Ethernet\t2]:23' contains spaces"}, } { remote, err := net.ResolveTCPAddr("tcp", m.remoteAddr) if err != nil { t.Fatalf("Unable to resolve tcp addr: %v", err) } var got bytes.Buffer - if err = WriteKnownHost(&got, m.hostname, remote, edKey); err != nil { + err = WriteKnownHost(&got, m.hostname, remote, edKey) + if m.err != "" { + if err == nil || err.Error() != m.err { + t.Errorf("WriteKnownHost(%q) expected error %v, found %v", m.hostname, m.err, err) + } + continue + } + if err != nil { t.Fatalf("Unable to write known host: %v", err) } if got.String() != m.want {