From 9e2db661ef6b200f6e1e6173ceb0421df08d6206 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Wed, 24 Jul 2024 13:05:49 -0500 Subject: [PATCH] Improve inferability (#85) --- src/alspgrad.jl | 8 ++++---- src/interf.jl | 26 ++++++++++++++++---------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/alspgrad.jl b/src/alspgrad.jl index 5e6df63..fabc73c 100644 --- a/src/alspgrad.jl +++ b/src/alspgrad.jl @@ -385,13 +385,13 @@ solve!(alg::ALSPGrad, X, W, H) = struct ALSPGradUpd_State{T} WH::Matrix{T} - uhstate::ALSGradUpdH_State - uwstate::ALSGradUpdW_State + uhstate::ALSGradUpdH_State{T} + uwstate::ALSGradUpdW_State{T} ALSPGradUpd_State{T}(X, W, H) where {T} = new{T}(W * H, - ALSGradUpdH_State(X, W, H), - ALSGradUpdW_State(X, W, H)) + ALSGradUpdH_State{T}(X, W, H), + ALSGradUpdW_State{T}(X, W, H)) end prepare_state(::ALSPGradUpd{T}, X, W, H) where {T} = ALSPGradUpd_State{T}(X, W, H) diff --git a/src/interf.jl b/src/interf.jl index 110edf1..6a33cd6 100644 --- a/src/interf.jl +++ b/src/interf.jl @@ -54,37 +54,43 @@ function nnmf(X::AbstractMatrix{T}, k::Integer; else throw(ArgumentError("Invalid value for init.")) end + W = W::Matrix{T} + H = H::Matrix{T} # choose algorithm if alg == :projals - alginst = ProjectedALS{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H) + ret = solve_replicates!(ProjectedALS{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH) elseif alg == :alspgrad - alginst = ALSPGrad{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H) + ret = solve_replicates!(ALSPGrad{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH) elseif alg == :multmse - alginst = MultUpdate{T}(obj=:mse, maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H) + ret = solve_replicates!(MultUpdate{T}(obj=:mse, maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH) elseif alg == :multdiv - alginst = MultUpdate{T}(obj=:div, maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H) + ret = solve_replicates!(MultUpdate{T}(obj=:div, maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH) elseif alg == :cd - alginst = CoordinateDescent{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H) + ret = solve_replicates!(CoordinateDescent{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH) elseif alg == :greedycd - alginst = GreedyCD{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H) + ret = solve_replicates!(GreedyCD{T}(maxiter=maxiter, tol=tol, verbose=verbose, update_H=update_H), X, W, H; replicates, initH) elseif alg == :spa if init != :spa throw(ArgumentError("Invalid value for init, use :spa instead.")) end - alginst = SPA{T}(obj=:mse) + ret = solve_replicates!(SPA{T}(obj=:mse), X, W, H; replicates, initH) else throw(ArgumentError("Invalid algorithm.")) end - # run optimization + return ret +end + +function solve_replicates!(alginst, X, W, H; replicates, initH) ret = solve!(alginst, X, W, H) + k = size(W, 2) # replicates minobjv = ret.objvalue for _ in 2:replicates - W, H = randinit(X, k; zeroh=!initH, normalize=true) - tmp = solve!(alginst, X, W, H) + Wrand, Hrand = randinit(X, k; zeroh=!initH, normalize=true) + tmp = solve!(alginst, X, Wrand, Hrand) if minobjv > tmp.objvalue ret = tmp minobjv = tmp.objvalue