-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.TxEvent
-- Copyright   :  (c) Kevin Donnelly & Matthew Fluet 2006
-- License     :  BSD-style
-- Maintainer  :  mfluet@acm.org
-- Stability   :  experimental
-- Portability :  non-portable (requires STM)
--
-- Transactional Events.
-- 
-- This library provides /first-class synchronous events/ in the style
-- of CML (<http://cml.cs.uchicago.edu/>), but generalizes the concept
-- to allow /multiple/, /dependent/ events to be combined in a single
-- event.  The semantics of these generalized events ensures an
-- /all-or-nothing/ transactional property -- either all of the
-- constituent events synchronize or none of them synchronize.  When
-- the constituent events include synchronous message passing, the
-- semantics ensures that no thread is able to complete its
-- synchronization until all of its (transitive) communications
-- partners are willing to commit to a compatible synchronization.
--
-----------------------------------------------------------------------------

module Control.Concurrent.TxEvent (
  -- * Event type
    Evt           -- abstract, instance Functor, Monad, MonadPlus

  -- * Synchronization
  , sync          -- :: Evt a -> IO a

  -- * Monadic event combinators
  , alwaysEvt     -- :: a -> Evt a
  , thenEvt       -- :: Evt a -> (a -> Evt b) -> Evt b
  , neverEvt      -- :: Evt a
  , chooseEvt     -- :: Evt a -> Evt a -> Evt a,

  -- * Exceptions
  , throwEvt      -- :: Exception -> Evt a
  , catchEvt      -- :: Evt a -> (Exception -> Evt a) -> Evt a

  -- * Synchronous channels
  , SChan         -- abstract, instance Eq
  , newSChan      -- :: Evt (SChan a)
  , sendEvt       -- :: SChan a -> a -> Evt ()
  , recvEvt       -- :: SChan a -> Evt a

  -- * Time delays
  , timeOutEvt    -- :: Int -> Evt
  , timeDiffEvt   -- :: System.Time.TimeDiff -> Evt ()
  , clockTimeEvt  -- :: System.Time.ClockTime -> Evt ()

  -- *  Miscellaneous
  , myThreadIdEvt -- :: Evt ThreadId

  ) where

import Prelude
import Control.Monad
import qualified Control.Exception as Exception
import Control.Exception (Exception)
import Control.Concurrent
import Control.Concurrent.STM
import Control.Concurrent.STM.TVar
import Control.Concurrent.STM.TMVar

import qualified Data.List as List
import qualified Data.Map as Map
import Data.Map (Map)
import qualified Data.Maybe as Maybe

import qualified System.Time as Time
import System.Time (ClockTime,TimeDiff)

import qualified Debug.Trace as Trace

-- Misc
forkIO_ :: IO () -> IO ()
forkIO_ act = forkIO act >> return ()

throwAssertionFailed :: String -> a
throwAssertionFailed s = Exception.throw (Exception.AssertionFailed s)

timeDiffToMicroSeconds :: TimeDiff -> Int
timeDiffToMicroSeconds td = fromIntegral ((Time.tdPicosec td) `quot` 1000000)

----------------------------------------------------------------------
----------------------------------------------------------------------

debugEvt = False
debugMsg = when debugEvt . Trace.putTraceMsg

doCounts = False
doChecks = False
doYields = True

-- Global integer counters;
-- A unique counter is associated with each thread synchronization.
-- It maintains the number of running search threads.
type CountVar = TVar Int
newCountVar = newTVar 1
incCount :: CountVar -> IO ()
incCount cnt = when doCounts $ atomically $ do
  c <- readTVar cnt
  writeTVar cnt (c + 1)
decCount :: CountVar -> IO ()
decCount cnt = when doCounts $ atomically $ do
  c <- readTVar cnt
  writeTVar cnt (c - 1)
-- A forkable action to watch the behavior of a counter.
watchCount :: ThreadId -> CountVar -> IO ()
watchCount tid cnt = 
    let loop last = do
          c <- atomically $ do
                 c <- readTVar cnt
                 if c == last then retry else return c
          debugMsg $ (show tid) ++ ": cnt -> " ++ (show c)
          loop c in
    loop (-1)


-- Global boolean flags;
-- A unique flag is associated with
--   a) each thread synchronization;
--      False when thread has not yet synchronized
--   b) each complete path;
--      False when path has not been chosen for commitment
type BoolFlag = TVar Bool
newBoolFlag = newTVar False

instance Show (TVar a)
    where
      show _ = "TVar"

-- Event synchronization paths;
-- A 'path' uniquely describes the non-deterministic choices and
-- communication partners of an event synchronization.  At each
-- communication point, a TVar maintains a list of path extensions
-- that may be able to commit to this path.  At each communication
-- point, a tuple maintains the event state trail of the communication
-- partner (at the time of communication) and a TVar maintaining the
-- list of path extensions that may be able to commit to the partner's
-- path.
type CompletedSearch = (Path, BoolFlag)
type CompletedSearchList = TVar [CompletedSearch]
newCompletedSearchList = newTVar []

data ChooseElement = ChooseLeft | ChooseRight deriving (Eq, Show)
data CommElement = CommSend | CommRecv deriving (Eq, Show)

data PathElement = Choose ChooseElement
                 | Catch 
                 | Comm CommElement (Trail, CompletedSearchList) CompletedSearchList

-- The CompletedSearchList associated with a communication partner's
-- event state is unique; hence, equality of the CompletedSearchList
-- implies equality of the event state trail.
instance Eq PathElement
    where
      (==) (Choose ce1) (Choose ce2) = ce1 == ce2
      (==) (Catch) (Catch) = True
      (==) (Comm ce1 (_, pcs1') pcs1) (Comm ce2 (_, pcs2') pcs2) =
          ce1 == ce2 && pcs1' == pcs2' && pcs1 == pcs2
      (==) _ _ = False

instance Show PathElement
    where
      show (Choose ce) = show ce
      show (Catch) = "Catch"
      show (Comm ce (s', _) _) = (show ce) ++ " " ++ (show s')

chooseLeft = Choose ChooseLeft
chooseRight = Choose ChooseRight
commSend = Comm CommSend
commRecv = Comm CommRecv

type Path = [PathElement]

-- Event synchronization trails;
-- A 'trail' is an active path recording the thread id of the
-- synchronizing thread, the unique boolean flag associated with that
-- thread synchronization, and the current event synchronization path.
newtype Trail = Trail (ThreadId, BoolFlag, Path)

instance Eq Trail
    where
      (==) (Trail (tid1, bft1, p1)) (Trail (tid2, bft2, p2)) =
          tid1 == tid2 && bft1 == bft2 && p1 == p2

instance Show Trail
    where
      show (Trail (tid, _, p)) = 
          "(" ++ (show tid) ++ "," ++ (show p) ++ ")"

-- Event synchronization tracks;
-- A 'track' is a completed path recording the thread id of the
-- synchronizing thread, the unique boolean flag associated with that
-- thread synchronization, the complete event synchronization path,
-- and the unique boolean flag associated with that complete path.
newtype Track = Track (ThreadId, BoolFlag, Path, BoolFlag)

instance Eq Track
    where
      (==) (Track (tid1, bft1, p1, bfp1)) (Track (tid2, bft2, p2, bfp2)) =
          tid1 == tid2 && bft1 == bft2 && p1 == p2 && bfp1 == bfp2

instance Show Track
    where
      show (Track (tid, _, p, _)) = 
          "(" ++ (show tid) ++ "," ++ (show p) ++ ")"


-- Event dependency maps;
-- A 'dependency trail map' records the maximal event state trail of
-- each thread in the dependencies of a trail.
-- A 'dependency track map' records the maximal event state track of
-- each thread in the dependencies of a track.
type DepTrailMap = Map ThreadId Trail
type DepTrackMap = Map ThreadId Track

-- Event synchronization state;
-- The state of an event synchronization pairs a trail with its
-- dependency trail map, and includes a counter for the number of
-- active search threads.
newtype EvtState = EvtState (Trail, DepTrailMap, CountVar)

instance Show EvtState
    where
      show (EvtState (tr, dm, _)) = "(" ++ (show tr) ++ 
                                    -- "," ++ (show dm) ++ 
                                    ")"


-- Synchronous event computation;
-- 
data EvtRes a = Always a | Throw Exception
type EvtCont a = EvtRes a -> EvtState -> IO ()
newtype Evt a = Evt (EvtCont a -> EvtState -> IO ())
{- ^
A value of type @'Evt' a@ is an /event/ which, when synchronized upon,
performs some synchronous operations before returning a value of type
@a@.

When synchronized upon, the event may perform (tentative) synchronous
message passing, but will exhibit no observable effects (i.e., will
not return) until the event and all communication partners can commit.

'Evt' is a monad (with plus), so 'Evt' actions can be combined using
either the @do@-notation or the operations from the 'Monad' and
'MonadPlus' classes.
-}
unEvt (Evt f) = f

-- Force evaluation of an event, catching exceptions in the Evt monad.
forceEvt :: Evt a -> EvtCont a -> EvtState -> IO ()
forceEvt evt k s = do
  f <- Exception.catch (Exception.evaluate (unEvt evt))
                       (\ exn -> return (\ k s -> k (Throw exn) s))
  f k s

-- Fizzle a search thread if
--  a) the synchronization has completed 
--     (necessarily, via a different search thread)
--  b) a depended upon communication partner has already synchronized
fizzleEvt :: String -> EvtState -> IO () -> IO ()
fizzleEvt str s@(EvtState (Trail (_, bft, _), dm, cnt)) act = do
  checkEvtState s
  b <- atomically (readTVar bft)
  if b then do decCount cnt
               return () 
       else fizzleAux (Map.elems dm)
  where fizzleAux [] = act
        fizzleAux ((Trail (_, bft', _)):rest) = do
          b' <- atomically (readTVar bft')
          if b' then do decCount cnt
                        return ()
                else fizzleAux rest

{-|
The always commitable event computation; 
the 'return' of the `Evt` monad.
-}
alwaysEvt :: a -> Evt a
alwaysEvt x = 
    Evt (\ k s@(EvtState (t@(Trail(tid, bft, p)), dm, cnt)) -> 
             let alwaysEvtAct = k (Always x) s in
             fizzleEvt "alwaysEvt" s alwaysEvtAct)

{-|
Sequential composition of event computations; 
the '>>=' of the `Evt` monad.
-}
thenEvt :: Evt a -> (a -> Evt b) -> Evt b
thenEvt evt f = 
    Evt (\ k s@(EvtState (t@(Trail(tid, bft, p)), dm, cnt)) ->
             let thenEvtAct = do
                   let k' = (\ x s ->
                                 case x of
                                   Always x -> forceEvt (f x) k s
                                   Throw exn -> k (Throw exn) s)
                   forceEvt evt k' s in
             fizzleEvt "thenEvt" s thenEvtAct)

{-|
The never commitable event computation; 
the 'mzero' of the `Evt` monad.
-}
neverEvt :: Evt a
neverEvt = 
    Evt (\ k s@(EvtState (t@(Trail(tid, bft, p)), dm, cnt)) -> 
             let neverEvtAct = do
                   decCount cnt
                   return () in
             fizzleEvt "neverEvt" s neverEvtAct)

{-|
Non-deterministic composition of event compuations;
the 'mplus' of the `Evt` monad.
-}
chooseEvt :: Evt a -> Evt a -> Evt a
chooseEvt evt1 evt2 = 
    Evt (\ k s@(EvtState (Trail (tid, bft, p), dm, cnt)) -> 
             let chooseEvtAct = do
                   forkIO_ $ do
                     incCount cnt
                     let t' = Trail (tid, bft, (chooseLeft):p)
                         dm' = Map.insert tid t' dm
                         s' = EvtState (t', dm', cnt)
                     forceEvt evt1 k s'
                   forkIO_ $ do
                     incCount cnt
                     let t' = Trail (tid, bft, (chooseRight):p)
                         dm' = Map.insert tid t' dm
                         s' = EvtState (t', dm', cnt)
                     forceEvt evt2 k s'
                   decCount cnt
                   return () in
             fizzleEvt "chooseEvt" s chooseEvtAct)

instance Monad Evt
    where 
      {-# INLINE return #-}
      {-# INLINE (>>=)  #-}
      {-# INLINE (>>)   #-}
      return x = alwaysEvt x
      evt >>= f = thenEvt evt f
      evt >> evt' = evt >>= (\ _ -> evt')
instance MonadPlus Evt
    where
      {-# INLINE mzero  #-}
      {-# INLINE mplus  #-}
      mzero = neverEvt
      mplus = chooseEvt
instance Functor Evt
    where
      {-# INLINE fmap   #-}
      fmap f evt = evt >>= return . f

{-|
A variant of `Exception.throw` that can be used within the 'Evt' monad.
-}
throwEvt :: Exception -> Evt a
throwEvt exn = 
    Evt (\ k s@(EvtState (t@(Trail(tid, bft, p)), dm, cnt)) ->
             let throwEvtAct = k (Throw exn) s in
             fizzleEvt "throwEvt" s throwEvtAct)

{-|
Exception handling within event computations.
-}
catchEvt :: Evt a -> (Exception -> Evt a) -> Evt a
catchEvt evt f =
    Evt (\ k s@(EvtState (t@(Trail(tid, bft, p)), dm, cnt)) ->
             let catchEvtAct = do
                   let k' = \ x s@(EvtState (Trail (tid, bft, p), dm, cnt)) ->
                                case x of
                                  Always x -> k (Always x) s
                                  Throw exn -> 
                                       let t' = Trail (tid, bft, (Catch):p)
                                           dm' = Map.insert tid t' dm
                                           s' = EvtState (t', dm', cnt) in
                                       forceEvt (f exn) k s'
                   forceEvt evt k' s in
             fizzleEvt "catchEvt" s catchEvtAct)

-- Synchronous channels;
-- A synchronous channel is represented by a send list and a recv
-- list.  The send list maintains, for each sender, the event state of
-- the sender, the value sent, and a continuation to spawn a search
-- thread when supplied with the information from a matching recv.
-- The recv list maintains, for each receiver, the event state of the
-- receiver, and a continuation to spawn a search thread when supplied
-- with the information from a matching send.
type Sender a =
    (EvtState, a, 
     CompletedSearchList -> DepTrailMap -> EvtState 
     -> CompletedSearchList -> IO ())
type SendList a = [Sender a]
type Recver a = 
    (EvtState, a ->
     CompletedSearchList -> DepTrailMap -> EvtState 
     -> CompletedSearchList -> IO ())
type RecvList a = [Recver a]
newtype SChan a = SChan (TVar (SendList a), TVar (RecvList a))
{- ^
A `SChan` is a synchronous channel, used for communication between
concurrent threads.  Message passing is synchronous: both the sender
and the receiver must be ready to communicate before either can
proceed.
-}

instance Eq (SChan a)
    where 
      (==) (SChan (sl1, rl1)) (SChan (sl2, rl2)) =
          sl1 == sl2 && rl1 == rl2

-- Clean a send list of entries for synchronized threads;
-- returns the cleaned list.
cleanSendList :: TVar (SendList a) -> STM (SendList a)
cleanSendList sl = do
  sl' <- readTVar sl
  sl' <- foldM (\ acc e@(EvtState (Trail (_, bft, _), _, _), _, _) -> do
                  b <- readTVar bft
                  if b then return acc
                       else return (e : acc)) 
               [] 
               sl'
  writeTVar sl sl'
  return sl'

-- Clean a recv list of entries for synchronized threads;
-- returns the cleaned list.
cleanRecvList :: TVar (RecvList a) -> STM (RecvList a)
cleanRecvList rl = do
  rl' <- readTVar rl
  rl' <- foldM (\ acc e@(EvtState (Trail (_, bft, _), _, _), _) -> do
                  b <- readTVar bft
                  if b then return acc
                       else return (e : acc))
               [] 
               rl'
  writeTVar rl rl'
  return rl'

{-| 
Create a new synchronous channel.
-}
newSChan :: Evt (SChan a)
newSChan = 
    Evt (\ k s ->
             let newSChanAct = do
                   (sl, rl) <- atomically $ do 
                                 sl <- newTVar []
                                 rl <- newTVar []
                                 return (sl, rl)
                   k (Always (SChan (sl, rl))) s in
             fizzleEvt "newSChan" s newSChanAct)

{-| 
Send a value on the channel.
-}
sendEvt :: SChan a -> a -> Evt ()
sendEvt (SChan (sl, rl)) x = 
    Evt (\ k s@(EvtState (t@(Trail(tid, bft, p)), dm, cnt)) ->
             let sendEvtAct = do
                   let sendCont pcs dmU (EvtState (t'@(Trail (tid', bft', p')), _, _)) 
                                pcs' = 
                           forkIO_ $ do
                             incCount cnt
                             let pRecv' = ((commRecv (t, pcs) pcs'):p')
                                 tRecv' = Trail (tid', bft', pRecv')
                                 pSend = ((commSend (t', pcs') pcs):p)
                                 tSend = Trail (tid, bft, pSend)
                                 dmU' = Map.insert tid tSend (Map.insert tid' tRecv' dmU)
                             k (Always ()) (EvtState (tSend, dmU', cnt))
                   -- Note that by atomically adding ourselves to the
                   -- send list and taking a copy of the recv list,
                   -- we send to all receivers already on the recv
                   -- list and are on the send list for all future
                   -- receivers.
                   rl' <- atomically $ do
                            sl' <- cleanSendList sl
                            writeTVar sl ((s, x, sendCont) : sl')
                            rl' <- cleanRecvList rl
                            return rl'
                   -- Spawn search threads for all matching communications.
                   mapM_ (\ (s', recvCont') -> do
                            let mdmU = s `coherent` s'
                            case mdmU of
                              Just dmU -> do -- Allocate completed search lists
                                             pcs <- atomically newCompletedSearchList
                                             pcs' <- atomically newCompletedSearchList
                                             sendCont pcs dmU s' pcs'
                                             recvCont' x pcs' dmU s pcs
                                             when doYields yield
                                             return ()
                              Nothing -> return ())
                         rl'
                   decCount cnt
                   return () in
             fizzleEvt "sendEvt" s sendEvtAct)

{-| 
Receive a value on the channel.
-}
recvEvt :: SChan a -> Evt a
recvEvt (SChan (sl, rl)) = 
    Evt (\ k s@(EvtState (t@(Trail(tid, bft, p)), dm, cnt)) ->
             let recvEvtAct = do
                   let recvCont x pcs dmU (EvtState (t'@(Trail (tid', bft', p')), _, _)) 
                                  pcs' =
                           forkIO_ $ do
                             incCount cnt 
                             let pSend' = ((commSend (t, pcs) pcs'):p')
                                 tSend' = Trail (tid', bft', pSend')
                                 pRecv = ((commRecv (t', pcs') pcs):p)
                                 tRecv = Trail (tid, bft, pRecv)
                                 dmU' = Map.insert tid tRecv (Map.insert tid' tSend' dmU)
                             k (Always x) (EvtState (tRecv, dmU', cnt))
                   -- Note that by atomically adding ourselves to the
                   -- recv list and taking a copy of the send list,
                   -- we recv from all senders already on the send
                   -- list and are on the recv list for all future
                   -- senders.
                   sl' <- atomically $ do
                            rl' <- cleanRecvList rl
                            writeTVar rl ((s, recvCont) : rl')
                            sl' <- cleanSendList sl
                            return sl'
                   -- Spawn search threads for all matching communications.
                   mapM_ (\ (s', x', sendCont') -> do
                            let mdmU = s `coherent` s'
                            case mdmU of
                              Just dmU -> do -- Allocate completed search lists
                                             pcs <- atomically newCompletedSearchList
                                             pcs' <- atomically newCompletedSearchList
                                             recvCont x' pcs dmU s' pcs'
                                             sendCont' pcs' dmU s pcs
                                             when doYields yield
                                             return ()
                              Nothing -> return ())
                         sl'
                   decCount cnt
                   return () in
             fizzleEvt "recvEvt" s recvEvtAct)

-- Coherence of event states for communication.
coherent :: EvtState -> EvtState -> Maybe DepTrailMap
coherent (EvtState (Trail (tid1, bft1, p1), dm1, _)) 
         (EvtState (Trail (tid2, bft2, p2), dm2, _)) =
    if tid1 == tid2 
       then Nothing
       else -- If thread 2 depends upon thread 1, 
            -- then thread 1's current path must be an extension of
            -- the path in thread 2's dependency map.
            let b1 = case Map.lookup tid1 dm2 of
                       Just (Trail (_, bft', p')) -> 
                           bft1 == bft' && List.isSuffixOf p' p1
                       Nothing -> True in
            -- If thread 1 depends upon thread 2, 
            -- then thread 2's current path must be an extension of
            -- the path in thread 1's dependency map.
            let b2 = case Map.lookup tid2 dm1 of
                       Just (Trail (_, bft', p')) ->
                           bft2 == bft' && List.isSuffixOf p' p2
                       Nothing -> True in
            if b1 && b2
               then -- If the states are pairwise coherent,
                    -- then check dependency maps for deep coherence.
                    coherentDepTrailMaps dm1 dm2
               else Nothing

-- WISHLIST: 
-- Data.Map.unionWithKeyPartial :: Ord k =>
--                                 (k -> a -> a -> Maybe a) ->
--                                 Map k a ->
--                                 Map k a ->
--                                 Maybe (Map k a)
coherentDepTrailMaps :: DepTrailMap -> DepTrailMap -> Maybe DepTrailMap
coherentDepTrailMaps dm1 dm2 =
    Map.fold (\ t1@(Trail (tid1, bft1, p1)) mm ->
                  case mm of 
                    Nothing -> Nothing
                    Just m -> case Map.lookup tid1 m of
                                Nothing -> Just (Map.insert tid1 t1 m)
                                -- If there is a common depended upon
                                -- thread, then the path in one
                                -- dependency map must be an extension
                                -- of the path in the other dependency
                                -- map.  
                                -- If the path in thread 2's
                                -- dependency map is maximal, then it
                                -- becomes the member of the common
                                -- dependency map.
                                -- If the path in thread 1's
                                -- dependency map is maximal, then it
                                -- becomes the member of the common
                                -- dependency map.
                                Just (Trail (_, bft2, p2)) ->
                                    if bft1 /= bft2
                                       then Nothing
                                    else if List.isSuffixOf p1 p2
                                       then Just m
                                    else if List.isSuffixOf p2 p1
                                       then Just (Map.insert tid1 t1 m)
                                    else Nothing)
             (Just dm2) dm1
{--
    let -- If there is a common depended upon thread,
        -- then the path in one dependency map must be an extension
        -- of the path in the other dependency map.
        -- If the path in thread 2's dependency map is maximal,
        -- then it becomes the member of the common dependency map.
        -- If the path in thread 1's dependency map is maximal,
        -- then it becomes the member of the common dependency map.
        common = Map.intersectionWith
                 (\ t1@(Trail (bft1, _, p1)) t2@(Trail (bft2, _, p2)) ->
                      if bft1 /= bft2
                         then Nothing
                      else if List.isSuffixOf p1 p2
                         then Just t2
                      else if List.isSuffixOf p2 p1
                         then Just t1
                      else Nothing)
                 dm1 dm2
        b = Map.fold (\ m -> ((Maybe.isJust m) &&)) True cap in
    if b
       then let common' = Map.map Maybe.fromJust common in
            Just (common' `Map.union` (dm1 `Map.union` dm2))
       else Nothing
--}

{-|
Create an event that becomes commitable in the given number of
microseconds after synchronization.  
-}
timeOutEvt :: Int -> Evt ()
timeOutEvt us = 
    Evt (\ k s@(EvtState (t@(Trail(tid, bft, p)), dm, cnt)) ->
             let timeOutEvtAct = do
                   if us <= 0
                      then k (Always ()) s
                      else do threadDelay us
                              fizzleEvt "timeOutEvt" s (k (Always ()) s) in
             fizzleEvt "timeOutEvt" s timeOutEvtAct)

{-|
Create an event that becomes commitable at the specified time interval
after synchronization.
-}
timeDiffEvt :: TimeDiff -> Evt ()
timeDiffEvt td = 
    Evt (\ k s@(EvtState (t@(Trail(tid, bft, p)), dm, cnt)) ->
             let timeDiffEvtAct = do
                   if td <= Time.noTimeDiff
                      then k (Always ()) s
                      else do threadDelay (timeDiffToMicroSeconds td)
                              fizzleEvt "timeDiffEvt" s (k (Always ()) s) in
             fizzleEvt "timeDiffEvt" s timeDiffEvtAct)

{-|
Create an event that becomes commitable at the specified time.
-}
clockTimeEvt :: ClockTime -> Evt ()
clockTimeEvt tm =
    Evt (\ k s@(EvtState (t@(Trail(tid, bft, p)), dm, cnt)) ->
             let clockTimeEvtAct = do 
                   ctm <- Time.getClockTime
                   if tm <= ctm
                      then k (Always ()) s
                      else let td = Time.diffClockTimes tm ctm in
                           do threadDelay (timeDiffToMicroSeconds td)
                              fizzleEvt "clockTimeEvt" s (k (Always ()) s) in
             fizzleEvt "clockTimeEvt" s clockTimeEvtAct)

{-| 
An always commiable event computation that returns the 'ThreadId' of
the synchronizing thread.
-}
myThreadIdEvt :: Evt ThreadId
myThreadIdEvt =
    Evt (\ k s@(EvtState (t@(Trail(tid, bft, p)), dm, cnt)) -> 
             let myThreadIdEvtAct = k (Always tid) s in
             fizzleEvt "myThreadIdEvt" s myThreadIdEvtAct)

{- |
Synchronize on an event.
-}
sync :: Evt a -> IO a
-- Synchronization on an event proceeds as follows:
--  a) allocate a new boolean flag for this synchronization,
--  b) allocate a new TMVar to return the result of synchronization to
--     the main thread,
--  c) spawn a search thread with an initial continuation and intial
--     event state,
--  d) block reading the TMVar for the synchronization result.
sync evt = do
  tid <- myThreadId
  bft <- atomically newBoolFlag
  cnt <- atomically newCountVar
  res <- atomically newEmptyTMVar
  when doCounts $ forkIO_ $ watchCount tid cnt
  let k = initEvtCont res
      t = Trail (tid, bft, [])
      dm = Map.singleton tid t
      s = EvtState (t, dm, cnt)
  forkIO_ $ forceEvt evt k s
  Exception.catch (atomically (readTMVar res))
                  (\ e -> do
                     -- The exception is necessarily asynchronous.
                     -- Set the boolean flag to True to fizzle all
                     -- search threads.  Note: if the flag is still
                     -- False, then the semantics is as though the
                     -- exception arrived before the synchronization.
                     -- If the flag is already True, then the
                     -- semantics is as though the exception arrived
                     -- after the synchronization.  In either case,
                     -- the atomicity of the synchronization is
                     -- preserved.
                     atomically (writeTVar bft True)
                     Exception.throwIO e)

