From 033905abe7f7004dd1af8ea5d281f8b5cf9b7f20 Mon Sep 17 00:00:00 2001 From: Ru-Xiang Date: Tue, 28 Mar 2017 16:09:42 +0800 Subject: [PATCH] elastic net and bug fix --- .../scala/generalizedLinear/Updater.scala | 76 ++++++++++++++++++- 1 file changed, 75 insertions(+), 1 deletion(-) diff --git a/src/main/scala/generalizedLinear/Updater.scala b/src/main/scala/generalizedLinear/Updater.scala index b830c15..75ab1f0 100644 --- a/src/main/scala/generalizedLinear/Updater.scala +++ b/src/main/scala/generalizedLinear/Updater.scala @@ -130,7 +130,7 @@ class L1Updater extends Updater { override def update(data: RDD[(Double, Vector)], weights: Vector, mu: Vector, lossfunc: LossFunc, stepSize: Double, factor: Double, regParam: Double): Vector = { - val preFact = 1.0 - stepSize * (regParam + factor) + val preFact = 1.0 - stepSize * factor val upFact = -stepSize / preFact @@ -246,3 +246,77 @@ class L2Updater extends Updater { } } +class elasticNetUpdater(val alpha:Double) extends Updater{ + val beta= 1-alpha + + + override def update(data: RDD[(Double, Vector)], weights: Vector, mu: Vector, lossfunc: LossFunc, stepSize: Double, factor: Double, regParam: Double): Vector = { + + val preFact = 1.0 - stepSize * (regParam*alpha + factor) + val upFact = -stepSize / preFact + mu.plusax(-factor, weights) + val w_0 = data.sparkContext.broadcast(weights) + val fix = data.sparkContext.broadcast(mu) + val partsNum = data.partitions.length + val chkSize = findChkSize(preFact) + + val l1fact=stepSize * regParam*beta + + data.mapPartitions(iter => { + val omiga = new WeightsVector(w_0.value.copy, fix.value) + val indexedSeq = iter.toIndexedSeq + val pNum = indexedSeq.size + + val rand = new Random(partsNum * pNum) + + val flags = new Array[Int](omiga.size) + util.Arrays.fill(flags, 0) + + for (j <- 1 to pNum) { + val e = indexedSeq(rand.nextInt(pNum)) + val f1 = lossfunc.deltaF(e._2, e._1, omiga) + f1 -= lossfunc.deltaF(e._2, e._1, w_0.value) + // val delta = f1 x e._2 + // delta += mu + if (j % chkSize == 0) + omiga.merge() + + val oValues = omiga.partA.toArray + e._2.foreachActive { (i, v) => + val wi = omiga.apply(i) + oValues(i) = (math.signum(wi) * max(0.0, abs(wi) - (j - 1 - flags(i)) * l1fact) - omiga.fac_b * omiga.partB(i)) / omiga.fac_a + flags(i) = j - 1 + } + + omiga.partA.plusax(upFact / omiga.fac_a, f1 x e._2) + omiga.fac_a *= preFact + omiga.fac_b *= preFact + omiga.fac_b -= stepSize + + } + Iterator(omiga.toDense()) + + }, true).treeAggregate(new DenseVector(weights.size))(seqOp = (c, w) => { + c += w + }, combOp = (c1, c2) => { + c1 += c2 + }) /= (partsNum) + + + + } + /** + * In this method, we give the cost of the regularizer. + * + * @param weight + * @param regParam + * @return regCost + */ + override def getRegVal(weight: Vector, regParam: Double): Double = { + val norm1= weight.norm1() + val norm2=weight.norm2() + regParam*(0.5*alpha*norm2*norm2+beta*norm1) + + } +} +