{-# LANGUAGE RecordWildCards #-}

module General.Timing(Timing, withTiming, timed, timedOverwrite) where

import Data.List.Extra
import System.Time.Extra
import Data.IORef
import Control.Monad.Extra
import System.IO
import General.Util
import Control.Monad.IO.Class

-- | A mutable object to keep timing information
data Timing = Timing
    { -- | Get time since the initialization of this 'Timing'.
      Timing -> IO Seconds
timingOffset :: IO Seconds
      -- | Record timings for writing to a file
    , Timing -> IORef [(String, Seconds)]
timingStore :: IORef [(String, Seconds)]
      -- | If you are below T you may overwrite N characters
      -- at the end of the current terminal output.
      -- Only used iff @timingTerminal == True@.
    , Timing -> IORef (Maybe (Seconds, Int))
timingOverwrite :: IORef (Maybe (Seconds, Int))
      -- | whether is this a terminal
    , Timing -> Bool
timingTerminal :: Bool
    }

-- | Time an action, printing timing information to the terminal
withTiming ::
  -- | A file to optionally write all timings to, after the action is finished
  Maybe FilePath ->
  -- | An action that can write timings into 'Timing'
  (Timing -> IO a) ->
  IO a
withTiming :: forall a. Maybe String -> (Timing -> IO a) -> IO a
withTiming Maybe String
writeTimingsTo Timing -> IO a
act = do
    timingOffset <- IO (IO Seconds)
offsetTime
    timingStore <- newIORef []
    timingOverwrite <- newIORef Nothing
    timingTerminal <- hIsTerminalDevice stdout

    res <- act Timing{..}
    total <- timingOffset
    whenJust writeTimingsTo $ \String
file -> do
        xs <- IORef [(String, Seconds)] -> IO [(String, Seconds)]
forall a. IORef a -> IO a
readIORef IORef [(String, Seconds)]
timingStore
        -- Expecting unrecorded of ~2s
        -- Most of that comes from the pipeline - we get occasional 0.01 between items as one flushes
        -- Then at the end there is ~0.5 while the final item flushes
        xs <- pure $ sortOn (negate . snd) $ ("Unrecorded", total - sum (map snd xs)) : xs
        writeFile file $ unlines $ prettyTable 2 "Secs" xs
    putStrLn $ "Took " ++ showDuration total
    pure res


-- skip it if have written out in the last 1s and takes < 0.1

-- | Time & write the given message to stdout
timed :: MonadIO m => Timing -> String -> m a -> m a
timed :: forall (m :: * -> *) a. MonadIO m => Timing -> String -> m a -> m a
timed = Bool -> Timing -> String -> m a -> m a
forall (m :: * -> *) a.
MonadIO m =>
Bool -> Timing -> String -> m a -> m a
timedEx Bool
False

-- | Time & write the given message to stdout
-- overwriting a previous message if it was marked as overwritable
timedOverwrite :: MonadIO m => Timing -> String -> m a -> m a
timedOverwrite :: forall (m :: * -> *) a. MonadIO m => Timing -> String -> m a -> m a
timedOverwrite = Bool -> Timing -> String -> m a -> m a
forall (m :: * -> *) a.
MonadIO m =>
Bool -> Timing -> String -> m a -> m a
timedEx Bool
True

timedEx :: MonadIO m => Bool -> Timing -> String -> m a -> m a
timedEx :: forall (m :: * -> *) a.
MonadIO m =>
Bool -> Timing -> String -> m a -> m a
timedEx Bool
overwrite Timing{Bool
IO Seconds
IORef [(String, Seconds)]
IORef (Maybe (Seconds, Int))
timingOffset :: Timing -> IO Seconds
timingStore :: Timing -> IORef [(String, Seconds)]
timingOverwrite :: Timing -> IORef (Maybe (Seconds, Int))
timingTerminal :: Timing -> Bool
timingOffset :: IO Seconds
timingStore :: IORef [(String, Seconds)]
timingOverwrite :: IORef (Maybe (Seconds, Int))
timingTerminal :: Bool
..} String
msg m a
act = do
    start <- IO Seconds -> m Seconds
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Seconds
timingOffset
    liftIO $ whenJustM (readIORef timingOverwrite) $ \(Seconds
t,Int
n) ->
        if Bool
overwrite Bool -> Bool -> Bool
&& Seconds
start Seconds -> Seconds -> Bool
forall a. Ord a => a -> a -> Bool
< Seconds
t then
            String -> IO ()
putStr (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Char -> String
forall a. Int -> a -> [a]
replicate Int
n Char
'\b' String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Char -> String
forall a. Int -> a -> [a]
replicate Int
n Char
' ' String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> Char -> String
forall a. Int -> a -> [a]
replicate Int
n Char
'\b'
        else
            String -> IO ()
putStrLn String
""

    let out String
msg = IO Int -> m Int
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> m Int) -> IO Int -> m Int
forall a b. (a -> b) -> a -> b
$ String -> IO ()
putStr String
msg IO () -> IO Int -> IO Int
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> IO Int
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
msg)
    undo1 <- out $ msg ++ "... "
    liftIO $ hFlush stdout

    res <- act
    end <- liftIO timingOffset
    let time = Seconds
end Seconds -> Seconds -> Seconds
forall a. Num a => a -> a -> a
- Seconds
start
    liftIO $ modifyIORef timingStore ((msg,time):)

    s <- maybe "" (\String
x -> String
" (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
x String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")") <$> liftIO getStatsPeakAllocBytes
    undo2 <- out $ showDuration time ++ s

    old <- liftIO $ readIORef timingOverwrite
    let next = Seconds
-> ((Seconds, Int) -> Seconds) -> Maybe (Seconds, Int) -> Seconds
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Seconds
start Seconds -> Seconds -> Seconds
forall a. Num a => a -> a -> a
+ Seconds
1.0) (Seconds, Int) -> Seconds
forall a b. (a, b) -> a
fst Maybe (Seconds, Int)
old
    liftIO $ if timingTerminal && overwrite && end < next then
        writeIORef timingOverwrite $ Just (next, undo1 + undo2)
     else do
        writeIORef timingOverwrite Nothing
        putStrLn ""
    pure res