diff --git a/internal/feed/feed.go b/internal/feed/feed.go index a94672b..3cb5e4a 100644 --- a/internal/feed/feed.go +++ b/internal/feed/feed.go @@ -250,7 +250,7 @@ func (feed *Feed) GetPublicItem(i string) (*PublicFeedItem, error) { return nil, FeedErrorInvalidFeedItem } - s, err := os.Stat(path.Join(feed.Path, i)) + s, err := os.Stat(path.Join(feed.Path, path.Join("/", i))) if err != nil { if os.IsNotExist(err) { @@ -276,8 +276,11 @@ func (feed *Feed) GetItemData(item string) ([]byte, error) { var content []byte // Get path to feed item - filePath := path.Join(feed.Path, item) + filePath := path.Join(feed.Path, path.Join("/"+item)) + if path.Base(filePath) == "secret" || path.Base(filePath) == "pin" || path.Base(filePath) == "config.json" { + return nil, fmt.Errorf("%w: %s", FeedErrorItemNotFound, item) + } // Read feed item content content, err := os.ReadFile(filePath) if err != nil { @@ -315,7 +318,7 @@ func (feed *Feed) IsSecretValid(secret string) error { // AddItem reads content from r and creates a new file in the feed directory // with a name and file extension based on contentType, then notifies clients -func (f *Feed) AddItem(contentType string, r io.ReadCloser) error { +func (f *Feed) AddItem(contentType string, r io.Reader) error { fL.Logger.Debug("Adding Item", slog.String("feed", f.Name()), slog.String("content-type", contentType)) var err error @@ -385,10 +388,11 @@ func (f *Feed) AddItem(contentType string, r io.ReadCloser) error { } // Notify additon to all connected browsers - if err = f.WebSocketManager.NotifyAdd(publicItem); err != nil { - return err + if f.WebSocketManager != nil { + if err = f.WebSocketManager.NotifyAdd(publicItem); err != nil { + return err + } } - // Send push notification to subscribed browsers err = f.sendPushNotification() if err != nil { @@ -405,7 +409,7 @@ func (f *Feed) AddItem(contentType string, r io.ReadCloser) error { func (f *Feed) RemoveItem(item string) error { fL.Logger.Debug("Remove Item", slog.String("name", item), slog.String("feed", f.Path)) - itemPath := path.Join(f.Path, item) + itemPath := path.Join(f.Path, path.Join("/", item)) // Save public item before deletion for notification later publicItem, err := f.GetPublicItem(item) @@ -423,8 +427,10 @@ func (f *Feed) RemoveItem(item string) error { } // Notify all connected websockets - if err = f.WebSocketManager.NotifyRemove(publicItem); err != nil { - return err + if f.WebSocketManager != nil { + if err = f.WebSocketManager.NotifyRemove(publicItem); err != nil { + return err + } } fL.Logger.Debug("Removed Item", slog.String("name", item), slog.String("feed", f.Path)) diff --git a/internal/feed/feed_test.go b/internal/feed/feed_test.go new file mode 100644 index 0000000..7d1e897 --- /dev/null +++ b/internal/feed/feed_test.go @@ -0,0 +1,129 @@ +package feed + +import ( + "bytes" + "os" + "testing" +) + +func TestGetFeedItemData(t *testing.T) { + t.Cleanup(func() { + os.RemoveAll("tests/feed1") + }) + f, err := NewFeed("tests/feed1") + if err != nil { + t.Fatal(err) + } + + reader := bytes.NewReader([]byte("test")) + + err = f.AddItem("text/plain", reader) + if err != nil { + t.Fatal(err) + } + + pf, err := f.Public() + if err != nil { + t.Fatal(err) + } + i := pf.Items[0] + b, err := f.GetItemData(i.Name) + if len(b) == 0 || err != nil { + t.Fatal(err) + } +} + +func TestPathTraversalGet(t *testing.T) { + t.Cleanup(func() { + os.RemoveAll("tests/feed1") + os.RemoveAll("tests/feed2") + }) + _, err := NewFeed("tests/feed1") + if err != nil { + t.Fatal(err) + } + + f, err := NewFeed("tests/feed2") + if err != nil { + t.Fatal(err) + } + + b, err := f.GetItemData("../feed1/config.json") + + if len(b) != 0 || err == nil { + t.Fatal("Path traversal not blocked") + } +} + +func TestPathTraversalDelete(t *testing.T) { + t.Cleanup(func() { + os.RemoveAll("tests/feed1") + os.RemoveAll("tests/feed2") + }) + _, err := NewFeed("tests/feed1") + if err != nil { + t.Fatal(err) + } + + f, err := NewFeed("tests/feed2") + if err != nil { + t.Fatal(err) + } + + err = f.RemoveItem("../feed1/config.json") + + if err == nil { + t.Fatal("Path traversal not blocked") + } +} + +func TestPathTraversalPublicItem(t *testing.T) { + t.Cleanup(func() { + os.RemoveAll("tests/feed1") + os.RemoveAll("tests/feed2") + }) + _, err := NewFeed("tests/feed1") + if err != nil { + t.Fatal(err) + } + + f, err := NewFeed("tests/feed2") + if err != nil { + t.Fatal(err) + } + + p, err := f.GetPublicItem("../feed1/config.json") + + if p != nil || err == nil { + t.Fatal("Path traversal not blocked") + } +} + +func TestPublicItem(t *testing.T) { + t.Cleanup(func() { + os.RemoveAll("tests/feed1") + }) + f, err := NewFeed("tests/feed1") + if err != nil { + t.Fatal(err) + } + + reader := bytes.NewReader([]byte("test")) + + err = f.AddItem("text/plain", reader) + if err != nil { + t.Fatal(err) + } + + pf, err := f.Public() + if err != nil { + t.Fatal(err) + } + i := pf.Items[0] + + p, err := f.GetPublicItem(i.Name) + + if p == nil || err != nil { + t.Fatal(err) + } +} diff --git a/internal/feed/websocket.go b/internal/feed/websocket.go index 68da388..2102650 100644 --- a/internal/feed/websocket.go +++ b/internal/feed/websocket.go @@ -89,6 +89,7 @@ func (m *WebSocketManager) RunSocketForFeed(feedName string, w http.ResponseWrit m.FeedSockets = append(m.FeedSockets, feedSockets) } + // Upgrade http connection to websocket c, err := upgrader.Upgrade(w, r, nil) if err != nil { utils.CloseWithCodeAndMessage(w, 500, "Unable to upgrade WebSocket") @@ -96,6 +97,7 @@ func (m *WebSocketManager) RunSocketForFeed(feedName string, w http.ResponseWrit feedSockets.websockets = append(feedSockets.websockets, c) + // Get provided secret and validate feed access secret, _ := utils.GetSecret(r) f, err := m.FeedManager.GetFeedWithAuth(feedName, secret) @@ -113,19 +115,26 @@ func (m *WebSocketManager) RunSocketForFeed(feedName string, w http.ResponseWrit } } + // Cleanup defer func() { feedSockets.RemoveConn(c) c.Close() }() + // Start waiting for messages for { mt, message, err := c.ReadMessage() - wsL.Logger.Debug("Message Received", slog.String("message", string(message)), slog.Int("messageType", mt)) + wsL.Logger.Debug("Message Received", + slog.String("message", string(message)), + slog.Int("messageType", mt)) if err != nil { - slog.Error("Error reading message", slog.String("error", err.Error()), slog.Int("messageType", mt)) + slog.Error("Error reading message", + slog.String("error", err.Error()), + slog.Int("messageType", mt)) break } switch strings.TrimSpace(string(message)) { + // Return pubic feed content case "feed": pf, err := f.Public() if err != nil { @@ -139,8 +148,11 @@ func (m *WebSocketManager) RunSocketForFeed(feedName string, w http.ResponseWrit } } +// NotifyAdd notifies all connected websockets that an item has been added func (m *WebSocketManager) NotifyAdd(item *PublicFeedItem) error { - wsL.Logger.Debug("Notify websocket", slog.Any("item", item), slog.Int("ws count", len(m.FeedSockets))) + wsL.Logger.Debug("Notify websocket", + slog.Any("item", item), + slog.Int("ws count", len(m.FeedSockets))) for _, f := range m.FeedSockets { wsL.Logger.Debug("checking feed", slog.String("feedName", f.feedName)) if f.feedName == item.Feed.Name { @@ -158,12 +170,15 @@ func (m *WebSocketManager) NotifyAdd(item *PublicFeedItem) error { return nil } +// NotifyRemove notify all connected websockets that an item has been removed func (m *WebSocketManager) NotifyRemove(item *PublicFeedItem) error { - wsL.Logger.Debug("Notify websocket", slog.Any("item", item), slog.Int("ws count", len(m.FeedSockets))) + wsL.Logger.Debug("Notify websocket", + slog.Any("item", item), + slog.Int("ws count", len(m.FeedSockets))) for _, f := range m.FeedSockets { wsL.Logger.Debug("checking feed", slog.String("feedName", f.feedName)) if f.feedName == item.Feed.Name { - wsL.Logger.Debug("Found feed", slog.String("feedName", f.feedName)) + wsL.Logger.Debug("found feed", slog.String("feedName", f.feedName)) for _, w := range f.websockets { if err := w.WriteJSON(FeedNotification{ Action: "remove", diff --git a/internal/handlers/handlers.go b/internal/handlers/handlers.go index 0dd667f..cdd33e0 100644 --- a/internal/handlers/handlers.go +++ b/internal/handlers/handlers.go @@ -19,18 +19,12 @@ import ( "github.com/go-chi/chi/v5" "github.com/ybizeul/ybfeed/internal/feed" "github.com/ybizeul/ybfeed/internal/utils" + "github.com/ybizeul/ybfeed/pkg/yblog" "github.com/ybizeul/ybfeed/web/ui" ) -var logLevel = new(slog.LevelVar) -var logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{Level: logLevel})).WithGroup("http") - -func init() { - if os.Getenv("DEBUG") != "" || os.Getenv("DEBUG_HTTP") != "" { - logLevel.Set(slog.LevelDebug) - } -} +var hL = yblog.NewYBLogger("http", []string{"DEBUG", "DEBUG_HTTP"}) var webUiHandler = http.FileServer(http.FS(ui.GetUiFs())) @@ -39,7 +33,7 @@ var webUiHandler = http.FileServer(http.FS(ui.GetUiFs())) // then it serves this file from webUiHandler, otherwise it returns // index.html for proper react routing func RootHandlerFunc(w http.ResponseWriter, r *http.Request) { - logger.Debug("Root request", slog.Any("request", r)) + hL.Logger.Debug("Root request", slog.Any("request_uri", r.RequestURI)) p := r.URL.Path @@ -55,7 +49,7 @@ func RootHandlerFunc(w http.ResponseWriter, r *http.Request) { matches, err := fs.Glob(ui, p) if err != nil { - logger.Error("Unable to get web ui fs", slog.String("error", err.Error())) + hL.Logger.Error("Unable to get web ui fs", slog.String("error", err.Error())) } if len(matches) == 1 { @@ -69,12 +63,12 @@ func RootHandlerFunc(w http.ResponseWriter, r *http.Request) { content, err := fs.ReadFile(ui, "index.html") if err != nil { - logger.Error("Unable to read index.html from web ui", slog.String("error", err.Error())) + hL.Logger.Error("Unable to read index.html from web ui", slog.String("error", err.Error())) } _, err = w.Write(content) if err != nil { - logger.Error("Error while writing HTTP response", slog.String("error", err.Error())) + hL.Logger.Error("Error while writing HTTP response", slog.String("error", err.Error())) } } @@ -167,7 +161,7 @@ func (api *ApiHandler) StartServer() { r := api.GetServer() err := http.ListenAndServe(fmt.Sprintf("%s:%d", api.ListenAddr, api.HttpPort), r) if err != nil { - logger.Error("Unable to start HTTP server", + hL.Logger.Error("Unable to start HTTP server", slog.String("error", err.Error())) os.Exit(1) } @@ -189,7 +183,7 @@ func (api *ApiHandler) GetServer() *chi.Mux { r.Get("/api", func(w http.ResponseWriter, r *http.Request) { if _, err := w.Write([]byte("OK")); err != nil { - logger.Error("Cannot write Ping response") + hL.Logger.Error("Cannot write Ping response") } }) @@ -244,7 +238,7 @@ func (api *ApiHandler) feedWSHandler(w http.ResponseWriter, r *http.Request) { } func (api *ApiHandler) feedHandlerFunc(w http.ResponseWriter, r *http.Request) { - logger.Debug("Feed API request", slog.Any("request", r)) + hL.Logger.Debug("Feed API request", slog.Any("request_uri", r.RequestURI)) feedName, _ := url.QueryUnescape(chi.URLParam(r, "feedName")) @@ -301,12 +295,12 @@ func (api *ApiHandler) feedHandlerFunc(w http.ResponseWriter, r *http.Request) { return } if _, err = w.Write(j); err != nil { - logger.Error("Error while writing HTTP response", slog.String("error", err.Error())) + hL.Logger.Error("Error while writing HTTP response", slog.String("error", err.Error())) } } func (api *ApiHandler) feedPatchHandlerFunc(w http.ResponseWriter, r *http.Request) { - logger.Debug("Feed API Set PIN request", slog.Any("request", r)) + hL.Logger.Debug("Feed API Set PIN request", slog.Any("request", r)) secret, _ := utils.GetSecret(r) feedName, _ := url.QueryUnescape(chi.URLParam(r, "feedName")) @@ -332,7 +326,7 @@ func (api *ApiHandler) feedPatchHandlerFunc(w http.ResponseWriter, r *http.Reque if err != nil { w.WriteHeader(500) if _, err = w.Write([]byte(err.Error())); err != nil { - logger.Error("Error while writing HTTP response", slog.String("error", err.Error())) + hL.Logger.Error("Error while writing HTTP response", slog.String("error", err.Error())) } return } @@ -347,7 +341,7 @@ func (api *ApiHandler) feedPatchHandlerFunc(w http.ResponseWriter, r *http.Reque } func (api *ApiHandler) feedItemHandlerFunc(w http.ResponseWriter, r *http.Request) { - logger.Debug("Item API GET request", slog.Any("request", r)) + hL.Logger.Debug("Item API GET request", slog.Any("request", r)) secret, _ := utils.GetSecret(r) @@ -388,12 +382,12 @@ func (api *ApiHandler) feedItemHandlerFunc(w http.ResponseWriter, r *http.Reques return } if _, err = w.Write(content); err != nil { - logger.Error("Error while writing HTTP response", slog.String("error", err.Error())) + hL.Logger.Error("Error while writing HTTP response", slog.String("error", err.Error())) } } func (api *ApiHandler) feedPostHandlerFunc(w http.ResponseWriter, r *http.Request) { - logger.Debug("Item API POST request", slog.Any("request", r)) + hL.Logger.Debug("Item API POST request", slog.Any("request", r)) secret, _ := utils.GetSecret(r) @@ -438,7 +432,7 @@ func (api *ApiHandler) feedPostHandlerFunc(w http.ResponseWriter, r *http.Reques } func (api *ApiHandler) feedItemDeleteHandlerFunc(w http.ResponseWriter, r *http.Request) { - logger.Debug("Item API DELETE request", slog.Any("request", r)) + hL.Logger.Debug("Item API DELETE request", slog.Any("request", r)) secret, _ := utils.GetSecret(r) @@ -478,13 +472,13 @@ func (api *ApiHandler) feedItemDeleteHandlerFunc(w http.ResponseWriter, r *http. } if _, err = w.Write([]byte("Item Removed")); err != nil { - logger.Error("Error while writing HTTP response", slog.String("error", err.Error())) + hL.Logger.Error("Error while writing HTTP response", slog.String("error", err.Error())) } } func (api *ApiHandler) feedSubscriptionHandlerFunc(w http.ResponseWriter, r *http.Request) { - logger.Debug("Feed subscription request", slog.Any("request", r)) + hL.Logger.Debug("Feed subscription request", slog.Any("request", r)) secret, _ := utils.GetSecret(r) @@ -535,7 +529,7 @@ func (api *ApiHandler) feedSubscriptionHandlerFunc(w http.ResponseWriter, r *htt func (api *ApiHandler) feedUnsubscribeHandlerFunc(w http.ResponseWriter, r *http.Request) { - logger.Debug("Feed subscription request", slog.Any("request", r)) + hL.Logger.Debug("Feed subscription request", slog.Any("request", r)) secret, _ := utils.GetSecret(r) diff --git a/test.env b/test.env new file mode 100644 index 0000000..e69de29