Skip to content

Commit

Permalink
Preserve gradients in pytorch as_tensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelburbulla committed May 31, 2023
1 parent 0877e11 commit be1a659
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions deepxde/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def Variable(initial_value, dtype=None):


def as_tensor(data, dtype=None):
if isinstance(data, (list, tuple)) and all(isinstance(d, torch.Tensor) for d in data):
data = torch.stack(data)
if isinstance(data, torch.Tensor):
if dtype is None or data.dtype == dtype:
return data
Expand Down

0 comments on commit be1a659

Please sign in to comment.