diff --git a/Web/Scotty.hs b/Web/Scotty.hs index 415c8a1..45ea4f7 100644 --- a/Web/Scotty.hs +++ b/Web/Scotty.hs @@ -55,7 +55,11 @@ module Web.Scotty , ScottyM, ActionM, RoutePattern, File, Content(..), Kilobytes, ErrorHandler, Handler(..) , ScottyState, defaultScottyState -- ** Functions from Cookie module - , setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie + , setSimpleCookie, getCookie, getCookies, deleteCookie, makeSimpleCookie + -- ** Session Management + , Session (..), SessionId, SessionJar, createSessionJar, + createUserSession, createSession, readUserSession, + readSession, getUserSession, getSession, addSession, deleteSession, maintainSessions ) where import qualified Web.Scotty.Trans as Trans @@ -76,7 +80,10 @@ import qualified Network.Wai.Parse as W import Web.FormUrlEncoded (FromForm) import Web.Scotty.Internal.Types (ScottyT, ActionT, ErrorHandler, Param, RoutePattern, Options, defaultOptions, File, Kilobytes, ScottyState, defaultScottyState, ScottyException, StatusError(..), Content(..)) import UnliftIO.Exception (Handler(..), catch) -import Web.Scotty.Cookie (setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie) +import Web.Scotty.Cookie (setSimpleCookie, getCookie, getCookies, deleteCookie, makeSimpleCookie) +import Web.Scotty.Session (Session (..), SessionId, SessionJar, createSessionJar, + createUserSession, createSession, readUserSession, + readSession, getUserSession, getSession, addSession, deleteSession, maintainSessions) {- $setup >>> :{ diff --git a/Web/Scotty/Session.hs b/Web/Scotty/Session.hs new file mode 100644 index 0000000..431021f --- /dev/null +++ b/Web/Scotty/Session.hs @@ -0,0 +1,186 @@ +{-# LANGUAGE OverloadedStrings #-} + +{- | +Module : Web.Scotty.Cookie +Copyright : (c) 2014, 2015 Mārtiņš Mačs, + (c) 2023 Marco Zocca + +License : BSD-3-Clause +Maintainer : +Stability : experimental +Portability : GHC + +This module provides session management functionality for Scotty web applications. + +==Example usage: + +@ +\{\-\# LANGUAGE OverloadedStrings \#\-\} + +import Web.Scotty +import Web.Scotty.Session +import Control.Monad.IO.Class (liftIO) +main :: IO () +main = do + -- Create a session jar + sessionJar <- createSessionJar + scotty 3000 $ do + -- Route to create a session + get "/create" $ do + sess <- createUserSession sessionJar "user data" + html $ "Session created with ID: " <> sessId sess + -- Route to read a session + get "/read" $ do + mSession <- getUserSession sessionJar + case mSession of + Nothing -> html "No session found or session expired." + Just sess -> html $ "Session content: " <> sessContent sess +@ +-} +module Web.Scotty.Session ( + Session (..), + SessionId, + SessionJar, + + -- * Create Session Jar + createSessionJar, + + -- * Create session + createUserSession, + createSession, + + -- * Read session + readUserSession, + readSession, + getUserSession, + getSession, + + -- * Add session + addSession, + + -- * Delte session + deleteSession, + + -- * Helper functions + maintainSessions, +) where + +import Control.Concurrent +import Control.Concurrent.STM +import Control.Monad +import Control.Monad.IO.Class (MonadIO (..)) +import qualified Data.HashMap.Strict as HM +import qualified Data.Text as T +import Data.Time (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime) +import System.Random (randomRIO) +import Web.Scotty.Action (ActionT) +import Web.Scotty.Cookie + +-- | Type alias for session identifiers. +type SessionId = T.Text + +-- | Represents a session containing an ID, expiration time, and content. +data Session a = Session + { sessId :: SessionId + -- ^ Unique identifier for the session. + , sessExpiresAt :: UTCTime + -- ^ Expiration time of the session. + , sessContent :: a + -- ^ Content stored in the session. + } + deriving (Eq, Show) + +-- | Type for session storage, a transactional variable containing a map of session IDs to sessions. +type SessionJar a = TVar (HM.HashMap SessionId (Session a)) + +-- | Creates a new session jar and starts a background thread to maintain it. +createSessionJar :: IO (SessionJar a) +createSessionJar = do + storage <- liftIO $ newTVarIO HM.empty + _ <- liftIO $ forkIO $ maintainSessions storage + return storage + +-- | Continuously removes expired sessions from the session jar. +maintainSessions :: SessionJar a -> IO () +maintainSessions sessionJar = + do + now <- getCurrentTime + let stillValid sess = sessExpiresAt sess > now + atomically $ modifyTVar sessionJar $ \m -> HM.filter stillValid m + threadDelay 1000000 + maintainSessions sessionJar + +-- | Adds a new session to the session jar. +addSession :: SessionJar a -> Session a -> IO () +addSession sessionJar sess = + atomically $ modifyTVar sessionJar $ \m -> HM.insert (sessId sess) sess m + +-- | Retrieves a session by its ID from the session jar. +getSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m (Maybe (Session a)) +getSession sessionJar sId = + do + s <- liftIO $ readTVarIO sessionJar + return $ HM.lookup sId s + +-- | Deletes a session by its ID from the session jar. +deleteSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m () +deleteSession sessionJar sId = + liftIO $ + atomically $ + modifyTVar sessionJar $ + HM.delete sId + +{- | Retrieves the current user's session based on the "sess_id" cookie. +| Returns 'Nothing' if the session is expired or does not exist. +-} +getUserSession :: (MonadIO m) => SessionJar a -> ActionT m (Maybe (Session a)) +getUserSession sessionJar = do + mSid <- getCookie "sess_id" + case mSid of + Nothing -> return Nothing + Just sid -> do + mSession <- lookupSession sid + case mSession of + Nothing -> return Nothing + Just sess -> do + now <- liftIO getCurrentTime + if sessExpiresAt sess < now + then do + deleteSession sessionJar (sessId sess) + return Nothing + else return $ Just sess + where + lookupSession = getSession sessionJar + +-- | Reads the content of a session by its ID. +readSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m (Maybe a) +readSession sessionJar sId = do + res <- getSession sessionJar sId + return $ sessContent <$> res + +-- | Reads the content of the current user's session. +readUserSession :: (MonadIO m) => SessionJar a -> ActionT m (Maybe a) +readUserSession sessionJar = do + res <- getUserSession sessionJar + return $ sessContent <$> res + +-- | The time-to-live for sessions, in seconds. +sessionTTL :: NominalDiffTime +sessionTTL = fromIntegral 36000 -- in seconds + +-- | Creates a new session for a user, storing the content and setting a cookie. +createUserSession :: (MonadIO m) => SessionJar a -> a -> ActionT m (Session a) +createUserSession sessionJar content = do + sess <- liftIO $ createSession sessionJar content + setSimpleCookie "sess_id" (sessId sess) + return sess + +-- | Creates a new session with a generated ID, sets its expiration, and adds it to the session jar. +createSession :: SessionJar a -> a -> IO (Session a) +createSession sessionJar content = do + sId <- liftIO $ T.pack <$> replicateM 32 (randomRIO ('a', 'z')) + now <- getCurrentTime + let expiresAt = addUTCTime sessionTTL now + sess = Session sId expiresAt content + liftIO $ addSession sessionJar sess + return $ Session sId expiresAt content diff --git a/Web/Scotty/Trans.hs b/Web/Scotty/Trans.hs index b3468ea..fa84136 100644 --- a/Web/Scotty/Trans.hs +++ b/Web/Scotty/Trans.hs @@ -64,7 +64,11 @@ module Web.Scotty.Trans , ScottyT, ActionT , ScottyState, defaultScottyState -- ** Functions from Cookie module - , setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie + , setSimpleCookie, getCookie, getCookies, deleteCookie, makeSimpleCookie + -- ** Session Management + , Session (..), SessionId, SessionJar, createSessionJar, + createUserSession, createSession, readUserSession, + readSession, getUserSession, getSession, addSession, deleteSession, maintainSessions ) where import Blaze.ByteString.Builder (fromByteString) @@ -90,6 +94,9 @@ import Web.Scotty.Body (newBodyInfo) import UnliftIO.Exception (Handler(..), catch) import Web.Scotty.Cookie (setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie) +import Web.Scotty.Session (Session (..), SessionId, SessionJar, createSessionJar, + createUserSession, createSession, readUserSession, + readSession, getUserSession, getSession, addSession, deleteSession, maintainSessions) -- | Run a scotty application using the warp server. diff --git a/examples/session.hs b/examples/session.hs new file mode 100644 index 0000000..035a28c --- /dev/null +++ b/examples/session.hs @@ -0,0 +1,31 @@ +{-# LANGUAGE OverloadedStrings #-} +module Main (main) where + +import Web.Scotty +import qualified Data.Text.Lazy as LT +import qualified Data.Text as T + +main :: IO () +main = do + sessionJar <- liftIO createSessionJar :: IO (SessionJar T.Text) + scotty 3000 $ do + -- Login route + get "/login" $ do + username <- queryParam "username" :: ActionM String + password <- queryParam "password" :: ActionM String + if username == "foo" && password == "bar" + then do + _ <- createUserSession sessionJar "foo" + text "Login successful!" + else + text "Invalid username or password." + -- Dashboard route + get "/dashboard" $ do + mUser <- readUserSession sessionJar + case mUser of + Nothing -> text "Hello, user." + Just userName -> text $ "Hello, " <> LT.fromStrict userName <> "." + -- Logout route + get "/logout" $ do + deleteCookie "sess_id" + text "Logged out successfully." \ No newline at end of file diff --git a/scotty.cabal b/scotty.cabal index bd311a2..a347c47 100644 --- a/scotty.cabal +++ b/scotty.cabal @@ -64,6 +64,7 @@ Library Web.Scotty.Trans.Strict Web.Scotty.Internal.Types Web.Scotty.Cookie + Web.Scotty.Session other-modules: Web.Scotty.Action Web.Scotty.Body Web.Scotty.Route @@ -93,7 +94,8 @@ Library unordered-containers >= 0.2.10.0 && < 0.3, wai >= 3.0.0 && < 3.3, wai-extra >= 3.1.14, - warp >= 3.0.13 + warp >= 3.0.13, + random >= 1.0.0.0 if impl(ghc < 8.0) build-depends: fail diff --git a/test/Web/ScottySpec.hs b/test/Web/ScottySpec.hs index c6c36df..0b62d56 100644 --- a/test/Web/ScottySpec.hs +++ b/test/Web/ScottySpec.hs @@ -11,6 +11,7 @@ import Data.Char import Data.String import Data.Text.Lazy (Text) import qualified Data.Text.Lazy as TL +import qualified Data.Text as T import qualified Data.Text.Lazy.Encoding as TLE import Data.Time (UTCTime(..)) import Data.Time.Calendar (fromGregorian) @@ -537,6 +538,20 @@ spec = do withApp (Scotty.get "/nested" (nested simpleApp)) $ do it "responds with the expected simpleApp response" $ do get "/nested" `shouldRespondWith` 200 {matchHeaders = ["Content-Type" <:> "text/plain"], matchBody = "Hello, Web!"} + + describe "Session Management" $ do + withApp (Scotty.get "/scotty" $ do + sessionJar <- liftIO createSessionJar + sess <- createUserSession sessionJar ("foo" :: T.Text) + mRes <- readSession sessionJar (sessId sess) + case mRes of + Nothing -> Scotty.status status400 + Just res -> do + if res /= "foo" then Scotty.status status400 + else text "all good" + ) $ do + it "Roundtrip of session by adding and fetching a value" $ do + get "/scotty" `shouldRespondWith` 200 -- Unix sockets not available on Windows #if !defined(mingw32_HOST_OS)