Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynmax glm-hmm using jagged input arrays with differing E and M steps #362

Open
jess-breda opened this issue May 14, 2024 · 2 comments
Open

Comments

@jess-breda
Copy link

Summary

I'm interested in learning if/how it would be possible to fit an glm-hmm (i.e. LogisticRegressionHMM, CategoricalRegressionHMM) with a jagged input list (i.e. a list whose elements are lists of different lengths), such that the E step could be run individually for each of the inner constituent lists, whereas the M step would be run to the entire input.

Context

To motivate this question, supposejagged_list is a list where each element is a session_list (i.e. jagged_list = [session_list_1, session_list_2, ... session_list_s]), and each session_list contains trial samples (i.e. session_list_k = [trial_1, trial_2, ... trial_t]). Because number of trials varies per session, this is a jagged array. These data represent trials from a single subject across multiple sessions. Trials from previous or future sessions should not be used to learning state probabilities and transitions (E-step), but all trials should be used together for learning weights (M-step).

Note this was previously supported in the SSM library. When the data was structured in this way in SSM, it allowed for the E-step to be run for each session, followed by the M-step across all sessions.

Issues

I think there are two roadblocks that prevent this from being possible.

  1. Dynamax does not appear to support jagged arrays due to jax implementation
  • Dynamax has a procedure for inputs that are batched, however it requires each batch (e.g. session) to have the same number of time steps (e.g. trials).
  1. The current fit_em method runs the E step and M step for each batch (as opposed to E-step for each batch, M step across all batches)
  • However, There are separate methods for E and M steps- I'm just unsure how to properly summarize batch/session iterated E-step outputs (i.e. SuffStats) to pass into a single M-step call across all batches/sessions- any advice here would be greatly. appreciated

Questions

  1. Is this summary of issues and comparison accurate?
  2. Is there a way to implement the desired behavior of session-level E-steps and subject-level M-step? (e.g., via padding or using e_step and m_step methods)
  3. Any additional thoughts or suggestions?

Thank you!

@atlaie
Copy link

atlaie commented Jun 4, 2024

+1! I'd also be very interested in having this implemented in Dynamax

@conormcgrory
Copy link

+1 Getting this implemented would also help me a lot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants