Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: migration dependencies #1197

Merged
merged 5 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions migrate/dependency.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package migrate

import (
"bufio"
"path/filepath"
"strings"

"github.com/flanksource/duty/functions"
"github.com/flanksource/duty/views"
"github.com/samber/lo"
)

func parseDependencies(script string) ([]string, error) {
const dependencyHeader = "-- dependsOn: "

var dependencies []string
scanner := bufio.NewScanner(strings.NewReader(script))
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, dependencyHeader) {
break
}

line = strings.TrimPrefix(line, dependencyHeader)
deps := strings.Split(line, ",")
dependencies = append(dependencies, lo.Map(deps, func(x string, _ int) string {
return strings.TrimSpace(x)
})...)
}

return dependencies, nil
}

// DependencyMap map holds path -> dependents
type DependencyMap map[string][]string

// getDependencyTree returns a list of scripts and its dependents
//
// example: if a.sql dependsOn b.sql, c.sql
// it returns
//
// {
// b.sql: []string{a.sql},
// c.sql: []string{a.sql},
// }
func getDependencyTree() (DependencyMap, error) {
graph := make(DependencyMap)

funcs, err := functions.GetFunctions()
if err != nil {
return nil, err
}

views, err := views.GetViews()
if err != nil {
return nil, err
}

for i, dir := range []map[string]string{funcs, views} {
dirName := "functions"
if i == 1 {
dirName = "views"
}

for entry, content := range dir {
path := filepath.Join(dirName, entry)
dependents, err := parseDependencies(content)
if err != nil {
return nil, err
}

for _, dependent := range dependents {
graph[dependent] = append(graph[dependent], strings.TrimPrefix(path, "../"))
}
}
}

return graph, nil
}
58 changes: 58 additions & 0 deletions migrate/dependency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package migrate

import (
"testing"

"github.com/onsi/gomega"
)

func TestParseDependencies(t *testing.T) {
testdata := []struct {
script string
want []string
}{
{
script: "-- dependsOn: a.sql, b.sql",
want: []string{"a.sql", "b.sql"},
},
{
script: "SELECT 1;",
want: nil,
},
{
script: "-- dependsOn: a.sql, b.sql,c.sql",
want: []string{"a.sql", "b.sql", "c.sql"},
},
}

g := gomega.NewWithT(t) // use gomega with std go tests
for _, td := range testdata {
got, err := parseDependencies(td.script)
if err != nil {
t.Fatal(err.Error())
}

g.Expect(got).To(gomega.Equal(td.want))
}
}

func TestDependencyMap(t *testing.T) {
g := gomega.NewWithT(t) // use gomega with std go tests

graph, err := getDependencyTree()
if err != nil {
t.Fatal(err.Error())
}

expected := map[string][]string{
"functions/drop.sql": {"views/006_config_views.sql", "views/021_notification.sql"},
"views/006_config_views.sql": {"views/021_notification.sql"},
}

g.Expect(graph).To(gomega.HaveLen(len(expected)))

for key, expectedDeps := range expected {
g.Expect(graph).To(gomega.HaveKey(key))
g.Expect(graph[key]).To(gomega.ConsistOf(expectedDeps))
}
}
132 changes: 112 additions & 20 deletions migrate/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"database/sql"
"errors"
"fmt"
"path/filepath"
"sort"

"github.com/flanksource/commons/collections"
Expand Down Expand Up @@ -50,14 +51,13 @@ func RunMigrations(pool *sql.DB, config api.Config) error {
return fmt.Errorf("failed to create migration log table: %w", err)
}

l.V(3).Infof("Getting functions")
funcs, err := functions.GetFunctions()
allFunctions, allViews, err := GetExecutableScripts(pool)
if err != nil {
return fmt.Errorf("failed to get functions: %w", err)
return fmt.Errorf("failed to get executable scripts: %w", err)
}

l.V(3).Infof("Running scripts")
if err := runScripts(pool, funcs, config.SkipMigrationFiles); err != nil {
if err := runScripts(pool, allFunctions, config.SkipMigrationFiles); err != nil {
return fmt.Errorf("failed to run scripts: %w", err)
}

Expand All @@ -72,18 +72,119 @@ func RunMigrations(pool *sql.DB, config api.Config) error {
return fmt.Errorf("failed to apply schema migrations: %w", err)
}

l.V(3).Infof("Running scripts for views")
if err := runScripts(pool, allViews, config.SkipMigrationFiles); err != nil {
return fmt.Errorf("failed to run scripts for views: %w", err)
}

return nil
}

