Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to use C#7 features #7

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 102 additions & 236 deletions ProbabilityMonad/Base.cs

Large diffs are not rendered by default.

97 changes: 31 additions & 66 deletions ProbabilityMonad/DistGadt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,20 +54,15 @@ public interface ParallelDistInterpreter<A, X>
public class Pure<A> : Dist<A>
{
public readonly A Value;

public Pure(A value)
{
Value = value;
}
=> Value = value;

public X Run<X>(DistInterpreter<A, X> interpreter)
{
return interpreter.Pure(Value);
}
=> interpreter.Pure(Value);

public X RunParallel<X>(ParallelDistInterpreter<A, X> interpreter)
{
return interpreter.Pure(Value);
}
=> interpreter.Pure(Value);
}

/// <summary>
Expand All @@ -76,20 +71,15 @@ public X RunParallel<X>(ParallelDistInterpreter<A, X> interpreter)
public class Primitive<A> : Dist<A>
{
public readonly PrimitiveDist<A> dist;

public Primitive(PrimitiveDist<A> dist)
{
this.dist = dist;
}
=> this.dist = dist;

public X Run<X>(DistInterpreter<A, X> interpreter)
{
return interpreter.Primitive(dist);
}
=> interpreter.Primitive(dist);

public X RunParallel<X>(ParallelDistInterpreter<A, X> interpreter)
{
return interpreter.Primitive(dist);
}
=> interpreter.Primitive(dist);
}

/// <summary>
Expand All @@ -99,21 +89,15 @@ public class Conditional<A> : Dist<A>
{
public readonly Func<A, Prob> likelihood;
public readonly Dist<A> dist;

public Conditional(Func<A, Prob> likelihood, Dist<A> dist)
{
this.likelihood = likelihood;
this.dist = dist;
}
=> (this.likelihood, this.dist) = (likelihood, dist);

public X Run<X>(DistInterpreter<A, X> interpreter)
{
return interpreter.Conditional(likelihood, dist);
}
=> interpreter.Conditional(likelihood, dist);

public X RunParallel<X>(ParallelDistInterpreter<A, X> interpreter)
{
return interpreter.Conditional(likelihood, dist);
}
=> interpreter.Conditional(likelihood, dist);
}

/// <summary>
Expand All @@ -125,6 +109,7 @@ public class RunIndependent3<T1, T2, T3, A> : Dist<A>
public readonly Dist<T2> second;
public readonly Dist<T3> third;
public readonly Func<T1, T2, T3, Dist<A>> run;

public RunIndependent3(Dist<T1> first, Dist<T2> second, Dist<T3> third, Func<T1, T2, T3, Dist<A>> run)
{
this.first = first;
Expand All @@ -145,9 +130,7 @@ from result in run(x, y, z)
}

public X RunParallel<X>(ParallelDistInterpreter<A, X> interpreter)
{
return interpreter.RunIndependent3(first, second, third, run);
}
=> interpreter.RunIndependent3(first, second, third, run);
}

/// <summary>
Expand All @@ -158,6 +141,7 @@ public class RunIndependent<T1, T2, A> : Dist<A>
public readonly Dist<T1> first;
public readonly Dist<T2> second;
public readonly Func<T1, T2, Dist<A>> run;

public RunIndependent(Dist<T1> first, Dist<T2> second, Func<T1, T2, Dist<A>> run)
{
this.first = first;
Expand All @@ -176,9 +160,7 @@ from result in run(x, y)
}

public X RunParallel<X>(ParallelDistInterpreter<A, X> interpreter)
{
return interpreter.RunIndependent(first, second, run);
}
=> interpreter.RunIndependent(first, second, run);
}

/// <summary>
Expand All @@ -187,21 +169,16 @@ public X RunParallel<X>(ParallelDistInterpreter<A, X> interpreter)
public class Independent<A> : Dist<Dist<A>>
{
public readonly Dist<A> dist;

public Independent(Dist<A> dist)
{
this.dist = dist;
}
=> this.dist = dist;

// Run sequentially if we're not using a parallel interpreter
public X Run<X>(DistInterpreter<Dist<A>, X> interpreter)
{
return Return(dist).Run(interpreter);
}
=> Return(dist).Run(interpreter);

public X RunParallel<X>(ParallelDistInterpreter<Dist<A>, X> interpreter)
{
return interpreter.Independent(Return(dist));
}
=> interpreter.Independent(Return(dist));
}

/// <summary>
Expand All @@ -211,21 +188,18 @@ public class Bind<Y, A> : Dist<A>
{
public readonly Dist<Y> dist;
public readonly Func<Y, Dist<A>> bind;

public Bind(Dist<Y> dist, Func<Y, Dist<A>> bind)
{
this.dist = dist;
this.bind = bind;
}

public X Run<X>(DistInterpreter<A, X> interpreter)
{
return interpreter.Bind(dist, bind);
}
=> interpreter.Bind(dist, bind);

public X RunParallel<X>(ParallelDistInterpreter<A, X> interpreter)
{
return interpreter.Bind(dist, bind);
}
=> interpreter.Bind(dist, bind);
}

/// <summary>
Expand All @@ -235,36 +209,27 @@ public X RunParallel<X>(ParallelDistInterpreter<A, X> interpreter)
public static class DistExt
{
public static Dist<B> Select<A, B>(this Dist<A> dist, Func<A, B> f)
{
return new Bind<A, B>(dist, a => new Pure<B>(f(a)));
}
=> new Bind<A, B>(dist, a => new Pure<B>(f(a)));

public static Dist<B> SelectMany<A, B>(this Dist<A> dist, Func<A, Dist<B>> bind)
{
return new Bind<A, B>(dist, bind);
}
=> new Bind<A, B>(dist, bind);

public static Dist<C> SelectMany<A, B, C>(
this Dist<A> dist,
Func<A, Dist<B>> bind,
Func<A, B, C> project
)
{
return
new Bind<A, C>(dist, a =>
new Bind<B, C>(bind(a), b =>
new Pure<C>(project(a, b))
)
);
}
=> new Bind<A, C>(dist, a =>
new Bind<B, C>(bind(a), b =>
new Pure<C>(project(a, b))
)
);

/// <summary>
/// Default to using recursion depth limit of 100
/// </summary>
public static Dist<IEnumerable<A>> Sequence<A>(this IEnumerable<Dist<A>> dists)
{
return SequenceWithDepth(dists, 100);
}
=> SequenceWithDepth(dists, 100);

/// <summary>
/// This implementation sort of does trampolining to avoid stack overflows,
Expand Down
14 changes: 4 additions & 10 deletions ProbabilityMonad/Finite/FiniteDist.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,10 @@ namespace ProbCSharp
public class FiniteDist<A>
{
public FiniteDist(Samples<A> samples)
{
Explicit = samples;
}
=> Explicit = samples;

public FiniteDist(params ItemProb<A>[] samples)
{
Explicit = Samples(samples);
}
=> Explicit = Samples(samples);

public Samples<A> Explicit { get; }
}
Expand All @@ -31,10 +27,8 @@ public static class FiniteDistMonad
/// fmap f (FiniteDist (Samples xs)) = FiniteDist $ Samples $ map (first f) xs
/// </summary>
public static FiniteDist<B> Select<A, B>(this FiniteDist<A> self, Func<A, B> select)
{
return new FiniteDist<B>(Samples(self.Explicit.Weights.Select(i =>
ItemProb(select(i.Item), i.Prob))));
}
=> new FiniteDist<B>(Samples(self.Explicit.Weights.Select(i =>
ItemProb(@select(i.Item), i.Prob))));

/// <summary>
/// (FiniteDist dist) >>= bind = FiniteDist $ do
Expand Down
40 changes: 11 additions & 29 deletions ProbabilityMonad/Finite/FiniteExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@ public static A Pick<A>(this FiniteDist<A> distribution, Prob pickProb)
/// Lifts a FiniteDist<A> into a SampleableDist<A>
/// </summary>
public static PrimitiveDist<A> ToSampleDist<A>(this FiniteDist<A> dist)
{
return new SampleDist<A>(() =>
=> new SampleDist<A>(() =>
{
var rand = new MathNet.Numerics.Distributions.ContinuousUniform().Sample();
return dist.Pick(Prob(rand));
});
}

/// <summary>
/// Returns the probability of a certain event
Expand All @@ -51,71 +49,55 @@ public static Prob ProbOf<A>(this FiniteDist<A> dist, Func<A, bool> eventTest)
/// Reweight by a probability that depends on associated item
/// </summary>
public static FiniteDist<A> ConditionSoft<A>(this FiniteDist<A> distribution, Func<A, Prob> likelihood)
{
return new FiniteDist<A>(
=> new FiniteDist<A>(
distribution.Explicit
.Select(p => ItemProb(p.Item, likelihood(p.Item).Mult(p.Prob)))
.Normalize()
);
}

/// <summary>
/// Reweight by a probability that depends on associated item, without normalizing
/// </summary>
public static FiniteDist<A> ConditionSoftUnnormalized<A>(this FiniteDist<A> distribution, Func<A, Prob> likelihood)
{
return new FiniteDist<A>(
=> new FiniteDist<A>(
distribution.Explicit.Select(p => ItemProb(p.Item, likelihood(p.Item).Mult(p.Prob)))
);
}

/// <summary>
/// Hard reweight by a condition that depends on associated item
/// </summary>
public static FiniteDist<A> ConditionHard<A>(this FiniteDist<A> distribution, Func<A, bool> condition)
{
return new FiniteDist<A>(
=> new FiniteDist<A>(
distribution.Explicit
.Select(p => ItemProb(p.Item, condition(p.Item) ? p.Prob : Prob(0)))
.Normalize()
);


}

/// <summary>
/// Computes the posterior distribution, given a piece of data and a likelihood function
/// </summary>
public static FiniteDist<A> UpdateOn<A, D>(this FiniteDist<A> prior, Func<A, D, Prob> likelihood, D datum)
{
return prior.ConditionSoft(w => likelihood(w, datum));
}
=> prior.ConditionSoft(w => likelihood(w, datum));

/// <summary>
/// Computes the posterior distribution, given a list of data and a likelihood function
/// </summary>
public static FiniteDist<A> UpdateOn<A, D>(this FiniteDist<A> prior, Func<A, D, Prob> likelihood, IEnumerable<D> data)
{
return data.Aggregate(prior, (dist, datum) => dist.UpdateOn(likelihood, datum));
}
=> data.Aggregate(prior, (dist, datum) => dist.UpdateOn(likelihood, datum));

/// <summary>
/// Normalize a finite distribution
/// </summary>
public static FiniteDist<A> Normalize<A>(this FiniteDist<A> dist)
{
return new FiniteDist<A>(dist.Explicit.Normalize());
}

=> new FiniteDist<A>(dist.Explicit.Normalize());

/// <summary>
/// Join two independent distributions
/// </summary>
public static FiniteDist<Tuple<A,B>> Join<A, B>(this FiniteDist<A> self, FiniteDist<B> other)
{
return from a in self
from b in other
select new Tuple<A, B>(a, b);
}
=> from a in self
from b in other
select new Tuple<A, B>(a, b);

/// <summary>
/// Returns all elements, and the collection without that element
Expand Down
Loading