Skip to content

Commit

Permalink
Add switch to variant equality check
Browse files Browse the repository at this point in the history
  • Loading branch information
btwj committed Feb 6, 2024
1 parent edf6b0b commit 2a275b2
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 46 deletions.
20 changes: 12 additions & 8 deletions aeneas/src/ir/Normalization.v3
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ class ReachabilityNormalizer(config: NormalizerConfig, ra: ReachabilityAnalyzer)
var idx = 0;
for (i < box.origRanges.length) {
var or = box.origRanges[i];
var fieldArray = normFieldIntoArray(box.origFields[i], oldRecord.values[i], or.1 - or.0);
var fieldArray = if(or.1 - or.0 > 0, normFieldIntoArray(box.origFields[i], oldRecord.values[i], or.1 - or.0));
for (j < or.1 - or.0) record.values[idx++] = fieldArray[j];
}
array[boxIndex] = record;
Expand Down Expand Up @@ -614,11 +614,13 @@ class ReachabilityNormalizer(config: NormalizerConfig, ra: ReachabilityAnalyzer)
boxIc.facts |= Fact.C_HEAP;
box.ic = boxIc;

var equality = IrMethod.new(box.boxType, null, Function.sig(box.boxType, Bool.TYPE));
boxIc.methods[IrUtil.EQUALS_METHOD_INDEX] = equality;
equality.index = IrUtil.EQUALS_METHOD_INDEX;
equality.facts |= Fact.M_EQUALS;
newIr.methods.put(equality);
if (rc.raFacts.RC_EQUALITY) {
var equality = IrMethod.new(box.boxType, null, Function.sig(box.boxType, Bool.TYPE));
boxIc.methods[IrUtil.EQUALS_METHOD_INDEX] = equality;
equality.index = IrUtil.EQUALS_METHOD_INDEX;
equality.facts |= Fact.M_EQUALS;
newIr.methods.put(equality);
}

newIr.setIrClass(box.boxType, boxIc);
newIr.methods.put(constructor);
Expand Down Expand Up @@ -647,8 +649,10 @@ class ReachabilityNormalizer(config: NormalizerConfig, ra: ReachabilityAnalyzer)

if (rc.variantNorm != null) {
for (box in rc.variantNorm.boxes) {
var equality = box.ic.methods[IrUtil.EQUALS_METHOD_INDEX];
equality.ssa = BoxComparatorGen.new(this, context, box.ic, equality).generate();
if (rc.raFacts.RC_EQUALITY) {
var equality = box.ic.methods[IrUtil.EQUALS_METHOD_INDEX];
equality.ssa = BoxComparatorGen.new(this, context, box.ic, equality).generate();
}
}
}
}
Expand Down
12 changes: 2 additions & 10 deletions aeneas/src/ir/Reachability.v3
Original file line number Diff line number Diff line change
Expand Up @@ -897,21 +897,13 @@ class BoxComparatorGen(rn: ReachabilityNormalizer, context: SsaContext, receiver
method.ssa = graph = context.graph = SsaGraph.new(params, Bool.TYPE);

var b = SsaBuilder.new(context, graph, graph.startBlock);
var refsEq = b.opEqual(AnyRef.TYPE, p0, p1);

var refsEqBlock = SsaBlock.new(), refsNotEqBlock = SsaBlock.new();
SsaBuilder.new(context, graph, refsEqBlock).addReturn([graph.trueConst()]);
b.addIf(refsEq, refsEqBlock, refsNotEqBlock);

var fieldChecks = SsaBlock.new();
var falseBlock = SsaBlock.new();
SsaBuilder.new(context, graph, falseBlock).addReturn([graph.boolConst(false)]);

b = SsaBuilder.new(context, graph, refsNotEqBlock);
p1 = b.add(V3Op.newVariantReplaceNull(receiver.ctype), [p1], Facts.NONE);

var s = SsaRaNormalizer.new(context, rn);
s.newGraph = graph;
s.curBlock = SsaBuilder.new(context, graph, refsNotEqBlock);
s.curBlock = SsaBuilder.new(context, graph, graph.startBlock);

for (f in receiver.fields) {
var spec = IrSpec.new(receiver.ctype, [receiver.ctype], f);
Expand Down
85 changes: 60 additions & 25 deletions aeneas/src/ir/SsaNormalizer.v3
Original file line number Diff line number Diff line change
Expand Up @@ -681,10 +681,13 @@ class SsaRaNormalizer extends SsaRebuilder {
return normEqualOp0(oldApp.source, tn, refs);
}
def normEqualOp0(source: Source, tn: TypeNorm, refs: Array<SsaInstr>) -> SsaInstr {
if (V3.isVariant(tn.newType) && !VariantNorm.?(tn)) return normVariantEqual(source, tn.newType, refs[0], refs[1]);
if (V3.isVariant(tn.newType)) return normVariantEqual(source, tn.newType, refs[0], refs[1]);
if (VariantNorm.?(tn)) return normUnboxedVariantEqual(source, VariantNorm.!(tn), refs);
return normEqualOps(source, tn, refs);
}
def normEqualOps(source: Source, tn: TypeNorm, refs: Array<SsaInstr>) -> SsaInstr {
if (tn.size == 0) return newGraph.trueConst();
if (tn.size == 1) return normSingleEqualOp(source, tn.newType, refs[0], refs[1]);

var join = opBoolAnd;
var expr: SsaInstr;
for (i < tn.size) {
Expand All @@ -702,25 +705,19 @@ class SsaRaNormalizer extends SsaRebuilder {
}
def normSingleEqualOp(source: Source, t: Type, x: SsaInstr, y: SsaInstr) -> SsaInstr {
var tn = normType(t);
if (V3.isVariant(t) && !VariantNorm.?(tn)) return normVariantEqual(source, t, x, y);
if (V3.isVariant(tn.newType)) return normVariantEqual(source, t, x, y);
if (VariantNorm.?(tn)) return normUnboxedVariantEqual(source, VariantNorm.!(tn), [x, y]);
if (V3.isBox(t)) return normBoxEqualOp(source, t, x, y);
return curBlock.opEqual(t, x, y);
}
def normBoxEqualOp(source: Source, t: Type, x: SsaInstr, y: SsaInstr) -> SsaInstr {
var ic = norm.newIr.getIrClass(t);
var refsEqual = curBlock.opEqual(AnyRef.TYPE, x, y);

var split = SsaBlockSplit.new(context, curBlock);
curBlock = split.addIf(refsEqual);
curBlock = split.addElse();
x = curBlock.add(V3Op.newVariantReplaceNull(t), [x], Facts.NONE);
y = curBlock.add(V3Op.newVariantReplaceNull(t), [y], Facts.NONE);
var equals = ic.methods[IrUtil.EQUALS_METHOD_INDEX];
var call = V3Op.newCallMethod(IrSpec.new(t, [t], equals));
x = curBlock.add(V3Op.newVariantReplaceNull(t), [x], Facts.NONE);
var app = curBlock.addApply(source, call, [x, y]);

curBlock = split.finish();
var result = split.addPhi(Bool.TYPE, [newGraph.trueConst(), app]);
return result;
return app;
}
def normVariantEqual(source: Source, t: Type, x: SsaInstr, y: SsaInstr) -> SsaInstr {
var rc = norm.ra.getClass(t);
Expand Down Expand Up @@ -750,6 +747,53 @@ class SsaRaNormalizer extends SsaRebuilder {
call.setFact(facts);
return call;
}
// Normalize equality between two unboxed variants of the same case (e.g. x: T.A == y: T.A)
def normUnboxedVariantCaseEqual(source: Source, vn: VariantNorm, refs: Array<SsaInstr>) -> SsaInstr {
var join = opBoolAnd;
var expr: SsaInstr;
for (i < vn.caseFields.length) {
var cf = vn.caseFields[i];
for (j < cf.indexes.length) {
var idx = cf.indexes[j];
if (idx == vn.tagIndex()) continue;
var cmp: SsaInstr, a = refs[idx], b = refs[idx + vn.size];
a = curBlock.opTypeSubsume(vn.sub[idx], cf.types[j], a);
b = curBlock.opTypeSubsume(vn.sub[idx], cf.types[j], b);
cmp = normSingleEqualOp(source, cf.types[j], a, b);
if (expr == null) expr = cmp;
else expr = join(expr, cmp);
}
}
return if(expr != null, expr, newGraph.trueConst());
}
// Normalize equality between two unboxed variants of unknown case (e.g. x: T == y: T)
def normUnboxedVariantEqual(source: Source, vn: VariantNorm, refs: Array<SsaInstr>) -> SsaInstr {
if (vn.isSingleCase() || !vn.isBoxed()) return normEqualOps(source, vn, refs);
if (vn.parentNorm != null) return normUnboxedVariantCaseEqual(source, vn, refs);
var x = refs[0 ... vn.size], y = refs[vn.size ...];
var xTag = x[vn.tagIndex()], yTag = y[vn.tagIndex()];

var tagsEqual = curBlock.opEqual(vn.tagType(), xTag, yTag);
var split = SsaBlockSplit.new(context, curBlock);
var results = Vector<SsaInstr>.new();

curBlock = split.addIfNot(tagsEqual);
results.put(newGraph.falseConst());

for (i < vn.childNorms.length) {
curBlock = split.addElse();
var cn = vn.childNorms[i];
var tag = V3.getVariantTag(cn.oldType);
curBlock = split.addIf(curBlock.opEqual(vn.tagType(), xTag, newGraph.intConst(tag)));
results.put(normUnboxedVariantCaseEqual(source, cn, refs));
}

curBlock = split.addElse();
results.put(newGraph.falseConst());
curBlock = split.finish();
var result = split.addPhi(Bool.TYPE, results.extract());
return result;
}
def normTupleGetElem(oldInstr: SsaInstr, args: Array<SsaDfEdge>, op: Operator, index: int) {
var tn = TupleNorm.!(normTypeArg(op, 0));
return mapNnf(oldInstr, tn.getElem(genRefs(args), index));
Expand All @@ -759,13 +803,6 @@ class SsaRaNormalizer extends SsaRebuilder {
if (vn.isSingleCase()) return newGraph.zeroConst();
return args[offset + vn.tagIndex()];
}
def normVariantSubsume(oi: SsaApplyOp, avn: VariantNorm, rvn: VariantNorm, inputs: Range<SsaInstr>) -> Array<SsaInstr> {
var result = Array<SsaInstr>.new(avn.size);
for (i < result.length) {
result[i] = inputs[i];
}
return result;
}
def throwTypeCheckException() {
curBlock.addThrow(curBlock.source, V3Exception.TypeCheck);
}
Expand Down Expand Up @@ -794,15 +831,13 @@ class SsaRaNormalizer extends SsaRebuilder {
var avn = VariantNorm.!(atn), rvn = VariantNorm.!(rtn);
if ((avn.isEnum && rvn.isEnum) || (avn.isSingleCase() && rvn.isSingleCase())) {
var tag = V3.getVariantTag(rtn.oldType);
curBlock.opIntRangeCheck(1, tag, tag + 1, oi[offset + avn.tagIndex()]);
for (i < avn.size) result.put(oi[offset + i]);
curBlock.opIntRangeCheck(1, tag, tag + 1, oi[offset + avn.tagIndex()]);
} else {
var tag = normVariantGetTag(avn, oi, offset);
var expectedTag = V3.getVariantTag(rtn.oldType);
curBlock.opIntRangeCheck(1, expectedTag, expectedTag + 1, tag);
var subsumed = normVariantSubsume(oldInstr, avn, rvn, oi[offset ...]);
for (i < avn.size) result.put(subsumed[i]);
}
for (i < avn.size) result.put(oi[offset + i]);
return;
}
// break
Expand Down Expand Up @@ -1294,8 +1329,8 @@ class SsaRaNormalizer extends SsaRebuilder {

// field of flattened but boxed variant case
var cf = rc.variantNorm.caseFields[0];
var receiver = ninputs[cf.indexes[0]];
var box = rc.variantNorm.boxes[0];
var receiver = curBlock.opTypeSubsume(rc.variantNorm.sub[0], box.boxType, ninputs[cf.indexes[0]]);
if (V3Op.needsNullCheck(oldApp, receiver)) receiver = curBlock.add(V3Op.newVariantReplaceNull(box.boxType), [receiver], Facts.NONE);
if (nf.length == 1) {
var read = curBlock.opGetField(IrSpec.new(box.boxType, [box.boxType], box.ic.fields[raField.normIndex]), receiver);
Expand Down
7 changes: 6 additions & 1 deletion aeneas/src/ir/VariantSolver.v3
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Norm of a variant with unboxed cases
class VariantNorm extends TypeNorm {
def parentNorm: VariantNorm; // null for single parent or enum
var childNorms: Array<VariantNorm>;
def caseFields: Array<VariantField>;
def tagField: VariantField;

Expand Down Expand Up @@ -218,7 +219,7 @@ class VariantSolver(nc: NormalizerConfig, rn: ReachabilityNormalizer, verbose: b
var c = l.head;
var caseUnboxed = parentUnboxed || c.orig.boxing == Boxing.UNBOXED;

if (!caseUnboxed) {
if (!caseUnboxed || CLOptions.UNBOX_ALL_AND_BOX.get()) {
hasBoxedCase = true;
var boxedIndex = -1;

Expand All @@ -242,6 +243,7 @@ class VariantSolver(nc: NormalizerConfig, rn: ReachabilityNormalizer, verbose: b
var origRanges = Array<(int, int)>.new(c.fields.length);
var start = 0;
for (i < c.fields.length) {
if (c.fields[i] == null) continue;
if (c.fields[i].norm == null) origRanges[i] = (start, start);
else {
origRanges[i] = (start, start + c.fields[i].norm.length);
Expand Down Expand Up @@ -317,14 +319,17 @@ class VariantSolver(nc: NormalizerConfig, rn: ReachabilityNormalizer, verbose: b
if (verbose) Terminal.put1("variant norm %q\n", rc.variantNorm.render);

var caseIndex = 0;
var childNorms = Vector<VariantNorm>.new();
for (l = rc.children; l != null; l = l.tail) {
var c = l.head;
var vc = cases[caseIndex]; caseIndex++;

var caseNorm = VariantNorm.new(c.oldType, newType, parentNorm, vc.fields, tagField, vc.boxes, vecT.copy());
c.variantNorm = caseNorm;
childNorms.put(caseNorm);
if (verbose) Terminal.put1("-- case norm %q\n", caseNorm.render);
}
parentNorm.childNorms = childNorms.extract();
return true;
}
def setVariantNormForChildren(rc: RaClass, tagField: VariantField, sub: Array<Type>, facts: RaFact.set) {
Expand Down
3 changes: 3 additions & 0 deletions aeneas/src/jvm/JvmTarget.v3
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class JvmTarget extends Target {
if (IntType.?(a) && IntType.?(b)) {
return if(IntType.!(a).width > IntType.!(b).width, a, b);
}
if (BoxType.?(a) && BoxType.?(b)) {
return AnyRef.TYPE;
}
if (ClassType.?(a) && ClassType.?(b)) {
return AnyRef.TYPE;
}
Expand Down
2 changes: 1 addition & 1 deletion aeneas/src/mach/MachProgram.v3
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ class MachProgram extends TargetProgram {
def isRefType(t: Type) -> bool {
match (t.typeCon.kind) {
V3Kind.VARIANT => return !prog.ir.isEnum(t);
V3Kind.CLASS, V3Kind.ARRAY => return true;
V3Kind.CLASS, V3Kind.ARRAY, V3Kind.BOX => return true;
}
return false;
}
Expand Down
2 changes: 2 additions & 0 deletions aeneas/src/main/CLOptions.v3
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ component CLOptions {
"Enable variant unboxing features.");
def UNBOX_ALL = compileOpt.newBoolOption("unbox-all", false,
"Unbox all non-recursive variants.");
def UNBOX_ALL_AND_BOX = compileOpt.newBoolOption("unbox-all-and-box", false,
"Unbox all non-recursive variants, and box each case.");
// JVM target options
def JVM_RT_PATH = jvmOpt.newStringOption("jvm.rt-path", null,
"Specify the path to the Java runtime.");
Expand Down
1 change: 0 additions & 1 deletion aeneas/src/v3/V3.v3
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ component V3 {
x: BoxType => tag = x.tag;
}
return if(tag >= 0, tag);

}
def getVariantTagType(t: Type) -> IntType {
match (t) {
Expand Down
9 changes: 9 additions & 0 deletions aeneas/src/x86-64/X86_64Darwin.v3
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class X86_64DarwinTarget extends Target {

def configureCompiler(compiler: Compiler) {
compiler.Reachability = true;
compiler.NormConfig.UnifyRepresentations = unifyRepresentations;
}
def configureProgram(prog: Program) {
def space = SPACE;
Expand Down Expand Up @@ -160,6 +161,14 @@ class X86_64DarwinTarget extends Target {
header.addCmd(s);
return s;
}
private def unifyRepresentations(compiler: Compiler, prog: Program, a: Type, b: Type) -> Type {
var mach = getRuntime(prog).mach;
if (a == b) return a;
if (IntType.?(a) && IntType.?(b))
return if(IntType.!(a).width > IntType.!(b).width, a, b);
if (mach.isRefType(a) && mach.isRefType(b)) return AnyRef.TYPE;
return null;
}
}

// Darwin-specific backend code generation.
Expand Down
1 change: 1 addition & 0 deletions apps/TypeRep/TypeSystem.v3
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ component TypeSystem {
// Return {true} if {t} is a reference type.
def isRefType(t: Type) -> bool {
match (t) {
x: BoxType => return true;
x: ClassType => return true;
x: ArrayType => return true;
x: ClosureType => return true;
Expand Down
15 changes: 15 additions & 0 deletions test/variants/ub_eq06.v3
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//@execute 0=true;1=true;2=true;3=true;4=true;5=true
type T {
case A(x: int);
case B(y: int);
case C(a: Array<int>) #unboxed;
}

def a: Array<int> = [];
def b: Array<int> = [1, 2];
def arr1 = [T.A(0), T.A(10), T.B(20), T.B(30), T.C(a), T.C(b)];
def arr2 = [T.A(0), T.A(10), T.B(20), T.B(30), T.C(a), T.C(b)];

def main(a: int) -> bool {
return arr1[a] == arr2[a];
}
15 changes: 15 additions & 0 deletions test/variants/ub_eq07.v3
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
//@execute 0=false;1=false;2=false;3=false;4=false;5=false
type T {
case A(x: int);
case B(y: int);
case C(a: Array<int>) #unboxed;
}

def a: Array<int> = [];
def b: Array<int> = [1, 2];
def arr1 = [T.A(0), T.A(10), T.B(20), T.B(30), T.C(a), T.C(b)];
def arr2 = [T.B(20), T.B(30), T.C(a), T.C(b), T.A(0), T.A(10)];

def main(a: int) -> bool {
return arr1[a] == arr2[a];
}

0 comments on commit 2a275b2

Please sign in to comment.