diff --git a/CHANGELOG.md b/CHANGELOG.md index 35ef66882..3599c619b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ See [Mixins/collections.py](src/runtime/Mixins/collections.py). - .NET arrays implement Python buffer protocol - Python.NET will correctly resolve .NET methods, that accept `PyList`, `PyInt`, and other `PyObject` derived types when called from Python. +- .NET classes, that have `__call__` method are callable from Python - `PyIterable` type, that wraps any iterable object in Python diff --git a/src/embed_tests/CallableObject.cs b/src/embed_tests/CallableObject.cs new file mode 100644 index 000000000..ab732be15 --- /dev/null +++ b/src/embed_tests/CallableObject.cs @@ -0,0 +1,87 @@ +using System; +using System.Collections.Generic; + +using NUnit.Framework; + +using Python.Runtime; + +namespace Python.EmbeddingTest +{ + public class CallableObject + { + [OneTimeSetUp] + public void SetUp() + { + PythonEngine.Initialize(); + using var locals = new PyDict(); + PythonEngine.Exec(CallViaInheritance.BaseClassSource, locals: locals.Handle); + CustomBaseTypeProvider.BaseClass = new PyType(locals[CallViaInheritance.BaseClassName]); + PythonEngine.InteropConfiguration.PythonBaseTypeProviders.Add(new CustomBaseTypeProvider()); + } + + [OneTimeTearDown] + public void Dispose() + { + PythonEngine.Shutdown(); + } + [Test] + public void CallMethodMakesObjectCallable() + { + var doubler = new DerivedDoubler(); + dynamic applyObjectTo21 = PythonEngine.Eval("lambda o: o(21)"); + Assert.AreEqual(doubler.__call__(21), (int)applyObjectTo21(doubler.ToPython())); + } + [Test] + public void CallMethodCanBeInheritedFromPython() + { + var callViaInheritance = new CallViaInheritance(); + dynamic applyObjectTo14 = PythonEngine.Eval("lambda o: o(14)"); + Assert.AreEqual(callViaInheritance.Call(14), (int)applyObjectTo14(callViaInheritance.ToPython())); + } + + [Test] + public void CanOverwriteCall() + { + var callViaInheritance = new CallViaInheritance(); + using var scope = Py.CreateScope(); + scope.Set("o", callViaInheritance); + scope.Exec("orig_call = o.Call"); + scope.Exec("o.Call = lambda a: orig_call(a*7)"); + int result = scope.Eval("o.Call(5)"); + Assert.AreEqual(105, result); + } + + class Doubler + { + public int __call__(int arg) => 2 * arg; + } + + class DerivedDoubler : Doubler { } + + class CallViaInheritance + { + public const string BaseClassName = "Forwarder"; + public static readonly string BaseClassSource = $@" +class MyCallableBase: + def __call__(self, val): + return self.Call(val) + +class {BaseClassName}(MyCallableBase): pass +"; + public int Call(int arg) => 3 * arg; + } + + class CustomBaseTypeProvider : IPythonBaseTypeProvider + { + internal static PyType BaseClass; + + public IEnumerable GetBaseTypes(Type type, IList existingBases) + { + Assert.Greater(BaseClass.Refcount, 0); + return type != typeof(CallViaInheritance) + ? existingBases + : new[] { BaseClass }; + } + } + } +} diff --git a/src/runtime/classbase.cs b/src/runtime/classbase.cs index 570ce3062..311b5b5f3 100644 --- a/src/runtime/classbase.cs +++ b/src/runtime/classbase.cs @@ -1,6 +1,9 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Reflection; using System.Runtime.InteropServices; namespace Python.Runtime @@ -557,5 +560,44 @@ public static int mp_ass_subscript(IntPtr ob, IntPtr idx, IntPtr v) return 0; } + + static IntPtr tp_call_impl(IntPtr ob, IntPtr args, IntPtr kw) + { + IntPtr tp = Runtime.PyObject_TYPE(ob); + var self = (ClassBase)GetManagedObject(tp); + + if (!self.type.Valid) + { + return Exceptions.RaiseTypeError(self.type.DeletedMessage); + } + + Type type = self.type.Value; + + var calls = GetCallImplementations(type).ToList(); + Debug.Assert(calls.Count > 0); + var callBinder = new MethodBinder(); + foreach (MethodInfo call in calls) + { + callBinder.AddMethod(call); + } + return callBinder.Invoke(ob, args, kw); + } + + static IEnumerable GetCallImplementations(Type type) + => type.GetMethods(BindingFlags.Public | BindingFlags.Instance) + .Where(m => m.Name == "__call__"); + + static readonly Interop.TernaryFunc tp_call_delegate = tp_call_impl; + + public virtual void InitializeSlots(SlotsHolder slotsHolder) + { + if (!this.type.Valid) return; + + if (GetCallImplementations(this.type.Value).Any() + && !slotsHolder.IsHolding(TypeOffset.tp_call)) + { + TypeManager.InitializeSlot(ObjectReference, TypeOffset.tp_call, tp_call_delegate, slotsHolder); + } + } } } diff --git a/src/runtime/classmanager.cs b/src/runtime/classmanager.cs index 589ac0ad1..06d82c7b8 100644 --- a/src/runtime/classmanager.cs +++ b/src/runtime/classmanager.cs @@ -162,6 +162,9 @@ internal static Dictionary RestoreRuntimeData(R Runtime.PyType_Modified(pair.Value.TypeReference); var context = contexts[pair.Value.pyHandle]; pair.Value.Load(context); + var slotsHolder = TypeManager.GetSlotsHolder(pyType); + pair.Value.InitializeSlots(slotsHolder); + Runtime.PyType_Modified(pair.Value.TypeReference); loadedObjs.Add(pair.Value, context); } diff --git a/src/runtime/interop.cs b/src/runtime/interop.cs index 188db3a58..e10348e39 100644 --- a/src/runtime/interop.cs +++ b/src/runtime/interop.cs @@ -242,8 +242,13 @@ internal static ThunkInfo GetThunk(MethodInfo method, string funcType = null) return ThunkInfo.Empty; } Delegate d = Delegate.CreateDelegate(dt, method); - var info = new ThunkInfo(d); - allocatedThunks[info.Address] = d; + return GetThunk(d); + } + + internal static ThunkInfo GetThunk(Delegate @delegate) + { + var info = new ThunkInfo(@delegate); + allocatedThunks[info.Address] = @delegate; return info; } diff --git a/src/runtime/pytype.cs b/src/runtime/pytype.cs index 52ef60d04..546a3ed05 100644 --- a/src/runtime/pytype.cs +++ b/src/runtime/pytype.cs @@ -121,6 +121,20 @@ internal static BorrowedReference GetBase(BorrowedReference type) return new BorrowedReference(basePtr); } + internal static BorrowedReference GetBases(BorrowedReference type) + { + Debug.Assert(IsType(type)); + IntPtr basesPtr = Marshal.ReadIntPtr(type.DangerousGetAddress(), TypeOffset.tp_bases); + return new BorrowedReference(basesPtr); + } + + internal static BorrowedReference GetMRO(BorrowedReference type) + { + Debug.Assert(IsType(type)); + IntPtr basesPtr = Marshal.ReadIntPtr(type.DangerousGetAddress(), TypeOffset.tp_mro); + return new BorrowedReference(basesPtr); + } + private static IntPtr EnsureIsType(in StolenReference reference) { IntPtr address = reference.DangerousGetAddressOrNull(); diff --git a/src/runtime/typemanager.cs b/src/runtime/typemanager.cs index 1d6321791..7a836bf05 100644 --- a/src/runtime/typemanager.cs +++ b/src/runtime/typemanager.cs @@ -404,6 +404,10 @@ static void InitializeClass(PyType pyType, ClassBase impl, Type clrType) impl.tpHandle = type; impl.pyHandle = type; + impl.InitializeSlots(slotsHolder); + + Runtime.PyType_Modified(pyType.Reference); + //DebugUtil.DumpType(type); } @@ -787,6 +791,12 @@ static void InitializeSlot(IntPtr type, int slotOffset, MethodInfo method, Slots InitializeSlot(type, slotOffset, thunk, slotsHolder); } + internal static void InitializeSlot(BorrowedReference type, int slotOffset, Delegate impl, SlotsHolder slotsHolder) + { + var thunk = Interop.GetThunk(impl); + InitializeSlot(type.DangerousGetAddress(), slotOffset, thunk, slotsHolder); + } + static void InitializeSlot(IntPtr type, int slotOffset, ThunkInfo thunk, SlotsHolder slotsHolder) { Marshal.WriteIntPtr(type, slotOffset, thunk.Address); @@ -848,6 +858,9 @@ private static SlotsHolder CreateSolotsHolder(IntPtr type) _slotsHolders.Add(type, holder); return holder; } + + internal static SlotsHolder GetSlotsHolder(PyType type) + => _slotsHolders[type.Handle]; } @@ -873,6 +886,8 @@ public SlotsHolder(IntPtr type) _type = type; } + public bool IsHolding(int offset) => _slots.ContainsKey(offset); + public void Set(int offset, ThunkInfo thunk) { _slots[offset] = thunk; pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy