Skip to content

Commit

Permalink
fix decoding embedded structs
Browse files Browse the repository at this point in the history
  • Loading branch information
ehsannm committed Jan 3, 2024
1 parent d0e5ffb commit a5c60ee
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 57 deletions.
6 changes: 3 additions & 3 deletions kit/desc/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ func (c *Contract) Selector(s ...kit.RouteSelector) *Contract {
return c.AddSelector(s...)
}

// AddNamedSelector adds a kit.RouteSelector for this contract, and assign it a unique name.
// In case of you need to use auto-generated stub.Stub for your service/contract this name will
// AddNamedSelector adds a kit.RouteSelector for this contract, and assigns it a unique name.
// In case you need to use auto-generated stub.Stub for your service/contract this name will
// be used in the generated code.
func (c *Contract) AddNamedSelector(name string, s kit.RouteSelector) *Contract {
c.RouteSelectors = append(c.RouteSelectors, RouteSelector{
Expand Down Expand Up @@ -185,7 +185,7 @@ func (c *Contract) SetHandler(h ...kit.HandlerFunc) *Contract {
return c
}

// contractImpl is simple implementation of kit.Contract interface.
// contractImpl is a simple implementation of kit.Contract interface.
type contractImpl struct {
id string
routeSelector kit.RouteSelector
Expand Down
69 changes: 46 additions & 23 deletions std/gateways/fasthttp/decoder.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package fasthttp

import (
"fmt"
"strings"
"unsafe"

Expand Down Expand Up @@ -120,34 +121,14 @@ func reflectDecoder(enc kit.Encoding, factory kit.MessageFactoryFunc) DecoderFun

rVal := reflect.ValueOf(factory())
if rVal.Kind() != reflect.Ptr {
panic("x must be a pointer to struct")
panic(fmt.Sprintf("%s must be a pointer to struct", rVal.String()))
}
rVal = rVal.Elem()
if rVal.Kind() != reflect.Struct {
panic("x must be a pointer to struct")
panic(fmt.Sprintf("%s must be a pointer to struct", rVal.String()))
}

var pcs []paramCaster

for i := 0; i < rVal.NumField(); i++ {
f := rVal.Type().Field(i)
if tagValue := f.Tag.Get(tagKey); tagValue != "" {
valueParts := strings.Split(tagValue, ",")
if len(valueParts) == 1 {
valueParts = append(valueParts, "")
}

pcs = append(
pcs,
paramCaster{
offset: f.Offset,
name: valueParts[0],
opt: valueParts[1],
typ: f.Type,
},
)
}
}
pcs := extractFields(rVal, tagKey)

return func(bag Params, data []byte) (kit.Message, error) {
var (
Expand All @@ -171,6 +152,8 @@ func reflectDecoder(enc kit.Encoding, factory kit.MessageFactoryFunc) DecoderFun
ptr := unsafe.Add((*emptyInterface)(unsafe.Pointer(&v)).word, pcs[idx].offset)

switch pcs[idx].typ.Kind() {
default:
// simply ignore
case reflect.Int64:
*(*int64)(ptr) = utils.StrToInt64(x)
case reflect.Int32:
Expand All @@ -179,10 +162,21 @@ func reflectDecoder(enc kit.Encoding, factory kit.MessageFactoryFunc) DecoderFun
*(*uint64)(ptr) = utils.StrToUInt64(x)
case reflect.Uint32:
*(*uint32)(ptr) = utils.StrToUInt32(x)
case reflect.Float64:
*(*float64)(ptr) = utils.StrToFloat64(x)
case reflect.Float32:
*(*float32)(ptr) = utils.StrToFloat32(x)
case reflect.Int:
*(*int)(ptr) = utils.StrToInt(x)
case reflect.Uint:
*(*uint)(ptr) = utils.StrToUInt(x)
case reflect.Slice:
switch pcs[idx].typ.Elem().Kind() {
default:
// simply ignore
case reflect.Uint8:
*(*[]byte)(ptr) = utils.S2B(x)
}
case reflect.String:
*(*string)(ptr) = string(utils.S2B(x))
case reflect.Bool:
Expand All @@ -195,3 +189,32 @@ func reflectDecoder(enc kit.Encoding, factory kit.MessageFactoryFunc) DecoderFun
return v.(kit.Message), nil //nolint:forcetypeassert
}
}

func extractFields(rVal reflect.Value, tagKey string) []paramCaster {
var pcs []paramCaster
for i := 0; i < rVal.NumField(); i++ {
f := rVal.Type().Field(i)
if f.Type.Kind() == reflect.Struct && f.Anonymous {
pcs = append(pcs, extractFields(rVal.Field(i), tagKey)...)
} else {
if tagValue := f.Tag.Get(tagKey); tagValue != "" {
valueParts := strings.Split(tagValue, ",")
if len(valueParts) == 1 {
valueParts = append(valueParts, "")
}

pcs = append(
pcs,
paramCaster{
offset: f.Offset,
name: valueParts[0],
opt: valueParts[1],
typ: f.Type,
},
)
}
}
}

return pcs
}
57 changes: 57 additions & 0 deletions std/gateways/fasthttp/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@ import (
)

type message struct {
embeddedMessage

A string `json:"a"`
B int `json:"b"`
C []byte `json:"c"`
D []string `json:"d"`
Sub subMessage `json:"sub"`
}

type embeddedMessage struct {
X string `json:"x"`
Y int `json:"y"`
}

type subMessage struct {
A string `json:"a"`
B int `json:"b"`
Expand All @@ -30,6 +37,10 @@ func BenchmarkDecoder(b *testing.B) {
A: "a",
B: 1,
},
embeddedMessage: embeddedMessage{
X: "x",
Y: 10,
},
})
p := Params{}
d := reflectDecoder(kit.JSON, kit.CreateMessageFactory(&message{}))
Expand All @@ -44,5 +55,51 @@ func BenchmarkDecoder(b *testing.B) {
if msg.(*message).Sub.B != 1 { //nolint:forcetypeassert
b.Fatal("invalid value")
}
if msg.(*message).X != "x" { //nolint:forcetypeassert
b.Fatal("invalid value")
}
if msg.(*message).Y != 10 { //nolint:forcetypeassert
b.Fatal("invalid value")
}
}
}

func TestDecoder(t *testing.T) {
dec := reflectDecoder(kit.JSON, kit.CreateMessageFactory(&message{}))

params := Params{
{Key: "a", Value: "valueA"},
{Key: "b", Value: "1"},
{Key: "c", Value: "valueC"},
{Key: "d", Value: "valueD"},
{Key: "x", Value: "valueX"},
{Key: "y", Value: "2"},
}

m, err := dec(params, nil)
if err != nil {
t.Fatal(err)
}
mm, ok := m.(*message)
if !ok {
t.Fatal("invalid type")
}
if mm.A != "valueA" {
t.Fatal("invalid value for A")
}
if mm.B != 1 {
t.Fatal("invalid value for B")
}
if string(mm.C) != "valueC" {
t.Fatal("invalid value for C")
}
//if len(mm.D) != 1 || mm.D[0] != "valueD" {
// t.Fatal("invalid value for D")
//}
if mm.X != "valueX" {
t.Fatal("invalid value for X - ", mm.X)
}
if mm.Y != 2 {
t.Fatal("invalid value for Y - ", mm.Y)
}
}
10 changes: 5 additions & 5 deletions std/gateways/fasthttp/proxy/channelpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@ var (
errInvalidCapacitySetting = errors.New("invalid capacity settings")
)

// Pool interface impelement based on channel
// Pool interface implement based on a channel
// there is a channel to contain ReverseProxy object,
// and provide Get and Put method to handle with RevsereProxy
// and provide Get and Put method to handle with ReverseProxy
type chanPool struct {
// mutex makes the chanPool woking with goroutine safely
// mutex makes the chanPool working with goroutine safely
mutex sync.RWMutex

// reverseProxyChan chan of getting the *ReverseProxy and putting it back
reverseProxyChan chan *ReverseProxy

// factory is factory method to generate ReverseProxy
// factory is a factory method to generate ReverseProxy
// this can be customized
factory Factory
}
Expand Down Expand Up @@ -69,7 +69,7 @@ func (p *chanPool) getConnsAndFactory() (chan *ReverseProxy, Factory) {
return reverseProxyChan, factory
}

// Close close the pool
// Close the pool
func (p *chanPool) Close() {
p.mutex.Lock()
reverseProxyChan := p.reverseProxyChan
Expand Down
3 changes: 2 additions & 1 deletion std/gateways/silverhttp/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ func (b *bundle) registerREST(

var methods []string
if method := restSelector.GetMethod(); method == MethodWildcard {
methods = append(methods,
methods = append(
methods,
MethodGet, MethodPost, MethodPut, MethodPatch, MethodDelete, MethodOptions,
MethodConnect, MethodTrace, MethodHead,
)
Expand Down
76 changes: 51 additions & 25 deletions std/gateways/silverhttp/decoder.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package silverhttp

import (
"fmt"
"strings"
"unsafe"

Expand Down Expand Up @@ -47,38 +48,21 @@ func reflectDecoder(enc kit.Encoding, factory kit.MessageFactoryFunc) DecoderFun

rVal := reflect.ValueOf(factory())
if rVal.Kind() != reflect.Ptr {
panic("x must be a pointer to struct")
panic(fmt.Sprintf("%s must be a pointer to struct", rVal.String()))
}
rVal = rVal.Elem()
if rVal.Kind() != reflect.Struct {
panic("x must be a pointer to struct")
panic(fmt.Sprintf("%s must be a pointer to struct", rVal.String()))
}

var pcs []paramCaster

for i := 0; i < rVal.NumField(); i++ {
f := rVal.Type().Field(i)
if tagValue := f.Tag.Get(tagKey); tagValue != "" {
valueParts := strings.Split(tagValue, ",")
if len(valueParts) == 1 {
valueParts = append(valueParts, "")
}

pcs = append(
pcs,
paramCaster{
offset: f.Offset,
name: valueParts[0],
opt: valueParts[1],
typ: f.Type,
},
)
}
}
pcs := extractFields(rVal, tagKey)

return func(bag Params, data []byte) (kit.Message, error) {
v := factory()
var err error
var (
v = factory()
err error
)

if len(data) > 0 {
err = kit.UnmarshalMessage(data, v)
if err != nil {
Expand All @@ -95,6 +79,8 @@ func reflectDecoder(enc kit.Encoding, factory kit.MessageFactoryFunc) DecoderFun
ptr := unsafe.Add((*emptyInterface)(unsafe.Pointer(&v)).word, pcs[idx].offset)

switch pcs[idx].typ.Kind() {
default:
// simply ignore
case reflect.Int64:
*(*int64)(ptr) = utils.StrToInt64(x)
case reflect.Int32:
Expand All @@ -103,10 +89,21 @@ func reflectDecoder(enc kit.Encoding, factory kit.MessageFactoryFunc) DecoderFun
*(*uint64)(ptr) = utils.StrToUInt64(x)
case reflect.Uint32:
*(*uint32)(ptr) = utils.StrToUInt32(x)
case reflect.Float64:
*(*float64)(ptr) = utils.StrToFloat64(x)
case reflect.Float32:
*(*float32)(ptr) = utils.StrToFloat32(x)
case reflect.Int:
*(*int)(ptr) = utils.StrToInt(x)
case reflect.Uint:
*(*uint)(ptr) = utils.StrToUInt(x)
case reflect.Slice:
switch pcs[idx].typ.Elem().Kind() {
default:
// simply ignore
case reflect.Uint8:
*(*[]byte)(ptr) = utils.S2B(x)
}
case reflect.String:
*(*string)(ptr) = string(utils.S2B(x))
case reflect.Bool:
Expand All @@ -119,3 +116,32 @@ func reflectDecoder(enc kit.Encoding, factory kit.MessageFactoryFunc) DecoderFun
return v.(kit.Message), nil //nolint:forcetypeassert
}
}

func extractFields(rVal reflect.Value, tagKey string) []paramCaster {
var pcs []paramCaster
for i := 0; i < rVal.NumField(); i++ {
f := rVal.Type().Field(i)
if f.Type.Kind() == reflect.Struct && f.Anonymous {
pcs = append(pcs, extractFields(rVal.Field(i), tagKey)...)
} else {
if tagValue := f.Tag.Get(tagKey); tagValue != "" {
valueParts := strings.Split(tagValue, ",")
if len(valueParts) == 1 {
valueParts = append(valueParts, "")
}

pcs = append(
pcs,
paramCaster{
offset: f.Offset,
name: valueParts[0],
opt: valueParts[1],
typ: f.Type,
},
)
}
}
}

return pcs
}

0 comments on commit a5c60ee

Please sign in to comment.