Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional Satlas pretrained models #1884

Merged
merged 9 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/landsat_pretrained_weights.csv
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ ResNet50_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link <https://github.com/microsoft/
ResNet50_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://arxiv.org/abs/2306.09424>`__,"CC0-1.0",63.65,46.68,60.01,43.17
ViTSmall16_Weights.LANDSAT_OLI_SR_MOCO,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://arxiv.org/abs/2306.09424>`__,"CC0-1.0",66.81,50.16,64.17,47.24
ViTSmall16_Weights.LANDSAT_OLI_SR_SIMCLR,8--9,7,`link <https://github.com/microsoft/torchgeo>`__,`link <https://arxiv.org/abs/2306.09424>`__,"CC0-1.0",65.04,48.20,62.61,45.46
Swin_V2_B_Weights.LANDSAT_MS_SI_SATLAS,11,'link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 1 addition & 1 deletion docs/api/naip_pretrained_weights.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Weight,Channels,Source,Citation,License
Swin_V2_B_Weights.NAIP_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"Apache-2.0"
Swin_V2_B_Weights.NAIP_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY"
1 change: 1 addition & 0 deletions docs/api/sentinel1_pretrained_weights.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
Weight,Channels,Source,Citation,License
ResNet50_Weights.SENTINEL1_ALL_MOCO, 2,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0"
Swin_V2_B_Weights.SENTINEL1_SI_SATLAS,2,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY"
3 changes: 2 additions & 1 deletion docs/api/sentinel2_pretrained_weights.csv
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ ResNet50_Weights.SENTINEL2_RGB_MOCO, 3,`link <https://github.com/zhu-xlab/SSL4EO
ResNet50_Weights.SENTINEL2_RGB_SECO, 3,`link <https://github.com/ServiceNow/seasonal-contrast>`__,`link <https://arxiv.org/abs/2103.16607>`__,"Apache-2.0",87.81,,,
ViTSmall16_Weights.SENTINEL2_ALL_DINO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",90.5,99.0,62.2,
ViTSmall16_Weights.SENTINEL2_ALL_MOCO,13,`link <https://github.com/zhu-xlab/SSL4EO-S12>`__,`link <https://arxiv.org/abs/2211.07044>`__,"CC-BY-4.0",89.9,98.6,61.6,
Swin_V2_B_Weights.SENTINEL2_RGB_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"Apache-2.0",,,,
Swin_V2_B_Weights.SENTINEL2_RGB_SI_SATLAS,3,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
Swin_V2_B_Weights.SENTINEL2_MS_SI_SATLAS,9,`link <https://github.com/allenai/satlas>`__,`link <https://arxiv.org/abs/2211.15660>`__,"ODC-BY",,,,
83 changes: 75 additions & 8 deletions torchgeo/models/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,35 @@
import kornia.augmentation as K
import torch
import torchvision
from kornia.contrib import Lambda
from torchvision.models import SwinTransformer
from torchvision.models._api import Weights, WeightsEnum

from ..transforms import AugmentationSequential

__all__ = ["Swin_V2_B_Weights"]


# https://github.com/allenai/satlas/blob/bcaa968da5395f675d067613e02613a344e81415/satlas/cmd/model/train.py#L42 # noqa: E501
# All Satlas imagery is uint8 and normalized to the range (0, 1) by dividing by 255
# Satlas Sentinel-1 and RGB Sentinel-2 and NAIP imagery is uint8 and is normalized to (0, 1) by dividing by 255. # noqa: E501
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
_satlas_transforms = AugmentationSequential(
K.CenterCrop(256),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)), data_keys=["image"]
)

# Satlas multispectral Sentinel-2 imagery divides first 3 bands by 255 and the following 6 bands by 8160, both clipped to (0, 1). # noqa: E501
calebrob6 marked this conversation as resolved.
Show resolved Hide resolved
_std = torch.tensor(
[255.0, 255.0, 255.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0, 8160.0]
) # noqa: E501
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
_mean = torch.zeros_like(_std)
_sentinel2_ms_satlas_transforms = AugmentationSequential(
K.Normalize(mean=_mean, std=_std),
Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)),
data_keys=["image"],
)

