Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compiler crashed when compiling a simple module with single conv layer #17701

Open
Marrrb1e opened this issue Jun 19, 2024 · 4 comments
Open
Labels
bug 🐞 Something isn't working support Request support or ask a question

Comments

@Marrrb1e
Copy link

What happened?

image

Steps to reproduce your issue

iree crash.zip
I put all the program into a zip, so you can reproduce. However, in the compilation script, be careful to change the corresponding path.

What component(s) does this issue relate to?

Compiler

Version information

I download the iree recently, maybe June 06. I am not sure about the exact version number.

Additional context

I was trying to compile the output of conv3*3 operators at a certain size with iree, and I wrote a network model with only one layer of convolution for this purpose, which I encountered errors after compiling.

@Marrrb1e Marrrb1e added the bug 🐞 Something isn't working label Jun 19, 2024
@ScottTodd
Copy link
Collaborator

What is the output from iree-compile --version?

The "openxla" in the screenshot makes me suspect this is using the latest "stable" release from April, prior to 3f51a55.

You can install a more recent nightly release by following https://iree.dev/reference/bindings/python/#__tabbed_2_2:

python -m pip install \
  --find-links https://iree.dev/pip-release-links.html \
  --upgrade \
  iree-compiler \
  iree-runtime

@Marrrb1e
Copy link
Author

I checked my version and found that its version is 20240226.813, so I updated the iree version to 20240619.929 following your command.
image
image
However, when I tried to compile the module, it still returned an error

@Marrrb1e
Copy link
Author

The new error is shown below.
image
The steps to reproduce are the same while the difference is the iree version is
IREE (https://iree.dev):
IREE compiler version 20240619.929 @ d792d24
LLVM version 19.0.0git

@ScottTodd
Copy link
Collaborator

Thanks for checking with a more recent version. As a general point, please include logs as text instead of screenshots (or in addition to them). Text is searchable across issues and can be more easily referenced against source files. For best results you can use markdown code blocks: https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks

The input program as .mlir would be useful. I'm not going to run compile_test.py myself for initial triage and test.log starts partway through compilation:

// -----// IR Dump Before LLVMCPUSelectLoweringStrategy (iree-llvmcpu-select-lowering-strategy) //----- //
hal.executable.variant public @embedded_elf_x86_64 target(<"llvm-cpu", "embedded-elf-x86_64", {cpu = "alderlake", cpu_features = "+prfchw,-cldemote,+avx,+aes,+sahf,+pclmul,-xop,+crc32,+xsaves,-avx512fp16,-usermsr,-sm4,-egpr,+sse4.1,-avx512ifma,+xsave,-avx512pf,+sse4.2,-tsxldtrk,-ptwrite,-widekl,-sm3,+invpcid,+64bit,+xsavec,-avx10.1-512,-avx512vpopcntdq,+cmov,-avx512vp2intersect,-avx512cd,+movbe,-avxvnniint8,-avx512er,-ccmp,-amx-int8,-kl,-avx10.1-256,-sha512,+avxvnni,-rtm,+adx,+avx2,-hreset,+movdiri,+serialize,+vpclmulqdq,-avx512vl,-uintr,-cf,+clflushopt,-raoint,-cmpccxadd,+bmi,-amx-tile,+sse,+gfni,-avxvnniint16,-amx-fp16,-ndd,+xsaveopt,+rdrnd,-avx512f,-amx-bf16,-avx512bf16,-avx512vnni,-push2pop2,+cx8,-avx512bw,+sse3,+pku,+fsgsbase,-clzero,-mwaitx,-lwp,+lzcnt,+sha,+movdir64b,-ppx,-wbnoinvd,-enqcmd,-prefetchwt1,-avxneconvert,-tbm,-pconfig,-amx-complex,+ssse3,+cx16,+bmi2,+fma,+popcnt,-avxifma,+f16c,-avx512bitalg,-rdpru,+clwb,+mmx,+sse2,+rdseed,-avx512vbmi2,-prefetchi,+rdpid,-fma4,-avx512vbmi,+shstk,+vaes,-waitpkg,-sgx,+fxsr,-avx512dq,-sse4a", data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", native_vector_size = 32 : i64, target_triple = "x86_64-unknown-unknown-eabi-elf"}>) {
  hal.executable.export public @jit_eval_dispatch_0_generic_64_f32 ordinal(0) layout(#hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer>]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>, #hal.interface.binding<0, 1>]} {
  ^bb0(%arg0: !hal.device):
    %x, %y, %z = flow.dispatch.workgroup_count_from_slice 
    hal.return %x, %y, %z : index, index, index
  }
  builtin.module {
    func.func @jit_eval_dispatch_0_generic_64_f32() {
      %cst = arith.constant 9.99999974E-6 : f32
      %cst_0 = arith.constant 1.000000e+00 : f32
      %c0 = arith.constant 0 : index
      %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<64xf32>>
      %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<64xf32>>
      %2 = flow.dispatch.tensor.load %0, offsets = [0], sizes = [64], strides = [1] : !flow.dispatch.tensor<readonly:tensor<64xf32>> -> tensor<64xf32>
      %3 = tensor.empty() : tensor<64xf32>
      %4 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%2 : tensor<64xf32>) outs(%3 : tensor<64xf32>) {
      ^bb0(%in: f32, %out: f32):
        %5 = arith.addf %in, %cst : f32
        %6 = math.sqrt %5 : f32
        %7 = arith.divf %cst_0, %6 : f32
        linalg.yield %7 : f32
      } -> tensor<64xf32>
      flow.dispatch.tensor.store %4, %1, offsets = [0], sizes = [64], strides = [1] : tensor<64xf32> -> !flow.dispatch.tensor<writeonly:tensor<64xf32>>
      return
    }
  }
}

The important part of the backtrace is:

 #3 0x00007ad6ad25cc5c mlir::DenseElementsAttr::getNumElements() const (/home/marble/iree-build/lib/libIREECompiler.so+0x1c5cc5c)
 #4 0x00007ad6b0882a98 mlir::iree_compiler::IREE::Stream::(anonymous namespace)::ConvertSplatConstantsIntoSplats::matchAndRewrite(mlir::iree_compiler::IREE::Stream::AsyncConstantOp, mlir::PatternRewriter&) const StreamOpFolders.cpp:0:0
 #5 0x00007ad6b0ae6655 mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) (/home/marble/iree-build/lib/libIREECompiler.so+0x54e6655)
 #6 0x00007ad6b0a9f26f (anonymous namespace)::GreedyPatternRewriteDriver::processWorklist() GreedyPatternRewriteDriver.cpp:0:0
 #7 0x00007ad6b0aa0c56 mlir::applyPatternsAndFoldGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) (/home/marble/iree-build/lib/libIREECompiler.so+0x54a0c56)
 #8 0x00007ad6b0a54b7a (anonymous namespace)::Canonicalizer::runOnOperation() Canonicalizer.cpp:0:0

The program looks fairly normal to me only f32 data types (no exotic types like bf16, i4, etc.).

Possibly an issue with the Conv2d and BatchNorm2d op lowerings? We have a few models using convolution coming from PyTorch working without this issue though 🤔

@ScottTodd ScottTodd added the support Request support or ask a question label Jun 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐞 Something isn't working support Request support or ask a question
Projects
None yet
Development

No branches or pull requests

2 participants