Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev: refactor some checks to use forEachKey() #43

Merged
merged 2 commits into from
May 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 58 additions & 125 deletions sloglint.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ var slogFuncs = map[string]struct {
argsPos int
skipContextCheck bool
}{
// funcName: {argsPos, skipContextCheck}
"log/slog.With": {argsPos: 0, skipContextCheck: true},
"log/slog.Log": {argsPos: 3},
"log/slog.LogAttrs": {argsPos: 3},
Expand Down Expand Up @@ -198,7 +197,7 @@ func visit(pass *analysis.Pass, opts *Options, node ast.Node, stack []ast.Node)

switch opts.NoGlobal {
case "all":
if strings.HasPrefix(name, "log/slog.") || globalLoggerUsed(pass.TypesInfo, call.Fun) {
if strings.HasPrefix(name, "log/slog.") || isGlobalLoggerUsed(pass.TypesInfo, call.Fun) {
pass.Reportf(call.Pos(), "global logger should not be used")
}
case "default":
Expand All @@ -217,13 +216,14 @@ func visit(pass *analysis.Pass, opts *Options, node ast.Node, stack []ast.Node)
}
case "scope":
typ := pass.TypesInfo.TypeOf(call.Args[0])
if typ != nil && typ.String() != "context.Context" && hasContextInScope(pass.TypesInfo, stack) {
if typ != nil && typ.String() != "context.Context" && isContextInScope(pass.TypesInfo, stack) {
pass.Reportf(call.Pos(), "%sContext should be used instead", fn.Name())
}
}
}

if opts.StaticMsg && !staticMsg(call.Args[funcInfo.argsPos-1]) {
msgPos := funcInfo.argsPos - 1
if opts.StaticMsg && !isStaticMsg(call.Args[msgPos]) {
pass.Reportf(call.Pos(), "message should be a string literal or a constant")
}

Expand Down Expand Up @@ -259,54 +259,48 @@ func visit(pass *analysis.Pass, opts *Options, node ast.Node, stack []ast.Node)
pass.Reportf(call.Pos(), "key-value pairs and attributes should not be mixed")
}

if opts.NoRawKeys && rawKeysUsed(pass.TypesInfo, keys, attrs) {
pass.Reportf(call.Pos(), "raw keys should not be used")
if opts.NoRawKeys {
forEachKey(pass.TypesInfo, keys, attrs, func(key ast.Expr) {
if ident, ok := key.(*ast.Ident); !ok || ident.Obj == nil || ident.Obj.Kind != ast.Con {
pass.Reportf(call.Pos(), "raw keys should not be used")
}
})
}

if opts.ArgsOnSepLines && argsOnSameLine(pass.Fset, call, keys, attrs) {
pass.Reportf(call.Pos(), "arguments should be put on separate lines")
checkKeyNamingCase := func(caseFn func(string) string, caseName string) {
forEachKey(pass.TypesInfo, keys, attrs, func(key ast.Expr) {
if name, ok := getKeyName(key); ok && name != caseFn(name) {
pass.Reportf(call.Pos(), "keys should be written in %s", caseName)
}
})
}

if len(opts.ForbiddenKeys) > 0 {
if name, found := badKeyNames(pass.TypesInfo, isForbiddenKey(opts.ForbiddenKeys), keys, attrs); found {
pass.Reportf(call.Pos(), "%q key is forbidden and should not be used", name)
}
switch opts.KeyNamingCase {
case snakeCase:
checkKeyNamingCase(strcase.ToSnake, "snake_case")
case kebabCase:
checkKeyNamingCase(strcase.ToKebab, "kebab-case")
case camelCase:
checkKeyNamingCase(strcase.ToCamel, "camelCase")
case pascalCase:
checkKeyNamingCase(strcase.ToPascal, "PascalCase")
}

switch {
case opts.KeyNamingCase == snakeCase:
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToSnake), keys, attrs); found {
pass.Reportf(call.Pos(), "keys should be written in snake_case")
}
case opts.KeyNamingCase == kebabCase:
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToKebab), keys, attrs); found {
pass.Reportf(call.Pos(), "keys should be written in kebab-case")
}
case opts.KeyNamingCase == camelCase:
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToCamel), keys, attrs); found {
pass.Reportf(call.Pos(), "keys should be written in camelCase")
}
case opts.KeyNamingCase == pascalCase:
if _, found := badKeyNames(pass.TypesInfo, valueChanged(strcase.ToPascal), keys, attrs); found {
pass.Reportf(call.Pos(), "keys should be written in PascalCase")
}
}
}

