Skip to content

Commit

Permalink
Add a very basic gin test example
Browse files Browse the repository at this point in the history
where we get the index page and check it contains
a particular post in the mock database.

To complete this, I had to mock the database
by turning the previous declaration into
an interface, and moving the concrete types
into SqlDatabase. The mock is done in DatabaseMock.
  • Loading branch information
matheusgomes28 committed Feb 24, 2024
1 parent 31a88c9 commit 2b8b006
Show file tree
Hide file tree
Showing 11 changed files with 130 additions and 63 deletions.
28 changes: 10 additions & 18 deletions admin-app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package admin_app

import (
"encoding/json"
"fmt"
"net/http"
"strconv"

Expand Down Expand Up @@ -33,7 +32,7 @@ type DeletePostRequest struct {
Id int `json:"id"`
}

func getPostHandler(database *database.Database) func(*gin.Context) {
func getPostHandler(database database.Database) func(*gin.Context) {
return func(c *gin.Context) {
// localhost:8080/post/{id}
var post_binding PostBinding
Expand Down Expand Up @@ -73,7 +72,7 @@ func getPostHandler(database *database.Database) func(*gin.Context) {
}
}

func postPostHandler(database *database.Database) func(*gin.Context) {
func postPostHandler(database database.Database) func(*gin.Context) {
return func(c *gin.Context) {
var add_post_request AddPostRequest
decoder := json.NewDecoder(c.Request.Body)
Expand Down Expand Up @@ -108,7 +107,7 @@ func postPostHandler(database *database.Database) func(*gin.Context) {
}
}

func putPostHandler(database *database.Database) func(*gin.Context) {
func putPostHandler(database database.Database) func(*gin.Context) {
return func(c *gin.Context) {
var change_post_request ChangePostRequest
decoder := json.NewDecoder(c.Request.Body)
Expand Down Expand Up @@ -145,7 +144,7 @@ func putPostHandler(database *database.Database) func(*gin.Context) {
}
}

func deletePostHandler(database *database.Database) func(*gin.Context) {
func deletePostHandler(database database.Database) func(*gin.Context) {
return func(c *gin.Context) {
var delete_post_request DeletePostRequest
decoder := json.NewDecoder(c.Request.Body)
Expand Down Expand Up @@ -177,21 +176,14 @@ func deletePostHandler(database *database.Database) func(*gin.Context) {
}
}

func Run(app_settings common.AppSettings, database database.Database) error {
func SetupRoutes(app_settings common.AppSettings, database database.Database) *gin.Engine{

r := gin.Default()
r.MaxMultipartMemory = 1

r.GET("/posts/:id", getPostHandler(&database))
r.POST("/posts", postPostHandler(&database))
r.PUT("/posts", putPostHandler(&database))
r.DELETE("/posts", deletePostHandler(&database))

err := r.Run(fmt.Sprintf(":%s", app_settings.WebserverPort))
if err != nil {
log.Error().Msgf("could not run app: %v", err)
return err
}

return nil
r.GET("/posts/:id", getPostHandler(database))
r.POST("/posts", postPostHandler(database))
r.PUT("/posts", putPostHandler(database))
r.DELETE("/posts", deletePostHandler(database))
return r
}
21 changes: 7 additions & 14 deletions app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package app

import (
"bytes"
"fmt"
"net/http"
"time"

Expand All @@ -15,9 +14,9 @@ import (

const CACHE_TIMEOUT = 20 * time.Second

type Generator = func(*gin.Context, common.AppSettings, *database.Database) ([]byte, error)
type Generator = func(*gin.Context, common.AppSettings, database.Database) ([]byte, error)

func Run(app_settings common.AppSettings, database *database.Database) error {
func SetupRoutes(app_settings common.AppSettings, database database.Database) *gin.Engine{
r := gin.Default()
r.MaxMultipartMemory = 1

Expand All @@ -31,16 +30,10 @@ func Run(app_settings common.AppSettings, database *database.Database) error {
r.POST("/contact-send", makeContactFormHandler())

r.Static("/static", "./static")
err := r.Run(fmt.Sprintf(":%s", app_settings.WebserverPort))
if err != nil {
log.Error().Msgf("could not run app: %v", err)
return err
}

return nil
return r
}

func addCachableHandler(e *gin.Engine, method string, endpoint string, generator Generator, cache *Cache, app_settings common.AppSettings, db *database.Database) {
func addCachableHandler(e *gin.Engine, method string, endpoint string, generator Generator, cache *Cache, app_settings common.AppSettings, db database.Database) {

handler := func(c *gin.Context) {
// if the endpoint is cached
Expand Down Expand Up @@ -79,9 +72,9 @@ func addCachableHandler(e *gin.Engine, method string, endpoint string, generator
}
}

// / This function will act as the handler for
// / the home page
func homeHandler(c *gin.Context, settings common.AppSettings, db *database.Database) ([]byte, error) {
/// This function will act as the handler for
/// the home page
func homeHandler(c *gin.Context, settings common.AppSettings, db database.Database) ([]byte, error) {
posts, err := db.GetPosts()
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion app/contact.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func makeContactFormHandler() func(*gin.Context) {
}

// TODO : This is a duplicate of the index handler... abstract
func contactHandler(c *gin.Context, app_settings common.AppSettings, db *database.Database) ([]byte, error) {
func contactHandler(c *gin.Context, app_settings common.AppSettings, db database.Database) ([]byte, error) {
index_view := views.MakeContactPage()
html_buffer := bytes.NewBuffer(nil)
if err := index_view.Render(c, html_buffer); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion app/post.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func mdToHTML(md []byte) []byte {
return markdown.Render(doc, renderer)
}

func postHandler(c *gin.Context, app_settings common.AppSettings, database *database.Database) ([]byte, error) {
func postHandler(c *gin.Context, app_settings common.AppSettings, database database.Database) ([]byte, error) {
// localhost:8080/post/{id}

var post_binding PostBinding
Expand Down
8 changes: 4 additions & 4 deletions cmd/urchin-admin/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package main

import (
"fmt"

_ "github.com/go-sql-driver/mysql"
admin_app "github.com/matheusgomes28/urchin/admin-app"
"github.com/matheusgomes28/urchin/common"
Expand Down Expand Up @@ -30,8 +32,6 @@ func main() {
log.Fatal().Msgf("could not create database: %v", err)
}

err = admin_app.Run(app_settings, database)
if err != nil {
log.Fatal().Msgf("could not run app: %v", err)
}
r := admin_app.SetupRoutes(app_settings, database)
r.Run(fmt.Sprintf(":%d", app_settings.WebserverPort))
}
64 changes: 64 additions & 0 deletions cmd/urchin/index_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package main

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/matheusgomes28/urchin/app"
"github.com/matheusgomes28/urchin/common"

"github.com/stretchr/testify/assert"
)

type DatabaseMock struct {}

func (db DatabaseMock) GetPosts() ([]common.Post, error) {
return []common.Post{
{
Title: "TestPost",
Content: "TestContent",
Excerpt: "TestExcerpt",
Id: 0,
},
}, nil
}

func (db DatabaseMock) GetPost(post_id int) (common.Post, error) {
return common.Post{}, fmt.Errorf("not implemented")
}

func (db DatabaseMock) AddPost(title string, excerpt string, content string) (int, error) {
return 0, fmt.Errorf("not implemented")
}

func (db DatabaseMock) ChangePost(id int, title string, excerpt string, content string) error {
return nil
}

func (db DatabaseMock) DeletePost(id int) error {
return fmt.Errorf("not implemented")
}

func TestIndexPing(t *testing.T) {
app_settings := common.AppSettings{
DatabaseAddress: "localhost",
DatabasePort: 3006,
DatabaseUser: "root",
DatabasePassword: "root",
DatabaseName: "urchin",
WebserverPort: 8080,
}

database_mock := DatabaseMock{}

r := app.SetupRoutes(app_settings, database_mock)
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
r.ServeHTTP(w, req)

assert.Equal(t, 200, w.Code)
assert.Contains(t, w.Body.String(), "TestPost")
assert.Contains(t, w.Body.String(), "TestExcerpt")
}
7 changes: 4 additions & 3 deletions cmd/urchin/main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package main

import (
"fmt"

_ "github.com/go-sql-driver/mysql"
"github.com/matheusgomes28/urchin/app"
"github.com/matheusgomes28/urchin/common"
Expand All @@ -27,7 +29,6 @@ func main() {
log.Error().Msgf("could not create database connection: %v", err)
}

if err = app.Run(app_settings, &db_connection); err != nil {
log.Error().Msgf("could not run app: %v", err)
}
r := app.SetupRoutes(app_settings, db_connection)
r.Run(fmt.Sprintf(":%d", app_settings.WebserverPort))
}
13 changes: 9 additions & 4 deletions common/app_settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ type AppSettings struct {
DatabasePort int
DatabaseUser string
DatabasePassword string
DatabaseName string
WebserverPort string
DatabaseName string
WebserverPort int
}

func LoadSettings() (AppSettings, error) {
Expand Down Expand Up @@ -47,11 +47,16 @@ func LoadSettings() (AppSettings, error) {
return AppSettings{}, fmt.Errorf("URCHIN_DATABASE_PORT is not a valid integer: %v", err)
}

webserver_port := os.Getenv("URCHIN_WEBSERVER_PORT")
if webserver_port == "" {
webserver_port_str := os.Getenv("URCHIN_WEBSERVER_PORT")
if webserver_port_str == "" {
return AppSettings{}, fmt.Errorf("URCHIN_WEBSERVER_PORT is not defined")
}

webserver_port, err := strconv.Atoi(webserver_port_str)
if err != nil {
return AppSettings{}, fmt.Errorf("URCHIN_WEBSERVER_PORT is not valid: %v", err)
}

return AppSettings{
DatabaseUser: database_user,
DatabasePassword: database_password,
Expand Down
40 changes: 24 additions & 16 deletions database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@ import (
"github.com/rs/zerolog/log"
)

type Database struct {
type Database interface {
GetPosts() ([]common.Post, error)
GetPost(post_id int) (common.Post, error)
AddPost(title string, excerpt string, content string) (int, error)
ChangePost(id int, title string, excerpt string, content string) error
DeletePost(id int) error
}

type SqlDatabase struct {
Address string
Port int
User string
Expand All @@ -19,7 +27,7 @@ type Database struct {

// / This function gets all the posts from the current
// / database connection.
func (db Database) GetPosts() ([]common.Post, error) {
func (db SqlDatabase) GetPosts() ([]common.Post, error) {
rows, err := db.Connection.Query("SELECT title, excerpt, id FROM posts;")
if err != nil {
return make([]common.Post, 0), err
Expand All @@ -40,7 +48,7 @@ func (db Database) GetPosts() ([]common.Post, error) {

// / This function gets a post from the database
// / with the given ID.
func (db *Database) GetPost(post_id int) (common.Post, error) {
func (db SqlDatabase) GetPost(post_id int) (common.Post, error) {
rows, err := db.Connection.Query("SELECT title, content FROM posts WHERE id=?;", post_id)
if err != nil {
return common.Post{}, err
Expand All @@ -56,8 +64,8 @@ func (db *Database) GetPost(post_id int) (common.Post, error) {
return post, nil
}

// / This function adds a post to the database
func (db *Database) AddPost(title string, excerpt string, content string) (int, error) {
/// This function adds a post to the database
func (db SqlDatabase) AddPost(title string, excerpt string, content string) (int, error) {
res, err := db.Connection.Exec("INSERT INTO posts(content, title, excerpt) VALUES(?, ?, ?)", content, title, excerpt)
if err != nil {
return -1, err
Expand All @@ -75,10 +83,10 @@ func (db *Database) AddPost(title string, excerpt string, content string) (int,
return int(id), nil
}

// / This function changes a post based on the values
// / provided. Note that empty strings will mean that
// / the value will not be updated.
func (db *Database) ChangePost(id int, title string, excerpt string, content string) error {
/// This function changes a post based on the values
/// provided. Note that empty strings will mean that
/// the value will not be updated.
func (db SqlDatabase) ChangePost(id int, title string, excerpt string, content string) error {
tx, err := db.Connection.Begin()
if err != nil {
return err
Expand Down Expand Up @@ -115,31 +123,31 @@ func (db *Database) ChangePost(id int, title string, excerpt string, content str
return nil
}

// / This function changes a post based on the values
// / provided. Note that empty strings will mean that
// / the value will not be updated.
func (db *Database) DeletePost(id int) error {
/// This function changes a post based on the values
/// provided. Note that empty strings will mean that
/// the value will not be updated.
func (db SqlDatabase) DeletePost(id int) error {
if _, err := db.Connection.Exec("DELETE FROM posts WHERE id=?;", id); err != nil {
return err
}

return nil
}

func MakeSqlConnection(user string, password string, address string, port int, database string) (Database, error) {
func MakeSqlConnection(user string, password string, address string, port int, database string) (SqlDatabase, error) {
/// Checking the DB connection
/// TODO : let user specify the DB
connection_str := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", user, password, address, port, database)
db, err := sql.Open("mysql", connection_str)
if err != nil {
return Database{}, err
return SqlDatabase{}, err
}
// See "Important settings" section.
db.SetConnMaxLifetime(time.Second * 5)
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(10)

return Database{
return SqlDatabase{
Address: address,
Port: port,
User: user,
Expand Down
Loading

0 comments on commit 2b8b006

Please sign in to comment.