Chemical Data Processing Library C++ API - Version 1.0.0
KabschAlgorithm.hpp
Go to the documentation of this file.
1 /*
2  * KabschAlgorithm.hpp
3  *
4  * Copyright (C) 2010-2012 Thomas Seidel <thomas.seidel@univie.ac.at>
5  *
6  * This library is free software; you can redistribute it and/or
7  * modify it under the terms of the GNU Lesser General Public
8  * License as published by the Free Software Foundation; either
9  * version 2 of the License, or (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
14  * Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library; see the file COPYING. If not, write to
18  * the Free Software Foundation, Inc., 59 Temple Place - Suite 330,
19  * Boston, MA 02111-1307, USA.
20  */
21 
27 #ifndef CDPL_MATH_KABSCHALGORITHM_HPP
28 #define CDPL_MATH_KABSCHALGORITHM_HPP
29 
30 #include <cstddef>
31 
32 #include "CDPL/Math/Check.hpp"
33 #include "CDPL/Math/TypeTraits.hpp"
34 #include "CDPL/Math/Matrix.hpp"
35 #include "CDPL/Math/Vector.hpp"
39 #include "CDPL/Base/Exceptions.hpp"
40 
41 
42 namespace CDPL
43 {
44 
45  namespace Math
46  {
47 
60  template <typename T>
62  {
63 
64  public:
65  typedef T ValueType;
68 
85  template <typename M1, typename M2, typename V>
86  bool align(const MatrixExpression<M1>& points, const MatrixExpression<M2>& ref_points, const VectorExpression<V>& weights,
87  bool do_center = true, std::size_t max_svd_iter = 0)
88  {
89 
91  typename V::SizeType>::Type SizeType;
92 
93  SizeType dim = points().getSize1();
94  SizeType num_pts = points().getSize2();
95 
96  CDPL_MATH_CHECK(dim == SizeType(ref_points().getSize1()) && num_pts == SizeType(ref_points().getSize2()),
97  "KabschAlgorithm: Point-sets of different size", Base::SizeError);
98 
99  CDPL_MATH_CHECK(num_pts == SizeType(weights().getSize()),
100  "KabschAlgorithm: Number of points != number of weights", Base::SizeError);
101 
102  ValueType w_sum = ValueType();
103 
104  for (SizeType i = 0; i < num_pts; i++) {
105  CDPL_MATH_CHECK(ValueType(weights()(i)) >= ValueType(), "KabschAlgorithm: weights must be non-negative entries", Base::ValueError);
106  w_sum += weights()(i);
107  }
108 
109  CDPL_MATH_CHECK(w_sum > ValueType(), "KabschAlgorithm: weights must contain some positive entry", Base::ValueError);
110 
111  if (do_center) {
112  prod(points, weights, centroid1);
113  prod(ref_points, weights, centroid2);
114 
115  centroid1 /= w_sum;
116  centroid2 /= w_sum;
117 
118  tmpPoints.resize(dim, num_pts, false);
119  tmpPoints.assign(points);
120 
121  tmpRefPoints.resize(dim, num_pts, false);
122  tmpRefPoints.assign(ref_points);
123 
124  for (SizeType i = 0; i < num_pts; i++) {
125  column(tmpPoints, i).minusAssign(centroid1) *= weights()(i) / w_sum;
126  column(tmpRefPoints, i).minusAssign(centroid2);
127  }
128 
129  } else {
130  tmpPoints.resize(dim, num_pts, false);
131  tmpPoints.assign(points);
132 
133  for (SizeType i = 0; i < num_pts; i++)
134  column(tmpPoints, i) *= weights()(i) / w_sum;
135  }
136 
137  covarMatrix.resize(dim, dim, false);
138 
139  if (do_center)
140  prod(tmpPoints, trans(tmpRefPoints), covarMatrix);
141  else
142  prod(tmpPoints, trans(ref_points), covarMatrix);
143 
144  return align(dim, do_center, max_svd_iter);
145  }
146 
160  template <typename M1, typename M2>
161  bool align(const MatrixExpression<M1>& points, const MatrixExpression<M2>& ref_points,
162  bool do_center = true, std::size_t max_svd_iter = 0)
163  {
164 
166 
167  SizeType dim = points().getSize1();
168  SizeType num_pts = points().getSize2();
169 
170  CDPL_MATH_CHECK(dim == SizeType(ref_points().getSize1()) && num_pts == SizeType(ref_points().getSize2()),
171  "KabschAlgorithm: Point-sets of different size", Base::SizeError);
172 
173  if (do_center) {
174  prod(points, ScalarVector<ValueType>(num_pts, ValueType(1) / num_pts), centroid1);
175  prod(ref_points, ScalarVector<ValueType>(num_pts, ValueType(1) / num_pts), centroid2);
176 
177  tmpPoints.resize(dim, num_pts, false);
178  tmpPoints.assign(points);
179 
180  tmpRefPoints.resize(dim, num_pts, false);
181  tmpRefPoints.assign(ref_points);
182 
183  for (SizeType i = 0; i < num_pts; i++) {
184  column(tmpPoints, i).minusAssign(centroid1);
185  column(tmpRefPoints, i).minusAssign(centroid2);
186  }
187  }
188 
189  covarMatrix.resize(dim, dim, false);
190 
191  if (do_center)
192  prod(tmpPoints, trans(tmpRefPoints), covarMatrix);
193  else
194  prod(points, trans(ref_points), covarMatrix);
195 
196  return align(dim, do_center, max_svd_iter);
197  }
198 
199  const MatrixType& getTransform() const
200  {
201  return transform;
202  }
203 
204  private:
205  template <typename SizeType>
206  bool align(SizeType dim, bool do_center, std::size_t max_svd_iter)
207  {
208  svdW.resize(dim);
209  svdV.resize(dim, dim, false);
210 
211  if (!svDecompose(covarMatrix, svdW, svdV, max_svd_iter))
212  return false;
213 
214  if (det(prod(covarMatrix, trans(svdV))) < ValueType())
215  column(svdV, dim - 1) *= -ValueType(1);
216 
217  SizeType xform_dim = dim + 1;
218 
219  transform.resize(xform_dim, xform_dim, false);
220 
221  range(transform, 0, dim, 0, dim).assign(prod(svdV, trans(covarMatrix)));
222 
223  MatrixRow<MatrixType> last_row(transform, dim);
224  MatrixColumn<MatrixType> last_col(transform, dim);
225 
226  range(last_row, 0, dim).assign(ZeroVector<ValueType>(dim));
227 
228  if (do_center)
229  range(last_col, 0, dim).assign(centroid2 - prod(range(transform, 0, dim, 0, dim), centroid1));
230  else
231  range(last_col, 0, dim).assign(ZeroVector<ValueType>(dim));
232 
233  transform(dim, dim) = ValueType(1);
234 
235  return true;
236  }
237 
238  MatrixType transform;
239  MatrixType tmpPoints;
240  MatrixType tmpRefPoints;
241  MatrixType covarMatrix;
242  VectorType svdW;
243  MatrixType svdV;
244  VectorType centroid1;
245  VectorType centroid2;
246  };
247  } // namespace Math
248 } // namespace CDPL
249 
250 #endif // CDPL_MATH_KABSCHALGORITHM_HPP
CDPL::Math::KabschAlgorithm::ValueType
T ValueType
Definition: KabschAlgorithm.hpp:65
CDPL::Math::Vector< T >
CDPL::Math::CommonType::Type
std::common_type< T1, T2 >::type Type
Definition: CommonType.hpp:43
CDPL::Math::trans
MatrixTranspose< E > trans(MatrixExpression< E > &e)
Definition: MatrixExpression.hpp:941
CDPL::Math::ScalarVector
Definition: Vector.hpp:1470
CDPL::Math::column
MatrixColumn< M > column(MatrixExpression< M > &e, typename MatrixColumn< M >::SizeType j)
Definition: MatrixProxy.hpp:730
CDPL::Math::KabschAlgorithm::MatrixType
Matrix< T > MatrixType
Definition: KabschAlgorithm.hpp:66
CDPL::Math::VectorExpression
Definition: Expression.hpp:54
CDPL::Math::Matrix::assign
Matrix & assign(const MatrixExpression< E > &e)
Definition: Matrix.hpp:459
CDPL::Math::MatrixExpression
Definition: Expression.hpp:76
CDPL_MATH_CHECK
#define CDPL_MATH_CHECK(expr, msg, e)
Definition: Check.hpp:36
CDPL::Math::CommonType
Definition: CommonType.hpp:41
VectorProxy.hpp
Definition of vector proxy types.
CDPL::Math::det
E::ValueType det(const MatrixExpression< E > &e)
Definition: Matrix.hpp:1721
CDPL::Base::SizeError
Thrown to indicate that the size of a (multidimensional) array is not correct.
Definition: Base/Exceptions.hpp:133
CDPL::Math::MatrixRow
Definition: MatrixProxy.hpp:49
CDPL::Math::MatrixColumn
Definition: MatrixProxy.hpp:196
TypeTraits.hpp
Definition of type traits.
CDPL::Math::prod
Matrix1VectorBinaryTraits< E1, E2, MatrixVectorProduct< E1, E2 > >::ResultType prod(const MatrixExpression< E1 > &e1, const VectorExpression< E2 > &e2)
Definition: MatrixExpression.hpp:833
CDPL::Math::Matrix::resize
void resize(SizeType m, SizeType n, bool preserve=true, const ValueType &v=ValueType())
Definition: Matrix.hpp:519
CDPL::Math::ZeroVector
Definition: Vector.hpp:1292
CDPL::Base::ValueError
Thrown to indicate errors caused by some invalid value.
Definition: Base/Exceptions.hpp:76
SVDecomposition.hpp
Implementation of matrix singular value decomposition and associated operations.
CDPL::Math::KabschAlgorithm::align
bool align(const MatrixExpression< M1 > &points, const MatrixExpression< M2 > &ref_points, bool do_center=true, std::size_t max_svd_iter=0)
Computes the rigid body transformation that aligns a set of -dimensional points points with a corres...
Definition: KabschAlgorithm.hpp:161
CDPL::Chem::AtomType::T
const unsigned int T
Specifies Hydrogen (Tritium).
Definition: AtomType.hpp:67
CDPL::Math::KabschAlgorithm::align
bool align(const MatrixExpression< M1 > &points, const MatrixExpression< M2 > &ref_points, const VectorExpression< V > &weights, bool do_center=true, std::size_t max_svd_iter=0)
Computes the rigid body transformation that aligns a set of -dimensional points points with a corres...
Definition: KabschAlgorithm.hpp:86
CDPL::Math::KabschAlgorithm
Implementation of the Kabsch algorithm [KABA].
Definition: KabschAlgorithm.hpp:62
Exceptions.hpp
Definition of exception classes.
CDPL::Math::KabschAlgorithm::VectorType
Vector< T > VectorType
Definition: KabschAlgorithm.hpp:67
CDPL::Math::Matrix< T >
CDPL
The namespace of the Chemical Data Processing Library.
CDPL::Math::Vector::resize
void resize(SizeType n, const ValueType &v=ValueType())
Definition: Vector.hpp:491
MatrixProxy.hpp
Definition of matrix proxy types.
Matrix.hpp
Definition of matrix data types.
CDPL::Math::KabschAlgorithm::getTransform
const MatrixType & getTransform() const
Definition: KabschAlgorithm.hpp:199
Check.hpp
Definition of various preprocessor macros for error checking.
CDPL::Math::range
MatrixRange< E > range(MatrixExpression< E > &e, const typename MatrixRange< E >::RangeType &r1, const typename MatrixRange< E >::RangeType &r2)
Definition: MatrixProxy.hpp:744
CDPL::Math::svDecompose
bool svDecompose(MatrixExpression< A > &a, VectorExpression< W > &w, MatrixExpression< V > &v, std::size_t max_iter=0)
Computes the Singular Value Decomposition [WSVD] of a -dimensional matrix a.
Definition: SVDecomposition.hpp:70
Vector.hpp
Definition of vector data types.