func isForbiddenKey(forbiddenKeys []string) func(string) bool {
return func(name string) bool {
return slices.Contains(forbiddenKeys, name)
if len(opts.ForbiddenKeys) > 0 {
forEachKey(pass.TypesInfo, keys, attrs, func(key ast.Expr) {
if name, ok := getKeyName(key); ok && slices.Contains(opts.ForbiddenKeys, name) {
pass.Reportf(call.Pos(), "%q key is forbidden and should not be used", name)
}
})
}
}

func valueChanged(handler func(string) string) func(string) bool {
return func(name string) bool {
return handler(name) != name
if opts.ArgsOnSepLines && areArgsOnSameLine(pass.Fset, call, keys, attrs) {
pass.Reportf(call.Pos(), "arguments should be put on separate lines")
}
}

func globalLoggerUsed(info *types.Info, expr ast.Expr) bool {
selector, ok := expr.(*ast.SelectorExpr)
func isGlobalLoggerUsed(info *types.Info, call ast.Expr) bool {
selector, ok := call.(*ast.SelectorExpr)
if !ok {
return false
}
Expand All @@ -318,7 +312,7 @@ func globalLoggerUsed(info *types.Info, expr ast.Expr) bool {
return obj.Parent() == obj.Pkg().Scope()
}

func hasContextInScope(info *types.Info, stack []ast.Node) bool {
func isContextInScope(info *types.Info, stack []ast.Node) bool {
for i := len(stack) - 1; i >= 0; i-- {
decl, ok := stack[i].(*ast.FuncDecl)
if !ok {
Expand All @@ -336,8 +330,8 @@ func hasContextInScope(info *types.Info, stack []ast.Node) bool {
return false
}

func staticMsg(expr ast.Expr) bool {
switch msg := expr.(type) {
func isStaticMsg(msg ast.Expr) bool {
switch msg := msg.(type) {
case *ast.BasicLit: // e.g. slog.Info("msg")
return msg.Kind == token.STRING
case *ast.Ident: // e.g. const msg = "msg"; slog.Info(msg)
Expand All @@ -347,114 +341,53 @@ func staticMsg(expr ast.Expr) bool {
}
}

func rawKeysUsed(info *types.Info, keys, attrs []ast.Expr) bool {
isConst := func(expr ast.Expr) bool {
ident, ok := expr.(*ast.Ident)
return ok && ident.Obj != nil && ident.Obj.Kind == ast.Con
}

for _, key := range keys {
if !isConst(key) {
return true
}
}

for _, attr := range attrs {
switch attr := attr.(type) {
case *ast.CallExpr: // e.g. slog.Int()
fn := typeutil.StaticCallee(info, attr)
if _, ok := attrFuncs[fn.FullName()]; ok && !isConst(attr.Args[0]) {
return true
}

case *ast.CompositeLit: // slog.Attr{}
isRawKey := func(kv *ast.KeyValueExpr) bool {
return kv.Key.(*ast.Ident).Name == "Key" && !isConst(kv.Value)
}

switch len(attr.Elts) {
case 1: // slog.Attr{Key: ...} | slog.Attr{Value: ...}
kv := attr.Elts[0].(*ast.KeyValueExpr)
if isRawKey(kv) {
return true
}
case 2: // slog.Attr{..., ...} | slog.Attr{Key: ..., Value: ...}
kv1, ok := attr.Elts[0].(*ast.KeyValueExpr)
if ok {
kv2 := attr.Elts[1].(*ast.KeyValueExpr)
if isRawKey(kv1) || isRawKey(kv2) {
return true
}
} else if !isConst(attr.Elts[0]) {
return true
}
}
}
}

return false
}

func badKeyNames(info *types.Info, validationFn func(string) bool, keys, attrs []ast.Expr) (string, bool) {
func forEachKey(info *types.Info, keys, attrs []ast.Expr, fn func(key ast.Expr)) {
for _, key := range keys {
if name, ok := getKeyName(key); ok && validationFn(name) {
return name, true
}
fn(key)
}

for _, attr := range attrs {
var expr ast.Expr

switch attr := attr.(type) {
case *ast.CallExpr: // e.g. slog.Int()
fn := typeutil.StaticCallee(info, attr)
if fn == nil {
callee := typeutil.StaticCallee(info, attr)
if callee == nil {
continue
}
if _, ok := attrFuncs[fn.FullName()]; !ok {
if _, ok := attrFuncs[callee.FullName()]; !ok {
continue
}
expr = attr.Args[0]
fn(attr.Args[0])

case *ast.CompositeLit: // slog.Attr{}
switch len(attr.Elts) {
case 1: // slog.Attr{Key: ...} | slog.Attr{Value: ...}
if kv := attr.Elts[0].(*ast.KeyValueExpr); kv.Key.(*ast.Ident).Name == "Key" {
expr = kv.Value
fn(kv.Value)
}
case 2: // slog.Attr{..., ...} | slog.Attr{Key: ..., Value: ...}
expr = attr.Elts[0]
if kv1, ok := attr.Elts[0].(*ast.KeyValueExpr); ok && kv1.Key.(*ast.Ident).Name == "Key" {
expr = kv1.Value
}
if kv2, ok := attr.Elts[1].(*ast.KeyValueExpr); ok && kv2.Key.(*ast.Ident).Name == "Key" {
expr = kv2.Value
case 2: // slog.Attr{Key: ..., Value: ...} | slog.Attr{Value: ..., Key: ...} | slog.Attr{..., ...}
if kv, ok := attr.Elts[0].(*ast.KeyValueExpr); ok && kv.Key.(*ast.Ident).Name == "Key" {
fn(kv.Value)
} else if kv, ok := attr.Elts[1].(*ast.KeyValueExpr); ok && kv.Key.(*ast.Ident).Name == "Key" {
fn(kv.Value)
} else {
fn(attr.Elts[0])
}
}
}

if name, ok := getKeyName(expr); ok && validationFn(name) {
return name, true
}
}

return "", false
}

func getKeyName(expr ast.Expr) (string, bool) {
if expr == nil {
return "", false
}
if ident, ok := expr.(*ast.Ident); ok {
func getKeyName(key ast.Expr) (string, bool) {
if ident, ok := key.(*ast.Ident); ok {
if ident.Obj == nil || ident.Obj.Decl == nil || ident.Obj.Kind != ast.Con {
return "", false
}
if spec, ok := ident.Obj.Decl.(*ast.ValueSpec); ok && len(spec.Values) > 0 {
// TODO: support len(spec.Values) > 1; e.g. "const foo, bar = 1, 2"
expr = spec.Values[0]
// TODO: support len(spec.Values) > 1; e.g. const foo, bar = 1, 2
key = spec.Values[0]
}
}
if lit, ok := expr.(*ast.BasicLit); ok && lit.Kind == token.STRING {
if lit, ok := key.(*ast.BasicLit); ok && lit.Kind == token.STRING {
// string literals are always quoted.
value, err := strconv.Unquote(lit.Value)
if err != nil {
Expand All @@ -465,7 +398,7 @@ func getKeyName(expr ast.Expr) (string, bool) {
return "", false
}

func argsOnSameLine(fset *token.FileSet, call ast.Expr, keys, attrs []ast.Expr) bool {
func areArgsOnSameLine(fset *token.FileSet, call ast.Expr, keys, attrs []ast.Expr) bool {
if len(keys)+len(attrs) <= 1 {
return false // special case: slog.Info("msg", "key", "value") is ok.
}
Expand Down
Loading