Skip to content

Commit 47a7a00

Browse files
authored
Rework issubclass (#5867)
* check_exact * check_class * Type compatibility tools * abstract_issubclass * recursive_issubclass
1 parent 7a6e5c4 commit 47a7a00

File tree

2 files changed

+123
-55
lines changed

2 files changed

+123
-55
lines changed

vm/src/builtins/type.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,18 @@ fn downcast_qualname(value: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyRef<
158158
}
159159
}
160160

161+
fn is_subtype_with_mro(a_mro: &[PyTypeRef], a: &Py<PyType>, b: &Py<PyType>) -> bool {
162+
if a.is(b) {
163+
return true;
164+
}
165+
for item in a_mro {
166+
if item.is(b) {
167+
return true;
168+
}
169+
}
170+
false
171+
}
172+
161173
impl PyType {
162174
pub fn new_simple_heap(
163175
name: &str,
@@ -197,6 +209,12 @@ impl PyType {
197209
Self::new_heap_inner(base, bases, attrs, slots, heaptype_ext, metaclass, ctx)
198210
}
199211

212+
/// Equivalent to CPython's PyType_Check macro
213+
/// Checks if obj is an instance of type (or its subclass)
214+
pub(crate) fn check(obj: &PyObject) -> Option<&Py<Self>> {
215+
obj.downcast_ref::<Self>()
216+
}
217+
200218
fn resolve_mro(bases: &[PyRef<Self>]) -> Result<Vec<PyTypeRef>, String> {
201219
// Check for duplicates in bases.
202220
let mut unique_bases = HashSet::new();
@@ -439,6 +457,16 @@ impl PyType {
439457
}
440458

441459
impl Py<PyType> {
460+
pub(crate) fn is_subtype(&self, other: &Py<PyType>) -> bool {
461+
is_subtype_with_mro(&self.mro.read(), self, other)
462+
}
463+
464+
/// Equivalent to CPython's PyType_CheckExact macro
465+
/// Checks if obj is exactly a type (not a subclass)
466+
pub fn check_exact<'a>(obj: &'a PyObject, vm: &VirtualMachine) -> Option<&'a Py<PyType>> {
467+
obj.downcast_ref_if_exact::<PyType>(vm)
468+
}
469+
442470
/// Determines if `subclass` is actually a subclass of `cls`, this doesn't call __subclasscheck__,
443471
/// so only use this if `cls` is known to have not overridden the base __subclasscheck__ magic
444472
/// method.

vm/src/protocol/object.rs

Lines changed: 95 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -371,80 +371,120 @@ impl PyObject {
371371
})
372372
}
373373

