From 85912a46f7421f930af9a7322b35e861ddcc4cba Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Thu, 16 Nov 2023 17:11:50 -0500 Subject: [PATCH 1/2] update psis to not add back the max and return the unnormalized resampling ratios --- src/stan/services/pathfinder/psis.hpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/stan/services/pathfinder/psis.hpp b/src/stan/services/pathfinder/psis.hpp index 7a13a80726..1684daf5cf 100644 --- a/src/stan/services/pathfinder/psis.hpp +++ b/src/stan/services/pathfinder/psis.hpp @@ -270,14 +270,7 @@ inline Eigen::Array psis_weights( } // truncate at max of raw wts (i.e., 0 since max has been subtracted) - for (Eigen::Index i = 0; i < llr_weights.size(); ++i) { - if (llr_weights.coeff(i) > 0) { - llr_weights.coeffRef(i) = 0.0; - } - } - auto max_adj = (llr_weights + max_log_ratio).eval(); - auto max_adj_exp = max_adj.exp(); - return max_adj_exp / max_adj_exp.sum(); + return (llr_weights.array() < 0.0).select(llr_weights, 0.0).exp().eval(); } } // namespace psis From 3c31d68d7002d6c6270f880d7071bfbacc8a2ebb Mon Sep 17 00:00:00 2001 From: Steve Bronder Date: Mon, 20 Nov 2023 12:32:48 -0500 Subject: [PATCH 2/2] fix test --- .../unit/services/pathfinder/psis_test.cpp | 68 +++++++++---------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/src/test/unit/services/pathfinder/psis_test.cpp b/src/test/unit/services/pathfinder/psis_test.cpp index d07b8281c7..def8bd4437 100644 --- a/src/test/unit/services/pathfinder/psis_test.cpp +++ b/src/test/unit/services/pathfinder/psis_test.cpp @@ -120,40 +120,40 @@ TEST(ServicesPSIS, get_psis_weights) { 4.91652175788515, 5.59472544249795, 7.16126247369561, 6.15854810500356, 6.62555174418028, 9.103992907802, 6.7399306070758, 6.04794961458687; Eigen::Array answer(100); - answer << 0.00614128190501742, 0.00259883000263199, 0.00142240988349195, - 0.0274531755800624, 0.0157476498780024, 0.00391823907901988, - 8.26414513956511e-14, 0.00322003256443345, 0.00619646954412031, - 0.00346793510111242, 0.0444046803091539, 0.004029171121401, - 0.00366480923122832, 0.00175331655026224, 0.00491192297921304, - 0.000810758279198907, 0.00588439834963157, 0.17822256163288, - 0.00297698115204899, 0.00580171797853455, 0.0121325368865641, - 0.00408651931997244, 0.00457647470707981, 0.00999374265076167, - 4.75420585761083e-06, 0.00250429809109908, 0.00273193115270838, - 0.00357946965550952, 0.00255728407443698, 0.00345942783466891, - 0.00585416468904412, 0.00866717090821407, 0.00223521120433322, - 0.00335369108675144, 0.000278240478565439, 0.00841871317684748, - 0.00303339482203412, 0.0108274835712062, 0.00635298901334288, - 0.0366679814601552, 0.0128890587651439, 0.0313596935307308, - 0.00212538535536358, 0.00646666659496334, 0.00787540323350999, - 0.00513279331867044, 0.0114488824095472, 0.00943345186908113, - 0.00744772541269413, 0.00262296832653951, 0.00786635558506987, - 0.00149440115404536, 0.00693444017895142, 0.00452127888795521, - 0.0137317162483214, 0.001172190208605, 0.00470337325067493, - 0.00880970502513788, 0.00388287423800577, 0.00615587072760303, - 0.0220257764512452, 0.000965367587466615, 0.00414255659344651, - 0.00869536963179888, 0.00266583380629748, 0.000131838303785288, - 4.379914619164e-16, 0.00226285858615159, 0.0183863334607642, - 0.00209481287761514, 2.69620598333185e-11, 2.54440757139945e-05, - 0.00400503003258219, 0.0200460546415302, 0.00559698462905835, - 0.00716790107562062, 0.00329971992325907, 0.0169711871963573, - 0.00473505622866302, 0.057026147481862, 0.00651433804563218, - 0.0102597432232274, 0.00575355555045384, 0.00379438966189875, - 0.0244370937648389, 0.00151257245292768, 0.00487470007189976, - 0.0054709853543924, 0.0044516176713708, 0.00952362373631346, - 0.00464287019866733, 9.6288493325318e-09, 0.00147240135547797, - 0.00290112423098392, 0.0146773422920254, 0.00509837055791414, - 0.00813295746634018, 0.082538932616576, 0.00911848336266506, - 0.00456456171607177; + answer << 0.0292354130264213, 0.0123716627387399, 0.00677134531192521, + 0.130690129419765, 0.0749662782948843, 0.0186526753832652, + 3.93412483260354e-13, 0.0153288814022406, 0.0294981323492076, + 0.0165090149903415, 0.211387327470457, 0.019180764490236, + 0.0174462291741511, 0.00834661790579885, 0.0233830817852443, + 0.00385959373361536, 0.0280125255319936, 0.848423876407167, + 0.0141718414653653, 0.0276189277725056, 0.0577566267796964, + 0.019453769100265, 0.0217861889969809, 0.0475749523623197, + 2.26322735235991e-05, 0.0119216460287785, 0.0130052873071864, + 0.0170399723401178, 0.0121738844264701, 0.0164685163693616, + 0.0278685989759023, 0.0412598420315477, 0.0106406648922235, + 0.0159651593267689, 0.00132455657260805, 0.0400770654535543, + 0.0144403972763279, 0.0515439543627507, 0.0302432392178012, + 0.174556973513457, 0.0613580295366247, 0.149287006675898, + 0.0101178417902614, 0.0307843983930222, 0.0374906525774646, + 0.0244345293004686, 0.0545021073956205, 0.0449077026462719, + 0.0354547034177988, 0.0124865726028556, 0.037447581482023, + 0.0071140578858597, 0.0330112223920261, 0.0215234307334137, + 0.0653694786023926, 0.00558018104755882, 0.0223902862183483, + 0.0419383719937075, 0.0184843219760041, 0.0293048627374842, + 0.104853136810685, 0.00459560733058814, 0.0197205330854113, + 0.041394081322883, 0.0126906326060736, 0.000627612821472451, + 2.08504698029095e-15, 0.0107722795353974, 0.0875276629668783, + 0.00997230230387398, 1.28352185661132e-10, 0.000121125861681488, + 0.0190658414638319, 0.0954287225467601, 0.0266442999790266, + 0.0341226069278941, 0.0157082310045359, 0.0807908959249787, + 0.0225411122122881, 0.271471494180354, 0.0310113370959555, + 0.0488412411797974, 0.0273896517843532, 0.0180630934492948, + 0.116332150262759, 0.00720056188230011, 0.0232058831016178, + 0.0260444837040436, 0.0211918103208897, 0.0453369634785661, + 0.0221022630104646, 4.58378871967692e-08, 0.00700932841609052, + 0.0138107265625568, 0.0698711068265654, 0.02427066064181, + 0.0387167328144493, 0.392924445274868, 0.0434083032508679, + 0.0217294775126546; stan::test::test_logger warner; auto blah = stan::services::psis::psis_weights(lrms, 20, warner); for (Eigen::Index i = 0; i < answer.size(); ++i) {