Skip to content
This repository has been archived by the owner on Jun 30, 2024. It is now read-only.

Commit

Permalink
More explicit and named shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
fjhheras committed Jul 6, 2020
1 parent 2af6935 commit 6a05efe
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions tests/socialcontext_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,25 @@ def test_neighbour_indices_vs_adjacency_matrix(num_neighbours):
t = np.load(cons.test_raw_trajectories_path, allow_pickle=True)
tt.interpolate_nans(t)

num_frames = t.shape[0]
num_individuals = t.shape[1]

nb_indices = neighbour_indices(t, num_neighbours=num_neighbours)
assert nb_indices.shape == tuple(
[t.shape[0], t.shape[1], num_neighbours + 1]
[num_frames, num_individuals, num_neighbours + 1]
)
adj_matrix = adjacency_matrix(t, num_neighbours=num_neighbours)
assert adj_matrix.shape == tuple([t.shape[i] for i in [0, 1, 1]])
assert adj_matrix.shape == (num_frames, num_individuals, num_individuals)

# When there is an index in neighbour_indices output, the
# corresponding elment in the adjacency_matrix must be True
for _ in range(5):
frame = random.randrange(0, t.shape[0])
individual = random.randrange(0, t.shape[1])
frame = random.randrange(0, num_frames)
individual = random.randrange(0, num_individuals)

indices_neighbours = nb_indices[frame, individual, :]
indices_no_neighbours = [
i for i in range(t.shape[1]) if i not in indices_neighbours
i for i in range(num_individuals) if i not in indices_neighbours
]
assert np.all(adj_matrix[frame, individual, indices_neighbours])
assert not np.any(adj_matrix[frame, individual, indices_no_neighbours])
Expand All @@ -56,25 +59,29 @@ def test_neighbour_indices_vs_adjacency_matrix(num_neighbours):
def test_neighbour_indices_vs_adjacency_matrix_in_frame(num_neighbours):
t = np.load(cons.test_raw_trajectories_path, allow_pickle=True)
tt.interpolate_nans(t)
frame = t[random.randrange(0, t.shape[0])]

num_frames = t.shape[0]
num_individuals = t.shape[1]

frame = t[random.randrange(0, num_frames)]

nb_indices = neighbour_indices_in_frame(
frame, num_neighbours=num_neighbours
)
assert nb_indices.shape == tuple([frame.shape[0], num_neighbours + 1])
assert nb_indices.shape == (num_individuals, num_neighbours + 1)
adj_matrix = adjacency_matrix_in_frame(
frame, num_neighbours=num_neighbours
)
assert adj_matrix.shape == tuple([frame.shape[i] for i in [0, 0]])
assert adj_matrix.shape == (num_individuals,) * 2

# When there is an index in neighbour_indices output, the
# corresponding elment in the adjacency_matrix must be True
for _ in range(5):
individual = random.randrange(0, t.shape[1])
individual = random.randrange(0, num_individuals)

indices_neighbours = nb_indices[individual, :]
indices_no_neighbours = [
i for i in range(t.shape[1]) if i not in indices_neighbours
i for i in range(num_individuals) if i not in indices_neighbours
]
assert np.all(adj_matrix[individual, indices_neighbours])
assert not np.any(adj_matrix[individual, indices_no_neighbours])

0 comments on commit 6a05efe

Please sign in to comment.