diff --git a/src/embed_tests/TestOperator.cs b/src/embed_tests/TestOperator.cs index ecdb0c1dc..8e9feb241 100644 --- a/src/embed_tests/TestOperator.cs +++ b/src/embed_tests/TestOperator.cs @@ -25,6 +25,17 @@ public class OperableObject { public int Num { get; set; } + public override int GetHashCode() + { + return unchecked(159832395 + Num.GetHashCode()); + } + + public override bool Equals(object obj) + { + return obj is OperableObject @object && + Num == @object.Num; + } + public OperableObject(int num) { Num = num; @@ -149,6 +160,103 @@ public OperableObject(int num) return new OperableObject(a.Num ^ b); } + public static bool operator ==(int a, OperableObject b) + { + return (a == b.Num); + } + public static bool operator ==(OperableObject a, OperableObject b) + { + return (a.Num == b.Num); + } + public static bool operator ==(OperableObject a, int b) + { + return (a.Num == b); + } + + public static bool operator !=(int a, OperableObject b) + { + return (a != b.Num); + } + public static bool operator !=(OperableObject a, OperableObject b) + { + return (a.Num != b.Num); + } + public static bool operator !=(OperableObject a, int b) + { + return (a.Num != b); + } + + public static bool operator <=(int a, OperableObject b) + { + return (a <= b.Num); + } + public static bool operator <=(OperableObject a, OperableObject b) + { + return (a.Num <= b.Num); + } + public static bool operator <=(OperableObject a, int b) + { + return (a.Num <= b); + } + + public static bool operator >=(int a, OperableObject b) + { + return (a >= b.Num); + } + public static bool operator >=(OperableObject a, OperableObject b) + { + return (a.Num >= b.Num); + } + public static bool operator >=(OperableObject a, int b) + { + return (a.Num >= b); + } + + public static bool operator >=(OperableObject a, PyObject b) + { + using (Py.GIL()) + { + // Assuming b is a tuple, take the first element. + int bNum = b[0].As(); + return a.Num >= bNum; + } + } + public static bool operator <=(OperableObject a, PyObject b) + { + using (Py.GIL()) + { + // Assuming b is a tuple, take the first element. + int bNum = b[0].As(); + return a.Num <= bNum; + } + } + + public static bool operator <(int a, OperableObject b) + { + return (a < b.Num); + } + public static bool operator <(OperableObject a, OperableObject b) + { + return (a.Num < b.Num); + } + public static bool operator <(OperableObject a, int b) + { + return (a.Num < b); + } + + public static bool operator >(int a, OperableObject b) + { + return (a > b.Num); + } + public static bool operator >(OperableObject a, OperableObject b) + { + return (a.Num > b.Num); + } + public static bool operator >(OperableObject a, int b) + { + return (a.Num > b); + } + public static OperableObject operator <<(OperableObject a, int offset) { return new OperableObject(a.Num << offset); @@ -161,7 +269,7 @@ public OperableObject(int num) } [Test] - public void OperatorOverloads() + public void SymmetricalOperatorOverloads() { string name = string.Format("{0}.{1}", typeof(OperableObject).DeclaringType.Name, @@ -206,6 +314,24 @@ public void OperatorOverloads() c = a ^ b assert c.Num == a.Num ^ b.Num + +c = a == b +assert c == (a.Num == b.Num) + +c = a != b +assert c == (a.Num != b.Num) + +c = a <= b +assert c == (a.Num <= b.Num) + +c = a >= b +assert c == (a.Num >= b.Num) + +c = a < b +assert c == (a.Num < b.Num) + +c = a > b +assert c == (a.Num > b.Num) "); } @@ -263,6 +389,51 @@ public void ForwardOperatorOverloads() c = a ^ b assert c.Num == a.Num ^ b + +c = a == b +assert c == (a.Num == b) + +c = a != b +assert c == (a.Num != b) + +c = a <= b +assert c == (a.Num <= b) + +c = a >= b +assert c == (a.Num >= b) + +c = a < b +assert c == (a.Num < b) + +c = a > b +assert c == (a.Num > b) +"); + } + + [Test] + public void TupleComparisonOperatorOverloads() + { + string name = string.Format("{0}.{1}", + typeof(OperableObject).DeclaringType.Name, + typeof(OperableObject).Name); + string module = MethodBase.GetCurrentMethod().DeclaringType.Namespace; + PythonEngine.Exec($@" +from {module} import * +cls = {name} +a = cls(2) +b = (1, 2) + +c = a >= b +assert c == (a.Num >= b[0]) + +c = a <= b +assert c == (a.Num <= b[0]) + +c = b >= a +assert c == (b[0] >= a.Num) + +c = b <= a +assert c == (b[0] <= a.Num) "); } @@ -304,6 +475,24 @@ public void ReverseOperatorOverloads() c = a ^ b assert c.Num == a ^ b.Num + +c = a == b +assert c == (a == b.Num) + +c = a != b +assert c == (a != b.Num) + +c = a <= b +assert c == (a <= b.Num) + +c = a >= b +assert c == (a >= b.Num) + +c = a < b +assert c == (a < b.Num) + +c = a > b +assert c == (a > b.Num) "); } diff --git a/src/runtime/classbase.cs b/src/runtime/classbase.cs index 0ff4ba154..872501267 100644 --- a/src/runtime/classbase.cs +++ b/src/runtime/classbase.cs @@ -21,6 +21,7 @@ internal class ClassBase : ManagedType [NonSerialized] internal List dotNetMembers; internal Indexer indexer; + internal Dictionary richcompare; internal MaybeType type; internal ClassBase(Type tp) @@ -35,6 +36,15 @@ internal virtual bool CanSubclass() return !type.Value.IsEnum; } + public readonly static Dictionary CilToPyOpMap = new Dictionary + { + ["op_Equality"] = Runtime.Py_EQ, + ["op_Inequality"] = Runtime.Py_NE, + ["op_LessThanOrEqual"] = Runtime.Py_LE, + ["op_GreaterThanOrEqual"] = Runtime.Py_GE, + ["op_LessThan"] = Runtime.Py_LT, + ["op_GreaterThan"] = Runtime.Py_GT, + }; /// /// Default implementation of [] semantics for reflected types. @@ -72,6 +82,30 @@ public static IntPtr tp_richcompare(IntPtr ob, IntPtr other, int op) { CLRObject co1; CLRObject co2; + IntPtr tp = Runtime.PyObject_TYPE(ob); + var cls = (ClassBase)GetManagedObject(tp); + // C# operator methods take precedence over IComparable. + // We first check if there's a comparison operator by looking up the richcompare table, + // otherwise fallback to checking if an IComparable interface is handled. + if (cls.richcompare.TryGetValue(op, out var methodObject)) + { + // Wrap the `other` argument of a binary comparison operator in a PyTuple. + IntPtr args = Runtime.PyTuple_New(1); + Runtime.XIncref(other); + Runtime.PyTuple_SetItem(args, 0, other); + + IntPtr value; + try + { + value = methodObject.Invoke(ob, args, IntPtr.Zero); + } + finally + { + Runtime.XDecref(args); // Free args pytuple + } + return value; + } + switch (op) { case Runtime.Py_EQ: diff --git a/src/runtime/classmanager.cs b/src/runtime/classmanager.cs index 64c985ce7..0cbff371f 100644 --- a/src/runtime/classmanager.cs +++ b/src/runtime/classmanager.cs @@ -259,6 +259,7 @@ private static void InitClassBase(Type type, ClassBase impl) ClassInfo info = GetClassInfo(type); impl.indexer = info.indexer; + impl.richcompare = new Dictionary(); // Now we allocate the Python type object to reflect the given // managed type, filling the Python type slots with thunks that @@ -284,6 +285,9 @@ private static void InitClassBase(Type type, ClassBase impl) Runtime.PyDict_SetItemString(dict, name, item.pyHandle); // Decref the item now that it's been used. item.DecrRefCount(); + if (ClassBase.CilToPyOpMap.TryGetValue(name, out var pyOp)) { + impl.richcompare.Add(pyOp, (MethodObject)item); + } } // If class has constructors, generate an __doc__ attribute. @@ -553,8 +557,7 @@ private static ClassInfo GetClassInfo(Type type) { string pyName = OperatorMethod.GetPyMethodName(name); string pyNameReverse = OperatorMethod.ReversePyMethodName(pyName); - MethodInfo[] forwardMethods, reverseMethods; - OperatorMethod.FilterMethods(mlist, out forwardMethods, out reverseMethods); + OperatorMethod.FilterMethods(mlist, out var forwardMethods, out var reverseMethods); // Only methods where the left operand is the declaring type. if (forwardMethods.Length > 0) ci.members[pyName] = new MethodObject(type, name, forwardMethods); diff --git a/src/runtime/methodbinder.cs b/src/runtime/methodbinder.cs index ba37c19c1..5de0ecc00 100644 --- a/src/runtime/methodbinder.cs +++ b/src/runtime/methodbinder.cs @@ -354,16 +354,17 @@ internal Binding Bind(IntPtr inst, IntPtr args, IntPtr kw, MethodBase info, Meth int kwargsMatched; int defaultsNeeded; bool isOperator = OperatorMethod.IsOperatorMethod(mi); - int clrnargs = pi.Length; // Binary operator methods will have 2 CLR args but only one Python arg // (unary operators will have 1 less each), since Python operator methods are bound. - isOperator = isOperator && pynargs == clrnargs - 1; + isOperator = isOperator && pynargs == pi.Length - 1; + bool isReverse = isOperator && OperatorMethod.IsReverse((MethodInfo)mi); // Only cast if isOperator. + if (isReverse && OperatorMethod.IsComparisonOp((MethodInfo)mi)) + continue; // Comparison operators in Python have no reverse mode. if (!MatchesArgumentCount(pynargs, pi, kwargDict, out paramsArray, out defaultArgList, out kwargsMatched, out defaultsNeeded) && !isOperator) { continue; } // Preprocessing pi to remove either the first or second argument. - bool isReverse = isOperator && OperatorMethod.IsReverse((MethodInfo)mi); // Only cast if isOperator. if (isOperator && !isReverse) { // The first Python arg is the right operand, while the bound instance is the left. // We need to skip the first (left operand) CLR argument. diff --git a/src/runtime/operatormethod.cs b/src/runtime/operatormethod.cs index 1e0244510..59bf944bc 100644 --- a/src/runtime/operatormethod.cs +++ b/src/runtime/operatormethod.cs @@ -15,6 +15,7 @@ internal static class OperatorMethod /// that identifies that operator's slot (e.g. nb_add) in heap space. /// public static Dictionary OpMethodMap { get; private set; } + public static Dictionary ComparisonOpMap { get; private set; } public readonly struct SlotDefinition { public SlotDefinition(string methodName, int typeOffset) @@ -24,6 +25,7 @@ public SlotDefinition(string methodName, int typeOffset) } public string MethodName { get; } public int TypeOffset { get; } + } private static PyObject _opType; @@ -49,6 +51,16 @@ static OperatorMethod() ["op_OnesComplement"] = new SlotDefinition("__invert__", TypeOffset.nb_invert), ["op_UnaryNegation"] = new SlotDefinition("__neg__", TypeOffset.nb_negative), ["op_UnaryPlus"] = new SlotDefinition("__pos__", TypeOffset.nb_positive), + ["op_OneComplement"] = new SlotDefinition("__invert__", TypeOffset.nb_invert), + }; + ComparisonOpMap = new Dictionary + { + ["op_Equality"] = "__eq__", + ["op_Inequality"] = "__ne__", + ["op_LessThanOrEqual"] = "__le__", + ["op_GreaterThanOrEqual"] = "__ge__", + ["op_LessThan"] = "__lt__", + ["op_GreaterThan"] = "__gt__", }; } @@ -72,8 +84,14 @@ public static bool IsOperatorMethod(MethodBase method) { return false; } - return OpMethodMap.ContainsKey(method.Name); + return OpMethodMap.ContainsKey(method.Name) || ComparisonOpMap.ContainsKey(method.Name); + } + + public static bool IsComparisonOp(MethodInfo method) + { + return ComparisonOpMap.ContainsKey(method.Name); } + /// /// For the operator methods of a CLR type, set the special slots of the /// corresponding Python type's operator methods. @@ -86,7 +104,9 @@ public static void FixupSlots(IntPtr pyType, Type clrType) Debug.Assert(_opType != null); foreach (var method in clrType.GetMethods(flags)) { - if (!IsOperatorMethod(method)) + // We only want to override slots for operators excluding + // comparison operators, which are handled by ClassBase.tp_richcompare. + if (!OpMethodMap.ContainsKey(method.Name)) { continue; } @@ -99,13 +119,18 @@ public static void FixupSlots(IntPtr pyType, Type clrType) // when used with a Python operator. // https://tenthousandmeters.com/blog/python-behind-the-scenes-6-how-python-object-system-works/ Marshal.WriteIntPtr(pyType, offset, func); - } } public static string GetPyMethodName(string clrName) { - return OpMethodMap[clrName].MethodName; + if (OpMethodMap.ContainsKey(clrName)) + { + return OpMethodMap[clrName].MethodName; + } else + { + return ComparisonOpMap[clrName]; + } } private static string GenerateDummyCode() pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy