BlasUtil.h 15.4 KB
Newer Older
LM's avatar
LM committed
1 2 3 4 5
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2009-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
//
Don Gagne's avatar
Don Gagne committed
6 7 8
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
LM's avatar
LM committed
9 10 11 12 13 14 15

#ifndef EIGEN_BLASUTIL_H
#define EIGEN_BLASUTIL_H

// This file contains many lightweight helper classes used to
// implement and control fast level 2 and level 3 BLAS-like routines.

Don Gagne's avatar
Don Gagne committed
16 17
namespace Eigen {

LM's avatar
LM committed
18 19 20
namespace internal {

// forward declarations
21
template<typename LhsScalar, typename RhsScalar, typename Index, typename DataMapper, int mr, int nr, bool ConjugateLhs=false, bool ConjugateRhs=false>
LM's avatar
LM committed
22 23
struct gebp_kernel;

24
template<typename Scalar, typename Index, typename DataMapper, int nr, int StorageOrder, bool Conjugate = false, bool PanelMode=false>
LM's avatar
LM committed
25 26
struct gemm_pack_rhs;

27
template<typename Scalar, typename Index, typename DataMapper, int Pack1, int Pack2, int StorageOrder, bool Conjugate = false, bool PanelMode = false>
LM's avatar
LM committed
28 29 30 31 32 33 34 35 36
struct gemm_pack_lhs;

template<
  typename Index,
  typename LhsScalar, int LhsStorageOrder, bool ConjugateLhs,
  typename RhsScalar, int RhsStorageOrder, bool ConjugateRhs,
  int ResStorageOrder>
struct general_matrix_matrix_product;

37 38 39
template<typename Index,
         typename LhsScalar, typename LhsMapper, int LhsStorageOrder, bool ConjugateLhs,
         typename RhsScalar, typename RhsMapper, bool ConjugateRhs, int Version=Specialized>
LM's avatar
LM committed
40 41 42 43 44 45 46
struct general_matrix_vector_product;


template<bool Conjugate> struct conj_if;

template<> struct conj_if<true> {
  template<typename T>
47
  inline T operator()(const T& x) const { return numext::conj(x); }
Don Gagne's avatar
Don Gagne committed
48
  template<typename T>
49
  inline T pconj(const T& x) const { return internal::pconj(x); }
LM's avatar
LM committed
50 51 52 53
};

template<> struct conj_if<false> {
  template<typename T>
54
  inline const T& operator()(const T& x) const { return x; }
Don Gagne's avatar
Don Gagne committed
55
  template<typename T>
56 57 58 59 60 61 62 63 64 65 66 67 68 69
  inline const T& pconj(const T& x) const { return x; }
};

// Generic implementation for custom complex types.
template<typename LhsScalar, typename RhsScalar, bool ConjLhs, bool ConjRhs>
struct conj_helper
{
  typedef typename ScalarBinaryOpTraits<LhsScalar,RhsScalar>::ReturnType Scalar;

  EIGEN_STRONG_INLINE Scalar pmadd(const LhsScalar& x, const RhsScalar& y, const Scalar& c) const
  { return padd(c, pmul(x,y)); }

  EIGEN_STRONG_INLINE Scalar pmul(const LhsScalar& x, const RhsScalar& y) const
  { return conj_if<ConjLhs>()(x) *  conj_if<ConjRhs>()(y); }
LM's avatar
LM committed
70 71 72 73
};

template<typename Scalar> struct conj_helper<Scalar,Scalar,false,false>
{
74 75
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const { return internal::pmadd(x,y,c); }
  EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const { return internal::pmul(x,y); }
LM's avatar
LM committed
76 77 78 79 80 81 82 83 84
};

template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, false,true>
{
  typedef std::complex<RealScalar> Scalar;
  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
  { return c + pmul(x,y); }

  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
Don Gagne's avatar
Don Gagne committed
85
  { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::imag(x)*numext::real(y) - numext::real(x)*numext::imag(y)); }
LM's avatar
LM committed
86 87 88 89 90 91 92 93 94
};

template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,false>
{
  typedef std::complex<RealScalar> Scalar;
  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
  { return c + pmul(x,y); }

  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
Don Gagne's avatar
Don Gagne committed
95
  { return Scalar(numext::real(x)*numext::real(y) + numext::imag(x)*numext::imag(y), numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
LM's avatar
LM committed
96 97 98 99 100 101 102 103 104
};

template<typename RealScalar> struct conj_helper<std::complex<RealScalar>, std::complex<RealScalar>, true,true>
{
  typedef std::complex<RealScalar> Scalar;
  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const Scalar& y, const Scalar& c) const
  { return c + pmul(x,y); }

  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const Scalar& y) const
Don Gagne's avatar
Don Gagne committed
105
  { return Scalar(numext::real(x)*numext::real(y) - numext::imag(x)*numext::imag(y), - numext::real(x)*numext::imag(y) - numext::imag(x)*numext::real(y)); }
LM's avatar
LM committed
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
};

template<typename RealScalar,bool Conj> struct conj_helper<std::complex<RealScalar>, RealScalar, Conj,false>
{
  typedef std::complex<RealScalar> Scalar;
  EIGEN_STRONG_INLINE Scalar pmadd(const Scalar& x, const RealScalar& y, const Scalar& c) const
  { return padd(c, pmul(x,y)); }
  EIGEN_STRONG_INLINE Scalar pmul(const Scalar& x, const RealScalar& y) const
  { return conj_if<Conj>()(x)*y; }
};

template<typename RealScalar,bool Conj> struct conj_helper<RealScalar, std::complex<RealScalar>, false,Conj>
{
  typedef std::complex<RealScalar> Scalar;
  EIGEN_STRONG_INLINE Scalar pmadd(const RealScalar& x, const Scalar& y, const Scalar& c) const
  { return padd(c, pmul(x,y)); }
  EIGEN_STRONG_INLINE Scalar pmul(const RealScalar& x, const Scalar& y) const
  { return x*conj_if<Conj>()(y); }
};

template<typename From,typename To> struct get_factor {
127
  EIGEN_DEVICE_FUNC static EIGEN_STRONG_INLINE To run(const From& x) { return To(x); }
LM's avatar
LM committed
128 129 130
};

template<typename Scalar> struct get_factor<Scalar,typename NumTraits<Scalar>::Real> {
131
  EIGEN_DEVICE_FUNC
Don Gagne's avatar
Don Gagne committed
132
  static EIGEN_STRONG_INLINE typename NumTraits<Scalar>::Real run(const Scalar& x) { return numext::real(x); }
LM's avatar
LM committed
133 134
};

135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189

template<typename Scalar, typename Index>
class BlasVectorMapper {
  public:
  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasVectorMapper(Scalar *data) : m_data(data) {}

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
    return m_data[i];
  }
  template <typename Packet, int AlignmentType>
  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet load(Index i) const {
    return ploadt<Packet, AlignmentType>(m_data + i);
  }

  template <typename Packet>
  EIGEN_DEVICE_FUNC bool aligned(Index i) const {
    return (UIntPtr(m_data+i)%sizeof(Packet))==0;
  }

  protected:
  Scalar* m_data;
};

template<typename Scalar, typename Index, int AlignmentType>
class BlasLinearMapper {
  public:
  typedef typename packet_traits<Scalar>::type Packet;
  typedef typename packet_traits<Scalar>::half HalfPacket;

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlasLinearMapper(Scalar *data) : m_data(data) {}

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void prefetch(int i) const {
    internal::prefetch(&operator()(i));
  }

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar& operator()(Index i) const {
    return m_data[i];
  }

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
    return ploadt<Packet, AlignmentType>(m_data + i);
  }

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i) const {
    return ploadt<HalfPacket, AlignmentType>(m_data + i);
  }

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void storePacket(Index i, const Packet &p) const {
    pstoret<Scalar, Packet, AlignmentType>(m_data + i, p);
  }

  protected:
  Scalar *m_data;
};

LM's avatar
LM committed
190
// Lightweight helper class to access matrix coefficients.
191 192
template<typename Scalar, typename Index, int StorageOrder, int AlignmentType = Unaligned>
class blas_data_mapper {
LM's avatar
LM committed
193
  public:
194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
  typedef typename packet_traits<Scalar>::type Packet;
  typedef typename packet_traits<Scalar>::half HalfPacket;

  typedef BlasLinearMapper<Scalar, Index, AlignmentType> LinearMapper;
  typedef BlasVectorMapper<Scalar, Index> VectorMapper;

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE blas_data_mapper(Scalar* data, Index stride) : m_data(data), m_stride(stride) {}

  EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>
  getSubMapper(Index i, Index j) const {
    return blas_data_mapper<Scalar, Index, StorageOrder, AlignmentType>(&operator()(i, j), m_stride);
  }

  EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
    return LinearMapper(&operator()(i, j));
  }

  EIGEN_DEVICE_FUNC  EIGEN_ALWAYS_INLINE VectorMapper getVectorMapper(Index i, Index j) const {
    return VectorMapper(&operator()(i, j));
  }


  EIGEN_DEVICE_FUNC
  EIGEN_ALWAYS_INLINE Scalar& operator()(Index i, Index j) const {
    return m_data[StorageOrder==RowMajor ? j + i*m_stride : i + j*m_stride];
  }

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, Index j) const {
    return ploadt<Packet, AlignmentType>(&operator()(i, j));
  }

  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE HalfPacket loadHalfPacket(Index i, Index j) const {
    return ploadt<HalfPacket, AlignmentType>(&operator()(i, j));
  }

  template<typename SubPacket>
  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void scatterPacket(Index i, Index j, const SubPacket &p) const {
    pscatter<Scalar, SubPacket>(&operator()(i, j), p, m_stride);
  }

  template<typename SubPacket>
  EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SubPacket gatherPacket(Index i, Index j) const {
    return pgather<Scalar, SubPacket>(&operator()(i, j), m_stride);
  }

  EIGEN_DEVICE_FUNC const Index stride() const { return m_stride; }
  EIGEN_DEVICE_FUNC const Scalar* data() const { return m_data; }

  EIGEN_DEVICE_FUNC Index firstAligned(Index size) const {
    if (UIntPtr(m_data)%sizeof(Scalar)) {
      return -1;
    }
    return internal::first_default_aligned(m_data, size);
  }

LM's avatar
LM committed
249
  protected:
250 251
  Scalar* EIGEN_RESTRICT m_data;
  const Index m_stride;
LM's avatar
LM committed
252 253 254 255
};

// lightweight helper class to access matrix coefficients (const version)
template<typename Scalar, typename Index, int StorageOrder>
256
class const_blas_data_mapper : public blas_data_mapper<const Scalar, Index, StorageOrder> {
LM's avatar
LM committed
257
  public:
258 259 260 261 262
  EIGEN_ALWAYS_INLINE const_blas_data_mapper(const Scalar *data, Index stride) : blas_data_mapper<const Scalar, Index, StorageOrder>(data, stride) {}

  EIGEN_ALWAYS_INLINE const_blas_data_mapper<Scalar, Index, StorageOrder> getSubMapper(Index i, Index j) const {
    return const_blas_data_mapper<Scalar, Index, StorageOrder>(&(this->operator()(i, j)), this->m_stride);
  }
LM's avatar
LM committed
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
};


/* Helper class to analyze the factors of a Product expression.
 * In particular it allows to pop out operator-, scalar multiples,
 * and conjugate */
template<typename XprType> struct blas_traits
{
  typedef typename traits<XprType>::Scalar Scalar;
  typedef const XprType& ExtractType;
  typedef XprType _ExtractType;
  enum {
    IsComplex = NumTraits<Scalar>::IsComplex,
    IsTransposed = false,
    NeedToConjugate = false,
    HasUsableDirectAccess = (    (int(XprType::Flags)&DirectAccessBit)
                              && (   bool(XprType::IsVectorAtCompileTime)
                                  || int(inner_stride_at_compile_time<XprType>::ret) == 1)
                             ) ?  1 : 0
  };
  typedef typename conditional<bool(HasUsableDirectAccess),
    ExtractType,
    typename _ExtractType::PlainObject
    >::type DirectLinearAccessType;
Don Gagne's avatar
Don Gagne committed
287
  static inline ExtractType extract(const XprType& x) { return x; }
LM's avatar
LM committed
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
  static inline const Scalar extractScalarFactor(const XprType&) { return Scalar(1); }
};

// pop conjugate
template<typename Scalar, typename NestedXpr>
struct blas_traits<CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> >
 : blas_traits<NestedXpr>
{
  typedef blas_traits<NestedXpr> Base;
  typedef CwiseUnaryOp<scalar_conjugate_op<Scalar>, NestedXpr> XprType;
  typedef typename Base::ExtractType ExtractType;

  enum {
    IsComplex = NumTraits<Scalar>::IsComplex,
    NeedToConjugate = Base::NeedToConjugate ? 0 : IsComplex
  };
Don Gagne's avatar
Don Gagne committed
304
  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
LM's avatar
LM committed
305 306 307 308
  static inline Scalar extractScalarFactor(const XprType& x) { return conj(Base::extractScalarFactor(x.nestedExpression())); }
};

// pop scalar multiple
309 310
template<typename Scalar, typename NestedXpr, typename Plain>
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> >
LM's avatar
LM committed
311 312 313
 : blas_traits<NestedXpr>
{
  typedef blas_traits<NestedXpr> Base;
314
  typedef CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain>, NestedXpr> XprType;
LM's avatar
LM committed
315
  typedef typename Base::ExtractType ExtractType;
316
  static inline ExtractType extract(const XprType& x) { return Base::extract(x.rhs()); }
LM's avatar
LM committed
317
  static inline Scalar extractScalarFactor(const XprType& x)
318
  { return x.lhs().functor().m_other * Base::extractScalarFactor(x.rhs()); }
LM's avatar
LM committed
319
};
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
template<typename Scalar, typename NestedXpr, typename Plain>
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > >
 : blas_traits<NestedXpr>
{
  typedef blas_traits<NestedXpr> Base;
  typedef CwiseBinaryOp<scalar_product_op<Scalar>, NestedXpr, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain> > XprType;
  typedef typename Base::ExtractType ExtractType;
  static inline ExtractType extract(const XprType& x) { return Base::extract(x.lhs()); }
  static inline Scalar extractScalarFactor(const XprType& x)
  { return Base::extractScalarFactor(x.lhs()) * x.rhs().functor().m_other; }
};
template<typename Scalar, typename Plain1, typename Plain2>
struct blas_traits<CwiseBinaryOp<scalar_product_op<Scalar>, const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1>,
                                                            const CwiseNullaryOp<scalar_constant_op<Scalar>,Plain2> > >
 : blas_traits<CwiseNullaryOp<scalar_constant_op<Scalar>,Plain1> >
{};
LM's avatar
LM committed
336 337 338 339 340 341 342 343 344

// pop opposite
template<typename Scalar, typename NestedXpr>
struct blas_traits<CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> >
 : blas_traits<NestedXpr>
{
  typedef blas_traits<NestedXpr> Base;
  typedef CwiseUnaryOp<scalar_opposite_op<Scalar>, NestedXpr> XprType;
  typedef typename Base::ExtractType ExtractType;
Don Gagne's avatar
Don Gagne committed
345
  static inline ExtractType extract(const XprType& x) { return Base::extract(x.nestedExpression()); }
LM's avatar
LM committed
346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
  static inline Scalar extractScalarFactor(const XprType& x)
  { return - Base::extractScalarFactor(x.nestedExpression()); }
};

// pop/push transpose
template<typename NestedXpr>
struct blas_traits<Transpose<NestedXpr> >
 : blas_traits<NestedXpr>
{
  typedef typename NestedXpr::Scalar Scalar;
  typedef blas_traits<NestedXpr> Base;
  typedef Transpose<NestedXpr> XprType;
  typedef Transpose<const typename Base::_ExtractType>  ExtractType; // const to get rid of a compile error; anyway blas traits are only used on the RHS
  typedef Transpose<const typename Base::_ExtractType> _ExtractType;
  typedef typename conditional<bool(Base::HasUsableDirectAccess),
    ExtractType,
    typename ExtractType::PlainObject
    >::type DirectLinearAccessType;
  enum {
    IsTransposed = Base::IsTransposed ? 0 : 1
  };
367
  static inline ExtractType extract(const XprType& x) { return ExtractType(Base::extract(x.nestedExpression())); }
LM's avatar
LM committed
368 369 370 371 372 373 374 375 376 377 378 379
  static inline Scalar extractScalarFactor(const XprType& x) { return Base::extractScalarFactor(x.nestedExpression()); }
};

template<typename T>
struct blas_traits<const T>
     : blas_traits<T>
{};

template<typename T, bool HasUsableDirectAccess=blas_traits<T>::HasUsableDirectAccess>
struct extract_data_selector {
  static const typename T::Scalar* run(const T& m)
  {
Don Gagne's avatar
Don Gagne committed
380
    return blas_traits<T>::extract(m).data();
LM's avatar
LM committed
381 382 383 384 385 386 387 388 389 390 391 392 393 394 395
  }
};

template<typename T>
struct extract_data_selector<T,false> {
  static typename T::Scalar* run(const T&) { return 0; }
};

template<typename T> const typename T::Scalar* extract_data(const T& m)
{
  return extract_data_selector<T>::run(m);
}

} // end namespace internal

Don Gagne's avatar
Don Gagne committed
396 397
} // end namespace Eigen

LM's avatar
LM committed
398
#endif // EIGEN_BLASUTIL_H