-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.go
139 lines (109 loc) · 3.01 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
package main
import (
"fmt"
"log"
"net"
"net/http"
"os"
"promproxy/util"
"strconv"
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"
"golang.org/x/net/netutil"
)
var envLabel *dto.LabelPair
var accessToken string
func main() {
log.Print("Starting promproxy")
if env, ok := os.LookupEnv("PROMPROXY_ENV_LABEL"); ok {
envLabel = util.CreateLabelPair("env", env)
log.Print("Using environment label " + env)
}
if token, ok := os.LookupEnv("PROMPROXY_ACCESS_TOKEN"); ok {
accessToken = "Bearer " + token
log.Print("Expecting access token " + token)
}
var port string
if configPort, ok := os.LookupEnv("PROMPROXY_PORT"); ok {
port = ":" + configPort
} else {
port = ":9999"
}
var maxConnections int
if configMaxConn, ok := os.LookupEnv("PROMPROXY_MAX_CONN"); ok {
var err error
maxConnections, err = strconv.Atoi(configMaxConn)
if err != nil {
log.Fatalf("Invalid max connections parameter: %v", err)
}
} else {
maxConnections = 10
}
log.Printf("Listening on port %s, max simulaneous connections %v", port, maxConnections)
http.HandleFunc("/", reqHandler)
l, err := net.Listen("tcp", port)
if err != nil {
log.Fatalf("Cannot start listening socket: %v", err)
}
defer l.Close()
l = netutil.LimitListener(l, maxConnections)
log.Fatal(http.Serve(l, nil))
}
func reqHandler(w http.ResponseWriter, inReq *http.Request) {
var err error
request, err := parseRequest(inReq.Context(), inReq.URL)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
log.Print(inReq.URL)
if accessToken != "" && inReq.Header["Authorization"][0] != accessToken {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
results, err := request.resolver.Resolve(inReq.Context(), request.host)
if err != nil || len(results) == 0 {
http.NotFound(w, inReq)
return
}
for _, result := range results {
url := fmt.Sprintf("http://%s:%d/%s", result.IP, request.port, request.path)
outReq, _ := http.NewRequest(http.MethodGet, url, nil)
// Basic auth
if request.basicAuth != nil {
outReq.SetBasicAuth(request.basicAuth.username, request.basicAuth.password)
}
// Set headers
for key, values := range request.headers {
for _, value := range values {
outReq.Header.Add(key, value)
}
}
outRes, err := http.DefaultClient.Do(outReq)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if outRes.StatusCode != http.StatusOK {
http.Error(w, "Upstream error", http.StatusInternalServerError)
return
}
defer outRes.Body.Close()
var sample dto.MetricFamily
format := expfmt.ResponseFormat(outRes.Header)
decoder := expfmt.NewDecoder(outRes.Body, format)
encoder := expfmt.NewEncoder(w, format)
for {
if decoder.Decode(&sample) != nil {
break
}
for _, metric := range sample.Metric {
metric.Label = append(metric.Label, result.Label)
if envLabel != nil {
metric.Label = append(metric.Label, envLabel)
}
}
encoder.Encode(&sample)
}
}
}