Skip to content

Commit

Permalink
Merge pull request #2731 from informalsystems/jk/oracles
Browse files Browse the repository at this point in the history
MockOracle refactor
  • Loading branch information
Shon Feder authored Jan 3, 2024
2 parents 253f4c4 + e37496b commit 70d7f24
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles

import at.forsyte.apalache.tla.bmcmt.smt.SolverContext
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.RewriterScope
import at.forsyte.apalache.tla.typecomp.TBuilderInstruction
import at.forsyte.apalache.tla.types.tla

/**
* An oracle that always has the same value. This class specializes all methods to the case oracle == fixedValue.
* However, evalPosition always returns fixedValue.
*
* @param fixedValue
* a fixed value of the oracle
*/
class MockOracle(fixedValue: Int) extends Oracle {
require(fixedValue >= 0, "MockOracle must have a non-negative fixed value.")

override def size: Int = fixedValue + 1

override def chosenValueIsEqualToIndexedValue(scope: RewriterScope, index: BigInt): TBuilderInstruction =
tla.bool(index == fixedValue)

override def caseAssertions(
scope: RewriterScope,
assertions: Seq[TBuilderInstruction],
elseAssertionsOpt: Option[Seq[TBuilderInstruction]] = None): TBuilderInstruction = {
require(assertions.size == this.size && elseAssertionsOpt.forall { _.size == this.size },
s"Invalid call to Oracle, assertion sequences must have length $size.")
assertions(fixedValue)
}

override def getIndexOfChosenValueFromModel(solverContext: SolverContext): BigInt =
fixedValue
}

object MockOracle {
def create(fixedValue: Int): MockOracle = new MockOracle(fixedValue)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package at.forsyte.apalache.tla.bmcmt.stratifiedRules.aux.oracles

import at.forsyte.apalache.tla.bmcmt.smt.{SolverConfig, Z3SolverContext}
import at.forsyte.apalache.tla.bmcmt.stratifiedRules.RewriterScope
import at.forsyte.apalache.tla.lir._
import at.forsyte.apalache.tla.lir.values.TlaBool
import at.forsyte.apalache.tla.typecomp.TBuilderInstruction
import at.forsyte.apalache.tla.types.tla
import org.junit.runner.RunWith
import org.scalacheck.Prop.forAll
import org.scalacheck.{Gen, Prop}
import org.scalatest.BeforeAndAfterEach
import org.scalatest.funsuite.AnyFunSuite
import org.scalatestplus.junit.JUnitRunner
import org.scalatestplus.scalacheck.Checkers

@RunWith(classOf[JUnitRunner])
class TestMockOracle extends AnyFunSuite with BeforeAndAfterEach with Checkers {

var initScope: RewriterScope = RewriterScope.initial()

override def beforeEach(): Unit = {
initScope = RewriterScope.initial()
}

val intGen: Gen[Int] = Gen.choose(-10, 10)
val nonNegIntGen: Gen[Int] = Gen.choose(0, 9)

val maxSizeAndIndexGen: Gen[(Int, Int)] = for {
max <- nonNegIntGen
idx <- Gen.choose(0, max)
} yield (max, idx)

test("Oracle cannot be constructed with a negative fixed value") {
val prop =
forAll(intGen) {
case i if i < 0 =>
Prop.throws(classOf[IllegalArgumentException]) {
MockOracle.create(i)
}
case i => MockOracle.create(i).size == i + 1
}

check(prop, minSuccessful(100), sizeRange(4))
}

test("chosenValueIsEqualToIndexedValue returns a simple boolean") {
val prop =
forAll(maxSizeAndIndexGen) { case (fixed, index) =>
val oracle = MockOracle.create(fixed)
val cmp: TlaEx = oracle.chosenValueIsEqualToIndexedValue(initScope, index)
cmp match {
case ValEx(TlaBool(v)) => v == (index == fixed)
case _ => false
}
}

check(prop, minSuccessful(1000), sizeRange(4))
}

val (assertionsA, assertionsB): (Seq[TBuilderInstruction], Seq[TBuilderInstruction]) = 0
.to(10)
.map { i =>
(tla.name(s"A$i", BoolT1), tla.name(s"B$i", BoolT1))
}
.unzip

test("caseAssertions requires assertion sequences of equal length") {
val assertionsGen: Gen[(Seq[TBuilderInstruction], Option[Seq[TBuilderInstruction]])] = for {
i <- Gen.choose(0, assertionsA.size)
j <- Gen.choose(0, assertionsB.size)
opt <- Gen.option(Gen.const(assertionsB.take(j)))
} yield (assertionsA.take(i), opt)

val prop =
forAll(Gen.zip(nonNegIntGen, assertionsGen)) { case (fixed, (assertionsIfTrue, assertionsIfFalseOpt)) =>
val oracle = MockOracle.create(fixed)
if (assertionsIfTrue.size != oracle.size || assertionsIfFalseOpt.exists { _.size != oracle.size })
Prop.throws(classOf[IllegalArgumentException]) {
oracle.caseAssertions(initScope, assertionsIfTrue, assertionsIfFalseOpt)
}
else true
}

check(prop, minSuccessful(1000), sizeRange(4))
}

test("caseAssertions always shorthands") {
val gen: Gen[(Int, Seq[TBuilderInstruction], Option[Seq[TBuilderInstruction]])] = for {
fixed <- nonNegIntGen
opt <- Gen.option(Gen.const(assertionsB.take(fixed + 1)))
} yield (fixed, assertionsA.take(fixed + 1), opt)

val prop =
forAll(gen) { case (fixed, assertionsIfTrue, assertionsIfFalseOpt) =>
val oracle = MockOracle.create(fixed)
val caseEx: TlaEx = oracle.caseAssertions(initScope, assertionsIfTrue, assertionsIfFalseOpt)
caseEx == assertionsIfTrue(fixed).build
}

check(prop, minSuccessful(1000), sizeRange(4))
}

// We don't actually need the solver in MockOracle
test("getIndexOfChosenValueFromModel recovers the index correctly") {
val prop =
forAll(Gen.zip(nonNegIntGen)) { fixed =>
val ctx = new Z3SolverContext(SolverConfig.default)
val oracle = MockOracle.create(fixed)
oracle.getIndexOfChosenValueFromModel(ctx) == fixed
}

check(prop, minSuccessful(100), sizeRange(4))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class TestSparseOracle extends AnyFunSuite with BeforeAndAfterEach with Checkers
ret
}

// 1000 is too many, since each run invokes the solver
check(prop, minSuccessful(80), sizeRange(4))
// The default minimum successful runs is 1000, but this is costly
// since each run invokes the solver.
check(prop, minSuccessful(50), sizeRange(4))
}
}

0 comments on commit 70d7f24

Please sign in to comment.