forked from hashicorp/dbassert
-
Notifications
You must be signed in to change notification settings - Fork 0
/
column.go
127 lines (111 loc) · 2.91 KB
/
column.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package dbassert
import (
"database/sql"
"strings"
"github.com/stretchr/testify/assert"
)
// ColumnInfo defines a set of information about a column.
type ColumnInfo struct {
// TableName for the column.
TableName string
// Name of the column.
Name string
// Default value for the column.
Default string
// Type of the column.
Type string
// DomainName for the column.
DomainName string
// IsNullable defines if the column can be null.
IsNullable bool
}
// Nullable asserts colName in tableName is nullable.
func (a *DbAsserts) Nullable(tableName, colName string) bool {
if h, ok := a.T.(THelper); ok {
h.Helper()
}
dbColumn, err := a.getSchemaInfo(tableName, colName)
if err != nil {
assert.FailNow(a.T, err.Error())
return false
}
if dbColumn.IsNullable {
return true
}
assert.Fail(a.T, "column is not nullable", "%s: %s is not nullable", tableName, colName)
return false
}
// Domain asserts colName in tableName is domainName.
func (a *DbAsserts) Domain(tableName, colName, domainName string) bool {
if h, ok := a.T.(THelper); ok {
h.Helper()
}
dbColumn, err := a.getSchemaInfo(tableName, colName)
if err != nil {
assert.FailNow(a.T, err.Error())
return false
}
if strings.EqualFold(domainName, dbColumn.DomainName) {
return true
}
assert.Fail(a.T, "domain is not valid", "%s: %s is not %s", tableName, colName, domainName)
return false
}
// Column asserts c ColumnInfo is valid.
func (a *DbAsserts) Column(c ColumnInfo) bool {
if h, ok := a.T.(THelper); ok {
h.Helper()
}
dbColumn, err := a.getSchemaInfo(c.TableName, c.Name)
if err != nil {
assert.FailNow(a.T, err.Error())
return false
}
if c != *dbColumn {
assert.Fail(a.T, "invalid column", "%s: %+v column is not valid in the db column %+v", c.TableName, c, dbColumn)
return false
}
return true
}
func (a *DbAsserts) getSchemaInfo(tableName, columnName string) (*ColumnInfo, error) {
if h, ok := a.T.(THelper); ok {
h.Helper()
}
const query = `
select
table_name,
column_name,
column_default,
data_type,
domain_name,
is_nullable
from information_schema.columns
where table_name = $1 and column_name = $2`
row := a.Db.QueryRow(query, tableName, columnName)
var table, colName, colType, colIsNullable string
var colDefault, colDomainName sql.NullString
if err := row.Scan(&table, &colName, &colDefault, &colType, &colDomainName, &colIsNullable); err != nil {
return nil, err
}
var nullable bool
if colIsNullable == "YES" {
nullable = true
}
return &ColumnInfo{
TableName: tableName,
Name: colName,
Default: NullableString(colDefault),
Type: colType,
DomainName: NullableString(colDomainName),
IsNullable: nullable,
}, nil
}
// NullableString is a type alias for nullable database columns for strings.
func NullableString(value sql.NullString) string {
if !value.Valid {
return ""
}
return value.String
}