Ginkgo Generated from branch based on master. Ginkgo version 1.7.0
A numerical linear algebra library targeting many-core architectures
Loading...
Searching...
No Matches
batch_solver_base.hpp
1/*******************************<GINKGO LICENSE>******************************
2Copyright (c) 2017-2023, the Ginkgo authors
3All rights reserved.
4
5Redistribution and use in source and binary forms, with or without
6modification, are permitted provided that the following conditions
7are met:
8
91. Redistributions of source code must retain the above copyright
10notice, this list of conditions and the following disclaimer.
11
122. Redistributions in binary form must reproduce the above copyright
13notice, this list of conditions and the following disclaimer in the
14documentation and/or other materials provided with the distribution.
15
163. Neither the name of the copyright holder nor the names of its
17contributors may be used to endorse or promote products derived from
18this software without specific prior written permission.
19
20THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
23PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31******************************<GINKGO LICENSE>*******************************/
32
33#ifndef GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
34#define GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
35
36
37#include <ginkgo/core/base/abstract_factory.hpp>
38#include <ginkgo/core/base/batch_lin_op.hpp>
39#include <ginkgo/core/base/batch_multi_vector.hpp>
40#include <ginkgo/core/base/utils_helper.hpp>
41#include <ginkgo/core/log/batch_logger.hpp>
42#include <ginkgo/core/matrix/batch_identity.hpp>
43#include <ginkgo/core/stop/batch_stop_enum.hpp>
44
45
46namespace gko {
47namespace batch {
48namespace solver {
49
50
58public:
64 std::shared_ptr<const BatchLinOp> get_system_matrix() const
65 {
66 return this->system_matrix_;
67 }
68
74 std::shared_ptr<const BatchLinOp> get_preconditioner() const
75 {
76 return this->preconditioner_;
77 }
78
84 double get_tolerance() const { return this->residual_tol_; }
85
93 {
94 if (res_tol < 0) {
95 GKO_INVALID_STATE("Tolerance cannot be negative!");
96 }
97 this->residual_tol_ = res_tol;
98 }
99
105 int get_max_iterations() const { return this->max_iterations_; }
106
113 void reset_max_iterations(int max_iterations)
114 {
115 if (max_iterations < 0) {
116 GKO_INVALID_STATE("Max iterations cannot be negative!");
117 }
118 this->max_iterations_ = max_iterations;
119 }
120
126 ::gko::batch::stop::tolerance_type get_tolerance_type() const
127 {
128 return this->tol_type_;
129 }
130
136 void reset_tolerance_type(::gko::batch::stop::tolerance_type tol_type)
137 {
138 if (tol_type == ::gko::batch::stop::tolerance_type::absolute ||
139 tol_type == ::gko::batch::stop::tolerance_type::relative) {
140 this->tol_type_ = tol_type;
141 } else {
142 GKO_INVALID_STATE("Invalid tolerance type specified!");
143 }
144 }
145
146protected:
147 BatchSolver() {}
148
149 BatchSolver(std::shared_ptr<const BatchLinOp> system_matrix,
150 std::shared_ptr<const BatchLinOp> gen_preconditioner,
151 const double res_tol, const int max_iterations,
152 const ::gko::batch::stop::tolerance_type tol_type)
153 : system_matrix_{std::move(system_matrix)},
154 preconditioner_{std::move(gen_preconditioner)},
155 residual_tol_{res_tol},
156 max_iterations_{max_iterations},
157 tol_type_{tol_type},
158 workspace_{}
159 {}
160
161 void set_system_matrix_base(std::shared_ptr<const BatchLinOp> system_matrix)
162 {
163 this->system_matrix_ = std::move(system_matrix);
164 }
165
166 void set_preconditioner_base(std::shared_ptr<const BatchLinOp> precond)
167 {
168 this->preconditioner_ = std::move(precond);
169 }
170
171 std::shared_ptr<const BatchLinOp> system_matrix_{};
172 std::shared_ptr<const BatchLinOp> preconditioner_{};
173 double residual_tol_{};
174 int max_iterations_{};
175 ::gko::batch::stop::tolerance_type tol_type_{};
176 mutable array<unsigned char> workspace_{};
177};
178
179
180template <typename Parameters, typename Factory>
182 : enable_parameters_type<Parameters, Factory> {
190
198
203 ::gko::batch::stop::tolerance_type GKO_FACTORY_PARAMETER_SCALAR(
204 tolerance_type, ::gko::batch::stop::tolerance_type::absolute);
205
210 std::shared_ptr<const BatchLinOpFactory> GKO_DEFERRED_FACTORY_PARAMETER(
212
217 std::shared_ptr<const BatchLinOp> GKO_FACTORY_PARAMETER_SCALAR(
219};
220
221
230template <typename ConcreteSolver, typename ValueType,
231 typename PolymorphicBase = BatchLinOp>
233 : public BatchSolver,
234 public EnableBatchLinOp<ConcreteSolver, PolymorphicBase> {
235public:
236 using real_type = remove_complex<ValueType>;
237
238 const ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> b,
240 {
241 this->validate_application_parameters(b.get(), x.get());
242 auto exec = this->get_executor();
243 this->apply_impl(make_temporary_clone(exec, b).get(),
244 make_temporary_clone(exec, x).get());
245 return self();
246 }
247
248 const ConcreteSolver* apply(ptr_param<const MultiVector<ValueType>> alpha,
252 {
253 this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
254 x.get());
255 auto exec = this->get_executor();
256 this->apply_impl(make_temporary_clone(exec, alpha).get(),
257 make_temporary_clone(exec, b).get(),
258 make_temporary_clone(exec, beta).get(),
259 make_temporary_clone(exec, x).get());
260 return self();
261 }
262
265 {
266 this->validate_application_parameters(b.get(), x.get());
267 auto exec = this->get_executor();
268 this->apply_impl(make_temporary_clone(exec, b).get(),
269 make_temporary_clone(exec, x).get());
270 return self();
271 }
272
277 {
278 this->validate_application_parameters(alpha.get(), b.get(), beta.get(),
279 x.get());
280 auto exec = this->get_executor();
281 this->apply_impl(make_temporary_clone(exec, alpha).get(),
282 make_temporary_clone(exec, b).get(),
283 make_temporary_clone(exec, beta).get(),
284 make_temporary_clone(exec, x).get());
285 return self();
286 }
287
288protected:
289 GKO_ENABLE_SELF(ConcreteSolver);
290
291 explicit EnableBatchSolver(std::shared_ptr<const Executor> exec)
293 {}
294
295 template <typename FactoryParameters>
296 explicit EnableBatchSolver(std::shared_ptr<const Executor> exec,
297 std::shared_ptr<const BatchLinOp> system_matrix,
299 : BatchSolver(system_matrix, nullptr, params.tolerance,
300 params.max_iterations, params.tolerance_type),
302 exec, gko::transpose(system_matrix->get_size()))
303 {
304 GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(system_matrix_);
305
306 using value_type = typename ConcreteSolver::value_type;
307 using Identity = matrix::Identity<value_type>;
308 using real_type = remove_complex<value_type>;
309
310 if (params.generated_preconditioner) {
311 GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(params.generated_preconditioner,
312 this);
313 preconditioner_ = std::move(params.generated_preconditioner);
314 } else if (params.preconditioner) {
315 preconditioner_ = params.preconditioner->generate(system_matrix_);
316 } else {
317 auto id = Identity::create(exec, system_matrix->get_size());
318 preconditioner_ = std::move(id);
319 }
320 const size_type workspace_size = system_matrix->get_num_batch_items() *
321 (sizeof(real_type) + sizeof(int));
322 workspace_.set_executor(exec);
323 workspace_.resize_and_reset(workspace_size);
324 }
325
326 void set_system_matrix(std::shared_ptr<const BatchLinOp> new_system_matrix)
327 {
328 auto exec = self()->get_executor();
329 if (new_system_matrix) {
330 GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(self(), new_system_matrix);
331 GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(new_system_matrix);
332 if (new_system_matrix->get_executor() != exec) {
334 }
335 }
336 this->set_system_matrix_base(new_system_matrix);
337 }
338
339 void set_preconditioner(std::shared_ptr<const BatchLinOp> new_precond)
340 {
341 auto exec = self()->get_executor();
342 if (new_precond) {
343 GKO_ASSERT_BATCH_EQUAL_DIMENSIONS(self(), new_precond);
344 GKO_ASSERT_BATCH_HAS_SQUARE_DIMENSIONS(new_precond);
345 if (new_precond->get_executor() != exec) {
347 }
348 }
349 this->set_preconditioner_base(new_precond);
350 }
351
352 EnableBatchSolver& operator=(const EnableBatchSolver& other)
353 {
354 if (&other != this) {
355 this->set_size(other.get_size());
356 this->set_system_matrix(other.get_system_matrix());
357 this->set_preconditioner(other.get_preconditioner());
358 this->reset_tolerance(other.get_tolerance());
361 }
362 return *this;
363 }
364
366 {
367 if (&other != this) {
368 this->set_size(other.get_size());
369 this->set_system_matrix(other.get_system_matrix());
370 this->set_preconditioner(other.get_preconditioner());
371 this->reset_tolerance(other.get_tolerance());
374 other.set_system_matrix(nullptr);
375 other.set_preconditioner(nullptr);
376 }
377 return *this;
378 }
379
382 other.self()->get_executor(), other.self()->get_size())
383 {
384 *this = other;
385 }
386
389 other.self()->get_executor(), other.self()->get_size())
390 {
391 *this = std::move(other);
392 }
393
394 void apply_impl(const MultiVector<ValueType>* b,
395 MultiVector<ValueType>* x) const
396 {
397 auto exec = this->get_executor();
398 if (b->get_common_size()[1] > 1) {
399 GKO_NOT_IMPLEMENTED;
400 }
401 auto workspace_view = workspace_.as_view();
402 auto log_data_ = std::make_unique<log::detail::log_data<real_type>>(
404
405 this->solver_apply(b, x, log_data_.get());
406
408 log_data_->iter_counts, log_data_->res_norms);
409 }
410
411 void apply_impl(const MultiVector<ValueType>* alpha,
412 const MultiVector<ValueType>* b,
413 const MultiVector<ValueType>* beta,
414 MultiVector<ValueType>* x) const
415 {
416 auto x_clone = x->clone();
417 this->apply(b, x_clone.get());
418 x->scale(beta);
419 x->add_scaled(alpha, x_clone.get());
420 }
421
422 virtual void solver_apply(const MultiVector<ValueType>* b,
424 log::detail::log_data<real_type>* info) const = 0;
425};
426
427
428} // namespace solver
429} // namespace batch
430} // namespace gko
431
432
433#endif // GKO_PUBLIC_CORE_SOLVER_BATCH_SOLVER_BASE_HPP_
std::shared_ptr< const Executor > get_executor() const noexcept
Returns the Executor of the object.
Definition polymorphic_object.hpp:263
Definition batch_lin_op.hpp:88
The EnableBatchLinOp mixin can be used to provide sensible default implementations of the majority of...
Definition batch_lin_op.hpp:281
MultiVector stores multiple vectors in a batched fashion and is useful for batched operations.
Definition batch_multi_vector.hpp:85
void scale(ptr_param< const MultiVector< ValueType > > alpha)
Scales the vector with a scalar (aka: BLAS scal).
dim< 2 > get_common_size() const
Returns the common size of the batch items.
Definition batch_multi_vector.hpp:157
size_type get_num_batch_items() const
Returns the number of batch items.
Definition batch_multi_vector.hpp:147
void add_scaled(ptr_param< const MultiVector< ValueType > > alpha, ptr_param< const MultiVector< ValueType > > b)
Adds b scaled by alpha to the vector (aka: BLAS axpy).
The batch Identity matrix, which represents a batch of Identity matrices.
Definition batch_identity.hpp:61
The BatchSolver is a base class for all batched solvers and provides the common getters and setter fo...
Definition batch_solver_base.hpp:57
std::shared_ptr< const BatchLinOp > get_system_matrix() const
Returns the system operator (matrix) of the linear system.
Definition batch_solver_base.hpp:64
void reset_max_iterations(int max_iterations)
Set the maximum number of iterations for the solver to use, independent of the factory that created i...
Definition batch_solver_base.hpp:113
double get_tolerance() const
Get the residual tolerance used by the solver.
Definition batch_solver_base.hpp:84
int get_max_iterations() const
Get the maximum number of iterations set on the solver.
Definition batch_solver_base.hpp:105
void reset_tolerance(double res_tol)
Update the residual tolerance to be used by the solver.
Definition batch_solver_base.hpp:92
::gko::batch::stop::tolerance_type get_tolerance_type() const
Get the tolerance type.
Definition batch_solver_base.hpp:126
void reset_tolerance_type(::gko::batch::stop::tolerance_type tol_type)
Set the type of tolerance check to use inside the solver.
Definition batch_solver_base.hpp:136
std::shared_ptr< const BatchLinOp > get_preconditioner() const
Returns the generated preconditioner.
Definition batch_solver_base.hpp:74
This mixin provides apply and common iterative solver functionality to all the batched solvers.
Definition batch_solver_base.hpp:234
The enable_parameters_type mixin is used to create a base implementation of the factory parameters st...
Definition abstract_factory.hpp:239
This class is used for function parameters in the place of raw pointers.
Definition utils_helper.hpp:71
#define GKO_FACTORY_PARAMETER_SCALAR(_name, _default)
Creates a scalar factory parameter in the factory parameters structure.
Definition abstract_factory.hpp:473
The Ginkgo namespace.
Definition abstract_factory.hpp:48
constexpr T one()
Returns the multiplicative identity for T.
Definition math.hpp:803
typename detail::remove_complex_s< T >::type remove_complex
Obtain the type which removed the complex of complex/scalar type or the template parameter of class b...
Definition math.hpp:354
std::size_t size_type
Integral type used for allocation quantities.
Definition types.hpp:120
detail::cloned_type< Pointer > clone(const Pointer &p)
Creates a unique clone of the object pointed to by p.
Definition utils_helper.hpp:203
batch_dim< 2, DimensionType > transpose(const batch_dim< 2, DimensionType > &input)
Returns a batch_dim object with its dimensions swapped for batched operators.
Definition batch_dim.hpp:148
detail::temporary_clone< detail::pointee< Ptr > > make_temporary_clone(std::shared_ptr< const Executor > exec, Ptr &&ptr)
Creates a temporary_clone.
Definition temporary_clone.hpp:207
int max_iterations
Default maximum number iterations allowed.
Definition batch_solver_base.hpp:189
double tolerance
Default residual tolerance.
Definition batch_solver_base.hpp:197
std::shared_ptr< const BatchLinOpFactory > preconditioner
The preconditioner to be used by the iterative solver.
Definition batch_solver_base.hpp:211
::gko::batch::stop::tolerance_type tolerance_type
To specify which type of tolerance check is to be considered, absolute or relative (to the rhs l2 nor...
Definition batch_solver_base.hpp:204
std::shared_ptr< const BatchLinOp > generated_preconditioner
Already generated preconditioner.
Definition batch_solver_base.hpp:218