Chemical Data Processing Library C++ API - Version 1.2.0
KabschAlgorithm.hpp
Go to the documentation of this file.
1 /*
2  * KabschAlgorithm.hpp
3  *
4  * Copyright (C) 2003 Thomas Seidel <thomas.seidel@univie.ac.at>
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library; see the file COPYING. If not, write to
18  * the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
19  * Boston, MA 02111-1307, USA.
20  */
21 
27 #ifndef CDPL_MATH_KABSCHALGORITHM_HPP
28 #define CDPL_MATH_KABSCHALGORITHM_HPP
29 
30 #include <cstddef>
31 
32 #include "CDPL/Math/Check.hpp"
33 #include "CDPL/Math/TypeTraits.hpp"
34 #include "CDPL/Math/Matrix.hpp"
35 #include "CDPL/Math/Vector.hpp"
39 #include "CDPL/Base/Exceptions.hpp"
40 
41 
42 namespace CDPL
43 {
44 
45  namespace Math
46  {
47 
60  template <typename T>
62  {
63 
64  public:
65  typedef T ValueType;
68 
85  template <typename M1, typename M2, typename V>
86  bool align(const MatrixExpression<M1>& points, const MatrixExpression<M2>& ref_points, const VectorExpression<V>& weights,
87  bool do_center = true, std::size_t max_svd_iter = 0)
88  {
89 
91  typename V::SizeType>::Type SizeType;
92 
93  SizeType dim = points().getSize1();
94  SizeType num_pts = points().getSize2();
95 
96  CDPL_MATH_CHECK(dim == SizeType(ref_points().getSize1()) && num_pts == SizeType(ref_points().getSize2()),
97  "KabschAlgorithm: Point-sets of different size", Base::SizeError);
98 
99  CDPL_MATH_CHECK(num_pts == SizeType(weights().getSize()),
100  "KabschAlgorithm: Number of points != number of weights", Base::SizeError);
101 
102  ValueType w_sum = ValueType();
103 
104  for (SizeType i = 0; i < num_pts; i++) {
105  CDPL_MATH_CHECK(ValueType(weights()(i)) >= ValueType(), "KabschAlgorithm: weights must be non-negative entries", Base::ValueError);
106  w_sum += weights()(i);
107  }
108 
109  CDPL_MATH_CHECK(w_sum > ValueType(), "KabschAlgorithm: weights must contain some positive entry", Base::ValueError);
110 
111  if (do_center) {
112  prod(points, weights, centroid1);
113  prod(ref_points, weights, centroid2);
114 
115  centroid1 /= w_sum;
116  centroid2 /= w_sum;
117 
118  tmpPoints.resize(dim, num_pts, false);
119  tmpPoints.assign(points);
120 
121  tmpRefPoints.resize(dim, num_pts, false);
122  tmpRefPoints.assign(ref_points);
123 
124  for (SizeType i = 0; i < num_pts; i++) {
125  column(tmpPoints, i).minusAssign(centroid1) *= weights()(i) / w_sum;
126  column(tmpRefPoints, i).minusAssign(centroid2);
127  }
128 
129  } else {
130  tmpPoints.resize(dim, num_pts, false);
131  tmpPoints.assign(points);
132 
133  for (SizeType i = 0; i < num_pts; i++)
134  column(tmpPoints, i) *= weights()(i) / w_sum;
135  }
136 
137  covarMatrix.resize(dim, dim, false);
138 
139  if (do_center)
140  prod(tmpPoints, trans(tmpRefPoints), covarMatrix);
141  else
142  prod(tmpPoints, trans(ref_points), covarMatrix);
143 
144  return align(dim, do_center, max_svd_iter);
145  }
146 
160  template <typename M1, typename M2>
161  bool align(const MatrixExpression<M1>& points, const MatrixExpression<M2>& ref_points,
162  bool do_center = true, std::size_t max_svd_iter = 0)
163  {
164 
166 
167  SizeType dim = points().getSize1();
168  SizeType num_pts = points().getSize2();
169 
170  CDPL_MATH_CHECK(dim == SizeType(ref_points().getSize1()) && num_pts == SizeType(ref_points().getSize2()),
171  "KabschAlgorithm: Point-sets of different size", Base::SizeError);
172 
173  if (do_center) {
174  prod(points, ScalarVector<ValueType>(num_pts, ValueType(1) / num_pts), centroid1);
175  prod(ref_points, ScalarVector<ValueType>(num_pts, ValueType(1) / num_pts), centroid2);
176 
177  tmpPoints.resize(dim, num_pts, false);
178  tmpPoints.assign(points);
179 
180  tmpRefPoints.resize(dim, num_pts, false);
181  tmpRefPoints.assign(ref_points);
182 
183  for (SizeType i = 0; i < num_pts; i++) {
184  column(tmpPoints, i).minusAssign(centroid1);
185  column(tmpRefPoints, i).minusAssign(centroid2);
186  }
187  }
188 
189  covarMatrix.resize(dim, dim, false);
190 
191  if (do_center)
192  prod(tmpPoints, trans(tmpRefPoints), covarMatrix);
193  else
194  prod(points, trans(ref_points), covarMatrix);
195 
196  return align(dim, do_center, max_svd_iter);
197  }
198 
199  const MatrixType& getTransform() const
200  {
201  return transform;
202  }
203 
204  private:
205  template <typename SizeType>
206  bool align(SizeType dim, bool do_center, std::size_t max_svd_iter)
207  {
208  svdW.resize(dim);
209  svdV.resize(dim, dim, false);
210 
211  if (!svDecompose(covarMatrix, svdW, svdV, max_svd_iter))
212  return false;
213 
214  if (det(prod(covarMatrix, trans(svdV))) < ValueType())
215  column(svdV, dim - 1) *= -ValueType(1);
216 
217  SizeType xform_dim = dim + 1;
218 
219  transform.resize(xform_dim, xform_dim, false);
220 
221  range(transform, 0, dim, 0, dim).assign(prod(svdV, trans(covarMatrix)));
222 
223  MatrixRow<MatrixType> last_row(transform, dim);
224  MatrixColumn<MatrixType> last_col(transform, dim);
225 
226  range(last_row, 0, dim).assign(ZeroVector<ValueType>(dim));
227 
228  if (do_center)
229  range(last_col, 0, dim).assign(centroid2 - prod(range(transform, 0, dim, 0, dim), centroid1));
230  else
231  range(last_col, 0, dim).assign(ZeroVector<ValueType>(dim));
232 
233  transform(dim, dim) = ValueType(1);
234 
235  return true;
236  }
237 
238  MatrixType transform;
239  MatrixType tmpPoints;
240  MatrixType tmpRefPoints;
241  MatrixType covarMatrix;
242  VectorType svdW;
243  MatrixType svdV;
244  VectorType centroid1;
245  VectorType centroid2;
246  };
247  } // namespace Math
248 } // namespace CDPL
249 
250 #endif // CDPL_MATH_KABSCHALGORITHM_HPP
Definition of exception classes.
Definition of various preprocessor macros for error checking.
#define CDPL_MATH_CHECK(expr, msg, e)
Definition: Check.hpp:36
Definition of matrix proxy types.
Definition of matrix data types.
Implementation of matrix singular value decomposition and associated operations.
Definition of type traits.
Definition of vector proxy types.
Definition of vector data types.
Thrown to indicate that the size of a (multidimensional) array is not correct.
Definition: Base/Exceptions.hpp:133
Thrown to indicate errors caused by some invalid value.
Definition: Base/Exceptions.hpp:76
Implementation of the Kabsch algorithm [KABA].
Definition: KabschAlgorithm.hpp:62
Vector< T > VectorType
Definition: KabschAlgorithm.hpp:67
bool align(const MatrixExpression< M1 > &points, const MatrixExpression< M2 > &ref_points, bool do_center=true, std::size_t max_svd_iter=0)
Computes the rigid body transformation that aligns a set of -dimensional points points with a corres...
Definition: KabschAlgorithm.hpp:161
const MatrixType & getTransform() const
Definition: KabschAlgorithm.hpp:199
Matrix< T > MatrixType
Definition: KabschAlgorithm.hpp:66
bool align(const MatrixExpression< M1 > &points, const MatrixExpression< M2 > &ref_points, const VectorExpression< V > &weights, bool do_center=true, std::size_t max_svd_iter=0)
Computes the rigid body transformation that aligns a set of -dimensional points points with a corres...
Definition: KabschAlgorithm.hpp:86
T ValueType
Definition: KabschAlgorithm.hpp:65
Definition: MatrixProxy.hpp:196
Definition: Expression.hpp:76
Definition: MatrixProxy.hpp:49
Definition: Matrix.hpp:280
Definition: Vector.hpp:1470
Definition: Expression.hpp:54
Definition: Vector.hpp:258
Definition: Vector.hpp:1292
constexpr unsigned int T
Specifies Hydrogen (Tritium).
Definition: AtomType.hpp:67
MatrixTranspose< E > trans(MatrixExpression< E > &e)
Definition: MatrixExpression.hpp:941
MatrixColumn< M > column(MatrixExpression< M > &e, typename MatrixColumn< M >::SizeType j)
Definition: MatrixProxy.hpp:730
E::ValueType det(const MatrixExpression< E > &e)
Definition: Matrix.hpp:1721
bool svDecompose(MatrixExpression< A > &a, VectorExpression< W > &w, MatrixExpression< V > &v, std::size_t max_iter=0)
Computes the Singular Value Decomposition [WSVD] of a -dimensional matrix a.
Definition: SVDecomposition.hpp:70
MatrixRange< E > range(MatrixExpression< E > &e, const typename MatrixRange< E >::RangeType &r1, const typename MatrixRange< E >::RangeType &r2)
Definition: MatrixProxy.hpp:744
Matrix1VectorBinaryTraits< E1, E2, MatrixVectorProduct< E1, E2 > >::ResultType prod(const MatrixExpression< E1 > &e1, const VectorExpression< E2 > &e2)
Definition: MatrixExpression.hpp:833
The namespace of the Chemical Data Processing Library.
Definition: CommonType.hpp:41
std::common_type< T1, T2 >::type Type
Definition: CommonType.hpp:43