diff --git a/tests/common_testing.py b/tests/common_testing.py index 729d56e72..9935f07a1 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -16,6 +16,14 @@ def assertSeparate(self, tensor1, tensor2) -> None: tensor1.storage().data_ptr(), tensor2.storage().data_ptr() ) + def assertNotSeparate(self, tensor1, tensor2) -> None: + """ + Verify that tensor1 and tensor2 have their data in the same locations. + """ + self.assertEqual( + tensor1.storage().data_ptr(), tensor2.storage().data_ptr() + ) + def assertAllSeparate(self, tensor_list) -> None: """ Verify that all tensors in tensor_list have their data in