Skip to content

Commit

Permalink
dev: refactor some checks to use forEachKey() (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmzane authored May 10, 2024
1 parent 6dd777b commit 13a1a83
Showing 1 changed file with 58 additions and 125 deletions.
183 changes: 58 additions & 125 deletions sloglint.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ var slogFuncs = map[string]struct {
argsPos int
skipContextCheck bool
}{
// funcName: {argsPos, skipContextCheck}
"log/slog.With": {argsPos: 0, skipContextCheck: true},
"log/slog.Log": {argsPos: 3},
"log/slog.LogAttrs": {argsPos: 3},
Expand Down Expand Up @@ -198,7 +197,7 @@ func visit(pass *analysis.Pass, opts *Options, node ast.Node, stack []ast.Node)

switch opts.NoGlobal {
case "all":
if strings.HasPrefix(name, "log/slog.") || globalLoggerUsed(pass.TypesInfo, call.Fun) {
if strings.HasPrefix(name, "log/slog.") || isGlobalLoggerUsed(pass.TypesInfo, call.Fun) {
pass.Reportf(call.Pos(), "global logger should not be used")
}
case "default":
Expand All @@ -217,13 +216,14 @@ func visit(pass *analysis.Pass, opts *Options, node ast.Node, stack []ast.Node)
}
case "scope":
typ := pass.TypesInfo.TypeOf(call.Args[0])
if typ != nil && typ.String() != "context.Context" && hasContextInScope(pass.TypesInfo, stack) {
if typ != nil && typ.String() != "context.Context" && isContextInScope(pass.TypesInfo, stack) {
pass.Reportf(call.Pos(), "%sContext should be used instead", fn.Name())
}
}
}

if opts.StaticMsg && !staticMsg(call.Args[funcInfo.argsPos-1]) {
msgPos := funcInfo.argsPos - 1
if opts.StaticMsg && !isStaticMsg(call.Args[msgPos]) {
pass.Reportf(call.Pos(), "message should be a string literal or a constant")
}

Expand Down Expand Up @@ -259,54 +259,48 @@ func visit(pass *analysis.Pass, opts *Options, node ast.Node, stack []ast.Node)
pass.Reportf(call.Pos(), "key-value pairs and attributes should not be mixed")
}

if opts.NoRawKeys && rawKeysUsed(pass.TypesInfo, keys, attrs) {
pass.Reportf(call.Pos(), "raw keys should not be used")
if opts.NoRawKeys {
forEachKey(pass.TypesInfo, keys, attrs, func(key ast.Expr) {
if ident, ok := key.(*ast.Ident); !ok || ident.Obj == nil || ident.Obj.Kind != ast.Con {
pass.Reportf(call.Pos(), "raw keys should not be used")
}
})
}

if opts.ArgsOnSepLines && argsOnSameLine(pass.Fset, call, keys, attrs) {
pass.Reportf(call.Pos(), "arguments should be put on separate lines")
checkKeyNamingCase := func(caseFn func(string) string, caseName string) {
forEachKey(pass.TypesInfo, keys, attrs, func(key ast.Expr) {
if name, ok := getKeyName(key); ok && name != caseFn(name) {
pass.Reportf(call.Pos(), "keys should be written in %s", caseName)
}
})
}

if len(opts.ForbiddenKeys) > 0 {
if name, found := badKeyNames(pass.TypesInfo, isForbiddenKey(opts.ForbiddenKeys), keys, attrs); found {
pass.Reportf(call.Pos(), "%q key is forbidden and should not be used", name)
}
switch opts.KeyNamingCase {
case snakeCase:
checkKeyNamingCase(strcase.ToSnake, "snake_case")
case kebabCase:
checkKeyNamingCase(strcase.ToKebab, "kebab-case")
case camelCase:
checkKeyNamingCase(strcase.ToCamel, "camelCase")
case pascalCase:
checkKeyNamingCase(strcase.ToPascal, "PascalCase")
}

switch {
case opts.KeyNamingCase == snakeCase:
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToSnake), keys, attrs); found {
pass.Reportf(call.Pos(), "keys should be written in snake_case")
}
case opts.KeyNamingCase == kebabCase:
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToKebab), keys, attrs); found {
pass.Reportf(call.Pos(), "keys should be written in kebab-case")
}
case opts.KeyNamingCase == camelCase:
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToCamel), keys, attrs); found {
pass.Reportf(call.Pos(), "keys should be written in camelCase")
}
case opts.KeyNamingCase == pascalCase:
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToPascal), keys, attrs); found {
pass.Reportf(call.Pos(), "keys should be written in PascalCase")
}
}
}

func isForbiddenKey(forbiddenKeys []string) func(string) bool {
return func(name string) bool {
return slices.Contains(forbiddenKeys, name)
if len(opts.ForbiddenKeys) > 0 {
forEachKey(pass.TypesInfo, keys, attrs, func(key ast.Expr) {
if name, ok := getKeyName(key); ok && slices.Contains(opts.ForbiddenKeys, name) {
pass.Reportf(call.Pos(), "%q key is forbidden and should not be used", name)
}
})
}
}

func valueChanged(handler func(string) string) func(string) bool {
return func(name string) bool {
return handler(name) != name
if opts.ArgsOnSepLines && areArgsOnSameLine(pass.Fset, call, keys, attrs) {
pass.Reportf(call.Pos(), "arguments should be put on separate lines")
}
}

func globalLoggerUsed(info *types.Info, expr ast.Expr) bool {
selector, ok := expr.(*ast.SelectorExpr)
func isGlobalLoggerUsed(info *types.Info, call ast.Expr) bool {
selector, ok := call.(*ast.SelectorExpr)
if !ok {
return false
}
Expand All @@ -318,7 +312,7 @@ func globalLoggerUsed(info *types.Info, expr ast.Expr) bool {
return obj.Parent() == obj.Pkg().Scope()
}

func hasContextInScope(info *types.Info, stack []ast.Node) bool {
func isContextInScope(info *types.Info, stack []ast.Node) bool {
for i := len(stack) - 1; i >= 0; i-- {
decl, ok := stack[i].(*ast.FuncDecl)
if !ok {
Expand All @@ -336,8 +330,8 @@ func hasContextInScope(info *types.Info, stack []ast.Node) bool {
return false
}

func staticMsg(expr ast.Expr) bool {
switch msg := expr.(type) {
func isStaticMsg(msg ast.Expr) bool {
switch msg := msg.(type) {
case *ast.BasicLit: // e.g. slog.Info("msg")
return msg.Kind == token.STRING
case *ast.Ident: // e.g. const msg = "msg"; slog.Info(msg)
Expand All @@ -347,114 +341,53 @@ func staticMsg(expr ast.Expr) bool {
}
}

func rawKeysUsed(info *types.Info, keys, attrs []ast.Expr) bool {
isConst := func(expr ast.Expr) bool {
ident, ok := expr.(*ast.Ident)
return ok && ident.Obj != nil && ident.Obj.Kind == ast.Con
}

for _, key := range keys {
if !isConst(key) {
return true
}
}

for _, attr := range attrs {
switch attr := attr.(type) {
case *ast.CallExpr: // e.g. slog.Int()
fn := typeutil.StaticCallee(info, attr)
if _, ok := attrFuncs[fn.FullName()]; ok && !isConst(attr.Args[0]) {
return true
}

case *ast.CompositeLit: // slog.Attr{}
isRawKey := func(kv *ast.KeyValueExpr) bool {
return kv.Key.(*ast.Ident).Name == "Key" && !isConst(kv.Value)
}

switch len(attr.Elts) {
case 1: // slog.Attr{Key: ...} | slog.Attr{Value: ...}
kv := attr.Elts[0].(*ast.KeyValueExpr)
if isRawKey(kv) {
return true
}
case 2: // slog.Attr{..., ...} | slog.Attr{Key: ..., Value: ...}
kv1, ok := attr.Elts[0].(*ast.KeyValueExpr)
if ok {
kv2 := attr.Elts[1].(*ast.KeyValueExpr)
if isRawKey(kv1) || isRawKey(kv2) {
return true
}
} else if !isConst(attr.Elts[0]) {
return true
}
}
}
}

return false
}

func badKeyNames(info *types.Info, validationFn func(string) bool, keys, attrs []ast.Expr) (string, bool) {
func forEachKey(info *types.Info, keys, attrs []ast.Expr, fn func(key ast.Expr)) {
for _, key := range keys {
if name, ok := getKeyName(key); ok && validationFn(name) {
return name, true
}
fn(key)
}

for _, attr := range attrs {
var expr ast.Expr

switch attr := attr.(type) {
case *ast.CallExpr: // e.g. slog.Int()
fn := typeutil.StaticCallee(info, attr)
if fn == nil {
callee := typeutil.StaticCallee(info, attr)
if callee == nil {
continue
}
if _, ok := attrFuncs[fn.FullName()]; !ok {
if _, ok := attrFuncs[callee.FullName()]; !ok {
continue
}
expr = attr.Args[0]
fn(attr.Args[0])

case *ast.CompositeLit: // slog.Attr{}
switch len(attr.Elts) {
case 1: // slog.Attr{Key: ...} | slog.Attr{Value: ...}
if kv := attr.Elts[0].(*ast.KeyValueExpr); kv.Key.(*ast.Ident).Name == "Key" {
expr = kv.Value
fn(kv.Value)
}
case 2: // slog.Attr{..., ...} | slog.Attr{Key: ..., Value: ...}
expr = attr.Elts[0]
if kv1, ok := attr.Elts[0].(*ast.KeyValueExpr); ok && kv1.Key.(*ast.Ident).Name == "Key" {
expr = kv1.Value
}
if kv2, ok := attr.Elts[1].(*ast.KeyValueExpr); ok && kv2.Key.(*ast.Ident).Name == "Key" {
expr = kv2.Value
case 2: // slog.Attr{Key: ..., Value: ...} | slog.Attr{Value: ..., Key: ...} | slog.Attr{..., ...}
if kv, ok := attr.Elts[0].(*ast.KeyValueExpr); ok && kv.Key.(*ast.Ident).Name == "Key" {
fn(kv.Value)
} else if kv, ok := attr.Elts[1].(*ast.KeyValueExpr); ok && kv.Key.(*ast.Ident).Name == "Key" {
fn(kv.Value)
} else {
fn(attr.Elts[0])
}
}
}

if name, ok := getKeyName(expr); ok && validationFn(name) {
return name, true
}
}

return "", false
}

func getKeyName(expr ast.Expr) (string, bool) {
if expr == nil {
return "", false
}
if ident, ok := expr.(*ast.Ident); ok {
func getKeyName(key ast.Expr) (string, bool) {
if ident, ok := key.(*ast.Ident); ok {
if ident.Obj == nil || ident.Obj.Decl == nil || ident.Obj.Kind != ast.Con {
return "", false
}
if spec, ok := ident.Obj.Decl.(*ast.ValueSpec); ok && len(spec.Values) > 0 {
// TODO: support len(spec.Values) > 1; e.g. "const foo, bar = 1, 2"
expr = spec.Values[0]
// TODO: support len(spec.Values) > 1; e.g. const foo, bar = 1, 2
key = spec.Values[0]
}
}
if lit, ok := expr.(*ast.BasicLit); ok && lit.Kind == token.STRING {
if lit, ok := key.(*ast.BasicLit); ok && lit.Kind == token.STRING {
// string literals are always quoted.
value, err := strconv.Unquote(lit.Value)
if err != nil {
Expand All @@ -465,7 +398,7 @@ func getKeyName(expr ast.Expr) (string, bool) {
return "", false
}

func argsOnSameLine(fset *token.FileSet, call ast.Expr, keys, attrs []ast.Expr) bool {
func areArgsOnSameLine(fset *token.FileSet, call ast.Expr, keys, attrs []ast.Expr) bool {
if len(keys)+len(attrs) <= 1 {
return false // special case: slog.Info("msg", "key", "value") is ok.
}
Expand Down

0 comments on commit 13a1a83

Please sign in to comment.