diff --git a/app/app.go b/app/app.go index c74adbf..dfd4f88 100644 --- a/app/app.go +++ b/app/app.go @@ -3,6 +3,7 @@ package app import ( "bytes" "net/http" + "strconv" "time" "github.com/gin-gonic/gin" @@ -26,6 +27,9 @@ func SetupRoutes(app_settings common.AppSettings, database database.Database) *g addCachableHandler(r, "GET", "/contact", contactHandler, &cache, app_settings, database) addCachableHandler(r, "GET", "/post/:id", postHandler, &cache, app_settings, database) + // Add the pagination route as a cacheable endpoint + addCachableHandler(r, "GET", "/page/:num", homeHandler, &cache, app_settings, database) + // DO not cache as it needs to handlenew form values r.POST("/contact-send", makeContactFormHandler()) @@ -81,7 +85,23 @@ func addCachableHandler(e *gin.Engine, method string, endpoint string, generator // / This function will act as the handler for // / the home page func homeHandler(c *gin.Context, settings common.AppSettings, db database.Database) ([]byte, error) { - posts, err := db.GetPosts() + pageNum := 0 // Default to page 0 + if pageNumQuery := c.Param("num"); pageNumQuery != "" { + num, err := strconv.Atoi(pageNumQuery) + if err == nil && num > 0 { + pageNum = num + } else { + log.Error().Msgf("Invalid page number: %s", pageNumQuery) + } + } + + // Ensure pageNum is at least 1 + pageNum = max(pageNum, 1) + + limit := 10 // or whatever limit you want + offset := (pageNum - 1) * limit + + posts, err := db.GetPosts(limit, offset) if err != nil { return nil, err } @@ -92,9 +112,16 @@ func homeHandler(c *gin.Context, settings common.AppSettings, db database.Databa err = index_view.Render(c, html_buffer) if err != nil { - log.Error().Msgf("could not render index: %v", err) + log.Error().Msgf("Could not render index: %v", err) return []byte{}, err } return html_buffer.Bytes(), nil } + +func max(a, b int) int { + if a > b { + return a + } + return b +} diff --git a/database/database.go b/database/database.go index 79b8ed2..143a0ba 100644 --- a/database/database.go +++ b/database/database.go @@ -11,7 +11,7 @@ import ( ) type Database interface { - GetPosts() ([]common.Post, error) + GetPosts(int, int) ([]common.Post, error) GetPost(post_id int) (common.Post, error) AddPost(title string, excerpt string, content string) (int, error) ChangePost(id int, title string, excerpt string, content string) error @@ -28,8 +28,9 @@ type SqlDatabase struct { // / This function gets all the posts from the current // / database connection. -func (db SqlDatabase) GetPosts() (all_posts []common.Post, err error) { - rows, err := db.Connection.Query("SELECT title, excerpt, id FROM posts;") +func (db SqlDatabase) GetPosts(limit int, offset int) (all_posts []common.Post, err error) { + query := "SELECT title, excerpt, id FROM posts LIMIT ? OFFSET ?;" + rows, err := db.Connection.Query(query, limit, offset) if err != nil { return make([]common.Post, 0), err } diff --git a/tests/app_tests/endpoint_tests/index_test.go b/tests/app_tests/endpoint_tests/index_test.go index 79c73c4..d20583f 100644 --- a/tests/app_tests/endpoint_tests/index_test.go +++ b/tests/app_tests/endpoint_tests/index_test.go @@ -24,7 +24,7 @@ func TestIndexSuccess(t *testing.T) { } database_mock := mocks.DatabaseMock{ - GetPostsHandler: func() ([]common.Post, error) { + GetPostsHandler: func(limit int, offset int) ([]common.Post, error) { return []common.Post{ { Title: "TestPost", @@ -38,7 +38,7 @@ func TestIndexSuccess(t *testing.T) { r := app.SetupRoutes(app_settings, database_mock) w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/", nil) + req, _ := http.NewRequest("GET", "/page/0", nil) r.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -57,7 +57,7 @@ func TestIndexFailToGetPosts(t *testing.T) { } database_mock := mocks.DatabaseMock{ - GetPostsHandler: func() ([]common.Post, error) { + GetPostsHandler: func(limit int, offset int) ([]common.Post, error) { return []common.Post{}, fmt.Errorf("invalid") }, } diff --git a/tests/mocks/mocks.go b/tests/mocks/mocks.go index 2de05fb..8c04f07 100644 --- a/tests/mocks/mocks.go +++ b/tests/mocks/mocks.go @@ -8,11 +8,11 @@ import ( type DatabaseMock struct { GetPostHandler func(int) (common.Post, error) - GetPostsHandler func() ([]common.Post, error) + GetPostsHandler func(int, int) ([]common.Post, error) } -func (db DatabaseMock) GetPosts() ([]common.Post, error) { - return db.GetPostsHandler() +func (db DatabaseMock) GetPosts(limit int, offset int) ([]common.Post, error) { + return db.GetPostsHandler(limit, offset) } func (db DatabaseMock) GetPost(post_id int) (common.Post, error) {