%Andrew Howard % %DST_em2 runs EM on on Dynamical Systems Tree % % function [DST, bound] = DST_em2(D, maxemiter, maxklit, inf_only, pred) % % D = a struct that represents the DST % maxemiter = maximum number of EM iterations (may stop early due to convergence) % maxklit = maximum iterations of KL divergence (can often set to 1 for quicker inference with no log term converge problems) % inf_only = flag to perform only inference (would want to set maxemiter = 1 and maxklit > 20 to be sure of convergence) % pred = number of time steps to predict (not completed) function [DST, bound] = DST_em2(D, maxemiter, maxklit, inf_only, fname) global TINY TINY = 1e-3; global DST; global DSTold; if(nargin<4) inf_only=0; end DST = D; pred=0; threshold = 1e-4; bound = []; b = -inf; llconverged = 0; iter = 0; etae = 1; etam = 1; alpha = .1; while(iter < maxemiter) & ~llconverged iter = iter + 1; %E step bold = b; b1 = -inf; bound1 = []; klconverged = 0; klit = 0; while(klit < maxklit) & ~klconverged b1old = b1; b1 = 0; for n=1:DST(1).N b1 = b1 + DST_inference4(1,n,pred); end bound1 = [bound1 b1]; % if(b1 - b1old < -(1e-4)) % % fprintf(1, '****DST_em:bound Dropped in E-Step by %f\n', b1old - b1); % % elseif(b1 - b1old < threshold) % % klconverged = 1; % % end [klconverged, kldecrease] = em_converged(b1, b1old, threshold, 1); klit = klit + 1; end %figure(3) %plot(1:maxklit,bound1) b = b1; if( ~mod(iter,50) ) fprintf(1, 'iteration %d, loglik = %f\n', iter, b); if(nargin==5) save(fname); end end % if(b - bold < -1e-4) % % fprintf(1, '*******DST_em:bound Dropped from M-Step by %f\n', bold - b); % % elseif(b - bold < threshold) % % llconverged = 1; % % end [llconverged, lldecrease] = em_converged(b, bold, threshold, 1); bound = [bound b]; %M step if(~inf_only) DST_mstep2(1); end end if(nargin==5) save(fname); end