-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.TxEvent.STM
-- Copyright   :  (c) Kevin Donnelly & Matthew Fluet 2006
-- License     :  BSD-style
-- Maintainer  :  mfluet@acm.org
-- Stability   :  experimental
-- Portability :  non-portable (requires TxEvent)
--
-- Software Transactional Memory
--
-- This library provides software transactional memory implemented on
-- top of the "Control.Concurrent.TxEvent" library.  The major
-- functional differences (other than efficiency) between this
-- implementation and that of "Control.Concurrent.STM" are (1) that
-- 'orElse' is unbaised, (2) that 'newTVar' returns an 'IO' action,
-- rather than a 'STM' action, and (3) that an uncaught exception
-- effectively behaves as 'retry'.
--
-----------------------------------------------------------------------------

module Control.Concurrent.TxEvent.STM (
  -- * The STM monad and basic operations
    STM         -- abstract, instance Functor, Monad, MonadPlus
  , atomically  -- :: STM a -> IO a
  , retry       -- :: STM a
  , orElse      -- :: STM a -> STM a -> STM a
  , check       -- :: Bool -> STM a
  , catchSTM    -- :: STM a -> (Exception -> STM a) -> STM a

  -- * TVars
  , TVar        -- abstract, instance Eq
  , newTVar     -- :: a -> IO (TVar a)
  , readTVar    -- :: TVar a -> STM a
  , writeTVar   -- :: TVar a -> a -> STM ()

  -- *  Miscellaneous
  , stmToEvt    -- :: STM a -> Evt a
  ) where

import Control.Monad
import qualified Control.Exception as Exception
import Control.Exception (Exception)

import Control.Concurrent
import Control.Concurrent.TxEvent

-- Misc
forkIO_ :: IO () -> IO ()
forkIO_ act = forkIO act >> return ()

----------------------------------------------------------------------
----------------------------------------------------------------------

newtype STM a = STM (Evt a)
unSTM (STM evt) = evt

atomically :: STM a -> IO a
atomically (STM evt) = sync evt

instance Monad STM 
    where
      return x = STM (alwaysEvt x)
      stm >>= f = STM (thenEvt (unSTM stm) (\ x -> unSTM (f x)))
      stm >> stm' = stm >>= (\ _ -> stm')
instance MonadPlus STM
    where
      mzero = STM (neverEvt)
      mplus stm1 stm2 = STM (chooseEvt (unSTM stm1) (unSTM stm2))
instance Functor STM
    where
      fmap f stm = stm >>= return . f

retry :: STM a
retry = mzero

orElse :: STM a -> STM a -> STM a
orElse = mplus

check :: Bool -> STM a
check b = if b then return undefined else retry

catchSTM :: STM a -> (Exception -> STM a) -> STM a
catchSTM stm f = STM (catchEvt (unSTM stm) (\ exn -> unSTM (f exn)))


newtype TVar a = TVar (SChan ThreadId, SChan a, SChan a)

instance Eq (TVar a)
    where
      (==) (TVar (tch1, rch1, wch1)) (TVar (tch2, rch2, wch2)) =
          tch1 == tch2 && rch1 == rch2 && wch1 == wch2

newTVar :: a -> IO (TVar a)
newTVar x = do 
  tch <- sync $ newSChan
  rch <- sync $ newSChan
  wch <- sync $ newSChan
  let serve x = do
        tid' <- recvEvt tch
        x' <- (do sendEvt wch x
                  alwaysEvt x)
              `chooseEvt`
              (recvEvt rch)
        return (tid', x')
  let loopEvt tid x =
        (do (tid', x') <- serve x
            if tid /= tid'
               then neverEvt
               else loopEvt tid' x')
        `chooseEvt`
        (alwaysEvt x) 
  let loopIO x = do
        x'' <- sync $ do
                 (tid', x') <- serve x
                 loopEvt tid' x'
        loopIO x''
  forkIO_ $ loopIO x
  return (TVar (tch, rch, wch))

readTVar :: TVar a -> STM a
readTVar (TVar (tch, rch, wch)) =
    STM $ do 
      tid <- myThreadIdEvt
      sendEvt tch tid
      recvEvt rch

writeTVar :: TVar a -> a -> STM ()
writeTVar (TVar (tch, rch, wch)) x =
    STM $ do
      tid <- myThreadIdEvt
      sendEvt tch tid
      sendEvt wch x

{-|
Perfectly safe.
-}
stmToEvt :: STM a -> Evt a
stmToEvt = unSTM
