Skip to content

Commit

Permalink
Implemented user sessions #317
Browse files Browse the repository at this point in the history
  • Loading branch information
tusharad committed Dec 3, 2024
1 parent 68009a0 commit 1e13008
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 4 deletions.
11 changes: 9 additions & 2 deletions Web/Scotty.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ module Web.Scotty
, ScottyM, ActionM, RoutePattern, File, Content(..), Kilobytes, ErrorHandler, Handler(..)
, ScottyState, defaultScottyState
-- ** Functions from Cookie module
, setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie
, setSimpleCookie, getCookie, getCookies, deleteCookie, makeSimpleCookie
-- ** Session Management
, Session (..), SessionId, SessionJar, createSessionJar,
createUserSession, createSession, readUserSession,
readSession, getUserSession, getSession, addSession, deleteSession, maintainSessions
) where

import qualified Web.Scotty.Trans as Trans
Expand All @@ -76,7 +80,10 @@ import qualified Network.Wai.Parse as W
import Web.FormUrlEncoded (FromForm)
import Web.Scotty.Internal.Types (ScottyT, ActionT, ErrorHandler, Param, RoutePattern, Options, defaultOptions, File, Kilobytes, ScottyState, defaultScottyState, ScottyException, StatusError(..), Content(..))
import UnliftIO.Exception (Handler(..), catch)
import Web.Scotty.Cookie (setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie)
import Web.Scotty.Cookie (setSimpleCookie, getCookie, getCookies, deleteCookie, makeSimpleCookie)
import Web.Scotty.Session (Session (..), SessionId, SessionJar, createSessionJar,
createUserSession, createSession, readUserSession,
readSession, getUserSession, getSession, addSession, deleteSession, maintainSessions)

{- $setup
>>> :{
Expand Down
186 changes: 186 additions & 0 deletions Web/Scotty/Session.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
{-# LANGUAGE OverloadedStrings #-}

{- |
Module : Web.Scotty.Cookie
Copyright : (c) 2014, 2015 Mārtiņš Mačs,
(c) 2023 Marco Zocca
License : BSD-3-Clause
Maintainer :
Stability : experimental
Portability : GHC
This module provides session management functionality for Scotty web applications.
==Example usage:
@
\{\-\# LANGUAGE OverloadedStrings \#\-\}
import Web.Scotty
import Web.Scotty.Session
import Control.Monad.IO.Class (liftIO)
main :: IO ()
main = do
-- Create a session jar
sessionJar <- createSessionJar
scotty 3000 $ do
-- Route to create a session
get "/create" $ do
sess <- createUserSession sessionJar "user data"
html $ "Session created with ID: " <> sessId sess
-- Route to read a session
get "/read" $ do
mSession <- getUserSession sessionJar
case mSession of
Nothing -> html "No session found or session expired."
Just sess -> html $ "Session content: " <> sessContent sess
@
-}
module Web.Scotty.Session (
Session (..),
SessionId,
SessionJar,

-- * Create Session Jar
createSessionJar,

-- * Create session
createUserSession,
createSession,

-- * Read session
readUserSession,
readSession,
getUserSession,
getSession,

-- * Add session
addSession,

-- * Delte session
deleteSession,

-- * Helper functions
maintainSessions,
) where

import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad
import Control.Monad.IO.Class (MonadIO (..))
import qualified Data.HashMap.Strict as HM
import qualified Data.Text as T
import Data.Time (NominalDiffTime, UTCTime, addUTCTime, getCurrentTime)
import System.Random (randomRIO)
import Web.Scotty.Action (ActionT)
import Web.Scotty.Cookie

-- | Type alias for session identifiers.
type SessionId = T.Text

-- | Represents a session containing an ID, expiration time, and content.
data Session a = Session
{ sessId :: SessionId
-- ^ Unique identifier for the session.
, sessExpiresAt :: UTCTime
-- ^ Expiration time of the session.
, sessContent :: a
-- ^ Content stored in the session.
}
deriving (Eq, Show)

-- | Type for session storage, a transactional variable containing a map of session IDs to sessions.
type SessionJar a = TVar (HM.HashMap SessionId (Session a))

-- | Creates a new session jar and starts a background thread to maintain it.
createSessionJar :: IO (SessionJar a)
createSessionJar = do
storage <- liftIO $ newTVarIO HM.empty
_ <- liftIO $ forkIO $ maintainSessions storage
return storage

-- | Continuously removes expired sessions from the session jar.
maintainSessions :: SessionJar a -> IO ()
maintainSessions sessionJar =
do
now <- getCurrentTime
let stillValid sess = sessExpiresAt sess > now
atomically $ modifyTVar sessionJar $ \m -> HM.filter stillValid m
threadDelay 1000000
maintainSessions sessionJar

-- | Adds a new session to the session jar.
addSession :: SessionJar a -> Session a -> IO ()
addSession sessionJar sess =
atomically $ modifyTVar sessionJar $ \m -> HM.insert (sessId sess) sess m

-- | Retrieves a session by its ID from the session jar.
getSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m (Maybe (Session a))
getSession sessionJar sId =
do
s <- liftIO $ readTVarIO sessionJar
return $ HM.lookup sId s

-- | Deletes a session by its ID from the session jar.
deleteSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m ()
deleteSession sessionJar sId =
liftIO $
atomically $
modifyTVar sessionJar $
HM.delete sId

{- | Retrieves the current user's session based on the "sess_id" cookie.
| Returns 'Nothing' if the session is expired or does not exist.
-}
getUserSession :: (MonadIO m) => SessionJar a -> ActionT m (Maybe (Session a))
getUserSession sessionJar = do
mSid <- getCookie "sess_id"
case mSid of
Nothing -> return Nothing
Just sid -> do
mSession <- lookupSession sid
case mSession of
Nothing -> return Nothing
Just sess -> do
now <- liftIO getCurrentTime
if sessExpiresAt sess < now
then do
deleteSession sessionJar (sessId sess)
return Nothing
else return $ Just sess
where
lookupSession = getSession sessionJar

-- | Reads the content of a session by its ID.
readSession :: (MonadIO m) => SessionJar a -> SessionId -> ActionT m (Maybe a)
readSession sessionJar sId = do
res <- getSession sessionJar sId
return $ sessContent <$> res

-- | Reads the content of the current user's session.
readUserSession :: (MonadIO m) => SessionJar a -> ActionT m (Maybe a)
readUserSession sessionJar = do
res <- getUserSession sessionJar
return $ sessContent <$> res

-- | The time-to-live for sessions, in seconds.
sessionTTL :: NominalDiffTime
sessionTTL = fromIntegral 36000 -- in seconds

-- | Creates a new session for a user, storing the content and setting a cookie.
createUserSession :: (MonadIO m) => SessionJar a -> a -> ActionT m (Session a)
createUserSession sessionJar content = do
sess <- liftIO $ createSession sessionJar content
setSimpleCookie "sess_id" (sessId sess)
return sess

-- | Creates a new session with a generated ID, sets its expiration, and adds it to the session jar.
createSession :: SessionJar a -> a -> IO (Session a)
createSession sessionJar content = do
sId <- liftIO $ T.pack <$> replicateM 32 (randomRIO ('a', 'z'))
now <- getCurrentTime
let expiresAt = addUTCTime sessionTTL now
sess = Session sId expiresAt content
liftIO $ addSession sessionJar sess
return $ Session sId expiresAt content
9 changes: 8 additions & 1 deletion Web/Scotty/Trans.hs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ module Web.Scotty.Trans
, ScottyT, ActionT
, ScottyState, defaultScottyState
-- ** Functions from Cookie module
, setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie
, setSimpleCookie, getCookie, getCookies, deleteCookie, makeSimpleCookie
-- ** Session Management
, Session (..), SessionId, SessionJar, createSessionJar,
createUserSession, createSession, readUserSession,
readSession, getUserSession, getSession, addSession, deleteSession, maintainSessions
) where

import Blaze.ByteString.Builder (fromByteString)
Expand All @@ -90,6 +94,9 @@ import Web.Scotty.Body (newBodyInfo)

import UnliftIO.Exception (Handler(..), catch)
import Web.Scotty.Cookie (setSimpleCookie,getCookie,getCookies,deleteCookie,makeSimpleCookie)
import Web.Scotty.Session (Session (..), SessionId, SessionJar, createSessionJar,
createUserSession, createSession, readUserSession,
readSession, getUserSession, getSession, addSession, deleteSession, maintainSessions)


-- | Run a scotty application using the warp server.
Expand Down
31 changes: 31 additions & 0 deletions examples/session.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{-# LANGUAGE OverloadedStrings #-}
module Main (main) where

import Web.Scotty
import qualified Data.Text.Lazy as LT
import qualified Data.Text as T

main :: IO ()
main = do
sessionJar <- liftIO createSessionJar :: IO (SessionJar T.Text)
scotty 3000 $ do
-- Login route
get "/login" $ do
username <- queryParam "username" :: ActionM String
password <- queryParam "password" :: ActionM String
if username == "foo" && password == "bar"
then do
_ <- createUserSession sessionJar "foo"
text "Login successful!"
else
text "Invalid username or password."
-- Dashboard route
get "/dashboard" $ do
mUser <- readUserSession sessionJar
case mUser of
Nothing -> text "Hello, user."
Just userName -> text $ "Hello, " <> LT.fromStrict userName <> "."
-- Logout route
get "/logout" $ do
deleteCookie "sess_id"
text "Logged out successfully."
4 changes: 3 additions & 1 deletion scotty.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Library
Web.Scotty.Trans.Strict
Web.Scotty.Internal.Types
Web.Scotty.Cookie
Web.Scotty.Session
other-modules: Web.Scotty.Action
Web.Scotty.Body
Web.Scotty.Route
Expand Down Expand Up @@ -93,7 +94,8 @@ Library
unordered-containers >= 0.2.10.0 && < 0.3,
wai >= 3.0.0 && < 3.3,
wai-extra >= 3.1.14,
warp >= 3.0.13
warp >= 3.0.13,
random >= 1.0.0.0

if impl(ghc < 8.0)
build-depends: fail
Expand Down
15 changes: 15 additions & 0 deletions test/Web/ScottySpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Data.Char
import Data.String
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy as TL
import qualified Data.Text as T
import qualified Data.Text.Lazy.Encoding as TLE
import Data.Time (UTCTime(..))
import Data.Time.Calendar (fromGregorian)
Expand Down Expand Up @@ -537,6 +538,20 @@ spec = do
withApp (Scotty.get "/nested" (nested simpleApp)) $ do
it "responds with the expected simpleApp response" $ do
get "/nested" `shouldRespondWith` 200 {matchHeaders = ["Content-Type" <:> "text/plain"], matchBody = "Hello, Web!"}

describe "Session Management" $ do
withApp (Scotty.get "/scotty" $ do
sessionJar <- liftIO createSessionJar
sess <- createUserSession sessionJar ("foo" :: T.Text)
mRes <- readSession sessionJar (sessId sess)
case mRes of
Nothing -> Scotty.status status400
Just res -> do
if res /= "foo" then Scotty.status status400
else text "all good"
) $ do
it "Roundtrip of session by adding and fetching a value" $ do
get "/scotty" `shouldRespondWith` 200

-- Unix sockets not available on Windows
#if !defined(mingw32_HOST_OS)
Expand Down

0 comments on commit 1e13008

Please sign in to comment.