Skip to content

Commit

Permalink
Refactor unique object tree walking.
Browse files Browse the repository at this point in the history
  • Loading branch information
riga committed Jul 4, 2023
1 parent fb0ee7d commit 4340a6f
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 48 deletions.
128 changes: 90 additions & 38 deletions order/unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


import collections
import warnings

import six

Expand Down Expand Up @@ -750,6 +751,55 @@ def _clear(self, remove_fn, index):
for name in index.names():
remove_fn(name)

# walk helper
def _walk(self, next_fn, algo="bfs", depth_first=False, include_self=False):
# handle cases where the deprecated depth_first argument is used
if depth_first:
if algo != "bfs":
raise Exception(
"using both 'algo' and 'depth_first' arguments is ambiguous; " +
"'depth_first' is deprecated; use 'algo=\"{}\" instead'".format(algo),
)
warnings.warn(
"the 'depth_first' attribute is deprecated; use 'algo=\"dfs\"' instead'",
DeprecationWarning,
)
algo = "dfs"

# check the algo
if algo == "dfs":
algo = "dfs_preorder"
known_algos = ["bfs", "dfs_preorder", "dfs_postorder"]
if algo not in known_algos:
_known_algos = ", ".join(map("'{}'".format, known_algos))
raise ValueError(
"unknown traversel order '{}', should be one of {}".format(algo, _known_algos),
)

lookup = collections.deque([(self, 0)])
visited = set()
while lookup:
obj, depth = lookup[0]
if obj in visited:
lookup.popleft()
continue

objs = list(next_fn(obj))
if algo == "dfs_postorder" and any(_obj not in visited for _obj in objs):
lookup.extendleft((obj, depth + 1) for obj in reversed(objs))
continue

if depth > 0 or include_self:
yield (obj, depth, objs)
visited.add(obj)

lookup.popleft()

if algo == "dfs_preorder":
lookup.extendleft((obj, depth + 1) for obj in reversed(objs))
elif algo == "bfs":
lookup.extend((obj, depth + 1) for obj in objs)

#
# child methods, independent of parents
#
Expand Down Expand Up @@ -838,28 +888,30 @@ def get(self, obj, deep=True, default=_no_default):

# walk children method
@patch("walk_" + plural)
def walk(self, depth_first=False, include_self=False):
def walk(self, algo="bfs", depth_first=False, include_self=False):
"""
Walks through the :py:attr:`{plural}` index and per iteration, yields a child
{singular}, its depth relative to *this* {singular}, and its child {plural} in a
list that can be modified to alter the walking. When *depth_first* is *True*,
iterate depth-first instead of the default breadth-first. When *include_self* is
*True*, also yield this {singular} instance with a depth of 0.
"""
lookup = collections.deque([(self, 0)])
while lookup:
obj, depth = lookup.popleft()
objs = list(getattr(obj, plural).values())
list that can be modified to alter the walking.
if include_self:
yield (obj, depth, objs)
else:
include_self = True
The traversal order is defined by *algo* which allows different values (more
`info <https://en.wikipedia.org/wiki/Tree_traversal>`__):
if depth_first:
lookup.extendleft((obj, depth + 1) for obj in reversed(objs))
else:
lookup.extend((obj, depth + 1) for obj in objs)
- ``"bfs"``: Breadth-first search.
- ``"dfs"``: Alias for ``"dfs_preorder"``.
- ``"dfs_preorder"``: Pre-order depth-first search.
- ``"dfs_postorder"``: Post-order depth-first search.
When *include_self* is *True*, this {singular} instance is yielded as well with a
depth of 0.
"""
return _walk(
self,
(lambda obj: getattr(obj, plural).values()),
algo=algo,
depth_first=depth_first,
include_self=include_self,
)

# get leaves method
@patch("get_leaf_" + plural)
Expand All @@ -869,9 +921,8 @@ def get_leaves(self):
{plural} themselves in a recursive fashion. Possible duplicates due to nested
structures are removed.
"""
walker = getattr(self, "walk_" + plural)()
leaves = []
for obj, _, objs in walker:
for obj, _, objs in getattr(self, "walk_" + plural)():
if not objs and obj not in leaves:
leaves.append(obj)
return leaves
Expand Down Expand Up @@ -1114,28 +1165,30 @@ def get(self, obj, deep=True, default=_no_default):

# walk parents method
@patch("walk_parent_" + plural) # noqa: F811
def walk(self, depth_first=False, include_self=False): # noqa: F811
def walk(self, algo="bfs", depth_first=False, include_self=False): # noqa: F811
"""
Walks through the :py:attr:`parent_{plural}` index and per iteration, yields a
parent {singular}, its depth relative to *this* {singular}, and its parent
{plural} in a list that can be modified to alter the walking. When *depth_first*
is *True*, iterate depth-first instead of the default breadth-first. When
*include_self* is *True*, also yield this {singular} instance with a depth of 0.
"""
lookup = collections.deque([(self, 0)])
while lookup:
obj, depth = lookup.popleft()
objs = list(getattr(obj, "parent_" + plural).values())
{plural} in a list that can be modified to alter the walking.
if include_self:
yield (obj, depth, objs)
else:
include_self = True
The traversal order is defined by *algo* which allows different values (more
`info <https://en.wikipedia.org/wiki/Tree_traversal>`__):
if depth_first:
lookup.extendleft((obj, depth + 1) for obj in reversed(objs))
else:
lookup.extend((obj, depth + 1) for obj in objs)
- ``"bfs"``: Breadth-first search.
- ``"dfs"``: Alias for ``"dfs_preorder"``.
- ``"dfs_preorder"``: Pre-order depth-first search.
- ``"dfs_postorder"``: Post-order depth-first search.
When *include_self* is *True*, this {singular} instance is yielded as well with
a depth of 0.
"""
return _walk(
self,
(lambda obj: getattr(obj, "parent_" + plural).values()),
algo=algo,
depth_first=depth_first,
include_self=include_self,
)

# get roots method
@patch("get_root_" + plural)
Expand All @@ -1145,9 +1198,8 @@ def get_roots(self):
no parent {plural} themselves in a recursive fashion. Possible duplicates due to
nested structures are removed.
"""
walker = getattr(self, "walk_parent_" + plural)()
roots = []
for obj, _, objs in walker:
for obj, _, objs in getattr(self, "walk_parent_" + plural)():
if not objs and obj not in roots:
roots.append(obj)
return roots
Expand Down
24 changes: 14 additions & 10 deletions tests/test_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,19 +427,23 @@ def test_walking(self):
self.assertEqual(n, n1)
self.assertEqual(len(nodes), 0)

def walk(depth, this):
return [
n
def walk(include_self, algo):
return tuple(
n.name
for n, _, _ in n1.walk_nodes(
depth_first=depth,
include_self=this,
algo=algo,
include_self=include_self,
)
]
)

self.assertEqual(walk(False, "bfs"), ("b", "c", "d"))
self.assertEqual(walk(True, "bfs"), ("a", "b", "c", "d"))

self.assertEqual(walk(False, "dfs_preorder"), ("b", "d", "c"))
self.assertEqual(walk(True, "dfs_preorder"), ("a", "b", "d", "c"))

self.assertListEqual(walk(False, False), [n2, n3, n4])
self.assertListEqual(walk(False, True), [n1, n2, n3, n4])
self.assertListEqual(walk(True, False), [n2, n4, n3])
self.assertListEqual(walk(True, True), [n1, n2, n4, n3])
self.assertEqual(walk(False, "dfs_postorder"), ("d", "b", "c"))
self.assertEqual(walk(True, "dfs_postorder"), ("d", "b", "c", "a"))

self.assertListEqual(
[n for n, _, _ in n4.walk_parent_nodes(include_self=True)],
Expand Down

0 comments on commit 4340a6f

Please sign in to comment.