function [w,lambda,alpha,delta,deltas] = solve_qp(kzz,y,C,B,alpha,delta,deltas,lambda) % Input: % Kzz is the Gram matrix of the training examples % y is the vector of training labels % C is the trade off between the slack and the margin % B is the upper bound on the training predictions % optional parameters: alpha, delta, deltas, lambda can be used % to warm start the code % Output: % w, lambda : parameters learned from the training. If Ktz is the % kernel between the test and the training examples, the predictions % are given by Ktz*w + lambda % alpha, dleta, deltas are the lagrange multipliers on the constraints % as in the NIPS paper "Relative Margin Machines" l = length(y); %C = C/l; n = size(kzz,1); %dy = diag(y); %yyt = y*y'; eps = 1e-3; beps = 0.01; ceps = C*1e-3; deps = 1e-3; rand('seed',1); if nargin < 4 fprintf('Not enough input arguments\n'); return; end good = 0; obj = 0; if nargin == 4 alpha = zeros(l,1); alphay = alpha; delta = zeros(n,1); deltas = zeros(n,1); all_pred = zeros(n,1); lambda = 0; else alphay = alpha.*y; w = [alphay; zeros(n-l,1) ] - delta + deltas; all_pred = kzz*w; all_pred = all_pred + lambda; end if n>800 q=800; else q = 8*round(n/10); end param.MSK_DPAR_INTPNT_CO_TOL_PFEAS = 1.0e-12; param.MSK_DPAR_INTPNT_CO_TOL_DFEAS = 1.0e-12; param.MSK_DPAR_INTPNT_CO_TOL_REL_GAP = 1.0e-12; param.MSK_IPAR_INTPNT_NUM_THREADS = 2; %param.MSK_IPAR_LOG = 0; list1 = [ 1:(l+2*n); 1:l 1:n 1:n ]'; count = 0; indl = []; indv1 =[]; indv2 =[]; kzzalphay = kzz*[alphay; zeros(n-l,1) ]; kzzdelta = kzz*delta; kzzdeltas = kzz*deltas; maxdelta = max(delta); maxdeltas = max(deltas); randfrac = 0.2; while 1 count = count + 1; if (rand < randfrac) || ( mod(count,10)==0) ind = [1:l+2*n]'; ind = shuffle(ind,count); ind = sort(ind(1:q)); indl = ind(find(ind<=l)); indv1 = ind(intersect(find(ind>l), find(ind<=l+n)))-l; indv2 = ind(find(l+n l+2*n; break end tl1 = list(i,2); tl2 = list(i,3); if tl1 <= l temp = alpha(tl2); if (1e-5+ceps)= temp if y(tl2)==-1 % if not(ismember(tl2,indl)) indl = [ indl; tl2 ];picked = picked + 1; % end end else if y(tl2)==+1 % if not(ismember(tl2,indl)) indl = [ indl; tl2 ];picked = picked + 1; % end end end elseif l < tl1 && tl1 <=l+n % if not(ismember(tl2,indv1)) indv1 = [ indv1; tl2 ]; picked = picked + 1; % end else if deltas(tl2) > maxdelta*deps+1e-5 % if not(ismember(tl2,indv2)) indv2 = [ indv2; tl2 ]; picked = picked + 1; % end end end end i=l+2*n+1; while 1 if picked==q break; end i = i-1; if i<=0 break; end tl1 = list(i,2); tl2 = list(i,3); if tl1<=l temp = alpha(tl2); if (1e-5+ceps)= temp if y(tl2)==+1 if not(ismember(tl2,indl)) indl = [ indl; tl2 ];picked = picked + 1; end end else if y(tl2)==-1 if not(ismember(tl2,indl)) indl = [ indl; tl2 ];picked = picked + 1; end end end elseif l < tl1 && tl1 <=l+n if delta(tl2) > maxdelta*deps+1e-5 if not(ismember(tl2,indv1)) indv1 = [ indv1; tl2 ]; picked = picked + 1; end end else if not(ismember(tl2,indv2)) indv2 = [ indv2; tl2 ]; picked = picked + 1; end end end end fixedl = setdiff([1:l]',indl); fixedv1 = setdiff([1:n]',indv1); fixedv2 = setdiff([1:n]',indv2); sl =length(indl) ; sv1=length(indv1); sv2=length(indv2); %% optimize the sub-problem quad =[diag(y(indl))* kzz(indl,indl)*diag(y(indl)) sparse(sl,sv1+sv2); ... -kzz(indv1,indl)*diag(y(indl)) kzz(indv1,indv1) sparse(sv1,sv2);... kzz(indv2,indl)*diag(y(indl)) -kzz(indv2,indv1) kzz(indv2,indv2) ]*C; [prob.qosubi,prob.qosubj,prob.qoval] = find(tril(sparse(double(quad)))); prob.blx = [ sparse(sl+sv1+sv2,1) ]; prob.bux = [ ones(sl,1) ; inf(sv1+sv2,1) ]; % temp1 = dy(indl,indl)*((kzz(indl,fixedl)*(alphay(fixedl)) ... % -kzz(indl,fixedv1)*delta(fixedv1) ... % +kzz(indl,fixedv2)*deltas(fixedv2))) ; % temp2 = (-kzz(indv1,fixedl)*alphay(fixedl) ... % + kzz(indv1,fixedv1)*delta(fixedv1) ... % - kzz(indv1,fixedv2)*deltas(fixedv2)); % temp3 = (kzz(indv2,fixedl)*alphay(fixedl) ... % - kzz(indv2,fixedv1)*delta(fixedv1) ... % + kzz(indv2,fixedv2)*deltas(fixedv2)); ntemp1 = zeros(sl,1); ntemp2 = zeros(sv1,1); ntemp3 = zeros(sv2,1); if length(indl) ntemp1 = ntemp1+y(indl).*(kzzalphay(indl) - kzzdelta(indl) + kzzdeltas(indl) ... - kzz(indl,indl)*alphay(indl)); if length(indv1) ntemp1 = ntemp1 + y(indl).*( kzz(indl,indv1)*delta(indv1)) ; end if length(indv2) ntemp1 = ntemp1 - y(indl).*( kzz(indl,indv2)*deltas(indv2)); end end if length(indv1) ntemp2 = ntemp2 + ( kzzdelta(indv1) - kzzalphay(indv1) - kzzdeltas(indv1) ... - kzz(indv1,indv1)*delta(indv1)) ; if length(indl) ntemp2 = ntemp2 + ( kzz(indv1,indl)*alphay(indl)); end if length(indv2) ntemp2 = ntemp2 + ( kzz(indv1,indv2)*deltas(indv2)); end end if length(indv2) ntemp3 = ntemp3 +( kzzdeltas(indv2) -kzzdelta(indv2) + kzzalphay(indv2) ... - kzz(indv2,indv2)*deltas(indv2)); if length(indv1) ntemp3 = ntemp3 + ( kzz(indv2,indv1)*delta(indv1)); end if length(indl) ntemp3 = ntemp3 - ( kzz(indv2,indl)*alphay(indl)); end end % [ length(indl) length(indv1) length(indv2) ] % assert(norm(temp1 - ntemp1) < 1e-5) % assert(norm(temp2 - ntemp2) < 1e-5 ) % assert(norm(temp3 - ntemp3) < 1e-5 ) % prob.c = [ prob.c; ntemp3+B*ones(sv2,1) ]; prob.c = double([ ntemp1-ones(sl,1); ntemp2+B*ones(sv1,1); ntemp3+B*ones(sv2,1) ]); prob.a = sparse([ y(indl)' -ones(1,sv1) ones(1,sv2) ]); prob.buc = (-sum(alphay(fixedl)) + sum(delta(fixedv1)) - sum(deltas(fixedv2)))/C; prob.blc = prob.buc ; [r,res] = mosekopt('minimize echo(0)',prob,param); % [r,res] = mosekopt('minimize',prob,param); old_alphay = alphay(indl); old_delta = delta(indv1); old_deltas = deltas(indv2); alpha(indl) = C*res.sol.itr.xx(1:sl); delta(indv1)=C*res.sol.itr.xx(sl+[1:sv1]); deltas(indv2)=C*res.sol.itr.xx(sv1+sl+[1:sv2]); maxdelta = max(delta); maxdeltas = max(maxdeltas); alphay(indl) = alpha(indl).*y(indl); if length(indl) kzzalphay = kzzalphay - kzz(:,indl)*(old_alphay - alphay(indl)); end if length(indv1) kzzdelta = kzzdelta - kzz(:,indv1)*(old_delta - delta(indv1)); end if length(indv2) kzzdeltas = kzzdeltas - kzz(:,indv2)*(old_deltas -deltas(indv2)); end % w = [alphay; zeros(n-l,1) ] - delta + deltas; if count==1 w = [alphay; zeros(n-l,1) ] - delta + deltas; obj=(0.5*w'*(kzzalphay - kzzdelta + kzzdeltas) ... - sum(alpha) + B*(sum(delta) + sum(deltas)))/C; end if mod(count,10)==0 w = [alphay; zeros(n-l,1) ] - delta + deltas; oldobj = obj; obj=(0.5*w'*(kzzalphay - kzzdelta + kzzdeltas)... -sum(alpha) + B*(sum(delta) + sum(deltas)))/C; progress= (oldobj - obj)/abs(obj); if (progress < 0) && count >20 good = good + 5; elseif progress < 0 good = good; elseif progress < 1e-10 good = good + 6; elseif progress < 1e-9 good = good + 5; elseif progress < 1e-9 good = good + 4; elseif progress < 1e-7 good = good + 2.5; elseif progress < 1e-6; good = good + 1; elseif progress > 1e-4; good = 0; end % fprintf('%f %d\n',progress,good); if good>=10 break; end %pause; end ind = intersect(find((1e-5+ceps) < alpha), find(alpha< C-(1e-5+ceps))); ind1 = find(delta > maxdelta*deps+1e-5); ind2 = find(deltas >maxdeltas*deps+1e-5); all_pred = kzzalphay - kzzdelta + kzzdeltas; lambda = res.sol.itr.suc - res.sol.itr.slc ; all_pred = all_pred + lambda; if sum(y(ind) - all_pred(ind) > eps) continue; end if sum( - eps > y(ind) - all_pred(ind)) continue; end ind = alpha <=(1e-5+ceps); if sum(y(ind).*(all_pred(ind)) < 1 - eps ) continue; end ind = alpha >=C-(1e-5+ceps); if sum(y(ind).*(all_pred(ind)) > 1 + eps ) continue; end % such examples should be at the boundary ind = delta > maxdelta*deps+1e-5; if sum(all_pred(ind) > B + B*beps) continue; end if sum( all_pred(ind) < B - B*beps ) continue; end ind = deltas > maxdeltas*deps+1e-5; if sum(-all_pred(ind) < B - B*beps) continue; end if sum( -all_pred(ind) > B + B*beps ) continue; end %such examples should be within the bound B ind = delta <= maxdelta*deps+1e-5; if sum(all_pred(ind) > B + B*beps) continue; end ind = deltas <=maxdeltas*deps+1e-5; if sum(-all_pred(ind) > B + B*beps) continue; end break; end disp('\n'); w = [alphay; zeros(n-l,1) ] - delta + deltas;