Skip to content

Commit

Permalink
Encapsulate key value and svg queries in sql client. (#1489)
Browse files Browse the repository at this point in the history
  • Loading branch information
keyurva authored Jan 8, 2025
1 parent 796bce9 commit 939a584
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 139 deletions.
101 changes: 34 additions & 67 deletions internal/server/statvar/fetcher/svg_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@ package fetcher

import (
"context"
"database/sql"
"strings"

pb "github.com/datacommonsorg/mixer/internal/proto"
"github.com/datacommonsorg/mixer/internal/server/statvar/hierarchy"
"github.com/datacommonsorg/mixer/internal/sqldb/sqlquery"
"github.com/datacommonsorg/mixer/internal/sqldb"
"github.com/datacommonsorg/mixer/internal/store"
"github.com/datacommonsorg/mixer/internal/store/bigtable"
"google.golang.org/protobuf/proto"
Expand Down Expand Up @@ -101,8 +100,8 @@ func FetchAllSVG(
}
}
}
if store.SQLClient.DB != nil {
sqlResult, err := fetchSQLSVGs(store.SQLClient.DB)
if sqldb.IsConnected(&store.SQLClient) {
sqlResult, err := fetchSQLSVGs(ctx, &store.SQLClient)
if err != nil {
return nil, err
}
Expand All @@ -121,9 +120,9 @@ func FetchAllSVG(

// Fetches SVGs from SQL.
// First attempts to get it from key value store and falls back to querying sql table.
func fetchSQLSVGs(sqlClient *sql.DB) (map[string]*pb.StatVarGroupNode, error) {
func fetchSQLSVGs(ctx context.Context, sqlClient *sqldb.SQLClient) (map[string]*pb.StatVarGroupNode, error) {
// Try key value first.
keyValueSVGs, err := fetchSQLKeyValueSVGs(sqlClient)
keyValueSVGs, err := fetchSQLKeyValueSVGs(ctx, sqlClient)
if err != nil {
return map[string]*pb.StatVarGroupNode{}, err
}
Expand All @@ -133,13 +132,13 @@ func fetchSQLSVGs(sqlClient *sql.DB) (map[string]*pb.StatVarGroupNode, error) {
}

// Query sql table.
return fetchSQLTableSVGs(sqlClient)
return fetchSQLTableSVGs(ctx, sqlClient)
}

func fetchSQLKeyValueSVGs(sqlClient *sql.DB) (*pb.StatVarGroups, error) {
func fetchSQLKeyValueSVGs(ctx context.Context, sqlClient *sqldb.SQLClient) (*pb.StatVarGroups, error) {
var svgs pb.StatVarGroups

found, err := sqlquery.GetKeyValue(sqlClient, sqlquery.StatVarGroupsKey, &svgs)
found, err := sqlClient.GetKeyValue(ctx, sqldb.StatVarGroupsKey, &svgs)
if !found || err != nil {
return nil, err
}
Expand All @@ -149,85 +148,53 @@ func fetchSQLKeyValueSVGs(sqlClient *sql.DB) (*pb.StatVarGroups, error) {

// TODO: Deprecate this approach in the future
// once the KV approach is universally available.
func fetchSQLTableSVGs(sqlClient *sql.DB) (map[string]*pb.StatVarGroupNode, error) {
func fetchSQLTableSVGs(ctx context.Context, sqlClient *sqldb.SQLClient) (map[string]*pb.StatVarGroupNode, error) {
result := map[string]*pb.StatVarGroupNode{}
// Query for all the stat var group node
query := `
SELECT t1.subject_id, t2.object_value, t3.object_id
FROM triples t1 JOIN triples t2 ON t1.subject_id = t2.subject_id
JOIN triples t3 ON t1.subject_id = t3.subject_id
WHERE t1.predicate="typeOf"
AND t1.object_id="StatVarGroup"
AND t2.predicate="name"
AND t3.predicate="specializationOf";
`
svgRows, err := sqlClient.Query(query)

svgRows, err := sqlClient.GetStatVarGroups(ctx)
if err != nil {
return nil, err
}
defer svgRows.Close()
for svgRows.Next() {
var self, name, parent string
err = svgRows.Scan(&self, &name, &parent)
if err != nil {
return nil, err
}
result[self] = &pb.StatVarGroupNode{
AbsoluteName: name,
for _, svgRow := range svgRows {
result[svgRow.ID] = &pb.StatVarGroupNode{
AbsoluteName: svgRow.Name,
}
if _, ok := result[parent]; !ok {
result[parent] = &pb.StatVarGroupNode{}
if _, ok := result[svgRow.ParentID]; !ok {
result[svgRow.ParentID] = &pb.StatVarGroupNode{}
}
result[parent].ChildStatVarGroups = append(
result[parent].ChildStatVarGroups,
result[svgRow.ParentID].ChildStatVarGroups = append(
result[svgRow.ParentID].ChildStatVarGroups,
&pb.StatVarGroupNode_ChildSVG{
Id: self,
SpecializedEntity: name,
Id: svgRow.ID,
SpecializedEntity: svgRow.Name,
},
)
}
// Query for all the stat var nodes
query = `
SELECT t1.subject_id, t2.object_value, t3.object_id, COALESCE(t4.object_value, '')
FROM triples t1
JOIN triples t2 ON t1.subject_id = t2.subject_id
JOIN triples t3 ON t1.subject_id = t3.subject_id
LEFT JOIN triples t4 ON t1.subject_id = t4.subject_id AND t4.predicate = "description"
WHERE t1.predicate="typeOf"
AND t1.object_id="StatisticalVariable"
AND t2.predicate="name"
AND t3.predicate="memberOf";
`
svRows, err := sqlClient.Query(query)

svRows, err := sqlClient.GetAllStatisticalVariables(ctx)
if err != nil {
return nil, err
}
defer svRows.Close()
for svRows.Next() {
var sv, name, svg, description string
err = svRows.Scan(&sv, &name, &svg, &description)
if err != nil {
return nil, err
}
if _, ok := result[svg]; !ok {
result[svg] = &pb.StatVarGroupNode{}
for _, svRow := range svRows {
if _, ok := result[svRow.SVGID]; !ok {
result[svRow.SVGID] = &pb.StatVarGroupNode{}
}
searchNames := []string{}
if len(name) > 0 {
searchNames = append(searchNames, name)
if len(svRow.Name) > 0 {
searchNames = append(searchNames, svRow.Name)
}
if len(description) > 0 {
searchNames = append(searchNames, description)
if len(svRow.Description) > 0 {
searchNames = append(searchNames, svRow.Description)
}
result[svg].ChildStatVars = append(
result[svg].ChildStatVars,
result[svRow.SVGID].ChildStatVars = append(
result[svRow.SVGID].ChildStatVars,
&pb.StatVarGroupNode_ChildSV{
Id: sv,
DisplayName: name,
Id: svRow.ID,
DisplayName: svRow.Name,
SearchNames: searchNames,
},
)
result[svg].DescendentStatVarCount += 1
result[svRow.SVGID].DescendentStatVarCount += 1
}
return result, nil
}
15 changes: 15 additions & 0 deletions internal/sqldb/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,18 @@ func (s *StringSlice) Scan(src interface{}) error {
*s = strings.Split(val, ",")
return nil
}

// StatVarGroup represents a StatVarGroup row.
type StatVarGroup struct {
ID string `db:"svg_id"`
Name string `db:"svg_name"`
ParentID string `db:"svg_parent_id"`
}

// StatisticalVariable represents a StatisticalVariable row.
type StatisticalVariable struct {
ID string `db:"sv_id"`
Name string `db:"sv_name"`
SVGID string `db:"svg_id"`
Description string `db:"sv_description"`
}
83 changes: 83 additions & 0 deletions internal/sqldb/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,17 @@ package sqldb
import (
"context"

"github.com/datacommonsorg/mixer/internal/util"
"github.com/jmoiron/sqlx"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
)

const (
// Requests for latest dates include this literal for date in the request.
latestDate = "LATEST"
// Key for SV groups in the key_value_store table.
StatVarGroupsKey = "StatVarGroups"
)

// GetObservations retrieves observations from SQL given a list of variables and entities and a date.
Expand Down Expand Up @@ -92,6 +98,83 @@ func (sc *SQLClient) GetSVSummaries(ctx context.Context, variables []string) ([]
return summaries, nil
}

// GetStatVarGroups retrieves all StatVarGroups from the database.
func (sc *SQLClient) GetStatVarGroups(ctx context.Context) ([]*StatVarGroup, error) {
var svgs []*StatVarGroup

stmt := statement{
query: statements.getAllStatVarGroups,
args: map[string]interface{}{},
}

err := sc.queryAndCollect(
ctx,
stmt,
&svgs,
)
if err != nil {
return nil, err
}

return svgs, nil
}

// GetAllStatisticalVariables retrieves all SVs from the database.
func (sc *SQLClient) GetAllStatisticalVariables(ctx context.Context) ([]*StatisticalVariable, error) {
var svs []*StatisticalVariable

stmt := statement{
query: statements.getAllStatVars,
args: map[string]interface{}{},
}

err := sc.queryAndCollect(
ctx,
stmt,
&svs,
)
if err != nil {
return nil, err
}

return svs, nil
}

// GetKeyValue gets the value for the specified key from the key_value_store table.
// If not found, returns false.
// If found, unmarshals the value into the specified proto and returns true.
func (sc *SQLClient) GetKeyValue(ctx context.Context, key string, out protoreflect.ProtoMessage) (bool, error) {
stmt := statement{
query: statements.getKeyValue,
args: map[string]interface{}{
"key": key,
},
}

values := []string{}

err := sc.queryAndCollect(
ctx,
stmt,
&values,
)
if err != nil || len(values) == 0 {
return false, err
}

bytes, err := util.UnzipAndDecode(values[0])
if err != nil {
return false, err
}

err = proto.Unmarshal(bytes, out)
if err != nil {
return false, err
}

return true, nil
}

func (sc *SQLClient) queryAndCollect(
ctx context.Context,
stmt statement,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 Google LLC
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -12,18 +12,18 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package sqlquery
package sqldb

import (
"database/sql"
"context"
"testing"

pb "github.com/datacommonsorg/mixer/internal/proto"
"github.com/go-test/deep"
)

func TestGetKeyValue(t *testing.T) {
sqlClient, err := sql.Open("sqlite", "../../../test/sqlquery/key_value/datacommons.db")
sqlClient, err := NewSQLiteClient("../../test/sqlquery/key_value/datacommons.db")
if err != nil {
t.Fatalf("Could not open test database: %v", err)
}
Expand All @@ -36,7 +36,7 @@ func TestGetKeyValue(t *testing.T) {

var got pb.StatVarGroups

found, _ := GetKeyValue(sqlClient, StatVarGroupsKey, &got)
found, _ := sqlClient.GetKeyValue(context.Background(), StatVarGroupsKey, &got)
if !found {
t.Errorf("Key value data not found: %s", StatVarGroupsKey)
}
Expand Down
67 changes: 0 additions & 67 deletions internal/sqldb/sqlquery/key_value.go

This file was deleted.

Loading

0 comments on commit 939a584

Please sign in to comment.