From 223c521f60acba5e98711f3278f09597cbc3634d Mon Sep 17 00:00:00 2001 From: Andrew LeFevre Date: Thu, 7 Jul 2022 17:01:07 -0400 Subject: [PATCH] wait until seccomp filters are created before starting to filter packets --- filter.go | 142 +++++++++++++++++++++++++++++++-------- filter_test.go | 111 +++++++++++++++++++++--------- filter_test_bin_setup.go | 2 + filter_test_helpers.go | 41 +++++++++++ filter_test_setup.go | 5 +- fuzz_test.go | 3 +- go.mod | 1 + go.sum | 5 +- main.go | 8 ++- 9 files changed, 251 insertions(+), 67 deletions(-) diff --git a/filter.go b/filter.go index ecb925c..2c7a6d6 100644 --- a/filter.go +++ b/filter.go @@ -35,6 +35,9 @@ const ( type FilterManager struct { ready chan struct{} + abort chan struct{} + + started bool logger *zap.Logger @@ -48,9 +51,15 @@ type FilterManager struct { } type filter struct { - dnsReqNFReady chan struct{} - genericNFReady chan struct{} - wg sync.WaitGroup + dnsReqReady chan struct{} + dnsReqAbort chan struct{} + genericReady chan struct{} + genericAbort chan struct{} + cachingReady chan struct{} + cachingAbort chan struct{} + + started bool + wg sync.WaitGroup opts *FilterOptions @@ -106,9 +115,12 @@ type resolver interface { type enforcerCreator func(ctx context.Context, logger *zap.Logger, queueNum uint16, ipv6 bool, hook nfqueue.HookFunc) (enforcer, error) -func StartFilters(ctx context.Context, logger *zap.Logger, config *Config) (*FilterManager, error) { +// CreateFilters creates packet filters. The returned FilterManager can +// be used to start or stop packet filtering. +func CreateFilters(ctx context.Context, logger *zap.Logger, config *Config) (*FilterManager, error) { f := FilterManager{ ready: make(chan struct{}), + abort: make(chan struct{}), logger: logger, queueNum4: config.InboundDNSQueue.IPv4, queueNum6: config.InboundDNSQueue.IPv6, @@ -118,14 +130,14 @@ func StartFilters(ctx context.Context, logger *zap.Logger, config *Config) (*Fil // if mock enforcers and resolver is not set, use real ones newEnforcer := config.enforcerCreator if newEnforcer == nil { - newEnforcer = startNfQueue + newEnforcer = openNfQueue } res := config.resolver if res == nil { res = &net.Resolver{} } - nf4, nf6, err := startNfQueues(ctx, logger, config.InboundDNSQueue, newEnforcer, func(ipv6 bool) nfqueue.HookFunc { + nf4, nf6, err := openNfQueues(ctx, logger, config.InboundDNSQueue, newEnforcer, func(ipv6 bool) nfqueue.HookFunc { return newDNSResponseCallback(&f, ipv6) }) if err != nil { @@ -136,7 +148,7 @@ func StartFilters(ctx context.Context, logger *zap.Logger, config *Config) (*Fil for i := range config.Filters { isSelfFilter := config.SelfDNSQueue == config.Filters[i].DNSQueue - filter, err := startFilter(ctx, logger, &config.Filters[i], isSelfFilter, newEnforcer, res) + filter, err := createFilter(ctx, logger, &config.Filters[i], isSelfFilter, newEnforcer, res) if err != nil { // TODO: stop other filters here return nil, err @@ -145,16 +157,32 @@ func StartFilters(ctx context.Context, logger *zap.Logger, config *Config) (*Fil f.filters[i] = filter } + return &f, nil +} + +// Start starts packet filtering. +func (f *FilterManager) Start() { // Let the DNS response callback know everything is setup. The // callback will be executing on another goroutine started by // nfqueue.RegisterWithErrorFunc, but only after a packet is // received on its nfqueue. close(f.ready) - return &f, nil + for i := range f.filters { + f.filters[i].start() + } + + f.started = true } +// Stop stops packet filtering and cleans up owned resources. func (f *FilterManager) Stop() { + // if the filters have not been started yet, tell running goroutines + // to abort and finish + if !f.started { + close(f.abort) + } + if f.dnsRespNF4 != nil { f.dnsRespNF4.Close() } @@ -167,27 +195,31 @@ func (f *FilterManager) Stop() { } } -func startFilter(ctx context.Context, logger *zap.Logger, opts *FilterOptions, isSelfFilter bool, newEnforcer enforcerCreator, res resolver) (*filter, error) { +func createFilter(ctx context.Context, logger *zap.Logger, opts *FilterOptions, isSelfFilter bool, newEnforcer enforcerCreator, res resolver) (*filter, error) { filterLogger := logger if opts.Name != "" { filterLogger = filterLogger.With(zap.String("filter.name", opts.Name)) } f := filter{ - dnsReqNFReady: make(chan struct{}), - genericNFReady: make(chan struct{}), - opts: opts, - logger: filterLogger, - res: res, - connections: NewTimedCache[connectionID](logger, true), - isSelfFilter: isSelfFilter, + dnsReqReady: make(chan struct{}), + dnsReqAbort: make(chan struct{}), + genericReady: make(chan struct{}), + genericAbort: make(chan struct{}), + cachingReady: make(chan struct{}), + cachingAbort: make(chan struct{}), + opts: opts, + logger: filterLogger, + res: res, + connections: NewTimedCache[connectionID](logger, true), + isSelfFilter: isSelfFilter, } if opts.TrafficQueue.eitherSet() { f.allowedIPs = NewTimedCache[netip.Addr](f.logger, false) f.additionalHostnames = NewTimedCache[string](filterLogger, false) - nf4, nf6, err := startNfQueues(ctx, filterLogger, opts.TrafficQueue, newEnforcer, func(ipv6 bool) nfqueue.HookFunc { + nf4, nf6, err := openNfQueues(ctx, filterLogger, opts.TrafficQueue, newEnforcer, func(ipv6 bool) nfqueue.HookFunc { return newGenericCallback(&f, ipv6) }) if err != nil { @@ -195,8 +227,6 @@ func startFilter(ctx context.Context, logger *zap.Logger, opts *FilterOptions, i } f.genericNF4 = nf4 f.genericNF6 = nf6 - // let the generic packet callback know everything is setup - close(f.genericNFReady) if len(f.opts.CachedHostnames) > 0 { f.wg.Add(1) @@ -209,7 +239,7 @@ func startFilter(ctx context.Context, logger *zap.Logger, opts *FilterOptions, i } if opts.DNSQueue.eitherSet() { - nf4, nf6, err := startNfQueues(ctx, filterLogger, opts.DNSQueue, newEnforcer, func(ipv6 bool) nfqueue.HookFunc { + nf4, nf6, err := openNfQueues(ctx, filterLogger, opts.DNSQueue, newEnforcer, func(ipv6 bool) nfqueue.HookFunc { return newDNSRequestCallback(&f, ipv6) }) if err != nil { @@ -217,14 +247,13 @@ func startFilter(ctx context.Context, logger *zap.Logger, opts *FilterOptions, i } f.dnsReqNF4 = nf4 f.dnsReqNF6 = nf6 - // let the DNS request callback know everything is setup - close(f.dnsReqNFReady) + } return &f, nil } -func startNfQueues(ctx context.Context, logger *zap.Logger, queues queue, newEnforcer enforcerCreator, hookGen func(ipv6 bool) nfqueue.HookFunc) (nf4 enforcer, nf6 enforcer, err error) { +func openNfQueues(ctx context.Context, logger *zap.Logger, queues queue, newEnforcer enforcerCreator, hookGen func(ipv6 bool) nfqueue.HookFunc) (nf4 enforcer, nf6 enforcer, err error) { if queues.IPv4 != 0 { nf4, err = newEnforcer(ctx, logger, queues.IPv4, false, hookGen(false)) if err != nil { @@ -241,7 +270,7 @@ func startNfQueues(ctx context.Context, logger *zap.Logger, queues queue, newEnf return nf4, nf6, nil } -func startNfQueue(ctx context.Context, logger *zap.Logger, queueNum uint16, ipv6 bool, hook nfqueue.HookFunc) (enforcer, error) { +func openNfQueue(ctx context.Context, logger *zap.Logger, queueNum uint16, ipv6 bool, hook nfqueue.HookFunc) (enforcer, error) { afFamily := unix.AF_INET if ipv6 { afFamily = unix.AF_INET6 @@ -291,7 +320,30 @@ func startNfQueue(ctx context.Context, logger *zap.Logger, queueNum uint16, ipv6 return nf, nil } +func (f *filter) start() { + if f.opts.DNSQueue.eitherSet() { + close(f.dnsReqReady) + } + if f.opts.TrafficQueue.eitherSet() { + close(f.genericReady) + } + if len(f.opts.CachedHostnames) > 0 { + close(f.cachingReady) + } + + f.started = true +} + func (f *filter) cacheHostnames(ctx context.Context, logger *zap.Logger) { + // wait until the filter manager is setup to prevent race conditions + select { + case <-f.cachingReady: + case <-f.cachingAbort: + // the filter manager has been stopped before it was started, + // return so the parent filter can finish cleaning up + return + } + logger.Debug("starting cache loop") var ( @@ -344,6 +396,20 @@ func (f *filter) cacheHostnames(ctx context.Context, logger *zap.Logger) { } func (f *filter) close() { + // if the filter has not been started yet, tell running goroutines + // to abort and finish + if !f.started { + if f.opts.DNSQueue.eitherSet() { + close(f.dnsReqAbort) + } + if f.opts.TrafficQueue.eitherSet() { + close(f.genericAbort) + } + if len(f.opts.CachedHostnames) > 0 { + close(f.cachingAbort) + } + } + f.wg.Wait() if f.dnsReqNF4 != nil { @@ -381,8 +447,14 @@ func newDNSRequestCallback(f *filter, ipv6 bool) nfqueue.HookFunc { logger.Info("started nfqueue") return func(attr nfqueue.Attribute) int { - // wait until the filter is setup to prevent race conditions - <-f.dnsReqNFReady + // wait until the filter manager is setup to prevent race conditions + select { + case <-f.dnsReqReady: + // the filter manager has been stopped before it was started, + // return so the parent filter can finish cleaning up + case <-f.dnsReqAbort: + return 0 + } var dnsReqNF enforcer if !ipv6 { @@ -577,7 +649,13 @@ func newDNSResponseCallback(f *FilterManager, ipv6 bool) nfqueue.HookFunc { return func(attr nfqueue.Attribute) int { // wait until the filter manager is setup to prevent race conditions - <-f.ready + select { + case <-f.ready: + case <-f.abort: + // the filter manager has been stopped before it was started, + // return so the parent filter can finish cleaning up + return 0 + } var dnsRespNF enforcer if !ipv6 { @@ -732,8 +810,14 @@ func newGenericCallback(f *filter, ipv6 bool) nfqueue.HookFunc { logger.Info("started nfqueue") return func(attr nfqueue.Attribute) int { - // wait until the filter is setup to prevent race conditions - <-f.genericNFReady + // wait until the filter manager is setup to prevent race conditions + select { + case <-f.genericReady: + case <-f.genericAbort: + // the filter manager has been stopped before it was started, + // return so the parent filter can finish cleaning up + return 0 + } var genericNF enforcer if !ipv6 { diff --git a/filter_test.go b/filter_test.go index ceee0c1..543a57e 100644 --- a/filter_test.go +++ b/filter_test.go @@ -2,13 +2,14 @@ package main import ( "context" - "errors" "net" - "net/http" "testing" "time" + "github.com/florianl/go-nfqueue" "github.com/matryer/is" + "go.uber.org/goleak" + "go.uber.org/zap" ) func TestFiltering(t *testing.T) { @@ -213,42 +214,88 @@ cachedHostnames = [ is.True(reqFailed(err)) // lookup of disallowed domain should fail } -func makeHTTPReqs(client4, client6 *http.Client, addr string) error { - if client4 != nil { - resp, err := client4.Get(addr) - if err != nil { - return err - } - resp.Body.Close() +func TestFiltersStart(t *testing.T) { + if testingWithBinary { + t.Skip() } - if client6 != nil { - resp, err := client6.Get(addr) - if err != nil { - return err - } - resp.Body.Close() - } + configBytes := []byte(` +inboundDNSQueue.ipv6 = 10 +selfDNSQueue.ipv6 = 110 - return nil -} +[[filters]] +name = "test" +dnsQueue.ipv6 = 1010 +trafficQueue.ipv6 = 1011 +reCacheEvery = "1m" +cachedHostnames = [ + "example.com", +] +allowAnswersFor = "1s" +allowedHostnames = [ + "test.org" +]`) -func reqFailed(err error) bool { - var dnsErr *net.DNSError - if errors.As(err, &dnsErr) { - return true - } + is := is.New(t) - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - return true - } + config, err := parseConfigBytes(configBytes) + is.NoErr(err) - return false -} + config.enforcerCreator = newMockEnforcer + + t.Run("filters waiting", func(t *testing.T) { + initMockEnforcers() + + ctx, cancel := context.WithCancel(context.Background()) + f, err := CreateFilters(ctx, zap.NewNop(), config) + is.NoErr(err) + t.Cleanup(func() { + cancel() + f.Stop() + }) + + finishedAt := make(chan time.Time) + + go func() { + mockEnforcers[config.InboundDNSQueue.IPv6].hook(nfqueue.Attribute{}) + t.Log("finished DNS reply queue") + finishedAt <- time.Now() + }() + // the self-filter will be the first filter + testFilter := config.Filters[1] + go func() { + mockEnforcers[testFilter.DNSQueue.IPv6].hook(nfqueue.Attribute{}) + t.Log("finished DNS request queue") + finishedAt <- time.Now() + }() + go func() { + mockEnforcers[testFilter.TrafficQueue.IPv6].hook(nfqueue.Attribute{}) + t.Log("finished generic queue") + finishedAt <- time.Now() + }() + + time.Sleep(time.Second) + startedAt := time.Now() + f.Start() + t.Log("starting filters") + + for i := 0; i < 3; i++ { + t := <-finishedAt + is.True(t.After(startedAt)) // packet handling should have finished after filters were started + } + }) + + t.Run("stopping without starting", func(t *testing.T) { + // test that goroutines are cleanly shutdown + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) -func getTimeout(t *testing.T) context.Context { - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - t.Cleanup(cancel) + // use real nfqueues + config.enforcerCreator = nil + ctx, cancel := context.WithCancel(context.Background()) + f, err := CreateFilters(ctx, zap.NewNop(), config) + is.NoErr(err) - return ctx + cancel() + f.Stop() + }) } diff --git a/filter_test_bin_setup.go b/filter_test_bin_setup.go index e2ea034..305ff2b 100644 --- a/filter_test_bin_setup.go +++ b/filter_test_bin_setup.go @@ -11,6 +11,8 @@ import ( "time" ) +var testingWithBinary = true + func initFilters(t *testing.T, configStr string, iptablesRules, ip6tablesRules []string) (*http.Client, *http.Client, func()) { f, err := os.CreateTemp("", "egress_eddie") if err != nil { diff --git a/filter_test_helpers.go b/filter_test_helpers.go index 81e62ef..64a1989 100644 --- a/filter_test_helpers.go +++ b/filter_test_helpers.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "net" "net/http" "os/exec" @@ -11,6 +12,46 @@ import ( "github.com/anmitsu/go-shlex" ) +func makeHTTPReqs(client4, client6 *http.Client, addr string) error { + if client4 != nil { + resp, err := client4.Get(addr) + if err != nil { + return err + } + resp.Body.Close() + } + + if client6 != nil { + resp, err := client6.Get(addr) + if err != nil { + return err + } + resp.Body.Close() + } + + return nil +} + +func reqFailed(err error) bool { + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) { + return true + } + + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + return true + } + + return false +} + +func getTimeout(t *testing.T) context.Context { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + t.Cleanup(cancel) + + return ctx +} + func getHTTPClients() (*http.Client, *http.Client) { dialer := net.Dialer{ FallbackDelay: -1, diff --git a/filter_test_setup.go b/filter_test_setup.go index f49cf42..b86c076 100644 --- a/filter_test_setup.go +++ b/filter_test_setup.go @@ -11,6 +11,8 @@ import ( "go.uber.org/zap/zapcore" ) +var testingWithBinary = false + func initFilters(t *testing.T, configStr string, iptablesRules, ip6tablesRules []string) (*http.Client, *http.Client, func()) { config, err := parseConfigBytes([]byte(configStr)) if err != nil { @@ -40,10 +42,11 @@ func initFilters(t *testing.T, configStr string, iptablesRules, ip6tablesRules [ } ctx, cancel := context.WithCancel(context.Background()) - filters, err := StartFilters(ctx, logger, config) + filters, err := CreateFilters(ctx, logger, config) if err != nil { t.Fatalf("error starting filters: %v", err) } + filters.Start() client4, client6 := getHTTPClients() diff --git a/fuzz_test.go b/fuzz_test.go index 0418a42..224cd2e 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -58,10 +58,11 @@ func FuzzFiltering(f *testing.F) { // test that a config that passes validation won't cause a // error/panic when starting filters ctx, cancel := context.WithCancel(context.Background()) - f, err := StartFilters(ctx, logger, config) + f, err := CreateFilters(ctx, logger, config) if err != nil { failAndDumpConfig(t, cb, "error starting filters: %v", err) } + f.Start() allowIPv4Port := uint16(1000) allowIPv6Port := uint16(1010) diff --git a/go.mod b/go.mod index 4ae10ee..0523915 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( require ( github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/matryer/is v1.4.0 + go.uber.org/goleak v1.1.12 ) require ( diff --git a/go.sum b/go.sum index 6650504..6e763d2 100644 --- a/go.sum +++ b/go.sum @@ -41,8 +41,9 @@ github.com/stretchr/testify v1.7.1-0.20210427113832-6241f9ab9942/go.mod h1:6Fq8o github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= @@ -50,6 +51,7 @@ go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b h1:Wh+f8QHJXR411sJR8/vRBTZ7YapZaRvUcLFFJhusH0k= golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= @@ -89,6 +91,7 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/tools v0.1.9 h1:j9KsMiaP1c3B0OTQGth0/k+miLGTgLsAFUCrF2vLcF8= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/main.go b/main.go index 81123d3..a949cb3 100644 --- a/main.go +++ b/main.go @@ -106,12 +106,10 @@ func main() { ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - filters, err := StartFilters(ctx, logger, config) + filters, err := CreateFilters(ctx, logger, config) if err != nil { logger.Fatal("error starting filters", zap.NamedError("error", err)) } - // TODO: block until seccomp filters are set - logger.Info("started filtering") defer func() { cancel() @@ -133,5 +131,9 @@ func main() { } logger.Info("applied seccomp filters", zap.Int("syscalls.allowed", numAllowedSyscalls)) + // Start filters now that seccomp filters have been applied + filters.Start() + logger.Info("started filtering") + <-ctx.Done() }