DESERT 3.5.1
Loading...
Searching...
No Matches
least_squares.cpp
Go to the documentation of this file.
1// Redistribution and use in source and binary forms, with or without
2// modification, are permitted provided that the following conditions
3// are met:
4// 1. Redistributions of source code must retain the above copyright
5// notice, this list of conditions and the following disclaimer.
6// 2. Redistributions in binary form must reproduce the above copyright
7// notice, this list of conditions and the following disclaimer in the
8// documentation and/or other materials provided with the distribution.
9// 3. Neither the name of the University of Padova (SIGNET lab) nor the
10// names of its contributors may be used to endorse or promote products
11// derived from this software without specific prior written permission.
12//
13// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
14// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
15// TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
17// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
18// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
19// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20// OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
21// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
22// OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
23// ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
24//
25
31#include "least_squares.h"
32#include <iostream>
33#include <cmath>
34
35namespace { //subroutines hidden in private namespace
36int sHhTransf(bool flm,int lfulcr,int p1,int m,std::vector<double> &u,int ud,double &su,double *cm,int sk,int borg,int nnn) {
37 if(m<1 || u.empty() || ud<1 || lfulcr<0 || lfulcr>=p1 || p1>m) return(1);
38
39 double cl = std::abs(u[lfulcr*ud]);
40
41 if(flm) {
42 if(cl<=0.) return(0);
43 } else {
44
45 for(int j=p1; j<m; j++) {
46 cl = std::max(std::abs(u[j*ud]), cl);
47 }
48 if(cl<=0.) return(0);
49
50 double clinv=1./cl;
51 double d1=u[lfulcr*ud]*clinv;
52 double sm=d1*d1;
53 for(int j=p1; j<m; j++) {
54 double d2=u[j*ud]*clinv;
55 sm+=d2*d2;
56 }
57 cl*=std::sqrt(sm);
58 if(u[lfulcr*ud] > 0.) {cl=-cl;}
59 su = u[lfulcr*ud] - cl;
60 u[lfulcr*ud] = cl;
61 }
62
63 double b=su*u[lfulcr*ud];
64
65 if(b >= 0.) return(0);
66 if(cm == nullptr) return(2);
67
68 for(int j=0; j<nnn; j++) {
69 double sm = cm[ lfulcr*sk + j*borg ] * (su);
70 for(int k=p1; k<m; k++) {
71 sm += cm[ k * sk + j*borg ] * u[ k*ud ];
72 }
73 if(sm!=0.) {
74 sm *= (1./b);
75 cm[ lfulcr * sk + j*borg] += sm*(su);
76 for(int k=p1; k<m; k++) {
77 cm[ k*sk + j*borg] += u[k * ud]*sm;
78 }
79 }
80 }
81 return(0);
82}
83
84void rotMat(double a, double b, double &c, double &s, double &fre)
85{
86 double d1, xr, yr;
87
88 if(std::abs(a)>std::abs(b)) {
89 if(a!=0) {xr=b/a; } else {std::cout << "nnls 1" << std::endl;}
90
91 d1=xr;
92 yr=std::hypot(d1, 1.);
93 if(yr!=0) {d1=1./yr;} else {std::cout << "nnls line 2" << std::endl;}
94
95 c=std::copysign(d1, a);
96 s=(c)*xr;
97 fre=std::abs(a)*yr;
98 } else if(b!=0.) {
99 xr=a/b;
100 d1=xr;
101 yr=std::hypot(d1, 1.);
102 if(yr!=0) {d1=1./yr;} else {std::cout << "nnls line 3" << std::endl;}
103 s=std::copysign(d1, b);
104 c=(s)*xr;
105 fre=std::abs(b)*yr;
106 } else {
107 fre=0.;
108 c=0.;
109 s=1.;
110 }
111}
112
113} //private namespace
114
115LSSQ::LeastSqResult LSSQ::nnLeastSquares(std::vector<std::vector<double>> a,std::vector<double> b,std::vector<double> &x,double* resid)
116{
117 if(a.empty() || b.empty() || x.empty()) return(LeastSqResult::ERROR);
118 int m = 0;
119 int n = a.size();
120 if (n > 0) {
121 m = a[0].size();
122 for(int i = 0; i < n; i++) {
123 if (a[i].size() != m) return(LeastSqResult::ERROR);
124 }
125 } else return(LeastSqResult::ERROR);
126
127 std::vector<int> v1 (n);
128 std::vector<double> v2 (n);
129 std::vector<double> v3 (m);
130
131 for(int i=0; i<n; i++) {
132 v1[i]=i;
133 }
134 int inda=0;
135 int indb=n-1;
136 int tup=0;
137 int cccp=0;
138
139 double up=0.;
140 int mcyc;
141 if(n<3)
142 mcyc = 3*n;
143 else mcyc = n*n;
144 int iter=0;
145 int o, y=0, yy=0;
146 while(inda<=indb && tup<m) {
147 for(int iz=inda; iz<=indb; iz++) {
148 int ni=v1[iz];
149 double sm=0.;
150 for(int mi=cccp; mi<m; mi++) {
151 sm+=a[ni][mi]*b[mi];
152 }
153 v2[ni]=sm;
154 }
155 double lim;
156 int indm=0;
157 while(1) {
158 lim=0.;
159 for(int iz=inda; iz<=indb; iz++) {
160 int i=v1[iz];
161 if(v2[i]>lim) {
162 lim=v2[i];
163 indm=iz;
164 }
165 }
166
167 if(lim <= 0.) break;
168 y=v1[indm];
169 double asave=a[y][cccp];
170 up=0.;
171 sHhTransf(false, cccp, cccp+1, m, a[y], 1, up, nullptr, 1, 1, 0);
172 double plain=0.;
173 if(tup!=0) {
174 for(int mi=0; mi<tup; mi++) {
175 plain+=a[y][mi]*a[y][mi];
176 }
177 }
178 plain = std::sqrt(plain);
179 double d = plain + std::abs(a[y][cccp]) * 0.01;
180 if((d - plain) > 0.) {
181 for(int mi=0; mi<m; mi++) {
182 v3[mi]=b[mi];
183 }
184 sHhTransf(true, cccp, cccp+1, m, a[y], 1, up, v3.data(), 1, 1, 1);
185 double ztest;
186 if(a[y][cccp]!=0) {ztest=v3[cccp]/a[y][cccp];} else {std::cout << "nnls line 3" << std::endl;}
187 //double ztest=v3[cccp]/a[y][cccp];
188 if(ztest > 0.) break;
189 }
190 a[y][cccp] = asave;
191 v2[y] = 0.;
192 }
193 if(lim <= 0.) break;
194
195 for(int mi=0; mi<m; mi++) {
196 b[mi]=v3[mi];
197 }
198 v1[indm] = v1[inda];
199 v1[inda++] = y;
200 tup = 1 + cccp++;
201 if(inda<=indb)
202 for(int jz = inda; jz <= indb; jz++) {
203 yy = v1[jz];
204 sHhTransf(true, tup-1, cccp, m, a[y], 1, up, a[yy].data(), 1, m, 1);
205 }
206 if(tup!=m) {
207 for(int mi=cccp; mi<m; mi++) {
208 a[y][mi]=0.;
209 }
210 }
211 v2[y]=0.;
212
213 for(int mi=0; mi<tup; mi++) {
214 int ip=tup-(mi+1);
215 if(mi!=0) {
216 for(int ii = 0; ii <= ip; ii++) {
217 v3[ii] -= a[yy][ii]*v3[ip+1];
218 }
219 }
220 yy = v1[ip];
221 if(a[yy][ip]!=0) {v3[ip]/=a[yy][ip];} else {std::cout << "nnls line 4" << std::endl;}
222 //v3[ip]/=a[yy][ip];
223 }
224
225 while(++iter<mcyc) {
226 double cent=2.;
227 for(int ip=0; ip<tup; ip++) {
228 int ni=v1[ip];
229 if(v3[ip]<=0.) {
230 //double t=-x[ni]/(v3[ip]-x[ni]);
231 double t=2;
232 if((v3[ip]-x[ni])!=0) {t=-x[ni]/(v3[ip]-x[ni]);} else {std::cout << "nnls line 5" << std::endl;}
233 if(cent>t) {
234 cent=t;
235 yy=ip-1;
236 }
237 }
238 }
239
240 if(cent==2.) break;
241
242 for(int ip=0; ip<tup; ip++) {
243 int ni=v1[ip];
244 x[ni]+=cent*(v3[ip]-x[ni]);
245 }
246
247 int fac=1;
248 o=v1[yy+1];
249 do {
250 x[o]=0.;
251 if(yy!=(tup-1)) {
252 yy++;
253 for(int ni=yy+1; ni<tup; ni++) {
254 int ii=v1[ni];
255 v1[ni-1]=ii;
256 double rig, fil;
257 rotMat(a[ii][ni-1], a[ii][ni], fil, rig, a[ii][ni-1]);
258 a[ii][ni]=0.;
259 for(int nj=0; nj<n; nj++) {
260 if(nj!=ii) {
261 double fai=a[nj][ni-1];
262 a[nj][ni-1] = fil*fai+rig*a[nj][ni];
263 a[nj][ni] =- rig*fai+fil*a[nj][ni];
264 }
265 }
266 double fai=b[ni-1];
267 b[ni-1]=fil*fai+rig*b[ni];
268 b[ni]=-rig*fai+fil*b[ni];
269 }
270 }
271 cccp = tup-1;
272 tup--;
273 inda--;
274 v1[inda] = o;
275
276 for(int i=0, fac=1; i<tup; i++) {
277 o=v1[i];
278 if(x[o]<=0.) {
279 fac=0; break;
280 }
281 }
282 } while(fac==0);
283
284 for(int i=0; i<m; i++) {
285 v3[i]=b[i];
286 }
287 for(int i=0; i<tup; i++) {
288 int mi = tup-(i+1);
289 if(i!=0) {
290 for(int ii=0; ii<=mi; ii++) {
291 v3[ii]-=a[yy][ii]*v3[i+1];
292 }
293 }
294 yy=v1[mi];
295 if(a[yy][mi]!=0) {v3[mi]/=a[yy][mi];} else {std::cout << "nnls line 6" << std::endl;}
296 //v3[mi]/=a[yy][mi];
297 }
298 }
299
300 if(iter>=mcyc) break;
301 for(int i=0; i<tup; i++) {
302 o=v1[i];
303 x[o]=v3[i];
304 }
305 }
306
307 if(resid != nullptr) {
308 double sm=0.;
309 if(cccp<m) {
310 for(int i=cccp; i<m; i++) {
311 sm += b[i]*b[i];
312 }
313 }
314 else {
315 for(int i=0; i<n; i++) {
316 v2[i]=0.;
317 }
318 }
319 *resid=sm;
320 }
321 if(iter>=mcyc) return(LeastSqResult::TIMEOUT);
322 return(LeastSqResult::OK);
323}
324
LeastSqResult nnLeastSquares(std::vector< std::vector< double > > a, std::vector< double > b, std::vector< double > &x, double *resid=nullptr)
Least Squares Linear Regressor solves the least squares problem A * X = B, X>=0.