Skip to content

Commit

Permalink
Merge pull request #1336 from cogentcore/reflectx-pointers
Browse files Browse the repository at this point in the history
Redesign reflectx pointer logic to fix panic
  • Loading branch information
rcoreilly authored Nov 25, 2024
2 parents 7a8cf46 + 6b54afe commit 6ad75a1
Show file tree
Hide file tree
Showing 17 changed files with 371 additions and 144 deletions.
28 changes: 23 additions & 5 deletions base/reflectx/pointers.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,25 @@ func NonPointerType(typ reflect.Type) reflect.Type {
}

// NonPointerValue returns a non-pointer version of the given value.
// If it encounters a nil pointer, it returns the nil pointer instead
// of an invalid value.
func NonPointerValue(v reflect.Value) reflect.Value {
for v.Kind() == reflect.Pointer {
v = v.Elem()
new := v.Elem()
if !new.IsValid() {
return v
}
v = new
}
return v
}

// PointerValue returns a pointer to the given value if it is not already
// a pointer.
func PointerValue(v reflect.Value) reflect.Value {
if !v.IsValid() {
return v
}
if v.Kind() == reflect.Pointer {
return v
}
Expand All @@ -44,6 +53,9 @@ func PointerValue(v reflect.Value) reflect.Value {
// OnePointerValue returns a value that is exactly one pointer away
// from a non-pointer value.
func OnePointerValue(v reflect.Value) reflect.Value {
if !v.IsValid() {
return v
}
if v.Kind() != reflect.Pointer {
if v.CanAddr() {
return v.Addr()
Expand All @@ -61,22 +73,28 @@ func OnePointerValue(v reflect.Value) reflect.Value {
}

// Underlying returns the actual underlying version of the given value,
// going through any pointers and interfaces.
// going through any pointers and interfaces. If it encounters a nil
// pointer or interface, it returns the nil pointer or interface instead of
// an invalid value.
func Underlying(v reflect.Value) reflect.Value {
if !v.IsValid() {
return v
}
for v.Type().Kind() == reflect.Interface || v.Type().Kind() == reflect.Pointer {
v = v.Elem()
if !v.IsValid() {
new := v.Elem()
if !new.IsValid() {
return v
}
v = new
}
return v
}

// UnderlyingPointer returns a pointer to the actual underlying version of the
// given value, going through any pointers and interfaces.
// given value, going through any pointers and interfaces. It is equivalent to
// [OnePointerValue] of [Underlying], so if it encounters a nil pointer or
// interface, it stops at the nil pointer or interface instead of returning
// an invalid value.
func UnderlyingPointer(v reflect.Value) reflect.Value {
if !v.IsValid() {
return v
Expand Down
216 changes: 216 additions & 0 deletions base/reflectx/pointers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,224 @@ import (
"reflect"
"testing"
"unsafe"

"github.com/stretchr/testify/assert"
)

type myInterface interface {
myMethod()
}

func TestNonPointerType(t *testing.T) {
assert.Equal(t, reflect.TypeFor[int](), NonPointerType(reflect.TypeFor[int]()))
assert.Equal(t, reflect.TypeFor[int](), NonPointerType(reflect.TypeFor[*int]()))
assert.Equal(t, reflect.TypeFor[int](), NonPointerType(reflect.TypeFor[**int]()))
assert.Equal(t, reflect.TypeFor[int](), NonPointerType(reflect.TypeFor[***int]()))

assert.Equal(t, reflect.TypeFor[any](), NonPointerType(reflect.TypeFor[any]()))
assert.Equal(t, reflect.TypeFor[any](), NonPointerType(reflect.TypeFor[*any]()))
assert.Equal(t, reflect.TypeFor[any](), NonPointerType(reflect.TypeFor[**any]()))
assert.Equal(t, reflect.TypeFor[any](), NonPointerType(reflect.TypeFor[***any]()))

assert.Equal(t, nil, NonPointerType(reflect.TypeOf(nil)))
}

func TestNonPointerValue(t *testing.T) {
v := 1
rv := reflect.ValueOf(v)
assert.True(t, NonPointerValue(reflect.ValueOf(v)).Equal(rv))
assert.True(t, NonPointerValue(reflect.ValueOf(&v)).Equal(rv))

p := &v
assert.True(t, NonPointerValue(reflect.ValueOf(p)).Equal(rv))
assert.True(t, NonPointerValue(reflect.ValueOf(&p)).Equal(rv))

a := any(v)
assert.True(t, NonPointerValue(reflect.ValueOf(a)).Equal(rv))
assert.Equal(t, rv.Type(), NonPointerValue(reflect.ValueOf(a)).Type())
assert.True(t, NonPointerValue(reflect.ValueOf(&a)).Equal(rv))
// NonPointerValue cannot go through *any to get the true type
assert.NotEqual(t, rv.Type(), NonPointerValue(reflect.ValueOf(&a)).Type())

n := (*int)(nil)
rn := reflect.ValueOf(n)
assert.True(t, rn.IsValid())
assert.True(t, NonPointerValue(rn).IsValid())
assert.True(t, NonPointerValue(rn).Equal(rn))

in := myInterface(nil)
rinp := reflect.ValueOf(&in)
assert.True(t, rinp.IsValid())
assert.True(t, NonPointerValue(rinp).IsValid())
assert.True(t, NonPointerValue(rinp).Equal(reflect.ValueOf(in)))

an := any(nil)
ran := reflect.ValueOf(an)
assert.False(t, ran.IsValid())
assert.False(t, NonPointerValue(ran).IsValid())
}

func TestPointerValue(t *testing.T) {
v := 1
rv := reflect.ValueOf(v)
assert.False(t, rv.CanAddr())
assert.False(t, PointerValue(reflect.ValueOf(v)).Equal(rv))
assert.Equal(t, reflect.TypeFor[*int](), PointerValue(reflect.ValueOf(v)).Type())

p := &v
rp := reflect.ValueOf(p)
assert.True(t, PointerValue(rp).Equal(rp))
assert.Equal(t, reflect.TypeFor[*int](), PointerValue(rp).Type())

assert.True(t, rp.Elem().CanAddr())
assert.True(t, PointerValue(rp.Elem()).Equal(rp))
assert.True(t, PointerValue(rp.Elem()).Equal(rp.Elem().Addr()))

pp := &p
rpp := reflect.ValueOf(pp)
assert.True(t, PointerValue(rpp).Equal(rpp))
assert.Equal(t, reflect.TypeFor[**int](), PointerValue(rpp).Type())

n := (*int)(nil)
rn := reflect.ValueOf(n)
assert.True(t, PointerValue(rn).Equal(rn))

an := any(nil)
ran := reflect.ValueOf(an)
assert.False(t, ran.IsValid())
assert.False(t, PointerValue(ran).IsValid())
}

func TestOnePointerValue(t *testing.T) {
v := 1
rv := reflect.ValueOf(v)
assert.False(t, rv.CanAddr())
assert.False(t, OnePointerValue(reflect.ValueOf(v)).Equal(rv))
assert.Equal(t, reflect.TypeFor[*int](), OnePointerValue(reflect.ValueOf(v)).Type())

p := &v
rp := reflect.ValueOf(p)
assert.True(t, OnePointerValue(rp).Equal(rp))
assert.Equal(t, reflect.TypeFor[*int](), OnePointerValue(rp).Type())

assert.True(t, rp.Elem().CanAddr())
assert.True(t, OnePointerValue(rp.Elem()).Equal(rp))
assert.True(t, OnePointerValue(rp.Elem()).Equal(rp.Elem().Addr()))

pp := &p
rpp := reflect.ValueOf(pp)
assert.False(t, OnePointerValue(rpp).Equal(rpp))
assert.True(t, OnePointerValue(rpp).Equal(rp))
assert.Equal(t, reflect.TypeFor[*int](), OnePointerValue(rpp).Type())

n := (*int)(nil)
rn := reflect.ValueOf(n)
assert.True(t, rn.IsValid())
assert.True(t, OnePointerValue(rn).IsValid())
assert.True(t, OnePointerValue(rn).Equal(rn))

an := any(nil)
ran := reflect.ValueOf(an)
assert.False(t, ran.IsValid())
assert.False(t, OnePointerValue(ran).IsValid())
}

func TestUnderlying(t *testing.T) {
v := 1
rv := reflect.ValueOf(v)
assert.True(t, Underlying(reflect.ValueOf(v)).Equal(rv))
assert.True(t, Underlying(reflect.ValueOf(&v)).Equal(rv))

p := &v
assert.True(t, Underlying(reflect.ValueOf(p)).Equal(rv))
assert.True(t, Underlying(reflect.ValueOf(&p)).Equal(rv))

a := any(v)
assert.True(t, Underlying(reflect.ValueOf(a)).Equal(rv))
assert.Equal(t, rv.Type(), Underlying(reflect.ValueOf(a)).Type())
assert.True(t, Underlying(reflect.ValueOf(&a)).Equal(rv))
assert.Equal(t, rv.Type(), Underlying(reflect.ValueOf(&a)).Type())

n := (*int)(nil)
rn := reflect.ValueOf(n)
assert.True(t, rn.IsValid())
assert.True(t, Underlying(rn).IsValid())
assert.True(t, Underlying(rn).Equal(rn))

in := myInterface(nil)
rinp := reflect.ValueOf(&in)
assert.True(t, rinp.IsValid())
assert.True(t, Underlying(rinp).IsValid())
assert.True(t, Underlying(rinp).Equal(rinp.Elem()))

an := any(nil)
ran := reflect.ValueOf(an)
assert.False(t, ran.IsValid())
assert.False(t, Underlying(ran).IsValid())
}

func TestUnderlyingPointer(t *testing.T) {
v := 1
rv := reflect.ValueOf(v)
assert.False(t, rv.CanAddr())
assert.False(t, UnderlyingPointer(reflect.ValueOf(v)).Equal(rv))
assert.Equal(t, reflect.TypeFor[*int](), UnderlyingPointer(reflect.ValueOf(v)).Type())

p := &v
rp := reflect.ValueOf(p)
assert.True(t, UnderlyingPointer(rp).Equal(rp))
assert.Equal(t, reflect.TypeFor[*int](), UnderlyingPointer(rp).Type())

assert.True(t, rp.Elem().CanAddr())
assert.True(t, UnderlyingPointer(rp.Elem()).Equal(rp))
assert.True(t, UnderlyingPointer(rp.Elem()).Equal(rp.Elem().Addr()))

pp := &p
rpp := reflect.ValueOf(pp)
assert.False(t, UnderlyingPointer(rpp).Equal(rpp))
assert.True(t, UnderlyingPointer(rpp).Equal(rp))
assert.Equal(t, reflect.TypeFor[*int](), UnderlyingPointer(rpp).Type())

a := any(v)
ap := &a
rap := reflect.ValueOf(ap)
// Different pointer, same type
assert.False(t, UnderlyingPointer(rap).Equal(rp))
assert.Equal(t, rp.Type(), UnderlyingPointer(rap).Type())
assert.Equal(t, reflect.TypeFor[*int](), UnderlyingPointer(rap).Type())

n := (*int)(nil)
rn := reflect.ValueOf(n)
assert.True(t, rn.IsValid())
assert.True(t, UnderlyingPointer(rn).IsValid())
assert.True(t, UnderlyingPointer(rn).Equal(rn))

an := any(nil)
ran := reflect.ValueOf(an)
assert.False(t, ran.IsValid())
assert.False(t, UnderlyingPointer(ran).IsValid())
}

func TestNonNilNew(t *testing.T) {
n0 := NonNilNew(reflect.TypeFor[int]())
assert.Equal(t, reflect.TypeFor[*int](), n0.Type())
assert.False(t, n0.IsNil())
assert.Equal(t, 0, n0.Elem().Interface())

n1 := NonNilNew(reflect.TypeFor[*int]())
assert.Equal(t, reflect.TypeFor[**int](), n1.Type())
assert.False(t, n1.IsNil())
assert.False(t, n1.Elem().IsNil())
assert.Equal(t, 0, n1.Elem().Elem().Interface())

n2 := NonNilNew(reflect.TypeFor[**int]())
assert.Equal(t, reflect.TypeFor[***int](), n2.Type())
assert.False(t, n2.IsNil())
assert.False(t, n2.Elem().IsNil())
assert.False(t, n2.Elem().Elem().IsNil())
assert.Equal(t, 0, n2.Elem().Elem().Elem().Interface())
}

type PointerTestSub struct {
Mbr1 string
Mbr2 int
Expand Down
9 changes: 3 additions & 6 deletions base/reflectx/structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,8 @@ func ValueIsDefault(fv reflect.Value, def string) bool {
// SetFromDefaultTags sets the values of fields in the given struct based on
// `default:` default value struct field tags.
func SetFromDefaultTags(v any) error {
if AnyIsNil(v) {
return nil
}
ov := reflect.ValueOf(v)
if ov.Kind() == reflect.Pointer && ov.IsNil() {
if IsNil(ov) {
return nil
}
val := NonPointerValue(ov)
Expand Down Expand Up @@ -161,8 +158,8 @@ type ShouldSaver interface {
func NonDefaultFields(v any) map[string]any {
res := map[string]any{}

rv := NonPointerValue(reflect.ValueOf(v))
if !rv.IsValid() {
rv := Underlying(reflect.ValueOf(v))
if IsNil(rv) {
return nil
}
rt := rv.Type()
Expand Down
Loading

0 comments on commit 6ad75a1

Please sign in to comment.