-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
7 changed files
with
352 additions
and
117 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
module Distributions where | ||
|
||
import Sampler | ||
|
||
import qualified System.Random.MWC as MWC | ||
|
||
import Control.Monad | ||
import Control.Monad.Reader | ||
|
||
-- | Prob in range [0,1] | ||
type Prob = Double | ||
|
||
-- | Uniform standard continuous | ||
random :: Sampler Double | ||
random = Sampler $ ask >>= MWC.uniform | ||
|
||
uniform :: Double -> Double -> Sampler Double | ||
uniform a b | a > b = uniform b a | ||
uniform a b | a <= b = (\x -> (b - a) * x + a) <$> random | ||
|
||
bernoulli :: Prob -> Sampler Bool | ||
bernoulli p = (<=p) <$> random | ||
|
||
binomial :: Int -> Prob -> Sampler Int | ||
binomial n p = length . filter id <$> replicateM n (bernoulli p) | ||
|
||
gaussian :: Double -> Double -> Sampler Double | ||
gaussian μ σ² = | ||
(\u1 u2 -> μ + σ² * sqrt (-2 * log u1) * cos (2 * pi * u2)) | ||
<$> random <*> random | ||
|
||
gk :: Double -> Double -> Double -> Double -> Sampler Double | ||
gk a b g k = let c = 0.8 in | ||
(\z -> a | ||
+ b | ||
* (1 + c * tanh (g * z / 2)) | ||
* z | ||
* (1 + z**2)**k) | ||
<$> gaussian 0 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
{-# LANGUAGE DuplicateRecordFields #-} | ||
|
||
module Examples where | ||
|
||
import Distributions | ||
import Rejection | ||
import Metropolis | ||
import Sampler | ||
|
||
import qualified System.Random.MWC as MWC | ||
|
||
import Control.Parallel.Strategies | ||
import Control.Monad | ||
|
||
import Data.List | ||
|
||
-- call the examples for an external script to graph the results | ||
main :: IO () | ||
main = do | ||
model <- getLine | ||
case model of | ||
"weibull" -> approxWeibull 1 1.8 >>= print | ||
"gk" -> abcGK >>= \pars -> print (fst <$> pars) >> print (snd <$> pars) | ||
|
||
approxWeibull :: Double -> Double -> IO [Double] | ||
approxWeibull λ k = let | ||
kernel = RSMC | ||
{ prior = uniform 0 5 | ||
, priorDensity = const 2 | ||
, targetDensity = \x -> (k / λ) * (x / λ)**(k-1) * exp (-(x / λ)**k) | ||
} | ||
in MWC.createSystemRandom >>= sample (rs 10000 kernel) | ||
|
||
abcGK :: IO [(Double, Double)] | ||
abcGK = let | ||
-- y <- gk 0 1 -0.3 0.8 | ||
y :: [Double] = [-0.5104415370506095,-0.7036434632465107,-0.2668368363498056,0.6815278539148982,1.8065750007642016,-1.1482634213252663,-0.10660853352738332,1.2469280832101104,0.7656001636856136,-0.7320043142261864,-4.028497402676845,1.0387799211913844,-1.8660457828367103,-4.0299248430163175,-4.741513290215539,0.6530403315829644,0.24479825013251938,2.6113731175552144,-0.25192550180978834,-5.460086525801049e-2,1.1772787229049015,-5.169184991582197,0.2460051914736727,2.850651742166059,2.803350122147847,-2.5138331249271295,-0.16021357084783153,7.95598130026318e-2,0.1579509935011216,1.2485569713553373,-0.7026650541206256,-6.356494098001902,-1.1067147258067545,-0.3656293496022067,2.5930550010502995,1.4006864549753912,1.3287563023623362e-2,-0.4934879835848968,-0.1152751265700261,4.124238713981147,-0.3274281061664352,17.815877356198563,2.439324175744774,-0.3556365214539682,4.184379956930912,0.5211002573326591,-7.544726539658729,1.1432160816505454,0.10474592715483252,2.2598574690278124,-5.982209564872222e-2,0.3970662237111481,-0.1226321303550215,-2.833925560844585,5.606732231516373,0.4775251542596893,5.521384007685723,-1.3339762548390766,-0.5803242310454019,2.036198920878609,1.011632562481338,-6.607706277661513e-2,3.8510553864065047,-0.11955527847164106,1.6288674013209972,-1.197855950997007,-0.9812973027028926,-3.3243812608680865,-0.9287618414540904,0.6247003134293596,-0.3608071093058975,-1.6574586764389634,0.8422068782827897,1.2963993547489352,3.210441947634901e-2,1.0709150621786345,0.24485693966354982,-1.682055190609811,-1.151283428155862,7.170554007117898e-2,4.8061093617134505,-1.5299907999218874,-2.9428392355646134,0.48157705348963215,2.5763918848022368e-2,0.7333915815410592,-15.221663665916477,-1.202260148652221,-0.2840124914985552,2.3367624660280866,-0.7582762004885446,-4.103347263586696,1.1093339134558795,-1.1724787237984193,0.2366040973823823,-2.1128076576945514,-0.8438843816125938,3.0540896846195116,2.5287449193874108,1.4229384570569787,-1.6480220917225217,-4.134782467746005,0.39280543241777094,0.8511886602772839,2.7602672336570637,0.43392287493375675,-0.4682323987521255,0.7613470128077908,-11.930262800184837,-6.431491814607961,3.0805223447668006,1.9298178856672217,3.2534332707055844,4.065319796389563,-1.7722270288855324,-0.5834553617246786,2.5598942267685425,3.8319128595273138,-1.2609803136050555,-1.5906702843789262,8.158738166147467e-2,-0.7764482576234684,0.6998190710622543,-1.1476144230138676,-1.370813552369474,2.5052687698461074,-0.5323437463478694,0.759077074826166,-0.11759656272783872,-0.24720087400075288,3.72415007937246,1.7881388423706739,1.149159194536326,0.3400441025811831,-0.6689134211264488,0.7618204570110713,-0.7033980200595047,0.2540934346295764,-5.240906405876261,0.41994043930730424,2.2754579486301916,-12.254575244919574,-0.25598032216004896,1.4874270863529166,1.2377717637918372e-3,-0.27220660238528827,0.23363929554286147,-0.574466434237023,-4.759971356174461,0.24492525256681869] | ||
summarise x = let | ||
x' = sort x | ||
n = length x | ||
mean = (sum x) / (fromIntegral n) | ||
sd = (sum . map (\x -> (x - mean)**2) $ x) / (fromIntegral n-1) | ||
in (mean, sd | ||
, (1/(fromIntegral n * sd**3)) * (sum . map (\x -> (x - mean)**3)) x | ||
, (1/(fromIntegral n * sd**4)) * (sum . map (\x -> (x - mean)**4)) x) | ||
kernel = MABC | ||
{ observations = summarise y | ||
, model = \(g, k) -> summarise <$> replicateM 50 (gk 0 1 g k) | ||
, prior = \(g, k) -> (if -2 <= g && g <= 0 then 1 else 0) * (if 0 <= k && k <= 2 then 1 else 0) | ||
, transition = \(g, k) -> (,) <$> gaussian g 2.36 <*> gaussian k 2.36 | ||
, distance = \(x0, x1, x2, x3) (y0, y1, y2, y3) -> | ||
(x0 - y0)**2 + (x1 - y1)**2 + (x2 - y2)**2 + (x3 - y3)**2 | ||
, tolerance = 1.0 | ||
} | ||
in do | ||
gen <- MWC.createSystemRandom | ||
(g0:g1:g2:g3:g4:g5:_) <- sample (replicateM 6 $ uniform (-2) 0) gen | ||
(k0:k1:k2:k3:k4:k5:_) <- sample (replicateM 6 $ uniform 0 2) gen | ||
runEval $ do | ||
pars0 <- rpar (sample (mh 10000 kernel (g0,k0)) gen) | ||
pars1 <- rpar (sample (mh 10000 kernel (g1,k1)) gen) | ||
pars2 <- rpar (sample (mh 10000 kernel (g2,k2)) gen) | ||
pars3 <- rpar (sample (mh 10000 kernel (g3,k3)) gen) | ||
pars4 <- rpar (sample (mh 10000 kernel (g4,k4)) gen) | ||
pars5 <- rpar (sample (mh 10000 kernel (g5,k5)) gen) | ||
return $ do | ||
a <- pars0 ; b <- pars1 ; c <- pars2 ; d <- pars3 ; e <- pars4 ; f <- pars5 | ||
return $ a <> b <> c <> d <> e <> f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
-- | Metropolis Sampling | ||
|
||
{-# LANGUAGE FunctionalDependencies #-} | ||
{-# LANGUAGE RecordWildCards #-} | ||
{-# LANGUAGE DuplicateRecordFields #-} | ||
|
||
module Metropolis where | ||
|
||
import Sampler | ||
import Distributions | ||
|
||
mh :: MHKernel k a => Int -> k -> a -> Sampler [a] | ||
mh 0 _ _ = return [] | ||
mh n k x_0 = do | ||
x_1 <- k `perturb` x_0 | ||
a <- accepts k x_0 x_1 | ||
if a | ||
then (x_1:) <$> mh (n-1) k x_1 | ||
else (x_0:) <$> mh (n-1) k x_0 | ||
|
||
class MHKernel k a | k -> a where | ||
perturb :: k -> a -> Sampler a | ||
accepts :: k -> a -> a -> Sampler Bool | ||
|
||
data MABC θ ω = MABC | ||
{ observations :: ω | ||
, model :: θ -> Sampler ω | ||
, prior :: θ -> Double -- ^ density | ||
, transition :: θ -> Sampler θ -- ^ assumed symmetrical | ||
, distance :: ω -> ω -> Double | ||
, tolerance :: Double | ||
} | ||
|
||
instance MHKernel (MABC θ ω) θ where | ||
perturb :: MABC θ ω -> θ -> Sampler θ | ||
perturb MABC{..} = transition | ||
|
||
accepts :: MABC θ ω -> θ -> θ -> Sampler Bool | ||
accepts MABC{..} θ θ' = do | ||
x <- model θ' | ||
if distance x observations <= tolerance | ||
then bernoulli $ min 1 (prior θ' / prior θ) | ||
else return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
-- | Rejection Sampling | ||
|
||
{-# LANGUAGE FunctionalDependencies #-} | ||
{-# LANGUAGE RecordWildCards #-} | ||
{-# LANGUAGE DuplicateRecordFields #-} | ||
|
||
module Rejection where | ||
|
||
import Sampler | ||
import Distributions | ||
|
||
import Data.List | ||
|
||
import Control.Monad | ||
|
||
import qualified System.Random.MWC as MWC | ||
|
||
class RSKernel k a | k -> a where | ||
propose :: k -> Sampler a | ||
accepts :: k -> a -> Sampler Bool | ||
|
||
rs :: RSKernel k a => Int -> k -> Sampler [a] | ||
rs 0 _ = return [] | ||
rs n k = do | ||
x <- propose k | ||
a <- k `accepts` x | ||
if a | ||
then (x:) <$> rs (n-1) k | ||
else rs (n-1) k | ||
|
||
data RSMC ω = RSMC | ||
{ prior :: Sampler ω | ||
, priorDensity :: ω -> Double -- ^ scaled by M | ||
, targetDensity :: ω -> Double | ||
} | ||
|
||
instance RSKernel (RSMC ω) ω where | ||
propose :: RSMC ω -> Sampler ω | ||
propose RSMC{..} = prior | ||
|
||
accepts :: RSMC ω -> ω -> Sampler Bool | ||
accepts RSMC{..} x = let | ||
α = targetDensity x / priorDensity x | ||
in (bernoulli $ min 1 α) | ||
|
||
data RSABC θ ω = RSABC | ||
{ observations :: ω | ||
, model :: θ -> Sampler ω | ||
, prior :: Sampler θ | ||
, distance :: ω -> ω -> Double | ||
, tolerance :: Double | ||
} | ||
|
||
instance Eq ω => RSKernel (RSABC θ ω) θ where | ||
propose :: RSABC θ ω -> Sampler θ | ||
propose RSABC{..} = prior | ||
|
||
accepts :: RSABC θ ω -> θ -> Sampler Bool | ||
accepts RSABC{..} θ = do | ||
x <- model θ | ||
return $ distance x observations <= tolerance |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{-# LANGUAGE GeneralizedNewtypeDeriving #-} | ||
|
||
module Sampler where | ||
|
||
import qualified System.Random.MWC as MWC | ||
|
||
import Control.Monad.Reader | ||
|
||
type Gen = MWC.GenIO | ||
|
||
newtype Sampler a = Sampler { runSampler :: ReaderT Gen IO a } | ||
deriving (Functor, Applicative, Monad, MonadIO) | ||
|
||
sample :: Sampler a -> Gen -> IO a | ||
sample = runReaderT . runSampler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import sys | ||
import subprocess | ||
import matplotlib.pyplot as plt | ||
|
||
def main(model): | ||
if model == "weibull": | ||
result = subprocess.run(["stack", "runhaskell", "Examples.hs"], input="weibull\n".encode(), stdout=subprocess.PIPE).stdout.decode() | ||
result = [float(i) for i in result[1:-2].split(",")] # turn into list | ||
plt.hist(result, bins=35, range=(-1, 8)); | ||
plt.show() | ||
elif model == "gk": | ||
result = subprocess.run(["stack", "runhaskell", "Examples.hs"], input="gk\n".encode(), stdout=subprocess.PIPE).stdout.decode() | ||
result = result[0:-1].split("\n") | ||
g = [float(i) for i in result[0][1:-1].split(",")] | ||
k = [float(i) for i in result[1][1:-1].split(",")] | ||
plt.hist(g, bins=35, range=(-3, 3), alpha=0.8) | ||
plt.hist(k, bins=35, range=(-3, 3), alpha=0.8) | ||
plt.show() | ||
else: | ||
print("bad argument") | ||
|
||
if __name__ == '__main__': | ||
main(sys.argv[1]) |
Oops, something went wrong.