function[] = varGausschain(cur,n) %perform variational inference for Gaussian Chain in SLDS global DST; bound = 0; T = DST(cur).T(n); Rinv = safeinv(DST(cur).R); %%calculate variational Q(T), A(T) and B(T) sQ = zeros(DST(cur).xdim); sQA = zeros(DST(cur).xdim); for i=1:DST(DST(cur).parent).S sQ = sQ + DST(DST(cur).parent).s{n}(i,T) * DST(cur).Qinv(:,:,i); sQA = sQA + DST(DST(cur).parent).s{n}(i,T) * DST(cur).Qinv(:,:,i) * DST(cur).A(:,:,i); end if(DST(cur).missing{n}(T)) vQinv(:,:,T) = sQ; vQ(:,:,T) = safeinv(vQinv(:,:,T)); vB(:,T) = zeros(DST(cur).xdim,1); else vQinv(:,:,T) = sQ + DST(cur).C' * Rinv * DST(cur).C; vQ(:,:,T) = safeinv(vQinv(:,:,T)); vB(:,T) = vQ(:,:,T) * DST(cur).C' * Rinv * DST(cur).y{n}(T,:)'; end vA(:,:,T) = vQ(:,:,T) * sQA; for t=T-1:-1:2 sQ = zeros(DST(cur).xdim); sQA = zeros(DST(cur).xdim); sAQA = zeros(DST(cur).xdim); for i=1:DST(DST(cur).parent).S sQ = sQ + DST(DST(cur).parent).s{n}(i,t) * DST(cur).Qinv(:,:,i); sQA = sQA + DST(DST(cur).parent).s{n}(i,t) * DST(cur).Qinv(:,:,i) * DST(cur).A(:,:,i); sAQA = sAQA + DST(DST(cur).parent).s{n}(i,t+1) * DST(cur).A(:,:,i)' * ... DST(cur).Qinv(:,:,i) * DST(cur).A(:,:,i); end if(DST(cur).missing{n}(t)) vQinv(:,:,t) = sQ + sAQA - vA(:,:,t+1)' * vQinv(:,:,t+1) * vA(:,:,t+1); vQ(:,:,t) = safeinv(vQinv(:,:,t)); vB(:,t) = vQ(:,:,t) * ( vA(:,:,t+1)' * vQinv(:,:,t+1) * vB(:,t+1) ); else vQinv(:,:,t) = sQ + DST(cur).C' * Rinv * DST(cur).C + ... sAQA - vA(:,:,t+1)' * vQinv(:,:,t+1) * vA(:,:,t+1); vQ(:,:,t) = safeinv(vQinv(:,:,t)); vB(:,t) = vQ(:,:,t) * ( DST(cur).C' * Rinv * DST(cur).y{n}(t,:)' + ... vA(:,:,t+1)' * vQinv(:,:,t+1) * vB(:,t+1) ); end vA(:,:,t) = vQ(:,:,t) * sQA; end %calculate variational mu and Q1 sQ = zeros(size(DST(cur).Q1inv(:,:,1))); sQmu = zeros(size(DST(cur).mu(:,1))); sAQA = sQ; for i=1:DST(DST(cur).parent).S sQ = sQ + DST(DST(cur).parent).s{n}(i,1) * DST(cur).Q1inv(:,:,i); sQmu = sQmu + DST(DST(cur).parent).s{n}(i,1) * DST(cur).Q1inv(:,:,i) * DST(cur).mu(:,i); sAQA = sAQA + DST(DST(cur).parent).s{n}(i,2) * DST(cur).A(:,:,i)' * ... DST(cur).Qinv(:,:,i) * DST(cur).A(:,:,i); end if(DST(cur).missing{n}(1)) vQinv(:,:,1) = sQ + sAQA - vA(:,:,2)' * vQinv(:,:,2) * vA(:,:,2); vQ(:,:,1) = safeinv(vQinv(:,:,1)) ; vmu = vQ(:,:,1) * ( sQmu + vA(:,:,2)' * vQinv(:,:,2) * vB(:,2) ); else vQinv(:,:,1) = sQ + DST(cur).C' * Rinv * DST(cur).C + ... sAQA - vA(:,:,2)' * vQinv(:,:,2) * vA(:,:,2); vQ(:,:,1) = safeinv(vQinv(:,:,1)); vmu = vQ(:,:,1) * ( sQmu + DST(cur).C' * Rinv * DST(cur).y{n}(1,:)' + ... vA(:,:,2)' * vQinv(:,:,2) * vB(:,2) ); end %%inference %initialization DST(cur).x{n} = zeros(DST(cur).xdim,T); DST(cur).x{n}(:,1) = vmu; DST(cur).sigma{n} = zeros(DST(cur).xdim,DST(cur).xdim,T); DST(cur).sigma{n}(:,:,1) = vQ(:,:,1); DST(cur).sigma2{n} = zeros(DST(cur).xdim,DST(cur).xdim,T); for t=2:T DST(cur).x{n}(:,t) = vB(:,t) + vA(:,:,t) * DST(cur).x{n}(:,t-1); DST(cur).sigma{n}(:,:,t) = vA(:,:,t) * DST(cur).sigma{n}(:,:,t-1) * vA(:,:,t)' + vQ(:,:,t); DST(cur).sigma2{n}(:,:,t) = vA(:,:,t) * DST(cur).sigma{n}(:,:,t-1); end