imate
C++/CUDA Reference
Loading...
Searching...
No Matches
convergence_tools.cpp
Go to the documentation of this file.
1/*
2 * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3 * SPDX-License-Identifier: BSD-3-Clause
4 * SPDX-FileType: SOURCE
5 *
6 * This program is free software: you can redistribute it and/or modify it
7 * under the terms of the license found in the LICENSE.txt file in the root
8 * directory of this source tree.
9 */
10
11
12// =======
13// Headers
14// =======
15
16#include "./convergence_tools.h"
17#include <cmath> // std::sqrt, std::abs, INFINITY, NAN, isnan
18#include <algorithm> // std::max
19#include "./special_functions.h" // erf_inv
20#include "../_c_arithmetics/c_arithmetics.h" // c_arithmetics
21
22
23// =================
24// check convergence
25// =================
26
94
95template <typename DataType>
97 DataType** samples,
98 const IndexType min_num_samples,
99 const IndexType num_inquiries,
100 const IndexType* processed_samples_indices,
101 const IndexType num_processed_samples,
102 const DataType confidence_level,
103 const DataType error_atol,
104 const DataType error_rtol,
105 DataType* error,
106 IndexType* num_samples_used,
107 FlagType* converged)
108{
109 FlagType all_converged;
110 IndexType j;
111
112 // If number of processed samples are not enough, set to not converged yet.
113 // This is essential since in the first few iterations, the standard
114 // deviation of the cumulative averages are still too small.
115 if (num_processed_samples < min_num_samples)
116 {
117 // Skip computing error. Fill outputs with trivial initial values
118 for (j=0; j < num_inquiries; j++)
119 {
120 error[j] = INFINITY;
121 converged[j] = 0;
122 num_samples_used[j] = num_processed_samples;
123 }
124 all_converged = 0;
125 return all_converged;
126 }
127
128 IndexType i;
129 DataType summand;
130 DataType mean;
131 DataType mean_abs;
132 DataType std;
133 DataType mean_discrepancy;
134 DataType data;
135
136 // Quantile of normal distribution (usually known as the "z" coefficient)
137 DataType standard_z_score = std::sqrt(2) * \
138 static_cast<DataType>(erf_inv(static_cast<double>(confidence_level)));
139
140 // For each column of samples, compute error of all processed rows
141 for (j=0; j < num_inquiries; ++j)
142 {
143 // Do not check convergence if j-th column already converged
144 if (converged[j] == 0)
145 {
146 // mean of j-th column using all processed rows of j-th column
147 summand = 0.0;
148 for (i=0; i < num_processed_samples; ++i)
149 {
150 summand += samples[processed_samples_indices[i]][j];
151 }
152 mean = summand / num_processed_samples;
153
154 // mean of absolute values of j-th column using all processed rows
155 // of j-th column
156 summand = 0.0;
157 for (i=0; i < num_processed_samples; ++i)
158 {
159 summand += std::abs(samples[processed_samples_indices[i]][j]);
160 }
161 mean_abs = summand / num_processed_samples;
162
163 // If mean of absolute values is zero, do not scale data with it as
164 // it causes divide by zero. In this case, all sample data are
165 // zero, and there is no need to scale them.
166 DataType zero = 0.0;
167 if (c_arithmetics::is_equal(mean_abs, zero))
168 {
169 mean_abs = 1.0;
170 }
171
172 // std of j-th column using all processed rows of j-th column
173 if (num_processed_samples > 1)
174 {
175 summand = 0.0;
176 for (i=0; i < num_processed_samples; ++i)
177 {
178 data = samples[processed_samples_indices[i]][j];
179 mean_discrepancy = data - mean;
180
181 // Normalize to the mean of absolute values to avoid
182 // underflow and overflow, but later, re-scale it back. The
183 // underflow or overflow is caused by taking the square two
184 // lines below.
185 mean_discrepancy /= mean_abs;
186
187 summand += mean_discrepancy * mean_discrepancy;
188 }
189 std = std::sqrt(summand / (num_processed_samples - 1.0));
190
191 // Re-scale back by mean of absolute values
192 std *= mean_abs;
193 }
194 else
195 {
196 std = INFINITY;
197 }
198
199 // Compute error based of std and confidence level
200 error[j] = standard_z_score * std / \
201 std::sqrt(num_processed_samples);
202
203 // Check error with atol and rtol to find if j-th column converged
204 if (error[j] < std::max(error_atol, error_rtol*mean))
205 {
206 converged[j] = 1;
207 }
208
209 // Update how many samples used so far to average j-th column
210 num_samples_used[j] = num_processed_samples;
211 }
212 }
213
214 // Check convergence is reached for all columns (all inquiries)
215 all_converged = 1;
216 for (j=0; j < num_inquiries; ++j)
217 {
218 if (converged[j] == 0)
219 {
220 // The j-th column not converged.
221 all_converged = 0;
222 break;
223 }
224 }
225
226 return all_converged;
227}
228
229
230// =================
231// average estimates
232// =================
233
287
288template <typename DataType>
290 const DataType confidence_level,
291 const DataType outlier_significance_level,
292 const IndexType num_inquiries,
293 const IndexType max_num_samples,
294 const IndexType* num_samples_used,
295 const IndexType* processed_samples_indices,
296 DataType** samples,
297 IndexType* num_outliers,
298 DataType* trace,
299 DataType* error)
300{
301 IndexType i;
302 IndexType j;
303 DataType summand;
304 DataType mean;
305 DataType mean_abs;
306 DataType std;
307 DataType mean_discrepancy;
308 DataType outlier_half_interval;
309
310 // Flag which samples are outliers
311 FlagType* outlier_indices = new FlagType[max_num_samples];
312
313 // Quantile of normal distribution (usually known as the "z" coefficient)
314 DataType error_z_score = std::sqrt(2) * erf_inv(confidence_level);
315
316 // Confidence level of outlier is the complement of significance level
317 DataType outlier_confidence_level = 1.0 - outlier_significance_level;
318
319 // Quantile of normal distribution area where is not considered as outlier
320 DataType outlier_z_score = std::sqrt(2.0) * \
321 erf_inv(outlier_confidence_level);
322
323 for (j=0; j < num_inquiries; ++j)
324 {
325 // Initialize outlier indices for each column of samples
326 for (i=0; i < max_num_samples; ++i)
327 {
328 outlier_indices[i] = 0;
329 }
330 num_outliers[j] = 0;
331
332 // Compute mean of the j-th column
333 summand = 0.0;
334 for (i=0; i < num_samples_used[j]; ++i)
335 {
336 summand += samples[processed_samples_indices[i]][j];
337 }
338 mean = summand / num_samples_used[j];
339
340 // Compute mean of the absolute values of j-th column
341 summand = 0.0;
342 for (i=0; i < num_samples_used[j]; ++i)
343 {
344 summand += std::abs(samples[processed_samples_indices[i]][j]);
345 }
346 mean_abs = summand / num_samples_used[j];
347
348 // If mean of absolute values is zero, do not scale data with it as it
349 // causes divide by zero. In this case, all sample data are zero, and
350 // there is no need to scale them.
351 DataType zero = 0.0;
352 if (c_arithmetics::is_equal(mean_abs, zero))
353 {
354 mean_abs = 1.0;
355 }
356
357 // Compute std of the j-th column
358 if (num_samples_used[j] > 1)
359 {
360 summand = 0.0;
361 for (i=0; i < num_samples_used[j]; ++i)
362 {
363 mean_discrepancy = \
364 samples[processed_samples_indices[i]][j] - mean;
365
366 // Normalize to the mean of absolute values to avoid underflow
367 // and overflow, but later, re-scale it back. The underflow or
368 // overflow is caused by taking the square two lines below.
369 mean_discrepancy /= mean_abs;
370
371 summand += mean_discrepancy * mean_discrepancy;
372 }
373 std = std::sqrt(summand / (num_samples_used[j] - 1.0));
374
375 // Re-scale back by mean of absolute values
376 std *= mean_abs;
377 }
378 else
379 {
380 std = INFINITY;
381 }
382
383 // Outlier half interval
384 outlier_half_interval = outlier_z_score * std;
385
386 // Find outliers by the difference of each element from mean
387 for (i=0; i < num_samples_used[j]; ++i)
388 {
389 mean_discrepancy = samples[processed_samples_indices[i]][j] - mean;
390 if (std::abs(mean_discrepancy) > outlier_half_interval)
391 {
392 // Outlier detected
393 outlier_indices[i] = 1;
394 num_outliers[j] += 1;
395 }
396 }
397
398 // Re-evaluate mean but leave out outliers
399 summand = 0.0;
400 for (i=0; i < num_samples_used[j]; ++i)
401 {
402 if (outlier_indices[i] == 0)
403 {
404 summand += samples[processed_samples_indices[i]][j];
405 }
406 }
407 mean = summand / (num_samples_used[j] - num_outliers[j]);
408
409 // Re-evaluate std but leave out outliers
410 if (num_samples_used[j] > 1 + num_outliers[j])
411 {
412 summand = 0.0;
413 for (i=0; i < num_samples_used[j]; ++i)
414 {
415 if (outlier_indices[i] == 0)
416 {
417 mean_discrepancy = \
418 samples[processed_samples_indices[i]][j] - mean;
419
420 // Normalize to the mean of absolute values to avoid underflow
421 // and overflow, but later, re-scale it back. The underflow or
422 // overflow is caused by taking the square two lines below.
423 mean_discrepancy /= mean_abs;
424
425 summand += mean_discrepancy * mean_discrepancy;
426 }
427 }
428 std = std::sqrt(
429 summand/(num_samples_used[j] - num_outliers[j] - 1.0));
430
431 // Re-scale back by mean of absolute values
432 std *= mean_abs;
433 }
434 else
435 {
436 std = INFINITY;
437 }
438
439 // trace and its error
440 trace[j] = mean;
441 error[j] = error_z_score * std / \
442 std::sqrt(num_samples_used[j] - num_outliers[j]);
443 }
444
445 delete[] outlier_indices;
446}
447
448
449// ===============================
450// Explicit template instantiation
451// ===============================
452
453template class ConvergenceTools<float>;
454template class ConvergenceTools<double>;
455template class ConvergenceTools<long double>;
A static class to compute the trace of implicit matrix functions using stochastic Lanczos quadrature ...
static FlagType check_convergence(DataType **samples, const IndexType min_num_samples, const IndexType num_inquiries, const IndexType *processed_samples_indices, const IndexType num_processed_samples, const DataType confidence_level, const DataType error_atol, const DataType error_rtol, DataType *error, IndexType *num_samples_used, FlagType *converged)
Checks if the standard deviation of the set of the cumulative averages of trace estimators converged ...
static void average_estimates(const DataType confidence_level, const DataType outlier_significance_level, const IndexType num_inquiries, const IndexType max_num_samples, const IndexType *num_samples_used, const IndexType *processed_samples_indices, DataType **samples, IndexType *num_outliers, DataType *trace, DataType *error)
Averages the estimates of trace. Removes outliers and reevaluates the error to take into account for ...
bool is_equal(DataType x, DataType y)
Check if two floating point numbers are equal within a tolerance.
Definition _c_is_equal.h:49
double erf_inv(const double x)
Inverse error function.
int FlagType
Definition types.h:68
int IndexType
Definition types.h:65