Skip to content

Commit

Permalink
Parse original resolv.conf (#1270)
Browse files Browse the repository at this point in the history
Handle original search domains in resolv.conf type implementations.

- parse the original resolv.conf file
- merge the search domains
- ignore the domain keyword
- append any other config lines (sortstlist, options)
- fix read origin resolv.conf from bkp in resolvconf implementation
- fix line length validation
- fix number of search domains validation
  • Loading branch information
pappz authored Nov 3, 2023
1 parent 2c01514 commit 9c4bf1e
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 94 deletions.
220 changes: 160 additions & 60 deletions client/internal/dns/file_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,25 @@
package dns

import (
"bufio"
"bytes"
"fmt"
"os"
"strings"

log "github.com/sirupsen/logrus"
)

const (
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
fileGeneratedResolvConfSearchBeginContent = "search "
fileGeneratedResolvConfContentFormat = fileGeneratedResolvConfContentHeader +
"\n# If needed you can restore the original file by copying back %s\n\nnameserver %s\n" +
fileGeneratedResolvConfSearchBeginContent + "%s\n\n" +
"%s\n"
)
fileGeneratedResolvConfContentHeader = "# Generated by NetBird"
fileGeneratedResolvConfContentHeaderNextLine = fileGeneratedResolvConfContentHeader + `
# If needed you can restore the original file by copying back ` + fileDefaultResolvConfBackupLocation + "\n\n"

const (
fileDefaultResolvConfBackupLocation = defaultResolvConfPath + ".original.netbird"
fileMaxLineCharsLimit = 256
fileMaxNumberOfSearchDomains = 6
)

var fileSearchLineBeginCharCount = len(fileGeneratedResolvConfSearchBeginContent)
fileMaxLineCharsLimit = 256
fileMaxNumberOfSearchDomains = 6
)

type fileConfigurator struct {
originalPerms os.FileMode
Expand Down Expand Up @@ -55,58 +51,39 @@ func (f *fileConfigurator) applyDNSConfig(config hostDNSConfig) error {
}
return fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
}
managerType, err := getOSDNSManagerType()
if err != nil {
return err
}
switch managerType {
case fileManager, netbirdManager:
if !backupFileExist {
err = f.backup()
if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file")
}
}
default:
// todo improve this and maybe restart DNS manager from scratch
return fmt.Errorf("something happened and file manager is not your preferred host dns configurator, restart the agent")
}

var searchDomains string
appendedDomains := 0
for _, dConf := range config.domains {
if dConf.matchOnly || dConf.disabled {
continue
}
if appendedDomains >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, dConf.domain)
continue
}
if fileSearchLineBeginCharCount+len(searchDomains) > fileMaxLineCharsLimit {
// lets log all skipped domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, dConf.domain)
continue
if !backupFileExist {
err = f.backup()
if err != nil {
return fmt.Errorf("unable to backup the resolv.conf file")
}

searchDomains += " " + dConf.domain
appendedDomains++
}

originalContent, err := os.ReadFile(fileDefaultResolvConfBackupLocation)
searchDomainList := searchDomains(config)

originalSearchDomains, nameServers, others, err := originalDNSConfigs(fileDefaultResolvConfBackupLocation)
if err != nil {
log.Errorf("Could not read existing resolv.conf")
log.Error(err)
}
content := fmt.Sprintf(fileGeneratedResolvConfContentFormat, fileDefaultResolvConfBackupLocation, config.serverIP, searchDomains, string(originalContent))
err = writeDNSConfig(content, defaultResolvConfPath, f.originalPerms)

searchDomainList = mergeSearchDomains(searchDomainList, originalSearchDomains)

buf := prepareResolvConfContent(
searchDomainList,
append([]string{config.serverIP}, nameServers...),
others)

log.Debugf("creating managed file %s", defaultResolvConfPath)
err = os.WriteFile(defaultResolvConfPath, buf.Bytes(), f.originalPerms)
if err != nil {
err = f.restore()
if err != nil {
restoreErr := f.restore()
if restoreErr != nil {
log.Errorf("attempt to restore default file failed with error: %s", err)
}
return err
return fmt.Errorf("got an creating resolver file %s. Error: %s", defaultResolvConfPath, err)
}
log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, appendedDomains, searchDomains)

log.Infof("created a NetBird managed %s file with your DNS settings. Added %d search domains. Search list: %s", defaultResolvConfPath, len(searchDomainList), searchDomainList)
return nil
}

Expand Down Expand Up @@ -138,15 +115,138 @@ func (f *fileConfigurator) restore() error {
return os.RemoveAll(fileDefaultResolvConfBackupLocation)
}

