diff --git a/pkg/core/controller.go b/pkg/core/controller.go index e34e755..dbbcfac 100644 --- a/pkg/core/controller.go +++ b/pkg/core/controller.go @@ -24,7 +24,6 @@ import ( "syscall" "github.com/mimuret/dtap/v2/pkg/config" - "github.com/mimuret/dtap/v2/pkg/logger" "github.com/mimuret/dtap/v2/pkg/plugin" "github.com/mimuret/dtap/v2/pkg/types" "github.com/pkg/errors" @@ -147,15 +146,20 @@ func (c *Controller) Run(ctx context.Context) error { iCtx, iCancel := context.WithCancel(ctx) for i, inputPlugin := range c.inputPlugins { iwg.Add(1) - go func(i int, ip types.InputPlugin) { - logger.GetLogger().Info("start input plugin", zap.String("name", ip.GetName()), zap.Int("no", i)) - err := ip.Start(iCtx, c.inputBuffer) - logger.GetLogger().Info("finish input plugin", zap.String("name", ip.GetName()), zap.Int("no", i), zap.Error(err)) + ic := &types.InputContext{ + No: i, + Logger: c.logger.With(zap.String("name", inputPlugin.GetName()), zap.Int("no", i)), + Writer: c.inputBuffer, + } + go func(ip types.InputPlugin, ic *types.InputContext) { + ic.Logger.Info("start input plugin") + err := ip.Start(iCtx, ic) + ic.Logger.Info("finish input plugin") if err != nil { errCh <- err } iwg.Done() - }(i, inputPlugin) + }(inputPlugin, ic) } // start outputPlugin @@ -164,15 +168,21 @@ func (c *Controller) Run(ctx context.Context) error { for _, og := range c.outputGroups { for i, outputPlugin := range og.outputs { owg.Add(1) - go func(i int, og OutputGroup, op types.OutputPlugin) { - logger.GetLogger().Info("start output plugin", zap.String("og", og.name), zap.String("name", op.GetName()), zap.Int("no", i)) - err := op.Start(oCtx, og.buffer) - logger.GetLogger().Info("finish output plugin", zap.String("og", og.name), zap.String("name", op.GetName()), zap.Int("no", i), zap.Error(err)) + oc := &types.OutputContext{ + OutputGroup: og.name, + No: i, + Logger: c.logger.With(zap.String("og", og.name), zap.String("name", outputPlugin.GetName()), zap.Int("no", i)), + Reader: og.buffer, + } + go func(op types.OutputPlugin, oc *types.OutputContext) { + oc.Logger.Info("start output plugin") + err := op.Start(oCtx, oc) + oc.Logger.Info("finish output plugin") if err != nil { errCh <- err } owg.Done() - }(i, og, outputPlugin) + }(outputPlugin, oc) } } defer func() { diff --git a/pkg/core/plugins/plugins.go b/pkg/core/plugins/plugins.go index 433e9ed..89a304f 100644 --- a/pkg/core/plugins/plugins.go +++ b/pkg/core/plugins/plugins.go @@ -26,9 +26,9 @@ import ( // input _ "github.com/mimuret/dtap/v2/pkg/plugin/input/file" + _ "github.com/mimuret/dtap/v2/pkg/plugin/input/nats" _ "github.com/mimuret/dtap/v2/pkg/plugin/input/tcp" _ "github.com/mimuret/dtap/v2/pkg/plugin/input/unix" - _ "github.com/mimuret/dtap/v2/pkg/plugin/input/nats" // output _ "github.com/mimuret/dtap/v2/pkg/plugin/output/file" diff --git a/pkg/plugin/input/file/file.go b/pkg/plugin/input/file/file.go index 30bdfe7..ba4c493 100644 --- a/pkg/plugin/input/file/file.go +++ b/pkg/plugin/input/file/file.go @@ -47,7 +47,7 @@ func SetupFile(bs json.RawMessage) (types.InputPlugin, error) { } if input.NewInputServer(p.Format, &framestream.DecoderOptions{ Bidirectional: false, - }) == nil { + }, nil) == nil { return nil, errors.Errorf("invalid format") } p.fs = afero.NewOsFs() @@ -65,16 +65,16 @@ type File struct { Format input.Format } -func (p *File) Start(_ context.Context, w types.Writer) error { +func (p *File) Start(_ context.Context, ic *types.InputContext) error { is := input.NewInputServer(p.Format, &framestream.DecoderOptions{ ContentType: dnstap.FSContentType, Bidirectional: false, - }) + }, ic) r, err := p.fs.Open(p.Path) if err != nil { return fmt.Errorf("failed to open file: %w", err) } - if err := is.Read(r, w); err != nil { + if err := is.Read(r, ic.Writer); err != nil { return fmt.Errorf("failed to push message: %w", err) } return nil diff --git a/pkg/plugin/input/file/file_test.go b/pkg/plugin/input/file/file_test.go index e598f47..56cc7df 100644 --- a/pkg/plugin/input/file/file_test.go +++ b/pkg/plugin/input/file/file_test.go @@ -23,6 +23,7 @@ import ( "github.com/mimuret/dtap/v2/pkg/buffer" "github.com/mimuret/dtap/v2/pkg/plugin/input/file" + "github.com/mimuret/dtap/v2/pkg/testtool" "github.com/mimuret/dtap/v2/pkg/types" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -80,6 +81,7 @@ var _ = Describe("input/file", func() { Context("Start", func() { var ( p types.InputPlugin + ic *types.InputContext fp *file.File err error fs afero.Fs @@ -88,6 +90,7 @@ var _ = Describe("input/file", func() { ) BeforeEach(func() { buf = buffer.NewRingBuffer(10, nil, nil) + ic = testtool.NewTestInputContext(buf) fs = afero.NewMemMapFs() p, err = file.SetupFile(validConfig) Expect(err).To(Succeed()) @@ -96,7 +99,7 @@ var _ = Describe("input/file", func() { }) When("file not exist", func() { BeforeEach(func() { - err = fp.Start(context.TODO(), buf) + err = fp.Start(context.TODO(), ic) }) It("returns error", func() { Expect(err).To(HaveOccurred()) @@ -110,7 +113,7 @@ var _ = Describe("input/file", func() { _, err = f.Write(dummyData) Expect(err).To(Succeed()) f.Close() - err = fp.Start(context.TODO(), buf) + err = fp.Start(context.TODO(), ic) }) It("returns error", func() { Expect(err).To(HaveOccurred()) @@ -124,7 +127,7 @@ var _ = Describe("input/file", func() { _, err = f.Write(validData) Expect(err).To(Succeed()) f.Close() - err = fp.Start(context.TODO(), buf) + err = fp.Start(context.TODO(), ic) }) It("returns error", func() { Expect(err).To(Succeed()) diff --git a/pkg/plugin/input/format.go b/pkg/plugin/input/format.go index 139950c..ad49710 100644 --- a/pkg/plugin/input/format.go +++ b/pkg/plugin/input/format.go @@ -4,6 +4,7 @@ import ( "strings" framestream "github.com/farsightsec/golang-framestream" + "github.com/mimuret/dtap/v2/pkg/types" ) const DefaultFormat = "DNSTAP" @@ -14,18 +15,18 @@ var ( registry = map[Format]NewFormatFunc{} ) -type NewFormatFunc func(options *framestream.DecoderOptions) *InputServer +type NewFormatFunc func(options *framestream.DecoderOptions, ic *types.InputContext) *InputServer func RegisterFormat(f Format, newFunc NewFormatFunc) { f = Format(strings.ToUpper(string(f))) registry[f] = newFunc } -func NewInputServer(f Format, options *framestream.DecoderOptions) *InputServer { +func NewInputServer(f Format, options *framestream.DecoderOptions, ic *types.InputContext) *InputServer { f = Format(strings.ToUpper(string(f))) newFunc := registry[f] if newFunc == nil { return nil } - return newFunc(options) + return newFunc(options, ic) } diff --git a/pkg/plugin/input/input.go b/pkg/plugin/input/input.go index 3155175..b1d41de 100644 --- a/pkg/plugin/input/input.go +++ b/pkg/plugin/input/input.go @@ -7,7 +7,6 @@ import ( dnstap "github.com/dnstap/golang-dnstap" framestream "github.com/farsightsec/golang-framestream" - "github.com/mimuret/dtap/v2/pkg/logger" "github.com/mimuret/dtap/v2/pkg/plugin/pub" "github.com/mimuret/dtap/v2/pkg/types" "github.com/pkg/errors" @@ -56,12 +55,12 @@ type FstrmUnmarshaler func([]byte) (*types.DnstapMessage, error) type InputServer struct { DecoderOptions *framestream.DecoderOptions - logger *zap.Logger connectionManager *connectionManager unmarshaler FstrmUnmarshaler + ic *types.InputContext } -func NewDnstapInputServer(options *framestream.DecoderOptions) *InputServer { +func NewDnstapInputServer(options *framestream.DecoderOptions, ic *types.InputContext) *InputServer { if options == nil { options = &framestream.DecoderOptions{ Bidirectional: true, @@ -72,13 +71,13 @@ func NewDnstapInputServer(options *framestream.DecoderOptions) *InputServer { } return &InputServer{ DecoderOptions: options, - logger: logger.GetLogger(), + ic: ic, connectionManager: newConnectionManager(), unmarshaler: types.NewDnstapMessage, } } -func NewDtapFrameInputServer(options *framestream.DecoderOptions) *InputServer { +func NewDtapFrameInputServer(options *framestream.DecoderOptions, ic *types.InputContext) *InputServer { if options == nil { options = &framestream.DecoderOptions{ Bidirectional: true, @@ -89,7 +88,7 @@ func NewDtapFrameInputServer(options *framestream.DecoderOptions) *InputServer { } return &InputServer{ DecoderOptions: options, - logger: logger.GetLogger(), + ic: ic, connectionManager: newConnectionManager(), unmarshaler: types.NewDnstapMessageFromDtapFrameRaw, } @@ -114,7 +113,7 @@ func (i *InputServer) Serve(ln net.Listener, buf types.Writer) error { go func(conn net.Conn) { if err := i.Read(conn, buf); err != nil { TotalDecordError.Inc() - i.logger.Warn("input error", zap.Error(err)) + i.ic.Logger.Debug("input error", zap.Error(err)) } i.connectionManager.remove(conn) wg.Done() diff --git a/pkg/plugin/input/input_test.go b/pkg/plugin/input/input_test.go index 436184a..c7b2749 100644 --- a/pkg/plugin/input/input_test.go +++ b/pkg/plugin/input/input_test.go @@ -47,7 +47,7 @@ var _ = Describe("InputServer", func() { buf types.Writer ) BeforeEach(func() { - srv = input.NewInputServer(input.FormatDNSTAP, nil) + srv = input.NewInputServer(input.FormatDNSTAP, nil, testtool.NewTestInputContext(nil)) srvErr = nil buf = buffer.NewRingBuffer(100, &counter{}, &counter{}) ln, srvErr = nettest.NewLocalListener("unix") @@ -80,7 +80,7 @@ var _ = Describe("InputServer", func() { buf types.Buffer ) BeforeEach(func() { - srv = input.NewInputServer(input.FormatDNSTAP, nil) + srv = input.NewInputServer(input.FormatDNSTAP, nil, testtool.NewTestInputContext(nil)) buf = buffer.NewRingBuffer(100, &counter{}, &counter{}) connOut, connIn = net.Pipe() srvErr = nil diff --git a/pkg/plugin/input/nats/nats.go b/pkg/plugin/input/nats/nats.go index 4e80eeb..4bee471 100644 --- a/pkg/plugin/input/nats/nats.go +++ b/pkg/plugin/input/nats/nats.go @@ -26,7 +26,6 @@ import ( "github.com/pkg/errors" "go.uber.org/zap" - "github.com/mimuret/dtap/v2/pkg/logger" "github.com/mimuret/dtap/v2/pkg/plugin" "github.com/mimuret/dtap/v2/pkg/types" @@ -59,9 +58,13 @@ func Setup(bs json.RawMessage) (types.InputPlugin, error) { } if input.NewInputServer(s.Format, &framestream.DecoderOptions{ Bidirectional: false, - }) == nil { + }, nil) == nil { return nil, errors.Errorf("invalid format") } + // for test + s.ic = &types.InputContext{ + Logger: zap.NewExample(), + } return s, nil } @@ -73,6 +76,8 @@ type Nats struct { *output.DnstapOutput + ic *types.InputContext + // config Hosts []string Subject string @@ -86,14 +91,15 @@ type Nats struct { Format input.Format } -func (f *Nats) Start(ctx context.Context, w types.Writer) error { +func (f *Nats) Start(ctx context.Context, ic *types.InputContext) error { + f.ic = ic LOOP: for { select { case <-ctx.Done(): break LOOP default: - if err := f.Subscribe(ctx, w); err != nil { + if err := f.Subscribe(ctx, ic.Writer); err != nil { return err } } @@ -121,7 +127,7 @@ func (f *Nats) Open() (*nats.Conn, error) { func (f *Nats) Subscribe(ctx context.Context, w types.Writer) error { is := input.NewInputServer(f.Format, &framestream.DecoderOptions{ Bidirectional: false, - }) + }, f.ic) nc, err := f.Open() if err != nil { return errors.Wrapf(err, "failed to connect nats server") @@ -138,7 +144,7 @@ func (f *Nats) Subscribe(ctx context.Context, w types.Writer) error { _ = sub.Unsubscribe() }() - logger.GetLogger().Info("start subscribe", zap.String("subject", f.Subject), zap.String("queue name", f.QueueName), zap.Int("queue len", f.QueueLen)) + f.ic.Logger.Info("start subscribe", zap.String("subject", f.Subject), zap.String("queue name", f.QueueName), zap.Int("queue len", f.QueueLen)) LOOP: for { select { diff --git a/pkg/plugin/input/tcp/tcp.go b/pkg/plugin/input/tcp/tcp.go index cf02adc..155e6ef 100644 --- a/pkg/plugin/input/tcp/tcp.go +++ b/pkg/plugin/input/tcp/tcp.go @@ -44,7 +44,7 @@ func SetupTCPSocket(bs json.RawMessage) (types.InputPlugin, error) { if p.Port == 0 { return nil, errors.Errorf("missing parameter Port") } - if input.NewInputServer(p.Format, nil) == nil { + if input.NewInputServer(p.Format, nil, nil) == nil { return nil, errors.Errorf("invalid format") } return p, nil @@ -76,7 +76,7 @@ func (p *TCPSocket) Close() error { return p.ln.Close() } -func (p *TCPSocket) Start(ctx context.Context, w types.Writer) error { +func (p *TCPSocket) Start(ctx context.Context, ic *types.InputContext) error { if err := p.Listen(); err != nil { return err } @@ -84,5 +84,5 @@ func (p *TCPSocket) Start(ctx context.Context, w types.Writer) error { <-ctx.Done() p.Close() }() - return input.NewInputServer(p.Format, nil).Serve(p.ln, w) + return input.NewInputServer(p.Format, nil, ic).Serve(p.ln, ic.Writer) } diff --git a/pkg/plugin/input/tcp/tcp_test.go b/pkg/plugin/input/tcp/tcp_test.go index 2963a63..2c3828f 100644 --- a/pkg/plugin/input/tcp/tcp_test.go +++ b/pkg/plugin/input/tcp/tcp_test.go @@ -16,7 +16,10 @@ package tcp_test import ( + "net" + "github.com/goccy/go-json" + "golang.org/x/net/nettest" "github.com/mimuret/dtap/v2/pkg/plugin/input/tcp" "github.com/mimuret/dtap/v2/pkg/types" @@ -68,9 +71,16 @@ var _ = Describe("input/tcp", func() { p *tcp.TCPSocket ) BeforeEach(func() { + ln, lerr := nettest.NewLocalListener("tcp") + Expect(lerr).To(Succeed()) + addr, ok := ln.Addr().(*net.TCPAddr) + Expect(ok).To(BeTrue()) + ln.Close() + ip, err = tcp.SetupTCPSocket(json.RawMessage(`{"Name": "tcp", "Port": 10053}`)) Expect(err).To(Succeed()) p = ip.(*tcp.TCPSocket) + p.Port = uint16(addr.Port) }) When("failed to listen", func() { BeforeEach(func() { diff --git a/pkg/plugin/input/unix/unix.go b/pkg/plugin/input/unix/unix.go index 1e9e9bf..64d7245 100644 --- a/pkg/plugin/input/unix/unix.go +++ b/pkg/plugin/input/unix/unix.go @@ -63,7 +63,7 @@ func SetupUnixSocket(bs json.RawMessage) (types.InputPlugin, error) { p.uid = &uid p.gid = &gid } - if input.NewInputServer(p.Format, nil) == nil { + if input.NewInputServer(p.Format, nil, nil) == nil { return nil, errors.Errorf("invalid format") } return p, nil @@ -102,7 +102,7 @@ func (p *UnixSocket) Close() error { return p.ln.Close() } -func (p *UnixSocket) Start(ctx context.Context, w types.Writer) error { +func (p *UnixSocket) Start(ctx context.Context, ic *types.InputContext) error { if err := p.Listen(); err != nil { return err } @@ -110,5 +110,5 @@ func (p *UnixSocket) Start(ctx context.Context, w types.Writer) error { <-ctx.Done() p.Close() }() - return input.NewInputServer(p.Format, nil).Serve(p.ln, w) + return input.NewInputServer(p.Format, nil, ic).Serve(p.ln, ic.Writer) } diff --git a/pkg/plugin/output/dnstap_fstrm_socket_output.go b/pkg/plugin/output/dnstap_fstrm_socket_output.go index eaf068f..d98ed22 100644 --- a/pkg/plugin/output/dnstap_fstrm_socket_output.go +++ b/pkg/plugin/output/dnstap_fstrm_socket_output.go @@ -28,6 +28,7 @@ import ( ) type SocketOutput interface { + SetOutputContext(*types.OutputContext) NewConnect() (io.Writer, error) Close() } @@ -37,6 +38,7 @@ var _ OutputHandler = &DnstapFstrmSocketOutput{} type DnstapFstrmSocketOutput struct { handler SocketOutput flushTimeout time.Duration + oc *types.OutputContext enc *framestream.Encoder encOpt *framestream.EncoderOptions @@ -60,6 +62,11 @@ func NewDnstapFstrmSocketOutput(handler SocketOutput, flushTimeout time.Duration } } +func (o *DnstapFstrmSocketOutput) SetOutputContext(oc *types.OutputContext) { + o.oc = oc + o.handler.SetOutputContext(oc) +} + func (o *DnstapFstrmSocketOutput) Open() error { w, err := o.handler.NewConnect() if err != nil { diff --git a/pkg/plugin/output/dnstap_fstrm_socket_output_test.go b/pkg/plugin/output/dnstap_fstrm_socket_output_test.go index 7df809e..28a132a 100644 --- a/pkg/plugin/output/dnstap_fstrm_socket_output_test.go +++ b/pkg/plugin/output/dnstap_fstrm_socket_output_test.go @@ -36,6 +36,7 @@ type socketOutput struct { Conn net.Conn } +func (s *socketOutput) SetOutputContext(*types.OutputContext) {} func (s *socketOutput) NewConnect() (io.Writer, error) { s.RunNewConnect++ if s.ErrNewConnect != nil { diff --git a/pkg/plugin/output/dnstap_output.go b/pkg/plugin/output/dnstap_output.go index 93cc4c2..47e6bcc 100644 --- a/pkg/plugin/output/dnstap_output.go +++ b/pkg/plugin/output/dnstap_output.go @@ -19,12 +19,15 @@ import ( "context" "time" - "github.com/mimuret/dtap/v2/pkg/logger" "github.com/mimuret/dtap/v2/pkg/types" + "github.com/pkg/errors" "go.uber.org/zap" ) +const MaxRetryDuration = time.Minute * 1 + type OutputHandler interface { + SetOutputContext(oc *types.OutputContext) Open() error Write(*types.DnstapMessage) error Close() @@ -32,44 +35,50 @@ type OutputHandler interface { type DnstapOutput struct { handler OutputHandler - logger *zap.Logger retryOpenCount uint + maxRetry uint + oc *types.OutputContext } -func NewDnstapOutput(handler OutputHandler) *DnstapOutput { +func NewDnstapOutput(handler OutputHandler, maxRetry uint) *DnstapOutput { if handler == nil { panic("handler is nil") } return &DnstapOutput{ handler: handler, - logger: logger.GetLogger(), retryOpenCount: 0, + maxRetry: 0, } } -func (o *DnstapOutput) Start(ctx context.Context, r types.Reader) error { - o.logger.Debug("start output run") +func (o *DnstapOutput) Start(ctx context.Context, oc *types.OutputContext) error { + o.oc = oc + o.handler.SetOutputContext(oc) + o.oc.Logger.Debug("start output run") L: for { select { case <-ctx.Done(): - o.logger.Debug("Run ctx done") + o.oc.Logger.Debug("Run ctx done") break L default: - if err := o.Run(ctx, r); err != nil { - o.logger.Debug("output running error", zap.Error(err)) + if err := o.Run(ctx, oc.Reader); err != nil { + if o.maxRetry != 0 && o.maxRetry <= o.retryOpenCount { + return errors.Wrap(err, "failed to open output resource") + } + o.oc.Logger.Debug("output running error", zap.Error(err)) } } } - o.logger.Debug("end output run") + o.oc.Logger.Debug("end output run") return nil } func (o *DnstapOutput) Run(ctx context.Context, r types.Reader) error { if err := o.handler.Open(); err != nil { retryDuration := time.Second * time.Duration(1+o.retryOpenCount*o.retryOpenCount) - if retryDuration > time.Minute*3 { - retryDuration = time.Minute * 3 + if retryDuration > MaxRetryDuration { + retryDuration = MaxRetryDuration } time.Sleep(retryDuration) o.retryOpenCount++ @@ -78,7 +87,7 @@ func (o *DnstapOutput) Run(ctx context.Context, r types.Reader) error { o.retryOpenCount = 0 defer o.handler.Close() - o.logger.Debug("start writer") + o.oc.Logger.Debug("start writer") L: for { select { @@ -87,12 +96,12 @@ L: case frame := <-r.Read(): if frame != nil { if err := o.handler.Write(frame); err != nil { - o.logger.Debug("writer error", zap.Error(err)) + o.oc.Logger.Debug("writer error", zap.Error(err)) return err } } } } - o.logger.Debug("end writer") + o.oc.Logger.Debug("end writer") return nil } diff --git a/pkg/plugin/output/dnstap_output_test.go b/pkg/plugin/output/dnstap_output_test.go index e16556e..43717d7 100644 --- a/pkg/plugin/output/dnstap_output_test.go +++ b/pkg/plugin/output/dnstap_output_test.go @@ -44,6 +44,8 @@ type handler struct { Msg *types.DnstapMessage } +func (h *handler) SetOutputContext(*types.OutputContext) { +} func (h *handler) Open() error { h.Opened = true return h.ErrOpen @@ -72,7 +74,7 @@ var _ = Describe("DnstapOutput", func() { ctx, cancelFunc = context.WithCancel(context.Background()) wg = &sync.WaitGroup{} h = &handler{} - out = output.NewDnstapOutput(h) + out = output.NewDnstapOutput(h, 0) msg = testtool.CreateValidDnstapMessage() }) Context("incoming message", func() { @@ -80,7 +82,7 @@ var _ = Describe("DnstapOutput", func() { buf.Write(msg) wg.Add(1) go func() { - err := out.Start(ctx, buf) + err := out.Start(ctx, testtool.NewTestOutputContext(buf)) Expect(err).To(Succeed()) wg.Done() }() @@ -100,7 +102,7 @@ var _ = Describe("DnstapOutput", func() { BeforeEach(func() { wg.Add(1) go func() { - err := out.Start(ctx, buf) + err := out.Start(ctx, testtool.NewTestOutputContext(buf)) Expect(err).To(Succeed()) wg.Done() }() @@ -117,7 +119,7 @@ var _ = Describe("DnstapOutput", func() { h.ErrOpen = errors.New("dummy") wg.Add(1) go func() { - err := out.Start(ctx, buf) + err := out.Start(ctx, testtool.NewTestOutputContext(buf)) Expect(err).To(Succeed()) wg.Done() }() diff --git a/pkg/plugin/output/file/file.go b/pkg/plugin/output/file/file.go index c0e36bc..a6d7b16 100644 --- a/pkg/plugin/output/file/file.go +++ b/pkg/plugin/output/file/file.go @@ -52,7 +52,7 @@ func setup(bs json.RawMessage) (types.OutputPlugin, error) { default: return nil, errors.New("Type is an invalid value") } - s.DnstapOutput = output.NewDnstapOutput(s) + s.DnstapOutput = output.NewDnstapOutput(s, 0) return s, nil } @@ -74,7 +74,12 @@ type Output struct { Format OutputFormat Template string - t *template.Template + t *template.Template + oc *types.OutputContext +} + +func (f *Output) SetOutputContext(oc *types.OutputContext) { + f.oc = oc } func (o *Output) Open() error { diff --git a/pkg/plugin/output/fluent/fluent.go b/pkg/plugin/output/fluent/fluent.go index cac2c27..ea534e0 100644 --- a/pkg/plugin/output/fluent/fluent.go +++ b/pkg/plugin/output/fluent/fluent.go @@ -36,7 +36,7 @@ func Setup(bs json.RawMessage) (types.OutputPlugin, error) { if err := json.Unmarshal(bs, s); err != nil { return nil, errors.Wrap(err, "failed to decode config") } - s.DnstapOutput = output.NewDnstapOutput(s) + s.DnstapOutput = output.NewDnstapOutput(s, s.MaxRetry) return s, nil } @@ -52,6 +52,11 @@ type Fluent struct { // fluent client *fluent.Fluent + oc *types.OutputContext +} + +func (f *Fluent) SetOutputContext(oc *types.OutputContext) { + f.oc = oc } func (o *Fluent) Open() error { diff --git a/pkg/plugin/output/kafka/kafka.go b/pkg/plugin/output/kafka/kafka.go index bd3fc62..058c2a3 100644 --- a/pkg/plugin/output/kafka/kafka.go +++ b/pkg/plugin/output/kafka/kafka.go @@ -66,7 +66,7 @@ func Setup(bs json.RawMessage) (types.OutputPlugin, error) { if err != nil { return nil, err } - s.DnstapOutput = output.NewDnstapOutput(s) + s.DnstapOutput = output.NewDnstapOutput(s, s.MaxRetry) return s, nil } @@ -87,6 +87,11 @@ type Kafka struct { valueSchemaID []byte keyCodec *goavro.Codec keySchemaID []byte + oc *types.OutputContext +} + +func (f *Kafka) SetOutputContext(oc *types.OutputContext) { + f.oc = oc } type OutputType string diff --git a/pkg/plugin/output/metrics/metrics.go b/pkg/plugin/output/metrics/metrics.go index 536e1ea..abac251 100644 --- a/pkg/plugin/output/metrics/metrics.go +++ b/pkg/plugin/output/metrics/metrics.go @@ -137,10 +137,15 @@ func (c *MetricsRule) GetLabels(dm *types.DnstapMessage) []string { type Metrics struct { plugin.PluginCommon Rules []*MetricsRule + oc *types.OutputContext } -func (f *Metrics) Start(ctx context.Context, r types.Reader) error { - return output.NewDnstapOutput(f).Start(ctx, r) +func (f *Metrics) SetOutputContext(oc *types.OutputContext) { + f.oc = oc +} + +func (f *Metrics) Start(ctx context.Context, oc *types.OutputContext) error { + return output.NewDnstapOutput(f, 0).Start(ctx, oc) } func (f *Metrics) Open() error { diff --git a/pkg/plugin/output/metrics/metrics_test.go b/pkg/plugin/output/metrics/metrics_test.go index 6751b97..12711bd 100644 --- a/pkg/plugin/output/metrics/metrics_test.go +++ b/pkg/plugin/output/metrics/metrics_test.go @@ -24,15 +24,15 @@ import ( "github.com/goccy/go-json" - dto "github.com/prometheus/client_model/go" - dnstap "github.com/dnstap/golang-dnstap" "github.com/miekg/dns" "github.com/mimuret/dnsutils/getter" "github.com/mimuret/dnsutils/testtool" "github.com/mimuret/dtap/v2/pkg/buffer" _ "github.com/mimuret/dtap/v2/pkg/plugin/filter/static" + dtaptesttool "github.com/mimuret/dtap/v2/pkg/testtool" "github.com/prometheus/client_golang/prometheus" + dto "github.com/prometheus/client_model/go" "google.golang.org/protobuf/proto" "github.com/mimuret/dtap/v2/pkg/plugin/output/metrics" @@ -177,7 +177,7 @@ var _ = Describe("output/metrics", func() { o = op.(*metrics.Metrics) buf.Write(dm) go func() { - err := op.Start(ctx, buf) + err := op.Start(ctx, dtaptesttool.NewTestOutputContext(buf)) Expect(err).To(Succeed()) }() }) diff --git a/pkg/plugin/output/nats/nats.go b/pkg/plugin/output/nats/nats.go index 912877c..5f1d65f 100644 --- a/pkg/plugin/output/nats/nats.go +++ b/pkg/plugin/output/nats/nats.go @@ -64,7 +64,7 @@ func Setup(bs json.RawMessage) (types.OutputPlugin, error) { if s.publisher == nil { return nil, errors.Errorf("failed to create publisher for format %s", s.Format) } - s.DnstapOutput = output.NewDnstapOutput(s) + s.DnstapOutput = output.NewDnstapOutput(s, s.MaxRetry) return s, nil } @@ -92,6 +92,12 @@ type Nats struct { IntervalSec uint Format pub.Format publisher pub.Publisher + + oc *types.OutputContext +} + +func (f *Nats) SetOutputContext(oc *types.OutputContext) { + f.oc = oc } func (f *Nats) Open() error { diff --git a/pkg/plugin/output/nop/nop.go b/pkg/plugin/output/nop/nop.go index f872a78..df06e86 100644 --- a/pkg/plugin/output/nop/nop.go +++ b/pkg/plugin/output/nop/nop.go @@ -39,13 +39,13 @@ type NOP struct { plugin.PluginCommon } -func (f *NOP) Start(ctx context.Context, r types.Reader) error { +func (f *NOP) Start(ctx context.Context, oc *types.OutputContext) error { LOOP: for { select { case <-ctx.Done(): break LOOP - case <-r.Read(): + case <-oc.Reader.Read(): } } return nil diff --git a/pkg/plugin/output/nop/nop_test.go b/pkg/plugin/output/nop/nop_test.go index 361d5e9..c5b7e99 100644 --- a/pkg/plugin/output/nop/nop_test.go +++ b/pkg/plugin/output/nop/nop_test.go @@ -22,6 +22,7 @@ import ( "github.com/mimuret/dtap/v2/pkg/buffer" "github.com/mimuret/dtap/v2/pkg/plugin/output/nop" + "github.com/mimuret/dtap/v2/pkg/testtool" "github.com/mimuret/dtap/v2/pkg/types" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -58,8 +59,9 @@ var _ = Describe("output/nop", func() { op, err = nop.Setup(json.RawMessage(`{"Name": "nop"}`)) Expect(err).To(Succeed()) ctx, cancelFunc = context.WithCancel(context.Background()) + go func() { - err := op.Start(ctx, buf) + err := op.Start(ctx, testtool.NewTestOutputContext(buf)) Expect(err).To(Succeed()) }() }) diff --git a/pkg/plugin/output/stdout/stdout.go b/pkg/plugin/output/stdout/stdout.go index 006a4b6..f1e7134 100644 --- a/pkg/plugin/output/stdout/stdout.go +++ b/pkg/plugin/output/stdout/stdout.go @@ -52,7 +52,7 @@ func setup(bs json.RawMessage) (types.OutputPlugin, error) { default: return nil, errors.New("Type is an invalid value") } - s.DnstapOutput = output.NewDnstapOutput(s) + s.DnstapOutput = output.NewDnstapOutput(s, s.MaxRetry) return s, nil } @@ -73,7 +73,12 @@ type Stdout struct { Type OutputFormat Template string - t *template.Template + t *template.Template + oc *types.OutputContext +} + +func (f *Stdout) SetOutputContext(oc *types.OutputContext) { + f.oc = oc } func (o *Stdout) Open() error { diff --git a/pkg/plugin/output/tcp/tcp.go b/pkg/plugin/output/tcp/tcp.go index 48360af..da4c7e9 100644 --- a/pkg/plugin/output/tcp/tcp.go +++ b/pkg/plugin/output/tcp/tcp.go @@ -45,7 +45,7 @@ func Setup(bs json.RawMessage) (types.OutputPlugin, error) { if s.Port == 0 { return nil, errors.Errorf("missing parameter Port") } - s.DnstapOutput = output.NewDnstapOutput(output.NewDnstapFstrmSocketOutput(s, time.Second, nil)) + s.DnstapOutput = output.NewDnstapOutput(output.NewDnstapFstrmSocketOutput(s, time.Second, nil), s.MaxRetry) return s, nil } @@ -59,6 +59,12 @@ type TCP struct { Port uint16 w net.Conn + + oc *types.OutputContext +} + +func (f *TCP) SetOutputContext(oc *types.OutputContext) { + f.oc = oc } func (t *TCP) NewConnect() (io.Writer, error) { diff --git a/pkg/plugin/output/tcp/tcp_test.go b/pkg/plugin/output/tcp/tcp_test.go index 5368688..5ad6673 100644 --- a/pkg/plugin/output/tcp/tcp_test.go +++ b/pkg/plugin/output/tcp/tcp_test.go @@ -20,6 +20,7 @@ import ( "net" "github.com/goccy/go-json" + "golang.org/x/net/nettest" "github.com/mimuret/dtap/v2/pkg/plugin/output/tcp" "github.com/mimuret/dtap/v2/pkg/types" @@ -79,11 +80,14 @@ var _ = Describe("output/tcp", func() { conn io.Writer ) BeforeEach(func() { - ln, err = net.Listen("tcp", "127.0.0.1:10053") + ln, err = nettest.NewLocalListener("tcp") Expect(err).To(Succeed()) + addr, ok := ln.Addr().(*net.TCPAddr) + Expect(ok).To(BeTrue()) op, err = tcp.Setup(json.RawMessage(`{"Name": "tcp", "Host": "127.0.0.1","Port": 10053}`)) Expect(err).To(Succeed()) p = op.(*tcp.TCP) + p.Port = uint16(addr.Port) }) AfterEach(func() { ln.Close() diff --git a/pkg/plugin/output/unix/unix.go b/pkg/plugin/output/unix/unix.go index 79b420f..d910de4 100644 --- a/pkg/plugin/output/unix/unix.go +++ b/pkg/plugin/output/unix/unix.go @@ -41,7 +41,7 @@ func Setup(bs json.RawMessage) (types.OutputPlugin, error) { if s.Path == "" { return nil, errors.New("missing parameter Path") } - s.DnstapOutput = output.NewDnstapOutput(output.NewDnstapFstrmSocketOutput(s, time.Second, nil)) + s.DnstapOutput = output.NewDnstapOutput(output.NewDnstapFstrmSocketOutput(s, time.Second, nil), s.MaxRetry) return s, nil } @@ -54,7 +54,12 @@ type Unix struct { Path string - w net.Conn + oc *types.OutputContext + w net.Conn +} + +func (f *Unix) SetOutputContext(oc *types.OutputContext) { + f.oc = oc } func (f *Unix) NewConnect() (io.Writer, error) { diff --git a/pkg/plugin/plugin.go b/pkg/plugin/plugin.go index 07ffc7e..8a0b11a 100644 --- a/pkg/plugin/plugin.go +++ b/pkg/plugin/plugin.go @@ -24,7 +24,8 @@ import ( ) type PluginCommon struct { - Name string `json:"Name"` + Name string `json:"Name"` + MaxRetry uint `json:"MaxRetry"` } func (p *PluginCommon) GetName() string { diff --git a/pkg/testtool/context.go b/pkg/testtool/context.go new file mode 100644 index 0000000..8446a80 --- /dev/null +++ b/pkg/testtool/context.go @@ -0,0 +1,30 @@ +package testtool + +import ( + "github.com/mimuret/dtap/v2/pkg/buffer" + "github.com/mimuret/dtap/v2/pkg/types" + "go.uber.org/zap" +) + +func NewTestInputContext(w types.Writer) *types.InputContext { + if w == nil { + w = buffer.NewRingBuffer(10, nil, nil) + } + return &types.InputContext{ + No: 0, + Writer: w, + Logger: zap.NewExample(), + } +} + +func NewTestOutputContext(r types.Reader) *types.OutputContext { + if r == nil { + r = buffer.NewRingBuffer(10, nil, nil) + } + return &types.OutputContext{ + OutputGroup: "og-test", + No: 0, + Reader: r, + Logger: zap.NewExample(), + } +} diff --git a/pkg/types/interface.go b/pkg/types/interface.go index 3baa11e..1a6677a 100644 --- a/pkg/types/interface.go +++ b/pkg/types/interface.go @@ -18,6 +18,8 @@ package types import ( "context" + + "go.uber.org/zap" ) type Buffer interface { @@ -40,9 +42,15 @@ type Plugin interface { GetName() string } +type InputContext struct { + No int + Logger *zap.Logger + Writer Writer +} + type InputPlugin interface { Plugin - Start(context.Context, Writer) error + Start(context.Context, *InputContext) error } type FilterPlugin interface { @@ -50,7 +58,14 @@ type FilterPlugin interface { Filter(*DnstapMessage) *DnstapMessage } +type OutputContext struct { + OutputGroup string + No int + Logger *zap.Logger + Reader Reader +} + type OutputPlugin interface { Plugin - Start(context.Context, Reader) error + Start(context.Context, *OutputContext) error }