Skip to content

Commit

Permalink
Adding Simple Gin Unit Tests (#24)
Browse files Browse the repository at this point in the history
Add a very basic gin test example where we
get the index page and check it contains
a particular post in the mock database.

To complete this, I had to mock the database
by turning the previous declaration into
an interface, and moving the concrete types
into `SqlDatabase`. The mock is done in `DatabaseMock`.
  • Loading branch information
matheusgomes28 authored Feb 24, 2024
1 parent 31a88c9 commit 082e970
Show file tree
Hide file tree
Showing 15 changed files with 288 additions and 68 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ jobs:
build:
uses: ./.github/workflows/build.yml
needs: failfast
# tests:
# uses: ./.github/workflows/test.yml
# needs: build

tests:
uses: ./.github/workflows/test.yml
needs: build

# release:
# uses: ./.github/workflows/release.yml
# needs: tests
Expand Down
24 changes: 24 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
name: Tests

on: [workflow_call]

jobs:
linters:
name: Tests 🧪

runs-on: ubuntu-latest
container:
image: mattgomes28/urchin-golang:0.2
options: --user 1001

steps:
- uses: actions/checkout@v3

- name: Generating templ files
run: |
templ generate
shell: bash

- name: Running Go Tests 🧪
run: |
go test ./... -v
28 changes: 10 additions & 18 deletions admin-app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package admin_app

import (
"encoding/json"
"fmt"
"net/http"
"strconv"

Expand Down Expand Up @@ -33,7 +32,7 @@ type DeletePostRequest struct {
Id int `json:"id"`
}

func getPostHandler(database *database.Database) func(*gin.Context) {
func getPostHandler(database database.Database) func(*gin.Context) {
return func(c *gin.Context) {
// localhost:8080/post/{id}
var post_binding PostBinding
Expand Down Expand Up @@ -73,7 +72,7 @@ func getPostHandler(database *database.Database) func(*gin.Context) {
}
}

func postPostHandler(database *database.Database) func(*gin.Context) {
func postPostHandler(database database.Database) func(*gin.Context) {
return func(c *gin.Context) {
var add_post_request AddPostRequest
decoder := json.NewDecoder(c.Request.Body)
Expand Down Expand Up @@ -108,7 +107,7 @@ func postPostHandler(database *database.Database) func(*gin.Context) {
}
}

func putPostHandler(database *database.Database) func(*gin.Context) {
func putPostHandler(database database.Database) func(*gin.Context) {
return func(c *gin.Context) {
var change_post_request ChangePostRequest
decoder := json.NewDecoder(c.Request.Body)
Expand Down Expand Up @@ -145,7 +144,7 @@ func putPostHandler(database *database.Database) func(*gin.Context) {
}
}

func deletePostHandler(database *database.Database) func(*gin.Context) {
func deletePostHandler(database database.Database) func(*gin.Context) {
return func(c *gin.Context) {
var delete_post_request DeletePostRequest
decoder := json.NewDecoder(c.Request.Body)
Expand Down Expand Up @@ -177,21 +176,14 @@ func deletePostHandler(database *database.Database) func(*gin.Context) {
}
}

func Run(app_settings common.AppSettings, database database.Database) error {
func SetupRoutes(app_settings common.AppSettings, database database.Database) *gin.Engine {

r := gin.Default()
r.MaxMultipartMemory = 1

r.GET("/posts/:id", getPostHandler(&database))
r.POST("/posts", postPostHandler(&database))
r.PUT("/posts", putPostHandler(&database))
r.DELETE("/posts", deletePostHandler(&database))

err := r.Run(fmt.Sprintf(":%s", app_settings.WebserverPort))
if err != nil {
log.Error().Msgf("could not run app: %v", err)
return err
}

return nil
r.GET("/posts/:id", getPostHandler(database))
r.POST("/posts", postPostHandler(database))
r.PUT("/posts", putPostHandler(database))
r.DELETE("/posts", deletePostHandler(database))
return r
}
21 changes: 7 additions & 14 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package app

import (
"bytes"
"fmt"
"net/http"
"time"

Expand All @@ -15,9 +14,9 @@ import (

const CACHE_TIMEOUT = 20 * time.Second

type Generator = func(*gin.Context, common.AppSettings, *database.Database) ([]byte, error)
type Generator = func(*gin.Context, common.AppSettings, database.Database) ([]byte, error)

func Run(app_settings common.AppSettings, database *database.Database) error {
func SetupRoutes(app_settings common.AppSettings, database database.Database) *gin.Engine {
r := gin.Default()
r.MaxMultipartMemory = 1

Expand All @@ -31,20 +30,14 @@ func Run(app_settings common.AppSettings, database *database.Database) error {
r.POST("/contact-send", makeContactFormHandler())

r.Static("/static", "./static")
err := r.Run(fmt.Sprintf(":%s", app_settings.WebserverPort))
if err != nil {
log.Error().Msgf("could not run app: %v", err)
return err
}

return nil
return r
}

func addCachableHandler(e *gin.Engine, method string, endpoint string, generator Generator, cache *Cache, app_settings common.AppSettings, db *database.Database) {
func addCachableHandler(e *gin.Engine, method string, endpoint string, generator Generator, cache *Cache, app_settings common.AppSettings, db database.Database) {

handler := func(c *gin.Context) {
// if the endpoint is cached
cached_endpoint, err := cache.Get(c.Request.RequestURI)
cached_endpoint, err := (*cache).Get(c.Request.RequestURI)
if err == nil {
c.Data(http.StatusOK, "text/html; charset=utf-8", cached_endpoint.contents)
return
Expand All @@ -57,7 +50,7 @@ func addCachableHandler(e *gin.Engine, method string, endpoint string, generator
}

// After handler (add to cache)
err = cache.Store(c.Request.RequestURI, html_buffer)
err = (*cache).Store(c.Request.RequestURI, html_buffer)
if err != nil {
log.Warn().Msgf("could not add page to cache: %v", err)
}
Expand All @@ -81,7 +74,7 @@ func addCachableHandler(e *gin.Engine, method string, endpoint string, generator

// / This function will act as the handler for
// / the home page
func homeHandler(c *gin.Context, settings common.AppSettings, db *database.Database) ([]byte, error) {
func homeHandler(c *gin.Context, settings common.AppSettings, db database.Database) ([]byte, error) {
posts, err := db.GetPosts()
if err != nil {
return nil, err
Expand Down
45 changes: 32 additions & 13 deletions app/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,55 +22,74 @@ func emptyEndpointCache() EndpointCache {
return EndpointCache{"", []byte{}, time.Now()}
}

type Cache struct {
type Cache interface {
Get(name string) (EndpointCache, error)
Store(name string, buffer []byte) error
Size() uint64
}

type CacheValidator interface {
IsValid(cache *EndpointCache) bool
}

type TimeValidator struct{}

func (validator *TimeValidator) IsValid(cache *EndpointCache) bool {
// We only return the cache if it's still valid
return cache.validUntil.After(time.Now())
}

type TimedCache struct {
cacheMap shardedmap.ShardMap
cacheTimeout time.Duration
estimatedSize atomic.Uint64 // in bytes
validator CacheValidator
}

func (self *Cache) Store(name string, buffer []byte) error {
func (cache *TimedCache) Store(name string, buffer []byte) error {
// Only store to the cache if we have enough space left
afterSizeMB := float64(self.estimatedSize.Load()+uint64(len(buffer))) / 1000000
afterSizeMB := float64(cache.estimatedSize.Load()+uint64(len(buffer))) / 1000000
if afterSizeMB > MAX_CACHE_SIZE_MB {
return fmt.Errorf("maximum size reached")
}

var cache_entry interface{} = EndpointCache{
name: name,
contents: buffer,
validUntil: time.Now().Add(self.cacheTimeout),
validUntil: time.Now().Add(cache.cacheTimeout),
}
self.cacheMap.Set(name, &cache_entry)
self.estimatedSize.Add(uint64(len(buffer)))
cache.cacheMap.Set(name, &cache_entry)
cache.estimatedSize.Add(uint64(len(buffer)))
return nil
}

func (self *Cache) Get(name string) (EndpointCache, error) {
func (cache *TimedCache) Get(name string) (EndpointCache, error) {
// if the endpoint is cached
cached_entry := self.cacheMap.Get(name)
cached_entry := cache.cacheMap.Get(name)
if cached_entry != nil {
cache_contents := (*cached_entry).(EndpointCache)

// We only return the cache if it's still valid
if cache_contents.validUntil.After(time.Now()) {
if cache.validator.IsValid(&cache_contents) {
return cache_contents, nil
} else {
self.cacheMap.Delete(name)
cache.cacheMap.Delete(name)
return emptyEndpointCache(), fmt.Errorf("cached endpoint had expired")
}
}

return emptyEndpointCache(), fmt.Errorf("cache does not contain key")
}

func (self *Cache) Size() uint64 {
return self.estimatedSize.Load()
func (cache *TimedCache) Size() uint64 {
return cache.estimatedSize.Load()
}

func makeCache(n_shards int, expiry_duration time.Duration) Cache {
return Cache{
return &TimedCache{
cacheMap: shardedmap.NewShardMap(n_shards),
cacheTimeout: expiry_duration,
estimatedSize: atomic.Uint64{},
validator: &TimeValidator{},
}
}
95 changes: 95 additions & 0 deletions app/cache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package app

import (
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/assert"
shardedmap "github.com/zutto/shardedmap"
)

type TrueTimeMockValidator struct{}

func (validator *TrueTimeMockValidator) IsValid(cache *EndpointCache) bool {
return true
}

type FalseTimeMockValidator struct{}

func (validator *FalseTimeMockValidator) IsValid(cache *EndpointCache) bool {
return false
}

func makeTrueCacheMock() Cache {
return &TimedCache{
cacheMap: shardedmap.NewShardMap(2),
cacheTimeout: 10 * time.Second,
estimatedSize: atomic.Uint64{},
validator: &TrueTimeMockValidator{},
}
}

func makeFalseCacheMock() Cache {
return &TimedCache{
cacheMap: shardedmap.NewShardMap(2),
cacheTimeout: 10 * time.Second,
estimatedSize: atomic.Uint64{},
validator: &FalseTimeMockValidator{},
}
}

func TestCacheAddition(t *testing.T) {

test_data := []struct {
name string
contents []byte
}{
{"first", []byte("hello")},
{"second", []byte("the quick brown fox does some weird stuff")},
{"veryLonNameThatIsProbablyTooLong", []byte("Hello there my friends")},
{"nPe9Rkff6ER6EzAxPUIpxc8UBBLm71hhq2MO9hkQWisrfihUqv", []byte("oA7Hv1A7vOuZSKrPT4ZN5DGKNSHZqpLEvUA5hu54CMyIt8c78u")},
}

cache := makeTrueCacheMock()

rolling_size := uint64(0)
for _, test_case := range test_data {

rolling_size += uint64(len(test_case.contents))
err := cache.Store(test_case.name, test_case.contents)
assert.Nil(t, err)
assert.Equal(t, cache.Size(), rolling_size)

endpoint_cache, err := cache.Get(test_case.name)
assert.Nil(t, err)
assert.Equal(t, endpoint_cache.contents, test_case.contents)
}
}

func TestCacheFailure(t *testing.T) {

test_data := []struct {
name string
contents []byte
}{
{"first", []byte("hello")},
{"second", []byte("the quick brown fox does some weird stuff")},
{"veryLonNameThatIsProbablyTooLong", []byte("Hello there my friends")},
{"nPe9Rkff6ER6EzAxPUIpxc8UBBLm71hhq2MO9hkQWisrfihUqv", []byte("oA7Hv1A7vOuZSKrPT4ZN5DGKNSHZqpLEvUA5hu54CMyIt8c78u")},
}

cache := makeFalseCacheMock()

rolling_size := uint64(0)
for _, test_case := range test_data {

rolling_size += uint64(len(test_case.contents))
err := cache.Store(test_case.name, test_case.contents)
assert.Nil(t, err)
assert.Equal(t, cache.Size(), rolling_size)

_, err = cache.Get(test_case.name)
assert.NotNil(t, err)
}
}
2 changes: 1 addition & 1 deletion app/contact.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func makeContactFormHandler() func(*gin.Context) {
}

// TODO : This is a duplicate of the index handler... abstract
func contactHandler(c *gin.Context, app_settings common.AppSettings, db *database.Database) ([]byte, error) {
func contactHandler(c *gin.Context, app_settings common.AppSettings, db database.Database) ([]byte, error) {
index_view := views.MakeContactPage()
html_buffer := bytes.NewBuffer(nil)
if err := index_view.Render(c, html_buffer); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion app/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func mdToHTML(md []byte) []byte {
return markdown.Render(doc, renderer)
}

func postHandler(c *gin.Context, app_settings common.AppSettings, database *database.Database) ([]byte, error) {
func postHandler(c *gin.Context, app_settings common.AppSettings, database database.Database) ([]byte, error) {
// localhost:8080/post/{id}

var post_binding PostBinding
Expand Down
Loading

0 comments on commit 082e970

Please sign in to comment.