# Satlas Landsat imagery is 16-bit, normalized by clipping some pixel N with (N-4000)/16320 to (0, 1). # noqa: E501
_landsat_satlas_transforms = AugmentationSequential(
K.Normalize(mean=torch.tensor(4000), std=torch.tensor(16320)),
Lambda(lambda x: torch.clamp(x, min=0.0, max=1.0)),
data_keys=["image"],
)

Expand All @@ -39,8 +55,8 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
.. versionadded:: 0.6
"""

NAIP_RGB_SATLAS = Weights(
url="https://huggingface.co/torchgeo/swin_v2_b_naip_rgb_satlas/resolve/main/swin_v2_b_naip_rgb_satlas-685f45bd.pth", # noqa: E501
NAIP_RGB_SI_SATLAS = Weights(
url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/aerial_swinb_si.pth", # noqa: E501
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
transforms=_satlas_transforms,
meta={
"dataset": "Satlas",
Expand All @@ -51,8 +67,8 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
},
)

SENTINEL2_RGB_SATLAS = Weights(
url="https://huggingface.co/torchgeo/swin_v2_b_sentinel2_rgb_satlas/resolve/main/swin_v2_b_sentinel2_rgb_satlas-51471041.pth", # noqa: E501
SENTINEL2_RGB_SI_SATLAS = Weights(
url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_si_rgb.pth", # noqa: E501
transforms=_satlas_transforms,
meta={
"dataset": "Satlas",
Expand All @@ -63,6 +79,57 @@ class Swin_V2_B_Weights(WeightsEnum): # type: ignore[misc]
},
)

SENTINEL2_MS_SI_SATLAS = Weights(
url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel2_swinb_si_ms.pth", # noqa: E501
transforms=_sentinel2_ms_satlas_transforms,
meta={
"dataset": "Satlas",
"in_chans": 9,
"model": "swin_v2_b",
"publication": "https://arxiv.org/abs/2211.15660",
"repo": "https://github.com/allenai/satlas",
"bands": ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B11", "B12"],
},
)

SENTINEL1_SI_SATLAS = Weights(
url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/sentinel1_swinb_si.pth", # noqa: E501
transforms=_satlas_transforms,
meta={
"dataset": "Satlas",
"in_chans": 2,
piperwolters marked this conversation as resolved.
Show resolved Hide resolved
"model": "swin_v2_b",
"publication": "https://arxiv.org/abs/2211.15660",
"repo": "https://github.com/allenai/satlas",
"bands": ["VH", "VV"],
},
)

LANDSAT_SI_SATLAS = Weights(
url="https://huggingface.co/allenai/satlas-pretrain/resolve/main/landsat_swinb_si.pth", # noqa: E501
transforms=_landsat_satlas_transforms,
meta={
"dataset": "Satlas",
"in_chans": 11,
"model": "swin_v2_b",
"publication": "https://arxiv.org/abs/2211.15660",
"repo": "https://github.com/allenai/satlas",
"bands": [
"B01",
"B02",
"B03",
"B04",
"B05",
"B06",
"B07",
"B08",
"B09",
"B10",
"B11",
], # noqa: E501
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved
},
)


def swin_v2_b(
weights: Optional[Swin_V2_B_Weights] = None, *args: Any, **kwargs: Any
Expand Down
6 changes: 3 additions & 3 deletions torchgeo/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import kornia.augmentation as K
import torch
from einops import rearrange
from kornia.contrib import extract_tensor_patches
from kornia.contrib import Lambda, extract_tensor_patches
from kornia.geometry import crop_by_indices
from kornia.geometry.boxes import Boxes
from torch import Tensor
Expand All @@ -25,7 +25,7 @@ class AugmentationSequential(Module):

def __init__(
self,
*args: Union[K.base._AugmentationBase, K.ImageSequential],
*args: Union[K.base._AugmentationBase, K.ImageSequential, Lambda],
data_keys: list[str],
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
else:
keys.append(key)

self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs)
self.augs = K.AugmentationSequential(*args, data_keys=keys, **kwargs) # type: ignore[arg-type] # noqa: E501
adamjstewart marked this conversation as resolved.
Show resolved Hide resolved

def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
"""Perform augmentations and update data dict.
Expand Down
Loading