Skip to content

Commit

Permalink
Fix nil references to sql DB. (#1485)
Browse files Browse the repository at this point in the history
  • Loading branch information
keyurva authored Jan 3, 2025
1 parent 2d046d8 commit c8d8fd7
Show file tree
Hide file tree
Showing 16 changed files with 33 additions and 31 deletions.
14 changes: 8 additions & 6 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,23 +202,25 @@ func main() {
}

// SQL client
var sqlClient *sqldb.SQLClient
var sqlClient sqldb.SQLClient
if *useSQLite {
sqlClient, err = sqldb.NewSQLiteClient(*sqlitePath)
client, err := sqldb.NewSQLiteClient(*sqlitePath)
if err != nil {
log.Fatalf("Cannot open sqlite database from: %s: %v", *sqlitePath, err)
}
sqlClient.DB = client.DB
defer sqlClient.Close()
}

if *useCloudSQL {
if sqlClient != nil {
if sqlClient.DB != nil {
log.Printf("SQL client has already been created, will not use CloudSQL")
} else {
sqlClient, err = sqldb.NewCloudSQLClient(*cloudSQLInstance)
client, err := sqldb.NewCloudSQLClient(*cloudSQLInstance)
if err != nil {
log.Fatalf("Cannot open cloud sql database from %s: %v", *cloudSQLInstance, err)
}
sqlClient.DB = client.DB
defer sqlClient.Close()
}
}
Expand All @@ -236,7 +238,7 @@ func main() {
}

// Store
if len(tables) == 0 && *remoteMixerDomain == "" && sqlClient == nil {
if len(tables) == 0 && *remoteMixerDomain == "" && sqlClient.DB == nil {
log.Fatal("No bigtables or remote mixer domain or sql database are provided")
}
store, err := store.NewStore(
Expand All @@ -250,7 +252,7 @@ func main() {
cacheOptions := cache.CacheOptions{
FetchSVG: *cacheSVG,
SearchSVG: *cacheSVG,
CacheSQL: store.SQLClient != nil,
CacheSQL: store.SQLClient.DB != nil,
CacheSVFormula: *cacheSVFormula,
}
c, err := cache.NewCache(ctx, store, cacheOptions, metadata)
Expand Down
2 changes: 1 addition & 1 deletion internal/server/count/count.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func countInternal(
}
}
}
if st.SQLClient != nil {
if st.SQLClient.DB != nil {
// all SV contains the SV in the request and child SV in the request SVG.
allSV := []string{}
for _, svOrSvg := range svOrSvgs {
Expand Down
5 changes: 1 addition & 4 deletions internal/server/handler_v1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/datacommonsorg/mixer/internal/proto"
pbv1 "github.com/datacommonsorg/mixer/internal/proto/v1"
"github.com/datacommonsorg/mixer/internal/server/resource"
"github.com/datacommonsorg/mixer/internal/sqldb"
"github.com/datacommonsorg/mixer/internal/store"
"github.com/go-test/deep"
)
Expand All @@ -32,9 +31,7 @@ func TestBulkVariableInfo(t *testing.T) {
ctx := context.Background()

s := Server{
store: &store.Store{
SQLClient: &sqldb.SQLClient{},
},
store: &store.Store{},
metadata: &resource.Metadata{},
httpClient: &http.Client{},
}
Expand Down
2 changes: 1 addition & 1 deletion internal/server/node/property_label.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func GetPropertiesHelper(
}
}
// Fetch data from SQLite
if store.SQLClient != nil {
if store.SQLClient.DB != nil {
var query string
if direction == util.DirectionOut {
query = fmt.Sprintf(
Expand Down
2 changes: 1 addition & 1 deletion internal/server/placein/placein.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func GetPlacesIn(
}
}
}
if store.SQLClient != nil {
if store.SQLClient.DB != nil {
var query string
var args []string
if len(parentPlaces) == 1 && parentPlaces[0] == childPlaceType {
Expand Down
2 changes: 1 addition & 1 deletion internal/server/statvar/fetcher/entity_sv_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func FetchEntityVariables(
}
}
// Fetch from SQL database
if store.SQLClient != nil {
if store.SQLClient.DB != nil {
query := fmt.Sprintf(
`
SELECT entity, GROUP_CONCAT(DISTINCT variable) AS variables
Expand Down
2 changes: 1 addition & 1 deletion internal/server/statvar/fetcher/svg_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func FetchAllSVG(
}
}
}
if store.SQLClient != nil {
if store.SQLClient.DB != nil {
sqlResult, err := fetchSQLSVGs(store.SQLClient.DB)
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions internal/server/statvar/statvar_summary.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (
func GetStatVarSummaryHelper(
ctx context.Context, entities []string, store *store.Store) (
map[string]*pb.StatVarSummary, error) {
if store.BtGroup == nil && store.SQLClient == nil {
if store.BtGroup == nil && store.SQLClient.DB == nil {
return nil, status.Error(codes.Internal, "No store found")
}

Expand All @@ -55,7 +55,7 @@ func GetStatVarSummaryHelper(
btChan <- map[string]*pb.StatVarSummary{}
}

if store.SQLClient != nil {
if store.SQLClient.DB != nil {
errGroup.Go(func() error {
sql, err := sqlGetStatVarSummary(entities, store.SQLClient.DB)
if err != nil {
Expand Down
3 changes: 2 additions & 1 deletion internal/server/v0/propertylabel/property_label_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"testing"

pb "github.com/datacommonsorg/mixer/internal/proto"
"github.com/datacommonsorg/mixer/internal/sqldb"
"github.com/datacommonsorg/mixer/internal/store"
"github.com/datacommonsorg/mixer/internal/store/bigtable"
"github.com/datacommonsorg/mixer/internal/util"
Expand Down Expand Up @@ -93,7 +94,7 @@ func TestMerge(t *testing.T) {

store, err := store.NewStore(
nil,
nil,
sqldb.SQLClient{},
[]*bigtable.Table{
bigtable.NewTable("borgcron_base", baseTable, false /*isCustom=*/),
bigtable.NewTable("borgcron_branch", branchTable, false /*isCustom=*/),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func BulkObservationDatesLinked(
}

// Read data from SQL store.
if store.SQLClient != nil {
if store.SQLClient.DB != nil {
childPlaces, err := shared.FetchChildPlaces(
ctx, store, metadata, httpClient, metadata.RemoteMixerDomain, linkedEntity, entityType)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/server/v1/propertyvalues/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ func Fetch(
}
// No pagination for sqlite query, so if there is a pagination token, meaning
// the data has already been queried and returned in previous query.
if store.SQLClient != nil && token == "" {
if store.SQLClient.DB != nil && token == "" {
sqlResp, err := fetchSQL(store.SQLClient.DB, nodes, properties, direction)
if err != nil {
return nil, nil, err
Expand Down
2 changes: 1 addition & 1 deletion internal/server/v2/facet/series.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ func SeriesFacet(
}
}
}
if store.SQLClient != nil {
if store.SQLClient.DB != nil {
observationCount, err := sqlquery.CountObservation(store.SQLClient.DB, entities, variables)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion internal/server/v2/observation/contained_in.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func FetchContainedIn(

// Fetch Data from SQLite database.
var sqlResult *pbv2.ObservationResponse
if store.SQLClient != nil {
if store.SQLClient.DB != nil {
if ancestor == childType {
sqlResult = initObservationResult(variables)
variablesStr := "'" + strings.Join(variables, "', '") + "'"
Expand Down
5 changes: 1 addition & 4 deletions internal/server/v2/shared/contained_in_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,14 @@ import (
pb "github.com/datacommonsorg/mixer/internal/proto"
pbv2 "github.com/datacommonsorg/mixer/internal/proto/v2"
"github.com/datacommonsorg/mixer/internal/server/resource"
"github.com/datacommonsorg/mixer/internal/sqldb"
"github.com/datacommonsorg/mixer/internal/store"
)

func TestFetchChildPlaces(t *testing.T) {
t.Parallel()
ctx := context.Background()

s := &store.Store{
SQLClient: &sqldb.SQLClient{},
}
s := &store.Store{}
metadata := &resource.Metadata{}
httpClient := &http.Client{}

Expand Down
8 changes: 6 additions & 2 deletions internal/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,17 @@ type Store struct {
BqClient *bigquery.Client
BtGroup *bigtable.Group
RecogPlaceStore *files.RecogPlaceStore
SQLClient *sqldb.SQLClient
// TODO: Make SQLClient a pointer instead of a value once SQLClient.DB is made internal.
// Currently the direct DB connection is referenced at many places
// and a nil SQLClient pointer leads to NPEs.
// Using a value avoids those situations.
SQLClient sqldb.SQLClient
}

// NewStore creates a new store.
func NewStore(
bqClient *bigquery.Client,
sqlClient *sqldb.SQLClient,
sqlClient sqldb.SQLClient,
tables []*bigtable.Table,
branchTableName string,
metadata *resource.Metadata,
Expand Down
7 changes: 4 additions & 3 deletions test/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,13 @@ func setupInternal(
log.Fatalf("failed to create Bigquery client: %v", err)
}
// SQL client
var sqlClient *sqldb.SQLClient
var sqlClient sqldb.SQLClient
if useSQLite {
sqlClient, err = sqldb.NewSQLiteClient(filepath.Join(path.Dir(filename), "./datacommons.db"))
client, err := sqldb.NewSQLiteClient(filepath.Join(path.Dir(filename), "./datacommons.db"))
if err != nil {
log.Fatalf("Failed to read sqlite database: %v", err)
}
sqlClient.DB = client.DB
err = sqldb.CheckSchema(sqlClient.DB)
if err != nil {
log.Fatalf("SQL schema check failed: %v", err)
Expand Down Expand Up @@ -216,7 +217,7 @@ func SetupBqOnly() (pbs.MixerClient, error) {
if err != nil {
return nil, err
}
st, err := store.NewStore(bqClient, nil, nil, "", nil)
st, err := store.NewStore(bqClient, sqldb.SQLClient{}, nil, "", nil)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit c8d8fd7

Please sign in to comment.