374-
// Equivalent to check_class. Masks Attribute errors (into TypeErrors) and lets everything
375-
// else go through.
376-
fn check_cls<F>(&self, cls: &PyObject, vm: &VirtualMachine, msg: F) -> PyResult
374+
// Equivalent to CPython's check_class. Returns Ok(()) if cls is a valid class,
375+
// Err with TypeError if not. Uses abstract_get_bases internally.
376+
fn check_class<F>(&self, vm: &VirtualMachine, msg: F) -> PyResult<()>
377377
where
378378
F: Fn() -> String,
379379
{
380-
cls.get_attr(identifier!(vm, __bases__), vm).map_err(|e| {
381-
// Only mask AttributeErrors.
382-
if e.class().is(vm.ctx.exceptions.attribute_error) {
383-
vm.new_type_error(msg())
384-
} else {
385-
e
380+
let cls = self;
381+
match cls.abstract_get_bases(vm)? {
382+
Some(_bases) => Ok(()), // Has __bases__, it's a valid class
383+
None => {
384+
// No __bases__ or __bases__ is not a tuple
385+
Err(vm.new_type_error(msg()))
386386
}
387-
})
387+
}
388388
}
389389

390-
fn abstract_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
391-
let mut derived = self;
392-
let mut first_item: PyObjectRef;
393-
loop {
394-
if derived.is(cls) {
395-
return Ok(true);
390+
/// abstract_get_bases() has logically 4 return states:
391+
/// 1. getattr(cls, '__bases__') could raise an AttributeError
392+
/// 2. getattr(cls, '__bases__') could raise some other exception
393+
/// 3. getattr(cls, '__bases__') could return a tuple
394+
/// 4. getattr(cls, '__bases__') could return something other than a tuple
395+
///
396+
/// Only state #3 returns Some(tuple). AttributeErrors are masked by returning None.
397+
/// If an object other than a tuple comes out of __bases__, then again, None is returned.
398+
/// Other exceptions are propagated.
399+
fn abstract_get_bases(&self, vm: &VirtualMachine) -> PyResult<Option<PyTupleRef>> {
400+
match vm.get_attribute_opt(self.to_owned(), identifier!(vm, __bases__))? {
401+
Some(bases) => {
402+
// Check if it's a tuple
403+
match PyTupleRef::try_from_object(vm, bases) {
404+
Ok(tuple) => Ok(Some(tuple)),
405+
Err(_) => Ok(None), // Not a tuple, return None
406+
}
396407
}
408+
None => Ok(None), // AttributeError was masked
409+
}
410+
}
397411

398-
let bases = derived.get_attr(identifier!(vm, __bases__), vm)?;
399-
let tuple = PyTupleRef::try_from_object(vm, bases)?;
412+
fn abstract_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
413+
// # Safety: The lifetime of `derived` is forced to be ignored
414+
let bases = unsafe {
415+
let mut derived = self;
416+
// First loop: handle single inheritance without recursion
417+
loop {
418+
if derived.is(cls) {
419+
return Ok(true);
420+
}
400421

401-
let n = tuple.len();
402-
match n {
403-
0 => {
422+
let Some(bases) = derived.abstract_get_bases(vm)? else {
404423
return Ok(false);
405-
}
406-
1 => {
407-
first_item = tuple[0].clone();
408-
derived = &first_item;
409-
continue;
410-
}
411-
_ => {
412-
for i in 0..n {
413-
let check = vm.with_recursion("in abstract_issubclass", || {
414-
tuple[i].abstract_issubclass(cls, vm)
415-
})?;
416-
if check {
417-
return Ok(true);
418-
}
424+
};
425+
let n = bases.len();
426+
match n {
427+
0 => return Ok(false),
428+
1 => {
429+
// Avoid recursion in the single inheritance case
430+
// # safety
431+
// Intention:
432+
// ```
433+
// derived = bases.as_slice()[0].as_object();
434+
// ```
435+
// Though type-system cannot guarantee, derived does live long enough in the loop.
436+
derived = &*(bases.as_slice()[0].as_object() as *const _);
437+
continue;
438+
}
439+
_ => {
440+
// Multiple inheritance - break out to handle recursively
441+
break bases;
419442
}
420443
}
421444
}
445+
};
422446

423-
return Ok(false);
447+
// Second loop: handle multiple inheritance with recursion
448+
// At this point we know n >= 2
449+
let n = bases.len();
450+
debug_assert!(n >= 2);
451+
452+
for i in 0..n {
453+
let result = vm.with_recursion("in __issubclass__", || {
454+
bases.as_slice()[i].abstract_issubclass(cls, vm)
455+
})?;
456+
if result {
457+
return Ok(true);
458+
}
424459
}
460+
461+
Ok(false)
425462
}
426463

427464
fn recursive_issubclass(&self, cls: &PyObject, vm: &VirtualMachine) -> PyResult<bool> {
428-
if let (Ok(obj), Ok(cls)) = (self.try_to_ref::<PyType>(vm), cls.try_to_ref::<PyType>(vm)) {
429-
Ok(obj.fast_issubclass(cls))
430-
} else {
431-
// Check if derived is a class
432-
self.check_cls(self, vm, || {
433-
format!("issubclass() arg 1 must be a class, not {}", self.class())
465+
// Fast path for both being types (matches CPython's PyType_Check)
466+
if let Some(cls) = PyType::check(cls)
467+
&& let Some(derived) = PyType::check(self)
468+
{
469+
// PyType_IsSubtype equivalent
470+
return Ok(derived.is_subtype(cls));
471+
}
472+
// Check if derived is a class
473+
self.check_class(vm, || {
474+
format!("issubclass() arg 1 must be a class, not {}", self.class())
475+
})?;
476+
477+
// Check if cls is a class, tuple, or union (matches CPython's order and message)
478+
if !cls.class().is(vm.ctx.types.union_type) {
479+
cls.check_class(vm, || {
480+
format!(
481+
"issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}",
482+
cls.class()
483+
)
434484
})?;
435-
436-
// Check if cls is a class, tuple, or union
437-
if !cls.class().is(vm.ctx.types.union_type) {
438-
self.check_cls(cls, vm, || {
439-
format!(
440-
"issubclass() arg 2 must be a class, a tuple of classes, or a union, not {}",
441-
cls.class()
442-
)
443-
})?;
444-
}
445-
446-
self.abstract_issubclass(cls, vm)
447485
}
486+
487+
self.abstract_issubclass(cls, vm)
448488
}
449489

450490
/// Real issubclass check without going through __subclasscheck__
@@ -520,7 +560,7 @@ impl PyObject {
520560
Ok(retval)
521561
} else {
522562
// Not a type object, check if it's a valid class
523-
self.check_cls(cls, vm, || {
563+
cls.check_class(vm, || {
524564
format!(
525565
"isinstance() arg 2 must be a type, a tuple of types, or a union, not {}",
526566
cls.class()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy