Skip to content

Commit

Permalink
feat: migration dependencies (#1197)
Browse files Browse the repository at this point in the history
* feat: migration dependency graph

* feat: run dependent scripts

* test: migration dependency

* fix: don't read dir. use the embed.FS when parsing dependencies.

* feat: add more dependency headers
  • Loading branch information
adityathebe authored Nov 15, 2024
1 parent ad99acd commit 27c2266
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 20 deletions.
79 changes: 79 additions & 0 deletions migrate/dependency.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package migrate

import (
"bufio"
"path/filepath"
"strings"

"github.com/flanksource/duty/functions"
"github.com/flanksource/duty/views"
"github.com/samber/lo"
)

func parseDependencies(script string) ([]string, error) {
const dependencyHeader = "-- dependsOn: "

var dependencies []string
scanner := bufio.NewScanner(strings.NewReader(script))
for scanner.Scan() {
line := scanner.Text()
if !strings.HasPrefix(line, dependencyHeader) {
break
}

line = strings.TrimPrefix(line, dependencyHeader)
deps := strings.Split(line, ",")
dependencies = append(dependencies, lo.Map(deps, func(x string, _ int) string {
return strings.TrimSpace(x)
})...)
}

return dependencies, nil
}

// DependencyMap map holds path -> dependents
type DependencyMap map[string][]string

// getDependencyTree returns a list of scripts and its dependents
//
// example: if a.sql dependsOn b.sql, c.sql
// it returns
//
// {
// b.sql: []string{a.sql},
// c.sql: []string{a.sql},
// }
func getDependencyTree() (DependencyMap, error) {
graph := make(DependencyMap)

funcs, err := functions.GetFunctions()
if err != nil {
return nil, err
}

views, err := views.GetViews()
if err != nil {
return nil, err
}

for i, dir := range []map[string]string{funcs, views} {
dirName := "functions"
if i == 1 {
dirName = "views"
}

for entry, content := range dir {
path := filepath.Join(dirName, entry)
dependents, err := parseDependencies(content)
if err != nil {
return nil, err
}

for _, dependent := range dependents {
graph[dependent] = append(graph[dependent], strings.TrimPrefix(path, "../"))
}
}
}

return graph, nil
}
58 changes: 58 additions & 0 deletions migrate/dependency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package migrate

import (
"testing"

"github.com/onsi/gomega"
)

func TestParseDependencies(t *testing.T) {
testdata := []struct {
script string
want []string
}{
{
script: "-- dependsOn: a.sql, b.sql",
want: []string{"a.sql", "b.sql"},
},
{
script: "SELECT 1;",
want: nil,
},
{
script: "-- dependsOn: a.sql, b.sql,c.sql",
want: []string{"a.sql", "b.sql", "c.sql"},
},
}

g := gomega.NewWithT(t) // use gomega with std go tests
for _, td := range testdata {
got, err := parseDependencies(td.script)
if err != nil {
t.Fatal(err.Error())
}

g.Expect(got).To(gomega.Equal(td.want))
}
}

func TestDependencyMap(t *testing.T) {
g := gomega.NewWithT(t) // use gomega with std go tests

graph, err := getDependencyTree()
if err != nil {
t.Fatal(err.Error())
}

expected := map[string][]string{
"functions/drop.sql": {"views/006_config_views.sql", "views/021_notification.sql"},
"views/006_config_views.sql": {"views/021_notification.sql"},
}

g.Expect(graph).To(gomega.HaveLen(len(expected)))

for key, expectedDeps := range expected {
g.Expect(graph).To(gomega.HaveKey(key))
g.Expect(graph[key]).To(gomega.ConsistOf(expectedDeps))
}
}
132 changes: 112 additions & 20 deletions migrate/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"database/sql"
"errors"
"fmt"
"path/filepath"
"sort"

"github.com/flanksource/commons/collections"
Expand Down Expand Up @@ -50,14 +51,13 @@ func RunMigrations(pool *sql.DB, config api.Config) error {
return fmt.Errorf("failed to create migration log table: %w", err)
}

l.V(3).Infof("Getting functions")
funcs, err := functions.GetFunctions()
allFunctions, allViews, err := GetExecutableScripts(pool)
if err != nil {
return fmt.Errorf("failed to get functions: %w", err)
return fmt.Errorf("failed to get executable scripts: %w", err)
}

l.V(3).Infof("Running scripts")
if err := runScripts(pool, funcs, config.SkipMigrationFiles); err != nil {
if err := runScripts(pool, allFunctions, config.SkipMigrationFiles); err != nil {
return fmt.Errorf("failed to run scripts: %w", err)
}

Expand All @@ -72,18 +72,119 @@ func RunMigrations(pool *sql.DB, config api.Config) error {
return fmt.Errorf("failed to apply schema migrations: %w", err)
}

l.V(3).Infof("Running scripts for views")
if err := runScripts(pool, allViews, config.SkipMigrationFiles); err != nil {
return fmt.Errorf("failed to run scripts for views: %w", err)
}

return nil
}

// GetExecutableScripts returns functions & views that must be applied.
// It takes dependencies into account & excludes any unchanged scripts.
func GetExecutableScripts(pool *sql.DB) (map[string]string, map[string]string, error) {
l := logger.GetLogger("migrate")

var (
allFunctions = map[string]string{}
allViews = map[string]string{}
)

l.V(3).Infof("Getting functions")
funcs, err := functions.GetFunctions()
if err != nil {
return nil, nil, fmt.Errorf("failed to get functions: %w", err)
}

l.V(3).Infof("Getting views")
views, err := views.GetViews()
if err != nil {
return fmt.Errorf("failed to get views: %w", err)
return nil, nil, fmt.Errorf("failed to get views: %w", err)
}

l.V(3).Infof("Running scripts for views")
if err := runScripts(pool, views, config.SkipMigrationFiles); err != nil {
return fmt.Errorf("failed to run scripts for views: %w", err)
depGraph, err := getDependencyTree()
if err != nil {
return nil, nil, fmt.Errorf("failed to for dependency map: %w", err)
}

return nil
currentMigrationHashes, err := readMigrationLogs(pool)
if err != nil {
return nil, nil, err
}

for path, content := range funcs {
hash := sha1.Sum([]byte(content))
if ch, ok := currentMigrationHashes[path]; ok && ch == string(hash[:]) {
continue
}

allFunctions[path] = content

// other scripts that depend on this should also be executed
for _, dependent := range depGraph[filepath.Join("functions", path)] {
baseDir := filepath.Dir(dependent)
filename := filepath.Base(dependent)

switch baseDir {
case "functions":
allFunctions[filename] = funcs[filename]
case "views":
allViews[filename] = views[filename]
default:
panic("unhandled base dir")
}
}
}

for path, content := range views {
hash := sha1.Sum([]byte(content))
if ch, ok := currentMigrationHashes[path]; ok && ch == string(hash[:]) {
continue
}

allViews[path] = content

// other scripts that depend on this should also be executed
for _, dependent := range depGraph[filepath.Join("functions", path)] {
baseDir := filepath.Dir(dependent)
filename := filepath.Base(dependent)

switch baseDir {
case "functions":
allFunctions[filename] = funcs[filename]
case "views":
allViews[filename] = views[filename]
default:
panic("unhandled base dir")
}
}
}

return allFunctions, allViews, err
}

func readMigrationLogs(pool *sql.DB) (map[string]string, error) {
rows, err := pool.Query("SELECT path, hash FROM migration_logs")
if err != nil {
return nil, fmt.Errorf("failed to read migration logs: %w", err)
}
defer rows.Close()

migrationHashes := make(map[string]string)
for rows.Next() {
var path, hash string
if err := rows.Scan(&path, &hash); err != nil {
return nil, err
}

migrationHashes[path] = hash
}

if rows.Err() != nil {
return nil, rows.Err()
}

return migrationHashes, nil
}

func createRole(db *sql.DB, roleName string, config api.Config, grants ...string) error {
Expand Down Expand Up @@ -166,6 +267,7 @@ func checkIfRoleIsGranted(pool *sql.DB, group, member string) (bool, error) {

func runScripts(pool *sql.DB, scripts map[string]string, ignoreFiles []string) error {
l := logger.GetLogger("migrate")

var filenames []string
for name := range scripts {
if collections.Contains(ignoreFiles, name) {
Expand All @@ -181,18 +283,8 @@ func runScripts(pool *sql.DB, scripts map[string]string, ignoreFiles []string) e
continue
}

var currentHash string
if err := pool.QueryRow("SELECT hash FROM migration_logs WHERE path = $1", file).Scan(&currentHash); err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}

hash := sha1.Sum([]byte(content))
if string(hash[:]) == currentHash {
l.V(3).Infof("Skipping script %s", file)
continue
}

l.Tracef("Running script %s", file)
l.Tracef("running script %s", file)
if _, err := pool.Exec(scripts[file]); err != nil {
return fmt.Errorf("failed to run script %s: %w", file, db.ErrorDetails(err))
}
Expand Down
51 changes: 51 additions & 0 deletions tests/migration_dependency_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package tests

import (
"github.com/flanksource/duty/migrate"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("migration dependency", Ordered, func() {
It("should have no executable scripts", func() {
db, err := DefaultContext.DB().DB()
Expect(err).To(BeNil())

funcs, views, err := migrate.GetExecutableScripts(db)
Expect(err).To(BeNil())
Expect(len(funcs)).To(BeZero())
Expect(len(views)).To(BeZero())
})

// FIXME: sql driver issue on CI
// It("should get correct executable scripts", func() {
// err := DefaultContext.DB().Exec(`UPDATE migration_logs SET hash = 'dummy' WHERE path = 'drop.sql'`).Error
// Expect(err).To(BeNil())
//
// db, err := DefaultContext.DB().DB()
// Expect(err).To(BeNil())
//
// funcs, views, err := migrate.GetExecutableScripts(db)
// Expect(err).To(BeNil())
// Expect(len(funcs)).To(Equal(1))
// Expect(len(views)).To(Equal(2))
//
// Expect(collections.MapKeys(funcs)).To(Equal([]string{"drop.sql"}))
// Expect(collections.MapKeys(views)).To(ConsistOf([]string{"006_config_views.sql", "021_notification.sql"}))
//
// {
// // run the migrations again to ensure that the hashes are repopulated
// err := migrate.RunMigrations(db, api.DefaultConfig)
// Expect(err).To(BeNil())
//
// // at the end, there should be no scrips to apply
// db, err := DefaultContext.DB().DB()
// Expect(err).To(BeNil())
//
// funcs, views, err := migrate.GetExecutableScripts(db)
// Expect(err).To(BeNil())
// Expect(len(funcs)).To(BeZero())
// Expect(len(views)).To(BeZero())
// }
// })
})
2 changes: 2 additions & 0 deletions views/006_config_views.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
-- dependsOn: functions/drop.sql

-- Add cascade drops first to make sure all functions and views are always recreated
DROP VIEW IF EXISTS configs CASCADE;

Expand Down
2 changes: 2 additions & 0 deletions views/021_notification.sql
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
-- dependsOn: functions/drop.sql, views/006_config_views.sql

-- Handle before updates for notifications
CREATE OR REPLACE FUNCTION reset_notification_error_before_update ()
RETURNS TRIGGER
Expand Down

0 comments on commit 27c2266

Please sign in to comment.