-
Notifications
You must be signed in to change notification settings - Fork 64
/
driver.go
153 lines (131 loc) · 4.13 KB
/
driver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package athena
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"net/url"
"sync"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/athena"
)
var (
openFromSessionMutex sync.Mutex
openFromSessionCount int
)
// Driver is a sql.Driver. It's intended for db/sql.Open().
type Driver struct {
cfg *DriverConfig
}
// NewDriver allows you to register your own driver with `sql.Register`.
// It's useful for more complex use cases. Read more in PR #3.
// https://github.com/segmentio/go-athena/pull/3
//
// Generally, sql.Open() or athena.Open() should suffice.
func NewDriver(cfg *DriverConfig) *Driver {
return &Driver{cfg}
}
func init() {
var drv driver.Driver = &Driver{}
sql.Register("athena", drv)
}
// Open should be used via `db/sql.Open("athena", "<params>")`.
// The following parameters are supported in URI query format (k=v&k2=v2&...)
//
// - `db` (required)
// This is the Athena database name. In the UI, this defaults to "default",
// but the driver requires it regardless.
//
// - `output_location` (required)
// This is the S3 location Athena will dump query results in the format
// "s3://bucket/and/so/forth". In the AWS UI, this defaults to
// "s3://aws-athena-query-results-<ACCOUNTID>-<REGION>", but the driver requires it.
//
// - `poll_frequency` (optional)
// Athena's API requires polling to retrieve query results. This is the frequency at
// which the driver will poll for results. It should be a time/Duration.String().
// A completely arbitrary default of "5s" was chosen.
//
// - `region` (optional)
// Override AWS region. Useful if it is not set with environment variable.
//
// Credentials must be accessible via the SDK's Default Credential Provider Chain.
// For more advanced AWS credentials/session/config management, please supply
// a custom AWS session directly via `athena.Open()`.
func (d *Driver) Open(connStr string) (driver.Conn, error) {
cfg := d.cfg
if cfg == nil {
var err error
// TODO: Implement DriverContext to get proper access to context
cfg, err = configFromConnectionString(context.TODO(), connStr)
if err != nil {
return nil, err
}
}
if cfg.PollFrequency == 0 {
cfg.PollFrequency = 5 * time.Second
}
return &conn{
athena: athena.NewFromConfig(*cfg.Config),
db: cfg.Database,
OutputLocation: cfg.OutputLocation,
pollFrequency: cfg.PollFrequency,
}, nil
}
// Open is a more robust version of `db.Open`, as it accepts a raw aws.Session.
// This is useful if you have a complex AWS session since the driver doesn't
// currently attempt to serialize all options into a string.
func Open(cfg DriverConfig) (*sql.DB, error) {
if cfg.Database == "" {
return nil, errors.New("db is required")
}
if cfg.OutputLocation == "" {
return nil, errors.New("s3_staging_url is required")
}
if cfg.Config == nil {
return nil, errors.New("AWS config is required")
}
// This hack was copied from jackc/pgx. Sorry :(
// https://github.com/jackc/pgx/blob/70a284f4f33a9cc28fd1223f6b83fb00deecfe33/stdlib/sql.go#L130-L136
openFromSessionMutex.Lock()
openFromSessionCount++
name := fmt.Sprintf("athena-%d", openFromSessionCount)
openFromSessionMutex.Unlock()
sql.Register(name, &Driver{&cfg})
return sql.Open(name, "")
}
// Config is the input to Open().
type DriverConfig struct {
Config *aws.Config
Database string
OutputLocation string
PollFrequency time.Duration
}
func configFromConnectionString(ctx context.Context, connStr string) (*DriverConfig, error) {
args, err := url.ParseQuery(connStr)
if err != nil {
return nil, err
}
var cfg DriverConfig
awsConfig, err := config.LoadDefaultConfig(ctx)
if err != nil {
return nil, err
}
if region := args.Get("region"); region != "" {
awsConfig.Region = region
}
cfg.Config = &awsConfig
cfg.Database = args.Get("db")
cfg.OutputLocation = args.Get("output_location")
frequencyStr := args.Get("poll_frequency")
if frequencyStr != "" {
cfg.PollFrequency, err = time.ParseDuration(frequencyStr)
if err != nil {
return nil, fmt.Errorf("invalid poll_frequency parameter: %s", frequencyStr)
}
}
return &cfg, nil
}