Skip to content

Commit 27a9830

Browse files
authored
[mypyc] Simplify IR for tagged integer comparisons (python#9607)
In a conditional context, such as in an if condition, simplify the IR for tagged integer comparisons. Also perform some additional optimizations if an operand is known to be a short integer. This slightly improves performance when compiling with no optimizations. The impact should be pretty negligible otherwise. This is a bit simple-minded, and some further optimizations are possible. For example, `3 < x < 6` could be made faster. This covers the most common cases, however. Closes mypyc/mypyc#758.
1 parent 3acbf3f commit 27a9830

File tree

10 files changed

+821
-1019
lines changed

10 files changed

+821
-1019
lines changed

mypyc/irbuild/builder.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from mypy.nodes import (
2020
MypyFile, SymbolNode, Statement, OpExpr, IntExpr, NameExpr, LDEF, Var, UnaryExpr,
2121
CallExpr, IndexExpr, Expression, MemberExpr, RefExpr, Lvalue, TupleExpr,
22-
TypeInfo, Decorator, OverloadedFuncDef, StarExpr, GDEF, ARG_POS, ARG_NAMED
22+
TypeInfo, Decorator, OverloadedFuncDef, StarExpr, ComparisonExpr, GDEF, ARG_POS, ARG_NAMED
2323
)
2424
from mypy.types import (
2525
Type, Instance, TupleType, UninhabitedType, get_proper_type
@@ -39,7 +39,7 @@
3939
from mypyc.ir.rtypes import (
4040
RType, RTuple, RInstance, int_rprimitive, dict_rprimitive,
4141
none_rprimitive, is_none_rprimitive, object_rprimitive, is_object_rprimitive,
42-
str_rprimitive,
42+
str_rprimitive, is_tagged
4343
)
4444
from mypyc.ir.func_ir import FuncIR, INVALID_FUNC_DEF
4545
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
@@ -813,11 +813,45 @@ def process_conditional(self, e: Expression, true: BasicBlock, false: BasicBlock
813813
self.process_conditional(e.right, true, false)
814814
elif isinstance(e, UnaryExpr) and e.op == 'not':
815815
self.process_conditional(e.expr, false, true)
816-
# Catch-all for arbitrary expressions.
817816
else:
817+
res = self.maybe_process_conditional_comparison(e, true, false)
818+
if res:
819+
return
820+
# Catch-all for arbitrary expressions.
818821
reg = self.accept(e)
819822
self.add_bool_branch(reg, true, false)
820823

824+
def maybe_process_conditional_comparison(self,
825+
e: Expression,
826+
true: BasicBlock,
827+
false: BasicBlock) -> bool:
828+
"""Transform simple tagged integer comparisons in a conditional context.
829+
830+
Return True if the operation is supported (and was transformed). Otherwise,
831+
do nothing and return False.
832+
833+
Args:
834+
e: Arbitrary expression
835+
true: Branch target if comparison is true
836+
false: Branch target if comparison is false
837+
"""
838+
if not isinstance(e, ComparisonExpr) or len(e.operands) != 2:
839+
return False
840+
ltype = self.node_type(e.operands[0])
841+
rtype = self.node_type(e.operands[1])
842+
if not is_tagged(ltype) or not is_tagged(rtype):
843+
return False
844+
op = e.operators[0]
845+
if op not in ('==', '!=', '<', '<=', '>', '>='):
846+
return False
847+
left = self.accept(e.operands[0])
848+
right = self.accept(e.operands[1])
849+
# "left op right" for two tagged integers
850+
self.builder.compare_tagged_condition(left, right, op, true, false, e.line)
851+
return True
852+
853+
# Basic helpers
854+
821855
def flatten_classes(self, arg: Union[RefExpr, TupleExpr]) -> Optional[List[ClassIR]]:
822856
"""Flatten classes in isinstance(obj, (A, (B, C))).
823857
@@ -841,8 +875,6 @@ def flatten_classes(self, arg: Union[RefExpr, TupleExpr]) -> Optional[List[Class
841875
return None
842876
return res
843877

844-
# Basic helpers
845-
846878
def enter(self, fn_info: Union[FuncInfo, str] = '') -> None:
847879
if isinstance(fn_info, str):
848880
fn_info = FuncInfo(name=fn_info)

mypyc/irbuild/ll_builder.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -589,17 +589,21 @@ def binary_op(self,
589589
assert target, 'Unsupported binary operation: %s' % op
590590
return target
591591

592-
def check_tagged_short_int(self, val: Value, line: int) -> Value:
593-
"""Check if a tagged integer is a short integer"""
592+
def check_tagged_short_int(self, val: Value, line: int, negated: bool = False) -> Value:
593+
"""Check if a tagged integer is a short integer.
594+
595+
Return the result of the check (value of type 'bit').
596+
"""
594597
int_tag = self.add(LoadInt(1, line, rtype=c_pyssize_t_rprimitive))
595598
bitwise_and = self.binary_int_op(c_pyssize_t_rprimitive, val,
596599
int_tag, BinaryIntOp.AND, line)
597600
zero = self.add(LoadInt(0, line, rtype=c_pyssize_t_rprimitive))
598-
check = self.comparison_op(bitwise_and, zero, ComparisonOp.EQ, line)
601+
op = ComparisonOp.NEQ if negated else ComparisonOp.EQ
602+
check = self.comparison_op(bitwise_and, zero, op, line)
599603
return check
600604

601605
def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
602-
"""Compare two tagged integers using given op"""
606+
"""Compare two tagged integers using given operator (value context)."""
603607
# generate fast binary logic ops on short ints
604608
if is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type):
605609
return self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
@@ -610,13 +614,11 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
610614
if op in ("==", "!="):
611615
check = check_lhs
612616
else:
613-
# for non-equal logical ops(less than, greater than, etc.), need to check both side
617+
# for non-equality logical ops (less/greater than, etc.), need to check both sides
614618
check_rhs = self.check_tagged_short_int(rhs, line)
615619
check = self.binary_int_op(bit_rprimitive, check_lhs,
616620
check_rhs, BinaryIntOp.AND, line)
617-
branch = Branch(check, short_int_block, int_block, Branch.BOOL)
618-
branch.negated = False
619-
self.add(branch)
621+
self.add(Branch(check, short_int_block, int_block, Branch.BOOL))
620622
self.activate_block(short_int_block)
621623
eq = self.comparison_op(lhs, rhs, op_type, line)
622624
self.add(Assign(result, eq, line))
@@ -636,6 +638,60 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
636638
self.goto_and_activate(out)
637639
return result
638640

641+
def compare_tagged_condition(self,
642+
lhs: Value,
643+
rhs: Value,
644+
op: str,
645+
true: BasicBlock,
646+
false: BasicBlock,
647+
line: int) -> None:
648+
"""Compare two tagged integers using given operator (conditional context).
649+
650+
Assume lhs and and rhs are tagged integers.
651+
652+
Args:
653+
lhs: Left operand
654+
rhs: Right operand
655+
op: Operation, one of '==', '!=', '<', '<=', '>', '<='
656+
true: Branch target if comparison is true
657+
false: Branch target if comparison is false
658+
"""
659+
is_eq = op in ("==", "!=")
660+
if ((is_short_int_rprimitive(lhs.type) and is_short_int_rprimitive(rhs.type))
661+
or (is_eq and (is_short_int_rprimitive(lhs.type) or
662+
is_short_int_rprimitive(rhs.type)))):
663+
# We can skip the tag check
664+
check = self.comparison_op(lhs, rhs, int_comparison_op_mapping[op][0], line)
665+
self.add(Branch(check, true, false, Branch.BOOL))
666+
return
667+
op_type, c_func_desc, negate_result, swap_op = int_comparison_op_mapping[op]
668+
int_block, short_int_block = BasicBlock(), BasicBlock()
669+
check_lhs = self.check_tagged_short_int(lhs, line, negated=True)
670+
if is_eq or is_short_int_rprimitive(rhs.type):
671+
self.add(Branch(check_lhs, int_block, short_int_block, Branch.BOOL))
672+
else:
673+
# For non-equality logical ops (less/greater than, etc.), need to check both sides
674+
rhs_block = BasicBlock()
675+
self.add(Branch(check_lhs, int_block, rhs_block, Branch.BOOL))
676+
self.activate_block(rhs_block)
677+
check_rhs = self.check_tagged_short_int(rhs, line, negated=True)
678+
self.add(Branch(check_rhs, int_block, short_int_block, Branch.BOOL))
679+
# Arbitrary integers (slow path)
680+
self.activate_block(int_block)
681+
if swap_op:
682+
args = [rhs, lhs]
683+
else:
684+
args = [lhs, rhs]
685+
call = self.call_c(c_func_desc, args, line)
686+
if negate_result:
687+
self.add(Branch(call, false, true, Branch.BOOL))
688+
else:
689+
self.add(Branch(call, true, false, Branch.BOOL))
690+
# Short integers (fast path)
691+
self.activate_block(short_int_block)
692+
eq = self.comparison_op(lhs, rhs, op_type, line)
693+
self.add(Branch(eq, true, false, Branch.BOOL))
694+
639695
def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
640696
"""Compare two strings"""
641697
compare_result = self.call_c(unicode_compare, [lhs, rhs], line)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy