Skip to content

Commit

Permalink
fix and-node update
Browse files Browse the repository at this point in the history
  • Loading branch information
ThijsvdLaar committed May 25, 2022
1 parent 005a16c commit 3a502a7
Showing 1 changed file with 80 additions and 69 deletions.
149 changes: 80 additions & 69 deletions demo/implementing_additional_nodes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,17 @@
"Note that the combination $y=1, z=0$ is disallowed by the truth table, and therefore $f(y=1, x, z=0)=0$ for both $x$. Furthermore, note that the combination $y=0, z=0$ appears twice, such that $f(y=0, x, z=0)=1$ for both $x$. This leads to\n",
"\\begin{align*} \n",
" \\overleftarrow{\\mu}_x(x) &= (1 - p_y)(1 - p_z) + 0 + (1 - p_y)p_z f(y=0, x, z=1) + p_y p_z f(y=1, x, z=1)\\\\\n",
" &= \\begin{cases}(1-p_y)(1-p_z) &\\text{ if } x=1\\\\\n",
" (1-p_y)(1-p_z) + p_z &\\text{ if } x=0 \\end{cases}\\\\\n",
" &\\propto \\mathcal{B}er\\left(x \\bigg| \\frac{a}{2a + p_z}\\right)\\,, \\text{ with } a=(1-p_y)(1-p_z).\n",
" &= \\begin{cases}(1-p_y)(1-p_z) + p_y p_z &\\text{ if } x=1\\\\\n",
" (1-p_y)(1-p_z) + p_z - p_y p_x &\\text{ if } x=0 \\end{cases}\\\\\n",
" &= \\begin{cases}1 - p_y - p_z + 2 p_y p_z &\\text{ if } x=1\\\\\n",
" 1-p_y &\\text{ if } x=0 \\end{cases}\\\\\n",
" &\\propto \\mathcal{B}er\\left(x \\bigg| \\frac{1 - p_y - p_z + 2 p_y p_z}{2 - 2 p_y - p_z + 2 p_y p_z}\\right)\\,.\n",
"\\end{align*}\n",
"\n",
"### Backward message for $z$\n",
"From symmerty, the backward message for $z$ follows as\n",
"\\begin{align*}\n",
" \\overleftarrow{\\mu}_z(z) \\propto \\mathcal{B}er\\left(z \\bigg| \\frac{a}{2a + p_x}\\right)\\,, \\text{ with } a=(1-p_y)(1-p_x)\\,.\n",
" \\overleftarrow{\\mu}_z(z) \\propto \\mathcal{B}er\\left(z \\bigg| \\frac{1 - p_y - p_x + 2 p_y p_x}{2 - 2 p_y - p_x + 2 p_y p_x}\\right)\\,.\n",
"\\end{align*}\n",
"\n",
"## Unit Tests\n",
Expand All @@ -76,7 +78,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -144,7 +146,7 @@
" @test isApplicable(SPAndIn1BNB, [Message{Bernoulli}, Nothing, Message{Bernoulli}]) \n",
" @test !isApplicable(SPAndIn1BNB, [Message{Bernoulli}, Message{Bernoulli}, Nothing]) \n",
"\n",
" @test ruleSPAndIn1BNB(Message(Univariate, Bernoulli, p=0.1), nothing, Message(Univariate, Bernoulli, p=0.25)) == Message(Univariate, Bernoulli, p=0.421875)\n",
" @test ruleSPAndIn1BNB(Message(Univariate, Bernoulli, p=0.1), nothing, Message(Univariate, Bernoulli, p=0.25)) == Message(Univariate, Bernoulli, p=0.4375)\n",
"end\n",
"\n",
"@testset \"SPAndIn2BBN\" begin\n",
Expand All @@ -153,7 +155,7 @@
" @test isApplicable(SPAndIn2BBN, [Message{Bernoulli}, Message{Bernoulli}, Nothing]) \n",
" @test !isApplicable(SPAndIn2BBN, [Nothing, Message{Bernoulli}, Message{Bernoulli}]) \n",
"\n",
" @test ruleSPAndIn2BBN(Message(Univariate, Bernoulli, p=0.1), Message(Univariate, Bernoulli, p=0.25), nothing) == Message(Univariate, Bernoulli, p=0.421875)\n",
" @test ruleSPAndIn2BBN(Message(Univariate, Bernoulli, p=0.1), Message(Univariate, Bernoulli, p=0.25), nothing) == Message(Univariate, Bernoulli, p=0.4375)\n",
"end\n",
"\n",
"end"
Expand All @@ -169,7 +171,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -233,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -265,7 +267,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -279,17 +281,19 @@
"function ruleSPAndIn1BNB(msg_out::Message{Bernoulli}, msg_in1::Nothing, msg_in2::Message{Bernoulli})\n",
" p_out = msg_out.dist.params[:p]\n",
" p_in2 = msg_in2.dist.params[:p]\n",
" a = (1 - p_out)*(1 - p_in2)\n",
" a = 1 - p_out - p_in2 + 2*p_out*p_in2\n",
" b = 1 - p_out\n",
"\n",
" return Message(Univariate, Bernoulli, p=a/(2*a + p_in2))\n",
" return Message(Univariate, Bernoulli, p=a/(a + b))\n",
"end\n",
"\n",
"function ruleSPAndIn1BNB(msg_out::Message{Bernoulli}, msg_in1::Message{Bernoulli}, msg_in2::Nothing)\n",
" p_out = msg_out.dist.params[:p]\n",
" p_in1 = msg_in1.dist.params[:p]\n",
" a = (1 - p_out)*(1 - p_in1)\n",
" a = 1 - p_out - p_in1 + 2*p_out*p_in1\n",
" b = 1 - p_out\n",
"\n",
" return Message(Univariate, Bernoulli, p=a/(2*a + p_in1))\n",
" return Message(Univariate, Bernoulli, p=a/(a + b))\n",
"end\n",
";"
]
Expand All @@ -305,7 +309,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -321,7 +325,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand All @@ -338,99 +342,99 @@
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 409)\">\r\n",
"<title>G</title>\r\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-409 336.56,-409 336.56,4 -4,4\"/>\r\n",
"<!-- 12854391413402327007 -->\r\n",
"<!-- 4739411105901139578 -->\r\n",
"<g id=\"node1\" class=\"node\">\r\n",
"<title>12854391413402327007</title>\r\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"59.56,-279 5.56,-279 5.56,-225 59.56,-225 59.56,-279\"/>\r\n",
"<text text-anchor=\"middle\" x=\"32.56\" y=\"-249.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">clamp_3</text>\r\n",
"</g>\r\n",
"<!-- 8766647308822353642 -->\r\n",
"<g id=\"node2\" class=\"node\">\r\n",
"<title>8766647308822353642</title>\r\n",
"<title>4739411105901139578</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"150.56,-405 78.56,-405 78.56,-333 150.56,-333 150.56,-405\"/>\r\n",
"<text text-anchor=\"middle\" x=\"114.56\" y=\"-371.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">Ber</text>\r\n",
"<text text-anchor=\"middle\" x=\"114.56\" y=\"-361.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">bernoulli_3</text>\r\n",
"</g>\r\n",
"<!-- 8766647308822353642&#45;&#45;12854391413402327007 -->\r\n",
"<g id=\"edge4\" class=\"edge\">\r\n",
"<title>8766647308822353642&#45;&#45;12854391413402327007</title>\r\n",
"<!-- 11099479730505282382 -->\r\n",
"<g id=\"node5\" class=\"node\">\r\n",
"<title>11099479730505282382</title>\r\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"59.56,-279 5.56,-279 5.56,-225 59.56,-225 59.56,-279\"/>\r\n",
"<text text-anchor=\"middle\" x=\"32.56\" y=\"-249.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">clamp_3</text>\r\n",
"</g>\r\n",
"<!-- 4739411105901139578&#45;&#45;11099479730505282382 -->\r\n",
"<g id=\"edge6\" class=\"edge\">\r\n",
"<title>4739411105901139578&#45;&#45;11099479730505282382</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M78.35,-357.2C58.09,-349.09 34.55,-335.75 22.56,-315 16.34,-304.24 17.64,-290.77 21,-279.13\"/>\r\n",
"<text text-anchor=\"start\" x=\"22.56\" y=\"-308.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\" fill=\"red\">clamp_3</text>\r\n",
"<text text-anchor=\"start\" x=\"0\" y=\"-281.73\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">1 out </text>\r\n",
"<text text-anchor=\"start\" x=\"64.35\" y=\"-359.8\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">2 p </text>\r\n",
"</g>\r\n",
"<!-- 12815532752540873597 -->\r\n",
"<!-- 1579320132496355312 -->\r\n",
"<g id=\"node6\" class=\"node\">\r\n",
"<title>12815532752540873597</title>\r\n",
"<title>1579320132496355312</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"246.56,-288 174.56,-288 174.56,-216 246.56,-216 246.56,-288\"/>\r\n",
"<text text-anchor=\"middle\" x=\"210.56\" y=\"-254.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">And</text>\r\n",
"<text text-anchor=\"middle\" x=\"210.56\" y=\"-244.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">and_1</text>\r\n",
"</g>\r\n",
"<!-- 8766647308822353642&#45;&#45;12815532752540873597 -->\r\n",
"<g id=\"edge5\" class=\"edge\">\r\n",
"<title>8766647308822353642&#45;&#45;12815532752540873597</title>\r\n",
"<!-- 4739411105901139578&#45;&#45;1579320132496355312 -->\r\n",
"<g id=\"edge4\" class=\"edge\">\r\n",
"<title>4739411105901139578&#45;&#45;1579320132496355312</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M143.88,-332.87C155.68,-318.74 169.27,-302.46 181.08,-288.31\"/>\r\n",
"<text text-anchor=\"start\" x=\"164.56\" y=\"-308.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\" fill=\"red\">y</text>\r\n",
"<text text-anchor=\"start\" x=\"160.08\" y=\"-290.91\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">1 out </text>\r\n",
"<text text-anchor=\"start\" x=\"122.88\" y=\"-326.47\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">1 out </text>\r\n",
"</g>\r\n",
"<!-- 15409961985777379545 -->\r\n",
"<g id=\"node3\" class=\"node\">\r\n",
"<title>15409961985777379545</title>\r\n",
"<!-- 13119880213352482441 -->\r\n",
"<g id=\"node2\" class=\"node\">\r\n",
"<title>13119880213352482441</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"145.56,-171 73.56,-171 73.56,-99 145.56,-99 145.56,-171\"/>\r\n",
"<text text-anchor=\"middle\" x=\"109.56\" y=\"-137.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">Ber</text>\r\n",
"<text text-anchor=\"middle\" x=\"109.56\" y=\"-127.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">bernoulli_1</text>\r\n",
"<text text-anchor=\"middle\" x=\"109.56\" y=\"-127.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">bernoulli_2</text>\r\n",
"</g>\r\n",
"<!-- 17911493520242536387 -->\r\n",
"<g id=\"node5\" class=\"node\">\r\n",
"<title>17911493520242536387</title>\r\n",
"<!-- 6946451901889933502 -->\r\n",
"<g id=\"node4\" class=\"node\">\r\n",
"<title>6946451901889933502</title>\r\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"136.56,-54 82.56,-54 82.56,0 136.56,0 136.56,-54\"/>\r\n",
"<text text-anchor=\"middle\" x=\"109.56\" y=\"-24.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">clamp_1</text>\r\n",
"<text text-anchor=\"middle\" x=\"109.56\" y=\"-24.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">clamp_2</text>\r\n",
"</g>\r\n",
"<!-- 15409961985777379545&#45;&#45;17911493520242536387 -->\r\n",
"<!-- 13119880213352482441&#45;&#45;6946451901889933502 -->\r\n",
"<g id=\"edge3\" class=\"edge\">\r\n",
"<title>15409961985777379545&#45;&#45;17911493520242536387</title>\r\n",
"<title>13119880213352482441&#45;&#45;6946451901889933502</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M109.56,-99C109.56,-84.4 109.56,-67.75 109.56,-54.23\"/>\r\n",
"<text text-anchor=\"start\" x=\"109.56\" y=\"-74.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\" fill=\"red\">clamp_1</text>\r\n",
"<text text-anchor=\"start\" x=\"109.56\" y=\"-74.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\" fill=\"red\">clamp_2</text>\r\n",
"<text text-anchor=\"start\" x=\"88.56\" y=\"-56.83\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">1 out </text>\r\n",
"<text text-anchor=\"start\" x=\"95.56\" y=\"-92.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">2 p </text>\r\n",
"</g>\r\n",
"<!-- 10086520768477768404 -->\r\n",
"<g id=\"node4\" class=\"node\">\r\n",
"<title>10086520768477768404</title>\r\n",
"<!-- 2007672268546576087 -->\r\n",
"<g id=\"node3\" class=\"node\">\r\n",
"<title>2007672268546576087</title>\r\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"332.56,-171 260.56,-171 260.56,-99 332.56,-99 332.56,-171\"/>\r\n",
"<text text-anchor=\"middle\" x=\"296.56\" y=\"-137.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">Ber</text>\r\n",
"<text text-anchor=\"middle\" x=\"296.56\" y=\"-127.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">bernoulli_2</text>\r\n",
"<text text-anchor=\"middle\" x=\"296.56\" y=\"-127.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">bernoulli_1</text>\r\n",
"</g>\r\n",
"<!-- 10954844975880832864 -->\r\n",
"<!-- 16639421223427970264 -->\r\n",
"<g id=\"node7\" class=\"node\">\r\n",
"<title>10954844975880832864</title>\r\n",
"<title>16639421223427970264</title>\r\n",
"<polygon fill=\"lightgrey\" stroke=\"black\" points=\"323.56,-54 269.56,-54 269.56,0 323.56,0 323.56,-54\"/>\r\n",
"<text text-anchor=\"middle\" x=\"296.56\" y=\"-24.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">clamp_2</text>\r\n",
"<text text-anchor=\"middle\" x=\"296.56\" y=\"-24.8\" font-family=\"Times New Roman,serif\" font-size=\"9.00\">clamp_1</text>\r\n",
"</g>\r\n",
"<!-- 10086520768477768404&#45;&#45;10954844975880832864 -->\r\n",
"<g id=\"edge2\" class=\"edge\">\r\n",
"<title>10086520768477768404&#45;&#45;10954844975880832864</title>\r\n",
"<!-- 2007672268546576087&#45;&#45;16639421223427970264 -->\r\n",
"<g id=\"edge1\" class=\"edge\">\r\n",
"<title>2007672268546576087&#45;&#45;16639421223427970264</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M296.56,-99C296.56,-84.4 296.56,-67.75 296.56,-54.23\"/>\r\n",
"<text text-anchor=\"start\" x=\"296.56\" y=\"-74.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\" fill=\"red\">clamp_2</text>\r\n",
"<text text-anchor=\"start\" x=\"296.56\" y=\"-74.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\" fill=\"red\">clamp_1</text>\r\n",
"<text text-anchor=\"start\" x=\"275.56\" y=\"-56.83\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">1 out </text>\r\n",
"<text text-anchor=\"start\" x=\"282.56\" y=\"-92.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">2 p </text>\r\n",
"</g>\r\n",
"<!-- 12815532752540873597&#45;&#45;15409961985777379545 -->\r\n",
"<g id=\"edge6\" class=\"edge\">\r\n",
"<title>12815532752540873597&#45;&#45;15409961985777379545</title>\r\n",
"<!-- 1579320132496355312&#45;&#45;13119880213352482441 -->\r\n",
"<g id=\"edge5\" class=\"edge\">\r\n",
"<title>1579320132496355312&#45;&#45;13119880213352482441</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M179.71,-215.87C167.29,-201.74 153,-185.46 140.57,-171.31\"/>\r\n",
"<text text-anchor=\"start\" x=\"162.56\" y=\"-191.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\" fill=\"red\">x</text>\r\n",
"<text text-anchor=\"start\" x=\"162.56\" y=\"-191.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\" fill=\"red\">z</text>\r\n",
"<text text-anchor=\"start\" x=\"119.57\" y=\"-173.91\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">1 out </text>\r\n",
"<text text-anchor=\"start\" x=\"158.71\" y=\"-209.47\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">2 in1 </text>\r\n",
"<text text-anchor=\"start\" x=\"158.71\" y=\"-209.47\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">3 in2 </text>\r\n",
"</g>\r\n",
"<!-- 12815532752540873597&#45;&#45;10086520768477768404 -->\r\n",
"<g id=\"edge1\" class=\"edge\">\r\n",
"<title>12815532752540873597&#45;&#45;10086520768477768404</title>\r\n",
"<!-- 1579320132496355312&#45;&#45;2007672268546576087 -->\r\n",
"<g id=\"edge2\" class=\"edge\">\r\n",
"<title>1579320132496355312&#45;&#45;2007672268546576087</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M236.83,-215.87C247.4,-201.74 259.57,-185.46 270.15,-171.31\"/>\r\n",
"<text text-anchor=\"start\" x=\"255.56\" y=\"-191.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\" fill=\"red\">z</text>\r\n",
"<text text-anchor=\"start\" x=\"255.56\" y=\"-191.6\" font-family=\"Times New Roman,serif\" font-size=\"8.00\" fill=\"red\">x</text>\r\n",
"<text text-anchor=\"start\" x=\"249.15\" y=\"-173.91\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">1 out </text>\r\n",
"<text text-anchor=\"start\" x=\"215.83\" y=\"-209.47\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">3 in2 </text>\r\n",
"<text text-anchor=\"start\" x=\"215.83\" y=\"-209.47\" font-family=\"Times New Roman,serif\" font-size=\"8.00\">2 in1 </text>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n"
Expand All @@ -446,7 +450,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -459,7 +463,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 17,
"metadata": {},
"outputs": [
{
Expand All @@ -481,7 +485,14 @@
"\n",
"end\n",
"\n",
"end # block\n"
"end # block"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
Expand All @@ -491,13 +502,13 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Ber(p=0.08)\n"
"Ber(p=0.83)\n"
]
},
"metadata": {},
Expand Down

0 comments on commit 3a502a7

Please sign in to comment.