diff --git a/deep.go b/deep.go index 18b5fca..1b1c14a 100644 --- a/deep.go +++ b/deep.go @@ -30,6 +30,8 @@ func MustCopy[T any](src T) T { return dst } +type pointersMap map[uintptr]map[string]reflect.Value + func copy[T any](src T, skipUnsupported bool) (T, error) { v := reflect.ValueOf(src) @@ -41,7 +43,7 @@ func copy[T any](src T, skipUnsupported bool) (T, error) { return t, nil } - dst, err := recursiveCopy(v, make(map[uintptr]reflect.Value), + dst, err := recursiveCopy(v, make(pointersMap), skipUnsupported) if err != nil { var t T @@ -51,7 +53,7 @@ func copy[T any](src T, skipUnsupported bool) (T, error) { return dst.Interface().(T), nil } -func recursiveCopy(v reflect.Value, pointers map[uintptr]reflect.Value, +func recursiveCopy(v reflect.Value, pointers pointersMap, skipUnsupported bool) (reflect.Value, error) { switch v.Kind() { case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, @@ -93,7 +95,7 @@ func recursiveCopy(v reflect.Value, pointers map[uintptr]reflect.Value, } } -func recursiveCopyArray(v reflect.Value, pointers map[uintptr]reflect.Value, +func recursiveCopyArray(v reflect.Value, pointers pointersMap, skipUnsupported bool) (reflect.Value, error) { dst := reflect.New(v.Type()).Elem() @@ -110,7 +112,7 @@ func recursiveCopyArray(v reflect.Value, pointers map[uintptr]reflect.Value, return dst, nil } -func recursiveCopyInterface(v reflect.Value, pointers map[uintptr]reflect.Value, +func recursiveCopyInterface(v reflect.Value, pointers pointersMap, skipUnsupported bool) (reflect.Value, error) { if v.IsNil() { // If the interface is nil, just return it. @@ -120,7 +122,7 @@ func recursiveCopyInterface(v reflect.Value, pointers map[uintptr]reflect.Value, return recursiveCopy(v.Elem(), pointers, skipUnsupported) } -func recursiveCopyMap(v reflect.Value, pointers map[uintptr]reflect.Value, +func recursiveCopyMap(v reflect.Value, pointers pointersMap, skipUnsupported bool) (reflect.Value, error) { if v.IsNil() { // If the slice is nil, just return it. @@ -143,22 +145,31 @@ func recursiveCopyMap(v reflect.Value, pointers map[uintptr]reflect.Value, return dst, nil } -func recursiveCopyPtr(v reflect.Value, pointers map[uintptr]reflect.Value, +func recursiveCopyPtr(v reflect.Value, pointers pointersMap, skipUnsupported bool) (reflect.Value, error) { // If the pointer is nil, just return it. if v.IsNil() { return v, nil } + typeName := v.Type().String() + // If the pointer is already in the pointers map, return it. ptr := v.Pointer() - if dst, ok := pointers[ptr]; ok { - return dst, nil + if dstMap, ok := pointers[ptr]; ok { + if dst, ok := dstMap[typeName]; ok { + return dst, nil + } } // Otherwise, create a new pointer and add it to the pointers map. dst := reflect.New(v.Type().Elem()) - pointers[ptr] = dst + + if pointers[ptr] == nil { + pointers[ptr] = make(map[string]reflect.Value) + } + + pointers[ptr][typeName] = dst // Proceed with the copy. elem := v.Elem() @@ -172,7 +183,7 @@ func recursiveCopyPtr(v reflect.Value, pointers map[uintptr]reflect.Value, return dst, nil } -func recursiveCopySlice(v reflect.Value, pointers map[uintptr]reflect.Value, +func recursiveCopySlice(v reflect.Value, pointers pointersMap, skipUnsupported bool) (reflect.Value, error) { if v.IsNil() { // If the slice is nil, just return it. @@ -195,7 +206,7 @@ func recursiveCopySlice(v reflect.Value, pointers map[uintptr]reflect.Value, return dst, nil } -func recursiveCopyStruct(v reflect.Value, pointers map[uintptr]reflect.Value, +func recursiveCopyStruct(v reflect.Value, pointers pointersMap, skipUnsupported bool) (reflect.Value, error) { dst := reflect.New(v.Type()).Elem() diff --git a/deep_test.go b/deep_test.go index bb2774d..64d553d 100644 --- a/deep_test.go +++ b/deep_test.go @@ -340,3 +340,18 @@ func BenchmarkCopy_Deep(b *testing.B) { MustCopy(src) } } + +func TestTrickyMemberPointer(t *testing.T) { + type Foo struct { + N int + } + type Bar struct { + F *Foo + P *int + } + + foo := Foo{N: 1} + bar := Bar{F: &foo, P: &foo.N} + + doCopyAndCheck(t, bar, false) +}