Chemical Data Processing Library C++ API - Version 1.4.0
KabschAlgorithm.hpp
Go to the documentation of this file.
1 /*
2  * KabschAlgorithm.hpp
3  *
4  * Copyright (C) 2003 Thomas Seidel <thomas.seidel@univie.ac.at>
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library; see the file COPYING. If not, write to
18  * the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
19  * Boston, MA 02111-1307, USA.
20  */
21 
27 #ifndef CDPL_MATH_KABSCHALGORITHM_HPP
28 #define CDPL_MATH_KABSCHALGORITHM_HPP
29 
30 #include <cstddef>
31 
32 #include "CDPL/Math/Check.hpp"
33 #include "CDPL/Math/TypeTraits.hpp"
34 #include "CDPL/Math/Matrix.hpp"
35 #include "CDPL/Math/Vector.hpp"
39 #include "CDPL/Base/Exceptions.hpp"
40 
41 
42 namespace CDPL
43 {
44 
45  namespace Math
46  {
47 
60  template <typename T>
62  {
63 
64  public:
66  typedef T ValueType;
71 
88  template <typename M1, typename M2, typename V>
89  bool align(const MatrixExpression<M1>& points, const MatrixExpression<M2>& ref_points, const VectorExpression<V>& weights,
90  bool do_center = true, std::size_t max_svd_iter = 0)
91  {
92 
94  typename V::SizeType>::Type SizeType;
95 
96  SizeType dim = points().getSize1();
97  SizeType num_pts = points().getSize2();
98 
99  CDPL_MATH_CHECK(dim == SizeType(ref_points().getSize1()) && num_pts == SizeType(ref_points().getSize2()),
100  "KabschAlgorithm: Point-sets of different size", Base::SizeError);
101 
102  CDPL_MATH_CHECK(num_pts == SizeType(weights().getSize()),
103  "KabschAlgorithm: Number of points != number of weights", Base::SizeError);
104 
105  ValueType w_sum = ValueType();
106 
107  for (SizeType i = 0; i < num_pts; i++) {
108  CDPL_MATH_CHECK(ValueType(weights()(i)) >= ValueType(), "KabschAlgorithm: weights must be non-negative entries", Base::ValueError);
109  w_sum += weights()(i);
110  }
111 
112  CDPL_MATH_CHECK(w_sum > ValueType(), "KabschAlgorithm: weights must contain some positive entry", Base::ValueError);
113 
114  if (do_center) {
115  prod(points, weights, centroid1);
116  prod(ref_points, weights, centroid2);
117 
118  centroid1 /= w_sum;
119  centroid2 /= w_sum;
120 
121  tmpPoints.resize(dim, num_pts, false);
122  tmpPoints.assign(points);
123 
124  tmpRefPoints.resize(dim, num_pts, false);
125  tmpRefPoints.assign(ref_points);
126 
127  for (SizeType i = 0; i < num_pts; i++) {
128  column(tmpPoints, i).minusAssign(centroid1) *= weights()(i) / w_sum;
129  column(tmpRefPoints, i).minusAssign(centroid2);
130  }
131 
132  } else {
133  tmpPoints.resize(dim, num_pts, false);
134  tmpPoints.assign(points);
135 
136  for (SizeType i = 0; i < num_pts; i++)
137  column(tmpPoints, i) *= weights()(i) / w_sum;
138  }
139 
140  covarMatrix.resize(dim, dim, false);
141 
142  if (do_center)
143  prod(tmpPoints, trans(tmpRefPoints), covarMatrix);
144  else
145  prod(tmpPoints, trans(ref_points), covarMatrix);
146 
147  return align(dim, do_center, max_svd_iter);
148  }
149 
163  template <typename M1, typename M2>
164  bool align(const MatrixExpression<M1>& points, const MatrixExpression<M2>& ref_points,
165  bool do_center = true, std::size_t max_svd_iter = 0)
166  {
167 
169 
170  SizeType dim = points().getSize1();
171  SizeType num_pts = points().getSize2();
172 
173  CDPL_MATH_CHECK(dim == SizeType(ref_points().getSize1()) && num_pts == SizeType(ref_points().getSize2()),
174  "KabschAlgorithm: Point-sets of different size", Base::SizeError);
175 
176  if (do_center) {
177  prod(points, ScalarVector<ValueType>(num_pts, ValueType(1) / num_pts), centroid1);
178  prod(ref_points, ScalarVector<ValueType>(num_pts, ValueType(1) / num_pts), centroid2);
179 
180  tmpPoints.resize(dim, num_pts, false);
181  tmpPoints.assign(points);
182 
183  tmpRefPoints.resize(dim, num_pts, false);
184  tmpRefPoints.assign(ref_points);
185 
186  for (SizeType i = 0; i < num_pts; i++) {
187  column(tmpPoints, i).minusAssign(centroid1);
188  column(tmpRefPoints, i).minusAssign(centroid2);
189  }
190  }
191 
192  covarMatrix.resize(dim, dim, false);
193 
194  if (do_center)
195  prod(tmpPoints, trans(tmpRefPoints), covarMatrix);
196  else
197  prod(points, trans(ref_points), covarMatrix);
198 
199  return align(dim, do_center, max_svd_iter);
200  }
201 
206  const MatrixType& getTransform() const
207  {
208  return transform;
209  }
210 
211  private:
212  template <typename SizeType>
213  bool align(SizeType dim, bool do_center, std::size_t max_svd_iter)
214  {
215  svdW.resize(dim);
216  svdV.resize(dim, dim, false);
217 
218  if (!svDecompose(covarMatrix, svdW, svdV, max_svd_iter))
219  return false;
220 
221  if (det(prod(covarMatrix, trans(svdV))) < ValueType())
222  column(svdV, dim - 1) *= -ValueType(1);
223 
224  SizeType xform_dim = dim + 1;
225 
226  transform.resize(xform_dim, xform_dim, false);
227 
228  range(transform, 0, dim, 0, dim).assign(prod(svdV, trans(covarMatrix)));
229 
230  MatrixRow<MatrixType> last_row(transform, dim);
231  MatrixColumn<MatrixType> last_col(transform, dim);
232 
233  range(last_row, 0, dim).assign(ZeroVector<ValueType>(dim));
234 
235  if (do_center)
236  range(last_col, 0, dim).assign(centroid2 - prod(range(transform, 0, dim, 0, dim), centroid1));
237  else
238  range(last_col, 0, dim).assign(ZeroVector<ValueType>(dim));
239 
240  transform(dim, dim) = ValueType(1);
241 
242  return true;
243  }
244 
245  MatrixType transform;
246  MatrixType tmpPoints;
247  MatrixType tmpRefPoints;
248  MatrixType covarMatrix;
249  VectorType svdW;
250  MatrixType svdV;
251  VectorType centroid1;
252  VectorType centroid2;
253  };
254  } // namespace Math
255 } // namespace CDPL
256 
257 #endif // CDPL_MATH_KABSCHALGORITHM_HPP
Definition of exception classes.
Definition of various preprocessor macros for error checking.
#define CDPL_MATH_CHECK(expr, msg, e)
Throws the exception e with message msg when the boolean expression expr evaluates to false.
Definition: Check.hpp:47
Definition of matrix proxy types.
Definition of matrix data types.
Implementation of matrix singular value decomposition and associated operations.
Definition of type traits.
Definition of vector proxy types.
Definition of vector data types.
Thrown to indicate that the size of a (multidimensional) array is not correct.
Definition: Base/Exceptions.hpp:133
Thrown to indicate errors caused by some invalid value.
Definition: Base/Exceptions.hpp:76
Implementation of the Kabsch algorithm [KABA].
Definition: KabschAlgorithm.hpp:62
Vector< T > VectorType
The vector type used for the centroids and singular-value vectors.
Definition: KabschAlgorithm.hpp:70
bool align(const MatrixExpression< M1 > &points, const MatrixExpression< M2 > &ref_points, bool do_center=true, std::size_t max_svd_iter=0)
Computes the rigid body transformation that aligns a set of -dimensional points points with a corres...
Definition: KabschAlgorithm.hpp:164
const MatrixType & getTransform() const
Returns the rigid-body transformation produced by the most recent successful align() call.
Definition: KabschAlgorithm.hpp:206
Matrix< T > MatrixType
The matrix type used for the transformation, the covariance matrix and the working buffers.
Definition: KabschAlgorithm.hpp:68
bool align(const MatrixExpression< M1 > &points, const MatrixExpression< M2 > &ref_points, const VectorExpression< V > &weights, bool do_center=true, std::size_t max_svd_iter=0)
Computes the rigid body transformation that aligns a set of -dimensional points points with a corres...
Definition: KabschAlgorithm.hpp:89
T ValueType
The scalar value type.
Definition: KabschAlgorithm.hpp:66
Vector-expression proxy that views a single column of an underlying matrix.
Definition: MatrixProxy.hpp:320
CRTP base class for all matrix expression types.
Definition: Expression.hpp:104
Vector-expression proxy that views a single row of an underlying matrix.
Definition: MatrixProxy.hpp:53
Dynamically-sized dense row-major matrix with configurable underlying storage.
Definition: Matrix.hpp:455
Constant vector expression in which every element equals the same scalar value.
Definition: Vector.hpp:2625
CRTP base class for all vector expression types.
Definition: Expression.hpp:66
Dynamically-sized dense vector with configurable underlying storage.
Definition: Vector.hpp:430
Constant vector expression whose elements are all zero.
Definition: Vector.hpp:2312
constexpr unsigned int T
Specifies Hydrogen (Tritium).
Definition: AtomType.hpp:67
MatrixTranspose< E > trans(MatrixExpression< E > &e)
Returns a mutable Math::MatrixTranspose view of the matrix expression e.
Definition: MatrixExpression.hpp:1692
MatrixColumn< M > column(MatrixExpression< M > &e, typename MatrixColumn< M >::SizeType j)
Returns a mutable column proxy for column j of the matrix expression e.
Definition: MatrixProxy.hpp:1259
E::ValueType det(const MatrixExpression< E > &e)
Returns the determinant of the matrix expression e.
Definition: Matrix.hpp:2987
bool svDecompose(MatrixExpression< A > &a, VectorExpression< W > &w, MatrixExpression< V > &v, std::size_t max_iter=0)
Computes the Singular Value Decomposition [WSVD] of a -dimensional matrix a.
Definition: SVDecomposition.hpp:70
MatrixRange< E > range(MatrixExpression< E > &e, const typename MatrixRange< E >::RangeType &r1, const typename MatrixRange< E >::RangeType &r2)
Returns a mutable matrix range proxy viewing rows in r1 and columns in r2 of e.
Definition: MatrixProxy.hpp:1288
Matrix1VectorBinaryTraits< E1, E2, MatrixVectorProduct< E1, E2 > >::ResultType prod(const MatrixExpression< E1 > &e1, const VectorExpression< E2 > &e2)
Returns the matrix-vector product as a vector expression (named-function form of operator*).
Definition: MatrixExpression.hpp:1480
The namespace of the Chemical Data Processing Library.
Trait that resolves the common arithmetic type of T1 and T2 via std::common_type.
Definition: CommonType.hpp:46
std::common_type< T1, T2 >::type Type
The common type.
Definition: CommonType.hpp:49