diff --git a/src/demo_stochastic_matrix.cpp b/src/demo_stochastic_matrix.cpp index a949c1b..93608e1 100644 --- a/src/demo_stochastic_matrix.cpp +++ b/src/demo_stochastic_matrix.cpp @@ -42,7 +42,7 @@ int main(int argc, char **argv) // Shut GetOpt error messages down (return '?'): opterr = 0; - while ( (opt = getopt(argc, argv, "l:d:a:m:e:h:p:")) != -1 ) { + while ( (opt = getopt(argc, argv, "l:d:a:m:e:h:p:r")) != -1 ) { switch ( opt ) { case 'l': sscanf(optarg, "%lf", ¶ms.lambda); @@ -50,6 +50,9 @@ int main(int argc, char **argv) case 'd': params.d = atoi(optarg); break; + case 'r': + sscanf(optarg, "%lf", ¶ms.eta); + break; case 'm': params.maxIter = atoi(optarg); break; diff --git a/src/gradient_descend.cpp b/src/gradient_descend.cpp index aafe877..f1d4be9 100644 --- a/src/gradient_descend.cpp +++ b/src/gradient_descend.cpp @@ -130,7 +130,7 @@ void kl_minimization(coord* y, // ----- t-SNE hard coded parameters - Same as in vdM's code int stop_lying_iter = params.earlyIter, mom_switch_iter = 250; double momentum = .5, final_momentum = .8; - double eta = 200.0; + double eta = params.eta; int iterPrint = 50; double timeFattr = 0.0; diff --git a/src/sgtsne.hpp b/src/sgtsne.hpp index f5230b8..dacbc9f 100644 --- a/src/sgtsne.hpp +++ b/src/sgtsne.hpp @@ -33,8 +33,8 @@ typedef struct { double h = -1; //!< Grid side length (accuracy control) bool dropLeaf = false; //!< Drop edges originating from leaf nodes? int np = 0; //!< Number of CILK workers (processes) - -} tsneparams; + double eta = 200.0; //!< learning rate +} tsneparams; //! Sparse matrix structure in CSC format diff --git a/src/utils.cpp b/src/utils.cpp index 904d303..86387e7 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -22,6 +22,7 @@ void printParams(tsneparams P){ << "Early exag. multiplier α: " << P.alpha << std::endl << "Maximum iterations: " << P.maxIter << std::endl << "Early exag. iterations: " << P.earlyIter << std::endl + << "Learning rate: " << P.eta << std::endl << "Box side length h: " << P.h << std::endl << "Drop edges originating from leaf nodes? " << P.dropLeaf << std::endl << "Number of processes: " << P.np << std::endl;