Actual source code: matdiagonalcupm.hpp

  1: #pragma once

  3: #include <petscmat.h>

  5: #include "../src/sys/objects/device/impls/cupm/cupmthrustutility.hpp"

  7: #include <petsc/private/cupminterface.hpp>
  8: #include <petsc/private/cupmobject.hpp>
  9: #include <petsc/private/deviceimpl.h>
 10: #include <petsc/private/vecimpl.h>
 11: #include <petsc/private/veccupmimpl.h>
 12: #include <petsc/private/matimpl.h>

 14: #include <thrust/device_ptr.h>
 15: #include <thrust/iterator/zip_iterator.h>
 16: #include <thrust/transform_reduce.h>
 17: #include <thrust/tuple.h>

 19: namespace Petsc
 20: {

 22: namespace device
 23: {

 25: namespace cupm
 26: {

 28: namespace impl
 29: {

 31: template <DeviceType T, typename VecType>
 32: struct MatDiagonal_CUPM : vec::cupm::impl::Vec_CUPMBase<T, VecType> {
 33:   PETSC_CUPMOBJECT_HEADER(T);
 34:   using base_type = ::Petsc::vec::cupm::impl::Vec_CUPMBase<T, VecType>;
 35:   friend base_type;

 37:   static PetscErrorCode ADot(Mat A, Vec x, Vec y, PetscScalar *z) noexcept;
 38:   static PetscErrorCode ANormSq(Mat A, Vec x, PetscReal *z) noexcept;
 39: };

 41: namespace detail
 42: {
 43: struct adot_transform {
 44:   using argument_type = thrust::tuple<PetscScalar, PetscScalar, PetscScalar>;

 46:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const argument_type &tup) const noexcept { return PetscConj(thrust::get<1>(tup)) * thrust::get<2>(tup) * thrust::get<0>(tup); }
 47: };
 48: } // namespace detail

 50: template <Petsc::device::cupm::DeviceType T, typename VecType>
 51: inline PetscErrorCode MatDiagonal_CUPM<T, VecType>::ADot(Mat A, Vec x, Vec y, PetscScalar *z) noexcept
 52: {
 53:   PetscDeviceContext dctx;
 54:   cupmStream_t       stream;
 55:   Mat_Diagonal      *ctx  = (Mat_Diagonal *)A->data;
 56:   PetscScalar        zero = 0.;
 57:   const PetscInt     n    = x->map->n;

 59:   PetscFunctionBegin;
 60:   PetscCall(GetHandles_(&dctx, &stream));

 62:   const auto xdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, x).data());
 63:   const auto ydptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, y).data());
 64:   const auto wdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, ctx->diag).data());

 66:   // clang-format off
 67:     PetscCallThrust(
 68:       *z = THRUST_CALL(
 69:         thrust::transform_reduce,
 70:         stream,
 71:         thrust::make_zip_iterator(thrust::make_tuple(xdptr, ydptr, wdptr)),
 72:         thrust::make_zip_iterator(thrust::make_tuple(xdptr + n, ydptr + n, wdptr + n)),
 73:         detail::adot_transform{},
 74:         zero,
 75:         thrust::plus<PetscScalar>()
 76:       )
 77:     );
 78:   // clang-format on
 79:   if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
 80:   PetscFunctionReturn(PETSC_SUCCESS);
 81: }

 83: namespace detail
 84: {
 85: struct anorm_transform {
 86:   using argument_type = thrust::tuple<PetscScalar, PetscScalar>;

 88:   PETSC_NODISCARD PETSC_HOSTDEVICE_INLINE_DECL PetscScalar operator()(const argument_type &tup) const noexcept { return thrust::get<1>(tup) * PetscConj(thrust::get<0>(tup)) * thrust::get<0>(tup); }
 89: };
 90: } // namespace detail

 92: template <Petsc::device::cupm::DeviceType T, typename VecType>
 93: inline PetscErrorCode MatDiagonal_CUPM<T, VecType>::ANormSq(Mat A, Vec x, PetscReal *z) noexcept
 94: {
 95:   PetscDeviceContext dctx;
 96:   cupmStream_t       stream;
 97:   Mat_Diagonal      *ctx  = (Mat_Diagonal *)A->data;
 98:   PetscScalar        zero = 0., res;
 99:   const PetscInt     n    = x->map->n;

101:   PetscFunctionBegin;
102:   PetscCall(GetHandles_(&dctx, &stream));

104:   const auto xdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, x).data());
105:   const auto wdptr = thrust::device_pointer_cast(base_type::DeviceArrayRead(dctx, ctx->diag).data());

107:   // clang-format off
108:   PetscCallThrust(
109:     res = THRUST_CALL(
110:       thrust::transform_reduce,
111:       stream,
112:       thrust::make_zip_iterator(thrust::make_tuple(xdptr, wdptr)),
113:       thrust::make_zip_iterator(thrust::make_tuple(xdptr + n, wdptr + n)),
114:       detail::anorm_transform{},
115:       zero,
116:       thrust::plus<PetscScalar>()
117:     )
118:   );
119:   // clang-format on
120:   *z = PetscRealPart(res);
121:   if (x->map->n > 0) PetscCall(PetscLogGpuFlops(3.0 * x->map->n));
122:   PetscFunctionReturn(PETSC_SUCCESS);
123: }

125: } // namespace impl

127: } // namespace cupm

129: } // namespace device

131: } // namespace Petsc