From be1a659160c155c023a9fca20be503592d7aa0ef Mon Sep 17 00:00:00 2001 From: Samuel Burbulla Date: Wed, 31 May 2023 08:36:00 +0200 Subject: [PATCH] Preserve gradients in pytorch as_tensor. --- deepxde/backend/pytorch/tensor.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/deepxde/backend/pytorch/tensor.py b/deepxde/backend/pytorch/tensor.py index 0e59c7610..23b7666be 100644 --- a/deepxde/backend/pytorch/tensor.py +++ b/deepxde/backend/pytorch/tensor.py @@ -71,6 +71,8 @@ def Variable(initial_value, dtype=None): def as_tensor(data, dtype=None): + if isinstance(data, (list, tuple)) and all(isinstance(d, torch.Tensor) for d in data): + data = torch.stack(data) if isinstance(data, torch.Tensor): if dtype is None or data.dtype == dtype: return data