Replies: 1 comment 1 reply
-
The VI loss looks like it hasn't stabilized. Can you try reducing the
learning rate ? (e.g. to .01 or .001) it may be just adapting too much
between iterations
…On Thu, Feb 13, 2025 at 1:39 PM katesjef ***@***.***> wrote:
Hi all!
As experienced in several previous posts (i.e., #527
<#527>, #323
<#323>), I have quite a
complex hierarchical model and have had trouble achieving convergence for
non-decision time (the distribution runs up against a ceiling). One of the
workarounds I am trying is to use variational inference instead.
When I run VI, even for long periods of time, I see that the loss is stuck
at a very high value (~200k). Running a posterior predictive check also
reveals opposite behavior for observed and predicted RT.
My questions are: Is it worth it to continue pursuing VI in this case? If
so, are there ways to improve optimization and/or efficiency? Or might
there be better alternatives to this approach?
Thank you for this great package!
HSSM v 0.2.4
param_t = {
'name' : 't',
'formula' : 't ~ 1',
'prior' : {
'Intercept' : {"name": "Gamma", "mu": 0.15, 'sigma': 0.13, 'initval':0.01},
},
'bounds' : (0, 2),
*When using NUTS*:
NUTS_t_intercept_trace.png (view on web)
<https://github.com/user-attachments/assets/dce830ac-f94f-4e6d-96f5-c65a056f837a>
NUTS_t_intercept_summary.png (view on web)
<https://github.com/user-attachments/assets/a5f2de91-f034-4ba9-a119-a0d4f82b6aa4>
And a posterior predictive plot of:
NUTS_ppc.png (view on web)
<https://github.com/user-attachments/assets/bb53acbe-8049-42e3-b76f-0e3af4966042>
*When using VI*:
(from the tutorial)
with model_6.pymc_model:
advi = pm.FullRankADVI()
start = model_6.pymc_model.initial_point()
vars_dict = {var.name: var for var in
model_6.pymc_model.continuous_value_vars}
x0 = DictToArrayBijection.map(
{var_name: value for var_name, value in start.items() if var_name in
vars_dict}
)
tracker = pm.callbacks.Tracker(
mean=lambda: DictToArrayBijection.rmap(
RaveledVars(advi.approx.mean.eval(), x0.point_map_info), start
),
std=lambda: DictToArrayBijection.rmap(
RaveledVars(advi.approx.std.eval(), x0.point_map_info), start
),
)
approx = advi.fit(n=1000000, callbacks=[tracker])
vi_posterior_samples = approx.sample(1000)
VI_means.png (view on web)
<https://github.com/user-attachments/assets/8b4a4506-f83a-4add-8594-2e1555c670a8>
VI_loss.png (view on web)
<https://github.com/user-attachments/assets/a6ebb760-4c40-4a28-ab5c-e8c180ef9dc0>
A "t" posterior of:
VI_t_intercept_posterior.png (view on web)
<https://github.com/user-attachments/assets/b9a8bcaf-7534-422e-9304-03e32bd40f61>
And the posterior predictive plot:
VI_ppc.png (view on web)
<https://github.com/user-attachments/assets/5df16d2e-cbca-4319-9c65-095108600b47>
—
Reply to this email directly, view it on GitHub
<#658>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAG7TFCRNRXLKJWJMIGTXL32PTRE5AVCNFSM6AAAAABXC4CT2KVHI2DSMVQWIX3LMV43ERDJONRXK43TNFXW4OZXHE3DEMZUGI>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi all!
As experienced in several previous posts (i.e., #527, #323), I have quite a complex hierarchical model and have had trouble achieving convergence for non-decision time (the distribution runs up against a ceiling). One of the workarounds I am trying is to use variational inference instead.
When I run VI, even for long periods of time, I see that the loss is stuck at a very high value (~200k). Running a posterior predictive check also reveals opposite behavior for observed and predicted RT.
My questions are: Is it worth it to continue pursuing VI in this case? If so, are there ways to improve optimization and/or efficiency? Or might there be better alternatives to this approach?
Thank you for this great package!
HSSM v 0.2.4
param_t = {
'name' : 't',
'formula' : 't ~ 1',
'prior' : {
'Intercept' : {"name": "Gamma", "mu": 0.15, 'sigma': 0.13, 'initval':0.01},
},
'bounds' : (0, 2),
When using NUTS:
And a posterior predictive plot of:
When using VI:
(from the tutorial)
with model_6.pymc_model:
advi = pm.FullRankADVI()
start = model_6.pymc_model.initial_point()
vars_dict = {var.name: var for var in model_6.pymc_model.continuous_value_vars}
x0 = DictToArrayBijection.map(
{var_name: value for var_name, value in start.items() if var_name in vars_dict}
)
tracker = pm.callbacks.Tracker(
mean=lambda: DictToArrayBijection.rmap(
RaveledVars(advi.approx.mean.eval(), x0.point_map_info), start
),
std=lambda: DictToArrayBijection.rmap(
RaveledVars(advi.approx.std.eval(), x0.point_map_info), start
),
)
approx = advi.fit(n=1000000, callbacks=[tracker])
vi_posterior_samples = approx.sample(1000)
A "t" posterior of:
![VI_t_intercept_posterior](https://private-user-images.githubusercontent.com/199129247/413013863-b9a8bcaf-7534-422e-9304-03e32bd40f61.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2ODM1OTQsIm5iZiI6MTczOTY4MzI5NCwicGF0aCI6Ii8xOTkxMjkyNDcvNDEzMDEzODYzLWI5YThiY2FmLTc1MzQtNDIyZS05MzA0LTAzZTMyYmQ0MGY2MS5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE2JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNlQwNTIxMzRaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT03MDk0OWE0NDg1YTFkNTJmZmZkNzdmNDhlYTQ3YWM4NzQzMzJiYmRmMTFjNzlmNjA4OGIwODkzMmY0NWI5NTUyJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.YOhe4iCpYDAhC-denwwYJ23qrRZkmSGT4-dZmVWnItI)
And the posterior predictive plot:
![VI_ppc](https://private-user-images.githubusercontent.com/199129247/413014748-5df16d2e-cbca-4319-9c65-095108600b47.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk2ODM1OTQsIm5iZiI6MTczOTY4MzI5NCwicGF0aCI6Ii8xOTkxMjkyNDcvNDEzMDE0NzQ4LTVkZjE2ZDJlLWNiY2EtNDMxOS05YzY1LTA5NTEwODYwMGI0Ny5wbmc_WC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BS0lBVkNPRFlMU0E1M1BRSzRaQSUyRjIwMjUwMjE2JTJGdXMtZWFzdC0xJTJGczMlMkZhd3M0X3JlcXVlc3QmWC1BbXotRGF0ZT0yMDI1MDIxNlQwNTIxMzRaJlgtQW16LUV4cGlyZXM9MzAwJlgtQW16LVNpZ25hdHVyZT04YjYxODJmZDJmZGM4YTA3OTJkYjVlM2U3MDRhZGFlZGUzZDJhNjM5MTFiMGYxMmVlYzA3N2VlZjRlNzhmN2RhJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCJ9.iCNE_mNULQMlNP_p6quIYnvedyiDhqB36emMLpfTLDQ)
Beta Was this translation helpful? Give feedback.
All reactions