diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 1a8e8737f7fed..71aaee931d543 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -986,6 +986,20 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, MLIR_CAPI_EXPORTED void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData); +/// Returns the number of successor blocks of the block. +MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block); + +/// Returns `pos`-th successor of the block. +MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block, + intptr_t pos); + +/// Returns the number of predecessor blocks of the block. +MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block); + +/// Returns `pos`-th predecessor of the block. +MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block, + intptr_t pos); + //===----------------------------------------------------------------------===// // Value API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index cbd35f2974ae9..eb8bad36e6312 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2626,6 +2626,85 @@ class PyOpSuccessors : public Sliceable { PyOperationRef operation; }; +/// A list of block successors. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) successor list is +/// associated with the operation and block whose successors these are, and thus +/// extends the lifetime of this operation and block. +class PyBlockSuccessors : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockSuccessors"; + + PyBlockSuccessors(PyBlock block, PyOperationRef operation, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumSuccessors(block.get()) + : length, + step), + operation(operation), block(block) {} + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + block.checkValid(); + return mlirBlockGetNumSuccessors(block.get()); + } + + PyBlock getRawElement(intptr_t pos) { + MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos); + return PyBlock(operation, block); + } + + PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyBlockSuccessors(block, operation, startIndex, length, step); + } + + PyOperationRef operation; + PyBlock block; +}; + +/// A list of block predecessors. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) predecessor list is +/// associated with the operation and block whose predecessors these are, and +/// thus extends the lifetime of this operation and block. +class PyBlockPredecessors : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockPredecessors"; + + PyBlockPredecessors(PyBlock block, PyOperationRef operation, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumPredecessors(block.get()) + : length, + step), + operation(operation), block(block) {} + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + block.checkValid(); + return mlirBlockGetNumPredecessors(block.get()); + } + + PyBlock getRawElement(intptr_t pos) { + MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos); + return PyBlock(operation, block); + } + + PyBlockPredecessors slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyBlockPredecessors(block, operation, startIndex, length, step); + } + + PyOperationRef operation; + PyBlock block; +}; + /// A list of operation attributes. Can be indexed by name, producing /// attributes, or by index, producing named attributes. class PyOpAttributeMap { @@ -3655,7 +3734,19 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("operation"), "Appends an operation to this block. If the operation is currently " - "in another block, it will be moved."); + "in another block, it will be moved.") + .def_prop_ro( + "successors", + [](PyBlock &self) { + return PyBlockSuccessors(self, self.getParentOperation()); + }, + "Returns the list of Block successors.") + .def_prop_ro( + "predecessors", + [](PyBlock &self) { + return PyBlockPredecessors(self, self.getParentOperation()); + }, + "Returns the list of Block predecessors."); //---------------------------------------------------------------------------- // Mapping of PyInsertionPoint. @@ -4099,6 +4190,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyBlockArgumentList::bind(m); PyBlockIterator::bind(m); PyBlockList::bind(m); + PyBlockSuccessors::bind(m); + PyBlockPredecessors::bind(m); PyOperationIterator::bind(m); PyOperationList::bind(m); PyOpAttributeMap::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index e0e386d55ede1..fbc66bcf5c2d0 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, unwrap(block)->print(stream); } +intptr_t mlirBlockGetNumSuccessors(MlirBlock block) { + return static_cast(unwrap(block)->getNumSuccessors()); +} + +MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) { + return wrap(unwrap(block)->getSuccessor(static_cast(pos))); +} + +intptr_t mlirBlockGetNumPredecessors(MlirBlock block) { + Block *b = unwrap(block); + return static_cast(std::distance(b->pred_begin(), b->pred_end())); +} + +MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) { + Block *b = unwrap(block); + Block::pred_iterator it = b->pred_begin(); + std::advance(it, pos); + return wrap(*it); +} + //===----------------------------------------------------------------------===// // Value API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py index 70ccaeeb5435b..ced5fce434728 100644 --- a/mlir/test/python/ir/blocks.py +++ b/mlir/test/python/ir/blocks.py @@ -1,12 +1,11 @@ # RUN: %PYTHON %s | FileCheck %s import gc -import io -import itertools -from mlir.ir import * + from mlir.dialects import builtin from mlir.dialects import cf from mlir.dialects import func +from mlir.ir import * def run(f): @@ -54,10 +53,25 @@ def testBlockCreation(): with InsertionPoint(middle_block) as middle_ip: assert middle_ip.block == middle_block cf.BranchOp([i32_arg], dest=successor_block) + module.print(enable_debug_info=True) # Ensure region back references are coherent. assert entry_block.region == middle_block.region == successor_block.region + assert len(entry_block.predecessors) == 0 + + assert len(entry_block.successors) == 1 + assert middle_block == entry_block.successors[0] + assert len(middle_block.predecessors) == 1 + assert entry_block == middle_block.predecessors[0] + + assert len(middle_block.successors) == 1 + assert successor_block == middle_block.successors[0] + assert len(successor_block.predecessors) == 1 + assert middle_block == successor_block.predecessors[0] + + assert len(successor_block.successors) == 0 + # CHECK-LABEL: TEST: testBlockCreationArgLocs @run pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy