Ginkgo Generated from branch based on master. Ginkgo version 1.7.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
triangular.hpp
1/*******************************<GINKGO LICENSE>******************************
2Copyright (c) 2017-2023, the Ginkgo authors
3All rights reserved.
4
5Redistribution and use in source and binary forms, with or without
6modification, are permitted provided that the following conditions
7are met:
8
91. Redistributions of source code must retain the above copyright
10notice, this list of conditions and the following disclaimer.
11
122. Redistributions in binary form must reproduce the above copyright
13notice, this list of conditions and the following disclaimer in the
14documentation and/or other materials provided with the distribution.
15
163. Neither the name of the copyright holder nor the names of its
17contributors may be used to endorse or promote products derived from
18this software without specific prior written permission.
19
20THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
23PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31******************************<GINKGO LICENSE>*******************************/
32
33#ifndef GKO_PUBLIC_CORE_SOLVER_TRIANGULAR_HPP_
34#define GKO_PUBLIC_CORE_SOLVER_TRIANGULAR_HPP_
35
36
37#include <memory>
38#include <utility>
39
40
41#include <ginkgo/core/base/abstract_factory.hpp>
42#include <ginkgo/core/base/array.hpp>
43#include <ginkgo/core/base/dim.hpp>
44#include <ginkgo/core/base/exception_helpers.hpp>
45#include <ginkgo/core/base/lin_op.hpp>
46#include <ginkgo/core/base/polymorphic_object.hpp>
47#include <ginkgo/core/base/types.hpp>
48#include <ginkgo/core/base/utils.hpp>
49#include <ginkgo/core/log/logger.hpp>
50#include <ginkgo/core/matrix/csr.hpp>
51#include <ginkgo/core/matrix/identity.hpp>
52#include <ginkgo/core/solver/solver_base.hpp>
53
54
55namespace gko {
56namespace solver {
57
58
59struct SolveStruct;
60
61
67enum class trisolve_algorithm { sparselib, syncfree };
68
69
70template <typename ValueType, typename IndexType>
71class UpperTrs;
72
73
91template <typename ValueType = default_precision, typename IndexType = int32>
92class LowerTrs : public EnableLinOp<LowerTrs<ValueType, IndexType>>,
93 public EnableSolverBase<LowerTrs<ValueType, IndexType>,
94 matrix::Csr<ValueType, IndexType>>,
95 public Transposable {
96 friend class EnableLinOp<LowerTrs>;
98 friend class UpperTrs<ValueType, IndexType>;
99
100public:
101 using value_type = ValueType;
102 using index_type = IndexType;
104
105 std::unique_ptr<LinOp> transpose() const override;
106
107 std::unique_ptr<LinOp> conj_transpose() const override;
108
110 {
118
123 bool GKO_FACTORY_PARAMETER_SCALAR(unit_diagonal, false);
124
132 algorithm, trisolve_algorithm::sparselib);
133 };
136
143
151
158
165
166protected:
168
169 void apply_impl(const LinOp* b, LinOp* x) const override;
170
171 void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
172 LinOp* x) const override;
173
178 void generate();
179
180 explicit LowerTrs(std::shared_ptr<const Executor> exec)
181 : EnableLinOp<LowerTrs>(std::move(exec))
182 {}
183
184 explicit LowerTrs(const Factory* factory,
185 std::shared_ptr<const LinOp> system_matrix)
186 : EnableLinOp<LowerTrs>(factory->get_executor(),
187 gko::transpose(system_matrix->get_size())),
188 EnableSolverBase<LowerTrs<ValueType, IndexType>, CsrMatrix>{
189 copy_and_convert_to<CsrMatrix>(factory->get_executor(),
190 system_matrix)},
191 parameters_{factory->get_parameters()}
192 {
193 this->generate();
194 }
195
196private:
197 std::shared_ptr<solver::SolveStruct> solve_struct_;
198};
199
200
201template <typename ValueType, typename IndexType>
202struct workspace_traits<LowerTrs<ValueType, IndexType>> {
204 // number of vectors used by this workspace
205 static int num_vectors(const Solver&);
206 // number of arrays used by this workspace
207 static int num_arrays(const Solver&);
208 // array containing the num_vectors names for the workspace vectors
209 static std::vector<std::string> op_names(const Solver&);
210 // array containing the num_arrays names for the workspace vectors
211 static std::vector<std::string> array_names(const Solver&);
212 // array containing all varying scalar vectors (independent of problem size)
213 static std::vector<int> scalars(const Solver&);
214 // array containing all varying vectors (dependent on problem size)
215 static std::vector<int> vectors(const Solver&);
216
217 // transposed input vector
218 constexpr static int transposed_b = 0;
219 // transposed output vector
220 constexpr static int transposed_x = 1;
221};
222
223
241template <typename ValueType = default_precision, typename IndexType = int32>
242class UpperTrs : public EnableLinOp<UpperTrs<ValueType, IndexType>>,
243 public EnableSolverBase<UpperTrs<ValueType, IndexType>,
244 matrix::Csr<ValueType, IndexType>>,
245 public Transposable {
246 friend class EnableLinOp<UpperTrs>;
248 friend class LowerTrs<ValueType, IndexType>;
249
250public:
251 using value_type = ValueType;
252 using index_type = IndexType;
254
255 std::unique_ptr<LinOp> transpose() const override;
256
257 std::unique_ptr<LinOp> conj_transpose() const override;
258
260 {
268
273 bool GKO_FACTORY_PARAMETER_SCALAR(unit_diagonal, false);
274
282 algorithm, trisolve_algorithm::sparselib);
283 };
286
293
301
308
315
316protected:
318
319 void apply_impl(const LinOp* b, LinOp* x) const override;
320
321 void apply_impl(const LinOp* alpha, const LinOp* b, const LinOp* beta,
322 LinOp* x) const override;
323
328 void generate();
329
330 explicit UpperTrs(std::shared_ptr<const Executor> exec)
331 : EnableLinOp<UpperTrs>(std::move(exec))
332 {}
333
334 explicit UpperTrs(const Factory* factory,
335 std::shared_ptr<const LinOp> system_matrix)
336 : EnableLinOp<UpperTrs>(factory->get_executor(),
337 gko::transpose(system_matrix->get_size())),
338 EnableSolverBase<UpperTrs<ValueType, IndexType>, CsrMatrix>{
339 copy_and_convert_to<CsrMatrix>(factory->get_executor(),
340 system_matrix)},
341 parameters_{factory->get_parameters()}
342 {
343 this->generate();
344 }
345
346private:
347 std::shared_ptr<solver::SolveStruct> solve_struct_;
348};
349
350
351template <typename ValueType, typename IndexType>
352struct workspace_traits<UpperTrs<ValueType, IndexType>> {
354 // number of vectors used by this workspace
355 static int num_vectors(const Solver&);
356 // number of arrays used by this workspace
357 static int num_arrays(const Solver&);
358 // array containing the num_vectors names for the workspace vectors
359 static std::vector<std::string> op_names(const Solver&);
360 // array containing the num_arrays names for the workspace vectors
361 static std::vector<std::string> array_names(const Solver&);
362 // array containing all varying scalar vectors (independent of problem size)
363 static std::vector<int> scalars(const Solver&);
364 // array containing all varying vectors (dependent on problem size)
365 static std::vector<int> vectors(const Solver&);
366
367 // transposed input vector
368 constexpr static int transposed_b = 0;
369 // transposed output vector
370 constexpr static int transposed_x = 1;
371};
372
373
374} // namespace solver
375} // namespace gko
376
377
378#endif // GKO_PUBLIC_CORE_SOLVER_TRIANGULAR_HPP_
The EnableLinOp mixin can be used to provide sensible default implementations of the majority of the ...
Definition lin_op.hpp:908
This mixin inherits from (a subclass of) PolymorphicObject and provides a base implementation of a ne...
Definition polymorphic_object.hpp:691
Definition lin_op.hpp:146
std::shared_ptr< const Executor > get_executor() const noexcept
Returns the Executor of the object.
Definition polymorphic_object.hpp:263
Linear operators which support transposition should implement the Transposable interface.
Definition lin_op.hpp:462
CSR is a matrix format which stores only the nonzero coefficients by compressing each row of the matr...
Definition csr.hpp:146
A LinOp deriving from this CRTP class stores a system matrix.
Definition solver_base.hpp:570
Definition triangular.hpp:134
LowerTrs is the triangular solver which solves the system L x = b, when L is a lower triangular matri...
Definition triangular.hpp:95
LowerTrs & operator=(LowerTrs &&)
Move-constructs a triangular solver.
LowerTrs(const LowerTrs &)
Copy-assigns a triangular solver.
std::unique_ptr< LinOp > transpose() const override
Returns a LinOp representing the transpose of the Transposable object.
LowerTrs & operator=(const LowerTrs &)
Copy-constructs a triangular solver.
std::unique_ptr< LinOp > conj_transpose() const override
Returns a LinOp representing the conjugate transpose of the Transposable object.
LowerTrs(LowerTrs &&)
Move-assigns a triangular solver.
Definition triangular.hpp:284
UpperTrs is the triangular solver which solves the system U x = b, when U is an upper triangular matr...
Definition triangular.hpp:245
UpperTrs(UpperTrs &&)
Move-assigns a triangular solver.
UpperTrs(const UpperTrs &)
Copy-assigns a triangular solver.
std::unique_ptr< LinOp > conj_transpose() const override
Returns a LinOp representing the conjugate transpose of the Transposable object.
UpperTrs & operator=(UpperTrs &&)
Move-constructs a triangular solver.
std::unique_ptr< LinOp > transpose() const override
Returns a LinOp representing the transpose of the Transposable object.
UpperTrs & operator=(const UpperTrs &)
Copy-constructs a triangular solver.
#define GKO_CREATE_FACTORY_PARAMETERS(_parameters_name, _factory_name)
This Macro will generate a new type containing the parameters for the factory _factory_name.
Definition abstract_factory.hpp:308
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Creates a scalar factory parameter in the factory parameters structure.
Definition abstract_factory.hpp:473
#define GKO_ENABLE_BUILD_METHOD(_factory_name)
Defines a build method for the factory, simplifying its construction by removing the repetitive typin...
Definition abstract_factory.hpp:422
#define GKO_ENABLE_LIN_OP_FACTORY(_lin_op, _parameters_name, _factory_name)
This macro will generate a default implementation of a LinOpFactory for the LinOp subclass it is defi...
Definition lin_op.hpp:1046
trisolve_algorithm
A helper for algorithm selection in the triangular solvers.
Definition triangular.hpp:67
The Ginkgo namespace.
Definition abstract_factory.hpp:48
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:803
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:120
Traits class providing information on the type and location of workspace vectors inside a solver.
Definition solver_base.hpp:267