diff --git a/src/runtime/StateSerialization/RuntimeData.cs b/src/runtime/StateSerialization/RuntimeData.cs index 204e15b5b..0fa484ada 100644 --- a/src/runtime/StateSerialization/RuntimeData.cs +++ b/src/runtime/StateSerialization/RuntimeData.cs @@ -1,20 +1,407 @@ using System; -using System.Collections; using System.Collections.Generic; -using System.Collections.ObjectModel; using System.Diagnostics; using System.IO; using System.Linq; +using System.Reflection; +using System.Reflection.Emit; using System.Runtime.InteropServices; using System.Runtime.Serialization; using System.Runtime.Serialization.Formatters.Binary; - using Python.Runtime.StateSerialization; using static Python.Runtime.Runtime; namespace Python.Runtime { + [System.Serializable] + public sealed class NotSerializedException: SerializationException + { + static string _message = "The underlying C# object has been deleted."; + public NotSerializedException() : base(_message){} + private NotSerializedException(SerializationInfo info, StreamingContext context) : base(info, context){} + override public void GetObjectData(SerializationInfo info, StreamingContext context) => base.GetObjectData(info, context); + } + + [Serializable] + internal static class NonSerializedTypeBuilder + { + + internal static AssemblyName nonSerializedAssemblyName = + new AssemblyName("Python.Runtime.NonSerialized.dll, Version=0.0.0.0, Culture=neutral, PublicKeyToken=null"); + internal static AssemblyBuilder assemblyForNonSerializedClasses = + AppDomain.CurrentDomain.DefineDynamicAssembly(nonSerializedAssemblyName, AssemblyBuilderAccess.Run); + internal static ModuleBuilder moduleBuilder = assemblyForNonSerializedClasses.DefineDynamicModule("NotSerializedModule"); + internal static HashSet dontReimplementMethods = new(){"Finalize", "Dispose", "GetType", "ReferenceEquals", "GetHashCode", "Equals"}; + internal const string notSerializedSuffix = "_NotSerialized"; + // dummy field name to mark classes created by the "non-serializer" so we don't loop-inherit + // on multiple cycles of de/serialization. We use a static field instead of an attribute + // becaues of a bug in mono. Put a space in the name so users will be extremely unlikely + // to create a field with the same name. + internal const string notSerializedFieldName = "__PyNet NonSerialized"; + + private static Func hasVisibility = (tp, attr) => (tp.Attributes & TypeAttributes.VisibilityMask) == attr; + private static Func isNestedType = (tp) => hasVisibility(tp, TypeAttributes.NestedPrivate) || hasVisibility(tp, TypeAttributes.NestedPublic) || hasVisibility(tp, TypeAttributes.NestedFamily) || hasVisibility(tp, TypeAttributes.NestedAssembly); + private static Func isPrivateType = (tp) => hasVisibility(tp, TypeAttributes.NotPublic) || hasVisibility(tp, TypeAttributes.NestedPrivate) || hasVisibility(tp, TypeAttributes.NestedFamily) || hasVisibility( tp, TypeAttributes.NestedAssembly); + private static Func isPublicType = (tp) => hasVisibility(tp, TypeAttributes.Public) || hasVisibility(tp,TypeAttributes.NestedPublic); + private static Func CanCreateType = (tp) => isPublicType(tp) && ((tp.Attributes & TypeAttributes.Sealed) == 0); + + public static object? CreateNewObject(Type baseType) + { + var myType = CreateType(baseType); + if (myType is null) + { + return null; + } + var myObject = Activator.CreateInstance(myType); + return myObject; + } + + static void FillTypeMethods(TypeBuilder tb) + { + var constructors = tb.BaseType.GetConstructors(); + if (constructors.Count() == 0) + { + // no constructors defined, at least declare a default + ConstructorBuilder constructor = tb.DefineDefaultConstructor(MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName); + } + else + { + foreach (var ctor in constructors) + { + var ctorParams = (from param in ctor.GetParameters() select param.ParameterType).ToArray(); + var ctorbuilder = tb.DefineConstructor(ctor.Attributes, ctor.CallingConvention, ctorParams); + ctorbuilder.GetILGenerator().Emit(OpCodes.Ret); + + } + var parameterless = tb.DefineConstructor(MethodAttributes.Public | MethodAttributes.SpecialName | MethodAttributes.RTSpecialName, CallingConventions.Standard | CallingConventions.HasThis, Type.EmptyTypes); + parameterless.GetILGenerator().Emit(OpCodes.Ret); + } + + var properties = tb.BaseType.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static); + foreach (var prop in properties) + { + CreateProperty(tb, prop); + } + + var methods = tb.BaseType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance | BindingFlags.Static); + foreach (var meth in methods) + { + CreateMethod(tb, meth); + } + + ImplementEqualityAndHash(tb); + } + + static string MakeName(Type tp) + { + string @out = tp.Name + notSerializedSuffix; + var parentType = tp.DeclaringType; + while (parentType is not null) + { + // If we have a nested class, we need the whole nester/nestee + // chain with the suffix for each. + @out = parentType.Name + notSerializedSuffix + "+" + @out; + parentType = parentType.DeclaringType; + } + return @out; + } + + public static Type? CreateType(Type tp) + { + if (!CanCreateType(tp)) + { + return null; + } + + Type existingType = assemblyForNonSerializedClasses.GetType(MakeName(tp), throwOnError:false); + if (existingType is not null) + { + return existingType; + } + var parentType = tp.DeclaringType; + if (parentType is not null) + { + // parent types for nested types must be created first. Climb up the + // declaring type chain until we find a "top-level" class. + while (parentType.DeclaringType is not null) + { + parentType = parentType.DeclaringType; + } + CreateTypeInternal(parentType); + Type nestedType = assemblyForNonSerializedClasses.GetType(MakeName(tp), throwOnError:true); + return nestedType; + } + return CreateTypeInternal(tp); + } + + private static Type? CreateTypeInternal(Type baseType) + { + if (!isPublicType(baseType)) + { + // we can't derive from non-public types. + return null; + } + Type existingType = assemblyForNonSerializedClasses.GetType(MakeName(baseType), throwOnError:false); + if (existingType is not null) + { + return existingType; + } + + TypeBuilder tb = GetTypeBuilder(baseType); + SetNonSerializedAttr(tb); + FillTypeMethods(tb); + + var nestedtypes = baseType.GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static); + List nestedBuilders = new(); + foreach (var nested in nestedtypes) + { + if (isPrivateType(nested)) + { + continue; + } + var nestedBuilder = tb.DefineNestedType(nested.Name + notSerializedSuffix, + TypeAttributes.NestedPublic, + nested + ); + nestedBuilders.Add(nestedBuilder); + } + var outTp = tb.CreateType(); + foreach(var builder in nestedBuilders) + { + FillTypeMethods(builder); + SetNonSerializedAttr(builder); + builder.CreateType(); + } + return outTp; + } + + private static void ImplementEqualityAndHash(TypeBuilder tb) + { + var hashCodeMb = tb.DefineMethod("GetHashCode", + MethodAttributes.Public | MethodAttributes.Final | MethodAttributes.ReuseSlot, + CallingConventions.Standard, + typeof(int), + Type.EmptyTypes + ); + var getHashIlGen = hashCodeMb.GetILGenerator(); + getHashIlGen.Emit(OpCodes.Ldarg_0); + getHashIlGen.EmitCall(OpCodes.Call, typeof(object).GetMethod("GetType"), Type.EmptyTypes); + getHashIlGen.EmitCall(OpCodes.Call, typeof(Type).GetProperty("Name").GetMethod, Type.EmptyTypes); + getHashIlGen.EmitCall(OpCodes.Call, typeof(string).GetMethod("GetHashCode", Type.EmptyTypes), Type.EmptyTypes); + getHashIlGen.Emit(OpCodes.Ret); + + Type[] equalsArgs = new Type[] {typeof(object), typeof(object)}; + var equalsMb = tb.DefineMethod("Equals", + MethodAttributes.Public | MethodAttributes.Final | MethodAttributes.ReuseSlot, + CallingConventions.Standard, + typeof(bool), + equalsArgs + ); + var equalsIlGen = equalsMb.GetILGenerator(); + equalsIlGen.Emit(OpCodes.Ldarg_0); // this + equalsIlGen.Emit(OpCodes.Ldarg_1); // the other object + equalsIlGen.EmitCall(OpCodes.Call, typeof(object).GetMethod("ReferenceEquals"), equalsArgs); + equalsIlGen.Emit(OpCodes.Ret); + } + + private static void SetNonSerializedAttr(TypeBuilder tb) + { + // Name of the function says we're adding an attribute, but for some + // reason on Mono the attribute is not added, and no exceptions are + // thrown. + tb.DefineField(notSerializedFieldName, typeof(int), FieldAttributes.Public | FieldAttributes.Static); + } + + public static bool IsNonSerializedType(Type tp) + { + return tp.GetField(NonSerializedTypeBuilder.notSerializedFieldName, BindingFlags.Public | BindingFlags.Static) is not null; + + } + + private static TypeBuilder GetTypeBuilder(Type baseType) + { + string typeSignature = baseType.Name + notSerializedSuffix; + TypeBuilder tb = moduleBuilder.DefineType(typeSignature, + baseType.Attributes, + baseType, + baseType.GetInterfaces()); + return tb; + } + + static ILGenerator GenerateExceptionILCode(dynamic builder) + { + ILGenerator ilgen = builder.GetILGenerator(); + var seriExc = typeof(NotSerializedException); + var exCtorInfo = seriExc.GetConstructor(new Type[]{}); + ilgen.Emit(OpCodes.Newobj, exCtorInfo); + ilgen.ThrowException(seriExc); + return ilgen; + } + + private static MethodAttributes GetMethodAttrs (MethodInfo minfo) + { + var methAttributes = minfo.Attributes; + // Always implement/shadow the method + methAttributes &=(~MethodAttributes.Abstract); + methAttributes &=(~MethodAttributes.NewSlot); + methAttributes |= MethodAttributes.ReuseSlot; + methAttributes |= MethodAttributes.HideBySig; + methAttributes |= MethodAttributes.Final; + + if (minfo.IsFinal) + { + // can't override a final method, new it instead. + methAttributes &= (~MethodAttributes.Virtual); + methAttributes |= MethodAttributes.NewSlot; + } + + return methAttributes; + } + + private static void CreateProperty(TypeBuilder tb, PropertyInfo pinfo) + { + string propertyName = pinfo.Name; + Type propertyType = pinfo.PropertyType; + FieldBuilder fieldBuilder = tb.DefineField("_" + propertyName, propertyType, FieldAttributes.Private); + PropertyBuilder propertyBuilder = tb.DefineProperty(propertyName, pinfo.Attributes, propertyType, null); + if (pinfo.GetMethod is not null) + { + var methAttributes = GetMethodAttrs(pinfo.GetMethod); + + MethodBuilder getPropMthdBldr = + tb.DefineMethod("get_" + propertyName, + methAttributes, + propertyType, + Type.EmptyTypes); + GenerateExceptionILCode(getPropMthdBldr); + propertyBuilder.SetGetMethod(getPropMthdBldr); + } + if (pinfo.SetMethod is not null) + { + var methAttributes = GetMethodAttrs(pinfo.SetMethod); + MethodBuilder setPropMthdBldr = + tb.DefineMethod("set_" + propertyName, + methAttributes, + null, + new[] { propertyType }); + + GenerateExceptionILCode(setPropMthdBldr); + propertyBuilder.SetSetMethod(setPropMthdBldr); + } + } + + private static void CreateMethod(TypeBuilder tb, MethodInfo minfo) + { + string methodName = minfo.Name; + + if (dontReimplementMethods.Contains(methodName)) + { + // Some methods must *not* be reimplemented (who wants to throw from Dispose?) + // and some methods we need to implement in a more specific way (Equals, GetHashCode) + return; + } + var methAttributes = GetMethodAttrs(minfo); + var @params = (from paraminfo in minfo.GetParameters() select paraminfo.ParameterType).ToArray(); + MethodBuilder mbuilder = tb.DefineMethod(methodName, methAttributes, minfo.CallingConvention, minfo.ReturnType, @params); + GenerateExceptionILCode(mbuilder); + } + } + + class NotSerializableSerializer : ISerializationSurrogate + { + + public NotSerializableSerializer() + { + } + + public void GetObjectData(object obj, SerializationInfo info, StreamingContext context) + { + // This type is private to System.Runtime.Serialization. We get an + // object of this type when, amongst others, the type didn't exist (yet?) + // (dll not loaded, type was removed/renamed) when we previously + // deserialized the previous domain objects. Don't serialize this + // object. + if (obj.GetType().Name == "TypeLoadExceptionHolder") + { + obj = null!; + return; + } + + MaybeType type = obj.GetType(); + + if (NonSerializedTypeBuilder.IsNonSerializedType(type.Value)) + { + // Don't serialize a _NotSerialized. Serialize the base type, and deserialize as a _NotSerialized + type = type.Value.BaseType; + obj = null!; + } + + info.AddValue("notSerialized_tp", type); + + } + + public object SetObjectData(object obj, SerializationInfo info, StreamingContext context, ISurrogateSelector selector) + { + if (info is null) + { + // `obj` is of type TypeLoadExceptionHolder. This means the type + // we're trying to load doesn't exist anymore or we haven't created + // it yet, and the runtime doesn't even gives us the chance to + // recover from this as info is null. We may even get objects + // this serializer did not serialize in a previous domain, + // like in the case of the "namespace_rename" domain reload + // test: the object successfully serialized, but it cannot be + // deserialized. + // just return null. + return null!; + } + + object nameObj = null!; + try + { + nameObj = info.GetValue("notSerialized_tp", typeof(object)); + } + catch + { + // we didn't find the expected information. We don't know + // what to do with this; return null. + return null!; + } + Debug.Assert(nameObj.GetType() == typeof(MaybeType)); + MaybeType name = (MaybeType)nameObj; + Debug.Assert(name.Valid); + if (!name.Valid) + { + // The type couldn't be loaded + return null!; + } + + obj = NonSerializedTypeBuilder.CreateNewObject(name.Value)!; + return obj; + } + } + + class NonSerializableSelector : SurrogateSelector + { + public override ISerializationSurrogate? GetSurrogate (Type type, StreamingContext context, out ISurrogateSelector selector) + { + if (type is null) + { + throw new ArgumentNullException(); + } + selector = this; + if (type.IsSerializable) + { + return null; // use whichever default + } + else + { + return new NotSerializableSerializer(); + } + } + } + public static class RuntimeData { private static Type? _formatterType; @@ -47,6 +434,73 @@ static void ClearCLRData () } } + internal static void SerializeNonSerializableTypes () + { + // Serialize the Types (Type objects) that couldn't be (de)serialized. + // This needs to be done otherwise at deserialization time we get + // TypeLoadExceptionHolder objects and we can't even recover. + + // We don't serialize the "_NotSerialized" Types, we serialize their base + // to recreate the "_NotSerialized" versions on the next domain load. + + Dictionary invalidTypes = new(); + foreach(var tp in NonSerializedTypeBuilder.assemblyForNonSerializedClasses.GetTypes()) + { + invalidTypes[tp.FullName] = new MaybeType(tp.BaseType); + } + + // delete previous data if any + BorrowedReference oldCapsule = PySys_GetObject("clr_nonSerializedTypes"); + if (!oldCapsule.IsNull) + { + IntPtr oldData = PyCapsule_GetPointer(oldCapsule, IntPtr.Zero); + PyMem_Free(oldData); + PyCapsule_SetPointer(oldCapsule, IntPtr.Zero); + } + IFormatter formatter = CreateFormatter(); + var ms = new MemoryStream(); + formatter.Serialize(ms, invalidTypes); + + Debug.Assert(ms.Length <= int.MaxValue); + byte[] data = ms.GetBuffer(); + + IntPtr mem = PyMem_Malloc(ms.Length + IntPtr.Size); + Marshal.WriteIntPtr(mem, (IntPtr)ms.Length); + Marshal.Copy(data, 0, mem + IntPtr.Size, (int)ms.Length); + + using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); + int res = PySys_SetObject("clr_nonSerializedTypes", capsule.BorrowOrThrow()); + PythonException.ThrowIfIsNotZero(res); + + } + + internal static void DeserializeNonSerializableTypes () + { + BorrowedReference capsule = PySys_GetObject("clr_nonSerializedTypes"); + if (capsule.IsNull) + { + // nothing to do. + return; + } + // get the memory stream from the capsule. + IntPtr mem = PyCapsule_GetPointer(capsule, IntPtr.Zero); + int length = (int)Marshal.ReadIntPtr(mem); + byte[] data = new byte[length]; + Marshal.Copy(mem + IntPtr.Size, data, 0, length); + var ms = new MemoryStream(data); + var formatter = CreateFormatter(); + var storage = (Dictionary)formatter.Deserialize(ms); + foreach(var item in storage) + { + if(item.Value.Valid) + { + // recreate the "_NotSerialized" Types + NonSerializedTypeBuilder.CreateType(item.Value.Value); + } + } + + } + internal static void Stash() { var runtimeStorage = new PythonNetState @@ -74,6 +528,8 @@ internal static void Stash() using NewReference capsule = PyCapsule_New(mem, IntPtr.Zero, IntPtr.Zero); int res = PySys_SetObject("clr_data", capsule.BorrowOrThrow()); PythonException.ThrowIfIsNotZero(res); + SerializeNonSerializableTypes(); + } internal static void RestoreRuntimeData() @@ -90,6 +546,9 @@ internal static void RestoreRuntimeData() private static void RestoreRuntimeDataImpl() { + // The "_NotSerialized" Types must exist before the rest of the data + // is deserialized. + DeserializeNonSerializableTypes(); BorrowedReference capsule = PySys_GetObject("clr_data"); if (capsule.IsNull) { @@ -123,19 +582,6 @@ public static void ClearStash() PySys_SetObject("clr_data", default); } - static bool CheckSerializable (object o) - { - Type type = o.GetType(); - do - { - if (!type.IsSerializable) - { - return false; - } - } while ((type = type.BaseType) != null); - return true; - } - private static SharedObjectsState SaveRuntimeDataObjects() { var contexts = new Dictionary>(PythonReferenceComparer.Instance); @@ -150,7 +596,6 @@ private static SharedObjectsState SaveRuntimeDataObjects() foreach (var pyObj in extensions) { var extension = (ExtensionType)ManagedType.GetManagedObject(pyObj)!; - Debug.Assert(CheckSerializable(extension)); var context = extension.Save(pyObj); if (context is not null) { @@ -170,6 +615,7 @@ private static SharedObjectsState SaveRuntimeDataObjects() .ToList(); foreach (var pyObj in reflectedObjects) { + // Console.WriteLine($"saving object: {pyObj} {pyObj.rawPtr} "); // Wrapper must be the CLRObject var clrObj = (CLRObject)ManagedType.GetManagedObject(pyObj)!; object inst = clrObj.inst; @@ -199,10 +645,6 @@ private static SharedObjectsState SaveRuntimeDataObjects() { if (!item.Stored) { - if (!CheckSerializable(item.Instance)) - { - continue; - } var clrO = wrappers[item.Instance].First(); foreach (var @ref in item.PyRefs) { @@ -254,7 +696,10 @@ internal static IFormatter CreateFormatter() { return FormatterType != null ? (IFormatter)Activator.CreateInstance(FormatterType) - : new BinaryFormatter(); + : new BinaryFormatter() + { + SurrogateSelector = new NonSerializableSelector(), + }; } } } diff --git a/tests/domain_tests/TestRunner.cs b/tests/domain_tests/TestRunner.cs index 4f6a3ea28..0097ea202 100644 --- a/tests/domain_tests/TestRunner.cs +++ b/tests/domain_tests/TestRunner.cs @@ -1132,6 +1132,168 @@ import System ", }, + new TestCase + { + Name = "serialize_not_serializable", + DotNetBefore = @" + namespace TestNamespace + { + + public class NotSerializableTextWriter : System.IO.TextWriter + { + override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} } + } + + [System.Serializable] + public static class SerializableWriter + { + private static System.IO.TextWriter _writer = null; + + public static System.IO.TextWriter Writer {get { return _writer; }} + + public static void SetWriter() + { + _writer = System.IO.TextWriter.Synchronized(new NotSerializableTextWriter()); + } + } + }", + DotNetAfter = @" + namespace TestNamespace + { + + public class NotSerializableTextWriter : System.IO.TextWriter + { + override public System.Text.Encoding Encoding { get { return System.Text.Encoding.ASCII;} } + } + + [System.Serializable] + public static class SerializableWriter + { + private static System.IO.TextWriter _writer = null; + + public static System.IO.TextWriter Writer {get { return _writer; }} + + public static void SetWriter(System.IO.TextWriter w) + { + _writer = System.IO.TextWriter.Synchronized(w); + } + } + }", + PythonCode = @" +import clr +import sys +clr.AddReference('DomainTests') +import TestNamespace +import System + +def before_reload(): + + TestNamespace.SerializableWriter.SetWriter(); + sys.log_writer = TestNamespace.SerializableWriter.Writer + +def after_reload(): + + assert sys.log_writer is not None + try: + encoding = sys.log_writer.Write('baba') + except System.Runtime.Serialization.SerializationException: + pass + else: + raise AssertionError('Serialized non-serializable objects should be deserialized to throwing objects') +", + }, + new TestCase + { + Name = "serialize_not_serializable_interface", + DotNetBefore = @" + namespace TestNamespace + { + public interface MyInterface + { + int InterfaceMethod(); + } + + public class NotSerializableInterfaceImplement : MyInterface + { + public int value = -1; + int MyInterface.InterfaceMethod() + { + return value; + } + } + + [System.Serializable] + public static class SerializableWriter + { + private static MyInterface _iface = null; + + public static MyInterface Writer {get { return _iface; }} + + public static void SetInterface() + { + var temp = new NotSerializableInterfaceImplement(); + temp.value = 12315; + _iface = temp; + } + } + }", + DotNetAfter = @" + namespace TestNamespace + { + public interface MyInterface + { + int InterfaceMethod(); + } + + public class NotSerializableInterfaceImplement : MyInterface + { + public int value = -1; + int MyInterface.InterfaceMethod() + { + return value; + } + } + + [System.Serializable] + public static class SerializableWriter + { + private static MyInterface _iface = null; + + public static MyInterface Writer {get { return _iface; }} + + public static void SetInterface() + { + var temp = new NotSerializableInterfaceImplement(); + temp.value = 123124; + _iface = temp; + } + } + }", + PythonCode = @" +import clr +import sys +clr.AddReference('DomainTests') +import TestNamespace +import System + +def before_reload(): + + TestNamespace.SerializableWriter.SetInterface(); + sys.log_writer = TestNamespace.SerializableWriter.Writer + assert(sys.log_writer.InterfaceMethod() == 12315) + +def after_reload(): + + assert sys.log_writer is not None + try: + retcode = sys.log_writer.InterfaceMethod() + print(f'retcode of InterfaceMethod is {retcode}') + except System.Runtime.Serialization.SerializationException: + pass + else: + raise AssertionError('Serialized non-serializable objects should be deserialized to throwing objects') +", + }, }; /// diff --git a/tests/domain_tests/test_domain_reload.py b/tests/domain_tests/test_domain_reload.py index 8999e481b..bbff73493 100644 --- a/tests/domain_tests/test_domain_reload.py +++ b/tests/domain_tests/test_domain_reload.py @@ -1,6 +1,7 @@ import subprocess import os import platform +from unittest import skip import pytest @@ -88,3 +89,10 @@ def test_nested_type(): def test_import_after_reload(): _run_test("import_after_reload") + +def test_serialize_not_serializable(): + _run_test("serialize_not_serializable") + +@skip("interface methods cannot be overriden") +def test_serialize_not_serializable_interface(): + _run_test("serialize_not_serializable_interface") \ No newline at end of file pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy