Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finding the "normalize" function in MoSeq 2 code #277

Open
Lynnkoehler opened this issue Aug 6, 2024 · 1 comment
Open

Finding the "normalize" function in MoSeq 2 code #277

Lynnkoehler opened this issue Aug 6, 2024 · 1 comment

Comments

@Lynnkoehler
Copy link

So I have been utilizing the transition matrices from my MoSeq outputs and I noticed that they all get normalized to 'bigram'.

I see in the code that they can also get normalized to 'row' or 'column' but I would like to see what sort of code/calculation is being used for this normalization so I can better understand the output.

For reference I am including a screen shot of the code where the normalize command is referenced.

Screenshot 2024-08-06 at 3 59 38 PM
@davidhbrann
Copy link
Collaborator

davidhbrann commented Aug 7, 2024

If you look a few lines down from the snippet you posted, you can see that normalize is passed to the get_group_trans_mats function from moseq2_viz, whose code you can find here. As you can see, the relevant details of the normalization can be further found within the moseq2_viz.model.get_transition_matrix and the moseq2_viz.model.normalize_transition_matrix functions. The normalization either divides by the total number of transitions (bigrams) or divides each row/column by their sums (to normalize with respect to incoming or outgoing transitions).

from moseq2_viz.model.trans_graph import get_trans_graph_groups, get_group_trans_mats

# syllable threshold defined above. Uncomment if you want to manually set
# max_syllables = 40

# select a transition matrix normalization method
normalize = 'bigram' # options: bigram, columns, rows

# load your model
model_path = progress_paths['model_path']
model_data = parse_model_results(model_path)
model_data['labels'] = relabel_by_usage(model_data['labels'], count='usage')[0]

# Get modeled session uuids to compute group-mean transition graph for each group
label_group, uuids = get_trans_graph_groups(model_data)
group = list(set(label_group))
# compute transition matrices and usages for each group
print('Group(s):', ', '.join(group))
trans_mats, usages = get_group_trans_mats(model_data['labels'], label_group, group, max_syllables, normalize=normalize)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants