From b4631695ef7090b27f4d89967d57e1cdf32ced2e Mon Sep 17 00:00:00 2001 From: Gabe Cook Date: Fri, 6 May 2022 14:46:43 -0500 Subject: [PATCH] :recycle: Refactor recurse function into visitor --- cmd/cmd.go | 2 +- pkg/template/line_comment.go | 6 ++++-- pkg/template/line_comment_test.go | 6 +++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/cmd/cmd.go b/cmd/cmd.go index ee2562b..53780f8 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -148,7 +148,7 @@ func templateReader(r io.Reader) ([]byte, error) { buf.Write([]byte("---\n")) } - if err := template.RecurseNode(conf, &node); err != nil { + if err := template.VisitNodes(conf, template.LineComment, &node); err != nil { return buf.Bytes(), err } diff --git a/pkg/template/line_comment.go b/pkg/template/line_comment.go index a9eb464..6c1e4a2 100644 --- a/pkg/template/line_comment.go +++ b/pkg/template/line_comment.go @@ -15,14 +15,16 @@ func init() { funcMap["tag"] = DockerTag } -func RecurseNode(conf config.Config, node *yaml.Node) error { +type Visitor func(conf config.Config, node *yaml.Node) error + +func VisitNodes(conf config.Config, visit Visitor, node *yaml.Node) error { if len(node.Content) == 0 { if err := LineComment(conf, node); err != nil { return err } } else { for _, node := range node.Content { - if err := RecurseNode(conf, node); err != nil { + if err := VisitNodes(conf, visit, node); err != nil { return err } } diff --git a/pkg/template/line_comment_test.go b/pkg/template/line_comment_test.go index 2369a71..886e9a8 100644 --- a/pkg/template/line_comment_test.go +++ b/pkg/template/line_comment_test.go @@ -37,9 +37,9 @@ func TestRecurseNode(t *testing.T) { var node yaml.Node _ = yaml.Unmarshal([]byte(tt.args.input), &node) - if err := RecurseNode(tt.args.conf, &node); err != nil { + if err := VisitNodes(tt.args.conf, LineComment, &node); err != nil { if (err != nil) != tt.wantErr { - t.Errorf("RecurseNode() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("VisitNodes() error = %v, wantErr %v", err, tt.wantErr) } return } @@ -47,7 +47,7 @@ func TestRecurseNode(t *testing.T) { got, _ := yaml.Marshal(&node) got = bytes.TrimRight(got, "\n") if string(got) != tt.want { - t.Errorf("RecurseNode() = %v, want %v", string(got), tt.want) + t.Errorf("VisitNodes() = %v, want %v", string(got), tt.want) } }) }