Skip to content

Commit

Permalink
go the extra mile and support quads #54
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 23, 2024
1 parent 873a0fc commit 641659a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 34 deletions.
76 changes: 43 additions & 33 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@ def derive_angle(x, y, eps = 1e-5):

@torch.no_grad()
def get_derived_face_features(
face_coords: TensorType['b', 'nf', 3, 3, float] # 3 vertices with 3 coordinates
face_coords: TensorType['b', 'nf', 'nvf', 3, float] # 3 vertices with 3 coordinates
):
shifted_face_coords = torch.cat((face_coords[:, :, -1:], face_coords[:, :, :-1]), dim = 2)

angles = derive_angle(face_coords, shifted_face_coords)

edge1, edge2, _ = (face_coords - shifted_face_coords).unbind(dim = 2)
edge1, edge2, *_ = (face_coords - shifted_face_coords).unbind(dim = 2)

normals = l2norm(torch.cross(edge1, edge2, dim = -1))
area = normals.norm(dim = -1, keepdim = True) * 0.5
Expand Down Expand Up @@ -429,10 +429,14 @@ def __init__(
attn_dropout = 0.,
ff_dropout = 0.,
resnet_dropout = 0,
checkpoint_quantizer = False
checkpoint_quantizer = False,
quads = False
):
super().__init__()

self.num_vertices_per_face = 3 if not quads else 4
total_coordinates_per_face = self.num_vertices_per_face * 3

# main face coordinate embedding

