Skip to content

Commit

Permalink
Support domain whitelisting for projects (#4729)
Browse files Browse the repository at this point in the history
* Support domain whitelisting for projects

* Fixed "sudo" command execution
  • Loading branch information
esevastyanov authored Apr 29, 2024
1 parent 994532c commit 76d1c16
Show file tree
Hide file tree
Showing 22 changed files with 6,096 additions and 3,333 deletions.
26 changes: 26 additions & 0 deletions admin/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ type DB interface {
DeleteProject(ctx context.Context, id string) error
UpdateProject(ctx context.Context, id string, opts *UpdateProjectOptions) (*Project, error)
CountProjectsForOrganization(ctx context.Context, orgID string) (int, error)
FindProjectWhitelistedDomain(ctx context.Context, projectID, domain string) (*ProjectWhitelistedDomain, error)
FindProjectWhitelistedDomainForProjectWithJoinedRoleNames(ctx context.Context, projectID string) ([]*ProjectWhitelistedDomainWithJoinedRoleNames, error)
FindProjectWhitelistedDomainsForDomain(ctx context.Context, domain string) ([]*ProjectWhitelistedDomain, error)
InsertProjectWhitelistedDomain(ctx context.Context, opts *InsertProjectWhitelistedDomainOptions) (*ProjectWhitelistedDomain, error)
DeleteProjectWhitelistedDomain(ctx context.Context, id string) error

FindExpiredDeployments(ctx context.Context) ([]*Deployment, error)
FindDeploymentsForProject(ctx context.Context, projectID string) ([]*Deployment, error)
Expand Down Expand Up @@ -107,6 +112,7 @@ type DB interface {
FindSuperusers(ctx context.Context) ([]*User, error)
UpdateSuperuser(ctx context.Context, userID string, superuser bool) error
CheckUserIsAnOrganizationMember(ctx context.Context, userID, orgID string) (bool, error)
CheckUserIsAProjectMember(ctx context.Context, userID, projectID string) (bool, error)

InsertUsergroup(ctx context.Context, opts *InsertUsergroupOptions) (*Usergroup, error)
FindUsergroupsForUser(ctx context.Context, userID, orgID string) ([]*Usergroup, error)
Expand Down Expand Up @@ -640,6 +646,26 @@ type OrganizationWhitelistedDomainWithJoinedRoleNames struct {
RoleName string `db:"name"`
}

type ProjectWhitelistedDomain struct {
ID string
ProjectID string `db:"project_id"`
ProjectRoleID string `db:"project_role_id"`
Domain string
CreatedOn time.Time `db:"created_on"`
UpdatedOn time.Time `db:"updated_on"`
}

type InsertProjectWhitelistedDomainOptions struct {
ProjectID string `validate:"required"`
ProjectRoleID string `validate:"required"`
Domain string `validate:"domain"`
}

type ProjectWhitelistedDomainWithJoinedRoleNames struct {
Domain string
RoleName string `db:"name"`
}

const (
DefaultQuotaProjects = 5
DefaultQuotaDeployments = 10
Expand Down
11 changes: 11 additions & 0 deletions admin/database/postgres/migrations/0027.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CREATE TABLE projects_autoinvite_domains (
id UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v4(),
project_id UUID NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
project_role_id UUID NOT NULL REFERENCES project_roles (id) ON DELETE CASCADE,
domain TEXT NOT NULL,
created_on TIMESTAMPTZ NOT NULL DEFAULT now(),
updated_on TIMESTAMPTZ NOT NULL DEFAULT now()
);

CREATE INDEX projects_autoinvite_domains_domain_idx ON projects_autoinvite_domains (lower(domain));
CREATE UNIQUE INDEX projects_autoinvite_domains_project_id_domain_idx ON projects_autoinvite_domains (project_id, lower(domain));
54 changes: 54 additions & 0 deletions admin/database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,51 @@ func (c *connection) CountProjectsForOrganization(ctx context.Context, orgID str
return count, nil
}

func (c *connection) FindProjectWhitelistedDomainForProjectWithJoinedRoleNames(ctx context.Context, projectID string) ([]*database.ProjectWhitelistedDomainWithJoinedRoleNames, error) {
var res []*database.ProjectWhitelistedDomainWithJoinedRoleNames
err := c.getDB(ctx).SelectContext(ctx, &res, "SELECT pad.domain, r.name FROM projects_autoinvite_domains pad JOIN project_roles r ON r.id = pad.project_role_id WHERE pad.project_id=$1", projectID)
if err != nil {
return nil, parseErr("project whitelist domains", err)
}
return res, nil
}

func (c *connection) FindProjectWhitelistedDomainsForDomain(ctx context.Context, domain string) ([]*database.ProjectWhitelistedDomain, error) {
var res []*database.ProjectWhitelistedDomain
err := c.getDB(ctx).SelectContext(ctx, "SELECT * FROM projects_autoinvite_domains WHERE lower(domain)=lower($1)", domain)
if err != nil {
return nil, parseErr("project whitelist domains", err)
}
return res, nil
}

func (c *connection) FindProjectWhitelistedDomain(ctx context.Context, projectID, domain string) (*database.ProjectWhitelistedDomain, error) {
res := &database.ProjectWhitelistedDomain{}
err := c.getDB(ctx).QueryRowxContext(ctx, "SELECT * FROM projects_autoinvite_domains WHERE project_id=$1 AND lower(domain)=lower($2)", projectID, domain).StructScan(res)
if err != nil {
return nil, parseErr("project whitelist domain", err)
}
return res, nil
}

func (c *connection) InsertProjectWhitelistedDomain(ctx context.Context, opts *database.InsertProjectWhitelistedDomainOptions) (*database.ProjectWhitelistedDomain, error) {
if err := database.Validate(opts); err != nil {
return nil, err
}

res := &database.ProjectWhitelistedDomain{}
err := c.getDB(ctx).QueryRowxContext(ctx, `INSERT INTO projects_autoinvite_domains(project_id, project_role_id, domain) VALUES ($1, $2, $3) RETURNING *`, opts.ProjectID, opts.ProjectRoleID, opts.Domain).StructScan(res)
if err != nil {
return nil, parseErr("project whitelist domain", err)
}
return res, nil
}

func (c *connection) DeleteProjectWhitelistedDomain(ctx context.Context, id string) error {
res, err := c.getDB(ctx).ExecContext(ctx, "DELETE FROM projects_autoinvite_domains WHERE id=$1", id)
return checkDeleteRow("project whitelist domain", res, err)
}

// FindExpiredDeployments returns all the deployments which are expired as per prod ttl
func (c *connection) FindExpiredDeployments(ctx context.Context) ([]*database.Deployment, error) {
var res []*database.Deployment
Expand Down Expand Up @@ -634,6 +679,15 @@ func (c *connection) CheckUserIsAnOrganizationMember(ctx context.Context, userID
return res, nil
}

func (c *connection) CheckUserIsAProjectMember(ctx context.Context, userID, projectID string) (bool, error) {
var res bool
err := c.getDB(ctx).QueryRowxContext(ctx, "SELECT EXISTS (SELECT 1 FROM users_projects_roles WHERE user_id=$1 AND project_id=$2)", userID, projectID).Scan(&res)
if err != nil {
return false, parseErr("check", err)
}
return res, nil
}

func (c *connection) InsertUsergroup(ctx context.Context, opts *database.InsertUsergroupOptions) (*database.Usergroup, error) {
if err := database.Validate(opts); err != nil {
return nil, err
Expand Down
171 changes: 171 additions & 0 deletions admin/server/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@ import (
"context"
"errors"
"fmt"
"math"
"net/url"
"strings"
"time"

"github.com/rilldata/rill/admin"
"github.com/rilldata/rill/admin/database"
"github.com/rilldata/rill/admin/pkg/publicemail"
"github.com/rilldata/rill/admin/server/auth"
adminv1 "github.com/rilldata/rill/proto/gen/rill/admin/v1"
"github.com/rilldata/rill/runtime/pkg/email"
Expand Down Expand Up @@ -888,6 +891,174 @@ func (s *Server) SudoUpdateAnnotations(ctx context.Context, req *adminv1.SudoUpd
}, nil
}

func (s *Server) CreateProjectWhitelistedDomain(ctx context.Context, req *adminv1.CreateProjectWhitelistedDomainRequest) (*adminv1.CreateProjectWhitelistedDomainResponse, error) {
observability.AddRequestAttributes(ctx,
attribute.String("args.organization", req.Organization),
attribute.String("args.project", req.Project),
attribute.String("args.domain", req.Domain),
attribute.String("args.role", req.Role),
)

claims := auth.GetClaims(ctx)
if claims.OwnerType() != auth.OwnerTypeUser {
return nil, status.Error(codes.Unauthenticated, "not authenticated as a user")
}

proj, err := s.admin.DB.FindProjectByName(ctx, req.Organization, req.Project)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
return nil, status.Error(codes.NotFound, "project not found")
}
return nil, status.Error(codes.Internal, err.Error())
}

if !claims.Superuser(ctx) {
if !claims.ProjectPermissions(ctx, proj.OrganizationID, proj.ID).ManageProject {
return nil, status.Error(codes.PermissionDenied, "only proj admins can add whitelisted domain")
}
// check if the user's domain matches the whitelist domain
user, err := s.admin.DB.FindUser(ctx, claims.OwnerID())
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
if !strings.HasSuffix(user.Email, "@"+req.Domain) {
return nil, status.Error(codes.PermissionDenied, "Domain name doesn’t match verified email domain. Please contact Rill support.")
}

if publicemail.IsPublic(req.Domain) {
return nil, status.Errorf(codes.InvalidArgument, "Public Domain %s cannot be whitelisted", req.Domain)
}
}

role, err := s.admin.DB.FindProjectRole(ctx, req.Role)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
return nil, status.Error(codes.NotFound, "role not found")
}
return nil, status.Error(codes.Internal, err.Error())
}

// find existing users belonging to the whitelisted domain to the project
users, err := s.admin.DB.FindUsersByEmailPattern(ctx, "%@"+req.Domain, "", math.MaxInt)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}

// filter out users who are already members of the project
newUsers := make([]*database.User, 0)
for _, user := range users {
// check if user is already a member of the project
exists, err := s.admin.DB.CheckUserIsAProjectMember(ctx, user.ID, proj.ID)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
if !exists {
newUsers = append(newUsers, user)
}
}

ctx, tx, err := s.admin.DB.NewTx(ctx)
if err != nil {
return nil, err
}
defer func() { _ = tx.Rollback() }()

_, err = s.admin.DB.InsertProjectWhitelistedDomain(ctx, &database.InsertProjectWhitelistedDomainOptions{
ProjectID: proj.ID,
ProjectRoleID: role.ID,
Domain: req.Domain,
})
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}

for _, user := range newUsers {
err = s.admin.DB.InsertProjectMemberUser(ctx, proj.ID, user.ID, role.ID)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
}

err = tx.Commit()
if err != nil {
return nil, err
}

return &adminv1.CreateProjectWhitelistedDomainResponse{}, nil
}

func (s *Server) RemoveProjectWhitelistedDomain(ctx context.Context, req *adminv1.RemoveProjectWhitelistedDomainRequest) (*adminv1.RemoveProjectWhitelistedDomainResponse, error) {
observability.AddRequestAttributes(ctx,
attribute.String("args.organization", req.Organization),
attribute.String("args.project", req.Project),
attribute.String("args.domain", req.Domain),
)

claims := auth.GetClaims(ctx)

proj, err := s.admin.DB.FindProjectByName(ctx, req.Organization, req.Project)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
return nil, status.Error(codes.NotFound, "project not found")
}
return nil, status.Error(codes.Internal, err.Error())
}

if !(claims.ProjectPermissions(ctx, proj.OrganizationID, proj.ID).ManageProject || claims.Superuser(ctx)) {
return nil, status.Error(codes.PermissionDenied, "only project admins can remove whitelisted domain")
}

invite, err := s.admin.DB.FindProjectWhitelistedDomain(ctx, proj.ID, req.Domain)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
return nil, status.Errorf(codes.NotFound, "whitelist not found for project %q and domain %q", proj.Name, req.Domain)
}
return nil, status.Error(codes.Internal, err.Error())
}

err = s.admin.DB.DeleteProjectWhitelistedDomain(ctx, invite.ID)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}

return &adminv1.RemoveProjectWhitelistedDomainResponse{}, nil
}

func (s *Server) ListProjectWhitelistedDomains(ctx context.Context, req *adminv1.ListProjectWhitelistedDomainsRequest) (*adminv1.ListProjectWhitelistedDomainsResponse, error) {
observability.AddRequestAttributes(ctx,
attribute.String("args.organization", req.Organization),
attribute.String("args.project", req.Project),
)

proj, err := s.admin.DB.FindProjectByName(ctx, req.Organization, req.Project)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
return nil, status.Error(codes.NotFound, "project not found")
}
return nil, status.Error(codes.Internal, err.Error())
}

