Skip to content

Commit

Permalink
take care of normalization of images within discrete vae class, to ac…
Browse files Browse the repository at this point in the history
…count for discrepancy between preprocessing of openai vae and offered vae
  • Loading branch information
lucidrains committed Mar 9, 2021
1 parent 79e005d commit 357c16f
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
18 changes: 17 additions & 1 deletion dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __init__(
smooth_l1_loss = False,
temperature = 0.9,
straight_through = False,
kl_div_loss_weight = 0.
kl_div_loss_weight = 0.,
normalization = ((0.5,) * 3, (0.5,) * 3)
):
super().__init__()
assert log2(image_size).is_integer(), 'image size must be a power of 2'
Expand Down Expand Up @@ -126,6 +127,19 @@ def __init__(
self.loss_fn = F.smooth_l1_loss if smooth_l1_loss else F.mse_loss
self.kl_div_loss_weight = kl_div_loss_weight

# take care of normalization within class
self.normalization = normalization

def norm(self, images):
if not exists(self.normalization):
return images

means, stds = map(lambda t: torch.as_tensor(t).to(images), self.normalization)
means, stds = map(lambda t: rearrange(t, 'c -> () c () ()'), (means, stds))
images = images.clone()
images.sub_(means).div_(stds)
return images

@torch.no_grad()
@eval_decorator
def get_codebook_indices(self, images):
Expand Down Expand Up @@ -156,6 +170,8 @@ def forward(
device, num_tokens, image_size, kl_div_loss_weight = img.device, self.num_tokens, self.image_size, self.kl_div_loss_weight
assert img.shape[-1] == image_size and img.shape[-2] == image_size, f'input must have the correct image size {image_size}'

img = self.norm(img)

logits = self.encoder(img)

if return_logits:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '0.2.4',
version = '0.2.5',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down
4 changes: 2 additions & 2 deletions train_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@
ds = ImageFolder(
IMAGE_PATH,
T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize(IMAGE_SIZE),
T.CenterCrop(IMAGE_SIZE),
T.ToTensor(),
T.Normalize((0.5,) * 3, (0.5,) * 3)
T.ToTensor()
])
)

Expand Down

0 comments on commit 357c16f

Please sign in to comment.