Skip to content

Commit

Permalink
Merge pull request #1 from bishabosha/add-ci
Browse files Browse the repository at this point in the history
Add more safe lookups
  • Loading branch information
bishabosha authored Oct 19, 2023
2 parents 3df5189 + 52618d2 commit e10a67a
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 86 deletions.
26 changes: 26 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: CI
on:
push:
branches:
- main
tags:
- "v*"
pull_request:

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- uses: coursier/[email protected]
- uses: VirtusLab/[email protected]
with:
power: true

- name: Check formatting
run: scala-cli fmt . --check

- name: Run unit tests
run: scala-cli test . --cross
10 changes: 10 additions & 0 deletions .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version = "3.7.14"
runner.dialect = scala3
align.preset = more
maxColumn = 100
indent.fewerBraces = never
rewrite.scala3.convertToNewSyntax = true
rewrite.scala3.removeOptionalBraces = yes
rewrite.scala3.insertEndMarkerMinLines = 5
verticalMultiline.atDefnSite = true
newlines.usingParamListModifierPrefer = before
16 changes: 13 additions & 3 deletions src/main/scala/enumextensions/EnumMirror.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,27 @@ trait EnumMirror[E]:
def mirroredName: String
def size: Int
def values: IArray[E]
def valueOf(name: String): E
def fromOrdinal(ordinal: Int): E
def declaresOrdinal(ordinal: Int): Boolean
def declaresName(name: String): Boolean
def valueOfUnsafe(name: String): E
def fromOrdinalUnsafe(ordinal: Int): E
def valueOf(name: String): Option[E] =
if declaresName(name) then Some(valueOfUnsafe(name)) else None
def fromOrdinal(ordinal: Int): Option[E] =
if declaresOrdinal(ordinal) then Some(fromOrdinalUnsafe(ordinal)) else None

extension (e: E)
def ordinal: Int
def name: String

end EnumMirror

object EnumMirror:

inline def apply[E](using mirror: EnumMirror[E]): mirror.type = mirror

transparent inline def derived[E]: EnumMirror[E] = ${ Macros.derivedEnumMirror[E] }
transparent inline def derived[E]: EnumMirror[E] = ${
Macros.derivedEnumMirror[E]
}

end EnumMirror
61 changes: 51 additions & 10 deletions src/main/scala/enumextensions/Macros.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
package enumextensions

import scala.quoted.*
import scala.collection.SeqView
import scala.deriving.Mirror

object Macros:

def string[T: Type](using Quotes): Expr[String] =
import quotes.reflect.*
val ConstantType(StringConstant(str)) = TypeRepr.of[T]: @unchecked
Expr(str)
end string

def names[T: Type](using Quotes): List[Expr[String]] = Type.of[T] match
case '[EmptyTuple] => Nil
case '[t *: ts] => string[t] :: names[ts]

def derivedEnumMirror[E: Type](using Quotes): Expr[EnumMirror[E]] =
import quotes.reflect.*

Expand All @@ -15,16 +27,29 @@ object Macros:
case _ =>
report.errorAndAbort(s"${tpe.show} is not an enum type")

val M = Expr.summon[Mirror.SumOf[E]] match
case Some(mirror) => mirror
case None =>
report.errorAndAbort(s"Could not summon a Mirror.SumOf[${tpe.show}]")

val reifiedNames: Expr[Set[String]] = M match
case '{ $m: Mirror.SumOf[E] { type MirroredElemLabels = elemLabels } } =>
'{ Set(${ Varargs(names[elemLabels]) }*) }

val E = sym.companionModule

val valuesRef =
Select.unique(Ref(E), "values").asExprOf[Array[E]]
Select.unique(Ref(E), "values").asExprOf[Array[E & reflect.Enum]]

def reifyName(name: Expr[String]) =
Select.overloaded(Ref(E), "valueOf", Nil, name.asTerm::Nil).asExprOf[E]
def reifyValueOf(name: Expr[String]) =
Select
.overloaded(Ref(E), "valueOf", Nil, name.asTerm :: Nil)
.asExprOf[E & reflect.Enum]

def reifyOrdinal(ordinal: Expr[Int]) =
Select.overloaded(Ref(E), "fromOrdinal", Nil, ordinal.asTerm::Nil).asExprOf[E]
def reifyFromOrdinal(ordinal: Expr[Int]) =
Select
.overloaded(Ref(E), "fromOrdinal", Nil, ordinal.asTerm :: Nil)
.asExprOf[E & reflect.Enum]

val sizeExpr = Expr(sym.children.length)

Expand All @@ -34,13 +59,29 @@ object Macros:

new EnumMirror[E]:

private val _values: IArray[E] = IArray.unsafeFromArray($valuesRef)
private val _values: IArray[E & reflect.Enum] =
IArray.unsafeFromArray($valuesRef)
private val _ordinals = _values.indices
private val _names = $reifiedNames

locally:
assert(_values.length == $sizeExpr)
assert(
(_values: IndexedSeq[E & reflect.Enum])
.map(_.ordinal)
.corresponds(_ordinals)(_ == _)
)

final def mirroredName: String = $mirroredNameExpr
final def size: Int = $sizeExpr
final def values: IArray[E] = _values
final def valueOf(name: String): E = ${ reifyName('name) }
final def fromOrdinal(ordinal: Int): E = ${ reifyOrdinal('ordinal) }
final def size: Int = $sizeExpr
final def values: IArray[E] = _values
final def declaresOrdinal(ordinal: Int): Boolean =
_ordinals.contains(ordinal)
final def declaresName(name: String): Boolean = _names.contains(name)
final def valueOfUnsafe(name: String): E = ${ reifyValueOf('name) }
final def fromOrdinalUnsafe(ordinal: Int): E = ${
reifyFromOrdinal('ordinal)
}

extension (e: E & scala.reflect.Enum)
final def ordinal: Int = e.ordinal
Expand Down
21 changes: 12 additions & 9 deletions src/main/scala/enumextensions/numeric/Macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,18 @@ object Macros:
case Some(sym) => sym
case _ => report.errorAndAbort(s"${tpe.show} is not a class type")

if sym.children.length > 1 then ('{
new NumericOps(using $mirror) with NumericOps.Modular[T]:
override final val zero = EnumMirror[T].fromOrdinal(0)
override final val one = EnumMirror[T].fromOrdinal(1)
})
else ('{
new NumericOps(using $mirror) with NumericOps.Singleton[T]:
override final val zero = EnumMirror[T].fromOrdinal(0)
})
if sym.children.length > 1 then
'{
new NumericOps(using $mirror) with NumericOps.Modular[T]:
override final val zero = EnumMirror[T].fromOrdinalUnsafe(0)
override final val one = EnumMirror[T].fromOrdinalUnsafe(1)
}
else
'{
new NumericOps(using $mirror) with NumericOps.Singleton[T]:
override final val zero = EnumMirror[T].fromOrdinalUnsafe(0)
}
end if
end derivedNumericOps

end Macros
72 changes: 38 additions & 34 deletions src/main/scala/enumextensions/numeric/NumericOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,19 @@ import scala.collection.immutable.NumericRange
import scala.util.Try
import scala.quoted.*

trait NumericOps[T](using final val mirror: EnumMirror[T]) extends Numeric[T] with Integral[T] { self =>
trait NumericOps[T](using final val mirror: EnumMirror[T]) extends Numeric[T] with Integral[T]:
self =>

final def parseString(str: String): Option[T] = Try(EnumMirror[T].valueOf(str)).toOption
final def parseString(str: String): Option[T] = EnumMirror[T].valueOf(str)

extension (t: T) {
def to (u: T): NumericRange[T] = NumericRange.inclusive(t, u, one)(self)
def until (u: T): NumericRange[T] = NumericRange(t, u, one)(self)
}
extension (t: T)
def to(u: T): NumericRange[T] = NumericRange.inclusive(t, u, one)(self)
def until(u: T): NumericRange[T] = NumericRange(t, u, one)(self)
end NumericOps

}
object NumericOps:

object NumericOps {

trait Singleton[T] extends NumericOps[T] {
trait Singleton[T] extends NumericOps[T]:

final def compare(l: T, r: T): Int = 0

Expand All @@ -28,40 +27,45 @@ object NumericOps {
final def fromInt(x: Int): T = zero

final def minus(x: T, y: T): T = x
final def plus(x: T, y: T): T = x
final def plus(x: T, y: T): T = x
final def times(x: T, y: T): T = x
final def quot(x: T, y: T): T = x
final def rem(x: T, y: T): T = x
final def negate(x: T): T = x
final def quot(x: T, y: T): T = x
final def rem(x: T, y: T): T = x
final def negate(x: T): T = x

final def toDouble(x: T): Double = 0
final def toFloat(x: T): Float = 0
final def toInt(x: T): Int = 0
final def toLong(x: T): Long = 0

}
final def toFloat(x: T): Float = 0
final def toInt(x: T): Int = 0
final def toLong(x: T): Long = 0
end Singleton

trait Modular[T] extends NumericOps[T] {
trait Modular[T] extends NumericOps[T]:
import mirror.size

final def compare(l: T, r: T): Int = l.ordinal compare r.ordinal

final def minus(x: T, y: T): T = EnumMirror[T].fromOrdinal((size + 1 + x.ordinal - y.ordinal) % size)
final def plus(x: T, y: T): T = EnumMirror[T].fromOrdinal((x.ordinal + y.ordinal) % size)
final def times(x: T, y: T): T = EnumMirror[T].fromOrdinal((x.ordinal * y.ordinal) % size)
final def quot(x: T, y: T): T = EnumMirror[T].fromOrdinal((x.ordinal / y.ordinal) % size)
final def rem(x: T, y: T): T = EnumMirror[T].fromOrdinal((x.ordinal % y.ordinal) % size)
final def negate(x: T): T = EnumMirror[T].fromOrdinal((size - x.ordinal) % size)

final def fromInt(x: Int): T = EnumMirror[T].fromOrdinal((x + size) % size)
final def minus(x: T, y: T): T =
EnumMirror[T].fromOrdinalUnsafe((size + 1 + x.ordinal - y.ordinal) % size)
final def plus(x: T, y: T): T =
EnumMirror[T].fromOrdinalUnsafe((x.ordinal + y.ordinal) % size)
final def times(x: T, y: T): T =
EnumMirror[T].fromOrdinalUnsafe((x.ordinal * y.ordinal) % size)
final def quot(x: T, y: T): T =
EnumMirror[T].fromOrdinalUnsafe((x.ordinal / y.ordinal) % size)
final def rem(x: T, y: T): T =
EnumMirror[T].fromOrdinalUnsafe((x.ordinal % y.ordinal) % size)
final def negate(x: T): T =
EnumMirror[T].fromOrdinalUnsafe((size - x.ordinal) % size)

final def fromInt(x: Int): T =
EnumMirror[T].fromOrdinalUnsafe((x + size) % size)

final def toDouble(x: T): Double = x.ordinal.toDouble
final def toFloat(x: T): Float = x.ordinal.toFloat
final def toInt(x: T): Int = x.ordinal
final def toLong(x: T): Long = x.ordinal.toLong

}
final def toFloat(x: T): Float = x.ordinal.toFloat
final def toInt(x: T): Int = x.ordinal
final def toLong(x: T): Long = x.ordinal.toLong
end Modular

transparent inline def derived[T](using inline mirror: EnumMirror[T]): NumericOps[T] =
${ Macros.derivedNumericOps[T]('mirror) }
}
end NumericOps
36 changes: 18 additions & 18 deletions src/test/scala/example/NumericSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,28 @@ import enumextensions.numeric.NumericOps

import scala.collection.immutable.NumericRange

object NumericSuite {
object NumericSuite:

enum Single derives EnumMirror, NumericOps {
enum Single derives EnumMirror, NumericOps:
case One
}

enum Rank derives EnumMirror, NumericOps {
case Two, Three, Four, Five, Six, Seven, Eight, Nine, Ten, Jack, Queen, King, Ace
}
enum Rank derives EnumMirror, NumericOps:
case Two, Three, Four, Five, Six, Seven, Eight, Nine, Ten, Jack, Queen,
King, Ace

enum Suit derives EnumMirror, NumericOps {
enum Suit derives EnumMirror, NumericOps:
case Clubs, Diamonds, Hearts, Spades
}

case class Card(suit: Suit, rank: Rank)
}
end NumericSuite

class NumericSuite extends munit.FunSuite {
test("make deck of cards") {
class NumericSuite extends munit.FunSuite:

test("make deck of cards"):
val deck =
for
suit <- Clubs to Spades
rank <- Two to Ace
rank <- Two to Ace
yield Card(suit, rank)

val deck2 =
Expand All @@ -45,13 +43,15 @@ class NumericSuite extends munit.FunSuite {
deck,
deck2.toIndexedSeq
)
}

test("test is numeric") {
.endLocally

test("test is numeric"):
def rangeTo[E: Integral](from: E, to: E): NumericRange[E] =
NumericRange.inclusive(from, to, summon[Numeric[E]].one)
assertEquals(
rangeTo(Clubs, Spades).toIndexedSeq,
EnumMirror[Suit].values.toIndexedSeq
)
.endLocally

assertEquals(rangeTo(Clubs, Spades).toIndexedSeq, EnumMirror[Suit].values.toIndexedSeq)
}
}
end NumericSuite
Loading

0 comments on commit e10a67a

Please sign in to comment.