Skip to content

Commit d0f8cdc

Browse files
committed
make .NET objects that have __call__ method callable from Python
Implemented by adding tp_call to ClassBase, that uses reflection to find __call__ methods in .NET, and falls back to invoking __call__ method from Python base classes. fixes #890 this is an amalgamation of d46878c, 5bb1007, and 960457f from https://github.com/losttech/pythonnet
1 parent c9626df commit d0f8cdc

File tree

7 files changed

+225
-2
lines changed

7 files changed

+225
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ See [Mixins/collections.py](src/runtime/Mixins/collections.py).
2323
- .NET arrays implement Python buffer protocol
2424
- Python.NET will correctly resolve .NET methods, that accept `PyList`, `PyInt`,
2525
and other `PyObject` derived types when called from Python.
26+
- .NET classes, that have `__call__` method are callable from Python
2627
- `PyIterable` type, that wraps any iterable object in Python
2728

2829

src/embed_tests/CallableObject.cs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
using System;
2+
using System.Collections.Generic;
3+
4+
using NUnit.Framework;
5+
6+
using Python.Runtime;
7+
8+
namespace Python.EmbeddingTest
9+
{
10+
public class CallableObject
11+
{
12+
[OneTimeSetUp]
13+
public void SetUp()
14+
{
15+
PythonEngine.Initialize();
16+
using (Py.GIL())
17+
{
18+
using var locals = new PyDict();
19+
PythonEngine.Exec(CallViaInheritance.BaseClassSource, locals: locals.Handle);
20+
CustomBaseTypeProvider.BaseClass = new PyType(locals[CallViaInheritance.BaseClassName]);
21+
PythonEngine.InteropConfiguration.PythonBaseTypeProviders.Add(new CustomBaseTypeProvider());
22+
}
23+
}
24+
25+
[OneTimeTearDown]
26+
public void Dispose()
27+
{
28+
PythonEngine.Shutdown();
29+
}
30+
[Test]
31+
public void CallMethodMakesObjectCallable()
32+
{
33+
var doubler = new DerivedDoubler();
34+
using (Py.GIL())
35+
{
36+
dynamic applyObjectTo21 = PythonEngine.Eval("lambda o: o(21)");
37+
Assert.AreEqual(doubler.__call__(21), (int)applyObjectTo21(doubler.ToPython()));
38+
}
39+
}
40+
[Test]
41+
public void CallMethodCanBeInheritedFromPython()
42+
{
43+
var callViaInheritance = new CallViaInheritance();
44+
using (Py.GIL())
45+
{
46+
dynamic applyObjectTo14 = PythonEngine.Eval("lambda o: o(14)");
47+
Assert.AreEqual(callViaInheritance.Call(14), (int)applyObjectTo14(callViaInheritance.ToPython()));
48+
}
49+
}
50+
51+
[Test]
52+
public void CanOverwriteCall()
53+
{
54+
var callViaInheritance = new CallViaInheritance();
55+
using var _ = Py.GIL();
56+
using var scope = Py.CreateScope();
57+
scope.Set("o", callViaInheritance);
58+
scope.Exec("orig_call = o.Call");
59+
scope.Exec("o.Call = lambda a: orig_call(a*7)");
60+
int result = scope.Eval<int>("o.Call(5)");
61+
Assert.AreEqual(105, result);
62+
}
63+
64+
class Doubler
65+
{
66+
public int __call__(int arg) => 2 * arg;
67+
}
68+
69+
class DerivedDoubler : Doubler { }
70+
71+
class CallViaInheritance
72+
{
73+
public const string BaseClassName = "Forwarder";
74+
public static readonly string BaseClassSource = $@"
75+
class MyCallableBase:
76+
def __call__(self, val):
77+
return self.Call(val)
78+
79+
class {BaseClassName}(MyCallableBase): pass
80+
";
81+
public int Call(int arg) => 3 * arg;
82+
}
83+
84+
class CustomBaseTypeProvider : IPythonBaseTypeProvider
85+
{
86+
internal static PyType BaseClass;
87+
88+
public IEnumerable<PyType> GetBaseTypes(Type type, IList<PyType> existingBases)
89+
{
90+
Assert.Greater(BaseClass.Refcount, 0);
91+
return type != typeof(CallViaInheritance)
92+
? existingBases
93+
: new[] { BaseClass };
94+
}
95+
}
96+
}
97+
}

src/runtime/classbase.cs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using System;
22
using System.Collections;
33
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Reflection;
46
using System.Runtime.InteropServices;
57

68
namespace Python.Runtime
@@ -557,5 +559,91 @@ public static int mp_ass_subscript(IntPtr ob, IntPtr idx, IntPtr v)
557559

558560
return 0;
559561
}
562+
563+
static IntPtr tp_call_impl(IntPtr ob, IntPtr args, IntPtr kw)
564+
{
565+
IntPtr tp = Runtime.PyObject_TYPE(ob);
566+
var self = (ClassBase)GetManagedObject(tp);
567+
568+
if (!self.type.Valid)
569+
{
570+
return Exceptions.RaiseTypeError(self.type.DeletedMessage);
571+
}
572+
573+
Type type = self.type.Value;
574+
575+
var calls = GetCallImplementations(type).ToList();
576+
if (calls.Count > 0)
577+
{
578+
var callBinder = new MethodBinder();
579+
foreach (MethodInfo call in calls)
580+
{
581+
callBinder.AddMethod(call);
582+
}
583+
return callBinder.Invoke(ob, args, kw);
584+
}
585+
586+
return InvokeCallInheritedFromPython(new BorrowedReference(ob), args, kw);
587+
}
588+
589+
static IEnumerable<MethodInfo> GetCallImplementations(Type type)
590+
=> type.GetMethods(BindingFlags.Public | BindingFlags.Instance)
591+
.Where(m => m.Name == "__call__");
592+
593+
/// <summary>
594+
/// Find bases defined in Python and use their __call__ if any
595+
/// </summary>
596+
static IntPtr InvokeCallInheritedFromPython(BorrowedReference ob, IntPtr args, IntPtr kw)
597+
{
598+
BorrowedReference tp = Runtime.PyObject_TYPE(ob);
599+
using var super = new PyObject(new BorrowedReference(Runtime.PySuper_Type));
600+
using var pyInst = new PyObject(ob);
601+
602+
BorrowedReference mro = PyType.GetMRO(tp);
603+
nint mroLen = Runtime.PyTuple_Size(mro);
604+
for (int baseIndex = 0; baseIndex < mroLen - 1; baseIndex++)
605+
{
606+
BorrowedReference @base = Runtime.PyTuple_GetItem(mro, baseIndex);
607+
if (!IsManagedType(@base)) continue;
608+
609+
BorrowedReference nextBase = Runtime.PyTuple_GetItem(mro, baseIndex + 1);
610+
if (ManagedType.IsManagedType(nextBase)) continue;
611+
612+
// call via super
613+
using var managedBase = new PyObject(@base);
614+
using var superInstance = super.Invoke(managedBase, pyInst);
615+
using var call = Runtime.PyObject_GetAttrString(superInstance.Reference, "__call__");
616+
if (call.IsNull())
617+
{
618+
if (Exceptions.ExceptionMatches(Exceptions.AttributeError))
619+
{
620+
Runtime.PyErr_Clear();
621+
continue;
622+
}
623+
else
624+
{
625+
return IntPtr.Zero;
626+
}
627+
}
628+
629+
return Runtime.PyObject_Call(call.DangerousGetAddress(), args, kw);
630+
}
631+
632+
Exceptions.SetError(Exceptions.TypeError, "object is not callable");
633+
return IntPtr.Zero;
634+
}
635+
636+
static readonly Interop.TernaryFunc tp_call_delegate = tp_call_impl;
637+
638+
public virtual void InitializeSlots(SlotsHolder slotsHolder)
639+
{
640+
if (!this.type.Valid) return;
641+
642+
if (GetCallImplementations(this.type.Value).Any()
643+
&& !slotsHolder.IsHolding(TypeOffset.tp_call))
644+
{
645+
TypeManager.InitializeSlot(ObjectReference, TypeOffset.tp_call, tp_call_delegate, slotsHolder);
646+
}
647+
}
560648
}
561649
}

