diff --git a/pystac/catalog.py b/pystac/catalog.py index 675f8fffa..da74b0b39 100644 --- a/pystac/catalog.py +++ b/pystac/catalog.py @@ -240,10 +240,12 @@ def add_child( child: Union["Catalog", Collection], title: Optional[str] = None, strategy: Optional[HrefLayoutStrategy] = None, + keep_parent: bool = False, ) -> None: """Adds a link to a child :class:`~pystac.Catalog` or :class:`~pystac.Collection`. This method will set the child's parent to this - object, and its root to this Catalog's root. + object (except if a parent is set and keep_parent is true). + It will always set its root to this Catalog's root. Args: child : The child to add. @@ -261,7 +263,8 @@ def add_child( strategy = BestPracticesLayoutStrategy() child.set_root(self.get_root()) - child.set_parent(self) + if child.get_parent() is None or not keep_parent: + child.set_parent(self) # set self link self_href = self.get_self_href() @@ -287,10 +290,12 @@ def add_item( item: Item, title: Optional[str] = None, strategy: Optional[HrefLayoutStrategy] = None, + keep_parent: bool = False, ) -> None: """Adds a link to an :class:`~pystac.Item`. - This method will set the item's parent to this object, and its root to - this Catalog's root. + This method will set the item's parent to this object (except if a parent + is set and keep_parent is true). + It will always set its root to this Catalog's root. Args: item : The item to add. @@ -308,7 +313,8 @@ def add_item( strategy = BestPracticesLayoutStrategy() item.set_root(self.get_root()) - item.set_parent(self) + if item.get_parent() is None or not keep_parent: + item.set_parent(self) # set self link self_href = self.get_self_href() diff --git a/tests/test_catalog.py b/tests/test_catalog.py index 58e699772..5629cc60d 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -222,6 +222,54 @@ def test_add_child_throws_if_item(self) -> None: with pytest.raises(pystac.STACError): cat.add_child(item) # type:ignore + def test_add_child_override_parent(self) -> None: + parent1 = Catalog(id="parent1", description="test1") + parent2 = Catalog(id="parent2", description="test2") + child = Catalog(id="child", description="test3") + assert child.get_parent() is None + + parent1.add_child(child) + assert child.get_parent() is parent1 + + parent2.add_child(child) + assert child.get_parent() is parent2 + + def test_add_child_keep_parent(self) -> None: + parent1 = Catalog(id="parent1", description="test1") + parent2 = Catalog(id="parent2", description="test2") + child = Catalog(id="child", description="test3") + assert child.get_parent() is None + + parent1.add_child(child, keep_parent=True) + assert child.get_parent() is parent1 + + parent2.add_child(child, keep_parent=True) + assert child.get_parent() is parent1 + + def test_add_item_override_parent(self) -> None: + parent1 = Catalog(id="parent1", description="test1") + parent2 = Catalog(id="parent2", description="test2") + child = Item(id="child", description="test3") + assert child.get_parent() is None + + parent1.add_item(child) + assert child.get_parent() is parent1 + + parent2.add_item(child) + assert child.get_parent() is parent2 + + def test_add_item_keep_parent(self) -> None: + parent1 = Catalog(id="parent1", description="test1") + parent2 = Catalog(id="parent2", description="test2") + child = Item(id="child", description="test3") + assert child.get_parent() is None + + parent1.add_item(child, keep_parent=True) + assert child.get_parent() is parent1 + + parent2.add_item(child, keep_parent=True) + assert child.get_parent() is parent1 + def test_add_item_throws_if_child(self) -> None: cat = TestCases.case_1() child = next(iter(cat.get_children()))