From d60d60f5c5397a455f46b3e8b20ed1cdd906551c Mon Sep 17 00:00:00 2001 From: Yash Mehrotra Date: Fri, 17 Nov 2023 13:45:14 +0530 Subject: [PATCH] chore: add namespace to connection string --- context/connection.go | 44 ++++++++++++++++++++++++-------------- context/connection_test.go | 22 ++++++++++++++++--- schema/connections.hcl | 5 +++-- tests/connection_test.go | 2 +- 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/context/connection.go b/context/connection.go index ad779e0c..1e9b2949 100644 --- a/context/connection.go +++ b/context/connection.go @@ -17,8 +17,8 @@ var ( ) // extractConnectionNameType extracts the name and connection type from a connection -// string formatted as "connection:///". -func extractConnectionNameType(connectionString string) (name string, connectionType string, found bool) { +// string formatted as "connection:////". +func extractConnectionNameType(connectionString string) (name, namespace, connectionType string, found bool) { prefix := "connection://" if !strings.HasPrefix(connectionString, prefix) { @@ -26,8 +26,8 @@ func extractConnectionNameType(connectionString string) (name string, connection } connectionString = strings.TrimPrefix(connectionString, prefix) - parts := strings.SplitN(connectionString, "/", 2) - if len(parts) != 2 { + parts := strings.Split(connectionString, "/") + if len(parts) > 3 || len(parts) < 1 { return } @@ -35,12 +35,20 @@ func extractConnectionNameType(connectionString string) (name string, connection return } - return parts[1], parts[0], true + if len(parts) == 3 { + name, namespace, connectionType = parts[2], parts[1], parts[0] + return name, namespace, connectionType, true + } else if len(parts) == 2 { + name, connectionType = parts[1], parts[0] + return name, "", connectionType, true + } + + return } // HydrateConnectionByURL retrieves a connection from the given connection string. // The connection string is expected to be in one of the following forms: -// - connection:/// or +// - connection:/// or connection://// // - the UUID of the connection. func HydrateConnectionByURL(ctx Context, connectionString string) (*models.Connection, error) { if connectionString == "" { @@ -72,7 +80,7 @@ func IsValidConnectionURL(connectionString string) bool { if _, err := uuid.Parse(connectionString); err == nil { return true } - _, _, found := extractConnectionNameType(connectionString) + _, _, _, found := extractConnectionNameType(connectionString) return found } @@ -87,24 +95,28 @@ func FindConnectionByURL(ctx Context, connectionString string) (*models.Connecti return &connection, nil } - name, connectionType, found := extractConnectionNameType(connectionString) + name, namespace, connectionType, found := extractConnectionNameType(connectionString) if !found { return nil, nil } - connection, err := FindConnection(ctx, connectionType, name) + connection, err := FindConnection(ctx, connectionType, name, namespace) if err != nil { - return nil, fmt.Errorf("failed to find connection (type=%s, name=%s): %w", connectionType, name, err) + return nil, fmt.Errorf("failed to find connection (type=%s, name=%s, namespace=%s): %w", connectionType, name, namespace, err) } return connection, nil } // FindConnection returns the connection with the given type and name -func FindConnection(ctx Context, connectionType, name string) (*models.Connection, error) { +func FindConnection(ctx Context, connectionType, name, namespace string) (*models.Connection, error) { var connection models.Connection - err := ctx.DB().Where("type = ? AND name = ?", connectionType, name).First(&connection).Error + if namespace == "" { + namespace = ctx.GetNamespace() + } + + err := ctx.DB().Where("type = ? AND name = ? AND namespace = ?", connectionType, name, namespace).First(&connection).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil @@ -116,12 +128,12 @@ func FindConnection(ctx Context, connectionType, name string) (*models.Connectio return &connection, nil } -func (ctx Context) GetConnection(connectionType string, name string) (*models.Connection, error) { - return GetConnection(ctx, connectionType, name) +func (ctx Context) GetConnection(connectionType, name, namespace string) (*models.Connection, error) { + return GetConnection(ctx, connectionType, name, namespace) } -func GetConnection(ctx Context, connectionType string, name string) (*models.Connection, error) { - connection, err := FindConnection(ctx, connectionType, name) +func GetConnection(ctx Context, connectionType, name, namespace string) (*models.Connection, error) { + connection, err := FindConnection(ctx, connectionType, name, namespace) if err != nil { return nil, err } diff --git a/context/connection_test.go b/context/connection_test.go index e881aaa1..7dab6f8b 100644 --- a/context/connection_test.go +++ b/context/connection_test.go @@ -8,32 +8,37 @@ func TestGetConnectionNameType(t *testing.T) { connection string Expect struct { name string + namespace string connectionType string found bool } }{ { name: "valid connection string", - connection: "connection://db/mission_control", + connection: "connection://db/default/mission_control", Expect: struct { name string + namespace string connectionType string found bool }{ name: "mission_control", + namespace: "default", connectionType: "db", found: true, }, }, { name: "valid connection string | name has /", - connection: "connection://db/mission_control//", + connection: "connection://db/default/mission_control//", Expect: struct { name string + namespace string connectionType string found bool }{ name: "mission_control//", + namespace: "default", connectionType: "db", found: true, }, @@ -43,10 +48,12 @@ func TestGetConnectionNameType(t *testing.T) { connection: "connection:///type-only", Expect: struct { name string + namespace string connectionType string found bool }{ name: "", + namespace: "", connectionType: "", found: false, }, @@ -56,10 +63,12 @@ func TestGetConnectionNameType(t *testing.T) { connection: "invalid-connection-string", Expect: struct { name string + namespace string connectionType string found bool }{ name: "", + namespace: "", connectionType: "", found: false, }, @@ -69,10 +78,12 @@ func TestGetConnectionNameType(t *testing.T) { connection: "", Expect: struct { name string + namespace string connectionType string found bool }{ name: "", + namespace: "", connectionType: "", found: false, }, @@ -82,10 +93,12 @@ func TestGetConnectionNameType(t *testing.T) { connection: "connection://type-only", Expect: struct { name string + namespace string connectionType string found bool }{ name: "", + namespace: "", connectionType: "", found: false, }, @@ -94,10 +107,13 @@ func TestGetConnectionNameType(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - name, connectionType, found := extractConnectionNameType(tc.connection) + name, namespace, connectionType, found := extractConnectionNameType(tc.connection) if name != tc.Expect.name { t.Errorf("g.Expected name %q, but got %q", tc.Expect.name, name) } + if namespace != tc.Expect.namespace { + t.Errorf("g.Expected namespace %q, but got %q", tc.Expect.namespace, namespace) + } if connectionType != tc.Expect.connectionType { t.Errorf("g.Expected connection type %q, but got %q", tc.Expect.connectionType, connectionType) } diff --git a/schema/connections.hcl b/schema/connections.hcl index 7652f66e..cb6e7223 100644 --- a/schema/connections.hcl +++ b/schema/connections.hcl @@ -10,8 +10,9 @@ table "connections" { type = text } column "namespace" { - null = true + null = false type = text + default = "default" } column "type" { null = false @@ -68,7 +69,7 @@ table "connections" { primary_key { columns = [column.id] } - index "connections_name_type_key" { + index "connections_type_name_namespace_key" { unique = true columns = [column.type, column.name, column.namespace] } diff --git a/tests/connection_test.go b/tests/connection_test.go index d0a1247f..fb535c67 100644 --- a/tests/connection_test.go +++ b/tests/connection_test.go @@ -33,7 +33,7 @@ var _ = Describe("Connection", Ordered, func() { var connection *models.Connection var err error It("should be retrieved successfully", func() { - connection, err = testutils.DefaultContext.GetConnection("test", "test") + connection, err = testutils.DefaultContext.GetConnection("test", "test", "default") Expect(err).ToNot(HaveOccurred()) })