diff --git a/internal/proxy/privilege_interceptor.go b/internal/proxy/privilege_interceptor.go index 9eb3e8d77f4f9..10bec733c5ae7 100644 --- a/internal/proxy/privilege_interceptor.go +++ b/internal/proxy/privilege_interceptor.go @@ -15,6 +15,7 @@ import ( "google.golang.org/grpc/status" "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" "github.com/milvus-io/milvus/pkg/log" "github.com/milvus-io/milvus/pkg/util" "github.com/milvus-io/milvus/pkg/util/funcutil" @@ -113,6 +114,11 @@ func PrivilegeInterceptor(ctx context.Context, req interface{}) (context.Context if isCurUserObject(objectType, username, objectName) { return ctx, nil } + + if isSelectMyRoleGrants(req, roleNames) { + return ctx, nil + } + objectNameIndexs := privilegeExt.ObjectNameIndexs objectNames := funcutil.GetObjectNames(req, objectNameIndexs) objectPrivilege := privilegeExt.ObjectPrivilege.String() @@ -181,6 +187,16 @@ func isCurUserObject(objectType string, curUser string, object string) bool { return curUser == object } +func isSelectMyRoleGrants(req interface{}, roleNames []string) bool { + selectGrantReq, ok := req.(*milvuspb.SelectGrantRequest) + if !ok { + return false + } + filterGrantEntity := selectGrantReq.GetEntity() + roleName := filterGrantEntity.GetRole().GetName() + return funcutil.SliceContain(roleNames, roleName) +} + func DBMatchFunc(args ...interface{}) (interface{}, error) { name1 := args[0].(string) name2 := args[1].(string) diff --git a/tests/python_client/testcases/test_utility.py b/tests/python_client/testcases/test_utility.py index b553d50d48f2d..094c3900fd5f5 100644 --- a/tests/python_client/testcases/test_utility.py +++ b/tests/python_client/testcases/test_utility.py @@ -2956,6 +2956,9 @@ def test_role_list_grants(self, host, port, with_db): r_name = cf.gen_unique_str(prefix) c_name = cf.gen_unique_str(prefix) u, _ = self.utility_wrap.create_user(user=user, password=password) + user2 = cf.gen_unique_str(prefix) + u2, _ = self.utility_wrap.create_user(user=user2, password=password) + self.utility_wrap.init_role(r_name) self.utility_wrap.create_role() @@ -2971,10 +2974,27 @@ def test_role_list_grants(self, host, port, with_db): self.utility_wrap.role_grant(grant_item["object"], grant_item["object_name"], grant_item["privilege"], **db_kwargs) - # list grants + # list grants with default user + g_list, _ = self.utility_wrap.role_list_grants(**db_kwargs) + assert len(g_list.groups) == len(grant_list) + + self.connection_wrap.disconnect(alias=DefaultConfig.DEFAULT_USING) + self.connection_wrap.connect(host=host, port=port, user=user, + password=password, check_task=ct.CheckTasks.ccr, **db_kwargs) + + # list grants with user g_list, _ = self.utility_wrap.role_list_grants(**db_kwargs) assert len(g_list.groups) == len(grant_list) + + self.connection_wrap.disconnect(alias=DefaultConfig.DEFAULT_USING) + self.connection_wrap.connect(host=host, port=port, user=user2, + password=password, check_task=ct.CheckTasks.ccr, **db_kwargs) + + # user2 can not list grants of role + self.utility_wrap.role_list_grants(**db_kwargs, + check_task=CheckTasks.check_permission_deny) + @pytest.mark.tags(CaseLabel.RBAC) def test_drop_role_which_bind_user(self, host, port): """