Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

la4j new feature #133

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.class
*.iml
*.classpath
*.project
*.project
/target
13 changes: 13 additions & 0 deletions src/main/java/org/la4j/LinearAlgebra.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import org.la4j.linear.JacobiSolver;
import org.la4j.linear.LeastSquaresSolver;
import org.la4j.linear.LinearSystemSolver;
import org.la4j.optimization.LinearSystemOptimizer;
import org.la4j.optimization.NonlinearConjugateGradientOptimizer;
import org.la4j.linear.SeidelSolver;
import org.la4j.linear.SquareRootSolver;
import org.la4j.linear.SweepSolver;
Expand Down Expand Up @@ -153,6 +155,17 @@ public LinearSystemSolver create(Matrix matrix) {

public abstract LinearSystemSolver create(Matrix matrix);
}

public static enum OptimizerFactory{
NLCG {
@Override
public LinearSystemOptimizer create(Matrix matrix){
return new NonlinearConjugateGradientOptimizer(matrix);
}
};

public abstract LinearSystemOptimizer create(Matrix matrix);
}

/**
* References to the Gaussian solver factory.
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/la4j/linear/LinearSystemSolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public interface LinearSystemSolver extends Serializable {
Matrix self();

/**
* Returns the number os unknowns in this solver.
* Returns the number of unknowns in this solver.
*
* @return
*/
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/la4j/matrix/AbstractMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.la4j.matrix.functor.MatrixFunction;
import org.la4j.matrix.functor.MatrixPredicate;
import org.la4j.matrix.functor.MatrixProcedure;
import org.la4j.optimization.LinearSystemOptimizer;
import org.la4j.vector.Vector;

public abstract class AbstractMatrix implements Matrix {
Expand Down Expand Up @@ -1055,6 +1056,11 @@ public Vector toColumnVector(Factory factory) {
@Override
public LinearSystemSolver withSolver(LinearAlgebra.SolverFactory factory) {
return factory.create(this);
}

@Override
public LinearSystemOptimizer withOptimizer(LinearAlgebra.OptimizerFactory factory){
return factory.create(this);
}

@Override
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/la4j/matrix/AbstractSafeMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.la4j.matrix.functor.MatrixFunction;
import org.la4j.matrix.functor.MatrixPredicate;
import org.la4j.matrix.functor.MatrixProcedure;
import org.la4j.optimization.LinearSystemOptimizer;
import org.la4j.vector.Vector;

public abstract class AbstractSafeMatrix implements Matrix {
Expand Down Expand Up @@ -581,6 +582,11 @@ public Vector toColumnVector(Factory factory) {
public LinearSystemSolver withSolver(LinearAlgebra.SolverFactory factory) {
return self.withSolver(factory);
}

@Override
public LinearSystemOptimizer withOptimizer(LinearAlgebra.OptimizerFactory factory){
return self.withOptimizer(factory);
}

@Override
public MatrixInverter withInverter(LinearAlgebra.InverterFactory factory) {
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/org/la4j/matrix/Matrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.la4j.factory.Factory;
import org.la4j.inversion.MatrixInverter;
import org.la4j.linear.LinearSystemSolver;
import org.la4j.optimization.LinearSystemOptimizer;
import org.la4j.matrix.functor.AdvancedMatrixPredicate;
import org.la4j.matrix.functor.MatrixAccumulator;
import org.la4j.matrix.functor.MatrixFunction;
Expand Down Expand Up @@ -876,6 +877,15 @@ public interface Matrix extends Externalizable {
*/
LinearSystemSolver withSolver(LinearAlgebra.SolverFactory factory);

/**
* Creates a new optimizer with given {@code accuracy} by {@code factory} of this matrix
* @param factory
* @param accourancy
* @return
*/
LinearSystemOptimizer withOptimizer(LinearAlgebra.OptimizerFactory factory);


/**
* Creates a new inverter by given {@code factory} of this matrix.
*
Expand Down
53 changes: 53 additions & 0 deletions src/main/java/org/la4j/optimization/AbstractOptimizer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package org.la4j.optimization;

import org.la4j.matrix.Matrix;
import org.la4j.vector.Vector;

public abstract class AbstractOptimizer implements LinearSystemOptimizer {

protected Matrix matrix;

protected int unknowns;
protected int equations;

protected AbstractOptimizer(Matrix a) {
if (!applicableTo(a)) {
fail("Given coefficient matrix can not be used with this solver.");
}

this.matrix = a;
this.unknowns = a.columns();
this.equations = a.rows();
}


@Override
public Vector solve(Vector b, double accuracy) {
return solve(b, b.factory(), accuracy);
}

@Override
public Vector solve(Vector b) {
return solve(b, b.factory(), 1e-7);
}

@Override
public Matrix self() {
return matrix;
}

@Override
public int unknowns() {
return unknowns;
}

@Override
public int equations() {
return equations;
}

protected void fail(String message) {
throw new IllegalArgumentException(message);
}

}
67 changes: 67 additions & 0 deletions src/main/java/org/la4j/optimization/LinearSystemOptimizer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package org.la4j.optimization;

import java.io.Serializable;

import org.la4j.factory.Factory;
import org.la4j.matrix.Matrix;
import org.la4j.vector.Vector;

/**
* Linear System Optimizator interface;
* This class implements Strategy design pattern;
*/
public interface LinearSystemOptimizer extends Serializable{
/**
* Optimize the system A*x = b with default accuracy 1e-7;
*
* @param b
* @return
*/
Vector solve(Vector b, double accuracy);

/**
* Optimize the system A*x = b with given {@code accuracy} accuracy
*
* @param b
* @return
*/
Vector solve(Vector b);

/**
* Optimize the system A*x = b.
*
* @param b
* @param factory
* @return
*/
Vector solve(Vector b, Factory factory, double accuracy);

/**
* Returns the self matrix of the optimizator.
*
* @return
*/
Matrix self();

/**
* Returns the number of unknowns in this optimizator.
*
* @return
*/
int unknowns();

/**
* Returns the number of equations in this optimizator.
*
* @return
*/
int equations();

/**
* Checks whether this optimizator applicable to given {@code matrix} or not.
*
* @param matrix
*/
boolean applicableTo(Matrix matrix);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package org.la4j.optimization;

import org.la4j.LinearAlgebra;
import org.la4j.factory.Factory;
import org.la4j.matrix.Matrix;
import org.la4j.vector.Vector;

public class NonlinearConjugateGradientOptimizer extends AbstractOptimizer implements LinearSystemOptimizer{

private static final long serialVersionUID = -6730752031320270935L;

private Vector start;
private int maxIterations = 10000;

private double step;

public NonlinearConjugateGradientOptimizer(Matrix a) {
super(a);
start = LinearAlgebra.BASIC1D_FACTORY.createVector(this.unknowns);
}

@Override
public Vector solve(Vector b, Factory factory, double accuracy) {
Vector oldX = start.copy();
Vector oldGradient = formGradient(matrix, start, b).multiply(-1.0);
Vector newX = alphaLineSearch(matrix, b, oldX, oldGradient);
Vector newGradient = formGradient(matrix, start, b).multiply(-1.0);
Vector direction = newGradient.copy();

int k = 0;
double beta = 1;
step = accuracy;

while ((k++ < maxIterations) && (formGradient(matrix, oldX, b).norm() > accuracy)){
newGradient = formGradient(matrix, newX, b).multiply(-1.0);
beta = Math.pow(formGradient(matrix, newX, b).norm(), 2) / Math.pow(formGradient(matrix, oldX, b).norm(), 2);
direction = newGradient.add(direction.multiply(beta));
oldX = newX.copy();
newX = alphaLineSearch(matrix, b, oldX, direction);
}

if (k >= maxIterations){
System.out.println("Reached " + maxIterations + " iterations." );
}

return oldX;
}

@Override
public boolean applicableTo(Matrix matrix) {
return true;
}

private double form(Matrix A, Vector x, Vector b){
return Math.pow(A.multiply(x).add(b.multiply(-1.0)).norm(), 2);
}

private Vector formGradient(Matrix A, Vector x, Vector b){
return A.transpose().multiply(A.multiply(x).add(b.multiply(-1.0))).multiply(2);
}

private Vector alphaLineSearch(Matrix A, Vector b, Vector x, Vector s){
int k = 0;

Vector newX = x.copy();
Vector oldX = x.copy();
double newF = form(A, x, b);
double oldF = form(A, x, b);
double alpha = step;

do{
k++;
oldX = newX;
newX = oldX.add(s.multiply(alpha));
oldF = newF;
newF = form(A, newX, b);

}while((newF < oldF) && (k < maxIterations));

return oldX;

}
}
29 changes: 29 additions & 0 deletions src/test/java/org/la4j/optimization/AbstractOptimizerTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package org.la4j.optimization;

import org.la4j.LinearAlgebra;
import org.la4j.factory.Factory;
import org.la4j.matrix.Matrix;
import org.la4j.vector.MockVector;
import org.la4j.vector.Vector;

import junit.framework.TestCase;

public abstract class AbstractOptimizerTest extends TestCase {

public void performTest(LinearAlgebra.OptimizerFactory optimizerFactory,
double coefficientMatrix[][], double rightHandVector[], double accuracy) {

for (Factory factory : LinearAlgebra.FACTORIES) {

Matrix a = factory.createMatrix(coefficientMatrix);
Vector b = factory.createVector(rightHandVector);

LinearSystemOptimizer solver = a.withOptimizer(optimizerFactory);
Vector x = solver.solve(b, factory, accuracy);

double eps = (new MockVector(b)).add((new MockVector(a.multiply(x)).multiply(-1.0))).max();

assertTrue(Math.abs(eps) <= accuracy);
}
}
}
Loading