Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Use int64 in argmax #103

Closed
soodoshll opened this issue Feb 15, 2023 · 1 comment · Fixed by #111
Closed

[Bug] Use int64 in argmax #103

soodoshll opened this issue Feb 15, 2023 · 1 comment · Fixed by #111

Comments

@soodoshll
Copy link
Collaborator

Now hidet uses int32 as the return type of ArgReduceTask,

extent=x_shape[dim], fcompute=reduce_fcompute, reduce_type=reduce_type, index_dtype='int32'

which is misaligned with torch and onnx that return int64, leading to incompatibility with other operators. Like concatenation requires inputs to have the same dtype. So concatenating the output of argmax with an int64 tensor is legal in torch but illegal in hidet.

Here is a simple snippet:

import torch
import hidet
import onnx

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x, y):
        y = torch.argmax(y, dim=0)
        print(y.dtype) # int64
        return torch.concat([x, y])

device = 'cuda'

model = Foo()
model.to(device)

x = torch.ones([5], dtype=torch.int64, device=device)
y = torch.rand([5, 5], device=device)
z = model(x, y)
print(z.shape)

torch.onnx.export(model, (x, y), 'tmp.onnx', input_names = ['x', 'y'],
                  output_names = ['z'])
model = onnx.load('tmp.onnx')

hidet.torch.dynamo_config.search_space(1)

x = hidet.from_torch(x)
y = hidet.from_torch(y)
symbol_data = [hidet.symbol_like(x), hidet.symbol_like(y)]
hidet_onnx_module = hidet.graph.frontend.from_onnx(model)
symbol_output = hidet_onnx_module(*symbol_data)
graph: hidet.FlowGraph = hidet.trace_from(symbol_output, inputs=symbol_data)
with hidet.graph.PassContext() as ctx:
    graph_opt: hidet.FlowGraph = hidet.graph.optimize(graph)
cuda_graph = graph_opt.cuda_graph()
outputs = cuda_graph.run([x, y])

which raises an error:

ValueError: concat: expect all tensors have the same dtype, but got:
Tensor(shape=(5,), dtype='int64', device='cuda:0')
Tensor(shape=(5,), dtype='int32', device='cuda:0')
@yaoyaoding
Copy link
Member

Thanks @soodoshll, should be fixed in #111.

KTong821 added a commit to KTong821/hidet that referenced this issue Apr 24, 2024
Infrastructure for compiled stable diffusion app.

Towards hidet-org#57
vadiklyutiy pushed a commit that referenced this issue Jul 22, 2024
Infrastructure for compiled stable diffusion app.

Towards #57
vadiklyutiy pushed a commit that referenced this issue Jul 23, 2024
Infrastructure for compiled stable diffusion app.

Towards #57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants