Skip to content

Commit

Permalink
Merge pull request #11 from Ru-Xiang/master
Browse files Browse the repository at this point in the history
add elastic net  and  L1 bug fix
  • Loading branch information
LIBBLE authored Mar 30, 2017
2 parents bb526af + 033905a commit 73cdc97
Showing 1 changed file with 75 additions and 1 deletion.
76 changes: 75 additions & 1 deletion src/main/scala/generalizedLinear/Updater.scala
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class L1Updater extends Updater {

override def update(data: RDD[(Double, Vector)], weights: Vector, mu: Vector, lossfunc: LossFunc, stepSize: Double, factor: Double, regParam: Double): Vector = {

val preFact = 1.0 - stepSize * (regParam + factor)
val preFact = 1.0 - stepSize * factor

val upFact = -stepSize / preFact

Expand Down Expand Up @@ -246,3 +246,77 @@ class L2Updater extends Updater {
}
}

class elasticNetUpdater(val alpha:Double) extends Updater{
val beta= 1-alpha


override def update(data: RDD[(Double, Vector)], weights: Vector, mu: Vector, lossfunc: LossFunc, stepSize: Double, factor: Double, regParam: Double): Vector = {

val preFact = 1.0 - stepSize * (regParam*alpha + factor)
val upFact = -stepSize / preFact
mu.plusax(-factor, weights)
val w_0 = data.sparkContext.broadcast(weights)
val fix = data.sparkContext.broadcast(mu)
val partsNum = data.partitions.length
val chkSize = findChkSize(preFact)

val l1fact=stepSize * regParam*beta

data.mapPartitions(iter => {
val omiga = new WeightsVector(w_0.value.copy, fix.value)
val indexedSeq = iter.toIndexedSeq
val pNum = indexedSeq.size

val rand = new Random(partsNum * pNum)

val flags = new Array[Int](omiga.size)
util.Arrays.fill(flags, 0)

for (j <- 1 to pNum) {
val e = indexedSeq(rand.nextInt(pNum))
val f1 = lossfunc.deltaF(e._2, e._1, omiga)
f1 -= lossfunc.deltaF(e._2, e._1, w_0.value)
// val delta = f1 x e._2
// delta += mu
if (j % chkSize == 0)
omiga.merge()

val oValues = omiga.partA.toArray
e._2.foreachActive { (i, v) =>
val wi = omiga.apply(i)
oValues(i) = (math.signum(wi) * max(0.0, abs(wi) - (j - 1 - flags(i)) * l1fact) - omiga.fac_b * omiga.partB(i)) / omiga.fac_a
flags(i) = j - 1
}

omiga.partA.plusax(upFact / omiga.fac_a, f1 x e._2)
omiga.fac_a *= preFact
omiga.fac_b *= preFact
omiga.fac_b -= stepSize

}
Iterator(omiga.toDense())

}, true).treeAggregate(new DenseVector(weights.size))(seqOp = (c, w) => {
c += w
}, combOp = (c1, c2) => {
c1 += c2
}) /= (partsNum)



}
/**
* In this method, we give the cost of the regularizer.
*
* @param weight
* @param regParam
* @return regCost
*/
override def getRegVal(weight: Vector, regParam: Double): Double = {
val norm1= weight.norm1()
val norm2=weight.norm2()
regParam*(0.5*alpha*norm2*norm2+beta*norm1)

}
}

0 comments on commit 73cdc97

Please sign in to comment.