claims := auth.GetClaims(ctx)
if !(claims.ProjectPermissions(ctx, proj.OrganizationID, proj.ID).ManageProject || claims.Superuser(ctx)) {
return nil, status.Error(codes.PermissionDenied, "only project admins can list whitelisted domains")
}

domains, err := s.admin.DB.FindProjectWhitelistedDomainForProjectWithJoinedRoleNames(ctx, proj.ID)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}

dtos := make([]*adminv1.WhitelistedDomain, len(domains))
for i, domain := range domains {
dtos[i] = &adminv1.WhitelistedDomain{
Domain: domain.Domain,
Role: domain.RoleName,
}
}

return &adminv1.ListProjectWhitelistedDomainsResponse{Domains: dtos}, nil
}

func (s *Server) projToDTO(p *database.Project, orgName string) *adminv1.Project {
frontendURL, _ := url.JoinPath(s.opts.FrontendURL, orgName, p.Name)

Expand Down
37 changes: 34 additions & 3 deletions admin/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ func (s *Service) CreateOrUpdateUser(ctx context.Context, email, name, photoURL
addedToOrgNames = append(addedToOrgNames, org.Name)
}

// check if users email domain is whitelisted
// check if users email domain is whitelisted for some organizations
domain := email[strings.LastIndex(email, "@")+1:]
whitelists, err := s.DB.FindOrganizationWhitelistedDomainsForDomain(ctx, domain)
organizationWhitelistedDomains, err := s.DB.FindOrganizationWhitelistedDomainsForDomain(ctx, domain)
if err != nil {
return nil, err
}
for _, whitelist := range whitelists {
for _, whitelist := range organizationWhitelistedDomains {
// if user is already a member of the org then skip, prefer explicit invite to whitelist
if _, ok := addedToOrgIDs[whitelist.OrgID]; ok {
continue
Expand All @@ -122,7 +122,13 @@ func (s *Service) CreateOrUpdateUser(ctx context.Context, email, name, photoURL
}

// handle project invites
addedToProjectIDs := make(map[string]bool)
addedToProjectNames := make([]string, 0)
for _, invite := range projectInvites {
project, err := s.DB.FindProject(ctx, invite.ProjectID)
if err != nil {
return nil, err
}
err = s.DB.InsertProjectMemberUser(ctx, invite.ProjectID, user.ID, invite.ProjectRoleID)
if err != nil {
return nil, err
Expand All @@ -131,6 +137,30 @@ func (s *Service) CreateOrUpdateUser(ctx context.Context, email, name, photoURL
if err != nil {
return nil, err
}
addedToProjectIDs[project.ID] = true
addedToProjectNames = append(addedToProjectNames, project.Name)
}

// check if users email domain is whitelisted for some projects
projectWhitelistedDomains, err := s.DB.FindProjectWhitelistedDomainsForDomain(ctx, domain)
if err != nil {
return nil, err
}
for _, whitelist := range projectWhitelistedDomains {
// if user is already a member of the project then skip, prefer explicit invite to whitelist
if _, ok := addedToProjectIDs[whitelist.ProjectID]; ok {
continue
}
project, err := s.DB.FindProject(ctx, whitelist.ProjectID)
if err != nil {
return nil, err
}
err = s.DB.InsertProjectMemberUser(ctx, whitelist.ProjectID, user.ID, whitelist.ProjectRoleID)
if err != nil {
return nil, err
}
addedToProjectIDs[project.ID] = true
addedToProjectNames = append(addedToProjectNames, project.Name)
}

err = tx.Commit()
Expand All @@ -143,6 +173,7 @@ func (s *Service) CreateOrUpdateUser(ctx context.Context, email, name, photoURL
zap.String("email", user.Email),
zap.String("name", user.DisplayName),
zap.String("org", strings.Join(addedToOrgNames, ",")),
zap.String("project", strings.Join(addedToProjectNames, ",")),
)

return user, nil
Expand Down
Loading

1 comment on commit 76d1c16

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.