self.num_discrete_coors = num_discrete_coors
Expand Down Expand Up @@ -464,7 +468,7 @@ def __init__(

# initial dimension

init_dim = dim_coor_embed * 9 + dim_angle_embed * 3 + dim_normal_embed * 3 + dim_area_embed
init_dim = dim_coor_embed * (3 * self.num_vertices_per_face) + dim_angle_embed * self.num_vertices_per_face + dim_normal_embed * 3 + dim_area_embed

# project into model dimension

Expand Down Expand Up @@ -510,7 +514,7 @@ def __init__(
self.codebook_size = codebook_size
self.num_quantizers = num_quantizers

self.project_dim_codebook = nn.Linear(curr_dim, dim_codebook * 3)
self.project_dim_codebook = nn.Linear(curr_dim, dim_codebook * self.num_vertices_per_face)

if use_residual_lfq:
self.quantizer = ResidualLFQ(
Expand Down Expand Up @@ -556,7 +560,7 @@ def __init__(
assert is_odd(init_decoder_conv_kernel)

self.init_decoder_conv = nn.Sequential(
nn.Conv1d(dim_codebook * 3, init_decoder_dim, kernel_size = init_decoder_conv_kernel, padding = init_decoder_conv_kernel // 2),
nn.Conv1d(dim_codebook * self.num_vertices_per_face, init_decoder_dim, kernel_size = init_decoder_conv_kernel, padding = init_decoder_conv_kernel // 2),
nn.SiLU(),
Rearrange('b c n -> b n c'),
nn.LayerNorm(init_decoder_dim),
Expand All @@ -572,8 +576,8 @@ def __init__(
curr_dim = dim_layer

self.to_coor_logits = nn.Sequential(
nn.Linear(curr_dim, num_discrete_coors * 9),
Rearrange('... (v c) -> ... v c', v = 9)
nn.Linear(curr_dim, num_discrete_coors * total_coordinates_per_face),
Rearrange('... (v c) -> ... v c', v = total_coordinates_per_face)
)

# loss related
Expand All @@ -586,7 +590,7 @@ def encode(
self,
*,
vertices: TensorType['b', 'nv', 3, float],
faces: TensorType['b', 'nf', 3, int],
faces: TensorType['b', 'nf', 'nvf', int],
face_edges: TensorType['b', 'e', 2, int],
face_mask: TensorType['b', 'nf', bool],
face_edges_mask: TensorType['b', 'e', bool],
Expand All @@ -597,12 +601,15 @@ def encode(
b - batch
nf - number of faces
nv - number of vertices (3)
nvf - number of vertices per face (3 or 4) - triangles vs quads
c - coordinates (3)
d - embed dim
"""

batch, num_vertices, num_coors, device = *vertices.shape, vertices.device
_, num_faces, _ = faces.shape
_, num_faces, num_vertices_per_face = faces.shape

assert self.num_vertices_per_face == num_vertices_per_face

face_without_pad = faces.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1'), 0)

Expand All @@ -626,7 +633,7 @@ def encode(
# discretize vertices for face coordinate embedding

discrete_face_coords = self.discretize_face_coords(face_coords)
discrete_face_coords = rearrange(discrete_face_coords, 'b nf nv c -> b nf (nv c)') # 9 coordinates per face
discrete_face_coords = rearrange(discrete_face_coords, 'b nf nv c -> b nf (nv c)') # 9 or 12 coordinates per face

face_coor_embed = self.coor_embed(discrete_face_coords)
face_coor_embed = rearrange(face_coor_embed, 'b nf c d -> b nf (c d)')
Expand Down Expand Up @@ -684,7 +691,7 @@ def encode(
def quantize(
self,
*,
faces: TensorType['b', 'nf', 3, int],
faces: TensorType['b', 'nf', 'nvf', int],
face_mask: TensorType['b', 'n', bool],
face_embed: TensorType['b', 'nf', 'd', float],
pad_id = None,
Expand All @@ -697,7 +704,7 @@ def quantize(
num_vertices = int(max_vertex_index.item() + 1)

face_embed = self.project_dim_codebook(face_embed)
face_embed = rearrange(face_embed, 'b nf (nv d) -> b nf nv d', nv = 3)
face_embed = rearrange(face_embed, 'b nf (nvf d) -> b nf nvf d', nvf = self.num_vertices_per_face)

vertex_dim = face_embed.shape[-1]
vertices = torch.zeros((batch, num_vertices, vertex_dim), device = device)
Expand All @@ -711,7 +718,7 @@ def quantize(

# prepare for scatter mean

faces_with_dim = repeat(faces, 'b nf nv -> b (nf nv) d', d = vertex_dim)
faces_with_dim = repeat(faces, 'b nf nvf -> b (nf nvf) d', d = vertex_dim)

face_embed = rearrange(face_embed, 'b ... d -> b (...) d')

Expand Down Expand Up @@ -749,16 +756,16 @@ def quantize_wrapper_fn(inp):
# gather quantized vertexes back to faces for decoding
# now the faces have quantized vertices

face_embed_output = get_at('b [n] d, b nf nv -> b nf (nv d)', quantized, faces)
face_embed_output = get_at('b [n] d, b nf nvf -> b nf (nvf d)', quantized, faces)

# vertex codes also need to be gathered to be organized by face sequence
# for autoregressive learning

codes_output = get_at('b [n] q, b nf nv -> b (nf nv) q', codes, faces)
codes_output = get_at('b [n] q, b nf nvf -> b (nf nvf) q', codes, faces)

# make sure codes being outputted have this padding

face_mask = repeat(face_mask, 'b nf -> b (nf nv) 1', nv = 3)
face_mask = repeat(face_mask, 'b nf -> b (nf nvf) 1', nvf = self.num_vertices_per_face)
codes_output = codes_output.masked_fill(~face_mask, self.pad_id)

# output quantized, codes, as well as commitment loss
Expand All @@ -783,7 +790,6 @@ def decode(
x = ff(x) + x

x = rearrange(x, 'b n d -> b d n')

x = x.masked_fill(~conv_face_mask, 0.)
x = self.init_decoder_conv(x)

Expand All @@ -803,7 +809,7 @@ def decode_from_codes_to_faces(
codes = rearrange(codes, 'b ... -> b (...)')

if not exists(face_mask):
face_mask = reduce(codes != self.pad_id, 'b (nf nv q) -> b nf', 'all', nv = 3, q = self.num_quantizers)
face_mask = reduce(codes != self.pad_id, 'b (nf nvf q) -> b nf', 'all', nvf = self.num_vertices_per_face, q = self.num_quantizers)

# handle different code shapes

Expand All @@ -812,7 +818,7 @@ def decode_from_codes_to_faces(
# decode

quantized = self.quantizer.get_output_from_indices(codes)
quantized = rearrange(quantized, 'b (nf nv) d -> b nf (nv d)', nv = 3)
quantized = rearrange(quantized, 'b (nf nvf) d -> b nf (nvf d)', nvf = self.num_vertices_per_face)

decoded = self.decode(
quantized,
Expand All @@ -824,7 +830,7 @@ def decode_from_codes_to_faces(

pred_face_coords = pred_face_coords.argmax(dim = -1)

pred_face_coords = rearrange(pred_face_coords, '... (v c) -> ... v c', v = 3)
pred_face_coords = rearrange(pred_face_coords, '... (v c) -> ... v c', v = self.num_vertices_per_face)

# back to continuous space

Expand Down Expand Up @@ -877,7 +883,7 @@ def forward(
self,
*,
vertices: TensorType['b', 'nv', 3, float],
faces: TensorType['b', 'nf', 3, int],
faces: TensorType['b', 'nf', 'nvf', int],
face_edges: Optional[TensorType['b', 'e', 2, int]] = None,
return_codes = False,
return_loss_breakdown = False,
Expand Down Expand Up @@ -912,7 +918,7 @@ def forward(
if return_codes:
assert not return_recon_faces, 'cannot return reconstructed faces when just returning raw codes'

codes = codes.masked_fill(~repeat(face_mask, 'b nf -> b (nf 3) 1'), self.pad_id)
codes = codes.masked_fill(~repeat(face_mask, 'b nf -> b (nf nvf) 1', nvf = self.num_vertices_per_face), self.pad_id)
return codes

decode = self.decode(
Expand All @@ -932,7 +938,7 @@ def forward(
continuous_range = self.coor_continuous_range,
)

recon_faces = rearrange(recon_faces, 'b nf (nv c) -> b nf nv c', nv = 3)
recon_faces = rearrange(recon_faces, 'b nf (nvf c) -> b nf nvf c', nvf = self.num_vertices_per_face)
face_mask = rearrange(face_mask, 'b nf -> b nf 1 1')
recon_faces = recon_faces.masked_fill(~face_mask, float('nan'))
face_mask = rearrange(face_mask, 'b nf 1 1 -> b nf')
Expand Down Expand Up @@ -960,7 +966,7 @@ def forward(

recon_losses = (-target_one_hot * pred_log_prob).sum(dim = 1)

face_mask = repeat(face_mask, 'b nf -> b (nf r)', r = 9)
face_mask = repeat(face_mask, 'b nf -> b (nf r)', r = self.num_vertices_per_face * 3)
recon_loss = recon_losses[face_mask].mean()

# calculate total loss
Expand Down Expand Up @@ -1012,9 +1018,11 @@ def __init__(
pad_id = -1,
condition_on_text = False,
text_condition_model_types = ('t5',),
text_condition_cond_drop_prob = 0.25
text_condition_cond_drop_prob = 0.25,
quads = False
):
super().__init__()
self.num_vertices_per_face = 3 if not quads else 4

dim, dim_fine = (dim, dim) if isinstance(dim, int) else dim

Expand All @@ -1029,12 +1037,12 @@ def __init__(

# they use axial positional embeddings

assert divisible_by(max_seq_len, 3 * self.num_quantizers), f'max_seq_len ({max_seq_len}) must be divisible by (3 x {self.num_quantizers}) = {3 * self.num_quantizers}' # 3 vertices per face, with D codes per vertex
assert divisible_by(max_seq_len, self.num_vertices_per_face * self.num_quantizers), f'max_seq_len ({max_seq_len}) must be divisible by (3 x {self.num_quantizers}) = {3 * self.num_quantizers}' # 3 vertices per face, with D codes per vertex

self.token_embed = nn.Embedding(self.codebook_size + 1, dim)

self.quantize_level_embed = nn.Parameter(torch.randn(self.num_quantizers, dim))
self.vertex_embed = nn.Parameter(torch.randn(3, dim))
self.vertex_embed = nn.Parameter(torch.randn(self.num_vertices_per_face, dim))

self.abs_pos_emb = nn.Embedding(max_seq_len, dim)

Expand All @@ -1057,7 +1065,7 @@ def __init__(
# for summarizing the vertices of each face

self.to_face_tokens = nn.Sequential(
nn.Linear(self.num_quantizers * 3 * dim, dim),
nn.Linear(self.num_quantizers * self.num_vertices_per_face * dim, dim),
nn.LayerNorm(dim)
)

Expand Down Expand Up @@ -1173,9 +1181,11 @@ def generate(
cache = (None, None)

for i in tqdm(range(curr_length, max_seq_len)):

# example below for triangles, extrapolate for quads
# v1([q1] [q2] [q1] [q2] [q1] [q2]) v2([eos| q1] [q2] [q1] [q2] [q1] [q2]) -> 0 1 2 3 4 5 6 7 8 9 10 11 12 -> v1(F F F F F F) v2(T F F F F F) v3(T F F F F F)

can_eos = i != 0 and divisible_by(i, self.num_quantizers * 3) # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residual VQ codes
can_eos = i != 0 and divisible_by(i, self.num_quantizers * self.num_vertices_per_face) # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residual VQ codes

output = self.forward_on_codes(
codes,
Expand Down Expand Up @@ -1249,7 +1259,7 @@ def forward(
self,
*,
vertices: TensorType['b', 'nv', 3, int],
faces: TensorType['b', 'nf', 3, int],
faces: TensorType['b', 'nf', 'nvf', int],
face_edges: Optional[TensorType['b', 'e', 2, int]] = None,
codes: Optional[Tensor] = None,
cache: Optional[LayerIntermediates] = None,
Expand Down Expand Up @@ -1352,13 +1362,13 @@ def forward_on_codes(

# embedding for each vertex

vertex_embed = repeat(self.vertex_embed, 'nv d -> (r nv q) d', r = ceil(code_len / (3 * self.num_quantizers)), q = self.num_quantizers)
vertex_embed = repeat(self.vertex_embed, 'nv d -> (r nv q) d', r = ceil(code_len / (self.num_vertices_per_face * self.num_quantizers)), q = self.num_quantizers)
codes = codes + vertex_embed[:code_len]

# create a token per face, by summarizing the 3 vertices
# this is similar in design to the RQ transformer from Lee et al. https://arxiv.org/abs/2203.01941

num_tokens_per_face = self.num_quantizers * 3
num_tokens_per_face = self.num_quantizers * self.num_vertices_per_face

curr_vertex_pos = code_len % num_tokens_per_face # the current intra-face vertex-code position id, needed for caching at the fine decoder stage

Expand Down
2 changes: 1 addition & 1 deletion meshgpt_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.7.2'
__version__ = '1.0.0'

0 comments on commit 641659a

Please sign in to comment.