{-# LANGUAGE RankNTypes #-}
-- |
-- Module      : System.ZMQ3.Monadic
-- Copyright   : (c) 2013 Toralf Wittner
-- License     : MIT
-- Maintainer  : Toralf Wittner <tw@dtex.org>
-- Stability   : experimental
-- Portability : non-portable
--
-- This modules exposes a monadic interface of 'System.ZMQ3'. Actions run
-- inside a 'ZMQ' monad and 'Socket's are guaranteed not to leak outside
-- their corresponding 'runZMQ' scope. Running 'ZMQ' computations
-- asynchronously is directly supported through 'async'.
module System.ZMQ3.Monadic
  ( -- * Type Definitions
    ZMQ
  , Socket
  , Z.Flag (SendMore)
  , Z.Switch (..)
  , Z.Timeout
  , Z.Event (..)
  , Z.EventType (..)
  , Z.EventMsg (..)
  , Z.Poll (..)

  -- ** Type Classes
  , Z.SocketType
  , Z.Sender
  , Z.Receiver
  , Z.Subscriber

  -- ** Socket Types
  , Z.Pair(..)
  , Z.Pub(..)
  , Z.Sub(..)
  , Z.XPub(..)
  , Z.XSub(..)
  , Z.Req(..)
  , Z.Rep(..)
  , Z.Dealer(..)
  , Z.Router(..)
  , Z.Pull(..)
  , Z.Push(..)

  -- * General Operations
  , version
  , runZMQ
  , async
  , socket

  -- * ZMQ Options (Read)
  , ioThreads
  , maxSockets

  -- * ZMQ Options (Write)
  , setIoThreads
  , setMaxSockets

  -- * Socket operations
  , close
  , bind
  , unbind
  , connect
  , send
  , send'
  , sendMulti
  , receive
  , receiveMulti
  , subscribe
  , unsubscribe
  , proxy
  , monitor
  , Z.poll

  -- * Socket Options (Read)
  , affinity
  , backlog
  , delayAttachOnConnect
  , events
  , fileDescriptor
  , identity
  , ipv4Only
  , lastEndpoint
  , linger
  , maxMessageSize
  , mcastHops
  , moreToReceive
  , rate
  , receiveBuffer
  , receiveHighWM
  , receiveTimeout
  , reconnectInterval
  , reconnectIntervalMax
  , recoveryInterval
  , sendBuffer
  , sendHighWM
  , sendTimeout
  , tcpKeepAlive
  , tcpKeepAliveCount
  , tcpKeepAliveIdle
  , tcpKeepAliveInterval

  -- * Socket Options (Write)
  , setAffinity
  , setBacklog
  , setDelayAttachOnConnect
  , setIdentity
  , setIpv4Only
  , setLinger
  , setMaxMessageSize
  , setMcastHops
  , setRate
  , setReceiveBuffer
  , setReceiveHighWM
  , setReceiveTimeout
  , setReconnectInterval
  , setReconnectIntervalMax
  , setRecoveryInterval
  , setRouterMandatory
  , setSendBuffer
  , setSendHighWM
  , setSendTimeout
  , setTcpAcceptFilter
  , setTcpKeepAlive
  , setTcpKeepAliveCount
  , setTcpKeepAliveIdle
  , setTcpKeepAliveInterval
  , setXPubVerbose

  -- * Error Handling
  , Z.ZMQError
  , Z.errno
  , Z.source
  , Z.message

  -- * Re-exports
  , Control.Monad.IO.Class.liftIO
  , Data.Restricted.restrict
  , Data.Restricted.toRestricted

  -- * Low-level Functions
  , waitRead
  , waitWrite
  )
where

import Control.Applicative
import Control.Concurrent.Async (Async)
import Control.Monad
import Control.Monad.Trans.Reader
import Control.Monad.IO.Class
import Control.Monad.CatchIO
import Data.Int
import Data.IORef
import Data.List.NonEmpty (NonEmpty)
import Data.Restricted
import Data.Word
import Data.ByteString (ByteString)
import System.Posix.Types (Fd)

import qualified Control.Concurrent.Async as A
import qualified Control.Exception        as E
import qualified Control.Monad.CatchIO    as M
import qualified Data.ByteString.Lazy     as Lazy
import qualified System.ZMQ3              as Z
import qualified System.ZMQ3.Internal     as I

data ZMQEnv = ZMQEnv
  { _refcount :: !(IORef Word)
  , _context  :: !Z.Context
  , _sockets  :: !(IORef [I.SocketRepr])
  }

-- | The ZMQ monad is modeled after 'Control.Monad.ST' and encapsulates
-- a 'System.ZMQ3.Context'. It uses the uninstantiated type variable 'z' to
-- distinguish different invoctions of 'runZMQ' and to prevent
-- unintented use of 'Socket's outside their scope. Cf. the paper
-- of John Launchbury and Simon Peyton Jones /Lazy Functional State Threads/.
newtype ZMQ z a = ZMQ { _unzmq :: ReaderT ZMQEnv IO a }

-- | The ZMQ socket, parameterised by 'SocketType' and belonging to
-- a particular 'ZMQ' thread.
newtype Socket z t = Socket { _unsocket :: Z.Socket t }

instance Monad (ZMQ z) where
    return = ZMQ . return
    (ZMQ m) >>= f = ZMQ $! m >>= _unzmq . f

instance MonadIO (ZMQ z) where
    liftIO m = ZMQ $! liftIO m

instance MonadCatchIO (ZMQ z) where
    catch (ZMQ m) f = ZMQ $! m `M.catch` (_unzmq . f)
    block (ZMQ m)   = ZMQ $! block m
    unblock (ZMQ m) = ZMQ $! unblock m

instance Functor (ZMQ z) where
    fmap = liftM

instance Applicative (ZMQ z) where
    pure  = return
    (<*>) = ap

-- | Return the value computed by the given 'ZMQ' monad. Rank-2
-- polymorphism is used to prevent leaking of 'z'.
-- An invocation of 'runZMQ' will internally create a 'System.ZMQ3.Context'
-- and all actions are executed relative to this context. On finish the
-- context will be disposed, but see 'async'.
runZMQ :: MonadIO m => (forall z. ZMQ z a) -> m a
runZMQ z = liftIO $ E.bracket make destroy (runReaderT (_unzmq z))
  where
    make = ZMQEnv <$> newIORef 1 <*> Z.context <*> newIORef []

-- | Run the given 'ZMQ' computation asynchronously, i.e. this function
-- runs the computation in a new thread using 'Control.Concurrent.Async.async'.
-- /N.B./ reference counting is used to prolong the lifetime of the
-- 'System.ZMQ.Context' encapsulated in 'ZMQ' as necessary, e.g.:
--
-- @
-- runZMQ $ do
--     s <- socket Pair
--     async $ do
--         liftIO (threadDelay 10000000)
--         identity s >>= liftIO . print
-- @
--
-- Here, 'runZMQ' will finish before the code section in 'async', but due to
-- reference counting, the 'System.ZMQ3.Context' will only be disposed after
-- 'async' finishes as well.
async :: ZMQ z a -> ZMQ z (Async a)
async z = ZMQ $ do
    e <- ask
    liftIO $ atomicModifyIORef (_refcount e) $ \n -> (succ n, ())
    liftIO . A.async $ (runReaderT (_unzmq z) e) `E.finally` destroy e

ioThreads :: ZMQ z Word
ioThreads = onContext Z.ioThreads

setIoThreads :: Word -> ZMQ z ()
setIoThreads = onContext . Z.setIoThreads

maxSockets :: ZMQ z Word
maxSockets = onContext Z.maxSockets

setMaxSockets :: Word -> ZMQ z ()
setMaxSockets = onContext . Z.setMaxSockets

socket :: Z.SocketType t => t -> ZMQ z (Socket z t)
socket t = ZMQ $ do
    c <- asks _context
    s <- asks _sockets
    x <- liftIO $ I.mkSocketRepr t c
    liftIO $ atomicModifyIORef s $ \ss -> (x:ss, ())
    return (Socket (I.Socket x))

version :: ZMQ z (Int, Int, Int)
version = liftIO $! Z.version

-- * Socket operations

close :: Socket z t -> ZMQ z ()
close = liftIO . Z.close . _unsocket

bind :: Socket z t -> String -> ZMQ z ()
bind s = liftIO . Z.bind (_unsocket s)

unbind :: Socket z t -> String -> ZMQ z ()
unbind s = liftIO . Z.unbind (_unsocket s)

connect :: Socket z t -> String -> ZMQ z ()
connect s = liftIO . Z.connect (_unsocket s)

send :: Z.Sender t => Socket z t -> [Z.Flag] -> ByteString -> ZMQ z ()
send s f = liftIO . Z.send (_unsocket s) f

send' :: Z.Sender t => Socket z t -> [Z.Flag] -> Lazy.ByteString -> ZMQ z ()
send' s f = liftIO . Z.send' (_unsocket s) f

sendMulti :: Z.Sender t => Socket z t -> NonEmpty ByteString -> ZMQ z ()
sendMulti s = liftIO . Z.sendMulti (_unsocket s)

receive :: Z.Receiver t => Socket z t -> ZMQ z ByteString
receive = liftIO . Z.receive . _unsocket

receiveMulti :: Z.Receiver t => Socket z t -> ZMQ z [ByteString]
receiveMulti = liftIO . Z.receiveMulti . _unsocket

subscribe :: Z.Subscriber t => Socket z t -> ByteString -> ZMQ z ()
subscribe s = liftIO . Z.subscribe (_unsocket s)

unsubscribe :: Z.Subscriber t => Socket z t -> ByteString -> ZMQ z ()
unsubscribe s = liftIO . Z.unsubscribe (_unsocket s)

proxy :: Socket z a -> Socket z b -> Maybe (Socket z c) -> ZMQ z ()
proxy a b c = liftIO $ Z.proxy (_unsocket a) (_unsocket b) (_unsocket <$> c)

monitor :: [Z.EventType] -> Socket z t -> ZMQ z (Bool -> IO (Maybe Z.EventMsg))
monitor es s = onContext $ \ctx -> Z.monitor es ctx (_unsocket s)

-- * Socket Options (Read)

affinity :: Socket z t -> ZMQ z Word64
affinity = liftIO . Z.affinity . _unsocket

backlog :: Socket z t -> ZMQ z Int
backlog = liftIO . Z.backlog . _unsocket

delayAttachOnConnect :: Socket z t -> ZMQ z Bool
delayAttachOnConnect = liftIO . Z.delayAttachOnConnect . _unsocket

events :: Socket z t -> ZMQ z [Z.Event]
events = liftIO . Z.events . _unsocket

fileDescriptor :: Socket z t -> ZMQ z Fd
fileDescriptor = liftIO . Z.fileDescriptor . _unsocket

identity :: Socket z t -> ZMQ z ByteString
identity = liftIO . Z.identity . _unsocket

ipv4Only :: Socket z t -> ZMQ z Bool
ipv4Only = liftIO . Z.ipv4Only . _unsocket

lastEndpoint :: Socket z t -> ZMQ z String
lastEndpoint = liftIO . Z.lastEndpoint . _unsocket

linger :: Socket z t -> ZMQ z Int
linger = liftIO . Z.linger . _unsocket

maxMessageSize :: Socket z t -> ZMQ z Int64
maxMessageSize = liftIO . Z.maxMessageSize . _unsocket

mcastHops :: Socket z t -> ZMQ z Int
mcastHops = liftIO . Z.mcastHops . _unsocket

moreToReceive :: Socket z t -> ZMQ z Bool
moreToReceive = liftIO . Z.moreToReceive . _unsocket

rate :: Socket z t -> ZMQ z Int
rate = liftIO . Z.rate . _unsocket

receiveBuffer :: Socket z t -> ZMQ z Int
receiveBuffer = liftIO . Z.receiveBuffer . _unsocket

receiveHighWM :: Socket z t -> ZMQ z Int
receiveHighWM = liftIO . Z.receiveHighWM . _unsocket

receiveTimeout :: Socket z t -> ZMQ z Int
receiveTimeout = liftIO . Z.receiveTimeout . _unsocket

reconnectInterval :: Socket z t -> ZMQ z Int
reconnectInterval = liftIO . Z.reconnectInterval . _unsocket

reconnectIntervalMax :: Socket z t -> ZMQ z Int
reconnectIntervalMax = liftIO . Z.reconnectIntervalMax . _unsocket

recoveryInterval :: Socket z t -> ZMQ z Int
recoveryInterval = liftIO . Z.recoveryInterval . _unsocket

sendBuffer :: Socket z t -> ZMQ z Int
sendBuffer = liftIO . Z.sendBuffer . _unsocket

sendHighWM :: Socket z t -> ZMQ z Int
sendHighWM = liftIO . Z.sendHighWM . _unsocket

sendTimeout :: Socket z t -> ZMQ z Int
sendTimeout = liftIO . Z.sendTimeout . _unsocket

tcpKeepAlive :: Socket z t -> ZMQ z Z.Switch
tcpKeepAlive = liftIO . Z.tcpKeepAlive . _unsocket

tcpKeepAliveCount :: Socket z t -> ZMQ z Int
tcpKeepAliveCount = liftIO . Z.tcpKeepAliveCount . _unsocket

tcpKeepAliveIdle :: Socket z t -> ZMQ z Int
tcpKeepAliveIdle = liftIO . Z.tcpKeepAliveIdle . _unsocket

tcpKeepAliveInterval :: Socket z t -> ZMQ z Int
tcpKeepAliveInterval = liftIO . Z.tcpKeepAliveInterval . _unsocket

-- * Socket Options (Write)

setAffinity :: Word64 -> Socket z t -> ZMQ z ()
setAffinity a = liftIO . Z.setAffinity a . _unsocket

setBacklog :: Integral i => Restricted N0 Int32 i -> Socket z t -> ZMQ z ()
setBacklog b = liftIO . Z.setBacklog b . _unsocket

setDelayAttachOnConnect :: Bool -> Socket z t -> ZMQ z ()
setDelayAttachOnConnect d = liftIO . Z.setDelayAttachOnConnect d . _unsocket

setIdentity :: Restricted N1 N254 ByteString -> Socket z t -> ZMQ z ()
setIdentity i = liftIO . Z.setIdentity i . _unsocket

setIpv4Only :: Bool -> Socket z t -> ZMQ z ()
setIpv4Only i = liftIO . Z.setIpv4Only i . _unsocket

setLinger :: Integral i => Restricted Nneg1 Int32 i -> Socket z t -> ZMQ z ()
setLinger l = liftIO . Z.setLinger l . _unsocket

setMaxMessageSize :: Integral i => Restricted Nneg1 Int64 i -> Socket z t -> ZMQ z ()
setMaxMessageSize s = liftIO . Z.setMaxMessageSize s . _unsocket

setMcastHops :: Integral i => Restricted N1 Int32 i -> Socket z t -> ZMQ z ()
setMcastHops k = liftIO . Z.setMcastHops k . _unsocket

setRate :: Integral i => Restricted N1 Int32 i -> Socket z t -> ZMQ z ()
setRate r = liftIO . Z.setRate r . _unsocket

setReceiveBuffer :: Integral i => Restricted N0 Int32 i -> Socket z t -> ZMQ z ()
setReceiveBuffer k = liftIO . Z.setReceiveBuffer k . _unsocket

setReceiveHighWM :: Integral i => Restricted N0 Int32 i -> Socket z t -> ZMQ z ()
setReceiveHighWM k = liftIO . Z.setReceiveHighWM k . _unsocket

setReceiveTimeout :: Integral i => Restricted Nneg1 Int32 i -> Socket z t -> ZMQ z ()
setReceiveTimeout t = liftIO . Z.setReceiveTimeout t . _unsocket

setReconnectInterval :: Integral i => Restricted N0 Int32 i -> Socket z t -> ZMQ z ()
setReconnectInterval i = liftIO . Z.setReconnectInterval i . _unsocket

setReconnectIntervalMax :: Integral i => Restricted N0 Int32 i -> Socket z t -> ZMQ z ()
setReconnectIntervalMax i = liftIO . Z.setReconnectIntervalMax i . _unsocket

setRecoveryInterval :: Integral i => Restricted N0 Int32 i -> Socket z t -> ZMQ z ()
setRecoveryInterval i = liftIO . Z.setRecoveryInterval i . _unsocket

setRouterMandatory :: Bool -> Socket z Z.Router -> ZMQ z ()
setRouterMandatory b = liftIO . Z.setRouterMandatory b . _unsocket

setSendBuffer :: Integral i => Restricted N0 Int32 i -> Socket z t -> ZMQ z ()
setSendBuffer i = liftIO . Z.setSendBuffer i . _unsocket

setSendHighWM :: Integral i => Restricted N0 Int32 i -> Socket z t -> ZMQ z ()
setSendHighWM i = liftIO . Z.setSendHighWM i . _unsocket

setSendTimeout :: Integral i => Restricted Nneg1 Int32 i -> Socket z t -> ZMQ z ()
setSendTimeout i = liftIO . Z.setSendTimeout i . _unsocket

setTcpAcceptFilter :: Maybe ByteString -> Socket z t -> ZMQ z ()
setTcpAcceptFilter s = liftIO . Z.setTcpAcceptFilter s . _unsocket

setTcpKeepAlive :: Z.Switch -> Socket z t -> ZMQ z ()
setTcpKeepAlive s = liftIO . Z.setTcpKeepAlive s . _unsocket

setTcpKeepAliveCount :: Integral i => Restricted Nneg1 Int32 i -> Socket z t -> ZMQ z ()
setTcpKeepAliveCount c = liftIO . Z.setTcpKeepAliveCount c . _unsocket

setTcpKeepAliveIdle :: Integral i => Restricted Nneg1 Int32 i -> Socket z t -> ZMQ z ()
setTcpKeepAliveIdle i = liftIO . Z.setTcpKeepAliveIdle i . _unsocket

setTcpKeepAliveInterval :: Integral i => Restricted Nneg1 Int32 i -> Socket z t -> ZMQ z ()
setTcpKeepAliveInterval i = liftIO . Z.setTcpKeepAliveInterval i . _unsocket

setXPubVerbose :: Bool -> Socket z Z.XPub -> ZMQ z ()
setXPubVerbose b = liftIO . Z.setXPubVerbose b . _unsocket

-- * Low Level Functions

waitRead :: Socket z t -> ZMQ z ()
waitRead = liftIO . Z.waitRead . _unsocket

waitWrite :: Socket z t -> ZMQ z ()
waitWrite = liftIO . Z.waitWrite . _unsocket

-- * Internal

onContext :: (Z.Context -> IO a) -> ZMQ z a
onContext f = ZMQ $! asks _context >>= liftIO . f

destroy :: ZMQEnv -> IO ()
destroy env = do
    n <- atomicModifyIORef (_refcount env) $ \n -> (pred n, n)
    when (n == 1) $ do
        readIORef (_sockets env) >>= mapM_ close'
        Z.destroy (_context env)
  where
    close' s = I.closeSock s `E.catch` (\e -> print (e :: E.SomeException))