Skip to content

Commit

Permalink
Merge pull request #10 from brunoga/pointer_map_fix
Browse files Browse the repository at this point in the history
Correctly handle 2 values with different types but same pointers.
  • Loading branch information
brunoga authored Jun 28, 2024
2 parents 3520539 + 63aa4fe commit 92c699d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 11 deletions.
33 changes: 22 additions & 11 deletions deep.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ func MustCopy[T any](src T) T {
return dst
}

type pointersMap map[uintptr]map[string]reflect.Value

func copy[T any](src T, skipUnsupported bool) (T, error) {
v := reflect.ValueOf(src)

Expand All @@ -41,7 +43,7 @@ func copy[T any](src T, skipUnsupported bool) (T, error) {
return t, nil
}

dst, err := recursiveCopy(v, make(map[uintptr]reflect.Value),
dst, err := recursiveCopy(v, make(pointersMap),
skipUnsupported)
if err != nil {
var t T
Expand All @@ -51,7 +53,7 @@ func copy[T any](src T, skipUnsupported bool) (T, error) {
return dst.Interface().(T), nil
}

func recursiveCopy(v reflect.Value, pointers map[uintptr]reflect.Value,
func recursiveCopy(v reflect.Value, pointers pointersMap,
skipUnsupported bool) (reflect.Value, error) {
switch v.Kind() {
case reflect.Bool, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
Expand Down Expand Up @@ -93,7 +95,7 @@ func recursiveCopy(v reflect.Value, pointers map[uintptr]reflect.Value,
}
}

func recursiveCopyArray(v reflect.Value, pointers map[uintptr]reflect.Value,
func recursiveCopyArray(v reflect.Value, pointers pointersMap,
skipUnsupported bool) (reflect.Value, error) {
dst := reflect.New(v.Type()).Elem()

Expand All @@ -110,7 +112,7 @@ func recursiveCopyArray(v reflect.Value, pointers map[uintptr]reflect.Value,
return dst, nil
}

func recursiveCopyInterface(v reflect.Value, pointers map[uintptr]reflect.Value,
func recursiveCopyInterface(v reflect.Value, pointers pointersMap,
skipUnsupported bool) (reflect.Value, error) {
if v.IsNil() {
// If the interface is nil, just return it.
Expand All @@ -120,7 +122,7 @@ func recursiveCopyInterface(v reflect.Value, pointers map[uintptr]reflect.Value,
return recursiveCopy(v.Elem(), pointers, skipUnsupported)
}

func recursiveCopyMap(v reflect.Value, pointers map[uintptr]reflect.Value,
func recursiveCopyMap(v reflect.Value, pointers pointersMap,
skipUnsupported bool) (reflect.Value, error) {
if v.IsNil() {
// If the slice is nil, just return it.
Expand All @@ -143,22 +145,31 @@ func recursiveCopyMap(v reflect.Value, pointers map[uintptr]reflect.Value,
return dst, nil
}

func recursiveCopyPtr(v reflect.Value, pointers map[uintptr]reflect.Value,
func recursiveCopyPtr(v reflect.Value, pointers pointersMap,
skipUnsupported bool) (reflect.Value, error) {
// If the pointer is nil, just return it.
if v.IsNil() {
return v, nil
}

typeName := v.Type().String()

// If the pointer is already in the pointers map, return it.
ptr := v.Pointer()
if dst, ok := pointers[ptr]; ok {
return dst, nil
if dstMap, ok := pointers[ptr]; ok {
if dst, ok := dstMap[typeName]; ok {
return dst, nil
}
}

// Otherwise, create a new pointer and add it to the pointers map.
dst := reflect.New(v.Type().Elem())
pointers[ptr] = dst

if pointers[ptr] == nil {
pointers[ptr] = make(map[string]reflect.Value)
}

pointers[ptr][typeName] = dst

// Proceed with the copy.
elem := v.Elem()
Expand All @@ -172,7 +183,7 @@ func recursiveCopyPtr(v reflect.Value, pointers map[uintptr]reflect.Value,
return dst, nil
}

func recursiveCopySlice(v reflect.Value, pointers map[uintptr]reflect.Value,
func recursiveCopySlice(v reflect.Value, pointers pointersMap,
skipUnsupported bool) (reflect.Value, error) {
if v.IsNil() {
// If the slice is nil, just return it.
Expand All @@ -195,7 +206,7 @@ func recursiveCopySlice(v reflect.Value, pointers map[uintptr]reflect.Value,
return dst, nil
}

func recursiveCopyStruct(v reflect.Value, pointers map[uintptr]reflect.Value,
func recursiveCopyStruct(v reflect.Value, pointers pointersMap,
skipUnsupported bool) (reflect.Value, error) {
dst := reflect.New(v.Type()).Elem()

Expand Down
15 changes: 15 additions & 0 deletions deep_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,18 @@ func BenchmarkCopy_Deep(b *testing.B) {
MustCopy(src)
}
}

func TestTrickyMemberPointer(t *testing.T) {
type Foo struct {
N int
}
type Bar struct {
F *Foo
P *int
}

foo := Foo{N: 1}
bar := Bar{F: &foo, P: &foo.N}

doCopyAndCheck(t, bar, false)
}

0 comments on commit 92c699d

Please sign in to comment.