Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move common pieces of mock REST servers to a new package #267

Merged
merged 15 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 12 additions & 144 deletions mocks/contractserver/contracts/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,158 +2,26 @@
package main

import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"strconv"

"github.com/canonical/ubuntu-pro-for-windows/mocks/contractserver/contractsmockserver"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
"github.com/canonical/ubuntu-pro-for-windows/mocks/restserver"
)

func main() {
rootCmd := rootCmd

rootCmd.AddCommand(defaultsCmd)

rootCmd.PersistentFlags().CountP("verbosity", "v", "WARNING (-v) INFO (-vv), DEBUG (-vvv)")
rootCmd.PersistentFlags().StringP("output", "o", "", "File where relevant non-log output will be written to")
rootCmd.Flags().StringP("address", "a", "", "Overrides the address where the server will be hosted")

if err := rootCmd.Execute(); err != nil {
slog.Error(fmt.Sprintf("Error executing: %v", err))
os.Exit(1)
}

os.Exit(0)
func serverFactory(settings restserver.Settings) restserver.Server {
//nolint:forcetypeassert // Let the type coersion panic on failure.
EduardGomezEscandell marked this conversation as resolved.
Show resolved Hide resolved
return contractsmockserver.NewServer(settings.(contractsmockserver.Settings))
}

// setVerboseMode changes the verbosity of the logs.
func setVerboseMode(n int) {
var level slog.Level
switch n {
case 0:
level = slog.LevelError
case 1:
level = slog.LevelWarn
case 2:
level = slog.LevelInfo
default:
level = slog.LevelDebug
}

h := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level})
slog.SetDefault(slog.New(h))
}
func main() {
defaultSettings := contractsmockserver.DefaultSettings()

func execName() string {
exe, err := os.Executable()
if err != nil {
slog.Error(fmt.Sprintf("Could not get executable name: %v", err))
os.Exit(1)
app := restserver.App{
Name: "contract server",
Description: "contract server",
DefaultSettings: &defaultSettings,
ServerFactory: serverFactory,
}

return filepath.Base(exe)
}

var defaultsCmd = &cobra.Command{
Use: "show-defaults",
Short: "See the default values for the contract server",
Long: "See the default values for the contract server. These are the settings that 'serve' will use unless overridden.",
Args: cobra.ExactArgs(0),
Run: func(cmd *cobra.Command, args []string) {
out, err := yaml.Marshal(contractsmockserver.DefaultSettings())
if err != nil {
slog.Error(fmt.Sprintf("Could not marshal default settings: %v", err))
os.Exit(1)
}

if outfile := cmd.Flag("output").Value.String(); outfile != "" {
if err := os.WriteFile(outfile, out, 0600); err != nil {
slog.Error(fmt.Sprintf("Could not write to output file: %v", err))
os.Exit(1)
}
return
}

fmt.Println(string(out))
},
}

var rootCmd = &cobra.Command{
Use: fmt.Sprintf("%s [settings_file]", execName()),
Short: "A mock contract server for Ubuntu Pro For Windows testing",
Long: `A mock contract server for Ubuntu Pro For Windows testing.
Serve the mock contract server with the optional settings file.
Default settings will be used if none are provided.
The outfile, if provided, will contain the address.`,
Args: cobra.RangeArgs(0, 1),
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
// Force a visit of the local flags so persistent flags for all parents are merged.
cmd.LocalFlags()

// command parsing has been successful. Returns to not print usage anymore.
cmd.SilenceUsage = true

v := cmd.Flag("verbosity").Value.String()
n, err := strconv.Atoi(v)
if err != nil {
return fmt.Errorf("could not parse verbosity: %v", err)
}

setVerboseMode(n)
return nil
},
Run: func(cmd *cobra.Command, args []string) {
ctx := context.Background()
settings := contractsmockserver.DefaultSettings()

if len(args) > 0 {
out, err := os.ReadFile(args[0])
if err != nil {
slog.Error(fmt.Sprintf("Could not read input file %q: %v", args[0], err))
os.Exit(1)
}

if err := yaml.Unmarshal(out, &settings); err != nil {
slog.Error(fmt.Sprintf("Could not unmarshal settings: %v", err))
os.Exit(1)
}
}

if addr := cmd.Flag("address").Value.String(); addr != "" {
settings.Address = addr
}

sv := contractsmockserver.NewServer(settings)
addr, err := sv.Serve(ctx)
if err != nil {
slog.Error(fmt.Sprintf("Could not serve: %v", err))
os.Exit(1)
}

defer func() {
if err := sv.Stop(); err != nil {
slog.Error(fmt.Sprintf("stopped serving: %v", err))
}
slog.Info("stopped serving")
}()

if outfile := cmd.Flag("output").Value.String(); outfile != "" {
if err := os.WriteFile(outfile, []byte(addr), 0600); err != nil {
slog.Error(fmt.Sprintf("Could not write output file: %v", err))
os.Exit(1)
}
}

slog.Info(fmt.Sprintf("Serving on address %s", addr))

// Wait loop
for scanned := ""; scanned != "exit"; fmt.Scanf("%s\n", &scanned) {
fmt.Println("Write 'exit' to stop serving")
}
},
os.Exit(app.Run())
}
139 changes: 15 additions & 124 deletions mocks/contractserver/contractsmockserver/contractsmockserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,13 @@
package contractsmockserver

import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"path"
"sync"
"time"

"github.com/canonical/ubuntu-pro-for-windows/contractsapi"
"github.com/canonical/ubuntu-pro-for-windows/mocks/restserver"
)

const (
Expand All @@ -28,120 +23,44 @@ const (

// Server is a mock of the contract server, where its behaviour can be modified.
type Server struct {
restserver.ServerBase
settings Settings

server *http.Server
mu sync.RWMutex

done chan struct{}
}

// Settings contains the parameters for the Server.
type Settings struct {
Token Endpoint
Subscription Endpoint
Address string
}

// Endpoint contains settings for an API endpoint behaviour. Can be modified for testing purposes.
type Endpoint struct {
// OnSuccess is the response returned in the happy path.
OnSuccess Response

// Disabled disables the endpoint.
Disabled bool

// Blocked means that a response will not be sent back, instead it'll block until the server is stopped.
Blocked bool
}

// Response contains settings for an API endpoint response behaviour. Can be modified for testing purposes.
type Response struct {
Value string
Status int
Token restserver.Endpoint
Subscription restserver.Endpoint
}

// DefaultSettings returns the default set of settings for the server.
func DefaultSettings() Settings {
EduardGomezEscandell marked this conversation as resolved.
Show resolved Hide resolved
return Settings{
Token: Endpoint{OnSuccess: Response{Value: DefaultADToken, Status: http.StatusOK}},
Subscription: Endpoint{OnSuccess: Response{Value: DefaultProToken, Status: http.StatusOK}},
Address: "localhost:0",
Token: restserver.Endpoint{OnSuccess: restserver.Response{Value: DefaultADToken, Status: http.StatusOK}},
Subscription: restserver.Endpoint{OnSuccess: restserver.Response{Value: DefaultProToken, Status: http.StatusOK}},
}
}

// NewServer creates a new contract server with the provided settings.
func NewServer(s Settings) *Server {
return &Server{
settings: s,
}
}

// Stop stops the server.
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()

if s.server == nil {
return errors.New("already stopped")
}

err := s.server.Close()
<-s.done

s.server = nil

return err
}

// Serve starts a new HTTP server mocking the Contracts Server backend REST API with
// responses defined according to Server Settings. Use Stop to Stop the server and
// release resources.
func (s *Server) Serve(ctx context.Context) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()

if s.server != nil {
return "", errors.New("already serving")
}

var lc net.ListenConfig
lis, err := lc.Listen(ctx, "tcp", s.settings.Address)
if err != nil {
return "", fmt.Errorf("failed to listen over tcp: %v", err)
}

sv := &Server{settings: s}
mux := http.NewServeMux()

if !s.settings.Token.Disabled {
mux.HandleFunc(path.Join(contractsapi.Version, contractsapi.TokenPath), s.handleToken)
if !s.Token.Disabled {
mux.HandleFunc(path.Join(contractsapi.Version, contractsapi.TokenPath), sv.handleToken)
}

if !s.settings.Subscription.Disabled {
mux.HandleFunc(path.Join(contractsapi.Version, contractsapi.SubscriptionPath), s.handleSubscription)
if !s.Subscription.Disabled {
mux.HandleFunc(path.Join(contractsapi.Version, contractsapi.SubscriptionPath), sv.handleSubscription)
}
sv.Mux = mux

s.server = &http.Server{
Addr: lis.Addr().String(),
Handler: mux,
ReadHeaderTimeout: 3 * time.Second,
}

s.done = make(chan struct{})

go func() {
defer close(s.done)
if err := s.server.Serve(lis); err != nil && err != http.ErrServerClosed {
slog.Error("Failed to start the HTTP server", "error", err)
}
}()

return lis.Addr().String(), nil
return sv
}

// handleToken implements the /token endpoint.
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
if err := s.handle(w, r, http.MethodGet, s.settings.Token); err != nil {
if err := s.ValidateRequest(w, r, http.MethodGet, s.settings.Token); err != nil {
fmt.Fprintf(w, "%v", err)
return
}
Expand All @@ -155,7 +74,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {

// handleSubscription implements the /susbcription endpoint.
func (s *Server) handleSubscription(w http.ResponseWriter, r *http.Request) {
if err := s.handle(w, r, http.MethodPost, s.settings.Subscription); err != nil {
if err := s.ValidateRequest(w, r, http.MethodPost, s.settings.Subscription); err != nil {
fmt.Fprintf(w, "%v", err)
return
}
Expand Down Expand Up @@ -186,31 +105,3 @@ func (s *Server) handleSubscription(w http.ResponseWriter, r *http.Request) {
return
}
}

// handle extracts common boilerplate from endpoints.
func (s *Server) handle(w http.ResponseWriter, r *http.Request, wantMethod string, endpoint Endpoint) (err error) {
slog.Info("Received request", "endpoint", r.URL.Path, "method", r.Method)
defer func() {
if err != nil {
slog.Error("bad request", "error", err, "endpoint", r.URL.Path, "method", r.Method)
}
}()

if r.Method != wantMethod {
w.WriteHeader(http.StatusBadRequest)
return fmt.Errorf("this endpoint only supports %s", wantMethod)
}

if endpoint.Blocked {
<-s.done
slog.Debug("Server context was cancelled. Exiting", "endpoint", r.URL.Path)
return errors.New("server stopped")
}

if endpoint.OnSuccess.Status != 200 {
w.WriteHeader(endpoint.OnSuccess.Status)
return fmt.Errorf("mock error: %d", endpoint.OnSuccess.Status)
}

return nil
}
Loading