Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

softmax + transpose + div_ triggers assertion fail in compile mode #93371

Closed
Kristoff-starling opened this issue Jan 31, 2023 · 2 comments
Closed
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Kristoff-starling
Copy link
Contributor

Kristoff-starling commented Jan 31, 2023

🐛 Describe the bug

The following program works fine in eager mode but raises assertion fail in compile mode. The 3 operators are all necessary for triggering the failure.

Error logs

[Click to expand]
==== Eager mode OK! ====
==== torchcomp compilation OK! ====
python3.8/site-packages/torch/_inductor/compile_fx.py:90: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
Traceback (most recent call last):
  File "python3.8/site-packages/torch/_dynamo/output_graph.py", line 692, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.fake_example_inputs())
  File "python3.8/site-packages/torch/_dynamo/debug_utils.py", line 1047, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs, **kwargs)
  File "python3.8/site-packages/torch/__init__.py", line 1323, in __call__
    return self.compile_fn(model_, inputs_)
  File "python3.8/site-packages/torch/_dynamo/optimizations/backends.py", line 24, in inner
    return fn(gm, example_inputs, **kwargs)
  File "python3.8/site-packages/torch/_dynamo/optimizations/backends.py", line 61, in inductor
    return compile_fx(*args, **kwargs)
  File "python3.8/site-packages/torch/_inductor/compile_fx.py", line 413, in compile_fx
    return aot_autograd(
  File "python3.8/site-packages/torch/_dynamo/optimizations/training.py", line 74, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2483, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "python3.8/site-packages/torch/_dynamo/utils.py", line 161, in time_wrapper
    r = func(*args, **kwargs)
  File "python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2180, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
  File "python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1411, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config)
  File "python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1061, in aot_dispatch_base
    compiled_fw = aot_config.fw_compiler(fw_module, flat_args)
  File "python3.8/site-packages/torch/_dynamo/utils.py", line 161, in time_wrapper
    r = func(*args, **kwargs)
  File "python3.8/site-packages/torch/_inductor/compile_fx.py", line 388, in fw_compiler
    return inner_compile(
  File "python3.8/site-packages/torch/_dynamo/debug_utils.py", line 586, in debug_wrapper
    compiled_fn = compiler_fn(gm, example_inputs, **kwargs)
  File "python3.8/site-packages/torch/_inductor/debug.py", line 239, in inner
    return fn(*args, **kwargs)
  File "/opt/miniconda3/lib/python3.8/contextlib.py", line 75, in inner
    return func(*args, **kwds)
  File "python3.8/site-packages/torch/_inductor/compile_fx.py", line 151, in compile_fx_inner
    compiled_fn = graph.compile_to_fn()
  File "python3.8/site-packages/torch/_inductor/graph.py", line 567, in compile_to_fn
    return self.compile_to_module().call
  File "python3.8/site-packages/torch/_dynamo/utils.py", line 161, in time_wrapper
    r = func(*args, **kwargs)
  File "python3.8/site-packages/torch/_inductor/graph.py", line 552, in compile_to_module
    code = self.codegen()
  File "python3.8/site-packages/torch/_inductor/graph.py", line 501, in codegen
    self.scheduler = Scheduler(self.buffers)
  File "python3.8/site-packages/torch/_dynamo/utils.py", line 161, in time_wrapper
    r = func(*args, **kwargs)
  File "python3.8/site-packages/torch/_inductor/scheduler.py", line 567, in __init__
    self.nodes.append(SchedulerNode(self, node, group_fn))
  File "python3.8/site-packages/torch/_inductor/scheduler.py", line 234, in __init__
    super().__init__(scheduler, node)
  File "python3.8/site-packages/torch/_inductor/scheduler.py", line 58, in __init__
    self.set_read_writes(node.get_read_writes())
  File "python3.8/site-packages/torch/_inductor/utils.py", line 206, in wrapper
    setattr(self, key, fn(self))
  File "python3.8/site-packages/torch/_inductor/ir.py", line 2035, in get_read_writes
    self.get_store_function(),
  File "python3.8/site-packages/torch/_inductor/ir.py", line 2040, in get_store_function
    indexer = self.layout.as_fixed().make_indexer()
  File "python3.8/site-packages/torch/_inductor/ir.py", line 1883, in make_indexer
    return self.target.make_indexer()
AttributeError: 'PermuteView' object has no attribute 'make_indexer'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "repro.py", line 16, in <module>
    ret_compiled = compiled(x)
  File "python3.8/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "python3.8/site-packages/torch/_dynamo/eval_frame.py", line 332, in catch_errors
    return callback(frame, cache_size, hooks)
  File "python3.8/site-packages/torch/_dynamo/convert_frame.py", line 403, in _convert_frame
    result = inner_convert(frame, cache_size, hooks)
  File "python3.8/site-packages/torch/_dynamo/convert_frame.py", line 103, in _fn
    return fn(*args, **kwargs)
  File "python3.8/site-packages/torch/_dynamo/convert_frame.py", line 261, in _convert_frame_assert
    return _compile(
  File "python3.8/site-packages/torch/_dynamo/utils.py", line 161, in time_wrapper
    r = func(*args, **kwargs)
  File "python3.8/site-packages/torch/_dynamo/convert_frame.py", line 323, in _compile
    out_code = transform_code_object(code, transform)
  File "python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 339, in transform_code_object
    transformations(instructions, code_options)
  File "python3.8/site-packages/torch/_dynamo/convert_frame.py", line 310, in transform
    tracer.run()
  File "python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1715, in run
    super().run()
  File "python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 564, in run
    and self.step()
  File "python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 527, in step
    getattr(self, inst.opname)(inst)
  File "python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1781, in RETURN_VALUE
    self.output.compile_subgraph(self)
  File "python3.8/site-packages/torch/_dynamo/output_graph.py", line 539, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "python3.8/site-packages/torch/_dynamo/output_graph.py", line 610, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "python3.8/site-packages/torch/_dynamo/utils.py", line 161, in time_wrapper
    r = func(*args, **kwargs)
  File "python3.8/site-packages/torch/_dynamo/output_graph.py", line 697, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised AttributeError: 'PermuteView' object has no attribute 'make_indexer'

Set torch._dynamo.config.verbose=True for more information


You can suppress this exception and fall back to eager by setting:
    torch._dynamo.config.suppress_errors = True

Minified repro

import torch

def fn(input):
    v1 = torch.nn.functional.softmax(input, 1) # works fine w/o this line
    v2 = v1.transpose(0, 3) # works fine w/o this line
    return v2.div_(2.0) # works fine with non-inplace operator "div"

x = torch.rand([4, 6, 4, 1])

ret_eager = fn(x)
print('==== Eager mode OK! ====')

compiled = torch.compile(fn)
print('==== torchcomp compilation OK! ====')

ret_compiled = compiled(x)
print('==== torchcomp mode OK! ====')

Versions

Envrionment [Click to expand]
PyTorch version: 2.0.0.dev20230130+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: 10.0.0-4ubuntu1 
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.13 (default, Mar 28 2022, 11:38:47)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-137-generic-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000

Nvidia driver version: 510.68.02
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.2
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.2
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] pytorch-triton==2.0.0+0d7e753227
[pip3] torch==2.0.0.dev20230130+cu117
[pip3] torchaudio==2.0.0.dev20230126+cu117
[pip3] torchvision==0.15.0.dev20230126+cu117
[conda] No relevant packages

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @mlazos @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @desertfire

@ezyang
Copy link
Contributor

ezyang commented Jan 31, 2023

maybe same as #93386

@albanD albanD added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 7, 2023
@ngimel
Copy link
Collaborator

ngimel commented Feb 22, 2023

Can't repro anymore

@ngimel ngimel closed this as completed Feb 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants