diff --git a/fact.cabal b/fact.cabal index 78fc1f1..a6a024c 100644 --- a/fact.cabal +++ b/fact.cabal @@ -54,6 +54,7 @@ library , criterion-measurement >=0.2.2.0 , monadlist >=0.0.2 , mtl >=2.3.1 + , transformers >= 0.6.1.1 default-language: Haskell2010 executable fact-exe @@ -70,5 +71,6 @@ executable fact-exe , criterion-measurement >=0.2.2.0 , monadlist >=0.0.2 , mtl >=2.3.1 + , transformers >= 0.6.1.1 , fact default-language: Haskell2010 diff --git a/src/Benchmarks.hs b/src/Benchmarks.hs index af70708..c4d8a26 100644 --- a/src/Benchmarks.hs +++ b/src/Benchmarks.hs @@ -2,6 +2,7 @@ module Benchmarks where import Examples.ChemicalReaction import Examples.Lorenz +import Examples.Sine import Driver import CT import IO @@ -18,7 +19,7 @@ import qualified Criterion.Measurement.Types as Criterion.Measurement.Measured perform :: IO [Double] -> IO (Double, Maybe Int64) perform test = do - (performance, _) <- measure (nfIO test) 10 + (performance, _) <- measure (nfIO test) 1 return (Criterion.Measurement.Measured.measTime performance, Criterion.Measurement.Measured.fromInt $ Criterion.Measurement.Measured.measAllocated performance) benchmarks :: IO () diff --git a/src/CT.hs b/src/CT.hs index 71e3747..316328b 100644 --- a/src/CT.hs +++ b/src/CT.hs @@ -37,18 +37,17 @@ -- Stability : stable -- Tested with: GHC 8.10.7 -- | - +{-# LANGUAGE FlexibleInstances #-} module CT (CT(..), Parameters(..)) where -import Control.Monad -import Control.Monad.Fix +import Control.Monad.Trans.Reader ( ReaderT ) -import Types +import Types ( Iteration ) -import Solver -import Simulation +import Solver ( Solver ) +import Simulation ( Interval ) -- | It defines the simulation time appended with additional information. data Parameters = Parameters { interval :: Interval, -- ^ the simulation interval @@ -57,83 +56,38 @@ data Parameters = Parameters { interval :: Interval, -- ^ the simulation interv iteration :: Iteration -- ^ the current iteration } deriving (Eq, Show) -newtype CT a = CT {apply :: Parameters -> IO a} - -instance Functor CT where - fmap f (CT da) = CT $ \ps -> fmap f (da ps) - -instance Applicative CT where - pure a = CT $ const (pure a) - (CT df) <*> (CT da) = CT $ \ps -> do f <- df ps - fmap f (da ps) - -appComposition :: CT (a -> b) -> CT a -> CT b -appComposition (CT df) (CT da) - = CT $ \ps -> df ps >>= \f -> fmap f (da ps) - -instance Monad CT where - return = pure - (CT m) >>= k = CT $ \ps -> do a <- m ps - k a `apply` ps - -instance MonadFix CT where - -- mfix :: (a -> m a) -> m a - mfix f = - CT $ \ps -> mfix ((`apply` ps) . f) - -returnD :: a -> CT a -returnD a = CT $ const (return a) - -bindD :: (a -> CT b ) -> CT a -> CT b -bindD k (CT m) = - CT $ \ps -> m ps >>= \a -> (\(CT m') -> m' ps) $ k a - -bindD' :: (a -> CT b ) -> CT a -> CT b -bindD' k (CT m) = CT $ \ps -> do - a <- m ps - k a `apply` ps - -instance Eq (CT a) where - x == y = error "<< Can't compare dynamics >>" - -instance Show (CT a) where - showsPrec _ x = showString "<< CT >>" - -unaryOP :: (a -> b) -> CT a -> CT b -unaryOP = fmap -binaryOP :: (a -> b -> c) -> CT a -> CT b -> CT c -binaryOP func da db = fmap func da <*> db +type CT a = ReaderT Parameters IO a instance (Num a) => Num (CT a) where - x + y = binaryOP (+) x y - x - y = binaryOP (-) x y - x * y = binaryOP (*) x y - negate = unaryOP negate - abs = unaryOP abs - signum = unaryOP signum + x + y = (+) <$> x <*> y + x - y = (-) <$> x <*> y + x * y = (*) <$> x <*> y + negate = fmap negate + abs = fmap abs + signum = fmap signum fromInteger i = return $ fromInteger i instance (Fractional a) => Fractional (CT a) where - x / y = binaryOP (/) x y - recip = unaryOP recip + x / y = (/) <$> x <*> y + recip = fmap recip fromRational t = return $ fromRational t instance (Floating a) => Floating (CT a) where pi = return pi - exp = unaryOP exp - log = unaryOP log - sqrt = unaryOP sqrt - x ** y = binaryOP (**) x y - sin = unaryOP sin - cos = unaryOP cos - tan = unaryOP tan - asin = unaryOP asin - acos = unaryOP acos - atan = unaryOP atan - sinh = unaryOP sinh - cosh = unaryOP cosh - tanh = unaryOP tanh - asinh = unaryOP asinh - acosh = unaryOP acosh - atanh = unaryOP atanh + exp = fmap exp + log = fmap log + sqrt = fmap sqrt + x ** y = (**) <$> x <*> y + sin = fmap sin + cos = fmap cos + tan = fmap tan + asin = fmap asin + acos = fmap acos + atan = fmap atan + sinh = fmap sinh + cosh = fmap cosh + tanh = fmap tanh + asinh = fmap asinh + acosh = fmap acosh + atanh = fmap atanh diff --git a/src/Driver.hs b/src/Driver.hs index 32df314..20701cf 100644 --- a/src/Driver.hs +++ b/src/Driver.hs @@ -1,9 +1,12 @@ module Driver where import CT + ( CT, Parameters(Parameters, solver, interval, time, iteration) ) import Solver + ( Solver(stage, dt), Stage(SolverStage, Interpolate), iterToTime ) import Simulation -import Types + ( Interval(Interval), iterationHiBnd, iterationBnds ) +import Control.Monad.Trans.Reader ( ReaderT(runReaderT) ) type Model a = CT (CT a) @@ -12,43 +15,43 @@ epslon = 0.00001 -- | Run the simulation and return the result in the last -- time point using the specified simulation specs. runCTFinal :: Model a -> Double -> Solver -> IO a -runCTFinal (CT m) t sl = - do d <- m Parameters { interval = Interval 0 t, - time = 0, - iteration = 0, - solver = sl { stage = SolverStage 0 }} +runCTFinal m t sl = + do d <- runReaderT m $ Parameters { interval = Interval 0 t, + time = 0, + iteration = 0, + solver = sl { stage = SolverStage 0 }} subRunCTFinal d t sl -- | Auxiliary functions to runCTFinal subRunCTFinal :: CT a -> Double -> Solver -> IO a -subRunCTFinal (CT m) t sl = - do let iv = Interval 0 t - n = iterationHiBnd iv (dt sl) - disct = iterToTime iv sl n (SolverStage 0) - x = m Parameters { interval = iv, - time = disct, - iteration = n, - solver = sl { stage = SolverStage 0 }} - if disct - t < epslon - then x - else m Parameters { interval = iv, - time = t, - iteration = n, - solver = sl { stage = Interpolate }} +subRunCTFinal m t sl = do + let iv = Interval 0 t + n = iterationHiBnd iv (dt sl) + disct = iterToTime iv sl n (SolverStage 0) + x = runReaderT m $ Parameters { interval = iv, + time = disct, + iteration = n, + solver = sl { stage = SolverStage 0 }} + if disct - t < epslon + then x + else runReaderT m $ Parameters { interval = iv, + time = t, + iteration = n, + solver = sl { stage = Interpolate }} -- | Run the simulation and return the results in all -- integration time points using the specified simulation specs. runCT :: Model a -> Double -> Solver -> IO [a] -runCT (CT m) t sl = do - d <- m Parameters { interval = Interval 0 t, - time = 0, - iteration = 0, - solver = sl { stage = SolverStage 0}} +runCT m t sl = do + d <- runReaderT m $ Parameters { interval = Interval 0 t, + time = 0, + iteration = 0, + solver = sl { stage = SolverStage 0}} sequence $ subRunCT d t sl -- | Auxiliary functions to runCT subRunCT :: CT a -> Double -> Solver -> [IO a] -subRunCT (CT m) t sl = do +subRunCT m t sl = do let iv = Interval 0 t (nl, nu) = iterationBnds iv (dt sl) parameterize n = @@ -65,8 +68,8 @@ subRunCT (CT m) t sl = do iteration = nu, solver = sl {stage = Interpolate}} endTime = iterToTime iv sl nu (SolverStage 0) - values = map (m . parameterize) [nl .. nu] + values = map (runReaderT m . parameterize) [nl .. nu] if endTime - t < epslon then values - else init values ++ [m ps] + else init values ++ [runReaderT m ps] diff --git a/src/Examples/Sine.hs b/src/Examples/Sine.hs index e382832..db06812 100644 --- a/src/Examples/Sine.hs +++ b/src/Examples/Sine.hs @@ -1,4 +1,3 @@ -{-# LANGUAGE RecordWildCards #-} module Examples.Sine where import Driver @@ -11,8 +10,8 @@ import Data.List import Simulation sineSolv = Solver { dt = 0.01, - method = RungeKutta4, - stage = SolverStage 0 } + method = RungeKutta4, + stage = SolverStage 0 } sineModel :: Model [Double] sineModel = diff --git a/src/Examples/Test.hs b/src/Examples/Test.hs index 95f2f1c..e4be2e7 100644 --- a/src/Examples/Test.hs +++ b/src/Examples/Test.hs @@ -61,25 +61,25 @@ predicate initialCondition = if initialCondition >= 20 then example2 else example1 -demux :: Predicate Double Result -> HybridModel Result -demux predicate (initialCondition, _) p = do - let m = predicate initialCondition - model <- m (pure initialCondition) `apply` p - head <$> model `apply` p +-- demux :: Predicate Double Result -> HybridModel Result +-- demux predicate (initialCondition, _) p = do +-- let m = predicate initialCondition +-- model <- m (pure initialCondition) `apply` p +-- head <$> model `apply` p -hybrid :: (MonadPlus p, Monad m) => (a -> Parameters -> m a) -> a -> Double -> Solver -> m (p a) -hybrid f z t sl = - do let iv = Interval 0 t - (nl, nu) = iterationBnds iv (dt sl) - parameterise n = Parameters { interval = iv, - time = iterToTime iv sl n (SolverStage 0), - iteration = 1, - solver = sl { stage = SolverStage 0 }} - ps = map parameterise [nl..nu] - scanM f z ps +-- hybrid :: (MonadPlus p, Monad m) => (a -> Parameters -> m a) -> a -> Double -> Solver -> m (p a) +-- hybrid f z t sl = +-- do let iv = Interval 0 t +-- (nl, nu) = iterationBnds iv (dt sl) +-- parameterise n = Parameters { interval = iv, +-- time = iterToTime iv sl n (SolverStage 0), +-- iteration = 1, +-- solver = sl { stage = SolverStage 0 }} +-- ps = map parameterise [nl..nu] +-- scanM f z ps -test = do - t <- hybrid (demux predicate) (1, "initial") 40 sineSolv2 - case t of - [] -> fail "Something went wrong during hybrid simulation" - list -> print list +-- test = do +-- t <- hybrid (demux predicate) (1, "initial") 40 sineSolv2 +-- case t of +-- [] -> fail "Something went wrong during hybrid simulation" +-- list -> print list diff --git a/src/IO.hs b/src/IO.hs index a18749e..052f481 100644 --- a/src/IO.hs +++ b/src/IO.hs @@ -2,8 +2,8 @@ module IO where import System.IO ( Handle, hClose, openFile, hPutStrLn, IOMode(WriteMode) ) -import Simulation -import Solver +import Simulation ( Interval(stopTime, startTime) ) +import Solver ( Solver(dt) ) addTime :: IO [[a]] -> Interval -> Solver -> IO [(Double, [a])] addTime answers intv solver = fmap (zip input) answers diff --git a/src/Integrator.hs b/src/Integrator.hs index c582c6b..63f3081 100644 --- a/src/Integrator.hs +++ b/src/Integrator.hs @@ -2,21 +2,25 @@ {-# LANGUAGE RecursiveDo #-} module Integrator where -import Data.IORef -import Control.Monad.Trans +import Data.IORef ( IORef, newIORef, readIORef, writeIORef ) -import Types -import CT +import CT ( Parameters(solver, interval, time, iteration), CT ) import Solver -import Interpolation -import Memo + ( Solver(dt, method, stage), + Method(Euler, RungeKutta2, RungeKutta4), + Stage(SolverStage), + getSolverStage, + iterToTime ) +import Interpolation ( interpolate ) +import Memo ( memo ) +import Control.Monad.Trans.Reader ( ReaderT(ReaderT, runReaderT) ) integ :: CT Double -> CT Double -> CT (CT Double) integ diff i = mdo y <- memo interpolate z - z <- CT $ \ps -> - let f = solverToFunction (method $ solver ps) - in return . CT $ f diff i y + z <- ReaderT $ \ps -> + let f = solverToFunction (method $ solver ps) + in pure . ReaderT $ f diff i y return y -- | The Integrator type represents an integral with caching. @@ -26,45 +30,45 @@ data Integrator = Integrator { initial :: CT Double, -- ^ The initial value. } initialize :: CT a -> CT a -initialize (CT m) = - CT $ \ps -> +initialize m = + ReaderT $ \ps -> if iteration ps == 0 && getSolverStage (stage $ solver ps) == 0 then - m ps + runReaderT m ps else let iv = interval ps sl = solver ps - in m $ ps { time = iterToTime iv sl 0 (SolverStage 0), - iteration = 0, - solver = sl { stage = SolverStage 0 }} + in runReaderT m $ ps { time = iterToTime iv sl 0 (SolverStage 0), + iteration = 0, + solver = sl { stage = SolverStage 0 }} createInteg :: CT Double -> CT Integrator createInteg i = - CT $ \ps -> + ReaderT $ \ps -> do r1 <- newIORef $ initialize i r2 <- newIORef $ initialize i let integ = Integrator { initial = i, cache = r1, computation = r2 } - z = CT $ \ps -> + z = ReaderT $ \ps -> do v <- readIORef (computation integ) - v `apply` ps - y <- memo interpolate z `apply` ps + runReaderT v ps + y <- runReaderT (memo interpolate z) ps writeIORef (cache integ) y return integ readInteg :: Integrator -> CT Double readInteg integ = - CT $ \ps -> (`apply` ps) =<< readIORef (cache integ) + ReaderT $ \ps -> flip runReaderT ps =<< readIORef (cache integ) updateInteg :: Integrator -> CT Double -> CT () updateInteg integ diff = - CT . const $ writeIORef (computation integ) z + ReaderT . const $ writeIORef (computation integ) z where i = initial integ - z = CT $ \ps -> + z = ReaderT $ \ps -> let f = solverToFunction (method $ solver ps) in (\y -> f diff i y ps) =<< readIORef (cache integ) - + solverToFunction Euler = integEuler solverToFunction RungeKutta2 = integRK2 solverToFunction RungeKutta4 = integRK4 @@ -73,17 +77,17 @@ integEuler :: CT Double -> CT Double -> CT Double -> Parameters -> IO Double -integEuler (CT diff) (CT i) (CT y) ps = +integEuler diff i y ps = case iteration ps of 0 -> - i ps + runReaderT i ps n -> do let iv = interval ps sl = solver ps ty = iterToTime iv sl (n - 1) (SolverStage 0) psy = ps { time = ty, iteration = n - 1, solver = sl { stage = SolverStage 0} } - a <- y psy - b <- diff psy + a <- runReaderT y psy + b <- runReaderT diff psy let !v = a + dt (solver ps) * b return v @@ -91,11 +95,11 @@ integRK2 :: CT Double -> CT Double -> CT Double -> Parameters -> IO Double -integRK2 (CT f) (CT i) (CT y) ps = +integRK2 f i y ps = case stage (solver ps) of SolverStage 0 -> case iteration ps of 0 -> - i ps + runReaderT i ps n -> do let iv = interval ps sl = solver ps @@ -105,9 +109,9 @@ integRK2 (CT f) (CT i) (CT y) ps = psy = ps { time = ty, iteration = n - 1, solver = sl { stage = SolverStage 0 }} ps1 = psy ps2 = ps { time = t2, iteration = n - 1, solver = sl { stage = SolverStage 1 }} - vy <- y psy - k1 <- f ps1 - k2 <- f ps2 + vy <- runReaderT y psy + k1 <- runReaderT f ps1 + k2 <- runReaderT f ps2 let !v = vy + dt sl / 2.0 * (k1 + k2) return v SolverStage 1 -> do @@ -118,8 +122,8 @@ integRK2 (CT f) (CT i) (CT y) ps = t1 = ty psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }} ps1 = psy - vy <- y psy - k1 <- f ps1 + vy <- runReaderT y psy + k1 <- runReaderT f ps1 let !v = vy + dt sl * k1 return v _ -> @@ -129,11 +133,11 @@ integRK4 :: CT Double -> CT Double -> CT Double -> Parameters -> IO Double -integRK4 (CT f) (CT i) (CT y) ps = +integRK4 f i y ps = case stage (solver ps) of SolverStage 0 -> case iteration ps of 0 -> - i ps + runReaderT i ps n -> do let iv = interval ps sl = solver ps @@ -147,11 +151,11 @@ integRK4 (CT f) (CT i) (CT y) ps = ps2 = ps { time = t2, iteration = n - 1, solver = sl { stage = SolverStage 1 }} ps3 = ps { time = t3, iteration = n - 1, solver = sl { stage = SolverStage 2 }} ps4 = ps { time = t4, iteration = n - 1, solver = sl { stage = SolverStage 3 }} - vy <- y psy - k1 <- f ps1 - k2 <- f ps2 - k3 <- f ps3 - k4 <- f ps4 + vy <- runReaderT y psy + k1 <- runReaderT f ps1 + k2 <- runReaderT f ps2 + k3 <- runReaderT f ps3 + k4 <- runReaderT f ps4 let !v = vy + dt sl / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4) return v SolverStage 1 -> do @@ -162,8 +166,8 @@ integRK4 (CT f) (CT i) (CT y) ps = t1 = ty psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }} ps1 = psy - vy <- y psy - k1 <- f ps1 + vy <- runReaderT y psy + k1 <- runReaderT f ps1 let !v = vy + dt sl / 2.0 * k1 return v SolverStage 2 -> do @@ -174,8 +178,8 @@ integRK4 (CT f) (CT i) (CT y) ps = t2 = iterToTime iv sl n (SolverStage 1) psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }} ps2 = ps { time = t2, iteration = n, solver = sl { stage = SolverStage 1 }} - vy <- y psy - k2 <- f ps2 + vy <- runReaderT y psy + k2 <- runReaderT f ps2 let !v = vy + dt sl / 2.0 * k2 return v SolverStage 3 -> do @@ -186,8 +190,8 @@ integRK4 (CT f) (CT i) (CT y) ps = t3 = iterToTime iv sl n (SolverStage 2) psy = ps { time = ty, iteration = n, solver = sl { stage = SolverStage 0 }} ps3 = ps { time = t3, iteration = n, solver = sl { stage = SolverStage 2 }} - vy <- y psy - k3 <- f ps3 + vy <- runReaderT y psy + k3 <- runReaderT f ps3 let !v = vy + dt sl * k3 return v _ -> diff --git a/src/Interpolation.hs b/src/Interpolation.hs index 224b048..f343a23 100644 --- a/src/Interpolation.hs +++ b/src/Interpolation.hs @@ -1,9 +1,14 @@ module Interpolation where -import Types -import CT +import CT ( Parameters(solver, interval, time, iteration), CT ) import Simulation + ( Interval(startTime), iterationLoBnd, iterationHiBnd ) import Solver + ( Solver(stage, dt), + Stage(SolverStage, Interpolate), + getSolverStage, + iterToTime ) +import Control.Monad.Trans.Reader ( ReaderT(ReaderT, runReaderT) ) -- | Function to solve floating point approximations neighborhood :: Solver -> Double -> Double -> Bool @@ -12,32 +17,32 @@ neighborhood sl t t' = -- | Discretize the computation in the integration time points. discrete :: CT a -> CT a -discrete (CT m) = - CT $ \ps -> +discrete m = + ReaderT $ \ps -> let st = getSolverStage $ stage (solver ps) - r | st == 0 = m ps + r | st == 0 = runReaderT m ps | st > 0 = let iv = interval ps sl = solver ps n = iteration ps - in m $ ps { time = iterToTime iv sl n (SolverStage 0), - solver = sl {stage = SolverStage 0} } + in runReaderT m $ ps { time = iterToTime iv sl n (SolverStage 0), + solver = sl {stage = SolverStage 0} } | otherwise = let iv = interval ps t = time ps sl = solver ps n = iteration ps t' = startTime iv + fromIntegral (n + 1) * dt sl n' = if neighborhood sl t t' then n + 1 else n - in m $ ps { time = iterToTime iv sl n' (SolverStage 0), - iteration = n', - solver = sl { stage = SolverStage 0} } + in runReaderT m $ ps { time = iterToTime iv sl n' (SolverStage 0), + iteration = n', + solver = sl { stage = SolverStage 0} } in r -- | Interpolate the computation based on the integration time points only. interpolate :: CT Double -> CT Double -interpolate (CT m) = - CT $ \ps -> +interpolate m = + ReaderT $ \ps -> case stage $ solver ps of - SolverStage _ -> m ps + SolverStage _ -> runReaderT m ps Interpolate -> let iv = interval ps sl = solver ps @@ -49,13 +54,13 @@ interpolate (CT m) = t1 = iterToTime iv sl n1 (SolverStage 0) t2 = iterToTime iv sl n2 (SolverStage 0) z1 = - m $ ps { time = t1, - iteration = n1, - solver = sl { stage = SolverStage 0 }} + runReaderT m $ ps { time = t1, + iteration = n1, + solver = sl { stage = SolverStage 0 }} z2 = - m $ ps { time = t2, - iteration = n2, - solver = sl { stage = SolverStage 0 }} + runReaderT m $ ps { time = t2, + iteration = n2, + solver = sl { stage = SolverStage 0 }} in do y1 <- z1 y2 <- z2 return $ y1 + (y2 - y1) * (t - t1) / (t2 - t1) diff --git a/src/Memo.hs b/src/Memo.hs index d0ecc24..28bf9ea 100644 --- a/src/Memo.hs +++ b/src/Memo.hs @@ -2,13 +2,21 @@ module Memo where -import CT +import CT ( CT, Parameters(solver, interval, time, iteration) ) import Solver -import Simulation + ( getSolverStage, + iterToTime, + stageBnds, + stageHiBnd, + Solver(stage, dt), + Stage(SolverStage) ) +import Simulation ( iterationBnds ) -import Data.IORef -import Data.Array +import Data.IORef ( newIORef, readIORef, writeIORef ) +import Data.Array ( Ix ) import Data.Array.IO + ( Ix, readArray, writeArray, MArray(newArray_), IOUArray, IOArray ) +import Control.Monad.Trans.Reader ( ReaderT(ReaderT, runReaderT) ) -- -- | The 'Memo' class specifies a type for which an array can be created. class (MArray IOArray e IO) => Memo e where @@ -28,68 +36,68 @@ instance (MArray IOUArray e IO) => UMemo e where -- the specified interpolation and being aware of the Runge-Kutta method. memo :: UMemo e => (CT e -> CT e) -> CT e -> CT (CT e) -memo tr (CT m) = - CT $ \ps -> - do let sl = solver ps - iv = interval ps - (SolverStage stl, SolverStage stu) = stageBnds sl - (nl, nu) = iterationBnds iv (dt sl) - arr <- newMemoUArray_ ((stl, nl), (stu, nu)) - nref <- newIORef 0 - stref <- newIORef 0 - let r ps = - do let sl = solver ps - iv = interval ps - n = iteration ps - st = getSolverStage $ stage sl - stu = getSolverStage $ stageHiBnd sl - loop n' st' = - if (n' > n) || ((n' == n) && (st' > st)) - then - readArray arr (st, n) - else - let ps' = ps { time = iterToTime iv sl n' (SolverStage st'), - iteration = n', - solver = sl { stage = SolverStage st' }} - in do a <- m ps' - a `seq` writeArray arr (st', n') a - if st' >= stu - then do writeIORef stref 0 - writeIORef nref (n' + 1) - loop (n' + 1) 0 - else do writeIORef stref (st' + 1) - loop n' (st' + 1) - n' <- readIORef nref - st' <- readIORef stref - loop n' st' - return $ tr $ CT r +memo tr m = + ReaderT $ \ps -> do + let sl = solver ps + iv = interval ps + (SolverStage stl, SolverStage stu) = stageBnds sl + (nl, nu) = iterationBnds iv (dt sl) + arr <- newMemoUArray_ ((stl, nl), (stu, nu)) + nref <- newIORef 0 + stref <- newIORef 0 + let r ps = do + let sl = solver ps + iv = interval ps + n = iteration ps + st = getSolverStage $ stage sl + stu = getSolverStage $ stageHiBnd sl + loop n' st' = + if (n' > n) || ((n' == n) && (st' > st)) + then + readArray arr (st, n) + else + let ps' = ps { time = iterToTime iv sl n' (SolverStage st'), + iteration = n', + solver = sl { stage = SolverStage st' }} + in do a <- runReaderT m ps' + a `seq` writeArray arr (st', n') a + if st' >= stu + then do writeIORef stref 0 + writeIORef nref (n' + 1) + loop (n' + 1) 0 + else do writeIORef stref (st' + 1) + loop n' (st' + 1) + n' <- readIORef nref + st' <- readIORef stref + loop n' st' + pure . tr . ReaderT $ r -- | Memoize and order the computation in the integration time points using -- the specified interpolation and without knowledge of the Runge-Kutta method. memo0 :: Memo e => (CT e -> CT e) -> CT e -> CT (CT e) -memo0 tr (CT m) = - CT $ \ps -> - do let iv = interval ps - bnds = iterationBnds iv (dt (solver ps)) - arr <- newMemoArray_ bnds - nref <- newIORef 0 - let r ps = - do let sl = solver ps - iv = interval ps - n = iteration ps - loop n' = - if n' > n - then - readArray arr n - else - let ps' = ps { time = iterToTime iv sl n' (SolverStage 0), - iteration = n', - solver = sl { stage = SolverStage 0} } - in do a <- m ps' - a `seq` writeArray arr n' a - writeIORef nref (n' + 1) - loop (n' + 1) - n' <- readIORef nref - loop n' - return $ tr $ CT r +memo0 tr m = + ReaderT $ \ps -> do + let iv = interval ps + bnds = iterationBnds iv (dt (solver ps)) + arr <- newMemoArray_ bnds + nref <- newIORef 0 + let r ps = do + let sl = solver ps + iv = interval ps + n = iteration ps + loop n' = + if n' > n + then + readArray arr n + else + let ps' = ps { time = iterToTime iv sl n' (SolverStage 0), + iteration = n', + solver = sl { stage = SolverStage 0} } + in do a <- runReaderT m ps' + a `seq` writeArray arr n' a + writeIORef nref (n' + 1) + loop (n' + 1) + n' <- readIORef nref + loop n' + pure . tr . ReaderT $ r diff --git a/src/Simulation.hs b/src/Simulation.hs index 4e88830..79ab1cd 100644 --- a/src/Simulation.hs +++ b/src/Simulation.hs @@ -1,6 +1,6 @@ module Simulation where -import Types +import Types ( Iteration, TimeStep ) -- | It defines a time interval data Interval = Interval { startTime :: Double, -- ^ the start time diff --git a/src/Solver.hs b/src/Solver.hs index 300c71d..b63ab54 100644 --- a/src/Solver.hs +++ b/src/Solver.hs @@ -1,7 +1,7 @@ module Solver where -import Types -import Simulation +import Types ( Iteration ) +import Simulation ( Interval(startTime) ) -- | It defines configurations to use within the solver data Solver = Solver { dt :: Double, -- ^ the integration time step