Skip to content

Commit

Permalink
Fix equality for nested boxes
Browse files Browse the repository at this point in the history
  • Loading branch information
btwj committed Jan 5, 2024
1 parent 0f48099 commit cf35b57
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 14 deletions.
19 changes: 8 additions & 11 deletions aeneas/src/ir/Normalization.v3
Original file line number Diff line number Diff line change
Expand Up @@ -614,12 +614,11 @@ class ReachabilityNormalizer(config: NormalizerConfig, ra: ReachabilityAnalyzer)
boxIc.facts |= Fact.C_HEAP;
box.ic = boxIc;

if (rc.raFacts.RC_EQUALITY) {
var equality = IrMethod.new(box.boxType, null, Function.sig(box.boxType, Bool.TYPE));
boxIc.methods[IrUtil.EQUALS_METHOD_INDEX] = equality;
equality.index = IrUtil.EQUALS_METHOD_INDEX;
newIr.methods.put(equality);
}
var equality = IrMethod.new(box.boxType, null, Function.sig(box.boxType, Bool.TYPE));
boxIc.methods[IrUtil.EQUALS_METHOD_INDEX] = equality;
equality.index = IrUtil.EQUALS_METHOD_INDEX;
equality.facts |= Fact.M_EQUALS;
newIr.methods.put(equality);

newIr.setIrClass(box.boxType, boxIc);
newIr.methods.put(constructor);
Expand All @@ -646,12 +645,10 @@ class ReachabilityNormalizer(config: NormalizerConfig, ra: ReachabilityAnalyzer)
i++;
}

if (rc.variantNorm != null && rc.variantNorm.boxes.length > 0) {
if (rc.variantNorm != null) {
for (box in rc.variantNorm.boxes) {
if (rc.raFacts.RC_EQUALITY) {
var equality = box.ic.methods[IrUtil.EQUALS_METHOD_INDEX];
equality.ssa = BoxComparatorGen.new(this, context, box.ic, equality).generate();
}
var equality = box.ic.methods[IrUtil.EQUALS_METHOD_INDEX];
equality.ssa = BoxComparatorGen.new(this, context, box.ic, equality).generate();
}
}
}
Expand Down
14 changes: 13 additions & 1 deletion aeneas/src/ir/Reachability.v3
Original file line number Diff line number Diff line change
Expand Up @@ -929,9 +929,21 @@ class BoxComparatorGen(rn: ReachabilityNormalizer, context: SsaContext, receiver
// XXX: share code with SsaRaNormalizer
if (V3.isBox(t)) {
var ic = rn.newIr.getIrClass(t);
var x_null = b.opEqual(AnyRef.TYPE, f0, graph.nullConst(t));
var y_null = b.opEqual(AnyRef.TYPE, f1, graph.nullConst(t));
var bothNull = b.opBoolAnd0(x_null, y_null);

var split = SsaBlockSplit.new(context, b);
b = split.addIf(bothNull);
b = split.addElse();
var equals = ic.methods[IrUtil.EQUALS_METHOD_INDEX];
var call = V3Op.newCallMethod(IrSpec.new(t, [t], equals));
cmp = b.addApply(null, call, [f0, f1]);
var app = b.addApply(null, call, [f0, f1]);

b = split.finish();
var result = split.addPhi(Bool.TYPE, [graph.trueConst(), app]);

cmp = result;
} else if (V3.isVariant(t)) {
var rc = rn.ra.getClass(t);
var list = rc.methods[IrUtil.EQUALS_METHOD_INDEX];
Expand Down
14 changes: 13 additions & 1 deletion aeneas/src/ir/SsaNormalizer.v3
Original file line number Diff line number Diff line change
Expand Up @@ -706,9 +706,21 @@ class SsaRaNormalizer extends SsaRebuilder {
}
def normBoxEqualOp(oldApp: SsaApplyOp, t: Type, x: SsaInstr, y: SsaInstr) -> SsaInstr {
var ic = norm.newIr.getIrClass(t);
var x_null = curBlock.opEqual(AnyRef.TYPE, x, newGraph.nullConst(t));
var y_null = curBlock.opEqual(AnyRef.TYPE, y, newGraph.nullConst(t));
var bothNull = curBlock.opBoolAnd0(x_null, y_null);

var split = SsaBlockSplit.new(context, curBlock);
curBlock = split.addIf(bothNull);
curBlock = split.addElse();
var equals = ic.methods[IrUtil.EQUALS_METHOD_INDEX];
var call = V3Op.newCallMethod(IrSpec.new(t, [t], equals));
return curBlock.addApply(oldApp.source, call, [x, y]);
var app = curBlock.addApply(oldApp.source, call, [x, y]);

curBlock = split.finish();
var result = split.addPhi(Bool.TYPE, [newGraph.trueConst(), app]);

return result;
}
def normVariantEqual(oldApp: SsaApplyOp, t: Type, x: SsaInstr, y: SsaInstr) -> SsaInstr {
var rc = norm.ra.getClass(t);
Expand Down
3 changes: 2 additions & 1 deletion aeneas/src/jvm/JvmGen.v3
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class JvmProgram(compiler: Compiler, prog: Program) {
}
V3Kind.CLASS => jvmType(utype);
V3Kind.COMPONENT => jvmType(utype);
V3Kind.BOX => jvmType(utype);
}
}
def getIrClass(vtype: Type) -> IrClass {
Expand Down Expand Up @@ -359,7 +360,7 @@ class JvmV3ClassGen extends JvmClassGen {
}
if (ic.inherits(m)) return;
// this method is declared in this type.
var name = V3.mangleIrMember(if(m.source == null, ic.root(m), m));
var name = V3.mangleIrMember(if(m.source == null && !m.facts.M_EQUALS, ic.root(m), m));
var methType = m.getMethodType();
var jsig = jvmSig(methType);
if (m.facts.M_ABSTRACT) {
Expand Down

0 comments on commit cf35b57

Please sign in to comment.