Skip to content

Commit

Permalink
feat: user authz (#3941)
Browse files Browse the repository at this point in the history
* feat: change user table to match mysql

* feat: support user authz

* fix: cean up created users
  • Loading branch information
oh2024 authored Jun 26, 2024
1 parent e3da2a6 commit 2a73952
Show file tree
Hide file tree
Showing 17 changed files with 543 additions and 34 deletions.
4 changes: 4 additions & 0 deletions hybridse/include/node/node_enum.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ enum SqlNodeType {
kColumnSchema,
kCreateUserStmt,
kAlterUserStmt,
kGrantStmt,
kRevokeStmt,
kCallStmt,
kSqlNodeTypeLast, // debug type
kVariadicUdfDef,
Expand Down Expand Up @@ -347,6 +349,8 @@ enum PlanType {
kPlanTypeShow,
kPlanTypeCreateUser,
kPlanTypeAlterUser,
kPlanTypeGrant,
kPlanTypeRevoke,
kPlanTypeCallStmt,
kUnknowPlan = -1,
};
Expand Down
61 changes: 61 additions & 0 deletions hybridse/include/node/plan_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,67 @@ class CreateUserPlanNode : public LeafPlanNode {
const std::shared_ptr<OptionsMap> options_;
};

class GrantPlanNode : public LeafPlanNode {
public:
explicit GrantPlanNode(std::optional<std::string> target_type, std::string database, std::string target,
std::vector<std::string> privileges, bool is_all_privileges,
std::vector<std::string> grantees, bool with_grant_option)
: LeafPlanNode(kPlanTypeGrant),
target_type_(target_type),
database_(database),
target_(target),
privileges_(privileges),
is_all_privileges_(is_all_privileges),
grantees_(grantees),
with_grant_option_(with_grant_option) {}
~GrantPlanNode() = default;
const std::vector<std::string> Privileges() const { return privileges_; }
const std::vector<std::string> Grantees() const { return grantees_; }
const std::string Database() const { return database_; }
const std::string Target() const { return target_; }
const std::optional<std::string> TargetType() const { return target_type_; }
const bool IsAllPrivileges() const { return is_all_privileges_; }
const bool WithGrantOption() const { return with_grant_option_; }

private:
std::optional<std::string> target_type_;
std::string database_;
std::string target_;
std::vector<std::string> privileges_;
bool is_all_privileges_;
std::vector<std::string> grantees_;
bool with_grant_option_;
};

class RevokePlanNode : public LeafPlanNode {
public:
explicit RevokePlanNode(std::optional<std::string> target_type, std::string database, std::string target,
std::vector<std::string> privileges, bool is_all_privileges,
std::vector<std::string> grantees)
: LeafPlanNode(kPlanTypeRevoke),
target_type_(target_type),
database_(database),
target_(target),
privileges_(privileges),
is_all_privileges_(is_all_privileges),
grantees_(grantees) {}
~RevokePlanNode() = default;
const std::vector<std::string> Privileges() const { return privileges_; }
const std::vector<std::string> Grantees() const { return grantees_; }
const std::string Database() const { return database_; }
const std::string Target() const { return target_; }
const std::optional<std::string> TargetType() const { return target_type_; }
const bool IsAllPrivileges() const { return is_all_privileges_; }

private:
std::optional<std::string> target_type_;
std::string database_;
std::string target_;
std::vector<std::string> privileges_;
bool is_all_privileges_;
std::vector<std::string> grantees_;
};

class AlterUserPlanNode : public LeafPlanNode {
public:
explicit AlterUserPlanNode(const std::string& name, bool if_exists, std::shared_ptr<OptionsMap> options)
Expand Down
58 changes: 58 additions & 0 deletions hybridse/include/node/sql_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -2421,6 +2421,64 @@ class AlterUserNode : public SqlNode {
const std::shared_ptr<OptionsMap> options_;
};

class GrantNode : public SqlNode {
public:
explicit GrantNode(std::optional<std::string> target_type, std::string database, std::string target,
std::vector<std::string> privileges, bool is_all_privileges, std::vector<std::string> grantees,
bool with_grant_option)
: SqlNode(kGrantStmt, 0, 0),
target_type_(target_type),
database_(database),
target_(target),
privileges_(privileges),
is_all_privileges_(is_all_privileges),
grantees_(grantees),
with_grant_option_(with_grant_option) {}
const std::vector<std::string> Privileges() const { return privileges_; }
const std::vector<std::string> Grantees() const { return grantees_; }
const std::string Database() const { return database_; }
const std::string Target() const { return target_; }
const std::optional<std::string> TargetType() const { return target_type_; }
const bool IsAllPrivileges() const { return is_all_privileges_; }
const bool WithGrantOption() const { return with_grant_option_; }

private:
std::optional<std::string> target_type_;
std::string database_;
std::string target_;
std::vector<std::string> privileges_;
bool is_all_privileges_;
std::vector<std::string> grantees_;
bool with_grant_option_;
};

class RevokeNode : public SqlNode {
public:
explicit RevokeNode(std::optional<std::string> target_type, std::string database, std::string target,
std::vector<std::string> privileges, bool is_all_privileges, std::vector<std::string> grantees)
: SqlNode(kRevokeStmt, 0, 0),
target_type_(target_type),
database_(database),
target_(target),
privileges_(privileges),
is_all_privileges_(is_all_privileges),
grantees_(grantees) {}
const std::vector<std::string> Privileges() const { return privileges_; }
const std::vector<std::string> Grantees() const { return grantees_; }
const std::string Database() const { return database_; }
const std::string Target() const { return target_; }
const std::optional<std::string> TargetType() const { return target_type_; }
const bool IsAllPrivileges() const { return is_all_privileges_; }

private:
std::optional<std::string> target_type_;
std::string database_;
std::string target_;
std::vector<std::string> privileges_;
bool is_all_privileges_;
std::vector<std::string> grantees_;
};

class ExplainNode : public SqlNode {
public:
explicit ExplainNode(const QueryNode *query, node::ExplainType explain_type)
Expand Down
16 changes: 16 additions & 0 deletions hybridse/src/plan/planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,22 @@ base::Status SimplePlanner::CreatePlanTree(const NodePointVector &parser_trees,
plan_trees.push_back(create_user_plan_node);
break;
}
case ::hybridse::node::kGrantStmt: {
auto node = dynamic_cast<node::GrantNode *>(parser_tree);
auto grant_plan_node = node_manager_->MakeNode<node::GrantPlanNode>(
node->TargetType(), node->Database(), node->Target(), node->Privileges(), node->IsAllPrivileges(),
node->Grantees(), node->WithGrantOption());
plan_trees.push_back(grant_plan_node);
break;
}
case ::hybridse::node::kRevokeStmt: {
auto node = dynamic_cast<node::RevokeNode *>(parser_tree);
auto revoke_plan_node = node_manager_->MakeNode<node::RevokePlanNode>(
node->TargetType(), node->Database(), node->Target(), node->Privileges(), node->IsAllPrivileges(),
node->Grantees());
plan_trees.push_back(revoke_plan_node);
break;
}
case ::hybridse::node::kAlterUserStmt: {
auto node = dynamic_cast<node::AlterUserNode *>(parser_tree);
auto alter_user_plan_node = node_manager_->MakeNode<node::AlterUserPlanNode>(node->Name(),
Expand Down
90 changes: 90 additions & 0 deletions hybridse/src/planv2/ast_node_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "absl/strings/ascii.h"
#include "absl/strings/match.h"
#include "absl/types/span.h"
#include "ast_node_converter.h"
#include "base/fe_status.h"
#include "node/sql_node.h"
#include "udf/udf.h"
Expand Down Expand Up @@ -725,6 +726,20 @@ base::Status ConvertStatement(const zetasql::ASTStatement* statement, node::Node
*output = create_user_node;
break;
}
case zetasql::AST_GRANT_STATEMENT: {
const zetasql::ASTGrantStatement* grant_stmt = statement->GetAsOrNull<zetasql::ASTGrantStatement>();
node::GrantNode* grant_node = nullptr;
CHECK_STATUS(ConvertGrantStatement(grant_stmt, node_manager, &grant_node))
*output = grant_node;
break;
}
case zetasql::AST_REVOKE_STATEMENT: {
const zetasql::ASTRevokeStatement* revoke_stmt = statement->GetAsOrNull<zetasql::ASTRevokeStatement>();
node::RevokeNode* revoke_node = nullptr;
CHECK_STATUS(ConvertRevokeStatement(revoke_stmt, node_manager, &revoke_node))
*output = revoke_node;
break;
}
case zetasql::AST_ALTER_USER_STATEMENT: {
const zetasql::ASTAlterUserStatement* alter_user_stmt =
statement->GetAsOrNull<zetasql::ASTAlterUserStatement>();
Expand Down Expand Up @@ -2133,6 +2148,81 @@ base::Status ConvertAlterUserStatement(const zetasql::ASTAlterUserStatement* roo
return base::Status::OK();
}

base::Status ConvertGrantStatement(const zetasql::ASTGrantStatement* root, node::NodeManager* node_manager,
node::GrantNode** output) {
CHECK_TRUE(root != nullptr, common::kSqlAstError, "not an ASTGrantStatement");
std::vector<std::string> target_path;
CHECK_STATUS(AstPathExpressionToStringList(root->target_path(), target_path));
std::optional<std::string> target_type = std::nullopt;
if (root->target_type() != nullptr) {
target_type = root->target_type()->GetAsString();
}

std::vector<std::string> privileges;
std::vector<std::string> grantees;
for (auto privilege : root->privileges()->privileges()) {
if (privilege == nullptr) {
continue;
}

auto privilege_action = privilege->privilege_action();
if (privilege_action != nullptr) {
privileges.push_back(privilege_action->GetAsString());
}
}

for (auto grantee : root->grantee_list()->grantee_list()) {
if (grantee == nullptr) {
continue;
}

std::string grantee_str;
CHECK_STATUS(AstStringLiteralToString(grantee, &grantee_str));
grantees.push_back(grantee_str);
}
*output = node_manager->MakeNode<node::GrantNode>(target_type, target_path.at(0), target_path.at(1), privileges,
root->privileges()->is_all_privileges(), grantees,
root->with_grant_option());
return base::Status::OK();
}

base::Status ConvertRevokeStatement(const zetasql::ASTRevokeStatement* root, node::NodeManager* node_manager,
node::RevokeNode** output) {
CHECK_TRUE(root != nullptr, common::kSqlAstError, "not an ASTRevokeStatement");
std::vector<std::string> target_path;
CHECK_STATUS(AstPathExpressionToStringList(root->target_path(), target_path));
std::optional<std::string> target_type = std::nullopt;
if (root->target_type() != nullptr) {
target_type = root->target_type()->GetAsString();
}

std::vector<std::string> privileges;
std::vector<std::string> grantees;
for (auto privilege : root->privileges()->privileges()) {
if (privilege == nullptr) {
continue;
}

auto privilege_action = privilege->privilege_action();
if (privilege_action != nullptr) {
privileges.push_back(privilege_action->GetAsString());
}
}

for (auto grantee : root->grantee_list()->grantee_list()) {
if (grantee == nullptr) {
continue;
}

std::string grantee_str;
CHECK_STATUS(AstStringLiteralToString(grantee, &grantee_str));
grantees.push_back(grantee_str);
}
*output = node_manager->MakeNode<node::RevokeNode>(target_type, target_path.at(0), target_path.at(1), privileges,
root->privileges()->is_all_privileges(), grantees);
return base::Status::OK();
}

base::Status ConvertCreateIndexStatement(const zetasql::ASTCreateIndexStatement* root, node::NodeManager* node_manager,
node::CreateIndexNode** output) {
CHECK_TRUE(nullptr != root, common::kSqlAstError, "not an ASTCreateIndexStatement")
Expand Down
6 changes: 6 additions & 0 deletions hybridse/src/planv2/ast_node_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ base::Status ConvertCreateUserStatement(const zetasql::ASTCreateUserStatement* r
base::Status ConvertAlterUserStatement(const zetasql::ASTAlterUserStatement* root, node::NodeManager* node_manager,
node::AlterUserNode** output);

base::Status ConvertGrantStatement(const zetasql::ASTGrantStatement* root, node::NodeManager* node_manager,
node::GrantNode** output);

base::Status ConvertRevokeStatement(const zetasql::ASTRevokeStatement* root, node::NodeManager* node_manager,
node::RevokeNode** output);

base::Status ConvertQueryNode(const zetasql::ASTQuery* root, node::NodeManager* node_manager, node::QueryNode** output);

base::Status ConvertQueryExpr(const zetasql::ASTQueryExpression* query_expr, node::NodeManager* node_manager,
Expand Down
41 changes: 35 additions & 6 deletions src/auth/user_access_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void UserAccessManager::StopSyncTask() {

void UserAccessManager::SyncWithDB() {
if (auto it_pair = user_table_iterator_factory_(::openmldb::nameserver::USER_INFO_NAME); it_pair) {
auto new_user_map = std::make_unique<std::unordered_map<std::string, std::string>>();
auto new_user_map = std::make_unique<std::unordered_map<std::string, UserRecord>>();
auto it = it_pair->first.get();
it->SeekToFirst();
while (it->Valid()) {
Expand All @@ -56,26 +56,55 @@ void UserAccessManager::SyncWithDB() {
auto size = it->GetValue().size();
codec::RowView row_view(*it_pair->second.get(), buf, size);
std::string host, user, password;
std::string privilege_level_str;
row_view.GetStrValue(0, &host);
row_view.GetStrValue(1, &user);
row_view.GetStrValue(2, &password);
row_view.GetStrValue(5, &privilege_level_str);
openmldb::nameserver::PrivilegeLevel privilege_level;
::openmldb::nameserver::PrivilegeLevel_Parse(privilege_level_str, &privilege_level);
UserRecord user_record = {password, privilege_level};
if (host == "%") {
new_user_map->emplace(user, password);
new_user_map->emplace(user, user_record);
} else {
new_user_map->emplace(FormUserHost(user, host), password);
new_user_map->emplace(FormUserHost(user, host), user_record);
}
it->Next();
}
user_map_.Refresh(std::move(new_user_map));
}
}

std::optional<std::string> UserAccessManager::GetUserPassword(const std::string& host, const std::string& user) {
if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) {
return user_record.value().password;
} else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) {
return stored_password.value().password;
} else {
return std::nullopt;
}
}

bool UserAccessManager::IsAuthenticated(const std::string& host, const std::string& user, const std::string& password) {
if (auto stored_password = user_map_.Get(FormUserHost(user, host)); stored_password.has_value()) {
return stored_password.value() == password;
if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) {
return user_record.value().password == password;
} else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) {
return stored_password.value() == password;
return stored_password.value().password == password;
}
return false;
}

::openmldb::nameserver::PrivilegeLevel UserAccessManager::GetPrivilegeLevel(const std::string& user_at_host) {
std::size_t at_pos = user_at_host.find('@');
if (at_pos != std::string::npos) {
std::string user = user_at_host.substr(0, at_pos);
std::string host = user_at_host.substr(at_pos + 1);
if (auto user_record = user_map_.Get(FormUserHost(user, host)); user_record.has_value()) {
return user_record.value().privilege_level;
} else if (auto stored_password = user_map_.Get(user); stored_password.has_value()) {
return stored_password.value().privilege_level;
}
}
return ::openmldb::nameserver::PrivilegeLevel::NO_PRIVILEGE;
}
} // namespace openmldb::auth
10 changes: 9 additions & 1 deletion src/auth/user_access_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,15 @@
#include <utility>

#include "catalog/distribute_iterator.h"
#include "proto/name_server.pb.h"
#include "refreshable_map.h"

namespace openmldb::auth {
struct UserRecord {
std::string password;
::openmldb::nameserver::PrivilegeLevel privilege_level;
};

class UserAccessManager {
public:
using IteratorFactory = std::function<std::optional<
Expand All @@ -39,11 +45,13 @@ class UserAccessManager {

~UserAccessManager();
bool IsAuthenticated(const std::string& host, const std::string& username, const std::string& password);
::openmldb::nameserver::PrivilegeLevel GetPrivilegeLevel(const std::string& user_at_host);
void SyncWithDB();
std::optional<std::string> GetUserPassword(const std::string& host, const std::string& user);

private:
IteratorFactory user_table_iterator_factory_;
RefreshableMap<std::string, std::string> user_map_;
RefreshableMap<std::string, UserRecord> user_map_;
std::thread sync_task_thread_;
std::promise<void> stop_promise_;
void StartSyncTask();
Expand Down
Loading

0 comments on commit 2a73952

Please sign in to comment.