src/runtime/classmanager.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ internal static Dictionary<ManagedType, InterDomainContext> RestoreRuntimeData(R
162162
Runtime.PyType_Modified(pair.Value.TypeReference);
163163
var context = contexts[pair.Value.pyHandle];
164164
pair.Value.Load(context);
165+
var slotsHolder = TypeManager.GetSlotsHolder(pyType);
166+
pair.Value.InitializeSlots(slotsHolder);
167+
Runtime.PyType_Modified(pair.Value.TypeReference);
165168
loadedObjs.Add(pair.Value, context);
166169
}
167170

src/runtime/interop.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,13 @@ internal static ThunkInfo GetThunk(MethodInfo method, string funcType = null)
242242
return ThunkInfo.Empty;
243243
}
244244
Delegate d = Delegate.CreateDelegate(dt, method);
245-
var info = new ThunkInfo(d);
246-
allocatedThunks[info.Address] = d;
245+
return GetThunk(d);
246+
}
247+
248+
internal static ThunkInfo GetThunk(Delegate @delegate)
249+
{
250+
var info = new ThunkInfo(@delegate);
251+
allocatedThunks[info.Address] = @delegate;
247252
return info;
248253
}
249254

src/runtime/pytype.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,20 @@ internal static BorrowedReference GetBase(BorrowedReference type)
121121
return new BorrowedReference(basePtr);
122122
}
123123

124+
internal static BorrowedReference GetBases(BorrowedReference type)
125+
{
126+
Debug.Assert(IsType(type));
127+
IntPtr basesPtr = Marshal.ReadIntPtr(type.DangerousGetAddress(), TypeOffset.tp_bases);
128+
return new BorrowedReference(basesPtr);
129+
}
130+
131+
internal static BorrowedReference GetMRO(BorrowedReference type)
132+
{
133+
Debug.Assert(IsType(type));
134+
IntPtr basesPtr = Marshal.ReadIntPtr(type.DangerousGetAddress(), TypeOffset.tp_mro);
135+
return new BorrowedReference(basesPtr);
136+
}
137+
124138
private static IntPtr EnsureIsType(in StolenReference reference)
125139
{
126140
IntPtr address = reference.DangerousGetAddressOrNull();

src/runtime/typemanager.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,10 @@ static void InitializeClass(PyType pyType, ClassBase impl, Type clrType)
404404
impl.tpHandle = type;
405405
impl.pyHandle = type;
406406

407+
impl.InitializeSlots(slotsHolder);
408+
409+
Runtime.PyType_Modified(pyType.Reference);
410+
407411
//DebugUtil.DumpType(type);
408412
}
409413

@@ -787,6 +791,12 @@ static void InitializeSlot(IntPtr type, int slotOffset, MethodInfo method, Slots
787791
InitializeSlot(type, slotOffset, thunk, slotsHolder);
788792
}
789793

794+
internal static void InitializeSlot(BorrowedReference type, int slotOffset, Delegate impl, SlotsHolder slotsHolder)
795+
{
796+
var thunk = Interop.GetThunk(impl);
797+
InitializeSlot(type.DangerousGetAddress(), slotOffset, thunk, slotsHolder);
798+
}
799+
790800
static void InitializeSlot(IntPtr type, int slotOffset, ThunkInfo thunk, SlotsHolder slotsHolder)
791801
{
792802
Marshal.WriteIntPtr(type, slotOffset, thunk.Address);
@@ -848,6 +858,9 @@ private static SlotsHolder CreateSolotsHolder(IntPtr type)
848858
_slotsHolders.Add(type, holder);
849859
return holder;
850860
}
861+
862+
internal static SlotsHolder GetSlotsHolder(PyType type)
863+
=> _slotsHolders[type.Handle];
851864
}
852865

853866

@@ -873,6 +886,8 @@ public SlotsHolder(IntPtr type)
873886
_type = type;
874887
}
875888

889+
public bool IsHolding(int offset) => _slots.ContainsKey(offset);
890+
876891
public void Set(int offset, ThunkInfo thunk)
877892
{
878893
_slots[offset] = thunk;

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy