-
-
Save notcome/df83f27b84088e0d9bcb to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DeriveDataTypeable #-} | |
{-# LANGUAGE TemplateHaskell #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE InstanceSigs #-} | |
module Data.RecordPool where | |
import Data.Data (Data, Typeable) | |
import Data.SafeCopy (base, deriveSafeCopy, SafeCopy(..)) | |
import Control.Applicative ((<$>), (<*>)) | |
import Control.Monad.Reader (ask) | |
import Control.Monad.State (get, put) | |
import Data.Acid | |
import Data.Acid.Advanced | |
import Data.IxSet | |
( Indexable, IxSet(..), ixFun, ixSet | |
, empty, (@=), getOne | |
) | |
import qualified Data.IxSet as Ix | |
import System.Random | |
import Data.Time (formatTime) | |
import Data.Time.Clock (getCurrentTime) | |
import System.Locale (defaultTimeLocale) | |
import Data.Accounts | |
newtype ExpireTime = ExpireTime Int | |
deriving (Eq, Ord, Data, Typeable, Show) | |
newtype VCode = VCode { unVCode :: String} | |
deriving (Eq, Ord, Data, Typeable, Show) | |
$(deriveSafeCopy 0 'base ''ExpireTime) | |
$(deriveSafeCopy 0 'base ''VCode) | |
data Record k = Record { | |
getKey :: k | |
, getVCode :: VCode | |
, getETime :: ExpireTime | |
} deriving (Eq, Ord, Data, Typeable, Show) | |
$(deriveSafeCopy 0 'base ''Record) | |
instance (Typeable k, Ord k) => Indexable (Record k) where | |
empty = ixSet | |
[ ixFun $ \bp -> [getKey bp] | |
, ixFun $ \bp -> [getVCode bp] | |
] | |
type RecordPool k = IxSet (Record k) | |
newRecord :: (Typeable k, Ord k) | |
=> k -> VCode -> ExpireTime -> Update (RecordPool k) () | |
newRecord key vcode etime = do | |
let record = Record key vcode etime | |
pool' <- Ix.updateIx key record <$> get | |
put pool' | |
validateRecord :: (Typeable k, Ord k) | |
=> VCode -> ExpireTime -> Query (RecordPool k) (Maybe k) | |
validateRecord vcode now = do | |
pool <- ask | |
case getOne $ pool @= vcode of | |
Nothing -> return Nothing | |
Just Record { getKey = key, getETime = etime } -> | |
if etime < now | |
then return Nothing | |
else return $ Just key | |
removeRecord' :: (Typeable k, Ord k) | |
=> VCode -> Update (RecordPool k) () | |
removeRecord' vcode = do | |
pool' <- Ix.deleteIx vcode <$> get | |
put pool' | |
-- Helper Functions -- | |
getNextVCode :: IO VCode | |
getNextVCode = VCode . show <$> getStdRandom (randomR (100000000000000 :: Integer, 999999999999999 :: Integer)) | |
expireIn :: ExpireTime -> IO ExpireTime | |
expireIn (ExpireTime ttl) = do | |
now <- (read <$> formatTime defaultTimeLocale "%s" <$> getCurrentTime) :: IO Int | |
return $ ExpireTime $ now + ttl | |
-- Template Haskell of Acid State breaks here. -- | |
-- Writing those types manually here. -- | |
-- Should use makeAcidic when this bug is fixed. -- | |
data NewRecord k = NewRecord k VCode ExpireTime | |
deriving (Typeable) | |
$(deriveSafeCopy 0 'base ''NewRecord) | |
instance (Typeable k, SafeCopy k) => Method (NewRecord k) where | |
type MethodResult (NewRecord k) = () | |
type MethodState (NewRecord k) = RecordPool k | |
instance (Typeable k, SafeCopy k) => UpdateEvent (NewRecord k) | |
data ValidateRecord k = ValidateRecord VCode ExpireTime | |
deriving (Typeable) | |
$(deriveSafeCopy 0 'base ''ValidateRecord) | |
instance (Typeable k, SafeCopy k) => Method (ValidateRecord k) where | |
type MethodResult (ValidateRecord k) = Maybe k | |
type MethodState (ValidateRecord k) = RecordPool k | |
instance (Typeable k, SafeCopy k) => QueryEvent (ValidateRecord k) | |
data RemoveRecord k = RemoveRecord VCode | |
deriving (Typeable) | |
$(deriveSafeCopy 0 'base ''RemoveRecord) | |
instance (Typeable k, SafeCopy k) => Method (RemoveRecord k) where | |
type MethodResult (RemoveRecord k) = () | |
type MethodState (RemoveRecord k) = RecordPool k | |
instance (Typeable k, SafeCopy k) => UpdateEvent (RemoveRecord k) | |
instance (Typeable k, SafeCopy k, Ord k) => IsAcidic (RecordPool k) where | |
acidEvents = [ UpdateEvent (\(NewRecord key vcode etime) -> newRecord key vcode etime) | |
, QueryEvent (\(ValidateRecord vcode now) -> validateRecord vcode now) | |
, UpdateEvent (\(RemoveRecord vcode) -> removeRecord' vcode) | |
] | |
-- Interfaces -- | |
data RecordPools = RecordPools { | |
getNewAccountPool :: AcidState (RecordPool Email) | |
, getResetPswdPool :: AcidState (RecordPool Email) | |
, getCookiePool :: AcidState (RecordPool AccountId) | |
} | |
class PoolType t where | |
type RecordKey t | |
getRecordPool :: t -> RecordPools -> AcidState (RecordPool (RecordKey t)) | |
getRecordKey :: t -> RecordKey t | |
data NewAccountEmail = NewAccountEmail Email | NewAccountPool | |
instance PoolType NewAccountEmail where | |
type RecordKey NewAccountEmail = Email | |
getRecordPool _ pools = getNewAccountPool pools | |
getRecordKey (NewAccountEmail email) = email | |
data ResetPswdEmail = ResetPswdEmail Email | ResetPswdPool | |
instance PoolType ResetPswdEmail where | |
type RecordKey ResetPswdEmail = Email | |
getRecordPool _ pools = getResetPswdPool pools | |
getRecordKey (ResetPswdEmail email) = email | |
data CookieAccountId = CookieAccountId AccountId | CookiePool | |
instance PoolType CookieAccountId where | |
type RecordKey CookieAccountId = AccountId | |
getRecordPool _ pools = getCookiePool pools | |
getRecordKey (CookieAccountId id) = id | |
insertNewRecord | |
:: (PoolType k, SafeCopy (RecordKey k), Typeable (RecordKey k)) | |
=> RecordPools -> k -> ExpireTime -> IO VCode | |
insertNewRecord pools wrappedKey ttl = do | |
let pool = getRecordPool wrappedKey pools | |
let key = getRecordKey wrappedKey | |
etime <- expireIn ttl | |
vcode <- getNextVCode | |
update' pool $ NewRecord key vcode etime | |
return vcode | |
getKeyFromRecord | |
:: (PoolType t, SafeCopy (RecordKey t), Typeable (RecordKey t)) | |
=> RecordPools -> t -> VCode -> IO (Maybe (RecordKey t)) | |
getKeyFromRecord pools t vcode = do | |
let pool = getRecordPool t pools | |
now <- expireIn $ ExpireTime 0 | |
query' pool $ ValidateRecord vcode now | |
removeRecord | |
:: (PoolType t, SafeCopy (RecordKey t), Typeable (RecordKey t)) | |
=> RecordPools -> t -> VCode -> IO () | |
removeRecord pools t vcode = do | |
let pool = getRecordPool t pools | |
update' pool $ RemoveRecord vcode | |
renewCookie :: RecordPools -> AccountId -> VCode -> ExpireTime -> IO () | |
renewCookie pools id vcode etime = do | |
let pool = getCookiePool pools | |
update' pool $ NewRecord id vcode etime |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment