imate
C++/CUDA Reference
convergence_tools.cpp
Go to the documentation of this file.
1 /*
2  * SPDX-FileCopyrightText: Copyright 2021, Siavash Ameli <sameli@berkeley.edu>
3  * SPDX-License-Identifier: BSD-3-Clause
4  * SPDX-FileType: SOURCE
5  *
6  * This program is free software: you can redistribute it and/or modify it
7  * under the terms of the license found in the LICENSE.txt file in the root
8  * directory of this source tree.
9  */
10 
11 
12 // =======
13 // Headers
14 // =======
15 
16 #include "./convergence_tools.h"
17 #include <cmath> // sqrt, std::abs, INFINITY, NAN, isnan
18 #include <algorithm> // std::max
19 #include "./special_functions.h" // erf_inv
20 
21 
22 // =================
23 // check convergence
24 // =================
25 
93 
94 template <typename DataType>
96  DataType** samples,
97  const IndexType min_num_samples,
98  const IndexType num_inquiries,
99  const IndexType* processed_samples_indices,
100  const IndexType num_processed_samples,
101  const DataType confidence_level,
102  const DataType error_atol,
103  const DataType error_rtol,
104  DataType* error,
105  IndexType* num_samples_used,
106  FlagType* converged)
107 {
108  FlagType all_converged;
109  IndexType j;
110 
111  // If number of processed samples are not enough, set to not converged yet.
112  // This is essential since in the first few iterations, the standard
113  // deviation of the cumulative averages are still too small.
114  if (num_processed_samples < min_num_samples)
115  {
116  // Skip computing error. Fill outputs with trivial initial values
117  for (j=0; j < num_inquiries; j++)
118  {
119  error[j] = INFINITY;
120  converged[j] = 0;
121  num_samples_used[j] = num_processed_samples;
122  }
123  all_converged = 0;
124  return all_converged;
125  }
126 
127  IndexType i;
128  DataType summand;
129  DataType mean;
130  DataType std;
131  DataType data;
132 
133  // Quantile of normal distribution (usually known as the "z" coefficient)
134  DataType standard_z_score = sqrt(2) * \
135  static_cast<DataType>(erf_inv(static_cast<double>(confidence_level)));
136 
137  // For each column of samples, compute error of all processed rows
138  for (j=0; j < num_inquiries; ++j)
139  {
140  // Do not check convergence if j-th column already converged
141  if (converged[j] == 0)
142  {
143  // mean of j-th column using all processed rows of j-th column
144  summand = 0.0;
145  for (i=0; i < num_processed_samples; ++i)
146  {
147  summand += samples[processed_samples_indices[i]][j];
148  }
149  mean = summand / num_processed_samples;
150 
151  // std of j-th column using all processed rows of j-th column
152  if (num_processed_samples > 1)
153  {
154  summand = 0.0;
155  for (i=0; i < num_processed_samples; ++i)
156  {
157  data = samples[processed_samples_indices[i]][j];
158  summand += (data - mean) * (data - mean);
159  }
160  std = sqrt(summand / (num_processed_samples - 1.0));
161  }
162  else
163  {
164  std = INFINITY;
165  }
166 
167  // Compute error based of std and confidence level
168  error[j] = standard_z_score * std / sqrt(num_processed_samples);
169 
170  // Check error with atol and rtol to find if j-th column converged
171  if (error[j] < std::max(error_atol, error_rtol*mean))
172  {
173  converged[j] = 1;
174  }
175 
176  // Update how many samples used so far to average j-th column
177  num_samples_used[j] = num_processed_samples;
178  }
179  }
180 
181  // Check convergence is reached for all columns (all inquiries)
182  all_converged = 1;
183  for (j=0; j < num_inquiries; ++j)
184  {
185  if (converged[j] == 0)
186  {
187  // The j-th column not converged.
188  all_converged = 0;
189  break;
190  }
191  }
192 
193  return all_converged;
194 }
195 
196 
197 // =================
198 // average estimates
199 // =================
200 
254 
255 template <typename DataType>
257  const DataType confidence_level,
258  const DataType outlier_significance_level,
259  const IndexType num_inquiries,
260  const IndexType max_num_samples,
261  const IndexType* num_samples_used,
262  const IndexType* processed_samples_indices,
263  DataType** samples,
264  IndexType* num_outliers,
265  DataType* trace,
266  DataType* error)
267 {
268  IndexType i;
269  IndexType j;
270  DataType summand;
271  DataType mean;
272  DataType std;
273  DataType mean_discrepancy;
274  DataType outlier_half_interval;
275 
276  // Flag which samples are outliers
277  FlagType* outlier_indices = new FlagType[max_num_samples];
278 
279  // Quantile of normal distribution (usually known as the "z" coefficient)
280  DataType error_z_score = sqrt(2) * erf_inv(confidence_level);
281 
282  // Confidence level of outlier is the complement of significance level
283  DataType outlier_confidence_level = 1.0 - outlier_significance_level;
284 
285  // Quantile of normal distribution area where is not considered as outlier
286  DataType outlier_z_score = sqrt(2.0) * erf_inv(outlier_confidence_level);
287 
288  for (j=0; j < num_inquiries; ++j)
289  {
290  // Initialize outlier indices for each column of samples
291  for (i=0; i < max_num_samples; ++i)
292  {
293  outlier_indices[i] = 0;
294  }
295  num_outliers[j] = 0;
296 
297  // Compute mean of the j-th column
298  summand = 0.0;
299  for (i=0; i < num_samples_used[j]; ++i)
300  {
301  summand += samples[processed_samples_indices[i]][j];
302  }
303  mean = summand / num_samples_used[j];
304 
305  // Compute std of the j-th column
306 
307  if (num_samples_used[j] > 1)
308  {
309  summand = 0.0;
310  for (i=0; i < num_samples_used[j]; ++i)
311  {
312  mean_discrepancy = \
313  samples[processed_samples_indices[i]][j] - mean;
314  summand += mean_discrepancy * mean_discrepancy;
315  }
316  std = sqrt(summand / (num_samples_used[j] - 1.0));
317  }
318  else
319  {
320  std = INFINITY;
321  }
322 
323  // Outlier half interval
324  outlier_half_interval = outlier_z_score * std;
325 
326  // Difference of each element from
327  for (i=0; i < num_samples_used[j]; ++i)
328  {
329  mean_discrepancy = samples[processed_samples_indices[i]][j] - mean;
330  if (std::abs(mean_discrepancy) > outlier_half_interval)
331  {
332  // Outlier detected
333  outlier_indices[i] = 1;
334  num_outliers[j] += 1;
335  }
336  }
337 
338  // Reevaluate mean but leave out outliers
339  summand = 0.0;
340  for (i=0; i < num_samples_used[j]; ++i)
341  {
342  if (outlier_indices[i] == 0)
343  {
344  summand += samples[processed_samples_indices[i]][j];
345  }
346  }
347  mean = summand / (num_samples_used[j] - num_outliers[j]);
348 
349  // Reevaluate std but leave out outliers
350  if (num_samples_used[j] > 1 + num_outliers[j])
351  {
352  summand = 0.0;
353  for (i=0; i < num_samples_used[j]; ++i)
354  {
355  if (outlier_indices[i] == 0)
356  {
357  mean_discrepancy = \
358  samples[processed_samples_indices[i]][j] - mean;
359  summand += mean_discrepancy * mean_discrepancy;
360  }
361  }
362  std = sqrt(summand/(num_samples_used[j] - num_outliers[j] - 1.0));
363  }
364  else
365  {
366  std = INFINITY;
367  }
368 
369  // trace and its error
370  trace[j] = mean;
371  error[j] = error_z_score * std / \
372  sqrt(num_samples_used[j] - num_outliers[j]);
373  }
374 
375  delete[] outlier_indices;
376 }
377 
378 
379 // ===============================
380 // Explicit template instantiation
381 // ===============================
382 
383 template class ConvergenceTools<float>;
384 template class ConvergenceTools<double>;
385 template class ConvergenceTools<long double>;
A static class to compute the trace of implicit matrix functions using stochastic Lanczos quadrature ...
static FlagType check_convergence(DataType **samples, const IndexType min_num_samples, const IndexType num_inquiries, const IndexType *processed_samples_indices, const IndexType num_processed_samples, const DataType confidence_level, const DataType error_atol, const DataType error_rtol, DataType *error, IndexType *num_samples_used, FlagType *converged)
Checks if the standard deviation of the set of the cumulative averages of trace estimators converged ...
static void average_estimates(const DataType confidence_level, const DataType outlier_significance_level, const IndexType num_inquiries, const IndexType max_num_samples, const IndexType *num_samples_used, const IndexType *processed_samples_indices, DataType **samples, IndexType *num_outliers, DataType *trace, DataType *error)
Averages the estimates of trace. Removes outliers and reevaluates the error to take into account for ...
double erf_inv(const double x)
Inverse error function.
int FlagType
Definition: types.h:68
int IndexType
Definition: types.h:65