Skip to content
This repository has been archived by the owner on Jun 30, 2024. It is now read-only.

Commit

Permalink
fran suggestions, imports, abs_tg_acceleration and better local_polar…
Browse files Browse the repository at this point in the history
…ization for n=1 neighbour
  • Loading branch information
pacorofe committed Feb 1, 2022
1 parent d744c3e commit 1c08194
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 47 deletions.
2 changes: 2 additions & 0 deletions trajectorytools/export/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .to_pandas import *
from .variables import *
59 changes: 15 additions & 44 deletions trajectorytools/export/to_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd

from variables import get_variable
from .variables import get_variable


def melt_df(df: pd.DataFrame, value_name: str, var_name: str):
Expand Down Expand Up @@ -38,69 +38,42 @@ def generate_var_df(


def get_focal_nb_ids(key: str, identities: List):
focal_neighbour_ids_conds = list(itertools.product(identities, identities))
focals, neighbours = list(zip(*focal_neighbour_ids_conds))
focals, neighbours = list(
zip(*[(id_, id_nb) for id_ in identities for id_nb in identities])
)
dict_vars = {
key: focals,
f"{key}_nb": neighbours,
}
return dict_vars


def is_focal_nb_variable(tr, var_array):
if len(var_array.shape) == 3:
frames_cond = len(var_array) == tr.number_of_frames
num_indiv_cond = var_array.shape[1] == tr.number_of_individuals
focal_nb_cond = var_array.shape[1] == var_array.shape[2]
return frames_cond and num_indiv_cond and focal_nb_cond
else:
return False


def is_individual_variable(tr, var_array):
if len(var_array.shape) == 2:
frames_cond = len(var_array) == tr.number_of_frames
num_indiv_cond = var_array.shape[1] == tr.number_of_individuals
return frames_cond and num_indiv_cond
else:
return False


def is_group_variable(tr, var_array):
if len(var_array.shape) == 1:
frames_cond = len(var_array) == tr.number_of_frames
return frames_cond
else:
return False


def tr_variable_to_df(tr, var):
identities = list(range(tr.number_of_individuals))

# Get variables
try:
y = get_variable(var, tr)
except ValueError:
print(f"Cannot extract {var} from {tr}")
return None

print(y.shape)
x = np.arange(len(tr))
x_name = "frame"
y_name = var["name"]

if is_focal_nb_variable(tr, y): # (frames, num_indiv, num_indiv)
if y.shape == (
tr.number_of_frames,
tr.number_of_individuals,
tr.number_of_individuals,
):
y = np.reshape(y, (y.shape[0], -1))
identity_dict = get_focal_nb_ids("identity", identities)
identity_dict = get_focal_nb_ids("identity", tr.identity_labels)
var_df = generate_var_df(x, y, x_name, y_name, identity_dict)
elif is_individual_variable(tr, y): # (num_frames, num_indiv)
identity_dict = {"identity": identities}
elif y.shape == (tr.number_of_frames, tr.number_of_individuals):
identity_dict = {"identity": tr.identity_labels}
var_df = generate_var_df(x, y, x_name, y_name, identity_dict)
elif is_group_variable(tr, y): # (num_frames,)
elif y.shape == (tr.number_of_frames,):
var_df = generate_var_df(x, y[:, np.newaxis], x_name, y_name)
else:
print(f"With var {var}")
print(f"With tr {tr}")
raise Exception(
f"Number of dimensions of y array is {y.ndim} not valid"
)
Expand All @@ -109,16 +82,14 @@ def tr_variable_to_df(tr, var):


def tr_variables_to_df(tr, variables: List):
assert len(variables) > 0
assert tr is not None
assert variables

vars_dfs = []
for variable in variables:
vars_df = tr_variable_to_df(tr, variable)
if vars_df is not None:
vars_dfs.append(vars_df)

print([len(df) for df in vars_dfs])
assert all([len(df) == len(vars_dfs[0]) for df in vars_dfs])

all_cols = [c for df in vars_dfs for c in df.columns]
Expand Down Expand Up @@ -164,4 +135,4 @@ def tr_variables_to_df(tr, variables: List):
)
indiv_df = tr_variables_to_df(tr, INDIVIDUAL_VARIALBES)
indiv_nb_df = tr_variables_to_df(tr, INDIVIDUAL_NEIGHBOUR_VARIABLES)
indiv_nb_df = tr_variables_to_df(tr, GROUP_VARIABLES)
group_df = tr_variables_to_df(tr, GROUP_VARIABLES)
9 changes: 6 additions & 3 deletions trajectorytools/export/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@


def get_variable(variable, tr):
print(f"Using function {variable['func']}")
kwargs = variable.get("kwargs", {})
print(f"With kwargs {kwargs}")
return variable["func"](tr, **kwargs)


Expand Down Expand Up @@ -67,6 +65,10 @@ def tg_acceleration(tr):
return tr.tg_acceleration


def abs_tg_acceleration(tr):
return np.abs(tr.tg_acceleration)


def distance_to_center_of_group(tr):
distance_to_group_center = tt.norm(
tr.center_of_mass.s[:, np.newaxis, :] - tr.s
Expand All @@ -76,7 +78,7 @@ def distance_to_center_of_group(tr):

def local_polarization(tr, number_of_neighbours=4):
indices = ttsocial.neighbour_indices(tr.s, number_of_neighbours)
en = ttsocial.restrict(tr.e, indices)[..., 1:, :]
en = ttsocial.restrict(tr.e, indices)
local_polarization = tt.norm(tt.collective.polarization(en))
return local_polarization

Expand Down Expand Up @@ -124,6 +126,7 @@ def focal_fwd_accel(tr):
{"name": "normal_acceleration", "func": normal_acceleration},
{"name": "abs_normal_acceleration", "func": abs_normal_acceleration},
{"name": "tg_acceleration", "func": tg_acceleration},
{"name": "abs_tg_acceleration", "func": abs_tg_acceleration},
{
"name": "distance_to_center_of_group",
"func": distance_to_center_of_group,
Expand Down

0 comments on commit 1c08194

Please sign in to comment.