From b9d3e534f5c8d9498f7afe94852a4614096f8a0b Mon Sep 17 00:00:00 2001 From: matheusgomes28 Date: Tue, 12 Mar 2024 21:35:59 +0000 Subject: [PATCH 1/3] Adding more unit tests for endpoints (#46) Increase the coverage a little bit by testing endpoints --- admin-app/post.go | 2 +- app/app.go | 10 ++- app/cache.go | 18 ++--- cmd/urchin/index_test.go | 68 ------------------ .../endpoint_tests/endpoint_test.go | 68 ++++++++++++++++++ .../app_settings}/app_settings_test.go | 23 +++--- {app => tests/app_tests/cache}/cache_test.go | 44 +++++------- tests/app_tests/endpoint_tests/index_test.go | 71 +++++++++++++++++++ tests/app_tests/endpoint_tests/posts_test.go | 42 +++++++++++ tests/mocks/mocks.go | 36 ++++++++++ 10 files changed, 265 insertions(+), 117 deletions(-) delete mode 100644 cmd/urchin/index_test.go create mode 100644 tests/admin_app_tests/endpoint_tests/endpoint_test.go rename {common => tests/app_tests/app_settings}/app_settings_test.go (93%) rename {app => tests/app_tests/cache}/cache_test.go (67%) create mode 100644 tests/app_tests/endpoint_tests/index_test.go create mode 100644 tests/app_tests/endpoint_tests/posts_test.go create mode 100644 tests/mocks/mocks.go diff --git a/admin-app/post.go b/admin-app/post.go index 487e36f..c5952e8 100644 --- a/admin-app/post.go +++ b/admin-app/post.go @@ -57,7 +57,7 @@ func postPostHandler(database database.Database) func(*gin.Context) { err := decoder.Decode(&add_post_request) if err != nil { - log.Warn().Msgf("could not get post from DB: %v", err) + log.Warn().Msgf("invalid post request: %v", err) c.JSON(http.StatusBadRequest, gin.H{ "error": "invalid request body", "msg": err.Error(), diff --git a/app/app.go b/app/app.go index 1cf8664..c74adbf 100644 --- a/app/app.go +++ b/app/app.go @@ -21,7 +21,7 @@ func SetupRoutes(app_settings common.AppSettings, database database.Database) *g r.MaxMultipartMemory = 1 // All cache-able endpoints - cache := makeCache(4, time.Minute*10) + cache := MakeCache(4, time.Minute*10, &TimeValidator{}) addCachableHandler(r, "GET", "/", homeHandler, &cache, app_settings, database) addCachableHandler(r, "GET", "/contact", contactHandler, &cache, app_settings, database) addCachableHandler(r, "GET", "/post/:id", postHandler, &cache, app_settings, database) @@ -39,7 +39,7 @@ func addCachableHandler(e *gin.Engine, method string, endpoint string, generator // if the endpoint is cached cached_endpoint, err := (*cache).Get(c.Request.RequestURI) if err == nil { - c.Data(http.StatusOK, "text/html; charset=utf-8", cached_endpoint.contents) + c.Data(http.StatusOK, "text/html; charset=utf-8", cached_endpoint.Contents) return } @@ -47,6 +47,12 @@ func addCachableHandler(e *gin.Engine, method string, endpoint string, generator html_buffer, err := generator(c, app_settings, db) if err != nil { log.Error().Msgf("could not generate html: %v", err) + // TODO : Need a proper error page + c.JSON(http.StatusInternalServerError, gin.H{ + "error": "could not render HTML", + "msg": err.Error(), + }) + return } // After handler (add to cache) diff --git a/app/cache.go b/app/cache.go index bdd4a49..a84e584 100644 --- a/app/cache.go +++ b/app/cache.go @@ -13,9 +13,9 @@ import ( const MAX_CACHE_SIZE_MB = 10 type EndpointCache struct { - name string - contents []byte - validUntil time.Time + Name string + Contents []byte + ValidUntil time.Time } func emptyEndpointCache() EndpointCache { @@ -36,7 +36,7 @@ type TimeValidator struct{} func (validator *TimeValidator) IsValid(cache *EndpointCache) bool { // We only return the cache if it's still valid - return cache.validUntil.After(time.Now()) + return cache.ValidUntil.After(time.Now()) } type TimedCache struct { @@ -54,9 +54,9 @@ func (cache *TimedCache) Store(name string, buffer []byte) error { } var cache_entry interface{} = EndpointCache{ - name: name, - contents: buffer, - validUntil: time.Now().Add(cache.cacheTimeout), + Name: name, + Contents: buffer, + ValidUntil: time.Now().Add(cache.cacheTimeout), } cache.cacheMap.Set(name, &cache_entry) cache.estimatedSize.Add(uint64(len(buffer))) @@ -85,11 +85,11 @@ func (cache *TimedCache) Size() uint64 { return cache.estimatedSize.Load() } -func makeCache(n_shards int, expiry_duration time.Duration) Cache { +func MakeCache(n_shards int, expiry_duration time.Duration, validator CacheValidator) Cache { return &TimedCache{ cacheMap: shardedmap.NewShardMap(n_shards), cacheTimeout: expiry_duration, estimatedSize: atomic.Uint64{}, - validator: &TimeValidator{}, + validator: validator, } } diff --git a/cmd/urchin/index_test.go b/cmd/urchin/index_test.go deleted file mode 100644 index da165cb..0000000 --- a/cmd/urchin/index_test.go +++ /dev/null @@ -1,68 +0,0 @@ -package main - -import ( - "fmt" - "net/http" - "net/http/httptest" - "testing" - - "github.com/matheusgomes28/urchin/app" - "github.com/matheusgomes28/urchin/common" - - "github.com/stretchr/testify/assert" -) - -type DatabaseMock struct{} - -func (db DatabaseMock) GetPosts() ([]common.Post, error) { - return []common.Post{ - { - Title: "TestPost", - Content: "TestContent", - Excerpt: "TestExcerpt", - Id: 0, - }, - }, nil -} - -func (db DatabaseMock) GetPost(post_id int) (common.Post, error) { - return common.Post{}, fmt.Errorf("not implemented") -} - -func (db DatabaseMock) AddPost(title string, excerpt string, content string) (int, error) { - return 0, fmt.Errorf("not implemented") -} - -func (db DatabaseMock) ChangePost(id int, title string, excerpt string, content string) error { - return nil -} - -func (db DatabaseMock) DeletePost(id int) error { - return fmt.Errorf("not implemented") -} - -func (db DatabaseMock) AddImage(string, string, string) error { - return fmt.Errorf("not implemented") -} - -func TestIndexPing(t *testing.T) { - app_settings := common.AppSettings{ - DatabaseAddress: "localhost", - DatabasePort: 3006, - DatabaseUser: "root", - DatabasePassword: "root", - DatabaseName: "urchin", - WebserverPort: 8080, - } - - database_mock := DatabaseMock{} - - r := app.SetupRoutes(app_settings, database_mock) - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/", nil) - r.ServeHTTP(w, req) - - assert.Equal(t, 200, w.Code) - assert.Contains(t, w.Body.String(), "TestPost") - assert.Contains(t, w.Body.String(), "TestExcerpt") -} diff --git a/tests/admin_app_tests/endpoint_tests/endpoint_test.go b/tests/admin_app_tests/endpoint_tests/endpoint_test.go new file mode 100644 index 0000000..f2a7aa7 --- /dev/null +++ b/tests/admin_app_tests/endpoint_tests/endpoint_test.go @@ -0,0 +1,68 @@ +package endpoint_tests + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + admin_app "github.com/matheusgomes28/urchin/admin-app" + "github.com/matheusgomes28/urchin/common" + "github.com/matheusgomes28/urchin/tests/mocks" + "github.com/stretchr/testify/assert" +) + +type postRequest struct { + Title string `json:"title"` + Excerpt string `json:"excerpt"` + Content string `json:"content"` +} + +type postResponse struct { + Id int `json:"id"` +} + +var app_settings = common.AppSettings{ + DatabaseAddress: "localhost", + DatabasePort: 3006, + DatabaseUser: "root", + DatabasePassword: "root", + DatabaseName: "urchin", + WebserverPort: 8080, +} + +func TestIndexPing(t *testing.T) { + + database_mock := mocks.DatabaseMock{} + r := admin_app.SetupRoutes(app_settings, database_mock) + w := httptest.NewRecorder() + + request := postRequest{ + Title: "", + Excerpt: "", + Content: "", + } + request_body, err := json.Marshal(request) + assert.Nil(t, err) + + req, _ := http.NewRequest("POST", "/posts", bytes.NewReader(request_body)) + req.Header.Add("content-type", "application/json") + r.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + + var response postResponse + err = json.Unmarshal(w.Body.Bytes(), &response) + assert.Nil(t, err) + + assert.Equal(t, response.Id, 0) +} + +// TODO : Test request without excerpt + +// TODO : Test request without content + +// TODO : Test request without title + +// TODO : Test request that fails to be added to database diff --git a/common/app_settings_test.go b/tests/app_tests/app_settings/app_settings_test.go similarity index 93% rename from common/app_settings_test.go rename to tests/app_tests/app_settings/app_settings_test.go index 324a875..9672dd5 100644 --- a/common/app_settings_test.go +++ b/tests/app_tests/app_settings/app_settings_test.go @@ -1,10 +1,11 @@ -package common +package app_settings_tests import ( "errors" "os" "testing" + "github.com/matheusgomes28/urchin/common" "github.com/pelletier/go-toml/v2" "github.com/stretchr/testify/assert" ) @@ -28,7 +29,7 @@ func writeToml(contents []byte) (s string, err error) { } func TestCorrectToml(t *testing.T) { - expected := AppSettings{ + expected := common.AppSettings{ DatabaseAddress: "test_database_address", DatabaseUser: "test_database_user", DatabasePassword: "test_database_password", @@ -42,7 +43,7 @@ func TestCorrectToml(t *testing.T) { filepath, err := writeToml(bytes) assert.Nil(t, err) - actual, err := ReadConfigToml(filepath) + actual, err := common.ReadConfigToml(filepath) assert.Nil(t, err) assert.Equal(t, actual, expected) } @@ -69,7 +70,7 @@ func TestMissingDatabaseAddress(t *testing.T) { filepath, err := writeToml(bytes) assert.Nil(t, err) - _, err = ReadConfigToml(filepath) + _, err = common.ReadConfigToml(filepath) assert.NotNil(t, err) } @@ -95,7 +96,7 @@ func TestMissingDatabaseUser(t *testing.T) { filepath, err := writeToml(bytes) assert.Nil(t, err) - _, err = ReadConfigToml(filepath) + _, err = common.ReadConfigToml(filepath) assert.NotNil(t, err) } @@ -120,7 +121,7 @@ func TestMissingDatabasePassword(t *testing.T) { filepath, err := writeToml(bytes) assert.Nil(t, err) - _, err = ReadConfigToml(filepath) + _, err = common.ReadConfigToml(filepath) assert.NotNil(t, err) } @@ -145,7 +146,7 @@ func TestMissingDatabaseName(t *testing.T) { filepath, err := writeToml(bytes) assert.Nil(t, err) - _, err = ReadConfigToml(filepath) + _, err = common.ReadConfigToml(filepath) assert.NotNil(t, err) } @@ -170,7 +171,7 @@ func TestMissingWebserverPort(t *testing.T) { filepath, err := writeToml(bytes) assert.Nil(t, err) - _, err = ReadConfigToml(filepath) + _, err = common.ReadConfigToml(filepath) assert.NotNil(t, err) } @@ -195,7 +196,7 @@ func TestMissingDatabasePort(t *testing.T) { filepath, err := writeToml(bytes) assert.Nil(t, err) - _, err = ReadConfigToml(filepath) + _, err = common.ReadConfigToml(filepath) assert.NotNil(t, err) } @@ -222,7 +223,7 @@ func TestWrongDatabasePortValueType(t *testing.T) { filepath, err := writeToml(bytes) assert.Nil(t, err) - _, err = ReadConfigToml(filepath) + _, err = common.ReadConfigToml(filepath) assert.NotNil(t, err) } @@ -249,6 +250,6 @@ func TestWrongwebserverPortValueType(t *testing.T) { filepath, err := writeToml(bytes) assert.Nil(t, err) - _, err = ReadConfigToml(filepath) + _, err = common.ReadConfigToml(filepath) assert.NotNil(t, err) } diff --git a/app/cache_test.go b/tests/app_tests/cache/cache_test.go similarity index 67% rename from app/cache_test.go rename to tests/app_tests/cache/cache_test.go index 98192c4..f7dea18 100644 --- a/app/cache_test.go +++ b/tests/app_tests/cache/cache_test.go @@ -1,44 +1,25 @@ -package app +package cache_tests import ( - "sync/atomic" "testing" "time" + "github.com/matheusgomes28/urchin/app" "github.com/stretchr/testify/assert" - shardedmap "github.com/zutto/shardedmap" ) type TrueTimeMockValidator struct{} -func (validator *TrueTimeMockValidator) IsValid(cache *EndpointCache) bool { +func (validator *TrueTimeMockValidator) IsValid(cache *app.EndpointCache) bool { return true } type FalseTimeMockValidator struct{} -func (validator *FalseTimeMockValidator) IsValid(cache *EndpointCache) bool { +func (validator *FalseTimeMockValidator) IsValid(cache *app.EndpointCache) bool { return false } -func makeTrueCacheMock() Cache { - return &TimedCache{ - cacheMap: shardedmap.NewShardMap(2), - cacheTimeout: 10 * time.Second, - estimatedSize: atomic.Uint64{}, - validator: &TrueTimeMockValidator{}, - } -} - -func makeFalseCacheMock() Cache { - return &TimedCache{ - cacheMap: shardedmap.NewShardMap(2), - cacheTimeout: 10 * time.Second, - estimatedSize: atomic.Uint64{}, - validator: &FalseTimeMockValidator{}, - } -} - func TestCacheAddition(t *testing.T) { test_data := []struct { @@ -51,7 +32,7 @@ func TestCacheAddition(t *testing.T) { {"nPe9Rkff6ER6EzAxPUIpxc8UBBLm71hhq2MO9hkQWisrfihUqv", []byte("oA7Hv1A7vOuZSKrPT4ZN5DGKNSHZqpLEvUA5hu54CMyIt8c78u")}, } - cache := makeTrueCacheMock() + cache := app.MakeCache(1, 10*time.Second, &TrueTimeMockValidator{}) rolling_size := uint64(0) for _, test_case := range test_data { @@ -63,7 +44,7 @@ func TestCacheAddition(t *testing.T) { endpoint_cache, err := cache.Get(test_case.name) assert.Nil(t, err) - assert.Equal(t, endpoint_cache.contents, test_case.contents) + assert.Equal(t, endpoint_cache.Contents, test_case.contents) } } @@ -79,7 +60,7 @@ func TestCacheFailure(t *testing.T) { {"nPe9Rkff6ER6EzAxPUIpxc8UBBLm71hhq2MO9hkQWisrfihUqv", []byte("oA7Hv1A7vOuZSKrPT4ZN5DGKNSHZqpLEvUA5hu54CMyIt8c78u")}, } - cache := makeFalseCacheMock() + cache := app.MakeCache(1, 10*time.Second, &FalseTimeMockValidator{}) rolling_size := uint64(0) for _, test_case := range test_data { @@ -93,3 +74,14 @@ func TestCacheFailure(t *testing.T) { assert.NotNil(t, err) } } + +// Tests that storing over 10MB fails +func TestCacheStoreMaxBytes(t *testing.T) { + cache := app.MakeCache(1, 10*time.Second, &FalseTimeMockValidator{}) + + err := cache.Store("fatty", make([]byte, 10000000)) + assert.Nil(t, err) + + err = cache.Store("slim", make([]byte, 1000)) + assert.NotNil(t, err) +} diff --git a/tests/app_tests/endpoint_tests/index_test.go b/tests/app_tests/endpoint_tests/index_test.go new file mode 100644 index 0000000..79c73c4 --- /dev/null +++ b/tests/app_tests/endpoint_tests/index_test.go @@ -0,0 +1,71 @@ +package endpoint_tests + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/matheusgomes28/urchin/app" + "github.com/matheusgomes28/urchin/common" + "github.com/matheusgomes28/urchin/tests/mocks" + + "github.com/stretchr/testify/assert" +) + +func TestIndexSuccess(t *testing.T) { + app_settings := common.AppSettings{ + DatabaseAddress: "localhost", + DatabasePort: 3006, + DatabaseUser: "root", + DatabasePassword: "root", + DatabaseName: "urchin", + WebserverPort: 8080, + } + + database_mock := mocks.DatabaseMock{ + GetPostsHandler: func() ([]common.Post, error) { + return []common.Post{ + { + Title: "TestPost", + Content: "TestContent", + Excerpt: "TestExcerpt", + Id: 0, + }, + }, nil + }, + } + + r := app.SetupRoutes(app_settings, database_mock) + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "TestPost") + assert.Contains(t, w.Body.String(), "TestExcerpt") +} + +func TestIndexFailToGetPosts(t *testing.T) { + app_settings := common.AppSettings{ + DatabaseAddress: "localhost", + DatabasePort: 3006, + DatabaseUser: "root", + DatabasePassword: "root", + DatabaseName: "urchin", + WebserverPort: 8080, + } + + database_mock := mocks.DatabaseMock{ + GetPostsHandler: func() ([]common.Post, error) { + return []common.Post{}, fmt.Errorf("invalid") + }, + } + + r := app.SetupRoutes(app_settings, database_mock) + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusInternalServerError, w.Code) +} diff --git a/tests/app_tests/endpoint_tests/posts_test.go b/tests/app_tests/endpoint_tests/posts_test.go new file mode 100644 index 0000000..e67179c --- /dev/null +++ b/tests/app_tests/endpoint_tests/posts_test.go @@ -0,0 +1,42 @@ +package endpoint_tests + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/matheusgomes28/urchin/app" + "github.com/matheusgomes28/urchin/common" + "github.com/matheusgomes28/urchin/tests/mocks" + "github.com/stretchr/testify/assert" +) + +func TestPostSuccess(t *testing.T) { + app_settings := common.AppSettings{ + DatabaseAddress: "localhost", + DatabasePort: 3006, + DatabaseUser: "root", + DatabasePassword: "root", + DatabaseName: "urchin", + WebserverPort: 8080, + } + + database_mock := mocks.DatabaseMock{ + GetPostHandler: func(post_id int) (common.Post, error) { + return common.Post{ + Title: "TestPost", + Content: "TestContent", + Excerpt: "TestExcerpt", + Id: post_id, + }, nil + }, + } + + r := app.SetupRoutes(app_settings, database_mock) + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/post/0", nil) + r.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "TestPost") +} diff --git a/tests/mocks/mocks.go b/tests/mocks/mocks.go new file mode 100644 index 0000000..2de05fb --- /dev/null +++ b/tests/mocks/mocks.go @@ -0,0 +1,36 @@ +package mocks + +import ( + "fmt" + + "github.com/matheusgomes28/urchin/common" +) + +type DatabaseMock struct { + GetPostHandler func(int) (common.Post, error) + GetPostsHandler func() ([]common.Post, error) +} + +func (db DatabaseMock) GetPosts() ([]common.Post, error) { + return db.GetPostsHandler() +} + +func (db DatabaseMock) GetPost(post_id int) (common.Post, error) { + return db.GetPostHandler(post_id) +} + +func (db DatabaseMock) AddPost(title string, excerpt string, content string) (int, error) { + return 0, nil +} + +func (db DatabaseMock) ChangePost(id int, title string, excerpt string, content string) error { + return nil +} + +func (db DatabaseMock) DeletePost(id int) error { + return fmt.Errorf("not implemented") +} + +func (db DatabaseMock) AddImage(string, string, string) error { + return fmt.Errorf("not implemented") +} From 79e6602532c58d10c90551abb5a8c5b77a3dfac6 Mon Sep 17 00:00:00 2001 From: Milad Rasouli <106727848+Milad75Rasouli@users.noreply.github.com> Date: Sat, 16 Mar 2024 14:19:00 +0330 Subject: [PATCH 2/3] check database connection #51 (#53) Adding a ping when database object is created to fail fast. --- database/database.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/database/database.go b/database/database.go index 79b8ed2..405316b 100644 --- a/database/database.go +++ b/database/database.go @@ -174,13 +174,17 @@ func (db SqlDatabase) AddImage(uuid string, name string, alt string) (err error) } func MakeSqlConnection(user string, password string, address string, port int, database string) (SqlDatabase, error) { - /// Checking the DB connection + /// TODO : let user specify the DB connection_str := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", user, password, address, port, database) db, err := sql.Open("mysql", connection_str) if err != nil { return SqlDatabase{}, err } + + if err := db.Ping(); err != nil { + return SqlDatabase{}, err + } // See "Important settings" section. db.SetConnMaxLifetime(time.Second * 5) db.SetMaxOpenConns(10) From 9efb9402cd318520cc6e74150f7e67dbeb75c24e Mon Sep 17 00:00:00 2001 From: Ali Assar <150648843+Ali-Assar@users.noreply.github.com> Date: Sat, 16 Mar 2024 17:58:44 +0330 Subject: [PATCH 3/3] Add Pagination (#50) Here is the things I've done for pagination: I edited the database query to support pagination Then I saw home handler use GetPosts so i modified that handler in app.go and at the end I Changed SetupRoutes in app.go --- app/app.go | 20 ++++++++++++++++++-- database/database.go | 7 ++++--- tests/app_tests/endpoint_tests/index_test.go | 6 +++--- tests/mocks/mocks.go | 6 +++--- 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/app/app.go b/app/app.go index c74adbf..fcd6858 100644 --- a/app/app.go +++ b/app/app.go @@ -3,6 +3,7 @@ package app import ( "bytes" "net/http" + "strconv" "time" "github.com/gin-gonic/gin" @@ -26,6 +27,9 @@ func SetupRoutes(app_settings common.AppSettings, database database.Database) *g addCachableHandler(r, "GET", "/contact", contactHandler, &cache, app_settings, database) addCachableHandler(r, "GET", "/post/:id", postHandler, &cache, app_settings, database) + // Add the pagination route as a cacheable endpoint + addCachableHandler(r, "GET", "/page/:num", homeHandler, &cache, app_settings, database) + // DO not cache as it needs to handlenew form values r.POST("/contact-send", makeContactFormHandler()) @@ -81,7 +85,19 @@ func addCachableHandler(e *gin.Engine, method string, endpoint string, generator // / This function will act as the handler for // / the home page func homeHandler(c *gin.Context, settings common.AppSettings, db database.Database) ([]byte, error) { - posts, err := db.GetPosts() + pageNum := 0 // Default to page 0 + if pageNumQuery := c.Param("num"); pageNumQuery != "" { + num, err := strconv.Atoi(pageNumQuery) + if err == nil && num > 0 { + pageNum = num + } else { + log.Error().Msgf("Invalid page number: %s", pageNumQuery) + } + } + limit := 10 // or whatever limit you want + offset := max((pageNum-1)*limit, 0) + + posts, err := db.GetPosts(limit, offset) if err != nil { return nil, err } @@ -92,7 +108,7 @@ func homeHandler(c *gin.Context, settings common.AppSettings, db database.Databa err = index_view.Render(c, html_buffer) if err != nil { - log.Error().Msgf("could not render index: %v", err) + log.Error().Msgf("Could not render index: %v", err) return []byte{}, err } diff --git a/database/database.go b/database/database.go index 405316b..07e570f 100644 --- a/database/database.go +++ b/database/database.go @@ -11,7 +11,7 @@ import ( ) type Database interface { - GetPosts() ([]common.Post, error) + GetPosts(int, int) ([]common.Post, error) GetPost(post_id int) (common.Post, error) AddPost(title string, excerpt string, content string) (int, error) ChangePost(id int, title string, excerpt string, content string) error @@ -28,8 +28,9 @@ type SqlDatabase struct { // / This function gets all the posts from the current // / database connection. -func (db SqlDatabase) GetPosts() (all_posts []common.Post, err error) { - rows, err := db.Connection.Query("SELECT title, excerpt, id FROM posts;") +func (db SqlDatabase) GetPosts(limit int, offset int) (all_posts []common.Post, err error) { + query := "SELECT title, excerpt, id FROM posts LIMIT ? OFFSET ?;" + rows, err := db.Connection.Query(query, limit, offset) if err != nil { return make([]common.Post, 0), err } diff --git a/tests/app_tests/endpoint_tests/index_test.go b/tests/app_tests/endpoint_tests/index_test.go index 79c73c4..d20583f 100644 --- a/tests/app_tests/endpoint_tests/index_test.go +++ b/tests/app_tests/endpoint_tests/index_test.go @@ -24,7 +24,7 @@ func TestIndexSuccess(t *testing.T) { } database_mock := mocks.DatabaseMock{ - GetPostsHandler: func() ([]common.Post, error) { + GetPostsHandler: func(limit int, offset int) ([]common.Post, error) { return []common.Post{ { Title: "TestPost", @@ -38,7 +38,7 @@ func TestIndexSuccess(t *testing.T) { r := app.SetupRoutes(app_settings, database_mock) w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/", nil) + req, _ := http.NewRequest("GET", "/page/0", nil) r.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -57,7 +57,7 @@ func TestIndexFailToGetPosts(t *testing.T) { } database_mock := mocks.DatabaseMock{ - GetPostsHandler: func() ([]common.Post, error) { + GetPostsHandler: func(limit int, offset int) ([]common.Post, error) { return []common.Post{}, fmt.Errorf("invalid") }, } diff --git a/tests/mocks/mocks.go b/tests/mocks/mocks.go index 2de05fb..8c04f07 100644 --- a/tests/mocks/mocks.go +++ b/tests/mocks/mocks.go @@ -8,11 +8,11 @@ import ( type DatabaseMock struct { GetPostHandler func(int) (common.Post, error) - GetPostsHandler func() ([]common.Post, error) + GetPostsHandler func(int, int) ([]common.Post, error) } -func (db DatabaseMock) GetPosts() ([]common.Post, error) { - return db.GetPostsHandler() +func (db DatabaseMock) GetPosts(limit int, offset int) ([]common.Post, error) { + return db.GetPostsHandler(limit, offset) } func (db DatabaseMock) GetPost(post_id int) (common.Post, error) {