Basix
Loading...
Searching...
No Matches
math.h
1// Copyright (C) 2021 Igor Baratta
2//
3// This file is part of DOLFINx (https://www.fenicsproject.org)
4//
5// SPDX-License-Identifier: LGPL-3.0-or-later
6
7#pragma once
8
9#include "mdspan.hpp"
10#include <array>
11#include <cmath>
12#include <concepts>
13#include <span>
14#include <string>
15#include <utility>
16#include <vector>
17
18extern "C"
19{
20 void ssyevd_(char* jobz, char* uplo, int* n, float* a, int* lda, float* w,
21 float* work, int* lwork, int* iwork, int* liwork, int* info);
22 void dsyevd_(char* jobz, char* uplo, int* n, double* a, int* lda, double* w,
23 double* work, int* lwork, int* iwork, int* liwork, int* info);
24
25 void sgesv_(int* N, int* NRHS, float* A, int* LDA, int* IPIV, float* B,
26 int* LDB, int* INFO);
27 void dgesv_(int* N, int* NRHS, double* A, int* LDA, int* IPIV, double* B,
28 int* LDB, int* INFO);
29
30 void sgemm_(char* transa, char* transb, int* m, int* n, int* k, float* alpha,
31 float* a, int* lda, float* b, int* ldb, float* beta, float* c,
32 int* ldc);
33 void dgemm_(char* transa, char* transb, int* m, int* n, int* k, double* alpha,
34 double* a, int* lda, double* b, int* ldb, double* beta, double* c,
35 int* ldc);
36
37 int sgetrf_(const int* m, const int* n, float* a, const int* lda, int* lpiv,
38 int* info);
39 int dgetrf_(const int* m, const int* n, double* a, const int* lda, int* lpiv,
40 int* info);
41}
42
47namespace basix::math
48{
49
50namespace impl
51{
56template <std::floating_point T>
57void dot_blas(std::span<const T> A, std::array<std::size_t, 2> Ashape,
58 std::span<const T> B, std::array<std::size_t, 2> Bshape,
59 std::span<T> C)
60{
61 static_assert(std::is_same_v<T, float> or std::is_same_v<T, double>);
62
63 assert(Ashape[1] == Bshape[0]);
64 assert(C.size() == Ashape[0] * Bshape[1]);
65
66 int M = Ashape[0];
67 int N = Bshape[1];
68 int K = Ashape[1];
69
70 T alpha = 1;
71 T beta = 0;
72 int lda = K;
73 int ldb = N;
74 int ldc = N;
75 char trans = 'N';
76 if constexpr (std::is_same_v<T, float>)
77 {
78 sgemm_(&trans, &trans, &N, &M, &K, &alpha, const_cast<T*>(B.data()), &ldb,
79 const_cast<T*>(A.data()), &lda, &beta, C.data(), &ldc);
80 }
81 else if constexpr (std::is_same_v<T, double>)
82 {
83 dgemm_(&trans, &trans, &N, &M, &K, &alpha, const_cast<T*>(B.data()), &ldb,
84 const_cast<T*>(A.data()), &lda, &beta, C.data(), &ldc);
85 }
86}
87
88} // namespace impl
89
94template <typename U, typename V>
95std::pair<std::vector<typename U::value_type>, std::array<std::size_t, 2>>
96outer(const U& u, const V& v)
97{
98 std::vector<typename U::value_type> result(u.size() * v.size());
99 for (std::size_t i = 0; i < u.size(); ++i)
100 for (std::size_t j = 0; j < v.size(); ++j)
101 result[i * v.size() + j] = u[i] * v[j];
102
103 return {std::move(result), {u.size(), v.size()}};
104}
105
110template <typename U, typename V>
111std::array<typename U::value_type, 3> cross(const U& u, const V& v)
112{
113 assert(u.size() == 3);
114 assert(v.size() == 3);
115 return {u[1] * v[2] - u[2] * v[1], u[2] * v[0] - u[0] * v[2],
116 u[0] * v[1] - u[1] * v[0]};
117}
118
125template <std::floating_point T>
126std::pair<std::vector<T>, std::vector<T>> eigh(std::span<const T> A,
127 std::size_t n)
128{
129 // Copy A
130 std::vector<T> M(A.begin(), A.end());
131
132 // Allocate storage for eigenvalues
133 std::vector<T> w(n, 0);
134
135 int N = n;
136 char jobz = 'V'; // Compute eigenvalues and eigenvectors
137 char uplo = 'L'; // Lower
138 int ldA = n;
139 int lwork = -1;
140 int liwork = -1;
141 int info;
142 std::vector<T> work(1);
143 std::vector<int> iwork(1);
144
145 // Query optimal workspace size
146 if constexpr (std::is_same_v<T, float>)
147 {
148 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
149 iwork.data(), &liwork, &info);
150 }
151 else if constexpr (std::is_same_v<T, double>)
152 {
153 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
154 iwork.data(), &liwork, &info);
155 }
156
157 if (info != 0)
158 throw std::runtime_error("Could not find workspace size for syevd.");
159
160 // Solve eigen problem
161 work.resize(work[0]);
162 iwork.resize(iwork[0]);
163 lwork = work.size();
164 liwork = iwork.size();
165 if constexpr (std::is_same_v<T, float>)
166 {
167 ssyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
168 iwork.data(), &liwork, &info);
169 }
170 else if constexpr (std::is_same_v<T, double>)
171 {
172 dsyevd_(&jobz, &uplo, &N, M.data(), &ldA, w.data(), work.data(), &lwork,
173 iwork.data(), &liwork, &info);
174 }
175 if (info != 0)
176 throw std::runtime_error("Eigenvalue computation did not converge.");
177
178 return {std::move(w), std::move(M)};
179}
180
185template <std::floating_point T>
186std::vector<T>
187solve(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
188 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
189 A,
190 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
191 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
192 B)
193{
194 namespace stdex
195 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
196
197 // Copy A and B to column-major storage
198 stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
199 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
200 _A(A.extents()), _B(B.extents());
201 for (std::size_t i = 0; i < A.extent(0); ++i)
202 for (std::size_t j = 0; j < A.extent(1); ++j)
203 _A(i, j) = A(i, j);
204 for (std::size_t i = 0; i < B.extent(0); ++i)
205 for (std::size_t j = 0; j < B.extent(1); ++j)
206 _B(i, j) = B(i, j);
207
208 int N = _A.extent(0);
209 int nrhs = _B.extent(1);
210 int lda = _A.extent(0);
211 int ldb = _B.extent(0);
212 // Pivot indices that define the permutation matrix for the LU solver
213 std::vector<int> piv(N);
214 int info;
215 if constexpr (std::is_same_v<T, float>)
216 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
217 else if constexpr (std::is_same_v<T, double>)
218 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), _B.data(), &ldb, &info);
219 if (info != 0)
220 throw std::runtime_error("Call to dgesv failed: " + std::to_string(info));
221
222 // Copy result to row-major storage
223 std::vector<T> rb(_B.extent(0) * _B.extent(1));
224 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
225 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
226 r(rb.data(), _B.extents());
227 for (std::size_t i = 0; i < _B.extent(0); ++i)
228 for (std::size_t j = 0; j < _B.extent(1); ++j)
229 r(i, j) = _B(i, j);
230
231 return rb;
232}
233
237template <std::floating_point T>
239 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
240 const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
241 A)
242{
243 // Copy to column major matrix
244 namespace stdex
245 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
246 stdex::mdarray<T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>,
247 MDSPAN_IMPL_STANDARD_NAMESPACE::layout_left>
248 _A(A.extents());
249 for (std::size_t i = 0; i < A.extent(0); ++i)
250 for (std::size_t j = 0; j < A.extent(1); ++j)
251 _A(i, j) = A(i, j);
252
253 std::vector<T> B(A.extent(1), 1);
254 int N = _A.extent(0);
255 int nrhs = 1;
256 int lda = _A.extent(0);
257 int ldb = B.size();
258
259 // Pivot indices that define the permutation matrix for the LU solver
260 std::vector<int> piv(N);
261 int info;
262 if constexpr (std::is_same_v<T, float>)
263 sgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
264 else if constexpr (std::is_same_v<T, double>)
265 dgesv_(&N, &nrhs, _A.data(), &lda, piv.data(), B.data(), &ldb, &info);
266
267 if (info < 0)
268 {
269 throw std::runtime_error("dgesv failed due to invalid value: "
270 + std::to_string(info));
271 }
272 else if (info > 0)
273 return true;
274 else
275 return false;
276}
277
282template <std::floating_point T>
283std::vector<std::size_t>
284transpose_lu(std::pair<std::vector<T>, std::array<std::size_t, 2>>& A)
285{
286 std::size_t dim = A.second[0];
287 assert(dim == A.second[1]);
288 int N = dim;
289 int info;
290 std::vector<int> lu_perm(dim);
291
292 // Comput LU decomposition of M
293 if constexpr (std::is_same_v<T, float>)
294 sgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
295 else if constexpr (std::is_same_v<T, double>)
296 dgetrf_(&N, &N, A.first.data(), &N, lu_perm.data(), &info);
297
298 if (info != 0)
299 {
300 throw std::runtime_error("LU decomposition failed: "
301 + std::to_string(info));
302 }
303
304 std::vector<std::size_t> perm(dim);
305 for (std::size_t i = 0; i < dim; ++i)
306 perm[i] = static_cast<std::size_t>(lu_perm[i] - 1);
307
308 return perm;
309}
310
316template <typename U, typename V, typename W>
317void dot(const U& A, const V& B, W&& C)
318{
319 assert(A.extent(1) == B.extent(0));
320 assert(C.extent(0) == A.extent(0));
321 assert(C.extent(1) == B.extent(1));
322 if (A.extent(0) * B.extent(1) * A.extent(1) < 512)
323 {
324 std::fill_n(C.data_handle(), C.extent(0) * C.extent(1), 0);
325 for (std::size_t i = 0; i < A.extent(0); ++i)
326 for (std::size_t j = 0; j < B.extent(1); ++j)
327 for (std::size_t k = 0; k < A.extent(1); ++k)
328 C(i, j) += A(i, k) * B(k, j);
329 }
330 else
331 {
332 using T = typename std::decay_t<U>::value_type;
333 impl::dot_blas<T>(
334 std::span(A.data_handle(), A.size()), {A.extent(0), A.extent(1)},
335 std::span(B.data_handle(), B.size()), {B.extent(0), B.extent(1)},
336 std::span(C.data_handle(), C.size()));
337 }
338}
339
343template <std::floating_point T>
344std::vector<T> eye(std::size_t n)
345{
346 std::vector<T> I(n * n, 0);
347 namespace stdex
348 = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE;
349 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
350 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
351 Iview(I.data(), n, n);
352 for (std::size_t i = 0; i < n; ++i)
353 Iview(i, i) = 1;
354 return I;
355}
356
361template <std::floating_point T>
363 MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan<
364 T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents<std::size_t, 2>>
365 wcoeffs,
366 std::size_t start = 0)
367{
368 for (std::size_t i = start; i < wcoeffs.extent(0); ++i)
369 {
370 T norm = 0;
371 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
372 norm += wcoeffs(i, k) * wcoeffs(i, k);
373
374 norm = std::sqrt(norm);
375 if (norm < 2 * std::numeric_limits<T>::epsilon())
376 {
377 throw std::runtime_error(
378 "Cannot orthogonalise the rows of a matrix with incomplete row rank");
379 }
380
381 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
382 wcoeffs(i, k) /= norm;
383
384 for (std::size_t j = i + 1; j < wcoeffs.extent(0); ++j)
385 {
386 T a = 0;
387 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
388 a += wcoeffs(i, k) * wcoeffs(j, k);
389 for (std::size_t k = 0; k < wcoeffs.extent(1); ++k)
390 wcoeffs(j, k) -= a * wcoeffs(i, k);
391 }
392 }
393}
394//-----------------------------------------------------------------------------
395
396} // namespace basix::math
A finite element.
Definition finite-element.h:139
Definition math.h:48
void dot(const U &A, const V &B, W &&C)
Definition math.h:317
std::vector< T > solve(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > A, MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > B)
Definition math.h:187
std::array< typename U::value_type, 3 > cross(const U &u, const V &v)
Definition math.h:111
bool is_singular(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< const T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > A)
Definition math.h:238
std::vector< std::size_t > transpose_lu(std::pair< std::vector< T >, std::array< std::size_t, 2 > > &A)
Definition math.h:284
void orthogonalise(MDSPAN_IMPL_STANDARD_NAMESPACE::mdspan< T, MDSPAN_IMPL_STANDARD_NAMESPACE::dextents< std::size_t, 2 > > wcoeffs, std::size_t start=0)
Definition math.h:362
std::pair< std::vector< T >, std::vector< T > > eigh(std::span< const T > A, std::size_t n)
Definition math.h:126
std::vector< T > eye(std::size_t n)
Definition math.h:344
std::pair< std::vector< typename U::value_type >, std::array< std::size_t, 2 > > outer(const U &u, const V &v)
Compute the outer product of vectors u and v.
Definition math.h:96