Skip to content

Commit

Permalink
feat(ibc-hooks): improved UnmarshalJSON validation & error handling (#…
Browse files Browse the repository at this point in the history
…324)

* feat(ibc-hooks): improved UnmarshalJSON validation & error handling

Descritpion
-----------
Improves AsyncCallback JSON unmarshaling with:
- Better validation for ModuleAddress format and required fields
- More specific error messages
- Memory optimization using single intermediate struct
- Comprehensive test coverage for error cases

Tetsing the introduced `feat`
-----------------------------
Fetch this PR branch and from the root directory run:
```
go test ./x/ibc-hooks/move-hooks/message_test.go -v
```

* refact: improved AsyncCallback.UnmarshalJSON validation

Description
-----------
- Added overflow checking for uint64 ID conversion
- Added validation for decimal and negative values
- Added comprehensive test cases for edge cases
  • Loading branch information
0xObsidian authored Dec 24, 2024
1 parent 10dcf14 commit e0e3c76
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 53 deletions.
77 changes: 45 additions & 32 deletions x/ibc-hooks/move-hooks/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package move_hooks

import (
"encoding/json"
"fmt"
"strings"

movetypes "github.com/initia-labs/initia/x/move/types"
)
Expand Down Expand Up @@ -53,44 +55,55 @@ type HookData struct {
AsyncCallback *AsyncCallback `json:"async_callback,omitempty"`
}

// asyncCallback is same as AsyncCallback.
type asyncCallback struct {
// callback id should be issued form the executor contract
Id uint64 `json:"id"`
ModuleAddress string `json:"module_address"`
ModuleName string `json:"module_name"`
}

// asyncCallbackStringID is same as AsyncCallback but
// it has Id as string.
type asyncCallbackStringID struct {
// callback id should be issued form the executor contract
Id uint64 `json:"id,string"`
ModuleAddress string `json:"module_address"`
ModuleName string `json:"module_name"`
// intermediateCallback is used internally for JSON unmarshaling
type intermediateCallback struct {
Id interface{} `json:"id"`
ModuleAddress string `json:"module_address"`
ModuleName string `json:"module_name"`
}

// UnmarshalJSON implements the json unmarshaler interface.
// custom unmarshaler is required because we have to handle
// id as string and uint64.
// It handles both string and numeric id formats and validates the module address.
func (a *AsyncCallback) UnmarshalJSON(bz []byte) error {
var ac asyncCallback
err := json.Unmarshal(bz, &ac)
if err != nil {
var acStr asyncCallbackStringID
err := json.Unmarshal(bz, &acStr)
if err != nil {
return err
}
var ic intermediateCallback
if err := json.Unmarshal(bz, &ic); err != nil {
return fmt.Errorf("failed to unmarshal AsyncCallback: %w", err)
}

a.Id = acStr.Id
a.ModuleAddress = acStr.ModuleAddress
a.ModuleName = acStr.ModuleName
return nil
// Validate required fields
if ic.ModuleAddress == "" {
return fmt.Errorf("module_address cannot be empty")
}
if ic.ModuleName == "" {
return fmt.Errorf("module_name cannot be empty")
}

// Validate module address format
if !strings.HasPrefix(ic.ModuleAddress, "0x") {
return fmt.Errorf("invalid module_address format: must start with '0x'")
}

// Handle ID based on type with overflow checking
switch v := ic.Id.(type) {
case float64:
if v < 0 || v > float64(^uint64(0)) || v != float64(uint64(v)) {
return fmt.Errorf("id value out of range or contains decimals")
}
a.Id = uint64(v)
case string:
var parsed float64
if err := json.Unmarshal([]byte(v), &parsed); err != nil {
return fmt.Errorf("invalid id format: %w", err)
}
if parsed < 0 || parsed > float64(^uint64(0)) || parsed != float64(uint64(parsed)) {
return fmt.Errorf("id value out of range or contains decimals")
}
a.Id = uint64(parsed)
default:
return fmt.Errorf("invalid id type: expected string or number")
}

a.Id = ac.Id
a.ModuleAddress = ac.ModuleAddress
a.ModuleName = ac.ModuleName
a.ModuleAddress = ic.ModuleAddress
a.ModuleName = ic.ModuleName
return nil
}
178 changes: 157 additions & 21 deletions x/ibc-hooks/move-hooks/message_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,25 +9,161 @@ import (
)

func Test_Unmarshal_AsyncCallback(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": 99,
"module_address": "0x1",
"module_name": "Counter"
}`), &callback)
require.NoError(t, err)
require.Equal(t, movehooks.AsyncCallback{
Id: 99,
ModuleAddress: "0x1",
ModuleName: "Counter",
}, callback)

var callbackStringID movehooks.AsyncCallback
err = json.Unmarshal([]byte(`{
"id": "99",
"module_address": "0x1",
"module_name": "Counter"
}`), &callbackStringID)
require.NoError(t, err)
require.Equal(t, callback, callbackStringID)
t.Run("valid numeric id", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": 99,
"module_address": "0x1",
"module_name": "Counter"
}`), &callback)
require.NoError(t, err)
require.Equal(t, movehooks.AsyncCallback{
Id: 99,
ModuleAddress: "0x1",
ModuleName: "Counter",
}, callback)
})

t.Run("valid string id", func(t *testing.T) {
var callbackStringID movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": "99",
"module_address": "0x1",
"module_name": "Counter"
}`), &callbackStringID)
require.NoError(t, err)
require.Equal(t, movehooks.AsyncCallback{
Id: 99,
ModuleAddress: "0x1",
ModuleName: "Counter",
}, callbackStringID)
})

t.Run("empty module address", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": 99,
"module_address": "",
"module_name": "Counter"
}`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "module_address cannot be empty")
})

t.Run("empty module name", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": 99,
"module_address": "0x1",
"module_name": ""
}`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "module_name cannot be empty")
})

t.Run("invalid module address format", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": 99,
"module_address": "invalid",
"module_name": "Counter"
}`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid module_address format")
})

t.Run("invalid id type", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": true,
"module_address": "0x1",
"module_name": "Counter"
}`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid id type")
})

t.Run("invalid id string format", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": "not_a_number",
"module_address": "0x1",
"module_name": "Counter"
}`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid id format")
})

t.Run("malformed json", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{malformed`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "invalid character")
})

t.Run("id with decimal value", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": 99.5,
"module_address": "0x1",
"module_name": "Counter"
}`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "id value out of range or contains decimals")
})

t.Run("id with string decimal value", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": "99.5",
"module_address": "0x1",
"module_name": "Counter"
}`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "id value out of range or contains decimals")
})

t.Run("negative id value", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": -1,
"module_address": "0x1",
"module_name": "Counter"
}`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "id value out of range or contains decimals")
})

t.Run("negative string id value", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": "-1",
"module_address": "0x1",
"module_name": "Counter"
}`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "id value out of range or contains decimals")
})

t.Run("id value exceeding uint64 max", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": 18446744073709551616,
"module_address": "0x1",
"module_name": "Counter"
}`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "id value out of range or contains decimals")
})

t.Run("string id value exceeding uint64 max", func(t *testing.T) {
var callback movehooks.AsyncCallback
err := json.Unmarshal([]byte(`{
"id": "18446744073709551616",
"module_address": "0x1",
"module_name": "Counter"
}`), &callback)
require.Error(t, err)
require.Contains(t, err.Error(), "id value out of range or contains decimals")
})
}

0 comments on commit e0e3c76

Please sign in to comment.