-- Initial event synchronization continuation;
initEvtCont :: TMVar a -> EvtCont a
initEvtCont res x s@(EvtState (Trail (tid, bft, p), _, cnt)) = do
  checkEvtState s
  case x of
    Throw _ -> do
      -- An uncaught exception fizzles the path.
      decCount cnt
      return ()
    Always x -> do
      -- Allocate a new boolean flag to govern this path.
      bfp <- atomically newBoolFlag
      -- Add this (complete) path as a potential commit to all
      -- communication points on the path.
      atomically $ mapM_ (\ pe -> 
                              case pe of 
                                Choose _ -> return ()
                                Catch -> return ()
                                Comm _ _ pcs -> do 
                                  pcs' <- readTVar pcs 
                                  writeTVar pcs ((p,bfp):pcs'))
                         p
      -- Search for a consistent sychronization set that includes
      -- this path.
      atomically $ do
        ms <- commitSearch (Track (tid, bft, p, bfp)) Map.empty
        checkCommitSearch ms
        case ms of
          m:_ -> mapM_ (\ (Track (_, bft, _, bfp)) -> do 
                          writeTVar bft True
                          writeTVar bfp True) 
                       (Map.elems m)
          _ -> return ()
      decCount cnt
      -- Block until either
      --  a) some other path synchronized the thread, in which case
      --     this path is fizzled
      --  b) this path is chosen to synchronize the thread, in which
      --     case we return the synchronization result via the TMVar.
      atomically $ do 
        bft' <- readTVar bft
        bfp' <- readTVar bfp
        case (bft', bfp') of
          (False, False) -> retry
          (True, False) -> return ()
          (True, True) -> putTMVar res x
      return ()

-- Search for a commitable set.
commitSearch :: Track -> DepTrackMap -> STM [DepTrackMap]
commitSearch (Track (tid, bft, p, bfp)) m = 
    case Map.lookup tid m of
      Just (Track (_, bft', _, bfp')) ->
          if bft == bft' && bfp == bfp' 
             then return [m]
             else return []
      Nothing -> do
        b <- readTVar bft
        if b then return []
             else commitSearchAux p (Map.insert tid (Track (tid, bft, p, bfp)) m)
    where commitSearchAux [] m = return [m]
          commitSearchAux ((Choose _):p) m = commitSearchAux p m
          commitSearchAux ((Catch):p) m = commitSearchAux p m
          commitSearchAux ((Comm _ ((Trail (tid', bft', _)), pcs') _):p) m = do
            ms <- commitSearchAux p m
            pcs'' <- readTVar pcs'
{--
            mss <- sequence [ commitSearch tid' bft' p' bfp' m' 
                                  | m' <- ms, (p', bfp') <- pcs'' ]
            return (concat mss)
--}
            foldM (\ acc m' -> do 
                     foldM (\ acc (p', bfp') -> do
                              ms <- commitSearch (Track (tid', bft', p', bfp')) m'
                              return (ms ++ acc))
                           acc
                           pcs'')
                  []
                  ms

----------------------------------------------------------------------
----------------------------------------------------------------------
----------------------------------------------------------------------
----------------------------------------------------------------------

-- Compute the dependency trail map from a trail.
depTrailMap :: Trail -> DepTrailMap
depTrailMap t = depTrailMapAux t Map.empty
depTrailMapAux :: Trail -> DepTrailMap -> DepTrailMap
depTrailMapAux t@(Trail (tid, bft, p)) m =
    case Map.lookup tid m of
      Nothing -> depTrailMapAuxAux t (Map.insert tid t m)
      Just (Trail (_, bft', p')) ->
          if bft' /= bft
             then throwAssertionFailed "AssertionFailed:: depTrailMapAux"
          else if List.isSuffixOf p p'
             then m
          else if List.isSuffixOf p' p
             then depTrailMapAuxAux t (Map.insert tid t m)
          else throwAssertionFailed "AssertionFailed:: depTrailMapAux"
depTrailMapAuxAux :: Trail -> DepTrailMap -> DepTrailMap
depTrailMapAuxAux (Trail (tid, bft, p)) m =
    case p of
      [] -> m
      (Choose _):p -> depTrailMapAuxAux (Trail (tid, bft, p)) m
      (Catch):p -> depTrailMapAuxAux (Trail (tid, bft, p)) m
      (Comm ce (Trail (tid', bft', p'),pcs') pcs):p -> 
          let ce' = case ce of CommSend -> CommRecv ; CommRecv -> CommSend in
          let m' = depTrailMapAuxAux (Trail (tid, bft, p)) m in
          let t' = Trail (tid', bft', 
                          ((Comm ce' ((Trail (tid, bft, p)),pcs) pcs'):p')) in
          depTrailMapAux t' m'

-- Check that the dynamically maintained dependency map is equal to
-- the dependency map of the trail.
checkDepTrailMap :: Monad m => EvtState -> m ()
checkDepTrailMap (EvtState (t, dm, _)) =
    let dm' = depTrailMap t in
    if dm == dm'
       then return ()
       else Exception.throw (Exception.AssertionFailed 
                               ("AssertionFailed:: checkDepTrailMap: " 
                                ++ (show dm) ++ " ** " 
                                ++ (show dm')))

-- Check the event state for consistency.
checkEvtState :: Monad m => EvtState -> m ()
checkEvtState s = do
  when doChecks $ checkDepTrailMap s


-- Check that a single set returned by `commitSearch` is commitable.
-- This requires that the dependency map of every (implicit) trail in
-- the the set is coherent and that every trail in these dependency
-- maps has an extension in the set.
checkCommitSearchAuxAux :: DepTrackMap -> Bool
checkCommitSearchAuxAux m =
    Map.fold (\ (Track (tid, bft, p, _)) b -> 
                  Map.fold (\ (Trail (tid', bft', p')) b ->
                                b &&
                                case Map.lookup tid' m of
                                  Just (Track (_, bft'', p'', _)) ->
                                      bft'' == bft' && List.isSuffixOf p' p''
                                  Nothing -> False)
                           b
                           (depTrailMap (Trail (tid, bft, p))))
             True
             m

checkCommitSearchAux :: Monad m => DepTrackMap -> m ()
checkCommitSearchAux m =
    if checkCommitSearchAuxAux m
       then return ()
       else Exception.throw (Exception.AssertionFailed 
                             ("AssertionFailed:: checkCommitSearchAux:" 
                              ++ (show m)))

-- Check that `commitSearch` returned commitable sets.
checkCommitSearch :: Monad m => [DepTrackMap] -> m ()
checkCommitSearch ms = do
  when doChecks $ mapM_ checkCommitSearchAux ms
