Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Formatter, Linter and fixes #62

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions cmd/lingo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ import (
"time"

"github.com/prometheus/client_golang/prometheus"

"sigs.k8s.io/controller-runtime/pkg/metrics"

"k8s.io/apimachinery/pkg/runtime"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/client-go/kubernetes"
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/log/zap"
"sigs.k8s.io/controller-runtime/pkg/metrics"
metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"

"github.com/substratusai/lingo/pkg/autoscaler"
Expand All @@ -27,10 +29,6 @@ import (
"github.com/substratusai/lingo/pkg/proxy"
"github.com/substratusai/lingo/pkg/queue"
"github.com/substratusai/lingo/pkg/stats"
"k8s.io/apimachinery/pkg/runtime"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
ctrl "sigs.k8s.io/controller-runtime"
)

var (
Expand Down Expand Up @@ -68,14 +66,18 @@ func run() error {
concurrency := getEnvInt("CONCURRENCY", 100)
scaleDownDelay := getEnvInt("SCALE_DOWN_DELAY", 30)

var metricsAddr string
var probeAddr string
var concurrencyPerReplica int
var (
metricsAddr string
probeAddr string
concurrencyPerReplica int
requestHeaderTimeout time.Duration // setting to prevent slowloris attack on http server
)

flag.StringVar(&metricsAddr, "metrics-bind-address", ":8082", "The address the metric endpoint binds to.")
flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.")
flag.IntVar(&concurrencyPerReplica, "concurrency", concurrency, "the number of simultaneous requests that can be processed by each replica")
flag.IntVar(&scaleDownDelay, "scale-down-delay", scaleDownDelay, "seconds to wait before scaling down")
flag.DurationVar(&requestHeaderTimeout, "request-header-timeout", 10*time.Second, "amount of time for the client to send headers before a timeout error will occur")
alpe marked this conversation as resolved.
Show resolved Hide resolved
opts := zap.Options{
Development: true,
}
Expand Down Expand Up @@ -111,7 +113,7 @@ func run() error {

hostname, err := os.Hostname()
if err != nil {
return fmt.Errorf("getting hostname: %v", err)
return fmt.Errorf("getting hostname: %w", err)
}
le := leader.NewElection(clientset, hostname, namespace)

Expand Down Expand Up @@ -154,19 +156,19 @@ func run() error {

proxy.MustRegister(metricsRegistry)
proxyHandler := proxy.NewHandler(deploymentManager, endpointManager, queueManager)
proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler}
proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler, ReadHeaderTimeout: requestHeaderTimeout}

statsHandler := &stats.Handler{
Queues: queueManager,
}
statsServer := &http.Server{Addr: ":8083", Handler: statsHandler}
statsServer := &http.Server{Addr: ":8083", Handler: statsHandler, ReadHeaderTimeout: requestHeaderTimeout}

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer func() {
statsServer.Shutdown(context.Background())
proxyServer.Shutdown(context.Background())
_ = statsServer.Shutdown(context.Background())
_ = proxyServer.Shutdown(context.Background())
wg.Done()
}()
if err := mgr.Start(ctx); err != nil {
Expand Down
31 changes: 16 additions & 15 deletions pkg/autoscaler/autoscaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@ import (
"sync"
"time"

corev1 "k8s.io/api/core/v1"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/substratusai/lingo/pkg/deployments"
"github.com/substratusai/lingo/pkg/endpoints"
"github.com/substratusai/lingo/pkg/leader"
"github.com/substratusai/lingo/pkg/movingaverage"
"github.com/substratusai/lingo/pkg/queue"
"github.com/substratusai/lingo/pkg/stats"
corev1 "k8s.io/api/core/v1"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
)

func New(mgr ctrl.Manager) (*Autoscaler, error) {
Expand Down Expand Up @@ -55,15 +56,15 @@ type Autoscaler struct {
movingAvgQueueSize map[string]*movingaverage.Simple
}

func (r *Autoscaler) SetupWithManager(mgr ctrl.Manager) error {
func (a *Autoscaler) SetupWithManager(mgr ctrl.Manager) error {
return ctrl.NewControllerManagedBy(mgr).
For(&corev1.ConfigMap{}).
Complete(r)
Complete(a)
}

func (r *Autoscaler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
func (a *Autoscaler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) {
var cm corev1.ConfigMap
if err := r.Get(ctx, req.NamespacedName, &cm); err != nil {
if err := a.Get(ctx, req.NamespacedName, &cm); err != nil {
return ctrl.Result{}, fmt.Errorf("get: %w", err)
}

Expand Down Expand Up @@ -105,15 +106,15 @@ func (a *Autoscaler) Start() {
}
}

func (r *Autoscaler) getMovingAvgQueueSize(deploymentName string) *movingaverage.Simple {
r.movingAvgQueueSizeMtx.Lock()
a, ok := r.movingAvgQueueSize[deploymentName]
func (a *Autoscaler) getMovingAvgQueueSize(deploymentName string) *movingaverage.Simple {
a.movingAvgQueueSizeMtx.Lock()
avg, ok := a.movingAvgQueueSize[deploymentName]
if !ok {
a = movingaverage.NewSimple(make([]float64, r.AverageCount))
r.movingAvgQueueSize[deploymentName] = a
avg = movingaverage.NewSimple(make([]float64, a.AverageCount))
a.movingAvgQueueSize[deploymentName] = avg
}
r.movingAvgQueueSizeMtx.Unlock()
return a
a.movingAvgQueueSizeMtx.Unlock()
return avg
}

func aggregateStats(s stats.Stats, httpc *http.Client, endpoints []string) (stats.Stats, []error) {
Expand All @@ -126,7 +127,7 @@ func aggregateStats(s stats.Stats, httpc *http.Client, endpoints []string) (stat
for _, endpoint := range endpoints {
fetched, err := getStats(httpc, "http://"+endpoint)
if err != nil {
errs = append(errs, fmt.Errorf("getting stats: %v: %v", endpoint, err))
errs = append(errs, fmt.Errorf("getting stats: %v: %w", endpoint, err))
continue
}
for k, v := range fetched.ActiveRequests {
Expand Down
12 changes: 8 additions & 4 deletions pkg/deployments/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log"
"math"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -150,7 +151,7 @@ func (r *Manager) getScaler(deploymentName string) *scaler {
}

// getScalesSnapshot returns a snapshot of the stats for all scalers managed by the Manager.
// The scales are returned as a map, where the keys are the model names
// The scales are returned as a map, where the keys are the model names.
func (r *Manager) getScalesSnapshot() map[string]scale {
r.scalersMtx.Lock()
defer r.scalersMtx.Unlock()
Expand Down Expand Up @@ -235,7 +236,7 @@ func (r *Manager) Bootstrap(ctx context.Context) error {

// ReadinessChecker checks if the Manager state is loaded and ready to handle requests.
// It returns an error if Manager is not bootstrapped yet.
// To be used with sigs.k8s.io/controller-runtime manager `AddReadyzCheck`
// To be used with sigs.k8s.io/controller-runtime manager `AddReadyzCheck`.
func (r *Manager) ReadinessChecker(_ *http.Request) error {
if !r.bootstrapped.Load() {
return fmt.Errorf("not boostrapped yet")
Expand All @@ -258,6 +259,9 @@ func getAnnotationInt32(ann map[string]string, key string, defaultValue int32) i
log.Printf("parsing annotation as int: %v", err)
return defaultValue
}

return int32(value)
if value > math.MaxInt32 {
alpe marked this conversation as resolved.
Show resolved Hide resolved
log.Printf("invalid value that exceeds max int32: %d", value)
return defaultValue
}
return int32(value) // #nosec G109 : checked before
}
6 changes: 3 additions & 3 deletions pkg/deployments/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ type MetricsCollector struct {
manager *Manager
}

// NewMetricsCollector constructor
// NewMetricsCollector constructor.
func NewMetricsCollector(m *Manager) *MetricsCollector {
if m == nil {
panic("manager required")
Expand All @@ -22,12 +22,12 @@ func NewMetricsCollector(m *Manager) *MetricsCollector {
}
}

// MustRegister registers all metrics
// MustRegister registers all metrics.
func (p *MetricsCollector) MustRegister(r prometheus.Registerer) {
r.MustRegister(p)
}

// Describe sends the super-set of all possible descriptors of metrics
// Describe sends the super-set of all possible descriptors of metrics.
func (p *MetricsCollector) Describe(descs chan<- *prometheus.Desc) {
descs <- p.currentScaleDescr
descs <- p.minScaleDescr
Expand Down
36 changes: 18 additions & 18 deletions pkg/endpoints/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,37 +93,37 @@ func (e *endpointGroup) getPort(portName string) int32 {
return e.ports[portName]
}

func (g *endpointGroup) lenIPs() int {
g.mtx.RLock()
defer g.mtx.RUnlock()
return len(g.endpoints)
func (e *endpointGroup) lenIPs() int {
e.mtx.RLock()
defer e.mtx.RUnlock()
return len(e.endpoints)
}

func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32) {
g.mtx.Lock()
g.ports = ports
func (e *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32) {
e.mtx.Lock()
e.ports = ports
for ip := range ips {
if _, ok := g.endpoints[ip]; !ok {
g.endpoints[ip] = endpoint{inFlight: &atomic.Int64{}}
if _, ok := e.endpoints[ip]; !ok {
e.endpoints[ip] = endpoint{inFlight: &atomic.Int64{}}
}
}
for ip := range g.endpoints {
for ip := range e.endpoints {
if _, ok := ips[ip]; !ok {
delete(g.endpoints, ip)
delete(e.endpoints, ip)
}
}
g.mtx.Unlock()
e.mtx.Unlock()

// notify waiting requests
if len(ips) > 0 {
g.broadcastEndpoints()
e.broadcastEndpoints()
}
}

func (g *endpointGroup) broadcastEndpoints() {
g.bmtx.Lock()
defer g.bmtx.Unlock()
func (e *endpointGroup) broadcastEndpoints() {
e.bmtx.Lock()
defer e.bmtx.Unlock()

close(g.bcast)
g.bcast = make(chan struct{})
close(e.bcast)
e.bcast = make(chan struct{})
}
1 change: 0 additions & 1 deletion pkg/endpoints/endpoints_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"k8s.io/apimachinery/pkg/util/rand"
)

Expand Down
3 changes: 1 addition & 2 deletions pkg/endpoints/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"sync"

disv1 "k8s.io/api/discovery/v1"

ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
)
Expand Down Expand Up @@ -113,7 +112,7 @@ func (r *Manager) getEndpoints(service string) *endpointGroup {
// AwaitHostAddress returns the host address with the lowest number of in-flight requests. It will block until the host address
// becomes available or the context times out.
//
// It returns a string in the format "host:port" or error on timeout
// It returns a string in the format "host:port" or error on timeout.
func (r *Manager) AwaitHostAddress(ctx context.Context, service, portName string) (string, error) {
return r.getEndpoints(service).getBestHost(ctx, portName)
}
Expand Down
8 changes: 4 additions & 4 deletions pkg/proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
modelName = "unknown"
log.Printf("error reading model from request body: %v", err)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Bad request: unable to parse .model from JSON payload"))
_, _ = w.Write([]byte("Bad request: unable to parse .model from JSON payload"))
nstogner marked this conversation as resolved.
Show resolved Hide resolved
return
}
log.Println("model:", modelName)
Expand All @@ -65,7 +65,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !found {
log.Printf("deployment not found for model: %v", err)
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName)))
_, _ = w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName)))
return
}

Expand All @@ -91,7 +91,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch {
case errors.Is(err, context.Canceled):
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("Request cancelled"))
_, _ = w.Write([]byte("Request canceled"))
return
case errors.Is(err, context.DeadlineExceeded):
w.WriteHeader(http.StatusGatewayTimeout)
Expand All @@ -112,7 +112,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// parseModel parses the model name from the request
// returns empty string when none found or an error for failures on the proxy request object
// returns empty string when none found or an error for failures on the proxy request object.
func parseModel(r *http.Request) (string, *http.Request, error) {
if model := r.Header.Get("X-Model"); model != "" {
return model, r, nil
Expand Down
4 changes: 2 additions & 2 deletions pkg/proxy/metrics_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ import (
"github.com/prometheus/client_model/go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/substratusai/lingo/pkg/deployments"
"k8s.io/apimachinery/pkg/runtime"
utilruntime "k8s.io/apimachinery/pkg/util/runtime"
clientgoscheme "k8s.io/client-go/kubernetes/scheme"

ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/cache"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/config"
"sigs.k8s.io/controller-runtime/pkg/manager"

"github.com/substratusai/lingo/pkg/deployments"
)

func TestMetrics(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions pkg/queue/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type MetricsCollector struct {
manager *Manager
}

// NewMetricsCollector constructor
// NewMetricsCollector constructor.
func NewMetricsCollector(m *Manager) *MetricsCollector {
if m == nil {
panic("manager required")
Expand All @@ -20,12 +20,12 @@ func NewMetricsCollector(m *Manager) *MetricsCollector {
}
}

// MustRegister registers all metrics
// MustRegister registers all metrics.
func (p *MetricsCollector) MustRegister(r prometheus.Registerer) {
r.MustRegister(p)
}

// Describe sends the super-set of all possible descriptors of metrics
// Describe sends the super-set of all possible descriptors of metrics.
func (p *MetricsCollector) Describe(descs chan<- *prometheus.Desc) {
descs <- p.inFlightDescr
descs <- p.queuedDescr
Expand Down
6 changes: 5 additions & 1 deletion pkg/queue/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package queue
import (
"container/list"
"context"
"fmt"
"log"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -152,7 +153,10 @@ func (q *Queue) Start() {
continue
}

itm := e.Value.(*item)
nstogner marked this conversation as resolved.
Show resolved Hide resolved
itm, ok := e.Value.(*item)
if !ok {
panic(fmt.Sprintf("invalid type: %T", e.Value))
}
q.dequeue(itm, true)
log.Println("Dequeued: ", itm.id)

Expand Down
Loading
Loading