Skip to content

[mlir][python] bind block predecessors and successors #145116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Jun 20, 2025

bind block.getSuccessor and block.getPredecessors.

@makslevental makslevental force-pushed the makslevental/py-block-successors branch 7 times, most recently from 173c5c3 to 2160339 Compare June 20, 2025 23:58
@makslevental makslevental marked this pull request as ready for review June 21, 2025 00:34
@llvmbot llvmbot added the mlir label Jun 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 21, 2025

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

bind block-getSuccessor.


Full diff: https://github.com/llvm/llvm-project/pull/145116.diff

4 Files Affected:

  • (modified) mlir/include/mlir-c/IR.h (+7)
  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+47-1)
  • (modified) mlir/lib/CAPI/IR/IR.cpp (+8)
  • (modified) mlir/test/python/ir/blocks.py (+14-3)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 1a8e8737f7fed..30763c0c8c052 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -986,6 +986,13 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
 MLIR_CAPI_EXPORTED void
 mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);
 
+/// Returns the number of successor blocks of the block.
+MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
+
+/// Returns `pos`-th successor of the block.
+MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block,
+                                                   intptr_t pos);
+
 //===----------------------------------------------------------------------===//
 // Value API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index cbd35f2974ae9..6f9db50e2aaa1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2626,6 +2626,45 @@ class PyOpSuccessors : public Sliceable<PyOpSuccessors, PyBlock> {
   PyOperationRef operation;
 };
 
+/// A list of block successors. Internally, these are stored as consecutive
+/// elements, random access is cheap. The (returned) successor list is
+/// associated with the operation and block whose successors these are, and thus
+/// extends the lifetime of this operation and block.
+class PyBlockSuccessors : public Sliceable<PyBlockSuccessors, PyBlock> {
+public:
+  static constexpr const char *pyClassName = "BlockSuccessors";
+
+  PyBlockSuccessors(PyBlock block, PyOperationRef operation,
+                    intptr_t startIndex = 0, intptr_t length = -1,
+                    intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirBlockGetNumSuccessors(block.get())
+                               : length,
+                  step),
+        operation(operation), block(block) {}
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyBlockSuccessors, PyBlock>;
+
+  intptr_t getRawNumElements() {
+    block.checkValid();
+    return mlirBlockGetNumSuccessors(block.get());
+  }
+
+  PyBlock getRawElement(intptr_t pos) {
+    MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos);
+    return PyBlock(operation, block);
+  }
+
+  PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+    return PyBlockSuccessors(block, operation, startIndex, length, step);
+  }
+
+  PyOperationRef operation;
+  PyBlock block;
+};
+
 /// A list of operation attributes. Can be indexed by name, producing
 /// attributes, or by index, producing named attributes.
 class PyOpAttributeMap {
@@ -3655,7 +3694,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
           },
           nb::arg("operation"),
           "Appends an operation to this block. If the operation is currently "
-          "in another block, it will be moved.");
+          "in another block, it will be moved.")
+      .def_prop_ro(
+          "successors",
+          [](PyBlock &self) {
+            return PyBlockSuccessors(self, self.getParentOperation());
+          },
+          "Returns the list of Block successors.");
 
   //----------------------------------------------------------------------------
   // Mapping of PyInsertionPoint.
@@ -4099,6 +4144,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
   PyBlockArgumentList::bind(m);
   PyBlockIterator::bind(m);
   PyBlockList::bind(m);
+  PyBlockSuccessors::bind(m);
   PyOperationIterator::bind(m);
   PyOperationList::bind(m);
   PyOpAttributeMap::bind(m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e0e386d55ede1..7fa831da3a4d8 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -1059,6 +1059,14 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback,
   unwrap(block)->print(stream);
 }
 
+intptr_t mlirBlockGetNumSuccessors(MlirBlock block) {
+  return static_cast<intptr_t>(unwrap(block)->getNumSuccessors());
+}
+
+MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) {
+  return wrap(unwrap(block)->getSuccessor(static_cast<unsigned>(pos)));
+}
+
 //===----------------------------------------------------------------------===//
 // Value API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/blocks.py b/mlir/test/python/ir/blocks.py
index 70ccaeeb5435b..551b2181b6552 100644
--- a/mlir/test/python/ir/blocks.py
+++ b/mlir/test/python/ir/blocks.py
@@ -1,12 +1,11 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
-import io
-import itertools
-from mlir.ir import *
+
 from mlir.dialects import builtin
 from mlir.dialects import cf
 from mlir.dialects import func
+from mlir.ir import *
 
 
 def run(f):
@@ -54,10 +53,22 @@ def testBlockCreation():
             with InsertionPoint(middle_block) as middle_ip:
                 assert middle_ip.block == middle_block
                 cf.BranchOp([i32_arg], dest=successor_block)
+
         module.print(enable_debug_info=True)
         # Ensure region back references are coherent.
         assert entry_block.region == middle_block.region == successor_block.region
 
+        entry_block_successors = entry_block.successors
+        assert len(entry_block_successors) == 1
+        assert middle_block == entry_block_successors[0]
+
+        middle_block_successors = middle_block.successors
+        assert len(middle_block_successors) == 1
+        assert successor_block == middle_block_successors[0]
+
+        successor_block_successors = successor_block.successors
+        assert len(successor_block_successors) == 0
+
 
 # CHECK-LABEL: TEST: testBlockCreationArgLocs
 @run

@makslevental makslevental force-pushed the makslevental/py-block-successors branch 2 times, most recently from 04ef6e4 to c49fd8a Compare June 21, 2025 03:09
@makslevental makslevental changed the title [mlir][python] bind block successors [mlir][python] bind block predecessors and successors Jun 21, 2025
@makslevental makslevental force-pushed the makslevental/py-block-successors branch from c49fd8a to 69186b6 Compare June 21, 2025 03:15
@makslevental makslevental force-pushed the makslevental/py-block-successors branch from 69186b6 to bedc679 Compare June 21, 2025 03:40
@@ -986,6 +986,20 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block,
MLIR_CAPI_EXPORTED void
mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData);

/// Returns the number of successor blocks of the block.
MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have tests for the C API as well

Comment on lines +1077 to +1079
Block::pred_iterator it = b->pred_begin();
std::advance(it, pos);
return wrap(*it);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather avoid iterating over the use-def list every time... This goes through block's use-def chain, maybe there is a way to expose a BlockOperand (and incidentally OpOperand if it isn't) and a getNextUse.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy