-
Notifications
You must be signed in to change notification settings - Fork 52
/
aggregate.go
199 lines (160 loc) · 5.44 KB
/
aggregate.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
// Copyright 2019 Tim Shannon. All rights reserved.
// Use of this source code is governed by the MIT license
// that can be found in the LICENSE file.
package badgerhold
import (
"fmt"
"reflect"
"sort"
"github.com/dgraph-io/badger/v4"
)
// AggregateResult allows you to access the results of an aggregate query
type AggregateResult struct {
reduction []reflect.Value // always pointers
group []reflect.Value
sortby string
}
// Group returns the field grouped by in the query
func (a *AggregateResult) Group(result ...interface{}) {
for i := range result {
resultVal := reflect.ValueOf(result[i])
if resultVal.Kind() != reflect.Ptr {
panic("result argument must be an address")
}
if i >= len(a.group) {
panic(fmt.Sprintf("There is not %d elements in the grouping", i))
}
resultVal.Elem().Set(a.group[i])
}
}
// Reduction is the collection of records that are part of the AggregateResult Group
func (a *AggregateResult) Reduction(result interface{}) {
resultVal := reflect.ValueOf(result)
if resultVal.Kind() != reflect.Ptr || resultVal.Elem().Kind() != reflect.Slice {
panic("result argument must be a slice address")
}
sliceVal := resultVal.Elem()
elType := sliceVal.Type().Elem()
for i := range a.reduction {
if elType.Kind() == reflect.Ptr {
sliceVal = reflect.Append(sliceVal, a.reduction[i])
} else {
sliceVal = reflect.Append(sliceVal, a.reduction[i].Elem())
}
}
resultVal.Elem().Set(sliceVal.Slice(0, sliceVal.Len()))
}
type aggregateResultSort AggregateResult
func (a *aggregateResultSort) Len() int { return len(a.reduction) }
func (a *aggregateResultSort) Swap(i, j int) {
a.reduction[i], a.reduction[j] = a.reduction[j], a.reduction[i]
}
func (a *aggregateResultSort) Less(i, j int) bool {
//reduction values are always pointers
iVal := a.reduction[i].Elem().FieldByName(a.sortby)
if !iVal.IsValid() {
panic(fmt.Sprintf("The field %s does not exist in the type %s", a.sortby, a.reduction[i].Type()))
}
jVal := a.reduction[j].Elem().FieldByName(a.sortby)
if !jVal.IsValid() {
panic(fmt.Sprintf("The field %s does not exist in the type %s", a.sortby, a.reduction[j].Type()))
}
c, err := compare(iVal.Interface(), jVal.Interface())
if err != nil {
panic(err)
}
return c == -1
}
// Sort sorts the aggregate reduction by the passed in field in ascending order
// Sort is called automatically by calls to Min / Max to get the min and max values
func (a *AggregateResult) Sort(field string) {
if !startsUpper(field) {
panic("The first letter of a field must be upper-case")
}
if a.sortby == field {
// already sorted
return
}
a.sortby = field
sort.Sort((*aggregateResultSort)(a))
}
// Max Returns the maxiumum value of the Aggregate Grouping, uses the Comparer interface
func (a *AggregateResult) Max(field string, result interface{}) {
a.Sort(field)
resultVal := reflect.ValueOf(result)
if resultVal.Kind() != reflect.Ptr {
panic("result argument must be an address")
}
if resultVal.IsNil() {
panic("result argument must not be nil")
}
resultVal.Elem().Set(a.reduction[len(a.reduction)-1].Elem())
}
// Min returns the minimum value of the Aggregate Grouping, uses the Comparer interface
func (a *AggregateResult) Min(field string, result interface{}) {
a.Sort(field)
resultVal := reflect.ValueOf(result)
if resultVal.Kind() != reflect.Ptr {
panic("result argument must be an address")
}
if resultVal.IsNil() {
panic("result argument must not be nil")
}
resultVal.Elem().Set(a.reduction[0].Elem())
}
// Avg returns the average float value of the aggregate grouping
// panics if the field cannot be converted to an float64
func (a *AggregateResult) Avg(field string) float64 {
sum := a.Sum(field)
return sum / float64(len(a.reduction))
}
// Sum returns the sum value of the aggregate grouping
// panics if the field cannot be converted to an float64
func (a *AggregateResult) Sum(field string) float64 {
var sum float64
for i := range a.reduction {
fVal := a.reduction[i].Elem().FieldByName(field)
if !fVal.IsValid() {
panic(fmt.Sprintf("The field %s does not exist in the type %s", field, a.reduction[i].Type()))
}
sum += tryFloat(fVal)
}
return sum
}
// Count returns the number of records in the aggregate grouping
func (a *AggregateResult) Count() uint64 {
return uint64(len(a.reduction))
}
// FindAggregate returns an aggregate grouping for the passed in query
// groupBy is optional
func (s *Store) FindAggregate(dataType interface{}, query *Query, groupBy ...string) ([]*AggregateResult, error) {
var result []*AggregateResult
var err error
err = s.Badger().View(func(tx *badger.Txn) error {
result, err = s.TxFindAggregate(tx, dataType, query, groupBy...)
return err
})
if err != nil {
return nil, err
}
return result, nil
}
// TxFindAggregate is the same as FindAggregate, but you specify your own transaction
// groupBy is optional
func (s *Store) TxFindAggregate(tx *badger.Txn, dataType interface{}, query *Query,
groupBy ...string) ([]*AggregateResult, error) {
return s.aggregateQuery(tx, dataType, query, groupBy...)
}
func tryFloat(val reflect.Value) float64 {
switch val.Kind() {
case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int8:
return float64(val.Int())
case reflect.Uint, reflect.Uint16,
reflect.Uint32, reflect.Uint64, reflect.Uint8:
return float64(val.Uint())
case reflect.Float32, reflect.Float64:
return val.Float()
default:
panic(fmt.Sprintf("The field is of Kind %s and cannot be converted to a float64", val.Kind()))
}
}