From fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 16 Feb 2023 15:40:53 +0100 Subject: [PATCH] Rebase Triton to LLVM-15. (#1070) This PR rebases Triton from LLVM-14 to LLVM-15. Most changes are mechanical, except for the analysis framework changes. --- CMakeLists.txt | 6 +- bin/CMakeLists.txt | 2 +- bin/FileCheck/FileCheck.cpp | 3 + bin/triton-opt.cpp | 6 +- bin/triton-translate.cpp | 7 +- include/triton/Analysis/Alias.h | 21 +- include/triton/Analysis/Allocation.h | 2 + include/triton/Analysis/AxisInfo.h | 56 ++- include/triton/Analysis/Utility.h | 6 +- include/triton/Conversion/Passes.td | 4 +- include/triton/Dialect/Triton/IR/Dialect.h | 7 +- .../triton/Dialect/Triton/IR/TritonDialect.td | 8 +- include/triton/Dialect/Triton/IR/TritonOps.td | 12 +- .../triton/Dialect/Triton/IR/TritonTypes.td | 2 + .../Dialect/Triton/Transforms/Passes.td | 3 +- include/triton/Dialect/TritonGPU/IR/Dialect.h | 4 +- .../Dialect/TritonGPU/IR/TritonGPUAttrDefs.td | 7 + .../Dialect/TritonGPU/IR/TritonGPUDialect.td | 2 +- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 13 +- lib/Analysis/Alias.cpp | 14 +- lib/Analysis/Allocation.cpp | 30 +- lib/Analysis/AxisInfo.cpp | 79 ++-- lib/Analysis/CMakeLists.txt | 2 +- lib/Analysis/Membar.cpp | 2 +- lib/Analysis/Utility.cpp | 54 +++ .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 3 - lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h | 10 +- .../TritonGPUToLLVM/DotOpToLLVM.cpp | 5 - .../TritonGPUToLLVM/ElementwiseOpToLLVM.cpp | 2 - .../TritonGPUToLLVM/LoadStoreOpToLLVM.cpp | 5 +- .../TritonGPUToLLVM/ReduceOpToLLVM.cpp | 2 - .../TritonGPUToLLVM/TritonGPUToLLVM.cpp | 7 +- .../TritonGPUToLLVM/TritonGPUToLLVMBase.h | 26 +- .../TritonGPUToLLVM/TritonGPUToLLVMPass.cpp | 52 +-- lib/Conversion/TritonGPUToLLVM/Utility.h | 5 +- .../TritonToTritonGPUPass.cpp | 69 ++-- lib/Dialect/Triton/IR/CMakeLists.txt | 10 +- lib/Dialect/Triton/IR/Ops.cpp | 34 +- lib/Dialect/Triton/Transforms/Combine.cpp | 6 +- lib/Dialect/Triton/Transforms/Combine.td | 2 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 27 +- lib/Dialect/TritonGPU/Transforms/Coalesce.cpp | 20 +- lib/Dialect/TritonGPU/Transforms/Combine.cpp | 2 +- lib/Dialect/TritonGPU/Transforms/Combine.td | 1 + .../Transforms/DecomposeConversions.cpp | 2 +- lib/Dialect/TritonGPU/Transforms/Pipeline.cpp | 10 +- .../Transforms/ReorderInstructions.cpp | 2 +- .../Transforms/TritonGPUConversion.cpp | 12 +- .../Transforms/UpdateMmaForVolta.cpp | 6 +- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 2 +- lib/Target/LLVMIR/CMakeLists.txt | 3 +- lib/Target/PTX/PTXTranslation.cpp | 3 + python/setup.py | 15 +- python/src/triton.cc | 85 +++-- python/test/unit/language/test_core.py | 2 +- python/triton/compiler.py | 4 +- test/Analysis/test-alias.mlir | 24 +- test/Analysis/test-alignment.mlir | 344 +++++++++--------- test/Analysis/test-allocation.mlir | 32 +- test/Analysis/test-membar.mlir | 38 +- test/Conversion/triton_ops.mlir | 10 +- test/Conversion/triton_to_tritongpu.mlir | 6 +- test/Conversion/tritongpu_to_llvm.mlir | 94 ++--- test/Target/tritongpu_to_llvmir.mlir | 4 +- test/Target/tritongpu_to_ptx.mlir | 2 +- test/Triton/combine.mlir | 40 +- test/Triton/vecadd.mlir | 4 +- test/TritonGPU/coalesce.mlir | 2 +- test/TritonGPU/combine.mlir | 38 +- test/TritonGPU/loop-pipeline.mlir | 22 +- test/TritonGPU/matmul.mlir | 4 +- test/TritonGPU/prefetch.mlir | 4 +- test/TritonGPU/update-mma-for-volta.mlir | 4 +- test/lib/Analysis/TestAlias.cpp | 29 +- test/lib/Analysis/TestAllocation.cpp | 5 +- test/lib/Analysis/TestAxisInfo.cpp | 51 +-- test/lib/Analysis/TestMembar.cpp | 7 +- 78 files changed, 808 insertions(+), 742 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d0d361fc7c..b281a28400 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,7 @@ cmake_minimum_required(VERSION 3.6) + +cmake_policy(SET CMP0116 OLD) + include(ExternalProject) set(CMAKE_CXX_STANDARD 17) @@ -155,7 +158,6 @@ if(TRITON_BUILD_PYTHON_MODULE) endif() endif() - # # Triton # file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) # if (WIN32 AND TRITON_BUILD_PYTHON_MODULE) @@ -212,7 +214,7 @@ if(TRITON_BUILD_PYTHON_MODULE) # optimizations MLIRPass MLIRTransforms - MLIRLLVMIR + MLIRLLVMDialect MLIRSupport MLIRTargetLLVMIRExport MLIRExecutionEngine diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index 906f635f8b..695b3479fd 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -48,7 +48,7 @@ llvm_update_compile_flags(triton-translate) # MLIR core MLIROptLib MLIRIR - MLIRLLVMIR + MLIRLLVMDialect MLIRPass MLIRSupport MLIRTransforms diff --git a/bin/FileCheck/FileCheck.cpp b/bin/FileCheck/FileCheck.cpp index 819efc3541..9ac6f1b277 100644 --- a/bin/FileCheck/FileCheck.cpp +++ b/bin/FileCheck/FileCheck.cpp @@ -19,6 +19,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/Process.h" +#include "llvm/Support/SourceMgr.h" #include "llvm/Support/WithColor.h" #include "llvm/Support/raw_ostream.h" #include @@ -360,6 +361,8 @@ static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) { return "bad-not"; case Check::CheckBadCount: return "bad-count"; + case Check::CheckMisspelled: + return "misspelled"; case Check::CheckNone: llvm_unreachable("invalid FileCheckType"); } diff --git a/bin/triton-opt.cpp b/bin/triton-opt.cpp index 9f3b53b7ae..f96232e1b0 100644 --- a/bin/triton-opt.cpp +++ b/bin/triton-opt.cpp @@ -8,7 +8,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/InitAllPasses.h" -#include "mlir/Support/MlirOptMain.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" namespace mlir { namespace test { @@ -33,8 +33,8 @@ int main(int argc, char **argv) { // TODO: register Triton & TritonGPU passes mlir::DialectRegistry registry; registry.insert(); return mlir::asMainReturnCode(mlir::MlirOptMain( diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index 05ba15e453..56b5d65857 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -3,7 +3,7 @@ #include "mlir/IR/AsmState.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" -#include "mlir/Parser.h" +#include "mlir/Parser/Parser.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "mlir/Support/FileUtilities.h" @@ -38,7 +38,7 @@ OwningOpRef loadMLIRModule(llvm::StringRef inputFilename, mlir::DialectRegistry registry; registry.insert(); + scf::SCFDialect>(); context.appendDialectRegistry(registry); @@ -50,7 +50,8 @@ OwningOpRef loadMLIRModule(llvm::StringRef inputFilename, context.loadAllAvailableDialects(); context.allowUnregisteredDialects(); - OwningOpRef module(parseSourceFile(sourceMgr, &context)); + OwningOpRef module = + parseSourceFile(sourceMgr, &context); if (!module) { llvm::errs() << "Parse MLIR file failed."; return nullptr; diff --git a/include/triton/Analysis/Alias.h b/include/triton/Analysis/Alias.h index fa6b906fc9..631df518bc 100644 --- a/include/triton/Analysis/Alias.h +++ b/include/triton/Analysis/Alias.h @@ -2,7 +2,7 @@ #define TRITON_ANALYSIS_ALIAS_H #include "mlir/Analysis/AliasAnalysis.h" -#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "llvm/ADT/DenseSet.h" namespace mlir { @@ -21,7 +21,7 @@ class AliasInfo { } /// The pessimistic value state of a value without alias - static AliasInfo getPessimisticValueState(MLIRContext *context) { + static AliasInfo getPessimisticValueState(MLIRContext *context = nullptr) { return AliasInfo(); } static AliasInfo getPessimisticValueState(Value value) { return AliasInfo(); } @@ -29,6 +29,10 @@ class AliasInfo { /// The union of both arguments static AliasInfo join(const AliasInfo &lhs, const AliasInfo &rhs); + void print(raw_ostream &os) const { + llvm::interleaveComma(allocs, os, [&](Value alloc) { alloc.print(os); }); + } + private: /// The set of allocated values that are aliased by this lattice. /// For now, we only consider aliased value produced by the following @@ -58,9 +62,13 @@ class AliasInfo { //===----------------------------------------------------------------------===// // Shared Memory Alias Analysis //===----------------------------------------------------------------------===// -class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis { +class SharedMemoryAliasAnalysis + : public dataflow::SparseDataFlowAnalysis> { public: - using ForwardDataFlowAnalysis::ForwardDataFlowAnalysis; + using dataflow::SparseDataFlowAnalysis< + dataflow::Lattice>::SparseDataFlowAnalysis; + using dataflow::SparseDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; /// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use. /// Given two values, returns their aliasing behavior. @@ -70,9 +78,10 @@ class SharedMemoryAliasAnalysis : public ForwardDataFlowAnalysis { ModRefResult getModRef(Operation *op, Value location); /// Computes if the alloc set of the results are changed. - ChangeResult + void visitOperation(Operation *op, - ArrayRef *> operands) override; + ArrayRef *> operands, + ArrayRef *> results) override; }; } // namespace mlir diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index b7c136d602..89b77034cc 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -188,6 +188,8 @@ class Allocation { friend class triton::AllocationAnalysis; }; +template Interval(T, T) -> Interval; + } // namespace mlir #endif // TRITON_ANALYSIS_ALLOCATION_H diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index fdfbd8fbb3..7083b9c43b 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -1,9 +1,10 @@ #ifndef TRITON_ANALYSIS_AXISINFO_H #define TRITON_ANALYSIS_AXISINFO_H -#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Support/LLVM.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -62,7 +63,7 @@ class AxisInfo { } /// The pessimistic value state of the contiguity is unknown. - static AxisInfo getPessimisticValueState(MLIRContext *context) { + static AxisInfo getPessimisticValueState(MLIRContext *context = nullptr) { return AxisInfo(); } static AxisInfo getPessimisticValueState(Value value); @@ -70,6 +71,22 @@ class AxisInfo { /// The gcd of both arguments for each dimension static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs); + void print(raw_ostream &os) const { + auto print = [&](StringRef name, DimVectorT vec) { + os << name << " = ["; + llvm::interleaveComma(vec, os); + os << "]"; + }; + print("contiguity", contiguity); + print(", divisibility", divisibility); + print(", constancy", constancy); + os << ", constant_value = "; + if (constantValue) + os << *constantValue; + else + os << ""; + } + private: /// The _contiguity_ information maps the `d`-th /// dimension to the length of the shortest @@ -147,7 +164,8 @@ class AxisInfoVisitor { } virtual AxisInfo - getAxisInfo(Operation *op, ArrayRef *> operands) = 0; + getAxisInfo(Operation *op, + ArrayRef *> operands) = 0; virtual bool match(Operation *op) = 0; }; @@ -157,15 +175,16 @@ template class AxisInfoVisitorImpl : public AxisInfoVisitor { public: using AxisInfoVisitor::AxisInfoVisitor; - AxisInfo getAxisInfo(Operation *op, - ArrayRef *> operands) final { + AxisInfo + getAxisInfo(Operation *op, + ArrayRef *> operands) final { return getAxisInfo(cast(op), operands); } bool match(Operation *op) final { return isa(op); } - virtual AxisInfo getAxisInfo(OpTy op, - ArrayRef *> operands) { + virtual AxisInfo + getAxisInfo(OpTy op, ArrayRef *> operands) { llvm_unreachable("Unimplemented getAxisInfo"); } }; @@ -176,8 +195,9 @@ class BinaryOpVisitorImpl : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; - AxisInfo getAxisInfo(OpTy op, - ArrayRef *> operands) override { + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { auto lhsInfo = operands[0]->getValue(); auto rhsInfo = operands[1]->getValue(); auto rank = lhsInfo.getRank(); @@ -230,7 +250,8 @@ class AxisInfoVisitorList { (visitors.emplace_back(std::make_unique()), ...); } - AxisInfo apply(Operation *op, ArrayRef *> operands) { + AxisInfo apply(Operation *op, + ArrayRef *> operands) { for (auto &visitor : visitors) if (visitor->match(op)) return visitor->getAxisInfo(op, operands); @@ -241,16 +262,19 @@ class AxisInfoVisitorList { std::vector> visitors; }; -class AxisInfoAnalysis : public ForwardDataFlowAnalysis { +class AxisInfoAnalysis + : public dataflow::SparseDataFlowAnalysis> { private: AxisInfoVisitorList visitors; public: - AxisInfoAnalysis(MLIRContext *context); + AxisInfoAnalysis(DataFlowSolver &solver); + using dataflow::SparseDataFlowAnalysis< + dataflow::Lattice>::getLatticeElement; - ChangeResult - visitOperation(Operation *op, - ArrayRef *> operands) override; + void visitOperation(Operation *op, + ArrayRef *> operands, + ArrayRef *> results) override; unsigned getPtrContiguity(Value ptr); @@ -261,4 +285,4 @@ class AxisInfoAnalysis : public ForwardDataFlowAnalysis { } // namespace mlir -#endif \ No newline at end of file +#endif diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index c5ac137dc1..ee7fadb59d 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -1,6 +1,7 @@ #ifndef TRITON_ANALYSIS_UTILITY_H #define TRITON_ANALYSIS_UTILITY_H +#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Analysis/SliceAnalysis.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include @@ -12,7 +13,7 @@ namespace mlir { class ReduceOpHelper { public: explicit ReduceOpHelper(triton::ReduceOp op) : op(op) { - srcTy = op.operand().getType().cast(); + srcTy = op.getOperand().getType().cast(); } ArrayRef getSrcShape() { return srcTy.getShape(); } @@ -103,6 +104,9 @@ SetVector multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr, TransitiveFilter forwardFilter = nullptr); +// Create a basic DataFlowSolver with constant and dead code analysis included. +std::unique_ptr createDataFlowSolver(); + } // namespace mlir #endif // TRITON_ANALYSIS_UTILITY_H diff --git a/include/triton/Conversion/Passes.td b/include/triton/Conversion/Passes.td index 70bb20b78e..be00eb2dac 100644 --- a/include/triton/Conversion/Passes.td +++ b/include/triton/Conversion/Passes.td @@ -12,7 +12,6 @@ def ConvertTritonToTritonGPU: Pass<"convert-triton-to-tritongpu", "mlir::ModuleO let dependentDialects = ["mlir::arith::ArithmeticDialect", "mlir::math::MathDialect", - "mlir::StandardOpsDialect", // TODO: Does this pass depend on SCF? "mlir::scf::SCFDialect", "mlir::triton::TritonDialect", @@ -41,8 +40,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" "mlir::tensor::TensorDialect", "mlir::triton::TritonDialect", "mlir::triton::gpu::TritonGPUDialect", - "mlir::NVVM::NVVMDialect", - "mlir::StandardOpsDialect"]; + "mlir::NVVM::NVVMDialect"]; let options = [ Option<"computeCapability", "compute-capability", diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index e8012a51df..15869e262e 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -1,14 +1,15 @@ #ifndef TRITON_DIALECT_TRITON_IR_DIALECT_H_ #define TRITON_DIALECT_TRITON_IR_DIALECT_H_ +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/SCF/SCF.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" - #include "triton/Dialect/Triton/IR/Dialect.h.inc" #include "triton/Dialect/Triton/IR/OpsEnums.h.inc" #include "triton/Dialect/Triton/IR/Traits.h" diff --git a/include/triton/Dialect/Triton/IR/TritonDialect.td b/include/triton/Dialect/Triton/IR/TritonDialect.td index 07b069e14f..d98ce73884 100644 --- a/include/triton/Dialect/Triton/IR/TritonDialect.td +++ b/include/triton/Dialect/Triton/IR/TritonDialect.td @@ -25,12 +25,9 @@ def Triton_Dialect : Dialect { let dependentDialects = [ "arith::ArithmeticDialect", "math::MathDialect", - "StandardOpsDialect", "scf::SCFDialect", - - // Since LLVM 15 - // "cf::ControlFlowDialect", - // "func::FuncDialect" + "cf::ControlFlowDialect", + "func::FuncDialect" ]; let extraClassDeclaration = [{ @@ -38,6 +35,7 @@ def Triton_Dialect : Dialect { }]; let hasConstantMaterializer = 1; + let useDefaultTypePrinterParser = 1; } include "triton/Dialect/Triton/IR/TritonTypes.td" diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 779e0b648c..0a69211179 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -141,11 +141,7 @@ def TT_LoadOp : TT_Op<"load", "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, ]; - // let assemblyFormat = "operands attr-dict `:` type($result)"; - let parser = [{ return mlir::triton::parseLoadOp(parser, result); }]; - - let printer = [{ return mlir::triton::printLoadOp(p, *this); }]; - + let hasCustomAssemblyFormat = 1; let hasCanonicalizer = 1; } @@ -170,11 +166,7 @@ def TT_StoreOp : TT_Op<"store", "triton::EvictionPolicy":$evict)>, ]; - // let assemblyFormat = "operands attr-dict `:` type($value)"; - let parser = [{ return mlir::triton::parseStoreOp(parser, result); }]; - - let printer = [{ return mlir::triton::printStoreOp(p, *this); }]; - + let hasCustomAssemblyFormat = 1; let hasCanonicalizer = 1; } diff --git a/include/triton/Dialect/Triton/IR/TritonTypes.td b/include/triton/Dialect/Triton/IR/TritonTypes.td index 66d2a7b9a9..2fe2fd077d 100644 --- a/include/triton/Dialect/Triton/IR/TritonTypes.td +++ b/include/triton/Dialect/Triton/IR/TritonTypes.td @@ -1,6 +1,7 @@ #ifndef TRITON_TYPES #define TRITON_TYPES +include "mlir/IR/AttrTypeBase.td" include "triton/Dialect/Triton/IR/TritonDialect.td" // @@ -58,6 +59,7 @@ def TT_Ptr : TritonTypeDef<"Pointer", "ptr"> { }]> ]; + let hasCustomAssemblyFormat = 1; let skipDefaultBuilders = 1; } def TT_PtrTensor : TensorOf<[TT_Ptr]>; diff --git a/include/triton/Dialect/Triton/Transforms/Passes.td b/include/triton/Dialect/Triton/Transforms/Passes.td index 8f77aed774..a25cdc5680 100644 --- a/include/triton/Dialect/Triton/Transforms/Passes.td +++ b/include/triton/Dialect/Triton/Transforms/Passes.td @@ -16,8 +16,7 @@ def TritonCombineOps : Pass let constructor = "mlir::triton::createCombineOpsPass()"; - let dependentDialects = ["mlir::arith::ArithmeticDialect", - /*SelectOp*/"mlir::StandardOpsDialect"]; + let dependentDialects = ["mlir::arith::ArithmeticDialect"]; } #endif diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index b4c8daec7b..dfc5f53ab1 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -1,19 +1,17 @@ #ifndef TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ #define TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ -#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" // TritonGPU depends on Triton #include "triton/Dialect/Triton/IR/Dialect.h" - #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" #include "triton/Dialect/TritonGPU/IR/Traits.h" #define GET_ATTRDEF_CLASSES -#include "triton/Dialect/Triton/IR/AttrInterfaces.h.inc" #include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.h.inc" #define GET_OP_CLASSES diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 0242c3cc17..af2aeb03a8 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -1,6 +1,7 @@ #ifndef TRITONGPU_ATTRDEFS #define TRITONGPU_ATTRDEFS +include "mlir/IR/AttrTypeBase.td" include "triton/Dialect/TritonGPU/IR/TritonGPUDialect.td" include "triton/Dialect/Triton/IR/TritonInterfaces.td" @@ -136,6 +137,7 @@ A_{3, 2} A_{3, 3} A_{3, 0} A_{3, 1} ... [phase 1] / ]; let extraClassDeclaration = extraBaseClassDeclaration; + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -273,6 +275,7 @@ for // ArrayRefParameter<"unsigned">:$sizePerCTA ); + let hasCustomAssemblyFormat = 1; } //===----------------------------------------------------------------------===// @@ -422,6 +425,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is: static constexpr int numBitsToHoldMmaV1ID{5}; }]; + let hasCustomAssemblyFormat = 1; } def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { @@ -456,6 +460,8 @@ def SliceEncodingAttr : DistributedEncoding<"SliceEncoding"> { template SmallVector paddedShape(ArrayRef shape) const; }]; + + let hasCustomAssemblyFormat = 1; } def DotOperandEncodingAttr : DistributedEncoding<"DotOperandEncoding"> { @@ -492,6 +498,7 @@ section 9.7.13.4.1 for more details. ]; + let hasCustomAssemblyFormat = 1; let extraClassDeclaration = extraBaseClassDeclaration; } diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index 87ec1d36c6..6489a721b4 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -30,7 +30,7 @@ def TritonGPU_Dialect : Dialect { } }]; - + let useDefaultAttributePrinterParser = 1; } #endif diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 510f8d0183..7aba11dc75 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -59,7 +59,7 @@ def TTG_AsyncCommitGroupOp : TTG_Op<"async_commit_group"> { // This is needed because these ops don't // handle encodings // e.g., https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td#L111 -def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise, +def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding]> { let summary = "integer comparison operation"; @@ -73,7 +73,7 @@ def TTG_CmpIOp : TTG_Op<"cmpi", [NoSideEffect, Elementwise, let results = (outs TT_BoolLike:$result); } -def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise, +def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise, SameOperandsAndResultShape, SameOperandsAndResultEncoding]> { let summary = "floating-point comparison operation"; @@ -88,8 +88,8 @@ def TTG_CmpFOp : TTG_Op<"cmpf", [NoSideEffect, Elementwise, } // TODO: migrate to arith::SelectOp on LLVM16 -def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise, - SameOperandsAndResultShape, +def TTG_SelectOp : TTG_Op<"select", [NoSideEffect, Elementwise, + SameOperandsAndResultShape, SameOperandsAndResultEncoding]> { let summary = "select operation"; @@ -188,10 +188,7 @@ def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async", } }]; - // The custom parser could be replaced with oilist in LLVM-16 - let parser = [{ return parseInsertSliceAsyncOp(parser, result); }]; - - let printer = [{ return printInsertSliceAsyncOp(p, *this); }]; + let hasCustomAssemblyFormat = 1; } def TTG_AllocTensorOp : TTG_Op<"alloc_tensor", [MemoryEffects<[MemAlloc]>, // Allocate shared memory diff --git a/lib/Analysis/Alias.cpp b/lib/Analysis/Alias.cpp index a39e4de9aa..208fdd4afc 100644 --- a/lib/Analysis/Alias.cpp +++ b/lib/Analysis/Alias.cpp @@ -18,8 +18,9 @@ AliasInfo AliasInfo::join(const AliasInfo &lhs, const AliasInfo &rhs) { return ret; } -ChangeResult SharedMemoryAliasAnalysis::visitOperation( - Operation *op, ArrayRef *> operands) { +void SharedMemoryAliasAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { AliasInfo aliasInfo; bool pessimistic = true; if (maybeSharedAllocationOp(op)) { @@ -44,14 +45,11 @@ ChangeResult SharedMemoryAliasAnalysis::visitOperation( } if (pessimistic) { - return markAllPessimisticFixpoint(op->getResults()); + return markAllPessimisticFixpoint(results); } // Join all lattice elements - ChangeResult result = ChangeResult::NoChange; - for (Value value : op->getResults()) { - result |= getLatticeElement(value).join(aliasInfo); - } - return result; + for (auto *result : results) + propagateIfChanged(result, result->join(aliasInfo)); } AliasResult SharedMemoryAliasAnalysis::alias(Value lhs, Value rhs) { diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 712c08c475..b4de8dcd9d 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -1,4 +1,5 @@ #include "triton/Analysis/Allocation.h" +#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Analysis/Liveness.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -33,10 +34,8 @@ constexpr int kPtrBitWidth = 64; static std::pair, SmallVector> getCvtOrder(const Attribute &srcLayout, const Attribute &dstLayout) { - auto srcBlockedLayout = srcLayout.dyn_cast(); auto srcMmaLayout = srcLayout.dyn_cast(); auto srcDotLayout = srcLayout.dyn_cast(); - auto dstBlockedLayout = dstLayout.dyn_cast(); auto dstMmaLayout = dstLayout.dyn_cast(); auto dstDotLayout = dstLayout.dyn_cast(); assert(!(srcMmaLayout && dstMmaLayout) && @@ -224,14 +223,12 @@ class AllocationAnalysis { } void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) { - LatticeElement *latticeElement = - analysis.lookupLatticeElement(value); - if (latticeElement) { - auto &info = latticeElement->getValue(); - if (!info.getAllocs().empty()) { - for (auto alloc : info.getAllocs()) { - allocation->addAlias(value, alloc); - } + dataflow::Lattice *latticeElement = + analysis.getLatticeElement(value); + if (latticeElement && !latticeElement->isUninitialized()) { + AliasInfo &info = latticeElement->getValue(); + for (auto alloc : info.getAllocs()) { + allocation->addAlias(value, alloc); } } } @@ -244,14 +241,19 @@ class AllocationAnalysis { getScratchValueSize(op); }); // Get the alias values - SharedMemoryAliasAnalysis aliasAnalysis(operation->getContext()); - aliasAnalysis.run(operation); + std::unique_ptr solver = createDataFlowSolver(); + SharedMemoryAliasAnalysis *aliasAnalysis = + solver->load(); + if (failed(solver->initializeAndRun(operation))) { + // TODO: return error instead of bailing out.. + llvm_unreachable("failed to run SharedMemoryAliasAnalysis"); + } operation->walk([&](Operation *op) { for (auto operand : op->getOperands()) { - getValueAlias(operand, aliasAnalysis); + getValueAlias(operand, *aliasAnalysis); } for (auto value : op->getResults()) { - getValueAlias(value, aliasAnalysis); + getValueAlias(value, *aliasAnalysis); } }); } diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 0b7142b04d..4af46c3fbb 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -1,4 +1,4 @@ -#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "llvm/Support/raw_ostream.h" @@ -52,7 +52,7 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) { BlockArgument blockArg = value.dyn_cast(); if (blockArg && blockArg.getOwner()->isEntryBlock()) { Operation *op = blockArg.getOwner()->getParentOp(); - if (FuncOp fun = dyn_cast(op)) { + if (func::FuncOp fun = dyn_cast(op)) { Attribute attr = fun.getArgAttr(blockArg.getArgNumber(), "tt.divisibility"); if (attr) @@ -136,8 +136,9 @@ class CastOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; - AxisInfo getAxisInfo(OpTy op, - ArrayRef *> operands) override { + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { return operands[0]->getValue(); } }; @@ -147,8 +148,9 @@ class MakeRangeOpAxisInfoVisitor final public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; - AxisInfo getAxisInfo(triton::MakeRangeOp op, - ArrayRef *> operands) override { + AxisInfo + getAxisInfo(triton::MakeRangeOp op, + ArrayRef *> operands) override { auto start = op.start(); auto end = op.end(); return AxisInfo(/*contiguity=*/{end - start}, @@ -162,8 +164,9 @@ class ConstantOpAxisInfoVisitor final public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; - AxisInfo getAxisInfo(arith::ConstantOp op, - ArrayRef *> operands) override { + AxisInfo + getAxisInfo(arith::ConstantOp op, + ArrayRef *> operands) override { auto intAttr = op.getValue().dyn_cast(); auto boolAttr = op.getValue().dyn_cast(); if (intAttr || boolAttr) { @@ -416,8 +419,9 @@ class SplatOpAxisInfoVisitor final public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; - AxisInfo getAxisInfo(triton::SplatOp op, - ArrayRef *> operands) override { + AxisInfo + getAxisInfo(triton::SplatOp op, + ArrayRef *> operands) override { Type _retTy = *op->result_type_begin(); TensorType retTy = _retTy.cast(); AxisInfo opInfo = operands[0]->getValue(); @@ -439,8 +443,9 @@ class ExpandDimsOpAxisInfoVisitor final public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; - AxisInfo getAxisInfo(triton::ExpandDimsOp op, - ArrayRef *> operands) override { + AxisInfo + getAxisInfo(triton::ExpandDimsOp op, + ArrayRef *> operands) override { AxisInfo opInfo = operands[0]->getValue(); AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); @@ -458,8 +463,9 @@ class BroadcastOpAxisInfoVisitor final public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; - AxisInfo getAxisInfo(triton::BroadcastOp op, - ArrayRef *> operands) override { + AxisInfo + getAxisInfo(triton::BroadcastOp op, + ArrayRef *> operands) override { Type _retTy = *op->result_type_begin(); Type _opTy = *op->operand_type_begin(); TensorType retTy = _retTy.cast(); @@ -486,8 +492,9 @@ class CmpOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; - AxisInfo getAxisInfo(OpTy op, - ArrayRef *> operands) override { + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { auto resTy = op.getResult().getType().template dyn_cast(); if (!resTy) return AxisInfo(); @@ -596,8 +603,9 @@ class SelectOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; - AxisInfo getAxisInfo(OpTy op, - ArrayRef *> operands) override { + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { auto resTy = op.getResult().getType().template dyn_cast(); if (!resTy) return AxisInfo(); @@ -757,8 +765,9 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { public: using AxisInfoVisitorImpl::AxisInfoVisitorImpl; - AxisInfo getAxisInfo(OpTy op, - ArrayRef *> operands) override { + AxisInfo + getAxisInfo(OpTy op, + ArrayRef *> operands) override { auto lhsInfo = operands[0]->getValue(); auto rhsInfo = operands[1]->getValue(); std::optional constantValue; @@ -786,8 +795,8 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl { // AxisInfoAnalysis //===----------------------------------------------------------------------===// -AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context) - : ForwardDataFlowAnalysis(context) { +AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) + : dataflow::SparseDataFlowAnalysis>(solver) { // UnrealizedConversionCast: // This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is // in the process of a PartialConversion, where UnrealizedConversionCast @@ -819,7 +828,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context) visitors.append, LogicalOpAxisInfoVisitor, LogicalOpAxisInfoVisitor>(); - visitors.append, + visitors.append, SelectOpAxisInfoVisitor>(); visitors.append, ShROpAxisInfoVisitor>(); @@ -829,11 +838,12 @@ AxisInfoAnalysis::AxisInfoAnalysis(MLIRContext *context) MaxMinOpAxisInfoVisitor>(); } -ChangeResult AxisInfoAnalysis::visitOperation( - Operation *op, ArrayRef *> operands) { +void AxisInfoAnalysis::visitOperation( + Operation *op, ArrayRef *> operands, + ArrayRef *> results) { AxisInfo curr = visitors.apply(op, operands); if (curr.getRank() == 0) { - return markAllPessimisticFixpoint(op->getResults()); + return markAllPessimisticFixpoint(results); } // override with hint auto newContiguity = curr.getContiguity(); @@ -854,11 +864,8 @@ ChangeResult AxisInfoAnalysis::visitOperation( curr = mlir::AxisInfo(newContiguity, newDivisibility, newConstancy, curr.getConstantValue()); // join all lattice elements - ChangeResult result = ChangeResult::NoChange; - for (Value value : op->getResults()) { - result |= getLatticeElement(value).join(curr); - } - return result; + for (auto *result : results) + propagateIfChanged(result, result->join(curr)); } unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) { @@ -884,7 +891,10 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) { auto tensorTy = ptr.getType().dyn_cast(); if (!tensorTy) return 1; - auto axisInfo = lookupLatticeElement(ptr)->getValue(); + dataflow::Lattice *latticeElement = getLatticeElement(ptr); + if (!latticeElement || latticeElement->isUninitialized()) + return 1; + auto axisInfo = latticeElement->getValue(); auto layout = tensorTy.getEncoding(); auto order = triton::gpu::getOrder(layout); auto maxMultipleBytes = axisInfo.getDivisibility(order[0]); @@ -900,8 +910,11 @@ unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) { auto tensorTy = mask.getType().dyn_cast(); if (!tensorTy) return 1; + dataflow::Lattice *latticeElement = getLatticeElement(mask); + if (!latticeElement || latticeElement->isUninitialized()) + return 1; + auto maskAxis = latticeElement->getValue(); auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); - auto maskAxis = lookupLatticeElement(mask)->getValue(); auto alignment = std::max(maskAxis.getConstancy(maskOrder[0]), 1); return alignment; } diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index afbc692510..1f761f845c 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -8,7 +8,7 @@ add_mlir_library(TritonAnalysis DEPENDS TritonTableGen TritonGPUAttrDefsIncGen - + LINK_LIBS PUBLIC MLIRAnalysis ) diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index acc885e827..910274b2ac 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -2,7 +2,7 @@ #include "triton/Analysis/Alias.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" -#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" namespace mlir { diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index d9e917e731..6ea52df272 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -1,5 +1,8 @@ #include "triton/Analysis/Utility.h" +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Matchers.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include @@ -325,4 +328,55 @@ SetVector multiRootGetSlice(Operation *op, return multiRootTopologicalSort(slice); } +namespace { +// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis +// interacts with constant propagation, but SparseConstantPropagation +// doesn't seem to be sufficient. +struct ConstantAnalysis : public DataFlowAnalysis { + using DataFlowAnalysis::DataFlowAnalysis; + + LogicalResult initialize(Operation *top) override { + WalkResult result = top->walk([&](Operation *op) { + if (failed(visit(op))) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return success(!result.wasInterrupted()); + } + + LogicalResult visit(ProgramPoint point) override { + Operation *op = point.get(); + Attribute value; + if (matchPattern(op, m_Constant(&value))) { + auto *constant = getOrCreate>( + op->getResult(0)); + propagateIfChanged(constant, constant->join(dataflow::ConstantValue( + value, op->getDialect()))); + return success(); + } + setAllToUnknownConstants(op->getResults()); + for (Region ®ion : op->getRegions()) + setAllToUnknownConstants(region.getArguments()); + return success(); + } + + /// Set all given values as not constants. + void setAllToUnknownConstants(ValueRange values) { + dataflow::ConstantValue unknownConstant(nullptr, nullptr); + for (Value value : values) { + auto *constant = + getOrCreate>(value); + propagateIfChanged(constant, constant->join(unknownConstant)); + } + } +}; +} // namespace + +std::unique_ptr createDataFlowSolver() { + auto solver = std::make_unique(); + solver->load(); + solver->load(); + return solver; +} + } // namespace mlir diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 6a46265bd7..e352eb3698 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -159,9 +159,6 @@ struct ConvertLayoutOpConversion Value smemBase) const { auto accumNumCTAsEachRep = product(numCTAsEachRep); auto layout = type.getEncoding(); - auto blockedLayout = layout.dyn_cast(); - auto sliceLayout = layout.dyn_cast(); - auto mmaLayout = layout.dyn_cast(); auto rank = type.getRank(); auto sizePerThread = getSizePerThread(layout); auto accumSizePerThread = product(sizePerThread); diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h index 4b89965aa9..1d9e00519b 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h +++ b/lib/Conversion/TritonGPUToLLVM/DotOpHelpers.h @@ -7,10 +7,8 @@ #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Matchers.h" @@ -422,9 +420,9 @@ struct MMA16816ConversionHelper { MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout, Value thread, ConversionPatternRewriter &rewriter, TypeConverter *typeConverter, Location loc) - : mmaLayout(mmaLayout), thread(thread), helper(mmaLayout), - rewriter(rewriter), typeConverter(typeConverter), loc(loc), - ctx(mmaLayout.getContext()), wpt(mmaLayout.getWarpsPerCTA()) { + : mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()), thread(thread), + helper(mmaLayout), rewriter(rewriter), typeConverter(typeConverter), + loc(loc), ctx(mmaLayout.getContext()) { helper.deduceMmaType(dotOperand); Value _32 = i32_val(32); diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp index 0f8070ca9f..e4bd47c411 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp @@ -115,8 +115,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { auto DTensorTy = D.getType().cast(); auto AShape = ATensorTy.getShape(); auto BShape = BTensorTy.getShape(); - auto DShape = DTensorTy.getShape(); - auto wpt = mmaLayout.getWarpsPerCTA(); bool isARow = ALayout.getIsMMAv1Row().cast().getValue(); bool isBRow = BLayout.getIsMMAv1Row().cast().getValue(); @@ -221,7 +219,6 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { ConversionPatternRewriter &rewriter) const { auto *ctx = rewriter.getContext(); auto loc = op.getLoc(); - auto threadId = getThreadId(rewriter, loc); auto A = op.a(); auto B = op.b(); @@ -230,12 +227,10 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { auto aTensorTy = A.getType().cast(); auto bTensorTy = B.getType().cast(); - auto cTensorTy = C.getType().cast(); auto dTensorTy = D.getType().cast(); auto aShape = aTensorTy.getShape(); auto bShape = bTensorTy.getShape(); - auto cShape = cTensorTy.getShape(); BlockedEncodingAttr dLayout = dTensorTy.getEncoding().cast(); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index deb71b9597..0b9e67674b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -61,7 +61,6 @@ struct FpToFpOpConversion convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, const Value &v0, const Value &v1, const Value &v2, const Value &v3) { - auto ctx = rewriter.getContext(); auto fp16x2VecTy = vec_ty(f16_ty, 2); Value fp16x2Vec0 = undef(fp16x2VecTy); Value fp16x2Vec1 = undef(fp16x2VecTy); @@ -153,7 +152,6 @@ struct FpToFpOpConversion convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter, const Value &v0, const Value &v1, const Value &v2, const Value &v3) { - auto ctx = rewriter.getContext(); auto bf16x2VecTy = vec_ty(i16_ty, 2); Value bf16x2Vec0 = undef(bf16x2VecTy); Value bf16x2Vec1 = undef(bf16x2VecTy); diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 9a8b4702bc..bae675f0cb 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -109,7 +109,8 @@ struct LoadOpConversion DenseElementsAttr constAttr; int64_t splatVal = 0; if (other && valueElemTy.isa() && - matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat()) { + matchPattern(other, m_Constant(&constAttr)) && constAttr.isSplat() && + constAttr.getElementType().isa()) { otherIsSplatConstInt = true; splatVal = constAttr.getSplatValue().getSExtValue(); } @@ -333,7 +334,6 @@ struct StoreOpConversion elem = rewriter.create(loc, type::i8Ty(ctx), elem); elem = bitcast(elem, valueElemTy); - Type u32Ty = typeConverter->convertType(type::u32Ty(ctx)); llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx)); } llWord = bitcast(llWord, valArgTy); @@ -387,7 +387,6 @@ struct AtomicCASOpConversion ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); - Value ptr = op.ptr(); Value llPtr = adaptor.ptr(); Value llCmp = adaptor.cmp(); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 69abd889be..1c973dc196 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -286,7 +286,6 @@ struct ReduceOpConversion auto srcTy = op.operand().getType().cast(); auto srcLayout = srcTy.getEncoding(); auto srcShape = srcTy.getShape(); - auto srcRank = srcTy.getRank(); auto order = getOrder(srcLayout); auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); @@ -351,7 +350,6 @@ struct ReduceOpConversion Value zero = i32_val(0); Value laneZero = icmp_eq(laneIdAxis, zero); - Value warpZero = icmp_eq(warpIdAxis, zero); for (auto it : accs) { const SmallVector &key = it.first; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index 5b77150b1a..78cfa076bd 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -11,11 +11,11 @@ using ::mlir::LLVM::getStructFromElements; using ::mlir::triton::gpu::getElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; -struct ReturnOpConversion : public ConvertOpToLLVMPattern<::mlir::ReturnOp> { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; +struct ReturnOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { unsigned numArguments = op.getNumOperands(); @@ -476,7 +476,6 @@ struct ExtractSliceOpConversion auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType()); auto elemPtrTy = ptr_ty(llvmElemTy, 3); - auto resTy = op.getType().dyn_cast(); smemObj = SharedMemoryObject(gep(elemPtrTy, smemObj.base, offset), strideVals, offsetVals); auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index bb10d5b24a..00e399f848 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -4,6 +4,7 @@ // TODO: refactor so that it doesn't fail if Allocation.h // is included after utility.h (due to conflict in `store` macro // and +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "triton/Analysis/Allocation.h" // @@ -39,15 +40,15 @@ void vprintf_array(Value thread, ArrayRef arr, std::string info, // TODO(Superjomn): remove the code when MLIR v15.0 is included. // All the rights are reserved by the LLVM community. -struct FuncOpConversionBase : public ConvertOpToLLVMPattern { +struct FuncOpConversionBase : public ConvertOpToLLVMPattern { private: /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. - static void filterFuncAttributes(ArrayRef attrs, - bool filterArgAttrs, + static void filterFuncAttributes(func::FuncOp op, bool filterArgAttrs, SmallVectorImpl &result) { - for (const auto &attr : attrs) { + + for (const auto &attr : op->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || attr.getName() == FunctionOpInterface::getTypeAttrName() || attr.getName() == "std.varargs" || @@ -65,27 +66,27 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern { } protected: - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; // Convert input FuncOp to LLVMFuncOp by using the LLVMTypeConverter provided // to this legalization pattern. LLVM::LLVMFuncOp - convertFuncOpToLLVMFuncOp(FuncOp funcOp, + convertFuncOpToLLVMFuncOp(func::FuncOp funcOp, ConversionPatternRewriter &rewriter) const { // Convert the original function arguments. They are converted using the // LLVMTypeConverter provided to this legalization pattern. auto varargsAttr = funcOp->getAttrOfType("func.varargs"); TypeConverter::SignatureConversion result(funcOp.getNumArguments()); auto llvmType = getTypeConverter()->convertFunctionSignature( - funcOp.getType(), varargsAttr && varargsAttr.getValue(), result); + funcOp.getFunctionType(), varargsAttr && varargsAttr.getValue(), + result); if (!llvmType) return nullptr; // Propagate argument/result attributes to all converted arguments/result // obtained after converting a given original argument/result. SmallVector attributes; - filterFuncAttributes(funcOp->getAttrs(), /*filterArgAttrs=*/true, - attributes); + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, attributes); if (ArrayAttr resAttrDicts = funcOp.getAllResultAttrs()) { assert(!resAttrDicts.empty() && "expected array to be non-empty"); auto newResAttrDicts = @@ -131,7 +132,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern { } auto newFuncOp = rewriter.create( funcOp.getLoc(), funcOp.getName(), llvmType, linkage, - /*dsoLocal*/ false, attributes); + /*dsoLocal*/ false, LLVM::CConv::C, attributes); rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(), newFuncOp.end()); if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter, @@ -191,8 +192,8 @@ class ConvertTritonGPUOpToLLVMPatternBase { const Allocation *allocation, Value smem, IndexCacheInfo indexCacheInfo) - : converter(&typeConverter), indexCacheInfo(indexCacheInfo), - allocation(allocation), smem(smem) {} + : converter(&typeConverter), allocation(allocation), smem(smem), + indexCacheInfo(indexCacheInfo) {} LLVMTypeConverter *getTypeConverter() const { return converter; } @@ -861,7 +862,6 @@ class ConvertTritonGPUOpToLLVMPatternBase { ArrayRef shape) const { auto parent = sliceLayout.getParent(); unsigned dim = sliceLayout.getDim(); - size_t rank = shape.size(); auto parentIndices = emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape)); unsigned numIndices = parentIndices.size(); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index ff1af09835..6f66af4e34 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -1,10 +1,11 @@ #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" +#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" -#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Pass/Pass.h" @@ -40,7 +41,6 @@ class TritonLLVMConversionTarget : public ConversionTarget { addIllegalDialect(); addIllegalDialect(); addIllegalDialect(); - addIllegalDialect(); addLegalOp(); } }; @@ -51,7 +51,7 @@ class TritonLLVMFunctionConversionTarget : public ConversionTarget { : ConversionTarget(ctx) { addLegalDialect(); addLegalDialect(); - addIllegalOp(); + addIllegalOp(); addLegalOp(); } }; @@ -69,7 +69,7 @@ struct FuncOpConversion : public FuncOpConversionBase { : FuncOpConversionBase(converter, benefit), numWarps(numWarps) {} LogicalResult - matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, + matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) @@ -133,7 +133,8 @@ class ConvertTritonGPUToLLVM decomposeBlockedToDotOperand(mod); // Step 2 - decomposeInsertSliceAsyncOp(mod); + if (failed(decomposeInsertSliceAsyncOp(mod))) + return signalPassFailure(); // Step 3 Allocation allocation(mod); @@ -142,7 +143,7 @@ class ConvertTritonGPUToLLVM // Step 4 RewritePatternSet scf_patterns(context); - mlir::populateLoopToStdConversionPatterns(scf_patterns); + mlir::populateSCFToControlFlowConversionPatterns(scf_patterns); mlir::ConversionTarget scf_target(*context); scf_target.addIllegalOp(); @@ -159,8 +160,10 @@ class ConvertTritonGPUToLLVM return signalPassFailure(); // Step 6 - get axis and shared memory info - AxisInfoAnalysis axisInfoAnalysis(mod.getContext()); - axisInfoAnalysis.run(mod); + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *axisInfoAnalysis = solver->load(); + if (failed(solver->initializeAndRun(mod))) + return signalPassFailure(); initSharedMemory(allocation.getSharedMemorySize(), typeConverter); mod->setAttr("triton_gpu.shared", mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32), @@ -178,38 +181,39 @@ class ConvertTritonGPUToLLVM // Normal conversions populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, &allocation, smem, + *axisInfoAnalysis, &allocation, smem, indexCacheInfo, /*benefit=*/10); // ConvertLayoutOp populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, &allocation, smem, + *axisInfoAnalysis, &allocation, smem, indexCacheInfo, /*benefit=*/10); // DotOp populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, &allocation, smem, + *axisInfoAnalysis, &allocation, smem, /*benefit=*/10); // ElementwiseOp populateElementwiseOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, &allocation, smem, + *axisInfoAnalysis, &allocation, smem, /*benefit=*/10); // LoadStoreOp populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, &allocation, smem, + *axisInfoAnalysis, &allocation, smem, indexCacheInfo, /*benefit=*/10); // ReduceOp populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, &allocation, smem, + *axisInfoAnalysis, &allocation, smem, indexCacheInfo, /*benefit=*/10); // ViewOp populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps, - axisInfoAnalysis, &allocation, smem, + *axisInfoAnalysis, &allocation, smem, /*benefit=*/10); // Add arith/math's patterns to help convert scalar expression to LLVM. mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter, patterns); mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); - mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns); + mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, + patterns); mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns); if (failed(applyPartialConversion(mod, target, std::move(patterns)))) @@ -306,9 +310,11 @@ class ConvertTritonGPUToLLVM }); } - void decomposeInsertSliceAsyncOp(ModuleOp mod) const { - AxisInfoAnalysis axisInfoAnalysis(mod.getContext()); - axisInfoAnalysis.run(mod); + LogicalResult decomposeInsertSliceAsyncOp(ModuleOp mod) const { + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *axisInfoAnalysis = solver->load(); + if (failed(solver->initializeAndRun(mod))) + return failure(); // TODO(Keren): This is a hacky knob that may cause performance regression // when decomposition has been performed. We should remove this knob once we // have thorough analysis on async wait. Currently, we decompose @@ -342,7 +348,7 @@ class ConvertTritonGPUToLLVM auto resSharedLayout = dstTy.getEncoding().dyn_cast(); auto resElemTy = dstTy.getElementType(); - unsigned inVec = axisInfoAnalysis.getPtrContiguity(src); + unsigned inVec = axisInfoAnalysis->getPtrContiguity(src); unsigned outVec = resSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); auto maxBitWidth = @@ -400,11 +406,11 @@ class ConvertTritonGPUToLLVM } else if (decomposed) { // Wait for all previous async ops OpBuilder builder(asyncWaitOp); - auto newAsyncWaitOp = - builder.create(asyncWaitOp.getLoc(), 0); + builder.create(asyncWaitOp.getLoc(), 0); asyncWaitOp.erase(); } }); + return success(); } }; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index d35dac28c5..11976908cf 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -220,10 +220,7 @@ struct SharedMemoryObject { ConversionPatternRewriter &rewriter) : base(base) { strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter); - - for (auto idx : order) { - offsets.emplace_back(i32_val(0)); - } + offsets.append(order.size(), i32_val(0)); } SmallVector getElems() const { diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index fe42202c34..5f230f787f 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -1,10 +1,10 @@ #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" -#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -59,10 +59,13 @@ class ArithConstantPattern : public OpConversionPattern { Type retType = getTypeConverter()->convertType(op.getType()); auto value = adaptor.getValue().dyn_cast(); assert(value); - rewriter.replaceOpWithNewOp( - op, retType, - value.reshape(retType) // This is a hack. We just want to add encoding - ); + if (value.getElementType().isInteger(1) && value.isSplat()) + // Workaround until https://reviews.llvm.org/D133743 is included. + value = DenseElementsAttr::get(retType, value.getSplatValue()); + else + // This is a hack. We just want to add encoding + value = value.reshape(retType); + rewriter.replaceOpWithNewOp(op, retType, value); return success(); } }; @@ -127,12 +130,12 @@ void populateArithmeticPatternsAndLegality( } // this shouldn't exist if mlir's SelectOp checked encodings properly -class StdSelectPattern : public OpConversionPattern { +class StdSelectPattern : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SelectOp op, typename SelectOp::Adaptor adaptor, + matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = this->getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp( @@ -148,8 +151,8 @@ void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter, MLIRContext *context = patterns.getContext(); // Rewrite rule patterns.add(typeConverter, context); - target.addLegalOp(); // this is ok because all functions are inlined - // by the frontend + target.addLegalOp(); // this is ok because all functions are + // inlined by the frontend } void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, @@ -455,18 +458,19 @@ struct TritonPrintfPattern : public OpConversionPattern { void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add< // TODO: view should have custom pattern that views the layout - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, TritonBroadcastPattern, - TritonGenericPattern, TritonCatPattern, - TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, - TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, - TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern, - TritonAtomicRMWPattern>(typeConverter, context); + patterns + .insert< // TODO: view should have custom pattern that views the layout + TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, + TritonGenericPattern, TritonBroadcastPattern, + TritonGenericPattern, TritonCatPattern, + TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern, + TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern, + TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern, + TritonAtomicRMWPattern>(typeConverter, context); } // @@ -623,29 +627,28 @@ void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, // CF -class CFBranchPattern : public OpConversionPattern { +class CFBranchPattern : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(BranchOp op, BranchOp::Adaptor adaptor, + matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto converter = getTypeConverter(); - auto newOp = rewriter.replaceOpWithNewOp(op, op.getSuccessor(), - adaptor.getOperands()); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getSuccessor(), adaptor.getOperands()); return success(); } }; -class CFCondBranchPattern : public OpConversionPattern { +class CFCondBranchPattern : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CondBranchOp op, CondBranchOp::Adaptor adaptor, + matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto converter = getTypeConverter(); - auto newOp = rewriter.replaceOpWithNewOp( + auto newOp = rewriter.replaceOpWithNewOp( op, adaptor.getCondition(), op.getTrueDest(), adaptor.getTrueDestOperands(), op.getFalseDest(), adaptor.getFalseDestOperands()); diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt index 2d679b21fd..705554ba6b 100644 --- a/lib/Dialect/Triton/IR/CMakeLists.txt +++ b/lib/Dialect/Triton/IR/CMakeLists.txt @@ -10,11 +10,7 @@ add_mlir_dialect_library(TritonIR LINK_LIBS PUBLIC MLIRIR - MLIRArithmetic - MLIRSCF - - # Since LLVM 15 - # MLIRFunc - # else - MLIRStandard + MLIRArithmeticDialect + MLIRSCFDialect + MLIRFuncDialect ) diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 3aadbfa0c0..86570359c5 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1,10 +1,9 @@ -#include "triton/Dialect/Triton/IR/Dialect.h" -#include "triton/Dialect/Triton/IR/Types.h" - #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OperationSupport.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Types.h" namespace mlir { namespace triton { @@ -38,8 +37,8 @@ static Type getPointerTypeSameShape(Type type) { } // Parser & printer for assembly forms -ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { - SmallVector allOperands; +ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector allOperands; Type resultTypes[1]; SMLoc allOperandLoc = parser.getCurrentLocation(); if (parser.parseOperandList(allOperands) || @@ -73,18 +72,18 @@ ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { return success(); } -void printLoadOp(OpAsmPrinter &printer, LoadOp loadOp) { +void LoadOp::print(OpAsmPrinter &printer) { printer << " "; - printer << loadOp.getOperation()->getOperands(); + printer << getOperation()->getOperands(); // "operand_segment_sizes" can be deduced, so we don't print it. - printer.printOptionalAttrDict(loadOp->getAttrs(), - {loadOp.operand_segment_sizesAttrName()}); + printer.printOptionalAttrDict(getOperation()->getAttrs(), + {operand_segment_sizesAttrName()}); printer << " : "; - printer.printStrippedAttrOrType(loadOp.result().getType()); + printer.printStrippedAttrOrType(getResult().getType()); } -ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { - SmallVector allOperands; +ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector allOperands; Type valueType; SMLoc allOperandLoc = parser.getCurrentLocation(); if (parser.parseOperandList(allOperands) || @@ -104,12 +103,12 @@ ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { return success(); } -void printStoreOp(OpAsmPrinter &printer, StoreOp storeOp) { +void StoreOp::print(OpAsmPrinter &printer) { printer << " "; - printer << storeOp.getOperation()->getOperands(); - printer.printOptionalAttrDict(storeOp->getAttrs(), /*elidedAttrs=*/{}); + printer << getOperation()->getOperands(); + printer.printOptionalAttrDict(getOperation()->getAttrs(), /*elidedAttrs=*/{}); printer << " : "; - printer.printStrippedAttrOrType(storeOp.value().getType()); + printer.printStrippedAttrOrType(value().getType()); } } // namespace triton @@ -319,7 +318,8 @@ OpFoldResult SplatOp::fold(ArrayRef operands) { if (!constOperand) return {}; auto shapedType = getType().cast(); - auto ret = SplatElementsAttr::get(shapedType, {constOperand.getValue()}); + auto ret = SplatElementsAttr::get( + shapedType, ArrayRef(constOperand.getValue())); return ret; } diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 2261472170..11570283d6 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -57,13 +57,13 @@ DenseElementsAttr getConstantValue(Builder &builder, Attribute value, class CombineSelectMaskedLoadPattern : public mlir::RewritePattern { public: CombineSelectMaskedLoadPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(mlir::SelectOp::getOperationName(), 3, context, - {triton::LoadOp::getOperationName()}) {} + : mlir::RewritePattern(mlir::arith::SelectOp::getOperationName(), 3, + context, {triton::LoadOp::getOperationName()}) {} mlir::LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { - auto selectOp = llvm::dyn_cast(op); + auto selectOp = llvm::dyn_cast(op); if (!selectOp) return mlir::failure(); diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 14f286b26e..ded0e346e6 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -1,9 +1,9 @@ #ifndef TRITON_PATTERNS #define TRITON_PATTERNS -include "mlir/Dialect/StandardOps/IR/Ops.td" include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" include "triton/Dialect/Triton/IR/TritonOps.td" +include "mlir/IR/PatternBase.td" // AddIOp(DotOp(a, b, c), d) and c==0 => DotOp(a, b, d) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 1fbc609e88..bfc3f3d3da 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1,14 +1,14 @@ +#include "triton/Dialect/Triton/IR/Dialect.h" + #include #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/OpImplementation.h" #include "triton/Analysis/Utility.h" -#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/TypeSwitch.h" -#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc" - using namespace mlir; using namespace mlir::triton::gpu; @@ -366,7 +366,6 @@ template SmallVector SliceEncodingAttr::paddedShape(ArrayRef shape) const; unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef shape) const { - size_t rank = shape.size(); auto parent = getParent(); return ::getElemsPerThread(parent, paddedShape(shape)); } @@ -655,9 +654,9 @@ void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { // InsertSliceAsyncOp //===----------------------------------------------------------------------===// -ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser, - OperationState &result) { - SmallVector allOperands; +ParseResult InsertSliceAsyncOp::parse(OpAsmParser &parser, + OperationState &result) { + SmallVector allOperands; Type srcType, dstType; SMLoc allOperandLoc = parser.getCurrentLocation(); if (parser.parseOperandList(allOperands) || @@ -696,18 +695,16 @@ ParseResult parseInsertSliceAsyncOp(OpAsmParser &parser, return success(); } -void printInsertSliceAsyncOp(OpAsmPrinter &printer, - InsertSliceAsyncOp insertSliceAsyncOp) { +void InsertSliceAsyncOp::print(OpAsmPrinter &printer) { printer << " "; - printer << insertSliceAsyncOp.getOperation()->getOperands(); + printer << getOperation()->getOperands(); // "operand_segment_sizes" can be deduced, so we don't print it. - printer.printOptionalAttrDict( - insertSliceAsyncOp->getAttrs(), - {insertSliceAsyncOp.operand_segment_sizesAttrName()}); + printer.printOptionalAttrDict(getOperation()->getAttrs(), + {operand_segment_sizesAttrName()}); printer << " : "; - printer.printStrippedAttrOrType(insertSliceAsyncOp.src().getType()); + printer.printStrippedAttrOrType(src().getType()); printer << " -> "; - printer.printStrippedAttrOrType(insertSliceAsyncOp.result().getType()); + printer.printStrippedAttrOrType(result().getType()); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index 82407980d3..ee6009f44a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -27,7 +27,11 @@ struct CoalescePass : public TritonGPUCoalesceBase { auto origType = ptr.getType().cast(); // Get the shape of the tensor. size_t rank = origType.getRank(); - AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue(); + dataflow::Lattice *latticeElement = + axisInfo.getLatticeElement(ptr); + AxisInfo info = latticeElement && !latticeElement->isUninitialized() + ? latticeElement->getValue() + : AxisInfo(); // Get the contiguity order of `ptr` auto order = argSort(info.getContiguity()); // The desired divisibility is the maximum divisibility @@ -40,7 +44,7 @@ struct CoalescePass : public TritonGPUCoalesceBase { for (Value val : op->getResults()) { if (val.getType() != origType) continue; - auto valInfo = axisInfo.lookupLatticeElement(val); + auto valInfo = axisInfo.getLatticeElement(val); auto currOrder = argSort(valInfo->getValue().getContiguity()); if (order == currOrder) withSameOrder.insert(val); @@ -55,7 +59,7 @@ struct CoalescePass : public TritonGPUCoalesceBase { unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); unsigned perThread = 1; for (Value val : withSameOrder) { - AxisInfo info = axisInfo.lookupLatticeElement(val)->getValue(); + AxisInfo info = axisInfo.getLatticeElement(val)->getValue(); unsigned maxMultipleBytes = info.getDivisibility(order[0]); unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); unsigned maxContig = info.getContiguity(order[0]); @@ -123,8 +127,10 @@ struct CoalescePass : public TritonGPUCoalesceBase { void runOnOperation() override { Operation *op = getOperation(); // Run axis info analysis - AxisInfoAnalysis axisInfo(&getContext()); - axisInfo.run(op); + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *axisInfo = solver->load(); + if (failed(solver->initializeAndRun(op))) + return signalPassFailure(); // For each i/o operation, we determine what layout // the pointers should have for best memory coalescing @@ -146,10 +152,10 @@ struct CoalescePass : public TritonGPUCoalesceBase { RankedTensorType ty = ptr.getType().template dyn_cast(); if (!ty || !ty.getElementType().isa()) return; - AxisInfo info = axisInfo.lookupLatticeElement(ptr)->getValue(); + AxisInfo info = axisInfo->getLatticeElement(ptr)->getValue(); auto mod = curr->getParentOfType(); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - auto convertType = getTypeConverter(axisInfo, ptr, numWarps); + auto convertType = getTypeConverter(*axisInfo, ptr, numWarps); layoutMap[ptr] = convertType; }); diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.cpp b/lib/Dialect/TritonGPU/Transforms/Combine.cpp index efa37ff2dc..089ce3996c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Combine.cpp @@ -1,6 +1,6 @@ #include "Utility.h" #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" diff --git a/lib/Dialect/TritonGPU/Transforms/Combine.td b/lib/Dialect/TritonGPU/Transforms/Combine.td index 6bf1b14866..6a7b10dbcb 100644 --- a/lib/Dialect/TritonGPU/Transforms/Combine.td +++ b/lib/Dialect/TritonGPU/Transforms/Combine.td @@ -3,5 +3,6 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUOps.td" include "triton/Dialect/Triton/IR/TritonOps.td" +include "mlir/IR/PatternBase.td" #endif diff --git a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp index 4bd3bc76bf..b2f8defd81 100644 --- a/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/DecomposeConversions.cpp @@ -1,5 +1,5 @@ #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 9b2f42231e..85f746c1dc 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -2,6 +2,7 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/TypeUtilities.h" #include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -160,15 +161,18 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op, LogicalResult LoopPipeliner::initialize() { Block *loop = forOp.getBody(); - AxisInfoAnalysis axisInfoAnalysis(forOp.getContext()); - axisInfoAnalysis.run(forOp->getParentOfType()); + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *axisInfoAnalysis = solver->load(); + if (failed(solver->initializeAndRun(forOp->getParentOfType()))) { + return failure(); + } // can we use forOp.walk(...) here? SmallVector allLoads; for (Operation &op : *loop) if (auto loadOp = dyn_cast(&op)) { auto ptr = loadOp.ptr(); - unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); + unsigned vec = axisInfoAnalysis->getPtrContiguity(ptr); auto tensorTy = ptr.getType().dyn_cast(); if (!tensorTy) continue; diff --git a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp index 0e7dbe5264..b95a4f50a6 100644 --- a/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp @@ -1,5 +1,5 @@ #include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Matchers.h" diff --git a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp index 37ac710995..762e887f36 100644 --- a/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp +++ b/lib/Dialect/TritonGPU/Transforms/TritonGPUConversion.cpp @@ -82,12 +82,12 @@ TritonGPUConversionTarget::TritonGPUConversionTarget( scf::ReduceReturnOp>(); addDynamicallyLegalDialect([&](Operation *op) { - if (typeConverter.isLegal(op)) - return true; - return false; - }); + triton::TritonDialect, scf::SCFDialect>( + [&](Operation *op) { + if (typeConverter.isLegal(op)) + return true; + return false; + }); // We have requirements for the data layouts addDynamicallyLegalOp([](triton::DotOp dotOp) -> bool { diff --git a/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp b/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp index c229104286..c911fd4a5c 100644 --- a/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp +++ b/lib/Dialect/TritonGPU/Transforms/UpdateMmaForVolta.cpp @@ -1,5 +1,5 @@ #include "Utility.h" -#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -118,8 +118,8 @@ void setOpResultType(Operation *op, ArrayRef newTypes) { .get("value") .dyn_cast(); if (attr) { - auto newAttr = mlir::DenseElementsAttr::getFromRawBuffer( - newType, attr.getRawData(), true); + auto newAttr = + mlir::DenseElementsAttr::getFromRawBuffer(newType, attr.getRawData()); op->setAttr("value", newAttr); } } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index ed15f02f67..6400f1633a 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -1,5 +1,5 @@ #include "Utility.h" -#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index f1bbd0bf4e..ac8973ad19 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -6,8 +6,7 @@ add_mlir_translation_library(TritonLLVMIR LINK_LIBS PUBLIC MLIRIR - MLIRLLVMIR - MLIRSCFToStandard + MLIRLLVMDialect MLIRSupport MLIRTargetLLVMIRExport ) diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index 4cb0d8193c..6a5453a6e7 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -1,11 +1,14 @@ #include "triton/Target/PTX/PTXTranslation.h" #include "triton/Target/LLVMIR/LLVMIRTranslation.h" +#include #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LegacyPassManager.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" #include "llvm/MC/TargetRegistry.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/TargetSelect.h" #include "llvm/Target/TargetMachine.h" diff --git a/python/setup.py b/python/setup.py index 2ac3accd25..4530b36714 100644 --- a/python/setup.py +++ b/python/setup.py @@ -57,19 +57,10 @@ def get_pybind11_package_info(): def get_llvm_package_info(): # download if nothing is installed system = platform.system() - if system == "Darwin": - system_suffix = "apple-darwin" - elif system == "Linux": - vglibc = tuple(map(int, platform.libc_ver()[1].split('.'))) - vglibc = vglibc[0] * 100 + vglibc[1] - linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7' - system_suffix = f"linux-gnu-{linux_suffix}" - else: - raise RuntimeError(f"unsupported system: {system}") + system_suffix = {"Linux": "linux-gnu-ubuntu-18.04", "Darwin": "apple-darwin"}[system] use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False") - release_suffix = "assert" if use_assert_enabled_llvm else "release" - name = f'llvm+mlir-14.0.6-x86_64-{system_suffix}-{release_suffix}' - url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-14.0.6-f28c006a5895/{name}.tar.xz" + name = 'llvm+mlir-15.0.7-x86_64-{}-{}'.format(system_suffix, "assert" if use_assert_enabled_llvm else "release") + url = "https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-15.0.7-8dfdcc7b7bf6/{}.tar.xz".format(name) return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") diff --git a/python/src/triton.cc b/python/src/triton.cc index c40b117a55..f190eacc34 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -8,9 +8,10 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "mlir/Parser.h" +#include "mlir/Parser/Parser.h" #include "mlir/Support/FileUtilities.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "triton/Analysis/Allocation.h" #include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h" @@ -195,7 +196,7 @@ void init_triton_ir(py::module &&m) { std::string attrName = name + "_arg" + std::to_string(id); mlir::Block *owner = arg.getOwner(); if (owner->isEntryBlock() && - !mlir::isa(owner->getParentOp())) { + !mlir::isa(owner->getParentOp())) { owner->getParentOp()->setAttr(attrName, attr); } } @@ -348,7 +349,7 @@ void init_triton_ir(py::module &&m) { return str; }) .def("push_back", - [](mlir::ModuleOp &self, mlir::FuncOp &funcOp) -> void { + [](mlir::ModuleOp &self, mlir::func::FuncOp &funcOp) -> void { self.push_back(funcOp); }) .def("has_function", @@ -358,16 +359,18 @@ void init_triton_ir(py::module &&m) { return false; }) .def("get_function", - [](mlir::ModuleOp &self, std::string &funcName) -> mlir::FuncOp { - return self.lookupSymbol(funcName); - }) - .def("get_single_function", [](mlir::ModuleOp &self) -> mlir::FuncOp { - llvm::SmallVector funcs; - self.walk([&](mlir::FuncOp func) { funcs.push_back(func); }); - if (funcs.size() != 1) - throw std::runtime_error("Expected a single function"); - return funcs[0]; - }); + [](mlir::ModuleOp &self, + std::string &funcName) -> mlir::func::FuncOp { + return self.lookupSymbol(funcName); + }) + .def("get_single_function", + [](mlir::ModuleOp &self) -> mlir::func::FuncOp { + llvm::SmallVector funcs; + self.walk([&](mlir::func::FuncOp func) { funcs.push_back(func); }); + if (funcs.size() != 1) + throw std::runtime_error("Expected a single function"); + return funcs[0]; + }); m.def("make_attr", [](const std::vector &values, mlir::MLIRContext &context) { @@ -388,47 +391,48 @@ void init_triton_ir(py::module &&m) { registry.insert(); + mlir::func::FuncDialect, mlir::scf::SCFDialect>(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); // parse module - mlir::OwningOpRef module( - mlir::parseSourceFile(inputFilename, &context)); + mlir::OwningOpRef module = + mlir::parseSourceFile(inputFilename, &context); + if (!module) + throw std::runtime_error("Parse MLIR file failed."); // locations are incompatible with ptx < 7.5 ! module->walk([](mlir::Operation *op) { op->setLoc(mlir::UnknownLoc::get(op->getContext())); }); - if (!module) - throw std::runtime_error("Parse MLIR file failed."); return module->clone(); }, ret::take_ownership); - py::class_(m, "function") + py::class_(m, "function") // .def_property_readonly("attrs", &ir::function::attrs) // .def("add_attr", &ir::function::add_attr); .def("args", - [](mlir::FuncOp &self, unsigned idx) -> mlir::BlockArgument { + [](mlir::func::FuncOp &self, unsigned idx) -> mlir::BlockArgument { return self.getArgument(idx); }) .def( "add_entry_block", - [](mlir::FuncOp &self) -> mlir::Block * { + [](mlir::func::FuncOp &self) -> mlir::Block * { return self.addEntryBlock(); }, ret::reference) .def( "set_arg_attr", - [](mlir::FuncOp &self, int arg_no, const std::string &name, int val) { + [](mlir::func::FuncOp &self, int arg_no, const std::string &name, + int val) { // set arg attributes "name" to value "val" auto attrTy = mlir::IntegerType::get(self.getContext(), 32); self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val)); }, ret::reference) - .def_property_readonly("type", &mlir::FuncOp::getType) - .def("reset_type", &mlir::FuncOp::setType); + .def_property_readonly("type", &mlir::func::FuncOp::getFunctionType) + .def("reset_type", &mlir::func::FuncOp::setType); py::class_(m, "InsertPoint"); @@ -445,13 +449,13 @@ void init_triton_ir(py::module &&m) { .def("ret", [](mlir::OpBuilder &self, std::vector &vals) -> void { auto loc = self.getUnknownLoc(); - self.create(loc, vals); + self.create(loc, vals); }) .def("call", - [](mlir::OpBuilder &self, mlir::FuncOp &func, + [](mlir::OpBuilder &self, mlir::func::FuncOp &func, std::vector &args) -> mlir::OpState { auto loc = self.getUnknownLoc(); - return self.create(loc, func, args); + return self.create(loc, func, args); }) // insertion block/point .def("set_insertion_point_to_start", @@ -618,15 +622,16 @@ void init_triton_ir(py::module &&m) { .def("get_or_insert_function", [](mlir::OpBuilder &self, mlir::ModuleOp &module, std::string &funcName, mlir::Type &funcType, - std::string &visibility) -> mlir::FuncOp { + std::string &visibility) -> mlir::func::FuncOp { if (mlir::Operation *funcOperation = module.lookupSymbol(funcName)) - return llvm::dyn_cast(funcOperation); + return llvm::dyn_cast(funcOperation); auto loc = self.getUnknownLoc(); if (auto funcTy = funcType.dyn_cast()) { llvm::SmallVector attrs = { mlir::NamedAttribute(self.getStringAttr("sym_visibility"), self.getStringAttr(visibility))}; - return self.create(loc, funcName, funcTy, attrs); + return self.create(loc, funcName, funcTy, + attrs); } throw std::runtime_error("invalid function type"); }) @@ -658,15 +663,15 @@ void init_triton_ir(py::module &&m) { [](mlir::OpBuilder &self, mlir::Value condition, mlir::Block *trueDest, mlir::Block *falseDest) { auto loc = self.getUnknownLoc(); - self.create(loc, condition, trueDest, - falseDest); + self.create(loc, condition, trueDest, + falseDest); return; }) .def("create_branch", [](mlir::OpBuilder &self, mlir::Block *dest, std::vector &args) { auto loc = self.getUnknownLoc(); - self.create(loc, dest, args); + self.create(loc, dest, args); return; }) // Structured control flow @@ -792,14 +797,14 @@ void init_triton_ir(py::module &&m) { .def("create_to_index", [](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, input, - self.getIndexType()); + return self.create( + loc, self.getIndexType(), input); }) .def("create_index_to_si", [](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, input, - self.getI32Type()); + return self.create( + loc, self.getI32Type(), input); }) .def("create_fmul", [](mlir::OpBuilder &self, mlir::Value &lhs, @@ -1316,8 +1321,8 @@ void init_triton_ir(py::module &&m) { [](mlir::OpBuilder &self, mlir::Value &condition, mlir::Value &trueValue, mlir::Value &falseValue) -> mlir::Value { auto loc = self.getUnknownLoc(); - return self.create(loc, condition, trueValue, - falseValue); + return self.create(loc, condition, + trueValue, falseValue); }) .def("create_printf", [](mlir::OpBuilder &self, const std::string &prefix, @@ -1429,7 +1434,7 @@ void init_triton_ir(py::module &&m) { self.addPass(mlir::triton::createConvertTritonGPUToLLVMPass()); }) .def("add_scf_to_cfg", [](mlir::PassManager &self) { - self.addPass(mlir::createLowerToCFGPass()); + self.addPass(mlir::createConvertSCFToCFPass()); }); } diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 432544a8a4..018f544714 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1918,7 +1918,7 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'): #dst = {dst_layout} """ + """ module attributes {"triton_gpu.num-warps" = 4 : i32} { - func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { + func.func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<128> : tensor<128x1xi32, #src> %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>> %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>> diff --git a/python/triton/compiler.py b/python/triton/compiler.py index 5d167634df..c36589037c 100644 --- a/python/triton/compiler.py +++ b/python/triton/compiler.py @@ -1514,14 +1514,14 @@ def make_hash(fn, **kwargs): return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest() -# - ^\s*func\s+ : match the start of the string, any leading whitespace, the keyword func, +# - ^\s*func\.func\s+ : match the start of the string, any leading whitespace, the keyword func, # and any following whitespace # - (public\s+)? : optionally match the keyword public and any following whitespace # - (@\w+) : match an @ symbol followed by one or more word characters # (letters, digits, or underscores), and capture it as group 1 (the function name) # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing # zero or more arguments separated by commas, and capture it as group 2 (the argument list) -mlir_prototype_pattern = r'^\s*func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$' +mlir_prototype_pattern = r'^\s*func\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$' ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" prototype_pattern = { "ttir": mlir_prototype_pattern, diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index b3d5673f85..bb21615e68 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -11,7 +11,7 @@ // CHECK-LABEL: matmul_loop // There shouldn't be any aliasing with the dot op encoding. -func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { +func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> %a_mask = arith.constant dense : tensor<128x32xi1, #AL> @@ -36,7 +36,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B } // CHECK-LABEL: alloc -func @alloc(%A : !tt.ptr) { +func.func @alloc(%A : !tt.ptr) { // CHECK: %cst -> %cst %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> @@ -46,7 +46,7 @@ func @alloc(%A : !tt.ptr) { } // CHECK-LABEL: convert -func @convert(%A : !tt.ptr) { +func.func @convert(%A : !tt.ptr) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: %0 -> %0 %cst1 = triton_gpu.convert_layout %cst0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED> @@ -54,7 +54,7 @@ func @convert(%A : !tt.ptr) { } // CHECK-LABEL: trans -func @trans(%A : !tt.ptr) { +func.func @trans(%A : !tt.ptr) { // CHECK: %cst -> %cst %tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> // CHECK: %0 -> %cst @@ -63,7 +63,7 @@ func @trans(%A : !tt.ptr) { } // CHECK-LABEL: insert_slice_async -func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { +func.func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> @@ -76,7 +76,7 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { } // CHECK-LABEL: insert_slice -func @insert_slice(%A : !tt.ptr, %i1 : i1) { +func.func @insert_slice(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> @@ -90,7 +90,7 @@ func @insert_slice(%A : !tt.ptr, %i1 : i1) { } // CHECK-LABEL: extract_slice -func @extract_slice(%A : !tt.ptr) { +func.func @extract_slice(%A : !tt.ptr) { // CHECK: %cst -> %cst %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : index @@ -100,7 +100,7 @@ func @extract_slice(%A : !tt.ptr) { } // CHECK-LABEL: if_cat -func @if_cat(%i1 : i1) { +func.func @if_cat(%i1 : i1) { // CHECK: %cst -> %cst %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK: %cst_0 -> %cst_0 @@ -119,7 +119,7 @@ func @if_cat(%i1 : i1) { } // CHECK-LABEL: if_alias -func @if_alias(%i1 : i1) { +func.func @if_alias(%i1 : i1) { // CHECK: %cst -> %cst %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: %cst_0 -> %cst_0 @@ -134,7 +134,7 @@ func @if_alias(%i1 : i1) { } // CHECK-LABEL: for -func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { +func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: %cst -> %cst %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %cst_0 -> %cst_0 @@ -154,7 +154,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.p } // CHECK-LABEL: for_if -func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { +func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %cst -> %cst %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %cst_0 -> %cst_0 @@ -180,7 +180,7 @@ func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !t } // CHECK-LABEL: for_if_for -func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { +func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: %cst -> %cst %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: %cst_0 -> %cst_0 diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 0ab34c7a78..af8ea6f856 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -1,288 +1,288 @@ -// RUN: triton-opt %s -test-print-alignment -split-input-file 2>&1 | FileCheck %s +// RUN: triton-opt %s -test-print-alignment -split-input-file -o %t 2>&1 | FileCheck %s -// CHECK-LABEL: cast -func @cast() { - // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1] +// CHECK-LABEL: @cast +func.func @cast() { + // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 %cst = arith.constant 1 : i32 - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [1] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 %0 = arith.extsi %cst : i32 to i64 - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 %cst_tensor = arith.constant dense<1> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 %1 = tt.bitcast %cst_tensor : tensor<128xi32> -> tensor<128xi64> return } // ----- -// CHECK-LABEL: add -func @add() { - // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] +// CHECK-LABEL: @add +func.func @add() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 %1 = arith.constant dense<1> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = %2 = arith.addi %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [127] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 127 %3 = arith.constant dense<127> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 %4 = arith.addi %1, %3 : tensor<128xi32> return } // ----- -// CHECK-LABEL: sub -func @sub() { - // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] +// CHECK-LABEL: @sub +func.func @sub() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 %1 = arith.constant dense<1> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = %2 = arith.subi %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [129] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 129 %3 = arith.constant dense<129> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 %4 = arith.subi %3, %1 : tensor<128xi32> return } // ----- -// CHECK-LABEL: mul -func @mul() { - // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] +// CHECK-LABEL: @mul +func.func @mul() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 %1 = arith.constant dense<1> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %2 = arith.muli %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 %3 = arith.constant dense<128> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 %4 = arith.muli %3, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [2] + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2 %5 = arith.constant dense<2> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [256] ; Constancy: [128] ; ConstantValue: [256] + // CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [128], constant_value = 256 %6 = arith.muli %4, %5 : tensor<128xi32> return } // ----- -// CHECK-LABEL: div -func @div() { - // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] +// CHECK-LABEL: @div +func.func @div() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 %1 = arith.constant dense<1> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %2 = arith.divsi %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %3 = arith.divui %1, %0 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] + // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 %4 = arith.constant dense<64> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = %5 = arith.divsi %0, %4 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %6 = arith.divsi %4, %0 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] + // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 %7 = arith.divsi %4, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66] + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66 %8 = arith.constant dense<66> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [2] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [2], constant_value = %9 = arith.divui %0, %8 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [128] ; Divisibility: [8192] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = %10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [64] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [64], constant_value = %11 = arith.divsi %10, %4 : tensor<128xi32> - return + return } // ----- -// CHECK-LABEL: rem -func @rem() { - // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] +// CHECK-LABEL: @rem +func.func @rem() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [1] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 1 %1 = arith.constant dense<1> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 %2 = arith.remsi %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %3 = arith.remui %1, %0 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] + // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 %4 = arith.constant dense<64> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [64] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [64], divisibility = [64], constancy = [1], constant_value = %5 = arith.remsi %0, %4 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [1], constant_value = %6 = arith.remsi %4, %0 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [2] ; Constancy: [128] ; ConstantValue: [66] + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 66 %7 = arith.constant dense<66> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [2] ; Divisibility: [2] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [2], divisibility = [2], constancy = [1], constant_value = %8 = arith.remui %0, %7 : tensor<128xi32> - return + return } // ----- -// CHECK-LABEL: broadcast -func @broadcast() { - // CHECK: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] +// CHECK-LABEL: @broadcast +func.func @broadcast() { + // CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 %0 = arith.constant dense<64> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 1] ; ConstantValue: [64] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64 %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [64, 1] ; Constancy: [128, 128] ; ConstantValue: [64] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64 %2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32> return } // ----- -// CHECK-LABEL: splat -func @splat(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { - // CHECK: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 128] ; ConstantValue: [None] +// CHECK-LABEL: @splat +func.func @splat(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + // CHECK: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 128], constant_value = %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x128x!tt.ptr> return } // ----- -// CHECK-LABEL: cmp -func @cmp() { - // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] +// CHECK-LABEL: @cmp +func.func @cmp() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 %1 = arith.constant dense<0> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = %2 = arith.cmpi eq, %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = %3 = arith.cmpi slt, %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %4 = arith.cmpi sle, %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = %5 = arith.cmpi sge, %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 %6 = arith.constant dense<8> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = %7 = arith.cmpi sgt, %0, %6 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [0] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 0 %8 = arith.cmpi sgt, %1, %6 : tensor<128xi32> return } // ----- -// CHECK-LABEL: logic -func @logic() { - // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] +// CHECK-LABEL: @logic +func.func @logic() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [64] ; Constancy: [128] ; ConstantValue: [64] + // CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 %1 = arith.constant dense<64> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16777216] ; Constancy: [64] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = %2 = arith.divsi %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 %3 = arith.constant dense<8> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [134217728] ; Constancy: [8] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [134217728], constancy = [8], constant_value = %4 = arith.divsi %0, %3 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %5 = arith.andi %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %6 = arith.ori %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %7 = arith.xori %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = %8 = arith.andi %2, %4 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = %9 = arith.ori %2, %4 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [8] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = %10 = arith.xori %2, %4 : tensor<128xi32> return } // ----- -// CHECK-LABEL: select -func @select() { - // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] +// CHECK-LABEL: @select +func.func @select() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 %1 = arith.constant dense<0> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = %2 = arith.cmpi eq, %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = %3 = arith.cmpi slt, %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0] + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 %4 = arith.constant 0 : i1 - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 %7 = tt.splat %4 : (i1) -> tensor<128xi1> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [128] ; ConstantValue: [0] - %5 = select %4, %3, %7 : tensor<128xi1> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [128], constant_value = 0 + %5 = arith.select %4, %3, %7 : tensor<128xi1> + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = %8 = "triton_gpu.select"(%7, %3, %2) : (tensor<128xi1>, tensor<128xi1>, tensor<128xi1>) -> tensor<128xi1> return } // ----- -func @shift() { - // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] +func.func @shift() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 %1 = arith.constant dense<8> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4] + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4 %2 = arith.constant dense<4> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [274877906944] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [274877906944], constancy = [1], constant_value = %3 = arith.shli %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [67108864] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [67108864], constancy = [1], constant_value = %4 = arith.shrsi %0, %2 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [128] + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 %5 = arith.shli %1, %2 : tensor<128xi32> return } // ----- -func @max_min() { - // CHECK: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] +func.func @max_min() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [128] ; Divisibility: [64] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128], divisibility = [64], constancy = [1], constant_value = %1 = tt.make_range {end = 192 : i32, start = 64 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %2 = arith.maxsi %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %3 = arith.minsi %0, %1 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [8] ; Constancy: [128] ; ConstantValue: [8] + // CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8 %4 = arith.constant dense<8> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4] ; Constancy: [128] ; ConstantValue: [4] + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4 %5 = arith.constant dense<4> : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [8] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 8 %6 = arith.maxsi %4, %5 : tensor<128xi32> return } // ----- -// CHECK-LABEL: for -func @for() { - // CHECK: Contiguity: [1, 1] ; Divisibility: [4611686018427387904, 4611686018427387904] ; Constancy: [128, 32] ; ConstantValue: [0] +// CHECK-LABEL: @for +func.func @for() { + // CHECK: contiguity = [1, 1], divisibility = [4611686018427387904, 4611686018427387904], constancy = [128, 32], constant_value = 0 %a_init = arith.constant dense<0> : tensor<128x32xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [1] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = 1 %b_init = arith.constant dense<1> : tensor<128x32xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4 %c_init = arith.constant dense<4> : tensor<128x32xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128] + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128 %ub = arith.constant 128 : index - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [4611686018427387904] ; Constancy: [1] ; ConstantValue: [0] + // CHECK-NEXT: contiguity = [1], divisibility = [4611686018427387904], constancy = [1], constant_value = 0 %lb = arith.constant 0 : index - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [16] + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16 %step = arith.constant 16 : index %a, %b, %c = scf.for %iv = %lb to %ub step %step iter_args(%a = %a_init, %b = %b_init, %c = %c_init) -> (tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32>) { - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = %t = arith.index_cast %iv : index to i32 - // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None] - // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 32] ; ConstantValue: [None] - // CHECK: Contiguity: [1, 1] ; Divisibility: [4, 4] ; Constancy: [128, 32] ; ConstantValue: [4] + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 32], constant_value = + // CHECK: contiguity = [1, 1], divisibility = [4, 4], constancy = [128, 32], constant_value = 4 scf.yield %b, %a, %c : tensor<128x32xi32>, tensor<128x32xi32>, tensor<128x32xi32> } return @@ -290,53 +290,53 @@ func @for() { // ----- -// CHECK-LABEL: permute_2d -func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { - // CHECK: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [128, 128] ; ConstantValue: [1] +// CHECK-LABEL: @permute_2d +func.func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { + // CHECK: contiguity = [1, 1], divisibility = [1, 1], constancy = [128, 128], constant_value = 1 %cst = arith.constant dense : tensor<128x128xi1> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf32> - // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = %3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [17179869184, 16] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [17179869184, 16], constancy = [1, 1], constant_value = %4 = arith.muli %2, %3 : tensor<128x1xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x1x!tt.ptr> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = %6 = tt.addptr %5, %4 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> - // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = %7 = tt.expand_dims %1 {axis = 0 : i32}: (tensor<128xi32>) -> tensor<1x128xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = %8 = tt.broadcast %6 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> - // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [128, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [128, 1], constant_value = %9 = tt.broadcast %7 : (tensor<1x128xi32>) -> tensor<128x128xi32> - // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 16] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 16], constancy = [1, 1], constant_value = %10 = tt.addptr %8, %9 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [1073741824, 1] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128, 1], divisibility = [1073741824, 1], constancy = [1, 1], constant_value = %11 = tt.expand_dims %0 {axis = 1 : i32}: (tensor<128xi32>) -> tensor<128x1xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [128, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = %12 = tt.splat %arg2 : (!tt.ptr) -> tensor<128x1x!tt.ptr> - // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = %13 = tt.addptr %12, %11 : tensor<128x1x!tt.ptr>, tensor<128x1xi32> - // CHECK-NEXT: Contiguity: [1, 128] ; Divisibility: [1, 1073741824] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 128], divisibility = [1, 1073741824], constancy = [1, 1], constant_value = %14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 16] ; Constancy: [1, 128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = %15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = %16 = arith.muli %14, %15 : tensor<1x128xi32> - // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 128], constant_value = %17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [16, 17179869184] ; Constancy: [128, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = %18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32> - // CHECK-NEXT: Contiguity: [128, 1] ; Divisibility: [16, 1] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 1], constancy = [1, 1], constant_value = %19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> - // CHECK-NEXT: Contiguity: [1, 1] ; Divisibility: [1, 1] ; Constancy: [1, 1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1, 1], divisibility = [1, 1], constancy = [1, 1], constant_value = %20 = tt.load %10, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf32> tt.store %19, %20, %cst : tensor<128x128xf32> return @@ -347,29 +347,29 @@ func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {t module { // This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer. -// CHECK-LABEL: store_constant_align -func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) { - // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] +// CHECK-LABEL: @store_constant_align +func.func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) { + // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %pid = tt.get_program_id {axis = 0 : i32} : i32 - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [128] + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = 128 %c128_i32 = arith.constant 128 : i32 - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [1], constant_value = %1 = arith.muli %pid, %c128_i32 : i32 - // CHECK-NEXT: Contiguity: [128] ; Divisibility: [1073741824] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = %2 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [128] ; Constancy: [128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = %3 = tt.splat %1 : (i32) -> tensor<128xi32> - // CHECK-NEXT: Contiguity: [128] ; Divisibility: [128] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128], divisibility = [128], constancy = [1], constant_value = %4 = arith.addi %3, %2 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = %5 = tt.splat %addr : (!tt.ptr) -> tensor<128x!tt.ptr> - // CHECK-NEXT: Contiguity: [128] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [128], divisibility = [16], constancy = [1], constant_value = %6 = tt.addptr %5, %4 : tensor<128x!tt.ptr>, tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [16] ; Constancy: [128] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [128], constant_value = %9 = tt.splat %n : (i32) -> tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [16], constant_value = %mask = arith.cmpi slt, %4, %9 : tensor<128xi32> - // CHECK-NEXT: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %cst = arith.constant dense<0.0> : tensor<128xf32> tt.store %5, %cst, %mask : tensor<128xf32> return @@ -381,8 +381,8 @@ func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, %n: // This IR is dumped from vecadd test. // Note, the hint {tt.divisibility = 16 : i32} for %n_elements affects the alignment of mask. -// CHECK-LABEL: vecadd_mask_align_16 -func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { +// CHECK-LABEL: @vecadd_mask_align_16 +func.func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32 {tt.divisibility = 16 : i32}) { %c64_i32 = arith.constant 64 : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.muli %0, %c64_i32 : i32 @@ -394,13 +394,13 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %ar %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x!tt.ptr> %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr>, tensor<64xi32> %9 = tt.splat %n_elements : (i32) -> tensor<64xi32> - // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [16] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> ) + // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [16], constant_value = %mask = arith.cmpi slt, %4, %9 : tensor<64xi32> %11 = tt.load %6, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> %12 = tt.load %8, %mask {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> %13 = arith.addf %11, %12 : tensor<64xf32> %14 = tt.splat %arg2 : (!tt.ptr) -> tensor<64x!tt.ptr> - // CHECK: Contiguity: [64] ; Divisibility: [16] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = tt.addptr %{{.*}}, %{{.*}} : tensor<64x!tt.ptr>, tensor<64xi32> ) + // CHECK: tt.addptr %{{.*}} => contiguity = [64], divisibility = [16], constancy = [1], constant_value = %15 = tt.addptr %14, %4 : tensor<64x!tt.ptr>, tensor<64xi32> tt.store %15, %13, %mask : tensor<64xf32> return @@ -410,8 +410,8 @@ func @vecadd_mask_align_16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %ar // This IR is dumped from vecadd test. // Note, there is no divisibility hint for %n_elements, Triton should assume its divisibility to be 1 by default. -// CHECK-LABEL: vecadd_mask_align_1 -func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { +// CHECK-LABEL: @vecadd_mask_align_1 +func.func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { %c64_i32 = arith.constant 64 : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.muli %0, %c64_i32 : i32 @@ -423,7 +423,7 @@ func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<64x!tt.ptr> %8 = tt.addptr %7, %4 : tensor<64x!tt.ptr>, tensor<64xi32> %9 = tt.splat %n_elements : (i32) -> tensor<64xi32> - // CHECK: Contiguity: [1] ; Divisibility: [1] ; Constancy: [1] ; ConstantValue: [None] ( %{{.*}} = arith.cmpi slt, %{{.*}}, %{{.*}} : tensor<64xi32> ) + // CHECK: arith.cmpi slt, %{{.*}} => contiguity = [1], divisibility = [1], constancy = [1], constant_value = %10 = arith.cmpi slt, %4, %9 : tensor<64xi32> %11 = tt.load %6, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> %12 = tt.load %8, %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xf32> diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index efb00c404d..f79222aa7b 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -13,7 +13,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_loop -func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { +func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> @@ -46,7 +46,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B // Shared memory is available after a tensor's liveness range ends // CHECK-LABEL: reusable -func @reusable(%A : !tt.ptr) { +func.func @reusable(%A : !tt.ptr) { %cst1 = arith.constant dense : tensor<128x32xi1, #AL> %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %cst3 = arith.constant dense : tensor<32x128xi1, #AL> @@ -78,7 +78,7 @@ func @reusable(%A : !tt.ptr) { // %cst1->%cst4 // %cst3->%g->%h->%i // CHECK-LABEL: preallocate -func @preallocate(%A : !tt.ptr) { +func.func @preallocate(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 1024, size = 512 @@ -113,7 +113,7 @@ func @preallocate(%A : !tt.ptr) { // Unused tensors are immediately released // CHECK-LABEL: unused -func @unused(%A : !tt.ptr) { +func.func @unused(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 %cst0 = arith.constant dense<0.000000e+00> : tensor<32x16xf16, #A_SHARED> // CHECK-NEXT: offset = 0, size = 512 @@ -128,7 +128,7 @@ func @unused(%A : !tt.ptr) { // cst0 is alive through the entire function, it cannot be released before the end of the function // CHECK-LABEL: longlive -func @longlive(%A : !tt.ptr) { +func.func @longlive(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 512, size = 512 @@ -156,7 +156,7 @@ func @longlive(%A : !tt.ptr) { } // CHECK-LABEL: alloc -func @alloc(%A : !tt.ptr) { +func.func @alloc(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> @@ -167,7 +167,7 @@ func @alloc(%A : !tt.ptr) { } // CHECK-LABEL: scratch -func @scratch() { +func.func @scratch() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> // CHECK: scratch offset = 0, size = 512 %b = tt.reduce %cst0 {redOp = 1 : i32, axis = 0 : i32} : tensor<16x16xf16, #AL> -> tensor<16xf16, #sliceAd0> @@ -176,7 +176,7 @@ func @scratch() { } // CHECK-LABEL: trans -func @trans(%A : !tt.ptr) { +func.func @trans(%A : !tt.ptr) { // CHECK: offset = 0, size = 1024 %tensor = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> %b = tt.trans %tensor : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T> @@ -184,7 +184,7 @@ func @trans(%A : !tt.ptr) { } // CHECK-LABEL: insert_slice_async -func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { +func.func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> @@ -197,7 +197,7 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { } // CHECK-LABEL: extract_slice -func @extract_slice(%A : !tt.ptr) { +func.func @extract_slice(%A : !tt.ptr) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : index @@ -209,7 +209,7 @@ func @extract_slice(%A : !tt.ptr) { // B0 -> (B1) -> B0 // Memory used by B1 can be reused by B0. // CHECK-LABEL: if -func @if(%i1 : i1) { +func.func @if(%i1 : i1) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 512, size = 512 @@ -233,7 +233,7 @@ func @if(%i1 : i1) { // B0 -> (B1) -> (B2) -> B0 // Memory used by B0 cannot be reused by B1 or B2. // CHECK-LABEL: if_else -func @if_else(%i1 : i1) { +func.func @if_else(%i1 : i1) { // CHECK: offset = 0, size = 512 %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK-NEXT: offset = 512, size = 512 @@ -260,7 +260,7 @@ func @if_else(%i1 : i1) { // Block arguments and yields are memory aliases that do not trigger a new // allocation. // CHECK-LABEL: for -func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { +func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { // CHECK: offset = 0, size = 8192 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: offset = 8192, size = 8192 @@ -275,7 +275,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.p } // CHECK-LABEL: for_if_slice -func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { +func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: offset = 8192, size = 8192 @@ -296,7 +296,7 @@ func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr, % // c0 cannot be released in the loop // CHECK-LABEL: for_use_ancestor -func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { +func.func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: offset = 8192, size = 8192 @@ -316,7 +316,7 @@ func @for_use_ancestor(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { +func.func @for_if_for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %i1 : i1) { // CHECK: offset = 0, size = 8192 %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: offset = 8192, size = 8192 diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 7199e5f53d..17880b2094 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -14,7 +14,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_loop // There shouldn't be any membar with the dot op encoding. -func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { +func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> @@ -42,7 +42,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B } // CHECK-LABEL: raw_single_block -func @raw_single_block(%A : !tt.ptr) { +func.func @raw_single_block(%A : !tt.ptr) { %cst1 = arith.constant dense : tensor<128x32xi1, #AL> %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> @@ -54,7 +54,7 @@ func @raw_single_block(%A : !tt.ptr) { } // CHECK-LABEL: war_single_block -func @war_single_block(%A : !tt.ptr) { +func.func @war_single_block(%A : !tt.ptr) { %cst1 = arith.constant dense : tensor<128x32xi1, #AL> %cst2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #AL> %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> @@ -70,7 +70,7 @@ func @war_single_block(%A : !tt.ptr) { } // CHECK-LABEL: scratch -func @scratch() { +func.func @scratch() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK: Membar 1 %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> @@ -81,7 +81,7 @@ func @scratch() { } // CHECK-LABEL: async_wait -func @async_wait() { +func.func @async_wait() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> // CHECK: Membar 1 %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> @@ -92,7 +92,7 @@ func @async_wait() { } // CHECK-LABEL: alloc -func @alloc() { +func.func @alloc() { %cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> %a = tt.cat %cst0, %cst0 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> // CHECK: Membar 2 @@ -101,7 +101,7 @@ func @alloc() { } // CHECK-LABEL: extract_slice -func @extract_slice() { +func.func @extract_slice() { %cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED> %index = arith.constant 0 : index %cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED> @@ -113,14 +113,14 @@ func @extract_slice() { } // CHECK-LABEL: trans -func @trans() { +func.func @trans() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> %b = tt.trans %cst0 : (tensor<16x32xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED_T> return } // CHECK-LABEL: insert_slice_async -func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { +func.func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> @@ -135,7 +135,7 @@ func @insert_slice_async(%A : !tt.ptr, %i1 : i1) { } // CHECK-LABEL: insert_slice -func @insert_slice(%A : !tt.ptr, %i1 : i1) { +func.func @insert_slice(%A : !tt.ptr, %i1 : i1) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<16x16x!tt.ptr, #AL> %mask = tt.splat %i1 : (i1) -> tensor<16x16xi1, #AL> %other = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> @@ -153,7 +153,7 @@ func @insert_slice(%A : !tt.ptr, %i1 : i1) { // If branch inserted a barrier for %cst0 and %cst1, but else didn't, then the barrier should be inserted in the parent region // CHECK-LABEL: multi_blocks -func @multi_blocks(%i1 : i1) { +func.func @multi_blocks(%i1 : i1) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { @@ -174,7 +174,7 @@ func @multi_blocks(%i1 : i1) { // Both branches inserted a barrier for %cst0 and %cst1, then the barrier doesn't need to be inserted in the parent region // CHECK-LABEL: multi_blocks_join_barrier -func @multi_blocks_join_barrier(%i1 : i1) { +func.func @multi_blocks_join_barrier(%i1 : i1) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { @@ -192,7 +192,7 @@ func @multi_blocks_join_barrier(%i1 : i1) { // Read yielded tensor requires a barrier // CHECK-LABEL: multi_blocks_yield -func @multi_blocks_yield(%i1 : i1) { +func.func @multi_blocks_yield(%i1 : i1) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %a = scf.if %i1 -> (tensor<32x16xf16, #A_SHARED>) { @@ -212,7 +212,7 @@ func @multi_blocks_yield(%i1 : i1) { // Conservatively add a barrier as if the branch (%i1) is never taken // CHECK-LABEL: multi_blocks_noelse -func @multi_blocks_noelse(%i1 : i1) { +func.func @multi_blocks_noelse(%i1 : i1) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { @@ -226,7 +226,7 @@ func @multi_blocks_noelse(%i1 : i1) { // Conservatively add a barrier as if the branch (%i2) is never taken // CHECK-LABEL: multi_blocks_nested_scf -func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { +func.func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> %cst1 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> scf.if %i1 { @@ -247,7 +247,7 @@ func @multi_blocks_nested_scf(%i1 : i1, %i2 : i1) { } // CHECK-LABEL: for -func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { +func.func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> @@ -262,7 +262,7 @@ func @for(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.p // Although a_shared and b_shared are synced before entering the loop, // they are reassociated with aliases (c_shared) and thus require a barrier. // CHECK-LABEL: for_alias -func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { +func.func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: Membar 2 @@ -282,7 +282,7 @@ func @for_alias(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // Although cst2 is not an argument of scf.yield, its memory is reused by cst1. // So we need a barrier both before and after cst1 // CHECK-LABEL: for_reuse -func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { +func.func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: Membar 2 @@ -302,7 +302,7 @@ func @for_reuse(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : // CHECK-LABEL: for_reuse_nested -func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { +func.func @for_reuse_nested(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> %b_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED> // CHECK-NEXT: Membar 2 diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir index e9ee502435..0e979b148d 100644 --- a/test/Conversion/triton_ops.mlir +++ b/test/Conversion/triton_ops.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s | FileCheck %s -func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { +func.func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { // scalar -> scalar // CHECK: i64 -> !tt.ptr %0 = tt.int_to_ptr %scalar_i64 : i64 -> !tt.ptr @@ -35,7 +35,7 @@ func @cast_ops(%scalar_ptr: !tt.ptr, %scalar_f32: f32, %scalar_i64: i64) { return } -func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { +func.func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { // scalar -> scalar // CHECK: !tt.ptr %0 = tt.addptr %scalar_ptr, %scalar_i32 : !tt.ptr, i32 @@ -54,7 +54,7 @@ func @addptr_ops(%scalar_ptr: !tt.ptr, %scalar_i32: i32) { return } -func @load_store_ops_scalar(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %mask : i1) { +func.func @load_store_ops_scalar(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %mask : i1) { // Test if Load/Store ops can handle scalar values %other = arith.constant 0.0e+0 : f32 @@ -76,7 +76,7 @@ func @load_store_ops_scalar(%ptr: !tt.ptr {tt.divisibility = 16 : i32}, %ma return } -func @reduce_ops_infer(%ptr: !tt.ptr, %v : tensor<1x2x4xf32>) { +func.func @reduce_ops_infer(%ptr: !tt.ptr, %v : tensor<1x2x4xf32>) { // Test if reduce ops infer types correctly // CHECK: %{{.*}} = tt.reduce %{{.*}} -> tensor<2x4xf32> @@ -101,7 +101,7 @@ func @reduce_ops_infer(%ptr: !tt.ptr, %v : tensor<1x2x4xf32>) { return } -func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { +func.func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { // Test if reduce ops infer types correctly %v128x32 = tt.splat %v : (f32) -> tensor<128x32xf32> %v32x128 = tt.splat %v : (f32) -> tensor<32x128xf32> diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index a160bc8815..b461ca542f 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -1,6 +1,6 @@ // RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu=num-warps=2 | FileCheck %s -func @ops() { +func.func @ops() { // CHECK: module attributes {"triton_gpu.num-warps" = 2 : i32} {{.*}} %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> @@ -11,7 +11,7 @@ func @ops() { // ----- -func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { +func.func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // Test if LoadOp is lowered properly (see #771) %ptrs = tt.splat %ptr : (!tt.ptr) -> tensor<128x!tt.ptr> %mask = arith.constant dense : tensor<128xi1> @@ -30,7 +30,7 @@ func @load_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // ----- -func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { +func.func @reduce_ops(%ptr: !tt.ptr {tt.divisibility = 16 : i32}) { // Test if the total number of threadsPerWarp is 32 // Test if the total number of warps is 2 // CHECK: #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 2], order = [0, 1]}> diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index e9e7d5a340..507b362c99 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -4,7 +4,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr) // Here the 128 comes from the 4 in module attribute multiples 32 // CHECK: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = 128 : i32} {{.*}} - func @test_empty_kernel(%lb : index, %A : !tt.ptr) { + func.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { // CHECK: llvm.return return } @@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_load - func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { + func.func @basic_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm // CHECK: llvm.inline_asm %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> @@ -28,7 +28,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: vectorized_load - func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { + func.func @vectorized_load(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm // CHECK-SAME: ld.global.b32 // CHECK: llvm.inline_asm @@ -43,7 +43,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: vectorized_load_f16 - func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { + func.func @vectorized_load_f16(%a_ptr_init: tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>, %cst_0 : tensor<256xf16, #blocked0>) { // CHECK: llvm.inline_asm // CHECK-SAME: ld.global.b16 // CHECK: llvm.inline_asm @@ -59,7 +59,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: masked_load_const_other - func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { + func.func @masked_load_const_other(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> return @@ -72,7 +72,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: masked_load_const_other_vec - func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { + func.func @masked_load_const_other_vec(%a_ptr_init : tensor<256x!tt.ptr, #blocked0>, %cst : tensor<256xi1, #blocked0>) { %cst_0 = arith.constant dense<0.000000e+00> : tensor<256xf32, #blocked0> %1 = tt.load %a_ptr_init, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256xf32, #blocked0> return @@ -84,7 +84,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> module attributes {"triton_gpu.num-warps" = 2 : i32} { // CHECK-LABEL: global_load_store_no_vec - func @global_load_store_no_vec(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { + func.func @global_load_store_no_vec(%arg0: !tt.ptr {tt.divisibility = 4 : i32}, %arg1: !tt.ptr {tt.divisibility = 4 : i32}, %arg2: !tt.ptr {tt.divisibility = 4 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.muli %0, %c256_i32 : i32 @@ -128,7 +128,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> module attributes {"triton_gpu.num-warps" = 2 : i32} { // CHECK-LABEL: global_load_store_vec4 - func @global_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + func.func @global_load_store_vec4(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.muli %0, %c256_i32 : i32 @@ -165,7 +165,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { #blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}> // Note, the %n_elements doesn't have a "tt.divisibility" hint, so Triton assumes it's divisibility is 1, this should effect the mask's alignment and further restrict the load/store ops' vector width to be 1. module attributes {"triton_gpu.num-warps" = 2 : i32} { - func @vecadd_masked_vec1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { + func.func @vecadd_masked_vec1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %n_elements: i32) { %c64_i32 = arith.constant 64 : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.muli %0, %c64_i32 : i32 @@ -195,7 +195,7 @@ module attributes {"triton_gpu.num-warps" = 2 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec2 - func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { + func.func @global_load_store_vec2(%arg0: !tt.ptr {tt.divisibility = 8 : i32}, %arg1: !tt.ptr {tt.divisibility = 8 : i32}, %arg2: !tt.ptr {tt.divisibility = 8 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.muli %0, %c256_i32 : i32 @@ -240,7 +240,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: global_load_store_vec8 - func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { + func.func @global_load_store_vec8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { %c256_i32 = arith.constant 256 : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = arith.muli %0, %c256_i32 : i32 @@ -283,7 +283,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_view_broadcast - func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) { + func.func @basic_view_broadcast(%arg : tensor<256xf32,#blocked0>) { // CHECK: llvm.mlir.undef // CHECK: %[[T0:.*]] = llvm.extractvalue // CHECK: %[[T1:.*]] = llvm.extractvalue @@ -307,7 +307,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_make_range - func @basic_make_range() { + func.func @basic_make_range() { // CHECK: nvvm.read.ptx.sreg.tid.x // CHECK: llvm.mlir.undef // CHECK: llvm.insertvalue @@ -322,7 +322,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addf - func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { + func.func @basic_addf(%arg0 : tensor<256xf32,#blocked0>, %arg1 : tensor<256xf32,#blocked0>) { // CHECK: llvm.fadd // CHECK: llvm.fadd %1 = arith.addf %arg0, %arg1 : tensor<256xf32,#blocked0> @@ -335,7 +335,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addi - func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { + func.func @basic_addi(%arg0 : tensor<256xi32,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { // CHECK: llvm.add // CHECK: llvm.add %1 = arith.addi %arg0, %arg1 : tensor<256xi32,#blocked0> @@ -347,7 +347,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_program_id - func @basic_program_id() { + func.func @basic_program_id() { // CHECK: nvvm.read.ptx.sreg.ctaid.x : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32 return @@ -359,7 +359,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_addptr - func @basic_addptr(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { + func.func @basic_addptr(%arg0 : tensor<256x!tt.ptr,#blocked0>, %arg1 : tensor<256xi32,#blocked0>) { // CHECK: llvm.getelementptr // CHECK: llvm.getelementptr %0 = tt.addptr %arg0, %arg1 : tensor<256x!tt.ptr, #blocked0>, tensor<256xi32, #blocked0> @@ -373,7 +373,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_alloc_tensor - func @basic_alloc_tensor() { + func.func @basic_alloc_tensor() { // CHECK: llvm.mlir.addressof @global_smem // CHECK-NEXT: llvm.bitcast // CHECK-NEXT: llvm.mlir.constant @@ -390,7 +390,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: llvm.mlir.global external @global_smem // CHECK-LABEL: basic_extract_slice - func @basic_extract_slice() { + func.func @basic_extract_slice() { // CHECK: llvm.mlir.addressof @global_smem // CHECK: llvm.extractvalue // CHECK-NEXT: llvm.extractvalue @@ -423,7 +423,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_async_wait - func @basic_async_wait() { + func.func @basic_async_wait() { // CHECK: cp.async.wait_group 0x4 triton_gpu.async_wait {num = 4: i32} return @@ -442,7 +442,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_fallback - func @basic_insert_slice_async_fallback(%arg0: !tt.ptr {tt.divisibility = 1 : i32}) { + func.func @basic_insert_slice_async_fallback(%arg0: !tt.ptr {tt.divisibility = 1 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> @@ -481,7 +481,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #A = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v4 - func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 32 : i32}) { + func.func @basic_insert_slice_async_v4(%arg0: !tt.ptr {tt.divisibility = 32 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<64xi32, #slice3d0> %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> @@ -523,7 +523,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1 - func @basic_insert_slice_async_v1(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + func.func @basic_insert_slice_async_v1(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { %off0_ = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #slice2d1> %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0> %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<16xi32, #slice2d1>) -> tensor<16x1xi32, #block2> @@ -568,7 +568,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #A = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_insert_slice_async_v1_multictas - func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + func.func @basic_insert_slice_async_v1_multictas(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { %off0_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice2d1> %off1_ = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #slice3d0> %off0 = tt.expand_dims %off0_ {axis = 1 : i32} : (tensor<32xi32, #slice2d1>) -> tensor<32x1xi32, #block2> @@ -619,7 +619,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: basic_splat - func @basic_splat(%ptr: !tt.ptr) { + func.func @basic_splat(%ptr: !tt.ptr) { // CHECK: llvm.mlir.undef // CHECK: llvm.insertvalue // CHECK: llvm.insertvalue @@ -633,7 +633,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: basic_store - func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { + func.func @basic_store(%ptrs: tensor<256x!tt.ptr, #blocked0>, %vals: tensor<256xf32, #blocked0>, %mask: tensor<256xi1, #blocked0>) { // CHECK: llvm.inline_asm // CHECK-SAME: st.global.b32 [ ${{.*}} + 0 ], { ${{.*}} }; // CHECK: llvm.inline_asm @@ -650,7 +650,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_blocked - func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { + func.func @convert_layout_blocked_blocked(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem // CHECK: llvm.store // CHECK-SAME: !llvm.ptr, 3> @@ -697,7 +697,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_blocked_vec - func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { + func.func @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem // CHECK: llvm.store // CHECK-SAME: !llvm.ptr, 3> @@ -720,7 +720,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_blocked_multi_rep - func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { + func.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) { // CHECK: llvm.mlir.addressof @global_smem // CHECK: llvm.store // CHECK-SAME: !llvm.ptr, 3> @@ -751,7 +751,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_dot - func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { + func.func @convert_dot(%A: tensor<16x16xf16, #blocked0>, %B: tensor<16x16xf16, #blocked0>) { %AA = triton_gpu.convert_layout %A : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> %BB = triton_gpu.convert_layout %B : (tensor<16x16xf16, #blocked0>) -> tensor<16x16xf16, #shared0> // CHECK: llvm.inline_asm @@ -775,7 +775,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { // TODO: problems in MLIR's parser on slice layout // #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> // module attributes {"triton_gpu.num-warps" = 1 : i32} { -// func @make_range_sliced_layout() { +// func.func @make_range_sliced_layout() { // %0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #triton_gpu.slice<{dim = 0, parent = #blocked0}>> // return // } @@ -788,7 +788,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_mmav2_block - func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { + func.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) { // CHECK: llvm.store // CHECK-SAME: !llvm.ptr, 3> // CHECK: llvm.store @@ -808,7 +808,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_mmav1_block - func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) { + func.func @convert_layout_mmav1_blocked(%arg0: tensor<32x64xf32, #mma>) { // CHECK: llvm.store // CHECK-SAME: !llvm.ptr, 3> // CHECK: llvm.store @@ -831,7 +831,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK: llvm.mlir.global external @global_smem() {addr_space = 3 : i32} : !llvm.array<0 x i8> // CHECK-LABEL: convert_layout_blocked_shared - func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { + func.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) { // CHECK: llvm.store // CHECK-SAME: !llvm.ptr, 3> // CHECK: llvm.store @@ -847,7 +847,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice0 - func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { + func.func @convert_blocked1d_to_slice0(%src:tensor<32xi32, #blocked0>) { // CHECK-COUNT-4: llvm.load {{.*}} : !llvm.ptr, 3> %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> return @@ -860,7 +860,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice1 - func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { + func.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { // CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr, 3> %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> return @@ -873,7 +873,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked_to_blocked_ptr - func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr, #blocked0>) { + func.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr, #blocked0>) { // CHECK: llvm.ptrtoint // CHECK: llvm.store // CHECK: nvvm.barrier0 @@ -892,7 +892,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + func.func @matmul_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:tensor<128x32xf16, #shared>, %b:tensor<32x256xf16, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #mma> // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 @@ -918,7 +918,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, isMMAv1Row=true}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, isMMAv1Row=true}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + func.func @matmul884_kernel_dot_operand_layout(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:tensor<32x64xf16, #shared0>, %b:tensor<64x64xf16, #shared1>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x64xf32, #mma> // CHECK: ldmatrix.sync.aligned.m8n8.x4.shared.b16 @@ -941,7 +941,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#blocked}> #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#blocked}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + func.func @matmul_fmadot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #blocked> // CHECK: llvm.intr.fmuladd @@ -965,7 +965,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: matmul_tf32dot - func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, + func.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> // CHECK: llvm.inline_asm @@ -1000,7 +1000,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32 - func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { + func.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { // CHECK: llvm.inline_asm // CHECK-SAME: atom.global.gpu.add.f32 %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> @@ -1012,7 +1012,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { -func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { +func.func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { %blockidx = tt.get_program_id {axis=0:i32} : i32 %blockidy = tt.get_program_id {axis=1:i32} : i32 %blockidz = tt.get_program_id {axis=2:i32} : i32 @@ -1032,7 +1032,7 @@ func @test_get_program_id(%a: tensor<32x!tt.ptr, #blocked0>) { // ----- #blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { - func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { + func.func @test_get_num_program(%a: tensor<32x!tt.ptr, #blocked0>) { // CHECK: nvvm.read.ptx.sreg.nctaid.x // CHECK: nvvm.read.ptx.sreg.nctaid.y // CHECK: nvvm.read.ptx.sreg.nctaid.z @@ -1052,7 +1052,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: test_index_cache - func @test_index_cache() { + func.func @test_index_cache() { // CHECK: nvvm.read.ptx.sreg.tid.x %0 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked0> // CHECK-NOT: nvvm.read.ptx.sreg.tid.x @@ -1066,7 +1066,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: test_base_index_cache - func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { + func.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) { // CHECK: nvvm.read.ptx.sreg.tid.x %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> // CHECK-NOT: nvvm.read.ptx.sreg.tid.x @@ -1080,7 +1080,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #shared0 = #triton_gpu.shared<{vec = 8, perPhase = 2, maxPhase = 4, order = [1, 0]}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: test_index_cache_different_block - func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { + func.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) { // CHECK: nvvm.read.ptx.sreg.tid.x %0 = triton_gpu.convert_layout %arg0 : (tensor<128x32xf32, #blocked0>) -> tensor<128x32xf32, #shared0> scf.if %arg1 { diff --git a/test/Target/tritongpu_to_llvmir.mlir b/test/Target/tritongpu_to_llvmir.mlir index cafff3ca60..114d3a9eb2 100644 --- a/test/Target/tritongpu_to_llvmir.mlir +++ b/test/Target/tritongpu_to_llvmir.mlir @@ -4,11 +4,11 @@ // CHECK-LABEL: ; ModuleID = 'LLVMDialectModule' // CHECK: define void @test_empty_kernel // CHECK: !nvvm.annotations -// CHECK: !{void (i32, half addrspace(1)*)* @test_empty_kernel, !"maxntidx", i32 128} +// CHECK: !{ptr @test_empty_kernel, !"maxntidx", i32 128} module attributes {"triton_gpu.num-warps" = 4 : i32} { -func @test_empty_kernel(%lb : index, %A : !tt.ptr) { +func.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { return } diff --git a/test/Target/tritongpu_to_ptx.mlir b/test/Target/tritongpu_to_ptx.mlir index 404e970a29..12742ad9e2 100644 --- a/test/Target/tritongpu_to_ptx.mlir +++ b/test/Target/tritongpu_to_ptx.mlir @@ -6,7 +6,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { -func @test_empty_kernel(%lb : index, %A : !tt.ptr) { +func.func @test_empty_kernel(%lb : index, %A : !tt.ptr) { return } diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index 050a3f7565..5ef6790e69 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -2,10 +2,10 @@ // RUN: triton-opt %s -split-input-file -canonicalize -triton-combine | FileCheck %s // CHECK-LABEL: @test_combine_dot_add_pattern -func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) { - // CHECK: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32> - // CHECK: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32> - // CHECK: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32> +func.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32>) { + // CHECK-DAG: %[[d:.*]] = arith.constant dense<3.000000e+00> : tensor<128x128xf32> + // CHECK-DAG: %[[b:.*]] = arith.constant dense<2.000000e+00> : tensor<128x128xf32> + // CHECK-DAG: %[[a:.*]] = arith.constant dense<1.000000e+00> : tensor<128x128xf32> %a = arith.constant dense<1.0> : tensor<128x128xf32> %b = arith.constant dense<2.0> : tensor<128x128xf32> %zero = arith.constant dense<0.0> : tensor<128x128xf32> @@ -24,7 +24,7 @@ func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128xf32 // COM: CHECK-LABEL: @test_combine_addptr_pattern -func @test_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { +func.func @test_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> { %off0 = arith.constant 10 : i32 %off1 = arith.constant 15 : i32 @@ -47,46 +47,46 @@ func @test_combine_addptr_pattern(%base: !tt.ptr) -> tensor<8x!tt.ptr> // CHECK-LABEL: @test_combine_select_masked_load_pattern -func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) { +func.func @test_combine_select_masked_load_pattern(%ptr: tensor<8x!tt.ptr>, %cond: i1) -> (tensor<8xf32>, tensor<8xf32>) { %mask = tt.broadcast %cond : (i1) -> tensor<8xi1> %false_val = arith.constant dense<0.0> : tensor<8xf32> // CHECK: %[[res1:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> %x = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> - %0 = select %cond, %x, %false_val : tensor<8xf32> + %0 = arith.select %cond, %x, %false_val : tensor<8xf32> // CHECK: %[[res2:.*]] = tt.load %{{.*}}, %{{.*}}, %{{.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> %y = tt.load %ptr, %mask, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> - %1 = select %cond, %y, %false_val : tensor<8xf32> + %1 = arith.select %cond, %y, %false_val : tensor<8xf32> // CHECK: return %[[res1]], %[[res2]] : tensor<8xf32>, tensor<8xf32> return %0, %1 : tensor<8xf32>, tensor<8xf32> } // CHECK-LABEL: @test_combine_select_masked_load_fail_pattern -func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { +func.func @test_combine_select_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %dummy_load: tensor<8xf32>, %dummy_broadcast: tensor<8xi1>, %cond0: i1, %cond1: i1) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { %false_val = arith.constant dense<0.0> : tensor<8xf32> // Case 1: value at the "load" position is not an "op". Select should not be canonicalized. - // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> - %0 = select %cond0, %dummy_load, %false_val : tensor<8xf32> + // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> + %0 = arith.select %cond0, %dummy_load, %false_val : tensor<8xf32> // Case 2: value at the "broadcast" position is not an "op". Select should not be canonicalized. %real_load0 = tt.load %ptr, %dummy_broadcast, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> - // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> - %1 = select %cond0, %real_load0, %false_val : tensor<8xf32> + // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> + %1 = arith.select %cond0, %real_load0, %false_val : tensor<8xf32> // Case 3: condition of "broadcast" is not the same as the condition of "select". Select should not be canonicalized. %cond0_ = tt.broadcast %cond0 : (i1) -> tensor<8xi1> %real_load1 = tt.load %ptr, %cond0_, %false_val {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<8xf32> - // CHECK: %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> - %2 = select %cond1, %real_load1, %false_val : tensor<8xf32> + // CHECK: %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> + %2 = arith.select %cond1, %real_load1, %false_val : tensor<8xf32> return %0, %1, %2 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32> } // CHECK-LABEL: @test_combine_broadcast_constant_pattern -func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { +func.func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { // CHECK: %[[cst:.*]] = arith.constant dense<1.000000e+00> : tensor<8x2xf32> %const = arith.constant dense<1.0> : tensor<8xf32> %bst_out = tt.broadcast %const : (tensor<8xf32>) -> tensor<8x2xf32> @@ -96,7 +96,7 @@ func @test_combine_broadcast_constant_pattern(%cst : f32) -> tensor<8x2xf32> { } // CHECK-LABEL: @test_canonicalize_masked_load_pattern -func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { +func.func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr>) -> (tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) { %true_mask = arith.constant dense : tensor<8xi1> %false_mask = arith.constant dense : tensor<8xi1> %other_val = arith.constant dense<0.0> : tensor<8xf32> @@ -117,7 +117,7 @@ func @test_canonicalize_masked_load_pattern(%ptr: tensor<8x!tt.ptr>) -> (te } // CHECK-LABEL: @test_canonicalize_masked_load_fail_pattern -func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) { +func.func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, %mask: tensor<8xi1>) -> (tensor<8xf32>, tensor<8xf32>) { %other_val = arith.constant dense<0.0> : tensor<8xf32> // Case: value at the "mask" position is not an "op". Load should not be canonicalized. @@ -130,7 +130,7 @@ func @test_canonicalize_masked_load_fail_pattern(%ptr: tensor<8x!tt.ptr>, % } // CHECK-LABEL: @test_canonicalize_masked_store_pattern -func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>) { +func.func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>) { %true_mask = arith.constant dense : tensor<8xi1> %false_mask = arith.constant dense : tensor<8xi1> @@ -144,7 +144,7 @@ func @test_canonicalize_masked_store_pattern(%ptr: tensor<8x!tt.ptr>, %val: } // CHECK-LABEL: @test_canonicalize_masked_store_fail_pattern -func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>, %mask: tensor<8xi1>) { +func.func @test_canonicalize_masked_store_fail_pattern(%ptr: tensor<8x!tt.ptr>, %val: tensor<8xf32>, %mask: tensor<8xi1>) { // Case: value at the "mask" position is not an "op". Store should not be canonicalized. // CHECK: tt.store %{{.*}}, %{{.*}}, %{{.*}} : tensor<8xf32> tt.store %ptr, %val, %mask : tensor<8xf32> diff --git a/test/Triton/vecadd.mlir b/test/Triton/vecadd.mlir index 0b69ef3054..f5019b1cdd 100644 --- a/test/Triton/vecadd.mlir +++ b/test/Triton/vecadd.mlir @@ -1,7 +1,7 @@ // RUN: triton-opt %s -verify-diagnostics module { - func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { + func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { %0 = tt.get_program_id {axis = 0 : i32} : i32 %c256_i32 = arith.constant 256 : i32 %1 = arith.muli %0, %c256_i32 : i32 @@ -43,7 +43,7 @@ module { } } // module { -// func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { +// func.func @add_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32__(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32) { // %c64 = arith.constant 64 : index // %c32 = arith.constant 32 : index // %c0 = arith.constant 0 : index diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 60e359f527..51cccccfbd 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -19,7 +19,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: [[store_val:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xf32, [[col_layout]]> // CHECK: [[store_mask:%.*]] = triton_gpu.convert_layout {{.*}} -> tensor<64x64xi1, [[col_layout]]> // CHECK: tt.store [[store_ptr]], [[store_val]], [[store_mask]] -func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, +func.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 2c009ffa48..7e9cb9d504 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -9,7 +9,7 @@ // CHECK: [[col_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> // CHECK: [[col_layout_novec:#.*]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}> // CHECK-LABEL: cst -func @cst() -> tensor<1024xi32, #layout1> { +func.func @cst() -> tensor<1024xi32, #layout1> { %cst = arith.constant dense<0> : tensor<1024xi32, #layout0> %1 = triton_gpu.convert_layout %cst : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> // CHECK-NOT: triton_gpu.convert_layout @@ -18,7 +18,7 @@ func @cst() -> tensor<1024xi32, #layout1> { } // CHECK-LABEL: range -func @range() -> tensor<1024xi32, #layout1> { +func.func @range() -> tensor<1024xi32, #layout1> { %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> %1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> // CHECK-NOT: triton_gpu.convert_layout @@ -27,7 +27,7 @@ func @range() -> tensor<1024xi32, #layout1> { } // CHECK-LABEL: splat -func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { +func.func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { %0 = tt.splat %arg0 : (i32) -> tensor<1024xi32, #layout0> %1 = triton_gpu.convert_layout %0 : (tensor<1024xi32, #layout0>) -> tensor<1024xi32, #layout1> // CHECK-NOT: triton_gpu.convert_layout @@ -36,7 +36,7 @@ func @splat(%arg0: i32) -> tensor<1024xi32, #layout1> { } // CHECK-LABEL: remat -func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { +func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { %0 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> %1 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #layout0> %2 = arith.muli %0, %1 : tensor<1024xi32, #layout0> @@ -56,7 +56,7 @@ func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> { } // CHECK-LABEL: remat_load_store -func @remat_load_store(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { +func.func @remat_load_store(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0> %1 = tt.splat %arg : (!tt.ptr) -> tensor<64x!tt.ptr, #layout0> %2 = tt.addptr %1, %0 : tensor<64x!tt.ptr, #layout0>, tensor<64xi32, #layout0> @@ -70,7 +70,7 @@ func @remat_load_store(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { // Don't rematerialize vectorized loads // CHECK-LABEL: remat_expensive -func @remat_expensive(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { +func.func @remat_expensive(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout1> %1 = tt.splat %arg : (!tt.ptr) -> tensor<64x!tt.ptr, #layout1> %2 = tt.addptr %1, %0 : tensor<64x!tt.ptr, #layout1>, tensor<64xi32, #layout1> @@ -85,7 +85,7 @@ func @remat_expensive(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { // Don't rematerialize loads when original and target layouts are different // CHECK-LABEL: remat_multi_layout -func @remat_multi_layout(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { +func.func @remat_multi_layout(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0> %1 = tt.splat %arg : (!tt.ptr) -> tensor<64x!tt.ptr, #layout0> %2 = tt.addptr %1, %0 : tensor<64x!tt.ptr, #layout0>, tensor<64xi32, #layout0> @@ -100,7 +100,7 @@ func @remat_multi_layout(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { // Always rematerialize single value loads // CHECK-LABEL: remat_single_value -func @remat_single_value(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { +func.func @remat_single_value(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { %0 = tt.splat %arg : (!tt.ptr) -> tensor<1x!tt.ptr, #layout1> %1 = tt.load %0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1xi32, #layout1> // CHECK-NOT: triton_gpu.convert_layout @@ -111,7 +111,7 @@ func @remat_single_value(%arg: !tt.ptr {tt.divisibility = 16 : i32}) { } // CHECK-LABEL: if -func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { +func.func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { // CHECK-NOT: triton_gpu.convert_layout %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout1> %0 = tt.get_program_id {axis = 0 : i32} : i32 @@ -128,7 +128,7 @@ func @if(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { } // CHECK-LABEL: if_convert_else_not -func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { +func.func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0> @@ -149,7 +149,7 @@ func @if_convert_else_not(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 } // CHECK-LABEL: if_not_else_convert -func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { +func.func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0> @@ -170,7 +170,7 @@ func @if_not_else_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 } // CHECK-LABEL: if_else_both_convert -func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { +func.func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { %c32_i32 = arith.constant dense<32> : tensor<1024xi32, #layout0> %0 = tt.get_program_id {axis = 0 : i32} : i32 %1 = tt.splat %0 : (i32) -> tensor<1024xi32, #layout0> @@ -200,7 +200,7 @@ func @if_else_both_convert(%arg0: i32, %arg1: !tt.ptr {tt.divisibility = 16 #blocked4 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [0, 1]}> // CHECK-LABEL: transpose -func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { +func.func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}) { // CHECK-NOT: triton_gpu.convert_layout // CHECK: [[loaded_val:%.*]] = tt.load {{.*}}, {{%cst.*}}, {{%cst.*}} {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf32, [[row_layout]]> // CHECK: [[cvt_val:%.*]] = triton_gpu.convert_layout [[loaded_val]] : (tensor<64x64xf32, [[row_layout]]>) -> tensor<64x64xf32, [[col_layout]]> @@ -241,7 +241,7 @@ func @transpose(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 {tt } // CHECK-LABEL: loop -func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { +func.func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %arg4: i32) { // CHECK-NOT: triton_gpu.convert_layout // CHECK: [[loop_ret:%.*]]:2 = scf.for {{.*}} -> (tensor<64x64xf32, [[row_layout]]>, tensor<64x64x!tt.ptr, [[row_layout]]>) // CHECK-NEXT: {{.*}} = tt.load {{.*}} : tensor<64x64xf32, [[row_layout]]> @@ -295,7 +295,7 @@ func @loop(%arg0: !tt.ptr, %arg1: i32, %arg2: !tt.ptr, %arg3: i32, %ar } // CHECK-LABEL: vecadd -func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { +func.func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) { // CHECK-NOT: triton_gpu.convert_layout %c256_i32 = arith.constant 256 : i32 %0 = tt.get_program_id {axis = 0 : i32} : i32 @@ -327,7 +327,7 @@ func @vecadd(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { +func.func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}) { // CHECK-NOT: triton_gpu.convert_layout %cst = arith.constant dense<30000> : tensor<1x1xi32, #blocked2> %cst_0 = arith.constant dense<30000> : tensor<1x512xi32, #blocked2> @@ -378,7 +378,7 @@ func @select(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: !tt.ptr {tt.divisibility = 16 : i32}, %arg13: !tt.ptr {tt.divisibility = 16 : i32}, %arg14: !tt.ptr {tt.divisibility = 16 : i32}, %arg15: !tt.ptr {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) { +func.func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: !tt.ptr {tt.divisibility = 16 : i32}, %arg7: !tt.ptr {tt.divisibility = 16 : i32}, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: !tt.ptr {tt.divisibility = 16 : i32}, %arg10: !tt.ptr {tt.divisibility = 16 : i32}, %arg11: !tt.ptr {tt.divisibility = 16 : i32}, %arg12: !tt.ptr {tt.divisibility = 16 : i32}, %arg13: !tt.ptr {tt.divisibility = 16 : i32}, %arg14: !tt.ptr {tt.divisibility = 16 : i32}, %arg15: !tt.ptr {tt.divisibility = 16 : i32}, %arg16: i32 {tt.divisibility = 16 : i32}) { %cst = arith.constant dense<1.000000e+00> : tensor<1024xf32, #blocked0> %cst_0 = arith.constant dense<5.000000e-04> : tensor<1024xf32, #blocked0> %cst_1 = arith.constant dense<0.999499976> : tensor<1024xf32, #blocked0> @@ -775,7 +775,7 @@ func public @long_func(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: // A mnist model from torch inductor. // Check if topological sort is working correct and there's no unnecessary convert // CHECK-LABEL: mnist -func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) { +func.func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: i32) { // CHECK-NOT: triton_gpu.convert_layout %cst = arith.constant dense<10> : tensor<16x1xi32, #blocked2> %cst_0 = arith.constant dense<10> : tensor<1x16xi32, #blocked3> @@ -862,7 +862,7 @@ func public @mnist(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt. #blocked5 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}> // cmpf and cmpi have different operands and result types // CHECK-LABEL: cmp -func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { +func.func public @cmp(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}) { %c64 = arith.constant 64 : index %c2048 = arith.constant 2048 : index %c0 = arith.constant 0 : index diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 6ee3b15fbc..663f2da7b0 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -10,7 +10,7 @@ #A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> #B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> -// CHECK: func @matmul_loop +// CHECK: func.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 @@ -46,8 +46,8 @@ // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] -func @matmul_loop(%lb : index, %ub : index, %step : index, - %A : !tt.ptr {tt.divisibility = 16 : i32}, +func.func @matmul_loop(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) { // A ptrs %a_ptr_splat = tt.splat %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> @@ -61,7 +61,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> - + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> @@ -88,7 +88,7 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, } -// CHECK: func @matmul_loop_nested +// CHECK: func.func @matmul_loop_nested // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 @@ -118,8 +118,8 @@ func @matmul_loop(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] -func @matmul_loop_nested(%lb : index, %ub : index, %step : index, - %A : !tt.ptr {tt.divisibility = 16 : i32}, +func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) { scf.for %iv0 = %lb to %ub step %step { // A ptrs @@ -134,7 +134,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %b_tmp1 = tt.expand_dims %b_tmp0 {axis = 0 : i32} : (tensor<128xi32, #BLs0>) -> tensor<1x128xi32, #BL> %b_offs = tt.broadcast %b_tmp1 : (tensor<1x128xi32, #BL>) -> tensor<32x128xi32, #BL> %b_ptr_init = tt.addptr %b_ptr_splat, %b_offs : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> - + %a_mask = arith.constant dense : tensor<128x32xi1, #AL> %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> %b_mask = arith.constant dense : tensor<32x128xi1, #BL> @@ -161,7 +161,7 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, } -// CHECK: func @matmul_loop_single_pipeline +// CHECK: func.func @matmul_loop_single_pipeline // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 // CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32 // CHECK-DAG: %[[CONSTANT_2:.*]] = arith.constant 2 : i32 @@ -183,8 +183,8 @@ func @matmul_loop_nested(%lb : index, %ub : index, %step : index, // CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]] // CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]] // CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]] -func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, - %A : !tt.ptr {tt.divisibility = 16 : i32}, +func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, + %A : !tt.ptr {tt.divisibility = 16 : i32}, %B : !tt.ptr {tt.divisibility = 16 : i32}) { // A ptrs %a_ptr_splat = tt.splat %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir index 9bd5318e1e..01dc3f0ab1 100644 --- a/test/TritonGPU/matmul.mlir +++ b/test/TritonGPU/matmul.mlir @@ -4,7 +4,7 @@ // CHECK: offset = 49152, size = 49152 // CHECK: size = 98304 module { -func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) { +func.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c64_13c64_14c64_15c8(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32 {tt.divisibility = 16 : i32}, %arg11: i32) { %cst = arith.constant dense : tensor<64x64xi1> %c64 = arith.constant 64 : index %c0 = arith.constant 0 : index @@ -22,7 +22,7 @@ func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__12c6 %7 = arith.muli %6, %c8_i32 : i32 %8 = arith.subi %2, %7 : i32 %9 = arith.cmpi slt, %8, %c8_i32 : i32 - %10 = select %9, %8, %c8_i32 : i32 + %10 = arith.select %9, %8, %c8_i32 : i32 %11 = arith.remsi %0, %10 : i32 %12 = arith.addi %7, %11 : i32 %13 = arith.remsi %0, %5 : i32 diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index 52b4dddec1..b427547890 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -11,7 +11,7 @@ #B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> -// CHECK: func @matmul_loop +// CHECK: func.func @matmul_loop // CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[A0:.*]][0, 0] [128, 16] // CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]] // CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[B0:.*]][0, 0] [16, 128] @@ -28,7 +28,7 @@ // CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [16, 128] // CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]] // CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]] -func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { +func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> diff --git a/test/TritonGPU/update-mma-for-volta.mlir b/test/TritonGPU/update-mma-for-volta.mlir index d587fffcca..7571ec6185 100644 --- a/test/TritonGPU/update-mma-for-volta.mlir +++ b/test/TritonGPU/update-mma-for-volta.mlir @@ -15,7 +15,7 @@ // CHECK: [[new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}> module attributes {"triton_gpu.num-warps" = 16 : i32} { // CHECK-LABEL: dot_mmav1 - func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { + func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { %C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0> %AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a> %BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b> @@ -50,7 +50,7 @@ module attributes {"triton_gpu.num-warps" = 16 : i32} { module attributes {"triton_gpu.num-warps" = 16 : i32} { // CHECK-LABEL: dot_mmav1 - func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { + func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> { %C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0> %AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a> %BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b> diff --git a/test/lib/Analysis/TestAlias.cpp b/test/lib/Analysis/TestAlias.cpp index 88a4118fe9..3fd0cfd0d3 100644 --- a/test/lib/Analysis/TestAlias.cpp +++ b/test/lib/Analysis/TestAlias.cpp @@ -9,10 +9,10 @@ using namespace mlir; namespace { struct TestAliasPass - : public PassWrapper> { + : public PassWrapper> { + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass); - // LLVM15+ - // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAliasPass); static void print(StringRef name, SmallVector &vals, raw_ostream &os) { if (vals.empty()) @@ -39,23 +39,24 @@ struct TestAliasPass auto opName = SymbolTable::getSymbolName(operation).getValue().str(); os << opName << "\n"; - SharedMemoryAliasAnalysis analysis(&getContext()); - analysis.run(operation); + std::unique_ptr solver = createDataFlowSolver(); + SharedMemoryAliasAnalysis *analysis = + solver->load(); + if (failed(solver->initializeAndRun(operation))) + return signalPassFailure(); AsmState state(operation->getParentOfType()); // Get operation ids of value's aliases auto getAllocOpNames = [&](Value value) { - LatticeElement *latticeElement = - analysis.lookupLatticeElement(value); + dataflow::Lattice *latticeElement = + analysis->getLatticeElement(value); SmallVector opNames; - if (latticeElement) { + if (latticeElement && !latticeElement->isUninitialized()) { auto &info = latticeElement->getValue(); - if (!info.getAllocs().empty()) { - for (auto &alias : info.getAllocs()) { - auto opName = - getValueOperandName(alias.getDefiningOp()->getResult(0), state); - opNames.push_back(std::move(opName)); - } + for (auto &alias : info.getAllocs()) { + auto opName = + getValueOperandName(alias.getDefiningOp()->getResult(0), state); + opNames.push_back(std::move(opName)); } } // Ensure deterministic output diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp index 84108c4d36..35e42242bd 100644 --- a/test/lib/Analysis/TestAllocation.cpp +++ b/test/lib/Analysis/TestAllocation.cpp @@ -6,10 +6,9 @@ using namespace mlir; namespace { struct TestAllocationPass - : public PassWrapper> { + : public PassWrapper> { - // LLVM15+ - // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); StringRef getArgument() const final { return "test-print-allocation"; } StringRef getDescription() const final { diff --git a/test/lib/Analysis/TestAxisInfo.cpp b/test/lib/Analysis/TestAxisInfo.cpp index a5205bb0a0..22347c32f0 100644 --- a/test/lib/Analysis/TestAxisInfo.cpp +++ b/test/lib/Analysis/TestAxisInfo.cpp @@ -1,25 +1,15 @@ #include "mlir/Pass/Pass.h" #include "triton/Analysis/AxisInfo.h" +#include "triton/Analysis/Utility.h" using namespace mlir; namespace { struct TestAxisInfoPass - : public PassWrapper> { + : public PassWrapper> { - // LLVM15+ - // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass); - - void print(const std::string &name, raw_ostream &os, ArrayRef vals) { - os << name << ": ["; - for (size_t d = 0; d < vals.size(); d++) { - if (d != 0) - os << ", "; - os << vals[d]; - } - os << "]"; - } + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass); StringRef getArgument() const final { return "test-print-alignment"; } StringRef getDescription() const final { @@ -30,38 +20,19 @@ struct TestAxisInfoPass Operation *operation = getOperation(); auto &os = llvm::errs(); auto opName = SymbolTable::getSymbolName(operation).getValue().str(); - os << opName << "\n"; - AxisInfoAnalysis analysis(&getContext()); - analysis.run(operation); + os << "@" << opName << "\n"; + + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *analysis = solver->load(); + if (failed(solver->initializeAndRun(operation))) + return signalPassFailure(); operation->walk([&](Operation *op) { if (op->getNumResults() < 1) return; for (Value result : op->getResults()) { - // std::ostringstream oss; - // result.print(oss); - // os << " => "; - LatticeElement *latticeElement = - analysis.lookupLatticeElement(result); - if (!latticeElement) { - os << "None\n"; - return; - } - AxisInfo &info = latticeElement->getValue(); - print("Contiguity", os, info.getContiguity()); - os << " ; "; - print("Divisibility", os, info.getDivisibility()); - os << " ; "; - print("Constancy", os, info.getConstancy()); - os << " ; "; - auto constantValue = info.getConstantValue(); - os << "ConstantValue: ["; - if (constantValue.has_value()) - os << constantValue.value(); - else - os << "None"; - os << "] ( "; result.print(os); - os << " ) "; + os << " => "; + analysis->getLatticeElement(result)->getValue().print(os); os << "\n"; } }); diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp index df4279fe24..ab9b9f3fb7 100644 --- a/test/lib/Analysis/TestMembar.cpp +++ b/test/lib/Analysis/TestMembar.cpp @@ -1,4 +1,4 @@ -#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/Pass/Pass.h" #include "triton/Analysis/Allocation.h" @@ -9,10 +9,9 @@ using namespace mlir; namespace { struct TestMembarPass - : public PassWrapper> { + : public PassWrapper> { - // LLVM15+ - // MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass); + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass); StringRef getArgument() const final { return "test-print-membar"; } StringRef getDescription() const final {