Skip to content

Commit

Permalink
Move common pieces of mock REST servers to a new package (#267)
Browse files Browse the repository at this point in the history
The store mock server is structurally very close to the contracts mock:
YAML configuration, similar CLI, both talk via REST, both may have
blocked or disabled endpoints etc.

Thus I though that I could move the guts of the contracts mock into a
new package and make both mock servers inherit from this common base.

No semantics or endpoint behaviors should have changed with this.
  • Loading branch information
CarlosNihelton authored Sep 12, 2023
2 parents 5ca9c4a + b2a6b11 commit 8ac5850
Show file tree
Hide file tree
Showing 6 changed files with 358 additions and 271 deletions.
156 changes: 12 additions & 144 deletions mocks/contractserver/contracts/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,158 +2,26 @@
package main

import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"strconv"

"github.com/canonical/ubuntu-pro-for-windows/mocks/contractserver/contractsmockserver"
"github.com/spf13/cobra"
"gopkg.in/yaml.v3"
"github.com/canonical/ubuntu-pro-for-windows/mocks/restserver"
)

func main() {
rootCmd := rootCmd

rootCmd.AddCommand(defaultsCmd)

rootCmd.PersistentFlags().CountP("verbosity", "v", "WARNING (-v) INFO (-vv), DEBUG (-vvv)")
rootCmd.PersistentFlags().StringP("output", "o", "", "File where relevant non-log output will be written to")
rootCmd.Flags().StringP("address", "a", "", "Overrides the address where the server will be hosted")

if err := rootCmd.Execute(); err != nil {
slog.Error(fmt.Sprintf("Error executing: %v", err))
os.Exit(1)
}

os.Exit(0)
func serverFactory(settings restserver.Settings) restserver.Server {
//nolint:forcetypeassert // Let the type coersion panic on failure.
return contractsmockserver.NewServer(settings.(contractsmockserver.Settings))
}

// setVerboseMode changes the verbosity of the logs.
func setVerboseMode(n int) {
var level slog.Level
switch n {
case 0:
level = slog.LevelError
case 1:
level = slog.LevelWarn
case 2:
level = slog.LevelInfo
default:
level = slog.LevelDebug
}

h := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{Level: level})
slog.SetDefault(slog.New(h))
}
func main() {
defaultSettings := contractsmockserver.DefaultSettings()

func execName() string {
exe, err := os.Executable()
if err != nil {
slog.Error(fmt.Sprintf("Could not get executable name: %v", err))
os.Exit(1)
app := restserver.App{
Name: "contract server",
Description: "contract server",
DefaultSettings: &defaultSettings,
ServerFactory: serverFactory,
}

return filepath.Base(exe)
}

var defaultsCmd = &cobra.Command{
Use: "show-defaults",
Short: "See the default values for the contract server",
Long: "See the default values for the contract server. These are the settings that 'serve' will use unless overridden.",
Args: cobra.ExactArgs(0),
Run: func(cmd *cobra.Command, args []string) {
out, err := yaml.Marshal(contractsmockserver.DefaultSettings())
if err != nil {
slog.Error(fmt.Sprintf("Could not marshal default settings: %v", err))
os.Exit(1)
}

if outfile := cmd.Flag("output").Value.String(); outfile != "" {
if err := os.WriteFile(outfile, out, 0600); err != nil {
slog.Error(fmt.Sprintf("Could not write to output file: %v", err))
os.Exit(1)
}
return
}

fmt.Println(string(out))
},
}

var rootCmd = &cobra.Command{
Use: fmt.Sprintf("%s [settings_file]", execName()),
Short: "A mock contract server for Ubuntu Pro For Windows testing",
Long: `A mock contract server for Ubuntu Pro For Windows testing.
Serve the mock contract server with the optional settings file.
Default settings will be used if none are provided.
The outfile, if provided, will contain the address.`,
Args: cobra.RangeArgs(0, 1),
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
// Force a visit of the local flags so persistent flags for all parents are merged.
cmd.LocalFlags()

// command parsing has been successful. Returns to not print usage anymore.
cmd.SilenceUsage = true

v := cmd.Flag("verbosity").Value.String()
n, err := strconv.Atoi(v)
if err != nil {
return fmt.Errorf("could not parse verbosity: %v", err)
}

setVerboseMode(n)
return nil
},
Run: func(cmd *cobra.Command, args []string) {
ctx := context.Background()
settings := contractsmockserver.DefaultSettings()

if len(args) > 0 {
out, err := os.ReadFile(args[0])
if err != nil {
slog.Error(fmt.Sprintf("Could not read input file %q: %v", args[0], err))
os.Exit(1)
}

if err := yaml.Unmarshal(out, &settings); err != nil {
slog.Error(fmt.Sprintf("Could not unmarshal settings: %v", err))
os.Exit(1)
}
}

if addr := cmd.Flag("address").Value.String(); addr != "" {
settings.Address = addr
}

sv := contractsmockserver.NewServer(settings)
addr, err := sv.Serve(ctx)
if err != nil {
slog.Error(fmt.Sprintf("Could not serve: %v", err))
os.Exit(1)
}

defer func() {
if err := sv.Stop(); err != nil {
slog.Error(fmt.Sprintf("stopped serving: %v", err))
}
slog.Info("stopped serving")
}()

if outfile := cmd.Flag("output").Value.String(); outfile != "" {
if err := os.WriteFile(outfile, []byte(addr), 0600); err != nil {
slog.Error(fmt.Sprintf("Could not write output file: %v", err))
os.Exit(1)
}
}

slog.Info(fmt.Sprintf("Serving on address %s", addr))

// Wait loop
for scanned := ""; scanned != "exit"; fmt.Scanf("%s\n", &scanned) {
fmt.Println("Write 'exit' to stop serving")
}
},
os.Exit(app.Run())
}
139 changes: 15 additions & 124 deletions mocks/contractserver/contractsmockserver/contractsmockserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,13 @@
package contractsmockserver

import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"path"
"sync"
"time"

"github.com/canonical/ubuntu-pro-for-windows/contractsapi"
"github.com/canonical/ubuntu-pro-for-windows/mocks/restserver"
)

const (
Expand All @@ -28,120 +23,44 @@ const (

// Server is a mock of the contract server, where its behaviour can be modified.
type Server struct {
restserver.ServerBase
settings Settings

server *http.Server
mu sync.RWMutex

done chan struct{}
}

// Settings contains the parameters for the Server.
type Settings struct {
Token Endpoint
Subscription Endpoint
Address string
}

// Endpoint contains settings for an API endpoint behaviour. Can be modified for testing purposes.
type Endpoint struct {
// OnSuccess is the response returned in the happy path.
OnSuccess Response

// Disabled disables the endpoint.
Disabled bool

// Blocked means that a response will not be sent back, instead it'll block until the server is stopped.
Blocked bool
}

// Response contains settings for an API endpoint response behaviour. Can be modified for testing purposes.
type Response struct {
Value string
Status int
Token restserver.Endpoint
Subscription restserver.Endpoint
}

// DefaultSettings returns the default set of settings for the server.
func DefaultSettings() Settings {
return Settings{
Token: Endpoint{OnSuccess: Response{Value: DefaultADToken, Status: http.StatusOK}},
Subscription: Endpoint{OnSuccess: Response{Value: DefaultProToken, Status: http.StatusOK}},
Address: "localhost:0",
Token: restserver.Endpoint{OnSuccess: restserver.Response{Value: DefaultADToken, Status: http.StatusOK}},
Subscription: restserver.Endpoint{OnSuccess: restserver.Response{Value: DefaultProToken, Status: http.StatusOK}},
}
}

// NewServer creates a new contract server with the provided settings.
func NewServer(s Settings) *Server {
return &Server{
settings: s,
}
}

// Stop stops the server.
func (s *Server) Stop() error {
s.mu.Lock()
defer s.mu.Unlock()

if s.server == nil {
return errors.New("already stopped")
}

err := s.server.Close()
<-s.done

s.server = nil

return err
}

// Serve starts a new HTTP server mocking the Contracts Server backend REST API with
// responses defined according to Server Settings. Use Stop to Stop the server and
// release resources.
func (s *Server) Serve(ctx context.Context) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()

if s.server != nil {
return "", errors.New("already serving")
}

var lc net.ListenConfig
lis, err := lc.Listen(ctx, "tcp", s.settings.Address)
if err != nil {
return "", fmt.Errorf("failed to listen over tcp: %v", err)
}

sv := &Server{settings: s}
mux := http.NewServeMux()

if !s.settings.Token.Disabled {
mux.HandleFunc(path.Join(contractsapi.Version, contractsapi.TokenPath), s.handleToken)
if !s.Token.Disabled {
mux.HandleFunc(path.Join(contractsapi.Version, contractsapi.TokenPath), sv.handleToken)
}

if !s.settings.Subscription.Disabled {
mux.HandleFunc(path.Join(contractsapi.Version, contractsapi.SubscriptionPath), s.handleSubscription)
if !s.Subscription.Disabled {
mux.HandleFunc(path.Join(contractsapi.Version, contractsapi.SubscriptionPath), sv.handleSubscription)
}
sv.Mux = mux

s.server = &http.Server{
Addr: lis.Addr().String(),
Handler: mux,
ReadHeaderTimeout: 3 * time.Second,
}

s.done = make(chan struct{})

go func() {
defer close(s.done)
if err := s.server.Serve(lis); err != nil && err != http.ErrServerClosed {
slog.Error("Failed to start the HTTP server", "error", err)
}
}()

return lis.Addr().String(), nil
return sv
}

// handleToken implements the /token endpoint.
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
if err := s.handle(w, r, http.MethodGet, s.settings.Token); err != nil {
if err := s.ValidateRequest(w, r, http.MethodGet, s.settings.Token); err != nil {
fmt.Fprintf(w, "%v", err)
return
}
Expand All @@ -155,7 +74,7 @@ func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {

// handleSubscription implements the /susbcription endpoint.
func (s *Server) handleSubscription(w http.ResponseWriter, r *http.Request) {
if err := s.handle(w, r, http.MethodPost, s.settings.Subscription); err != nil {
if err := s.ValidateRequest(w, r, http.MethodPost, s.settings.Subscription); err != nil {
fmt.Fprintf(w, "%v", err)
return
}
Expand Down Expand Up @@ -186,31 +105,3 @@ func (s *Server) handleSubscription(w http.ResponseWriter, r *http.Request) {
return
}
}

// handle extracts common boilerplate from endpoints.
func (s *Server) handle(w http.ResponseWriter, r *http.Request, wantMethod string, endpoint Endpoint) (err error) {
slog.Info("Received request", "endpoint", r.URL.Path, "method", r.Method)
defer func() {
if err != nil {
slog.Error("bad request", "error", err, "endpoint", r.URL.Path, "method", r.Method)
}
}()

if r.Method != wantMethod {
w.WriteHeader(http.StatusBadRequest)
return fmt.Errorf("this endpoint only supports %s", wantMethod)
}

if endpoint.Blocked {
<-s.done
slog.Debug("Server context was cancelled. Exiting", "endpoint", r.URL.Path)
return errors.New("server stopped")
}

if endpoint.OnSuccess.Status != 200 {
w.WriteHeader(endpoint.OnSuccess.Status)
return fmt.Errorf("mock error: %d", endpoint.OnSuccess.Status)
}

return nil
}
Loading

0 comments on commit 8ac5850

Please sign in to comment.