Skip to content

Commit

Permalink
Add options KeyPreserveCase() and KeyNormalizer(func (string) string)
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitry-irtegov committed Mar 16, 2020
1 parent 97ee7ad commit 0caa8e0
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 46 deletions.
34 changes: 17 additions & 17 deletions util.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,31 @@ func (pe ConfigParseError) Error() string {
return fmt.Sprintf("While parsing config: %s", pe.err.Error())
}

// toCaseInsensitiveValue checks if the value is a map;
// if so, create a copy and lower-case the keys recursively.
func toCaseInsensitiveValue(value interface{}) interface{} {
// toNormalizedValue checks if the value is a map;
// if so, create a copy and normalize the keys recursively.
func toNormalizedValue(value interface{}, normalize keyNormalizeHookType) interface{} {
switch v := value.(type) {
case map[interface{}]interface{}:
value = copyAndInsensitiviseMap(cast.ToStringMap(v))
value = copyAndNormalizeMap(cast.ToStringMap(v), normalize)
case map[string]interface{}:
value = copyAndInsensitiviseMap(v)
value = copyAndNormalizeMap(v, normalize)
}

return value
}

// copyAndInsensitiviseMap behaves like insensitiviseMap, but creates a copy of
// any map it makes case insensitive.
func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} {
// copyAndNormalizeMap behaves like normalizeMap, but creates a copy of
// any map it normalizes.
func copyAndNormalizeMap(m map[string]interface{}, normalize keyNormalizeHookType) map[string]interface{} {
nm := make(map[string]interface{})

for key, val := range m {
lkey := strings.ToLower(key)
lkey := normalize(key)
switch v := val.(type) {
case map[interface{}]interface{}:
nm[lkey] = copyAndInsensitiviseMap(cast.ToStringMap(v))
nm[lkey] = copyAndNormalizeMap(cast.ToStringMap(v), normalize)
case map[string]interface{}:
nm[lkey] = copyAndInsensitiviseMap(v)
nm[lkey] = copyAndNormalizeMap(v, normalize)
default:
nm[lkey] = v
}
Expand All @@ -66,25 +66,25 @@ func copyAndInsensitiviseMap(m map[string]interface{}) map[string]interface{} {
return nm
}

func insensitiviseMap(m map[string]interface{}) {
func normalizeMap(m map[string]interface{}, normalize keyNormalizeHookType) {
for key, val := range m {
switch val.(type) {
case map[interface{}]interface{}:
// nested map: cast and recursively insensitivise
val = cast.ToStringMap(val)
insensitiviseMap(val.(map[string]interface{}))
normalizeMap(val.(map[string]interface{}), normalize)
case map[string]interface{}:
// nested map: recursively insensitivise
insensitiviseMap(val.(map[string]interface{}))
normalizeMap(val.(map[string]interface{}), normalize)
}

lower := strings.ToLower(key)
if key != lower {
normKey := normalize(key)
if key != normKey {
// remove old key (not lower-cased)
delete(m, key)
}
// update map
m[lower] = val
m[normKey] = val
}
}

Expand Down
3 changes: 2 additions & 1 deletion util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package viper

import (
"reflect"
"strings"
"testing"
)

Expand All @@ -33,7 +34,7 @@ func TestCopyAndInsensitiviseMap(t *testing.T) {
}
)

got := copyAndInsensitiviseMap(given)
got := copyAndNormalizeMap(given, strings.ToLower)

if !reflect.DeepEqual(got, expected) {
t.Fatalf("Got %q\nexpected\n%q", got, expected)
Expand Down
92 changes: 64 additions & 28 deletions viper.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ func DecodeHook(hook mapstructure.DecodeHookFunc) DecoderConfigOption {
}
}

type keyNormalizeHookType func(string) string

var defaultKeyNormalizer = strings.ToLower

// Viper is a prioritized configuration registry. It
// maintains a set of configuration sources, fetches
// values to populate those, and provides them according
Expand Down Expand Up @@ -180,6 +184,10 @@ type Viper struct {
// used to access a nested value in one go
keyDelim string

// Function to normalize keys
// by default, strings.ToLower
keyNormalizeHook keyNormalizeHookType

// A set of paths to look for the config file in
configPaths []string

Expand Down Expand Up @@ -220,6 +228,7 @@ type Viper struct {
func New() *Viper {
v := new(Viper)
v.keyDelim = "."
v.keyNormalizeHook = defaultKeyNormalizer
v.configName = "config"
v.configPermissions = os.FileMode(0644)
v.fs = afero.NewOsFs()
Expand Down Expand Up @@ -257,6 +266,23 @@ func KeyDelimiter(d string) Option {
})
}

// KeyNormalizer is option to set arbitrary function for key normalization
// This function will be applied to all keys after unmarshal, during merge, search for duplicates, etc
// Default normalizer is strings.ToLower
func KeyNormalizer(n keyNormalizeHookType) Option {
return optionFunc(func(v *Viper) {
v.keyNormalizeHook = n
})
}

// KeyPreserveCase is option to disable key lowercasing
// By default, Viper converts all keys to lovercase
func KeyPreserveCase() Option {
return optionFunc(func(v *Viper) {
v.keyNormalizeHook = func(key string) string { return key }
})
}

// StringReplacer applies a set of replacements to a string.
type StringReplacer interface {
// Replace returns a copy of s with all replacements performed.
Expand Down Expand Up @@ -425,6 +451,13 @@ func (v *Viper) SetEnvPrefix(in string) {
}
}

func (v *Viper) keyNormalize(k string) string {
if v.keyNormalizeHook != nil {
return v.keyNormalizeHook(k)
}
return defaultKeyNormalizer(k)
}

func (v *Viper) mergeWithEnvPrefix(in string) string {
if v.envPrefix != "" {
return strings.ToUpper(v.envPrefix + "_" + in)
Expand Down Expand Up @@ -548,7 +581,7 @@ func (v *Viper) providerPathExists(p *defaultRemoteProvider) bool {

// searchMap recursively searches for a value for path in source map.
// Returns nil if not found.
// Note: This assumes that the path entries and map keys are lower cased.
// Note: This assumes that the path entries and map keys are normalized (by default, lowercased).
func (v *Viper) searchMap(source map[string]interface{}, path []string) interface{} {
if len(path) == 0 {
return source
Expand Down Expand Up @@ -587,15 +620,15 @@ func (v *Viper) searchMap(source map[string]interface{}, path []string) interfac
// This should be useful only at config level (other maps may not contain dots
// in their keys).
//
// Note: This assumes that the path entries and map keys are lower cased.
// Note: This assumes that the path entries and map keys are normalized (by default, lowercased).
func (v *Viper) searchMapWithPathPrefixes(source map[string]interface{}, path []string) interface{} {
if len(path) == 0 {
return source
}

// search for path prefixes, starting from the longest one
for i := len(path); i > 0; i-- {
prefixKey := strings.ToLower(strings.Join(path[0:i], v.keyDelim))
prefixKey := v.keyNormalize(strings.Join(path[0:i], v.keyDelim))

next, ok := source[prefixKey]
if ok {
Expand Down Expand Up @@ -724,7 +757,7 @@ func GetViper() *Viper {
// Get returns an interface. For a specific value use one of the Get____ methods.
func Get(key string) interface{} { return v.Get(key) }
func (v *Viper) Get(key string) interface{} {
lcaseKey := strings.ToLower(key)
lcaseKey := v.keyNormalize(key)
val := v.find(lcaseKey, true)
if val == nil {
return nil
Expand Down Expand Up @@ -1001,7 +1034,7 @@ func (v *Viper) BindFlagValue(key string, flag FlagValue) error {
if flag == nil {
return fmt.Errorf("flag for %q is nil", key)
}
v.pflags[strings.ToLower(key)] = flag
v.pflags[v.keyNormalize(key)] = flag
return nil
}

Expand All @@ -1016,7 +1049,7 @@ func (v *Viper) BindEnv(input ...string) error {
return fmt.Errorf("missing key to bind to")
}

key = strings.ToLower(input[0])
key = v.keyNormalize(input[0])

if len(input) == 1 {
envkey = v.mergeWithEnvPrefix(key)
Expand All @@ -1037,7 +1070,7 @@ func (v *Viper) BindEnv(input ...string) error {
// Lastly, if no value was found and flagDefault is true, and if the key
// corresponds to a flag, the flag's default value is returned.
//
// Note: this assumes a lower-cased key given.
// Note: this assumes a normalized key given.
func (v *Viper) find(lcaseKey string, flagDefault bool) interface{} {
var (
val interface{}
Expand Down Expand Up @@ -1178,11 +1211,11 @@ func readAsCSV(val string) ([]string, error) {
}

// IsSet checks to see if the key has been set in any of the data locations.
// IsSet is case-insensitive for a key.
// IsSet normalizes the key.
func IsSet(key string) bool { return v.IsSet(key) }
func (v *Viper) IsSet(key string) bool {
lcaseKey := strings.ToLower(key)
val := v.find(lcaseKey, false)
normKey := v.keyNormalize(key)
val := v.find(normKey, false)
return val != nil
}

Expand All @@ -1205,11 +1238,11 @@ func (v *Viper) SetEnvKeyReplacer(r *strings.Replacer) {
// This enables one to change a name without breaking the application.
func RegisterAlias(alias string, key string) { v.RegisterAlias(alias, key) }
func (v *Viper) RegisterAlias(alias string, key string) {
v.registerAlias(alias, strings.ToLower(key))
v.registerAlias(alias, v.keyNormalize(key))
}

func (v *Viper) registerAlias(alias string, key string) {
alias = strings.ToLower(alias)
alias = v.keyNormalize(alias)
if alias != key && alias != v.realKey(key) {
_, exists := v.aliases[alias]

Expand Down Expand Up @@ -1260,34 +1293,36 @@ func (v *Viper) InConfig(key string) bool {
}

// SetDefault sets the default value for this key.
// SetDefault is case-insensitive for a key.
// SetDefault applies normalization (by default, lowercases) a key.
// Default only used when no value is provided by the user via flag, config or ENV.
func SetDefault(key string, value interface{}) { v.SetDefault(key, value) }
func (v *Viper) SetDefault(key string, value interface{}) {
// If alias passed in, then set the proper default
key = v.realKey(strings.ToLower(key))
value = toCaseInsensitiveValue(value)
key = v.keyNormalize(key)
value = toNormalizedValue(value, v.keyNormalize)
key = v.realKey(key)

path := strings.Split(key, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1])
lastKey := v.keyNormalize(path[len(path)-1])
deepestMap := deepSearch(v.defaults, path[0:len(path)-1])

// set innermost value
deepestMap[lastKey] = value
}

// Set sets the value for the key in the override register.
// Set is case-insensitive for a key.
// Set normalizes a key.
// Will be used instead of values obtained via
// flags, config file, ENV, default, or key/value store.
func Set(key string, value interface{}) { v.Set(key, value) }
func (v *Viper) Set(key string, value interface{}) {
// If alias passed in, then set the proper override
key = v.realKey(strings.ToLower(key))
value = toCaseInsensitiveValue(value)
key = v.keyNormalize(key)
value = toNormalizedValue(value, v.keyNormalize)
key = v.realKey(key)

path := strings.Split(key, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1])
lastKey := v.keyNormalize(path[len(path)-1])
deepestMap := deepSearch(v.override, path[0:len(path)-1])

// set innermost value
Expand Down Expand Up @@ -1371,7 +1406,7 @@ func (v *Viper) MergeConfigMap(cfg map[string]interface{}) error {
if v.config == nil {
v.config = make(map[string]interface{})
}
insensitiviseMap(cfg)
normalizeMap(cfg, v.keyNormalize)
mergeMaps(cfg, v.config, nil)
return nil
}
Expand Down Expand Up @@ -1506,7 +1541,7 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error {
value, _ := v.properties.Get(key)
// recursively build nested maps
path := strings.Split(key, ".")
lastKey := strings.ToLower(path[len(path)-1])
lastKey := v.keyNormalize(path[len(path)-1])
deepestMap := deepSearch(c, path[0:len(path)-1])
// set innermost value
deepestMap[lastKey] = value
Expand All @@ -1530,7 +1565,8 @@ func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error {
}
}

insensitiviseMap(c)
normalizeMap(c, v.keyNormalize)

return nil
}

Expand Down Expand Up @@ -1629,9 +1665,9 @@ func (v *Viper) marshalWriter(f afero.File, configType string) error {
}

func keyExists(k string, m map[string]interface{}) string {
lk := strings.ToLower(k)
lk := v.keyNormalize(k)
for mk := range m {
lmk := strings.ToLower(mk)
lmk := v.keyNormalize(mk)
if lmk == lk {
return mk
}
Expand Down Expand Up @@ -1855,7 +1891,7 @@ func (v *Viper) flattenAndMergeMap(shadow map[string]bool, m map[string]interfac
m2 = cast.ToStringMap(val)
default:
// immediate value
shadow[strings.ToLower(fullKey)] = true
shadow[v.keyNormalize(fullKey)] = true
continue
}
// recursively merge to shadow map
Expand All @@ -1881,7 +1917,7 @@ outer:
}
}
// add key
shadow[strings.ToLower(k)] = true
shadow[v.keyNormalize(k)] = true
}
return shadow
}
Expand All @@ -1899,7 +1935,7 @@ func (v *Viper) AllSettings() map[string]interface{} {
continue
}
path := strings.Split(k, v.keyDelim)
lastKey := strings.ToLower(path[len(path)-1])
lastKey := v.keyNormalize(path[len(path)-1])
deepestMap := deepSearch(m, path[0:len(path)-1])
// set innermost value
deepestMap[lastKey] = value
Expand Down
Loading

0 comments on commit 0caa8e0

Please sign in to comment.