func writeDNSConfig(content, fileName string, permissions os.FileMode) error {
log.Debugf("creating managed file %s", fileName)
func prepareResolvConfContent(searchDomains, nameServers, others []string) bytes.Buffer {
var buf bytes.Buffer
buf.WriteString(content)
err := os.WriteFile(fileName, buf.Bytes(), permissions)
buf.WriteString(fileGeneratedResolvConfContentHeaderNextLine)

for _, cfgLine := range others {
buf.WriteString(cfgLine)
buf.WriteString("\n")
}

if len(searchDomains) > 0 {
buf.WriteString("search ")
buf.WriteString(strings.Join(searchDomains, " "))
buf.WriteString("\n")
}

for _, ns := range nameServers {
buf.WriteString("nameserver ")
buf.WriteString(ns)
buf.WriteString("\n")
}
return buf
}

func searchDomains(config hostDNSConfig) []string {
listOfDomains := make([]string, 0)
for _, dConf := range config.domains {
if dConf.matchOnly || dConf.disabled {
continue
}

listOfDomains = append(listOfDomains, dConf.domain)
}
return listOfDomains
}

func originalDNSConfigs(resolvconfFile string) (searchDomains, nameServers, others []string, err error) {
file, err := os.Open(resolvconfFile)
if err != nil {
return fmt.Errorf("got an creating resolver file %s. Error: %s", fileName, err)
err = fmt.Errorf(`could not read existing resolv.conf`)
return
}
return nil
defer file.Close()

reader := bufio.NewReader(file)

for {
lineBytes, isPrefix, readErr := reader.ReadLine()
if readErr != nil {
break
}

if isPrefix {
err = fmt.Errorf(`resolv.conf line too long`)
return
}

line := strings.TrimSpace(string(lineBytes))

if strings.HasPrefix(line, "#") {
continue
}

if strings.HasPrefix(line, "domain") {
continue
}

if strings.HasPrefix(line, "options") && strings.Contains(line, "rotate") {
line = strings.ReplaceAll(line, "rotate", "")
splitLines := strings.Fields(line)
if len(splitLines) == 1 {
continue
}
line = strings.Join(splitLines, " ")
}

if strings.HasPrefix(line, "search") {
splitLines := strings.Fields(line)
if len(splitLines) < 2 {
continue
}

searchDomains = splitLines[1:]
continue
}

if strings.HasPrefix(line, "nameserver") {
splitLines := strings.Fields(line)
if len(splitLines) != 2 {
continue
}
nameServers = append(nameServers, splitLines[1])
continue
}

others = append(others, line)
}
return
}

// merge search domains lists and cut off the list if it is too long
func mergeSearchDomains(searchDomains []string, originalSearchDomains []string) []string {
lineSize := len("search")
searchDomainsList := make([]string, 0, len(searchDomains)+len(originalSearchDomains))

lineSize = validateAndFillSearchDomains(lineSize, &searchDomainsList, searchDomains)
_ = validateAndFillSearchDomains(lineSize, &searchDomainsList, originalSearchDomains)

return searchDomainsList
}

// validateAndFillSearchDomains checks if the search domains list is not too long and if the line is not too long
// extend s slice with vs elements
// return with the number of characters in the searchDomains line
func validateAndFillSearchDomains(initialLineChars int, s *[]string, vs []string) int {
for _, sd := range vs {
tmpCharsNumber := initialLineChars + 1 + len(sd)
if tmpCharsNumber > fileMaxLineCharsLimit {
// lets log all skipped domains
log.Infof("search list line is larger than %d characters. Skipping append of %s domain", fileMaxLineCharsLimit, sd)
continue
}

initialLineChars = tmpCharsNumber

if len(*s) >= fileMaxNumberOfSearchDomains {
// lets log all skipped domains
log.Infof("already appended %d domains to search list. Skipping append of %s domain", fileMaxNumberOfSearchDomains, sd)
continue
}
*s = append(*s, sd)
}
return initialLineChars
}

func copyFile(src, dest string) error {
Expand Down
62 changes: 62 additions & 0 deletions client/internal/dns/file_linux_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package dns

import (
"fmt"
"testing"
)

func Test_mergeSearchDomains(t *testing.T) {
searchDomains := []string{"a", "b"}
originDomains := []string{"a", "b"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 4 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 4)
}
}

func Test_mergeSearchTooMuchDomains(t *testing.T) {
searchDomains := []string{"a", "b", "c", "d", "e", "f", "g"}
originDomains := []string{"h", "i"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 6 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
}
}

func Test_mergeSearchTooMuchDomainsInOrigin(t *testing.T) {
searchDomains := []string{"a", "b"}
originDomains := []string{"c", "d", "e", "f", "g"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 6 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 6)
}
}

func Test_mergeSearchTooLongDomain(t *testing.T) {
searchDomains := []string{getLongLine()}
originDomains := []string{"b"}
mergedDomains := mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 1 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
}

searchDomains = []string{"b"}
originDomains = []string{getLongLine()}

mergedDomains = mergeSearchDomains(searchDomains, originDomains)
if len(mergedDomains) != 1 {
t.Errorf("invalid len of result domains: %d, want: %d", len(mergedDomains), 1)
}
}

func getLongLine() string {
x := "search "
for {
for i := 0; i <= 9; i++ {
if len(x) > fileMaxLineCharsLimit {
return x
}
x = fmt.Sprintf("%s%d", x, i)
}
}
}
Loading

0 comments on commit 9c4bf1e

Please sign in to comment.