// GetExecutableScripts returns functions & views that must be applied.
// It takes dependencies into account & excludes any unchanged scripts.
func GetExecutableScripts(pool *sql.DB) (map[string]string, map[string]string, error) {
l := logger.GetLogger("migrate")

var (
allFunctions = map[string]string{}
allViews = map[string]string{}
)

l.V(3).Infof("Getting functions")
funcs, err := functions.GetFunctions()
if err != nil {
return nil, nil, fmt.Errorf("failed to get functions: %w", err)
}

l.V(3).Infof("Getting views")
views, err := views.GetViews()
if err != nil {
return fmt.Errorf("failed to get views: %w", err)
return nil, nil, fmt.Errorf("failed to get views: %w", err)
}

l.V(3).Infof("Running scripts for views")
if err := runScripts(pool, views, config.SkipMigrationFiles); err != nil {
return fmt.Errorf("failed to run scripts for views: %w", err)
depGraph, err := getDependencyTree()
if err != nil {
return nil, nil, fmt.Errorf("failed to for dependency map: %w", err)
}

return nil
currentMigrationHashes, err := readMigrationLogs(pool)
if err != nil {
return nil, nil, err
}

for path, content := range funcs {
hash := sha1.Sum([]byte(content))
if ch, ok := currentMigrationHashes[path]; ok && ch == string(hash[:]) {
continue
}

allFunctions[path] = content

// other scripts that depend on this should also be executed
for _, dependent := range depGraph[filepath.Join("functions", path)] {
baseDir := filepath.Dir(dependent)
filename := filepath.Base(dependent)

switch baseDir {
case "functions":
allFunctions[filename] = funcs[filename]
case "views":
allViews[filename] = views[filename]
default:
panic("unhandled base dir")
}
}
}

for path, content := range views {
hash := sha1.Sum([]byte(content))
if ch, ok := currentMigrationHashes[path]; ok && ch == string(hash[:]) {
continue
}

allViews[path] = content

// other scripts that depend on this should also be executed
for _, dependent := range depGraph[filepath.Join("functions", path)] {
baseDir := filepath.Dir(dependent)
filename := filepath.Base(dependent)

switch baseDir {
case "functions":
allFunctions[filename] = funcs[filename]
case "views":
allViews[filename] = views[filename]
default:
panic("unhandled base dir")
}
}
}

return allFunctions, allViews, err
}

func readMigrationLogs(pool *sql.DB) (map[string]string, error) {
rows, err := pool.Query("SELECT path, hash FROM migration_logs")
if err != nil {
return nil, fmt.Errorf("failed to read migration logs: %w", err)
}
defer rows.Close()

migrationHashes := make(map[string]string)
for rows.Next() {
var path, hash string
if err := rows.Scan(&path, &hash); err != nil {
return nil, err
}

migrationHashes[path] = hash
}

if rows.Err() != nil {
return nil, rows.Err()
}

return migrationHashes, nil
}

func createRole(db *sql.DB, roleName string, config api.Config, grants ...string) error {
Expand Down Expand Up @@ -166,6 +267,7 @@ func checkIfRoleIsGranted(pool *sql.DB, group, member string) (bool, error) {

func runScripts(pool *sql.DB, scripts map[string]string, ignoreFiles []string) error {
l := logger.GetLogger("migrate")

var filenames []string
for name := range scripts {
if collections.Contains(ignoreFiles, name) {
Expand All @@ -181,18 +283,8 @@ func runScripts(pool *sql.DB, scripts map[string]string, ignoreFiles []string) e
continue
}

var currentHash string
if err := pool.QueryRow("SELECT hash FROM migration_logs WHERE path = $1", file).Scan(&currentHash); err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}

hash := sha1.Sum([]byte(content))
if string(hash[:]) == currentHash {
l.V(3).Infof("Skipping script %s", file)
continue
}

l.Tracef("Running script %s", file)
l.Tracef("running script %s", file)
if _, err := pool.Exec(scripts[file]); err != nil {
return fmt.Errorf("failed to run script %s: %w", file, db.ErrorDetails(err))
}
Expand Down
51 changes: 51 additions & 0 deletions tests/migration_dependency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package tests

import (
"github.com/flanksource/duty/migrate"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("migration dependency", Ordered, func() {
It("should have no executable scripts", func() {
db, err := DefaultContext.DB().DB()
Expect(err).To(BeNil())

funcs, views, err := migrate.GetExecutableScripts(db)
Expect(err).To(BeNil())
Expect(len(funcs)).To(BeZero())
Expect(len(views)).To(BeZero())
})

// FIXME: sql driver issue on CI
// It("should get correct executable scripts", func() {
// err := DefaultContext.DB().Exec(`UPDATE migration_logs SET hash = 'dummy' WHERE path = 'drop.sql'`).Error
// Expect(err).To(BeNil())
//
// db, err := DefaultContext.DB().DB()
// Expect(err).To(BeNil())
//
// funcs, views, err := migrate.GetExecutableScripts(db)
// Expect(err).To(BeNil())
// Expect(len(funcs)).To(Equal(1))
// Expect(len(views)).To(Equal(2))
//
// Expect(collections.MapKeys(funcs)).To(Equal([]string{"drop.sql"}))
// Expect(collections.MapKeys(views)).To(ConsistOf([]string{"006_config_views.sql", "021_notification.sql"}))
//
// {
// // run the migrations again to ensure that the hashes are repopulated
// err := migrate.RunMigrations(db, api.DefaultConfig)
// Expect(err).To(BeNil())
//
// // at the end, there should be no scrips to apply
// db, err := DefaultContext.DB().DB()
// Expect(err).To(BeNil())
//
// funcs, views, err := migrate.GetExecutableScripts(db)
// Expect(err).To(BeNil())
// Expect(len(funcs)).To(BeZero())
// Expect(len(views)).To(BeZero())
// }
// })
})
2 changes: 2 additions & 0 deletions views/006_config_views.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
-- dependsOn: functions/drop.sql

-- Add cascade drops first to make sure all functions and views are always recreated
DROP VIEW IF EXISTS configs CASCADE;

Expand Down
2 changes: 2 additions & 0 deletions views/021_notification.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
-- dependsOn: functions/drop.sql, views/006_config_views.sql

-- Handle before updates for notifications
CREATE OR REPLACE FUNCTION reset_notification_error_before_update ()
RETURNS TRIGGER
Expand Down
Loading