From 72ae32dbf8965f414888cc957c6f66b5a96f27c0 Mon Sep 17 00:00:00 2001 From: Gabe Cook Date: Tue, 21 Nov 2023 21:27:01 -0600 Subject: [PATCH] fix(mariadb): Fix restore when db name has special characters --- internal/database/dialect/mariadb.go | 9 ++++++++- internal/database/dialect/mariadb_test.go | 22 +++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/internal/database/dialect/mariadb.go b/internal/database/dialect/mariadb.go index 5e1184e3..c1c1cd0c 100644 --- a/internal/database/dialect/mariadb.go +++ b/internal/database/dialect/mariadb.go @@ -46,7 +46,8 @@ func (MariaDB) DefaultUser() string { return "root" } -func (MariaDB) DropDatabaseQuery(database string) string { +func (db MariaDB) DropDatabaseQuery(database string) string { + database = db.quoteIdentifier(database) return "set FOREIGN_KEY_CHECKS=0; create or replace database " + database + "; set FOREIGN_KEY_CHECKS=1; use " + database + ";" } @@ -160,3 +161,9 @@ func (db MariaDB) DumpExtension(format sqlformat.Format) string { } return "" } + +func (db MariaDB) quoteIdentifier(param string) string { + param = strings.ReplaceAll(param, "`", "``") + param = "`" + param + "`" + return param +} diff --git a/internal/database/dialect/mariadb_test.go b/internal/database/dialect/mariadb_test.go index a539e156..026e3407 100644 --- a/internal/database/dialect/mariadb_test.go +++ b/internal/database/dialect/mariadb_test.go @@ -21,7 +21,7 @@ func TestMariaDB_DropDatabaseQuery(t *testing.T) { args args want string }{ - {"database", args{"database"}, "set FOREIGN_KEY_CHECKS=0; create or replace database database; set FOREIGN_KEY_CHECKS=1; use database;"}, + {"database", args{"database"}, "set FOREIGN_KEY_CHECKS=0; create or replace database `database`; set FOREIGN_KEY_CHECKS=1; use `database`;"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -264,3 +264,23 @@ func TestMariaDB_FormatFromFilename(t *testing.T) { }) } } + +func TestMariaDB_quoteIdentifier(t *testing.T) { + type args struct { + param string + } + tests := []struct { + name string + args args + want string + }{ + {"simple", args{"table"}, "`table`"}, + {"escaped", args{"T`able"}, "`T``able`"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := MariaDB{} + assert.Equal(t, tt.want, db.quoteIdentifier(tt.args.param)) + }) + } +}