diff --git a/go/Cargo.toml b/go/Cargo.toml index 62872578da..6d6c4ecb15 100644 --- a/go/Cargo.toml +++ b/go/Cargo.toml @@ -13,7 +13,6 @@ redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp" glide-core = { path = "../glide-core", features = ["socket-layer"] } tokio = { version = "^1", features = ["rt", "macros", "rt-multi-thread", "time"] } protobuf = { version = "3.3.0", features = [] } -derivative = "2.2.0" [profile.release] lto = true diff --git a/go/api/glide_client.go b/go/api/glide_client.go index 27f3dbac75..2f8e8e5235 100644 --- a/go/api/glide_client.go +++ b/go/api/glide_client.go @@ -33,18 +33,12 @@ func NewGlideClient(config *GlideClientConfiguration) (*GlideClient, error) { // For example, to return a list of all pub/sub clients: // // client.CustomCommand([]string{"CLIENT", "LIST","TYPE", "PUBSUB"}) -// -// TODO: Add support for complex return types. func (client *GlideClient) CustomCommand(args []string) (interface{}, error) { res, err := client.executeCommand(C.CustomCommand, args) if err != nil { return nil, err } - resString, err := handleStringOrNullResponse(res) - if err != nil { - return nil, err - } - return resString.Value(), err + return handleInterfaceResponse(res) } // Sets configuration parameters to the specified values. diff --git a/go/api/response_handlers.go b/go/api/response_handlers.go index 2cd739aefd..63f0ac0007 100644 --- a/go/api/response_handlers.go +++ b/go/api/response_handlers.go @@ -59,6 +59,102 @@ func convertCharArrayToString(response *C.struct_CommandResponse, isNilable bool return CreateStringResult(string(byteSlice)), nil } +func handleInterfaceResponse(response *C.struct_CommandResponse) (interface{}, error) { + defer C.free_command_response(response) + + return parseInterface(response) +} + +func parseInterface(response *C.struct_CommandResponse) (interface{}, error) { + if response == nil { + return nil, nil + } + + switch response.response_type { + case C.Null: + return nil, nil + case C.String: + return parseString(response) + case C.Int: + return int64(response.int_value), nil + case C.Float: + return float64(response.float_value), nil + case C.Bool: + return bool(response.bool_value), nil + case C.Array: + return parseArray(response) + case C.Map: + return parseMap(response) + case C.Sets: + return parseSet(response) + } + + return nil, &RequestError{"Unexpected return type from Valkey"} +} + +func parseString(response *C.struct_CommandResponse) (interface{}, error) { + if response.string_value == nil { + return nil, nil + } + byteSlice := C.GoBytes(unsafe.Pointer(response.string_value), C.int(int64(response.string_value_len))) + + // Create Go string from byte slice (preserving null characters) + return string(byteSlice), nil +} + +func parseArray(response *C.struct_CommandResponse) (interface{}, error) { + if response.array_value == nil { + return nil, nil + } + + var slice []interface{} + for _, v := range unsafe.Slice(response.array_value, response.array_value_len) { + res, err := parseInterface(&v) + if err != nil { + return nil, err + } + slice = append(slice, res) + } + return slice, nil +} + +func parseMap(response *C.struct_CommandResponse) (interface{}, error) { + if response.array_value == nil { + return nil, nil + } + + value_map := make(map[interface{}]interface{}, response.array_value_len) + for _, v := range unsafe.Slice(response.array_value, response.array_value_len) { + res_key, err := parseInterface(v.map_key) + if err != nil { + return nil, err + } + res_val, err := parseInterface(v.map_value) + if err != nil { + return nil, err + } + value_map[res_key] = res_val + } + return value_map, nil +} + +func parseSet(response *C.struct_CommandResponse) (interface{}, error) { + if response.sets_value == nil { + return nil, nil + } + + slice := make(map[interface{}]struct{}, response.sets_value_len) + for _, v := range unsafe.Slice(response.sets_value, response.sets_value_len) { + res, err := parseInterface(&v) + if err != nil { + return nil, err + } + slice[res] = struct{}{} + } + + return slice, nil +} + func handleStringResponse(response *C.struct_CommandResponse) (Result[string], error) { defer C.free_command_response(response) diff --git a/go/integTest/standalone_commands_test.go b/go/integTest/standalone_commands_test.go index 4186c6b184..80cbb63581 100644 --- a/go/integTest/standalone_commands_test.go +++ b/go/integTest/standalone_commands_test.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "github.com/google/uuid" "github.com/valkey-io/valkey-glide/go/glide/api" "github.com/stretchr/testify/assert" @@ -21,12 +22,12 @@ func (suite *GlideTestSuite) TestCustomCommandInfo() { assert.True(suite.T(), strings.Contains(strResult, "# Stats")) } -func (suite *GlideTestSuite) TestCustomCommandPing() { +func (suite *GlideTestSuite) TestCustomCommandPing_StringResponse() { client := suite.defaultClient() result, err := client.CustomCommand([]string{"PING"}) assert.Nil(suite.T(), err) - assert.Equal(suite.T(), "PONG", result) + assert.Equal(suite.T(), "PONG", result.(string)) } func (suite *GlideTestSuite) TestCustomCommandClientInfo() { @@ -44,6 +45,106 @@ func (suite *GlideTestSuite) TestCustomCommandClientInfo() { assert.True(suite.T(), strings.Contains(strResult, fmt.Sprintf("name=%s", clientName))) } +func (suite *GlideTestSuite) TestCustomCommandGet_NullResponse() { + client := suite.defaultClient() + key := uuid.New().String() + result, err := client.CustomCommand([]string{"GET", key}) + + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), nil, result) +} + +func (suite *GlideTestSuite) TestCustomCommandDel_LongResponse() { + client := suite.defaultClient() + key := uuid.New().String() + suite.verifyOK(client.Set(key, "value")) + result, err := client.CustomCommand([]string{"DEL", key}) + + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), int64(1), result.(int64)) +} + +func (suite *GlideTestSuite) TestCustomCommandHExists_BoolResponse() { + client := suite.defaultClient() + fields := map[string]string{"field1": "value1"} + key := uuid.New().String() + + res1, err := client.HSet(key, fields) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), int64(1), res1.Value()) + + result, err := client.CustomCommand([]string{"HEXISTS", key, "field1"}) + + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), true, result.(bool)) +} + +func (suite *GlideTestSuite) TestCustomCommandIncrByFloat_FloatResponse() { + client := suite.defaultClient() + key := uuid.New().String() + + result, err := client.CustomCommand([]string{"INCRBYFLOAT", key, fmt.Sprintf("%f", 0.1)}) + + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), float64(0.1), result.(float64)) +} + +func (suite *GlideTestSuite) TestCustomCommandMGet_ArrayResponse() { + clientName := "TEST_CLIENT_NAME" + config := api.NewGlideClientConfiguration(). + WithAddress(&api.NodeAddress{Port: suite.standalonePorts[0]}). + WithClientName(clientName) + client := suite.client(config) + + key1 := uuid.New().String() + key2 := uuid.New().String() + key3 := uuid.New().String() + oldValue := uuid.New().String() + value := uuid.New().String() + suite.verifyOK(client.Set(key1, oldValue)) + keyValueMap := map[string]string{ + key1: value, + key2: value, + } + suite.verifyOK(client.MSet(keyValueMap)) + keys := []string{key1, key2, key3} + values := []interface{}{value, value, nil} + result, err := client.CustomCommand(append([]string{"MGET"}, keys...)) + + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), values, result.([]interface{})) +} + +func (suite *GlideTestSuite) TestCustomCommandConfigGet_MapResponse() { + client := suite.defaultClient() + + if suite.serverVersion < "7.0.0" { + suite.T().Skip("This feature is added in version 7") + } + configMap := map[string]string{"timeout": "1000", "maxmemory": "1GB"} + result, err := client.ConfigSet(configMap) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), "OK", result.Value()) + + result2, err := client.CustomCommand([]string{"CONFIG", "GET", "timeout", "maxmemory"}) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), map[interface{}]interface{}{"timeout": "1000", "maxmemory": "1073741824"}, result2) +} + +func (suite *GlideTestSuite) TestCustomCommandConfigSMembers_SetResponse() { + client := suite.defaultClient() + key := uuid.NewString() + members := []string{"member1", "member2", "member3"} + + res1, err := client.SAdd(key, members) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), int64(3), res1.Value()) + + result2, err := client.CustomCommand([]string{"SMEMBERS", key}) + assert.Nil(suite.T(), err) + assert.Equal(suite.T(), map[interface{}]struct{}{"member1": {}, "member2": {}, "member3": {}}, result2) +} + func (suite *GlideTestSuite) TestCustomCommand_invalidCommand() { client := suite.defaultClient() result, err := client.CustomCommand([]string{"pewpew"}) diff --git a/go/src/lib.rs b/go/src/lib.rs index d9bdf14e68..55b3c9515c 100644 --- a/go/src/lib.rs +++ b/go/src/lib.rs @@ -3,7 +3,6 @@ */ #![deny(unsafe_op_in_unsafe_fn)] -use derivative::Derivative; use glide_core::client::Client as GlideClient; use glide_core::connection_request; use glide_core::errors; @@ -28,8 +27,7 @@ use tokio::runtime::Runtime; /// The struct is freed by the external caller by using `free_command_response` to avoid memory leaks. /// TODO: Add a type enum to validate what type of response is being sent in the CommandResponse. #[repr(C)] -#[derive(Derivative)] -#[derivative(Debug, Default)] +#[derive(Debug)] pub struct CommandResponse { response_type: ResponseType, int_value: c_long, @@ -39,33 +37,47 @@ pub struct CommandResponse { /// Below two values are related to each other. /// `string_value` represents the string. /// `string_value_len` represents the length of the string. - #[derivative(Default(value = "std::ptr::null_mut()"))] string_value: *mut c_char, string_value_len: c_long, /// Below two values are related to each other. /// `array_value` represents the array of CommandResponse. /// `array_value_len` represents the length of the array. - #[derivative(Default(value = "std::ptr::null_mut()"))] array_value: *mut CommandResponse, array_value_len: c_long, /// Below two values represent the Map structure inside CommandResponse. /// The map is transformed into an array of (map_key: CommandResponse, map_value: CommandResponse) and passed to Go. /// These are represented as pointers as the map can be null (optionally present). - #[derivative(Default(value = "std::ptr::null_mut()"))] map_key: *mut CommandResponse, - #[derivative(Default(value = "std::ptr::null_mut()"))] map_value: *mut CommandResponse, /// Below two values are related to each other. /// `sets_value` represents the set of CommandResponse. /// `sets_value_len` represents the length of the set. - #[derivative(Default(value = "std::ptr::null_mut()"))] sets_value: *mut CommandResponse, sets_value_len: c_long, } +impl Default for CommandResponse { + fn default() -> Self { + CommandResponse { + response_type: ResponseType::default(), + int_value: 0, + float_value: 0.0, + bool_value: false, + string_value: std::ptr::null_mut(), + string_value_len: 0, + array_value: std::ptr::null_mut(), + array_value_len: 0, + map_key: std::ptr::null_mut(), + map_value: std::ptr::null_mut(), + sets_value: std::ptr::null_mut(), + sets_value_len: 0, + } + } +} + #[repr(C)] #[derive(Debug, Default)] pub enum ResponseType {