diff --git a/meshgpt_pytorch/meshgpt_pytorch.py b/meshgpt_pytorch/meshgpt_pytorch.py index b4e78d07..df5bb76e 100644 --- a/meshgpt_pytorch/meshgpt_pytorch.py +++ b/meshgpt_pytorch/meshgpt_pytorch.py @@ -136,13 +136,13 @@ def derive_angle(x, y, eps = 1e-5): @torch.no_grad() def get_derived_face_features( - face_coords: TensorType['b', 'nf', 3, 3, float] # 3 vertices with 3 coordinates + face_coords: TensorType['b', 'nf', 'nvf', 3, float] # 3 vertices with 3 coordinates ): shifted_face_coords = torch.cat((face_coords[:, :, -1:], face_coords[:, :, :-1]), dim = 2) angles = derive_angle(face_coords, shifted_face_coords) - edge1, edge2, _ = (face_coords - shifted_face_coords).unbind(dim = 2) + edge1, edge2, *_ = (face_coords - shifted_face_coords).unbind(dim = 2) normals = l2norm(torch.cross(edge1, edge2, dim = -1)) area = normals.norm(dim = -1, keepdim = True) * 0.5 @@ -429,10 +429,14 @@ def __init__( attn_dropout = 0., ff_dropout = 0., resnet_dropout = 0, - checkpoint_quantizer = False + checkpoint_quantizer = False, + quads = False ): super().__init__() + self.num_vertices_per_face = 3 if not quads else 4 + total_coordinates_per_face = self.num_vertices_per_face * 3 + # main face coordinate embedding self.num_discrete_coors = num_discrete_coors @@ -464,7 +468,7 @@ def __init__( # initial dimension - init_dim = dim_coor_embed * 9 + dim_angle_embed * 3 + dim_normal_embed * 3 + dim_area_embed + init_dim = dim_coor_embed * (3 * self.num_vertices_per_face) + dim_angle_embed * self.num_vertices_per_face + dim_normal_embed * 3 + dim_area_embed # project into model dimension @@ -510,7 +514,7 @@ def __init__( self.codebook_size = codebook_size self.num_quantizers = num_quantizers - self.project_dim_codebook = nn.Linear(curr_dim, dim_codebook * 3) + self.project_dim_codebook = nn.Linear(curr_dim, dim_codebook * self.num_vertices_per_face) if use_residual_lfq: self.quantizer = ResidualLFQ( @@ -556,7 +560,7 @@ def __init__( assert is_odd(init_decoder_conv_kernel) self.init_decoder_conv = nn.Sequential( - nn.Conv1d(dim_codebook * 3, init_decoder_dim, kernel_size = init_decoder_conv_kernel, padding = init_decoder_conv_kernel // 2), + nn.Conv1d(dim_codebook * self.num_vertices_per_face, init_decoder_dim, kernel_size = init_decoder_conv_kernel, padding = init_decoder_conv_kernel // 2), nn.SiLU(), Rearrange('b c n -> b n c'), nn.LayerNorm(init_decoder_dim), @@ -572,8 +576,8 @@ def __init__( curr_dim = dim_layer self.to_coor_logits = nn.Sequential( - nn.Linear(curr_dim, num_discrete_coors * 9), - Rearrange('... (v c) -> ... v c', v = 9) + nn.Linear(curr_dim, num_discrete_coors * total_coordinates_per_face), + Rearrange('... (v c) -> ... v c', v = total_coordinates_per_face) ) # loss related @@ -586,7 +590,7 @@ def encode( self, *, vertices: TensorType['b', 'nv', 3, float], - faces: TensorType['b', 'nf', 3, int], + faces: TensorType['b', 'nf', 'nvf', int], face_edges: TensorType['b', 'e', 2, int], face_mask: TensorType['b', 'nf', bool], face_edges_mask: TensorType['b', 'e', bool], @@ -597,12 +601,15 @@ def encode( b - batch nf - number of faces nv - number of vertices (3) + nvf - number of vertices per face (3 or 4) - triangles vs quads c - coordinates (3) d - embed dim """ batch, num_vertices, num_coors, device = *vertices.shape, vertices.device - _, num_faces, _ = faces.shape + _, num_faces, num_vertices_per_face = faces.shape + + assert self.num_vertices_per_face == num_vertices_per_face face_without_pad = faces.masked_fill(~rearrange(face_mask, 'b nf -> b nf 1'), 0) @@ -626,7 +633,7 @@ def encode( # discretize vertices for face coordinate embedding discrete_face_coords = self.discretize_face_coords(face_coords) - discrete_face_coords = rearrange(discrete_face_coords, 'b nf nv c -> b nf (nv c)') # 9 coordinates per face + discrete_face_coords = rearrange(discrete_face_coords, 'b nf nv c -> b nf (nv c)') # 9 or 12 coordinates per face face_coor_embed = self.coor_embed(discrete_face_coords) face_coor_embed = rearrange(face_coor_embed, 'b nf c d -> b nf (c d)') @@ -684,7 +691,7 @@ def encode( def quantize( self, *, - faces: TensorType['b', 'nf', 3, int], + faces: TensorType['b', 'nf', 'nvf', int], face_mask: TensorType['b', 'n', bool], face_embed: TensorType['b', 'nf', 'd', float], pad_id = None, @@ -697,7 +704,7 @@ def quantize( num_vertices = int(max_vertex_index.item() + 1) face_embed = self.project_dim_codebook(face_embed) - face_embed = rearrange(face_embed, 'b nf (nv d) -> b nf nv d', nv = 3) + face_embed = rearrange(face_embed, 'b nf (nvf d) -> b nf nvf d', nvf = self.num_vertices_per_face) vertex_dim = face_embed.shape[-1] vertices = torch.zeros((batch, num_vertices, vertex_dim), device = device) @@ -711,7 +718,7 @@ def quantize( # prepare for scatter mean - faces_with_dim = repeat(faces, 'b nf nv -> b (nf nv) d', d = vertex_dim) + faces_with_dim = repeat(faces, 'b nf nvf -> b (nf nvf) d', d = vertex_dim) face_embed = rearrange(face_embed, 'b ... d -> b (...) d') @@ -749,16 +756,16 @@ def quantize_wrapper_fn(inp): # gather quantized vertexes back to faces for decoding # now the faces have quantized vertices - face_embed_output = get_at('b [n] d, b nf nv -> b nf (nv d)', quantized, faces) + face_embed_output = get_at('b [n] d, b nf nvf -> b nf (nvf d)', quantized, faces) # vertex codes also need to be gathered to be organized by face sequence # for autoregressive learning - codes_output = get_at('b [n] q, b nf nv -> b (nf nv) q', codes, faces) + codes_output = get_at('b [n] q, b nf nvf -> b (nf nvf) q', codes, faces) # make sure codes being outputted have this padding - face_mask = repeat(face_mask, 'b nf -> b (nf nv) 1', nv = 3) + face_mask = repeat(face_mask, 'b nf -> b (nf nvf) 1', nvf = self.num_vertices_per_face) codes_output = codes_output.masked_fill(~face_mask, self.pad_id) # output quantized, codes, as well as commitment loss @@ -783,7 +790,6 @@ def decode( x = ff(x) + x x = rearrange(x, 'b n d -> b d n') - x = x.masked_fill(~conv_face_mask, 0.) x = self.init_decoder_conv(x) @@ -803,7 +809,7 @@ def decode_from_codes_to_faces( codes = rearrange(codes, 'b ... -> b (...)') if not exists(face_mask): - face_mask = reduce(codes != self.pad_id, 'b (nf nv q) -> b nf', 'all', nv = 3, q = self.num_quantizers) + face_mask = reduce(codes != self.pad_id, 'b (nf nvf q) -> b nf', 'all', nvf = self.num_vertices_per_face, q = self.num_quantizers) # handle different code shapes @@ -812,7 +818,7 @@ def decode_from_codes_to_faces( # decode quantized = self.quantizer.get_output_from_indices(codes) - quantized = rearrange(quantized, 'b (nf nv) d -> b nf (nv d)', nv = 3) + quantized = rearrange(quantized, 'b (nf nvf) d -> b nf (nvf d)', nvf = self.num_vertices_per_face) decoded = self.decode( quantized, @@ -824,7 +830,7 @@ def decode_from_codes_to_faces( pred_face_coords = pred_face_coords.argmax(dim = -1) - pred_face_coords = rearrange(pred_face_coords, '... (v c) -> ... v c', v = 3) + pred_face_coords = rearrange(pred_face_coords, '... (v c) -> ... v c', v = self.num_vertices_per_face) # back to continuous space @@ -877,7 +883,7 @@ def forward( self, *, vertices: TensorType['b', 'nv', 3, float], - faces: TensorType['b', 'nf', 3, int], + faces: TensorType['b', 'nf', 'nvf', int], face_edges: Optional[TensorType['b', 'e', 2, int]] = None, return_codes = False, return_loss_breakdown = False, @@ -912,7 +918,7 @@ def forward( if return_codes: assert not return_recon_faces, 'cannot return reconstructed faces when just returning raw codes' - codes = codes.masked_fill(~repeat(face_mask, 'b nf -> b (nf 3) 1'), self.pad_id) + codes = codes.masked_fill(~repeat(face_mask, 'b nf -> b (nf nvf) 1', nvf = self.num_vertices_per_face), self.pad_id) return codes decode = self.decode( @@ -932,7 +938,7 @@ def forward( continuous_range = self.coor_continuous_range, ) - recon_faces = rearrange(recon_faces, 'b nf (nv c) -> b nf nv c', nv = 3) + recon_faces = rearrange(recon_faces, 'b nf (nvf c) -> b nf nvf c', nvf = self.num_vertices_per_face) face_mask = rearrange(face_mask, 'b nf -> b nf 1 1') recon_faces = recon_faces.masked_fill(~face_mask, float('nan')) face_mask = rearrange(face_mask, 'b nf 1 1 -> b nf') @@ -960,7 +966,7 @@ def forward( recon_losses = (-target_one_hot * pred_log_prob).sum(dim = 1) - face_mask = repeat(face_mask, 'b nf -> b (nf r)', r = 9) + face_mask = repeat(face_mask, 'b nf -> b (nf r)', r = self.num_vertices_per_face * 3) recon_loss = recon_losses[face_mask].mean() # calculate total loss @@ -1012,9 +1018,11 @@ def __init__( pad_id = -1, condition_on_text = False, text_condition_model_types = ('t5',), - text_condition_cond_drop_prob = 0.25 + text_condition_cond_drop_prob = 0.25, + quads = False ): super().__init__() + self.num_vertices_per_face = 3 if not quads else 4 dim, dim_fine = (dim, dim) if isinstance(dim, int) else dim @@ -1029,12 +1037,12 @@ def __init__( # they use axial positional embeddings - assert divisible_by(max_seq_len, 3 * self.num_quantizers), f'max_seq_len ({max_seq_len}) must be divisible by (3 x {self.num_quantizers}) = {3 * self.num_quantizers}' # 3 vertices per face, with D codes per vertex + assert divisible_by(max_seq_len, self.num_vertices_per_face * self.num_quantizers), f'max_seq_len ({max_seq_len}) must be divisible by (3 x {self.num_quantizers}) = {3 * self.num_quantizers}' # 3 vertices per face, with D codes per vertex self.token_embed = nn.Embedding(self.codebook_size + 1, dim) self.quantize_level_embed = nn.Parameter(torch.randn(self.num_quantizers, dim)) - self.vertex_embed = nn.Parameter(torch.randn(3, dim)) + self.vertex_embed = nn.Parameter(torch.randn(self.num_vertices_per_face, dim)) self.abs_pos_emb = nn.Embedding(max_seq_len, dim) @@ -1057,7 +1065,7 @@ def __init__( # for summarizing the vertices of each face self.to_face_tokens = nn.Sequential( - nn.Linear(self.num_quantizers * 3 * dim, dim), + nn.Linear(self.num_quantizers * self.num_vertices_per_face * dim, dim), nn.LayerNorm(dim) ) @@ -1173,9 +1181,11 @@ def generate( cache = (None, None) for i in tqdm(range(curr_length, max_seq_len)): + + # example below for triangles, extrapolate for quads # v1([q1] [q2] [q1] [q2] [q1] [q2]) v2([eos| q1] [q2] [q1] [q2] [q1] [q2]) -> 0 1 2 3 4 5 6 7 8 9 10 11 12 -> v1(F F F F F F) v2(T F F F F F) v3(T F F F F F) - can_eos = i != 0 and divisible_by(i, self.num_quantizers * 3) # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residual VQ codes + can_eos = i != 0 and divisible_by(i, self.num_quantizers * self.num_vertices_per_face) # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residual VQ codes output = self.forward_on_codes( codes, @@ -1249,7 +1259,7 @@ def forward( self, *, vertices: TensorType['b', 'nv', 3, int], - faces: TensorType['b', 'nf', 3, int], + faces: TensorType['b', 'nf', 'nvf', int], face_edges: Optional[TensorType['b', 'e', 2, int]] = None, codes: Optional[Tensor] = None, cache: Optional[LayerIntermediates] = None, @@ -1352,13 +1362,13 @@ def forward_on_codes( # embedding for each vertex - vertex_embed = repeat(self.vertex_embed, 'nv d -> (r nv q) d', r = ceil(code_len / (3 * self.num_quantizers)), q = self.num_quantizers) + vertex_embed = repeat(self.vertex_embed, 'nv d -> (r nv q) d', r = ceil(code_len / (self.num_vertices_per_face * self.num_quantizers)), q = self.num_quantizers) codes = codes + vertex_embed[:code_len] # create a token per face, by summarizing the 3 vertices # this is similar in design to the RQ transformer from Lee et al. https://arxiv.org/abs/2203.01941 - num_tokens_per_face = self.num_quantizers * 3 + num_tokens_per_face = self.num_quantizers * self.num_vertices_per_face curr_vertex_pos = code_len % num_tokens_per_face # the current intra-face vertex-code position id, needed for caching at the fine decoder stage diff --git a/meshgpt_pytorch/version.py b/meshgpt_pytorch/version.py index fb9b668f..1f356cc5 100644 --- a/meshgpt_pytorch/version.py +++ b/meshgpt_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.7.2' +__version__ = '1.0.0'