diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 15465a404c8a34f64831e57baed096d1384bd1fc..40a6d551c5f09c4a1b28641223e709b8a40ac409 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -24,17 +24,15 @@ jobs: - name: Formatting check run: - cd faer-libs && cargo fmt --all -- --check && - cd ../faer-entity && + cd ./faer-entity && cargo fmt --all -- --check # want to get all quality issues continue-on-error: true - name: Linting check run: - cd faer-libs && cargo clippy --all-targets && - cd ../faer-entity && + cd ./faer-entity && cargo clippy --all-targets continue-on-error: true diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index f2b902ab87093e466f52e060d841baddeaa277f7..ab00d53d1d663539e4d7bfd46ed809a199c7b259 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -34,9 +34,8 @@ jobs: - name: Verify 1.67.0 run: - cd faer-libs && cargo check && - cd ../faer-entity && + cd ./faer-entity && cargo check testing: @@ -70,7 +69,7 @@ jobs: uses: taiki-e/install-action@cargo-llvm-cov - name: Collect coverage data - run: cd faer-libs && cargo llvm-cov nextest --lcov --output-path lcov.info --workspace + run: cargo llvm-cov nextest --lcov --output-path lcov.info --workspace - name: Upload coverage data to codecov uses: codecov/codecov-action@v3 diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..169858d0611c418459196e751ddb79defda76bb6 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,58 @@ +[package] +name = "faer" +version = "0.18.0" +edition = "2021" + +[dependencies] +bytemuck = "1.14.3" +coe-rs = "0.1.2" +dbgf = "0.1.1" +paste = "1.0.14" +reborrow = "0.5.5" + +dyn-stack = "0.10.0" +equator = "0.1.10" +faer-entity = { version ="0.17.0", default-features = false, path = "./faer-entity" } + +gemm = { version = "0.17.1", default-features = false } +num-complex = { version = "0.4.5", default-features = false } +num-traits = { version = "0.2.18", default-features = false } + +matrixcompare-core = { version = "0.1.0", optional = true } +matrixcompare = { version = "0.3", optional = true } + +rayon = { version = "1.8.1", optional = true } +serde = { version = "1", optional = true, features = ["derive"] } +log = { version = "0.4", optional = true, default-features = false } +npyz = { version = "0.8", optional = true } + +[features] +default = ["std", "rayon", "serde"] +std = [ + "faer-entity/std", + "gemm/std", + "matrixcompare-core", + "matrixcompare", + "num-traits/std", + "num-complex/std", +] +rayon = ["std", "gemm/rayon", "dep:rayon"] +nightly = ["faer-entity/nightly", "gemm/nightly"] +perf-warn = ["log"] +serde = ["dep:serde"] +npy = ["std", "dep:npyz"] + +[dev-dependencies] +amd = "0.2.2" +assert_approx_eq = "1.1.0" +matrix-market-rs = "0.1.3" +matrixcompare = "0.3.0" +rand = "0.8.5" +serde_test = "1.0.176" + +[profile.dev] +opt-level = 3 + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"] diff --git a/book/Cargo.toml b/book/Cargo.toml deleted file mode 100644 index 334a817812a6c01ec9196277c4472767e445da21..0000000000000000000000000000000000000000 --- a/book/Cargo.toml +++ /dev/null @@ -1,17 +0,0 @@ -[package] -name = "faer-book" -version = "0.0.0" -edition = "2021" - -[dependencies] -faer = "0.15.0" -faer-core = "0.15.0" -faer-cholesky = "0.15.0" -faer-qr = "0.15.0" -faer-lu = "0.15.0" -faer-svd = "0.15.0" -faer-evd = "0.15.0" - -[[bin]] -name = "intro" -path = "src/intro.rs" diff --git a/book/brand-rust.svg b/book/brand-rust.svg deleted file mode 100644 index dd830c261bf3675146ac5456a152fce57a1b941b..0000000000000000000000000000000000000000 --- a/book/brand-rust.svg +++ /dev/null @@ -1,57 +0,0 @@ - - - \ No newline at end of file diff --git a/book/dev_guide.typ b/book/dev_guide.typ deleted file mode 100644 index 131a0968df162e2c209b255c9031c52c6646de3c..0000000000000000000000000000000000000000 --- a/book/dev_guide.typ +++ /dev/null @@ -1,1034 +0,0 @@ -#set text(font: "New Computer Modern") - -#show raw: set text(font: "New Computer Modern Mono", size: 1.2em) - -#show par: set block(spacing: 0.55em) - -#show heading: set block(above: 1.4em, below: 1em) - -#show link: underline - -#set page(numbering: "1") - -#set par(leading: 0.55em, justify: true) - -#set heading(numbering: "1.1") - -#show heading.where(level: 1): it => pagebreak(weak:true) + block({ - set text(font: "New Computer Modern", weight: "black") - v(2cm) - block(text(18pt)[Chapter #counter(heading).display()]) - v(1cm) - block(text(22pt)[#it.body]) - v(1cm) -}) - -#import "@preview/codly:0.1.0" -#import "@preview/tablex:0.0.6": tablex, rowspanx, colspanx, gridx, hlinex, vlinex -#import "@preview/colorful-boxes:1.2.0": colorbox - -#let icon(codepoint) = { - box( - height: 0.8em, - baseline: 0.05em, - image(codepoint) - ) - h(0.1em) -} - -#show: codly.codly-init.with() - -#codly.codly( - languages: ( - rust: (name: "Rust", icon: icon("brand-rust.svg"), color: rgb("#CE412B")), - ), - breakable: false, - width-numbers: none, -) - -#outline() - -== Introduction -_`faer-rs`_ is a general-purpose linear algebra library for the Rust -programming language, with a focus on correctness, portability, and -performance. -In this book, we'll be assuming version `0.16.0` of the library. - -_`faer`_ is designed around a high level API that sacrifices some amount of -performance and customizability in exchange for ease of use, as well as a low -level API that offers more control over memory allocations and multithreading -capabilities. The two APIs share the same data structures and can be used -together or separately, depending on the user's needs. - -This book assumes some level of familiarity with Rust, linear algebra and _`faer`_'s API. -Users who are new to the library are encouraged to get started by taking a look -at the user guide, the library's examples directory -#footnote[`faer-rs/faer-libs/faer/examples`] and browsing the `docs.rs` -documentation #footnote[https://docs.rs/faer/0.16.0/faer/index.html]. - -We will go into detail over the various operations and matrix decompositions -that are provided by the library, as well as their implementation details. We -will also explain the architecture of _`faer`_'s data structures and how low -level operations are handled using vectorized SIMD instructions. - -#pagebreak() - -= Data layout and the `Entity` trait - -In most linear algebra libraries, matrix data is stored contiguously in memory, -regardless of the scalar type. This can be done in two ways, either a row-major -layout or a column-major layout. - -Consider the matrix -$ mat( - a_11, a_12; - a_21, a_22; - a_31, a_32; -) $ -Storing it in row-major layout would place the values in memory in the following order: -$ ( - a_11, a_12, - a_21, a_22, - a_31, a_32 -), $ -while storing it in column-major order would place the values in memory in this order: -$ ( - a_11, a_21, a_31, - a_12, a_22, a_32 -). $ - -_`faer`_, on the other hand, first splits each scalar into its atomic units, -then stores each unit matrix separately in a contiguous fashion. The library -does not mandate the usage of one layout or the other, but heavily prefers to receive -data in column-major layout, with the notable exception of matrix multiplication which -we try to optimize for both column-major and row-major layouts. - -The way in which a scalar can be split is chosen by the scalar type itself. -For example, a complex floating point type may choose to either be stored as one unit -or as a group of two units. - -Given the following complex matrix: -$ mat( - a_11 + i b_11, a_12 + i b_12; - a_21 + i b_21, a_22 + i b_22; - a_31 + i b_31, a_32 + i b_32; -), $ -and assuming column-major layout, we can either choose the following storage scheme in which -the full number is considered a single unit: -$ ( - a_11, b_11, a_21, b_21, a_31, b_31, - a_12, b_12, a_22, b_22, a_32, b_32 -), $ - -or the following scheme in which the real and imaginary parts are considered two distinct units -$ ( - a_11, a_21, a_31, - a_12, a_22, a_32 -),\ -( - b_11, b_21, b_31, - b_12, b_22, b_32 -). $ - -The former is commonly referred to as AoS layout (array of structures), while -the latter is called SoA (structure of arrays). The choice of which one to use -depends on the context. As a general rule, types that are natively vectorizable -(have direct CPU support for arithmetic operations) prefer to be laid out in -AoS layout. On the other hand, types that do not have native vectorization -support but can still be vectorized by combining more primitive operations -prefer to be laid out in SoA layout. - -Types that are not vectorizable may be in either one, but the AoS layout is -typically easier to work with in that scenario. - -== `Entity` trait -The `Entity` trait determines how a type prefers to be stored in memory, -through its associated type `Group`. - -Given some type `E` that implements `Entity`, we can manipulate groups of -arbitrary types in a generic way. - -For example, `faer_core::GroupFor` is an `E`-group of `E::Unit`, which can be -thought of as a raw representation of `E`. - -Pre-existing data can be referred to using a reference to a slice or a raw -pointer, for example `GroupFor`. - -The `Entity` trait requires associated functions to convert from one `E`-group type to another. -For example, we can take a reference to each element in a group with -`E::faer_as_ref`, or `E::faer_as_mut`. - -```rust -use faer_core::{Entity, GroupFor}; - -fn value_to_unit_references(value: E) { - let units: GroupFor = value.into_units(); - let references: GroupFor = E::faer_as_ref(&units); -} -``` - -We can map one group type to another using `E::faer_map`. -```rust -use faer_core::{Entity, GroupFor}; - -fn slice_to_ptr( - slice: GroupFor -) -> GroupFor { - E::faer_map(slice, |slice| slice.as_ptr()) -} -``` - -We can also zip and unzip groups of values with `E::faer_zip` and `E::faer_unzip`. -```rust -use faer_core::{Entity, GroupFor}; - -unsafe fn ptr_to_slice<'a, E: Entity>( - ptr: GroupFor, - len: GroupFor -) -> GroupFor { - let zipped: GroupFor = E::faer_zip(ptr, len); - E::faer_map(zipped, |(ptr, len)| std::slice::from_raw_parts(ptr, len)) -} - -unsafe fn split_at( - slice: GroupFor, - mid: usize -) -> (GroupFor, GroupFor) { - E::faer_unzip(E::faer_map(slice, |slice| slice.split_at(mid))) -} -``` - -== Matrix layout -Matrices in _`faer`_ fall into two broad categories with respect to layout. Owned -matrices (`Mat`) which are always stored in column-major layout, and matrix views -(`MatRef`/`MatMut`) which allow any strided layout. - -Note that even though matrix views allow for any row and column stride, they -are still typically optimized for column major layout, since that happens to be -the preferred layout for most matrix decompositions. - -Matrix views are roughly defined as: -```rust -struct MatRef<'a, E: Entity> { - ptr: GroupFor, - nrows: usize, - ncols: usize, - row_stride: isize, - col_stride: isize, - __marker: PhantomData<&'a E>, -} - -struct MatMut<'a, E: Entity> { - ptr: GroupFor, - nrows: usize, - ncols: usize, - row_stride: isize, - col_stride: isize, - __marker: PhantomData<&'a mut E>, -} -``` - -The actual implementation is slightly different in order to allow `MatRef` to -have `Copy` semantics, as well as make use of the fact that `ptr` is never null -to allow for niche optimizations (such as `Option>` having the -same layout as `MatRef<'_, E>`). - -`ptr` is a group of non-null pointers to units, each pointing to a matrix with an -underlying contiguous allocation. In other words, even though the data itself -is strided, it has to have a contiguous underlying storage in order to allow -for pointer arithmetic to be valid. - -`nrows`, `ncols`, `row_stride` and `col_stride` are the matrix dimensions and -strides, which must be the same for every unit matrix in the group. - -Finally, `__marker` imbues `MatRef` and `MatMut` with the correct variance, -This allows `MatRef<'short_lifetime, E>` to be a subtype of -`MatRef<'long_lifetime, E>`, which allows for better ergonomics. - -In addition to `Copy` semantics for `MatRef`, both `MatRef` and `MatMut` naturally -provide `Move` semantics, as do most Rust types. On top of that, they also provide -`Reborrow` semantics, which currently need to be explicitly used, unlike native -references which are implicitly reborrowed. - -#pagebreak() - -Reborrowing is the act of temporarily borrowing a matrix view as another matrix -view with a shorter lifetime. For example, given a `MatMut<'a, E>`, we would like -to pass it to functions taking `MatMut<'_, E>` by value without having to consume -our object. Unlike `MatRef<'a, E>`, this is not done automatically as `MatMut` -is not `Copy`. The solution is to mutably reborrow our `MatMut` object like this -```rust -fn function_taking_mat_ref(mat: MatRef<'_, E>) {} -fn function_taking_mat_mut(mat: MatMut<'_, E>) {} - -fn mutable_reborrow_example(mut mat: MatMut<'_, E>) { - use faer::prelude::*; - - function_taking_mat_mut(mat.rb_mut()); - function_taking_mat_mut(mat.rb_mut()); - function_taking_mat_ref(mat.rb()); - function_taking_mat_ref(mat.rb()); - function_taking_mat_mut(mat); - - // does not compile, since `mat` was moved in the previous call - // function_taking_mat_mut(mat); -} -``` - -Owned matrices on the other hand are roughly defined as: -```rust -struct Mat { - ptr: GroupFor, - nrows: usize, - ncols: usize, - row_capacity: usize, - col_capacity: usize, - __marker: PhantomData, -} - -impl Drop for Mat { - fn drop(&mut self) { - // deallocate the storage - } -} -``` -Unlike matrix views, we don't need to explicitly store the strides. We know that -the row stride is equal to `1`, since the layout is column major, and the column -stride is equal to `row_capacity`. - -We also have two new fields: `row_capacity` and `col_capacity`, which represent -how much storage we have for resizing the matrix without having to reallocate. - -`Mat` can be converted to `MatRef` using `Mat::as_ref(&self)` or `MatMut` using -`Mat::as_mut(&mut self)`. - -= Vector operations -== Componentwise operations -Componentwise operations are operations that take $n$ matrices with matching -dimensions, producing an output of the same shape. Addition and subtraction -are examples of commonly used componentwise operations. - -Componentwise operations can be expressed in _`faer`_ using the `zipped!` -macro, followed by a call to `for_each` (for in-place iteration) or `map` (for -producing an output value). - -```rust -use faer_core::{zipped, unzipped}; - -fn a_plus_3b(a: MatRef<'_, f64>, b: MatRef<'_, f64>) -> Mat { - zipped!(a, b).map(|unzipped!(a, b)| { - *a + 3.0 * *b - }) -} - -fn swap_a_b(a: MatMut<'_, f64>, b: MatMut<'_, f64>) { - zipped!(a, b).for_each(|unzipped!(mut a, mut b)| { - (*a, *b) = (*b, *a); - }) -} -``` - -`zipped!` function calls can be more efficient than naive nested loops. The -reason for this is that `zipped!` analyzes the layout of the input matrices in -order to determine the optimal iteration order. For example whether it should -iterate over rows first, before columns. Or whether the iteration should happen -in reverse order (starting from the last row/column) instead of the forward -order. - -Currently, `zipped!` determines the iteration order based on the preferred -iteration order of the first matrix, but this may change in a future release. - -== Vectorized operations -SIMD (Single Instruction, Multiple Data) refers to the usage of CPU instructions -that take vectors of inputs, packed together in CPU registers, and perform the -same operation on all of them. As an example, classic addition takes two scalars -as an input and produces one output, while SIMD addition could take two vectors, -each containing 4 scalars, and adds them componentwise, producing an output vector -of 4 scalars. Correct SIMD usage is a crucial part of any linear algebra -library, given that most linear algebra operations lend themselves well to -vectorization. - -== SIMD with _`pulp`_ - -_`faer`_ provides a common interface for generic and composable SIMD, using the -_`pulp`_ crate as a backend. _`pulp`_'s high level API abstracts away the differences -between various instruction sets and provides a common API that's generic over -them (but not the scalar type). This allows users to write a generic implementation -that gets turned into several functions, one for each possible instruction set -among a predetermined subset. Finally, the generic implementation can be used along -with an `Arch` structure that determines the best implementation at runtime. - -Here's an example of how _`pulp`_ could be used to compute the expression $x^2 + -2y - |z|$, and store it into an output vector. - -```rust -use core::iter::zip; - -fn compute_expr(out: &mut[f64], x: &[f64], y: &[f64], z: &[f64]) { - struct Impl<'a> { - out: &'a mut [f64], - x: &'a [f64], - y: &'a [f64], - z: &'a [f64], - } - - impl pulp::WithSimd for Impl<'_> { - type Output = (); - - #[inline(always)] - fn with_simd(self, simd: S) { - let Self { out, x, y, z } = self; - - let (out_head, out_tail) = S::f64s_as_mut_simd(out); - let (x_head, x_tail) = S::f64s_as_simd(x); - let (y_head, y_tail) = S::f64s_as_simd(y); - let (z_head, z_tail) = S::f64s_as_simd(z); - - let two = simd.f64s_splat(2.0); - for (out, (&x, (&y, &z))) in zip( - out_head, - zip(x_head, zip(y_head, z_head)), - ) { - *out = simd.f64s_add( - x, - simd.f64s_sub(simd.f64s_mul(two, y), simd.f64s_abs(z)), - ); - } - - for (out, (&x, (&y, &z))) in zip( - out_tail, - zip(x_tail, zip(y_tail, z_tail)), - ) { - *out = x - 2.0 * y - z.abs(); - } - } - } - - pulp::Arch::new().dispatch(Impl { out, x, y, z }); -} -``` - -There's a lot of things going on at the same time in this code example. Let us -go over them step by step. - -_`pulp`_'s generic SIMD implementation happens through the `WithSimd` trait, -which takes `self` by value to pass in the function parameters. It additionally -provides another parameter to `with_simd` describing the instruction set being -used. `WithSimd::with_simd` *must* be marked with the `#[inline(always)]` attribute. -Forgetting to do so could lead to a significant performance drop. - -Inside the body of the function, we split up each of `out`, `x`, `y` and -`z` into two parts using `S::f64s_as[_mut]_simd`. The first part (`head`) is a -slice of `S::f64s`, representing the vectorizable part of the original slice. -The second part (`tail`) contains the remainder that doesn't fit into a vector -register. - -Handling the head section is done using vectorized operation. Currently these -need to take `simd` as a parameter, in order to guarantee its availability in a -sound way. This is what allows the API to be safe. The tail section is handled -using scalar operations. - -The final step is actually calling into our SIMD implementation. This is done -by creating an instance of `pulp::Arch` that performs the runtime detection -(and caches the result, so that future invocations are as fast as possible), -then calling `Arch::dispatch` which takes a type that implements `WithSimd`, -and chooses the best SIMD implementation for it. - -=== Memory alignment - -Instead of splitting the input and output slices into two sections -(vectorizable head + non-vectorizable tail), an alternative approach would be -to split them up into three sections instead (vectorizable head + vectorizable -body + vectorizable tail). This can be accomplished using masked loads and -stores, which can speed things up if the slices are _similarly aligned_. - -Similarly aligned slices are slices which have the same base address modulo -the byte size of the CPU's vector registers. The simplest way to guarantee this -is to allocate the slices in aligned memory (such that the base address is a -multiple of the register size in bytes), in which case the slices are similarly -aligned, and any subslices of them (with a shared offset and size) will also be -similarly aligned. Aligned allocation is done automatically for matrices in _`faer`_, -which helps uphold these guarantees for maximum performance. - -Here's an example of how one might write an implementation that makes use of -memory alignment, using _`pulp`_. - -```rust -use core::iter::zip; -use pulp::{Read, Write}; - -#[inline(always)] -fn compute_expr_register( - simd: S, - mut out: impl Write, - x: impl Read, - y: impl Read, - z: impl Read, -) { - let zero = simd.f64s_splat(0.0); - let x = x.read_or(zero); - let y = y.read_or(zero); - let z = z.read_or(zero); - let two = simd.f64s_splat(2.0); - out.write(simd.f64s_add( - x, - simd.f64s_sub(simd.f64s_mul(two, y), simd.f64s_abs(z)), - )); -} -impl pulp::WithSimd for Impl<'_> { - type Output = (); - #[inline(always)] - fn with_simd(self, simd: S) { - let Self { out, x, y, z } = self; - let offset = simd.f64s_align_offset(out.as_ptr(), out.len()); - - let (out_head, out_body, out_tail) = - simd.f64s_as_aligned_mut_simd(out, offset); - let (x_head, x_body, x_tail) = simd.f64s_as_aligned_simd(x, offset); - let (y_head, y_body, y_tail) = simd.f64s_as_aligned_simd(y, offset); - let (z_head, z_body, z_tail) = simd.f64s_as_aligned_simd(z, offset); - - compute_expr_register(simd, out_head, x_head, y_head, z_head); - for (out, (x, (y, z))) in zip( - out_body, - zip(x_body, zip(y_body, z_body)), - ) { - compute_expr_register(simd, out, x, y, z); - } - compute_expr_register(simd, out_tail, x_tail, y_tail, z_tail); - } -} -``` - -_`faer`_ adds one more abstraction layer on top of _`pulp`_, in order to make the -SIMD operations generic over the scalar type. This is done using the -`faer_core::group_helpers::SimdFor` struct that's effectively a thin -wrapper over `S`, and only exposes operations specific to the type `E`. - -Here's how one might implement the previous operation for a generic real scalar type. - -```rust -use faer_core::group_helpers::{SliceGroup, SliceGroupMut}; -use faer_core::RealField; - -struct Impl<'a, E: RealField> { - out: SliceGroupMut<'a, E>, - x: SliceGroup<'a, E>, - y: SliceGroup<'a, E>, - z: SliceGroup<'a, E>, -} -``` - -`&[f64]` and `&mut [f64]` are replaced by `SliceGroup<'_, E>` and -`SliceGroupMut<'_, E>`, to accomodate the fact that `E` might be an SoA type -that wants to be decomposed into multiple units. Aside from that change, most -of the code looks similar to what we had before. - -```rust -use core::iter::zip; -use faer_core::{RealField, SimdGroupFor, group_helpers::SimdFor}; -use pulp::{Read, Write}; -use reborrow::*; -#[inline(always)] -fn compute_expr_register( - simd: SimdFor, - mut out: impl Write>, - x: impl Read>, - y: impl Read>, - z: impl Read>, -) { - let zero = simd.splat(E::faer_zero()); - let two = simd.splat(E::faer_from_f64(2.0)); - let x = x.read_or(zero); - let y = y.read_or(zero); - let z = z.read_or(zero); - out.write(simd.add(x, simd.sub(simd.mul(two, y), simd.abs(z)))); -} - -``` - -```rust -impl pulp::WithSimd for Impl<'_, E> { - type Output = (); - #[inline(always)] - fn with_simd(self, simd: S) { - let Self { out, x, y, z } = self; - let simd = SimdFor::::new(simd); - let offset = simd.align_offset(out.rb()); - let (out_head, out_body, out_tail) = - simd.as_aligned_simd_mut(out, offset); - let (x_head, x_body, x_tail) = simd.as_aligned_simd(x, offset); - let (y_head, y_body, y_tail) = simd.as_aligned_simd(y, offset); - let (z_head, z_body, z_tail) = simd.as_aligned_simd(z, offset); - compute_expr_register(simd, out_head, x_head, y_head, z_head); - for (out, (x, (y, z))) in zip( - out_body.into_mut_iter(), - zip( - x_body.into_ref_iter(), - zip(y_body.into_ref_iter(), z_body.into_ref_iter()) - ), - ) { - compute_expr_register(simd, out, x, y, z); - } - compute_expr_register(simd, out_tail, x_tail, y_tail, z_tail); - } -} -``` - -== SIMD reductions -The previous examples focused on _vertical_ operations, which compute the -output of a componentwise operation and storing the result in an output vector. -Another interesting kind of operations is _horizontal_ ones, which accumulate -the result of one or more vector into one or more scalar values. - -One example of this is the dot product, which takes two vectors $a$ and $b$ of -size $n$ and computes $sum_(i = 0)^n a_i b_i$. - -One way to implement it would be like this: - -```rust -use faer_core::group_helpers::{SliceGroup, SliceGroupMut}; -use faer_core::RealField; - -struct Impl<'a, E: RealField> { - a: SliceGroup<'a, E>, - b: SliceGroup<'a, E>, -} -#[inline(always)] -fn dot_register( - simd: SimdFor, - acc: SimdGroupFor, - b: impl Read>, - a: impl Read>, -) -> SimdGroupFor { - let zero = simd.splat(E::faer_zero()); - let a = a.read_or(zero); - let b = b.read_or(zero); - simd.mul_add(a, b, acc) -} - -impl pulp::WithSimd for Impl<'_, E> { - type Output = (); - #[inline(always)] - fn with_simd(self, simd: S) { - let Self { a, b } = self; - let simd = SimdFor::::new(simd); - let offset = simd.align_offset(a); - - let (a_head, a_body, a_tail) = simd.as_aligned_simd(a, offset); - let (b_head, b_body, b_tail) = simd.as_aligned_simd(b, offset); - - let mut acc = simd.splat(E::faer_zero()); - acc = dot_register(simd, acc, a_head, b_head); - for (a, b) in zip(a_body, b_body) { - acc = dot_register(simd, acc, a, b); - } - acc = dot_register(simd, acc, a_tail, b_tail); - - simd.reduce_add(simd.rotate_left(acc, offset.rotate_left_amount())) - } -} -``` -The code looks similar to what we've written before. An interesting addition -is the use of `simd.rotate_left` in the last line. The reason for this is to -make sure our computed reduction doesn't depend on the memory offset, which can -help avoid variations in the output due to the non-associativity of floating point -arithmetic. - -For example, suppose our register size is 4 elements, and we want to compute -the dot product of 13 elements from each of $a$ and $b$. - -In the case where the memory is aligned, this is what the head, body and tail -of $a$ and $b$ look like: - -$ -a_("head") &= (a_1, a_2, a_3, a_4),\ -b_("head") &= (b_1, b_2, b_3, b_4),\ -\ -a_("body") &= [(a_5, a_6, a_7, a_8), (a_9, a_10, a_11, a_12)],\ -b_("body") &= [(b_5, b_6, b_7, b_8), (b_9, b_10, b_11, b_12)],\ -\ -a_("tail") &= (a_13, 0, 0, 0),\ -b_("tail") &= (b_13, 0, 0, 0).\ -$ - -Right before we perform the rotation, the accumulator contains the following result -$ - "acc"_1 &= a_1 b_1 + a_5 b_5 + a_9 b_9 + a_13 b_13,\ - "acc"_2 &= a_2 b_2 + a_6 b_6 + a_10 b_10 ,\ - "acc"_3 &= a_3 b_3 + a_7 b_7 + a_11 b_11 ,\ - "acc"_4 &= a_4 b_4 + a_8 b_8 + a_12 b_12 .\ -$ - -If we assume the reduction operation `simd.reduce_add` sums the elements -sequentially, we get the final result: -$ - "acc"_"aligned" = &(a_1 b_1 + a_5 b_5 + a_9 b_9 + a_13 b_13)\ - &+ (a_2 b_2 + a_6 b_6 + a_10 b_10 )\ - &+ (a_3 b_3 + a_7 b_7 + a_11 b_11 )\ - &+ (a_4 b_4 + a_8 b_8 + a_12 b_12 ).\ -$ - -Now let's take a look at the case where the memory is unaligned, for example -with an offset of 1. In this case $a$ and $b$ look like: - -$ -a'_("head") &= (0, a_1, a_2, a_3),\ -b'_("head") &= (0, b_1, b_2, b_3),\ -\ -a'_("body") &= [(a_4, a_5, a_6, a_7), (a_8, a_9, a_10, a_11)],\ -b'_("body") &= [(a_4, b_5, b_6, b_7), (b_8, b_9, b_10, b_11)],\ -\ -a'_("tail") &= (a_12, a_13, 0, 0),\ -b'_("tail") &= (b_12, b_13, 0, 0).\ -$ - -Right before we perform the rotation, the accumulator contains the following result -$ - "acc'"_1 &= a_4 b_4 + a_8 b_8 + a_12 b_12,\ - "acc'"_2 &= a_1 b_1 + a_5 b_5 + a_9 b_9 + a_13 b_13,\ - "acc'"_3 &= a_2 b_2 + a_6 b_6 + a_10 b_10 ,\ - "acc'"_4 &= a_3 b_3 + a_7 b_7 + a_11 b_11 .\ -$ - -If we use `simd.reduce_add` directly, without going through `simd.rotate_left` first, we get -this result: - -$ - "result"_("unaligned"(1)) = &( a_4 b_4 + a_8 b_8 + a_12 b_12)\ - +&(a_1 b_1 + a_5 b_5 + a_9 b_9 + a_13 b_13)\ - +&(a_2 b_2 + a_6 b_6 + a_10 b_10 )\ - +&(a_3 b_3 + a_7 b_7 + a_11 b_11 )\ -$ - -Mathematically, the result is equivalent, but since floating point operations round the result, -we would get a slightly different result for the aligned and unaligned cases. - -Our solution is to first rotate the accumulator to the left by the alignment offset. -Doing this right before the accumulation would give us the rotated accumulator: -$ - "rotate"_1 &=& "acc'"_2 &= a_1 b_1 + a_5 b_5 + a_9 b_9 + a_13 b_13 &&= "acc"_1,\ - "rotate"_2 &=& "acc'"_3 &= a_2 b_2 + a_6 b_6 + a_10 b_10 &&= "acc"_2,\ - "rotate"_3 &=& "acc'"_4 &= a_3 b_3 + a_7 b_7 + a_11 b_11 &&= "acc"_3,\ - "rotate"_4 &=& "acc'"_1 &= a_4 b_4 + a_8 b_8 + a_12 b_12 &&= "acc"_4.\ -$ - -Summing these sequentially would then give us the exact result as $"acc"_"aligned"$. - -#pagebreak() - -= Matrix multiplication -In this section we will give a detailed overview of the techniques used to -speed up matrix multiplication in _`faer`_. The approach we use is a -reimplementation of BLIS's matrix multiplication algorithm with some -modifications. - -Consider three matrices $A$, $B$ and $C$, such that we want to perform the operation -$ C "+=" A B. $ - -We can chunk $A$, $B$ and $C$ in a way that is compatible with matrix multiplication: -$ -A = mat( - A_11 , A_12 , ... , A_(1 k); - A_21 , A_22 , ... , A_(2 k); - dots.v , dots.v , dots.down, dots.v ; - A_(m 1), A_(m 2), ... , A_(m k); -),\ -B = mat( - B_11 , B_12 , ... , B_(1 n); - B_21 , B_22 , ... , B_(2 n); - dots.v , dots.v , dots.down, dots.v ; - B_(k 1), B_(k 2), ... , B_(k n); -),\ -C = mat( - C_11 , C_12 , ... , C_(1 n); - C_21 , C_22 , ... , C_(2 n); - dots.v , dots.v , dots.down, dots.v ; - C_(m 1), C_(m 2), ... , C_(m n); -). -$ - -Then the $C "+=" A B$ operation may be decomposed into: -#set math.mat(delim: none, column-gap: 2.0em) -$ -mat( - C_(1 1) "+=" sum_(p = 1)^(k) A_(1 p) B_(p 1), C_(1 2) "+=" sum_(p = 1)^(k) A_(1 p) B_(p 2), ... , C_(1 n) "+=" sum_(p = 1)^(k) A_(1 p) B_(p n); - C_(2 1) "+=" sum_(p = 1)^(k) A_(2 p) B_(p 1), C_(2 2) "+=" sum_(p = 1)^(k) A_(2 p) B_(p 2), ... , C_(2 n) "+=" sum_(p = 1)^(k) A_(2 p) B_(p n); - dots.v , dots.v , dots.down, dots.v ; - C_(m 1) "+=" sum_(p = 1)^(k) A_(m p) B_(p 1), C_(m 2) "+=" sum_(p = 1)^(k) A_(m p) B_(p 2), ... , C_(m n) "+=" sum_(p = 1)^(k) A_(m p) B_(p n); -). -$ - -#set math.mat(delim: "(", column-gap: 0.5em) - -Doing so does not decrease the number of flops (floating point operations). But -this restructuring step can lead to a large speedup if done correctly, by -making use of cache locality on modern CPUs/GPUs. - -The general idea revolves around memory reuse. For now, let us consider the -case of a single thread executing the entire operation. The multithreaded case -can be handled with a few adjustments. - - - The algorithm we use computes sequentially $C$ by column blocks. In other words, we first compute $C_(: 1)$, then $C_(: 2)$ and so on. - - For each column block $j$, we sequentially iterate over $p$, computing all the terms $A_(: p) B_(p j)$, and accumulate them to the output. - - Then for each row block $i$, we compute $A_(i p) B_(p j)$ and accumulate it to $C_(i j)$. - -Since most modern CPUs have a hierarchical cache structure (usually ranging -from L1 (smallest) to L3 (largest)), we would like to make use of this in our -algorithm for maximum efficiency. - -The way we exploit this is by choosing the chunk dimensions so that $B_(p j)$ -remains in the L3 cache during each iteration of the second loop, and $A_(i p)$ -remains in the L2 cache during each iteration of the third loop. This leaves us -with one more cache level to use: the L1 cache. - -We make the most use out of this by chunking the inner product once again, resulting in -two more loop levels: -$ -A_(i p) = mat( - A'_1p ; - A'_2p ; - dots.v ; - A'_(m' p); -),\ -B_(p j) = mat( - B'_(p 1) , B_(p 2) , ... , B_(p n'); -),\ -C_(i j) = mat( - C'_(1 1) , C'_(1 2) , ... , C'_(1 n') ; - C'_(2 1) , C'_(2 2) , ... , C'_(2 n') ; - dots.v , dots.v , dots.down, dots.v ; - C'_(m' 1), C'_(m' 2), ... , C'_(m' n'); -). -$ - -We iterate over each column block $B'_(p j')$, then over each row block -$A'_(i' p)$, -and accumulate the product $A'_(i' p) B'_(p j')$ to $C'_(i' j')$. - -In the outer loop, $B'_(p j')$ is brought from the L3 cache into the L1 cache, -and stays there until the outer loop iteration is done. -Then in the inner loop, it is brought once again from the L1 cache into -registers, while $A'_(i' p)$ is brought from the L2 cache into registers, -so that the computation can be performed. - -This last step is done using a vectorized microkernel, which heavily uses -SIMD instructions to maximize efficiency. -The number of registers limits the size of the microkernel. - -During each iteration $p'$ of the microkernel, we can bring in $m_r$ elements -from the $A'_(i' p')$, and $n_r$ elements from $B'_(p' j')$, where $m_r$ -and $n_r$ are the dimensions of the microkernel (as well as the dimensions of -each block $C'_(i' j')$). We use the following algorithm: - - - - Load one element from $B'$, - - Load $m_r$ elements from $A'$, - - Multiply the $m_r$ elements from $A'$ by the element from $B'$, and accumulate the result to $C'$ - -Consider x86-64 as an example, with the AVX2+FMA instruction set (256-bit -registers), and suppose the scalar type has a size of 64 bits so that each -register can hold $N = 256/64 = 4$ scalars. We have 16 available registers in -total, this means we can load one element from $B'$ into one register, and $m_r$ -elements from $A'$ into $m_r / N$ registers. - -Since we don't want to constantly read and write to $C'$ from main memory, we use a local accumulator -that occupies $m_r / N n_r$ registers. - -In this case we have a total of 16 available registers, which need to hold $1 + -m_r / N + m_r / N n_r$ registers. A good choice for our case is $m_r = 3N$, -$n_r = 4$, which requires exactly 16 registers. - -To determine the chunk sizes, we compute them starting from the innermost loop -to the outermost loop. - -Given that we've already computed $m_r$ and $n_r$, we determine $k_c$ (the -number of columns of $A_(i p)$, and also the number of rows of $B_(p j)$) so -that $B'_(p, j')$ fits into the L1 cache, then we determine $m_c$ (the number -of rows of $A_(i p)$) so that $A_(i p)$ fits into the L2 cache. -And finally we determine $n_c$ (the number of columns of $B_(p j)$) so that -$B_(: j)$ fits into the L3 cache. - -Note that bringing data into the cache is typically done automatically by the CPU. -However, in our case, we want to perform that explicitly by storing each L3/L2 chunk -into packed storage, which allows for contiguous access that's friendly to the CPU's -hardware prefetcher and minimizes TLB (Translation Lookaside Buffer) misses. - -In order to parallelize the algorithm, we have a few options. The second -loop can't be easily parallelized without allocating extra storage for the -accumulators, since we have a data dependency in the second loop. -The microkernel also doesn't perform enough work to compensate for the overhead -of core synchronization, so it also makes for a poor multithreading candidate. - -This leaves us with the first loop, as well as the third, fourth and fifth -loops. The first loop stores its data in the L3 cache, which is typically -shared between cores. So it's not usually a very attractive candidate for -parallelization. - -Loops three through five however can be parallelized in a straightforward way, -and make good use of each core's separate L1 (and often L2) cache, which leads -to a significant speedup. - -Since data is packed explicitly during matrix multiplication, the original -layout of the input matrices has little effect on efficiency when the -dimensions are medium or large. This has the side-effect of matrix multiplication -being highly efficient regardless of the matrix layout. - -== Special cases -In this section, we will refer to matrix multiplication with dimensions $(m, n) -× (n, k)$ with $(m, n, k)$ as a shorthand. - -For special common matrix dimensions, we do not usually want to go through the -aforementioned strategy, because the packing and unpacking steps, as well as -the microkernel indirection can add considerable overhead. - -Such cases include: -- inner product $(1, 1, k)$, -- outer product: $(m, n, 1)$, -- matrix-vector: $(m, 1, k)$, -- vector-matrix: $(1, n, k)$, which can be rewritten in terms of matrix-vector by transposing, if we assume that the scalar multiplication is commutative (which _`faer`_ generally does), since $C "+=" A B <=> C^top "+=" B^top A^top$. - -The $(1, 1, k)$ case can be optimized for when $A$ is row-major and $B$ is -column-major, and is written similarly to our previous dot product example, -with one difference: Instead of using one accumulator for the result, we use -multiple accumulators, and then sum them together at the end. This can speed up -the computation by making use of instruction level parallelism, since each -accumulator can be computed independently from the others. - -For the $(m, n, 1)$ case, we assume $C$ is column-major. If it is row-major, we -can implicitly transpose the matrix multiply operation. The algorithm we use -consists of computing $C$ column by column, which is equivalent to $C_(: j) "+=" -A b_(1 j)$. This can be vectorized as a vertical operation, if $A$ is column-major. -If it is not, we can store it to contiguous temporary storage before performing -the computation. Note that this is a relatively cheap operation since its -dimensions are $(m, 1)$, which is usually much smaller than the size of $C$: -$(m, n)$. - -For the $(m, 1, k)$ case, there are two interesting cases. The first one is -when $A$ is column major. In this case we assume $C$ is column major -(otherwise, we can compute the result in a temporary vector and accumulate it -to $C$ afterwards). For each column of $A$, we multiply it by the corresponding -element of $B$ and accumulate it to $C$. The inner kernel for this operation is -$C "+=" A_(: k) b_(k 1)$, which is essentially the same as the one from the outer -product. - -When $A$ is row-major, we assume $B$ is column-major, and compute $C_(i 1) "+=" -A_(i :) B_(: 1)$, which uses the same kernel as the $(1, 1, k)$ case. - -== Triangular matrix products -In some cases, one of the matrices $A$ and $B$ is triangular (with a possibly -implicit zero or one diagonal), or we only want to compute the lower or upper -half of the output. _`faer`_ currently uses recursive implementations that are -padded and handled as rectangular matrix multiplication in the base case. For -example, we may want to compute $A B$ where $A$ is lower triangular and $B$ is -upper triangular: - -$ -mat( - C_(1 1), C_(1 2); - C_(2 1), C_(2 2); -) "+=" -mat( - A_(1 1), 0; - A_(2 1), A_(2 2); -) -mat( - B_(1 1), B_(1 2); - 0 , B_(2 2); -) -$ - -This can be split up into a sequence of products: - -$ -C_(1 1) & "+=" A_(1 1) B_(1 1),\ -C_(1 2) & "+=" A_(1 1) B_(1 2),\ -C_(2 1) & "+=" A_(2 1) B_(1 1),\ -C_(2 2) & "+=" A_(2 1) B_(1 2) + A_(2 2) B_(2 2). -$ - -The steps $C_(1 1) "+=" A_(1 1) B_(1 1)$ and $C_(2 2) "+=" A_(2 2) B_(2 2)$ -are also matrix prodcuts where the LHS is lower triangular and the RHS is upper -triangular, so we call the algorithm recursively for them. - -All of these products can be performed either sequentially or in parallel. -In the parallel case, we group them like this to avoid load imbalances between -threads. - -#set align(center) -#gridx( - columns: (auto, auto), - align: center, - "Thread 1", vlinex(), "Thread 2", - $ - C_(1 1) & "+=" A_(1 1) B_(1 1)\ - C_(2 2) & "+=" A_(2 1) B_(1 2) + A_(2 2) B_(2 2) - $, (), - $ - C_(1 2) & "+=" A_(1 1) B_(1 2)\ - C_(2 1) & "+=" A_(2 1) B_(1 1)\ - $, -) -#set align(left) - -This way each thread performs roughtly the same number of flops, which helps -avoid idling threads that spend time waiting for the others to finish. -Every time we recurse to another triangular matrix multiplication, we can split -up the work again. And if we perform a rectangular matrix multiply, we can rely on -its inherent parallelism. - -This is one of the scenarios where _`faer`_'s fine control over multithreading shines, -as we can provide a hint for each nested operation that it doesn't need to use -all the available cores, which can reduce synchronization overhead and conflict -over shared resources by the different threads. - -#colorbox( - title: "PERF", - color: "blue", - radius: 2pt, - width: auto -)[ - The current strategy doesn't take advantage of the CPU's cache hierarchy for - deciding how to split up the work between threads. This could lead to - multiple threads contending over the L3 cache for very large matrices. - - In that case it could be worth investigating if it's better to only use - triangular threading when the sizes fall below a certain threshold. -] - -#colorbox( - title: "PERF", - color: "blue", - radius: 2pt, - width: auto -)[ - Specialized implementations for specific dimensions, similarly to what is - done for the rectangular matrix multiply. For example a lower triangular - matrix times a column vector. -] - -== Triangular matrix solve -Solving systems of the form $A X = B$ (where $A$ is a triangular matrix) is -another primitive that is implemented using a recursive implementation that -relies on matrix multiplication. For simplicity, we assume $A$ is lower -triangular. The cases where $A$ is unit lower triangular or [unit] upper -triangular are handled similarly. - -If we decompose $A$, $B$ and $X$ as follows: -$ -A = mat(A_(1 1), ; A_(2 1), A_(2 2)), -B = mat(B_1; B_2), -X = mat(X_1; X_2), -$ -then the system can be reformulated as -$ -A_(1 1) X_1 &= B_1,\ -A_(2 2) X_2 &= B_2 - A_(2 1) X_1. -$ -This system can be solved sequentially by first solving the first equation, -then substituting its solution in the second equation. Moreover, the system can -be solved in place, by taking an inout parameter that contains $B$ on input, -and $X$ on output. - -Once the recursion reaches a certain threshold, we fall back to a sequential -implementation to avoid the recursion overhead. diff --git a/book/user_guide.typ b/book/user_guide.typ deleted file mode 100644 index 2aaacaf29abc0b1be039fdfb941dba2a7765dc09..0000000000000000000000000000000000000000 --- a/book/user_guide.typ +++ /dev/null @@ -1,512 +0,0 @@ -#set text(font: "New Computer Modern") - -#show raw: set text(font: "New Computer Modern Mono", size: 1.2em) - -#show par: set block(spacing: 0.55em) - -#show heading: set block(above: 1.4em, below: 1em) - -#show link: underline - -#set page(numbering: "1") - -#set par(leading: 0.55em, justify: true) - -#set heading(numbering: "1.1") - -#show heading.where(level: 1): it => pagebreak(weak:true) + block({ - set text(font: "New Computer Modern", weight: "black") - v(2cm) - block(text(18pt)[Chapter #counter(heading).display()]) - v(1cm) - block(text(22pt)[#it.body]) - v(1cm) -}) - -#import "@preview/codly:0.1.0" -#import "@preview/tablex:0.0.6": tablex, rowspanx, colspanx, gridx, hlinex, vlinex -#import "@preview/colorful-boxes:1.2.0": colorbox - -#let icon(codepoint) = { - box( - height: 0.8em, - baseline: 0.05em, - image(codepoint) - ) - h(0.1em) -} - -#show: codly.codly-init.with() - -#codly.codly( - languages: ( - rust: (name: "Rust", icon: icon("brand-rust.svg"), color: rgb("#CE412B")), - ), - breakable: false, - width-numbers: none, -) - -#outline() - -== Introduction -_`faer-rs`_ is a general-purpose linear algebra library for the Rust -programming language, with a focus on correctness, portability, and -performance. -In this book, we'll be assuming version `0.16.0` of the library. - -A matrix is a 2-dimensional array of numerical values, which can represent -different things depending on the context. In the context of linear algebra, -it is often used to represent a linear transformation mapping vectors from -one finite-dimensional space to another. - -Column vectors are typically elements of the vector space, but may also be used -interchangeably with $n×1$ matrices. Row vectors are also similarly used -interchangeably with $1×n$ matrices. - -= Dense linear algebra -== Creating a matrix -_`faer`_ provides several ways to create matrices and matrix views. - -The main matrix types are `faer_core::Mat`, `faer_core::MatRef` and `faer_core::MatMut`, -which can be thought of as being analogous to `Vec`, `&[_]` and `&mut [_]`. - -The most flexible way to initialize a matrix is to initialize a zero matrix, -then fill out the values by hand. - -```rust -let mut a = Mat::::zeros(4, 3); - -for j in 0..a.ncols() { - for i in 0..a.nrows() { - a[(i, j)] = 9.0; - } -} -``` - -Given a callable object that outputs the matrix elements, `Mat::from_fn`, can also be used. - -```rust -let a = Mat::from_fn(3, 4, |i, j| (i + j) as f64); -``` - -For common matrices such as the zero matrix and the identity matrix, shorthands -are provided. - -```rust -use faer_core::Mat; - -// creates a 10×4 matrix whose values are all `0.0`. -let a = Mat::::zeros(10, 4); - -// creates a 5×4 matrix containing `0.0` except on the main diagonal, -// which contains `1.0` instead. -let a = Mat::::identity(5, 4); -``` -#colorbox( - title: "Note", - color: "green", - radius: 2pt, - width: auto -)[ - In some cases, users may wish to avoid the cost of initializing the matrix to zero, - in which case, unsafe code may be used to allocate an uninitialized matrix, which - can then be filled out before it's used. - ```rust - // `a` is initially a 0×0 matrix. - let mut a = Mat::::with_capacity(4, 3); - - // `a` is now a 4×3 matrix, whose values are uninitialized. - unsafe { a.set_dims(4, 3) }; - - for j in 0..a.ncols() { - for i in 0..a.nrows() { - // we cannot write `a[(i, j)] = 9.0`, as that would - // create a reference to uninitialized data, - // which is currently disallowed by Rust. - a.write(i, j, 9.0); - } - } - ``` -] - -== Creating a matrix view -In some situations, it may be desirable to create a matrix view over existing -data. -In that case, we can use `faer_core::MatRef` (or `faer_core::MatMut` for -mutable views). - -They can be created in a safe way using: - -`faer_core::mat::{from_column_major_slice, from_row_major_slice}`, - -`faer_core::mat::{from_column_major_slice_mut, from_row_major_slice_mut}`, - -for contiguous matrix storage, or: - -`{from_column_major_slice_with_stride, from_row_major_slice_with_stride}`, - -`{from_column_major_slice_with_stride_mut, from_row_major_slice_with_stride_mut}`, - -for strided matrix storage. - -#colorbox( - title: "Note", - color: "green", - radius: 2pt, - width: auto -)[ - A lower level pointer API is also provided for handling uninitialized data - or arbitrary strides `{from_raw_parts, from_raw_parts_mut}`. -] - -== Converting to a view - -A `Mat` instance `m` can be converted to `MatRef` or `MatMut` by writing `m.as_ref()` -or `m.as_mut()`. - -== Reborrowing a mutable view - -Immutable matrix views can be freely copied around, since they are non-owning -wrappers around a pointer and the matrix dimensions/strides. - -Mutable matrices however are limited by Rust's borrow checker. Copying them -would be unsound since only a single active mutable view is allowed at a time. - -This means the following code does not compile. - -```rust -use faer::{Mat, MatMut}; - -fn takes_view_mut(m: MatMut) {} - -let mut a = Mat::::new(); -let view = a.as_mut(); - -takes_view_mut(view); - -// This would have failed to compile since `MatMut` is never `Copy` -// takes_view_mut(view); -``` - -The alternative is to temporarily give up ownership over the data, by creating -a view with a shorter lifetime, then recovering the ownership when the view is -no longer being used. - -This is also called reborrowing. - -```rust -use faer::{Mat, MatMut, MatRef}; -use reborrow::*; - -fn takes_view(m: MatRef) {} -fn takes_view_mut(m: MatMut) {} - -let mut a = Mat::::new(); -let mut view = a.as_mut(); - -takes_view_mut(view.rb_mut()); -takes_view_mut(view.rb_mut()); -takes_view(view.rb()); // We can also reborrow immutably - -{ - let short_view = view.rb_mut(); - - // This would have failed to compile since we can't use the original view - // while the reborrowed view is still being actively used - // takes_view_mut(view); - - takes_view_mut(short_view); -} - -// We can once again use the original view -takes_view_mut(view.rb_mut()); - -// Or consume it to convert it to an immutable view -takes_view(view.into_const()); -``` - -== Splitting a matrix view, slicing a submatrix -A matrix view can be split up along its row axis, column axis or both. -This is done using `MatRef::split_at_row`, `MatRef::split_at_col` or -`MatRef::split_at` (or `MatMut::split_at_row_mut`, `MatMut::split_at_col_mut` or -`MatMut::split_at_mut`). - -These functions take the middle index at which the split is performed, and return -the two sides (in top/bottom or left/right order) or the four corners (top -left, top right, bottom left, bottom right) - -We can also take a submatrix using `MatRef::subrows`, `MatRef::subcols` or -`MatRef::submatrix` (or `MatMut::subrows_mut`, `MatMut::subcols_mut` or -`MatMut::submatrix_mut`). - -Alternatively, we can also use `MatRef::get` or `MatMut::get_mut`, which take -as parameters the row and column ranges. - -#colorbox( - title: "Warning", - color: "red", - radius: 2pt, - width: auto -)[ - Note that `MatRef::submatrix` (and `MatRef::subrows`, `MatRef::subcols`) takes - as a parameter, the first row and column of the submatrix, then the number - of rows and columns of the submatrix. - - On the other hand, `MatRef::get` takes a range from the first row and column - to the last row and column. -] - -== Matrix arithmetic operations -_`faer`_ matrices implement most of the arithmetic operators, so two matrices -can be added simply by writing `&a + &b`, the result of the expression is a -`faer::Mat`, which allows chaining operations (e.g. `(&a + &b) * &c`), although -at the cost of allocating temporary matrices. - -#colorbox( - title: "Note", - color: "green", - radius: 2pt, - width: auto -)[ - Temporary allocations can be avoided by using the zip api: -```rust -use faer::{Mat, zipped, unzipped}; - -let a = Mat::::zeros(4, 3); -let b = Mat::::zeros(4, 3); -let mut c = Mat::::zeros(4, 3); - -// Sums `a` and `b` and stores the result in `c`. -zipped!(&mut c, &a, &b).for_each(|unzipped!(c, a, b)| *c = *a + *b); - -// Sums `a`, `b` and `c` into a new matrix `d`. -let d = zipped!(&mut c, &a, &b).map(|unzipped!(c, a, b)| *a + *b + *c); -``` - For matrix multiplication, the non-allocating api is provided in the - `faer_core::mul` module. - -```rust -use faer::{Mat, Parallelism}; -use faer_core::mul::matmul; - -let a = Mat::::zeros(4, 3); -let b = Mat::::zeros(3, 5); - -let mut c = Mat::::zeros(4, 5); - -// Computes `faer::scale(3.0) * &a * &b` and stores the result in `c`. -matmul(c.as_mut(), a.as_ref(), b.as_ref(), None, 3.0, Parallelism::None); - -// Computes `faer::scale(3.0) * &a * &b + 5.0 * &c` and stores the result in `c`. -matmul(c.as_mut(), a.as_ref(), b.as_ref(), Some(5.0), 3.0, Parallelism::None); -``` -] - -== Solving a linear system -Several applications require solving a linear system of the form $A x = b$. -The recommended method can vary depending on the properties of $A$, and the -desired numerical accuracy. - -=== $A$ is triangular -In this case, one can use $A$ and $b$ directly to find $x$, using the functions -provided in `faer_core::solve`. - -```rust -use faer::{Mat, Parallelism}; -use faer_core::solve::solve_lower_triangular_in_place; - -let a = Mat::::from_fn(4, 4, |i, j| if i >= j { 1.0 } else { 0.0 }); -let b = Mat::::from_fn(4, 2, |i, j| (i - j) as f64); - -let mut x = Mat::::zeros(4, 2); -x.copy_from(&b); -solve_lower_triangular_in_place(a.as_ref(), x.as_mut(), Parallelism::None); - -// x now contains the approximate solution -``` - -In the case where $A$ has a unit diagonal, one can use -`solve_unit_lower_triangular_in_place`, which avoids reading the diagonal, and -instead implicitly uses the value `1.0` as a replacement. - -=== $A$ is real-symmetric/complex-Hermitian -If $A$ is Hermitian and positive definite, users can use the Cholesky LLT -decomposition. - -```rust -use faer::{mat, Side}; -use faer::prelude::*; - -let a = mat![ - [10.0, 2.0], - [2.0, 10.0f64], -]; -let b = mat![[15.0], [-3.0f64]]; - -// Compute the Cholesky decomposition, -// reading only the lower triangular half of the matrix. -let llt = a.cholesky(Side::Lower).unwrap(); - -let x = llt.solve(&b); -``` - -Alternatively, a lower-level API could be used to avoid temporary allocations. -The corresponding code for other decompositions follows the same pattern, so we -will avoid repeating it. - -```rust -use faer::{mat, Parallelism, Conj}; -use faer_cholesky::llt::compute::cholesky_in_place_req; -use faer_cholesky::llt::compute::{cholesky_in_place, LltRegularization, LltParams}; -use faer_cholesky::llt::solve::solve_in_place_req; -use faer_cholesky::llt::solve::solve_in_place_with_conj; -use dyn_stack::{PodStack, GlobalPodBuffer}; - -let a = mat![ - [10.0, 2.0], - [2.0, 10.0f64], -]; -let mut b = mat![[15.0], [-3.0f64]]; - -let mut llt = Mat::::zeros(2, 2); -let no_par = Parallelism::None; - -// Compute the size and alignment of the required scratch space -let cholesky_memory = cholesky_in_place_req::( - a.nrows(), - Parallelism::None, - LltParams::default(), -).unwrap(); -let solve_memory = solve_in_place_req::( - a.nrows(), - b.ncols(), - Parallelism::None, -).unwrap(); - -// Allocate the scratch space -let mut memory = GlobalPodBuffer::new(cholesky_memory.or(solve_memory)); -let mut stack = PodStack::new(&mut mem); - -// Compute the decomposition -llt.copy_from(&a); -cholesky_in_place( - llt.as_mut(), - LltRegularization::default(), // no regularization - no_par, - stack.rb_mut(), // scratch space - LltParams::default(), // default settings -); -// Solve the linear system -solve_in_place_with_conj(llt.as_ref(), Conj::No, b.as_mut(), no_par, stack); -``` - -If $A$ is not positive definite, the Bunch-Kaufman LBLT decomposition is recommended instead. -```rust -use faer::{mat, Side}; -use faer::prelude::*; - -let a = mat![ - [10.0, 2.0], - [2.0, -10.0f64], -]; -let b = mat![[15.0], [-3.0f64]]; - -// Compute the Bunch-Kaufman LBLT decomposition, -// reading only the lower triangular half of the matrix. -let lblt = a.lblt(Side::Lower); - -let x = lblt.solve(&b); -``` - -=== $A$ is square -For a square matrix $A$, we can use the LU decomposition with partial pivoting, -or the full pivoting variant which is slower but can be more accurate when the -matrix is nearly singular. - -```rust -use faer::mat; -use faer::prelude::*; - -let a = mat![ - [10.0, 3.0], - [2.0, -10.0f64], -]; -let b = mat![[15.0], [-3.0f64]]; - -// Compute the LU decomposition with partial pivoting, -let plu = a.partial_piv_lu(); -let x1 = plu.solve(&b); - -// or the LU decomposition with full pivoting. -let flu = a.full_piv_lu(); -let x2 = flu.solve(&b); -``` - -=== $A$ is a tall matrix (least squares solution) -When the linear system is over-determined, an exact solution may not -necessarily exist, in which case we can get a best-effort result by computing -the least squares solution. -That is, the solution that minimizes $||A x - b||$. - -This can be done using the QR decomposition. - -```rust -use faer::mat; -use faer::prelude::*; - -let a = mat![ - [10.0, 3.0], - [2.0, -10.0], - [3.0, -45.0f64], -]; -let b = mat![[15.0], [-3.0], [13.1f64]]; - -// Compute the QR decomposition. -let qr = a.qr(); -let x = qr.solve_lstsq(&b); -``` - -== Computing the singular value decomposition -```rust -use faer::mat; -use faer::prelude::*; - -let a = mat![ - [10.0, 3.0], - [2.0, -10.0], - [3.0, -45.0f64], -]; - -// Compute the SVD decomposition. -let svd = a.svd(); -// Compute the thin SVD decomposition. -let svd = a.thin_svd(); -// Compute the singular values. -let svd = a.singular_values(); -``` - -== Computing the eigenvalue decomposition -```rust -use faer::mat; -use faer::prelude::*; -use faer::complex_native::c64; - -let a = mat![ - [10.0, 3.0], - [2.0, -10.0f64], -]; - -// Compute the eigendecomposition. -let evd = a.eigendecomposition::(); - -// Compute the eigenvalues. -let evd = a.eigen_values::(); - -// Compute the eigendecomposition assuming `a` is Hermitian. -let evd = a.selfadjoint_eigendecomposition(); - -// Compute the eigenvalues assuming `a` is Hermitian. -let evd = a.selfadjoint_eigenvalues(); -``` - -= Sparse linear algebra diff --git a/faer-libs/faer/examples/cholesky.rs b/examples/cholesky.rs similarity index 100% rename from faer-libs/faer/examples/cholesky.rs rename to examples/cholesky.rs diff --git a/faer-libs/faer/examples/pseudoinverse.rs b/examples/pseudoinverse.rs similarity index 84% rename from faer-libs/faer/examples/pseudoinverse.rs rename to examples/pseudoinverse.rs index 64af6169d778506b406ca34adeeac2c0b84e761a..a178e27a3787f2b941d83d44d41ade134da223e4 100644 --- a/faer-libs/faer/examples/pseudoinverse.rs +++ b/examples/pseudoinverse.rs @@ -1,4 +1,4 @@ -use faer::{assert_matrix_eq, mat, prelude::*, Mat}; +use faer::{assert_matrix_eq, mat, Mat}; fn main() { let matrix = mat![ @@ -14,7 +14,7 @@ fn main() { let s_diag = svd.s_diagonal(); let mut s_inv = Mat::zeros(n, m); for i in 0..Ord::min(m, n) { - s_inv[(i, i)] = 1.0 / s_diag[(i, 0)]; + s_inv[(i, i)] = 1.0 / s_diag[i]; } let pseudoinv = svd.v() * &s_inv * svd.u().adjoint(); diff --git a/faer-bench/Cargo.toml b/faer-bench/Cargo.toml index ac19f3e86cee926305cabfdfc621a06411c11757..9b327ccc9bb95b303860c6f0fea0694913f6c9e2 100644 --- a/faer-bench/Cargo.toml +++ b/faer-bench/Cargo.toml @@ -11,13 +11,7 @@ nalgebra = "0.32.1" ndarray = { version = "0.15.6", features = ["blas"] } blas-src = { version = "0.8", features = ["openblas"] } faer-entity = { path = "../faer-entity", features = ["nightly"] } - -faer-core = { path = "../faer-libs/faer-core", features = ["nightly"] } -faer-lu = { path = "../faer-libs/faer-lu", features = ["nightly"] } -faer-qr = { path = "../faer-libs/faer-qr", features = ["nightly"] } -faer-svd = { path = "../faer-libs/faer-svd", features = ["nightly"] } -faer-evd = { path = "../faer-libs/faer-evd", features = ["nightly"] } -faer-cholesky = { path = "../faer-libs/faer-cholesky", features = ["nightly"] } +faer = { path = "..", features = ["nightly"] } human-repr = "1.0.1" ndarray-linalg = { version = "0.16.0", features = ["openblas-system"], git = "https://github.com/mike-kfed/ndarray-linalg.git", branch="arm-cross-compile"} @@ -28,7 +22,6 @@ openmp-sys = "1.2.3" num-traits = "0.2.15" coe-rs = "0.1.2" num-complex = "0.4.3" -pulp = "0.16" bytemuck = "1" rayon = "1.8" core_affinity = "0.8" diff --git a/faer-bench/src/cholesky.rs b/faer-bench/src/cholesky.rs index f829265c2b26efb7e6315556086cf160e9c7a164..d1d7a1fff009517cc026c75ac21f8c8795ad931f 100644 --- a/faer-bench/src/cholesky.rs +++ b/faer-bench/src/cholesky.rs @@ -1,6 +1,6 @@ use super::timeit; use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut}; -use faer_core::{unzipped, zipped, Mat, Parallelism}; +use faer::{linalg::cholesky as faer_cholesky, unzipped, zipped, Mat, Parallelism}; use ndarray_linalg::Cholesky; use std::time::Duration; @@ -44,10 +44,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() diff --git a/faer-bench/src/col_piv_qr.rs b/faer-bench/src/col_piv_qr.rs index 9a5a5a171ca4ef37167568bcd8b8ec2bf338709b..6184976f830dac14cfea108332b1444ca029adc8 100644 --- a/faer-bench/src/col_piv_qr.rs +++ b/faer-bench/src/col_piv_qr.rs @@ -1,7 +1,7 @@ use super::timeit; use crate::random; use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut}; -use faer_core::{unzipped, zipped, Mat, Parallelism}; +use faer::{linalg::qr as faer_qr, unzipped, zipped, Mat, Parallelism}; use faer_qr::no_pivoting::compute::recommended_blocksize; use std::time::Duration; @@ -36,10 +36,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() diff --git a/faer-bench/src/double_f64.rs b/faer-bench/src/double_f64.rs index 6bb248b0e7b074b1864db46c2430787322873ddc..72e04d07466f214024ee289cc51b288f80641326 100644 --- a/faer-bench/src/double_f64.rs +++ b/faer-bench/src/double_f64.rs @@ -1,11 +1,17 @@ use bytemuck::{Pod, Zeroable}; -use faer_core::pulp::Simd; -use pulp::Scalar; +use faer_entity::{ + pulp::{Scalar, Simd}, + *, +}; +// https://web.mit.edu/tabbott/Public/quaddouble-debian/qd-2.3.4-old/docs/qd.pdf +// https://gitlab.com/hodge_star/mantis + +/// Value representing the implicit sum of two floating point terms, such that the absolute +/// value of the second term is less half a ULP of the first term. #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] #[repr(C)] pub struct Double(pub T, pub T); -pub type DoubleF64 = Double; unsafe impl Zeroable for Double {} unsafe impl Pod for Double {} @@ -22,20 +28,19 @@ impl Iterator for Double { } #[inline(always)] -fn quick_two_sum(simd: S, a: S::f64s, b: S::f64s) -> (S::f64s, S::f64s) { - let s = simd.f64s_add(a, b); - let err = simd.f64s_sub(b, simd.f64s_sub(s, a)); - (s, err) -} - -#[inline(always)] -fn two_sum(simd: S, a: S::f64s, b: S::f64s) -> (S::f64s, S::f64s) { - let s = simd.f64s_add(a, b); - let bb = simd.f64s_sub(s, a); - - // (a - (s - bb)) + (b - bb) - let err = simd.f64s_add(simd.f64s_sub(a, simd.f64s_sub(s, bb)), simd.f64s_sub(b, bb)); - (s, err) +#[allow(dead_code)] +fn two_sum_e(simd: S, a: S::f64s, b: S::f64s) -> (S::f64s, S::f64s) { + let sign_bit = simd.f64s_splat(-0.0); + let cmp = simd.u64s_greater_than( + pulp::cast(simd.f64s_or(a, sign_bit)), + pulp::cast(simd.f64s_or(b, sign_bit)), + ); + let (a, b) = ( + simd.m64s_select_f64s(cmp, a, b), + simd.m64s_select_f64s(cmp, b, a), + ); + + quick_two_sum(simd, a, b) } #[inline(always)] @@ -47,6 +52,7 @@ fn quick_two_diff(simd: S, a: S::f64s, b: S::f64s) -> (S::f64s, S::f64s } #[inline(always)] +#[allow(dead_code)] fn two_diff(simd: S, a: S::f64s, b: S::f64s) -> (S::f64s, S::f64s) { let s = simd.f64s_sub(a, b); let bb = simd.f64s_sub(s, a); @@ -56,6 +62,29 @@ fn two_diff(simd: S, a: S::f64s, b: S::f64s) -> (S::f64s, S::f64s) { (s, err) } +#[inline(always)] +#[allow(dead_code)] +fn two_diff_e(simd: S, a: S::f64s, b: S::f64s) -> (S::f64s, S::f64s) { + two_sum_e(simd, a, simd.f64s_neg(b)) +} + +#[inline(always)] +fn quick_two_sum(simd: S, a: S::f64s, b: S::f64s) -> (S::f64s, S::f64s) { + let s = simd.f64s_add(a, b); + let err = simd.f64s_sub(b, simd.f64s_sub(s, a)); + (s, err) +} + +#[inline(always)] +fn two_sum(simd: S, a: S::f64s, b: S::f64s) -> (S::f64s, S::f64s) { + let s = simd.f64s_add(a, b); + let bb = simd.f64s_sub(s, a); + + // (a - (s - bb)) + (b - bb) + let err = simd.f64s_add(simd.f64s_sub(a, simd.f64s_sub(s, bb)), simd.f64s_sub(b, bb)); + (s, err) +} + #[inline(always)] fn two_prod(simd: S, a: S::f64s, b: S::f64s) -> (S::f64s, S::f64s) { let p = simd.f64s_mul(a, b); @@ -64,12 +93,12 @@ fn two_prod(simd: S, a: S::f64s, b: S::f64s) -> (S::f64s, S::f64s) { (p, err) } -pub mod simd { +pub mod double { use super::*; #[inline(always)] pub fn simd_add(simd: S, a: Double, b: Double) -> Double { - let (s, e) = two_sum(simd, a.0, b.0); + let (s, e) = two_sum_e(simd, a.0, b.0); let e = simd.f64s_add(e, simd.f64s_add(a.1, b.1)); let (s, e) = quick_two_sum(simd, s, e); Double(s, e) @@ -77,7 +106,7 @@ pub mod simd { #[inline(always)] pub fn simd_sub(simd: S, a: Double, b: Double) -> Double { - let (s, e) = two_diff(simd, a.0, b.0); + let (s, e) = two_diff_e(simd, a.0, b.0); let e = simd.f64s_add(e, a.1); let e = simd.f64s_sub(e, b.1); let (s, e) = quick_two_sum(simd, s, e); @@ -100,15 +129,6 @@ pub mod simd { Double(p1, p2) } - #[inline(always)] - pub fn simd_mul_power_of_two( - simd: S, - a: Double, - b: S::f64s, - ) -> Double { - Double(simd.f64s_mul(a.0, b), simd.f64s_mul(a.1, b)) - } - #[inline(always)] fn simd_mul_f64(simd: S, a: Double, b: S::f64s) -> Double { let (p1, p2) = two_prod(simd, a.0, b); @@ -164,7 +184,7 @@ pub mod simd { let s2 = simd.f64s_add(s2, a.1); let q2 = simd.f64s_div(simd.f64s_add(s1, s2), b.0); - let (r0, r1) = quick_two_sum(simd, q1, q2); + let (q0, q1) = quick_two_sum(simd, q1, q2); simd_select( simd, @@ -180,7 +200,7 @@ pub mod simd { simd.f64s_or(combined_sign, pos_zero), simd.f64s_or(combined_sign, pos_zero), ), - Double(r0, r1), + Double(q0, q1), ), ) }, @@ -244,71 +264,71 @@ pub mod simd { } } -impl core::ops::Add for DoubleF64 { +impl core::ops::Add for Double { type Output = Self; #[inline(always)] fn add(self, rhs: Self) -> Self::Output { - simd::simd_add(Scalar::new(), self, rhs) + double::simd_add(Scalar::new(), self, rhs) } } -impl core::ops::Sub for DoubleF64 { +impl core::ops::Sub for Double { type Output = Self; #[inline(always)] fn sub(self, rhs: Self) -> Self::Output { - simd::simd_sub(Scalar::new(), self, rhs) + double::simd_sub(Scalar::new(), self, rhs) } } -impl core::ops::Mul for DoubleF64 { +impl core::ops::Mul for Double { type Output = Self; #[inline(always)] fn mul(self, rhs: Self) -> Self::Output { - simd::simd_mul(Scalar::new(), self, rhs) + double::simd_mul(Scalar::new(), self, rhs) } } -impl core::ops::Div for DoubleF64 { +impl core::ops::Div for Double { type Output = Self; #[inline(always)] fn div(self, rhs: Self) -> Self::Output { - simd::simd_div(Scalar::new(), self, rhs) + double::simd_div(Scalar::new(), self, rhs) } } -impl core::ops::AddAssign for DoubleF64 { +impl core::ops::AddAssign for Double { #[inline(always)] fn add_assign(&mut self, rhs: Self) { *self = *self + rhs; } } -impl core::ops::SubAssign for DoubleF64 { +impl core::ops::SubAssign for Double { #[inline(always)] fn sub_assign(&mut self, rhs: Self) { *self = *self - rhs; } } -impl core::ops::MulAssign for DoubleF64 { +impl core::ops::MulAssign for Double { #[inline(always)] fn mul_assign(&mut self, rhs: Self) { *self = *self * rhs; } } -impl core::ops::DivAssign for DoubleF64 { +impl core::ops::DivAssign for Double { #[inline(always)] fn div_assign(&mut self, rhs: Self) { *self = *self / rhs; } } -impl core::ops::Neg for DoubleF64 { +impl core::ops::Neg for Double { type Output = Self; #[inline(always)] @@ -317,7 +337,7 @@ impl core::ops::Neg for DoubleF64 { } } -impl DoubleF64 { +impl Double { /// 2.0^{-100} pub const EPSILON: Self = Self(7.888609052210118e-31, 0.0); /// 2.0^{-970}: precision below this value begins to degrade. @@ -326,16 +346,15 @@ impl DoubleF64 { pub const ZERO: Self = Self(0.0, 0.0); pub const NAN: Self = Self(f64::NAN, f64::NAN); pub const INFINITY: Self = Self(f64::INFINITY, f64::INFINITY); - pub const NEG_INFINITY: Self = Self(f64::NEG_INFINITY, f64::NEG_INFINITY); #[inline(always)] pub fn abs(self) -> Self { - simd::simd_abs(Scalar::new(), self) + double::simd_abs(Scalar::new(), self) } #[inline(always)] pub fn recip(self) -> Self { - simd::simd_div(Scalar::new(), Self(1.0, 0.0), self) + double::simd_div(Scalar::new(), Self(1.0, 0.0), self) } #[inline] @@ -360,17 +379,19 @@ pub struct DoubleGroup { __private: (), } +impl ForType for DoubleGroup { + type FaerOf = Double; +} +impl ForCopyType for DoubleGroup { + type FaerOfCopy = Double; +} +impl ForDebugType for DoubleGroup { + type FaerOfDebug = Double; +} + mod faer_impl { use super::*; - use faer_core::{ComplexField, Conjugate, Entity, RealField}; - use faer_entity::*; - - impl ForType for DoubleGroup { - type FaerOf = Double; - } - impl ForCopyType for DoubleGroup { - type FaerOfCopy = Double; - } + use faer::{ComplexField, Conjugate, Entity, RealField}; unsafe impl Entity for Double { type Unit = f64; @@ -383,9 +404,19 @@ mod faer_impl { type Group = DoubleGroup; type Iter = Double; + type PrefixUnit<'a, S: Simd> = pulp::Prefix<'a, f64, S, S::m64s>; + type SuffixUnit<'a, S: Simd> = pulp::Suffix<'a, f64, S, S::m64s>; + type PrefixMutUnit<'a, S: Simd> = pulp::PrefixMut<'a, f64, S, S::m64s>; + type SuffixMutUnit<'a, S: Simd> = pulp::SuffixMut<'a, f64, S, S::m64s>; + const N_COMPONENTS: usize = 2; const UNIT: GroupCopyFor = Double((), ()); + #[inline(always)] + fn faer_first(group: GroupFor) -> T { + group.0 + } + #[inline(always)] fn faer_from_units(group: GroupFor) -> Self { group @@ -406,6 +437,16 @@ mod faer_impl { Double(&mut group.0, &mut group.1) } + #[inline(always)] + fn faer_as_ptr(group: *mut GroupFor) -> GroupFor { + unsafe { + Double( + core::ptr::addr_of_mut!((*group).0), + core::ptr::addr_of_mut!((*group).1), + ) + } + } + #[inline(always)] fn faer_map_impl( group: GroupFor, @@ -449,19 +490,21 @@ mod faer_impl { } } - unsafe impl Conjugate for DoubleF64 { - type Conj = DoubleF64; - type Canonical = DoubleF64; + unsafe impl Conjugate for Double { + type Conj = Double; + type Canonical = Double; #[inline(always)] fn canonicalize(self) -> Self::Canonical { self } } - impl RealField for DoubleF64 { + impl RealField for Double { + #[inline(always)] fn faer_epsilon() -> Option { Some(Self::EPSILON) } + #[inline(always)] fn faer_zero_threshold() -> Option { Some(Self::MIN_POSITIVE) } @@ -492,7 +535,7 @@ mod faer_impl { a: SimdGroupFor, b: SimdGroupFor, ) -> Self::SimdMask { - simd::simd_less_than(simd, a, b) + double::simd_less_than(simd, a, b) } #[inline(always)] @@ -501,7 +544,7 @@ mod faer_impl { a: SimdGroupFor, b: SimdGroupFor, ) -> Self::SimdMask { - simd::simd_less_than_or_equal(simd, a, b) + double::simd_less_than_or_equal(simd, a, b) } #[inline(always)] @@ -510,7 +553,7 @@ mod faer_impl { a: SimdGroupFor, b: SimdGroupFor, ) -> Self::SimdMask { - simd::simd_greater_than(simd, a, b) + double::simd_greater_than(simd, a, b) } #[inline(always)] @@ -519,7 +562,7 @@ mod faer_impl { a: SimdGroupFor, b: SimdGroupFor, ) -> Self::SimdMask { - simd::simd_greater_than_or_equal(simd, a, b) + double::simd_greater_than_or_equal(simd, a, b) } #[inline(always)] @@ -529,7 +572,7 @@ mod faer_impl { if_true: SimdGroupFor, if_false: SimdGroupFor, ) -> SimdGroupFor { - simd::simd_select(simd, mask, if_true, if_false) + double::simd_select(simd, mask, if_true, if_false) } #[inline(always)] @@ -561,12 +604,47 @@ mod faer_impl { ) -> Self::SimdIndex { simd.u64s_add(a, b) } + + #[inline(always)] + fn faer_min_positive() -> Self { + Self::MIN_POSITIVE + } + + #[inline(always)] + fn faer_min_positive_inv() -> Self { + Self::MIN_POSITIVE.recip() + } + + #[inline(always)] + fn faer_min_positive_sqrt() -> Self { + Self::MIN_POSITIVE.sqrt() + } + + #[inline(always)] + fn faer_min_positive_sqrt_inv() -> Self { + Self::MIN_POSITIVE.sqrt().recip() + } + + #[inline(always)] + fn faer_simd_index_rotate_left( + simd: S, + values: SimdIndexFor, + amount: usize, + ) -> SimdIndexFor { + simd.u64s_rotate_left(values, amount) + } + + #[inline(always)] + fn faer_simd_abs(simd: S, values: SimdGroupFor) -> SimdGroupFor { + double::simd_abs(simd, values) + } } - impl ComplexField for DoubleF64 { - type Real = DoubleF64; + impl ComplexField for Double { + type Real = Double; type Simd = pulp::Arch; type ScalarSimd = pulp::Arch; + type PortableSimd = pulp::Arch; #[inline(always)] fn faer_sqrt(self) -> Self { @@ -600,7 +678,7 @@ mod faer_impl { #[inline(always)] fn faer_inv(self) -> Self { - (self).recip() + self.recip() } #[inline(always)] @@ -664,29 +742,26 @@ mod faer_impl { } #[inline(always)] - fn faer_slice_as_simd( + fn faer_slice_as_simd( slice: &[Self::Unit], ) -> (&[Self::SimdUnit], &[Self::Unit]) { S::f64s_as_simd(slice) } #[inline(always)] - fn faer_slice_as_mut_simd( + fn faer_slice_as_simd_mut( slice: &mut [Self::Unit], ) -> (&mut [Self::SimdUnit], &mut [Self::Unit]) { S::f64s_as_mut_simd(slice) } #[inline(always)] - fn faer_partial_load_unit( - simd: S, - slice: &[Self::Unit], - ) -> Self::SimdUnit { + fn faer_partial_load_unit(simd: S, slice: &[Self::Unit]) -> Self::SimdUnit { simd.f64s_partial_load(slice) } #[inline(always)] - fn faer_partial_store_unit( + fn faer_partial_store_unit( simd: S, slice: &mut [Self::Unit], values: Self::SimdUnit, @@ -695,7 +770,7 @@ mod faer_impl { } #[inline(always)] - fn faer_partial_load_last_unit( + fn faer_partial_load_last_unit( simd: S, slice: &[Self::Unit], ) -> Self::SimdUnit { @@ -703,7 +778,7 @@ mod faer_impl { } #[inline(always)] - fn faer_partial_store_last_unit( + fn faer_partial_store_last_unit( simd: S, slice: &mut [Self::Unit], values: Self::SimdUnit, @@ -712,23 +787,17 @@ mod faer_impl { } #[inline(always)] - fn faer_simd_splat_unit( - simd: S, - unit: Self::Unit, - ) -> Self::SimdUnit { + fn faer_simd_splat_unit(simd: S, unit: Self::Unit) -> Self::SimdUnit { simd.f64s_splat(unit) } #[inline(always)] - fn faer_simd_neg( - simd: S, - values: SimdGroupFor, - ) -> SimdGroupFor { - simd::simd_neg(simd, values) + fn faer_simd_neg(simd: S, values: SimdGroupFor) -> SimdGroupFor { + double::simd_neg(simd, values) } #[inline(always)] - fn faer_simd_conj( + fn faer_simd_conj( simd: S, values: SimdGroupFor, ) -> SimdGroupFor { @@ -737,68 +806,68 @@ mod faer_impl { } #[inline(always)] - fn faer_simd_add( + fn faer_simd_add( simd: S, lhs: SimdGroupFor, rhs: SimdGroupFor, ) -> SimdGroupFor { - simd::simd_add(simd, lhs, rhs) + double::simd_add(simd, lhs, rhs) } #[inline(always)] - fn faer_simd_sub( + fn faer_simd_sub( simd: S, lhs: SimdGroupFor, rhs: SimdGroupFor, ) -> SimdGroupFor { - simd::simd_sub(simd, lhs, rhs) + double::simd_sub(simd, lhs, rhs) } #[inline(always)] - fn faer_simd_mul( + fn faer_simd_mul( simd: S, lhs: SimdGroupFor, rhs: SimdGroupFor, ) -> SimdGroupFor { - simd::simd_mul(simd, lhs, rhs) + double::simd_mul(simd, lhs, rhs) } #[inline(always)] - fn faer_simd_scale_real( + fn faer_simd_scale_real( simd: S, lhs: SimdGroupFor, rhs: SimdGroupFor, ) -> SimdGroupFor { - simd::simd_mul(simd, lhs, rhs) + double::simd_mul(simd, lhs, rhs) } #[inline(always)] - fn faer_simd_conj_mul( + fn faer_simd_conj_mul( simd: S, lhs: SimdGroupFor, rhs: SimdGroupFor, ) -> SimdGroupFor { - simd::simd_mul(simd, lhs, rhs) + double::simd_mul(simd, lhs, rhs) } #[inline(always)] - fn faer_simd_mul_adde( + fn faer_simd_mul_adde( simd: S, lhs: SimdGroupFor, rhs: SimdGroupFor, acc: SimdGroupFor, ) -> SimdGroupFor { - simd::simd_add(simd, acc, simd::simd_mul(simd, lhs, rhs)) + double::simd_add(simd, acc, double::simd_mul(simd, lhs, rhs)) } #[inline(always)] - fn faer_simd_conj_mul_adde( + fn faer_simd_conj_mul_adde( simd: S, lhs: SimdGroupFor, rhs: SimdGroupFor, acc: SimdGroupFor, ) -> SimdGroupFor { - simd::simd_add(simd, acc, simd::simd_mul(simd, lhs, rhs)) + double::simd_add(simd, acc, double::simd_mul(simd, lhs, rhs)) } #[inline(always)] @@ -806,7 +875,7 @@ mod faer_impl { simd: S, values: SimdGroupFor, ) -> SimdGroupFor { - simd::simd_abs(simd, values) + double::simd_abs(simd, values) } #[inline(always)] @@ -831,16 +900,19 @@ mod faer_impl { let _ = simd; lhs * rhs } + #[inline(always)] fn faer_simd_scalar_conj_mul(simd: S, lhs: Self, rhs: Self) -> Self { let _ = simd; lhs * rhs } + #[inline(always)] fn faer_simd_scalar_mul_adde(simd: S, lhs: Self, rhs: Self, acc: Self) -> Self { let _ = simd; lhs * rhs + acc } + #[inline(always)] fn faer_simd_scalar_conj_mul_adde( simd: S, @@ -851,5 +923,51 @@ mod faer_impl { let _ = simd; lhs * rhs + acc } + + #[inline(always)] + fn faer_slice_as_aligned_simd( + simd: S, + slice: &[UnitFor], + offset: pulp::Offset>, + ) -> ( + pulp::Prefix<'_, UnitFor, S, SimdMaskFor>, + &[SimdUnitFor], + pulp::Suffix<'_, UnitFor, S, SimdMaskFor>, + ) { + simd.f64s_as_aligned_simd(slice, offset) + } + #[inline(always)] + fn faer_slice_as_aligned_simd_mut( + simd: S, + slice: &mut [UnitFor], + offset: pulp::Offset>, + ) -> ( + pulp::PrefixMut<'_, UnitFor, S, SimdMaskFor>, + &mut [SimdUnitFor], + pulp::SuffixMut<'_, UnitFor, S, SimdMaskFor>, + ) { + simd.f64s_as_aligned_mut_simd(slice, offset) + } + + #[inline(always)] + fn faer_simd_rotate_left( + simd: S, + values: SimdGroupFor, + amount: usize, + ) -> SimdGroupFor { + Double( + simd.f64s_rotate_left(values.0, amount), + simd.f64s_rotate_left(values.1, amount), + ) + } + + #[inline(always)] + fn faer_align_offset( + simd: S, + ptr: *const UnitFor, + len: usize, + ) -> pulp::Offset> { + simd.f64s_align_offset(ptr, len) + } } } diff --git a/faer-bench/src/evd.rs b/faer-bench/src/evd.rs index c740db1614d0548a71fe678c3ed8ae38ed9e5ac0..ee4f49c9daa882c6353ec290aa901b1ef57e2e4b 100644 --- a/faer-bench/src/evd.rs +++ b/faer-bench/src/evd.rs @@ -1,7 +1,7 @@ use super::timeit; use crate::random; use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut}; -use faer_core::{Mat, Parallelism}; +use faer::{linalg::evd as faer_evd, prelude::*, Parallelism}; use ndarray_linalg::Eig; use std::time::Duration; @@ -36,10 +36,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() diff --git a/faer-bench/src/full_piv_lu.rs b/faer-bench/src/full_piv_lu.rs index 441c4ad3fc3597dc87740de2554ccc250fc14c0d..6c7ba1582ac8693553d8f621079f1d49d35b245b 100644 --- a/faer-bench/src/full_piv_lu.rs +++ b/faer-bench/src/full_piv_lu.rs @@ -1,7 +1,7 @@ use super::timeit; use crate::random; use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut}; -use faer_core::{unzipped, zipped, Mat, Parallelism}; +use faer::{linalg::lu as faer_lu, unzipped, zipped, Mat, Parallelism}; use std::time::Duration; pub fn ndarray(sizes: &[usize]) -> Vec { @@ -35,10 +35,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() diff --git a/faer-bench/src/gemm.rs b/faer-bench/src/gemm.rs index 38ad4aff87b8cf343bdf7aef7ebf3b4850b6a866..477fa5a17275247097efe81b12e5d653d4c61f07 100644 --- a/faer-bench/src/gemm.rs +++ b/faer-bench/src/gemm.rs @@ -1,5 +1,5 @@ use super::timeit; -use faer_core::{Mat, Parallelism}; +use faer::{prelude::*, Parallelism}; use num_traits::Zero; use std::time::Duration; @@ -45,10 +45,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() @@ -58,7 +55,7 @@ pub fn faer( let b = Mat::::zeros(n, n); let time = timeit(|| { - faer_core::mul::matmul( + faer::linalg::matmul::matmul( c.as_mut(), a.as_ref(), b.as_ref(), diff --git a/faer-bench/src/inverse.rs b/faer-bench/src/inverse.rs index 6945ebd98dbf49b679b884cfce4124e3c85227f0..eacaff8bfb3746edc9015e5bec2696e76cbabb6e 100644 --- a/faer-bench/src/inverse.rs +++ b/faer-bench/src/inverse.rs @@ -1,7 +1,7 @@ use super::timeit; use crate::random; use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut, StackReq}; -use faer_core::{unzipped, zipped, Mat, Parallelism}; +use faer::{linalg::lu as faer_lu, unzipped, zipped, Mat, Parallelism}; use ndarray_linalg::Inverse; use reborrow::*; use std::time::Duration; @@ -50,10 +50,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() diff --git a/faer-bench/src/main.rs b/faer-bench/src/main.rs index cb1aafc7c8f27427fdf4f83f6c8e33c01f9c1eb9..88db805bea792e2bd9be03aacdedd0bebdd06ef8 100644 --- a/faer-bench/src/main.rs +++ b/faer-bench/src/main.rs @@ -1,12 +1,15 @@ #![allow(dead_code)] use coe::is_same; use eyre::Result; -use faer_core::{c32, c64, Parallelism}; +use faer::{ + complex_native::{c32, c64}, + Parallelism, +}; use human_repr::HumanDuration; use std::{fs::File, io::Write, time::Duration}; -// use double_f64::DoubleF64 as f128; -type f128 = f64; +#[allow(non_camel_case_types)] +type f128 = double_f64::Double; #[allow(non_camel_case_types)] type c128 = num_complex::Complex; @@ -18,17 +21,17 @@ fn random() -> T { coe::coerce_static(rand::random::()) } else if is_same::() { coe::coerce_static(rand::random::()) - // } else if is_same::() { - // coe::coerce_static(double_f64::Double(rand::random::(), 0.0)) + } else if is_same::() { + coe::coerce_static(double_f64::Double(rand::random::(), 0.0)) } else if is_same::() { coe::coerce_static(c32::new(rand::random(), rand::random())) } else if is_same::() { coe::coerce_static(c64::new(rand::random(), rand::random())) - // } else if is_same::() { - // coe::coerce_static(c128::new( - // double_f64::Double(rand::random(), 0.0), - // double_f64::Double(rand::random(), 0.0), - // )) + } else if is_same::() { + coe::coerce_static(c128::new( + double_f64::Double(rand::random(), 0.0), + double_f64::Double(rand::random(), 0.0), + )) } else if is_same::() { coe::coerce_static(Cplx32::new(rand::random(), rand::random())) } else if is_same::() { @@ -70,7 +73,7 @@ fn timeit(f: impl FnMut()) -> f64 { }) / n as f64 } -// mod double_f64; +mod double_f64; mod gemm; mod tr_inverse; diff --git a/faer-bench/src/no_piv_qr.rs b/faer-bench/src/no_piv_qr.rs index 690f17304b5040f244e15825bbae4f797ac31d83..c5e89904188c088ad11a75b2dcb7357b9849d258 100644 --- a/faer-bench/src/no_piv_qr.rs +++ b/faer-bench/src/no_piv_qr.rs @@ -1,7 +1,7 @@ use super::timeit; use crate::random; use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut}; -use faer_core::{unzipped, zipped, Mat, Parallelism}; +use faer::{linalg::qr as faer_qr, unzipped, zipped, Mat, Parallelism}; use ndarray_linalg::QR; use std::time::Duration; @@ -49,10 +49,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() diff --git a/faer-bench/src/partial_piv_lu.rs b/faer-bench/src/partial_piv_lu.rs index 10fa3cd18c20326d6dc77e33edc74caf40165608..c861aa3a439a6063c71a837803e55797fef37dc3 100644 --- a/faer-bench/src/partial_piv_lu.rs +++ b/faer-bench/src/partial_piv_lu.rs @@ -1,7 +1,7 @@ use super::timeit; use crate::random; use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut}; -use faer_core::{unzipped, zipped, Mat, Parallelism}; +use faer::{linalg::lu as faer_lu, unzipped, zipped, Mat, Parallelism}; use ndarray_linalg::solve::Factorize; use std::time::Duration; @@ -49,10 +49,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() diff --git a/faer-bench/src/rectangular_svd.rs b/faer-bench/src/rectangular_svd.rs index bef0c7455f48a76f616011d97b7beb24d88202c4..6ec9b054c106f9b4f34924cad29700332c80792d 100644 --- a/faer-bench/src/rectangular_svd.rs +++ b/faer-bench/src/rectangular_svd.rs @@ -1,7 +1,7 @@ use super::timeit; use crate::random; use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut}; -use faer_core::{Mat, Parallelism}; +use faer::{linalg::svd as faer_svd, Mat, Parallelism}; use ndarray_linalg::{JobSvd, SVDDC}; use std::time::Duration; @@ -49,10 +49,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() diff --git a/faer-bench/src/svd.rs b/faer-bench/src/svd.rs index 077cf7bd84ff0912c4e47df7acb9e18446e84a78..25793c9175b86a192346caaf36288b2d21920427 100644 --- a/faer-bench/src/svd.rs +++ b/faer-bench/src/svd.rs @@ -1,7 +1,7 @@ use super::timeit; use crate::random; use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut}; -use faer_core::{Mat, Parallelism}; +use faer::{linalg::svd as faer_svd, prelude::*, Parallelism}; use ndarray_linalg::{JobSvd, SVDDC}; use std::time::Duration; @@ -49,10 +49,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() diff --git a/faer-bench/src/symmetric_evd.rs b/faer-bench/src/symmetric_evd.rs index d25514c8cbf583c6feef92a4261858fe925e6e37..d76edc7f464b7df7ce394083c89cd6f7a133e305 100644 --- a/faer-bench/src/symmetric_evd.rs +++ b/faer-bench/src/symmetric_evd.rs @@ -1,7 +1,7 @@ use super::timeit; use crate::random; use dyn_stack::{GlobalPodBuffer, PodStack, ReborrowMut}; -use faer_core::{Mat, Parallelism}; +use faer::{linalg::evd as faer_evd, Mat, Parallelism}; use ndarray_linalg::Eigh; use std::time::Duration; @@ -52,10 +52,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() diff --git a/faer-bench/src/tr_inverse.rs b/faer-bench/src/tr_inverse.rs index 7cae4abe5d656160c19b9d5911171deaf90f96d1..87cd03c25a55feb9dd83a42b2652e71835db9775 100644 --- a/faer-bench/src/tr_inverse.rs +++ b/faer-bench/src/tr_inverse.rs @@ -1,5 +1,5 @@ use super::timeit; -use faer_core::{Mat, Parallelism}; +use faer::{Mat, Parallelism}; use std::time::Duration; pub fn ndarray(sizes: &[usize]) -> Vec { @@ -10,10 +10,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { super::trsm::nalgebra::(sizes) } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() @@ -22,7 +19,7 @@ pub fn faer( let a = Mat::::zeros(n, n); let time = timeit(|| { - faer_core::inverse::invert_unit_lower_triangular( + faer::linalg::triangular_inverse::invert_unit_lower_triangular( c.as_mut(), a.as_ref(), parallelism, diff --git a/faer-bench/src/trsm.rs b/faer-bench/src/trsm.rs index 96fb649d1688344ab58bf44ed5edd859dfe09805..14605b3d5bddf7bb1fea49fcb5c21951bfaa1980 100644 --- a/faer-bench/src/trsm.rs +++ b/faer-bench/src/trsm.rs @@ -1,5 +1,5 @@ use super::timeit; -use faer_core::{Mat, Parallelism}; +use faer::{Mat, Parallelism}; use ndarray_linalg::*; use std::time::Duration; @@ -44,10 +44,7 @@ pub fn nalgebra(sizes: &[usize]) -> Vec { .collect() } -pub fn faer( - sizes: &[usize], - parallelism: Parallelism, -) -> Vec { +pub fn faer(sizes: &[usize], parallelism: Parallelism) -> Vec { sizes .iter() .copied() @@ -56,7 +53,7 @@ pub fn faer( let a = Mat::::zeros(n, n); let time = timeit(|| { - faer_core::solve::solve_unit_lower_triangular_in_place( + faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place( a.as_ref(), c.as_mut(), parallelism, diff --git a/faer-libs/Cargo.toml b/faer-libs/Cargo.toml deleted file mode 100644 index 3bb85795f8e8b50ef286e00dc76de9691a1cd398..0000000000000000000000000000000000000000 --- a/faer-libs/Cargo.toml +++ /dev/null @@ -1,64 +0,0 @@ -[workspace] -members = [ - "faer-core", - "faer-cholesky", - "faer-lu", - "faer-qr", - "faer-svd", - "faer-evd", - "faer-sparse", - - "faer", -] -resolver = "2" - -[workspace.dependencies] -faer-entity = { version = "0.17.0", default-features = false, path = "../faer-entity" } - -gemm = { version = "0.17.1", default-features = false } - -coe-rs = "0.1" -reborrow = "0.5" -pulp = { version = "0.18.8", default-features = false } -dyn-stack = { version = "0.10", default-features = false } - -num-traits = { version = "0.2", default-features = false } -num-complex = { version = "0.4", default-features = false } -libm = "0.2" -bytemuck = { version = "1", default-features = false } - -rand = { version = "0.8", default-features = false } -rayon = "1" -serde = { version = "1", features = ["derive"] } -assert2 = "0.3" -equator = "0.1.10" -log = { version = "0.4", default-features = false } - -criterion = { git = "https://github.com/sarah-ek/criterion.rs" } - -[profile.unopt] -inherits = "dev" -opt-level = 0 -debug = true -debug-assertions = true -overflow-checks = true -lto = false -panic = 'unwind' -incremental = true -codegen-units = 256 -rpath = false - -[profile.dev] -opt-level = 3 -debug = true -debug-assertions = true -overflow-checks = true -lto = false -panic = 'unwind' -incremental = true -codegen-units = 256 -rpath = false - -[profile.bench] -inherits = "release" -debug = false diff --git a/faer-libs/faer-cholesky/Cargo.toml b/faer-libs/faer-cholesky/Cargo.toml deleted file mode 100644 index 5b5b622978d9128b87970f6462b20eb0aef56f92..0000000000000000000000000000000000000000 --- a/faer-libs/faer-cholesky/Cargo.toml +++ /dev/null @@ -1,52 +0,0 @@ -[package] -name = "faer-cholesky" -version = "0.17.1" -edition = "2021" -authors = ["sarah <>"] -description = "Basic linear algebra routines" -readme = "../../README.md" -repository = "https://github.com/sarah-ek/faer-rs/" -license = "MIT" -keywords = ["math", "matrix", "linear-algebra"] - -[dependencies] -faer-entity = { workspace = true, default-features = false } - -faer-core = { version = "0.17.1", default-features = false, path = "../faer-core" } -seq-macro = "0.3" - -reborrow = { workspace = true } -pulp = { workspace = true, default-features = false } -dyn-stack = { workspace = true, default-features = false } - -num-traits = { workspace = true, default-features = false } -num-complex = { workspace = true, default-features = false } -bytemuck = { workspace = true } - -log = { workspace = true, optional = true, default-features = false } - -[features] -default = ["std", "rayon"] -std = [ - "faer-core/std", - "pulp/std", -] -perf-warn = ["log", "faer-core/perf-warn"] -rayon = ["std", "faer-core/rayon"] -nightly = ["faer-core/nightly", "pulp/nightly"] - -[dev-dependencies] -criterion = "0.5" -rand = "0.8.5" -nalgebra = "0.32.3" -assert_approx_eq = "1.1.0" -rayon = "1.8" -dbgf = "0.1.1" -env_logger = "0.10" - -[[bench]] -name = "bench" -harness = false - -[package.metadata.docs.rs] -rustdoc-args = ["--html-in-header", "katex-header.html"] diff --git a/faer-libs/faer-cholesky/LICENSE b/faer-libs/faer-cholesky/LICENSE deleted file mode 100644 index b3e9659c8860f4d82899554c214b91d46760ea59..0000000000000000000000000000000000000000 --- a/faer-libs/faer-cholesky/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2022 sarah - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/faer-libs/faer-cholesky/benches/bench.rs b/faer-libs/faer-cholesky/benches/bench.rs deleted file mode 100644 index 17ab2261007e151a820f8cfcff0497d60273a34e..0000000000000000000000000000000000000000 --- a/faer-libs/faer-cholesky/benches/bench.rs +++ /dev/null @@ -1,324 +0,0 @@ -use std::time::Duration; - -use criterion::{criterion_group, criterion_main, Criterion}; -use dyn_stack::{GlobalPodBuffer, PodStack}; -use faer_cholesky::bunch_kaufman; -use faer_core::{c64, ComplexField}; -use reborrow::*; - -use faer_core::{Mat, Parallelism}; -use nalgebra::DMatrix; - -pub fn cholesky(c: &mut Criterion) { - use faer_cholesky::{ldlt_diagonal, llt}; - - for n in [6, 8, 12, 16, 24, 32, 64, 128, 256, 512, 1024, 2000, 4096] { - c.bench_function(&format!("faer-st-bk-{n}"), |b| { - let mut mat = Mat::from_fn(n, n, |_, _| rand::random::()); - let mut subdiag = Mat::zeros(n, 1); - - let mut perm = vec![0usize; n]; - let mut perm_inv = vec![0; n]; - - let mut mem = GlobalPodBuffer::new( - bunch_kaufman::compute::cholesky_in_place_req::( - n, - Parallelism::None, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - b.iter(|| { - bunch_kaufman::compute::cholesky_in_place( - mat.as_mut(), - subdiag.as_mut(), - Default::default(), - &mut perm, - &mut perm_inv, - Parallelism::None, - stack.rb_mut(), - Default::default(), - ); - }) - }); - - c.bench_function(&format!("faer-mt-bk-{n}"), |b| { - let mut mat = Mat::from_fn(n, n, |_, _| rand::random::()); - let mut subdiag = Mat::zeros(n, 1); - - let mut perm = vec![0usize; n]; - let mut perm_inv = vec![0; n]; - - let mut mem = GlobalPodBuffer::new( - bunch_kaufman::compute::cholesky_in_place_req::( - n, - Parallelism::Rayon(0), - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - b.iter(|| { - bunch_kaufman::compute::cholesky_in_place( - mat.as_mut(), - subdiag.as_mut(), - Default::default(), - &mut perm, - &mut perm_inv, - Parallelism::Rayon(0), - stack.rb_mut(), - Default::default(), - ); - }) - }); - - c.bench_function(&format!("faer-st-cplx-bk-{n}"), |b| { - let mut mat = Mat::from_fn(n, n, |_, _| c64::new(rand::random(), rand::random())); - let mut subdiag = Mat::zeros(n, 1); - - let mut perm = vec![0usize; n]; - let mut perm_inv = vec![0; n]; - - let mut mem = GlobalPodBuffer::new( - bunch_kaufman::compute::cholesky_in_place_req::( - n, - Parallelism::None, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - b.iter(|| { - bunch_kaufman::compute::cholesky_in_place( - mat.as_mut(), - subdiag.as_mut(), - Default::default(), - &mut perm, - &mut perm_inv, - Parallelism::None, - stack.rb_mut(), - Default::default(), - ); - }) - }); - c.bench_function(&format!("faer-mt-cplx-bk-{n}"), |b| { - let mut mat = Mat::from_fn(n, n, |_, _| c64::new(rand::random(), rand::random())); - let mut subdiag = Mat::zeros(n, 1); - - let mut perm = vec![0usize; n]; - let mut perm_inv = vec![0; n]; - - let mut mem = GlobalPodBuffer::new( - bunch_kaufman::compute::cholesky_in_place_req::( - n, - Parallelism::Rayon(0), - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - b.iter(|| { - bunch_kaufman::compute::cholesky_in_place( - mat.as_mut(), - subdiag.as_mut(), - Default::default(), - &mut perm, - &mut perm_inv, - Parallelism::Rayon(0), - stack.rb_mut(), - Default::default(), - ); - }) - }); - - c.bench_function(&format!("faer-st-ldlt-{n}"), |b| { - let mut mat = Mat::new(); - - mat.resize_with(n, n, |i, j| if i == j { 1.0 } else { 0.0 }); - let mut mem = GlobalPodBuffer::new( - ldlt_diagonal::compute::raw_cholesky_in_place_req::( - n, - Parallelism::None, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - b.iter(|| { - ldlt_diagonal::compute::raw_cholesky_in_place( - mat.as_mut(), - Default::default(), - Parallelism::None, - stack.rb_mut(), - Default::default(), - ); - }) - }); - - c.bench_function(&format!("faer-mt-ldlt-{n}"), |b| { - let mut mat = Mat::new(); - - mat.resize_with(n, n, |i, j| if i == j { 1.0 } else { 0.0 }); - let mut mem = GlobalPodBuffer::new( - ldlt_diagonal::compute::raw_cholesky_in_place_req::( - n, - Parallelism::Rayon(rayon::current_num_threads()), - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - b.iter(|| { - ldlt_diagonal::compute::raw_cholesky_in_place( - mat.as_mut(), - Default::default(), - Parallelism::Rayon(rayon::current_num_threads()), - stack.rb_mut(), - Default::default(), - ); - }) - }); - - c.bench_function(&format!("faer-st-llt-{n}"), |b| { - let mut mat = Mat::new(); - - mat.resize_with(n, n, |i, j| if i == j { 1.0 } else { 0.0 }); - let mut mem = GlobalPodBuffer::new( - llt::compute::cholesky_in_place_req::( - n, - Parallelism::None, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - b.iter(|| { - llt::compute::cholesky_in_place( - mat.as_mut(), - Default::default(), - Parallelism::None, - stack.rb_mut(), - Default::default(), - ) - .unwrap(); - }) - }); - - c.bench_function(&format!("faer-mt-llt-{n}"), |b| { - let mut mat = Mat::new(); - - mat.resize_with(n, n, |i, j| if i == j { 1.0 } else { 0.0 }); - let mut mem = GlobalPodBuffer::new( - llt::compute::cholesky_in_place_req::( - n, - Parallelism::Rayon(rayon::current_num_threads()), - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - b.iter(|| { - llt::compute::cholesky_in_place( - mat.as_mut(), - Default::default(), - Parallelism::Rayon(rayon::current_num_threads()), - stack.rb_mut(), - Default::default(), - ) - .unwrap(); - }) - }); - - c.bench_function(&format!("faer-st-cplx-llt-{n}"), |b| { - let mut mat = Mat::from_fn(n, n, |i, j| { - if i == j { - c64::faer_one() - } else { - c64::faer_zero() - } - }); - - let mut mem = GlobalPodBuffer::new( - llt::compute::cholesky_in_place_req::( - n, - Parallelism::None, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - b.iter(|| { - llt::compute::cholesky_in_place( - mat.as_mut(), - Default::default(), - Parallelism::None, - stack.rb_mut(), - Default::default(), - ) - .unwrap(); - }) - }); - - c.bench_function(&format!("faer-mt-cplx-llt-{n}"), |b| { - let mut mat = Mat::from_fn(n, n, |i, j| { - if i == j { - c64::faer_one() - } else { - c64::faer_zero() - } - }); - - let mut mem = GlobalPodBuffer::new( - llt::compute::cholesky_in_place_req::( - n, - Parallelism::Rayon(rayon::current_num_threads()), - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - b.iter(|| { - llt::compute::cholesky_in_place( - mat.as_mut(), - Default::default(), - Parallelism::Rayon(rayon::current_num_threads()), - stack.rb_mut(), - Default::default(), - ) - .unwrap(); - }) - }); - - c.bench_function(&format!("nalg-st-llt-{n}"), |b| { - let mut mat = DMatrix::::zeros(n, n); - for i in 0..n { - mat[(i, i)] = 1.0; - } - - b.iter(|| { - let _ = mat.clone().cholesky(); - }) - }); - } -} - -criterion_group!( - name = benches; - config = Criterion::default() - .warm_up_time(Duration::from_secs(3)) - .measurement_time(Duration::from_secs(5)) - .sample_size(10); - targets = cholesky -); -criterion_main!(benches); diff --git a/faer-libs/faer-cholesky/katex-header.html b/faer-libs/faer-cholesky/katex-header.html deleted file mode 100644 index 32ac35a411428d1bcf1914b639299df9f86e448c..0000000000000000000000000000000000000000 --- a/faer-libs/faer-cholesky/katex-header.html +++ /dev/null @@ -1,15 +0,0 @@ - - - - diff --git a/faer-libs/faer-core/Cargo.toml b/faer-libs/faer-core/Cargo.toml deleted file mode 100644 index cd4ec78c2b28e1d3d94147173f9e97944d018fb7..0000000000000000000000000000000000000000 --- a/faer-libs/faer-core/Cargo.toml +++ /dev/null @@ -1,65 +0,0 @@ -[package] -name = "faer-core" -version = "0.17.1" -edition = "2021" -authors = ["sarah <>"] -description = "Basic linear algebra routines" -readme = "../../README.md" -repository = "https://github.com/sarah-ek/faer-rs/" -license = "MIT" -keywords = ["math", "matrix", "linear-algebra"] -rust-version = "1.67.0" - -[dependencies] -faer-entity = { workspace = true, default-features = false } - -seq-macro = "0.3" -paste = "1.0" - -coe-rs = { workspace = true } -reborrow = { workspace = true } - -pulp = { workspace = true, default-features = false } -dyn-stack = { workspace = true, default-features = false } -gemm = { workspace = true, default-features = false } - -num-traits = { workspace = true, default-features = false } -num-complex = { workspace = true, default-features = false } -bytemuck = { workspace = true } - -rand = { workspace = true, optional = true, default-features = false } -rayon = { workspace = true, optional = true } -serde = { workspace = true, optional = true } -equator = { workspace = true } -log = { workspace = true, optional = true, default-features = false } -matrixcompare-core = { version = "0.1.0", optional = true } - -[features] -default = ["std", "rayon", "serde"] -std = [ - "faer-entity/std", - "gemm/std", - "pulp/std", - "matrixcompare-core", - "num-traits/std", - "num-complex/std", -] -rayon = ["std", "gemm/rayon", "dep:rayon"] -nightly = ["faer-entity/nightly", "gemm/nightly", "pulp/nightly"] -perf-warn = ["log"] -serde = ["dep:serde"] - -[dev-dependencies] -criterion = "0.5" -rand = "0.8.5" -nalgebra = "0.32.3" -assert_approx_eq = "1.1.0" -dbgf = "0.1.1" -serde_test = "1.0.176" - -[[bench]] -name = "bench" -harness = false - -[package.metadata.docs.rs] -rustdoc-args = ["--html-in-header", "katex-header.html"] diff --git a/faer-libs/faer-core/LICENSE b/faer-libs/faer-core/LICENSE deleted file mode 100644 index b3e9659c8860f4d82899554c214b91d46760ea59..0000000000000000000000000000000000000000 --- a/faer-libs/faer-core/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2022 sarah - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/faer-libs/faer-core/benches/bench.rs b/faer-libs/faer-core/benches/bench.rs deleted file mode 100644 index 92bd7ee3d2284ef230c1ead458a8af3605c2d514..0000000000000000000000000000000000000000 --- a/faer-libs/faer-core/benches/bench.rs +++ /dev/null @@ -1,94 +0,0 @@ -use criterion::{criterion_group, criterion_main, Criterion}; -use faer_core::{ - c32, c64, is_vectorizable, mul::matmul_with_conj_gemm_dispatch as matmul_with_conj, - ComplexField, Conj, Mat, Parallelism, -}; -use std::time::Duration; - -pub fn matmul(criterion: &mut Criterion) { - let name = core::any::type_name::(); - for n in [2, 3, 4, 5, 6, 8, 32, 64, 128, 256, 512, 1024] { - let mut acc = Mat::::zeros(n, n); - let a = Mat::::zeros(n, n); - let b = Mat::::zeros(n, n); - criterion.bench_function(&format!("matmul-{name}-st-{n}"), |bencher| { - bencher.iter(|| { - matmul_with_conj( - acc.as_mut(), - a.as_ref(), - Conj::No, - b.as_ref(), - Conj::No, - None, - E::faer_one(), - Parallelism::None, - false, - ); - }); - }); - criterion.bench_function(&format!("matmul-{name}-mt-{n}"), |bencher| { - bencher.iter(|| { - matmul_with_conj( - acc.as_mut(), - a.as_ref(), - Conj::No, - b.as_ref(), - Conj::No, - None, - E::faer_one(), - Parallelism::Rayon(0), - false, - ); - }); - }); - - if is_vectorizable::() { - criterion.bench_function(&format!("gemm-{name}-mt-{n}"), |bencher| { - bencher.iter(|| { - matmul_with_conj( - acc.as_mut(), - a.as_ref(), - Conj::No, - b.as_ref(), - Conj::No, - None, - E::faer_one(), - Parallelism::Rayon(0), - true, - ); - }); - }); - criterion.bench_function(&format!("gemm-{name}-st-{n}"), |bencher| { - bencher.iter(|| { - matmul_with_conj( - acc.as_mut(), - a.as_ref(), - Conj::No, - b.as_ref(), - Conj::No, - None, - E::faer_one(), - Parallelism::None, - true, - ); - }); - }); - } - } -} - -criterion_group!( - name = benches; - config = Criterion::default() - .warm_up_time(Duration::from_secs(1)) - .measurement_time(Duration::from_secs(1)) - .sample_size(10); - targets = - matmul::, - matmul::, - matmul::, - matmul::, - matmul::, - matmul::, -); -criterion_main!(benches); diff --git a/faer-libs/faer-core/katex-header.html b/faer-libs/faer-core/katex-header.html deleted file mode 100644 index 32ac35a411428d1bcf1914b639299df9f86e448c..0000000000000000000000000000000000000000 --- a/faer-libs/faer-core/katex-header.html +++ /dev/null @@ -1,15 +0,0 @@ - - - - diff --git a/faer-libs/faer-core/src/jacobi.rs b/faer-libs/faer-core/src/jacobi.rs deleted file mode 100644 index 59b3ed9e63b285c065a899ccdb80c9c86db84d0d..0000000000000000000000000000000000000000 --- a/faer-libs/faer-core/src/jacobi.rs +++ /dev/null @@ -1,316 +0,0 @@ -use crate::{assert, group_helpers::*, unzipped, zipped, MatMut, RealField}; -use faer_entity::{SimdCtx, SimdGroupFor}; -use reborrow::*; - -#[derive(Copy, Clone, Debug)] -#[repr(C)] -pub struct JacobiRotation { - pub c: T, - pub s: T, -} - -unsafe impl bytemuck::Zeroable for JacobiRotation {} -unsafe impl bytemuck::Pod for JacobiRotation {} - -impl JacobiRotation { - #[inline] - pub fn make_givens(p: E, q: E) -> Self { - if q == E::faer_zero() { - Self { - c: if p < E::faer_zero() { - E::faer_one().faer_neg() - } else { - E::faer_one() - }, - s: E::faer_zero(), - } - } else if p == E::faer_zero() { - Self { - c: E::faer_zero(), - s: if q < E::faer_zero() { - E::faer_one().faer_neg() - } else { - E::faer_one() - }, - } - } else if p.faer_abs() > q.faer_abs() { - let t = q.faer_div(p); - let mut u = E::faer_one().faer_add(t.faer_abs2()).faer_sqrt(); - if p < E::faer_zero() { - u = u.faer_neg(); - } - let c = u.faer_inv(); - let s = t.faer_neg().faer_mul(c); - - Self { c, s } - } else { - let t = p.faer_div(q); - let mut u = E::faer_one().faer_add(t.faer_abs2()).faer_sqrt(); - if q < E::faer_zero() { - u = u.faer_neg(); - } - let s = u.faer_inv().faer_neg(); - let c = t.faer_neg().faer_mul(s); - - Self { c, s } - } - } - - #[inline] - pub fn from_triplet(x: E, y: E, z: E) -> Self { - let abs_y = y.faer_abs(); - let two_abs_y = abs_y.faer_add(abs_y); - if two_abs_y == E::faer_zero() { - Self { - c: E::faer_one(), - s: E::faer_zero(), - } - } else { - let tau = (x.faer_sub(z)).faer_mul(two_abs_y.faer_inv()); - let w = ((tau.faer_mul(tau)).faer_add(E::faer_one())).faer_sqrt(); - let t = if tau > E::faer_zero() { - (tau.faer_add(w)).faer_inv() - } else { - (tau.faer_sub(w)).faer_inv() - }; - - let neg_sign_y = if y > E::faer_zero() { - E::faer_one().faer_neg() - } else { - E::faer_one() - }; - let n = (t.faer_mul(t).faer_add(E::faer_one())) - .faer_sqrt() - .faer_inv(); - - Self { - c: n, - s: neg_sign_y.faer_mul(t).faer_mul(n), - } - } - } - - #[inline] - pub fn apply_on_the_left_2x2(&self, m00: E, m01: E, m10: E, m11: E) -> (E, E, E, E) { - let Self { c, s } = *self; - ( - m00.faer_mul(c).faer_add(m10.faer_mul(s)), - m01.faer_mul(c).faer_add(m11.faer_mul(s)), - s.faer_neg().faer_mul(m00).faer_add(c.faer_mul(m10)), - s.faer_neg().faer_mul(m01).faer_add(c.faer_mul(m11)), - ) - } - - #[inline] - pub fn apply_on_the_right_2x2(&self, m00: E, m01: E, m10: E, m11: E) -> (E, E, E, E) { - let (r00, r01, r10, r11) = self.transpose().apply_on_the_left_2x2(m00, m10, m01, m11); - (r00, r10, r01, r11) - } - - #[inline] - pub fn apply_on_the_left_in_place(&self, x: MatMut<'_, E>, y: MatMut<'_, E>) { - self.apply_on_the_left_in_place_arch(E::Simd::default(), x, y); - } - - #[inline(never)] - fn apply_on_the_left_in_place_fallback(&self, x: MatMut<'_, E>, y: MatMut<'_, E>) { - let Self { c, s } = *self; - zipped!(x, y).for_each(move |unzipped!(mut x, mut y)| { - let x_ = x.read(); - let y_ = y.read(); - x.write(c.faer_mul(x_).faer_add(s.faer_mul(y_))); - y.write(s.faer_neg().faer_mul(x_).faer_add(c.faer_mul(y_))); - }); - } - - #[inline(always)] - pub fn apply_on_the_right_in_place_with_simd_and_offset( - &self, - simd: S, - offset: pulp::Offset>, - x: MatMut<'_, E>, - y: MatMut<'_, E>, - ) { - self.transpose() - .apply_on_the_left_in_place_with_simd_and_offset( - simd, - offset, - x.transpose_mut(), - y.transpose_mut(), - ); - } - - #[inline(always)] - pub fn apply_on_the_left_in_place_with_simd_and_offset( - &self, - simd: S, - offset: pulp::Offset>, - x: MatMut<'_, E>, - y: MatMut<'_, E>, - ) { - let Self { c, s } = *self; - assert!(all(x.nrows() == 1, y.nrows() == 1, x.ncols() == y.ncols())); - - if c == E::faer_one() && s == E::faer_zero() { - return; - } - - if x.col_stride() != 1 || y.col_stride() != 1 { - self.apply_on_the_left_in_place_fallback(x, y); - return; - } - - let simd = SimdFor::::new(simd); - - let x = SliceGroupMut::<'_, E>::new(x.transpose_mut().try_get_contiguous_col_mut(0)); - let y = SliceGroupMut::<'_, E>::new(y.transpose_mut().try_get_contiguous_col_mut(0)); - - let c = simd.splat(c); - let s = simd.splat(s); - - let (x_head, x_body, x_tail) = simd.as_aligned_simd_mut(x, offset); - let (y_head, y_body, y_tail) = simd.as_aligned_simd_mut(y, offset); - - #[inline(always)] - fn process( - simd: SimdFor, - mut x: impl Write>, - mut y: impl Write>, - c: SimdGroupFor, - s: SimdGroupFor, - ) { - let zero = simd.splat(E::faer_zero()); - let x_ = x.read_or(zero); - let y_ = y.read_or(zero); - x.write(simd.mul_add_e(c, x_, simd.mul(s, y_))); - y.write(simd.mul_add_e(c, y_, simd.neg(simd.mul(s, x_)))); - } - - process(simd, x_head, y_head, c, s); - for (x, y) in x_body.into_mut_iter().zip(y_body.into_mut_iter()) { - process(simd, x, y, c, s); - } - process(simd, x_tail, y_tail, c, s); - } - - #[inline] - pub fn apply_on_the_left_in_place_arch( - &self, - arch: E::Simd, - x: MatMut<'_, E>, - y: MatMut<'_, E>, - ) { - struct ApplyOnLeft<'a, E: RealField> { - c: E, - s: E, - x: MatMut<'a, E>, - y: MatMut<'a, E>, - } - - impl pulp::WithSimd for ApplyOnLeft<'_, E> { - type Output = (); - - #[inline(always)] - fn with_simd(self, simd: S) -> Self::Output { - let Self { x, y, c, s } = self; - assert!(all(x.nrows() == 1, y.nrows() == 1, x.ncols() == y.ncols())); - - if c == E::faer_one() && s == E::faer_zero() { - return; - } - - let simd = SimdFor::::new(simd); - - let x = - SliceGroupMut::<'_, E>::new(x.transpose_mut().try_get_contiguous_col_mut(0)); - let y = - SliceGroupMut::<'_, E>::new(y.transpose_mut().try_get_contiguous_col_mut(0)); - - let offset = simd.align_offset(x.rb()); - - let c = simd.splat(c); - let s = simd.splat(s); - - let (x_head, x_body, x_tail) = simd.as_aligned_simd_mut(x, offset); - let (y_head, y_body, y_tail) = simd.as_aligned_simd_mut(y, offset); - - #[inline(always)] - fn process( - simd: SimdFor, - mut x: impl Write>, - mut y: impl Write>, - c: SimdGroupFor, - s: SimdGroupFor, - ) { - let zero = simd.splat(E::faer_zero()); - let x_ = x.read_or(zero); - let y_ = y.read_or(zero); - x.write(simd.mul_add_e(c, x_, simd.mul(s, y_))); - y.write(simd.mul_add_e(c, y_, simd.neg(simd.mul(s, x_)))); - } - - process(simd, x_head, y_head, c, s); - for (x, y) in x_body.into_mut_iter().zip(y_body.into_mut_iter()) { - process(simd, x, y, c, s); - } - process(simd, x_tail, y_tail, c, s); - } - } - - let Self { c, s } = *self; - - let mut x = x; - let mut y = y; - - if x.col_stride() == 1 && y.col_stride() == 1 { - arch.dispatch(ApplyOnLeft::<'_, E> { c, s, x, y }); - } else { - zipped!(x, y).for_each(move |unzipped!(mut x, mut y)| { - let x_ = x.read(); - let y_ = y.read(); - x.write(c.faer_mul(x_).faer_add(s.faer_mul(y_))); - y.write(s.faer_neg().faer_mul(x_).faer_add(c.faer_mul(y_))); - }); - } - } - - #[inline] - pub fn apply_on_the_right_in_place(&self, x: MatMut<'_, E>, y: MatMut<'_, E>) { - self.transpose() - .apply_on_the_left_in_place(x.transpose_mut(), y.transpose_mut()); - } - - #[inline] - pub fn apply_on_the_right_in_place_arch( - &self, - arch: E::Simd, - x: MatMut<'_, E>, - y: MatMut<'_, E>, - ) { - self.transpose().apply_on_the_left_in_place_arch( - arch, - x.transpose_mut(), - y.transpose_mut(), - ); - } - - #[inline] - pub fn transpose(&self) -> Self { - Self { - c: self.c, - s: self.s.faer_neg(), - } - } -} - -impl core::ops::Mul for JacobiRotation { - type Output = Self; - - #[inline] - fn mul(self, rhs: Self) -> Self::Output { - Self { - c: self.c.faer_mul(rhs.c).faer_sub(self.s.faer_mul(rhs.s)), - s: self.c.faer_mul(rhs.s).faer_add(self.s.faer_mul(rhs.c)), - } - } -} diff --git a/faer-libs/faer-core/src/lib.rs b/faer-libs/faer-core/src/lib.rs deleted file mode 100644 index 37036f320ea78bbba8b70cf57b3bc7be75501a34..0000000000000000000000000000000000000000 --- a/faer-libs/faer-core/src/lib.rs +++ /dev/null @@ -1,15566 +0,0 @@ -//! `faer` is a linear algebra library for Rust, with a focus on high performance for -//! medium/large matrices. -//! -//! The core module contains the building blocks of linear algebra: -//! * Matrix structure definitions: [`Mat`], [`MatRef`], and [`MatMut`]. -//! * Coefficient-wise matrix operations, like addition and subtraction: either using the builtin -//! `+` and `-` operators or using the low level api [`zipped!`]. -//! * Matrix multiplication: either using the builtin `*` operator or the low level [`mul`] module. -//! * Triangular matrix solve: the [`solve`] module. -//! * Triangular matrix inverse: the [`inverse`] module. -//! * Householder matrix multiplication: the [`householder`] module. -//! -//! # Example -//! ``` -//! use faer_core::{mat, scale, Mat}; -//! -//! let a = mat![ -//! [1.0, 5.0, 9.0], -//! [2.0, 6.0, 10.0], -//! [3.0, 7.0, 11.0], -//! [4.0, 8.0, 12.0f64], -//! ]; -//! -//! let b = Mat::::from_fn(4, 3, |i, j| (i + j) as f64); -//! -//! let add = &a + &b; -//! let sub = &a - &b; -//! let scale = scale(3.0) * &a; -//! let mul = &a * b.transpose(); -//! ``` -//! -//! # Entity trait -//! Matrices are built on top of the [`Entity`] trait, which describes the prefered memory storage -//! layout for a given type `E`. An entity can be decomposed into a group of units: for a natively -//! supported type ([`f32`], [`f64`], [`c32`], [`c64`]), the unit is simply the type itself, and a -//! group contains a single element. On the other hand, for a type with a more specific preferred -//! layout, like an extended precision floating point type, or a dual number type, the unit would -//! be one of the natively supported types, and the group would be a structure holding the -//! components that build up the full value. -//! -//! To take a more specific example: [`num_complex::Complex`] has a storage memory layout that -//! differs from that of [`c64`] (see [`complex_native`] for more details). Its real and complex -//! components are stored separately, so its unit type is `f64`, while its group type is `Complex`. -//! In practice, this means that for a `Mat`, methods such as [`Mat::col_as_slice`] will return -//! a `&[f64]`. Meanwhile, for a `Mat>`, [`Mat::col_as_slice`] will return -//! `Complex<&[f64]>`, which holds two slices, each pointing respectively to a view over the real -//! and the imaginary components. -//! -//! While the design of the entity trait is unconventional, it helps us achieve much higher -//! performance when targetting non native types, due to the design matching the typical preffered -//! CPU layout for SIMD operations. And for native types, since [`Group` is just -//! `T`](Entity#impl-Entity-for-f64), the entity layer is a no-op, and the matrix layout is -//! compatible with the classic contiguous layout that's commonly used by other libraries. -//! -//! # Memory allocation -//! Since most `faer` crates aim to expose a low level api for optimal performance, most algorithms -//! try to defer memory allocation to the user. -//! -//! However, since a lot of algorithms need some form of temporary space for intermediate -//! computations, they may ask for a slice of memory for that purpose, by taking a [`stack: -//! PodStack`](dyn_stack::PodStack) parameter. A `PodStack` is a thin wrapper over a slice of -//! memory bytes. This memory may come from any valid source (heap allocation, fixed-size array on -//! the stack, etc.). The functions taking a `PodStack` parameter have a corresponding function -//! with a similar name ending in `_req` that returns the memory requirements of the algorithm. For -//! example: -//! [`householder::apply_block_householder_on_the_left_in_place_with_conj`] and -//! [`householder::apply_block_householder_on_the_left_in_place_req`]. -//! -//! The memory stack may be reused in user-code to avoid repeated allocations, and it is also -//! possible to compute the sum ([`dyn_stack::StackReq::all_of`]) or union -//! ([`dyn_stack::StackReq::any_of`]) of multiple requirements, in order to optimally combine them -//! into a single allocation. -//! -//! After computing a [`dyn_stack::StackReq`], one can query its size and alignment to allocate the -//! required memory. The simplest way to do so is through [`dyn_stack::GlobalMemBuffer::new`]. - -#![allow(non_snake_case)] -#![allow(clippy::type_complexity)] -#![allow(clippy::too_many_arguments)] -#![warn(missing_docs)] -#![cfg_attr(docsrs, feature(doc_cfg))] -#![cfg_attr(not(feature = "std"), no_std)] - -use faer_entity::*; -pub use faer_entity::{ - ComplexField, Conjugate, Entity, GroupFor, IdentityGroup, RealField, SimdCtx, SimpleEntity, -}; - -#[doc(hidden)] -pub use equator::{assert, debug_assert}; - -pub use dyn_stack; -pub use reborrow; -pub use faer_entity::pulp; - -use coe::Coerce; -use core::{ - fmt::Debug, marker::PhantomData, mem::ManuallyDrop, ptr::NonNull, sync::atomic::AtomicUsize, -}; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use group_helpers::{SliceGroup, SliceGroupMut}; -use num_complex::Complex; -use pulp::Simd; -use reborrow::*; - -#[cfg(feature = "perf-warn")] -#[macro_export] -#[doc(hidden)] -macro_rules! __perf_warn { - ($name: ident) => {{ - #[inline(always)] - #[allow(non_snake_case)] - fn $name() -> &'static ::core::sync::atomic::AtomicBool { - static $name: ::core::sync::atomic::AtomicBool = - ::core::sync::atomic::AtomicBool::new(false); - &$name - } - ::core::matches!( - $name().compare_exchange( - false, - true, - ::core::sync::atomic::Ordering::Relaxed, - ::core::sync::atomic::Ordering::Relaxed, - ), - Ok(_) - ) - }}; -} - -#[doc(hidden)] -pub trait DivCeil: Sized { - fn msrv_div_ceil(self, rhs: Self) -> Self; - fn msrv_next_multiple_of(self, rhs: Self) -> Self; - fn msrv_checked_next_multiple_of(self, rhs: Self) -> Option; -} - -impl DivCeil for usize { - #[inline] - fn msrv_div_ceil(self, rhs: Self) -> Self { - let d = self / rhs; - let r = self % rhs; - if r > 0 { - d + 1 - } else { - d - } - } - - #[inline] - fn msrv_next_multiple_of(self, rhs: Self) -> Self { - match self % rhs { - 0 => self, - r => self + (rhs - r), - } - } - - #[inline] - fn msrv_checked_next_multiple_of(self, rhs: Self) -> Option { - { - match self.checked_rem(rhs)? { - 0 => Some(self), - r => self.checked_add(rhs - r), - } - } - } -} - -/// Specifies whether the triangular lower or upper part of a matrix should be accessed. -#[derive(Copy, Clone, Debug, PartialEq)] -pub enum Side { - /// Lower half should be accessed. - Lower, - /// Upper half should be accessed. - Upper, -} - -/// Errors that can occur in sparse algorithms. -#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] -#[non_exhaustive] -pub enum FaerError { - /// An index exceeding the maximum value (`I::Signed::MAX` for a given index type `I`). - IndexOverflow, - /// Memory allocation failed. - OutOfMemory, -} - -impl core::fmt::Display for FaerError { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - core::fmt::Debug::fmt(self, f) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for FaerError {} - -extern crate alloc; - -pub mod householder; -#[doc(hidden)] -pub mod jacobi; - -pub mod inverse; -pub mod mul; -pub mod permutation; -pub mod solve; - -pub mod matrix_ops; - -#[cfg(feature = "serde")] -pub mod serde_impl; -pub mod sparse; - -/// Thin wrapper used for scalar multiplication of a matrix by a scalar value. -pub use matrix_ops::scale; - -#[doc(hidden)] -pub mod simd; - -#[doc(hidden)] -pub use faer_entity::transmute_unchecked; - -pub mod complex_native; -pub use complex_native::*; - -mod sort; - -/// Whether a matrix should be implicitly conjugated when read or not. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum Conj { - /// Do conjugate. - Yes, - /// Do not conjugate. - No, -} - -impl Conj { - /// Combine `self` and `other` to create a new conjugation object. - #[inline] - pub fn compose(self, other: Conj) -> Conj { - if self == other { - Conj::No - } else { - Conj::Yes - } - } -} - -/// Trait for types that can be converted to a row view. -pub trait AsRowRef { - /// Convert to a row view. - fn as_row_ref(&self) -> RowRef<'_, E>; -} -/// Trait for types that can be converted to a mutable row view. -pub trait AsRowMut { - /// Convert to a mutable row view. - fn as_row_mut(&mut self) -> RowMut<'_, E>; -} - -/// Trait for types that can be converted to a column view. -pub trait AsColRef { - /// Convert to a column view. - fn as_col_ref(&self) -> ColRef<'_, E>; -} -/// Trait for types that can be converted to a mutable col view. -pub trait AsColMut { - /// Convert to a mutable column view. - fn as_col_mut(&mut self) -> ColMut<'_, E>; -} - -/// Trait for types that can be converted to a matrix view. -/// -/// This trait is implemented for types of the matrix family, like [`Mat`], -/// [`MatRef`], and [`MatMut`], but not for types like [`Col`], [`Row`], or -/// their families. For a more general trait, see [`As2D`]. -pub trait AsMatRef { - /// Convert to a matrix view. - fn as_mat_ref(&self) -> MatRef<'_, E>; -} -/// Trait for types that can be converted to a mutable matrix view. -/// -/// This trait is implemented for types of the matrix family, like [`Mat`], -/// [`MatRef`], and [`MatMut`], but not for types like [`Col`], [`Row`], or -/// their families. For a more general trait, see [`As2D`]. -pub trait AsMatMut { - /// Convert to a mutable matrix view. - fn as_mat_mut(&mut self) -> MatMut<'_, E>; -} - -/// Trait for types that can be converted to a 2D matrix view. -/// -/// This trait is implemented for any type that can be represented as a -/// 2D matrix view, like [`Mat`], [`Row`], [`Col`], and their respective -/// references and mutable references. For a trait specific to the matrix -/// family, see [`AsMatRef`] or [`AsMatMut`]. -pub trait As2D { - /// Convert to a 2D matrix view. - fn as_2d_ref(&self) -> MatRef<'_, E>; -} -/// Trait for types that can be converted to a mutable 2D matrix view. -/// -/// This trait is implemented for any type that can be represented as a -/// 2D matrix view, like [`Mat`], [`Row`], [`Col`], and their respective -/// references and mutable references. For a trait specific to the matrix -/// family, see [`AsMatRef`] or [`AsMatMut`]. -pub trait As2DMut { - /// Convert to a mutable 2D matrix view. - fn as_2d_mut(&mut self) -> MatMut<'_, E>; -} - -// AS COL -const _: () = { - impl AsColRef for ColRef<'_, E> { - #[inline] - fn as_col_ref(&self) -> ColRef<'_, E> { - *self - } - } - impl AsColRef for &'_ ColRef<'_, E> { - #[inline] - fn as_col_ref(&self) -> ColRef<'_, E> { - **self - } - } - impl AsColRef for ColMut<'_, E> { - #[inline] - fn as_col_ref(&self) -> ColRef<'_, E> { - (*self).rb() - } - } - impl AsColRef for &'_ ColMut<'_, E> { - #[inline] - fn as_col_ref(&self) -> ColRef<'_, E> { - (**self).rb() - } - } - impl AsColRef for Col { - #[inline] - fn as_col_ref(&self) -> ColRef<'_, E> { - (*self).as_ref() - } - } - impl AsColRef for &'_ Col { - #[inline] - fn as_col_ref(&self) -> ColRef<'_, E> { - (**self).as_ref() - } - } - - impl AsColMut for ColMut<'_, E> { - #[inline] - fn as_col_mut(&mut self) -> ColMut<'_, E> { - (*self).rb_mut() - } - } - - impl AsColMut for &'_ mut ColMut<'_, E> { - #[inline] - fn as_col_mut(&mut self) -> ColMut<'_, E> { - (**self).rb_mut() - } - } - - impl AsColMut for Col { - #[inline] - fn as_col_mut(&mut self) -> ColMut<'_, E> { - (*self).as_mut() - } - } - - impl AsColMut for &'_ mut Col { - #[inline] - fn as_col_mut(&mut self) -> ColMut<'_, E> { - (**self).as_mut() - } - } -}; - -// AS ROW -const _: () = { - impl AsRowRef for RowRef<'_, E> { - #[inline] - fn as_row_ref(&self) -> RowRef<'_, E> { - *self - } - } - impl AsRowRef for &'_ RowRef<'_, E> { - #[inline] - fn as_row_ref(&self) -> RowRef<'_, E> { - **self - } - } - impl AsRowRef for RowMut<'_, E> { - #[inline] - fn as_row_ref(&self) -> RowRef<'_, E> { - (*self).rb() - } - } - impl AsRowRef for &'_ RowMut<'_, E> { - #[inline] - fn as_row_ref(&self) -> RowRef<'_, E> { - (**self).rb() - } - } - impl AsRowRef for Row { - #[inline] - fn as_row_ref(&self) -> RowRef<'_, E> { - (*self).as_ref() - } - } - impl AsRowRef for &'_ Row { - #[inline] - fn as_row_ref(&self) -> RowRef<'_, E> { - (**self).as_ref() - } - } - - impl AsRowMut for RowMut<'_, E> { - #[inline] - fn as_row_mut(&mut self) -> RowMut<'_, E> { - (*self).rb_mut() - } - } - - impl AsRowMut for &'_ mut RowMut<'_, E> { - #[inline] - fn as_row_mut(&mut self) -> RowMut<'_, E> { - (**self).rb_mut() - } - } - - impl AsRowMut for Row { - #[inline] - fn as_row_mut(&mut self) -> RowMut<'_, E> { - (*self).as_mut() - } - } - - impl AsRowMut for &'_ mut Row { - #[inline] - fn as_row_mut(&mut self) -> RowMut<'_, E> { - (**self).as_mut() - } - } -}; - -// AS MAT -const _: () = { - impl AsMatRef for MatRef<'_, E> { - #[inline] - fn as_mat_ref(&self) -> MatRef<'_, E> { - *self - } - } - impl AsMatRef for &'_ MatRef<'_, E> { - #[inline] - fn as_mat_ref(&self) -> MatRef<'_, E> { - **self - } - } - impl AsMatRef for MatMut<'_, E> { - #[inline] - fn as_mat_ref(&self) -> MatRef<'_, E> { - (*self).rb() - } - } - impl AsMatRef for &'_ MatMut<'_, E> { - #[inline] - fn as_mat_ref(&self) -> MatRef<'_, E> { - (**self).rb() - } - } - impl AsMatRef for Mat { - #[inline] - fn as_mat_ref(&self) -> MatRef<'_, E> { - (*self).as_ref() - } - } - impl AsMatRef for &'_ Mat { - #[inline] - fn as_mat_ref(&self) -> MatRef<'_, E> { - (**self).as_ref() - } - } - - impl AsMatMut for MatMut<'_, E> { - #[inline] - fn as_mat_mut(&mut self) -> MatMut<'_, E> { - (*self).rb_mut() - } - } - - impl AsMatMut for &'_ mut MatMut<'_, E> { - #[inline] - fn as_mat_mut(&mut self) -> MatMut<'_, E> { - (**self).rb_mut() - } - } - - impl AsMatMut for Mat { - #[inline] - fn as_mat_mut(&mut self) -> MatMut<'_, E> { - (*self).as_mut() - } - } - - impl AsMatMut for &'_ mut Mat { - #[inline] - fn as_mat_mut(&mut self) -> MatMut<'_, E> { - (**self).as_mut() - } - } -}; - -// AS 2D -const _: () = { - // Matrix family - impl As2D for &'_ MatRef<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - **self - } - } - - impl As2D for MatRef<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - *self - } - } - - impl As2D for &'_ MatMut<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - (**self).rb() - } - } - - impl As2D for MatMut<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - (*self).rb() - } - } - - impl As2D for &'_ Mat { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - (**self).as_ref() - } - } - - impl As2D for Mat { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - (*self).as_ref() - } - } - - // Row Family - impl As2D for &'_ RowRef<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - self.as_2d() - } - } - - impl As2D for RowRef<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - self.as_2d() - } - } - - impl As2D for &'_ RowMut<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - (**self).rb().as_2d() - } - } - - impl As2D for RowMut<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - self.rb().as_2d() - } - } - - impl As2D for &'_ Row { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - (**self).as_ref().as_2d() - } - } - - impl As2D for Row { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - self.as_ref().as_2d() - } - } - - // Col Family - impl As2D for &'_ ColRef<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - self.as_2d() - } - } - - impl As2D for ColRef<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - self.as_2d() - } - } - - impl As2D for &'_ ColMut<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - (**self).rb().as_2d() - } - } - - impl As2D for ColMut<'_, E> { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - self.rb().as_2d() - } - } - - impl As2D for &'_ Col { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - (**self).as_ref().as_2d() - } - } - - impl As2D for Col { - #[inline] - fn as_2d_ref(&self) -> MatRef<'_, E> { - self.as_ref().as_2d() - } - } -}; - -// AS 2D MUT -const _: () = { - // Matrix family - impl As2DMut for &'_ mut MatMut<'_, E> { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - (**self).rb_mut() - } - } - - impl As2DMut for MatMut<'_, E> { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - (*self).rb_mut() - } - } - - impl As2DMut for &'_ mut Mat { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - (**self).as_mut() - } - } - - impl As2DMut for Mat { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - (*self).as_mut() - } - } - - // Row Family - impl As2DMut for &'_ mut RowMut<'_, E> { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - (**self).rb_mut().as_2d_mut() - } - } - - impl As2DMut for RowMut<'_, E> { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - self.rb_mut().as_2d_mut() - } - } - - impl As2DMut for &'_ mut Row { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - (**self).as_mut().as_2d_mut() - } - } - - impl As2DMut for Row { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - self.as_mut().as_2d_mut() - } - } - - // Col Family - impl As2DMut for &'_ mut ColMut<'_, E> { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - (**self).rb_mut().as_2d_mut() - } - } - - impl As2DMut for ColMut<'_, E> { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - self.rb_mut().as_2d_mut() - } - } - - impl As2DMut for &'_ mut Col { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - (**self).as_mut().as_2d_mut() - } - } - - impl As2DMut for Col { - #[inline] - fn as_2d_mut(&mut self) -> MatMut<'_, E> { - self.as_mut().as_2d_mut() - } - } -}; - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::Matrix for MatRef<'_, E> { - #[inline] - fn rows(&self) -> usize { - self.nrows() - } - #[inline] - fn cols(&self) -> usize { - self.ncols() - } - #[inline] - fn access(&self) -> matrixcompare_core::Access<'_, E> { - matrixcompare_core::Access::Dense(self) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::DenseAccess for MatRef<'_, E> { - #[inline] - fn fetch_single(&self, row: usize, col: usize) -> E { - self.read(row, col) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::Matrix for MatMut<'_, E> { - #[inline] - fn rows(&self) -> usize { - self.nrows() - } - #[inline] - fn cols(&self) -> usize { - self.ncols() - } - #[inline] - fn access(&self) -> matrixcompare_core::Access<'_, E> { - matrixcompare_core::Access::Dense(self) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::DenseAccess for MatMut<'_, E> { - #[inline] - fn fetch_single(&self, row: usize, col: usize) -> E { - self.read(row, col) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::Matrix for Mat { - #[inline] - fn rows(&self) -> usize { - self.nrows() - } - #[inline] - fn cols(&self) -> usize { - self.ncols() - } - #[inline] - fn access(&self) -> matrixcompare_core::Access<'_, E> { - matrixcompare_core::Access::Dense(self) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::DenseAccess for Mat { - #[inline] - fn fetch_single(&self, row: usize, col: usize) -> E { - self.read(row, col) - } -} - -#[repr(C)] -struct VecImpl { - ptr: GroupCopyFor>, - len: usize, - stride: isize, -} -#[repr(C)] -struct VecOwnImpl { - ptr: GroupCopyFor>, - len: usize, -} - -#[repr(C)] -struct MatImpl { - ptr: GroupCopyFor>, - nrows: usize, - ncols: usize, - row_stride: isize, - col_stride: isize, -} -#[repr(C)] -struct MatOwnImpl { - ptr: GroupCopyFor>, - nrows: usize, - ncols: usize, -} - -impl Copy for VecImpl {} -impl Clone for VecImpl { - #[inline(always)] - fn clone(&self) -> Self { - *self - } -} - -impl Copy for MatImpl {} -impl Clone for MatImpl { - #[inline(always)] - fn clone(&self) -> Self { - *self - } -} - -/// Generic matrix container. -#[derive(Copy, Clone)] -pub struct Matrix { - inner: M, -} - -/// Specialized containers that are used with [`Matrix`]. -pub mod inner { - use super::*; - use crate::group_helpers::VecGroup; - - impl Copy for DiagRef<'_, E> {} - impl Clone for DiagRef<'_, E> { - #[inline(always)] - fn clone(&self) -> Self { - *self - } - } - - impl Copy for DenseRowRef<'_, E> {} - impl Clone for DenseRowRef<'_, E> { - #[inline(always)] - fn clone(&self) -> Self { - *self - } - } - - impl Copy for DenseColRef<'_, E> {} - impl Clone for DenseColRef<'_, E> { - #[inline(always)] - fn clone(&self) -> Self { - *self - } - } - - impl Copy for DenseRef<'_, E> {} - impl Clone for DenseRef<'_, E> { - #[inline(always)] - fn clone(&self) -> Self { - *self - } - } - - impl Copy for PermRef<'_, I, E> {} - impl Clone for PermRef<'_, I, E> { - #[inline(always)] - fn clone(&self) -> Self { - *self - } - } - - /// Immutable permutation view. - #[repr(C)] - #[derive(Debug)] - pub struct PermRef<'a, I, E: Entity> { - pub(crate) forward: &'a [I], - pub(crate) inverse: &'a [I], - pub(crate) __marker: PhantomData, - } - /// Mutable permutation view. - #[repr(C)] - #[derive(Debug)] - pub struct PermMut<'a, I, E: Entity> { - pub(crate) forward: &'a mut [I], - pub(crate) inverse: &'a mut [I], - pub(crate) __marker: PhantomData, - } - /// Owned permutation. - #[repr(C)] - #[derive(Debug)] - pub struct PermOwn { - pub(crate) forward: alloc::boxed::Box<[I]>, - pub(crate) inverse: alloc::boxed::Box<[I]>, - pub(crate) __marker: PhantomData, - } - - /// Immutable diagonal matrix view. - #[repr(C)] - pub struct DiagRef<'a, E: Entity> { - pub(crate) inner: ColRef<'a, E>, - } - - /// Mutable diagonal matrix view. - #[repr(C)] - pub struct DiagMut<'a, E: Entity> { - pub(crate) inner: ColMut<'a, E>, - } - - /// Owned diagonal matrix. - #[repr(C)] - pub struct DiagOwn { - pub(crate) inner: Col, - } - - /// Immutable column vector view. - #[repr(C)] - pub struct DenseColRef<'a, E: Entity> { - pub(crate) inner: VecImpl, - pub(crate) __marker: PhantomData<&'a E>, - } - - /// Mutable column vector view. - #[repr(C)] - pub struct DenseColMut<'a, E: Entity> { - pub(crate) inner: VecImpl, - pub(crate) __marker: PhantomData<&'a mut E>, - } - - /// Owned column vector. - #[repr(C)] - pub struct DenseColOwn { - pub(crate) inner: VecOwnImpl, - pub(crate) row_capacity: usize, - } - - /// Immutable row vector view. - #[repr(C)] - pub struct DenseRowRef<'a, E: Entity> { - pub(crate) inner: VecImpl, - pub(crate) __marker: PhantomData<&'a E>, - } - - /// Mutable row vector view. - #[repr(C)] - pub struct DenseRowMut<'a, E: Entity> { - pub(crate) inner: VecImpl, - pub(crate) __marker: PhantomData<&'a mut E>, - } - - /// Owned row vector. - #[repr(C)] - pub struct DenseRowOwn { - pub(crate) inner: VecOwnImpl, - pub(crate) col_capacity: usize, - } - - /// Immutable dense matrix view. - #[repr(C)] - pub struct DenseRef<'a, E: Entity> { - pub(crate) inner: MatImpl, - pub(crate) __marker: PhantomData<&'a E>, - } - - /// Mutable dense matrix view. - #[repr(C)] - pub struct DenseMut<'a, E: Entity> { - pub(crate) inner: MatImpl, - pub(crate) __marker: PhantomData<&'a mut E>, - } - - /// Owned dense matrix. - #[repr(C)] - pub struct DenseOwn { - pub(crate) inner: MatOwnImpl, - pub(crate) row_capacity: usize, - pub(crate) col_capacity: usize, - } - - /// Scaling factor. - #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] - #[repr(transparent)] - pub struct Scale(pub E); - - use permutation::Index; - - /// Immutable sparse matrix view, in column-major order. - #[derive(Debug)] - pub struct SparseColMatRef<'a, I: Index, E: Entity> { - pub(crate) symbolic: sparse::SymbolicSparseColMatRef<'a, I>, - pub(crate) values: SliceGroup<'a, E>, - } - - /// Immutable sparse matrix view, in row-major order. - #[derive(Debug)] - pub struct SparseRowMatRef<'a, I: Index, E: Entity> { - pub(crate) symbolic: sparse::SymbolicSparseRowMatRef<'a, I>, - pub(crate) values: SliceGroup<'a, E>, - } - - /// Mutable sparse matrix view, in column-major order. - #[derive(Debug)] - pub struct SparseColMatMut<'a, I: Index, E: Entity> { - pub(crate) symbolic: sparse::SymbolicSparseColMatRef<'a, I>, - pub(crate) values: SliceGroupMut<'a, E>, - } - - /// Mutable sparse matrix view, in row-major order. - #[derive(Debug)] - pub struct SparseRowMatMut<'a, I: Index, E: Entity> { - pub(crate) symbolic: sparse::SymbolicSparseRowMatRef<'a, I>, - pub(crate) values: SliceGroupMut<'a, E>, - } - - /// Owned sparse matrix, in column-major order. - #[derive(Debug, Clone)] - pub struct SparseColMat { - pub(crate) symbolic: sparse::SymbolicSparseColMat, - pub(crate) values: VecGroup, - } - - /// Owned sparse matrix, in row-major order. - #[derive(Debug, Clone)] - pub struct SparseRowMat { - pub(crate) symbolic: sparse::SymbolicSparseRowMat, - pub(crate) values: VecGroup, - } - - impl Copy for SparseRowMatRef<'_, I, E> {} - impl Clone for SparseRowMatRef<'_, I, E> { - #[inline] - fn clone(&self) -> Self { - *self - } - } - impl Copy for SparseColMatRef<'_, I, E> {} - impl Clone for SparseColMatRef<'_, I, E> { - #[inline] - fn clone(&self) -> Self { - *self - } - } -} - -/// Advanced: Helper types for working with [`GroupFor`] in generic contexts. -pub mod group_helpers { - pub use pulp::{Read, Write}; - - /// Analogous to [`alloc::vec::Vec`] for groups. - pub struct VecGroup> { - inner: GroupFor>, - } - - impl Clone for VecGroup { - #[inline] - fn clone(&self) -> Self { - Self { - inner: E::faer_map(E::faer_as_ref(&self.inner), |v| (*v).clone()), - } - } - } - - impl Debug for VecGroup { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.as_slice().fmt(f) - } - } - - unsafe impl Sync for VecGroup {} - unsafe impl Send for VecGroup {} - - impl VecGroup { - /// Create a new [`VecGroup`] from a group of [`alloc::vec::Vec`]. - #[inline] - pub fn from_inner(inner: GroupFor>) -> Self { - Self { inner } - } - - /// Consume `self` to return a group of [`alloc::vec::Vec`]. - #[inline] - pub fn into_inner(self) -> GroupFor> { - self.inner - } - - /// Return a reference to the inner group of [`alloc::vec::Vec`]. - #[inline] - pub fn as_inner_ref(&self) -> GroupFor> { - E::faer_as_ref(&self.inner) - } - - /// Return a mutable reference to the inner group of [`alloc::vec::Vec`]. - #[inline] - pub fn as_inner_mut(&mut self) -> GroupFor> { - E::faer_as_mut(&mut self.inner) - } - - /// Return a [`SliceGroup`] view over the elements of `self`. - #[inline] - pub fn as_slice(&self) -> SliceGroup<'_, E, T> { - SliceGroup::new(E::faer_map( - E::faer_as_ref(&self.inner), - #[inline] - |slice| &**slice, - )) - } - - /// Return a [`SliceGroupMut`] mutable view over the elements of `self`. - #[inline] - pub fn as_slice_mut(&mut self) -> SliceGroupMut<'_, E, T> { - SliceGroupMut::new(E::faer_map( - E::faer_as_mut(&mut self.inner), - #[inline] - |slice| &mut **slice, - )) - } - - /// Create an empty [`VecGroup`]. - #[inline] - pub fn new() -> Self { - Self { - inner: E::faer_map(E::UNIT, |()| alloc::vec::Vec::new()), - } - } - - /// Returns the length of the vector group. - #[inline] - pub fn len(&self) -> usize { - let mut len = usize::MAX; - E::faer_map( - E::faer_as_ref(&self.inner), - #[inline(always)] - |slice| len = Ord::min(len, slice.len()), - ); - len - } - - /// Returns the capacity of the vector group. - #[inline] - pub fn capacity(&self) -> usize { - let mut cap = usize::MAX; - E::faer_map( - E::faer_as_ref(&self.inner), - #[inline(always)] - |slice| cap = Ord::min(cap, slice.capacity()), - ); - cap - } - - /// Reserve enough capacity for extra `additional` elements. - pub fn reserve(&mut self, additional: usize) { - E::faer_map(E::faer_as_mut(&mut self.inner), |v| v.reserve(additional)); - } - - /// Reserve exactly enough capacity for extra `additional` elements. - pub fn reserve_exact(&mut self, additional: usize) { - E::faer_map(E::faer_as_mut(&mut self.inner), |v| { - v.reserve_exact(additional) - }); - } - - /// Try to reserve enough capacity for extra `additional` elements. - pub fn try_reserve( - &mut self, - additional: usize, - ) -> Result<(), alloc::collections::TryReserveError> { - let mut result = Ok(()); - E::faer_map(E::faer_as_mut(&mut self.inner), |v| match &result { - Ok(()) => result = v.try_reserve(additional), - Err(_) => {} - }); - result - } - - /// Try to reserve exactly enough capacity for extra `additional` elements. - pub fn try_reserve_exact( - &mut self, - additional: usize, - ) -> Result<(), alloc::collections::TryReserveError> { - let mut result = Ok(()); - E::faer_map(E::faer_as_mut(&mut self.inner), |v| match &result { - Ok(()) => result = v.try_reserve_exact(additional), - Err(_) => {} - }); - result - } - - /// Truncate the length of the vector to `len`. - pub fn truncate(&mut self, len: usize) { - E::faer_map(E::faer_as_mut(&mut self.inner), |v| v.truncate(len)); - } - - /// Clear the vector, making it empty. - pub fn clear(&mut self) { - E::faer_map(E::faer_as_mut(&mut self.inner), |v| v.clear()); - } - - /// Resize the vector to `new_len`, filling the new elements with - /// `value`. - pub fn resize(&mut self, new_len: usize, value: GroupFor) - where - T: Clone, - { - E::faer_map( - E::faer_zip(E::faer_as_mut(&mut self.inner), value), - |(v, value)| v.resize(new_len, value), - ); - } - - /// Resize the vector to `new_len`, filling the new elements with - /// the output of `f`. - pub fn resize_with(&mut self, new_len: usize, f: impl FnMut() -> GroupFor) { - let len = self.len(); - let mut f = f; - if new_len <= len { - self.truncate(new_len); - } else { - self.reserve(new_len - len); - for _ in len..new_len { - self.push(f()) - } - } - } - - /// Push a new element to the end of `self`. - #[inline] - pub fn push(&mut self, value: GroupFor) { - E::faer_map( - E::faer_zip(E::faer_as_mut(&mut self.inner), value), - #[inline] - |(v, value)| v.push(value), - ); - } - - /// Remove a new element from the end of `self`, and return it. - #[inline] - pub fn pop(&mut self) -> Option> { - if self.len() >= 1 { - Some(E::faer_map( - E::faer_as_mut(&mut self.inner), - #[inline] - |v| v.pop().unwrap(), - )) - } else { - None - } - } - - /// Remove a new element from position `index`, and return it. - #[inline] - pub fn remove(&mut self, index: usize) -> GroupFor { - E::faer_map( - E::faer_as_mut(&mut self.inner), - #[inline] - |v| v.remove(index), - ) - } - } - - /// Do conjugate. - #[derive(Copy, Clone, Debug)] - pub struct YesConj; - /// Do not conjugate. - #[derive(Copy, Clone, Debug)] - pub struct NoConj; - - /// Similar to [`Conj`], but determined at compile time instead of runtime. - pub trait ConjTy: Copy + Debug { - /// The corresponding [`Conj`] value. - const CONJ: Conj; - /// The opposing conjugation type. - type Flip: ConjTy; - - /// Returns an instance of the corresponding conjugation type. - fn flip(self) -> Self::Flip; - } - - impl ConjTy for YesConj { - const CONJ: Conj = Conj::Yes; - type Flip = NoConj; - #[inline(always)] - fn flip(self) -> Self::Flip { - NoConj - } - } - impl ConjTy for NoConj { - const CONJ: Conj = Conj::No; - type Flip = YesConj; - #[inline(always)] - fn flip(self) -> Self::Flip { - YesConj - } - } - - use super::*; - use crate::{assert, debug_assert}; - use core::ops::Range; - - /// Wrapper for simd operations for type `E`. - pub struct SimdFor { - /// Simd token. - pub simd: S, - __marker: PhantomData, - } - - impl Copy for SimdFor {} - impl Clone for SimdFor { - #[inline] - fn clone(&self) -> Self { - *self - } - } - - impl SimdFor { - /// Create a new wrapper from a simd token. - #[inline(always)] - pub fn new(simd: S) -> Self { - Self { - simd, - __marker: PhantomData, - } - } - - /// Computes the alignment offset for subsequent aligned loads. - #[inline(always)] - pub fn align_offset(self, slice: SliceGroup<'_, E>) -> pulp::Offset> { - let slice = E::faer_first(slice.into_inner()); - E::faer_align_offset(self.simd, slice.as_ptr(), slice.len()) - } - - /// Computes the alignment offset for subsequent aligned loads from a pointer. - #[inline(always)] - pub fn align_offset_ptr( - self, - ptr: GroupFor, - len: usize, - ) -> pulp::Offset> { - E::faer_align_offset(self.simd, E::faer_first(ptr), len) - } - - /// Convert a slice to a slice over vector registers, and a scalar tail. - #[inline(always)] - pub fn as_simd( - self, - slice: SliceGroup<'_, E>, - ) -> (SliceGroup<'_, E, SimdUnitFor>, SliceGroup<'_, E>) { - let (head, tail) = slice_as_simd::(slice.into_inner()); - (SliceGroup::new(head), SliceGroup::new(tail)) - } - - /// Convert a mutable slice to a slice over vector registers, and a scalar tail. - #[inline(always)] - pub fn as_simd_mut( - self, - slice: SliceGroupMut<'_, E>, - ) -> ( - SliceGroupMut<'_, E, SimdUnitFor>, - SliceGroupMut<'_, E>, - ) { - let (head, tail) = slice_as_mut_simd::(slice.into_inner()); - (SliceGroupMut::new(head), SliceGroupMut::new(tail)) - } - - /// Convert a slice to a partial register prefix and suffix, and a vector register slice - /// (body). - #[inline(always)] - pub fn as_aligned_simd( - self, - slice: SliceGroup<'_, E>, - offset: pulp::Offset>, - ) -> ( - Prefix<'_, E, S>, - SliceGroup<'_, E, SimdUnitFor>, - Suffix<'_, E, S>, - ) { - let (head_tail, body) = E::faer_unzip(E::faer_map(slice.into_inner(), |slice| { - let (head, body, tail) = E::faer_slice_as_aligned_simd(self.simd, slice, offset); - ((head, tail), body) - })); - - let (head, tail) = E::faer_unzip(head_tail); - - unsafe { - ( - Prefix( - transmute_unchecked::< - GroupCopyFor>, - GroupCopyFor>, - >(into_copy::(head)), - PhantomData, - ), - SliceGroup::new(body), - Suffix( - transmute_unchecked::< - GroupCopyFor>, - GroupCopyFor>, - >(into_copy::(tail)), - PhantomData, - ), - ) - } - } - - /// Convert a mutable slice to a partial register prefix and suffix, and a vector register - /// slice (body). - #[inline(always)] - pub fn as_aligned_simd_mut( - self, - slice: SliceGroupMut<'_, E>, - offset: pulp::Offset>, - ) -> ( - PrefixMut<'_, E, S>, - SliceGroupMut<'_, E, SimdUnitFor>, - SuffixMut<'_, E, S>, - ) { - let (head_tail, body) = E::faer_unzip(E::faer_map(slice.into_inner(), |slice| { - let (head, body, tail) = - E::faer_slice_as_aligned_simd_mut(self.simd, slice, offset); - ((head, tail), body) - })); - - let (head, tail) = E::faer_unzip(head_tail); - - ( - PrefixMut( - unsafe { - transmute_unchecked::< - GroupFor>, - GroupFor>, - >(head) - }, - PhantomData, - ), - SliceGroupMut::new(body), - SuffixMut( - unsafe { - transmute_unchecked::< - GroupFor>, - GroupFor>, - >(tail) - }, - PhantomData, - ), - ) - } - - /// Fill all the register lanes with the same value. - #[inline(always)] - pub fn splat(self, value: E) -> SimdGroupFor { - E::faer_simd_splat(self.simd, value) - } - - /// Returns `lhs * rhs`. - #[inline(always)] - pub fn scalar_mul(self, lhs: E, rhs: E) -> E { - E::faer_simd_scalar_mul(self.simd, lhs, rhs) - } - /// Returns `conj(lhs) * rhs`. - #[inline(always)] - pub fn scalar_conj_mul(self, lhs: E, rhs: E) -> E { - E::faer_simd_scalar_conj_mul(self.simd, lhs, rhs) - } - /// Returns an estimate of `lhs * rhs + acc`. - #[inline(always)] - pub fn scalar_mul_add_e(self, lhs: E, rhs: E, acc: E) -> E { - E::faer_simd_scalar_mul_adde(self.simd, lhs, rhs, acc) - } - /// Returns an estimate of `conj(lhs) * rhs + acc`. - #[inline(always)] - pub fn scalar_conj_mul_add_e(self, lhs: E, rhs: E, acc: E) -> E { - E::faer_simd_scalar_conj_mul_adde(self.simd, lhs, rhs, acc) - } - - /// Returns an estimate of `op(lhs) * rhs`, where `op` is either the conjugation - /// or the identity operation. - #[inline(always)] - pub fn scalar_conditional_conj_mul(self, conj: C, lhs: E, rhs: E) -> E { - let _ = conj; - if C::CONJ == Conj::Yes { - self.scalar_conj_mul(lhs, rhs) - } else { - self.scalar_mul(lhs, rhs) - } - } - /// Returns an estimate of `op(lhs) * rhs + acc`, where `op` is either the conjugation or - /// the identity operation. - #[inline(always)] - pub fn scalar_conditional_conj_mul_add_e( - self, - conj: C, - lhs: E, - rhs: E, - acc: E, - ) -> E { - let _ = conj; - if C::CONJ == Conj::Yes { - self.scalar_conj_mul_add_e(lhs, rhs, acc) - } else { - self.scalar_mul_add_e(lhs, rhs, acc) - } - } - - /// Returns `lhs + rhs`. - #[inline(always)] - pub fn add(self, lhs: SimdGroupFor, rhs: SimdGroupFor) -> SimdGroupFor { - E::faer_simd_add(self.simd, lhs, rhs) - } - /// Returns `lhs - rhs`. - #[inline(always)] - pub fn sub(self, lhs: SimdGroupFor, rhs: SimdGroupFor) -> SimdGroupFor { - E::faer_simd_sub(self.simd, lhs, rhs) - } - /// Returns `-a`. - #[inline(always)] - pub fn neg(self, a: SimdGroupFor) -> SimdGroupFor { - E::faer_simd_neg(self.simd, a) - } - /// Returns `lhs * rhs`. - #[inline(always)] - pub fn scale_real( - self, - lhs: SimdGroupFor, - rhs: SimdGroupFor, - ) -> SimdGroupFor { - E::faer_simd_scale_real(self.simd, lhs, rhs) - } - /// Returns `lhs * rhs`. - #[inline(always)] - pub fn mul(self, lhs: SimdGroupFor, rhs: SimdGroupFor) -> SimdGroupFor { - E::faer_simd_mul(self.simd, lhs, rhs) - } - /// Returns `conj(lhs) * rhs`. - #[inline(always)] - pub fn conj_mul( - self, - lhs: SimdGroupFor, - rhs: SimdGroupFor, - ) -> SimdGroupFor { - E::faer_simd_conj_mul(self.simd, lhs, rhs) - } - /// Returns `op(lhs) * rhs`, where `op` is either the conjugation or the identity - /// operation. - #[inline(always)] - pub fn conditional_conj_mul( - self, - conj: C, - lhs: SimdGroupFor, - rhs: SimdGroupFor, - ) -> SimdGroupFor { - let _ = conj; - if C::CONJ == Conj::Yes { - self.conj_mul(lhs, rhs) - } else { - self.mul(lhs, rhs) - } - } - - /// Returns `lhs * rhs + acc`. - #[inline(always)] - pub fn mul_add_e( - self, - lhs: SimdGroupFor, - rhs: SimdGroupFor, - acc: SimdGroupFor, - ) -> SimdGroupFor { - E::faer_simd_mul_adde(self.simd, lhs, rhs, acc) - } - /// Returns `conj(lhs) * rhs + acc`. - #[inline(always)] - pub fn conj_mul_add_e( - self, - lhs: SimdGroupFor, - rhs: SimdGroupFor, - acc: SimdGroupFor, - ) -> SimdGroupFor { - E::faer_simd_conj_mul_adde(self.simd, lhs, rhs, acc) - } - /// Returns `op(lhs) * rhs + acc`, where `op` is either the conjugation or the identity - /// operation. - #[inline(always)] - pub fn conditional_conj_mul_add_e( - self, - conj: C, - lhs: SimdGroupFor, - rhs: SimdGroupFor, - acc: SimdGroupFor, - ) -> SimdGroupFor { - let _ = conj; - if C::CONJ == Conj::Yes { - self.conj_mul_add_e(lhs, rhs, acc) - } else { - self.mul_add_e(lhs, rhs, acc) - } - } - - /// Returns `abs(values) * abs(values) + acc`. - #[inline(always)] - pub fn abs2_add_e( - self, - values: SimdGroupFor, - acc: SimdGroupFor, - ) -> SimdGroupFor { - E::faer_simd_abs2_adde(self.simd, values, acc) - } - /// Returns `abs(values) * abs(values)`. - #[inline(always)] - pub fn abs2(self, values: SimdGroupFor) -> SimdGroupFor { - E::faer_simd_abs2(self.simd, values) - } - /// Returns `abs(values)` or `abs(values) * abs(values)`, whichever is cheaper to compute. - #[inline(always)] - pub fn score(self, values: SimdGroupFor) -> SimdGroupFor { - E::faer_simd_score(self.simd, values) - } - - /// Sum the components of a vector register into a single accumulator. - #[inline(always)] - pub fn reduce_add(self, values: SimdGroupFor) -> E { - E::faer_simd_reduce_add(self.simd, values) - } - - /// Rotate `values` to the left, with overflowing entries wrapping around to the right side - /// of the register. - #[inline(always)] - pub fn rotate_left(self, values: SimdGroupFor, amount: usize) -> SimdGroupFor { - E::faer_simd_rotate_left(self.simd, values, amount) - } - } - - impl SimdFor { - /// Returns `abs(values)`. - #[inline(always)] - pub fn abs(self, values: SimdGroupFor) -> SimdGroupFor { - E::faer_simd_abs(self.simd, values) - } - /// Returns `a < b`. - #[inline(always)] - pub fn less_than(self, a: SimdGroupFor, b: SimdGroupFor) -> SimdMaskFor { - E::faer_simd_less_than(self.simd, a, b) - } - /// Returns `a <= b`. - #[inline(always)] - pub fn less_than_or_equal( - self, - a: SimdGroupFor, - b: SimdGroupFor, - ) -> SimdMaskFor { - E::faer_simd_less_than_or_equal(self.simd, a, b) - } - /// Returns `a > b`. - #[inline(always)] - pub fn greater_than( - self, - a: SimdGroupFor, - b: SimdGroupFor, - ) -> SimdMaskFor { - E::faer_simd_greater_than(self.simd, a, b) - } - /// Returns `a >= b`. - #[inline(always)] - pub fn greater_than_or_equal( - self, - a: SimdGroupFor, - b: SimdGroupFor, - ) -> SimdMaskFor { - E::faer_simd_greater_than_or_equal(self.simd, a, b) - } - - /// Returns `if mask { if_true } else { if_false }` - #[inline(always)] - pub fn select( - self, - mask: SimdMaskFor, - if_true: SimdGroupFor, - if_false: SimdGroupFor, - ) -> SimdGroupFor { - E::faer_simd_select(self.simd, mask, if_true, if_false) - } - /// Returns `if mask { if_true } else { if_false }` - #[inline(always)] - pub fn index_select( - self, - mask: SimdMaskFor, - if_true: SimdIndexFor, - if_false: SimdIndexFor, - ) -> SimdIndexFor { - E::faer_simd_index_select(self.simd, mask, if_true, if_false) - } - /// Returns `[0, 1, 2, 3, ..., REGISTER_SIZE - 1]` - #[inline(always)] - pub fn index_seq(self) -> SimdIndexFor { - E::faer_simd_index_seq(self.simd) - } - /// Fill all the register lanes with the same value. - #[inline(always)] - pub fn index_splat(self, value: IndexFor) -> SimdIndexFor { - E::faer_simd_index_splat(self.simd, value) - } - /// Returns `a + b`. - #[inline(always)] - pub fn index_add(self, a: SimdIndexFor, b: SimdIndexFor) -> SimdIndexFor { - E::faer_simd_index_add(self.simd, a, b) - } - } - - /// Analogous to an immutable reference to a [prim@slice] for groups. - pub struct SliceGroup<'a, E: Entity, T: 'a = ::Unit>( - GroupCopyFor, - PhantomData<&'a ()>, - ); - /// Analogous to a mutable reference to a [prim@slice] for groups. - pub struct SliceGroupMut<'a, E: Entity, T: 'a = ::Unit>( - GroupFor, - PhantomData<&'a mut ()>, - ); - - /// Simd prefix, contains the elements before the body. - pub struct Prefix<'a, E: Entity, S: pulp::Simd>( - GroupCopyFor>, - PhantomData<&'a ()>, - ); - /// Simd suffix, contains the elements after the body. - pub struct Suffix<'a, E: Entity, S: pulp::Simd>( - GroupCopyFor>, - PhantomData<&'a mut ()>, - ); - /// Simd prefix (mutable), contains the elements before the body. - pub struct PrefixMut<'a, E: Entity, S: pulp::Simd>( - GroupFor>, - PhantomData<&'a ()>, - ); - /// Simd suffix (mutable), contains the elements after the body. - pub struct SuffixMut<'a, E: Entity, S: pulp::Simd>( - GroupFor>, - PhantomData<&'a mut ()>, - ); - - impl Read for RefGroupMut<'_, E, T> { - type Output = GroupCopyFor; - #[inline(always)] - fn read_or(&self, _or: Self::Output) -> Self::Output { - self.get() - } - } - impl Write for RefGroupMut<'_, E, T> { - #[inline(always)] - fn write(&mut self, values: Self::Output) { - self.set(values) - } - } - impl Read for RefGroup<'_, E, T> { - type Output = GroupCopyFor; - #[inline(always)] - fn read_or(&self, _or: Self::Output) -> Self::Output { - self.get() - } - } - - impl Read for Prefix<'_, E, S> { - type Output = SimdGroupFor; - #[inline(always)] - fn read_or(&self, or: Self::Output) -> Self::Output { - into_copy::(E::faer_map( - E::faer_zip(from_copy::(self.0), from_copy::(or)), - #[inline(always)] - |(prefix, or)| prefix.read_or(or), - )) - } - } - impl Read for PrefixMut<'_, E, S> { - type Output = SimdGroupFor; - #[inline(always)] - fn read_or(&self, or: Self::Output) -> Self::Output { - self.rb().read_or(or) - } - } - impl Write for PrefixMut<'_, E, S> { - #[inline(always)] - fn write(&mut self, values: Self::Output) { - E::faer_map( - E::faer_zip(self.rb_mut().0, from_copy::(values)), - #[inline(always)] - |(mut prefix, values)| prefix.write(values), - ); - } - } - - impl Read for Suffix<'_, E, S> { - type Output = SimdGroupFor; - #[inline(always)] - fn read_or(&self, or: Self::Output) -> Self::Output { - into_copy::(E::faer_map( - E::faer_zip(from_copy::(self.0), from_copy::(or)), - #[inline(always)] - |(suffix, or)| suffix.read_or(or), - )) - } - } - impl Read for SuffixMut<'_, E, S> { - type Output = SimdGroupFor; - #[inline(always)] - fn read_or(&self, or: Self::Output) -> Self::Output { - self.rb().read_or(or) - } - } - impl Write for SuffixMut<'_, E, S> { - #[inline(always)] - fn write(&mut self, values: Self::Output) { - E::faer_map( - E::faer_zip(self.rb_mut().0, from_copy::(values)), - #[inline(always)] - |(mut suffix, values)| suffix.write(values), - ); - } - } - - impl<'short, E: Entity, S: pulp::Simd> Reborrow<'short> for PrefixMut<'_, E, S> { - type Target = Prefix<'short, E, S>; - #[inline] - fn rb(&'short self) -> Self::Target { - unsafe { - Prefix( - into_copy::(transmute_unchecked::< - GroupFor as Reborrow<'_>>::Target>, - GroupFor>, - >(E::faer_map( - E::faer_as_ref(&self.0), - |x| (*x).rb(), - ))), - PhantomData, - ) - } - } - } - impl<'short, E: Entity, S: pulp::Simd> ReborrowMut<'short> for PrefixMut<'_, E, S> { - type Target = PrefixMut<'short, E, S>; - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - unsafe { - PrefixMut( - transmute_unchecked::< - GroupFor as ReborrowMut<'_>>::Target>, - GroupFor>, - >(E::faer_map(E::faer_as_mut(&mut self.0), |x| { - (*x).rb_mut() - })), - PhantomData, - ) - } - } - } - impl<'short, E: Entity, S: pulp::Simd> Reborrow<'short> for SuffixMut<'_, E, S> { - type Target = Suffix<'short, E, S>; - #[inline] - fn rb(&'short self) -> Self::Target { - unsafe { - Suffix( - into_copy::(transmute_unchecked::< - GroupFor as Reborrow<'_>>::Target>, - GroupFor>, - >(E::faer_map( - E::faer_as_ref(&self.0), - |x| (*x).rb(), - ))), - PhantomData, - ) - } - } - } - impl<'short, E: Entity, S: pulp::Simd> ReborrowMut<'short> for SuffixMut<'_, E, S> { - type Target = SuffixMut<'short, E, S>; - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - unsafe { - SuffixMut( - transmute_unchecked::< - GroupFor as ReborrowMut<'_>>::Target>, - GroupFor>, - >(E::faer_map(E::faer_as_mut(&mut self.0), |x| { - (*x).rb_mut() - })), - PhantomData, - ) - } - } - } - - impl<'short, E: Entity, S: pulp::Simd> Reborrow<'short> for Prefix<'_, E, S> { - type Target = Prefix<'short, E, S>; - #[inline] - fn rb(&'short self) -> Self::Target { - *self - } - } - impl<'short, E: Entity, S: pulp::Simd> ReborrowMut<'short> for Prefix<'_, E, S> { - type Target = Prefix<'short, E, S>; - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } - impl<'short, E: Entity, S: pulp::Simd> Reborrow<'short> for Suffix<'_, E, S> { - type Target = Suffix<'short, E, S>; - #[inline] - fn rb(&'short self) -> Self::Target { - *self - } - } - impl<'short, E: Entity, S: pulp::Simd> ReborrowMut<'short> for Suffix<'_, E, S> { - type Target = Suffix<'short, E, S>; - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } - - impl Copy for Prefix<'_, E, S> {} - impl Clone for Prefix<'_, E, S> { - #[inline] - fn clone(&self) -> Self { - *self - } - } - impl Copy for Suffix<'_, E, S> {} - impl Clone for Suffix<'_, E, S> { - #[inline] - fn clone(&self) -> Self { - *self - } - } - - /// Wrapper around a group of references. - pub struct RefGroup<'a, E: Entity, T: 'a = ::Unit>( - GroupCopyFor, - PhantomData<&'a ()>, - ); - /// Wrapper around a group of mutable references. - pub struct RefGroupMut<'a, E: Entity, T: 'a = ::Unit>( - GroupFor, - PhantomData<&'a mut ()>, - ); - - unsafe impl Send for SliceGroup<'_, E, T> {} - unsafe impl Sync for SliceGroup<'_, E, T> {} - unsafe impl Send for SliceGroupMut<'_, E, T> {} - unsafe impl Sync for SliceGroupMut<'_, E, T> {} - - impl Copy for SliceGroup<'_, E, T> {} - impl Copy for RefGroup<'_, E, T> {} - impl Clone for SliceGroup<'_, E, T> { - #[inline] - fn clone(&self) -> Self { - *self - } - } - impl Clone for RefGroup<'_, E, T> { - #[inline] - fn clone(&self) -> Self { - *self - } - } - - impl<'a, E: Entity, T> RefGroup<'a, E, T> { - /// Create a new [`RefGroup`] from a group of references. - #[inline(always)] - pub fn new(reference: GroupFor) -> Self { - Self( - into_copy::(E::faer_map( - reference, - #[inline(always)] - |reference| reference as *const T, - )), - PhantomData, - ) - } - - /// Consume `self` to return the internally stored group of references. - #[inline(always)] - pub fn into_inner(self) -> GroupFor { - E::faer_map( - from_copy::(self.0), - #[inline(always)] - |ptr| unsafe { &*ptr }, - ) - } - - /// Copies and returns the value pointed to by the references. - #[inline(always)] - pub fn get(self) -> GroupCopyFor - where - T: Copy, - { - into_copy::(E::faer_deref(self.into_inner())) - } - } - - impl<'a, E: Entity, T, const N: usize> RefGroup<'a, E, [T; N]> { - /// Convert a reference to an array to an array of references. - #[inline(always)] - pub fn unzip(self) -> [RefGroup<'a, E, T>; N] { - unsafe { - let mut out = transmute_unchecked::< - core::mem::MaybeUninit<[RefGroup<'a, E, T>; N]>, - [core::mem::MaybeUninit>; N], - >( - core::mem::MaybeUninit::<[RefGroup<'a, E, T>; N]>::uninit() - ); - for (out, inp) in - core::iter::zip(out.iter_mut(), E::faer_into_iter(self.into_inner())) - { - out.write(RefGroup::new(inp)); - } - transmute_unchecked::< - [core::mem::MaybeUninit>; N], - [RefGroup<'a, E, T>; N], - >(out) - } - } - } - - impl<'a, E: Entity, T, const N: usize> RefGroupMut<'a, E, [T; N]> { - /// Convert a mutable reference to an array to an array of mutable references. - #[inline(always)] - pub fn unzip(self) -> [RefGroupMut<'a, E, T>; N] { - unsafe { - let mut out = transmute_unchecked::< - core::mem::MaybeUninit<[RefGroupMut<'a, E, T>; N]>, - [core::mem::MaybeUninit>; N], - >( - core::mem::MaybeUninit::<[RefGroupMut<'a, E, T>; N]>::uninit() - ); - for (out, inp) in - core::iter::zip(out.iter_mut(), E::faer_into_iter(self.into_inner())) - { - out.write(RefGroupMut::new(inp)); - } - transmute_unchecked::< - [core::mem::MaybeUninit>; N], - [RefGroupMut<'a, E, T>; N], - >(out) - } - } - } - - impl<'a, E: Entity, T> RefGroupMut<'a, E, T> { - /// Create a new [`RefGroupMut`] from a group of mutable references. - #[inline(always)] - pub fn new(reference: GroupFor) -> Self { - Self( - E::faer_map( - reference, - #[inline(always)] - |reference| reference as *mut T, - ), - PhantomData, - ) - } - - /// Consume `self` to return the internally stored group of references. - #[inline(always)] - pub fn into_inner(self) -> GroupFor { - E::faer_map( - self.0, - #[inline(always)] - |ptr| unsafe { &mut *ptr }, - ) - } - - /// Copies and returns the value pointed to by the references. - #[inline(always)] - pub fn get(&self) -> GroupCopyFor - where - T: Copy, - { - self.rb().get() - } - - /// Writes `value` to the location pointed to by the references. - #[inline(always)] - pub fn set(&mut self, value: GroupCopyFor) - where - T: Copy, - { - E::faer_map( - E::faer_zip(self.rb_mut().into_inner(), from_copy::(value)), - #[inline(always)] - |(r, value)| *r = value, - ); - } - } - - impl<'a, E: Entity, T> IntoConst for SliceGroup<'a, E, T> { - type Target = SliceGroup<'a, E, T>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - self - } - } - impl<'a, E: Entity, T> IntoConst for SliceGroupMut<'a, E, T> { - type Target = SliceGroup<'a, E, T>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - SliceGroup::new(E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| &*slice, - )) - } - } - - impl<'a, E: Entity, T> IntoConst for RefGroup<'a, E, T> { - type Target = RefGroup<'a, E, T>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - self - } - } - impl<'a, E: Entity, T> IntoConst for RefGroupMut<'a, E, T> { - type Target = RefGroup<'a, E, T>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - RefGroup::new(E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| &*slice, - )) - } - } - - impl<'short, 'a, E: Entity, T> ReborrowMut<'short> for RefGroup<'a, E, T> { - type Target = RefGroup<'short, E, T>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } - - impl<'short, 'a, E: Entity, T> Reborrow<'short> for RefGroup<'a, E, T> { - type Target = RefGroup<'short, E, T>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - *self - } - } - - impl<'short, 'a, E: Entity, T> ReborrowMut<'short> for RefGroupMut<'a, E, T> { - type Target = RefGroupMut<'short, E, T>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - RefGroupMut::new(E::faer_map( - E::faer_as_mut(&mut self.0), - #[inline(always)] - |this| unsafe { &mut **this }, - )) - } - } - - impl<'short, 'a, E: Entity, T> Reborrow<'short> for RefGroupMut<'a, E, T> { - type Target = RefGroup<'short, E, T>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - RefGroup::new(E::faer_map( - E::faer_as_ref(&self.0), - #[inline(always)] - |this| unsafe { &**this }, - )) - } - } - - impl<'a, E: Entity, T> SliceGroup<'a, E, T> { - /// Create a new [`SliceGroup`] from a group of slice references. - #[inline(always)] - pub fn new(slice: GroupFor) -> Self { - Self( - into_copy::(E::faer_map(slice, |slice| slice as *const [T])), - PhantomData, - ) - } - - /// Consume `self` to return the internally stored group of slice references. - #[inline(always)] - pub fn into_inner(self) -> GroupFor { - unsafe { E::faer_map(from_copy::(self.0), |ptr| &*ptr) } - } - - /// Decompose `self` into a slice of arrays of size `N`, and a remainder part with length - /// `< N`. - #[inline(always)] - pub fn as_arrays( - self, - ) -> (SliceGroup<'a, E, [T; N]>, SliceGroup<'a, E, T>) { - let (head, tail) = E::faer_as_arrays::(self.into_inner()); - (SliceGroup::new(head), SliceGroup::new(tail)) - } - } - - impl<'a, E: Entity, T> SliceGroupMut<'a, E, T> { - /// Create a new [`SliceGroup`] from a group of mutable slice references. - #[inline(always)] - pub fn new(slice: GroupFor) -> Self { - Self(E::faer_map(slice, |slice| slice as *mut [T]), PhantomData) - } - - /// Consume `self` to return the internally stored group of mutable slice references. - #[inline(always)] - pub fn into_inner(self) -> GroupFor { - unsafe { E::faer_map(self.0, |ptr| &mut *ptr) } - } - - /// Decompose `self` into a mutable slice of arrays of size `N`, and a remainder part with - /// length `< N`. - #[inline(always)] - pub fn as_arrays_mut( - self, - ) -> (SliceGroupMut<'a, E, [T; N]>, SliceGroupMut<'a, E, T>) { - let (head, tail) = E::faer_as_arrays_mut::(self.into_inner()); - (SliceGroupMut::new(head), SliceGroupMut::new(tail)) - } - } - - impl<'short, 'a, E: Entity, T> ReborrowMut<'short> for SliceGroup<'a, E, T> { - type Target = SliceGroup<'short, E, T>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } - - impl<'short, 'a, E: Entity, T> Reborrow<'short> for SliceGroup<'a, E, T> { - type Target = SliceGroup<'short, E, T>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - *self - } - } - - impl<'short, 'a, E: Entity, T> ReborrowMut<'short> for SliceGroupMut<'a, E, T> { - type Target = SliceGroupMut<'short, E, T>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - SliceGroupMut::new(E::faer_map( - E::faer_as_mut(&mut self.0), - #[inline(always)] - |this| unsafe { &mut **this }, - )) - } - } - - impl<'short, 'a, E: Entity, T> Reborrow<'short> for SliceGroupMut<'a, E, T> { - type Target = SliceGroup<'short, E, T>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - SliceGroup::new(E::faer_map( - E::faer_as_ref(&self.0), - #[inline(always)] - |this| unsafe { &**this }, - )) - } - } - - impl<'a, E: Entity> RefGroup<'a, E> { - /// Read the element pointed to by the references. - #[inline(always)] - pub fn read(&self) -> E { - E::faer_from_units(E::faer_deref(self.into_inner())) - } - } - - impl<'a, E: Entity> RefGroupMut<'a, E> { - /// Read the element pointed to by the references. - #[inline(always)] - pub fn read(&self) -> E { - self.rb().read() - } - - /// Write `value` to the location pointed to by the references. - #[inline(always)] - pub fn write(&mut self, value: E) { - E::faer_map( - E::faer_zip(self.rb_mut().into_inner(), value.faer_into_units()), - #[inline(always)] - |(r, value)| *r = value, - ); - } - } - - impl<'a, E: Entity> SliceGroup<'a, E> { - /// Read the element at position `idx`. - #[inline(always)] - #[track_caller] - pub fn read(&self, idx: usize) -> E { - assert!(idx < self.len()); - unsafe { self.read_unchecked(idx) } - } - - /// Read the element at position `idx`, without bound checks. - /// - /// # Safety - /// The behavior is undefined if `idx >= self.len()`. - #[inline(always)] - #[track_caller] - pub unsafe fn read_unchecked(&self, idx: usize) -> E { - debug_assert!(idx < self.len()); - E::faer_from_units(E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| *slice.get_unchecked(idx), - )) - } - } - impl<'a, E: Entity, T> SliceGroup<'a, E, T> { - /// Get a [`RefGroup`] pointing to the element at position `idx`. - #[inline(always)] - #[track_caller] - pub fn get(self, idx: usize) -> RefGroup<'a, E, T> { - assert!(idx < self.len()); - unsafe { self.get_unchecked(idx) } - } - - /// Get a [`RefGroup`] pointing to the element at position `idx`, without bound checks. - /// - /// # Safety - /// The behavior is undefined if `idx >= self.len()`. - #[inline(always)] - #[track_caller] - pub unsafe fn get_unchecked(self, idx: usize) -> RefGroup<'a, E, T> { - debug_assert!(idx < self.len()); - RefGroup::new(E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| slice.get_unchecked(idx), - )) - } - - /// Checks whether the slice is empty. - #[inline] - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns the length of the slice. - #[inline] - pub fn len(&self) -> usize { - let mut len = usize::MAX; - E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| len = Ord::min(len, slice.len()), - ); - len - } - - /// Returns the subslice of `self` from the start to the end of the provided range. - #[inline(always)] - #[track_caller] - pub fn subslice(self, range: Range) -> Self { - assert!(all(range.start <= range.end, range.end <= self.len())); - unsafe { self.subslice_unchecked(range) } - } - - /// Split `self` at the midpoint `idx`, and return the two parts. - #[inline(always)] - #[track_caller] - pub fn split_at(self, idx: usize) -> (Self, Self) { - assert!(idx <= self.len()); - let (head, tail) = E::faer_unzip(E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| slice.split_at(idx), - )); - (Self::new(head), Self::new(tail)) - } - - /// Returns the subslice of `self` from the start to the end of the provided range, without - /// bound checks. - /// - /// # Safety - /// The behavior is undefined if `range.start > range.end` or `range.end > self.len()`. - #[inline(always)] - #[track_caller] - pub unsafe fn subslice_unchecked(self, range: Range) -> Self { - debug_assert!(all(range.start <= range.end, range.end <= self.len())); - Self::new(E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| slice.get_unchecked(range.start..range.end), - )) - } - - /// Returns an iterator of [`RefGroup`] over the elements of the slice. - #[inline(always)] - pub fn into_ref_iter(self) -> impl Iterator> { - E::faer_into_iter(self.into_inner()).map(RefGroup::new) - } - - /// Returns an iterator of slices over chunks of size `chunk_size`, and the remainder of - /// the slice. - #[inline(always)] - pub fn into_chunks_exact( - self, - chunk_size: usize, - ) -> (impl Iterator>, Self) { - let len = self.len(); - let mid = len / chunk_size * chunk_size; - let (head, tail) = E::faer_unzip(E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| slice.split_at(mid), - )); - let head = E::faer_map( - head, - #[inline(always)] - |head| head.chunks_exact(chunk_size), - ); - ( - E::faer_into_iter(head).map(SliceGroup::new), - SliceGroup::new(tail), - ) - } - } - - impl<'a, E: Entity> SliceGroupMut<'a, E> { - /// Read the element at position `idx`. - #[inline(always)] - #[track_caller] - pub fn read(&self, idx: usize) -> E { - self.rb().read(idx) - } - - /// Read the element at position `idx`, without bound checks. - /// - /// # Safety - /// The behavior is undefined if `idx >= self.len()`. - #[inline(always)] - #[track_caller] - pub unsafe fn read_unchecked(&self, idx: usize) -> E { - self.rb().read_unchecked(idx) - } - - /// Write `value` to the location at position `idx`. - #[inline(always)] - #[track_caller] - pub fn write(&mut self, idx: usize, value: E) { - assert!(idx < self.len()); - unsafe { self.write_unchecked(idx, value) } - } - - /// Write `value` to the location at position `idx`, without bound checks. - /// - /// # Safety - /// The behavior is undefined if `idx >= self.len()`. - #[inline(always)] - #[track_caller] - pub unsafe fn write_unchecked(&mut self, idx: usize, value: E) { - debug_assert!(idx < self.len()); - E::faer_map( - E::faer_zip(self.rb_mut().into_inner(), value.faer_into_units()), - #[inline(always)] - |(slice, value)| *slice.get_unchecked_mut(idx) = value, - ); - } - - /// Fill the slice with zeros. - #[inline] - pub fn fill_zero(&mut self) { - E::faer_map(self.rb_mut().into_inner(), |slice| unsafe { - let len = slice.len(); - core::ptr::write_bytes(slice.as_mut_ptr(), 0u8, len); - }); - } - } - - impl<'a, E: Entity, T> SliceGroupMut<'a, E, T> { - /// Get a [`RefGroupMut`] pointing to the element at position `idx`. - #[inline(always)] - #[track_caller] - pub fn get_mut(self, idx: usize) -> RefGroupMut<'a, E, T> { - assert!(idx < self.len()); - unsafe { self.get_unchecked_mut(idx) } - } - - /// Get a [`RefGroupMut`] pointing to the element at position `idx`. - /// - /// # Safety - /// The behavior is undefined if `idx >= self.len()`. - #[inline(always)] - #[track_caller] - pub unsafe fn get_unchecked_mut(self, idx: usize) -> RefGroupMut<'a, E, T> { - debug_assert!(idx < self.len()); - RefGroupMut::new(E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| slice.get_unchecked_mut(idx), - )) - } - - /// Get a [`RefGroup`] pointing to the element at position `idx`. - #[inline(always)] - #[track_caller] - pub fn get(self, idx: usize) -> RefGroup<'a, E, T> { - self.into_const().get(idx) - } - - /// Get a [`RefGroup`] pointing to the element at position `idx`, without bound checks. - /// - /// # Safety - /// The behavior is undefined if `idx >= self.len()`. - #[inline(always)] - #[track_caller] - pub unsafe fn get_unchecked(self, idx: usize) -> RefGroup<'a, E, T> { - self.into_const().get_unchecked(idx) - } - - /// Checks whether the slice is empty. - #[inline] - pub fn is_empty(&self) -> bool { - self.rb().is_empty() - } - - /// Returns the length of the slice. - #[inline] - pub fn len(&self) -> usize { - self.rb().len() - } - - /// Returns the subslice of `self` from the start to the end of the provided range. - #[inline(always)] - #[track_caller] - pub fn subslice(self, range: Range) -> Self { - assert!(all(range.start <= range.end, range.end <= self.len())); - unsafe { self.subslice_unchecked(range) } - } - - /// Returns the subslice of `self` from the start to the end of the provided range, without - /// bound checks. - /// - /// # Safety - /// The behavior is undefined if `range.start > range.end` or `range.end > self.len()`. - #[inline(always)] - #[track_caller] - pub unsafe fn subslice_unchecked(self, range: Range) -> Self { - debug_assert!(all(range.start <= range.end, range.end <= self.len())); - Self::new(E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| slice.get_unchecked_mut(range.start..range.end), - )) - } - - /// Returns an iterator of [`RefGroupMut`] over the elements of the slice. - #[inline(always)] - pub fn into_mut_iter(self) -> impl Iterator> { - E::faer_into_iter(self.into_inner()).map(RefGroupMut::new) - } - - /// Split `self` at the midpoint `idx`, and return the two parts. - #[inline(always)] - #[track_caller] - pub fn split_at(self, idx: usize) -> (Self, Self) { - assert!(idx <= self.len()); - let (head, tail) = E::faer_unzip(E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| slice.split_at_mut(idx), - )); - (Self::new(head), Self::new(tail)) - } - - /// Returns an iterator of slices over chunks of size `chunk_size`, and the remainder of - /// the slice. - #[inline(always)] - pub fn into_chunks_exact( - self, - chunk_size: usize, - ) -> (impl Iterator>, Self) { - let len = self.len(); - let mid = len % chunk_size * chunk_size; - let (head, tail) = E::faer_unzip(E::faer_map( - self.into_inner(), - #[inline(always)] - |slice| slice.split_at_mut(mid), - )); - let head = E::faer_map( - head, - #[inline(always)] - |head| head.chunks_exact_mut(chunk_size), - ); - ( - E::faer_into_iter(head).map(SliceGroupMut::new), - SliceGroupMut::new(tail), - ) - } - } - - impl core::fmt::Debug for Prefix<'_, E, S> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - unsafe { - transmute_unchecked::, GroupDebugFor>>( - self.read_or(core::mem::zeroed()), - ) - .fmt(f) - } - } - } - impl core::fmt::Debug for PrefixMut<'_, E, S> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.rb().fmt(f) - } - } - impl core::fmt::Debug for Suffix<'_, E, S> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - unsafe { - transmute_unchecked::, GroupDebugFor>>( - self.read_or(core::mem::zeroed()), - ) - .fmt(f) - } - } - } - impl core::fmt::Debug for SuffixMut<'_, E, S> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.rb().fmt(f) - } - } - impl core::fmt::Debug for RefGroup<'_, E, T> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - unsafe { - transmute_unchecked::, GroupDebugFor>(self.into_inner()) - .fmt(f) - } - } - } - impl core::fmt::Debug for RefGroupMut<'_, E, T> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.rb().fmt(f) - } - } - impl core::fmt::Debug for SliceGroup<'_, E, T> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_list().entries(self.into_ref_iter()).finish() - } - } - impl core::fmt::Debug for SliceGroupMut<'_, E, T> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.rb().fmt(f) - } - } -} - -/// Immutable view over a column vector, similar to an immutable reference to a strided -/// [prim@slice]. -/// -/// # Note -/// -/// Unlike a slice, the data pointed to by `ColRef<'_, E>` is allowed to be partially or fully -/// uninitialized under certain conditions. In this case, care must be taken to not perform any -/// operations that read the uninitialized values, or form references to them, either directly -/// through [`ColRef::read`], or indirectly through any of the numerical library routines, unless -/// it is explicitly permitted. -pub type ColRef<'a, E> = Matrix>; - -/// Immutable view over a row vector, similar to an immutable reference to a strided [prim@slice]. -/// -/// # Note -/// -/// Unlike a slice, the data pointed to by `RowRef<'_, E>` is allowed to be partially or fully -/// uninitialized under certain conditions. In this case, care must be taken to not perform any -/// operations that read the uninitialized values, or form references to them, either directly -/// through [`RowRef::read`], or indirectly through any of the numerical library routines, unless -/// it is explicitly permitted. -pub type RowRef<'a, E> = Matrix>; - -/// Immutable view over a matrix, similar to an immutable reference to a 2D strided [prim@slice]. -/// -/// # Note -/// -/// Unlike a slice, the data pointed to by `MatRef<'_, E>` is allowed to be partially or fully -/// uninitialized under certain conditions. In this case, care must be taken to not perform any -/// operations that read the uninitialized values, or form references to them, either directly -/// through [`MatRef::read`], or indirectly through any of the numerical library routines, unless -/// it is explicitly permitted. -pub type MatRef<'a, E> = Matrix>; - -/// Mutable view over a column vector, similar to a mutable reference to a strided [prim@slice]. -/// -/// # Note -/// -/// Unlike a slice, the data pointed to by `ColMut<'_, E>` is allowed to be partially or fully -/// uninitialized under certain conditions. In this case, care must be taken to not perform any -/// operations that read the uninitialized values, or form references to them, either directly -/// through [`ColMut::read`], or indirectly through any of the numerical library routines, unless -/// it is explicitly permitted. -pub type ColMut<'a, E> = Matrix>; - -/// Mutable view over a row vector, similar to a mutable reference to a strided [prim@slice]. -/// -/// # Note -/// -/// Unlike a slice, the data pointed to by `RowMut<'_, E>` is allowed to be partially or fully -/// uninitialized under certain conditions. In this case, care must be taken to not perform any -/// operations that read the uninitialized values, or form references to them, either directly -/// through [`RowMut::read`], or indirectly through any of the numerical library routines, unless -/// it is explicitly permitted. -pub type RowMut<'a, E> = Matrix>; - -/// Mutable view over a matrix, similar to a mutable reference to a 2D strided [prim@slice]. -/// -/// # Note -/// -/// Unlike a slice, the data pointed to by `MatMut<'_, E>` is allowed to be partially or fully -/// uninitialized under certain conditions. In this case, care must be taken to not perform any -/// operations that read the uninitialized values, or form references to them, either directly -/// through [`MatMut::read`], or indirectly through any of the numerical library routines, unless -/// it is explicitly permitted. -/// -/// # Move semantics -/// Since `MatMut` mutably borrows data, it cannot be [`Copy`]. This means that if we pass a -/// `MatMut` to a function that takes it by value, or use a method that consumes `self` like -/// [`MatMut::transpose`], this renders the original variable unusable. -/// ```compile_fail -/// use faer_core::{Mat, MatMut}; -/// -/// fn takes_matmut(view: MatMut<'_, f64>) {} -/// -/// let mut matrix = Mat::new(); -/// let view = matrix.as_mut(); -/// -/// takes_matmut(view); // `view` is moved (passed by value) -/// takes_matmut(view); // this fails to compile since `view` was moved -/// ``` -/// The way to get around it is to use the [`reborrow::ReborrowMut`] trait, which allows us to -/// mutably borrow a `MatMut` to obtain another `MatMut` for the lifetime of the borrow. -/// It's also similarly possible to immutably borrow a `MatMut` to obtain a `MatRef` for the -/// lifetime of the borrow, using [`reborrow::Reborrow`]. -/// ``` -/// use faer_core::{Mat, MatMut, MatRef}; -/// use reborrow::*; -/// -/// fn takes_matmut(view: MatMut<'_, f64>) {} -/// fn takes_matref(view: MatRef<'_, f64>) {} -/// -/// let mut matrix = Mat::new(); -/// let mut view = matrix.as_mut(); -/// -/// takes_matmut(view.rb_mut()); -/// takes_matmut(view.rb_mut()); -/// takes_matref(view.rb()); -/// // view is still usable here -/// ``` -pub type MatMut<'a, E> = Matrix>; - -/// Wrapper around a scalar value that allows scalar multiplication by matrices. -pub type MatScale = Matrix>; - -impl MatScale { - /// Returns a new scaling factor with the given value. - #[inline(always)] - pub fn new(value: E) -> Self { - Self { - inner: inner::Scale(value), - } - } - - /// Returns the value of the scaling factor. - #[inline(always)] - pub fn value(self) -> E { - self.inner.0 - } -} - -// COL_REBORROW -const _: () = { - impl<'a, E: Entity> IntoConst for ColMut<'a, E> { - type Target = ColRef<'a, E>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - ColRef { - inner: inner::DenseColRef { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - } - - impl<'short, 'a, E: Entity> Reborrow<'short> for ColMut<'a, E> { - type Target = ColRef<'short, E>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - ColRef { - inner: inner::DenseColRef { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - } - - impl<'short, 'a, E: Entity> ReborrowMut<'short> for ColMut<'a, E> { - type Target = ColMut<'short, E>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - ColMut { - inner: inner::DenseColMut { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - } - - impl<'a, E: Entity> IntoConst for ColRef<'a, E> { - type Target = ColRef<'a, E>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - self - } - } - - impl<'short, 'a, E: Entity> Reborrow<'short> for ColRef<'a, E> { - type Target = ColRef<'short, E>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - *self - } - } - - impl<'short, 'a, E: Entity> ReborrowMut<'short> for ColRef<'a, E> { - type Target = ColRef<'short, E>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } -}; - -// ROW REBORROW -const _: () = { - impl<'a, E: Entity> IntoConst for RowMut<'a, E> { - type Target = RowRef<'a, E>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - RowRef { - inner: inner::DenseRowRef { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - } - - impl<'short, 'a, E: Entity> Reborrow<'short> for RowMut<'a, E> { - type Target = RowRef<'short, E>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - RowRef { - inner: inner::DenseRowRef { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - } - - impl<'short, 'a, E: Entity> ReborrowMut<'short> for RowMut<'a, E> { - type Target = RowMut<'short, E>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - RowMut { - inner: inner::DenseRowMut { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - } - - impl<'a, E: Entity> IntoConst for RowRef<'a, E> { - type Target = RowRef<'a, E>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - self - } - } - - impl<'short, 'a, E: Entity> Reborrow<'short> for RowRef<'a, E> { - type Target = RowRef<'short, E>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - *self - } - } - - impl<'short, 'a, E: Entity> ReborrowMut<'short> for RowRef<'a, E> { - type Target = RowRef<'short, E>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } -}; - -// MAT_REBORROW -const _: () = { - impl<'a, E: Entity> IntoConst for MatMut<'a, E> { - type Target = MatRef<'a, E>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - MatRef { - inner: inner::DenseRef { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - } - - impl<'short, 'a, E: Entity> Reborrow<'short> for MatMut<'a, E> { - type Target = MatRef<'short, E>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - MatRef { - inner: inner::DenseRef { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - } - - impl<'short, 'a, E: Entity> ReborrowMut<'short> for MatMut<'a, E> { - type Target = MatMut<'short, E>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - MatMut { - inner: inner::DenseMut { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - } - - impl<'a, E: Entity> IntoConst for MatRef<'a, E> { - type Target = MatRef<'a, E>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - self - } - } - - impl<'short, 'a, E: Entity> Reborrow<'short> for MatRef<'a, E> { - type Target = MatRef<'short, E>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - *self - } - } - - impl<'short, 'a, E: Entity> ReborrowMut<'short> for MatRef<'a, E> { - type Target = MatRef<'short, E>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } -}; - -impl<'a, E: Entity> IntoConst for Matrix> { - type Target = Matrix>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - Matrix { - inner: inner::DiagRef { - inner: self.inner.inner.into_const(), - }, - } - } -} - -impl<'short, 'a, E: Entity> Reborrow<'short> for Matrix> { - type Target = Matrix>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - Matrix { - inner: inner::DiagRef { - inner: self.inner.inner.rb(), - }, - } - } -} - -impl<'short, 'a, E: Entity> ReborrowMut<'short> for Matrix> { - type Target = Matrix>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - Matrix { - inner: inner::DiagMut { - inner: self.inner.inner.rb_mut(), - }, - } - } -} - -impl<'a, E: Entity> IntoConst for Matrix> { - type Target = Matrix>; - - #[inline(always)] - fn into_const(self) -> Self::Target { - self - } -} - -impl<'short, 'a, E: Entity> Reborrow<'short> for Matrix> { - type Target = Matrix>; - - #[inline(always)] - fn rb(&'short self) -> Self::Target { - *self - } -} - -impl<'short, 'a, E: Entity> ReborrowMut<'short> for Matrix> { - type Target = Matrix>; - - #[inline(always)] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } -} - -unsafe impl Send for VecImpl {} -unsafe impl Sync for VecImpl {} -unsafe impl Send for VecOwnImpl {} -unsafe impl Sync for VecOwnImpl {} - -unsafe impl Send for MatImpl {} -unsafe impl Sync for MatImpl {} -unsafe impl Send for MatOwnImpl {} -unsafe impl Sync for MatOwnImpl {} - -#[doc(hidden)] -#[inline] -pub fn par_split_indices(n: usize, idx: usize, chunk_count: usize) -> (usize, usize) { - let chunk_size = n / chunk_count; - let rem = n % chunk_count; - - let idx_to_col_start = move |idx| { - if idx < rem { - idx * (chunk_size + 1) - } else { - rem + idx * chunk_size - } - }; - - let start = idx_to_col_start(idx); - let end = idx_to_col_start(idx + 1); - (start, end - start) -} - -mod seal { - pub trait Seal {} -} -impl<'a, E: Entity> seal::Seal for MatRef<'a, E> {} -impl<'a, E: Entity> seal::Seal for MatMut<'a, E> {} -impl<'a, E: Entity> seal::Seal for ColRef<'a, E> {} -impl<'a, E: Entity> seal::Seal for ColMut<'a, E> {} -impl<'a, E: Entity> seal::Seal for RowRef<'a, E> {} -impl<'a, E: Entity> seal::Seal for RowMut<'a, E> {} - -/// Represents a type that can be used to slice a row, such as an index or a range of indices. -pub trait RowIndex: seal::Seal + Sized { - /// Resulting type of the indexing operation. - type Target; - - /// Index the row at `col`, without bound checks. - #[allow(clippy::missing_safety_doc)] - unsafe fn get_unchecked(this: Self, col: ColRange) -> Self::Target { - >::get(this, col) - } - /// Index the row at `col`. - fn get(this: Self, col: ColRange) -> Self::Target; -} - -/// Represents a type that can be used to slice a column, such as an index or a range of indices. -pub trait ColIndex: seal::Seal + Sized { - /// Resulting type of the indexing operation. - type Target; - - /// Index the column at `row`, without bound checks. - #[allow(clippy::missing_safety_doc)] - unsafe fn get_unchecked(this: Self, row: RowRange) -> Self::Target { - >::get(this, row) - } - /// Index the column at `row`. - fn get(this: Self, row: RowRange) -> Self::Target; -} - -/// Represents a type that can be used to slice a matrix, such as an index or a range of indices. -pub trait MatIndex: seal::Seal + Sized { - /// Resulting type of the indexing operation. - type Target; - - /// Index the matrix at `(row, col)`, without bound checks. - #[allow(clippy::missing_safety_doc)] - unsafe fn get_unchecked(this: Self, row: RowRange, col: ColRange) -> Self::Target { - >::get(this, row, col) - } - /// Index the matrix at `(row, col)`. - fn get(this: Self, row: RowRange, col: ColRange) -> Self::Target; -} - -// MAT INDEX -const _: () = { - // RangeFull - // Range - // RangeInclusive - // RangeTo - // RangeToInclusive - // usize - - use core::ops::RangeFull; - type Range = core::ops::Range; - type RangeInclusive = core::ops::RangeInclusive; - type RangeFrom = core::ops::RangeFrom; - type RangeTo = core::ops::RangeTo; - type RangeToInclusive = core::ops::RangeToInclusive; - - impl MatIndex for MatRef<'_, E> - where - Self: MatIndex, - { - type Target = >::Target; - - #[track_caller] - #[inline(always)] - fn get( - this: Self, - row: RowRange, - col: RangeFrom, - ) -> >::Target { - let ncols = this.ncols(); - >::get(this, row, col.start..ncols) - } - } - impl MatIndex for MatRef<'_, E> - where - Self: MatIndex, - { - type Target = >::Target; - - #[track_caller] - #[inline(always)] - fn get( - this: Self, - row: RowRange, - col: RangeTo, - ) -> >::Target { - >::get(this, row, 0..col.end) - } - } - impl MatIndex for MatRef<'_, E> - where - Self: MatIndex, - { - type Target = >::Target; - - #[track_caller] - #[inline(always)] - fn get( - this: Self, - row: RowRange, - col: RangeToInclusive, - ) -> >::Target { - assert!(col.end != usize::MAX); - >::get(this, row, 0..col.end + 1) - } - } - impl MatIndex for MatRef<'_, E> - where - Self: MatIndex, - { - type Target = >::Target; - - #[track_caller] - #[inline(always)] - fn get( - this: Self, - row: RowRange, - col: RangeInclusive, - ) -> >::Target { - assert!(*col.end() != usize::MAX); - >::get(this, row, *col.start()..*col.end() + 1) - } - } - impl MatIndex for MatRef<'_, E> - where - Self: MatIndex, - { - type Target = >::Target; - - #[track_caller] - #[inline(always)] - fn get( - this: Self, - row: RowRange, - col: RangeFull, - ) -> >::Target { - let _ = col; - let ncols = this.ncols(); - >::get(this, row, 0..ncols) - } - } - - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFull, col: Range) -> Self { - let _ = row; - this.subcols(col.start, col.end - col.start) - } - } - impl<'a, E: Entity> MatIndex for MatRef<'a, E> { - type Target = ColRef<'a, E>; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFull, col: usize) -> Self::Target { - let _ = row; - this.col(col) - } - } - - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: Range, col: Range) -> Self { - this.submatrix( - row.start, - col.start, - row.end - row.start, - col.end - col.start, - ) - } - } - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: Range, col: usize) -> Self { - this.submatrix(row.start, col, row.end - row.start, 1) - } - } - - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeInclusive, col: Range) -> Self { - assert!(*row.end() != usize::MAX); - >::get(this, *row.start()..*row.end() + 1, col) - } - } - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeInclusive, col: usize) -> Self { - assert!(*row.end() != usize::MAX); - >::get(this, *row.start()..*row.end() + 1, col) - } - } - - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFrom, col: Range) -> Self { - let nrows = this.nrows(); - >::get(this, row.start..nrows, col) - } - } - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFrom, col: usize) -> Self { - let nrows = this.nrows(); - >::get(this, row.start..nrows, col) - } - } - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeTo, col: Range) -> Self { - >::get(this, 0..row.end, col) - } - } - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeTo, col: usize) -> Self { - >::get(this, 0..row.end, col) - } - } - - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeToInclusive, col: Range) -> Self { - assert!(row.end != usize::MAX); - >::get(this, 0..row.end + 1, col) - } - } - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeToInclusive, col: usize) -> Self { - assert!(row.end != usize::MAX); - >::get(this, 0..row.end + 1, col) - } - } - - impl MatIndex for MatRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: usize, col: Range) -> Self { - this.submatrix(row, col.start, 1, col.end - col.start) - } - } - - impl MatIndex for MatMut<'_, E> - where - Self: MatIndex, - { - type Target = >::Target; - - #[track_caller] - #[inline(always)] - fn get( - this: Self, - row: RowRange, - col: RangeFrom, - ) -> >::Target { - let ncols = this.ncols(); - >::get(this, row, col.start..ncols) - } - } - impl MatIndex for MatMut<'_, E> - where - Self: MatIndex, - { - type Target = >::Target; - - #[track_caller] - #[inline(always)] - fn get( - this: Self, - row: RowRange, - col: RangeTo, - ) -> >::Target { - >::get(this, row, 0..col.end) - } - } - impl MatIndex for MatMut<'_, E> - where - Self: MatIndex, - { - type Target = >::Target; - - #[track_caller] - #[inline(always)] - fn get( - this: Self, - row: RowRange, - col: RangeToInclusive, - ) -> >::Target { - assert!(col.end != usize::MAX); - >::get(this, row, 0..col.end + 1) - } - } - impl MatIndex for MatMut<'_, E> - where - Self: MatIndex, - { - type Target = >::Target; - - #[track_caller] - #[inline(always)] - fn get( - this: Self, - row: RowRange, - col: RangeInclusive, - ) -> >::Target { - assert!(*col.end() != usize::MAX); - >::get(this, row, *col.start()..*col.end() + 1) - } - } - impl MatIndex for MatMut<'_, E> - where - Self: MatIndex, - { - type Target = >::Target; - - #[track_caller] - #[inline(always)] - fn get( - this: Self, - row: RowRange, - col: RangeFull, - ) -> >::Target { - let _ = col; - let ncols = this.ncols(); - >::get(this, row, 0..ncols) - } - } - - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFull, col: Range) -> Self { - let _ = row; - this.subcols_mut(col.start, col.end - col.start) - } - } - impl<'a, E: Entity> MatIndex for MatMut<'a, E> { - type Target = ColMut<'a, E>; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFull, col: usize) -> Self::Target { - let _ = row; - this.col_mut(col) - } - } - - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: Range, col: Range) -> Self { - this.submatrix_mut( - row.start, - col.start, - row.end - row.start, - col.end - col.start, - ) - } - } - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: Range, col: usize) -> Self { - this.submatrix_mut(row.start, col, row.end - row.start, 1) - } - } - - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeInclusive, col: Range) -> Self { - assert!(*row.end() != usize::MAX); - >::get(this, *row.start()..*row.end() + 1, col) - } - } - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeInclusive, col: usize) -> Self { - assert!(*row.end() != usize::MAX); - >::get(this, *row.start()..*row.end() + 1, col) - } - } - - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFrom, col: Range) -> Self { - let nrows = this.nrows(); - >::get(this, row.start..nrows, col) - } - } - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFrom, col: usize) -> Self { - let nrows = this.nrows(); - >::get(this, row.start..nrows, col) - } - } - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeTo, col: Range) -> Self { - >::get(this, 0..row.end, col) - } - } - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeTo, col: usize) -> Self { - >::get(this, 0..row.end, col) - } - } - - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeToInclusive, col: Range) -> Self { - assert!(row.end != usize::MAX); - >::get(this, 0..row.end + 1, col) - } - } - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeToInclusive, col: usize) -> Self { - assert!(row.end != usize::MAX); - >::get(this, 0..row.end + 1, col) - } - } - - impl MatIndex for MatMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: usize, col: Range) -> Self { - this.submatrix_mut(row, col.start, 1, col.end - col.start) - } - } - - impl<'a, E: Entity> MatIndex for MatRef<'a, E> { - type Target = GroupFor; - - #[track_caller] - #[inline(always)] - unsafe fn get_unchecked(this: Self, row: usize, col: usize) -> Self::Target { - unsafe { E::faer_map(this.ptr_inbounds_at(row, col), |ptr| &*ptr) } - } - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: usize, col: usize) -> Self::Target { - assert!(all(row < this.nrows(), col < this.ncols())); - unsafe { >::get_unchecked(this, row, col) } - } - } - - impl<'a, E: Entity> MatIndex for MatMut<'a, E> { - type Target = GroupFor; - - #[track_caller] - #[inline(always)] - unsafe fn get_unchecked(this: Self, row: usize, col: usize) -> Self::Target { - unsafe { E::faer_map(this.ptr_inbounds_at_mut(row, col), |ptr| &mut *ptr) } - } - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: usize, col: usize) -> Self::Target { - assert!(all(row < this.nrows(), col < this.ncols())); - unsafe { >::get_unchecked(this, row, col) } - } - } -}; - -// COL INDEX -const _: () = { - // RangeFull - // Range - // RangeInclusive - // RangeTo - // RangeToInclusive - // usize - - use core::ops::RangeFull; - type Range = core::ops::Range; - type RangeInclusive = core::ops::RangeInclusive; - type RangeFrom = core::ops::RangeFrom; - type RangeTo = core::ops::RangeTo; - type RangeToInclusive = core::ops::RangeToInclusive; - - impl ColIndex for ColRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFull) -> Self { - let _ = row; - this - } - } - - impl ColIndex for ColRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: Range) -> Self { - this.subrows(row.start, row.end - row.start) - } - } - - impl ColIndex for ColRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeInclusive) -> Self { - assert!(*row.end() != usize::MAX); - >::get(this, *row.start()..*row.end() + 1) - } - } - - impl ColIndex for ColRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFrom) -> Self { - let nrows = this.nrows(); - >::get(this, row.start..nrows) - } - } - impl ColIndex for ColRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeTo) -> Self { - >::get(this, 0..row.end) - } - } - - impl ColIndex for ColRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeToInclusive) -> Self { - assert!(row.end != usize::MAX); - >::get(this, 0..row.end + 1) - } - } - - impl<'a, E: Entity> ColIndex for ColRef<'a, E> { - type Target = GroupFor; - - #[track_caller] - #[inline(always)] - unsafe fn get_unchecked(this: Self, row: usize) -> Self::Target { - unsafe { E::faer_map(this.ptr_inbounds_at(row), |ptr: *const _| &*ptr) } - } - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: usize) -> Self::Target { - assert!(row < this.nrows()); - unsafe { >::get_unchecked(this, row) } - } - } - - impl ColIndex for ColMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFull) -> Self { - let _ = row; - this - } - } - - impl ColIndex for ColMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: Range) -> Self { - this.subrows_mut(row.start, row.end - row.start) - } - } - - impl ColIndex for ColMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeInclusive) -> Self { - assert!(*row.end() != usize::MAX); - >::get(this, *row.start()..*row.end() + 1) - } - } - - impl ColIndex for ColMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeFrom) -> Self { - let nrows = this.nrows(); - >::get(this, row.start..nrows) - } - } - impl ColIndex for ColMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeTo) -> Self { - >::get(this, 0..row.end) - } - } - - impl ColIndex for ColMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: RangeToInclusive) -> Self { - assert!(row.end != usize::MAX); - >::get(this, 0..row.end + 1) - } - } - - impl<'a, E: Entity> ColIndex for ColMut<'a, E> { - type Target = GroupFor; - - #[track_caller] - #[inline(always)] - unsafe fn get_unchecked(this: Self, row: usize) -> Self::Target { - unsafe { E::faer_map(this.ptr_inbounds_at_mut(row), |ptr: *mut _| &mut *ptr) } - } - - #[track_caller] - #[inline(always)] - fn get(this: Self, row: usize) -> Self::Target { - assert!(row < this.nrows()); - unsafe { >::get_unchecked(this, row) } - } - } -}; - -// ROW INDEX -const _: () = { - // RangeFull - // Range - // RangeInclusive - // RangeTo - // RangeToInclusive - // usize - - use core::ops::RangeFull; - type Range = core::ops::Range; - type RangeInclusive = core::ops::RangeInclusive; - type RangeFrom = core::ops::RangeFrom; - type RangeTo = core::ops::RangeTo; - type RangeToInclusive = core::ops::RangeToInclusive; - - impl RowIndex for RowRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: RangeFull) -> Self { - let _ = col; - this - } - } - - impl RowIndex for RowRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: Range) -> Self { - this.subcols(col.start, col.end - col.start) - } - } - - impl RowIndex for RowRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: RangeInclusive) -> Self { - assert!(*col.end() != usize::MAX); - >::get(this, *col.start()..*col.end() + 1) - } - } - - impl RowIndex for RowRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: RangeFrom) -> Self { - let ncols = this.ncols(); - >::get(this, col.start..ncols) - } - } - impl RowIndex for RowRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: RangeTo) -> Self { - >::get(this, 0..col.end) - } - } - - impl RowIndex for RowRef<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: RangeToInclusive) -> Self { - assert!(col.end != usize::MAX); - >::get(this, 0..col.end + 1) - } - } - - impl<'a, E: Entity> RowIndex for RowRef<'a, E> { - type Target = GroupFor; - - #[track_caller] - #[inline(always)] - unsafe fn get_unchecked(this: Self, col: usize) -> Self::Target { - unsafe { E::faer_map(this.ptr_inbounds_at(col), |ptr: *const _| &*ptr) } - } - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: usize) -> Self::Target { - assert!(col < this.ncols()); - unsafe { >::get_unchecked(this, col) } - } - } - - impl RowIndex for RowMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: RangeFull) -> Self { - let _ = col; - this - } - } - - impl RowIndex for RowMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: Range) -> Self { - this.subcols_mut(col.start, col.end - col.start) - } - } - - impl RowIndex for RowMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: RangeInclusive) -> Self { - assert!(*col.end() != usize::MAX); - >::get(this, *col.start()..*col.end() + 1) - } - } - - impl RowIndex for RowMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: RangeFrom) -> Self { - let ncols = this.ncols(); - >::get(this, col.start..ncols) - } - } - - impl RowIndex for RowMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: RangeTo) -> Self { - >::get(this, 0..col.end) - } - } - - impl RowIndex for RowMut<'_, E> { - type Target = Self; - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: RangeToInclusive) -> Self { - assert!(col.end != usize::MAX); - >::get(this, 0..col.end + 1) - } - } - - impl<'a, E: Entity> RowIndex for RowMut<'a, E> { - type Target = GroupFor; - - #[track_caller] - #[inline(always)] - unsafe fn get_unchecked(this: Self, col: usize) -> Self::Target { - unsafe { E::faer_map(this.ptr_inbounds_at_mut(col), |ptr: *mut _| &mut *ptr) } - } - - #[track_caller] - #[inline(always)] - fn get(this: Self, col: usize) -> Self::Target { - assert!(col < this.ncols()); - unsafe { >::get_unchecked(this, col) } - } - } -}; - -impl<'a, E: Entity> Matrix> { - /// Returns the diagonal as a column vector view. - #[inline(always)] - pub fn column_vector(self) -> ColRef<'a, E> { - self.inner.inner - } -} - -impl<'a, E: Entity> Matrix> { - /// Returns the diagonal as a mutable column vector view. - #[inline(always)] - pub fn column_vector_mut(self) -> ColMut<'a, E> { - self.inner.inner - } -} - -impl Matrix> { - /// Returns the diagonal as a column vector. - #[inline(always)] - pub fn into_column_vector(self) -> Col { - self.inner.inner - } - - /// Returns a view over `self`. - #[inline(always)] - pub fn as_ref(&self) -> Matrix> { - Matrix { - inner: inner::DiagRef { - inner: self.inner.inner.as_ref(), - }, - } - } - - /// Returns a mutable view over `self`. - #[inline(always)] - pub fn as_mut(&mut self) -> Matrix> { - Matrix { - inner: inner::DiagMut { - inner: self.inner.inner.as_mut(), - }, - } - } -} - -#[track_caller] -#[inline] -fn from_slice_assert(nrows: usize, ncols: usize, len: usize) { - // we don't have to worry about size == usize::MAX == slice.len(), because the length of a - // slice can never exceed isize::MAX in bytes, unless the type is zero sized, in which case - // we don't care - let size = usize::checked_mul(nrows, ncols).unwrap_or(usize::MAX); - assert!(size == len); -} - -#[track_caller] -#[inline] -fn from_strided_column_major_slice_assert( - nrows: usize, - ncols: usize, - col_stride: usize, - len: usize, -) { - // we don't have to worry about size == usize::MAX == slice.len(), because the length of a - // slice can never exceed isize::MAX in bytes, unless the type is zero sized, in which case - // we don't care - let last = usize::checked_mul(col_stride, ncols - 1) - .and_then(|last_col| last_col.checked_add(nrows - 1)) - .unwrap_or(usize::MAX); - assert!(last < len); -} - -#[track_caller] -#[inline] -fn from_strided_column_major_slice_mut_assert( - nrows: usize, - ncols: usize, - col_stride: usize, - len: usize, -) { - // we don't have to worry about size == usize::MAX == slice.len(), because the length of a - // slice can never exceed isize::MAX in bytes, unless the type is zero sized, in which case - // we don't care - let last = usize::checked_mul(col_stride, ncols - 1) - .and_then(|last_col| last_col.checked_add(nrows - 1)) - .unwrap_or(usize::MAX); - assert!(all(col_stride >= nrows, last < len)); -} - -#[inline(always)] -unsafe fn unchecked_mul(a: usize, b: isize) -> isize { - let (sum, overflow) = (a as isize).overflowing_mul(b); - if overflow { - core::hint::unreachable_unchecked(); - } - sum -} - -#[inline(always)] -unsafe fn unchecked_add(a: isize, b: isize) -> isize { - let (sum, overflow) = a.overflowing_add(b); - if overflow { - core::hint::unreachable_unchecked(); - } - sum -} - -// COL IMPL -const _: () = { - impl<'a, E: Entity> ColRef<'a, E> { - #[track_caller] - #[inline(always)] - #[doc(hidden)] - pub fn try_get_contiguous_col(self) -> GroupFor { - assert!(self.row_stride() == 1); - let m = self.nrows(); - E::faer_map( - self.as_ptr(), - #[inline(always)] - |ptr| unsafe { core::slice::from_raw_parts(ptr, m) }, - ) - } - - /// Returns the number of rows of the column. - #[inline(always)] - pub fn nrows(&self) -> usize { - self.inner.inner.len - } - /// Returns the number of columns of the column. This is always equal to `1`. - #[inline(always)] - pub fn ncols(&self) -> usize { - 1 - } - - /// Returns pointers to the matrix data. - #[inline(always)] - pub fn as_ptr(self) -> GroupFor { - E::faer_map( - from_copy::(self.inner.inner.ptr), - #[inline(always)] - |ptr| ptr.as_ptr() as *const E::Unit, - ) - } - - /// Returns the row stride of the matrix, specified in number of elements, not in bytes. - #[inline(always)] - pub fn row_stride(&self) -> isize { - self.inner.inner.stride - } - - /// Returns `self` as a matrix view. - #[inline(always)] - pub fn as_2d(self) -> MatRef<'a, E> { - let nrows = self.nrows(); - let row_stride = self.row_stride(); - unsafe { mat::from_raw_parts(self.as_ptr(), nrows, 1, row_stride, 0) } - } - - /// Returns raw pointers to the element at the given index. - #[inline(always)] - pub fn ptr_at(self, row: usize) -> GroupFor { - let offset = (row as isize).wrapping_mul(self.inner.inner.stride); - - E::faer_map( - self.as_ptr(), - #[inline(always)] - |ptr| ptr.wrapping_offset(offset), - ) - } - - #[inline(always)] - unsafe fn unchecked_ptr_at(self, row: usize) -> GroupFor { - let offset = unchecked_mul(row, self.inner.inner.stride); - E::faer_map( - self.as_ptr(), - #[inline(always)] - |ptr| ptr.offset(offset), - ) - } - - #[inline(always)] - unsafe fn overflowing_ptr_at(self, row: usize) -> GroupFor { - unsafe { - let cond = row != self.nrows(); - let offset = (cond as usize).wrapping_neg() as isize - & (row as isize).wrapping_mul(self.inner.inner.stride); - E::faer_map( - self.as_ptr(), - #[inline(always)] - |ptr| ptr.offset(offset), - ) - } - } - - /// Returns raw pointers to the element at the given index, assuming the provided index - /// is within the size of the vector. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn ptr_inbounds_at(self, row: usize) -> GroupFor { - debug_assert!(row < self.nrows()); - self.unchecked_ptr_at(row) - } - - /// Splits the column vector at the given index into two parts and - /// returns an array of each subvector, in the following order: - /// * top. - /// * bottom. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row <= self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at_unchecked(self, row: usize) -> (Self, Self) { - debug_assert!(row <= self.nrows()); - - let row_stride = self.row_stride(); - - let nrows = self.nrows(); - - unsafe { - let top = self.as_ptr(); - let bot = self.overflowing_ptr_at(row); - - ( - col::from_raw_parts(top, row, row_stride), - col::from_raw_parts(bot, nrows - row, row_stride), - ) - } - } - - /// Splits the column vector at the given index into two parts and - /// returns an array of each subvector, in the following order: - /// * top. - /// * bottom. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row <= self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at(self, row: usize) -> (Self, Self) { - assert!(row <= self.nrows()); - unsafe { self.split_at_unchecked(row) } - } - - /// Returns references to the element at the given index, or subvector if `row` is a - /// range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - #[inline(always)] - #[track_caller] - pub unsafe fn get_unchecked( - self, - row: RowRange, - ) -> >::Target - where - Self: ColIndex, - { - >::get_unchecked(self, row) - } - - /// Returns references to the element at the given index, or subvector if `row` is a - /// range, with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - #[inline(always)] - #[track_caller] - pub fn get(self, row: RowRange) -> >::Target - where - Self: ColIndex, - { - >::get(self, row) - } - - /// Reads the value of the element at the given index. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn read_unchecked(&self, row: usize) -> E { - E::faer_from_units(E::faer_map( - self.get_unchecked(row), - #[inline(always)] - |ptr| *ptr, - )) - } - - /// Reads the value of the element at the given index, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub fn read(&self, row: usize) -> E { - E::faer_from_units(E::faer_map( - self.get(row), - #[inline(always)] - |ptr| *ptr, - )) - } - - /// Returns a view over the transpose of `self`. - #[inline(always)] - #[must_use] - pub fn transpose(self) -> RowRef<'a, E> { - unsafe { row::from_raw_parts(self.as_ptr(), self.nrows(), self.row_stride()) } - } - - /// Returns a view over the conjugate of `self`. - #[inline(always)] - #[must_use] - pub fn conjugate(self) -> ColRef<'a, E::Conj> - where - E: Conjugate, - { - unsafe { - // SAFETY: Conjugate requires that E::Unit and E::Conj::Unit have the same layout - // and that GroupCopyFor == E::Conj::GroupCopy - col::from_raw_parts::<'_, E::Conj>( - transmute_unchecked::< - GroupFor>, - GroupFor>, - >(self.as_ptr()), - self.nrows(), - self.row_stride(), - ) - } - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline(always)] - pub fn adjoint(self) -> RowRef<'a, E::Conj> - where - E: Conjugate, - { - self.conjugate().transpose() - } - - /// Returns a view over the canonical representation of `self`, as well as a flag declaring - /// whether `self` is implicitly conjugated or not. - #[inline(always)] - pub fn canonicalize(self) -> (ColRef<'a, E::Canonical>, Conj) - where - E: Conjugate, - { - ( - unsafe { - // SAFETY: see Self::conjugate - col::from_raw_parts::<'_, E::Canonical>( - transmute_unchecked::< - GroupFor, - GroupFor>, - >(self.as_ptr()), - self.nrows(), - self.row_stride(), - ) - }, - if coe::is_same::() { - Conj::No - } else { - Conj::Yes - }, - ) - } - - /// Returns a view over the `self`, with the rows in reversed order. - #[inline(always)] - #[must_use] - pub fn reverse_rows(self) -> Self { - let nrows = self.nrows(); - let row_stride = self.row_stride().wrapping_neg(); - - let ptr = unsafe { self.unchecked_ptr_at(nrows.saturating_sub(1)) }; - unsafe { col::from_raw_parts(ptr, nrows, row_stride) } - } - - /// Returns a view over the subvector starting at row `row_start`, and with number of rows - /// `nrows`. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row_start <= self.nrows()`. - /// * `nrows <= self.nrows() - row_start`. - #[track_caller] - #[inline(always)] - pub unsafe fn subrows_unchecked(self, row_start: usize, nrows: usize) -> Self { - debug_assert!(all( - row_start <= self.nrows(), - nrows <= self.nrows() - row_start - )); - let row_stride = self.row_stride(); - unsafe { col::from_raw_parts(self.overflowing_ptr_at(row_start), nrows, row_stride) } - } - - /// Returns a view over the subvector starting at row `row_start`, and with number of rows - /// `nrows`. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row_start <= self.nrows()`. - /// * `nrows <= self.nrows() - row_start`. - #[track_caller] - #[inline(always)] - pub fn subrows(self, row_start: usize, nrows: usize) -> Self { - assert!(all( - row_start <= self.nrows(), - nrows <= self.nrows() - row_start - )); - unsafe { self.subrows_unchecked(row_start, nrows) } - } - - /// Given a matrix with a single column, returns an object that interprets - /// the column as a diagonal matrix, whoes diagonal elements are values in the column. - #[track_caller] - #[inline(always)] - pub fn column_vector_as_diagonal(self) -> Matrix> { - Matrix { - inner: inner::DiagRef { inner: self }, - } - } - - /// Returns an owning [`Col`] of the data. - #[inline] - pub fn to_owned(&self) -> Col - where - E: Conjugate, - { - let mut mat = Col::new(); - mat.resize_with( - self.nrows(), - #[inline(always)] - |row| unsafe { self.read_unchecked(row).canonicalize() }, - ); - mat - } - - /// Returns `true` if any of the elements is NaN, otherwise returns `false`. - #[inline] - pub fn has_nan(&self) -> bool - where - E: ComplexField, - { - (*self).as_2d().has_nan() - } - - /// Returns `true` if all of the elements are finite, otherwise returns `false`. - #[inline] - pub fn is_all_finite(&self) -> bool - where - E: ComplexField, - { - (*self).rb().as_2d().is_all_finite() - } - - /// Returns the maximum norm of `self`. - #[inline] - pub fn norm_max(&self) -> E::Real - where - E: ComplexField, - { - norm_max((*self).rb().as_2d()) - } - /// Returns the L2 norm of `self`. - #[inline] - pub fn norm_l2(&self) -> E::Real - where - E: ComplexField, - { - norm_l2((*self).rb().as_2d()) - } - - /// Returns the sum of `self`. - #[inline] - pub fn sum(&self) -> E - where - E: ComplexField, - { - sum((*self).rb().as_2d()) - } - - /// Kroneckor product of `self` and `rhs`. - /// - /// This is an allocating operation; see [`kron`] for the - /// allocation-free version or more info in general. - #[inline] - #[track_caller] - pub fn kron(&self, rhs: impl As2D) -> Mat - where - E: ComplexField, - { - self.as_2d_ref().kron(rhs) - } - - /// Returns a view over the matrix. - #[inline] - pub fn as_ref(&self) -> ColRef<'_, E> { - *self - } - - #[doc(hidden)] - #[inline(always)] - pub unsafe fn const_cast(self) -> ColMut<'a, E> { - ColMut { - inner: inner::DenseColMut { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - } - - impl core::ops::Index for ColRef<'_, E> { - type Output = E; - - #[inline] - #[track_caller] - fn index(&self, row: usize) -> &E { - self.get(row) - } - } - - impl core::ops::Index for ColMut<'_, E> { - type Output = E; - - #[inline] - #[track_caller] - fn index(&self, row: usize) -> &E { - (*self).rb().get(row) - } - } - - impl core::ops::IndexMut for ColMut<'_, E> { - #[inline] - #[track_caller] - fn index_mut(&mut self, row: usize) -> &mut E { - (*self).rb_mut().get_mut(row) - } - } - - impl core::ops::Index for Col { - type Output = E; - - #[inline] - #[track_caller] - fn index(&self, row: usize) -> &E { - self.as_ref().get(row) - } - } - - impl core::ops::IndexMut for Col { - #[inline] - #[track_caller] - fn index_mut(&mut self, row: usize) -> &mut E { - self.as_mut().get_mut(row) - } - } - - impl<'a, E: Entity> ColMut<'a, E> { - #[track_caller] - #[inline(always)] - #[doc(hidden)] - pub fn try_get_contiguous_col_mut(self) -> GroupFor { - assert!(self.row_stride() == 1); - let m = self.nrows(); - E::faer_map( - self.as_ptr_mut(), - #[inline(always)] - |ptr| unsafe { core::slice::from_raw_parts_mut(ptr, m) }, - ) - } - - /// Returns the number of rows of the column. - #[inline(always)] - pub fn nrows(&self) -> usize { - self.inner.inner.len - } - /// Returns the number of columns of the column. This is always equal to `1`. - #[inline(always)] - pub fn ncols(&self) -> usize { - 1 - } - - /// Returns pointers to the matrix data. - #[inline(always)] - pub fn as_ptr_mut(self) -> GroupFor { - E::faer_map( - from_copy::(self.inner.inner.ptr), - #[inline(always)] - |ptr| ptr.as_ptr() as *mut E::Unit, - ) - } - - /// Returns the row stride of the matrix, specified in number of elements, not in bytes. - #[inline(always)] - pub fn row_stride(&self) -> isize { - self.inner.inner.stride - } - - /// Returns `self` as a mutable matrix view. - #[inline(always)] - pub fn as_2d_mut(self) -> MatMut<'a, E> { - let nrows = self.nrows(); - let row_stride = self.row_stride(); - unsafe { mat::from_raw_parts_mut(self.as_ptr_mut(), nrows, 1, row_stride, 0) } - } - - /// Returns raw pointers to the element at the given index. - #[inline(always)] - pub fn ptr_at_mut(self, row: usize) -> GroupFor { - let offset = (row as isize).wrapping_mul(self.inner.inner.stride); - - E::faer_map( - self.as_ptr_mut(), - #[inline(always)] - |ptr| ptr.wrapping_offset(offset), - ) - } - - #[inline(always)] - unsafe fn ptr_at_mut_unchecked(self, row: usize) -> GroupFor { - let offset = unchecked_mul(row, self.inner.inner.stride); - E::faer_map( - self.as_ptr_mut(), - #[inline(always)] - |ptr| ptr.offset(offset), - ) - } - - /// Returns raw pointers to the element at the given index, assuming the provided index - /// is within the size of the vector. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn ptr_inbounds_at_mut(self, row: usize) -> GroupFor { - debug_assert!(row < self.nrows()); - self.ptr_at_mut_unchecked(row) - } - - /// Splits the column vector at the given index into two parts and - /// returns an array of each subvector, in the following order: - /// * top. - /// * bottom. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row <= self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at_mut_unchecked(self, row: usize) -> (Self, Self) { - let (top, bot) = self.into_const().split_at_unchecked(row); - unsafe { (top.const_cast(), bot.const_cast()) } - } - - /// Splits the column vector at the given index into two parts and - /// returns an array of each subvector, in the following order: - /// * top. - /// * bottom. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row <= self.nrows()`. - #[inline(always)] - #[track_caller] - pub fn split_at_mut(self, row: usize) -> (Self, Self) { - assert!(row <= self.nrows()); - unsafe { self.split_at_mut_unchecked(row) } - } - - /// Returns references to the element at the given index, or subvector if `row` is a - /// range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - #[inline(always)] - #[track_caller] - pub unsafe fn get_unchecked_mut( - self, - row: RowRange, - ) -> >::Target - where - Self: ColIndex, - { - >::get_unchecked(self, row) - } - - /// Returns references to the element at the given index, or subvector if `row` is a - /// range, with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - #[inline(always)] - #[track_caller] - pub fn get_mut(self, row: RowRange) -> >::Target - where - Self: ColIndex, - { - >::get(self, row) - } - - /// Reads the value of the element at the given index. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn read_unchecked(&self, row: usize) -> E { - self.rb().read_unchecked(row) - } - - /// Reads the value of the element at the given index, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub fn read(&self, row: usize) -> E { - self.rb().read(row) - } - - /// Writes the value to the element at the given index. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn write_unchecked(&mut self, row: usize, value: E) { - let units = value.faer_into_units(); - let zipped = E::faer_zip(units, (*self).rb_mut().ptr_inbounds_at_mut(row)); - E::faer_map( - zipped, - #[inline(always)] - |(unit, ptr)| *ptr = unit, - ); - } - - /// Writes the value to the element at the given index, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub fn write(&mut self, row: usize, value: E) { - assert!(row < self.nrows()); - unsafe { self.write_unchecked(row, value) }; - } - - /// Copies the values from `other` into `self`. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `self.nrows() == other.nrows()`. - /// * `self.ncols() == other.ncols()`. - #[track_caller] - pub fn copy_from(&mut self, other: impl AsColRef) { - #[track_caller] - #[inline(always)] - fn implementation(this: ColMut<'_, E>, other: ColRef<'_, E>) { - zipped!(this.as_2d_mut(), other.as_2d()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - } - implementation(self.rb_mut(), other.as_col_ref()) - } - - /// Fills the elements of `self` with zeros. - #[track_caller] - pub fn fill_zero(&mut self) - where - E: ComplexField, - { - zipped!(self.rb_mut().as_2d_mut()).for_each( - #[inline(always)] - |unzipped!(mut x)| x.write(E::faer_zero()), - ); - } - - /// Fills the elements of `self` with copies of `constant`. - #[track_caller] - pub fn fill(&mut self, constant: E) { - zipped!((*self).rb_mut().as_2d_mut()).for_each( - #[inline(always)] - |unzipped!(mut x)| x.write(constant), - ); - } - - /// Returns a view over the transpose of `self`. - #[inline(always)] - #[must_use] - pub fn transpose_mut(self) -> RowMut<'a, E> { - unsafe { self.into_const().transpose().const_cast() } - } - - /// Returns a view over the conjugate of `self`. - #[inline(always)] - #[must_use] - pub fn conjugate_mut(self) -> ColMut<'a, E::Conj> - where - E: Conjugate, - { - unsafe { self.into_const().conjugate().const_cast() } - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline(always)] - pub fn adjoint_mut(self) -> RowMut<'a, E::Conj> - where - E: Conjugate, - { - self.conjugate_mut().transpose_mut() - } - - /// Returns a view over the canonical representation of `self`, as well as a flag declaring - /// whether `self` is implicitly conjugated or not. - #[inline(always)] - pub fn canonicalize_mut(self) -> (ColMut<'a, E::Canonical>, Conj) - where - E: Conjugate, - { - let (canon, conj) = self.into_const().canonicalize(); - unsafe { (canon.const_cast(), conj) } - } - - /// Returns a view over the `self`, with the rows in reversed order. - #[inline(always)] - #[must_use] - pub fn reverse_rows_mut(self) -> Self { - unsafe { self.into_const().reverse_rows().const_cast() } - } - - /// Returns a view over the subvector starting at row `row_start`, and with number of rows - /// `nrows`. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row_start <= self.nrows()`. - /// * `nrows <= self.nrows() - row_start`. - #[track_caller] - #[inline(always)] - pub unsafe fn subrows_mut_unchecked(self, row_start: usize, nrows: usize) -> Self { - self.into_const() - .subrows_unchecked(row_start, nrows) - .const_cast() - } - - /// Returns a view over the subvector starting at row `row_start`, and with number of rows - /// `nrows`. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row_start <= self.nrows()`. - /// * `nrows <= self.nrows() - row_start`. - #[track_caller] - #[inline(always)] - pub fn subrows_mut(self, row_start: usize, nrows: usize) -> Self { - unsafe { self.into_const().subrows(row_start, nrows).const_cast() } - } - - /// Given a matrix with a single column, returns an object that interprets - /// the column as a diagonal matrix, whoes diagonal elements are values in the column. - #[track_caller] - #[inline(always)] - pub fn column_vector_as_diagonal(self) -> Matrix> { - Matrix { - inner: inner::DiagMut { inner: self }, - } - } - - /// Returns an owning [`Col`] of the data. - #[inline] - pub fn to_owned(&self) -> Col - where - E: Conjugate, - { - (*self).rb().to_owned() - } - - /// Returns `true` if any of the elements is NaN, otherwise returns `false`. - #[inline] - pub fn has_nan(&self) -> bool - where - E: ComplexField, - { - (*self).rb().as_2d().has_nan() - } - - /// Returns `true` if all of the elements are finite, otherwise returns `false`. - #[inline] - pub fn is_all_finite(&self) -> bool - where - E: ComplexField, - { - (*self).rb().as_2d().is_all_finite() - } - - /// Returns the maximum norm of `self`. - #[inline] - pub fn norm_max(&self) -> E::Real - where - E: ComplexField, - { - norm_max((*self).rb().as_2d()) - } - /// Returns the L2 norm of `self`. - #[inline] - pub fn norm_l2(&self) -> E::Real - where - E: ComplexField, - { - norm_l2((*self).rb().as_2d()) - } - - /// Returns the sum of `self`. - #[inline] - pub fn sum(&self) -> E - where - E: ComplexField, - { - sum((*self).rb().as_2d()) - } - - /// Kroneckor product of `self` and `rhs`. - /// - /// This is an allocating operation; see [`kron`] for the - /// allocation-free version or more info in general. - #[inline] - #[track_caller] - pub fn kron(&self, rhs: impl As2D) -> Mat - where - E: ComplexField, - { - self.as_2d_ref().kron(rhs) - } - - /// Returns a view over the matrix. - #[inline] - pub fn as_ref(&self) -> ColRef<'_, E> { - (*self).rb() - } - } -}; - -// ROW IMPL -const _: () = { - impl<'a, E: Entity> RowRef<'a, E> { - /// Returns the number of rows of the row. This is always equal to `1`. - #[inline(always)] - pub fn nrows(&self) -> usize { - 1 - } - /// Returns the number of columns of the row. - #[inline(always)] - pub fn ncols(&self) -> usize { - self.inner.inner.len - } - - /// Returns pointers to the matrix data. - #[inline(always)] - pub fn as_ptr(self) -> GroupFor { - E::faer_map( - from_copy::(self.inner.inner.ptr), - #[inline(always)] - |ptr| ptr.as_ptr() as *const E::Unit, - ) - } - - /// Returns the column stride of the matrix, specified in number of elements, not in bytes. - #[inline(always)] - pub fn col_stride(&self) -> isize { - self.inner.inner.stride - } - - /// Returns `self` as a matrix view. - #[inline(always)] - pub fn as_2d(self) -> MatRef<'a, E> { - let ncols = self.ncols(); - let col_stride = self.col_stride(); - unsafe { mat::from_raw_parts(self.as_ptr(), 1, ncols, 0, col_stride) } - } - - /// Returns raw pointers to the element at the given index. - #[inline(always)] - pub fn ptr_at(self, col: usize) -> GroupFor { - let offset = (col as isize).wrapping_mul(self.inner.inner.stride); - - E::faer_map( - self.as_ptr(), - #[inline(always)] - |ptr| ptr.wrapping_offset(offset), - ) - } - - #[inline(always)] - unsafe fn unchecked_ptr_at(self, col: usize) -> GroupFor { - let offset = unchecked_mul(col, self.inner.inner.stride); - E::faer_map( - self.as_ptr(), - #[inline(always)] - |ptr| ptr.offset(offset), - ) - } - - #[inline(always)] - unsafe fn overflowing_ptr_at(self, col: usize) -> GroupFor { - unsafe { - let cond = col != self.ncols(); - let offset = (cond as usize).wrapping_neg() as isize - & (col as isize).wrapping_mul(self.inner.inner.stride); - E::faer_map( - self.as_ptr(), - #[inline(always)] - |ptr| ptr.offset(offset), - ) - } - } - - /// Returns raw pointers to the element at the given index, assuming the provided index - /// is within the size of the vector. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn ptr_inbounds_at(self, col: usize) -> GroupFor { - debug_assert!(col < self.ncols()); - self.unchecked_ptr_at(col) - } - - /// Splits the column vector at the given index into two parts and - /// returns an array of each subvector, in the following order: - /// * left. - /// * right. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at_unchecked(self, col: usize) -> (Self, Self) { - debug_assert!(col <= self.ncols()); - - let col_stride = self.col_stride(); - - let ncols = self.ncols(); - - unsafe { - let top = self.as_ptr(); - let bot = self.overflowing_ptr_at(col); - - ( - row::from_raw_parts(top, col, col_stride), - row::from_raw_parts(bot, ncols - col, col_stride), - ) - } - } - - /// Splits the column vector at the given index into two parts and - /// returns an array of each subvector, in the following order: - /// * top. - /// * bottom. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at(self, col: usize) -> (Self, Self) { - assert!(col <= self.ncols()); - unsafe { self.split_at_unchecked(col) } - } - - /// Returns references to the element at the given index, or subvector if `row` is a - /// range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col` must be contained in `[0, self.ncols())`. - #[inline(always)] - #[track_caller] - pub unsafe fn get_unchecked( - self, - col: ColRange, - ) -> >::Target - where - Self: RowIndex, - { - >::get_unchecked(self, col) - } - - /// Returns references to the element at the given index, or subvector if `col` is a - /// range, with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col` must be contained in `[0, self.ncols())`. - #[inline(always)] - #[track_caller] - pub fn get(self, col: ColRange) -> >::Target - where - Self: RowIndex, - { - >::get(self, col) - } - - /// Reads the value of the element at the given index. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn read_unchecked(&self, col: usize) -> E { - E::faer_from_units(E::faer_map( - self.get_unchecked(col), - #[inline(always)] - |ptr| *ptr, - )) - } - - /// Reads the value of the element at the given index, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn read(&self, col: usize) -> E { - E::faer_from_units(E::faer_map( - self.get(col), - #[inline(always)] - |ptr| *ptr, - )) - } - - /// Returns a view over the transpose of `self`. - #[inline(always)] - #[must_use] - pub fn transpose(self) -> ColRef<'a, E> { - unsafe { col::from_raw_parts(self.as_ptr(), self.ncols(), self.col_stride()) } - } - - /// Returns a view over the conjugate of `self`. - #[inline(always)] - #[must_use] - pub fn conjugate(self) -> RowRef<'a, E::Conj> - where - E: Conjugate, - { - unsafe { - // SAFETY: Conjugate requires that E::Unit and E::Conj::Unit have the same layout - // and that GroupCopyFor == E::Conj::GroupCopy - row::from_raw_parts::<'_, E::Conj>( - transmute_unchecked::< - GroupFor>, - GroupFor>, - >(self.as_ptr()), - self.ncols(), - self.col_stride(), - ) - } - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline(always)] - pub fn adjoint(self) -> ColRef<'a, E::Conj> - where - E: Conjugate, - { - self.conjugate().transpose() - } - - /// Returns a view over the canonical representation of `self`, as well as a flag declaring - /// whether `self` is implicitly conjugated or not. - #[inline(always)] - pub fn canonicalize(self) -> (RowRef<'a, E::Canonical>, Conj) - where - E: Conjugate, - { - ( - unsafe { - // SAFETY: see Self::conjugate - row::from_raw_parts::<'_, E::Canonical>( - transmute_unchecked::< - GroupFor, - GroupFor>, - >(self.as_ptr()), - self.ncols(), - self.col_stride(), - ) - }, - if coe::is_same::() { - Conj::No - } else { - Conj::Yes - }, - ) - } - - /// Returns a view over the `self`, with the columnss in reversed order. - #[inline(always)] - #[must_use] - pub fn reverse_cols(self) -> Self { - let ncols = self.ncols(); - let col_stride = self.col_stride().wrapping_neg(); - - let ptr = unsafe { self.unchecked_ptr_at(ncols.saturating_sub(1)) }; - unsafe { row::from_raw_parts(ptr, ncols, col_stride) } - } - - /// Returns a view over the subvector starting at column `col_start`, and with number of - /// columns `ncols`. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col_start <= self.ncols()`. - /// * `ncols <= self.ncols() - col_start`. - #[track_caller] - #[inline(always)] - pub unsafe fn subcols_unchecked(self, col_start: usize, ncols: usize) -> Self { - debug_assert!(col_start <= self.ncols()); - debug_assert!(ncols <= self.ncols() - col_start); - let col_stride = self.col_stride(); - unsafe { row::from_raw_parts(self.overflowing_ptr_at(col_start), ncols, col_stride) } - } - - /// Returns a view over the subvector starting at col `col_start`, and with number of cols - /// `ncols`. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col_start <= self.ncols()`. - /// * `ncols <= self.ncols() - col_start`. - #[track_caller] - #[inline(always)] - pub fn subcols(self, col_start: usize, ncols: usize) -> Self { - assert!(col_start <= self.ncols()); - assert!(ncols <= self.ncols() - col_start); - unsafe { self.subcols_unchecked(col_start, ncols) } - } - - /// Returns an owning [`Row`] of the data. - #[inline] - pub fn to_owned(&self) -> Row - where - E: Conjugate, - { - let mut mat = Row::new(); - mat.resize_with( - self.ncols(), - #[inline(always)] - |col| unsafe { self.read_unchecked(col).canonicalize() }, - ); - mat - } - - /// Returns `true` if any of the elements is NaN, otherwise returns `false`. - #[inline] - pub fn has_nan(&self) -> bool - where - E: ComplexField, - { - (*self).rb().as_2d().has_nan() - } - - /// Returns `true` if all of the elements are finite, otherwise returns `false`. - #[inline] - pub fn is_all_finite(&self) -> bool - where - E: ComplexField, - { - (*self).rb().as_2d().is_all_finite() - } - - /// Returns the maximum norm of `self`. - #[inline] - pub fn norm_max(&self) -> E::Real - where - E: ComplexField, - { - norm_max((*self).rb().as_2d()) - } - /// Returns the L2 norm of `self`. - #[inline] - pub fn norm_l2(&self) -> E::Real - where - E: ComplexField, - { - norm_l2((*self).rb().as_2d()) - } - - /// Returns the sum of `self`. - #[inline] - pub fn sum(&self) -> E - where - E: ComplexField, - { - sum((*self).rb().as_2d()) - } - - /// Kroneckor product of `self` and `rhs`. - /// - /// This is an allocating operation; see [`kron`] for the - /// allocation-free version or more info in general. - #[inline] - #[track_caller] - pub fn kron(&self, rhs: impl As2D) -> Mat - where - E: ComplexField, - { - self.as_2d_ref().kron(rhs) - } - - /// Returns a view over the matrix. - #[inline] - pub fn as_ref(&self) -> RowRef<'_, E> { - *self - } - - #[doc(hidden)] - #[inline(always)] - pub unsafe fn const_cast(self) -> RowMut<'a, E> { - RowMut { - inner: inner::DenseRowMut { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - } - - impl core::ops::Index for RowRef<'_, E> { - type Output = E; - - #[inline] - #[track_caller] - fn index(&self, col: usize) -> &E { - self.get(col) - } - } - - impl core::ops::Index for RowMut<'_, E> { - type Output = E; - - #[inline] - #[track_caller] - fn index(&self, col: usize) -> &E { - (*self).rb().get(col) - } - } - - impl core::ops::IndexMut for RowMut<'_, E> { - #[inline] - #[track_caller] - fn index_mut(&mut self, col: usize) -> &mut E { - (*self).rb_mut().get_mut(col) - } - } - - impl core::ops::Index for Row { - type Output = E; - - #[inline] - #[track_caller] - fn index(&self, col: usize) -> &E { - self.as_ref().get(col) - } - } - - impl core::ops::IndexMut for Row { - #[inline] - #[track_caller] - fn index_mut(&mut self, col: usize) -> &mut E { - self.as_mut().get_mut(col) - } - } - - impl<'a, E: Entity> RowMut<'a, E> { - /// Returns the number of rows of the row. This is always equal to `1`. - #[inline(always)] - pub fn nrows(&self) -> usize { - 1 - } - /// Returns the number of columns of the row. - #[inline(always)] - pub fn ncols(&self) -> usize { - self.inner.inner.len - } - - /// Returns pointers to the matrix data. - #[inline(always)] - pub fn as_ptr_mut(self) -> GroupFor { - E::faer_map( - from_copy::(self.inner.inner.ptr), - #[inline(always)] - |ptr| ptr.as_ptr() as *mut E::Unit, - ) - } - - /// Returns the column stride of the matrix, specified in number of elements, not in bytes. - #[inline(always)] - pub fn col_stride(&self) -> isize { - self.inner.inner.stride - } - - /// Returns `self` as a mutable matrix view. - #[inline(always)] - pub fn as_2d_mut(self) -> MatMut<'a, E> { - let ncols = self.ncols(); - let col_stride = self.col_stride(); - unsafe { mat::from_raw_parts_mut(self.as_ptr_mut(), 1, ncols, 0, col_stride) } - } - - /// Returns raw pointers to the element at the given index. - #[inline(always)] - pub fn ptr_at_mut(self, col: usize) -> GroupFor { - let offset = (col as isize).wrapping_mul(self.inner.inner.stride); - - E::faer_map( - self.as_ptr_mut(), - #[inline(always)] - |ptr| ptr.wrapping_offset(offset), - ) - } - - #[inline(always)] - unsafe fn ptr_at_mut_unchecked(self, col: usize) -> GroupFor { - let offset = unchecked_mul(col, self.inner.inner.stride); - E::faer_map( - self.as_ptr_mut(), - #[inline(always)] - |ptr| ptr.offset(offset), - ) - } - - /// Returns raw pointers to the element at the given index, assuming the provided index - /// is within the size of the vector. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn ptr_inbounds_at_mut(self, col: usize) -> GroupFor { - debug_assert!(col < self.ncols()); - self.ptr_at_mut_unchecked(col) - } - - /// Splits the column vector at the given index into two parts and - /// returns an array of each subvector, in the following order: - /// * left. - /// * right. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at_mut_unchecked(self, col: usize) -> (Self, Self) { - let (left, right) = self.into_const().split_at_unchecked(col); - unsafe { (left.const_cast(), right.const_cast()) } - } - - /// Splits the column vector at the given index into two parts and - /// returns an array of each subvector, in the following order: - /// * top. - /// * bottom. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn split_at_mut(self, col: usize) -> (Self, Self) { - assert!(col <= self.ncols()); - unsafe { self.split_at_mut_unchecked(col) } - } - - /// Returns references to the element at the given index, or subvector if `col` is a - /// range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col` must be contained in `[0, self.ncols())`. - #[inline(always)] - #[track_caller] - pub unsafe fn get_mut_unchecked( - self, - col: ColRange, - ) -> >::Target - where - Self: RowIndex, - { - >::get_unchecked(self, col) - } - - /// Returns references to the element at the given index, or subvector if `col` is a - /// range, with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col` must be contained in `[0, self.ncols())`. - #[inline(always)] - #[track_caller] - pub fn get_mut(self, col: ColRange) -> >::Target - where - Self: RowIndex, - { - >::get(self, col) - } - - /// Reads the value of the element at the given index. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn read_unchecked(&self, col: usize) -> E { - self.rb().read_unchecked(col) - } - - /// Reads the value of the element at the given index, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn read(&self, col: usize) -> E { - self.rb().read(col) - } - - /// Writes the value to the element at the given index. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn write_unchecked(&mut self, col: usize, value: E) { - let units = value.faer_into_units(); - let zipped = E::faer_zip(units, (*self).rb_mut().ptr_inbounds_at_mut(col)); - E::faer_map( - zipped, - #[inline(always)] - |(unit, ptr)| *ptr = unit, - ); - } - - /// Writes the value to the element at the given index, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn write(&mut self, col: usize, value: E) { - assert!(col < self.ncols()); - unsafe { self.write_unchecked(col, value) }; - } - - /// Copies the values from `other` into `self`. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `self.ncols() == other.ncols()`. - #[track_caller] - pub fn copy_from(&mut self, other: impl AsRowRef) { - #[track_caller] - #[inline(always)] - fn implementation(this: RowMut<'_, E>, other: RowRef<'_, E>) { - zipped!(this.as_2d_mut(), other.as_2d()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - } - implementation(self.rb_mut(), other.as_row_ref()) - } - - /// Fills the elements of `self` with zeros. - #[track_caller] - pub fn fill_zero(&mut self) - where - E: ComplexField, - { - zipped!(self.rb_mut().as_2d_mut()).for_each( - #[inline(always)] - |unzipped!(mut x)| x.write(E::faer_zero()), - ); - } - - /// Fills the elements of `self` with copies of `constant`. - #[track_caller] - pub fn fill(&mut self, constant: E) { - zipped!((*self).rb_mut().as_2d_mut()).for_each( - #[inline(always)] - |unzipped!(mut x)| x.write(constant), - ); - } - - /// Returns a view over the transpose of `self`. - #[inline(always)] - #[must_use] - pub fn transpose_mut(self) -> ColMut<'a, E> { - unsafe { self.into_const().transpose().const_cast() } - } - - /// Returns a view over the conjugate of `self`. - #[inline(always)] - #[must_use] - pub fn conjugate_mut(self) -> RowMut<'a, E::Conj> - where - E: Conjugate, - { - unsafe { self.into_const().conjugate().const_cast() } - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline(always)] - pub fn adjoint_mut(self) -> ColMut<'a, E::Conj> - where - E: Conjugate, - { - self.conjugate_mut().transpose_mut() - } - - /// Returns a view over the canonical representation of `self`, as well as a flag declaring - /// whether `self` is implicitly conjugated or not. - #[inline(always)] - pub fn canonicalize_mut(self) -> (RowMut<'a, E::Canonical>, Conj) - where - E: Conjugate, - { - let (canon, conj) = self.into_const().canonicalize(); - unsafe { (canon.const_cast(), conj) } - } - - /// Returns a view over the `self`, with the columnss in reversed order. - #[inline(always)] - #[must_use] - pub fn reverse_cols_mut(self) -> Self { - unsafe { self.into_const().reverse_cols().const_cast() } - } - - /// Returns a view over the subvector starting at col `col_start`, and with number of - /// columns `ncols`. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col_start <= self.ncols()`. - /// * `ncols <= self.ncols() - col_start`. - #[track_caller] - #[inline(always)] - pub unsafe fn subcols_mut_unchecked(self, col_start: usize, ncols: usize) -> Self { - self.into_const() - .subcols_unchecked(col_start, ncols) - .const_cast() - } - - /// Returns a view over the subvector starting at col `col_start`, and with number of - /// columns `ncols`. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col_start <= self.ncols()`. - /// * `ncols <= self.ncols() - col_start`. - #[track_caller] - #[inline(always)] - pub fn subcols_mut(self, col_start: usize, ncols: usize) -> Self { - unsafe { self.into_const().subcols(col_start, ncols).const_cast() } - } - - /// Returns an owning [`Row`] of the data. - #[inline] - pub fn to_owned(&self) -> Row - where - E: Conjugate, - { - (*self).rb().to_owned() - } - - /// Returns `true` if any of the elements is NaN, otherwise returns `false`. - #[inline] - pub fn has_nan(&self) -> bool - where - E: ComplexField, - { - (*self).rb().as_2d().has_nan() - } - - /// Returns `true` if all of the elements are finite, otherwise returns `false`. - #[inline] - pub fn is_all_finite(&self) -> bool - where - E: ComplexField, - { - (*self).rb().as_2d().is_all_finite() - } - - /// Returns the maximum norm of `self`. - #[inline] - pub fn norm_max(&self) -> E::Real - where - E: ComplexField, - { - norm_max((*self).rb().as_2d()) - } - /// Returns the L2 norm of `self`. - #[inline] - pub fn norm_l2(&self) -> E::Real - where - E: ComplexField, - { - norm_l2((*self).rb().as_2d()) - } - - /// Returns the sum of `self`. - #[inline] - pub fn sum(&self) -> E - where - E: ComplexField, - { - sum((*self).rb().as_2d()) - } - - /// Kroneckor product of `self` and `rhs`. - /// - /// This is an allocating operation; see [`kron`] for the - /// allocation-free version or more info in general. - #[inline] - #[track_caller] - pub fn kron(&self, rhs: impl As2D) -> Mat - where - E: ComplexField, - { - self.as_2d_ref().kron(rhs) - } - - /// Returns a view over the matrix. - #[inline] - pub fn as_ref(&self) -> RowRef<'_, E> { - (*self).rb() - } - } -}; - -// MAT IMPL -const _: () = { - impl<'a, E: Entity> MatRef<'a, E> { - #[track_caller] - #[inline(always)] - #[doc(hidden)] - pub fn try_get_contiguous_col(self, j: usize) -> GroupFor { - assert!(self.row_stride() == 1); - let col = self.col(j); - if col.nrows() == 0 { - E::faer_map( - E::UNIT, - #[inline(always)] - |()| &[] as &[E::Unit], - ) - } else { - let m = col.nrows(); - E::faer_map( - col.as_ptr(), - #[inline(always)] - |ptr| unsafe { core::slice::from_raw_parts(ptr, m) }, - ) - } - } - - /// Returns the number of rows of the matrix. - #[inline(always)] - pub fn nrows(&self) -> usize { - self.inner.inner.nrows - } - /// Returns the number of columns of the matrix. - #[inline(always)] - pub fn ncols(&self) -> usize { - self.inner.inner.ncols - } - - /// Returns pointers to the matrix data. - #[inline(always)] - pub fn as_ptr(self) -> GroupFor { - E::faer_map( - from_copy::(self.inner.inner.ptr), - #[inline(always)] - |ptr| ptr.as_ptr() as *const E::Unit, - ) - } - - /// Returns the row stride of the matrix, specified in number of elements, not in bytes. - #[inline(always)] - pub fn row_stride(&self) -> isize { - self.inner.inner.row_stride - } - - /// Returns the column stride of the matrix, specified in number of elements, not in bytes. - #[inline(always)] - pub fn col_stride(&self) -> isize { - self.inner.inner.col_stride - } - - /// Returns raw pointers to the element at the given indices. - #[inline(always)] - pub fn ptr_at(self, row: usize, col: usize) -> GroupFor { - let offset = ((row as isize).wrapping_mul(self.inner.inner.row_stride)) - .wrapping_add((col as isize).wrapping_mul(self.inner.inner.col_stride)); - - E::faer_map( - self.as_ptr(), - #[inline(always)] - |ptr| ptr.wrapping_offset(offset), - ) - } - - #[inline(always)] - unsafe fn unchecked_ptr_at(self, row: usize, col: usize) -> GroupFor { - let offset = unchecked_add( - unchecked_mul(row, self.inner.inner.row_stride), - unchecked_mul(col, self.inner.inner.col_stride), - ); - E::faer_map( - self.as_ptr(), - #[inline(always)] - |ptr| ptr.offset(offset), - ) - } - - #[inline(always)] - unsafe fn overflowing_ptr_at(self, row: usize, col: usize) -> GroupFor { - unsafe { - let cond = (row != self.nrows()) & (col != self.ncols()); - let offset = (cond as usize).wrapping_neg() as isize - & (isize::wrapping_add( - (row as isize).wrapping_mul(self.inner.inner.row_stride), - (col as isize).wrapping_mul(self.inner.inner.col_stride), - )); - E::faer_map( - self.as_ptr(), - #[inline(always)] - |ptr| ptr.offset(offset), - ) - } - } - - /// Returns raw pointers to the element at the given indices, assuming the provided indices - /// are within the matrix dimensions. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn ptr_inbounds_at(self, row: usize, col: usize) -> GroupFor { - debug_assert!(all(row < self.nrows(), col < self.ncols())); - self.unchecked_ptr_at(row, col) - } - - /// Splits the matrix horizontally and vertically at the given indices into four corners and - /// returns an array of each submatrix, in the following order: - /// * top left. - /// * top right. - /// * bottom left. - /// * bottom right. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row <= self.nrows()`. - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at_unchecked(self, row: usize, col: usize) -> (Self, Self, Self, Self) { - debug_assert!(all(row <= self.nrows(), col <= self.ncols())); - - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - - let nrows = self.nrows(); - let ncols = self.ncols(); - - unsafe { - let top_left = self.overflowing_ptr_at(0, 0); - let top_right = self.overflowing_ptr_at(0, col); - let bot_left = self.overflowing_ptr_at(row, 0); - let bot_right = self.overflowing_ptr_at(row, col); - - ( - mat::from_raw_parts(top_left, row, col, row_stride, col_stride), - mat::from_raw_parts(top_right, row, ncols - col, row_stride, col_stride), - mat::from_raw_parts(bot_left, nrows - row, col, row_stride, col_stride), - mat::from_raw_parts( - bot_right, - nrows - row, - ncols - col, - row_stride, - col_stride, - ), - ) - } - } - - /// Splits the matrix horizontally and vertically at the given indices into four corners and - /// returns an array of each submatrix, in the following order: - /// * top left. - /// * top right. - /// * bottom left. - /// * bottom right. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row <= self.nrows()`. - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn split_at(self, row: usize, col: usize) -> (Self, Self, Self, Self) { - assert!(all(row <= self.nrows(), col <= self.ncols())); - unsafe { self.split_at_unchecked(row, col) } - } - - /// Splits the matrix horizontally at the given row into two parts and returns an array of - /// each submatrix, in the following order: - /// * top. - /// * bottom. - /// - /// # Safety - /// The behavior is undefined if the following condition is violated: - /// * `row <= self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at_row_unchecked(self, row: usize) -> (Self, Self) { - debug_assert!(row <= self.nrows()); - - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - - let nrows = self.nrows(); - let ncols = self.ncols(); - - unsafe { - let top_right = self.overflowing_ptr_at(0, 0); - let bot_right = self.overflowing_ptr_at(row, 0); - - ( - mat::from_raw_parts(top_right, row, ncols, row_stride, col_stride), - mat::from_raw_parts(bot_right, nrows - row, ncols, row_stride, col_stride), - ) - } - } - - /// Splits the matrix horizontally at the given row into two parts and returns an array of - /// each submatrix, in the following order: - /// * top. - /// * bottom. - /// - /// # Panics - /// The function panics if the following condition is violated: - /// * `row <= self.nrows()`. - #[inline(always)] - #[track_caller] - pub fn split_at_row(self, row: usize) -> (Self, Self) { - assert!(row <= self.nrows()); - unsafe { self.split_at_row_unchecked(row) } - } - - /// Splits the matrix vertically at the given row into two parts and returns an array of - /// each submatrix, in the following order: - /// * left. - /// * right. - /// - /// # Safety - /// The behavior is undefined if the following condition is violated: - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at_col_unchecked(self, col: usize) -> (Self, Self) { - debug_assert!(col <= self.ncols()); - - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - - let nrows = self.nrows(); - let ncols = self.ncols(); - - unsafe { - let bot_left = self.overflowing_ptr_at(0, 0); - let bot_right = self.overflowing_ptr_at(0, col); - - ( - mat::from_raw_parts(bot_left, nrows, col, row_stride, col_stride), - mat::from_raw_parts(bot_right, nrows, ncols - col, row_stride, col_stride), - ) - } - } - - /// Splits the matrix vertically at the given row into two parts and returns an array of - /// each submatrix, in the following order: - /// * left. - /// * right. - /// - /// # Panics - /// The function panics if the following condition is violated: - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn split_at_col(self, col: usize) -> (Self, Self) { - assert!(col <= self.ncols()); - unsafe { self.split_at_col_unchecked(col) } - } - - /// Returns references to the element at the given indices, or submatrices if either `row` - /// or `col` is a range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - /// * `col` must be contained in `[0, self.ncols())`. - #[inline(always)] - #[track_caller] - pub unsafe fn get_unchecked( - self, - row: RowRange, - col: ColRange, - ) -> >::Target - where - Self: MatIndex, - { - >::get_unchecked(self, row, col) - } - - /// Returns references to the element at the given indices, or submatrices if either `row` - /// or `col` is a range, with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - /// * `col` must be contained in `[0, self.ncols())`. - #[inline(always)] - #[track_caller] - pub fn get( - self, - row: RowRange, - col: ColRange, - ) -> >::Target - where - Self: MatIndex, - { - >::get(self, row, col) - } - - /// Reads the value of the element at the given indices. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn read_unchecked(&self, row: usize, col: usize) -> E { - E::faer_from_units(E::faer_map( - self.get_unchecked(row, col), - #[inline(always)] - |ptr| *ptr, - )) - } - - /// Reads the value of the element at the given indices, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn read(&self, row: usize, col: usize) -> E { - E::faer_from_units(E::faer_map( - self.get(row, col), - #[inline(always)] - |ptr| *ptr, - )) - } - - /// Returns a view over the transpose of `self`. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - /// let view = matrix.as_ref(); - /// let transpose = view.transpose(); - /// - /// let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; - /// assert_eq!(expected.as_ref(), transpose); - /// ``` - #[inline(always)] - #[must_use] - pub fn transpose(self) -> Self { - unsafe { - mat::from_raw_parts( - self.as_ptr(), - self.ncols(), - self.nrows(), - self.col_stride(), - self.row_stride(), - ) - } - } - - /// Returns a view over the conjugate of `self`. - #[inline(always)] - #[must_use] - pub fn conjugate(self) -> MatRef<'a, E::Conj> - where - E: Conjugate, - { - unsafe { - // SAFETY: Conjugate requires that E::Unit and E::Conj::Unit have the same layout - // and that GroupCopyFor == E::Conj::GroupCopy - mat::from_raw_parts::<'_, E::Conj>( - transmute_unchecked::< - GroupFor>, - GroupFor>, - >(self.as_ptr()), - self.nrows(), - self.ncols(), - self.row_stride(), - self.col_stride(), - ) - } - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline(always)] - pub fn adjoint(self) -> MatRef<'a, E::Conj> - where - E: Conjugate, - { - self.transpose().conjugate() - } - - /// Returns a view over the canonical representation of `self`, as well as a flag declaring - /// whether `self` is implicitly conjugated or not. - #[inline(always)] - pub fn canonicalize(self) -> (MatRef<'a, E::Canonical>, Conj) - where - E: Conjugate, - { - ( - unsafe { - // SAFETY: see Self::conjugate - mat::from_raw_parts::<'_, E::Canonical>( - transmute_unchecked::< - GroupFor, - GroupFor>, - >(self.as_ptr()), - self.nrows(), - self.ncols(), - self.row_stride(), - self.col_stride(), - ) - }, - if coe::is_same::() { - Conj::No - } else { - Conj::Yes - }, - ) - } - - /// Returns a view over the `self`, with the rows in reversed order. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - /// let view = matrix.as_ref(); - /// let reversed_rows = view.reverse_rows(); - /// - /// let expected = mat![[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]; - /// assert_eq!(expected.as_ref(), reversed_rows); - /// ``` - #[inline(always)] - #[must_use] - pub fn reverse_rows(self) -> Self { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride = self.row_stride().wrapping_neg(); - let col_stride = self.col_stride(); - - let ptr = unsafe { self.unchecked_ptr_at(nrows.saturating_sub(1), 0) }; - unsafe { mat::from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) } - } - - /// Returns a view over the `self`, with the columns in reversed order. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - /// let view = matrix.as_ref(); - /// let reversed_cols = view.reverse_cols(); - /// - /// let expected = mat![[3.0, 2.0, 1.0], [6.0, 5.0, 4.0]]; - /// assert_eq!(expected.as_ref(), reversed_cols); - /// ``` - #[inline(always)] - #[must_use] - pub fn reverse_cols(self) -> Self { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride = self.row_stride(); - let col_stride = self.col_stride().wrapping_neg(); - let ptr = unsafe { self.unchecked_ptr_at(0, ncols.saturating_sub(1)) }; - unsafe { mat::from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) } - } - - /// Returns a view over the `self`, with the rows and the columns in reversed order. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - /// let view = matrix.as_ref(); - /// let reversed = view.reverse_rows_and_cols(); - /// - /// let expected = mat![[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]; - /// assert_eq!(expected.as_ref(), reversed); - /// ``` - #[inline(always)] - #[must_use] - pub fn reverse_rows_and_cols(self) -> Self { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride = -self.row_stride(); - let col_stride = -self.col_stride(); - - let ptr = - unsafe { self.unchecked_ptr_at(nrows.saturating_sub(1), ncols.saturating_sub(1)) }; - unsafe { mat::from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) } - } - - /// Returns a view over the submatrix starting at indices `(row_start, col_start)`, and with - /// dimensions `(nrows, ncols)`. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row_start <= self.nrows()`. - /// * `col_start <= self.ncols()`. - /// * `nrows <= self.nrows() - row_start`. - /// * `ncols <= self.ncols() - col_start`. - #[track_caller] - #[inline(always)] - pub unsafe fn submatrix_unchecked( - self, - row_start: usize, - col_start: usize, - nrows: usize, - ncols: usize, - ) -> Self { - debug_assert!(all(row_start <= self.nrows(), col_start <= self.ncols())); - debug_assert!(all( - nrows <= self.nrows() - row_start, - ncols <= self.ncols() - col_start, - )); - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - - unsafe { - mat::from_raw_parts( - self.overflowing_ptr_at(row_start, col_start), - nrows, - ncols, - row_stride, - col_stride, - ) - } - } - - /// Returns a view over the submatrix starting at indices `(row_start, col_start)`, and with - /// dimensions `(nrows, ncols)`. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row_start <= self.nrows()`. - /// * `col_start <= self.ncols()`. - /// * `nrows <= self.nrows() - row_start`. - /// * `ncols <= self.ncols() - col_start`. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let matrix = mat![ - /// [1.0, 5.0, 9.0], - /// [2.0, 6.0, 10.0], - /// [3.0, 7.0, 11.0], - /// [4.0, 8.0, 12.0f64], - /// ]; - /// - /// let view = matrix.as_ref(); - /// let submatrix = view.submatrix(2, 1, 2, 2); - /// - /// let expected = mat![[7.0, 11.0], [8.0, 12.0f64]]; - /// assert_eq!(expected.as_ref(), submatrix); - /// ``` - #[track_caller] - #[inline(always)] - pub fn submatrix( - self, - row_start: usize, - col_start: usize, - nrows: usize, - ncols: usize, - ) -> Self { - assert!(all(row_start <= self.nrows(), col_start <= self.ncols())); - assert!(all( - nrows <= self.nrows() - row_start, - ncols <= self.ncols() - col_start, - )); - unsafe { self.submatrix_unchecked(row_start, col_start, nrows, ncols) } - } - - /// Returns a view over the submatrix starting at row `row_start`, and with number of rows - /// `nrows`. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row_start <= self.nrows()`. - /// * `nrows <= self.nrows() - row_start`. - #[track_caller] - #[inline(always)] - pub unsafe fn subrows_unchecked(self, row_start: usize, nrows: usize) -> Self { - debug_assert!(row_start <= self.nrows()); - debug_assert!(nrows <= self.nrows() - row_start); - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - unsafe { - mat::from_raw_parts( - self.overflowing_ptr_at(row_start, 0), - nrows, - self.ncols(), - row_stride, - col_stride, - ) - } - } - - /// Returns a view over the submatrix starting at row `row_start`, and with number of rows - /// `nrows`. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row_start <= self.nrows()`. - /// * `nrows <= self.nrows() - row_start`. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let matrix = mat![ - /// [1.0, 5.0, 9.0], - /// [2.0, 6.0, 10.0], - /// [3.0, 7.0, 11.0], - /// [4.0, 8.0, 12.0f64], - /// ]; - /// - /// let view = matrix.as_ref(); - /// let subrows = view.subrows(1, 2); - /// - /// let expected = mat![[2.0, 6.0, 10.0], [3.0, 7.0, 11.0],]; - /// assert_eq!(expected.as_ref(), subrows); - /// ``` - #[track_caller] - #[inline(always)] - pub fn subrows(self, row_start: usize, nrows: usize) -> Self { - assert!(row_start <= self.nrows()); - assert!(nrows <= self.nrows() - row_start); - unsafe { self.subrows_unchecked(row_start, nrows) } - } - - /// Returns a view over the submatrix starting at column `col_start`, and with number of - /// columns `ncols`. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col_start <= self.ncols()`. - /// * `ncols <= self.ncols() - col_start`. - #[track_caller] - #[inline(always)] - pub unsafe fn subcols_unchecked(self, col_start: usize, ncols: usize) -> Self { - debug_assert!(col_start <= self.ncols()); - debug_assert!(ncols <= self.ncols() - col_start); - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - unsafe { - mat::from_raw_parts( - self.overflowing_ptr_at(0, col_start), - self.nrows(), - ncols, - row_stride, - col_stride, - ) - } - } - - /// Returns a view over the submatrix starting at column `col_start`, and with number of - /// columns `ncols`. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col_start <= self.ncols()`. - /// * `ncols <= self.ncols() - col_start`. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let matrix = mat![ - /// [1.0, 5.0, 9.0], - /// [2.0, 6.0, 10.0], - /// [3.0, 7.0, 11.0], - /// [4.0, 8.0, 12.0f64], - /// ]; - /// - /// let view = matrix.as_ref(); - /// let subcols = view.subcols(2, 1); - /// - /// let expected = mat![[9.0], [10.0], [11.0], [12.0f64]]; - /// assert_eq!(expected.as_ref(), subcols); - /// ``` - #[track_caller] - #[inline(always)] - pub fn subcols(self, col_start: usize, ncols: usize) -> Self { - debug_assert!(col_start <= self.ncols()); - debug_assert!(ncols <= self.ncols() - col_start); - unsafe { self.subcols_unchecked(col_start, ncols) } - } - - /// Returns a view over the row at the given index. - /// - /// # Safety - /// The function panics if any of the following conditions are violated: - /// * `row_idx < self.nrows()`. - #[track_caller] - #[inline(always)] - pub unsafe fn row_unchecked(self, row_idx: usize) -> RowRef<'a, E> { - debug_assert!(row_idx < self.nrows()); - unsafe { - row::from_raw_parts( - self.overflowing_ptr_at(row_idx, 0), - self.ncols(), - self.col_stride(), - ) - } - } - - /// Returns a view over the row at the given index. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row_idx < self.nrows()`. - #[track_caller] - #[inline(always)] - pub fn row(self, row_idx: usize) -> RowRef<'a, E> { - assert!(row_idx < self.nrows()); - unsafe { self.row_unchecked(row_idx) } - } - - /// Returns a view over the column at the given index. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col_idx < self.ncols()`. - #[track_caller] - #[inline(always)] - pub unsafe fn col_unchecked(self, col_idx: usize) -> ColRef<'a, E> { - debug_assert!(col_idx < self.ncols()); - unsafe { - col::from_raw_parts( - self.overflowing_ptr_at(0, col_idx), - self.nrows(), - self.row_stride(), - ) - } - } - - /// Returns a view over the column at the given index. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col_idx < self.ncols()`. - #[track_caller] - #[inline(always)] - pub fn col(self, col_idx: usize) -> ColRef<'a, E> { - assert!(col_idx < self.ncols()); - unsafe { self.col_unchecked(col_idx) } - } - - /// Given a matrix with a single column, returns an object that interprets - /// the column as a diagonal matrix, whoes diagonal elements are values in the column. - #[track_caller] - #[inline(always)] - pub fn column_vector_as_diagonal(self) -> Matrix> { - assert!(self.ncols() == 1); - Matrix { - inner: inner::DiagRef { inner: self.col(0) }, - } - } - - /// Returns the diagonal of the matrix. - #[inline(always)] - pub fn diagonal(self) -> Matrix> { - let size = self.nrows().min(self.ncols()); - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - unsafe { - Matrix { - inner: inner::DiagRef { - inner: col::from_raw_parts(self.as_ptr(), size, row_stride + col_stride), - }, - } - } - } - - /// Returns an owning [`Mat`] of the data. - #[inline] - pub fn to_owned(&self) -> Mat - where - E: Conjugate, - { - let mut mat = Mat::new(); - mat.resize_with( - self.nrows(), - self.ncols(), - #[inline(always)] - |row, col| unsafe { self.read_unchecked(row, col).canonicalize() }, - ); - mat - } - - /// Returns `true` if any of the elements is NaN, otherwise returns `false`. - #[inline] - pub fn has_nan(&self) -> bool - where - E: ComplexField, - { - let mut found_nan = false; - zipped!(*self).for_each(|unzipped!(x)| { - found_nan |= x.read().faer_is_nan(); - }); - found_nan - } - - /// Returns `true` if all of the elements are finite, otherwise returns `false`. - #[inline] - pub fn is_all_finite(&self) -> bool - where - E: ComplexField, - { - let mut all_finite = true; - zipped!(*self).for_each(|unzipped!(x)| { - all_finite &= x.read().faer_is_finite(); - }); - all_finite - } - - /// Returns the maximum norm of `self`. - #[inline] - pub fn norm_max(&self) -> E::Real - where - E: ComplexField, - { - norm_max((*self).rb()) - } - /// Returns the L2 norm of `self`. - #[inline] - pub fn norm_l2(&self) -> E::Real - where - E: ComplexField, - { - norm_l2((*self).rb()) - } - - /// Returns the sum of `self`. - #[inline] - pub fn sum(&self) -> E - where - E: ComplexField, - { - sum((*self).rb()) - } - - /// Kroneckor product of `self` and `rhs`. - /// - /// This is an allocating operation; see [`kron`] for the - /// allocation-free version or more info in general. - #[inline] - #[track_caller] - pub fn kron(&self, rhs: impl As2D) -> Mat - where - E: ComplexField, - { - let lhs = (*self).rb(); - let rhs = rhs.as_2d_ref(); - let mut dst = Mat::new(); - dst.resize_with( - lhs.nrows() * rhs.nrows(), - lhs.ncols() * rhs.ncols(), - |_, _| E::zeroed(), - ); - kron(dst.as_mut(), lhs, rhs); - dst - } - - /// Returns a view over the matrix. - #[inline] - pub fn as_ref(&self) -> MatRef<'_, E> { - *self - } - - #[doc(hidden)] - #[inline(always)] - pub unsafe fn const_cast(self) -> MatMut<'a, E> { - MatMut { - inner: inner::DenseMut { - inner: self.inner.inner, - __marker: PhantomData, - }, - } - } - - /// Returns an iterator that provides successive chunks of the columns of this matrix, with - /// each having at most `chunk_size` columns. - /// - /// If the number of columns is a multiple of `chunk_size`, then all chunks have - /// `chunk_size` columns. - #[inline] - #[track_caller] - pub fn col_chunks( - self, - chunk_size: usize, - ) -> impl 'a + DoubleEndedIterator> { - assert!(chunk_size > 0); - let chunk_count = self.ncols().msrv_div_ceil(chunk_size); - (0..chunk_count).map(move |chunk_idx| { - let pos = chunk_size * chunk_idx; - self.subcols(pos, Ord::min(chunk_size, self.ncols() - pos)) - }) - } - - /// Returns an iterator that provides successive chunks of the rows of this matrix, with - /// each having at most `chunk_size` rows. - /// - /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// rows. - #[inline] - #[track_caller] - pub fn row_chunks( - self, - chunk_size: usize, - ) -> impl 'a + DoubleEndedIterator> { - self.transpose() - .col_chunks(chunk_size) - .map(|chunk| chunk.transpose()) - } - - /// Returns a parallel iterator that provides successive chunks of the columns of this - /// matrix, with each having at most `chunk_size` columns. - /// - /// If the number of columns is a multiple of `chunk_size`, then all chunks have - /// `chunk_size` columns. - /// - /// Only available with the `rayon` feature. - #[cfg(feature = "rayon")] - #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] - #[inline] - #[track_caller] - pub fn par_col_chunks( - self, - chunk_size: usize, - ) -> impl 'a + rayon::iter::IndexedParallelIterator> { - use rayon::prelude::*; - - assert!(chunk_size > 0); - let chunk_count = self.ncols().msrv_div_ceil(chunk_size); - (0..chunk_count).into_par_iter().map(move |chunk_idx| { - let pos = chunk_size * chunk_idx; - self.subcols(pos, Ord::min(chunk_size, self.ncols() - pos)) - }) - } - - /// Returns a parallel iterator that provides successive chunks of the rows of this matrix, - /// with each having at most `chunk_size` rows. - /// - /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// rows. - /// - /// Only available with the `rayon` feature. - #[cfg(feature = "rayon")] - #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] - #[inline] - #[track_caller] - pub fn par_row_chunks( - self, - chunk_size: usize, - ) -> impl 'a + rayon::iter::IndexedParallelIterator> { - use rayon::prelude::*; - - self.transpose() - .par_col_chunks(chunk_size) - .map(|chunk| chunk.transpose()) - } - - /// Returns a parallel iterator that provides successive chunks of the rows of this matrix, - /// with each having at most `chunk_size` rows. - /// - /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// rows. - /// - /// Only available with the `rayon` feature. - #[cfg(feature = "rayon")] - #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] - #[inline] - #[track_caller] - #[deprecated = "replaced by `MatRef::par_row_chunks`"] - pub fn into_par_row_chunks( - self, - chunk_size: usize, - ) -> impl 'a + rayon::iter::IndexedParallelIterator> { - self.par_row_chunks(chunk_size) - } - } - - impl core::ops::Index<(usize, usize)> for MatRef<'_, E> { - type Output = E; - - #[inline] - #[track_caller] - fn index(&self, (row, col): (usize, usize)) -> &E { - self.get(row, col) - } - } - - impl core::ops::Index<(usize, usize)> for MatMut<'_, E> { - type Output = E; - - #[inline] - #[track_caller] - fn index(&self, (row, col): (usize, usize)) -> &E { - (*self).rb().get(row, col) - } - } - - impl core::ops::IndexMut<(usize, usize)> for MatMut<'_, E> { - #[inline] - #[track_caller] - fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut E { - (*self).rb_mut().get_mut(row, col) - } - } - - impl core::ops::Index<(usize, usize)> for Mat { - type Output = E; - - #[inline] - #[track_caller] - fn index(&self, (row, col): (usize, usize)) -> &E { - self.as_ref().get(row, col) - } - } - - impl core::ops::IndexMut<(usize, usize)> for Mat { - #[inline] - #[track_caller] - fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut E { - self.as_mut().get_mut(row, col) - } - } - - impl<'a, E: Entity> MatMut<'a, E> { - #[track_caller] - #[inline(always)] - #[doc(hidden)] - pub fn try_get_contiguous_col_mut(self, j: usize) -> GroupFor { - assert!(self.row_stride() == 1); - let col = self.col_mut(j); - if col.nrows() == 0 { - E::faer_map( - E::UNIT, - #[inline(always)] - |()| &mut [] as &mut [E::Unit], - ) - } else { - let m = col.nrows(); - E::faer_map( - col.as_ptr_mut(), - #[inline(always)] - |ptr| unsafe { core::slice::from_raw_parts_mut(ptr, m) }, - ) - } - } - - /// Returns the number of rows of the matrix. - #[inline(always)] - pub fn nrows(&self) -> usize { - self.inner.inner.nrows - } - /// Returns the number of columns of the matrix. - #[inline(always)] - pub fn ncols(&self) -> usize { - self.inner.inner.ncols - } - - /// Returns pointers to the matrix data. - #[inline(always)] - pub fn as_ptr_mut(self) -> GroupFor { - E::faer_map( - from_copy::(self.inner.inner.ptr), - #[inline(always)] - |ptr| ptr.as_ptr(), - ) - } - - /// Returns the row stride of the matrix, specified in number of elements, not in bytes. - #[inline(always)] - pub fn row_stride(&self) -> isize { - self.inner.inner.row_stride - } - - /// Returns the column stride of the matrix, specified in number of elements, not in bytes. - #[inline(always)] - pub fn col_stride(&self) -> isize { - self.inner.inner.col_stride - } - - /// Returns raw pointers to the element at the given indices. - #[inline(always)] - pub fn ptr_at_mut(self, row: usize, col: usize) -> GroupFor { - let offset = ((row as isize).wrapping_mul(self.inner.inner.row_stride)) - .wrapping_add((col as isize).wrapping_mul(self.inner.inner.col_stride)); - E::faer_map( - self.as_ptr_mut(), - #[inline(always)] - |ptr| ptr.wrapping_offset(offset), - ) - } - - #[inline(always)] - unsafe fn ptr_at_mut_unchecked(self, row: usize, col: usize) -> GroupFor { - let offset = unchecked_add( - unchecked_mul(row, self.inner.inner.row_stride), - unchecked_mul(col, self.inner.inner.col_stride), - ); - E::faer_map( - self.as_ptr_mut(), - #[inline(always)] - |ptr| ptr.offset(offset), - ) - } - - /// Returns raw pointers to the element at the given indices, assuming the provided indices - /// are within the matrix dimensions. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn ptr_inbounds_at_mut( - self, - row: usize, - col: usize, - ) -> GroupFor { - debug_assert!(all(row < self.nrows(), col < self.ncols())); - self.ptr_at_mut_unchecked(row, col) - } - - /// Splits the matrix horizontally and vertically at the given indices into four corners and - /// returns an array of each submatrix, in the following order: - /// * top left. - /// * top right. - /// * bottom left. - /// * bottom right. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row <= self.nrows()`. - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at_mut_unchecked( - self, - row: usize, - col: usize, - ) -> (Self, Self, Self, Self) { - let (top_left, top_right, bot_left, bot_right) = - self.into_const().split_at_unchecked(row, col); - ( - top_left.const_cast(), - top_right.const_cast(), - bot_left.const_cast(), - bot_right.const_cast(), - ) - } - - /// Splits the matrix horizontally and vertically at the given indices into four corners and - /// returns an array of each submatrix, in the following order: - /// * top left. - /// * top right. - /// * bottom left. - /// * bottom right. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row <= self.nrows()`. - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn split_at_mut(self, row: usize, col: usize) -> (Self, Self, Self, Self) { - let (top_left, top_right, bot_left, bot_right) = self.into_const().split_at(row, col); - unsafe { - ( - top_left.const_cast(), - top_right.const_cast(), - bot_left.const_cast(), - bot_right.const_cast(), - ) - } - } - - /// Splits the matrix horizontally at the given row into two parts and returns an array of - /// each submatrix, in the following order: - /// * top. - /// * bottom. - /// - /// # Safety - /// The behavior is undefined if the following condition is violated: - /// * `row <= self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at_row_mut_unchecked(self, row: usize) -> (Self, Self) { - let (top, bot) = self.into_const().split_at_row_unchecked(row); - (top.const_cast(), bot.const_cast()) - } - - /// Splits the matrix horizontally at the given row into two parts and returns an array of - /// each submatrix, in the following order: - /// * top. - /// * bottom. - /// - /// # Panics - /// The function panics if the following condition is violated: - /// * `row <= self.nrows()`. - #[inline(always)] - #[track_caller] - pub fn split_at_row_mut(self, row: usize) -> (Self, Self) { - let (top, bot) = self.into_const().split_at_row(row); - unsafe { (top.const_cast(), bot.const_cast()) } - } - - /// Splits the matrix vertically at the given row into two parts and returns an array of - /// each submatrix, in the following order: - /// * left. - /// * right. - /// - /// # Safety - /// The behavior is undefined if the following condition is violated: - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn split_at_col_mut_unchecked(self, col: usize) -> (Self, Self) { - let (left, right) = self.into_const().split_at_col_unchecked(col); - (left.const_cast(), right.const_cast()) - } - - /// Splits the matrix vertically at the given row into two parts and returns an array of - /// each submatrix, in the following order: - /// * left. - /// * right. - /// - /// # Panics - /// The function panics if the following condition is violated: - /// * `col <= self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn split_at_col_mut(self, col: usize) -> (Self, Self) { - let (left, right) = self.into_const().split_at_col(col); - unsafe { (left.const_cast(), right.const_cast()) } - } - - /// Returns mutable references to the element at the given indices, or submatrices if either - /// `row` or `col` is a range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - /// * `col` must be contained in `[0, self.ncols())`. - #[inline(always)] - #[track_caller] - pub unsafe fn get_mut_unchecked( - self, - row: RowRange, - col: ColRange, - ) -> >::Target - where - Self: MatIndex, - { - >::get_unchecked(self, row, col) - } - - /// Returns mutable references to the element at the given indices, or submatrices if either - /// `row` or `col` is a range, with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - /// * `col` must be contained in `[0, self.ncols())`. - #[inline(always)] - #[track_caller] - pub fn get_mut( - self, - row: RowRange, - col: ColRange, - ) -> >::Target - where - Self: MatIndex, - { - >::get(self, row, col) - } - - /// Reads the value of the element at the given indices. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn read_unchecked(&self, row: usize, col: usize) -> E { - self.rb().read_unchecked(row, col) - } - - /// Reads the value of the element at the given indices, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn read(&self, row: usize, col: usize) -> E { - self.rb().read(row, col) - } - - /// Writes the value to the element at the given indices. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn write_unchecked(&mut self, row: usize, col: usize, value: E) { - let units = value.faer_into_units(); - let zipped = E::faer_zip(units, (*self).rb_mut().ptr_inbounds_at_mut(row, col)); - E::faer_map( - zipped, - #[inline(always)] - |(unit, ptr)| *ptr = unit, - ); - } - - /// Writes the value to the element at the given indices, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn write(&mut self, row: usize, col: usize, value: E) { - assert!(all(row < self.nrows(), col < self.ncols())); - unsafe { self.write_unchecked(row, col, value) }; - } - - /// Copies the values from the lower triangular part of `other` into the lower triangular - /// part of `self`. The diagonal part is included. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `self.nrows() == other.nrows()`. - /// * `self.ncols() == other.ncols()`. - /// * `self.nrows() == self.ncols()`. - #[track_caller] - pub fn copy_from_triangular_lower(&mut self, other: impl AsMatRef) { - #[track_caller] - #[inline(always)] - fn implementation(this: MatMut<'_, E>, other: MatRef<'_, E>) { - zipped!(this, other).for_each_triangular_lower( - zip::Diag::Include, - #[inline(always)] - |unzipped!(mut dst, src)| dst.write(src.read()), - ); - } - implementation(self.rb_mut(), other.as_mat_ref()) - } - - /// Copies the values from the lower triangular part of `other` into the lower triangular - /// part of `self`. The diagonal part is excluded. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `self.nrows() == other.nrows()`. - /// * `self.ncols() == other.ncols()`. - /// * `self.nrows() == self.ncols()`. - #[track_caller] - pub fn copy_from_strict_triangular_lower(&mut self, other: impl AsMatRef) { - #[track_caller] - #[inline(always)] - fn implementation(this: MatMut<'_, E>, other: MatRef<'_, E>) { - zipped!(this, other).for_each_triangular_lower( - zip::Diag::Skip, - #[inline(always)] - |unzipped!(mut dst, src)| dst.write(src.read()), - ); - } - implementation(self.rb_mut(), other.as_mat_ref()) - } - - /// Copies the values from the upper triangular part of `other` into the upper triangular - /// part of `self`. The diagonal part is included. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `self.nrows() == other.nrows()`. - /// * `self.ncols() == other.ncols()`. - /// * `self.nrows() == self.ncols()`. - #[track_caller] - #[inline(always)] - pub fn copy_from_triangular_upper(&mut self, other: impl AsMatRef) { - (*self) - .rb_mut() - .transpose_mut() - .copy_from_triangular_lower(other.as_mat_ref().transpose()) - } - - /// Copies the values from the upper triangular part of `other` into the upper triangular - /// part of `self`. The diagonal part is excluded. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `self.nrows() == other.nrows()`. - /// * `self.ncols() == other.ncols()`. - /// * `self.nrows() == self.ncols()`. - #[track_caller] - #[inline(always)] - pub fn copy_from_strict_triangular_upper(&mut self, other: impl AsMatRef) { - (*self) - .rb_mut() - .transpose_mut() - .copy_from_strict_triangular_lower(other.as_mat_ref().transpose()) - } - - /// Copies the values from `other` into `self`. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `self.nrows() == other.nrows()`. - /// * `self.ncols() == other.ncols()`. - #[track_caller] - pub fn copy_from(&mut self, other: impl AsMatRef) { - #[track_caller] - #[inline(always)] - fn implementation(this: MatMut<'_, E>, other: MatRef<'_, E>) { - zipped!(this, other).for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - } - implementation(self.rb_mut(), other.as_mat_ref()) - } - - /// Fills the elements of `self` with zeros. - #[track_caller] - pub fn fill_zero(&mut self) - where - E: ComplexField, - { - zipped!(self.rb_mut()).for_each( - #[inline(always)] - |unzipped!(mut x)| x.write(E::faer_zero()), - ); - } - - /// Fills the elements of `self` with copies of `constant`. - #[track_caller] - pub fn fill(&mut self, constant: E) { - zipped!((*self).rb_mut()).for_each( - #[inline(always)] - |unzipped!(mut x)| x.write(constant), - ); - } - - /// Returns a view over the transpose of `self`. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - /// let view = matrix.as_mut(); - /// let transpose = view.transpose_mut(); - /// - /// let mut expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; - /// assert_eq!(expected.as_mut(), transpose); - /// ``` - #[inline(always)] - #[must_use] - pub fn transpose_mut(self) -> Self { - unsafe { - mat::from_raw_parts_mut( - E::faer_map( - from_copy::(self.inner.inner.ptr), - #[inline(always)] - |ptr| ptr.as_ptr(), - ), - self.ncols(), - self.nrows(), - self.col_stride(), - self.row_stride(), - ) - } - } - - /// Returns a view over the conjugate of `self`. - #[inline(always)] - #[must_use] - pub fn conjugate_mut(self) -> MatMut<'a, E::Conj> - where - E: Conjugate, - { - unsafe { self.into_const().conjugate().const_cast() } - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline(always)] - #[must_use] - pub fn adjoint_mut(self) -> MatMut<'a, E::Conj> - where - E: Conjugate, - { - self.transpose_mut().conjugate_mut() - } - - /// Returns a view over the canonical representation of `self`, as well as a flag declaring - /// whether `self` is implicitly conjugated or not. - #[inline(always)] - #[must_use] - pub fn canonicalize_mut(self) -> (MatMut<'a, E::Canonical>, Conj) - where - E: Conjugate, - { - let (canonical, conj) = self.into_const().canonicalize(); - unsafe { (canonical.const_cast(), conj) } - } - - /// Returns a view over the `self`, with the rows in reversed order. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - /// let view = matrix.as_mut(); - /// let reversed_rows = view.reverse_rows_mut(); - /// - /// let mut expected = mat![[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]; - /// assert_eq!(expected.as_mut(), reversed_rows); - /// ``` - #[inline(always)] - #[must_use] - pub fn reverse_rows_mut(self) -> Self { - unsafe { self.into_const().reverse_rows().const_cast() } - } - - /// Returns a view over the `self`, with the columns in reversed order. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - /// let view = matrix.as_mut(); - /// let reversed_cols = view.reverse_cols_mut(); - /// - /// let mut expected = mat![[3.0, 2.0, 1.0], [6.0, 5.0, 4.0]]; - /// assert_eq!(expected.as_mut(), reversed_cols); - /// ``` - #[inline(always)] - #[must_use] - pub fn reverse_cols_mut(self) -> Self { - unsafe { self.into_const().reverse_cols().const_cast() } - } - - /// Returns a view over the `self`, with the rows and the columns in reversed order. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - /// let view = matrix.as_mut(); - /// let reversed = view.reverse_rows_and_cols_mut(); - /// - /// let mut expected = mat![[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]; - /// assert_eq!(expected.as_mut(), reversed); - /// ``` - #[inline(always)] - #[must_use] - pub fn reverse_rows_and_cols_mut(self) -> Self { - unsafe { self.into_const().reverse_rows_and_cols().const_cast() } - } - - /// Returns a view over the submatrix starting at indices `(row_start, col_start)`, and with - /// dimensions `(nrows, ncols)`. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row_start <= self.nrows()`. - /// * `col_start <= self.ncols()`. - /// * `nrows <= self.nrows() - row_start`. - /// * `ncols <= self.ncols() - col_start`. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let mut matrix = mat![ - /// [1.0, 5.0, 9.0], - /// [2.0, 6.0, 10.0], - /// [3.0, 7.0, 11.0], - /// [4.0, 8.0, 12.0f64], - /// ]; - /// - /// let view = matrix.as_mut(); - /// let submatrix = view.submatrix_mut(2, 1, 2, 2); - /// - /// let mut expected = mat![[7.0, 11.0], [8.0, 12.0f64]]; - /// assert_eq!(expected.as_mut(), submatrix); - /// ``` - #[track_caller] - #[inline(always)] - pub fn submatrix_mut( - self, - row_start: usize, - col_start: usize, - nrows: usize, - ncols: usize, - ) -> Self { - unsafe { - self.into_const() - .submatrix(row_start, col_start, nrows, ncols) - .const_cast() - } - } - - /// Returns a view over the submatrix starting at row `row_start`, and with number of rows - /// `nrows`. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row_start <= self.nrows()`. - /// * `nrows <= self.nrows() - row_start`. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let mut matrix = mat![ - /// [1.0, 5.0, 9.0], - /// [2.0, 6.0, 10.0], - /// [3.0, 7.0, 11.0], - /// [4.0, 8.0, 12.0f64], - /// ]; - /// - /// let view = matrix.as_mut(); - /// let subrows = view.subrows_mut(1, 2); - /// - /// let mut expected = mat![[2.0, 6.0, 10.0], [3.0, 7.0, 11.0],]; - /// assert_eq!(expected.as_mut(), subrows); - /// ``` - #[track_caller] - #[inline(always)] - pub fn subrows_mut(self, row_start: usize, nrows: usize) -> Self { - unsafe { self.into_const().subrows(row_start, nrows).const_cast() } - } - - /// Returns a view over the submatrix starting at column `col_start`, and with number of - /// columns `ncols`. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col_start <= self.ncols()`. - /// * `ncols <= self.ncols() - col_start`. - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let mut matrix = mat![ - /// [1.0, 5.0, 9.0], - /// [2.0, 6.0, 10.0], - /// [3.0, 7.0, 11.0], - /// [4.0, 8.0, 12.0f64], - /// ]; - /// - /// let view = matrix.as_mut(); - /// let subcols = view.subcols_mut(2, 1); - /// - /// let mut expected = mat![[9.0], [10.0], [11.0], [12.0f64]]; - /// assert_eq!(expected.as_mut(), subcols); - /// ``` - #[track_caller] - #[inline(always)] - pub fn subcols_mut(self, col_start: usize, ncols: usize) -> Self { - unsafe { self.into_const().subcols(col_start, ncols).const_cast() } - } - - /// Returns a view over the row at the given index. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row_idx < self.nrows()`. - #[track_caller] - #[inline(always)] - pub fn row_mut(self, row_idx: usize) -> RowMut<'a, E> { - unsafe { self.into_const().row(row_idx).const_cast() } - } - - /// Returns a view over the column at the given index. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col_idx < self.ncols()`. - #[track_caller] - #[inline(always)] - pub fn col_mut(self, col_idx: usize) -> ColMut<'a, E> { - unsafe { self.into_const().col(col_idx).const_cast() } - } - - /// Given a matrix with a single column, returns an object that interprets - /// the column as a diagonal matrix, whoes diagonal elements are values in the column. - #[track_caller] - #[inline(always)] - pub fn column_vector_as_diagonal_mut(self) -> Matrix> { - assert!(self.ncols() == 1); - Matrix { - inner: inner::DiagMut { - inner: self.col_mut(0), - }, - } - } - - /// Returns the diagonal of the matrix. - #[inline(always)] - pub fn diagonal_mut(self) -> Matrix> { - let size = self.nrows().min(self.ncols()); - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - unsafe { - Matrix { - inner: inner::DiagMut { - inner: col::from_raw_parts_mut( - self.as_ptr_mut(), - size, - row_stride + col_stride, - ), - }, - } - } - } - - /// Returns an owning [`Mat`] of the data - #[inline] - pub fn to_owned(&self) -> Mat - where - E: Conjugate, - { - self.rb().to_owned() - } - - /// Returns `true` if any of the elements is NaN, otherwise returns `false`. - #[inline] - pub fn has_nan(&self) -> bool - where - E: ComplexField, - { - self.rb().has_nan() - } - - /// Returns `true` if all of the elements are finite, otherwise returns `false`. - #[inline] - pub fn is_all_finite(&self) -> bool - where - E: ComplexField, - { - self.rb().is_all_finite() - } - - /// Returns the maximum norm of `self`. - #[inline] - pub fn norm_max(&self) -> E::Real - where - E: ComplexField, - { - norm_max((*self).rb()) - } - /// Returns the L2 norm of `self`. - #[inline] - pub fn norm_l2(&self) -> E::Real - where - E: ComplexField, - { - norm_l2((*self).rb()) - } - - /// Returns the sum of `self`. - #[inline] - pub fn sum(&self) -> E - where - E: ComplexField, - { - sum((*self).rb()) - } - - /// Kroneckor product of `self` and `rhs`. - /// - /// This is an allocating operation; see [`kron`] for the - /// allocation-free version or more info in general. - #[inline] - #[track_caller] - pub fn kron(&self, rhs: impl As2D) -> Mat - where - E: ComplexField, - { - self.as_2d_ref().kron(rhs) - } - - /// Returns a view over the matrix. - #[inline] - pub fn as_ref(&self) -> MatRef<'_, E> { - self.rb() - } - - /// Returns a mutable view over the matrix. - #[inline] - pub fn as_mut(&mut self) -> MatMut<'_, E> { - self.rb_mut() - } - - /// Returns an iterator that provides successive chunks of the columns of this matrix, with - /// each having at most `chunk_size` columns. - /// - /// If the number of columns is a multiple of `chunk_size`, then all chunks have - /// `chunk_size` columns. - #[inline] - #[track_caller] - pub fn col_chunks_mut( - self, - chunk_size: usize, - ) -> impl 'a + DoubleEndedIterator> { - self.into_const() - .col_chunks(chunk_size) - .map(|chunk| unsafe { chunk.const_cast() }) - } - - /// Returns an iterator that provides successive chunks of the rows of this matrix, - /// with each having at most `chunk_size` rows. - /// - /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// rows. - #[inline] - #[track_caller] - pub fn row_chunks_mut( - self, - chunk_size: usize, - ) -> impl 'a + DoubleEndedIterator> { - self.into_const() - .row_chunks(chunk_size) - .map(|chunk| unsafe { chunk.const_cast() }) - } - - /// Returns a parallel iterator that provides successive chunks of the columns of this - /// matrix, with each having at most `chunk_size` columns. - /// - /// If the number of columns is a multiple of `chunk_size`, then all chunks have - /// `chunk_size` columns. - /// - /// Only available with the `rayon` feature. - #[cfg(feature = "rayon")] - #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] - #[inline] - #[track_caller] - pub fn par_col_chunks_mut( - self, - chunk_size: usize, - ) -> impl 'a + rayon::iter::IndexedParallelIterator> { - use rayon::prelude::*; - self.into_const() - .par_col_chunks(chunk_size) - .map(|chunk| unsafe { chunk.const_cast() }) - } - - /// Returns a parallel iterator that provides successive chunks of the rows of this matrix, - /// with each having at most `chunk_size` rows. - /// - /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// rows. - /// - /// Only available with the `rayon` feature. - #[cfg(feature = "rayon")] - #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] - #[inline] - #[track_caller] - pub fn par_row_chunks_mut( - self, - chunk_size: usize, - ) -> impl 'a + rayon::iter::IndexedParallelIterator> { - use rayon::prelude::*; - self.into_const() - .par_row_chunks(chunk_size) - .map(|chunk| unsafe { chunk.const_cast() }) - } - } - - impl<'a, E: RealField> MatRef<'a, Complex> { - /// Returns the real and imaginary components of `self`. - #[inline(always)] - pub fn real_imag(self) -> Complex> { - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - let nrows = self.nrows(); - let ncols = self.ncols(); - let Complex { re, im } = self.as_ptr(); - unsafe { - Complex { - re: mat::from_raw_parts(re, nrows, ncols, row_stride, col_stride), - im: mat::from_raw_parts(im, nrows, ncols, row_stride, col_stride), - } - } - } - } - - impl<'a, E: RealField> MatMut<'a, Complex> { - /// Returns the real and imaginary components of `self`. - #[inline(always)] - pub fn real_imag_mut(self) -> Complex> { - let Complex { re, im } = self.into_const().real_imag(); - unsafe { - Complex { - re: re.const_cast(), - im: im.const_cast(), - } - } - } - } -}; - -#[repr(C)] -struct RawMatUnit { - ptr: NonNull, - row_capacity: usize, - col_capacity: usize, -} - -impl RawMatUnit { - pub fn new(row_capacity: usize, col_capacity: usize) -> Self { - let dangling = NonNull::::dangling(); - if core::mem::size_of::() == 0 { - Self { - ptr: dangling, - row_capacity, - col_capacity, - } - } else { - let cap = row_capacity - .checked_mul(col_capacity) - .unwrap_or_else(capacity_overflow); - let cap_bytes = cap - .checked_mul(core::mem::size_of::()) - .unwrap_or_else(capacity_overflow); - if cap_bytes > isize::MAX as usize { - capacity_overflow::<()>(); - } - - use alloc::alloc::{alloc, handle_alloc_error, Layout}; - - let layout = Layout::from_size_align(cap_bytes, align_for::()) - .ok() - .unwrap_or_else(capacity_overflow); - - let ptr = if layout.size() == 0 { - dangling - } else { - // SAFETY: we checked that layout has non zero size - let ptr = unsafe { alloc(layout) } as *mut T; - if ptr.is_null() { - handle_alloc_error(layout) - } else { - // SAFETY: we checked that the pointer is not null - unsafe { NonNull::::new_unchecked(ptr) } - } - }; - - Self { - ptr, - row_capacity, - col_capacity, - } - } - } -} - -impl Drop for RawMatUnit { - fn drop(&mut self) { - use alloc::alloc::{dealloc, Layout}; - // this cannot overflow because we already allocated this much memory - // self.row_capacity.wrapping_mul(self.col_capacity) may overflow if T is a zst - // but that's fine since we immediately multiply it by 0. - let alloc_size = - self.row_capacity.wrapping_mul(self.col_capacity) * core::mem::size_of::(); - if alloc_size != 0 { - // SAFETY: pointer was allocated with alloc::alloc::alloc - unsafe { - dealloc( - self.ptr.as_ptr() as *mut u8, - Layout::from_size_align_unchecked(alloc_size, align_for::()), - ); - } - } - } -} - -#[repr(C)] -struct RawMat { - ptr: GroupCopyFor>, - row_capacity: usize, - col_capacity: usize, -} - -#[cold] -fn capacity_overflow_impl() -> ! { - panic!("capacity overflow") -} - -#[inline(always)] -fn capacity_overflow() -> T { - capacity_overflow_impl(); -} - -#[doc(hidden)] -#[inline(always)] -pub fn is_vectorizable() -> bool { - coe::is_same::() - || coe::is_same::() - || coe::is_same::() - || coe::is_same::() - || coe::is_same::() - || coe::is_same::() -} - -// https://rust-lang.github.io/hashbrown/src/crossbeam_utils/cache_padded.rs.html#128-130 -#[doc(hidden)] -pub const CACHELINE_ALIGN: usize = { - #[cfg(any( - target_arch = "x86_64", - target_arch = "aarch64", - target_arch = "powerpc64", - ))] - { - 128 - } - #[cfg(any( - target_arch = "arm", - target_arch = "mips", - target_arch = "mips64", - target_arch = "riscv64", - ))] - { - 32 - } - #[cfg(target_arch = "s390x")] - { - 256 - } - #[cfg(not(any( - target_arch = "x86_64", - target_arch = "aarch64", - target_arch = "powerpc64", - target_arch = "arm", - target_arch = "mips", - target_arch = "mips64", - target_arch = "riscv64", - target_arch = "s390x", - )))] - { - 64 - } -}; - -#[doc(hidden)] -#[inline(always)] -pub fn align_for() -> usize { - if is_vectorizable::() { - Ord::max( - core::mem::size_of::(), - Ord::max(core::mem::align_of::(), CACHELINE_ALIGN), - ) - } else { - core::mem::align_of::() - } -} - -impl RawMat { - pub fn new(row_capacity: usize, col_capacity: usize) -> Self { - // allocate the unit matrices - let group = E::faer_map(E::UNIT, |()| { - RawMatUnit::::new(row_capacity, col_capacity) - }); - - let group = E::faer_map(group, ManuallyDrop::new); - - Self { - ptr: into_copy::(E::faer_map(group, |mat| mat.ptr)), - row_capacity, - col_capacity, - } - } -} - -impl Drop for RawMat { - fn drop(&mut self) { - drop(E::faer_map(from_copy::(self.ptr), |ptr| RawMatUnit { - ptr, - row_capacity: self.row_capacity, - col_capacity: self.col_capacity, - })); - } -} - -/// Heap allocated resizable matrix, similar to a 2D [`Vec`]. -/// -/// # Note -/// -/// The memory layout of `Mat` is guaranteed to be column-major, meaning that it has a row stride -/// of `1`, and an unspecified column stride that can be queried with [`Mat::col_stride`]. -/// -/// This implies that while each individual column is stored contiguously in memory, the matrix as -/// a whole may not necessarily be contiguous. The implementation may add padding at the end of -/// each column when overaligning each column can provide a performance gain. -/// -/// Let us consider a 3×4 matrix -/// -/// ```notcode -/// 0 │ 3 │ 6 │ 9 -/// ───┼───┼───┼─── -/// 1 │ 4 │ 7 │ 10 -/// ───┼───┼───┼─── -/// 2 │ 5 │ 8 │ 11 -/// ``` -/// The memory representation of the data held by such a matrix could look like the following: -/// -/// ```notcode -/// 0 1 2 X 3 4 5 X 6 7 8 X 9 10 11 X -/// ``` -/// -/// where X represents padding elements. -pub type Mat = Matrix>; - -/// Heap allocated resizable column vector. -/// -/// # Note -/// -/// The memory layout of `Col` is guaranteed to be column-major, meaning that it has a row stride -/// of `1`. -pub type Col = Matrix>; - -/// Heap allocated resizable row vector. -/// -/// # Note -/// -/// The memory layout of `Col` is guaranteed to be row-major, meaning that it has a column stride -/// of `1`. -pub type Row = Matrix>; - -#[repr(C)] -struct MatUnit { - raw: RawMatUnit, - nrows: usize, - ncols: usize, -} - -impl Clone for Mat { - fn clone(&self) -> Self { - let this = self.as_ref(); - unsafe { - Self::from_fn(self.nrows(), self.ncols(), |i, j| { - E::faer_from_units(E::faer_deref(this.get_unchecked(i, j))) - }) - } - } -} - -impl MatUnit { - #[cold] - fn do_reserve_exact(&mut self, mut new_row_capacity: usize, mut new_col_capacity: usize) { - new_row_capacity = self.raw.row_capacity.max(new_row_capacity); - new_col_capacity = self.raw.col_capacity.max(new_col_capacity); - - let new_ptr = if self.raw.row_capacity == new_row_capacity - && self.raw.row_capacity != 0 - && self.raw.col_capacity != 0 - { - // case 1: - // we have enough row capacity, and we've already allocated memory. - // use realloc to get extra column memory - - use alloc::alloc::{handle_alloc_error, realloc, Layout}; - - // this shouldn't overflow since we already hold this many bytes - let old_cap = self.raw.row_capacity * self.raw.col_capacity; - let old_cap_bytes = old_cap * core::mem::size_of::(); - - let new_cap = new_row_capacity - .checked_mul(new_col_capacity) - .unwrap_or_else(capacity_overflow); - let new_cap_bytes = new_cap - .checked_mul(core::mem::size_of::()) - .unwrap_or_else(capacity_overflow); - - if new_cap_bytes > isize::MAX as usize { - capacity_overflow::<()>(); - } - - // SAFETY: this shouldn't overflow since we already checked that it's valid during - // allocation - let old_layout = - unsafe { Layout::from_size_align_unchecked(old_cap_bytes, align_for::()) }; - let new_layout = Layout::from_size_align(new_cap_bytes, align_for::()) - .ok() - .unwrap_or_else(capacity_overflow); - - // SAFETY: - // * old_ptr is non null and is the return value of some previous call to alloc - // * old_layout is the same layout that was used to provide the old allocation - // * new_cap_bytes is non zero since new_row_capacity and new_col_capacity are larger - // than self.raw.row_capacity and self.raw.col_capacity respectively, and the computed - // product doesn't overflow. - // * new_cap_bytes, when rounded up to the nearest multiple of the alignment does not - // overflow, since we checked that we can create new_layout with it. - unsafe { - let old_ptr = self.raw.ptr.as_ptr(); - let new_ptr = realloc(old_ptr as *mut u8, old_layout, new_cap_bytes); - if new_ptr.is_null() { - handle_alloc_error(new_layout); - } - new_ptr as *mut T - } - } else { - // case 2: - // use alloc and move stuff manually. - - // allocate new memory region - let new_ptr = { - let m = ManuallyDrop::new(RawMatUnit::::new(new_row_capacity, new_col_capacity)); - m.ptr.as_ptr() - }; - - let old_ptr = self.raw.ptr.as_ptr(); - - // copy each column to new matrix - for j in 0..self.ncols { - // SAFETY: - // * pointer offsets can't overflow since they're within an already allocated - // memory region less than isize::MAX bytes in size. - // * new and old allocation can't overlap, so copy_nonoverlapping is fine here. - unsafe { - let old_ptr = old_ptr.add(j * self.raw.row_capacity); - let new_ptr = new_ptr.add(j * new_row_capacity); - core::ptr::copy_nonoverlapping(old_ptr, new_ptr, self.nrows); - } - } - - // deallocate old matrix memory - let _ = RawMatUnit:: { - // SAFETY: this ptr was checked to be non null, or was acquired from a NonNull - // pointer. - ptr: unsafe { NonNull::new_unchecked(old_ptr) }, - row_capacity: self.raw.row_capacity, - col_capacity: self.raw.col_capacity, - }; - - new_ptr - }; - self.raw.row_capacity = new_row_capacity; - self.raw.col_capacity = new_col_capacity; - self.raw.ptr = unsafe { NonNull::::new_unchecked(new_ptr) }; - } -} - -impl Drop for inner::DenseOwn { - fn drop(&mut self) { - drop(RawMat:: { - ptr: self.inner.ptr, - row_capacity: self.row_capacity, - col_capacity: self.col_capacity, - }); - } -} -impl Drop for inner::DenseColOwn { - fn drop(&mut self) { - drop(RawMat:: { - ptr: self.inner.ptr, - row_capacity: self.row_capacity, - col_capacity: 1, - }); - } -} -impl Drop for inner::DenseRowOwn { - fn drop(&mut self) { - drop(RawMat:: { - ptr: self.inner.ptr, - row_capacity: self.col_capacity, - col_capacity: 1, - }); - } -} - -impl Default for Mat { - #[inline] - fn default() -> Self { - Self::new() - } -} -impl Default for Col { - #[inline] - fn default() -> Self { - Self::new() - } -} -impl Default for Row { - #[inline] - fn default() -> Self { - Self::new() - } -} - -impl Col { - /// Returns an empty column of dimension `0`. - #[inline] - pub fn new() -> Self { - Self { - inner: inner::DenseColOwn { - inner: VecOwnImpl { - ptr: into_copy::(E::faer_map(E::UNIT, |()| { - NonNull::::dangling() - })), - len: 0, - }, - row_capacity: 0, - }, - } - } - - /// Returns a new column vector with 0 rows, with enough capacity to hold a maximum of - /// `row_capacity` rows columns without reallocating. If `row_capacity` is `0`, - /// the function will not allocate. - /// - /// # Panics - /// The function panics if the total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn with_capacity(row_capacity: usize) -> Self { - let raw = ManuallyDrop::new(RawMat::::new(row_capacity, 1)); - Self { - inner: inner::DenseColOwn { - inner: VecOwnImpl { - ptr: raw.ptr, - len: 0, - }, - row_capacity: raw.row_capacity, - }, - } - } - - /// Returns a new matrix with number of rows `nrows`, filled with the provided function. - /// - /// # Panics - /// The function panics if the total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn from_fn(nrows: usize, f: impl FnMut(usize) -> E) -> Self { - let mut this = Self::new(); - this.resize_with(nrows, f); - this - } - - /// Returns a new matrix with number of rows `nrows`, filled with zeros. - /// - /// # Panics - /// The function panics if the total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn zeros(nrows: usize) -> Self - where - E: ComplexField, - { - Self::from_fn(nrows, |_| E::faer_zero()) - } - - /// Returns the number of rows of the column. - #[inline(always)] - pub fn nrows(&self) -> usize { - self.inner.inner.len - } - /// Returns the number of columns of the column. This is always equal to `1`. - #[inline(always)] - pub fn ncols(&self) -> usize { - 1 - } - - /// Set the dimensions of the matrix. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `nrows < self.row_capacity()`. - /// * The elements that were previously out of bounds but are now in bounds must be - /// initialized. - #[inline] - pub unsafe fn set_nrows(&mut self, nrows: usize) { - self.inner.inner.len = nrows; - } - - /// Returns a pointer to the data of the matrix. - #[inline] - pub fn as_ptr(&self) -> GroupFor { - E::faer_map(from_copy::(self.inner.inner.ptr), |ptr| { - ptr.as_ptr() as *const E::Unit - }) - } - - /// Returns a mutable pointer to the data of the matrix. - #[inline] - pub fn as_ptr_mut(&mut self) -> GroupFor { - E::faer_map(from_copy::(self.inner.inner.ptr), |ptr| ptr.as_ptr()) - } - - /// Returns the row capacity, that is, the number of rows that the matrix is able to hold - /// without needing to reallocate, excluding column insertions. - #[inline] - pub fn row_capacity(&self) -> usize { - self.inner.row_capacity - } - - /// Returns the offset between the first elements of two successive rows in the matrix. - /// Always returns `1` since the matrix is column major. - #[inline] - pub fn row_stride(&self) -> isize { - 1 - } - - #[cold] - fn do_reserve_exact(&mut self, mut new_row_capacity: usize) { - if is_vectorizable::() { - let align_factor = align_for::() / core::mem::size_of::(); - new_row_capacity = new_row_capacity - .msrv_checked_next_multiple_of(align_factor) - .unwrap(); - } - - let nrows = self.inner.inner.len; - let old_row_capacity = self.inner.row_capacity; - - let mut this = ManuallyDrop::new(core::mem::take(self)); - { - let mut this_group = - E::faer_map(from_copy::(this.inner.inner.ptr), |ptr| MatUnit { - raw: RawMatUnit { - ptr, - row_capacity: old_row_capacity, - col_capacity: 1, - }, - nrows, - ncols: 1, - }); - - E::faer_map(E::faer_as_mut(&mut this_group), |mat_unit| { - mat_unit.do_reserve_exact(new_row_capacity, 1); - }); - - let this_group = E::faer_map(this_group, ManuallyDrop::new); - this.inner.inner.ptr = - into_copy::(E::faer_map(this_group, |mat_unit| mat_unit.raw.ptr)); - this.inner.row_capacity = new_row_capacity; - } - *self = ManuallyDrop::into_inner(this); - } - - /// Reserves the minimum capacity for `row_capacity` rows without reallocating. Does nothing if - /// the capacity is already sufficient. - /// - /// # Panics - /// The function panics if the new total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn reserve_exact(&mut self, row_capacity: usize) { - if self.row_capacity() >= row_capacity { - // do nothing - } else if core::mem::size_of::() == 0 { - self.inner.row_capacity = self.row_capacity().max(row_capacity); - } else { - self.do_reserve_exact(row_capacity); - } - } - - unsafe fn insert_block_with E>( - &mut self, - f: &mut F, - row_start: usize, - row_end: usize, - ) { - debug_assert!(row_start <= row_end); - - let ptr = self.as_ptr_mut(); - - for i in row_start..row_end { - // SAFETY: - // * pointer to element at index (i, j), which is within the - // allocation since we reserved enough space - // * writing to this memory region is sound since it is properly - // aligned and valid for writes - let ptr_ij = E::faer_map(E::faer_copy(&ptr), |ptr| ptr.add(i)); - let value = E::faer_into_units(f(i)); - - E::faer_map(E::faer_zip(ptr_ij, value), |(ptr_ij, value)| { - core::ptr::write(ptr_ij, value) - }); - } - } - - fn erase_last_rows(&mut self, new_nrows: usize) { - let old_nrows = self.nrows(); - debug_assert!(new_nrows <= old_nrows); - self.inner.inner.len = new_nrows; - } - - unsafe fn insert_last_rows_with E>(&mut self, f: &mut F, new_nrows: usize) { - let old_nrows = self.nrows(); - - debug_assert!(new_nrows > old_nrows); - - self.insert_block_with(f, old_nrows, new_nrows); - self.inner.inner.len = new_nrows; - } - - /// Resizes the vector in-place so that the new number of rows is `new_nrows`. - /// New elements are created with the given function `f`, so that elements at index `i` - /// are created by calling `f(i)`. - pub fn resize_with(&mut self, new_nrows: usize, f: impl FnMut(usize) -> E) { - let mut f = f; - let old_nrows = self.nrows(); - - if new_nrows <= old_nrows { - self.erase_last_rows(new_nrows); - } else { - self.reserve_exact(new_nrows); - unsafe { - self.insert_last_rows_with(&mut f, new_nrows); - } - } - } - - /// Returns a reference to a slice over the column. - #[inline] - #[track_caller] - pub fn as_slice(&self) -> GroupFor { - let nrows = self.nrows(); - let ptr = self.as_ref().as_ptr(); - E::faer_map( - ptr, - #[inline(always)] - |ptr| unsafe { core::slice::from_raw_parts(ptr, nrows) }, - ) - } - - /// Returns a mutable reference to a slice over the column. - #[inline] - #[track_caller] - pub fn as_slice_mut(&mut self) -> GroupFor { - let nrows = self.nrows(); - let ptr = self.as_ptr_mut(); - E::faer_map( - ptr, - #[inline(always)] - |ptr| unsafe { core::slice::from_raw_parts_mut(ptr, nrows) }, - ) - } - - /// Returns a view over the vector. - #[inline] - pub fn as_ref(&self) -> ColRef<'_, E> { - unsafe { col::from_raw_parts(self.as_ptr(), self.nrows(), 1) } - } - - /// Returns a mutable view over the vector. - #[inline] - pub fn as_mut(&mut self) -> ColMut<'_, E> { - unsafe { col::from_raw_parts_mut(self.as_ptr_mut(), self.nrows(), 1) } - } - - /// Returns references to the element at the given index, or submatrices if `row` is a range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - #[inline] - pub unsafe fn get_unchecked( - &self, - row: RowRange, - ) -> as ColIndex>::Target - where - for<'a> ColRef<'a, E>: ColIndex, - { - self.as_ref().get_unchecked(row) - } - - /// Returns references to the element at the given index, or submatrices if `row` is a range, - /// with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - #[inline] - pub fn get(&self, row: RowRange) -> as ColIndex>::Target - where - for<'a> ColRef<'a, E>: ColIndex, - { - self.as_ref().get(row) - } - - /// Returns mutable references to the element at the given index, or submatrices if - /// `row` is a range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - #[inline] - pub unsafe fn get_mut_unchecked( - &mut self, - row: RowRange, - ) -> as ColIndex>::Target - where - for<'a> ColMut<'a, E>: ColIndex, - { - self.as_mut().get_unchecked_mut(row) - } - - /// Returns mutable references to the element at the given index, or submatrices if - /// `row` is a range, with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - #[inline] - pub fn get_mut( - &mut self, - row: RowRange, - ) -> as ColIndex>::Target - where - for<'a> ColMut<'a, E>: ColIndex, - { - self.as_mut().get_mut(row) - } - - /// Reads the value of the element at the given index. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn read_unchecked(&self, row: usize) -> E { - self.as_ref().read_unchecked(row) - } - - /// Reads the value of the element at the given index, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub fn read(&self, row: usize) -> E { - self.as_ref().read(row) - } - - /// Writes the value to the element at the given index. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub unsafe fn write_unchecked(&mut self, row: usize, value: E) { - self.as_mut().write_unchecked(row, value); - } - - /// Writes the value to the element at the given index, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row < self.nrows()`. - #[inline(always)] - #[track_caller] - pub fn write(&mut self, row: usize, value: E) { - self.as_mut().write(row, value); - } - - /// Copies the values from `other` into `self`. - #[inline(always)] - #[track_caller] - pub fn copy_from(&mut self, other: impl AsColRef) { - #[track_caller] - #[inline(always)] - fn implementation(this: &mut Col, other: ColRef<'_, E>) { - let mut mat = Col::::new(); - mat.resize_with( - other.nrows(), - #[inline(always)] - |row| unsafe { other.read_unchecked(row) }, - ); - *this = mat; - } - implementation(self, other.as_col_ref()); - } - - /// Fills the elements of `self` with zeros. - #[inline(always)] - #[track_caller] - pub fn fill_zero(&mut self) - where - E: ComplexField, - { - self.as_mut().fill_zero() - } - - /// Fills the elements of `self` with copies of `constant`. - #[inline(always)] - #[track_caller] - pub fn fill(&mut self, constant: E) { - self.as_mut().fill(constant) - } - - /// Returns a view over the transpose of `self`. - #[inline] - pub fn transpose(&self) -> RowRef<'_, E> { - self.as_ref().transpose() - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn conjugate(&self) -> ColRef<'_, E::Conj> - where - E: Conjugate, - { - self.as_ref().conjugate() - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline] - pub fn adjoint(&self) -> RowRef<'_, E::Conj> - where - E: Conjugate, - { - self.as_ref().adjoint() - } - - /// Returns an owning [`Col`] of the data - #[inline] - pub fn to_owned(&self) -> Col - where - E: Conjugate, - { - self.as_ref().to_owned() - } - - /// Returns `true` if any of the elements is NaN, otherwise returns `false`. - #[inline] - pub fn has_nan(&self) -> bool - where - E: ComplexField, - { - self.as_ref().has_nan() - } - - /// Returns `true` if all of the elements are finite, otherwise returns `false`. - #[inline] - pub fn is_all_finite(&self) -> bool - where - E: ComplexField, - { - self.as_ref().is_all_finite() - } - - /// Returns the maximum norm of `self`. - #[inline] - pub fn norm_max(&self) -> E::Real - where - E: ComplexField, - { - norm_max((*self).as_ref().as_2d()) - } - /// Returns the L2 norm of `self`. - #[inline] - pub fn norm_l2(&self) -> E::Real - where - E: ComplexField, - { - norm_l2((*self).as_ref().as_2d()) - } - - /// Returns the sum of `self`. - #[inline] - pub fn sum(&self) -> E - where - E: ComplexField, - { - sum((*self).as_ref().as_2d()) - } - - /// Kroneckor product of `self` and `rhs`. - /// - /// This is an allocating operation; see [`kron`] for the - /// allocation-free version or more info in general. - #[inline] - #[track_caller] - pub fn kron(&self, rhs: impl As2D) -> Mat - where - E: ComplexField, - { - self.as_2d_ref().kron(rhs) - } -} - -impl Row { - /// Returns an empty row of dimension `0`. - #[inline] - pub fn new() -> Self { - Self { - inner: inner::DenseRowOwn { - inner: VecOwnImpl { - ptr: into_copy::(E::faer_map(E::UNIT, |()| { - NonNull::::dangling() - })), - len: 0, - }, - col_capacity: 0, - }, - } - } - - /// Returns a new column vector with 0 columns, with enough capacity to hold a maximum of - /// `col_capacity` columnss columns without reallocating. If `col_capacity` is `0`, - /// the function will not allocate. - /// - /// # Panics - /// The function panics if the total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn with_capacity(col_capacity: usize) -> Self { - let raw = ManuallyDrop::new(RawMat::::new(col_capacity, 1)); - Self { - inner: inner::DenseRowOwn { - inner: VecOwnImpl { - ptr: raw.ptr, - len: 0, - }, - col_capacity: raw.row_capacity, - }, - } - } - - /// Returns a new matrix with number of columns `ncols`, filled with the provided function. - /// - /// # Panics - /// The function panics if the total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn from_fn(ncols: usize, f: impl FnMut(usize) -> E) -> Self { - let mut this = Self::new(); - this.resize_with(ncols, f); - this - } - - /// Returns a new matrix with number of columns `ncols`, filled with zeros. - /// - /// # Panics - /// The function panics if the total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn zeros(ncols: usize) -> Self - where - E: ComplexField, - { - Self::from_fn(ncols, |_| E::faer_zero()) - } - - /// Returns the number of rows of the row. This is always equal to `1`. - #[inline(always)] - pub fn nrows(&self) -> usize { - 1 - } - /// Returns the number of columns of the row. - #[inline(always)] - pub fn ncols(&self) -> usize { - self.inner.inner.len - } - - /// Set the dimensions of the matrix. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `ncols < self.col_capacity()`. - /// * The elements that were previously out of bounds but are now in bounds must be - /// initialized. - #[inline] - pub unsafe fn set_ncols(&mut self, ncols: usize) { - self.inner.inner.len = ncols; - } - - /// Returns a pointer to the data of the matrix. - #[inline] - pub fn as_ptr(&self) -> GroupFor { - E::faer_map(from_copy::(self.inner.inner.ptr), |ptr| { - ptr.as_ptr() as *const E::Unit - }) - } - - /// Returns a mutable pointer to the data of the matrix. - #[inline] - pub fn as_ptr_mut(&mut self) -> GroupFor { - E::faer_map(from_copy::(self.inner.inner.ptr), |ptr| ptr.as_ptr()) - } - - /// Returns the col capacity, that is, the number of cols that the matrix is able to hold - /// without needing to reallocate, excluding column insertions. - #[inline] - pub fn col_capacity(&self) -> usize { - self.inner.col_capacity - } - - /// Returns the offset between the first elements of two successive columns in the matrix. - /// Always returns `1` since the matrix is column major. - #[inline] - pub fn col_stride(&self) -> isize { - 1 - } - - #[cold] - fn do_reserve_exact(&mut self, mut new_col_capacity: usize) { - if is_vectorizable::() { - let align_factor = align_for::() / core::mem::size_of::(); - new_col_capacity = new_col_capacity - .msrv_checked_next_multiple_of(align_factor) - .unwrap(); - } - - let ncols = self.inner.inner.len; - let old_col_capacity = self.inner.col_capacity; - - let mut this = ManuallyDrop::new(core::mem::take(self)); - { - let mut this_group = - E::faer_map(from_copy::(this.inner.inner.ptr), |ptr| MatUnit { - raw: RawMatUnit { - ptr, - row_capacity: old_col_capacity, - col_capacity: 1, - }, - ncols, - nrows: 1, - }); - - E::faer_map(E::faer_as_mut(&mut this_group), |mat_unit| { - mat_unit.do_reserve_exact(new_col_capacity, 1); - }); - - let this_group = E::faer_map(this_group, ManuallyDrop::new); - this.inner.inner.ptr = - into_copy::(E::faer_map(this_group, |mat_unit| mat_unit.raw.ptr)); - this.inner.col_capacity = new_col_capacity; - } - *self = ManuallyDrop::into_inner(this); - } - - /// Reserves the minimum capacity for `col_capacity` columns without reallocating. Does nothing - /// if the capacity is already sufficient. - /// - /// # Panics - /// The function panics if the new total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn reserve_exact(&mut self, col_capacity: usize) { - if self.col_capacity() >= col_capacity { - // do nothing - } else if core::mem::size_of::() == 0 { - self.inner.col_capacity = self.col_capacity().max(col_capacity); - } else { - self.do_reserve_exact(col_capacity); - } - } - - unsafe fn insert_block_with E>( - &mut self, - f: &mut F, - col_start: usize, - col_end: usize, - ) { - debug_assert!(col_start <= col_end); - - let ptr = self.as_ptr_mut(); - - for j in col_start..col_end { - // SAFETY: - // * pointer to element at index (i, j), which is within the - // allocation since we reserved enough space - // * writing to this memory region is sound since it is properly - // aligned and valid for writes - let ptr_ij = E::faer_map(E::faer_copy(&ptr), |ptr| ptr.add(j)); - let value = E::faer_into_units(f(j)); - - E::faer_map(E::faer_zip(ptr_ij, value), |(ptr_ij, value)| { - core::ptr::write(ptr_ij, value) - }); - } - } - - fn erase_last_cols(&mut self, new_ncols: usize) { - let old_ncols = self.ncols(); - debug_assert!(new_ncols <= old_ncols); - self.inner.inner.len = new_ncols; - } - - unsafe fn insert_last_cols_with E>(&mut self, f: &mut F, new_ncols: usize) { - let old_ncols = self.ncols(); - - debug_assert!(new_ncols > old_ncols); - - self.insert_block_with(f, old_ncols, new_ncols); - self.inner.inner.len = new_ncols; - } - - /// Resizes the vector in-place so that the new number of columns is `new_ncols`. - /// New elements are created with the given function `f`, so that elements at index `i` - /// are created by calling `f(i)`. - pub fn resize_with(&mut self, new_ncols: usize, f: impl FnMut(usize) -> E) { - let mut f = f; - let old_ncols = self.ncols(); - - if new_ncols <= old_ncols { - self.erase_last_cols(new_ncols); - } else { - self.reserve_exact(new_ncols); - unsafe { - self.insert_last_cols_with(&mut f, new_ncols); - } - } - } - - /// Returns a reference to a slice over the row. - #[inline] - #[track_caller] - pub fn as_slice(&self) -> GroupFor { - let ncols = self.ncols(); - let ptr = self.as_ref().as_ptr(); - E::faer_map( - ptr, - #[inline(always)] - |ptr| unsafe { core::slice::from_raw_parts(ptr, ncols) }, - ) - } - - /// Returns a mutable reference to a slice over the row. - #[inline] - #[track_caller] - pub fn as_slice_mut(&mut self) -> GroupFor { - let ncols = self.ncols(); - let ptr = self.as_ptr_mut(); - E::faer_map( - ptr, - #[inline(always)] - |ptr| unsafe { core::slice::from_raw_parts_mut(ptr, ncols) }, - ) - } - - /// Returns a view over the vector. - #[inline] - pub fn as_ref(&self) -> RowRef<'_, E> { - unsafe { row::from_raw_parts(self.as_ptr(), self.ncols(), 1) } - } - - /// Returns a mutable view over the vector. - #[inline] - pub fn as_mut(&mut self) -> RowMut<'_, E> { - unsafe { row::from_raw_parts_mut(self.as_ptr_mut(), self.ncols(), 1) } - } - - /// Returns references to the element at the given index, or submatrices if `col` is a range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col` must be contained in `[0, self.ncols())`. - #[inline] - pub unsafe fn get_unchecked( - &self, - col: ColRange, - ) -> as RowIndex>::Target - where - for<'a> RowRef<'a, E>: RowIndex, - { - self.as_ref().get_unchecked(col) - } - - /// Returns references to the element at the given index, or submatrices if `col` is a range, - /// with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col` must be contained in `[0, self.ncols())`. - #[inline] - pub fn get(&self, col: ColRange) -> as RowIndex>::Target - where - for<'a> RowRef<'a, E>: RowIndex, - { - self.as_ref().get(col) - } - - /// Returns mutable references to the element at the given index, or submatrices if - /// `col` is a range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col` must be contained in `[0, self.ncols())`. - #[inline] - pub unsafe fn get_mut_unchecked( - &mut self, - col: ColRange, - ) -> as RowIndex>::Target - where - for<'a> RowMut<'a, E>: RowIndex, - { - self.as_mut().get_mut_unchecked(col) - } - - /// Returns mutable references to the element at the given index, or submatrices if - /// `col` is a range, with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col` must be contained in `[0, self.ncols())`. - #[inline] - pub fn get_mut( - &mut self, - col: ColRange, - ) -> as RowIndex>::Target - where - for<'a> RowMut<'a, E>: RowIndex, - { - self.as_mut().get_mut(col) - } - - /// Reads the value of the element at the given index. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn read_unchecked(&self, col: usize) -> E { - self.as_ref().read_unchecked(col) - } - - /// Reads the value of the element at the given index, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn read(&self, col: usize) -> E { - self.as_ref().read(col) - } - - /// Writes the value to the element at the given index. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn write_unchecked(&mut self, col: usize, value: E) { - self.as_mut().write_unchecked(col, value); - } - - /// Writes the value to the element at the given index, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn write(&mut self, col: usize, value: E) { - self.as_mut().write(col, value); - } - - /// Copies the values from `other` into `self`. - #[inline(always)] - #[track_caller] - pub fn copy_from(&mut self, other: impl AsRowRef) { - #[track_caller] - #[inline(always)] - fn implementation(this: &mut Row, other: RowRef<'_, E>) { - let mut mat = Row::::new(); - mat.resize_with( - other.nrows(), - #[inline(always)] - |row| unsafe { other.read_unchecked(row) }, - ); - *this = mat; - } - implementation(self, other.as_row_ref()); - } - - /// Fills the elements of `self` with zeros. - #[inline(always)] - #[track_caller] - pub fn fill_zero(&mut self) - where - E: ComplexField, - { - self.as_mut().fill_zero() - } - - /// Fills the elements of `self` with copies of `constant`. - #[inline(always)] - #[track_caller] - pub fn fill(&mut self, constant: E) { - self.as_mut().fill(constant) - } - - /// Returns a view over the transpose of `self`. - #[inline] - pub fn transpose(&self) -> ColRef<'_, E> { - self.as_ref().transpose() - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn conjugate(&self) -> RowRef<'_, E::Conj> - where - E: Conjugate, - { - self.as_ref().conjugate() - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline] - pub fn adjoint(&self) -> ColRef<'_, E::Conj> - where - E: Conjugate, - { - self.as_ref().adjoint() - } - - /// Returns an owning [`Row`] of the data - #[inline] - pub fn to_owned(&self) -> Row - where - E: Conjugate, - { - self.as_ref().to_owned() - } - - /// Returns `true` if any of the elements is NaN, otherwise returns `false`. - #[inline] - pub fn has_nan(&self) -> bool - where - E: ComplexField, - { - self.as_ref().has_nan() - } - - /// Returns `true` if all of the elements are finite, otherwise returns `false`. - #[inline] - pub fn is_all_finite(&self) -> bool - where - E: ComplexField, - { - self.as_ref().is_all_finite() - } - - /// Returns the maximum norm of `self`. - #[inline] - pub fn norm_max(&self) -> E::Real - where - E: ComplexField, - { - norm_max((*self).as_ref().as_2d()) - } - /// Returns the L2 norm of `self`. - #[inline] - pub fn norm_l2(&self) -> E::Real - where - E: ComplexField, - { - norm_l2((*self).as_ref().as_2d()) - } - - /// Returns the sum of `self`. - #[inline] - pub fn sum(&self) -> E - where - E: ComplexField, - { - sum((*self).as_ref().as_2d()) - } - - /// Kroneckor product of `self` and `rhs`. - /// - /// This is an allocating operation; see [`kron`] for the - /// allocation-free version or more info in general. - #[inline] - #[track_caller] - pub fn kron(&self, rhs: impl As2D) -> Mat - where - E: ComplexField, - { - self.as_2d_ref().kron(rhs) - } -} - -impl Mat { - /// Returns an empty matrix of dimension `0×0`. - #[inline] - pub fn new() -> Self { - Self { - inner: inner::DenseOwn { - inner: MatOwnImpl { - ptr: into_copy::(E::faer_map(E::UNIT, |()| { - NonNull::::dangling() - })), - nrows: 0, - ncols: 0, - }, - row_capacity: 0, - col_capacity: 0, - }, - } - } - - /// Returns a new matrix with dimensions `(0, 0)`, with enough capacity to hold a maximum of - /// `row_capacity` rows and `col_capacity` columns without reallocating. If either is `0`, - /// the matrix will not allocate. - /// - /// # Panics - /// The function panics if the total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn with_capacity(row_capacity: usize, col_capacity: usize) -> Self { - let raw = ManuallyDrop::new(RawMat::::new(row_capacity, col_capacity)); - Self { - inner: inner::DenseOwn { - inner: MatOwnImpl { - ptr: raw.ptr, - nrows: 0, - ncols: 0, - }, - row_capacity: raw.row_capacity, - col_capacity: raw.col_capacity, - }, - } - } - - /// Returns a new matrix with dimensions `(nrows, ncols)`, filled with the provided function. - /// - /// # Panics - /// The function panics if the total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn from_fn(nrows: usize, ncols: usize, f: impl FnMut(usize, usize) -> E) -> Self { - let mut this = Self::new(); - this.resize_with(nrows, ncols, f); - this - } - - /// Returns a new matrix with dimensions `(nrows, ncols)`, filled with zeros. - /// - /// # Panics - /// The function panics if the total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn zeros(nrows: usize, ncols: usize) -> Self - where - E: ComplexField, - { - Self::from_fn(nrows, ncols, |_, _| E::faer_zero()) - } - - /// Returns a new matrix with dimensions `(nrows, ncols)`, filled with zeros, except the main - /// diagonal which is filled with ones. - /// - /// # Panics - /// The function panics if the total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn identity(nrows: usize, ncols: usize) -> Self - where - E: ComplexField, - { - let mut matrix = Self::zeros(nrows, ncols); - matrix - .as_mut() - .diagonal_mut() - .column_vector_mut() - .fill(E::faer_one()); - matrix - } - - /// Returns the number of rows of the matrix. - #[inline(always)] - pub fn nrows(&self) -> usize { - self.inner.inner.nrows - } - /// Returns the number of columns of the matrix. - #[inline(always)] - pub fn ncols(&self) -> usize { - self.inner.inner.ncols - } - - /// Set the dimensions of the matrix. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `nrows < self.row_capacity()`. - /// * `ncols < self.col_capacity()`. - /// * The elements that were previously out of bounds but are now in bounds must be - /// initialized. - #[inline] - pub unsafe fn set_dims(&mut self, nrows: usize, ncols: usize) { - self.inner.inner.nrows = nrows; - self.inner.inner.ncols = ncols; - } - - /// Returns a pointer to the data of the matrix. - #[inline] - pub fn as_ptr(&self) -> GroupFor { - E::faer_map(from_copy::(self.inner.inner.ptr), |ptr| { - ptr.as_ptr() as *const E::Unit - }) - } - - /// Returns a mutable pointer to the data of the matrix. - #[inline] - pub fn as_ptr_mut(&mut self) -> GroupFor { - E::faer_map(from_copy::(self.inner.inner.ptr), |ptr| ptr.as_ptr()) - } - - /// Returns the row capacity, that is, the number of rows that the matrix is able to hold - /// without needing to reallocate, excluding column insertions. - #[inline] - pub fn row_capacity(&self) -> usize { - self.inner.row_capacity - } - - /// Returns the column capacity, that is, the number of columns that the matrix is able to hold - /// without needing to reallocate, excluding row insertions. - #[inline] - pub fn col_capacity(&self) -> usize { - self.inner.col_capacity - } - - /// Returns the offset between the first elements of two successive rows in the matrix. - /// Always returns `1` since the matrix is column major. - #[inline] - pub fn row_stride(&self) -> isize { - 1 - } - - /// Returns the offset between the first elements of two successive columns in the matrix. - #[inline] - pub fn col_stride(&self) -> isize { - self.row_capacity() as isize - } - - #[cold] - fn do_reserve_exact(&mut self, mut new_row_capacity: usize, new_col_capacity: usize) { - if is_vectorizable::() { - let align_factor = align_for::() / core::mem::size_of::(); - new_row_capacity = new_row_capacity - .msrv_checked_next_multiple_of(align_factor) - .unwrap(); - } - - let nrows = self.inner.inner.nrows; - let ncols = self.inner.inner.ncols; - let old_row_capacity = self.inner.row_capacity; - let old_col_capacity = self.inner.col_capacity; - - let mut this = ManuallyDrop::new(core::mem::take(self)); - { - let mut this_group = - E::faer_map(from_copy::(this.inner.inner.ptr), |ptr| MatUnit { - raw: RawMatUnit { - ptr, - row_capacity: old_row_capacity, - col_capacity: old_col_capacity, - }, - nrows, - ncols, - }); - - E::faer_map(E::faer_as_mut(&mut this_group), |mat_unit| { - mat_unit.do_reserve_exact(new_row_capacity, new_col_capacity); - }); - - let this_group = E::faer_map(this_group, ManuallyDrop::new); - this.inner.inner.ptr = - into_copy::(E::faer_map(this_group, |mat_unit| mat_unit.raw.ptr)); - this.inner.row_capacity = new_row_capacity; - this.inner.col_capacity = new_col_capacity; - } - *self = ManuallyDrop::into_inner(this); - } - - /// Reserves the minimum capacity for `row_capacity` rows and `col_capacity` - /// columns without reallocating. Does nothing if the capacity is already sufficient. - /// - /// # Panics - /// The function panics if the new total capacity in bytes exceeds `isize::MAX`. - #[inline] - pub fn reserve_exact(&mut self, row_capacity: usize, col_capacity: usize) { - if self.row_capacity() >= row_capacity && self.col_capacity() >= col_capacity { - // do nothing - } else if core::mem::size_of::() == 0 { - self.inner.row_capacity = self.row_capacity().max(row_capacity); - self.inner.col_capacity = self.col_capacity().max(col_capacity); - } else { - self.do_reserve_exact(row_capacity, col_capacity); - } - } - - unsafe fn insert_block_with E>( - &mut self, - f: &mut F, - row_start: usize, - row_end: usize, - col_start: usize, - col_end: usize, - ) { - debug_assert!(all(row_start <= row_end, col_start <= col_end)); - - let ptr = self.as_ptr_mut(); - - for j in col_start..col_end { - let ptr_j = E::faer_map(E::faer_copy(&ptr), |ptr| { - ptr.wrapping_offset(j as isize * self.col_stride()) - }); - - for i in row_start..row_end { - // SAFETY: - // * pointer to element at index (i, j), which is within the - // allocation since we reserved enough space - // * writing to this memory region is sound since it is properly - // aligned and valid for writes - let ptr_ij = E::faer_map(E::faer_copy(&ptr_j), |ptr_j| ptr_j.add(i)); - let value = E::faer_into_units(f(i, j)); - - E::faer_map(E::faer_zip(ptr_ij, value), |(ptr_ij, value)| { - core::ptr::write(ptr_ij, value) - }); - } - } - } - - fn erase_last_cols(&mut self, new_ncols: usize) { - let old_ncols = self.ncols(); - debug_assert!(new_ncols <= old_ncols); - self.inner.inner.ncols = new_ncols; - } - - fn erase_last_rows(&mut self, new_nrows: usize) { - let old_nrows = self.nrows(); - debug_assert!(new_nrows <= old_nrows); - self.inner.inner.nrows = new_nrows; - } - - unsafe fn insert_last_cols_with E>( - &mut self, - f: &mut F, - new_ncols: usize, - ) { - let old_ncols = self.ncols(); - - debug_assert!(new_ncols > old_ncols); - - self.insert_block_with(f, 0, self.nrows(), old_ncols, new_ncols); - self.inner.inner.ncols = new_ncols; - } - - unsafe fn insert_last_rows_with E>( - &mut self, - f: &mut F, - new_nrows: usize, - ) { - let old_nrows = self.nrows(); - - debug_assert!(new_nrows > old_nrows); - - self.insert_block_with(f, old_nrows, new_nrows, 0, self.ncols()); - self.inner.inner.nrows = new_nrows; - } - - /// Resizes the matrix in-place so that the new dimensions are `(new_nrows, new_ncols)`. - /// New elements are created with the given function `f`, so that elements at indices `(i, j)` - /// are created by calling `f(i, j)`. - pub fn resize_with( - &mut self, - new_nrows: usize, - new_ncols: usize, - f: impl FnMut(usize, usize) -> E, - ) { - let mut f = f; - let old_nrows = self.nrows(); - let old_ncols = self.ncols(); - - if new_ncols <= old_ncols { - self.erase_last_cols(new_ncols); - if new_nrows <= old_nrows { - self.erase_last_rows(new_nrows); - } else { - self.reserve_exact(new_nrows, new_ncols); - unsafe { - self.insert_last_rows_with(&mut f, new_nrows); - } - } - } else { - if new_nrows <= old_nrows { - self.erase_last_rows(new_nrows); - } else { - self.reserve_exact(new_nrows, new_ncols); - unsafe { - self.insert_last_rows_with(&mut f, new_nrows); - } - } - self.reserve_exact(new_nrows, new_ncols); - unsafe { - self.insert_last_cols_with(&mut f, new_ncols); - } - } - } - - /// Returns a reference to a slice over the column at the given index. - #[inline] - #[track_caller] - pub fn col_as_slice(&self, col: usize) -> GroupFor { - assert!(col < self.ncols()); - let nrows = self.nrows(); - let ptr = self.as_ref().ptr_at(0, col); - E::faer_map( - ptr, - #[inline(always)] - |ptr| unsafe { core::slice::from_raw_parts(ptr, nrows) }, - ) - } - - /// Returns a mutable reference to a slice over the column at the given index. - #[inline] - #[track_caller] - pub fn col_as_slice_mut(&mut self, col: usize) -> GroupFor { - assert!(col < self.ncols()); - let nrows = self.nrows(); - let ptr = self.as_mut().ptr_at_mut(0, col); - E::faer_map( - ptr, - #[inline(always)] - |ptr| unsafe { core::slice::from_raw_parts_mut(ptr, nrows) }, - ) - } - - /// Returns a reference to a slice over the column at the given index. - #[inline] - #[track_caller] - #[deprecated = "replaced by `Mat::col_as_slice`"] - pub fn col_ref(&self, col: usize) -> GroupFor { - self.col_as_slice(col) - } - - /// Returns a mutable reference to a slice over the column at the given index. - #[inline] - #[track_caller] - #[deprecated = "replaced by `Mat::col_as_slice_mut`"] - pub fn col_mut(&mut self, col: usize) -> GroupFor { - self.col_as_slice_mut(col) - } - - /// Returns a view over the matrix. - #[inline] - pub fn as_ref(&self) -> MatRef<'_, E> { - unsafe { - mat::from_raw_parts( - self.as_ptr(), - self.nrows(), - self.ncols(), - 1, - self.col_stride(), - ) - } - } - - /// Returns a mutable view over the matrix. - #[inline] - pub fn as_mut(&mut self) -> MatMut<'_, E> { - unsafe { - mat::from_raw_parts_mut( - self.as_ptr_mut(), - self.nrows(), - self.ncols(), - 1, - self.col_stride(), - ) - } - } - - /// Returns references to the element at the given indices, or submatrices if either `row` or - /// `col` is a range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - /// * `col` must be contained in `[0, self.ncols())`. - #[inline] - pub unsafe fn get_unchecked( - &self, - row: RowRange, - col: ColRange, - ) -> as MatIndex>::Target - where - for<'a> MatRef<'a, E>: MatIndex, - { - self.as_ref().get_unchecked(row, col) - } - - /// Returns references to the element at the given indices, or submatrices if either `row` or - /// `col` is a range, with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - /// * `col` must be contained in `[0, self.ncols())`. - #[inline] - pub fn get( - &self, - row: RowRange, - col: ColRange, - ) -> as MatIndex>::Target - where - for<'a> MatRef<'a, E>: MatIndex, - { - self.as_ref().get(row, col) - } - - /// Returns mutable references to the element at the given indices, or submatrices if either - /// `row` or `col` is a range. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - /// * `col` must be contained in `[0, self.ncols())`. - #[inline] - pub unsafe fn get_mut_unchecked( - &mut self, - row: RowRange, - col: ColRange, - ) -> as MatIndex>::Target - where - for<'a> MatMut<'a, E>: MatIndex, - { - self.as_mut().get_mut_unchecked(row, col) - } - - /// Returns mutable references to the element at the given indices, or submatrices if either - /// `row` or `col` is a range, with bound checks. - /// - /// # Note - /// The values pointed to by the references are expected to be initialized, even if the - /// pointed-to value is not read, otherwise the behavior is undefined. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row` must be contained in `[0, self.nrows())`. - /// * `col` must be contained in `[0, self.ncols())`. - #[inline] - pub fn get_mut( - &mut self, - row: RowRange, - col: ColRange, - ) -> as MatIndex>::Target - where - for<'a> MatMut<'a, E>: MatIndex, - { - self.as_mut().get_mut(row, col) - } - - /// Reads the value of the element at the given indices. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn read_unchecked(&self, row: usize, col: usize) -> E { - self.as_ref().read_unchecked(row, col) - } - - /// Reads the value of the element at the given indices, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn read(&self, row: usize, col: usize) -> E { - self.as_ref().read(row, col) - } - - /// Writes the value to the element at the given indices. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub unsafe fn write_unchecked(&mut self, row: usize, col: usize, value: E) { - self.as_mut().write_unchecked(row, col, value); - } - - /// Writes the value to the element at the given indices, with bound checks. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `row < self.nrows()`. - /// * `col < self.ncols()`. - #[inline(always)] - #[track_caller] - pub fn write(&mut self, row: usize, col: usize, value: E) { - self.as_mut().write(row, col, value); - } - - /// Copies the values from `other` into `self`. - #[inline(always)] - #[track_caller] - pub fn copy_from(&mut self, other: impl AsMatRef) { - #[track_caller] - #[inline(always)] - fn implementation(this: &mut Mat, other: MatRef<'_, E>) { - let mut mat = Mat::::new(); - mat.resize_with( - other.nrows(), - other.ncols(), - #[inline(always)] - |row, col| unsafe { other.read_unchecked(row, col) }, - ); - *this = mat; - } - implementation(self, other.as_mat_ref()); - } - - /// Fills the elements of `self` with zeros. - #[inline(always)] - #[track_caller] - pub fn fill_zero(&mut self) - where - E: ComplexField, - { - self.as_mut().fill_zero() - } - - /// Fills the elements of `self` with copies of `constant`. - #[inline(always)] - #[track_caller] - pub fn fill(&mut self, constant: E) { - self.as_mut().fill(constant) - } - - /// Returns a view over the transpose of `self`. - #[inline] - pub fn transpose(&self) -> MatRef<'_, E> { - self.as_ref().transpose() - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn conjugate(&self) -> MatRef<'_, E::Conj> - where - E: Conjugate, - { - self.as_ref().conjugate() - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline] - pub fn adjoint(&self) -> MatRef<'_, E::Conj> - where - E: Conjugate, - { - self.as_ref().adjoint() - } - - /// Returns a view over the diagonal of the matrix. - #[inline] - pub fn diagonal(&self) -> Matrix> { - self.as_ref().diagonal() - } - - /// Returns an owning [`Mat`] of the data - #[inline] - pub fn to_owned(&self) -> Mat - where - E: Conjugate, - { - self.as_ref().to_owned() - } - - /// Returns `true` if any of the elements is NaN, otherwise returns `false`. - #[inline] - pub fn has_nan(&self) -> bool - where - E: ComplexField, - { - self.as_ref().has_nan() - } - - /// Returns `true` if all of the elements are finite, otherwise returns `false`. - #[inline] - pub fn is_all_finite(&self) -> bool - where - E: ComplexField, - { - self.as_ref().is_all_finite() - } - - /// Returns the maximum norm of `self`. - #[inline] - pub fn norm_max(&self) -> E::Real - where - E: ComplexField, - { - norm_max((*self).as_ref()) - } - /// Returns the L2 norm of `self`. - #[inline] - pub fn norm_l2(&self) -> E::Real - where - E: ComplexField, - { - norm_l2((*self).as_ref()) - } - - /// Returns the sum of `self`. - #[inline] - pub fn sum(&self) -> E - where - E: ComplexField, - { - sum((*self).as_ref()) - } - - /// Kroneckor product of `self` and `rhs`. - /// - /// This is an allocating operation; see [`kron`] for the - /// allocation-free version or more info in general. - #[inline] - #[track_caller] - pub fn kron(&self, rhs: impl As2D) -> Mat - where - E: ComplexField, - { - self.as_2d_ref().kron(rhs) - } - - /// Returns an iterator that provides successive chunks of the columns of a view over this - /// matrix, with each having at most `chunk_size` columns. - /// - /// If the number of columns is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// columns. - #[inline] - #[track_caller] - pub fn col_chunks( - &self, - chunk_size: usize, - ) -> impl '_ + DoubleEndedIterator> { - self.as_ref().col_chunks(chunk_size) - } - - /// Returns an iterator that provides successive chunks of the columns of a mutable view over - /// this matrix, with each having at most `chunk_size` columns. - /// - /// If the number of columns is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// columns. - #[inline] - #[track_caller] - pub fn col_chunks_mut( - &mut self, - chunk_size: usize, - ) -> impl '_ + DoubleEndedIterator> { - self.as_mut().col_chunks_mut(chunk_size) - } - - /// Returns a parallel iterator that provides successive chunks of the columns of a view over - /// this matrix, with each having at most `chunk_size` columns. - /// - /// If the number of columns is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// columns. - /// - /// Only available with the `rayon` feature. - #[cfg(feature = "rayon")] - #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] - #[inline] - #[track_caller] - pub fn par_col_chunks( - &self, - chunk_size: usize, - ) -> impl '_ + rayon::iter::IndexedParallelIterator> { - self.as_ref().par_col_chunks(chunk_size) - } - - /// Returns a parallel iterator that provides successive chunks of the columns of a mutable view - /// over this matrix, with each having at most `chunk_size` columns. - /// - /// If the number of columns is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// columns. - /// - /// Only available with the `rayon` feature. - #[cfg(feature = "rayon")] - #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] - #[inline] - #[track_caller] - pub fn par_col_chunks_mut( - &mut self, - chunk_size: usize, - ) -> impl '_ + rayon::iter::IndexedParallelIterator> { - self.as_mut().par_col_chunks_mut(chunk_size) - } - - /// Returns an iterator that provides successive chunks of the rows of a view over this - /// matrix, with each having at most `chunk_size` rows. - /// - /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// rows. - #[inline] - #[track_caller] - pub fn row_chunks( - &self, - chunk_size: usize, - ) -> impl '_ + DoubleEndedIterator> { - self.as_ref().row_chunks(chunk_size) - } - - /// Returns an iterator that provides successive chunks of the rows of a mutable view over - /// this matrix, with each having at most `chunk_size` rows. - /// - /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// rows. - #[inline] - #[track_caller] - pub fn row_chunks_mut( - &mut self, - chunk_size: usize, - ) -> impl '_ + DoubleEndedIterator> { - self.as_mut().row_chunks_mut(chunk_size) - } - - /// Returns a parallel iterator that provides successive chunks of the rows of a view over this - /// matrix, with each having at most `chunk_size` rows. - /// - /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// rows. - /// - /// Only available with the `rayon` feature. - #[cfg(feature = "rayon")] - #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] - #[inline] - #[track_caller] - pub fn par_row_chunks( - &self, - chunk_size: usize, - ) -> impl '_ + rayon::iter::IndexedParallelIterator> { - self.as_ref().par_row_chunks(chunk_size) - } - - /// Returns a parallel iterator that provides successive chunks of the rows of a mutable view - /// over this matrix, with each having at most `chunk_size` rows. - /// - /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` - /// rows. - /// - /// Only available with the `rayon` feature. - #[cfg(feature = "rayon")] - #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] - #[inline] - #[track_caller] - pub fn par_row_chunks_mut( - &mut self, - chunk_size: usize, - ) -> impl '_ + rayon::iter::IndexedParallelIterator> { - self.as_mut().par_row_chunks_mut(chunk_size) - } -} - -#[doc(hidden)] -#[inline(always)] -pub fn ref_to_ptr(ptr: &T) -> *const T { - ptr -} - -#[macro_export] -#[doc(hidden)] -macro_rules! __transpose_impl { - ([$([$($col:expr),*])*] $($v:expr;)* ) => { - [$([$($col,)*],)* [$($v,)*]] - }; - ([$([$($col:expr),*])*] $($v0:expr, $($v:expr),* ;)*) => { - $crate::__transpose_impl!([$([$($col),*])* [$($v0),*]] $($($v),* ;)*) - }; -} - -/// Creates a [`Mat`] containing the arguments. -/// -/// ``` -/// use faer_core::mat; -/// -/// let matrix = mat![ -/// [1.0, 5.0, 9.0], -/// [2.0, 6.0, 10.0], -/// [3.0, 7.0, 11.0], -/// [4.0, 8.0, 12.0f64], -/// ]; -/// -/// assert_eq!(matrix.read(0, 0), 1.0); -/// assert_eq!(matrix.read(1, 0), 2.0); -/// assert_eq!(matrix.read(2, 0), 3.0); -/// assert_eq!(matrix.read(3, 0), 4.0); -/// -/// assert_eq!(matrix.read(0, 1), 5.0); -/// assert_eq!(matrix.read(1, 1), 6.0); -/// assert_eq!(matrix.read(2, 1), 7.0); -/// assert_eq!(matrix.read(3, 1), 8.0); -/// -/// assert_eq!(matrix.read(0, 2), 9.0); -/// assert_eq!(matrix.read(1, 2), 10.0); -/// assert_eq!(matrix.read(2, 2), 11.0); -/// assert_eq!(matrix.read(3, 2), 12.0); -/// ``` -#[macro_export] -macro_rules! mat { - () => { - { - compile_error!("number of columns in the matrix is ambiguous"); - } - }; - - ($([$($v:expr),* $(,)?] ),* $(,)?) => { - { - let data = ::core::mem::ManuallyDrop::new($crate::__transpose_impl!([] $($($v),* ;)*)); - let data = &*data; - let ncols = data.len(); - let nrows = (*data.get(0).unwrap()).len(); - - #[allow(unused_unsafe)] - unsafe { - $crate::Mat::<_>::from_fn(nrows, ncols, |i, j| $crate::ref_to_ptr(&data[j][i]).read()) - } - } - }; -} - -/// Concatenates the matrices in each row horizontally, -/// then concatenates the results vertically. -/// `concat![[a0, a1, a2], [b1, b2]]` results in the matrix -/// -/// ```notcode -/// [a0 | a1 | a2][b0 | b1] -/// ``` -#[macro_export] -macro_rules! concat { - () => { - { - compile_error!("number of columns in the matrix is ambiguous"); - } - }; - - ($([$($v:expr),* $(,)?] ),* $(,)?) => { - { - $crate::__concat_impl(&[$(&[$(($v).as_ref(),)*],)*]) - } - }; -} - -/// Creates a [`Col`] containing the arguments. -/// -/// ``` -/// use faer_core::col; -/// -/// let col_vec = col![3.0, 5.0, 7.0, 9.0]; -/// -/// assert_eq!(col_vec.read(0), 3.0); -/// assert_eq!(col_vec.read(1), 5.0); -/// assert_eq!(col_vec.read(2), 7.0); -/// assert_eq!(col_vec.read(3), 9.0); -/// ``` -#[macro_export] -macro_rules! col { - () => { - $crate::Col::<_>::new() - }; - - ($($v:expr),+ $(,)?) => {{ - let data = &[$($v),+]; - let n = data.len(); - - #[allow(unused_unsafe)] - unsafe { - $crate::Col::<_>::from_fn(n, |i| $crate::ref_to_ptr(&data[i]).read()) - } - }}; -} - -/// Creates a [`Row`] containing the arguments. -/// -/// ``` -/// use faer_core::row; -/// -/// let row_vec = row![3.0, 5.0, 7.0, 9.0]; -/// -/// assert_eq!(row_vec.read(0), 3.0); -/// assert_eq!(row_vec.read(1), 5.0); -/// assert_eq!(row_vec.read(2), 7.0); -/// assert_eq!(row_vec.read(3), 9.0); -/// ``` -#[macro_export] -macro_rules! row { - () => { - $crate::Row::<_>::new() - }; - - ($($v:expr),+ $(,)?) => {{ - let data = &[$($v),+]; - let n = data.len(); - - #[allow(unused_unsafe)] - unsafe { - $crate::Row::<_>::from_fn(n, |i| $crate::ref_to_ptr(&data[i]).read()) - } - }}; -} - -/// Parallelism strategy that can be passed to most of the routines in the library. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum Parallelism { - /// No parallelism. - /// - /// The code is executed sequentially on the same thread that calls a function - /// and passes this argument. - None, - /// Rayon parallelism. Only avaialble with the `rayon` feature. - /// - /// The code is possibly executed in parallel on the current thread, as well as the currently - /// active rayon thread pool. - /// - /// The contained value represents a hint about the number of threads an implementation should - /// use, but there is no way to guarantee how many or which threads will be used. - /// - /// A value of `0` treated as equivalent to `rayon::current_num_threads()`. - #[cfg(feature = "rayon")] - #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] - Rayon(usize), -} - -/// 0: Disable -/// 1: None -/// n >= 2: Rayon(n - 2) -/// -/// default: Rayon(0) -static GLOBAL_PARALLELISM: AtomicUsize = { - #[cfg(feature = "rayon")] - { - AtomicUsize::new(2) - } - #[cfg(not(feature = "rayon"))] - { - AtomicUsize::new(1) - } -}; - -/// Causes functions that access global parallelism settings to panic. -pub fn disable_global_parallelism() { - GLOBAL_PARALLELISM.store(0, core::sync::atomic::Ordering::Relaxed); -} - -/// Sets the global parallelism settings. -pub fn set_global_parallelism(parallelism: Parallelism) { - let value = match parallelism { - Parallelism::None => 1, - #[cfg(feature = "rayon")] - Parallelism::Rayon(n) => n.saturating_add(2), - }; - GLOBAL_PARALLELISM.store(value, core::sync::atomic::Ordering::Relaxed); -} - -/// Gets the global parallelism settings. -/// -/// # Panics -/// Panics if global parallelism is disabled. -#[track_caller] -pub fn get_global_parallelism() -> Parallelism { - let value = GLOBAL_PARALLELISM.load(core::sync::atomic::Ordering::Relaxed); - match value { - 0 => panic!("Global parallelism is disabled."), - 1 => Parallelism::None, - #[cfg(feature = "rayon")] - n => Parallelism::Rayon(n - 2), - #[cfg(not(feature = "rayon"))] - _ => unreachable!(), - } -} - -#[inline] -#[doc(hidden)] -pub fn join_raw( - op_a: impl Send + FnOnce(Parallelism), - op_b: impl Send + FnOnce(Parallelism), - parallelism: Parallelism, -) { - fn implementation( - op_a: &mut (dyn Send + FnMut(Parallelism)), - op_b: &mut (dyn Send + FnMut(Parallelism)), - parallelism: Parallelism, - ) { - match parallelism { - Parallelism::None => (op_a(parallelism), op_b(parallelism)), - #[cfg(feature = "rayon")] - Parallelism::Rayon(n_threads) => { - if n_threads == 1 { - (op_a(Parallelism::None), op_b(Parallelism::None)) - } else { - let n_threads = if n_threads > 0 { - n_threads - } else { - rayon::current_num_threads() - }; - let parallelism = Parallelism::Rayon(n_threads - n_threads / 2); - rayon::join(|| op_a(parallelism), || op_b(parallelism)) - } - } - }; - } - let mut op_a = Some(op_a); - let mut op_b = Some(op_b); - implementation( - &mut |parallelism| (op_a.take().unwrap())(parallelism), - &mut |parallelism| (op_b.take().unwrap())(parallelism), - parallelism, - ) -} - -#[inline] -#[doc(hidden)] -pub fn for_each_raw(n_tasks: usize, op: impl Send + Sync + Fn(usize), parallelism: Parallelism) { - fn implementation( - n_tasks: usize, - op: &(dyn Send + Sync + Fn(usize)), - parallelism: Parallelism, - ) { - if n_tasks == 1 { - op(0); - return; - } - - match parallelism { - Parallelism::None => (0..n_tasks).for_each(op), - #[cfg(feature = "rayon")] - Parallelism::Rayon(n_threads) => { - let n_threads = if n_threads > 0 { - n_threads - } else { - rayon::current_num_threads() - }; - - use rayon::prelude::*; - let min_len = n_tasks / n_threads; - (0..n_tasks) - .into_par_iter() - .with_min_len(min_len) - .for_each(op); - } - } - } - implementation(n_tasks, &op, parallelism); -} - -#[doc(hidden)] -pub struct Ptr(pub *mut T); -unsafe impl Send for Ptr {} -unsafe impl Sync for Ptr {} -impl Copy for Ptr {} -impl Clone for Ptr { - #[inline] - fn clone(&self) -> Self { - *self - } -} - -#[inline] -#[doc(hidden)] -pub fn parallelism_degree(parallelism: Parallelism) -> usize { - match parallelism { - Parallelism::None => 1, - #[cfg(feature = "rayon")] - Parallelism::Rayon(0) => rayon::current_num_threads(), - #[cfg(feature = "rayon")] - Parallelism::Rayon(n_threads) => n_threads, - } -} - -/// Creates a temporary matrix of constant values, from the given memory stack. -pub fn temp_mat_constant( - nrows: usize, - ncols: usize, - value: E, - stack: PodStack<'_>, -) -> (MatMut<'_, E>, PodStack<'_>) { - let (mut mat, stack) = temp_mat_uninit::(nrows, ncols, stack); - mat.as_mut().fill(value); - (mat, stack) -} - -/// Creates a temporary matrix of zero values, from the given memory stack. -pub fn temp_mat_zeroed( - nrows: usize, - ncols: usize, - stack: PodStack<'_>, -) -> (MatMut<'_, E>, PodStack<'_>) { - let (mut mat, stack) = temp_mat_uninit::(nrows, ncols, stack); - mat.as_mut().fill_zero(); - (mat, stack) -} - -/// Creates a temporary matrix of untouched values, from the given memory stack. -pub fn temp_mat_uninit( - nrows: usize, - ncols: usize, - stack: PodStack<'_>, -) -> (MatMut<'_, E>, PodStack<'_>) { - let col_stride = col_stride::(nrows); - let alloc_size = ncols.checked_mul(col_stride).unwrap(); - - let (stack, alloc) = E::faer_map_with_context(stack, E::UNIT, &mut { - #[inline(always)] - |stack, ()| { - let (alloc, stack) = - stack.make_aligned_raw::(alloc_size, align_for::()); - (stack, alloc) - } - }); - ( - unsafe { - mat::from_raw_parts_mut( - E::faer_map(alloc, |alloc| alloc.as_mut_ptr()), - nrows, - ncols, - 1, - col_stride as isize, - ) - }, - stack, - ) -} - -#[doc(hidden)] -#[inline] -pub fn col_stride(nrows: usize) -> usize { - if !is_vectorizable::() || nrows >= isize::MAX as usize { - nrows - } else { - nrows - .msrv_checked_next_multiple_of(align_for::() / core::mem::size_of::()) - .unwrap() - } -} - -/// Returns the stack requirements for creating a temporary matrix with the given dimensions. -#[inline] -pub fn temp_mat_req(nrows: usize, ncols: usize) -> Result { - let col_stride = col_stride::(nrows); - let alloc_size = ncols.checked_mul(col_stride).ok_or(SizeOverflow)?; - let additional = StackReq::try_new_aligned::(alloc_size, align_for::())?; - - let req = Ok(StackReq::empty()); - let (req, _) = E::faer_map_with_context(req, E::UNIT, &mut { - #[inline(always)] - |req, ()| { - let req = match req { - Ok(req) => req.try_and(additional), - _ => Err(SizeOverflow), - }; - - (req, ()) - } - }); - - req -} - -impl<'a, FromE: Entity, ToE: Entity> Coerce> for MatRef<'a, FromE> { - #[inline(always)] - fn coerce(self) -> MatRef<'a, ToE> { - assert!(coe::is_same::()); - unsafe { transmute_unchecked::, MatRef<'a, ToE>>(self) } - } -} -impl<'a, FromE: Entity, ToE: Entity> Coerce> for MatMut<'a, FromE> { - #[inline(always)] - fn coerce(self) -> MatMut<'a, ToE> { - assert!(coe::is_same::()); - unsafe { transmute_unchecked::, MatMut<'a, ToE>>(self) } - } -} - -/// Zips together matrix of the same size, so that coefficient-wise operations can be performed on -/// their elements. -/// -/// # Note -/// The order in which the matrix elements are traversed is unspecified. -/// -/// # Example -/// ``` -/// use faer_core::{mat, unzipped, zipped, Mat}; -/// -/// let nrows = 2; -/// let ncols = 3; -/// -/// let a = mat![[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]]; -/// let b = mat![[7.0, 9.0, 11.0], [8.0, 10.0, 12.0]]; -/// let mut sum = Mat::::zeros(nrows, ncols); -/// -/// zipped!(sum.as_mut(), a.as_ref(), b.as_ref()).for_each(|unzipped!(mut sum, a, b)| { -/// let a = a.read(); -/// let b = b.read(); -/// sum.write(a + b); -/// }); -/// -/// for i in 0..nrows { -/// for j in 0..ncols { -/// assert_eq!(sum.read(i, j), a.read(i, j) + b.read(i, j)); -/// } -/// } -/// ``` -#[macro_export] -macro_rules! zipped { - ($head: expr $(,)?) => { - $crate::zip::LastEq($crate::zip::ViewMut::view_mut(&mut { $head })) - }; - - ($head: expr, $($tail: expr),* $(,)?) => { - $crate::zip::ZipEq::new($crate::zip::ViewMut::view_mut(&mut { $head }), $crate::zipped!($($tail,)*)) - }; -} - -/// Used to undo the zipping by the [`zipped!`] macro. -#[macro_export] -macro_rules! unzipped { - ($head: pat $(,)?) => { - $crate::zip::Last($head) - }; - - ($head: pat, $($tail: pat),* $(,)?) => { - $crate::zip::Zip($head, $crate::unzipped!($($tail,)*)) - }; -} - -impl<'a, E: Entity> Debug for RowRef<'a, E> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.as_2d().fmt(f) - } -} -impl<'a, E: Entity> Debug for RowMut<'a, E> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.rb().fmt(f) - } -} - -impl<'a, E: Entity> Debug for ColRef<'a, E> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.as_2d().fmt(f) - } -} -impl<'a, E: Entity> Debug for ColMut<'a, E> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.rb().fmt(f) - } -} - -impl<'a, E: Entity> Debug for MatRef<'a, E> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - struct DebugRow<'a, T: Entity>(MatRef<'a, T>); - - impl<'a, T: Entity> Debug for DebugRow<'a, T> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mut j = 0; - f.debug_list() - .entries(core::iter::from_fn(|| { - let ret = if j < self.0.ncols() { - Some(T::faer_from_units(T::faer_deref(self.0.get(0, j)))) - } else { - None - }; - j += 1; - ret - })) - .finish() - } - } - - writeln!(f, "[")?; - for i in 0..self.nrows() { - let row = self.subrows(i, 1); - DebugRow(row).fmt(f)?; - f.write_str(",\n")?; - } - write!(f, "]") - } -} - -impl<'a, E: Entity> Debug for MatMut<'a, E> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.rb().fmt(f) - } -} - -impl Debug for Mat { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.as_ref().fmt(f) - } -} - -/// Advanced: Module for index and matrix types with compile time checks, instead of bound checking -/// at runtime. -pub mod constrained { - use core::ops::Range; - - use super::*; - use crate::{ - assert, debug_assert, - permutation::{Index, SignedIndex}, - }; - - #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] - #[repr(transparent)] - struct Branded<'a, T: ?Sized> { - __marker: PhantomData &'a ()>, - inner: T, - } - - /// `usize` value tied to the lifetime `'n`. - #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] - #[repr(transparent)] - pub struct Size<'n>(Branded<'n, usize>); - - /// `I` value smaller than the size corresponding to the lifetime `'n`. - #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] - #[repr(transparent)] - pub struct Idx<'n, I>(Branded<'n, I>); - - /// `I` value smaller or equal to the size corresponding to the lifetime `'n`. - #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] - #[repr(transparent)] - pub struct IdxInclusive<'n, I>(Branded<'n, I>); - - /// `I` value smaller than the size corresponding to the lifetime `'n`, or `None`. - #[derive(Copy, Clone, PartialEq, Eq)] - #[repr(transparent)] - pub struct MaybeIdx<'n, I: Index>(Branded<'n, I>); - - impl core::ops::Deref for Size<'_> { - type Target = usize; - #[inline] - fn deref(&self) -> &Self::Target { - &self.0.inner - } - } - impl core::ops::Deref for Idx<'_, I> { - type Target = I; - #[inline] - fn deref(&self) -> &Self::Target { - &self.0.inner - } - } - impl core::ops::Deref for MaybeIdx<'_, I> { - type Target = I::Signed; - #[inline] - fn deref(&self) -> &Self::Target { - bytemuck::cast_ref(&self.0.inner) - } - } - impl core::ops::Deref for IdxInclusive<'_, I> { - type Target = I; - #[inline] - fn deref(&self) -> &Self::Target { - &self.0.inner - } - } - - /// Array of length equal to the value tied to `'n`. - #[derive(PartialEq, Eq, PartialOrd, Ord)] - #[repr(transparent)] - pub struct Array<'n, T>(Branded<'n, [T]>); - - /// Immutable dense matrix view with dimensions equal to the values tied to `('nrows, 'ncols)`. - #[repr(transparent)] - pub struct MatRef<'nrows, 'ncols, 'a, E: Entity>( - Branded<'ncols, Branded<'nrows, super::MatRef<'a, E>>>, - ); - /// Mutable dense matrix view with dimensions equal to the values tied to `('nrows, 'ncols)`. - #[repr(transparent)] - pub struct MatMut<'nrows, 'ncols, 'a, E: Entity>( - Branded<'ncols, Branded<'nrows, super::MatMut<'a, E>>>, - ); - - /// Permutations with compile-time checks. - pub mod permutation { - use super::*; - use crate::assert; - - /// Permutation of length equal to the value tied to `'n`. - #[repr(transparent)] - pub struct PermutationRef<'n, 'a, I, E: Entity>( - Branded<'n, crate::permutation::PermutationRef<'a, I, E>>, - ); - - impl<'n, 'a, I: Index, E: Entity> PermutationRef<'n, 'a, I, E> { - /// Returns a new permutation after checking that it matches the size tied to `'n`. - #[inline] - #[track_caller] - pub fn new(perm: crate::permutation::PermutationRef<'a, I, E>, size: Size<'n>) -> Self { - let (fwd, inv) = perm.into_arrays(); - assert!(all( - fwd.len() == size.into_inner(), - inv.len() == size.into_inner(), - )); - Self(Branded { - __marker: PhantomData, - inner: perm, - }) - } - - /// Returns the inverse permutation. - #[inline] - pub fn inverse(self) -> PermutationRef<'n, 'a, I, E> { - PermutationRef(Branded { - __marker: PhantomData, - inner: self.0.inner.inverse(), - }) - } - - /// Returns the forward and inverse permutation indices. - #[inline] - pub fn into_arrays(self) -> (&'a Array<'n, Idx<'n, I>>, &'a Array<'n, Idx<'n, I>>) { - unsafe { - let (fwd, inv) = self.0.inner.into_arrays(); - let fwd = &*(fwd as *const [I] as *const Array<'n, Idx<'n, I>>); - let inv = &*(inv as *const [I] as *const Array<'n, Idx<'n, I>>); - (fwd, inv) - } - } - - /// Returns the unconstrained permutation. - #[inline] - pub fn into_inner(self) -> crate::permutation::PermutationRef<'a, I, E> { - self.0.inner - } - - /// Returns the length of the permutation. - #[inline] - pub fn len(&self) -> Size<'n> { - unsafe { Size::new_raw_unchecked(self.into_inner().len()) } - } - - /// Casts the permutation to one with a different type. - pub fn cast(self) -> PermutationRef<'n, 'a, I, T> { - PermutationRef(Branded { - __marker: PhantomData, - inner: self.into_inner().cast(), - }) - } - } - - impl Clone for PermutationRef<'_, '_, I, E> { - #[inline] - fn clone(&self) -> Self { - *self - } - } - impl Copy for PermutationRef<'_, '_, I, E> {} - - impl Debug for PermutationRef<'_, '_, I, E> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.fmt(f) - } - } - } - - /// Sparse matrices with compile-time checks. - pub mod sparse { - use super::*; - use crate::{assert, group_helpers::SliceGroup, sparse::__get_unchecked}; - use core::ops::Range; - - /// Symbolic structure view with dimensions equal to the values tied to `('nrows, 'ncols)`, - /// in column-major order. - #[repr(transparent)] - pub struct SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>( - Branded<'ncols, Branded<'nrows, crate::sparse::SymbolicSparseColMatRef<'a, I>>>, - ); - /// Immutable sparse matrix view with dimensions equal to the values tied to `('nrows, - /// 'ncols)`, in column-major order. - pub struct SparseColMatRef<'nrows, 'ncols, 'a, I, E: Entity> { - symbolic: SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>, - values: SliceGroup<'a, E>, - } - /// Mutable sparse matrix view with dimensions equal to the values tied to `('nrows, - /// 'ncols)`, in column-major order. - pub struct SparseColMatMut<'nrows, 'ncols, 'a, I, E: Entity> { - symbolic: SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>, - values: SliceGroupMut<'a, E>, - } - - impl<'nrows, 'ncols, 'a, I: Index> SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I> { - /// Returns a new symbolic structure after checking that its dimensions match the - /// dimensions tied to `('nrows, 'ncols)`. - #[inline] - pub fn new( - inner: crate::sparse::SymbolicSparseColMatRef<'a, I>, - nrows: Size<'nrows>, - ncols: Size<'ncols>, - ) -> Self { - assert!(all( - inner.nrows() == nrows.into_inner(), - inner.ncols() == ncols.into_inner(), - )); - Self(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } - - /// Returns the unconstrained symbolic structure. - #[inline] - pub fn into_inner(self) -> crate::sparse::SymbolicSparseColMatRef<'a, I> { - self.0.inner.inner - } - - /// Returns the number of rows of the matrix. - #[inline] - pub fn nrows(&self) -> Size<'nrows> { - unsafe { Size::new_raw_unchecked(self.0.inner.inner.nrows()) } - } - - /// Returns the number of columns of the matrix. - #[inline] - pub fn ncols(&self) -> Size<'ncols> { - unsafe { Size::new_raw_unchecked(self.0.inner.inner.ncols()) } - } - - #[inline] - #[track_caller] - #[doc(hidden)] - pub fn col_range(&self, j: Idx<'ncols, usize>) -> Range { - unsafe { self.into_inner().col_range_unchecked(j.into_inner()) } - } - - /// Returns the row indices in column `j`. - #[inline] - #[track_caller] - pub fn row_indices_of_col_raw(&self, j: Idx<'ncols, usize>) -> &'a [Idx<'nrows, I>] { - unsafe { - &*(__get_unchecked(self.into_inner().row_indices(), self.col_range(j)) - as *const [I] as *const [Idx<'_, I>]) - } - } - - /// Returns the row indices in column `j`. - #[inline] - #[track_caller] - pub fn row_indices_of_col( - &self, - j: Idx<'ncols, usize>, - ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator> - { - unsafe { - __get_unchecked( - self.into_inner().row_indices(), - self.into_inner().col_range_unchecked(j.into_inner()), - ) - .iter() - .map( - #[inline(always)] - move |&row| Idx::new_raw_unchecked(row.zx()), - ) - } - } - } - - impl<'nrows, 'ncols, 'a, I: Index, E: Entity> SparseColMatRef<'nrows, 'ncols, 'a, I, E> { - /// Returns a new matrix view after checking that its dimensions match the - /// dimensions tied to `('nrows, 'ncols)`. - pub fn new( - inner: crate::sparse::SparseColMatRef<'a, I, E>, - nrows: Size<'nrows>, - ncols: Size<'ncols>, - ) -> Self { - assert!(all( - inner.nrows() == nrows.into_inner(), - inner.ncols() == ncols.into_inner(), - )); - Self { - symbolic: SymbolicSparseColMatRef::new(inner.symbolic(), nrows, ncols), - values: SliceGroup::new(inner.values()), - } - } - - /// Returns the unconstrained matrix. - #[inline] - pub fn into_inner(self) -> crate::sparse::SparseColMatRef<'a, I, E> { - crate::sparse::SparseColMatRef::new( - self.symbolic.into_inner(), - self.values.into_inner(), - ) - } - - /// Returns the values in column `j`. - #[inline] - pub fn values_of_col(&self, j: Idx<'ncols, usize>) -> GroupFor { - unsafe { - self.values - .subslice_unchecked(self.col_range(j)) - .into_inner() - } - } - } - - impl<'nrows, 'ncols, 'a, I: Index, E: Entity> SparseColMatMut<'nrows, 'ncols, 'a, I, E> { - /// Returns a new matrix view after checking that its dimensions match the - /// dimensions tied to `('nrows, 'ncols)`. - pub fn new( - inner: crate::sparse::SparseColMatMut<'a, I, E>, - nrows: Size<'nrows>, - ncols: Size<'ncols>, - ) -> Self { - assert!(all( - inner.nrows() == nrows.into_inner(), - inner.ncols() == ncols.into_inner(), - )); - Self { - symbolic: SymbolicSparseColMatRef::new(inner.symbolic(), nrows, ncols), - values: SliceGroupMut::new(inner.values_mut()), - } - } - - /// Returns the unconstrained matrix. - #[inline] - pub fn into_inner(self) -> crate::sparse::SparseColMatMut<'a, I, E> { - crate::sparse::SparseColMatMut::new( - self.symbolic.into_inner(), - self.values.into_inner(), - ) - } - - /// Returns the values in column `j`. - #[inline] - pub fn values_of_col_mut( - &mut self, - j: Idx<'ncols, usize>, - ) -> GroupFor { - unsafe { - let range = self.col_range(j); - self.values.rb_mut().subslice_unchecked(range).into_inner() - } - } - } - - impl Copy for SparseColMatRef<'_, '_, '_, I, E> {} - impl Clone for SparseColMatRef<'_, '_, '_, I, E> { - #[inline] - fn clone(&self) -> Self { - *self - } - } - impl Copy for SymbolicSparseColMatRef<'_, '_, '_, I> {} - impl Clone for SymbolicSparseColMatRef<'_, '_, '_, I> { - #[inline] - fn clone(&self) -> Self { - *self - } - } - - impl<'nrows, 'ncols, 'a, I, E: Entity> core::ops::Deref - for SparseColMatRef<'nrows, 'ncols, 'a, I, E> - { - type Target = SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>; - - #[inline] - fn deref(&self) -> &Self::Target { - &self.symbolic - } - } - - impl<'nrows, 'ncols, 'a, I, E: Entity> core::ops::Deref - for SparseColMatMut<'nrows, 'ncols, 'a, I, E> - { - type Target = SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>; - - #[inline] - fn deref(&self) -> &Self::Target { - &self.symbolic - } - } - - impl<'short, 'nrows, 'ncols, 'a, I, E: Entity> ReborrowMut<'short> - for SparseColMatRef<'nrows, 'ncols, 'a, I, E> - { - type Target = SparseColMatRef<'nrows, 'ncols, 'short, I, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } - - impl<'short, 'nrows, 'ncols, 'a, I, E: Entity> Reborrow<'short> - for SparseColMatRef<'nrows, 'ncols, 'a, I, E> - { - type Target = SparseColMatRef<'nrows, 'ncols, 'short, I, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - *self - } - } - - impl<'nrows, 'ncols, 'a, I, E: Entity> IntoConst for SparseColMatRef<'nrows, 'ncols, 'a, I, E> { - type Target = SparseColMatRef<'nrows, 'ncols, 'a, I, E>; - - #[inline] - fn into_const(self) -> Self::Target { - self - } - } - - impl<'short, 'nrows, 'ncols, 'a, I, E: Entity> ReborrowMut<'short> - for SparseColMatMut<'nrows, 'ncols, 'a, I, E> - { - type Target = SparseColMatMut<'nrows, 'ncols, 'short, I, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - SparseColMatMut::<'nrows, 'ncols, 'short, I, E> { - symbolic: self.symbolic, - values: self.values.rb_mut(), - } - } - } - - impl<'short, 'nrows, 'ncols, 'a, I, E: Entity> Reborrow<'short> - for SparseColMatMut<'nrows, 'ncols, 'a, I, E> - { - type Target = SparseColMatRef<'nrows, 'ncols, 'short, I, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - SparseColMatRef::<'nrows, 'ncols, 'short, I, E> { - symbolic: self.symbolic, - values: self.values.rb(), - } - } - } - - impl<'nrows, 'ncols, 'a, I, E: Entity> IntoConst for SparseColMatMut<'nrows, 'ncols, 'a, I, E> { - type Target = SparseColMatRef<'nrows, 'ncols, 'a, I, E>; - - #[inline] - fn into_const(self) -> Self::Target { - SparseColMatRef::<'nrows, 'ncols, 'a, I, E> { - symbolic: self.symbolic, - values: self.values.into_const(), - } - } - } - } - - /// Group helpers with compile-time checks. - pub mod group_helpers { - use super::*; - use crate::{ - assert, - group_helpers::{SliceGroup, SliceGroupMut}, - }; - use core::ops::Range; - - /// Immutable array group of length equal to the value tied to `'n`. - pub struct ArrayGroup<'n, 'a, E: Entity>(Branded<'n, SliceGroup<'a, E>>); - /// Mutable array group of length equal to the value tied to `'n`. - pub struct ArrayGroupMut<'n, 'a, E: Entity>(Branded<'n, SliceGroupMut<'a, E>>); - - impl Debug for ArrayGroup<'_, '_, E> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.fmt(f) - } - } - impl Debug for ArrayGroupMut<'_, '_, E> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.fmt(f) - } - } - - impl Copy for ArrayGroup<'_, '_, E> {} - impl Clone for ArrayGroup<'_, '_, E> { - #[inline] - fn clone(&self) -> Self { - *self - } - } - - impl<'short, 'n, 'a, E: Entity> reborrow::ReborrowMut<'short> for ArrayGroup<'n, 'a, E> { - type Target = ArrayGroup<'n, 'short, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } - - impl<'short, 'n, 'a, E: Entity> reborrow::Reborrow<'short> for ArrayGroup<'n, 'a, E> { - type Target = ArrayGroup<'n, 'short, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - *self - } - } - - impl<'short, 'n, 'a, E: Entity> reborrow::ReborrowMut<'short> for ArrayGroupMut<'n, 'a, E> { - type Target = ArrayGroupMut<'n, 'short, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - ArrayGroupMut(Branded { - __marker: PhantomData, - inner: self.0.inner.rb_mut(), - }) - } - } - - impl<'short, 'n, 'a, E: Entity> reborrow::Reborrow<'short> for ArrayGroupMut<'n, 'a, E> { - type Target = ArrayGroup<'n, 'short, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - ArrayGroup(Branded { - __marker: PhantomData, - inner: self.0.inner.rb(), - }) - } - } - - impl<'n, 'a, E: Entity> ArrayGroupMut<'n, 'a, E> { - /// Returns an array group with length after checking that its length matches - /// the value tied to `'n`. - #[inline] - pub fn new(slice: GroupFor, len: Size<'n>) -> Self { - let slice = SliceGroupMut::<'_, E>::new(slice); - assert!(slice.rb().len() == len.into_inner()); - ArrayGroupMut(Branded { - __marker: PhantomData, - inner: slice, - }) - } - - /// Returns the unconstrained slice. - #[inline] - pub fn into_slice(self) -> GroupFor { - self.0.inner.into_inner() - } - - /// Returns a subslice at from the range start to its end. - #[inline] - pub fn subslice( - self, - range: Range>, - ) -> GroupFor { - unsafe { - SliceGroupMut::<'_, E>::new(self.into_slice()) - .subslice_unchecked(range.start.into_inner()..range.end.into_inner()) - .into_inner() - } - } - - /// Read the element at position `j`. - #[inline] - pub fn read(&self, j: Idx<'n, usize>) -> E { - self.rb().read(j) - } - - /// Write `value` to the location at position `j`. - #[inline] - pub fn write(&mut self, j: Idx<'n, usize>, value: E) { - unsafe { - SliceGroupMut::new(self.rb_mut().into_slice()) - .write_unchecked(j.into_inner(), value) - } - } - } - - impl<'n, 'a, E: Entity> ArrayGroup<'n, 'a, E> { - /// Returns an array group with length after checking that its length matches - /// the value tied to `'n`. - #[inline] - pub fn new(slice: GroupFor, len: Size<'n>) -> Self { - let slice = SliceGroup::<'_, E>::new(slice); - assert!(slice.rb().len() == len.into_inner()); - ArrayGroup(Branded { - __marker: PhantomData, - inner: slice, - }) - } - - /// Returns the unconstrained slice. - #[inline] - pub fn into_slice(self) -> GroupFor { - self.0.inner.into_inner() - } - - /// Returns a subslice at from the range start to its end. - #[inline] - pub fn subslice( - self, - range: Range>, - ) -> GroupFor { - unsafe { - SliceGroup::<'_, E>::new(self.into_slice()) - .subslice_unchecked(range.start.into_inner()..range.end.into_inner()) - .into_inner() - } - } - - /// Read the element at position `j`. - #[inline] - pub fn read(&self, j: Idx<'n, usize>) -> E { - unsafe { SliceGroup::new(self.into_slice()).read_unchecked(j.into_inner()) } - } - } - } - - impl<'size> Size<'size> { - /// Create a new [`Size`] with a lifetime tied to `n`. - #[track_caller] - #[inline] - pub fn with(n: usize, f: impl for<'n> FnOnce(Size<'n>) -> R) -> R { - f(Size(Branded { - __marker: PhantomData, - inner: n, - })) - } - - /// Create two new [`Size`] with lifetimes tied to `m` and `n`. - #[track_caller] - #[inline] - pub fn with2( - m: usize, - n: usize, - f: impl for<'m, 'n> FnOnce(Size<'m>, Size<'n>) -> R, - ) -> R { - f( - Size(Branded { - __marker: PhantomData, - inner: m, - }), - Size(Branded { - __marker: PhantomData, - inner: n, - }), - ) - } - - /// Create a new [`Size`] tied to the lifetime `'n`. - #[inline] - pub unsafe fn new_raw_unchecked(n: usize) -> Self { - Size(Branded { - __marker: PhantomData, - inner: n, - }) - } - - /// Returns the unconstrained value. - #[inline] - pub fn into_inner(self) -> usize { - self.0.inner - } - - /// Returns an iterator of the indices smaller than `self`. - #[inline] - pub fn indices(self) -> impl DoubleEndedIterator> { - (0..self.0.inner).map(|i| unsafe { Idx::new_raw_unchecked(i) }) - } - - /// Check that the index is bounded by `self`, or panic otherwise. - #[track_caller] - #[inline] - pub fn check(self, idx: I) -> Idx<'size, I> { - Idx::new_checked(idx, self) - } - - /// Check that the index is bounded by `self`, or return `None` otherwise. - #[inline] - pub fn try_check(self, idx: I) -> Option> { - if idx.zx() < self.into_inner() { - Some(Idx(Branded { - __marker: PhantomData, - inner: idx, - })) - } else { - None - } - } - } - - impl<'n> Idx<'n, usize> { - /// Truncate `self` to a smaller type `I`. - pub fn truncate(self) -> Idx<'n, I> { - unsafe { Idx::new_raw_unchecked(I::truncate(self.into_inner())) } - } - } - - impl<'n, I: Index> Idx<'n, I> { - /// Returns a new index after asserting that it's bounded by `size`. - #[track_caller] - #[inline] - pub fn new_checked(idx: I, size: Size<'n>) -> Self { - assert!(idx.zx() < size.into_inner()); - Self(Branded { - __marker: PhantomData, - inner: idx, - }) - } - /// Returns a new index without asserting that it's bounded by `size`. - #[track_caller] - #[inline] - pub unsafe fn new_unchecked(idx: I, size: Size<'n>) -> Self { - debug_assert!(idx.zx() < size.into_inner()); - Self(Branded { - __marker: PhantomData, - inner: idx, - }) - } - - /// Returns a new index without asserting that it's bounded by the value tied to the - /// lifetime `'n`. - #[inline] - pub unsafe fn new_raw_unchecked(idx: I) -> Self { - Self(Branded { - __marker: PhantomData, - inner: idx, - }) - } - - /// Returns the unconstrained value. - #[inline] - pub fn into_inner(self) -> I { - self.0.inner - } - - /// Zero extend the value. - #[inline] - pub fn zx(self) -> Idx<'n, usize> { - unsafe { Idx::new_raw_unchecked(self.0.inner.zx()) } - } - - /// Unimplemented: Sign extend the value. - #[inline] - pub fn sx(self) -> ! { - unimplemented!() - } - - /// Returns the index, bounded inclusively by the value tied to `'n`. - #[inline] - pub fn to_inclusive(self) -> IdxInclusive<'n, I> { - unsafe { IdxInclusive::new_raw_unchecked(self.into_inner()) } - } - /// Returns the next index, bounded inclusively by the value tied to `'n`. - #[inline] - pub fn next(self) -> IdxInclusive<'n, I> { - unsafe { IdxInclusive::new_raw_unchecked(self.into_inner() + I::truncate(1)) } - } - - /// Assert that the values of `slice` are all bounded by `size`. - #[track_caller] - #[inline] - pub fn from_slice_mut_checked<'a>( - slice: &'a mut [I], - size: Size<'n>, - ) -> &'a mut [Idx<'n, I>] { - Self::from_slice_ref_checked(slice, size); - unsafe { &mut *(slice as *mut _ as *mut _) } - } - - /// Assume that the values of `slice` are all bounded by the value tied to `'n`. - #[track_caller] - #[inline] - pub unsafe fn from_slice_mut_unchecked<'a>(slice: &'a mut [I]) -> &'a mut [Idx<'n, I>] { - unsafe { &mut *(slice as *mut _ as *mut _) } - } - - /// Assert that the values of `slice` are all bounded by `size`. - #[track_caller] - pub fn from_slice_ref_checked<'a>(slice: &'a [I], size: Size<'n>) -> &'a [Idx<'n, I>] { - for &idx in slice { - Self::new_checked(idx, size); - } - unsafe { &*(slice as *const _ as *const _) } - } - - /// Assume that the values of `slice` are all bounded by the value tied to `'n`. - #[track_caller] - #[inline] - pub unsafe fn from_slice_ref_unchecked<'a>(slice: &'a [I]) -> &'a [Idx<'n, I>] { - unsafe { &*(slice as *const _ as *const _) } - } - } - - impl<'n, I: Index> MaybeIdx<'n, I> { - /// Returns an index value. - #[inline] - pub fn from_index(idx: Idx<'n, I>) -> Self { - unsafe { Self::new_raw_unchecked(idx.into_inner()) } - } - /// Returns a `None` value. - #[inline] - pub fn none() -> Self { - unsafe { Self::new_raw_unchecked(I::truncate(usize::MAX)) } - } - - /// Returns a constrained index value if `idx` is nonnegative, `None` otherwise. - #[inline] - pub fn new_checked(idx: I::Signed, size: Size<'n>) -> Self { - assert!((idx.sx() as isize) < size.into_inner() as isize); - Self(Branded { - __marker: PhantomData, - inner: I::from_signed(idx), - }) - } - - /// Returns a constrained index value if `idx` is nonnegative, `None` otherwise. - #[inline] - pub unsafe fn new_unchecked(idx: I::Signed, size: Size<'n>) -> Self { - debug_assert!((idx.sx() as isize) < size.into_inner() as isize); - Self(Branded { - __marker: PhantomData, - inner: I::from_signed(idx), - }) - } - - /// Returns a constrained index value if `idx` is nonnegative, `None` otherwise. - #[inline] - pub unsafe fn new_raw_unchecked(idx: I) -> Self { - Self(Branded { - __marker: PhantomData, - inner: idx, - }) - } - - /// Returns the inner value. - #[inline] - pub fn into_inner(self) -> I { - self.0.inner - } - - /// Returns the index if available, or `None` otherwise. - #[inline] - pub fn idx(self) -> Option> { - if self.0.inner.to_signed() >= I::Signed::truncate(0) { - Some(unsafe { Idx::new_raw_unchecked(self.into_inner()) }) - } else { - None - } - } - - /// Unimplemented: Zero extend the value. - #[inline] - pub fn zx(self) -> ! { - unimplemented!() - } - - /// Sign extend the value. - #[inline] - pub fn sx(self) -> MaybeIdx<'n, usize> { - unsafe { MaybeIdx::new_raw_unchecked(self.0.inner.to_signed().sx()) } - } - - /// Assert that the values of `slice` are all bounded by `size`. - #[track_caller] - #[inline] - pub fn from_slice_mut_checked<'a>( - slice: &'a mut [I::Signed], - size: Size<'n>, - ) -> &'a mut [MaybeIdx<'n, I>] { - Self::from_slice_ref_checked(slice, size); - unsafe { &mut *(slice as *mut _ as *mut _) } - } - - /// Assume that the values of `slice` are all bounded by the value tied to `'n`. - #[track_caller] - #[inline] - pub unsafe fn from_slice_mut_unchecked<'a>( - slice: &'a mut [I::Signed], - ) -> &'a mut [MaybeIdx<'n, I>] { - unsafe { &mut *(slice as *mut _ as *mut _) } - } - - /// Assert that the values of `slice` are all bounded by `size`. - #[track_caller] - pub fn from_slice_ref_checked<'a>( - slice: &'a [I::Signed], - size: Size<'n>, - ) -> &'a [MaybeIdx<'n, I>] { - for &idx in slice { - Self::new_checked(idx, size); - } - unsafe { &*(slice as *const _ as *const _) } - } - - /// Convert a constrained slice to an unconstrained one. - #[track_caller] - pub fn as_slice_ref<'a>(slice: &'a [MaybeIdx<'n, I>]) -> &'a [I::Signed] { - unsafe { &*(slice as *const _ as *const _) } - } - - /// Assume that the values of `slice` are all bounded by the value tied to `'n`. - #[track_caller] - #[inline] - pub unsafe fn from_slice_ref_unchecked<'a>( - slice: &'a [I::Signed], - ) -> &'a [MaybeIdx<'n, I>] { - unsafe { &*(slice as *const _ as *const _) } - } - } - - impl<'n> IdxInclusive<'n, usize> { - /// Returns an iterator over constrained indices from `0` to `self` (exclusive). - #[inline] - pub fn range_to(self, last: Self) -> impl DoubleEndedIterator> { - (*self..*last).map( - #[inline(always)] - |idx| unsafe { Idx::new_raw_unchecked(idx) }, - ) - } - } - - impl<'n, I: Index> IdxInclusive<'n, I> { - /// Returns a constrained inclusive index after checking that it's bounded (inclusively) by - /// `size`. - #[inline] - pub fn new_checked(idx: I, size: Size<'n>) -> Self { - assert!(idx.zx() <= size.into_inner()); - Self(Branded { - __marker: PhantomData, - inner: idx, - }) - } - /// Returns a constrained inclusive index, assuming that it's bounded (inclusively) by - /// `size`. - #[inline] - pub unsafe fn new_unchecked(idx: I, size: Size<'n>) -> Self { - debug_assert!(idx.zx() <= size.into_inner()); - Self(Branded { - __marker: PhantomData, - inner: idx, - }) - } - - /// Returns a constrained inclusive index, assuming that it's bounded (inclusively) by - /// the size tied to `'n`. - #[inline] - pub unsafe fn new_raw_unchecked(idx: I) -> Self { - Self(Branded { - __marker: PhantomData, - inner: idx, - }) - } - - /// Returns the unconstrained value. - #[inline] - pub fn into_inner(self) -> I { - self.0.inner - } - - /// Unimplemented: Sign extend the value. - #[inline] - pub fn sx(self) -> ! { - unimplemented!() - } - /// Unimplemented: Zero extend the value. - #[inline] - pub fn zx(self) -> ! { - unimplemented!() - } - } - - impl<'n, T> Array<'n, T> { - /// Returns a constrained array after checking that its length matches `size`. - #[inline] - #[track_caller] - pub fn from_ref<'a>(slice: &'a [T], size: Size<'n>) -> &'a Self { - assert!(slice.len() == size.into_inner()); - unsafe { &*(slice as *const [T] as *const Self) } - } - - /// Returns a constrained array after checking that its length matches `size`. - #[inline] - #[track_caller] - pub fn from_mut<'a>(slice: &'a mut [T], size: Size<'n>) -> &'a mut Self { - assert!(slice.len() == size.into_inner()); - unsafe { &mut *(slice as *mut [T] as *mut Self) } - } - - /// Returns the unconstrained slice. - #[inline] - #[track_caller] - pub fn as_ref(&self) -> &[T] { - unsafe { &*(self as *const _ as *const _) } - } - - /// Returns the unconstrained slice. - #[inline] - #[track_caller] - pub fn as_mut<'a>(&mut self) -> &'a mut [T] { - unsafe { &mut *(self as *mut _ as *mut _) } - } - - /// Returns the length of `self`. - #[inline] - pub fn len(&self) -> Size<'n> { - unsafe { Size::new_raw_unchecked(self.0.inner.len()) } - } - } - - impl<'nrows, 'ncols, 'a, E: Entity> MatRef<'nrows, 'ncols, 'a, E> { - /// Returns a new matrix view after checking that its dimensions match the - /// dimensions tied to `('nrows, 'ncols)`. - #[inline] - #[track_caller] - pub fn new(inner: super::MatRef<'a, E>, nrows: Size<'nrows>, ncols: Size<'ncols>) -> Self { - assert!(all( - inner.nrows() == nrows.into_inner(), - inner.ncols() == ncols.into_inner(), - )); - Self(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } - - /// Returns the number of rows of the matrix. - #[inline] - pub fn nrows(&self) -> Size<'nrows> { - unsafe { Size::new_raw_unchecked(self.0.inner.inner.nrows()) } - } - - /// Returns the number of columns of the matrix. - #[inline] - pub fn ncols(&self) -> Size<'ncols> { - unsafe { Size::new_raw_unchecked(self.0.inner.inner.ncols()) } - } - - /// Returns the unconstrained matrix. - #[inline] - pub fn into_inner(self) -> super::MatRef<'a, E> { - self.0.inner.inner - } - - /// Returns the element at position `(i, j)`. - #[inline] - #[track_caller] - pub fn read(&self, i: Idx<'nrows, usize>, j: Idx<'ncols, usize>) -> E { - unsafe { - self.0 - .inner - .inner - .read_unchecked(i.into_inner(), j.into_inner()) - } - } - } - - impl<'nrows, 'ncols, 'a, E: Entity> MatMut<'nrows, 'ncols, 'a, E> { - /// Returns a new matrix view after checking that its dimensions match the - /// dimensions tied to `('nrows, 'ncols)`. - #[inline] - #[track_caller] - pub fn new(inner: super::MatMut<'a, E>, nrows: Size<'nrows>, ncols: Size<'ncols>) -> Self { - assert!(all( - inner.nrows() == nrows.into_inner(), - inner.ncols() == ncols.into_inner(), - )); - Self(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } - - /// Returns the number of rows of the matrix. - #[inline] - pub fn nrows(&self) -> Size<'nrows> { - unsafe { Size::new_raw_unchecked(self.0.inner.inner.nrows()) } - } - - /// Returns the number of columns of the matrix. - #[inline] - pub fn ncols(&self) -> Size<'ncols> { - unsafe { Size::new_raw_unchecked(self.0.inner.inner.ncols()) } - } - - /// Returns the unconstrained matrix. - #[inline] - pub fn into_inner(self) -> super::MatMut<'a, E> { - self.0.inner.inner - } - - /// Returns the element at position `(i, j)`. - #[inline] - #[track_caller] - pub fn read(&self, i: Idx<'nrows, usize>, j: Idx<'ncols, usize>) -> E { - unsafe { - self.0 - .inner - .inner - .read_unchecked(i.into_inner(), j.into_inner()) - } - } - - /// Writes `value` to the location at position `(i, j)`. - #[inline] - #[track_caller] - pub fn write(&mut self, i: Idx<'nrows, usize>, j: Idx<'ncols, usize>, value: E) { - unsafe { - self.0 - .inner - .inner - .write_unchecked(i.into_inner(), j.into_inner(), value) - }; - } - } - - impl Clone for MatRef<'_, '_, '_, E> { - #[inline] - fn clone(&self) -> Self { - *self - } - } - impl Copy for MatRef<'_, '_, '_, E> {} - - impl<'nrows, 'ncols, 'a, E: Entity> IntoConst for MatRef<'nrows, 'ncols, 'a, E> { - type Target = MatRef<'nrows, 'ncols, 'a, E>; - #[inline] - fn into_const(self) -> Self::Target { - self - } - } - impl<'nrows, 'ncols, 'a, 'short, E: Entity> Reborrow<'short> for MatRef<'nrows, 'ncols, 'a, E> { - type Target = MatRef<'nrows, 'ncols, 'short, E>; - #[inline] - fn rb(&'short self) -> Self::Target { - *self - } - } - impl<'nrows, 'ncols, 'a, 'short, E: Entity> ReborrowMut<'short> for MatRef<'nrows, 'ncols, 'a, E> { - type Target = MatRef<'nrows, 'ncols, 'short, E>; - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } - - impl<'nrows, 'ncols, 'a, E: Entity> IntoConst for MatMut<'nrows, 'ncols, 'a, E> { - type Target = MatRef<'nrows, 'ncols, 'a, E>; - #[inline] - fn into_const(self) -> Self::Target { - let inner = self.0.inner.inner.into_const(); - MatRef(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } - } - impl<'nrows, 'ncols, 'a, 'short, E: Entity> Reborrow<'short> for MatMut<'nrows, 'ncols, 'a, E> { - type Target = MatRef<'nrows, 'ncols, 'short, E>; - #[inline] - fn rb(&'short self) -> Self::Target { - let inner = self.0.inner.inner.rb(); - MatRef(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } - } - impl<'nrows, 'ncols, 'a, 'short, E: Entity> ReborrowMut<'short> for MatMut<'nrows, 'ncols, 'a, E> { - type Target = MatMut<'nrows, 'ncols, 'short, E>; - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - let inner = self.0.inner.inner.rb_mut(); - MatMut(Branded { - __marker: PhantomData, - inner: Branded { - __marker: PhantomData, - inner, - }, - }) - } - } - - impl Debug for Size<'_> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.fmt(f) - } - } - impl Debug for Idx<'_, I> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.fmt(f) - } - } - impl Debug for IdxInclusive<'_, I> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.fmt(f) - } - } - impl Debug for MaybeIdx<'_, I> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - #[derive(Debug)] - struct None; - - match self.idx() { - Some(idx) => idx.fmt(f), - Option::None => None.fmt(f), - } - } - } - impl Debug for Array<'_, T> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.fmt(f) - } - } - impl Debug for MatRef<'_, '_, '_, E> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.inner.fmt(f) - } - } - impl Debug for MatMut<'_, '_, '_, E> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.0.inner.inner.fmt(f) - } - } - - impl<'n, T> core::ops::Index>> for Array<'n, T> { - type Output = [T]; - #[track_caller] - fn index(&self, idx: Range>) -> &Self::Output { - #[cfg(debug_assertions)] - { - &self.0.inner[idx.start.into_inner()..idx.end.into_inner()] - } - #[cfg(not(debug_assertions))] - unsafe { - self.0 - .inner - .get_unchecked(idx.start.into_inner()..idx.end.into_inner()) - } - } - } - impl<'n, T> core::ops::IndexMut>> for Array<'n, T> { - #[track_caller] - fn index_mut(&mut self, idx: Range>) -> &mut Self::Output { - #[cfg(debug_assertions)] - { - &mut self.0.inner[idx.start.into_inner()..idx.end.into_inner()] - } - #[cfg(not(debug_assertions))] - unsafe { - self.0 - .inner - .get_unchecked_mut(idx.start.into_inner()..idx.end.into_inner()) - } - } - } - impl<'n, T> core::ops::Index> for Array<'n, T> { - type Output = T; - #[track_caller] - fn index(&self, idx: Idx<'n, usize>) -> &Self::Output { - #[cfg(debug_assertions)] - { - &self.0.inner[idx.into_inner()] - } - #[cfg(not(debug_assertions))] - unsafe { - self.0.inner.get_unchecked(idx.into_inner()) - } - } - } - impl<'n, T> core::ops::IndexMut> for Array<'n, T> { - #[track_caller] - fn index_mut(&mut self, idx: Idx<'n, usize>) -> &mut Self::Output { - #[cfg(debug_assertions)] - { - &mut self.0.inner[idx.into_inner()] - } - #[cfg(not(debug_assertions))] - unsafe { - self.0.inner.get_unchecked_mut(idx.into_inner()) - } - } - } -} - -/// Kronecker product of two matrices. -/// -/// The Kronecker product of two matrices `A` and `B` is a block matrix -/// `C` with the following structure: -/// -/// ```text -/// C = [ a[(0, 0)] * B , a[(0, 1)] * B , ... , a[(0, n-1)] * B ] -/// [ a[(1, 0)] * B , a[(1, 1)] * B , ... , a[(1, n-1)] * B ] -/// [ ... , ... , ... , ... ] -/// [ a[(m-1, 0)] * B , a[(m-1, 1)] * B , ... , a[(m-1, n-1)] * B ] -/// ``` -/// -/// # Panics -/// -/// Panics if `dst` does not have the correct dimensions. The dimensions -/// of `dst` must be `nrows(A) * nrows(B)` by `ncols(A) * ncols(B)`. -/// -/// # Example -/// -/// ``` -/// use faer_core::{kron, mat, Mat}; -/// -/// let a = mat![[1.0, 2.0], [3.0, 4.0]]; -/// let b = mat![[0.0, 5.0], [6.0, 7.0]]; -/// let c = mat![ -/// [0.0, 5.0, 0.0, 10.0], -/// [6.0, 7.0, 12.0, 14.0], -/// [0.0, 15.0, 0.0, 20.0], -/// [18.0, 21.0, 24.0, 28.0], -/// ]; -/// let mut dst = Mat::new(); -/// dst.resize_with(4, 4, |_, _| 0f64); -/// kron(dst.as_mut(), a.as_ref(), b.as_ref()); -/// assert_eq!(dst, c); -/// ``` -#[track_caller] -pub fn kron(dst: MatMut, lhs: MatRef, rhs: MatRef) { - let mut dst = dst; - let mut lhs = lhs; - let mut rhs = rhs; - if dst.col_stride().unsigned_abs() < dst.row_stride().unsigned_abs() { - dst = dst.transpose_mut(); - lhs = lhs.transpose(); - rhs = rhs.transpose(); - } - - assert!(Some(dst.nrows()) == lhs.nrows().checked_mul(rhs.nrows())); - assert!(Some(dst.ncols()) == lhs.ncols().checked_mul(rhs.ncols())); - - for lhs_j in 0..lhs.ncols() { - for lhs_i in 0..lhs.nrows() { - let lhs_val = lhs.read(lhs_i, lhs_j); - let mut dst = dst.rb_mut().submatrix_mut( - lhs_i * rhs.nrows(), - lhs_j * rhs.ncols(), - rhs.nrows(), - rhs.ncols(), - ); - - for rhs_j in 0..rhs.ncols() { - for rhs_i in 0..rhs.nrows() { - // SAFETY: Bounds have been checked. - unsafe { - let rhs_val = rhs.read_unchecked(rhs_i, rhs_j); - dst.write_unchecked(rhs_i, rhs_j, lhs_val.faer_mul(rhs_val)); - } - } - } - } - } -} - -#[inline(always)] -fn norm_l2_with_simd_and_offset_prologue( - simd: S, - data: SliceGroup<'_, E>, - offset: pulp::Offset>, -) -> ( - SimdGroupFor, - SimdGroupFor, - SimdGroupFor, -) { - use group_helpers::*; - - let simd_real = SimdFor::::new(simd); - let simd = SimdFor::::new(simd); - let half_big = simd_real.splat(E::Real::faer_min_positive_sqrt_inv()); - let half_small = simd_real.splat(E::Real::faer_min_positive_sqrt()); - let zero = simd.splat(E::faer_zero()); - let zero_real = simd_real.splat(E::Real::faer_zero()); - - let (head, body, tail) = simd.as_aligned_simd(data, offset); - let (body2, body1) = body.as_arrays::<2>(); - - let mut acc0 = simd.abs2(head.read_or(zero)); - let mut acc1 = zero_real; - - let mut acc_small0 = simd.abs2(simd.scale_real(half_small, head.read_or(zero))); - let mut acc_small1 = zero_real; - - let mut acc_big0 = simd.abs2(simd.scale_real(half_big, head.read_or(zero))); - let mut acc_big1 = zero_real; - - for [x0, x1] in body2.into_ref_iter().map(RefGroup::unzip) { - let x0 = x0.get(); - let x1 = x1.get(); - acc0 = simd.abs2_add_e(x0, acc0); - acc1 = simd.abs2_add_e(x1, acc1); - - acc_small0 = simd.abs2_add_e(simd.scale_real(half_small, x0), acc_small0); - acc_small1 = simd.abs2_add_e(simd.scale_real(half_small, x1), acc_small1); - - acc_big0 = simd.abs2_add_e(simd.scale_real(half_big, x0), acc_big0); - acc_big1 = simd.abs2_add_e(simd.scale_real(half_big, x1), acc_big1); - } - - for x0 in body1.into_ref_iter() { - let x0 = x0.get(); - acc0 = simd.abs2_add_e(x0, acc0); - acc_small0 = simd.abs2_add_e(simd.scale_real(half_small, x0), acc_small0); - acc_big0 = simd.abs2_add_e(simd.scale_real(half_big, x0), acc_big0); - } - - acc0 = simd.abs2_add_e(tail.read_or(zero), acc0); - acc_small0 = simd.abs2_add_e(simd.scale_real(half_small, tail.read_or(zero)), acc_small0); - acc_big0 = simd.abs2_add_e(simd.scale_real(half_big, tail.read_or(zero)), acc_big0); - - acc0 = simd_real.add(acc0, acc1); - acc_small0 = simd_real.add(acc_small0, acc_small1); - acc_big0 = simd_real.add(acc_big0, acc_big1); - - (acc_small0, acc0, acc_big0) -} - -#[inline(always)] -fn sum_with_simd_and_offset_prologue( - simd: S, - data: SliceGroup<'_, E>, - offset: pulp::Offset>, -) -> SimdGroupFor { - use group_helpers::*; - - let simd = SimdFor::::new(simd); - - let zero = simd.splat(E::faer_zero()); - - let mut acc0 = zero; - let mut acc1 = zero; - let mut acc2 = zero; - let mut acc3 = zero; - let (head, body, tail) = simd.as_aligned_simd(data, offset); - let (body4, body1) = body.as_arrays::<4>(); - let head = head.read_or(zero); - acc0 = simd.add(acc0, head); - - for [x0, x1, x2, x3] in body4.into_ref_iter().map(RefGroup::unzip) { - let x0 = x0.get(); - let x1 = x1.get(); - let x2 = x2.get(); - let x3 = x3.get(); - acc0 = simd.add(acc0, x0); - acc1 = simd.add(acc1, x1); - acc2 = simd.add(acc2, x2); - acc3 = simd.add(acc3, x3); - } - - for x0 in body1.into_ref_iter() { - let x0 = x0.get(); - acc0 = simd.add(acc0, x0); - } - - let tail = tail.read_or(zero); - acc3 = simd.add(acc3, tail); - - acc0 = simd.add(acc0, acc1); - acc2 = simd.add(acc2, acc3); - simd.add(acc0, acc2) -} - -#[inline(always)] -fn norm_max_contiguous(data: MatRef<'_, E>) -> E { - struct Impl<'a, E: RealField> { - data: MatRef<'a, E>, - } - - impl pulp::WithSimd for Impl<'_, E> { - type Output = E; - - #[inline(always)] - fn with_simd(self, simd: S) -> Self::Output { - let Self { data } = self; - use group_helpers::*; - let m = data.nrows(); - let n = data.ncols(); - - let offset = SimdFor::::new(simd).align_offset_ptr(data.as_ptr(), m); - - let simd = SimdFor::::new(simd); - - let zero = simd.splat(E::faer_zero()); - - let mut acc0 = zero; - let mut acc1 = zero; - let mut acc2 = zero; - let mut acc3 = zero; - for j in 0..n { - let col = SliceGroup::<'_, E>::new(data.try_get_contiguous_col(j)); - let (head, body, tail) = simd.as_aligned_simd(col, offset); - let (body4, body1) = body.as_arrays::<4>(); - - let head = simd.abs(head.read_or(zero)); - acc0 = simd.select(simd.greater_than(head, acc0), head, acc0); - - for [x0, x1, x2, x3] in body4.into_ref_iter().map(RefGroup::unzip) { - let x0 = simd.abs(x0.get()); - let x1 = simd.abs(x1.get()); - let x2 = simd.abs(x2.get()); - let x3 = simd.abs(x3.get()); - acc0 = simd.select(simd.greater_than(x0, acc0), x0, acc0); - acc1 = simd.select(simd.greater_than(x1, acc1), x1, acc1); - acc2 = simd.select(simd.greater_than(x2, acc2), x2, acc2); - acc3 = simd.select(simd.greater_than(x3, acc3), x3, acc3); - } - - for x0 in body1.into_ref_iter() { - let x0 = simd.abs(x0.get()); - acc0 = simd.select(simd.greater_than(x0, acc0), x0, acc0); - } - - let tail = simd.abs(tail.read_or(zero)); - acc3 = simd.select(simd.greater_than(tail, acc3), tail, acc3); - } - acc0 = simd.select(simd.greater_than(acc0, acc1), acc0, acc1); - acc2 = simd.select(simd.greater_than(acc2, acc3), acc2, acc3); - acc0 = simd.select(simd.greater_than(acc0, acc2), acc0, acc2); - - let acc0 = from_copy::(simd.rotate_left(acc0, offset.rotate_left_amount())); - let acc = SliceGroup::<'_, E>::new(E::faer_map( - E::faer_as_ref(&acc0), - #[inline(always)] - |acc| bytemuck::cast_slice::<_, ::Unit>(core::slice::from_ref(acc)), - )); - let mut acc_scalar = E::faer_zero(); - for x in acc.into_ref_iter() { - let x = x.read(); - acc_scalar = if acc_scalar > x { acc_scalar } else { x }; - } - acc_scalar - } - } - - E::Simd::default().dispatch(Impl { data }) -} - -const NORM_L2_THRESHOLD: usize = 128; - -#[inline(always)] -fn norm_l2_with_simd_and_offset_pairwise_rows( - simd: S, - data: SliceGroup<'_, E>, - offset: pulp::Offset>, - last_offset: pulp::Offset>, -) -> ( - SimdGroupFor, - SimdGroupFor, - SimdGroupFor, -) { - struct Impl<'a, E: ComplexField, S: Simd> { - simd: S, - data: SliceGroup<'a, E>, - offset: pulp::Offset>, - last_offset: pulp::Offset>, - } - - impl pulp::NullaryFnOnce for Impl<'_, E, S> { - type Output = ( - SimdGroupFor, - SimdGroupFor, - SimdGroupFor, - ); - - #[inline(always)] - fn call(self) -> Self::Output { - let Self { - simd, - data, - offset, - last_offset, - } = self; - - if data.len() == NORM_L2_THRESHOLD { - norm_l2_with_simd_and_offset_prologue(simd, data, offset) - } else if data.len() < NORM_L2_THRESHOLD { - norm_l2_with_simd_and_offset_prologue(simd, data, last_offset) - } else { - let split_point = ((data.len() + 1) / 2).next_power_of_two(); - let (head, tail) = data.split_at(split_point); - let (acc_small0, acc0, acc_big0) = - norm_l2_with_simd_and_offset_pairwise_rows(simd, head, offset, last_offset); - let (acc_small1, acc1, acc_big1) = - norm_l2_with_simd_and_offset_pairwise_rows(simd, tail, offset, last_offset); - - use group_helpers::*; - let simd = SimdFor::::new(simd); - ( - simd.add(acc_small0, acc_small1), - simd.add(acc0, acc1), - simd.add(acc_big0, acc_big1), - ) - } - } - } - - simd.vectorize(Impl { - simd, - data, - offset, - last_offset, - }) -} - -#[inline(always)] -fn sum_with_simd_and_offset_pairwise_rows( - simd: S, - data: SliceGroup<'_, E>, - offset: pulp::Offset>, - last_offset: pulp::Offset>, -) -> SimdGroupFor { - struct Impl<'a, E: ComplexField, S: Simd> { - simd: S, - data: SliceGroup<'a, E>, - offset: pulp::Offset>, - last_offset: pulp::Offset>, - } - - impl pulp::NullaryFnOnce for Impl<'_, E, S> { - type Output = SimdGroupFor; - - #[inline(always)] - fn call(self) -> Self::Output { - let Self { - simd, - data, - offset, - last_offset, - } = self; - - if data.len() == NORM_L2_THRESHOLD { - sum_with_simd_and_offset_prologue(simd, data, offset) - } else if data.len() < NORM_L2_THRESHOLD { - sum_with_simd_and_offset_prologue(simd, data, last_offset) - } else { - let split_point = ((data.len() + 1) / 2).next_power_of_two(); - let (head, tail) = data.split_at(split_point); - let acc0 = sum_with_simd_and_offset_pairwise_rows(simd, head, offset, last_offset); - let acc1 = sum_with_simd_and_offset_pairwise_rows(simd, tail, offset, last_offset); - - use group_helpers::*; - let simd = SimdFor::::new(simd); - simd.add(acc0, acc1) - } - } - } - - simd.vectorize(Impl { - simd, - data, - offset, - last_offset, - }) -} - -#[inline(always)] -fn norm_l2_with_simd_and_offset_pairwise_cols( - simd: S, - data: MatRef<'_, E>, - offset: pulp::Offset>, - last_offset: pulp::Offset>, -) -> ( - SimdGroupFor, - SimdGroupFor, - SimdGroupFor, -) { - struct Impl<'a, E: ComplexField, S: Simd> { - simd: S, - data: MatRef<'a, E>, - offset: pulp::Offset>, - last_offset: pulp::Offset>, - } - - impl pulp::NullaryFnOnce for Impl<'_, E, S> { - type Output = ( - SimdGroupFor, - SimdGroupFor, - SimdGroupFor, - ); - - #[inline(always)] - fn call(self) -> Self::Output { - use group_helpers::*; - - let Self { - simd, - data, - offset, - last_offset, - } = self; - if data.ncols() == 1 { - norm_l2_with_simd_and_offset_pairwise_rows( - simd, - SliceGroup::<'_, E>::new(data.try_get_contiguous_col(0)), - offset, - last_offset, - ) - } else { - let split_point = (data.ncols() / 2).next_power_of_two(); - - let (head, tail) = data.split_at_col(split_point); - - let (acc_small0, acc0, acc_big0) = - norm_l2_with_simd_and_offset_pairwise_cols(simd, head, offset, last_offset); - let (acc_small1, acc1, acc_big1) = - norm_l2_with_simd_and_offset_pairwise_cols(simd, tail, offset, last_offset); - - let simd = SimdFor::::new(simd); - ( - simd.add(acc_small0, acc_small1), - simd.add(acc0, acc1), - simd.add(acc_big0, acc_big1), - ) - } - } - } - - simd.vectorize(Impl { - simd, - data, - offset, - last_offset, - }) -} - -#[inline(always)] -fn sum_with_simd_and_offset_pairwise_cols( - simd: S, - data: MatRef<'_, E>, - offset: pulp::Offset>, - last_offset: pulp::Offset>, -) -> SimdGroupFor { - struct Impl<'a, E: ComplexField, S: Simd> { - simd: S, - data: MatRef<'a, E>, - offset: pulp::Offset>, - last_offset: pulp::Offset>, - } - - impl pulp::NullaryFnOnce for Impl<'_, E, S> { - type Output = SimdGroupFor; - - #[inline(always)] - fn call(self) -> Self::Output { - use group_helpers::*; - - let Self { - simd, - data, - offset, - last_offset, - } = self; - if data.ncols() == 1 { - sum_with_simd_and_offset_pairwise_rows( - simd, - SliceGroup::<'_, E>::new(data.try_get_contiguous_col(0)), - offset, - last_offset, - ) - } else { - let split_point = (data.ncols() / 2).next_power_of_two(); - - let (head, tail) = data.split_at_col(split_point); - - let acc0 = sum_with_simd_and_offset_pairwise_cols(simd, head, offset, last_offset); - let acc1 = sum_with_simd_and_offset_pairwise_cols(simd, tail, offset, last_offset); - - let simd = SimdFor::::new(simd); - simd.add(acc0, acc1) - } - } - } - - simd.vectorize(Impl { - simd, - data, - offset, - last_offset, - }) -} - -fn norm_l2_contiguous(data: MatRef<'_, E>) -> (E::Real, E::Real, E::Real) { - struct Impl<'a, E: ComplexField> { - data: MatRef<'a, E>, - } - - impl pulp::WithSimd for Impl<'_, E> { - type Output = (E::Real, E::Real, E::Real); - - #[inline(always)] - fn with_simd(self, simd: S) -> Self::Output { - let Self { data } = self; - use group_helpers::*; - - let offset = - SimdFor::::new(simd).align_offset_ptr(data.as_ptr(), NORM_L2_THRESHOLD); - - let last_offset = SimdFor::::new(simd) - .align_offset_ptr(data.as_ptr(), data.nrows() % NORM_L2_THRESHOLD); - - let (acc_small, acc, acc_big) = - norm_l2_with_simd_and_offset_pairwise_cols(simd, data, offset, last_offset); - - let simd = SimdFor::::new(simd); - ( - simd.reduce_add(simd.rotate_left(acc_small, offset.rotate_left_amount())), - simd.reduce_add(simd.rotate_left(acc, offset.rotate_left_amount())), - simd.reduce_add(simd.rotate_left(acc_big, offset.rotate_left_amount())), - ) - } - } - - E::Simd::default().dispatch(Impl { data }) -} - -fn sum_contiguous(data: MatRef<'_, E>) -> E { - struct Impl<'a, E: ComplexField> { - data: MatRef<'a, E>, - } - - impl pulp::WithSimd for Impl<'_, E> { - type Output = E; - - #[inline(always)] - fn with_simd(self, simd: S) -> Self::Output { - let Self { data } = self; - use group_helpers::*; - - let offset = - SimdFor::::new(simd).align_offset_ptr(data.as_ptr(), NORM_L2_THRESHOLD); - - let last_offset = SimdFor::::new(simd) - .align_offset_ptr(data.as_ptr(), data.nrows() % NORM_L2_THRESHOLD); - - let acc = sum_with_simd_and_offset_pairwise_cols(simd, data, offset, last_offset); - - let simd = SimdFor::::new(simd); - simd.reduce_add(simd.rotate_left(acc, offset.rotate_left_amount())) - } - } - - E::Simd::default().dispatch(Impl { data }) -} - -fn norm_l2(mut mat: MatRef<'_, E>) -> E::Real { - if mat.ncols() > 1 && mat.col_stride().unsigned_abs() < mat.row_stride().unsigned_abs() { - mat = mat.transpose(); - } - if mat.row_stride() < 0 { - mat = mat.reverse_rows(); - } - - if mat.nrows() == 0 || mat.ncols() == 0 { - E::Real::faer_zero() - } else { - let m = mat.nrows(); - let n = mat.ncols(); - - let half_small = E::Real::faer_min_positive_sqrt(); - let half_big = E::Real::faer_min_positive_sqrt_inv(); - - let mut acc_small = E::Real::faer_zero(); - let mut acc = E::Real::faer_zero(); - let mut acc_big = E::Real::faer_zero(); - - if mat.row_stride() == 1 { - if coe::is_same::() { - let mat: MatRef<'_, c32> = coe::coerce(mat); - let mat = unsafe { - mat::from_raw_parts( - mat.as_ptr() as *const f32, - 2 * mat.nrows(), - mat.ncols(), - 1, - 2 * mat.col_stride(), - ) - }; - let (acc_small_, acc_, acc_big_) = norm_l2_contiguous::(mat); - acc_small = coe::coerce_static(acc_small_); - acc = coe::coerce_static(acc_); - acc_big = coe::coerce_static(acc_big_); - } else if coe::is_same::() { - let mat: MatRef<'_, c64> = coe::coerce(mat); - let mat = unsafe { - mat::from_raw_parts( - mat.as_ptr() as *const f64, - 2 * mat.nrows(), - mat.ncols(), - 1, - 2 * mat.col_stride(), - ) - }; - let (acc_small_, acc_, acc_big_) = norm_l2_contiguous::(mat); - acc_small = coe::coerce_static(acc_small_); - acc = coe::coerce_static(acc_); - acc_big = coe::coerce_static(acc_big_); - } else { - (acc_small, acc, acc_big) = norm_l2_contiguous(mat); - } - } else { - for j in 0..n { - for i in 0..m { - let val = mat.read(i, j); - let val_small = val.faer_scale_power_of_two(half_small); - let val_big = val.faer_scale_power_of_two(half_big); - - acc_small = acc_small.faer_add(val_small.faer_abs2()); - acc = acc.faer_add(val.faer_abs2()); - acc_big = acc_big.faer_add(val_big.faer_abs2()); - } - } - } - - if acc_small >= E::Real::faer_one() { - acc_small.faer_sqrt().faer_mul(half_big) - } else if acc_big <= E::Real::faer_one() { - acc_big.faer_sqrt().faer_mul(half_small) - } else { - acc.faer_sqrt() - } - } -} - -fn sum(mut mat: MatRef<'_, E>) -> E { - if mat.ncols() > 1 && mat.col_stride().unsigned_abs() < mat.row_stride().unsigned_abs() { - mat = mat.transpose(); - } - if mat.row_stride() < 0 { - mat = mat.reverse_rows(); - } - - if mat.nrows() == 0 || mat.ncols() == 0 { - E::faer_zero() - } else { - let m = mat.nrows(); - let n = mat.ncols(); - - let mut acc = E::faer_zero(); - - if mat.row_stride() == 1 { - acc = sum_contiguous(mat); - } else { - for j in 0..n { - for i in 0..m { - acc = acc.faer_add(mat.read(i, j)); - } - } - } - - acc - } -} -fn norm_max(mut mat: MatRef<'_, E>) -> E::Real { - if mat.ncols() > 1 && mat.col_stride().unsigned_abs() < mat.row_stride().unsigned_abs() { - mat = mat.transpose(); - } - if mat.row_stride() < 0 { - mat = mat.reverse_rows(); - } - - if mat.nrows() == 0 || mat.ncols() == 0 { - E::Real::faer_zero() - } else { - let m = mat.nrows(); - let n = mat.ncols(); - - if mat.row_stride() == 1 { - if coe::is_same::() { - let mat: MatRef<'_, c32> = coe::coerce(mat); - let mat = unsafe { - mat::from_raw_parts( - mat.as_ptr() as *const f32, - 2 * mat.nrows(), - mat.ncols(), - 1, - 2 * mat.col_stride(), - ) - }; - return coe::coerce_static(norm_max_contiguous::(mat)); - } else if coe::is_same::() { - let mat: MatRef<'_, c64> = coe::coerce(mat); - let mat = unsafe { - mat::from_raw_parts( - mat.as_ptr() as *const f64, - 2 * mat.nrows(), - mat.ncols(), - 1, - 2 * mat.col_stride(), - ) - }; - return coe::coerce_static(norm_max_contiguous::(mat)); - } else if coe::is_same::>() { - let mat: MatRef<'_, num_complex::Complex> = coe::coerce(mat); - let num_complex::Complex { re, im } = mat.real_imag(); - let re = norm_max_contiguous(re); - let im = norm_max_contiguous(im); - return if re > im { re } else { im }; - } else if coe::is_same::() { - let mat: MatRef<'_, E::Real> = coe::coerce(mat); - return norm_max_contiguous(mat); - } - } - - let mut acc = E::Real::faer_zero(); - for j in 0..n { - for i in 0..m { - let val = mat.read(i, j); - let re = val.faer_real(); - let im = val.faer_imag(); - acc = if re > acc { re } else { acc }; - acc = if im > acc { im } else { acc }; - } - } - acc - } -} - -/// Matrix view creation module. -pub mod mat { - use super::*; - - /// Creates a `MatRef` from pointers to the matrix data, dimensions, and strides. - /// - /// The row (resp. column) stride is the offset from the memory address of a given matrix - /// element at indices `(row: i, col: j)`, to the memory address of the matrix element at - /// indices `(row: i + 1, col: 0)` (resp. `(row: 0, col: i + 1)`). This offset is specified in - /// number of elements, not in bytes. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * For each matrix unit, the entire memory region addressed by the matrix must be contained - /// within a single allocation, accessible in its entirety by the corresponding pointer in - /// `ptr`. - /// * For each matrix unit, the corresponding pointer must be properly aligned, - /// even for a zero-sized matrix. - /// * The values accessible by the matrix must be initialized at some point before they are - /// read, or references to them are formed. - /// * No mutable aliasing is allowed. In other words, none of the elements accessible by any - /// matrix unit may be accessed for writes by any other means for the duration of the lifetime - /// `'a`. - /// - /// # Example - /// - /// ``` - /// use faer_core::mat; - /// - /// // row major matrix with 2 rows, 3 columns, with a column at the end that we want to skip. - /// // the row stride is the pointer offset from the address of 1.0 to the address of 4.0, - /// // which is 4. - /// // the column stride is the pointer offset from the address of 1.0 to the address of 2.0, - /// // which is 1. - /// let data = [[1.0, 2.0, 3.0, f64::NAN], [4.0, 5.0, 6.0, f64::NAN]]; - /// let matrix = unsafe { mat::from_raw_parts::(data.as_ptr() as *const f64, 2, 3, 4, 1) }; - /// - /// let expected = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - /// assert_eq!(expected.as_ref(), matrix); - /// ``` - #[inline(always)] - pub unsafe fn from_raw_parts<'a, E: Entity>( - ptr: GroupFor, - nrows: usize, - ncols: usize, - row_stride: isize, - col_stride: isize, - ) -> MatRef<'a, E> { - MatRef { - inner: inner::DenseRef { - inner: MatImpl { - ptr: into_copy::(E::faer_map(ptr, |ptr| { - NonNull::new_unchecked(ptr as *mut E::Unit) - })), - nrows, - ncols, - row_stride, - col_stride, - }, - __marker: PhantomData, - }, - } - } - - /// Creates a `MatMut` from pointers to the matrix data, dimensions, and strides. - /// - /// The row (resp. column) stride is the offset from the memory address of a given matrix - /// element at indices `(row: i, col: j)`, to the memory address of the matrix element at - /// indices `(row: i + 1, col: 0)` (resp. `(row: 0, col: i + 1)`). This offset is specified in - /// number of elements, not in bytes. - /// - /// # Safety - /// The behavior is undefined if any of the following conditions are violated: - /// * For each matrix unit, the entire memory region addressed by the matrix must be contained - /// within a single allocation, accessible in its entirety by the corresponding pointer in - /// `ptr`. - /// * For each matrix unit, the corresponding pointer must be non null and properly aligned, - /// even for a zero-sized matrix. - /// * The values accessible by the matrix must be initialized at some point before they are - /// read, or - /// references to them are formed. - /// * No aliasing (including self aliasing) is allowed. In other words, none of the elements - /// accessible by any matrix unit may be accessed for reads or writes by any other means for - /// the duration of the lifetime `'a`. No two elements within a single matrix unit may point to - /// the same address (such a thing can be achieved with a zero stride, for example), and no two - /// matrix units may point to the same address. - /// - /// # Example - /// - /// ``` - /// use faer_core::mat; - /// - /// // row major matrix with 2 rows, 3 columns, with a column at the end that we want to skip. - /// // the row stride is the pointer offset from the address of 1.0 to the address of 4.0, - /// // which is 4. - /// // the column stride is the pointer offset from the address of 1.0 to the address of 2.0, - /// // which is 1. - /// let mut data = [[1.0, 2.0, 3.0, f64::NAN], [4.0, 5.0, 6.0, f64::NAN]]; - /// let mut matrix = - /// unsafe { mat::from_raw_parts_mut::(data.as_mut_ptr() as *mut f64, 2, 3, 4, 1) }; - /// - /// let expected = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - /// assert_eq!(expected.as_ref(), matrix); - /// ``` - #[inline(always)] - pub unsafe fn from_raw_parts_mut<'a, E: Entity>( - ptr: GroupFor, - nrows: usize, - ncols: usize, - row_stride: isize, - col_stride: isize, - ) -> MatMut<'a, E> { - MatMut { - inner: inner::DenseMut { - inner: MatImpl { - ptr: into_copy::(E::faer_map(ptr, |ptr| { - NonNull::new_unchecked(ptr as *mut E::Unit) - })), - nrows, - ncols, - row_stride, - col_stride, - }, - __marker: PhantomData, - }, - } - } - - /// Creates a `MatRef` from slice views over the matrix data, and the matrix dimensions. - /// The data is interpreted in a column-major format, so that the first chunk of `nrows` - /// values from the slices goes in the first column of the matrix, the second chunk of `nrows` - /// values goes in the second column, and so on. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `nrows * ncols == slice.len()` - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64]; - /// let view = mat::from_column_major_slice::(&slice, 3, 2); - /// - /// let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; - /// assert_eq!(expected, view); - /// ``` - #[track_caller] - #[inline(always)] - pub fn from_column_major_slice<'a, E: Entity>( - slice: GroupFor, - nrows: usize, - ncols: usize, - ) -> MatRef<'a, E> { - from_slice_assert( - nrows, - ncols, - SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len(), - ); - - unsafe { - from_raw_parts( - E::faer_map( - slice, - #[inline(always)] - |slice| slice.as_ptr(), - ), - nrows, - ncols, - 1, - nrows as isize, - ) - } - } - - /// Creates a `MatRef` from slice views over the matrix data, and the matrix dimensions. - /// The data is interpreted in a row-major format, so that the first chunk of `ncols` - /// values from the slices goes in the first column of the matrix, the second chunk of `ncols` - /// values goes in the second column, and so on. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `nrows * ncols == slice.len()` - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64]; - /// let view = mat::from_row_major_slice::(&slice, 3, 2); - /// - /// let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]; - /// assert_eq!(expected, view); - /// ``` - #[track_caller] - #[inline(always)] - pub fn from_row_major_slice<'a, E: Entity>( - slice: GroupFor, - nrows: usize, - ncols: usize, - ) -> MatRef<'a, E> { - from_column_major_slice(slice, ncols, nrows).transpose() - } - - /// Creates a `MatRef` from slice views over the matrix data, and the matrix dimensions. - /// The data is interpreted in a column-major format, where the beginnings of two consecutive - /// columns are separated by `col_stride` elements. - #[track_caller] - pub fn from_column_major_slice_with_stride<'a, E: Entity>( - slice: GroupFor, - nrows: usize, - ncols: usize, - col_stride: usize, - ) -> MatRef<'a, E> { - from_strided_column_major_slice_assert( - nrows, - ncols, - col_stride, - SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len(), - ); - - unsafe { - from_raw_parts( - E::faer_map( - slice, - #[inline(always)] - |slice| slice.as_ptr(), - ), - nrows, - ncols, - 1, - col_stride as isize, - ) - } - } - - /// Creates a `MatMut` from slice views over the matrix data, and the matrix dimensions. - /// The data is interpreted in a column-major format, so that the first chunk of `nrows` - /// values from the slices goes in the first column of the matrix, the second chunk of `nrows` - /// values goes in the second column, and so on. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `nrows * ncols == slice.len()` - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let mut slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64]; - /// let view = mat::from_column_major_slice_mut::(&mut slice, 3, 2); - /// - /// let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; - /// assert_eq!(expected, view); - /// ``` - #[track_caller] - pub fn from_column_major_slice_mut<'a, E: Entity>( - slice: GroupFor, - nrows: usize, - ncols: usize, - ) -> MatMut<'a, E> { - from_slice_assert( - nrows, - ncols, - SliceGroup::<'_, E>::new(E::faer_rb(E::faer_as_ref(&slice))).len(), - ); - unsafe { - from_raw_parts_mut( - E::faer_map( - slice, - #[inline(always)] - |slice| slice.as_mut_ptr(), - ), - nrows, - ncols, - 1, - nrows as isize, - ) - } - } - - /// Creates a `MatMut` from slice views over the matrix data, and the matrix dimensions. - /// The data is interpreted in a row-major format, so that the first chunk of `ncols` - /// values from the slices goes in the first column of the matrix, the second chunk of `ncols` - /// values goes in the second column, and so on. - /// - /// # Panics - /// The function panics if any of the following conditions are violated: - /// * `nrows * ncols == slice.len()` - /// - /// # Example - /// ``` - /// use faer_core::mat; - /// - /// let mut slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64]; - /// let view = mat::from_row_major_slice_mut::(&mut slice, 3, 2); - /// - /// let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]; - /// assert_eq!(expected, view); - /// ``` - #[inline(always)] - #[track_caller] - pub fn from_row_major_slice_mut<'a, E: Entity>( - slice: GroupFor, - nrows: usize, - ncols: usize, - ) -> MatMut<'a, E> { - from_column_major_slice_mut(slice, ncols, nrows).transpose_mut() - } - - /// Creates a `MatMut` from slice views over the matrix data, and the matrix dimensions. - /// The data is interpreted in a column-major format, where the beginnings of two consecutive - /// columns are separated by `col_stride` elements. - #[track_caller] - pub fn from_column_major_slice_with_stride_mut<'a, E: Entity>( - slice: GroupFor, - nrows: usize, - ncols: usize, - col_stride: usize, - ) -> MatMut<'a, E> { - from_strided_column_major_slice_mut_assert( - nrows, - ncols, - col_stride, - SliceGroup::<'_, E>::new(E::faer_rb(E::faer_as_ref(&slice))).len(), - ); - unsafe { - from_raw_parts_mut( - E::faer_map( - slice, - #[inline(always)] - |slice| slice.as_mut_ptr(), - ), - nrows, - ncols, - 1, - col_stride as isize, - ) - } - } -} - -/// Column view creation module. -pub mod col { - use super::*; - - /// Creates a `ColRef` from pointers to the column vector data, number of rows, and row stride. - /// - /// # Safety: - /// This function has the same safety requirements as - /// [`mat::from_raw_parts(ptr, nrows, 1, row_stride, 0)`] - #[inline(always)] - pub unsafe fn from_raw_parts<'a, E: Entity>( - ptr: GroupFor, - nrows: usize, - row_stride: isize, - ) -> ColRef<'a, E> { - ColRef { - inner: inner::DenseColRef { - inner: VecImpl { - ptr: into_copy::(E::faer_map(ptr, |ptr| { - NonNull::new_unchecked(ptr as *mut E::Unit) - })), - len: nrows, - stride: row_stride, - }, - __marker: PhantomData, - }, - } - } - - /// Creates a `ColMut` from pointers to the column vector data, number of rows, and row stride. - /// - /// # Safety: - /// This function has the same safety requirements as - /// [`mat::from_raw_parts_mut(ptr, nrows, 1, row_stride, 0)`] - #[inline(always)] - pub unsafe fn from_raw_parts_mut<'a, E: Entity>( - ptr: GroupFor, - nrows: usize, - row_stride: isize, - ) -> ColMut<'a, E> { - ColMut { - inner: inner::DenseColMut { - inner: VecImpl { - ptr: into_copy::(E::faer_map(ptr, |ptr| { - NonNull::new_unchecked(ptr as *mut E::Unit) - })), - len: nrows, - stride: row_stride, - }, - __marker: PhantomData, - }, - } - } - - /// Creates a `ColRef` from slice views over the column vector data, The result has the same - /// number of rows as the length of the input slice. - #[inline(always)] - pub fn from_slice<'a, E: Entity>(slice: GroupFor) -> ColRef<'a, E> { - let nrows = SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len(); - - unsafe { - from_raw_parts( - E::faer_map( - slice, - #[inline(always)] - |slice| slice.as_ptr(), - ), - nrows, - 1, - ) - } - } - - /// Creates a `ColMut` from slice views over the column vector data, The result has the same - /// number of rows as the length of the input slice. - #[inline(always)] - pub fn from_slice_mut<'a, E: Entity>(slice: GroupFor) -> ColMut<'a, E> { - let nrows = SliceGroup::<'_, E>::new(E::faer_rb(E::faer_as_ref(&slice))).len(); - - unsafe { - from_raw_parts_mut( - E::faer_map( - slice, - #[inline(always)] - |slice| slice.as_mut_ptr(), - ), - nrows, - 1, - ) - } - } -} - -/// Row view creation module. -pub mod row { - use super::*; - - /// Creates a `RowRef` from pointers to the row vector data, number of columns, and column - /// stride. - /// - /// # Safety: - /// This function has the same safety requirements as - /// [`mat::from_raw_parts(ptr, 1, ncols, 0, col_stride)`] - #[inline(always)] - pub unsafe fn from_raw_parts<'a, E: Entity>( - ptr: GroupFor, - ncols: usize, - col_stride: isize, - ) -> RowRef<'a, E> { - RowRef { - inner: inner::DenseRowRef { - inner: VecImpl { - ptr: into_copy::(E::faer_map(ptr, |ptr| { - NonNull::new_unchecked(ptr as *mut E::Unit) - })), - len: ncols, - stride: col_stride, - }, - __marker: PhantomData, - }, - } - } - - /// Creates a `RowMut` from pointers to the row vector data, number of columns, and column - /// stride. - /// - /// # Safety: - /// This function has the same safety requirements as - /// [`mat::from_raw_parts_mut(ptr, 1, ncols, 0, col_stride)`] - #[inline(always)] - pub unsafe fn from_raw_parts_mut<'a, E: Entity>( - ptr: GroupFor, - ncols: usize, - col_stride: isize, - ) -> RowMut<'a, E> { - RowMut { - inner: inner::DenseRowMut { - inner: VecImpl { - ptr: into_copy::(E::faer_map(ptr, |ptr| { - NonNull::new_unchecked(ptr as *mut E::Unit) - })), - len: ncols, - stride: col_stride, - }, - __marker: PhantomData, - }, - } - } - - /// Creates a `RowRef` from slice views over the row vector data, The result has the same - /// number of columns as the length of the input slice. - #[inline(always)] - pub fn from_slice<'a, E: Entity>(slice: GroupFor) -> RowRef<'a, E> { - let nrows = SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len(); - - unsafe { - from_raw_parts( - E::faer_map( - slice, - #[inline(always)] - |slice| slice.as_ptr(), - ), - nrows, - 1, - ) - } - } - - /// Creates a `RowMut` from slice views over the row vector data, The result has the same - /// number of columns as the length of the input slice. - #[inline(always)] - pub fn from_slice_mut<'a, E: Entity>(slice: GroupFor) -> RowMut<'a, E> { - let nrows = SliceGroup::<'_, E>::new(E::faer_rb(E::faer_as_ref(&slice))).len(); - - unsafe { - from_raw_parts_mut( - E::faer_map( - slice, - #[inline(always)] - |slice| slice.as_mut_ptr(), - ), - nrows, - 1, - ) - } - } -} - -/// Convenience function to concatonate a nested list of matrices into a single -/// big ['Mat']. Concatonation pattern follows the numpy.block convention that -/// each sub-list must have an equal number of columns (net) but the boundaries -/// do not need to align. In other words, this sort of thing: -/// ```notcode -/// AAAbb -/// AAAbb -/// cDDDD -/// ``` -/// is perfectly acceptable. -#[doc(hidden)] -#[track_caller] -pub fn __concat_impl(blocks: &[&[MatRef<'_, E>]]) -> Mat { - #[inline(always)] - fn count_total_columns(block_row: &[MatRef<'_, E>]) -> usize { - let mut out: usize = 0; - for elem in block_row.iter() { - out += elem.ncols(); - } - out - } - - #[inline(always)] - #[track_caller] - fn count_rows(block_row: &[MatRef<'_, E>]) -> usize { - let mut out: usize = 0; - for (i, e) in block_row.iter().enumerate() { - if i.eq(&0) { - out = e.nrows(); - } else { - assert!(e.nrows().eq(&out)); - } - } - out - } - - // get size of result while doing checks - let mut n: usize = 0; - let mut m: usize = 0; - for row in blocks.iter() { - n += count_rows(row); - } - for (i, row) in blocks.iter().enumerate() { - let cols = count_total_columns(row); - if i.eq(&0) { - m = cols; - } else { - assert!(cols.eq(&m)); - } - } - - let mut mat = Mat::::zeros(n, m); - let mut ni: usize = 0; - let mut mj: usize; - for row in blocks.iter() { - mj = 0; - - for elem in row.iter() { - mat.as_mut() - .submatrix_mut(ni, mj, elem.nrows(), elem.ncols()) - .copy_from(elem); - mj += elem.ncols(); - } - ni += row[0].nrows(); - } - - mat -} - -#[cfg(test)] -mod tests { - macro_rules! impl_unit_entity { - ($ty: ty) => { - unsafe impl Entity for $ty { - type Unit = Self; - type Index = (); - type SimdUnit = (); - type SimdMask = (); - type SimdIndex = (); - type Group = IdentityGroup; - type Iter = I; - - type PrefixUnit<'a, S: Simd> = &'a [()]; - type SuffixUnit<'a, S: Simd> = &'a [()]; - type PrefixMutUnit<'a, S: Simd> = &'a mut [()]; - type SuffixMutUnit<'a, S: Simd> = &'a mut [()]; - - const N_COMPONENTS: usize = 1; - const UNIT: GroupCopyFor = (); - - #[inline(always)] - fn faer_first(group: GroupFor) -> T { - group - } - - #[inline(always)] - fn faer_from_units(group: GroupFor) -> Self { - group - } - - #[inline(always)] - fn faer_into_units(self) -> GroupFor { - self - } - - #[inline(always)] - fn faer_as_ref(group: &GroupFor) -> GroupFor { - group - } - - #[inline(always)] - fn faer_as_mut(group: &mut GroupFor) -> GroupFor { - group - } - - #[inline(always)] - fn faer_as_ptr(group: *mut GroupFor) -> GroupFor { - group - } - - #[inline(always)] - fn faer_map_impl( - group: GroupFor, - f: &mut impl FnMut(T) -> U, - ) -> GroupFor { - (*f)(group) - } - - #[inline(always)] - fn faer_map_with_context( - ctx: Ctx, - group: GroupFor, - f: &mut impl FnMut(Ctx, T) -> (Ctx, U), - ) -> (Ctx, GroupFor) { - (*f)(ctx, group) - } - - #[inline(always)] - fn faer_zip( - first: GroupFor, - second: GroupFor, - ) -> GroupFor { - (first, second) - } - #[inline(always)] - fn faer_unzip( - zipped: GroupFor, - ) -> (GroupFor, GroupFor) { - zipped - } - - #[inline(always)] - fn faer_into_iter( - iter: GroupFor, - ) -> Self::Iter { - iter.into_iter() - } - } - }; - } - - use super::*; - use crate::assert; - - #[test] - fn basic_slice() { - let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; - let slice = unsafe { mat::from_raw_parts::<'_, f64>(data.as_ptr(), 2, 3, 3, 1) }; - - assert!(slice.get(0, 0) == &1.0); - assert!(slice.get(0, 1) == &2.0); - assert!(slice.get(0, 2) == &3.0); - - assert!(slice.get(1, 0) == &4.0); - assert!(slice.get(1, 1) == &5.0); - assert!(slice.get(1, 2) == &6.0); - } - - #[test] - fn empty() { - { - let m = Mat::::new(); - assert!(m.nrows() == 0); - assert!(m.ncols() == 0); - assert!(m.row_capacity() == 0); - assert!(m.col_capacity() == 0); - } - - { - let m = Mat::::with_capacity(100, 120); - assert!(m.nrows() == 0); - assert!(m.ncols() == 0); - assert!(m.row_capacity() == 100); - assert!(m.col_capacity() == 120); - } - } - - #[test] - fn reserve() { - let mut m = Mat::::new(); - - m.reserve_exact(0, 0); - assert!(m.row_capacity() == 0); - assert!(m.col_capacity() == 0); - - m.reserve_exact(1, 1); - assert!(m.row_capacity() >= 1); - assert!(m.col_capacity() == 1); - - m.reserve_exact(2, 0); - assert!(m.row_capacity() >= 2); - assert!(m.col_capacity() == 1); - - m.reserve_exact(2, 3); - assert!(m.row_capacity() >= 2); - assert!(m.col_capacity() == 3); - } - - #[derive(Debug, PartialEq, Clone, Copy)] - struct Zst; - unsafe impl bytemuck::Zeroable for Zst {} - unsafe impl bytemuck::Pod for Zst {} - - #[test] - fn reserve_zst() { - impl_unit_entity!(Zst); - - let mut m = Mat::::new(); - - m.reserve_exact(0, 0); - assert!(m.row_capacity() == 0); - assert!(m.col_capacity() == 0); - - m.reserve_exact(1, 1); - assert!(m.row_capacity() == 1); - assert!(m.col_capacity() == 1); - - m.reserve_exact(2, 0); - assert!(m.row_capacity() == 2); - assert!(m.col_capacity() == 1); - - m.reserve_exact(2, 3); - assert!(m.row_capacity() == 2); - assert!(m.col_capacity() == 3); - - m.reserve_exact(usize::MAX, usize::MAX); - } - - #[test] - fn resize() { - let mut m = Mat::new(); - let f = |i, j| i as f64 - j as f64; - m.resize_with(2, 3, f); - assert!(m.read(0, 0) == 0.0); - assert!(m.read(0, 1) == -1.0); - assert!(m.read(0, 2) == -2.0); - assert!(m.read(1, 0) == 1.0); - assert!(m.read(1, 1) == 0.0); - assert!(m.read(1, 2) == -1.0); - - m.resize_with(1, 2, f); - assert!(m.read(0, 0) == 0.0); - assert!(m.read(0, 1) == -1.0); - - m.resize_with(2, 1, f); - assert!(m.read(0, 0) == 0.0); - assert!(m.read(1, 0) == 1.0); - - m.resize_with(1, 2, f); - assert!(m.read(0, 0) == 0.0); - assert!(m.read(0, 1) == -1.0); - } - - #[test] - fn resize_zst() { - // miri test - let mut m = Mat::new(); - let f = |_i, _j| Zst; - m.resize_with(2, 3, f); - m.resize_with(1, 2, f); - m.resize_with(2, 1, f); - m.resize_with(1, 2, f); - } - - #[test] - #[should_panic] - fn cap_overflow_1() { - let _ = Mat::::with_capacity(isize::MAX as usize, 1); - } - - #[test] - #[should_panic] - fn cap_overflow_2() { - let _ = Mat::::with_capacity(isize::MAX as usize, isize::MAX as usize); - } - - #[test] - fn matrix_macro() { - let mut x = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]; - - assert!(x[(0, 0)] == 1.0); - assert!(x[(0, 1)] == 2.0); - assert!(x[(0, 2)] == 3.0); - - assert!(x[(1, 0)] == 4.0); - assert!(x[(1, 1)] == 5.0); - assert!(x[(1, 2)] == 6.0); - - assert!(x[(2, 0)] == 7.0); - assert!(x[(2, 1)] == 8.0); - assert!(x[(2, 2)] == 9.0); - - x[(0, 0)] = 13.0; - assert!(x[(0, 0)] == 13.0); - - assert!(x.get(.., ..) == x); - assert!(x.get(.., 1..3) == x.as_ref().submatrix(0, 1, 3, 2)); - } - - #[test] - fn matrix_macro_cplx() { - let new = Complex::new; - let mut x = mat![ - [new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0)], - [new(7.0, 8.0), new(9.0, 10.0), new(11.0, 12.0)], - [new(13.0, 14.0), new(15.0, 16.0), new(17.0, 18.0)] - ]; - - assert!(x.read(0, 0) == Complex::new(1.0, 2.0)); - assert!(x.read(0, 1) == Complex::new(3.0, 4.0)); - assert!(x.read(0, 2) == Complex::new(5.0, 6.0)); - - assert!(x.read(1, 0) == Complex::new(7.0, 8.0)); - assert!(x.read(1, 1) == Complex::new(9.0, 10.0)); - assert!(x.read(1, 2) == Complex::new(11.0, 12.0)); - - assert!(x.read(2, 0) == Complex::new(13.0, 14.0)); - assert!(x.read(2, 1) == Complex::new(15.0, 16.0)); - assert!(x.read(2, 2) == Complex::new(17.0, 18.0)); - - x.write(1, 0, Complex::new(3.0, 2.0)); - assert!(x.read(1, 0) == Complex::new(3.0, 2.0)); - } - - #[test] - fn matrix_macro_native_cplx() { - let new = Complex::new; - let mut x = mat![ - [new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0)], - [new(7.0, 8.0), new(9.0, 10.0), new(11.0, 12.0)], - [new(13.0, 14.0), new(15.0, 16.0), new(17.0, 18.0)] - ]; - - assert!(x.read(0, 0) == Complex::new(1.0, 2.0)); - assert!(x.read(0, 1) == Complex::new(3.0, 4.0)); - assert!(x.read(0, 2) == Complex::new(5.0, 6.0)); - - assert!(x.read(1, 0) == Complex::new(7.0, 8.0)); - assert!(x.read(1, 1) == Complex::new(9.0, 10.0)); - assert!(x.read(1, 2) == Complex::new(11.0, 12.0)); - - assert!(x.read(2, 0) == Complex::new(13.0, 14.0)); - assert!(x.read(2, 1) == Complex::new(15.0, 16.0)); - assert!(x.read(2, 2) == Complex::new(17.0, 18.0)); - - x.write(1, 0, Complex::new(3.0, 2.0)); - assert!(x.read(1, 0) == Complex::new(3.0, 2.0)); - } - - #[test] - fn col_macro() { - let mut x = col![3.0, 5.0, 7.0, 9.0]; - - assert!(x[0] == 3.0); - assert!(x[1] == 5.0); - assert!(x[2] == 7.0); - assert!(x[3] == 9.0); - - x[0] = 13.0; - assert!(x[0] == 13.0); - - // TODO: - // Col::get() seems to be missing - // assert!(x.get(...) == x); - } - - #[test] - fn col_macro_cplx() { - let new = Complex::new; - let mut x = col![new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0),]; - - assert!(x.read(0) == Complex::new(1.0, 2.0)); - assert!(x.read(1) == Complex::new(3.0, 4.0)); - assert!(x.read(2) == Complex::new(5.0, 6.0)); - - x.write(0, Complex::new(3.0, 2.0)); - assert!(x.read(0) == Complex::new(3.0, 2.0)); - } - - #[test] - fn col_macro_native_cplx() { - let new = Complex::new; - let mut x = col![new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0),]; - - assert!(x.read(0) == Complex::new(1.0, 2.0)); - assert!(x.read(1) == Complex::new(3.0, 4.0)); - assert!(x.read(2) == Complex::new(5.0, 6.0)); - - x.write(0, Complex::new(3.0, 2.0)); - assert!(x.read(0) == Complex::new(3.0, 2.0)); - } - - #[test] - fn row_macro() { - let mut x = row![3.0, 5.0, 7.0, 9.0]; - - assert!(x[0] == 3.0); - assert!(x[1] == 5.0); - assert!(x[2] == 7.0); - assert!(x[3] == 9.0); - - x.write(0, 13.0); - assert!(x.read(0) == 13.0); - } - - #[test] - fn row_macro_cplx() { - let new = Complex::new; - let mut x = row![new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0),]; - - assert!(x.read(0) == Complex::new(1.0, 2.0)); - assert!(x.read(1) == Complex::new(3.0, 4.0)); - assert!(x.read(2) == Complex::new(5.0, 6.0)); - - x.write(0, Complex::new(3.0, 2.0)); - assert!(x.read(0) == Complex::new(3.0, 2.0)); - } - - #[test] - fn row_macro_native_cplx() { - let new = Complex::new; - let mut x = row![new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0),]; - - assert!(x.read(0) == new(1.0, 2.0)); - assert!(x.read(1) == new(3.0, 4.0)); - assert!(x.read(2) == new(5.0, 6.0)); - - x.write(0, new(3.0, 2.0)); - assert!(x.read(0) == new(3.0, 2.0)); - } - - #[test] - fn null_col_and_row() { - let null_col: Col = col![]; - assert!(null_col == Col::::new()); - - let null_row: Row = row![]; - assert!(null_row == Row::::new()); - } - - #[test] - fn positive_concat_f64() { - let a0: Mat = Mat::from_fn(2, 2, |_, _| 1f64); - let a1: Mat = Mat::from_fn(2, 3, |_, _| 2f64); - let a2: Mat = Mat::from_fn(2, 4, |_, _| 3f64); - - let b0: Mat = Mat::from_fn(1, 6, |_, _| 4f64); - let b1: Mat = Mat::from_fn(1, 3, |_, _| 5f64); - - let c0: Mat = Mat::from_fn(6, 1, |_, _| 6f64); - let c1: Mat = Mat::from_fn(6, 3, |_, _| 7f64); - let c2: Mat = Mat::from_fn(6, 2, |_, _| 8f64); - let c3: Mat = Mat::from_fn(6, 3, |_, _| 9f64); - - let x = __concat_impl(&[ - &[a0.as_ref(), a1.as_ref(), a2.as_ref()], - &[b0.as_ref(), b1.as_ref()], - &[c0.as_ref(), c1.as_ref(), c2.as_ref(), c3.as_ref()], - ]); - - assert!(x == concat![[a0, a1, a2], [b0, b1], [c0, c1, c2, &c3]]); - - assert!(x[(0, 0)] == 1f64); - assert!(x[(1, 1)] == 1f64); - assert!(x[(2, 2)] == 4f64); - assert!(x[(3, 3)] == 7f64); - assert!(x[(4, 4)] == 8f64); - assert!(x[(5, 5)] == 8f64); - assert!(x[(6, 6)] == 9f64); - assert!(x[(7, 7)] == 9f64); - assert!(x[(8, 8)] == 9f64); - } - - #[test] - fn to_owned_equality() { - use num_complex::Complex as C; - let mut mf32: Mat = mat![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]; - let mut mf64: Mat = mat![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]; - let mut mf32c: Mat> = mat![ - [C::new(1., 1.), C::new(2., 2.), C::new(3., 3.)], - [C::new(4., 4.), C::new(5., 5.), C::new(6., 6.)], - [C::new(7., 7.), C::new(8., 8.), C::new(9., 9.)] - ]; - let mut mf64c: Mat> = mat![ - [C::new(1., 1.), C::new(2., 2.), C::new(3., 3.)], - [C::new(4., 4.), C::new(5., 5.), C::new(6., 6.)], - [C::new(7., 7.), C::new(8., 8.), C::new(9., 9.)] - ]; - - assert!(mf32.transpose().to_owned().as_ref() == mf32.transpose()); - assert!(mf64.transpose().to_owned().as_ref() == mf64.transpose()); - assert!(mf32c.transpose().to_owned().as_ref() == mf32c.transpose()); - assert!(mf64c.transpose().to_owned().as_ref() == mf64c.transpose()); - - assert!(mf32.as_mut().transpose_mut().to_owned().as_ref() == mf32.transpose()); - assert!(mf64.as_mut().transpose_mut().to_owned().as_ref() == mf64.transpose()); - assert!(mf32c.as_mut().transpose_mut().to_owned().as_ref() == mf32c.transpose()); - assert!(mf64c.as_mut().transpose_mut().to_owned().as_ref() == mf64c.transpose()); - } - - #[test] - fn conj_to_owned_equality() { - use num_complex::Complex as C; - let mut mf32: Mat = mat![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]; - let mut mf64: Mat = mat![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]; - let mut mf32c: Mat> = mat![ - [C::new(1., 1.), C::new(2., 2.), C::new(3., 3.)], - [C::new(4., 4.), C::new(5., 5.), C::new(6., 6.)], - [C::new(7., 7.), C::new(8., 8.), C::new(9., 9.)] - ]; - let mut mf64c: Mat> = mat![ - [C::new(1., 1.), C::new(2., 2.), C::new(3., 3.)], - [C::new(4., 4.), C::new(5., 5.), C::new(6., 6.)], - [C::new(7., 7.), C::new(8., 8.), C::new(9., 9.)] - ]; - - assert!(mf32.as_ref().adjoint().to_owned().as_ref() == mf32.adjoint()); - assert!(mf64.as_ref().adjoint().to_owned().as_ref() == mf64.adjoint()); - assert!(mf32c.as_ref().adjoint().to_owned().as_ref() == mf32c.adjoint()); - assert!(mf64c.as_ref().adjoint().to_owned().as_ref() == mf64c.adjoint()); - - assert!(mf32.as_mut().adjoint_mut().to_owned().as_ref() == mf32.adjoint()); - assert!(mf64.as_mut().adjoint_mut().to_owned().as_ref() == mf64.adjoint()); - assert!(mf32c.as_mut().adjoint_mut().to_owned().as_ref() == mf32c.adjoint()); - assert!(mf64c.as_mut().adjoint_mut().to_owned().as_ref() == mf64c.adjoint()); - } - - #[test] - fn mat_mul_assign_scalar() { - let mut x = mat![[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]; - - let expected = mat![[0.0, 2.0], [4.0, 6.0], [8.0, 10.0]]; - x *= scale(2.0); - assert_eq!(x, expected); - - let expected = mat![[0.0, 4.0], [8.0, 12.0], [16.0, 20.0]]; - let mut x_mut = x.as_mut(); - x_mut *= scale(2.0); - assert_eq!(x, expected); - } - - #[test] - fn test_col_slice() { - let mut matrix = mat![[1.0, 5.0, 9.0], [2.0, 6.0, 10.0], [3.0, 7.0, 11.0f64]]; - - assert_eq!(matrix.col_as_slice(1), &[5.0, 6.0, 7.0]); - assert_eq!(matrix.col_as_slice_mut(0), &[1.0, 2.0, 3.0]); - - matrix - .col_as_slice_mut(0) - .copy_from_slice(&[-1.0, -2.0, -3.0]); - - let expected = mat![[-1.0, 5.0, 9.0], [-2.0, 6.0, 10.0], [-3.0, 7.0, 11.0f64]]; - assert_eq!(matrix, expected); - } - - #[test] - fn from_slice() { - let mut slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64]; - - let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; - let view = mat::from_column_major_slice::<'_, f64>(&slice, 3, 2); - assert_eq!(expected, view); - let view = mat::from_column_major_slice::<'_, f64>(&mut slice, 3, 2); - assert_eq!(expected, view); - - let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]; - let view = mat::from_row_major_slice::<'_, f64>(&slice, 3, 2); - assert_eq!(expected, view); - let view = mat::from_row_major_slice::<'_, f64>(&mut slice, 3, 2); - assert_eq!(expected, view); - } - - #[test] - #[should_panic] - fn from_slice_too_big() { - let slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0_f64]; - mat::from_column_major_slice::<'_, f64>(&slice, 3, 2); - } - - #[test] - #[should_panic] - fn from_slice_too_small() { - let slice = [1.0, 2.0, 3.0, 4.0, 5.0_f64]; - mat::from_column_major_slice::<'_, f64>(&slice, 3, 2); - } - - #[test] - fn test_is_finite() { - let inf = f32::INFINITY; - let nan = f32::NAN; - - { - assert!(::faer_is_finite(&1.0)); - assert!(!::faer_is_finite(&inf)); - assert!(!::faer_is_finite(&-inf)); - assert!(!::faer_is_finite(&nan)); - } - { - let x = c32::new(1.0, 2.0); - assert!(::faer_is_finite(&x)); - - let x = c32::new(inf, 2.0); - assert!(!::faer_is_finite(&x)); - - let x = c32::new(1.0, inf); - assert!(!::faer_is_finite(&x)); - - let x = c32::new(inf, inf); - assert!(!::faer_is_finite(&x)); - - let x = c32::new(nan, 2.0); - assert!(!::faer_is_finite(&x)); - - let x = c32::new(1.0, nan); - assert!(!::faer_is_finite(&x)); - - let x = c32::new(nan, nan); - assert!(!::faer_is_finite(&x)); - } - } - - #[test] - fn test_iter() { - let mut mat = Mat::from_fn(9, 10, |i, j| (i + j) as f64); - let mut iter = mat.row_chunks_mut(4); - - let first = iter.next(); - let second = iter.next(); - let last = iter.next(); - let none = iter.next(); - - assert!(first == Some(Mat::from_fn(4, 10, |i, j| (i + j) as f64).as_mut())); - assert!(second == Some(Mat::from_fn(4, 10, |i, j| (i + j + 4) as f64).as_mut())); - assert!(last == Some(Mat::from_fn(1, 10, |i, j| (i + j + 8) as f64).as_mut())); - assert!(none == None); - } - - #[test] - fn test_norm_l2() { - let relative_err = |a: f64, b: f64| (a - b).abs() / f64::max(a.abs(), b.abs()); - - for (m, n) in [(9, 10), (1023, 5), (42, 1)] { - for factor in [0.0, 1.0, 1e30, 1e250, 1e-30, 1e-250] { - let mat = Mat::from_fn(m, n, |i, j| factor * ((i + j) as f64)); - let mut target = 0.0; - zipped!(mat.as_ref()).for_each(|unzipped!(x)| { - target = f64::hypot(*x, target); - }); - - if factor == 0.0 { - assert!(mat.norm_l2() == target); - } else { - assert!(relative_err(mat.norm_l2(), target) < 1e-14); - } - } - } - - let mat = Col::from_fn(10000000, |_| 0.3); - let target = (0.3 * 0.3 * 10000000.0f64).sqrt(); - assert!(relative_err(mat.norm_l2(), target) < 1e-14); - } - - #[test] - fn test_sum() { - let relative_err = |a: f64, b: f64| (a - b).abs() / f64::max(a.abs(), b.abs()); - - for (m, n) in [(9, 10), (1023, 5), (42, 1)] { - for factor in [0.0, 1.0, 1e30, 1e250, 1e-30, 1e-250] { - let mat = Mat::from_fn(m, n, |i, j| factor * ((i + j) as f64)); - let mut target = 0.0; - zipped!(mat.as_ref()).for_each(|unzipped!(x)| { - target += *x; - }); - - if factor == 0.0 { - assert!(mat.sum() == target); - } else { - assert!(relative_err(mat.sum(), target) < 1e-14); - } - } - } - - let mat = Col::from_fn(10000000, |_| 0.3); - let target = 0.3 * 10000000.0f64; - assert!(relative_err(mat.sum(), target) < 1e-14); - } - - #[test] - fn test_kron_ones() { - for (m, n, p, q) in [(2, 3, 4, 5), (3, 2, 5, 4), (1, 1, 1, 1)] { - let a = Mat::from_fn(m, n, |_, _| 1 as f64); - let b = Mat::from_fn(p, q, |_, _| 1 as f64); - let expected = Mat::from_fn(m * p, n * q, |_, _| 1 as f64); - assert!(a.kron(&b) == expected); - } - - for (m, n, p) in [(2, 3, 4), (3, 2, 5), (1, 1, 1)] { - let a = Mat::from_fn(m, n, |_, _| 1 as f64); - let b = Col::from_fn(p, |_| 1 as f64); - let expected = Mat::from_fn(m * p, n, |_, _| 1 as f64); - assert!(a.kron(&b) == expected); - assert!(b.kron(&a) == expected); - - let a = Mat::from_fn(m, n, |_, _| 1 as f64); - let b = Row::from_fn(p, |_| 1 as f64); - let expected = Mat::from_fn(m, n * p, |_, _| 1 as f64); - assert!(a.kron(&b) == expected); - assert!(b.kron(&a) == expected); - } - - for (m, n) in [(2, 3), (3, 2), (1, 1)] { - let a = Row::from_fn(m, |_| 1 as f64); - let b = Col::from_fn(n, |_| 1 as f64); - let expected = Mat::from_fn(n, m, |_, _| 1 as f64); - assert!(a.kron(&b) == expected); - assert!(b.kron(&a) == expected); - - let c = Row::from_fn(n, |_| 1 as f64); - let expected = Mat::from_fn(1, m * n, |_, _| 1 as f64); - assert!(a.kron(&c) == expected); - - let d = Col::from_fn(m, |_| 1 as f64); - let expected = Mat::from_fn(m * n, 1, |_, _| 1 as f64); - assert!(d.kron(&b) == expected); - } - } - - #[test] - fn test_col_index() { - let mut col_32: Col = Col::from_fn(3, |i| i as f32); - col_32.as_mut()[1] = 10f32; - let tval: f32 = (10f32 - col_32[1]).abs(); - assert!(tval < 1e-14); - - let mut col_64: Col = Col::from_fn(3, |i| i as f64); - col_64.as_mut()[1] = 10f64; - let tval: f64 = (10f64 - col_64[1]).abs(); - assert!(tval < 1e-14); - } - - #[test] - fn test_row_index() { - let mut row_32: Row = Row::from_fn(3, |i| i as f32); - row_32.as_mut()[1] = 10f32; - let tval: f32 = (10f32 - row_32[1]).abs(); - assert!(tval < 1e-14); - - let mut row_64: Row = Row::from_fn(3, |i| i as f64); - row_64.as_mut()[1] = 10f64; - let tval: f64 = (10f64 - row_64[1]).abs(); - assert!(tval < 1e-14); - } -} - -/// Implementation of [`zipped!`] structures. -pub mod zip { - use super::{assert, debug_assert, *}; - use core::mem::MaybeUninit; - - /// Read only view over a single matrix element. - pub struct Read<'a, E: Entity> { - ptr: GroupFor>, - } - /// Read-write view over a single matrix element. - pub struct ReadWrite<'a, E: Entity> { - ptr: GroupFor>, - } - - /// Type that can be converted to a view. - pub trait ViewMut { - /// View type. - type Target<'a> - where - Self: 'a; - - /// Returns the view over self. - fn view_mut(&mut self) -> Self::Target<'_>; - } - - impl ViewMut for Row { - type Target<'a> = RowRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - self.as_ref() - } - } - impl ViewMut for &Row { - type Target<'a> = RowRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (*self).as_ref() - } - } - impl ViewMut for &mut Row { - type Target<'a> = RowMut<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (*self).as_mut() - } - } - - impl ViewMut for RowRef<'_, E> { - type Target<'a> = RowRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - *self - } - } - impl ViewMut for RowMut<'_, E> { - type Target<'a> = RowMut<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (*self).rb_mut() - } - } - impl ViewMut for &mut RowRef<'_, E> { - type Target<'a> = RowRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - **self - } - } - impl ViewMut for &mut RowMut<'_, E> { - type Target<'a> = RowMut<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (**self).rb_mut() - } - } - impl ViewMut for &RowRef<'_, E> { - type Target<'a> = RowRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - **self - } - } - impl ViewMut for &RowMut<'_, E> { - type Target<'a> = RowRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (**self).rb() - } - } - - impl ViewMut for Col { - type Target<'a> = ColRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - self.as_ref() - } - } - impl ViewMut for &Col { - type Target<'a> = ColRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (*self).as_ref() - } - } - impl ViewMut for &mut Col { - type Target<'a> = ColMut<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (*self).as_mut() - } - } - - impl ViewMut for ColRef<'_, E> { - type Target<'a> = ColRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - *self - } - } - impl ViewMut for ColMut<'_, E> { - type Target<'a> = ColMut<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (*self).rb_mut() - } - } - impl ViewMut for &mut ColRef<'_, E> { - type Target<'a> = ColRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - **self - } - } - impl ViewMut for &mut ColMut<'_, E> { - type Target<'a> = ColMut<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (**self).rb_mut() - } - } - impl ViewMut for &ColRef<'_, E> { - type Target<'a> = ColRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - **self - } - } - impl ViewMut for &ColMut<'_, E> { - type Target<'a> = ColRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (**self).rb() - } - } - - impl ViewMut for Mat { - type Target<'a> = MatRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - self.as_ref() - } - } - impl ViewMut for &Mat { - type Target<'a> = MatRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (*self).as_ref() - } - } - impl ViewMut for &mut Mat { - type Target<'a> = MatMut<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (*self).as_mut() - } - } - - impl ViewMut for MatRef<'_, E> { - type Target<'a> = MatRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - *self - } - } - impl ViewMut for MatMut<'_, E> { - type Target<'a> = MatMut<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (*self).rb_mut() - } - } - impl ViewMut for &mut MatRef<'_, E> { - type Target<'a> = MatRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - **self - } - } - impl ViewMut for &mut MatMut<'_, E> { - type Target<'a> = MatMut<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (**self).rb_mut() - } - } - impl ViewMut for &MatRef<'_, E> { - type Target<'a> = MatRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - **self - } - } - impl ViewMut for &MatMut<'_, E> { - type Target<'a> = MatRef<'a, E> - where - Self: 'a; - - #[inline] - fn view_mut(&mut self) -> Self::Target<'_> { - (**self).rb() - } - } - - impl core::ops::Deref for Read<'_, E> { - type Target = E; - #[inline(always)] - fn deref(&self) -> &Self::Target { - unsafe { &*(self.ptr as *const _ as *const E::Unit) } - } - } - impl core::ops::Deref for ReadWrite<'_, E> { - type Target = E; - #[inline(always)] - fn deref(&self) -> &Self::Target { - unsafe { &*(self.ptr as *const _ as *const E::Unit) } - } - } - impl core::ops::DerefMut for ReadWrite<'_, E> { - #[inline(always)] - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { &mut *(self.ptr as *mut _ as *mut E::Unit) } - } - } - - impl Read<'_, E> { - /// Read the value of the element. - #[inline(always)] - pub fn read(&self) -> E { - E::faer_from_units(E::faer_map( - E::faer_as_ref(&self.ptr), - #[inline(always)] - |ptr| unsafe { ptr.assume_init_read() }, - )) - } - } - impl ReadWrite<'_, E> { - /// Read the value of the element. - #[inline(always)] - pub fn read(&self) -> E { - E::faer_from_units(E::faer_map( - E::faer_as_ref(&self.ptr), - #[inline(always)] - |ptr| unsafe { *ptr.assume_init_ref() }, - )) - } - - /// Write to the location of the element. - #[inline(always)] - pub fn write(&mut self, value: E) { - let value = E::faer_into_units(value); - E::faer_map( - E::faer_zip(E::faer_as_mut(&mut self.ptr), value), - #[inline(always)] - |(ptr, value)| unsafe { *ptr.assume_init_mut() = value }, - ); - } - } - - /// Specifies whether the main diagonal should be traversed, when iterating over a triangular - /// chunk of the matrix. - #[derive(Copy, Clone, Debug, PartialEq, Eq)] - pub enum Diag { - /// Do not include diagonal of matrix - Skip, - /// Include diagonal of matrix - Include, - } - - /// Matrix layout transformation. Used for zipping optimizations. - #[derive(Copy, Clone)] - pub enum MatLayoutTransform { - /// Matrix is used as-is. - None, - /// Matrix rows are reversed. - ReverseRows, - /// Matrix is transposed. - Transpose, - /// Matrix is transposed, then rows are reversed. - TransposeReverseRows, - } - - /// Vector layout transformation. Used for zipping optimizations. - #[derive(Copy, Clone)] - pub enum VecLayoutTransform { - /// Vector is used as-is. - None, - /// Vector is reversed. - Reverse, - } - - /// Type with a given matrix shape. - pub trait MatShape { - /// Type of rows. - type Rows: Copy + Eq; - /// Type of columns. - type Cols: Copy + Eq; - /// Returns the number of rows. - fn nrows(&self) -> Self::Rows; - /// Returns the number of columns. - fn ncols(&self) -> Self::Cols; - } - - /// Zipped matrix views. - pub unsafe trait MaybeContiguous: MatShape { - /// Indexing type. - type Index: Copy; - /// Contiguous slice type. - type Slice; - /// Layout transformation type. - type LayoutTransform: Copy; - /// Returns slice at index of length `n_elems`. - unsafe fn get_slice_unchecked(&mut self, idx: Self::Index, n_elems: usize) -> Self::Slice; - } - - /// Zipped matrix views. - pub unsafe trait MatIndex<'a, _Outlives = &'a Self>: MaybeContiguous { - /// Item produced by the zipped views. - type Item; - - /// Get the item at the given index, skipping bound checks. - unsafe fn get_unchecked(&'a mut self, index: Self::Index) -> Self::Item; - /// Get the item at the given slice position, skipping bound checks. - unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item; - - /// Checks if the zipped matrices are contiguous. - fn is_contiguous(&self) -> bool; - /// Computes the preferred iteration layout of the matrices. - fn preferred_layout(&self) -> Self::LayoutTransform; - /// Applies the layout transformation to the matrices. - fn with_layout(self, layout: Self::LayoutTransform) -> Self; - } - - /// Single element. - #[derive(Copy, Clone, Debug)] - pub struct Last(pub Mat); - - /// Zipped elements. - #[derive(Copy, Clone, Debug)] - pub struct Zip(pub Head, pub Tail); - - /// Single matrix view. - #[derive(Copy, Clone, Debug)] - pub struct LastEq>(pub Mat); - /// Zipped matrix views. - #[derive(Copy, Clone, Debug)] - pub struct ZipEq< - Rows, - Cols, - Head: MatShape, - Tail: MatShape, - >(Head, Tail); - - impl< - Rows: Copy + Eq, - Cols: Copy + Eq, - Head: MatShape, - Tail: MatShape, - > ZipEq - { - /// Creates a zipped matrix, after asserting that the dimensions match. - #[inline(always)] - pub fn new(head: Head, tail: Tail) -> Self { - assert!((head.nrows(), head.ncols()) == (tail.nrows(), tail.ncols())); - Self(head, tail) - } - - /// Creates a zipped matrix, assuming that the dimensions match. - #[inline(always)] - pub fn new_unchecked(head: Head, tail: Tail) -> Self { - debug_assert!((head.nrows(), head.ncols()) == (tail.nrows(), tail.ncols())); - Self(head, tail) - } - } - - impl> MatShape - for LastEq - { - type Rows = Rows; - type Cols = Cols; - #[inline(always)] - fn nrows(&self) -> Self::Rows { - self.0.nrows() - } - #[inline(always)] - fn ncols(&self) -> Self::Cols { - self.0.ncols() - } - } - - impl< - Rows: Copy + Eq, - Cols: Copy + Eq, - Head: MatShape, - Tail: MatShape, - > MatShape for ZipEq - { - type Rows = Rows; - type Cols = Cols; - #[inline(always)] - fn nrows(&self) -> Self::Rows { - self.0.nrows() - } - #[inline(always)] - fn ncols(&self) -> Self::Cols { - self.0.ncols() - } - } - - impl MatShape for ColRef<'_, E> { - type Rows = usize; - type Cols = (); - #[inline(always)] - fn nrows(&self) -> Self::Rows { - (*self).nrows() - } - #[inline(always)] - fn ncols(&self) -> Self::Cols { - () - } - } - - impl MatShape for ColMut<'_, E> { - type Rows = usize; - type Cols = (); - #[inline(always)] - fn nrows(&self) -> Self::Rows { - (*self).nrows() - } - #[inline(always)] - fn ncols(&self) -> Self::Cols { - () - } - } - - impl MatShape for RowRef<'_, E> { - type Rows = (); - type Cols = usize; - #[inline(always)] - fn nrows(&self) -> Self::Rows { - () - } - #[inline(always)] - fn ncols(&self) -> Self::Cols { - (*self).ncols() - } - } - impl MatShape for RowMut<'_, E> { - type Rows = (); - type Cols = usize; - #[inline(always)] - fn nrows(&self) -> Self::Rows { - () - } - #[inline(always)] - fn ncols(&self) -> Self::Cols { - (*self).ncols() - } - } - - impl MatShape for MatRef<'_, E> { - type Rows = usize; - type Cols = usize; - #[inline(always)] - fn nrows(&self) -> Self::Rows { - (*self).nrows() - } - #[inline(always)] - fn ncols(&self) -> Self::Cols { - (*self).ncols() - } - } - - impl MatShape for MatMut<'_, E> { - type Rows = usize; - type Cols = usize; - #[inline(always)] - fn nrows(&self) -> Self::Rows { - (*self).nrows() - } - #[inline(always)] - fn ncols(&self) -> Self::Cols { - (*self).ncols() - } - } - - unsafe impl> - MaybeContiguous for LastEq - { - type Index = Mat::Index; - type Slice = Last; - type LayoutTransform = Mat::LayoutTransform; - #[inline(always)] - unsafe fn get_slice_unchecked(&mut self, idx: Self::Index, n_elems: usize) -> Self::Slice { - Last(self.0.get_slice_unchecked(idx, n_elems)) - } - } - - unsafe impl<'a, Rows: Copy + Eq, Cols: Copy + Eq, Mat: MatIndex<'a, Rows = Rows, Cols = Cols>> - MatIndex<'a> for LastEq - { - type Item = Last; - - #[inline(always)] - unsafe fn get_unchecked(&'a mut self, index: Self::Index) -> Self::Item { - Last(self.0.get_unchecked(index)) - } - - #[inline(always)] - unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { - Last(Mat::get_from_slice_unchecked(&mut slice.0, idx)) - } - - #[inline(always)] - fn is_contiguous(&self) -> bool { - self.0.is_contiguous() - } - #[inline(always)] - fn preferred_layout(&self) -> Self::LayoutTransform { - self.0.preferred_layout() - } - #[inline(always)] - fn with_layout(self, layout: Self::LayoutTransform) -> Self { - Self(self.0.with_layout(layout)) - } - } - - unsafe impl< - Rows: Copy + Eq, - Cols: Copy + Eq, - Head: MaybeContiguous, - Tail: MaybeContiguous< - Rows = Rows, - Cols = Cols, - Index = Head::Index, - LayoutTransform = Head::LayoutTransform, - >, - > MaybeContiguous for ZipEq - { - type Index = Head::Index; - type Slice = Zip; - type LayoutTransform = Head::LayoutTransform; - #[inline(always)] - unsafe fn get_slice_unchecked(&mut self, idx: Self::Index, n_elems: usize) -> Self::Slice { - Zip( - self.0.get_slice_unchecked(idx, n_elems), - self.1.get_slice_unchecked(idx, n_elems), - ) - } - } - - unsafe impl< - 'a, - Rows: Copy + Eq, - Cols: Copy + Eq, - Head: MatIndex<'a, Rows = Rows, Cols = Cols>, - Tail: MatIndex< - 'a, - Rows = Rows, - Cols = Cols, - Index = Head::Index, - LayoutTransform = Head::LayoutTransform, - >, - > MatIndex<'a> for ZipEq - { - type Item = Zip; - - #[inline(always)] - unsafe fn get_unchecked(&'a mut self, index: Self::Index) -> Self::Item { - Zip(self.0.get_unchecked(index), self.1.get_unchecked(index)) - } - - #[inline(always)] - unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { - Zip( - Head::get_from_slice_unchecked(&mut slice.0, idx), - Tail::get_from_slice_unchecked(&mut slice.1, idx), - ) - } - - #[inline(always)] - fn is_contiguous(&self) -> bool { - self.0.is_contiguous() && self.1.is_contiguous() - } - #[inline(always)] - fn preferred_layout(&self) -> Self::LayoutTransform { - self.0.preferred_layout() - } - #[inline(always)] - fn with_layout(self, layout: Self::LayoutTransform) -> Self { - ZipEq(self.0.with_layout(layout), self.1.with_layout(layout)) - } - } - - unsafe impl MaybeContiguous for ColRef<'_, E> { - type Index = (usize, ()); - type Slice = GroupFor]>; - type LayoutTransform = VecLayoutTransform; - - #[inline(always)] - unsafe fn get_slice_unchecked( - &mut self, - (i, _): Self::Index, - n_elems: usize, - ) -> Self::Slice { - E::faer_map( - (*self).rb().ptr_at(i), - #[inline(always)] - |ptr| core::slice::from_raw_parts(ptr as *const MaybeUninit, n_elems), - ) - } - } - unsafe impl<'a, E: Entity> MatIndex<'a> for ColRef<'_, E> { - type Item = Read<'a, E>; - - #[inline(always)] - unsafe fn get_unchecked(&'a mut self, (i, _): Self::Index) -> Self::Item { - Read { - ptr: E::faer_map( - self.rb().ptr_inbounds_at(i), - #[inline(always)] - |ptr| &*(ptr as *const MaybeUninit), - ), - } - } - - #[inline(always)] - unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { - let slice = E::faer_rb(E::faer_as_ref(slice)); - Read { - ptr: E::faer_map( - slice, - #[inline(always)] - |slice| slice.get_unchecked(idx), - ), - } - } - - #[inline(always)] - fn is_contiguous(&self) -> bool { - self.row_stride() == 1 - } - #[inline(always)] - fn preferred_layout(&self) -> Self::LayoutTransform { - let rs = self.row_stride(); - if self.nrows() > 1 && rs == 1 { - VecLayoutTransform::None - } else if self.nrows() > 1 && rs == -1 { - VecLayoutTransform::Reverse - } else { - VecLayoutTransform::None - } - } - #[inline(always)] - fn with_layout(self, layout: Self::LayoutTransform) -> Self { - use VecLayoutTransform::*; - match layout { - None => self, - Reverse => self.reverse_rows(), - } - } - } - - unsafe impl MaybeContiguous for ColMut<'_, E> { - type Index = (usize, ()); - type Slice = GroupFor]>; - type LayoutTransform = VecLayoutTransform; - - #[inline(always)] - unsafe fn get_slice_unchecked( - &mut self, - (i, _): Self::Index, - n_elems: usize, - ) -> Self::Slice { - E::faer_map( - (*self).rb_mut().ptr_at_mut(i), - #[inline(always)] - |ptr| core::slice::from_raw_parts_mut(ptr as *mut MaybeUninit, n_elems), - ) - } - } - unsafe impl<'a, E: Entity> MatIndex<'a> for ColMut<'_, E> { - type Item = ReadWrite<'a, E>; - - #[inline(always)] - unsafe fn get_unchecked(&'a mut self, (i, _): Self::Index) -> Self::Item { - ReadWrite { - ptr: E::faer_map( - self.rb_mut().ptr_inbounds_at_mut(i), - #[inline(always)] - |ptr| &mut *(ptr as *mut MaybeUninit), - ), - } - } - - #[inline(always)] - unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { - let slice = E::faer_rb_mut(E::faer_as_mut(slice)); - ReadWrite { - ptr: E::faer_map( - slice, - #[inline(always)] - |slice| slice.get_unchecked_mut(idx), - ), - } - } - - #[inline(always)] - fn is_contiguous(&self) -> bool { - self.row_stride() == 1 - } - #[inline(always)] - fn preferred_layout(&self) -> Self::LayoutTransform { - let rs = self.row_stride(); - if self.nrows() > 1 && rs == 1 { - VecLayoutTransform::None - } else if self.nrows() > 1 && rs == -1 { - VecLayoutTransform::Reverse - } else { - VecLayoutTransform::None - } - } - #[inline(always)] - fn with_layout(self, layout: Self::LayoutTransform) -> Self { - use VecLayoutTransform::*; - match layout { - None => self, - Reverse => self.reverse_rows_mut(), - } - } - } - - unsafe impl MaybeContiguous for RowRef<'_, E> { - type Index = ((), usize); - type Slice = GroupFor]>; - type LayoutTransform = VecLayoutTransform; - - #[inline(always)] - unsafe fn get_slice_unchecked( - &mut self, - (_, j): Self::Index, - n_elems: usize, - ) -> Self::Slice { - E::faer_map( - (*self).rb().ptr_at(j), - #[inline(always)] - |ptr| core::slice::from_raw_parts(ptr as *const MaybeUninit, n_elems), - ) - } - } - unsafe impl<'a, E: Entity> MatIndex<'a> for RowRef<'_, E> { - type Item = Read<'a, E>; - - #[inline(always)] - unsafe fn get_unchecked(&'a mut self, (_, j): Self::Index) -> Self::Item { - Read { - ptr: E::faer_map( - self.rb().ptr_inbounds_at(j), - #[inline(always)] - |ptr| &*(ptr as *const MaybeUninit), - ), - } - } - - #[inline(always)] - unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { - let slice = E::faer_rb(E::faer_as_ref(slice)); - Read { - ptr: E::faer_map( - slice, - #[inline(always)] - |slice| slice.get_unchecked(idx), - ), - } - } - - #[inline(always)] - fn is_contiguous(&self) -> bool { - self.col_stride() == 1 - } - #[inline(always)] - fn preferred_layout(&self) -> Self::LayoutTransform { - let cs = self.col_stride(); - if self.ncols() > 1 && cs == 1 { - VecLayoutTransform::None - } else if self.ncols() > 1 && cs == -1 { - VecLayoutTransform::Reverse - } else { - VecLayoutTransform::None - } - } - #[inline(always)] - fn with_layout(self, layout: Self::LayoutTransform) -> Self { - use VecLayoutTransform::*; - match layout { - None => self, - Reverse => self.reverse_cols(), - } - } - } - - unsafe impl MaybeContiguous for RowMut<'_, E> { - type Index = ((), usize); - type Slice = GroupFor]>; - type LayoutTransform = VecLayoutTransform; - - #[inline(always)] - unsafe fn get_slice_unchecked( - &mut self, - (_, j): Self::Index, - n_elems: usize, - ) -> Self::Slice { - E::faer_map( - (*self).rb_mut().ptr_at_mut(j), - #[inline(always)] - |ptr| core::slice::from_raw_parts_mut(ptr as *mut MaybeUninit, n_elems), - ) - } - } - unsafe impl<'a, E: Entity> MatIndex<'a> for RowMut<'_, E> { - type Item = ReadWrite<'a, E>; - - #[inline(always)] - unsafe fn get_unchecked(&'a mut self, (_, j): Self::Index) -> Self::Item { - ReadWrite { - ptr: E::faer_map( - self.rb_mut().ptr_inbounds_at_mut(j), - #[inline(always)] - |ptr| &mut *(ptr as *mut MaybeUninit), - ), - } - } - - #[inline(always)] - unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { - let slice = E::faer_rb_mut(E::faer_as_mut(slice)); - ReadWrite { - ptr: E::faer_map( - slice, - #[inline(always)] - |slice| slice.get_unchecked_mut(idx), - ), - } - } - - #[inline(always)] - fn is_contiguous(&self) -> bool { - self.col_stride() == 1 - } - #[inline(always)] - fn preferred_layout(&self) -> Self::LayoutTransform { - let cs = self.col_stride(); - if self.ncols() > 1 && cs == 1 { - VecLayoutTransform::None - } else if self.ncols() > 1 && cs == -1 { - VecLayoutTransform::Reverse - } else { - VecLayoutTransform::None - } - } - #[inline(always)] - fn with_layout(self, layout: Self::LayoutTransform) -> Self { - use VecLayoutTransform::*; - match layout { - None => self, - Reverse => self.reverse_cols_mut(), - } - } - } - - unsafe impl MaybeContiguous for MatRef<'_, E> { - type Index = (usize, usize); - type Slice = GroupFor]>; - type LayoutTransform = MatLayoutTransform; - - #[inline(always)] - unsafe fn get_slice_unchecked( - &mut self, - (i, j): Self::Index, - n_elems: usize, - ) -> Self::Slice { - E::faer_map( - (*self).rb().overflowing_ptr_at(i, j), - #[inline(always)] - |ptr| core::slice::from_raw_parts(ptr as *const MaybeUninit, n_elems), - ) - } - } - unsafe impl<'a, E: Entity> MatIndex<'a> for MatRef<'_, E> { - type Item = Read<'a, E>; - - #[inline(always)] - unsafe fn get_unchecked(&'a mut self, (i, j): Self::Index) -> Self::Item { - Read { - ptr: E::faer_map( - self.rb().ptr_inbounds_at(i, j), - #[inline(always)] - |ptr| &*(ptr as *const MaybeUninit), - ), - } - } - - #[inline(always)] - unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { - let slice = E::faer_rb(E::faer_as_ref(slice)); - Read { - ptr: E::faer_map( - slice, - #[inline(always)] - |slice| slice.get_unchecked(idx), - ), - } - } - - #[inline(always)] - fn is_contiguous(&self) -> bool { - self.row_stride() == 1 - } - #[inline(always)] - fn preferred_layout(&self) -> Self::LayoutTransform { - let rs = self.row_stride(); - let cs = self.col_stride(); - if self.nrows() > 1 && rs == 1 { - MatLayoutTransform::None - } else if self.nrows() > 1 && rs == -1 { - MatLayoutTransform::ReverseRows - } else if self.ncols() > 1 && cs == 1 { - MatLayoutTransform::Transpose - } else if self.ncols() > 1 && cs == -1 { - MatLayoutTransform::TransposeReverseRows - } else { - MatLayoutTransform::None - } - } - #[inline(always)] - fn with_layout(self, layout: Self::LayoutTransform) -> Self { - use MatLayoutTransform::*; - match layout { - None => self, - ReverseRows => self.reverse_rows(), - Transpose => self.transpose(), - TransposeReverseRows => self.transpose().reverse_rows(), - } - } - } - - unsafe impl MaybeContiguous for MatMut<'_, E> { - type Index = (usize, usize); - type Slice = GroupFor]>; - type LayoutTransform = MatLayoutTransform; - - #[inline(always)] - unsafe fn get_slice_unchecked( - &mut self, - (i, j): Self::Index, - n_elems: usize, - ) -> Self::Slice { - E::faer_map( - (*self).rb().overflowing_ptr_at(i, j), - #[inline(always)] - |ptr| core::slice::from_raw_parts_mut(ptr as *mut MaybeUninit, n_elems), - ) - } - } - - unsafe impl<'a, E: Entity> MatIndex<'a> for MatMut<'_, E> { - type Item = ReadWrite<'a, E>; - - #[inline(always)] - unsafe fn get_unchecked(&'a mut self, (i, j): Self::Index) -> Self::Item { - ReadWrite { - ptr: E::faer_map( - self.rb_mut().ptr_inbounds_at_mut(i, j), - #[inline(always)] - |ptr| &mut *(ptr as *mut MaybeUninit), - ), - } - } - - #[inline(always)] - unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { - let slice = E::faer_rb_mut(E::faer_as_mut(slice)); - ReadWrite { - ptr: E::faer_map( - slice, - #[inline(always)] - |slice| slice.get_unchecked_mut(idx), - ), - } - } - - #[inline(always)] - fn is_contiguous(&self) -> bool { - self.row_stride() == 1 - } - #[inline(always)] - fn preferred_layout(&self) -> Self::LayoutTransform { - let rs = self.row_stride(); - let cs = self.col_stride(); - if self.nrows() > 1 && rs == 1 { - MatLayoutTransform::None - } else if self.nrows() > 1 && rs == -1 { - MatLayoutTransform::ReverseRows - } else if self.ncols() > 1 && cs == 1 { - MatLayoutTransform::Transpose - } else if self.ncols() > 1 && cs == -1 { - MatLayoutTransform::TransposeReverseRows - } else { - MatLayoutTransform::None - } - } - #[inline(always)] - fn with_layout(self, layout: Self::LayoutTransform) -> Self { - use MatLayoutTransform::*; - match layout { - None => self, - ReverseRows => self.reverse_rows_mut(), - Transpose => self.transpose_mut(), - TransposeReverseRows => self.transpose_mut().reverse_rows_mut(), - } - } - } - - #[inline(always)] - fn annotate_noalias_mat MatIndex<'a>>( - f: &mut impl for<'a> FnMut(>::Item), - mut slice: Z::Slice, - i_begin: usize, - i_end: usize, - _j: usize, - ) { - for i in i_begin..i_end { - unsafe { f(Z::get_from_slice_unchecked(&mut slice, i - i_begin)) }; - } - } - - #[inline(always)] - fn annotate_noalias_col MatIndex<'a>>( - f: &mut impl for<'a> FnMut(>::Item), - mut slice: Z::Slice, - i_begin: usize, - i_end: usize, - ) { - for i in i_begin..i_end { - unsafe { f(Z::get_from_slice_unchecked(&mut slice, i - i_begin)) }; - } - } - - #[inline(always)] - fn for_each_mat MatIndex<'a, Rows = usize, Cols = usize, Index = (usize, usize)>>( - z: Z, - mut f: impl for<'a> FnMut(>::Item), - ) { - let layout = z.preferred_layout(); - let mut z = z.with_layout(layout); - - let m = z.nrows(); - let n = z.ncols(); - if m == 0 || n == 0 { - return; - } - - unsafe { - if z.is_contiguous() { - for j in 0..n { - annotate_noalias_mat::(&mut f, z.get_slice_unchecked((0, j), m), 0, m, j); - } - } else { - for j in 0..n { - for i in 0..m { - f(z.get_unchecked((i, j))) - } - } - } - } - } - - #[inline(always)] - fn for_each_mat_triangular_lower< - Z: for<'a> MatIndex< - 'a, - Rows = usize, - Cols = usize, - Index = (usize, usize), - LayoutTransform = MatLayoutTransform, - >, - >( - z: Z, - diag: Diag, - transpose: bool, - mut f: impl for<'a> FnMut(>::Item), - ) { - use MatLayoutTransform::*; - - let z = if transpose { - z.with_layout(MatLayoutTransform::Transpose) - } else { - z - }; - let layout = z.preferred_layout(); - let mut z = z.with_layout(layout); - - let m = z.nrows(); - let n = z.ncols(); - let n = match layout { - None | ReverseRows => Ord::min(m, n), - Transpose | TransposeReverseRows => n, - }; - if m == 0 || n == 0 { - return; - } - - let strict = match diag { - Diag::Skip => true, - Diag::Include => false, - }; - - unsafe { - if z.is_contiguous() { - for j in 0..n { - let (start, end) = match layout { - None => (j + strict as usize, m), - ReverseRows => (0, (m - (j + strict as usize))), - Transpose => (0, (j + !strict as usize).min(m)), - TransposeReverseRows => (m - ((j + !strict as usize).min(m)), m), - }; - - let len = end - start; - - annotate_noalias_mat::( - &mut f, - z.get_slice_unchecked((start, j), len), - start, - end, - j, - ); - } - } else { - for j in 0..n { - let (start, end) = match layout { - None => (j + strict as usize, m), - ReverseRows => (0, (m - (j + strict as usize))), - Transpose => (0, (j + !strict as usize).min(m)), - TransposeReverseRows => (m - ((j + !strict as usize).min(m)), m), - }; - - for i in start..end { - f(z.get_unchecked((i, j))) - } - } - } - } - } - - #[inline(always)] - fn for_each_col MatIndex<'a, Rows = usize, Cols = (), Index = (usize, ())>>( - z: Z, - mut f: impl for<'a> FnMut(>::Item), - ) { - let layout = z.preferred_layout(); - let mut z = z.with_layout(layout); - - let m = z.nrows(); - if m == 0 { - return; - } - - unsafe { - if z.is_contiguous() { - annotate_noalias_col::(&mut f, z.get_slice_unchecked((0, ()), m), 0, m); - } else { - for i in 0..m { - f(z.get_unchecked((i, ()))) - } - } - } - } - - #[inline(always)] - fn for_each_row MatIndex<'a, Rows = (), Cols = usize, Index = ((), usize)>>( - z: Z, - mut f: impl for<'a> FnMut(>::Item), - ) { - let layout = z.preferred_layout(); - let mut z = z.with_layout(layout); - - let n = z.ncols(); - if n == 0 { - return; - } - - unsafe { - if z.is_contiguous() { - annotate_noalias_col::(&mut f, z.get_slice_unchecked(((), 0), n), 0, n); - } else { - for j in 0..n { - f(z.get_unchecked(((), j))) - } - } - } - } - - impl< - M: for<'a> MatIndex< - 'a, - Rows = usize, - Cols = usize, - Index = (usize, usize), - LayoutTransform = MatLayoutTransform, - >, - > LastEq - { - /// Applies `f` to each element of `self`. - #[inline(always)] - pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { - for_each_mat(self, f); - } - - /// Applies `f` to each element of the lower triangular half of `self`. - /// - /// `diag` specifies whether the diagonal should be included or excluded. - #[inline(always)] - pub fn for_each_triangular_lower( - self, - diag: Diag, - f: impl for<'a> FnMut(>::Item), - ) { - for_each_mat_triangular_lower(self, diag, false, f); - } - - /// Applies `f` to each element of the upper triangular half of `self`. - /// - /// `diag` specifies whether the diagonal should be included or excluded. - #[inline(always)] - pub fn for_each_triangular_upper( - self, - diag: Diag, - f: impl for<'a> FnMut(>::Item), - ) { - for_each_mat_triangular_lower(self, diag, true, f); - } - - /// Applies `f` to each element of `self` and collect its result into a new matrix. - #[inline(always)] - pub fn map( - self, - f: impl for<'a> FnMut(>::Item) -> E, - ) -> Mat { - let (m, n) = (self.nrows(), self.ncols()); - let mut out = Mat::::with_capacity(m, n); - let rs = 1; - let cs = out.col_stride(); - let out_view = - unsafe { mat::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), m, n, rs, cs) }; - let mut f = f; - ZipEq::new(out_view, self).for_each( - #[inline(always)] - |Zip(mut out, item)| out.write(f(item)), - ); - unsafe { out.set_dims(m, n) }; - out - } - } - - impl< - M: for<'a> MatIndex< - 'a, - Rows = (), - Cols = usize, - Index = ((), usize), - LayoutTransform = VecLayoutTransform, - >, - > LastEq<(), usize, M> - { - /// Applies `f` to each element of `self`. - #[inline(always)] - pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { - for_each_row(self, f); - } - - /// Applies `f` to each element of `self` and collect its result into a new row. - #[inline(always)] - pub fn map( - self, - f: impl for<'a> FnMut(>::Item) -> E, - ) -> Row { - let (_, n) = (self.nrows(), self.ncols()); - let mut out = Row::::with_capacity(n); - let out_view = unsafe { row::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), n, 1) }; - let mut f = f; - ZipEq::new(out_view, self).for_each( - #[inline(always)] - |Zip(mut out, item)| out.write(f(item)), - ); - unsafe { out.set_ncols(n) }; - out - } - } - - impl< - M: for<'a> MatIndex< - 'a, - Rows = usize, - Cols = (), - Index = (usize, ()), - LayoutTransform = VecLayoutTransform, - >, - > LastEq - { - /// Applies `f` to each element of `self`. - #[inline(always)] - pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { - for_each_col(self, f); - } - - /// Applies `f` to each element of `self` and collect its result into a new column. - #[inline(always)] - pub fn map( - self, - f: impl for<'a> FnMut(>::Item) -> E, - ) -> Col { - let (m, _) = (self.nrows(), self.ncols()); - let mut out = Col::::with_capacity(m); - let out_view = unsafe { col::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), m, 1) }; - let mut f = f; - ZipEq::new(out_view, self).for_each( - #[inline(always)] - |Zip(mut out, item)| out.write(f(item)), - ); - unsafe { out.set_nrows(m) }; - out - } - } - - impl< - Head: for<'a> MatIndex< - 'a, - Rows = (), - Cols = usize, - Index = ((), usize), - LayoutTransform = VecLayoutTransform, - >, - Tail: for<'a> MatIndex< - 'a, - Rows = (), - Cols = usize, - Index = ((), usize), - LayoutTransform = VecLayoutTransform, - >, - > ZipEq<(), usize, Head, Tail> - { - /// Applies `f` to each element of `self`. - #[inline(always)] - pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { - for_each_row(self, f); - } - - /// Applies `f` to each element of `self` and collect its result into a new row. - #[inline(always)] - pub fn map( - self, - f: impl for<'a> FnMut(>::Item) -> E, - ) -> Row { - let (_, n) = (self.nrows(), self.ncols()); - let mut out = Row::::with_capacity(n); - let out_view = unsafe { row::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), n, 1) }; - let mut f = f; - ZipEq::new(out_view, self).for_each( - #[inline(always)] - |Zip(mut out, item)| out.write(f(item)), - ); - unsafe { out.set_ncols(n) }; - out - } - } - - impl< - Head: for<'a> MatIndex< - 'a, - Rows = usize, - Cols = (), - Index = (usize, ()), - LayoutTransform = VecLayoutTransform, - >, - Tail: for<'a> MatIndex< - 'a, - Rows = usize, - Cols = (), - Index = (usize, ()), - LayoutTransform = VecLayoutTransform, - >, - > ZipEq - { - /// Applies `f` to each element of `self`. - #[inline(always)] - pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { - for_each_col(self, f); - } - - /// Applies `f` to each element of `self` and collect its result into a new column. - #[inline(always)] - pub fn map( - self, - f: impl for<'a> FnMut(>::Item) -> E, - ) -> Col { - let (m, _) = (self.nrows(), self.ncols()); - let mut out = Col::::with_capacity(m); - let out_view = unsafe { col::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), m, 1) }; - let mut f = f; - ZipEq::new(out_view, self).for_each( - #[inline(always)] - |Zip(mut out, item)| out.write(f(item)), - ); - unsafe { out.set_nrows(m) }; - out - } - } - - impl< - Head: for<'a> MatIndex< - 'a, - Rows = usize, - Cols = usize, - Index = (usize, usize), - LayoutTransform = MatLayoutTransform, - >, - Tail: for<'a> MatIndex< - 'a, - Rows = usize, - Cols = usize, - Index = (usize, usize), - LayoutTransform = MatLayoutTransform, - >, - > ZipEq - { - /// Applies `f` to each element of `self`. - #[inline(always)] - pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { - for_each_mat(self, f); - } - - /// Applies `f` to each element of the lower triangular half of `self`. - /// - /// `diag` specifies whether the diagonal should be included or excluded. - #[inline(always)] - pub fn for_each_triangular_lower( - self, - diag: Diag, - f: impl for<'a> FnMut(>::Item), - ) { - for_each_mat_triangular_lower(self, diag, false, f); - } - - /// Applies `f` to each element of the upper triangular half of `self`. - /// - /// `diag` specifies whether the diagonal should be included or excluded. - #[inline(always)] - pub fn for_each_triangular_upper( - self, - diag: Diag, - f: impl for<'a> FnMut(>::Item), - ) { - for_each_mat_triangular_lower(self, diag, true, f); - } - - /// Applies `f` to each element of `self` and collect its result into a new matrix. - #[inline(always)] - pub fn map( - self, - f: impl for<'a> FnMut(>::Item) -> E, - ) -> Mat { - let (m, n) = (self.nrows(), self.ncols()); - let mut out = Mat::::with_capacity(m, n); - let rs = 1; - let cs = out.col_stride(); - let out_view = - unsafe { mat::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), m, n, rs, cs) }; - let mut f = f; - ZipEq::new(out_view, self).for_each( - #[inline(always)] - |Zip(mut out, item)| out.write(f(item)), - ); - unsafe { out.set_dims(m, n) }; - out - } - } - - #[cfg(test)] - mod tests { - use super::*; - use crate::{assert, unzipped, zipped, ComplexField, Mat}; - - #[test] - fn test_zip() { - for (m, n) in [(2, 2), (4, 2), (2, 4)] { - for rev_dst in [false, true] { - for rev_src in [false, true] { - for transpose_dst in [false, true] { - for transpose_src in [false, true] { - for diag in [Diag::Include, Diag::Skip] { - let mut dst = Mat::from_fn( - if transpose_dst { n } else { m }, - if transpose_dst { m } else { n }, - |_, _| f64::faer_zero(), - ); - let src = Mat::from_fn( - if transpose_src { n } else { m }, - if transpose_src { m } else { n }, - |_, _| f64::faer_one(), - ); - - let mut target = Mat::from_fn(m, n, |_, _| f64::faer_zero()); - let target_src = Mat::from_fn(m, n, |_, _| f64::faer_one()); - - zipped!(target.as_mut(), target_src.as_ref()) - .for_each_triangular_lower( - diag, - |unzipped!(mut dst, src)| dst.write(src.read()), - ); - - let mut dst = dst.as_mut(); - let mut src = src.as_ref(); - - if transpose_dst { - dst = dst.transpose_mut(); - } - if rev_dst { - dst = dst.reverse_rows_mut(); - } - - if transpose_src { - src = src.transpose(); - } - if rev_src { - src = src.reverse_rows(); - } - - zipped!(dst.rb_mut(), src).for_each_triangular_lower( - diag, - |unzipped!(mut dst, src)| dst.write(src.read()), - ); - - assert!(dst.rb() == target.as_ref()); - } - } - } - } - } - } - - { - let m = 3; - for rev_dst in [false, true] { - for rev_src in [false, true] { - let mut dst = Col::::zeros(m); - let src = Col::from_fn(m, |i| (i + 1) as f64); - - let mut target = Col::::zeros(m); - let target_src = - Col::from_fn(m, |i| if rev_src { m - i } else { i + 1 } as f64); - - zipped!(target.as_mut(), target_src.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - - let mut dst = dst.as_mut(); - let mut src = src.as_ref(); - - if rev_dst { - dst = dst.reverse_rows_mut(); - } - if rev_src { - src = src.reverse_rows(); - } - - zipped!(dst.rb_mut(), src) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - - assert!(dst.rb() == target.as_ref()); - } - } - } - - { - let m = 3; - for rev_dst in [false, true] { - for rev_src in [false, true] { - let mut dst = Row::::zeros(m); - let src = Row::from_fn(m, |i| (i + 1) as f64); - - let mut target = Row::::zeros(m); - let target_src = - Row::from_fn(m, |i| if rev_src { m - i } else { i + 1 } as f64); - - zipped!(target.as_mut(), target_src.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - - let mut dst = dst.as_mut(); - let mut src = src.as_ref(); - - if rev_dst { - dst = dst.reverse_cols_mut(); - } - if rev_src { - src = src.reverse_cols(); - } - - zipped!(&mut dst, src) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - - assert!(dst.rb() == target.as_ref()); - } - } - } - } - } -} diff --git a/faer-libs/faer-core/src/matrix_ops.rs b/faer-libs/faer-core/src/matrix_ops.rs deleted file mode 100644 index 389bc6ab54e7d1635339beb3818b30cadc3984ca..0000000000000000000000000000000000000000 --- a/faer-libs/faer-core/src/matrix_ops.rs +++ /dev/null @@ -1,2303 +0,0 @@ -//! addition and subtraction of matrices - -use super::*; -use crate::{ - assert, - permutation::{Index, SignedIndex}, - sparse, -}; - -/// Scalar value tag. -pub struct Scalar { - __private: (), -} -/// Dense column vector tag. -pub struct DenseCol { - __private: (), -} -/// Dense row vector tag. -pub struct DenseRow { - __private: (), -} -/// Dense matrix tag. -pub struct Dense { - __private: (), -} -/// Sparse column-major matrix tag. -pub struct SparseColMat { - __private: PhantomData, -} -/// Sparse row-major matrix tag. -pub struct SparseRowMat { - __private: PhantomData, -} -/// Diagonal matrix tag. -pub struct Diag { - __private: (), -} -/// Scaling factor tag. -pub struct Scale { - __private: (), -} -/// Permutation matrix tag. -pub struct Perm { - __private: PhantomData, -} - -/// Trait for describing the view and owning variants of a given matrix type tag. -pub trait MatrixKind { - /// Immutable view type. - type Ref<'a, E: Entity>: Copy; - /// Mutable view type. - type Mut<'a, E: Entity>; - /// Owning type. - type Own; -} -type KindRef<'a, E, K> = ::Ref<'a, E>; -type KindMut<'a, E, K> = ::Mut<'a, E>; -type KindOwn = ::Own; - -impl MatrixKind for Scalar { - type Ref<'a, E: Entity> = &'a E; - type Mut<'a, E: Entity> = &'a mut E; - type Own = E; -} -impl MatrixKind for DenseCol { - type Ref<'a, E: Entity> = ColRef<'a, E>; - type Mut<'a, E: Entity> = ColMut<'a, E>; - type Own = Col; -} -impl MatrixKind for DenseRow { - type Ref<'a, E: Entity> = RowRef<'a, E>; - type Mut<'a, E: Entity> = RowMut<'a, E>; - type Own = Row; -} -impl MatrixKind for Dense { - type Ref<'a, E: Entity> = MatRef<'a, E>; - type Mut<'a, E: Entity> = MatMut<'a, E>; - type Own = Mat; -} -impl MatrixKind for SparseColMat { - type Ref<'a, E: Entity> = sparse::SparseColMatRef<'a, I, E>; - type Mut<'a, E: Entity> = sparse::SparseColMatMut<'a, I, E>; - type Own = sparse::SparseColMat; -} -impl MatrixKind for SparseRowMat { - type Ref<'a, E: Entity> = sparse::SparseRowMatRef<'a, I, E>; - type Mut<'a, E: Entity> = sparse::SparseRowMatMut<'a, I, E>; - type Own = sparse::SparseRowMat; -} -impl MatrixKind for Scale { - type Ref<'a, E: Entity> = &'a MatScale; - type Mut<'a, E: Entity> = &'a mut MatScale; - type Own = MatScale; -} -impl MatrixKind for Diag { - type Ref<'a, E: Entity> = Matrix>; - type Mut<'a, E: Entity> = Matrix>; - type Own = Matrix>; -} -impl MatrixKind for Perm { - type Ref<'a, E: Entity> = Matrix>; - type Mut<'a, E: Entity> = Matrix>; - type Own = Matrix>; -} - -/// Generic matrix trait. -pub trait GenericMatrix: Sized { - /// Tag type. - type Kind: MatrixKind; - /// Scalar element type. - type Elem: Entity; - - /// Returns a view over the matrix. - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem>; -} -/// Generic mutable matrix trait. -pub trait GenericMatrixMut: GenericMatrix { - /// Returns a mutable over the matrix. - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem>; -} - -impl GenericMatrix for inner::PermRef<'_, I, E> { - type Kind = Perm; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - *this - } -} -impl GenericMatrix for inner::PermMut<'_, I, E> { - type Kind = Perm; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.rb() - } -} -impl GenericMatrix for inner::PermOwn { - type Kind = Perm; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.as_ref() - } -} - -impl GenericMatrix for inner::DenseRowRef<'_, E> { - type Kind = DenseRow; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - *this - } -} -impl GenericMatrix for inner::DenseRowMut<'_, E> { - type Kind = DenseRow; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.rb() - } -} -impl GenericMatrix for inner::DenseRowOwn { - type Kind = DenseRow; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.as_ref() - } -} -impl GenericMatrixMut for inner::DenseRowMut<'_, E> { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.rb_mut() - } -} -impl GenericMatrixMut for inner::DenseRowOwn { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.as_mut() - } -} - -impl GenericMatrix for inner::DenseColRef<'_, E> { - type Kind = DenseCol; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - *this - } -} -impl GenericMatrix for inner::DenseColMut<'_, E> { - type Kind = DenseCol; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.rb() - } -} -impl GenericMatrix for inner::DenseColOwn { - type Kind = DenseCol; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.as_ref() - } -} -impl GenericMatrixMut for inner::DenseColMut<'_, E> { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.rb_mut() - } -} -impl GenericMatrixMut for inner::DenseColOwn { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.as_mut() - } -} - -impl GenericMatrix for inner::DenseRef<'_, E> { - type Kind = Dense; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - *this - } -} -impl GenericMatrix for inner::DenseMut<'_, E> { - type Kind = Dense; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.rb() - } -} -impl GenericMatrix for inner::DenseOwn { - type Kind = Dense; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.as_ref() - } -} -impl GenericMatrixMut for inner::DenseMut<'_, E> { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.rb_mut() - } -} -impl GenericMatrixMut for inner::DenseOwn { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.as_mut() - } -} - -impl GenericMatrix for inner::DiagRef<'_, E> { - type Kind = Diag; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - *this - } -} -impl GenericMatrix for inner::DiagMut<'_, E> { - type Kind = Diag; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.rb() - } -} -impl GenericMatrix for inner::DiagOwn { - type Kind = Diag; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.as_ref() - } -} -impl GenericMatrixMut for inner::DiagMut<'_, E> { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.rb_mut() - } -} -impl GenericMatrixMut for inner::DiagOwn { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.as_mut() - } -} - -impl GenericMatrix for inner::Scale { - type Kind = Scale; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this - } -} -impl GenericMatrixMut for inner::Scale { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this - } -} - -impl GenericMatrix for inner::SparseColMatRef<'_, I, E> { - type Kind = SparseColMat; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - *this - } -} - -impl GenericMatrix for inner::SparseRowMatRef<'_, I, E> { - type Kind = SparseRowMat; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - *this - } -} - -impl GenericMatrix for inner::SparseColMatMut<'_, I, E> { - type Kind = SparseColMat; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.rb() - } -} - -impl GenericMatrix for inner::SparseRowMatMut<'_, I, E> { - type Kind = SparseRowMat; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.rb() - } -} - -impl GenericMatrixMut for inner::SparseColMatMut<'_, I, E> { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.rb_mut() - } -} - -impl GenericMatrixMut for inner::SparseRowMatMut<'_, I, E> { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.rb_mut() - } -} - -impl GenericMatrix for inner::SparseColMat { - type Kind = SparseColMat; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.as_ref() - } -} - -impl GenericMatrix for inner::SparseRowMat { - type Kind = SparseRowMat; - type Elem = E; - - #[inline(always)] - fn as_ref(this: &Matrix) -> ::Ref<'_, Self::Elem> { - this.as_ref() - } -} - -impl GenericMatrixMut for inner::SparseColMat { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.as_mut() - } -} - -impl GenericMatrixMut for inner::SparseRowMat { - #[inline(always)] - fn as_mut(this: &mut Matrix) -> ::Mut<'_, Self::Elem> { - this.as_mut() - } -} - -mod __matmul_assign { - use super::*; - - impl MatMulAssign for DenseCol { - #[track_caller] - fn mat_mul_assign>( - lhs: KindMut<'_, E, DenseCol>, - rhs: KindRef<'_, RhsE, Scale>, - ) { - let rhs = rhs.value().canonicalize(); - zipped!(lhs.as_2d_mut()) - .for_each(|unzipped!(mut lhs)| lhs.write(lhs.read().faer_mul(rhs))); - } - } - impl MatMulAssign for DenseRow { - #[track_caller] - fn mat_mul_assign>( - lhs: KindMut<'_, E, DenseRow>, - rhs: KindRef<'_, RhsE, Scale>, - ) { - let rhs = rhs.value().canonicalize(); - zipped!(lhs.as_2d_mut()) - .for_each(|unzipped!(mut lhs)| lhs.write(lhs.read().faer_mul(rhs))); - } - } - impl MatMulAssign for Dense { - #[track_caller] - fn mat_mul_assign>( - lhs: KindMut<'_, E, Dense>, - rhs: KindRef<'_, RhsE, Scale>, - ) { - let rhs = rhs.value().canonicalize(); - zipped!(lhs).for_each(|unzipped!(mut lhs)| lhs.write(lhs.read().faer_mul(rhs))); - } - } - impl MatMulAssign for Scale { - #[track_caller] - fn mat_mul_assign>( - lhs: KindMut<'_, E, Scale>, - rhs: KindRef<'_, RhsE, Scale>, - ) { - let rhs = rhs.value().canonicalize(); - *lhs = scale((*lhs).value().faer_mul(rhs)); - } - } - - impl MatMulAssign for Diag { - #[track_caller] - fn mat_mul_assign>( - lhs: KindMut<'_, E, Diag>, - rhs: KindRef<'_, RhsE, Diag>, - ) { - zipped!(lhs.inner.inner.as_2d_mut(), rhs.inner.inner.as_2d()).for_each( - |unzipped!(mut lhs, rhs)| lhs.write(lhs.read().faer_mul(rhs.read().canonicalize())), - ); - } - } -} - -mod __matmul { - use super::*; - use crate::{assert, permutation::Permutation}; - - impl MatMul> for Perm { - type Output = Perm; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Perm>, - rhs: KindRef<'_, RhsE, Perm>, - ) -> KindOwn { - assert!(lhs.len() == rhs.len()); - let truncate = ::truncate; - let mut fwd = alloc::vec![I::from_signed(truncate(0)); lhs.len()].into_boxed_slice(); - let mut inv = alloc::vec![I::from_signed(truncate(0)); lhs.len()].into_boxed_slice(); - - for (fwd, rhs) in fwd.iter_mut().zip(rhs.inner.forward) { - *fwd = lhs.inner.forward[rhs.to_signed().zx()]; - } - for (i, fwd) in fwd.iter().enumerate() { - inv[fwd.to_signed().zx()] = I::from_signed(I::Signed::truncate(i)); - } - - Permutation { - inner: inner::PermOwn { - forward: fwd, - inverse: inv, - __marker: core::marker::PhantomData, - }, - } - } - } - - impl MatMul for Perm { - type Output = DenseCol; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Perm>, - rhs: KindRef<'_, RhsE, DenseCol>, - ) -> KindOwn { - assert!(lhs.len() == rhs.nrows()); - let mut out = Col::zeros(rhs.nrows()); - let fwd = lhs.inner.forward; - for (i, fwd) in fwd.iter().enumerate() { - out.write(i, rhs.read(fwd.to_signed().zx()).canonicalize()); - } - out - } - } - impl MatMul for Perm { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Perm>, - rhs: KindRef<'_, RhsE, Dense>, - ) -> KindOwn { - assert!(lhs.len() == rhs.nrows()); - let mut out = Mat::zeros(rhs.nrows(), rhs.ncols()); - let fwd = lhs.inner.forward; - - for j in 0..rhs.ncols() { - for (i, fwd) in fwd.iter().enumerate() { - out.write(i, j, rhs.read(fwd.to_signed().zx(), j).canonicalize()); - } - } - out - } - } - impl MatMul> for DenseRow { - type Output = DenseRow; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, DenseRow>, - rhs: KindRef<'_, RhsE, Perm>, - ) -> KindOwn { - assert!(lhs.ncols() == rhs.len()); - let mut out = Row::zeros(lhs.ncols()); - let inv = rhs.inner.inverse; - - for (j, inv) in inv.iter().enumerate() { - out.write(j, lhs.read(inv.to_signed().zx()).canonicalize()); - } - out - } - } - impl MatMul> for Dense { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Dense>, - rhs: KindRef<'_, RhsE, Perm>, - ) -> KindOwn { - assert!(lhs.ncols() == rhs.len()); - let mut out = Mat::zeros(lhs.nrows(), lhs.ncols()); - let inv = rhs.inner.inverse; - - for (j, inv) in inv.iter().enumerate() { - for i in 0..lhs.nrows() { - out.write(i, j, lhs.read(i, inv.to_signed().zx()).canonicalize()); - } - } - out - } - } - - impl MatMul for Scale { - type Output = DenseCol; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Scale>, - rhs: KindRef<'_, RhsE, DenseCol>, - ) -> KindOwn { - let mut out = Col::::zeros(rhs.nrows()); - let lhs = lhs.inner.0.canonicalize(); - zipped!(out.as_mut().as_2d_mut(), rhs.as_2d()).for_each(|unzipped!(mut out, rhs)| { - out.write(E::faer_mul(lhs, rhs.read().canonicalize())) - }); - out - } - } - impl MatMul for DenseCol { - type Output = DenseCol; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, DenseCol>, - rhs: KindRef<'_, RhsE, Scale>, - ) -> KindOwn { - let mut out = Col::::zeros(lhs.nrows()); - let rhs = rhs.inner.0.canonicalize(); - zipped!(out.as_mut().as_2d_mut(), lhs.as_2d()).for_each(|unzipped!(mut out, lhs)| { - out.write(E::faer_mul(lhs.read().canonicalize(), rhs)) - }); - out - } - } - impl MatMul for Scale { - type Output = DenseRow; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Scale>, - rhs: KindRef<'_, RhsE, DenseRow>, - ) -> KindOwn { - let mut out = Row::::zeros(rhs.nrows()); - let lhs = lhs.inner.0.canonicalize(); - zipped!(out.as_mut().as_2d_mut(), rhs.as_2d()).for_each(|unzipped!(mut out, rhs)| { - out.write(E::faer_mul(lhs, rhs.read().canonicalize())) - }); - out - } - } - impl MatMul for DenseRow { - type Output = DenseRow; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, DenseRow>, - rhs: KindRef<'_, RhsE, Scale>, - ) -> KindOwn { - let mut out = Row::::zeros(lhs.nrows()); - let rhs = rhs.inner.0.canonicalize(); - zipped!(out.as_mut().as_2d_mut(), lhs.as_2d()).for_each(|unzipped!(mut out, lhs)| { - out.write(E::faer_mul(lhs.read().canonicalize(), rhs)) - }); - out - } - } - impl MatMul for Scale { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Scale>, - rhs: KindRef<'_, RhsE, Dense>, - ) -> KindOwn { - let mut out = Mat::::zeros(rhs.nrows(), rhs.ncols()); - let lhs = lhs.inner.0.canonicalize(); - zipped!(out.as_mut(), rhs).for_each(|unzipped!(mut out, rhs)| { - out.write(E::faer_mul(lhs, rhs.read().canonicalize())) - }); - out - } - } - impl MatMul for Dense { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Dense>, - rhs: KindRef<'_, RhsE, Scale>, - ) -> KindOwn { - let mut out = Mat::::zeros(lhs.nrows(), lhs.ncols()); - let rhs = rhs.inner.0.canonicalize(); - zipped!(out.as_mut(), lhs).for_each(|unzipped!(mut out, lhs)| { - out.write(E::faer_mul(lhs.read().canonicalize(), rhs)) - }); - out - } - } - impl MatMul for Scale { - type Output = Scale; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Scale>, - rhs: KindRef<'_, RhsE, Scale>, - ) -> KindOwn { - scale(E::faer_mul( - lhs.inner.0.canonicalize(), - rhs.inner.0.canonicalize(), - )) - } - } - - impl MatMul for DenseRow { - type Output = DenseRow; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, DenseRow>, - rhs: KindRef<'_, RhsE, Diag>, - ) -> KindOwn { - let lhs_ncols = lhs.ncols(); - let rhs_dim = rhs.inner.inner.nrows(); - assert!(lhs_ncols == rhs_dim); - - Row::from_fn(lhs_ncols, |j| unsafe { - E::faer_mul( - lhs.read_unchecked(j).canonicalize(), - rhs.inner.inner.read_unchecked(j).canonicalize(), - ) - }) - } - } - impl MatMul for Dense { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Dense>, - rhs: KindRef<'_, RhsE, Diag>, - ) -> KindOwn { - let lhs_ncols = lhs.ncols(); - let rhs_dim = rhs.inner.inner.nrows(); - assert!(lhs_ncols == rhs_dim); - - Mat::from_fn(lhs.nrows(), lhs.ncols(), |i, j| unsafe { - E::faer_mul( - lhs.read_unchecked(i, j).canonicalize(), - rhs.inner.inner.read_unchecked(j).canonicalize(), - ) - }) - } - } - - impl MatMul for Diag { - type Output = DenseCol; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Diag>, - rhs: KindRef<'_, RhsE, DenseCol>, - ) -> KindOwn { - let lhs_dim = lhs.inner.inner.nrows(); - let rhs_nrows = rhs.nrows(); - assert!(lhs_dim == rhs_nrows); - - Col::from_fn(rhs.nrows(), |i| unsafe { - E::faer_mul( - lhs.inner.inner.read_unchecked(i).canonicalize(), - rhs.read_unchecked(i).canonicalize(), - ) - }) - } - } - impl MatMul for Diag { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Diag>, - rhs: KindRef<'_, RhsE, Dense>, - ) -> KindOwn { - let lhs_dim = lhs.inner.inner.nrows(); - let rhs_nrows = rhs.nrows(); - assert!(lhs_dim == rhs_nrows); - - Mat::from_fn(rhs.nrows(), rhs.ncols(), |i, j| unsafe { - E::faer_mul( - lhs.inner.inner.read_unchecked(i).canonicalize(), - rhs.read_unchecked(i, j).canonicalize(), - ) - }) - } - } - - impl MatMul for Diag { - type Output = Diag; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Diag>, - rhs: KindRef<'_, RhsE, Diag>, - ) -> KindOwn { - let lhs_dim = lhs.inner.inner.nrows(); - let rhs_dim = rhs.inner.inner.nrows(); - assert!(lhs_dim == rhs_dim); - - Matrix { - inner: inner::DiagOwn { - inner: Col::from_fn(lhs_dim, |i| unsafe { - E::faer_mul( - lhs.inner.inner.read_unchecked(i).canonicalize(), - rhs.inner.inner.read_unchecked(i).canonicalize(), - ) - }), - }, - } - } - } - - impl MatMul for Dense { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> KindOwn { - assert!(lhs.ncols() == rhs.nrows()); - let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); - mul::matmul( - out.as_mut(), - lhs, - rhs, - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } - - impl MatMul for Dense { - type Output = DenseCol; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Dense>, - rhs: KindRef<'_, RhsE, DenseCol>, - ) -> KindOwn { - assert!(lhs.ncols() == rhs.nrows()); - let mut out = Col::zeros(lhs.nrows()); - mul::matmul( - out.as_mut().as_2d_mut(), - lhs, - rhs.as_2d(), - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } - impl MatMul for DenseRow { - type Output = DenseRow; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, DenseRow>, - rhs: KindRef<'_, RhsE, Dense>, - ) -> KindOwn { - assert!(lhs.ncols() == rhs.nrows()); - let mut out = Row::zeros(lhs.nrows()); - mul::matmul( - out.as_mut().as_2d_mut(), - lhs.as_2d(), - rhs, - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } - - impl MatMul for DenseRow { - type Output = Scalar; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, DenseRow>, - rhs: KindRef<'_, RhsE, DenseCol>, - ) -> KindOwn { - assert!(lhs.ncols() == rhs.nrows()); - let (lhs, conj_lhs) = lhs.canonicalize(); - let (rhs, conj_rhs) = rhs.canonicalize(); - - crate::mul::inner_prod::inner_prod_with_conj( - lhs.transpose().as_2d(), - conj_lhs, - rhs.as_2d(), - conj_rhs, - ) - } - } - - impl MatMul for DenseCol { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, DenseCol>, - rhs: KindRef<'_, RhsE, DenseRow>, - ) -> KindOwn { - assert!(lhs.ncols() == rhs.nrows()); - let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); - mul::matmul( - out.as_mut(), - lhs.as_2d(), - rhs.as_2d(), - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } - - impl MatMul> for Dense { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, SparseColMat>, - ) -> KindOwn { - let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); - sparse::mul::dense_sparse_matmul( - out.as_mut(), - lhs, - rhs, - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } - impl MatMul> for DenseRow { - type Output = DenseRow; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, SparseColMat>, - ) -> KindOwn { - let mut out = Row::zeros(rhs.ncols()); - sparse::mul::dense_sparse_matmul( - out.as_mut().as_2d_mut(), - lhs.as_2d(), - rhs, - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } - impl MatMul> for Dense { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, SparseRowMat>, - ) -> KindOwn { - let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); - sparse::mul::sparse_dense_matmul( - out.as_mut().transpose_mut(), - rhs.transpose(), - lhs.transpose(), - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } - impl MatMul> for DenseRow { - type Output = DenseRow; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, SparseRowMat>, - ) -> KindOwn { - let mut out = Row::zeros(rhs.ncols()); - sparse::mul::sparse_dense_matmul( - out.as_mut().transpose_mut().as_2d_mut(), - rhs.transpose(), - lhs.transpose().as_2d(), - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } - - impl MatMul for SparseColMat { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Dense>, - ) -> KindOwn { - let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); - sparse::mul::sparse_dense_matmul( - out.as_mut(), - lhs, - rhs, - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } - impl MatMul for SparseColMat { - type Output = DenseCol; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, DenseCol>, - ) -> KindOwn { - let mut out = Col::zeros(lhs.nrows()); - sparse::mul::sparse_dense_matmul( - out.as_mut().as_2d_mut(), - lhs, - rhs.as_2d(), - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } - impl MatMul for SparseRowMat { - type Output = Dense; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Dense>, - ) -> KindOwn { - let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); - sparse::mul::dense_sparse_matmul( - out.as_mut().transpose_mut(), - rhs.transpose(), - lhs.transpose(), - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } - impl MatMul for SparseRowMat { - type Output = DenseCol; - - #[track_caller] - fn mat_mul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, DenseCol>, - ) -> KindOwn { - let mut out = Col::zeros(lhs.nrows()); - sparse::mul::dense_sparse_matmul( - out.as_mut().transpose_mut().as_2d_mut(), - rhs.transpose().as_2d(), - lhs.transpose(), - None, - E::faer_one(), - get_global_parallelism(), - ); - out - } - } -} - -/// Matrix multiplication. -pub trait MatMulAssign: MatrixKind { - /// Computes `lhs * rhs` and assigns it to `lhs`. - fn mat_mul_assign>( - lhs: KindMut<'_, E, Self>, - rhs: KindRef<'_, RhsE, Rhs>, - ); -} -/// Matrix addition. -pub trait MatAddAssign: MatrixKind { - /// Computes `lhs + rhs` and assigns it to `lhs`. - fn mat_add_assign>( - lhs: KindMut<'_, E, Self>, - rhs: KindRef<'_, RhsE, Rhs>, - ); -} -/// Matrix subtraction. -pub trait MatSubAssign: MatrixKind { - /// Computes `lhs - rhs` and assigns it to `lhs`. - fn mat_sub_assign>( - lhs: KindMut<'_, E, Self>, - rhs: KindRef<'_, RhsE, Rhs>, - ); -} - -/// Matrix equality comparison. -pub trait MatEq: MatrixKind { - /// Computes `lhs == rhs`. - fn mat_eq, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Rhs>, - ) -> bool; -} - -/// Matrix multiplication. -pub trait MatMul: MatrixKind { - /// Result matrix type. - type Output: MatrixKind; - - /// Returns `lhs * rhs`. - fn mat_mul, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Rhs>, - ) -> KindOwn; -} -/// Matrix addition. -pub trait MatAdd: MatrixKind { - /// Result matrix type. - type Output: MatrixKind; - - /// Returns `lhs + rhs`. - fn mat_add, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Rhs>, - ) -> KindOwn; -} -/// Matrix subtraction. -pub trait MatSub: MatrixKind { - /// Result matrix type. - type Output: MatrixKind; - - /// Returns `lhs - rhs`. - fn mat_sub, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Rhs>, - ) -> KindOwn; -} -/// Matrix negation. -pub trait MatNeg: MatrixKind { - /// Result matrix type. - type Output: MatrixKind; - - /// Returns `-mat`. - fn mat_neg(mat: KindRef<'_, E, Self>) -> KindOwn - where - E::Canonical: ComplexField; -} - -impl MatEq> for Perm { - #[track_caller] - fn mat_eq, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> bool { - lhs.inner.forward == rhs.inner.forward - } -} - -impl MatEq for DenseCol { - #[track_caller] - fn mat_eq, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> bool { - lhs.as_2d() == rhs.as_2d() - } -} -impl MatEq for DenseRow { - #[track_caller] - fn mat_eq, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> bool { - lhs.as_2d() == rhs.as_2d() - } -} - -impl MatEq for Dense { - #[track_caller] - fn mat_eq, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> bool { - if (lhs.nrows(), lhs.ncols()) != (rhs.nrows(), rhs.ncols()) { - return false; - } - let m = lhs.nrows(); - let n = lhs.ncols(); - for j in 0..n { - for i in 0..m { - if !(lhs.read(i, j).canonicalize() == rhs.read(i, j).canonicalize()) { - return false; - } - } - } - - true - } -} - -impl MatEq> for SparseColMat { - fn mat_eq, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> bool { - if lhs.nrows() != rhs.nrows() || lhs.ncols() != rhs.ncols() { - return false; - } - - let n = lhs.ncols(); - let mut equal = true; - for j in 0..n { - equal &= lhs.row_indices_of_col_raw(j) == rhs.row_indices_of_col_raw(j); - let lhs_val = SliceGroup::<'_, LhsE>::new(lhs.values_of_col(j)); - let rhs_val = SliceGroup::<'_, RhsE>::new(rhs.values_of_col(j)); - equal &= lhs_val - .into_ref_iter() - .map(|r| r.read().canonicalize()) - .eq(rhs_val.into_ref_iter().map(|r| r.read().canonicalize())) - } - - equal - } -} - -impl MatEq> for SparseRowMat { - fn mat_eq, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> bool { - lhs.transpose() == rhs.transpose() - } -} - -impl MatAdd for DenseCol { - type Output = DenseCol; - - #[track_caller] - fn mat_add, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> KindOwn { - assert!(all(lhs.nrows() == rhs.nrows(), lhs.ncols() == rhs.ncols())); - let mut out = Col::::zeros(lhs.nrows()); - zipped!(out.as_mut().as_2d_mut(), lhs.as_2d(), rhs.as_2d()).for_each( - |unzipped!(mut out, lhs, rhs)| { - out.write(E::faer_add( - lhs.read().canonicalize(), - rhs.read().canonicalize(), - )) - }, - ); - out - } -} -impl MatSub for DenseCol { - type Output = DenseCol; - - #[track_caller] - fn mat_sub, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> KindOwn { - assert!(all(lhs.nrows() == rhs.nrows(), lhs.ncols() == rhs.ncols())); - let mut out = Col::::zeros(lhs.nrows()); - zipped!(out.as_mut().as_2d_mut(), lhs.as_2d(), rhs.as_2d()).for_each( - |unzipped!(mut out, lhs, rhs)| { - out.write(E::faer_sub( - lhs.read().canonicalize(), - rhs.read().canonicalize(), - )) - }, - ); - out - } -} -impl MatAddAssign for DenseCol { - #[track_caller] - fn mat_add_assign>( - lhs: KindMut<'_, E, DenseCol>, - rhs: KindRef<'_, RhsE, DenseCol>, - ) { - zipped!(lhs.as_2d_mut(), rhs.as_2d()).for_each(|unzipped!(mut lhs, rhs)| { - lhs.write(lhs.read().faer_add(rhs.read().canonicalize())) - }); - } -} -impl MatSubAssign for DenseCol { - #[track_caller] - fn mat_sub_assign>( - lhs: KindMut<'_, E, DenseCol>, - rhs: KindRef<'_, RhsE, DenseCol>, - ) { - zipped!(lhs.as_2d_mut(), rhs.as_2d()).for_each(|unzipped!(mut lhs, rhs)| { - lhs.write(lhs.read().faer_sub(rhs.read().canonicalize())) - }); - } -} - -impl MatNeg for DenseCol { - type Output = DenseCol; - - fn mat_neg(mat: KindRef<'_, E, Self>) -> KindOwn - where - E::Canonical: ComplexField, - { - let mut out = Col::::zeros(mat.nrows()); - zipped!(out.as_mut().as_2d_mut(), mat.as_2d()) - .for_each(|unzipped!(mut out, src)| out.write(src.read().canonicalize().faer_neg())); - out - } -} - -impl MatAdd for DenseRow { - type Output = DenseRow; - - #[track_caller] - fn mat_add, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> KindOwn { - assert!(all(lhs.nrows() == rhs.nrows(), lhs.ncols() == rhs.ncols())); - let mut out = Row::::zeros(lhs.nrows()); - zipped!(out.as_mut().as_2d_mut(), lhs.as_2d(), rhs.as_2d()).for_each( - |unzipped!(mut out, lhs, rhs)| { - out.write(E::faer_add( - lhs.read().canonicalize(), - rhs.read().canonicalize(), - )) - }, - ); - out - } -} -impl MatSub for DenseRow { - type Output = DenseRow; - - #[track_caller] - fn mat_sub, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> KindOwn { - assert!(all(lhs.nrows() == rhs.nrows(), lhs.ncols() == rhs.ncols())); - let mut out = Row::::zeros(lhs.nrows()); - zipped!(out.as_mut().as_2d_mut(), lhs.as_2d(), rhs.as_2d()).for_each( - |unzipped!(mut out, lhs, rhs)| { - out.write(E::faer_sub( - lhs.read().canonicalize(), - rhs.read().canonicalize(), - )) - }, - ); - out - } -} -impl MatAddAssign for DenseRow { - #[track_caller] - fn mat_add_assign>( - lhs: KindMut<'_, E, DenseRow>, - rhs: KindRef<'_, RhsE, DenseRow>, - ) { - zipped!(lhs.as_2d_mut(), rhs.as_2d()).for_each(|unzipped!(mut lhs, rhs)| { - lhs.write(lhs.read().faer_add(rhs.read().canonicalize())) - }); - } -} -impl MatSubAssign for DenseRow { - #[track_caller] - fn mat_sub_assign>( - lhs: KindMut<'_, E, DenseRow>, - rhs: KindRef<'_, RhsE, DenseRow>, - ) { - zipped!(lhs.as_2d_mut(), rhs.as_2d()).for_each(|unzipped!(mut lhs, rhs)| { - lhs.write(lhs.read().faer_sub(rhs.read().canonicalize())) - }); - } -} - -impl MatNeg for DenseRow { - type Output = DenseRow; - - fn mat_neg(mat: KindRef<'_, E, Self>) -> KindOwn - where - E::Canonical: ComplexField, - { - let mut out = Row::::zeros(mat.nrows()); - zipped!(out.as_mut().as_2d_mut(), mat.as_2d()) - .for_each(|unzipped!(mut out, src)| out.write(src.read().canonicalize().faer_neg())); - out - } -} - -impl MatAdd for Dense { - type Output = Dense; - - #[track_caller] - fn mat_add, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> KindOwn { - assert!(all(lhs.nrows() == rhs.nrows(), lhs.ncols() == rhs.ncols())); - let mut out = Mat::::zeros(lhs.nrows(), rhs.ncols()); - zipped!(out.as_mut(), lhs, rhs).for_each(|unzipped!(mut out, lhs, rhs)| { - out.write(E::faer_add( - lhs.read().canonicalize(), - rhs.read().canonicalize(), - )) - }); - out - } -} -impl MatSub for Dense { - type Output = Dense; - - #[track_caller] - fn mat_sub, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> KindOwn { - assert!(all(lhs.nrows() == rhs.nrows(), lhs.ncols() == rhs.ncols())); - let mut out = Mat::::zeros(lhs.nrows(), rhs.ncols()); - zipped!(out.as_mut(), lhs, rhs).for_each(|unzipped!(mut out, lhs, rhs)| { - out.write(E::faer_sub( - lhs.read().canonicalize(), - rhs.read().canonicalize(), - )) - }); - out - } -} -impl MatAddAssign for Dense { - #[track_caller] - fn mat_add_assign>( - lhs: KindMut<'_, E, Dense>, - rhs: KindRef<'_, RhsE, Dense>, - ) { - zipped!(lhs, rhs).for_each(|unzipped!(mut lhs, rhs)| { - lhs.write(lhs.read().faer_add(rhs.read().canonicalize())) - }); - } -} -impl MatSubAssign for Dense { - #[track_caller] - fn mat_sub_assign>( - lhs: KindMut<'_, E, Dense>, - rhs: KindRef<'_, RhsE, Dense>, - ) { - zipped!(lhs, rhs).for_each(|unzipped!(mut lhs, rhs)| { - lhs.write(lhs.read().faer_sub(rhs.read().canonicalize())) - }); - } -} - -impl MatNeg for Dense { - type Output = Dense; - - fn mat_neg(mat: KindRef<'_, E, Self>) -> KindOwn - where - E::Canonical: ComplexField, - { - let mut out = Mat::::zeros(mat.nrows(), mat.ncols()); - zipped!(out.as_mut(), mat) - .for_each(|unzipped!(mut out, src)| out.write(src.read().canonicalize().faer_neg())); - out - } -} - -impl MatAdd> for SparseColMat { - type Output = SparseColMat; - - #[track_caller] - fn mat_add, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> KindOwn { - crate::sparse::ops::add(lhs, rhs).unwrap() - } -} -impl MatAdd> for SparseRowMat { - type Output = SparseColMat; - - #[track_caller] - fn mat_add, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> KindOwn { - crate::sparse::ops::add(lhs.transpose(), rhs.transpose()) - .unwrap() - .into_transpose() - .to_col_major() - .unwrap() - } -} -impl MatAdd> for SparseColMat { - type Output = SparseColMat; - - #[track_caller] - fn mat_add, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, SparseRowMat>, - ) -> KindOwn { - crate::sparse::ops::add(lhs, rhs.to_col_major().unwrap().as_ref()).unwrap() - } -} -impl MatAdd> for SparseRowMat { - type Output = SparseColMat; - - #[track_caller] - fn mat_add, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, SparseColMat>, - ) -> KindOwn { - crate::sparse::ops::add(lhs.to_col_major().unwrap().as_ref(), rhs).unwrap() - } -} - -impl MatSub> for SparseColMat { - type Output = SparseColMat; - - #[track_caller] - fn mat_sub, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> KindOwn { - crate::sparse::ops::sub(lhs, rhs).unwrap() - } -} -impl MatSub> for SparseRowMat { - type Output = SparseColMat; - - #[track_caller] - fn mat_sub, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, Self>, - ) -> KindOwn { - crate::sparse::ops::sub(lhs.transpose(), rhs.transpose()) - .unwrap() - .into_transpose() - .to_col_major() - .unwrap() - } -} -impl MatSub> for SparseColMat { - type Output = SparseColMat; - - #[track_caller] - fn mat_sub, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, SparseRowMat>, - ) -> KindOwn { - crate::sparse::ops::sub(lhs, rhs.to_col_major().unwrap().as_ref()).unwrap() - } -} -impl MatSub> for SparseRowMat { - type Output = SparseColMat; - - #[track_caller] - fn mat_sub, RhsE: Conjugate>( - lhs: KindRef<'_, LhsE, Self>, - rhs: KindRef<'_, RhsE, SparseColMat>, - ) -> KindOwn { - crate::sparse::ops::sub(lhs.to_col_major().unwrap().as_ref(), rhs).unwrap() - } -} - -impl MatNeg for SparseColMat { - type Output = SparseColMat; - - fn mat_neg(mat: KindRef<'_, E, Self>) -> KindOwn - where - E::Canonical: ComplexField, - { - let mut out = mat.to_owned().unwrap(); - crate::group_helpers::SliceGroupMut::::new(out.as_mut().values_mut()) - .into_mut_iter() - .for_each(|mut x| x.write(x.read().faer_neg())); - out - } -} - -impl MatNeg for SparseRowMat { - type Output = SparseColMat; - - fn mat_neg(mat: KindRef<'_, E, Self>) -> KindOwn - where - E::Canonical: ComplexField, - { - let mut out = mat.to_col_major().unwrap(); - crate::group_helpers::SliceGroupMut::::new(out.as_mut().values_mut()) - .into_mut_iter() - .for_each(|mut x| x.write(x.read().faer_neg())); - out - } -} - -impl MatAddAssign> for SparseColMat { - #[track_caller] - fn mat_add_assign>( - lhs: KindMut<'_, E, SparseColMat>, - rhs: KindRef<'_, RhsE, SparseColMat>, - ) { - crate::sparse::ops::add_assign(lhs, rhs); - } -} -impl MatSubAssign> for SparseColMat { - #[track_caller] - fn mat_sub_assign>( - lhs: KindMut<'_, E, SparseColMat>, - rhs: KindRef<'_, RhsE, SparseColMat>, - ) { - crate::sparse::ops::sub_assign(lhs, rhs); - } -} - -impl MatAddAssign> for SparseRowMat { - #[track_caller] - fn mat_add_assign>( - lhs: KindMut<'_, E, SparseRowMat>, - rhs: KindRef<'_, RhsE, SparseRowMat>, - ) { - crate::sparse::ops::add_assign(lhs.transpose_mut(), rhs.transpose()); - } -} -impl MatSubAssign> for SparseRowMat { - #[track_caller] - fn mat_sub_assign>( - lhs: KindMut<'_, E, SparseRowMat>, - rhs: KindRef<'_, RhsE, SparseRowMat>, - ) { - crate::sparse::ops::sub_assign(lhs.transpose_mut(), rhs.transpose()); - } -} - -impl MatAddAssign> for SparseRowMat { - #[track_caller] - fn mat_add_assign>( - lhs: KindMut<'_, E, SparseRowMat>, - rhs: KindRef<'_, RhsE, SparseColMat>, - ) { - crate::sparse::ops::add_assign( - lhs.transpose_mut(), - rhs.transpose().to_col_major().unwrap().as_ref(), - ); - } -} -impl MatSubAssign> for SparseRowMat { - #[track_caller] - fn mat_sub_assign>( - lhs: KindMut<'_, E, SparseRowMat>, - rhs: KindRef<'_, RhsE, SparseColMat>, - ) { - crate::sparse::ops::sub_assign( - lhs.transpose_mut(), - rhs.transpose().to_col_major().unwrap().as_ref(), - ); - } -} - -impl MatAddAssign> for SparseColMat { - #[track_caller] - fn mat_add_assign>( - lhs: KindMut<'_, E, SparseColMat>, - rhs: KindRef<'_, RhsE, SparseRowMat>, - ) { - crate::sparse::ops::add_assign(lhs, rhs.to_col_major().unwrap().as_ref()); - } -} -impl MatSubAssign> for SparseColMat { - #[track_caller] - fn mat_sub_assign>( - lhs: KindMut<'_, E, SparseColMat>, - rhs: KindRef<'_, RhsE, SparseRowMat>, - ) { - crate::sparse::ops::sub_assign(lhs, rhs.to_col_major().unwrap().as_ref()); - } -} - -/// Returns a scaling factor with the given value. -#[inline(always)] -pub fn scale(value: E) -> Matrix> { - Matrix { - inner: inner::Scale(value), - } -} - -const _: () = { - use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; - - impl Mul<&Matrix> for &Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatMul, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn mul(self, rhs: &Matrix) -> Self::Output { - >::mat_mul( - GenericMatrix::as_ref(self), - GenericMatrix::as_ref(rhs), - ) - } - } - impl Mul<&Matrix> for Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatMul, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn mul(self, rhs: &Matrix) -> Self::Output { - &self * rhs - } - } - impl Mul> for &Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatMul, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn mul(self, rhs: Matrix) -> Self::Output { - self * &rhs - } - } - - impl Mul> for Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatMul, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn mul(self, rhs: Matrix) -> Self::Output { - &self * &rhs - } - } - - impl Add<&Matrix> for &Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatAdd, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn add(self, rhs: &Matrix) -> Self::Output { - >::mat_add( - GenericMatrix::as_ref(self), - GenericMatrix::as_ref(rhs), - ) - } - } - impl Add<&Matrix> for Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatAdd, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn add(self, rhs: &Matrix) -> Self::Output { - &self + rhs - } - } - impl Add> for &Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatAdd, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn add(self, rhs: Matrix) -> Self::Output { - self + &rhs - } - } - impl Add> for Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatAdd, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn add(self, rhs: Matrix) -> Self::Output { - &self + &rhs - } - } - - impl Sub<&Matrix> for &Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatSub, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn sub(self, rhs: &Matrix) -> Self::Output { - >::mat_sub( - GenericMatrix::as_ref(self), - GenericMatrix::as_ref(rhs), - ) - } - } - - impl Sub<&Matrix> for Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatSub, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn sub(self, rhs: &Matrix) -> Self::Output { - &self - rhs - } - } - impl Sub> for &Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatSub, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn sub(self, rhs: Matrix) -> Self::Output { - self - &rhs - } - } - impl Sub> for Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatSub, - { - type Output = - KindOwn<::Canonical, >::Output>; - - #[track_caller] - fn sub(self, rhs: Matrix) -> Self::Output { - &self - &rhs - } - } - - impl Neg for &Matrix - where - Mat::Elem: Conjugate, - ::Canonical: ComplexField, - Mat::Kind: MatNeg, - { - type Output = KindOwn<::Canonical, ::Output>; - fn neg(self) -> Self::Output { - ::mat_neg(GenericMatrix::as_ref(self)) - } - } - impl Neg for Matrix - where - Mat::Elem: Conjugate, - ::Canonical: ComplexField, - Mat::Kind: MatNeg, - { - type Output = KindOwn<::Canonical, ::Output>; - fn neg(self) -> Self::Output { - -&self - } - } - - impl PartialEq> for Matrix - where - Lhs::Elem: Conjugate, - Rhs::Elem: Conjugate::Canonical>, - ::Canonical: ComplexField, - Lhs::Kind: MatEq, - { - fn eq(&self, rhs: &Matrix) -> bool { - >::mat_eq( - GenericMatrix::as_ref(self), - GenericMatrix::as_ref(rhs), - ) - } - } - - impl MulAssign<&Matrix> for Matrix - where - Lhs::Elem: ComplexField, - Rhs::Elem: Conjugate, - Lhs::Kind: MatMulAssign, - { - #[track_caller] - fn mul_assign(&mut self, rhs: &Matrix) { - >::mat_mul_assign( - GenericMatrixMut::as_mut(self), - GenericMatrix::as_ref(rhs), - ); - } - } - impl MulAssign> for Matrix - where - Lhs::Elem: ComplexField, - Rhs::Elem: Conjugate, - Lhs::Kind: MatMulAssign, - { - #[track_caller] - fn mul_assign(&mut self, rhs: Matrix) { - *self *= &rhs; - } - } - - impl AddAssign<&Matrix> for Matrix - where - Lhs::Elem: ComplexField, - Rhs::Elem: Conjugate, - Lhs::Kind: MatAddAssign, - { - #[track_caller] - fn add_assign(&mut self, rhs: &Matrix) { - >::mat_add_assign( - GenericMatrixMut::as_mut(self), - GenericMatrix::as_ref(rhs), - ); - } - } - impl AddAssign> for Matrix - where - Lhs::Elem: ComplexField, - Rhs::Elem: Conjugate, - Lhs::Kind: MatAddAssign, - { - #[track_caller] - fn add_assign(&mut self, rhs: Matrix) { - *self += &rhs; - } - } - - impl SubAssign<&Matrix> for Matrix - where - Lhs::Elem: ComplexField, - Rhs::Elem: Conjugate, - Lhs::Kind: MatSubAssign, - { - #[track_caller] - fn sub_assign(&mut self, rhs: &Matrix) { - >::mat_sub_assign( - GenericMatrixMut::as_mut(self), - GenericMatrix::as_ref(rhs), - ); - } - } - impl SubAssign> for Matrix - where - Lhs::Elem: ComplexField, - Rhs::Elem: Conjugate, - Lhs::Kind: MatSubAssign, - { - #[track_caller] - fn sub_assign(&mut self, rhs: Matrix) { - *self -= &rhs; - } - } -}; - -#[cfg(test)] -#[allow(non_snake_case)] -mod test { - use crate::{ - assert, mat, - permutation::{Permutation, PermutationRef}, - Col, Mat, Row, - }; - use assert_approx_eq::assert_approx_eq; - - fn matrices() -> (Mat, Mat) { - let A = mat![[2.8, -3.3], [-1.7, 5.2], [4.6, -8.3],]; - - let B = mat![[-7.9, 8.3], [4.7, -3.2], [3.8, -5.2],]; - (A, B) - } - - #[test] - #[should_panic] - fn test_adding_matrices_of_different_sizes_should_panic() { - let A = mat![[1.0, 2.0], [3.0, 4.0]]; - let B = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - _ = A + B; - } - - #[test] - #[should_panic] - fn test_subtracting_two_matrices_of_different_sizes_should_panic() { - let A = mat![[1.0, 2.0], [3.0, 4.0]]; - let B = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; - _ = A - B; - } - - #[test] - fn test_add() { - let (A, B) = matrices(); - - let expected = mat![[-5.1, 5.0], [3.0, 2.0], [8.4, -13.5],]; - - assert_matrix_approx_eq(A.as_ref() + B.as_ref(), &expected); - assert_matrix_approx_eq(&A + &B, &expected); - assert_matrix_approx_eq(A.as_ref() + &B, &expected); - assert_matrix_approx_eq(&A + B.as_ref(), &expected); - assert_matrix_approx_eq(A.as_ref() + B.clone(), &expected); - assert_matrix_approx_eq(&A + B.clone(), &expected); - assert_matrix_approx_eq(A.clone() + B.as_ref(), &expected); - assert_matrix_approx_eq(A.clone() + &B, &expected); - assert_matrix_approx_eq(A + B, &expected); - } - - #[test] - fn test_sub() { - let (A, B) = matrices(); - - let expected = mat![[10.7, -11.6], [-6.4, 8.4], [0.8, -3.1],]; - - assert_matrix_approx_eq(A.as_ref() - B.as_ref(), &expected); - assert_matrix_approx_eq(&A - &B, &expected); - assert_matrix_approx_eq(A.as_ref() - &B, &expected); - assert_matrix_approx_eq(&A - B.as_ref(), &expected); - assert_matrix_approx_eq(A.as_ref() - B.clone(), &expected); - assert_matrix_approx_eq(&A - B.clone(), &expected); - assert_matrix_approx_eq(A.clone() - B.as_ref(), &expected); - assert_matrix_approx_eq(A.clone() - &B, &expected); - assert_matrix_approx_eq(A - B, &expected); - } - - #[test] - fn test_neg() { - let (A, _) = matrices(); - - let expected = mat![[-2.8, 3.3], [1.7, -5.2], [-4.6, 8.3],]; - - assert_eq!(-A, expected); - } - - #[test] - fn test_scalar_mul() { - use crate::scale; - - let (A, _) = matrices(); - let scale = scale(3.0); - let expected = Mat::from_fn(A.nrows(), A.ncols(), |i, j| A.read(i, j) * scale.value()); - - { - assert_matrix_approx_eq(A.as_ref() * scale, &expected); - assert_matrix_approx_eq(&A * scale, &expected); - assert_matrix_approx_eq(A.as_ref() * scale, &expected); - assert_matrix_approx_eq(&A * scale, &expected); - assert_matrix_approx_eq(A.as_ref() * scale, &expected); - assert_matrix_approx_eq(&A * scale, &expected); - assert_matrix_approx_eq(A.clone() * scale, &expected); - assert_matrix_approx_eq(A.clone() * scale, &expected); - assert_matrix_approx_eq(A * scale, &expected); - } - - let (A, _) = matrices(); - { - assert_matrix_approx_eq(scale * A.as_ref(), &expected); - assert_matrix_approx_eq(scale * &A, &expected); - assert_matrix_approx_eq(scale * A.as_ref(), &expected); - assert_matrix_approx_eq(scale * &A, &expected); - assert_matrix_approx_eq(scale * A.as_ref(), &expected); - assert_matrix_approx_eq(scale * &A, &expected); - assert_matrix_approx_eq(scale * A.clone(), &expected); - assert_matrix_approx_eq(scale * A.clone(), &expected); - assert_matrix_approx_eq(scale * A, &expected); - } - } - - #[test] - fn test_diag_mul() { - let (A, _) = matrices(); - let diag_left = mat![[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]]; - let diag_right = mat![[4.0, 0.0], [0.0, 5.0]]; - - assert!(&diag_left * &A == diag_left.diagonal() * &A); - assert!(&A * &diag_right == &A * diag_right.diagonal()); - } - - #[test] - fn test_perm_mul() { - let A = Mat::from_fn(6, 5, |i, j| (j + 5 * i) as f64); - let pl = Permutation::::new_checked( - Box::new([5, 1, 4, 0, 2, 3]), - Box::new([3, 1, 4, 5, 2, 0]), - ); - let pr = Permutation::::new_checked( - Box::new([1, 4, 0, 2, 3]), - Box::new([2, 0, 3, 4, 1]), - ); - - let perm_left = mat![ - [0.0, 0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - ]; - let perm_right = mat![ - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - ]; - - assert!( - &pl * pl.as_ref().inverse() - == PermutationRef::<'_, usize, f64>::new_checked( - &[0, 1, 2, 3, 4, 5], - &[0, 1, 2, 3, 4, 5], - ) - ); - assert!(&perm_left * &A == &pl * &A); - assert!(&A * &perm_right == &A * &pr); - } - - #[test] - fn test_matmul_col_row() { - let A = Col::from_fn(6, |i| i as f64); - let B = Row::from_fn(6, |j| (5 * j + 1) as f64); - - // outer product - assert_eq!(&A * &B, A.as_ref().as_2d() * B.as_ref().as_2d()); - // inner product - assert_eq!( - &B * &A, - (B.as_ref().as_2d() * A.as_ref().as_2d()).read(0, 0), - ); - } - - fn assert_matrix_approx_eq(given: Mat, expected: &Mat) { - assert_eq!(given.nrows(), expected.nrows()); - assert_eq!(given.ncols(), expected.ncols()); - for i in 0..given.nrows() { - for j in 0..given.ncols() { - assert_approx_eq!(given.read(i, j), expected.read(i, j)); - } - } - } -} diff --git a/faer-libs/faer-core/src/permutation.rs b/faer-libs/faer-core/src/permutation.rs deleted file mode 100644 index 15b0ca01ee5659af6b9be2a1a144c3fb8df045d4..0000000000000000000000000000000000000000 --- a/faer-libs/faer-core/src/permutation.rs +++ /dev/null @@ -1,926 +0,0 @@ -//! Permutation matrices. -#![allow(clippy::len_without_is_empty)] - -use crate::{ - assert, constrained, debug_assert, - inner::{PermMut, PermOwn, PermRef}, - seal::Seal, - temp_mat_req, temp_mat_uninit, unzipped, zipped, ComplexField, Entity, MatMut, MatRef, Matrix, -}; -use bytemuck::Pod; -use core::fmt::Debug; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use reborrow::*; - -impl Seal for i32 {} -impl Seal for i64 {} -impl Seal for i128 {} -impl Seal for isize {} -impl Seal for u32 {} -impl Seal for u64 {} -impl Seal for u128 {} -impl Seal for usize {} - -/// Trait for unsigned integers that can be indexed with. -/// -/// Always smaller than or equal to `usize`. -pub trait Index: - Seal - + core::fmt::Debug - + core::ops::Not - + core::ops::Add - + core::ops::Sub - + core::ops::AddAssign - + core::ops::SubAssign - + Pod - + Eq - + Ord - + Send - + Sync -{ - /// Equally-sized index type with a fixed size (no `usize`). - type FixedWidth: Index; - /// Equally-sized signed index type. - type Signed: SignedIndex; - - /// Truncate `value` to type [`Self`]. - #[must_use] - #[inline(always)] - fn truncate(value: usize) -> Self { - Self::from_signed(::truncate(value)) - } - - /// Zero extend `self`. - #[must_use] - #[inline(always)] - fn zx(self) -> usize { - self.to_signed().zx() - } - - /// Convert a reference to a slice of [`Self`] to fixed width types. - #[inline(always)] - fn canonicalize(slice: &[Self]) -> &[Self::FixedWidth] { - bytemuck::cast_slice(slice) - } - - /// Convert a mutable reference to a slice of [`Self`] to fixed width types. - #[inline(always)] - fn canonicalize_mut(slice: &mut [Self]) -> &mut [Self::FixedWidth] { - bytemuck::cast_slice_mut(slice) - } - - /// Convert a signed value to an unsigned one. - #[inline(always)] - fn from_signed(value: Self::Signed) -> Self { - pulp::cast(value) - } - - /// Convert an unsigned value to a signed one. - #[inline(always)] - fn to_signed(self) -> Self::Signed { - pulp::cast(self) - } - - /// Sum values while checking for overflow. - #[inline] - fn sum_nonnegative(slice: &[Self]) -> Option { - Self::Signed::sum_nonnegative(bytemuck::cast_slice(slice)).map(Self::from_signed) - } -} - -/// Trait for signed integers corresponding to the ones satisfying [`Index`]. -/// -/// Always smaller than or equal to `isize`. -pub trait SignedIndex: - Seal - + core::fmt::Debug - + core::ops::Neg - + core::ops::Add - + core::ops::Sub - + core::ops::AddAssign - + core::ops::SubAssign - + Pod - + Eq - + Ord - + Send - + Sync -{ - /// Maximum representable value. - const MAX: Self; - - /// Truncate `value` to type [`Self`]. - #[must_use] - fn truncate(value: usize) -> Self; - - /// Zero extend `self`. - #[must_use] - fn zx(self) -> usize; - /// Sign extend `self`. - #[must_use] - fn sx(self) -> usize; - - /// Sum nonnegative values while checking for overflow. - fn sum_nonnegative(slice: &[Self]) -> Option { - let mut acc = Self::zeroed(); - for &i in slice { - if Self::MAX - i < acc { - return None; - } - acc += i; - } - Some(acc) - } -} - -#[cfg(any( - target_pointer_width = "32", - target_pointer_width = "64", - target_pointer_width = "128", -))] -impl Index for u32 { - type FixedWidth = u32; - type Signed = i32; -} -#[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))] -impl Index for u64 { - type FixedWidth = u64; - type Signed = i64; -} -#[cfg(target_pointer_width = "128")] -impl Index for u128 { - type FixedWidth = u128; - type Signed = i128; -} - -impl Index for usize { - #[cfg(target_pointer_width = "32")] - type FixedWidth = u32; - #[cfg(target_pointer_width = "64")] - type FixedWidth = u64; - #[cfg(target_pointer_width = "128")] - type FixedWidth = u128; - - type Signed = isize; -} - -#[cfg(any( - target_pointer_width = "32", - target_pointer_width = "64", - target_pointer_width = "128", -))] -impl SignedIndex for i32 { - const MAX: Self = Self::MAX; - - #[inline(always)] - fn truncate(value: usize) -> Self { - #[allow(clippy::assertions_on_constants)] - const _: () = { - core::assert!(i32::BITS <= usize::BITS); - }; - value as isize as Self - } - - #[inline(always)] - fn zx(self) -> usize { - self as u32 as usize - } - - #[inline(always)] - fn sx(self) -> usize { - self as isize as usize - } -} - -#[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))] -impl SignedIndex for i64 { - const MAX: Self = Self::MAX; - - #[inline(always)] - fn truncate(value: usize) -> Self { - #[allow(clippy::assertions_on_constants)] - const _: () = { - core::assert!(i64::BITS <= usize::BITS); - }; - value as isize as Self - } - - #[inline(always)] - fn zx(self) -> usize { - self as u64 as usize - } - - #[inline(always)] - fn sx(self) -> usize { - self as isize as usize - } -} - -#[cfg(target_pointer_width = "128")] -impl SignedIndex for i128 { - const MAX: Self = Self::MAX; - - #[inline(always)] - fn truncate(value: usize) -> Self { - #[allow(clippy::assertions_on_constants)] - const _: () = { - core::assert!(i128::BITS <= usize::BITS); - }; - value as isize as Self - } - - #[inline(always)] - fn zx(self) -> usize { - self as u128 as usize - } - - #[inline(always)] - fn sx(self) -> usize { - self as isize as usize - } -} - -impl SignedIndex for isize { - const MAX: Self = Self::MAX; - - #[inline(always)] - fn truncate(value: usize) -> Self { - value as isize - } - - #[inline(always)] - fn zx(self) -> usize { - self as usize - } - - #[inline(always)] - fn sx(self) -> usize { - self as usize - } -} - -/// Swaps the two columns at indices `a` and `b` in the given matrix. -/// -/// # Panics -/// -/// Panics if either `a` or `b` is out of bounds. -/// -/// # Example -/// -/// ``` -/// use faer_core::{mat, permutation::swap_cols}; -/// -/// let mut m = mat![ -/// [1.0, 2.0, 3.0], -/// [4.0, 5.0, 6.0], -/// [7.0, 8.0, 9.0], -/// [10.0, 14.0, 12.0], -/// ]; -/// -/// swap_cols(m.as_mut(), 0, 2); -/// -/// let swapped = mat![ -/// [3.0, 2.0, 1.0], -/// [6.0, 5.0, 4.0], -/// [9.0, 8.0, 7.0], -/// [12.0, 14.0, 10.0], -/// ]; -/// -/// assert_eq!(m, swapped); -/// ``` -#[track_caller] -#[inline] -pub fn swap_cols(mat: MatMut<'_, E>, a: usize, b: usize) { - assert!(all(a < mat.ncols(), b < mat.ncols())); - - if a == b { - return; - } - - let mat = mat.into_const(); - let mat_a = mat.col(a); - let mat_b = mat.col(b); - - unsafe { - zipped!( - mat_a.const_cast().as_2d_mut(), - mat_b.const_cast().as_2d_mut(), - ) - } - .for_each(|unzipped!(mut a, mut b)| { - let (a_read, b_read) = (a.read(), b.read()); - a.write(b_read); - b.write(a_read); - }); -} - -/// Swaps the two rows at indices `a` and `b` in the given matrix. -/// -/// # Panics -/// -/// Panics if either `a` or `b` is out of bounds. -/// -/// # Example -/// -/// ``` -/// use faer_core::{mat, permutation::swap_rows}; -/// -/// let mut m = mat![ -/// [1.0, 2.0, 3.0], -/// [4.0, 5.0, 6.0], -/// [7.0, 8.0, 9.0], -/// [10.0, 14.0, 12.0], -/// ]; -/// -/// swap_rows(m.as_mut(), 0, 2); -/// -/// let swapped = mat![ -/// [7.0, 8.0, 9.0], -/// [4.0, 5.0, 6.0], -/// [1.0, 2.0, 3.0], -/// [10.0, 14.0, 12.0], -/// ]; -/// -/// assert_eq!(m, swapped); -/// ``` -#[track_caller] -#[inline] -pub fn swap_rows(mat: MatMut<'_, E>, a: usize, b: usize) { - swap_cols(mat.transpose_mut(), a, b) -} - -/// Immutable permutation view. -pub type PermutationRef<'a, I, E> = Matrix>; -/// Mutable permutation view. -pub type PermutationMut<'a, I, E> = Matrix>; -/// Owned permutation. -pub type Permutation = Matrix>; - -impl Permutation { - /// Convert `self` to a permutation view. - #[inline] - pub fn as_ref(&self) -> PermutationRef<'_, I, E> { - PermutationRef { - inner: PermRef { - forward: &self.inner.forward, - inverse: &self.inner.inverse, - __marker: core::marker::PhantomData, - }, - } - } - - /// Convert `self` to a mutable permutation view. - #[inline] - pub fn as_mut(&mut self) -> PermutationMut<'_, I, E> { - PermutationMut { - inner: PermMut { - forward: &mut self.inner.forward, - inverse: &mut self.inner.inverse, - __marker: core::marker::PhantomData, - }, - } - } -} - -impl Permutation { - /// Creates a new permutation, by checking the validity of the inputs. - /// - /// # Panics - /// - /// The function panics if any of the following conditions are violated: - /// `forward` and `inverse` must have the same length which must be less than or equal to - /// `I::Signed::MAX`, be valid permutations, and be inverse permutations of each other. - #[inline] - #[track_caller] - pub fn new_checked(forward: alloc::boxed::Box<[I]>, inverse: alloc::boxed::Box<[I]>) -> Self { - PermutationRef::<'_, I, E>::new_checked(&forward, &inverse); - Self { - inner: PermOwn { - forward, - inverse, - __marker: core::marker::PhantomData, - }, - } - } - - /// Creates a new permutation reference, without checking the validity of the inputs. - /// - /// # Safety - /// - /// `forward` and `inverse` must have the same length which must be less than or equal to - /// `I::Signed::MAX`, be valid permutations, and be inverse permutations of each other. - #[inline] - #[track_caller] - pub unsafe fn new_unchecked( - forward: alloc::boxed::Box<[I]>, - inverse: alloc::boxed::Box<[I]>, - ) -> Self { - let n = forward.len(); - assert!(all( - forward.len() == inverse.len(), - n <= I::Signed::MAX.zx(), - )); - Self { - inner: PermOwn { - forward, - inverse, - __marker: core::marker::PhantomData, - }, - } - } - - /// Returns the permutation as an array. - #[inline] - pub fn into_arrays(self) -> (alloc::boxed::Box<[I]>, alloc::boxed::Box<[I]>) { - (self.inner.forward, self.inner.inverse) - } - - /// Returns the dimension of the permutation. - #[inline] - pub fn len(&self) -> usize { - self.inner.forward.len() - } - - /// Returns the inverse permutation. - #[inline] - pub fn inverse(self) -> Self { - Self { - inner: PermOwn { - forward: self.inner.inverse, - inverse: self.inner.forward, - __marker: core::marker::PhantomData, - }, - } - } - - /// Cast the permutation to a different scalar type. - #[inline] - pub fn cast(self) -> Permutation { - Permutation { - inner: PermOwn { - forward: self.inner.forward, - inverse: self.inner.inverse, - __marker: core::marker::PhantomData, - }, - } - } -} - -impl<'a, I: Index, E: Entity> PermutationRef<'a, I, E> { - /// Creates a new permutation reference, by checking the validity of the inputs. - /// - /// # Panics - /// - /// The function panics if any of the following conditions are violated: - /// `forward` and `inverse` must have the same length which must be less than or equal to - /// `I::Signed::MAX`, be valid permutations, and be inverse permutations of each other. - #[inline] - #[track_caller] - pub fn new_checked(forward: &'a [I], inverse: &'a [I]) -> Self { - #[track_caller] - fn check(forward: &[I], inverse: &[I]) { - let n = forward.len(); - assert!(all( - forward.len() == inverse.len(), - n <= I::Signed::MAX.zx() - )); - for (i, &p) in forward.iter().enumerate() { - let p = p.to_signed().zx(); - assert!(p < n); - assert!(inverse[p].to_signed().zx() == i); - } - } - - check(I::canonicalize(forward), I::canonicalize(inverse)); - Self { - inner: PermRef { - forward, - inverse, - __marker: core::marker::PhantomData, - }, - } - } - - /// Creates a new permutation reference, without checking the validity of the inputs. - /// - /// # Safety - /// - /// `forward` and `inverse` must have the same length which must be less than or equal to - /// `I::Signed::MAX`, be valid permutations, and be inverse permutations of each other. - #[inline] - #[track_caller] - pub unsafe fn new_unchecked(forward: &'a [I], inverse: &'a [I]) -> Self { - let n = forward.len(); - assert!(all( - forward.len() == inverse.len(), - n <= I::Signed::MAX.zx(), - )); - - Self { - inner: PermRef { - forward, - inverse, - __marker: core::marker::PhantomData, - }, - } - } - - /// Returns the permutation as an array. - #[inline] - pub fn into_arrays(self) -> (&'a [I], &'a [I]) { - (self.inner.forward, self.inner.inverse) - } - - /// Returns the dimension of the permutation. - #[inline] - pub fn len(&self) -> usize { - debug_assert!(self.inner.inverse.len() == self.inner.forward.len()); - self.inner.forward.len() - } - - /// Returns the inverse permutation. - #[inline] - pub fn inverse(self) -> Self { - Self { - inner: PermRef { - forward: self.inner.inverse, - inverse: self.inner.forward, - __marker: core::marker::PhantomData, - }, - } - } - - /// Cast the permutation to a different scalar type. - #[inline] - pub fn cast(self) -> PermutationRef<'a, I, T> { - PermutationRef { - inner: PermRef { - forward: self.inner.forward, - inverse: self.inner.inverse, - __marker: core::marker::PhantomData, - }, - } - } - - /// Cast the permutation to the fixed width index type. - #[inline(always)] - pub fn canonicalize(self) -> PermutationRef<'a, I::FixedWidth, E> { - PermutationRef { - inner: PermRef { - forward: I::canonicalize(self.inner.forward), - inverse: I::canonicalize(self.inner.inverse), - __marker: core::marker::PhantomData, - }, - } - } - - /// Cast the permutation from the fixed width index type. - #[inline(always)] - pub fn uncanonicalize(self) -> PermutationRef<'a, J, E> { - assert!(core::mem::size_of::() == core::mem::size_of::()); - PermutationRef { - inner: PermRef { - forward: bytemuck::cast_slice(self.inner.forward), - inverse: bytemuck::cast_slice(self.inner.inverse), - __marker: core::marker::PhantomData, - }, - } - } -} - -impl<'a, I: Index, E: Entity> PermutationMut<'a, I, E> { - /// Creates a new permutation mutable reference, by checking the validity of the inputs. - /// - /// # Panics - /// - /// The function panics if any of the following conditions are violated: - /// `forward` and `inverse` must have the same length which must be less than or equal to - /// `I::Signed::MAX`, be valid permutations, and be inverse permutations of each other. - #[inline] - #[track_caller] - pub fn new_checked(forward: &'a mut [I], inverse: &'a mut [I]) -> Self { - PermutationRef::<'_, I, E>::new_checked(forward, inverse); - Self { - inner: PermMut { - forward, - inverse, - __marker: core::marker::PhantomData, - }, - } - } - - /// Creates a new permutation mutable reference, without checking the validity of the inputs. - /// - /// # Safety - /// - /// `forward` and `inverse` must have the same length which must be less than or equal to - /// `I::Signed::MAX`, be valid permutations, and be inverse permutations of each other. - #[inline] - #[track_caller] - pub unsafe fn new_unchecked(forward: &'a mut [I], inverse: &'a mut [I]) -> Self { - let n = forward.len(); - assert!(all( - forward.len() == inverse.len(), - n <= I::Signed::MAX.zx(), - )); - - Self { - inner: PermMut { - forward, - inverse, - __marker: core::marker::PhantomData, - }, - } - } - - /// Returns the permutation as an array. - /// - /// # Safety - /// - /// The behavior is undefined if the arrays are no longer inverse permutations of each other by - /// the end of lifetime `'a`. - #[inline] - pub unsafe fn into_arrays(self) -> (&'a mut [I], &'a mut [I]) { - (self.inner.forward, self.inner.inverse) - } - - /// Returns the dimension of the permutation. - #[inline] - pub fn len(&self) -> usize { - debug_assert!(self.inner.inverse.len() == self.inner.forward.len()); - self.inner.forward.len() - } - - /// Returns the inverse permutation. - #[inline] - pub fn inverse(self) -> Self { - Self { - inner: PermMut { - forward: self.inner.inverse, - inverse: self.inner.forward, - __marker: core::marker::PhantomData, - }, - } - } - - /// Cast the permutation to a different scalar type. - #[inline] - pub fn cast(self) -> PermutationMut<'a, I, T> { - PermutationMut { - inner: PermMut { - forward: self.inner.forward, - inverse: self.inner.inverse, - __marker: core::marker::PhantomData, - }, - } - } - - /// Cast the permutation to the fixed width index type. - #[inline(always)] - pub fn canonicalize(self) -> PermutationMut<'a, I::FixedWidth, E> { - PermutationMut { - inner: PermMut { - forward: I::canonicalize_mut(self.inner.forward), - inverse: I::canonicalize_mut(self.inner.inverse), - __marker: core::marker::PhantomData, - }, - } - } - - /// Cast the permutation from the fixed width index type. - #[inline(always)] - pub fn uncanonicalize(self) -> PermutationMut<'a, J, E> { - assert!(core::mem::size_of::() == core::mem::size_of::()); - PermutationMut { - inner: PermMut { - forward: bytemuck::cast_slice_mut(self.inner.forward), - inverse: bytemuck::cast_slice_mut(self.inner.inverse), - __marker: core::marker::PhantomData, - }, - } - } -} - -impl<'short, 'a, I, E: Entity> Reborrow<'short> for PermutationRef<'a, I, E> { - type Target = PermutationRef<'short, I, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - *self - } -} - -impl<'short, 'a, I, E: Entity> ReborrowMut<'short> for PermutationRef<'a, I, E> { - type Target = PermutationRef<'short, I, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } -} - -impl<'short, 'a, I, E: Entity> Reborrow<'short> for PermutationMut<'a, I, E> { - type Target = PermutationRef<'short, I, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - PermutationRef { - inner: PermRef { - forward: &*self.inner.forward, - inverse: &*self.inner.inverse, - __marker: core::marker::PhantomData, - }, - } - } -} - -impl<'short, 'a, I, E: Entity> ReborrowMut<'short> for PermutationMut<'a, I, E> { - type Target = PermutationMut<'short, I, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - PermutationMut { - inner: PermMut { - forward: &mut *self.inner.forward, - inverse: &mut *self.inner.inverse, - __marker: core::marker::PhantomData, - }, - } - } -} - -impl<'a, I: Debug, E: Entity> Debug for PermutationRef<'a, I, E> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.inner.fmt(f) - } -} -impl<'a, I: Debug, E: Entity> Debug for PermutationMut<'a, I, E> { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.rb().fmt(f) - } -} -impl<'a, I: Debug, E: Entity> Debug for Permutation { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.as_ref().fmt(f) - } -} - -/// Computes a permutation of the columns of the source matrix using the given permutation, and -/// stores the result in the destination matrix. -/// -/// # Panics -/// -/// - Panics if the matrices do not have the same shape. -/// - Panics if the size of the permutation doesn't match the number of columns of the matrices. -#[inline] -#[track_caller] -pub fn permute_cols( - dst: MatMut<'_, E>, - src: MatRef<'_, E>, - perm_indices: PermutationRef<'_, I, E>, -) { - assert!(all( - src.nrows() == dst.nrows(), - src.ncols() == dst.ncols(), - perm_indices.into_arrays().0.len() == src.ncols(), - )); - - permute_rows( - dst.transpose_mut(), - src.transpose(), - perm_indices.canonicalize(), - ); -} - -/// Computes a permutation of the rows of the source matrix using the given permutation, and -/// stores the result in the destination matrix. -/// -/// # Panics -/// -/// - Panics if the matrices do not have the same shape. -/// - Panics if the size of the permutation doesn't match the number of rows of the matrices. -#[inline] -#[track_caller] -pub fn permute_rows( - dst: MatMut<'_, E>, - src: MatRef<'_, E>, - perm_indices: PermutationRef<'_, I, E>, -) { - #[track_caller] - fn implementation( - dst: MatMut<'_, E>, - src: MatRef<'_, E>, - perm_indices: PermutationRef<'_, I, E>, - ) { - assert!(all( - src.nrows() == dst.nrows(), - src.ncols() == dst.ncols(), - perm_indices.into_arrays().0.len() == src.nrows(), - )); - - constrained::Size::with2(src.nrows(), src.ncols(), |m, n| { - let mut dst = constrained::MatMut::new(dst, m, n); - let src = constrained::MatRef::new(src, m, n); - let perm = constrained::permutation::PermutationRef::new(perm_indices, m) - .into_arrays() - .0; - - if dst.rb().into_inner().row_stride().unsigned_abs() - < dst.rb().into_inner().col_stride().unsigned_abs() - { - for j in n.indices() { - for i in m.indices() { - dst.rb_mut().write(i, j, src.read(perm[i].zx(), j)); - } - } - } else { - for i in m.indices() { - let src_i = src.into_inner().row(perm[i].zx().into_inner()); - let mut dst_i = dst.rb_mut().into_inner().row_mut(i.into_inner()); - - dst_i.copy_from(src_i); - } - } - }); - } - - implementation(dst, src, perm_indices.canonicalize()) -} - -/// Computes the size and alignment of required workspace for applying a row permutation to a -/// matrix in place. -pub fn permute_rows_in_place_req( - nrows: usize, - ncols: usize, -) -> Result { - temp_mat_req::(nrows, ncols) -} - -/// Computes the size and alignment of required workspace for applying a column permutation to a -/// matrix in place. -pub fn permute_cols_in_place_req( - nrows: usize, - ncols: usize, -) -> Result { - temp_mat_req::(nrows, ncols) -} - -/// Computes a permutation of the rows of the matrix using the given permutation, and -/// stores the result in the same matrix. -/// -/// # Panics -/// -/// - Panics if the size of the permutation doesn't match the number of rows of the matrix. -#[inline] -#[track_caller] -pub fn permute_rows_in_place( - matrix: MatMut<'_, E>, - perm_indices: PermutationRef<'_, I, E>, - stack: PodStack<'_>, -) { - #[inline] - #[track_caller] - fn implementation( - matrix: MatMut<'_, E>, - perm_indices: PermutationRef<'_, I, E>, - stack: PodStack<'_>, - ) { - let mut matrix = matrix; - let (mut tmp, _) = temp_mat_uninit::(matrix.nrows(), matrix.ncols(), stack); - tmp.rb_mut().copy_from(matrix.rb()); - permute_rows(matrix.rb_mut(), tmp.rb(), perm_indices); - } - - implementation(matrix, perm_indices.canonicalize(), stack) -} - -/// Computes a permutation of the columns of the matrix using the given permutation, and -/// stores the result in the same matrix. -/// -/// # Panics -/// -/// - Panics if the size of the permutation doesn't match the number of columns of the matrix. -#[inline] -#[track_caller] -pub fn permute_cols_in_place( - matrix: MatMut<'_, E>, - perm_indices: PermutationRef<'_, I, E>, - stack: PodStack<'_>, -) { - #[inline] - #[track_caller] - fn implementation( - matrix: MatMut<'_, E>, - perm_indices: PermutationRef<'_, I, E>, - stack: PodStack<'_>, - ) { - let mut matrix = matrix; - let (mut tmp, _) = temp_mat_uninit::(matrix.nrows(), matrix.ncols(), stack); - tmp.rb_mut().copy_from(matrix.rb()); - permute_cols(matrix.rb_mut(), tmp.rb(), perm_indices); - } - - implementation(matrix, perm_indices.canonicalize(), stack) -} diff --git a/faer-libs/faer-core/src/simd.rs b/faer-libs/faer-core/src/simd.rs deleted file mode 100644 index 78d8445b6a37d77d85b89ec07fb253b4ba1e188a..0000000000000000000000000000000000000000 --- a/faer-libs/faer-core/src/simd.rs +++ /dev/null @@ -1,147 +0,0 @@ -pub use faer_entity::{ - one_simd_as_slice, simd_as_slice, simd_as_slice_unit, simd_index_as_slice, slice_as_mut_simd, - slice_as_simd, -}; - -fn sum_i32_scalar(slice: &[i32]) -> Option { - let mut overflow = false; - let mut sum = 0i32; - for &v in slice { - let o; - (sum, o) = i32::overflowing_add(sum, v); - overflow |= o; - } - (!overflow).then_some(sum) -} -fn sum_i64_scalar(slice: &[i64]) -> Option { - let mut overflow = false; - let mut sum = 0i64; - for &v in slice { - let o; - (sum, o) = i64::overflowing_add(sum, v); - overflow |= o; - } - (!overflow).then_some(sum) -} - -pub fn sum_i32(slice: &[i32]) -> Option { - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if let Some(simd) = pulp::x86::V3::try_new() { - return x86::sum_i32_v3(simd, slice); - } - sum_i32_scalar(slice) -} - -pub fn sum_i64(slice: &[i64]) -> Option { - #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] - if let Some(simd) = pulp::x86::V3::try_new() { - return x86::sum_i64_v3(simd, slice); - } - sum_i64_scalar(slice) -} - -#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] -mod x86 { - use super::*; - use pulp::{x86::*, *}; - pub fn sum_i32_v3(simd: V3, slice: &[i32]) -> Option { - struct Impl<'a> { - simd: V3, - slice: &'a [i32], - } - - impl pulp::NullaryFnOnce for Impl<'_> { - type Output = Option; - - #[inline(always)] - fn call(self) -> Self::Output { - let Self { simd, slice } = self; - let (head, tail) = V3::i32s_as_simd(slice); - - let zero = simd.splat_i32x8(0); - let mut sum = zero; - let mut overflow = simd.splat_m32x8(m32::new(false)); - - for &v in head { - sum = simd.wrapping_add_i32x8(sum, v); - overflow = simd.or_m32x8(overflow, simd.cmp_lt_i32x8(sum, zero)); - } - - if overflow != simd.splat_m32x8(m32::new(false)) { - return None; - } - - i32::checked_add( - sum_i32_scalar(tail)?, - sum_i32_scalar(bytemuck::cast_slice(&[sum]))?, - ) - } - } - - simd.vectorize(Impl { simd, slice }) - } - - pub fn sum_i64_v3(simd: V3, slice: &[i64]) -> Option { - struct Impl<'a> { - simd: V3, - slice: &'a [i64], - } - - impl pulp::NullaryFnOnce for Impl<'_> { - type Output = Option; - - #[inline(always)] - fn call(self) -> Self::Output { - let Self { simd, slice } = self; - let (head, tail) = V3::i64s_as_simd(slice); - - let zero = simd.splat_i64x4(0); - let mut sum = zero; - let mut overflow = simd.splat_m64x4(m64::new(false)); - - for &v in head { - sum = simd.wrapping_add_i64x4(sum, v); - overflow = simd.or_m64x4(overflow, simd.cmp_lt_i64x4(sum, zero)); - } - - if overflow != simd.splat_m64x4(m64::new(false)) { - return None; - } - - i64::checked_add( - sum_i64_scalar(tail)?, - sum_i64_scalar(bytemuck::cast_slice(&[sum]))?, - ) - } - } - - simd.vectorize(Impl { simd, slice }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::assert; - - #[test] - fn test_sum_i32() { - let array = vec![100_000_000i32; 1000]; - assert!(sum_i32(&array) == None); - - let array = vec![1_000_000i32; 1000]; - assert!(sum_i32(&array) == Some(1_000_000_000i32)); - } - - #[test] - fn test_sum_i64() { - let array = vec![i64::MAX / 100; 1000]; - assert!(sum_i64(&array) == None); - - let array = vec![100_000_000i64; 1000]; - assert!(sum_i64(&array) == Some(100_000_000_000i64)); - - let array = vec![1_000_000i64; 1000]; - assert!(sum_i64(&array) == Some(1_000_000_000i64)); - } -} diff --git a/faer-libs/faer-core/src/sparse.rs b/faer-libs/faer-core/src/sparse.rs deleted file mode 100644 index 38d99f457c004bf45d29e7a5287bfc3b928013b0..0000000000000000000000000000000000000000 --- a/faer-libs/faer-core/src/sparse.rs +++ /dev/null @@ -1,5411 +0,0 @@ -//! Sparse matrix data structures. -//! -//! Most sparse matrix algorithms accept matrices in sparse column-oriented format. -//! This format represents each column of the matrix by storing the row indices of its non-zero -//! elements, as well as their values. -//! -//! The indices and the values are each stored in a contiguous slice (or group of slices for -//! arbitrary values). In order to specify where each column starts and ends, a slice of size -//! `ncols + 1` stores the start of each column, with the last element being equal to the total -//! number of non-zeros (or the capacity in uncompressed mode). -//! -//! # Example -//! -//! Consider the 4-by-5 matrix: -//! ```notcode -//! 10.0 0.0 12.0 -1.0 13.0 -//! 0.0 0.0 25.0 -2.0 0.0 -//! 1.0 0.0 0.0 0.0 0.0 -//! 4.0 0.0 0.0 0.0 5.0 -//! ``` -//! -//! The matrix is stored as follows: -//! ```notcode -//! column pointers: 0 | 3 | 3 | 5 | 7 | 9 -//! -//! row indices: 0 | 2 | 3 | 0 | 1 | 0 | 1 | 0 | 3 -//! values : 10.0 | 1.0 | 4.0 | 12.0 | 25.0 | -1.0 | -2.0 | 13.0 | 5.0 -//! ``` - -use super::*; -use crate::{assert, group_helpers::VecGroup}; -use core::{cell::Cell, iter::zip, ops::Range, slice::SliceIndex}; -use dyn_stack::GlobalPodBuffer; -use group_helpers::SliceGroup; - -pub use permutation::{Index, SignedIndex}; - -mod ghost { - pub use crate::constrained::{group_helpers::*, permutation::*, sparse::*, *}; -} - -mod mem { - #[inline] - pub fn fill_zero(slice: &mut [I]) { - let len = slice.len(); - unsafe { core::ptr::write_bytes(slice.as_mut_ptr(), 0u8, len) } - } -} - -#[inline(always)] -#[track_caller] -#[doc(hidden)] -pub unsafe fn __get_unchecked>(slice: &[I], i: R) -> &R::Output { - #[cfg(debug_assertions)] - { - let _ = &slice[i.clone()]; - } - unsafe { slice.get_unchecked(i) } -} -#[inline(always)] -#[track_caller] -#[doc(hidden)] -pub unsafe fn __get_unchecked_mut>( - slice: &mut [I], - i: R, -) -> &mut R::Output { - #[cfg(debug_assertions)] - { - let _ = &slice[i.clone()]; - } - unsafe { slice.get_unchecked_mut(i) } -} - -#[inline(always)] -#[doc(hidden)] -pub fn windows2(slice: &[I]) -> impl DoubleEndedIterator { - slice - .windows(2) - .map(|window| unsafe { &*(window.as_ptr() as *const [I; 2]) }) -} - -#[inline] -#[doc(hidden)] -pub const fn repeat_byte(byte: u8) -> usize { - union Union { - bytes: [u8; 32], - value: usize, - } - - let data = Union { bytes: [byte; 32] }; - unsafe { data.value } -} - -/// Symbolic view structure of sparse matrix in column format, either compressed or uncompressed. -/// -/// Requires: -/// * `nrows <= I::Signed::MAX` (always checked) -/// * `ncols <= I::Signed::MAX` (always checked) -/// * `col_ptrs` has length `ncols + 1` (always checked) -/// * `col_ptrs` is non-decreasing -/// * `col_ptrs[0]..col_ptrs[ncols]` is a valid range in row_indices (always checked, assuming -/// non-decreasing) -/// * if `nnz_per_col` is `None`, elements of `row_indices[col_ptrs[j]..col_ptrs[j + 1]]` are less -/// than `nrows` -/// -/// * `nnz_per_col[j] <= col_ptrs[j+1] - col_ptrs[j]` -/// * if `nnz_per_col` is `Some(_)`, elements of `row_indices[col_ptrs[j]..][..nnz_per_col[j]]` are -/// less than `nrows` -/// -/// * Within each column, row indices are unique and sorted in increasing order. -/// -/// # Note -/// Some algorithms allow working with matrices containing duplicate and/or unsorted row -/// indicers per column. -/// -/// Passing such a matrix to an algorithm that does not explicitly permit this is unspecified -/// (though not undefined) behavior. -pub struct SymbolicSparseColMatRef<'a, I> { - nrows: usize, - ncols: usize, - col_ptr: &'a [I], - col_nnz: Option<&'a [I]>, - row_ind: &'a [I], -} - -/// Symbolic view structure of sparse matrix in row format, either compressed or uncompressed. -/// -/// Requires: -/// * `nrows <= I::Signed::MAX` (always checked) -/// * `ncols <= I::Signed::MAX` (always checked) -/// * `row_ptrs` has length `nrows + 1` (always checked) -/// * `row_ptrs` is non-decreasing -/// * `row_ptrs[0]..row_ptrs[nrows]` is a valid range in row_indices (always checked, assuming -/// non-decreasing) -/// * if `nnz_per_row` is `None`, elements of `col_indices[row_ptrs[i]..row_ptrs[i + 1]]` are less -/// than `ncols` -/// -/// * `nnz_per_row[i] <= row_ptrs[i+1] - row_ptrs[i]` -/// * if `nnz_per_row` is `Some(_)`, elements of `col_indices[row_ptrs[i]..][..nnz_per_row[i]]` are -/// less than `ncols` -/// -/// * Within each row, column indices are unique and sorted in increasing order. -/// -/// # Note -/// Some algorithms allow working with matrices containing duplicate and/or unsorted column -/// indicers per row. -/// -/// Passing such a matrix to an algorithm that does not explicitly permit this is unspecified -/// (though not undefined) behavior. -pub struct SymbolicSparseRowMatRef<'a, I> { - nrows: usize, - ncols: usize, - row_ptr: &'a [I], - row_nnz: Option<&'a [I]>, - col_ind: &'a [I], -} - -/// Symbolic structure of sparse matrix in column format, either compressed or uncompressed. -/// -/// Requires: -/// * `nrows <= I::Signed::MAX` (always checked) -/// * `ncols <= I::Signed::MAX` (always checked) -/// * `col_ptrs` has length `ncols + 1` (always checked) -/// * `col_ptrs` is non-decreasing -/// * `col_ptrs[0]..col_ptrs[ncols]` is a valid range in row_indices (always checked, assuming -/// non-decreasing) -/// * if `nnz_per_col` is `None`, elements of `row_indices[col_ptrs[j]..col_ptrs[j + 1]]` are less -/// than `nrows` -/// -/// * `nnz_per_col[j] <= col_ptrs[j+1] - col_ptrs[j]` -/// * if `nnz_per_col` is `Some(_)`, elements of `row_indices[col_ptrs[j]..][..nnz_per_col[j]]` are -/// less than `nrows` -#[derive(Clone)] -pub struct SymbolicSparseColMat { - nrows: usize, - ncols: usize, - col_ptr: Vec, - col_nnz: Option>, - row_ind: Vec, -} - -/// Symbolic structure of sparse matrix in row format, either compressed or uncompressed. -/// -/// Requires: -/// * `nrows <= I::Signed::MAX` (always checked) -/// * `ncols <= I::Signed::MAX` (always checked) -/// * `row_ptrs` has length `nrows + 1` (always checked) -/// * `row_ptrs` is non-decreasing -/// * `row_ptrs[0]..row_ptrs[nrows]` is a valid range in row_indices (always checked, assuming -/// non-decreasing) -/// * if `nnz_per_row` is `None`, elements of `col_indices[row_ptrs[i]..row_ptrs[i + 1]]` are less -/// than `ncols` -/// -/// * `nnz_per_row[i] <= row_ptrs[i+1] - row_ptrs[i]` -/// * if `nnz_per_row` is `Some(_)`, elements of `col_indices[row_ptrs[i]..][..nnz_per_row[i]]` are -/// less than `ncols` -#[derive(Clone)] -pub struct SymbolicSparseRowMat { - nrows: usize, - ncols: usize, - row_ptr: Vec, - row_nnz: Option>, - col_ind: Vec, -} - -impl Copy for SymbolicSparseColMatRef<'_, I> {} -impl Clone for SymbolicSparseColMatRef<'_, I> { - #[inline] - fn clone(&self) -> Self { - *self - } -} -impl Copy for SymbolicSparseRowMatRef<'_, I> {} -impl Clone for SymbolicSparseRowMatRef<'_, I> { - #[inline] - fn clone(&self) -> Self { - *self - } -} - -impl SymbolicSparseRowMat { - /// Creates a new symbolic matrix view after asserting its invariants. - /// - /// # Panics - /// - /// See type level documentation. - #[inline] - #[track_caller] - pub fn new_checked( - nrows: usize, - ncols: usize, - row_ptrs: Vec, - nnz_per_row: Option>, - col_indices: Vec, - ) -> Self { - SymbolicSparseRowMatRef::new_checked( - nrows, - ncols, - &row_ptrs, - nnz_per_row.as_deref(), - &col_indices, - ); - - Self { - nrows, - ncols, - row_ptr: row_ptrs, - row_nnz: nnz_per_row, - col_ind: col_indices, - } - } - - /// Creates a new symbolic matrix view from data containing duplicate and/or unsorted column - /// indices per row, after asserting its other invariants. - /// - /// # Panics - /// - /// See type level documentation. - #[inline] - #[track_caller] - pub fn new_unsorted_checked( - nrows: usize, - ncols: usize, - row_ptrs: Vec, - nnz_per_row: Option>, - col_indices: Vec, - ) -> Self { - SymbolicSparseRowMatRef::new_unsorted_checked( - nrows, - ncols, - &row_ptrs, - nnz_per_row.as_deref(), - &col_indices, - ); - - Self { - nrows, - ncols, - row_ptr: row_ptrs, - row_nnz: nnz_per_row, - col_ind: col_indices, - } - } - - /// Creates a new symbolic matrix view without asserting its invariants. - /// - /// # Safety - /// - /// See type level documentation. - #[inline(always)] - #[track_caller] - pub unsafe fn new_unchecked( - nrows: usize, - ncols: usize, - row_ptrs: Vec, - nnz_per_row: Option>, - col_indices: Vec, - ) -> Self { - SymbolicSparseRowMatRef::new_unchecked( - nrows, - ncols, - &row_ptrs, - nnz_per_row.as_deref(), - &col_indices, - ); - - Self { - nrows, - ncols, - row_ptr: row_ptrs, - row_nnz: nnz_per_row, - col_ind: col_indices, - } - } - - /// Returns the components of the matrix in the order: - /// - row count, - /// - column count, - /// - row pointers, - /// - nonzeros per row, - /// - column indices. - #[inline] - pub fn into_parts(self) -> (usize, usize, Vec, Option>, Vec) { - ( - self.nrows, - self.ncols, - self.row_ptr, - self.row_nnz, - self.col_ind, - ) - } - - /// Returns a view over the symbolic structure of `self`. - #[inline] - pub fn as_ref(&self) -> SymbolicSparseRowMatRef<'_, I> { - SymbolicSparseRowMatRef { - nrows: self.nrows, - ncols: self.ncols, - row_ptr: &self.row_ptr, - row_nnz: self.row_nnz.as_deref(), - col_ind: &self.col_ind, - } - } - - /// Returns the number of rows of the matrix. - #[inline] - pub fn nrows(&self) -> usize { - self.nrows - } - /// Returns the number of columns of the matrix. - #[inline] - pub fn ncols(&self) -> usize { - self.ncols - } - - /// Consumes the matrix, and returns its transpose in column-major format without reallocating. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn into_transpose(self) -> SymbolicSparseColMat { - SymbolicSparseColMat { - nrows: self.ncols, - ncols: self.nrows, - col_ptr: self.row_ptr, - col_nnz: self.row_nnz, - row_ind: self.col_ind, - } - } - - /// Copies the current matrix into a newly allocated matrix. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn to_owned(&self) -> Result, FaerError> { - self.as_ref().to_owned() - } - - /// Copies the current matrix into a newly allocated matrix, with column-major order. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - #[inline] - pub fn to_col_major(&self) -> Result, FaerError> { - self.as_ref().to_col_major() - } - - /// Returns the number of symbolic non-zeros in the matrix. - /// - /// The value is guaranteed to be less than `I::Signed::MAX`. - /// - /// # Note - /// Allows unsorted matrices, but the output is a count of all the entries, including the - /// duplicate ones. - #[inline] - pub fn compute_nnz(&self) -> usize { - self.as_ref().compute_nnz() - } - - /// Returns the column pointers. - #[inline] - pub fn row_ptrs(&self) -> &[I] { - &self.row_ptr - } - - /// Returns the count of non-zeros per row of the matrix. - #[inline] - pub fn nnz_per_row(&self) -> Option<&[I]> { - self.row_nnz.as_deref() - } - - /// Returns the column indices. - #[inline] - pub fn col_indices(&self) -> &[I] { - &self.col_ind - } - - /// Returns the column indices of row `i`. - /// - /// # Panics - /// - /// Panics if `i >= self.nrows()`. - #[inline] - #[track_caller] - pub fn col_indices_of_row_raw(&self, i: usize) -> &[I] { - &self.col_ind[self.row_range(i)] - } - - /// Returns the column indices of row `i`. - /// - /// # Panics - /// - /// Panics if `i >= self.ncols()`. - #[inline] - #[track_caller] - pub fn col_indices_of_row( - &self, - i: usize, - ) -> impl '_ + ExactSizeIterator + DoubleEndedIterator { - self.as_ref().col_indices_of_row(i) - } - - /// Returns the range that the row `i` occupies in `self.col_indices()`. - /// - /// # Panics - /// - /// Panics if `i >= self.nrows()`. - #[inline] - #[track_caller] - pub fn row_range(&self, i: usize) -> Range { - self.as_ref().row_range(i) - } - - /// Returns the range that the row `i` occupies in `self.col_indices()`. - /// - /// # Safety - /// - /// The behavior is undefined if `i >= self.nrows()`. - #[inline] - #[track_caller] - pub unsafe fn row_range_unchecked(&self, i: usize) -> Range { - self.as_ref().row_range_unchecked(i) - } -} - -impl SymbolicSparseColMat { - /// Creates a new symbolic matrix view after asserting its invariants. - /// - /// # Panics - /// - /// See type level documentation. - #[inline] - #[track_caller] - pub fn new_checked( - nrows: usize, - ncols: usize, - col_ptrs: Vec, - nnz_per_col: Option>, - row_indices: Vec, - ) -> Self { - SymbolicSparseColMatRef::new_checked( - nrows, - ncols, - &col_ptrs, - nnz_per_col.as_deref(), - &row_indices, - ); - - Self { - nrows, - ncols, - col_ptr: col_ptrs, - col_nnz: nnz_per_col, - row_ind: row_indices, - } - } - - /// Creates a new symbolic matrix view from data containing duplicate and/or unsorted row - /// indices per column, after asserting its other invariants. - /// - /// # Panics - /// - /// See type level documentation. - #[inline] - #[track_caller] - pub fn new_unsorted_checked( - nrows: usize, - ncols: usize, - col_ptrs: Vec, - nnz_per_col: Option>, - row_indices: Vec, - ) -> Self { - SymbolicSparseColMatRef::new_unsorted_checked( - nrows, - ncols, - &col_ptrs, - nnz_per_col.as_deref(), - &row_indices, - ); - - Self { - nrows, - ncols, - col_ptr: col_ptrs, - col_nnz: nnz_per_col, - row_ind: row_indices, - } - } - - /// Creates a new symbolic matrix view without asserting its invariants. - /// - /// # Safety - /// - /// See type level documentation. - #[inline(always)] - #[track_caller] - pub unsafe fn new_unchecked( - nrows: usize, - ncols: usize, - col_ptrs: Vec, - nnz_per_col: Option>, - row_indices: Vec, - ) -> Self { - SymbolicSparseRowMatRef::new_unchecked( - nrows, - ncols, - &col_ptrs, - nnz_per_col.as_deref(), - &row_indices, - ); - - Self { - nrows, - ncols, - col_ptr: col_ptrs, - col_nnz: nnz_per_col, - row_ind: row_indices, - } - } - - /// Returns the components of the matrix in the order: - /// - row count, - /// - column count, - /// - column pointers, - /// - nonzeros per column, - /// - row indices. - #[inline] - pub fn into_parts(self) -> (usize, usize, Vec, Option>, Vec) { - ( - self.nrows, - self.ncols, - self.col_ptr, - self.col_nnz, - self.row_ind, - ) - } - - /// Returns a view over the symbolic structure of `self`. - #[inline] - pub fn as_ref(&self) -> SymbolicSparseColMatRef<'_, I> { - SymbolicSparseColMatRef { - nrows: self.nrows, - ncols: self.ncols, - col_ptr: &self.col_ptr, - col_nnz: self.col_nnz.as_deref(), - row_ind: &self.row_ind, - } - } - - /// Returns the number of rows of the matrix. - #[inline] - pub fn nrows(&self) -> usize { - self.nrows - } - /// Returns the number of columns of the matrix. - #[inline] - pub fn ncols(&self) -> usize { - self.ncols - } - - /// Consumes the matrix, and returns its transpose in row-major format without reallocating. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn into_transpose(self) -> SymbolicSparseRowMat { - SymbolicSparseRowMat { - nrows: self.ncols, - ncols: self.nrows, - row_ptr: self.col_ptr, - row_nnz: self.col_nnz, - col_ind: self.row_ind, - } - } - - /// Copies the current matrix into a newly allocated matrix. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn to_owned(&self) -> Result, FaerError> { - self.as_ref().to_owned() - } - - /// Copies the current matrix into a newly allocated matrix, with row-major order. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - #[inline] - pub fn to_row_major(&self) -> Result, FaerError> { - self.as_ref().to_row_major() - } - - /// Returns the number of symbolic non-zeros in the matrix. - /// - /// The value is guaranteed to be less than `I::Signed::MAX`. - /// - /// # Note - /// Allows unsorted matrices, but the output is a count of all the entries, including the - /// duplicate ones. - #[inline] - pub fn compute_nnz(&self) -> usize { - self.as_ref().compute_nnz() - } - - /// Returns the column pointers. - #[inline] - pub fn col_ptrs(&self) -> &[I] { - &self.col_ptr - } - - /// Returns the count of non-zeros per column of the matrix. - #[inline] - pub fn nnz_per_col(&self) -> Option<&[I]> { - self.col_nnz.as_deref() - } - - /// Returns the row indices. - #[inline] - pub fn row_indices(&self) -> &[I] { - &self.row_ind - } - - /// Returns the row indices of column `j`. - /// - /// # Panics - /// - /// Panics if `j >= self.ncols()`. - #[inline] - #[track_caller] - pub fn row_indices_of_col_raw(&self, j: usize) -> &[I] { - &self.row_ind[self.col_range(j)] - } - - /// Returns the row indices of column `j`. - /// - /// # Panics - /// - /// Panics if `j >= self.ncols()`. - #[inline] - #[track_caller] - pub fn row_indices_of_col( - &self, - j: usize, - ) -> impl '_ + ExactSizeIterator + DoubleEndedIterator { - self.as_ref().row_indices_of_col(j) - } - - /// Returns the range that the column `j` occupies in `self.row_indices()`. - /// - /// # Panics - /// - /// Panics if `j >= self.ncols()`. - #[inline] - #[track_caller] - pub fn col_range(&self, j: usize) -> Range { - self.as_ref().col_range(j) - } - - /// Returns the range that the column `j` occupies in `self.row_indices()`. - /// - /// # Safety - /// - /// The behavior is undefined if `j >= self.ncols()`. - #[inline] - #[track_caller] - pub unsafe fn col_range_unchecked(&self, j: usize) -> Range { - self.as_ref().col_range_unchecked(j) - } -} - -impl<'a, I: Index> SymbolicSparseRowMatRef<'a, I> { - /// Creates a new symbolic matrix view after asserting its invariants. - /// - /// # Panics - /// - /// See type level documentation. - #[inline] - #[track_caller] - pub fn new_checked( - nrows: usize, - ncols: usize, - row_ptrs: &'a [I], - nnz_per_row: Option<&'a [I]>, - col_indices: &'a [I], - ) -> Self { - assert!(all( - ncols <= I::Signed::MAX.zx(), - nrows <= I::Signed::MAX.zx(), - )); - assert!(row_ptrs.len() == nrows + 1); - for &[c, c_next] in windows2(row_ptrs) { - assert!(c <= c_next); - } - assert!(row_ptrs[ncols].zx() <= col_indices.len()); - - if let Some(nnz_per_row) = nnz_per_row { - for (&nnz_i, &[c, c_next]) in zip(nnz_per_row, windows2(row_ptrs)) { - assert!(nnz_i <= c_next - c); - let col_indices = &col_indices[c.zx()..c.zx() + nnz_i.zx()]; - if !col_indices.is_empty() { - let mut j_prev = col_indices[0]; - for &j in &col_indices[1..] { - assert!(j_prev < j); - j_prev = j; - } - let ncols = I::truncate(ncols); - assert!(j_prev < ncols); - } - } - } else { - for &[c, c_next] in windows2(row_ptrs) { - let col_indices = &col_indices[c.zx()..c_next.zx()]; - if !col_indices.is_empty() { - let mut j_prev = col_indices[0]; - for &j in &col_indices[1..] { - assert!(j_prev < j); - j_prev = j; - } - let ncols = I::truncate(ncols); - assert!(j_prev < ncols); - } - } - } - - Self { - nrows, - ncols, - row_ptr: row_ptrs, - row_nnz: nnz_per_row, - col_ind: col_indices, - } - } - - /// Creates a new symbolic matrix view from data containing duplicate and/or unsorted column - /// indices per row, after asserting its other invariants. - /// - /// # Panics - /// - /// See type level documentation. - #[inline] - #[track_caller] - pub fn new_unsorted_checked( - nrows: usize, - ncols: usize, - row_ptrs: &'a [I], - nnz_per_row: Option<&'a [I]>, - col_indices: &'a [I], - ) -> Self { - assert!(all( - ncols <= I::Signed::MAX.zx(), - nrows <= I::Signed::MAX.zx(), - )); - assert!(row_ptrs.len() == nrows + 1); - for &[c, c_next] in windows2(row_ptrs) { - assert!(c <= c_next); - } - assert!(row_ptrs[ncols].zx() <= col_indices.len()); - - if let Some(nnz_per_row) = nnz_per_row { - for (&nnz_i, &[c, c_next]) in zip(nnz_per_row, windows2(row_ptrs)) { - assert!(nnz_i <= c_next - c); - for &j in &col_indices[c.zx()..c.zx() + nnz_i.zx()] { - assert!(j < I::truncate(ncols)); - } - } - } else { - let c0 = row_ptrs[0].zx(); - let cn = row_ptrs[ncols].zx(); - for &j in &col_indices[c0..cn] { - assert!(j < I::truncate(ncols)); - } - } - - Self { - nrows, - ncols, - row_ptr: row_ptrs, - row_nnz: nnz_per_row, - col_ind: col_indices, - } - } - - /// Creates a new symbolic matrix view without asserting its invariants. - /// - /// # Safety - /// - /// See type level documentation. - #[inline(always)] - #[track_caller] - pub unsafe fn new_unchecked( - nrows: usize, - ncols: usize, - row_ptrs: &'a [I], - nnz_per_row: Option<&'a [I]>, - col_indices: &'a [I], - ) -> Self { - assert!(all( - ncols <= ::MAX.zx(), - nrows <= ::MAX.zx(), - )); - assert!(row_ptrs.len() == nrows + 1); - assert!(row_ptrs[nrows].zx() <= col_indices.len()); - - Self { - nrows, - ncols, - row_ptr: row_ptrs, - row_nnz: nnz_per_row, - col_ind: col_indices, - } - } - - /// Returns the number of rows of the matrix. - #[inline] - pub fn nrows(&self) -> usize { - self.nrows - } - /// Returns the number of columns of the matrix. - #[inline] - pub fn ncols(&self) -> usize { - self.ncols - } - - /// Returns a view over the transpose of `self` in column-major format. - #[inline] - pub fn transpose(self) -> SymbolicSparseColMatRef<'a, I> { - SymbolicSparseColMatRef { - nrows: self.ncols, - ncols: self.nrows, - col_ptr: self.row_ptr, - col_nnz: self.row_nnz, - row_ind: self.col_ind, - } - } - - /// Copies the current matrix into a newly allocated matrix. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn to_owned(&self) -> Result, FaerError> { - self.transpose() - .to_owned() - .map(SymbolicSparseColMat::into_transpose) - } - - /// Copies the current matrix into a newly allocated matrix, with column-major order. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - #[inline] - pub fn to_col_major(&self) -> Result, FaerError> { - self.transpose().to_row_major().map(|m| m.into_transpose()) - } - - /// Returns the number of symbolic non-zeros in the matrix. - /// - /// The value is guaranteed to be less than `I::Signed::MAX`. - /// - /// # Note - /// Allows unsorted matrices, but the output is a count of all the entries, including the - /// duplicate ones. - #[inline] - pub fn compute_nnz(&self) -> usize { - self.transpose().compute_nnz() - } - - /// Returns the column pointers. - #[inline] - pub fn row_ptrs(&self) -> &'a [I] { - self.row_ptr - } - - /// Returns the count of non-zeros per column of the matrix. - #[inline] - pub fn nnz_per_row(&self) -> Option<&'a [I]> { - self.row_nnz - } - - /// Returns the column indices. - #[inline] - pub fn col_indices(&self) -> &'a [I] { - self.col_ind - } - - /// Returns the column indices of row i. - /// - /// # Panics - /// - /// Panics if `i >= self.nrows()`. - #[inline] - #[track_caller] - pub fn col_indices_of_row_raw(&self, i: usize) -> &'a [I] { - &self.col_ind[self.row_range(i)] - } - - /// Returns the column indices of row i. - /// - /// # Panics - /// - /// Panics if `i >= self.ncols()`. - #[inline] - #[track_caller] - pub fn col_indices_of_row( - &self, - i: usize, - ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator { - self.col_indices_of_row_raw(i).iter().map( - #[inline(always)] - |&i| i.zx(), - ) - } - - /// Returns the range that the row `i` occupies in `self.col_indices()`. - /// - /// # Panics - /// - /// Panics if `i >= self.nrows()`. - #[inline] - #[track_caller] - pub fn row_range(&self, i: usize) -> Range { - let start = self.row_ptr[i].zx(); - let end = self - .row_nnz - .map(|row_nnz| row_nnz[i].zx() + start) - .unwrap_or(self.row_ptr[i + 1].zx()); - - start..end - } - - /// Returns the range that the row `i` occupies in `self.col_indices()`. - /// - /// # Safety - /// - /// The behavior is undefined if `i >= self.nrows()`. - #[inline] - #[track_caller] - pub unsafe fn row_range_unchecked(&self, i: usize) -> Range { - let start = __get_unchecked(self.row_ptr, i).zx(); - let end = self - .row_nnz - .map(|row_nnz| (__get_unchecked(row_nnz, i).zx() + start)) - .unwrap_or(__get_unchecked(self.row_ptr, i + 1).zx()); - - start..end - } -} - -impl<'a, I: Index> SymbolicSparseColMatRef<'a, I> { - /// Creates a new symbolic matrix view after asserting its invariants. - /// - /// # Panics - /// - /// See type level documentation. - #[inline] - #[track_caller] - pub fn new_checked( - nrows: usize, - ncols: usize, - col_ptrs: &'a [I], - nnz_per_col: Option<&'a [I]>, - row_indices: &'a [I], - ) -> Self { - assert!(all( - ncols <= I::Signed::MAX.zx(), - nrows <= I::Signed::MAX.zx(), - )); - assert!(col_ptrs.len() == ncols + 1); - for &[c, c_next] in windows2(col_ptrs) { - assert!(c <= c_next); - } - assert!(col_ptrs[ncols].zx() <= row_indices.len()); - - if let Some(nnz_per_col) = nnz_per_col { - for (&nnz_j, &[c, c_next]) in zip(nnz_per_col, windows2(col_ptrs)) { - assert!(nnz_j <= c_next - c); - let row_indices = &row_indices[c.zx()..c.zx() + nnz_j.zx()]; - if !row_indices.is_empty() { - let mut i_prev = row_indices[0]; - for &i in &row_indices[1..] { - assert!(i_prev < i); - i_prev = i; - } - let nrows = I::truncate(nrows); - assert!(i_prev < nrows); - } - } - } else { - for &[c, c_next] in windows2(col_ptrs) { - let row_indices = &row_indices[c.zx()..c_next.zx()]; - if !row_indices.is_empty() { - let mut i_prev = row_indices[0]; - for &i in &row_indices[1..] { - assert!(i_prev < i); - i_prev = i; - } - let nrows = I::truncate(nrows); - assert!(i_prev < nrows); - } - } - } - - Self { - nrows, - ncols, - col_ptr: col_ptrs, - col_nnz: nnz_per_col, - row_ind: row_indices, - } - } - - /// Creates a new symbolic matrix view from data containing duplicate and/or unsorted row - /// indices per column, after asserting its other invariants. - /// - /// # Panics - /// - /// See type level documentation. - #[inline] - #[track_caller] - pub fn new_unsorted_checked( - nrows: usize, - ncols: usize, - col_ptrs: &'a [I], - nnz_per_col: Option<&'a [I]>, - row_indices: &'a [I], - ) -> Self { - assert!(all( - ncols <= I::Signed::MAX.zx(), - nrows <= I::Signed::MAX.zx(), - )); - assert!(col_ptrs.len() == ncols + 1); - for &[c, c_next] in windows2(col_ptrs) { - assert!(c <= c_next); - } - assert!(col_ptrs[ncols].zx() <= row_indices.len()); - - if let Some(nnz_per_col) = nnz_per_col { - for (&nnz_j, &[c, c_next]) in zip(nnz_per_col, windows2(col_ptrs)) { - assert!(nnz_j <= c_next - c); - for &i in &row_indices[c.zx()..c.zx() + nnz_j.zx()] { - assert!(i < I::truncate(nrows)); - } - } - } else { - let c0 = col_ptrs[0].zx(); - let cn = col_ptrs[ncols].zx(); - for &i in &row_indices[c0..cn] { - assert!(i < I::truncate(nrows)); - } - } - - Self { - nrows, - ncols, - col_ptr: col_ptrs, - col_nnz: nnz_per_col, - row_ind: row_indices, - } - } - - /// Creates a new symbolic matrix view without asserting its invariants. - /// - /// # Safety - /// - /// See type level documentation. - #[inline(always)] - #[track_caller] - pub unsafe fn new_unchecked( - nrows: usize, - ncols: usize, - col_ptrs: &'a [I], - nnz_per_col: Option<&'a [I]>, - row_indices: &'a [I], - ) -> Self { - assert!(all( - ncols <= ::MAX.zx(), - nrows <= ::MAX.zx(), - )); - assert!(col_ptrs.len() == ncols + 1); - assert!(col_ptrs[ncols].zx() <= row_indices.len()); - - Self { - nrows, - ncols, - col_ptr: col_ptrs, - col_nnz: nnz_per_col, - row_ind: row_indices, - } - } - - /// Returns the number of rows of the matrix. - #[inline] - pub fn nrows(&self) -> usize { - self.nrows - } - /// Returns the number of columns of the matrix. - #[inline] - pub fn ncols(&self) -> usize { - self.ncols - } - - /// Returns a view over the transpose of `self` in row-major format. - #[inline] - pub fn transpose(self) -> SymbolicSparseRowMatRef<'a, I> { - SymbolicSparseRowMatRef { - nrows: self.ncols, - ncols: self.nrows, - row_ptr: self.col_ptr, - row_nnz: self.col_nnz, - col_ind: self.row_ind, - } - } - - /// Copies the current matrix into a newly allocated matrix. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn to_owned(&self) -> Result, FaerError> { - Ok(SymbolicSparseColMat { - nrows: self.nrows, - ncols: self.ncols, - col_ptr: try_collect(self.col_ptr.iter().copied())?, - col_nnz: self - .col_nnz - .map(|nnz| try_collect(nnz.iter().copied())) - .transpose()?, - row_ind: try_collect(self.row_ind.iter().copied())?, - }) - } - - /// Copies the current matrix into a newly allocated matrix, with row-major order. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - #[inline] - pub fn to_row_major(&self) -> Result, FaerError> { - let mut col_ptr = try_zeroed::(self.nrows + 1)?; - let mut row_ind = try_zeroed::(self.compute_nnz())?; - - let mut mem = GlobalPodBuffer::try_new(StackReq::new::(self.nrows)) - .map_err(|_| FaerError::OutOfMemory)?; - - util::adjoint_symbolic(&mut col_ptr, &mut row_ind, *self, PodStack::new(&mut mem)); - - let transpose = unsafe { - SymbolicSparseColMat::new_unchecked(self.ncols, self.nrows, col_ptr, None, row_ind) - }; - - Ok(transpose.into_transpose()) - } - - /// Returns the number of symbolic non-zeros in the matrix. - /// - /// The value is guaranteed to be less than `I::Signed::MAX`. - /// - /// # Note - /// Allows unsorted matrices, but the output is a count of all the entries, including the - /// duplicate ones. - #[inline] - pub fn compute_nnz(&self) -> usize { - match self.col_nnz { - Some(col_nnz) => { - let mut nnz = 0usize; - for &nnz_j in col_nnz { - // can't overflow - nnz += nnz_j.zx(); - } - nnz - } - None => self.col_ptr[self.ncols].zx() - self.col_ptr[0].zx(), - } - } - - /// Returns the column pointers. - #[inline] - pub fn col_ptrs(&self) -> &'a [I] { - self.col_ptr - } - - /// Returns the count of non-zeros per column of the matrix. - #[inline] - pub fn nnz_per_col(&self) -> Option<&'a [I]> { - self.col_nnz - } - - /// Returns the row indices. - #[inline] - pub fn row_indices(&self) -> &'a [I] { - self.row_ind - } - - /// Returns the row indices of column `j`. - /// - /// # Panics - /// - /// Panics if `j >= self.ncols()`. - #[inline] - #[track_caller] - pub fn row_indices_of_col_raw(&self, j: usize) -> &'a [I] { - &self.row_ind[self.col_range(j)] - } - - /// Returns the row indices of column `j`. - /// - /// # Panics - /// - /// Panics if `j >= self.ncols()`. - #[inline] - #[track_caller] - pub fn row_indices_of_col( - &self, - j: usize, - ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator { - self.row_indices_of_col_raw(j).iter().map( - #[inline(always)] - |&i| i.zx(), - ) - } - - /// Returns the range that the column `j` occupies in `self.row_indices()`. - /// - /// # Panics - /// - /// Panics if `j >= self.ncols()`. - #[inline] - #[track_caller] - pub fn col_range(&self, j: usize) -> Range { - let start = self.col_ptr[j].zx(); - let end = self - .col_nnz - .map(|col_nnz| col_nnz[j].zx() + start) - .unwrap_or(self.col_ptr[j + 1].zx()); - - start..end - } - - /// Returns the range that the column `j` occupies in `self.row_indices()`. - /// - /// # Safety - /// - /// The behavior is undefined if `j >= self.ncols()`. - #[inline] - #[track_caller] - pub unsafe fn col_range_unchecked(&self, j: usize) -> Range { - let start = __get_unchecked(self.col_ptr, j).zx(); - let end = self - .col_nnz - .map(|col_nnz| (__get_unchecked(col_nnz, j).zx() + start)) - .unwrap_or(__get_unchecked(self.col_ptr, j + 1).zx()); - - start..end - } -} - -impl core::fmt::Debug for SymbolicSparseColMatRef<'_, I> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mat = *self; - let mut iter = (0..mat.ncols()).into_iter().flat_map(move |j| { - struct Wrapper(usize, usize); - impl core::fmt::Debug for Wrapper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let row = self.0; - let col = self.1; - write!(f, "({row}, {col}") - } - } - - mat.row_indices_of_col(j).map(move |i| Wrapper(i, j)) - }); - - f.debug_list().entries(&mut iter).finish() - } -} - -impl core::fmt::Debug for SymbolicSparseRowMatRef<'_, I> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mat = *self; - let mut iter = (0..mat.nrows()).into_iter().flat_map(move |i| { - struct Wrapper(usize, usize); - impl core::fmt::Debug for Wrapper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let row = self.0; - let col = self.1; - write!(f, "({row}, {col}") - } - } - - mat.col_indices_of_row(i).map(move |j| Wrapper(i, j)) - }); - - f.debug_list().entries(&mut iter).finish() - } -} -impl core::fmt::Debug for SparseColMatRef<'_, I, E> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mat = *self; - let mut iter = (0..mat.ncols()).into_iter().flat_map(move |j| { - struct Wrapper(usize, usize, E); - impl core::fmt::Debug for Wrapper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let row = self.0; - let col = self.1; - let val = &self.2; - write!(f, "({row}, {col}, {val:?})") - } - } - - mat.row_indices_of_col(j) - .zip(SliceGroup::::new(mat.values_of_col(j)).into_ref_iter()) - .map(move |(i, val)| Wrapper(i, j, val.read())) - }); - - f.debug_list().entries(&mut iter).finish() - } -} - -impl core::fmt::Debug for SparseRowMatRef<'_, I, E> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - let mat = *self; - let mut iter = (0..mat.nrows()).into_iter().flat_map(move |i| { - struct Wrapper(usize, usize, E); - impl core::fmt::Debug for Wrapper { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let row = self.0; - let col = self.1; - let val = &self.2; - write!(f, "({row}, {col}, {val:?})") - } - } - - mat.col_indices_of_row(i) - .zip(SliceGroup::::new(mat.values_of_row(i)).into_ref_iter()) - .map(move |(j, val)| Wrapper(i, j, val.read())) - }); - - f.debug_list().entries(&mut iter).finish() - } -} - -impl core::fmt::Debug for SparseColMatMut<'_, I, E> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.rb().fmt(f) - } -} - -impl core::fmt::Debug for SparseColMat { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.as_ref().fmt(f) - } -} - -impl core::fmt::Debug for SymbolicSparseColMat { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.as_ref().fmt(f) - } -} - -impl core::fmt::Debug for SparseRowMatMut<'_, I, E> { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.rb().fmt(f) - } -} - -impl core::fmt::Debug for SparseRowMat { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.as_ref().fmt(f) - } -} - -impl core::fmt::Debug for SymbolicSparseRowMat { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - self.as_ref().fmt(f) - } -} - -/// Sparse matrix mutable view in row-major format, either compressed or uncompressed. -/// -/// Note that only the values are modifiable through this view. The row pointers and column -/// indices are fixed. -pub type SparseRowMatMut<'a, I, E> = Matrix>; - -/// Sparse matrix mutable view in column-major format, either compressed or uncompressed. -/// -/// Note that only the values are modifiable through this view. The column pointers and row indices -/// are fixed. -pub type SparseColMatMut<'a, I, E> = Matrix>; - -/// Sparse matrix view in row-major format, either compressed or uncompressed. -pub type SparseRowMatRef<'a, I, E> = Matrix>; - -/// Sparse matrix view in column-major format, either compressed or uncompressed. -pub type SparseColMatRef<'a, I, E> = Matrix>; - -/// Sparse matrix in row-major format, either compressed or uncompressed. -pub type SparseRowMat = Matrix>; - -/// Sparse matrix in column-major format, either compressed or uncompressed. -pub type SparseColMat = Matrix>; - -impl<'a, I: Index, E: Entity> SparseRowMatMut<'a, I, E> { - /// Creates a new sparse matrix view. - /// - /// # Panics - /// - /// Panics if the length of `values` is not equal to the length of - /// `symbolic.col_indices()`. - #[inline] - #[track_caller] - pub fn new( - symbolic: SymbolicSparseRowMatRef<'a, I>, - values: GroupFor, - ) -> Self { - let values = SliceGroupMut::new(values); - assert!(symbolic.col_indices().len() == values.len()); - Self { - inner: inner::SparseRowMatMut { symbolic, values }, - } - } - - /// Copies the current matrix into a newly allocated matrix. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn to_owned(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - self.rb().to_owned() - } - - /// Copies the current matrix into a newly allocated matrix, with column-major order. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - #[inline] - pub fn to_col_major(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - self.rb().to_col_major() - } - - /// Returns a view over the transpose of `self` in column-major format. - #[inline] - pub fn transpose_mut(self) -> SparseColMatMut<'a, I, E> { - SparseColMatMut { - inner: inner::SparseColMatMut { - symbolic: SymbolicSparseColMatRef { - nrows: self.inner.symbolic.ncols, - ncols: self.inner.symbolic.nrows, - col_ptr: self.inner.symbolic.row_ptr, - col_nnz: self.inner.symbolic.row_nnz, - row_ind: self.inner.symbolic.col_ind, - }, - values: self.inner.values, - }, - } - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn canonicalize_mut(self) -> (SparseRowMatMut<'a, I, E::Canonical>, Conj) - where - E: Conjugate, - { - ( - SparseRowMatMut { - inner: inner::SparseRowMatMut { - symbolic: self.inner.symbolic, - values: unsafe { - SliceGroupMut::<'a, E::Canonical>::new(transmute_unchecked::< - GroupFor]>, - GroupFor]>, - >( - E::faer_map(self.inner.values.into_inner(), |slice| { - let len = slice.len(); - core::slice::from_raw_parts_mut( - slice.as_mut_ptr() as *mut UnitFor - as *mut UnitFor, - len, - ) - }), - )) - }, - }, - }, - if coe::is_same::() { - Conj::No - } else { - Conj::Yes - }, - ) - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn conjugate_mut(self) -> SparseRowMatMut<'a, I, E::Conj> - where - E: Conjugate, - { - SparseRowMatMut { - inner: inner::SparseRowMatMut { - symbolic: self.inner.symbolic, - values: unsafe { - SliceGroupMut::<'a, E::Conj>::new(transmute_unchecked::< - GroupFor]>, - GroupFor]>, - >(E::faer_map( - self.inner.values.into_inner(), - |slice| { - let len = slice.len(); - core::slice::from_raw_parts_mut( - slice.as_mut_ptr() as *mut UnitFor as *mut UnitFor, - len, - ) - }, - ))) - }, - }, - } - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline] - pub fn adjoint_mut(self) -> SparseColMatMut<'a, I, E::Conj> - where - E: Conjugate, - { - self.transpose_mut().conjugate_mut() - } - - /// Returns the numerical values of the matrix. - #[inline] - pub fn values_mut(self) -> GroupFor { - self.inner.values.into_inner() - } - - /// Returns the numerical values of row `i` of the matrix. - /// - /// # Panics: - /// - /// Panics if `i >= nrows`. - #[inline] - #[track_caller] - pub fn values_of_row_mut(self, i: usize) -> GroupFor { - let range = self.symbolic().row_range(i); - self.inner.values.subslice(range).into_inner() - } - - /// Returns the symbolic structure of the matrix. - #[inline] - pub fn symbolic(&self) -> SymbolicSparseRowMatRef<'a, I> { - self.inner.symbolic - } - - /// Decomposes the matrix into the symbolic part and the numerical values. - #[inline] - pub fn into_parts( - self, - ) -> ( - SymbolicSparseRowMatRef<'a, I>, - GroupFor, - ) { - (self.inner.symbolic, self.inner.values.into_inner()) - } -} - -impl<'a, I: Index, E: Entity> SparseColMatMut<'a, I, E> { - /// Creates a new sparse matrix view. - /// - /// # Panics - /// - /// Panics if the length of `values` is not equal to the length of - /// `symbolic.row_indices()`. - #[inline] - #[track_caller] - pub fn new( - symbolic: SymbolicSparseColMatRef<'a, I>, - values: GroupFor, - ) -> Self { - let values = SliceGroupMut::new(values); - assert!(symbolic.row_indices().len() == values.len()); - Self { - inner: inner::SparseColMatMut { symbolic, values }, - } - } - - /// Copies the current matrix into a newly allocated matrix. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn to_owned(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - self.rb().to_owned() - } - - /// Copies the current matrix into a newly allocated matrix, with row-major order. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - #[inline] - pub fn to_row_major(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - self.rb().to_row_major() - } - - /// Returns a view over the transpose of `self` in row-major format. - #[inline] - pub fn transpose_mut(self) -> SparseRowMatMut<'a, I, E> { - SparseRowMatMut { - inner: inner::SparseRowMatMut { - symbolic: SymbolicSparseRowMatRef { - nrows: self.inner.symbolic.ncols, - ncols: self.inner.symbolic.nrows, - row_ptr: self.inner.symbolic.col_ptr, - row_nnz: self.inner.symbolic.col_nnz, - col_ind: self.inner.symbolic.row_ind, - }, - values: self.inner.values, - }, - } - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn conjugate_mut(self) -> SparseColMatMut<'a, I, E::Conj> - where - E: Conjugate, - { - SparseColMatMut { - inner: inner::SparseColMatMut { - symbolic: self.inner.symbolic, - values: unsafe { - SliceGroupMut::<'a, E::Conj>::new(transmute_unchecked::< - GroupFor]>, - GroupFor]>, - >(E::faer_map( - self.inner.values.into_inner(), - |slice| { - let len = slice.len(); - core::slice::from_raw_parts_mut( - slice.as_ptr() as *mut UnitFor as *mut UnitFor, - len, - ) - }, - ))) - }, - }, - } - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn canonicalize_mut(self) -> (SparseColMatMut<'a, I, E::Canonical>, Conj) - where - E: Conjugate, - { - ( - SparseColMatMut { - inner: inner::SparseColMatMut { - symbolic: self.inner.symbolic, - values: unsafe { - SliceGroupMut::<'a, E::Canonical>::new(transmute_unchecked::< - GroupFor]>, - GroupFor]>, - >( - E::faer_map(self.inner.values.into_inner(), |slice| { - let len = slice.len(); - core::slice::from_raw_parts_mut( - slice.as_mut_ptr() as *mut UnitFor - as *mut UnitFor, - len, - ) - }), - )) - }, - }, - }, - if coe::is_same::() { - Conj::No - } else { - Conj::Yes - }, - ) - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline] - pub fn adjoint_mut(self) -> SparseRowMatMut<'a, I, E::Conj> - where - E: Conjugate, - { - self.transpose_mut().conjugate_mut() - } - - /// Returns the numerical values of the matrix. - #[inline] - pub fn values_mut(self) -> GroupFor { - self.inner.values.into_inner() - } - - /// Returns the numerical values of column `j` of the matrix. - /// - /// # Panics: - /// - /// Panics if `j >= ncols`. - #[inline] - #[track_caller] - pub fn values_of_col_mut(self, j: usize) -> GroupFor { - let range = self.symbolic().col_range(j); - self.inner.values.subslice(range).into_inner() - } - - /// Returns the symbolic structure of the matrix. - #[inline] - pub fn symbolic(&self) -> SymbolicSparseColMatRef<'a, I> { - self.inner.symbolic - } - - /// Decomposes the matrix into the symbolic part and the numerical values. - #[inline] - pub fn into_parts_mut( - self, - ) -> ( - SymbolicSparseColMatRef<'a, I>, - GroupFor, - ) { - (self.inner.symbolic, self.inner.values.into_inner()) - } -} - -impl<'a, I: Index, E: Entity> SparseRowMatRef<'a, I, E> { - /// Creates a new sparse matrix view. - /// - /// # Panics - /// - /// Panics if the length of `values` is not equal to the length of - /// `symbolic.col_indices()`. - #[inline] - #[track_caller] - pub fn new( - symbolic: SymbolicSparseRowMatRef<'a, I>, - values: GroupFor, - ) -> Self { - let values = SliceGroup::new(values); - assert!(symbolic.col_indices().len() == values.len()); - Self { - inner: inner::SparseRowMatRef { symbolic, values }, - } - } - - /// Returns the numerical values of the matrix. - #[inline] - pub fn values(self) -> GroupFor { - self.inner.values.into_inner() - } - - /// Copies the current matrix into a newly allocated matrix. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn to_owned(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - self.transpose() - .to_owned() - .map(SparseColMat::into_transpose) - } - - /// Copies the current matrix into a newly allocated matrix, with column-major order. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - #[inline] - pub fn to_col_major(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - self.transpose() - .to_row_major() - .map(SparseRowMat::into_transpose) - } - - /// Returns a view over the transpose of `self` in column-major format. - #[inline] - pub fn transpose(self) -> SparseColMatRef<'a, I, E> { - SparseColMatRef { - inner: inner::SparseColMatRef { - symbolic: SymbolicSparseColMatRef { - nrows: self.inner.symbolic.ncols, - ncols: self.inner.symbolic.nrows, - col_ptr: self.inner.symbolic.row_ptr, - col_nnz: self.inner.symbolic.row_nnz, - row_ind: self.inner.symbolic.col_ind, - }, - values: self.inner.values, - }, - } - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn conjugate(self) -> SparseRowMatRef<'a, I, E::Conj> - where - E: Conjugate, - { - SparseRowMatRef { - inner: inner::SparseRowMatRef { - symbolic: self.inner.symbolic, - values: unsafe { - SliceGroup::<'a, E::Conj>::new(transmute_unchecked::< - GroupFor]>, - GroupFor]>, - >(E::faer_map( - self.inner.values.into_inner(), - |slice| { - let len = slice.len(); - core::slice::from_raw_parts( - slice.as_ptr() as *const UnitFor as *const UnitFor, - len, - ) - }, - ))) - }, - }, - } - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn canonicalize(self) -> (SparseRowMatRef<'a, I, E::Canonical>, Conj) - where - E: Conjugate, - { - ( - SparseRowMatRef { - inner: inner::SparseRowMatRef { - symbolic: self.inner.symbolic, - values: unsafe { - SliceGroup::<'a, E::Canonical>::new(transmute_unchecked::< - GroupFor]>, - GroupFor]>, - >(E::faer_map( - self.inner.values.into_inner(), - |slice| { - let len = slice.len(); - core::slice::from_raw_parts( - slice.as_ptr() as *const UnitFor - as *const UnitFor, - len, - ) - }, - ))) - }, - }, - }, - if coe::is_same::() { - Conj::No - } else { - Conj::Yes - }, - ) - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline] - pub fn adjoint(self) -> SparseColMatRef<'a, I, E::Conj> - where - E: Conjugate, - { - self.transpose().conjugate() - } - - /// Returns the numerical values of row `i` of the matrix. - /// - /// # Panics: - /// - /// Panics if `i >= nrows`. - #[inline] - #[track_caller] - pub fn values_of_row(self, i: usize) -> GroupFor { - self.inner.values.subslice(self.row_range(i)).into_inner() - } - - /// Returns the symbolic structure of the matrix. - #[inline] - pub fn symbolic(&self) -> SymbolicSparseRowMatRef<'a, I> { - self.inner.symbolic - } - - /// Decomposes the matrix into the symbolic part and the numerical values. - #[inline] - pub fn into_parts(self) -> (SymbolicSparseRowMatRef<'a, I>, GroupFor) { - (self.inner.symbolic, self.inner.values.into_inner()) - } -} - -impl<'a, I: Index, E: Entity> SparseColMatRef<'a, I, E> { - /// Creates a new sparse matrix view. - /// - /// # Panics - /// - /// Panics if the length of `values` is not equal to the length of - /// `symbolic.row_indices()`. - #[inline] - #[track_caller] - pub fn new( - symbolic: SymbolicSparseColMatRef<'a, I>, - values: GroupFor, - ) -> Self { - let values = SliceGroup::new(values); - assert!(symbolic.row_indices().len() == values.len()); - Self { - inner: inner::SparseColMatRef { symbolic, values }, - } - } - - /// Copies the current matrix into a newly allocated matrix. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn to_owned(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - let symbolic = self.symbolic().to_owned()?; - let mut values = VecGroup::::new(); - - values - .try_reserve_exact(self.inner.values.len()) - .map_err(|_| FaerError::OutOfMemory)?; - - values.resize( - self.inner.values.len(), - E::Canonical::faer_zero().faer_into_units(), - ); - - let src = self.inner.values; - let dst = values.as_slice_mut(); - - for (mut dst, src) in core::iter::zip(dst.into_mut_iter(), src.into_ref_iter()) { - dst.write(src.read().canonicalize()); - } - - Ok(SparseColMat { - inner: inner::SparseColMat { symbolic, values }, - }) - } - - /// Copies the current matrix into a newly allocated matrix, with row-major order. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - #[inline] - pub fn to_row_major(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - let mut col_ptr = try_zeroed::(self.nrows + 1)?; - let nnz = self.compute_nnz(); - let mut row_ind = try_zeroed::(nnz)?; - let mut values = VecGroup::::new(); - values - .try_reserve_exact(nnz) - .map_err(|_| FaerError::OutOfMemory)?; - values.resize(nnz, E::Canonical::faer_zero().faer_into_units()); - - let mut mem = GlobalPodBuffer::try_new(StackReq::new::(self.nrows)) - .map_err(|_| FaerError::OutOfMemory)?; - - let (this, conj) = self.canonicalize(); - - if conj == Conj::No { - util::transpose::( - &mut col_ptr, - &mut row_ind, - values.as_slice_mut().into_inner(), - this, - PodStack::new(&mut mem), - ); - } else { - util::adjoint::( - &mut col_ptr, - &mut row_ind, - values.as_slice_mut().into_inner(), - this, - PodStack::new(&mut mem), - ); - } - - let transpose = unsafe { - SparseColMat::new( - SymbolicSparseColMat::new_unchecked(self.ncols, self.nrows, col_ptr, None, row_ind), - values.into_inner(), - ) - }; - - Ok(transpose.into_transpose()) - } - - /// Returns a view over the transpose of `self` in row-major format. - #[inline] - pub fn transpose(self) -> SparseRowMatRef<'a, I, E> { - SparseRowMatRef { - inner: inner::SparseRowMatRef { - symbolic: SymbolicSparseRowMatRef { - nrows: self.inner.symbolic.ncols, - ncols: self.inner.symbolic.nrows, - row_ptr: self.inner.symbolic.col_ptr, - row_nnz: self.inner.symbolic.col_nnz, - col_ind: self.inner.symbolic.row_ind, - }, - values: self.inner.values, - }, - } - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn conjugate(self) -> SparseColMatRef<'a, I, E::Conj> - where - E: Conjugate, - { - SparseColMatRef { - inner: inner::SparseColMatRef { - symbolic: self.inner.symbolic, - values: unsafe { - SliceGroup::<'a, E::Conj>::new(transmute_unchecked::< - GroupFor]>, - GroupFor]>, - >(E::faer_map( - self.inner.values.into_inner(), - |slice| { - let len = slice.len(); - core::slice::from_raw_parts( - slice.as_ptr() as *const UnitFor as *const UnitFor, - len, - ) - }, - ))) - }, - }, - } - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn canonicalize(self) -> (SparseColMatRef<'a, I, E::Canonical>, Conj) - where - E: Conjugate, - { - ( - SparseColMatRef { - inner: inner::SparseColMatRef { - symbolic: self.inner.symbolic, - values: unsafe { - SliceGroup::<'a, E::Canonical>::new(transmute_unchecked::< - GroupFor]>, - GroupFor]>, - >(E::faer_map( - self.inner.values.into_inner(), - |slice| { - let len = slice.len(); - core::slice::from_raw_parts( - slice.as_ptr() as *const UnitFor - as *const UnitFor, - len, - ) - }, - ))) - }, - }, - }, - if coe::is_same::() { - Conj::No - } else { - Conj::Yes - }, - ) - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline] - pub fn adjoint(self) -> SparseRowMatRef<'a, I, E::Conj> - where - E: Conjugate, - { - self.transpose().conjugate() - } - - /// Returns the numerical values of the matrix. - #[inline] - pub fn values(self) -> GroupFor { - self.inner.values.into_inner() - } - - /// Returns the numerical values of column `j` of the matrix. - /// - /// # Panics: - /// - /// Panics if `j >= ncols`. - #[inline] - #[track_caller] - pub fn values_of_col(self, j: usize) -> GroupFor { - self.inner.values.subslice(self.col_range(j)).into_inner() - } - - /// Returns the symbolic structure of the matrix. - #[inline] - pub fn symbolic(&self) -> SymbolicSparseColMatRef<'a, I> { - self.inner.symbolic - } - - /// Decomposes the matrix into the symbolic part and the numerical values. - #[inline] - pub fn into_parts(self) -> (SymbolicSparseColMatRef<'a, I>, GroupFor) { - (self.inner.symbolic, self.inner.values.into_inner()) - } -} - -impl SparseColMat { - /// Creates a new sparse matrix view. - /// - /// # Panics - /// - /// Panics if the length of `values` is not equal to the length of - /// `symbolic.row_indices()`. - #[inline] - #[track_caller] - pub fn new(symbolic: SymbolicSparseColMat, values: GroupFor>) -> Self { - let values = VecGroup::from_inner(values); - assert!(symbolic.row_indices().len() == values.len()); - Self { - inner: inner::SparseColMat { symbolic, values }, - } - } - - /// Copies the current matrix into a newly allocated matrix. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn to_owned(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - self.as_ref().to_owned() - } - - /// Copies the current matrix into a newly allocated matrix, with row-major order. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - #[inline] - pub fn to_row_major(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - self.as_ref().to_row_major() - } - - /// Decomposes the matrix into the symbolic part and the numerical values. - #[inline] - pub fn into_parts(self) -> (SymbolicSparseColMat, GroupFor>) { - (self.inner.symbolic, self.inner.values.into_inner()) - } - - /// Returns a view over `self`. - #[inline] - pub fn as_ref(&self) -> SparseColMatRef<'_, I, E> { - SparseColMatRef { - inner: inner::SparseColMatRef { - symbolic: self.inner.symbolic.as_ref(), - values: self.inner.values.as_slice(), - }, - } - } - - /// Returns a mutable view over `self`. - /// - /// Note that the symbolic structure cannot be changed through this view. - #[inline] - pub fn as_mut(&mut self) -> SparseColMatMut<'_, I, E> { - SparseColMatMut { - inner: inner::SparseColMatMut { - symbolic: self.inner.symbolic.as_ref(), - values: self.inner.values.as_slice_mut(), - }, - } - } - - /// Returns a slice over the numerical values of the matrix. - #[inline] - pub fn values(&self) -> GroupFor { - self.inner.values.as_slice().into_inner() - } - - /// Returns a mutable slice over the numerical values of the matrix. - #[inline] - pub fn values_mut(&mut self) -> GroupFor { - self.inner.values.as_slice_mut().into_inner() - } - - /// Returns a view over the transpose of `self` in row-major format. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn into_transpose(self) -> SparseRowMat { - SparseRowMat { - inner: inner::SparseRowMat { - symbolic: SymbolicSparseRowMat { - nrows: self.inner.symbolic.ncols, - ncols: self.inner.symbolic.nrows, - row_ptr: self.inner.symbolic.col_ptr, - row_nnz: self.inner.symbolic.col_nnz, - col_ind: self.inner.symbolic.row_ind, - }, - values: self.inner.values, - }, - } - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn into_conjugate(self) -> SparseColMat - where - E: Conjugate, - { - SparseColMat { - inner: inner::SparseColMat { - symbolic: self.inner.symbolic, - values: unsafe { - VecGroup::::from_inner(transmute_unchecked::< - GroupFor>>, - GroupFor>>, - >(E::faer_map( - self.inner.values.into_inner(), - |mut slice| { - let len = slice.len(); - let cap = slice.capacity(); - let ptr = - slice.as_mut_ptr() as *mut UnitFor as *mut UnitFor; - - Vec::from_raw_parts(ptr, len, cap) - }, - ))) - }, - }, - } - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline] - pub fn into_adjoint(self) -> SparseRowMat - where - E: Conjugate, - { - self.into_transpose().into_conjugate() - } -} - -impl SparseRowMat { - /// Creates a new sparse matrix view. - /// - /// # Panics - /// - /// Panics if the length of `values` is not equal to the length of - /// `symbolic.col_indices()`. - #[inline] - #[track_caller] - pub fn new(symbolic: SymbolicSparseRowMat, values: GroupFor>) -> Self { - let values = VecGroup::from_inner(values); - assert!(symbolic.col_indices().len() == values.len()); - Self { - inner: inner::SparseRowMat { symbolic, values }, - } - } - - /// Copies the current matrix into a newly allocated matrix. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn to_owned(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - self.as_ref().to_owned() - } - - /// Copies the current matrix into a newly allocated matrix, with column-major order. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - #[inline] - pub fn to_col_major(&self) -> Result, FaerError> - where - E: Conjugate, - E::Canonical: ComplexField, - { - self.as_ref().to_col_major() - } - - /// Decomposes the matrix into the symbolic part and the numerical values. - #[inline] - pub fn into_parts(self) -> (SymbolicSparseRowMat, GroupFor>) { - (self.inner.symbolic, self.inner.values.into_inner()) - } - - /// Returns a view over `self`. - #[inline] - pub fn as_ref(&self) -> SparseRowMatRef<'_, I, E> { - SparseRowMatRef { - inner: inner::SparseRowMatRef { - symbolic: self.inner.symbolic.as_ref(), - values: self.inner.values.as_slice(), - }, - } - } - - /// Returns a mutable view over `self`. - /// - /// Note that the symbolic structure cannot be changed through this view. - #[inline] - pub fn as_mut(&mut self) -> SparseRowMatMut<'_, I, E> { - SparseRowMatMut { - inner: inner::SparseRowMatMut { - symbolic: self.inner.symbolic.as_ref(), - values: self.inner.values.as_slice_mut(), - }, - } - } - - /// Returns a slice over the numerical values of the matrix. - #[inline] - pub fn values(&self) -> GroupFor { - self.inner.values.as_slice().into_inner() - } - - /// Returns a mutable slice over the numerical values of the matrix. - #[inline] - pub fn values_mut(&mut self) -> GroupFor { - self.inner.values.as_slice_mut().into_inner() - } - - /// Returns a view over the transpose of `self` in column-major format. - /// - /// # Note - /// Allows unsorted matrices, producing an unsorted output. - #[inline] - pub fn into_transpose(self) -> SparseColMat { - SparseColMat { - inner: inner::SparseColMat { - symbolic: SymbolicSparseColMat { - nrows: self.inner.symbolic.ncols, - ncols: self.inner.symbolic.nrows, - col_ptr: self.inner.symbolic.row_ptr, - col_nnz: self.inner.symbolic.row_nnz, - row_ind: self.inner.symbolic.col_ind, - }, - values: self.inner.values, - }, - } - } - - /// Returns a view over the conjugate of `self`. - #[inline] - pub fn into_conjugate(self) -> SparseRowMat - where - E: Conjugate, - { - SparseRowMat { - inner: inner::SparseRowMat { - symbolic: self.inner.symbolic, - values: unsafe { - VecGroup::::from_inner(transmute_unchecked::< - GroupFor>>, - GroupFor>>, - >(E::faer_map( - self.inner.values.into_inner(), - |mut slice| { - let len = slice.len(); - let cap = slice.capacity(); - let ptr = - slice.as_mut_ptr() as *mut UnitFor as *mut UnitFor; - - Vec::from_raw_parts(ptr, len, cap) - }, - ))) - }, - }, - } - } - - /// Returns a view over the conjugate transpose of `self`. - #[inline] - pub fn into_adjoint(self) -> SparseColMat - where - E: Conjugate, - { - self.into_transpose().into_conjugate() - } -} - -// DEREF/REBORROW -const _: () = { - impl<'a, I: Index, E: Entity> core::ops::Deref for SparseRowMatMut<'a, I, E> { - type Target = SymbolicSparseRowMatRef<'a, I>; - #[inline] - fn deref(&self) -> &Self::Target { - &self.inner.symbolic - } - } - - impl<'a, I: Index, E: Entity> core::ops::Deref for SparseColMatMut<'a, I, E> { - type Target = SymbolicSparseColMatRef<'a, I>; - #[inline] - fn deref(&self) -> &Self::Target { - &self.inner.symbolic - } - } - - impl<'a, I: Index, E: Entity> core::ops::Deref for SparseRowMatRef<'a, I, E> { - type Target = SymbolicSparseRowMatRef<'a, I>; - #[inline] - fn deref(&self) -> &Self::Target { - &self.inner.symbolic - } - } - - impl<'a, I: Index, E: Entity> core::ops::Deref for SparseColMatRef<'a, I, E> { - type Target = SymbolicSparseColMatRef<'a, I>; - #[inline] - fn deref(&self) -> &Self::Target { - &self.inner.symbolic - } - } - - impl core::ops::Deref for SparseRowMat { - type Target = SymbolicSparseRowMat; - #[inline] - fn deref(&self) -> &Self::Target { - &self.inner.symbolic - } - } - - impl core::ops::Deref for SparseColMat { - type Target = SymbolicSparseColMat; - #[inline] - fn deref(&self) -> &Self::Target { - &self.inner.symbolic - } - } - - impl<'short, I: Index, E: Entity> ReborrowMut<'short> for SparseRowMatRef<'_, I, E> { - type Target = SparseRowMatRef<'short, I, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } - - impl<'short, I: Index, E: Entity> Reborrow<'short> for SparseRowMatRef<'_, I, E> { - type Target = SparseRowMatRef<'short, I, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - *self - } - } - - impl<'a, I: Index, E: Entity> IntoConst for SparseRowMatRef<'a, I, E> { - type Target = SparseRowMatRef<'a, I, E>; - - #[inline] - fn into_const(self) -> Self::Target { - self - } - } - - impl<'short, I: Index, E: Entity> ReborrowMut<'short> for SparseColMatRef<'_, I, E> { - type Target = SparseColMatRef<'short, I, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - *self - } - } - - impl<'short, I: Index, E: Entity> Reborrow<'short> for SparseColMatRef<'_, I, E> { - type Target = SparseColMatRef<'short, I, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - *self - } - } - - impl<'a, I: Index, E: Entity> IntoConst for SparseColMatRef<'a, I, E> { - type Target = SparseColMatRef<'a, I, E>; - - #[inline] - fn into_const(self) -> Self::Target { - self - } - } - - impl<'short, I: Index, E: Entity> ReborrowMut<'short> for SparseRowMatMut<'_, I, E> { - type Target = SparseRowMatMut<'short, I, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - SparseRowMatMut { - inner: inner::SparseRowMatMut { - symbolic: self.inner.symbolic, - values: self.inner.values.rb_mut(), - }, - } - } - } - - impl<'short, I: Index, E: Entity> Reborrow<'short> for SparseRowMatMut<'_, I, E> { - type Target = SparseRowMatRef<'short, I, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - SparseRowMatRef { - inner: inner::SparseRowMatRef { - symbolic: self.inner.symbolic, - values: self.inner.values.rb(), - }, - } - } - } - - impl<'a, I: Index, E: Entity> IntoConst for SparseRowMatMut<'a, I, E> { - type Target = SparseRowMatRef<'a, I, E>; - - #[inline] - fn into_const(self) -> Self::Target { - SparseRowMatRef { - inner: inner::SparseRowMatRef { - symbolic: self.inner.symbolic, - values: self.inner.values.into_const(), - }, - } - } - } - - impl<'short, I: Index, E: Entity> ReborrowMut<'short> for SparseColMatMut<'_, I, E> { - type Target = SparseColMatMut<'short, I, E>; - - #[inline] - fn rb_mut(&'short mut self) -> Self::Target { - SparseColMatMut { - inner: inner::SparseColMatMut { - symbolic: self.inner.symbolic, - values: self.inner.values.rb_mut(), - }, - } - } - } - - impl<'short, I: Index, E: Entity> Reborrow<'short> for SparseColMatMut<'_, I, E> { - type Target = SparseColMatRef<'short, I, E>; - - #[inline] - fn rb(&'short self) -> Self::Target { - SparseColMatRef { - inner: inner::SparseColMatRef { - symbolic: self.inner.symbolic, - values: self.inner.values.rb(), - }, - } - } - } - - impl<'a, I: Index, E: Entity> IntoConst for SparseColMatMut<'a, I, E> { - type Target = SparseColMatRef<'a, I, E>; - - #[inline] - fn into_const(self) -> Self::Target { - SparseColMatRef { - inner: inner::SparseColMatRef { - symbolic: self.inner.symbolic, - values: self.inner.values.into_const(), - }, - } - } - } -}; - -/// Errors that can occur in sparse algorithms. -#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] -#[non_exhaustive] -pub enum CreationError { - /// Generic error (allocation or index overflow). - Generic(FaerError), - /// Matrix index out-of-bounds error. - OutOfBounds { - /// Row of the out-of-bounds index. - row: usize, - /// Column of the out-of-bounds index. - col: usize, - }, -} - -impl From for CreationError { - #[inline] - fn from(value: FaerError) -> Self { - Self::Generic(value) - } -} -impl core::fmt::Display for CreationError { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - core::fmt::Debug::fmt(self, f) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for CreationError {} - -#[inline] -#[track_caller] -fn try_zeroed(n: usize) -> Result, FaerError> { - let mut v = alloc::vec::Vec::new(); - v.try_reserve_exact(n).map_err(|_| FaerError::OutOfMemory)?; - unsafe { - core::ptr::write_bytes::(v.as_mut_ptr(), 0u8, n); - v.set_len(n); - } - Ok(v) -} - -#[inline] -#[track_caller] -fn try_collect(iter: I) -> Result, FaerError> { - let iter = iter.into_iter(); - let mut v = alloc::vec::Vec::new(); - v.try_reserve_exact(iter.size_hint().0) - .map_err(|_| FaerError::OutOfMemory)?; - v.extend(iter); - Ok(v) -} - -/// The order values should be read in, when constructing/filling from indices and values. -/// -/// Allows separately creating the symbolic structure and filling the numerical values. -#[derive(Debug, Clone)] -pub struct ValuesOrder { - argsort: Vec, - all_nnz: usize, - nnz: usize, - __marker: PhantomData, -} - -/// Whether the filled values should replace the current matrix values or be added to them. -#[derive(Debug, Copy, Clone)] -pub enum FillMode { - /// New filled values should replace the old values. - Replace, - /// New filled values should be added to the old values. - Add, -} - -// FROM TRIPLETS -const _: () = { - const TOP_BIT: usize = 1usize << (usize::BITS - 1); - const TOP_BIT_MASK: usize = TOP_BIT - 1; - - impl SymbolicSparseColMat { - fn try_new_from_indices_impl( - nrows: usize, - ncols: usize, - indices: impl Fn(usize) -> (I, I), - all_nnz: usize, - ) -> Result<(Self, ValuesOrder), CreationError> { - if nrows > I::Signed::MAX.zx() || ncols > I::Signed::MAX.zx() { - return Err(CreationError::Generic(FaerError::IndexOverflow)); - } - - if all_nnz == 0 { - return Ok(( - Self { - nrows, - ncols, - col_ptr: try_zeroed(ncols + 1)?, - col_nnz: None, - row_ind: Vec::new(), - }, - ValuesOrder { - argsort: Vec::new(), - all_nnz, - nnz: 0, - __marker: PhantomData, - }, - )); - } - - let mut argsort = try_collect(0..all_nnz)?; - argsort.sort_unstable_by_key(|&i| { - let (row, col) = indices(i); - (col, row) - }); - - let mut n_duplicates = 0usize; - let mut current_bit = 0usize; - - let mut prev = indices(argsort[0]); - for i in 1..all_nnz { - let idx = indices(argsort[i]); - let same_as_prev = idx == prev; - prev = idx; - current_bit = ((current_bit == ((same_as_prev as usize) << (usize::BITS - 1))) - as usize) - << (usize::BITS - 1); - argsort[i] |= current_bit; - - n_duplicates += same_as_prev as usize; - } - - let nnz = all_nnz - n_duplicates; - if nnz > I::Signed::MAX.zx() { - return Err(CreationError::Generic(FaerError::IndexOverflow)); - } - - let mut col_ptr = try_zeroed::(ncols + 1)?; - let mut row_ind = try_zeroed::(nnz)?; - - let mut original_pos = 0usize; - let mut new_pos = 0usize; - - for j in 0..ncols { - let mut n_unique = 0usize; - - while original_pos < all_nnz { - let (row, col) = indices(argsort[original_pos] & TOP_BIT_MASK); - if row.zx() >= nrows || col.zx() >= ncols { - return Err(CreationError::OutOfBounds { - row: row.zx(), - col: col.zx(), - }); - } - - if col.zx() != j { - break; - } - - row_ind[new_pos] = row; - - n_unique += 1; - - new_pos += 1; - original_pos += 1; - - while original_pos < all_nnz - && indices(argsort[original_pos] & TOP_BIT_MASK) == (row, col) - { - original_pos += 1; - } - } - - col_ptr[j + 1] = col_ptr[j] + I::truncate(n_unique); - } - - Ok(( - Self { - nrows, - ncols, - col_ptr, - col_nnz: None, - row_ind, - }, - ValuesOrder { - argsort, - all_nnz, - nnz, - __marker: PhantomData, - }, - )) - } - - fn try_new_from_nonnegative_indices_impl( - nrows: usize, - ncols: usize, - indices: impl Fn(usize) -> (I::Signed, I::Signed), - all_nnz: usize, - ) -> Result<(Self, ValuesOrder), CreationError> { - if nrows > I::Signed::MAX.zx() || ncols > I::Signed::MAX.zx() { - return Err(CreationError::Generic(FaerError::IndexOverflow)); - } - - let mut argsort = try_collect(0..all_nnz)?; - argsort.sort_unstable_by_key(|&i| { - let (row, col) = indices(i); - let ignore = (row < I::Signed::truncate(0)) | (col < I::Signed::truncate(0)); - (ignore, col, row) - }); - - let all_nnz = argsort.partition_point(|&i| { - let (row, col) = indices(i); - let ignore = (row < I::Signed::truncate(0)) | (col < I::Signed::truncate(0)); - !ignore - }); - - if all_nnz == 0 { - return Ok(( - Self { - nrows, - ncols, - col_ptr: try_zeroed(ncols + 1)?, - col_nnz: None, - row_ind: Vec::new(), - }, - ValuesOrder { - argsort: Vec::new(), - all_nnz, - nnz: 0, - __marker: PhantomData, - }, - )); - } - - let mut n_duplicates = 0usize; - let mut current_bit = 0usize; - - let mut prev = indices(argsort[0]); - - for i in 1..all_nnz { - let idx = indices(argsort[i]); - let same_as_prev = idx == prev; - prev = idx; - current_bit = ((current_bit == ((same_as_prev as usize) << (usize::BITS - 1))) - as usize) - << (usize::BITS - 1); - argsort[i] |= current_bit; - - n_duplicates += same_as_prev as usize; - } - - let nnz = all_nnz - n_duplicates; - if nnz > I::Signed::MAX.zx() { - return Err(CreationError::Generic(FaerError::IndexOverflow)); - } - - let mut col_ptr = try_zeroed::(ncols + 1)?; - let mut row_ind = try_zeroed::(nnz)?; - - let mut original_pos = 0usize; - let mut new_pos = 0usize; - - for j in 0..ncols { - let mut n_unique = 0usize; - - while original_pos < all_nnz { - let (row, col) = indices(argsort[original_pos] & TOP_BIT_MASK); - if row.zx() >= nrows || col.zx() >= ncols { - return Err(CreationError::OutOfBounds { - row: row.zx(), - col: col.zx(), - }); - } - - if col.zx() != j { - break; - } - - row_ind[new_pos] = I::from_signed(row); - - n_unique += 1; - - new_pos += 1; - original_pos += 1; - - while original_pos < all_nnz - && indices(argsort[original_pos] & TOP_BIT_MASK) == (row, col) - { - original_pos += 1; - } - } - - col_ptr[j + 1] = col_ptr[j] + I::truncate(n_unique); - } - - Ok(( - Self { - nrows, - ncols, - col_ptr, - col_nnz: None, - row_ind, - }, - ValuesOrder { - argsort, - all_nnz, - nnz, - __marker: PhantomData, - }, - )) - } - - /// Create a new symbolic structure, and the corresponding order for the numerical values - /// from pairs of indices `(row, col)`. - #[inline] - pub fn try_new_from_indices( - nrows: usize, - ncols: usize, - indices: &[(I, I)], - ) -> Result<(Self, ValuesOrder), CreationError> { - Self::try_new_from_indices_impl(nrows, ncols, |i| indices[i], indices.len()) - } - - /// Create a new symbolic structure, and the corresponding order for the numerical values - /// from pairs of indices `(row, col)`. - /// - /// Negative indices are ignored. - #[inline] - pub fn try_new_from_nonnegative_indices( - nrows: usize, - ncols: usize, - indices: &[(I::Signed, I::Signed)], - ) -> Result<(Self, ValuesOrder), CreationError> { - Self::try_new_from_nonnegative_indices_impl(nrows, ncols, |i| indices[i], indices.len()) - } - } - - impl SymbolicSparseRowMat { - /// Create a new symbolic structure, and the corresponding order for the numerical values - /// from pairs of indices `(row, col)`. - #[inline] - pub fn try_new_from_indices( - nrows: usize, - ncols: usize, - indices: &[(I, I)], - ) -> Result<(Self, ValuesOrder), CreationError> { - SymbolicSparseColMat::try_new_from_indices_impl( - ncols, - nrows, - |i| { - let (row, col) = indices[i]; - (col, row) - }, - indices.len(), - ) - .map(|(m, o)| (m.into_transpose(), o)) - } - - /// Create a new symbolic structure, and the corresponding order for the numerical values - /// from pairs of indices `(row, col)`. - /// - /// Negative indices are ignored. - #[inline] - pub fn try_new_from_nonnegative_indices( - nrows: usize, - ncols: usize, - indices: &[(I::Signed, I::Signed)], - ) -> Result<(Self, ValuesOrder), CreationError> { - SymbolicSparseColMat::try_new_from_nonnegative_indices_impl( - ncols, - nrows, - |i| { - let (row, col) = indices[i]; - (col, row) - }, - indices.len(), - ) - .map(|(m, o)| (m.into_transpose(), o)) - } - } - - impl SparseColMat { - #[track_caller] - fn new_from_order_and_values_impl( - symbolic: SymbolicSparseColMat, - order: &ValuesOrder, - all_values: impl Fn(usize) -> E, - values_len: usize, - ) -> Result { - { - let nnz = order.argsort.len(); - assert!(values_len == nnz); - } - - let all_nnz = order.all_nnz; - - let mut values = VecGroup::::new(); - match values.try_reserve_exact(order.nnz) { - Ok(()) => {} - Err(_) => return Err(FaerError::OutOfMemory), - }; - - let mut pos = 0usize; - let mut pos_unique = usize::MAX; - let mut current_bit = TOP_BIT; - - while pos < all_nnz { - let argsort_pos = order.argsort[pos]; - let extracted_bit = argsort_pos & TOP_BIT; - let argsort_pos = argsort_pos & TOP_BIT_MASK; - - let val = all_values(argsort_pos); - if extracted_bit != current_bit { - values.push(val.faer_into_units()); - pos_unique = pos_unique.wrapping_add(1); - } else { - let old_val = values.as_slice().read(pos_unique); - values - .as_slice_mut() - .write(pos_unique, old_val.faer_add(val)); - } - - current_bit = extracted_bit; - - pos += 1; - } - - Ok(Self { - inner: inner::SparseColMat { symbolic, values }, - }) - } - - /// Create a new matrix from a previously created symbolic structure and value order. - /// The provided values must correspond to the same indices that were provided in the - /// function call from which the order was created. - #[track_caller] - pub fn new_from_order_and_values( - symbolic: SymbolicSparseColMat, - order: &ValuesOrder, - values: GroupFor, - ) -> Result { - let values = SliceGroup::<'_, E>::new(values); - Self::new_from_order_and_values_impl(symbolic, order, |i| values.read(i), values.len()) - } - - /// Create a new matrix from triplets `(row, col, value)`. - #[track_caller] - pub fn try_new_from_triplets( - nrows: usize, - ncols: usize, - triplets: &[(I, I, E)], - ) -> Result { - let (symbolic, order) = SymbolicSparseColMat::try_new_from_indices_impl( - nrows, - ncols, - |i| { - let (row, col, _) = triplets[i]; - (row, col) - }, - triplets.len(), - )?; - Ok(Self::new_from_order_and_values_impl( - symbolic, - &order, - |i| triplets[i].2, - triplets.len(), - )?) - } - - /// Create a new matrix from triplets `(row, col, value)`. Negative indices are ignored. - #[track_caller] - pub fn try_new_from_nonnegative_triplets( - nrows: usize, - ncols: usize, - triplets: &[(I::Signed, I::Signed, E)], - ) -> Result { - let (symbolic, order) = - SymbolicSparseColMat::::try_new_from_nonnegative_indices_impl( - nrows, - ncols, - |i| { - let (row, col, _) = triplets[i]; - (row, col) - }, - triplets.len(), - )?; - Ok(Self::new_from_order_and_values_impl( - symbolic, - &order, - |i| triplets[i].2, - triplets.len(), - )?) - } - } - - impl SparseRowMat { - /// Create a new matrix from a previously created symbolic structure and value order. - /// The provided values must correspond to the same indices that were provided in the - /// function call from which the order was created. - #[track_caller] - pub fn new_from_order_and_values( - symbolic: SymbolicSparseRowMat, - order: &ValuesOrder, - values: GroupFor, - ) -> Result { - SparseColMat::new_from_order_and_values(symbolic.into_transpose(), order, values) - .map(SparseColMat::into_transpose) - } - - /// Create a new matrix from triplets `(row, col, value)`. - #[track_caller] - pub fn try_new_from_triplets( - nrows: usize, - ncols: usize, - triplets: &[(I, I, E)], - ) -> Result { - let (symbolic, order) = SymbolicSparseColMat::try_new_from_indices_impl( - ncols, - nrows, - |i| { - let (row, col, _) = triplets[i]; - (col, row) - }, - triplets.len(), - )?; - Ok(SparseColMat::new_from_order_and_values_impl( - symbolic, - &order, - |i| triplets[i].2, - triplets.len(), - )? - .into_transpose()) - } - - /// Create a new matrix from triplets `(row, col, value)`. Negative indices are ignored. - #[track_caller] - pub fn try_new_from_nonnegative_triplets( - nrows: usize, - ncols: usize, - triplets: &[(I::Signed, I::Signed, E)], - ) -> Result { - let (symbolic, order) = - SymbolicSparseColMat::::try_new_from_nonnegative_indices_impl( - ncols, - nrows, - |i| { - let (row, col, _) = triplets[i]; - (col, row) - }, - triplets.len(), - )?; - Ok(SparseColMat::new_from_order_and_values_impl( - symbolic, - &order, - |i| triplets[i].2, - triplets.len(), - )? - .into_transpose()) - } - } - - impl SparseColMatMut<'_, I, E> { - /// Fill the matrix from a previously created value order. - /// The provided values must correspond to the same indices that were provided in the - /// function call from which the order was created. - /// - /// # Note - /// The symbolic structure is not changed. - pub fn fill_from_order_and_values( - &mut self, - order: &ValuesOrder, - values: GroupFor, - mode: FillMode, - ) { - let values = SliceGroup::<'_, E>::new(values); - - { - let nnz = order.argsort.len(); - assert!(values.len() == nnz); - assert!(order.nnz == self.inner.values.len()); - } - let all_nnz = order.all_nnz; - let mut dst = self.inner.values.rb_mut(); - - let mut pos = 0usize; - let mut pos_unique = usize::MAX; - let mut current_bit = TOP_BIT; - - match mode { - FillMode::Replace => { - while pos < all_nnz { - let argsort_pos = order.argsort[pos]; - let extracted_bit = argsort_pos & TOP_BIT; - let argsort_pos = argsort_pos & TOP_BIT_MASK; - - let val = values.read(argsort_pos); - if extracted_bit != current_bit { - pos_unique = pos_unique.wrapping_add(1); - dst.write(pos_unique, val); - } else { - let old_val = dst.read(pos_unique); - dst.write(pos_unique, old_val.faer_add(val)); - } - - current_bit = extracted_bit; - - pos += 1; - } - } - FillMode::Add => { - while pos < all_nnz { - let argsort_pos = order.argsort[pos]; - let extracted_bit = argsort_pos & TOP_BIT; - let argsort_pos = argsort_pos & TOP_BIT_MASK; - - let val = values.read(argsort_pos); - if extracted_bit != current_bit { - pos_unique = pos_unique.wrapping_add(1); - } - - let old_val = dst.read(pos_unique); - dst.write(pos_unique, old_val.faer_add(val)); - - current_bit = extracted_bit; - - pos += 1; - } - } - } - } - } - - impl SparseRowMatMut<'_, I, E> { - /// Fill the matrix from a previously created value order. - /// The provided values must correspond to the same indices that were provided in the - /// function call from which the order was created. - /// - /// # Note - /// The symbolic structure is not changed. - pub fn fill_from_order_and_values( - &mut self, - order: &ValuesOrder, - values: GroupFor, - mode: FillMode, - ) { - self.rb_mut() - .transpose_mut() - .fill_from_order_and_values(order, values, mode); - } - } -}; - -/// Useful sparse matrix primitives. -pub mod util { - use super::*; - use crate::{assert, debug_assert}; - - /// Sorts `row_indices` and `values` simultaneously so that `row_indices` is nonincreasing. - pub fn sort_indices( - col_ptrs: &[I], - row_indices: &mut [I], - values: GroupFor, - ) { - assert!(col_ptrs.len() >= 1); - let mut values = SliceGroupMut::<'_, E>::new(values); - - let n = col_ptrs.len() - 1; - for j in 0..n { - let start = col_ptrs[j].zx(); - let end = col_ptrs[j + 1].zx(); - - unsafe { - crate::sort::sort_indices( - &mut row_indices[start..end], - values.rb_mut().subslice(start..end), - ); - } - } - } - - #[doc(hidden)] - pub unsafe fn ghost_permute_hermitian_unsorted<'n, 'out, I: Index, E: ComplexField>( - new_values: SliceGroupMut<'out, E>, - new_col_ptrs: &'out mut [I], - new_row_indices: &'out mut [I], - A: ghost::SparseColMatRef<'n, 'n, '_, I, E>, - perm: ghost::PermutationRef<'n, '_, I, E>, - in_side: Side, - out_side: Side, - sort: bool, - stack: PodStack<'_>, - ) -> ghost::SparseColMatMut<'n, 'n, 'out, I, E> { - let N = A.ncols(); - let n = *A.ncols(); - - // (1) - assert!(new_col_ptrs.len() == n + 1); - let (_, perm_inv) = perm.into_arrays(); - - let (current_row_position, _) = stack.make_raw::(n); - let current_row_position = ghost::Array::from_mut(current_row_position, N); - - mem::fill_zero(current_row_position.as_mut()); - let col_counts = &mut *current_row_position; - match (in_side, out_side) { - (Side::Lower, Side::Lower) => { - for old_j in N.indices() { - let new_j = perm_inv[old_j].zx(); - for old_i in A.row_indices_of_col(old_j) { - if old_i >= old_j { - let new_i = perm_inv[old_i].zx(); - let new_min = Ord::min(new_i, new_j); - // cannot overflow because A.compute_nnz() <= I::MAX - // col_counts[new_max] always >= 0 - col_counts[new_min] += I::truncate(1); - } - } - } - } - (Side::Lower, Side::Upper) => { - for old_j in N.indices() { - let new_j = perm_inv[old_j].zx(); - for old_i in A.row_indices_of_col(old_j) { - if old_i >= old_j { - let new_i = perm_inv[old_i].zx(); - let new_max = Ord::max(new_i, new_j); - // cannot overflow because A.compute_nnz() <= I::MAX - // col_counts[new_max] always >= 0 - col_counts[new_max] += I::truncate(1); - } - } - } - } - (Side::Upper, Side::Lower) => { - for old_j in N.indices() { - let new_j = perm_inv[old_j].zx(); - for old_i in A.row_indices_of_col(old_j) { - if old_i <= old_j { - let new_i = perm_inv[old_i].zx(); - let new_min = Ord::min(new_i, new_j); - // cannot overflow because A.compute_nnz() <= I::MAX - // col_counts[new_max] always >= 0 - col_counts[new_min] += I::truncate(1); - } - } - } - } - (Side::Upper, Side::Upper) => { - for old_j in N.indices() { - let new_j = perm_inv[old_j].zx(); - for old_i in A.row_indices_of_col(old_j) { - if old_i <= old_j { - let new_i = perm_inv[old_i].zx(); - let new_max = Ord::max(new_i, new_j); - // cannot overflow because A.compute_nnz() <= I::MAX - // col_counts[new_max] always >= 0 - col_counts[new_max] += I::truncate(1); - } - } - } - } - } - - // col_counts[_] >= 0 - // cumulative sum cannot overflow because it is <= A.compute_nnz() - - // SAFETY: new_col_ptrs.len() == n + 1 > 0 - new_col_ptrs[0] = I::truncate(0); - for (count, [ci0, ci1]) in zip( - col_counts.as_mut(), - windows2(Cell::as_slice_of_cells(Cell::from_mut(&mut *new_col_ptrs))), - ) { - let ci0 = ci0.get(); - ci1.set(ci0 + *count); - *count = ci0; - } - // new_col_ptrs is non-decreasing - - let nnz = new_col_ptrs[n].zx(); - let new_row_indices = &mut new_row_indices[..nnz]; - let mut new_values = new_values.subslice(0..nnz); - - ghost::Size::with( - nnz, - #[inline(always)] - |NNZ| { - let mut new_values = - ghost::ArrayGroupMut::new(new_values.rb_mut().into_inner(), NNZ); - let new_row_indices = ghost::Array::from_mut(new_row_indices, NNZ); - - let conj_if = |cond: bool, x: E| { - if !coe::is_same::() && cond { - x.faer_conj() - } else { - x - } - }; - - match (in_side, out_side) { - (Side::Lower, Side::Lower) => { - for old_j in N.indices() { - let new_j_ = perm_inv[old_j]; - let new_j = new_j_.zx(); - - for (old_i, val) in zip( - A.row_indices_of_col(old_j), - SliceGroup::<'_, E>::new(A.values_of_col(old_j)).into_ref_iter(), - ) { - if old_i >= old_j { - let new_i_ = perm_inv[old_i]; - let new_i = new_i_.zx(); - - let new_max = Ord::max(new_i_, new_j_); - let new_min = Ord::min(new_i, new_j); - let current_row_pos: &mut I = - &mut current_row_position[new_min]; - // SAFETY: current_row_pos < NNZ - let row_pos = unsafe { - ghost::Idx::new_unchecked(current_row_pos.zx(), NNZ) - }; - *current_row_pos += I::truncate(1); - new_values - .write(row_pos, conj_if(new_min == new_i, val.read())); - // (2) - new_row_indices[row_pos] = *new_max; - } - } - } - } - (Side::Lower, Side::Upper) => { - for old_j in N.indices() { - let new_j_ = perm_inv[old_j]; - let new_j = new_j_.zx(); - - for (old_i, val) in zip( - A.row_indices_of_col(old_j), - SliceGroup::<'_, E>::new(A.values_of_col(old_j)).into_ref_iter(), - ) { - if old_i >= old_j { - let new_i_ = perm_inv[old_i]; - let new_i = new_i_.zx(); - - let new_max = Ord::max(new_i, new_j); - let new_min = Ord::min(new_i_, new_j_); - let current_row_pos = &mut current_row_position[new_max]; - // SAFETY: current_row_pos < NNZ - let row_pos = unsafe { - ghost::Idx::new_unchecked(current_row_pos.zx(), NNZ) - }; - *current_row_pos += I::truncate(1); - new_values - .write(row_pos, conj_if(new_max == new_i, val.read())); - // (2) - new_row_indices[row_pos] = *new_min; - } - } - } - } - (Side::Upper, Side::Lower) => { - for old_j in N.indices() { - let new_j_ = perm_inv[old_j]; - let new_j = new_j_.zx(); - - for (old_i, val) in zip( - A.row_indices_of_col(old_j), - SliceGroup::<'_, E>::new(A.values_of_col(old_j)).into_ref_iter(), - ) { - if old_i <= old_j { - let new_i_ = perm_inv[old_i]; - let new_i = new_i_.zx(); - - let new_max = Ord::max(new_i_, new_j_); - let new_min = Ord::min(new_i, new_j); - let current_row_pos = &mut current_row_position[new_min]; - // SAFETY: current_row_pos < NNZ - let row_pos = unsafe { - ghost::Idx::new_unchecked(current_row_pos.zx(), NNZ) - }; - *current_row_pos += I::truncate(1); - new_values - .write(row_pos, conj_if(new_min == new_i, val.read())); - // (2) - new_row_indices[row_pos] = *new_max; - } - } - } - } - (Side::Upper, Side::Upper) => { - for old_j in N.indices() { - let new_j_ = perm_inv[old_j]; - let new_j = new_j_.zx(); - - for (old_i, val) in zip( - A.row_indices_of_col(old_j), - SliceGroup::<'_, E>::new(A.values_of_col(old_j)).into_ref_iter(), - ) { - if old_i <= old_j { - let new_i_ = perm_inv[old_i]; - let new_i = new_i_.zx(); - - let new_max = Ord::max(new_i, new_j); - let new_min = Ord::min(new_i_, new_j_); - let current_row_pos = &mut current_row_position[new_max]; - // SAFETY: current_row_pos < NNZ - let row_pos = unsafe { - ghost::Idx::new_unchecked(current_row_pos.zx(), NNZ) - }; - *current_row_pos += I::truncate(1); - new_values - .write(row_pos, conj_if(new_max == new_i, val.read())); - // (2) - new_row_indices[row_pos] = *new_min; - } - } - } - } - } - debug_assert!(current_row_position.as_ref() == &new_col_ptrs[1..]); - }, - ); - - if sort { - sort_indices::( - new_col_ptrs, - new_row_indices, - new_values.rb_mut().into_inner(), - ); - } - - // SAFETY: - // 0. new_col_ptrs is non-decreasing - // 1. new_values.len() == new_row_indices.len() - // 2. all written row indices are less than n - unsafe { - ghost::SparseColMatMut::new( - SparseColMatMut::new( - SymbolicSparseColMatRef::new_unchecked( - n, - n, - new_col_ptrs, - None, - new_row_indices, - ), - new_values.into_inner(), - ), - N, - N, - ) - } - } - - #[doc(hidden)] - pub unsafe fn ghost_permute_hermitian_unsorted_symbolic<'n, 'out, I: Index>( - new_col_ptrs: &'out mut [I], - new_row_indices: &'out mut [I], - A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>, - perm: ghost::PermutationRef<'n, '_, I, Symbolic>, - in_side: Side, - out_side: Side, - stack: PodStack<'_>, - ) -> ghost::SymbolicSparseColMatRef<'n, 'n, 'out, I> { - let old_values = &*Symbolic::materialize(A.into_inner().row_indices().len()); - let new_values = Symbolic::materialize(new_row_indices.len()); - *ghost_permute_hermitian_unsorted( - SliceGroupMut::<'_, Symbolic>::new(new_values), - new_col_ptrs, - new_row_indices, - ghost::SparseColMatRef::new( - SparseColMatRef::new(A.into_inner(), old_values), - A.nrows(), - A.ncols(), - ), - perm, - in_side, - out_side, - false, - stack, - ) - } - - /// Computes the self-adjoint permutation $P A P^\top$ of the matrix `A` without sorting the row - /// indices, and returns a view over it. - /// - /// The result is stored in `new_col_ptrs`, `new_row_indices`. - #[doc(hidden)] - pub unsafe fn permute_hermitian_unsorted<'out, I: Index, E: ComplexField>( - new_values: GroupFor, - new_col_ptrs: &'out mut [I], - new_row_indices: &'out mut [I], - A: SparseColMatRef<'_, I, E>, - perm: crate::permutation::PermutationRef<'_, I, E>, - in_side: Side, - out_side: Side, - stack: PodStack<'_>, - ) -> SparseColMatMut<'out, I, E> { - ghost::Size::with(A.nrows(), |N| { - assert!(A.nrows() == A.ncols()); - ghost_permute_hermitian_unsorted( - SliceGroupMut::new(new_values), - new_col_ptrs, - new_row_indices, - ghost::SparseColMatRef::new(A, N, N), - ghost::PermutationRef::new(perm, N), - in_side, - out_side, - false, - stack, - ) - .into_inner() - }) - } - - /// Computes the self-adjoint permutation $P A P^\top$ of the matrix `A` and returns a view over - /// it. - /// - /// The result is stored in `new_col_ptrs`, `new_row_indices`. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - pub fn permute_hermitian<'out, I: Index, E: ComplexField>( - new_values: GroupFor, - new_col_ptrs: &'out mut [I], - new_row_indices: &'out mut [I], - A: SparseColMatRef<'_, I, E>, - perm: crate::permutation::PermutationRef<'_, I, E>, - in_side: Side, - out_side: Side, - stack: PodStack<'_>, - ) -> SparseColMatMut<'out, I, E> { - ghost::Size::with(A.nrows(), |N| { - assert!(A.nrows() == A.ncols()); - unsafe { - ghost_permute_hermitian_unsorted( - SliceGroupMut::new(new_values), - new_col_ptrs, - new_row_indices, - ghost::SparseColMatRef::new(A, N, N), - ghost::PermutationRef::new(perm, N), - in_side, - out_side, - true, - stack, - ) - } - .into_inner() - }) - } - - #[doc(hidden)] - pub fn ghost_adjoint_symbolic<'m, 'n, 'a, I: Index>( - new_col_ptrs: &'a mut [I], - new_row_indices: &'a mut [I], - A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>, - stack: PodStack<'_>, - ) -> ghost::SymbolicSparseColMatRef<'n, 'm, 'a, I> { - let old_values = &*Symbolic::materialize(A.into_inner().row_indices().len()); - let new_values = Symbolic::materialize(new_row_indices.len()); - *ghost_adjoint( - new_col_ptrs, - new_row_indices, - SliceGroupMut::<'_, Symbolic>::new(new_values), - ghost::SparseColMatRef::new( - SparseColMatRef::new(A.into_inner(), old_values), - A.nrows(), - A.ncols(), - ), - stack, - ) - } - - #[doc(hidden)] - pub fn ghost_adjoint<'m, 'n, 'a, I: Index, E: ComplexField>( - new_col_ptrs: &'a mut [I], - new_row_indices: &'a mut [I], - new_values: SliceGroupMut<'a, E>, - A: ghost::SparseColMatRef<'m, 'n, '_, I, E>, - stack: PodStack<'_>, - ) -> ghost::SparseColMatMut<'n, 'm, 'a, I, E> { - let M = A.nrows(); - let N = A.ncols(); - assert!(new_col_ptrs.len() == *M + 1); - - let (col_count, _) = stack.make_raw::(*M); - let col_count = ghost::Array::from_mut(col_count, M); - mem::fill_zero(col_count.as_mut()); - - // can't overflow because the total count is A.compute_nnz() <= I::MAX - for j in N.indices() { - for i in A.row_indices_of_col(j) { - col_count[i] += I::truncate(1); - } - } - - new_col_ptrs[0] = I::truncate(0); - // col_count elements are >= 0 - for (j, [pj0, pj1]) in zip( - M.indices(), - windows2(Cell::as_slice_of_cells(Cell::from_mut(new_col_ptrs))), - ) { - let cj = &mut col_count[j]; - let pj = pj0.get(); - // new_col_ptrs is non-decreasing - pj1.set(pj + *cj); - *cj = pj; - } - - let new_row_indices = &mut new_row_indices[..new_col_ptrs[*M].zx()]; - let mut new_values = new_values.subslice(0..new_col_ptrs[*M].zx()); - let current_row_position = &mut *col_count; - // current_row_position[i] == col_ptr[i] - for j in N.indices() { - let j_: ghost::Idx<'n, I> = j.truncate::(); - for (i, val) in zip( - A.row_indices_of_col(j), - SliceGroup::<'_, E>::new(A.values_of_col(j)).into_ref_iter(), - ) { - let ci = &mut current_row_position[i]; - - // SAFETY: see below - unsafe { - *new_row_indices.get_unchecked_mut(ci.zx()) = *j_; - new_values.write_unchecked(ci.zx(), val.read().faer_conj()) - }; - *ci += I::truncate(1); - } - } - // current_row_position[i] == col_ptr[i] + col_count[i] == col_ptr[i + 1] <= col_ptr[m] - // so all the unchecked accesses were valid and non-overlapping, which means the entire - // array is filled - debug_assert!(current_row_position.as_ref() == &new_col_ptrs[1..]); - - // SAFETY: - // 0. new_col_ptrs is non-decreasing - // 1. all written row indices are less than n - ghost::SparseColMatMut::new( - unsafe { - SparseColMatMut::new( - SymbolicSparseColMatRef::new_unchecked( - *N, - *M, - new_col_ptrs, - None, - new_row_indices, - ), - new_values.into_inner(), - ) - }, - N, - M, - ) - } - - #[doc(hidden)] - pub fn ghost_transpose<'m, 'n, 'a, I: Index, E: Entity>( - new_col_ptrs: &'a mut [I], - new_row_indices: &'a mut [I], - new_values: SliceGroupMut<'a, E>, - A: ghost::SparseColMatRef<'m, 'n, '_, I, E>, - stack: PodStack<'_>, - ) -> ghost::SparseColMatMut<'n, 'm, 'a, I, E> { - let M = A.nrows(); - let N = A.ncols(); - assert!(new_col_ptrs.len() == *M + 1); - - let (col_count, _) = stack.make_raw::(*M); - let col_count = ghost::Array::from_mut(col_count, M); - mem::fill_zero(col_count.as_mut()); - - // can't overflow because the total count is A.compute_nnz() <= I::MAX - for j in N.indices() { - for i in A.row_indices_of_col(j) { - col_count[i] += I::truncate(1); - } - } - - new_col_ptrs[0] = I::truncate(0); - // col_count elements are >= 0 - for (j, [pj0, pj1]) in zip( - M.indices(), - windows2(Cell::as_slice_of_cells(Cell::from_mut(new_col_ptrs))), - ) { - let cj = &mut col_count[j]; - let pj = pj0.get(); - // new_col_ptrs is non-decreasing - pj1.set(pj + *cj); - *cj = pj; - } - - let new_row_indices = &mut new_row_indices[..new_col_ptrs[*M].zx()]; - let mut new_values = new_values.subslice(0..new_col_ptrs[*M].zx()); - let current_row_position = &mut *col_count; - // current_row_position[i] == col_ptr[i] - for j in N.indices() { - let j_: ghost::Idx<'n, I> = j.truncate::(); - for (i, val) in zip( - A.row_indices_of_col(j), - SliceGroup::<'_, E>::new(A.values_of_col(j)).into_ref_iter(), - ) { - let ci = &mut current_row_position[i]; - - // SAFETY: see below - unsafe { - *new_row_indices.get_unchecked_mut(ci.zx()) = *j_; - new_values.write_unchecked(ci.zx(), val.read()) - }; - *ci += I::truncate(1); - } - } - // current_row_position[i] == col_ptr[i] + col_count[i] == col_ptr[i + 1] <= col_ptr[m] - // so all the unchecked accesses were valid and non-overlapping, which means the entire - // array is filled - debug_assert!(current_row_position.as_ref() == &new_col_ptrs[1..]); - - // SAFETY: - // 0. new_col_ptrs is non-decreasing - // 1. all written row indices are less than n - ghost::SparseColMatMut::new( - unsafe { - SparseColMatMut::new( - SymbolicSparseColMatRef::new_unchecked( - *N, - *M, - new_col_ptrs, - None, - new_row_indices, - ), - new_values.into_inner(), - ) - }, - N, - M, - ) - } - - /// Computes the transpose of the matrix `A` and returns a view over it. - /// - /// The result is stored in `new_col_ptrs`, `new_row_indices` and `new_values`. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - pub fn transpose<'a, I: Index, E: Entity>( - new_col_ptrs: &'a mut [I], - new_row_indices: &'a mut [I], - new_values: GroupFor, - A: SparseColMatRef<'_, I, E>, - stack: PodStack<'_>, - ) -> SparseColMatMut<'a, I, E> { - ghost::Size::with2(A.nrows(), A.ncols(), |M, N| { - ghost_transpose( - new_col_ptrs, - new_row_indices, - SliceGroupMut::new(new_values), - ghost::SparseColMatRef::new(A, M, N), - stack, - ) - .into_inner() - }) - } - - /// Computes the adjoint of the matrix `A` and returns a view over it. - /// - /// The result is stored in `new_col_ptrs`, `new_row_indices` and `new_values`. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - pub fn adjoint<'a, I: Index, E: ComplexField>( - new_col_ptrs: &'a mut [I], - new_row_indices: &'a mut [I], - new_values: GroupFor, - A: SparseColMatRef<'_, I, E>, - stack: PodStack<'_>, - ) -> SparseColMatMut<'a, I, E> { - ghost::Size::with2(A.nrows(), A.ncols(), |M, N| { - ghost_adjoint( - new_col_ptrs, - new_row_indices, - SliceGroupMut::new(new_values), - ghost::SparseColMatRef::new(A, M, N), - stack, - ) - .into_inner() - }) - } - - /// Computes the adjoint of the symbolic matrix `A` and returns a view over it. - /// - /// The result is stored in `new_col_ptrs`, `new_row_indices`. - /// - /// # Note - /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. - pub fn adjoint_symbolic<'a, I: Index>( - new_col_ptrs: &'a mut [I], - new_row_indices: &'a mut [I], - A: SymbolicSparseColMatRef<'_, I>, - stack: PodStack<'_>, - ) -> SymbolicSparseColMatRef<'a, I> { - ghost::Size::with2(A.nrows(), A.ncols(), |M, N| { - ghost_adjoint_symbolic( - new_col_ptrs, - new_row_indices, - ghost::SymbolicSparseColMatRef::new(A, M, N), - stack, - ) - .into_inner() - }) - } -} - -/// Arithmetic and generic binary/ternary operations. -pub mod ops { - use super::*; - use crate::assert; - - /// Returns the resulting matrix obtained by applying `f` to the elements from `lhs` and `rhs`, - /// skipping entries that are unavailable in both of `lhs` and `rhs`. - /// - /// # Panics - /// Panics if `lhs` and `rhs` don't have matching dimensions. - #[track_caller] - pub fn binary_op( - lhs: SparseColMatRef<'_, I, LhsE>, - rhs: SparseColMatRef<'_, I, RhsE>, - f: impl FnMut(LhsE, RhsE) -> E, - ) -> Result, FaerError> { - assert!(lhs.nrows() == rhs.nrows()); - assert!(lhs.ncols() == rhs.ncols()); - let mut f = f; - let m = lhs.nrows(); - let n = lhs.ncols(); - - let mut col_ptrs = try_zeroed::(n + 1)?; - - let mut nnz = 0usize; - for j in 0..n { - let lhs = lhs.row_indices_of_col_raw(j); - let rhs = rhs.row_indices_of_col_raw(j); - - let mut lhs_pos = 0usize; - let mut rhs_pos = 0usize; - while lhs_pos < lhs.len() && rhs_pos < rhs.len() { - let lhs = lhs[lhs_pos]; - let rhs = rhs[rhs_pos]; - - lhs_pos += (lhs <= rhs) as usize; - rhs_pos += (rhs <= lhs) as usize; - nnz += 1; - } - nnz += lhs.len() - lhs_pos; - nnz += rhs.len() - rhs_pos; - col_ptrs[j + 1] = I::truncate(nnz); - } - - if nnz > I::Signed::MAX.zx() { - return Err(FaerError::IndexOverflow); - } - - let mut row_indices = try_zeroed(nnz)?; - let mut values = VecGroup::::new(); - values - .try_reserve_exact(nnz) - .map_err(|_| FaerError::OutOfMemory)?; - values.resize(nnz, unsafe { core::mem::zeroed() }); - - let mut nnz = 0usize; - for j in 0..n { - let mut values = values.as_slice_mut(); - let lhs_values = SliceGroup::::new(lhs.values_of_col(j)); - let rhs_values = SliceGroup::::new(rhs.values_of_col(j)); - let lhs = lhs.row_indices_of_col_raw(j); - let rhs = rhs.row_indices_of_col_raw(j); - - let mut lhs_pos = 0usize; - let mut rhs_pos = 0usize; - while lhs_pos < lhs.len() && rhs_pos < rhs.len() { - let lhs = lhs[lhs_pos]; - let rhs = rhs[rhs_pos]; - - match lhs.cmp(&rhs) { - core::cmp::Ordering::Less => { - row_indices[nnz] = lhs; - values.write( - nnz, - f(lhs_values.read(lhs_pos), unsafe { core::mem::zeroed() }), - ); - } - core::cmp::Ordering::Equal => { - row_indices[nnz] = lhs; - values.write(nnz, f(lhs_values.read(lhs_pos), rhs_values.read(rhs_pos))); - } - core::cmp::Ordering::Greater => { - row_indices[nnz] = rhs; - values.write( - nnz, - f(unsafe { core::mem::zeroed() }, rhs_values.read(rhs_pos)), - ); - } - } - - lhs_pos += (lhs <= rhs) as usize; - rhs_pos += (rhs <= lhs) as usize; - nnz += 1; - } - row_indices[nnz..nnz + lhs.len() - lhs_pos].copy_from_slice(&lhs[lhs_pos..]); - for (mut dst, src) in values - .rb_mut() - .subslice(nnz..nnz + lhs.len() - lhs_pos) - .into_mut_iter() - .zip(lhs_values.subslice(lhs_pos..lhs.len()).into_ref_iter()) - { - dst.write(f(src.read(), unsafe { core::mem::zeroed() })); - } - nnz += lhs.len() - lhs_pos; - - row_indices[nnz..nnz + rhs.len() - rhs_pos].copy_from_slice(&rhs[rhs_pos..]); - for (mut dst, src) in values - .rb_mut() - .subslice(nnz..nnz + rhs.len() - rhs_pos) - .into_mut_iter() - .zip(rhs_values.subslice(rhs_pos..rhs.len()).into_ref_iter()) - { - dst.write(f(unsafe { core::mem::zeroed() }, src.read())); - } - nnz += rhs.len() - rhs_pos; - } - - Ok(SparseColMat::::new( - SymbolicSparseColMat::::new_checked(m, n, col_ptrs, None, row_indices), - values.into_inner(), - )) - } - - /// Returns the resulting matrix obtained by applying `f` to the elements from `dst` and `src` - /// skipping entries that are unavailable in both of them. - /// The sparsity patter of `dst` is unchanged. - /// - /// # Panics - /// Panics if `src` and `dst` don't have matching dimensions. - /// Panics if `src` contains an index that's unavailable in `dst`. - #[track_caller] - pub fn binary_op_assign_into( - dst: SparseColMatMut<'_, I, E>, - src: SparseColMatRef<'_, I, SrcE>, - f: impl FnMut(E, SrcE) -> E, - ) { - { - assert!(dst.nrows() == src.nrows()); - assert!(dst.ncols() == src.ncols()); - - let n = dst.ncols(); - let mut dst = dst; - let mut f = f; - unsafe { - assert!(f(core::mem::zeroed(), core::mem::zeroed()) == core::mem::zeroed()); - } - - for j in 0..n { - let (dst, dst_val) = dst.rb_mut().into_parts_mut(); - - let mut dst_val = SliceGroupMut::::new(dst_val).subslice(dst.col_range(j)); - let src_val = SliceGroup::::new(src.values_of_col(j)); - - let dst = dst.row_indices_of_col_raw(j); - let src = src.row_indices_of_col_raw(j); - - let mut dst_pos = 0usize; - let mut src_pos = 0usize; - - while src_pos < src.len() { - let src = src[src_pos]; - - if dst[dst_pos] < src { - dst_val.write( - dst_pos, - f(dst_val.read(dst_pos), unsafe { core::mem::zeroed() }), - ); - dst_pos += 1; - continue; - } - - assert!(dst[dst_pos] == src); - - dst_val.write(dst_pos, f(dst_val.read(dst_pos), src_val.read(src_pos))); - - src_pos += 1; - dst_pos += 1; - } - while dst_pos < dst.len() { - dst_val.write( - dst_pos, - f(dst_val.read(dst_pos), unsafe { core::mem::zeroed() }), - ); - dst_pos += 1; - } - } - } - } - - /// Returns the resulting matrix obtained by applying `f` to the elements from `dst`, `lhs` and - /// `rhs`, skipping entries that are unavailable in all of `dst`, `lhs` and `rhs`. - /// The sparsity patter of `dst` is unchanged. - /// - /// # Panics - /// Panics if `lhs`, `rhs` and `dst` don't have matching dimensions. - /// Panics if `lhs` or `rhs` contains an index that's unavailable in `dst`. - #[track_caller] - pub fn ternary_op_assign_into( - dst: SparseColMatMut<'_, I, E>, - lhs: SparseColMatRef<'_, I, LhsE>, - rhs: SparseColMatRef<'_, I, RhsE>, - f: impl FnMut(E, LhsE, RhsE) -> E, - ) { - { - assert!(dst.nrows() == lhs.nrows()); - assert!(dst.ncols() == lhs.ncols()); - assert!(dst.nrows() == rhs.nrows()); - assert!(dst.ncols() == rhs.ncols()); - - let n = dst.ncols(); - let mut dst = dst; - let mut f = f; - unsafe { - assert!( - f( - core::mem::zeroed(), - core::mem::zeroed(), - core::mem::zeroed() - ) == core::mem::zeroed() - ); - } - - for j in 0..n { - let (dst, dst_val) = dst.rb_mut().into_parts_mut(); - - let mut dst_val = SliceGroupMut::::new(dst_val); - let lhs_val = SliceGroup::::new(lhs.values_of_col(j)); - let rhs_val = SliceGroup::::new(rhs.values_of_col(j)); - - let dst = dst.row_indices_of_col_raw(j); - let rhs = rhs.row_indices_of_col_raw(j); - let lhs = lhs.row_indices_of_col_raw(j); - - let mut dst_pos = 0usize; - let mut lhs_pos = 0usize; - let mut rhs_pos = 0usize; - - while lhs_pos < lhs.len() && rhs_pos < rhs.len() { - let lhs = lhs[lhs_pos]; - let rhs = rhs[rhs_pos]; - - if dst[dst_pos] < Ord::min(lhs, rhs) { - dst_val.write( - dst_pos, - f( - dst_val.read(dst_pos), - unsafe { core::mem::zeroed() }, - unsafe { core::mem::zeroed() }, - ), - ); - dst_pos += 1; - continue; - } - - assert!(dst[dst_pos] == Ord::min(lhs, rhs)); - - match lhs.cmp(&rhs) { - core::cmp::Ordering::Less => { - dst_val.write( - dst_pos, - f(dst_val.read(dst_pos), lhs_val.read(lhs_pos), unsafe { - core::mem::zeroed() - }), - ); - } - core::cmp::Ordering::Equal => { - dst_val.write( - dst_pos, - f( - dst_val.read(dst_pos), - lhs_val.read(lhs_pos), - rhs_val.read(rhs_pos), - ), - ); - } - core::cmp::Ordering::Greater => { - dst_val.write( - dst_pos, - f( - dst_val.read(dst_pos), - unsafe { core::mem::zeroed() }, - rhs_val.read(rhs_pos), - ), - ); - } - } - - lhs_pos += (lhs <= rhs) as usize; - rhs_pos += (rhs <= lhs) as usize; - dst_pos += 1; - } - while lhs_pos < lhs.len() { - let lhs = lhs[lhs_pos]; - if dst[dst_pos] < lhs { - dst_val.write( - dst_pos, - f( - dst_val.read(dst_pos), - unsafe { core::mem::zeroed() }, - unsafe { core::mem::zeroed() }, - ), - ); - dst_pos += 1; - continue; - } - dst_val.write( - dst_pos, - f(dst_val.read(dst_pos), lhs_val.read(lhs_pos), unsafe { - core::mem::zeroed() - }), - ); - lhs_pos += 1; - dst_pos += 1; - } - while rhs_pos < rhs.len() { - let rhs = rhs[rhs_pos]; - if dst[dst_pos] < rhs { - dst_val.write( - dst_pos, - f( - dst_val.read(dst_pos), - unsafe { core::mem::zeroed() }, - unsafe { core::mem::zeroed() }, - ), - ); - dst_pos += 1; - continue; - } - dst_val.write( - dst_pos, - f( - dst_val.read(dst_pos), - unsafe { core::mem::zeroed() }, - rhs_val.read(rhs_pos), - ), - ); - rhs_pos += 1; - dst_pos += 1; - } - while rhs_pos < rhs.len() { - let rhs = rhs[rhs_pos]; - dst_pos += dst[dst_pos..].binary_search(&rhs).unwrap(); - dst_val.write( - dst_pos, - f( - dst_val.read(dst_pos), - unsafe { core::mem::zeroed() }, - rhs_val.read(rhs_pos), - ), - ); - rhs_pos += 1; - } - } - } - } - - /// Returns the sparsity pattern containing the union of those of `lhs` and `rhs`. - /// - /// # Panics - /// Panics if `lhs` and `rhs` don't have mathcing dimensions. - #[track_caller] - #[inline] - pub fn union_symbolic( - lhs: SymbolicSparseColMatRef<'_, I>, - rhs: SymbolicSparseColMatRef<'_, I>, - ) -> Result, FaerError> { - Ok(binary_op( - SparseColMatRef::::new(lhs, Symbolic::materialize(lhs.compute_nnz())), - SparseColMatRef::::new(rhs, Symbolic::materialize(rhs.compute_nnz())), - #[inline(always)] - |_, _| Symbolic, - )? - .into_parts() - .0) - } - - /// Returns the sum of `lhs` and `rhs`. - /// - /// # Panics - /// Panics if `lhs` and `rhs` don't have mathcing dimensions. - #[track_caller] - #[inline] - pub fn add< - I: Index, - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - lhs: SparseColMatRef<'_, I, LhsE>, - rhs: SparseColMatRef<'_, I, RhsE>, - ) -> Result, FaerError> { - binary_op(lhs, rhs, |lhs, rhs| { - lhs.canonicalize().faer_add(rhs.canonicalize()) - }) - } - - /// Returns the difference of `lhs` and `rhs`. - /// - /// # Panics - /// Panics if `lhs` and `rhs` don't have matching dimensions. - #[track_caller] - #[inline] - pub fn sub< - I: Index, - LhsE: Conjugate, - RhsE: Conjugate, - E: ComplexField, - >( - lhs: SparseColMatRef<'_, I, LhsE>, - rhs: SparseColMatRef<'_, I, RhsE>, - ) -> Result, FaerError> { - binary_op(lhs, rhs, |lhs, rhs| { - lhs.canonicalize().faer_sub(rhs.canonicalize()) - }) - } - - /// Computes the sum of `dst` and `src` and stores the result in `dst` without changing its - /// symbolic structure. - /// - /// # Panics - /// Panics if `dst` and `rhs` don't have matching dimensions. - /// Panics if `rhs` contains an index that's unavailable in `dst`. - pub fn add_assign>( - dst: SparseColMatMut<'_, I, E>, - rhs: SparseColMatRef<'_, I, RhsE>, - ) { - binary_op_assign_into(dst, rhs, |dst, rhs| dst.faer_add(rhs.canonicalize())) - } - - /// Computes the difference of `dst` and `src` and stores the result in `dst` without changing - /// its symbolic structure. - /// - /// # Panics - /// Panics if `dst` and `rhs` don't have matching dimensions. - /// Panics if `rhs` contains an index that's unavailable in `dst`. - pub fn sub_assign>( - dst: SparseColMatMut<'_, I, E>, - rhs: SparseColMatRef<'_, I, RhsE>, - ) { - binary_op_assign_into(dst, rhs, |dst, rhs| dst.faer_sub(rhs.canonicalize())) - } - - /// Computes the sum of `lhs` and `rhs`, storing the result in `dst` without changing its - /// symbolic structure. - /// - /// # Panics - /// Panics if `dst`, `lhs` and `rhs` don't have matching dimensions. - /// Panics if `lhs` or `rhs` contains an index that's unavailable in `dst`. - #[track_caller] - #[inline] - pub fn add_into< - I: Index, - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - dst: SparseColMatMut<'_, I, E>, - lhs: SparseColMatRef<'_, I, LhsE>, - rhs: SparseColMatRef<'_, I, RhsE>, - ) { - ternary_op_assign_into(dst, lhs, rhs, |_, lhs, rhs| { - lhs.canonicalize().faer_add(rhs.canonicalize()) - }) - } - - /// Computes the difference of `lhs` and `rhs`, storing the result in `dst` without changing its - /// symbolic structure. - /// - /// # Panics - /// Panics if `dst`, `lhs` and `rhs` don't have matching dimensions. - /// Panics if `lhs` or `rhs` contains an index that's unavailable in `dst`. - #[track_caller] - #[inline] - pub fn sub_into< - I: Index, - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - dst: SparseColMatMut<'_, I, E>, - lhs: SparseColMatRef<'_, I, LhsE>, - rhs: SparseColMatRef<'_, I, RhsE>, - ) { - ternary_op_assign_into(dst, lhs, rhs, |_, lhs, rhs| { - lhs.canonicalize().faer_sub(rhs.canonicalize()) - }) - } -} - -impl<'a, I: Index, E: Entity> SparseColMatRef<'a, I, E> { - /// Returns a reference to the value at the given index, or None if the symbolic structure - /// doesn't contain it - /// - /// # Panics - /// Panics if `row >= self.nrows()` - /// Panics if `col >= self.ncols()` - #[track_caller] - pub fn get(self, row: usize, col: usize) -> Option> { - assert!(row < self.nrows()); - assert!(col < self.ncols()); - - let Ok(pos) = self - .row_indices_of_col_raw(col) - .binary_search(&I::truncate(row)) - else { - return None; - }; - - Some(E::faer_map(self.values_of_col(col), |slice| &slice[pos])) - } -} - -impl<'a, I: Index, E: Entity> SparseColMatMut<'a, I, E> { - /// Returns a reference to the value at the given index using a binary search, or None if the - /// symbolic structure doesn't contain it - /// - /// # Panics - /// Panics if `row >= self.nrows()` - /// Panics if `col >= self.ncols()` - #[track_caller] - pub fn get(self, row: usize, col: usize) -> Option> { - self.into_const().get(row, col) - } - - /// Returns a reference to the value at the given index using a binary search, or None if the - /// symbolic structure doesn't contain it - /// - /// # Panics - /// Panics if `row >= self.nrows()` - /// Panics if `col >= self.ncols()` - #[track_caller] - pub fn get_mut(self, row: usize, col: usize) -> Option> { - assert!(row < self.nrows()); - assert!(col < self.ncols()); - - let Ok(pos) = self - .row_indices_of_col_raw(col) - .binary_search(&I::truncate(row)) - else { - return None; - }; - - Some(E::faer_map(self.values_of_col_mut(col), |slice| { - &mut slice[pos] - })) - } -} - -impl SparseColMat { - /// Returns a reference to the value at the given index using a binary search, or None if the - /// symbolic structure doesn't contain it - /// - /// # Panics - /// Panics if `row >= self.nrows()` - /// Panics if `col >= self.ncols()` - #[track_caller] - pub fn get(&self, row: usize, col: usize) -> Option> { - self.as_ref().get(row, col) - } - - /// Returns a reference to the value at the given index using a binary search, or None if the - /// symbolic structure doesn't contain it - /// - /// # Panics - /// Panics if `row >= self.nrows()` - /// Panics if `col >= self.ncols()` - #[track_caller] - pub fn get_mut(&mut self, row: usize, col: usize) -> Option> { - self.as_mut().get_mut(row, col) - } -} - -impl<'a, I: Index, E: Entity> SparseRowMatRef<'a, I, E> { - /// Returns a reference to the value at the given index using a binary search, or None if the - /// symbolic structure doesn't contain it - /// - /// # Panics - /// Panics if `row >= self.nrows()` - /// Panics if `col >= self.ncols()` - #[track_caller] - pub fn get(self, row: usize, col: usize) -> Option> { - assert!(row < self.nrows()); - assert!(col < self.ncols()); - - let Ok(pos) = self - .col_indices_of_row_raw(row) - .binary_search(&I::truncate(col)) - else { - return None; - }; - - Some(E::faer_map(self.values_of_row(row), |slice| &slice[pos])) - } -} - -impl<'a, I: Index, E: Entity> SparseRowMatMut<'a, I, E> { - /// Returns a reference to the value at the given index using a binary search, or None if the - /// symbolic structure doesn't contain it - /// - /// # Panics - /// Panics if `row >= self.nrows()` - /// Panics if `col >= self.ncols()` - #[track_caller] - pub fn get_mut(self, row: usize, col: usize) -> Option> { - assert!(row < self.nrows()); - assert!(col < self.ncols()); - - let Ok(pos) = self - .col_indices_of_row_raw(row) - .binary_search(&I::truncate(col)) - else { - return None; - }; - - Some(E::faer_map(self.values_of_row_mut(row), |slice| { - &mut slice[pos] - })) - } -} - -impl SparseRowMat { - /// Returns a reference to the value at the given index using a binary search, or None if the - /// symbolic structure doesn't contain it - /// - /// # Panics - /// Panics if `row >= self.nrows()` - /// Panics if `col >= self.ncols()` - #[track_caller] - pub fn get(&self, row: usize, col: usize) -> Option> { - self.as_ref().get(row, col) - } - - /// Returns a reference to the value at the given index using a binary search, or None if the - /// symbolic structure doesn't contain it - /// - /// # Panics - /// Panics if `row >= self.nrows()` - /// Panics if `col >= self.ncols()` - #[track_caller] - pub fn get_mut(&mut self, row: usize, col: usize) -> Option> { - self.as_mut().get_mut(row, col) - } -} - -impl core::ops::Index<(usize, usize)> for SparseColMatRef<'_, I, E> { - type Output = E; - - #[track_caller] - fn index(&self, (row, col): (usize, usize)) -> &Self::Output { - self.get(row, col).unwrap() - } -} - -impl core::ops::Index<(usize, usize)> for SparseRowMatRef<'_, I, E> { - type Output = E; - - #[track_caller] - fn index(&self, (row, col): (usize, usize)) -> &Self::Output { - self.get(row, col).unwrap() - } -} - -impl core::ops::Index<(usize, usize)> for SparseColMatMut<'_, I, E> { - type Output = E; - - #[track_caller] - fn index(&self, (row, col): (usize, usize)) -> &Self::Output { - self.rb().get(row, col).unwrap() - } -} - -impl core::ops::Index<(usize, usize)> for SparseRowMatMut<'_, I, E> { - type Output = E; - - #[track_caller] - fn index(&self, (row, col): (usize, usize)) -> &Self::Output { - self.rb().get(row, col).unwrap() - } -} - -impl core::ops::IndexMut<(usize, usize)> for SparseColMatMut<'_, I, E> { - #[track_caller] - fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output { - self.rb_mut().get_mut(row, col).unwrap() - } -} - -impl core::ops::IndexMut<(usize, usize)> for SparseRowMatMut<'_, I, E> { - #[track_caller] - fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output { - self.rb_mut().get_mut(row, col).unwrap() - } -} - -impl core::ops::Index<(usize, usize)> for SparseColMat { - type Output = E; - - #[track_caller] - fn index(&self, (row, col): (usize, usize)) -> &Self::Output { - self.as_ref().get(row, col).unwrap() - } -} - -impl core::ops::Index<(usize, usize)> for SparseRowMat { - type Output = E; - - #[track_caller] - fn index(&self, (row, col): (usize, usize)) -> &Self::Output { - self.as_ref().get(row, col).unwrap() - } -} - -impl core::ops::IndexMut<(usize, usize)> for SparseColMat { - #[track_caller] - fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output { - self.as_mut().get_mut(row, col).unwrap() - } -} - -impl core::ops::IndexMut<(usize, usize)> for SparseRowMat { - #[track_caller] - fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output { - self.as_mut().get_mut(row, col).unwrap() - } -} - -/// Sparse matrix multiplication. -pub mod mul { - // TODO: sparse_sparse_matmul - // - // PERF: optimize matmul - // - parallelization - // - simd(?) - - use super::*; - use crate::{ - assert, - constrained::{self, Size}, - }; - - /// Multiplies a sparse matrix `lhs` by a dense matrix `rhs`, and stores the result in - /// `acc`. See [`crate::mul::matmul`] for more details. - /// - /// # Note - /// Allows unsorted matrices. - #[track_caller] - pub fn sparse_dense_matmul< - I: Index, - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - acc: MatMut<'_, E>, - lhs: SparseColMatRef<'_, I, LhsE>, - rhs: MatRef<'_, RhsE>, - alpha: Option, - beta: E, - parallelism: Parallelism, - ) { - assert!(all( - acc.nrows() == lhs.nrows(), - acc.ncols() == rhs.ncols(), - lhs.ncols() == rhs.nrows(), - )); - - let _ = parallelism; - let m = acc.nrows(); - let n = acc.ncols(); - let k = lhs.ncols(); - - let mut acc = acc; - - match alpha { - Some(alpha) => { - if alpha != E::faer_one() { - zipped!(acc.rb_mut()) - .for_each(|unzipped!(mut dst)| dst.write(dst.read().faer_mul(alpha))) - } - } - None => acc.fill_zero(), - } - - Size::with2(m, n, |m, n| { - Size::with(k, |k| { - let mut acc = constrained::MatMut::new(acc, m, n); - let lhs = constrained::sparse::SparseColMatRef::new(lhs, m, k); - let rhs = constrained::MatRef::new(rhs, k, n); - - for j in n.indices() { - for depth in k.indices() { - let rhs_kj = rhs.read(depth, j).canonicalize().faer_mul(beta); - for (i, lhs_ik) in zip( - lhs.row_indices_of_col(depth), - SliceGroup::<'_, LhsE>::new(lhs.values_of_col(depth)).into_ref_iter(), - ) { - acc.write( - i, - j, - acc.read(i, j) - .faer_add(lhs_ik.read().canonicalize().faer_mul(rhs_kj)), - ); - } - } - } - }); - }); - } - - /// Multiplies a dense matrix `lhs` by a sparse matrix `rhs`, and stores the result in - /// `acc`. See [`crate::mul::matmul`] for more details. - /// - /// # Note - /// Allows unsorted matrices. - #[track_caller] - pub fn dense_sparse_matmul< - I: Index, - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - acc: MatMut<'_, E>, - lhs: MatRef<'_, LhsE>, - rhs: SparseColMatRef<'_, I, RhsE>, - alpha: Option, - beta: E, - parallelism: Parallelism, - ) { - assert!(all( - acc.nrows() == lhs.nrows(), - acc.ncols() == rhs.ncols(), - lhs.ncols() == rhs.nrows(), - )); - - let _ = parallelism; - let m = acc.nrows(); - let n = acc.ncols(); - let k = lhs.ncols(); - - let mut acc = acc; - - match alpha { - Some(alpha) => { - if alpha != E::faer_one() { - zipped!(acc.rb_mut()) - .for_each(|unzipped!(mut dst)| dst.write(dst.read().faer_mul(alpha))) - } - } - None => acc.fill_zero(), - } - - Size::with2(m, n, |m, n| { - Size::with(k, |k| { - let mut acc = constrained::MatMut::new(acc, m, n); - let lhs = constrained::MatRef::new(lhs, m, k); - let rhs = constrained::sparse::SparseColMatRef::new(rhs, k, n); - - for i in m.indices() { - for j in n.indices() { - let mut acc_ij = E::faer_zero(); - for (depth, rhs_kj) in zip( - rhs.row_indices_of_col(j), - SliceGroup::<'_, RhsE>::new(rhs.values_of_col(j)).into_ref_iter(), - ) { - let lhs_ik = lhs.read(i, depth); - acc_ij = acc_ij.faer_add( - lhs_ik.canonicalize().faer_mul(rhs_kj.read().canonicalize()), - ); - } - - acc.write(i, j, acc.read(i, j).faer_add(beta.faer_mul(acc_ij))); - } - } - }); - }); - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::Matrix for SparseColMatRef<'_, I, E> { - #[inline] - fn rows(&self) -> usize { - self.nrows() - } - #[inline] - fn cols(&self) -> usize { - self.ncols() - } - #[inline] - fn access(&self) -> matrixcompare_core::Access<'_, E> { - matrixcompare_core::Access::Sparse(self) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::SparseAccess for SparseColMatRef<'_, I, E> { - #[inline] - fn nnz(&self) -> usize { - self.compute_nnz() - } - - #[inline] - fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { - let mut triplets = Vec::new(); - for j in 0..self.ncols() { - for (i, val) in self - .row_indices_of_col(j) - .zip(SliceGroup::<'_, E>::new(self.values_of_col(j)).into_ref_iter()) - { - triplets.push((i, j, val.read())) - } - } - triplets - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::Matrix for SparseRowMatRef<'_, I, E> { - #[inline] - fn rows(&self) -> usize { - self.nrows() - } - #[inline] - fn cols(&self) -> usize { - self.ncols() - } - #[inline] - fn access(&self) -> matrixcompare_core::Access<'_, E> { - matrixcompare_core::Access::Sparse(self) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::SparseAccess for SparseRowMatRef<'_, I, E> { - #[inline] - fn nnz(&self) -> usize { - self.compute_nnz() - } - - #[inline] - fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { - let mut triplets = Vec::new(); - for i in 0..self.nrows() { - for (j, val) in self - .col_indices_of_row(i) - .zip(SliceGroup::<'_, E>::new(self.values_of_row(i)).into_ref_iter()) - { - triplets.push((i, j, val.read())) - } - } - triplets - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::Matrix for SparseColMatMut<'_, I, E> { - #[inline] - fn rows(&self) -> usize { - self.nrows() - } - #[inline] - fn cols(&self) -> usize { - self.ncols() - } - #[inline] - fn access(&self) -> matrixcompare_core::Access<'_, E> { - matrixcompare_core::Access::Sparse(self) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::SparseAccess for SparseColMatMut<'_, I, E> { - #[inline] - fn nnz(&self) -> usize { - self.compute_nnz() - } - - #[inline] - fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { - self.rb().fetch_triplets() - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::Matrix for SparseColMat { - #[inline] - fn rows(&self) -> usize { - self.nrows() - } - #[inline] - fn cols(&self) -> usize { - self.ncols() - } - #[inline] - fn access(&self) -> matrixcompare_core::Access<'_, E> { - matrixcompare_core::Access::Sparse(self) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::SparseAccess for SparseColMat { - #[inline] - fn nnz(&self) -> usize { - self.compute_nnz() - } - - #[inline] - fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { - self.as_ref().fetch_triplets() - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::Matrix for SparseRowMatMut<'_, I, E> { - #[inline] - fn rows(&self) -> usize { - self.nrows() - } - #[inline] - fn cols(&self) -> usize { - self.ncols() - } - #[inline] - fn access(&self) -> matrixcompare_core::Access<'_, E> { - matrixcompare_core::Access::Sparse(self) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::SparseAccess for SparseRowMatMut<'_, I, E> { - #[inline] - fn nnz(&self) -> usize { - self.compute_nnz() - } - - #[inline] - fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { - self.rb().fetch_triplets() - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::Matrix for SparseRowMat { - #[inline] - fn rows(&self) -> usize { - self.nrows() - } - #[inline] - fn cols(&self) -> usize { - self.ncols() - } - #[inline] - fn access(&self) -> matrixcompare_core::Access<'_, E> { - matrixcompare_core::Access::Sparse(self) - } -} - -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl matrixcompare_core::SparseAccess for SparseRowMat { - #[inline] - fn nnz(&self) -> usize { - self.compute_nnz() - } - - #[inline] - fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { - self.as_ref().fetch_triplets() - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::assert; - - #[test] - fn test_from_indices() { - let nrows = 5; - let ncols = 4; - - let indices = &[(0, 0), (1, 2), (0, 0), (1, 1), (0, 1), (3, 3), (3, 3usize)]; - let values = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0f64]; - - let triplets = &[ - (0, 0, 1.0), - (1, 2, 2.0), - (0, 0, 3.0), - (1, 1, 4.0), - (0, 1, 5.0), - (3, 3, 6.0), - (3, 3usize, 7.0), - ]; - - { - let mat = SymbolicSparseColMat::try_new_from_indices(nrows, ncols, indices); - assert!(mat.is_ok()); - - let (mat, order) = mat.unwrap(); - assert!(mat.nrows() == nrows); - assert!(mat.ncols() == ncols); - assert!(mat.col_ptrs() == &[0, 1, 3, 4, 5]); - assert!(mat.nnz_per_col() == None); - assert!(mat.row_indices() == &[0, 0, 1, 1, 3]); - - let mat = - SparseColMat::<_, f64>::new_from_order_and_values(mat, &order, values).unwrap(); - assert!(mat.as_ref().values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); - } - - { - let mat = SparseColMat::try_new_from_triplets(nrows, ncols, triplets); - assert!(mat.is_ok()); - let mat = mat.unwrap(); - - assert!(mat.nrows() == nrows); - assert!(mat.ncols() == ncols); - assert!(mat.col_ptrs() == &[0, 1, 3, 4, 5]); - assert!(mat.nnz_per_col() == None); - assert!(mat.row_indices() == &[0, 0, 1, 1, 3]); - assert!(mat.values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); - } - - { - let mat = SymbolicSparseRowMat::try_new_from_indices(nrows, ncols, indices); - assert!(mat.is_ok()); - - let (mat, order) = mat.unwrap(); - assert!(mat.nrows() == nrows); - assert!(mat.ncols() == ncols); - assert!(mat.row_ptrs() == &[0, 2, 4, 4, 5, 5]); - assert!(mat.nnz_per_row() == None); - assert!(mat.col_indices() == &[0, 1, 1, 2, 3]); - - let mat = - SparseRowMat::<_, f64>::new_from_order_and_values(mat, &order, values).unwrap(); - assert!(mat.values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); - } - { - let mat = SparseRowMat::try_new_from_triplets(nrows, ncols, triplets); - assert!(mat.is_ok()); - - let mat = mat.unwrap(); - assert!(mat.nrows() == nrows); - assert!(mat.ncols() == ncols); - assert!(mat.row_ptrs() == &[0, 2, 4, 4, 5, 5]); - assert!(mat.nnz_per_row() == None); - assert!(mat.col_indices() == &[0, 1, 1, 2, 3]); - assert!(mat.as_ref().values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); - } - } - - #[test] - fn test_from_nonnegative_indices() { - let nrows = 5; - let ncols = 4; - - let indices = &[ - (0, 0), - (1, 2), - (0, 0), - (1, 1), - (0, 1), - (-1, 2), - (-2, 1), - (-3, -4), - (3, 3), - (3, 3isize), - ]; - let values = &[ - 1.0, - 2.0, - 3.0, - 4.0, - 5.0, - f64::NAN, - f64::NAN, - f64::NAN, - 6.0, - 7.0f64, - ]; - - let triplets = &[ - (0, 0, 1.0), - (1, 2, 2.0), - (0, 0, 3.0), - (1, 1, 4.0), - (0, 1, 5.0), - (-1, 2, f64::NAN), - (-2, 1, f64::NAN), - (-3, -4, f64::NAN), - (3, 3, 6.0), - (3, 3isize, 7.0), - ]; - - { - let mat = SymbolicSparseColMat::::try_new_from_nonnegative_indices( - nrows, ncols, indices, - ); - assert!(mat.is_ok()); - - let (mat, order) = mat.unwrap(); - assert!(mat.nrows() == nrows); - assert!(mat.ncols() == ncols); - assert!(mat.col_ptrs() == &[0, 1, 3, 4, 5]); - assert!(mat.nnz_per_col() == None); - assert!(mat.row_indices() == &[0, 0, 1, 1, 3]); - - let mat = - SparseColMat::<_, f64>::new_from_order_and_values(mat, &order, values).unwrap(); - assert!(mat.as_ref().values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); - } - - { - let mat = - SparseColMat::::try_new_from_nonnegative_triplets(nrows, ncols, triplets); - assert!(mat.is_ok()); - let mat = mat.unwrap(); - - assert!(mat.nrows() == nrows); - assert!(mat.ncols() == ncols); - assert!(mat.col_ptrs() == &[0, 1, 3, 4, 5]); - assert!(mat.nnz_per_col() == None); - assert!(mat.row_indices() == &[0, 0, 1, 1, 3]); - assert!(mat.values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); - } - - { - let mat = SymbolicSparseRowMat::::try_new_from_nonnegative_indices( - nrows, ncols, indices, - ); - assert!(mat.is_ok()); - - let (mat, order) = mat.unwrap(); - assert!(mat.nrows() == nrows); - assert!(mat.ncols() == ncols); - assert!(mat.row_ptrs() == &[0, 2, 4, 4, 5, 5]); - assert!(mat.nnz_per_row() == None); - assert!(mat.col_indices() == &[0, 1, 1, 2, 3]); - - let mat = - SparseRowMat::<_, f64>::new_from_order_and_values(mat, &order, values).unwrap(); - assert!(mat.values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); - } - { - let mat = - SparseRowMat::::try_new_from_nonnegative_triplets(nrows, ncols, triplets); - assert!(mat.is_ok()); - - let mat = mat.unwrap(); - assert!(mat.nrows() == nrows); - assert!(mat.ncols() == ncols); - assert!(mat.row_ptrs() == &[0, 2, 4, 4, 5, 5]); - assert!(mat.nnz_per_row() == None); - assert!(mat.col_indices() == &[0, 1, 1, 2, 3]); - assert!(mat.as_ref().values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); - } - { - let order = SymbolicSparseRowMat::::try_new_from_nonnegative_indices( - nrows, ncols, indices, - ) - .unwrap() - .1; - - let new_values = &mut [f64::NAN; 5]; - let mut mat = SparseRowMatMut::<'_, usize, f64>::new( - SymbolicSparseRowMatRef::new_checked( - nrows, - ncols, - &[0, 2, 4, 4, 5, 5], - None, - &[0, 1, 1, 2, 3], - ), - new_values, - ); - mat.fill_from_order_and_values(&order, values, FillMode::Replace); - - assert!(&*new_values == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); - } - } - - #[test] - fn test_from_indices_oob_row() { - let nrows = 5; - let ncols = 4; - - let indices = &[ - (0, 0), - (1, 2), - (0, 0), - (1, 1), - (0, 1), - (3, 3), - (3, 3), - (5, 3usize), - ]; - let err = SymbolicSparseColMat::try_new_from_indices(nrows, ncols, indices); - assert!(err.is_err()); - let err = err.unwrap_err(); - assert!(err == CreationError::OutOfBounds { row: 5, col: 3 }); - } - - #[test] - fn test_from_indices_oob_col() { - let nrows = 5; - let ncols = 4; - - let indices = &[ - (0, 0), - (1, 2), - (0, 0), - (1, 1), - (0, 1), - (3, 3), - (3, 3), - (2, 4usize), - ]; - let err = SymbolicSparseColMat::try_new_from_indices(nrows, ncols, indices); - assert!(err.is_err()); - let err = err.unwrap_err(); - assert!(err == CreationError::OutOfBounds { row: 2, col: 4 }); - } - - #[test] - fn test_add_intersecting() { - let lhs = SparseColMat::::try_new_from_triplets( - 5, - 4, - &[ - (1, 0, 1.0), - (2, 1, 2.0), - (3, 2, 3.0), - (0, 0, 4.0), - (1, 1, 5.0), - (2, 2, 6.0), - (3, 3, 7.0), - (2, 0, 8.0), - (3, 1, 9.0), - (4, 2, 10.0), - (0, 2, 11.0), - (1, 3, 12.0), - (4, 0, 13.0), - ], - ) - .unwrap(); - - let rhs = SparseColMat::::try_new_from_triplets( - 5, - 4, - &[ - (1, 0, 10.0), - (2, 1, 14.0), - (3, 2, 15.0), - (4, 3, 16.0), - (0, 1, 17.0), - (1, 2, 18.0), - (2, 3, 19.0), - (3, 0, 20.0), - (4, 1, 21.0), - (0, 3, 22.0), - ], - ) - .unwrap(); - - let sum = ops::add(lhs.as_ref(), rhs.as_ref()).unwrap(); - assert!(sum.compute_nnz() == lhs.compute_nnz() + rhs.compute_nnz() - 3); - - for j in 0..4 { - for i in 0..5 { - assert!(sum.row_indices_of_col_raw(j)[i] == i); - } - } - - for j in 0..4 { - for i in 0..5 { - assert!( - sum[(i, j)] == lhs.get(i, j).unwrap_or(&0.0) + rhs.get(i, j).unwrap_or(&0.0) - ); - } - } - } - - #[test] - fn test_add_disjoint() { - let lhs = SparseColMat::::try_new_from_triplets( - 5, - 4, - &[ - (0, 0, 1.0), - (1, 1, 2.0), - (2, 2, 3.0), - (3, 3, 4.0), - (2, 0, 5.0), - (3, 1, 6.0), - (4, 2, 7.0), - (0, 2, 8.0), - (1, 3, 9.0), - (4, 0, 10.0), - ], - ) - .unwrap(); - - let rhs = SparseColMat::::try_new_from_triplets( - 5, - 4, - &[ - (1, 0, 11.0), - (2, 1, 12.0), - (3, 2, 13.0), - (4, 3, 14.0), - (0, 1, 15.0), - (1, 2, 16.0), - (2, 3, 17.0), - (3, 0, 18.0), - (4, 1, 19.0), - (0, 3, 20.0), - ], - ) - .unwrap(); - - let sum = ops::add(lhs.as_ref(), rhs.as_ref()).unwrap(); - assert!(sum.compute_nnz() == lhs.compute_nnz() + rhs.compute_nnz()); - - for j in 0..4 { - for i in 0..5 { - assert!(sum.row_indices_of_col_raw(j)[i] == i); - } - } - - for j in 0..4 { - for i in 0..5 { - assert!( - sum[(i, j)] == lhs.get(i, j).unwrap_or(&0.0) + rhs.get(i, j).unwrap_or(&0.0) - ); - } - } - } -} diff --git a/faer-libs/faer-evd/Cargo.toml b/faer-libs/faer-evd/Cargo.toml deleted file mode 100644 index e664fac244e3c4b981fc31cd5479b9a1e16e93fe..0000000000000000000000000000000000000000 --- a/faer-libs/faer-evd/Cargo.toml +++ /dev/null @@ -1,61 +0,0 @@ -[package] -name = "faer-evd" -version = "0.17.1" -edition = "2021" -authors = ["sarah <>"] -description = "Basic linear algebra routines" -readme = "../../README.md" -repository = "https://github.com/sarah-ek/faer-rs/" -license = "MIT" -keywords = ["math", "matrix", "linear-algebra"] - -[dependencies] -faer-entity = { workspace = true, default-features = false } - -faer-core = { version = "0.17.1", default-features = false, path = "../faer-core" } -faer-qr = { version = "0.17.1", default-features = false, path = "../faer-qr" } - -coe-rs = { workspace = true } -reborrow = { workspace = true } -pulp = { workspace = true, default-features = false } -dyn-stack = { workspace = true, default-features = false } - -num-traits = { workspace = true, default-features = false } -num-complex = { workspace = true, default-features = false } -bytemuck = { workspace = true } - -log = { workspace = true, optional = true, default-features = false } -libm = { workspace = true } -dbgf = "0.1.1" - -[dev-dependencies] -criterion = "0.5" -rand = "0.8.5" -nalgebra = "0.32.3" -assert_approx_eq = "1.1.0" - -[features] -default = ["std", "rayon"] -std = [ - "faer-core/std", - "faer-qr/std", - "pulp/std", -] -perf-warn = ["log", "faer-core/perf-warn"] -rayon = [ - "std", - "faer-core/rayon", - "faer-qr/rayon", -] -nightly = [ - "faer-core/nightly", - "faer-qr/nightly", - "pulp/nightly", -] - -[[bench]] -name = "bench" -harness = false - -[package.metadata.docs.rs] -rustdoc-args = ["--html-in-header", "katex-header.html"] diff --git a/faer-libs/faer-evd/LICENSE.MIT b/faer-libs/faer-evd/LICENSE.MIT deleted file mode 100644 index b3e9659c8860f4d82899554c214b91d46760ea59..0000000000000000000000000000000000000000 --- a/faer-libs/faer-evd/LICENSE.MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2022 sarah - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/faer-libs/faer-evd/LICENSE.MPL2 b/faer-libs/faer-evd/LICENSE.MPL2 deleted file mode 100644 index ee6256cdb62a765749a71aae3abea32884301cd1..0000000000000000000000000000000000000000 --- a/faer-libs/faer-evd/LICENSE.MPL2 +++ /dev/null @@ -1,373 +0,0 @@ -Mozilla Public License Version 2.0 -================================== - -1. Definitions --------------- - -1.1. "Contributor" - means each individual or legal entity that creates, contributes to - the creation of, or owns Covered Software. - -1.2. "Contributor Version" - means the combination of the Contributions of others (if any) used - by a Contributor and that particular Contributor's Contribution. - -1.3. "Contribution" - means Covered Software of a particular Contributor. - -1.4. "Covered Software" - means Source Code Form to which the initial Contributor has attached - the notice in Exhibit A, the Executable Form of such Source Code - Form, and Modifications of such Source Code Form, in each case - including portions thereof. - -1.5. "Incompatible With Secondary Licenses" - means - - (a) that the initial Contributor has attached the notice described - in Exhibit B to the Covered Software; or - - (b) that the Covered Software was made available under the terms of - version 1.1 or earlier of the License, but not also under the - terms of a Secondary License. - -1.6. "Executable Form" - means any form of the work other than Source Code Form. - -1.7. "Larger Work" - means a work that combines Covered Software with other material, in - a separate file or files, that is not Covered Software. - -1.8. "License" - means this document. - -1.9. "Licensable" - means having the right to grant, to the maximum extent possible, - whether at the time of the initial grant or subsequently, any and - all of the rights conveyed by this License. - -1.10. "Modifications" - means any of the following: - - (a) any file in Source Code Form that results from an addition to, - deletion from, or modification of the contents of Covered - Software; or - - (b) any new file in Source Code Form that contains any Covered - Software. - -1.11. "Patent Claims" of a Contributor - means any patent claim(s), including without limitation, method, - process, and apparatus claims, in any patent Licensable by such - Contributor that would be infringed, but for the grant of the - License, by the making, using, selling, offering for sale, having - made, import, or transfer of either its Contributions or its - Contributor Version. - -1.12. "Secondary License" - means either the GNU General Public License, Version 2.0, the GNU - Lesser General Public License, Version 2.1, the GNU Affero General - Public License, Version 3.0, or any later versions of those - licenses. - -1.13. "Source Code Form" - means the form of the work preferred for making modifications. - -1.14. "You" (or "Your") - means an individual or a legal entity exercising rights under this - License. For legal entities, "You" includes any entity that - controls, is controlled by, or is under common control with You. For - purposes of this definition, "control" means (a) the power, direct - or indirect, to cause the direction or management of such entity, - whether by contract or otherwise, or (b) ownership of more than - fifty percent (50%) of the outstanding shares or beneficial - ownership of such entity. - -2. License Grants and Conditions --------------------------------- - -2.1. Grants - -Each Contributor hereby grants You a world-wide, royalty-free, -non-exclusive license: - -(a) under intellectual property rights (other than patent or trademark) - Licensable by such Contributor to use, reproduce, make available, - modify, display, perform, distribute, and otherwise exploit its - Contributions, either on an unmodified basis, with Modifications, or - as part of a Larger Work; and - -(b) under Patent Claims of such Contributor to make, use, sell, offer - for sale, have made, import, and otherwise transfer either its - Contributions or its Contributor Version. - -2.2. Effective Date - -The licenses granted in Section 2.1 with respect to any Contribution -become effective for each Contribution on the date the Contributor first -distributes such Contribution. - -2.3. Limitations on Grant Scope - -The licenses granted in this Section 2 are the only rights granted under -this License. No additional rights or licenses will be implied from the -distribution or licensing of Covered Software under this License. -Notwithstanding Section 2.1(b) above, no patent license is granted by a -Contributor: - -(a) for any code that a Contributor has removed from Covered Software; - or - -(b) for infringements caused by: (i) Your and any other third party's - modifications of Covered Software, or (ii) the combination of its - Contributions with other software (except as part of its Contributor - Version); or - -(c) under Patent Claims infringed by Covered Software in the absence of - its Contributions. - -This License does not grant any rights in the trademarks, service marks, -or logos of any Contributor (except as may be necessary to comply with -the notice requirements in Section 3.4). - -2.4. Subsequent Licenses - -No Contributor makes additional grants as a result of Your choice to -distribute the Covered Software under a subsequent version of this -License (see Section 10.2) or under the terms of a Secondary License (if -permitted under the terms of Section 3.3). - -2.5. Representation - -Each Contributor represents that the Contributor believes its -Contributions are its original creation(s) or it has sufficient rights -to grant the rights to its Contributions conveyed by this License. - -2.6. Fair Use - -This License is not intended to limit any rights You have under -applicable copyright doctrines of fair use, fair dealing, or other -equivalents. - -2.7. Conditions - -Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted -in Section 2.1. - -3. Responsibilities -------------------- - -3.1. Distribution of Source Form - -All distribution of Covered Software in Source Code Form, including any -Modifications that You create or to which You contribute, must be under -the terms of this License. You must inform recipients that the Source -Code Form of the Covered Software is governed by the terms of this -License, and how they can obtain a copy of this License. You may not -attempt to alter or restrict the recipients' rights in the Source Code -Form. - -3.2. Distribution of Executable Form - -If You distribute Covered Software in Executable Form then: - -(a) such Covered Software must also be made available in Source Code - Form, as described in Section 3.1, and You must inform recipients of - the Executable Form how they can obtain a copy of such Source Code - Form by reasonable means in a timely manner, at a charge no more - than the cost of distribution to the recipient; and - -(b) You may distribute such Executable Form under the terms of this - License, or sublicense it under different terms, provided that the - license for the Executable Form does not attempt to limit or alter - the recipients' rights in the Source Code Form under this License. - -3.3. Distribution of a Larger Work - -You may create and distribute a Larger Work under terms of Your choice, -provided that You also comply with the requirements of this License for -the Covered Software. If the Larger Work is a combination of Covered -Software with a work governed by one or more Secondary Licenses, and the -Covered Software is not Incompatible With Secondary Licenses, this -License permits You to additionally distribute such Covered Software -under the terms of such Secondary License(s), so that the recipient of -the Larger Work may, at their option, further distribute the Covered -Software under the terms of either this License or such Secondary -License(s). - -3.4. Notices - -You may not remove or alter the substance of any license notices -(including copyright notices, patent notices, disclaimers of warranty, -or limitations of liability) contained within the Source Code Form of -the Covered Software, except that You may alter any license notices to -the extent required to remedy known factual inaccuracies. - -3.5. Application of Additional Terms - -You may choose to offer, and to charge a fee for, warranty, support, -indemnity or liability obligations to one or more recipients of Covered -Software. However, You may do so only on Your own behalf, and not on -behalf of any Contributor. You must make it absolutely clear that any -such warranty, support, indemnity, or liability obligation is offered by -You alone, and You hereby agree to indemnify every Contributor for any -liability incurred by such Contributor as a result of warranty, support, -indemnity or liability terms You offer. You may include additional -disclaimers of warranty and limitations of liability specific to any -jurisdiction. - -4. Inability to Comply Due to Statute or Regulation ---------------------------------------------------- - -If it is impossible for You to comply with any of the terms of this -License with respect to some or all of the Covered Software due to -statute, judicial order, or regulation then You must: (a) comply with -the terms of this License to the maximum extent possible; and (b) -describe the limitations and the code they affect. Such description must -be placed in a text file included with all distributions of the Covered -Software under this License. Except to the extent prohibited by statute -or regulation, such description must be sufficiently detailed for a -recipient of ordinary skill to be able to understand it. - -5. Termination --------------- - -5.1. The rights granted under this License will terminate automatically -if You fail to comply with any of its terms. However, if You become -compliant, then the rights granted under this License from a particular -Contributor are reinstated (a) provisionally, unless and until such -Contributor explicitly and finally terminates Your grants, and (b) on an -ongoing basis, if such Contributor fails to notify You of the -non-compliance by some reasonable means prior to 60 days after You have -come back into compliance. Moreover, Your grants from a particular -Contributor are reinstated on an ongoing basis if such Contributor -notifies You of the non-compliance by some reasonable means, this is the -first time You have received notice of non-compliance with this License -from such Contributor, and You become compliant prior to 30 days after -Your receipt of the notice. - -5.2. If You initiate litigation against any entity by asserting a patent -infringement claim (excluding declaratory judgment actions, -counter-claims, and cross-claims) alleging that a Contributor Version -directly or indirectly infringes any patent, then the rights granted to -You by any and all Contributors for the Covered Software under Section -2.1 of this License shall terminate. - -5.3. In the event of termination under Sections 5.1 or 5.2 above, all -end user license agreements (excluding distributors and resellers) which -have been validly granted by You or Your distributors under this License -prior to termination shall survive termination. - -************************************************************************ -* * -* 6. Disclaimer of Warranty * -* ------------------------- * -* * -* Covered Software is provided under this License on an "as is" * -* basis, without warranty of any kind, either expressed, implied, or * -* statutory, including, without limitation, warranties that the * -* Covered Software is free of defects, merchantable, fit for a * -* particular purpose or non-infringing. The entire risk as to the * -* quality and performance of the Covered Software is with You. * -* Should any Covered Software prove defective in any respect, You * -* (not any Contributor) assume the cost of any necessary servicing, * -* repair, or correction. This disclaimer of warranty constitutes an * -* essential part of this License. No use of any Covered Software is * -* authorized under this License except under this disclaimer. * -* * -************************************************************************ - -************************************************************************ -* * -* 7. Limitation of Liability * -* -------------------------- * -* * -* Under no circumstances and under no legal theory, whether tort * -* (including negligence), contract, or otherwise, shall any * -* Contributor, or anyone who distributes Covered Software as * -* permitted above, be liable to You for any direct, indirect, * -* special, incidental, or consequential damages of any character * -* including, without limitation, damages for lost profits, loss of * -* goodwill, work stoppage, computer failure or malfunction, or any * -* and all other commercial damages or losses, even if such party * -* shall have been informed of the possibility of such damages. This * -* limitation of liability shall not apply to liability for death or * -* personal injury resulting from such party's negligence to the * -* extent applicable law prohibits such limitation. Some * -* jurisdictions do not allow the exclusion or limitation of * -* incidental or consequential damages, so this exclusion and * -* limitation may not apply to You. * -* * -************************************************************************ - -8. Litigation -------------- - -Any litigation relating to this License may be brought only in the -courts of a jurisdiction where the defendant maintains its principal -place of business and such litigation shall be governed by laws of that -jurisdiction, without reference to its conflict-of-law provisions. -Nothing in this Section shall prevent a party's ability to bring -cross-claims or counter-claims. - -9. Miscellaneous ----------------- - -This License represents the complete agreement concerning the subject -matter hereof. If any provision of this License is held to be -unenforceable, such provision shall be reformed only to the extent -necessary to make it enforceable. Any law or regulation which provides -that the language of a contract shall be construed against the drafter -shall not be used to construe this License against a Contributor. - -10. Versions of the License ---------------------------- - -10.1. New Versions - -Mozilla Foundation is the license steward. Except as provided in Section -10.3, no one other than the license steward has the right to modify or -publish new versions of this License. Each version will be given a -distinguishing version number. - -10.2. Effect of New Versions - -You may distribute the Covered Software under the terms of the version -of the License under which You originally received the Covered Software, -or under the terms of any subsequent version published by the license -steward. - -10.3. Modified Versions - -If you create software not governed by this License, and you want to -create a new license for such software, you may create and use a -modified version of this License if you rename the license and remove -any references to the name of the license steward (except to note that -such modified license differs from this License). - -10.4. Distributing Source Code Form that is Incompatible With Secondary -Licenses - -If You choose to distribute Source Code Form that is Incompatible With -Secondary Licenses under the terms of this version of the License, the -notice described in Exhibit B of this License must be attached. - -Exhibit A - Source Code Form License Notice -------------------------------------------- - - This Source Code Form is subject to the terms of the Mozilla Public - License, v. 2.0. If a copy of the MPL was not distributed with this - file, You can obtain one at https://mozilla.org/MPL/2.0/. - -If it is not possible or desirable to put the notice in a particular -file, then You may include the notice in a location (such as a LICENSE -file in a relevant directory) where a recipient would be likely to look -for such a notice. - -You may add additional accurate notices of copyright ownership. - -Exhibit B - "Incompatible With Secondary Licenses" Notice ---------------------------------------------------------- - - This Source Code Form is "Incompatible With Secondary Licenses", as - defined by the Mozilla Public License, v. 2.0. diff --git a/faer-libs/faer-evd/benches/bench.rs b/faer-libs/faer-evd/benches/bench.rs deleted file mode 100644 index 203fc3adc39126108b7f89ef9f1790f79475037a..0000000000000000000000000000000000000000 --- a/faer-libs/faer-evd/benches/bench.rs +++ /dev/null @@ -1,561 +0,0 @@ -use criterion::*; -use dyn_stack::{GlobalPodBuffer, PodStack}; -use faer_core::{c32, c64, unzipped, zipped, ComplexField, Mat, Parallelism, RealField}; -use faer_evd::{ - tridiag::{tridiagonalize_in_place, tridiagonalize_in_place_req}, - tridiag_real_evd::{compute_tridiag_real_evd, compute_tridiag_real_evd_req}, -}; -use reborrow::*; -use std::any::type_name; - -fn random() -> E { - if coe::is_same::() { - coe::coerce_static(rand::random::()) - } else if coe::is_same::() { - coe::coerce_static(rand::random::()) - } else if coe::is_same::() { - coe::coerce_static(c32::new(rand::random(), rand::random())) - } else if coe::is_same::() { - coe::coerce_static(c64::new(rand::random(), rand::random())) - } else if coe::is_same::() { - coe::coerce_static(num_complex::Complex32::new(rand::random(), rand::random())) - } else if coe::is_same::() { - coe::coerce_static(num_complex::Complex64::new(rand::random(), rand::random())) - } else { - panic!() - } -} - -fn epsilon() -> E { - if coe::is_same::() { - coe::coerce_static(f32::EPSILON) - } else if coe::is_same::() { - coe::coerce_static(f64::EPSILON) - } else { - panic!() - } -} - -fn min_positive() -> E { - if coe::is_same::() { - coe::coerce_static(f32::MIN_POSITIVE) - } else if coe::is_same::() { - coe::coerce_static(f64::MIN_POSITIVE) - } else { - panic!() - } -} - -fn tridiagonalization(criterion: &mut Criterion) { - for n in [32, 64, 128, 256, 512, 1024, 2000, 4000] { - let mut mat = Mat::from_fn(n, n, |_, _| random::()); - let adjoint = mat.adjoint().to_owned(); - - zipped!(mat.as_mut(), adjoint.as_ref()) - .for_each(|unzipped!(mut x, y)| x.write(x.read().faer_add(y.read()))); - - let mut trid = mat.clone(); - let mut tau_left = Mat::zeros(n - 1, 1); - - { - let parallelism = Parallelism::None; - let mut mem = - GlobalPodBuffer::new(tridiagonalize_in_place_req::(n, parallelism).unwrap()); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function( - &format!("tridiag-st-{}-{}", type_name::(), n), - |bencher| { - bencher.iter(|| { - zipped!(trid.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - tridiagonalize_in_place( - trid.as_mut(), - tau_left.as_mut().col_mut(0).as_2d_mut(), - parallelism, - stack.rb_mut(), - ); - }); - }, - ); - } - { - let parallelism = Parallelism::Rayon(0); - let mut mem = - GlobalPodBuffer::new(tridiagonalize_in_place_req::(n, parallelism).unwrap()); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function( - &format!("tridiag-mt-{}-{}", type_name::(), n), - |bencher| { - bencher.iter(|| { - zipped!(trid.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - tridiagonalize_in_place( - trid.as_mut(), - tau_left.as_mut().col_mut(0).as_2d_mut(), - parallelism, - stack.rb_mut(), - ); - }); - }, - ); - } - } -} - -fn tridiagonal_evd(criterion: &mut Criterion) { - for n in [32, 64, 128, 256, 512, 1024, 4096] { - let diag = (0..n).map(|_| random::()).collect::>(); - let offdiag = (0..n - 1).map(|_| random::()).collect::>(); - let mut u = Mat::::zeros(n, n); - - let parallelism = Parallelism::None; - let mut mem = - GlobalPodBuffer::new(compute_tridiag_real_evd_req::(n, parallelism).unwrap()); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function( - &format!("tridiag-evd-st-{}-{}", type_name::(), n), - |bencher| { - bencher.iter(|| { - let mut diag = diag.clone(); - let mut offdiag = offdiag.clone(); - compute_tridiag_real_evd( - &mut diag, - &mut offdiag, - u.as_mut(), - epsilon(), - min_positive(), - parallelism, - stack.rb_mut(), - ); - }); - }, - ); - let parallelism = Parallelism::Rayon(0); - let mut mem = - GlobalPodBuffer::new(compute_tridiag_real_evd_req::(n, parallelism).unwrap()); - let mut stack = PodStack::new(&mut mem); - criterion.bench_function( - &format!("tridiag-evd-mt-{}-{}", type_name::(), n), - |bencher| { - bencher.iter(|| { - let mut diag = diag.clone(); - let mut offdiag = offdiag.clone(); - compute_tridiag_real_evd( - &mut diag, - &mut offdiag, - u.as_mut(), - epsilon(), - min_positive(), - parallelism, - stack.rb_mut(), - ); - }); - }, - ); - } -} - -fn evd(criterion: &mut Criterion) { - for n in [4, 6, 8, 10, 12, 16, 24, 32, 64, 128, 256, 512, 1024, 4096] { - let mut mat = Mat::from_fn(n, n, |_, _| random::()); - let adjoint = mat.adjoint().to_owned(); - - zipped!(mat.as_mut(), adjoint.as_ref()) - .for_each(|unzipped!(mut x, y)| x.write(x.read().faer_add(y.read()))); - - let mut s = Mat::zeros(n, n); - let mut u = Mat::zeros(n, n); - - { - let parallelism = Parallelism::None; - let mut mem = GlobalPodBuffer::new( - faer_evd::compute_hermitian_evd_req::( - n, - faer_evd::ComputeVectors::Yes, - parallelism, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function( - &format!("sym-evd-st-{}-{}", type_name::(), n), - |bencher| { - bencher.iter(|| { - faer_evd::compute_hermitian_evd( - mat.as_ref(), - s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(), - Some(u.as_mut()), - parallelism, - stack.rb_mut(), - Default::default(), - ); - }); - }, - ); - } - { - let parallelism = Parallelism::Rayon(0); - let mut mem = GlobalPodBuffer::new( - faer_evd::compute_hermitian_evd_req::( - n, - faer_evd::ComputeVectors::Yes, - parallelism, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function( - &format!("sym-evd-mt-{}-{}", type_name::(), n), - |bencher| { - bencher.iter(|| { - faer_evd::compute_hermitian_evd( - mat.as_ref(), - s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(), - Some(u.as_mut()), - parallelism, - stack.rb_mut(), - Default::default(), - ); - }); - }, - ); - } - } -} - -fn evd_nalgebra(criterion: &mut Criterion) { - for n in [4, 6, 8, 10, 12, 16, 24, 32, 64, 128, 256, 512, 1024, 4096] { - let mat = nalgebra::DMatrix::::from_fn(n, n, |_, _| random::()); - criterion.bench_function( - &format!("sym-evd-nalgebra-{}-{}", type_name::(), n), - |bencher| { - bencher.iter(|| { - mat.clone().symmetric_eigen(); - }); - }, - ); - } -} - -fn cplx_schur(criterion: &mut Criterion) { - for n in [32, 64, 128, 256, 512, 1024, 4096] { - let mat = Mat::from_fn(n, n, |_, _| random::()); - let mut t = mat.clone(); - let mut z = mat.clone(); - let mut w = Mat::zeros(n, 1); - - { - let parallelism = Parallelism::None; - let mut mem = GlobalPodBuffer::new( - faer_evd::hessenberg_cplx_evd::multishift_qr_req::( - n, - n, - true, - true, - Parallelism::None, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function(&format!("schur-st-{}-{}", type_name::(), n), |bencher| { - bencher.iter(|| { - zipped!(t.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - zipped!(z.as_mut()).for_each(|unzipped!(mut x)| x.write(E::faer_zero())); - zipped!(z.as_mut().diagonal_mut().column_vector_mut().as_2d_mut()) - .for_each(|unzipped!(mut x)| x.write(E::faer_one())); - - faer_evd::hessenberg_cplx_evd::multishift_qr( - true, - t.as_mut(), - Some(z.as_mut()), - w.as_mut(), - 0, - n, - epsilon(), - min_positive(), - parallelism, - stack.rb_mut(), - Default::default(), - ); - }); - }); - } - { - let parallelism = Parallelism::Rayon(0); - let mut mem = GlobalPodBuffer::new( - faer_evd::hessenberg_cplx_evd::multishift_qr_req::( - n, - n, - true, - true, - Parallelism::None, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function(&format!("schur-mt-{}-{}", type_name::(), n), |bencher| { - bencher.iter(|| { - zipped!(t.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - zipped!(z.as_mut()).for_each(|unzipped!(mut x)| x.write(E::faer_zero())); - zipped!(z.as_mut().diagonal_mut().column_vector_mut().as_2d_mut()) - .for_each(|unzipped!(mut x)| x.write(E::faer_one())); - - faer_evd::hessenberg_cplx_evd::multishift_qr( - true, - t.as_mut(), - Some(z.as_mut()), - w.as_mut(), - 0, - n, - epsilon(), - min_positive(), - parallelism, - stack.rb_mut(), - Default::default(), - ); - }); - }); - } - } -} - -fn real_schur(criterion: &mut Criterion) { - for n in [32, 64, 128, 256, 512, 1024, 4096] { - let mat = Mat::from_fn(n, n, |_, _| random::()); - let mut t = mat.clone(); - let mut z = mat.clone(); - let mut w_re = Mat::zeros(n, 1); - let mut w_im = Mat::zeros(n, 1); - - { - let parallelism = Parallelism::None; - let mut mem = GlobalPodBuffer::new( - faer_evd::hessenberg_real_evd::multishift_qr_req::( - n, - n, - true, - true, - Parallelism::None, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function(&format!("schur-st-{}-{}", type_name::(), n), |bencher| { - bencher.iter(|| { - zipped!(t.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - zipped!(z.as_mut()).for_each(|unzipped!(mut x)| x.write(E::faer_zero())); - zipped!(z.as_mut().diagonal_mut().column_vector_mut().as_2d_mut()) - .for_each(|unzipped!(mut x)| x.write(E::faer_one())); - - faer_evd::hessenberg_real_evd::multishift_qr( - true, - t.as_mut(), - Some(z.as_mut()), - w_re.as_mut(), - w_im.as_mut(), - 0, - n, - epsilon(), - min_positive(), - parallelism, - stack.rb_mut(), - Default::default(), - ); - }); - }); - } - { - let parallelism = Parallelism::Rayon(0); - let mut mem = GlobalPodBuffer::new( - faer_evd::hessenberg_real_evd::multishift_qr_req::( - n, - n, - true, - true, - Parallelism::None, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function(&format!("schur-mt-{}-{}", type_name::(), n), |bencher| { - bencher.iter(|| { - zipped!(t.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - zipped!(z.as_mut()).for_each(|unzipped!(mut x)| x.write(E::faer_zero())); - zipped!(z.as_mut().diagonal_mut().column_vector_mut().as_2d_mut()) - .for_each(|unzipped!(mut x)| x.write(E::faer_one())); - - faer_evd::hessenberg_real_evd::multishift_qr( - true, - t.as_mut(), - Some(z.as_mut()), - w_re.as_mut(), - w_im.as_mut(), - 0, - n, - epsilon(), - min_positive(), - parallelism, - stack.rb_mut(), - Default::default(), - ); - }); - }); - } - } -} - -fn hessenberg(criterion: &mut Criterion) { - for n in [32, 64, 128, 256, 512, 1024, 4096] { - let mat = Mat::from_fn(n, n, |_, _| random::()); - let mut t = mat.clone(); - let bs = faer_qr::no_pivoting::compute::recommended_blocksize::(n - 1, n - 1); - let mut householder = Mat::zeros(n - 1, bs); - - let parallelism = Parallelism::Rayon(0); - let mut mem = GlobalPodBuffer::new( - faer_evd::hessenberg::make_hessenberg_in_place_req::(n, bs, Parallelism::None) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function( - &format!("hessenberg-mt-{}-{}", type_name::(), n), - |bencher| { - bencher.iter(|| { - zipped!(t.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - - faer_evd::hessenberg::make_hessenberg_in_place( - t.as_mut(), - householder.as_mut(), - parallelism, - stack.rb_mut(), - ); - }); - }, - ); - } -} - -fn unsym_evd(criterion: &mut Criterion) { - for n in [32, 64, 128, 256, 512, 1024, 2048, 4096] { - let mat = Mat::from_fn(n, n, |_, _| random::()); - let mut z = mat.clone(); - let mut w_re = Mat::zeros(n, 1); - let mut w_im = Mat::zeros(n, 1); - - { - let parallelism = Parallelism::None; - let mut mem = GlobalPodBuffer::new( - faer_evd::compute_evd_req::( - n, - faer_evd::ComputeVectors::Yes, - Parallelism::None, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function( - &format!("unsym-evd-st-{}-{}", type_name::(), n), - |bencher| { - bencher.iter(|| { - faer_evd::compute_evd_real( - mat.as_ref(), - w_re.as_mut(), - w_im.as_mut(), - Some(z.as_mut()), - parallelism, - stack.rb_mut(), - Default::default(), - ); - }); - }, - ); - } - { - let parallelism = Parallelism::Rayon(0); - let mut mem = GlobalPodBuffer::new( - faer_evd::compute_evd_req::( - n, - faer_evd::ComputeVectors::Yes, - parallelism, - Default::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - criterion.bench_function( - &format!("unsym-evd-mt-{}-{}", type_name::(), n), - |bencher| { - bencher.iter(|| { - let t = std::time::Instant::now(); - faer_evd::compute_evd_real( - mat.as_ref(), - w_re.as_mut(), - w_im.as_mut(), - Some(z.as_mut()), - parallelism, - stack.rb_mut(), - Default::default(), - ); - dbg!(t.elapsed()); - }); - }, - ); - } - } -} - -fn cplx_schur_nalgebra(criterion: &mut Criterion) { - type Cplx64 = num_complex::Complex64; - for n in [32, 64, 128, 256, 512, 1024, 4096] { - let mat = nalgebra::DMatrix::::from_fn(n, n, |_, _| random::()); - criterion.bench_function( - &format!("schur-nalgebra-{}-{}", type_name::(), n), - |bencher| { - bencher.iter(|| mat.clone().schur()); - }, - ); - } -} - -criterion_group!( - benches, - tridiagonalization::, - tridiagonalization::, - tridiagonalization::, - tridiagonal_evd::, - tridiagonal_evd::, - evd::, - evd::, - evd::, - cplx_schur::, - real_schur::, - hessenberg::, - unsym_evd::, - cplx_schur_nalgebra, - evd_nalgebra, -); -criterion_main!(benches); diff --git a/faer-libs/faer-evd/katex-header.html b/faer-libs/faer-evd/katex-header.html deleted file mode 100644 index 32ac35a411428d1bcf1914b639299df9f86e448c..0000000000000000000000000000000000000000 --- a/faer-libs/faer-evd/katex-header.html +++ /dev/null @@ -1,15 +0,0 @@ - - - - diff --git a/faer-libs/faer-lu/Cargo.toml b/faer-libs/faer-lu/Cargo.toml deleted file mode 100644 index bd245ff0bcf14c3e8927e13ba342f4a0cf5b70e7..0000000000000000000000000000000000000000 --- a/faer-libs/faer-lu/Cargo.toml +++ /dev/null @@ -1,55 +0,0 @@ -[package] -name = "faer-lu" -version = "0.17.1" -edition = "2021" -authors = ["sarah <>"] -description = "Basic linear algebra routines" -readme = "../../README.md" -repository = "https://github.com/sarah-ek/faer-rs/" -license = "MIT" -keywords = ["math", "matrix", "linear-algebra"] - -[dependencies] -paste = "1.0.14" - -faer-entity = { workspace = true, default-features = false } - -faer-core = { version = "0.17.1", default-features = false, path = "../faer-core" } - -coe-rs = { workspace = true } -reborrow = { workspace = true } -pulp = { workspace = true, default-features = false } -dyn-stack = { workspace = true, default-features = false } - -num-traits = { workspace = true, default-features = false } -num-complex = { workspace = true, default-features = false } -bytemuck = { workspace = true } - -rayon = { workspace = true, optional = true } -log = { workspace = true, optional = true, default-features = false } -hurdles = "1.0.1" - -[features] -default = ["std", "rayon"] -std = [ - "faer-core/std", - "pulp/std", -] -perf-warn = ["log", "faer-core/perf-warn"] -rayon = ["std", "faer-core/rayon", "dep:rayon"] -nightly = ["faer-core/nightly", "pulp/nightly"] - -[dev-dependencies] -criterion = "0.5" -rand = "0.8.5" -nalgebra = "0.32.3" -assert_approx_eq = "1.1.0" -rayon = { workspace = true } -core_affinity = "0.8" - -[[bench]] -name = "bench" -harness = false - -[package.metadata.docs.rs] -rustdoc-args = ["--html-in-header", "katex-header.html"] diff --git a/faer-libs/faer-lu/LICENSE b/faer-libs/faer-lu/LICENSE deleted file mode 100644 index b3e9659c8860f4d82899554c214b91d46760ea59..0000000000000000000000000000000000000000 --- a/faer-libs/faer-lu/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2022 sarah - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/faer-libs/faer-lu/benches/bench.rs b/faer-libs/faer-lu/benches/bench.rs deleted file mode 100644 index a786359dd3f8d0e730fb46c7f540165ca87bd9c5..0000000000000000000000000000000000000000 --- a/faer-libs/faer-lu/benches/bench.rs +++ /dev/null @@ -1,168 +0,0 @@ -use criterion::{criterion_group, criterion_main, Criterion}; -use dyn_stack::{GlobalPodBuffer, PodStack}; -use faer_core::{Mat, Parallelism}; -use faer_lu::{ - full_pivoting::compute::FullPivLuComputeParams, - partial_pivoting::compute::PartialPivLuComputeParams, -}; -use rand::random; -use reborrow::*; - -use faer_lu::{full_pivoting, partial_pivoting}; - -pub fn lu(c: &mut Criterion) { - for n in [4, 6, 8, 12, 32, 64, 128, 256, 512, 1023, 1024, 2048, 4096] { - let partial_params = PartialPivLuComputeParams::default(); - let full_params = FullPivLuComputeParams::default(); - - let mat = nalgebra::DMatrix::::from_fn(n, n, |_, _| random::()); - { - c.bench_function(&format!("nalg-st-plu-{n}"), |b| { - b.iter(|| { - mat.clone().lu(); - }) - }); - c.bench_function(&format!("nalg-st-flu-{n}"), |b| { - b.iter(|| { - mat.clone().full_piv_lu(); - }) - }); - } - - let mat = Mat::from_fn(n, n, |_, _| random::()); - { - let mut perm = vec![0usize; n]; - let mut perm_inv = vec![0; n]; - let mut copy = mat.clone(); - - let mut mem = GlobalPodBuffer::new( - partial_pivoting::compute::lu_in_place_req::( - n, - n, - Parallelism::None, - partial_params, - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - c.bench_function(&format!("faer-st-plu-{n}"), |b| { - b.iter(|| { - copy.as_mut().copy_from(mat.as_ref()); - partial_pivoting::compute::lu_in_place( - copy.as_mut(), - &mut perm, - &mut perm_inv, - Parallelism::None, - stack.rb_mut(), - partial_params, - ); - }) - }); - } - { - let mut copy = mat.clone(); - let mut perm = vec![0usize; n]; - let mut perm_inv = vec![0; n]; - - let mut mem = GlobalPodBuffer::new( - partial_pivoting::compute::lu_in_place_req::( - n, - n, - Parallelism::Rayon(rayon::current_num_threads()), - partial_params, - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - c.bench_function(&format!("faer-mt-plu-{n}"), |b| { - b.iter(|| { - copy.as_mut().copy_from(mat.as_ref()); - partial_pivoting::compute::lu_in_place( - copy.as_mut(), - &mut perm, - &mut perm_inv, - Parallelism::Rayon(0), - stack.rb_mut(), - partial_params, - ); - }) - }); - } - - { - let mut copy = mat.clone(); - let mut row_perm = vec![0usize; n]; - let mut row_perm_inv = vec![0; n]; - let mut col_perm = vec![0; n]; - let mut col_perm_inv = vec![0; n]; - - let mut mem = GlobalPodBuffer::new( - full_pivoting::compute::lu_in_place_req::( - n, - n, - Parallelism::None, - full_params, - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - c.bench_function(&format!("faer-st-flu-{n}"), |b| { - b.iter(|| { - copy.as_mut().copy_from(mat.as_ref()); - full_pivoting::compute::lu_in_place( - copy.as_mut(), - &mut row_perm, - &mut row_perm_inv, - &mut col_perm, - &mut col_perm_inv, - Parallelism::None, - stack.rb_mut(), - full_params, - ); - }) - }); - } - - { - let mut copy = mat.clone(); - let mut row_perm = vec![0usize; n]; - let mut row_perm_inv = vec![0; n]; - let mut col_perm = vec![0; n]; - let mut col_perm_inv = vec![0; n]; - - let mut mem = GlobalPodBuffer::new( - full_pivoting::compute::lu_in_place_req::( - n, - n, - Parallelism::None, - full_params, - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - c.bench_function(&format!("faer-mt-flu-{n}"), |b| { - b.iter(|| { - copy.as_mut().copy_from(mat.as_ref()); - full_pivoting::compute::lu_in_place( - copy.as_mut(), - &mut row_perm, - &mut row_perm_inv, - &mut col_perm, - &mut col_perm_inv, - Parallelism::Rayon(rayon::current_num_threads()), - stack.rb_mut(), - full_params, - ); - }) - }); - } - } -} - -criterion_group!( - name = benches; - config = Criterion::default(); - targets = lu -); -criterion_main!(benches); diff --git a/faer-libs/faer-lu/katex-header.html b/faer-libs/faer-lu/katex-header.html deleted file mode 100644 index 32ac35a411428d1bcf1914b639299df9f86e448c..0000000000000000000000000000000000000000 --- a/faer-libs/faer-lu/katex-header.html +++ /dev/null @@ -1,15 +0,0 @@ - - - - diff --git a/faer-libs/faer-lu/src/lib.rs b/faer-libs/faer-lu/src/lib.rs deleted file mode 100644 index 2a52ba060ea45ff445e2db5df9381bcf94559ee9..0000000000000000000000000000000000000000 --- a/faer-libs/faer-lu/src/lib.rs +++ /dev/null @@ -1,6 +0,0 @@ -#![allow(clippy::type_complexity)] -#![allow(clippy::too_many_arguments)] -#![cfg_attr(not(feature = "std"), no_std)] - -pub mod full_pivoting; -pub mod partial_pivoting; diff --git a/faer-libs/faer-qr/Cargo.toml b/faer-libs/faer-qr/Cargo.toml deleted file mode 100644 index ca43e85f34d9010adef167d20ac84f689a8a7302..0000000000000000000000000000000000000000 --- a/faer-libs/faer-qr/Cargo.toml +++ /dev/null @@ -1,53 +0,0 @@ -[package] -name = "faer-qr" -version = "0.17.1" -edition = "2021" -authors = ["sarah <>"] -description = "Basic linear algebra routines" -readme = "../../README.md" -repository = "https://github.com/sarah-ek/faer-rs/" -license = "MIT" -keywords = ["math", "matrix", "linear-algebra"] - -[dependencies] -faer-entity = { workspace = true, default-features = false } - -faer-core = { version = "0.17.1", default-features = false, path = "../faer-core" } - -coe-rs = { workspace = true } -reborrow = { workspace = true } -pulp = { workspace = true, default-features = false } -dyn-stack = { workspace = true, default-features = false } - -num-traits = { workspace = true, default-features = false } -num-complex = { workspace = true, default-features = false } -bytemuck = { workspace = true } - -rayon = { workspace = true, optional = true } -log = { workspace = true, optional = true, default-features = false } - -[features] -default = ["std", "rayon"] -std = [ - "faer-core/std", - "pulp/std", -] -perf-warn = ["log", "faer-core/perf-warn"] -rayon = ["std", "faer-core/rayon", "dep:rayon"] -nightly = ["faer-core/nightly", "pulp/nightly"] - -[dev-dependencies] -criterion = "0.5" -rand = "0.8.5" -nalgebra = "0.32.3" -assert_approx_eq = "1.1.0" -rayon = "1.8" -dbgf = "0.1.1" -matrixcompare = "0.3" - -[[bench]] -name = "bench" -harness = false - -[package.metadata.docs.rs] -rustdoc-args = ["--html-in-header", "katex-header.html"] diff --git a/faer-libs/faer-qr/LICENSE b/faer-libs/faer-qr/LICENSE deleted file mode 100644 index b3e9659c8860f4d82899554c214b91d46760ea59..0000000000000000000000000000000000000000 --- a/faer-libs/faer-qr/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2022 sarah - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/faer-libs/faer-qr/benches/bench.rs b/faer-libs/faer-qr/benches/bench.rs deleted file mode 100644 index dda03bab84cce0f24248b93c754437f31274aa5e..0000000000000000000000000000000000000000 --- a/faer-libs/faer-qr/benches/bench.rs +++ /dev/null @@ -1,189 +0,0 @@ -use criterion::{criterion_group, criterion_main, Criterion}; -use dyn_stack::*; -use faer_core::{c64, unzipped, Mat, Parallelism}; -use faer_qr::no_pivoting::compute::recommended_blocksize; -use rand::random; - -pub fn qr(c: &mut Criterion) { - use faer_qr::*; - - for (m, n) in [ - (6, 6), - (8, 8), - (10, 10), - (12, 12), - (24, 24), - (32, 32), - (64, 64), - (128, 128), - (256, 256), - (512, 512), - (1024, 1024), - (10000, 128), - (10000, 1024), - (2048, 2048), - (4096, 4096), - (8192, 8192), - ] { - let mat = nalgebra::DMatrix::::from_fn(m, n, |_, _| random::()); - { - c.bench_function(&format!("nalg-st-qr-{m}x{n}"), |b| { - b.iter(|| { - mat.clone().qr(); - }) - }); - c.bench_function(&format!("nalg-st-colqr-{m}x{n}"), |b| { - b.iter(|| { - mat.clone().col_piv_qr(); - }) - }); - } - - let mat = Mat::from_fn(m, n, |_, _| random::()); - - { - let mut copy = mat.clone(); - let blocksize = no_pivoting::compute::recommended_blocksize::(m, n); - let mut householder = Mat::from_fn(blocksize, n, |_, _| random::()); - - let mut mem = GlobalPodBuffer::new(StackReq::new::(1024 * 1024 * 1024)); - let mut stack = PodStack::new(&mut mem); - c.bench_function(&format!("faer-st-qr-{m}x{n}"), |b| { - b.iter(|| { - faer_core::zipped!(copy.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - no_pivoting::compute::qr_in_place( - copy.as_mut(), - householder.as_mut(), - Parallelism::None, - stack.rb_mut(), - Default::default(), - ); - }) - }); - } - { - let mut copy = mat.clone(); - let blocksize = no_pivoting::compute::recommended_blocksize::(m, n); - let mut householder = Mat::from_fn(blocksize, n, |_, _| random::()); - - let mut mem = GlobalPodBuffer::new(StackReq::new::(1024 * 1024 * 1024)); - let mut stack = PodStack::new(&mut mem); - c.bench_function(&format!("faer-mt-qr-{m}x{n}"), |b| { - b.iter(|| { - faer_core::zipped!(copy.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - no_pivoting::compute::qr_in_place( - copy.as_mut(), - householder.as_mut(), - Parallelism::Rayon(0), - stack.rb_mut(), - Default::default(), - ); - }) - }); - } - - { - let mut copy = mat.clone(); - let blocksize = recommended_blocksize::(m, n); - let mut householder = Mat::from_fn(blocksize, n, |_, _| random::()); - let mut perm = vec![0usize; n]; - let mut perm_inv = vec![0; n]; - c.bench_function(&format!("faer-st-colqr-{m}x{n}"), |b| { - b.iter(|| { - faer_core::zipped!(copy.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - col_pivoting::compute::qr_in_place( - copy.as_mut(), - householder.as_mut(), - &mut perm, - &mut perm_inv, - Parallelism::None, - PodStack::new(&mut []), - Default::default(), - ); - }) - }); - } - - { - let mut copy = mat.clone(); - let blocksize = recommended_blocksize::(m, n); - let mut householder = Mat::from_fn(blocksize, n, |_, _| random::()); - let mut perm = vec![0usize; n]; - let mut perm_inv = vec![0; n]; - c.bench_function(&format!("faer-mt-colqr-{m}x{n}"), |b| { - b.iter(|| { - faer_core::zipped!(copy.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - col_pivoting::compute::qr_in_place( - copy.as_mut(), - householder.as_mut(), - &mut perm, - &mut perm_inv, - Parallelism::Rayon(0), - PodStack::new(&mut []), - Default::default(), - ); - }) - }); - } - - let mat = Mat::from_fn(m, n, |_, _| c64::new(random(), random())); - { - let mut copy = mat.clone(); - let blocksize = recommended_blocksize::(m, n); - let mut householder = Mat::from_fn(blocksize, n, |_, _| c64::new(random(), random())); - let mut perm = vec![0usize; n]; - let mut perm_inv = vec![0; n]; - c.bench_function(&format!("faer-st-cplx-colqr-{m}x{n}"), |b| { - b.iter(|| { - faer_core::zipped!(copy.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - col_pivoting::compute::qr_in_place( - copy.as_mut(), - householder.as_mut(), - &mut perm, - &mut perm_inv, - Parallelism::None, - PodStack::new(&mut []), - Default::default(), - ); - }) - }); - } - - { - let mut copy = mat.clone(); - let blocksize = recommended_blocksize::(m, n); - let mut householder = Mat::from_fn(blocksize, n, |_, _| c64::new(random(), random())); - let mut perm = vec![0usize; n]; - let mut perm_inv = vec![0; n]; - c.bench_function(&format!("faer-mt-cplx-colqr-{m}x{n}"), |b| { - b.iter(|| { - faer_core::zipped!(copy.as_mut(), mat.as_ref()) - .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); - col_pivoting::compute::qr_in_place( - copy.as_mut(), - householder.as_mut(), - &mut perm, - &mut perm_inv, - Parallelism::Rayon(0), - PodStack::new(&mut []), - Default::default(), - ); - }) - }); - } - } - - let _c = c; -} - -criterion_group!( - name = benches; - config = Criterion::default(); - targets = qr -); -criterion_main!(benches); diff --git a/faer-libs/faer-qr/katex-header.html b/faer-libs/faer-qr/katex-header.html deleted file mode 100644 index 32ac35a411428d1bcf1914b639299df9f86e448c..0000000000000000000000000000000000000000 --- a/faer-libs/faer-qr/katex-header.html +++ /dev/null @@ -1,15 +0,0 @@ - - - - diff --git a/faer-libs/faer-sparse/Cargo.toml b/faer-libs/faer-sparse/Cargo.toml deleted file mode 100644 index f84289e353b4bf0f59ae72be271f2dad193b68c0..0000000000000000000000000000000000000000 --- a/faer-libs/faer-sparse/Cargo.toml +++ /dev/null @@ -1,79 +0,0 @@ -[package] -name = "faer-sparse" -version = "0.17.1" -edition = "2021" -authors = ["sarah <>"] -description = "Basic linear algebra routines" -readme = "../../README.md" -repository = "https://github.com/sarah-ek/faer-rs/" -license = "MIT" -keywords = ["math", "matrix", "linear-algebra"] - -[dependencies] -faer-entity = { workspace = true, default-features = false } - -faer-core = { version = "0.17.1", default-features = false, path = "../faer-core" } -faer-cholesky = { version = "0.17.1", default-features = false, path = "../faer-cholesky" } -faer-qr = { version = "0.17.1", default-features = false, path = "../faer-qr" } -faer-lu = { version = "0.17.1", default-features = false, path = "../faer-lu" } - -reborrow = { workspace = true } -dyn-stack = { workspace = true } -log = { workspace = true, optional = true } -rayon = { workspace = true } -bytemuck = { workspace = true } -pulp = { workspace = true, default-features = false} -coe-rs = { workspace = true } -dbgf = "0.1.1" - -[features] -default = ["std", "rayon"] -std = [ - "faer-core/std", - "faer-cholesky/std", - "faer-qr/std", - "faer-lu/std", - "pulp/std", -] -rayon = [ - "std", - "faer-core/rayon", - "faer-cholesky/rayon", - "faer-qr/rayon", - "faer-lu/rayon", -] -nightly = [ - "faer-core/nightly", - "faer-cholesky/nightly", - "faer-qr/nightly", - "faer-lu/nightly", - "pulp/nightly", -] -perf-warn = ["log"] - -[dev-dependencies] -criterion.workspace = true -rand = "0.8.5" -assert_approx_eq = "1.1.0" -dbgf = "0.1.1" -paste = "1.0.14" -matrix-market-rs = "0.1.3" -amd = "0.2.2" -colamd = { version = "0.1", features = ["i64"] } -num-complex.workspace = true -regex = "1.10.3" - -[[bench]] -name = "cholesky" -harness = false - -[[bench]] -name = "qr" -harness = false - -[[bench]] -name = "lu" -harness = false - -[package.metadata.docs.rs] -rustdoc-args = ["--html-in-header", "katex-header.html"] diff --git a/faer-libs/faer-sparse/LICENSE b/faer-libs/faer-sparse/LICENSE deleted file mode 100644 index b3e9659c8860f4d82899554c214b91d46760ea59..0000000000000000000000000000000000000000 --- a/faer-libs/faer-sparse/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2022 sarah - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/faer-libs/faer-sparse/benches/cholesky.rs b/faer-libs/faer-sparse/benches/cholesky.rs deleted file mode 100644 index 49431630a9a4b0bf77d7e7bb5cff58c60e23ddea..0000000000000000000000000000000000000000 --- a/faer-libs/faer-sparse/benches/cholesky.rs +++ /dev/null @@ -1,213 +0,0 @@ -#![allow(non_snake_case)] - -use core::iter::zip; -use dyn_stack::{GlobalPodBuffer, PodStack}; -use faer_core::{ - sparse::{SparseColMatRef, SymbolicSparseColMatRef}, - Parallelism, Side, -}; -use faer_sparse::{cholesky::*, Index, SupernodalThreshold}; -use matrix_market_rs::MtxData; -use reborrow::*; -use regex::Regex; -use std::{ - env::args, - ffi::OsStr, - time::{Duration, Instant}, -}; - -fn load_mtx(data: MtxData) -> (usize, Vec, Vec, Vec) { - let I = I::truncate; - - let MtxData::Sparse([nrows, _], coo_indices, coo_values, _) = data else { - panic!() - }; - - let n = nrows; - let mut col_counts = vec![I(0); n]; - let mut col_ptr = vec![I(0); n + 1]; - - for &[_, j] in &coo_indices { - col_counts[j] += I(1); - } - - for i in 0..n { - col_ptr[i + 1] = col_ptr[i] + col_counts[i]; - } - let nnz = col_ptr[n].zx(); - - let mut row_ind = vec![I(0); nnz]; - let mut values = vec![0.0; nnz]; - - col_counts.copy_from_slice(&col_ptr[..n]); - - for (&[i, j], &val) in zip(&coo_indices, &coo_values) { - row_ind[col_counts[j].zx()] = I(i); - values[col_counts[j].zx()] = val; - col_counts[j] += I(1); - } - - (n, col_ptr, row_ind, values) -} - -fn time(mut f: impl FnMut()) -> Duration { - let now = Instant::now(); - f(); - now.elapsed() -} - -fn timeit(mut f: impl FnMut(), time_limit: Duration) -> Duration { - let mut n_iters: u32 = 1; - loop { - let t = time(|| { - for _ in 0..n_iters { - f(); - } - }); - - if t >= time_limit || n_iters > 1_000_000_000 { - return t / n_iters; - } - - n_iters = 10 * (time_limit.as_secs_f64() / t.as_secs_f64()) as u32; - } -} - -fn main() { - let regexes = args() - .skip(1) - .filter(|x| !x.trim().starts_with('-')) - .map(|s| Regex::new(&s).unwrap()) - .collect::>(); - - let matches = |s: &str| regexes.is_empty() || regexes.iter().any(|regex| regex.is_match(s)); - - let methods = [ - ( - "simplicial ", - SupernodalThreshold::FORCE_SIMPLICIAL, - Parallelism::None, - ), - ( - "supernodal ", - SupernodalThreshold::FORCE_SUPERNODAL, - Parallelism::None, - ), - ( - "supernodal-bk", - SupernodalThreshold::FORCE_SUPERNODAL, - Parallelism::None, - ), - ]; - - let time_limit = Duration::from_secs_f64(0.1); - - type I = usize; - let mut files = Vec::new(); - - for file in std::fs::read_dir("./bench_data").unwrap() { - let file = file.unwrap(); - if file.path().extension() == Some(OsStr::new("mtx")) { - let name = file - .path() - .file_name() - .unwrap() - .to_string_lossy() - .into_owned(); - files.push((name.strip_suffix(".mtx").unwrap().to_string(), Side::Upper)) - } - } - - files.sort_by(|(f0, _), (f1, _)| str::cmp(f0, f1)); - - for (file, side) in files { - if !matches(&file) { - continue; - } - if file.starts_with("chain") { - continue; - } - let path = "./bench_data/".to_string() + &*file + ".mtx"; - let Ok(data) = MtxData::::from_file(path) else { - continue; - }; - - let (n, col_ptr, row_ind, values) = load_mtx::(data); - let A = SparseColMatRef::<_, f64>::new( - SymbolicSparseColMatRef::new_checked(n, n, &col_ptr, None, &row_ind), - &*values, - ); - - let mut auto = ""; - time(|| { - let symbolic_cholesky = - &factorize_symbolic_cholesky(A.symbolic(), side, Default::default()).unwrap(); - auto = match symbolic_cholesky.raw() { - SymbolicCholeskyRaw::Simplicial(_) => "simplicial", - SymbolicCholeskyRaw::Supernodal(_) => "supernodal", - }; - println!("picked {auto} method for {file}"); - }); - - let times = methods.map(|(method, supernodal_flop_ratio_threshold, parallelism)| { - let symbolic = factorize_symbolic_cholesky( - A.symbolic(), - side, - CholeskySymbolicParams { - supernodal_flop_ratio_threshold, - ..Default::default() - }, - ) - .unwrap(); - - let mut mem = GlobalPodBuffer::new( - symbolic - .factorize_numeric_ldlt_req::(false, parallelism) - .unwrap(), - ); - let mut L_values = vec![0.0f64; symbolic.len_values()]; - let mut subdiag = vec![0.0f64; n]; - let mut fwd = vec![0; n]; - let mut inv = vec![0; n]; - let mut L_values = &mut *L_values; - - let f = || { - if method == "supernodal-bk" { - symbolic.factorize_numeric_intranode_bunch_kaufman( - L_values.rb_mut(), - &mut *subdiag, - &mut fwd, - &mut inv, - A, - side, - Default::default(), - parallelism, - PodStack::new(&mut mem), - ); - } else { - symbolic.factorize_numeric_ldlt( - L_values.rb_mut(), - A, - side, - Default::default(), - parallelism, - PodStack::new(&mut mem), - ); - } - }; - - let time = timeit(f, time_limit); - println!("{method}: {time:>35?}"); - (method, time) - }); - let best = times[..2].iter().min_by_key(|(_, time)| time).unwrap(); - let worst = times[..2].iter().max_by_key(|(_, time)| time).unwrap(); - - if best.0.trim_end() == auto { - println!("good: {}", worst.1.as_secs_f64() / best.1.as_secs_f64()); - } else { - println!("bad: {}", best.1.as_secs_f64() / worst.1.as_secs_f64()); - } - println!(); - } -} diff --git a/faer-libs/faer-sparse/benches/lu.rs b/faer-libs/faer-sparse/benches/lu.rs deleted file mode 100644 index bc1037d27463078b60373a26f8b77a315c5e08e1..0000000000000000000000000000000000000000 --- a/faer-libs/faer-sparse/benches/lu.rs +++ /dev/null @@ -1,269 +0,0 @@ -#![allow(non_snake_case)] - -use core::iter::zip; -use dyn_stack::{GlobalPodBuffer, PodStack, StackReq}; -use faer_core::{ - permutation::PermutationRef, - sparse::{SparseColMatRef, SymbolicSparseColMatRef}, -}; -use faer_entity::*; -use faer_sparse::{lu::supernodal::*, qr::col_etree, Index}; -use matrix_market_rs::MtxData; -use regex::Regex; -use std::{ - env::args, - ffi::OsStr, - time::{Duration, Instant}, -}; - -fn load_mtx(data: MtxData) -> (usize, usize, Vec, Vec, Vec) { - let I = I::truncate; - - let MtxData::Sparse([nrows, ncols], coo_indices, coo_values, _) = data else { - panic!() - }; - - let m = nrows; - let n = ncols; - let mut col_counts = vec![I(0); n]; - let mut col_ptr = vec![I(0); n + 1]; - - for &[i, j] in &coo_indices { - col_counts[j] += I(1); - if i != j { - col_counts[i] += I(1); - } - } - - for i in 0..n { - col_ptr[i + 1] = col_ptr[i] + col_counts[i]; - } - let nnz = col_ptr[n].zx(); - - let mut row_ind = vec![I(0); nnz]; - let mut values = vec![0.0; nnz]; - - col_counts.copy_from_slice(&col_ptr[..n]); - - for (&[i, j], &val) in zip(&coo_indices, &coo_values) { - if i == j { - values[col_counts[j].zx()] = 2.0 * val; - } else { - values[col_counts[i].zx()] = val; - values[col_counts[j].zx()] = val; - } - - row_ind[col_counts[j].zx()] = I(i); - col_counts[j] += I(1); - - if i != j { - row_ind[col_counts[i].zx()] = I(j); - col_counts[i] += I(1); - } - } - - (m, n, col_ptr, row_ind, values) -} - -fn time(mut f: impl FnMut()) -> Duration { - let now = Instant::now(); - f(); - now.elapsed() -} - -fn timeit(mut f: impl FnMut(), time_limit: Duration) -> Duration { - let mut n_iters: u32 = 1; - loop { - let t = time(|| { - for _ in 0..n_iters { - f(); - } - }); - - if t >= time_limit || n_iters > 1_000_000_000 { - return t / n_iters; - } - - n_iters = 2 * Ord::max((time_limit.as_secs_f64() / t.as_secs_f64()) as u32, n_iters); - } -} - -fn main() { - let regexes = args() - .skip(1) - .filter(|x| !x.trim().starts_with('-')) - .map(|s| Regex::new(&s).unwrap()) - .collect::>(); - - let matches = |s: &str| regexes.is_empty() || regexes.iter().any(|regex| regex.is_match(s)); - let time_limit = Duration::from_secs_f64(1.0); - - type I = usize; - - let mut files = Vec::new(); - - for file in std::fs::read_dir("./bench_data").unwrap() { - let file = file.unwrap(); - if file.path().extension() == Some(OsStr::new("mtx")) { - let name = file - .path() - .file_name() - .unwrap() - .to_string_lossy() - .into_owned(); - files.push(name.strip_suffix(".mtx").unwrap().to_string()) - } - } - files.sort(); - - for file in files { - if !matches(&file) { - continue; - } - let path = "./bench_data/".to_string() + &*file + ".mtx"; - let Ok(data) = MtxData::::from_file(path) else { - continue; - }; - - let (m, n, col_ptr, row_ind, values) = load_mtx::(data); - let nnz = values.len(); - println!("{file}: {m}×{n}, {nnz} non-zeros"); - - let A = SparseColMatRef::<'_, I, f64>::new( - SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_ind), - &values, - ); - - let mut p = vec![0usize; n].into_boxed_slice(); - let mut p_inv = vec![0usize; n].into_boxed_slice(); - { - let mut mem = GlobalPodBuffer::new(StackReq::new::(4 * 1024 * 1024 * 1024)); - - faer_sparse::colamd::order( - &mut p, - &mut p_inv, - *A, - Default::default(), - PodStack::new(&mut mem), - ) - .unwrap(); - } - let fill_col_perm = PermutationRef::<'_, I, Symbolic>::new_checked(&p, &p_inv); - - let mut etree = vec![0usize; n]; - let mut min_col = vec![0usize; m]; - let mut col_counts = vec![0usize; n]; - let etree = { - let mut mem = GlobalPodBuffer::new(StackReq::new::(4 * 1024 * 1024 * 1024)); - let nnz = A.compute_nnz(); - let mut new_col_ptrs = vec![0usize; m + 1]; - let mut new_row_ind = vec![0usize; nnz]; - let mut new_values = vec![0.0; nnz]; - let AT = faer_sparse::adjoint::( - &mut new_col_ptrs, - &mut new_row_ind, - &mut new_values, - A, - PodStack::new(&mut mem), - ); - - let mut post = vec![0usize; n]; - - let etree = col_etree(*A, Some(fill_col_perm), &mut etree, PodStack::new(&mut mem)); - faer_sparse::qr::postorder(&mut post, etree, PodStack::new(&mut mem)); - faer_sparse::qr::column_counts_aat( - &mut col_counts, - &mut min_col, - *AT, - Some(fill_col_perm), - etree, - &post, - PodStack::new(&mut mem), - ); - etree - }; - - let mut mem = GlobalPodBuffer::new(StackReq::new::(4 * 1024 * 1024 * 1024)); - - let symbolic = faer_sparse::lu::supernodal::factorize_supernodal_symbolic_lu::( - *A, - Some(fill_col_perm), - &min_col, - etree, - &col_counts, - PodStack::new(&mut mem), - faer_sparse::SymbolicSupernodalParams { - relax: Some(&[(4, 1.0), (16, 0.8), (48, 0.1), (usize::MAX, 0.05)]), - }, - ) - .unwrap(); - - let mut row_perm = vec![0; n]; - let mut row_perm_inv = vec![0; n]; - let mut col_perm = vec![0; n]; - let mut col_perm_inv = vec![0; n]; - - { - let mut lu = SupernodalLu::::new(); - let mut op = |parallelism| { - let _ = faer_sparse::lu::supernodal::factorize_supernodal_numeric_lu::( - &mut row_perm, - &mut row_perm_inv, - &mut lu, - A, - A, - fill_col_perm.cast(), - &symbolic, - parallelism, - PodStack::new(&mut mem), - ); - }; - - let warmup = time(|| op(faer_core::Parallelism::None)).as_secs_f64(); - println!("Multifrontal warmup : {warmup:>12.9}s"); - - let single_thread = - timeit(|| op(faer_core::Parallelism::None), time_limit).as_secs_f64(); - println!("Multifrontal single thread : {single_thread:>12.9}s"); - - let multithread = - timeit(|| op(faer_core::Parallelism::Rayon(0)), time_limit).as_secs_f64(); - println!("Multifrontal multithread : {multithread:>12.9}s"); - } - { - let mut lu = faer_sparse::superlu::supernodal::SupernodalLu::::new(); - let mut work = vec![]; - - let mut op = |parallelism| { - let _ = - faer_sparse::superlu::supernodal::factorize_supernodal_numeric_lu::( - &mut row_perm, - &mut row_perm_inv, - &mut col_perm, - &mut col_perm_inv, - &mut lu, - &mut work, - A, - fill_col_perm.cast(), - etree, - parallelism, - PodStack::new(&mut mem), - Default::default(), - ) - .unwrap(); - }; - - let warmup = time(|| op(faer_core::Parallelism::None)).as_secs_f64(); - println!("SuperLU warmup : {warmup:>12.9}s"); - - let single_thread = - timeit(|| op(faer_core::Parallelism::None), time_limit).as_secs_f64(); - println!("SuperLU single thread : {single_thread:>12.9}s"); - - let multithread = - timeit(|| op(faer_core::Parallelism::Rayon(0)), time_limit).as_secs_f64(); - println!("SuperLU multithread : {multithread:>12.9}s"); - } - println!(); - } -} diff --git a/faer-libs/faer-sparse/benches/qr.rs b/faer-libs/faer-sparse/benches/qr.rs deleted file mode 100644 index 10da4a7baef0d3500105d745d4668107615303de..0000000000000000000000000000000000000000 --- a/faer-libs/faer-sparse/benches/qr.rs +++ /dev/null @@ -1,246 +0,0 @@ -#![allow(non_snake_case)] - -use core::iter::zip; -use dyn_stack::{GlobalPodBuffer, PodStack, StackReq}; -use faer_core::{ - permutation::PermutationRef, - sparse::{SparseColMatRef, SymbolicSparseColMatRef}, -}; -use faer_entity::Symbolic; -use faer_sparse::{ - adjoint, - qr::{ - col_etree, column_counts_aat, postorder, - supernodal::{factorize_supernodal_numeric_qr, factorize_supernodal_symbolic_qr}, - }, - Index, -}; -use matrix_market_rs::MtxData; -use reborrow::*; -use regex::Regex; -use std::{ - env::args, - ffi::OsStr, - time::{Duration, Instant}, -}; - -fn load_mtx(data: MtxData) -> (usize, usize, Vec, Vec, Vec) { - let I = I::truncate; - - let MtxData::Sparse([nrows, ncols], coo_indices, coo_values, _) = data else { - panic!() - }; - - let m = nrows; - let n = ncols; - let mut col_counts = vec![I(0); n]; - let mut col_ptr = vec![I(0); n + 1]; - - for &[_, j] in &coo_indices { - col_counts[j] += I(1); - } - - for i in 0..n { - col_ptr[i + 1] = col_ptr[i] + col_counts[i]; - } - let nnz = col_ptr[n].zx(); - - let mut row_ind = vec![I(0); nnz]; - let mut values = vec![0.0; nnz]; - - col_counts.copy_from_slice(&col_ptr[..n]); - - for (&[i, j], &val) in zip(&coo_indices, &coo_values) { - row_ind[col_counts[j].zx()] = I(i); - values[col_counts[j].zx()] = val; - col_counts[j] += I(1); - } - - (m, n, col_ptr, row_ind, values) -} - -fn time(mut f: impl FnMut()) -> Duration { - let now = Instant::now(); - f(); - now.elapsed() -} - -fn timeit(mut f: impl FnMut(), time_limit: Duration) -> Duration { - let mut n_iters: u32 = 1; - loop { - let t = time(|| { - for _ in 0..n_iters { - f(); - } - }); - - if t >= time_limit || n_iters > 1_000_000_000 { - return t / n_iters; - } - - n_iters = 2 * Ord::max((time_limit.as_secs_f64() / t.as_secs_f64()) as u32, n_iters); - } -} - -fn main() { - let regexes = args() - .skip(1) - .filter(|x| !x.trim().starts_with('-')) - .map(|s| Regex::new(&s).unwrap()) - .collect::>(); - - let matches = |s: &str| regexes.is_empty() || regexes.iter().any(|regex| regex.is_match(s)); - let time_limit = Duration::from_secs_f64(1.0); - - type I = usize; - let I = I::truncate; - - let mut files = Vec::new(); - - for file in std::fs::read_dir("./bench_data/qr").unwrap() { - let file = file.unwrap(); - if file.path().extension() == Some(OsStr::new("mtx")) { - let name = file - .path() - .file_name() - .unwrap() - .to_string_lossy() - .into_owned(); - files.push(name.strip_suffix(".mtx").unwrap().to_string()) - } - } - files.sort(); - - let mut mem = GlobalPodBuffer::new(StackReq::new::(1024 * 1024 * 1024)); - - for file in files { - if !matches(&file) { - continue; - } - let path = "./bench_data/qr/".to_string() + &*file + ".mtx"; - let Ok(data) = MtxData::::from_file(path) else { - continue; - }; - - let (m, n, col_ptr, row_ind, values) = load_mtx::(data); - let nnz = row_ind.len(); - - let A = SparseColMatRef::<'_, I, f64>::new( - SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_ind), - &values, - ); - - let zero = I(0); - let mut new_col_ptrs = vec![zero; m + 1]; - let mut new_row_ind = vec![zero; nnz]; - let mut new_values = vec![0.0; nnz]; - - let AT = adjoint::( - &mut new_col_ptrs, - &mut new_row_ind, - &mut new_values, - A, - PodStack::new(&mut mem), - ) - .into_const(); - - let mut p = vec![0usize; n].into_boxed_slice(); - let mut p_inv = vec![0usize; n].into_boxed_slice(); - - faer_sparse::colamd::order( - &mut p, - &mut p_inv, - *A, - Default::default(), - PodStack::new(&mut mem), - ) - .unwrap(); - - let p = PermutationRef::<'_, I, Symbolic>::new_checked(&p, &p_inv); - - let mut etree = vec![zero; n]; - let mut post = vec![zero; n]; - let mut col_counts = vec![zero; n]; - let mut min_row = vec![zero; m]; - - let etree = col_etree(*A, Some(p), &mut etree, PodStack::new(&mut mem)); - postorder(&mut post, etree, PodStack::new(&mut mem)); - - column_counts_aat( - &mut col_counts, - &mut min_row, - *AT, - Some(p), - etree, - &post, - PodStack::new(&mut mem), - ); - - let min_col = min_row; - - let symbolic = factorize_supernodal_symbolic_qr::( - *A, - Some(p), - min_col, - etree, - &col_counts, - PodStack::new(&mut mem), - Default::default(), - ) - .unwrap(); - - let householder_nnz = symbolic.householder().len_householder_row_indices(); - let mut row_indices_in_panel = vec![zero; householder_nnz]; - - dbg!(&file); - dbg!(m, n, A.compute_nnz()); - let mut L_values = vec![0.0; symbolic.r_adjoint().len_values()]; - let mut householder_values = vec![0.0; symbolic.householder().len_householder_values()]; - let mut tau_values = vec![0.0; symbolic.householder().len_tau_values()]; - - let mut tau_blocksize = vec![I(0); n]; - let mut householder_nrows = vec![I(0); n]; - let mut householder_ncols = vec![I(0); n]; - - let multithread = timeit( - || { - factorize_supernodal_numeric_qr::( - &mut row_indices_in_panel, - &mut tau_blocksize, - &mut householder_nrows, - &mut householder_ncols, - &mut L_values, - &mut householder_values, - &mut tau_values, - AT, - Some(p.cast()), - &symbolic, - faer_core::Parallelism::Rayon(0), - PodStack::new(&mut mem), - ); - }, - time_limit, - ); - dbg!(multithread); - let single_thread = timeit( - || { - factorize_supernodal_numeric_qr::( - &mut row_indices_in_panel, - &mut tau_blocksize, - &mut householder_nrows, - &mut householder_ncols, - &mut L_values, - &mut householder_values, - &mut tau_values, - AT, - Some(p.cast()), - &symbolic, - faer_core::Parallelism::None, - PodStack::new(&mut mem), - ); - }, - time_limit, - ); - dbg!(single_thread); - } -} diff --git a/faer-libs/faer-sparse/katex-header.html b/faer-libs/faer-sparse/katex-header.html deleted file mode 100644 index 32ac35a411428d1bcf1914b639299df9f86e448c..0000000000000000000000000000000000000000 --- a/faer-libs/faer-sparse/katex-header.html +++ /dev/null @@ -1,15 +0,0 @@ - - - - diff --git a/faer-libs/faer-sparse/src/ghost.rs b/faer-libs/faer-sparse/src/ghost.rs deleted file mode 100644 index f9cd1fd82fe8ca9b617f28c6437f21a080b34846..0000000000000000000000000000000000000000 --- a/faer-libs/faer-sparse/src/ghost.rs +++ /dev/null @@ -1,39 +0,0 @@ -pub use faer_core::constrained::{group_helpers::*, permutation::*, sparse::*, *}; -use faer_core::permutation::Index; - -pub const NONE_BYTE: u8 = u8::MAX; - -#[inline] -pub fn with_size(n: usize, f: impl FnOnce(Size<'_>) -> R) -> R { - Size::with(n, f) -} - -#[inline] -pub fn fill_zero<'n, 'a, I: Index>(slice: &'a mut [I], size: Size<'n>) -> &'a mut [Idx<'n, I>] { - let len = slice.len(); - if len > 0 { - assert!(*size > 0); - } - unsafe { - core::ptr::write_bytes(slice.as_mut_ptr(), 0u8, len); - &mut *(slice as *mut _ as *mut _) - } -} - -#[inline] -pub fn fill_none<'n, 'a, I: Index>( - slice: &'a mut [I::Signed], - size: Size<'n>, -) -> &'a mut [MaybeIdx<'n, I>] { - let _ = size; - let len = slice.len(); - unsafe { core::ptr::write_bytes(slice.as_mut_ptr(), NONE_BYTE, len) }; - unsafe { &mut *(slice as *mut _ as *mut _) } -} - -#[inline] -pub fn copy_slice<'n, 'a, I: Index>(dst: &'a mut [I], src: &[Idx<'n, I>]) -> &'a mut [Idx<'n, I>] { - let dst: &mut [Idx<'_, I>] = unsafe { &mut *(dst as *mut _ as *mut _) }; - dst.copy_from_slice(src); - dst -} diff --git a/faer-libs/faer-sparse/src/mem.rs b/faer-libs/faer-sparse/src/mem.rs deleted file mode 100644 index efdcc9f500a833eecd8e714ed02ca2dd47ec1d8d..0000000000000000000000000000000000000000 --- a/faer-libs/faer-sparse/src/mem.rs +++ /dev/null @@ -1,15 +0,0 @@ -use faer_core::permutation::SignedIndex; - -pub const NONE_BYTE: u8 = 0xFF; -pub const NONE: usize = faer_core::sparse::repeat_byte(NONE_BYTE); - -#[inline] -pub fn fill_none(slice: &mut [I]) { - let len = slice.len(); - unsafe { core::ptr::write_bytes(slice.as_mut_ptr(), NONE_BYTE, len) } -} -#[inline] -pub fn fill_zero(slice: &mut [I]) { - let len = slice.len(); - unsafe { core::ptr::write_bytes(slice.as_mut_ptr(), 0u8, len) } -} diff --git a/faer-libs/faer-sparse/src/superlu.rs b/faer-libs/faer-sparse/src/superlu.rs deleted file mode 100644 index 36aa7702f114365c876b837c948b645009642435..0000000000000000000000000000000000000000 --- a/faer-libs/faer-sparse/src/superlu.rs +++ /dev/null @@ -1,2568 +0,0 @@ -// Copyright (c) 2003, The Regents of the University of California, through -// Lawrence Berkeley National Laboratory (subject to receipt of any required -// approvals from U.S. Dept. of Energy) -// -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are met: -// -// (1) Redistributions of source code must retain the above copyright notice, -// this list of conditions and the following disclaimer. -// (2) Redistributions in binary form must reproduce the above copyright notice, -// this list of conditions and the following disclaimer in the documentation -// and/or other materials provided with the distribution. -// (3) Neither the name of Lawrence Berkeley National Laboratory, U.S. Dept. of -// Energy nor the names of its contributors may be used to endorse or promote -// products derived from this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS -// IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, -// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING -// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -use crate::{ - ghost::{Array, Idx, MaybeIdx}, - mem::{self}, - Index, -}; -use reborrow::*; - -#[inline] -fn relaxed_supernodes<'n, I: Index>( - etree: &Array<'n, MaybeIdx<'n, I>>, - postorder: &Array<'n, Idx<'n, I>>, - postorder_inv: &Array<'n, Idx<'n, I>>, - relax_columns: usize, - descendants: &mut Array<'n, I>, - relax_end: &mut Array<'n, I::Signed>, -) { - let I = I::truncate; - - mem::fill_none(relax_end.as_mut()); - mem::fill_zero(descendants.as_mut()); - - let N = etree.len(); - let etree = |i: Idx<'n, usize>| match etree[postorder[i].zx()].idx() { - Some(parent) => MaybeIdx::from_index(postorder_inv[parent.zx()]), - None => MaybeIdx::none(), - }; - - for j in N.indices() { - if let Some(parent) = etree(j.zx()).idx() { - let parent = parent.zx(); - descendants[parent] = descendants[parent] + descendants[j] + I(1); - } - } - - let mut j = 0; - while j < *N { - let mut parent = etree(N.check(j).zx()).sx(); - let snode_start = j; - while let Some(parent_) = parent.idx() { - if descendants[parent_] >= I(relax_columns) { - break; - } - j = *parent_; - parent = etree(N.check(j).zx()).sx(); - } - relax_end[N.check(snode_start)] = I(j).to_signed(); - - j += 1; - - while j < *N && descendants[N.check(j)] != I(0) { - j += 1; - } - } -} - -pub mod supernodal { - use super::*; - use crate::{cholesky::simplicial::EliminationTreeRef, mem::NONE, FaerError}; - use core::iter::zip; - use dyn_stack::{PodStack, SizeOverflow, StackReq}; - use faer_core::{ - assert, - constrained::Size, - group_helpers::{SliceGroup, SliceGroupMut}, - mul, - permutation::{PermutationRef, SignedIndex}, - solve, - sparse::SparseColMatRef, - temp_mat_req, temp_mat_uninit, Conj, MatMut, Parallelism, - }; - use faer_entity::*; - - pub struct SupernodalLu { - nrows: usize, - ncols: usize, - nsupernodes: usize, - xsup: alloc::vec::Vec, - supno: alloc::vec::Vec, - lsub: alloc::vec::Vec, - xlusup: alloc::vec::Vec, - xlsub: alloc::vec::Vec, - usub: alloc::vec::Vec, - xusub: alloc::vec::Vec, - - lusup: GroupFor>, - ucol: GroupFor>, - } - - unsafe impl Send for SupernodalLu {} - unsafe impl Sync for SupernodalLu {} - - impl Default for SupernodalLu { - #[inline] - fn default() -> Self { - Self::new() - } - } - - impl SupernodalLu { - #[inline] - pub fn new() -> Self { - Self { - nrows: 0, - ncols: 0, - nsupernodes: 0, - xsup: alloc::vec::Vec::new(), - supno: alloc::vec::Vec::new(), - lsub: alloc::vec::Vec::new(), - xlusup: alloc::vec::Vec::new(), - xlsub: alloc::vec::Vec::new(), - usub: alloc::vec::Vec::new(), - xusub: alloc::vec::Vec::new(), - lusup: E::faer_map(E::UNIT, |()| alloc::vec::Vec::::new()), - ucol: E::faer_map(E::UNIT, |()| alloc::vec::Vec::::new()), - } - } - - #[inline] - pub fn nrows(&self) -> usize { - self.nrows - } - - #[inline] - pub fn ncols(&self) -> usize { - self.ncols - } - - #[inline] - pub fn n_supernodes(&self) -> usize { - self.nsupernodes - } - - #[track_caller] - pub fn solve_transpose_in_place_with_conj( - &self, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, - conj_lhs: Conj, - rhs: MatMut<'_, E>, - parallelism: Parallelism, - stack: PodStack<'_>, - ) where - E: ComplexField, - { - assert!(self.nrows() == self.ncols()); - assert!(self.nrows() == rhs.nrows()); - let mut X = rhs; - let (mut temp, mut stack) = temp_mat_uninit::(self.nrows(), X.ncols(), stack); - faer_core::permutation::permute_rows(temp.rb_mut(), X.rb(), col_perm); - self.u_solve_transpose_in_place_with_conj( - conj_lhs, - temp.rb_mut(), - parallelism, - stack.rb_mut(), - ); - self.l_solve_transpose_in_place_with_conj( - conj_lhs, - temp.rb_mut(), - parallelism, - stack.rb_mut(), - ); - faer_core::permutation::permute_rows(X.rb_mut(), temp.rb(), row_perm.inverse()); - } - - #[track_caller] - pub fn solve_in_place_with_conj( - &self, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, - conj_lhs: Conj, - rhs: MatMut<'_, E>, - parallelism: Parallelism, - stack: PodStack<'_>, - ) where - E: ComplexField, - { - assert!(self.nrows() == self.ncols()); - assert!(self.nrows() == rhs.nrows()); - let mut X = rhs; - let (mut temp, mut stack) = temp_mat_uninit::(self.nrows(), X.ncols(), stack); - faer_core::permutation::permute_rows(temp.rb_mut(), X.rb(), row_perm); - self.l_solve_in_place_with_conj(conj_lhs, temp.rb_mut(), parallelism, stack.rb_mut()); - self.u_solve_in_place_with_conj(conj_lhs, temp.rb_mut(), parallelism, stack.rb_mut()); - faer_core::permutation::permute_rows(X.rb_mut(), temp.rb(), col_perm.inverse()); - } - - #[track_caller] - pub fn l_solve_transpose_in_place_with_conj( - &self, - conj_lhs: Conj, - rhs: MatMut<'_, E>, - parallelism: Parallelism, - stack: PodStack<'_>, - ) where - E: ComplexField, - { - assert!(self.nrows() == self.ncols()); - assert!(self.nrows() == rhs.nrows()); - - let (mut work, _) = faer_core::temp_mat_uninit::(rhs.nrows(), rhs.ncols(), stack); - - let mut X = rhs; - let nrhs = X.ncols(); - - let nzval = - SliceGroup::<'_, E>::new(E::faer_map(E::faer_as_ref(&self.lusup), |lusup| { - &**lusup - })); - let nzval_colptr = &*self.xlusup; - let rowind = &*self.lsub; - let rowind_colptr = &*self.xlsub; - let sup_to_col = &*self.xsup; - - for k in (0..self.n_supernodes()).rev() { - let fsupc = sup_to_col[k].zx(); - let istart = rowind_colptr[fsupc].zx(); - let nsupr = rowind_colptr[fsupc + 1].zx() - istart; - let nsupc = sup_to_col[k + 1].zx() - fsupc; - let nrow = nsupr - nsupc; - - let luptr = nzval_colptr[fsupc].zx(); - let lda = nzval_colptr[fsupc + 1].zx() - luptr; - - let mut work = work.rb_mut().subrows_mut(0, nrow); - let A = faer_core::mat::from_column_major_slice_with_stride::<'_, E>( - nzval.subslice(luptr..nzval.len()).into_inner(), - nsupr, - nsupc, - lda, - ); - - let A_top = A.subrows(0, nsupc); - let A_bot = A.subrows(nsupc, nrow); - - for j in 0..nrhs { - let mut iptr = istart + nsupc; - for i in 0..nrow { - let irow = rowind[iptr].zx(); - work.write(i, j, X.read(irow, j)); - iptr += 1; - } - } - - mul::matmul_with_conj( - X.rb_mut().subrows_mut(fsupc, nsupc), - A_bot.transpose(), - conj_lhs, - work.rb().subrows(0, nrow), - Conj::No, - Some(E::faer_one()), - E::faer_one().faer_neg(), - parallelism, - ); - solve::solve_unit_upper_triangular_in_place_with_conj( - A_top.transpose(), - conj_lhs, - X.rb_mut().subrows_mut(fsupc, nsupc), - parallelism, - ); - } - } - - #[track_caller] - pub fn l_solve_in_place_with_conj( - &self, - conj_lhs: Conj, - rhs: MatMut<'_, E>, - parallelism: Parallelism, - stack: PodStack<'_>, - ) where - E: ComplexField, - { - assert!(self.nrows() == self.ncols()); - assert!(self.nrows() == rhs.nrows()); - - let (mut work, _) = faer_core::temp_mat_uninit::(rhs.nrows(), rhs.ncols(), stack); - - let mut X = rhs; - let nrhs = X.ncols(); - - let nzval = - SliceGroup::<'_, E>::new(E::faer_map(E::faer_as_ref(&self.lusup), |lusup| { - &**lusup - })); - let nzval_colptr = &*self.xlusup; - let rowind = &*self.lsub; - let rowind_colptr = &*self.xlsub; - let sup_to_col = &*self.xsup; - - for k in 0..self.n_supernodes() { - let fsupc = sup_to_col[k].zx(); - let istart = rowind_colptr[fsupc].zx(); - let nsupr = rowind_colptr[fsupc + 1].zx() - istart; - let nsupc = sup_to_col[k + 1].zx() - fsupc; - let nrow = nsupr - nsupc; - - let luptr = nzval_colptr[fsupc].zx(); - let lda = nzval_colptr[fsupc + 1].zx() - luptr; - - let mut work = work.rb_mut().subrows_mut(0, nrow); - let A = faer_core::mat::from_column_major_slice_with_stride::<'_, E>( - nzval.rb().subslice(luptr..nzval.len()).into_inner(), - nsupr, - nsupc, - lda, - ); - - let A_top = A.subrows(0, nsupc); - let A_bot = A.subrows(nsupc, nrow); - - solve::solve_unit_lower_triangular_in_place_with_conj( - A_top, - conj_lhs, - X.rb_mut().subrows_mut(fsupc, nsupc), - parallelism, - ); - mul::matmul_with_conj( - work.rb_mut(), - A_bot, - conj_lhs, - X.rb().subrows(fsupc, nsupc), - Conj::No, - None, - E::faer_one(), - parallelism, - ); - - for j in 0..nrhs { - let mut iptr = istart + nsupc; - for i in 0..nrow { - let irow = rowind[iptr].zx(); - X.write(irow, j, X.read(irow, j).faer_sub(work.read(i, j))); - iptr += 1; - } - } - } - } - - #[track_caller] - pub fn u_solve_transpose_in_place_with_conj( - &self, - conj_lhs: Conj, - rhs: MatMut<'_, E>, - parallelism: Parallelism, - stack: PodStack<'_>, - ) where - E: ComplexField, - { - assert!(self.ncols() == rhs.nrows()); - - let _ = stack; - - let mut X = rhs; - let nrhs = X.ncols(); - - let nzval = - SliceGroup::<'_, E>::new(E::faer_map(E::faer_as_ref(&self.lusup), |lusup| { - &**lusup - })); - let nzval_colptr = &*self.xlusup; - let sup_to_col = &*self.xsup; - - let u_col_ptr = &*self.xusub; - let u_row_ind = &*self.usub; - let u_val = - SliceGroup::<'_, E>::new(E::faer_map(E::faer_as_ref(&self.ucol), |ucol| &**ucol)); - - for k in 0..self.n_supernodes() { - let fsupc = sup_to_col[k].zx(); - let nsupc = sup_to_col[k + 1].zx() - fsupc; - - let luptr = nzval_colptr[fsupc].zx(); - let lda = nzval_colptr[fsupc + 1].zx() - luptr; - - let A = faer_core::mat::from_column_major_slice_with_stride::<'_, E>( - nzval.rb().subslice(luptr..nzval.len()).into_inner(), - nsupc, - nsupc, - lda, - ); - - let A_top = A.subrows(0, nsupc); - - // PERF(sparse-dense gemm) - for j in 0..nrhs { - for jcol in fsupc..fsupc + nsupc { - let start = u_col_ptr[jcol].zx(); - let end = u_col_ptr[jcol + 1].zx(); - let mut acc = E::faer_zero(); - for (row, val) in u_row_ind[start..end] - .iter() - .zip(u_val.subslice(start..end).into_ref_iter()) - { - let val = val.read(); - let val = if conj_lhs == Conj::Yes { - val.faer_conj() - } else { - val - }; - acc = acc.faer_add(X.read(row.zx(), j).faer_mul(val)); - } - X.write(jcol, j, X.read(jcol, j).faer_sub(acc)); - } - } - - solve::solve_lower_triangular_in_place_with_conj( - A_top.transpose(), - conj_lhs, - X.rb_mut().subrows_mut(fsupc, nsupc), - parallelism, - ); - } - } - - #[track_caller] - pub fn u_solve_in_place_with_conj( - &self, - conj_lhs: Conj, - rhs: MatMut<'_, E>, - parallelism: Parallelism, - stack: PodStack<'_>, - ) where - E: ComplexField, - { - assert!(self.ncols() == rhs.nrows()); - - let _ = stack; - - let mut X = rhs; - let nrhs = X.ncols(); - - let nzval = - SliceGroup::<'_, E>::new(E::faer_map(E::faer_as_ref(&self.lusup), |lusup| { - &**lusup - })); - let nzval_colptr = &*self.xlusup; - let sup_to_col = &*self.xsup; - - let u_col_ptr = &*self.xusub; - let u_row_ind = &*self.usub; - let u_val = - SliceGroup::<'_, E>::new(E::faer_map(E::faer_as_ref(&self.ucol), |ucol| &**ucol)); - - for k in (0..self.n_supernodes()).rev() { - let fsupc = sup_to_col[k].zx(); - let nsupc = sup_to_col[k + 1].zx() - fsupc; - - let luptr = nzval_colptr[fsupc].zx(); - let lda = nzval_colptr[fsupc + 1].zx() - luptr; - - let A = faer_core::mat::from_column_major_slice_with_stride::<'_, E>( - nzval.rb().subslice(luptr..nzval.len()).into_inner(), - nsupc, - nsupc, - lda, - ); - - let A_top = A.subrows(0, nsupc); - - solve::solve_upper_triangular_in_place_with_conj( - A_top, - conj_lhs, - X.rb_mut().subrows_mut(fsupc, nsupc), - parallelism, - ); - - // PERF(sparse-dense gemm) - for j in 0..nrhs { - for jcol in fsupc..fsupc + nsupc { - let start = u_col_ptr[jcol].zx(); - let end = u_col_ptr[jcol + 1].zx(); - let x_jcol = X.read(jcol, j); - for (row, val) in u_row_ind[start..end] - .iter() - .zip(u_val.subslice(start..end).into_ref_iter()) - { - let val = val.read(); - let val = if conj_lhs == Conj::Yes { - val.faer_conj() - } else { - val - }; - X.write( - row.zx(), - j, - X.read(row.zx(), j).faer_sub(x_jcol.faer_mul(val)), - ) - } - } - } - } - } - } - - impl core::fmt::Debug for SupernodalLu { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("SupernodalLu") - .field("xsup", &self.xsup) - .field("supno", &self.supno) - .field("lsub", &self.lsub) - .field("xlusup", &self.xlusup) - .field("xlsub", &self.xlsub) - .field("usub", &self.usub) - .field("xusub", &self.xusub) - .field("lusup", &to_slice_group::(&self.lusup)) - .field("ucol", &to_slice_group::(&self.ucol)) - .finish() - } - } - - #[derive(Copy, Clone)] - pub struct SupernodalLuParams { - pub panel_size: usize, - pub relax: usize, - pub max_super: usize, - pub row_block: usize, - pub col_block: usize, - pub fill_factor: usize, - } - - impl Default for SupernodalLuParams { - fn default() -> Self { - Self { - panel_size: 8, - relax: 4, - max_super: 128, - row_block: 16, - col_block: 8, - fill_factor: 20, - } - } - } - - #[inline(never)] - fn resize_scalar( - v: &mut GroupFor>, - n: usize, - exact: bool, - reserve_only: bool, - ) -> Result<(), FaerError> { - let mut failed = false; - let reserve = if exact { - alloc::vec::Vec::try_reserve_exact - } else { - alloc::vec::Vec::try_reserve - }; - - E::faer_map(E::faer_as_mut(v), |v| { - if !failed { - failed = reserve(v, n.saturating_sub(v.len())).is_err(); - if !reserve_only { - v.resize(Ord::max(n, v.len()), unsafe { core::mem::zeroed() }); - } - } - }); - if failed { - Err(FaerError::OutOfMemory) - } else { - Ok(()) - } - } - - #[inline(never)] - fn resize_work( - v: &mut alloc::vec::Vec, - req: Result, - exact: bool, - ) -> Result<(), FaerError> { - let reserve = if exact { - alloc::vec::Vec::try_reserve_exact - } else { - alloc::vec::Vec::try_reserve - }; - let n = req - .and_then(|req| req.try_unaligned_bytes_required()) - .map_err(|_| FaerError::OutOfMemory)?; - reserve(v, n.saturating_sub(v.len())).map_err(|_| FaerError::OutOfMemory)?; - v.resize(n, 0); - unsafe { v.set_len(n) }; - Ok(()) - } - - #[inline(never)] - fn resize_index( - v: &mut alloc::vec::Vec, - n: usize, - exact: bool, - reserve_only: bool, - ) -> Result<(), FaerError> { - let reserve = if exact { - alloc::vec::Vec::try_reserve_exact - } else { - alloc::vec::Vec::try_reserve - }; - reserve(v, n.saturating_sub(v.len())).map_err(|_| FaerError::OutOfMemory)?; - if !reserve_only { - v.resize(Ord::max(n, v.len()), I::truncate(0)); - } - Ok(()) - } - - impl SupernodalLu { - #[inline(never)] - fn mem_init( - &mut self, - m: usize, - n: usize, - a_nnz: usize, - fillratio: usize, - ) -> Result<(), FaerError> { - use FaerError::IndexOverflow; - - let nzumax = Ord::min( - fillratio.checked_mul(a_nnz + 1).ok_or(IndexOverflow)?, - m.checked_mul(n).ok_or(IndexOverflow)?, - ); - let nzlumax = nzumax; - let nzlmax = Ord::max(4, fillratio) - .checked_mul(a_nnz + 1) - .ok_or(IndexOverflow)? - / 4; - - resize_index(&mut self.xsup, n + 1, true, false)?; - resize_index(&mut self.supno, n + 1, true, false)?; - resize_index(&mut self.xlsub, n + 1, true, false)?; - resize_index(&mut self.xlusup, n + 1, true, false)?; - resize_index(&mut self.xusub, n + 1, true, false)?; - - resize_scalar::(&mut self.lusup, nzlumax, true, true)?; - resize_scalar::(&mut self.ucol, nzumax, true, true)?; - resize_index(&mut self.lsub, nzlmax, true, true)?; - resize_index(&mut self.usub, nzumax, true, true)?; - - Ok(()) - } - } - - #[inline(never)] - fn panel_dfs( - m: usize, - w: usize, - jcol: usize, - A: SparseColMatRef<'_, I, E>, - col_perm: &[I], - perm_r: &[I], - nseg: &mut usize, - dense: SliceGroupMut<'_, E>, - panel_lsub: &mut [I::Signed], - segrep: &mut [I], - repfnz: &mut [I::Signed], - xprune: &mut [I], - marker: &mut [I::Signed], - parent: &mut [I], - xplore: &mut [I], - lu: &mut SupernodalLu, - ) { - let I = I::truncate; - *nseg = 0; - - let mut dense = dense; - - for jj in jcol..jcol + w { - let mut nextl_col = (jj - jcol) * m; - let repfnz_col = &mut repfnz[nextl_col..][..m]; - let mut dense_col = dense.rb_mut().subslice(nextl_col..nextl_col + m); - - let pjj = col_perm[jj].zx(); - for (krow, value) in zip( - A.row_indices_of_col(pjj), - SliceGroup::<'_, E>::new(A.values_of_col(pjj)).into_ref_iter(), - ) { - dense_col.write(krow, value.read()); - let kmark = marker[krow]; - if kmark == I(jj).to_signed() { - continue; - } - - dfs_kernel_for_panel_dfs( - m, - jcol, - jj, - perm_r, - nseg, - panel_lsub, - segrep, - repfnz_col, - xprune, - marker, - parent, - xplore, - lu, - &mut nextl_col, - krow, - ); - } - } - } - - #[inline(never)] - fn dfs_kernel_for_column_dfs( - jsuper: &mut usize, - jj: usize, - perm_r: &[I], - nseg: &mut usize, - segrep: &mut [I], - repfnz_col: &mut [I::Signed], - xprune: &mut [I], - marker: &mut [I::Signed], - parent: &mut [I], - xplore: &mut [I], - lu: &mut SupernodalLu, - nextl_col: &mut usize, - krow: usize, - ) -> Result<(), FaerError> { - let I = I::truncate; - let panel_lsub = &mut lu.lsub; - - let kmark = marker[krow]; - - marker[krow] = I(jj).to_signed(); - let kperm = perm_r[krow]; - - if kperm == I(NONE) { - resize_index(panel_lsub, *nextl_col + 1, false, false)?; - panel_lsub[*nextl_col] = I(krow); - *nextl_col += 1; - - if kmark + I(1).to_signed() != I(jj).to_signed() { - *jsuper = NONE; - } - } else { - let mut krep = lu.xsup[lu.supno[kperm.zx()].zx() + 1] - .zx() - .saturating_sub(1); - let mut myfnz = repfnz_col[krep]; - - if myfnz != I(NONE).to_signed() { - if myfnz > kperm.to_signed() { - repfnz_col[krep] = kperm.to_signed(); - } - } else { - let mut oldrep = I(NONE); - parent[krep] = oldrep; - repfnz_col[krep] = kperm.to_signed(); - let mut xdfs = lu.xlsub[krep].zx(); - let mut maxdfs = xprune[krep].zx(); - - loop { - while xdfs < maxdfs { - let kchild = panel_lsub[xdfs].zx(); - xdfs += 1; - let chmark = marker[kchild]; - - if chmark != I(jj).to_signed() { - marker[kchild] = I(jj).to_signed(); - let chperm = perm_r[kchild]; - if chperm == I(NONE) { - resize_index(panel_lsub, *nextl_col + 1, false, false)?; - panel_lsub[*nextl_col] = I(kchild); - *nextl_col += 1; - - if chmark + I(1).to_signed() != I(jj).to_signed() { - *jsuper = NONE; - } - } else { - let chrep = lu.xsup[lu.supno[chperm.zx()].zx() + 1] - .zx() - .saturating_sub(1); - myfnz = repfnz_col[chrep]; - - if myfnz != I(NONE).to_signed() { - if myfnz > chperm.to_signed() { - repfnz_col[chrep] = chperm.to_signed(); - } - } else { - xplore[krep] = I(xdfs); - oldrep = I(krep); - krep = chrep; - parent[krep] = oldrep; - repfnz_col[krep] = chperm.to_signed(); - xdfs = lu.xlsub[krep].zx(); - maxdfs = xprune[krep].zx(); - } - } - } - } - - segrep[*nseg] = I(krep); - *nseg += 1; - - let kpar = parent[krep]; - if kpar == I(NONE) { - break; - } - - krep = kpar.zx(); - xdfs = xplore[krep].zx(); - maxdfs = xprune[krep].zx(); - } - } - } - Ok(()) - } - - #[inline(never)] - fn dfs_kernel_for_panel_dfs( - m: usize, - jcol: usize, - jj: usize, - perm_r: &[I], - nseg: &mut usize, - panel_lsub: &mut [I::Signed], - segrep: &mut [I], - repfnz_col: &mut [I::Signed], - xprune: &mut [I], - marker: &mut [I::Signed], - parent: &mut [I], - xplore: &mut [I], - lu: &mut SupernodalLu, - nextl_col: &mut usize, - krow: usize, - ) { - let I = I::truncate; - - marker[krow] = I(jj).to_signed(); - let kperm = perm_r[krow]; - - if kperm == I(NONE) { - panel_lsub[*nextl_col] = I(krow).to_signed(); - *nextl_col += 1; - } else { - let mut krep = lu.xsup[lu.supno[kperm.zx()].zx() + 1].zx() - 1; - let mut myfnz = repfnz_col[krep]; - - if myfnz != I(NONE).to_signed() { - if myfnz > kperm.to_signed() { - repfnz_col[krep] = kperm.to_signed(); - } - } else { - let mut oldrep = I(NONE); - parent[krep] = oldrep; - repfnz_col[krep] = kperm.to_signed(); - let mut xdfs = lu.xlsub[krep].zx(); - let mut maxdfs = xprune[krep].zx(); - - loop { - while xdfs < maxdfs { - let kchild = lu.lsub[xdfs].zx(); - xdfs += 1; - let chmark = marker[kchild]; - - if chmark != I(jj).to_signed() { - marker[kchild] = I(jj).to_signed(); - let chperm = perm_r[kchild]; - if chperm == I(NONE) { - panel_lsub[*nextl_col] = I(kchild).to_signed(); - *nextl_col += 1; - } else { - let chrep = lu.xsup[lu.supno[chperm.zx()].zx() + 1] - .zx() - .saturating_sub(1); - myfnz = repfnz_col[chrep]; - - if myfnz != I(NONE).to_signed() { - if myfnz > chperm.to_signed() { - repfnz_col[chrep] = chperm.to_signed(); - } - } else { - xplore[krep] = I(xdfs); - oldrep = I(krep); - krep = chrep; - parent[krep] = oldrep; - repfnz_col[krep] = chperm.to_signed(); - xdfs = lu.xlsub[krep].zx(); - maxdfs = xprune[krep].zx(); - } - } - } - } - - if marker[m + krep] < I(jcol).to_signed() { - marker[m + krep] = I(jj).to_signed(); - segrep[*nseg] = I(krep); - *nseg += 1; - } - - let kpar = parent[krep]; - if kpar == I(NONE) { - break; - } - - krep = kpar.zx(); - xdfs = xplore[krep].zx(); - maxdfs = xprune[krep].zx(); - } - } - } - } - - #[derive(Copy, Clone, Debug)] - pub enum LuError { - Generic(FaerError), - ZeroColumn(usize), - } - - impl From for LuError { - #[inline] - fn from(value: FaerError) -> Self { - Self::Generic(value) - } - } - - #[inline(never)] - fn prune_l( - jcol: usize, - row_perm: &[I], - pivrow: usize, - nseg: usize, - segrep: &[I], - repfnz: &mut [I::Signed], - xprune: &mut [I], - lu: &mut SupernodalLu, - ) { - let I = I::truncate; - - let jsupno = lu.supno[jcol].zx(); - let mut kmin = 0; - let mut kmax = 0; - - for i in 0..nseg { - let irep = segrep[i].zx(); - let irep1 = irep + 1; - let mut do_prune = false; - if repfnz[irep] == I(NONE).to_signed() { - continue; - } - if lu.supno[irep] == lu.supno[irep1] { - continue; - } - - if lu.supno[irep] != I(jsupno) { - if xprune[irep] >= lu.xlsub[irep1] { - kmin = lu.xlsub[irep].zx(); - kmax = lu.xlsub[irep1].zx() - 1; - for krow in kmin..kmax + 1 { - if lu.lsub[krow] == I(pivrow) { - do_prune = true; - break; - } - } - } - - if do_prune { - let mut movnum = false; - if I(irep) == lu.xsup[lu.supno[irep].zx()] { - movnum = true; - } - - while kmin <= kmax { - if row_perm[lu.lsub[kmax].zx()] == I(NONE) { - kmax -= 1; - } else if row_perm[lu.lsub[kmin].zx()] != I(NONE) { - kmin += 1; - } else { - lu.lsub.swap(kmin, kmax); - if movnum { - let minloc = lu.xlusup[irep].zx() + (kmin - lu.xlsub[irep].zx()); - let maxloc = lu.xlusup[irep].zx() + (kmax - lu.xlsub[irep].zx()); - E::faer_map(E::faer_as_mut(&mut lu.lusup), |lusup| { - lusup.swap(minloc, maxloc) - }); - } - kmin += 1; - kmax -= 1; - } - } - xprune[irep] = I(kmin); - } - } - } - } - - #[inline(never)] - fn pivot_l( - jcol: usize, - diag_pivot_thresh: E::Real, - row_perm: &mut [I], - col_perm_inv: PermutationRef<'_, I, E>, - pivrow: &mut usize, - lu: &mut SupernodalLu, - ) -> bool { - let I = I::truncate; - - let mut values = - SliceGroupMut::<'_, E>::new(E::faer_map(E::faer_as_mut(&mut lu.lusup), |x| &mut **x)); - let indices = &mut *lu.lsub; - - let fsupc = lu.xsup[lu.supno[jcol].zx()].zx(); - let nsupc = jcol - fsupc; - let lptr = lu.xlsub[fsupc].zx(); - let nsupr = lu.xlsub[fsupc + 1].zx() - lptr; - let lda = (lu.xlusup[fsupc + 1] - lu.xlusup[fsupc]).zx(); - - let lu_sup_ptr = lu.xlusup[fsupc].zx(); - let lu_col_ptr = lu.xlusup[jcol].zx(); - let lsub_ptr = lptr; - - let diagind = col_perm_inv.into_arrays().0[jcol].zx(); - let mut pivmax = E::Real::faer_one().faer_neg(); - let mut pivptr = nsupc; - let mut diag = I(NONE); - for isub in nsupc..nsupr { - let rtemp = values.read(lu_col_ptr + isub).faer_abs(); - if rtemp > pivmax { - pivmax = rtemp; - pivptr = isub; - } - if indices[lsub_ptr + isub].zx() == diagind { - diag = I(isub); - } - } - - if pivmax <= E::Real::faer_zero() { - *pivrow = if pivmax < E::Real::faer_zero() { - diagind - } else { - indices[lsub_ptr + pivptr].zx() - }; - row_perm[*pivrow] = I(jcol); - return true; - } - - let thresh = diag_pivot_thresh.faer_mul(pivmax); - if diag.to_signed() >= I(0).to_signed() { - let rtemp = values.read(lu_col_ptr + diag.zx()).faer_abs(); - if rtemp != E::Real::faer_zero() && rtemp >= thresh { - pivptr = diag.zx(); - } - } - *pivrow = indices[lsub_ptr + pivptr].zx(); - row_perm[*pivrow] = I(jcol); - - if pivptr != nsupc { - indices.swap(lsub_ptr + pivptr, lsub_ptr + nsupc); - for icol in 0..nsupc + 1 { - let itemp = pivptr + icol * lda; - let tmp = values.read(lu_sup_ptr + itemp); - values.write( - lu_sup_ptr + itemp, - values.read(lu_sup_ptr + nsupc + icol * lda), - ); - values.write(lu_sup_ptr + nsupc + icol * lda, tmp); - } - } - - let temp = values.read(lu_col_ptr + nsupc).faer_inv(); - for k in nsupc + 1..nsupr { - values.write(lu_col_ptr + k, values.read(lu_col_ptr + k).faer_mul(temp)); - } - - return false; - } - - #[inline(never)] - fn copy_to_ucol( - jcol: usize, - nseg: usize, - segrep: &mut [I], - repfnz: &mut [::Signed], - row_perm: &[I], - mut dense: SliceGroupMut<'_, E>, - lu: &mut SupernodalLu, - ) -> Result<(), FaerError> { - let I = I::truncate; - - let jsupno = lu.supno[jcol].zx(); - let mut nextu = lu.xusub[jcol].zx(); - - for k in (0..nseg).rev() { - let krep = segrep[k].zx(); - let ksupno = lu.supno[krep].zx(); - - if jsupno != ksupno { - let kfnz = repfnz[krep]; - if kfnz != I(NONE).to_signed() { - let fsupc = lu.xsup[ksupno].zx(); - let mut isub = lu.xlsub[fsupc].zx() + kfnz.zx() - fsupc; - let segsize = krep + 1 - kfnz.zx(); - let new_next = nextu + segsize; - - resize_scalar::(&mut lu.ucol, new_next, false, false)?; - resize_index(&mut lu.usub, new_next, false, false)?; - - let mut ucol = SliceGroupMut::<'_, E>::new(E::faer_map( - E::faer_as_mut(&mut lu.ucol), - |ucol| &mut **ucol, - )); - - for _ in 0..segsize { - let irow = lu.lsub[isub].zx(); - - lu.usub[nextu] = row_perm[irow]; - ucol.write(nextu, dense.read(irow)); - dense.write(irow, E::faer_zero()); - - nextu += 1; - isub += 1; - } - } - } - } - lu.xusub[jcol + 1] = I(nextu); - - Ok(()) - } - - #[inline(never)] - fn column_bmod( - jcol: usize, - nseg: usize, - mut dense: SliceGroupMut<'_, E>, - work: &mut alloc::vec::Vec, - segrep: &mut [I], - repfnz: &mut [I::Signed], - fpanelc: usize, - lu: &mut SupernodalLu, - ) -> Result<(), FaerError> { - let I = I::truncate; - - let jsupno = lu.supno[jcol].zx(); - - for k in (0..nseg).rev() { - let krep = segrep[k].zx(); - let ksupno = lu.supno[krep]; - if I(jsupno) != ksupno { - let ksupno = ksupno.zx(); - let fsupc = lu.xsup[ksupno].zx(); - let fst_col = Ord::max(fsupc, fpanelc); - let d_fsupc = fst_col - fsupc; - let mut luptr = lu.xlusup[fst_col].zx() + d_fsupc; - let lptr = lu.xlsub[fsupc].zx() + d_fsupc; - let kfnz = repfnz[krep]; - let kfnz = Ord::max(kfnz, I(fpanelc).to_signed()).zx(); - let segsize = krep + 1 - kfnz; - let nsupc = krep + 1 - fst_col; - let nsupr = (lu.xlsub[fsupc + 1] - lu.xlsub[fsupc]).zx(); - let nrow = nsupr - d_fsupc - nsupc; - let lda = (lu.xlusup[fst_col + 1] - lu.xlusup[fst_col]).zx(); - let no_zeros = kfnz - fst_col; - - lu_kernel_bmod( - segsize, - dense.rb_mut(), - work, - SliceGroup::<'_, E>::new(E::faer_map(E::faer_as_ref(&lu.lusup), |lusup| { - &**lusup - })), - &mut luptr, - lda, - nrow, - &lu.lsub, - lptr, - no_zeros, - )?; - } - } - - let mut nextlu = lu.xlusup[jcol].zx(); - let fsupc = lu.xsup[jsupno].zx(); - - let new_next = nextlu + lu.xlsub[fsupc + 1].zx() - lu.xlsub[fsupc].zx(); - resize_scalar::(&mut lu.lusup, new_next, false, false)?; - - let mut lusup = - SliceGroupMut::<'_, E>::new(E::faer_map(E::faer_as_mut(&mut lu.lusup), |x| &mut **x)); - - for isub in lu.xlsub[fsupc].zx()..lu.xlsub[fsupc + 1].zx() { - let irow = lu.lsub[isub].zx(); - lusup.write(nextlu, dense.read(irow)); - dense.write(irow, E::faer_zero()); - nextlu += 1; - } - lu.xlusup[jcol + 1] = I(nextlu); - let fst_col = Ord::max(fsupc, fpanelc); - - if fst_col < jcol { - let d_fsupc = fst_col - fsupc; - let luptr = lu.xlusup[fst_col].zx() + d_fsupc; - let nsupr = lu.xlsub[fsupc + 1].zx() - lu.xlsub[fsupc].zx(); - let nsupc = jcol - fst_col; - let nrow = nsupr - d_fsupc - nsupc; - - let ufirst = lu.xlusup[jcol].zx() + d_fsupc; - let lda = lu.xlusup[jcol + 1].zx() - lu.xlusup[jcol].zx(); - - let (left, right) = lusup.rb_mut().split_at(ufirst); - let (mid, right) = right.split_at(nsupc); - - let A = faer_core::mat::from_column_major_slice_with_stride::<'_, E>( - left.rb().subslice(luptr..left.len()).into_inner(), - nsupr, - nsupc, - lda, - ); - let (A_top, A_bot) = A.split_at_row(nsupc); - let A_bot = A_bot.subrows(0, nrow); - - let mut l = faer_core::mat::from_column_major_slice_mut::<'_, E>( - right.subslice(0..nrow).into_inner(), - nrow, - 1, - ); - let mut u = faer_core::mat::from_column_major_slice_mut::<'_, E>( - mid.subslice(0..nsupc).into_inner(), - nsupc, - 1, - ); - - solve::solve_unit_lower_triangular_in_place(A_top, u.rb_mut(), Parallelism::None); - - mul::matmul( - l.rb_mut(), - A_bot, - u.rb(), - Some(E::faer_one()), - E::faer_one().faer_neg(), - Parallelism::None, - ); - } - - Ok(()) - } - - #[inline(never)] - fn lu_kernel_bmod( - segsize: usize, - mut dense: SliceGroupMut<'_, E>, - work: &mut alloc::vec::Vec, - lusup: SliceGroup<'_, E>, - luptr: &mut usize, - lda: usize, - nrow: usize, - lsub: &[I], - lptr: usize, - no_zeros: usize, - ) -> Result<(), FaerError> { - if segsize == 1 { - let f = dense.read(lsub[lptr + no_zeros].zx()); - *luptr += lda * no_zeros + no_zeros + 1; - let a = lusup.subslice(*luptr..*luptr + nrow); - let irow = &lsub[lptr + no_zeros + 1..][..nrow]; - let mut i = 0; - let chunk2 = irow.chunks_exact(2); - let rem2 = chunk2.remainder(); - for i0i1 in chunk2 { - let i0 = i0i1[0].zx(); - let i1 = i0i1[1].zx(); - - unsafe { - let a0 = a.read_unchecked(i); - let a1 = a.read_unchecked(i + 1); - - let d0 = dense.read_unchecked(i0); - let d1 = dense.read_unchecked(i1); - dense.write_unchecked(i0, d0.faer_sub(f.faer_mul(a0))); - dense.write_unchecked(i1, d1.faer_sub(f.faer_mul(a1))); - } - - i += 2; - } - for i0 in rem2 { - let i0 = i0.zx(); - let a0 = a.read(i); - let d0 = dense.read(i0); - dense.write(i0, d0.faer_sub(f.faer_mul(a0))); - } - } else { - resize_work(work, temp_mat_req::(segsize + nrow, 1), false)?; - let stack = PodStack::new(work); - let (_, mut storage) = E::faer_map_with_context(stack, E::UNIT, &mut |stack, ()| { - let (storage, stack) = - stack.make_aligned_raw::(segsize + nrow, faer_core::CACHELINE_ALIGN); - (stack, storage) - }); - let mut tempv = unsafe { - faer_core::mat::from_raw_parts_mut::<'_, E>( - E::faer_map(E::faer_as_mut(&mut storage), |storage| { - storage.as_mut_ptr() as *mut E::Unit - }), - segsize + nrow, - 1, - 1, - 1, - ) - }; - - let mut isub = lptr + no_zeros; - for i in 0..segsize { - let irow = lsub[isub].zx(); - tempv.write(i, 0, dense.read(irow)); - isub += 1; - } - - assert!(lda > 0); - assert!(*luptr + (segsize + nrow - 1) + (segsize - 1) * lda < lusup.len()); - *luptr += lda * no_zeros + no_zeros; - - let A = faer_core::mat::from_column_major_slice_with_stride::<'_, E>( - lusup.subslice(*luptr..lusup.len()).into_inner(), - segsize + nrow, - segsize, - lda, - ); - *luptr += segsize; - - let (A, B) = A.split_at_row(segsize); - let B = B.subrows(0, nrow); - - let (mut u, mut l) = tempv.rb_mut().split_at_row_mut(segsize); - - solve::solve_unit_lower_triangular_in_place(A, u.rb_mut(), Parallelism::None); - mul::matmul( - l.rb_mut(), - B, - u.rb(), - None, - E::faer_one(), - Parallelism::None, - ); - let mut isub = lptr + no_zeros; - for i in 0..segsize { - let irow = lsub[isub].zx(); - isub += 1; - dense.write(irow, u.read(i, 0)); - } - for i in 0..nrow { - let irow = lsub[isub].zx(); - isub += 1; - dense.write(irow, dense.read(irow).faer_sub(l.read(i, 0))); - } - - for i in 0..segsize { - tempv.write(i, 0, E::faer_zero()); - } - } - Ok(()) - } - - #[inline(never)] - fn column_dfs( - m: usize, - jcol: usize, - row_perm: &[I], - maxsuper: usize, - nseg: &mut usize, - lsub_col: &mut [I::Signed], - segrep: &mut [I], - repfnz: &mut [I::Signed], - xprune: &mut [I], - marker: &mut [I::Signed], - parent: &mut [I], - xplore: &mut [I], - lu: &mut SupernodalLu, - ) -> Result<(), FaerError> { - let I = I::truncate; - - let mut jsuper = lu.supno[jcol].zx(); - let mut nextl = lu.xlsub[jcol].zx(); - let marker2 = &mut marker[2 * m..][..m]; - - let mut k = 0; - while if k < m { - lsub_col[k] != I(NONE).to_signed() - } else { - false - } { - let krow = lsub_col[k].zx(); - lsub_col[k] = I(NONE).to_signed(); - let kmark = marker2[krow]; - - if kmark == I(jcol).to_signed() { - k += 1; - continue; - } - - dfs_kernel_for_column_dfs( - &mut jsuper, - jcol, - row_perm, - nseg, - segrep, - repfnz, - xprune, - marker2, - parent, - xplore, - lu, - &mut nextl, - krow, - )?; - - k += 1; - } - - let jcolp1 = jcol + 1; - - let mut nsuper = lu.supno[jcol].zx(); - if jcol == 0 { - nsuper = 0; - lu.supno[0] = I(0); - } else { - let jcolm1 = jcol - 1; - let fsupc = lu.xsup[nsuper].zx(); - let jptr = lu.xlsub[jcol].zx(); - let jm1ptr = lu.xlsub[jcolm1].zx(); - - if nextl - jptr != jptr - jm1ptr - 1 { - jsuper = NONE; - } - if jcol - fsupc >= maxsuper { - jsuper = NONE; - } - - if jsuper == NONE { - if fsupc + 1 < jcolm1 { - let mut ito = lu.xlsub[fsupc + 1]; - lu.xlsub[jcolm1] = ito; - let istop = ito + I(jptr) - I(jm1ptr); - xprune[jcolm1] = istop; - lu.xlsub[jcol] = istop; - - for ifrom in jm1ptr..nextl { - lu.lsub[ito.zx()] = lu.lsub[ifrom]; - ito += I(1); - } - nextl = ito.zx(); - } - nsuper += 1; - lu.supno[jcol] = I(nsuper); - } - } - - lu.xsup[nsuper + 1] = I(jcolp1); - lu.supno[jcolp1] = I(nsuper); - xprune[jcol] = I(nextl); - lu.xlsub[jcolp1] = I(nextl); - - Ok(()) - } - - #[inline(never)] - fn panel_bmod( - m: usize, - w: usize, - jcol: usize, - nseg: usize, - mut dense: SliceGroupMut<'_, E>, - work: &mut alloc::vec::Vec, - segrep: &mut [I], - repfnz: &mut [I::Signed], - lu: &mut SupernodalLu, - parallelism: Parallelism, - ) -> Result<(), FaerError> { - let I = I::truncate; - for k in (0..nseg).rev() { - let krep = segrep[k].zx(); - let fsupc = lu.xsup[lu.supno[krep].zx()].zx(); - let nsupc = krep + 1 - fsupc; - let nsupr = (lu.xlsub[fsupc + 1] - lu.xlsub[fsupc]).zx(); - let nrow = nsupr - nsupc; - let lptr = lu.xlsub[fsupc].zx(); - - let mut u_rows = 0usize; - let mut u_cols = 0usize; - - for jj in jcol..jcol + w { - let nextl_col = (jj - jcol) * m; - let repfnz_col = &mut repfnz[nextl_col..][..m]; - let kfnz = repfnz_col[krep]; - if kfnz == I(NONE).to_signed() { - continue; - } - - let segsize = krep + 1 - kfnz.zx(); - u_cols += 1; - u_rows = Ord::max(u_rows, segsize); - } - - if nsupc >= 2 { - resize_work(work, temp_mat_req::(u_rows + nrow, u_cols), false)?; - let stack = PodStack::new(work); - let tmp_lda = faer_core::col_stride::(u_rows + nrow); - let (_, mut storage) = - E::faer_map_with_context(stack, E::UNIT, &mut |stack, ()| { - let (storage, stack) = stack.make_aligned_raw::( - tmp_lda * u_cols, - faer_core::CACHELINE_ALIGN, - ); - (stack, storage) - }); - let mut tempv = unsafe { - faer_core::mat::from_raw_parts_mut::<'_, E>( - E::faer_map(E::faer_as_mut(&mut storage), |storage| { - storage.as_mut_ptr() as *mut E::Unit - }), - u_rows + nrow, - u_cols, - 1, - tmp_lda as isize, - ) - }; - - let (mut U, mut L) = tempv.rb_mut().split_at_row_mut(u_rows); - - let mut u_col = 0usize; - for jj in jcol..jcol + w { - let nextl_col = (jj - jcol) * m; - let repfnz_col = &mut repfnz[nextl_col..][..m]; - let dense_col = dense.rb_mut().subslice(nextl_col..nextl_col + m); - - let kfnz = repfnz_col[krep]; - if kfnz == I(NONE).to_signed() { - continue; - } - - let segsize = krep + 1 - kfnz.zx(); - let no_zeros = kfnz.zx() - fsupc; - let isub = lptr + no_zeros; - let off = u_rows - segsize; - assert!(off <= U.nrows()); - assert!(u_col < U.ncols()); - for i in 0..off { - U.write(i, u_col, E::faer_zero()); - } - let mut U = U.rb_mut().get_mut(off.., ..); - assert!(segsize <= U.nrows()); - assert!(u_col < U.ncols()); - for (i, irow) in lu.lsub[isub..][..segsize].iter().enumerate() { - U.write(i, u_col, dense_col.read(irow.zx())); - } - u_col += 1; - } - let mut luptr = lu.xlusup[fsupc].zx(); - let lda = (lu.xlusup[fsupc + 1] - lu.xlusup[fsupc]).zx(); - let no_zeros = (krep + 1 - u_rows) - fsupc; - luptr += lda * no_zeros + no_zeros; - let l_val = to_slice_group::(&lu.lusup); - let A = faer_core::mat::from_column_major_slice_with_stride::<'_, E>( - l_val.subslice(luptr..l_val.len()).into_inner(), - u_rows + nrow, - u_rows, - lda, - ); - let (A, B) = A.split_at_row(u_rows); - let B = B.subrows(0, nrow); - - solve::solve_unit_lower_triangular_in_place(A, U.rb_mut(), parallelism); - - mul::matmul(L.rb_mut(), B, U.rb(), None, E::faer_one(), parallelism); - let mut u_col = 0usize; - for jj in jcol..jcol + w { - let nextl_col = (jj - jcol) * m; - let repfnz_col = &mut repfnz[nextl_col..][..m]; - let mut dense_col = dense.rb_mut().subslice(nextl_col..nextl_col + m); - - let kfnz = repfnz_col[krep]; - if kfnz == I(NONE).to_signed() { - continue; - } - - let segsize = krep + 1 - kfnz.zx(); - let no_zeros = kfnz.zx() - fsupc; - let mut isub = lptr + no_zeros; - let off = u_rows - segsize; - - let mut U = U.rb_mut().get_mut(off.., ..); - assert!(segsize <= U.nrows()); - assert!(u_col < U.ncols()); - for (i, irow) in lu.lsub[isub..][..segsize].iter().enumerate() { - let irow = irow.zx(); - unsafe { - dense_col.write_unchecked(irow, U.read_unchecked(i, u_col)); - U.write_unchecked(i, u_col, E::faer_zero()); - } - } - isub += segsize; - assert!(nrow <= L.nrows()); - assert!(u_col < L.ncols()); - for (i, irow) in lu.lsub[isub..][..nrow].iter().enumerate() { - let irow = irow.zx(); - unsafe { - dense_col.write_unchecked( - irow, - dense_col - .read_unchecked(irow) - .faer_sub(L.read_unchecked(i, u_col)), - ); - L.write_unchecked(i, u_col, E::faer_zero()); - } - } - u_col += 1; - } - } else { - for jj in jcol..jcol + w { - let nextl_col = (jj - jcol) * m; - let repfnz_col = &mut repfnz[nextl_col..][..m]; - let dense_col = dense.rb_mut().subslice(nextl_col..nextl_col + m); - - let kfnz = repfnz_col[krep]; - if kfnz == I(NONE).to_signed() { - continue; - } - let kfnz = kfnz.zx(); - let segsize = krep + 1 - kfnz; - let mut luptr = lu.xlusup[fsupc].zx(); - let lda = lu.xlusup[fsupc + 1].zx() - lu.xlusup[fsupc].zx(); - let no_zeros = kfnz - fsupc; - - lu_kernel_bmod( - segsize, - dense_col, - work, - SliceGroup::<'_, E>::new(E::faer_map(E::faer_as_ref(&lu.lusup), |lusup| { - &**lusup - })), - &mut luptr, - lda, - nrow, - &mut lu.lsub, - lptr, - no_zeros, - )?; - } - } - } - Ok(()) - } - - #[inline] - fn to_slice_group_mut( - v: &mut GroupFor>, - ) -> SliceGroupMut<'_, E> { - SliceGroupMut::<'_, E>::new(E::faer_map(E::faer_as_mut(v), |v| &mut **v)) - } - #[inline] - fn to_slice_group(v: &GroupFor>) -> SliceGroup<'_, E> { - SliceGroup::<'_, E>::new(E::faer_map(E::faer_as_ref(v), |v| &**v)) - } - - #[inline(never)] - fn snode_dfs( - jcol: usize, - kcol: usize, - A: SparseColMatRef<'_, I, E>, - row_perm: &[I], - col_perm: &[I], - xprune: &mut [I], - marker: &mut [I::Signed], - lu: &mut SupernodalLu, - ) -> Result<(), FaerError> { - let I = I::truncate; - let SI = I::Signed::truncate; - - // TODO: handle non leaf nodes properly - let _ = row_perm; - - let nsuper = lu.supno[jcol].zx().wrapping_add(1); - lu.supno[jcol] = I(nsuper); - let mut nextl = lu.xlsub[jcol].zx(); - - for i in jcol..kcol { - let i_p = col_perm[i].zx(); - for krow in A.row_indices_of_col(i_p) { - let kmark = marker[krow].zx(); - if kmark != kcol - 1 { - marker[krow] = SI(kcol - 1); - resize_index::(&mut lu.lsub, nextl + 1, false, false)?; - lu.lsub[nextl] = I(krow); - nextl += 1; - } - } - lu.supno[i] = I(nsuper); - } - - if jcol + 1 < kcol { - let new_next = nextl + (nextl - lu.xlsub[jcol].zx()); - resize_index::(&mut lu.lsub, new_next, false, false)?; - - lu.lsub.copy_within(lu.xlsub[jcol].zx()..nextl, nextl); - for i in jcol + 1..kcol { - lu.xlsub[i] = I(nextl); - } - nextl = new_next; - } - - lu.xsup[nsuper + 1] = I(kcol); - lu.supno[kcol] = I(nsuper); - xprune[kcol - 1] = I(nextl); - lu.xlsub[kcol] = I(nextl); - - Ok(()) - } - - #[inline(never)] - fn snode_bmod_setup( - jcol: usize, - fsupc: usize, - mut dense: SliceGroupMut<'_, E>, - lu: &mut SupernodalLu, - ) { - let I = I::truncate; - - let mut nextlu = lu.xlusup[jcol].zx(); - let mut lusup = to_slice_group_mut::(&mut lu.lusup); - for isub in lu.xlsub[fsupc].zx()..lu.xlsub[fsupc + 1].zx() { - let irow = lu.lsub[isub].zx(); - lusup.write(nextlu, dense.read(irow)); - dense.write(irow, E::faer_zero()); - nextlu += 1; - } - - lu.xlusup[jcol + 1] = I(nextlu); - } - - #[allow(dead_code)] - #[inline(never)] - fn snode_bmod( - jcol: usize, - kcol: usize, - _jsupno: usize, - fsupc: usize, - lu: &mut SupernodalLu, - parallelism: Parallelism, - ) { - let ufirst = lu.xlusup[jcol].zx(); - let left_lda = lu.xlusup[fsupc + 1].zx() - lu.xlusup[fsupc].zx(); - let right_lda = lu.xlusup[jcol + 1].zx() - lu.xlusup[jcol].zx(); - let nsupr = (lu.xlsub[fsupc + 1] - lu.xlsub[fsupc]).zx(); - let nsupc = jcol - fsupc; - let luptr = lu.xlusup[fsupc].zx(); - - let (left, right) = to_slice_group_mut::(&mut lu.lusup).split_at(ufirst); - - let (A, B) = faer_core::mat::from_column_major_slice_with_stride::<'_, E>( - left.rb() - .subslice(luptr..luptr + left_lda * nsupc) - .into_inner(), - nsupr, - nsupc, - left_lda, - ) - .subrows(0, nsupr) - .split_at_row(nsupc); - - let (mut top, mut bot) = faer_core::mat::from_column_major_slice_with_stride_mut::<'_, E>( - right.subslice(0..right_lda * (kcol - jcol)).into_inner(), - nsupr, - kcol - jcol, - right_lda, - ) - .subrows_mut(0, nsupr) - .split_at_row_mut(nsupc); - - solve::solve_unit_lower_triangular_in_place(A, top.rb_mut(), parallelism); - mul::matmul( - bot.rb_mut(), - B, - top.rb(), - Some(E::faer_one()), - E::faer_one().faer_neg(), - parallelism, - ); - } - - #[inline(never)] - fn snode_lu( - kcol: usize, - _jsupno: usize, - fsupc: usize, - transpositions: &mut [I], - lu: &mut SupernodalLu, - parallelism: Parallelism, - ) { - let lda = lu.xlusup[fsupc + 1].zx() - lu.xlusup[fsupc].zx(); - let nsupr = (lu.xlsub[fsupc + 1] - lu.xlsub[fsupc]).zx(); - let nsupc = kcol - fsupc; - let luptr = lu.xlusup[fsupc].zx(); - - let lu_val = to_slice_group_mut::(&mut lu.lusup); - let lu_val_len = lu_val.len(); - let mut LU = faer_core::mat::from_column_major_slice_with_stride_mut::<'_, E>( - lu_val.subslice(luptr..lu_val_len).into_inner(), - nsupr, - nsupc, - lda, - ); - assert!(LU.nrows() >= LU.ncols()); - faer_lu::partial_pivoting::compute::lu_in_place_impl( - LU.rb_mut(), - 0, - nsupc, - transpositions, - parallelism, - ); - } - - #[track_caller] - pub fn factorize_supernodal_numeric_lu( - row_perm: &mut [I], - row_perm_inv: &mut [I], - col_perm: &mut [I], - col_perm_inv: &mut [I], - lu: &mut SupernodalLu, - work: &mut alloc::vec::Vec, - - A: SparseColMatRef<'_, I, E>, - fill_reducing_col_perm: PermutationRef<'_, I, E>, - etree: EliminationTreeRef<'_, I>, - - parallelism: Parallelism, - stack: PodStack<'_>, - params: SupernodalLuParams, - ) -> Result<(), LuError> { - let I = I::truncate; - - let m = A.nrows(); - let n = A.ncols(); - - assert!(row_perm.len() == m); - assert!(row_perm_inv.len() == m); - assert!(fill_reducing_col_perm.len() == n); - assert!(etree.into_inner().len() == n); - - let (row_perm, row_perm_inv) = (row_perm_inv, row_perm); - - let a_nnz = A.compute_nnz(); - lu.nrows = 0; - lu.ncols = 0; - lu.nsupernodes = 0; - - let maxpanel = params.panel_size.checked_mul(m).unwrap(); - - let (descendants, stack) = stack.make_raw::(n); - let (relax_end, mut stack) = stack.make_raw::(n); - - { - let (post, stack) = stack.rb_mut().make_raw::(n); - let (post_inv, mut stack) = stack.make_raw::(n); - - Size::with(n, |N| { - crate::qr::postorder::(post, etree, stack.rb_mut()); - for i in 0..n { - post_inv[post[i].zx()] = I(i); - } - relaxed_supernodes( - etree.ghost_inner(N), - Array::from_ref(Idx::from_slice_ref_checked(&post, N), N), - Array::from_ref(Idx::from_slice_ref_checked(&post_inv, N), N), - params.relax, - Array::from_mut(descendants, N), - Array::from_mut(bytemuck::cast_slice_mut(relax_end), N), - ); - }); - - for i in 0..n { - col_perm[i] = fill_reducing_col_perm.into_arrays().0[post[i].zx()]; - } - } - for i in 0..n { - col_perm_inv[col_perm[i].zx()] = I(i); - } - let col_perm = PermutationRef::new_checked(&col_perm, &col_perm_inv); - - let (repfnz, stack) = stack.make_raw::(maxpanel); - let (panel_lsub, stack) = stack.make_raw::(maxpanel); - let (marker, stack) = stack.make_raw::(m.checked_mul(3).unwrap()); - let (segrep, stack) = stack.make_raw::(m); - let (parent, stack) = stack.make_raw::(m); - let (xplore, stack) = stack.make_raw::(m); - let (xprune, stack) = stack.make_raw::(n); - let (transpositions, stack) = stack.make_raw::(n); - - let (mut dense, _) = crate::make_raw::(maxpanel, stack); - - lu.mem_init(m, n, a_nnz, params.fill_factor)?; - - mem::fill_none::(bytemuck::cast_slice_mut(row_perm)); - - mem::fill_zero(segrep); - mem::fill_zero(parent); - mem::fill_zero(xplore); - mem::fill_zero(xprune); - - mem::fill_none(marker); - mem::fill_none(repfnz); - mem::fill_none(panel_lsub); - - dense.fill_zero(); - - lu.supno[0] = I(NONE); - lu.xlsub[0] = I(0); - lu.xusub[0] = I(0); - lu.xlusup[0] = I(0); - mem::fill_zero(&mut lu.xsup); - mem::fill_none::(bytemuck::cast_slice_mut(&mut lu.supno)); - - let mut jcol = 0; - - let diag_pivot_thresh = E::Real::faer_one(); - while jcol < n { - if relax_end[jcol] != I(NONE) && relax_end[jcol].zx() - jcol + 1 > 16 { - let kcol = relax_end[jcol].zx() + 1; - snode_dfs( - jcol, - kcol, - A, - row_perm, - col_perm.into_arrays().0, - xprune, - marker, - lu, - )?; - - let nextu = lu.xusub[jcol].zx(); - let nextlu = lu.xlusup[jcol].zx(); - let jsupno = lu.supno[jcol].zx(); - let fsupc = lu.xsup[jsupno].zx(); - let new_next = - nextlu + (lu.xlsub[fsupc + 1] - lu.xlsub[fsupc]).zx() * (kcol - jcol); - resize_scalar::(&mut lu.lusup, new_next, false, false)?; - - for icol in jcol..kcol { - lu.xusub[icol + 1] = I(nextu); - - let icol_p = col_perm.into_arrays().0[icol].zx(); - for (i, val) in zip( - A.row_indices_of_col(icol_p), - SliceGroup::<'_, E>::new(A.values_of_col(icol_p)).into_ref_iter(), - ) { - dense.write(i, val.read()); - } - - // TODO: handle non leaf nodes properly - snode_bmod_setup(icol, fsupc, dense.rb_mut(), lu); - } - let transpositions = &mut transpositions[jcol..kcol]; - snode_lu(kcol, jsupno, fsupc, transpositions, lu, parallelism); - for (idx, t) in transpositions.iter().enumerate() { - let t = t.zx(); - let j = jcol + idx; - let lptr = lu.xlsub[fsupc].zx(); - let pivrow = lu.lsub[lptr + idx + t].zx(); - row_perm[pivrow] = I(j); - lu.lsub.swap(lptr + idx, lptr + idx + t); - } - - jcol = kcol; - } else { - let mut panel_size = params.panel_size; - - let mut nseg1 = 0; - let mut nseg; - - { - let mut k = jcol + 1; - while k < Ord::min(jcol.saturating_add(panel_size), n) { - if relax_end[k] != I(NONE) { - panel_size = k - jcol; - break; - } - k += 1; - } - if k == n { - panel_size = n - jcol; - } - } - - panel_dfs( - m, - panel_size, - jcol, - A, - col_perm.into_arrays().0, - row_perm, - &mut nseg1, - dense.rb_mut(), - panel_lsub, - segrep, - repfnz, - xprune, - marker, - parent, - xplore, - lu, - ); - panel_bmod( - m, - panel_size, - jcol, - nseg1, - dense.rb_mut(), - work, - segrep, - repfnz, - lu, - parallelism, - )?; - - for jj in jcol..jcol + panel_size { - let k = (jj - jcol) * m; - nseg = nseg1; - - let panel_lsubk = &mut panel_lsub[k..][..m]; - let repfnz_k = &mut repfnz[k..][..m]; - column_dfs( - m, - jj, - row_perm, - params.max_super, - &mut nseg, - panel_lsubk, - segrep, - repfnz_k, - xprune, - marker, - parent, - xplore, - lu, - )?; - - let mut dense_k = dense.rb_mut().subslice(k..k + m); - let segrep_k = &mut segrep[nseg1..m]; - - column_bmod( - jj, - nseg - nseg1, - dense_k.rb_mut(), - work, - segrep_k, - repfnz_k, - jcol, - lu, - )?; - copy_to_ucol(jj, nseg, segrep, repfnz_k, row_perm, dense_k, lu)?; - - let mut pivrow = 0usize; - if pivot_l( - jj, - diag_pivot_thresh, - row_perm, - col_perm.inverse(), - &mut pivrow, - lu, - ) { - return Err(LuError::ZeroColumn(jj)); - } - - prune_l(jj, row_perm, pivrow, nseg, segrep, repfnz_k, xprune, lu); - - for &irep in &segrep[..nseg] { - repfnz_k[irep.zx()] = I(NONE).to_signed(); - } - } - jcol += panel_size; - } - } - - for i in 0..m { - row_perm_inv[row_perm[i].zx()] = I(i); - } - - let mut nextl = 0usize; - let nsuper = lu.supno[n].zx(); - - for i in 0..nsuper + 1 { - let fsupc = lu.xsup[i].zx(); - let jstart = lu.xlsub[fsupc].zx(); - lu.xlsub[fsupc] = I(nextl); - for j in jstart..lu.xlsub[fsupc + 1].zx() { - lu.lsub[nextl] = row_perm[lu.lsub[j].zx()]; - nextl += 1; - } - for k in fsupc + 1..lu.xsup[i + 1].zx() { - lu.xlsub[k] = I(nextl); - } - } - lu.xlsub[n] = I(nextl); - lu.nrows = m; - lu.ncols = n; - lu.nsupernodes = nsuper + 1; - - Ok(()) - } -} - -#[cfg(test)] -#[cfg(__false)] -mod tests { - use super::{supernodal::SupernodalLu, *}; - use crate::{ - cholesky::simplicial::EliminationTreeRef, - ghost, - mem::NONE, - qr::{col_etree, ghost_col_etree}, - superlu::supernodal::factorize_supernodal_numeric_lu, - SymbolicSparseColMatRef, - }; - use core::iter::zip; - use dyn_stack::{GlobalPodBuffer, PodStack, StackReq}; - use faer_core::{ - assert, c64, group_helpers::SliceGroup, permutation::PermutationRef, - sparse::SparseColMatRef, Conj, Mat, - }; - use faer_entity::ComplexField; - use matrix_market_rs::MtxData; - use rand::{Rng, SeedableRng}; - - fn sparse_to_dense(sparse: SparseColMatRef<'_, I, E>) -> Mat { - let m = sparse.nrows(); - let n = sparse.ncols(); - - let mut dense = Mat::::zeros(m, n); - let slice_group = SliceGroup::<'_, E>::new; - - for j in 0..n { - for (i, val) in zip( - sparse.row_indices_of_col(j), - slice_group(sparse.values_of_col(j)).into_ref_iter(), - ) { - dense.write(i, j, val.read()); - } - } - - dense - } - - fn test_supernodes() { - let I = I::truncate; - - let n = 11; - let col_ptr = &[0, 3, 6, 10, 13, 16, 21, 24, 29, 31, 37, 43].map(I); - let row_ind = &[ - 0, 5, 6, // 0 - 1, 2, 7, // 1 - 1, 2, 9, 10, // 2 - 3, 5, 9, // 3 - 4, 7, 10, // 4 - 0, 3, 5, 8, 9, // 5 - 0, 6, 10, // 6 - 1, 4, 7, 9, 10, // 7 - 5, 8, // 8 - 2, 3, 5, 7, 9, 10, // 9 - 2, 4, 6, 7, 9, 10, // 10 - ] - .map(I); - - let A = SymbolicSparseColMatRef::new_checked(n, n, col_ptr, None, row_ind); - let zero = I(0); - let mut etree = vec![zero.to_signed(); n]; - let mut descendants = vec![zero; n]; - let mut relax_end = vec![zero.to_signed(); n]; - ghost::with_size(n, |N| { - let A = ghost::SymbolicSparseColMatRef::new(A, N, N); - ghost_col_etree( - A, - None, - Array::from_mut(&mut etree, N), - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(2 * *N))), - ); - let mut post = vec![I(0); n]; - let mut post_inv = vec![I(0); n]; - crate::qr::postorder::( - &mut post, - EliminationTreeRef { inner: &etree }, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(3 * *N))), - ); - for i in 0..n { - post_inv[post[i].zx()] = I(i); - } - relaxed_supernodes( - Array::from_ref(MaybeIdx::from_slice_ref_checked(&etree, N), N), - Array::from_ref(Idx::from_slice_ref_checked(&post, N), N), - Array::from_ref(Idx::from_slice_ref_checked(&post_inv, N), N), - 1, - Array::from_mut(&mut descendants, N), - Array::from_mut(&mut relax_end, N), - ); - }); - assert!( - etree - == [3, 2, 3, 4, 5, 6, 7, 8, 9, 10, NONE] - .map(I) - .map(I::to_signed) - ); - assert!( - relax_end - == [0, 1, NONE, NONE, NONE, NONE, NONE, NONE, NONE, NONE, NONE] - .map(I) - .map(I::to_signed) - ); - } - - #[test] - fn test_numeric_lu_tiny() { - let n = 8; - let col_ptr = &[0, 4, 8, 12, 16, 21, 23, 27, 29]; - let row_ind = &[ - 1, 4, 6, 7, 0, 5, 6, 7, 0, 1, 3, 6, 0, 1, 4, 5, 0, 2, 5, 6, 7, 3, 4, 0, 4, 6, 7, 2, 4, - ]; - let val = &[ - 0.783099, 0.335223, 0.55397, 0.628871, 0.513401, 0.606969, 0.242887, 0.804177, - 0.400944, 0.108809, 0.512932, 0.637552, 0.972775, 0.771358, 0.891529, 0.352458, - 0.949327, 0.192214, 0.0641713, 0.457702, 0.23828, 0.53976, 0.760249, 0.437638, - 0.738534, 0.687861, 0.440105, 0.228968, 0.68667, - ]; - let A = SparseColMatRef::<'_, u32, f64>::new( - SymbolicSparseColMatRef::new_checked(n, n, col_ptr, None, row_ind), - val, - ); - - let mut etree = vec![0u32; n]; - let etree = col_etree( - *A, - None, - &mut etree, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(2 * n))), - ); - - let mut row_perm = vec![0u32; n]; - let mut row_perm_inv = vec![0u32; n]; - let mut col_perm = vec![0u32; n]; - let mut col_perm_inv = vec![0u32; n]; - let mut fill_col_perm = vec![0u32; n]; - let mut fill_col_perm_inv = vec![0u32; n]; - for i in 0..n { - fill_col_perm[i] = i as _; - fill_col_perm_inv[i] = i as _; - } - let fill_col_perm = PermutationRef::new_checked(&fill_col_perm, &fill_col_perm_inv); - - let mut lu = SupernodalLu::::new(); - - factorize_supernodal_numeric_lu( - &mut row_perm, - &mut row_perm_inv, - &mut col_perm, - &mut col_perm_inv, - &mut lu, - &mut vec![], - A, - fill_col_perm, - etree, - faer_core::Parallelism::None, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(1024 * 1024))), - Default::default(), - ) - .unwrap(); - { - let row_perm = PermutationRef::<'_, _, f64>::new_checked(&row_perm, &row_perm_inv); - let col_perm = PermutationRef::<'_, _, f64>::new_checked(&col_perm, &col_perm_inv); - let mut gen = rand::rngs::StdRng::seed_from_u64(0); - let A_dense = sparse_to_dense(A); - let k = 2; - let rhs = Mat::from_fn(n, k, |_, _| gen.gen::()); - let mut x = rhs.clone(); - - lu.solve_in_place_with_conj( - row_perm, - col_perm, - Conj::No, - x.as_mut(), - faer_core::Parallelism::None, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::( - 1024 * 1024, - ))), - ); - dbgf::dbgf!("?", &A_dense * &x - &rhs); - } - } - - #[test] - fn test_numeric_lu_small() { - let n = 16; - let col_ptr = &[ - 0, 2, 10, 15, 19, 20, 27, 34, 39, 44, 45, 53, 57, 61, 68, 71, 75, - ]; - let row_ind = &[ - 5, 7, 0, 1, 2, 3, 4, 6, 9, 13, 2, 5, 7, 11, 13, 1, 7, 11, 15, 3, 1, 3, 7, 10, 11, 13, - 14, 7, 9, 11, 12, 13, 14, 15, 6, 8, 9, 13, 15, 0, 1, 2, 5, 14, 6, 1, 3, 4, 5, 8, 10, - 13, 15, 6, 9, 14, 15, 4, 9, 12, 13, 0, 5, 6, 7, 8, 10, 15, 5, 7, 12, 2, 5, 9, 11, - ]; - let mut gen = rand::rngs::StdRng::seed_from_u64(0); - let val: &[c64; 75] = &core::array::from_fn(|_| c64::new(gen.gen(), gen.gen())); - let A = SparseColMatRef::<'_, usize, c64>::new( - SymbolicSparseColMatRef::new_checked(n, n, col_ptr, None, row_ind), - val, - ); - - let mut etree = vec![0usize; n]; - let etree = col_etree( - *A, - None, - &mut etree, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(2 * n))), - ); - - let mut row_perm = vec![0usize; n]; - let mut row_perm_inv = vec![0usize; n]; - let mut col_perm = vec![0usize; n]; - let mut col_perm_inv = vec![0usize; n]; - let fill_col_perm = vec![6, 2, 8, 7, 11, 4, 12, 14, 5, 3, 0, 9, 13, 15, 10, 1]; - let mut fill_col_perm_inv = vec![0usize; n]; - for i in 0..n { - fill_col_perm_inv[fill_col_perm[i]] = i; - } - let fill_col_perm = PermutationRef::new_checked(&fill_col_perm, &fill_col_perm_inv); - - let mut lu = SupernodalLu::::new(); - - factorize_supernodal_numeric_lu( - &mut row_perm, - &mut row_perm_inv, - &mut col_perm, - &mut col_perm_inv, - &mut lu, - &mut vec![], - A, - fill_col_perm, - etree, - faer_core::Parallelism::None, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::( - 1024 * 1024, - ))), - Default::default(), - ) - .unwrap(); - let row_perm = PermutationRef::<'_, _, c64>::new_checked(&row_perm, &row_perm_inv); - let A_dense = sparse_to_dense(A); - let k = 1; - let rhs = Mat::from_fn(n, k, |_, _| c64::new(gen.gen(), gen.gen())); - - { - let mut x = rhs.clone(); - let col_perm = PermutationRef::<'_, _, c64>::new_checked(&col_perm, &col_perm_inv); - lu.solve_in_place_with_conj( - row_perm, - col_perm, - Conj::No, - x.as_mut(), - faer_core::Parallelism::None, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::( - 1024 * 1024, - ))), - ); - assert!((&A_dense * &x - &rhs).norm_max() < 1e-14); - } - - { - let mut x = rhs.clone(); - let col_perm = PermutationRef::<'_, _, c64>::new_checked(&col_perm, &col_perm_inv); - lu.solve_transpose_in_place_with_conj( - row_perm, - col_perm, - Conj::No, - x.as_mut(), - faer_core::Parallelism::None, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::( - 1024 * 1024, - ))), - ); - assert!((A_dense.transpose() * &x - &rhs).norm_max() < 1e-14); - } - - { - let mut x = rhs.clone(); - let col_perm = PermutationRef::<'_, _, c64>::new_checked(&col_perm, &col_perm_inv); - lu.solve_in_place_with_conj( - row_perm, - col_perm, - Conj::Yes, - x.as_mut(), - faer_core::Parallelism::None, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::( - 1024 * 1024, - ))), - ); - assert!((A_dense.conjugate() * &x - &rhs).norm_max() < 1e-14); - } - - { - let mut x = rhs.clone(); - let col_perm = PermutationRef::<'_, _, c64>::new_checked(&col_perm, &col_perm_inv); - lu.solve_transpose_in_place_with_conj( - row_perm, - col_perm, - Conj::Yes, - x.as_mut(), - faer_core::Parallelism::None, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::( - 1024 * 1024, - ))), - ); - assert!((A_dense.adjoint() * &x - &rhs).norm_max() < 1e-14); - } - } - - fn load_mtx( - data: MtxData, - ) -> ( - usize, - usize, - alloc::vec::Vec, - alloc::vec::Vec, - alloc::vec::Vec, - ) { - let I = I::truncate; - - let MtxData::Sparse([nrows, ncols], coo_indices, coo_values, _) = data else { - panic!() - }; - - let m = nrows; - let n = ncols; - let mut col_counts = vec![I(0); n]; - let mut col_ptr = vec![I(0); n + 1]; - - for &[i, j] in &coo_indices { - col_counts[j] += I(1); - if i != j { - col_counts[i] += I(1); - } - } - - for i in 0..n { - col_ptr[i + 1] = col_ptr[i] + col_counts[i]; - } - let nnz = col_ptr[n].zx(); - - let mut row_ind = vec![I(0); nnz]; - let mut values = vec![0.0; nnz]; - - col_counts.copy_from_slice(&col_ptr[..n]); - - for (&[i, j], &val) in zip(&coo_indices, &coo_values) { - if i == j { - values[col_counts[j].zx()] = 2.0 * val; - } else { - values[col_counts[i].zx()] = val; - values[col_counts[j].zx()] = val; - } - - row_ind[col_counts[j].zx()] = I(i); - col_counts[j] += I(1); - - if i != j { - row_ind[col_counts[i].zx()] = I(j); - col_counts[i] += I(1); - } - } - - (m, n, col_ptr, row_ind, values) - } - - #[test] - fn test_numeric_lu_mtx() { - let (m, n, col_ptr, row_ind, val) = - load_mtx::(MtxData::from_file("bench_data/rijc781.mtx").unwrap()); - - let A = SparseColMatRef::<'_, usize, f64>::new( - SymbolicSparseColMatRef::new_checked(m, n, &col_ptr, None, &row_ind), - &val, - ); - - let mut etree = vec![0usize; n]; - let etree = col_etree( - *A, - None, - &mut etree, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(2 * n))), - ); - - let mut row_perm = vec![0usize; n]; - let mut row_perm_inv = vec![0usize; n]; - let mut col_perm = vec![0usize; n]; - let mut col_perm_inv = vec![0usize; n]; - let mut fill_col_perm = vec![0usize; n]; - let mut fill_col_perm_inv = vec![0usize; n]; - for i in 0..n { - fill_col_perm[i] = i; - fill_col_perm_inv[i] = i; - } - let fill_col_perm = PermutationRef::new_checked(&fill_col_perm, &fill_col_perm_inv); - - let mut lu = SupernodalLu::::new(); - - factorize_supernodal_numeric_lu( - &mut row_perm, - &mut row_perm_inv, - &mut col_perm, - &mut col_perm_inv, - &mut lu, - &mut vec![], - A, - fill_col_perm, - etree, - faer_core::Parallelism::None, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::( - 1024 * 1024, - ))), - Default::default(), - ) - .unwrap(); - - { - let row_perm = PermutationRef::<'_, _, f64>::new_checked(&row_perm, &row_perm_inv); - let col_perm = PermutationRef::<'_, _, f64>::new_checked(&col_perm, &col_perm_inv); - let mut gen = rand::rngs::StdRng::seed_from_u64(0); - let A_dense = sparse_to_dense(A); - let k = 2; - let rhs = Mat::from_fn(n, k, |_, _| gen.gen::()); - let mut x = rhs.clone(); - - lu.solve_in_place_with_conj( - row_perm, - col_perm, - Conj::No, - x.as_mut(), - faer_core::Parallelism::None, - PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::( - 1024 * 1024, - ))), - ); - assert!((&A_dense * &x - &rhs).norm_max() < 1e-14); - } - } - - monomorphize_test!(test_supernodes); -} diff --git a/faer-libs/faer-svd/Cargo.toml b/faer-libs/faer-svd/Cargo.toml deleted file mode 100644 index 1be95c7b4dabd527c72f948e2db1233e378241c9..0000000000000000000000000000000000000000 --- a/faer-libs/faer-svd/Cargo.toml +++ /dev/null @@ -1,60 +0,0 @@ -[package] -name = "faer-svd" -version = "0.17.1" -edition = "2021" -authors = ["sarah <>"] -description = "Basic linear algebra routines" -readme = "../../README.md" -repository = "https://github.com/sarah-ek/faer-rs/" -license = "MIT" -keywords = ["math", "matrix", "linear-algebra"] - -[dependencies] -faer-entity = { workspace = true, default-features = false } - -faer-core = { version = "0.17.1", default-features = false, path = "../faer-core" } -faer-qr = { version = "0.17.1", default-features = false, path = "../faer-qr" } - -coe-rs = { workspace = true } -reborrow = { workspace = true } -pulp = { workspace = true, default-features = false } -dyn-stack = { workspace = true, default-features = false } - -num-traits = { workspace = true, default-features = false } -num-complex = { workspace = true, default-features = false } -bytemuck = { workspace = true } - -log = { workspace = true, optional = true, default-features = false } -dbgf = "0.1.1" - -[dev-dependencies] -criterion = "0.5" -rand = "0.8.5" -nalgebra = "0.32.3" -assert_approx_eq = "1.1.0" - -[features] -default = ["std", "rayon"] -std = [ - "faer-core/std", - "faer-qr/std", - "pulp/std", -] -perf-warn = ["log", "faer-core/perf-warn"] -rayon = [ - "std", - "faer-core/rayon", - "faer-qr/rayon", -] -nightly = [ - "faer-core/nightly", - "faer-qr/nightly", - "pulp/nightly", -] - -[[bench]] -name = "bench" -harness = false - -[package.metadata.docs.rs] -rustdoc-args = ["--html-in-header", "katex-header.html"] diff --git a/faer-libs/faer-svd/LICENSE.MIT b/faer-libs/faer-svd/LICENSE.MIT deleted file mode 100644 index b3e9659c8860f4d82899554c214b91d46760ea59..0000000000000000000000000000000000000000 --- a/faer-libs/faer-svd/LICENSE.MIT +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2022 sarah - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/faer-libs/faer-svd/LICENSE.MPL2 b/faer-libs/faer-svd/LICENSE.MPL2 deleted file mode 100644 index ee6256cdb62a765749a71aae3abea32884301cd1..0000000000000000000000000000000000000000 --- a/faer-libs/faer-svd/LICENSE.MPL2 +++ /dev/null @@ -1,373 +0,0 @@ -Mozilla Public License Version 2.0 -================================== - -1. Definitions --------------- - -1.1. "Contributor" - means each individual or legal entity that creates, contributes to - the creation of, or owns Covered Software. - -1.2. "Contributor Version" - means the combination of the Contributions of others (if any) used - by a Contributor and that particular Contributor's Contribution. - -1.3. "Contribution" - means Covered Software of a particular Contributor. - -1.4. "Covered Software" - means Source Code Form to which the initial Contributor has attached - the notice in Exhibit A, the Executable Form of such Source Code - Form, and Modifications of such Source Code Form, in each case - including portions thereof. - -1.5. "Incompatible With Secondary Licenses" - means - - (a) that the initial Contributor has attached the notice described - in Exhibit B to the Covered Software; or - - (b) that the Covered Software was made available under the terms of - version 1.1 or earlier of the License, but not also under the - terms of a Secondary License. - -1.6. "Executable Form" - means any form of the work other than Source Code Form. - -1.7. "Larger Work" - means a work that combines Covered Software with other material, in - a separate file or files, that is not Covered Software. - -1.8. "License" - means this document. - -1.9. "Licensable" - means having the right to grant, to the maximum extent possible, - whether at the time of the initial grant or subsequently, any and - all of the rights conveyed by this License. - -1.10. "Modifications" - means any of the following: - - (a) any file in Source Code Form that results from an addition to, - deletion from, or modification of the contents of Covered - Software; or - - (b) any new file in Source Code Form that contains any Covered - Software. - -1.11. "Patent Claims" of a Contributor - means any patent claim(s), including without limitation, method, - process, and apparatus claims, in any patent Licensable by such - Contributor that would be infringed, but for the grant of the - License, by the making, using, selling, offering for sale, having - made, import, or transfer of either its Contributions or its - Contributor Version. - -1.12. "Secondary License" - means either the GNU General Public License, Version 2.0, the GNU - Lesser General Public License, Version 2.1, the GNU Affero General - Public License, Version 3.0, or any later versions of those - licenses. - -1.13. "Source Code Form" - means the form of the work preferred for making modifications. - -1.14. "You" (or "Your") - means an individual or a legal entity exercising rights under this - License. For legal entities, "You" includes any entity that - controls, is controlled by, or is under common control with You. For - purposes of this definition, "control" means (a) the power, direct - or indirect, to cause the direction or management of such entity, - whether by contract or otherwise, or (b) ownership of more than - fifty percent (50%) of the outstanding shares or beneficial - ownership of such entity. - -2. License Grants and Conditions --------------------------------- - -2.1. Grants - -Each Contributor hereby grants You a world-wide, royalty-free, -non-exclusive license: - -(a) under intellectual property rights (other than patent or trademark) - Licensable by such Contributor to use, reproduce, make available, - modify, display, perform, distribute, and otherwise exploit its - Contributions, either on an unmodified basis, with Modifications, or - as part of a Larger Work; and - -(b) under Patent Claims of such Contributor to make, use, sell, offer - for sale, have made, import, and otherwise transfer either its - Contributions or its Contributor Version. - -2.2. Effective Date - -The licenses granted in Section 2.1 with respect to any Contribution -become effective for each Contribution on the date the Contributor first -distributes such Contribution. - -2.3. Limitations on Grant Scope - -The licenses granted in this Section 2 are the only rights granted under -this License. No additional rights or licenses will be implied from the -distribution or licensing of Covered Software under this License. -Notwithstanding Section 2.1(b) above, no patent license is granted by a -Contributor: - -(a) for any code that a Contributor has removed from Covered Software; - or - -(b) for infringements caused by: (i) Your and any other third party's - modifications of Covered Software, or (ii) the combination of its - Contributions with other software (except as part of its Contributor - Version); or - -(c) under Patent Claims infringed by Covered Software in the absence of - its Contributions. - -This License does not grant any rights in the trademarks, service marks, -or logos of any Contributor (except as may be necessary to comply with -the notice requirements in Section 3.4). - -2.4. Subsequent Licenses - -No Contributor makes additional grants as a result of Your choice to -distribute the Covered Software under a subsequent version of this -License (see Section 10.2) or under the terms of a Secondary License (if -permitted under the terms of Section 3.3). - -2.5. Representation - -Each Contributor represents that the Contributor believes its -Contributions are its original creation(s) or it has sufficient rights -to grant the rights to its Contributions conveyed by this License. - -2.6. Fair Use - -This License is not intended to limit any rights You have under -applicable copyright doctrines of fair use, fair dealing, or other -equivalents. - -2.7. Conditions - -Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted -in Section 2.1. - -3. Responsibilities -------------------- - -3.1. Distribution of Source Form - -All distribution of Covered Software in Source Code Form, including any -Modifications that You create or to which You contribute, must be under -the terms of this License. You must inform recipients that the Source -Code Form of the Covered Software is governed by the terms of this -License, and how they can obtain a copy of this License. You may not -attempt to alter or restrict the recipients' rights in the Source Code -Form. - -3.2. Distribution of Executable Form - -If You distribute Covered Software in Executable Form then: - -(a) such Covered Software must also be made available in Source Code - Form, as described in Section 3.1, and You must inform recipients of - the Executable Form how they can obtain a copy of such Source Code - Form by reasonable means in a timely manner, at a charge no more - than the cost of distribution to the recipient; and - -(b) You may distribute such Executable Form under the terms of this - License, or sublicense it under different terms, provided that the - license for the Executable Form does not attempt to limit or alter - the recipients' rights in the Source Code Form under this License. - -3.3. Distribution of a Larger Work - -You may create and distribute a Larger Work under terms of Your choice, -provided that You also comply with the requirements of this License for -the Covered Software. If the Larger Work is a combination of Covered -Software with a work governed by one or more Secondary Licenses, and the -Covered Software is not Incompatible With Secondary Licenses, this -License permits You to additionally distribute such Covered Software -under the terms of such Secondary License(s), so that the recipient of -the Larger Work may, at their option, further distribute the Covered -Software under the terms of either this License or such Secondary -License(s). - -3.4. Notices - -You may not remove or alter the substance of any license notices -(including copyright notices, patent notices, disclaimers of warranty, -or limitations of liability) contained within the Source Code Form of -the Covered Software, except that You may alter any license notices to -the extent required to remedy known factual inaccuracies. - -3.5. Application of Additional Terms - -You may choose to offer, and to charge a fee for, warranty, support, -indemnity or liability obligations to one or more recipients of Covered -Software. However, You may do so only on Your own behalf, and not on -behalf of any Contributor. You must make it absolutely clear that any -such warranty, support, indemnity, or liability obligation is offered by -You alone, and You hereby agree to indemnify every Contributor for any -liability incurred by such Contributor as a result of warranty, support, -indemnity or liability terms You offer. You may include additional -disclaimers of warranty and limitations of liability specific to any -jurisdiction. - -4. Inability to Comply Due to Statute or Regulation ---------------------------------------------------- - -If it is impossible for You to comply with any of the terms of this -License with respect to some or all of the Covered Software due to -statute, judicial order, or regulation then You must: (a) comply with -the terms of this License to the maximum extent possible; and (b) -describe the limitations and the code they affect. Such description must -be placed in a text file included with all distributions of the Covered -Software under this License. Except to the extent prohibited by statute -or regulation, such description must be sufficiently detailed for a -recipient of ordinary skill to be able to understand it. - -5. Termination --------------- - -5.1. The rights granted under this License will terminate automatically -if You fail to comply with any of its terms. However, if You become -compliant, then the rights granted under this License from a particular -Contributor are reinstated (a) provisionally, unless and until such -Contributor explicitly and finally terminates Your grants, and (b) on an -ongoing basis, if such Contributor fails to notify You of the -non-compliance by some reasonable means prior to 60 days after You have -come back into compliance. Moreover, Your grants from a particular -Contributor are reinstated on an ongoing basis if such Contributor -notifies You of the non-compliance by some reasonable means, this is the -first time You have received notice of non-compliance with this License -from such Contributor, and You become compliant prior to 30 days after -Your receipt of the notice. - -5.2. If You initiate litigation against any entity by asserting a patent -infringement claim (excluding declaratory judgment actions, -counter-claims, and cross-claims) alleging that a Contributor Version -directly or indirectly infringes any patent, then the rights granted to -You by any and all Contributors for the Covered Software under Section -2.1 of this License shall terminate. - -5.3. In the event of termination under Sections 5.1 or 5.2 above, all -end user license agreements (excluding distributors and resellers) which -have been validly granted by You or Your distributors under this License -prior to termination shall survive termination. - -************************************************************************ -* * -* 6. Disclaimer of Warranty * -* ------------------------- * -* * -* Covered Software is provided under this License on an "as is" * -* basis, without warranty of any kind, either expressed, implied, or * -* statutory, including, without limitation, warranties that the * -* Covered Software is free of defects, merchantable, fit for a * -* particular purpose or non-infringing. The entire risk as to the * -* quality and performance of the Covered Software is with You. * -* Should any Covered Software prove defective in any respect, You * -* (not any Contributor) assume the cost of any necessary servicing, * -* repair, or correction. This disclaimer of warranty constitutes an * -* essential part of this License. No use of any Covered Software is * -* authorized under this License except under this disclaimer. * -* * -************************************************************************ - -************************************************************************ -* * -* 7. Limitation of Liability * -* -------------------------- * -* * -* Under no circumstances and under no legal theory, whether tort * -* (including negligence), contract, or otherwise, shall any * -* Contributor, or anyone who distributes Covered Software as * -* permitted above, be liable to You for any direct, indirect, * -* special, incidental, or consequential damages of any character * -* including, without limitation, damages for lost profits, loss of * -* goodwill, work stoppage, computer failure or malfunction, or any * -* and all other commercial damages or losses, even if such party * -* shall have been informed of the possibility of such damages. This * -* limitation of liability shall not apply to liability for death or * -* personal injury resulting from such party's negligence to the * -* extent applicable law prohibits such limitation. Some * -* jurisdictions do not allow the exclusion or limitation of * -* incidental or consequential damages, so this exclusion and * -* limitation may not apply to You. * -* * -************************************************************************ - -8. Litigation -------------- - -Any litigation relating to this License may be brought only in the -courts of a jurisdiction where the defendant maintains its principal -place of business and such litigation shall be governed by laws of that -jurisdiction, without reference to its conflict-of-law provisions. -Nothing in this Section shall prevent a party's ability to bring -cross-claims or counter-claims. - -9. Miscellaneous ----------------- - -This License represents the complete agreement concerning the subject -matter hereof. If any provision of this License is held to be -unenforceable, such provision shall be reformed only to the extent -necessary to make it enforceable. Any law or regulation which provides -that the language of a contract shall be construed against the drafter -shall not be used to construe this License against a Contributor. - -10. Versions of the License ---------------------------- - -10.1. New Versions - -Mozilla Foundation is the license steward. Except as provided in Section -10.3, no one other than the license steward has the right to modify or -publish new versions of this License. Each version will be given a -distinguishing version number. - -10.2. Effect of New Versions - -You may distribute the Covered Software under the terms of the version -of the License under which You originally received the Covered Software, -or under the terms of any subsequent version published by the license -steward. - -10.3. Modified Versions - -If you create software not governed by this License, and you want to -create a new license for such software, you may create and use a -modified version of this License if you rename the license and remove -any references to the name of the license steward (except to note that -such modified license differs from this License). - -10.4. Distributing Source Code Form that is Incompatible With Secondary -Licenses - -If You choose to distribute Source Code Form that is Incompatible With -Secondary Licenses under the terms of this version of the License, the -notice described in Exhibit B of this License must be attached. - -Exhibit A - Source Code Form License Notice -------------------------------------------- - - This Source Code Form is subject to the terms of the Mozilla Public - License, v. 2.0. If a copy of the MPL was not distributed with this - file, You can obtain one at https://mozilla.org/MPL/2.0/. - -If it is not possible or desirable to put the notice in a particular -file, then You may include the notice in a location (such as a LICENSE -file in a relevant directory) where a recipient would be likely to look -for such a notice. - -You may add additional accurate notices of copyright ownership. - -Exhibit B - "Incompatible With Secondary Licenses" Notice ---------------------------------------------------------- - - This Source Code Form is "Incompatible With Secondary Licenses", as - defined by the Mozilla Public License, v. 2.0. diff --git a/faer-libs/faer-svd/benches/bench.rs b/faer-libs/faer-svd/benches/bench.rs deleted file mode 100644 index b722d74eac0efd40753b5d7f65e8a377e652b1c8..0000000000000000000000000000000000000000 --- a/faer-libs/faer-svd/benches/bench.rs +++ /dev/null @@ -1,274 +0,0 @@ -use criterion::{criterion_group, criterion_main, Criterion}; -use faer_svd::{ - bidiag::bidiagonalize_in_place, bidiag_real_svd::compute_bidiag_real_svd, compute_svd, - SvdParams, -}; -use std::time::Duration; - -use dyn_stack::*; -use rand::random; - -use faer_core::{Mat, Parallelism}; - -pub fn bidiag(c: &mut Criterion) { - for (m, n) in [ - (32, 32), - (64, 64), - (128, 128), - (256, 256), - (512, 512), - (1024, 1024), - (10000, 128), - (10000, 1024), - (2048, 2048), - (4096, 4096), - ] { - let mat = Mat::from_fn(m, n, |_, _| random::()); - - { - let mut copy = mat.clone(); - let mut householder_left = Mat::from_fn(n, 1, |_, _| random::()); - let mut householder_right = Mat::from_fn(n, 1, |_, _| random::()); - - let mut mem = GlobalPodBuffer::new( - faer_svd::bidiag::bidiagonalize_in_place_req::(m, n, Parallelism::None) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - c.bench_function(&format!("faer-st-bidiag-{m}x{n}"), |b| { - b.iter(|| { - copy.as_mut().copy_from(mat.as_ref()); - bidiagonalize_in_place( - copy.as_mut(), - householder_left.as_mut().col_mut(0).as_2d_mut(), - householder_right.as_mut().col_mut(0).as_2d_mut(), - Parallelism::None, - stack.rb_mut(), - ) - }) - }); - } - - { - let mut copy = mat.clone(); - let mut householder_left = Mat::from_fn(n, 1, |_, _| random::()); - let mut householder_right = Mat::from_fn(n, 1, |_, _| random::()); - - let mut mem = GlobalPodBuffer::new( - faer_svd::bidiag::bidiagonalize_in_place_req::(m, n, Parallelism::Rayon(0)) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - c.bench_function(&format!("faer-st-bidiag-{m}x{n}"), |b| { - b.iter(|| { - copy.as_mut().copy_from(mat.as_ref()); - bidiagonalize_in_place( - copy.as_mut(), - householder_left.as_mut().col_mut(0).as_2d_mut(), - householder_right.as_mut().col_mut(0).as_2d_mut(), - Parallelism::Rayon(0), - stack.rb_mut(), - ) - }) - }); - } - } -} - -fn bidiag_svd(c: &mut Criterion) { - for n in [32, 64, 128, 256, 1024, 4096] { - let diag = (0..n).map(|_| rand::random::()).collect::>(); - let subdiag = (0..n).map(|_| rand::random::()).collect::>(); - - { - let mut diag_copy = diag.clone(); - let mut subdiag_copy = subdiag.clone(); - - let mut u = Mat::zeros(n + 1, n + 1); - let mut v = Mat::zeros(n, n); - - let mut mem = GlobalPodBuffer::new( - faer_svd::bidiag_real_svd::bidiag_real_svd_req::( - n, - 4, - true, - true, - Parallelism::None, - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - - c.bench_function(&format!("faer-st-bidiag-svd-{n}"), |bencher| { - bencher.iter(|| { - diag_copy.clone_from_slice(&diag); - subdiag_copy.clone_from_slice(&subdiag); - let mut diag = (0..n).map(|_| rand::random::()).collect::>(); - let mut subdiag = (0..n).map(|_| rand::random::()).collect::>(); - compute_bidiag_real_svd( - &mut diag, - &mut subdiag, - Some(u.as_mut()), - Some(v.as_mut()), - 4, - 128, - f64::EPSILON, - f64::MIN_POSITIVE, - Parallelism::None, - stack.rb_mut(), - ); - }); - }); - } - let mut diag_copy = diag.clone(); - let mut subdiag_copy = subdiag.clone(); - - let mut u = Mat::zeros(n + 1, n + 1); - let mut v = Mat::zeros(n, n); - - let mut mem = GlobalPodBuffer::new( - faer_svd::bidiag_real_svd::bidiag_real_svd_req::( - n, - 4, - true, - true, - Parallelism::Rayon(0), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - { - c.bench_function(&format!("faer-mt-bidiag-svd-{n}"), |bencher| { - bencher.iter(|| { - diag_copy.clone_from_slice(&diag); - subdiag_copy.clone_from_slice(&subdiag); - compute_bidiag_real_svd( - &mut diag_copy, - &mut subdiag_copy, - Some(u.as_mut()), - Some(v.as_mut()), - 4, - 128, - f64::EPSILON, - f64::MIN_POSITIVE, - Parallelism::Rayon(0), - stack.rb_mut(), - ); - }); - }); - } - } -} - -fn real_svd(c: &mut Criterion) { - for (m, n) in [ - (8, 8), - (16, 16), - (32, 32), - (64, 64), - (128, 128), - (256, 256), - (512, 512), - (32, 4096), - (1024, 1024), - (10000, 128), - (10000, 1024), - (2048, 2048), - (4096, 4096), - ] { - let mat = Mat::from_fn(m, n, |_, _| rand::random::()); - let mat = mat.as_ref(); - - let size = m.min(n); - let mut s = Mat::zeros(size, 1); - let mut u = Mat::zeros(m, size); - let mut v = Mat::zeros(n, size); - - { - let mut mem = GlobalPodBuffer::new( - faer_svd::compute_svd_req::( - m, - n, - faer_svd::ComputeVectors::Full, - faer_svd::ComputeVectors::Full, - Parallelism::None, - SvdParams::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - c.bench_function(&format!("faer-st-svd-f64-{m}x{n}"), |bencher| { - bencher.iter(|| { - compute_svd( - mat, - s.as_mut().col_mut(0).as_2d_mut(), - Some(u.as_mut()), - Some(v.as_mut()), - Parallelism::None, - stack.rb_mut(), - SvdParams::default(), - ); - }); - }); - } - { - let mut mem = GlobalPodBuffer::new( - faer_svd::compute_svd_req::( - m, - n, - faer_svd::ComputeVectors::Full, - faer_svd::ComputeVectors::Full, - Parallelism::Rayon(0), - SvdParams::default(), - ) - .unwrap(), - ); - let mut stack = PodStack::new(&mut mem); - c.bench_function(&format!("faer-mt-svd-f64-{m}x{n}"), |bencher| { - bencher.iter(|| { - compute_svd( - mat, - s.as_mut().col_mut(0).as_2d_mut(), - Some(u.as_mut()), - Some(v.as_mut()), - Parallelism::Rayon(0), - stack.rb_mut(), - SvdParams::default(), - ); - }); - }); - } - } - - for (m, n) in [ - (8, 8), - (16, 16), - (32, 32), - (64, 64), - (128, 128), - (256, 256), - (512, 512), - (32, 4096), - ] { - let mut mat = nalgebra::DMatrix::::zeros(m, n); - for i in 0..m { - for j in 0..n { - mat[(i, j)] = random(); - } - } - - c.bench_function(&format!("nalgebra-svd-f64-{m}x{n}"), |bencher| { - bencher.iter(|| mat.clone().svd(true, true)) - }); - } -} - -criterion_group!( - name = benches; - config = Criterion::default() - .warm_up_time(Duration::from_secs(1)) - .measurement_time(Duration::from_secs(5)) - .sample_size(10); - targets = bidiag, bidiag_svd, real_svd, -); -criterion_main!(benches); diff --git a/faer-libs/faer-svd/katex-header.html b/faer-libs/faer-svd/katex-header.html deleted file mode 100644 index 32ac35a411428d1bcf1914b639299df9f86e448c..0000000000000000000000000000000000000000 --- a/faer-libs/faer-svd/katex-header.html +++ /dev/null @@ -1,15 +0,0 @@ - - - - diff --git a/faer-libs/faer/Cargo.toml b/faer-libs/faer/Cargo.toml deleted file mode 100644 index 0fa2e88777b3f9c020b7324a79f311f8280fc456..0000000000000000000000000000000000000000 --- a/faer-libs/faer/Cargo.toml +++ /dev/null @@ -1,107 +0,0 @@ -[package] -name = "faer" -version = "0.17.1" -edition = "2021" -authors = ["sarah <>"] -description = "Basic linear algebra routines" -readme = "../../README.md" -repository = "https://github.com/sarah-ek/faer-rs/" -license = "MIT" -keywords = ["math", "matrix", "linear-algebra"] - -[dependencies] -faer-core = { version = "0.17.1", path = "../faer-core", default-features = false } -faer-cholesky = { version = "0.17.1", path = "../faer-cholesky", default-features = false } -faer-lu = { version = "0.17.1", path = "../faer-lu", default-features = false } -faer-qr = { version = "0.17.1", path = "../faer-qr", default-features = false } -faer-svd = { version = "0.17.1", path = "../faer-svd", default-features = false } -faer-evd = { version = "0.17.1", path = "../faer-evd", default-features = false } -faer-sparse = { version = "0.17.1", path = "../faer-sparse", default-features = false } - -coe-rs = { workspace = true } -reborrow = { workspace = true } -pulp = { workspace = true, default-features = false } -dyn-stack = { workspace = true, default-features = false } -bytemuck = { workspace = true } - -dbgf = "0.1.1" - -num-complex = { workspace = true, default-features = false } - -matrixcompare = { version = "0.3", optional = true } - -nalgebra = { version = "0.32", optional = true, default-features = false } -ndarray = { version = "0.15", optional = true, default-features = false } -polars = { version = "0.37", optional = true, features = ["lazy"] } - -log = { workspace = true, optional = true, default-features = false } - -npyz = { version = "0.8", optional = true } - -[features] -default = ["std", "rayon", "matrixcompare"] -serde = ["faer-core/serde"] -std = [ - "faer-core/std", - "faer-cholesky/std", - "faer-lu/std", - "faer-qr/std", - "faer-svd/std", - "faer-evd/std", - "faer-sparse/std", - "matrixcompare", -] -perf-warn = [ - "log", - "faer-core/perf-warn", - "faer-cholesky/perf-warn", - "faer-lu/perf-warn", - "faer-qr/perf-warn", - "faer-svd/perf-warn", - "faer-evd/perf-warn", - "faer-sparse/perf-warn", -] -rayon = [ - "std", - "faer-core/rayon", - "faer-cholesky/rayon", - "faer-lu/rayon", - "faer-qr/rayon", - "faer-svd/rayon", - "faer-evd/rayon", - "faer-sparse/rayon", -] -nightly = [ - "faer-core/nightly", - "faer-cholesky/nightly", - "faer-lu/nightly", - "faer-qr/nightly", - "faer-svd/nightly", - "faer-evd/nightly", - "faer-sparse/nightly", - "pulp/nightly", -] -matrixcompare = ["dep:matrixcompare"] -nalgebra = ["dep:nalgebra"] -ndarray = ["dep:ndarray"] -polars = ["dep:polars"] -npy = ["std", "dep:npyz"] - -[dev-dependencies] -assert_approx_eq = "1.1.0" -rand = "0.8.5" -nalgebra = "0.32" -ndarray = "0.15" -polars = { version = "0.37", features = ["lazy", "parquet"] } - -[[example]] -name = "conversions" -required-features = ["nalgebra", "ndarray"] - -[[example]] -name = "polars" -required-features = ["polars" ] - -[package.metadata.docs.rs] -all-features = true -rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"] diff --git a/faer-libs/faer/examples/conversions.rs b/faer-libs/faer/examples/conversions.rs deleted file mode 100644 index 49b153e5f8bdc7d9d22f932fd3f281276afdce4e..0000000000000000000000000000000000000000 --- a/faer-libs/faer/examples/conversions.rs +++ /dev/null @@ -1,28 +0,0 @@ -use faer::{assert_matrix_eq, mat, IntoFaer, IntoNalgebra, IntoNdarray}; - -fn main() { - let matrix = mat![ - [2.28583329, -0.90628668, -1.71493024], - [-0.90628668, 4.00729077, 2.17332502], - [-1.71493024, 2.17332502, 1.97196187] - ]; - - let nalgebra = matrix.as_ref().into_nalgebra(); - let ndarray = matrix.as_ref().into_ndarray(); - - // compare multiplication using faer, with multiplication using nalgebra - assert_matrix_eq!( - &matrix * &matrix, - (nalgebra * nalgebra).view_range(.., ..).into_faer(), - comp = abs, - tol = 1e-14 - ); - - // compare addition using faer, with addition using ndarray - assert_matrix_eq!( - &matrix + &matrix, - (&ndarray + &ndarray).view().into_faer(), - comp = abs, - tol = 1e-14 - ); -} diff --git a/faer-libs/faer/examples/diabetes_data_raw.parquet b/faer-libs/faer/examples/diabetes_data_raw.parquet deleted file mode 100644 index bc91f12bb5d1b33d33cc229c56ed9fd4ff2c8616..0000000000000000000000000000000000000000 Binary files a/faer-libs/faer/examples/diabetes_data_raw.parquet and /dev/null differ diff --git a/faer-libs/faer/examples/iris.parquet b/faer-libs/faer/examples/iris.parquet deleted file mode 100644 index 3341b7f3e9f93edb14cd40e3a9f48a42a628c3b0..0000000000000000000000000000000000000000 Binary files a/faer-libs/faer/examples/iris.parquet and /dev/null differ diff --git a/faer-libs/faer/examples/lu.rs b/faer-libs/faer/examples/lu.rs deleted file mode 100644 index ed29c208bd487c0c424107a79df0d485da9d682c..0000000000000000000000000000000000000000 --- a/faer-libs/faer/examples/lu.rs +++ /dev/null @@ -1,23 +0,0 @@ -use faer::{assert_matrix_eq, mat, prelude::*, Mat}; - -fn main() { - let matrix = mat![ - [2.28583329, -0.90628668, -1.71493024], - [-0.90628668, 4.00729077, 2.17332502], - [-1.71493024, 2.17332502, 1.97196187] - ]; - - let lu = matrix.partial_piv_lu(); - - let rhs = mat![ - [-0.29945184, -0.5228196], - [0.84136125, -1.15768694], - [1.25678304, -0.46203532] - ]; - - let sol = lu.solve(&rhs); - let inv = lu.inverse(); - - assert_matrix_eq!(rhs, &matrix * &sol, comp = abs, tol = 1e-10); - assert_matrix_eq!(Mat::identity(3, 3), &matrix * &inv, comp = abs, tol = 1e-10); -} diff --git a/faer-libs/faer/examples/polars.rs b/faer-libs/faer/examples/polars.rs deleted file mode 100644 index a65df5365264bc36d5752a0f277798eb770159cc..0000000000000000000000000000000000000000 --- a/faer-libs/faer/examples/polars.rs +++ /dev/null @@ -1,29 +0,0 @@ -use faer::{dbgf, polars::polars_to_faer_f64}; -use polars::prelude::*; - -fn main() -> PolarsResult<()> { - let directory = "./faer/examples/"; - for filename in ["diabetes_data_raw.parquet", "iris.parquet"] { - dbg!(filename); - - let data = LazyFrame::scan_parquet( - format!("{directory}{filename}"), - ScanArgsParquet { - n_rows: None, - cache: true, - parallel: ParallelStrategy::Auto, - rechunk: true, - row_index: None, - low_memory: false, - cloud_options: None, - use_statistics: true, - ..Default::default() - }, - ) - .and_then(|lf| polars_to_faer_f64(lf)) - .unwrap(); - dbgf!("6.2?", data); - } - - Ok(()) -} diff --git a/faer-libs/faer/katex-header.html b/faer-libs/faer/katex-header.html deleted file mode 100644 index 32ac35a411428d1bcf1914b639299df9f86e448c..0000000000000000000000000000000000000000 --- a/faer-libs/faer/katex-header.html +++ /dev/null @@ -1,15 +0,0 @@ - - - - diff --git a/faer-libs/faer/src/lib.rs b/faer-libs/faer/src/lib.rs deleted file mode 100644 index 006b7a9b4259afce7360613f5a07cfdd99228f7e..0000000000000000000000000000000000000000 --- a/faer-libs/faer/src/lib.rs +++ /dev/null @@ -1,5225 +0,0 @@ -//! `faer` is a general-purpose linear algebra library for Rust, with a focus on high performance -//! for algebraic operations on medium/large matrices, as well as matrix decompositions. -//! -//! Most of the high-level functionality in this library is provided through associated functions in -//! its vocabulary types: [`Mat`]/[`MatRef`]/[`MatMut`], as well as the [`FaerMat`] extension trait. -//! The parent crates (`faer-core`, `faer-cholesky`, `faer-lu`, etc.), on the other hand, offer a -//! lower-level of abstraction in exchange for more control over memory allocations and -//! multithreading behavior. -//! -//! `faer` is recommended for applications that handle medium to large dense matrices, and its -//! design is not well suited for applications that operate mostly on low dimensional vectors and -//! matrices such as computer graphics or game development. For those purposes, `nalgebra` and -//! `cgmath` may provide better tools. -//! -//! # Basic usage -//! -//! [`Mat`] is a resizable matrix type with dynamic capacity, which can be created using -//! [`Mat::new`] to produce an empty $0\times 0$ matrix, [`Mat::zeros`] to create a rectangular -//! matrix filled with zeros, [`Mat::identity`] to create an identity matrix, or [`Mat::from_fn`] -//! for the most generic case. -//! -//! Given a `&Mat` (resp. `&mut Mat`), a [`MatRef<'_, E>`](MatRef) (resp. [`MatMut<'_, -//! E>`](MatMut)) can be created by calling [`Mat::as_ref`] (resp. [`Mat::as_mut`]), which allow -//! for more flexibility than `Mat` in that they allow slicing ([`MatRef::get`]) and splitting -//! ([`MatRef::split_at`]). -//! -//! `MatRef` and `MatMut` are lightweight view objects. The former can be copied freely while the -//! latter has move and reborrow semantics, as described in its documentation. -//! -//! More details about the vocabulary types can be found in the `faer-core` crate-level -//! documentation. See also: [`faer_core::Entity`] and [`faer_core::complex_native`]. -//! -//! Most of the matrix operations can be used through the corresponding math operators: `+` for -//! matrix addition, `-` for subtraction, `*` for either scalar or matrix multiplication depending -//! on the types of the operands. -//! -//! ## Example -//! ``` -//! use faer::{mat, prelude::*, scale, Mat}; -//! -//! let a = mat![ -//! [1.0, 5.0, 9.0], -//! [2.0, 6.0, 10.0], -//! [3.0, 7.0, 11.0], -//! [4.0, 8.0, 12.0f64], -//! ]; -//! -//! let b = Mat::::from_fn(4, 3, |i, j| (i + j) as f64); -//! -//! let add = &a + &b; -//! let sub = &a - &b; -//! let scale = scale(3.0) * &a; -//! let mul = &a * b.transpose(); -//! -//! let a00 = a[(0, 0)]; -//! ``` -//! -//! # Matrix decompositions -//! `faer` provides a variety of matrix factorizations, each with its own advantages and drawbacks: -//! -//! ## Cholesky decomposition -//! [`FaerMat::cholesky`] decomposes a self-adjoint positive definite matrix $A$ such that -//! $$A = LL^H,$$ -//! where $L$ is a lower triangular matrix. This decomposition is highly efficient and has good -//! stability properties. -//! -//! [An implementation for sparse matrices is also available.](sparse::solvers::Cholesky) -//! -//! ## Bunch-Kaufman decomposition -//! [`FaerMat::lblt`] decomposes a self-adjoint (possibly indefinite) matrix $A$ such that -//! $$P A P^\top = LBL^H,$$ -//! where $P$ is a permutation matrix, $L$ is a lower triangular matrix, and $B$ is a block -//! diagonal matrix, with $1 \times 1$ or $2 \times 2$ diagonal blocks. -//! This decomposition is efficient and has good stability properties. -//! ## LU decomposition with partial pivoting -//! [`FaerMat::partial_piv_lu`] decomposes a square invertible matrix $A$ into a lower triangular -//! matrix $L$, a unit upper triangular matrix $U$, and a permutation matrix $P$, such that -//! $$PA = LU.$$ -//! It is used by default for computing the determinant, and is generally the recommended method -//! for solving a square linear system or computing the inverse of a matrix (although we generally -//! recommend using a [`Solver`] instead of computing the inverse explicitly). -//! -//! [An implementation for sparse matrices is also available.](sparse::solvers::Lu) -//! -//! ## LU decomposition with full pivoting -//! [`FaerMat::full_piv_lu`] Decomposes a generic rectangular matrix $A$ into a lower triangular -//! matrix $L$, a unit upper triangular matrix $U$, and permutation matrices $P$ and $Q$, such that -//! $$PAQ^\top = LU.$$ -//! It can be more stable than the LU decomposition with partial pivoting, in exchange for being -//! more computationally expensive. -//! -//! ## QR decomposition -//! The QR decomposition ([`FaerMat::qr`]) decomposes a matrix $A$ into the product -//! $$A = QR,$$ -//! where $Q$ is a unitary matrix, and $R$ is an upper trapezoidal matrix. It is often used for -//! solving least squares problems. -//! -//! [An implementation for sparse matrices is also available.](sparse::solvers::Qr) -//! -//! ## QR decomposition with column pivoting -//! The QR decomposition with column pivoting ([`FaerMat::col_piv_qr`]) decomposes a matrix $A$ into -//! the product -//! $$AP^T = QR,$$ -//! where $P$ is a permutation matrix, $Q$ is a unitary matrix, and $R$ is an upper trapezoidal -//! matrix. -//! -//! It is slower than the version with no pivoting, in exchange for being more numerically stable -//! for rank-deficient matrices. -//! -//! ## Singular value decomposition -//! The SVD of a matrix $M$ of shape $(m, n)$ is a decomposition into three components $U$, $S$, -//! and $V$, such that: -//! -//! - $U$ has shape $(m, m)$ and is a unitary matrix, -//! - $V$ has shape $(n, n)$ and is a unitary matrix, -//! - $S$ has shape $(m, n)$ and is zero everywhere except the main diagonal, with nonnegative -//! diagonal values in nonincreasing order, -//! - and finally: -//! -//! $$M = U S V^H.$$ -//! -//! The SVD is provided in two forms: either the full matrices $U$ and $V$ are computed, using -//! [`FaerMat::svd`], or only their first $\min(m, n)$ columns are computed, using -//! [`FaerMat::thin_svd`]. -//! -//! If only the singular values (elements of $S$) are desired, they can be obtained in -//! nonincreasing order using [`FaerMat::singular_values`]. -//! -//! ## Eigendecomposition -//! **Note**: The order of the eigenvalues is currently unspecified and may be changed in a future -//! release. -//! -//! The eigendecomposition of a square matrix $M$ of shape $(n, n)$ is a decomposition into -//! two components $U$, $S$: -//! -//! - $U$ has shape $(n, n)$ and is invertible, -//! - $S$ has shape $(n, n)$ and is a diagonal matrix, -//! - and finally: -//! -//! $$M = U S U^{-1}.$$ -//! -//! If $M$ is hermitian, then $U$ can be made unitary ($U^{-1} = U^H$), and $S$ is real valued. -//! -//! Depending on the domain of the input matrix and whether it is self-adjoint, multiple methods -//! are provided to compute the eigendecomposition: -//! * [`FaerMat::selfadjoint_eigendecomposition`] can be used with either real or complex matrices, -//! producing an eigendecomposition of the same type. -//! * [`FaerMat::eigendecomposition`] can be used with either real or complex matrices, but the -//! output -//! complex type has to be specified. -//! * [`FaerMat::complex_eigendecomposition`] can only be used with complex matrices, with the -//! output -//! having the same type. -//! -//! If only the eigenvalues (elements of $S$) are desired, they can be obtained in -//! nonincreasing order using [`FaerMat::selfadjoint_eigenvalues`], [`FaerMat::eigenvalues`], or -//! [`FaerMat::complex_eigenvalues`], with the same conditions described above. -//! -//! # Crate features -//! -//! - `std`: enabled by default. Links with the standard library to enable additional features such -//! as cpu feature detection at runtime. -//! - `rayon`: enabled by default. Enables the `rayon` parallel backend and enables global -//! parallelism by default. -//! - `matrixcompare`: enabled by default. Enables macros for approximate equality checks on -//! matrices. -//! - `serde`: Enables serialization and deserialization of [`Mat`]. -//! - `npy`: Enables conversions to/from numpy's matrix file format. -//! - `perf-warn`: Produces performance warnings when matrix operations are called with suboptimal -//! data layout. -//! - `polars`: Enables basic interoperability with the `polars` crate. -//! - `nalgebra`: Enables basic interoperability with the `nalgebra` crate. -//! - `ndarray`: Enables basic interoperability with the `ndarray` crate. -//! - `nightly`: Requires the nightly compiler. Enables experimental SIMD features such as AVX512. - -#![cfg_attr(docsrs, feature(doc_cfg))] -#![cfg_attr(not(feature = "std"), no_std)] - -use dyn_stack::{GlobalPodBuffer, PodStack}; -use faer_core::{AsMatMut, AsMatRef, ComplexField, Conj, Conjugate, Entity}; -use prelude::*; -use solvers::*; - -/// Similar to the [`dbg`] macro, but takes a format spec as a first parameter. -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -pub use dbgf::dbgf; -pub use faer_cholesky::llt::CholeskyError; - -/// Re-exports. -pub mod modules { - pub use faer_cholesky as cholesky; - pub use faer_core as core; - pub use faer_evd as evd; - pub use faer_lu as lu; - pub use faer_qr as qr; - pub use faer_sparse as sparse; - pub use faer_svd as svd; -} - -/// Commonly used traits for a streamlined user experience. -pub mod prelude { - pub use crate::{ - solvers::{Solver, SolverCore, SolverLstsq, SolverLstsqCore}, - sparse::solvers::{SpSolver, SpSolverCore, SpSolverLstsq, SpSolverLstsqCore}, - FaerMat, IntoFaer, IntoFaerComplex, - }; - pub use reborrow::{IntoConst, Reborrow, ReborrowMut}; -} - -pub use faer_core::{ - col, complex_native, get_global_parallelism, mat, row, scale, set_global_parallelism, unzipped, - zipped, Col, ColMut, ColRef, Mat, MatMut, MatRef, Parallelism, Row, RowMut, RowRef, Side, -}; -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -pub use matrixcompare::assert_matrix_eq; - -extern crate alloc; -use alloc::{vec, vec::Vec}; - -/// Matrix solvers and decompositions. -pub mod solvers { - use super::*; - use faer_core::{assert, permutation::PermutationRef, zipped}; - use sparse::solvers::{SpSolverCore, SpSolverLstsqCore}; - - pub trait SolverCore: SpSolverCore { - /// Reconstructs the original matrix using the decomposition. - fn reconstruct(&self) -> Mat; - /// Computes the inverse of the original matrix using the decomposition. - /// - /// # Panics - /// Panics if the matrix is not square. - fn inverse(&self) -> Mat; - } - pub trait SolverLstsqCore: SolverCore + SpSolverLstsqCore {} - - pub trait Solver: SolverCore + SpSolver {} - pub trait SolverLstsq: SolverLstsqCore + SpSolverLstsq {} - - const _: () = { - fn __assert_object_safe() { - let _: Option<&dyn SolverCore> = None; - let _: Option<&dyn SolverLstsqCore> = None; - } - }; - - impl> SolverLstsq for Dec {} - - impl> Solver for Dec {} - - /// Cholesky decomposition. - pub struct Cholesky { - factors: Mat, - } - - /// Bunch-Kaufman decomposition. - pub struct Lblt { - factors: Mat, - subdiag: Mat, - perm: Vec, - perm_inv: Vec, - } - - /// LU decomposition with partial pivoting. - pub struct PartialPivLu { - pub(crate) factors: Mat, - row_perm: Vec, - row_perm_inv: Vec, - n_transpositions: usize, - } - /// LU decomposition with full pivoting. - pub struct FullPivLu { - factors: Mat, - row_perm: Vec, - row_perm_inv: Vec, - col_perm: Vec, - col_perm_inv: Vec, - n_transpositions: usize, - } - - /// QR decomposition. - pub struct Qr { - factors: Mat, - householder: Mat, - } - /// QR decomposition with column pivoting. - pub struct ColPivQr { - factors: Mat, - householder: Mat, - col_perm: Vec, - col_perm_inv: Vec, - } - - /// Singular value decomposition. - pub struct Svd { - s: Mat, - u: Mat, - v: Mat, - } - /// Thin singular value decomposition. - pub struct ThinSvd { - inner: Svd, - } - - /// Self-adjoint eigendecomposition. - pub struct SelfAdjointEigendecomposition { - s: Mat, - u: Mat, - } - - /// Complex eigendecomposition. - pub struct Eigendecomposition { - s: Col, - u: Mat, - } - - impl Cholesky { - #[track_caller] - pub fn try_new>( - matrix: MatRef<'_, ViewE>, - side: Side, - ) -> Result { - assert!(matrix.nrows() == matrix.ncols()); - - let dim = matrix.nrows(); - let parallelism = get_global_parallelism(); - - let mut factors = Mat::::zeros(dim, dim); - match side { - Side::Lower => { - zipped!(factors.as_mut(), matrix).for_each_triangular_lower( - faer_core::zip::Diag::Include, - |unzipped!(mut dst, src)| dst.write(src.read().canonicalize()), - ); - } - Side::Upper => { - zipped!(factors.as_mut(), matrix.adjoint()).for_each_triangular_lower( - faer_core::zip::Diag::Include, - |unzipped!(mut dst, src)| dst.write(src.read().canonicalize()), - ); - } - } - - let params = Default::default(); - - faer_cholesky::llt::compute::cholesky_in_place( - factors.as_mut(), - Default::default(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_cholesky::llt::compute::cholesky_in_place_req::( - dim, - parallelism, - params, - ) - .unwrap(), - )), - params, - )?; - Ok(Self { factors }) - } - - fn dim(&self) -> usize { - self.factors.nrows() - } - - pub fn compute_l(&self) -> Mat { - let mut factor = self.factors.to_owned(); - zipped!(factor.as_mut()) - .for_each_triangular_upper(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { - dst.write(E::faer_zero()) - }); - factor - } - } - impl SpSolverCore for Cholesky { - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - - faer_cholesky::llt::solve::solve_in_place_with_conj( - self.factors.as_ref(), - conj, - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_cholesky::llt::solve::solve_in_place_req::( - self.dim(), - rhs_ncols, - parallelism, - ) - .unwrap(), - )), - ); - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - self.solve_in_place_with_conj_impl(rhs, conj.compose(Conj::Yes)) - } - - fn nrows(&self) -> usize { - self.factors.nrows() - } - - fn ncols(&self) -> usize { - self.factors.ncols() - } - } - impl SolverCore for Cholesky { - fn inverse(&self) -> Mat { - let mut inv = Mat::::zeros(self.dim(), self.dim()); - let parallelism = get_global_parallelism(); - - faer_cholesky::llt::inverse::invert_lower( - inv.as_mut(), - self.factors.as_ref(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_cholesky::llt::inverse::invert_lower_req::(self.dim(), parallelism) - .unwrap(), - )), - ); - - for j in 0..self.dim() { - for i in 0..j { - inv.write(i, j, inv.read(j, i).faer_conj()); - } - } - - inv - } - - fn reconstruct(&self) -> Mat { - let mut rec = Mat::::zeros(self.dim(), self.dim()); - let parallelism = get_global_parallelism(); - - faer_cholesky::llt::reconstruct::reconstruct_lower( - rec.as_mut(), - self.factors.as_ref(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_cholesky::llt::reconstruct::reconstruct_lower_req::(self.dim()) - .unwrap(), - )), - ); - - for j in 0..self.dim() { - for i in 0..j { - rec.write(i, j, rec.read(j, i).faer_conj()); - } - } - - rec - } - } - - impl Lblt { - #[track_caller] - pub fn new>(matrix: MatRef<'_, ViewE>, side: Side) -> Self { - assert!(matrix.nrows() == matrix.ncols()); - - let dim = matrix.nrows(); - let parallelism = get_global_parallelism(); - - let mut factors = Mat::::zeros(dim, dim); - let mut subdiag = Mat::::zeros(dim, 1); - let mut perm = vec![0; dim]; - let mut perm_inv = vec![0; dim]; - - match side { - Side::Lower => { - zipped!(factors.as_mut(), matrix).for_each_triangular_lower( - faer_core::zip::Diag::Include, - |unzipped!(mut dst, src)| dst.write(src.read().canonicalize()), - ); - } - Side::Upper => { - zipped!(factors.as_mut(), matrix.adjoint()).for_each_triangular_lower( - faer_core::zip::Diag::Include, - |unzipped!(mut dst, src)| dst.write(src.read().canonicalize()), - ); - } - } - - let params = Default::default(); - - faer_cholesky::bunch_kaufman::compute::cholesky_in_place( - factors.as_mut(), - subdiag.as_mut(), - Default::default(), - &mut perm, - &mut perm_inv, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_cholesky::bunch_kaufman::compute::cholesky_in_place_req::( - dim, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - Self { - factors, - subdiag, - perm, - perm_inv, - } - } - - fn dim(&self) -> usize { - self.factors.nrows() - } - } - - impl SpSolverCore for Lblt { - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - - faer_cholesky::bunch_kaufman::solve::solve_in_place_with_conj( - self.factors.as_ref(), - self.subdiag.as_ref(), - conj, - unsafe { PermutationRef::new_unchecked(&self.perm, &self.perm_inv) }, - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_cholesky::bunch_kaufman::solve::solve_in_place_req::( - self.dim(), - rhs_ncols, - parallelism, - ) - .unwrap(), - )), - ); - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - self.solve_in_place_with_conj_impl(rhs, conj.compose(Conj::Yes)) - } - - fn nrows(&self) -> usize { - self.factors.nrows() - } - - fn ncols(&self) -> usize { - self.factors.ncols() - } - } - impl SolverCore for Lblt { - fn inverse(&self) -> Mat { - let n = self.dim(); - let mut inv = Mat::identity(n, n); - self.solve_in_place_with_conj_impl(inv.as_mut(), Conj::No); - inv - } - - fn reconstruct(&self) -> Mat { - let parallelism = get_global_parallelism(); - let n = self.dim(); - let lbl = self.factors.as_ref(); - let subdiag = self.subdiag.as_ref(); - let mut mat = Mat::::identity(n, n); - let mut mat2 = Mat::::identity(n, n); - zipped!(mat.as_mut(), lbl).for_each_triangular_lower( - faer_core::zip::Diag::Skip, - |unzipped!(mut dst, src)| dst.write(src.read()), - ); - - let mut j = 0; - while j < n { - if subdiag.read(j, 0) == E::faer_zero() { - let d = lbl.read(j, j).faer_real().faer_inv(); - for i in 0..n { - mat.write(i, j, mat.read(i, j).faer_scale_real(d)); - } - j += 1; - } else { - let akp1k = subdiag.read(j, 0).faer_inv(); - let ak = akp1k.faer_scale_real(lbl.read(j, j).faer_real()); - let akp1 = akp1k - .faer_conj() - .faer_scale_real(lbl.read(j + 1, j + 1).faer_real()); - let denom = ak - .faer_mul(akp1) - .faer_sub(E::faer_one()) - .faer_real() - .faer_inv(); - - for i in 0..n { - let xk = mat.read(i, j).faer_mul(akp1k); - let xkp1 = mat.read(i, j + 1).faer_mul(akp1k.faer_conj()); - - mat.write( - i, - j, - (akp1.faer_mul(xk).faer_sub(xkp1)).faer_scale_real(denom), - ); - mat.write( - i, - j + 1, - (ak.faer_mul(xkp1).faer_sub(xk)).faer_scale_real(denom), - ); - } - j += 2; - } - } - faer_core::mul::triangular::matmul( - mat2.as_mut(), - faer_core::mul::triangular::BlockStructure::TriangularLower, - lbl, - faer_core::mul::triangular::BlockStructure::UnitTriangularLower, - mat.as_ref().adjoint(), - faer_core::mul::triangular::BlockStructure::Rectangular, - None, - E::faer_one(), - parallelism, - ); - - for j in 0..n { - let pj = self.perm_inv[j]; - for i in j..n { - let pi = self.perm_inv[i]; - - mat.write( - i, - j, - if pi >= pj { - mat2.read(pi, pj) - } else { - mat2.read(pj, pi).faer_conj() - }, - ); - } - } - - for j in 0..n { - mat.write(j, j, E::faer_from_real(mat.read(j, j).faer_real())); - for i in 0..j { - mat.write(i, j, mat.read(j, i).faer_conj()); - } - } - - mat - } - } - - impl PartialPivLu { - #[track_caller] - pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { - assert!(matrix.nrows() == matrix.ncols()); - - let dim = matrix.nrows(); - let parallelism = get_global_parallelism(); - - let mut factors = matrix.to_owned(); - - let params = Default::default(); - - let mut row_perm = vec![0usize; dim]; - let mut row_perm_inv = vec![0usize; dim]; - - let (n_transpositions, _) = faer_lu::partial_pivoting::compute::lu_in_place( - factors.as_mut(), - &mut row_perm, - &mut row_perm_inv, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_lu::partial_pivoting::compute::lu_in_place_req::( - dim, - dim, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - Self { - n_transpositions: n_transpositions.transposition_count, - factors, - row_perm, - row_perm_inv, - } - } - - fn dim(&self) -> usize { - self.factors.nrows() - } - - pub fn row_permutation(&self) -> PermutationRef<'_, usize, E> { - unsafe { PermutationRef::new_unchecked(&self.row_perm, &self.row_perm_inv) } - } - - pub fn transposition_count(&self) -> usize { - self.n_transpositions - } - - pub fn compute_l(&self) -> Mat { - let mut factor = self.factors.to_owned(); - zipped!(factor.as_mut()) - .for_each_triangular_upper(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { - dst.write(E::faer_zero()) - }); - factor - } - pub fn compute_u(&self) -> Mat { - let mut factor = self.factors.to_owned(); - zipped!(factor.as_mut()) - .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { - dst.write(E::faer_zero()) - }); - factor - .as_mut() - .diagonal_mut() - .column_vector_mut() - .fill(E::faer_one()); - factor - } - } - impl SpSolverCore for PartialPivLu { - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - - faer_lu::partial_pivoting::solve::solve_in_place( - self.factors.as_ref(), - conj, - self.row_permutation(), - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_lu::partial_pivoting::solve::solve_in_place_req::( - self.dim(), - self.dim(), - rhs_ncols, - parallelism, - ) - .unwrap(), - )), - ); - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - - faer_lu::partial_pivoting::solve::solve_transpose_in_place( - self.factors.as_ref(), - conj, - self.row_permutation(), - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_lu::partial_pivoting::solve::solve_transpose_in_place_req::( - self.dim(), - self.dim(), - rhs_ncols, - parallelism, - ) - .unwrap(), - )), - ); - } - - fn nrows(&self) -> usize { - self.factors.nrows() - } - - fn ncols(&self) -> usize { - self.factors.ncols() - } - } - impl SolverCore for PartialPivLu { - fn inverse(&self) -> Mat { - let mut inv = Mat::::zeros(self.dim(), self.dim()); - let parallelism = get_global_parallelism(); - - faer_lu::partial_pivoting::inverse::invert( - inv.as_mut(), - self.factors.as_ref(), - self.row_permutation(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_lu::partial_pivoting::inverse::invert_req::( - self.dim(), - self.dim(), - parallelism, - ) - .unwrap(), - )), - ); - - inv - } - - fn reconstruct(&self) -> Mat { - let mut rec = Mat::::zeros(self.dim(), self.dim()); - let parallelism = get_global_parallelism(); - - faer_lu::partial_pivoting::reconstruct::reconstruct( - rec.as_mut(), - self.factors.as_ref(), - self.row_permutation(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_lu::partial_pivoting::reconstruct::reconstruct_req::( - self.dim(), - self.dim(), - parallelism, - ) - .unwrap(), - )), - ); - - rec - } - } - - impl FullPivLu { - #[track_caller] - pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { - let m = matrix.nrows(); - let n = matrix.ncols(); - let parallelism = get_global_parallelism(); - - let mut factors = matrix.to_owned(); - - let params = Default::default(); - - let mut row_perm = vec![0usize; m]; - let mut row_perm_inv = vec![0usize; m]; - let mut col_perm = vec![0usize; n]; - let mut col_perm_inv = vec![0usize; n]; - - let (n_transpositions, _, _) = faer_lu::full_pivoting::compute::lu_in_place( - factors.as_mut(), - &mut row_perm, - &mut row_perm_inv, - &mut col_perm, - &mut col_perm_inv, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_lu::full_pivoting::compute::lu_in_place_req::( - m, - n, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - Self { - factors, - row_perm, - row_perm_inv, - col_perm, - col_perm_inv, - n_transpositions: n_transpositions.transposition_count, - } - } - - pub fn row_permutation(&self) -> PermutationRef<'_, usize, E> { - unsafe { PermutationRef::new_unchecked(&self.row_perm, &self.row_perm_inv) } - } - pub fn col_permutation(&self) -> PermutationRef<'_, usize, E> { - unsafe { PermutationRef::new_unchecked(&self.col_perm, &self.col_perm_inv) } - } - - pub fn transposition_count(&self) -> usize { - self.n_transpositions - } - - pub fn compute_l(&self) -> Mat { - let size = Ord::min(self.nrows(), self.ncols()); - let mut factor = self - .factors - .as_ref() - .submatrix(0, 0, self.nrows(), size) - .to_owned(); - zipped!(factor.as_mut()) - .for_each_triangular_upper(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { - dst.write(E::faer_zero()) - }); - factor - } - pub fn compute_u(&self) -> Mat { - let size = Ord::min(self.nrows(), self.ncols()); - let mut factor = self - .factors - .as_ref() - .submatrix(0, 0, size, self.ncols()) - .to_owned(); - zipped!(factor.as_mut()) - .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { - dst.write(E::faer_zero()) - }); - factor - .as_mut() - .diagonal_mut() - .column_vector_mut() - .fill(E::faer_one()); - factor - } - } - impl SpSolverCore for FullPivLu { - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - assert!(self.nrows() == self.ncols()); - - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - - faer_lu::full_pivoting::solve::solve_in_place( - self.factors.as_ref(), - conj, - self.row_permutation(), - self.col_permutation(), - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_lu::full_pivoting::solve::solve_in_place_req::( - self.nrows(), - self.ncols(), - rhs_ncols, - parallelism, - ) - .unwrap(), - )), - ); - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - assert!(self.nrows() == self.ncols()); - - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - - faer_lu::full_pivoting::solve::solve_transpose_in_place( - self.factors.as_ref(), - conj, - self.row_permutation(), - self.col_permutation(), - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_lu::full_pivoting::solve::solve_transpose_in_place_req::( - self.nrows(), - self.ncols(), - rhs_ncols, - parallelism, - ) - .unwrap(), - )), - ); - } - - fn nrows(&self) -> usize { - self.factors.nrows() - } - - fn ncols(&self) -> usize { - self.factors.ncols() - } - } - impl SolverCore for FullPivLu { - #[track_caller] - fn inverse(&self) -> Mat { - assert!(self.nrows() == self.ncols()); - - let dim = self.nrows(); - - let mut inv = Mat::::zeros(dim, dim); - let parallelism = get_global_parallelism(); - - faer_lu::full_pivoting::inverse::invert( - inv.as_mut(), - self.factors.as_ref(), - self.row_permutation(), - self.col_permutation(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_lu::full_pivoting::inverse::invert_req::(dim, dim, parallelism) - .unwrap(), - )), - ); - - inv - } - - fn reconstruct(&self) -> Mat { - let mut rec = Mat::::zeros(self.nrows(), self.ncols()); - let parallelism = get_global_parallelism(); - - faer_lu::full_pivoting::reconstruct::reconstruct( - rec.as_mut(), - self.factors.as_ref(), - self.row_permutation(), - self.col_permutation(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_lu::full_pivoting::reconstruct::reconstruct_req::( - self.nrows(), - self.ncols(), - parallelism, - ) - .unwrap(), - )), - ); - - rec - } - } - - impl Qr { - #[track_caller] - pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { - let parallelism = get_global_parallelism(); - let nrows = matrix.nrows(); - let ncols = matrix.ncols(); - - let mut factors = matrix.to_owned(); - let size = Ord::min(nrows, ncols); - let blocksize = faer_qr::no_pivoting::compute::recommended_blocksize::(nrows, ncols); - let mut householder = Mat::::zeros(blocksize, size); - - let params = Default::default(); - - faer_qr::no_pivoting::compute::qr_in_place( - factors.as_mut(), - householder.as_mut(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_qr::no_pivoting::compute::qr_in_place_req::( - nrows, - ncols, - blocksize, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - Self { - factors, - householder, - } - } - - fn blocksize(&self) -> usize { - self.householder.nrows() - } - - pub fn compute_r(&self) -> Mat { - let mut factor = self.factors.to_owned(); - zipped!(factor.as_mut()) - .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { - dst.write(E::faer_zero()) - }); - factor - } - - pub fn compute_q(&self) -> Mat { - Self::__compute_q_impl(self.factors.as_ref(), self.householder.as_ref(), false) - } - - pub fn compute_thin_r(&self) -> Mat { - let m = self.nrows(); - let n = self.ncols(); - let mut factor = self.factors.as_ref().subrows(0, Ord::min(m, n)).to_owned(); - zipped!(factor.as_mut()) - .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { - dst.write(E::faer_zero()) - }); - factor - } - - pub fn compute_thin_q(&self) -> Mat { - Self::__compute_q_impl(self.factors.as_ref(), self.householder.as_ref(), true) - } - - fn __compute_q_impl( - factors: MatRef<'_, E>, - householder: MatRef<'_, E>, - thin: bool, - ) -> Mat { - let parallelism = get_global_parallelism(); - let m = factors.nrows(); - let size = Ord::min(m, factors.ncols()); - - let mut q = Mat::::zeros(m, if thin { size } else { m }); - q.as_mut() - .diagonal_mut() - .column_vector_mut() - .fill(E::faer_one()); - - faer_core::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj( - factors, - householder, - Conj::No, - q.as_mut(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_core::householder::apply_block_householder_sequence_on_the_left_in_place_req::( - m, - householder.nrows(), - m, - ) - .unwrap(), - )), - ); - - q - } - } - impl SpSolverCore for Qr { - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - assert!(self.nrows() == self.ncols()); - self.solve_lstsq_in_place_with_conj_impl(rhs, conj) - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - assert!(self.nrows() == self.ncols()); - - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - - faer_qr::no_pivoting::solve::solve_transpose_in_place( - self.factors.as_ref(), - self.householder.as_ref(), - conj, - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_qr::no_pivoting::solve::solve_transpose_in_place_req::( - self.nrows(), - self.blocksize(), - rhs_ncols, - ) - .unwrap(), - )), - ); - } - - fn nrows(&self) -> usize { - self.factors.nrows() - } - - fn ncols(&self) -> usize { - self.factors.ncols() - } - } - impl SolverCore for Qr { - fn reconstruct(&self) -> Mat { - let mut rec = Mat::::zeros(self.nrows(), self.ncols()); - let parallelism = get_global_parallelism(); - - faer_qr::no_pivoting::reconstruct::reconstruct( - rec.as_mut(), - self.factors.as_ref(), - self.householder.as_ref(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_qr::no_pivoting::reconstruct::reconstruct_req::( - self.nrows(), - self.ncols(), - self.blocksize(), - parallelism, - ) - .unwrap(), - )), - ); - - rec - } - - fn inverse(&self) -> Mat { - assert!(self.nrows() == self.ncols()); - - let mut inv = Mat::::zeros(self.nrows(), self.ncols()); - let parallelism = get_global_parallelism(); - - faer_qr::no_pivoting::inverse::invert( - inv.as_mut(), - self.factors.as_ref(), - self.householder.as_ref(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_qr::no_pivoting::inverse::invert_req::( - self.nrows(), - self.ncols(), - self.blocksize(), - parallelism, - ) - .unwrap(), - )), - ); - - inv - } - } - - impl SpSolverLstsqCore for Qr { - #[track_caller] - fn solve_lstsq_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - - faer_qr::no_pivoting::solve::solve_in_place( - self.factors.as_ref(), - self.householder.as_ref(), - conj, - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_qr::no_pivoting::solve::solve_in_place_req::( - self.nrows(), - self.blocksize(), - rhs_ncols, - ) - .unwrap(), - )), - ); - } - } - impl SolverLstsqCore for Qr {} - - impl ColPivQr { - #[track_caller] - pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { - let parallelism = get_global_parallelism(); - let nrows = matrix.nrows(); - let ncols = matrix.ncols(); - - let mut factors = matrix.to_owned(); - let size = Ord::min(nrows, ncols); - let blocksize = - faer_qr::col_pivoting::compute::recommended_blocksize::(nrows, ncols); - let mut householder = Mat::::zeros(blocksize, size); - - let params = Default::default(); - - let mut col_perm = vec![0usize; ncols]; - let mut col_perm_inv = vec![0usize; ncols]; - - faer_qr::col_pivoting::compute::qr_in_place( - factors.as_mut(), - householder.as_mut(), - &mut col_perm, - &mut col_perm_inv, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_qr::col_pivoting::compute::qr_in_place_req::( - nrows, - ncols, - blocksize, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - Self { - factors, - householder, - col_perm, - col_perm_inv, - } - } - - pub fn col_permutation(&self) -> PermutationRef<'_, usize, E> { - unsafe { PermutationRef::new_unchecked(&self.col_perm, &self.col_perm_inv) } - } - - fn blocksize(&self) -> usize { - self.householder.nrows() - } - - pub fn compute_r(&self) -> Mat { - let mut factor = self.factors.to_owned(); - zipped!(factor.as_mut()) - .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { - dst.write(E::faer_zero()) - }); - factor - } - - pub fn compute_q(&self) -> Mat { - Qr::::__compute_q_impl(self.factors.as_ref(), self.householder.as_ref(), false) - } - - pub fn compute_thin_r(&self) -> Mat { - let m = self.nrows(); - let n = self.ncols(); - let mut factor = self.factors.as_ref().subrows(0, Ord::min(m, n)).to_owned(); - zipped!(factor.as_mut()) - .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { - dst.write(E::faer_zero()) - }); - factor - } - - pub fn compute_thin_q(&self) -> Mat { - Qr::::__compute_q_impl(self.factors.as_ref(), self.householder.as_ref(), true) - } - } - impl SpSolverCore for ColPivQr { - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - assert!(self.nrows() == self.ncols()); - self.solve_lstsq_in_place_with_conj_impl(rhs, conj); - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - assert!(self.nrows() == self.ncols()); - - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - - faer_qr::col_pivoting::solve::solve_transpose_in_place( - self.factors.as_ref(), - self.householder.as_ref(), - self.col_permutation(), - conj, - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_qr::col_pivoting::solve::solve_transpose_in_place_req::( - self.nrows(), - self.blocksize(), - rhs_ncols, - ) - .unwrap(), - )), - ); - } - - fn nrows(&self) -> usize { - self.factors.nrows() - } - - fn ncols(&self) -> usize { - self.factors.ncols() - } - } - impl SolverCore for ColPivQr { - fn reconstruct(&self) -> Mat { - let mut rec = Mat::::zeros(self.nrows(), self.ncols()); - let parallelism = get_global_parallelism(); - - faer_qr::col_pivoting::reconstruct::reconstruct( - rec.as_mut(), - self.factors.as_ref(), - self.householder.as_ref(), - self.col_permutation(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_qr::col_pivoting::reconstruct::reconstruct_req::( - self.nrows(), - self.ncols(), - self.blocksize(), - parallelism, - ) - .unwrap(), - )), - ); - - rec - } - - fn inverse(&self) -> Mat { - assert!(self.nrows() == self.ncols()); - - let mut inv = Mat::::zeros(self.nrows(), self.ncols()); - let parallelism = get_global_parallelism(); - - faer_qr::col_pivoting::inverse::invert( - inv.as_mut(), - self.factors.as_ref(), - self.householder.as_ref(), - self.col_permutation(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_qr::col_pivoting::inverse::invert_req::( - self.nrows(), - self.ncols(), - self.blocksize(), - parallelism, - ) - .unwrap(), - )), - ); - - inv - } - } - - impl SpSolverLstsqCore for ColPivQr { - #[track_caller] - fn solve_lstsq_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - - faer_qr::col_pivoting::solve::solve_in_place( - self.factors.as_ref(), - self.householder.as_ref(), - self.col_permutation(), - conj, - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_qr::col_pivoting::solve::solve_in_place_req::( - self.nrows(), - self.blocksize(), - rhs_ncols, - ) - .unwrap(), - )), - ); - } - } - impl SolverLstsqCore for ColPivQr {} - - impl Svd { - #[track_caller] - fn __new_impl((matrix, conj): (MatRef<'_, E>, Conj), thin: bool) -> Self { - let parallelism = get_global_parallelism(); - let m = matrix.nrows(); - let n = matrix.ncols(); - let size = Ord::min(m, n); - - let mut s = Mat::::zeros(size, 1); - let mut u = Mat::::zeros(m, if thin { size } else { m }); - let mut v = Mat::::zeros(n, if thin { size } else { n }); - - let params = Default::default(); - - let compute_vecs = if thin { - faer_svd::ComputeVectors::Thin - } else { - faer_svd::ComputeVectors::Full - }; - - faer_svd::compute_svd( - matrix, - s.as_mut(), - Some(u.as_mut()), - Some(v.as_mut()), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_svd::compute_svd_req::( - m, - n, - compute_vecs, - compute_vecs, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - if matches!(conj, Conj::Yes) { - zipped!(u.as_mut()).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); - zipped!(v.as_mut()).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); - } - - Self { s, u, v } - } - - #[track_caller] - pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { - Self::__new_impl(matrix.canonicalize(), false) - } - - pub fn u(&self) -> MatRef<'_, E> { - self.u.as_ref() - } - pub fn s_diagonal(&self) -> MatRef<'_, E> { - self.s.as_ref() - } - pub fn v(&self) -> MatRef<'_, E> { - self.v.as_ref() - } - } - fn div_by_s(rhs: MatMut<'_, E>, s: MatRef<'_, E>) { - let mut rhs = rhs; - for j in 0..rhs.ncols() { - zipped!(rhs.rb_mut().col_mut(j).as_2d_mut(), s).for_each(|unzipped!(mut rhs, s)| { - rhs.write(rhs.read().faer_scale_real(s.read().faer_real().faer_inv())) - }); - } - } - impl SpSolverCore for Svd { - fn nrows(&self) -> usize { - self.u.nrows() - } - - fn ncols(&self) -> usize { - self.v.nrows() - } - - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - assert!(self.nrows() == self.ncols()); - let mut rhs = rhs; - - let u = self.u.as_ref(); - let v = self.v.as_ref(); - let s = self.s.as_ref(); - - match conj { - Conj::Yes => { - rhs.copy_from((u.transpose() * rhs.rb()).as_ref()); - div_by_s(rhs.rb_mut(), s); - rhs.copy_from((v.conjugate() * rhs.rb()).as_ref()); - } - Conj::No => { - rhs.copy_from((u.adjoint() * rhs.rb()).as_ref()); - div_by_s(rhs.rb_mut(), s); - rhs.copy_from((v * rhs.rb()).as_ref()); - } - } - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - assert!(self.nrows() == self.ncols()); - let mut rhs = rhs; - - let u = self.u.as_ref(); - let v = self.v.as_ref(); - let s = self.s.as_ref(); - - match conj { - Conj::No => { - rhs.copy_from((v.transpose() * rhs.rb()).as_ref()); - div_by_s(rhs.rb_mut(), s); - rhs.copy_from((u.conjugate() * rhs.rb()).as_ref()); - } - Conj::Yes => { - rhs.copy_from((v.adjoint() * rhs.rb()).as_ref()); - div_by_s(rhs.rb_mut(), s); - rhs.copy_from((u * rhs.rb()).as_ref()); - } - } - } - } - impl SolverCore for Svd { - fn reconstruct(&self) -> Mat { - let m = self.nrows(); - let n = self.ncols(); - let size = Ord::min(m, n); - - let thin_u = self.u.as_ref().submatrix(0, 0, m, size); - let s = self.s.as_ref(); - let us = Mat::::from_fn(m, size, |i, j| thin_u.read(i, j).faer_mul(s.read(j, 0))); - - us * self.v.adjoint() - } - - fn inverse(&self) -> Mat { - assert!(self.nrows() == self.ncols()); - let dim = self.nrows(); - - let u = self.u.as_ref(); - let v = self.v.as_ref(); - let s = self.s.as_ref(); - - let vs_inv = Mat::::from_fn(dim, dim, |i, j| { - v.read(i, j).faer_mul(s.read(j, 0).faer_inv()) - }); - - vs_inv * u.adjoint() - } - } - - impl ThinSvd { - #[track_caller] - pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { - Self { - inner: Svd::__new_impl(matrix.canonicalize(), true), - } - } - - pub fn u(&self) -> MatRef<'_, E> { - self.inner.u.as_ref() - } - pub fn s_diagonal(&self) -> MatRef<'_, E> { - self.inner.s.as_ref() - } - pub fn v(&self) -> MatRef<'_, E> { - self.inner.v.as_ref() - } - } - impl SpSolverCore for ThinSvd { - fn nrows(&self) -> usize { - self.inner.nrows() - } - - fn ncols(&self) -> usize { - self.inner.ncols() - } - - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - self.inner.solve_in_place_with_conj_impl(rhs, conj) - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - self.inner - .solve_transpose_in_place_with_conj_impl(rhs, conj) - } - } - impl SolverCore for ThinSvd { - fn reconstruct(&self) -> Mat { - self.inner.reconstruct() - } - - fn inverse(&self) -> Mat { - self.inner.inverse() - } - } - - impl SelfAdjointEigendecomposition { - #[track_caller] - fn __new_impl((matrix, conj): (MatRef<'_, E>, Conj), side: Side) -> Self { - assert!(matrix.nrows() == matrix.ncols()); - let parallelism = get_global_parallelism(); - - let dim = matrix.nrows(); - - let mut s = Mat::::zeros(dim, 1); - let mut u = Mat::::zeros(dim, dim); - - let matrix = match side { - Side::Lower => matrix, - Side::Upper => matrix.transpose(), - }; - let conj = conj.compose(match side { - Side::Lower => Conj::No, - Side::Upper => Conj::Yes, - }); - - let params = Default::default(); - faer_evd::compute_hermitian_evd( - matrix, - s.as_mut(), - Some(u.as_mut()), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_evd::compute_hermitian_evd_req::( - dim, - faer_evd::ComputeVectors::Yes, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - if matches!(conj, Conj::Yes) { - zipped!(u.as_mut()).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); - } - - Self { s, u } - } - - #[track_caller] - pub fn new>(matrix: MatRef<'_, ViewE>, side: Side) -> Self { - Self::__new_impl(matrix.canonicalize(), side) - } - - pub fn u(&self) -> MatRef<'_, E> { - self.u.as_ref() - } - pub fn s_diagonal(&self) -> MatRef<'_, E> { - self.s.as_ref() - } - } - impl SpSolverCore for SelfAdjointEigendecomposition { - fn nrows(&self) -> usize { - self.u.nrows() - } - - fn ncols(&self) -> usize { - self.u.nrows() - } - - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - assert!(self.nrows() == self.ncols()); - let mut rhs = rhs; - - let u = self.u.as_ref(); - let s = self.s.as_ref(); - - match conj { - Conj::Yes => { - rhs.copy_from((u.transpose() * rhs.rb()).as_ref()); - div_by_s(rhs.rb_mut(), s); - rhs.copy_from((u.conjugate() * rhs.rb()).as_ref()); - } - Conj::No => { - rhs.copy_from((u.adjoint() * rhs.rb()).as_ref()); - div_by_s(rhs.rb_mut(), s); - rhs.copy_from((u * rhs.rb()).as_ref()); - } - } - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - assert!(self.nrows() == self.ncols()); - let mut rhs = rhs; - - let u = self.u.as_ref(); - let s = self.s.as_ref(); - - match conj { - Conj::No => { - rhs.copy_from((u.transpose() * rhs.rb()).as_ref()); - div_by_s(rhs.rb_mut(), s); - rhs.copy_from((u.conjugate() * rhs.rb()).as_ref()); - } - Conj::Yes => { - rhs.copy_from((u.adjoint() * rhs.rb()).as_ref()); - div_by_s(rhs.rb_mut(), s); - rhs.copy_from((u * rhs.rb()).as_ref()); - } - } - } - } - impl SolverCore for SelfAdjointEigendecomposition { - fn reconstruct(&self) -> Mat { - let size = self.nrows(); - - let u = self.u.as_ref(); - let s = self.s.as_ref(); - let us = Mat::::from_fn(size, size, |i, j| u.read(i, j).faer_mul(s.read(j, 0))); - - us * u.adjoint() - } - - fn inverse(&self) -> Mat { - let dim = self.nrows(); - - let u = self.u.as_ref(); - let s = self.s.as_ref(); - - let us_inv = Mat::::from_fn(dim, dim, |i, j| { - u.read(i, j).faer_mul(s.read(j, 0).faer_inv()) - }); - - us_inv * u.adjoint() - } - } - - impl Eigendecomposition { - #[track_caller] - pub(crate) fn __values_from_real(matrix: MatRef<'_, E::Real>) -> Vec { - assert!(matrix.nrows() == matrix.ncols()); - if coe::is_same::() { - panic!( - "The type E ({}) must not be real-valued.", - core::any::type_name::(), - ); - } - - let parallelism = get_global_parallelism(); - - let dim = matrix.nrows(); - let mut s_re = Mat::::zeros(dim, 1); - let mut s_im = Mat::::zeros(dim, 1); - - let params = Default::default(); - - faer_evd::compute_evd_real( - matrix, - s_re.as_mut(), - s_im.as_mut(), - None, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_evd::compute_evd_req::( - dim, - faer_evd::ComputeVectors::Yes, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - let imag = E::faer_from_f64(-1.0).faer_sqrt(); - let cplx = |re: E::Real, im: E::Real| -> E { - E::faer_from_real(re).faer_add(imag.faer_mul(E::faer_from_real(im))) - }; - - (0..dim) - .map(|i| cplx(s_re.read(i, 0), s_im.read(i, 0))) - .collect() - } - - #[track_caller] - pub(crate) fn __values_from_complex_impl((matrix, conj): (MatRef<'_, E>, Conj)) -> Vec { - assert!(matrix.nrows() == matrix.ncols()); - if coe::is_same::() { - panic!( - "The type E ({}) must not be real-valued.", - core::any::type_name::(), - ); - } - - let parallelism = get_global_parallelism(); - let dim = matrix.nrows(); - - let mut s = Mat::::zeros(dim, 1); - - let params = Default::default(); - - faer_evd::compute_evd_complex( - matrix, - s.as_mut(), - None, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_evd::compute_evd_req::( - dim, - faer_evd::ComputeVectors::Yes, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - if matches!(conj, Conj::Yes) { - zipped!(s.as_mut()).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); - } - - (0..dim).map(|i| s.read(i, 0)).collect() - } - - #[track_caller] - pub fn new_from_real(matrix: MatRef<'_, E::Real>) -> Self { - assert!(matrix.nrows() == matrix.ncols()); - if coe::is_same::() { - panic!( - "The type E ({}) must not be real-valued.", - core::any::type_name::(), - ); - } - - let parallelism = get_global_parallelism(); - - let dim = matrix.nrows(); - let mut s_re = Col::::zeros(dim); - let mut s_im = Col::::zeros(dim); - let mut u_real = Mat::::zeros(dim, dim); - - let params = Default::default(); - - faer_evd::compute_evd_real( - matrix, - s_re.as_mut().as_2d_mut(), - s_im.as_mut().as_2d_mut(), - Some(u_real.as_mut()), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_evd::compute_evd_req::( - dim, - faer_evd::ComputeVectors::Yes, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - let imag = E::faer_from_f64(-1.0).faer_sqrt(); - let cplx = |re: E::Real, im: E::Real| -> E { - E::faer_from_real(re).faer_add(imag.faer_mul(E::faer_from_real(im))) - }; - - let s = Col::::from_fn(dim, |i| cplx(s_re.read(i), s_im.read(i))); - let mut u = Mat::::zeros(dim, dim); - let u_real = u_real.as_ref(); - - let mut j = 0usize; - while j < dim { - if s_im.read(j) == E::Real::faer_zero() { - zipped!(u.as_mut().col_mut(j).as_2d_mut(), u_real.col(j).as_2d()).for_each( - |unzipped!(mut dst, src)| dst.write(E::faer_from_real(src.read())), - ); - j += 1; - } else { - let (u_left, u_right) = u.as_mut().split_at_col_mut(j + 1); - - zipped!( - u_left.col_mut(j).as_2d_mut(), - u_right.col_mut(0).as_2d_mut(), - u_real.col(j).as_2d(), - u_real.col(j + 1).as_2d(), - ) - .for_each(|unzipped!(mut dst, mut dst_conj, re, im)| { - let re = re.read(); - let im = im.read(); - dst_conj.write(cplx(re, im.faer_neg())); - dst.write(cplx(re, im)); - }); - - j += 2; - } - } - - Self { s, u } - } - - #[track_caller] - pub(crate) fn __new_from_complex_impl((matrix, conj): (MatRef<'_, E>, Conj)) -> Self { - assert!(matrix.nrows() == matrix.ncols()); - if coe::is_same::() { - panic!( - "The type E ({}) must not be real-valued.", - core::any::type_name::(), - ); - } - - let parallelism = get_global_parallelism(); - let dim = matrix.nrows(); - - let mut s = Col::::zeros(dim); - let mut u = Mat::::zeros(dim, dim); - - let params = Default::default(); - - faer_evd::compute_evd_complex( - matrix, - s.as_mut().as_2d_mut(), - Some(u.as_mut()), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_evd::compute_evd_req::( - dim, - faer_evd::ComputeVectors::Yes, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - if matches!(conj, Conj::Yes) { - zipped!(s.as_mut().as_2d_mut()) - .for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); - zipped!(u.as_mut()).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); - } - - Self { s, u } - } - - #[track_caller] - pub fn new_from_complex>( - matrix: MatRef<'_, ViewE>, - ) -> Self { - Self::__new_from_complex_impl(matrix.canonicalize()) - } - - pub fn u(&self) -> MatRef<'_, E> { - self.u.as_ref() - } - pub fn s_diagonal(&self) -> ColRef<'_, E> { - self.s.as_ref() - } - } -} - -/// Extension trait for `faer` types. -pub trait FaerMat { - /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and - /// stores the result in `rhs`. - fn solve_lower_triangular_in_place(&self, rhs: impl AsMatMut); - /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and - /// stores the result in `rhs`. - fn solve_upper_triangular_in_place(&self, rhs: impl AsMatMut); - /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, - /// and stores the result in `rhs`. - fn solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut); - /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, - /// and stores the result in `rhs`. - fn solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut); - - /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and - /// returns the result. - #[track_caller] - fn solve_lower_triangular>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - let mut rhs = rhs.as_mat_ref().to_owned(); - self.solve_lower_triangular_in_place(rhs.as_mut()); - rhs - } - /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and - /// returns the result. - #[track_caller] - fn solve_upper_triangular>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - let mut rhs = rhs.as_mat_ref().to_owned(); - self.solve_upper_triangular_in_place(rhs.as_mut()); - rhs - } - /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, and - /// returns the result. - #[track_caller] - fn solve_unit_lower_triangular>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - let mut rhs = rhs.as_mat_ref().to_owned(); - self.solve_unit_lower_triangular_in_place(rhs.as_mut()); - rhs - } - /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, and - /// returns the result. - #[track_caller] - fn solve_unit_upper_triangular>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - let mut rhs = rhs.as_mat_ref().to_owned(); - self.solve_unit_upper_triangular_in_place(rhs.as_mut()); - rhs - } - - /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. - fn cholesky(&self, side: Side) -> Result, CholeskyError>; - /// Returns the Bunch-Kaufman decomposition of `self`. Only the provided side is accessed. - fn lblt(&self, side: Side) -> Lblt; - /// Returns the LU decomposition of `self` with partial (row) pivoting. - fn partial_piv_lu(&self) -> PartialPivLu; - /// Returns the LU decomposition of `self` with full pivoting. - fn full_piv_lu(&self) -> FullPivLu; - /// Returns the QR decomposition of `self`. - fn qr(&self) -> Qr; - /// Returns the QR decomposition of `self`, with column pivoting. - fn col_piv_qr(&self) -> ColPivQr; - /// Returns the SVD of `self`. - fn svd(&self) -> Svd; - /// Returns the thin SVD of `self`. - fn thin_svd(&self) -> ThinSvd; - /// Returns the eigendecomposition of `self`, assuming it is self-adjoint. Only the provided - /// side is accessed. - fn selfadjoint_eigendecomposition(&self, side: Side) -> SelfAdjointEigendecomposition; - /// Returns the eigendecomposition of `self`, as a complex matrix. - fn eigendecomposition>( - &self, - ) -> Eigendecomposition; - /// Returns the eigendecomposition of `self`, when `E` is in the complex domain. - fn complex_eigendecomposition(&self) -> Eigendecomposition; - - /// Returns the determinant of `self`. - fn determinant(&self) -> E; - /// Returns the singular values of `self`, in nonincreasing order. - fn singular_values(&self) -> Vec; - /// Returns the eigenvalues of `self`, assuming it is self-adjoint. Only the provided - /// side is accessed. The order of the eigenvalues is currently unspecified. - fn selfadjoint_eigenvalues(&self, side: Side) -> Vec; - /// Returns the eigenvalues of `self`, as complex values. The order of the eigenvalues is - /// currently unspecified. - fn eigenvalues>(&self) -> Vec; - /// Returns the eigenvalues of `self`, when `E` is in the complex domain. The order of the - /// eigenvalues is currently unspecified. - fn complex_eigenvalues(&self) -> Vec; -} - -/// Sparse solvers and traits. -pub mod sparse { - use super::*; - use faer_core::group_helpers::VecGroup; - - pub use faer_core::{ - permutation::Index, - sparse::{ - SparseColMat, SparseColMatMut, SparseColMatRef, SparseRowMat, SparseRowMatMut, - SparseRowMatRef, SymbolicSparseColMat, SymbolicSparseColMatRef, SymbolicSparseRowMat, - SymbolicSparseRowMatRef, - }, - }; - pub use faer_sparse::{lu::LuError, FaerError}; - - /// Sparse Cholesky error. - #[derive(Copy, Clone, Debug)] - pub enum CholeskyError { - Generic(FaerError), - SymbolicSingular, - NotPositiveDefinite, - } - - impl core::fmt::Display for CholeskyError { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - core::fmt::Debug::fmt(self, f) - } - } - - #[cfg(feature = "std")] - impl std::error::Error for CholeskyError {} - - impl From for CholeskyError { - #[inline] - fn from(value: FaerError) -> Self { - Self::Generic(value) - } - } - - impl From for CholeskyError { - #[inline] - fn from(_: crate::CholeskyError) -> Self { - Self::NotPositiveDefinite - } - } - - /// Sparse solvers. - /// - /// Each solver satisfies the [`SpSolver`] and/or [`SpSolverLstsq`] traits, which can be used - /// to solve linear systems. - pub mod solvers { - use super::*; - - /// Object-safe base for [`SpSolver`] - pub trait SpSolverCore { - /// Returns the number of rows of the matrix used to construct this decomposition. - fn nrows(&self) -> usize; - /// Returns the number of columns of the matrix used to construct this decomposition. - fn ncols(&self) -> usize; - - #[doc(hidden)] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj); - #[doc(hidden)] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj); - } - - pub trait SpSolverLstsqCore: SpSolverCore { - #[doc(hidden)] - fn solve_lstsq_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj); - } - - pub trait SpSolver: SpSolverCore { - /// Solves the equation `self * X = rhs`, and stores the result in `rhs`. - fn solve_in_place(&self, rhs: impl AsMatMut); - /// Solves the equation `conjugate(self) * X = rhs`, and stores the result in `rhs`. - fn solve_conj_in_place(&self, rhs: impl AsMatMut); - /// Solves the equation `transpose(self) * X = rhs`, and stores the result in `rhs`. - fn solve_transpose_in_place(&self, rhs: impl AsMatMut); - /// Solves the equation `adjoint(self) * X = rhs`, and stores the result in `rhs`. - fn solve_conj_transpose_in_place(&self, rhs: impl AsMatMut); - /// Solves the equation `self * X = rhs`, and returns the result. - fn solve>(&self, rhs: impl AsMatRef) -> Mat; - /// Solves the equation `conjugate(self) * X = rhs`, and returns the result. - fn solve_conj>( - &self, - rhs: impl AsMatRef, - ) -> Mat; - /// Solves the equation `transpose(self) * X = rhs`, and returns the result. - fn solve_transpose>( - &self, - rhs: impl AsMatRef, - ) -> Mat; - /// Solves the equation `adjoint(self) * X = rhs`, and returns the result. - fn solve_conj_transpose>( - &self, - rhs: impl AsMatRef, - ) -> Mat; - } - - pub trait SpSolverLstsq: SpSolverLstsqCore { - /// Solves the equation `self * X = rhs`, in the sense of least squares, and stores the - /// result in the top rows of `rhs`. - fn solve_lstsq_in_place(&self, rhs: impl AsMatMut); - /// Solves the equation `conjugate(self) * X = rhs`, in the sense of least squares, and - /// stores the result in the top rows of `rhs`. - fn solve_lstsq_conj_in_place(&self, rhs: impl AsMatMut); - /// Solves the equation `self * X = rhs`, and returns the result. - fn solve_lstsq>( - &self, - rhs: impl AsMatRef, - ) -> Mat; - /// Solves the equation `conjugate(self) * X = rhs`, and returns the result. - fn solve_lstsq_conj>( - &self, - rhs: impl AsMatRef, - ) -> Mat; - } - - #[track_caller] - fn solve_with_conj_impl< - E: ComplexField, - D: ?Sized + SpSolverCore, - ViewE: Conjugate, - >( - d: &D, - rhs: MatRef<'_, ViewE>, - conj: Conj, - ) -> Mat { - let mut rhs = rhs.to_owned(); - d.solve_in_place_with_conj_impl(rhs.as_mut(), conj); - rhs - } - - #[track_caller] - fn solve_transpose_with_conj_impl< - E: ComplexField, - D: ?Sized + SpSolverCore, - ViewE: Conjugate, - >( - d: &D, - rhs: MatRef<'_, ViewE>, - conj: Conj, - ) -> Mat { - let mut rhs = rhs.to_owned(); - d.solve_transpose_in_place_with_conj_impl(rhs.as_mut(), conj); - rhs - } - - #[track_caller] - fn solve_lstsq_with_conj_impl< - E: ComplexField, - D: ?Sized + SpSolverLstsqCore, - ViewE: Conjugate, - >( - d: &D, - rhs: MatRef<'_, ViewE>, - conj: Conj, - ) -> Mat { - let mut rhs = rhs.to_owned(); - let k = rhs.ncols(); - d.solve_lstsq_in_place_with_conj_impl(rhs.as_mut(), conj); - rhs.resize_with(d.ncols(), k, |_, _| unreachable!()); - rhs - } - - impl> SpSolver for Dec { - #[track_caller] - fn solve_in_place(&self, rhs: impl AsMatMut) { - let mut rhs = rhs; - self.solve_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::No) - } - - #[track_caller] - fn solve_conj_in_place(&self, rhs: impl AsMatMut) { - let mut rhs = rhs; - self.solve_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::Yes) - } - - #[track_caller] - fn solve_transpose_in_place(&self, rhs: impl AsMatMut) { - let mut rhs = rhs; - self.solve_transpose_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::No) - } - - #[track_caller] - fn solve_conj_transpose_in_place(&self, rhs: impl AsMatMut) { - let mut rhs = rhs; - self.solve_transpose_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::Yes) - } - - #[track_caller] - fn solve>(&self, rhs: impl AsMatRef) -> Mat { - solve_with_conj_impl::(self, rhs.as_mat_ref(), Conj::No) - } - - #[track_caller] - fn solve_conj>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - solve_with_conj_impl::(self, rhs.as_mat_ref(), Conj::Yes) - } - - #[track_caller] - fn solve_transpose>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - solve_transpose_with_conj_impl::(self, rhs.as_mat_ref(), Conj::No) - } - - #[track_caller] - fn solve_conj_transpose>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - solve_transpose_with_conj_impl::(self, rhs.as_mat_ref(), Conj::Yes) - } - } - - impl> SpSolverLstsq for Dec { - #[track_caller] - fn solve_lstsq_in_place(&self, rhs: impl AsMatMut) { - let mut rhs = rhs; - self.solve_lstsq_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::No) - } - - #[track_caller] - fn solve_lstsq_conj_in_place(&self, rhs: impl AsMatMut) { - let mut rhs = rhs; - self.solve_lstsq_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::Yes) - } - - #[track_caller] - fn solve_lstsq>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - solve_lstsq_with_conj_impl::(self, rhs.as_mat_ref(), Conj::No) - } - - #[track_caller] - fn solve_lstsq_conj>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - solve_lstsq_with_conj_impl::(self, rhs.as_mat_ref(), Conj::Yes) - } - } - - #[derive(Debug)] - pub struct SymbolicCholesky { - inner: alloc::sync::Arc>, - } - #[derive(Clone, Debug)] - pub struct Cholesky { - symbolic: SymbolicCholesky, - values: VecGroup, - } - - #[derive(Debug)] - pub struct SymbolicQr { - inner: alloc::sync::Arc>, - } - #[derive(Clone, Debug)] - pub struct Qr { - symbolic: SymbolicQr, - indices: alloc::vec::Vec, - values: VecGroup, - } - - #[derive(Debug)] - pub struct SymbolicLu { - inner: alloc::sync::Arc>, - } - #[derive(Clone, Debug)] - pub struct Lu { - symbolic: SymbolicLu, - numeric: faer_sparse::lu::NumericLu, - } - - impl Clone for SymbolicCholesky { - #[inline] - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - } - } - } - impl Clone for SymbolicQr { - #[inline] - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - } - } - } - impl Clone for SymbolicLu { - #[inline] - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - } - } - } - - impl SymbolicCholesky { - #[track_caller] - pub fn try_new( - mat: SymbolicSparseColMatRef<'_, I>, - side: Side, - ) -> Result { - Ok(Self { - inner: alloc::sync::Arc::new( - faer_sparse::cholesky::factorize_symbolic_cholesky( - mat, - side, - Default::default(), - )?, - ), - }) - } - } - impl SymbolicQr { - #[track_caller] - pub fn try_new(mat: SymbolicSparseColMatRef<'_, I>) -> Result { - Ok(Self { - inner: alloc::sync::Arc::new(faer_sparse::qr::factorize_symbolic_qr( - mat, - Default::default(), - )?), - }) - } - } - impl SymbolicLu { - #[track_caller] - pub fn try_new(mat: SymbolicSparseColMatRef<'_, I>) -> Result { - Ok(Self { - inner: alloc::sync::Arc::new(faer_sparse::lu::factorize_symbolic_lu( - mat, - Default::default(), - )?), - }) - } - } - - impl Cholesky { - #[track_caller] - pub fn try_new_with_symbolic( - symbolic: SymbolicCholesky, - mat: SparseColMatRef<'_, I, E>, - side: Side, - ) -> Result { - let len_values = symbolic.inner.len_values(); - let mut values = VecGroup::new(); - values - .try_reserve_exact(len_values) - .map_err(|_| FaerError::OutOfMemory)?; - values.resize(len_values, E::faer_zero().faer_into_units()); - let parallelism = get_global_parallelism(); - symbolic.inner.factorize_numeric_llt::( - values.as_slice_mut().into_inner(), - mat, - side, - Default::default(), - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - symbolic - .inner - .factorize_numeric_llt_req::(parallelism) - .map_err(|_| FaerError::OutOfMemory)?, - )), - )?; - Ok(Self { symbolic, values }) - } - } - - impl Qr { - #[track_caller] - pub fn try_new_with_symbolic( - symbolic: SymbolicQr, - mat: SparseColMatRef<'_, I, E>, - ) -> Result { - let len_values = symbolic.inner.len_values(); - let len_indices = symbolic.inner.len_indices(); - let mut values = VecGroup::new(); - let mut indices = alloc::vec::Vec::new(); - values - .try_reserve_exact(len_values) - .map_err(|_| FaerError::OutOfMemory)?; - indices - .try_reserve_exact(len_indices) - .map_err(|_| FaerError::OutOfMemory)?; - values.resize(len_values, E::faer_zero().faer_into_units()); - indices.resize(len_indices, I::truncate(0)); - let parallelism = get_global_parallelism(); - symbolic.inner.factorize_numeric_qr::( - &mut indices, - values.as_slice_mut().into_inner(), - mat, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - symbolic - .inner - .factorize_numeric_qr_req::(parallelism) - .map_err(|_| FaerError::OutOfMemory)?, - )), - ); - Ok(Self { - symbolic, - indices, - values, - }) - } - } - - impl Lu { - #[track_caller] - pub fn try_new_with_symbolic( - symbolic: SymbolicLu, - mat: SparseColMatRef<'_, I, E>, - ) -> Result { - let mut numeric = faer_sparse::lu::NumericLu::new(); - let parallelism = get_global_parallelism(); - symbolic.inner.factorize_numeric_lu::( - &mut numeric, - mat, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - symbolic - .inner - .factorize_numeric_lu_req::(parallelism) - .map_err(|_| FaerError::OutOfMemory)?, - )), - )?; - Ok(Self { symbolic, numeric }) - } - } - - impl SpSolverCore for Cholesky { - #[inline] - fn nrows(&self) -> usize { - self.symbolic.inner.nrows() - } - #[inline] - fn ncols(&self) -> usize { - self.symbolic.inner.ncols() - } - - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - faer_sparse::cholesky::LltRef::<'_, I, E>::new( - &self.symbolic.inner, - self.values.as_slice().into_inner(), - ) - .solve_in_place_with_conj( - conj, - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - self.symbolic - .inner - .solve_in_place_req::(rhs_ncols) - .unwrap(), - )), - ); - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - faer_sparse::cholesky::LltRef::<'_, I, E>::new( - &self.symbolic.inner, - self.values.as_slice().into_inner(), - ) - .solve_in_place_with_conj( - conj.compose(Conj::Yes), - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - self.symbolic - .inner - .solve_in_place_req::(rhs_ncols) - .unwrap(), - )), - ); - } - } - - impl SpSolverCore for Qr { - #[inline] - fn nrows(&self) -> usize { - self.symbolic.inner.nrows() - } - #[inline] - fn ncols(&self) -> usize { - self.symbolic.inner.ncols() - } - - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - self.solve_lstsq_in_place_with_conj_impl(rhs, conj); - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let _ = (&rhs, &conj); - unimplemented!( - "the sparse QR decomposition doesn't support solve_transpose.\n\ - consider using the sparse LU or Cholesky instead." - ) - } - } - - impl SpSolverLstsqCore for Qr { - #[track_caller] - fn solve_lstsq_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - unsafe { - faer_sparse::qr::QrRef::<'_, I, E>::new_unchecked( - &self.symbolic.inner, - &self.indices, - self.values.as_slice().into_inner(), - ) - } - .solve_in_place_with_conj( - conj, - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - self.symbolic - .inner - .solve_in_place_req::(rhs_ncols, parallelism) - .unwrap(), - )), - ); - } - } - - impl SpSolverCore for Lu { - #[inline] - fn nrows(&self) -> usize { - self.symbolic.inner.nrows() - } - #[inline] - fn ncols(&self) -> usize { - self.symbolic.inner.ncols() - } - - #[track_caller] - fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - unsafe { - faer_sparse::lu::LuRef::<'_, I, E>::new_unchecked( - &self.symbolic.inner, - &self.numeric, - ) - } - .solve_in_place_with_conj( - conj, - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - self.symbolic - .inner - .solve_in_place_req::(rhs_ncols, parallelism) - .unwrap(), - )), - ); - } - - #[track_caller] - fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { - let parallelism = get_global_parallelism(); - let rhs_ncols = rhs.ncols(); - unsafe { - faer_sparse::lu::LuRef::<'_, I, E>::new_unchecked( - &self.symbolic.inner, - &self.numeric, - ) - } - .solve_transpose_in_place_with_conj( - conj, - rhs, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - self.symbolic - .inner - .solve_in_place_req::(rhs_ncols, parallelism) - .unwrap(), - )), - ); - } - } - } - - /// Extension trait for sparse `faer` types. - pub trait FaerSparseMat { - /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and - /// stores the result in `rhs`. - /// - /// # Note - /// The matrix indices need not be sorted, but - /// the diagonal element is assumed to be the first stored element in each column, if the - /// matrix is column-major, or the last stored element in each row, if it is row-major. - fn sp_solve_lower_triangular_in_place(&self, rhs: impl AsMatMut); - /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and - /// stores the result in `rhs`. - /// - /// # Note - /// The matrix indices need not be sorted, but - /// the diagonal element is assumed to be the last stored element in each column, if the - /// matrix is column-major, or the first stored element in each row, if it is row-major. - fn sp_solve_upper_triangular_in_place(&self, rhs: impl AsMatMut); - /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, - /// and stores the result in `rhs`. - /// - /// # Note - /// The matrix indices need not be sorted, but - /// the diagonal element is assumed to be the first stored element in each column, if the - /// matrix is column-major, or the last stored element in each row, if it is row-major. - fn sp_solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut); - /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, - /// and stores the result in `rhs`. - /// - /// # Note - /// The matrix indices need not be sorted, but - /// the diagonal element is assumed to be the last stored element in each column, if the - /// matrix is column-major, or the first stored element in each row, if it is row-major. - fn sp_solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut); - - /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and - /// returns the result. - /// - /// # Note - /// The matrix indices need not be sorted, but - /// the diagonal element is assumed to be the first stored element in each column, if the - /// matrix is column-major, or the last stored element in each row, if it is row-major. - #[track_caller] - fn sp_solve_lower_triangular>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - let mut rhs = rhs.as_mat_ref().to_owned(); - self.sp_solve_lower_triangular_in_place(rhs.as_mut()); - rhs - } - /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and - /// returns the result. - /// - /// # Note - /// The matrix indices need not be sorted, but - /// the diagonal element is assumed to be the last stored element in each column, if the - /// matrix is column-major, or the first stored element in each row, if it is row-major. - #[track_caller] - fn sp_solve_upper_triangular>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - let mut rhs = rhs.as_mat_ref().to_owned(); - self.sp_solve_upper_triangular_in_place(rhs.as_mut()); - rhs - } - /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, - /// and returns the result. - /// - /// # Note - /// The matrix indices need not be sorted, but - /// the diagonal element is assumed to be the first stored element in each column, if the - /// matrix is column-major, or the last stored element in each row, if it is row-major. - #[track_caller] - fn sp_solve_unit_lower_triangular>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - let mut rhs = rhs.as_mat_ref().to_owned(); - self.sp_solve_unit_lower_triangular_in_place(rhs.as_mut()); - rhs - } - /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, - /// and returns the result. - /// - /// # Note - /// The matrix indices need not be sorted, but - /// the diagonal element is assumed to be the first stored element in each column, if the - /// matrix is column-major, or the last stored element in each row, if it is row-major. - #[track_caller] - fn sp_solve_unit_upper_triangular>( - &self, - rhs: impl AsMatRef, - ) -> Mat { - let mut rhs = rhs.as_mat_ref().to_owned(); - self.sp_solve_unit_upper_triangular_in_place(rhs.as_mut()); - rhs - } - - /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. - fn sp_cholesky(&self, side: Side) - -> Result, sparse::CholeskyError>; - - /// Returns the LU decomposition of `self` with partial (row) pivoting. - fn sp_lu(&self) -> Result, LuError>; - - /// Returns the QR decomposition of `self`. - fn sp_qr(&self) -> Result, FaerError>; - } - - impl FaerSparseMat for SparseColMatRef<'_, I, E> { - #[track_caller] - fn sp_solve_lower_triangular_in_place(&self, mut rhs: impl AsMatMut) { - faer_sparse::triangular_solve::solve_lower_triangular_in_place( - *self, - Conj::No, - rhs.as_mat_mut(), - get_global_parallelism(), - ); - } - #[track_caller] - fn sp_solve_upper_triangular_in_place(&self, mut rhs: impl AsMatMut) { - faer_sparse::triangular_solve::solve_upper_triangular_in_place( - *self, - Conj::No, - rhs.as_mat_mut(), - get_global_parallelism(), - ); - } - #[track_caller] - fn sp_solve_unit_lower_triangular_in_place(&self, mut rhs: impl AsMatMut) { - faer_sparse::triangular_solve::solve_unit_lower_triangular_in_place( - *self, - Conj::No, - rhs.as_mat_mut(), - get_global_parallelism(), - ); - } - #[track_caller] - fn sp_solve_unit_upper_triangular_in_place(&self, mut rhs: impl AsMatMut) { - faer_sparse::triangular_solve::solve_unit_upper_triangular_in_place( - *self, - Conj::No, - rhs.as_mat_mut(), - get_global_parallelism(), - ); - } - - /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. - #[track_caller] - fn sp_cholesky( - &self, - side: Side, - ) -> Result, sparse::CholeskyError> { - solvers::Cholesky::try_new_with_symbolic( - solvers::SymbolicCholesky::try_new(self.symbolic(), side)?, - *self, - side, - ) - } - - /// Returns the LU decomposition of `self` with partial (row) pivoting. - #[track_caller] - fn sp_lu(&self) -> Result, LuError> { - solvers::Lu::try_new_with_symbolic( - solvers::SymbolicLu::try_new(self.symbolic())?, - *self, - ) - } - - /// Returns the QR decomposition of `self`. - #[track_caller] - fn sp_qr(&self) -> Result, FaerError> { - solvers::Qr::try_new_with_symbolic( - solvers::SymbolicQr::try_new(self.symbolic())?, - *self, - ) - } - } - - impl FaerSparseMat for SparseRowMatRef<'_, I, E> { - #[track_caller] - fn sp_solve_lower_triangular_in_place(&self, mut rhs: impl AsMatMut) { - faer_sparse::triangular_solve::solve_upper_triangular_in_place( - self.transpose(), - Conj::No, - rhs.as_mat_mut(), - get_global_parallelism(), - ); - } - #[track_caller] - fn sp_solve_upper_triangular_in_place(&self, mut rhs: impl AsMatMut) { - faer_sparse::triangular_solve::solve_lower_triangular_in_place( - self.transpose(), - Conj::No, - rhs.as_mat_mut(), - get_global_parallelism(), - ); - } - #[track_caller] - fn sp_solve_unit_lower_triangular_in_place(&self, mut rhs: impl AsMatMut) { - faer_sparse::triangular_solve::solve_unit_upper_triangular_in_place( - self.transpose(), - Conj::No, - rhs.as_mat_mut(), - get_global_parallelism(), - ); - } - #[track_caller] - fn sp_solve_unit_upper_triangular_in_place(&self, mut rhs: impl AsMatMut) { - faer_sparse::triangular_solve::solve_unit_lower_triangular_in_place( - self.transpose(), - Conj::No, - rhs.as_mat_mut(), - get_global_parallelism(), - ); - } - - /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. - #[track_caller] - fn sp_cholesky( - &self, - side: Side, - ) -> Result, sparse::CholeskyError> { - let this = self.transpose(); - let side = match side { - Side::Lower => Side::Upper, - Side::Upper => Side::Lower, - }; - solvers::Cholesky::try_new_with_symbolic( - solvers::SymbolicCholesky::try_new(this.symbolic(), side)?, - this, - side, - ) - } - - /// Returns the LU decomposition of `self` with partial (row) pivoting. - #[track_caller] - fn sp_lu(&self) -> Result, LuError> { - let this = self.to_col_major()?; - let this = this.as_ref(); - solvers::Lu::try_new_with_symbolic(solvers::SymbolicLu::try_new(this.symbolic())?, this) - } - - /// Returns the QR decomposition of `self`. - #[track_caller] - fn sp_qr(&self) -> Result, FaerError> { - let this = self.to_col_major()?; - let this = this.as_ref(); - solvers::Qr::try_new_with_symbolic(solvers::SymbolicQr::try_new(this.symbolic())?, this) - } - } - - impl FaerSparseMat for SparseColMat { - #[track_caller] - fn sp_solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().sp_solve_lower_triangular_in_place(rhs); - } - #[track_caller] - fn sp_solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().sp_solve_upper_triangular_in_place(rhs); - } - #[track_caller] - fn sp_solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().sp_solve_unit_lower_triangular_in_place(rhs); - } - #[track_caller] - fn sp_solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().sp_solve_unit_upper_triangular_in_place(rhs); - } - - /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. - #[track_caller] - fn sp_cholesky( - &self, - side: Side, - ) -> Result, sparse::CholeskyError> { - self.as_ref().sp_cholesky(side) - } - - /// Returns the LU decomposition of `self` with partial (row) pivoting. - #[track_caller] - fn sp_lu(&self) -> Result, LuError> { - self.as_ref().sp_lu() - } - - /// Returns the QR decomposition of `self`. - #[track_caller] - fn sp_qr(&self) -> Result, FaerError> { - self.as_ref().sp_qr() - } - } - - impl FaerSparseMat for SparseRowMat { - #[track_caller] - fn sp_solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().sp_solve_lower_triangular_in_place(rhs); - } - #[track_caller] - fn sp_solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().sp_solve_upper_triangular_in_place(rhs); - } - #[track_caller] - fn sp_solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().sp_solve_unit_lower_triangular_in_place(rhs); - } - #[track_caller] - fn sp_solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().sp_solve_unit_upper_triangular_in_place(rhs); - } - - /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. - #[track_caller] - fn sp_cholesky( - &self, - side: Side, - ) -> Result, sparse::CholeskyError> { - self.as_ref().sp_cholesky(side) - } - - /// Returns the LU decomposition of `self` with partial (row) pivoting. - #[track_caller] - fn sp_lu(&self) -> Result, LuError> { - self.as_ref().sp_lu() - } - - /// Returns the QR decomposition of `self`. - #[track_caller] - fn sp_qr(&self) -> Result, FaerError> { - self.as_ref().sp_qr() - } - } -} - -impl FaerMat for MatRef<'_, E> -where - E::Canonical: ComplexField, -{ - #[track_caller] - fn solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { - let parallelism = get_global_parallelism(); - let mut rhs = rhs; - faer_core::solve::solve_lower_triangular_in_place(*self, rhs.as_mat_mut(), parallelism); - } - #[track_caller] - fn solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { - let parallelism = get_global_parallelism(); - let mut rhs = rhs; - faer_core::solve::solve_upper_triangular_in_place(*self, rhs.as_mat_mut(), parallelism); - } - #[track_caller] - fn solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { - let parallelism = get_global_parallelism(); - let mut rhs = rhs; - faer_core::solve::solve_unit_lower_triangular_in_place( - *self, - rhs.as_mat_mut(), - parallelism, - ); - } - #[track_caller] - fn solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { - let parallelism = get_global_parallelism(); - let mut rhs = rhs; - faer_core::solve::solve_unit_upper_triangular_in_place( - *self, - rhs.as_mat_mut(), - parallelism, - ); - } - - #[track_caller] - fn cholesky(&self, side: Side) -> Result, CholeskyError> { - Cholesky::try_new(self.as_ref(), side) - } - #[track_caller] - fn lblt(&self, side: Side) -> Lblt { - Lblt::new(self.as_ref(), side) - } - #[track_caller] - fn partial_piv_lu(&self) -> PartialPivLu { - PartialPivLu::::new(self.as_ref()) - } - #[track_caller] - fn full_piv_lu(&self) -> FullPivLu { - FullPivLu::::new(self.as_ref()) - } - #[track_caller] - fn qr(&self) -> Qr { - Qr::::new(self.as_ref()) - } - #[track_caller] - fn col_piv_qr(&self) -> ColPivQr { - ColPivQr::::new(self.as_ref()) - } - #[track_caller] - fn svd(&self) -> Svd { - Svd::::new(self.as_ref()) - } - #[track_caller] - fn thin_svd(&self) -> ThinSvd { - ThinSvd::::new(self.as_ref()) - } - #[track_caller] - fn selfadjoint_eigendecomposition( - &self, - side: Side, - ) -> SelfAdjointEigendecomposition { - SelfAdjointEigendecomposition::::new(self.as_ref(), side) - } - - #[track_caller] - fn eigendecomposition::Real>>( - &self, - ) -> Eigendecomposition { - if coe::is_same::::Real>() { - let matrix: MatRef<'_, ::Real> = - coe::coerce(self.as_ref()); - Eigendecomposition::::new_from_real(matrix) - } else if coe::is_same::() { - let (matrix, conj) = self.as_ref().canonicalize(); - Eigendecomposition::::__new_from_complex_impl((coe::coerce(matrix), conj)) - } else { - panic!( - "The type ComplexE must be either E::Canonical ({}) or E::Canonical::Real ({})", - core::any::type_name::(), - core::any::type_name::<::Real>(), - ); - } - } - - #[track_caller] - fn complex_eigendecomposition(&self) -> Eigendecomposition { - Eigendecomposition::::new_from_complex(self.as_ref()) - } - - #[track_caller] - fn determinant(&self) -> E::Canonical { - assert!(self.nrows() == self.ncols()); - let lu = self.partial_piv_lu(); - let mut det = E::Canonical::faer_one(); - for i in 0..self.nrows() { - det = det.faer_mul(lu.factors.read(i, i)); - } - if lu.transposition_count() % 2 == 0 { - det - } else { - det.faer_neg() - } - } - - #[track_caller] - fn selfadjoint_eigenvalues(&self, side: Side) -> Vec<::Real> { - let matrix = match side { - Side::Lower => *self, - Side::Upper => self.transpose(), - }; - - assert!(matrix.nrows() == matrix.ncols()); - let dim = matrix.nrows(); - let parallelism = get_global_parallelism(); - - let mut s = Mat::::zeros(dim, 1); - let params = Default::default(); - faer_evd::compute_hermitian_evd( - matrix.canonicalize().0, - s.as_mut(), - None, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_evd::compute_hermitian_evd_req::( - dim, - faer_evd::ComputeVectors::No, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - (0..dim).map(|i| s.read(i, 0).faer_real()).collect() - } - - #[track_caller] - fn singular_values(&self) -> Vec<::Real> { - let dim = Ord::min(self.nrows(), self.ncols()); - let parallelism = get_global_parallelism(); - - let mut s = Mat::::zeros(dim, 1); - let params = Default::default(); - faer_svd::compute_svd( - self.canonicalize().0, - s.as_mut(), - None, - None, - parallelism, - PodStack::new(&mut GlobalPodBuffer::new( - faer_svd::compute_svd_req::( - self.nrows(), - self.ncols(), - faer_svd::ComputeVectors::No, - faer_svd::ComputeVectors::No, - parallelism, - params, - ) - .unwrap(), - )), - params, - ); - - (0..dim).map(|i| s.read(i, 0).faer_real()).collect() - } - - #[track_caller] - fn eigenvalues::Real>>( - &self, - ) -> Vec { - if coe::is_same::::Real>() { - let matrix: MatRef<'_, ::Real> = - coe::coerce(self.as_ref()); - Eigendecomposition::::__values_from_real(matrix) - } else if coe::is_same::() { - let (matrix, conj) = self.as_ref().canonicalize(); - Eigendecomposition::::__values_from_complex_impl((coe::coerce(matrix), conj)) - } else { - panic!( - "The type ComplexE must be either E::Canonical ({}) or E::Canonical::Real ({})", - core::any::type_name::(), - core::any::type_name::<::Real>(), - ); - } - } - - #[track_caller] - fn complex_eigenvalues(&self) -> Vec { - Eigendecomposition::::__values_from_complex_impl(self.canonicalize()) - } -} - -impl FaerMat for MatMut<'_, E> -where - E::Canonical: ComplexField, -{ - #[track_caller] - fn solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().solve_lower_triangular_in_place(rhs) - } - #[track_caller] - fn solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().solve_upper_triangular_in_place(rhs) - } - #[track_caller] - fn solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().solve_unit_lower_triangular_in_place(rhs) - } - #[track_caller] - fn solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().solve_unit_upper_triangular_in_place(rhs) - } - - #[track_caller] - fn cholesky(&self, side: Side) -> Result, CholeskyError> { - self.as_ref().cholesky(side) - } - #[track_caller] - fn lblt(&self, side: Side) -> Lblt { - self.as_ref().lblt(side) - } - #[track_caller] - fn partial_piv_lu(&self) -> PartialPivLu { - self.as_ref().partial_piv_lu() - } - #[track_caller] - fn full_piv_lu(&self) -> FullPivLu { - self.as_ref().full_piv_lu() - } - #[track_caller] - fn qr(&self) -> Qr { - self.as_ref().qr() - } - #[track_caller] - fn col_piv_qr(&self) -> ColPivQr { - self.as_ref().col_piv_qr() - } - #[track_caller] - fn svd(&self) -> Svd { - self.as_ref().svd() - } - #[track_caller] - fn thin_svd(&self) -> ThinSvd { - self.as_ref().thin_svd() - } - #[track_caller] - fn selfadjoint_eigendecomposition( - &self, - side: Side, - ) -> SelfAdjointEigendecomposition { - self.as_ref().selfadjoint_eigendecomposition(side) - } - - #[track_caller] - fn eigendecomposition::Real>>( - &self, - ) -> Eigendecomposition { - self.as_ref().eigendecomposition::() - } - - #[track_caller] - fn complex_eigendecomposition(&self) -> Eigendecomposition { - self.as_ref().complex_eigendecomposition() - } - - #[track_caller] - fn determinant(&self) -> E::Canonical { - self.as_ref().determinant() - } - - #[track_caller] - fn selfadjoint_eigenvalues(&self, side: Side) -> Vec<::Real> { - self.as_ref().selfadjoint_eigenvalues(side) - } - - #[track_caller] - fn singular_values(&self) -> Vec<::Real> { - self.as_ref().singular_values() - } - - #[track_caller] - fn eigenvalues::Real>>( - &self, - ) -> Vec { - self.as_ref().eigenvalues() - } - - #[track_caller] - fn complex_eigenvalues(&self) -> Vec { - self.as_ref().complex_eigenvalues() - } -} - -impl FaerMat for Mat -where - E::Canonical: ComplexField, -{ - #[track_caller] - fn solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().solve_lower_triangular_in_place(rhs) - } - #[track_caller] - fn solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().solve_upper_triangular_in_place(rhs) - } - #[track_caller] - fn solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().solve_unit_lower_triangular_in_place(rhs) - } - #[track_caller] - fn solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { - self.as_ref().solve_unit_upper_triangular_in_place(rhs) - } - - #[track_caller] - fn cholesky(&self, side: Side) -> Result, CholeskyError> { - self.as_ref().cholesky(side) - } - #[track_caller] - fn lblt(&self, side: Side) -> Lblt { - self.as_ref().lblt(side) - } - #[track_caller] - fn partial_piv_lu(&self) -> PartialPivLu { - self.as_ref().partial_piv_lu() - } - #[track_caller] - fn full_piv_lu(&self) -> FullPivLu { - self.as_ref().full_piv_lu() - } - #[track_caller] - fn qr(&self) -> Qr { - self.as_ref().qr() - } - #[track_caller] - fn col_piv_qr(&self) -> ColPivQr { - self.as_ref().col_piv_qr() - } - #[track_caller] - fn svd(&self) -> Svd { - self.as_ref().svd() - } - #[track_caller] - fn thin_svd(&self) -> ThinSvd { - self.as_ref().thin_svd() - } - #[track_caller] - fn selfadjoint_eigendecomposition( - &self, - side: Side, - ) -> SelfAdjointEigendecomposition { - self.as_ref().selfadjoint_eigendecomposition(side) - } - - #[track_caller] - fn eigendecomposition::Real>>( - &self, - ) -> Eigendecomposition { - self.as_ref().eigendecomposition::() - } - - #[track_caller] - fn complex_eigendecomposition(&self) -> Eigendecomposition { - self.as_ref().complex_eigendecomposition() - } - - #[track_caller] - fn determinant(&self) -> E::Canonical { - self.as_ref().determinant() - } - - #[track_caller] - fn selfadjoint_eigenvalues(&self, side: Side) -> Vec<::Real> { - self.as_ref().selfadjoint_eigenvalues(side) - } - - #[track_caller] - fn singular_values(&self) -> Vec<::Real> { - self.as_ref().singular_values() - } - - #[track_caller] - fn eigenvalues::Real>>( - &self, - ) -> Vec { - self.as_ref().eigenvalues() - } - - #[track_caller] - fn complex_eigenvalues(&self) -> Vec { - self.as_ref().complex_eigenvalues() - } -} - -/// Conversions from external library matrix views into `faer` types. -pub trait IntoFaer { - type Faer; - fn into_faer(self) -> Self::Faer; -} - -#[cfg(feature = "nalgebra")] -#[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))] -/// Conversions from external library matrix views into `nalgebra` types. -pub trait IntoNalgebra { - type Nalgebra; - fn into_nalgebra(self) -> Self::Nalgebra; -} - -#[cfg(feature = "ndarray")] -#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -/// Conversions from external library matrix views into `ndarray` types. -pub trait IntoNdarray { - type Ndarray; - fn into_ndarray(self) -> Self::Ndarray; -} - -/// Conversions from external library matrix views into complex `faer` types. -pub trait IntoFaerComplex { - type Faer; - fn into_faer_complex(self) -> Self::Faer; -} - -#[cfg(feature = "nalgebra")] -#[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))] -/// Conversions from external library matrix views into complex `nalgebra` types. -pub trait IntoNalgebraComplex { - type Nalgebra; - fn into_nalgebra_complex(self) -> Self::Nalgebra; -} - -#[cfg(feature = "ndarray")] -#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -/// Conversions from external library matrix views into complex `ndarray` types. -pub trait IntoNdarrayComplex { - type Ndarray; - fn into_ndarray_complex(self) -> Self::Ndarray; -} - -#[cfg(feature = "nalgebra")] -#[cfg_attr(docsrs, doc(cfg(feature = "nalgebra")))] -const _: () = { - use complex_native::*; - use faer_core::SimpleEntity; - use nalgebra::{Dim, Dyn, MatrixView, MatrixViewMut, ViewStorage, ViewStorageMut}; - use num_complex::{Complex32, Complex64}; - - impl<'a, T: SimpleEntity, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaer - for MatrixView<'a, T, R, C, RStride, CStride> - { - type Faer = MatRef<'a, T>; - - #[track_caller] - fn into_faer(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides = self.strides(); - let ptr = self.as_ptr(); - unsafe { - faer_core::mat::from_raw_parts( - ptr, - nrows, - ncols, - strides.0.try_into().unwrap(), - strides.1.try_into().unwrap(), - ) - } - } - } - - impl<'a, T: SimpleEntity, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaer - for MatrixViewMut<'a, T, R, C, RStride, CStride> - { - type Faer = MatMut<'a, T>; - - #[track_caller] - fn into_faer(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides = self.strides(); - let ptr = { self }.as_mut_ptr(); - unsafe { - faer_core::mat::from_raw_parts_mut::<'_, T>( - ptr, - nrows, - ncols, - strides.0.try_into().unwrap(), - strides.1.try_into().unwrap(), - ) - } - } - } - - impl<'a, T: SimpleEntity> IntoNalgebra for MatRef<'a, T> { - type Nalgebra = MatrixView<'a, T, Dyn, Dyn, Dyn, Dyn>; - - #[track_caller] - fn into_nalgebra(self) -> Self::Nalgebra { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - let ptr = self.as_ptr(); - unsafe { - MatrixView::<'_, T, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorage::< - '_, - T, - Dyn, - Dyn, - Dyn, - Dyn, - >::from_raw_parts( - ptr, - (Dyn(nrows), Dyn(ncols)), - ( - Dyn(row_stride.try_into().unwrap()), - Dyn(col_stride.try_into().unwrap()), - ), - )) - } - } - } - - impl<'a, T: SimpleEntity> IntoNalgebra for MatMut<'a, T> { - type Nalgebra = MatrixViewMut<'a, T, Dyn, Dyn, Dyn, Dyn>; - - #[track_caller] - fn into_nalgebra(self) -> Self::Nalgebra { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - let ptr = self.as_ptr_mut(); - unsafe { - MatrixViewMut::<'_, T, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorageMut::< - '_, - T, - Dyn, - Dyn, - Dyn, - Dyn, - >::from_raw_parts( - ptr, - (Dyn(nrows), Dyn(ncols)), - ( - Dyn(row_stride.try_into().unwrap()), - Dyn(col_stride.try_into().unwrap()), - ), - )) - } - } - } - - impl<'a, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaerComplex - for MatrixView<'a, Complex32, R, C, RStride, CStride> - { - type Faer = MatRef<'a, c32>; - - #[track_caller] - fn into_faer_complex(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides = self.strides(); - let ptr = self.as_ptr() as *const c32; - unsafe { - faer_core::mat::from_raw_parts( - ptr, - nrows, - ncols, - strides.0.try_into().unwrap(), - strides.1.try_into().unwrap(), - ) - } - } - } - - impl<'a, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaerComplex - for MatrixViewMut<'a, Complex32, R, C, RStride, CStride> - { - type Faer = MatMut<'a, c32>; - - #[track_caller] - fn into_faer_complex(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides = self.strides(); - let ptr = { self }.as_mut_ptr() as *mut c32; - unsafe { - faer_core::mat::from_raw_parts_mut( - ptr, - nrows, - ncols, - strides.0.try_into().unwrap(), - strides.1.try_into().unwrap(), - ) - } - } - } - - impl<'a> IntoNalgebraComplex for MatRef<'a, c32> { - type Nalgebra = MatrixView<'a, Complex32, Dyn, Dyn, Dyn, Dyn>; - - #[track_caller] - fn into_nalgebra_complex(self) -> Self::Nalgebra { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - let ptr = self.as_ptr() as *const Complex32; - unsafe { - MatrixView::<'_, Complex32, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorage::< - '_, - Complex32, - Dyn, - Dyn, - Dyn, - Dyn, - >::from_raw_parts( - ptr, - (Dyn(nrows), Dyn(ncols)), - ( - Dyn(row_stride.try_into().unwrap()), - Dyn(col_stride.try_into().unwrap()), - ), - )) - } - } - } - - impl<'a> IntoNalgebraComplex for MatMut<'a, c32> { - type Nalgebra = MatrixViewMut<'a, Complex32, Dyn, Dyn, Dyn, Dyn>; - - #[track_caller] - fn into_nalgebra_complex(self) -> Self::Nalgebra { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - let ptr = self.as_ptr_mut() as *mut Complex32; - unsafe { - MatrixViewMut::<'_, Complex32, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorageMut::< - '_, - Complex32, - Dyn, - Dyn, - Dyn, - Dyn, - >::from_raw_parts( - ptr, - (Dyn(nrows), Dyn(ncols)), - ( - Dyn(row_stride.try_into().unwrap()), - Dyn(col_stride.try_into().unwrap()), - ), - )) - } - } - } - - impl<'a, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaerComplex - for MatrixView<'a, Complex64, R, C, RStride, CStride> - { - type Faer = MatRef<'a, c64>; - - #[track_caller] - fn into_faer_complex(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides = self.strides(); - let ptr = self.as_ptr() as *const c64; - unsafe { - faer_core::mat::from_raw_parts( - ptr, - nrows, - ncols, - strides.0.try_into().unwrap(), - strides.1.try_into().unwrap(), - ) - } - } - } - - impl<'a, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoFaerComplex - for MatrixViewMut<'a, Complex64, R, C, RStride, CStride> - { - type Faer = MatMut<'a, c64>; - - #[track_caller] - fn into_faer_complex(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides = self.strides(); - let ptr = { self }.as_mut_ptr() as *mut c64; - unsafe { - faer_core::mat::from_raw_parts_mut( - ptr, - nrows, - ncols, - strides.0.try_into().unwrap(), - strides.1.try_into().unwrap(), - ) - } - } - } - - impl<'a> IntoNalgebraComplex for MatRef<'a, c64> { - type Nalgebra = MatrixView<'a, Complex64, Dyn, Dyn, Dyn, Dyn>; - - #[track_caller] - fn into_nalgebra_complex(self) -> Self::Nalgebra { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - let ptr = self.as_ptr() as *const Complex64; - unsafe { - MatrixView::<'_, Complex64, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorage::< - '_, - Complex64, - Dyn, - Dyn, - Dyn, - Dyn, - >::from_raw_parts( - ptr, - (Dyn(nrows), Dyn(ncols)), - ( - Dyn(row_stride.try_into().unwrap()), - Dyn(col_stride.try_into().unwrap()), - ), - )) - } - } - } - - impl<'a> IntoNalgebraComplex for MatMut<'a, c64> { - type Nalgebra = MatrixViewMut<'a, Complex64, Dyn, Dyn, Dyn, Dyn>; - - #[track_caller] - fn into_nalgebra_complex(self) -> Self::Nalgebra { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride = self.row_stride(); - let col_stride = self.col_stride(); - let ptr = self.as_ptr_mut() as *mut Complex64; - unsafe { - MatrixViewMut::<'_, Complex64, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorageMut::< - '_, - Complex64, - Dyn, - Dyn, - Dyn, - Dyn, - >::from_raw_parts( - ptr, - (Dyn(nrows), Dyn(ncols)), - ( - Dyn(row_stride.try_into().unwrap()), - Dyn(col_stride.try_into().unwrap()), - ), - )) - } - } - } -}; - -#[cfg(feature = "ndarray")] -#[cfg_attr(docsrs, doc(cfg(feature = "ndarray")))] -const _: () = { - use complex_native::*; - use faer_core::SimpleEntity; - use ndarray::{ArrayView, ArrayViewMut, IntoDimension, Ix2, ShapeBuilder}; - use num_complex::{Complex32, Complex64}; - - impl<'a, T: SimpleEntity> IntoFaer for ArrayView<'a, T, Ix2> { - type Faer = MatRef<'a, T>; - - #[track_caller] - fn into_faer(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides: [isize; 2] = self.strides().try_into().unwrap(); - let ptr = self.as_ptr(); - unsafe { faer_core::mat::from_raw_parts(ptr, nrows, ncols, strides[0], strides[1]) } - } - } - - impl<'a, T: SimpleEntity> IntoFaer for ArrayViewMut<'a, T, Ix2> { - type Faer = MatMut<'a, T>; - - #[track_caller] - fn into_faer(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides: [isize; 2] = self.strides().try_into().unwrap(); - let ptr = { self }.as_mut_ptr(); - unsafe { - faer_core::mat::from_raw_parts_mut::<'_, T>( - ptr, nrows, ncols, strides[0], strides[1], - ) - } - } - } - - impl<'a, T: SimpleEntity> IntoNdarray for MatRef<'a, T> { - type Ndarray = ArrayView<'a, T, Ix2>; - - #[track_caller] - fn into_ndarray(self) -> Self::Ndarray { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride: usize = self.row_stride().try_into().unwrap(); - let col_stride: usize = self.col_stride().try_into().unwrap(); - let ptr = self.as_ptr(); - unsafe { - ArrayView::<'_, T, Ix2>::from_shape_ptr( - (nrows, ncols) - .into_shape() - .strides((row_stride, col_stride).into_dimension()), - ptr, - ) - } - } - } - - impl<'a, T: SimpleEntity> IntoNdarray for MatMut<'a, T> { - type Ndarray = ArrayViewMut<'a, T, Ix2>; - - #[track_caller] - fn into_ndarray(self) -> Self::Ndarray { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride: usize = self.row_stride().try_into().unwrap(); - let col_stride: usize = self.col_stride().try_into().unwrap(); - let ptr = self.as_ptr_mut(); - unsafe { - ArrayViewMut::<'_, T, Ix2>::from_shape_ptr( - (nrows, ncols) - .into_shape() - .strides((row_stride, col_stride).into_dimension()), - ptr, - ) - } - } - } - - impl<'a> IntoFaerComplex for ArrayView<'a, Complex32, Ix2> { - type Faer = MatRef<'a, c32>; - - #[track_caller] - fn into_faer_complex(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides: [isize; 2] = self.strides().try_into().unwrap(); - let ptr = self.as_ptr() as *const c32; - unsafe { faer_core::mat::from_raw_parts(ptr, nrows, ncols, strides[0], strides[1]) } - } - } - - impl<'a> IntoFaerComplex for ArrayViewMut<'a, Complex32, Ix2> { - type Faer = MatMut<'a, c32>; - - #[track_caller] - fn into_faer_complex(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides: [isize; 2] = self.strides().try_into().unwrap(); - let ptr = { self }.as_mut_ptr() as *mut c32; - unsafe { faer_core::mat::from_raw_parts_mut(ptr, nrows, ncols, strides[0], strides[1]) } - } - } - - impl<'a> IntoNdarrayComplex for MatRef<'a, c32> { - type Ndarray = ArrayView<'a, Complex32, Ix2>; - - #[track_caller] - fn into_ndarray_complex(self) -> Self::Ndarray { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride: usize = self.row_stride().try_into().unwrap(); - let col_stride: usize = self.col_stride().try_into().unwrap(); - let ptr = self.as_ptr() as *const Complex32; - unsafe { - ArrayView::<'_, Complex32, Ix2>::from_shape_ptr( - (nrows, ncols) - .into_shape() - .strides((row_stride, col_stride).into_dimension()), - ptr, - ) - } - } - } - - impl<'a> IntoNdarrayComplex for MatMut<'a, c32> { - type Ndarray = ArrayViewMut<'a, Complex32, Ix2>; - - #[track_caller] - fn into_ndarray_complex(self) -> Self::Ndarray { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride: usize = self.row_stride().try_into().unwrap(); - let col_stride: usize = self.col_stride().try_into().unwrap(); - let ptr = self.as_ptr_mut() as *mut Complex32; - unsafe { - ArrayViewMut::<'_, Complex32, Ix2>::from_shape_ptr( - (nrows, ncols) - .into_shape() - .strides((row_stride, col_stride).into_dimension()), - ptr, - ) - } - } - } - - impl<'a> IntoFaerComplex for ArrayView<'a, Complex64, Ix2> { - type Faer = MatRef<'a, c64>; - - #[track_caller] - fn into_faer_complex(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides: [isize; 2] = self.strides().try_into().unwrap(); - let ptr = self.as_ptr() as *const c64; - unsafe { faer_core::mat::from_raw_parts(ptr, nrows, ncols, strides[0], strides[1]) } - } - } - - impl<'a> IntoFaerComplex for ArrayViewMut<'a, Complex64, Ix2> { - type Faer = MatMut<'a, c64>; - - #[track_caller] - fn into_faer_complex(self) -> Self::Faer { - let nrows = self.nrows(); - let ncols = self.ncols(); - let strides: [isize; 2] = self.strides().try_into().unwrap(); - let ptr = { self }.as_mut_ptr() as *mut c64; - unsafe { faer_core::mat::from_raw_parts_mut(ptr, nrows, ncols, strides[0], strides[1]) } - } - } - - impl<'a> IntoNdarrayComplex for MatRef<'a, c64> { - type Ndarray = ArrayView<'a, Complex64, Ix2>; - - #[track_caller] - fn into_ndarray_complex(self) -> Self::Ndarray { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride: usize = self.row_stride().try_into().unwrap(); - let col_stride: usize = self.col_stride().try_into().unwrap(); - let ptr = self.as_ptr() as *const Complex64; - unsafe { - ArrayView::<'_, Complex64, Ix2>::from_shape_ptr( - (nrows, ncols) - .into_shape() - .strides((row_stride, col_stride).into_dimension()), - ptr, - ) - } - } - } - - impl<'a> IntoNdarrayComplex for MatMut<'a, c64> { - type Ndarray = ArrayViewMut<'a, Complex64, Ix2>; - - #[track_caller] - fn into_ndarray_complex(self) -> Self::Ndarray { - let nrows = self.nrows(); - let ncols = self.ncols(); - let row_stride: usize = self.row_stride().try_into().unwrap(); - let col_stride: usize = self.col_stride().try_into().unwrap(); - let ptr = self.as_ptr_mut() as *mut Complex64; - unsafe { - ArrayViewMut::<'_, Complex64, Ix2>::from_shape_ptr( - (nrows, ncols) - .into_shape() - .strides((row_stride, col_stride).into_dimension()), - ptr, - ) - } - } - } -}; - -#[cfg(all(feature = "nalgebra", feature = "ndarray"))] -#[cfg_attr(docsrs, doc(cfg(all(feature = "nalgebra", feature = "ndarray"))))] -const _: () = - { - use nalgebra::{Dim, Dyn, MatrixView, MatrixViewMut, ViewStorage, ViewStorageMut}; - use ndarray::{ArrayView, ArrayViewMut, IntoDimension, Ix2, ShapeBuilder}; - use num_complex::Complex; - - impl<'a, T> IntoNalgebra for ArrayView<'a, T, Ix2> { - type Nalgebra = MatrixView<'a, T, Dyn, Dyn, Dyn, Dyn>; - - #[track_caller] - fn into_nalgebra(self) -> Self::Nalgebra { - let nrows = self.nrows(); - let ncols = self.ncols(); - let [row_stride, col_stride]: [isize; 2] = self.strides().try_into().unwrap(); - let ptr = self.as_ptr(); - - unsafe { - MatrixView::<'_, T, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorage::< - '_, - T, - Dyn, - Dyn, - Dyn, - Dyn, - >::from_raw_parts( - ptr, - (Dyn(nrows), Dyn(ncols)), - ( - Dyn(row_stride.try_into().unwrap()), - Dyn(col_stride.try_into().unwrap()), - ), - )) - } - } - } - impl<'a, T> IntoNalgebra for ArrayViewMut<'a, T, Ix2> { - type Nalgebra = MatrixViewMut<'a, T, Dyn, Dyn, Dyn, Dyn>; - - #[track_caller] - fn into_nalgebra(self) -> Self::Nalgebra { - let nrows = self.nrows(); - let ncols = self.ncols(); - let [row_stride, col_stride]: [isize; 2] = self.strides().try_into().unwrap(); - let ptr = { self }.as_mut_ptr(); - - unsafe { - MatrixViewMut::<'_, T, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorageMut::< - '_, - T, - Dyn, - Dyn, - Dyn, - Dyn, - >::from_raw_parts( - ptr, - (Dyn(nrows), Dyn(ncols)), - ( - Dyn(row_stride.try_into().unwrap()), - Dyn(col_stride.try_into().unwrap()), - ), - )) - } - } - } - - impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoNdarray - for MatrixView<'a, T, R, C, RStride, CStride> - { - type Ndarray = ArrayView<'a, T, Ix2>; - - #[track_caller] - fn into_ndarray(self) -> Self::Ndarray { - let nrows = self.nrows(); - let ncols = self.ncols(); - let (row_stride, col_stride) = self.strides(); - let ptr = self.as_ptr(); - - unsafe { - ArrayView::<'_, T, Ix2>::from_shape_ptr( - (nrows, ncols) - .into_shape() - .strides((row_stride, col_stride).into_dimension()), - ptr, - ) - } - } - } - impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoNdarray - for MatrixViewMut<'a, T, R, C, RStride, CStride> - { - type Ndarray = ArrayViewMut<'a, T, Ix2>; - - #[track_caller] - fn into_ndarray(self) -> Self::Ndarray { - let nrows = self.nrows(); - let ncols = self.ncols(); - let (row_stride, col_stride) = self.strides(); - let ptr = { self }.as_mut_ptr(); - - unsafe { - ArrayViewMut::<'_, T, Ix2>::from_shape_ptr( - (nrows, ncols) - .into_shape() - .strides((row_stride, col_stride).into_dimension()), - ptr, - ) - } - } - } - - impl<'a, T> IntoNalgebraComplex for ArrayView<'a, Complex, Ix2> { - type Nalgebra = MatrixView<'a, Complex, Dyn, Dyn, Dyn, Dyn>; - - #[track_caller] - fn into_nalgebra_complex(self) -> Self::Nalgebra { - let nrows = self.nrows(); - let ncols = self.ncols(); - let [row_stride, col_stride]: [isize; 2] = self.strides().try_into().unwrap(); - let ptr = self.as_ptr(); - - unsafe { - MatrixView::<'_, Complex, Dyn, Dyn, Dyn, Dyn>::from_data(ViewStorage::< - '_, - Complex, - Dyn, - Dyn, - Dyn, - Dyn, - >::from_raw_parts( - ptr, - (Dyn(nrows), Dyn(ncols)), - ( - Dyn(row_stride.try_into().unwrap()), - Dyn(col_stride.try_into().unwrap()), - ), - )) - } - } - } - impl<'a, T> IntoNalgebraComplex for ArrayViewMut<'a, Complex, Ix2> { - type Nalgebra = MatrixViewMut<'a, Complex, Dyn, Dyn, Dyn, Dyn>; - - #[track_caller] - fn into_nalgebra_complex(self) -> Self::Nalgebra { - let nrows = self.nrows(); - let ncols = self.ncols(); - let [row_stride, col_stride]: [isize; 2] = self.strides().try_into().unwrap(); - let ptr = { self }.as_mut_ptr(); - - unsafe { - MatrixViewMut::<'_, Complex, Dyn, Dyn, Dyn, Dyn>::from_data( - ViewStorageMut::<'_, Complex, Dyn, Dyn, Dyn, Dyn>::from_raw_parts( - ptr, - (Dyn(nrows), Dyn(ncols)), - ( - Dyn(row_stride.try_into().unwrap()), - Dyn(col_stride.try_into().unwrap()), - ), - ), - ) - } - } - } - - impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoNdarrayComplex - for MatrixView<'a, Complex, R, C, RStride, CStride> - { - type Ndarray = ArrayView<'a, Complex, Ix2>; - - #[track_caller] - fn into_ndarray_complex(self) -> Self::Ndarray { - let nrows = self.nrows(); - let ncols = self.ncols(); - let (row_stride, col_stride) = self.strides(); - let ptr = self.as_ptr(); - - unsafe { - ArrayView::<'_, Complex, Ix2>::from_shape_ptr( - (nrows, ncols) - .into_shape() - .strides((row_stride, col_stride).into_dimension()), - ptr, - ) - } - } - } - impl<'a, T, R: Dim, C: Dim, RStride: Dim, CStride: Dim> IntoNdarrayComplex - for MatrixViewMut<'a, Complex, R, C, RStride, CStride> - { - type Ndarray = ArrayViewMut<'a, Complex, Ix2>; - - #[track_caller] - fn into_ndarray_complex(self) -> Self::Ndarray { - let nrows = self.nrows(); - let ncols = self.ncols(); - let (row_stride, col_stride) = self.strides(); - let ptr = { self }.as_mut_ptr(); - - unsafe { - ArrayViewMut::<'_, Complex, Ix2>::from_shape_ptr( - (nrows, ncols) - .into_shape() - .strides((row_stride, col_stride).into_dimension()), - ptr, - ) - } - } - } - }; - -#[cfg(feature = "polars")] -#[cfg_attr(docsrs, doc(cfg(feature = "polars")))] -pub mod polars { - use super::Mat; - use polars::prelude::*; - - pub trait Frame { - fn is_valid(self) -> PolarsResult; - } - - impl Frame for LazyFrame { - fn is_valid(self) -> PolarsResult { - let test_dtypes: bool = self - .clone() - .limit(0) - .collect() - .unwrap() - .dtypes() - .into_iter() - .map(|e| { - matches!( - e, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - ) - }) - .all(|e| e); - let test_no_nulls: bool = self - .clone() - .null_count() - .with_column( - fold_exprs(lit(0u64), |acc, x| Ok(Some(acc + x)), [col("*")]).alias("sum"), - ) - .select(&[col("sum")]) - .collect() - .unwrap() - .column("sum") - .unwrap() - .u64() - .unwrap() - .into_iter() - .map(|e| e.eq(&Some(0u64))) - .collect::>()[0]; - match (test_dtypes, test_no_nulls) { - (true, true) => Ok(self), - (false, true) => Err(PolarsError::InvalidOperation( - "frame contains non-numerical data".into(), - )), - (true, false) => Err(PolarsError::InvalidOperation( - "frame contains null entries".into(), - )), - (false, false) => Err(PolarsError::InvalidOperation( - "frame contains non-numerical data and null entries".into(), - )), - } - } - } - - macro_rules! polars_impl { - ($ty: ident, $dtype: ident, $fn_name: ident) => { - /// Converts a `polars` lazyframe into a [`Mat`]. - /// - /// Note that this function expects that the frame passed "looks like" - /// a numerical array and all values will be cast to either f32 or f64 - /// prior to building [`Mat`]. - /// - /// Passing a frame with either non-numerical column data or null - /// entries will result in a error. Users are expected to reolve - /// these issues in `polars` prior calling this function. - #[cfg(feature = "polars")] - #[cfg_attr(docsrs, doc(cfg(feature = "polars")))] - pub fn $fn_name( - frame: impl Frame, - ) -> PolarsResult> { - use core::{iter::zip, mem::MaybeUninit}; - - fn implementation( - lf: LazyFrame, - ) -> PolarsResult> { - let df = lf - .select(&[col("*").cast(DataType::$dtype)]) - .collect() - .unwrap(); - - let nrows = df.height(); - let ncols = df.get_column_names().len(); - - let mut out = Mat::<$ty>::with_capacity(df.height(), df.get_column_names().len()); - - df.get_column_names().iter() - .enumerate() - .try_for_each(|(j, col)| -> PolarsResult<()> { - let mut row_start = 0usize; - - // SAFETY: this is safe since we allocated enough space for `ncols` columns and - // `nrows` rows - let out_col = unsafe { - core::slice::from_raw_parts_mut( - out.as_mut().ptr_at_mut(0, j) as *mut MaybeUninit<$ty>, - nrows, - ) - }; - - df.column(col)?.$ty()?.downcast_iter().try_for_each( - |chunk| -> PolarsResult<()> { - let len = chunk.len(); - if len == 0 { - return Ok(()); - } - - match row_start.checked_add(len) { - Some(next_row_start) => { - if next_row_start <= nrows { - let mut out_slice = &mut out_col[row_start..next_row_start]; - let mut values = chunk.values_iter().as_slice(); - let validity = chunk.validity(); - - assert_eq!(values.len(), len); - - match validity { - Some(bitmap) => { - let (mut bytes, offset, bitmap_len) = bitmap.as_slice(); - assert_eq!(bitmap_len, len); - const BITS_PER_BYTE: usize = 8; - - if offset > 0 { - let first_byte_len = Ord::min(len, 8 - offset); - - let (out_prefix, out_suffix) = out_slice.split_at_mut(first_byte_len); - let (values_prefix, values_suffix) = values.split_at(first_byte_len); - - for (out_elem, value_elem) in zip( - out_prefix, - values_prefix, - ) { - *out_elem = MaybeUninit::new(*value_elem) - } - - bytes = &bytes[1..]; - values = values_suffix; - out_slice = out_suffix; - } - - if bytes.len() > 0 { - for (out_slice8, values8) in zip( - out_slice.chunks_exact_mut(BITS_PER_BYTE), - values.chunks_exact(BITS_PER_BYTE), - ) { - for (out_elem, value_elem) in zip(out_slice8, values8) { - *out_elem = MaybeUninit::new(*value_elem); - } - } - - for (out_elem, value_elem) in zip( - out_slice.chunks_exact_mut(BITS_PER_BYTE).into_remainder(), - values.chunks_exact(BITS_PER_BYTE).remainder(), - ) { - *out_elem = MaybeUninit::new(*value_elem); - } - } - } - None => { - // SAFETY: T and MaybeUninit have the same layout - // NOTE: This state should not be reachable - let values = unsafe { - core::slice::from_raw_parts( - values.as_ptr() as *const MaybeUninit<$ty>, - values.len(), - ) - }; - out_slice.copy_from_slice(values); - } - } - - row_start = next_row_start; - Ok(()) - } else { - Err(PolarsError::ShapeMismatch( - format!("too many values in column {col}").into(), - )) - } - } - None => Err(PolarsError::ShapeMismatch( - format!("too many values in column {col}").into(), - )), - } - }, - )?; - - if row_start < nrows { - Err(PolarsError::ShapeMismatch( - format!("not enough values in column {col} (column has {row_start} values, while dataframe has {nrows} rows)").into(), - )) - } else { - Ok(()) - } - })?; - - // SAFETY: we initialized every `ncols` columns, and each one was initialized with `nrows` - // elements - unsafe { out.set_dims(nrows, ncols) }; - - Ok(out) - } - - implementation(frame.is_valid()?) - } - }; - } - - polars_impl!(f32, Float32, polars_to_faer_f32); - polars_impl!(f64, Float64, polars_to_faer_f64); -} - -/// De-serialization from common matrix file formats. -#[cfg(feature = "std")] -#[cfg_attr(docsrs, doc(cfg(feature = "std")))] -pub mod io { - #[allow(unused_imports)] - use super::*; - #[allow(unused_imports)] - use complex_native::{c32, c64}; - - #[cfg(feature = "npy")] - #[cfg_attr(docsrs, doc(cfg(feature = "npy")))] - pub struct Npy<'a> { - aligned_bytes: &'a [u8], - nrows: usize, - ncols: usize, - prefix_len: usize, - dtype: NpyDType, - fortran_order: bool, - } - - #[cfg(feature = "npy")] - #[cfg_attr(docsrs, doc(cfg(feature = "npy")))] - #[derive(Debug, Copy, Clone, PartialEq, Eq)] - pub enum NpyDType { - F32, - F64, - C32, - C64, - Other, - } - - #[cfg(feature = "npy")] - #[cfg_attr(docsrs, doc(cfg(feature = "npy")))] - pub trait FromNpy: modules::core::SimpleEntity { - const DTYPE: NpyDType; - } - - #[cfg(feature = "npy")] - #[cfg_attr(docsrs, doc(cfg(feature = "npy")))] - impl FromNpy for f32 { - const DTYPE: NpyDType = NpyDType::F32; - } - #[cfg(feature = "npy")] - #[cfg_attr(docsrs, doc(cfg(feature = "npy")))] - impl FromNpy for f64 { - const DTYPE: NpyDType = NpyDType::F64; - } - #[cfg(feature = "npy")] - #[cfg_attr(docsrs, doc(cfg(feature = "npy")))] - impl FromNpy for c32 { - const DTYPE: NpyDType = NpyDType::C32; - } - #[cfg(feature = "npy")] - #[cfg_attr(docsrs, doc(cfg(feature = "npy")))] - impl FromNpy for c64 { - const DTYPE: NpyDType = NpyDType::C64; - } - - #[cfg(feature = "npy")] - #[cfg_attr(docsrs, doc(cfg(feature = "npy")))] - impl<'a> Npy<'a> { - fn parse_npyz( - data: &[u8], - npyz: npyz::NpyFile<&[u8]>, - ) -> Result<(NpyDType, usize, usize, usize, bool), std::io::Error> { - let ver_major = data[6] - b'\x00'; - let length = if ver_major <= 1 { - 2usize - } else if ver_major <= 3 { - 4usize - } else { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "unsupported version", - )); - }; - let header_len = if length == 2 { - u16::from_le_bytes(data[8..10].try_into().unwrap()) as usize - } else { - u32::from_le_bytes(data[8..12].try_into().unwrap()) as usize - }; - let dtype = || -> NpyDType { - match npyz.dtype() { - npyz::DType::Plain(str) => { - let is_complex = match str.type_char() { - npyz::TypeChar::Float => false, - npyz::TypeChar::Complex => true, - _ => return NpyDType::Other, - }; - - let byte_size = str.size_field(); - if byte_size == 8 && is_complex { - NpyDType::C32 - } else if byte_size == 16 && is_complex { - NpyDType::C64 - } else if byte_size == 4 && !is_complex { - NpyDType::F32 - } else if byte_size == 16 && !is_complex { - NpyDType::F64 - } else { - NpyDType::Other - } - } - _ => NpyDType::Other, - } - }; - - let dtype = dtype(); - let order = npyz.header().order(); - let shape = npyz.shape(); - let nrows = shape.get(0).copied().unwrap_or(1) as usize; - let ncols = shape.get(1).copied().unwrap_or(1) as usize; - let prefix_len = 8 + length + header_len; - let fortran_order = order == npyz::Order::Fortran; - Ok((dtype, nrows, ncols, prefix_len, fortran_order)) - } - - #[inline] - pub fn new(data: &'a [u8]) -> Result { - let npyz = npyz::NpyFile::new(data)?; - - let (dtype, nrows, ncols, prefix_len, fortran_order) = Self::parse_npyz(data, npyz)?; - - Ok(Self { - aligned_bytes: data, - prefix_len, - nrows, - ncols, - dtype, - fortran_order, - }) - } - - #[inline] - pub fn dtype(&self) -> NpyDType { - self.dtype - } - - #[inline] - pub fn is_aligned(&self) -> bool { - self.aligned_bytes.as_ptr().align_offset(64) == 0 - } - - #[inline] - pub fn as_aligned_ref(&self) -> MatRef<'_, E> { - assert!(self.is_aligned()); - assert!(self.dtype == E::DTYPE); - - if self.fortran_order { - faer_core::mat::from_column_major_slice( - bytemuck::cast_slice(&self.aligned_bytes[self.prefix_len..]), - self.nrows, - self.ncols, - ) - } else { - faer_core::mat::from_row_major_slice( - bytemuck::cast_slice(&self.aligned_bytes[self.prefix_len..]), - self.nrows, - self.ncols, - ) - } - } - - #[inline] - pub fn to_mat(&self) -> Mat { - assert!(self.dtype == E::DTYPE); - - let mut mat = Mat::::with_capacity(self.nrows, self.ncols); - unsafe { mat.set_dims(self.nrows, self.ncols) }; - - let data = &self.aligned_bytes[self.prefix_len..]; - - if self.fortran_order { - for j in 0..self.ncols { - bytemuck::cast_slice_mut(mat.col_as_slice_mut(j)).copy_from_slice( - &data[j * self.nrows * core::mem::size_of::()..] - [..self.nrows * core::mem::size_of::()], - ) - } - } else { - for j in 0..self.ncols { - for i in 0..self.nrows { - bytemuck::cast_slice_mut(&mut mat.col_as_slice_mut(j)[i..i + 1]) - .copy_from_slice( - &data[(i * self.ncols + j) * core::mem::size_of::()..] - [..core::mem::size_of::()], - ) - } - } - }; - - mat - } - } -} - -#[cfg(test)] -mod tests { - #![allow(non_snake_case)] - - use super::*; - use complex_native::*; - use faer_core::{assert, RealField}; - - #[track_caller] - fn assert_approx_eq(a: impl AsMatRef, b: impl AsMatRef) { - let a = a.as_mat_ref(); - let b = b.as_mat_ref(); - let eps = E::Real::faer_epsilon().unwrap().faer_sqrt(); - - assert!(a.nrows() == b.nrows()); - assert!(a.ncols() == b.ncols()); - - let m = a.nrows(); - let n = a.ncols(); - - for j in 0..n { - for i in 0..m { - assert!((a.read(i, j).faer_sub(b.read(i, j))).faer_abs() < eps); - } - } - } - - fn test_solver_real(H: impl AsMatRef, decomp: &dyn SolverCore) { - let H = H.as_mat_ref(); - let n = H.nrows(); - let k = 2; - - let random = |_, _| rand::random::(); - let rhs = Mat::from_fn(n, k, random); - - let I = Mat::from_fn(n, n, |i, j| { - if i == j { - f64::faer_one() - } else { - f64::faer_zero() - } - }); - - let sol = decomp.solve(&rhs); - assert_approx_eq(H * &sol, &rhs); - - let sol = decomp.solve_conj(&rhs); - assert_approx_eq(H.conjugate() * &sol, &rhs); - - let sol = decomp.solve_transpose(&rhs); - assert_approx_eq(H.transpose() * &sol, &rhs); - - let sol = decomp.solve_conj_transpose(&rhs); - assert_approx_eq(H.adjoint() * &sol, &rhs); - - assert_approx_eq(decomp.reconstruct(), H); - assert_approx_eq(H * decomp.inverse(), I); - } - - fn test_solver(H: impl AsMatRef, decomp: &dyn SolverCore) { - let H = H.as_mat_ref(); - let n = H.nrows(); - let k = 2; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let rhs = Mat::from_fn(n, k, random); - - let I = Mat::from_fn(n, n, |i, j| { - if i == j { - c64::faer_one() - } else { - c64::faer_zero() - } - }); - - let sol = decomp.solve(&rhs); - assert_approx_eq(H * &sol, &rhs); - - let sol = decomp.solve_conj(&rhs); - assert_approx_eq(H.conjugate() * &sol, &rhs); - - let sol = decomp.solve_transpose(&rhs); - assert_approx_eq(H.transpose() * &sol, &rhs); - - let sol = decomp.solve_conj_transpose(&rhs); - assert_approx_eq(H.adjoint() * &sol, &rhs); - - assert_approx_eq(decomp.reconstruct(), H); - assert_approx_eq(H * decomp.inverse(), I); - } - - fn test_solver_lstsq(H: impl AsMatRef, decomp: &dyn SolverLstsqCore) { - let H = H.as_mat_ref(); - - let m = H.nrows(); - let k = 2; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let rhs = Mat::from_fn(m, k, random); - - let sol = decomp.solve_lstsq(&rhs); - assert_approx_eq(H.adjoint() * H * &sol, H.adjoint() * &rhs); - - let sol = decomp.solve_lstsq_conj(&rhs); - assert_approx_eq(H.transpose() * H.conjugate() * &sol, H.transpose() * &rhs); - } - - #[test] - fn test_lblt_real() { - let n = 7; - - let random = |_, _| rand::random::(); - let H = Mat::from_fn(n, n, random); - let H = &H + H.adjoint(); - - test_solver_real(&H, &H.lblt(Side::Lower)); - test_solver_real(&H, &H.lblt(Side::Upper)); - } - - #[test] - fn test_lblt() { - let n = 7; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let H = Mat::from_fn(n, n, random); - let H = &H + H.adjoint(); - - test_solver(&H, &H.lblt(Side::Lower)); - test_solver(&H, &H.lblt(Side::Upper)); - } - - #[test] - fn test_cholesky() { - let n = 7; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let H = Mat::from_fn(n, n, random); - let H = &H * H.adjoint(); - - test_solver(&H, &H.cholesky(Side::Lower).unwrap()); - test_solver(&H, &H.cholesky(Side::Upper).unwrap()); - } - - #[test] - fn test_partial_piv_lu() { - let n = 7; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let H = Mat::from_fn(n, n, random); - - test_solver(&H, &H.partial_piv_lu()); - } - - #[test] - fn test_full_piv_lu() { - let n = 7; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let H = Mat::from_fn(n, n, random); - - test_solver(&H, &H.full_piv_lu()); - } - - #[test] - fn test_qr_real() { - let n = 7; - - let random = |_, _| rand::random::(); - let H = Mat::from_fn(n, n, random); - - let qr = H.qr(); - test_solver_real(&H, &qr); - - for (m, n) in [(7, 5), (5, 7), (7, 7)] { - let H = Mat::from_fn(m, n, random); - let qr = H.qr(); - assert_approx_eq(qr.compute_q() * qr.compute_r(), &H); - assert_approx_eq(qr.compute_thin_q() * qr.compute_thin_r(), &H); - } - } - - #[test] - fn test_qr() { - let n = 7; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let H = Mat::from_fn(n, n, random); - - let qr = H.qr(); - test_solver(&H, &qr); - - for (m, n) in [(7, 5), (5, 7), (7, 7)] { - let H = Mat::from_fn(m, n, random); - let qr = H.qr(); - assert_approx_eq(qr.compute_q() * qr.compute_r(), &H); - assert_approx_eq(qr.compute_thin_q() * qr.compute_thin_r(), &H); - if m >= n { - test_solver_lstsq(H, &qr) - } - } - } - - #[test] - fn test_col_piv_qr() { - let n = 7; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let H = Mat::from_fn(n, n, random); - - test_solver(&H, &H.col_piv_qr()); - - for (m, n) in [(7, 5), (5, 7), (7, 7)] { - let H = Mat::from_fn(m, n, random); - let qr = H.col_piv_qr(); - assert_approx_eq( - qr.compute_q() * qr.compute_r(), - &H * qr.col_permutation().inverse(), - ); - assert_approx_eq( - qr.compute_thin_q() * qr.compute_thin_r(), - &H * qr.col_permutation().inverse(), - ); - if m >= n { - test_solver_lstsq(H, &qr) - } - } - } - - #[test] - fn test_svd() { - let n = 7; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let H = Mat::from_fn(n, n, random); - - test_solver(&H, &H.svd()); - test_solver(&H.adjoint().to_owned(), &H.adjoint().svd()); - - let svd = H.svd(); - for i in 0..n - 1 { - assert!(svd.s_diagonal()[(i, 0)].re >= svd.s_diagonal()[(i + 1, 0)].re); - } - let svd = H.singular_values(); - for i in 0..n - 1 { - assert!(svd[i] >= svd[i + 1]); - } - } - - #[test] - fn test_thin_svd() { - let n = 7; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let H = Mat::from_fn(n, n, random); - - test_solver(&H, &H.thin_svd()); - test_solver(&H.adjoint().to_owned(), &H.adjoint().thin_svd()); - } - - #[test] - fn test_selfadjoint_eigendecomposition() { - let n = 7; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let H = Mat::from_fn(n, n, random); - let H = &H * H.adjoint(); - - test_solver(&H, &H.selfadjoint_eigendecomposition(Side::Lower)); - test_solver(&H, &H.selfadjoint_eigendecomposition(Side::Upper)); - test_solver( - &H.adjoint().to_owned(), - &H.adjoint().selfadjoint_eigendecomposition(Side::Lower), - ); - test_solver( - &H.adjoint().to_owned(), - &H.adjoint().selfadjoint_eigendecomposition(Side::Upper), - ); - - let evd = H.selfadjoint_eigendecomposition(Side::Lower); - for i in 0..n - 1 { - assert!(evd.s_diagonal()[(i, 0)].re <= evd.s_diagonal()[(i + 1, 0)].re); - } - let evd = H.selfadjoint_eigenvalues(Side::Lower); - for i in 0..n - 1 { - assert!(evd[i] <= evd[i + 1]); - } - } - - #[test] - fn test_eigendecomposition() { - let n = 7; - - let random = |_, _| c64::new(rand::random(), rand::random()); - let H = Mat::from_fn(n, n, random); - - { - let eigen = H.eigendecomposition::(); - let mut s = Mat::zeros(n, n); - s.as_mut() - .diagonal_mut() - .column_vector_mut() - .copy_from(eigen.s_diagonal()); - let u = eigen.u(); - assert_approx_eq(u * &s, &H * u); - } - - { - let eigen = H.complex_eigendecomposition(); - let mut s = Mat::zeros(n, n); - s.as_mut() - .diagonal_mut() - .column_vector_mut() - .copy_from(eigen.s_diagonal()); - let u = eigen.u(); - assert_approx_eq(u * &s, &H * u); - } - - let det = H.determinant(); - let eigen_det = H - .complex_eigenvalues() - .into_iter() - .fold(c64::faer_one(), |a, b| a * b); - - assert!((det - eigen_det).faer_abs() < 1e-8); - } - - #[test] - fn test_real_eigendecomposition() { - let n = 7; - - let random = |_, _| rand::random::(); - let H_real = Mat::from_fn(n, n, random); - let H = Mat::from_fn(n, n, |i, j| c64::new(H_real.read(i, j), 0.0)); - - let eigen = H_real.eigendecomposition::(); - let mut s = Mat::zeros(n, n); - s.as_mut() - .diagonal_mut() - .column_vector_mut() - .copy_from(eigen.s_diagonal()); - let u = eigen.u(); - assert_approx_eq(u * &s, &H * u); - } - - #[cfg(feature = "ndarray")] - #[test] - fn test_ext_ndarray() { - let mut I_faer = Mat::::identity(8, 7); - let mut I_ndarray = ndarray::Array2::::zeros([8, 7]); - I_ndarray.diag_mut().fill(1.0); - - assert_matrix_eq!(I_ndarray.view().into_faer(), I_faer, comp = exact); - assert!(I_faer.as_ref().into_ndarray() == I_ndarray); - - assert!(I_ndarray.view_mut().into_faer() == I_faer); - assert!(I_faer.as_mut().into_ndarray() == I_ndarray); - } - - #[cfg(feature = "nalgebra")] - #[test] - fn test_ext_nalgebra() { - let mut I_faer = Mat::::identity(8, 7); - let mut I_nalgebra = nalgebra::DMatrix::::identity(8, 7); - - assert!(I_nalgebra.view_range(.., ..).into_faer() == I_faer); - assert!(I_faer.as_ref().into_nalgebra() == I_nalgebra); - - assert!(I_nalgebra.view_range_mut(.., ..).into_faer() == I_faer); - assert!(I_faer.as_mut().into_nalgebra() == I_nalgebra); - } - - #[cfg(feature = "polars")] - #[test] - fn test_polars_pos() { - use crate::polars::{polars_to_faer_f32, polars_to_faer_f64}; - #[rustfmt::skip] - use ::polars::prelude::*; - - let s0: Series = Series::new("a", [1, 2, 3]); - let s1: Series = Series::new("b", [10, 11, 12]); - - let lf = DataFrame::new(vec![s0, s1]).unwrap().lazy(); - - let arr_32 = polars_to_faer_f32(lf.clone()).unwrap(); - let arr_64 = polars_to_faer_f64(lf).unwrap(); - - let expected_32 = mat![[1f32, 10f32], [2f32, 11f32], [3f32, 12f32]]; - let expected_64 = mat![[1f64, 10f64], [2f64, 11f64], [3f64, 12f64]]; - - assert_approx_eq(arr_32, expected_32); - assert_approx_eq(arr_64, expected_64); - } - - #[cfg(feature = "polars")] - #[test] - #[should_panic(expected = "frame contains null entries")] - fn test_polars_neg_32_null() { - use crate::polars::polars_to_faer_f32; - #[rustfmt::skip] - use ::polars::prelude::*; - - let s0: Series = Series::new("a", [1, 2, 3]); - let s1: Series = Series::new("b", [Some(10), Some(11), None]); - - let lf = DataFrame::new(vec![s0, s1]).unwrap().lazy(); - - polars_to_faer_f32(lf).unwrap(); - } - - #[cfg(feature = "polars")] - #[test] - #[should_panic(expected = "frame contains non-numerical data")] - fn test_polars_neg_32_strl() { - use crate::polars::polars_to_faer_f32; - #[rustfmt::skip] - use ::polars::prelude::*; - - let s0: Series = Series::new("a", [1, 2, 3]); - let s1: Series = Series::new("b", ["fish", "dog", "crocodile"]); - - let lf = DataFrame::new(vec![s0, s1]).unwrap().lazy(); - - polars_to_faer_f32(lf).unwrap(); - } - - #[cfg(feature = "polars")] - #[test] - #[should_panic(expected = "frame contains non-numerical data and null entries")] - fn test_polars_neg_32_combo() { - use crate::polars::polars_to_faer_f32; - #[rustfmt::skip] - use ::polars::prelude::*; - - let s0: Series = Series::new("a", [1, 2, 3]); - let s1: Series = Series::new("b", [Some(10), Some(11), None]); - let s2: Series = Series::new("c", [Some("fish"), Some("dog"), None]); - - let lf = DataFrame::new(vec![s0, s1, s2]).unwrap().lazy(); - - polars_to_faer_f32(lf).unwrap(); - } - - #[cfg(feature = "polars")] - #[test] - #[should_panic(expected = "frame contains null entries")] - fn test_polars_neg_64_null() { - use crate::polars::polars_to_faer_f64; - #[rustfmt::skip] - use ::polars::prelude::*; - - let s0: Series = Series::new("a", [1, 2, 3]); - let s1: Series = Series::new("b", [Some(10), Some(11), None]); - - let lf = DataFrame::new(vec![s0, s1]).unwrap().lazy(); - - polars_to_faer_f64(lf).unwrap(); - } - - #[cfg(feature = "polars")] - #[test] - #[should_panic(expected = "frame contains non-numerical data")] - fn test_polars_neg_64_strl() { - use crate::polars::polars_to_faer_f64; - #[rustfmt::skip] - use ::polars::prelude::*; - - let s0: Series = Series::new("a", [1, 2, 3]); - let s1: Series = Series::new("b", ["fish", "dog", "crocodile"]); - - let lf = DataFrame::new(vec![s0, s1]).unwrap().lazy(); - - polars_to_faer_f64(lf).unwrap(); - } - - #[cfg(feature = "polars")] - #[test] - #[should_panic(expected = "frame contains non-numerical data and null entries")] - fn test_polars_neg_64_combo() { - use crate::polars::polars_to_faer_f64; - #[rustfmt::skip] - use ::polars::prelude::*; - - let s0: Series = Series::new("a", [1, 2, 3]); - let s1: Series = Series::new("b", [Some(10), Some(11), None]); - let s2: Series = Series::new("c", [Some("fish"), Some("dog"), None]); - - let lf = DataFrame::new(vec![s0, s1, s2]).unwrap().lazy(); - - polars_to_faer_f64(lf).unwrap(); - } - - #[test] - fn this_other_tree_has_correct_maximum_eigenvalue_20() { - let edges = [ - (3, 2), - (6, 1), - (7, 4), - (7, 6), - (8, 5), - (9, 4), - (11, 2), - (12, 2), - (13, 2), - (15, 6), - (16, 2), - (16, 4), - (17, 8), - (18, 0), - (18, 8), - (18, 2), - (19, 6), - (19, 10), - (19, 14), - ]; - let mut a = Mat::zeros(20, 20); - for (v, u) in edges.iter() { - a[(*v, *u)] = 1.0; - a[(*u, *v)] = 1.0; - } - let eigs_complex: Vec = a.eigenvalues(); - println!("{eigs_complex:?}"); - let eigs_real = eigs_complex.iter().map(|e| e.re).collect::>(); - println!("{eigs_real:?}"); - let lambda_1 = *eigs_real - .iter() - .max_by(|a, b| a.partial_cmp(&b).unwrap()) - .unwrap(); - let correct_lamba_1 = 2.6148611139728866; - assert!( - (lambda_1 - correct_lamba_1).abs() < 1e-10, - "lambda_1 = {lambda_1}, correct_lamba_1 = {correct_lamba_1}", - ); - } - - #[test] - fn this_other_tree_has_correct_maximum_eigenvalue_3() { - let edges = [(1, 0), (0, 2)]; - let mut a = Mat::zeros(3, 3); - for (v, u) in edges.iter() { - a[(*v, *u)] = 1.0; - a[(*u, *v)] = 1.0; - } - let eigs_complex: Vec = a.eigenvalues(); - let eigs_real = eigs_complex.iter().map(|e| e.re).collect::>(); - let lambda_1 = *eigs_real - .iter() - .max_by(|a, b| a.partial_cmp(&b).unwrap()) - .unwrap(); - let correct_lamba_1 = 1.414213562373095; - assert!( - (lambda_1 - correct_lamba_1).abs() < 1e-10, - "lambda_1 = {lambda_1}, correct_lamba_1 = {correct_lamba_1}", - ); - } -} diff --git a/src/col/col_index.rs b/src/col/col_index.rs new file mode 100644 index 0000000000000000000000000000000000000000..269f35d7ceb91ba5643a96099302a4eb9521606e --- /dev/null +++ b/src/col/col_index.rs @@ -0,0 +1,174 @@ +// RangeFull +// Range +// RangeInclusive +// RangeTo +// RangeToInclusive +// usize + +use super::*; +use core::ops::RangeFull; +type Range = core::ops::Range; +type RangeInclusive = core::ops::RangeInclusive; +type RangeFrom = core::ops::RangeFrom; +type RangeTo = core::ops::RangeTo; +type RangeToInclusive = core::ops::RangeToInclusive; + +impl ColIndex for ColRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFull) -> Self { + let _ = row; + this + } +} + +impl ColIndex for ColRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: Range) -> Self { + this.subrows(row.start, row.end - row.start) + } +} + +impl ColIndex for ColRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeInclusive) -> Self { + assert!(*row.end() != usize::MAX); + >::get(this, *row.start()..*row.end() + 1) + } +} + +impl ColIndex for ColRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFrom) -> Self { + let nrows = this.nrows(); + >::get(this, row.start..nrows) + } +} +impl ColIndex for ColRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeTo) -> Self { + >::get(this, 0..row.end) + } +} + +impl ColIndex for ColRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeToInclusive) -> Self { + assert!(row.end != usize::MAX); + >::get(this, 0..row.end + 1) + } +} + +impl<'a, E: Entity> ColIndex for ColRef<'a, E> { + type Target = GroupFor; + + #[track_caller] + #[inline(always)] + unsafe fn get_unchecked(this: Self, row: usize) -> Self::Target { + unsafe { E::faer_map(this.ptr_inbounds_at(row), |ptr: *const _| &*ptr) } + } + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: usize) -> Self::Target { + assert!(row < this.nrows()); + unsafe { >::get_unchecked(this, row) } + } +} + +impl ColIndex for ColMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFull) -> Self { + let _ = row; + this + } +} + +impl ColIndex for ColMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: Range) -> Self { + this.subrows_mut(row.start, row.end - row.start) + } +} + +impl ColIndex for ColMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeInclusive) -> Self { + assert!(*row.end() != usize::MAX); + >::get(this, *row.start()..*row.end() + 1) + } +} + +impl ColIndex for ColMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFrom) -> Self { + let nrows = this.nrows(); + >::get(this, row.start..nrows) + } +} +impl ColIndex for ColMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeTo) -> Self { + >::get(this, 0..row.end) + } +} + +impl ColIndex for ColMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeToInclusive) -> Self { + assert!(row.end != usize::MAX); + >::get(this, 0..row.end + 1) + } +} + +impl<'a, E: Entity> ColIndex for ColMut<'a, E> { + type Target = GroupFor; + + #[track_caller] + #[inline(always)] + unsafe fn get_unchecked(this: Self, row: usize) -> Self::Target { + unsafe { E::faer_map(this.ptr_inbounds_at_mut(row), |ptr: *mut _| &mut *ptr) } + } + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: usize) -> Self::Target { + assert!(row < this.nrows()); + unsafe { >::get_unchecked(this, row) } + } +} diff --git a/src/col/colmut.rs b/src/col/colmut.rs new file mode 100644 index 0000000000000000000000000000000000000000..d21eefbd81917a2a4daccfae12d946734663c9ef --- /dev/null +++ b/src/col/colmut.rs @@ -0,0 +1,597 @@ +use crate::{ + diag::DiagMut, + mat::{self, As2D, As2DMut, Mat, MatMut, MatRef}, + row::RowMut, + unzipped, zipped, +}; + +use super::*; + +/// Mutable view over a column vector, similar to a mutable reference to a strided [prim@slice]. +/// +/// # Note +/// +/// Unlike a slice, the data pointed to by `ColMut<'_, E>` is allowed to be partially or fully +/// uninitialized under certain conditions. In this case, care must be taken to not perform any +/// operations that read the uninitialized values, or form references to them, either directly +/// through [`ColMut::read`], or indirectly through any of the numerical library routines, unless +/// it is explicitly permitted. +/// +/// # Move semantics +/// See [`faer::Mat`](crate::Mat) for information about reborrowing when using this type. +#[repr(C)] +pub struct ColMut<'a, E: Entity> { + pub(super) inner: VecImpl, + pub(super) __marker: PhantomData<&'a E>, +} + +impl<'short, E: Entity> Reborrow<'short> for ColMut<'_, E> { + type Target = ColRef<'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + ColRef { + inner: self.inner, + __marker: PhantomData, + } + } +} + +impl<'short, E: Entity> ReborrowMut<'short> for ColMut<'_, E> { + type Target = ColMut<'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + ColMut { + inner: self.inner, + __marker: PhantomData, + } + } +} + +impl<'a, E: Entity> IntoConst for ColMut<'a, E> { + type Target = ColRef<'a, E>; + + #[inline] + fn into_const(self) -> Self::Target { + ColRef { + inner: self.inner, + __marker: PhantomData, + } + } +} + +impl<'a, E: Entity> ColMut<'a, E> { + #[inline] + pub(crate) unsafe fn __from_raw_parts( + ptr: GroupFor, + nrows: usize, + row_stride: isize, + ) -> Self { + Self { + inner: VecImpl { + ptr: into_copy::(E::faer_map( + ptr, + #[inline] + |ptr| NonNull::new_unchecked(ptr), + )), + len: nrows, + stride: row_stride, + }, + __marker: PhantomData, + } + } + + #[track_caller] + #[inline(always)] + #[doc(hidden)] + pub fn try_get_contiguous_col_mut(self) -> GroupFor { + assert!(self.row_stride() == 1); + let m = self.nrows(); + E::faer_map( + self.as_ptr_mut(), + #[inline(always)] + |ptr| unsafe { core::slice::from_raw_parts_mut(ptr, m) }, + ) + } + + /// Returns the number of rows of the column. + #[inline(always)] + pub fn nrows(&self) -> usize { + self.inner.len + } + /// Returns the number of columns of the column. This is always equal to `1`. + #[inline(always)] + pub fn ncols(&self) -> usize { + 1 + } + + /// Returns pointers to the matrix data. + #[inline(always)] + pub fn as_ptr_mut(self) -> GroupFor { + E::faer_map( + from_copy::(self.inner.ptr), + #[inline(always)] + |ptr| ptr.as_ptr() as *mut E::Unit, + ) + } + + /// Returns the row stride of the matrix, specified in number of elements, not in bytes. + #[inline(always)] + pub fn row_stride(&self) -> isize { + self.inner.stride + } + + /// Returns `self` as a mutable matrix view. + #[inline(always)] + pub fn as_2d_mut(self) -> MatMut<'a, E> { + let nrows = self.nrows(); + let row_stride = self.row_stride(); + unsafe { mat::from_raw_parts_mut(self.as_ptr_mut(), nrows, 1, row_stride, isize::MAX) } + } + + /// Returns raw pointers to the element at the given index. + #[inline(always)] + pub fn ptr_at_mut(self, row: usize) -> GroupFor { + let offset = (row as isize).wrapping_mul(self.inner.stride); + + E::faer_map( + self.as_ptr_mut(), + #[inline(always)] + |ptr| ptr.wrapping_offset(offset), + ) + } + + #[inline(always)] + unsafe fn ptr_at_mut_unchecked(self, row: usize) -> GroupFor { + let offset = crate::utils::unchecked_mul(row, self.inner.stride); + E::faer_map( + self.as_ptr_mut(), + #[inline(always)] + |ptr| ptr.offset(offset), + ) + } + + /// Returns raw pointers to the element at the given index, assuming the provided index + /// is within the size of the vector. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn ptr_inbounds_at_mut(self, row: usize) -> GroupFor { + debug_assert!(row < self.nrows()); + self.ptr_at_mut_unchecked(row) + } + + /// Splits the column vector at the given index into two parts and + /// returns an array of each subvector, in the following order: + /// * top. + /// * bottom. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row <= self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at_mut_unchecked(self, row: usize) -> (Self, Self) { + let (top, bot) = self.into_const().split_at_unchecked(row); + unsafe { (top.const_cast(), bot.const_cast()) } + } + + /// Splits the column vector at the given index into two parts and + /// returns an array of each subvector, in the following order: + /// * top. + /// * bottom. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row <= self.nrows()`. + #[inline(always)] + #[track_caller] + pub fn split_at_mut(self, row: usize) -> (Self, Self) { + assert!(row <= self.nrows()); + unsafe { self.split_at_mut_unchecked(row) } + } + + /// Returns references to the element at the given index, or subvector if `row` is a + /// range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + #[inline(always)] + #[track_caller] + pub unsafe fn get_unchecked_mut( + self, + row: RowRange, + ) -> >::Target + where + Self: ColIndex, + { + >::get_unchecked(self, row) + } + + /// Returns references to the element at the given index, or subvector if `row` is a + /// range, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + #[inline(always)] + #[track_caller] + pub fn get_mut(self, row: RowRange) -> >::Target + where + Self: ColIndex, + { + >::get(self, row) + } + + /// Reads the value of the element at the given index. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn read_unchecked(&self, row: usize) -> E { + self.rb().read_unchecked(row) + } + + /// Reads the value of the element at the given index, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub fn read(&self, row: usize) -> E { + self.rb().read(row) + } + + /// Writes the value to the element at the given index. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn write_unchecked(&mut self, row: usize, value: E) { + let units = value.faer_into_units(); + let zipped = E::faer_zip(units, (*self).rb_mut().ptr_inbounds_at_mut(row)); + E::faer_map( + zipped, + #[inline(always)] + |(unit, ptr)| *ptr = unit, + ); + } + + /// Writes the value to the element at the given index, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub fn write(&mut self, row: usize, value: E) { + assert!(row < self.nrows()); + unsafe { self.write_unchecked(row, value) }; + } + + /// Copies the values from `other` into `self`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `self.nrows() == other.nrows()`. + /// * `self.ncols() == other.ncols()`. + #[track_caller] + pub fn copy_from(&mut self, other: impl AsColRef) { + #[track_caller] + #[inline(always)] + fn implementation(this: ColMut<'_, E>, other: ColRef<'_, E>) { + zipped!(this.as_2d_mut(), other.as_2d()) + .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); + } + implementation(self.rb_mut(), other.as_col_ref()) + } + + /// Fills the elements of `self` with zeros. + #[track_caller] + pub fn fill_zero(&mut self) + where + E: ComplexField, + { + zipped!(self.rb_mut().as_2d_mut()).for_each( + #[inline(always)] + |unzipped!(mut x)| x.write(E::faer_zero()), + ); + } + + /// Fills the elements of `self` with copies of `constant`. + #[track_caller] + pub fn fill(&mut self, constant: E) { + zipped!((*self).rb_mut().as_2d_mut()).for_each( + #[inline(always)] + |unzipped!(mut x)| x.write(constant), + ); + } + + /// Returns a view over the transpose of `self`. + #[inline(always)] + #[must_use] + pub fn transpose_mut(self) -> RowMut<'a, E> { + unsafe { self.into_const().transpose().const_cast() } + } + + /// Returns a view over the conjugate of `self`. + #[inline(always)] + #[must_use] + pub fn conjugate_mut(self) -> ColMut<'a, E::Conj> + where + E: Conjugate, + { + unsafe { self.into_const().conjugate().const_cast() } + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline(always)] + pub fn adjoint_mut(self) -> RowMut<'a, E::Conj> + where + E: Conjugate, + { + self.conjugate_mut().transpose_mut() + } + + /// Returns a view over the canonical representation of `self`, as well as a flag declaring + /// whether `self` is implicitly conjugated or not. + #[inline(always)] + pub fn canonicalize_mut(self) -> (ColMut<'a, E::Canonical>, Conj) + where + E: Conjugate, + { + let (canon, conj) = self.into_const().canonicalize(); + unsafe { (canon.const_cast(), conj) } + } + + /// Returns a view over the `self`, with the rows in reversed order. + #[inline(always)] + #[must_use] + pub fn reverse_rows_mut(self) -> Self { + unsafe { self.into_const().reverse_rows().const_cast() } + } + + /// Returns a view over the subvector starting at row `row_start`, and with number of rows + /// `nrows`. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `nrows <= self.nrows() - row_start`. + #[track_caller] + #[inline(always)] + pub unsafe fn subrows_mut_unchecked(self, row_start: usize, nrows: usize) -> Self { + self.into_const() + .subrows_unchecked(row_start, nrows) + .const_cast() + } + + /// Returns a view over the subvector starting at row `row_start`, and with number of rows + /// `nrows`. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `nrows <= self.nrows() - row_start`. + #[track_caller] + #[inline(always)] + pub fn subrows_mut(self, row_start: usize, nrows: usize) -> Self { + unsafe { self.into_const().subrows(row_start, nrows).const_cast() } + } + + /// Given a matrix with a single column, returns an object that interprets + /// the column as a diagonal matrix, whoes diagonal elements are values in the column. + #[track_caller] + #[inline(always)] + pub fn column_vector_as_diagonal_mut(self) -> DiagMut<'a, E> { + DiagMut { inner: self } + } + + /// Returns an owning [`Col`] of the data. + #[inline] + pub fn to_owned(&self) -> Col + where + E: Conjugate, + { + (*self).rb().to_owned() + } + + /// Returns `true` if any of the elements is NaN, otherwise returns `false`. + #[inline] + pub fn has_nan(&self) -> bool + where + E: ComplexField, + { + (*self).rb().as_2d().has_nan() + } + + /// Returns `true` if all of the elements are finite, otherwise returns `false`. + #[inline] + pub fn is_all_finite(&self) -> bool + where + E: ComplexField, + { + (*self).rb().as_2d().is_all_finite() + } + + /// Returns the maximum norm of `self`. + #[inline] + pub fn norm_max(&self) -> E::Real + where + E: ComplexField, + { + self.rb().as_2d().norm_max() + } + /// Returns the L2 norm of `self`. + #[inline] + pub fn norm_l2(&self) -> E::Real + where + E: ComplexField, + { + self.rb().as_2d().norm_l2() + } + + /// Returns the sum of `self`. + #[inline] + pub fn sum(&self) -> E + where + E: ComplexField, + { + self.rb().as_2d().sum() + } + + /// Kroneckor product of `self` and `rhs`. + /// + /// This is an allocating operation; see [`faer::linalg::kron`](crate::linalg::kron) for the + /// allocation-free version or more info in general. + #[inline] + #[track_caller] + pub fn kron(&self, rhs: impl As2D) -> Mat + where + E: ComplexField, + { + self.as_ref().kron(rhs) + } + + /// Returns a view over the matrix. + #[inline] + pub fn as_ref(&self) -> ColRef<'_, E> { + (*self).rb() + } + + /// Returns a mutable view over the matrix. + #[inline] + pub fn as_mut(&mut self) -> ColMut<'_, E> { + (*self).rb_mut() + } +} + +/// Creates a `ColMut` from pointers to the column vector data, number of rows, and row stride. +/// +/// # Safety: +/// This function has the same safety requirements as +/// [`mat::from_raw_parts_mut(ptr, nrows, 1, row_stride, 0)`] +#[inline(always)] +pub unsafe fn from_raw_parts_mut<'a, E: Entity>( + ptr: GroupFor, + nrows: usize, + row_stride: isize, +) -> ColMut<'a, E> { + ColMut::__from_raw_parts(ptr, nrows, row_stride) +} + +/// Creates a `ColMut` from slice views over the column vector data, The result has the same +/// number of rows as the length of the input slice. +#[inline(always)] +pub fn from_slice_mut(slice: GroupFor) -> ColMut<'_, E> { + let nrows = SliceGroup::<'_, E>::new(E::faer_rb(E::faer_as_ref(&slice))).len(); + + unsafe { + from_raw_parts_mut( + E::faer_map( + slice, + #[inline(always)] + |slice| slice.as_mut_ptr(), + ), + nrows, + 1, + ) + } +} + +impl As2D for &'_ ColMut<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (**self).rb().as_2d() + } +} + +impl As2D for ColMut<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (*self).rb().as_2d() + } +} + +impl As2DMut for &'_ mut ColMut<'_, E> { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (**self).rb_mut().as_2d_mut() + } +} + +impl As2DMut for ColMut<'_, E> { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (*self).rb_mut().as_2d_mut() + } +} + +impl AsColRef for ColMut<'_, E> { + #[inline] + fn as_col_ref(&self) -> ColRef<'_, E> { + (*self).rb() + } +} +impl AsColRef for &'_ ColMut<'_, E> { + #[inline] + fn as_col_ref(&self) -> ColRef<'_, E> { + (**self).rb() + } +} + +impl AsColMut for ColMut<'_, E> { + #[inline] + fn as_col_mut(&mut self) -> ColMut<'_, E> { + (*self).rb_mut() + } +} + +impl AsColMut for &'_ mut ColMut<'_, E> { + #[inline] + fn as_col_mut(&mut self) -> ColMut<'_, E> { + (**self).rb_mut() + } +} + +impl<'a, E: Entity> core::fmt::Debug for ColMut<'a, E> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.rb().fmt(f) + } +} + +impl core::ops::Index for ColMut<'_, E> { + type Output = E; + + #[inline] + #[track_caller] + fn index(&self, row: usize) -> &E { + (*self).rb().get(row) + } +} + +impl core::ops::IndexMut for ColMut<'_, E> { + #[inline] + #[track_caller] + fn index_mut(&mut self, row: usize) -> &mut E { + (*self).rb_mut().get_mut(row) + } +} diff --git a/src/col/colown.rs b/src/col/colown.rs new file mode 100644 index 0000000000000000000000000000000000000000..e53b8247525c4ac5b4a28a8458edbccdf2615cea --- /dev/null +++ b/src/col/colown.rs @@ -0,0 +1,634 @@ +use super::*; +use crate::{ + debug_assert, + diag::Diag, + mat::{ + matalloc::{align_for, is_vectorizable, MatUnit, RawMat, RawMatUnit}, + As2D, As2DMut, Mat, MatMut, MatRef, + }, + row::RowRef, + utils::DivCeil, +}; +use core::mem::ManuallyDrop; + +/// Heap allocated resizable column vector. +/// +/// # Note +/// +/// The memory layout of `Col` is guaranteed to be column-major, meaning that it has a row stride +/// of `1`. +#[repr(C)] +pub struct Col { + inner: VecOwnImpl, + row_capacity: usize, + __marker: PhantomData, +} + +impl Col { + /// Returns an empty column of dimension `0`. + #[inline] + pub fn new() -> Self { + Self { + inner: VecOwnImpl { + ptr: into_copy::(E::faer_map(E::UNIT, |()| NonNull::::dangling())), + len: 0, + }, + row_capacity: 0, + __marker: PhantomData, + } + } + + /// Returns a new column vector with 0 rows, with enough capacity to hold a maximum of + /// `row_capacity` rows columns without reallocating. If `row_capacity` is `0`, + /// the function will not allocate. + /// + /// # Panics + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn with_capacity(row_capacity: usize) -> Self { + let raw = ManuallyDrop::new(RawMat::::new(row_capacity, 1)); + Self { + inner: VecOwnImpl { + ptr: raw.ptr, + len: 0, + }, + row_capacity: raw.row_capacity, + __marker: PhantomData, + } + } + + /// Returns a new matrix with number of rows `nrows`, filled with the provided function. + /// + /// # Panics + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn from_fn(nrows: usize, f: impl FnMut(usize) -> E) -> Self { + let mut this = Self::new(); + this.resize_with(nrows, f); + this + } + + /// Returns a new matrix with number of rows `nrows`, filled with zeros. + /// + /// # Panics + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn zeros(nrows: usize) -> Self + where + E: ComplexField, + { + Self::from_fn(nrows, |_| E::faer_zero()) + } + + /// Returns the number of rows of the column. + #[inline(always)] + pub fn nrows(&self) -> usize { + self.inner.len + } + /// Returns the number of columns of the column. This is always equal to `1`. + #[inline(always)] + pub fn ncols(&self) -> usize { + 1 + } + + /// Set the dimensions of the matrix. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `nrows < self.row_capacity()`. + /// * The elements that were previously out of bounds but are now in bounds must be + /// initialized. + #[inline] + pub unsafe fn set_nrows(&mut self, nrows: usize) { + self.inner.len = nrows; + } + + /// Returns a pointer to the data of the matrix. + #[inline] + pub fn as_ptr(&self) -> GroupFor { + E::faer_map(from_copy::(self.inner.ptr), |ptr| { + ptr.as_ptr() as *const E::Unit + }) + } + + /// Returns a mutable pointer to the data of the matrix. + #[inline] + pub fn as_ptr_mut(&mut self) -> GroupFor { + E::faer_map(from_copy::(self.inner.ptr), |ptr| ptr.as_ptr()) + } + + /// Returns the row capacity, that is, the number of rows that the matrix is able to hold + /// without needing to reallocate, excluding column insertions. + #[inline] + pub fn row_capacity(&self) -> usize { + self.row_capacity + } + + /// Returns the offset between the first elements of two successive rows in the matrix. + /// Always returns `1` since the matrix is column major. + #[inline] + pub fn row_stride(&self) -> isize { + 1 + } + + #[cold] + fn do_reserve_exact(&mut self, mut new_row_capacity: usize) { + if is_vectorizable::() { + let align_factor = align_for::() / core::mem::size_of::(); + new_row_capacity = new_row_capacity + .msrv_checked_next_multiple_of(align_factor) + .unwrap(); + } + + let nrows = self.inner.len; + let old_row_capacity = self.row_capacity; + + let mut this = ManuallyDrop::new(core::mem::take(self)); + { + let mut this_group = E::faer_map(from_copy::(this.inner.ptr), |ptr| MatUnit { + raw: RawMatUnit { + ptr, + row_capacity: old_row_capacity, + col_capacity: 1, + }, + nrows, + ncols: 1, + }); + + E::faer_map(E::faer_as_mut(&mut this_group), |mat_unit| { + mat_unit.do_reserve_exact(new_row_capacity, 1); + }); + + let this_group = E::faer_map(this_group, ManuallyDrop::new); + this.inner.ptr = + into_copy::(E::faer_map(this_group, |mat_unit| mat_unit.raw.ptr)); + this.row_capacity = new_row_capacity; + } + *self = ManuallyDrop::into_inner(this); + } + + /// Reserves the minimum capacity for `row_capacity` rows without reallocating. Does nothing if + /// the capacity is already sufficient. + /// + /// # Panics + /// The function panics if the new total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn reserve_exact(&mut self, row_capacity: usize) { + if self.row_capacity() >= row_capacity { + // do nothing + } else if core::mem::size_of::() == 0 { + self.row_capacity = self.row_capacity().max(row_capacity); + } else { + self.do_reserve_exact(row_capacity); + } + } + + unsafe fn insert_block_with E>( + &mut self, + f: &mut F, + row_start: usize, + row_end: usize, + ) { + debug_assert!(row_start <= row_end); + + let ptr = self.as_ptr_mut(); + + for i in row_start..row_end { + // SAFETY: + // * pointer to element at index (i, j), which is within the + // allocation since we reserved enough space + // * writing to this memory region is sound since it is properly + // aligned and valid for writes + let ptr_ij = E::faer_map(E::faer_copy(&ptr), |ptr| ptr.add(i)); + let value = E::faer_into_units(f(i)); + + E::faer_map(E::faer_zip(ptr_ij, value), |(ptr_ij, value)| { + core::ptr::write(ptr_ij, value) + }); + } + } + + fn erase_last_rows(&mut self, new_nrows: usize) { + let old_nrows = self.nrows(); + debug_assert!(new_nrows <= old_nrows); + self.inner.len = new_nrows; + } + + unsafe fn insert_last_rows_with E>(&mut self, f: &mut F, new_nrows: usize) { + let old_nrows = self.nrows(); + + debug_assert!(new_nrows > old_nrows); + + self.insert_block_with(f, old_nrows, new_nrows); + self.inner.len = new_nrows; + } + + /// Resizes the vector in-place so that the new number of rows is `new_nrows`. + /// New elements are created with the given function `f`, so that elements at index `i` + /// are created by calling `f(i)`. + pub fn resize_with(&mut self, new_nrows: usize, f: impl FnMut(usize) -> E) { + let mut f = f; + let old_nrows = self.nrows(); + + if new_nrows <= old_nrows { + self.erase_last_rows(new_nrows); + } else { + self.reserve_exact(new_nrows); + unsafe { + self.insert_last_rows_with(&mut f, new_nrows); + } + } + } + + /// Returns a reference to a slice over the column. + #[inline] + #[track_caller] + pub fn as_slice(&self) -> GroupFor { + let nrows = self.nrows(); + let ptr = self.as_ref().as_ptr(); + E::faer_map( + ptr, + #[inline(always)] + |ptr| unsafe { core::slice::from_raw_parts(ptr, nrows) }, + ) + } + + /// Returns a mutable reference to a slice over the column. + #[inline] + #[track_caller] + pub fn as_slice_mut(&mut self) -> GroupFor { + let nrows = self.nrows(); + let ptr = self.as_ptr_mut(); + E::faer_map( + ptr, + #[inline(always)] + |ptr| unsafe { core::slice::from_raw_parts_mut(ptr, nrows) }, + ) + } + + /// Returns a view over the vector. + #[inline] + pub fn as_ref(&self) -> ColRef<'_, E> { + unsafe { super::from_raw_parts(self.as_ptr(), self.nrows(), 1) } + } + + /// Returns a mutable view over the vector. + #[inline] + pub fn as_mut(&mut self) -> ColMut<'_, E> { + unsafe { super::from_raw_parts_mut(self.as_ptr_mut(), self.nrows(), 1) } + } + + /// Returns references to the element at the given index, or submatrices if `row` is a range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + #[inline] + pub unsafe fn get_unchecked( + &self, + row: RowRange, + ) -> as ColIndex>::Target + where + for<'a> ColRef<'a, E>: ColIndex, + { + self.as_ref().get_unchecked(row) + } + + /// Returns references to the element at the given index, or submatrices if `row` is a range, + /// with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + #[inline] + pub fn get(&self, row: RowRange) -> as ColIndex>::Target + where + for<'a> ColRef<'a, E>: ColIndex, + { + self.as_ref().get(row) + } + + /// Returns mutable references to the element at the given index, or submatrices if + /// `row` is a range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + #[inline] + pub unsafe fn get_mut_unchecked( + &mut self, + row: RowRange, + ) -> as ColIndex>::Target + where + for<'a> ColMut<'a, E>: ColIndex, + { + self.as_mut().get_unchecked_mut(row) + } + + /// Returns mutable references to the element at the given index, or submatrices if + /// `row` is a range, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + #[inline] + pub fn get_mut( + &mut self, + row: RowRange, + ) -> as ColIndex>::Target + where + for<'a> ColMut<'a, E>: ColIndex, + { + self.as_mut().get_mut(row) + } + + /// Reads the value of the element at the given index. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn read_unchecked(&self, row: usize) -> E { + self.as_ref().read_unchecked(row) + } + + /// Reads the value of the element at the given index, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub fn read(&self, row: usize) -> E { + self.as_ref().read(row) + } + + /// Writes the value to the element at the given index. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn write_unchecked(&mut self, row: usize, value: E) { + self.as_mut().write_unchecked(row, value); + } + + /// Writes the value to the element at the given index, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub fn write(&mut self, row: usize, value: E) { + self.as_mut().write(row, value); + } + + /// Copies the values from `other` into `self`. + #[inline(always)] + #[track_caller] + pub fn copy_from(&mut self, other: impl AsColRef) { + #[track_caller] + #[inline(always)] + fn implementation(this: &mut Col, other: ColRef<'_, E>) { + let mut mat = Col::::new(); + mat.resize_with( + other.nrows(), + #[inline(always)] + |row| unsafe { other.read_unchecked(row) }, + ); + *this = mat; + } + implementation(self, other.as_col_ref()); + } + + /// Fills the elements of `self` with zeros. + #[inline(always)] + #[track_caller] + pub fn fill_zero(&mut self) + where + E: ComplexField, + { + self.as_mut().fill_zero() + } + + /// Fills the elements of `self` with copies of `constant`. + #[inline(always)] + #[track_caller] + pub fn fill(&mut self, constant: E) { + self.as_mut().fill(constant) + } + + /// Returns a view over the transpose of `self`. + #[inline] + pub fn transpose(&self) -> RowRef<'_, E> { + self.as_ref().transpose() + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn conjugate(&self) -> ColRef<'_, E::Conj> + where + E: Conjugate, + { + self.as_ref().conjugate() + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline] + pub fn adjoint(&self) -> RowRef<'_, E::Conj> + where + E: Conjugate, + { + self.as_ref().adjoint() + } + + /// Given a matrix with a single column, returns an object that interprets + /// the column as a diagonal matrix, whoes diagonal elements are values in the column. + #[track_caller] + #[inline(always)] + pub fn column_vector_into_diagonal(self) -> Diag { + Diag { inner: self } + } + + /// Returns an owning [`Col`] of the data + #[inline] + pub fn to_owned(&self) -> Col + where + E: Conjugate, + { + self.as_ref().to_owned() + } + + /// Returns `true` if any of the elements is NaN, otherwise returns `false`. + #[inline] + pub fn has_nan(&self) -> bool + where + E: ComplexField, + { + self.as_ref().has_nan() + } + + /// Returns `true` if all of the elements are finite, otherwise returns `false`. + #[inline] + pub fn is_all_finite(&self) -> bool + where + E: ComplexField, + { + self.as_ref().is_all_finite() + } + + /// Returns the maximum norm of `self`. + #[inline] + pub fn norm_max(&self) -> E::Real + where + E: ComplexField, + { + self.as_ref().as_2d().norm_max() + } + /// Returns the L2 norm of `self`. + #[inline] + pub fn norm_l2(&self) -> E::Real + where + E: ComplexField, + { + self.as_ref().as_2d().norm_l2() + } + + /// Returns the sum of `self`. + #[inline] + pub fn sum(&self) -> E + where + E: ComplexField, + { + self.as_ref().as_2d().sum() + } + + /// Kroneckor product of `self` and `rhs`. + /// + /// This is an allocating operation; see [`faer::linalg::kron`](crate::linalg::kron) for the + /// allocation-free version or more info in general. + #[inline] + #[track_caller] + pub fn kron(&self, rhs: impl As2D) -> Mat + where + E: ComplexField, + { + self.as_2d_ref().kron(rhs) + } +} + +impl Default for Col { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Clone for Col { + fn clone(&self) -> Self { + let this = self.as_ref(); + unsafe { + Self::from_fn(self.nrows(), |i| { + E::faer_from_units(E::faer_deref(this.get_unchecked(i))) + }) + } + } +} + +impl As2D for &'_ Col { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (**self).as_ref().as_2d() + } +} + +impl As2D for Col { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (*self).as_ref().as_2d() + } +} + +impl As2DMut for &'_ mut Col { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (**self).as_mut().as_2d_mut() + } +} + +impl As2DMut for Col { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (*self).as_mut().as_2d_mut() + } +} + +impl AsColRef for Col { + #[inline] + fn as_col_ref(&self) -> ColRef<'_, E> { + (*self).as_ref() + } +} +impl AsColRef for &'_ Col { + #[inline] + fn as_col_ref(&self) -> ColRef<'_, E> { + (**self).as_ref() + } +} + +impl AsColMut for Col { + #[inline] + fn as_col_mut(&mut self) -> ColMut<'_, E> { + (*self).as_mut() + } +} + +impl AsColMut for &'_ mut Col { + #[inline] + fn as_col_mut(&mut self) -> ColMut<'_, E> { + (**self).as_mut() + } +} + +impl core::fmt::Debug for Col { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.as_ref().fmt(f) + } +} + +impl core::ops::Index for Col { + type Output = E; + + #[inline] + #[track_caller] + fn index(&self, row: usize) -> &E { + self.as_ref().get(row) + } +} + +impl core::ops::IndexMut for Col { + #[inline] + #[track_caller] + fn index_mut(&mut self, row: usize) -> &mut E { + self.as_mut().get_mut(row) + } +} diff --git a/src/col/colref.rs b/src/col/colref.rs new file mode 100644 index 0000000000000000000000000000000000000000..15ed6e50b2dbf09ed0436485bae381bd48f03072 --- /dev/null +++ b/src/col/colref.rs @@ -0,0 +1,572 @@ +use super::*; +use crate::{ + assert, debug_assert, + diag::DiagRef, + mat::{As2D, Mat, MatRef}, + row::RowRef, +}; + +/// Immutable view over a column vector, similar to an immutable reference to a strided +/// [prim@slice]. +/// +/// # Note +/// +/// Unlike a slice, the data pointed to by `ColRef<'_, E>` is allowed to be partially or fully +/// uninitialized under certain conditions. In this case, care must be taken to not perform any +/// operations that read the uninitialized values, or form references to them, either directly +/// through [`ColRef::read`], or indirectly through any of the numerical library routines, unless +/// it is explicitly permitted. +#[repr(C)] +pub struct ColRef<'a, E: Entity> { + pub(super) inner: VecImpl, + pub(super) __marker: PhantomData<&'a E>, +} + +impl Clone for ColRef<'_, E> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl Copy for ColRef<'_, E> {} + +impl<'short, E: Entity> Reborrow<'short> for ColRef<'_, E> { + type Target = ColRef<'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} + +impl<'short, E: Entity> ReborrowMut<'short> for ColRef<'_, E> { + type Target = ColRef<'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl IntoConst for ColRef<'_, E> { + type Target = Self; + + #[inline] + fn into_const(self) -> Self::Target { + self + } +} + +impl<'a, E: Entity> ColRef<'a, E> { + #[inline] + pub(crate) unsafe fn __from_raw_parts( + ptr: GroupFor, + nrows: usize, + row_stride: isize, + ) -> Self { + Self { + inner: VecImpl { + ptr: into_copy::(E::faer_map( + ptr, + #[inline] + |ptr| NonNull::new_unchecked(ptr as *mut E::Unit), + )), + len: nrows, + stride: row_stride, + }, + __marker: PhantomData, + } + } + + #[track_caller] + #[inline(always)] + #[doc(hidden)] + pub fn try_get_contiguous_col(self) -> GroupFor { + assert!(self.row_stride() == 1); + let m = self.nrows(); + E::faer_map( + self.as_ptr(), + #[inline(always)] + |ptr| unsafe { core::slice::from_raw_parts(ptr, m) }, + ) + } + + /// Returns the number of rows of the column. + #[inline(always)] + pub fn nrows(&self) -> usize { + self.inner.len + } + /// Returns the number of columns of the column. This is always equal to `1`. + #[inline(always)] + pub fn ncols(&self) -> usize { + 1 + } + + /// Returns pointers to the matrix data. + #[inline(always)] + pub fn as_ptr(self) -> GroupFor { + E::faer_map( + from_copy::(self.inner.ptr), + #[inline(always)] + |ptr| ptr.as_ptr() as *const E::Unit, + ) + } + + /// Returns the row stride of the matrix, specified in number of elements, not in bytes. + #[inline(always)] + pub fn row_stride(&self) -> isize { + self.inner.stride + } + + /// Returns `self` as a matrix view. + #[inline(always)] + pub fn as_2d(self) -> MatRef<'a, E> { + let nrows = self.nrows(); + let row_stride = self.row_stride(); + unsafe { crate::mat::from_raw_parts(self.as_ptr(), nrows, 1, row_stride, isize::MAX) } + } + + /// Returns raw pointers to the element at the given index. + #[inline(always)] + pub fn ptr_at(self, row: usize) -> GroupFor { + let offset = (row as isize).wrapping_mul(self.inner.stride); + + E::faer_map( + self.as_ptr(), + #[inline(always)] + |ptr| ptr.wrapping_offset(offset), + ) + } + + #[inline(always)] + unsafe fn unchecked_ptr_at(self, row: usize) -> GroupFor { + let offset = crate::utils::unchecked_mul(row, self.inner.stride); + E::faer_map( + self.as_ptr(), + #[inline(always)] + |ptr| ptr.offset(offset), + ) + } + + #[inline(always)] + unsafe fn overflowing_ptr_at(self, row: usize) -> GroupFor { + unsafe { + let cond = row != self.nrows(); + let offset = (cond as usize).wrapping_neg() as isize + & (row as isize).wrapping_mul(self.inner.stride); + E::faer_map( + self.as_ptr(), + #[inline(always)] + |ptr| ptr.offset(offset), + ) + } + } + + /// Returns raw pointers to the element at the given index, assuming the provided index + /// is within the size of the vector. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn ptr_inbounds_at(self, row: usize) -> GroupFor { + debug_assert!(row < self.nrows()); + self.unchecked_ptr_at(row) + } + + /// Splits the column vector at the given index into two parts and + /// returns an array of each subvector, in the following order: + /// * top. + /// * bottom. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row <= self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at_unchecked(self, row: usize) -> (Self, Self) { + debug_assert!(row <= self.nrows()); + + let row_stride = self.row_stride(); + + let nrows = self.nrows(); + + unsafe { + let top = self.as_ptr(); + let bot = self.overflowing_ptr_at(row); + + ( + Self::__from_raw_parts(top, row, row_stride), + Self::__from_raw_parts(bot, nrows - row, row_stride), + ) + } + } + + /// Splits the column vector at the given index into two parts and + /// returns an array of each subvector, in the following order: + /// * top. + /// * bottom. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row <= self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at(self, row: usize) -> (Self, Self) { + assert!(row <= self.nrows()); + unsafe { self.split_at_unchecked(row) } + } + + /// Returns references to the element at the given index, or subvector if `row` is a + /// range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + #[inline(always)] + #[track_caller] + pub unsafe fn get_unchecked( + self, + row: RowRange, + ) -> >::Target + where + Self: ColIndex, + { + >::get_unchecked(self, row) + } + + /// Returns references to the element at the given index, or subvector if `row` is a + /// range, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + #[inline(always)] + #[track_caller] + pub fn get(self, row: RowRange) -> >::Target + where + Self: ColIndex, + { + >::get(self, row) + } + + /// Reads the value of the element at the given index. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn read_unchecked(&self, row: usize) -> E { + E::faer_from_units(E::faer_map( + self.get_unchecked(row), + #[inline(always)] + |ptr| *ptr, + )) + } + + /// Reads the value of the element at the given index, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + #[inline(always)] + #[track_caller] + pub fn read(&self, row: usize) -> E { + E::faer_from_units(E::faer_map( + self.get(row), + #[inline(always)] + |ptr| *ptr, + )) + } + + /// Returns a view over the transpose of `self`. + #[inline(always)] + #[must_use] + pub fn transpose(self) -> RowRef<'a, E> { + unsafe { crate::row::from_raw_parts(self.as_ptr(), self.nrows(), self.row_stride()) } + } + + /// Returns a view over the conjugate of `self`. + #[inline(always)] + #[must_use] + pub fn conjugate(self) -> ColRef<'a, E::Conj> + where + E: Conjugate, + { + unsafe { + // SAFETY: Conjugate requires that E::Unit and E::Conj::Unit have the same layout + // and that GroupCopyFor == E::Conj::GroupCopy + super::from_raw_parts::<'_, E::Conj>( + transmute_unchecked::< + GroupFor>, + GroupFor>, + >(self.as_ptr()), + self.nrows(), + self.row_stride(), + ) + } + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline(always)] + pub fn adjoint(self) -> RowRef<'a, E::Conj> + where + E: Conjugate, + { + self.conjugate().transpose() + } + + /// Returns a view over the canonical representation of `self`, as well as a flag declaring + /// whether `self` is implicitly conjugated or not. + #[inline(always)] + pub fn canonicalize(self) -> (ColRef<'a, E::Canonical>, Conj) + where + E: Conjugate, + { + ( + unsafe { + // SAFETY: see Self::conjugate + super::from_raw_parts::<'_, E::Canonical>( + transmute_unchecked::< + GroupFor, + GroupFor>, + >(self.as_ptr()), + self.nrows(), + self.row_stride(), + ) + }, + if coe::is_same::() { + Conj::No + } else { + Conj::Yes + }, + ) + } + + /// Returns a view over the `self`, with the rows in reversed order. + #[inline(always)] + #[must_use] + pub fn reverse_rows(self) -> Self { + let nrows = self.nrows(); + let row_stride = self.row_stride().wrapping_neg(); + + let ptr = unsafe { self.unchecked_ptr_at(nrows.saturating_sub(1)) }; + unsafe { Self::__from_raw_parts(ptr, nrows, row_stride) } + } + + /// Returns a view over the subvector starting at row `row_start`, and with number of rows + /// `nrows`. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `nrows <= self.nrows() - row_start`. + #[track_caller] + #[inline(always)] + pub unsafe fn subrows_unchecked(self, row_start: usize, nrows: usize) -> Self { + debug_assert!(all( + row_start <= self.nrows(), + nrows <= self.nrows() - row_start + )); + let row_stride = self.row_stride(); + unsafe { Self::__from_raw_parts(self.overflowing_ptr_at(row_start), nrows, row_stride) } + } + + /// Returns a view over the subvector starting at row `row_start`, and with number of rows + /// `nrows`. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `nrows <= self.nrows() - row_start`. + #[track_caller] + #[inline(always)] + pub fn subrows(self, row_start: usize, nrows: usize) -> Self { + assert!(all( + row_start <= self.nrows(), + nrows <= self.nrows() - row_start + )); + unsafe { self.subrows_unchecked(row_start, nrows) } + } + + /// Given a matrix with a single column, returns an object that interprets + /// the column as a diagonal matrix, whoes diagonal elements are values in the column. + #[track_caller] + #[inline(always)] + pub fn column_vector_as_diagonal(self) -> DiagRef<'a, E> { + DiagRef { inner: self } + } + + /// Returns an owning [`Col`] of the data. + #[inline] + pub fn to_owned(&self) -> Col + where + E: Conjugate, + { + let mut mat = Col::new(); + mat.resize_with( + self.nrows(), + #[inline(always)] + |row| unsafe { self.read_unchecked(row).canonicalize() }, + ); + mat + } + + /// Returns `true` if any of the elements is NaN, otherwise returns `false`. + #[inline] + pub fn has_nan(&self) -> bool + where + E: ComplexField, + { + (*self).as_2d().has_nan() + } + + /// Returns `true` if all of the elements are finite, otherwise returns `false`. + #[inline] + pub fn is_all_finite(&self) -> bool + where + E: ComplexField, + { + (*self).rb().as_2d().is_all_finite() + } + + /// Returns the maximum norm of `self`. + #[inline] + pub fn norm_max(&self) -> E::Real + where + E: ComplexField, + { + self.as_2d().norm_max() + } + /// Returns the L2 norm of `self`. + #[inline] + pub fn norm_l2(&self) -> E::Real + where + E: ComplexField, + { + self.as_2d().norm_l2() + } + + /// Returns the sum of `self`. + #[inline] + pub fn sum(&self) -> E + where + E: ComplexField, + { + self.as_2d().sum() + } + + /// Kroneckor product of `self` and `rhs`. + /// + /// This is an allocating operation; see [`faer::linalg::kron`](crate::linalg::kron) for the + /// allocation-free version or more info in general. + #[inline] + #[track_caller] + pub fn kron(&self, rhs: impl As2D) -> Mat + where + E: ComplexField, + { + self.as_2d_ref().kron(rhs) + } + + /// Returns a view over the matrix. + #[inline] + pub fn as_ref(&self) -> ColRef<'_, E> { + *self + } + + #[doc(hidden)] + #[inline(always)] + pub unsafe fn const_cast(self) -> ColMut<'a, E> { + ColMut { + inner: self.inner, + __marker: PhantomData, + } + } +} + +/// Creates a `ColRef` from pointers to the column vector data, number of rows, and row stride. +/// +/// # Safety: +/// This function has the same safety requirements as +/// [`mat::from_raw_parts(ptr, nrows, 1, row_stride, 0)`] +#[inline(always)] +pub unsafe fn from_raw_parts<'a, E: Entity>( + ptr: GroupFor, + nrows: usize, + row_stride: isize, +) -> ColRef<'a, E> { + ColRef::__from_raw_parts(ptr, nrows, row_stride) +} + +/// Creates a `ColRef` from slice views over the column vector data, The result has the same +/// number of rows as the length of the input slice. +#[inline(always)] +pub fn from_slice(slice: GroupFor) -> ColRef<'_, E> { + let nrows = SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len(); + + unsafe { + from_raw_parts( + E::faer_map( + slice, + #[inline(always)] + |slice| slice.as_ptr(), + ), + nrows, + 1, + ) + } +} +impl As2D for &'_ ColRef<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (**self).as_2d() + } +} + +impl As2D for ColRef<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (*self).as_2d() + } +} + +impl AsColRef for ColRef<'_, E> { + #[inline] + fn as_col_ref(&self) -> ColRef<'_, E> { + *self + } +} +impl AsColRef for &'_ ColRef<'_, E> { + #[inline] + fn as_col_ref(&self) -> ColRef<'_, E> { + **self + } +} + +impl<'a, E: Entity> core::fmt::Debug for ColRef<'a, E> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.as_2d().fmt(f) + } +} + +impl core::ops::Index for ColRef<'_, E> { + type Output = E; + + #[inline] + #[track_caller] + fn index(&self, row: usize) -> &E { + self.get(row) + } +} diff --git a/src/col/mod.rs b/src/col/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..a3f76ba0cdc3fc4bb6a2c36aac7d2166044f6c48 --- /dev/null +++ b/src/col/mod.rs @@ -0,0 +1,65 @@ +use crate::{utils::slice::*, Conj}; +use core::{marker::PhantomData, ptr::NonNull}; +use faer_entity::*; +use reborrow::*; + +#[repr(C)] +pub(crate) struct VecImpl { + pub(crate) ptr: GroupCopyFor>, + pub(crate) len: usize, + pub(crate) stride: isize, +} +#[repr(C)] +pub(crate) struct VecOwnImpl { + pub(crate) ptr: GroupCopyFor>, + pub(crate) len: usize, +} + +impl Copy for VecImpl {} +impl Clone for VecImpl { + #[inline(always)] + fn clone(&self) -> Self { + *self + } +} + +unsafe impl Sync for VecImpl {} +unsafe impl Send for VecImpl {} +unsafe impl Sync for VecOwnImpl {} +unsafe impl Send for VecOwnImpl {} + +/// Represents a type that can be used to slice a column, such as an index or a range of indices. +pub trait ColIndex: crate::seal::Seal + Sized { + /// Resulting type of the indexing operation. + type Target; + + /// Index the column at `row`, without bound checks. + #[allow(clippy::missing_safety_doc)] + unsafe fn get_unchecked(this: Self, row: RowRange) -> Self::Target { + >::get(this, row) + } + /// Index the column at `row`. + fn get(this: Self, row: RowRange) -> Self::Target; +} + +/// Trait for types that can be converted to a column view. +pub trait AsColRef { + /// Convert to a column view. + fn as_col_ref(&self) -> ColRef<'_, E>; +} +/// Trait for types that can be converted to a mutable column view. +pub trait AsColMut { + /// Convert to a mutable column view. + fn as_col_mut(&mut self) -> ColMut<'_, E>; +} + +mod col_index; + +mod colref; +pub use colref::{from_raw_parts, from_slice, ColRef}; + +mod colmut; +pub use colmut::{from_raw_parts_mut, from_slice_mut, ColMut}; + +mod colown; +pub use colown::Col; diff --git a/faer-libs/faer-core/src/complex_native/c32conj_impl.rs b/src/complex_native/c32_conj_impl.rs similarity index 91% rename from faer-libs/faer-core/src/complex_native/c32conj_impl.rs rename to src/complex_native/c32_conj_impl.rs index 9c04f2424fbb19c9e4275395cf33e991ac205eb1..ed2e78ca4bdba11a95efb4ef4d2dbfdd8e624f02 100644 --- a/faer-libs/faer-core/src/complex_native/c32conj_impl.rs +++ b/src/complex_native/c32_conj_impl.rs @@ -1,18 +1,7 @@ -use crate::complex_native::c32_impl::c32; +use super::{c32, c32conj}; use faer_entity::*; use pulp::Simd; -/// 32-bit implicitly conjugated complex floating point type. -#[allow(non_camel_case_types)] -#[derive(Copy, Clone, PartialEq)] -#[repr(C)] -pub struct c32conj { - /// Real part. - pub re: f32, - /// Imaginary part. - pub neg_im: f32, -} - unsafe impl bytemuck::Zeroable for c32conj {} unsafe impl bytemuck::Pod for c32conj {} diff --git a/faer-libs/faer-core/src/complex_native/c32_impl.rs b/src/complex_native/c32_impl.rs similarity index 98% rename from faer-libs/faer-core/src/complex_native/c32_impl.rs rename to src/complex_native/c32_impl.rs index 153e4ecbb917662134c4ddfb8299d1ef082342e3..bb646af367d65f6c6d30bb91721cdaacdfd2c84a 100644 --- a/faer-libs/faer-core/src/complex_native/c32_impl.rs +++ b/src/complex_native/c32_impl.rs @@ -1,4 +1,5 @@ -use crate::complex_native::c32conj_impl::c32conj; +use super::{c32, c32conj}; + use faer_entity::*; #[cfg(not(feature = "std"))] use num_traits::float::FloatCore; @@ -22,17 +23,6 @@ macro_rules! impl_from_num_complex { }; } -/// 32-bit complex floating point type. See the module-level documentation for more details. -#[allow(non_camel_case_types)] -#[derive(Copy, Clone, PartialEq)] -#[repr(C)] -pub struct c32 { - /// Real part. - pub re: f32, - /// Imaginary part. - pub im: f32, -} - impl c32 { /// Create a new complex number. #[inline(always)] diff --git a/faer-libs/faer-core/src/complex_native/c64conj_impl.rs b/src/complex_native/c64_conj_impl.rs similarity index 91% rename from faer-libs/faer-core/src/complex_native/c64conj_impl.rs rename to src/complex_native/c64_conj_impl.rs index 2e8936f74f59a75601a1eba8128e87bc2e328590..fd794c7d23853a71b763aebc20bc64e8923404b1 100644 --- a/faer-libs/faer-core/src/complex_native/c64conj_impl.rs +++ b/src/complex_native/c64_conj_impl.rs @@ -1,18 +1,8 @@ -use crate::complex_native::c64_impl::c64; +use super::{c64, c64conj}; + use faer_entity::*; use pulp::Simd; -/// 64-bit implicitly conjugated complex floating point type. -#[allow(non_camel_case_types)] -#[derive(Copy, Clone, PartialEq)] -#[repr(C)] -pub struct c64conj { - /// Real part. - pub re: f64, - /// Imaginary part. - pub neg_im: f64, -} - unsafe impl bytemuck::Pod for c64conj {} unsafe impl bytemuck::Zeroable for c64conj {} diff --git a/faer-libs/faer-core/src/complex_native/c64_impl.rs b/src/complex_native/c64_impl.rs similarity index 98% rename from faer-libs/faer-core/src/complex_native/c64_impl.rs rename to src/complex_native/c64_impl.rs index 988ab258e08ce375825294a99786c9bf0978756d..5cf203e516364a184c1aff409c1dc9d91be27480 100644 --- a/faer-libs/faer-core/src/complex_native/c64_impl.rs +++ b/src/complex_native/c64_impl.rs @@ -1,4 +1,4 @@ -use crate::complex_native::c64conj_impl::c64conj; +use super::{c64, c64conj}; use faer_entity::*; #[cfg(not(feature = "std"))] use num_traits::float::FloatCore; @@ -21,17 +21,6 @@ macro_rules! impl_from_num_complex { }; } -/// 64-bit complex floating point type. See the module-level documentation for more details. -#[allow(non_camel_case_types)] -#[derive(Copy, Clone, PartialEq)] -#[repr(C)] -pub struct c64 { - /// Real part. - pub re: f64, - /// Imaginary part. - pub im: f64, -} - impl c64 { /// Create a new complex number. #[inline(always)] diff --git a/faer-libs/faer-core/src/complex_native/mod.rs b/src/complex_native/mod.rs similarity index 70% rename from faer-libs/faer-core/src/complex_native/mod.rs rename to src/complex_native/mod.rs index 3d81fb0beae72806faa3bebc934a834714cd104e..0e2c8054ecba4579806c126868f08ff20d48f0bd 100644 --- a/faer-libs/faer-core/src/complex_native/mod.rs +++ b/src/complex_native/mod.rs @@ -40,12 +40,47 @@ //! └──────────┘ └──────────┘ //! ``` +mod c32_conj_impl; mod c32_impl; -mod c32conj_impl; +mod c64_conj_impl; mod c64_impl; -mod c64conj_impl; -pub use c32_impl::c32; -pub use c32conj_impl::c32conj; -pub use c64_impl::c64; -pub use c64conj_impl::c64conj; +/// 32-bit complex floating point type. See the module-level documentation for more details. +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, PartialEq)] +pub struct c32 { + /// Real part. + pub re: f32, + /// Negated imaginary part. + pub im: f32, +} + +/// 64-bit complex floating point type. See the module-level documentation for more details. +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, PartialEq)] +pub struct c64 { + /// Real part. + pub re: f64, + /// Negated imaginary part. + pub im: f64, +} + +/// 32-bit implicitly conjugated complex floating point type. +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, PartialEq)] +pub struct c32conj { + /// Real part. + pub re: f32, + /// Negated imaginary part. + pub neg_im: f32, +} + +/// 64-bit implicitly conjugated complex floating point type. +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, PartialEq)] +pub struct c64conj { + /// Real part. + pub re: f64, + /// Negated imaginary part. + pub neg_im: f64, +} diff --git a/src/diag/diagmut.rs b/src/diag/diagmut.rs new file mode 100644 index 0000000000000000000000000000000000000000..cd7d59e581df46b4d9ffb5475de38c83428f78dd --- /dev/null +++ b/src/diag/diagmut.rs @@ -0,0 +1,60 @@ +use super::*; +use crate::col::ColMut; + +/// Diagonal mutable matrix view. +pub struct DiagMut<'a, E: Entity> { + pub(crate) inner: ColMut<'a, E>, +} + +impl<'a, E: Entity> DiagMut<'a, E> { + /// Returns the diagonal as a mutable column vector view. + #[inline(always)] + pub fn column_vector_mut(self) -> ColMut<'a, E> { + self.inner + } + + /// Returns a view over the matrix. + #[inline] + pub fn as_ref(&self) -> DiagRef<'_, E> { + self.rb() + } + + /// Returns a mutable view over the matrix. + #[inline] + pub fn as_mut(&mut self) -> DiagMut<'_, E> { + self.rb_mut() + } +} + +impl<'short, E: Entity> Reborrow<'short> for DiagMut<'_, E> { + type Target = DiagRef<'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + DiagRef { + inner: self.inner.rb(), + } + } +} + +impl<'short, E: Entity> ReborrowMut<'short> for DiagMut<'_, E> { + type Target = DiagMut<'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + DiagMut { + inner: self.inner.rb_mut(), + } + } +} + +impl<'a, E: Entity> IntoConst for DiagMut<'a, E> { + type Target = DiagRef<'a, E>; + + #[inline] + fn into_const(self) -> Self::Target { + DiagRef { + inner: self.inner.into_const(), + } + } +} diff --git a/src/diag/diagown.rs b/src/diag/diagown.rs new file mode 100644 index 0000000000000000000000000000000000000000..a85fc55a1b0a0e608fad7ab339107a3f7ff97a03 --- /dev/null +++ b/src/diag/diagown.rs @@ -0,0 +1,32 @@ +use super::*; +use crate::col::Col; +use faer_entity::Entity; + +/// Diagonal matrix. +pub struct Diag { + pub(crate) inner: Col, +} + +impl Diag { + /// Returns the diagonal as a column vector. + #[inline(always)] + pub fn into_column_vector(self) -> Col { + self.inner + } + + /// Returns a view over `self`. + #[inline(always)] + pub fn as_ref(&self) -> DiagRef<'_, E> { + DiagRef { + inner: self.inner.as_ref(), + } + } + + /// Returns a mutable view over `self`. + #[inline(always)] + pub fn as_mut(&mut self) -> DiagMut<'_, E> { + DiagMut { + inner: self.inner.as_mut(), + } + } +} diff --git a/src/diag/diagref.rs b/src/diag/diagref.rs new file mode 100644 index 0000000000000000000000000000000000000000..9978f9d3dd47aa9ba54671c7a4c4225c4ecd7b34 --- /dev/null +++ b/src/diag/diagref.rs @@ -0,0 +1,57 @@ +use super::*; +use crate::col::ColRef; + +/// Diagonal matrix view. +pub struct DiagRef<'a, E: Entity> { + pub(crate) inner: ColRef<'a, E>, +} + +impl<'a, E: Entity> DiagRef<'a, E> { + /// Returns the diagonal as a column vector view. + #[inline(always)] + pub fn column_vector(self) -> ColRef<'a, E> { + self.inner + } + + /// Returns a view over the matrix. + #[inline] + pub fn as_ref(&self) -> DiagRef<'_, E> { + *self + } +} + +impl Clone for DiagRef<'_, E> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl Copy for DiagRef<'_, E> {} + +impl<'short, E: Entity> Reborrow<'short> for DiagRef<'_, E> { + type Target = DiagRef<'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} + +impl<'short, E: Entity> ReborrowMut<'short> for DiagRef<'_, E> { + type Target = DiagRef<'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl IntoConst for DiagRef<'_, E> { + type Target = Self; + + #[inline] + fn into_const(self) -> Self::Target { + self + } +} diff --git a/src/diag/mod.rs b/src/diag/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..cb2efe4f2960e97167cb7d5c44ed9ea9fcc43264 --- /dev/null +++ b/src/diag/mod.rs @@ -0,0 +1,11 @@ +use faer_entity::*; +use reborrow::*; + +mod diagref; +pub use diagref::DiagRef; + +mod diagmut; +pub use diagmut::DiagMut; + +mod diagown; +pub use diagown::Diag; diff --git a/src/io.rs b/src/io.rs new file mode 100644 index 0000000000000000000000000000000000000000..df6e874332043f89ff86f90ef99ce616292cdeca --- /dev/null +++ b/src/io.rs @@ -0,0 +1,192 @@ +#[allow(unused_imports)] +use super::*; +#[allow(unused_imports)] +use crate::assert; +#[allow(unused_imports)] +use complex_native::{c32, c64}; + +#[cfg(feature = "npy")] +#[cfg_attr(docsrs, doc(cfg(feature = "npy")))] +pub struct Npy<'a> { + aligned_bytes: &'a [u8], + nrows: usize, + ncols: usize, + prefix_len: usize, + dtype: NpyDType, + fortran_order: bool, +} + +#[cfg(feature = "npy")] +#[cfg_attr(docsrs, doc(cfg(feature = "npy")))] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum NpyDType { + F32, + F64, + C32, + C64, + Other, +} + +#[cfg(feature = "npy")] +#[cfg_attr(docsrs, doc(cfg(feature = "npy")))] +pub trait FromNpy: faer_entity::SimpleEntity { + const DTYPE: NpyDType; +} + +#[cfg(feature = "npy")] +#[cfg_attr(docsrs, doc(cfg(feature = "npy")))] +impl FromNpy for f32 { + const DTYPE: NpyDType = NpyDType::F32; +} +#[cfg(feature = "npy")] +#[cfg_attr(docsrs, doc(cfg(feature = "npy")))] +impl FromNpy for f64 { + const DTYPE: NpyDType = NpyDType::F64; +} +#[cfg(feature = "npy")] +#[cfg_attr(docsrs, doc(cfg(feature = "npy")))] +impl FromNpy for c32 { + const DTYPE: NpyDType = NpyDType::C32; +} +#[cfg(feature = "npy")] +#[cfg_attr(docsrs, doc(cfg(feature = "npy")))] +impl FromNpy for c64 { + const DTYPE: NpyDType = NpyDType::C64; +} + +#[cfg(feature = "npy")] +#[cfg_attr(docsrs, doc(cfg(feature = "npy")))] +impl<'a> Npy<'a> { + fn parse_npyz( + data: &[u8], + npyz: npyz::NpyFile<&[u8]>, + ) -> Result<(NpyDType, usize, usize, usize, bool), std::io::Error> { + let ver_major = data[6] - b'\x00'; + let length = if ver_major <= 1 { + 2usize + } else if ver_major <= 3 { + 4usize + } else { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "unsupported version", + )); + }; + let header_len = if length == 2 { + u16::from_le_bytes(data[8..10].try_into().unwrap()) as usize + } else { + u32::from_le_bytes(data[8..12].try_into().unwrap()) as usize + }; + let dtype = || -> NpyDType { + match npyz.dtype() { + npyz::DType::Plain(str) => { + let is_complex = match str.type_char() { + npyz::TypeChar::Float => false, + npyz::TypeChar::Complex => true, + _ => return NpyDType::Other, + }; + + let byte_size = str.size_field(); + if byte_size == 8 && is_complex { + NpyDType::C32 + } else if byte_size == 16 && is_complex { + NpyDType::C64 + } else if byte_size == 4 && !is_complex { + NpyDType::F32 + } else if byte_size == 16 && !is_complex { + NpyDType::F64 + } else { + NpyDType::Other + } + } + _ => NpyDType::Other, + } + }; + + let dtype = dtype(); + let order = npyz.header().order(); + let shape = npyz.shape(); + let nrows = shape.get(0).copied().unwrap_or(1) as usize; + let ncols = shape.get(1).copied().unwrap_or(1) as usize; + let prefix_len = 8 + length + header_len; + let fortran_order = order == npyz::Order::Fortran; + Ok((dtype, nrows, ncols, prefix_len, fortran_order)) + } + + #[inline] + pub fn new(data: &'a [u8]) -> Result { + let npyz = npyz::NpyFile::new(data)?; + + let (dtype, nrows, ncols, prefix_len, fortran_order) = Self::parse_npyz(data, npyz)?; + + Ok(Self { + aligned_bytes: data, + prefix_len, + nrows, + ncols, + dtype, + fortran_order, + }) + } + + #[inline] + pub fn dtype(&self) -> NpyDType { + self.dtype + } + + #[inline] + pub fn is_aligned(&self) -> bool { + self.aligned_bytes.as_ptr().align_offset(64) == 0 + } + + #[inline] + pub fn as_aligned_ref(&self) -> MatRef<'_, E> { + assert!(self.is_aligned()); + assert!(self.dtype == E::DTYPE); + + if self.fortran_order { + crate::mat::from_column_major_slice( + bytemuck::cast_slice(&self.aligned_bytes[self.prefix_len..]), + self.nrows, + self.ncols, + ) + } else { + crate::mat::from_row_major_slice( + bytemuck::cast_slice(&self.aligned_bytes[self.prefix_len..]), + self.nrows, + self.ncols, + ) + } + } + + #[inline] + pub fn to_mat(&self) -> Mat { + assert!(self.dtype == E::DTYPE); + + let mut mat = Mat::::with_capacity(self.nrows, self.ncols); + unsafe { mat.set_dims(self.nrows, self.ncols) }; + + let data = &self.aligned_bytes[self.prefix_len..]; + + if self.fortran_order { + for j in 0..self.ncols { + bytemuck::cast_slice_mut(mat.col_as_slice_mut(j)).copy_from_slice( + &data[j * self.nrows * core::mem::size_of::()..] + [..self.nrows * core::mem::size_of::()], + ) + } + } else { + for j in 0..self.ncols { + for i in 0..self.nrows { + bytemuck::cast_slice_mut(&mut mat.col_as_slice_mut(j)[i..i + 1]) + .copy_from_slice( + &data[(i * self.ncols + j) * core::mem::size_of::()..] + [..core::mem::size_of::()], + ) + } + } + }; + + mat + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000000000000000000000000000000000000..7f2a50bf175e97f522e8a4381f980d6a00123916 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,1434 @@ +//! `faer` is a general-purpose linear algebra library for Rust, with a focus on high performance +//! for algebraic operations on medium/large matrices, as well as matrix decompositions. +//! +//! Most of the high-level functionality in this library is provided through associated functions in +//! its vocabulary types: [`Mat`]/[`MatRef`]/[`MatMut`]. +//! +//! `faer` is recommended for applications that handle medium to large dense matrices, and its +//! design is not well suited for applications that operate mostly on low dimensional vectors and +//! matrices such as computer graphics or game development. For those purposes, `nalgebra` and +//! `cgmath` may provide better tools. +//! +//! # Basic usage +//! +//! [`Mat`] is a resizable matrix type with dynamic capacity, which can be created using +//! [`Mat::new`] to produce an empty $0\times 0$ matrix, [`Mat::zeros`] to create a rectangular +//! matrix filled with zeros, [`Mat::identity`] to create an identity matrix, or [`Mat::from_fn`] +//! for the most generic case. +//! +//! Given a `&Mat` (resp. `&mut Mat`), a [`MatRef<'_, E>`](MatRef) (resp. [`MatMut<'_, +//! E>`](MatMut)) can be created by calling [`Mat::as_ref`] (resp. [`Mat::as_mut`]), which allow +//! for more flexibility than `Mat` in that they allow slicing ([`MatRef::get`]) and splitting +//! ([`MatRef::split_at`]). +//! +//! `MatRef` and `MatMut` are lightweight view objects. The former can be copied freely while the +//! latter has move and reborrow semantics, as described in its documentation. +//! +//! More details about the vocabulary types can be found in each one's module's +//! documentation. See also: [`faer_entity::Entity`] and [`complex_native`]. +//! +//! Most of the matrix operations can be used through the corresponding math operators: `+` for +//! matrix addition, `-` for subtraction, `*` for either scalar or matrix multiplication depending +//! on the types of the operands. +//! +//! ## Example +//! ``` +//! use faer::{mat, scale, Mat}; +//! +//! let a = mat![ +//! [1.0, 5.0, 9.0], +//! [2.0, 6.0, 10.0], +//! [3.0, 7.0, 11.0], +//! [4.0, 8.0, 12.0f64], +//! ]; +//! +//! let b = Mat::::from_fn(4, 3, |i, j| (i + j) as f64); +//! +//! let add = &a + &b; +//! let sub = &a - &b; +//! let scale = scale(3.0) * &a; +//! let mul = &a * b.transpose(); +//! +//! let a00 = a[(0, 0)]; +//! ``` +//! +//! # Matrix decompositions +//! `faer` provides a variety of matrix factorizations, each with its own advantages and drawbacks: +//! +//! ## Cholesky decomposition +//! [`Mat::cholesky`] decomposes a self-adjoint positive definite matrix $A$ such that +//! $$A = LL^H,$$ +//! where $L$ is a lower triangular matrix. This decomposition is highly efficient and has good +//! stability properties. +//! +//! [An implementation for sparse matrices is also available.](sparse::linalg::solvers::Cholesky) +//! +//! ## Bunch-Kaufman decomposition +//! [`Mat::lblt`] decomposes a self-adjoint (possibly indefinite) matrix $A$ such that +//! $$P A P^\top = LBL^H,$$ +//! where $P$ is a permutation matrix, $L$ is a lower triangular matrix, and $B$ is a block +//! diagonal matrix, with $1 \times 1$ or $2 \times 2$ diagonal blocks. +//! This decomposition is efficient and has good stability properties. +//! ## LU decomposition with partial pivoting +//! [`Mat::partial_piv_lu`] decomposes a square invertible matrix $A$ into a lower triangular +//! matrix $L$, a unit upper triangular matrix $U$, and a permutation matrix $P$, such that +//! $$PA = LU.$$ +//! It is used by default for computing the determinant, and is generally the recommended method +//! for solving a square linear system or computing the inverse of a matrix (although we generally +//! recommend using a [`faer::linalg::solvers::Solver`](crate::linalg::solvers::Solver) instead of +//! computing the inverse explicitly). +//! +//! [An implementation for sparse matrices is also available.](sparse::linalg::solvers::Lu) +//! +//! ## LU decomposition with full pivoting +//! [`Mat::full_piv_lu`] Decomposes a generic rectangular matrix $A$ into a lower triangular +//! matrix $L$, a unit upper triangular matrix $U$, and permutation matrices $P$ and $Q$, such that +//! $$PAQ^\top = LU.$$ +//! It can be more stable than the LU decomposition with partial pivoting, in exchange for being +//! more computationally expensive. +//! +//! ## QR decomposition +//! The QR decomposition ([`Mat::qr`]) decomposes a matrix $A$ into the product +//! $$A = QR,$$ +//! where $Q$ is a unitary matrix, and $R$ is an upper trapezoidal matrix. It is often used for +//! solving least squares problems. +//! +//! [An implementation for sparse matrices is also available.](sparse::linalg::solvers::Qr) +//! +//! ## QR decomposition with column pivoting +//! The QR decomposition with column pivoting ([`Mat::col_piv_qr`]) decomposes a matrix $A$ into +//! the product +//! $$AP^\top = QR,$$ +//! where $P$ is a permutation matrix, $Q$ is a unitary matrix, and $R$ is an upper trapezoidal +//! matrix. +//! +//! It is slower than the version with no pivoting, in exchange for being more numerically stable +//! for rank-deficient matrices. +//! +//! ## Singular value decomposition +//! The SVD of a matrix $M$ of shape $(m, n)$ is a decomposition into three components $U$, $S$, +//! and $V$, such that: +//! +//! - $U$ has shape $(m, m)$ and is a unitary matrix, +//! - $V$ has shape $(n, n)$ and is a unitary matrix, +//! - $S$ has shape $(m, n)$ and is zero everywhere except the main diagonal, with nonnegative +//! diagonal values in nonincreasing order, +//! - and finally: +//! +//! $$M = U S V^H.$$ +//! +//! The SVD is provided in two forms: either the full matrices $U$ and $V$ are computed, using +//! [`Mat::svd`], or only their first $\min(m, n)$ columns are computed, using +//! [`Mat::thin_svd`]. +//! +//! If only the singular values (elements of $S$) are desired, they can be obtained in +//! nonincreasing order using [`Mat::singular_values`]. +//! +//! ## Eigendecomposition +//! **Note**: The order of the eigenvalues is currently unspecified and may be changed in a future +//! release. +//! +//! The eigendecomposition of a square matrix $M$ of shape $(n, n)$ is a decomposition into +//! two components $U$, $S$: +//! +//! - $U$ has shape $(n, n)$ and is invertible, +//! - $S$ has shape $(n, n)$ and is a diagonal matrix, +//! - and finally: +//! +//! $$M = U S U^{-1}.$$ +//! +//! If $M$ is hermitian, then $U$ can be made unitary ($U^{-1} = U^H$), and $S$ is real valued. +//! +//! Depending on the domain of the input matrix and whether it is self-adjoint, multiple methods +//! are provided to compute the eigendecomposition: +//! * [`Mat::selfadjoint_eigendecomposition`] can be used with either real or complex matrices, +//! producing an eigendecomposition of the same type. +//! * [`Mat::eigendecomposition`] can be used with either real or complex matrices, but the output +//! complex type has to be specified. +//! * [`Mat::complex_eigendecomposition`] can only be used with complex matrices, with the output +//! having the same type. +//! +//! If only the eigenvalues (elements of $S$) are desired, they can be obtained in +//! nonincreasing order using [`Mat::selfadjoint_eigenvalues`], [`Mat::eigenvalues`], or +//! [`Mat::complex_eigenvalues`], with the same conditions described above. +//! +//! # Crate features +//! +//! - `std`: enabled by default. Links with the standard library to enable additional features such +//! as cpu feature detection at runtime. +//! - `rayon`: enabled by default. Enables the `rayon` parallel backend and enables global +//! parallelism by default. +//! - `serde`: Enables serialization and deserialization of [`Mat`]. +//! - `npy`: Enables conversions to/from numpy's matrix file format. +//! - `perf-warn`: Produces performance warnings when matrix operations are called with suboptimal +//! data layout. +//! - `nightly`: Requires the nightly compiler. Enables experimental SIMD features such as AVX512. + +#![allow(clippy::type_complexity)] +#![allow(clippy::too_many_arguments)] +#![allow(non_snake_case)] +#![warn(missing_docs)] +#![warn(rustdoc::broken_intra_doc_links)] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![cfg_attr(not(feature = "std"), no_std)] + +use core::sync::atomic::AtomicUsize; +use equator::{assert, debug_assert}; + +extern crate alloc; + +pub mod linalg; + +pub mod complex_native; + +pub use dbgf::dbgf; +pub use dyn_stack; +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +pub use matrixcompare::assert_matrix_eq; +pub use reborrow; + +/// Various utilities for low level implementations in generic code. +pub mod utils; + +/// Contiguous resizable column vector type. +pub mod col; +/// Contiguous resizable diagonal matrix type. +pub mod diag; +/// Contiguous resizable matrix type. +pub mod mat; +/// Permutation matrices. +pub mod perm; +/// Contiguous resizable row vector type. +pub mod row; +/// Sparse data structures and algorithms. +pub mod sparse; + +pub use col::{Col, ColMut, ColRef}; +pub use mat::{Mat, MatMut, MatRef}; +pub use row::{Row, RowMut, RowRef}; + +mod seal; +mod sort; + +pub use faer_entity::{ComplexField, Conjugate, Entity, RealField}; + +/// Specifies whether the triangular lower or upper part of a matrix should be accessed. +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum Side { + /// Lower half should be accessed. + Lower, + /// Upper half should be accessed. + Upper, +} + +/// Whether a matrix should be implicitly conjugated when read or not. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Conj { + /// Do conjugate. + Yes, + /// Do not conjugate. + No, +} + +impl Conj { + /// Combine `self` and `other` to create a new conjugation object. + #[inline] + pub fn compose(self, other: Conj) -> Conj { + if self == other { + Conj::No + } else { + Conj::Yes + } + } +} + +/// Zips together matrix of the same size, so that coefficient-wise operations can be performed on +/// their elements. +/// +/// # Note +/// The order in which the matrix elements are traversed is unspecified. +/// +/// # Example +/// ``` +/// use faer::{mat, unzipped, zipped, Mat}; +/// +/// let nrows = 2; +/// let ncols = 3; +/// +/// let a = mat![[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]]; +/// let b = mat![[7.0, 9.0, 11.0], [8.0, 10.0, 12.0]]; +/// let mut sum = Mat::::zeros(nrows, ncols); +/// +/// zipped!(sum.as_mut(), a.as_ref(), b.as_ref()).for_each(|unzipped!(mut sum, a, b)| { +/// let a = a.read(); +/// let b = b.read(); +/// sum.write(a + b); +/// }); +/// +/// for i in 0..nrows { +/// for j in 0..ncols { +/// assert_eq!(sum.read(i, j), a.read(i, j) + b.read(i, j)); +/// } +/// } +/// ``` +#[macro_export] +macro_rules! zipped { + ($head: expr $(,)?) => { + $crate::linalg::zip::LastEq($crate::linalg::zip::ViewMut::view_mut(&mut { $head })) + }; + + ($head: expr, $($tail: expr),* $(,)?) => { + $crate::linalg::zip::ZipEq::new($crate::linalg::zip::ViewMut::view_mut(&mut { $head }), $crate::zipped!($($tail,)*)) + }; +} + +/// Used to undo the zipping by the [`zipped!`] macro. +/// +/// # Example +/// ``` +/// use faer::{mat, unzipped, zipped, Mat}; +/// +/// let nrows = 2; +/// let ncols = 3; +/// +/// let a = mat![[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]]; +/// let b = mat![[7.0, 9.0, 11.0], [8.0, 10.0, 12.0]]; +/// let mut sum = Mat::::zeros(nrows, ncols); +/// +/// zipped!(sum.as_mut(), a.as_ref(), b.as_ref()).for_each(|unzipped!(mut sum, a, b)| { +/// let a = a.read(); +/// let b = b.read(); +/// sum.write(a + b); +/// }); +/// +/// for i in 0..nrows { +/// for j in 0..ncols { +/// assert_eq!(sum.read(i, j), a.read(i, j) + b.read(i, j)); +/// } +/// } +/// ``` +#[macro_export] +macro_rules! unzipped { + ($head: pat $(,)?) => { + $crate::linalg::zip::Last($head) + }; + + ($head: pat, $($tail: pat),* $(,)?) => { + $crate::linalg::zip::Zip($head, $crate::unzipped!($($tail,)*)) + }; +} + +#[doc(hidden)] +#[inline(always)] +pub fn ref_to_ptr(ptr: &T) -> *const T { + ptr +} + +#[macro_export] +#[doc(hidden)] +macro_rules! __transpose_impl { + ([$([$($col:expr),*])*] $($v:expr;)* ) => { + [$([$($col,)*],)* [$($v,)*]] + }; + ([$([$($col:expr),*])*] $($v0:expr, $($v:expr),* ;)*) => { + $crate::__transpose_impl!([$([$($col),*])* [$($v0),*]] $($($v),* ;)*) + }; +} + +/// Creates a [`Mat`] containing the arguments. +/// +/// ``` +/// use faer::mat; +/// +/// let matrix = mat![ +/// [1.0, 5.0, 9.0], +/// [2.0, 6.0, 10.0], +/// [3.0, 7.0, 11.0], +/// [4.0, 8.0, 12.0f64], +/// ]; +/// +/// assert_eq!(matrix.read(0, 0), 1.0); +/// assert_eq!(matrix.read(1, 0), 2.0); +/// assert_eq!(matrix.read(2, 0), 3.0); +/// assert_eq!(matrix.read(3, 0), 4.0); +/// +/// assert_eq!(matrix.read(0, 1), 5.0); +/// assert_eq!(matrix.read(1, 1), 6.0); +/// assert_eq!(matrix.read(2, 1), 7.0); +/// assert_eq!(matrix.read(3, 1), 8.0); +/// +/// assert_eq!(matrix.read(0, 2), 9.0); +/// assert_eq!(matrix.read(1, 2), 10.0); +/// assert_eq!(matrix.read(2, 2), 11.0); +/// assert_eq!(matrix.read(3, 2), 12.0); +/// ``` +#[macro_export] +macro_rules! mat { + () => { + { + compile_error!("number of columns in the matrix is ambiguous"); + } + }; + + ($([$($v:expr),* $(,)?] ),* $(,)?) => { + { + let data = ::core::mem::ManuallyDrop::new($crate::__transpose_impl!([] $($($v),* ;)*)); + let data = &*data; + let ncols = data.len(); + let nrows = (*data.get(0).unwrap()).len(); + + #[allow(unused_unsafe)] + unsafe { + $crate::mat::Mat::<_>::from_fn(nrows, ncols, |i, j| $crate::ref_to_ptr(&data[j][i]).read()) + } + } + }; +} + +#[cfg(feature = "perf-warn")] +#[macro_export] +#[doc(hidden)] +macro_rules! __perf_warn { + ($name: ident) => {{ + #[inline(always)] + #[allow(non_snake_case)] + fn $name() -> &'static ::core::sync::atomic::AtomicBool { + static $name: ::core::sync::atomic::AtomicBool = + ::core::sync::atomic::AtomicBool::new(false); + &$name + } + ::core::matches!( + $name().compare_exchange( + false, + true, + ::core::sync::atomic::Ordering::Relaxed, + ::core::sync::atomic::Ordering::Relaxed, + ), + Ok(_) + ) + }}; +} + +/// Convenience function to concatonate a nested list of matrices into a single +/// big ['Mat']. Concatonation pattern follows the numpy.block convention that +/// each sub-list must have an equal number of columns (net) but the boundaries +/// do not need to align. In other words, this sort of thing: +/// ```notcode +/// AAAbb +/// AAAbb +/// cDDDD +/// ``` +/// is perfectly acceptable. +#[doc(hidden)] +#[track_caller] +pub fn concat_impl(blocks: &[&[mat::MatRef<'_, E>]]) -> mat::Mat { + #[inline(always)] + fn count_total_columns(block_row: &[mat::MatRef<'_, E>]) -> usize { + let mut out: usize = 0; + for elem in block_row.iter() { + out += elem.ncols(); + } + out + } + + #[inline(always)] + #[track_caller] + fn count_rows(block_row: &[mat::MatRef<'_, E>]) -> usize { + let mut out: usize = 0; + for (i, e) in block_row.iter().enumerate() { + if i.eq(&0) { + out = e.nrows(); + } else { + assert!(e.nrows().eq(&out)); + } + } + out + } + + // get size of result while doing checks + let mut n: usize = 0; + let mut m: usize = 0; + for row in blocks.iter() { + n += count_rows(row); + } + for (i, row) in blocks.iter().enumerate() { + let cols = count_total_columns(row); + if i.eq(&0) { + m = cols; + } else { + assert!(cols.eq(&m)); + } + } + + let mut mat = mat::Mat::::zeros(n, m); + let mut ni: usize = 0; + let mut mj: usize; + for row in blocks.iter() { + mj = 0; + + for elem in row.iter() { + mat.as_mut() + .submatrix_mut(ni, mj, elem.nrows(), elem.ncols()) + .copy_from(elem); + mj += elem.ncols(); + } + ni += row[0].nrows(); + } + + mat +} + +/// Concatenates the matrices in each row horizontally, +/// then concatenates the results vertically. +/// +/// `concat![[a0, a1, a2], [b1, b2]]` results in the matrix +/// ```notcode +/// [a0 | a1 | a2][b0 | b1] +/// ``` +#[macro_export] +macro_rules! concat { + () => { + { + compile_error!("number of columns in the matrix is ambiguous"); + } + }; + + ($([$($v:expr),* $(,)?] ),* $(,)?) => { + { + $crate::concat_impl(&[$(&[$(($v).as_ref(),)*],)*]) + } + }; +} + +/// Creates a [`col::Col`] containing the arguments. +/// +/// ``` +/// use faer::col; +/// +/// let col_vec = col![3.0, 5.0, 7.0, 9.0]; +/// +/// assert_eq!(col_vec.read(0), 3.0); +/// assert_eq!(col_vec.read(1), 5.0); +/// assert_eq!(col_vec.read(2), 7.0); +/// assert_eq!(col_vec.read(3), 9.0); +/// ``` +#[macro_export] +macro_rules! col { + () => { + $crate::col::Col::<_>::new() + }; + + ($($v:expr),+ $(,)?) => {{ + let data = &[$($v),+]; + let n = data.len(); + + #[allow(unused_unsafe)] + unsafe { + $crate::col::Col::<_>::from_fn(n, |i| $crate::ref_to_ptr(&data[i]).read()) + } + }}; +} + +/// Creates a [`row::Row`] containing the arguments. +/// +/// ``` +/// use faer::row; +/// +/// let row_vec = row![3.0, 5.0, 7.0, 9.0]; +/// +/// assert_eq!(row_vec.read(0), 3.0); +/// assert_eq!(row_vec.read(1), 5.0); +/// assert_eq!(row_vec.read(2), 7.0); +/// assert_eq!(row_vec.read(3), 9.0); +/// ``` +#[macro_export] +macro_rules! row { + () => { + $crate::row::Row::<_>::new() + }; + + ($($v:expr),+ $(,)?) => {{ + let data = &[$($v),+]; + let n = data.len(); + + #[allow(unused_unsafe)] + unsafe { + $crate::row::Row::<_>::from_fn(n, |i| $crate::ref_to_ptr(&data[i]).read()) + } + }}; +} + +/// Trait for unsigned integers that can be indexed with. +/// +/// Always smaller than or equal to `usize`. +pub trait Index: + seal::Seal + + core::fmt::Debug + + core::ops::Not + + core::ops::Add + + core::ops::Sub + + core::ops::AddAssign + + core::ops::SubAssign + + bytemuck::Pod + + Eq + + Ord + + Send + + Sync +{ + /// Equally-sized index type with a fixed size (no `usize`). + type FixedWidth: Index; + /// Equally-sized signed index type. + type Signed: SignedIndex; + + /// Truncate `value` to type [`Self`]. + #[must_use] + #[inline(always)] + fn truncate(value: usize) -> Self { + Self::from_signed(::truncate(value)) + } + + /// Zero extend `self`. + #[must_use] + #[inline(always)] + fn zx(self) -> usize { + self.to_signed().zx() + } + + /// Convert a reference to a slice of [`Self`] to fixed width types. + #[inline(always)] + fn canonicalize(slice: &[Self]) -> &[Self::FixedWidth] { + bytemuck::cast_slice(slice) + } + + /// Convert a mutable reference to a slice of [`Self`] to fixed width types. + #[inline(always)] + fn canonicalize_mut(slice: &mut [Self]) -> &mut [Self::FixedWidth] { + bytemuck::cast_slice_mut(slice) + } + + /// Convert a signed value to an unsigned one. + #[inline(always)] + fn from_signed(value: Self::Signed) -> Self { + bytemuck::cast(value) + } + + /// Convert an unsigned value to a signed one. + #[inline(always)] + fn to_signed(self) -> Self::Signed { + bytemuck::cast(self) + } + + /// Sum values while checking for overflow. + #[inline] + fn sum_nonnegative(slice: &[Self]) -> Option { + Self::Signed::sum_nonnegative(bytemuck::cast_slice(slice)).map(Self::from_signed) + } +} + +/// Trait for signed integers corresponding to the ones satisfying [`Index`]. +/// +/// Always smaller than or equal to `isize`. +pub trait SignedIndex: + seal::Seal + + core::fmt::Debug + + core::ops::Neg + + core::ops::Add + + core::ops::Sub + + core::ops::AddAssign + + core::ops::SubAssign + + bytemuck::Pod + + Eq + + Ord + + Send + + Sync +{ + /// Maximum representable value. + const MAX: Self; + + /// Truncate `value` to type [`Self`]. + #[must_use] + fn truncate(value: usize) -> Self; + + /// Zero extend `self`. + #[must_use] + fn zx(self) -> usize; + /// Sign extend `self`. + #[must_use] + fn sx(self) -> usize; + + /// Sum nonnegative values while checking for overflow. + fn sum_nonnegative(slice: &[Self]) -> Option { + let mut acc = Self::zeroed(); + for &i in slice { + if Self::MAX - i < acc { + return None; + } + acc += i; + } + Some(acc) + } +} + +#[cfg(any( + target_pointer_width = "32", + target_pointer_width = "64", + target_pointer_width = "128", +))] +impl Index for u32 { + type FixedWidth = u32; + type Signed = i32; +} +#[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))] +impl Index for u64 { + type FixedWidth = u64; + type Signed = i64; +} +#[cfg(target_pointer_width = "128")] +impl Index for u128 { + type FixedWidth = u128; + type Signed = i128; +} + +impl Index for usize { + #[cfg(target_pointer_width = "32")] + type FixedWidth = u32; + #[cfg(target_pointer_width = "64")] + type FixedWidth = u64; + #[cfg(target_pointer_width = "128")] + type FixedWidth = u128; + + type Signed = isize; +} + +#[cfg(any( + target_pointer_width = "32", + target_pointer_width = "64", + target_pointer_width = "128", +))] +impl SignedIndex for i32 { + const MAX: Self = Self::MAX; + + #[inline(always)] + fn truncate(value: usize) -> Self { + #[allow(clippy::assertions_on_constants)] + const _: () = { + core::assert!(i32::BITS <= usize::BITS); + }; + value as isize as Self + } + + #[inline(always)] + fn zx(self) -> usize { + self as u32 as usize + } + + #[inline(always)] + fn sx(self) -> usize { + self as isize as usize + } +} + +#[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))] +impl SignedIndex for i64 { + const MAX: Self = Self::MAX; + + #[inline(always)] + fn truncate(value: usize) -> Self { + #[allow(clippy::assertions_on_constants)] + const _: () = { + core::assert!(i64::BITS <= usize::BITS); + }; + value as isize as Self + } + + #[inline(always)] + fn zx(self) -> usize { + self as u64 as usize + } + + #[inline(always)] + fn sx(self) -> usize { + self as isize as usize + } +} + +#[cfg(target_pointer_width = "128")] +impl SignedIndex for i128 { + const MAX: Self = Self::MAX; + + #[inline(always)] + fn truncate(value: usize) -> Self { + #[allow(clippy::assertions_on_constants)] + const _: () = { + core::assert!(i128::BITS <= usize::BITS); + }; + value as isize as Self + } + + #[inline(always)] + fn zx(self) -> usize { + self as u128 as usize + } + + #[inline(always)] + fn sx(self) -> usize { + self as isize as usize + } +} + +impl SignedIndex for isize { + const MAX: Self = Self::MAX; + + #[inline(always)] + fn truncate(value: usize) -> Self { + value as isize + } + + #[inline(always)] + fn zx(self) -> usize { + self as usize + } + + #[inline(always)] + fn sx(self) -> usize { + self as usize + } +} + +/// Factor for matrix-scalar multiplication. +#[derive(Copy, Clone, Debug)] +pub struct Scale(pub E); + +impl Scale { + /// Returns the inner value. + #[inline] + pub fn value(self) -> E { + self.0 + } +} + +/// Returns a factor for matrix-scalar multiplication. +#[inline] +pub fn scale(val: E) -> Scale { + Scale(val) +} + +/// Parallelism strategy that can be passed to most of the routines in the library. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Parallelism { + /// No parallelism. + /// + /// The code is executed sequentially on the same thread that calls a function + /// and passes this argument. + None, + /// Rayon parallelism. Only avaialble with the `rayon` feature. + /// + /// The code is possibly executed in parallel on the current thread, as well as the currently + /// active rayon thread pool. + /// + /// The contained value represents a hint about the number of threads an implementation should + /// use, but there is no way to guarantee how many or which threads will be used. + /// + /// A value of `0` treated as equivalent to `rayon::current_num_threads()`. + #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] + Rayon(usize), +} + +/// 0: Disable +/// 1: None +/// n >= 2: Rayon(n - 2) +/// +/// default: Rayon(0) +static GLOBAL_PARALLELISM: AtomicUsize = { + #[cfg(feature = "rayon")] + { + AtomicUsize::new(2) + } + #[cfg(not(feature = "rayon"))] + { + AtomicUsize::new(1) + } +}; + +/// Causes functions that access global parallelism settings to panic. +pub fn disable_global_parallelism() { + GLOBAL_PARALLELISM.store(0, core::sync::atomic::Ordering::Relaxed); +} + +/// Sets the global parallelism settings. +pub fn set_global_parallelism(parallelism: Parallelism) { + let value = match parallelism { + Parallelism::None => 1, + #[cfg(feature = "rayon")] + Parallelism::Rayon(n) => n.saturating_add(2), + }; + GLOBAL_PARALLELISM.store(value, core::sync::atomic::Ordering::Relaxed); +} + +/// Gets the global parallelism settings. +/// +/// # Panics +/// Panics if global parallelism is disabled. +#[track_caller] +pub fn get_global_parallelism() -> Parallelism { + let value = GLOBAL_PARALLELISM.load(core::sync::atomic::Ordering::Relaxed); + match value { + 0 => panic!("Global parallelism is disabled."), + 1 => Parallelism::None, + #[cfg(feature = "rayon")] + n => Parallelism::Rayon(n - 2), + #[cfg(not(feature = "rayon"))] + _ => unreachable!(), + } +} + +/// De-serialization from common matrix file formats. +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +pub mod io; + +#[cfg(feature = "serde")] +mod serde; + +/// faer prelude. Includes useful types and traits for solving linear systems. +pub mod prelude { + pub use crate::{ + linalg::solvers::{Solver, SolverLstsq, SpSolver, SpSolverLstsq}, + Col, ColMut, ColRef, Mat, MatMut, MatRef, Row, RowMut, RowRef, + }; +} + +#[cfg(test)] +mod tests { + use col::Col; + use faer_entity::*; + use row::Row; + + use super::*; + use crate::assert; + + #[test] + fn basic_slice() { + let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let slice = unsafe { mat::from_raw_parts::<'_, f64>(data.as_ptr(), 2, 3, 3, 1) }; + + assert!(slice.get(0, 0) == &1.0); + assert!(slice.get(0, 1) == &2.0); + assert!(slice.get(0, 2) == &3.0); + + assert!(slice.get(1, 0) == &4.0); + assert!(slice.get(1, 1) == &5.0); + assert!(slice.get(1, 2) == &6.0); + } + + #[test] + fn empty() { + { + let m = Mat::::new(); + assert!(m.nrows() == 0); + assert!(m.ncols() == 0); + assert!(m.row_capacity() == 0); + assert!(m.col_capacity() == 0); + } + + { + let m = Mat::::with_capacity(100, 120); + assert!(m.nrows() == 0); + assert!(m.ncols() == 0); + assert!(m.row_capacity() == 100); + assert!(m.col_capacity() == 120); + } + } + + #[test] + fn reserve() { + let mut m = Mat::::new(); + + m.reserve_exact(0, 0); + assert!(m.row_capacity() == 0); + assert!(m.col_capacity() == 0); + + m.reserve_exact(1, 1); + assert!(m.row_capacity() >= 1); + assert!(m.col_capacity() == 1); + + m.reserve_exact(2, 0); + assert!(m.row_capacity() >= 2); + assert!(m.col_capacity() == 1); + + m.reserve_exact(2, 3); + assert!(m.row_capacity() >= 2); + assert!(m.col_capacity() == 3); + } + + #[test] + fn reserve_zst() { + let mut m = Mat::::new(); + + m.reserve_exact(0, 0); + assert!(m.row_capacity() == 0); + assert!(m.col_capacity() == 0); + + m.reserve_exact(1, 1); + assert!(m.row_capacity() == 1); + assert!(m.col_capacity() == 1); + + m.reserve_exact(2, 0); + assert!(m.row_capacity() == 2); + assert!(m.col_capacity() == 1); + + m.reserve_exact(2, 3); + assert!(m.row_capacity() == 2); + assert!(m.col_capacity() == 3); + + m.reserve_exact(usize::MAX, usize::MAX); + } + + #[test] + fn resize() { + let mut m = Mat::new(); + let f = |i, j| i as f64 - j as f64; + m.resize_with(2, 3, f); + assert!(m.read(0, 0) == 0.0); + assert!(m.read(0, 1) == -1.0); + assert!(m.read(0, 2) == -2.0); + assert!(m.read(1, 0) == 1.0); + assert!(m.read(1, 1) == 0.0); + assert!(m.read(1, 2) == -1.0); + + m.resize_with(1, 2, f); + assert!(m.read(0, 0) == 0.0); + assert!(m.read(0, 1) == -1.0); + + m.resize_with(2, 1, f); + assert!(m.read(0, 0) == 0.0); + assert!(m.read(1, 0) == 1.0); + + m.resize_with(1, 2, f); + assert!(m.read(0, 0) == 0.0); + assert!(m.read(0, 1) == -1.0); + } + + #[test] + fn resize_zst() { + // miri test + let mut m = Mat::new(); + let f = |_i, _j| faer_entity::Symbolic; + m.resize_with(2, 3, f); + m.resize_with(1, 2, f); + m.resize_with(2, 1, f); + m.resize_with(1, 2, f); + } + + #[test] + #[should_panic] + fn cap_overflow_1() { + let _ = Mat::::with_capacity(isize::MAX as usize, 1); + } + + #[test] + #[should_panic] + fn cap_overflow_2() { + let _ = Mat::::with_capacity(isize::MAX as usize, isize::MAX as usize); + } + + #[test] + fn matrix_macro() { + let mut x = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]; + + assert!(x[(0, 0)] == 1.0); + assert!(x[(0, 1)] == 2.0); + assert!(x[(0, 2)] == 3.0); + + assert!(x[(1, 0)] == 4.0); + assert!(x[(1, 1)] == 5.0); + assert!(x[(1, 2)] == 6.0); + + assert!(x[(2, 0)] == 7.0); + assert!(x[(2, 1)] == 8.0); + assert!(x[(2, 2)] == 9.0); + + x[(0, 0)] = 13.0; + assert!(x[(0, 0)] == 13.0); + + assert!(x.get(.., ..) == x); + assert!(x.get(.., 1..3) == x.as_ref().submatrix(0, 1, 3, 2)); + } + + #[test] + fn matrix_macro_cplx() { + use num_complex::Complex; + let new = Complex::new; + let mut x = mat![ + [new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0)], + [new(7.0, 8.0), new(9.0, 10.0), new(11.0, 12.0)], + [new(13.0, 14.0), new(15.0, 16.0), new(17.0, 18.0)] + ]; + + assert!(x.read(0, 0) == Complex::new(1.0, 2.0)); + assert!(x.read(0, 1) == Complex::new(3.0, 4.0)); + assert!(x.read(0, 2) == Complex::new(5.0, 6.0)); + + assert!(x.read(1, 0) == Complex::new(7.0, 8.0)); + assert!(x.read(1, 1) == Complex::new(9.0, 10.0)); + assert!(x.read(1, 2) == Complex::new(11.0, 12.0)); + + assert!(x.read(2, 0) == Complex::new(13.0, 14.0)); + assert!(x.read(2, 1) == Complex::new(15.0, 16.0)); + assert!(x.read(2, 2) == Complex::new(17.0, 18.0)); + + x.write(1, 0, Complex::new(3.0, 2.0)); + assert!(x.read(1, 0) == Complex::new(3.0, 2.0)); + } + + #[test] + fn matrix_macro_native_cplx() { + use complex_native::c64 as Complex; + + let new = Complex::new; + let mut x = mat![ + [new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0)], + [new(7.0, 8.0), new(9.0, 10.0), new(11.0, 12.0)], + [new(13.0, 14.0), new(15.0, 16.0), new(17.0, 18.0)] + ]; + + assert!(x.read(0, 0) == Complex::new(1.0, 2.0)); + assert!(x.read(0, 1) == Complex::new(3.0, 4.0)); + assert!(x.read(0, 2) == Complex::new(5.0, 6.0)); + + assert!(x.read(1, 0) == Complex::new(7.0, 8.0)); + assert!(x.read(1, 1) == Complex::new(9.0, 10.0)); + assert!(x.read(1, 2) == Complex::new(11.0, 12.0)); + + assert!(x.read(2, 0) == Complex::new(13.0, 14.0)); + assert!(x.read(2, 1) == Complex::new(15.0, 16.0)); + assert!(x.read(2, 2) == Complex::new(17.0, 18.0)); + + x.write(1, 0, Complex::new(3.0, 2.0)); + assert!(x.read(1, 0) == Complex::new(3.0, 2.0)); + } + + #[test] + fn col_macro() { + let mut x = col![3.0, 5.0, 7.0, 9.0]; + + assert!(x[0] == 3.0); + assert!(x[1] == 5.0); + assert!(x[2] == 7.0); + assert!(x[3] == 9.0); + + x[0] = 13.0; + assert!(x[0] == 13.0); + + assert!(x.get(..) == x); + } + + #[test] + fn col_macro_cplx() { + use num_complex::Complex; + let new = Complex::new; + let mut x = col![new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0),]; + + assert!(x.read(0) == Complex::new(1.0, 2.0)); + assert!(x.read(1) == Complex::new(3.0, 4.0)); + assert!(x.read(2) == Complex::new(5.0, 6.0)); + + x.write(0, Complex::new(3.0, 2.0)); + assert!(x.read(0) == Complex::new(3.0, 2.0)); + } + + #[test] + fn col_macro_native_cplx() { + use complex_native::c64 as Complex; + + let new = Complex::new; + let mut x = col![new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0),]; + + assert!(x.read(0) == Complex::new(1.0, 2.0)); + assert!(x.read(1) == Complex::new(3.0, 4.0)); + assert!(x.read(2) == Complex::new(5.0, 6.0)); + + x.write(0, Complex::new(3.0, 2.0)); + assert!(x.read(0) == Complex::new(3.0, 2.0)); + } + + #[test] + fn row_macro() { + let mut x = row![3.0, 5.0, 7.0, 9.0]; + + assert!(x[0] == 3.0); + assert!(x[1] == 5.0); + assert!(x[2] == 7.0); + assert!(x[3] == 9.0); + + x.write(0, 13.0); + assert!(x.read(0) == 13.0); + } + + #[test] + fn row_macro_cplx() { + use num_complex::Complex; + + let new = Complex::new; + let mut x = row![new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0),]; + + assert!(x.read(0) == Complex::new(1.0, 2.0)); + assert!(x.read(1) == Complex::new(3.0, 4.0)); + assert!(x.read(2) == Complex::new(5.0, 6.0)); + + x.write(0, Complex::new(3.0, 2.0)); + assert!(x.read(0) == Complex::new(3.0, 2.0)); + } + + #[test] + fn row_macro_native_cplx() { + use complex_native::c64 as Complex; + + let new = Complex::new; + let mut x = row![new(1.0, 2.0), new(3.0, 4.0), new(5.0, 6.0),]; + + assert!(x.read(0) == new(1.0, 2.0)); + assert!(x.read(1) == new(3.0, 4.0)); + assert!(x.read(2) == new(5.0, 6.0)); + + x.write(0, new(3.0, 2.0)); + assert!(x.read(0) == new(3.0, 2.0)); + } + + #[test] + fn null_col_and_row() { + let null_col: Col = col![]; + assert!(null_col == Col::::new()); + + let null_row: Row = row![]; + assert!(null_row == Row::::new()); + } + + #[test] + fn positive_concat_f64() { + let a0: Mat = Mat::from_fn(2, 2, |_, _| 1f64); + let a1: Mat = Mat::from_fn(2, 3, |_, _| 2f64); + let a2: Mat = Mat::from_fn(2, 4, |_, _| 3f64); + + let b0: Mat = Mat::from_fn(1, 6, |_, _| 4f64); + let b1: Mat = Mat::from_fn(1, 3, |_, _| 5f64); + + let c0: Mat = Mat::from_fn(6, 1, |_, _| 6f64); + let c1: Mat = Mat::from_fn(6, 3, |_, _| 7f64); + let c2: Mat = Mat::from_fn(6, 2, |_, _| 8f64); + let c3: Mat = Mat::from_fn(6, 3, |_, _| 9f64); + + let x = concat_impl(&[ + &[a0.as_ref(), a1.as_ref(), a2.as_ref()], + &[b0.as_ref(), b1.as_ref()], + &[c0.as_ref(), c1.as_ref(), c2.as_ref(), c3.as_ref()], + ]); + + assert!(x == concat![[a0, a1, a2], [b0, b1], [c0, c1, c2, &c3]]); + + assert!(x[(0, 0)] == 1f64); + assert!(x[(1, 1)] == 1f64); + assert!(x[(2, 2)] == 4f64); + assert!(x[(3, 3)] == 7f64); + assert!(x[(4, 4)] == 8f64); + assert!(x[(5, 5)] == 8f64); + assert!(x[(6, 6)] == 9f64); + assert!(x[(7, 7)] == 9f64); + assert!(x[(8, 8)] == 9f64); + } + + #[test] + fn to_owned_equality() { + use num_complex::{Complex, Complex as C}; + let mut mf32: Mat = mat![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]; + let mut mf64: Mat = mat![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]; + let mut mf32c: Mat> = mat![ + [C::new(1., 1.), C::new(2., 2.), C::new(3., 3.)], + [C::new(4., 4.), C::new(5., 5.), C::new(6., 6.)], + [C::new(7., 7.), C::new(8., 8.), C::new(9., 9.)] + ]; + let mut mf64c: Mat> = mat![ + [C::new(1., 1.), C::new(2., 2.), C::new(3., 3.)], + [C::new(4., 4.), C::new(5., 5.), C::new(6., 6.)], + [C::new(7., 7.), C::new(8., 8.), C::new(9., 9.)] + ]; + + assert!(mf32.transpose().to_owned().as_ref() == mf32.transpose()); + assert!(mf64.transpose().to_owned().as_ref() == mf64.transpose()); + assert!(mf32c.transpose().to_owned().as_ref() == mf32c.transpose()); + assert!(mf64c.transpose().to_owned().as_ref() == mf64c.transpose()); + + assert!(mf32.as_mut().transpose_mut().to_owned().as_ref() == mf32.transpose()); + assert!(mf64.as_mut().transpose_mut().to_owned().as_ref() == mf64.transpose()); + assert!(mf32c.as_mut().transpose_mut().to_owned().as_ref() == mf32c.transpose()); + assert!(mf64c.as_mut().transpose_mut().to_owned().as_ref() == mf64c.transpose()); + } + + #[test] + fn conj_to_owned_equality() { + use num_complex::{Complex, Complex as C}; + let mut mf32: Mat = mat![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]; + let mut mf64: Mat = mat![[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]; + let mut mf32c: Mat> = mat![ + [C::new(1., 1.), C::new(2., 2.), C::new(3., 3.)], + [C::new(4., 4.), C::new(5., 5.), C::new(6., 6.)], + [C::new(7., 7.), C::new(8., 8.), C::new(9., 9.)] + ]; + let mut mf64c: Mat> = mat![ + [C::new(1., 1.), C::new(2., 2.), C::new(3., 3.)], + [C::new(4., 4.), C::new(5., 5.), C::new(6., 6.)], + [C::new(7., 7.), C::new(8., 8.), C::new(9., 9.)] + ]; + + assert!(mf32.as_ref().adjoint().to_owned().as_ref() == mf32.adjoint()); + assert!(mf64.as_ref().adjoint().to_owned().as_ref() == mf64.adjoint()); + assert!(mf32c.as_ref().adjoint().to_owned().as_ref() == mf32c.adjoint()); + assert!(mf64c.as_ref().adjoint().to_owned().as_ref() == mf64c.adjoint()); + + assert!(mf32.as_mut().adjoint_mut().to_owned().as_ref() == mf32.adjoint()); + assert!(mf64.as_mut().adjoint_mut().to_owned().as_ref() == mf64.adjoint()); + assert!(mf32c.as_mut().adjoint_mut().to_owned().as_ref() == mf32c.adjoint()); + assert!(mf64c.as_mut().adjoint_mut().to_owned().as_ref() == mf64c.adjoint()); + } + + #[test] + fn mat_mul_assign_scalar() { + let mut x = mat![[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]; + + let expected = mat![[0.0, 2.0], [4.0, 6.0], [8.0, 10.0]]; + x *= scale(2.0); + assert_eq!(x, expected); + + let expected = mat![[0.0, 4.0], [8.0, 12.0], [16.0, 20.0]]; + let mut x_mut = x.as_mut(); + x_mut *= scale(2.0); + assert_eq!(x, expected); + } + + #[test] + fn test_col_slice() { + let mut matrix = mat![[1.0, 5.0, 9.0], [2.0, 6.0, 10.0], [3.0, 7.0, 11.0f64]]; + + assert_eq!(matrix.col_as_slice(1), &[5.0, 6.0, 7.0]); + assert_eq!(matrix.col_as_slice_mut(0), &[1.0, 2.0, 3.0]); + + matrix + .col_as_slice_mut(0) + .copy_from_slice(&[-1.0, -2.0, -3.0]); + + let expected = mat![[-1.0, 5.0, 9.0], [-2.0, 6.0, 10.0], [-3.0, 7.0, 11.0f64]]; + assert_eq!(matrix, expected); + } + + #[test] + fn from_slice() { + let mut slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64]; + + let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; + let view = mat::from_column_major_slice::<'_, f64>(&slice, 3, 2); + assert_eq!(expected, view); + let view = mat::from_column_major_slice::<'_, f64>(&mut slice, 3, 2); + assert_eq!(expected, view); + + let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]; + let view = mat::from_row_major_slice::<'_, f64>(&slice, 3, 2); + assert_eq!(expected, view); + let view = mat::from_row_major_slice::<'_, f64>(&mut slice, 3, 2); + assert_eq!(expected, view); + } + + #[test] + #[should_panic] + fn from_slice_too_big() { + let slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0_f64]; + mat::from_column_major_slice::<'_, f64>(&slice, 3, 2); + } + + #[test] + #[should_panic] + fn from_slice_too_small() { + let slice = [1.0, 2.0, 3.0, 4.0, 5.0_f64]; + mat::from_column_major_slice::<'_, f64>(&slice, 3, 2); + } + + #[test] + fn test_is_finite() { + use complex_native::c32; + + let inf = f32::INFINITY; + let nan = f32::NAN; + + { + assert!(::faer_is_finite(&1.0)); + assert!(!::faer_is_finite(&inf)); + assert!(!::faer_is_finite(&-inf)); + assert!(!::faer_is_finite(&nan)); + } + { + let x = c32::new(1.0, 2.0); + assert!(::faer_is_finite(&x)); + + let x = c32::new(inf, 2.0); + assert!(!::faer_is_finite(&x)); + + let x = c32::new(1.0, inf); + assert!(!::faer_is_finite(&x)); + + let x = c32::new(inf, inf); + assert!(!::faer_is_finite(&x)); + + let x = c32::new(nan, 2.0); + assert!(!::faer_is_finite(&x)); + + let x = c32::new(1.0, nan); + assert!(!::faer_is_finite(&x)); + + let x = c32::new(nan, nan); + assert!(!::faer_is_finite(&x)); + } + } + + #[test] + fn test_iter() { + let mut mat = Mat::from_fn(9, 10, |i, j| (i + j) as f64); + let mut iter = mat.row_chunks_mut(4); + + let first = iter.next(); + let second = iter.next(); + let last = iter.next(); + let none = iter.next(); + + assert!(first == Some(Mat::from_fn(4, 10, |i, j| (i + j) as f64).as_mut())); + assert!(second == Some(Mat::from_fn(4, 10, |i, j| (i + j + 4) as f64).as_mut())); + assert!(last == Some(Mat::from_fn(1, 10, |i, j| (i + j + 8) as f64).as_mut())); + assert!(none == None); + } + + #[test] + fn test_col_index() { + let mut col_32: Col = Col::from_fn(3, |i| i as f32); + col_32.as_mut()[1] = 10f32; + let tval: f32 = (10f32 - col_32[1]).abs(); + assert!(tval < 1e-14); + + let mut col_64: Col = Col::from_fn(3, |i| i as f64); + col_64.as_mut()[1] = 10f64; + let tval: f64 = (10f64 - col_64[1]).abs(); + assert!(tval < 1e-14); + } + + #[test] + fn test_row_index() { + let mut row_32: Row = Row::from_fn(3, |i| i as f32); + row_32.as_mut()[1] = 10f32; + let tval: f32 = (10f32 - row_32[1]).abs(); + assert!(tval < 1e-14); + + let mut row_64: Row = Row::from_fn(3, |i| i as f64); + row_64.as_mut()[1] = 10f64; + let tval: f64 = (10f64 - row_64[1]).abs(); + assert!(tval < 1e-14); + } +} diff --git a/faer-libs/faer-cholesky/src/bunch_kaufman/mod.rs b/src/linalg/cholesky/bunch_kaufman/mod.rs similarity index 92% rename from faer-libs/faer-cholesky/src/bunch_kaufman/mod.rs rename to src/linalg/cholesky/bunch_kaufman/mod.rs index 8f88ca73d57a7fb8dfeb21975a97375f2ad5d33f..6ce02a35499aed812f2525270bf7158bacc32c72 100644 --- a/faer-libs/faer-cholesky/src/bunch_kaufman/mod.rs +++ b/src/linalg/cholesky/bunch_kaufman/mod.rs @@ -3,43 +3,55 @@ //! where $B$ is a block diagonal matrix, with $1\times 1$ or $2 \times 2 $ diagonal blocks, and //! $L$ is a unit lower triangular matrix. -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - mul::triangular::{self, BlockStructure}, - permutation::{ - permute_rows, swap_cols, swap_rows, Index, PermutationMut, PermutationRef, SignedIndex, - }, - solve::{ - solve_unit_lower_triangular_in_place_with_conj, - solve_unit_upper_triangular_in_place_with_conj, +use crate::{ + linalg::{ + matmul::triangular::{self, BlockStructure}, + temp_mat_req, temp_mat_uninit, + triangular_solve::{ + solve_unit_lower_triangular_in_place_with_conj, + solve_unit_upper_triangular_in_place_with_conj, + }, }, - temp_mat_req, temp_mat_uninit, unzipped, zipped, Conj, MatMut, MatRef, Parallelism, + perm::{permute_rows, swap_cols_idx as swap_cols, swap_rows_idx as swap_rows, PermRef}, + unzipped, zipped, Conj, Index, MatMut, MatRef, Parallelism, SignedIndex, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use faer_entity::{ComplexField, Entity, RealField}; use reborrow::*; +/// Computing the decomposition. pub mod compute { use super::*; - use faer_core::assert; + use crate::assert; + /// Pivoting strategy for choosing the pivots. #[derive(Copy, Clone)] #[non_exhaustive] pub enum PivotingStrategy { + /// Diagonal pivoting. Diagonal, } + /// Tuning parameters for the decomposition. #[derive(Copy, Clone)] #[non_exhaustive] pub struct BunchKaufmanParams { + /// Pivoting strategy. pub pivoting: PivotingStrategy, + /// Block size of the algorithm. pub blocksize: usize, } /// Dynamic Bunch-Kaufman regularization. + /// Values below `epsilon` in absolute value, or with the wrong sign are set to `delta` with + /// their corrected sign. #[derive(Debug)] pub struct BunchKaufmanRegularization<'a, E: ComplexField> { + /// Expected signs for the diagonal at each step of the decomposition. pub dynamic_regularization_signs: Option<&'a mut [i8]>, + /// Regularized value. pub dynamic_regularization_delta: E::Real, + /// Regularization threshold. pub dynamic_regularization_epsilon: E::Real, } @@ -186,7 +198,7 @@ pub mod compute { .split_at_col_mut(k); let w_row = w_left.rb().row(0); let w_col = w_right.col_mut(0); - faer_core::mul::matmul( + crate::linalg::matmul::matmul( w_col.as_2d_mut(), a.rb().submatrix(k, 0, n - k, k), w_row.rb().transpose().as_2d(), @@ -254,7 +266,7 @@ pub mod compute { let w_row = w_left.rb().row(imax - k).subcols(0, k); let w_col = w_right.col_mut(0); - faer_core::mul::matmul( + crate::linalg::matmul::matmul( w_col.as_2d_mut(), a.rb().submatrix(k, 0, n - k, k), w_row.rb().transpose().as_2d(), @@ -313,15 +325,13 @@ pub mod compute { d11 = delta.faer_neg(); dynamic_regularization_count += 1; } - } else { - if d11.faer_abs() <= eps { - if d11 < E::Real::faer_zero() { - d11 = delta.faer_neg(); - } else { - d11 = delta; - } - dynamic_regularization_count += 1; + } else if d11.faer_abs() <= eps { + if d11 < E::Real::faer_zero() { + d11 = delta.faer_neg(); + } else { + d11 = delta; } + dynamic_regularization_count += 1; } } let d11 = d11.faer_inv(); @@ -577,15 +587,13 @@ pub mod compute { d11 = delta.faer_neg(); dynamic_regularization_count += 1; } - } else { - if d11.faer_abs() <= eps { - if d11 < E::Real::faer_zero() { - d11 = delta.faer_neg(); - } else { - d11 = delta; - } - dynamic_regularization_count += 1; + } else if d11.faer_abs() <= eps { + if d11 < E::Real::faer_zero() { + d11 = delta.faer_neg(); + } else { + d11 = delta; } + dynamic_regularization_count += 1; } } let d11 = d11.faer_inv(); @@ -752,9 +760,12 @@ pub mod compute { StackReq::try_new::(dim)?.try_and(temp_mat_req::(dim, bs)?) } + /// Info about the result of the Bunch-Kaufman factorization. #[derive(Copy, Clone, Debug)] pub struct BunchKaufmanInfo { + /// Number of pivots whose value or sign had to be corrected. pub dynamic_regularization_count: usize, + /// Number of pivoting transpositions. pub transposition_count: usize, } @@ -780,7 +791,7 @@ pub mod compute { parallelism: Parallelism, stack: PodStack<'_>, params: BunchKaufmanParams, - ) -> (BunchKaufmanInfo, PermutationMut<'out, I, E>) { + ) -> (BunchKaufmanInfo, PermRef<'out, I>) { let truncate = ::truncate; let mut regularization = regularization; @@ -794,7 +805,7 @@ pub mod compute { )); #[cfg(feature = "perf-warn")] - if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(CHOLESKY_WARN) { + if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(CHOLESKY_WARN) { if matrix.col_stride().unsigned_abs() == 1 { log::warn!(target: "faer_perf", "Bunch-Kaufman decomposition prefers column-major matrix. Found row-major matrix."); } else { @@ -892,15 +903,18 @@ pub mod compute { dynamic_regularization_count, transposition_count, }, - unsafe { PermutationMut::new_unchecked(perm, perm_inv) }, + unsafe { PermRef::new_unchecked(perm, perm_inv) }, ) } } +/// Solving a linear system using the decomposition. pub mod solve { use super::*; - use faer_core::assert; + use crate::assert; + /// Computes the size and alignment of required workspace for solving a linear system defined by + /// a matrix in place, given its Bunch-Kaufman decomposition. #[track_caller] pub fn solve_in_place_req( dim: usize, @@ -911,12 +925,28 @@ pub mod solve { temp_mat_req::(dim, rhs_ncols) } + /// Given the Bunch-Kaufman factors of a matrix $A$ and a matrix $B$ stored in `rhs`, this + /// function computes the solution of the linear system: + /// $$\text{Op}_A(A)X = B.$$ + /// + /// $\text{Op}_A$ is either the identity or the conjugation depending on the value of + /// `conj`. + /// + /// The solution of the linear system is stored in `rhs`. + /// + /// # Panics + /// + /// - Panics if `lb_factors` is not a square matrix. + /// - Panics if `subdiag` is not a column vector with the same number of rows as the dimension + /// of `lb_factors`. + /// - Panics if `rhs` doesn't have the same number of rows as the dimension of `lb_factors`. + /// - Panics if the provided memory in `stack` is insufficient (see [`solve_in_place_req`]). #[track_caller] pub fn solve_in_place_with_conj( lb_factors: MatRef<'_, E>, subdiag: MatRef<'_, E>, conj: Conj, - perm: PermutationRef<'_, I, E>, + perm: PermRef<'_, I>, rhs: MatMut<'_, E>, parallelism: Parallelism, stack: PodStack<'_>, @@ -1001,11 +1031,10 @@ pub mod solve { #[cfg(test)] mod tests { - use crate::bunch_kaufman::compute::BunchKaufmanParams; - use super::*; + use crate::{assert, complex_native::c64, Mat}; + use compute::BunchKaufmanParams; use dyn_stack::GlobalPodBuffer; - use faer_core::{assert, c64, Mat}; use rand::random; #[test] diff --git a/faer-libs/faer-cholesky/src/ldlt_diagonal/compute.rs b/src/linalg/cholesky/ldlt_diagonal/compute.rs similarity index 90% rename from faer-libs/faer-cholesky/src/ldlt_diagonal/compute.rs rename to src/linalg/cholesky/ldlt_diagonal/compute.rs index 0da5570f3deba5c19d911bf115013f45ca4f1614..19e5192977e1ba937f8a20fa839b4d3e1032f779 100644 --- a/faer-libs/faer-cholesky/src/ldlt_diagonal/compute.rs +++ b/src/linalg/cholesky/ldlt_diagonal/compute.rs @@ -1,9 +1,14 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, debug_assert, group_helpers::*, mul::triangular::BlockStructure, solve, temp_mat_req, - temp_mat_uninit, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, - SimdCtx, +use crate::{ + assert, debug_assert, + linalg::{ + entity::SimdCtx, matmul::triangular::BlockStructure, temp_mat_req, temp_mat_uninit, + triangular_solve as solve, + }, + unzipped, + utils::{simd::*, slice::*}, + zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use faer_entity::*; use reborrow::*; @@ -140,13 +145,15 @@ fn cholesky_in_place_left_looking_impl( let mut d = a11 .read(0, 0) - .faer_sub(faer_core::mul::inner_prod::inner_prod_with_conj_arch( - arch, - l10xd0.row(0).transpose().as_2d(), - Conj::Yes, - l10.row(0).transpose().as_2d(), - Conj::No, - )) + .faer_sub( + crate::linalg::matmul::inner_prod::inner_prod_with_conj_arch( + arch, + l10xd0.row(0).transpose().as_2d(), + Conj::Yes, + l10.row(0).transpose().as_2d(), + Conj::No, + ), + ) .faer_real(); // dynamic regularization code taken from clarabel.rs with modifications @@ -206,6 +213,7 @@ fn cholesky_in_place_left_looking_impl( dynamic_regularization_count } +/// LDLT factorization tuning parameters. #[derive(Default, Copy, Clone)] #[non_exhaustive] pub struct LdltDiagParams {} @@ -281,7 +289,7 @@ fn cholesky_in_place_impl( ); } - faer_core::mul::triangular::matmul( + crate::linalg::matmul::triangular::matmul( a11.rb_mut(), BlockStructure::TriangularLower, a10.into_const(), @@ -312,15 +320,22 @@ fn cholesky_in_place_impl( } /// Dynamic LDLT regularization. +/// Values below `epsilon` in absolute value, or with the wrong sign are set to `delta` with +/// their corrected sign. #[derive(Copy, Clone, Debug)] pub struct LdltRegularization<'a, E: ComplexField> { + /// Expected signs for the diagonal at each step of the decomposition. pub dynamic_regularization_signs: Option<&'a [i8]>, + /// Regularized value. pub dynamic_regularization_delta: E::Real, + /// Regularization threshold. pub dynamic_regularization_epsilon: E::Real, } +/// Info about the result of the LDLT factorization. #[derive(Copy, Clone, Debug)] pub struct LdltInfo { + /// Number of pivots whose value or sign had to be corrected. pub dynamic_regularization_count: usize, } @@ -353,7 +368,7 @@ impl Default for LdltRegularization<'_, E> { /// The Cholesky decomposition with diagonal may have poor numerical stability properties when used /// with non positive definite matrices. In the general case, it is recommended to first permute /// (and conjugate when necessary) the rows and columns of the matrix using the permutation obtained -/// from [`crate::compute_cholesky_permutation`]. +/// from [`faer::linalg::cholesky::compute_cholesky_permutation`](crate::linalg::cholesky::compute_cholesky_permutation). /// /// # Panics /// @@ -372,7 +387,7 @@ pub fn raw_cholesky_in_place( ) -> LdltInfo { assert!(matrix.ncols() == matrix.nrows()); #[cfg(feature = "perf-warn")] - if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(CHOLESKY_WARN) { + if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(CHOLESKY_WARN) { if matrix.col_stride().unsigned_abs() == 1 { log::warn!(target: "faer_perf", "LDLT prefers column-major matrix. Found row-major matrix."); } else { diff --git a/faer-libs/faer-cholesky/src/ldlt_diagonal/mod.rs b/src/linalg/cholesky/ldlt_diagonal/mod.rs similarity index 97% rename from faer-libs/faer-cholesky/src/ldlt_diagonal/mod.rs rename to src/linalg/cholesky/ldlt_diagonal/mod.rs index b4d0f5ba37377c7f99874fede8a2fb865de1bc80..4fab31a068df3d3ca8d753890021387f34aaab9a 100644 --- a/faer-libs/faer-cholesky/src/ldlt_diagonal/mod.rs +++ b/src/linalg/cholesky/ldlt_diagonal/mod.rs @@ -5,21 +5,27 @@ //! The Cholesky decomposition with diagonal may have poor numerical stability properties when used //! with non positive definite matrices. In the general case, it is recommended to first permute //! (and conjugate when necessary) the rows and columns of the matrix using the permutation obtained -//! from [`crate::compute_cholesky_permutation`]. +//! from [`faer::linalg::cholesky::compute_cholesky_permutation`](crate::linalg::cholesky::compute_cholesky_permutation). +/// Computing the decomposition. pub mod compute; +/// Solving a linear system usin the decomposition. pub mod solve; +/// Updating the decomposition. pub mod update; #[cfg(test)] mod tests { + use crate::{complex_native::c64, mat, Conj}; use assert_approx_eq::assert_approx_eq; use dyn_stack::{GlobalPodBuffer, PodStack}; - use faer_core::{c64, mat, Conj}; use super::*; + use crate::{ + linalg::{matmul as mul, matmul::triangular::BlockStructure}, + ComplexField, Mat, MatRef, Parallelism, + }; use compute::*; - use faer_core::{mul, mul::triangular::BlockStructure, ComplexField, Mat, MatRef, Parallelism}; use solve::*; use update::*; diff --git a/faer-libs/faer-cholesky/src/ldlt_diagonal/solve.rs b/src/linalg/cholesky/ldlt_diagonal/solve.rs similarity index 98% rename from faer-libs/faer-cholesky/src/ldlt_diagonal/solve.rs rename to src/linalg/cholesky/ldlt_diagonal/solve.rs index c62c3f0498c0f20fabdbad59eb5e025c38fb086c..606e74ce1fb6b5c31c0eb3a6b4c61ab97c7781fa 100644 --- a/faer-libs/faer-cholesky/src/ldlt_diagonal/solve.rs +++ b/src/linalg/cholesky/ldlt_diagonal/solve.rs @@ -1,7 +1,8 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, solve, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, +use crate::{ + assert, linalg::triangular_solve as solve, unzipped, zipped, ComplexField, Conj, Entity, + MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; /// Computes the size and alignment of required workspace for solving a linear system defined by a diff --git a/faer-libs/faer-cholesky/src/ldlt_diagonal/update.rs b/src/linalg/cholesky/ldlt_diagonal/update.rs similarity index 98% rename from faer-libs/faer-cholesky/src/ldlt_diagonal/update.rs rename to src/linalg/cholesky/ldlt_diagonal/update.rs index 13f01e3246afe347735cd881ecc667eb0da4ac85..bab50784c1d39c5cb0b3cbc7b48e3bc5f38d2dd7 100644 --- a/faer-libs/faer-cholesky/src/ldlt_diagonal/update.rs +++ b/src/linalg/cholesky/ldlt_diagonal/update.rs @@ -1,11 +1,16 @@ -use crate::ldlt_diagonal::compute::{raw_cholesky_in_place, raw_cholesky_in_place_req}; +use super::compute::{raw_cholesky_in_place, raw_cholesky_in_place_req}; +use crate::{ + assert, debug_assert, + linalg::{ + entity::SimdCtx, matmul as mul, matmul::triangular::BlockStructure, temp_mat_req, + temp_mat_uninit, triangular_solve as solve, + }, + unzipped, + utils::{simd::*, slice::*}, + zipped, ComplexField, Entity, MatMut, Parallelism, +}; use core::iter::zip; use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, debug_assert, group_helpers::*, mul, mul::triangular::BlockStructure, solve, - temp_mat_req, temp_mat_uninit, unzipped, zipped, ComplexField, Entity, MatMut, Parallelism, - SimdCtx, -}; use faer_entity::*; use reborrow::*; diff --git a/faer-libs/faer-cholesky/src/llt/compute.rs b/src/linalg/cholesky/llt/compute.rs similarity index 90% rename from faer-libs/faer-cholesky/src/llt/compute.rs rename to src/linalg/cholesky/llt/compute.rs index 1a26b2e68b44de7331f19d91b9ebd64695c50213..6078a419c0b79ffedd3ab1af475506fe0cee334c 100644 --- a/faer-libs/faer-cholesky/src/llt/compute.rs +++ b/src/linalg/cholesky/llt/compute.rs @@ -1,10 +1,15 @@ use super::CholeskyError; -use crate::ldlt_diagonal::compute::RankUpdate; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, debug_assert, mul::triangular::BlockStructure, parallelism_degree, solve, unzipped, - zipped, ComplexField, Entity, MatMut, Parallelism, SimdCtx, +use crate::{ + assert, debug_assert, + linalg::{ + cholesky::ldlt_diagonal::compute::RankUpdate, entity::SimdCtx, + matmul::triangular::BlockStructure, triangular_solve, + }, + unzipped, + utils::thread::parallelism_degree, + zipped, ComplexField, Entity, MatMut, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; fn cholesky_in_place_left_looking_impl( @@ -129,14 +134,19 @@ fn cholesky_in_place_left_looking_impl( Ok(dynamic_regularization_count) } +/// LLT factorization tuning parameters. #[derive(Default, Copy, Clone)] #[non_exhaustive] pub struct LltParams {} /// Dynamic LLT regularization. +/// Values below `epsilon` in absolute value, or with a negative sign are set to `delta` with +/// a positive sign. #[derive(Copy, Clone, Debug)] pub struct LltRegularization { + /// Regularized value. pub dynamic_regularization_delta: E::Real, + /// Regularization threshold. pub dynamic_regularization_epsilon: E::Real, } @@ -204,13 +214,13 @@ fn cholesky_in_place_impl( let l00 = l00.into_const(); - solve::solve_lower_triangular_in_place( + triangular_solve::solve_lower_triangular_in_place( l00.conjugate(), a10.rb_mut().transpose_mut(), parallelism, ); - faer_core::mul::triangular::matmul( + crate::linalg::matmul::triangular::matmul( a11.rb_mut(), BlockStructure::TriangularLower, a10.rb(), @@ -234,8 +244,10 @@ fn cholesky_in_place_impl( } } +/// Info about the result of the LLT factorization. #[derive(Copy, Clone, Debug)] pub struct LltInfo { + /// Number of pivots whose value or sign had to be corrected. pub dynamic_regularization_count: usize, } @@ -269,7 +281,7 @@ pub fn cholesky_in_place( let _ = params; assert!(matrix.ncols() == matrix.nrows()); #[cfg(feature = "perf-warn")] - if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(CHOLESKY_WARN) { + if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(CHOLESKY_WARN) { if matrix.col_stride().unsigned_abs() == 1 { log::warn!(target: "faer_perf", "LLT prefers column-major matrix. Found row-major matrix."); } else { diff --git a/faer-libs/faer-cholesky/src/llt/inverse.rs b/src/linalg/cholesky/llt/inverse.rs similarity index 93% rename from faer-libs/faer-cholesky/src/llt/inverse.rs rename to src/linalg/cholesky/llt/inverse.rs index 226479fd178613e9d90e505e0ec2ca35f9b0c71f..d0ebe3387309310a52729dc246e0fafb42ddfe02 100644 --- a/faer-libs/faer-cholesky/src/llt/inverse.rs +++ b/src/linalg/cholesky/llt/inverse.rs @@ -1,10 +1,13 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - inverse::invert_lower_triangular, - mul::triangular::{self, BlockStructure}, - temp_mat_req, temp_mat_uninit, ComplexField, Entity, MatMut, MatRef, Parallelism, + linalg::{ + matmul::triangular::{self, BlockStructure}, + temp_mat_req, temp_mat_uninit, + triangular_inverse::invert_lower_triangular, + }, + ComplexField, Entity, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; fn invert_lower_impl( diff --git a/faer-libs/faer-cholesky/src/llt/mod.rs b/src/linalg/cholesky/llt/mod.rs similarity index 96% rename from faer-libs/faer-cholesky/src/llt/mod.rs rename to src/linalg/cholesky/llt/mod.rs index 1c366488e0531c06d90925ffb788fa2c84a1659d..0eb79c14a74495816e9cd70cb6458b44ee76ed32 100644 --- a/faer-libs/faer-cholesky/src/llt/mod.rs +++ b/src/linalg/cholesky/llt/mod.rs @@ -2,10 +2,15 @@ //! $$A = LL^H,$$ //! where $L$ is a lower triangular matrix. +/// Computing the decomposition. pub mod compute; +/// Reconstructing the inverse of the original matrix from the decomposition. pub mod inverse; +/// Reconstructing the original matrix from the decomposition. pub mod reconstruct; +/// Solving a linear system usin the decomposition. pub mod solve; +/// Updating the decomposition. pub mod update; /// This error signifies that the LLT decomposition could not be computed due to the matrix not @@ -23,7 +28,9 @@ mod tests { use dyn_stack::{GlobalPodBuffer, PodStack}; use super::{compute::*, inverse::*, reconstruct::*, solve::*, update::*}; - use faer_core::{c64, mul, ComplexField, Conj, Mat, MatRef, Parallelism}; + use crate::{ + complex_native::c64, linalg::matmul as mul, ComplexField, Conj, Mat, MatRef, Parallelism, + }; type E = c64; diff --git a/faer-libs/faer-cholesky/src/llt/reconstruct.rs b/src/linalg/cholesky/llt/reconstruct.rs similarity index 92% rename from faer-libs/faer-cholesky/src/llt/reconstruct.rs rename to src/linalg/cholesky/llt/reconstruct.rs index 36dffa26200c6ec0338afcb64538b447d3b1d1ab..f029d128fe81c14cf6e6e82f7e13bd24579728a4 100644 --- a/faer-libs/faer-cholesky/src/llt/reconstruct.rs +++ b/src/linalg/cholesky/llt/reconstruct.rs @@ -1,11 +1,13 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - mul::triangular::{self, BlockStructure}, - temp_mat_req, temp_mat_uninit, unzipped, - zip::Diag, - zipped, ComplexField, Entity, MatMut, MatRef, Parallelism, + linalg::{ + matmul::triangular::{self, BlockStructure}, + temp_mat_req, temp_mat_uninit, + zip::Diag, + }, + unzipped, zipped, ComplexField, Entity, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; /// Computes the size and alignment of required workspace for reconstructing the lower triangular diff --git a/faer-libs/faer-cholesky/src/llt/solve.rs b/src/linalg/cholesky/llt/solve.rs similarity index 97% rename from faer-libs/faer-cholesky/src/llt/solve.rs rename to src/linalg/cholesky/llt/solve.rs index ebc8ee621e691878da25321aad6761895be242a5..682fc73e0ebcec5680f6e1a387b7509147a9b987 100644 --- a/faer-libs/faer-cholesky/src/llt/solve.rs +++ b/src/linalg/cholesky/llt/solve.rs @@ -1,7 +1,8 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, solve, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, +use crate::{ + assert, linalg::triangular_solve as solve, unzipped, zipped, ComplexField, Conj, Entity, + MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; /// Computes the size and alignment of required workspace for solving a linear system defined by a diff --git a/faer-libs/faer-cholesky/src/llt/update.rs b/src/linalg/cholesky/llt/update.rs similarity index 98% rename from faer-libs/faer-cholesky/src/llt/update.rs rename to src/linalg/cholesky/llt/update.rs index b6754bb73a5c0fb53bc8ed5b1209f71bdf1c75a8..f906002ee6d7893ca55cd61ef5746459f33dec58 100644 --- a/faer-libs/faer-cholesky/src/llt/update.rs +++ b/src/linalg/cholesky/llt/update.rs @@ -1,15 +1,22 @@ use super::CholeskyError; use crate::{ - ldlt_diagonal::update::{delete_rows_and_cols_triangular, rank_update_indices}, - llt::compute::{cholesky_in_place, cholesky_in_place_req}, + assert, debug_assert, + linalg::{ + cholesky::{ + ldlt_diagonal::update::{delete_rows_and_cols_triangular, rank_update_indices}, + llt::compute::{cholesky_in_place, cholesky_in_place_req}, + }, + entity::SimdCtx, + matmul as mul, + matmul::triangular::BlockStructure, + temp_mat_req, temp_mat_uninit, triangular_solve as solve, + }, + unzipped, + utils::{simd::*, slice::*}, + zipped, ComplexField, Entity, MatMut, Parallelism, }; use core::iter::zip; use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, debug_assert, group_helpers::*, mul, mul::triangular::BlockStructure, solve, - temp_mat_req, temp_mat_uninit, unzipped, zipped, ComplexField, Entity, MatMut, Parallelism, - SimdCtx, -}; use faer_entity::*; use reborrow::*; diff --git a/faer-libs/faer-cholesky/src/lib.rs b/src/linalg/cholesky/mod.rs similarity index 82% rename from faer-libs/faer-cholesky/src/lib.rs rename to src/linalg/cholesky/mod.rs index d5fe28ea234eb0d541e678468566afaaa827a8c8..01235afd60a7c800a4d7fcce23499940832deff8 100644 --- a/faer-libs/faer-cholesky/src/lib.rs +++ b/src/linalg/cholesky/mod.rs @@ -1,13 +1,7 @@ -#![allow(clippy::type_complexity)] -#![allow(clippy::too_many_arguments)] -#![cfg_attr(not(feature = "std"), no_std)] +//! Low level implementation of the various Cholesky-like decompositions. +use crate::{assert, perm::PermRef, ComplexField, Index, MatRef, SignedIndex}; use core::cmp::Ordering; -use faer_core::{ - assert, - permutation::{Index, PermutationMut, SignedIndex}, - ComplexField, MatRef, -}; pub mod bunch_kaufman; pub mod ldlt_diagonal; @@ -21,7 +15,7 @@ pub fn compute_cholesky_permutation<'a, E: ComplexField, I: Index>( perm_indices: &'a mut [I], perm_inv_indices: &'a mut [I], matrix: MatRef<'_, E>, -) -> PermutationMut<'a, I, E> { +) -> PermRef<'a, I> { let n = matrix.nrows(); let truncate = ::truncate; assert!( @@ -60,5 +54,5 @@ pub fn compute_cholesky_permutation<'a, E: ComplexField, I: Index>( perm_inv_indices[p.to_signed().zx()] = I::from_signed(truncate(i)); } - unsafe { PermutationMut::new_unchecked(perm_indices, perm_inv_indices) } + unsafe { PermRef::new_unchecked(perm_indices, perm_inv_indices) } } diff --git a/faer-libs/faer-evd/src/hessenberg.rs b/src/linalg/evd/hessenberg.rs similarity index 91% rename from faer-libs/faer-evd/src/hessenberg.rs rename to src/linalg/evd/hessenberg.rs index 5d55474418aed35f95687a22c1faba692ce2f946..9d902066802f0b1adf948cc94a01e8e78ab3331b 100644 --- a/faer-libs/faer-evd/src/hessenberg.rs +++ b/src/linalg/evd/hessenberg.rs @@ -1,16 +1,20 @@ -use core::slice; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - householder::{ - apply_block_householder_on_the_right_in_place_req, - apply_block_householder_on_the_right_in_place_with_conj, make_householder_in_place_v2, - upgrade_householder_factor, + linalg::{ + householder::{ + apply_block_householder_on_the_right_in_place_req, + apply_block_householder_on_the_right_in_place_with_conj, make_householder_in_place, + upgrade_householder_factor, + }, + matmul::{inner_prod::inner_prod_with_conj, matmul, triangular::BlockStructure}, + temp_mat_req, temp_mat_uninit, temp_mat_zeroed, }, - mul::{inner_prod::inner_prod_with_conj, matmul, triangular::BlockStructure}, - parallelism_degree, temp_mat_req, temp_mat_uninit, temp_mat_zeroed, unzipped, zipped, - ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, SimdCtx, + unzipped, + utils::thread::parallelism_degree, + zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, }; +use core::slice; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use faer_entity::*; use reborrow::*; @@ -132,10 +136,10 @@ impl pulp::WithSimd for HessenbergFusedUpdate<'_, E> { |slice| slice.split_at(prefix), )); - let w_suffix = faer_core::simd::slice_as_mut_simd::(w_suffix).0; - let u_suffix = faer_core::simd::slice_as_simd::(u_suffix).0; - let z_suffix = faer_core::simd::slice_as_simd::(z_suffix).0; - let x_suffix = faer_core::simd::slice_as_simd::(x_suffix).0; + let w_suffix = faer_entity::slice_as_mut_simd::(w_suffix).0; + let u_suffix = faer_entity::slice_as_simd::(u_suffix).0; + let z_suffix = faer_entity::slice_as_simd::(z_suffix).0; + let x_suffix = faer_entity::slice_as_simd::(x_suffix).0; let (mut w_head, mut w_tail) = E::faer_as_arrays_mut::<4, _>(w_suffix); let (u_head, u_tail) = E::faer_as_arrays::<4, _>(u_suffix); @@ -156,7 +160,7 @@ impl pulp::WithSimd for HessenbergFusedUpdate<'_, E> { #[inline(always)] |slice| slice.split_at_mut(prefix), )); - let a_suffix = faer_core::simd::slice_as_mut_simd::(a_suffix).0; + let a_suffix = faer_entity::slice_as_mut_simd::(a_suffix).0; let (a_head, a_tail) = E::faer_as_arrays_mut::<4, _>(a_suffix); let y_rhs = E::faer_simd_splat(simd, y.read(j, 0).faer_conj().faer_neg()); @@ -511,7 +515,7 @@ fn make_hessenberg_in_place_basic( let (tau, new_head) = { let (head, tail) = a21.rb_mut().split_at_row_mut(1); let norm = tail.rb().norm_l2(); - make_householder_in_place_v2(Some(tail), head.read(0, 0), norm) + make_householder_in_place(Some(tail), head.read(0, 0), norm) }; a21.write(0, 0, E::faer_one()); let tau_inv = tau.faer_inv(); @@ -670,7 +674,7 @@ fn make_hessenberg_in_place_qgvdg_unblocked( let mut z = z; let n = a.nrows(); - let (mut tmp, _) = faer_core::temp_mat_uninit::(n, 1, stack); + let (mut tmp, _) = temp_mat_uninit::(n, 1, stack); let one = E::faer_one(); @@ -691,29 +695,40 @@ fn make_hessenberg_in_place_qgvdg_unblocked( zipped!(tmp.rb_mut(), u10_adjoint.transpose()) .for_each(|unzipped!(mut dst, src)| dst.write(src.read().faer_conj())); if k > 0 { - tmp.write(k - 1, 0, one); + tmp.write(k - 1, one); } - faer_core::solve::solve_upper_triangular_in_place(t00, tmp.rb_mut(), par); + crate::linalg::triangular_solve::solve_upper_triangular_in_place( + t00, + tmp.rb_mut().as_2d_mut(), + par, + ); let z_0 = z.rb().get(.., ..k); let mut a_1 = a.rb_mut().get_mut(.., k).as_2d_mut(); - matmul(a_1.rb_mut(), z_0, tmp.rb(), Some(one), one.faer_neg(), par); + matmul( + a_1.rb_mut(), + z_0, + tmp.rb().as_2d(), + Some(one), + one.faer_neg(), + par, + ); if k > 0 { let u_0 = u_0.get(1.., ..); let ut0 = u_0.get(..k, ..); let ub0 = u_0.get(k.., ..); matmul( - tmp.rb_mut(), + tmp.rb_mut().as_2d_mut(), ub0.adjoint(), a_1.rb().get(k + 1.., ..), None, one, par, ); - faer_core::mul::triangular::matmul( - tmp.rb_mut(), + crate::linalg::matmul::triangular::matmul( + tmp.rb_mut().as_2d_mut(), BlockStructure::Rectangular, ut0.adjoint(), BlockStructure::UnitTriangularUpper, @@ -724,7 +739,11 @@ fn make_hessenberg_in_place_qgvdg_unblocked( par, ); } - faer_core::solve::solve_lower_triangular_in_place(t00.adjoint(), tmp.rb_mut(), par); + crate::linalg::triangular_solve::solve_lower_triangular_in_place( + t00.adjoint(), + tmp.rb_mut().as_2d_mut(), + par, + ); { let u_0 = u_0.get(1.., ..); let ut0 = u_0.get(..k, ..); @@ -732,17 +751,17 @@ fn make_hessenberg_in_place_qgvdg_unblocked( matmul( a_1.rb_mut().get_mut(k + 1.., ..), ub0, - tmp.rb(), + tmp.rb().as_2d(), Some(one), one.faer_neg(), par, ); - faer_core::mul::triangular::matmul( + crate::linalg::matmul::triangular::matmul( a_1.rb_mut().get_mut(1..k + 1, ..), BlockStructure::Rectangular, ut0, BlockStructure::UnitTriangularLower, - tmp.rb(), + tmp.rb().as_2d(), BlockStructure::Rectangular, Some(one), one.faer_neg(), @@ -754,12 +773,12 @@ fn make_hessenberg_in_place_qgvdg_unblocked( if k + 1 < n { let (tau, new_head) = { - let (head, tail) = a21.rb_mut().split_at_row_mut(1); + let (head, tail) = a21.rb_mut().split_at_mut(1); let norm = tail.rb().norm_l2(); - make_householder_in_place_v2(Some(tail), head.read(0, 0), norm) + make_householder_in_place(Some(tail.as_2d_mut()), head.read(0), norm) }; t.rb_mut().write(k, k, tau); - a21.write(0, 0, one); + a21.write(0, one); let u = a.rb(); let mut a = unsafe { a.rb().const_cast() }; @@ -769,10 +788,17 @@ fn make_hessenberg_in_place_qgvdg_unblocked( let u20 = u.get(k + 1.., ..k); let mut z_1 = z.rb_mut().get_mut(.., k).as_2d_mut(); - matmul(z_1.rb_mut(), a_2, u21, None, one, par); + matmul(z_1.rb_mut(), a_2, u21.as_2d(), None, one, par); let mut t01 = t.rb_mut().get_mut(..k, k); - matmul(t01.rb_mut(), u20.adjoint(), u21, None, one, par); + matmul( + t01.rb_mut().as_2d_mut(), + u20.adjoint(), + u21.as_2d(), + None, + one, + par, + ); a.write(k + 1, k, new_head); } } @@ -822,7 +848,7 @@ fn make_hessenberg_in_place_qgvdg_blocked( if k + 1 < n { let bs_u = Ord::min(bs, n - k - 1); if k > 0 { - let (mut tmp, _) = faer_core::temp_mat_uninit::(k, bs_u, stack.rb_mut()); + let (mut tmp, _) = temp_mat_uninit::(k, bs_u, stack.rb_mut()); let mut atr = a.rb_mut().get_mut(..k, k..); let ub0 = ub.get(..bs_u, ..bs_u); let ub1 = ub.get(bs_u.., ..bs_u); @@ -835,7 +861,7 @@ fn make_hessenberg_in_place_qgvdg_blocked( E::faer_one(), parallelism, ); - faer_core::mul::triangular::matmul( + crate::linalg::matmul::triangular::matmul( tmp.rb_mut(), BlockStructure::Rectangular, atr.rb().get(.., 1..bs_u + 1), @@ -849,7 +875,7 @@ fn make_hessenberg_in_place_qgvdg_blocked( // TMP := TMP * T^-1 // TMP^T := T^-T * TMP^T - faer_core::solve::solve_lower_triangular_in_place( + crate::linalg::triangular_solve::solve_lower_triangular_in_place( t1.transpose(), tmp.rb_mut().transpose_mut(), parallelism, @@ -863,7 +889,7 @@ fn make_hessenberg_in_place_qgvdg_blocked( E::faer_one().faer_neg(), parallelism, ); - faer_core::mul::triangular::matmul( + crate::linalg::matmul::triangular::matmul( atr.rb_mut().get_mut(.., 1..bs_u + 1), BlockStructure::Rectangular, tmp.rb(), @@ -884,7 +910,7 @@ fn make_hessenberg_in_place_qgvdg_blocked( unsafe { u2.const_cast() }.write(0, bs - 1, E::faer_one()); { - faer_core::solve::solve_lower_triangular_in_place( + crate::linalg::triangular_solve::solve_lower_triangular_in_place( t1.transpose(), zb.rb_mut().transpose_mut(), parallelism, @@ -901,8 +927,7 @@ fn make_hessenberg_in_place_qgvdg_blocked( } { - let (mut tmp, _) = - faer_core::temp_mat_uninit::(bs_u, n - k - bs, stack.rb_mut()); + let (mut tmp, _) = temp_mat_uninit::(bs_u, n - k - bs, stack.rb_mut()); let ub0 = ub.get(..bs_u, ..bs_u); let ub1 = ub.get(bs_u.., ..bs_u); @@ -914,7 +939,7 @@ fn make_hessenberg_in_place_qgvdg_blocked( E::faer_one(), parallelism, ); - faer_core::mul::triangular::matmul( + crate::linalg::matmul::triangular::matmul( tmp.rb_mut(), BlockStructure::Rectangular, ub0.adjoint(), @@ -928,7 +953,7 @@ fn make_hessenberg_in_place_qgvdg_blocked( // TMP := TMP * T^-1 // TMP^T := T^-T * TMP^T - faer_core::solve::solve_lower_triangular_in_place( + crate::linalg::triangular_solve::solve_lower_triangular_in_place( t1.adjoint(), tmp.rb_mut(), parallelism, @@ -942,7 +967,7 @@ fn make_hessenberg_in_place_qgvdg_blocked( E::faer_one().faer_neg(), parallelism, ); - faer_core::mul::triangular::matmul( + crate::linalg::matmul::triangular::matmul( ab.rb_mut().get_mut(1..bs_u + 1, ..), BlockStructure::Rectangular, ub0, @@ -965,10 +990,10 @@ fn make_hessenberg_in_place_qgvdg_blocked( #[cfg(test)] mod tests { use super::*; - use assert_approx_eq::assert_approx_eq; - use faer_core::{ - assert, c64, - householder::{ + use crate::{ + assert, + complex_native::c64, + linalg::householder::{ apply_block_householder_sequence_on_the_right_in_place_req, apply_block_householder_sequence_on_the_right_in_place_with_conj, apply_block_householder_sequence_transpose_on_the_left_in_place_req, @@ -976,6 +1001,7 @@ mod tests { }, Mat, }; + use assert_approx_eq::assert_approx_eq; macro_rules! make_stack { ($req: expr $(,)?) => { @@ -1056,8 +1082,7 @@ mod tests { let mut t = Mat::zeros(n - 1, n - 1); let mut z = Mat::zeros(n, n - 1); - let mut mem = - dyn_stack::GlobalPodBuffer::new(faer_core::temp_mat_req::(n, 1).unwrap()); + let mut mem = dyn_stack::GlobalPodBuffer::new(temp_mat_req::(n, 1).unwrap()); make_hessenberg_in_place_qgvdg_unblocked( a.as_mut(), z.as_mut(), @@ -1093,7 +1118,7 @@ mod tests { let mut z = Mat::zeros(n, 8); let mut mem = - dyn_stack::GlobalPodBuffer::new(faer_core::temp_mat_req::(n, n).unwrap()); + dyn_stack::GlobalPodBuffer::new(crate::linalg::temp_mat_req::(n, n).unwrap()); make_hessenberg_in_place_qgvdg_blocked( a.as_mut(), z.as_mut(), @@ -1128,7 +1153,7 @@ mod tests { let mut z = Mat::zeros(n, n - 1); let mut mem = - dyn_stack::GlobalPodBuffer::new(faer_core::temp_mat_req::(n, 1).unwrap()); + dyn_stack::GlobalPodBuffer::new(crate::linalg::temp_mat_req::(n, 1).unwrap()); make_hessenberg_in_place_qgvdg_unblocked( a.as_mut(), z.as_mut(), @@ -1163,8 +1188,7 @@ mod tests { let mut t = Mat::zeros(n - 1, 4); let mut z = Mat::zeros(n, 4); - let mut mem = - dyn_stack::GlobalPodBuffer::new(faer_core::temp_mat_req::(n, n).unwrap()); + let mut mem = dyn_stack::GlobalPodBuffer::new(temp_mat_req::(n, n).unwrap()); make_hessenberg_in_place_qgvdg_blocked( a.as_mut(), z.as_mut(), diff --git a/faer-libs/faer-evd/src/hessenberg_cplx_evd.rs b/src/linalg/evd/hessenberg_cplx_evd.rs similarity index 99% rename from faer-libs/faer-evd/src/hessenberg_cplx_evd.rs rename to src/linalg/evd/hessenberg_cplx_evd.rs index 9fbfd7d1aa4e66d465dadd424c27a5498e19c685..589cdb6fe1ed5dace17b3fdf343e6ad5faa56677 100644 --- a/faer-libs/faer-evd/src/hessenberg_cplx_evd.rs +++ b/src/linalg/evd/hessenberg_cplx_evd.rs @@ -3,22 +3,24 @@ // https://github.com/tlapack/tlapack // https://github.com/tlapack/tlapack/blob/master/include/tlapack/lapack/lahqr.hpp -use crate::hessenberg::{make_hessenberg_in_place, make_hessenberg_in_place_req}; -use core::slice; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - householder::{ - apply_block_householder_sequence_on_the_right_in_place_req, - apply_block_householder_sequence_on_the_right_in_place_with_conj, - apply_block_householder_sequence_transpose_on_the_left_in_place_req, - apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj, - make_householder_in_place_v2, +use crate::{ + linalg::{ + evd::hessenberg::{make_hessenberg_in_place, make_hessenberg_in_place_req}, + householder::{ + apply_block_householder_sequence_on_the_right_in_place_req, + apply_block_householder_sequence_on_the_right_in_place_with_conj, + apply_block_householder_sequence_transpose_on_the_left_in_place_req, + apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj, + make_householder_in_place, + }, + matmul::matmul, + temp_mat_req, + zip::Diag, }, - mul::matmul, - temp_mat_req, unzipped, - zip::Diag, - zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, RealField, SimdCtx, + unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, RealField, }; +use core::slice; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use faer_entity::*; use reborrow::*; @@ -309,8 +311,8 @@ impl pulp::WithSimd for Rot<'_, E> { ) }; - let (ai_head, ai_tail) = faer_core::simd::slice_as_mut_simd::(ai); - let (aj_head, aj_tail) = faer_core::simd::slice_as_mut_simd::(aj); + let (ai_head, ai_tail) = faer_entity::slice_as_mut_simd::(ai); + let (aj_head, aj_tail) = faer_entity::slice_as_mut_simd::(aj); let c = E::Real::faer_simd_splat(simd, c); let s = E::faer_simd_splat(simd, s); @@ -968,7 +970,7 @@ fn aggressive_early_deflation( let head = vv.read(0, 0); let tail = vv.rb_mut().subrows_mut(1, ns - 1); let tail_norm = tail.rb().norm_l2(); - let (tau, beta) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau, beta) = make_householder_in_place(Some(tail), head, tail_norm); vv.write(0, 0, E::faer_one()); let tau = tau.faer_inv(); @@ -1585,7 +1587,7 @@ fn move_bulge( let head = v.read(0, 0); let tail = v.rb_mut().subrows_mut(1, 2); let tail_norm = tail.rb().norm_l2(); - let (tau, beta) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau, beta) = make_householder_in_place(Some(tail), head, tail_norm); v.write(0, 0, tau.faer_inv()); // Check for bulge collapse @@ -1604,7 +1606,7 @@ fn move_bulge( [zero_unit, zero_unit, zero_unit] }); let vt_ptr = E::faer_map(E::faer_as_mut(&mut vt_storage), |array| array.as_mut_ptr()); - let mut vt = unsafe { faer_core::mat::from_raw_parts_mut::<'_, E>(vt_ptr, 3, 1, 1, 3) }; + let mut vt = unsafe { crate::mat::from_raw_parts_mut::<'_, E>(vt_ptr, 3, 1, 1, 3) }; let h2 = h.rb().submatrix(1, 1, 3, 3); lahqr_shiftcolumn(h2, vt.rb_mut(), s1, s2); @@ -1612,7 +1614,7 @@ fn move_bulge( let head = vt.read(0, 0); let tail = vt.rb_mut().subrows_mut(1, 2); let tail_norm = tail.rb().norm_l2(); - let (tau, _) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau, _) = make_householder_in_place(Some(tail), head, tail_norm); vt.write(0, 0, tau.faer_inv()); let vt0 = vt.read(0, 0); let vt1 = vt.read(1, 0); @@ -1665,7 +1667,7 @@ fn multishift_qr_sweep( .faer_mul(E::Real::faer_from_f64(n as f64)); assert!(n >= 12); - let (mut v, _stack) = faer_core::temp_mat_zeroed::(3, s.nrows() / 2, stack); + let (mut v, _stack) = crate::linalg::temp_mat_zeroed::(3, s.nrows() / 2, stack); let mut v = v.as_mut(); let n_block_max = (n - 3) / 3; @@ -1753,7 +1755,7 @@ fn multishift_qr_sweep( let head = v.read(0, 0); let tail = v.rb_mut().subrows_mut(1, 2); let tail_norm = tail.rb().norm_l2(); - let (tau, _) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau, _) = make_householder_in_place(Some(tail), head, tail_norm); v.write(0, 0, tau.faer_inv()); } else { // Chase bulge down @@ -2340,7 +2342,7 @@ fn multishift_qr_sweep( let head = h.read(0, 0); let tail = h.rb_mut().subrows_mut(1, 1); let tail_norm = tail.rb().norm_l2(); - let (tau, beta) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau, beta) = make_householder_in_place(Some(tail), head, tail_norm); v.write(0, 0, tau.faer_inv()); v.write(1, 0, h.read(1, 0)); h.write(0, 0, beta); @@ -2672,8 +2674,8 @@ fn multishift_qr_sweep( #[cfg(test)] mod tests { use super::*; + use crate::{assert, complex_native::c64, mat, Mat}; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, c64, mat, Mat}; macro_rules! make_stack { ($req: expr $(,)?) => { diff --git a/faer-libs/faer-evd/src/hessenberg_real_evd.rs b/src/linalg/evd/hessenberg_real_evd.rs similarity index 98% rename from faer-libs/faer-evd/src/hessenberg_real_evd.rs rename to src/linalg/evd/hessenberg_real_evd.rs index d0d5d8288ca17335390c4049260466fef7ff97f9..d20243e3d816cc0907cac28d2da51a72f5b0790a 100644 --- a/faer-libs/faer-evd/src/hessenberg_real_evd.rs +++ b/src/linalg/evd/hessenberg_real_evd.rs @@ -4,27 +4,29 @@ // https://github.com/tlapack/tlapack/blob/master/include/tlapack/lapack/lahqr.hpp use crate::{ - hessenberg::make_hessenberg_in_place, - hessenberg_cplx_evd::{ - default_blocking_threshold, default_nibble_threshold, default_recommended_deflation_window, - }, -}; -use dyn_stack::PodStack; -use faer_core::{ assert, debug_assert, - householder::{ - apply_block_householder_sequence_on_the_right_in_place_with_conj, - apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj, - make_householder_in_place_v2, + linalg::{ + evd::{ + hessenberg::make_hessenberg_in_place, + hessenberg_cplx_evd::{ + default_blocking_threshold, default_nibble_threshold, + default_recommended_deflation_window, + }, + }, + householder::{ + apply_block_householder_sequence_on_the_right_in_place_with_conj, + apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj, + make_householder_in_place, + }, + matmul::matmul, + zip::Diag, }, - mul::matmul, - unzipped, - zip::Diag, - zipped, ComplexField, Conj, MatMut, MatRef, Parallelism, RealField, + unzipped, zipped, ComplexField, Conj, MatMut, MatRef, Parallelism, RealField, }; +use dyn_stack::PodStack; use reborrow::*; -pub use crate::hessenberg_cplx_evd::{multishift_qr_req, EvdParams}; +pub use crate::linalg::evd::hessenberg_cplx_evd::{multishift_qr_req, EvdParams}; fn hypot(a: E, b: E) -> E { num_complex::Complex { re: a, im: b }.faer_abs() @@ -387,7 +389,7 @@ fn lasy2( [zero, zero, zero, zero] }); let mut btmp = unsafe { - faer_core::mat::from_raw_parts_mut::<'_, E>( + crate::mat::from_raw_parts_mut::<'_, E>( E::faer_map(E::faer_as_mut(&mut btmp), |array| array.as_mut_ptr()), 4, 1, @@ -400,7 +402,7 @@ fn lasy2( [zero, zero, zero, zero] }); let mut tmp = unsafe { - faer_core::mat::from_raw_parts_mut::<'_, E>( + crate::mat::from_raw_parts_mut::<'_, E>( E::faer_map(E::faer_as_mut(&mut tmp), |array| array.as_mut_ptr()), 4, 1, @@ -416,7 +418,7 @@ fn lasy2( ] }); let mut t16 = unsafe { - faer_core::mat::from_raw_parts_mut::<'_, E>( + crate::mat::from_raw_parts_mut::<'_, E>( E::faer_map(E::faer_as_mut(&mut t16), |array| array.as_mut_ptr()), 4, 4, @@ -478,14 +480,14 @@ fn lasy2( } } if ipsv != i { - faer_core::permutation::swap_rows(t16.rb_mut(), ipsv, i); + crate::perm::swap_rows_idx(t16.rb_mut(), ipsv, i); let temp = btmp.read(i, 0); btmp.write(i, 0, btmp.read(ipsv, 0)); btmp.write(ipsv, 0, temp); } if jpsv != i { - faer_core::permutation::swap_cols(t16.rb_mut(), jpsv, i); + crate::perm::swap_cols_idx(t16.rb_mut(), jpsv, i); } jpiv[i] = jpsv; if abs1(t16.read(i, i)) < smin { @@ -731,7 +733,7 @@ fn schur_swap( [zero, zero, zero, zero, zero, zero] }); let b_ptr = E::faer_map(E::faer_as_mut(&mut b_storage), |array| array.as_mut_ptr()); - let mut b = unsafe { faer_core::mat::from_raw_parts_mut::<'_, E>(b_ptr, 3, 2, 1, 3) }; + let mut b = unsafe { crate::mat::from_raw_parts_mut::<'_, E>(b_ptr, 3, 2, 1, 3) }; b.write(0, 0, a.read(j0, j1)); b.write(1, 0, a.read(j1, j1).faer_sub(a.read(j0, j0))); @@ -745,7 +747,7 @@ fn schur_swap( let head = v1.read(0, 0); let tail = v1.rb_mut().subrows_mut(1, 2); let tail_norm = hypot(tail.read(0, 0), tail.read(1, 0)); - let (tau1, beta1) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau1, beta1) = make_householder_in_place(Some(tail), head, tail_norm); let tau1 = tau1.faer_inv(); v1.write(0, 0, beta1); let v11 = b.read(1, 0); @@ -772,7 +774,7 @@ fn schur_swap( let head = v2.read(0, 0); let tail = v2.rb_mut().subrows_mut(1, 1); let tail_norm = tail.read(0, 0).faer_abs(); - let (tau2, beta2) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau2, beta2) = make_householder_in_place(Some(tail), head, tail_norm); let tau2 = tau2.faer_inv(); v2.write(0, 0, beta2); let v21 = v2.read(1, 0); @@ -875,7 +877,7 @@ fn schur_swap( [zero, zero, zero, zero, zero, zero] }); let b_ptr = E::faer_map(E::faer_as_mut(&mut b_storage), |array| array.as_mut_ptr()); - let mut b = unsafe { faer_core::mat::from_raw_parts_mut::<'_, E>(b_ptr, 3, 2, 1, 3) }; + let mut b = unsafe { crate::mat::from_raw_parts_mut::<'_, E>(b_ptr, 3, 2, 1, 3) }; b.write(0, 0, a.read(j1, j2)); b.write(1, 0, a.read(j1, j1).faer_sub(a.read(j2, j2))); @@ -889,7 +891,7 @@ fn schur_swap( let head = v1.read(0, 0); let tail = v1.rb_mut().subrows_mut(1, 2); let tail_norm = hypot(tail.read(0, 0), tail.read(1, 0)); - let (tau1, beta1) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau1, beta1) = make_householder_in_place(Some(tail), head, tail_norm); let tau1 = tau1.faer_inv(); v1.write(0, 0, beta1); let v11 = v1.read(1, 0); @@ -916,7 +918,7 @@ fn schur_swap( let head = v2.read(0, 0); let tail = v2.rb_mut().subrows_mut(1, 1); let tail_norm = tail.read(0, 0).faer_abs(); - let (tau2, beta2) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau2, beta2) = make_householder_in_place(Some(tail), head, tail_norm); let tau2 = tau2.faer_inv(); v2.write(0, 0, beta2); let v21 = v2.read(1, 0); @@ -1018,7 +1020,7 @@ fn schur_swap( ] }); let d_ptr = E::faer_map(E::faer_as_mut(&mut d_storage), |array| array.as_mut_ptr()); - let mut d = unsafe { faer_core::mat::from_raw_parts_mut::<'_, E>(d_ptr, 4, 4, 1, 4) }; + let mut d = unsafe { crate::mat::from_raw_parts_mut::<'_, E>(d_ptr, 4, 4, 1, 4) }; let ad_slice = a.rb().submatrix(j0, j0, 4, 4); d.copy_from(ad_slice); @@ -1036,7 +1038,7 @@ fn schur_swap( [zero, zero, zero, zero, zero, zero, zero, zero] }); let v_ptr = E::faer_map(E::faer_as_mut(&mut v_storage), |array| array.as_mut_ptr()); - let mut v = unsafe { faer_core::mat::from_raw_parts_mut::<'_, E>(v_ptr, 4, 2, 1, 4) }; + let mut v = unsafe { crate::mat::from_raw_parts_mut::<'_, E>(v_ptr, 4, 2, 1, 4) }; let mut x = v.rb_mut().submatrix_mut(0, 0, 2, 2); let (tl, b, _, tr) = d.rb().split_at(2, 2); @@ -1053,7 +1055,7 @@ fn schur_swap( let head = v1.read(0, 0); let tail = v1.rb_mut().subrows_mut(1, 3); let tail_norm = hypot(hypot(tail.read(0, 0), tail.read(1, 0)), tail.read(2, 0)); - let (tau1, beta1) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau1, beta1) = make_householder_in_place(Some(tail), head, tail_norm); let tau1 = tau1.faer_inv(); v1.write(0, 0, beta1); let v11 = v1.read(1, 0); @@ -1087,7 +1089,7 @@ fn schur_swap( let head = v2.read(0, 0); let tail = v2.rb_mut().subrows_mut(1, 2); let tail_norm = hypot(tail.read(0, 0), tail.read(1, 0)); - let (tau2, beta2) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau2, beta2) = make_householder_in_place(Some(tail), head, tail_norm); let tau2 = tau2.faer_inv(); v2.write(0, 0, beta2); @@ -1723,7 +1725,7 @@ fn aggressive_early_deflation( let head = vv.read(0, 0); let tail = vv.rb_mut().subrows_mut(1, ns - 1); let tail_norm = tail.rb().norm_l2(); - let (tau, beta) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau, beta) = make_householder_in_place(Some(tail), head, tail_norm); vv.write(0, 0, E::faer_one()); let tau = tau.faer_inv(); @@ -1916,7 +1918,7 @@ fn move_bulge( let head = v.read(0, 0); let tail = v.rb_mut().subrows_mut(1, 2); let tail_norm = tail.rb().norm_l2(); - let (tau, beta) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau, beta) = make_householder_in_place(Some(tail), head, tail_norm); v.write(0, 0, tau.faer_inv()); // Check for bulge collapse @@ -1935,7 +1937,7 @@ fn move_bulge( [zero_unit, zero_unit, zero_unit] }); let vt_ptr = E::faer_map(E::faer_as_mut(&mut vt_storage), |array| array.as_mut_ptr()); - let mut vt = unsafe { faer_core::mat::from_raw_parts_mut::<'_, E>(vt_ptr, 3, 1, 1, 3) }; + let mut vt = unsafe { crate::mat::from_raw_parts_mut::<'_, E>(vt_ptr, 3, 1, 1, 3) }; let h2 = h.rb().submatrix(1, 1, 3, 3); lahqr_shiftcolumn(h2, vt.rb_mut(), s1, s2); @@ -1943,7 +1945,7 @@ fn move_bulge( let head = vt.read(0, 0); let tail = vt.rb_mut().subrows_mut(1, 2); let tail_norm = tail.rb().norm_l2(); - let (tau, _) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau, _) = make_householder_in_place(Some(tail), head, tail_norm); vt.write(0, 0, tau.faer_inv()); let vt0 = vt.read(0, 0); let vt1 = vt.read(1, 0); @@ -1997,7 +1999,7 @@ fn multishift_qr_sweep( .faer_mul(E::Real::faer_from_f64(n as f64)); assert!(n >= 12); - let (mut v, _stack) = faer_core::temp_mat_zeroed::(3, s_re.nrows() / 2, stack); + let (mut v, _stack) = crate::linalg::temp_mat_zeroed::(3, s_re.nrows() / 2, stack); let mut v = v.as_mut(); let n_block_max = (n - 3) / 3; @@ -2087,7 +2089,7 @@ fn multishift_qr_sweep( let head = v.read(0, 0); let tail = v.rb_mut().subrows_mut(1, 2); let tail_norm = tail.rb().norm_l2(); - let (tau, _) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau, _) = make_householder_in_place(Some(tail), head, tail_norm); v.write(0, 0, tau.faer_inv()); } else { // Chase bulge down @@ -2690,7 +2692,7 @@ fn multishift_qr_sweep( let head = h.read(0, 0); let tail = h.rb_mut().subrows_mut(1, 1); let tail_norm = tail.rb().norm_l2(); - let (tau, beta) = make_householder_in_place_v2(Some(tail), head, tail_norm); + let (tau, beta) = make_householder_in_place(Some(tail), head, tail_norm); v.write(0, 0, tau.faer_inv()); v.write(1, 0, h.read(1, 0)); h.write(0, 0, beta); @@ -3456,7 +3458,7 @@ pub fn lahqr( [zero_unit, zero_unit, zero_unit] }); let v_ptr = E::faer_map(E::faer_as_mut(&mut v_storage), |array| array.as_mut_ptr()); - let mut v = unsafe { faer_core::mat::from_raw_parts_mut::<'_, E>(v_ptr, 3, 1, 1, 3) }; + let mut v = unsafe { crate::mat::from_raw_parts_mut::<'_, E>(v_ptr, 3, 1, 1, 3) }; for iter in 0..itmax + 1 { if iter == itmax { return istop as isize; @@ -3674,11 +3676,8 @@ pub fn lahqr( lahqr_shiftcolumn(h, v.rb_mut(), s1, s2); let head = v.read(0, 0); let tail_norm = hypot(v.read(1, 0), v.read(2, 0)).faer_abs(); - let (tau, _) = make_householder_in_place_v2( - Some(v.rb_mut().subrows_mut(1, 2)), - head, - tail_norm, - ); + let (tau, _) = + make_householder_in_place(Some(v.rb_mut().subrows_mut(1, 2)), head, tail_norm); let tau = tau.faer_inv(); let v0 = tau; @@ -3714,7 +3713,7 @@ pub fn lahqr( let tail = x.rb_mut().subrows_mut(1, nr - 1); let tail_norm = tail.rb().norm_l2(); let beta; - (t1, beta) = make_householder_in_place_v2(Some(tail), head, tail_norm); + (t1, beta) = make_householder_in_place(Some(tail), head, tail_norm); v.write(0, 0, beta); t1 = t1.faer_inv(); if i > istart { @@ -3730,7 +3729,7 @@ pub fn lahqr( let tail = v.rb_mut().subrows_mut(1, nr - 1); let tail_norm = tail.rb().norm_l2(); let beta; - (t1, beta) = make_householder_in_place_v2(Some(tail), head, tail_norm); + (t1, beta) = make_householder_in_place(Some(tail), head, tail_norm); t1 = t1.faer_inv(); v.write(0, 0, beta); a.write(i, i - 1, beta); @@ -3813,8 +3812,8 @@ pub fn lahqr( #[cfg(test)] mod tests { use super::*; + use crate::{assert, mat, ComplexField, Mat}; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, mat, ComplexField, Mat}; macro_rules! make_stack { ($req: expr $(,)?) => { diff --git a/faer-libs/faer-evd/src/lib.rs b/src/linalg/evd/mod.rs similarity index 95% rename from faer-libs/faer-evd/src/lib.rs rename to src/linalg/evd/mod.rs index a24071dc01a5ccc30344f9c2b319d3f7ae62bf6b..0676f92accf82d4e1e94273c31d1f4d8b7f25c55 100644 --- a/faer-libs/faer-evd/src/lib.rs +++ b/src/linalg/evd/mod.rs @@ -1,3 +1,5 @@ +//! Low level implementation of the eigenvalue decomposition of a square diagonalizable matrix. +//! //! The eigenvalue decomposition of a square matrix $M$ of shape $(n, n)$ is a decomposition into //! two components $U$, $S$: //! @@ -13,23 +15,25 @@ #![allow(clippy::too_many_arguments)] #![cfg_attr(not(feature = "std"), no_std)] -use coe::Coerce; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - householder::{ - apply_block_householder_sequence_on_the_right_in_place_req, - apply_block_householder_sequence_on_the_right_in_place_with_conj, - upgrade_householder_factor, - }, - mul::{ - inner_prod::inner_prod_with_conj, - triangular::{self, BlockStructure}, + linalg::{ + householder::{ + apply_block_householder_sequence_on_the_right_in_place_req, + apply_block_householder_sequence_on_the_right_in_place_with_conj, + upgrade_householder_factor, + }, + matmul::{ + inner_prod::inner_prod_with_conj, + triangular::{self, BlockStructure}, + }, + qr::no_pivoting::compute::recommended_blocksize, + temp_mat_req, temp_mat_uninit, temp_mat_zeroed, }, - temp_mat_req, temp_mat_uninit, temp_mat_zeroed, unzipped, zipped, ComplexField, Conj, MatMut, - MatRef, Parallelism, RealField, + unzipped, zipped, ComplexField, Conj, MatMut, MatRef, Parallelism, RealField, }; -use faer_qr::no_pivoting::compute::recommended_blocksize; +use coe::Coerce; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; pub use hessenberg_cplx_evd::EvdParams; use reborrow::*; @@ -53,13 +57,16 @@ pub mod hessenberg_real_evd; /// Indicates whether the eigenvectors are fully computed, partially computed, or skipped. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum ComputeVectors { + /// Do not compute the eigenvectors. No, + /// Do compute the eigenvectors. Yes, } +/// Hermitian eigendecomposition tuning parameters. #[derive(Default, Copy, Clone)] #[non_exhaustive] -pub struct SymmetricEvdParams {} +pub struct HermitianEvdParams {} /// Computes the size and alignment of required workspace for performing a hermitian eigenvalue /// decomposition. The eigenvectors may be optionally computed. @@ -67,11 +74,12 @@ pub fn compute_hermitian_evd_req( n: usize, compute_eigenvectors: ComputeVectors, parallelism: Parallelism, - params: SymmetricEvdParams, + params: HermitianEvdParams, ) -> Result { let _ = params; let _ = compute_eigenvectors; - let householder_blocksize = faer_qr::no_pivoting::compute::recommended_blocksize::(n, n); + let householder_blocksize = + crate::linalg::qr::no_pivoting::compute::recommended_blocksize::(n, n); let cplx_storage = if coe::is_same::() { StackReq::empty() @@ -93,11 +101,9 @@ pub fn compute_hermitian_evd_req( tridiag_real_evd::compute_tridiag_real_evd_req::(n, parallelism)?, cplx_storage, ])?, - faer_core::householder::apply_block_householder_sequence_on_the_left_in_place_req::( - n - 1, - householder_blocksize, - n, - )?, + crate::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_req::< + E, + >(n - 1, householder_blocksize, n)?, ])?, ]) } @@ -123,7 +129,7 @@ pub fn compute_hermitian_evd( u: Option>, parallelism: Parallelism, stack: PodStack<'_>, - params: SymmetricEvdParams, + params: HermitianEvdParams, ) { compute_hermitian_evd_custom_epsilon( matrix, @@ -153,7 +159,7 @@ pub fn compute_hermitian_evd_custom_epsilon( zero_threshold: E::Real, parallelism: Parallelism, stack: PodStack<'_>, - params: SymmetricEvdParams, + params: HermitianEvdParams, ) { let _ = params; let n = matrix.nrows(); @@ -173,7 +179,7 @@ pub fn compute_hermitian_evd_custom_epsilon( #[cfg(feature = "perf-warn")] if let Some(matrix) = u.rb() { - if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(QR_WARN) { + if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(QR_WARN) { if matrix.col_stride().unsigned_abs() == 1 { log::warn!(target: "faer_perf", "EVD prefers column-major eigenvector matrix. Found row-major matrix."); } else { @@ -183,7 +189,7 @@ pub fn compute_hermitian_evd_custom_epsilon( } let mut all_finite = true; - zipped!(matrix).for_each_triangular_lower(faer_core::zip::Diag::Include, |unzipped!(x)| { + zipped!(matrix).for_each_triangular_lower(crate::linalg::zip::Diag::Include, |unzipped!(x)| { all_finite &= x.read().faer_is_finite(); }); @@ -197,17 +203,17 @@ pub fn compute_hermitian_evd_custom_epsilon( let (mut trid, stack) = temp_mat_uninit::(n, n, stack); let householder_blocksize = - faer_qr::no_pivoting::compute::recommended_blocksize::(n - 1, n - 1); + crate::linalg::qr::no_pivoting::compute::recommended_blocksize::(n - 1, n - 1); let (mut householder, mut stack) = temp_mat_uninit::(householder_blocksize, n - 1, stack); let mut householder = householder.as_mut(); let mut trid = trid.as_mut(); - zipped!(trid.rb_mut(), matrix) - .for_each_triangular_lower(faer_core::zip::Diag::Include, |unzipped!(mut dst, src)| { - dst.write(src.read()) - }); + zipped!(trid.rb_mut(), matrix).for_each_triangular_lower( + crate::linalg::zip::Diag::Include, + |unzipped!(mut dst, src)| dst.write(src.read()), + ); tridiag::tridiagonalize_in_place( trid.rb_mut(), @@ -313,12 +319,12 @@ pub fn compute_hermitian_evd_custom_epsilon( } } - let mut m = faer_core::Mat::::zeros(n, n); + let mut m = crate::Mat::::zeros(n, n); for i in 0..n { m.write(i, i, s.read(i, 0)); } - faer_core::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj( + crate::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj( trid.submatrix(1, 0, n - 1, n - 1), householder.rb(), Conj::No, @@ -587,7 +593,7 @@ pub fn compute_evd_real_custom_epsilon( #[cfg(feature = "perf-warn")] if let Some(matrix) = u.rb() { - if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(QR_WARN) { + if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(QR_WARN) { if matrix.col_stride().unsigned_abs() == 1 { log::warn!(target: "faer_perf", "EVD prefers column-major eigenvector matrix. Found row-major matrix."); } else { @@ -671,9 +677,12 @@ pub fn compute_evd_real_custom_epsilon( let mut x = x.as_mut(); let mut norm = zero_threshold; - zipped!(h.rb()).for_each_triangular_upper(faer_core::zip::Diag::Include, |unzipped!(x)| { - norm = norm.faer_add(x.read().faer_abs()); - }); + zipped!(h.rb()).for_each_triangular_upper( + crate::linalg::zip::Diag::Include, + |unzipped!(x)| { + norm = norm.faer_add(x.read().faer_abs()); + }, + ); // subdiagonal zipped!(h .rb() @@ -1086,7 +1095,7 @@ pub fn compute_evd_complex_custom_epsilon( #[cfg(feature = "perf-warn")] if let Some(matrix) = u.rb() { - if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(QR_WARN) { + if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(QR_WARN) { if matrix.col_stride().unsigned_abs() == 1 { log::warn!(target: "faer_perf", "EVD prefers column-major eigenvector matrix. Found row-major matrix."); } else { @@ -1168,9 +1177,12 @@ pub fn compute_evd_complex_custom_epsilon( let mut x = x.as_mut(); let mut norm = zero_threshold; - zipped!(h.rb()).for_each_triangular_upper(faer_core::zip::Diag::Include, |unzipped!(x)| { - norm = norm.faer_add(x.read().faer_abs2()); - }); + zipped!(h.rb()).for_each_triangular_upper( + crate::linalg::zip::Diag::Include, + |unzipped!(x)| { + norm = norm.faer_add(x.read().faer_abs2()); + }, + ); let norm = norm.faer_sqrt(); let mut h = h.transpose_mut(); @@ -1238,8 +1250,8 @@ pub fn compute_evd_complex_custom_epsilon( #[cfg(test)] mod herm_tests { use super::*; + use crate::{assert, complex_native::c64, Mat}; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, c64, Mat}; macro_rules! make_stack { ($req: expr) => { @@ -1460,8 +1472,8 @@ mod herm_tests { #[cfg(test)] mod tests { use super::*; + use crate::{assert, complex_native::c64, Mat}; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, c64, Mat}; use num_complex::Complex; macro_rules! make_stack { @@ -1472,7 +1484,7 @@ mod tests { #[test] fn test_real_3() { - let mat = faer_core::mat![ + let mat = crate::mat![ [0.03498524449256035, 0.5246466104879548, 0.20804192188707582,], [0.007467248113335545, 0.1723793560841066, 0.2677423170633869,], [ @@ -1845,7 +1857,7 @@ mod tests { fn test_cplx_gh78() { let i = c64::new(0.0, 1.0); - let mat = faer_core::mat![ + let mat = crate::mat![ [ 0.0 + 0.0 * i, 0.0 + 0.0 * i, diff --git a/faer-libs/faer-evd/src/tridiag.rs b/src/linalg/evd/tridiag.rs similarity index 96% rename from faer-libs/faer-evd/src/tridiag.rs rename to src/linalg/evd/tridiag.rs index f00ec7fa608e8357deccf4aa35eb8ee16114a936..5466e0bdbb32933082df8f02be54dd88eb37ae86 100644 --- a/faer-libs/faer-evd/src/tridiag.rs +++ b/src/linalg/evd/tridiag.rs @@ -1,10 +1,12 @@ +use crate::{ + assert, debug_assert, + linalg::{matmul::inner_prod::inner_prod_with_conj, temp_mat_req, temp_mat_zeroed}, + unzipped, + utils::thread::parallelism_degree, + zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, +}; use core::iter::zip; use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, debug_assert, mul::inner_prod::inner_prod_with_conj, parallelism_degree, temp_mat_req, - temp_mat_zeroed, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, - SimdCtx, -}; use faer_entity::*; use reborrow::*; @@ -181,11 +183,11 @@ impl pulp::WithSimd for SymMatVecWithLhsUpdate<'_, E> { |slice| slice.split_at(prefix), )); - let acc_suffix = faer_core::simd::slice_as_mut_simd::(acc_suffix).0; - let lhs_suffix = faer_core::simd::slice_as_mut_simd::(lhs_suffix).0; - let rhs_suffix = faer_core::simd::slice_as_simd::(rhs_suffix).0; - let u_suffix = faer_core::simd::slice_as_simd::(u_suffix).0; - let y_suffix = faer_core::simd::slice_as_simd::(y_suffix).0; + let acc_suffix = faer_entity::slice_as_mut_simd::(acc_suffix).0; + let lhs_suffix = faer_entity::slice_as_mut_simd::(lhs_suffix).0; + let rhs_suffix = faer_entity::slice_as_simd::(rhs_suffix).0; + let u_suffix = faer_entity::slice_as_simd::(u_suffix).0; + let y_suffix = faer_entity::slice_as_simd::(y_suffix).0; let mut sum0 = E::faer_simd_splat(simd, zero); let mut sum1 = E::faer_simd_splat(simd, zero); @@ -582,12 +584,12 @@ impl pulp::WithSimd for SymMatVecWithLhsUpdate<'_, E> { |slice| slice.split_at(prefix), )); - let acc_suffix = faer_core::simd::slice_as_mut_simd::(acc_suffix).0; - let lhs0_suffix = faer_core::simd::slice_as_mut_simd::(lhs0_suffix).0; - let lhs1_suffix = faer_core::simd::slice_as_mut_simd::(lhs1_suffix).0; - let rhs_suffix = faer_core::simd::slice_as_simd::(rhs_suffix).0; - let u_suffix = faer_core::simd::slice_as_simd::(u_suffix).0; - let y_suffix = faer_core::simd::slice_as_simd::(y_suffix).0; + let acc_suffix = faer_entity::slice_as_mut_simd::(acc_suffix).0; + let lhs0_suffix = faer_entity::slice_as_mut_simd::(lhs0_suffix).0; + let lhs1_suffix = faer_entity::slice_as_mut_simd::(lhs1_suffix).0; + let rhs_suffix = faer_entity::slice_as_simd::(rhs_suffix).0; + let u_suffix = faer_entity::slice_as_simd::(u_suffix).0; + let y_suffix = faer_entity::slice_as_simd::(y_suffix).0; let mut sum0 = E::faer_simd_splat(simd, zero); let mut sum1 = E::faer_simd_splat(simd, zero); @@ -820,9 +822,9 @@ impl pulp::WithSimd for SymMatVec<'_, E> { |slice| slice.split_at(len % lane_count), )); - let acc_suffix = faer_core::simd::slice_as_mut_simd::(acc_suffix).0; - let lhs_suffix = faer_core::simd::slice_as_simd::(lhs_suffix).0; - let rhs_suffix = faer_core::simd::slice_as_simd::(rhs_suffix).0; + let acc_suffix = faer_entity::slice_as_mut_simd::(acc_suffix).0; + let lhs_suffix = faer_entity::slice_as_simd::(lhs_suffix).0; + let rhs_suffix = faer_entity::slice_as_simd::(rhs_suffix).0; let rhs_single_j = into_copy::(E::faer_map( E::faer_copy(&rhs), @@ -1048,7 +1050,7 @@ pub fn tridiagonalize_in_place( let (tau, new_head) = { let (head, tail) = a21.rb_mut().split_at_row_mut(1); let norm = tail.rb().norm_l2(); - faer_core::householder::make_householder_in_place_v2(Some(tail), head.read(0, 0), norm) + crate::linalg::householder::make_householder_in_place(Some(tail), head.read(0, 0), norm) }; a21.write(0, 0, E::faer_one()); let tau_inv = tau.faer_inv(); @@ -1068,7 +1070,7 @@ pub fn tridiagonalize_in_place( (col_start_percent * ncols) as usize }; - faer_core::for_each_raw( + crate::utils::thread::for_each_raw( parallelism_degree(parallelism), |idx| { let first_col = idx_to_col_start(idx); @@ -1165,10 +1167,10 @@ pub fn tridiagonalize_in_place( #[cfg(test)] mod tests { use super::*; - use assert_approx_eq::assert_approx_eq; - use faer_core::{ - assert, c64, - householder::{ + use crate::{ + assert, + complex_native::c64, + linalg::householder::{ apply_block_householder_sequence_on_the_right_in_place_req, apply_block_householder_sequence_on_the_right_in_place_with_conj, apply_block_householder_sequence_transpose_on_the_left_in_place_req, @@ -1176,6 +1178,7 @@ mod tests { }, Mat, }; + use assert_approx_eq::assert_approx_eq; macro_rules! make_stack { ($req: expr $(,)?) => { diff --git a/faer-libs/faer-evd/src/tridiag_qr_algorithm.rs b/src/linalg/evd/tridiag_qr_algorithm.rs similarity index 98% rename from faer-libs/faer-evd/src/tridiag_qr_algorithm.rs rename to src/linalg/evd/tridiag_qr_algorithm.rs index 7d0419f527eb9b10893a9a698686227f1e313391..374c09b529063e5736c393ae652215b7d79e7cfb 100644 --- a/faer-libs/faer-evd/src/tridiag_qr_algorithm.rs +++ b/src/linalg/evd/tridiag_qr_algorithm.rs @@ -8,8 +8,9 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. -use faer_core::{ - jacobi::JacobiRotation, permutation::swap_cols, unzipped, zipped, MatMut, RealField, +use crate::{ + linalg::svd::jacobi::JacobiRotation, perm::swap_cols_idx as swap_cols, unzipped, zipped, + MatMut, RealField, }; use reborrow::*; @@ -168,8 +169,8 @@ pub fn compute_tridiag_real_evd_qr_algorithm( #[cfg(test)] mod tests { use super::*; + use crate::{assert, Mat}; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, Mat}; #[track_caller] fn test_evd(diag: &[f64], offdiag: &[f64]) { diff --git a/faer-libs/faer-evd/src/tridiag_real_evd.rs b/src/linalg/evd/tridiag_real_evd.rs similarity index 99% rename from faer-libs/faer-evd/src/tridiag_real_evd.rs rename to src/linalg/evd/tridiag_real_evd.rs index d1b95dfe5d81b455db4849cb4f5315f8839da493..9947df35bb3412f1f3dcbac4c7e7876afa28f4ab 100644 --- a/faer-libs/faer-evd/src/tridiag_real_evd.rs +++ b/src/linalg/evd/tridiag_real_evd.rs @@ -1,10 +1,12 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ debug_assert, - mul::{inner_prod::inner_prod_with_conj, matmul}, - temp_mat_req, temp_mat_uninit, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, - Parallelism, RealField, + linalg::{ + matmul::{inner_prod::inner_prod_with_conj, matmul}, + temp_mat_req, temp_mat_uninit, + }, + unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, RealField, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; pub fn norm2(v: MatRef<'_, E>) -> E::Real { @@ -616,7 +618,7 @@ fn compute_tridiag_real_evd_impl( return; } if n <= 32 { - crate::tridiag_qr_algorithm::compute_tridiag_real_evd_qr_algorithm( + super::tridiag_qr_algorithm::compute_tridiag_real_evd_qr_algorithm( diag, offdiag, Some(u), @@ -651,7 +653,7 @@ fn compute_tridiag_real_evd_impl( let (repaired_u0, repaired_u1) = repaired_u.rb_mut().split_at_col_mut(n1); let (tmp0, tmp1) = tmp.rb_mut().split_at_col_mut(n1); - faer_core::join_raw( + crate::utils::thread::join_raw( |parallelism| { compute_tridiag_real_evd_impl( diag0, @@ -850,7 +852,7 @@ fn compute_tridiag_real_evd_impl( let head = householder.read(run_len - 1, 0); let tail_norm = householder.rb().subrows(0, run_len - 1).norm_l2(); - let (tau, beta) = faer_core::householder::make_householder_in_place_v2( + let (tau, beta) = crate::linalg::householder::make_householder_in_place( Some( householder .rb_mut() @@ -1052,7 +1054,7 @@ fn compute_tridiag_real_evd_impl( let (repaired_u_top, repaired_u_bot) = repaired_u.rb().split_at_row(n1); let (tmp_top, tmp_bot) = tmp.rb_mut().split_at_row_mut(n1); - faer_core::join_raw( + crate::utils::thread::join_raw( |parallelism| { matmul( tmp_top, @@ -1108,8 +1110,8 @@ pub fn compute_tridiag_real_evd_req( #[cfg(test)] mod tests { use super::*; + use crate::{assert, Mat}; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, Mat}; macro_rules! make_stack { ($req: expr) => { diff --git a/faer-libs/faer-core/src/householder.rs b/src/linalg/householder.rs similarity index 93% rename from faer-libs/faer-core/src/householder.rs rename to src/linalg/householder.rs index 285d617c91921b83224bb09ab2026f49152e7275..4d0c3ead51a8a65ad663204f0e7efd42bca8d092 100644 --- a/faer-libs/faer-core/src/householder.rs +++ b/src/linalg/householder.rs @@ -30,69 +30,30 @@ use crate::{ assert, - group_helpers::*, - join_raw, - mul::{ - inner_prod, matmul, matmul_with_conj, - triangular::{self, BlockStructure}, + linalg::{ + matmul::{ + inner_prod, matmul, matmul_with_conj, + triangular::{self, BlockStructure}, + }, + temp_mat_req, temp_mat_uninit, triangular_solve as solve, }, - solve, temp_mat_req, temp_mat_uninit, unzipped, zipped, ComplexField, Conj, DivCeil, Entity, - MatMut, MatRef, Parallelism, + unzipped, + utils::{simd::*, slice::*, thread::join_raw, DivCeil}, + zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, }; use dyn_stack::{PodStack, SizeOverflow, StackReq}; use faer_entity::*; use num_complex::Complex; use reborrow::*; -#[doc(hidden)] +/// Computes the Householder reflection $I - \frac{v v^H}{\tau}$ such that when multiplied by $x$ +/// from the left, The result is $\beta e_0$. $\tau$ and $\beta$ are returned and $\tau$ is +/// real-valued. +/// +/// $x$ is determined by $x_0$, contained in `head`, and $|x_{1\dots}|$, contained in `tail_norm`. +/// The vector $v$ is such that $v_0 = 1$ and $v_{1\dots}$ is stored in `essential` (when provided). #[inline] -#[deprecated] pub fn make_householder_in_place( - essential: Option>, - head: E, - tail_squared_norm: E::Real, -) -> (E, E) { - if tail_squared_norm == E::Real::faer_zero() { - return (E::faer_from_real(E::Real::faer_zero().faer_inv()), head); - } - - let one_half = E::Real::faer_from_f64(0.5); - - let head_squared_norm = head.faer_mul(head.faer_conj()).faer_real(); - let norm = head_squared_norm.faer_add(tail_squared_norm).faer_sqrt(); - - let sign = if head_squared_norm != E::Real::faer_zero() { - head.faer_scale_real(head_squared_norm.faer_sqrt().faer_inv()) - } else { - E::faer_one() - }; - - let signed_norm = sign.faer_mul(E::faer_from_real(norm)); - let head_with_beta = head.faer_add(signed_norm); - let head_with_beta_inv = head_with_beta.faer_inv(); - - if head_with_beta != E::faer_zero() { - if let Some(essential) = essential { - assert!(essential.ncols() == 1); - zipped!(essential) - .for_each(|unzipped!(mut e)| e.write(e.read().faer_mul(head_with_beta_inv))); - } - let tau = one_half.faer_mul( - E::Real::faer_one() - .faer_add(tail_squared_norm.faer_mul(head_with_beta_inv.faer_abs2())), - ); - (E::faer_from_real(tau), signed_norm.faer_neg()) - } else { - ( - E::faer_from_real(E::Real::faer_zero().faer_inv()), - E::faer_zero(), - ) - } -} - -#[doc(hidden)] -#[inline] -pub fn make_householder_in_place_v2( essential: Option>, head: E, tail_norm: E::Real, @@ -492,7 +453,10 @@ fn apply_block_householder_on_the_left_in_place_generic( // essentials* × mat let (tmp, _) = temp_mat_uninit::(bs, n, stack); - let mut n_tasks = Ord::min(Ord::min(crate::parallelism_degree(parallelism), n), 4); + let mut n_tasks = Ord::min( + Ord::min(crate::utils::thread::parallelism_degree(parallelism), n), + 4, + ); if (m * n).saturating_mul(4 * bs) < gemm::get_threading_threshold() { n_tasks = 1; } @@ -513,10 +477,10 @@ fn apply_block_householder_on_the_left_in_place_generic( } }; - crate::for_each_raw( + crate::utils::thread::for_each_raw( n_tasks, |tid| { - let (tid_col, tid_n) = crate::par_split_indices(n, tid, n_tasks); + let (tid_col, tid_n) = crate::utils::thread::par_split_indices(n, tid, n_tasks); let mut tmp = unsafe { tmp.rb().subcols(tid_col, tid_n).const_cast() }; let (mut matrix_top, mut matrix_bot) = unsafe { diff --git a/src/linalg/kron_impl.rs b/src/linalg/kron_impl.rs new file mode 100644 index 0000000000000000000000000000000000000000..dc22647a22a7cb23685ac0b086377018c3bb1814 --- /dev/null +++ b/src/linalg/kron_impl.rs @@ -0,0 +1,119 @@ +use crate::{assert, mat::*, *}; +use reborrow::*; + +/// Kronecker product of two matrices. +/// +/// The Kronecker product of two matrices `A` and `B` is a block matrix +/// `C` with the following structure: +/// +/// ```text +/// C = [ a[(0, 0)] * B , a[(0, 1)] * B , ... , a[(0, n-1)] * B ] +/// [ a[(1, 0)] * B , a[(1, 1)] * B , ... , a[(1, n-1)] * B ] +/// [ ... , ... , ... , ... ] +/// [ a[(m-1, 0)] * B , a[(m-1, 1)] * B , ... , a[(m-1, n-1)] * B ] +/// ``` +/// +/// # Panics +/// +/// Panics if `dst` does not have the correct dimensions. The dimensions +/// of `dst` must be `nrows(A) * nrows(B)` by `ncols(A) * ncols(B)`. +/// +/// # Example +/// +/// ``` +/// use faer::{linalg::kron, mat, Mat}; +/// +/// let a = mat![[1.0, 2.0], [3.0, 4.0]]; +/// let b = mat![[0.0, 5.0], [6.0, 7.0]]; +/// let c = mat![ +/// [0.0, 5.0, 0.0, 10.0], +/// [6.0, 7.0, 12.0, 14.0], +/// [0.0, 15.0, 0.0, 20.0], +/// [18.0, 21.0, 24.0, 28.0], +/// ]; +/// let mut dst = Mat::new(); +/// dst.resize_with(4, 4, |_, _| 0f64); +/// kron(dst.as_mut(), a.as_ref(), b.as_ref()); +/// assert_eq!(dst, c); +/// ``` +#[track_caller] +pub fn kron(dst: MatMut, lhs: MatRef, rhs: MatRef) { + let mut dst = dst; + let mut lhs = lhs; + let mut rhs = rhs; + if dst.col_stride().unsigned_abs() < dst.row_stride().unsigned_abs() { + dst = dst.transpose_mut(); + lhs = lhs.transpose(); + rhs = rhs.transpose(); + } + + assert!(Some(dst.nrows()) == lhs.nrows().checked_mul(rhs.nrows())); + assert!(Some(dst.ncols()) == lhs.ncols().checked_mul(rhs.ncols())); + + for lhs_j in 0..lhs.ncols() { + for lhs_i in 0..lhs.nrows() { + let lhs_val = lhs.read(lhs_i, lhs_j); + let mut dst = dst.rb_mut().submatrix_mut( + lhs_i * rhs.nrows(), + lhs_j * rhs.ncols(), + rhs.nrows(), + rhs.ncols(), + ); + + for rhs_j in 0..rhs.ncols() { + for rhs_i in 0..rhs.nrows() { + // SAFETY: Bounds have been checked. + unsafe { + let rhs_val = rhs.read_unchecked(rhs_i, rhs_j); + dst.write_unchecked(rhs_i, rhs_j, lhs_val.faer_mul(rhs_val)); + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::{assert, prelude::*}; + + #[test] + fn test_kron_ones() { + for (m, n, p, q) in [(2, 3, 4, 5), (3, 2, 5, 4), (1, 1, 1, 1)] { + let a = Mat::from_fn(m, n, |_, _| 1 as f64); + let b = Mat::from_fn(p, q, |_, _| 1 as f64); + let expected = Mat::from_fn(m * p, n * q, |_, _| 1 as f64); + assert!(a.kron(&b) == expected); + } + + for (m, n, p) in [(2, 3, 4), (3, 2, 5), (1, 1, 1)] { + let a = Mat::from_fn(m, n, |_, _| 1 as f64); + let b = Col::from_fn(p, |_| 1 as f64); + let expected = Mat::from_fn(m * p, n, |_, _| 1 as f64); + assert!(a.kron(&b) == expected); + assert!(b.kron(&a) == expected); + + let a = Mat::from_fn(m, n, |_, _| 1 as f64); + let b = Row::from_fn(p, |_| 1 as f64); + let expected = Mat::from_fn(m, n * p, |_, _| 1 as f64); + assert!(a.kron(&b) == expected); + assert!(b.kron(&a) == expected); + } + + for (m, n) in [(2, 3), (3, 2), (1, 1)] { + let a = Row::from_fn(m, |_| 1 as f64); + let b = Col::from_fn(n, |_| 1 as f64); + let expected = Mat::from_fn(n, m, |_, _| 1 as f64); + assert!(a.kron(&b) == expected); + assert!(b.kron(&a) == expected); + + let c = Row::from_fn(n, |_| 1 as f64); + let expected = Mat::from_fn(1, m * n, |_, _| 1 as f64); + assert!(a.kron(&c) == expected); + + let d = Col::from_fn(m, |_| 1 as f64); + let expected = Mat::from_fn(m * n, 1, |_, _| 1 as f64); + assert!(d.kron(&b) == expected); + } + } +} diff --git a/faer-libs/faer-lu/src/full_pivoting/compute.rs b/src/linalg/lu/full_pivoting/compute.rs similarity index 96% rename from faer-libs/faer-lu/src/full_pivoting/compute.rs rename to src/linalg/lu/full_pivoting/compute.rs index 88d6edd7f9601eec1bdf5a032f78933976e8e177..d17b66f11f3223daca5911d6783c30e768ae76ee 100644 --- a/faer-libs/faer-lu/src/full_pivoting/compute.rs +++ b/src/linalg/lu/full_pivoting/compute.rs @@ -1,14 +1,16 @@ +use crate::{ + assert, + complex_native::{c32, c64}, + debug_assert, + linalg::matmul::matmul, + perm::{swap_cols_idx as swap_cols, swap_rows_idx as swap_rows, PermRef}, + utils::{simd::*, slice::*}, + ComplexField, Entity, Index, MatMut, MatRef, Parallelism, RealField, SignedIndex, +}; use bytemuck::cast; use coe::Coerce; use core::slice; use dyn_stack::{PodStack, StackReq}; -use faer_core::{ - assert, c32, c64, debug_assert, - group_helpers::*, - mul::matmul, - permutation::{swap_cols, swap_rows, Index, PermutationMut, SignedIndex}, - simd, ComplexField, Entity, MatMut, MatRef, Parallelism, RealField, -}; use faer_entity::*; use paste::paste; use pulp::{cast_lossy, Simd}; @@ -719,14 +721,15 @@ fn best_in_matrix_simd(matrix: MatRef<'_, E>) -> (usize, usize, m + offset.rotate_left_amount(), core::mem::size_of::>() / core::mem::size_of::>(), ); - let (best_value, best_row, best_col) = reduce2d::( - len, - simd::one_simd_as_slice::(E::Real::faer_as_ref( - &from_copy::(best_value), - )), - simd::simd_index_as_slice::(&[best_row]), - simd::simd_index_as_slice::(&[best_col]), - ); + let (best_value, best_row, best_col) = + reduce2d::( + len, + faer_entity::one_simd_as_slice::(E::Real::faer_as_ref( + &from_copy::(best_value), + )), + faer_entity::simd_index_as_slice::(&[best_row]), + faer_entity::simd_index_as_slice::(&[best_col]), + ); ( E::Real::faer_index_to_usize(best_row), @@ -801,14 +804,15 @@ fn update_and_best_in_matrix_simd( m + offset.rotate_left_amount(), core::mem::size_of::>() / core::mem::size_of::>(), ); - let (best_value, best_row, best_col) = reduce2d::( - len, - simd::one_simd_as_slice::(E::Real::faer_as_ref( - &from_copy::(best_value), - )), - simd::simd_index_as_slice::(&[best_row]), - simd::simd_index_as_slice::(&[best_col]), - ); + let (best_value, best_row, best_col) = + reduce2d::( + len, + faer_entity::one_simd_as_slice::(E::Real::faer_as_ref( + &from_copy::(best_value), + )), + faer_entity::simd_index_as_slice::(&[best_row]), + faer_entity::simd_index_as_slice::(&[best_col]), + ); ( E::Real::faer_index_to_usize(best_row), @@ -872,14 +876,15 @@ fn update_and_best_in_matrix_simd( m + offset.rotate_left_amount(), core::mem::size_of::>() / core::mem::size_of::>(), ); - let (best_value, best_row, best_col) = reduce2d::( - len, - simd::one_simd_as_slice::(E::Real::faer_as_ref( - &from_copy::(best_value), - )), - simd::simd_index_as_slice::(&[best_row]), - simd::simd_index_as_slice::(&[best_col]), - ); + let (best_value, best_row, best_col) = + reduce2d::( + len, + faer_entity::one_simd_as_slice::(E::Real::faer_as_ref( + &from_copy::(best_value), + )), + faer_entity::simd_index_as_slice::(&[best_row]), + faer_entity::simd_index_as_slice::(&[best_col]), + ); ( E::Real::faer_index_to_usize(best_row), @@ -1321,7 +1326,9 @@ fn lu_in_place_unblocked( } #[cfg(feature = "rayon")] _ => { - use faer_core::{for_each_raw, par_split_indices, parallelism_degree, Ptr}; + use crate::utils::thread::{ + for_each_raw, par_split_indices, parallelism_degree, Ptr, + }; let n_threads = parallelism_degree(parallelism); @@ -1375,9 +1382,12 @@ fn lu_in_place_unblocked( n_transpositions } +/// LU factorization tuning parameters. #[derive(Default, Copy, Clone)] #[non_exhaustive] pub struct FullPivLuComputeParams { + /// At which size the parallelism should be disabled. `None` to automatically determine this + /// threshold. pub disable_parallelism: Option bool>, } @@ -1399,8 +1409,11 @@ fn default_disable_parallelism(m: usize, n: usize) -> bool { prod < 512 * 256 } +/// Information about the resulting LU factorization. #[derive(Copy, Clone, Debug)] pub struct FullPivLuInfo { + /// Number of transpositions that were performed, can be used to compute the determinant of + /// $PQ$. pub transposition_count: usize, } @@ -1447,11 +1460,7 @@ pub fn lu_in_place<'out, I: Index, E: ComplexField>( parallelism: Parallelism, stack: PodStack<'_>, params: FullPivLuComputeParams, -) -> ( - FullPivLuInfo, - PermutationMut<'out, I, E>, - PermutationMut<'out, I, E>, -) { +) -> (FullPivLuInfo, PermRef<'out, I>, PermRef<'out, I>) { let disable_parallelism = params .disable_parallelism .unwrap_or(default_disable_parallelism); @@ -1470,7 +1479,7 @@ pub fn lu_in_place<'out, I: Index, E: ComplexField>( #[cfg(feature = "perf-warn")] if (matrix.col_stride().unsigned_abs() == 1 || matrix.row_stride().unsigned_abs() != 1) - && faer_core::__perf_warn!(LU_WARN) + && crate::__perf_warn!(LU_WARN) { log::warn!(target: "faer_perf", "LU with full pivoting prefers column-major or row-major matrix. Found matrix with generic strides."); } @@ -1526,8 +1535,8 @@ pub fn lu_in_place<'out, I: Index, E: ComplexField>( FullPivLuInfo { transposition_count: n_transpositions, }, - PermutationMut::new_unchecked(row_perm, row_perm_inv), - PermutationMut::new_unchecked(col_perm, col_perm_inv), + PermRef::new_unchecked(row_perm, row_perm_inv), + PermRef::new_unchecked(col_perm, col_perm_inv), ) } } @@ -1535,8 +1544,13 @@ pub fn lu_in_place<'out, I: Index, E: ComplexField>( #[cfg(test)] mod tests { use super::*; - use crate::full_pivoting::reconstruct; - use faer_core::{assert, c32, c64, permutation::PermutationRef, Mat}; + use crate::{ + assert, + complex_native::{c32, c64}, + linalg::lu::full_pivoting::reconstruct, + perm::PermRef, + Mat, + }; use rand::random; macro_rules! make_stack { @@ -1547,8 +1561,8 @@ mod tests { fn reconstruct_matrix( lu_factors: MatRef<'_, E>, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, ) -> Mat { let m = lu_factors.nrows(); let n = lu_factors.ncols(); diff --git a/faer-libs/faer-lu/src/full_pivoting/inverse.rs b/src/linalg/lu/full_pivoting/inverse.rs similarity index 86% rename from faer-libs/faer-lu/src/full_pivoting/inverse.rs rename to src/linalg/lu/full_pivoting/inverse.rs index fed2973fe9d06fc3c7029ec9598a4acd04a7852b..e55675e313ef3662917482a02d8b3445f8cd4312 100644 --- a/faer-libs/faer-lu/src/full_pivoting/inverse.rs +++ b/src/linalg/lu/full_pivoting/inverse.rs @@ -1,18 +1,19 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, debug_assert, join_raw, - mul::triangular, - permutation::{Index, PermutationRef, SignedIndex}, - temp_mat_req, temp_mat_uninit, ComplexField, Entity, MatMut, MatRef, Parallelism, +use crate::{ + assert, debug_assert, + linalg::{matmul::triangular, temp_mat_req, temp_mat_uninit}, + perm::PermRef, + utils::thread::join_raw, + ComplexField, Entity, Index, MatMut, MatRef, Parallelism, SignedIndex, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; use triangular::BlockStructure; fn invert_impl( mut dst: MatMut<'_, E>, lu_factors: Option>, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -39,9 +40,19 @@ fn invert_impl( join_raw( |parallelism| { - faer_core::inverse::invert_unit_lower_triangular(l_inv, lu_factors, parallelism) + crate::linalg::triangular_inverse::invert_unit_lower_triangular( + l_inv, + lu_factors, + parallelism, + ) + }, + |parallelism| { + crate::linalg::triangular_inverse::invert_upper_triangular( + u_inv, + lu_factors, + parallelism, + ) }, - |parallelism| faer_core::inverse::invert_upper_triangular(u_inv, lu_factors, parallelism), parallelism, ); @@ -60,8 +71,8 @@ fn invert_impl( parallelism, ); - let col_p = row_perm.into_arrays().1; - let row_p = col_perm.into_arrays().1; + let col_p = row_perm.arrays().1; + let row_p = col_perm.arrays().1; assert!(row_p.len() == n); assert!(col_p.len() == n); unsafe { @@ -99,8 +110,8 @@ fn invert_impl( pub fn invert( dst: MatMut<'_, E>, lu_factors: MatRef<'_, E>, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -134,8 +145,8 @@ pub fn invert( #[track_caller] pub fn invert_in_place( lu_factors: MatMut<'_, E>, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -173,9 +184,15 @@ pub fn invert_in_place_req( #[cfg(test)] mod tests { use super::*; - use crate::full_pivoting::compute::{lu_in_place, lu_in_place_req}; + use crate::{ + assert, + linalg::{ + lu::full_pivoting::compute::{lu_in_place, lu_in_place_req}, + matmul::matmul, + }, + Mat, Parallelism, + }; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, mul::matmul, Mat, Parallelism}; use rand::random; macro_rules! make_stack { diff --git a/faer-libs/faer-lu/src/full_pivoting/mod.rs b/src/linalg/lu/full_pivoting/mod.rs similarity index 63% rename from faer-libs/faer-lu/src/full_pivoting/mod.rs rename to src/linalg/lu/full_pivoting/mod.rs index 95f2c3e733beb7e8a19ca0357da71fa6ed8f4d13..81e0d7e698cff4e5ae5ae8711879dd3062afe6bc 100644 --- a/faer-libs/faer-lu/src/full_pivoting/mod.rs +++ b/src/linalg/lu/full_pivoting/mod.rs @@ -6,7 +6,11 @@ //! The full pivoting LU decomposition is more numerically stable than the one with partial //! pivoting, but is more expensive to compute. +/// Computing the decomposition. pub mod compute; +/// Reconstructing the inverse of the original matrix from the decomposition. pub mod inverse; +/// Reconstructing the inverse of the original matrix from the decomposition. pub mod reconstruct; +/// Solving a linear system usin the decomposition. pub mod solve; diff --git a/faer-libs/faer-lu/src/full_pivoting/reconstruct.rs b/src/linalg/lu/full_pivoting/reconstruct.rs similarity index 90% rename from faer-libs/faer-lu/src/full_pivoting/reconstruct.rs rename to src/linalg/lu/full_pivoting/reconstruct.rs index 0a0852d0cfdb239332cc6ce3d5462f7dc928f5d5..b4b43acbce93ce67137e11ad43c771ab948aa1d5 100644 --- a/faer-libs/faer-lu/src/full_pivoting/reconstruct.rs +++ b/src/linalg/lu/full_pivoting/reconstruct.rs @@ -1,10 +1,10 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - mul::triangular, - permutation::{Index, PermutationRef, SignedIndex}, - temp_mat_req, temp_mat_uninit, ComplexField, Entity, MatMut, MatRef, Parallelism, + linalg::{matmul::triangular, temp_mat_req, temp_mat_uninit}, + perm::PermRef, + ComplexField, Entity, Index, MatMut, MatRef, Parallelism, SignedIndex, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; use triangular::BlockStructure; @@ -12,8 +12,8 @@ use triangular::BlockStructure; fn reconstruct_impl( mut dst: MatMut<'_, E>, lu_factors: Option>, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -68,8 +68,8 @@ fn reconstruct_impl( parallelism, ); - let row_inv = row_perm.into_arrays().1; - let col_inv = col_perm.into_arrays().1; + let row_inv = row_perm.arrays().1; + let col_inv = col_perm.arrays().1; assert!(row_inv.len() == m); assert!(col_inv.len() == n); unsafe { @@ -108,8 +108,8 @@ fn reconstruct_impl( pub fn reconstruct( dst: MatMut<'_, E>, lu_factors: MatRef<'_, E>, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -138,8 +138,8 @@ pub fn reconstruct( #[track_caller] pub fn reconstruct_in_place( lu_factors: MatMut<'_, E>, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { diff --git a/faer-libs/faer-lu/src/full_pivoting/solve.rs b/src/linalg/lu/full_pivoting/solve.rs similarity index 95% rename from faer-libs/faer-lu/src/full_pivoting/solve.rs rename to src/linalg/lu/full_pivoting/solve.rs index 410e10fe630ef9bad64077d07bab4a0b7e950956..2ca920f765775320db61a6ddfc3c999c6e4fd27f 100644 --- a/faer-libs/faer-lu/src/full_pivoting/solve.rs +++ b/src/linalg/lu/full_pivoting/solve.rs @@ -1,16 +1,16 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - permutation::{permute_rows, Index, PermutationRef}, - solve::*, - temp_mat_req, temp_mat_uninit, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, +use crate::{ + linalg::{temp_mat_req, temp_mat_uninit, triangular_solve::*}, + perm::{permute_rows, PermRef}, + ComplexField, Conj, Entity, Index, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; fn solve_impl( lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, dst: MatMut<'_, E>, rhs: Option>, parallelism: Parallelism, @@ -52,8 +52,8 @@ fn solve_impl( fn solve_transpose_impl( lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, dst: MatMut<'_, E>, rhs: Option>, parallelism: Parallelism, @@ -170,8 +170,8 @@ pub fn solve( dst: MatMut<'_, E>, lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, rhs: MatRef<'_, E>, parallelism: Parallelism, stack: PodStack<'_>, @@ -206,8 +206,8 @@ pub fn solve( pub fn solve_in_place( lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, rhs: MatMut<'_, E>, parallelism: Parallelism, stack: PodStack<'_>, @@ -244,8 +244,8 @@ pub fn solve_transpose( dst: MatMut<'_, E>, lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, rhs: MatRef<'_, E>, parallelism: Parallelism, stack: PodStack<'_>, @@ -281,8 +281,8 @@ pub fn solve_transpose( pub fn solve_transpose_in_place( lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, rhs: MatMut<'_, E>, parallelism: Parallelism, stack: PodStack<'_>, @@ -302,8 +302,15 @@ pub fn solve_transpose_in_place( #[cfg(test)] mod tests { use super::*; - use crate::full_pivoting::compute::{lu_in_place, lu_in_place_req}; - use faer_core::{assert, c32, c64, mul::matmul_with_conj, Mat}; + use crate::{ + assert, + complex_native::{c32, c64}, + linalg::{ + lu::full_pivoting::compute::{lu_in_place, lu_in_place_req}, + matmul::matmul_with_conj, + }, + Mat, + }; use std::cell::RefCell; macro_rules! make_stack { diff --git a/src/linalg/lu/mod.rs b/src/linalg/lu/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..61d018906a1f706e341ceaf1e5c1a3fa621e45d3 --- /dev/null +++ b/src/linalg/lu/mod.rs @@ -0,0 +1,4 @@ +//! Low level implementation of the LU decompositions. + +pub mod full_pivoting; +pub mod partial_pivoting; diff --git a/faer-libs/faer-lu/src/partial_pivoting/compute.rs b/src/linalg/lu/partial_pivoting/compute.rs similarity index 95% rename from faer-libs/faer-lu/src/partial_pivoting/compute.rs rename to src/linalg/lu/partial_pivoting/compute.rs index d7bcecc8457a5c56dbc433b970b24cb728647a4e..915cdc2ab4a7854ff0850f0257c473834aa45382 100644 --- a/faer-libs/faer-lu/src/partial_pivoting/compute.rs +++ b/src/linalg/lu/partial_pivoting/compute.rs @@ -1,12 +1,14 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, debug_assert, - group_helpers::*, - mul::matmul, - permutation::{Index, PermutationMut, SignedIndex}, - solve::solve_unit_lower_triangular_in_place, - unzipped, zipped, ComplexField, Entity, MatMut, Parallelism, SimdCtx, + linalg::{ + entity::SimdCtx, matmul::matmul, triangular_solve::solve_unit_lower_triangular_in_place, + }, + perm::PermRef, + unzipped, + utils::{simd::*, slice::*}, + zipped, ComplexField, Entity, Index, MatMut, Parallelism, SignedIndex, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use faer_entity::*; use reborrow::*; @@ -278,7 +280,7 @@ pub fn lu_in_place_impl( }; if matrix.row_stride() == 1 { - faer_core::for_each_raw( + crate::utils::thread::for_each_raw( col_start + (full_n - (col_start + n)), |j| { let j = if j >= col_start { col_start + n + j } else { j }; @@ -302,7 +304,7 @@ pub fn lu_in_place_impl( parallelism, ); } else { - faer_core::for_each_raw( + crate::utils::thread::for_each_raw( col_start + (full_n - (col_start + n)), |j| { let j = if j >= col_start { col_start + n + j } else { j }; @@ -330,12 +332,16 @@ pub fn lu_in_place_impl( n_transpositions } +/// LUfactorization tuning parameters. #[derive(Default, Copy, Clone)] #[non_exhaustive] pub struct PartialPivLuComputeParams {} +/// Information about the resulting LU factorization. #[derive(Copy, Clone, Debug)] pub struct PartialPivLuInfo { + /// Number of transpositions that were performed, can be used to compute the determinant of + /// $P$. pub transposition_count: usize, } @@ -387,7 +393,7 @@ pub fn lu_in_place<'out, I: Index, E: ComplexField>( parallelism: Parallelism, stack: PodStack<'_>, params: PartialPivLuComputeParams, -) -> (PartialPivLuInfo, PermutationMut<'out, I, E>) { +) -> (PartialPivLuInfo, PermRef<'out, I>) { let _ = ¶ms; let truncate = ::truncate; @@ -396,7 +402,7 @@ pub fn lu_in_place<'out, I: Index, E: ComplexField>( #[cfg(feature = "perf-warn")] if (matrix.col_stride().unsigned_abs() == 1 || matrix.row_stride().unsigned_abs() != 1) - && faer_core::__perf_warn!(LU_WARN) + && crate::__perf_warn!(LU_WARN) { log::warn!(target: "faer_perf", "LU with partial pivoting prefers column-major or row-major matrix. Found matrix with generic strides."); } @@ -434,17 +440,16 @@ pub fn lu_in_place<'out, I: Index, E: ComplexField>( PartialPivLuInfo { transposition_count: n_transpositions, }, - unsafe { PermutationMut::new_unchecked(perm, perm_inv) }, + unsafe { PermRef::new_unchecked(perm, perm_inv) }, ) } #[cfg(test)] mod tests { use super::*; - use crate::partial_pivoting::reconstruct; + use crate::{assert, linalg::lu::partial_pivoting::reconstruct, Mat, MatRef}; use assert_approx_eq::assert_approx_eq; use dyn_stack::GlobalPodBuffer; - use faer_core::{assert, permutation::PermutationRef, Mat, MatRef}; use rand::random; macro_rules! make_stack { @@ -455,7 +460,7 @@ mod tests { fn reconstruct_matrix( lu_factors: MatRef<'_, E>, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, ) -> Mat { let m = lu_factors.nrows(); let n = lu_factors.ncols(); diff --git a/faer-libs/faer-lu/src/partial_pivoting/inverse.rs b/src/linalg/lu/partial_pivoting/inverse.rs similarity index 91% rename from faer-libs/faer-lu/src/partial_pivoting/inverse.rs rename to src/linalg/lu/partial_pivoting/inverse.rs index 2e94082457685630aa8020f8aca9c60deac829c5..231a52135f1ecf8d01139158f3c4cd61a90b55b1 100644 --- a/faer-libs/faer-lu/src/partial_pivoting/inverse.rs +++ b/src/linalg/lu/partial_pivoting/inverse.rs @@ -1,17 +1,18 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, debug_assert, inverse, join_raw, - mul::triangular, - permutation::{permute_cols, Index, PermutationRef}, - temp_mat_req, temp_mat_uninit, ComplexField, Entity, MatMut, MatRef, Parallelism, +use crate::{ + assert, debug_assert, + linalg::{matmul::triangular, temp_mat_req, temp_mat_uninit, triangular_inverse as inverse}, + perm::{permute_cols, PermRef}, + utils::thread::join_raw, + ComplexField, Entity, Index, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; use triangular::BlockStructure; fn invert_impl( dst: MatMut<'_, E>, lu_factors: Option>, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -95,7 +96,7 @@ pub fn invert_req( pub fn invert( dst: MatMut<'_, E>, lu_factors: MatRef<'_, E>, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -121,7 +122,7 @@ pub fn invert( #[track_caller] pub fn invert_in_place( lu_factors: MatMut<'_, E>, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -136,9 +137,15 @@ pub fn invert_in_place( #[cfg(test)] mod tests { use super::*; - use crate::partial_pivoting::compute::{lu_in_place, lu_in_place_req}; + use crate::{ + assert, + linalg::{ + lu::partial_pivoting::compute::{lu_in_place, lu_in_place_req}, + matmul::matmul, + }, + Mat, Parallelism, + }; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, mul::matmul, Mat, Parallelism}; use rand::random; macro_rules! make_stack { diff --git a/faer-libs/faer-lu/src/partial_pivoting/mod.rs b/src/linalg/lu/partial_pivoting/mod.rs similarity index 53% rename from faer-libs/faer-lu/src/partial_pivoting/mod.rs rename to src/linalg/lu/partial_pivoting/mod.rs index 97107c5aafabb03c8d8b79b405f21c904db543f0..081d95e93d2da5e41536bd3d2e157a8506e52a95 100644 --- a/faer-libs/faer-lu/src/partial_pivoting/mod.rs +++ b/src/linalg/lu/partial_pivoting/mod.rs @@ -3,7 +3,11 @@ //! where $P$ is a permutation matrix, $L$ is a unit lower triangular matrix, and $U$ is //! an upper triangular matrix. +/// Computing the decomposition. pub mod compute; +/// Reconstructing the inverse of the original matrix from the decomposition. pub mod inverse; +/// Reconstructing the original matrix from the decomposition. pub mod reconstruct; +/// Solving a linear system usin the decomposition. pub mod solve; diff --git a/faer-libs/faer-lu/src/partial_pivoting/reconstruct.rs b/src/linalg/lu/partial_pivoting/reconstruct.rs similarity index 91% rename from faer-libs/faer-lu/src/partial_pivoting/reconstruct.rs rename to src/linalg/lu/partial_pivoting/reconstruct.rs index 07f19017f073f17d10fe67cde4e9fdcb6978f6f5..ee6e0d5b1db4203377ddce9e9c8b7d2faf9c083c 100644 --- a/faer-libs/faer-lu/src/partial_pivoting/reconstruct.rs +++ b/src/linalg/lu/partial_pivoting/reconstruct.rs @@ -1,10 +1,10 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - mul::triangular, - permutation::{Index, PermutationRef}, - temp_mat_req, temp_mat_uninit, ComplexField, Entity, MatMut, MatRef, Parallelism, + linalg::{matmul::triangular, temp_mat_req, temp_mat_uninit}, + perm::PermRef, + ComplexField, Entity, Index, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; use triangular::BlockStructure; @@ -12,7 +12,7 @@ use triangular::BlockStructure; fn reconstruct_impl( dst: MatMut<'_, E>, lu_factors: Option>, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -67,7 +67,7 @@ fn reconstruct_impl( parallelism, ); - faer_core::permutation::permute_rows(dst, lu.rb(), row_perm.inverse()); + crate::perm::permute_rows(dst, lu.rb(), row_perm.inverse()); } /// Computes the reconstructed matrix, given its partial pivoting LU decomposition, @@ -83,7 +83,7 @@ fn reconstruct_impl( pub fn reconstruct( dst: MatMut<'_, E>, lu_factors: MatRef<'_, E>, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -103,7 +103,7 @@ pub fn reconstruct( #[track_caller] pub fn reconstruct_in_place( lu_factors: MatMut<'_, E>, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { diff --git a/faer-libs/faer-lu/src/partial_pivoting/solve.rs b/src/linalg/lu/partial_pivoting/solve.rs similarity index 96% rename from faer-libs/faer-lu/src/partial_pivoting/solve.rs rename to src/linalg/lu/partial_pivoting/solve.rs index c1b7c382009f534b969c6d7d1d2e59d766a9dd73..3532591d20a45df0e2896490dba687ea9111a083 100644 --- a/faer-libs/faer-lu/src/partial_pivoting/solve.rs +++ b/src/linalg/lu/partial_pivoting/solve.rs @@ -1,16 +1,15 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - permutation::{permute_rows, Index, PermutationRef}, - solve::*, - temp_mat_req, temp_mat_uninit, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, - Parallelism, +use crate::{ + linalg::{temp_mat_req, temp_mat_uninit, triangular_solve::*}, + perm::{permute_rows, PermRef}, + unzipped, zipped, ComplexField, Conj, Entity, Index, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; fn solve_impl( lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, dst: MatMut<'_, E>, rhs: Option>, parallelism: Parallelism, @@ -52,7 +51,7 @@ fn solve_impl( fn solve_transpose_impl( lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, dst: MatMut<'_, E>, rhs: Option>, parallelism: Parallelism, @@ -168,7 +167,7 @@ pub fn solve( dst: MatMut<'_, E>, lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, rhs: MatRef<'_, E>, parallelism: Parallelism, stack: PodStack<'_>, @@ -201,7 +200,7 @@ pub fn solve( pub fn solve_in_place( lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, rhs: MatMut<'_, E>, parallelism: Parallelism, stack: PodStack<'_>, @@ -236,7 +235,7 @@ pub fn solve_transpose( dst: MatMut<'_, E>, lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, rhs: MatRef<'_, E>, parallelism: Parallelism, stack: PodStack<'_>, @@ -269,7 +268,7 @@ pub fn solve_transpose( pub fn solve_transpose_in_place( lu_factors: MatRef<'_, E>, conj_lhs: Conj, - row_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, rhs: MatMut<'_, E>, parallelism: Parallelism, stack: PodStack<'_>, @@ -288,8 +287,15 @@ pub fn solve_transpose_in_place( #[cfg(test)] mod tests { use super::*; - use crate::partial_pivoting::compute::{lu_in_place, lu_in_place_req}; - use faer_core::{assert, c32, c64, mul::matmul_with_conj, Mat}; + use crate::{ + assert, + complex_native::{c32, c64}, + linalg::{ + lu::partial_pivoting::compute::{lu_in_place, lu_in_place_req}, + matmul::matmul_with_conj, + }, + Mat, + }; use std::cell::RefCell; macro_rules! make_stack { diff --git a/src/linalg/mat_ops.rs b/src/linalg/mat_ops.rs new file mode 100644 index 0000000000000000000000000000000000000000..619b37c51b452b6afcba3f40ae849db3548f7f55 --- /dev/null +++ b/src/linalg/mat_ops.rs @@ -0,0 +1,2992 @@ +use crate::{assert, col::*, diag::*, mat::*, perm::*, row::*, sparse::*, *}; +use faer_entity::*; + +use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; + +macro_rules! impl_partial_eq { + ($lhs: ty, $rhs: ty) => { + impl> PartialEq<$rhs> + for $lhs + { + fn eq(&self, other: &$rhs) -> bool { + self.as_ref().eq(&other.as_ref()) + } + } + }; +} + +macro_rules! impl_add_sub { + ($lhs: ty, $rhs: ty, $out: ty) => { + impl, RhsE: Conjugate> + Add<$rhs> for $lhs + { + type Output = $out; + #[track_caller] + fn add(self, other: $rhs) -> Self::Output { + self.as_ref().add(other.as_ref()) + } + } + + impl, RhsE: Conjugate> + Sub<$rhs> for $lhs + { + type Output = $out; + #[track_caller] + fn sub(self, other: $rhs) -> Self::Output { + self.as_ref().sub(other.as_ref()) + } + } + }; +} + +macro_rules! impl_add_sub_assign { + ($lhs: ty, $rhs: ty) => { + impl> AddAssign<$rhs> for $lhs { + #[track_caller] + fn add_assign(&mut self, other: $rhs) { + self.as_mut().add_assign(other.as_ref()) + } + } + + impl> SubAssign<$rhs> for $lhs { + #[track_caller] + fn sub_assign(&mut self, other: $rhs) { + self.as_mut().sub_assign(other.as_ref()) + } + } + }; +} + +macro_rules! impl_neg { + ($mat: ty, $out: ty) => { + impl Neg for $mat + where + E::Canonical: ComplexField, + { + type Output = $out; + #[track_caller] + fn neg(self) -> Self::Output { + self.as_ref().neg() + } + } + }; +} + +macro_rules! impl_mul { + ($lhs: ty, $rhs: ty, $out: ty) => { + impl, RhsE: Conjugate> + Mul<$rhs> for $lhs + { + type Output = $out; + #[track_caller] + fn mul(self, other: $rhs) -> Self::Output { + self.as_ref().mul(other.as_ref()) + } + } + }; +} + +macro_rules! impl_perm { + ($lhs: ty, $rhs: ty, $out: ty) => { + impl Mul<$rhs> for $lhs + where + E::Canonical: ComplexField, + { + type Output = $out; + #[track_caller] + fn mul(self, other: $rhs) -> Self::Output { + self.as_ref().mul(other.as_ref()) + } + } + }; +} + +macro_rules! impl_perm_perm { + ($lhs: ty, $rhs: ty, $out: ty) => { + impl Mul<$rhs> for $lhs { + type Output = $out; + #[track_caller] + fn mul(self, other: $rhs) -> Self::Output { + self.as_ref().mul(other.as_ref()) + } + } + }; +} + +macro_rules! impl_scalar_mul { + ($lhs: ty, $rhs: ty, $out: ty) => { + impl, RhsE: Conjugate> + Mul<$rhs> for $lhs + { + type Output = $out; + #[track_caller] + fn mul(self, other: $rhs) -> Self::Output { + self.mul(other.as_ref()) + } + } + }; +} + +macro_rules! impl_mul_scalar { + ($lhs: ty, $rhs: ty, $out: ty) => { + impl, RhsE: Conjugate> + Mul<$rhs> for $lhs + { + type Output = $out; + #[track_caller] + fn mul(self, other: $rhs) -> Self::Output { + self.as_ref().mul(other) + } + } + }; +} + +macro_rules! impl_mul_assign_scalar { + ($lhs: ty, $rhs: ty) => { + impl> MulAssign<$rhs> for $lhs { + #[track_caller] + fn mul_assign(&mut self, other: $rhs) { + self.as_mut().mul_assign(other) + } + } + }; +} + +macro_rules! impl_sparse_mul { + ($lhs: ty, $rhs: ty, $out: ty) => { + impl< + I: Index, + E: ComplexField, + LhsE: Conjugate, + RhsE: Conjugate, + > Mul<$rhs> for $lhs + where + E::Canonical: ComplexField, + { + type Output = $out; + #[track_caller] + fn mul(self, other: $rhs) -> Self::Output { + self.as_ref().mul(other.as_ref()) + } + } + }; +} + +macro_rules! impl_partial_eq_sparse { + ($lhs: ty, $rhs: ty) => { + impl> + PartialEq<$rhs> for $lhs + { + fn eq(&self, other: &$rhs) -> bool { + self.as_ref().eq(&other.as_ref()) + } + } + }; +} + +macro_rules! impl_add_sub_sparse { + ($lhs: ty, $rhs: ty, $out: ty) => { + impl< + I: Index, + E: ComplexField, + LhsE: Conjugate, + RhsE: Conjugate, + > Add<$rhs> for $lhs + { + type Output = $out; + #[track_caller] + fn add(self, other: $rhs) -> Self::Output { + self.as_ref().add(other.as_ref()) + } + } + + impl< + I: Index, + E: ComplexField, + LhsE: Conjugate, + RhsE: Conjugate, + > Sub<$rhs> for $lhs + { + type Output = $out; + #[track_caller] + fn sub(self, other: $rhs) -> Self::Output { + self.as_ref().sub(other.as_ref()) + } + } + }; +} + +macro_rules! impl_add_sub_assign_sparse { + ($lhs: ty, $rhs: ty) => { + impl> AddAssign<$rhs> + for $lhs + { + #[track_caller] + fn add_assign(&mut self, other: $rhs) { + self.as_mut().add_assign(other.as_ref()) + } + } + + impl> SubAssign<$rhs> + for $lhs + { + #[track_caller] + fn sub_assign(&mut self, other: $rhs) { + self.as_mut().sub_assign(other.as_ref()) + } + } + }; +} + +macro_rules! impl_neg_sparse { + ($mat: ty, $out: ty) => { + impl Neg for $mat + where + E::Canonical: ComplexField, + { + type Output = $out; + #[track_caller] + fn neg(self) -> Self::Output { + self.as_ref().neg() + } + } + }; +} + +macro_rules! impl_scalar_mul_sparse { + ($lhs: ty, $rhs: ty, $out: ty) => { + impl< + I: Index, + E: ComplexField, + LhsE: Conjugate, + RhsE: Conjugate, + > Mul<$rhs> for $lhs + { + type Output = $out; + #[track_caller] + fn mul(self, other: $rhs) -> Self::Output { + self.mul(other.as_ref()) + } + } + }; +} + +macro_rules! impl_mul_scalar_sparse { + ($lhs: ty, $rhs: ty, $out: ty) => { + impl< + I: Index, + E: ComplexField, + LhsE: Conjugate, + RhsE: Conjugate, + > Mul<$rhs> for $lhs + { + type Output = $out; + #[track_caller] + fn mul(self, other: $rhs) -> Self::Output { + self.as_ref().mul(other) + } + } + }; +} + +impl> PartialEq> + for MatRef<'_, LhsE> +{ + fn eq(&self, other: &MatRef<'_, RhsE>) -> bool { + let lhs = *self; + let rhs = *other; + + if (lhs.nrows(), lhs.ncols()) != (rhs.nrows(), rhs.ncols()) { + return false; + } + let m = lhs.nrows(); + let n = lhs.ncols(); + for j in 0..n { + for i in 0..m { + if !(lhs.read(i, j).canonicalize() == rhs.read(i, j).canonicalize()) { + return false; + } + } + } + + true + } +} + +// impl_partial_eq!(MatRef<'_, LhsE>, MatRef<'_, RhsE>); +impl_partial_eq!(MatRef<'_, LhsE>, MatMut<'_, RhsE>); +impl_partial_eq!(MatRef<'_, LhsE>, Mat); + +impl_partial_eq!(MatMut<'_, LhsE>, MatRef<'_, RhsE>); +impl_partial_eq!(MatMut<'_, LhsE>, MatMut<'_, RhsE>); +impl_partial_eq!(MatMut<'_, LhsE>, Mat); + +impl_partial_eq!(Mat, MatRef<'_, RhsE>); +impl_partial_eq!(Mat, MatMut<'_, RhsE>); +impl_partial_eq!(Mat, Mat); + +impl> PartialEq> + for ColRef<'_, LhsE> +{ + fn eq(&self, other: &ColRef<'_, RhsE>) -> bool { + self.as_2d().eq(&other.as_2d()) + } +} + +// impl_partial_eq!(ColRef<'_, LhsE>, ColRef<'_, RhsE>); +impl_partial_eq!(ColRef<'_, LhsE>, ColMut<'_, RhsE>); +impl_partial_eq!(ColRef<'_, LhsE>, Col); + +impl_partial_eq!(ColMut<'_, LhsE>, ColRef<'_, RhsE>); +impl_partial_eq!(ColMut<'_, LhsE>, ColMut<'_, RhsE>); +impl_partial_eq!(ColMut<'_, LhsE>, Col); + +impl_partial_eq!(Col, ColRef<'_, RhsE>); +impl_partial_eq!(Col, ColMut<'_, RhsE>); +impl_partial_eq!(Col, Col); + +impl> PartialEq> + for RowRef<'_, LhsE> +{ + fn eq(&self, other: &RowRef<'_, RhsE>) -> bool { + self.as_2d().eq(&other.as_2d()) + } +} + +// impl_partial_eq!(RowRef<'_, LhsE>, RowRef<'_, RhsE>); +impl_partial_eq!(RowRef<'_, LhsE>, RowMut<'_, RhsE>); +impl_partial_eq!(RowRef<'_, LhsE>, Row); + +impl_partial_eq!(RowMut<'_, LhsE>, RowRef<'_, RhsE>); +impl_partial_eq!(RowMut<'_, LhsE>, RowMut<'_, RhsE>); +impl_partial_eq!(RowMut<'_, LhsE>, Row); + +impl_partial_eq!(Row, RowRef<'_, RhsE>); +impl_partial_eq!(Row, RowMut<'_, RhsE>); +impl_partial_eq!(Row, Row); + +impl> PartialEq> + for DiagRef<'_, LhsE> +{ + fn eq(&self, other: &DiagRef<'_, RhsE>) -> bool { + self.column_vector().eq(&other.column_vector()) + } +} + +// impl_partial_eq!(DiagRef<'_, LhsE>, DiagRef<'_, RhsE>); +impl_partial_eq!(DiagRef<'_, LhsE>, DiagMut<'_, RhsE>); +impl_partial_eq!(DiagRef<'_, LhsE>, Diag); + +impl_partial_eq!(DiagMut<'_, LhsE>, DiagRef<'_, RhsE>); +impl_partial_eq!(DiagMut<'_, LhsE>, DiagMut<'_, RhsE>); +impl_partial_eq!(DiagMut<'_, LhsE>, Diag); + +impl_partial_eq!(Diag, DiagRef<'_, RhsE>); +impl_partial_eq!(Diag, DiagMut<'_, RhsE>); +impl_partial_eq!(Diag, Diag); + +impl PartialEq> for PermRef<'_, I> { + #[inline] + fn eq(&self, other: &PermRef<'_, I>) -> bool { + self.arrays().0 == other.arrays().0 + } +} +impl PartialEq> for Perm { + #[inline] + fn eq(&self, other: &PermRef<'_, I>) -> bool { + self.as_ref() == other.as_ref() + } +} +impl PartialEq> for PermRef<'_, I> { + #[inline] + fn eq(&self, other: &Perm) -> bool { + self.as_ref() == other.as_ref() + } +} +impl PartialEq> for Perm { + #[inline] + fn eq(&self, other: &Perm) -> bool { + self.as_ref() == other.as_ref() + } +} + +impl, RhsE: Conjugate> + Add> for MatRef<'_, LhsE> +{ + type Output = Mat; + + #[track_caller] + fn add(self, rhs: MatRef<'_, RhsE>) -> Self::Output { + zipped!(self, rhs).map(|unzipped!(lhs, rhs)| { + lhs.read() + .canonicalize() + .faer_add(rhs.read().canonicalize()) + }) + } +} + +impl, RhsE: Conjugate> + Sub> for MatRef<'_, LhsE> +{ + type Output = Mat; + + #[track_caller] + fn sub(self, rhs: MatRef<'_, RhsE>) -> Self::Output { + zipped!(self, rhs).map(|unzipped!(lhs, rhs)| { + lhs.read() + .canonicalize() + .faer_sub(rhs.read().canonicalize()) + }) + } +} + +impl> AddAssign> + for MatMut<'_, LhsE> +{ + #[track_caller] + fn add_assign(&mut self, rhs: MatRef<'_, RhsE>) { + zipped!(self.as_mut(), rhs).for_each(|unzipped!(mut lhs, rhs)| { + lhs.write(lhs.read().faer_add(rhs.read().canonicalize())) + }) + } +} + +impl> SubAssign> + for MatMut<'_, LhsE> +{ + #[track_caller] + fn sub_assign(&mut self, rhs: MatRef<'_, RhsE>) { + zipped!(self.as_mut(), rhs).for_each(|unzipped!(mut lhs, rhs)| { + lhs.write(lhs.read().faer_sub(rhs.read().canonicalize())) + }) + } +} + +impl Neg for MatRef<'_, E> +where + E::Canonical: ComplexField, +{ + type Output = Mat; + + fn neg(self) -> Self::Output { + zipped!(self).map(|unzipped!(x)| x.read().canonicalize().faer_neg()) + } +} + +impl, RhsE: Conjugate> + Add> for ColRef<'_, LhsE> +{ + type Output = Col; + + #[track_caller] + fn add(self, rhs: ColRef<'_, RhsE>) -> Self::Output { + zipped!(self, rhs).map(|unzipped!(lhs, rhs)| { + lhs.read() + .canonicalize() + .faer_add(rhs.read().canonicalize()) + }) + } +} + +impl, RhsE: Conjugate> + Sub> for ColRef<'_, LhsE> +{ + type Output = Col; + + #[track_caller] + fn sub(self, rhs: ColRef<'_, RhsE>) -> Self::Output { + zipped!(self, rhs).map(|unzipped!(lhs, rhs)| { + lhs.read() + .canonicalize() + .faer_sub(rhs.read().canonicalize()) + }) + } +} + +impl> AddAssign> + for ColMut<'_, LhsE> +{ + #[track_caller] + fn add_assign(&mut self, rhs: ColRef<'_, RhsE>) { + zipped!(self.as_mut(), rhs).for_each(|unzipped!(mut lhs, rhs)| { + lhs.write(lhs.read().faer_add(rhs.read().canonicalize())) + }) + } +} + +impl> SubAssign> + for ColMut<'_, LhsE> +{ + #[track_caller] + fn sub_assign(&mut self, rhs: ColRef<'_, RhsE>) { + zipped!(self.as_mut(), rhs).for_each(|unzipped!(mut lhs, rhs)| { + lhs.write(lhs.read().faer_sub(rhs.read().canonicalize())) + }) + } +} + +impl Neg for ColRef<'_, E> +where + E::Canonical: ComplexField, +{ + type Output = Col; + + fn neg(self) -> Self::Output { + zipped!(self).map(|unzipped!(x)| x.read().canonicalize().faer_neg()) + } +} + +impl, RhsE: Conjugate> + Add> for RowRef<'_, LhsE> +{ + type Output = Row; + + #[track_caller] + fn add(self, rhs: RowRef<'_, RhsE>) -> Self::Output { + zipped!(self, rhs).map(|unzipped!(lhs, rhs)| { + lhs.read() + .canonicalize() + .faer_add(rhs.read().canonicalize()) + }) + } +} + +impl, RhsE: Conjugate> + Sub> for RowRef<'_, LhsE> +{ + type Output = Row; + + #[track_caller] + fn sub(self, rhs: RowRef<'_, RhsE>) -> Self::Output { + zipped!(self, rhs).map(|unzipped!(lhs, rhs)| { + lhs.read() + .canonicalize() + .faer_sub(rhs.read().canonicalize()) + }) + } +} + +impl> AddAssign> + for RowMut<'_, LhsE> +{ + #[track_caller] + fn add_assign(&mut self, rhs: RowRef<'_, RhsE>) { + zipped!(self.as_mut(), rhs).for_each(|unzipped!(mut lhs, rhs)| { + lhs.write(lhs.read().faer_add(rhs.read().canonicalize())) + }) + } +} + +impl> SubAssign> + for RowMut<'_, LhsE> +{ + #[track_caller] + fn sub_assign(&mut self, rhs: RowRef<'_, RhsE>) { + zipped!(self.as_mut(), rhs).for_each(|unzipped!(mut lhs, rhs)| { + lhs.write(lhs.read().faer_sub(rhs.read().canonicalize())) + }) + } +} + +impl Neg for RowRef<'_, E> +where + E::Canonical: ComplexField, +{ + type Output = Row; + + fn neg(self) -> Self::Output { + zipped!(self).map(|unzipped!(x)| x.read().canonicalize().faer_neg()) + } +} + +impl, RhsE: Conjugate> + Add> for DiagRef<'_, LhsE> +{ + type Output = Diag; + + #[track_caller] + fn add(self, rhs: DiagRef<'_, RhsE>) -> Self::Output { + zipped!(self.column_vector(), rhs.column_vector()) + .map(|unzipped!(lhs, rhs)| { + lhs.read() + .canonicalize() + .faer_add(rhs.read().canonicalize()) + }) + .column_vector_into_diagonal() + } +} + +impl, RhsE: Conjugate> + Sub> for DiagRef<'_, LhsE> +{ + type Output = Diag; + + #[track_caller] + fn sub(self, rhs: DiagRef<'_, RhsE>) -> Self::Output { + zipped!(self.column_vector(), rhs.column_vector()) + .map(|unzipped!(lhs, rhs)| { + lhs.read() + .canonicalize() + .faer_sub(rhs.read().canonicalize()) + }) + .column_vector_into_diagonal() + } +} + +impl> AddAssign> + for DiagMut<'_, LhsE> +{ + #[track_caller] + fn add_assign(&mut self, rhs: DiagRef<'_, RhsE>) { + zipped!(self.as_mut().column_vector_mut(), rhs.column_vector()).for_each( + |unzipped!(mut lhs, rhs)| lhs.write(lhs.read().faer_add(rhs.read().canonicalize())), + ) + } +} + +impl> SubAssign> + for DiagMut<'_, LhsE> +{ + #[track_caller] + fn sub_assign(&mut self, rhs: DiagRef<'_, RhsE>) { + zipped!(self.as_mut().column_vector_mut(), rhs.column_vector()).for_each( + |unzipped!(mut lhs, rhs)| lhs.write(lhs.read().faer_sub(rhs.read().canonicalize())), + ) + } +} + +impl Neg for DiagRef<'_, E> +where + E::Canonical: ComplexField, +{ + type Output = Diag; + + fn neg(self) -> Self::Output { + zipped!(self.column_vector()) + .map(|unzipped!(x)| x.read().canonicalize().faer_neg()) + .column_vector_into_diagonal() + } +} + +// impl_add_sub!(MatRef<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_add_sub!(MatRef<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_add_sub!(MatRef<'_, LhsE>, Mat, Mat); +impl_add_sub!(MatRef<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_add_sub!(MatRef<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_add_sub!(MatRef<'_, LhsE>, &Mat, Mat); +impl_add_sub!(&MatRef<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_add_sub!(&MatRef<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_add_sub!(&MatRef<'_, LhsE>, Mat, Mat); +impl_add_sub!(&MatRef<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_add_sub!(&MatRef<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_add_sub!(&MatRef<'_, LhsE>, &Mat, Mat); + +impl_add_sub!(MatMut<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_add_sub!(MatMut<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_add_sub!(MatMut<'_, LhsE>, Mat, Mat); +impl_add_sub!(MatMut<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_add_sub!(MatMut<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_add_sub!(MatMut<'_, LhsE>, &Mat, Mat); +impl_add_sub!(&MatMut<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_add_sub!(&MatMut<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_add_sub!(&MatMut<'_, LhsE>, Mat, Mat); +impl_add_sub!(&MatMut<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_add_sub!(&MatMut<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_add_sub!(&MatMut<'_, LhsE>, &Mat, Mat); + +impl_add_sub!(Mat, MatRef<'_, RhsE>, Mat); +impl_add_sub!(Mat, MatMut<'_, RhsE>, Mat); +impl_add_sub!(Mat, Mat, Mat); +impl_add_sub!(Mat, &MatRef<'_, RhsE>, Mat); +impl_add_sub!(Mat, &MatMut<'_, RhsE>, Mat); +impl_add_sub!(Mat, &Mat, Mat); +impl_add_sub!(&Mat, MatRef<'_, RhsE>, Mat); +impl_add_sub!(&Mat, MatMut<'_, RhsE>, Mat); +impl_add_sub!(&Mat, Mat, Mat); +impl_add_sub!(&Mat, &MatRef<'_, RhsE>, Mat); +impl_add_sub!(&Mat, &MatMut<'_, RhsE>, Mat); +impl_add_sub!(&Mat, &Mat, Mat); + +// impl_add_sub_assign!(MatMut<'_, LhsE>, MatRef<'_, RhsE>); +impl_add_sub_assign!(MatMut<'_, LhsE>, MatMut<'_, RhsE>); +impl_add_sub_assign!(MatMut<'_, LhsE>, Mat); +impl_add_sub_assign!(MatMut<'_, LhsE>, &MatRef<'_, RhsE>); +impl_add_sub_assign!(MatMut<'_, LhsE>, &MatMut<'_, RhsE>); +impl_add_sub_assign!(MatMut<'_, LhsE>, &Mat); + +impl_add_sub_assign!(Mat, MatRef<'_, RhsE>); +impl_add_sub_assign!(Mat, MatMut<'_, RhsE>); +impl_add_sub_assign!(Mat, Mat); +impl_add_sub_assign!(Mat, &MatRef<'_, RhsE>); +impl_add_sub_assign!(Mat, &MatMut<'_, RhsE>); +impl_add_sub_assign!(Mat, &Mat); + +// impl_neg!(MatRef<'_, E>, Mat); +impl_neg!(MatMut<'_, E>, Mat); +impl_neg!(Mat, Mat); +impl_neg!(&MatRef<'_, E>, Mat); +impl_neg!(&MatMut<'_, E>, Mat); +impl_neg!(&Mat, Mat); + +// impl_add_sub!(ColRef<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_add_sub!(ColRef<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_add_sub!(ColRef<'_, LhsE>, Col, Col); +impl_add_sub!(ColRef<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_add_sub!(ColRef<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_add_sub!(ColRef<'_, LhsE>, &Col, Col); +impl_add_sub!(&ColRef<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_add_sub!(&ColRef<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_add_sub!(&ColRef<'_, LhsE>, Col, Col); +impl_add_sub!(&ColRef<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_add_sub!(&ColRef<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_add_sub!(&ColRef<'_, LhsE>, &Col, Col); + +impl_add_sub!(ColMut<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_add_sub!(ColMut<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_add_sub!(ColMut<'_, LhsE>, Col, Col); +impl_add_sub!(ColMut<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_add_sub!(ColMut<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_add_sub!(ColMut<'_, LhsE>, &Col, Col); +impl_add_sub!(&ColMut<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_add_sub!(&ColMut<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_add_sub!(&ColMut<'_, LhsE>, Col, Col); +impl_add_sub!(&ColMut<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_add_sub!(&ColMut<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_add_sub!(&ColMut<'_, LhsE>, &Col, Col); + +impl_add_sub!(Col, ColRef<'_, RhsE>, Col); +impl_add_sub!(Col, ColMut<'_, RhsE>, Col); +impl_add_sub!(Col, Col, Col); +impl_add_sub!(Col, &ColRef<'_, RhsE>, Col); +impl_add_sub!(Col, &ColMut<'_, RhsE>, Col); +impl_add_sub!(Col, &Col, Col); +impl_add_sub!(&Col, ColRef<'_, RhsE>, Col); +impl_add_sub!(&Col, ColMut<'_, RhsE>, Col); +impl_add_sub!(&Col, Col, Col); +impl_add_sub!(&Col, &ColRef<'_, RhsE>, Col); +impl_add_sub!(&Col, &ColMut<'_, RhsE>, Col); +impl_add_sub!(&Col, &Col, Col); + +// impl_add_sub_assign!(ColMut<'_, LhsE>, ColRef<'_, RhsE>); +impl_add_sub_assign!(ColMut<'_, LhsE>, ColMut<'_, RhsE>); +impl_add_sub_assign!(ColMut<'_, LhsE>, Col); +impl_add_sub_assign!(ColMut<'_, LhsE>, &ColRef<'_, RhsE>); +impl_add_sub_assign!(ColMut<'_, LhsE>, &ColMut<'_, RhsE>); +impl_add_sub_assign!(ColMut<'_, LhsE>, &Col); + +impl_add_sub_assign!(Col, ColRef<'_, RhsE>); +impl_add_sub_assign!(Col, ColMut<'_, RhsE>); +impl_add_sub_assign!(Col, Col); +impl_add_sub_assign!(Col, &ColRef<'_, RhsE>); +impl_add_sub_assign!(Col, &ColMut<'_, RhsE>); +impl_add_sub_assign!(Col, &Col); + +// impl_neg!(ColRef<'_, E>, Col); +impl_neg!(ColMut<'_, E>, Col); +impl_neg!(Col, Col); +impl_neg!(&ColRef<'_, E>, Col); +impl_neg!(&ColMut<'_, E>, Col); +impl_neg!(&Col, Col); + +// impl_add_sub!(RowRef<'_, LhsE>, RowRef<'_, RhsE>, Row); +impl_add_sub!(RowRef<'_, LhsE>, RowMut<'_, RhsE>, Row); +impl_add_sub!(RowRef<'_, LhsE>, Row, Row); +impl_add_sub!(RowRef<'_, LhsE>, &RowRef<'_, RhsE>, Row); +impl_add_sub!(RowRef<'_, LhsE>, &RowMut<'_, RhsE>, Row); +impl_add_sub!(RowRef<'_, LhsE>, &Row, Row); +impl_add_sub!(&RowRef<'_, LhsE>, RowRef<'_, RhsE>, Row); +impl_add_sub!(&RowRef<'_, LhsE>, RowMut<'_, RhsE>, Row); +impl_add_sub!(&RowRef<'_, LhsE>, Row, Row); +impl_add_sub!(&RowRef<'_, LhsE>, &RowRef<'_, RhsE>, Row); +impl_add_sub!(&RowRef<'_, LhsE>, &RowMut<'_, RhsE>, Row); +impl_add_sub!(&RowRef<'_, LhsE>, &Row, Row); + +impl_add_sub!(RowMut<'_, LhsE>, RowRef<'_, RhsE>, Row); +impl_add_sub!(RowMut<'_, LhsE>, RowMut<'_, RhsE>, Row); +impl_add_sub!(RowMut<'_, LhsE>, Row, Row); +impl_add_sub!(RowMut<'_, LhsE>, &RowRef<'_, RhsE>, Row); +impl_add_sub!(RowMut<'_, LhsE>, &RowMut<'_, RhsE>, Row); +impl_add_sub!(RowMut<'_, LhsE>, &Row, Row); +impl_add_sub!(&RowMut<'_, LhsE>, RowRef<'_, RhsE>, Row); +impl_add_sub!(&RowMut<'_, LhsE>, RowMut<'_, RhsE>, Row); +impl_add_sub!(&RowMut<'_, LhsE>, Row, Row); +impl_add_sub!(&RowMut<'_, LhsE>, &RowRef<'_, RhsE>, Row); +impl_add_sub!(&RowMut<'_, LhsE>, &RowMut<'_, RhsE>, Row); +impl_add_sub!(&RowMut<'_, LhsE>, &Row, Row); + +impl_add_sub!(Row, RowRef<'_, RhsE>, Row); +impl_add_sub!(Row, RowMut<'_, RhsE>, Row); +impl_add_sub!(Row, Row, Row); +impl_add_sub!(Row, &RowRef<'_, RhsE>, Row); +impl_add_sub!(Row, &RowMut<'_, RhsE>, Row); +impl_add_sub!(Row, &Row, Row); +impl_add_sub!(&Row, RowRef<'_, RhsE>, Row); +impl_add_sub!(&Row, RowMut<'_, RhsE>, Row); +impl_add_sub!(&Row, Row, Row); +impl_add_sub!(&Row, &RowRef<'_, RhsE>, Row); +impl_add_sub!(&Row, &RowMut<'_, RhsE>, Row); +impl_add_sub!(&Row, &Row, Row); + +// impl_add_sub_assign!(RowMut<'_, LhsE>, RowRef<'_, RhsE>); +impl_add_sub_assign!(RowMut<'_, LhsE>, RowMut<'_, RhsE>); +impl_add_sub_assign!(RowMut<'_, LhsE>, Row); +impl_add_sub_assign!(RowMut<'_, LhsE>, &RowRef<'_, RhsE>); +impl_add_sub_assign!(RowMut<'_, LhsE>, &RowMut<'_, RhsE>); +impl_add_sub_assign!(RowMut<'_, LhsE>, &Row); + +impl_add_sub_assign!(Row, RowRef<'_, RhsE>); +impl_add_sub_assign!(Row, RowMut<'_, RhsE>); +impl_add_sub_assign!(Row, Row); +impl_add_sub_assign!(Row, &RowRef<'_, RhsE>); +impl_add_sub_assign!(Row, &RowMut<'_, RhsE>); +impl_add_sub_assign!(Row, &Row); + +// impl_neg!(RowRef<'_, E>, Row); +impl_neg!(RowMut<'_, E>, Row); +impl_neg!(Row, Row); +impl_neg!(&RowRef<'_, E>, Row); +impl_neg!(&RowMut<'_, E>, Row); +impl_neg!(&Row, Row); + +// impl_add_sub!(DiagRef<'_, LhsE>, DiagRef<'_, RhsE>, Diag); +impl_add_sub!(DiagRef<'_, LhsE>, DiagMut<'_, RhsE>, Diag); +impl_add_sub!(DiagRef<'_, LhsE>, Diag, Diag); +impl_add_sub!(DiagRef<'_, LhsE>, &DiagRef<'_, RhsE>, Diag); +impl_add_sub!(DiagRef<'_, LhsE>, &DiagMut<'_, RhsE>, Diag); +impl_add_sub!(DiagRef<'_, LhsE>, &Diag, Diag); +impl_add_sub!(&DiagRef<'_, LhsE>, DiagRef<'_, RhsE>, Diag); +impl_add_sub!(&DiagRef<'_, LhsE>, DiagMut<'_, RhsE>, Diag); +impl_add_sub!(&DiagRef<'_, LhsE>, Diag, Diag); +impl_add_sub!(&DiagRef<'_, LhsE>, &DiagRef<'_, RhsE>, Diag); +impl_add_sub!(&DiagRef<'_, LhsE>, &DiagMut<'_, RhsE>, Diag); +impl_add_sub!(&DiagRef<'_, LhsE>, &Diag, Diag); + +impl_add_sub!(DiagMut<'_, LhsE>, DiagRef<'_, RhsE>, Diag); +impl_add_sub!(DiagMut<'_, LhsE>, DiagMut<'_, RhsE>, Diag); +impl_add_sub!(DiagMut<'_, LhsE>, Diag, Diag); +impl_add_sub!(DiagMut<'_, LhsE>, &DiagRef<'_, RhsE>, Diag); +impl_add_sub!(DiagMut<'_, LhsE>, &DiagMut<'_, RhsE>, Diag); +impl_add_sub!(DiagMut<'_, LhsE>, &Diag, Diag); +impl_add_sub!(&DiagMut<'_, LhsE>, DiagRef<'_, RhsE>, Diag); +impl_add_sub!(&DiagMut<'_, LhsE>, DiagMut<'_, RhsE>, Diag); +impl_add_sub!(&DiagMut<'_, LhsE>, Diag, Diag); +impl_add_sub!(&DiagMut<'_, LhsE>, &DiagRef<'_, RhsE>, Diag); +impl_add_sub!(&DiagMut<'_, LhsE>, &DiagMut<'_, RhsE>, Diag); +impl_add_sub!(&DiagMut<'_, LhsE>, &Diag, Diag); + +impl_add_sub!(Diag, DiagRef<'_, RhsE>, Diag); +impl_add_sub!(Diag, DiagMut<'_, RhsE>, Diag); +impl_add_sub!(Diag, Diag, Diag); +impl_add_sub!(Diag, &DiagRef<'_, RhsE>, Diag); +impl_add_sub!(Diag, &DiagMut<'_, RhsE>, Diag); +impl_add_sub!(Diag, &Diag, Diag); +impl_add_sub!(&Diag, DiagRef<'_, RhsE>, Diag); +impl_add_sub!(&Diag, DiagMut<'_, RhsE>, Diag); +impl_add_sub!(&Diag, Diag, Diag); +impl_add_sub!(&Diag, &DiagRef<'_, RhsE>, Diag); +impl_add_sub!(&Diag, &DiagMut<'_, RhsE>, Diag); +impl_add_sub!(&Diag, &Diag, Diag); + +// impl_add_sub_assign!(DiagMut<'_, LhsE>, DiagRef<'_, RhsE>); +impl_add_sub_assign!(DiagMut<'_, LhsE>, DiagMut<'_, RhsE>); +impl_add_sub_assign!(DiagMut<'_, LhsE>, Diag); +impl_add_sub_assign!(DiagMut<'_, LhsE>, &DiagRef<'_, RhsE>); +impl_add_sub_assign!(DiagMut<'_, LhsE>, &DiagMut<'_, RhsE>); +impl_add_sub_assign!(DiagMut<'_, LhsE>, &Diag); + +impl_add_sub_assign!(Diag, DiagRef<'_, RhsE>); +impl_add_sub_assign!(Diag, DiagMut<'_, RhsE>); +impl_add_sub_assign!(Diag, Diag); +impl_add_sub_assign!(Diag, &DiagRef<'_, RhsE>); +impl_add_sub_assign!(Diag, &DiagMut<'_, RhsE>); +impl_add_sub_assign!(Diag, &Diag); + +// impl_neg!(DiagRef<'_, E>, Diag); +impl_neg!(DiagMut<'_, E>, Diag); +impl_neg!(Diag, Diag); +impl_neg!(&DiagRef<'_, E>, Diag); +impl_neg!(&DiagMut<'_, E>, Diag); +impl_neg!(&Diag, Diag); + +impl, RhsE: Conjugate> + Mul> for Scale +{ + type Output = Scale; + + #[inline] + fn mul(self, rhs: Scale) -> Self::Output { + Scale(self.0.canonicalize().faer_mul(rhs.0.canonicalize())) + } +} + +impl, RhsE: Conjugate> + Add> for Scale +{ + type Output = Scale; + + #[inline] + fn add(self, rhs: Scale) -> Self::Output { + Scale(self.0.canonicalize().faer_add(rhs.0.canonicalize())) + } +} + +impl, RhsE: Conjugate> + Sub> for Scale +{ + type Output = Scale; + + #[inline] + fn sub(self, rhs: Scale) -> Self::Output { + Scale(self.0.canonicalize().faer_add(rhs.0.canonicalize())) + } +} + +impl> MulAssign> for Scale { + #[inline] + fn mul_assign(&mut self, rhs: Scale) { + self.0 = self.0.faer_mul(rhs.0.canonicalize()) + } +} + +impl> AddAssign> for Scale { + #[inline] + fn add_assign(&mut self, rhs: Scale) { + self.0 = self.0.faer_add(rhs.0.canonicalize()) + } +} + +impl> SubAssign> for Scale { + #[inline] + fn sub_assign(&mut self, rhs: Scale) { + self.0 = self.0.faer_sub(rhs.0.canonicalize()) + } +} + +impl, RhsE: Conjugate> + Mul> for MatRef<'_, LhsE> +{ + type Output = Mat; + + #[inline] + #[track_caller] + fn mul(self, rhs: MatRef<'_, RhsE>) -> Self::Output { + let lhs = self; + assert!(lhs.ncols() == rhs.nrows()); + let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); + crate::linalg::matmul::matmul( + out.as_mut(), + lhs, + rhs, + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +impl, RhsE: Conjugate> + Mul> for MatRef<'_, LhsE> +{ + type Output = Col; + + #[inline] + #[track_caller] + fn mul(self, rhs: ColRef<'_, RhsE>) -> Self::Output { + let lhs = self; + assert!(lhs.ncols() == rhs.nrows()); + let mut out = Col::zeros(lhs.nrows()); + crate::linalg::matmul::matmul( + out.as_mut().as_2d_mut(), + lhs, + rhs.as_2d(), + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +impl, RhsE: Conjugate> + Mul> for RowRef<'_, LhsE> +{ + type Output = Row; + + #[inline] + #[track_caller] + fn mul(self, rhs: MatRef<'_, RhsE>) -> Self::Output { + let lhs = self; + assert!(lhs.ncols() == rhs.nrows()); + let mut out = Row::zeros(rhs.ncols()); + crate::linalg::matmul::matmul( + out.as_mut().as_2d_mut(), + lhs.as_2d(), + rhs, + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +impl, RhsE: Conjugate> + Mul> for RowRef<'_, LhsE> +{ + type Output = E; + + #[inline] + #[track_caller] + fn mul(self, rhs: ColRef<'_, RhsE>) -> Self::Output { + let lhs = self; + assert!(lhs.ncols() == rhs.nrows()); + let (lhs, conj_lhs) = lhs.as_2d().transpose().canonicalize(); + let (rhs, conj_rhs) = rhs.as_2d().canonicalize(); + crate::linalg::matmul::inner_prod::inner_prod_with_conj(lhs, conj_lhs, rhs, conj_rhs) + } +} + +impl, RhsE: Conjugate> + Mul> for ColRef<'_, LhsE> +{ + type Output = Mat; + + #[inline] + #[track_caller] + fn mul(self, rhs: RowRef<'_, RhsE>) -> Self::Output { + let lhs = self; + assert!(lhs.ncols() == rhs.nrows()); + let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); + crate::linalg::matmul::matmul( + out.as_mut(), + lhs.as_2d(), + rhs.as_2d(), + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +// impl_mul!(MatRef<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_mul!(MatRef<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_mul!(MatRef<'_, LhsE>, Mat, Mat); +impl_mul!(MatRef<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_mul!(MatRef<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_mul!(MatRef<'_, LhsE>, &Mat, Mat); +impl_mul!(&MatRef<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_mul!(&MatRef<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_mul!(&MatRef<'_, LhsE>, Mat, Mat); +impl_mul!(&MatRef<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_mul!(&MatRef<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_mul!(&MatRef<'_, LhsE>, &Mat, Mat); + +impl_mul!(MatMut<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_mul!(MatMut<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_mul!(MatMut<'_, LhsE>, Mat, Mat); +impl_mul!(MatMut<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_mul!(MatMut<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_mul!(MatMut<'_, LhsE>, &Mat, Mat); +impl_mul!(&MatMut<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_mul!(&MatMut<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_mul!(&MatMut<'_, LhsE>, Mat, Mat); +impl_mul!(&MatMut<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_mul!(&MatMut<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_mul!(&MatMut<'_, LhsE>, &Mat, Mat); + +impl_mul!(Mat, MatRef<'_, RhsE>, Mat); +impl_mul!(Mat, MatMut<'_, RhsE>, Mat); +impl_mul!(Mat, Mat, Mat); +impl_mul!(Mat, &MatRef<'_, RhsE>, Mat); +impl_mul!(Mat, &MatMut<'_, RhsE>, Mat); +impl_mul!(Mat, &Mat, Mat); +impl_mul!(&Mat, MatRef<'_, RhsE>, Mat); +impl_mul!(&Mat, MatMut<'_, RhsE>, Mat); +impl_mul!(&Mat, Mat, Mat); +impl_mul!(&Mat, &MatRef<'_, RhsE>, Mat); +impl_mul!(&Mat, &MatMut<'_, RhsE>, Mat); +impl_mul!(&Mat, &Mat, Mat); + +// impl_mul!(MatRef<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_mul!(MatRef<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_mul!(MatRef<'_, LhsE>, Col, Col); +impl_mul!(MatRef<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_mul!(MatRef<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_mul!(MatRef<'_, LhsE>, &Col, Col); +impl_mul!(&MatRef<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_mul!(&MatRef<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_mul!(&MatRef<'_, LhsE>, Col, Col); +impl_mul!(&MatRef<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_mul!(&MatRef<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_mul!(&MatRef<'_, LhsE>, &Col, Col); + +impl_mul!(MatMut<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_mul!(MatMut<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_mul!(MatMut<'_, LhsE>, Col, Col); +impl_mul!(MatMut<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_mul!(MatMut<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_mul!(MatMut<'_, LhsE>, &Col, Col); +impl_mul!(&MatMut<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_mul!(&MatMut<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_mul!(&MatMut<'_, LhsE>, Col, Col); +impl_mul!(&MatMut<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_mul!(&MatMut<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_mul!(&MatMut<'_, LhsE>, &Col, Col); + +impl_mul!(Mat, ColRef<'_, RhsE>, Col); +impl_mul!(Mat, ColMut<'_, RhsE>, Col); +impl_mul!(Mat, Col, Col); +impl_mul!(Mat, &ColRef<'_, RhsE>, Col); +impl_mul!(Mat, &ColMut<'_, RhsE>, Col); +impl_mul!(Mat, &Col, Col); +impl_mul!(&Mat, ColRef<'_, RhsE>, Col); +impl_mul!(&Mat, ColMut<'_, RhsE>, Col); +impl_mul!(&Mat, Col, Col); +impl_mul!(&Mat, &ColRef<'_, RhsE>, Col); +impl_mul!(&Mat, &ColMut<'_, RhsE>, Col); +impl_mul!(&Mat, &Col, Col); + +// impl_mul!(RowRef<'_, LhsE>, MatRef<'_, RhsE>, Row); +impl_mul!(RowRef<'_, LhsE>, MatMut<'_, RhsE>, Row); +impl_mul!(RowRef<'_, LhsE>, Mat, Row); +impl_mul!(RowRef<'_, LhsE>, &MatRef<'_, RhsE>, Row); +impl_mul!(RowRef<'_, LhsE>, &MatMut<'_, RhsE>, Row); +impl_mul!(RowRef<'_, LhsE>, &Mat, Row); +impl_mul!(&RowRef<'_, LhsE>, MatRef<'_, RhsE>, Row); +impl_mul!(&RowRef<'_, LhsE>, MatMut<'_, RhsE>, Row); +impl_mul!(&RowRef<'_, LhsE>, Mat, Row); +impl_mul!(&RowRef<'_, LhsE>, &MatRef<'_, RhsE>, Row); +impl_mul!(&RowRef<'_, LhsE>, &MatMut<'_, RhsE>, Row); +impl_mul!(&RowRef<'_, LhsE>, &Mat, Row); + +impl_mul!(RowMut<'_, LhsE>, MatRef<'_, RhsE>, Row); +impl_mul!(RowMut<'_, LhsE>, MatMut<'_, RhsE>, Row); +impl_mul!(RowMut<'_, LhsE>, Mat, Row); +impl_mul!(RowMut<'_, LhsE>, &MatRef<'_, RhsE>, Row); +impl_mul!(RowMut<'_, LhsE>, &MatMut<'_, RhsE>, Row); +impl_mul!(RowMut<'_, LhsE>, &Mat, Row); +impl_mul!(&RowMut<'_, LhsE>, MatRef<'_, RhsE>, Row); +impl_mul!(&RowMut<'_, LhsE>, MatMut<'_, RhsE>, Row); +impl_mul!(&RowMut<'_, LhsE>, Mat, Row); +impl_mul!(&RowMut<'_, LhsE>, &MatRef<'_, RhsE>, Row); +impl_mul!(&RowMut<'_, LhsE>, &MatMut<'_, RhsE>, Row); +impl_mul!(&RowMut<'_, LhsE>, &Mat, Row); + +impl_mul!(Row, MatRef<'_, RhsE>, Row); +impl_mul!(Row, MatMut<'_, RhsE>, Row); +impl_mul!(Row, Mat, Row); +impl_mul!(Row, &MatRef<'_, RhsE>, Row); +impl_mul!(Row, &MatMut<'_, RhsE>, Row); +impl_mul!(Row, &Mat, Row); +impl_mul!(&Row, MatRef<'_, RhsE>, Row); +impl_mul!(&Row, MatMut<'_, RhsE>, Row); +impl_mul!(&Row, Mat, Row); +impl_mul!(&Row, &MatRef<'_, RhsE>, Row); +impl_mul!(&Row, &MatMut<'_, RhsE>, Row); +impl_mul!(&Row, &Mat, Row); + +// impl_mul!(RowRef<'_, LhsE>, ColRef<'_, RhsE>, E); +impl_mul!(RowRef<'_, LhsE>, ColMut<'_, RhsE>, E); +impl_mul!(RowRef<'_, LhsE>, Col, E); +impl_mul!(RowRef<'_, LhsE>, &ColRef<'_, RhsE>, E); +impl_mul!(RowRef<'_, LhsE>, &ColMut<'_, RhsE>, E); +impl_mul!(RowRef<'_, LhsE>, &Col, E); +impl_mul!(&RowRef<'_, LhsE>, ColRef<'_, RhsE>, E); +impl_mul!(&RowRef<'_, LhsE>, ColMut<'_, RhsE>, E); +impl_mul!(&RowRef<'_, LhsE>, Col, E); +impl_mul!(&RowRef<'_, LhsE>, &ColRef<'_, RhsE>, E); +impl_mul!(&RowRef<'_, LhsE>, &ColMut<'_, RhsE>, E); +impl_mul!(&RowRef<'_, LhsE>, &Col, E); + +impl_mul!(RowMut<'_, LhsE>, ColRef<'_, RhsE>, E); +impl_mul!(RowMut<'_, LhsE>, ColMut<'_, RhsE>, E); +impl_mul!(RowMut<'_, LhsE>, Col, E); +impl_mul!(RowMut<'_, LhsE>, &ColRef<'_, RhsE>, E); +impl_mul!(RowMut<'_, LhsE>, &ColMut<'_, RhsE>, E); +impl_mul!(RowMut<'_, LhsE>, &Col, E); +impl_mul!(&RowMut<'_, LhsE>, ColRef<'_, RhsE>, E); +impl_mul!(&RowMut<'_, LhsE>, ColMut<'_, RhsE>, E); +impl_mul!(&RowMut<'_, LhsE>, Col, E); +impl_mul!(&RowMut<'_, LhsE>, &ColRef<'_, RhsE>, E); +impl_mul!(&RowMut<'_, LhsE>, &ColMut<'_, RhsE>, E); +impl_mul!(&RowMut<'_, LhsE>, &Col, E); + +impl_mul!(Row, ColRef<'_, RhsE>, E); +impl_mul!(Row, ColMut<'_, RhsE>, E); +impl_mul!(Row, Col, E); +impl_mul!(Row, &ColRef<'_, RhsE>, E); +impl_mul!(Row, &ColMut<'_, RhsE>, E); +impl_mul!(Row, &Col, E); +impl_mul!(&Row, ColRef<'_, RhsE>, E); +impl_mul!(&Row, ColMut<'_, RhsE>, E); +impl_mul!(&Row, Col, E); +impl_mul!(&Row, &ColRef<'_, RhsE>, E); +impl_mul!(&Row, &ColMut<'_, RhsE>, E); +impl_mul!(&Row, &Col, E); + +// impl_mul!(ColRef<'_, LhsE>, RowRef<'_, RhsE>, Mat); +impl_mul!(ColRef<'_, LhsE>, RowMut<'_, RhsE>, Mat); +impl_mul!(ColRef<'_, LhsE>, Row, Mat); +impl_mul!(ColRef<'_, LhsE>, &RowRef<'_, RhsE>, Mat); +impl_mul!(ColRef<'_, LhsE>, &RowMut<'_, RhsE>, Mat); +impl_mul!(ColRef<'_, LhsE>, &Row, Mat); +impl_mul!(&ColRef<'_, LhsE>, RowRef<'_, RhsE>, Mat); +impl_mul!(&ColRef<'_, LhsE>, RowMut<'_, RhsE>, Mat); +impl_mul!(&ColRef<'_, LhsE>, Row, Mat); +impl_mul!(&ColRef<'_, LhsE>, &RowRef<'_, RhsE>, Mat); +impl_mul!(&ColRef<'_, LhsE>, &RowMut<'_, RhsE>, Mat); +impl_mul!(&ColRef<'_, LhsE>, &Row, Mat); + +impl_mul!(ColMut<'_, LhsE>, RowRef<'_, RhsE>, Mat); +impl_mul!(ColMut<'_, LhsE>, RowMut<'_, RhsE>, Mat); +impl_mul!(ColMut<'_, LhsE>, Row, Mat); +impl_mul!(ColMut<'_, LhsE>, &RowRef<'_, RhsE>, Mat); +impl_mul!(ColMut<'_, LhsE>, &RowMut<'_, RhsE>, Mat); +impl_mul!(ColMut<'_, LhsE>, &Row, Mat); +impl_mul!(&ColMut<'_, LhsE>, RowRef<'_, RhsE>, Mat); +impl_mul!(&ColMut<'_, LhsE>, RowMut<'_, RhsE>, Mat); +impl_mul!(&ColMut<'_, LhsE>, Row, Mat); +impl_mul!(&ColMut<'_, LhsE>, &RowRef<'_, RhsE>, Mat); +impl_mul!(&ColMut<'_, LhsE>, &RowMut<'_, RhsE>, Mat); +impl_mul!(&ColMut<'_, LhsE>, &Row, Mat); + +impl_mul!(Col, RowRef<'_, RhsE>, Mat); +impl_mul!(Col, RowMut<'_, RhsE>, Mat); +impl_mul!(Col, Row, Mat); +impl_mul!(Col, &RowRef<'_, RhsE>, Mat); +impl_mul!(Col, &RowMut<'_, RhsE>, Mat); +impl_mul!(Col, &Row, Mat); +impl_mul!(&Col, RowRef<'_, RhsE>, Mat); +impl_mul!(&Col, RowMut<'_, RhsE>, Mat); +impl_mul!(&Col, Row, Mat); +impl_mul!(&Col, &RowRef<'_, RhsE>, Mat); +impl_mul!(&Col, &RowMut<'_, RhsE>, Mat); +impl_mul!(&Col, &Row, Mat); + +impl, RhsE: Conjugate> + Mul> for DiagRef<'_, LhsE> +{ + type Output = Mat; + + #[track_caller] + fn mul(self, rhs: MatRef<'_, RhsE>) -> Self::Output { + let lhs = self.column_vector(); + let lhs_dim = lhs.nrows(); + let rhs_nrows = rhs.nrows(); + assert!(lhs_dim == rhs_nrows); + + Mat::from_fn(rhs.nrows(), rhs.ncols(), |i, j| unsafe { + E::faer_mul( + lhs.read_unchecked(i).canonicalize(), + rhs.read_unchecked(i, j).canonicalize(), + ) + }) + } +} + +// impl_mul!(DiagRef<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_mul!(DiagRef<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_mul!(DiagRef<'_, LhsE>, Mat, Mat); +impl_mul!(DiagRef<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_mul!(DiagRef<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_mul!(DiagRef<'_, LhsE>, &Mat, Mat); +impl_mul!(&DiagRef<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_mul!(&DiagRef<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_mul!(&DiagRef<'_, LhsE>, Mat, Mat); +impl_mul!(&DiagRef<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_mul!(&DiagRef<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_mul!(&DiagRef<'_, LhsE>, &Mat, Mat); + +impl_mul!(DiagMut<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_mul!(DiagMut<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_mul!(DiagMut<'_, LhsE>, Mat, Mat); +impl_mul!(DiagMut<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_mul!(DiagMut<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_mul!(DiagMut<'_, LhsE>, &Mat, Mat); +impl_mul!(&DiagMut<'_, LhsE>, MatRef<'_, RhsE>, Mat); +impl_mul!(&DiagMut<'_, LhsE>, MatMut<'_, RhsE>, Mat); +impl_mul!(&DiagMut<'_, LhsE>, Mat, Mat); +impl_mul!(&DiagMut<'_, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_mul!(&DiagMut<'_, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_mul!(&DiagMut<'_, LhsE>, &Mat, Mat); + +impl_mul!(Diag, MatRef<'_, RhsE>, Mat); +impl_mul!(Diag, MatMut<'_, RhsE>, Mat); +impl_mul!(Diag, Mat, Mat); +impl_mul!(Diag, &MatRef<'_, RhsE>, Mat); +impl_mul!(Diag, &MatMut<'_, RhsE>, Mat); +impl_mul!(Diag, &Mat, Mat); +impl_mul!(&Diag, MatRef<'_, RhsE>, Mat); +impl_mul!(&Diag, MatMut<'_, RhsE>, Mat); +impl_mul!(&Diag, Mat, Mat); +impl_mul!(&Diag, &MatRef<'_, RhsE>, Mat); +impl_mul!(&Diag, &MatMut<'_, RhsE>, Mat); +impl_mul!(&Diag, &Mat, Mat); + +impl, RhsE: Conjugate> + Mul> for DiagRef<'_, LhsE> +{ + type Output = Col; + + #[track_caller] + fn mul(self, rhs: ColRef<'_, RhsE>) -> Self::Output { + let lhs = self.column_vector(); + let lhs_dim = lhs.nrows(); + let rhs_nrows = rhs.nrows(); + assert!(lhs_dim == rhs_nrows); + + Col::from_fn(rhs.nrows(), |i| unsafe { + E::faer_mul( + lhs.read_unchecked(i).canonicalize(), + rhs.read_unchecked(i).canonicalize(), + ) + }) + } +} + +// impl_mul!(DiagRef<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_mul!(DiagRef<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_mul!(DiagRef<'_, LhsE>, Col, Col); +impl_mul!(DiagRef<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_mul!(DiagRef<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_mul!(DiagRef<'_, LhsE>, &Col, Col); +impl_mul!(&DiagRef<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_mul!(&DiagRef<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_mul!(&DiagRef<'_, LhsE>, Col, Col); +impl_mul!(&DiagRef<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_mul!(&DiagRef<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_mul!(&DiagRef<'_, LhsE>, &Col, Col); + +impl_mul!(DiagMut<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_mul!(DiagMut<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_mul!(DiagMut<'_, LhsE>, Col, Col); +impl_mul!(DiagMut<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_mul!(DiagMut<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_mul!(DiagMut<'_, LhsE>, &Col, Col); +impl_mul!(&DiagMut<'_, LhsE>, ColRef<'_, RhsE>, Col); +impl_mul!(&DiagMut<'_, LhsE>, ColMut<'_, RhsE>, Col); +impl_mul!(&DiagMut<'_, LhsE>, Col, Col); +impl_mul!(&DiagMut<'_, LhsE>, &ColRef<'_, RhsE>, Col); +impl_mul!(&DiagMut<'_, LhsE>, &ColMut<'_, RhsE>, Col); +impl_mul!(&DiagMut<'_, LhsE>, &Col, Col); + +impl_mul!(Diag, ColRef<'_, RhsE>, Col); +impl_mul!(Diag, ColMut<'_, RhsE>, Col); +impl_mul!(Diag, Col, Col); +impl_mul!(Diag, &ColRef<'_, RhsE>, Col); +impl_mul!(Diag, &ColMut<'_, RhsE>, Col); +impl_mul!(Diag, &Col, Col); +impl_mul!(&Diag, ColRef<'_, RhsE>, Col); +impl_mul!(&Diag, ColMut<'_, RhsE>, Col); +impl_mul!(&Diag, Col, Col); +impl_mul!(&Diag, &ColRef<'_, RhsE>, Col); +impl_mul!(&Diag, &ColMut<'_, RhsE>, Col); +impl_mul!(&Diag, &Col, Col); + +impl, RhsE: Conjugate> + Mul> for MatRef<'_, LhsE> +{ + type Output = Mat; + + #[track_caller] + fn mul(self, rhs: DiagRef<'_, RhsE>) -> Self::Output { + let lhs = self; + let rhs = rhs.column_vector(); + let lhs_ncols = lhs.ncols(); + let rhs_dim = rhs.nrows(); + assert!(lhs_ncols == rhs_dim); + + Mat::from_fn(lhs.nrows(), lhs.ncols(), |i, j| unsafe { + E::faer_mul( + lhs.read_unchecked(i, j).canonicalize(), + rhs.read_unchecked(j).canonicalize(), + ) + }) + } +} + +// impl_mul!(MatRef<'_, LhsE>, DiagRef<'_, RhsE>, Mat); +impl_mul!(MatRef<'_, LhsE>, DiagMut<'_, RhsE>, Mat); +impl_mul!(MatRef<'_, LhsE>, Diag, Mat); +impl_mul!(MatRef<'_, LhsE>, &DiagRef<'_, RhsE>, Mat); +impl_mul!(MatRef<'_, LhsE>, &DiagMut<'_, RhsE>, Mat); +impl_mul!(MatRef<'_, LhsE>, &Diag, Mat); +impl_mul!(&MatRef<'_, LhsE>, DiagRef<'_, RhsE>, Mat); +impl_mul!(&MatRef<'_, LhsE>, DiagMut<'_, RhsE>, Mat); +impl_mul!(&MatRef<'_, LhsE>, Diag, Mat); +impl_mul!(&MatRef<'_, LhsE>, &DiagRef<'_, RhsE>, Mat); +impl_mul!(&MatRef<'_, LhsE>, &DiagMut<'_, RhsE>, Mat); +impl_mul!(&MatRef<'_, LhsE>, &Diag, Mat); + +impl_mul!(MatMut<'_, LhsE>, DiagRef<'_, RhsE>, Mat); +impl_mul!(MatMut<'_, LhsE>, DiagMut<'_, RhsE>, Mat); +impl_mul!(MatMut<'_, LhsE>, Diag, Mat); +impl_mul!(MatMut<'_, LhsE>, &DiagRef<'_, RhsE>, Mat); +impl_mul!(MatMut<'_, LhsE>, &DiagMut<'_, RhsE>, Mat); +impl_mul!(MatMut<'_, LhsE>, &Diag, Mat); +impl_mul!(&MatMut<'_, LhsE>, DiagRef<'_, RhsE>, Mat); +impl_mul!(&MatMut<'_, LhsE>, DiagMut<'_, RhsE>, Mat); +impl_mul!(&MatMut<'_, LhsE>, Diag, Mat); +impl_mul!(&MatMut<'_, LhsE>, &DiagRef<'_, RhsE>, Mat); +impl_mul!(&MatMut<'_, LhsE>, &DiagMut<'_, RhsE>, Mat); +impl_mul!(&MatMut<'_, LhsE>, &Diag, Mat); + +impl_mul!(Mat, DiagRef<'_, RhsE>, Mat); +impl_mul!(Mat, DiagMut<'_, RhsE>, Mat); +impl_mul!(Mat, Diag, Mat); +impl_mul!(Mat, &DiagRef<'_, RhsE>, Mat); +impl_mul!(Mat, &DiagMut<'_, RhsE>, Mat); +impl_mul!(Mat, &Diag, Mat); +impl_mul!(&Mat, DiagRef<'_, RhsE>, Mat); +impl_mul!(&Mat, DiagMut<'_, RhsE>, Mat); +impl_mul!(&Mat, Diag, Mat); +impl_mul!(&Mat, &DiagRef<'_, RhsE>, Mat); +impl_mul!(&Mat, &DiagMut<'_, RhsE>, Mat); +impl_mul!(&Mat, &Diag, Mat); + +impl, RhsE: Conjugate> + Mul> for RowRef<'_, LhsE> +{ + type Output = Row; + + #[track_caller] + fn mul(self, rhs: DiagRef<'_, RhsE>) -> Self::Output { + let lhs = self; + let rhs = rhs.column_vector(); + let lhs_ncols = lhs.ncols(); + let rhs_dim = rhs.nrows(); + assert!(lhs_ncols == rhs_dim); + + Row::from_fn(lhs.ncols(), |j| unsafe { + E::faer_mul( + lhs.read_unchecked(j).canonicalize(), + rhs.read_unchecked(j).canonicalize(), + ) + }) + } +} + +// impl_mul!(RowRef<'_, LhsE>, DiagRef<'_, RhsE>, Row); +impl_mul!(RowRef<'_, LhsE>, DiagMut<'_, RhsE>, Row); +impl_mul!(RowRef<'_, LhsE>, Diag, Row); +impl_mul!(RowRef<'_, LhsE>, &DiagRef<'_, RhsE>, Row); +impl_mul!(RowRef<'_, LhsE>, &DiagMut<'_, RhsE>, Row); +impl_mul!(RowRef<'_, LhsE>, &Diag, Row); +impl_mul!(&RowRef<'_, LhsE>, DiagRef<'_, RhsE>, Row); +impl_mul!(&RowRef<'_, LhsE>, DiagMut<'_, RhsE>, Row); +impl_mul!(&RowRef<'_, LhsE>, Diag, Row); +impl_mul!(&RowRef<'_, LhsE>, &DiagRef<'_, RhsE>, Row); +impl_mul!(&RowRef<'_, LhsE>, &DiagMut<'_, RhsE>, Row); +impl_mul!(&RowRef<'_, LhsE>, &Diag, Row); + +impl_mul!(RowMut<'_, LhsE>, DiagRef<'_, RhsE>, Row); +impl_mul!(RowMut<'_, LhsE>, DiagMut<'_, RhsE>, Row); +impl_mul!(RowMut<'_, LhsE>, Diag, Row); +impl_mul!(RowMut<'_, LhsE>, &DiagRef<'_, RhsE>, Row); +impl_mul!(RowMut<'_, LhsE>, &DiagMut<'_, RhsE>, Row); +impl_mul!(RowMut<'_, LhsE>, &Diag, Row); +impl_mul!(&RowMut<'_, LhsE>, DiagRef<'_, RhsE>, Row); +impl_mul!(&RowMut<'_, LhsE>, DiagMut<'_, RhsE>, Row); +impl_mul!(&RowMut<'_, LhsE>, Diag, Row); +impl_mul!(&RowMut<'_, LhsE>, &DiagRef<'_, RhsE>, Row); +impl_mul!(&RowMut<'_, LhsE>, &DiagMut<'_, RhsE>, Row); +impl_mul!(&RowMut<'_, LhsE>, &Diag, Row); + +impl_mul!(Row, DiagRef<'_, RhsE>, Row); +impl_mul!(Row, DiagMut<'_, RhsE>, Row); +impl_mul!(Row, Diag, Row); +impl_mul!(Row, &DiagRef<'_, RhsE>, Row); +impl_mul!(Row, &DiagMut<'_, RhsE>, Row); +impl_mul!(Row, &Diag, Row); +impl_mul!(&Row, DiagRef<'_, RhsE>, Row); +impl_mul!(&Row, DiagMut<'_, RhsE>, Row); +impl_mul!(&Row, Diag, Row); +impl_mul!(&Row, &DiagRef<'_, RhsE>, Row); +impl_mul!(&Row, &DiagMut<'_, RhsE>, Row); +impl_mul!(&Row, &Diag, Row); + +impl, RhsE: Conjugate> + Mul> for DiagRef<'_, LhsE> +{ + type Output = Diag; + + #[track_caller] + fn mul(self, rhs: DiagRef<'_, RhsE>) -> Self::Output { + let lhs = self.column_vector(); + let rhs = rhs.column_vector(); + assert!(lhs.nrows() == rhs.nrows()); + + Col::from_fn(lhs.nrows(), |i| unsafe { + E::faer_mul( + lhs.read_unchecked(i).canonicalize(), + rhs.read_unchecked(i).canonicalize(), + ) + }) + .column_vector_into_diagonal() + } +} + +// impl_mul!(DiagRef<'_, LhsE>, DiagRef<'_, RhsE>, Diag); +impl_mul!(DiagRef<'_, LhsE>, DiagMut<'_, RhsE>, Diag); +impl_mul!(DiagRef<'_, LhsE>, Diag, Diag); +impl_mul!(DiagRef<'_, LhsE>, &DiagRef<'_, RhsE>, Diag); +impl_mul!(DiagRef<'_, LhsE>, &DiagMut<'_, RhsE>, Diag); +impl_mul!(DiagRef<'_, LhsE>, &Diag, Diag); +impl_mul!(&DiagRef<'_, LhsE>, DiagRef<'_, RhsE>, Diag); +impl_mul!(&DiagRef<'_, LhsE>, DiagMut<'_, RhsE>, Diag); +impl_mul!(&DiagRef<'_, LhsE>, Diag, Diag); +impl_mul!(&DiagRef<'_, LhsE>, &DiagRef<'_, RhsE>, Diag); +impl_mul!(&DiagRef<'_, LhsE>, &DiagMut<'_, RhsE>, Diag); +impl_mul!(&DiagRef<'_, LhsE>, &Diag, Diag); + +impl_mul!(DiagMut<'_, LhsE>, DiagRef<'_, RhsE>, Diag); +impl_mul!(DiagMut<'_, LhsE>, DiagMut<'_, RhsE>, Diag); +impl_mul!(DiagMut<'_, LhsE>, Diag, Diag); +impl_mul!(DiagMut<'_, LhsE>, &DiagRef<'_, RhsE>, Diag); +impl_mul!(DiagMut<'_, LhsE>, &DiagMut<'_, RhsE>, Diag); +impl_mul!(DiagMut<'_, LhsE>, &Diag, Diag); +impl_mul!(&DiagMut<'_, LhsE>, DiagRef<'_, RhsE>, Diag); +impl_mul!(&DiagMut<'_, LhsE>, DiagMut<'_, RhsE>, Diag); +impl_mul!(&DiagMut<'_, LhsE>, Diag, Diag); +impl_mul!(&DiagMut<'_, LhsE>, &DiagRef<'_, RhsE>, Diag); +impl_mul!(&DiagMut<'_, LhsE>, &DiagMut<'_, RhsE>, Diag); +impl_mul!(&DiagMut<'_, LhsE>, &Diag, Diag); + +impl_mul!(Diag, DiagRef<'_, RhsE>, Diag); +impl_mul!(Diag, DiagMut<'_, RhsE>, Diag); +impl_mul!(Diag, Diag, Diag); +impl_mul!(Diag, &DiagRef<'_, RhsE>, Diag); +impl_mul!(Diag, &DiagMut<'_, RhsE>, Diag); +impl_mul!(Diag, &Diag, Diag); +impl_mul!(&Diag, DiagRef<'_, RhsE>, Diag); +impl_mul!(&Diag, DiagMut<'_, RhsE>, Diag); +impl_mul!(&Diag, Diag, Diag); +impl_mul!(&Diag, &DiagRef<'_, RhsE>, Diag); +impl_mul!(&Diag, &DiagMut<'_, RhsE>, Diag); +impl_mul!(&Diag, &Diag, Diag); + +impl Mul> for PermRef<'_, I> { + type Output = Perm; + + #[track_caller] + fn mul(self, rhs: PermRef<'_, I>) -> Self::Output { + let lhs = self; + assert!(lhs.len() == rhs.len()); + let truncate = ::truncate; + let mut fwd = alloc::vec![I::from_signed(truncate(0)); lhs.len()].into_boxed_slice(); + let mut inv = alloc::vec![I::from_signed(truncate(0)); lhs.len()].into_boxed_slice(); + + for (fwd, rhs) in fwd.iter_mut().zip(rhs.arrays().0) { + *fwd = lhs.arrays().0[rhs.to_signed().zx()]; + } + for (i, fwd) in fwd.iter().enumerate() { + inv[fwd.to_signed().zx()] = I::from_signed(I::Signed::truncate(i)); + } + + Perm::new_checked(fwd, inv) + } +} + +// impl_perm_perm!(PermRef<'_, I>, PermRef<'_, I>, Perm); +impl_perm_perm!(PermRef<'_, I>, Perm, Perm); +impl_perm_perm!(PermRef<'_, I>, &PermRef<'_, I>, Perm); +impl_perm_perm!(PermRef<'_, I>, &Perm, Perm); +impl_perm_perm!(&PermRef<'_, I>, PermRef<'_, I>, Perm); +impl_perm_perm!(&PermRef<'_, I>, Perm, Perm); +impl_perm_perm!(&PermRef<'_, I>, &PermRef<'_, I>, Perm); +impl_perm_perm!(&PermRef<'_, I>, &Perm, Perm); + +impl_perm_perm!(Perm, PermRef<'_, I>, Perm); +impl_perm_perm!(Perm, Perm, Perm); +impl_perm_perm!(Perm, &PermRef<'_, I>, Perm); +impl_perm_perm!(Perm, &Perm, Perm); +impl_perm_perm!(&Perm, PermRef<'_, I>, Perm); +impl_perm_perm!(&Perm, Perm, Perm); +impl_perm_perm!(&Perm, &PermRef<'_, I>, Perm); +impl_perm_perm!(&Perm, &Perm, Perm); + +impl Mul> for PermRef<'_, I> +where + E::Canonical: ComplexField, +{ + type Output = Mat; + + #[track_caller] + fn mul(self, rhs: MatRef<'_, E>) -> Self::Output { + let lhs = self; + + assert!(lhs.len() == rhs.nrows()); + let mut out = Mat::zeros(rhs.nrows(), rhs.ncols()); + let fwd = lhs.arrays().0; + + for j in 0..rhs.ncols() { + for (i, fwd) in fwd.iter().enumerate() { + out.write(i, j, rhs.read(fwd.to_signed().zx(), j).canonicalize()); + } + } + out + } +} + +// impl_perm!(PermRef<'_, I>, MatRef<'_, E>, Mat); +impl_perm!(PermRef<'_, I>, MatMut<'_, E>, Mat); +impl_perm!(PermRef<'_, I>, Mat, Mat); +impl_perm!(PermRef<'_, I>, &MatRef<'_, E>, Mat); +impl_perm!(PermRef<'_, I>, &MatMut<'_, E>, Mat); +impl_perm!(PermRef<'_, I>, &Mat, Mat); +impl_perm!(&PermRef<'_, I>, MatRef<'_, E>, Mat); +impl_perm!(&PermRef<'_, I>, MatMut<'_, E>, Mat); +impl_perm!(&PermRef<'_, I>, Mat, Mat); +impl_perm!(&PermRef<'_, I>, &MatRef<'_, E>, Mat); +impl_perm!(&PermRef<'_, I>, &MatMut<'_, E>, Mat); +impl_perm!(&PermRef<'_, I>, &Mat, Mat); + +impl_perm!(Perm, MatRef<'_, E>, Mat); +impl_perm!(Perm, MatMut<'_, E>, Mat); +impl_perm!(Perm, Mat, Mat); +impl_perm!(Perm, &MatRef<'_, E>, Mat); +impl_perm!(Perm, &MatMut<'_, E>, Mat); +impl_perm!(Perm, &Mat, Mat); +impl_perm!(&Perm, MatRef<'_, E>, Mat); +impl_perm!(&Perm, MatMut<'_, E>, Mat); +impl_perm!(&Perm, Mat, Mat); +impl_perm!(&Perm, &MatRef<'_, E>, Mat); +impl_perm!(&Perm, &MatMut<'_, E>, Mat); +impl_perm!(&Perm, &Mat, Mat); + +impl Mul> for PermRef<'_, I> +where + E::Canonical: ComplexField, +{ + type Output = Col; + + #[track_caller] + fn mul(self, rhs: ColRef<'_, E>) -> Self::Output { + let lhs = self; + + assert!(lhs.len() == rhs.nrows()); + let mut out = Col::zeros(rhs.nrows()); + let fwd = lhs.arrays().0; + + for (i, fwd) in fwd.iter().enumerate() { + out.write(i, rhs.read(fwd.to_signed().zx()).canonicalize()); + } + out + } +} + +// impl_perm!(PermRef<'_, I>, ColRef<'_, E>, Col); +impl_perm!(PermRef<'_, I>, ColMut<'_, E>, Col); +impl_perm!(PermRef<'_, I>, Col, Col); +impl_perm!(PermRef<'_, I>, &ColRef<'_, E>, Col); +impl_perm!(PermRef<'_, I>, &ColMut<'_, E>, Col); +impl_perm!(PermRef<'_, I>, &Col, Col); +impl_perm!(&PermRef<'_, I>, ColRef<'_, E>, Col); +impl_perm!(&PermRef<'_, I>, ColMut<'_, E>, Col); +impl_perm!(&PermRef<'_, I>, Col, Col); +impl_perm!(&PermRef<'_, I>, &ColRef<'_, E>, Col); +impl_perm!(&PermRef<'_, I>, &ColMut<'_, E>, Col); +impl_perm!(&PermRef<'_, I>, &Col, Col); + +impl_perm!(Perm, ColRef<'_, E>, Col); +impl_perm!(Perm, ColMut<'_, E>, Col); +impl_perm!(Perm, Col, Col); +impl_perm!(Perm, &ColRef<'_, E>, Col); +impl_perm!(Perm, &ColMut<'_, E>, Col); +impl_perm!(Perm, &Col, Col); +impl_perm!(&Perm, ColRef<'_, E>, Col); +impl_perm!(&Perm, ColMut<'_, E>, Col); +impl_perm!(&Perm, Col, Col); +impl_perm!(&Perm, &ColRef<'_, E>, Col); +impl_perm!(&Perm, &ColMut<'_, E>, Col); +impl_perm!(&Perm, &Col, Col); + +impl Mul> for MatRef<'_, E> +where + E::Canonical: ComplexField, +{ + type Output = Mat; + + #[track_caller] + fn mul(self, rhs: PermRef<'_, I>) -> Self::Output { + let lhs = self; + + assert!(lhs.ncols() == rhs.len()); + let mut out = Mat::zeros(lhs.nrows(), lhs.ncols()); + let inv = rhs.arrays().1; + + for (j, inv) in inv.iter().enumerate() { + for i in 0..lhs.nrows() { + out.write(i, j, lhs.read(i, inv.to_signed().zx()).canonicalize()); + } + } + out + } +} + +// impl_perm!(MatRef<'_, E>, PermRef<'_, I>, Mat); +impl_perm!(MatRef<'_, E>, Perm, Mat); +impl_perm!(MatRef<'_, E>, &PermRef<'_, I>, Mat); +impl_perm!(MatRef<'_, E>, &Perm, Mat); +impl_perm!(&MatRef<'_, E>, PermRef<'_, I>, Mat); +impl_perm!(&MatRef<'_, E>, Perm, Mat); +impl_perm!(&MatRef<'_, E>, &PermRef<'_, I>, Mat); +impl_perm!(&MatRef<'_, E>, &Perm, Mat); + +impl_perm!(MatMut<'_, E>, PermRef<'_, I>, Mat); +impl_perm!(MatMut<'_, E>, Perm, Mat); +impl_perm!(MatMut<'_, E>, &PermRef<'_, I>, Mat); +impl_perm!(MatMut<'_, E>, &Perm, Mat); +impl_perm!(&MatMut<'_, E>, PermRef<'_, I>, Mat); +impl_perm!(&MatMut<'_, E>, Perm, Mat); +impl_perm!(&MatMut<'_, E>, &PermRef<'_, I>, Mat); +impl_perm!(&MatMut<'_, E>, &Perm, Mat); + +impl_perm!(Mat, PermRef<'_, I>, Mat); +impl_perm!(Mat, Perm, Mat); +impl_perm!(Mat, &PermRef<'_, I>, Mat); +impl_perm!(Mat, &Perm, Mat); +impl_perm!(&Mat, PermRef<'_, I>, Mat); +impl_perm!(&Mat, Perm, Mat); +impl_perm!(&Mat, &PermRef<'_, I>, Mat); +impl_perm!(&Mat, &Perm, Mat); + +impl Mul> for RowRef<'_, E> +where + E::Canonical: ComplexField, +{ + type Output = Row; + + #[track_caller] + fn mul(self, rhs: PermRef<'_, I>) -> Self::Output { + let lhs = self; + + assert!(lhs.ncols() == rhs.len()); + let mut out = Row::zeros(lhs.ncols()); + let inv = rhs.arrays().1; + + for (j, inv) in inv.iter().enumerate() { + out.write(j, lhs.read(inv.to_signed().zx()).canonicalize()); + } + out + } +} + +// impl_perm!(RowRef<'_, E>, PermRef<'_, I>, Row); +impl_perm!(RowRef<'_, E>, Perm, Row); +impl_perm!(RowRef<'_, E>, &PermRef<'_, I>, Row); +impl_perm!(RowRef<'_, E>, &Perm, Row); +impl_perm!(&RowRef<'_, E>, PermRef<'_, I>, Row); +impl_perm!(&RowRef<'_, E>, Perm, Row); +impl_perm!(&RowRef<'_, E>, &PermRef<'_, I>, Row); +impl_perm!(&RowRef<'_, E>, &Perm, Row); + +impl_perm!(RowMut<'_, E>, PermRef<'_, I>, Row); +impl_perm!(RowMut<'_, E>, Perm, Row); +impl_perm!(RowMut<'_, E>, &PermRef<'_, I>, Row); +impl_perm!(RowMut<'_, E>, &Perm, Row); +impl_perm!(&RowMut<'_, E>, PermRef<'_, I>, Row); +impl_perm!(&RowMut<'_, E>, Perm, Row); +impl_perm!(&RowMut<'_, E>, &PermRef<'_, I>, Row); +impl_perm!(&RowMut<'_, E>, &Perm, Row); + +impl_perm!(Row, PermRef<'_, I>, Row); +impl_perm!(Row, Perm, Row); +impl_perm!(Row, &PermRef<'_, I>, Row); +impl_perm!(Row, &Perm, Row); +impl_perm!(&Row, PermRef<'_, I>, Row); +impl_perm!(&Row, Perm, Row); +impl_perm!(&Row, &PermRef<'_, I>, Row); +impl_perm!(&Row, &Perm, Row); + +impl, RhsE: Conjugate> + Mul> for MatRef<'_, LhsE> +{ + type Output = Mat; + + fn mul(self, rhs: Scale) -> Self::Output { + zipped!(self).map(|unzipped!(x)| x.read().canonicalize().faer_mul(rhs.0.canonicalize())) + } +} +impl, RhsE: Conjugate> + Mul> for Scale +{ + type Output = Mat; + + fn mul(self, rhs: MatRef<'_, RhsE>) -> Self::Output { + zipped!(rhs).map(|unzipped!(x)| x.read().canonicalize().faer_mul(self.0.canonicalize())) + } +} + +impl, RhsE: Conjugate> + Mul> for ColRef<'_, LhsE> +{ + type Output = Col; + + fn mul(self, rhs: Scale) -> Self::Output { + zipped!(self).map(|unzipped!(x)| x.read().canonicalize().faer_mul(rhs.0.canonicalize())) + } +} +impl, RhsE: Conjugate> + Mul> for Scale +{ + type Output = Col; + + fn mul(self, rhs: ColRef<'_, RhsE>) -> Self::Output { + zipped!(rhs).map(|unzipped!(x)| x.read().canonicalize().faer_mul(self.0.canonicalize())) + } +} + +impl, RhsE: Conjugate> + Mul> for RowRef<'_, LhsE> +{ + type Output = Row; + + fn mul(self, rhs: Scale) -> Self::Output { + zipped!(self).map(|unzipped!(x)| x.read().canonicalize().faer_mul(rhs.0.canonicalize())) + } +} +impl, RhsE: Conjugate> + Mul> for Scale +{ + type Output = Row; + + fn mul(self, rhs: RowRef<'_, RhsE>) -> Self::Output { + zipped!(rhs).map(|unzipped!(x)| x.read().canonicalize().faer_mul(self.0.canonicalize())) + } +} + +impl, RhsE: Conjugate> + Mul> for DiagRef<'_, LhsE> +{ + type Output = Diag; + + fn mul(self, rhs: Scale) -> Self::Output { + zipped!(self.column_vector()) + .map(|unzipped!(x)| x.read().canonicalize().faer_mul(rhs.0.canonicalize())) + .column_vector_into_diagonal() + } +} +impl, RhsE: Conjugate> + Mul> for Scale +{ + type Output = Diag; + + fn mul(self, rhs: DiagRef<'_, RhsE>) -> Self::Output { + zipped!(rhs.column_vector()) + .map(|unzipped!(x)| x.read().canonicalize().faer_mul(self.0.canonicalize())) + .column_vector_into_diagonal() + } +} + +// impl_mul_scalar!(MatRef<'_, LhsE>, Scale, Mat); +impl_mul_scalar!(MatMut<'_, LhsE>, Scale, Mat); +impl_mul_scalar!(Mat, Scale, Mat); +impl_mul_scalar!(&MatRef<'_, LhsE>, Scale, Mat); +impl_mul_scalar!(&MatMut<'_, LhsE>, Scale, Mat); +impl_mul_scalar!(&Mat, Scale, Mat); + +// impl_scalar_mul!(Scale, MatRef<'_, RhsE>, Mat); +impl_scalar_mul!(Scale, MatMut<'_, RhsE>, Mat); +impl_scalar_mul!(Scale, Mat, Mat); +impl_scalar_mul!(Scale, &MatRef<'_, RhsE>, Mat); +impl_scalar_mul!(Scale, &MatMut<'_, RhsE>, Mat); +impl_scalar_mul!(Scale, &Mat, Mat); + +// impl_mul_scalar!(ColRef<'_, LhsE>, Scale, Col); +impl_mul_scalar!(ColMut<'_, LhsE>, Scale, Col); +impl_mul_scalar!(Col, Scale, Col); +impl_mul_scalar!(&ColRef<'_, LhsE>, Scale, Col); +impl_mul_scalar!(&ColMut<'_, LhsE>, Scale, Col); +impl_mul_scalar!(&Col, Scale, Col); + +// impl_scalar_mul!(Scale, ColRef<'_, RhsE>, Col); +impl_scalar_mul!(Scale, ColMut<'_, RhsE>, Col); +impl_scalar_mul!(Scale, Col, Col); +impl_scalar_mul!(Scale, &ColRef<'_, RhsE>, Col); +impl_scalar_mul!(Scale, &ColMut<'_, RhsE>, Col); +impl_scalar_mul!(Scale, &Col, Col); + +// impl_mul_scalar!(RowRef<'_, LhsE>, Scale, Row); +impl_mul_scalar!(RowMut<'_, LhsE>, Scale, Row); +impl_mul_scalar!(Row, Scale, Row); +impl_mul_scalar!(&RowRef<'_, LhsE>, Scale, Row); +impl_mul_scalar!(&RowMut<'_, LhsE>, Scale, Row); +impl_mul_scalar!(&Row, Scale, Row); + +// impl_scalar_mul!(Scale, RowRef<'_, RhsE>, Row); +impl_scalar_mul!(Scale, RowMut<'_, RhsE>, Row); +impl_scalar_mul!(Scale, Row, Row); +impl_scalar_mul!(Scale, &RowRef<'_, RhsE>, Row); +impl_scalar_mul!(Scale, &RowMut<'_, RhsE>, Row); +impl_scalar_mul!(Scale, &Row, Row); + +// impl_mul_scalar!(DiagRef<'_, LhsE>, Scale, Diag); +impl_mul_scalar!(DiagMut<'_, LhsE>, Scale, Diag); +impl_mul_scalar!(Diag, Scale, Diag); +impl_mul_scalar!(&DiagRef<'_, LhsE>, Scale, Diag); +impl_mul_scalar!(&DiagMut<'_, LhsE>, Scale, Diag); +impl_mul_scalar!(&Diag, Scale, Diag); + +// impl_scalar_mul!(Scale, DiagRef<'_, RhsE>, Diag); +impl_scalar_mul!(Scale, DiagMut<'_, RhsE>, Diag); +impl_scalar_mul!(Scale, Diag, Diag); +impl_scalar_mul!(Scale, &DiagRef<'_, RhsE>, Diag); +impl_scalar_mul!(Scale, &DiagMut<'_, RhsE>, Diag); +impl_scalar_mul!(Scale, &Diag, Diag); + +impl> MulAssign> + for MatMut<'_, LhsE> +{ + fn mul_assign(&mut self, rhs: Scale) { + zipped!(self.as_mut()) + .for_each(|unzipped!(mut x)| x.write(x.read().faer_mul(rhs.0.canonicalize()))) + } +} +impl> MulAssign> + for ColMut<'_, LhsE> +{ + fn mul_assign(&mut self, rhs: Scale) { + zipped!(self.as_mut()) + .for_each(|unzipped!(mut x)| x.write(x.read().faer_mul(rhs.0.canonicalize()))) + } +} +impl> MulAssign> + for RowMut<'_, LhsE> +{ + fn mul_assign(&mut self, rhs: Scale) { + zipped!(self.as_mut()) + .for_each(|unzipped!(mut x)| x.write(x.read().faer_mul(rhs.0.canonicalize()))) + } +} +impl> MulAssign> + for DiagMut<'_, LhsE> +{ + fn mul_assign(&mut self, rhs: Scale) { + zipped!(self.as_mut().column_vector_mut()) + .for_each(|unzipped!(mut x)| x.write(x.read().faer_mul(rhs.0.canonicalize()))) + } +} + +// impl_mul_assign_scalar!(MatMut<'_, LhsE>, Scale); +impl_mul_assign_scalar!(Mat, Scale); +// impl_mul_assign_scalar!(ColMut<'_, LhsE>, Scale); +impl_mul_assign_scalar!(Col, Scale); +// impl_mul_assign_scalar!(RowMut<'_, LhsE>, Scale); +impl_mul_assign_scalar!(Row, Scale); +// impl_mul_assign_scalar!(DiagMut<'_, LhsE>, Scale); +impl_mul_assign_scalar!(Diag, Scale); + +impl, RhsE: Conjugate> + Mul> for SparseColMatRef<'_, I, LhsE> +{ + type Output = Mat; + + #[track_caller] + fn mul(self, rhs: MatRef<'_, RhsE>) -> Self::Output { + let lhs = self; + let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); + crate::sparse::linalg::matmul::sparse_dense_matmul( + out.as_mut(), + lhs, + rhs, + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +impl, RhsE: Conjugate> + Mul> for SparseColMatRef<'_, I, LhsE> +{ + type Output = Col; + + #[track_caller] + fn mul(self, rhs: ColRef<'_, RhsE>) -> Self::Output { + let lhs = self; + let mut out = Col::zeros(lhs.nrows()); + crate::sparse::linalg::matmul::sparse_dense_matmul( + out.as_mut().as_2d_mut(), + lhs, + rhs.as_2d(), + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +impl, RhsE: Conjugate> + Mul> for MatRef<'_, LhsE> +{ + type Output = Mat; + + #[track_caller] + fn mul(self, rhs: SparseColMatRef<'_, I, RhsE>) -> Self::Output { + let lhs = self; + let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); + crate::sparse::linalg::matmul::dense_sparse_matmul( + out.as_mut(), + lhs, + rhs, + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +impl, RhsE: Conjugate> + Mul> for RowRef<'_, LhsE> +{ + type Output = Row; + + #[track_caller] + fn mul(self, rhs: SparseColMatRef<'_, I, RhsE>) -> Self::Output { + let lhs = self; + let mut out = Row::zeros(rhs.ncols()); + crate::sparse::linalg::matmul::dense_sparse_matmul( + out.as_mut().as_2d_mut(), + lhs.as_2d(), + rhs, + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +impl, RhsE: Conjugate> + Mul> for SparseRowMatRef<'_, I, LhsE> +{ + type Output = Mat; + + #[track_caller] + fn mul(self, rhs: MatRef<'_, RhsE>) -> Self::Output { + let lhs = self; + let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); + crate::sparse::linalg::matmul::dense_sparse_matmul( + out.as_mut().transpose_mut(), + rhs.transpose(), + lhs.transpose(), + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +impl, RhsE: Conjugate> + Mul> for SparseRowMatRef<'_, I, LhsE> +{ + type Output = Col; + + #[track_caller] + fn mul(self, rhs: ColRef<'_, RhsE>) -> Self::Output { + let lhs = self; + let mut out = Col::zeros(lhs.nrows()); + crate::sparse::linalg::matmul::dense_sparse_matmul( + out.as_mut().transpose_mut().as_2d_mut(), + rhs.transpose().as_2d(), + lhs.transpose(), + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +impl, RhsE: Conjugate> + Mul> for MatRef<'_, LhsE> +{ + type Output = Mat; + + #[track_caller] + fn mul(self, rhs: SparseRowMatRef<'_, I, RhsE>) -> Self::Output { + let lhs = self; + let mut out = Mat::zeros(lhs.nrows(), rhs.ncols()); + crate::sparse::linalg::matmul::sparse_dense_matmul( + out.as_mut().transpose_mut(), + rhs.transpose(), + lhs.transpose(), + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +impl, RhsE: Conjugate> + Mul> for RowRef<'_, LhsE> +{ + type Output = Row; + + #[track_caller] + fn mul(self, rhs: SparseRowMatRef<'_, I, RhsE>) -> Self::Output { + let lhs = self; + let mut out = Row::zeros(rhs.ncols()); + crate::sparse::linalg::matmul::sparse_dense_matmul( + out.as_mut().transpose_mut().as_2d_mut(), + rhs.transpose(), + lhs.transpose().as_2d(), + None, + E::faer_one(), + get_global_parallelism(), + ); + out + } +} + +// impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, Mat, Mat); +impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, &Mat, Mat); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, Mat, Mat); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, &Mat, Mat); + +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, Mat, Mat); +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, &Mat, Mat); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, Mat, Mat); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, &Mat, Mat); + +impl_sparse_mul!(SparseColMat, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMat, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMat, Mat, Mat); +impl_sparse_mul!(SparseColMat, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMat, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseColMat, &Mat, Mat); +impl_sparse_mul!(&SparseColMat, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMat, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMat, Mat, Mat); +impl_sparse_mul!(&SparseColMat, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMat, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseColMat, &Mat, Mat); + +// impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, Mat, Mat); +impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, &Mat, Mat); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, Mat, Mat); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, &Mat, Mat); + +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, Mat, Mat); +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, &Mat, Mat); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, Mat, Mat); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, &Mat, Mat); + +impl_sparse_mul!(SparseRowMat, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMat, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMat, Mat, Mat); +impl_sparse_mul!(SparseRowMat, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMat, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(SparseRowMat, &Mat, Mat); +impl_sparse_mul!(&SparseRowMat, MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMat, MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMat, Mat, Mat); +impl_sparse_mul!(&SparseRowMat, &MatRef<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMat, &MatMut<'_, RhsE>, Mat); +impl_sparse_mul!(&SparseRowMat, &Mat, Mat); + +// impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, Col, Col); +impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMatRef<'_, I, LhsE>, &Col, Col); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, Col, Col); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMatRef<'_, I, LhsE>, &Col, Col); + +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, Col, Col); +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMatMut<'_, I, LhsE>, &Col, Col); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, Col, Col); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMatMut<'_, I, LhsE>, &Col, Col); + +impl_sparse_mul!(SparseColMat, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMat, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMat, Col, Col); +impl_sparse_mul!(SparseColMat, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMat, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseColMat, &Col, Col); +impl_sparse_mul!(&SparseColMat, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMat, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMat, Col, Col); +impl_sparse_mul!(&SparseColMat, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMat, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseColMat, &Col, Col); + +// impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, Col, Col); +impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMatRef<'_, I, LhsE>, &Col, Col); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, Col, Col); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMatRef<'_, I, LhsE>, &Col, Col); + +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, Col, Col); +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMatMut<'_, I, LhsE>, &Col, Col); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, Col, Col); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMatMut<'_, I, LhsE>, &Col, Col); + +impl_sparse_mul!(SparseRowMat, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMat, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMat, Col, Col); +impl_sparse_mul!(SparseRowMat, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMat, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(SparseRowMat, &Col, Col); +impl_sparse_mul!(&SparseRowMat, ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMat, ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMat, Col, Col); +impl_sparse_mul!(&SparseRowMat, &ColRef<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMat, &ColMut<'_, RhsE>, Col); +impl_sparse_mul!(&SparseRowMat, &Col, Col); + +// impl_sparse_mul!(MatRef<'_, LhsE>, SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatRef<'_, LhsE>, SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatRef<'_, LhsE>, SparseColMat, Mat); +impl_sparse_mul!(MatRef<'_, LhsE>, &SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatRef<'_, LhsE>, &SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatRef<'_, LhsE>, &SparseColMat, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, SparseColMat, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, &SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, &SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, &SparseColMat, Mat); + +impl_sparse_mul!(MatMut<'_, LhsE>, SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatMut<'_, LhsE>, SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatMut<'_, LhsE>, SparseColMat, Mat); +impl_sparse_mul!(MatMut<'_, LhsE>, &SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatMut<'_, LhsE>, &SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatMut<'_, LhsE>, &SparseColMat, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, SparseColMat, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, &SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, &SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, &SparseColMat, Mat); + +impl_sparse_mul!(Mat, SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(Mat, SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(Mat< LhsE>, SparseColMat, Mat); +impl_sparse_mul!(Mat, &SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(Mat, &SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(Mat, &SparseColMat, Mat); +impl_sparse_mul!(&Mat, SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&Mat, SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&Mat, SparseColMat, Mat); +impl_sparse_mul!(&Mat, &SparseColMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&Mat, &SparseColMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&Mat, &SparseColMat, Mat); + +// impl_sparse_mul!(RowRef<'_, LhsE>, SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(RowRef<'_, LhsE>, SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(RowRef<'_, LhsE>, SparseColMat, Row); +impl_sparse_mul!(RowRef<'_, LhsE>, &SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(RowRef<'_, LhsE>, &SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(RowRef<'_, LhsE>, &SparseColMat, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, SparseColMat, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, &SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, &SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, &SparseColMat, Row); + +impl_sparse_mul!(RowMut<'_, LhsE>, SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(RowMut<'_, LhsE>, SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(RowMut<'_, LhsE>, SparseColMat, Row); +impl_sparse_mul!(RowMut<'_, LhsE>, &SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(RowMut<'_, LhsE>, &SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(RowMut<'_, LhsE>, &SparseColMat, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, SparseColMat, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, &SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, &SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, &SparseColMat, Row); + +impl_sparse_mul!(Row, SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(Row, SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(Row, SparseColMat, Row); +impl_sparse_mul!(Row, &SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(Row, &SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(Row, &SparseColMat, Row); +impl_sparse_mul!(&Row, SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&Row, SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&Row, SparseColMat, Row); +impl_sparse_mul!(&Row, &SparseColMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&Row, &SparseColMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&Row, &SparseColMat, Row); + +// impl_sparse_mul!(MatRef<'_, LhsE>, SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatRef<'_, LhsE>, SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatRef<'_, LhsE>, SparseRowMat, Mat); +impl_sparse_mul!(MatRef<'_, LhsE>, &SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatRef<'_, LhsE>, &SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatRef<'_, LhsE>, &SparseRowMat, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, SparseRowMat, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, &SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, &SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatRef<'_, LhsE>, &SparseRowMat, Mat); + +impl_sparse_mul!(MatMut<'_, LhsE>, SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatMut<'_, LhsE>, SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatMut<'_, LhsE>, SparseRowMat, Mat); +impl_sparse_mul!(MatMut<'_, LhsE>, &SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatMut<'_, LhsE>, &SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(MatMut<'_, LhsE>, &SparseRowMat, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, SparseRowMat, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, &SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, &SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&MatMut<'_, LhsE>, &SparseRowMat, Mat); + +impl_sparse_mul!(Mat, SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(Mat, SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(Mat< LhsE>, SparseRowMat, Mat); +impl_sparse_mul!(Mat, &SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(Mat, &SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(Mat, &SparseRowMat, Mat); +impl_sparse_mul!(&Mat, SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&Mat, SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&Mat, SparseRowMat, Mat); +impl_sparse_mul!(&Mat, &SparseRowMatRef<'_, I, RhsE>, Mat); +impl_sparse_mul!(&Mat, &SparseRowMatMut<'_, I, RhsE>, Mat); +impl_sparse_mul!(&Mat, &SparseRowMat, Mat); + +// impl_sparse_mul!(RowRef<'_, LhsE>, SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(RowRef<'_, LhsE>, SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(RowRef<'_, LhsE>, SparseRowMat, Row); +impl_sparse_mul!(RowRef<'_, LhsE>, &SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(RowRef<'_, LhsE>, &SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(RowRef<'_, LhsE>, &SparseRowMat, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, SparseRowMat, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, &SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, &SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowRef<'_, LhsE>, &SparseRowMat, Row); + +impl_sparse_mul!(RowMut<'_, LhsE>, SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(RowMut<'_, LhsE>, SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(RowMut<'_, LhsE>, SparseRowMat, Row); +impl_sparse_mul!(RowMut<'_, LhsE>, &SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(RowMut<'_, LhsE>, &SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(RowMut<'_, LhsE>, &SparseRowMat, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, SparseRowMat, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, &SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, &SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&RowMut<'_, LhsE>, &SparseRowMat, Row); + +impl_sparse_mul!(Row, SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(Row, SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(Row, SparseRowMat, Row); +impl_sparse_mul!(Row, &SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(Row, &SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(Row, &SparseRowMat, Row); +impl_sparse_mul!(&Row, SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&Row, SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&Row, SparseRowMat, Row); +impl_sparse_mul!(&Row, &SparseRowMatRef<'_, I, RhsE>, Row); +impl_sparse_mul!(&Row, &SparseRowMatMut<'_, I, RhsE>, Row); +impl_sparse_mul!(&Row, &SparseRowMat, Row); + +impl> + PartialEq> for SparseColMatRef<'_, I, LhsE> +{ + fn eq(&self, other: &SparseColMatRef<'_, I, RhsE>) -> bool { + let lhs = *self; + let rhs = *other; + + if lhs.nrows() != rhs.nrows() || lhs.ncols() != rhs.ncols() { + return false; + } + + let n = lhs.ncols(); + let mut equal = true; + for j in 0..n { + equal &= lhs.row_indices_of_col_raw(j) == rhs.row_indices_of_col_raw(j); + if !equal { + return false; + } + + let lhs_val = crate::utils::slice::SliceGroup::<'_, LhsE>::new(lhs.values_of_col(j)); + let rhs_val = crate::utils::slice::SliceGroup::<'_, RhsE>::new(rhs.values_of_col(j)); + equal &= lhs_val + .into_ref_iter() + .map(|r| r.read().canonicalize()) + .eq(rhs_val.into_ref_iter().map(|r| r.read().canonicalize())); + + if !equal { + return false; + } + } + + equal + } +} + +impl> + PartialEq> for SparseRowMatRef<'_, I, LhsE> +{ + #[inline] + fn eq(&self, other: &SparseRowMatRef<'_, I, RhsE>) -> bool { + self.transpose() == other.transpose() + } +} + +// impl_partial_eq_sparse!(SparseColMatRef<'_, I, LhsE>, SparseColMatRef<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseColMatRef<'_, I, LhsE>, SparseColMatMut<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseColMatRef<'_, I, LhsE>, SparseColMat); +impl_partial_eq_sparse!(SparseColMatMut<'_, I, LhsE>, SparseColMatRef<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseColMatMut<'_, I, LhsE>, SparseColMatMut<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseColMatMut<'_, I, LhsE>, SparseColMat); +impl_partial_eq_sparse!(SparseColMat, SparseColMatRef<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseColMat, SparseColMatMut<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseColMat, SparseColMat); + +// impl_partial_eq_sparse!(SparseRowMatRef<'_, I, LhsE>, SparseRowMatRef<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseRowMatRef<'_, I, LhsE>, SparseRowMatMut<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseRowMatRef<'_, I, LhsE>, SparseRowMat); +impl_partial_eq_sparse!(SparseRowMatMut<'_, I, LhsE>, SparseRowMatRef<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseRowMatMut<'_, I, LhsE>, SparseRowMatMut<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseRowMatMut<'_, I, LhsE>, SparseRowMat); +impl_partial_eq_sparse!(SparseRowMat, SparseRowMatRef<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseRowMat, SparseRowMatMut<'_, I, RhsE>); +impl_partial_eq_sparse!(SparseRowMat, SparseRowMat); + +impl, RhsE: Conjugate> + Add> for SparseColMatRef<'_, I, LhsE> +{ + type Output = SparseColMat; + #[track_caller] + fn add(self, rhs: SparseColMatRef<'_, I, RhsE>) -> Self::Output { + crate::sparse::ops::add(self, rhs).unwrap() + } +} + +impl, RhsE: Conjugate> + Sub> for SparseColMatRef<'_, I, LhsE> +{ + type Output = SparseColMat; + #[track_caller] + fn sub(self, rhs: SparseColMatRef<'_, I, RhsE>) -> Self::Output { + crate::sparse::ops::sub(self, rhs).unwrap() + } +} + +impl, RhsE: Conjugate> + Add> for SparseRowMatRef<'_, I, LhsE> +{ + type Output = SparseRowMat; + #[track_caller] + fn add(self, rhs: SparseRowMatRef<'_, I, RhsE>) -> Self::Output { + (self.transpose() + rhs.transpose()).into_transpose() + } +} + +impl, RhsE: Conjugate> + Sub> for SparseRowMatRef<'_, I, LhsE> +{ + type Output = SparseRowMat; + #[track_caller] + fn sub(self, rhs: SparseRowMatRef<'_, I, RhsE>) -> Self::Output { + (self.transpose() - rhs.transpose()).into_transpose() + } +} + +impl> + AddAssign> for SparseColMatMut<'_, I, LhsE> +{ + #[track_caller] + fn add_assign(&mut self, other: SparseColMatRef<'_, I, RhsE>) { + crate::sparse::ops::add_assign(self.as_mut(), other); + } +} + +impl> + SubAssign> for SparseColMatMut<'_, I, LhsE> +{ + #[track_caller] + fn sub_assign(&mut self, other: SparseColMatRef<'_, I, RhsE>) { + crate::sparse::ops::sub_assign(self.as_mut(), other); + } +} + +impl> + AddAssign> for SparseRowMatMut<'_, I, LhsE> +{ + #[track_caller] + fn add_assign(&mut self, other: SparseRowMatRef<'_, I, RhsE>) { + crate::sparse::ops::add_assign(self.as_mut().transpose_mut(), other.transpose()); + } +} + +impl> + SubAssign> for SparseRowMatMut<'_, I, LhsE> +{ + #[track_caller] + fn sub_assign(&mut self, other: SparseRowMatRef<'_, I, RhsE>) { + crate::sparse::ops::sub_assign(self.as_mut().transpose_mut(), other.transpose()); + } +} + +impl Neg for SparseColMatRef<'_, I, E> +where + E::Canonical: ComplexField, +{ + type Output = SparseColMat; + #[track_caller] + fn neg(self) -> Self::Output { + let mut out = self.to_owned().unwrap(); + for mut x in crate::utils::slice::SliceGroupMut::<'_, E::Canonical>::new(out.values_mut()) + .into_mut_iter() + { + x.write(x.read().faer_neg()) + } + out + } +} +impl Neg for SparseRowMatRef<'_, I, E> +where + E::Canonical: ComplexField, +{ + type Output = SparseRowMat; + #[track_caller] + fn neg(self) -> Self::Output { + (-self.transpose()).into_transpose() + } +} + +impl, RhsE: Conjugate> + Mul> for Scale +{ + type Output = SparseColMat; + #[track_caller] + fn mul(self, rhs: SparseColMatRef<'_, I, RhsE>) -> Self::Output { + let mut out = rhs.to_owned().unwrap(); + for mut x in + crate::utils::slice::SliceGroupMut::<'_, E>::new(out.values_mut()).into_mut_iter() + { + x.write(self.0.canonicalize().faer_mul(x.read())) + } + out + } +} + +impl, RhsE: Conjugate> + Mul> for SparseColMatRef<'_, I, LhsE> +{ + type Output = SparseColMat; + #[track_caller] + fn mul(self, rhs: Scale) -> Self::Output { + let mut out = self.to_owned().unwrap(); + for mut x in + crate::utils::slice::SliceGroupMut::<'_, E>::new(out.values_mut()).into_mut_iter() + { + x.write(x.read().faer_mul(rhs.0.canonicalize())) + } + out + } +} + +impl, RhsE: Conjugate> + Mul> for Scale +{ + type Output = SparseRowMat; + #[track_caller] + fn mul(self, rhs: SparseRowMatRef<'_, I, RhsE>) -> Self::Output { + self.mul(rhs.transpose()).into_transpose() + } +} + +impl, RhsE: Conjugate> + Mul> for SparseRowMatRef<'_, I, LhsE> +{ + type Output = SparseRowMat; + #[track_caller] + fn mul(self, rhs: Scale) -> Self::Output { + self.transpose().mul(rhs).into_transpose() + } +} + +#[rustfmt::skip] +// impl_add_sub_sparse!(SparseColMatRef<'_, I, LhsE>, SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMatRef<'_, I, LhsE>, SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMatRef<'_, I, LhsE>, SparseColMat, SparseColMat); +impl_add_sub_sparse!(SparseColMatRef<'_, I, LhsE>, &SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMatRef<'_, I, LhsE>, &SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMatRef<'_, I, LhsE>, &SparseColMat, SparseColMat); +impl_add_sub_sparse!(SparseColMatMut<'_, I, LhsE>, SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMatMut<'_, I, LhsE>, SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMatMut<'_, I, LhsE>, SparseColMat, SparseColMat); +impl_add_sub_sparse!(SparseColMatMut<'_, I, LhsE>, &SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMatMut<'_, I, LhsE>, &SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMatMut<'_, I, LhsE>, &SparseColMat, SparseColMat); +impl_add_sub_sparse!(SparseColMat, SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMat, SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMat, SparseColMat, SparseColMat); +impl_add_sub_sparse!(SparseColMat, &SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMat, &SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(SparseColMat, &SparseColMat, SparseColMat); +impl_add_sub_sparse!(&SparseColMatRef<'_, I, LhsE>, SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMatRef<'_, I, LhsE>, SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMatRef<'_, I, LhsE>, SparseColMat, SparseColMat); +impl_add_sub_sparse!(&SparseColMatRef<'_, I, LhsE>, &SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMatRef<'_, I, LhsE>, &SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMatRef<'_, I, LhsE>, &SparseColMat, SparseColMat); +impl_add_sub_sparse!(&SparseColMatMut<'_, I, LhsE>, SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMatMut<'_, I, LhsE>, SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMatMut<'_, I, LhsE>, SparseColMat, SparseColMat); +impl_add_sub_sparse!(&SparseColMatMut<'_, I, LhsE>, &SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMatMut<'_, I, LhsE>, &SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMatMut<'_, I, LhsE>, &SparseColMat, SparseColMat); +impl_add_sub_sparse!(&SparseColMat, SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMat, SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMat, SparseColMat, SparseColMat); +impl_add_sub_sparse!(&SparseColMat, &SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMat, &SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_add_sub_sparse!(&SparseColMat, &SparseColMat, SparseColMat); +#[rustfmt::skip] +// impl_add_sub_sparse!(SparseRowMatRef<'_, I, LhsE>, SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMatRef<'_, I, LhsE>, SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMatRef<'_, I, LhsE>, SparseRowMat, SparseRowMat); +impl_add_sub_sparse!(SparseRowMatRef<'_, I, LhsE>, &SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMatRef<'_, I, LhsE>, &SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMatRef<'_, I, LhsE>, &SparseRowMat, SparseRowMat); +impl_add_sub_sparse!(SparseRowMatMut<'_, I, LhsE>, SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMatMut<'_, I, LhsE>, SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMatMut<'_, I, LhsE>, SparseRowMat, SparseRowMat); +impl_add_sub_sparse!(SparseRowMatMut<'_, I, LhsE>, &SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMatMut<'_, I, LhsE>, &SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMatMut<'_, I, LhsE>, &SparseRowMat, SparseRowMat); +impl_add_sub_sparse!(SparseRowMat, SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMat, SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMat, SparseRowMat, SparseRowMat); +impl_add_sub_sparse!(SparseRowMat, &SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMat, &SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(SparseRowMat, &SparseRowMat, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatRef<'_, I, LhsE>, SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatRef<'_, I, LhsE>, SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatRef<'_, I, LhsE>, SparseRowMat, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatRef<'_, I, LhsE>, &SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatRef<'_, I, LhsE>, &SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatRef<'_, I, LhsE>, &SparseRowMat, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatMut<'_, I, LhsE>, SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatMut<'_, I, LhsE>, SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatMut<'_, I, LhsE>, SparseRowMat, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatMut<'_, I, LhsE>, &SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatMut<'_, I, LhsE>, &SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMatMut<'_, I, LhsE>, &SparseRowMat, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMat, SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMat, SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMat, SparseRowMat, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMat, &SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMat, &SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_add_sub_sparse!(&SparseRowMat, &SparseRowMat, SparseRowMat); + +// impl_add_sub_assign_sparse!(SparseColMatMut<'_, I, LhsE>, SparseColMatRef<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseColMatMut<'_, I, LhsE>, SparseColMatMut<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseColMatMut<'_, I, LhsE>, SparseColMat); +impl_add_sub_assign_sparse!(SparseColMatMut<'_, I, LhsE>, &SparseColMatRef<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseColMatMut<'_, I, LhsE>, &SparseColMatMut<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseColMatMut<'_, I, LhsE>, &SparseColMat); +impl_add_sub_assign_sparse!(SparseColMat, SparseColMatRef<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseColMat, SparseColMatMut<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseColMat, SparseColMat); +impl_add_sub_assign_sparse!(SparseColMat, &SparseColMatRef<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseColMat, &SparseColMatMut<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseColMat, &SparseColMat); +// impl_add_sub_assign_sparse!(SparseRowMatMut<'_, I, LhsE>, SparseRowMatRef<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseRowMatMut<'_, I, LhsE>, SparseRowMatMut<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseRowMatMut<'_, I, LhsE>, SparseRowMat); +impl_add_sub_assign_sparse!(SparseRowMatMut<'_, I, LhsE>, &SparseRowMatRef<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseRowMatMut<'_, I, LhsE>, &SparseRowMatMut<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseRowMatMut<'_, I, LhsE>, &SparseRowMat); +impl_add_sub_assign_sparse!(SparseRowMat, SparseRowMatRef<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseRowMat, SparseRowMatMut<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseRowMat, SparseRowMat); +impl_add_sub_assign_sparse!(SparseRowMat, &SparseRowMatRef<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseRowMat, &SparseRowMatMut<'_, I, RhsE>); +impl_add_sub_assign_sparse!(SparseRowMat, &SparseRowMat); + +// impl_neg_sparse!(SparseColMatRef<'_, I, E>, SparseColMat); +impl_neg_sparse!(SparseColMatMut<'_, I, E>, SparseColMat); +impl_neg_sparse!(SparseColMat, SparseColMat); +impl_neg_sparse!(&SparseColMatRef<'_, I, E>, SparseColMat); +impl_neg_sparse!(&SparseColMatMut<'_, I, E>, SparseColMat); +impl_neg_sparse!(&SparseColMat, SparseColMat); +// impl_neg_sparse!(SparseRowMatRef<'_, I, E>, SparseRowMat); +impl_neg_sparse!(SparseRowMatMut<'_, I, E>, SparseRowMat); +impl_neg_sparse!(SparseRowMat, SparseRowMat); +impl_neg_sparse!(&SparseRowMatRef<'_, I, E>, SparseRowMat); +impl_neg_sparse!(&SparseRowMatMut<'_, I, E>, SparseRowMat); +impl_neg_sparse!(&SparseRowMat, SparseRowMat); + +// impl_scalar_mul_sparse!(Scale, SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_scalar_mul_sparse!(Scale, SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_scalar_mul_sparse!(Scale, SparseColMat, SparseColMat); +impl_scalar_mul_sparse!(Scale, &SparseColMatRef<'_, I, RhsE>, SparseColMat); +impl_scalar_mul_sparse!(Scale, &SparseColMatMut<'_, I, RhsE>, SparseColMat); +impl_scalar_mul_sparse!(Scale, &SparseColMat, SparseColMat); + +// impl_scalar_mul_sparse!(Scale, SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_scalar_mul_sparse!(Scale, SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_scalar_mul_sparse!(Scale, SparseRowMat, SparseRowMat); +impl_scalar_mul_sparse!(Scale, &SparseRowMatRef<'_, I, RhsE>, SparseRowMat); +impl_scalar_mul_sparse!(Scale, &SparseRowMatMut<'_, I, RhsE>, SparseRowMat); +impl_scalar_mul_sparse!(Scale, &SparseRowMat, SparseRowMat); + +// impl_mul_scalar_sparse!(SparseColMatRef<'_, I, LhsE>, Scale, SparseColMat); +impl_mul_scalar_sparse!(SparseColMatMut<'_, I, LhsE>, Scale, SparseColMat); +impl_mul_scalar_sparse!(SparseColMat, Scale, SparseColMat); +impl_mul_scalar_sparse!(&SparseColMatRef<'_, I, LhsE>, Scale, SparseColMat); +impl_mul_scalar_sparse!(&SparseColMatMut<'_, I, LhsE>, Scale, SparseColMat); +impl_mul_scalar_sparse!(&SparseColMat, Scale, SparseColMat); + +// impl_mul_scalar_sparse!(SparseRowMatRef<'_, I, LhsE>, Scale, SparseRowMat); +impl_mul_scalar_sparse!(SparseRowMatMut<'_, I, LhsE>, Scale, SparseRowMat); +impl_mul_scalar_sparse!(SparseRowMat, Scale, SparseRowMat); +impl_mul_scalar_sparse!(&SparseRowMatRef<'_, I, LhsE>, Scale, SparseRowMat); +impl_mul_scalar_sparse!(&SparseRowMatMut<'_, I, LhsE>, Scale, SparseRowMat); +impl_mul_scalar_sparse!(&SparseRowMat, Scale, SparseRowMat); + +#[cfg(test)] +#[allow(non_snake_case)] +mod test { + use crate::{assert, col::*, mat, mat::*, perm::*, row::*}; + use assert_approx_eq::assert_approx_eq; + + fn matrices() -> (Mat, Mat) { + let A = mat![[2.8, -3.3], [-1.7, 5.2], [4.6, -8.3],]; + + let B = mat![[-7.9, 8.3], [4.7, -3.2], [3.8, -5.2],]; + (A, B) + } + + #[test] + #[should_panic] + fn test_adding_matrices_of_different_sizes_should_panic() { + let A = mat![[1.0, 2.0], [3.0, 4.0]]; + let B = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + _ = A + B; + } + + #[test] + #[should_panic] + fn test_subtracting_two_matrices_of_different_sizes_should_panic() { + let A = mat![[1.0, 2.0], [3.0, 4.0]]; + let B = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + _ = A - B; + } + + #[test] + fn test_add() { + let (A, B) = matrices(); + + let expected = mat![[-5.1, 5.0], [3.0, 2.0], [8.4, -13.5],]; + + assert_matrix_approx_eq(A.as_ref() + B.as_ref(), &expected); + assert_matrix_approx_eq(&A + &B, &expected); + assert_matrix_approx_eq(A.as_ref() + &B, &expected); + assert_matrix_approx_eq(&A + B.as_ref(), &expected); + assert_matrix_approx_eq(A.as_ref() + B.clone(), &expected); + assert_matrix_approx_eq(&A + B.clone(), &expected); + assert_matrix_approx_eq(A.clone() + B.as_ref(), &expected); + assert_matrix_approx_eq(A.clone() + &B, &expected); + assert_matrix_approx_eq(A + B, &expected); + } + + #[test] + fn test_sub() { + let (A, B) = matrices(); + + let expected = mat![[10.7, -11.6], [-6.4, 8.4], [0.8, -3.1],]; + + assert_matrix_approx_eq(A.as_ref() - B.as_ref(), &expected); + assert_matrix_approx_eq(&A - &B, &expected); + assert_matrix_approx_eq(A.as_ref() - &B, &expected); + assert_matrix_approx_eq(&A - B.as_ref(), &expected); + assert_matrix_approx_eq(A.as_ref() - B.clone(), &expected); + assert_matrix_approx_eq(&A - B.clone(), &expected); + assert_matrix_approx_eq(A.clone() - B.as_ref(), &expected); + assert_matrix_approx_eq(A.clone() - &B, &expected); + assert_matrix_approx_eq(A - B, &expected); + } + + #[test] + fn test_neg() { + let (A, _) = matrices(); + + let expected = mat![[-2.8, 3.3], [1.7, -5.2], [-4.6, 8.3],]; + + assert_eq!(-A, expected); + } + + #[test] + fn test_scalar_mul() { + use crate::scale; + + let (A, _) = matrices(); + let scale = scale(3.0); + let expected = Mat::from_fn(A.nrows(), A.ncols(), |i, j| A.read(i, j) * scale.value()); + + { + assert_matrix_approx_eq(A.as_ref() * scale, &expected); + assert_matrix_approx_eq(&A * scale, &expected); + assert_matrix_approx_eq(A.as_ref() * scale, &expected); + assert_matrix_approx_eq(&A * scale, &expected); + assert_matrix_approx_eq(A.as_ref() * scale, &expected); + assert_matrix_approx_eq(&A * scale, &expected); + assert_matrix_approx_eq(A.clone() * scale, &expected); + assert_matrix_approx_eq(A.clone() * scale, &expected); + assert_matrix_approx_eq(A * scale, &expected); + } + + let (A, _) = matrices(); + { + assert_matrix_approx_eq(scale * A.as_ref(), &expected); + assert_matrix_approx_eq(scale * &A, &expected); + assert_matrix_approx_eq(scale * A.as_ref(), &expected); + assert_matrix_approx_eq(scale * &A, &expected); + assert_matrix_approx_eq(scale * A.as_ref(), &expected); + assert_matrix_approx_eq(scale * &A, &expected); + assert_matrix_approx_eq(scale * A.clone(), &expected); + assert_matrix_approx_eq(scale * A.clone(), &expected); + assert_matrix_approx_eq(scale * A, &expected); + } + } + + #[test] + fn test_diag_mul() { + let (A, _) = matrices(); + let diag_left = mat![[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]]; + let diag_right = mat![[4.0, 0.0], [0.0, 5.0]]; + + assert!(&diag_left * &A == diag_left.diagonal() * &A); + assert!(&A * &diag_right == &A * diag_right.diagonal()); + } + + #[test] + fn test_perm_mul() { + let A = Mat::from_fn(6, 5, |i, j| (j + 5 * i) as f64); + let pl = + Perm::::new_checked(Box::new([5, 1, 4, 0, 2, 3]), Box::new([3, 1, 4, 5, 2, 0])); + let pr = Perm::::new_checked(Box::new([1, 4, 0, 2, 3]), Box::new([2, 0, 3, 4, 1])); + + let perm_left = mat![ + [0.0, 0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0, 0.0], + [1.0, 0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0], + ]; + let perm_right = mat![ + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + ]; + + assert!( + &pl * pl.as_ref().inverse() + == PermRef::<'_, usize>::new_checked(&[0, 1, 2, 3, 4, 5], &[0, 1, 2, 3, 4, 5],) + ); + assert!(&perm_left * &A == &pl * &A); + assert!(&A * &perm_right == &A * &pr); + } + + #[test] + fn test_matmul_col_row() { + let A = Col::from_fn(6, |i| i as f64); + let B = Row::from_fn(6, |j| (5 * j + 1) as f64); + + // outer product + assert_eq!(&A * &B, A.as_ref().as_2d() * B.as_ref().as_2d()); + // inner product + assert_eq!( + &B * &A, + (B.as_ref().as_2d() * A.as_ref().as_2d()).read(0, 0), + ); + } + + fn assert_matrix_approx_eq(given: Mat, expected: &Mat) { + assert_eq!(given.nrows(), expected.nrows()); + assert_eq!(given.ncols(), expected.ncols()); + for i in 0..given.nrows() { + for j in 0..given.ncols() { + assert_approx_eq!(given.read(i, j), expected.read(i, j)); + } + } + } +} diff --git a/faer-libs/faer-core/src/mul.rs b/src/linalg/matmul/mod.rs similarity index 64% rename from faer-libs/faer-core/src/mul.rs rename to src/linalg/matmul/mod.rs index 488820f56e4b98588532cc06e8e945fd2426a720..7e53dd0e9f7d26baefe316fc9d44f924537c15ed 100644 --- a/faer-libs/faer-core/src/mul.rs +++ b/src/linalg/matmul/mod.rs @@ -1,12 +1,17 @@ //! Matrix multiplication. use crate::{ - assert, c32, c64, group_helpers::*, transmute_unchecked, unzipped, zipped, ComplexField, Conj, - Conjugate, DivCeil, MatMut, MatRef, Parallelism, SimdGroupFor, + assert, + complex_native::*, + linalg::entity::{transmute_unchecked, SimdGroupFor}, + mat::{MatMut, MatRef}, + unzipped, + utils::{simd::*, slice::*, DivCeil}, + zipped, ComplexField, Conj, Conjugate, Parallelism, }; use core::{iter::zip, marker::PhantomData, mem::MaybeUninit}; use faer_entity::{SimdCtx, *}; -use pulp::Simd; +use pulp::{Read, Simd, Write}; use reborrow::*; #[doc(hidden)] @@ -221,15 +226,15 @@ pub mod inner_prod { }) } } else { - crate::constrained::Size::with2( + crate::utils::constrained::Size::with2( nrows, 1, #[inline(always)] |nrows, ncols| { let zero_idx = ncols.check(0); - let a = crate::constrained::MatRef::new(a, nrows, ncols); - let b = crate::constrained::MatRef::new(b, nrows, ncols); + let a = crate::utils::constrained::mat::MatRef::new(a, nrows, ncols); + let b = crate::utils::constrained::mat::MatRef::new(b, nrows, ncols); let mut acc = E::faer_zero(); if conj_lhs == conj_rhs { for i in nrows.indices() { @@ -477,7 +482,7 @@ pub mod matvec_colmajor { matvec_with_conj_impl(acc, lhs, conj_lhs, rhs, conj_rhs, beta); } else { - let mut tmp = crate::Mat::::zeros(m, 1); + let mut tmp = crate::mat::Mat::::zeros(m, 1); matvec_with_conj_impl(tmp.as_mut(), lhs, conj_lhs, rhs, conj_rhs, beta); match alpha { Some(alpha) => { @@ -1215,7 +1220,7 @@ fn matmul_with_conj_impl( } }; - crate::for_each_raw(job_count, job, parallelism); + crate::utils::thread::for_each_raw(job_count, job, parallelism); depth_outer += k_chunk; } @@ -1723,7 +1728,7 @@ pub fn matmul_with_conj_gemm_dispatch( let mut a_copy = a.to_owned(); a_copy.resize_with(padded_m, k, |_, _| E::faer_zero()); let a_copy = a_copy.as_ref(); - let mut tmp = crate::Mat::::zeros(padded_m, n); + let mut tmp = crate::mat::Mat::::zeros(padded_m, n); let tmp_conj_b = match (conj_a, conj_b) { (Conj::Yes, Conj::Yes) | (Conj::No, Conj::No) => Conj::No, (Conj::Yes, Conj::No) | (Conj::No, Conj::Yes) => Conj::Yes, @@ -1791,7 +1796,7 @@ pub fn matmul_with_conj_gemm_dispatch( /// # Example /// /// ``` -/// use faer_core::{mat, mul::matmul_with_conj, unzipped, zipped, Conj, Mat, Parallelism}; +/// use faer::{linalg::matmul::matmul_with_conj, mat, unzipped, zipped, Conj, Mat, Parallelism}; /// /// let lhs = mat![[0.0, 2.0], [1.0, 3.0]]; /// let rhs = mat![[4.0, 6.0], [5.0, 7.0]]; @@ -1871,7 +1876,7 @@ pub fn matmul_with_conj( /// # Example /// /// ``` -/// use faer_core::{mat, mul::matmul, unzipped, zipped, Mat, Parallelism}; +/// use faer::{linalg::matmul::matmul, mat, unzipped, zipped, Mat, Parallelism}; /// /// let lhs = mat![[0.0, 2.0], [1.0, 3.0]]; /// let rhs = mat![[4.0, 6.0], [5.0, 7.0]]; @@ -1924,7 +1929,7 @@ macro_rules! stack_mat_16x16_begin { <$ty as $crate::Entity>::UNIT, #[inline(always)] |()| unsafe { - $crate::transmute_unchecked::< + $crate::linalg::entity::transmute_unchecked::< ::core::mem::MaybeUninit<[<$ty as $crate::Entity>::Unit; 16 * 16]>, [::core::mem::MaybeUninit<<$ty as $crate::Entity>::Unit>; 16 * 16], >(::core::mem::MaybeUninit::< @@ -1971,1365 +1976,7 @@ macro_rules! stack_mat_16x16_begin { /// Triangular matrix multiplication module, where some of the operands are treated as triangular /// matrices. -pub mod triangular { - use super::*; - use crate::{assert, debug_assert, join_raw, zip::Diag}; - - #[repr(u8)] - #[derive(Copy, Clone, Debug)] - pub(crate) enum DiagonalKind { - Zero, - Unit, - Generic, - } - - unsafe fn copy_lower( - mut dst: MatMut<'_, E>, - src: MatRef<'_, E>, - src_diag: DiagonalKind, - ) { - let n = dst.nrows(); - debug_assert!(n == dst.nrows()); - debug_assert!(n == dst.ncols()); - debug_assert!(n == src.nrows()); - debug_assert!(n == src.ncols()); - - let strict = match src_diag { - DiagonalKind::Zero => { - for j in 0..n { - dst.write_unchecked(j, j, E::faer_zero()); - } - true - } - DiagonalKind::Unit => { - for j in 0..n { - dst.write_unchecked(j, j, E::faer_one()); - } - true - } - DiagonalKind::Generic => false, - }; - - zipped!(dst.rb_mut()) - .for_each_triangular_upper(Diag::Skip, |unzipped!(mut dst)| dst.write(E::faer_zero())); - zipped!(dst, src).for_each_triangular_lower( - if strict { Diag::Skip } else { Diag::Include }, - |unzipped!(mut dst, src)| dst.write(src.read()), - ); - } - - unsafe fn accum_lower( - dst: MatMut<'_, E>, - src: MatRef<'_, E>, - skip_diag: bool, - alpha: Option, - ) { - let n = dst.nrows(); - debug_assert!(n == dst.nrows()); - debug_assert!(n == dst.ncols()); - debug_assert!(n == src.nrows()); - debug_assert!(n == src.ncols()); - - match alpha { - Some(alpha) => { - zipped!(dst, src).for_each_triangular_lower( - if skip_diag { Diag::Skip } else { Diag::Include }, - |unzipped!(mut dst, src)| { - dst.write(alpha.faer_mul(dst.read().faer_add(src.read()))) - }, - ); - } - None => { - zipped!(dst, src).for_each_triangular_lower( - if skip_diag { Diag::Skip } else { Diag::Include }, - |unzipped!(mut dst, src)| dst.write(src.read()), - ); - } - } - } - - #[inline] - unsafe fn copy_upper( - dst: MatMut<'_, E>, - src: MatRef<'_, E>, - src_diag: DiagonalKind, - ) { - copy_lower(dst.transpose_mut(), src.transpose(), src_diag) - } - - #[inline] - unsafe fn mul( - dst: MatMut<'_, E>, - lhs: MatRef<'_, E>, - rhs: MatRef<'_, E>, - alpha: Option, - beta: E, - conj_lhs: Conj, - conj_rhs: Conj, - parallelism: Parallelism, - ) { - super::matmul_with_conj(dst, lhs, conj_lhs, rhs, conj_rhs, alpha, beta, parallelism); - } - - unsafe fn mat_x_lower_into_lower_impl_unchecked( - dst: MatMut<'_, E>, - skip_diag: bool, - lhs: MatRef<'_, E>, - rhs: MatRef<'_, E>, - rhs_diag: DiagonalKind, - alpha: Option, - beta: E, - conj_lhs: Conj, - conj_rhs: Conj, - parallelism: Parallelism, - ) { - let n = dst.nrows(); - debug_assert!(n == dst.nrows()); - debug_assert!(n == dst.ncols()); - debug_assert!(n == lhs.nrows()); - debug_assert!(n == lhs.ncols()); - debug_assert!(n == rhs.nrows()); - debug_assert!(n == rhs.ncols()); - - if n <= 16 { - let op = { - #[inline(never)] - || { - stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E); - stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E); - - copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag); - mul( - temp_dst.rb_mut(), - lhs, - temp_rhs.rb(), - None, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - accum_lower(dst, temp_dst.rb(), skip_diag, alpha); - } - }; - op(); - } else { - let bs = n / 2; - - let (mut dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs); - let (lhs_top_left, lhs_top_right, lhs_bot_left, lhs_bot_right) = lhs.split_at(bs, bs); - let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs); - - // lhs_bot_right × rhs_bot_left => dst_bot_left | mat × mat => mat | 1 - // lhs_bot_right × rhs_bot_right => dst_bot_right | mat × low => low | X - // - // lhs_top_left × rhs_top_left => dst_top_left | mat × low => low | X - // lhs_top_right × rhs_bot_left => dst_top_left | mat × mat => low | 1/2 - // lhs_bot_left × rhs_top_left => dst_bot_left | mat × low => mat | 1/2 - - mul( - dst_bot_left.rb_mut(), - lhs_bot_right, - rhs_bot_left, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - mat_x_lower_into_lower_impl_unchecked( - dst_bot_right, - skip_diag, - lhs_bot_right, - rhs_bot_right, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - - mat_x_lower_into_lower_impl_unchecked( - dst_top_left.rb_mut(), - skip_diag, - lhs_top_left, - rhs_top_left, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - mat_x_mat_into_lower_impl_unchecked( - dst_top_left, - skip_diag, - lhs_top_right, - rhs_bot_left, - Some(E::faer_one()), - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - mat_x_lower_impl_unchecked( - dst_bot_left, - lhs_bot_left, - rhs_top_left, - rhs_diag, - Some(E::faer_one()), - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - } - } - - unsafe fn mat_x_lower_impl_unchecked( - dst: MatMut<'_, E>, - lhs: MatRef<'_, E>, - rhs: MatRef<'_, E>, - rhs_diag: DiagonalKind, - alpha: Option, - beta: E, - conj_lhs: Conj, - conj_rhs: Conj, - parallelism: Parallelism, - ) { - let n = rhs.nrows(); - let m = lhs.nrows(); - debug_assert!(m == lhs.nrows()); - debug_assert!(n == lhs.ncols()); - debug_assert!(n == rhs.nrows()); - debug_assert!(n == rhs.ncols()); - debug_assert!(m == dst.nrows()); - debug_assert!(n == dst.ncols()); - - let join_parallelism = if n * n * m < 128 * 128 * 64 { - Parallelism::None - } else { - parallelism - }; - - if n <= 16 { - let op = { - #[inline(never)] - || { - stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E); - - copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag); - - mul( - dst, - lhs, - temp_rhs.rb(), - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - } - }; - op(); - } else { - // split rhs into 3 sections - // split lhs and dst into 2 sections - - let bs = n / 2; - - let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs); - let (lhs_left, lhs_right) = lhs.split_at_col(bs); - let (mut dst_left, mut dst_right) = dst.split_at_col_mut(bs); - - join_raw( - |parallelism| { - mat_x_lower_impl_unchecked( - dst_left.rb_mut(), - lhs_left, - rhs_top_left, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - }, - |parallelism| { - mat_x_lower_impl_unchecked( - dst_right.rb_mut(), - lhs_right, - rhs_bot_right, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - }, - join_parallelism, - ); - mul( - dst_left, - lhs_right, - rhs_bot_left, - Some(E::faer_one()), - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - } - } - - unsafe fn lower_x_lower_into_lower_impl_unchecked( - dst: MatMut<'_, E>, - skip_diag: bool, - lhs: MatRef<'_, E>, - lhs_diag: DiagonalKind, - rhs: MatRef<'_, E>, - rhs_diag: DiagonalKind, - alpha: Option, - beta: E, - conj_lhs: Conj, - conj_rhs: Conj, - parallelism: Parallelism, - ) { - let n = dst.nrows(); - debug_assert!(n == lhs.nrows()); - debug_assert!(n == lhs.ncols()); - debug_assert!(n == rhs.nrows()); - debug_assert!(n == rhs.ncols()); - debug_assert!(n == dst.nrows()); - debug_assert!(n == dst.ncols()); - - if n <= 16 { - let op = { - #[inline(never)] - || { - stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E); - stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E); - stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E); - - copy_lower(temp_lhs.rb_mut(), lhs, lhs_diag); - copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag); - - mul( - temp_dst.rb_mut(), - temp_lhs.rb(), - temp_rhs.rb(), - None, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - accum_lower(dst, temp_dst.rb(), skip_diag, alpha); - } - }; - op(); - } else { - let bs = n / 2; - - let (dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs); - let (lhs_top_left, _, lhs_bot_left, lhs_bot_right) = lhs.split_at(bs, bs); - let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs); - - // lhs_top_left × rhs_top_left => dst_top_left | low × low => low | X - // lhs_bot_left × rhs_top_left => dst_bot_left | mat × low => mat | 1/2 - // lhs_bot_right × rhs_bot_left => dst_bot_left | low × mat => mat | 1/2 - // lhs_bot_right × rhs_bot_right => dst_bot_right | low × low => low | X - - lower_x_lower_into_lower_impl_unchecked( - dst_top_left, - skip_diag, - lhs_top_left, - lhs_diag, - rhs_top_left, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - mat_x_lower_impl_unchecked( - dst_bot_left.rb_mut(), - lhs_bot_left, - rhs_top_left, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - mat_x_lower_impl_unchecked( - dst_bot_left.reverse_rows_and_cols_mut().transpose_mut(), - rhs_bot_left.reverse_rows_and_cols().transpose(), - lhs_bot_right.reverse_rows_and_cols().transpose(), - lhs_diag, - Some(E::faer_one()), - beta, - conj_rhs, - conj_lhs, - parallelism, - ); - lower_x_lower_into_lower_impl_unchecked( - dst_bot_right, - skip_diag, - lhs_bot_right, - lhs_diag, - rhs_bot_right, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - } - } - - unsafe fn upper_x_lower_impl_unchecked( - dst: MatMut<'_, E>, - lhs: MatRef<'_, E>, - lhs_diag: DiagonalKind, - rhs: MatRef<'_, E>, - rhs_diag: DiagonalKind, - alpha: Option, - beta: E, - conj_lhs: Conj, - conj_rhs: Conj, - parallelism: Parallelism, - ) { - let n = dst.nrows(); - debug_assert!(n == lhs.nrows()); - debug_assert!(n == lhs.ncols()); - debug_assert!(n == rhs.nrows()); - debug_assert!(n == rhs.ncols()); - debug_assert!(n == dst.nrows()); - debug_assert!(n == dst.ncols()); - - if n <= 16 { - let op = { - #[inline(never)] - || { - stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E); - stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E); - - copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag); - copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag); - - mul( - dst, - temp_lhs.rb(), - temp_rhs.rb(), - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - } - }; - op(); - } else { - let bs = n / 2; - - let (mut dst_top_left, dst_top_right, dst_bot_left, dst_bot_right) = - dst.split_at_mut(bs, bs); - let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_at(bs, bs); - let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs); - - // lhs_top_right × rhs_bot_left => dst_top_left | mat × mat => mat | 1 - // lhs_top_left × rhs_top_left => dst_top_left | upp × low => mat | X - // - // lhs_top_right × rhs_bot_right => dst_top_right | mat × low => mat | 1/2 - // lhs_bot_right × rhs_bot_left => dst_bot_left | upp × mat => mat | 1/2 - // lhs_bot_right × rhs_bot_right => dst_bot_right | upp × low => mat | X - - join_raw( - |_| { - mul( - dst_top_left.rb_mut(), - lhs_top_right, - rhs_bot_left, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - upper_x_lower_impl_unchecked( - dst_top_left, - lhs_top_left, - lhs_diag, - rhs_top_left, - rhs_diag, - Some(E::faer_one()), - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - }, - |_| { - join_raw( - |_| { - mat_x_lower_impl_unchecked( - dst_top_right, - lhs_top_right, - rhs_bot_right, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - }, - |_| { - mat_x_lower_impl_unchecked( - dst_bot_left.transpose_mut(), - rhs_bot_left.transpose(), - lhs_bot_right.transpose(), - lhs_diag, - alpha, - beta, - conj_rhs, - conj_lhs, - parallelism, - ) - }, - parallelism, - ); - - upper_x_lower_impl_unchecked( - dst_bot_right, - lhs_bot_right, - lhs_diag, - rhs_bot_right, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - }, - parallelism, - ); - } - } - - unsafe fn upper_x_lower_into_lower_impl_unchecked( - dst: MatMut<'_, E>, - skip_diag: bool, - lhs: MatRef<'_, E>, - lhs_diag: DiagonalKind, - rhs: MatRef<'_, E>, - rhs_diag: DiagonalKind, - alpha: Option, - beta: E, - conj_lhs: Conj, - conj_rhs: Conj, - parallelism: Parallelism, - ) { - let n = dst.nrows(); - debug_assert!(n == lhs.nrows()); - debug_assert!(n == lhs.ncols()); - debug_assert!(n == rhs.nrows()); - debug_assert!(n == rhs.ncols()); - debug_assert!(n == dst.nrows()); - debug_assert!(n == dst.ncols()); - - if n <= 16 { - let op = { - #[inline(never)] - || { - stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E); - stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E); - stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E); - - copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag); - copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag); - - mul( - temp_dst.rb_mut(), - temp_lhs.rb(), - temp_rhs.rb(), - None, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - - accum_lower(dst, temp_dst.rb(), skip_diag, alpha); - } - }; - op(); - } else { - let bs = n / 2; - - let (mut dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs); - let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_at(bs, bs); - let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs); - - // lhs_top_left × rhs_top_left => dst_top_left | upp × low => low | X - // lhs_top_right × rhs_bot_left => dst_top_left | mat × mat => low | 1/2 - // - // lhs_bot_right × rhs_bot_left => dst_bot_left | upp × mat => mat | 1/2 - // lhs_bot_right × rhs_bot_right => dst_bot_right | upp × low => low | X - - join_raw( - |_| { - mat_x_mat_into_lower_impl_unchecked( - dst_top_left.rb_mut(), - skip_diag, - lhs_top_right, - rhs_bot_left, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - upper_x_lower_into_lower_impl_unchecked( - dst_top_left, - skip_diag, - lhs_top_left, - lhs_diag, - rhs_top_left, - rhs_diag, - Some(E::faer_one()), - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - }, - |_| { - mat_x_lower_impl_unchecked( - dst_bot_left.transpose_mut(), - rhs_bot_left.transpose(), - lhs_bot_right.transpose(), - lhs_diag, - alpha, - beta, - conj_rhs, - conj_lhs, - parallelism, - ); - upper_x_lower_into_lower_impl_unchecked( - dst_bot_right, - skip_diag, - lhs_bot_right, - lhs_diag, - rhs_bot_right, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - }, - parallelism, - ); - } - } - - unsafe fn mat_x_mat_into_lower_impl_unchecked( - dst: MatMut<'_, E>, - skip_diag: bool, - lhs: MatRef<'_, E>, - rhs: MatRef<'_, E>, - alpha: Option, - beta: E, - conj_lhs: Conj, - conj_rhs: Conj, - parallelism: Parallelism, - ) { - debug_assert!(dst.nrows() == dst.ncols()); - debug_assert!(dst.nrows() == lhs.nrows()); - debug_assert!(dst.ncols() == rhs.ncols()); - debug_assert!(lhs.ncols() == rhs.nrows()); - - let n = dst.nrows(); - let k = lhs.ncols(); - - let join_parallelism = if n * n * k < 128 * 128 * 128 { - Parallelism::None - } else { - parallelism - }; - - if n <= 16 { - let op = { - #[inline(never)] - || { - stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E); - - mul( - temp_dst.rb_mut(), - lhs, - rhs, - None, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - accum_lower(dst, temp_dst.rb(), skip_diag, alpha); - } - }; - op(); - } else { - let bs = n / 2; - let (dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs); - let (lhs_top, lhs_bot) = lhs.split_at_row(bs); - let (rhs_left, rhs_right) = rhs.split_at_col(bs); - - join_raw( - |_| { - mul( - dst_bot_left, - lhs_bot, - rhs_left, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - }, - |_| { - join_raw( - |_| { - mat_x_mat_into_lower_impl_unchecked( - dst_top_left, - skip_diag, - lhs_top, - rhs_left, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - }, - |_| { - mat_x_mat_into_lower_impl_unchecked( - dst_bot_right, - skip_diag, - lhs_bot, - rhs_right, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - }, - join_parallelism, - ) - }, - join_parallelism, - ); - } - } - - /// Describes the parts of the matrix that must be accessed. - #[derive(Debug, Clone, Copy, PartialEq, Eq)] - pub enum BlockStructure { - /// The full matrix is accessed. - Rectangular, - /// The lower triangular half (including the diagonal) is accessed. - TriangularLower, - /// The lower triangular half (excluding the diagonal) is accessed. - StrictTriangularLower, - /// The lower triangular half (excluding the diagonal, which is assumed to be equal to - /// `1.0`) is accessed. - UnitTriangularLower, - /// The upper triangular half (including the diagonal) is accessed. - TriangularUpper, - /// The upper triangular half (excluding the diagonal) is accessed. - StrictTriangularUpper, - /// The upper triangular half (excluding the diagonal, which is assumed to be equal to - /// `1.0`) is accessed. - UnitTriangularUpper, - } - - impl BlockStructure { - /// Checks if `self` is full. - #[inline] - pub fn is_dense(self) -> bool { - matches!(self, BlockStructure::Rectangular) - } - - /// Checks if `self` is triangular lower (either inclusive or exclusive). - #[inline] - pub fn is_lower(self) -> bool { - use BlockStructure::*; - matches!( - self, - TriangularLower | StrictTriangularLower | UnitTriangularLower - ) - } - - /// Checks if `self` is triangular upper (either inclusive or exclusive). - #[inline] - pub fn is_upper(self) -> bool { - use BlockStructure::*; - matches!( - self, - TriangularUpper | StrictTriangularUpper | UnitTriangularUpper - ) - } - - /// Returns the block structure corresponding to the transposed matrix. - #[inline] - pub fn transpose(self) -> Self { - use BlockStructure::*; - match self { - Rectangular => Rectangular, - TriangularLower => TriangularUpper, - StrictTriangularLower => StrictTriangularUpper, - UnitTriangularLower => UnitTriangularUpper, - TriangularUpper => TriangularLower, - StrictTriangularUpper => StrictTriangularLower, - UnitTriangularUpper => UnitTriangularLower, - } - } - - #[inline] - pub(crate) fn diag_kind(self) -> DiagonalKind { - use BlockStructure::*; - match self { - Rectangular | TriangularLower | TriangularUpper => DiagonalKind::Generic, - StrictTriangularLower | StrictTriangularUpper => DiagonalKind::Zero, - UnitTriangularLower | UnitTriangularUpper => DiagonalKind::Unit, - } - } - } - - /// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` (while optionally conjugating - /// either or both of the input matrices) and stores the result in `acc`. - /// - /// Performs the operation: - /// - `acc = beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `None` (in this case, the - /// preexisting values in `acc` are not read, so it is allowed to be a view over uninitialized - /// values if `E: Copy`), - /// - `acc = alpha * acc + beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `Some(_)`, - /// - /// The left hand side and right hand side may be interpreted as triangular depending on the - /// given corresponding matrix structure. - /// - /// For the destination matrix, the result is: - /// - fully computed if the structure is rectangular, - /// - only the triangular half (including the diagonal) is computed if the structure is - /// triangular, - /// - only the strict triangular half (excluding the diagonal) is computed if the structure is - /// strictly triangular or unit triangular. - /// - /// `Op_lhs` is the identity if `conj_lhs` is `Conj::No`, and the conjugation operation if it is - /// `Conj::Yes`. - /// `Op_rhs` is the identity if `conj_rhs` is `Conj::No`, and the conjugation operation if it is - /// `Conj::Yes`. - /// - /// # Panics - /// - /// Panics if the matrix dimensions are not compatible for matrix multiplication. - /// i.e. - /// - `acc.nrows() == lhs.nrows()` - /// - `acc.ncols() == rhs.ncols()` - /// - `lhs.ncols() == rhs.nrows()` - /// - /// Additionally, matrices that are marked as triangular must be square, i.e., they must have - /// the same number of rows and columns. - /// - /// # Example - /// - /// ``` - /// use faer_core::{ - /// mat, - /// mul::triangular::{matmul_with_conj, BlockStructure}, - /// unzipped, zipped, Conj, Mat, Parallelism, - /// }; - /// - /// let lhs = mat![[0.0, 2.0], [1.0, 3.0]]; - /// let rhs = mat![[4.0, 6.0], [5.0, 7.0]]; - /// - /// let mut acc = Mat::::zeros(2, 2); - /// let target = mat![ - /// [ - /// 2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)), - /// 0.0, - /// ], - /// [ - /// 2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)), - /// 2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)), - /// ], - /// ]; - /// - /// matmul_with_conj( - /// acc.as_mut(), - /// BlockStructure::TriangularLower, - /// lhs.as_ref(), - /// BlockStructure::Rectangular, - /// Conj::No, - /// rhs.as_ref(), - /// BlockStructure::Rectangular, - /// Conj::No, - /// None, - /// 2.5, - /// Parallelism::None, - /// ); - /// - /// zipped!(acc.as_ref(), target.as_ref()) - /// .for_each(|unzipped!(acc, target)| assert!((acc.read() - target.read()).abs() < 1e-10)); - /// ``` - #[track_caller] - #[inline] - pub fn matmul_with_conj( - acc: MatMut<'_, E>, - acc_structure: BlockStructure, - lhs: MatRef<'_, E>, - lhs_structure: BlockStructure, - conj_lhs: Conj, - rhs: MatRef<'_, E>, - rhs_structure: BlockStructure, - conj_rhs: Conj, - alpha: Option, - beta: E, - parallelism: Parallelism, - ) { - assert!(all( - acc.nrows() == lhs.nrows(), - acc.ncols() == rhs.ncols(), - lhs.ncols() == rhs.nrows(), - )); - - if !acc_structure.is_dense() { - assert!(acc.nrows() == acc.ncols()); - } - if !lhs_structure.is_dense() { - assert!(lhs.nrows() == lhs.ncols()); - } - if !rhs_structure.is_dense() { - assert!(rhs.nrows() == rhs.ncols()); - } - - unsafe { - matmul_unchecked( - acc, - acc_structure, - lhs, - lhs_structure, - conj_lhs, - rhs, - rhs_structure, - conj_rhs, - alpha, - beta, - parallelism, - ) - } - } - - /// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` and stores the result in - /// `acc`. - /// - /// Performs the operation: - /// - `acc = beta * lhs * rhs` if `alpha` is `None` (in this case, the preexisting values in - /// `acc` are not read, so it is allowed to be a view over uninitialized values if `E: Copy`), - /// - `acc = alpha * acc + beta * lhs * rhs` if `alpha` is `Some(_)`, - /// - /// The left hand side and right hand side may be interpreted as triangular depending on the - /// given corresponding matrix structure. - /// - /// For the destination matrix, the result is: - /// - fully computed if the structure is rectangular, - /// - only the triangular half (including the diagonal) is computed if the structure is - /// triangular, - /// - only the strict triangular half (excluding the diagonal) is computed if the structure is - /// strictly triangular or unit triangular. - /// - /// # Panics - /// - /// Panics if the matrix dimensions are not compatible for matrix multiplication. - /// i.e. - /// - `acc.nrows() == lhs.nrows()` - /// - `acc.ncols() == rhs.ncols()` - /// - `lhs.ncols() == rhs.nrows()` - /// - /// Additionally, matrices that are marked as triangular must be square, i.e., they must have - /// the same number of rows and columns. - /// - /// # Example - /// - /// ``` - /// use faer_core::{ - /// mat, - /// mul::triangular::{matmul, BlockStructure}, - /// unzipped, zipped, Conj, Mat, Parallelism, - /// }; - /// - /// let lhs = mat![[0.0, 2.0], [1.0, 3.0]]; - /// let rhs = mat![[4.0, 6.0], [5.0, 7.0]]; - /// - /// let mut acc = Mat::::zeros(2, 2); - /// let target = mat![ - /// [ - /// 2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)), - /// 0.0, - /// ], - /// [ - /// 2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)), - /// 2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)), - /// ], - /// ]; - /// - /// matmul( - /// acc.as_mut(), - /// BlockStructure::TriangularLower, - /// lhs.as_ref(), - /// BlockStructure::Rectangular, - /// rhs.as_ref(), - /// BlockStructure::Rectangular, - /// None, - /// 2.5, - /// Parallelism::None, - /// ); - /// - /// zipped!(acc.as_ref(), target.as_ref()) - /// .for_each(|unzipped!(acc, target)| assert!((acc.read() - target.read()).abs() < 1e-10)); - /// ``` - #[track_caller] - #[inline] - pub fn matmul< - E: ComplexField, - LhsE: Conjugate, - RhsE: Conjugate, - >( - acc: MatMut<'_, E>, - acc_structure: BlockStructure, - lhs: MatRef<'_, LhsE>, - lhs_structure: BlockStructure, - rhs: MatRef<'_, RhsE>, - rhs_structure: BlockStructure, - alpha: Option, - beta: E, - parallelism: Parallelism, - ) { - let (lhs, conj_lhs) = lhs.canonicalize(); - let (rhs, conj_rhs) = rhs.canonicalize(); - matmul_with_conj( - acc, - acc_structure, - lhs, - lhs_structure, - conj_lhs, - rhs, - rhs_structure, - conj_rhs, - alpha, - beta, - parallelism, - ); - } - - unsafe fn matmul_unchecked( - acc: MatMut<'_, E>, - acc_structure: BlockStructure, - lhs: MatRef<'_, E>, - lhs_structure: BlockStructure, - conj_lhs: Conj, - rhs: MatRef<'_, E>, - rhs_structure: BlockStructure, - conj_rhs: Conj, - alpha: Option, - beta: E, - parallelism: Parallelism, - ) { - debug_assert!(acc.nrows() == lhs.nrows()); - debug_assert!(acc.ncols() == rhs.ncols()); - debug_assert!(lhs.ncols() == rhs.nrows()); - - if !acc_structure.is_dense() { - debug_assert!(acc.nrows() == acc.ncols()); - } - if !lhs_structure.is_dense() { - debug_assert!(lhs.nrows() == lhs.ncols()); - } - if !rhs_structure.is_dense() { - debug_assert!(rhs.nrows() == rhs.ncols()); - } - - let mut acc = acc; - let mut lhs = lhs; - let mut rhs = rhs; - - let mut acc_structure = acc_structure; - let mut lhs_structure = lhs_structure; - let mut rhs_structure = rhs_structure; - - let mut conj_lhs = conj_lhs; - let mut conj_rhs = conj_rhs; - - // if either the lhs or the rhs is triangular - if rhs_structure.is_lower() { - // do nothing - false - } else if rhs_structure.is_upper() { - // invert acc, lhs and rhs - acc = acc.reverse_rows_and_cols_mut(); - lhs = lhs.reverse_rows_and_cols(); - rhs = rhs.reverse_rows_and_cols(); - acc_structure = acc_structure.transpose(); - lhs_structure = lhs_structure.transpose(); - rhs_structure = rhs_structure.transpose(); - false - } else if lhs_structure.is_lower() { - // invert and transpose - acc = acc.reverse_rows_and_cols_mut().transpose_mut(); - (lhs, rhs) = ( - rhs.reverse_rows_and_cols().transpose(), - lhs.reverse_rows_and_cols().transpose(), - ); - (conj_lhs, conj_rhs) = (conj_rhs, conj_lhs); - (lhs_structure, rhs_structure) = (rhs_structure, lhs_structure); - true - } else if lhs_structure.is_upper() { - // transpose - acc_structure = acc_structure.transpose(); - acc = acc.transpose_mut(); - (lhs, rhs) = (rhs.transpose(), lhs.transpose()); - (conj_lhs, conj_rhs) = (conj_rhs, conj_lhs); - (lhs_structure, rhs_structure) = (rhs_structure.transpose(), lhs_structure.transpose()); - true - } else { - // do nothing - false - }; - - let clear_upper = |acc: MatMut<'_, E>, skip_diag: bool| match &alpha { - &Some(alpha) => zipped!(acc).for_each_triangular_upper( - if skip_diag { Diag::Skip } else { Diag::Include }, - |unzipped!(mut acc)| acc.write(alpha.faer_mul(acc.read())), - ), - - None => zipped!(acc).for_each_triangular_upper( - if skip_diag { Diag::Skip } else { Diag::Include }, - |unzipped!(mut acc)| acc.write(E::faer_zero()), - ), - }; - - let skip_diag = matches!( - acc_structure, - BlockStructure::StrictTriangularLower - | BlockStructure::StrictTriangularUpper - | BlockStructure::UnitTriangularLower - | BlockStructure::UnitTriangularUpper - ); - let lhs_diag = lhs_structure.diag_kind(); - let rhs_diag = rhs_structure.diag_kind(); - - if acc_structure.is_dense() { - if lhs_structure.is_dense() && rhs_structure.is_dense() { - mul(acc, lhs, rhs, alpha, beta, conj_lhs, conj_rhs, parallelism); - } else { - debug_assert!(rhs_structure.is_lower()); - - if lhs_structure.is_dense() { - mat_x_lower_impl_unchecked( - acc, - lhs, - rhs, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - } else if lhs_structure.is_lower() { - clear_upper(acc.rb_mut(), true); - lower_x_lower_into_lower_impl_unchecked( - acc, - false, - lhs, - lhs_diag, - rhs, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - } else { - debug_assert!(lhs_structure.is_upper()); - upper_x_lower_impl_unchecked( - acc, - lhs, - lhs_diag, - rhs, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - } - } - } else if acc_structure.is_lower() { - if lhs_structure.is_dense() && rhs_structure.is_dense() { - mat_x_mat_into_lower_impl_unchecked( - acc, - skip_diag, - lhs, - rhs, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - } else { - debug_assert!(rhs_structure.is_lower()); - if lhs_structure.is_dense() { - mat_x_lower_into_lower_impl_unchecked( - acc, - skip_diag, - lhs, - rhs, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ); - } else if lhs_structure.is_lower() { - lower_x_lower_into_lower_impl_unchecked( - acc, - skip_diag, - lhs, - lhs_diag, - rhs, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - } else { - upper_x_lower_into_lower_impl_unchecked( - acc, - skip_diag, - lhs, - lhs_diag, - rhs, - rhs_diag, - alpha, - beta, - conj_lhs, - conj_rhs, - parallelism, - ) - } - } - } else if lhs_structure.is_dense() && rhs_structure.is_dense() { - mat_x_mat_into_lower_impl_unchecked( - acc.transpose_mut(), - skip_diag, - rhs.transpose(), - lhs.transpose(), - alpha, - beta, - conj_rhs, - conj_lhs, - parallelism, - ) - } else { - debug_assert!(rhs_structure.is_lower()); - if lhs_structure.is_dense() { - // lower part of lhs does not contribute to result - upper_x_lower_into_lower_impl_unchecked( - acc.transpose_mut(), - skip_diag, - rhs.transpose(), - rhs_diag, - lhs.transpose(), - lhs_diag, - alpha, - beta, - conj_rhs, - conj_lhs, - parallelism, - ) - } else if lhs_structure.is_lower() { - if !skip_diag { - match &alpha { - &Some(alpha) => { - zipped!( - acc.rb_mut().diagonal_mut().column_vector_mut().as_2d_mut(), - lhs.diagonal().column_vector().as_2d(), - rhs.diagonal().column_vector().as_2d(), - ) - .for_each( - |unzipped!(mut acc, lhs, rhs)| { - acc.write( - (alpha.faer_mul(acc.read())).faer_add( - beta.faer_mul(lhs.read().faer_mul(rhs.read())), - ), - ) - }, - ); - } - None => { - zipped!( - acc.rb_mut().diagonal_mut().column_vector_mut().as_2d_mut(), - lhs.diagonal().column_vector().as_2d(), - rhs.diagonal().column_vector().as_2d(), - ) - .for_each( - |unzipped!(mut acc, lhs, rhs)| { - acc.write(beta.faer_mul(lhs.read().faer_mul(rhs.read()))) - }, - ); - } - } - } - clear_upper(acc.rb_mut(), true); - } else { - debug_assert!(lhs_structure.is_upper()); - upper_x_lower_into_lower_impl_unchecked( - acc.transpose_mut(), - skip_diag, - rhs.transpose(), - rhs_diag, - lhs.transpose(), - lhs_diag, - alpha, - beta, - conj_rhs, - conj_lhs, - parallelism, - ) - } - } - } -} +pub mod triangular; #[cfg(test)] mod tests { @@ -3337,7 +1984,7 @@ mod tests { triangular::{BlockStructure, DiagonalKind}, *, }; - use crate::{assert, Mat}; + use crate::{assert, mat::Mat}; use assert_approx_eq::assert_approx_eq; use num_complex::Complex32; @@ -3543,7 +2190,7 @@ mod tests { } }; - crate::for_each_raw(m * n, job, parallelism); + crate::utils::thread::for_each_raw(m * n, job, parallelism); } fn test_matmul_impl( diff --git a/src/linalg/matmul/triangular.rs b/src/linalg/matmul/triangular.rs new file mode 100644 index 0000000000000000000000000000000000000000..4457a529c0f1c3fea5deb6cb318d5b1b2ad5af7b --- /dev/null +++ b/src/linalg/matmul/triangular.rs @@ -0,0 +1,1346 @@ +use super::*; +use crate::{assert, debug_assert, linalg::zip::Diag, utils::thread::join_raw}; + +#[repr(u8)] +#[derive(Copy, Clone, Debug)] +pub(crate) enum DiagonalKind { + Zero, + Unit, + Generic, +} + +unsafe fn copy_lower( + mut dst: MatMut<'_, E>, + src: MatRef<'_, E>, + src_diag: DiagonalKind, +) { + let n = dst.nrows(); + debug_assert!(n == dst.nrows()); + debug_assert!(n == dst.ncols()); + debug_assert!(n == src.nrows()); + debug_assert!(n == src.ncols()); + + let strict = match src_diag { + DiagonalKind::Zero => { + for j in 0..n { + dst.write_unchecked(j, j, E::faer_zero()); + } + true + } + DiagonalKind::Unit => { + for j in 0..n { + dst.write_unchecked(j, j, E::faer_one()); + } + true + } + DiagonalKind::Generic => false, + }; + + zipped!(dst.rb_mut()) + .for_each_triangular_upper(Diag::Skip, |unzipped!(mut dst)| dst.write(E::faer_zero())); + zipped!(dst, src).for_each_triangular_lower( + if strict { Diag::Skip } else { Diag::Include }, + |unzipped!(mut dst, src)| dst.write(src.read()), + ); +} + +unsafe fn accum_lower( + dst: MatMut<'_, E>, + src: MatRef<'_, E>, + skip_diag: bool, + alpha: Option, +) { + let n = dst.nrows(); + debug_assert!(n == dst.nrows()); + debug_assert!(n == dst.ncols()); + debug_assert!(n == src.nrows()); + debug_assert!(n == src.ncols()); + + match alpha { + Some(alpha) => { + zipped!(dst, src).for_each_triangular_lower( + if skip_diag { Diag::Skip } else { Diag::Include }, + |unzipped!(mut dst, src)| { + dst.write(alpha.faer_mul(dst.read().faer_add(src.read()))) + }, + ); + } + None => { + zipped!(dst, src).for_each_triangular_lower( + if skip_diag { Diag::Skip } else { Diag::Include }, + |unzipped!(mut dst, src)| dst.write(src.read()), + ); + } + } +} + +#[inline] +unsafe fn copy_upper( + dst: MatMut<'_, E>, + src: MatRef<'_, E>, + src_diag: DiagonalKind, +) { + copy_lower(dst.transpose_mut(), src.transpose(), src_diag) +} + +#[inline] +unsafe fn mul( + dst: MatMut<'_, E>, + lhs: MatRef<'_, E>, + rhs: MatRef<'_, E>, + alpha: Option, + beta: E, + conj_lhs: Conj, + conj_rhs: Conj, + parallelism: Parallelism, +) { + super::matmul_with_conj(dst, lhs, conj_lhs, rhs, conj_rhs, alpha, beta, parallelism); +} + +unsafe fn mat_x_lower_into_lower_impl_unchecked( + dst: MatMut<'_, E>, + skip_diag: bool, + lhs: MatRef<'_, E>, + rhs: MatRef<'_, E>, + rhs_diag: DiagonalKind, + alpha: Option, + beta: E, + conj_lhs: Conj, + conj_rhs: Conj, + parallelism: Parallelism, +) { + let n = dst.nrows(); + debug_assert!(n == dst.nrows()); + debug_assert!(n == dst.ncols()); + debug_assert!(n == lhs.nrows()); + debug_assert!(n == lhs.ncols()); + debug_assert!(n == rhs.nrows()); + debug_assert!(n == rhs.ncols()); + + if n <= 16 { + let op = { + #[inline(never)] + || { + stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E); + stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E); + + copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag); + mul( + temp_dst.rb_mut(), + lhs, + temp_rhs.rb(), + None, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + accum_lower(dst, temp_dst.rb(), skip_diag, alpha); + } + }; + op(); + } else { + let bs = n / 2; + + let (mut dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs); + let (lhs_top_left, lhs_top_right, lhs_bot_left, lhs_bot_right) = lhs.split_at(bs, bs); + let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs); + + // lhs_bot_right × rhs_bot_left => dst_bot_left | mat × mat => mat | 1 + // lhs_bot_right × rhs_bot_right => dst_bot_right | mat × low => low | X + // + // lhs_top_left × rhs_top_left => dst_top_left | mat × low => low | X + // lhs_top_right × rhs_bot_left => dst_top_left | mat × mat => low | 1/2 + // lhs_bot_left × rhs_top_left => dst_bot_left | mat × low => mat | 1/2 + + mul( + dst_bot_left.rb_mut(), + lhs_bot_right, + rhs_bot_left, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + mat_x_lower_into_lower_impl_unchecked( + dst_bot_right, + skip_diag, + lhs_bot_right, + rhs_bot_right, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + + mat_x_lower_into_lower_impl_unchecked( + dst_top_left.rb_mut(), + skip_diag, + lhs_top_left, + rhs_top_left, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + mat_x_mat_into_lower_impl_unchecked( + dst_top_left, + skip_diag, + lhs_top_right, + rhs_bot_left, + Some(E::faer_one()), + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + mat_x_lower_impl_unchecked( + dst_bot_left, + lhs_bot_left, + rhs_top_left, + rhs_diag, + Some(E::faer_one()), + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + } +} + +unsafe fn mat_x_lower_impl_unchecked( + dst: MatMut<'_, E>, + lhs: MatRef<'_, E>, + rhs: MatRef<'_, E>, + rhs_diag: DiagonalKind, + alpha: Option, + beta: E, + conj_lhs: Conj, + conj_rhs: Conj, + parallelism: Parallelism, +) { + let n = rhs.nrows(); + let m = lhs.nrows(); + debug_assert!(m == lhs.nrows()); + debug_assert!(n == lhs.ncols()); + debug_assert!(n == rhs.nrows()); + debug_assert!(n == rhs.ncols()); + debug_assert!(m == dst.nrows()); + debug_assert!(n == dst.ncols()); + + let join_parallelism = if n * n * m < 128 * 128 * 64 { + Parallelism::None + } else { + parallelism + }; + + if n <= 16 { + let op = { + #[inline(never)] + || { + stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E); + + copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag); + + mul( + dst, + lhs, + temp_rhs.rb(), + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + } + }; + op(); + } else { + // split rhs into 3 sections + // split lhs and dst into 2 sections + + let bs = n / 2; + + let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs); + let (lhs_left, lhs_right) = lhs.split_at_col(bs); + let (mut dst_left, mut dst_right) = dst.split_at_col_mut(bs); + + join_raw( + |parallelism| { + mat_x_lower_impl_unchecked( + dst_left.rb_mut(), + lhs_left, + rhs_top_left, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + }, + |parallelism| { + mat_x_lower_impl_unchecked( + dst_right.rb_mut(), + lhs_right, + rhs_bot_right, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + }, + join_parallelism, + ); + mul( + dst_left, + lhs_right, + rhs_bot_left, + Some(E::faer_one()), + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + } +} + +unsafe fn lower_x_lower_into_lower_impl_unchecked( + dst: MatMut<'_, E>, + skip_diag: bool, + lhs: MatRef<'_, E>, + lhs_diag: DiagonalKind, + rhs: MatRef<'_, E>, + rhs_diag: DiagonalKind, + alpha: Option, + beta: E, + conj_lhs: Conj, + conj_rhs: Conj, + parallelism: Parallelism, +) { + let n = dst.nrows(); + debug_assert!(n == lhs.nrows()); + debug_assert!(n == lhs.ncols()); + debug_assert!(n == rhs.nrows()); + debug_assert!(n == rhs.ncols()); + debug_assert!(n == dst.nrows()); + debug_assert!(n == dst.ncols()); + + if n <= 16 { + let op = { + #[inline(never)] + || { + stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E); + stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E); + stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E); + + copy_lower(temp_lhs.rb_mut(), lhs, lhs_diag); + copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag); + + mul( + temp_dst.rb_mut(), + temp_lhs.rb(), + temp_rhs.rb(), + None, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + accum_lower(dst, temp_dst.rb(), skip_diag, alpha); + } + }; + op(); + } else { + let bs = n / 2; + + let (dst_top_left, _, mut dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs); + let (lhs_top_left, _, lhs_bot_left, lhs_bot_right) = lhs.split_at(bs, bs); + let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs); + + // lhs_top_left × rhs_top_left => dst_top_left | low × low => low | X + // lhs_bot_left × rhs_top_left => dst_bot_left | mat × low => mat | 1/2 + // lhs_bot_right × rhs_bot_left => dst_bot_left | low × mat => mat | 1/2 + // lhs_bot_right × rhs_bot_right => dst_bot_right | low × low => low | X + + lower_x_lower_into_lower_impl_unchecked( + dst_top_left, + skip_diag, + lhs_top_left, + lhs_diag, + rhs_top_left, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + mat_x_lower_impl_unchecked( + dst_bot_left.rb_mut(), + lhs_bot_left, + rhs_top_left, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + mat_x_lower_impl_unchecked( + dst_bot_left.reverse_rows_and_cols_mut().transpose_mut(), + rhs_bot_left.reverse_rows_and_cols().transpose(), + lhs_bot_right.reverse_rows_and_cols().transpose(), + lhs_diag, + Some(E::faer_one()), + beta, + conj_rhs, + conj_lhs, + parallelism, + ); + lower_x_lower_into_lower_impl_unchecked( + dst_bot_right, + skip_diag, + lhs_bot_right, + lhs_diag, + rhs_bot_right, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + } +} + +unsafe fn upper_x_lower_impl_unchecked( + dst: MatMut<'_, E>, + lhs: MatRef<'_, E>, + lhs_diag: DiagonalKind, + rhs: MatRef<'_, E>, + rhs_diag: DiagonalKind, + alpha: Option, + beta: E, + conj_lhs: Conj, + conj_rhs: Conj, + parallelism: Parallelism, +) { + let n = dst.nrows(); + debug_assert!(n == lhs.nrows()); + debug_assert!(n == lhs.ncols()); + debug_assert!(n == rhs.nrows()); + debug_assert!(n == rhs.ncols()); + debug_assert!(n == dst.nrows()); + debug_assert!(n == dst.ncols()); + + if n <= 16 { + let op = { + #[inline(never)] + || { + stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E); + stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E); + + copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag); + copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag); + + mul( + dst, + temp_lhs.rb(), + temp_rhs.rb(), + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + } + }; + op(); + } else { + let bs = n / 2; + + let (mut dst_top_left, dst_top_right, dst_bot_left, dst_bot_right) = + dst.split_at_mut(bs, bs); + let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_at(bs, bs); + let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs); + + // lhs_top_right × rhs_bot_left => dst_top_left | mat × mat => mat | 1 + // lhs_top_left × rhs_top_left => dst_top_left | upp × low => mat | X + // + // lhs_top_right × rhs_bot_right => dst_top_right | mat × low => mat | 1/2 + // lhs_bot_right × rhs_bot_left => dst_bot_left | upp × mat => mat | 1/2 + // lhs_bot_right × rhs_bot_right => dst_bot_right | upp × low => mat | X + + join_raw( + |_| { + mul( + dst_top_left.rb_mut(), + lhs_top_right, + rhs_bot_left, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + upper_x_lower_impl_unchecked( + dst_top_left, + lhs_top_left, + lhs_diag, + rhs_top_left, + rhs_diag, + Some(E::faer_one()), + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + }, + |_| { + join_raw( + |_| { + mat_x_lower_impl_unchecked( + dst_top_right, + lhs_top_right, + rhs_bot_right, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + }, + |_| { + mat_x_lower_impl_unchecked( + dst_bot_left.transpose_mut(), + rhs_bot_left.transpose(), + lhs_bot_right.transpose(), + lhs_diag, + alpha, + beta, + conj_rhs, + conj_lhs, + parallelism, + ) + }, + parallelism, + ); + + upper_x_lower_impl_unchecked( + dst_bot_right, + lhs_bot_right, + lhs_diag, + rhs_bot_right, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + }, + parallelism, + ); + } +} + +unsafe fn upper_x_lower_into_lower_impl_unchecked( + dst: MatMut<'_, E>, + skip_diag: bool, + lhs: MatRef<'_, E>, + lhs_diag: DiagonalKind, + rhs: MatRef<'_, E>, + rhs_diag: DiagonalKind, + alpha: Option, + beta: E, + conj_lhs: Conj, + conj_rhs: Conj, + parallelism: Parallelism, +) { + let n = dst.nrows(); + debug_assert!(n == lhs.nrows()); + debug_assert!(n == lhs.ncols()); + debug_assert!(n == rhs.nrows()); + debug_assert!(n == rhs.ncols()); + debug_assert!(n == dst.nrows()); + debug_assert!(n == dst.ncols()); + + if n <= 16 { + let op = { + #[inline(never)] + || { + stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E); + stack_mat_16x16_begin!(temp_lhs, n, n, lhs.row_stride(), lhs.col_stride(), E); + stack_mat_16x16_begin!(temp_rhs, n, n, rhs.row_stride(), rhs.col_stride(), E); + + copy_upper(temp_lhs.rb_mut(), lhs, lhs_diag); + copy_lower(temp_rhs.rb_mut(), rhs, rhs_diag); + + mul( + temp_dst.rb_mut(), + temp_lhs.rb(), + temp_rhs.rb(), + None, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + + accum_lower(dst, temp_dst.rb(), skip_diag, alpha); + } + }; + op(); + } else { + let bs = n / 2; + + let (mut dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs); + let (lhs_top_left, lhs_top_right, _, lhs_bot_right) = lhs.split_at(bs, bs); + let (rhs_top_left, _, rhs_bot_left, rhs_bot_right) = rhs.split_at(bs, bs); + + // lhs_top_left × rhs_top_left => dst_top_left | upp × low => low | X + // lhs_top_right × rhs_bot_left => dst_top_left | mat × mat => low | 1/2 + // + // lhs_bot_right × rhs_bot_left => dst_bot_left | upp × mat => mat | 1/2 + // lhs_bot_right × rhs_bot_right => dst_bot_right | upp × low => low | X + + join_raw( + |_| { + mat_x_mat_into_lower_impl_unchecked( + dst_top_left.rb_mut(), + skip_diag, + lhs_top_right, + rhs_bot_left, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + upper_x_lower_into_lower_impl_unchecked( + dst_top_left, + skip_diag, + lhs_top_left, + lhs_diag, + rhs_top_left, + rhs_diag, + Some(E::faer_one()), + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + }, + |_| { + mat_x_lower_impl_unchecked( + dst_bot_left.transpose_mut(), + rhs_bot_left.transpose(), + lhs_bot_right.transpose(), + lhs_diag, + alpha, + beta, + conj_rhs, + conj_lhs, + parallelism, + ); + upper_x_lower_into_lower_impl_unchecked( + dst_bot_right, + skip_diag, + lhs_bot_right, + lhs_diag, + rhs_bot_right, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + }, + parallelism, + ); + } +} + +unsafe fn mat_x_mat_into_lower_impl_unchecked( + dst: MatMut<'_, E>, + skip_diag: bool, + lhs: MatRef<'_, E>, + rhs: MatRef<'_, E>, + alpha: Option, + beta: E, + conj_lhs: Conj, + conj_rhs: Conj, + parallelism: Parallelism, +) { + debug_assert!(dst.nrows() == dst.ncols()); + debug_assert!(dst.nrows() == lhs.nrows()); + debug_assert!(dst.ncols() == rhs.ncols()); + debug_assert!(lhs.ncols() == rhs.nrows()); + + let n = dst.nrows(); + let k = lhs.ncols(); + + let join_parallelism = if n * n * k < 128 * 128 * 128 { + Parallelism::None + } else { + parallelism + }; + + if n <= 16 { + let op = { + #[inline(never)] + || { + stack_mat_16x16_begin!(temp_dst, n, n, dst.row_stride(), dst.col_stride(), E); + + mul( + temp_dst.rb_mut(), + lhs, + rhs, + None, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + accum_lower(dst, temp_dst.rb(), skip_diag, alpha); + } + }; + op(); + } else { + let bs = n / 2; + let (dst_top_left, _, dst_bot_left, dst_bot_right) = dst.split_at_mut(bs, bs); + let (lhs_top, lhs_bot) = lhs.split_at_row(bs); + let (rhs_left, rhs_right) = rhs.split_at_col(bs); + + join_raw( + |_| { + mul( + dst_bot_left, + lhs_bot, + rhs_left, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + }, + |_| { + join_raw( + |_| { + mat_x_mat_into_lower_impl_unchecked( + dst_top_left, + skip_diag, + lhs_top, + rhs_left, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + }, + |_| { + mat_x_mat_into_lower_impl_unchecked( + dst_bot_right, + skip_diag, + lhs_bot, + rhs_right, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + }, + join_parallelism, + ) + }, + join_parallelism, + ); + } +} + +/// Describes the parts of the matrix that must be accessed. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BlockStructure { + /// The full matrix is accessed. + Rectangular, + /// The lower triangular half (including the diagonal) is accessed. + TriangularLower, + /// The lower triangular half (excluding the diagonal) is accessed. + StrictTriangularLower, + /// The lower triangular half (excluding the diagonal, which is assumed to be equal to + /// `1.0`) is accessed. + UnitTriangularLower, + /// The upper triangular half (including the diagonal) is accessed. + TriangularUpper, + /// The upper triangular half (excluding the diagonal) is accessed. + StrictTriangularUpper, + /// The upper triangular half (excluding the diagonal, which is assumed to be equal to + /// `1.0`) is accessed. + UnitTriangularUpper, +} + +impl BlockStructure { + /// Checks if `self` is full. + #[inline] + pub fn is_dense(self) -> bool { + matches!(self, BlockStructure::Rectangular) + } + + /// Checks if `self` is triangular lower (either inclusive or exclusive). + #[inline] + pub fn is_lower(self) -> bool { + use BlockStructure::*; + matches!( + self, + TriangularLower | StrictTriangularLower | UnitTriangularLower + ) + } + + /// Checks if `self` is triangular upper (either inclusive or exclusive). + #[inline] + pub fn is_upper(self) -> bool { + use BlockStructure::*; + matches!( + self, + TriangularUpper | StrictTriangularUpper | UnitTriangularUpper + ) + } + + /// Returns the block structure corresponding to the transposed matrix. + #[inline] + pub fn transpose(self) -> Self { + use BlockStructure::*; + match self { + Rectangular => Rectangular, + TriangularLower => TriangularUpper, + StrictTriangularLower => StrictTriangularUpper, + UnitTriangularLower => UnitTriangularUpper, + TriangularUpper => TriangularLower, + StrictTriangularUpper => StrictTriangularLower, + UnitTriangularUpper => UnitTriangularLower, + } + } + + #[inline] + pub(crate) fn diag_kind(self) -> DiagonalKind { + use BlockStructure::*; + match self { + Rectangular | TriangularLower | TriangularUpper => DiagonalKind::Generic, + StrictTriangularLower | StrictTriangularUpper => DiagonalKind::Zero, + UnitTriangularLower | UnitTriangularUpper => DiagonalKind::Unit, + } + } +} + +/// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` (while optionally conjugating +/// either or both of the input matrices) and stores the result in `acc`. +/// +/// Performs the operation: +/// - `acc = beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `None` (in this case, the preexisting +/// values in `acc` are not read, so it is allowed to be a view over uninitialized values if `E: +/// Copy`), +/// - `acc = alpha * acc + beta * Op_lhs(lhs) * Op_rhs(rhs)` if `alpha` is `Some(_)`, +/// +/// The left hand side and right hand side may be interpreted as triangular depending on the +/// given corresponding matrix structure. +/// +/// For the destination matrix, the result is: +/// - fully computed if the structure is rectangular, +/// - only the triangular half (including the diagonal) is computed if the structure is +/// triangular, +/// - only the strict triangular half (excluding the diagonal) is computed if the structure is +/// strictly triangular or unit triangular. +/// +/// `Op_lhs` is the identity if `conj_lhs` is `Conj::No`, and the conjugation operation if it is +/// `Conj::Yes`. +/// `Op_rhs` is the identity if `conj_rhs` is `Conj::No`, and the conjugation operation if it is +/// `Conj::Yes`. +/// +/// # Panics +/// +/// Panics if the matrix dimensions are not compatible for matrix multiplication. +/// i.e. +/// - `acc.nrows() == lhs.nrows()` +/// - `acc.ncols() == rhs.ncols()` +/// - `lhs.ncols() == rhs.nrows()` +/// +/// Additionally, matrices that are marked as triangular must be square, i.e., they must have +/// the same number of rows and columns. +/// +/// # Example +/// +/// ``` +/// use faer::{ +/// linalg::matmul::triangular::{matmul_with_conj, BlockStructure}, +/// mat, unzipped, zipped, Conj, Mat, Parallelism, +/// }; +/// +/// let lhs = mat![[0.0, 2.0], [1.0, 3.0]]; +/// let rhs = mat![[4.0, 6.0], [5.0, 7.0]]; +/// +/// let mut acc = Mat::::zeros(2, 2); +/// let target = mat![ +/// [ +/// 2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)), +/// 0.0, +/// ], +/// [ +/// 2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)), +/// 2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)), +/// ], +/// ]; +/// +/// matmul_with_conj( +/// acc.as_mut(), +/// BlockStructure::TriangularLower, +/// lhs.as_ref(), +/// BlockStructure::Rectangular, +/// Conj::No, +/// rhs.as_ref(), +/// BlockStructure::Rectangular, +/// Conj::No, +/// None, +/// 2.5, +/// Parallelism::None, +/// ); +/// +/// zipped!(acc.as_ref(), target.as_ref()) +/// .for_each(|unzipped!(acc, target)| assert!((acc.read() - target.read()).abs() < 1e-10)); +/// ``` +#[track_caller] +#[inline] +pub fn matmul_with_conj( + acc: MatMut<'_, E>, + acc_structure: BlockStructure, + lhs: MatRef<'_, E>, + lhs_structure: BlockStructure, + conj_lhs: Conj, + rhs: MatRef<'_, E>, + rhs_structure: BlockStructure, + conj_rhs: Conj, + alpha: Option, + beta: E, + parallelism: Parallelism, +) { + assert!(all( + acc.nrows() == lhs.nrows(), + acc.ncols() == rhs.ncols(), + lhs.ncols() == rhs.nrows(), + )); + + if !acc_structure.is_dense() { + assert!(acc.nrows() == acc.ncols()); + } + if !lhs_structure.is_dense() { + assert!(lhs.nrows() == lhs.ncols()); + } + if !rhs_structure.is_dense() { + assert!(rhs.nrows() == rhs.ncols()); + } + + unsafe { + matmul_unchecked( + acc, + acc_structure, + lhs, + lhs_structure, + conj_lhs, + rhs, + rhs_structure, + conj_rhs, + alpha, + beta, + parallelism, + ) + } +} + +/// Computes the matrix product `[alpha * acc] + beta * lhs * rhs` and stores the result in +/// `acc`. +/// +/// Performs the operation: +/// - `acc = beta * lhs * rhs` if `alpha` is `None` (in this case, the preexisting values in `acc` +/// are not read, so it is allowed to be a view over uninitialized values if `E: Copy`), +/// - `acc = alpha * acc + beta * lhs * rhs` if `alpha` is `Some(_)`, +/// +/// The left hand side and right hand side may be interpreted as triangular depending on the +/// given corresponding matrix structure. +/// +/// For the destination matrix, the result is: +/// - fully computed if the structure is rectangular, +/// - only the triangular half (including the diagonal) is computed if the structure is +/// triangular, +/// - only the strict triangular half (excluding the diagonal) is computed if the structure is +/// strictly triangular or unit triangular. +/// +/// # Panics +/// +/// Panics if the matrix dimensions are not compatible for matrix multiplication. +/// i.e. +/// - `acc.nrows() == lhs.nrows()` +/// - `acc.ncols() == rhs.ncols()` +/// - `lhs.ncols() == rhs.nrows()` +/// +/// Additionally, matrices that are marked as triangular must be square, i.e., they must have +/// the same number of rows and columns. +/// +/// # Example +/// +/// ``` +/// use faer::{ +/// linalg::matmul::triangular::{matmul, BlockStructure}, +/// mat, unzipped, zipped, Conj, Mat, Parallelism, +/// }; +/// +/// let lhs = mat![[0.0, 2.0], [1.0, 3.0]]; +/// let rhs = mat![[4.0, 6.0], [5.0, 7.0]]; +/// +/// let mut acc = Mat::::zeros(2, 2); +/// let target = mat![ +/// [ +/// 2.5 * (lhs.read(0, 0) * rhs.read(0, 0) + lhs.read(0, 1) * rhs.read(1, 0)), +/// 0.0, +/// ], +/// [ +/// 2.5 * (lhs.read(1, 0) * rhs.read(0, 0) + lhs.read(1, 1) * rhs.read(1, 0)), +/// 2.5 * (lhs.read(1, 0) * rhs.read(0, 1) + lhs.read(1, 1) * rhs.read(1, 1)), +/// ], +/// ]; +/// +/// matmul( +/// acc.as_mut(), +/// BlockStructure::TriangularLower, +/// lhs.as_ref(), +/// BlockStructure::Rectangular, +/// rhs.as_ref(), +/// BlockStructure::Rectangular, +/// None, +/// 2.5, +/// Parallelism::None, +/// ); +/// +/// zipped!(acc.as_ref(), target.as_ref()) +/// .for_each(|unzipped!(acc, target)| assert!((acc.read() - target.read()).abs() < 1e-10)); +/// ``` +#[track_caller] +#[inline] +pub fn matmul, RhsE: Conjugate>( + acc: MatMut<'_, E>, + acc_structure: BlockStructure, + lhs: MatRef<'_, LhsE>, + lhs_structure: BlockStructure, + rhs: MatRef<'_, RhsE>, + rhs_structure: BlockStructure, + alpha: Option, + beta: E, + parallelism: Parallelism, +) { + let (lhs, conj_lhs) = lhs.canonicalize(); + let (rhs, conj_rhs) = rhs.canonicalize(); + matmul_with_conj( + acc, + acc_structure, + lhs, + lhs_structure, + conj_lhs, + rhs, + rhs_structure, + conj_rhs, + alpha, + beta, + parallelism, + ); +} + +unsafe fn matmul_unchecked( + acc: MatMut<'_, E>, + acc_structure: BlockStructure, + lhs: MatRef<'_, E>, + lhs_structure: BlockStructure, + conj_lhs: Conj, + rhs: MatRef<'_, E>, + rhs_structure: BlockStructure, + conj_rhs: Conj, + alpha: Option, + beta: E, + parallelism: Parallelism, +) { + debug_assert!(acc.nrows() == lhs.nrows()); + debug_assert!(acc.ncols() == rhs.ncols()); + debug_assert!(lhs.ncols() == rhs.nrows()); + + if !acc_structure.is_dense() { + debug_assert!(acc.nrows() == acc.ncols()); + } + if !lhs_structure.is_dense() { + debug_assert!(lhs.nrows() == lhs.ncols()); + } + if !rhs_structure.is_dense() { + debug_assert!(rhs.nrows() == rhs.ncols()); + } + + let mut acc = acc; + let mut lhs = lhs; + let mut rhs = rhs; + + let mut acc_structure = acc_structure; + let mut lhs_structure = lhs_structure; + let mut rhs_structure = rhs_structure; + + let mut conj_lhs = conj_lhs; + let mut conj_rhs = conj_rhs; + + // if either the lhs or the rhs is triangular + if rhs_structure.is_lower() { + // do nothing + false + } else if rhs_structure.is_upper() { + // invert acc, lhs and rhs + acc = acc.reverse_rows_and_cols_mut(); + lhs = lhs.reverse_rows_and_cols(); + rhs = rhs.reverse_rows_and_cols(); + acc_structure = acc_structure.transpose(); + lhs_structure = lhs_structure.transpose(); + rhs_structure = rhs_structure.transpose(); + false + } else if lhs_structure.is_lower() { + // invert and transpose + acc = acc.reverse_rows_and_cols_mut().transpose_mut(); + (lhs, rhs) = ( + rhs.reverse_rows_and_cols().transpose(), + lhs.reverse_rows_and_cols().transpose(), + ); + (conj_lhs, conj_rhs) = (conj_rhs, conj_lhs); + (lhs_structure, rhs_structure) = (rhs_structure, lhs_structure); + true + } else if lhs_structure.is_upper() { + // transpose + acc_structure = acc_structure.transpose(); + acc = acc.transpose_mut(); + (lhs, rhs) = (rhs.transpose(), lhs.transpose()); + (conj_lhs, conj_rhs) = (conj_rhs, conj_lhs); + (lhs_structure, rhs_structure) = (rhs_structure.transpose(), lhs_structure.transpose()); + true + } else { + // do nothing + false + }; + + let clear_upper = |acc: MatMut<'_, E>, skip_diag: bool| match &alpha { + &Some(alpha) => zipped!(acc).for_each_triangular_upper( + if skip_diag { Diag::Skip } else { Diag::Include }, + |unzipped!(mut acc)| acc.write(alpha.faer_mul(acc.read())), + ), + + None => zipped!(acc).for_each_triangular_upper( + if skip_diag { Diag::Skip } else { Diag::Include }, + |unzipped!(mut acc)| acc.write(E::faer_zero()), + ), + }; + + let skip_diag = matches!( + acc_structure, + BlockStructure::StrictTriangularLower + | BlockStructure::StrictTriangularUpper + | BlockStructure::UnitTriangularLower + | BlockStructure::UnitTriangularUpper + ); + let lhs_diag = lhs_structure.diag_kind(); + let rhs_diag = rhs_structure.diag_kind(); + + if acc_structure.is_dense() { + if lhs_structure.is_dense() && rhs_structure.is_dense() { + mul(acc, lhs, rhs, alpha, beta, conj_lhs, conj_rhs, parallelism); + } else { + debug_assert!(rhs_structure.is_lower()); + + if lhs_structure.is_dense() { + mat_x_lower_impl_unchecked( + acc, + lhs, + rhs, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + } else if lhs_structure.is_lower() { + clear_upper(acc.rb_mut(), true); + lower_x_lower_into_lower_impl_unchecked( + acc, + false, + lhs, + lhs_diag, + rhs, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + } else { + debug_assert!(lhs_structure.is_upper()); + upper_x_lower_impl_unchecked( + acc, + lhs, + lhs_diag, + rhs, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + } + } + } else if acc_structure.is_lower() { + if lhs_structure.is_dense() && rhs_structure.is_dense() { + mat_x_mat_into_lower_impl_unchecked( + acc, + skip_diag, + lhs, + rhs, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + } else { + debug_assert!(rhs_structure.is_lower()); + if lhs_structure.is_dense() { + mat_x_lower_into_lower_impl_unchecked( + acc, + skip_diag, + lhs, + rhs, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ); + } else if lhs_structure.is_lower() { + lower_x_lower_into_lower_impl_unchecked( + acc, + skip_diag, + lhs, + lhs_diag, + rhs, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + } else { + upper_x_lower_into_lower_impl_unchecked( + acc, + skip_diag, + lhs, + lhs_diag, + rhs, + rhs_diag, + alpha, + beta, + conj_lhs, + conj_rhs, + parallelism, + ) + } + } + } else if lhs_structure.is_dense() && rhs_structure.is_dense() { + mat_x_mat_into_lower_impl_unchecked( + acc.transpose_mut(), + skip_diag, + rhs.transpose(), + lhs.transpose(), + alpha, + beta, + conj_rhs, + conj_lhs, + parallelism, + ) + } else { + debug_assert!(rhs_structure.is_lower()); + if lhs_structure.is_dense() { + // lower part of lhs does not contribute to result + upper_x_lower_into_lower_impl_unchecked( + acc.transpose_mut(), + skip_diag, + rhs.transpose(), + rhs_diag, + lhs.transpose(), + lhs_diag, + alpha, + beta, + conj_rhs, + conj_lhs, + parallelism, + ) + } else if lhs_structure.is_lower() { + if !skip_diag { + match &alpha { + &Some(alpha) => { + zipped!( + acc.rb_mut().diagonal_mut().column_vector_mut().as_2d_mut(), + lhs.diagonal().column_vector().as_2d(), + rhs.diagonal().column_vector().as_2d(), + ) + .for_each(|unzipped!(mut acc, lhs, rhs)| { + acc.write( + (alpha.faer_mul(acc.read())) + .faer_add(beta.faer_mul(lhs.read().faer_mul(rhs.read()))), + ) + }); + } + None => { + zipped!( + acc.rb_mut().diagonal_mut().column_vector_mut().as_2d_mut(), + lhs.diagonal().column_vector().as_2d(), + rhs.diagonal().column_vector().as_2d(), + ) + .for_each(|unzipped!(mut acc, lhs, rhs)| { + acc.write(beta.faer_mul(lhs.read().faer_mul(rhs.read()))) + }); + } + } + } + clear_upper(acc.rb_mut(), true); + } else { + debug_assert!(lhs_structure.is_upper()); + upper_x_lower_into_lower_impl_unchecked( + acc.transpose_mut(), + skip_diag, + rhs.transpose(), + rhs_diag, + lhs.transpose(), + lhs_diag, + alpha, + beta, + conj_rhs, + conj_lhs, + parallelism, + ) + } + } +} diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..c505b2f8f0da741670ded782096f1e47775a0286 --- /dev/null +++ b/src/linalg/mod.rs @@ -0,0 +1,171 @@ +//! Linear algebra module. +//! +//! Contains low level routines and the implementation of their corresponding high level +//! wrappers. +//! +//! # Memory allocation +//! Since most `faer` crates aim to expose a low level api for optimal performance, most algorithms +//! try to defer memory allocation to the user. +//! +//! However, since a lot of algorithms need some form of temporary space for intermediate +//! computations, they may ask for a slice of memory for that purpose, by taking a [`stack: +//! PodStack`](dyn_stack::PodStack) parameter. A [`PodStack`] is a thin wrapper over a slice of +//! memory bytes. This memory may come from any valid source (heap allocation, fixed-size array on +//! the stack, etc.). The functions taking a [`PodStack`] parameter have a corresponding function +//! with a similar name ending in `_req` that returns the memory requirements of the algorithm. For +//! example: +//! [`householder::apply_block_householder_on_the_left_in_place_with_conj`] and +//! [`householder::apply_block_householder_on_the_left_in_place_req`]. +//! +//! The memory stack may be reused in user-code to avoid repeated allocations, and it is also +//! possible to compute the sum ([`dyn_stack::StackReq::all_of`]) or union +//! ([`dyn_stack::StackReq::any_of`]) of multiple requirements, in order to optimally combine them +//! into a single allocation. +//! +//! After computing a [`dyn_stack::StackReq`], one can query its size and alignment to allocate the +//! required memory. The simplest way to do so is through [`dyn_stack::GlobalMemBuffer::new`]. +//! +//! # Entity trait +//! Matrices are built on top of the [`Entity`] trait, which describes the prefered memory +//! storage layout for a given type `E`. An entity can be decomposed into a group of units: for +//! a natively supported type ([`f32`], [`f64`], [`c32`](crate::complex_native::c32), +//! [`c64`](crate::complex_native::c64)), the unit is simply the type itself, and a group +//! contains a single element. On the other hand, for a type with a more specific preferred +//! layout, like an extended precision floating point type, or a dual number type, the unit +//! would be one of the natively supported types, and the group would be a structure holding +//! the components that build up the full value. +//! +//! To take a more specific example: [`num_complex::Complex`] has a storage memory layout +//! that differs from that of [`c64`](crate::complex_native::c64) (see +//! [`faer::complex_native`](crate::complex_native) for more details). Its real and complex +//! components are stored separately, so its unit type is `f64`, while its group type is +//! `Complex`. In practice, this means that for a `Mat`, methods such as +//! [`Mat::col_as_slice`] will return a `&[f64]`. Meanwhile, for a `Mat>`, +//! [`Mat::col_as_slice`] will return `Complex<&[f64]>`, which holds two slices, each pointing +//! respectively to a view over the real and the imaginary components. +//! +//! While the design of the entity trait is unconventional, it helps us achieve much higher +//! performance when targetting non native types, due to the design matching the typical +//! preffered CPU layout for SIMD operations. And for native types, since [`Group` is just +//! `T`](Entity#impl-Entity-for-f64), the entity layer is a no-op, and the matrix layout is +//! compatible with the classic contiguous layout that's commonly used by other libraries. + +use crate::{ + mat::{self, matalloc::align_for, *}, + utils::DivCeil, +}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; +use entity::{ComplexField, Entity}; + +pub use faer_entity as entity; + +pub mod zip; + +pub mod householder; +pub mod matmul; +pub mod triangular_inverse; +pub mod triangular_solve; + +pub mod cholesky; +pub mod lu; +pub mod qr; + +pub mod evd; +pub mod svd; + +/// High level linear system solvers. +pub mod solvers; + +pub(crate) mod kron_impl; +mod mat_ops; +pub(crate) mod reductions; + +pub use kron_impl::kron; + +#[inline] +pub(crate) fn col_stride(nrows: usize) -> usize { + if !crate::mat::matalloc::is_vectorizable::() || nrows >= isize::MAX as usize { + nrows + } else { + nrows + .msrv_checked_next_multiple_of(align_for::() / core::mem::size_of::()) + .unwrap() + } +} + +/// Returns the stack requirements for creating a temporary matrix with the given dimensions. +#[inline] +pub fn temp_mat_req(nrows: usize, ncols: usize) -> Result { + let col_stride = col_stride::(nrows); + let alloc_size = ncols.checked_mul(col_stride).ok_or(SizeOverflow)?; + let additional = StackReq::try_new_aligned::(alloc_size, align_for::())?; + + let req = Ok(StackReq::empty()); + let (req, _) = E::faer_map_with_context(req, E::UNIT, &mut { + #[inline(always)] + |req, ()| { + let req = match req { + Ok(req) => req.try_and(additional), + _ => Err(SizeOverflow), + }; + + (req, ()) + } + }); + + req +} + +/// Creates a temporary matrix of constant values, from the given memory stack. +pub fn temp_mat_constant( + nrows: usize, + ncols: usize, + value: E, + stack: PodStack<'_>, +) -> (MatMut<'_, E>, PodStack<'_>) { + let (mut mat, stack) = temp_mat_uninit::(nrows, ncols, stack); + mat.as_mut().fill(value); + (mat, stack) +} + +/// Creates a temporary matrix of zero values, from the given memory stack. +pub fn temp_mat_zeroed( + nrows: usize, + ncols: usize, + stack: PodStack<'_>, +) -> (MatMut<'_, E>, PodStack<'_>) { + let (mut mat, stack) = temp_mat_uninit::(nrows, ncols, stack); + mat.as_mut().fill_zero(); + (mat, stack) +} + +/// Creates a temporary matrix of untouched values, from the given memory stack. +pub fn temp_mat_uninit( + nrows: usize, + ncols: usize, + stack: PodStack<'_>, +) -> (MatMut<'_, E>, PodStack<'_>) { + let col_stride = col_stride::(nrows); + let alloc_size = ncols.checked_mul(col_stride).unwrap(); + + let (stack, alloc) = E::faer_map_with_context(stack, E::UNIT, &mut { + #[inline(always)] + |stack, ()| { + let (alloc, stack) = + stack.make_aligned_raw::(alloc_size, align_for::()); + (stack, alloc) + } + }); + ( + unsafe { + mat::from_raw_parts_mut( + E::faer_map(alloc, |alloc| alloc.as_mut_ptr()), + nrows, + ncols, + 1, + col_stride as isize, + ) + }, + stack, + ) +} diff --git a/faer-libs/faer-qr/src/col_pivoting/compute.rs b/src/linalg/qr/col_pivoting/compute.rs similarity index 96% rename from faer-libs/faer-qr/src/col_pivoting/compute.rs rename to src/linalg/qr/col_pivoting/compute.rs index 7a99d388e4f0709dfd7ecc73a99789deb8a3d859..13926243d9a68198ba77e995615080d049a351d3 100644 --- a/faer-libs/faer-qr/src/col_pivoting/compute.rs +++ b/src/linalg/qr/col_pivoting/compute.rs @@ -1,14 +1,18 @@ -pub use crate::no_pivoting::compute::recommended_blocksize; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, c32, c64, debug_assert, - group_helpers::*, - householder::upgrade_householder_factor, - mul::inner_prod::{self, inner_prod_with_conj_arch}, - permutation::{swap_cols, Index, PermutationMut, SignedIndex}, - transmute_unchecked, unzipped, zipped, ComplexField, Conj, DivCeil, Entity, MatMut, MatRef, - Parallelism, SimdCtx, +pub use crate::linalg::qr::no_pivoting::compute::recommended_blocksize; +use crate::{ + assert, + complex_native::{c32, c64}, + debug_assert, + linalg::{ + householder::upgrade_householder_factor, + matmul::inner_prod::{self, inner_prod_with_conj_arch}, + }, + perm::{swap_cols_idx as swap_cols, PermRef}, + unzipped, + utils::{simd::*, slice::*, DivCeil}, + zipped, ComplexField, Conj, Entity, Index, MatMut, MatRef, Parallelism, SignedIndex, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use faer_entity::*; use pulp::Simd; use reborrow::*; @@ -446,14 +450,12 @@ fn qr_in_place_colmajor( let first_col = first_col.col_mut(0); let (mut first_head, mut first_tail) = first_col.split_at_mut(1); - let tail_squared_norm = norm2(arch, first_tail.rb().as_2d()); + let tail_norm = first_tail.norm_l2(); - // TODO: replace norm2 functions by non-squared norm, even in the simd kernel - #[allow(deprecated)] - let (tau, beta) = faer_core::householder::make_householder_in_place( + let (tau, beta) = crate::linalg::householder::make_householder_in_place( Some(first_tail.rb_mut().as_2d_mut()), first_head.read(0), - tail_squared_norm, + tail_norm, ); first_head.write(0, beta); let tau_inv = tau.faer_inv(); @@ -488,7 +490,9 @@ fn qr_in_place_colmajor( } #[cfg(feature = "rayon")] Parallelism::Rayon(_) => { - use faer_core::{for_each_raw, par_split_indices, parallelism_degree, Ptr}; + use crate::utils::thread::{ + for_each_raw, par_split_indices, parallelism_degree, Ptr, + }; let n_threads = parallelism_degree(parallelism); let mut biggest_col = vec![(E::Real::faer_zero(), 0_usize); n_threads]; @@ -648,6 +652,7 @@ fn default_disable_parallelism(m: usize, n: usize) -> bool { prod < 192 * 256 } +/// QR factorization tuning parameters. #[derive(Default, Copy, Clone)] #[non_exhaustive] pub struct ColPivQrComputeParams { @@ -680,8 +685,11 @@ pub fn qr_in_place_req( Ok(StackReq::default()) } +/// Information about the resulting QR factorization. #[derive(Copy, Clone, Debug)] pub struct ColPivQrInfo { + /// Number of transpositions that were performed, can be used to compute the determinant of + /// $P$. pub transposition_count: usize, } @@ -720,7 +728,7 @@ pub fn qr_in_place<'out, I: Index, E: ComplexField>( parallelism: Parallelism, stack: PodStack<'_>, params: ColPivQrComputeParams, -) -> (ColPivQrInfo, PermutationMut<'out, I, E>) { +) -> (ColPivQrInfo, PermRef<'out, I>) { fn implementation<'out, I: Index, E: ComplexField>( matrix: MatMut<'_, E>, householder_factor: MatMut<'_, E>, @@ -729,7 +737,7 @@ pub fn qr_in_place<'out, I: Index, E: ComplexField>( parallelism: Parallelism, stack: PodStack<'_>, params: ColPivQrComputeParams, - ) -> (usize, PermutationMut<'out, I, E>) { + ) -> (usize, PermRef<'out, I>) { { let truncate = ::truncate; @@ -741,7 +749,7 @@ pub fn qr_in_place<'out, I: Index, E: ComplexField>( assert!(all(col_perm.len() == n, col_perm_inv.len() == n)); #[cfg(feature = "perf-warn")] - if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(QR_WARN) { + if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(QR_WARN) { if matrix.col_stride().unsigned_abs() == 1 { log::warn!(target: "faer_perf", "QR with column pivoting prefers column-major matrix. Found row-major matrix."); } else { @@ -804,7 +812,7 @@ pub fn qr_in_place<'out, I: Index, E: ComplexField>( } (n_transpositions, unsafe { - PermutationMut::new_unchecked(col_perm, col_perm_inv) + PermRef::new_unchecked(col_perm, col_perm_inv) }) } } @@ -822,24 +830,27 @@ pub fn qr_in_place<'out, I: Index, E: ComplexField>( ColPivQrInfo { transposition_count: n_transpositions, }, - perm.uncanonicalize::(), + perm.uncanonicalized::(), ) } #[cfg(test)] mod tests { use super::*; - use assert_approx_eq::assert_approx_eq; - use faer_core::{ - assert, c64, - householder::{ - apply_block_householder_sequence_on_the_left_in_place_req, - apply_block_householder_sequence_on_the_left_in_place_with_conj, + use crate::{ + assert, + complex_native::c64, + linalg::{ + householder::{ + apply_block_householder_sequence_on_the_left_in_place_req, + apply_block_householder_sequence_on_the_left_in_place_with_conj, + }, + matmul::matmul, + zip::Diag, }, - mul::matmul, - zip::Diag, Conj, Mat, MatRef, }; + use assert_approx_eq::assert_approx_eq; use matrixcompare::assert_matrix_eq; use rand::random; diff --git a/faer-libs/faer-qr/src/col_pivoting/inverse.rs b/src/linalg/qr/col_pivoting/inverse.rs similarity index 88% rename from faer-libs/faer-qr/src/col_pivoting/inverse.rs rename to src/linalg/qr/col_pivoting/inverse.rs index a701e5f2b0be1eb9dabb4ac10db89168f06b3ee7..1f6d18fb648a9d472965112b4aa1e76e71c7e8c8 100644 --- a/faer-libs/faer-qr/src/col_pivoting/inverse.rs +++ b/src/linalg/qr/col_pivoting/inverse.rs @@ -1,12 +1,13 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - householder::apply_block_householder_sequence_transpose_on_the_right_in_place_with_conj, - inverse::invert_upper_triangular, - permutation::{permute_cols_in_place_req, permute_rows_in_place, Index, PermutationRef}, - temp_mat_req, temp_mat_uninit, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, - Parallelism, + linalg::{ + householder::apply_block_householder_sequence_transpose_on_the_right_in_place_with_conj, + temp_mat_req, temp_mat_uninit, triangular_inverse::invert_upper_triangular, + }, + perm::{permute_cols_in_place_req, permute_rows_in_place, PermRef}, + unzipped, zipped, ComplexField, Conj, Entity, Index, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; /// Computes the inverse of a matrix, given its QR decomposition with column pivoting, @@ -26,7 +27,7 @@ pub fn invert( dst: MatMut<'_, E>, qr_factors: MatRef<'_, E>, householder_factor: MatRef<'_, E>, - col_perm: PermutationRef<'_, I, E>, + col_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -47,7 +48,7 @@ pub fn invert( // zero bottom part zipped!(dst.rb_mut()) - .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { + .for_each_triangular_lower(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { dst.write(E::faer_zero()) }); @@ -78,7 +79,7 @@ pub fn invert( pub fn invert_in_place( qr_factors: MatMut<'_, E>, householder_factor: MatRef<'_, E>, - col_perm: PermutationRef<'_, I, E>, + col_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -130,9 +131,16 @@ pub fn invert_in_place_req( #[cfg(test)] mod tests { use super::*; - use crate::col_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}; + use crate::{ + assert, + complex_native::c64, + linalg::{ + matmul::matmul, + qr::col_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}, + }, + Mat, + }; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, c64, mul::matmul, Mat}; use rand::prelude::*; use std::cell::RefCell; @@ -164,7 +172,7 @@ mod tests { let mut qr = mat.clone(); let mut householder_factor = Mat::zeros(blocksize, n); - let parallelism = faer_core::Parallelism::Rayon(0); + let parallelism = crate::Parallelism::Rayon(0); let mut perm = vec![0usize; n]; let mut perm_inv = vec![0; n]; diff --git a/faer-libs/faer-qr/src/col_pivoting/mod.rs b/src/linalg/qr/col_pivoting/mod.rs similarity index 57% rename from faer-libs/faer-qr/src/col_pivoting/mod.rs rename to src/linalg/qr/col_pivoting/mod.rs index b21aeca74adab0e166f2db4ac621afce36f0e8e1..ff153b019978211a8b2347f46465e241d8251781 100644 --- a/faer-libs/faer-qr/src/col_pivoting/mod.rs +++ b/src/linalg/qr/col_pivoting/mod.rs @@ -3,7 +3,11 @@ //! where $P$ is a permutation matrix, $Q$ is a unitary matrix (represented as a block Householder //! sequence), and $R$ is an upper trapezoidal matrix. +/// Computing the decomposition. pub mod compute; +/// Reconstructing the inverse of the original matrix from the decomposition. pub mod inverse; +/// Reconstructing the original matrix from the decomposition. pub mod reconstruct; +/// Solving a linear system usin the decomposition. pub mod solve; diff --git a/faer-libs/faer-qr/src/col_pivoting/reconstruct.rs b/src/linalg/qr/col_pivoting/reconstruct.rs similarity index 86% rename from faer-libs/faer-qr/src/col_pivoting/reconstruct.rs rename to src/linalg/qr/col_pivoting/reconstruct.rs index 9485f03dbfb1b48883ac6d3bd7d1fb0a411098b3..cf068a9f766edecb3f043837a8474fc3b959a6be 100644 --- a/faer-libs/faer-qr/src/col_pivoting/reconstruct.rs +++ b/src/linalg/qr/col_pivoting/reconstruct.rs @@ -1,11 +1,13 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - householder::apply_block_householder_sequence_on_the_left_in_place_with_conj, - permutation::{permute_cols_in_place, permute_cols_in_place_req, Index, PermutationRef}, - temp_mat_req, temp_mat_uninit, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, - Parallelism, + linalg::{ + householder::apply_block_householder_sequence_on_the_left_in_place_with_conj, temp_mat_req, + temp_mat_uninit, + }, + perm::{permute_cols_in_place, permute_cols_in_place_req, PermRef}, + unzipped, zipped, ComplexField, Conj, Entity, Index, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; /// Computes the reconstructed matrix, given its QR decomposition, and stores the @@ -24,7 +26,7 @@ pub fn reconstruct( dst: MatMut<'_, E>, qr_factors: MatRef<'_, E>, householder_factor: MatRef<'_, E>, - col_perm: PermutationRef<'_, I, E>, + col_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -36,14 +38,14 @@ pub fn reconstruct( let mut stack = stack; // copy R - zipped!(dst.rb_mut(), qr_factors) - .for_each_triangular_upper(faer_core::zip::Diag::Include, |unzipped!(mut dst, src)| { - dst.write(src.read()) - }); + zipped!(dst.rb_mut(), qr_factors).for_each_triangular_upper( + crate::linalg::zip::Diag::Include, + |unzipped!(mut dst, src)| dst.write(src.read()), + ); // zero bottom part zipped!(dst.rb_mut()) - .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { + .for_each_triangular_lower(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { dst.write(E::faer_zero()) }); @@ -73,7 +75,7 @@ pub fn reconstruct( pub fn reconstruct_in_place( qr_factors: MatMut<'_, E>, householder_factor: MatRef<'_, E>, - col_perm: PermutationRef<'_, I, E>, + col_perm: PermRef<'_, I>, parallelism: Parallelism, stack: PodStack<'_>, ) { @@ -125,9 +127,13 @@ pub fn reconstruct_in_place_req( #[cfg(test)] mod tests { use super::*; - use crate::col_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}; + use crate::{ + assert, + complex_native::c64, + linalg::qr::col_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}, + Mat, + }; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, c64, Mat}; use rand::prelude::*; use std::cell::RefCell; @@ -159,7 +165,7 @@ mod tests { let mut qr = mat.clone(); let mut householder_factor = Mat::zeros(blocksize, n); - let parallelism = faer_core::Parallelism::Rayon(0); + let parallelism = crate::Parallelism::Rayon(0); let mut perm = vec![0usize; n]; let mut perm_inv = vec![0; n]; diff --git a/faer-libs/faer-qr/src/col_pivoting/solve.rs b/src/linalg/qr/col_pivoting/solve.rs similarity index 95% rename from faer-libs/faer-qr/src/col_pivoting/solve.rs rename to src/linalg/qr/col_pivoting/solve.rs index 162cd70a26e5348e2af58049936ef9d7b98e4015..da373c869f3e142fcbefb07414b5195ab3c20b04 100644 --- a/faer-libs/faer-qr/src/col_pivoting/solve.rs +++ b/src/linalg/qr/col_pivoting/solve.rs @@ -1,11 +1,9 @@ -use crate::no_pivoting; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - permutation::{ - permute_rows, permute_rows_in_place, permute_rows_in_place_req, Index, PermutationRef, - }, - ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, +use crate::{ + linalg::qr::no_pivoting, + perm::{permute_rows, permute_rows_in_place, permute_rows_in_place_req, PermRef}, + ComplexField, Conj, Entity, Index, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; /// Computes the size and alignment of required workspace for solving a linear system defined by a @@ -85,7 +83,7 @@ pub fn solve_transpose_req( pub fn solve_in_place( qr_factors: MatRef<'_, E>, householder_factor: MatRef<'_, E>, - col_perm: PermutationRef<'_, I, E>, + col_perm: PermRef<'_, I>, conj_lhs: Conj, rhs: MatMut<'_, E>, parallelism: Parallelism, @@ -127,7 +125,7 @@ pub fn solve_in_place( pub fn solve_transpose_in_place( qr_factors: MatRef<'_, E>, householder_factor: MatRef<'_, E>, - col_perm: PermutationRef<'_, I, E>, + col_perm: PermRef<'_, I>, conj_lhs: Conj, rhs: MatMut<'_, E>, parallelism: Parallelism, @@ -169,7 +167,7 @@ pub fn solve( dst: MatMut<'_, E>, qr_factors: MatRef<'_, E>, householder_factor: MatRef<'_, E>, - col_perm: PermutationRef<'_, I, E>, + col_perm: PermRef<'_, I>, conj_lhs: Conj, rhs: MatRef<'_, E>, parallelism: Parallelism, @@ -212,7 +210,7 @@ pub fn solve_transpose( dst: MatMut<'_, E>, qr_factors: MatRef<'_, E>, householder_factor: MatRef<'_, E>, - col_perm: PermutationRef<'_, I, E>, + col_perm: PermRef<'_, I>, conj_lhs: Conj, rhs: MatRef<'_, E>, parallelism: Parallelism, @@ -234,8 +232,15 @@ pub fn solve_transpose( #[cfg(test)] mod tests { use super::*; - use crate::col_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}; - use faer_core::{assert, c32, c64, mul::matmul_with_conj, Mat}; + use crate::{ + assert, + complex_native::{c32, c64}, + linalg::{ + matmul::matmul_with_conj, + qr::col_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}, + }, + Mat, + }; use rand::random; macro_rules! make_stack { diff --git a/faer-libs/faer-qr/src/lib.rs b/src/linalg/qr/mod.rs similarity index 88% rename from faer-libs/faer-qr/src/lib.rs rename to src/linalg/qr/mod.rs index ff453f1910b4d4bfe1cb39e86d0504dcdeea541d..54e00fca37db27ce3dbdcd76c0661e99d4000ca1 100644 --- a/faer-libs/faer-qr/src/lib.rs +++ b/src/linalg/qr/mod.rs @@ -19,8 +19,11 @@ //! //! ``` //! use assert_approx_eq::assert_approx_eq; -//! use dyn_stack::{PodStack, GlobalPodBuffer, StackReq}; -//! use faer_core::{mat, solve, Conj, Mat, Parallelism}; +//! use dyn_stack::{GlobalPodBuffer, PodStack, StackReq}; +//! use faer::{ +//! linalg::{householder, qr::no_pivoting::compute, triangular_solve}, +//! mat, Conj, Mat, Parallelism, +//! }; //! use reborrow::*; //! //! // we start by defining matrices A and B that define our least-squares problem. @@ -59,13 +62,13 @@ //! let rank = a.nrows().min(a.ncols()); //! //! // we choose the recommended block size for the householder factors of our problem. -//! let blocksize = faer_qr::no_pivoting::compute::recommended_blocksize::(a.nrows(), a.ncols()); +//! let blocksize = compute::recommended_blocksize::(a.nrows(), a.ncols()); //! //! // we allocate the memory for the operations that we perform //! let mut mem = //! GlobalPodBuffer::new(StackReq::any_of( //! [ -//! faer_qr::no_pivoting::compute::qr_in_place_req::( +//! compute::qr_in_place_req::( //! a.nrows(), //! a.ncols(), //! blocksize, @@ -73,7 +76,7 @@ //! Default::default(), //! ) //! .unwrap(), -//! faer_core::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_req::< +//! householder::apply_block_householder_sequence_transpose_on_the_left_in_place_req::< //! f64, //! >(a.nrows(), blocksize, b.ncols()) //! .unwrap(), @@ -81,9 +84,9 @@ //! )); //! let mut stack = PodStack::new(&mut mem); //! -//! let mut qr = a.clone(); +//! let mut qr = a; //! let mut h_factor = Mat::zeros(blocksize, rank); -//! faer_qr::no_pivoting::compute::qr_in_place( +//! compute::qr_in_place( //! qr.as_mut(), //! h_factor.as_mut(), //! Parallelism::None, @@ -97,7 +100,7 @@ //! let mut solution = b.clone(); //! //! // compute Q^H×B -//! faer_core::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj( +//! householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj( //! qr.as_ref(), //! h_factor.as_ref(), //! Conj::Yes, @@ -109,7 +112,7 @@ //! solution.resize_with(rank, b.ncols(), |_, _| unreachable!()); //! //! // compute R_rect^{-1} Q_thin^H×B -//! solve::solve_upper_triangular_in_place( +//! triangular_solve::solve_upper_triangular_in_place( //! qr.as_ref().split_at_row(rank).0, //! solution.as_mut(), //! Parallelism::None, @@ -134,10 +137,12 @@ mod tests { #[test] fn test_example() { - use crate::no_pivoting::compute; + use crate::{ + linalg::{householder, qr::no_pivoting::compute, triangular_solve}, + mat, Conj, Mat, Parallelism, + }; use assert_approx_eq::assert_approx_eq; use dyn_stack::{GlobalPodBuffer, PodStack, StackReq}; - use faer_core::{householder, mat, solve, Conj, Mat, Parallelism}; use reborrow::*; // we start by defining matrices A and B that define our least-squares problem. @@ -189,7 +194,7 @@ mod tests { Default::default(), ) .unwrap(), - faer_core::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_req::< + householder::apply_block_householder_sequence_transpose_on_the_left_in_place_req::< f64, >(a.nrows(), blocksize, b.ncols()) .unwrap(), @@ -224,7 +229,7 @@ mod tests { solution.resize_with(rank, b.ncols(), |_, _| unreachable!()); // compute R_rect^{-1} Q_thin^H×B - solve::solve_upper_triangular_in_place( + triangular_solve::solve_upper_triangular_in_place( qr.as_ref().split_at_row(rank).0, solution.as_mut(), Parallelism::None, diff --git a/faer-libs/faer-qr/src/no_pivoting/compute.rs b/src/linalg/qr/no_pivoting/compute.rs similarity index 95% rename from faer-libs/faer-qr/src/no_pivoting/compute.rs rename to src/linalg/qr/no_pivoting/compute.rs index d1a042b06a817e4c30366f82eb7f734a9ed53f92..b6cf8c7b184201a2f7e542ff855021cd44ef81a4 100644 --- a/faer-libs/faer-qr/src/no_pivoting/compute.rs +++ b/src/linalg/qr/no_pivoting/compute.rs @@ -1,15 +1,19 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - group_helpers::*, - householder::{ - apply_block_householder_transpose_on_the_left_in_place_with_conj, - upgrade_householder_factor, + linalg::{ + entity::SimdCtx, + householder::{ + self, apply_block_householder_transpose_on_the_left_in_place_with_conj, + upgrade_householder_factor, + }, + matmul::inner_prod::{self, inner_prod_with_conj_arch}, + temp_mat_req, }, - mul::inner_prod::{self, inner_prod_with_conj_arch}, - temp_mat_req, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, - SimdCtx, + unzipped, + utils::{simd::*, slice::*}, + zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use faer_entity::*; use reborrow::*; @@ -34,7 +38,7 @@ fn qr_in_place_unblocked( let tail_norm = first_col_tail.norm_l2(); - let (tau, beta) = faer_core::householder::make_householder_in_place_v2( + let (tau, beta) = householder::make_householder_in_place( Some(first_col_tail.rb_mut().as_2d_mut()), first_col_head.read(0), tail_norm, @@ -183,6 +187,7 @@ fn default_disable_blocking(m: usize, n: usize) -> bool { prod < 48 * 48 } +/// QR factorization tuning parameters. #[derive(Default, Copy, Clone)] #[non_exhaustive] pub struct QrComputeParams { @@ -324,7 +329,7 @@ pub fn qr_in_place( )); #[cfg(feature = "perf-warn")] - if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(QR_WARN) { + if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(QR_WARN) { if matrix.col_stride().unsigned_abs() == 1 { log::warn!(target: "faer_perf", "QR prefers column-major matrix. Found row-major matrix."); } else { @@ -365,17 +370,20 @@ pub fn qr_in_place_req( #[cfg(test)] mod tests { use super::*; - use assert_approx_eq::assert_approx_eq; - use faer_core::{ - assert, c64, - householder::{ - apply_block_householder_sequence_on_the_left_in_place_req, - apply_block_householder_sequence_on_the_left_in_place_with_conj, + use crate::{ + assert, + complex_native::c64, + linalg::{ + householder::{ + apply_block_householder_sequence_on_the_left_in_place_req, + apply_block_householder_sequence_on_the_left_in_place_with_conj, + }, + matmul::matmul, + zip::Diag, }, - mul::matmul, - zip::Diag, Conj, Mat, MatRef, Parallelism, }; + use assert_approx_eq::assert_approx_eq; use std::cell::RefCell; macro_rules! make_stack { diff --git a/faer-libs/faer-qr/src/no_pivoting/inverse.rs b/src/linalg/qr/no_pivoting/inverse.rs similarity index 89% rename from faer-libs/faer-qr/src/no_pivoting/inverse.rs rename to src/linalg/qr/no_pivoting/inverse.rs index 4ad396af0d7f41df7a26767c9f0e8776b58c0d9d..460fc7ff1a8d1c9584c1812d4c53f4a24aea42f6 100644 --- a/faer-libs/faer-qr/src/no_pivoting/inverse.rs +++ b/src/linalg/qr/no_pivoting/inverse.rs @@ -1,10 +1,12 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - householder::apply_block_householder_sequence_transpose_on_the_right_in_place_with_conj, - inverse::invert_upper_triangular, temp_mat_req, temp_mat_uninit, unzipped, zipped, - ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, + linalg::{ + householder::apply_block_householder_sequence_transpose_on_the_right_in_place_with_conj, + temp_mat_req, temp_mat_uninit, triangular_inverse::invert_upper_triangular, + }, + unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; /// Computes the inverse of a matrix, given its QR decomposition, @@ -41,7 +43,7 @@ pub fn invert( // zero bottom part zipped!(dst.rb_mut()) - .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { + .for_each_triangular_lower(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { dst.write(E::faer_zero()) }); @@ -116,9 +118,16 @@ pub fn invert_in_place_req( #[cfg(test)] mod tests { use super::*; - use crate::no_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}; + use crate::{ + assert, + complex_native::c64, + linalg::{ + matmul::matmul, + qr::no_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}, + }, + Mat, + }; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, c64, mul::matmul, Mat}; use rand::prelude::*; use std::cell::RefCell; @@ -153,7 +162,7 @@ mod tests { let mut qr = mat.clone(); let mut householder_factor = Mat::zeros(blocksize, n); - let parallelism = faer_core::Parallelism::Rayon(0); + let parallelism = crate::Parallelism::Rayon(0); qr_in_place( qr.as_mut(), diff --git a/faer-libs/faer-qr/src/no_pivoting/mod.rs b/src/linalg/qr/no_pivoting/mod.rs similarity index 55% rename from faer-libs/faer-qr/src/no_pivoting/mod.rs rename to src/linalg/qr/no_pivoting/mod.rs index 3b249818bc230d42b9f381ba1bab38756f5c746f..0cd8d353cf0f6945fc9bcf2bd30ff3df3094afcf 100644 --- a/faer-libs/faer-qr/src/no_pivoting/mod.rs +++ b/src/linalg/qr/no_pivoting/mod.rs @@ -3,7 +3,11 @@ //! where $Q$ is a unitary matrix (represented as a block Householder sequence), and $R$ is an upper //! trapezoidal matrix. +/// Computing the decomposition. pub mod compute; +/// Reconstructing the inverse of the original matrix from the decomposition. pub mod inverse; +/// Reconstructing the original matrix from the decomposition. pub mod reconstruct; +/// Solving a linear system usin the decomposition. pub mod solve; diff --git a/faer-libs/faer-qr/src/no_pivoting/reconstruct.rs b/src/linalg/qr/no_pivoting/reconstruct.rs similarity index 87% rename from faer-libs/faer-qr/src/no_pivoting/reconstruct.rs rename to src/linalg/qr/no_pivoting/reconstruct.rs index 2aa0f867ce3773040a9fa8b952b1c22f3639eeb1..06d371792af1d2776ae5b52db151507fa45e954d 100644 --- a/faer-libs/faer-qr/src/no_pivoting/reconstruct.rs +++ b/src/linalg/qr/no_pivoting/reconstruct.rs @@ -1,9 +1,12 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, householder::apply_block_householder_sequence_on_the_left_in_place_with_conj, - temp_mat_req, temp_mat_uninit, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, - Parallelism, +use crate::{ + assert, + linalg::{ + householder::apply_block_householder_sequence_on_the_left_in_place_with_conj, temp_mat_req, + temp_mat_uninit, + }, + unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; /// Computes the reconstructed matrix, given its QR decomposition, and stores the @@ -31,14 +34,14 @@ pub fn reconstruct( let mut dst = dst; // copy R - zipped!(dst.rb_mut(), qr_factors) - .for_each_triangular_upper(faer_core::zip::Diag::Include, |unzipped!(mut dst, src)| { - dst.write(src.read()) - }); + zipped!(dst.rb_mut(), qr_factors).for_each_triangular_upper( + crate::linalg::zip::Diag::Include, + |unzipped!(mut dst, src)| dst.write(src.read()), + ); // zero bottom part zipped!(dst.rb_mut()) - .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst)| { + .for_each_triangular_lower(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { dst.write(E::faer_zero()) }); @@ -112,9 +115,13 @@ pub fn reconstruct_in_place_req( #[cfg(test)] mod tests { use super::*; - use crate::no_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}; + use crate::{ + assert, + complex_native::c64, + linalg::qr::no_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}, + Mat, + }; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, c64, Mat}; use rand::prelude::*; use std::cell::RefCell; @@ -149,7 +156,7 @@ mod tests { let mut qr = mat.clone(); let mut householder_factor = Mat::zeros(blocksize, n); - let parallelism = faer_core::Parallelism::Rayon(0); + let parallelism = crate::Parallelism::Rayon(0); qr_in_place( qr.as_mut(), diff --git a/faer-libs/faer-qr/src/no_pivoting/solve.rs b/src/linalg/qr/no_pivoting/solve.rs similarity index 95% rename from faer-libs/faer-qr/src/no_pivoting/solve.rs rename to src/linalg/qr/no_pivoting/solve.rs index fa8c29f9d7516637b9559d1113edb8562ad35be4..a9a2dca4f26b144f3189c9a5630ede8cc97b87b1 100644 --- a/faer-libs/faer-qr/src/no_pivoting/solve.rs +++ b/src/linalg/qr/no_pivoting/solve.rs @@ -1,12 +1,15 @@ -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - householder::{ - apply_block_householder_sequence_on_the_left_in_place_with_conj, - apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj, + linalg::{ + householder::{ + apply_block_householder_sequence_on_the_left_in_place_with_conj, + apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj, + }, + temp_mat_req, triangular_solve as solve, }, - solve, temp_mat_req, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, + unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, }; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use reborrow::*; /// Computes the size and alignment of required workspace for solving a linear system defined by a @@ -253,8 +256,15 @@ pub fn solve_transpose( #[cfg(test)] mod tests { use super::*; - use crate::no_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}; - use faer_core::{assert, c32, c64, mul::matmul_with_conj, Mat}; + use crate::{ + assert, + complex_native::{c32, c64}, + linalg::{ + matmul::matmul_with_conj, + qr::no_pivoting::compute::{qr_in_place, qr_in_place_req, recommended_blocksize}, + }, + Mat, + }; use rand::random; macro_rules! make_stack { diff --git a/src/linalg/reductions/mod.rs b/src/linalg/reductions/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..4729566083c1fe0f61e9d8dc32c9e375058e006f --- /dev/null +++ b/src/linalg/reductions/mod.rs @@ -0,0 +1,6 @@ +const LINEAR_IMPL_THRESHOLD: usize = 128; + +pub mod norm_l1; +pub mod norm_l2; +pub mod norm_max; +pub mod sum; diff --git a/src/linalg/reductions/norm_l1.rs b/src/linalg/reductions/norm_l1.rs new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/src/linalg/reductions/norm_l1.rs @@ -0,0 +1 @@ + diff --git a/src/linalg/reductions/norm_l2.rs b/src/linalg/reductions/norm_l2.rs new file mode 100644 index 0000000000000000000000000000000000000000..87187056c7b34944f0a10f7e9e609b8ce6c83016 --- /dev/null +++ b/src/linalg/reductions/norm_l2.rs @@ -0,0 +1,344 @@ +use super::LINEAR_IMPL_THRESHOLD; +use crate::{ + complex_native::*, + mat::MatRef, + utils::{simd::*, slice::*}, +}; +use faer_entity::*; +use pulp::{Read, Simd}; + +#[inline(always)] +fn norm_l2_with_simd_and_offset_prologue( + simd: S, + data: SliceGroup<'_, E>, + offset: pulp::Offset>, +) -> ( + SimdGroupFor, + SimdGroupFor, + SimdGroupFor, +) { + let simd_real = SimdFor::::new(simd); + let simd = SimdFor::::new(simd); + let half_big = simd_real.splat(E::Real::faer_min_positive_sqrt_inv()); + let half_small = simd_real.splat(E::Real::faer_min_positive_sqrt()); + let zero = simd.splat(E::faer_zero()); + let zero_real = simd_real.splat(E::Real::faer_zero()); + + let (head, body, tail) = simd.as_aligned_simd(data, offset); + let (body2, body1) = body.as_arrays::<2>(); + + let mut acc0 = simd.abs2(head.read_or(zero)); + let mut acc1 = zero_real; + + let mut acc_small0 = simd.abs2(simd.scale_real(half_small, head.read_or(zero))); + let mut acc_small1 = zero_real; + + let mut acc_big0 = simd.abs2(simd.scale_real(half_big, head.read_or(zero))); + let mut acc_big1 = zero_real; + + for [x0, x1] in body2.into_ref_iter().map(RefGroup::unzip) { + let x0 = x0.get(); + let x1 = x1.get(); + acc0 = simd.abs2_add_e(x0, acc0); + acc1 = simd.abs2_add_e(x1, acc1); + + acc_small0 = simd.abs2_add_e(simd.scale_real(half_small, x0), acc_small0); + acc_small1 = simd.abs2_add_e(simd.scale_real(half_small, x1), acc_small1); + + acc_big0 = simd.abs2_add_e(simd.scale_real(half_big, x0), acc_big0); + acc_big1 = simd.abs2_add_e(simd.scale_real(half_big, x1), acc_big1); + } + + for x0 in body1.into_ref_iter() { + let x0 = x0.get(); + acc0 = simd.abs2_add_e(x0, acc0); + acc_small0 = simd.abs2_add_e(simd.scale_real(half_small, x0), acc_small0); + acc_big0 = simd.abs2_add_e(simd.scale_real(half_big, x0), acc_big0); + } + + acc0 = simd.abs2_add_e(tail.read_or(zero), acc0); + acc_small0 = simd.abs2_add_e(simd.scale_real(half_small, tail.read_or(zero)), acc_small0); + acc_big0 = simd.abs2_add_e(simd.scale_real(half_big, tail.read_or(zero)), acc_big0); + + acc0 = simd_real.add(acc0, acc1); + acc_small0 = simd_real.add(acc_small0, acc_small1); + acc_big0 = simd_real.add(acc_big0, acc_big1); + + (acc_small0, acc0, acc_big0) +} + +#[inline(always)] +fn norm_l2_with_simd_and_offset_pairwise_rows( + simd: S, + data: SliceGroup<'_, E>, + offset: pulp::Offset>, + last_offset: pulp::Offset>, +) -> ( + SimdGroupFor, + SimdGroupFor, + SimdGroupFor, +) { + struct Impl<'a, E: ComplexField, S: Simd> { + simd: S, + data: SliceGroup<'a, E>, + offset: pulp::Offset>, + last_offset: pulp::Offset>, + } + + impl pulp::NullaryFnOnce for Impl<'_, E, S> { + type Output = ( + SimdGroupFor, + SimdGroupFor, + SimdGroupFor, + ); + + #[inline(always)] + fn call(self) -> Self::Output { + let Self { + simd, + data, + offset, + last_offset, + } = self; + + if data.len() == LINEAR_IMPL_THRESHOLD { + norm_l2_with_simd_and_offset_prologue(simd, data, offset) + } else if data.len() < LINEAR_IMPL_THRESHOLD { + norm_l2_with_simd_and_offset_prologue(simd, data, last_offset) + } else { + let split_point = ((data.len() + 1) / 2).next_power_of_two(); + let (head, tail) = data.split_at(split_point); + let (acc_small0, acc0, acc_big0) = + norm_l2_with_simd_and_offset_pairwise_rows(simd, head, offset, last_offset); + let (acc_small1, acc1, acc_big1) = + norm_l2_with_simd_and_offset_pairwise_rows(simd, tail, offset, last_offset); + + let simd = SimdFor::::new(simd); + ( + simd.add(acc_small0, acc_small1), + simd.add(acc0, acc1), + simd.add(acc_big0, acc_big1), + ) + } + } + } + + simd.vectorize(Impl { + simd, + data, + offset, + last_offset, + }) +} + +#[inline(always)] +fn norm_l2_with_simd_and_offset_pairwise_cols( + simd: S, + data: MatRef<'_, E>, + offset: pulp::Offset>, + last_offset: pulp::Offset>, +) -> ( + SimdGroupFor, + SimdGroupFor, + SimdGroupFor, +) { + struct Impl<'a, E: ComplexField, S: Simd> { + simd: S, + data: MatRef<'a, E>, + offset: pulp::Offset>, + last_offset: pulp::Offset>, + } + + impl pulp::NullaryFnOnce for Impl<'_, E, S> { + type Output = ( + SimdGroupFor, + SimdGroupFor, + SimdGroupFor, + ); + + #[inline(always)] + fn call(self) -> Self::Output { + let Self { + simd, + data, + offset, + last_offset, + } = self; + if data.ncols() == 1 { + norm_l2_with_simd_and_offset_pairwise_rows( + simd, + SliceGroup::<'_, E>::new(data.try_get_contiguous_col(0)), + offset, + last_offset, + ) + } else { + let split_point = (data.ncols() / 2).next_power_of_two(); + + let (head, tail) = data.split_at_col(split_point); + + let (acc_small0, acc0, acc_big0) = + norm_l2_with_simd_and_offset_pairwise_cols(simd, head, offset, last_offset); + let (acc_small1, acc1, acc_big1) = + norm_l2_with_simd_and_offset_pairwise_cols(simd, tail, offset, last_offset); + + let simd = SimdFor::::new(simd); + ( + simd.add(acc_small0, acc_small1), + simd.add(acc0, acc1), + simd.add(acc_big0, acc_big1), + ) + } + } + } + + simd.vectorize(Impl { + simd, + data, + offset, + last_offset, + }) +} + +#[inline(always)] + +fn norm_l2_contiguous(data: MatRef<'_, E>) -> (E::Real, E::Real, E::Real) { + struct Impl<'a, E: ComplexField> { + data: MatRef<'a, E>, + } + + impl pulp::WithSimd for Impl<'_, E> { + type Output = (E::Real, E::Real, E::Real); + + #[inline(always)] + fn with_simd(self, simd: S) -> Self::Output { + let Self { data } = self; + + let offset = + SimdFor::::new(simd).align_offset_ptr(data.as_ptr(), LINEAR_IMPL_THRESHOLD); + + let last_offset = SimdFor::::new(simd) + .align_offset_ptr(data.as_ptr(), data.nrows() % LINEAR_IMPL_THRESHOLD); + + let (acc_small, acc, acc_big) = + norm_l2_with_simd_and_offset_pairwise_cols(simd, data, offset, last_offset); + + let simd = SimdFor::::new(simd); + ( + simd.reduce_add(simd.rotate_left(acc_small, offset.rotate_left_amount())), + simd.reduce_add(simd.rotate_left(acc, offset.rotate_left_amount())), + simd.reduce_add(simd.rotate_left(acc_big, offset.rotate_left_amount())), + ) + } + } + + E::Simd::default().dispatch(Impl { data }) +} + +pub fn norm_l2(mut mat: MatRef<'_, E>) -> E::Real { + if mat.ncols() > 1 && mat.col_stride().unsigned_abs() < mat.row_stride().unsigned_abs() { + mat = mat.transpose(); + } + if mat.row_stride() < 0 { + mat = mat.reverse_rows(); + } + + if mat.nrows() == 0 || mat.ncols() == 0 { + E::Real::faer_zero() + } else { + let m = mat.nrows(); + let n = mat.ncols(); + + let half_small = E::Real::faer_min_positive_sqrt(); + let half_big = E::Real::faer_min_positive_sqrt_inv(); + + let mut acc_small = E::Real::faer_zero(); + let mut acc = E::Real::faer_zero(); + let mut acc_big = E::Real::faer_zero(); + + if mat.row_stride() == 1 { + if coe::is_same::() { + let mat: MatRef<'_, c32> = coe::coerce(mat); + let mat = unsafe { + crate::mat::from_raw_parts( + mat.as_ptr() as *const f32, + 2 * mat.nrows(), + mat.ncols(), + 1, + mat.col_stride().wrapping_mul(2), + ) + }; + let (acc_small_, acc_, acc_big_) = norm_l2_contiguous::(mat); + acc_small = coe::coerce_static(acc_small_); + acc = coe::coerce_static(acc_); + acc_big = coe::coerce_static(acc_big_); + } else if coe::is_same::() { + let mat: MatRef<'_, c64> = coe::coerce(mat); + let mat = unsafe { + crate::mat::from_raw_parts( + mat.as_ptr() as *const f64, + 2 * mat.nrows(), + mat.ncols(), + 1, + mat.col_stride().wrapping_mul(2), + ) + }; + let (acc_small_, acc_, acc_big_) = norm_l2_contiguous::(mat); + acc_small = coe::coerce_static(acc_small_); + acc = coe::coerce_static(acc_); + acc_big = coe::coerce_static(acc_big_); + } else { + (acc_small, acc, acc_big) = norm_l2_contiguous(mat); + } + } else { + for j in 0..n { + for i in 0..m { + let val = mat.read(i, j); + let val_small = val.faer_scale_power_of_two(half_small); + let val_big = val.faer_scale_power_of_two(half_big); + + acc_small = acc_small.faer_add(val_small.faer_abs2()); + acc = acc.faer_add(val.faer_abs2()); + acc_big = acc_big.faer_add(val_big.faer_abs2()); + } + } + } + + if acc_small >= E::Real::faer_one() { + acc_small.faer_sqrt().faer_mul(half_big) + } else if acc_big <= E::Real::faer_one() { + acc_big.faer_sqrt().faer_mul(half_small) + } else { + acc.faer_sqrt() + } + } +} + +#[cfg(test)] +mod tests { + use crate::{assert, prelude::*, unzipped, zipped}; + + #[test] + fn test_norm_l2() { + let relative_err = |a: f64, b: f64| (a - b).abs() / f64::max(a.abs(), b.abs()); + + for (m, n) in [(9, 10), (1023, 5), (42, 1)] { + for factor in [0.0, 1.0, 1e30, 1e250, 1e-30, 1e-250] { + let mat = Mat::from_fn(m, n, |i, j| factor * ((i + j) as f64)); + let mut target = 0.0; + zipped!(mat.as_ref()).for_each(|unzipped!(x)| { + target = f64::hypot(*x, target); + }); + + if factor == 0.0 { + assert!(mat.norm_l2() == target); + } else { + assert!(relative_err(mat.norm_l2(), target) < 1e-14); + } + } + } + + let mat = Col::from_fn(10000000, |_| 0.3); + let target = (0.3 * 0.3 * 10000000.0f64).sqrt(); + assert!(relative_err(mat.norm_l2(), target) < 1e-14); + } +} diff --git a/src/linalg/reductions/norm_max.rs b/src/linalg/reductions/norm_max.rs new file mode 100644 index 0000000000000000000000000000000000000000..37e0b45ed18a9718d93030c5b8d84452e899c327 --- /dev/null +++ b/src/linalg/reductions/norm_max.rs @@ -0,0 +1,149 @@ +use crate::{ + complex_native::*, + mat::MatRef, + utils::{simd::*, slice::*}, +}; +use faer_entity::*; +use pulp::Read; + +#[inline(always)] +fn norm_max_contiguous(data: MatRef<'_, E>) -> E { + struct Impl<'a, E: RealField> { + data: MatRef<'a, E>, + } + + impl pulp::WithSimd for Impl<'_, E> { + type Output = E; + + #[inline(always)] + fn with_simd(self, simd: S) -> Self::Output { + let Self { data } = self; + let m = data.nrows(); + let n = data.ncols(); + + let offset = SimdFor::::new(simd).align_offset_ptr(data.as_ptr(), m); + + let simd = SimdFor::::new(simd); + + let zero = simd.splat(E::faer_zero()); + + let mut acc0 = zero; + let mut acc1 = zero; + let mut acc2 = zero; + let mut acc3 = zero; + for j in 0..n { + let col = SliceGroup::<'_, E>::new(data.try_get_contiguous_col(j)); + let (head, body, tail) = simd.as_aligned_simd(col, offset); + let (body4, body1) = body.as_arrays::<4>(); + + let head = simd.abs(head.read_or(zero)); + acc0 = simd.select(simd.greater_than(head, acc0), head, acc0); + + for [x0, x1, x2, x3] in body4.into_ref_iter().map(RefGroup::unzip) { + let x0 = simd.abs(x0.get()); + let x1 = simd.abs(x1.get()); + let x2 = simd.abs(x2.get()); + let x3 = simd.abs(x3.get()); + acc0 = simd.select(simd.greater_than(x0, acc0), x0, acc0); + acc1 = simd.select(simd.greater_than(x1, acc1), x1, acc1); + acc2 = simd.select(simd.greater_than(x2, acc2), x2, acc2); + acc3 = simd.select(simd.greater_than(x3, acc3), x3, acc3); + } + + for x0 in body1.into_ref_iter() { + let x0 = simd.abs(x0.get()); + acc0 = simd.select(simd.greater_than(x0, acc0), x0, acc0); + } + + let tail = simd.abs(tail.read_or(zero)); + acc3 = simd.select(simd.greater_than(tail, acc3), tail, acc3); + } + acc0 = simd.select(simd.greater_than(acc0, acc1), acc0, acc1); + acc2 = simd.select(simd.greater_than(acc2, acc3), acc2, acc3); + acc0 = simd.select(simd.greater_than(acc0, acc2), acc0, acc2); + + let acc0 = from_copy::(simd.rotate_left(acc0, offset.rotate_left_amount())); + let acc = SliceGroup::<'_, E>::new(E::faer_map( + E::faer_as_ref(&acc0), + #[inline(always)] + |acc| bytemuck::cast_slice::<_, ::Unit>(core::slice::from_ref(acc)), + )); + let mut acc_scalar = E::faer_zero(); + for x in acc.into_ref_iter() { + let x = x.read(); + acc_scalar = if acc_scalar > x { acc_scalar } else { x }; + } + acc_scalar + } + } + + E::Simd::default().dispatch(Impl { data }) +} + +pub fn norm_max(mut mat: MatRef<'_, E>) -> E::Real { + if mat.ncols() > 1 && mat.col_stride().unsigned_abs() < mat.row_stride().unsigned_abs() { + mat = mat.transpose(); + } + if mat.row_stride() < 0 { + mat = mat.reverse_rows(); + } + + if mat.nrows() == 0 || mat.ncols() == 0 { + E::Real::faer_zero() + } else { + let m = mat.nrows(); + let n = mat.ncols(); + + if mat.row_stride() == 1 { + if coe::is_same::() { + let mat: MatRef<'_, c32> = coe::coerce(mat); + let mat = unsafe { + crate::mat::from_raw_parts( + mat.as_ptr() as *const f32, + 2 * mat.nrows(), + mat.ncols(), + 1, + 2 * mat.col_stride(), + ) + }; + return coe::coerce_static(norm_max_contiguous::(mat)); + } + if coe::is_same::() { + let mat: MatRef<'_, c64> = coe::coerce(mat); + let mat = unsafe { + crate::mat::from_raw_parts( + mat.as_ptr() as *const f64, + 2 * mat.nrows(), + mat.ncols(), + 1, + 2 * mat.col_stride(), + ) + }; + return coe::coerce_static(norm_max_contiguous::(mat)); + } + if coe::is_same::>() { + let mat: MatRef<'_, num_complex::Complex> = coe::coerce(mat); + let num_complex::Complex { re, im } = mat.real_imag(); + let re = norm_max_contiguous(re); + let im = norm_max_contiguous(im); + return if re > im { re } else { im }; + } + if coe::is_same::() { + let mat: MatRef<'_, E::Real> = coe::coerce(mat); + return norm_max_contiguous(mat); + } + } + + let mut acc = E::Real::faer_zero(); + for j in 0..n { + for i in 0..m { + let val = mat.read(i, j); + let re = val.faer_real(); + let im = val.faer_imag(); + acc = if re > acc { re } else { acc }; + acc = if im > acc { im } else { acc }; + } + } + acc + } +} diff --git a/src/linalg/reductions/sum.rs b/src/linalg/reductions/sum.rs new file mode 100644 index 0000000000000000000000000000000000000000..e6b52be297959a8818dd0d2a8d9c2151670fb297 --- /dev/null +++ b/src/linalg/reductions/sum.rs @@ -0,0 +1,242 @@ +use super::LINEAR_IMPL_THRESHOLD; +use crate::{ + mat::MatRef, + utils::{simd::*, slice::*}, +}; +use faer_entity::*; +use pulp::{Read, Simd}; + +#[inline(always)] +fn sum_with_simd_and_offset_prologue( + simd: S, + data: SliceGroup<'_, E>, + offset: pulp::Offset>, +) -> SimdGroupFor { + let simd = SimdFor::::new(simd); + + let zero = simd.splat(E::faer_zero()); + + let mut acc0 = zero; + let mut acc1 = zero; + let mut acc2 = zero; + let mut acc3 = zero; + let (head, body, tail) = simd.as_aligned_simd(data, offset); + let (body4, body1) = body.as_arrays::<4>(); + let head = head.read_or(zero); + acc0 = simd.add(acc0, head); + + for [x0, x1, x2, x3] in body4.into_ref_iter().map(RefGroup::unzip) { + let x0 = x0.get(); + let x1 = x1.get(); + let x2 = x2.get(); + let x3 = x3.get(); + acc0 = simd.add(acc0, x0); + acc1 = simd.add(acc1, x1); + acc2 = simd.add(acc2, x2); + acc3 = simd.add(acc3, x3); + } + + for x0 in body1.into_ref_iter() { + let x0 = x0.get(); + acc0 = simd.add(acc0, x0); + } + + let tail = tail.read_or(zero); + acc3 = simd.add(acc3, tail); + + acc0 = simd.add(acc0, acc1); + acc2 = simd.add(acc2, acc3); + simd.add(acc0, acc2) +} + +#[inline(always)] +fn sum_with_simd_and_offset_pairwise_rows( + simd: S, + data: SliceGroup<'_, E>, + offset: pulp::Offset>, + last_offset: pulp::Offset>, +) -> SimdGroupFor { + struct Impl<'a, E: ComplexField, S: Simd> { + simd: S, + data: SliceGroup<'a, E>, + offset: pulp::Offset>, + last_offset: pulp::Offset>, + } + + impl pulp::NullaryFnOnce for Impl<'_, E, S> { + type Output = SimdGroupFor; + + #[inline(always)] + fn call(self) -> Self::Output { + let Self { + simd, + data, + offset, + last_offset, + } = self; + + if data.len() == LINEAR_IMPL_THRESHOLD { + sum_with_simd_and_offset_prologue(simd, data, offset) + } else if data.len() < LINEAR_IMPL_THRESHOLD { + sum_with_simd_and_offset_prologue(simd, data, last_offset) + } else { + let split_point = ((data.len() + 1) / 2).next_power_of_two(); + let (head, tail) = data.split_at(split_point); + let acc0 = sum_with_simd_and_offset_pairwise_rows(simd, head, offset, last_offset); + let acc1 = sum_with_simd_and_offset_pairwise_rows(simd, tail, offset, last_offset); + + let simd = SimdFor::::new(simd); + simd.add(acc0, acc1) + } + } + } + + simd.vectorize(Impl { + simd, + data, + offset, + last_offset, + }) +} + +#[inline(always)] +fn sum_with_simd_and_offset_pairwise_cols( + simd: S, + data: MatRef<'_, E>, + offset: pulp::Offset>, + last_offset: pulp::Offset>, +) -> SimdGroupFor { + struct Impl<'a, E: ComplexField, S: Simd> { + simd: S, + data: MatRef<'a, E>, + offset: pulp::Offset>, + last_offset: pulp::Offset>, + } + + impl pulp::NullaryFnOnce for Impl<'_, E, S> { + type Output = SimdGroupFor; + + #[inline(always)] + fn call(self) -> Self::Output { + let Self { + simd, + data, + offset, + last_offset, + } = self; + if data.ncols() == 1 { + sum_with_simd_and_offset_pairwise_rows( + simd, + SliceGroup::<'_, E>::new(data.try_get_contiguous_col(0)), + offset, + last_offset, + ) + } else { + let split_point = (data.ncols() / 2).next_power_of_two(); + + let (head, tail) = data.split_at_col(split_point); + + let acc0 = sum_with_simd_and_offset_pairwise_cols(simd, head, offset, last_offset); + let acc1 = sum_with_simd_and_offset_pairwise_cols(simd, tail, offset, last_offset); + + let simd = SimdFor::::new(simd); + simd.add(acc0, acc1) + } + } + } + + simd.vectorize(Impl { + simd, + data, + offset, + last_offset, + }) +} + +fn sum_contiguous(data: MatRef<'_, E>) -> E { + struct Impl<'a, E: ComplexField> { + data: MatRef<'a, E>, + } + + impl pulp::WithSimd for Impl<'_, E> { + type Output = E; + + #[inline(always)] + fn with_simd(self, simd: S) -> Self::Output { + let Self { data } = self; + + let offset = + SimdFor::::new(simd).align_offset_ptr(data.as_ptr(), LINEAR_IMPL_THRESHOLD); + + let last_offset = SimdFor::::new(simd) + .align_offset_ptr(data.as_ptr(), data.nrows() % LINEAR_IMPL_THRESHOLD); + + let acc = sum_with_simd_and_offset_pairwise_cols(simd, data, offset, last_offset); + + let simd = SimdFor::::new(simd); + simd.reduce_add(simd.rotate_left(acc, offset.rotate_left_amount())) + } + } + + E::Simd::default().dispatch(Impl { data }) +} + +pub fn sum(mut mat: MatRef<'_, E>) -> E { + if mat.ncols() > 1 && mat.col_stride().unsigned_abs() < mat.row_stride().unsigned_abs() { + mat = mat.transpose(); + } + if mat.row_stride() < 0 { + mat = mat.reverse_rows(); + } + + if mat.nrows() == 0 || mat.ncols() == 0 { + E::faer_zero() + } else { + let m = mat.nrows(); + let n = mat.ncols(); + + let mut acc = E::faer_zero(); + + if mat.row_stride() == 1 { + acc = sum_contiguous(mat); + } else { + for j in 0..n { + for i in 0..m { + acc = acc.faer_add(mat.read(i, j)); + } + } + } + + acc + } +} + +#[cfg(test)] +mod tests { + use crate::{assert, prelude::*, unzipped, zipped}; + + #[test] + fn test_sum() { + let relative_err = |a: f64, b: f64| (a - b).abs() / f64::max(a.abs(), b.abs()); + + for (m, n) in [(9, 10), (1023, 5), (42, 1)] { + for factor in [0.0, 1.0, 1e30, 1e250, 1e-30, 1e-250] { + let mat = Mat::from_fn(m, n, |i, j| factor * ((i + j) as f64)); + let mut target = 0.0; + zipped!(mat.as_ref()).for_each(|unzipped!(x)| { + target += *x; + }); + + if factor == 0.0 { + assert!(mat.sum() == target); + } else { + assert!(relative_err(mat.sum(), target) < 1e-14); + } + } + } + + let mat = Col::from_fn(10000000, |_| 0.3); + let target = 0.3 * 10000000.0f64; + assert!(relative_err(mat.sum(), target) < 1e-14); + } +} diff --git a/src/linalg/solvers.rs b/src/linalg/solvers.rs new file mode 100644 index 0000000000000000000000000000000000000000..c98b9273d834030f20b8f8d6d68ca8a5381ba1ae --- /dev/null +++ b/src/linalg/solvers.rs @@ -0,0 +1,2918 @@ +use crate::{ + assert, col::*, diag::DiagRef, linalg::matmul::triangular::BlockStructure, mat::*, + perm::PermRef, unzipped, zipped, Side, *, +}; +use dyn_stack::*; +use reborrow::*; + +pub use crate::{ + linalg::cholesky::llt::CholeskyError, + sparse::linalg::solvers::{SpSolver, SpSolverCore, SpSolverLstsq, SpSolverLstsqCore}, +}; + +/// Object-safe base for [`Solver`] +pub trait SolverCore: SpSolverCore { + /// Reconstructs the original matrix using the decomposition. + fn reconstruct(&self) -> Mat; + /// Computes the inverse of the original matrix using the decomposition. + /// + /// # Panics + /// Panics if the matrix is not square. + fn inverse(&self) -> Mat; +} +/// Object-safe base for [`SolverLstsq`] +pub trait SolverLstsqCore: SolverCore + SpSolverLstsqCore {} + +/// Solver that can compute solution of a linear system. +pub trait Solver: SolverCore + SpSolver {} +/// Dense solver that can compute the least squares solution of an overdetermined linear system. +pub trait SolverLstsq: SolverLstsqCore + SpSolverLstsq {} + +const _: () = { + fn __assert_object_safe() { + let _: Option<&dyn SolverCore> = None; + let _: Option<&dyn SolverLstsqCore> = None; + } +}; + +impl> SolverLstsq for Dec {} + +impl> Solver for Dec {} + +/// Cholesky decomposition. +pub struct Cholesky { + factors: Mat, +} + +/// Bunch-Kaufman decomposition. +pub struct Lblt { + factors: Mat, + subdiag: Mat, + perm: Vec, + perm_inv: Vec, +} + +/// LU decomposition with partial pivoting. +pub struct PartialPivLu { + pub(crate) factors: Mat, + row_perm: Vec, + row_perm_inv: Vec, + n_transpositions: usize, +} +/// LU decomposition with full pivoting. +pub struct FullPivLu { + factors: Mat, + row_perm: Vec, + row_perm_inv: Vec, + col_perm: Vec, + col_perm_inv: Vec, + n_transpositions: usize, +} + +/// QR decomposition. +pub struct Qr { + factors: Mat, + householder: Mat, +} +/// QR decomposition with column pivoting. +pub struct ColPivQr { + factors: Mat, + householder: Mat, + col_perm: Vec, + col_perm_inv: Vec, +} + +/// Singular value decomposition. +pub struct Svd { + s: Mat, + u: Mat, + v: Mat, +} +/// Thin singular value decomposition. +pub struct ThinSvd { + inner: Svd, +} + +/// Self-adjoint eigendecomposition. +pub struct SelfAdjointEigendecomposition { + s: Mat, + u: Mat, +} + +/// Complex eigendecomposition. +pub struct Eigendecomposition { + s: Col, + u: Mat, +} + +impl Cholesky { + /// Returns the Cholesky factorization of the input + /// matrix, or an error if the matrix is not positive definite. + /// + /// The factorization is such that $A = LL^H$, where $L$ is lower triangular. + /// + /// The matrix is interpreted as Hermitian, but only the provided side is accessed. + #[track_caller] + pub fn try_new>( + matrix: MatRef<'_, ViewE>, + side: Side, + ) -> Result { + assert!(matrix.nrows() == matrix.ncols()); + + let dim = matrix.nrows(); + let parallelism = get_global_parallelism(); + + let mut factors = Mat::::zeros(dim, dim); + match side { + Side::Lower => { + zipped!(factors.as_mut(), matrix).for_each_triangular_lower( + crate::linalg::zip::Diag::Include, + |unzipped!(mut dst, src)| dst.write(src.read().canonicalize()), + ); + } + Side::Upper => { + zipped!(factors.as_mut(), matrix.adjoint()).for_each_triangular_lower( + crate::linalg::zip::Diag::Include, + |unzipped!(mut dst, src)| dst.write(src.read().canonicalize()), + ); + } + } + + let params = Default::default(); + + crate::linalg::cholesky::llt::compute::cholesky_in_place( + factors.as_mut(), + Default::default(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::cholesky::llt::compute::cholesky_in_place_req::( + dim, + parallelism, + params, + ) + .unwrap(), + )), + params, + )?; + Ok(Self { factors }) + } + + fn dim(&self) -> usize { + self.factors.nrows() + } + + /// Returns the factor $L$ of the Cholesky decomposition. + pub fn compute_l(&self) -> Mat { + let mut factor = self.factors.to_owned(); + zipped!(factor.as_mut()) + .for_each_triangular_upper(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { + dst.write(E::faer_zero()) + }); + factor + } +} +impl SpSolverCore for Cholesky { + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + + crate::linalg::cholesky::llt::solve::solve_in_place_with_conj( + self.factors.as_ref(), + conj, + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::cholesky::llt::solve::solve_in_place_req::( + self.dim(), + rhs_ncols, + parallelism, + ) + .unwrap(), + )), + ); + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + self.solve_in_place_with_conj_impl(rhs, conj.compose(Conj::Yes)) + } + + fn nrows(&self) -> usize { + self.factors.nrows() + } + + fn ncols(&self) -> usize { + self.factors.ncols() + } +} +impl SolverCore for Cholesky { + fn inverse(&self) -> Mat { + let mut inv = Mat::::zeros(self.dim(), self.dim()); + let parallelism = get_global_parallelism(); + + crate::linalg::cholesky::llt::inverse::invert_lower( + inv.as_mut(), + self.factors.as_ref(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::cholesky::llt::inverse::invert_lower_req::( + self.dim(), + parallelism, + ) + .unwrap(), + )), + ); + + for j in 0..self.dim() { + for i in 0..j { + inv.write(i, j, inv.read(j, i).faer_conj()); + } + } + + inv + } + + fn reconstruct(&self) -> Mat { + let mut rec = Mat::::zeros(self.dim(), self.dim()); + let parallelism = get_global_parallelism(); + + crate::linalg::cholesky::llt::reconstruct::reconstruct_lower( + rec.as_mut(), + self.factors.as_ref(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::cholesky::llt::reconstruct::reconstruct_lower_req::(self.dim()) + .unwrap(), + )), + ); + + for j in 0..self.dim() { + for i in 0..j { + rec.write(i, j, rec.read(j, i).faer_conj()); + } + } + + rec + } +} + +impl Lblt { + /// Returns the Bunch-Kaufman factorization of the input matrix. + /// + /// The matrix is interpreted as Hermitian, but only the provided side is accessed. + #[track_caller] + pub fn new>(matrix: MatRef<'_, ViewE>, side: Side) -> Self { + assert!(matrix.nrows() == matrix.ncols()); + + let dim = matrix.nrows(); + let parallelism = get_global_parallelism(); + + let mut factors = Mat::::zeros(dim, dim); + let mut subdiag = Mat::::zeros(dim, 1); + let mut perm = vec![0; dim]; + let mut perm_inv = vec![0; dim]; + + match side { + Side::Lower => { + zipped!(factors.as_mut(), matrix).for_each_triangular_lower( + crate::linalg::zip::Diag::Include, + |unzipped!(mut dst, src)| dst.write(src.read().canonicalize()), + ); + } + Side::Upper => { + zipped!(factors.as_mut(), matrix.adjoint()).for_each_triangular_lower( + crate::linalg::zip::Diag::Include, + |unzipped!(mut dst, src)| dst.write(src.read().canonicalize()), + ); + } + } + + let params = Default::default(); + + crate::linalg::cholesky::bunch_kaufman::compute::cholesky_in_place( + factors.as_mut(), + subdiag.as_mut(), + Default::default(), + &mut perm, + &mut perm_inv, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::cholesky::bunch_kaufman::compute::cholesky_in_place_req::( + dim, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + Self { + factors, + subdiag, + perm, + perm_inv, + } + } + + fn dim(&self) -> usize { + self.factors.nrows() + } +} + +impl SpSolverCore for Lblt { + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + + crate::linalg::cholesky::bunch_kaufman::solve::solve_in_place_with_conj( + self.factors.as_ref(), + self.subdiag.as_ref(), + conj, + unsafe { PermRef::new_unchecked(&self.perm, &self.perm_inv) }, + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::cholesky::bunch_kaufman::solve::solve_in_place_req::( + self.dim(), + rhs_ncols, + parallelism, + ) + .unwrap(), + )), + ); + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + self.solve_in_place_with_conj_impl(rhs, conj.compose(Conj::Yes)) + } + + fn nrows(&self) -> usize { + self.factors.nrows() + } + + fn ncols(&self) -> usize { + self.factors.ncols() + } +} +impl SolverCore for Lblt { + fn inverse(&self) -> Mat { + let n = self.dim(); + let mut inv = Mat::identity(n, n); + self.solve_in_place_with_conj_impl(inv.as_mut(), Conj::No); + inv + } + + fn reconstruct(&self) -> Mat { + let parallelism = get_global_parallelism(); + let n = self.dim(); + let lbl = self.factors.as_ref(); + let subdiag = self.subdiag.as_ref(); + let mut mat = Mat::::identity(n, n); + let mut mat2 = Mat::::identity(n, n); + zipped!(mat.as_mut(), lbl).for_each_triangular_lower( + crate::linalg::zip::Diag::Skip, + |unzipped!(mut dst, src)| dst.write(src.read()), + ); + + let mut j = 0; + while j < n { + if subdiag.read(j, 0) == E::faer_zero() { + let d = lbl.read(j, j).faer_real().faer_inv(); + for i in 0..n { + mat.write(i, j, mat.read(i, j).faer_scale_real(d)); + } + j += 1; + } else { + let akp1k = subdiag.read(j, 0).faer_inv(); + let ak = akp1k.faer_scale_real(lbl.read(j, j).faer_real()); + let akp1 = akp1k + .faer_conj() + .faer_scale_real(lbl.read(j + 1, j + 1).faer_real()); + let denom = ak + .faer_mul(akp1) + .faer_sub(E::faer_one()) + .faer_real() + .faer_inv(); + + for i in 0..n { + let xk = mat.read(i, j).faer_mul(akp1k); + let xkp1 = mat.read(i, j + 1).faer_mul(akp1k.faer_conj()); + + mat.write( + i, + j, + (akp1.faer_mul(xk).faer_sub(xkp1)).faer_scale_real(denom), + ); + mat.write( + i, + j + 1, + (ak.faer_mul(xkp1).faer_sub(xk)).faer_scale_real(denom), + ); + } + j += 2; + } + } + crate::linalg::matmul::triangular::matmul( + mat2.as_mut(), + BlockStructure::TriangularLower, + lbl, + BlockStructure::UnitTriangularLower, + mat.as_ref().adjoint(), + BlockStructure::Rectangular, + None, + E::faer_one(), + parallelism, + ); + + for j in 0..n { + let pj = self.perm_inv[j]; + for i in j..n { + let pi = self.perm_inv[i]; + + mat.write( + i, + j, + if pi >= pj { + mat2.read(pi, pj) + } else { + mat2.read(pj, pi).faer_conj() + }, + ); + } + } + + for j in 0..n { + mat.write(j, j, E::faer_from_real(mat.read(j, j).faer_real())); + for i in 0..j { + mat.write(i, j, mat.read(j, i).faer_conj()); + } + } + + mat + } +} + +impl PartialPivLu { + /// Returns the LU decomposition of the input matrix with partial (row) pivoting. + /// + /// The factorization is such that $PA = LU$, where $L$ is lower triangular, $U$ is unit + /// upper triangular, and $P$ is the permutation arising from the pivoting. + #[track_caller] + pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { + assert!(matrix.nrows() == matrix.ncols()); + + let dim = matrix.nrows(); + let parallelism = get_global_parallelism(); + + let mut factors = matrix.to_owned(); + + let params = Default::default(); + + let mut row_perm = vec![0usize; dim]; + let mut row_perm_inv = vec![0usize; dim]; + + let (n_transpositions, _) = crate::linalg::lu::partial_pivoting::compute::lu_in_place( + factors.as_mut(), + &mut row_perm, + &mut row_perm_inv, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::lu::partial_pivoting::compute::lu_in_place_req::( + dim, + dim, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + Self { + n_transpositions: n_transpositions.transposition_count, + factors, + row_perm, + row_perm_inv, + } + } + + fn dim(&self) -> usize { + self.factors.nrows() + } + + /// Returns the row permutation due to pivoting. + pub fn row_permutation(&self) -> PermRef<'_, usize> { + unsafe { PermRef::new_unchecked(&self.row_perm, &self.row_perm_inv) } + } + + /// Returns the number of transpositions that consitute the permutation. + pub fn transposition_count(&self) -> usize { + self.n_transpositions + } + + /// Returns the factor $L$ of the LU decomposition. + pub fn compute_l(&self) -> Mat { + let mut factor = self.factors.to_owned(); + zipped!(factor.as_mut()) + .for_each_triangular_upper(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { + dst.write(E::faer_zero()) + }); + factor + } + /// Returns the factor $U$ of the LU decomposition. + pub fn compute_u(&self) -> Mat { + let mut factor = self.factors.to_owned(); + zipped!(factor.as_mut()) + .for_each_triangular_lower(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { + dst.write(E::faer_zero()) + }); + factor + .as_mut() + .diagonal_mut() + .column_vector_mut() + .fill(E::faer_one()); + factor + } +} +impl SpSolverCore for PartialPivLu { + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + + crate::linalg::lu::partial_pivoting::solve::solve_in_place( + self.factors.as_ref(), + conj, + self.row_permutation(), + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::lu::partial_pivoting::solve::solve_in_place_req::( + self.dim(), + self.dim(), + rhs_ncols, + parallelism, + ) + .unwrap(), + )), + ); + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + + crate::linalg::lu::partial_pivoting::solve::solve_transpose_in_place( + self.factors.as_ref(), + conj, + self.row_permutation(), + rhs, + parallelism, + PodStack::new( + &mut GlobalPodBuffer::new( + crate::linalg::lu::partial_pivoting::solve::solve_transpose_in_place_req::< + usize, + E, + >(self.dim(), self.dim(), rhs_ncols, parallelism) + .unwrap(), + ), + ), + ); + } + + fn nrows(&self) -> usize { + self.factors.nrows() + } + + fn ncols(&self) -> usize { + self.factors.ncols() + } +} +impl SolverCore for PartialPivLu { + fn inverse(&self) -> Mat { + let mut inv = Mat::::zeros(self.dim(), self.dim()); + let parallelism = get_global_parallelism(); + + crate::linalg::lu::partial_pivoting::inverse::invert( + inv.as_mut(), + self.factors.as_ref(), + self.row_permutation(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::lu::partial_pivoting::inverse::invert_req::( + self.dim(), + self.dim(), + parallelism, + ) + .unwrap(), + )), + ); + + inv + } + + fn reconstruct(&self) -> Mat { + let mut rec = Mat::::zeros(self.dim(), self.dim()); + let parallelism = get_global_parallelism(); + + crate::linalg::lu::partial_pivoting::reconstruct::reconstruct( + rec.as_mut(), + self.factors.as_ref(), + self.row_permutation(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::lu::partial_pivoting::reconstruct::reconstruct_req::( + self.dim(), + self.dim(), + parallelism, + ) + .unwrap(), + )), + ); + + rec + } +} + +impl FullPivLu { + /// Returns the LU decomposition of the input matrix with row and column pivoting. + /// + /// The factorization is such that $PAQ^\top = LU$, where $L$ is lower triangular, $U$ is unit + /// upper triangular, and $P$ is the permutation arising from row pivoting and $Q$ is the + /// permutation due to column pivoting. + #[track_caller] + pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { + let m = matrix.nrows(); + let n = matrix.ncols(); + let parallelism = get_global_parallelism(); + + let mut factors = matrix.to_owned(); + + let params = Default::default(); + + let mut row_perm = vec![0usize; m]; + let mut row_perm_inv = vec![0usize; m]; + let mut col_perm = vec![0usize; n]; + let mut col_perm_inv = vec![0usize; n]; + + let (n_transpositions, _, _) = crate::linalg::lu::full_pivoting::compute::lu_in_place( + factors.as_mut(), + &mut row_perm, + &mut row_perm_inv, + &mut col_perm, + &mut col_perm_inv, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::lu::full_pivoting::compute::lu_in_place_req::( + m, + n, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + Self { + factors, + row_perm, + row_perm_inv, + col_perm, + col_perm_inv, + n_transpositions: n_transpositions.transposition_count, + } + } + + /// Returns the row permutation due to pivoting. + pub fn row_permutation(&self) -> PermRef<'_, usize> { + unsafe { PermRef::new_unchecked(&self.row_perm, &self.row_perm_inv) } + } + /// Returns the column permutation due to pivoting. + pub fn col_permutation(&self) -> PermRef<'_, usize> { + unsafe { PermRef::new_unchecked(&self.col_perm, &self.col_perm_inv) } + } + + /// Returns the number of transpositions that consitute the two permutations. + pub fn transposition_count(&self) -> usize { + self.n_transpositions + } + + /// Returns the factor $L$ of the LU decomposition. + pub fn compute_l(&self) -> Mat { + let size = Ord::min(self.nrows(), self.ncols()); + let mut factor = self + .factors + .as_ref() + .submatrix(0, 0, self.nrows(), size) + .to_owned(); + zipped!(factor.as_mut()) + .for_each_triangular_upper(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { + dst.write(E::faer_zero()) + }); + factor + } + /// Returns the factor $U$ of the LU decomposition. + pub fn compute_u(&self) -> Mat { + let size = Ord::min(self.nrows(), self.ncols()); + let mut factor = self + .factors + .as_ref() + .submatrix(0, 0, size, self.ncols()) + .to_owned(); + zipped!(factor.as_mut()) + .for_each_triangular_lower(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { + dst.write(E::faer_zero()) + }); + factor + .as_mut() + .diagonal_mut() + .column_vector_mut() + .fill(E::faer_one()); + factor + } +} +impl SpSolverCore for FullPivLu { + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + assert!(self.nrows() == self.ncols()); + + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + + crate::linalg::lu::full_pivoting::solve::solve_in_place( + self.factors.as_ref(), + conj, + self.row_permutation(), + self.col_permutation(), + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::lu::full_pivoting::solve::solve_in_place_req::( + self.nrows(), + self.ncols(), + rhs_ncols, + parallelism, + ) + .unwrap(), + )), + ); + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + assert!(self.nrows() == self.ncols()); + + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + + crate::linalg::lu::full_pivoting::solve::solve_transpose_in_place( + self.factors.as_ref(), + conj, + self.row_permutation(), + self.col_permutation(), + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::lu::full_pivoting::solve::solve_transpose_in_place_req::( + self.nrows(), + self.ncols(), + rhs_ncols, + parallelism, + ) + .unwrap(), + )), + ); + } + + fn nrows(&self) -> usize { + self.factors.nrows() + } + + fn ncols(&self) -> usize { + self.factors.ncols() + } +} +impl SolverCore for FullPivLu { + #[track_caller] + fn inverse(&self) -> Mat { + assert!(self.nrows() == self.ncols()); + + let dim = self.nrows(); + + let mut inv = Mat::::zeros(dim, dim); + let parallelism = get_global_parallelism(); + + crate::linalg::lu::full_pivoting::inverse::invert( + inv.as_mut(), + self.factors.as_ref(), + self.row_permutation(), + self.col_permutation(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::lu::full_pivoting::inverse::invert_req::( + dim, + dim, + parallelism, + ) + .unwrap(), + )), + ); + + inv + } + + fn reconstruct(&self) -> Mat { + let mut rec = Mat::::zeros(self.nrows(), self.ncols()); + let parallelism = get_global_parallelism(); + + crate::linalg::lu::full_pivoting::reconstruct::reconstruct( + rec.as_mut(), + self.factors.as_ref(), + self.row_permutation(), + self.col_permutation(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::lu::full_pivoting::reconstruct::reconstruct_req::( + self.nrows(), + self.ncols(), + parallelism, + ) + .unwrap(), + )), + ); + + rec + } +} + +impl Qr { + /// Returns the QR decomposition of the input matrix without pivoting. + /// + /// The factorization is such that $A = QR$, where $R$ is upper trapezoidal and $Q$ is unitary. + #[track_caller] + pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { + let parallelism = get_global_parallelism(); + let nrows = matrix.nrows(); + let ncols = matrix.ncols(); + + let mut factors = matrix.to_owned(); + let size = Ord::min(nrows, ncols); + let blocksize = + crate::linalg::qr::no_pivoting::compute::recommended_blocksize::(nrows, ncols); + let mut householder = Mat::::zeros(blocksize, size); + + let params = Default::default(); + + crate::linalg::qr::no_pivoting::compute::qr_in_place( + factors.as_mut(), + householder.as_mut(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::qr::no_pivoting::compute::qr_in_place_req::( + nrows, + ncols, + blocksize, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + Self { + factors, + householder, + } + } + + fn blocksize(&self) -> usize { + self.householder.nrows() + } + + /// Returns the factor $R$ of the QR decomposition. + pub fn compute_r(&self) -> Mat { + let mut factor = self.factors.to_owned(); + zipped!(factor.as_mut()) + .for_each_triangular_lower(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { + dst.write(E::faer_zero()) + }); + factor + } + + /// Returns the factor $R$ of the QR decomposition. + pub fn compute_q(&self) -> Mat { + Self::__compute_q_impl(self.factors.as_ref(), self.householder.as_ref(), false) + } + + /// Returns the top $r$ rows of the factor $R$ of the QR decomposition, where $r = + /// \min(\text{nrows}(A), \text{ncols}(A))$. + pub fn compute_thin_r(&self) -> Mat { + let m = self.nrows(); + let n = self.ncols(); + let mut factor = self.factors.as_ref().subrows(0, Ord::min(m, n)).to_owned(); + zipped!(factor.as_mut()) + .for_each_triangular_lower(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { + dst.write(E::faer_zero()) + }); + factor + } + + /// Returns the leftmost $r$ columns of the factor $R$ of the QR decomposition, where $r = + /// \min(\text{nrows}(A), \text{ncols}(A))$. + pub fn compute_thin_q(&self) -> Mat { + Self::__compute_q_impl(self.factors.as_ref(), self.householder.as_ref(), true) + } + + fn __compute_q_impl(factors: MatRef<'_, E>, householder: MatRef<'_, E>, thin: bool) -> Mat { + let parallelism = get_global_parallelism(); + let m = factors.nrows(); + let size = Ord::min(m, factors.ncols()); + + let mut q = Mat::::zeros(m, if thin { size } else { m }); + q.as_mut() + .diagonal_mut() + .column_vector_mut() + .fill(E::faer_one()); + + crate::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj( + factors, + householder, + Conj::No, + q.as_mut(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_req::( + m, + householder.nrows(), + m, + ) + .unwrap(), + )), + ); + + q + } +} +impl SpSolverCore for Qr { + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + assert!(self.nrows() == self.ncols()); + self.solve_lstsq_in_place_with_conj_impl(rhs, conj) + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + assert!(self.nrows() == self.ncols()); + + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + + crate::linalg::qr::no_pivoting::solve::solve_transpose_in_place( + self.factors.as_ref(), + self.householder.as_ref(), + conj, + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::qr::no_pivoting::solve::solve_transpose_in_place_req::( + self.nrows(), + self.blocksize(), + rhs_ncols, + ) + .unwrap(), + )), + ); + } + + fn nrows(&self) -> usize { + self.factors.nrows() + } + + fn ncols(&self) -> usize { + self.factors.ncols() + } +} +impl SolverCore for Qr { + fn reconstruct(&self) -> Mat { + let mut rec = Mat::::zeros(self.nrows(), self.ncols()); + let parallelism = get_global_parallelism(); + + crate::linalg::qr::no_pivoting::reconstruct::reconstruct( + rec.as_mut(), + self.factors.as_ref(), + self.householder.as_ref(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::qr::no_pivoting::reconstruct::reconstruct_req::( + self.nrows(), + self.ncols(), + self.blocksize(), + parallelism, + ) + .unwrap(), + )), + ); + + rec + } + + fn inverse(&self) -> Mat { + assert!(self.nrows() == self.ncols()); + + let mut inv = Mat::::zeros(self.nrows(), self.ncols()); + let parallelism = get_global_parallelism(); + + crate::linalg::qr::no_pivoting::inverse::invert( + inv.as_mut(), + self.factors.as_ref(), + self.householder.as_ref(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::qr::no_pivoting::inverse::invert_req::( + self.nrows(), + self.ncols(), + self.blocksize(), + parallelism, + ) + .unwrap(), + )), + ); + + inv + } +} + +impl SpSolverLstsqCore for Qr { + #[track_caller] + fn solve_lstsq_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + + crate::linalg::qr::no_pivoting::solve::solve_in_place( + self.factors.as_ref(), + self.householder.as_ref(), + conj, + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::qr::no_pivoting::solve::solve_in_place_req::( + self.nrows(), + self.blocksize(), + rhs_ncols, + ) + .unwrap(), + )), + ); + } +} +impl SolverLstsqCore for Qr {} + +impl ColPivQr { + /// Returns the QR decomposition of the input matrix with column pivoting. + /// + /// The factorization is such that $AP^\top = QR$, where $R$ is upper trapezoidal, $Q$ is + /// unitary, and $P$ is a permutation matrix. + #[track_caller] + pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { + let parallelism = get_global_parallelism(); + let nrows = matrix.nrows(); + let ncols = matrix.ncols(); + + let mut factors = matrix.to_owned(); + let size = Ord::min(nrows, ncols); + let blocksize = + crate::linalg::qr::col_pivoting::compute::recommended_blocksize::(nrows, ncols); + let mut householder = Mat::::zeros(blocksize, size); + + let params = Default::default(); + + let mut col_perm = vec![0usize; ncols]; + let mut col_perm_inv = vec![0usize; ncols]; + + crate::linalg::qr::col_pivoting::compute::qr_in_place( + factors.as_mut(), + householder.as_mut(), + &mut col_perm, + &mut col_perm_inv, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::qr::col_pivoting::compute::qr_in_place_req::( + nrows, + ncols, + blocksize, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + Self { + factors, + householder, + col_perm, + col_perm_inv, + } + } + + /// Returns the column permutation matrix $P$ of the QR decomposition. + pub fn col_permutation(&self) -> PermRef<'_, usize> { + unsafe { PermRef::new_unchecked(&self.col_perm, &self.col_perm_inv) } + } + + fn blocksize(&self) -> usize { + self.householder.nrows() + } + + /// Returns the factor $R$ of the QR decomposition. + pub fn compute_r(&self) -> Mat { + let mut factor = self.factors.to_owned(); + zipped!(factor.as_mut()) + .for_each_triangular_lower(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { + dst.write(E::faer_zero()) + }); + factor + } + + /// Returns the factor $Q$ of the QR decomposition. + pub fn compute_q(&self) -> Mat { + Qr::::__compute_q_impl(self.factors.as_ref(), self.householder.as_ref(), false) + } + + /// Returns the top $r$ rows of the factor $R$ of the QR decomposition, where $r = + /// \min(\text{nrows}(A), \text{ncols}(A))$. + pub fn compute_thin_r(&self) -> Mat { + let m = self.nrows(); + let n = self.ncols(); + let mut factor = self.factors.as_ref().subrows(0, Ord::min(m, n)).to_owned(); + zipped!(factor.as_mut()) + .for_each_triangular_lower(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst)| { + dst.write(E::faer_zero()) + }); + factor + } + + /// Returns the leftmost $r$ columns of the factor $R$ of the QR decomposition, where $r = + /// \min(\text{nrows}(A), \text{ncols}(A))$. + pub fn compute_thin_q(&self) -> Mat { + Qr::::__compute_q_impl(self.factors.as_ref(), self.householder.as_ref(), true) + } +} +impl SpSolverCore for ColPivQr { + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + assert!(self.nrows() == self.ncols()); + self.solve_lstsq_in_place_with_conj_impl(rhs, conj); + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + assert!(self.nrows() == self.ncols()); + + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + + crate::linalg::qr::col_pivoting::solve::solve_transpose_in_place( + self.factors.as_ref(), + self.householder.as_ref(), + self.col_permutation(), + conj, + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::qr::col_pivoting::solve::solve_transpose_in_place_req::( + self.nrows(), + self.blocksize(), + rhs_ncols, + ) + .unwrap(), + )), + ); + } + + fn nrows(&self) -> usize { + self.factors.nrows() + } + + fn ncols(&self) -> usize { + self.factors.ncols() + } +} +impl SolverCore for ColPivQr { + fn reconstruct(&self) -> Mat { + let mut rec = Mat::::zeros(self.nrows(), self.ncols()); + let parallelism = get_global_parallelism(); + + crate::linalg::qr::col_pivoting::reconstruct::reconstruct( + rec.as_mut(), + self.factors.as_ref(), + self.householder.as_ref(), + self.col_permutation(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::qr::col_pivoting::reconstruct::reconstruct_req::( + self.nrows(), + self.ncols(), + self.blocksize(), + parallelism, + ) + .unwrap(), + )), + ); + + rec + } + + fn inverse(&self) -> Mat { + assert!(self.nrows() == self.ncols()); + + let mut inv = Mat::::zeros(self.nrows(), self.ncols()); + let parallelism = get_global_parallelism(); + + crate::linalg::qr::col_pivoting::inverse::invert( + inv.as_mut(), + self.factors.as_ref(), + self.householder.as_ref(), + self.col_permutation(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::qr::col_pivoting::inverse::invert_req::( + self.nrows(), + self.ncols(), + self.blocksize(), + parallelism, + ) + .unwrap(), + )), + ); + + inv + } +} + +impl SpSolverLstsqCore for ColPivQr { + #[track_caller] + fn solve_lstsq_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + + crate::linalg::qr::col_pivoting::solve::solve_in_place( + self.factors.as_ref(), + self.householder.as_ref(), + self.col_permutation(), + conj, + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::qr::col_pivoting::solve::solve_in_place_req::( + self.nrows(), + self.blocksize(), + rhs_ncols, + ) + .unwrap(), + )), + ); + } +} +impl SolverLstsqCore for ColPivQr {} + +impl Svd { + #[track_caller] + fn __new_impl((matrix, conj): (MatRef<'_, E>, Conj), thin: bool) -> Self { + let parallelism = get_global_parallelism(); + let m = matrix.nrows(); + let n = matrix.ncols(); + let size = Ord::min(m, n); + + let mut s = Mat::::zeros(size, 1); + let mut u = Mat::::zeros(m, if thin { size } else { m }); + let mut v = Mat::::zeros(n, if thin { size } else { n }); + + let params = Default::default(); + + let compute_vecs = if thin { + crate::linalg::svd::ComputeVectors::Thin + } else { + crate::linalg::svd::ComputeVectors::Full + }; + + crate::linalg::svd::compute_svd( + matrix, + s.as_mut(), + Some(u.as_mut()), + Some(v.as_mut()), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::svd::compute_svd_req::( + m, + n, + compute_vecs, + compute_vecs, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + if matches!(conj, Conj::Yes) { + zipped!(u.as_mut()).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); + zipped!(v.as_mut()).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); + } + + Self { s, u, v } + } + + /// Returns the SVD of the input matrix. + /// + /// The factorization is such that $A = U S V^H$, where $U$ and $V$ are unitary and $S$ is a + /// rectangular diagonal matrix. + #[track_caller] + pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { + Self::__new_impl(matrix.canonicalize(), false) + } + + /// Returns the factor $U$ of the SVD. + pub fn u(&self) -> MatRef<'_, E> { + self.u.as_ref() + } + /// Returns the diagonal of the factor $S$ of the SVD as a column vector. + pub fn s_diagonal(&self) -> ColRef<'_, E> { + self.s.as_ref().col(0) + } + /// Returns the factor $V$ of the SVD. + pub fn v(&self) -> MatRef<'_, E> { + self.v.as_ref() + } +} +fn div_by_s(rhs: MatMut<'_, E>, s: MatRef<'_, E>) { + let mut rhs = rhs; + for j in 0..rhs.ncols() { + zipped!(rhs.rb_mut().col_mut(j).as_2d_mut(), s).for_each(|unzipped!(mut rhs, s)| { + rhs.write(rhs.read().faer_scale_real(s.read().faer_real().faer_inv())) + }); + } +} +impl SpSolverCore for Svd { + fn nrows(&self) -> usize { + self.u.nrows() + } + + fn ncols(&self) -> usize { + self.v.nrows() + } + + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + assert!(self.nrows() == self.ncols()); + let mut rhs = rhs; + + let u = self.u.as_ref(); + let v = self.v.as_ref(); + let s = self.s.as_ref(); + + match conj { + Conj::Yes => { + rhs.copy_from((u.transpose() * rhs.rb()).as_ref()); + div_by_s(rhs.rb_mut(), s); + rhs.copy_from((v.conjugate() * rhs.rb()).as_ref()); + } + Conj::No => { + rhs.copy_from((u.adjoint() * rhs.rb()).as_ref()); + div_by_s(rhs.rb_mut(), s); + rhs.copy_from((v * rhs.rb()).as_ref()); + } + } + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + assert!(self.nrows() == self.ncols()); + let mut rhs = rhs; + + let u = self.u.as_ref(); + let v = self.v.as_ref(); + let s = self.s.as_ref(); + + match conj { + Conj::No => { + rhs.copy_from((v.transpose() * rhs.rb()).as_ref()); + div_by_s(rhs.rb_mut(), s); + rhs.copy_from((u.conjugate() * rhs.rb()).as_ref()); + } + Conj::Yes => { + rhs.copy_from((v.adjoint() * rhs.rb()).as_ref()); + div_by_s(rhs.rb_mut(), s); + rhs.copy_from((u * rhs.rb()).as_ref()); + } + } + } +} +impl SolverCore for Svd { + fn reconstruct(&self) -> Mat { + let m = self.nrows(); + let n = self.ncols(); + let size = Ord::min(m, n); + + let thin_u = self.u.as_ref().submatrix(0, 0, m, size); + let s = self.s.as_ref(); + let us = Mat::::from_fn(m, size, |i, j| thin_u.read(i, j).faer_mul(s.read(j, 0))); + + us * self.v.adjoint() + } + + fn inverse(&self) -> Mat { + assert!(self.nrows() == self.ncols()); + let dim = self.nrows(); + + let u = self.u.as_ref(); + let v = self.v.as_ref(); + let s = self.s.as_ref(); + + let vs_inv = Mat::::from_fn(dim, dim, |i, j| { + v.read(i, j).faer_mul(s.read(j, 0).faer_inv()) + }); + + vs_inv * u.adjoint() + } +} + +impl ThinSvd { + /// Returns the thin SVD of the input matrix. + /// + /// This is the same as the SVD except that only the leftmost $r$ columns of $U$ and $V$ are + /// computed, where $r = \min(\text{nrows}(A), \text{ncols}(A))$. + #[track_caller] + pub fn new>(matrix: MatRef<'_, ViewE>) -> Self { + Self { + inner: Svd::__new_impl(matrix.canonicalize(), true), + } + } + + /// Returns the factor $U$ of the SVD. + pub fn u(&self) -> MatRef<'_, E> { + self.inner.u.as_ref() + } + /// Returns the diagonal of the factor $S$ of the SVD as a column vector. + pub fn s_diagonal(&self) -> ColRef<'_, E> { + self.inner.s.as_ref().col(0) + } + /// Returns the factor $V$ of the SVD. + pub fn v(&self) -> MatRef<'_, E> { + self.inner.v.as_ref() + } +} +impl SpSolverCore for ThinSvd { + fn nrows(&self) -> usize { + self.inner.nrows() + } + + fn ncols(&self) -> usize { + self.inner.ncols() + } + + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + self.inner.solve_in_place_with_conj_impl(rhs, conj) + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + self.inner + .solve_transpose_in_place_with_conj_impl(rhs, conj) + } +} +impl SolverCore for ThinSvd { + fn reconstruct(&self) -> Mat { + self.inner.reconstruct() + } + + fn inverse(&self) -> Mat { + self.inner.inverse() + } +} + +impl SelfAdjointEigendecomposition { + #[track_caller] + fn __new_impl((matrix, conj): (MatRef<'_, E>, Conj), side: Side) -> Self { + assert!(matrix.nrows() == matrix.ncols()); + let parallelism = get_global_parallelism(); + + let dim = matrix.nrows(); + + let mut s = Mat::::zeros(dim, 1); + let mut u = Mat::::zeros(dim, dim); + + let matrix = match side { + Side::Lower => matrix, + Side::Upper => matrix.transpose(), + }; + let conj = conj.compose(match side { + Side::Lower => Conj::No, + Side::Upper => Conj::Yes, + }); + + let params = Default::default(); + crate::linalg::evd::compute_hermitian_evd( + matrix, + s.as_mut(), + Some(u.as_mut()), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::evd::compute_hermitian_evd_req::( + dim, + crate::linalg::evd::ComputeVectors::Yes, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + if matches!(conj, Conj::Yes) { + zipped!(u.as_mut()).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); + } + + Self { s, u } + } + + /// Returns the eigenvalue decomposition of the Hermitian input matrix. + /// + /// The factorization is such that $A = U S U^\H$, where $S$ is a diagonal matrix, and $U$ is + /// unitary. + /// + /// Only the provided side is accessed. + #[track_caller] + pub fn new>(matrix: MatRef<'_, ViewE>, side: Side) -> Self { + Self::__new_impl(matrix.canonicalize(), side) + } + + /// Returns the factor $U$ of the eigenvalue decomposition. + pub fn u(&self) -> MatRef<'_, E> { + self.u.as_ref() + } + /// Returns the factor $S$ of the eigenvalue decomposition. + pub fn s(&self) -> DiagRef<'_, E> { + self.s.as_ref().col(0).column_vector_as_diagonal() + } +} +impl SpSolverCore for SelfAdjointEigendecomposition { + fn nrows(&self) -> usize { + self.u.nrows() + } + + fn ncols(&self) -> usize { + self.u.nrows() + } + + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + assert!(self.nrows() == self.ncols()); + let mut rhs = rhs; + + let u = self.u.as_ref(); + let s = self.s.as_ref(); + + match conj { + Conj::Yes => { + rhs.copy_from((u.transpose() * rhs.rb()).as_ref()); + div_by_s(rhs.rb_mut(), s); + rhs.copy_from((u.conjugate() * rhs.rb()).as_ref()); + } + Conj::No => { + rhs.copy_from((u.adjoint() * rhs.rb()).as_ref()); + div_by_s(rhs.rb_mut(), s); + rhs.copy_from((u * rhs.rb()).as_ref()); + } + } + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + assert!(self.nrows() == self.ncols()); + let mut rhs = rhs; + + let u = self.u.as_ref(); + let s = self.s.as_ref(); + + match conj { + Conj::No => { + rhs.copy_from((u.transpose() * rhs.rb()).as_ref()); + div_by_s(rhs.rb_mut(), s); + rhs.copy_from((u.conjugate() * rhs.rb()).as_ref()); + } + Conj::Yes => { + rhs.copy_from((u.adjoint() * rhs.rb()).as_ref()); + div_by_s(rhs.rb_mut(), s); + rhs.copy_from((u * rhs.rb()).as_ref()); + } + } + } +} +impl SolverCore for SelfAdjointEigendecomposition { + fn reconstruct(&self) -> Mat { + let size = self.nrows(); + + let u = self.u.as_ref(); + let s = self.s.as_ref(); + let us = Mat::::from_fn(size, size, |i, j| u.read(i, j).faer_mul(s.read(j, 0))); + + us * u.adjoint() + } + + fn inverse(&self) -> Mat { + let dim = self.nrows(); + + let u = self.u.as_ref(); + let s = self.s.as_ref(); + + let us_inv = Mat::::from_fn(dim, dim, |i, j| { + u.read(i, j).faer_mul(s.read(j, 0).faer_inv()) + }); + + us_inv * u.adjoint() + } +} + +impl Eigendecomposition { + #[track_caller] + pub(crate) fn __values_from_real(matrix: MatRef<'_, E::Real>) -> Vec { + assert!(matrix.nrows() == matrix.ncols()); + if coe::is_same::() { + panic!( + "The type E ({}) must not be real-valued.", + core::any::type_name::(), + ); + } + + let parallelism = get_global_parallelism(); + + let dim = matrix.nrows(); + let mut s_re = Mat::::zeros(dim, 1); + let mut s_im = Mat::::zeros(dim, 1); + + let params = Default::default(); + + crate::linalg::evd::compute_evd_real( + matrix, + s_re.as_mut(), + s_im.as_mut(), + None, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::evd::compute_evd_req::( + dim, + crate::linalg::evd::ComputeVectors::Yes, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + let imag = E::faer_from_f64(-1.0).faer_sqrt(); + let cplx = |re: E::Real, im: E::Real| -> E { + E::faer_from_real(re).faer_add(imag.faer_mul(E::faer_from_real(im))) + }; + + (0..dim) + .map(|i| cplx(s_re.read(i, 0), s_im.read(i, 0))) + .collect() + } + + #[track_caller] + pub(crate) fn __values_from_complex_impl((matrix, conj): (MatRef<'_, E>, Conj)) -> Vec { + assert!(matrix.nrows() == matrix.ncols()); + if coe::is_same::() { + panic!( + "The type E ({}) must not be real-valued.", + core::any::type_name::(), + ); + } + + let parallelism = get_global_parallelism(); + let dim = matrix.nrows(); + + let mut s = Mat::::zeros(dim, 1); + + let params = Default::default(); + + crate::linalg::evd::compute_evd_complex( + matrix, + s.as_mut(), + None, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::evd::compute_evd_req::( + dim, + crate::linalg::evd::ComputeVectors::Yes, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + if matches!(conj, Conj::Yes) { + zipped!(s.as_mut()).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); + } + + (0..dim).map(|i| s.read(i, 0)).collect() + } + + /// Returns the eigendecomposition of the real-valued input matrix. + /// + /// The factorization is such that $A = U S U^\H$, where $S$ is a diagonal matrix, and $U$ is + /// unitary. + #[track_caller] + pub fn new_from_real(matrix: MatRef<'_, E::Real>) -> Self { + assert!(matrix.nrows() == matrix.ncols()); + if coe::is_same::() { + panic!( + "The type E ({}) must not be real-valued.", + core::any::type_name::(), + ); + } + + let parallelism = get_global_parallelism(); + + let dim = matrix.nrows(); + let mut s_re = Col::::zeros(dim); + let mut s_im = Col::::zeros(dim); + let mut u_real = Mat::::zeros(dim, dim); + + let params = Default::default(); + + crate::linalg::evd::compute_evd_real( + matrix, + s_re.as_mut().as_2d_mut(), + s_im.as_mut().as_2d_mut(), + Some(u_real.as_mut()), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::evd::compute_evd_req::( + dim, + crate::linalg::evd::ComputeVectors::Yes, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + let imag = E::faer_from_f64(-1.0).faer_sqrt(); + let cplx = |re: E::Real, im: E::Real| -> E { + E::faer_from_real(re).faer_add(imag.faer_mul(E::faer_from_real(im))) + }; + + let s = Col::::from_fn(dim, |i| cplx(s_re.read(i), s_im.read(i))); + let mut u = Mat::::zeros(dim, dim); + let u_real = u_real.as_ref(); + + let mut j = 0usize; + while j < dim { + if s_im.read(j) == E::Real::faer_zero() { + zipped!(u.as_mut().col_mut(j).as_2d_mut(), u_real.col(j).as_2d()) + .for_each(|unzipped!(mut dst, src)| dst.write(E::faer_from_real(src.read()))); + j += 1; + } else { + let (u_left, u_right) = u.as_mut().split_at_col_mut(j + 1); + + zipped!( + u_left.col_mut(j).as_2d_mut(), + u_right.col_mut(0).as_2d_mut(), + u_real.col(j).as_2d(), + u_real.col(j + 1).as_2d(), + ) + .for_each(|unzipped!(mut dst, mut dst_conj, re, im)| { + let re = re.read(); + let im = im.read(); + dst_conj.write(cplx(re, im.faer_neg())); + dst.write(cplx(re, im)); + }); + + j += 2; + } + } + + Self { s, u } + } + + #[track_caller] + pub(crate) fn __new_from_complex_impl((matrix, conj): (MatRef<'_, E>, Conj)) -> Self { + assert!(matrix.nrows() == matrix.ncols()); + if coe::is_same::() { + panic!( + "The type E ({}) must not be real-valued.", + core::any::type_name::(), + ); + } + + let parallelism = get_global_parallelism(); + let dim = matrix.nrows(); + + let mut s = Col::::zeros(dim); + let mut u = Mat::::zeros(dim, dim); + + let params = Default::default(); + + crate::linalg::evd::compute_evd_complex( + matrix, + s.as_mut().as_2d_mut(), + Some(u.as_mut()), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::evd::compute_evd_req::( + dim, + crate::linalg::evd::ComputeVectors::Yes, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + if matches!(conj, Conj::Yes) { + zipped!(s.as_mut().as_2d_mut()) + .for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); + zipped!(u.as_mut()).for_each(|unzipped!(mut x)| x.write(x.read().faer_conj())); + } + + Self { s, u } + } + + /// Returns the eigendecomposition of the complex-valued input matrix. + /// + /// The factorization is such that $A = U S U^\H$, where $S$ is a diagonal matrix, and $U$ is + /// unitary. + #[track_caller] + pub fn new_from_complex>(matrix: MatRef<'_, ViewE>) -> Self { + Self::__new_from_complex_impl(matrix.canonicalize()) + } + + /// Returns the factor $U$ of the eigenvalue decomposition. + pub fn u(&self) -> MatRef<'_, E> { + self.u.as_ref() + } + /// Returns the factor $S$ of the eigenvalue decomposition. + pub fn s(&self) -> DiagRef<'_, E> { + self.s.as_ref().column_vector_as_diagonal() + } +} + +impl MatRef<'_, E> +where + E::Canonical: ComplexField, +{ + /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + #[track_caller] + pub fn solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + let parallelism = get_global_parallelism(); + let mut rhs = rhs; + crate::linalg::triangular_solve::solve_lower_triangular_in_place( + *self, + rhs.as_mat_mut(), + parallelism, + ); + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + #[track_caller] + pub fn solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + let parallelism = get_global_parallelism(); + let mut rhs = rhs; + crate::linalg::triangular_solve::solve_upper_triangular_in_place( + *self, + rhs.as_mat_mut(), + parallelism, + ); + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + let parallelism = get_global_parallelism(); + let mut rhs = rhs; + crate::linalg::triangular_solve::solve_unit_lower_triangular_in_place( + *self, + rhs.as_mat_mut(), + parallelism, + ); + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs` + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + let parallelism = get_global_parallelism(); + let mut rhs = rhs; + crate::linalg::triangular_solve::solve_unit_upper_triangular_in_place( + *self, + rhs.as_mat_mut(), + parallelism, + ); + } + + /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + #[track_caller] + pub fn solve_lower_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + let mut rhs = rhs.as_mat_ref().to_owned(); + self.solve_lower_triangular_in_place(rhs.as_mut()); + rhs + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + #[track_caller] + pub fn solve_upper_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + let mut rhs = rhs.as_mat_ref().to_owned(); + self.solve_upper_triangular_in_place(rhs.as_mut()); + rhs + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_lower_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + let mut rhs = rhs.as_mat_ref().to_owned(); + self.solve_unit_lower_triangular_in_place(rhs.as_mut()); + rhs + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_upper_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + let mut rhs = rhs.as_mat_ref().to_owned(); + self.solve_unit_upper_triangular_in_place(rhs.as_mut()); + rhs + } + + /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn cholesky(&self, side: Side) -> Result, CholeskyError> { + Cholesky::try_new(self.as_ref(), side) + } + /// Returns the Bunch-Kaufman decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn lblt(&self, side: Side) -> Lblt { + Lblt::new(self.as_ref(), side) + } + /// Returns the LU decomposition of `self` with partial (row) pivoting. + #[track_caller] + pub fn partial_piv_lu(&self) -> PartialPivLu { + PartialPivLu::::new(self.as_ref()) + } + /// Returns the LU decomposition of `self` with full pivoting. + #[track_caller] + pub fn full_piv_lu(&self) -> FullPivLu { + FullPivLu::::new(self.as_ref()) + } + /// Returns the QR decomposition of `self`. + #[track_caller] + pub fn qr(&self) -> Qr { + Qr::::new(self.as_ref()) + } + /// Returns the QR decomposition of `self` with column pivoting. + #[track_caller] + pub fn col_piv_qr(&self) -> ColPivQr { + ColPivQr::::new(self.as_ref()) + } + /// Returns the SVD of `self`. + #[track_caller] + pub fn svd(&self) -> Svd { + Svd::::new(self.as_ref()) + } + /// Returns the thin SVD of `self`. + #[track_caller] + pub fn thin_svd(&self) -> ThinSvd { + ThinSvd::::new(self.as_ref()) + } + /// Returns the eigendecomposition of `self`, assuming it is self-adjoint. Only the provided + /// side is accessed. + #[track_caller] + pub fn selfadjoint_eigendecomposition( + &self, + side: Side, + ) -> SelfAdjointEigendecomposition { + SelfAdjointEigendecomposition::::new(self.as_ref(), side) + } + + /// Returns the eigendecomposition of `self`, as a complex matrix. + #[track_caller] + pub fn eigendecomposition< + ComplexE: ComplexField::Real>, + >( + &self, + ) -> Eigendecomposition { + if coe::is_same::::Real>() { + let matrix: MatRef<'_, ::Real> = + coe::coerce(self.as_ref()); + Eigendecomposition::::new_from_real(matrix) + } else if coe::is_same::() { + let (matrix, conj) = self.as_ref().canonicalize(); + Eigendecomposition::::__new_from_complex_impl((coe::coerce(matrix), conj)) + } else { + panic!( + "The type ComplexE must be either E::Canonical ({}) or E::Canonical::Real ({})", + core::any::type_name::(), + core::any::type_name::<::Real>(), + ); + } + } + + /// Returns the eigendecomposition of `self`, when `E` is in the complex domain. + #[track_caller] + pub fn complex_eigendecomposition(&self) -> Eigendecomposition { + Eigendecomposition::::new_from_complex(self.as_ref()) + } + + /// Returns the determinant of `self`. + #[track_caller] + pub fn determinant(&self) -> E::Canonical { + assert!(self.nrows() == self.ncols()); + let lu = self.partial_piv_lu(); + let mut det = E::Canonical::faer_one(); + for i in 0..self.nrows() { + det = det.faer_mul(lu.factors.read(i, i)); + } + if lu.transposition_count() % 2 == 0 { + det + } else { + det.faer_neg() + } + } + + /// Returns the eigenvalues of `self`, assuming it is self-adjoint. Only the provided + /// side is accessed. The order of the eigenvalues is currently unspecified. + #[track_caller] + pub fn selfadjoint_eigenvalues(&self, side: Side) -> Vec<::Real> { + let matrix = match side { + Side::Lower => *self, + Side::Upper => self.transpose(), + }; + + assert!(matrix.nrows() == matrix.ncols()); + let dim = matrix.nrows(); + let parallelism = get_global_parallelism(); + + let mut s = Mat::::zeros(dim, 1); + let params = Default::default(); + crate::linalg::evd::compute_hermitian_evd( + matrix.canonicalize().0, + s.as_mut(), + None, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::evd::compute_hermitian_evd_req::( + dim, + crate::linalg::evd::ComputeVectors::No, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + (0..dim).map(|i| s.read(i, 0).faer_real()).collect() + } + + /// Returns the singular values of `self`, in nonincreasing order. + #[track_caller] + pub fn singular_values(&self) -> Vec<::Real> { + let dim = Ord::min(self.nrows(), self.ncols()); + let parallelism = get_global_parallelism(); + + let mut s = Mat::::zeros(dim, 1); + let params = Default::default(); + crate::linalg::svd::compute_svd( + self.canonicalize().0, + s.as_mut(), + None, + None, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + crate::linalg::svd::compute_svd_req::( + self.nrows(), + self.ncols(), + crate::linalg::svd::ComputeVectors::No, + crate::linalg::svd::ComputeVectors::No, + parallelism, + params, + ) + .unwrap(), + )), + params, + ); + + (0..dim).map(|i| s.read(i, 0).faer_real()).collect() + } + + /// Returns the eigenvalues of `self`, as complex values. The order of the eigenvalues is + /// currently unspecified. + #[track_caller] + pub fn eigenvalues::Real>>( + &self, + ) -> Vec { + if coe::is_same::::Real>() { + let matrix: MatRef<'_, ::Real> = + coe::coerce(self.as_ref()); + Eigendecomposition::::__values_from_real(matrix) + } else if coe::is_same::() { + let (matrix, conj) = self.as_ref().canonicalize(); + Eigendecomposition::::__values_from_complex_impl((coe::coerce(matrix), conj)) + } else { + panic!( + "The type ComplexE must be either E::Canonical ({}) or E::Canonical::Real ({})", + core::any::type_name::(), + core::any::type_name::<::Real>(), + ); + } + } + + /// Returns the eigenvalues of `self`, when `E` is in the complex domain. The order of the + /// eigenvalues is currently unspecified. + #[track_caller] + pub fn complex_eigenvalues(&self) -> Vec { + Eigendecomposition::::__values_from_complex_impl(self.canonicalize()) + } +} + +impl MatMut<'_, E> +where + E::Canonical: ComplexField, +{ + /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + #[track_caller] + pub fn solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().solve_lower_triangular_in_place(rhs) + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + #[track_caller] + pub fn solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().solve_upper_triangular_in_place(rhs) + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().solve_unit_lower_triangular_in_place(rhs) + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs` + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().solve_unit_upper_triangular_in_place(rhs) + } + + /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + #[track_caller] + pub fn solve_lower_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + self.as_ref().solve_lower_triangular(rhs.as_mat_ref()) + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + #[track_caller] + pub fn solve_upper_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + self.as_ref().solve_upper_triangular(rhs.as_mat_ref()) + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_lower_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + self.as_ref().solve_unit_lower_triangular(rhs.as_mat_ref()) + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_upper_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + self.as_ref().solve_unit_upper_triangular(rhs.as_mat_ref()) + } + + /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn cholesky(&self, side: Side) -> Result, CholeskyError> { + self.as_ref().cholesky(side) + } + /// Returns the Bunch-Kaufman decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn lblt(&self, side: Side) -> Lblt { + self.as_ref().lblt(side) + } + /// Returns the LU decomposition of `self` with partial (row) pivoting. + #[track_caller] + pub fn partial_piv_lu(&self) -> PartialPivLu { + self.as_ref().partial_piv_lu() + } + /// Returns the LU decomposition of `self` with full pivoting. + #[track_caller] + pub fn full_piv_lu(&self) -> FullPivLu { + self.as_ref().full_piv_lu() + } + /// Returns the QR decomposition of `self`. + #[track_caller] + pub fn qr(&self) -> Qr { + self.as_ref().qr() + } + /// Returns the QR decomposition of `self` with column pivoting. + #[track_caller] + pub fn col_piv_qr(&self) -> ColPivQr { + self.as_ref().col_piv_qr() + } + /// Returns the SVD of `self`. + #[track_caller] + pub fn svd(&self) -> Svd { + self.as_ref().svd() + } + /// Returns the thin SVD of `self`. + #[track_caller] + pub fn thin_svd(&self) -> ThinSvd { + self.as_ref().thin_svd() + } + /// Returns the eigendecomposition of `self`, assuming it is self-adjoint. Only the provided + /// side is accessed. + #[track_caller] + pub fn selfadjoint_eigendecomposition( + &self, + side: Side, + ) -> SelfAdjointEigendecomposition { + self.as_ref().selfadjoint_eigendecomposition(side) + } + + /// Returns the eigendecomposition of `self`, as a complex matrix. + #[track_caller] + pub fn eigendecomposition< + ComplexE: ComplexField::Real>, + >( + &self, + ) -> Eigendecomposition { + self.as_ref().eigendecomposition::() + } + + /// Returns the eigendecomposition of `self`, when `E` is in the complex domain. + #[track_caller] + pub fn complex_eigendecomposition(&self) -> Eigendecomposition { + self.as_ref().complex_eigendecomposition() + } + + /// Returns the determinant of `self`. + #[track_caller] + pub fn determinant(&self) -> E::Canonical { + self.as_ref().determinant() + } + + /// Returns the eigenvalues of `self`, assuming it is self-adjoint. Only the provided + /// side is accessed. The order of the eigenvalues is currently unspecified. + #[track_caller] + pub fn selfadjoint_eigenvalues(&self, side: Side) -> Vec<::Real> { + self.as_ref().selfadjoint_eigenvalues(side) + } + + /// Returns the singular values of `self`, in nonincreasing order. + #[track_caller] + pub fn singular_values(&self) -> Vec<::Real> { + self.as_ref().singular_values() + } + + /// Returns the eigenvalues of `self`, as complex values. The order of the eigenvalues is + /// currently unspecified. + #[track_caller] + pub fn eigenvalues::Real>>( + &self, + ) -> Vec { + self.as_ref().eigenvalues() + } + + /// Returns the eigenvalues of `self`, when `E` is in the complex domain. The order of the + /// eigenvalues is currently unspecified. + #[track_caller] + pub fn complex_eigenvalues(&self) -> Vec { + self.as_ref().complex_eigenvalues() + } +} + +impl Mat +where + E::Canonical: ComplexField, +{ + /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + #[track_caller] + pub fn solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().solve_lower_triangular_in_place(rhs) + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + #[track_caller] + pub fn solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().solve_upper_triangular_in_place(rhs) + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().solve_unit_lower_triangular_in_place(rhs) + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs` + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().solve_unit_upper_triangular_in_place(rhs) + } + + /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + #[track_caller] + pub fn solve_lower_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + self.as_ref().solve_lower_triangular(rhs.as_mat_ref()) + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + #[track_caller] + pub fn solve_upper_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + self.as_ref().solve_upper_triangular(rhs.as_mat_ref()) + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_lower_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + self.as_ref().solve_unit_lower_triangular(rhs.as_mat_ref()) + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, and + /// returns the result. + /// + /// The diagonal of the matrix is not accessed. + #[track_caller] + pub fn solve_unit_upper_triangular>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + self.as_ref().solve_unit_upper_triangular(rhs.as_mat_ref()) + } + + /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn cholesky(&self, side: Side) -> Result, CholeskyError> { + self.as_ref().cholesky(side) + } + /// Returns the Bunch-Kaufman decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn lblt(&self, side: Side) -> Lblt { + self.as_ref().lblt(side) + } + /// Returns the LU decomposition of `self` with partial (row) pivoting. + #[track_caller] + pub fn partial_piv_lu(&self) -> PartialPivLu { + self.as_ref().partial_piv_lu() + } + /// Returns the LU decomposition of `self` with full pivoting. + #[track_caller] + pub fn full_piv_lu(&self) -> FullPivLu { + self.as_ref().full_piv_lu() + } + /// Returns the QR decomposition of `self`. + #[track_caller] + pub fn qr(&self) -> Qr { + self.as_ref().qr() + } + /// Returns the QR decomposition of `self` with column pivoting. + #[track_caller] + pub fn col_piv_qr(&self) -> ColPivQr { + self.as_ref().col_piv_qr() + } + /// Returns the SVD of `self`. + #[track_caller] + pub fn svd(&self) -> Svd { + self.as_ref().svd() + } + /// Returns the thin SVD of `self`. + #[track_caller] + pub fn thin_svd(&self) -> ThinSvd { + self.as_ref().thin_svd() + } + /// Returns the eigendecomposition of `self`, assuming it is self-adjoint. Only the provided + /// side is accessed. + #[track_caller] + pub fn selfadjoint_eigendecomposition( + &self, + side: Side, + ) -> SelfAdjointEigendecomposition { + self.as_ref().selfadjoint_eigendecomposition(side) + } + + /// Returns the eigendecomposition of `self`, as a complex matrix. + #[track_caller] + pub fn eigendecomposition< + ComplexE: ComplexField::Real>, + >( + &self, + ) -> Eigendecomposition { + self.as_ref().eigendecomposition::() + } + + /// Returns the eigendecomposition of `self`, when `E` is in the complex domain. + #[track_caller] + pub fn complex_eigendecomposition(&self) -> Eigendecomposition { + self.as_ref().complex_eigendecomposition() + } + + /// Returns the determinant of `self`. + #[track_caller] + pub fn determinant(&self) -> E::Canonical { + self.as_ref().determinant() + } + + /// Returns the eigenvalues of `self`, assuming it is self-adjoint. Only the provided + /// side is accessed. The order of the eigenvalues is currently unspecified. + #[track_caller] + pub fn selfadjoint_eigenvalues(&self, side: Side) -> Vec<::Real> { + self.as_ref().selfadjoint_eigenvalues(side) + } + + /// Returns the singular values of `self`, in nonincreasing order. + #[track_caller] + pub fn singular_values(&self) -> Vec<::Real> { + self.as_ref().singular_values() + } + + /// Returns the eigenvalues of `self`, as complex values. The order of the eigenvalues is + /// currently unspecified. + #[track_caller] + pub fn eigenvalues::Real>>( + &self, + ) -> Vec { + self.as_ref().eigenvalues() + } + + /// Returns the eigenvalues of `self`, when `E` is in the complex domain. The order of the + /// eigenvalues is currently unspecified. + #[track_caller] + pub fn complex_eigenvalues(&self) -> Vec { + self.as_ref().complex_eigenvalues() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{assert, RealField}; + use complex_native::*; + + #[track_caller] + fn assert_approx_eq(a: impl AsMatRef, b: impl AsMatRef) { + let a = a.as_mat_ref(); + let b = b.as_mat_ref(); + let eps = E::Real::faer_epsilon().unwrap().faer_sqrt(); + + assert!(a.nrows() == b.nrows()); + assert!(a.ncols() == b.ncols()); + + let m = a.nrows(); + let n = a.ncols(); + + for j in 0..n { + for i in 0..m { + assert!((a.read(i, j).faer_sub(b.read(i, j))).faer_abs() < eps); + } + } + } + + fn test_solver_real(H: impl AsMatRef, decomp: &dyn SolverCore) { + let H = H.as_mat_ref(); + let n = H.nrows(); + let k = 2; + + let random = |_, _| rand::random::(); + let rhs = Mat::from_fn(n, k, random); + + let I = Mat::from_fn(n, n, |i, j| { + if i == j { + f64::faer_one() + } else { + f64::faer_zero() + } + }); + + let sol = decomp.solve(&rhs); + assert_approx_eq(H * &sol, &rhs); + + let sol = decomp.solve_conj(&rhs); + assert_approx_eq(H.conjugate() * &sol, &rhs); + + let sol = decomp.solve_transpose(&rhs); + assert_approx_eq(H.transpose() * &sol, &rhs); + + let sol = decomp.solve_conj_transpose(&rhs); + assert_approx_eq(H.adjoint() * &sol, &rhs); + + assert_approx_eq(decomp.reconstruct(), H); + assert_approx_eq(H * decomp.inverse(), I); + } + + fn test_solver(H: impl AsMatRef, decomp: &dyn SolverCore) { + let H = H.as_mat_ref(); + let n = H.nrows(); + let k = 2; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let rhs = Mat::from_fn(n, k, random); + + let I = Mat::from_fn(n, n, |i, j| { + if i == j { + c64::faer_one() + } else { + c64::faer_zero() + } + }); + + let sol = decomp.solve(&rhs); + assert_approx_eq(H * &sol, &rhs); + + let sol = decomp.solve_conj(&rhs); + assert_approx_eq(H.conjugate() * &sol, &rhs); + + let sol = decomp.solve_transpose(&rhs); + assert_approx_eq(H.transpose() * &sol, &rhs); + + let sol = decomp.solve_conj_transpose(&rhs); + assert_approx_eq(H.adjoint() * &sol, &rhs); + + assert_approx_eq(decomp.reconstruct(), H); + assert_approx_eq(H * decomp.inverse(), I); + } + + fn test_solver_lstsq(H: impl AsMatRef, decomp: &dyn SolverLstsqCore) { + let H = H.as_mat_ref(); + + let m = H.nrows(); + let k = 2; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let rhs = Mat::from_fn(m, k, random); + + let sol = decomp.solve_lstsq(&rhs); + assert_approx_eq(H.adjoint() * H * &sol, H.adjoint() * &rhs); + + let sol = decomp.solve_lstsq_conj(&rhs); + assert_approx_eq(H.transpose() * H.conjugate() * &sol, H.transpose() * &rhs); + } + + #[test] + fn test_lblt_real() { + let n = 7; + + let random = |_, _| rand::random::(); + let H = Mat::from_fn(n, n, random); + let H = &H + H.adjoint(); + + test_solver_real(&H, &H.lblt(Side::Lower)); + test_solver_real(&H, &H.lblt(Side::Upper)); + } + + #[test] + fn test_lblt() { + let n = 7; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let H = Mat::from_fn(n, n, random); + let H = &H + H.adjoint(); + + test_solver(&H, &H.lblt(Side::Lower)); + test_solver(&H, &H.lblt(Side::Upper)); + } + + #[test] + fn test_cholesky() { + let n = 7; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let H = Mat::from_fn(n, n, random); + let H = &H * H.adjoint(); + + test_solver(&H, &H.cholesky(Side::Lower).unwrap()); + test_solver(&H, &H.cholesky(Side::Upper).unwrap()); + } + + #[test] + fn test_partial_piv_lu() { + let n = 7; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let H = Mat::from_fn(n, n, random); + + test_solver(&H, &H.partial_piv_lu()); + } + + #[test] + fn test_full_piv_lu() { + let n = 7; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let H = Mat::from_fn(n, n, random); + + test_solver(&H, &H.full_piv_lu()); + } + + #[test] + fn test_qr_real() { + let n = 7; + + let random = |_, _| rand::random::(); + let H = Mat::from_fn(n, n, random); + + let qr = H.qr(); + test_solver_real(&H, &qr); + + for (m, n) in [(7, 5), (5, 7), (7, 7)] { + let H = Mat::from_fn(m, n, random); + let qr = H.qr(); + assert_approx_eq(qr.compute_q() * qr.compute_r(), &H); + assert_approx_eq(qr.compute_thin_q() * qr.compute_thin_r(), &H); + } + } + + #[test] + fn test_qr() { + let n = 7; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let H = Mat::from_fn(n, n, random); + + let qr = H.qr(); + test_solver(&H, &qr); + + for (m, n) in [(7, 5), (5, 7), (7, 7)] { + let H = Mat::from_fn(m, n, random); + let qr = H.qr(); + assert_approx_eq(qr.compute_q() * qr.compute_r(), &H); + assert_approx_eq(qr.compute_thin_q() * qr.compute_thin_r(), &H); + if m >= n { + test_solver_lstsq(H, &qr) + } + } + } + + #[test] + fn test_col_piv_qr() { + let n = 7; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let H = Mat::from_fn(n, n, random); + + test_solver(&H, &H.col_piv_qr()); + + for (m, n) in [(7, 5), (5, 7), (7, 7)] { + let H = Mat::from_fn(m, n, random); + let qr = H.col_piv_qr(); + assert_approx_eq( + qr.compute_q() * qr.compute_r(), + &H * qr.col_permutation().inverse(), + ); + assert_approx_eq( + qr.compute_thin_q() * qr.compute_thin_r(), + &H * qr.col_permutation().inverse(), + ); + if m >= n { + test_solver_lstsq(H, &qr) + } + } + } + + #[test] + fn test_svd() { + let n = 7; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let H = Mat::from_fn(n, n, random); + + test_solver(&H, &H.svd()); + test_solver(H.adjoint().to_owned(), &H.adjoint().svd()); + + let svd = H.svd(); + for i in 0..n - 1 { + assert!(svd.s_diagonal()[i].re >= svd.s_diagonal()[i + 1].re); + } + let svd = H.singular_values(); + for i in 0..n - 1 { + assert!(svd[i] >= svd[i + 1]); + } + } + + #[test] + fn test_thin_svd() { + let n = 7; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let H = Mat::from_fn(n, n, random); + + test_solver(&H, &H.thin_svd()); + test_solver(H.adjoint().to_owned(), &H.adjoint().thin_svd()); + } + + #[test] + fn test_selfadjoint_eigendecomposition() { + let n = 7; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let H = Mat::from_fn(n, n, random); + let H = &H * H.adjoint(); + + test_solver(&H, &H.selfadjoint_eigendecomposition(Side::Lower)); + test_solver(&H, &H.selfadjoint_eigendecomposition(Side::Upper)); + test_solver( + H.adjoint().to_owned(), + &H.adjoint().selfadjoint_eigendecomposition(Side::Lower), + ); + test_solver( + H.adjoint().to_owned(), + &H.adjoint().selfadjoint_eigendecomposition(Side::Upper), + ); + + let evd = H.selfadjoint_eigendecomposition(Side::Lower); + for i in 0..n - 1 { + assert!(evd.s().column_vector()[i].re <= evd.s().column_vector()[i + 1].re); + } + let evd = H.selfadjoint_eigenvalues(Side::Lower); + for i in 0..n - 1 { + assert!(evd[i] <= evd[i + 1]); + } + } + + #[test] + fn test_eigendecomposition() { + let n = 7; + + let random = |_, _| c64::new(rand::random(), rand::random()); + let H = Mat::from_fn(n, n, random); + + { + let eigen = H.eigendecomposition::(); + let s = eigen.s(); + let u = eigen.u(); + assert_approx_eq(u * s, &H * u); + } + + { + let eigen = H.complex_eigendecomposition(); + let s = eigen.s(); + let u = eigen.u(); + assert_approx_eq(u * &s, &H * u); + } + + let det = H.determinant(); + let eigen_det = H + .complex_eigenvalues() + .into_iter() + .fold(c64::faer_one(), |a, b| a * b); + + assert!((det - eigen_det).faer_abs() < 1e-8); + } + + #[test] + fn test_real_eigendecomposition() { + let n = 7; + + let random = |_, _| rand::random::(); + let H_real = Mat::from_fn(n, n, random); + let H = Mat::from_fn(n, n, |i, j| c64::new(H_real.read(i, j), 0.0)); + + let eigen = H_real.eigendecomposition::(); + let s = eigen.s(); + let u = eigen.u(); + assert_approx_eq(u * &s, &H * u); + } + + #[test] + fn this_other_tree_has_correct_maximum_eigenvalue_20() { + let edges = [ + (3, 2), + (6, 1), + (7, 4), + (7, 6), + (8, 5), + (9, 4), + (11, 2), + (12, 2), + (13, 2), + (15, 6), + (16, 2), + (16, 4), + (17, 8), + (18, 0), + (18, 8), + (18, 2), + (19, 6), + (19, 10), + (19, 14), + ]; + let mut a = Mat::zeros(20, 20); + for (v, u) in edges.iter() { + a[(*v, *u)] = 1.0; + a[(*u, *v)] = 1.0; + } + let eigs_complex: Vec = a.eigenvalues(); + println!("{eigs_complex:?}"); + let eigs_real = eigs_complex.iter().map(|e| e.re).collect::>(); + println!("{eigs_real:?}"); + let lambda_1 = *eigs_real + .iter() + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(); + let correct_lamba_1 = 2.6148611139728866; + assert!( + (lambda_1 - correct_lamba_1).abs() < 1e-10, + "lambda_1 = {lambda_1}, correct_lamba_1 = {correct_lamba_1}", + ); + } + + #[test] + fn this_other_tree_has_correct_maximum_eigenvalue_3() { + let edges = [(1, 0), (0, 2)]; + let mut a = Mat::zeros(3, 3); + for (v, u) in edges.iter() { + a[(*v, *u)] = 1.0; + a[(*u, *v)] = 1.0; + } + let eigs_complex: Vec = a.eigenvalues(); + let eigs_real = eigs_complex.iter().map(|e| e.re).collect::>(); + let lambda_1 = *eigs_real + .iter() + .max_by(|a, b| a.partial_cmp(b).unwrap()) + .unwrap(); + let correct_lamba_1 = 1.414213562373095; + assert!( + (lambda_1 - correct_lamba_1).abs() < 1e-10, + "lambda_1 = {lambda_1}, correct_lamba_1 = {correct_lamba_1}", + ); + } +} diff --git a/faer-libs/faer-svd/src/bidiag.rs b/src/linalg/svd/bidiag.rs similarity index 96% rename from faer-libs/faer-svd/src/bidiag.rs rename to src/linalg/svd/bidiag.rs index e843d964ce2043906a51c366e18f37d641482eb8..0947eb53e0609376def78c70c8caf07ecd2d0137 100644 --- a/faer-libs/faer-svd/src/bidiag.rs +++ b/src/linalg/svd/bidiag.rs @@ -1,10 +1,12 @@ +use crate::{ + assert, + linalg::{matmul::matmul, temp_mat_req, temp_mat_uninit, temp_mat_zeroed}, + unzipped, + utils::thread::{for_each_raw, par_split_indices, parallelism_degree}, + zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, +}; use core::slice; use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, for_each_raw, mul::matmul, par_split_indices, parallelism_degree, simd, temp_mat_req, - temp_mat_uninit, temp_mat_zeroed, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, - Parallelism, SimdCtx, -}; use faer_entity::*; use pulp::Simd; use reborrow::*; @@ -102,7 +104,7 @@ pub fn bidiagonalize_in_place( let head = a_col.read(0, 0); let essential = a_col.rb_mut().col_mut(0).subrows_mut(1, m - 1); let tail_norm = essential.norm_l2(); - faer_core::householder::make_householder_in_place_v2( + crate::linalg::householder::make_householder_in_place( Some(essential.as_2d_mut()), head, tail_norm, @@ -139,7 +141,7 @@ pub fn bidiagonalize_in_place( let head = a_row.read(0, 0); let essential = a_row.rb().row(0).subcols(1, n - 2).transpose(); let tail_norm = essential.norm_l2(); - faer_core::householder::make_householder_in_place_v2(None, head, tail_norm) + crate::linalg::householder::make_householder_in_place(None, head, tail_norm) }; householder_right.write(k, 0, tr); @@ -165,7 +167,7 @@ pub fn bidiagonalize_in_place( } a_row.write(0, 0, E::faer_one()); - let b = faer_core::mul::inner_prod::inner_prod_with_conj( + let b = crate::linalg::matmul::inner_prod::inner_prod_with_conj( y.rb().col(0).as_2d(), Conj::Yes, a_row.rb().row(0).transpose().as_2d(), @@ -302,10 +304,10 @@ fn bidiag_fused_op_step0( u, } = self; - let (a_j_head, a_j_tail) = simd::slice_as_mut_simd::(a_j); - let (z_head, z_tail) = simd::slice_as_simd::(z); - let (u_prev_head, u_prev_tail) = simd::slice_as_simd::(u_prev); - let (u_head, u_tail) = simd::slice_as_simd::(u); + let (a_j_head, a_j_tail) = faer_entity::slice_as_mut_simd::(a_j); + let (z_head, z_tail) = faer_entity::slice_as_simd::(z); + let (u_prev_head, u_prev_tail) = faer_entity::slice_as_simd::(u_prev); + let (u_head, u_tail) = faer_entity::slice_as_simd::(u); let (a_j_head4, a_j_head1) = E::faer_as_arrays_mut::<4, _>(a_j_head); let (z_head4, z_head1) = E::faer_as_arrays::<4, _>(z_head); @@ -480,8 +482,8 @@ fn bidiag_fused_op_step1<'a, E: ComplexField>( #[inline(always)] fn with_simd(self, simd: S) -> Self::Output { let Self { z, a_j, rhs } = self; - let (z_head, z_tail) = simd::slice_as_mut_simd::(z); - let (a_j_head, a_j_tail) = simd::slice_as_simd::(a_j); + let (z_head, z_tail) = faer_entity::slice_as_mut_simd::(z); + let (a_j_head, a_j_tail) = faer_entity::slice_as_simd::(a_j); let rhs_v = E::faer_simd_splat(simd, rhs); for (zi, aij) in E::faer_into_iter(z_head).zip(E::faer_into_iter(a_j_head)) { @@ -810,10 +812,10 @@ fn bidiag_fused_op( #[cfg(test)] mod tests { use super::*; - use assert_approx_eq::assert_approx_eq; - use faer_core::{ - assert, c64, - householder::{ + use crate::{ + assert, + complex_native::c64, + linalg::householder::{ apply_block_householder_sequence_on_the_right_in_place_with_conj, apply_block_householder_sequence_transpose_on_the_left_in_place_req, apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj, @@ -821,6 +823,7 @@ mod tests { }, Mat, }; + use assert_approx_eq::assert_approx_eq; macro_rules! make_stack { ($req: expr $(,)?) => { diff --git a/faer-libs/faer-svd/src/bidiag_real_svd.rs b/src/linalg/svd/bidiag_real_svd.rs similarity index 99% rename from faer-libs/faer-svd/src/bidiag_real_svd.rs rename to src/linalg/svd/bidiag_real_svd.rs index 5895a884b266467c2c68bdf95e09456c6aa04d21..b8f8f0b7525cb89311ae0a51dfd0d8697c18166f 100644 --- a/faer-libs/faer-svd/src/bidiag_real_svd.rs +++ b/src/linalg/svd/bidiag_real_svd.rs @@ -12,20 +12,23 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. -use crate::jacobi::{jacobi_svd, Skip}; +use super::jacobi::{jacobi_svd, JacobiRotation, Skip}; +use crate::{ + assert, + linalg::{temp_mat_req, temp_mat_uninit, temp_mat_zeroed}, + unzipped, + utils::{simd::SimdFor, thread::join_raw}, + zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, RealField, +}; use coe::Coerce; use core::{iter::zip, mem::swap}; use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - assert, group_helpers::SimdFor, jacobi::JacobiRotation, join_raw, temp_mat_req, - temp_mat_uninit, temp_mat_zeroed, unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, - Parallelism, RealField, -}; +use faer_entity::*; use reborrow::*; #[allow(dead_code)] -fn bidiag_to_mat(diag: &[E], subdiag: &[E]) -> faer_core::Mat { - let mut mat = faer_core::Mat::::zeros(diag.len() + 1, diag.len()); +fn bidiag_to_mat(diag: &[E], subdiag: &[E]) -> crate::Mat { + let mut mat = crate::Mat::::zeros(diag.len() + 1, diag.len()); for (i, d) in diag.iter().enumerate() { mat.write(i, i, *d); @@ -38,7 +41,7 @@ fn bidiag_to_mat(diag: &[E], subdiag: &[E]) -> faer_core::Mat { } fn norm(v: MatRef<'_, E>) -> E { - faer_core::mul::inner_prod::inner_prod_with_conj(v, Conj::No, v, Conj::No).faer_sqrt() + crate::linalg::matmul::inner_prod::inner_prod_with_conj(v, Conj::No, v, Conj::No).faer_sqrt() } fn compute_svd_of_m( @@ -1446,7 +1449,6 @@ fn bidiag_svd_qr_algorithm_impl( } } - use faer_entity::SimdCtx; E::Simd::default().dispatch(Impl { epsilon, consider_zero_threshold, @@ -1481,10 +1483,10 @@ fn bidiag_svd_qr_algorithm_impl( if k != max_idx { diag.swap(k, max_idx); if let Some(u) = u.rb_mut() { - faer_core::permutation::swap_cols(u, k, max_idx); + crate::perm::swap_cols_idx(u, k, max_idx); } if let Some(v) = v.rb_mut() { - faer_core::permutation::swap_cols(v, k, max_idx); + crate::perm::swap_cols_idx(v, k, max_idx); } } } @@ -1844,7 +1846,7 @@ fn bidiag_svd_impl( if compact_u == 1 { // handle rotation of Q1, q1 for i in (0..k).rev() { - faer_core::permutation::swap_cols(u1.rb_mut(), i, i + 1); + crate::perm::swap_cols_idx(u1.rb_mut(), i, i + 1); } } @@ -2062,7 +2064,7 @@ fn bidiag_svd_impl( join_raw( |parallelism| { - faer_core::mul::matmul( + crate::linalg::matmul::matmul( combined_v1.rb_mut(), v_lhs1, v_rhs1, @@ -2072,7 +2074,7 @@ fn bidiag_svd_impl( ) }, |parallelism| { - faer_core::mul::matmul( + crate::linalg::matmul::matmul( combined_v2.rb_mut(), v_lhs2, v_rhs2, @@ -2084,7 +2086,7 @@ fn bidiag_svd_impl( parallelism, ); - faer_core::mul::matmul( + crate::linalg::matmul::matmul( combined_v.rb_mut().submatrix_mut(k, 0, 1, n), v_lhs.submatrix(k, 0, 1, 1), v_rhs.submatrix(0, 0, 1, n), @@ -2115,7 +2117,7 @@ fn bidiag_svd_impl( join_raw( |parallelism| { // matrix matrix - faer_core::mul::matmul( + crate::linalg::matmul::matmul( combined_u1.rb_mut(), u_lhs1, u_rhs1, @@ -2124,7 +2126,7 @@ fn bidiag_svd_impl( parallelism, ); // rank 1 update - faer_core::mul::matmul( + crate::linalg::matmul::matmul( combined_u1.rb_mut(), u_lhs.col(n).subrows(0, k + 1).as_2d(), u_rhs2.row(rem).as_2d(), @@ -2135,7 +2137,7 @@ fn bidiag_svd_impl( }, |parallelism| { // matrix matrix - faer_core::mul::matmul( + crate::linalg::matmul::matmul( combined_u2.rb_mut(), u_lhs2, u_rhs2, @@ -2144,7 +2146,7 @@ fn bidiag_svd_impl( parallelism, ); // rank 1 update - faer_core::mul::matmul( + crate::linalg::matmul::matmul( combined_u2.rb_mut(), u_lhs.col(0).subrows(k + 1, rem + 1).as_2d(), u_rhs1.row(0).as_2d(), @@ -2166,7 +2168,7 @@ fn bidiag_svd_impl( if fill_u { let (mut combined_u, _) = temp_mat_uninit::(2, n + 1, stack); let mut combined_u = combined_u.as_mut(); - faer_core::mul::matmul( + crate::linalg::matmul::matmul( combined_u.rb_mut(), u.rb(), um.rb(), @@ -2181,11 +2183,11 @@ fn bidiag_svd_impl( match parallelism { #[cfg(feature = "rayon")] Parallelism::Rayon(_) if !_v_is_none => { - let req_v = faer_core::temp_mat_req::(n, n).unwrap(); + let req_v = crate::linalg::temp_mat_req::(n, n).unwrap(); let (mem_v, stack_u) = stack.make_aligned_raw::(req_v.size_bytes(), req_v.align_bytes()); let stack_v = PodStack::new(mem_v); - faer_core::join_raw( + crate::utils::thread::join_raw( |parallelism| update_v(parallelism, stack_v), |parallelism| update_u(parallelism, stack_u), parallelism, @@ -2245,8 +2247,8 @@ pub fn bidiag_real_svd_req( #[cfg(test)] mod tests { use super::*; + use crate::{assert, Mat}; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, Mat}; macro_rules! make_stack { ($req: expr) => { diff --git a/faer-libs/faer-svd/src/jacobi.rs b/src/linalg/svd/jacobi.rs similarity index 53% rename from faer-libs/faer-svd/src/jacobi.rs rename to src/linalg/svd/jacobi.rs index 30f22e98543c72a5a8379927be42105ca651d50c..fd1a8611530219cd7a5cc2514f61eec52b3117b1 100644 --- a/faer-libs/faer-svd/src/jacobi.rs +++ b/src/linalg/svd/jacobi.rs @@ -8,9 +8,329 @@ // Public License v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. -use faer_core::{assert, jacobi::JacobiRotation, permutation::swap_cols, MatMut, RealField}; +use crate::{ + assert, + perm::swap_cols_idx as swap_cols, + unzipped, + utils::{simd::*, slice::*}, + zipped, MatMut, RealField, +}; +use faer_entity::{pulp, SimdCtx, SimdGroupFor}; use reborrow::*; +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub struct JacobiRotation { + pub c: T, + pub s: T, +} + +unsafe impl bytemuck::Zeroable for JacobiRotation {} +unsafe impl bytemuck::Pod for JacobiRotation {} + +impl JacobiRotation { + #[inline] + pub fn make_givens(p: E, q: E) -> Self { + if q == E::faer_zero() { + Self { + c: if p < E::faer_zero() { + E::faer_one().faer_neg() + } else { + E::faer_one() + }, + s: E::faer_zero(), + } + } else if p == E::faer_zero() { + Self { + c: E::faer_zero(), + s: if q < E::faer_zero() { + E::faer_one().faer_neg() + } else { + E::faer_one() + }, + } + } else if p.faer_abs() > q.faer_abs() { + let t = q.faer_div(p); + let mut u = E::faer_one().faer_add(t.faer_abs2()).faer_sqrt(); + if p < E::faer_zero() { + u = u.faer_neg(); + } + let c = u.faer_inv(); + let s = t.faer_neg().faer_mul(c); + + Self { c, s } + } else { + let t = p.faer_div(q); + let mut u = E::faer_one().faer_add(t.faer_abs2()).faer_sqrt(); + if q < E::faer_zero() { + u = u.faer_neg(); + } + let s = u.faer_inv().faer_neg(); + let c = t.faer_neg().faer_mul(s); + + Self { c, s } + } + } + + #[inline] + pub fn from_triplet(x: E, y: E, z: E) -> Self { + let abs_y = y.faer_abs(); + let two_abs_y = abs_y.faer_add(abs_y); + if two_abs_y == E::faer_zero() { + Self { + c: E::faer_one(), + s: E::faer_zero(), + } + } else { + let tau = (x.faer_sub(z)).faer_mul(two_abs_y.faer_inv()); + let w = ((tau.faer_mul(tau)).faer_add(E::faer_one())).faer_sqrt(); + let t = if tau > E::faer_zero() { + (tau.faer_add(w)).faer_inv() + } else { + (tau.faer_sub(w)).faer_inv() + }; + + let neg_sign_y = if y > E::faer_zero() { + E::faer_one().faer_neg() + } else { + E::faer_one() + }; + let n = (t.faer_mul(t).faer_add(E::faer_one())) + .faer_sqrt() + .faer_inv(); + + Self { + c: n, + s: neg_sign_y.faer_mul(t).faer_mul(n), + } + } + } + + #[inline] + pub fn apply_on_the_left_2x2(&self, m00: E, m01: E, m10: E, m11: E) -> (E, E, E, E) { + let Self { c, s } = *self; + ( + m00.faer_mul(c).faer_add(m10.faer_mul(s)), + m01.faer_mul(c).faer_add(m11.faer_mul(s)), + s.faer_neg().faer_mul(m00).faer_add(c.faer_mul(m10)), + s.faer_neg().faer_mul(m01).faer_add(c.faer_mul(m11)), + ) + } + + #[inline] + pub fn apply_on_the_right_2x2(&self, m00: E, m01: E, m10: E, m11: E) -> (E, E, E, E) { + let (r00, r01, r10, r11) = self.transpose().apply_on_the_left_2x2(m00, m10, m01, m11); + (r00, r10, r01, r11) + } + + #[inline] + pub fn apply_on_the_left_in_place(&self, x: MatMut<'_, E>, y: MatMut<'_, E>) { + self.apply_on_the_left_in_place_arch(E::Simd::default(), x, y); + } + + #[inline(never)] + fn apply_on_the_left_in_place_fallback(&self, x: MatMut<'_, E>, y: MatMut<'_, E>) { + let Self { c, s } = *self; + zipped!(x, y).for_each(move |unzipped!(mut x, mut y)| { + let x_ = x.read(); + let y_ = y.read(); + x.write(c.faer_mul(x_).faer_add(s.faer_mul(y_))); + y.write(s.faer_neg().faer_mul(x_).faer_add(c.faer_mul(y_))); + }); + } + + #[inline(always)] + pub fn apply_on_the_right_in_place_with_simd_and_offset( + &self, + simd: S, + offset: pulp::Offset>, + x: MatMut<'_, E>, + y: MatMut<'_, E>, + ) { + self.transpose() + .apply_on_the_left_in_place_with_simd_and_offset( + simd, + offset, + x.transpose_mut(), + y.transpose_mut(), + ); + } + + #[inline(always)] + pub fn apply_on_the_left_in_place_with_simd_and_offset( + &self, + simd: S, + offset: pulp::Offset>, + x: MatMut<'_, E>, + y: MatMut<'_, E>, + ) { + let Self { c, s } = *self; + assert!(all(x.nrows() == 1, y.nrows() == 1, x.ncols() == y.ncols())); + + if c == E::faer_one() && s == E::faer_zero() { + return; + } + + if x.col_stride() != 1 || y.col_stride() != 1 { + self.apply_on_the_left_in_place_fallback(x, y); + return; + } + + let simd = SimdFor::::new(simd); + + let x = SliceGroupMut::<'_, E>::new(x.transpose_mut().try_get_contiguous_col_mut(0)); + let y = SliceGroupMut::<'_, E>::new(y.transpose_mut().try_get_contiguous_col_mut(0)); + + let c = simd.splat(c); + let s = simd.splat(s); + + let (x_head, x_body, x_tail) = simd.as_aligned_simd_mut(x, offset); + let (y_head, y_body, y_tail) = simd.as_aligned_simd_mut(y, offset); + + #[inline(always)] + fn process( + simd: SimdFor, + mut x: impl Write>, + mut y: impl Write>, + c: SimdGroupFor, + s: SimdGroupFor, + ) { + let zero = simd.splat(E::faer_zero()); + let x_ = x.read_or(zero); + let y_ = y.read_or(zero); + x.write(simd.mul_add_e(c, x_, simd.mul(s, y_))); + y.write(simd.mul_add_e(c, y_, simd.neg(simd.mul(s, x_)))); + } + + process(simd, x_head, y_head, c, s); + for (x, y) in x_body.into_mut_iter().zip(y_body.into_mut_iter()) { + process(simd, x, y, c, s); + } + process(simd, x_tail, y_tail, c, s); + } + + #[inline] + pub fn apply_on_the_left_in_place_arch( + &self, + arch: E::Simd, + x: MatMut<'_, E>, + y: MatMut<'_, E>, + ) { + struct ApplyOnLeft<'a, E: RealField> { + c: E, + s: E, + x: MatMut<'a, E>, + y: MatMut<'a, E>, + } + + impl pulp::WithSimd for ApplyOnLeft<'_, E> { + type Output = (); + + #[inline(always)] + fn with_simd(self, simd: S) -> Self::Output { + let Self { x, y, c, s } = self; + assert!(all(x.nrows() == 1, y.nrows() == 1, x.ncols() == y.ncols())); + + if c == E::faer_one() && s == E::faer_zero() { + return; + } + + let simd = SimdFor::::new(simd); + + let x = + SliceGroupMut::<'_, E>::new(x.transpose_mut().try_get_contiguous_col_mut(0)); + let y = + SliceGroupMut::<'_, E>::new(y.transpose_mut().try_get_contiguous_col_mut(0)); + + let offset = simd.align_offset(x.rb()); + + let c = simd.splat(c); + let s = simd.splat(s); + + let (x_head, x_body, x_tail) = simd.as_aligned_simd_mut(x, offset); + let (y_head, y_body, y_tail) = simd.as_aligned_simd_mut(y, offset); + + #[inline(always)] + fn process( + simd: SimdFor, + mut x: impl Write>, + mut y: impl Write>, + c: SimdGroupFor, + s: SimdGroupFor, + ) { + let zero = simd.splat(E::faer_zero()); + let x_ = x.read_or(zero); + let y_ = y.read_or(zero); + x.write(simd.mul_add_e(c, x_, simd.mul(s, y_))); + y.write(simd.mul_add_e(c, y_, simd.neg(simd.mul(s, x_)))); + } + + process(simd, x_head, y_head, c, s); + for (x, y) in x_body.into_mut_iter().zip(y_body.into_mut_iter()) { + process(simd, x, y, c, s); + } + process(simd, x_tail, y_tail, c, s); + } + } + + let Self { c, s } = *self; + + let mut x = x; + let mut y = y; + + if x.col_stride() == 1 && y.col_stride() == 1 { + arch.dispatch(ApplyOnLeft::<'_, E> { c, s, x, y }); + } else { + zipped!(x, y).for_each(move |unzipped!(mut x, mut y)| { + let x_ = x.read(); + let y_ = y.read(); + x.write(c.faer_mul(x_).faer_add(s.faer_mul(y_))); + y.write(s.faer_neg().faer_mul(x_).faer_add(c.faer_mul(y_))); + }); + } + } + + #[inline] + pub fn apply_on_the_right_in_place(&self, x: MatMut<'_, E>, y: MatMut<'_, E>) { + self.transpose() + .apply_on_the_left_in_place(x.transpose_mut(), y.transpose_mut()); + } + + #[inline] + pub fn apply_on_the_right_in_place_arch( + &self, + arch: E::Simd, + x: MatMut<'_, E>, + y: MatMut<'_, E>, + ) { + self.transpose().apply_on_the_left_in_place_arch( + arch, + x.transpose_mut(), + y.transpose_mut(), + ); + } + + #[inline] + pub fn transpose(&self) -> Self { + Self { + c: self.c, + s: self.s.faer_neg(), + } + } +} + +impl core::ops::Mul for JacobiRotation { + type Output = Self; + + #[inline] + fn mul(self, rhs: Self) -> Self::Output { + Self { + c: self.c.faer_mul(rhs.c).faer_sub(self.s.faer_mul(rhs.s)), + s: self.c.faer_mul(rhs.s).faer_add(self.s.faer_mul(rhs.c)), + } + } +} + fn compute_2x2( m00: E, m01: E, @@ -248,8 +568,8 @@ pub fn jacobi_svd( #[cfg(test)] mod tests { use super::*; + use crate::{assert, Mat, MatRef}; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, Mat, MatRef}; #[track_caller] fn check_svd(mat: MatRef<'_, f64>, u: MatRef<'_, f64>, v: MatRef<'_, f64>, s: MatRef<'_, f64>) { @@ -392,7 +712,7 @@ mod tests { #[test] fn eigen_286() { - let mat = faer_core::mat![[-7.90884e-313, -4.94e-324], [0.0, 5.60844e-313]]; + let mat = crate::mat![[-7.90884e-313, -4.94e-324], [0.0, 5.60844e-313]]; let n = 2; let mut s = mat.clone(); let mut u = Mat::::zeros(n, n); diff --git a/faer-libs/faer-svd/src/lib.rs b/src/linalg/svd/mod.rs similarity index 97% rename from faer-libs/faer-svd/src/lib.rs rename to src/linalg/svd/mod.rs index 8a47a2f3c6639c9ab6ea8a0286ea64929dfe52d8..e205501f5a68ff323bcb2c783b78b1c45300c735 100644 --- a/faer-libs/faer-svd/src/lib.rs +++ b/src/linalg/svd/mod.rs @@ -1,3 +1,5 @@ +//! Low level implementation of the SVD of a matrix. +//! //! The SVD of a matrix $M$ of shape $(m, n)$ is a decomposition into three components $U$, $S$, //! and $V$, such that: //! @@ -8,29 +10,26 @@ //! //! $$M = U S V^H.$$ -#![allow(clippy::type_complexity)] -#![allow(clippy::too_many_arguments)] -#![cfg_attr(not(feature = "std"), no_std)] - -use bidiag_real_svd::bidiag_real_svd_req; -use coe::Coerce; -use core::mem::swap; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - householder::{ - apply_block_householder_sequence_on_the_left_in_place_req, - apply_block_householder_sequence_on_the_left_in_place_with_conj, - upgrade_householder_factor, + linalg::{ + householder::{ + apply_block_householder_sequence_on_the_left_in_place_req, + apply_block_householder_sequence_on_the_left_in_place_with_conj, + upgrade_householder_factor, + }, + qr as faer_qr, temp_mat_req, temp_mat_uninit, + zip::Diag, }, - temp_mat_req, temp_mat_uninit, unzipped, - zip::Diag, - zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, RealField, + unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, RealField, }; +use coe::Coerce; +use core::mem::swap; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use num_complex::Complex; use reborrow::*; -use crate::bidiag_real_svd::compute_bidiag_real_svd; +use bidiag_real_svd::{bidiag_real_svd_req, compute_bidiag_real_svd}; #[doc(hidden)] pub mod bidiag; @@ -45,8 +44,11 @@ const BIDIAG_QR_FALLBACK_THRESHOLD: usize = 128; /// Indicates whether the singular vectors are fully computed, partially computed, or skipped. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum ComputeVectors { + /// Do not compute the singular vectors. No, + /// Only compute the first $\min(\text{nrows}(A), \text{ncols}(A))$ singular vectors. Thin, + /// Compute all the singular vectors. Full, } @@ -267,7 +269,7 @@ fn compute_real_svd_small( .for_each(|unzipped!(mut dst)| dst.write(E::faer_one())); } - faer_core::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj( + crate::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj( qr.rb(), householder.rb(), Conj::No, @@ -551,13 +553,13 @@ fn compute_svd_big( .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); let (mut bid_col_major, mut stack) = - faer_core::temp_mat_uninit::(n - 1, m, stack.rb_mut()); + crate::linalg::temp_mat_uninit::(n - 1, m, stack.rb_mut()); let mut bid_col_major = bid_col_major.as_mut(); zipped!( bid_col_major.rb_mut(), bid.submatrix(0, 1, m, n - 1).transpose() ) - .for_each_triangular_lower(faer_core::zip::Diag::Skip, |unzipped!(mut dst, src)| { + .for_each_triangular_lower(crate::linalg::zip::Diag::Skip, |unzipped!(mut dst, src)| { dst.write(src.read()) }); @@ -572,6 +574,7 @@ fn compute_svd_big( } } +/// SVD tuning parameters. #[derive(Default, Copy, Clone)] #[non_exhaustive] pub struct SvdParams {} @@ -741,7 +744,7 @@ pub fn compute_svd_custom_epsilon( #[cfg(feature = "perf-warn")] match (u.rb(), v.rb()) { (Some(matrix), _) | (_, Some(matrix)) => { - if matrix.row_stride().unsigned_abs() != 1 && faer_core::__perf_warn!(QR_WARN) { + if matrix.row_stride().unsigned_abs() != 1 && crate::__perf_warn!(QR_WARN) { if matrix.col_stride().unsigned_abs() == 1 { log::warn!(target: "faer_perf", "SVD prefers column-major singular vector matrices. Found row-major matrix."); } else { @@ -863,7 +866,7 @@ pub fn compute_svd_custom_epsilon( .for_each(|unzipped!(mut dst)| dst.write(E::faer_one())); } - faer_core::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj( + crate::linalg::householder::apply_block_householder_sequence_on_the_left_in_place_with_conj( qr.rb(), householder.rb(), Conj::No, @@ -939,8 +942,12 @@ fn squareish_svd( #[cfg(test)] mod tests { use super::*; + use crate::{ + assert, + complex_native::{c32, c64}, + Mat, + }; use assert_approx_eq::assert_approx_eq; - use faer_core::{assert, c32, c64, Mat}; macro_rules! make_stack { ($req: expr) => { @@ -1608,13 +1615,13 @@ mod tests { s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(), Some(u.as_mut()), Some(v.as_mut()), - faer_core::Parallelism::None, + crate::Parallelism::None, make_stack!(compute_svd_req::( m, n, ComputeVectors::Full, ComputeVectors::Full, - faer_core::Parallelism::None, + crate::Parallelism::None, Default::default(), )), Default::default(), @@ -1646,13 +1653,13 @@ mod tests { s.as_mut().diagonal_mut().column_vector_mut().as_2d_mut(), Some(u.as_mut()), Some(v.as_mut()), - faer_core::Parallelism::None, + crate::Parallelism::None, make_stack!(compute_svd_req::( m, n, ComputeVectors::Full, ComputeVectors::Full, - faer_core::Parallelism::None, + crate::Parallelism::None, Default::default(), )), Default::default(), diff --git a/faer-libs/faer-core/src/inverse.rs b/src/linalg/triangular_inverse.rs similarity index 97% rename from faer-libs/faer-core/src/inverse.rs rename to src/linalg/triangular_inverse.rs index a3891ae42b9c6c10144e3db69331e5762a4b3470..27625a0da846051c9585eccd63e86f953ed96512 100644 --- a/faer-libs/faer-core/src/inverse.rs +++ b/src/linalg/triangular_inverse.rs @@ -1,9 +1,13 @@ //! Triangular matrix inversion. use crate::{ - assert, join_raw, - mul::triangular::{self, BlockStructure}, - solve, ComplexField, MatMut, MatRef, Parallelism, + assert, + linalg::{ + matmul::triangular::{self, BlockStructure}, + triangular_solve as solve, + }, + utils::thread::join_raw, + ComplexField, MatMut, MatRef, Parallelism, }; use reborrow::*; diff --git a/faer-libs/faer-core/src/solve.rs b/src/linalg/triangular_solve.rs similarity index 95% rename from faer-libs/faer-core/src/solve.rs rename to src/linalg/triangular_solve.rs index 2f3f043a7c602d687a3b5a02f8003cb4580d1ff8..2870b413fd37dd1f378615073f0a9c51dd1e4ad7 100644 --- a/faer-libs/faer-core/src/solve.rs +++ b/src/linalg/triangular_solve.rs @@ -1,8 +1,8 @@ //! Triangular solve module. use crate::{ - assert, debug_assert, join_raw, unzipped, zipped, ComplexField, Conj, Conjugate, MatMut, - MatRef, Parallelism, + assert, debug_assert, unzipped, utils::thread::join_raw, zipped, ComplexField, Conj, Conjugate, + MatMut, MatRef, Parallelism, }; use faer_entity::SimdCtx; use reborrow::*; @@ -259,11 +259,12 @@ fn recursion_threshold() -> usize { /// # Example /// /// ``` -/// use faer_core::{ -/// mat, -/// mul::triangular::{matmul, BlockStructure}, -/// solve::solve_lower_triangular_in_place_with_conj, -/// unzipped, zipped, Conj, Mat, Parallelism, +/// use faer::{ +/// linalg::{ +/// matmul::triangular::{matmul, BlockStructure}, +/// triangular_solve::solve_lower_triangular_in_place_with_conj, +/// }, +/// mat, unzipped, zipped, Conj, Mat, Parallelism, /// }; /// /// let m = mat![[1.0, 0.0], [2.0, 3.0]]; @@ -344,11 +345,12 @@ pub fn solve_lower_triangular_in_place( parallelism, ); - crate::mul::matmul_with_conj( + crate::linalg::matmul::matmul_with_conj( rhs_bot.rb_mut(), tril_bot_left, conj_lhs, @@ -767,7 +771,7 @@ unsafe fn solve_lower_triangular_in_place_unchecked( parallelism, ); - crate::mul::matmul_with_conj( + crate::linalg::matmul::matmul_with_conj( rhs_bot.rb_mut(), tril_bot_left, conj_lhs, diff --git a/src/linalg/zip.rs b/src/linalg/zip.rs new file mode 100644 index 0000000000000000000000000000000000000000..6aabf9a4d7e5ea6d955c887cdcbe92a995a1b66d --- /dev/null +++ b/src/linalg/zip.rs @@ -0,0 +1,1684 @@ +//! Implementation of [`zipped!`] structures. + +use self::{ + col::{Col, ColMut, ColRef}, + mat::{Mat, MatMut, MatRef}, + row::{Row, RowMut, RowRef}, +}; +use crate::{assert, debug_assert, *}; +use core::mem::MaybeUninit; +use faer_entity::*; +use reborrow::*; + +/// Read only view over a single matrix element. +pub struct Read<'a, E: Entity> { + ptr: GroupFor>, +} +/// Read-write view over a single matrix element. +pub struct ReadWrite<'a, E: Entity> { + ptr: GroupFor>, +} + +/// Type that can be converted to a view. +pub trait ViewMut { + /// View type. + type Target<'a> + where + Self: 'a; + + /// Returns the view over self. + fn view_mut(&mut self) -> Self::Target<'_>; +} + +impl ViewMut for Row { + type Target<'a> = RowRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + self.as_ref() + } +} +impl ViewMut for &Row { + type Target<'a> = RowRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (*self).as_ref() + } +} +impl ViewMut for &mut Row { + type Target<'a> = RowMut<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (*self).as_mut() + } +} + +impl ViewMut for RowRef<'_, E> { + type Target<'a> = RowRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + *self + } +} +impl ViewMut for RowMut<'_, E> { + type Target<'a> = RowMut<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (*self).rb_mut() + } +} +impl ViewMut for &mut RowRef<'_, E> { + type Target<'a> = RowRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + **self + } +} +impl ViewMut for &mut RowMut<'_, E> { + type Target<'a> = RowMut<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (**self).rb_mut() + } +} +impl ViewMut for &RowRef<'_, E> { + type Target<'a> = RowRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + **self + } +} +impl ViewMut for &RowMut<'_, E> { + type Target<'a> = RowRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (**self).rb() + } +} + +impl ViewMut for Col { + type Target<'a> = ColRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + self.as_ref() + } +} +impl ViewMut for &Col { + type Target<'a> = ColRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (*self).as_ref() + } +} +impl ViewMut for &mut Col { + type Target<'a> = ColMut<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (*self).as_mut() + } +} + +impl ViewMut for ColRef<'_, E> { + type Target<'a> = ColRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + *self + } +} +impl ViewMut for ColMut<'_, E> { + type Target<'a> = ColMut<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (*self).rb_mut() + } +} +impl ViewMut for &mut ColRef<'_, E> { + type Target<'a> = ColRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + **self + } +} +impl ViewMut for &mut ColMut<'_, E> { + type Target<'a> = ColMut<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (**self).rb_mut() + } +} +impl ViewMut for &ColRef<'_, E> { + type Target<'a> = ColRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + **self + } +} +impl ViewMut for &ColMut<'_, E> { + type Target<'a> = ColRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (**self).rb() + } +} + +impl ViewMut for Mat { + type Target<'a> = MatRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + self.as_ref() + } +} +impl ViewMut for &Mat { + type Target<'a> = MatRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (*self).as_ref() + } +} +impl ViewMut for &mut Mat { + type Target<'a> = MatMut<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (*self).as_mut() + } +} + +impl ViewMut for MatRef<'_, E> { + type Target<'a> = MatRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + *self + } +} +impl ViewMut for MatMut<'_, E> { + type Target<'a> = MatMut<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (*self).rb_mut() + } +} +impl ViewMut for &mut MatRef<'_, E> { + type Target<'a> = MatRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + **self + } +} +impl ViewMut for &mut MatMut<'_, E> { + type Target<'a> = MatMut<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (**self).rb_mut() + } +} +impl ViewMut for &MatRef<'_, E> { + type Target<'a> = MatRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + **self + } +} +impl ViewMut for &MatMut<'_, E> { + type Target<'a> = MatRef<'a, E> + where + Self: 'a; + + #[inline] + fn view_mut(&mut self) -> Self::Target<'_> { + (**self).rb() + } +} + +impl core::ops::Deref for Read<'_, E> { + type Target = E; + #[inline(always)] + fn deref(&self) -> &Self::Target { + unsafe { &*(self.ptr as *const _ as *const E::Unit) } + } +} +impl core::ops::Deref for ReadWrite<'_, E> { + type Target = E; + #[inline(always)] + fn deref(&self) -> &Self::Target { + unsafe { &*(self.ptr as *const _ as *const E::Unit) } + } +} +impl core::ops::DerefMut for ReadWrite<'_, E> { + #[inline(always)] + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *(self.ptr as *mut _ as *mut E::Unit) } + } +} + +impl Read<'_, E> { + /// Read the value of the element. + #[inline(always)] + pub fn read(&self) -> E { + E::faer_from_units(E::faer_map( + E::faer_as_ref(&self.ptr), + #[inline(always)] + |ptr| unsafe { ptr.assume_init_read() }, + )) + } +} +impl ReadWrite<'_, E> { + /// Read the value of the element. + #[inline(always)] + pub fn read(&self) -> E { + E::faer_from_units(E::faer_map( + E::faer_as_ref(&self.ptr), + #[inline(always)] + |ptr| unsafe { *ptr.assume_init_ref() }, + )) + } + + /// Write to the location of the element. + #[inline(always)] + pub fn write(&mut self, value: E) { + let value = E::faer_into_units(value); + E::faer_map( + E::faer_zip(E::faer_as_mut(&mut self.ptr), value), + #[inline(always)] + |(ptr, value)| unsafe { *ptr.assume_init_mut() = value }, + ); + } +} + +/// Specifies whether the main diagonal should be traversed, when iterating over a triangular +/// chunk of the matrix. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum Diag { + /// Do not include diagonal of matrix + Skip, + /// Include diagonal of matrix + Include, +} + +/// Matrix layout transformation. Used for zipping optimizations. +#[derive(Copy, Clone)] +pub enum MatLayoutTransform { + /// Matrix is used as-is. + None, + /// Matrix rows are reversed. + ReverseRows, + /// Matrix is transposed. + Transpose, + /// Matrix is transposed, then rows are reversed. + TransposeReverseRows, +} + +/// Vector layout transformation. Used for zipping optimizations. +#[derive(Copy, Clone)] +pub enum VecLayoutTransform { + /// Vector is used as-is. + None, + /// Vector is reversed. + Reverse, +} + +/// Type with a given matrix shape. +pub trait MatShape { + /// Type of rows. + type Rows: Copy + Eq; + /// Type of columns. + type Cols: Copy + Eq; + /// Returns the number of rows. + fn nrows(&self) -> Self::Rows; + /// Returns the number of columns. + fn ncols(&self) -> Self::Cols; +} + +/// Zipped matrix views. +pub unsafe trait MaybeContiguous: MatShape { + /// Indexing type. + type Index: Copy; + /// Contiguous slice type. + type Slice; + /// Layout transformation type. + type LayoutTransform: Copy; + /// Returns slice at index of length `n_elems`. + unsafe fn get_slice_unchecked(&mut self, idx: Self::Index, n_elems: usize) -> Self::Slice; +} + +/// Zipped matrix views. +pub unsafe trait MatIndex<'a, _Outlives = &'a Self>: MaybeContiguous { + /// Item produced by the zipped views. + type Item; + + /// Get the item at the given index, skipping bound checks. + unsafe fn get_unchecked(&'a mut self, index: Self::Index) -> Self::Item; + /// Get the item at the given slice position, skipping bound checks. + unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item; + + /// Checks if the zipped matrices are contiguous. + fn is_contiguous(&self) -> bool; + /// Computes the preferred iteration layout of the matrices. + fn preferred_layout(&self) -> Self::LayoutTransform; + /// Applies the layout transformation to the matrices. + fn with_layout(self, layout: Self::LayoutTransform) -> Self; +} + +/// Single element. +#[derive(Copy, Clone, Debug)] +pub struct Last(pub Mat); + +/// Zipped elements. +#[derive(Copy, Clone, Debug)] +pub struct Zip(pub Head, pub Tail); + +/// Single matrix view. +#[derive(Copy, Clone, Debug)] +pub struct LastEq>(pub Mat); +/// Zipped matrix views. +#[derive(Copy, Clone, Debug)] +pub struct ZipEq< + Rows, + Cols, + Head: MatShape, + Tail: MatShape, +>(Head, Tail); + +impl< + Rows: Copy + Eq, + Cols: Copy + Eq, + Head: MatShape, + Tail: MatShape, + > ZipEq +{ + /// Creates a zipped matrix, after asserting that the dimensions match. + #[inline(always)] + pub fn new(head: Head, tail: Tail) -> Self { + assert!((head.nrows(), head.ncols()) == (tail.nrows(), tail.ncols())); + Self(head, tail) + } + + /// Creates a zipped matrix, assuming that the dimensions match. + #[inline(always)] + pub fn new_unchecked(head: Head, tail: Tail) -> Self { + debug_assert!((head.nrows(), head.ncols()) == (tail.nrows(), tail.ncols())); + Self(head, tail) + } +} + +impl> MatShape + for LastEq +{ + type Rows = Rows; + type Cols = Cols; + #[inline(always)] + fn nrows(&self) -> Self::Rows { + self.0.nrows() + } + #[inline(always)] + fn ncols(&self) -> Self::Cols { + self.0.ncols() + } +} + +impl< + Rows: Copy + Eq, + Cols: Copy + Eq, + Head: MatShape, + Tail: MatShape, + > MatShape for ZipEq +{ + type Rows = Rows; + type Cols = Cols; + #[inline(always)] + fn nrows(&self) -> Self::Rows { + self.0.nrows() + } + #[inline(always)] + fn ncols(&self) -> Self::Cols { + self.0.ncols() + } +} + +impl MatShape for ColRef<'_, E> { + type Rows = usize; + type Cols = (); + #[inline(always)] + fn nrows(&self) -> Self::Rows { + (*self).nrows() + } + #[inline(always)] + fn ncols(&self) -> Self::Cols {} +} + +impl MatShape for ColMut<'_, E> { + type Rows = usize; + type Cols = (); + #[inline(always)] + fn nrows(&self) -> Self::Rows { + (*self).nrows() + } + #[inline(always)] + fn ncols(&self) -> Self::Cols {} +} + +impl MatShape for RowRef<'_, E> { + type Rows = (); + type Cols = usize; + #[inline(always)] + fn nrows(&self) -> Self::Rows {} + #[inline(always)] + fn ncols(&self) -> Self::Cols { + (*self).ncols() + } +} +impl MatShape for RowMut<'_, E> { + type Rows = (); + type Cols = usize; + #[inline(always)] + fn nrows(&self) -> Self::Rows {} + #[inline(always)] + fn ncols(&self) -> Self::Cols { + (*self).ncols() + } +} + +impl MatShape for MatRef<'_, E> { + type Rows = usize; + type Cols = usize; + #[inline(always)] + fn nrows(&self) -> Self::Rows { + (*self).nrows() + } + #[inline(always)] + fn ncols(&self) -> Self::Cols { + (*self).ncols() + } +} + +impl MatShape for MatMut<'_, E> { + type Rows = usize; + type Cols = usize; + #[inline(always)] + fn nrows(&self) -> Self::Rows { + (*self).nrows() + } + #[inline(always)] + fn ncols(&self) -> Self::Cols { + (*self).ncols() + } +} + +unsafe impl> + MaybeContiguous for LastEq +{ + type Index = Mat::Index; + type Slice = Last; + type LayoutTransform = Mat::LayoutTransform; + #[inline(always)] + unsafe fn get_slice_unchecked(&mut self, idx: Self::Index, n_elems: usize) -> Self::Slice { + Last(self.0.get_slice_unchecked(idx, n_elems)) + } +} + +unsafe impl<'a, Rows: Copy + Eq, Cols: Copy + Eq, Mat: MatIndex<'a, Rows = Rows, Cols = Cols>> + MatIndex<'a> for LastEq +{ + type Item = Last; + + #[inline(always)] + unsafe fn get_unchecked(&'a mut self, index: Self::Index) -> Self::Item { + Last(self.0.get_unchecked(index)) + } + + #[inline(always)] + unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { + Last(Mat::get_from_slice_unchecked(&mut slice.0, idx)) + } + + #[inline(always)] + fn is_contiguous(&self) -> bool { + self.0.is_contiguous() + } + #[inline(always)] + fn preferred_layout(&self) -> Self::LayoutTransform { + self.0.preferred_layout() + } + #[inline(always)] + fn with_layout(self, layout: Self::LayoutTransform) -> Self { + Self(self.0.with_layout(layout)) + } +} + +unsafe impl< + Rows: Copy + Eq, + Cols: Copy + Eq, + Head: MaybeContiguous, + Tail: MaybeContiguous< + Rows = Rows, + Cols = Cols, + Index = Head::Index, + LayoutTransform = Head::LayoutTransform, + >, + > MaybeContiguous for ZipEq +{ + type Index = Head::Index; + type Slice = Zip; + type LayoutTransform = Head::LayoutTransform; + #[inline(always)] + unsafe fn get_slice_unchecked(&mut self, idx: Self::Index, n_elems: usize) -> Self::Slice { + Zip( + self.0.get_slice_unchecked(idx, n_elems), + self.1.get_slice_unchecked(idx, n_elems), + ) + } +} + +unsafe impl< + 'a, + Rows: Copy + Eq, + Cols: Copy + Eq, + Head: MatIndex<'a, Rows = Rows, Cols = Cols>, + Tail: MatIndex< + 'a, + Rows = Rows, + Cols = Cols, + Index = Head::Index, + LayoutTransform = Head::LayoutTransform, + >, + > MatIndex<'a> for ZipEq +{ + type Item = Zip; + + #[inline(always)] + unsafe fn get_unchecked(&'a mut self, index: Self::Index) -> Self::Item { + Zip(self.0.get_unchecked(index), self.1.get_unchecked(index)) + } + + #[inline(always)] + unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { + Zip( + Head::get_from_slice_unchecked(&mut slice.0, idx), + Tail::get_from_slice_unchecked(&mut slice.1, idx), + ) + } + + #[inline(always)] + fn is_contiguous(&self) -> bool { + self.0.is_contiguous() && self.1.is_contiguous() + } + #[inline(always)] + fn preferred_layout(&self) -> Self::LayoutTransform { + self.0.preferred_layout() + } + #[inline(always)] + fn with_layout(self, layout: Self::LayoutTransform) -> Self { + ZipEq(self.0.with_layout(layout), self.1.with_layout(layout)) + } +} + +unsafe impl MaybeContiguous for ColRef<'_, E> { + type Index = (usize, ()); + type Slice = GroupFor]>; + type LayoutTransform = VecLayoutTransform; + + #[inline(always)] + unsafe fn get_slice_unchecked(&mut self, (i, _): Self::Index, n_elems: usize) -> Self::Slice { + E::faer_map( + (*self).rb().ptr_at(i), + #[inline(always)] + |ptr| core::slice::from_raw_parts(ptr as *const MaybeUninit, n_elems), + ) + } +} +unsafe impl<'a, E: Entity> MatIndex<'a> for ColRef<'_, E> { + type Item = Read<'a, E>; + + #[inline(always)] + unsafe fn get_unchecked(&'a mut self, (i, _): Self::Index) -> Self::Item { + Read { + ptr: E::faer_map( + self.rb().ptr_inbounds_at(i), + #[inline(always)] + |ptr| &*(ptr as *const MaybeUninit), + ), + } + } + + #[inline(always)] + unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { + let slice = E::faer_rb(E::faer_as_ref(slice)); + Read { + ptr: E::faer_map( + slice, + #[inline(always)] + |slice| slice.get_unchecked(idx), + ), + } + } + + #[inline(always)] + fn is_contiguous(&self) -> bool { + self.row_stride() == 1 + } + #[inline(always)] + fn preferred_layout(&self) -> Self::LayoutTransform { + let rs = self.row_stride(); + if self.nrows() > 1 && rs == 1 { + VecLayoutTransform::None + } else if self.nrows() > 1 && rs == -1 { + VecLayoutTransform::Reverse + } else { + VecLayoutTransform::None + } + } + #[inline(always)] + fn with_layout(self, layout: Self::LayoutTransform) -> Self { + use VecLayoutTransform::*; + match layout { + None => self, + Reverse => self.reverse_rows(), + } + } +} + +unsafe impl MaybeContiguous for ColMut<'_, E> { + type Index = (usize, ()); + type Slice = GroupFor]>; + type LayoutTransform = VecLayoutTransform; + + #[inline(always)] + unsafe fn get_slice_unchecked(&mut self, (i, _): Self::Index, n_elems: usize) -> Self::Slice { + E::faer_map( + (*self).rb_mut().ptr_at_mut(i), + #[inline(always)] + |ptr| core::slice::from_raw_parts_mut(ptr as *mut MaybeUninit, n_elems), + ) + } +} +unsafe impl<'a, E: Entity> MatIndex<'a> for ColMut<'_, E> { + type Item = ReadWrite<'a, E>; + + #[inline(always)] + unsafe fn get_unchecked(&'a mut self, (i, _): Self::Index) -> Self::Item { + ReadWrite { + ptr: E::faer_map( + self.rb_mut().ptr_inbounds_at_mut(i), + #[inline(always)] + |ptr| &mut *(ptr as *mut MaybeUninit), + ), + } + } + + #[inline(always)] + unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { + let slice = E::faer_rb_mut(E::faer_as_mut(slice)); + ReadWrite { + ptr: E::faer_map( + slice, + #[inline(always)] + |slice| slice.get_unchecked_mut(idx), + ), + } + } + + #[inline(always)] + fn is_contiguous(&self) -> bool { + self.row_stride() == 1 + } + #[inline(always)] + fn preferred_layout(&self) -> Self::LayoutTransform { + let rs = self.row_stride(); + if self.nrows() > 1 && rs == 1 { + VecLayoutTransform::None + } else if self.nrows() > 1 && rs == -1 { + VecLayoutTransform::Reverse + } else { + VecLayoutTransform::None + } + } + #[inline(always)] + fn with_layout(self, layout: Self::LayoutTransform) -> Self { + use VecLayoutTransform::*; + match layout { + None => self, + Reverse => self.reverse_rows_mut(), + } + } +} + +unsafe impl MaybeContiguous for RowRef<'_, E> { + type Index = ((), usize); + type Slice = GroupFor]>; + type LayoutTransform = VecLayoutTransform; + + #[inline(always)] + unsafe fn get_slice_unchecked(&mut self, (_, j): Self::Index, n_elems: usize) -> Self::Slice { + E::faer_map( + (*self).rb().ptr_at(j), + #[inline(always)] + |ptr| core::slice::from_raw_parts(ptr as *const MaybeUninit, n_elems), + ) + } +} +unsafe impl<'a, E: Entity> MatIndex<'a> for RowRef<'_, E> { + type Item = Read<'a, E>; + + #[inline(always)] + unsafe fn get_unchecked(&'a mut self, (_, j): Self::Index) -> Self::Item { + Read { + ptr: E::faer_map( + self.rb().ptr_inbounds_at(j), + #[inline(always)] + |ptr| &*(ptr as *const MaybeUninit), + ), + } + } + + #[inline(always)] + unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { + let slice = E::faer_rb(E::faer_as_ref(slice)); + Read { + ptr: E::faer_map( + slice, + #[inline(always)] + |slice| slice.get_unchecked(idx), + ), + } + } + + #[inline(always)] + fn is_contiguous(&self) -> bool { + self.col_stride() == 1 + } + #[inline(always)] + fn preferred_layout(&self) -> Self::LayoutTransform { + let cs = self.col_stride(); + if self.ncols() > 1 && cs == 1 { + VecLayoutTransform::None + } else if self.ncols() > 1 && cs == -1 { + VecLayoutTransform::Reverse + } else { + VecLayoutTransform::None + } + } + #[inline(always)] + fn with_layout(self, layout: Self::LayoutTransform) -> Self { + use VecLayoutTransform::*; + match layout { + None => self, + Reverse => self.reverse_cols(), + } + } +} + +unsafe impl MaybeContiguous for RowMut<'_, E> { + type Index = ((), usize); + type Slice = GroupFor]>; + type LayoutTransform = VecLayoutTransform; + + #[inline(always)] + unsafe fn get_slice_unchecked(&mut self, (_, j): Self::Index, n_elems: usize) -> Self::Slice { + E::faer_map( + (*self).rb_mut().ptr_at_mut(j), + #[inline(always)] + |ptr| core::slice::from_raw_parts_mut(ptr as *mut MaybeUninit, n_elems), + ) + } +} +unsafe impl<'a, E: Entity> MatIndex<'a> for RowMut<'_, E> { + type Item = ReadWrite<'a, E>; + + #[inline(always)] + unsafe fn get_unchecked(&'a mut self, (_, j): Self::Index) -> Self::Item { + ReadWrite { + ptr: E::faer_map( + self.rb_mut().ptr_inbounds_at_mut(j), + #[inline(always)] + |ptr| &mut *(ptr as *mut MaybeUninit), + ), + } + } + + #[inline(always)] + unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { + let slice = E::faer_rb_mut(E::faer_as_mut(slice)); + ReadWrite { + ptr: E::faer_map( + slice, + #[inline(always)] + |slice| slice.get_unchecked_mut(idx), + ), + } + } + + #[inline(always)] + fn is_contiguous(&self) -> bool { + self.col_stride() == 1 + } + #[inline(always)] + fn preferred_layout(&self) -> Self::LayoutTransform { + let cs = self.col_stride(); + if self.ncols() > 1 && cs == 1 { + VecLayoutTransform::None + } else if self.ncols() > 1 && cs == -1 { + VecLayoutTransform::Reverse + } else { + VecLayoutTransform::None + } + } + #[inline(always)] + fn with_layout(self, layout: Self::LayoutTransform) -> Self { + use VecLayoutTransform::*; + match layout { + None => self, + Reverse => self.reverse_cols_mut(), + } + } +} + +unsafe impl MaybeContiguous for MatRef<'_, E> { + type Index = (usize, usize); + type Slice = GroupFor]>; + type LayoutTransform = MatLayoutTransform; + + #[inline(always)] + unsafe fn get_slice_unchecked(&mut self, (i, j): Self::Index, n_elems: usize) -> Self::Slice { + E::faer_map( + (*self).rb().overflowing_ptr_at(i, j), + #[inline(always)] + |ptr| core::slice::from_raw_parts(ptr as *const MaybeUninit, n_elems), + ) + } +} +unsafe impl<'a, E: Entity> MatIndex<'a> for MatRef<'_, E> { + type Item = Read<'a, E>; + + #[inline(always)] + unsafe fn get_unchecked(&'a mut self, (i, j): Self::Index) -> Self::Item { + Read { + ptr: E::faer_map( + self.rb().ptr_inbounds_at(i, j), + #[inline(always)] + |ptr| &*(ptr as *const MaybeUninit), + ), + } + } + + #[inline(always)] + unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { + let slice = E::faer_rb(E::faer_as_ref(slice)); + Read { + ptr: E::faer_map( + slice, + #[inline(always)] + |slice| slice.get_unchecked(idx), + ), + } + } + + #[inline(always)] + fn is_contiguous(&self) -> bool { + self.row_stride() == 1 + } + #[inline(always)] + fn preferred_layout(&self) -> Self::LayoutTransform { + let rs = self.row_stride(); + let cs = self.col_stride(); + if self.nrows() > 1 && rs == 1 { + MatLayoutTransform::None + } else if self.nrows() > 1 && rs == -1 { + MatLayoutTransform::ReverseRows + } else if self.ncols() > 1 && cs == 1 { + MatLayoutTransform::Transpose + } else if self.ncols() > 1 && cs == -1 { + MatLayoutTransform::TransposeReverseRows + } else { + MatLayoutTransform::None + } + } + #[inline(always)] + fn with_layout(self, layout: Self::LayoutTransform) -> Self { + use MatLayoutTransform::*; + match layout { + None => self, + ReverseRows => self.reverse_rows(), + Transpose => self.transpose(), + TransposeReverseRows => self.transpose().reverse_rows(), + } + } +} + +unsafe impl MaybeContiguous for MatMut<'_, E> { + type Index = (usize, usize); + type Slice = GroupFor]>; + type LayoutTransform = MatLayoutTransform; + + #[inline(always)] + unsafe fn get_slice_unchecked(&mut self, (i, j): Self::Index, n_elems: usize) -> Self::Slice { + E::faer_map( + (*self).rb().overflowing_ptr_at(i, j), + #[inline(always)] + |ptr| core::slice::from_raw_parts_mut(ptr as *mut MaybeUninit, n_elems), + ) + } +} + +unsafe impl<'a, E: Entity> MatIndex<'a> for MatMut<'_, E> { + type Item = ReadWrite<'a, E>; + + #[inline(always)] + unsafe fn get_unchecked(&'a mut self, (i, j): Self::Index) -> Self::Item { + ReadWrite { + ptr: E::faer_map( + self.rb_mut().ptr_inbounds_at_mut(i, j), + #[inline(always)] + |ptr| &mut *(ptr as *mut MaybeUninit), + ), + } + } + + #[inline(always)] + unsafe fn get_from_slice_unchecked(slice: &'a mut Self::Slice, idx: usize) -> Self::Item { + let slice = E::faer_rb_mut(E::faer_as_mut(slice)); + ReadWrite { + ptr: E::faer_map( + slice, + #[inline(always)] + |slice| slice.get_unchecked_mut(idx), + ), + } + } + + #[inline(always)] + fn is_contiguous(&self) -> bool { + self.row_stride() == 1 + } + #[inline(always)] + fn preferred_layout(&self) -> Self::LayoutTransform { + let rs = self.row_stride(); + let cs = self.col_stride(); + if self.nrows() > 1 && rs == 1 { + MatLayoutTransform::None + } else if self.nrows() > 1 && rs == -1 { + MatLayoutTransform::ReverseRows + } else if self.ncols() > 1 && cs == 1 { + MatLayoutTransform::Transpose + } else if self.ncols() > 1 && cs == -1 { + MatLayoutTransform::TransposeReverseRows + } else { + MatLayoutTransform::None + } + } + #[inline(always)] + fn with_layout(self, layout: Self::LayoutTransform) -> Self { + use MatLayoutTransform::*; + match layout { + None => self, + ReverseRows => self.reverse_rows_mut(), + Transpose => self.transpose_mut(), + TransposeReverseRows => self.transpose_mut().reverse_rows_mut(), + } + } +} + +#[inline(always)] +fn annotate_noalias_mat MatIndex<'a>>( + f: &mut impl for<'a> FnMut(>::Item), + mut slice: Z::Slice, + i_begin: usize, + i_end: usize, + _j: usize, +) { + for i in i_begin..i_end { + unsafe { f(Z::get_from_slice_unchecked(&mut slice, i - i_begin)) }; + } +} + +#[inline(always)] +fn annotate_noalias_col MatIndex<'a>>( + f: &mut impl for<'a> FnMut(>::Item), + mut slice: Z::Slice, + i_begin: usize, + i_end: usize, +) { + for i in i_begin..i_end { + unsafe { f(Z::get_from_slice_unchecked(&mut slice, i - i_begin)) }; + } +} + +#[inline(always)] +fn for_each_mat MatIndex<'a, Rows = usize, Cols = usize, Index = (usize, usize)>>( + z: Z, + mut f: impl for<'a> FnMut(>::Item), +) { + let layout = z.preferred_layout(); + let mut z = z.with_layout(layout); + + let m = z.nrows(); + let n = z.ncols(); + if m == 0 || n == 0 { + return; + } + + unsafe { + if z.is_contiguous() { + for j in 0..n { + annotate_noalias_mat::(&mut f, z.get_slice_unchecked((0, j), m), 0, m, j); + } + } else { + for j in 0..n { + for i in 0..m { + f(z.get_unchecked((i, j))) + } + } + } + } +} + +#[inline(always)] +fn for_each_mat_triangular_lower< + Z: for<'a> MatIndex< + 'a, + Rows = usize, + Cols = usize, + Index = (usize, usize), + LayoutTransform = MatLayoutTransform, + >, +>( + z: Z, + diag: Diag, + transpose: bool, + mut f: impl for<'a> FnMut(>::Item), +) { + use MatLayoutTransform::*; + + let z = if transpose { + z.with_layout(MatLayoutTransform::Transpose) + } else { + z + }; + let layout = z.preferred_layout(); + let mut z = z.with_layout(layout); + + let m = z.nrows(); + let n = z.ncols(); + let n = match layout { + None | ReverseRows => Ord::min(m, n), + Transpose | TransposeReverseRows => n, + }; + if m == 0 || n == 0 { + return; + } + + let strict = match diag { + Diag::Skip => true, + Diag::Include => false, + }; + + unsafe { + if z.is_contiguous() { + for j in 0..n { + let (start, end) = match layout { + None => (j + strict as usize, m), + ReverseRows => (0, (m - (j + strict as usize))), + Transpose => (0, (j + !strict as usize).min(m)), + TransposeReverseRows => (m - ((j + !strict as usize).min(m)), m), + }; + + let len = end - start; + + annotate_noalias_mat::( + &mut f, + z.get_slice_unchecked((start, j), len), + start, + end, + j, + ); + } + } else { + for j in 0..n { + let (start, end) = match layout { + None => (j + strict as usize, m), + ReverseRows => (0, (m - (j + strict as usize))), + Transpose => (0, (j + !strict as usize).min(m)), + TransposeReverseRows => (m - ((j + !strict as usize).min(m)), m), + }; + + for i in start..end { + f(z.get_unchecked((i, j))) + } + } + } + } +} + +#[inline(always)] +fn for_each_col MatIndex<'a, Rows = usize, Cols = (), Index = (usize, ())>>( + z: Z, + mut f: impl for<'a> FnMut(>::Item), +) { + let layout = z.preferred_layout(); + let mut z = z.with_layout(layout); + + let m = z.nrows(); + if m == 0 { + return; + } + + unsafe { + if z.is_contiguous() { + annotate_noalias_col::(&mut f, z.get_slice_unchecked((0, ()), m), 0, m); + } else { + for i in 0..m { + f(z.get_unchecked((i, ()))) + } + } + } +} + +#[inline(always)] +fn for_each_row MatIndex<'a, Rows = (), Cols = usize, Index = ((), usize)>>( + z: Z, + mut f: impl for<'a> FnMut(>::Item), +) { + let layout = z.preferred_layout(); + let mut z = z.with_layout(layout); + + let n = z.ncols(); + if n == 0 { + return; + } + + unsafe { + if z.is_contiguous() { + annotate_noalias_col::(&mut f, z.get_slice_unchecked(((), 0), n), 0, n); + } else { + for j in 0..n { + f(z.get_unchecked(((), j))) + } + } + } +} + +impl< + M: for<'a> MatIndex< + 'a, + Rows = usize, + Cols = usize, + Index = (usize, usize), + LayoutTransform = MatLayoutTransform, + >, + > LastEq +{ + /// Applies `f` to each element of `self`. + #[inline(always)] + pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { + for_each_mat(self, f); + } + + /// Applies `f` to each element of the lower triangular half of `self`. + /// + /// `diag` specifies whether the diagonal should be included or excluded. + #[inline(always)] + pub fn for_each_triangular_lower( + self, + diag: Diag, + f: impl for<'a> FnMut(>::Item), + ) { + for_each_mat_triangular_lower(self, diag, false, f); + } + + /// Applies `f` to each element of the upper triangular half of `self`. + /// + /// `diag` specifies whether the diagonal should be included or excluded. + #[inline(always)] + pub fn for_each_triangular_upper( + self, + diag: Diag, + f: impl for<'a> FnMut(>::Item), + ) { + for_each_mat_triangular_lower(self, diag, true, f); + } + + /// Applies `f` to each element of `self` and collect its result into a new matrix. + #[inline(always)] + pub fn map( + self, + f: impl for<'a> FnMut(>::Item) -> E, + ) -> Mat { + let (m, n) = (self.nrows(), self.ncols()); + let mut out = Mat::::with_capacity(m, n); + let rs = 1; + let cs = out.col_stride(); + let out_view = unsafe { mat::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), m, n, rs, cs) }; + let mut f = f; + ZipEq::new(out_view, self).for_each( + #[inline(always)] + |Zip(mut out, item)| out.write(f(item)), + ); + unsafe { out.set_dims(m, n) }; + out + } +} + +impl< + M: for<'a> MatIndex< + 'a, + Rows = (), + Cols = usize, + Index = ((), usize), + LayoutTransform = VecLayoutTransform, + >, + > LastEq<(), usize, M> +{ + /// Applies `f` to each element of `self`. + #[inline(always)] + pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { + for_each_row(self, f); + } + + /// Applies `f` to each element of `self` and collect its result into a new row. + #[inline(always)] + pub fn map( + self, + f: impl for<'a> FnMut(>::Item) -> E, + ) -> Row { + let (_, n) = (self.nrows(), self.ncols()); + let mut out = Row::::with_capacity(n); + let out_view = unsafe { row::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), n, 1) }; + let mut f = f; + ZipEq::new(out_view, self).for_each( + #[inline(always)] + |Zip(mut out, item)| out.write(f(item)), + ); + unsafe { out.set_ncols(n) }; + out + } +} + +impl< + M: for<'a> MatIndex< + 'a, + Rows = usize, + Cols = (), + Index = (usize, ()), + LayoutTransform = VecLayoutTransform, + >, + > LastEq +{ + /// Applies `f` to each element of `self`. + #[inline(always)] + pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { + for_each_col(self, f); + } + + /// Applies `f` to each element of `self` and collect its result into a new column. + #[inline(always)] + pub fn map( + self, + f: impl for<'a> FnMut(>::Item) -> E, + ) -> Col { + let (m, _) = (self.nrows(), self.ncols()); + let mut out = Col::::with_capacity(m); + let out_view = unsafe { col::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), m, 1) }; + let mut f = f; + ZipEq::new(out_view, self).for_each( + #[inline(always)] + |Zip(mut out, item)| out.write(f(item)), + ); + unsafe { out.set_nrows(m) }; + out + } +} + +impl< + Head: for<'a> MatIndex< + 'a, + Rows = (), + Cols = usize, + Index = ((), usize), + LayoutTransform = VecLayoutTransform, + >, + Tail: for<'a> MatIndex< + 'a, + Rows = (), + Cols = usize, + Index = ((), usize), + LayoutTransform = VecLayoutTransform, + >, + > ZipEq<(), usize, Head, Tail> +{ + /// Applies `f` to each element of `self`. + #[inline(always)] + pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { + for_each_row(self, f); + } + + /// Applies `f` to each element of `self` and collect its result into a new row. + #[inline(always)] + pub fn map( + self, + f: impl for<'a> FnMut(>::Item) -> E, + ) -> Row { + let (_, n) = (self.nrows(), self.ncols()); + let mut out = Row::::with_capacity(n); + let out_view = unsafe { row::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), n, 1) }; + let mut f = f; + ZipEq::new(out_view, self).for_each( + #[inline(always)] + |Zip(mut out, item)| out.write(f(item)), + ); + unsafe { out.set_ncols(n) }; + out + } +} + +impl< + Head: for<'a> MatIndex< + 'a, + Rows = usize, + Cols = (), + Index = (usize, ()), + LayoutTransform = VecLayoutTransform, + >, + Tail: for<'a> MatIndex< + 'a, + Rows = usize, + Cols = (), + Index = (usize, ()), + LayoutTransform = VecLayoutTransform, + >, + > ZipEq +{ + /// Applies `f` to each element of `self`. + #[inline(always)] + pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { + for_each_col(self, f); + } + + /// Applies `f` to each element of `self` and collect its result into a new column. + #[inline(always)] + pub fn map( + self, + f: impl for<'a> FnMut(>::Item) -> E, + ) -> Col { + let (m, _) = (self.nrows(), self.ncols()); + let mut out = Col::::with_capacity(m); + let out_view = unsafe { col::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), m, 1) }; + let mut f = f; + ZipEq::new(out_view, self).for_each( + #[inline(always)] + |Zip(mut out, item)| out.write(f(item)), + ); + unsafe { out.set_nrows(m) }; + out + } +} + +impl< + Head: for<'a> MatIndex< + 'a, + Rows = usize, + Cols = usize, + Index = (usize, usize), + LayoutTransform = MatLayoutTransform, + >, + Tail: for<'a> MatIndex< + 'a, + Rows = usize, + Cols = usize, + Index = (usize, usize), + LayoutTransform = MatLayoutTransform, + >, + > ZipEq +{ + /// Applies `f` to each element of `self`. + #[inline(always)] + pub fn for_each(self, f: impl for<'a> FnMut(>::Item)) { + for_each_mat(self, f); + } + + /// Applies `f` to each element of the lower triangular half of `self`. + /// + /// `diag` specifies whether the diagonal should be included or excluded. + #[inline(always)] + pub fn for_each_triangular_lower( + self, + diag: Diag, + f: impl for<'a> FnMut(>::Item), + ) { + for_each_mat_triangular_lower(self, diag, false, f); + } + + /// Applies `f` to each element of the upper triangular half of `self`. + /// + /// `diag` specifies whether the diagonal should be included or excluded. + #[inline(always)] + pub fn for_each_triangular_upper( + self, + diag: Diag, + f: impl for<'a> FnMut(>::Item), + ) { + for_each_mat_triangular_lower(self, diag, true, f); + } + + /// Applies `f` to each element of `self` and collect its result into a new matrix. + #[inline(always)] + pub fn map( + self, + f: impl for<'a> FnMut(>::Item) -> E, + ) -> Mat { + let (m, n) = (self.nrows(), self.ncols()); + let mut out = Mat::::with_capacity(m, n); + let rs = 1; + let cs = out.col_stride(); + let out_view = unsafe { mat::from_raw_parts_mut::<'_, E>(out.as_ptr_mut(), m, n, rs, cs) }; + let mut f = f; + ZipEq::new(out_view, self).for_each( + #[inline(always)] + |Zip(mut out, item)| out.write(f(item)), + ); + unsafe { out.set_dims(m, n) }; + out + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{assert, mat::Mat, unzipped, zipped, ComplexField}; + + #[test] + fn test_zip() { + for (m, n) in [(2, 2), (4, 2), (2, 4)] { + for rev_dst in [false, true] { + for rev_src in [false, true] { + for transpose_dst in [false, true] { + for transpose_src in [false, true] { + for diag in [Diag::Include, Diag::Skip] { + let mut dst = Mat::from_fn( + if transpose_dst { n } else { m }, + if transpose_dst { m } else { n }, + |_, _| f64::faer_zero(), + ); + let src = Mat::from_fn( + if transpose_src { n } else { m }, + if transpose_src { m } else { n }, + |_, _| f64::faer_one(), + ); + + let mut target = Mat::from_fn(m, n, |_, _| f64::faer_zero()); + let target_src = Mat::from_fn(m, n, |_, _| f64::faer_one()); + + zipped!(target.as_mut(), target_src.as_ref()) + .for_each_triangular_lower(diag, |unzipped!(mut dst, src)| { + dst.write(src.read()) + }); + + let mut dst = dst.as_mut(); + let mut src = src.as_ref(); + + if transpose_dst { + dst = dst.transpose_mut(); + } + if rev_dst { + dst = dst.reverse_rows_mut(); + } + + if transpose_src { + src = src.transpose(); + } + if rev_src { + src = src.reverse_rows(); + } + + zipped!(dst.rb_mut(), src) + .for_each_triangular_lower(diag, |unzipped!(mut dst, src)| { + dst.write(src.read()) + }); + + assert!(dst.rb() == target.as_ref()); + } + } + } + } + } + } + + { + let m = 3; + for rev_dst in [false, true] { + for rev_src in [false, true] { + let mut dst = Col::::zeros(m); + let src = Col::from_fn(m, |i| (i + 1) as f64); + + let mut target = Col::::zeros(m); + let target_src = + Col::from_fn(m, |i| if rev_src { m - i } else { i + 1 } as f64); + + zipped!(target.as_mut(), target_src.as_ref()) + .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); + + let mut dst = dst.as_mut(); + let mut src = src.as_ref(); + + if rev_dst { + dst = dst.reverse_rows_mut(); + } + if rev_src { + src = src.reverse_rows(); + } + + zipped!(dst.rb_mut(), src) + .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); + + assert!(dst.rb() == target.as_ref()); + } + } + } + + { + let m = 3; + for rev_dst in [false, true] { + for rev_src in [false, true] { + let mut dst = Row::::zeros(m); + let src = Row::from_fn(m, |i| (i + 1) as f64); + + let mut target = Row::::zeros(m); + let target_src = + Row::from_fn(m, |i| if rev_src { m - i } else { i + 1 } as f64); + + zipped!(target.as_mut(), target_src.as_ref()) + .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); + + let mut dst = dst.as_mut(); + let mut src = src.as_ref(); + + if rev_dst { + dst = dst.reverse_cols_mut(); + } + if rev_src { + src = src.reverse_cols(); + } + + zipped!(&mut dst, src) + .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); + + assert!(dst.rb() == target.as_ref()); + } + } + } + } +} diff --git a/src/mat/mat_index.rs b/src/mat/mat_index.rs new file mode 100644 index 0000000000000000000000000000000000000000..e70146ea24f291f6252220f203878f8da46f56ee --- /dev/null +++ b/src/mat/mat_index.rs @@ -0,0 +1,502 @@ +use super::*; +use crate::{ + assert, + col::{ColMut, ColRef}, + row::{RowMut, RowRef}, +}; + +// RangeFull +// Range +// RangeInclusive +// RangeTo +// RangeToInclusive +// usize + +use core::ops::RangeFull; +type Range = core::ops::Range; +type RangeInclusive = core::ops::RangeInclusive; +type RangeFrom = core::ops::RangeFrom; +type RangeTo = core::ops::RangeTo; +type RangeToInclusive = core::ops::RangeToInclusive; + +impl MatIndex for MatRef<'_, E> +where + Self: MatIndex, +{ + type Target = >::Target; + + #[track_caller] + #[inline(always)] + fn get( + this: Self, + row: RowRange, + col: RangeFrom, + ) -> >::Target { + let ncols = this.ncols(); + >::get(this, row, col.start..ncols) + } +} +impl MatIndex for MatRef<'_, E> +where + Self: MatIndex, +{ + type Target = >::Target; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RowRange, col: RangeTo) -> >::Target { + >::get(this, row, 0..col.end) + } +} +impl MatIndex for MatRef<'_, E> +where + Self: MatIndex, +{ + type Target = >::Target; + + #[track_caller] + #[inline(always)] + fn get( + this: Self, + row: RowRange, + col: RangeToInclusive, + ) -> >::Target { + assert!(col.end != usize::MAX); + >::get(this, row, 0..col.end + 1) + } +} +impl MatIndex for MatRef<'_, E> +where + Self: MatIndex, +{ + type Target = >::Target; + + #[track_caller] + #[inline(always)] + fn get( + this: Self, + row: RowRange, + col: RangeInclusive, + ) -> >::Target { + assert!(*col.end() != usize::MAX); + >::get(this, row, *col.start()..*col.end() + 1) + } +} +impl MatIndex for MatRef<'_, E> +where + Self: MatIndex, +{ + type Target = >::Target; + + #[track_caller] + #[inline(always)] + fn get( + this: Self, + row: RowRange, + col: RangeFull, + ) -> >::Target { + let _ = col; + let ncols = this.ncols(); + >::get(this, row, 0..ncols) + } +} + +impl MatIndex for MatRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFull, col: Range) -> Self { + let _ = row; + assert!(col.start <= col.end); + this.subcols(col.start, col.end - col.start) + } +} +impl<'a, E: Entity> MatIndex for MatRef<'a, E> { + type Target = ColRef<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFull, col: usize) -> Self::Target { + let _ = row; + this.col(col) + } +} + +impl MatIndex for MatRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: Range, col: Range) -> Self { + assert!(all(row.start <= row.end, col.start <= col.end)); + this.submatrix( + row.start, + col.start, + row.end - row.start, + col.end - col.start, + ) + } +} +impl<'a, E: Entity> MatIndex for MatRef<'a, E> { + type Target = ColRef<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: Range, col: usize) -> Self::Target { + assert!(row.start <= row.end); + this.submatrix(row.start, col, row.end - row.start, 1) + .col(0) + } +} + +impl MatIndex for MatRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeInclusive, col: Range) -> Self { + assert!(*row.end() != usize::MAX); + >::get(this, *row.start()..*row.end() + 1, col) + } +} +impl<'a, E: Entity> MatIndex for MatRef<'a, E> { + type Target = ColRef<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeInclusive, col: usize) -> Self::Target { + assert!(*row.end() != usize::MAX); + >::get(this, *row.start()..*row.end() + 1, col) + } +} + +impl MatIndex for MatRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFrom, col: Range) -> Self { + let nrows = this.nrows(); + >::get(this, row.start..nrows, col) + } +} +impl<'a, E: Entity> MatIndex for MatRef<'a, E> { + type Target = ColRef<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFrom, col: usize) -> Self::Target { + let nrows = this.nrows(); + >::get(this, row.start..nrows, col) + } +} +impl MatIndex for MatRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeTo, col: Range) -> Self { + >::get(this, 0..row.end, col) + } +} +impl<'a, E: Entity> MatIndex for MatRef<'a, E> { + type Target = ColRef<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeTo, col: usize) -> Self::Target { + >::get(this, 0..row.end, col) + } +} + +impl MatIndex for MatRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeToInclusive, col: Range) -> Self { + assert!(row.end != usize::MAX); + >::get(this, 0..row.end + 1, col) + } +} +impl<'a, E: Entity> MatIndex for MatRef<'a, E> { + type Target = ColRef<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeToInclusive, col: usize) -> Self::Target { + assert!(row.end != usize::MAX); + >::get(this, 0..row.end + 1, col) + } +} + +impl<'a, E: Entity> MatIndex for MatRef<'a, E> { + type Target = RowRef<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: usize, col: Range) -> Self::Target { + assert!(col.start <= col.end); + this.submatrix(row, col.start, 1, col.end - col.start) + .row(0) + } +} + +impl MatIndex for MatMut<'_, E> +where + Self: MatIndex, +{ + type Target = >::Target; + + #[track_caller] + #[inline(always)] + fn get( + this: Self, + row: RowRange, + col: RangeFrom, + ) -> >::Target { + let ncols = this.ncols(); + >::get(this, row, col.start..ncols) + } +} +impl MatIndex for MatMut<'_, E> +where + Self: MatIndex, +{ + type Target = >::Target; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RowRange, col: RangeTo) -> >::Target { + >::get(this, row, 0..col.end) + } +} +impl MatIndex for MatMut<'_, E> +where + Self: MatIndex, +{ + type Target = >::Target; + + #[track_caller] + #[inline(always)] + fn get( + this: Self, + row: RowRange, + col: RangeToInclusive, + ) -> >::Target { + assert!(col.end != usize::MAX); + >::get(this, row, 0..col.end + 1) + } +} +impl MatIndex for MatMut<'_, E> +where + Self: MatIndex, +{ + type Target = >::Target; + + #[track_caller] + #[inline(always)] + fn get( + this: Self, + row: RowRange, + col: RangeInclusive, + ) -> >::Target { + assert!(*col.end() != usize::MAX); + >::get(this, row, *col.start()..*col.end() + 1) + } +} +impl MatIndex for MatMut<'_, E> +where + Self: MatIndex, +{ + type Target = >::Target; + + #[track_caller] + #[inline(always)] + fn get( + this: Self, + row: RowRange, + col: RangeFull, + ) -> >::Target { + let _ = col; + let ncols = this.ncols(); + >::get(this, row, 0..ncols) + } +} + +impl MatIndex for MatMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFull, col: Range) -> Self { + let _ = row; + assert!(col.start <= col.end); + this.subcols_mut(col.start, col.end - col.start) + } +} +impl<'a, E: Entity> MatIndex for MatMut<'a, E> { + type Target = ColMut<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFull, col: usize) -> Self::Target { + let _ = row; + this.col_mut(col) + } +} + +impl MatIndex for MatMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: Range, col: Range) -> Self { + assert!(all(row.start <= row.end, col.start <= col.end)); + this.submatrix_mut( + row.start, + col.start, + row.end - row.start, + col.end - col.start, + ) + } +} +impl<'a, E: Entity> MatIndex for MatMut<'a, E> { + type Target = ColMut<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: Range, col: usize) -> Self::Target { + assert!(row.start <= row.end); + this.submatrix_mut(row.start, col, row.end - row.start, 1) + .col_mut(0) + } +} + +impl MatIndex for MatMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeInclusive, col: Range) -> Self { + assert!(*row.end() != usize::MAX); + >::get(this, *row.start()..*row.end() + 1, col) + } +} +impl<'a, E: Entity> MatIndex for MatMut<'a, E> { + type Target = ColMut<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeInclusive, col: usize) -> Self::Target { + assert!(*row.end() != usize::MAX); + >::get(this, *row.start()..*row.end() + 1, col) + } +} + +impl MatIndex for MatMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFrom, col: Range) -> Self { + let nrows = this.nrows(); + >::get(this, row.start..nrows, col) + } +} +impl<'a, E: Entity> MatIndex for MatMut<'a, E> { + type Target = ColMut<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeFrom, col: usize) -> Self::Target { + let nrows = this.nrows(); + >::get(this, row.start..nrows, col) + } +} +impl MatIndex for MatMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeTo, col: Range) -> Self { + >::get(this, 0..row.end, col) + } +} +impl<'a, E: Entity> MatIndex for MatMut<'a, E> { + type Target = ColMut<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeTo, col: usize) -> Self::Target { + >::get(this, 0..row.end, col) + } +} + +impl MatIndex for MatMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeToInclusive, col: Range) -> Self { + assert!(row.end != usize::MAX); + >::get(this, 0..row.end + 1, col) + } +} +impl<'a, E: Entity> MatIndex for MatMut<'a, E> { + type Target = ColMut<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: RangeToInclusive, col: usize) -> Self::Target { + assert!(row.end != usize::MAX); + >::get(this, 0..row.end + 1, col) + } +} + +impl<'a, E: Entity> MatIndex for MatMut<'a, E> { + type Target = RowMut<'a, E>; + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: usize, col: Range) -> Self::Target { + assert!(col.start <= col.end); + this.submatrix_mut(row, col.start, 1, col.end - col.start) + .row_mut(0) + } +} + +impl<'a, E: Entity> MatIndex for MatRef<'a, E> { + type Target = GroupFor; + + #[track_caller] + #[inline(always)] + unsafe fn get_unchecked(this: Self, row: usize, col: usize) -> Self::Target { + unsafe { E::faer_map(this.ptr_inbounds_at(row, col), |ptr| &*ptr) } + } + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: usize, col: usize) -> Self::Target { + assert!(all(row < this.nrows(), col < this.ncols())); + unsafe { >::get_unchecked(this, row, col) } + } +} + +impl<'a, E: Entity> MatIndex for MatMut<'a, E> { + type Target = GroupFor; + + #[track_caller] + #[inline(always)] + unsafe fn get_unchecked(this: Self, row: usize, col: usize) -> Self::Target { + unsafe { E::faer_map(this.ptr_inbounds_at_mut(row, col), |ptr| &mut *ptr) } + } + + #[track_caller] + #[inline(always)] + fn get(this: Self, row: usize, col: usize) -> Self::Target { + assert!(all(row < this.nrows(), col < this.ncols())); + unsafe { >::get_unchecked(this, row, col) } + } +} diff --git a/src/mat/matalloc.rs b/src/mat/matalloc.rs new file mode 100644 index 0000000000000000000000000000000000000000..8d71185ab2d796af0ab681b9294fa9021437911e --- /dev/null +++ b/src/mat/matalloc.rs @@ -0,0 +1,285 @@ +use super::*; +use crate::complex_native::*; +use core::mem::ManuallyDrop; + +#[repr(C)] +pub struct RawMatUnit { + pub(crate) ptr: NonNull, + pub(crate) row_capacity: usize, + pub(crate) col_capacity: usize, +} + +#[repr(C)] +pub(crate) struct MatUnit { + pub(crate) raw: RawMatUnit, + pub(crate) nrows: usize, + pub(crate) ncols: usize, +} + +impl RawMatUnit { + pub fn new(row_capacity: usize, col_capacity: usize) -> Self { + let dangling = NonNull::::dangling(); + if core::mem::size_of::() == 0 { + Self { + ptr: dangling, + row_capacity, + col_capacity, + } + } else { + let cap = row_capacity + .checked_mul(col_capacity) + .unwrap_or_else(capacity_overflow); + let cap_bytes = cap + .checked_mul(core::mem::size_of::()) + .unwrap_or_else(capacity_overflow); + if cap_bytes > isize::MAX as usize { + capacity_overflow::<()>(); + } + + use alloc::alloc::{alloc, handle_alloc_error, Layout}; + + let layout = Layout::from_size_align(cap_bytes, align_for::()) + .ok() + .unwrap_or_else(capacity_overflow); + + let ptr = if layout.size() == 0 { + dangling + } else { + // SAFETY: we checked that layout has non zero size + let ptr = unsafe { alloc(layout) } as *mut T; + if ptr.is_null() { + handle_alloc_error(layout) + } + // SAFETY: we checked that the pointer is not null + unsafe { NonNull::::new_unchecked(ptr) } + }; + + Self { + ptr, + row_capacity, + col_capacity, + } + } + } +} + +impl Drop for RawMatUnit { + fn drop(&mut self) { + use alloc::alloc::{dealloc, Layout}; + // this cannot overflow because we already allocated this much memory + // self.row_capacity.wrapping_mul(self.col_capacity) may overflow if T is a zst + // but that's fine since we immediately multiply it by 0. + let alloc_size = + self.row_capacity.wrapping_mul(self.col_capacity) * core::mem::size_of::(); + if alloc_size != 0 { + // SAFETY: pointer was allocated with alloc::alloc::alloc + unsafe { + dealloc( + self.ptr.as_ptr() as *mut u8, + Layout::from_size_align_unchecked(alloc_size, align_for::()), + ); + } + } + } +} + +#[repr(C)] +pub struct RawMat { + pub(crate) ptr: GroupCopyFor>, + pub(crate) row_capacity: usize, + pub(crate) col_capacity: usize, +} + +impl RawMat { + pub fn new(row_capacity: usize, col_capacity: usize) -> Self { + // allocate the unit matrices + let group = E::faer_map(E::UNIT, |()| { + RawMatUnit::::new(row_capacity, col_capacity) + }); + + let group = E::faer_map(group, core::mem::ManuallyDrop::new); + + Self { + ptr: into_copy::(E::faer_map(group, |mat| mat.ptr)), + row_capacity, + col_capacity, + } + } +} + +impl Drop for RawMat { + fn drop(&mut self) { + drop(E::faer_map(from_copy::(self.ptr), |ptr| RawMatUnit { + ptr, + row_capacity: self.row_capacity, + col_capacity: self.col_capacity, + })); + } +} + +impl MatUnit { + #[cold] + pub fn do_reserve_exact(&mut self, mut new_row_capacity: usize, mut new_col_capacity: usize) { + new_row_capacity = self.raw.row_capacity.max(new_row_capacity); + new_col_capacity = self.raw.col_capacity.max(new_col_capacity); + + let new_ptr = if self.raw.row_capacity == new_row_capacity + && self.raw.row_capacity != 0 + && self.raw.col_capacity != 0 + { + // case 1: + // we have enough row capacity, and we've already allocated memory. + // use realloc to get extra column memory + + use alloc::alloc::{handle_alloc_error, realloc, Layout}; + + // this shouldn't overflow since we already hold this many bytes + let old_cap = self.raw.row_capacity * self.raw.col_capacity; + let old_cap_bytes = old_cap * core::mem::size_of::(); + + let new_cap = new_row_capacity + .checked_mul(new_col_capacity) + .unwrap_or_else(capacity_overflow); + let new_cap_bytes = new_cap + .checked_mul(core::mem::size_of::()) + .unwrap_or_else(capacity_overflow); + + if new_cap_bytes > isize::MAX as usize { + capacity_overflow::<()>(); + } + + // SAFETY: this shouldn't overflow since we already checked that it's valid during + // allocation + let old_layout = + unsafe { Layout::from_size_align_unchecked(old_cap_bytes, align_for::()) }; + let new_layout = Layout::from_size_align(new_cap_bytes, align_for::()) + .ok() + .unwrap_or_else(capacity_overflow); + + // SAFETY: + // * old_ptr is non null and is the return value of some previous call to alloc + // * old_layout is the same layout that was used to provide the old allocation + // * new_cap_bytes is non zero since new_row_capacity and new_col_capacity are larger + // than self.raw.row_capacity and self.raw.col_capacity respectively, and the computed + // product doesn't overflow. + // * new_cap_bytes, when rounded up to the nearest multiple of the alignment does not + // overflow, since we checked that we can create new_layout with it. + unsafe { + let old_ptr = self.raw.ptr.as_ptr(); + let new_ptr = realloc(old_ptr as *mut u8, old_layout, new_cap_bytes); + if new_ptr.is_null() { + handle_alloc_error(new_layout); + } + new_ptr as *mut T + } + } else { + // case 2: + // use alloc and move stuff manually. + + // allocate new memory region + let new_ptr = { + let m = ManuallyDrop::new(RawMatUnit::::new(new_row_capacity, new_col_capacity)); + m.ptr.as_ptr() + }; + + let old_ptr = self.raw.ptr.as_ptr(); + + // copy each column to new matrix + for j in 0..self.ncols { + // SAFETY: + // * pointer offsets can't overflow since they're within an already allocated + // memory region less than isize::MAX bytes in size. + // * new and old allocation can't overlap, so copy_nonoverlapping is fine here. + unsafe { + let old_ptr = old_ptr.add(j * self.raw.row_capacity); + let new_ptr = new_ptr.add(j * new_row_capacity); + core::ptr::copy_nonoverlapping(old_ptr, new_ptr, self.nrows); + } + } + + // deallocate old matrix memory + let _ = RawMatUnit:: { + // SAFETY: this ptr was checked to be non null, or was acquired from a NonNull + // pointer. + ptr: unsafe { NonNull::new_unchecked(old_ptr) }, + row_capacity: self.raw.row_capacity, + col_capacity: self.raw.col_capacity, + }; + + new_ptr + }; + self.raw.row_capacity = new_row_capacity; + self.raw.col_capacity = new_col_capacity; + self.raw.ptr = unsafe { NonNull::::new_unchecked(new_ptr) }; + } +} + +#[cold] +fn capacity_overflow_impl() -> ! { + panic!("capacity overflow") +} + +#[inline(always)] +fn capacity_overflow() -> T { + capacity_overflow_impl(); +} + +#[inline(always)] +pub fn is_vectorizable() -> bool { + coe::is_same::() + || coe::is_same::() + || coe::is_same::() + || coe::is_same::() + || coe::is_same::() + || coe::is_same::() +} + +// https://rust-lang.github.io/hashbrown/src/crossbeam_utils/cache_padded.rs.html#128-130 +pub const CACHELINE_ALIGN: usize = { + #[cfg(any( + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "powerpc64", + ))] + { + 128 + } + #[cfg(any( + target_arch = "arm", + target_arch = "mips", + target_arch = "mips64", + target_arch = "riscv64", + ))] + { + 32 + } + #[cfg(target_arch = "s390x")] + { + 256 + } + #[cfg(not(any( + target_arch = "x86_64", + target_arch = "aarch64", + target_arch = "powerpc64", + target_arch = "arm", + target_arch = "mips", + target_arch = "mips64", + target_arch = "riscv64", + target_arch = "s390x", + )))] + { + 64 + } +}; + +#[inline(always)] +pub fn align_for() -> usize { + if is_vectorizable::() { + Ord::max( + core::mem::size_of::(), + Ord::max(core::mem::align_of::(), CACHELINE_ALIGN), + ) + } else { + core::mem::align_of::() + } +} diff --git a/src/mat/matmut.rs b/src/mat/matmut.rs new file mode 100644 index 0000000000000000000000000000000000000000..c4b8cf23bf279bb833928621635fe58f31dbb944 --- /dev/null +++ b/src/mat/matmut.rs @@ -0,0 +1,1289 @@ +use super::*; +use crate::{assert, debug_assert, diag::DiagMut, linalg::zip, unzipped, zipped}; + +/// Mutable view over a matrix, similar to a mutable reference to a 2D strided [prim@slice]. +/// +/// # Note +/// +/// Unlike a slice, the data pointed to by `MatMut<'_, E>` is allowed to be partially or fully +/// uninitialized under certain conditions. In this case, care must be taken to not perform any +/// operations that read the uninitialized values, or form references to them, either directly +/// through [`MatMut::read`], or indirectly through any of the numerical library routines, unless +/// it is explicitly permitted. +/// +/// # Move semantics +/// Since `MatMut` mutably borrows data, it cannot be [`Copy`]. This means that if we pass a +/// `MatMut` to a function that takes it by value, or use a method that consumes `self` like +/// [`MatMut::transpose_mut`], this renders the original variable unusable. +/// ```compile_fail +/// use faer::{Mat, MatMut}; +/// +/// fn takes_matmut(view: MatMut<'_, f64>) {} +/// +/// let mut matrix = Mat::new(); +/// let view = matrix.as_mut(); +/// +/// takes_matmut(view); // `view` is moved (passed by value) +/// takes_matmut(view); // this fails to compile since `view` was moved +/// ``` +/// The way to get around it is to use the [`reborrow::ReborrowMut`] trait, which allows us to +/// mutably borrow a `MatMut` to obtain another `MatMut` for the lifetime of the borrow. +/// It's also similarly possible to immutably borrow a `MatMut` to obtain a `MatRef` for the +/// lifetime of the borrow, using [`reborrow::Reborrow`]. +/// ``` +/// use faer::{Mat, MatMut, MatRef}; +/// use reborrow::*; +/// +/// fn takes_matmut(view: MatMut<'_, f64>) {} +/// fn takes_matref(view: MatRef<'_, f64>) {} +/// +/// let mut matrix = Mat::new(); +/// let mut view = matrix.as_mut(); +/// +/// takes_matmut(view.rb_mut()); +/// takes_matmut(view.rb_mut()); +/// takes_matref(view.rb()); +/// // view is still usable here +/// ``` +#[repr(C)] +pub struct MatMut<'a, E: Entity> { + pub(super) inner: MatImpl, + pub(super) __marker: PhantomData<&'a E>, +} + +impl<'short, E: Entity> Reborrow<'short> for MatMut<'_, E> { + type Target = MatRef<'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + MatRef { + inner: self.inner, + __marker: PhantomData, + } + } +} + +impl<'short, E: Entity> ReborrowMut<'short> for MatMut<'_, E> { + type Target = MatMut<'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + MatMut { + inner: self.inner, + __marker: PhantomData, + } + } +} + +impl<'a, E: Entity> IntoConst for MatMut<'a, E> { + type Target = MatRef<'a, E>; + + #[inline] + fn into_const(self) -> Self::Target { + MatRef { + inner: self.inner, + __marker: PhantomData, + } + } +} + +impl<'a, E: Entity> MatMut<'a, E> { + #[inline] + pub(crate) unsafe fn __from_raw_parts( + ptr: GroupFor, + nrows: usize, + ncols: usize, + row_stride: isize, + col_stride: isize, + ) -> Self { + Self { + inner: MatImpl { + ptr: into_copy::(E::faer_map( + ptr, + #[inline] + |ptr| NonNull::new_unchecked(ptr), + )), + nrows, + ncols, + row_stride, + col_stride, + }, + __marker: PhantomData, + } + } + + #[track_caller] + #[inline(always)] + #[doc(hidden)] + pub fn try_get_contiguous_col_mut(self, j: usize) -> GroupFor { + assert!(self.row_stride() == 1); + let col = self.col_mut(j); + if col.nrows() == 0 { + E::faer_map( + E::UNIT, + #[inline(always)] + |()| &mut [] as &mut [E::Unit], + ) + } else { + let m = col.nrows(); + E::faer_map( + col.as_ptr_mut(), + #[inline(always)] + |ptr| unsafe { core::slice::from_raw_parts_mut(ptr, m) }, + ) + } + } + + /// Returns the number of rows of the matrix. + #[inline(always)] + pub fn nrows(&self) -> usize { + self.inner.nrows + } + /// Returns the number of columns of the matrix. + #[inline(always)] + pub fn ncols(&self) -> usize { + self.inner.ncols + } + + /// Returns pointers to the matrix data. + #[inline(always)] + pub fn as_ptr_mut(self) -> GroupFor { + E::faer_map( + from_copy::(self.inner.ptr), + #[inline(always)] + |ptr| ptr.as_ptr(), + ) + } + + /// Returns the row stride of the matrix, specified in number of elements, not in bytes. + #[inline(always)] + pub fn row_stride(&self) -> isize { + self.inner.row_stride + } + + /// Returns the column stride of the matrix, specified in number of elements, not in bytes. + #[inline(always)] + pub fn col_stride(&self) -> isize { + self.inner.col_stride + } + + /// Returns raw pointers to the element at the given indices. + #[inline(always)] + pub fn ptr_at_mut(self, row: usize, col: usize) -> GroupFor { + let offset = ((row as isize).wrapping_mul(self.inner.row_stride)) + .wrapping_add((col as isize).wrapping_mul(self.inner.col_stride)); + E::faer_map( + self.as_ptr_mut(), + #[inline(always)] + |ptr| ptr.wrapping_offset(offset), + ) + } + + #[inline(always)] + unsafe fn ptr_at_mut_unchecked(self, row: usize, col: usize) -> GroupFor { + let offset = crate::utils::unchecked_add( + crate::utils::unchecked_mul(row, self.inner.row_stride), + crate::utils::unchecked_mul(col, self.inner.col_stride), + ); + E::faer_map( + self.as_ptr_mut(), + #[inline(always)] + |ptr| ptr.offset(offset), + ) + } + + /// Returns raw pointers to the element at the given indices, assuming the provided indices + /// are within the matrix dimensions. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn ptr_inbounds_at_mut(self, row: usize, col: usize) -> GroupFor { + debug_assert!(all(row < self.nrows(), col < self.ncols())); + self.ptr_at_mut_unchecked(row, col) + } + + /// Splits the matrix horizontally and vertically at the given indices into four corners and + /// returns an array of each submatrix, in the following order: + /// * top left. + /// * top right. + /// * bottom left. + /// * bottom right. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row <= self.nrows()`. + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at_mut_unchecked(self, row: usize, col: usize) -> (Self, Self, Self, Self) { + let (top_left, top_right, bot_left, bot_right) = + self.into_const().split_at_unchecked(row, col); + ( + top_left.const_cast(), + top_right.const_cast(), + bot_left.const_cast(), + bot_right.const_cast(), + ) + } + + /// Splits the matrix horizontally and vertically at the given indices into four corners and + /// returns an array of each submatrix, in the following order: + /// * top left. + /// * top right. + /// * bottom left. + /// * bottom right. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row <= self.nrows()`. + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn split_at_mut(self, row: usize, col: usize) -> (Self, Self, Self, Self) { + let (top_left, top_right, bot_left, bot_right) = self.into_const().split_at(row, col); + unsafe { + ( + top_left.const_cast(), + top_right.const_cast(), + bot_left.const_cast(), + bot_right.const_cast(), + ) + } + } + + /// Splits the matrix horizontally at the given row into two parts and returns an array of + /// each submatrix, in the following order: + /// * top. + /// * bottom. + /// + /// # Safety + /// The behavior is undefined if the following condition is violated: + /// * `row <= self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at_row_mut_unchecked(self, row: usize) -> (Self, Self) { + let (top, bot) = self.into_const().split_at_row_unchecked(row); + (top.const_cast(), bot.const_cast()) + } + + /// Splits the matrix horizontally at the given row into two parts and returns an array of + /// each submatrix, in the following order: + /// * top. + /// * bottom. + /// + /// # Panics + /// The function panics if the following condition is violated: + /// * `row <= self.nrows()`. + #[inline(always)] + #[track_caller] + pub fn split_at_row_mut(self, row: usize) -> (Self, Self) { + let (top, bot) = self.into_const().split_at_row(row); + unsafe { (top.const_cast(), bot.const_cast()) } + } + + /// Splits the matrix vertically at the given row into two parts and returns an array of + /// each submatrix, in the following order: + /// * left. + /// * right. + /// + /// # Safety + /// The behavior is undefined if the following condition is violated: + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at_col_mut_unchecked(self, col: usize) -> (Self, Self) { + let (left, right) = self.into_const().split_at_col_unchecked(col); + (left.const_cast(), right.const_cast()) + } + + /// Splits the matrix vertically at the given row into two parts and returns an array of + /// each submatrix, in the following order: + /// * left. + /// * right. + /// + /// # Panics + /// The function panics if the following condition is violated: + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn split_at_col_mut(self, col: usize) -> (Self, Self) { + let (left, right) = self.into_const().split_at_col(col); + unsafe { (left.const_cast(), right.const_cast()) } + } + + /// Returns mutable references to the element at the given indices, or submatrices if either + /// `row` or `col` is a range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + /// * `col` must be contained in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub unsafe fn get_mut_unchecked( + self, + row: RowRange, + col: ColRange, + ) -> >::Target + where + Self: MatIndex, + { + >::get_unchecked(self, row, col) + } + + /// Returns mutable references to the element at the given indices, or submatrices if either + /// `row` or `col` is a range, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + /// * `col` must be contained in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub fn get_mut( + self, + row: RowRange, + col: ColRange, + ) -> >::Target + where + Self: MatIndex, + { + >::get(self, row, col) + } + + /// Reads the value of the element at the given indices. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn read_unchecked(&self, row: usize, col: usize) -> E { + self.rb().read_unchecked(row, col) + } + + /// Reads the value of the element at the given indices, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn read(&self, row: usize, col: usize) -> E { + self.rb().read(row, col) + } + + /// Writes the value to the element at the given indices. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn write_unchecked(&mut self, row: usize, col: usize, value: E) { + let units = value.faer_into_units(); + let zipped = E::faer_zip(units, (*self).rb_mut().ptr_inbounds_at_mut(row, col)); + E::faer_map( + zipped, + #[inline(always)] + |(unit, ptr)| *ptr = unit, + ); + } + + /// Writes the value to the element at the given indices, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn write(&mut self, row: usize, col: usize, value: E) { + assert!(all(row < self.nrows(), col < self.ncols())); + unsafe { self.write_unchecked(row, col, value) }; + } + + /// Copies the values from the lower triangular part of `other` into the lower triangular + /// part of `self`. The diagonal part is included. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `self.nrows() == other.nrows()`. + /// * `self.ncols() == other.ncols()`. + /// * `self.nrows() == self.ncols()`. + #[track_caller] + pub fn copy_from_triangular_lower(&mut self, other: impl AsMatRef) { + #[track_caller] + #[inline(always)] + fn implementation(this: MatMut<'_, E>, other: MatRef<'_, E>) { + zipped!(this, other).for_each_triangular_lower( + zip::Diag::Include, + #[inline(always)] + |unzipped!(mut dst, src)| dst.write(src.read()), + ); + } + implementation(self.rb_mut(), other.as_mat_ref()) + } + + /// Copies the values from the lower triangular part of `other` into the lower triangular + /// part of `self`. The diagonal part is excluded. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `self.nrows() == other.nrows()`. + /// * `self.ncols() == other.ncols()`. + /// * `self.nrows() == self.ncols()`. + #[track_caller] + pub fn copy_from_strict_triangular_lower(&mut self, other: impl AsMatRef) { + #[track_caller] + #[inline(always)] + fn implementation(this: MatMut<'_, E>, other: MatRef<'_, E>) { + zipped!(this, other).for_each_triangular_lower( + zip::Diag::Skip, + #[inline(always)] + |unzipped!(mut dst, src)| dst.write(src.read()), + ); + } + implementation(self.rb_mut(), other.as_mat_ref()) + } + + /// Copies the values from the upper triangular part of `other` into the upper triangular + /// part of `self`. The diagonal part is included. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `self.nrows() == other.nrows()`. + /// * `self.ncols() == other.ncols()`. + /// * `self.nrows() == self.ncols()`. + #[track_caller] + #[inline(always)] + pub fn copy_from_triangular_upper(&mut self, other: impl AsMatRef) { + (*self) + .rb_mut() + .transpose_mut() + .copy_from_triangular_lower(other.as_mat_ref().transpose()) + } + + /// Copies the values from the upper triangular part of `other` into the upper triangular + /// part of `self`. The diagonal part is excluded. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `self.nrows() == other.nrows()`. + /// * `self.ncols() == other.ncols()`. + /// * `self.nrows() == self.ncols()`. + #[track_caller] + #[inline(always)] + pub fn copy_from_strict_triangular_upper(&mut self, other: impl AsMatRef) { + (*self) + .rb_mut() + .transpose_mut() + .copy_from_strict_triangular_lower(other.as_mat_ref().transpose()) + } + + /// Copies the values from `other` into `self`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `self.nrows() == other.nrows()`. + /// * `self.ncols() == other.ncols()`. + #[track_caller] + pub fn copy_from(&mut self, other: impl AsMatRef) { + #[track_caller] + #[inline(always)] + fn implementation(this: MatMut<'_, E>, other: MatRef<'_, E>) { + zipped!(this, other).for_each(|unzipped!(mut dst, src)| dst.write(src.read())); + } + implementation(self.rb_mut(), other.as_mat_ref()) + } + + /// Fills the elements of `self` with zeros. + #[track_caller] + pub fn fill_zero(&mut self) + where + E: ComplexField, + { + zipped!(self.rb_mut()).for_each( + #[inline(always)] + |unzipped!(mut x)| x.write(E::faer_zero()), + ); + } + + /// Fills the elements of `self` with copies of `constant`. + #[track_caller] + pub fn fill(&mut self, constant: E) { + zipped!((*self).rb_mut()).for_each( + #[inline(always)] + |unzipped!(mut x)| x.write(constant), + ); + } + + /// Returns a view over the transpose of `self`. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_mut(); + /// let transpose = view.transpose_mut(); + /// + /// let mut expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; + /// assert_eq!(expected.as_mut(), transpose); + /// ``` + #[inline(always)] + #[must_use] + pub fn transpose_mut(self) -> Self { + unsafe { + super::from_raw_parts_mut( + E::faer_map( + from_copy::(self.inner.ptr), + #[inline(always)] + |ptr| ptr.as_ptr(), + ), + self.ncols(), + self.nrows(), + self.col_stride(), + self.row_stride(), + ) + } + } + + /// Returns a view over the conjugate of `self`. + #[inline(always)] + #[must_use] + pub fn conjugate_mut(self) -> MatMut<'a, E::Conj> + where + E: Conjugate, + { + unsafe { self.into_const().conjugate().const_cast() } + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline(always)] + #[must_use] + pub fn adjoint_mut(self) -> MatMut<'a, E::Conj> + where + E: Conjugate, + { + self.transpose_mut().conjugate_mut() + } + + /// Returns a view over the canonical representation of `self`, as well as a flag declaring + /// whether `self` is implicitly conjugated or not. + #[inline(always)] + #[must_use] + pub fn canonicalize_mut(self) -> (MatMut<'a, E::Canonical>, Conj) + where + E: Conjugate, + { + let (canonical, conj) = self.into_const().canonicalize(); + unsafe { (canonical.const_cast(), conj) } + } + + /// Returns a view over the `self`, with the rows in reversed order. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_mut(); + /// let reversed_rows = view.reverse_rows_mut(); + /// + /// let mut expected = mat![[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]; + /// assert_eq!(expected.as_mut(), reversed_rows); + /// ``` + #[inline(always)] + #[must_use] + pub fn reverse_rows_mut(self) -> Self { + unsafe { self.into_const().reverse_rows().const_cast() } + } + + /// Returns a view over the `self`, with the columns in reversed order. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_mut(); + /// let reversed_cols = view.reverse_cols_mut(); + /// + /// let mut expected = mat![[3.0, 2.0, 1.0], [6.0, 5.0, 4.0]]; + /// assert_eq!(expected.as_mut(), reversed_cols); + /// ``` + #[inline(always)] + #[must_use] + pub fn reverse_cols_mut(self) -> Self { + unsafe { self.into_const().reverse_cols().const_cast() } + } + + /// Returns a view over the `self`, with the rows and the columns in reversed order. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let mut matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_mut(); + /// let reversed = view.reverse_rows_and_cols_mut(); + /// + /// let mut expected = mat![[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]; + /// assert_eq!(expected.as_mut(), reversed); + /// ``` + #[inline(always)] + #[must_use] + pub fn reverse_rows_and_cols_mut(self) -> Self { + unsafe { self.into_const().reverse_rows_and_cols().const_cast() } + } + + /// Returns a view over the submatrix starting at indices `(row_start, col_start)`, and with + /// dimensions `(nrows, ncols)`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `col_start <= self.ncols()`. + /// * `nrows <= self.nrows() - row_start`. + /// * `ncols <= self.ncols() - col_start`. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let mut matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_mut(); + /// let submatrix = view.submatrix_mut(2, 1, 2, 2); + /// + /// let mut expected = mat![[7.0, 11.0], [8.0, 12.0f64]]; + /// assert_eq!(expected.as_mut(), submatrix); + /// ``` + #[track_caller] + #[inline(always)] + pub fn submatrix_mut( + self, + row_start: usize, + col_start: usize, + nrows: usize, + ncols: usize, + ) -> Self { + unsafe { + self.into_const() + .submatrix(row_start, col_start, nrows, ncols) + .const_cast() + } + } + + /// Returns a view over the submatrix starting at row `row_start`, and with number of rows + /// `nrows`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `nrows <= self.nrows() - row_start`. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let mut matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_mut(); + /// let subrows = view.subrows_mut(1, 2); + /// + /// let mut expected = mat![[2.0, 6.0, 10.0], [3.0, 7.0, 11.0],]; + /// assert_eq!(expected.as_mut(), subrows); + /// ``` + #[track_caller] + #[inline(always)] + pub fn subrows_mut(self, row_start: usize, nrows: usize) -> Self { + unsafe { self.into_const().subrows(row_start, nrows).const_cast() } + } + + /// Returns a view over the submatrix starting at column `col_start`, and with number of + /// columns `ncols`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col_start <= self.ncols()`. + /// * `ncols <= self.ncols() - col_start`. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let mut matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_mut(); + /// let subcols = view.subcols_mut(2, 1); + /// + /// let mut expected = mat![[9.0], [10.0], [11.0], [12.0f64]]; + /// assert_eq!(expected.as_mut(), subcols); + /// ``` + #[track_caller] + #[inline(always)] + pub fn subcols_mut(self, col_start: usize, ncols: usize) -> Self { + unsafe { self.into_const().subcols(col_start, ncols).const_cast() } + } + + /// Returns a view over the row at the given index. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_idx < self.nrows()`. + #[track_caller] + #[inline(always)] + pub fn row_mut(self, row_idx: usize) -> RowMut<'a, E> { + unsafe { self.into_const().row(row_idx).const_cast() } + } + + /// Returns views over the rows at the given indices. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_idx0 < self.nrows()`. + /// * `row_idx1 < self.nrows()`. + /// * `row_idx0 == row_idx1`. + #[track_caller] + #[inline(always)] + pub fn two_rows_mut(self, row_idx0: usize, row_idx1: usize) -> (RowMut<'a, E>, RowMut<'a, E>) { + assert!(row_idx0 != row_idx1); + let this = self.into_const(); + unsafe { + ( + this.row(row_idx0).const_cast(), + this.row(row_idx1).const_cast(), + ) + } + } + + /// Returns a view over the column at the given index. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col_idx < self.ncols()`. + #[track_caller] + #[inline(always)] + pub fn col_mut(self, col_idx: usize) -> ColMut<'a, E> { + unsafe { self.into_const().col(col_idx).const_cast() } + } + + /// Returns views over the columns at the given indices. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col_idx0 < self.ncols()`. + /// * `col_idx1 < self.ncols()`. + /// * `col_idx0 == col_idx1`. + #[track_caller] + #[inline(always)] + pub fn two_cols_mut(self, col_idx0: usize, col_idx1: usize) -> (ColMut<'a, E>, ColMut<'a, E>) { + assert!(col_idx0 != col_idx1); + let this = self.into_const(); + unsafe { + ( + this.col(col_idx0).const_cast(), + this.col(col_idx1).const_cast(), + ) + } + } + + /// Given a matrix with a single column, returns an object that interprets + /// the column as a diagonal matrix, whoes diagonal elements are values in the column. + #[track_caller] + #[inline(always)] + pub fn column_vector_as_diagonal_mut(self) -> DiagMut<'a, E> { + assert!(self.ncols() == 1); + DiagMut { + inner: self.col_mut(0), + } + } + + /// Returns the diagonal of the matrix. + #[inline(always)] + pub fn diagonal_mut(self) -> DiagMut<'a, E> { + let size = self.nrows().min(self.ncols()); + let row_stride = self.row_stride(); + let col_stride = self.col_stride(); + unsafe { + DiagMut { + inner: crate::col::from_raw_parts_mut( + self.as_ptr_mut(), + size, + row_stride + col_stride, + ), + } + } + } + + /// Returns an owning [`Mat`] of the data + #[inline] + pub fn to_owned(&self) -> Mat + where + E: Conjugate, + { + self.rb().to_owned() + } + + /// Returns `true` if any of the elements is NaN, otherwise returns `false`. + #[inline] + pub fn has_nan(&self) -> bool + where + E: ComplexField, + { + self.rb().has_nan() + } + + /// Returns `true` if all of the elements are finite, otherwise returns `false`. + #[inline] + pub fn is_all_finite(&self) -> bool + where + E: ComplexField, + { + self.rb().is_all_finite() + } + + /// Returns the maximum norm of `self`. + #[inline] + pub fn norm_max(&self) -> E::Real + where + E: ComplexField, + { + self.rb().norm_max() + } + /// Returns the L2 norm of `self`. + #[inline] + pub fn norm_l2(&self) -> E::Real + where + E: ComplexField, + { + self.rb().norm_l2() + } + + /// Returns the sum of `self`. + #[inline] + pub fn sum(&self) -> E + where + E: ComplexField, + { + self.rb().sum() + } + + /// Kroneckor product of `self` and `rhs`. + /// + /// This is an allocating operation; see [`faer::linalg::kron`](crate::linalg::kron) for the + /// allocation-free version or more info in general. + #[inline] + #[track_caller] + pub fn kron(&self, rhs: impl As2D) -> Mat + where + E: ComplexField, + { + self.as_2d_ref().kron(rhs) + } + + /// Returns a view over the matrix. + #[inline] + pub fn as_ref(&self) -> MatRef<'_, E> { + self.rb() + } + + /// Returns a mutable view over the matrix. + #[inline] + pub fn as_mut(&mut self) -> MatMut<'_, E> { + self.rb_mut() + } + + /// Returns an iterator that provides successive chunks of the columns of this matrix, with + /// each having at most `chunk_size` columns. + /// + /// If the number of columns is a multiple of `chunk_size`, then all chunks have + /// `chunk_size` columns. + #[inline] + #[track_caller] + pub fn col_chunks_mut( + self, + chunk_size: usize, + ) -> impl 'a + DoubleEndedIterator> { + self.into_const() + .col_chunks(chunk_size) + .map(|chunk| unsafe { chunk.const_cast() }) + } + + /// Returns an iterator that provides successive chunks of the rows of this matrix, + /// with each having at most `chunk_size` rows. + /// + /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// rows. + #[inline] + #[track_caller] + pub fn row_chunks_mut( + self, + chunk_size: usize, + ) -> impl 'a + DoubleEndedIterator> { + self.into_const() + .row_chunks(chunk_size) + .map(|chunk| unsafe { chunk.const_cast() }) + } + + /// Returns a parallel iterator that provides successive chunks of the columns of this + /// matrix, with each having at most `chunk_size` columns. + /// + /// If the number of columns is a multiple of `chunk_size`, then all chunks have + /// `chunk_size` columns. + /// + /// Only available with the `rayon` feature. + #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] + #[inline] + #[track_caller] + pub fn par_col_chunks_mut( + self, + chunk_size: usize, + ) -> impl 'a + rayon::iter::IndexedParallelIterator> { + use rayon::prelude::*; + self.into_const() + .par_col_chunks(chunk_size) + .map(|chunk| unsafe { chunk.const_cast() }) + } + + /// Returns a parallel iterator that provides successive chunks of the rows of this matrix, + /// with each having at most `chunk_size` rows. + /// + /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// rows. + /// + /// Only available with the `rayon` feature. + #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] + #[inline] + #[track_caller] + pub fn par_row_chunks_mut( + self, + chunk_size: usize, + ) -> impl 'a + rayon::iter::IndexedParallelIterator> { + use rayon::prelude::*; + self.into_const() + .par_row_chunks(chunk_size) + .map(|chunk| unsafe { chunk.const_cast() }) + } +} + +impl<'a, E: RealField> MatMut<'a, num_complex::Complex> { + /// Returns the real and imaginary components of `self`. + #[inline(always)] + pub fn real_imag_mut(self) -> num_complex::Complex> { + let num_complex::Complex { re, im } = self.into_const().real_imag(); + unsafe { + num_complex::Complex { + re: re.const_cast(), + im: im.const_cast(), + } + } + } +} + +impl AsMatRef for MatMut<'_, E> { + #[inline] + fn as_mat_ref(&self) -> MatRef<'_, E> { + (*self).rb() + } +} +impl AsMatRef for &'_ MatMut<'_, E> { + #[inline] + fn as_mat_ref(&self) -> MatRef<'_, E> { + (**self).rb() + } +} + +impl AsMatMut for MatMut<'_, E> { + #[inline] + fn as_mat_mut(&mut self) -> MatMut<'_, E> { + (*self).rb_mut() + } +} + +impl AsMatMut for &'_ mut MatMut<'_, E> { + #[inline] + fn as_mat_mut(&mut self) -> MatMut<'_, E> { + (**self).rb_mut() + } +} + +impl As2D for &'_ MatMut<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (**self).rb() + } +} + +impl As2D for MatMut<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (*self).rb() + } +} + +impl As2DMut for &'_ mut MatMut<'_, E> { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (**self).rb_mut() + } +} + +impl As2DMut for MatMut<'_, E> { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (*self).rb_mut() + } +} + +/// Creates a `MatMut` from pointers to the matrix data, dimensions, and strides. +/// +/// The row (resp. column) stride is the offset from the memory address of a given matrix +/// element at indices `(row: i, col: j)`, to the memory address of the matrix element at +/// indices `(row: i + 1, col: 0)` (resp. `(row: 0, col: i + 1)`). This offset is specified in +/// number of elements, not in bytes. +/// +/// # Safety +/// The behavior is undefined if any of the following conditions are violated: +/// * For each matrix unit, the entire memory region addressed by the matrix must be contained +/// within a single allocation, accessible in its entirety by the corresponding pointer in +/// `ptr`. +/// * For each matrix unit, the corresponding pointer must be non null and properly aligned, +/// even for a zero-sized matrix. +/// * The values accessible by the matrix must be initialized at some point before they are read, or +/// references to them are formed. +/// * No aliasing (including self aliasing) is allowed. In other words, none of the elements +/// accessible by any matrix unit may be accessed for reads or writes by any other means for +/// the duration of the lifetime `'a`. No two elements within a single matrix unit may point to +/// the same address (such a thing can be achieved with a zero stride, for example), and no two +/// matrix units may point to the same address. +/// +/// # Example +/// +/// ``` +/// use faer::mat; +/// +/// // row major matrix with 2 rows, 3 columns, with a column at the end that we want to skip. +/// // the row stride is the pointer offset from the address of 1.0 to the address of 4.0, +/// // which is 4. +/// // the column stride is the pointer offset from the address of 1.0 to the address of 2.0, +/// // which is 1. +/// let mut data = [[1.0, 2.0, 3.0, f64::NAN], [4.0, 5.0, 6.0, f64::NAN]]; +/// let mut matrix = +/// unsafe { mat::from_raw_parts_mut::(data.as_mut_ptr() as *mut f64, 2, 3, 4, 1) }; +/// +/// let expected = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; +/// assert_eq!(expected.as_ref(), matrix); +/// ``` +#[inline(always)] +pub unsafe fn from_raw_parts_mut<'a, E: Entity>( + ptr: GroupFor, + nrows: usize, + ncols: usize, + row_stride: isize, + col_stride: isize, +) -> MatMut<'a, E> { + MatMut::__from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) +} + +/// Creates a `MatMut` from slice views over the matrix data, and the matrix dimensions. +/// The data is interpreted in a column-major format, so that the first chunk of `nrows` +/// values from the slices goes in the first column of the matrix, the second chunk of `nrows` +/// values goes in the second column, and so on. +/// +/// # Panics +/// The function panics if any of the following conditions are violated: +/// * `nrows * ncols == slice.len()` +/// +/// # Example +/// ``` +/// use faer::mat; +/// +/// let mut slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64]; +/// let view = mat::from_column_major_slice_mut::(&mut slice, 3, 2); +/// +/// let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; +/// assert_eq!(expected, view); +/// ``` +#[track_caller] +pub fn from_column_major_slice_mut( + slice: GroupFor, + nrows: usize, + ncols: usize, +) -> MatMut<'_, E> { + from_slice_assert( + nrows, + ncols, + SliceGroup::<'_, E>::new(E::faer_rb(E::faer_as_ref(&slice))).len(), + ); + unsafe { + from_raw_parts_mut( + E::faer_map( + slice, + #[inline(always)] + |slice| slice.as_mut_ptr(), + ), + nrows, + ncols, + 1, + nrows as isize, + ) + } +} + +/// Creates a `MatMut` from slice views over the matrix data, and the matrix dimensions. +/// The data is interpreted in a row-major format, so that the first chunk of `ncols` +/// values from the slices goes in the first column of the matrix, the second chunk of `ncols` +/// values goes in the second column, and so on. +/// +/// # Panics +/// The function panics if any of the following conditions are violated: +/// * `nrows * ncols == slice.len()` +/// +/// # Example +/// ``` +/// use faer::mat; +/// +/// let mut slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64]; +/// let view = mat::from_row_major_slice_mut::(&mut slice, 3, 2); +/// +/// let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]; +/// assert_eq!(expected, view); +/// ``` +#[inline(always)] +#[track_caller] +pub fn from_row_major_slice_mut( + slice: GroupFor, + nrows: usize, + ncols: usize, +) -> MatMut<'_, E> { + from_column_major_slice_mut(slice, ncols, nrows).transpose_mut() +} + +/// Creates a `MatMut` from slice views over the matrix data, and the matrix dimensions. +/// The data is interpreted in a column-major format, where the beginnings of two consecutive +/// columns are separated by `col_stride` elements. +#[track_caller] +pub fn from_column_major_slice_with_stride_mut( + slice: GroupFor, + nrows: usize, + ncols: usize, + col_stride: usize, +) -> MatMut<'_, E> { + from_strided_column_major_slice_mut_assert( + nrows, + ncols, + col_stride, + SliceGroup::<'_, E>::new(E::faer_rb(E::faer_as_ref(&slice))).len(), + ); + unsafe { + from_raw_parts_mut( + E::faer_map( + slice, + #[inline(always)] + |slice| slice.as_mut_ptr(), + ), + nrows, + ncols, + 1, + col_stride as isize, + ) + } +} + +/// Creates a `MatMut` from slice views over the matrix data, and the matrix dimensions. +/// The data is interpreted in a row-major format, where the beginnings of two consecutive +/// rows are separated by `row_stride` elements. +#[track_caller] +pub fn from_row_major_slice_with_stride_mut( + slice: GroupFor, + nrows: usize, + ncols: usize, + row_stride: usize, +) -> MatMut<'_, E> { + from_column_major_slice_with_stride_mut::(slice, ncols, nrows, row_stride).transpose_mut() +} + +impl<'a, E: Entity> core::fmt::Debug for MatMut<'a, E> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.rb().fmt(f) + } +} + +impl core::ops::Index<(usize, usize)> for MatMut<'_, E> { + type Output = E; + + #[inline] + #[track_caller] + fn index(&self, (row, col): (usize, usize)) -> &E { + (*self).rb().get(row, col) + } +} + +impl core::ops::IndexMut<(usize, usize)> for MatMut<'_, E> { + #[inline] + #[track_caller] + fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut E { + (*self).rb_mut().get_mut(row, col) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::Matrix for MatMut<'_, E> { + #[inline] + fn rows(&self) -> usize { + self.nrows() + } + #[inline] + fn cols(&self) -> usize { + self.ncols() + } + #[inline] + fn access(&self) -> matrixcompare_core::Access<'_, E> { + matrixcompare_core::Access::Dense(self) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::DenseAccess for MatMut<'_, E> { + #[inline] + fn fetch_single(&self, row: usize, col: usize) -> E { + self.read(row, col) + } +} diff --git a/src/mat/matown.rs b/src/mat/matown.rs new file mode 100644 index 0000000000000000000000000000000000000000..da5adf4a649f377141d501b52aeee16d178bfc1b --- /dev/null +++ b/src/mat/matown.rs @@ -0,0 +1,950 @@ +use super::*; +use crate::{ + assert, debug_assert, + diag::DiagRef, + mat::matalloc::{align_for, is_vectorizable, MatUnit, RawMat, RawMatUnit}, + utils::DivCeil, +}; +use core::mem::ManuallyDrop; + +/// Heap allocated resizable matrix, similar to a 2D [`Vec`]. +/// +/// # Note +/// +/// The memory layout of `Mat` is guaranteed to be column-major, meaning that it has a row stride +/// of `1`, and an unspecified column stride that can be queried with [`Mat::col_stride`]. +/// +/// This implies that while each individual column is stored contiguously in memory, the matrix as +/// a whole may not necessarily be contiguous. The implementation may add padding at the end of +/// each column when overaligning each column can provide a performance gain. +/// +/// Let us consider a 3×4 matrix +/// +/// ```notcode +/// 0 │ 3 │ 6 │ 9 +/// ───┼───┼───┼─── +/// 1 │ 4 │ 7 │ 10 +/// ───┼───┼───┼─── +/// 2 │ 5 │ 8 │ 11 +/// ``` +/// The memory representation of the data held by such a matrix could look like the following: +/// +/// ```notcode +/// 0 1 2 X 3 4 5 X 6 7 8 X 9 10 11 X +/// ``` +/// +/// where X represents padding elements. +#[repr(C)] +pub struct Mat { + inner: MatOwnImpl, + row_capacity: usize, + col_capacity: usize, + __marker: PhantomData, +} + +impl Mat { + /// Returns an empty matrix of dimension `0×0`. + #[inline] + pub fn new() -> Self { + Self { + inner: MatOwnImpl { + ptr: into_copy::(E::faer_map(E::UNIT, |()| NonNull::::dangling())), + nrows: 0, + ncols: 0, + }, + row_capacity: 0, + col_capacity: 0, + __marker: PhantomData, + } + } + + /// Returns a new matrix with dimensions `(0, 0)`, with enough capacity to hold a maximum of + /// `row_capacity` rows and `col_capacity` columns without reallocating. If either is `0`, + /// the matrix will not allocate. + /// + /// # Panics + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn with_capacity(row_capacity: usize, col_capacity: usize) -> Self { + let raw = ManuallyDrop::new(RawMat::::new(row_capacity, col_capacity)); + Self { + inner: MatOwnImpl { + ptr: raw.ptr, + nrows: 0, + ncols: 0, + }, + row_capacity: raw.row_capacity, + col_capacity: raw.col_capacity, + __marker: PhantomData, + } + } + + /// Returns a new matrix with dimensions `(nrows, ncols)`, filled with the provided function. + /// + /// # Panics + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn from_fn(nrows: usize, ncols: usize, f: impl FnMut(usize, usize) -> E) -> Self { + let mut this = Self::new(); + this.resize_with(nrows, ncols, f); + this + } + + /// Returns a new matrix with dimensions `(nrows, ncols)`, filled with zeros. + /// + /// # Panics + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn zeros(nrows: usize, ncols: usize) -> Self + where + E: ComplexField, + { + Self::from_fn(nrows, ncols, |_, _| E::faer_zero()) + } + + /// Returns a new matrix with dimensions `(nrows, ncols)`, filled with zeros, except the main + /// diagonal which is filled with ones. + /// + /// # Panics + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn identity(nrows: usize, ncols: usize) -> Self + where + E: ComplexField, + { + let mut matrix = Self::zeros(nrows, ncols); + matrix + .as_mut() + .diagonal_mut() + .column_vector_mut() + .fill(E::faer_one()); + matrix + } + + /// Returns the number of rows of the matrix. + #[inline(always)] + pub fn nrows(&self) -> usize { + self.inner.nrows + } + /// Returns the number of columns of the matrix. + #[inline(always)] + pub fn ncols(&self) -> usize { + self.inner.ncols + } + + /// Set the dimensions of the matrix. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `nrows < self.row_capacity()`. + /// * `ncols < self.col_capacity()`. + /// * The elements that were previously out of bounds but are now in bounds must be + /// initialized. + #[inline] + pub unsafe fn set_dims(&mut self, nrows: usize, ncols: usize) { + self.inner.nrows = nrows; + self.inner.ncols = ncols; + } + + /// Returns a pointer to the data of the matrix. + #[inline] + pub fn as_ptr(&self) -> GroupFor { + E::faer_map(from_copy::(self.inner.ptr), |ptr| { + ptr.as_ptr() as *const E::Unit + }) + } + + /// Returns a mutable pointer to the data of the matrix. + #[inline] + pub fn as_ptr_mut(&mut self) -> GroupFor { + E::faer_map(from_copy::(self.inner.ptr), |ptr| ptr.as_ptr()) + } + + /// Returns the row capacity, that is, the number of rows that the matrix is able to hold + /// without needing to reallocate, excluding column insertions. + #[inline] + pub fn row_capacity(&self) -> usize { + self.row_capacity + } + + /// Returns the column capacity, that is, the number of columns that the matrix is able to hold + /// without needing to reallocate, excluding row insertions. + #[inline] + pub fn col_capacity(&self) -> usize { + self.col_capacity + } + + /// Returns the offset between the first elements of two successive rows in the matrix. + /// Always returns `1` since the matrix is column major. + #[inline] + pub fn row_stride(&self) -> isize { + 1 + } + + /// Returns the offset between the first elements of two successive columns in the matrix. + #[inline] + pub fn col_stride(&self) -> isize { + self.row_capacity() as isize + } + + #[cold] + fn do_reserve_exact(&mut self, mut new_row_capacity: usize, new_col_capacity: usize) { + if is_vectorizable::() { + let align_factor = align_for::() / core::mem::size_of::(); + new_row_capacity = new_row_capacity + .msrv_checked_next_multiple_of(align_factor) + .unwrap(); + } + + let nrows = self.inner.nrows; + let ncols = self.inner.ncols; + let old_row_capacity = self.row_capacity; + let old_col_capacity = self.col_capacity; + + let mut this = ManuallyDrop::new(core::mem::take(self)); + { + let mut this_group = E::faer_map(from_copy::(this.inner.ptr), |ptr| MatUnit { + raw: RawMatUnit { + ptr, + row_capacity: old_row_capacity, + col_capacity: old_col_capacity, + }, + nrows, + ncols, + }); + + E::faer_map(E::faer_as_mut(&mut this_group), |mat_unit| { + mat_unit.do_reserve_exact(new_row_capacity, new_col_capacity); + }); + + let this_group = E::faer_map(this_group, ManuallyDrop::new); + this.inner.ptr = + into_copy::(E::faer_map(this_group, |mat_unit| mat_unit.raw.ptr)); + this.row_capacity = new_row_capacity; + this.col_capacity = new_col_capacity; + } + *self = ManuallyDrop::into_inner(this); + } + + /// Reserves the minimum capacity for `row_capacity` rows and `col_capacity` + /// columns without reallocating. Does nothing if the capacity is already sufficient. + /// + /// # Panics + /// The function panics if the new total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn reserve_exact(&mut self, row_capacity: usize, col_capacity: usize) { + if self.row_capacity() >= row_capacity && self.col_capacity() >= col_capacity { + // do nothing + } else if core::mem::size_of::() == 0 { + self.row_capacity = self.row_capacity().max(row_capacity); + self.col_capacity = self.col_capacity().max(col_capacity); + } else { + self.do_reserve_exact(row_capacity, col_capacity); + } + } + + unsafe fn insert_block_with E>( + &mut self, + f: &mut F, + row_start: usize, + row_end: usize, + col_start: usize, + col_end: usize, + ) { + debug_assert!(all(row_start <= row_end, col_start <= col_end)); + + let ptr = self.as_ptr_mut(); + + for j in col_start..col_end { + let ptr_j = E::faer_map(E::faer_copy(&ptr), |ptr| { + ptr.wrapping_offset(j as isize * self.col_stride()) + }); + + for i in row_start..row_end { + // SAFETY: + // * pointer to element at index (i, j), which is within the + // allocation since we reserved enough space + // * writing to this memory region is sound since it is properly + // aligned and valid for writes + let ptr_ij = E::faer_map(E::faer_copy(&ptr_j), |ptr_j| ptr_j.add(i)); + let value = E::faer_into_units(f(i, j)); + + E::faer_map(E::faer_zip(ptr_ij, value), |(ptr_ij, value)| { + core::ptr::write(ptr_ij, value) + }); + } + } + } + + fn erase_last_cols(&mut self, new_ncols: usize) { + let old_ncols = self.ncols(); + debug_assert!(new_ncols <= old_ncols); + self.inner.ncols = new_ncols; + } + + fn erase_last_rows(&mut self, new_nrows: usize) { + let old_nrows = self.nrows(); + debug_assert!(new_nrows <= old_nrows); + self.inner.nrows = new_nrows; + } + + unsafe fn insert_last_cols_with E>( + &mut self, + f: &mut F, + new_ncols: usize, + ) { + let old_ncols = self.ncols(); + + debug_assert!(new_ncols > old_ncols); + + self.insert_block_with(f, 0, self.nrows(), old_ncols, new_ncols); + self.inner.ncols = new_ncols; + } + + unsafe fn insert_last_rows_with E>( + &mut self, + f: &mut F, + new_nrows: usize, + ) { + let old_nrows = self.nrows(); + + debug_assert!(new_nrows > old_nrows); + + self.insert_block_with(f, old_nrows, new_nrows, 0, self.ncols()); + self.inner.nrows = new_nrows; + } + + /// Resizes the matrix in-place so that the new dimensions are `(new_nrows, new_ncols)`. + /// New elements are created with the given function `f`, so that elements at indices `(i, j)` + /// are created by calling `f(i, j)`. + pub fn resize_with( + &mut self, + new_nrows: usize, + new_ncols: usize, + f: impl FnMut(usize, usize) -> E, + ) { + let mut f = f; + let old_nrows = self.nrows(); + let old_ncols = self.ncols(); + + if new_ncols <= old_ncols { + self.erase_last_cols(new_ncols); + if new_nrows <= old_nrows { + self.erase_last_rows(new_nrows); + } else { + self.reserve_exact(new_nrows, new_ncols); + unsafe { + self.insert_last_rows_with(&mut f, new_nrows); + } + } + } else { + if new_nrows <= old_nrows { + self.erase_last_rows(new_nrows); + } else { + self.reserve_exact(new_nrows, new_ncols); + unsafe { + self.insert_last_rows_with(&mut f, new_nrows); + } + } + self.reserve_exact(new_nrows, new_ncols); + unsafe { + self.insert_last_cols_with(&mut f, new_ncols); + } + } + } + + /// Returns a reference to a slice over the column at the given index. + #[inline] + #[track_caller] + pub fn col_as_slice(&self, col: usize) -> GroupFor { + assert!(col < self.ncols()); + let nrows = self.nrows(); + let ptr = self.as_ref().ptr_at(0, col); + E::faer_map( + ptr, + #[inline(always)] + |ptr| unsafe { core::slice::from_raw_parts(ptr, nrows) }, + ) + } + + /// Returns a mutable reference to a slice over the column at the given index. + #[inline] + #[track_caller] + pub fn col_as_slice_mut(&mut self, col: usize) -> GroupFor { + assert!(col < self.ncols()); + let nrows = self.nrows(); + let ptr = self.as_mut().ptr_at_mut(0, col); + E::faer_map( + ptr, + #[inline(always)] + |ptr| unsafe { core::slice::from_raw_parts_mut(ptr, nrows) }, + ) + } + + /// Returns a reference to a slice over the column at the given index. + #[inline] + #[track_caller] + #[deprecated = "replaced by `Mat::col_as_slice`"] + pub fn col_ref(&self, col: usize) -> GroupFor { + self.col_as_slice(col) + } + + /// Returns a mutable reference to a slice over the column at the given index. + #[inline] + #[track_caller] + #[deprecated = "replaced by `Mat::col_as_slice_mut`"] + pub fn col_mut(&mut self, col: usize) -> GroupFor { + self.col_as_slice_mut(col) + } + + /// Returns a view over the matrix. + #[inline] + pub fn as_ref(&self) -> MatRef<'_, E> { + unsafe { + super::from_raw_parts( + self.as_ptr(), + self.nrows(), + self.ncols(), + 1, + self.col_stride(), + ) + } + } + + /// Returns a mutable view over the matrix. + #[inline] + pub fn as_mut(&mut self) -> MatMut<'_, E> { + unsafe { + super::from_raw_parts_mut( + self.as_ptr_mut(), + self.nrows(), + self.ncols(), + 1, + self.col_stride(), + ) + } + } + + /// Returns references to the element at the given indices, or submatrices if either `row` or + /// `col` is a range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + /// * `col` must be contained in `[0, self.ncols())`. + #[inline] + pub unsafe fn get_unchecked( + &self, + row: RowRange, + col: ColRange, + ) -> as MatIndex>::Target + where + for<'a> MatRef<'a, E>: MatIndex, + { + self.as_ref().get_unchecked(row, col) + } + + /// Returns references to the element at the given indices, or submatrices if either `row` or + /// `col` is a range, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + /// * `col` must be contained in `[0, self.ncols())`. + #[inline] + pub fn get( + &self, + row: RowRange, + col: ColRange, + ) -> as MatIndex>::Target + where + for<'a> MatRef<'a, E>: MatIndex, + { + self.as_ref().get(row, col) + } + + /// Returns mutable references to the element at the given indices, or submatrices if either + /// `row` or `col` is a range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + /// * `col` must be contained in `[0, self.ncols())`. + #[inline] + pub unsafe fn get_mut_unchecked( + &mut self, + row: RowRange, + col: ColRange, + ) -> as MatIndex>::Target + where + for<'a> MatMut<'a, E>: MatIndex, + { + self.as_mut().get_mut_unchecked(row, col) + } + + /// Returns mutable references to the element at the given indices, or submatrices if either + /// `row` or `col` is a range, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + /// * `col` must be contained in `[0, self.ncols())`. + #[inline] + pub fn get_mut( + &mut self, + row: RowRange, + col: ColRange, + ) -> as MatIndex>::Target + where + for<'a> MatMut<'a, E>: MatIndex, + { + self.as_mut().get_mut(row, col) + } + + /// Reads the value of the element at the given indices. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn read_unchecked(&self, row: usize, col: usize) -> E { + self.as_ref().read_unchecked(row, col) + } + + /// Reads the value of the element at the given indices, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn read(&self, row: usize, col: usize) -> E { + self.as_ref().read(row, col) + } + + /// Writes the value to the element at the given indices. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn write_unchecked(&mut self, row: usize, col: usize, value: E) { + self.as_mut().write_unchecked(row, col, value); + } + + /// Writes the value to the element at the given indices, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn write(&mut self, row: usize, col: usize, value: E) { + self.as_mut().write(row, col, value); + } + + /// Copies the values from `other` into `self`. + #[inline(always)] + #[track_caller] + pub fn copy_from(&mut self, other: impl AsMatRef) { + #[track_caller] + #[inline(always)] + fn implementation(this: &mut Mat, other: MatRef<'_, E>) { + let mut mat = Mat::::new(); + mat.resize_with( + other.nrows(), + other.ncols(), + #[inline(always)] + |row, col| unsafe { other.read_unchecked(row, col) }, + ); + *this = mat; + } + implementation(self, other.as_mat_ref()); + } + + /// Fills the elements of `self` with zeros. + #[inline(always)] + #[track_caller] + pub fn fill_zero(&mut self) + where + E: ComplexField, + { + self.as_mut().fill_zero() + } + + /// Fills the elements of `self` with copies of `constant`. + #[inline(always)] + #[track_caller] + pub fn fill(&mut self, constant: E) { + self.as_mut().fill(constant) + } + + /// Returns a view over the transpose of `self`. + #[inline] + pub fn transpose(&self) -> MatRef<'_, E> { + self.as_ref().transpose() + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn conjugate(&self) -> MatRef<'_, E::Conj> + where + E: Conjugate, + { + self.as_ref().conjugate() + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline] + pub fn adjoint(&self) -> MatRef<'_, E::Conj> + where + E: Conjugate, + { + self.as_ref().adjoint() + } + + /// Returns a view over the diagonal of the matrix. + #[inline] + pub fn diagonal(&self) -> DiagRef<'_, E> { + self.as_ref().diagonal() + } + + /// Returns an owning [`Mat`] of the data + #[inline] + pub fn to_owned(&self) -> Mat + where + E: Conjugate, + { + self.as_ref().to_owned() + } + + /// Returns `true` if any of the elements is NaN, otherwise returns `false`. + #[inline] + pub fn has_nan(&self) -> bool + where + E: ComplexField, + { + self.as_ref().has_nan() + } + + /// Returns `true` if all of the elements are finite, otherwise returns `false`. + #[inline] + pub fn is_all_finite(&self) -> bool + where + E: ComplexField, + { + self.as_ref().is_all_finite() + } + + /// Returns the maximum norm of `self`. + #[inline] + pub fn norm_max(&self) -> E::Real + where + E: ComplexField, + { + crate::linalg::reductions::norm_max::norm_max((*self).as_ref()) + } + /// Returns the L2 norm of `self`. + #[inline] + pub fn norm_l2(&self) -> E::Real + where + E: ComplexField, + { + crate::linalg::reductions::norm_l2::norm_l2((*self).as_ref()) + } + + /// Returns the sum of `self`. + #[inline] + pub fn sum(&self) -> E + where + E: ComplexField, + { + crate::linalg::reductions::sum::sum((*self).as_ref()) + } + + /// Kroneckor product of `self` and `rhs`. + /// + /// This is an allocating operation; see [`faer::linalg::kron`](crate::linalg::kron) for the + /// allocation-free version or more info in general. + #[inline] + #[track_caller] + pub fn kron(&self, rhs: impl As2D) -> Mat + where + E: ComplexField, + { + self.as_2d_ref().kron(rhs) + } + + /// Returns an iterator that provides successive chunks of the columns of a view over this + /// matrix, with each having at most `chunk_size` columns. + /// + /// If the number of columns is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// columns. + #[inline] + #[track_caller] + pub fn col_chunks( + &self, + chunk_size: usize, + ) -> impl '_ + DoubleEndedIterator> { + self.as_ref().col_chunks(chunk_size) + } + + /// Returns an iterator that provides successive chunks of the columns of a mutable view over + /// this matrix, with each having at most `chunk_size` columns. + /// + /// If the number of columns is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// columns. + #[inline] + #[track_caller] + pub fn col_chunks_mut( + &mut self, + chunk_size: usize, + ) -> impl '_ + DoubleEndedIterator> { + self.as_mut().col_chunks_mut(chunk_size) + } + + /// Returns a parallel iterator that provides successive chunks of the columns of a view over + /// this matrix, with each having at most `chunk_size` columns. + /// + /// If the number of columns is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// columns. + /// + /// Only available with the `rayon` feature. + #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] + #[inline] + #[track_caller] + pub fn par_col_chunks( + &self, + chunk_size: usize, + ) -> impl '_ + rayon::iter::IndexedParallelIterator> { + self.as_ref().par_col_chunks(chunk_size) + } + + /// Returns a parallel iterator that provides successive chunks of the columns of a mutable view + /// over this matrix, with each having at most `chunk_size` columns. + /// + /// If the number of columns is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// columns. + /// + /// Only available with the `rayon` feature. + #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] + #[inline] + #[track_caller] + pub fn par_col_chunks_mut( + &mut self, + chunk_size: usize, + ) -> impl '_ + rayon::iter::IndexedParallelIterator> { + self.as_mut().par_col_chunks_mut(chunk_size) + } + + /// Returns an iterator that provides successive chunks of the rows of a view over this + /// matrix, with each having at most `chunk_size` rows. + /// + /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// rows. + #[inline] + #[track_caller] + pub fn row_chunks( + &self, + chunk_size: usize, + ) -> impl '_ + DoubleEndedIterator> { + self.as_ref().row_chunks(chunk_size) + } + + /// Returns an iterator that provides successive chunks of the rows of a mutable view over + /// this matrix, with each having at most `chunk_size` rows. + /// + /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// rows. + #[inline] + #[track_caller] + pub fn row_chunks_mut( + &mut self, + chunk_size: usize, + ) -> impl '_ + DoubleEndedIterator> { + self.as_mut().row_chunks_mut(chunk_size) + } + + /// Returns a parallel iterator that provides successive chunks of the rows of a view over this + /// matrix, with each having at most `chunk_size` rows. + /// + /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// rows. + /// + /// Only available with the `rayon` feature. + #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] + #[inline] + #[track_caller] + pub fn par_row_chunks( + &self, + chunk_size: usize, + ) -> impl '_ + rayon::iter::IndexedParallelIterator> { + self.as_ref().par_row_chunks(chunk_size) + } + + /// Returns a parallel iterator that provides successive chunks of the rows of a mutable view + /// over this matrix, with each having at most `chunk_size` rows. + /// + /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// rows. + /// + /// Only available with the `rayon` feature. + #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] + #[inline] + #[track_caller] + pub fn par_row_chunks_mut( + &mut self, + chunk_size: usize, + ) -> impl '_ + rayon::iter::IndexedParallelIterator> { + self.as_mut().par_row_chunks_mut(chunk_size) + } +} + +impl Default for Mat { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Clone for Mat { + fn clone(&self) -> Self { + let this = self.as_ref(); + unsafe { + Self::from_fn(self.nrows(), self.ncols(), |i, j| { + E::faer_from_units(E::faer_deref(this.get_unchecked(i, j))) + }) + } + } +} + +impl AsMatRef for Mat { + #[inline] + fn as_mat_ref(&self) -> MatRef<'_, E> { + (*self).as_ref() + } +} +impl AsMatRef for &'_ Mat { + #[inline] + fn as_mat_ref(&self) -> MatRef<'_, E> { + (**self).as_ref() + } +} + +impl AsMatMut for Mat { + #[inline] + fn as_mat_mut(&mut self) -> MatMut<'_, E> { + (*self).as_mut() + } +} + +impl AsMatMut for &'_ mut Mat { + #[inline] + fn as_mat_mut(&mut self) -> MatMut<'_, E> { + (**self).as_mut() + } +} + +impl As2D for &'_ Mat { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (**self).as_ref() + } +} + +impl As2D for Mat { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (*self).as_ref() + } +} + +impl As2DMut for &'_ mut Mat { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (**self).as_mut() + } +} + +impl As2DMut for Mat { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (*self).as_mut() + } +} + +impl core::fmt::Debug for Mat { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.as_ref().fmt(f) + } +} + +impl core::ops::Index<(usize, usize)> for Mat { + type Output = E; + + #[inline] + #[track_caller] + fn index(&self, (row, col): (usize, usize)) -> &E { + self.as_ref().get(row, col) + } +} + +impl core::ops::IndexMut<(usize, usize)> for Mat { + #[inline] + #[track_caller] + fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut E { + self.as_mut().get_mut(row, col) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::Matrix for Mat { + #[inline] + fn rows(&self) -> usize { + self.nrows() + } + #[inline] + fn cols(&self) -> usize { + self.ncols() + } + #[inline] + fn access(&self) -> matrixcompare_core::Access<'_, E> { + matrixcompare_core::Access::Dense(self) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::DenseAccess for Mat { + #[inline] + fn fetch_single(&self, row: usize, col: usize) -> E { + self.read(row, col) + } +} diff --git a/src/mat/matref.rs b/src/mat/matref.rs new file mode 100644 index 0000000000000000000000000000000000000000..47edf6eed6600eaf78d7b5b8b47e9552f7b6293b --- /dev/null +++ b/src/mat/matref.rs @@ -0,0 +1,1345 @@ +use super::*; +use crate::{ + assert, col::ColRef, debug_assert, diag::DiagRef, row::RowRef, unzipped, utils::DivCeil, zipped, +}; + +/// Immutable view over a matrix, similar to an immutable reference to a 2D strided [prim@slice]. +/// +/// # Note +/// +/// Unlike a slice, the data pointed to by `MatRef<'_, E>` is allowed to be partially or fully +/// uninitialized under certain conditions. In this case, care must be taken to not perform any +/// operations that read the uninitialized values, or form references to them, either directly +/// through [`MatRef::read`], or indirectly through any of the numerical library routines, unless +/// it is explicitly permitted. +#[repr(C)] +pub struct MatRef<'a, E: Entity> { + pub(super) inner: MatImpl, + pub(super) __marker: PhantomData<&'a E>, +} + +impl Clone for MatRef<'_, E> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl Copy for MatRef<'_, E> {} + +impl<'short, E: Entity> Reborrow<'short> for MatRef<'_, E> { + type Target = MatRef<'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} + +impl<'short, E: Entity> ReborrowMut<'short> for MatRef<'_, E> { + type Target = MatRef<'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl IntoConst for MatRef<'_, E> { + type Target = Self; + + #[inline] + fn into_const(self) -> Self::Target { + self + } +} + +impl<'a, E: Entity> MatRef<'a, E> { + #[inline] + pub(crate) unsafe fn __from_raw_parts( + ptr: GroupFor, + nrows: usize, + ncols: usize, + row_stride: isize, + col_stride: isize, + ) -> Self { + Self { + inner: MatImpl { + ptr: into_copy::(E::faer_map( + ptr, + #[inline] + |ptr| NonNull::new_unchecked(ptr as *mut E::Unit), + )), + nrows, + ncols, + row_stride, + col_stride, + }, + __marker: PhantomData, + } + } + + #[track_caller] + #[inline(always)] + #[doc(hidden)] + pub fn try_get_contiguous_col(self, j: usize) -> GroupFor { + assert!(self.row_stride() == 1); + let col = self.col(j); + if col.nrows() == 0 { + E::faer_map( + E::UNIT, + #[inline(always)] + |()| &[] as &[E::Unit], + ) + } else { + let m = col.nrows(); + E::faer_map( + col.as_ptr(), + #[inline(always)] + |ptr| unsafe { core::slice::from_raw_parts(ptr, m) }, + ) + } + } + + /// Returns pointers to the matrix data. + #[inline(always)] + pub fn as_ptr(self) -> GroupFor { + E::faer_map( + from_copy::(self.inner.ptr), + #[inline] + |ptr| ptr.as_ptr() as *const E::Unit, + ) + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> usize { + self.inner.nrows + } + + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> usize { + self.inner.ncols + } + + /// Returns the row stride of the matrix, specified in number of elements, not in bytes. + #[inline] + pub fn row_stride(&self) -> isize { + self.inner.row_stride + } + + /// Returns the column stride of the matrix, specified in number of elements, not in bytes. + #[inline] + pub fn col_stride(&self) -> isize { + self.inner.col_stride + } + + /// Returns raw pointers to the element at the given indices. + #[inline(always)] + pub fn ptr_at(self, row: usize, col: usize) -> GroupFor { + let offset = ((row as isize).wrapping_mul(self.inner.row_stride)) + .wrapping_add((col as isize).wrapping_mul(self.inner.col_stride)); + + E::faer_map( + self.as_ptr(), + #[inline(always)] + |ptr| ptr.wrapping_offset(offset), + ) + } + + #[inline(always)] + unsafe fn unchecked_ptr_at(self, row: usize, col: usize) -> GroupFor { + let offset = crate::utils::unchecked_add( + crate::utils::unchecked_mul(row, self.inner.row_stride), + crate::utils::unchecked_mul(col, self.inner.col_stride), + ); + E::faer_map( + self.as_ptr(), + #[inline(always)] + |ptr| ptr.offset(offset), + ) + } + + #[inline(always)] + pub(crate) unsafe fn overflowing_ptr_at( + self, + row: usize, + col: usize, + ) -> GroupFor { + unsafe { + let cond = (row != self.nrows()) & (col != self.ncols()); + let offset = (cond as usize).wrapping_neg() as isize + & (isize::wrapping_add( + (row as isize).wrapping_mul(self.inner.row_stride), + (col as isize).wrapping_mul(self.inner.col_stride), + )); + E::faer_map( + self.as_ptr(), + #[inline(always)] + |ptr| ptr.offset(offset), + ) + } + } + + /// Returns raw pointers to the element at the given indices, assuming the provided indices + /// are within the matrix dimensions. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn ptr_inbounds_at(self, row: usize, col: usize) -> GroupFor { + debug_assert!(all(row < self.nrows(), col < self.ncols())); + self.unchecked_ptr_at(row, col) + } + + /// Splits the matrix horizontally and vertically at the given indices into four corners and + /// returns an array of each submatrix, in the following order: + /// * top left. + /// * top right. + /// * bottom left. + /// * bottom right. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row <= self.nrows()`. + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at_unchecked(self, row: usize, col: usize) -> (Self, Self, Self, Self) { + debug_assert!(all(row <= self.nrows(), col <= self.ncols())); + + let row_stride = self.row_stride(); + let col_stride = self.col_stride(); + + let nrows = self.nrows(); + let ncols = self.ncols(); + + unsafe { + let top_left = self.overflowing_ptr_at(0, 0); + let top_right = self.overflowing_ptr_at(0, col); + let bot_left = self.overflowing_ptr_at(row, 0); + let bot_right = self.overflowing_ptr_at(row, col); + + ( + Self::__from_raw_parts(top_left, row, col, row_stride, col_stride), + Self::__from_raw_parts(top_right, row, ncols - col, row_stride, col_stride), + Self::__from_raw_parts(bot_left, nrows - row, col, row_stride, col_stride), + Self::__from_raw_parts(bot_right, nrows - row, ncols - col, row_stride, col_stride), + ) + } + } + + /// Splits the matrix horizontally and vertically at the given indices into four corners and + /// returns an array of each submatrix, in the following order: + /// * top left. + /// * top right. + /// * bottom left. + /// * bottom right. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row <= self.nrows()`. + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn split_at(self, row: usize, col: usize) -> (Self, Self, Self, Self) { + assert!(all(row <= self.nrows(), col <= self.ncols())); + unsafe { self.split_at_unchecked(row, col) } + } + + /// Splits the matrix horizontally at the given row into two parts and returns an array of + /// each submatrix, in the following order: + /// * top. + /// * bottom. + /// + /// # Safety + /// The behavior is undefined if the following condition is violated: + /// * `row <= self.nrows()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at_row_unchecked(self, row: usize) -> (Self, Self) { + debug_assert!(row <= self.nrows()); + + let row_stride = self.row_stride(); + let col_stride = self.col_stride(); + + let nrows = self.nrows(); + let ncols = self.ncols(); + + unsafe { + let top_right = self.overflowing_ptr_at(0, 0); + let bot_right = self.overflowing_ptr_at(row, 0); + + ( + Self::__from_raw_parts(top_right, row, ncols, row_stride, col_stride), + Self::__from_raw_parts(bot_right, nrows - row, ncols, row_stride, col_stride), + ) + } + } + + /// Splits the matrix horizontally at the given row into two parts and returns an array of + /// each submatrix, in the following order: + /// * top. + /// * bottom. + /// + /// # Panics + /// The function panics if the following condition is violated: + /// * `row <= self.nrows()`. + #[inline(always)] + #[track_caller] + pub fn split_at_row(self, row: usize) -> (Self, Self) { + assert!(row <= self.nrows()); + unsafe { self.split_at_row_unchecked(row) } + } + + /// Splits the matrix vertically at the given row into two parts and returns an array of + /// each submatrix, in the following order: + /// * left. + /// * right. + /// + /// # Safety + /// The behavior is undefined if the following condition is violated: + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at_col_unchecked(self, col: usize) -> (Self, Self) { + debug_assert!(col <= self.ncols()); + + let row_stride = self.row_stride(); + let col_stride = self.col_stride(); + + let nrows = self.nrows(); + let ncols = self.ncols(); + + unsafe { + let bot_left = self.overflowing_ptr_at(0, 0); + let bot_right = self.overflowing_ptr_at(0, col); + + ( + Self::__from_raw_parts(bot_left, nrows, col, row_stride, col_stride), + Self::__from_raw_parts(bot_right, nrows, ncols - col, row_stride, col_stride), + ) + } + } + + /// Splits the matrix vertically at the given row into two parts and returns an array of + /// each submatrix, in the following order: + /// * left. + /// * right. + /// + /// # Panics + /// The function panics if the following condition is violated: + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn split_at_col(self, col: usize) -> (Self, Self) { + assert!(col <= self.ncols()); + unsafe { self.split_at_col_unchecked(col) } + } + + /// Returns references to the element at the given indices, or submatrices if either `row` + /// or `col` is a range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + /// * `col` must be contained in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub unsafe fn get_unchecked( + self, + row: RowRange, + col: ColRange, + ) -> >::Target + where + Self: MatIndex, + { + >::get_unchecked(self, row, col) + } + + /// Returns references to the element at the given indices, or submatrices if either `row` + /// or `col` is a range, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row` must be contained in `[0, self.nrows())`. + /// * `col` must be contained in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub fn get( + self, + row: RowRange, + col: ColRange, + ) -> >::Target + where + Self: MatIndex, + { + >::get(self, row, col) + } + + /// Reads the value of the element at the given indices. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn read_unchecked(&self, row: usize, col: usize) -> E { + E::faer_from_units(E::faer_map( + self.get_unchecked(row, col), + #[inline(always)] + |ptr| *ptr, + )) + } + + /// Reads the value of the element at the given indices, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row < self.nrows()`. + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn read(&self, row: usize, col: usize) -> E { + E::faer_from_units(E::faer_map( + self.get(row, col), + #[inline(always)] + |ptr| *ptr, + )) + } + + /// Returns a view over the transpose of `self`. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_ref(); + /// let transpose = view.transpose(); + /// + /// let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; + /// assert_eq!(expected.as_ref(), transpose); + /// ``` + #[inline(always)] + #[must_use] + pub fn transpose(self) -> Self { + unsafe { + Self::__from_raw_parts( + self.as_ptr(), + self.ncols(), + self.nrows(), + self.col_stride(), + self.row_stride(), + ) + } + } + + /// Returns a view over the conjugate of `self`. + #[inline(always)] + #[must_use] + pub fn conjugate(self) -> MatRef<'a, E::Conj> + where + E: Conjugate, + { + unsafe { + // SAFETY: Conjugate requires that E::Unit and E::Conj::Unit have the same layout + // and that GroupCopyFor == E::Conj::GroupCopy + MatRef::<'_, E::Conj>::__from_raw_parts( + transmute_unchecked::< + GroupFor>, + GroupFor>, + >(self.as_ptr()), + self.nrows(), + self.ncols(), + self.row_stride(), + self.col_stride(), + ) + } + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline(always)] + pub fn adjoint(self) -> MatRef<'a, E::Conj> + where + E: Conjugate, + { + self.transpose().conjugate() + } + + /// Returns a view over the canonical representation of `self`, as well as a flag declaring + /// whether `self` is implicitly conjugated or not. + #[inline(always)] + pub fn canonicalize(self) -> (MatRef<'a, E::Canonical>, Conj) + where + E: Conjugate, + { + ( + unsafe { + // SAFETY: see Self::conjugate + MatRef::<'_, E::Canonical>::__from_raw_parts( + transmute_unchecked::< + GroupFor, + GroupFor>, + >(self.as_ptr()), + self.nrows(), + self.ncols(), + self.row_stride(), + self.col_stride(), + ) + }, + if coe::is_same::() { + Conj::No + } else { + Conj::Yes + }, + ) + } + + /// Returns a view over the `self`, with the rows in reversed order. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_ref(); + /// let reversed_rows = view.reverse_rows(); + /// + /// let expected = mat![[4.0, 5.0, 6.0], [1.0, 2.0, 3.0]]; + /// assert_eq!(expected.as_ref(), reversed_rows); + /// ``` + #[inline(always)] + #[must_use] + pub fn reverse_rows(self) -> Self { + let nrows = self.nrows(); + let ncols = self.ncols(); + let row_stride = self.row_stride().wrapping_neg(); + let col_stride = self.col_stride(); + + let ptr = unsafe { self.unchecked_ptr_at(nrows.saturating_sub(1), 0) }; + unsafe { Self::__from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) } + } + + /// Returns a view over the `self`, with the columns in reversed order. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_ref(); + /// let reversed_cols = view.reverse_cols(); + /// + /// let expected = mat![[3.0, 2.0, 1.0], [6.0, 5.0, 4.0]]; + /// assert_eq!(expected.as_ref(), reversed_cols); + /// ``` + #[inline(always)] + #[must_use] + pub fn reverse_cols(self) -> Self { + let nrows = self.nrows(); + let ncols = self.ncols(); + let row_stride = self.row_stride(); + let col_stride = self.col_stride().wrapping_neg(); + let ptr = unsafe { self.unchecked_ptr_at(0, ncols.saturating_sub(1)) }; + unsafe { Self::__from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) } + } + + /// Returns a view over the `self`, with the rows and the columns in reversed order. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let matrix = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; + /// let view = matrix.as_ref(); + /// let reversed = view.reverse_rows_and_cols(); + /// + /// let expected = mat![[6.0, 5.0, 4.0], [3.0, 2.0, 1.0]]; + /// assert_eq!(expected.as_ref(), reversed); + /// ``` + #[inline(always)] + #[must_use] + pub fn reverse_rows_and_cols(self) -> Self { + let nrows = self.nrows(); + let ncols = self.ncols(); + let row_stride = -self.row_stride(); + let col_stride = -self.col_stride(); + + let ptr = + unsafe { self.unchecked_ptr_at(nrows.saturating_sub(1), ncols.saturating_sub(1)) }; + unsafe { Self::__from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) } + } + + /// Returns a view over the submatrix starting at indices `(row_start, col_start)`, and with + /// dimensions `(nrows, ncols)`. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `col_start <= self.ncols()`. + /// * `nrows <= self.nrows() - row_start`. + /// * `ncols <= self.ncols() - col_start`. + #[track_caller] + #[inline(always)] + pub unsafe fn submatrix_unchecked( + self, + row_start: usize, + col_start: usize, + nrows: usize, + ncols: usize, + ) -> Self { + debug_assert!(all(row_start <= self.nrows(), col_start <= self.ncols())); + debug_assert!(all( + nrows <= self.nrows() - row_start, + ncols <= self.ncols() - col_start, + )); + let row_stride = self.row_stride(); + let col_stride = self.col_stride(); + + unsafe { + Self::__from_raw_parts( + self.overflowing_ptr_at(row_start, col_start), + nrows, + ncols, + row_stride, + col_stride, + ) + } + } + + /// Returns a view over the submatrix starting at indices `(row_start, col_start)`, and with + /// dimensions `(nrows, ncols)`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `col_start <= self.ncols()`. + /// * `nrows <= self.nrows() - row_start`. + /// * `ncols <= self.ncols() - col_start`. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_ref(); + /// let submatrix = view.submatrix(2, 1, 2, 2); + /// + /// let expected = mat![[7.0, 11.0], [8.0, 12.0f64]]; + /// assert_eq!(expected.as_ref(), submatrix); + /// ``` + #[track_caller] + #[inline(always)] + pub fn submatrix(self, row_start: usize, col_start: usize, nrows: usize, ncols: usize) -> Self { + assert!(all(row_start <= self.nrows(), col_start <= self.ncols())); + assert!(all( + nrows <= self.nrows() - row_start, + ncols <= self.ncols() - col_start, + )); + unsafe { self.submatrix_unchecked(row_start, col_start, nrows, ncols) } + } + + /// Returns a view over the submatrix starting at row `row_start`, and with number of rows + /// `nrows`. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `nrows <= self.nrows() - row_start`. + #[track_caller] + #[inline(always)] + pub unsafe fn subrows_unchecked(self, row_start: usize, nrows: usize) -> Self { + debug_assert!(row_start <= self.nrows()); + debug_assert!(nrows <= self.nrows() - row_start); + let row_stride = self.row_stride(); + let col_stride = self.col_stride(); + unsafe { + Self::__from_raw_parts( + self.overflowing_ptr_at(row_start, 0), + nrows, + self.ncols(), + row_stride, + col_stride, + ) + } + } + + /// Returns a view over the submatrix starting at row `row_start`, and with number of rows + /// `nrows`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_start <= self.nrows()`. + /// * `nrows <= self.nrows() - row_start`. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_ref(); + /// let subrows = view.subrows(1, 2); + /// + /// let expected = mat![[2.0, 6.0, 10.0], [3.0, 7.0, 11.0],]; + /// assert_eq!(expected.as_ref(), subrows); + /// ``` + #[track_caller] + #[inline(always)] + pub fn subrows(self, row_start: usize, nrows: usize) -> Self { + assert!(row_start <= self.nrows()); + assert!(nrows <= self.nrows() - row_start); + unsafe { self.subrows_unchecked(row_start, nrows) } + } + + /// Returns a view over the submatrix starting at column `col_start`, and with number of + /// columns `ncols`. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col_start <= self.ncols()`. + /// * `ncols <= self.ncols() - col_start`. + #[track_caller] + #[inline(always)] + pub unsafe fn subcols_unchecked(self, col_start: usize, ncols: usize) -> Self { + debug_assert!(col_start <= self.ncols()); + debug_assert!(ncols <= self.ncols() - col_start); + let row_stride = self.row_stride(); + let col_stride = self.col_stride(); + unsafe { + Self::__from_raw_parts( + self.overflowing_ptr_at(0, col_start), + self.nrows(), + ncols, + row_stride, + col_stride, + ) + } + } + + /// Returns a view over the submatrix starting at column `col_start`, and with number of + /// columns `ncols`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col_start <= self.ncols()`. + /// * `ncols <= self.ncols() - col_start`. + /// + /// # Example + /// ``` + /// use faer::mat; + /// + /// let matrix = mat![ + /// [1.0, 5.0, 9.0], + /// [2.0, 6.0, 10.0], + /// [3.0, 7.0, 11.0], + /// [4.0, 8.0, 12.0f64], + /// ]; + /// + /// let view = matrix.as_ref(); + /// let subcols = view.subcols(2, 1); + /// + /// let expected = mat![[9.0], [10.0], [11.0], [12.0f64]]; + /// assert_eq!(expected.as_ref(), subcols); + /// ``` + #[track_caller] + #[inline(always)] + pub fn subcols(self, col_start: usize, ncols: usize) -> Self { + debug_assert!(col_start <= self.ncols()); + debug_assert!(ncols <= self.ncols() - col_start); + unsafe { self.subcols_unchecked(col_start, ncols) } + } + + /// Returns a view over the row at the given index. + /// + /// # Safety + /// The function panics if any of the following conditions are violated: + /// * `row_idx < self.nrows()`. + #[track_caller] + #[inline(always)] + pub unsafe fn row_unchecked(self, row_idx: usize) -> RowRef<'a, E> { + debug_assert!(row_idx < self.nrows()); + unsafe { + crate::row::from_raw_parts( + self.overflowing_ptr_at(row_idx, 0), + self.ncols(), + self.col_stride(), + ) + } + } + + /// Returns a view over the row at the given index. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `row_idx < self.nrows()`. + #[track_caller] + #[inline(always)] + pub fn row(self, row_idx: usize) -> RowRef<'a, E> { + assert!(row_idx < self.nrows()); + unsafe { self.row_unchecked(row_idx) } + } + + /// Returns a view over the column at the given index. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col_idx < self.ncols()`. + #[track_caller] + #[inline(always)] + pub unsafe fn col_unchecked(self, col_idx: usize) -> ColRef<'a, E> { + debug_assert!(col_idx < self.ncols()); + unsafe { + crate::col::from_raw_parts( + self.overflowing_ptr_at(0, col_idx), + self.nrows(), + self.row_stride(), + ) + } + } + + /// Returns a view over the column at the given index. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col_idx < self.ncols()`. + #[track_caller] + #[inline(always)] + pub fn col(self, col_idx: usize) -> ColRef<'a, E> { + assert!(col_idx < self.ncols()); + unsafe { self.col_unchecked(col_idx) } + } + + /// Given a matrix with a single column, returns an object that interprets + /// the column as a diagonal matrix, whoes diagonal elements are values in the column. + #[track_caller] + #[inline(always)] + pub fn column_vector_as_diagonal(self) -> DiagRef<'a, E> { + assert!(self.ncols() == 1); + DiagRef { inner: self.col(0) } + } + + /// Returns the diagonal of the matrix. + #[inline(always)] + pub fn diagonal(self) -> DiagRef<'a, E> { + let size = self.nrows().min(self.ncols()); + let row_stride = self.row_stride(); + let col_stride = self.col_stride(); + unsafe { + DiagRef { + inner: crate::col::from_raw_parts(self.as_ptr(), size, row_stride + col_stride), + } + } + } + + /// Returns an owning [`Mat`] of the data. + #[inline] + pub fn to_owned(&self) -> Mat + where + E: Conjugate, + { + let mut mat = Mat::new(); + mat.resize_with( + self.nrows(), + self.ncols(), + #[inline(always)] + |row, col| unsafe { self.read_unchecked(row, col).canonicalize() }, + ); + mat + } + + /// Returns `true` if any of the elements is NaN, otherwise returns `false`. + #[inline] + pub fn has_nan(&self) -> bool + where + E: ComplexField, + { + let mut found_nan = false; + zipped!(*self).for_each(|unzipped!(x)| { + found_nan |= x.read().faer_is_nan(); + }); + found_nan + } + + /// Returns `true` if all of the elements are finite, otherwise returns `false`. + #[inline] + pub fn is_all_finite(&self) -> bool + where + E: ComplexField, + { + let mut all_finite = true; + zipped!(*self).for_each(|unzipped!(x)| { + all_finite &= x.read().faer_is_finite(); + }); + all_finite + } + + /// Returns the maximum norm of `self`. + #[inline] + pub fn norm_max(&self) -> E::Real + where + E: ComplexField, + { + crate::linalg::reductions::norm_max::norm_max((*self).rb()) + } + /// Returns the L2 norm of `self`. + #[inline] + pub fn norm_l2(&self) -> E::Real + where + E: ComplexField, + { + crate::linalg::reductions::norm_l2::norm_l2((*self).rb()) + } + + /// Returns the sum of `self`. + #[inline] + pub fn sum(&self) -> E + where + E: ComplexField, + { + crate::linalg::reductions::sum::sum((*self).rb()) + } + + /// Kroneckor product of `self` and `rhs`. + /// + /// This is an allocating operation; see [`faer::linalg::kron`](crate::linalg::kron) for the + /// allocation-free version or more info in general. + #[inline] + #[track_caller] + pub fn kron(&self, rhs: impl As2D) -> Mat + where + E: ComplexField, + { + let lhs = (*self).rb(); + let rhs = rhs.as_2d_ref(); + let mut dst = Mat::new(); + dst.resize_with( + lhs.nrows() * rhs.nrows(), + lhs.ncols() * rhs.ncols(), + |_, _| E::zeroed(), + ); + crate::linalg::kron(dst.as_mut(), lhs, rhs); + dst + } + + /// Returns a view over the matrix. + #[inline] + pub fn as_ref(&self) -> MatRef<'_, E> { + *self + } + + #[doc(hidden)] + #[inline(always)] + pub unsafe fn const_cast(self) -> MatMut<'a, E> { + MatMut { + inner: self.inner, + __marker: PhantomData, + } + } + + /// Returns an iterator that provides successive chunks of the columns of this matrix, with + /// each having at most `chunk_size` columns. + /// + /// If the number of columns is a multiple of `chunk_size`, then all chunks have + /// `chunk_size` columns. + #[inline] + #[track_caller] + pub fn col_chunks( + self, + chunk_size: usize, + ) -> impl 'a + DoubleEndedIterator> { + assert!(chunk_size > 0); + let chunk_count = self.ncols().msrv_div_ceil(chunk_size); + (0..chunk_count).map(move |chunk_idx| { + let pos = chunk_size * chunk_idx; + self.subcols(pos, Ord::min(chunk_size, self.ncols() - pos)) + }) + } + + /// Returns an iterator that provides successive chunks of the rows of this matrix, with + /// each having at most `chunk_size` rows. + /// + /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// rows. + #[inline] + #[track_caller] + pub fn row_chunks( + self, + chunk_size: usize, + ) -> impl 'a + DoubleEndedIterator> { + self.transpose() + .col_chunks(chunk_size) + .map(|chunk| chunk.transpose()) + } + + /// Returns a parallel iterator that provides successive chunks of the columns of this + /// matrix, with each having at most `chunk_size` columns. + /// + /// If the number of columns is a multiple of `chunk_size`, then all chunks have + /// `chunk_size` columns. + /// + /// Only available with the `rayon` feature. + #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] + #[inline] + #[track_caller] + pub fn par_col_chunks( + self, + chunk_size: usize, + ) -> impl 'a + rayon::iter::IndexedParallelIterator> { + use rayon::prelude::*; + + assert!(chunk_size > 0); + let chunk_count = self.ncols().msrv_div_ceil(chunk_size); + (0..chunk_count).into_par_iter().map(move |chunk_idx| { + let pos = chunk_size * chunk_idx; + self.subcols(pos, Ord::min(chunk_size, self.ncols() - pos)) + }) + } + + /// Returns a parallel iterator that provides successive chunks of the rows of this matrix, + /// with each having at most `chunk_size` rows. + /// + /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// rows. + /// + /// Only available with the `rayon` feature. + #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] + #[inline] + #[track_caller] + pub fn par_row_chunks( + self, + chunk_size: usize, + ) -> impl 'a + rayon::iter::IndexedParallelIterator> { + use rayon::prelude::*; + + self.transpose() + .par_col_chunks(chunk_size) + .map(|chunk| chunk.transpose()) + } + + /// Returns a parallel iterator that provides successive chunks of the rows of this matrix, + /// with each having at most `chunk_size` rows. + /// + /// If the number of rows is a multiple of `chunk_size`, then all chunks have `chunk_size` + /// rows. + /// + /// Only available with the `rayon` feature. + #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] + #[inline] + #[track_caller] + #[deprecated = "replaced by `MatRef::par_row_chunks`"] + pub fn into_par_row_chunks( + self, + chunk_size: usize, + ) -> impl 'a + rayon::iter::IndexedParallelIterator> { + self.par_row_chunks(chunk_size) + } +} + +impl<'a, E: RealField> MatRef<'a, num_complex::Complex> { + /// Returns the real and imaginary components of `self`. + #[inline(always)] + pub fn real_imag(self) -> num_complex::Complex> { + let row_stride = self.row_stride(); + let col_stride = self.col_stride(); + let nrows = self.nrows(); + let ncols = self.ncols(); + let num_complex::Complex { re, im } = self.as_ptr(); + unsafe { + num_complex::Complex { + re: super::from_raw_parts(re, nrows, ncols, row_stride, col_stride), + im: super::from_raw_parts(im, nrows, ncols, row_stride, col_stride), + } + } + } +} + +impl AsMatRef for MatRef<'_, E> { + #[inline] + fn as_mat_ref(&self) -> MatRef<'_, E> { + *self + } +} +impl AsMatRef for &'_ MatRef<'_, E> { + #[inline] + fn as_mat_ref(&self) -> MatRef<'_, E> { + **self + } +} + +impl As2D for &'_ MatRef<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + **self + } +} + +impl As2D for MatRef<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + *self + } +} + +/// Creates a `MatRef` from pointers to the matrix data, dimensions, and strides. +/// +/// The row (resp. column) stride is the offset from the memory address of a given matrix +/// element at indices `(row: i, col: j)`, to the memory address of the matrix element at +/// indices `(row: i + 1, col: 0)` (resp. `(row: 0, col: i + 1)`). This offset is specified in +/// number of elements, not in bytes. +/// +/// # Safety +/// The behavior is undefined if any of the following conditions are violated: +/// * For each matrix unit, the entire memory region addressed by the matrix must be contained +/// within a single allocation, accessible in its entirety by the corresponding pointer in +/// `ptr`. +/// * For each matrix unit, the corresponding pointer must be properly aligned, +/// even for a zero-sized matrix. +/// * The values accessible by the matrix must be initialized at some point before they are +/// read, or references to them are formed. +/// * No mutable aliasing is allowed. In other words, none of the elements accessible by any +/// matrix unit may be accessed for writes by any other means for the duration of the lifetime +/// `'a`. +/// +/// # Example +/// +/// ``` +/// use faer::mat; +/// +/// // row major matrix with 2 rows, 3 columns, with a column at the end that we want to skip. +/// // the row stride is the pointer offset from the address of 1.0 to the address of 4.0, +/// // which is 4. +/// // the column stride is the pointer offset from the address of 1.0 to the address of 2.0, +/// // which is 1. +/// let data = [[1.0, 2.0, 3.0, f64::NAN], [4.0, 5.0, 6.0, f64::NAN]]; +/// let matrix = unsafe { mat::from_raw_parts::(data.as_ptr() as *const f64, 2, 3, 4, 1) }; +/// +/// let expected = mat![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; +/// assert_eq!(expected.as_ref(), matrix); +/// ``` +#[inline(always)] +pub unsafe fn from_raw_parts<'a, E: Entity>( + ptr: GroupFor, + nrows: usize, + ncols: usize, + row_stride: isize, + col_stride: isize, +) -> MatRef<'a, E> { + MatRef::__from_raw_parts(ptr, nrows, ncols, row_stride, col_stride) +} + +/// Creates a `MatRef` from slice views over the matrix data, and the matrix dimensions. +/// The data is interpreted in a column-major format, so that the first chunk of `nrows` +/// values from the slices goes in the first column of the matrix, the second chunk of `nrows` +/// values goes in the second column, and so on. +/// +/// # Panics +/// The function panics if any of the following conditions are violated: +/// * `nrows * ncols == slice.len()` +/// +/// # Example +/// ``` +/// use faer::mat; +/// +/// let slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64]; +/// let view = mat::from_column_major_slice::(&slice, 3, 2); +/// +/// let expected = mat![[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]; +/// assert_eq!(expected, view); +/// ``` +#[track_caller] +#[inline(always)] +pub fn from_column_major_slice( + slice: GroupFor, + nrows: usize, + ncols: usize, +) -> MatRef<'_, E> { + from_slice_assert( + nrows, + ncols, + SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len(), + ); + + unsafe { + from_raw_parts( + E::faer_map( + slice, + #[inline(always)] + |slice| slice.as_ptr(), + ), + nrows, + ncols, + 1, + nrows as isize, + ) + } +} + +/// Creates a `MatRef` from slice views over the matrix data, and the matrix dimensions. +/// The data is interpreted in a row-major format, so that the first chunk of `ncols` +/// values from the slices goes in the first column of the matrix, the second chunk of `ncols` +/// values goes in the second column, and so on. +/// +/// # Panics +/// The function panics if any of the following conditions are violated: +/// * `nrows * ncols == slice.len()` +/// +/// # Example +/// ``` +/// use faer::mat; +/// +/// let slice = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0_f64]; +/// let view = mat::from_row_major_slice::(&slice, 3, 2); +/// +/// let expected = mat![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]; +/// assert_eq!(expected, view); +/// ``` +#[track_caller] +#[inline(always)] +pub fn from_row_major_slice( + slice: GroupFor, + nrows: usize, + ncols: usize, +) -> MatRef<'_, E> { + from_column_major_slice(slice, ncols, nrows).transpose() +} + +/// Creates a `MatRef` from slice views over the matrix data, and the matrix dimensions. +/// The data is interpreted in a column-major format, where the beginnings of two consecutive +/// columns are separated by `col_stride` elements. +#[track_caller] +pub fn from_column_major_slice_with_stride( + slice: GroupFor, + nrows: usize, + ncols: usize, + col_stride: usize, +) -> MatRef<'_, E> { + from_strided_column_major_slice_assert( + nrows, + ncols, + col_stride, + SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len(), + ); + + unsafe { + from_raw_parts( + E::faer_map( + slice, + #[inline(always)] + |slice| slice.as_ptr(), + ), + nrows, + ncols, + 1, + col_stride as isize, + ) + } +} + +/// Creates a `MatRef` from slice views over the matrix data, and the matrix dimensions. +/// The data is interpreted in a row-major format, where the beginnings of two consecutive +/// rows are separated by `row_stride` elements. +#[track_caller] +pub fn from_row_major_slice_with_stride( + slice: GroupFor, + nrows: usize, + ncols: usize, + row_stride: usize, +) -> MatRef<'_, E> { + from_column_major_slice_with_stride::(slice, ncols, nrows, row_stride).transpose() +} + +impl<'a, E: Entity> core::fmt::Debug for MatRef<'a, E> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + struct DebugRow<'a, T: Entity>(MatRef<'a, T>); + + impl<'a, T: Entity> core::fmt::Debug for DebugRow<'a, T> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mut j = 0; + f.debug_list() + .entries(core::iter::from_fn(|| { + let ret = if j < self.0.ncols() { + Some(T::faer_from_units(T::faer_deref(self.0.get(0, j)))) + } else { + None + }; + j += 1; + ret + })) + .finish() + } + } + + writeln!(f, "[")?; + for i in 0..self.nrows() { + let row = self.subrows(i, 1); + DebugRow(row).fmt(f)?; + f.write_str(",\n")?; + } + write!(f, "]") + } +} + +impl core::ops::Index<(usize, usize)> for MatRef<'_, E> { + type Output = E; + + #[inline] + #[track_caller] + fn index(&self, (row, col): (usize, usize)) -> &E { + self.get(row, col) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::Matrix for MatRef<'_, E> { + #[inline] + fn rows(&self) -> usize { + self.nrows() + } + #[inline] + fn cols(&self) -> usize { + self.ncols() + } + #[inline] + fn access(&self) -> matrixcompare_core::Access<'_, E> { + matrixcompare_core::Access::Dense(self) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::DenseAccess for MatRef<'_, E> { + #[inline] + fn fetch_single(&self, row: usize, col: usize) -> E { + self.read(row, col) + } +} diff --git a/src/mat/mod.rs b/src/mat/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..a159fa2b8b9ba8f765b6191a391698a30e70977f --- /dev/null +++ b/src/mat/mod.rs @@ -0,0 +1,165 @@ +use crate::{assert, col::ColMut, row::RowMut, utils::slice::*, Conj}; +use coe::Coerce; +use core::{marker::PhantomData, ptr::NonNull}; +use faer_entity::*; +use reborrow::*; + +#[repr(C)] +struct MatImpl { + ptr: GroupCopyFor>, + nrows: usize, + ncols: usize, + row_stride: isize, + col_stride: isize, +} +#[repr(C)] +struct MatOwnImpl { + ptr: GroupCopyFor>, + nrows: usize, + ncols: usize, +} + +unsafe impl Sync for MatImpl {} +unsafe impl Send for MatImpl {} +unsafe impl Sync for MatOwnImpl {} +unsafe impl Send for MatOwnImpl {} + +impl Copy for MatImpl {} +impl Clone for MatImpl { + #[inline(always)] + fn clone(&self) -> Self { + *self + } +} + +/// Represents a type that can be used to slice a matrix, such as an index or a range of indices. +pub trait MatIndex: crate::seal::Seal + Sized { + /// Resulting type of the indexing operation. + type Target; + + /// Index the matrix at `(row, col)`, without bound checks. + #[allow(clippy::missing_safety_doc)] + unsafe fn get_unchecked(this: Self, row: RowRange, col: ColRange) -> Self::Target { + >::get(this, row, col) + } + /// Index the matrix at `(row, col)`. + fn get(this: Self, row: RowRange, col: ColRange) -> Self::Target; +} + +/// Trait for types that can be converted to a matrix view. +/// +/// This trait is implemented for types of the matrix family, like [`Mat`], +/// [`MatRef`], and [`MatMut`], but not for types like [`Col`], [`Row`], or +/// their families. For a more general trait, see [`As2D`]. +pub trait AsMatRef { + /// Convert to a matrix view. + fn as_mat_ref(&self) -> MatRef<'_, E>; +} +/// Trait for types that can be converted to a mutable matrix view. +/// +/// This trait is implemented for types of the matrix family, like [`Mat`], +/// [`MatRef`], and [`MatMut`], but not for types like [`Col`], [`Row`], or +/// their families. For a more general trait, see [`As2D`]. +pub trait AsMatMut { + /// Convert to a mutable matrix view. + fn as_mat_mut(&mut self) -> MatMut<'_, E>; +} + +/// Trait for types that can be converted to a 2D matrix view. +/// +/// This trait is implemented for any type that can be represented as a +/// 2D matrix view, like [`Mat`], [`Row`], [`Col`], and their respective +/// references and mutable references. For a trait specific to the matrix +/// family, see [`AsMatRef`] or [`AsMatMut`]. +pub trait As2D { + /// Convert to a 2D matrix view. + fn as_2d_ref(&self) -> MatRef<'_, E>; +} +/// Trait for types that can be converted to a mutable 2D matrix view. +/// +/// This trait is implemented for any type that can be represented as a +/// 2D matrix view, like [`Mat`], [`Row`], [`Col`], and their respective +/// references and mutable references. For a trait specific to the matrix +/// family, see [`AsMatRef`] or [`AsMatMut`]. +pub trait As2DMut { + /// Convert to a mutable 2D matrix view. + fn as_2d_mut(&mut self) -> MatMut<'_, E>; +} + +impl<'a, FromE: Entity, ToE: Entity> Coerce> for MatRef<'a, FromE> { + #[inline(always)] + fn coerce(self) -> MatRef<'a, ToE> { + assert!(coe::is_same::()); + unsafe { transmute_unchecked::, MatRef<'a, ToE>>(self) } + } +} +impl<'a, FromE: Entity, ToE: Entity> Coerce> for MatMut<'a, FromE> { + #[inline(always)] + fn coerce(self) -> MatMut<'a, ToE> { + assert!(coe::is_same::()); + unsafe { transmute_unchecked::, MatMut<'a, ToE>>(self) } + } +} + +mod mat_index; + +mod matref; +pub use matref::{ + from_column_major_slice, from_column_major_slice_with_stride, from_raw_parts, + from_row_major_slice, from_row_major_slice_with_stride, MatRef, +}; + +mod matmut; +pub use matmut::{ + from_column_major_slice_mut, from_column_major_slice_with_stride_mut, from_raw_parts_mut, + from_row_major_slice_mut, from_row_major_slice_with_stride_mut, MatMut, +}; + +mod matown; +pub use matown::Mat; + +pub(crate) mod matalloc; + +#[track_caller] +#[inline] +fn from_slice_assert(nrows: usize, ncols: usize, len: usize) { + // we don't have to worry about size == usize::MAX == slice.len(), because the length of a + // slice can never exceed isize::MAX in bytes, unless the type is zero sized, in which case + // we don't care + let size = usize::checked_mul(nrows, ncols).unwrap_or(usize::MAX); + assert!(size == len); +} + +#[track_caller] +#[inline] +fn from_strided_column_major_slice_assert( + nrows: usize, + ncols: usize, + col_stride: usize, + len: usize, +) { + // we don't have to worry about size == usize::MAX == slice.len(), because the length of a + // slice can never exceed isize::MAX in bytes, unless the type is zero sized, in which case + // we don't care + let last = usize::checked_mul(col_stride, ncols - 1) + .and_then(|last_col| last_col.checked_add(nrows - 1)) + .unwrap_or(usize::MAX); + assert!(last < len); +} + +#[track_caller] +#[inline] +fn from_strided_column_major_slice_mut_assert( + nrows: usize, + ncols: usize, + col_stride: usize, + len: usize, +) { + // we don't have to worry about size == usize::MAX == slice.len(), because the length of a + // slice can never exceed isize::MAX in bytes, unless the type is zero sized, in which case + // we don't care + let last = usize::checked_mul(col_stride, ncols - 1) + .and_then(|last_col| last_col.checked_add(nrows - 1)) + .unwrap_or(usize::MAX); + assert!(all(col_stride >= nrows, last < len)); +} diff --git a/src/perm/mod.rs b/src/perm/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..f98a2cc5a66e318781a4d4d4ab2d1855a29970d1 --- /dev/null +++ b/src/perm/mod.rs @@ -0,0 +1,319 @@ +use crate::{assert, col::*, linalg::temp_mat_uninit, mat::*, row::*, utils::constrained, *}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; +use reborrow::*; + +/// Swaps the values in the columns `a` and `b`. +/// +/// # Panics +/// +/// Panics if `a` and `b` don't have the same number of columns. +/// +/// # Example +/// +/// ``` +/// use faer::{mat, perm::swap_cols}; +/// +/// let mut m = mat![ +/// [1.0, 2.0, 3.0], +/// [4.0, 5.0, 6.0], +/// [7.0, 8.0, 9.0], +/// [10.0, 14.0, 12.0], +/// ]; +/// +/// let (a, b) = m.as_mut().two_cols_mut(0, 2); +/// swap_cols(a, b); +/// +/// let swapped = mat![ +/// [3.0, 2.0, 1.0], +/// [6.0, 5.0, 4.0], +/// [9.0, 8.0, 7.0], +/// [12.0, 14.0, 10.0], +/// ]; +/// +/// assert_eq!(m, swapped); +/// ``` +#[track_caller] +#[inline] +pub fn swap_cols(a: ColMut<'_, E>, b: ColMut<'_, E>) { + zipped!(a, b).for_each(|unzipped!(mut a, mut b)| { + let (a_read, b_read) = (a.read(), b.read()); + a.write(b_read); + b.write(a_read); + }); +} + +/// Swaps the values in the rows `a` and `b`. +/// +/// # Panics +/// +/// Panics if `a` and `b` don't have the same number of columns. +/// +/// # Example +/// +/// ``` +/// use faer::{mat, perm::swap_rows}; +/// +/// let mut m = mat![ +/// [1.0, 2.0, 3.0], +/// [4.0, 5.0, 6.0], +/// [7.0, 8.0, 9.0], +/// [10.0, 14.0, 12.0], +/// ]; +/// +/// let (a, b) = m.as_mut().two_rows_mut(0, 2); +/// swap_rows(a, b); +/// +/// let swapped = mat![ +/// [7.0, 8.0, 9.0], +/// [4.0, 5.0, 6.0], +/// [1.0, 2.0, 3.0], +/// [10.0, 14.0, 12.0], +/// ]; +/// +/// assert_eq!(m, swapped); +/// ``` +#[track_caller] +#[inline] +pub fn swap_rows(a: RowMut<'_, E>, b: RowMut<'_, E>) { + swap_cols(a.transpose_mut(), b.transpose_mut()) +} + +/// Swaps the two rows at indices `a` and `b` in the given matrix. +/// +/// # Panics +/// +/// Panics if either `a` or `b` is out of bounds. +/// +/// # Example +/// +/// ``` +/// use faer::{mat, perm::swap_rows_idx}; +/// +/// let mut m = mat![ +/// [1.0, 2.0, 3.0], +/// [4.0, 5.0, 6.0], +/// [7.0, 8.0, 9.0], +/// [10.0, 14.0, 12.0], +/// ]; +/// +/// swap_rows_idx(m.as_mut(), 0, 2); +/// +/// let swapped = mat![ +/// [7.0, 8.0, 9.0], +/// [4.0, 5.0, 6.0], +/// [1.0, 2.0, 3.0], +/// [10.0, 14.0, 12.0], +/// ]; +/// +/// assert_eq!(m, swapped); +/// ``` +#[track_caller] +#[inline] +pub fn swap_rows_idx(mat: MatMut<'_, E>, a: usize, b: usize) { + if a != b { + let (a, b) = mat.two_rows_mut(a, b); + swap_rows(a, b); + } +} + +/// Swaps the two columns at indices `a` and `b` in the given matrix. +/// +/// # Panics +/// +/// Panics if either `a` or `b` is out of bounds. +/// +/// # Example +/// +/// ``` +/// use faer::{mat, perm::swap_cols_idx}; +/// +/// let mut m = mat![ +/// [1.0, 2.0, 3.0], +/// [4.0, 5.0, 6.0], +/// [7.0, 8.0, 9.0], +/// [10.0, 14.0, 12.0], +/// ]; +/// +/// swap_cols_idx(m.as_mut(), 0, 2); +/// +/// let swapped = mat![ +/// [3.0, 2.0, 1.0], +/// [6.0, 5.0, 4.0], +/// [9.0, 8.0, 7.0], +/// [12.0, 14.0, 10.0], +/// ]; +/// +/// assert_eq!(m, swapped); +/// ``` +#[track_caller] +#[inline] +pub fn swap_cols_idx(mat: MatMut<'_, E>, a: usize, b: usize) { + if a != b { + let (a, b) = mat.two_cols_mut(a, b); + swap_cols(a, b); + } +} + +mod permown; +mod permref; + +pub use permown::Perm; +pub use permref::PermRef; + +use self::linalg::temp_mat_req; + +/// Computes a permutation of the columns of the source matrix using the given permutation, and +/// stores the result in the destination matrix. +/// +/// # Panics +/// +/// - Panics if the matrices do not have the same shape. +/// - Panics if the size of the permutation doesn't match the number of columns of the matrices. +#[inline] +#[track_caller] +pub fn permute_cols( + dst: MatMut<'_, E>, + src: MatRef<'_, E>, + perm_indices: PermRef<'_, I>, +) { + assert!(all( + src.nrows() == dst.nrows(), + src.ncols() == dst.ncols(), + perm_indices.arrays().0.len() == src.ncols(), + )); + + permute_rows( + dst.transpose_mut(), + src.transpose(), + perm_indices.canonicalized(), + ); +} + +/// Computes a permutation of the rows of the source matrix using the given permutation, and +/// stores the result in the destination matrix. +/// +/// # Panics +/// +/// - Panics if the matrices do not have the same shape. +/// - Panics if the size of the permutation doesn't match the number of rows of the matrices. +#[inline] +#[track_caller] +pub fn permute_rows( + dst: MatMut<'_, E>, + src: MatRef<'_, E>, + perm_indices: PermRef<'_, I>, +) { + #[track_caller] + fn implementation( + dst: MatMut<'_, E>, + src: MatRef<'_, E>, + perm_indices: PermRef<'_, I>, + ) { + assert!(all( + src.nrows() == dst.nrows(), + src.ncols() == dst.ncols(), + perm_indices.len() == src.nrows(), + )); + + constrained::Size::with2(src.nrows(), src.ncols(), |m, n| { + let mut dst = constrained::mat::MatMut::new(dst, m, n); + let src = constrained::mat::MatRef::new(src, m, n); + let perm = constrained::perm::PermRef::new(perm_indices, m).arrays().0; + + if dst.rb().into_inner().row_stride().unsigned_abs() + < dst.rb().into_inner().col_stride().unsigned_abs() + { + for j in n.indices() { + for i in m.indices() { + dst.rb_mut().write(i, j, src.read(perm[i].zx(), j)); + } + } + } else { + for i in m.indices() { + let src_i = src.into_inner().row(perm[i].zx().into_inner()); + let mut dst_i = dst.rb_mut().into_inner().row_mut(i.into_inner()); + + dst_i.copy_from(src_i); + } + } + }); + } + + implementation(dst, src, perm_indices.canonicalized()) +} + +/// Computes the size and alignment of required workspace for applying a row permutation to a +/// matrix in place. +pub fn permute_rows_in_place_req( + nrows: usize, + ncols: usize, +) -> Result { + temp_mat_req::(nrows, ncols) +} + +/// Computes the size and alignment of required workspace for applying a column permutation to a +/// matrix in place. +pub fn permute_cols_in_place_req( + nrows: usize, + ncols: usize, +) -> Result { + temp_mat_req::(nrows, ncols) +} + +/// Computes a permutation of the rows of the matrix using the given permutation, and +/// stores the result in the same matrix. +/// +/// # Panics +/// +/// - Panics if the size of the permutation doesn't match the number of rows of the matrix. +#[inline] +#[track_caller] +pub fn permute_rows_in_place( + matrix: MatMut<'_, E>, + perm_indices: PermRef<'_, I>, + stack: PodStack<'_>, +) { + #[inline] + #[track_caller] + fn implementation( + matrix: MatMut<'_, E>, + perm_indices: PermRef<'_, I>, + stack: PodStack<'_>, + ) { + let mut matrix = matrix; + let (mut tmp, _) = temp_mat_uninit::(matrix.nrows(), matrix.ncols(), stack); + tmp.rb_mut().copy_from(matrix.rb()); + permute_rows(matrix.rb_mut(), tmp.rb(), perm_indices); + } + + implementation(matrix, perm_indices.canonicalized(), stack) +} + +/// Computes a permutation of the columns of the matrix using the given permutation, and +/// stores the result in the same matrix. +/// +/// # Panics +/// +/// - Panics if the size of the permutation doesn't match the number of columns of the matrix. +#[inline] +#[track_caller] +pub fn permute_cols_in_place( + matrix: MatMut<'_, E>, + perm_indices: PermRef<'_, I>, + stack: PodStack<'_>, +) { + #[inline] + #[track_caller] + fn implementation( + matrix: MatMut<'_, E>, + perm_indices: PermRef<'_, I>, + stack: PodStack<'_>, + ) { + let mut matrix = matrix; + let (mut tmp, _) = temp_mat_uninit::(matrix.nrows(), matrix.ncols(), stack); + tmp.rb_mut().copy_from(matrix.rb()); + permute_cols(matrix.rb_mut(), tmp.rb(), perm_indices); + } + + implementation(matrix, perm_indices.canonicalized(), stack) +} diff --git a/src/perm/permown.rs b/src/perm/permown.rs new file mode 100644 index 0000000000000000000000000000000000000000..cca52cd929a1ee5edb8d4861d2a51a45f7710182 --- /dev/null +++ b/src/perm/permown.rs @@ -0,0 +1,75 @@ +use super::*; +use crate::assert; + +/// Permutation matrix. +#[derive(Debug, Clone)] +pub struct Perm { + pub(super) forward: alloc::boxed::Box<[I]>, + pub(super) inverse: alloc::boxed::Box<[I]>, +} + +impl Perm { + /// Convert `self` to a permutation view. + #[inline] + pub fn as_ref(&self) -> PermRef<'_, I> { + PermRef { + forward: &self.forward, + inverse: &self.inverse, + } + } + + /// Creates a new permutation, by checking the validity of the inputs. + /// + /// # Panics + /// + /// The function panics if any of the following conditions are violated: + /// `forward` and `inverse` must have the same length which must be less than or equal to + /// `I::Signed::MAX`, be valid permutations, and be inverse permutations of each other. + #[inline] + #[track_caller] + pub fn new_checked(forward: alloc::boxed::Box<[I]>, inverse: alloc::boxed::Box<[I]>) -> Self { + PermRef::<'_, I>::new_checked(&forward, &inverse); + Self { forward, inverse } + } + + /// Creates a new permutation reference, without checking the validity of the inputs. + /// + /// # Safety + /// + /// `forward` and `inverse` must have the same length which must be less than or equal to + /// `I::Signed::MAX`, be valid permutations, and be inverse permutations of each other. + #[inline] + #[track_caller] + pub unsafe fn new_unchecked( + forward: alloc::boxed::Box<[I]>, + inverse: alloc::boxed::Box<[I]>, + ) -> Self { + let n = forward.len(); + assert!(all( + forward.len() == inverse.len(), + n <= I::Signed::MAX.zx(), + )); + Self { forward, inverse } + } + + /// Returns the permutation as an array. + #[inline] + pub fn into_arrays(self) -> (alloc::boxed::Box<[I]>, alloc::boxed::Box<[I]>) { + (self.forward, self.inverse) + } + + /// Returns the dimension of the permutation. + #[inline] + pub fn len(&self) -> usize { + self.forward.len() + } + + /// Returns the inverse permutation. + #[inline] + pub fn into_inverse(self) -> Self { + Self { + forward: self.inverse, + inverse: self.forward, + } + } +} diff --git a/src/perm/permref.rs b/src/perm/permref.rs new file mode 100644 index 0000000000000000000000000000000000000000..fc2d9cf8e6be2d638aa1c124c1e79ff84a84e3f6 --- /dev/null +++ b/src/perm/permref.rs @@ -0,0 +1,138 @@ +use super::*; +use crate::assert; + +/// Immutable permutation matrix view. +#[derive(Debug)] +pub struct PermRef<'a, I: Index> { + pub(super) forward: &'a [I], + pub(super) inverse: &'a [I], +} + +impl Copy for PermRef<'_, I> {} +impl Clone for PermRef<'_, I> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl<'short, I: Index> Reborrow<'short> for PermRef<'_, I> { + type Target = PermRef<'short, I>; + + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} +impl<'short, I: Index> ReborrowMut<'short> for PermRef<'_, I> { + type Target = PermRef<'short, I>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} +impl<'a, I: Index> IntoConst for PermRef<'a, I> { + type Target = PermRef<'a, I>; + + #[inline] + fn into_const(self) -> Self::Target { + self + } +} + +impl<'a, I: Index> PermRef<'a, I> { + /// Convert `self` to a permutation view. + #[inline] + pub fn as_ref(&self) -> PermRef<'_, I> { + PermRef { + forward: self.forward, + inverse: self.inverse, + } + } + + /// Creates a new permutation, by checking the validity of the inputs. + /// + /// # Panics + /// + /// The function panics if any of the following conditions are violated: + /// `forward` and `inverse` must have the same length which must be less than or equal to + /// `I::Signed::MAX`, be valid permutations, and be inverse permutations of each other. + #[inline] + #[track_caller] + pub fn new_checked(forward: &'a [I], inverse: &'a [I]) -> Self { + #[track_caller] + fn check(forward: &[I], inverse: &[I]) { + let n = forward.len(); + assert!(all( + forward.len() == inverse.len(), + n <= I::Signed::MAX.zx() + )); + for (i, &p) in forward.iter().enumerate() { + let p = p.to_signed().zx(); + assert!(p < n); + assert!(inverse[p].to_signed().zx() == i); + } + } + + check(I::canonicalize(forward), I::canonicalize(inverse)); + Self { forward, inverse } + } + + /// Creates a new permutation reference, without checking the validity of the inputs. + /// + /// # Safety + /// + /// `forward` and `inverse` must have the same length which must be less than or equal to + /// `I::Signed::MAX`, be valid permutations, and be inverse permutations of each other. + #[inline] + #[track_caller] + pub unsafe fn new_unchecked(forward: &'a [I], inverse: &'a [I]) -> Self { + let n = forward.len(); + assert!(all( + forward.len() == inverse.len(), + n <= I::Signed::MAX.zx(), + )); + Self { forward, inverse } + } + + /// Returns the permutation as an array. + #[inline] + pub fn arrays(self) -> (&'a [I], &'a [I]) { + (self.forward, self.inverse) + } + + /// Returns the dimension of the permutation. + #[inline] + pub fn len(&self) -> usize { + self.forward.len() + } + + /// Returns the inverse permutation. + #[inline] + pub fn inverse(self) -> Self { + Self { + forward: self.inverse, + inverse: self.forward, + } + } + + /// Cast the permutation to the fixed width index type. + #[inline(always)] + pub fn canonicalized(self) -> PermRef<'a, I::FixedWidth> { + PermRef { + forward: I::canonicalize(self.forward), + inverse: I::canonicalize(self.inverse), + } + } + + /// Cast the permutation from the fixed width index type. + #[inline(always)] + pub fn uncanonicalized(self) -> PermRef<'a, J> { + assert!(core::mem::size_of::() == core::mem::size_of::()); + PermRef { + forward: bytemuck::cast_slice(self.forward), + inverse: bytemuck::cast_slice(self.inverse), + } + } +} diff --git a/src/row/mod.rs b/src/row/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..60ef9bfd3d06331f4141fe5c5f7948efee1a52ef --- /dev/null +++ b/src/row/mod.rs @@ -0,0 +1,44 @@ +use crate::{ + col::{VecImpl, VecOwnImpl}, + utils::slice::*, + Conj, +}; +use core::{marker::PhantomData, ptr::NonNull}; +use faer_entity::*; +use reborrow::*; + +/// Represents a type that can be used to slice a row, such as an index or a range of indices. +pub trait RowIndex: crate::seal::Seal + Sized { + /// Resulting type of the indexing operation. + type Target; + + /// Index the row at `col`, without bound checks. + #[allow(clippy::missing_safety_doc)] + unsafe fn get_unchecked(this: Self, col: ColRange) -> Self::Target { + >::get(this, col) + } + /// Index the row at `col`. + fn get(this: Self, col: ColRange) -> Self::Target; +} + +/// Trait for types that can be converted to a row view. +pub trait AsRowRef { + /// Convert to a row view. + fn as_row_ref(&self) -> RowRef<'_, E>; +} +/// Trait for types that can be converted to a mutable row view. +pub trait AsRowMut { + /// Convert to a mutable row view. + fn as_row_mut(&mut self) -> RowMut<'_, E>; +} + +mod row_index; + +mod rowref; +pub use rowref::{from_raw_parts, from_slice, RowRef}; + +mod rowmut; +pub use rowmut::{from_raw_parts_mut, from_slice_mut, RowMut}; + +mod rowown; +pub use rowown::Row; diff --git a/src/row/row_index.rs b/src/row/row_index.rs new file mode 100644 index 0000000000000000000000000000000000000000..8015f178e7da69550192ec1f54e49171bc805f5c --- /dev/null +++ b/src/row/row_index.rs @@ -0,0 +1,176 @@ +// RangeFull +// Range +// RangeInclusive +// RangeTo +// RangeToInclusive +// usize + +use super::*; +use core::ops::RangeFull; + +type Range = core::ops::Range; +type RangeInclusive = core::ops::RangeInclusive; +type RangeFrom = core::ops::RangeFrom; +type RangeTo = core::ops::RangeTo; +type RangeToInclusive = core::ops::RangeToInclusive; + +impl RowIndex for RowRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: RangeFull) -> Self { + let _ = col; + this + } +} + +impl RowIndex for RowRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: Range) -> Self { + this.subcols(col.start, col.end - col.start) + } +} + +impl RowIndex for RowRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: RangeInclusive) -> Self { + assert!(*col.end() != usize::MAX); + >::get(this, *col.start()..*col.end() + 1) + } +} + +impl RowIndex for RowRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: RangeFrom) -> Self { + let ncols = this.ncols(); + >::get(this, col.start..ncols) + } +} +impl RowIndex for RowRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: RangeTo) -> Self { + >::get(this, 0..col.end) + } +} + +impl RowIndex for RowRef<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: RangeToInclusive) -> Self { + assert!(col.end != usize::MAX); + >::get(this, 0..col.end + 1) + } +} + +impl<'a, E: Entity> RowIndex for RowRef<'a, E> { + type Target = GroupFor; + + #[track_caller] + #[inline(always)] + unsafe fn get_unchecked(this: Self, col: usize) -> Self::Target { + unsafe { E::faer_map(this.ptr_inbounds_at(col), |ptr: *const _| &*ptr) } + } + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: usize) -> Self::Target { + assert!(col < this.ncols()); + unsafe { >::get_unchecked(this, col) } + } +} + +impl RowIndex for RowMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: RangeFull) -> Self { + let _ = col; + this + } +} + +impl RowIndex for RowMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: Range) -> Self { + this.subcols_mut(col.start, col.end - col.start) + } +} + +impl RowIndex for RowMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: RangeInclusive) -> Self { + assert!(*col.end() != usize::MAX); + >::get(this, *col.start()..*col.end() + 1) + } +} + +impl RowIndex for RowMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: RangeFrom) -> Self { + let ncols = this.ncols(); + >::get(this, col.start..ncols) + } +} + +impl RowIndex for RowMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: RangeTo) -> Self { + >::get(this, 0..col.end) + } +} + +impl RowIndex for RowMut<'_, E> { + type Target = Self; + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: RangeToInclusive) -> Self { + assert!(col.end != usize::MAX); + >::get(this, 0..col.end + 1) + } +} + +impl<'a, E: Entity> RowIndex for RowMut<'a, E> { + type Target = GroupFor; + + #[track_caller] + #[inline(always)] + unsafe fn get_unchecked(this: Self, col: usize) -> Self::Target { + unsafe { E::faer_map(this.ptr_inbounds_at_mut(col), |ptr: *mut _| &mut *ptr) } + } + + #[track_caller] + #[inline(always)] + fn get(this: Self, col: usize) -> Self::Target { + assert!(col < this.ncols()); + unsafe { >::get_unchecked(this, col) } + } +} diff --git a/src/row/rowmut.rs b/src/row/rowmut.rs new file mode 100644 index 0000000000000000000000000000000000000000..5fe92d0ee8dfb4954b747e2640e204e0fbb5b5ab --- /dev/null +++ b/src/row/rowmut.rs @@ -0,0 +1,574 @@ +use super::*; +use crate::{ + assert, + col::ColMut, + debug_assert, + mat::{self, As2D, As2DMut, Mat, MatMut, MatRef}, + unzipped, zipped, +}; + +/// Mutable view over a row vector, similar to a mutable reference to a strided [prim@slice]. +/// +/// # Note +/// +/// Unlike a slice, the data pointed to by `RowMut<'_, E>` is allowed to be partially or fully +/// uninitialized under certain conditions. In this case, care must be taken to not perform any +/// operations that read the uninitialized values, or form references to them, either directly +/// through [`RowMut::read`], or indirectly through any of the numerical library routines, unless +/// it is explicitly permitted. +/// +/// # Move semantics +/// See [`faer::Mat`](crate::Mat) for information about reborrowing when using this type. +#[repr(C)] +pub struct RowMut<'a, E: Entity> { + pub(super) inner: VecImpl, + pub(super) __marker: PhantomData<&'a E>, +} + +impl<'short, E: Entity> Reborrow<'short> for RowMut<'_, E> { + type Target = RowRef<'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + RowRef { + inner: self.inner, + __marker: PhantomData, + } + } +} + +impl<'short, E: Entity> ReborrowMut<'short> for RowMut<'_, E> { + type Target = RowMut<'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + RowMut { + inner: self.inner, + __marker: PhantomData, + } + } +} + +impl<'a, E: Entity> IntoConst for RowMut<'a, E> { + type Target = RowRef<'a, E>; + + #[inline] + fn into_const(self) -> Self::Target { + RowRef { + inner: self.inner, + __marker: PhantomData, + } + } +} + +impl<'a, E: Entity> RowMut<'a, E> { + #[inline] + pub(crate) unsafe fn __from_raw_parts( + ptr: GroupFor, + ncols: usize, + col_stride: isize, + ) -> Self { + Self { + inner: VecImpl { + ptr: into_copy::(E::faer_map( + ptr, + #[inline] + |ptr| NonNull::new_unchecked(ptr), + )), + len: ncols, + stride: col_stride, + }, + __marker: PhantomData, + } + } + /// Returns the number of rows of the row. This is always equal to `1`. + #[inline(always)] + pub fn nrows(&self) -> usize { + 1 + } + /// Returns the number of columns of the row. + #[inline(always)] + pub fn ncols(&self) -> usize { + self.inner.len + } + + /// Returns pointers to the matrix data. + #[inline(always)] + pub fn as_ptr_mut(self) -> GroupFor { + E::faer_map( + from_copy::(self.inner.ptr), + #[inline(always)] + |ptr| ptr.as_ptr() as *mut E::Unit, + ) + } + + /// Returns the column stride of the matrix, specified in number of elements, not in bytes. + #[inline(always)] + pub fn col_stride(&self) -> isize { + self.inner.stride + } + + /// Returns `self` as a mutable matrix view. + #[inline(always)] + pub fn as_2d_mut(self) -> MatMut<'a, E> { + let ncols = self.ncols(); + let col_stride = self.col_stride(); + unsafe { mat::from_raw_parts_mut(self.as_ptr_mut(), 1, ncols, isize::MAX, col_stride) } + } + + /// Returns raw pointers to the element at the given index. + #[inline(always)] + pub fn ptr_at_mut(self, col: usize) -> GroupFor { + let offset = (col as isize).wrapping_mul(self.inner.stride); + + E::faer_map( + self.as_ptr_mut(), + #[inline(always)] + |ptr| ptr.wrapping_offset(offset), + ) + } + + #[inline(always)] + unsafe fn ptr_at_mut_unchecked(self, col: usize) -> GroupFor { + let offset = crate::utils::unchecked_mul(col, self.inner.stride); + E::faer_map( + self.as_ptr_mut(), + #[inline(always)] + |ptr| ptr.offset(offset), + ) + } + + /// Returns raw pointers to the element at the given index, assuming the provided index + /// is within the size of the vector. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn ptr_inbounds_at_mut(self, col: usize) -> GroupFor { + debug_assert!(col < self.ncols()); + self.ptr_at_mut_unchecked(col) + } + + /// Splits the column vector at the given index into two parts and + /// returns an array of each subvector, in the following order: + /// * left. + /// * right. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at_mut_unchecked(self, col: usize) -> (Self, Self) { + let (left, right) = self.into_const().split_at_unchecked(col); + unsafe { (left.const_cast(), right.const_cast()) } + } + + /// Splits the column vector at the given index into two parts and + /// returns an array of each subvector, in the following order: + /// * top. + /// * bottom. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn split_at_mut(self, col: usize) -> (Self, Self) { + assert!(col <= self.ncols()); + unsafe { self.split_at_mut_unchecked(col) } + } + + /// Returns references to the element at the given index, or subvector if `col` is a + /// range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col` must be contained in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub unsafe fn get_mut_unchecked( + self, + col: ColRange, + ) -> >::Target + where + Self: RowIndex, + { + >::get_unchecked(self, col) + } + + /// Returns references to the element at the given index, or subvector if `col` is a + /// range, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col` must be contained in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub fn get_mut(self, col: ColRange) -> >::Target + where + Self: RowIndex, + { + >::get(self, col) + } + + /// Reads the value of the element at the given index. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn read_unchecked(&self, col: usize) -> E { + self.rb().read_unchecked(col) + } + + /// Reads the value of the element at the given index, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn read(&self, col: usize) -> E { + self.rb().read(col) + } + + /// Writes the value to the element at the given index. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn write_unchecked(&mut self, col: usize, value: E) { + let units = value.faer_into_units(); + let zipped = E::faer_zip(units, (*self).rb_mut().ptr_inbounds_at_mut(col)); + E::faer_map( + zipped, + #[inline(always)] + |(unit, ptr)| *ptr = unit, + ); + } + + /// Writes the value to the element at the given index, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn write(&mut self, col: usize, value: E) { + assert!(col < self.ncols()); + unsafe { self.write_unchecked(col, value) }; + } + + /// Copies the values from `other` into `self`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `self.ncols() == other.ncols()`. + #[track_caller] + pub fn copy_from(&mut self, other: impl AsRowRef) { + #[track_caller] + #[inline(always)] + fn implementation(this: RowMut<'_, E>, other: RowRef<'_, E>) { + zipped!(this.as_2d_mut(), other.as_2d()) + .for_each(|unzipped!(mut dst, src)| dst.write(src.read())); + } + implementation(self.rb_mut(), other.as_row_ref()) + } + + /// Fills the elements of `self` with zeros. + #[track_caller] + pub fn fill_zero(&mut self) + where + E: ComplexField, + { + zipped!(self.rb_mut().as_2d_mut()).for_each( + #[inline(always)] + |unzipped!(mut x)| x.write(E::faer_zero()), + ); + } + + /// Fills the elements of `self` with copies of `constant`. + #[track_caller] + pub fn fill(&mut self, constant: E) { + zipped!((*self).rb_mut().as_2d_mut()).for_each( + #[inline(always)] + |unzipped!(mut x)| x.write(constant), + ); + } + + /// Returns a view over the transpose of `self`. + #[inline(always)] + #[must_use] + pub fn transpose_mut(self) -> ColMut<'a, E> { + unsafe { self.into_const().transpose().const_cast() } + } + + /// Returns a view over the conjugate of `self`. + #[inline(always)] + #[must_use] + pub fn conjugate_mut(self) -> RowMut<'a, E::Conj> + where + E: Conjugate, + { + unsafe { self.into_const().conjugate().const_cast() } + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline(always)] + pub fn adjoint_mut(self) -> ColMut<'a, E::Conj> + where + E: Conjugate, + { + self.conjugate_mut().transpose_mut() + } + + /// Returns a view over the canonical representation of `self`, as well as a flag declaring + /// whether `self` is implicitly conjugated or not. + #[inline(always)] + pub fn canonicalize_mut(self) -> (RowMut<'a, E::Canonical>, Conj) + where + E: Conjugate, + { + let (canon, conj) = self.into_const().canonicalize(); + unsafe { (canon.const_cast(), conj) } + } + + /// Returns a view over the `self`, with the columnss in reversed order. + #[inline(always)] + #[must_use] + pub fn reverse_cols_mut(self) -> Self { + unsafe { self.into_const().reverse_cols().const_cast() } + } + + /// Returns a view over the subvector starting at col `col_start`, and with number of + /// columns `ncols`. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col_start <= self.ncols()`. + /// * `ncols <= self.ncols() - col_start`. + #[track_caller] + #[inline(always)] + pub unsafe fn subcols_mut_unchecked(self, col_start: usize, ncols: usize) -> Self { + self.into_const() + .subcols_unchecked(col_start, ncols) + .const_cast() + } + + /// Returns a view over the subvector starting at col `col_start`, and with number of + /// columns `ncols`. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col_start <= self.ncols()`. + /// * `ncols <= self.ncols() - col_start`. + #[track_caller] + #[inline(always)] + pub fn subcols_mut(self, col_start: usize, ncols: usize) -> Self { + unsafe { self.into_const().subcols(col_start, ncols).const_cast() } + } + + /// Returns an owning [`Row`] of the data. + #[inline] + pub fn to_owned(&self) -> Row + where + E: Conjugate, + { + (*self).rb().to_owned() + } + + /// Returns `true` if any of the elements is NaN, otherwise returns `false`. + #[inline] + pub fn has_nan(&self) -> bool + where + E: ComplexField, + { + (*self).rb().as_2d().has_nan() + } + + /// Returns `true` if all of the elements are finite, otherwise returns `false`. + #[inline] + pub fn is_all_finite(&self) -> bool + where + E: ComplexField, + { + (*self).rb().as_2d().is_all_finite() + } + + /// Returns the maximum norm of `self`. + #[inline] + pub fn norm_max(&self) -> E::Real + where + E: ComplexField, + { + self.rb().as_2d().norm_max() + } + /// Returns the L2 norm of `self`. + #[inline] + pub fn norm_l2(&self) -> E::Real + where + E: ComplexField, + { + self.rb().as_2d().norm_l2() + } + + /// Returns the sum of `self`. + #[inline] + pub fn sum(&self) -> E + where + E: ComplexField, + { + self.rb().as_2d().sum() + } + + /// Kroneckor product of `self` and `rhs`. + /// + /// This is an allocating operation; see [`faer::linalg::kron`](crate::linalg::kron) for the + /// allocation-free version or more info in general. + #[inline] + #[track_caller] + pub fn kron(&self, rhs: impl As2D) -> Mat + where + E: ComplexField, + { + self.rb().as_2d().kron(rhs) + } + + /// Returns a view over the matrix. + #[inline] + pub fn as_ref(&self) -> RowRef<'_, E> { + (*self).rb() + } + + /// Returns a mutable view over the matrix. + #[inline] + pub fn as_mut(&mut self) -> RowMut<'_, E> { + (*self).rb_mut() + } +} + +/// Creates a `RowMut` from pointers to the row vector data, number of columns, and column +/// stride. +/// +/// # Safety: +/// This function has the same safety requirements as +/// [`mat::from_raw_parts_mut(ptr, 1, ncols, 0, col_stride)`] +#[inline(always)] +pub unsafe fn from_raw_parts_mut<'a, E: Entity>( + ptr: GroupFor, + ncols: usize, + col_stride: isize, +) -> RowMut<'a, E> { + RowMut::__from_raw_parts(ptr, ncols, col_stride) +} + +/// Creates a `RowMut` from slice views over the row vector data, The result has the same +/// number of columns as the length of the input slice. +#[inline(always)] +pub fn from_slice_mut(slice: GroupFor) -> RowMut<'_, E> { + let nrows = SliceGroup::<'_, E>::new(E::faer_rb(E::faer_as_ref(&slice))).len(); + + unsafe { + from_raw_parts_mut( + E::faer_map( + slice, + #[inline(always)] + |slice| slice.as_mut_ptr(), + ), + nrows, + 1, + ) + } +} +impl As2D for &'_ RowMut<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (**self).rb().as_2d() + } +} + +impl As2D for RowMut<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (*self).rb().as_2d() + } +} + +impl As2DMut for &'_ mut RowMut<'_, E> { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (**self).rb_mut().as_2d_mut() + } +} + +impl As2DMut for RowMut<'_, E> { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (*self).rb_mut().as_2d_mut() + } +} + +impl<'a, E: Entity> core::fmt::Debug for RowMut<'a, E> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.rb().fmt(f) + } +} + +impl core::ops::Index for RowMut<'_, E> { + type Output = E; + + #[inline] + #[track_caller] + fn index(&self, col: usize) -> &E { + (*self).rb().get(col) + } +} + +impl core::ops::IndexMut for RowMut<'_, E> { + #[inline] + #[track_caller] + fn index_mut(&mut self, col: usize) -> &mut E { + (*self).rb_mut().get_mut(col) + } +} + +impl AsRowRef for RowMut<'_, E> { + #[inline] + fn as_row_ref(&self) -> RowRef<'_, E> { + (*self).rb() + } +} +impl AsRowRef for &'_ RowMut<'_, E> { + #[inline] + fn as_row_ref(&self) -> RowRef<'_, E> { + (**self).rb() + } +} + +impl AsRowMut for RowMut<'_, E> { + #[inline] + fn as_row_mut(&mut self) -> RowMut<'_, E> { + (*self).rb_mut() + } +} + +impl AsRowMut for &'_ mut RowMut<'_, E> { + #[inline] + fn as_row_mut(&mut self) -> RowMut<'_, E> { + (**self).rb_mut() + } +} diff --git a/src/row/rowown.rs b/src/row/rowown.rs new file mode 100644 index 0000000000000000000000000000000000000000..18553d2a09dc7197a4c922dd109bfba3da7bd7fe --- /dev/null +++ b/src/row/rowown.rs @@ -0,0 +1,626 @@ +use super::*; +use crate::{ + col::ColRef, + debug_assert, + mat::{ + matalloc::{align_for, is_vectorizable, MatUnit, RawMat, RawMatUnit}, + As2D, As2DMut, Mat, MatMut, MatRef, + }, + row::RowRef, + utils::DivCeil, +}; +use core::mem::ManuallyDrop; + +/// Heap allocated resizable row vector. +/// +/// # Note +/// +/// The memory layout of `Col` is guaranteed to be row-major, meaning that it has a column stride +/// of `1`. +#[repr(C)] +pub struct Row { + inner: VecOwnImpl, + col_capacity: usize, + __marker: PhantomData, +} + +impl Row { + /// Returns an empty row of dimension `0`. + #[inline] + pub fn new() -> Self { + Self { + inner: VecOwnImpl { + ptr: into_copy::(E::faer_map(E::UNIT, |()| NonNull::::dangling())), + len: 0, + }, + col_capacity: 0, + __marker: PhantomData, + } + } + + /// Returns a new column vector with 0 columns, with enough capacity to hold a maximum of + /// `col_capacity` columnss columns without reallocating. If `col_capacity` is `0`, + /// the function will not allocate. + /// + /// # Panics + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn with_capacity(col_capacity: usize) -> Self { + let raw = ManuallyDrop::new(RawMat::::new(col_capacity, 1)); + Self { + inner: VecOwnImpl { + ptr: raw.ptr, + len: 0, + }, + col_capacity: raw.row_capacity, + __marker: PhantomData, + } + } + + /// Returns a new matrix with number of columns `ncols`, filled with the provided function. + /// + /// # Panics + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn from_fn(ncols: usize, f: impl FnMut(usize) -> E) -> Self { + let mut this = Self::new(); + this.resize_with(ncols, f); + this + } + + /// Returns a new matrix with number of columns `ncols`, filled with zeros. + /// + /// # Panics + /// The function panics if the total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn zeros(ncols: usize) -> Self + where + E: ComplexField, + { + Self::from_fn(ncols, |_| E::faer_zero()) + } + + /// Returns the number of rows of the row. This is always equal to `1`. + #[inline(always)] + pub fn nrows(&self) -> usize { + 1 + } + /// Returns the number of columns of the row. + #[inline(always)] + pub fn ncols(&self) -> usize { + self.inner.len + } + + /// Set the dimensions of the matrix. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `ncols < self.col_capacity()`. + /// * The elements that were previously out of bounds but are now in bounds must be + /// initialized. + #[inline] + pub unsafe fn set_ncols(&mut self, ncols: usize) { + self.inner.len = ncols; + } + + /// Returns a pointer to the data of the matrix. + #[inline] + pub fn as_ptr(&self) -> GroupFor { + E::faer_map(from_copy::(self.inner.ptr), |ptr| { + ptr.as_ptr() as *const E::Unit + }) + } + + /// Returns a mutable pointer to the data of the matrix. + #[inline] + pub fn as_ptr_mut(&mut self) -> GroupFor { + E::faer_map(from_copy::(self.inner.ptr), |ptr| ptr.as_ptr()) + } + + /// Returns the col capacity, that is, the number of cols that the matrix is able to hold + /// without needing to reallocate, excluding column insertions. + #[inline] + pub fn col_capacity(&self) -> usize { + self.col_capacity + } + + /// Returns the offset between the first elements of two successive columns in the matrix. + /// Always returns `1` since the matrix is column major. + #[inline] + pub fn col_stride(&self) -> isize { + 1 + } + + #[cold] + fn do_reserve_exact(&mut self, mut new_col_capacity: usize) { + if is_vectorizable::() { + let align_factor = align_for::() / core::mem::size_of::(); + new_col_capacity = new_col_capacity + .msrv_checked_next_multiple_of(align_factor) + .unwrap(); + } + + let ncols = self.inner.len; + let old_col_capacity = self.col_capacity; + + let mut this = ManuallyDrop::new(core::mem::take(self)); + { + let mut this_group = E::faer_map(from_copy::(this.inner.ptr), |ptr| MatUnit { + raw: RawMatUnit { + ptr, + row_capacity: old_col_capacity, + col_capacity: 1, + }, + ncols, + nrows: 1, + }); + + E::faer_map(E::faer_as_mut(&mut this_group), |mat_unit| { + mat_unit.do_reserve_exact(new_col_capacity, 1); + }); + + let this_group = E::faer_map(this_group, ManuallyDrop::new); + this.inner.ptr = + into_copy::(E::faer_map(this_group, |mat_unit| mat_unit.raw.ptr)); + this.col_capacity = new_col_capacity; + } + *self = ManuallyDrop::into_inner(this); + } + + /// Reserves the minimum capacity for `col_capacity` columns without reallocating. Does nothing + /// if the capacity is already sufficient. + /// + /// # Panics + /// The function panics if the new total capacity in bytes exceeds `isize::MAX`. + #[inline] + pub fn reserve_exact(&mut self, col_capacity: usize) { + if self.col_capacity() >= col_capacity { + // do nothing + } else if core::mem::size_of::() == 0 { + self.col_capacity = self.col_capacity().max(col_capacity); + } else { + self.do_reserve_exact(col_capacity); + } + } + + unsafe fn insert_block_with E>( + &mut self, + f: &mut F, + col_start: usize, + col_end: usize, + ) { + debug_assert!(col_start <= col_end); + + let ptr = self.as_ptr_mut(); + + for j in col_start..col_end { + // SAFETY: + // * pointer to element at index (i, j), which is within the + // allocation since we reserved enough space + // * writing to this memory region is sound since it is properly + // aligned and valid for writes + let ptr_ij = E::faer_map(E::faer_copy(&ptr), |ptr| ptr.add(j)); + let value = E::faer_into_units(f(j)); + + E::faer_map(E::faer_zip(ptr_ij, value), |(ptr_ij, value)| { + core::ptr::write(ptr_ij, value) + }); + } + } + + fn erase_last_cols(&mut self, new_ncols: usize) { + let old_ncols = self.ncols(); + debug_assert!(new_ncols <= old_ncols); + self.inner.len = new_ncols; + } + + unsafe fn insert_last_cols_with E>(&mut self, f: &mut F, new_ncols: usize) { + let old_ncols = self.ncols(); + + debug_assert!(new_ncols > old_ncols); + + self.insert_block_with(f, old_ncols, new_ncols); + self.inner.len = new_ncols; + } + + /// Resizes the vector in-place so that the new number of columns is `new_ncols`. + /// New elements are created with the given function `f`, so that elements at index `i` + /// are created by calling `f(i)`. + pub fn resize_with(&mut self, new_ncols: usize, f: impl FnMut(usize) -> E) { + let mut f = f; + let old_ncols = self.ncols(); + + if new_ncols <= old_ncols { + self.erase_last_cols(new_ncols); + } else { + self.reserve_exact(new_ncols); + unsafe { + self.insert_last_cols_with(&mut f, new_ncols); + } + } + } + + /// Returns a reference to a slice over the row. + #[inline] + #[track_caller] + pub fn as_slice(&self) -> GroupFor { + let ncols = self.ncols(); + let ptr = self.as_ref().as_ptr(); + E::faer_map( + ptr, + #[inline(always)] + |ptr| unsafe { core::slice::from_raw_parts(ptr, ncols) }, + ) + } + + /// Returns a mutable reference to a slice over the row. + #[inline] + #[track_caller] + pub fn as_slice_mut(&mut self) -> GroupFor { + let ncols = self.ncols(); + let ptr = self.as_ptr_mut(); + E::faer_map( + ptr, + #[inline(always)] + |ptr| unsafe { core::slice::from_raw_parts_mut(ptr, ncols) }, + ) + } + + /// Returns a view over the vector. + #[inline] + pub fn as_ref(&self) -> RowRef<'_, E> { + unsafe { super::from_raw_parts(self.as_ptr(), self.ncols(), 1) } + } + + /// Returns a mutable view over the vector. + #[inline] + pub fn as_mut(&mut self) -> RowMut<'_, E> { + unsafe { super::from_raw_parts_mut(self.as_ptr_mut(), self.ncols(), 1) } + } + + /// Returns references to the element at the given index, or submatrices if `col` is a range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col` must be contained in `[0, self.ncols())`. + #[inline] + pub unsafe fn get_unchecked( + &self, + col: ColRange, + ) -> as RowIndex>::Target + where + for<'a> RowRef<'a, E>: RowIndex, + { + self.as_ref().get_unchecked(col) + } + + /// Returns references to the element at the given index, or submatrices if `col` is a range, + /// with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col` must be contained in `[0, self.ncols())`. + #[inline] + pub fn get(&self, col: ColRange) -> as RowIndex>::Target + where + for<'a> RowRef<'a, E>: RowIndex, + { + self.as_ref().get(col) + } + + /// Returns mutable references to the element at the given index, or submatrices if + /// `col` is a range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col` must be contained in `[0, self.ncols())`. + #[inline] + pub unsafe fn get_mut_unchecked( + &mut self, + col: ColRange, + ) -> as RowIndex>::Target + where + for<'a> RowMut<'a, E>: RowIndex, + { + self.as_mut().get_mut_unchecked(col) + } + + /// Returns mutable references to the element at the given index, or submatrices if + /// `col` is a range, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col` must be contained in `[0, self.ncols())`. + #[inline] + pub fn get_mut( + &mut self, + col: ColRange, + ) -> as RowIndex>::Target + where + for<'a> RowMut<'a, E>: RowIndex, + { + self.as_mut().get_mut(col) + } + + /// Reads the value of the element at the given index. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn read_unchecked(&self, col: usize) -> E { + self.as_ref().read_unchecked(col) + } + + /// Reads the value of the element at the given index, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn read(&self, col: usize) -> E { + self.as_ref().read(col) + } + + /// Writes the value to the element at the given index. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn write_unchecked(&mut self, col: usize, value: E) { + self.as_mut().write_unchecked(col, value); + } + + /// Writes the value to the element at the given index, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn write(&mut self, col: usize, value: E) { + self.as_mut().write(col, value); + } + + /// Copies the values from `other` into `self`. + #[inline(always)] + #[track_caller] + pub fn copy_from(&mut self, other: impl AsRowRef) { + #[track_caller] + #[inline(always)] + fn implementation(this: &mut Row, other: RowRef<'_, E>) { + let mut mat = Row::::new(); + mat.resize_with( + other.nrows(), + #[inline(always)] + |row| unsafe { other.read_unchecked(row) }, + ); + *this = mat; + } + implementation(self, other.as_row_ref()); + } + + /// Fills the elements of `self` with zeros. + #[inline(always)] + #[track_caller] + pub fn fill_zero(&mut self) + where + E: ComplexField, + { + self.as_mut().fill_zero() + } + + /// Fills the elements of `self` with copies of `constant`. + #[inline(always)] + #[track_caller] + pub fn fill(&mut self, constant: E) { + self.as_mut().fill(constant) + } + + /// Returns a view over the transpose of `self`. + #[inline] + pub fn transpose(&self) -> ColRef<'_, E> { + self.as_ref().transpose() + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn conjugate(&self) -> RowRef<'_, E::Conj> + where + E: Conjugate, + { + self.as_ref().conjugate() + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline] + pub fn adjoint(&self) -> ColRef<'_, E::Conj> + where + E: Conjugate, + { + self.as_ref().adjoint() + } + + /// Returns an owning [`Row`] of the data + #[inline] + pub fn to_owned(&self) -> Row + where + E: Conjugate, + { + self.as_ref().to_owned() + } + + /// Returns `true` if any of the elements is NaN, otherwise returns `false`. + #[inline] + pub fn has_nan(&self) -> bool + where + E: ComplexField, + { + self.as_ref().has_nan() + } + + /// Returns `true` if all of the elements are finite, otherwise returns `false`. + #[inline] + pub fn is_all_finite(&self) -> bool + where + E: ComplexField, + { + self.as_ref().is_all_finite() + } + + /// Returns the maximum norm of `self`. + #[inline] + pub fn norm_max(&self) -> E::Real + where + E: ComplexField, + { + self.as_ref().as_2d().norm_max() + } + /// Returns the L2 norm of `self`. + #[inline] + pub fn norm_l2(&self) -> E::Real + where + E: ComplexField, + { + self.as_ref().as_2d().norm_l2() + } + + /// Returns the sum of `self`. + #[inline] + pub fn sum(&self) -> E + where + E: ComplexField, + { + self.as_ref().as_2d().sum() + } + + /// Kroneckor product of `self` and `rhs`. + /// + /// This is an allocating operation; see [`faer::linalg::kron`](crate::linalg::kron) for the + /// allocation-free version or more info in general. + #[inline] + #[track_caller] + pub fn kron(&self, rhs: impl As2D) -> Mat + where + E: ComplexField, + { + self.as_2d_ref().kron(rhs) + } +} + +impl Default for Row { + #[inline] + fn default() -> Self { + Self::new() + } +} + +impl Clone for Row { + fn clone(&self) -> Self { + let this = self.as_ref(); + unsafe { + Self::from_fn(self.ncols(), |j| { + E::faer_from_units(E::faer_deref(this.get_unchecked(j))) + }) + } + } +} + +impl As2D for &'_ Row { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (**self).as_ref().as_2d() + } +} + +impl As2D for Row { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (*self).as_ref().as_2d() + } +} + +impl As2DMut for &'_ mut Row { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (**self).as_mut().as_2d_mut() + } +} + +impl As2DMut for Row { + #[inline] + fn as_2d_mut(&mut self) -> MatMut<'_, E> { + (*self).as_mut().as_2d_mut() + } +} + +impl AsRowRef for Row { + #[inline] + fn as_row_ref(&self) -> RowRef<'_, E> { + (*self).as_ref() + } +} +impl AsRowRef for &'_ Row { + #[inline] + fn as_row_ref(&self) -> RowRef<'_, E> { + (**self).as_ref() + } +} + +impl AsRowMut for Row { + #[inline] + fn as_row_mut(&mut self) -> RowMut<'_, E> { + (*self).as_mut() + } +} + +impl AsRowMut for &'_ mut Row { + #[inline] + fn as_row_mut(&mut self) -> RowMut<'_, E> { + (**self).as_mut() + } +} + +impl core::fmt::Debug for Row { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.as_ref().fmt(f) + } +} + +impl core::ops::Index for Row { + type Output = E; + + #[inline] + #[track_caller] + fn index(&self, col: usize) -> &E { + self.as_ref().get(col) + } +} + +impl core::ops::IndexMut for Row { + #[inline] + #[track_caller] + fn index_mut(&mut self, col: usize) -> &mut E { + self.as_mut().get_mut(col) + } +} diff --git a/src/row/rowref.rs b/src/row/rowref.rs new file mode 100644 index 0000000000000000000000000000000000000000..2279622521dee26b98d6ff10d8cdfaf21731b2b2 --- /dev/null +++ b/src/row/rowref.rs @@ -0,0 +1,547 @@ +use super::*; +use crate::{ + assert, + col::ColRef, + debug_assert, + mat::{As2D, Mat, MatRef}, +}; + +/// Immutable view over a row vector, similar to an immutable reference to a strided [prim@slice]. +/// +/// # Note +/// +/// Unlike a slice, the data pointed to by `RowRef<'_, E>` is allowed to be partially or fully +/// uninitialized under certain conditions. In this case, care must be taken to not perform any +/// operations that read the uninitialized values, or form references to them, either directly +/// through [`RowRef::read`], or indirectly through any of the numerical library routines, unless +/// it is explicitly permitted. +#[repr(C)] +pub struct RowRef<'a, E: Entity> { + pub(super) inner: VecImpl, + pub(super) __marker: PhantomData<&'a E>, +} + +impl Clone for RowRef<'_, E> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl Copy for RowRef<'_, E> {} + +impl<'short, E: Entity> Reborrow<'short> for RowRef<'_, E> { + type Target = RowRef<'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} + +impl<'short, E: Entity> ReborrowMut<'short> for RowRef<'_, E> { + type Target = RowRef<'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl IntoConst for RowRef<'_, E> { + type Target = Self; + + #[inline] + fn into_const(self) -> Self::Target { + self + } +} + +impl<'a, E: Entity> RowRef<'a, E> { + pub(crate) unsafe fn __from_raw_parts( + ptr: GroupFor, + ncols: usize, + col_stride: isize, + ) -> Self { + Self { + inner: VecImpl { + ptr: into_copy::(E::faer_map( + ptr, + #[inline] + |ptr| NonNull::new_unchecked(ptr as *mut E::Unit), + )), + len: ncols, + stride: col_stride, + }, + __marker: PhantomData, + } + } + + /// Returns the number of rows of the row. This is always equal to `1`. + #[inline(always)] + pub fn nrows(&self) -> usize { + 1 + } + /// Returns the number of columns of the row. + #[inline(always)] + pub fn ncols(&self) -> usize { + self.inner.len + } + + /// Returns pointers to the matrix data. + #[inline(always)] + pub fn as_ptr(self) -> GroupFor { + E::faer_map( + from_copy::(self.inner.ptr), + #[inline(always)] + |ptr| ptr.as_ptr() as *const E::Unit, + ) + } + + /// Returns the column stride of the matrix, specified in number of elements, not in bytes. + #[inline(always)] + pub fn col_stride(&self) -> isize { + self.inner.stride + } + + /// Returns `self` as a matrix view. + #[inline(always)] + pub fn as_2d(self) -> MatRef<'a, E> { + let ncols = self.ncols(); + let col_stride = self.col_stride(); + unsafe { crate::mat::from_raw_parts(self.as_ptr(), 1, ncols, isize::MAX, col_stride) } + } + + /// Returns raw pointers to the element at the given index. + #[inline(always)] + pub fn ptr_at(self, col: usize) -> GroupFor { + let offset = (col as isize).wrapping_mul(self.inner.stride); + + E::faer_map( + self.as_ptr(), + #[inline(always)] + |ptr| ptr.wrapping_offset(offset), + ) + } + + #[inline(always)] + unsafe fn unchecked_ptr_at(self, col: usize) -> GroupFor { + let offset = crate::utils::unchecked_mul(col, self.inner.stride); + E::faer_map( + self.as_ptr(), + #[inline(always)] + |ptr| ptr.offset(offset), + ) + } + + #[inline(always)] + unsafe fn overflowing_ptr_at(self, col: usize) -> GroupFor { + unsafe { + let cond = col != self.ncols(); + let offset = (cond as usize).wrapping_neg() as isize + & (col as isize).wrapping_mul(self.inner.stride); + E::faer_map( + self.as_ptr(), + #[inline(always)] + |ptr| ptr.offset(offset), + ) + } + } + + /// Returns raw pointers to the element at the given index, assuming the provided index + /// is within the size of the vector. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn ptr_inbounds_at(self, col: usize) -> GroupFor { + debug_assert!(col < self.ncols()); + self.unchecked_ptr_at(col) + } + + /// Splits the column vector at the given index into two parts and + /// returns an array of each subvector, in the following order: + /// * left. + /// * right. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at_unchecked(self, col: usize) -> (Self, Self) { + debug_assert!(col <= self.ncols()); + + let col_stride = self.col_stride(); + + let ncols = self.ncols(); + + unsafe { + let top = self.as_ptr(); + let bot = self.overflowing_ptr_at(col); + + ( + Self::__from_raw_parts(top, col, col_stride), + Self::__from_raw_parts(bot, ncols - col, col_stride), + ) + } + } + + /// Splits the column vector at the given index into two parts and + /// returns an array of each subvector, in the following order: + /// * top. + /// * bottom. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col <= self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn split_at(self, col: usize) -> (Self, Self) { + assert!(col <= self.ncols()); + unsafe { self.split_at_unchecked(col) } + } + + /// Returns references to the element at the given index, or subvector if `row` is a + /// range. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col` must be contained in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub unsafe fn get_unchecked( + self, + col: ColRange, + ) -> >::Target + where + Self: RowIndex, + { + >::get_unchecked(self, col) + } + + /// Returns references to the element at the given index, or subvector if `col` is a + /// range, with bound checks. + /// + /// # Note + /// The values pointed to by the references are expected to be initialized, even if the + /// pointed-to value is not read, otherwise the behavior is undefined. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col` must be contained in `[0, self.ncols())`. + #[inline(always)] + #[track_caller] + pub fn get(self, col: ColRange) -> >::Target + where + Self: RowIndex, + { + >::get(self, col) + } + + /// Reads the value of the element at the given index. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub unsafe fn read_unchecked(&self, col: usize) -> E { + E::faer_from_units(E::faer_map( + self.get_unchecked(col), + #[inline(always)] + |ptr| *ptr, + )) + } + + /// Reads the value of the element at the given index, with bound checks. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col < self.ncols()`. + #[inline(always)] + #[track_caller] + pub fn read(&self, col: usize) -> E { + E::faer_from_units(E::faer_map( + self.get(col), + #[inline(always)] + |ptr| *ptr, + )) + } + + /// Returns a view over the transpose of `self`. + #[inline(always)] + #[must_use] + pub fn transpose(self) -> ColRef<'a, E> { + unsafe { ColRef::__from_raw_parts(self.as_ptr(), self.ncols(), self.col_stride()) } + } + + /// Returns a view over the conjugate of `self`. + #[inline(always)] + #[must_use] + pub fn conjugate(self) -> RowRef<'a, E::Conj> + where + E: Conjugate, + { + unsafe { + // SAFETY: Conjugate requires that E::Unit and E::Conj::Unit have the same layout + // and that GroupCopyFor == E::Conj::GroupCopy + super::from_raw_parts::<'_, E::Conj>( + transmute_unchecked::< + GroupFor>, + GroupFor>, + >(self.as_ptr()), + self.ncols(), + self.col_stride(), + ) + } + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline(always)] + pub fn adjoint(self) -> ColRef<'a, E::Conj> + where + E: Conjugate, + { + self.conjugate().transpose() + } + + /// Returns a view over the canonical representation of `self`, as well as a flag declaring + /// whether `self` is implicitly conjugated or not. + #[inline(always)] + pub fn canonicalize(self) -> (RowRef<'a, E::Canonical>, Conj) + where + E: Conjugate, + { + ( + unsafe { + // SAFETY: see Self::conjugate + super::from_raw_parts::<'_, E::Canonical>( + transmute_unchecked::< + GroupFor, + GroupFor>, + >(self.as_ptr()), + self.ncols(), + self.col_stride(), + ) + }, + if coe::is_same::() { + Conj::No + } else { + Conj::Yes + }, + ) + } + + /// Returns a view over the `self`, with the columnss in reversed order. + #[inline(always)] + #[must_use] + pub fn reverse_cols(self) -> Self { + let ncols = self.ncols(); + let col_stride = self.col_stride().wrapping_neg(); + + let ptr = unsafe { self.unchecked_ptr_at(ncols.saturating_sub(1)) }; + unsafe { Self::__from_raw_parts(ptr, ncols, col_stride) } + } + + /// Returns a view over the subvector starting at column `col_start`, and with number of + /// columns `ncols`. + /// + /// # Safety + /// The behavior is undefined if any of the following conditions are violated: + /// * `col_start <= self.ncols()`. + /// * `ncols <= self.ncols() - col_start`. + #[track_caller] + #[inline(always)] + pub unsafe fn subcols_unchecked(self, col_start: usize, ncols: usize) -> Self { + debug_assert!(col_start <= self.ncols()); + debug_assert!(ncols <= self.ncols() - col_start); + let col_stride = self.col_stride(); + unsafe { Self::__from_raw_parts(self.overflowing_ptr_at(col_start), ncols, col_stride) } + } + + /// Returns a view over the subvector starting at col `col_start`, and with number of cols + /// `ncols`. + /// + /// # Panics + /// The function panics if any of the following conditions are violated: + /// * `col_start <= self.ncols()`. + /// * `ncols <= self.ncols() - col_start`. + #[track_caller] + #[inline(always)] + pub fn subcols(self, col_start: usize, ncols: usize) -> Self { + assert!(col_start <= self.ncols()); + assert!(ncols <= self.ncols() - col_start); + unsafe { self.subcols_unchecked(col_start, ncols) } + } + + /// Returns an owning [`Row`] of the data. + #[inline] + pub fn to_owned(&self) -> Row + where + E: Conjugate, + { + let mut mat = Row::new(); + mat.resize_with( + self.ncols(), + #[inline(always)] + |col| unsafe { self.read_unchecked(col).canonicalize() }, + ); + mat + } + + /// Returns `true` if any of the elements is NaN, otherwise returns `false`. + #[inline] + pub fn has_nan(&self) -> bool + where + E: ComplexField, + { + (*self).rb().as_2d().has_nan() + } + + /// Returns `true` if all of the elements are finite, otherwise returns `false`. + #[inline] + pub fn is_all_finite(&self) -> bool + where + E: ComplexField, + { + (*self).rb().as_2d().is_all_finite() + } + + /// Returns the maximum norm of `self`. + #[inline] + pub fn norm_max(&self) -> E::Real + where + E: ComplexField, + { + self.as_2d().norm_max() + } + /// Returns the L2 norm of `self`. + #[inline] + pub fn norm_l2(&self) -> E::Real + where + E: ComplexField, + { + self.as_2d().norm_l2() + } + + /// Returns the sum of `self`. + #[inline] + pub fn sum(&self) -> E + where + E: ComplexField, + { + self.as_2d().sum() + } + + /// Kroneckor product of `self` and `rhs`. + /// + /// This is an allocating operation; see [`faer::linalg::kron`](crate::linalg::kron) for the + /// allocation-free version or more info in general. + #[inline] + #[track_caller] + pub fn kron(&self, rhs: impl As2D) -> Mat + where + E: ComplexField, + { + self.as_2d_ref().kron(rhs) + } + + /// Returns a view over the matrix. + #[inline] + pub fn as_ref(&self) -> RowRef<'_, E> { + *self + } + + #[doc(hidden)] + #[inline(always)] + pub unsafe fn const_cast(self) -> RowMut<'a, E> { + RowMut { + inner: self.inner, + __marker: PhantomData, + } + } +} + +/// Creates a `RowRef` from pointers to the row vector data, number of columns, and column +/// stride. +/// +/// # Safety: +/// This function has the same safety requirements as +/// [`mat::from_raw_parts(ptr, 1, ncols, 0, col_stride)`] +#[inline(always)] +pub unsafe fn from_raw_parts<'a, E: Entity>( + ptr: GroupFor, + ncols: usize, + col_stride: isize, +) -> RowRef<'a, E> { + RowRef::__from_raw_parts(ptr, ncols, col_stride) +} + +/// Creates a `RowRef` from slice views over the row vector data, The result has the same +/// number of columns as the length of the input slice. +#[inline(always)] +pub fn from_slice(slice: GroupFor) -> RowRef<'_, E> { + let nrows = SliceGroup::<'_, E>::new(E::faer_copy(&slice)).len(); + + unsafe { + from_raw_parts( + E::faer_map( + slice, + #[inline(always)] + |slice| slice.as_ptr(), + ), + nrows, + 1, + ) + } +} + +impl As2D for &'_ RowRef<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (**self).as_2d() + } +} + +impl As2D for RowRef<'_, E> { + #[inline] + fn as_2d_ref(&self) -> MatRef<'_, E> { + (*self).as_2d() + } +} + +impl AsRowRef for RowRef<'_, E> { + #[inline] + fn as_row_ref(&self) -> RowRef<'_, E> { + *self + } +} +impl AsRowRef for &'_ RowRef<'_, E> { + #[inline] + fn as_row_ref(&self) -> RowRef<'_, E> { + **self + } +} + +impl<'a, E: Entity> core::fmt::Debug for RowRef<'a, E> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.as_2d().fmt(f) + } +} + +impl core::ops::Index for RowRef<'_, E> { + type Output = E; + + #[inline] + #[track_caller] + fn index(&self, col: usize) -> &E { + self.get(col) + } +} diff --git a/src/seal.rs b/src/seal.rs new file mode 100644 index 0000000000000000000000000000000000000000..019f18b3d72428c6fd64be6337663a3519e8abb1 --- /dev/null +++ b/src/seal.rs @@ -0,0 +1,21 @@ +use faer_entity::Entity; + +pub trait Seal {} + +impl Seal for crate::mat::MatRef<'_, E> {} +impl Seal for crate::mat::MatMut<'_, E> {} + +impl Seal for crate::col::ColRef<'_, E> {} +impl Seal for crate::col::ColMut<'_, E> {} + +impl Seal for crate::row::RowRef<'_, E> {} +impl Seal for crate::row::RowMut<'_, E> {} + +impl Seal for i32 {} +impl Seal for i64 {} +impl Seal for i128 {} +impl Seal for isize {} +impl Seal for u32 {} +impl Seal for u64 {} +impl Seal for u128 {} +impl Seal for usize {} diff --git a/faer-libs/faer-core/src/serde_impl.rs b/src/serde/mat.rs similarity index 95% rename from faer-libs/faer-core/src/serde_impl.rs rename to src/serde/mat.rs index 0087af10b4c640b69f4322ce338f2898c5410fea..ab939649b65ba8b52b8bff6de182c1d9027a86cd 100644 --- a/faer-libs/faer-core/src/serde_impl.rs +++ b/src/serde/mat.rs @@ -9,9 +9,9 @@ use serde::{ Deserialize, Serialize, Serializer, }; -use crate::Mat; +use crate::{Mat, MatMut, MatRef}; -impl Serialize for Mat +impl Serialize for MatRef<'_, E> where E: Serialize, { @@ -19,7 +19,7 @@ where where S: Serializer, { - struct MatSequenceSerializer<'a, E: Entity>(&'a Mat); + struct MatSequenceSerializer<'a, E: Entity>(MatRef<'a, E>); impl<'a, E: Entity> Serialize for MatSequenceSerializer<'a, E> where @@ -42,11 +42,35 @@ where let mut structure = s.serialize_struct("Mat", 3)?; structure.serialize_field("nrows", &self.nrows())?; structure.serialize_field("ncols", &self.ncols())?; - structure.serialize_field("data", &MatSequenceSerializer(self))?; + structure.serialize_field("data", &MatSequenceSerializer(*self))?; structure.end() } } +impl Serialize for MatMut<'_, E> +where + E: Serialize, +{ + fn serialize(&self, s: S) -> Result<::Ok, ::Error> + where + S: Serializer, + { + self.as_ref().serialize(s) + } +} + +impl Serialize for Mat +where + E: Serialize, +{ + fn serialize(&self, s: S) -> Result<::Ok, ::Error> + where + S: Serializer, + { + self.as_ref().serialize(s) + } +} + impl<'a, E: Entity> Deserialize<'a> for Mat where E: Deserialize<'a>, diff --git a/src/serde/mod.rs b/src/serde/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..8c8aacc34e8c12c3501e8a2ab650c566bcbca745 --- /dev/null +++ b/src/serde/mod.rs @@ -0,0 +1 @@ +mod mat; diff --git a/faer-libs/faer-core/src/sort.rs b/src/sort.rs similarity index 98% rename from faer-libs/faer-core/src/sort.rs rename to src/sort.rs index e136c2a3d9b7ecf6463315f58229ca29d97396cb..b94ecff6c6636136d85619e38fd6a0f1ee59bb03 100644 --- a/faer-libs/faer-core/src/sort.rs +++ b/src/sort.rs @@ -873,7 +873,7 @@ where let len = v_len; // Three indices near which we are going to choose a pivot. - let mut a = len / 4 * 1; + let mut a = len / 4; let mut b = len / 4 * 2; let mut c = len / 4 * 3; @@ -1007,7 +1007,7 @@ unsafe fn recurse bool>( // Continue sorting elements greater than the pivot. v = v.add(mid); - v_len = v_len - mid; + v_len -= mid; continue; } } @@ -1056,9 +1056,9 @@ pub unsafe fn sort_unstable_by( ); } -pub unsafe fn sort_indices( +pub unsafe fn sort_indices( indices: &mut [I], - values: crate::group_helpers::SliceGroupMut, + values: crate::utils::slice::SliceGroupMut, ) { #[derive(Copy, Clone)] struct Wrap(GroupCopyFor); @@ -1173,7 +1173,7 @@ pub unsafe fn sort_indices( ))), ), len, - |(i, _), (j, _)| (&*i).cmp(&*j), + |(i, _), (j, _)| (*i).cmp(&*j), ); } @@ -1200,19 +1200,13 @@ mod tests { fn test_quicksort_big() { let mut rng = StdRng::seed_from_u64(0); - let mut a = (0..1000) - .into_iter() - .map(|_| rng.gen::()) - .collect::>(); - let mut b = (0..1000) - .into_iter() - .map(|_| rng.gen::()) - .collect::>(); + let mut a = (0..1000).map(|_| rng.gen::()).collect::>(); + let mut b = (0..1000).map(|_| rng.gen::()).collect::>(); let a_old = a.clone(); let b_old = b.clone(); - let mut perm = (0..1000).into_iter().collect::>(); + let mut perm = (0..1000).collect::>(); perm.sort_unstable_by_key(|&i| a[i]); let len = a.len(); diff --git a/src/sparse/csc/matmut.rs b/src/sparse/csc/matmut.rs new file mode 100644 index 0000000000000000000000000000000000000000..777483e51da46e48b7661281cb36bf11686a8baa --- /dev/null +++ b/src/sparse/csc/matmut.rs @@ -0,0 +1,414 @@ +use super::*; +use crate::assert; + +/// Sparse matrix view in column-major format, either compressed or uncompressed. +pub struct SparseColMatMut<'a, I: Index, E: Entity> { + pub(crate) symbolic: SymbolicSparseColMatRef<'a, I>, + pub(crate) values: SliceGroupMut<'a, E>, +} + +impl<'short, I: Index, E: Entity> Reborrow<'short> for SparseColMatMut<'_, I, E> { + type Target = SparseColMatRef<'short, I, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + SparseColMatRef { + symbolic: self.symbolic, + values: self.values.rb(), + } + } +} + +impl<'short, I: Index, E: Entity> ReborrowMut<'short> for SparseColMatMut<'_, I, E> { + type Target = SparseColMatMut<'short, I, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + SparseColMatMut { + symbolic: self.symbolic, + values: self.values.rb_mut(), + } + } +} + +impl<'a, I: Index, E: Entity> IntoConst for SparseColMatMut<'a, I, E> { + type Target = SparseColMatRef<'a, I, E>; + + #[inline] + fn into_const(self) -> Self::Target { + SparseColMatRef { + symbolic: self.symbolic, + values: self.values.into_const(), + } + } +} + +impl<'a, I: Index, E: Entity> SparseColMatMut<'a, I, E> { + /// Creates a new sparse matrix view. + /// + /// # Panics + /// + /// Panics if the length of `values` is not equal to the length of + /// `symbolic.row_indices()`. + #[inline] + #[track_caller] + pub fn new( + symbolic: SymbolicSparseColMatRef<'a, I>, + values: GroupFor, + ) -> Self { + let values = SliceGroupMut::new(values); + assert!(symbolic.row_indices().len() == values.len()); + Self { symbolic, values } + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> usize { + self.symbolic.nrows + } + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> usize { + self.symbolic.ncols + } + + /// Returns a view over `self`. + #[inline] + pub fn as_ref(&self) -> SparseColMatRef<'_, I, E> { + (*self).rb() + } + + /// Returns a mutable view over `self`. + /// + /// Note that the symbolic structure cannot be changed through this view. + #[inline] + pub fn as_mut(&mut self) -> SparseColMatMut<'_, I, E> { + (*self).rb_mut() + } + + /// Copies the current matrix into a newly allocated matrix. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn to_owned(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + self.rb().to_owned() + } + + /// Copies the current matrix into a newly allocated matrix, with row-major order. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + #[inline] + pub fn to_row_major(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + self.rb().to_row_major() + } + + /// Returns a view over the transpose of `self` in row-major format. + #[inline] + pub fn transpose_mut(self) -> SparseRowMatMut<'a, I, E> { + SparseRowMatMut { + symbolic: SymbolicSparseRowMatRef { + nrows: self.symbolic.ncols, + ncols: self.symbolic.nrows, + row_ptr: self.symbolic.col_ptr, + row_nnz: self.symbolic.col_nnz, + col_ind: self.symbolic.row_ind, + }, + values: self.values, + } + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn conjugate_mut(self) -> SparseColMatMut<'a, I, E::Conj> + where + E: Conjugate, + { + SparseColMatMut { + symbolic: self.symbolic, + values: unsafe { + SliceGroupMut::<'a, E::Conj>::new(transmute_unchecked::< + GroupFor]>, + GroupFor]>, + >(E::faer_map( + self.values.into_inner(), + |slice| { + let len = slice.len(); + core::slice::from_raw_parts_mut( + slice.as_ptr() as *mut UnitFor as *mut UnitFor, + len, + ) + }, + ))) + }, + } + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn canonicalize_mut(self) -> (SparseColMatMut<'a, I, E::Canonical>, Conj) + where + E: Conjugate, + { + ( + SparseColMatMut { + symbolic: self.symbolic, + values: unsafe { + SliceGroupMut::<'a, E::Canonical>::new(transmute_unchecked::< + GroupFor]>, + GroupFor]>, + >(E::faer_map( + self.values.into_inner(), + |slice| { + let len = slice.len(); + core::slice::from_raw_parts_mut( + slice.as_mut_ptr() as *mut UnitFor as *mut UnitFor, + len, + ) + }, + ))) + }, + }, + if coe::is_same::() { + Conj::No + } else { + Conj::Yes + }, + ) + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline] + pub fn adjoint_mut(self) -> SparseRowMatMut<'a, I, E::Conj> + where + E: Conjugate, + { + self.transpose_mut().conjugate_mut() + } + + /// Returns the numerical values of the matrix. + #[inline] + pub fn values_mut(self) -> GroupFor { + self.values.into_inner() + } + + /// Returns the numerical values of column `j` of the matrix. + /// + /// # Panics: + /// + /// Panics if `j >= ncols`. + #[inline] + #[track_caller] + pub fn values_of_col_mut(self, j: usize) -> GroupFor { + let range = self.symbolic().col_range(j); + self.values.subslice(range).into_inner() + } + + /// Returns the symbolic structure of the matrix. + #[inline] + pub fn symbolic(&self) -> SymbolicSparseColMatRef<'a, I> { + self.symbolic + } + + /// Decomposes the matrix into the symbolic part and the numerical values. + #[inline] + pub fn into_parts_mut( + self, + ) -> ( + SymbolicSparseColMatRef<'a, I>, + GroupFor, + ) { + (self.symbolic, self.values.into_inner()) + } + + /// Returns the number of symbolic non-zeros in the matrix. + /// + /// The value is guaranteed to be less than `I::Signed::MAX`. + /// + /// # Note + /// Allows unsorted matrices, but the output is a count of all the entries, including the + /// duplicate ones. + #[inline] + pub fn compute_nnz(&self) -> usize { + self.symbolic.compute_nnz() + } + + /// Returns the column pointers. + #[inline] + pub fn col_ptrs(&self) -> &'a [I] { + self.symbolic.col_ptrs() + } + + /// Returns the count of non-zeros per column of the matrix. + #[inline] + pub fn nnz_per_col(&self) -> Option<&'a [I]> { + self.symbolic.col_nnz + } + + /// Returns the row indices. + #[inline] + pub fn row_indices(&self) -> &'a [I] { + self.symbolic.row_ind + } + + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col_raw(&self, j: usize) -> &'a [I] { + self.symbolic.row_indices_of_col_raw(j) + } + + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col( + &self, + j: usize, + ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator { + self.symbolic.row_indices_of_col(j) + } + + /// Returns the range that the column `j` occupies in `self.row_indices()`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn col_range(&self, j: usize) -> Range { + self.symbolic.col_range(j) + } + + /// Returns the range that the column `j` occupies in `self.row_indices()`. + /// + /// # Safety + /// + /// The behavior is undefined if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub unsafe fn col_range_unchecked(&self, j: usize) -> Range { + self.symbolic.col_range_unchecked(j) + } + + /// Returns a reference to the value at the given index using a binary search, or None if the + /// symbolic structure doesn't contain it + /// + /// # Panics + /// Panics if `row >= self.nrows()` + /// Panics if `col >= self.ncols()` + #[track_caller] + pub fn get(self, row: usize, col: usize) -> Option> { + self.into_const().get(row, col) + } + + /// Returns a reference to the value at the given index using a binary search, or None if the + /// symbolic structure doesn't contain it + /// + /// # Panics + /// Panics if `row >= self.nrows()` + /// Panics if `col >= self.ncols()` + #[track_caller] + pub fn get_mut(self, row: usize, col: usize) -> Option> { + assert!(row < self.nrows()); + assert!(col < self.ncols()); + + let Ok(pos) = self + .row_indices_of_col_raw(col) + .binary_search(&I::truncate(row)) + else { + return None; + }; + + Some(E::faer_map(self.values_of_col_mut(col), |slice| { + &mut slice[pos] + })) + } +} + +impl SparseColMatMut<'_, I, E> { + /// Fill the matrix from a previously created value order. + /// The provided values must correspond to the same indices that were provided in the + /// function call from which the order was created. + /// + /// # Note + /// The symbolic structure is not changed. + pub fn fill_from_order_and_values( + &mut self, + order: &ValuesOrder, + values: GroupFor, + mode: FillMode, + ) { + let values = SliceGroup::<'_, E>::new(values); + + { + let nnz = order.argsort.len(); + assert!(values.len() == nnz); + assert!(order.nnz == self.values.len()); + } + let all_nnz = order.all_nnz; + let mut dst = self.values.rb_mut(); + + let mut pos = 0usize; + let mut pos_unique = usize::MAX; + let mut current_bit = TOP_BIT; + + match mode { + FillMode::Replace => { + while pos < all_nnz { + let argsort_pos = order.argsort[pos]; + let extracted_bit = argsort_pos & TOP_BIT; + let argsort_pos = argsort_pos & TOP_BIT_MASK; + + let val = values.read(argsort_pos); + if extracted_bit != current_bit { + pos_unique = pos_unique.wrapping_add(1); + dst.write(pos_unique, val); + } else { + let old_val = dst.read(pos_unique); + dst.write(pos_unique, old_val.faer_add(val)); + } + + current_bit = extracted_bit; + + pos += 1; + } + } + FillMode::Add => { + while pos < all_nnz { + let argsort_pos = order.argsort[pos]; + let extracted_bit = argsort_pos & TOP_BIT; + let argsort_pos = argsort_pos & TOP_BIT_MASK; + + let val = values.read(argsort_pos); + if extracted_bit != current_bit { + pos_unique = pos_unique.wrapping_add(1); + } + + let old_val = dst.read(pos_unique); + dst.write(pos_unique, old_val.faer_add(val)); + + current_bit = extracted_bit; + + pos += 1; + } + } + } + } +} diff --git a/src/sparse/csc/matown.rs b/src/sparse/csc/matown.rs new file mode 100644 index 0000000000000000000000000000000000000000..f4f3a64d88ecfccb132aa0ab332dedf2f91bee9e --- /dev/null +++ b/src/sparse/csc/matown.rs @@ -0,0 +1,362 @@ +use super::*; +use crate::assert; + +/// Sparse matrix in column-major format, either compressed or uncompressed. +pub struct SparseColMat { + pub(crate) symbolic: SymbolicSparseColMat, + pub(crate) values: VecGroup, +} + +impl SparseColMat { + /// Creates a new sparse matrix view. + /// + /// # Panics + /// + /// Panics if the length of `values` is not equal to the length of + /// `symbolic.row_indices()`. + #[inline] + #[track_caller] + pub fn new(symbolic: SymbolicSparseColMat, values: GroupFor>) -> Self { + let values = VecGroup::from_inner(values); + assert!(symbolic.row_indices().len() == values.len()); + Self { symbolic, values } + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> usize { + self.symbolic.nrows + } + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> usize { + self.symbolic.ncols + } + + /// Copies the current matrix into a newly allocated matrix. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn to_owned(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + self.as_ref().to_owned() + } + + /// Copies the current matrix into a newly allocated matrix, with row-major order. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + #[inline] + pub fn to_row_major(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + self.as_ref().to_row_major() + } + + /// Decomposes the matrix into the symbolic part and the numerical values. + #[inline] + pub fn into_parts(self) -> (SymbolicSparseColMat, GroupFor>) { + (self.symbolic, self.values.into_inner()) + } + + /// Returns a view over `self`. + #[inline] + pub fn as_ref(&self) -> SparseColMatRef<'_, I, E> { + SparseColMatRef { + symbolic: self.symbolic.as_ref(), + values: self.values.as_slice(), + } + } + + /// Returns a mutable view over `self`. + /// + /// Note that the symbolic structure cannot be changed through this view. + #[inline] + pub fn as_mut(&mut self) -> SparseColMatMut<'_, I, E> { + SparseColMatMut { + symbolic: self.symbolic.as_ref(), + values: self.values.as_slice_mut(), + } + } + + /// Returns a slice over the numerical values of the matrix. + #[inline] + pub fn values(&self) -> GroupFor { + self.values.as_slice().into_inner() + } + + /// Returns a mutable slice over the numerical values of the matrix. + #[inline] + pub fn values_mut(&mut self) -> GroupFor { + self.values.as_slice_mut().into_inner() + } + + /// Returns a view over the transpose of `self` in row-major format. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn into_transpose(self) -> SparseRowMat { + SparseRowMat { + symbolic: SymbolicSparseRowMat { + nrows: self.symbolic.ncols, + ncols: self.symbolic.nrows, + row_ptr: self.symbolic.col_ptr, + row_nnz: self.symbolic.col_nnz, + col_ind: self.symbolic.row_ind, + }, + values: self.values, + } + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn into_conjugate(self) -> SparseColMat + where + E: Conjugate, + { + SparseColMat { + symbolic: self.symbolic, + values: unsafe { + VecGroup::::from_inner(transmute_unchecked::< + GroupFor>>, + GroupFor>>, + >(E::faer_map( + self.values.into_inner(), + |mut slice| { + let len = slice.len(); + let cap = slice.capacity(); + let ptr = slice.as_mut_ptr() as *mut UnitFor as *mut UnitFor; + + Vec::from_raw_parts(ptr, len, cap) + }, + ))) + }, + } + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline] + pub fn into_adjoint(self) -> SparseRowMat + where + E: Conjugate, + { + self.into_transpose().into_conjugate() + } + + /// Returns the number of symbolic non-zeros in the matrix. + /// + /// The value is guaranteed to be less than `I::Signed::MAX`. + /// + /// # Note + /// Allows unsorted matrices, but the output is a count of all the entries, including the + /// duplicate ones. + #[inline] + pub fn compute_nnz(&self) -> usize { + self.symbolic.compute_nnz() + } + + /// Returns the column pointers. + #[inline] + pub fn col_ptrs(&self) -> &'_ [I] { + self.symbolic.col_ptrs() + } + + /// Returns the count of non-zeros per column of the matrix. + #[inline] + pub fn nnz_per_col(&self) -> Option<&'_ [I]> { + self.symbolic.col_nnz.as_deref() + } + + /// Returns the row indices. + #[inline] + pub fn row_indices(&self) -> &'_ [I] { + &self.symbolic.row_ind + } + + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col_raw(&self, j: usize) -> &'_ [I] { + self.symbolic.row_indices_of_col_raw(j) + } + + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col( + &self, + j: usize, + ) -> impl '_ + ExactSizeIterator + DoubleEndedIterator { + self.symbolic.row_indices_of_col(j) + } + + /// Returns the range that the column `j` occupies in `self.row_indices()`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn col_range(&self, j: usize) -> Range { + self.symbolic.col_range(j) + } + + /// Returns the range that the column `j` occupies in `self.row_indices()`. + /// + /// # Safety + /// + /// The behavior is undefined if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub unsafe fn col_range_unchecked(&self, j: usize) -> Range { + self.symbolic.col_range_unchecked(j) + } + + /// Returns a reference to the value at the given index using a binary search, or None if the + /// symbolic structure doesn't contain it + /// + /// # Panics + /// Panics if `row >= self.nrows()` + /// Panics if `col >= self.ncols()` + #[track_caller] + pub fn get(&self, row: usize, col: usize) -> Option> { + self.as_ref().get(row, col) + } + + /// Returns a reference to the value at the given index using a binary search, or None if the + /// symbolic structure doesn't contain it + /// + /// # Panics + /// Panics if `row >= self.nrows()` + /// Panics if `col >= self.ncols()` + #[track_caller] + pub fn get_mut(&mut self, row: usize, col: usize) -> Option> { + self.as_mut().get_mut(row, col) + } +} + +impl SparseColMat { + #[track_caller] + pub(crate) fn new_from_order_and_values_impl( + symbolic: SymbolicSparseColMat, + order: &ValuesOrder, + all_values: impl Fn(usize) -> E, + values_len: usize, + ) -> Result { + { + let nnz = order.argsort.len(); + assert!(values_len == nnz); + } + + let all_nnz = order.all_nnz; + + let mut values = VecGroup::::new(); + match values.try_reserve_exact(order.nnz) { + Ok(()) => {} + Err(_) => return Err(FaerError::OutOfMemory), + }; + + let mut pos = 0usize; + let mut pos_unique = usize::MAX; + let mut current_bit = TOP_BIT; + + while pos < all_nnz { + let argsort_pos = order.argsort[pos]; + let extracted_bit = argsort_pos & TOP_BIT; + let argsort_pos = argsort_pos & TOP_BIT_MASK; + + let val = all_values(argsort_pos); + if extracted_bit != current_bit { + values.push(val.faer_into_units()); + pos_unique = pos_unique.wrapping_add(1); + } else { + let old_val = values.as_slice().read(pos_unique); + values + .as_slice_mut() + .write(pos_unique, old_val.faer_add(val)); + } + + current_bit = extracted_bit; + + pos += 1; + } + + Ok(Self { symbolic, values }) + } + + /// Create a new matrix from a previously created symbolic structure and value order. + /// The provided values must correspond to the same indices that were provided in the + /// function call from which the order was created. + #[track_caller] + pub fn new_from_order_and_values( + symbolic: SymbolicSparseColMat, + order: &ValuesOrder, + values: GroupFor, + ) -> Result { + let values = SliceGroup::<'_, E>::new(values); + Self::new_from_order_and_values_impl(symbolic, order, |i| values.read(i), values.len()) + } + + /// Create a new matrix from triplets `(row, col, value)`. + #[track_caller] + pub fn try_new_from_triplets( + nrows: usize, + ncols: usize, + triplets: &[(I, I, E)], + ) -> Result { + let (symbolic, order) = SymbolicSparseColMat::try_new_from_indices_impl( + nrows, + ncols, + |i| { + let (row, col, _) = triplets[i]; + (row, col) + }, + triplets.len(), + )?; + Ok(Self::new_from_order_and_values_impl( + symbolic, + &order, + |i| triplets[i].2, + triplets.len(), + )?) + } + + /// Create a new matrix from triplets `(row, col, value)`. Negative indices are ignored. + #[track_caller] + pub fn try_new_from_nonnegative_triplets( + nrows: usize, + ncols: usize, + triplets: &[(I::Signed, I::Signed, E)], + ) -> Result { + let (symbolic, order) = SymbolicSparseColMat::::try_new_from_nonnegative_indices_impl( + nrows, + ncols, + |i| { + let (row, col, _) = triplets[i]; + (row, col) + }, + triplets.len(), + )?; + Ok(Self::new_from_order_and_values_impl( + symbolic, + &order, + |i| triplets[i].2, + triplets.len(), + )?) + } +} diff --git a/src/sparse/csc/matref.rs b/src/sparse/csc/matref.rs new file mode 100644 index 0000000000000000000000000000000000000000..65f1e20c8ae3ec25607f783582a709373f009231 --- /dev/null +++ b/src/sparse/csc/matref.rs @@ -0,0 +1,379 @@ +use super::*; +use crate::assert; + +/// Sparse matrix view in column-major format, either compressed or uncompressed. +pub struct SparseColMatRef<'a, I: Index, E: Entity> { + pub(crate) symbolic: SymbolicSparseColMatRef<'a, I>, + pub(crate) values: SliceGroup<'a, E>, +} + +impl Copy for SparseColMatRef<'_, I, E> {} +impl Clone for SparseColMatRef<'_, I, E> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl<'short, I: Index, E: Entity> Reborrow<'short> for SparseColMatRef<'_, I, E> { + type Target = SparseColMatRef<'short, I, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} + +impl<'short, I: Index, E: Entity> ReborrowMut<'short> for SparseColMatRef<'_, I, E> { + type Target = SparseColMatRef<'short, I, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl<'a, I: Index, E: Entity> IntoConst for SparseColMatRef<'a, I, E> { + type Target = SparseColMatRef<'a, I, E>; + + #[inline] + fn into_const(self) -> Self::Target { + self + } +} + +impl<'a, I: Index, E: Entity> SparseColMatRef<'a, I, E> { + /// Creates a new sparse matrix view. + /// + /// # Panics + /// + /// Panics if the length of `values` is not equal to the length of + /// `symbolic.row_indices()`. + #[inline] + #[track_caller] + pub fn new( + symbolic: SymbolicSparseColMatRef<'a, I>, + values: GroupFor, + ) -> Self { + let values = SliceGroup::new(values); + assert!(symbolic.row_indices().len() == values.len()); + Self { symbolic, values } + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> usize { + self.symbolic.nrows + } + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> usize { + self.symbolic.ncols + } + + /// Returns a view over `self`. + #[inline] + pub fn as_ref(&self) -> SparseColMatRef<'_, I, E> { + *self + } + + /// Copies the current matrix into a newly allocated matrix. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn to_owned(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + let symbolic = self.symbolic().to_owned()?; + let mut values = VecGroup::::new(); + + values + .try_reserve_exact(self.values.len()) + .map_err(|_| FaerError::OutOfMemory)?; + + values.resize( + self.values.len(), + E::Canonical::faer_zero().faer_into_units(), + ); + + let src = self.values; + let dst = values.as_slice_mut(); + + for (mut dst, src) in core::iter::zip(dst.into_mut_iter(), src.into_ref_iter()) { + dst.write(src.read().canonicalize()); + } + + Ok(SparseColMat { symbolic, values }) + } + + /// Copies the current matrix into a newly allocated matrix, with row-major order. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + #[inline] + pub fn to_row_major(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + let mut col_ptr = try_zeroed::(self.nrows() + 1)?; + let nnz = self.compute_nnz(); + let mut row_ind = try_zeroed::(nnz)?; + let mut values = VecGroup::::new(); + values + .try_reserve_exact(nnz) + .map_err(|_| FaerError::OutOfMemory)?; + values.resize(nnz, E::Canonical::faer_zero().faer_into_units()); + + let mut mem = GlobalPodBuffer::try_new(StackReq::new::(self.nrows())) + .map_err(|_| FaerError::OutOfMemory)?; + + let (this, conj) = self.canonicalize(); + + if conj == Conj::No { + utils::transpose::( + &mut col_ptr, + &mut row_ind, + values.as_slice_mut().into_inner(), + this, + PodStack::new(&mut mem), + ); + } else { + utils::adjoint::( + &mut col_ptr, + &mut row_ind, + values.as_slice_mut().into_inner(), + this, + PodStack::new(&mut mem), + ); + } + + let transpose = unsafe { + SparseColMat::new( + SymbolicSparseColMat::new_unchecked( + self.ncols(), + self.nrows(), + col_ptr, + None, + row_ind, + ), + values.into_inner(), + ) + }; + + Ok(transpose.into_transpose()) + } + + /// Returns a view over the transpose of `self` in row-major format. + #[inline] + pub fn transpose(self) -> SparseRowMatRef<'a, I, E> { + SparseRowMatRef { + symbolic: SymbolicSparseRowMatRef { + nrows: self.symbolic.ncols, + ncols: self.symbolic.nrows, + row_ptr: self.symbolic.col_ptr, + row_nnz: self.symbolic.col_nnz, + col_ind: self.symbolic.row_ind, + }, + values: self.values, + } + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn conjugate(self) -> SparseColMatRef<'a, I, E::Conj> + where + E: Conjugate, + { + SparseColMatRef { + symbolic: self.symbolic, + values: unsafe { + SliceGroup::<'a, E::Conj>::new(transmute_unchecked::< + GroupFor]>, + GroupFor]>, + >(E::faer_map( + self.values.into_inner(), + |slice| { + let len = slice.len(); + core::slice::from_raw_parts( + slice.as_ptr() as *const UnitFor as *const UnitFor, + len, + ) + }, + ))) + }, + } + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn canonicalize(self) -> (SparseColMatRef<'a, I, E::Canonical>, Conj) + where + E: Conjugate, + { + ( + SparseColMatRef { + symbolic: self.symbolic, + values: unsafe { + SliceGroup::<'a, E::Canonical>::new(transmute_unchecked::< + GroupFor]>, + GroupFor]>, + >(E::faer_map( + self.values.into_inner(), + |slice| { + let len = slice.len(); + core::slice::from_raw_parts( + slice.as_ptr() as *const UnitFor as *const UnitFor, + len, + ) + }, + ))) + }, + }, + if coe::is_same::() { + Conj::No + } else { + Conj::Yes + }, + ) + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline] + pub fn adjoint(self) -> SparseRowMatRef<'a, I, E::Conj> + where + E: Conjugate, + { + self.transpose().conjugate() + } + + /// Returns the numerical values of the matrix. + #[inline] + pub fn values(self) -> GroupFor { + self.values.into_inner() + } + + /// Returns the numerical values of column `j` of the matrix. + /// + /// # Panics: + /// + /// Panics if `j >= ncols`. + #[inline] + #[track_caller] + pub fn values_of_col(self, j: usize) -> GroupFor { + self.values.subslice(self.col_range(j)).into_inner() + } + + /// Returns the symbolic structure of the matrix. + #[inline] + pub fn symbolic(&self) -> SymbolicSparseColMatRef<'a, I> { + self.symbolic + } + + /// Decomposes the matrix into the symbolic part and the numerical values. + #[inline] + pub fn into_parts(self) -> (SymbolicSparseColMatRef<'a, I>, GroupFor) { + (self.symbolic, self.values.into_inner()) + } + + /// Returns the number of symbolic non-zeros in the matrix. + /// + /// The value is guaranteed to be less than `I::Signed::MAX`. + /// + /// # Note + /// Allows unsorted matrices, but the output is a count of all the entries, including the + /// duplicate ones. + #[inline] + pub fn compute_nnz(&self) -> usize { + self.symbolic.compute_nnz() + } + + /// Returns the column pointers. + #[inline] + pub fn col_ptrs(&self) -> &'a [I] { + self.symbolic.col_ptrs() + } + + /// Returns the count of non-zeros per column of the matrix. + #[inline] + pub fn nnz_per_col(&self) -> Option<&'a [I]> { + self.symbolic.col_nnz + } + + /// Returns the row indices. + #[inline] + pub fn row_indices(&self) -> &'a [I] { + self.symbolic.row_ind + } + + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col_raw(&self, j: usize) -> &'a [I] { + self.symbolic.row_indices_of_col_raw(j) + } + + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col( + &self, + j: usize, + ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator { + self.symbolic.row_indices_of_col(j) + } + + /// Returns the range that the column `j` occupies in `self.row_indices()`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn col_range(&self, j: usize) -> Range { + self.symbolic.col_range(j) + } + + /// Returns the range that the column `j` occupies in `self.row_indices()`. + /// + /// # Safety + /// + /// The behavior is undefined if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub unsafe fn col_range_unchecked(&self, j: usize) -> Range { + self.symbolic.col_range_unchecked(j) + } + + /// Returns a reference to the value at the given index, or None if the symbolic structure + /// doesn't contain it + /// + /// # Panics + /// Panics if `row >= self.nrows()` + /// Panics if `col >= self.ncols()` + #[track_caller] + pub fn get(self, row: usize, col: usize) -> Option> { + assert!(row < self.nrows()); + assert!(col < self.ncols()); + + let Ok(pos) = self + .row_indices_of_col_raw(col) + .binary_search(&I::truncate(row)) + else { + return None; + }; + + Some(E::faer_map(self.values_of_col(col), |slice| &slice[pos])) + } +} diff --git a/src/sparse/csc/mod.rs b/src/sparse/csc/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..20ca95ff8e8ae8e1b8a7b07d8260112bb9a11fe9 --- /dev/null +++ b/src/sparse/csc/mod.rs @@ -0,0 +1,15 @@ +use super::*; + +mod symbolic_own; +mod symbolic_ref; + +mod matmut; +mod matown; +mod matref; + +pub use symbolic_own::SymbolicSparseColMat; +pub use symbolic_ref::SymbolicSparseColMatRef; + +pub use matmut::SparseColMatMut; +pub use matown::SparseColMat; +pub use matref::SparseColMatRef; diff --git a/src/sparse/csc/symbolic_own.rs b/src/sparse/csc/symbolic_own.rs new file mode 100644 index 0000000000000000000000000000000000000000..d00934f907ebf746d9404eb98156a599f8c80cf2 --- /dev/null +++ b/src/sparse/csc/symbolic_own.rs @@ -0,0 +1,544 @@ +use super::*; +use crate::sparse::csr::*; +use core::marker::PhantomData; + +/// Symbolic structure of sparse matrix in column format, either compressed or uncompressed. +/// +/// Requires: +/// * `nrows <= I::Signed::MAX` (always checked) +/// * `ncols <= I::Signed::MAX` (always checked) +/// * `col_ptrs` has length `ncols + 1` (always checked) +/// * `col_ptrs` is non-decreasing +/// * `col_ptrs[0]..col_ptrs[ncols]` is a valid range in row_indices (always checked, assuming +/// non-decreasing) +/// * if `nnz_per_col` is `None`, elements of `row_indices[col_ptrs[j]..col_ptrs[j + 1]]` are less +/// than `nrows` +/// +/// * `nnz_per_col[j] <= col_ptrs[j+1] - col_ptrs[j]` +/// * if `nnz_per_col` is `Some(_)`, elements of `row_indices[col_ptrs[j]..][..nnz_per_col[j]]` are +/// less than `nrows` +/// +/// * Within each column, row indices are unique and sorted in increasing order. +/// +/// # Note +/// Some algorithms allow working with matrices containing duplicate and/or unsorted row +/// indicers per column. +/// +/// Passing such a matrix to an algorithm that does not explicitly permit this is unspecified +/// (though not undefined) behavior. +#[derive(Clone)] +pub struct SymbolicSparseColMat { + pub(crate) nrows: usize, + pub(crate) ncols: usize, + pub(crate) col_ptr: alloc::vec::Vec, + pub(crate) col_nnz: Option>, + pub(crate) row_ind: alloc::vec::Vec, +} + +impl SymbolicSparseColMat { + /// Creates a new symbolic matrix view after asserting its invariants. + /// + /// # Panics + /// + /// See type level documentation. + #[inline] + #[track_caller] + pub fn new_checked( + nrows: usize, + ncols: usize, + col_ptrs: Vec, + nnz_per_col: Option>, + row_indices: Vec, + ) -> Self { + SymbolicSparseColMatRef::new_checked( + nrows, + ncols, + &col_ptrs, + nnz_per_col.as_deref(), + &row_indices, + ); + + Self { + nrows, + ncols, + col_ptr: col_ptrs, + col_nnz: nnz_per_col, + row_ind: row_indices, + } + } + + /// Creates a new symbolic matrix view from data containing duplicate and/or unsorted row + /// indices per column, after asserting its other invariants. + /// + /// # Panics + /// + /// See type level documentation. + #[inline] + #[track_caller] + pub fn new_unsorted_checked( + nrows: usize, + ncols: usize, + col_ptrs: Vec, + nnz_per_col: Option>, + row_indices: Vec, + ) -> Self { + SymbolicSparseColMatRef::new_unsorted_checked( + nrows, + ncols, + &col_ptrs, + nnz_per_col.as_deref(), + &row_indices, + ); + + Self { + nrows, + ncols, + col_ptr: col_ptrs, + col_nnz: nnz_per_col, + row_ind: row_indices, + } + } + + /// Creates a new symbolic matrix view without asserting its invariants. + /// + /// # Safety + /// + /// See type level documentation. + #[inline(always)] + #[track_caller] + pub unsafe fn new_unchecked( + nrows: usize, + ncols: usize, + col_ptrs: Vec, + nnz_per_col: Option>, + row_indices: Vec, + ) -> Self { + SymbolicSparseRowMatRef::new_unchecked( + nrows, + ncols, + &col_ptrs, + nnz_per_col.as_deref(), + &row_indices, + ); + + Self { + nrows, + ncols, + col_ptr: col_ptrs, + col_nnz: nnz_per_col, + row_ind: row_indices, + } + } + + /// Returns the components of the matrix in the order: + /// - row count, + /// - column count, + /// - column pointers, + /// - nonzeros per column, + /// - row indices. + #[inline] + pub fn into_parts(self) -> (usize, usize, Vec, Option>, Vec) { + ( + self.nrows, + self.ncols, + self.col_ptr, + self.col_nnz, + self.row_ind, + ) + } + + /// Returns a view over the symbolic structure of `self`. + #[inline] + pub fn as_ref(&self) -> SymbolicSparseColMatRef<'_, I> { + SymbolicSparseColMatRef { + nrows: self.nrows, + ncols: self.ncols, + col_ptr: &self.col_ptr, + col_nnz: self.col_nnz.as_deref(), + row_ind: &self.row_ind, + } + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> usize { + self.nrows + } + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> usize { + self.ncols + } + + /// Consumes the matrix, and returns its transpose in row-major format without reallocating. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn into_transpose(self) -> SymbolicSparseRowMat { + SymbolicSparseRowMat { + nrows: self.ncols, + ncols: self.nrows, + row_ptr: self.col_ptr, + row_nnz: self.col_nnz, + col_ind: self.row_ind, + } + } + + /// Copies the current matrix into a newly allocated matrix. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn to_owned(&self) -> Result, FaerError> { + self.as_ref().to_owned() + } + + /// Copies the current matrix into a newly allocated matrix, with row-major order. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + #[inline] + pub fn to_row_major(&self) -> Result, FaerError> { + self.as_ref().to_row_major() + } + + /// Returns the number of symbolic non-zeros in the matrix. + /// + /// The value is guaranteed to be less than `I::Signed::MAX`. + /// + /// # Note + /// Allows unsorted matrices, but the output is a count of all the entries, including the + /// duplicate ones. + #[inline] + pub fn compute_nnz(&self) -> usize { + self.as_ref().compute_nnz() + } + + /// Returns the column pointers. + #[inline] + pub fn col_ptrs(&self) -> &[I] { + &self.col_ptr + } + + /// Returns the count of non-zeros per column of the matrix. + #[inline] + pub fn nnz_per_col(&self) -> Option<&[I]> { + self.col_nnz.as_deref() + } + + /// Returns the row indices. + #[inline] + pub fn row_indices(&self) -> &[I] { + &self.row_ind + } + + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col_raw(&self, j: usize) -> &[I] { + &self.row_ind[self.col_range(j)] + } + + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col( + &self, + j: usize, + ) -> impl '_ + ExactSizeIterator + DoubleEndedIterator { + self.as_ref().row_indices_of_col(j) + } + + /// Returns the range that the column `j` occupies in `self.row_indices()`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn col_range(&self, j: usize) -> Range { + self.as_ref().col_range(j) + } + + /// Returns the range that the column `j` occupies in `self.row_indices()`. + /// + /// # Safety + /// + /// The behavior is undefined if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub unsafe fn col_range_unchecked(&self, j: usize) -> Range { + self.as_ref().col_range_unchecked(j) + } +} + +impl core::fmt::Debug for SymbolicSparseColMat { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.as_ref().fmt(f) + } +} + +impl SymbolicSparseColMat { + pub(crate) fn try_new_from_indices_impl( + nrows: usize, + ncols: usize, + indices: impl Fn(usize) -> (I, I), + all_nnz: usize, + ) -> Result<(Self, ValuesOrder), CreationError> { + if nrows > I::Signed::MAX.zx() || ncols > I::Signed::MAX.zx() { + return Err(CreationError::Generic(FaerError::IndexOverflow)); + } + + if all_nnz == 0 { + return Ok(( + Self { + nrows, + ncols, + col_ptr: try_zeroed(ncols + 1)?, + col_nnz: None, + row_ind: Vec::new(), + }, + ValuesOrder { + argsort: Vec::new(), + all_nnz, + nnz: 0, + __marker: PhantomData, + }, + )); + } + + let mut argsort = try_collect(0..all_nnz)?; + argsort.sort_unstable_by_key(|&i| { + let (row, col) = indices(i); + (col, row) + }); + + let mut n_duplicates = 0usize; + let mut current_bit = 0usize; + + let mut prev = indices(argsort[0]); + for i in 1..all_nnz { + let idx = indices(argsort[i]); + let same_as_prev = idx == prev; + prev = idx; + current_bit = ((current_bit == ((same_as_prev as usize) << (usize::BITS - 1))) + as usize) + << (usize::BITS - 1); + argsort[i] |= current_bit; + + n_duplicates += same_as_prev as usize; + } + + let nnz = all_nnz - n_duplicates; + if nnz > I::Signed::MAX.zx() { + return Err(CreationError::Generic(FaerError::IndexOverflow)); + } + + let mut col_ptr = try_zeroed::(ncols + 1)?; + let mut row_ind = try_zeroed::(nnz)?; + + let mut original_pos = 0usize; + let mut new_pos = 0usize; + + for j in 0..ncols { + let mut n_unique = 0usize; + + while original_pos < all_nnz { + let (row, col) = indices(argsort[original_pos] & TOP_BIT_MASK); + if row.zx() >= nrows || col.zx() >= ncols { + return Err(CreationError::OutOfBounds { + row: row.zx(), + col: col.zx(), + }); + } + + if col.zx() != j { + break; + } + + row_ind[new_pos] = row; + + n_unique += 1; + + new_pos += 1; + original_pos += 1; + + while original_pos < all_nnz + && indices(argsort[original_pos] & TOP_BIT_MASK) == (row, col) + { + original_pos += 1; + } + } + + col_ptr[j + 1] = col_ptr[j] + I::truncate(n_unique); + } + + Ok(( + Self { + nrows, + ncols, + col_ptr, + col_nnz: None, + row_ind, + }, + ValuesOrder { + argsort, + all_nnz, + nnz, + __marker: PhantomData, + }, + )) + } + + pub(crate) fn try_new_from_nonnegative_indices_impl( + nrows: usize, + ncols: usize, + indices: impl Fn(usize) -> (I::Signed, I::Signed), + all_nnz: usize, + ) -> Result<(Self, ValuesOrder), CreationError> { + if nrows > I::Signed::MAX.zx() || ncols > I::Signed::MAX.zx() { + return Err(CreationError::Generic(FaerError::IndexOverflow)); + } + + let mut argsort = try_collect(0..all_nnz)?; + argsort.sort_unstable_by_key(|&i| { + let (row, col) = indices(i); + let ignore = (row < I::Signed::truncate(0)) | (col < I::Signed::truncate(0)); + (ignore, col, row) + }); + + let all_nnz = argsort.partition_point(|&i| { + let (row, col) = indices(i); + let ignore = (row < I::Signed::truncate(0)) | (col < I::Signed::truncate(0)); + !ignore + }); + + if all_nnz == 0 { + return Ok(( + Self { + nrows, + ncols, + col_ptr: try_zeroed(ncols + 1)?, + col_nnz: None, + row_ind: Vec::new(), + }, + ValuesOrder { + argsort: Vec::new(), + all_nnz, + nnz: 0, + __marker: PhantomData, + }, + )); + } + + let mut n_duplicates = 0usize; + let mut current_bit = 0usize; + + let mut prev = indices(argsort[0]); + + for i in 1..all_nnz { + let idx = indices(argsort[i]); + let same_as_prev = idx == prev; + prev = idx; + current_bit = ((current_bit == ((same_as_prev as usize) << (usize::BITS - 1))) + as usize) + << (usize::BITS - 1); + argsort[i] |= current_bit; + + n_duplicates += same_as_prev as usize; + } + + let nnz = all_nnz - n_duplicates; + if nnz > I::Signed::MAX.zx() { + return Err(CreationError::Generic(FaerError::IndexOverflow)); + } + + let mut col_ptr = try_zeroed::(ncols + 1)?; + let mut row_ind = try_zeroed::(nnz)?; + + let mut original_pos = 0usize; + let mut new_pos = 0usize; + + for j in 0..ncols { + let mut n_unique = 0usize; + + while original_pos < all_nnz { + let (row, col) = indices(argsort[original_pos] & TOP_BIT_MASK); + if row.zx() >= nrows || col.zx() >= ncols { + return Err(CreationError::OutOfBounds { + row: row.zx(), + col: col.zx(), + }); + } + + if col.zx() != j { + break; + } + + row_ind[new_pos] = I::from_signed(row); + + n_unique += 1; + + new_pos += 1; + original_pos += 1; + + while original_pos < all_nnz + && indices(argsort[original_pos] & TOP_BIT_MASK) == (row, col) + { + original_pos += 1; + } + } + + col_ptr[j + 1] = col_ptr[j] + I::truncate(n_unique); + } + + Ok(( + Self { + nrows, + ncols, + col_ptr, + col_nnz: None, + row_ind, + }, + ValuesOrder { + argsort, + all_nnz, + nnz, + __marker: PhantomData, + }, + )) + } + + /// Create a new symbolic structure, and the corresponding order for the numerical values + /// from pairs of indices `(row, col)`. + #[inline] + pub fn try_new_from_indices( + nrows: usize, + ncols: usize, + indices: &[(I, I)], + ) -> Result<(Self, ValuesOrder), CreationError> { + Self::try_new_from_indices_impl(nrows, ncols, |i| indices[i], indices.len()) + } + + /// Create a new symbolic structure, and the corresponding order for the numerical values + /// from pairs of indices `(row, col)`. + /// + /// Negative indices are ignored. + #[inline] + pub fn try_new_from_nonnegative_indices( + nrows: usize, + ncols: usize, + indices: &[(I::Signed, I::Signed)], + ) -> Result<(Self, ValuesOrder), CreationError> { + Self::try_new_from_nonnegative_indices_impl(nrows, ncols, |i| indices[i], indices.len()) + } +} diff --git a/src/sparse/csc/symbolic_ref.rs b/src/sparse/csc/symbolic_ref.rs new file mode 100644 index 0000000000000000000000000000000000000000..d3e5ba9771035b50854dacdf939b0eb1f5b7cddd --- /dev/null +++ b/src/sparse/csc/symbolic_ref.rs @@ -0,0 +1,401 @@ +use super::*; +use crate::assert; + +/// Symbolic view structure of sparse matrix in column format, either compressed or uncompressed. +/// +/// Requires: +/// * `nrows <= I::Signed::MAX` (always checked) +/// * `ncols <= I::Signed::MAX` (always checked) +/// * `col_ptrs` has length `ncols + 1` (always checked) +/// * `col_ptrs` is non-decreasing +/// * `col_ptrs[0]..col_ptrs[ncols]` is a valid range in row_indices (always checked, assuming +/// non-decreasing) +/// * if `nnz_per_col` is `None`, elements of `row_indices[col_ptrs[j]..col_ptrs[j + 1]]` are less +/// than `nrows` +/// +/// * `nnz_per_col[j] <= col_ptrs[j+1] - col_ptrs[j]` +/// * if `nnz_per_col` is `Some(_)`, elements of `row_indices[col_ptrs[j]..][..nnz_per_col[j]]` are +/// less than `nrows` +/// +/// * Within each column, row indices are unique and sorted in increasing order. +/// +/// # Note +/// Some algorithms allow working with matrices containing duplicate and/or unsorted row +/// indicers per column. +/// +/// Passing such a matrix to an algorithm that does not explicitly permit this is unspecified +/// (though not undefined) behavior. +pub struct SymbolicSparseColMatRef<'a, I: Index> { + pub(crate) nrows: usize, + pub(crate) ncols: usize, + pub(crate) col_ptr: &'a [I], + pub(crate) col_nnz: Option<&'a [I]>, + pub(crate) row_ind: &'a [I], +} + +impl Copy for SymbolicSparseColMatRef<'_, I> {} +impl Clone for SymbolicSparseColMatRef<'_, I> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl<'short, I: Index> Reborrow<'short> for SymbolicSparseColMatRef<'_, I> { + type Target = SymbolicSparseColMatRef<'short, I>; + + #[inline] + fn rb(&self) -> Self::Target { + *self + } +} + +impl<'short, I: Index> ReborrowMut<'short> for SymbolicSparseColMatRef<'_, I> { + type Target = SymbolicSparseColMatRef<'short, I>; + + #[inline] + fn rb_mut(&mut self) -> Self::Target { + *self + } +} + +impl<'a, I: Index> IntoConst for SymbolicSparseColMatRef<'a, I> { + type Target = SymbolicSparseColMatRef<'a, I>; + + #[inline] + fn into_const(self) -> Self::Target { + self + } +} + +impl<'a, I: Index> SymbolicSparseColMatRef<'a, I> { + /// Creates a new symbolic matrix view after asserting its invariants. + /// + /// # Panics + /// + /// See type level documentation. + #[inline] + #[track_caller] + pub fn new_checked( + nrows: usize, + ncols: usize, + col_ptrs: &'a [I], + nnz_per_col: Option<&'a [I]>, + row_indices: &'a [I], + ) -> Self { + assert!(all( + ncols <= I::Signed::MAX.zx(), + nrows <= I::Signed::MAX.zx(), + )); + assert!(col_ptrs.len() == ncols + 1); + for &[c, c_next] in windows2(col_ptrs) { + assert!(c <= c_next); + } + assert!(col_ptrs[ncols].zx() <= row_indices.len()); + + if let Some(nnz_per_col) = nnz_per_col { + for (&nnz_j, &[c, c_next]) in zip(nnz_per_col, windows2(col_ptrs)) { + assert!(nnz_j <= c_next - c); + let row_indices = &row_indices[c.zx()..c.zx() + nnz_j.zx()]; + if !row_indices.is_empty() { + let mut i_prev = row_indices[0]; + for &i in &row_indices[1..] { + assert!(i_prev < i); + i_prev = i; + } + let nrows = I::truncate(nrows); + assert!(i_prev < nrows); + } + } + } else { + for &[c, c_next] in windows2(col_ptrs) { + let row_indices = &row_indices[c.zx()..c_next.zx()]; + if !row_indices.is_empty() { + let mut i_prev = row_indices[0]; + for &i in &row_indices[1..] { + assert!(i_prev < i); + i_prev = i; + } + let nrows = I::truncate(nrows); + assert!(i_prev < nrows); + } + } + } + + Self { + nrows, + ncols, + col_ptr: col_ptrs, + col_nnz: nnz_per_col, + row_ind: row_indices, + } + } + + /// Creates a new symbolic matrix view from data containing duplicate and/or unsorted row + /// indices per column, after asserting its other invariants. + /// + /// # Panics + /// + /// See type level documentation. + #[inline] + #[track_caller] + pub fn new_unsorted_checked( + nrows: usize, + ncols: usize, + col_ptrs: &'a [I], + nnz_per_col: Option<&'a [I]>, + row_indices: &'a [I], + ) -> Self { + assert!(all( + ncols <= I::Signed::MAX.zx(), + nrows <= I::Signed::MAX.zx(), + )); + assert!(col_ptrs.len() == ncols + 1); + for &[c, c_next] in windows2(col_ptrs) { + assert!(c <= c_next); + } + assert!(col_ptrs[ncols].zx() <= row_indices.len()); + + if let Some(nnz_per_col) = nnz_per_col { + for (&nnz_j, &[c, c_next]) in zip(nnz_per_col, windows2(col_ptrs)) { + assert!(nnz_j <= c_next - c); + for &i in &row_indices[c.zx()..c.zx() + nnz_j.zx()] { + assert!(i < I::truncate(nrows)); + } + } + } else { + let c0 = col_ptrs[0].zx(); + let cn = col_ptrs[ncols].zx(); + for &i in &row_indices[c0..cn] { + assert!(i < I::truncate(nrows)); + } + } + + Self { + nrows, + ncols, + col_ptr: col_ptrs, + col_nnz: nnz_per_col, + row_ind: row_indices, + } + } + + /// Creates a new symbolic matrix view without asserting its invariants. + /// + /// # Safety + /// + /// See type level documentation. + #[inline(always)] + #[track_caller] + pub unsafe fn new_unchecked( + nrows: usize, + ncols: usize, + col_ptrs: &'a [I], + nnz_per_col: Option<&'a [I]>, + row_indices: &'a [I], + ) -> Self { + assert!(all( + ncols <= ::MAX.zx(), + nrows <= ::MAX.zx(), + )); + assert!(col_ptrs.len() == ncols + 1); + assert!(col_ptrs[ncols].zx() <= row_indices.len()); + + Self { + nrows, + ncols, + col_ptr: col_ptrs, + col_nnz: nnz_per_col, + row_ind: row_indices, + } + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> usize { + self.nrows + } + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> usize { + self.ncols + } + + /// Returns a view over the transpose of `self` in row-major format. + #[inline] + pub fn transpose(self) -> SymbolicSparseRowMatRef<'a, I> { + SymbolicSparseRowMatRef { + nrows: self.ncols, + ncols: self.nrows, + row_ptr: self.col_ptr, + row_nnz: self.col_nnz, + col_ind: self.row_ind, + } + } + + /// Copies the current matrix into a newly allocated matrix. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn to_owned(&self) -> Result, FaerError> { + Ok(SymbolicSparseColMat { + nrows: self.nrows, + ncols: self.ncols, + col_ptr: try_collect(self.col_ptr.iter().copied())?, + col_nnz: self + .col_nnz + .map(|nnz| try_collect(nnz.iter().copied())) + .transpose()?, + row_ind: try_collect(self.row_ind.iter().copied())?, + }) + } + + /// Copies the current matrix into a newly allocated matrix, with row-major order. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + #[inline] + pub fn to_row_major(&self) -> Result, FaerError> { + let mut col_ptr = try_zeroed::(self.nrows + 1)?; + let mut row_ind = try_zeroed::(self.compute_nnz())?; + + let mut mem = GlobalPodBuffer::try_new(dyn_stack::StackReq::new::(self.nrows)) + .map_err(|_| FaerError::OutOfMemory)?; + + utils::adjoint_symbolic( + &mut col_ptr, + &mut row_ind, + *self, + dyn_stack::PodStack::new(&mut mem), + ); + + let transpose = unsafe { + SymbolicSparseColMat::new_unchecked(self.ncols, self.nrows, col_ptr, None, row_ind) + }; + + Ok(transpose.into_transpose()) + } + + /// Returns the number of symbolic non-zeros in the matrix. + /// + /// The value is guaranteed to be less than `I::Signed::MAX`. + /// + /// # Note + /// Allows unsorted matrices, but the output is a count of all the entries, including the + /// duplicate ones. + #[inline] + pub fn compute_nnz(&self) -> usize { + match self.col_nnz { + Some(col_nnz) => { + let mut nnz = 0usize; + for &nnz_j in col_nnz { + // can't overflow + nnz += nnz_j.zx(); + } + nnz + } + None => self.col_ptr[self.ncols].zx() - self.col_ptr[0].zx(), + } + } + + /// Returns the column pointers. + #[inline] + pub fn col_ptrs(&self) -> &'a [I] { + self.col_ptr + } + + /// Returns the count of non-zeros per column of the matrix. + #[inline] + pub fn nnz_per_col(&self) -> Option<&'a [I]> { + self.col_nnz + } + + /// Returns the row indices. + #[inline] + pub fn row_indices(&self) -> &'a [I] { + self.row_ind + } + + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col_raw(&self, j: usize) -> &'a [I] { + &self.row_ind[self.col_range(j)] + } + + /// Returns the row indices of column `j`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn row_indices_of_col( + &self, + j: usize, + ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator { + self.row_indices_of_col_raw(j).iter().map( + #[inline(always)] + |&i| i.zx(), + ) + } + + /// Returns the range that the column `j` occupies in `self.row_indices()`. + /// + /// # Panics + /// + /// Panics if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub fn col_range(&self, j: usize) -> Range { + let start = self.col_ptr[j].zx(); + let end = self + .col_nnz + .map(|col_nnz| col_nnz[j].zx() + start) + .unwrap_or(self.col_ptr[j + 1].zx()); + + start..end + } + + /// Returns the range that the column `j` occupies in `self.row_indices()`. + /// + /// # Safety + /// + /// The behavior is undefined if `j >= self.ncols()`. + #[inline] + #[track_caller] + pub unsafe fn col_range_unchecked(&self, j: usize) -> Range { + let start = __get_unchecked(self.col_ptr, j).zx(); + let end = self + .col_nnz + .map(|col_nnz| (__get_unchecked(col_nnz, j).zx() + start)) + .unwrap_or(__get_unchecked(self.col_ptr, j + 1).zx()); + + start..end + } +} + +impl core::fmt::Debug for SymbolicSparseColMatRef<'_, I> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mat = *self; + let mut iter = (0..mat.ncols()).flat_map(move |j| { + struct Wrapper(usize, usize); + impl core::fmt::Debug for Wrapper { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let row = self.0; + let col = self.1; + write!(f, "({row}, {col}") + } + } + + mat.row_indices_of_col(j).map(move |i| Wrapper(i, j)) + }); + + f.debug_list().entries(&mut iter).finish() + } +} diff --git a/src/sparse/csr/matmut.rs b/src/sparse/csr/matmut.rs new file mode 100644 index 0000000000000000000000000000000000000000..b61fa381f0ac5242e9b8234649ef0e0bd161f2c6 --- /dev/null +++ b/src/sparse/csr/matmut.rs @@ -0,0 +1,351 @@ +use super::*; +use crate::assert; + +/// Sparse matrix view in column-major format, either compressed or uncompressed. +pub struct SparseRowMatMut<'a, I: Index, E: Entity> { + pub(crate) symbolic: SymbolicSparseRowMatRef<'a, I>, + pub(crate) values: SliceGroupMut<'a, E>, +} + +impl<'short, I: Index, E: Entity> Reborrow<'short> for SparseRowMatMut<'_, I, E> { + type Target = SparseRowMatRef<'short, I, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + SparseRowMatRef { + symbolic: self.symbolic, + values: self.values.rb(), + } + } +} + +impl<'short, I: Index, E: Entity> ReborrowMut<'short> for SparseRowMatMut<'_, I, E> { + type Target = SparseRowMatMut<'short, I, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + SparseRowMatMut { + symbolic: self.symbolic, + values: self.values.rb_mut(), + } + } +} + +impl<'a, I: Index, E: Entity> IntoConst for SparseRowMatMut<'a, I, E> { + type Target = SparseRowMatRef<'a, I, E>; + + #[inline] + fn into_const(self) -> Self::Target { + SparseRowMatRef { + symbolic: self.symbolic, + values: self.values.into_const(), + } + } +} + +impl<'a, I: Index, E: Entity> SparseRowMatMut<'a, I, E> { + /// Creates a new sparse matrix view. + /// + /// # Panics + /// + /// Panics if the length of `values` is not equal to the length of + /// `symbolic.col_indices()`. + #[inline] + #[track_caller] + pub fn new( + symbolic: SymbolicSparseRowMatRef<'a, I>, + values: GroupFor, + ) -> Self { + let values = SliceGroupMut::new(values); + assert!(symbolic.col_indices().len() == values.len()); + Self { symbolic, values } + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> usize { + self.symbolic.nrows + } + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> usize { + self.symbolic.ncols + } + + /// Returns a view over `self`. + #[inline] + pub fn as_ref(&self) -> SparseRowMatRef<'_, I, E> { + (*self).rb() + } + + /// Returns a mutable view over `self`. + /// + /// Note that the symbolic structure cannot be changed through this view. + #[inline] + pub fn as_mut(&mut self) -> SparseRowMatMut<'_, I, E> { + (*self).rb_mut() + } + + /// Copies the current matrix into a newly allocated matrix. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn to_owned(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + self.rb().to_owned() + } + + /// Copies the current matrix into a newly allocated matrix, with column-major order. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + #[inline] + pub fn to_col_major(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + self.rb().to_col_major() + } + + /// Returns a view over the transpose of `self` in column-major format. + #[inline] + pub fn transpose_mut(self) -> SparseColMatMut<'a, I, E> { + SparseColMatMut { + symbolic: SymbolicSparseColMatRef { + nrows: self.symbolic.ncols, + ncols: self.symbolic.nrows, + col_ptr: self.symbolic.row_ptr, + col_nnz: self.symbolic.row_nnz, + row_ind: self.symbolic.col_ind, + }, + values: self.values, + } + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn canonicalize_mut(self) -> (SparseRowMatMut<'a, I, E::Canonical>, Conj) + where + E: Conjugate, + { + ( + SparseRowMatMut { + symbolic: self.symbolic, + values: unsafe { + SliceGroupMut::<'a, E::Canonical>::new(transmute_unchecked::< + GroupFor]>, + GroupFor]>, + >(E::faer_map( + self.values.into_inner(), + |slice| { + let len = slice.len(); + core::slice::from_raw_parts_mut( + slice.as_mut_ptr() as *mut UnitFor as *mut UnitFor, + len, + ) + }, + ))) + }, + }, + if coe::is_same::() { + Conj::No + } else { + Conj::Yes + }, + ) + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn conjugate_mut(self) -> SparseRowMatMut<'a, I, E::Conj> + where + E: Conjugate, + { + SparseRowMatMut { + symbolic: self.symbolic, + values: unsafe { + SliceGroupMut::<'a, E::Conj>::new(transmute_unchecked::< + GroupFor]>, + GroupFor]>, + >(E::faer_map( + self.values.into_inner(), + |slice| { + let len = slice.len(); + core::slice::from_raw_parts_mut( + slice.as_mut_ptr() as *mut UnitFor as *mut UnitFor, + len, + ) + }, + ))) + }, + } + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline] + pub fn adjoint_mut(self) -> SparseColMatMut<'a, I, E::Conj> + where + E: Conjugate, + { + self.transpose_mut().conjugate_mut() + } + + /// Returns the numerical values of the matrix. + #[inline] + pub fn values_mut(self) -> GroupFor { + self.values.into_inner() + } + + /// Returns the numerical values of row `i` of the matrix. + /// + /// # Panics: + /// + /// Panics if `i >= nrows`. + #[inline] + #[track_caller] + pub fn values_of_row_mut(self, i: usize) -> GroupFor { + let range = self.symbolic().row_range(i); + self.values.subslice(range).into_inner() + } + + /// Returns the symbolic structure of the matrix. + #[inline] + pub fn symbolic(&self) -> SymbolicSparseRowMatRef<'a, I> { + self.symbolic + } + + /// Decomposes the matrix into the symbolic part and the numerical values. + #[inline] + pub fn into_parts( + self, + ) -> ( + SymbolicSparseRowMatRef<'a, I>, + GroupFor, + ) { + (self.symbolic, self.values.into_inner()) + } + + /// Returns the number of symbolic non-zeros in the matrix. + /// + /// The value is guaranteed to be less than `I::Signed::MAX`. + /// + /// # Note + /// Allows unsorted matrices, but the output is a count of all the entries, including the + /// duplicate ones. + #[inline] + pub fn compute_nnz(&self) -> usize { + self.symbolic.compute_nnz() + } + + /// Returns the column pointers. + #[inline] + pub fn row_ptrs(&self) -> &'a [I] { + self.symbolic.row_ptrs() + } + + /// Returns the count of non-zeros per column of the matrix. + #[inline] + pub fn nnz_per_row(&self) -> Option<&'a [I]> { + self.symbolic.nnz_per_row() + } + + /// Returns the column indices. + #[inline] + pub fn col_indices(&self) -> &'a [I] { + self.symbolic.col_indices() + } + + /// Returns the column indices of row i. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row_raw(&self, i: usize) -> &'a [I] { + self.symbolic.col_indices_of_row_raw(i) + } + + /// Returns the column indices of row i. + /// + /// # Panics + /// + /// Panics if `i >= self.ncols()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row( + &self, + i: usize, + ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator { + self.symbolic.col_indices_of_row(i) + } + + /// Returns the range that the row `i` occupies in `self.col_indices()`. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn row_range(&self, i: usize) -> Range { + self.symbolic.row_range(i) + } + + /// Returns the range that the row `i` occupies in `self.col_indices()`. + /// + /// # Safety + /// + /// The behavior is undefined if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub unsafe fn row_range_unchecked(&self, i: usize) -> Range { + self.symbolic.row_range_unchecked(i) + } + + /// Returns a reference to the value at the given index using a binary search, or None if the + /// symbolic structure doesn't contain it + /// + /// # Panics + /// Panics if `row >= self.nrows()` + /// Panics if `col >= self.ncols()` + #[track_caller] + pub fn get_mut(self, row: usize, col: usize) -> Option> { + assert!(row < self.nrows()); + assert!(col < self.ncols()); + + let Ok(pos) = self + .col_indices_of_row_raw(row) + .binary_search(&I::truncate(col)) + else { + return None; + }; + + Some(E::faer_map(self.values_of_row_mut(row), |slice| { + &mut slice[pos] + })) + } +} + +impl SparseRowMatMut<'_, I, E> { + /// Fill the matrix from a previously created value order. + /// The provided values must correspond to the same indices that were provided in the + /// function call from which the order was created. + /// + /// # Note + /// The symbolic structure is not changed. + pub fn fill_from_order_and_values( + &mut self, + order: &ValuesOrder, + values: GroupFor, + mode: FillMode, + ) { + self.rb_mut() + .transpose_mut() + .fill_from_order_and_values(order, values, mode); + } +} diff --git a/src/sparse/csr/matown.rs b/src/sparse/csr/matown.rs new file mode 100644 index 0000000000000000000000000000000000000000..72b2c092b1c8a48874f05a56b300e3f84132412c --- /dev/null +++ b/src/sparse/csr/matown.rs @@ -0,0 +1,316 @@ +use super::*; +use crate::assert; + +/// Sparse matrix in column-major format, either compressed or uncompressed. +pub struct SparseRowMat { + pub(crate) symbolic: SymbolicSparseRowMat, + pub(crate) values: VecGroup, +} + +impl SparseRowMat { + /// Creates a new sparse matrix view. + /// + /// # Panics + /// + /// Panics if the length of `values` is not equal to the length of + /// `symbolic.col_indices()`. + #[inline] + #[track_caller] + pub fn new(symbolic: SymbolicSparseRowMat, values: GroupFor>) -> Self { + let values = VecGroup::from_inner(values); + assert!(symbolic.col_indices().len() == values.len()); + Self { symbolic, values } + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> usize { + self.symbolic.nrows + } + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> usize { + self.symbolic.ncols + } + + /// Copies the current matrix into a newly allocated matrix. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn to_owned(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + self.as_ref().to_owned() + } + + /// Copies the current matrix into a newly allocated matrix, with column-major order. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + #[inline] + pub fn to_col_major(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + self.as_ref().to_col_major() + } + + /// Decomposes the matrix into the symbolic part and the numerical values. + #[inline] + pub fn into_parts(self) -> (SymbolicSparseRowMat, GroupFor>) { + (self.symbolic, self.values.into_inner()) + } + + /// Returns a view over `self`. + #[inline] + pub fn as_ref(&self) -> SparseRowMatRef<'_, I, E> { + SparseRowMatRef { + symbolic: self.symbolic.as_ref(), + values: self.values.as_slice(), + } + } + + /// Returns a mutable view over `self`. + /// + /// Note that the symbolic structure cannot be changed through this view. + #[inline] + pub fn as_mut(&mut self) -> SparseRowMatMut<'_, I, E> { + SparseRowMatMut { + symbolic: self.symbolic.as_ref(), + values: self.values.as_slice_mut(), + } + } + + /// Returns a slice over the numerical values of the matrix. + #[inline] + pub fn values(&self) -> GroupFor { + self.values.as_slice().into_inner() + } + + /// Returns a mutable slice over the numerical values of the matrix. + #[inline] + pub fn values_mut(&mut self) -> GroupFor { + self.values.as_slice_mut().into_inner() + } + + /// Returns a view over the transpose of `self` in column-major format. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn into_transpose(self) -> SparseColMat { + SparseColMat { + symbolic: SymbolicSparseColMat { + nrows: self.symbolic.ncols, + ncols: self.symbolic.nrows, + col_ptr: self.symbolic.row_ptr, + col_nnz: self.symbolic.row_nnz, + row_ind: self.symbolic.col_ind, + }, + values: self.values, + } + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn into_conjugate(self) -> SparseRowMat + where + E: Conjugate, + { + SparseRowMat { + symbolic: self.symbolic, + values: unsafe { + VecGroup::::from_inner(transmute_unchecked::< + GroupFor>>, + GroupFor>>, + >(E::faer_map( + self.values.into_inner(), + |mut slice| { + let len = slice.len(); + let cap = slice.capacity(); + let ptr = slice.as_mut_ptr() as *mut UnitFor as *mut UnitFor; + + Vec::from_raw_parts(ptr, len, cap) + }, + ))) + }, + } + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline] + pub fn into_adjoint(self) -> SparseColMat + where + E: Conjugate, + { + self.into_transpose().into_conjugate() + } + + /// Returns the number of symbolic non-zeros in the matrix. + /// + /// The value is guaranteed to be less than `I::Signed::MAX`. + /// + /// # Note + /// Allows unsorted matrices, but the output is a count of all the entries, including the + /// duplicate ones. + #[inline] + pub fn compute_nnz(&self) -> usize { + self.symbolic.compute_nnz() + } + + /// Returns the column pointers. + #[inline] + pub fn row_ptrs(&self) -> &'_ [I] { + self.symbolic.row_ptrs() + } + + /// Returns the count of non-zeros per column of the matrix. + #[inline] + pub fn nnz_per_row(&self) -> Option<&'_ [I]> { + self.symbolic.nnz_per_row() + } + + /// Returns the column indices. + #[inline] + pub fn col_indices(&self) -> &'_ [I] { + self.symbolic.col_indices() + } + + /// Returns the column indices of row i. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row_raw(&self, i: usize) -> &'_ [I] { + self.symbolic.col_indices_of_row_raw(i) + } + + /// Returns the column indices of row i. + /// + /// # Panics + /// + /// Panics if `i >= self.ncols()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row( + &self, + i: usize, + ) -> impl '_ + ExactSizeIterator + DoubleEndedIterator { + self.symbolic.col_indices_of_row(i) + } + + /// Returns the range that the row `i` occupies in `self.col_indices()`. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn row_range(&self, i: usize) -> Range { + self.symbolic.row_range(i) + } + + /// Returns the range that the row `i` occupies in `self.col_indices()`. + /// + /// # Safety + /// + /// The behavior is undefined if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub unsafe fn row_range_unchecked(&self, i: usize) -> Range { + self.symbolic.row_range_unchecked(i) + } + + /// Returns a reference to the value at the given index using a binary search, or None if the + /// symbolic structure doesn't contain it + /// + /// # Panics + /// Panics if `row >= self.nrows()` + /// Panics if `col >= self.ncols()` + #[track_caller] + pub fn get(&self, row: usize, col: usize) -> Option> { + self.as_ref().get(row, col) + } + + /// Returns a reference to the value at the given index using a binary search, or None if the + /// symbolic structure doesn't contain it + /// + /// # Panics + /// Panics if `row >= self.nrows()` + /// Panics if `col >= self.ncols()` + #[track_caller] + pub fn get_mut(&mut self, row: usize, col: usize) -> Option> { + self.as_mut().get_mut(row, col) + } +} + +impl SparseRowMat { + /// Create a new matrix from a previously created symbolic structure and value order. + /// The provided values must correspond to the same indices that were provided in the + /// function call from which the order was created. + #[track_caller] + pub fn new_from_order_and_values( + symbolic: SymbolicSparseRowMat, + order: &ValuesOrder, + values: GroupFor, + ) -> Result { + SparseColMat::new_from_order_and_values(symbolic.into_transpose(), order, values) + .map(SparseColMat::into_transpose) + } + + /// Create a new matrix from triplets `(row, col, value)`. + #[track_caller] + pub fn try_new_from_triplets( + nrows: usize, + ncols: usize, + triplets: &[(I, I, E)], + ) -> Result { + let (symbolic, order) = SymbolicSparseColMat::try_new_from_indices_impl( + ncols, + nrows, + |i| { + let (row, col, _) = triplets[i]; + (col, row) + }, + triplets.len(), + )?; + Ok(SparseColMat::new_from_order_and_values_impl( + symbolic, + &order, + |i| triplets[i].2, + triplets.len(), + )? + .into_transpose()) + } + + /// Create a new matrix from triplets `(row, col, value)`. Negative indices are ignored. + #[track_caller] + pub fn try_new_from_nonnegative_triplets( + nrows: usize, + ncols: usize, + triplets: &[(I::Signed, I::Signed, E)], + ) -> Result { + let (symbolic, order) = SymbolicSparseColMat::::try_new_from_nonnegative_indices_impl( + ncols, + nrows, + |i| { + let (row, col, _) = triplets[i]; + (col, row) + }, + triplets.len(), + )?; + Ok(SparseColMat::new_from_order_and_values_impl( + symbolic, + &order, + |i| triplets[i].2, + triplets.len(), + )? + .into_transpose()) + } +} diff --git a/src/sparse/csr/matref.rs b/src/sparse/csr/matref.rs new file mode 100644 index 0000000000000000000000000000000000000000..261800625fab64c58de9af1e818defd32a825e16 --- /dev/null +++ b/src/sparse/csr/matref.rs @@ -0,0 +1,319 @@ +use super::*; +use crate::assert; + +/// Sparse matrix view in column-major format, either compressed or uncompressed. +pub struct SparseRowMatRef<'a, I: Index, E: Entity> { + pub(crate) symbolic: SymbolicSparseRowMatRef<'a, I>, + pub(crate) values: SliceGroup<'a, E>, +} + +impl Copy for SparseRowMatRef<'_, I, E> {} +impl Clone for SparseRowMatRef<'_, I, E> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl<'short, I: Index, E: Entity> Reborrow<'short> for SparseRowMatRef<'_, I, E> { + type Target = SparseRowMatRef<'short, I, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} + +impl<'short, I: Index, E: Entity> ReborrowMut<'short> for SparseRowMatRef<'_, I, E> { + type Target = SparseRowMatRef<'short, I, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl<'a, I: Index, E: Entity> IntoConst for SparseRowMatRef<'a, I, E> { + type Target = SparseRowMatRef<'a, I, E>; + + #[inline] + fn into_const(self) -> Self::Target { + self + } +} + +impl<'a, I: Index, E: Entity> SparseRowMatRef<'a, I, E> { + /// Creates a new sparse matrix view. + /// + /// # Panics + /// + /// Panics if the length of `values` is not equal to the length of + /// `symbolic.col_indices()`. + #[inline] + #[track_caller] + pub fn new( + symbolic: SymbolicSparseRowMatRef<'a, I>, + values: GroupFor, + ) -> Self { + let values = SliceGroup::new(values); + assert!(symbolic.col_indices().len() == values.len()); + Self { symbolic, values } + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> usize { + self.symbolic.nrows + } + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> usize { + self.symbolic.ncols + } + + /// Returns a view over `self`. + #[inline] + pub fn as_ref(&self) -> SparseRowMatRef<'_, I, E> { + *self + } + + /// Returns the numerical values of the matrix. + #[inline] + pub fn values(self) -> GroupFor { + self.values.into_inner() + } + + /// Copies the current matrix into a newly allocated matrix. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn to_owned(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + self.transpose() + .to_owned() + .map(SparseColMat::into_transpose) + } + + /// Copies the current matrix into a newly allocated matrix, with column-major order. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + #[inline] + pub fn to_col_major(&self) -> Result, FaerError> + where + E: Conjugate, + E::Canonical: ComplexField, + { + self.transpose() + .to_row_major() + .map(SparseRowMat::into_transpose) + } + + /// Returns a view over the transpose of `self` in column-major format. + #[inline] + pub fn transpose(self) -> SparseColMatRef<'a, I, E> { + SparseColMatRef { + symbolic: SymbolicSparseColMatRef { + nrows: self.symbolic.ncols, + ncols: self.symbolic.nrows, + col_ptr: self.symbolic.row_ptr, + col_nnz: self.symbolic.row_nnz, + row_ind: self.symbolic.col_ind, + }, + values: self.values, + } + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn conjugate(self) -> SparseRowMatRef<'a, I, E::Conj> + where + E: Conjugate, + { + SparseRowMatRef { + symbolic: self.symbolic, + values: unsafe { + SliceGroup::<'a, E::Conj>::new(transmute_unchecked::< + GroupFor]>, + GroupFor]>, + >(E::faer_map( + self.values.into_inner(), + |slice| { + let len = slice.len(); + core::slice::from_raw_parts( + slice.as_ptr() as *const UnitFor as *const UnitFor, + len, + ) + }, + ))) + }, + } + } + + /// Returns a view over the conjugate of `self`. + #[inline] + pub fn canonicalize(self) -> (SparseRowMatRef<'a, I, E::Canonical>, Conj) + where + E: Conjugate, + { + ( + SparseRowMatRef { + symbolic: self.symbolic, + values: unsafe { + SliceGroup::<'a, E::Canonical>::new(transmute_unchecked::< + GroupFor]>, + GroupFor]>, + >(E::faer_map( + self.values.into_inner(), + |slice| { + let len = slice.len(); + core::slice::from_raw_parts( + slice.as_ptr() as *const UnitFor as *const UnitFor, + len, + ) + }, + ))) + }, + }, + if coe::is_same::() { + Conj::No + } else { + Conj::Yes + }, + ) + } + + /// Returns a view over the conjugate transpose of `self`. + #[inline] + pub fn adjoint(self) -> SparseColMatRef<'a, I, E::Conj> + where + E: Conjugate, + { + self.transpose().conjugate() + } + + /// Returns the numerical values of row `i` of the matrix. + /// + /// # Panics: + /// + /// Panics if `i >= nrows`. + #[inline] + #[track_caller] + pub fn values_of_row(self, i: usize) -> GroupFor { + self.values.subslice(self.row_range(i)).into_inner() + } + + /// Returns the symbolic structure of the matrix. + #[inline] + pub fn symbolic(&self) -> SymbolicSparseRowMatRef<'a, I> { + self.symbolic + } + + /// Decomposes the matrix into the symbolic part and the numerical values. + #[inline] + pub fn into_parts(self) -> (SymbolicSparseRowMatRef<'a, I>, GroupFor) { + (self.symbolic, self.values.into_inner()) + } + + /// Returns the number of symbolic non-zeros in the matrix. + /// + /// The value is guaranteed to be less than `I::Signed::MAX`. + /// + /// # Note + /// Allows unsorted matrices, but the output is a count of all the entries, including the + /// duplicate ones. + #[inline] + pub fn compute_nnz(&self) -> usize { + self.transpose().compute_nnz() + } + + /// Returns the column pointers. + #[inline] + pub fn row_ptrs(&self) -> &'a [I] { + self.symbolic.row_ptrs() + } + + /// Returns the count of non-zeros per column of the matrix. + #[inline] + pub fn nnz_per_row(&self) -> Option<&'a [I]> { + self.symbolic.nnz_per_row() + } + + /// Returns the column indices. + #[inline] + pub fn col_indices(&self) -> &'a [I] { + self.symbolic.col_indices() + } + + /// Returns the column indices of row i. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row_raw(&self, i: usize) -> &'a [I] { + self.symbolic.col_indices_of_row_raw(i) + } + + /// Returns the column indices of row i. + /// + /// # Panics + /// + /// Panics if `i >= self.ncols()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row( + &self, + i: usize, + ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator { + self.symbolic.col_indices_of_row(i) + } + + /// Returns the range that the row `i` occupies in `self.col_indices()`. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn row_range(&self, i: usize) -> Range { + self.symbolic.row_range(i) + } + + /// Returns the range that the row `i` occupies in `self.col_indices()`. + /// + /// # Safety + /// + /// The behavior is undefined if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub unsafe fn row_range_unchecked(&self, i: usize) -> Range { + self.symbolic.row_range_unchecked(i) + } + + /// Returns a reference to the value at the given index using a binary search, or None if the + /// symbolic structure doesn't contain it + /// + /// # Panics + /// Panics if `row >= self.nrows()` + /// Panics if `col >= self.ncols()` + #[track_caller] + pub fn get(self, row: usize, col: usize) -> Option> { + assert!(row < self.nrows()); + assert!(col < self.ncols()); + + let Ok(pos) = self + .col_indices_of_row_raw(row) + .binary_search(&I::truncate(col)) + else { + return None; + }; + + Some(E::faer_map(self.values_of_row(row), |slice| &slice[pos])) + } +} diff --git a/src/sparse/csr/mod.rs b/src/sparse/csr/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..546f86f09c992fd70f386c1c916295b2e959c420 --- /dev/null +++ b/src/sparse/csr/mod.rs @@ -0,0 +1,15 @@ +use super::*; + +mod symbolic_own; +mod symbolic_ref; + +mod matmut; +mod matown; +mod matref; + +pub use symbolic_own::SymbolicSparseRowMat; +pub use symbolic_ref::SymbolicSparseRowMatRef; + +pub use matmut::SparseRowMatMut; +pub use matown::SparseRowMat; +pub use matref::SparseRowMatRef; diff --git a/src/sparse/csr/symbolic_own.rs b/src/sparse/csr/symbolic_own.rs new file mode 100644 index 0000000000000000000000000000000000000000..3011e66ccfc864a17f6435408130915048fc1403 --- /dev/null +++ b/src/sparse/csr/symbolic_own.rs @@ -0,0 +1,332 @@ +use super::*; +use crate::sparse::csc::*; + +/// Symbolic structure of sparse matrix in row format, either compressed or uncompressed. +/// +/// Requires: +/// * `nrows <= I::Signed::MAX` (always checked) +/// * `ncols <= I::Signed::MAX` (always checked) +/// * `row_ptrs` has length `nrows + 1` (always checked) +/// * `row_ptrs` is non-decreasing +/// * `row_ptrs[0]..row_ptrs[nrows]` is a valid range in row_indices (always checked, assuming +/// non-decreasing) +/// * if `nnz_per_row` is `None`, elements of `col_indices[row_ptrs[i]..row_ptrs[i + 1]]` are less +/// than `ncols` +/// +/// * `nnz_per_row[i] <= row_ptrs[i+1] - row_ptrs[i]` +/// * if `nnz_per_row` is `Some(_)`, elements of `col_indices[row_ptrs[i]..][..nnz_per_row[i]]` are +/// less than `ncols` +/// +/// * Within each row, column indices are unique and sorted in increasing order. +/// +/// # Note +/// Some algorithms allow working with matrices containing duplicate and/or unsorted column +/// indicers per row. +/// +/// Passing such a matrix to an algorithm that does not explicitly permit this is unspecified +/// (though not undefined) behavior. + +#[derive(Clone)] +pub struct SymbolicSparseRowMat { + pub(crate) nrows: usize, + pub(crate) ncols: usize, + pub(crate) row_ptr: alloc::vec::Vec, + pub(crate) row_nnz: Option>, + pub(crate) col_ind: alloc::vec::Vec, +} + +impl SymbolicSparseRowMat { + /// Creates a new symbolic matrix view after asserting its invariants. + /// + /// # Panics + /// + /// See type level documentation. + #[inline] + #[track_caller] + pub fn new_checked( + nrows: usize, + ncols: usize, + row_ptrs: Vec, + nnz_per_row: Option>, + col_indices: Vec, + ) -> Self { + SymbolicSparseRowMatRef::new_checked( + nrows, + ncols, + &row_ptrs, + nnz_per_row.as_deref(), + &col_indices, + ); + + Self { + nrows, + ncols, + row_ptr: row_ptrs, + row_nnz: nnz_per_row, + col_ind: col_indices, + } + } + + /// Creates a new symbolic matrix view from data containing duplicate and/or unsorted column + /// indices per row, after asserting its other invariants. + /// + /// # Panics + /// + /// See type level documentation. + #[inline] + #[track_caller] + pub fn new_unsorted_checked( + nrows: usize, + ncols: usize, + row_ptrs: Vec, + nnz_per_row: Option>, + col_indices: Vec, + ) -> Self { + SymbolicSparseRowMatRef::new_unsorted_checked( + nrows, + ncols, + &row_ptrs, + nnz_per_row.as_deref(), + &col_indices, + ); + + Self { + nrows, + ncols, + row_ptr: row_ptrs, + row_nnz: nnz_per_row, + col_ind: col_indices, + } + } + + /// Creates a new symbolic matrix view without asserting its invariants. + /// + /// # Safety + /// + /// See type level documentation. + #[inline(always)] + #[track_caller] + pub unsafe fn new_unchecked( + nrows: usize, + ncols: usize, + row_ptrs: Vec, + nnz_per_row: Option>, + col_indices: Vec, + ) -> Self { + SymbolicSparseRowMatRef::new_unchecked( + nrows, + ncols, + &row_ptrs, + nnz_per_row.as_deref(), + &col_indices, + ); + + Self { + nrows, + ncols, + row_ptr: row_ptrs, + row_nnz: nnz_per_row, + col_ind: col_indices, + } + } + + /// Returns the components of the matrix in the order: + /// - row count, + /// - column count, + /// - row pointers, + /// - nonzeros per row, + /// - column indices. + #[inline] + pub fn into_parts(self) -> (usize, usize, Vec, Option>, Vec) { + ( + self.nrows, + self.ncols, + self.row_ptr, + self.row_nnz, + self.col_ind, + ) + } + + /// Returns a view over the symbolic structure of `self`. + #[inline] + pub fn as_ref(&self) -> SymbolicSparseRowMatRef<'_, I> { + SymbolicSparseRowMatRef { + nrows: self.nrows, + ncols: self.ncols, + row_ptr: &self.row_ptr, + row_nnz: self.row_nnz.as_deref(), + col_ind: &self.col_ind, + } + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> usize { + self.nrows + } + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> usize { + self.ncols + } + + /// Consumes the matrix, and returns its transpose in column-major format without reallocating. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn into_transpose(self) -> SymbolicSparseColMat { + SymbolicSparseColMat { + nrows: self.ncols, + ncols: self.nrows, + col_ptr: self.row_ptr, + col_nnz: self.row_nnz, + row_ind: self.col_ind, + } + } + + /// Copies the current matrix into a newly allocated matrix. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn to_owned(&self) -> Result, FaerError> { + self.as_ref().to_owned() + } + + /// Copies the current matrix into a newly allocated matrix, with column-major order. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + #[inline] + pub fn to_col_major(&self) -> Result, FaerError> { + self.as_ref().to_col_major() + } + + /// Returns the number of symbolic non-zeros in the matrix. + /// + /// The value is guaranteed to be less than `I::Signed::MAX`. + /// + /// # Note + /// Allows unsorted matrices, but the output is a count of all the entries, including the + /// duplicate ones. + #[inline] + pub fn compute_nnz(&self) -> usize { + self.as_ref().compute_nnz() + } + + /// Returns the column pointers. + #[inline] + pub fn row_ptrs(&self) -> &[I] { + &self.row_ptr + } + + /// Returns the count of non-zeros per row of the matrix. + #[inline] + pub fn nnz_per_row(&self) -> Option<&[I]> { + self.row_nnz.as_deref() + } + + /// Returns the column indices. + #[inline] + pub fn col_indices(&self) -> &[I] { + &self.col_ind + } + + /// Returns the column indices of row `i`. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row_raw(&self, i: usize) -> &[I] { + &self.col_ind[self.row_range(i)] + } + + /// Returns the column indices of row `i`. + /// + /// # Panics + /// + /// Panics if `i >= self.ncols()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row( + &self, + i: usize, + ) -> impl '_ + ExactSizeIterator + DoubleEndedIterator { + self.as_ref().col_indices_of_row(i) + } + + /// Returns the range that the row `i` occupies in `self.col_indices()`. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn row_range(&self, i: usize) -> Range { + self.as_ref().row_range(i) + } + + /// Returns the range that the row `i` occupies in `self.col_indices()`. + /// + /// # Safety + /// + /// The behavior is undefined if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub unsafe fn row_range_unchecked(&self, i: usize) -> Range { + self.as_ref().row_range_unchecked(i) + } +} + +impl core::fmt::Debug for SymbolicSparseRowMat { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.as_ref().fmt(f) + } +} + +impl SymbolicSparseRowMat { + /// Create a new symbolic structure, and the corresponding order for the numerical values + /// from pairs of indices `(row, col)`. + #[inline] + pub fn try_new_from_indices( + nrows: usize, + ncols: usize, + indices: &[(I, I)], + ) -> Result<(Self, ValuesOrder), CreationError> { + SymbolicSparseColMat::try_new_from_indices_impl( + ncols, + nrows, + |i| { + let (row, col) = indices[i]; + (col, row) + }, + indices.len(), + ) + .map(|(m, o)| (m.into_transpose(), o)) + } + + /// Create a new symbolic structure, and the corresponding order for the numerical values + /// from pairs of indices `(row, col)`. + /// + /// Negative indices are ignored. + #[inline] + pub fn try_new_from_nonnegative_indices( + nrows: usize, + ncols: usize, + indices: &[(I::Signed, I::Signed)], + ) -> Result<(Self, ValuesOrder), CreationError> { + SymbolicSparseColMat::try_new_from_nonnegative_indices_impl( + ncols, + nrows, + |i| { + let (row, col) = indices[i]; + (col, row) + }, + indices.len(), + ) + .map(|(m, o)| (m.into_transpose(), o)) + } +} diff --git a/src/sparse/csr/symbolic_ref.rs b/src/sparse/csr/symbolic_ref.rs new file mode 100644 index 0000000000000000000000000000000000000000..71e53a42a7045d41c279feb512e5489f0377aa77 --- /dev/null +++ b/src/sparse/csr/symbolic_ref.rs @@ -0,0 +1,367 @@ +use super::*; +use crate::{assert, sparse::csc::*}; + +/// Symbolic view structure of sparse matrix in row format, either compressed or uncompressed. +/// +/// Requires: +/// * `nrows <= I::Signed::MAX` (always checked) +/// * `ncols <= I::Signed::MAX` (always checked) +/// * `row_ptrs` has length `nrows + 1` (always checked) +/// * `row_ptrs` is non-decreasing +/// * `row_ptrs[0]..row_ptrs[nrows]` is a valid range in row_indices (always checked, assuming +/// non-decreasing) +/// * if `nnz_per_row` is `None`, elements of `col_indices[row_ptrs[i]..row_ptrs[i + 1]]` are less +/// than `ncols` +/// +/// * `nnz_per_row[i] <= row_ptrs[i+1] - row_ptrs[i]` +/// * if `nnz_per_row` is `Some(_)`, elements of `col_indices[row_ptrs[i]..][..nnz_per_row[i]]` are +/// less than `ncols` +/// +/// * Within each row, column indices are unique and sorted in increasing order. +/// +/// # Note +/// Some algorithms allow working with matrices containing duplicate and/or unsorted column +/// indicers per row. +/// +/// Passing such a matrix to an algorithm that does not explicitly permit this is unspecified +/// (though not undefined) behavior. +pub struct SymbolicSparseRowMatRef<'a, I: Index> { + pub(crate) nrows: usize, + pub(crate) ncols: usize, + pub(crate) row_ptr: &'a [I], + pub(crate) row_nnz: Option<&'a [I]>, + pub(crate) col_ind: &'a [I], +} + +impl Copy for SymbolicSparseRowMatRef<'_, I> {} +impl Clone for SymbolicSparseRowMatRef<'_, I> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl<'short, I: Index> Reborrow<'short> for SymbolicSparseRowMatRef<'_, I> { + type Target = SymbolicSparseRowMatRef<'short, I>; + + #[inline] + fn rb(&self) -> Self::Target { + *self + } +} + +impl<'short, I: Index> ReborrowMut<'short> for SymbolicSparseRowMatRef<'_, I> { + type Target = SymbolicSparseRowMatRef<'short, I>; + + #[inline] + fn rb_mut(&mut self) -> Self::Target { + *self + } +} + +impl<'a, I: Index> IntoConst for SymbolicSparseRowMatRef<'a, I> { + type Target = SymbolicSparseRowMatRef<'a, I>; + + #[inline] + fn into_const(self) -> Self::Target { + self + } +} + +impl<'a, I: Index> SymbolicSparseRowMatRef<'a, I> { + /// Creates a new symbolic matrix view after asserting its invariants. + /// + /// # Panics + /// + /// See type level documentation. + #[inline] + #[track_caller] + pub fn new_checked( + nrows: usize, + ncols: usize, + row_ptrs: &'a [I], + nnz_per_row: Option<&'a [I]>, + col_indices: &'a [I], + ) -> Self { + assert!(all( + ncols <= I::Signed::MAX.zx(), + nrows <= I::Signed::MAX.zx(), + )); + assert!(row_ptrs.len() == nrows + 1); + for &[c, c_next] in windows2(row_ptrs) { + assert!(c <= c_next); + } + assert!(row_ptrs[ncols].zx() <= col_indices.len()); + + if let Some(nnz_per_row) = nnz_per_row { + for (&nnz_i, &[c, c_next]) in zip(nnz_per_row, windows2(row_ptrs)) { + assert!(nnz_i <= c_next - c); + let col_indices = &col_indices[c.zx()..c.zx() + nnz_i.zx()]; + if !col_indices.is_empty() { + let mut j_prev = col_indices[0]; + for &j in &col_indices[1..] { + assert!(j_prev < j); + j_prev = j; + } + let ncols = I::truncate(ncols); + assert!(j_prev < ncols); + } + } + } else { + for &[c, c_next] in windows2(row_ptrs) { + let col_indices = &col_indices[c.zx()..c_next.zx()]; + if !col_indices.is_empty() { + let mut j_prev = col_indices[0]; + for &j in &col_indices[1..] { + assert!(j_prev < j); + j_prev = j; + } + let ncols = I::truncate(ncols); + assert!(j_prev < ncols); + } + } + } + + Self { + nrows, + ncols, + row_ptr: row_ptrs, + row_nnz: nnz_per_row, + col_ind: col_indices, + } + } + + /// Creates a new symbolic matrix view from data containing duplicate and/or unsorted column + /// indices per row, after asserting its other invariants. + /// + /// # Panics + /// + /// See type level documentation. + #[inline] + #[track_caller] + pub fn new_unsorted_checked( + nrows: usize, + ncols: usize, + row_ptrs: &'a [I], + nnz_per_row: Option<&'a [I]>, + col_indices: &'a [I], + ) -> Self { + assert!(all( + ncols <= I::Signed::MAX.zx(), + nrows <= I::Signed::MAX.zx(), + )); + assert!(row_ptrs.len() == nrows + 1); + for &[c, c_next] in windows2(row_ptrs) { + assert!(c <= c_next); + } + assert!(row_ptrs[ncols].zx() <= col_indices.len()); + + if let Some(nnz_per_row) = nnz_per_row { + for (&nnz_i, &[c, c_next]) in zip(nnz_per_row, windows2(row_ptrs)) { + assert!(nnz_i <= c_next - c); + for &j in &col_indices[c.zx()..c.zx() + nnz_i.zx()] { + assert!(j < I::truncate(ncols)); + } + } + } else { + let c0 = row_ptrs[0].zx(); + let cn = row_ptrs[ncols].zx(); + for &j in &col_indices[c0..cn] { + assert!(j < I::truncate(ncols)); + } + } + + Self { + nrows, + ncols, + row_ptr: row_ptrs, + row_nnz: nnz_per_row, + col_ind: col_indices, + } + } + + /// Creates a new symbolic matrix view without asserting its invariants. + /// + /// # Safety + /// + /// See type level documentation. + #[inline(always)] + #[track_caller] + pub unsafe fn new_unchecked( + nrows: usize, + ncols: usize, + row_ptrs: &'a [I], + nnz_per_row: Option<&'a [I]>, + col_indices: &'a [I], + ) -> Self { + assert!(all( + ncols <= ::MAX.zx(), + nrows <= ::MAX.zx(), + )); + assert!(row_ptrs.len() == nrows + 1); + assert!(row_ptrs[nrows].zx() <= col_indices.len()); + + Self { + nrows, + ncols, + row_ptr: row_ptrs, + row_nnz: nnz_per_row, + col_ind: col_indices, + } + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> usize { + self.nrows + } + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> usize { + self.ncols + } + + /// Returns a view over the transpose of `self` in column-major format. + #[inline] + pub fn transpose(self) -> SymbolicSparseColMatRef<'a, I> { + SymbolicSparseColMatRef { + nrows: self.ncols, + ncols: self.nrows, + col_ptr: self.row_ptr, + col_nnz: self.row_nnz, + row_ind: self.col_ind, + } + } + + /// Copies the current matrix into a newly allocated matrix. + /// + /// # Note + /// Allows unsorted matrices, producing an unsorted output. + #[inline] + pub fn to_owned(&self) -> Result, FaerError> { + self.transpose() + .to_owned() + .map(SymbolicSparseColMat::into_transpose) + } + + /// Copies the current matrix into a newly allocated matrix, with column-major order. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + #[inline] + pub fn to_col_major(&self) -> Result, FaerError> { + self.transpose().to_row_major().map(|m| m.into_transpose()) + } + + /// Returns the number of symbolic non-zeros in the matrix. + /// + /// The value is guaranteed to be less than `I::Signed::MAX`. + /// + /// # Note + /// Allows unsorted matrices, but the output is a count of all the entries, including the + /// duplicate ones. + #[inline] + pub fn compute_nnz(&self) -> usize { + self.transpose().compute_nnz() + } + + /// Returns the column pointers. + #[inline] + pub fn row_ptrs(&self) -> &'a [I] { + self.row_ptr + } + + /// Returns the count of non-zeros per column of the matrix. + #[inline] + pub fn nnz_per_row(&self) -> Option<&'a [I]> { + self.row_nnz + } + + /// Returns the column indices. + #[inline] + pub fn col_indices(&self) -> &'a [I] { + self.col_ind + } + + /// Returns the column indices of row i. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row_raw(&self, i: usize) -> &'a [I] { + &self.col_ind[self.row_range(i)] + } + + /// Returns the column indices of row i. + /// + /// # Panics + /// + /// Panics if `i >= self.ncols()`. + #[inline] + #[track_caller] + pub fn col_indices_of_row( + &self, + i: usize, + ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator { + self.col_indices_of_row_raw(i).iter().map( + #[inline(always)] + |&i| i.zx(), + ) + } + + /// Returns the range that the row `i` occupies in `self.col_indices()`. + /// + /// # Panics + /// + /// Panics if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub fn row_range(&self, i: usize) -> Range { + let start = self.row_ptr[i].zx(); + let end = self + .row_nnz + .map(|row_nnz| row_nnz[i].zx() + start) + .unwrap_or(self.row_ptr[i + 1].zx()); + + start..end + } + + /// Returns the range that the row `i` occupies in `self.col_indices()`. + /// + /// # Safety + /// + /// The behavior is undefined if `i >= self.nrows()`. + #[inline] + #[track_caller] + pub unsafe fn row_range_unchecked(&self, i: usize) -> Range { + let start = __get_unchecked(self.row_ptr, i).zx(); + let end = self + .row_nnz + .map(|row_nnz| (__get_unchecked(row_nnz, i).zx() + start)) + .unwrap_or(__get_unchecked(self.row_ptr, i + 1).zx()); + + start..end + } +} + +impl core::fmt::Debug for SymbolicSparseRowMatRef<'_, I> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + let mat = *self; + let mut iter = (0..mat.nrows()).flat_map(move |i| { + struct Wrapper(usize, usize); + impl core::fmt::Debug for Wrapper { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let row = self.0; + let col = self.1; + write!(f, "({row}, {col}") + } + } + + mat.col_indices_of_row(i).map(move |j| Wrapper(i, j)) + }); + + f.debug_list().entries(&mut iter).finish() + } +} diff --git a/faer-libs/faer-sparse/src/amd.rs b/src/sparse/linalg/amd.rs similarity index 96% rename from faer-libs/faer-sparse/src/amd.rs rename to src/sparse/linalg/amd.rs index c84ed16a68042ec824ca88a7a87b1095979f94d0..40f471ca127b0f3e03b643e2762cbe3a1e03df78 100644 --- a/faer-libs/faer-sparse/src/amd.rs +++ b/src/sparse/linalg/amd.rs @@ -34,14 +34,14 @@ // OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH // DAMAGE. -use crate::{ +use super::{ ghost::{self, Array, Idx, MaybeIdx}, mem::{self, NONE}, windows2, FaerError, Index, SignedIndex, SymbolicSparseColMatRef, }; +use crate::{assert, ComplexField}; use core::{cell::Cell, iter::zip}; use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{assert, ComplexField}; use reborrow::*; #[inline] @@ -817,13 +817,10 @@ fn amd_1( s_p[k] += one; seen += one; - } else if j == k { - // Skip the diagonal. - seen += one; - break; } else { - // j > k - // First entry below the diagonal. + if j == k { + seen += one; + } break; } @@ -842,13 +839,10 @@ fn amd_1( s_p[j] += one; seen_j += one; - } else if i == k { - // Entry A(k,j) in lower part and A(j,k) in upper. - seen_j += one; - break; } else { - // i > k - // Consider this entry later, when k advances to i. + if i == k { + seen_j += one; + } break; } } @@ -972,10 +966,10 @@ fn aat( seen += one; len[j] += one; len[k] += one; - } else if j == k { - seen += one; - break; } else { + if j == k { + seen += one; + } break; } @@ -986,10 +980,10 @@ fn aat( len[i] += one; len[j] += one; seen_j += one; - } else if i == k { - seen_j += one; - break; } else { + if i == k { + seen_j += one; + } break; } } @@ -1009,6 +1003,8 @@ fn aat( nzaat.ok_or(FaerError::IndexOverflow).map(I::zx) } +/// Computes the size and alignment of required workspace for computing the AMD ordering of a sorted +/// matrix. pub fn order_sorted_req(n: usize, nnz_upper: usize) -> Result { let n_req = StackReq::try_new::(n)?; let nzaat = nnz_upper.checked_mul(2).ok_or(SizeOverflow)?; @@ -1039,6 +1035,8 @@ pub fn order_sorted_req(n: usize, nnz_upper: usize) -> Result( n: usize, nnz_upper: usize, @@ -1050,7 +1048,9 @@ pub fn order_maybe_unsorted_req( ]) } -pub fn order_sorted( +/// Computes the approximate minimum degree ordering for reducing the fill-in during the sparse +/// Cholesky factorization of a matrix with the sparsity pattern of `A + A.T`. +pub fn order( perm: &mut [I], perm_inv: &mut [I], A: SymbolicSparseColMatRef<'_, I>, @@ -1084,6 +1084,9 @@ pub fn order_sorted( )) } +/// Computes the approximate minimum degree ordering for reducing the fill-in during the sparse +/// Cholesky factorization of a matrix with the sparsity pattern of `A + A.T`. +/// /// # Note /// Allows unsorted matrices. pub fn order_maybe_unsorted( @@ -1106,9 +1109,10 @@ pub fn order_maybe_unsorted( let (new_col_ptrs, stack) = stack.make_raw::(n + 1); let (new_row_indices, mut stack) = stack.make_raw::(nnz); let A = preprocess(new_col_ptrs, new_row_indices, A, stack.rb_mut()); - order_sorted(perm, perm_inv, A, control, stack) + order(perm, perm_inv, A, control, stack) } +/// Tuning parameters for the AMD implementation. #[derive(Debug, Copy, Clone, PartialEq)] pub struct Control { /// "dense" if degree > dense * sqrt(n) @@ -1127,9 +1131,13 @@ impl Default for Control { } } +/// Flop count of the LDLT and LU factorizations if the provided ordering is used. #[derive(Default, Debug, Copy, Clone, PartialEq)] pub struct FlopCount { + /// Number of division. pub n_div: f64, + /// Number of multiplications and subtractions for the LDLT factorization. pub n_mult_subs_ldl: f64, + /// Number of multiplications and subtractions for the LU factorization. pub n_mult_subs_lu: f64, } diff --git a/faer-libs/faer-sparse/src/cholesky.rs b/src/sparse/linalg/cholesky.rs similarity index 95% rename from faer-libs/faer-sparse/src/cholesky.rs rename to src/sparse/linalg/cholesky.rs index 85485df314cfc646d1ac1787df98c9fa894762dc..3ddd391d68b6d7cab64fee61c6ffe7185281c292 100644 --- a/faer-libs/faer-sparse/src/cholesky.rs +++ b/src/sparse/linalg/cholesky.rs @@ -1,40 +1,42 @@ //! Computes the Cholesky decomposition (either LLT, LDLT, or Bunch-Kaufman) of a given sparse -//! matrix. See [`faer_cholesky`] for more info. +//! matrix. See [`crate::linalg::cholesky`] for more info. //! //! The entry point in this module is [`SymbolicCholesky`] and [`factorize_symbolic_cholesky`]. //! //! # Note //! The functions in this module accept unsorted input, producing a sorted decomposition factor //! (simplicial). +#![allow(missing_docs)] // implementation inspired by https://gitlab.com/hodge_star/catamari -use crate::{ +use super::{ amd::{self, Control}, ghost::{self, Array, Idx, MaybeIdx}, ghost_permute_hermitian_unsorted, ghost_permute_hermitian_unsorted_symbolic, make_raw_req, mem, mem::NONE, - nomem, triangular_solve, try_collect, try_zeroed, windows2, FaerError, Index, PermutationRef, - Side, SliceGroup, SliceGroupMut, SparseColMatRef, SupernodalThreshold, SymbolicSparseColMatRef, + nomem, triangular_solve, try_collect, try_zeroed, windows2, FaerError, Index, PermRef, Side, + SliceGroup, SliceGroupMut, SparseColMatRef, SupernodalThreshold, SymbolicSparseColMatRef, SymbolicSupernodalParams, }; -use core::{cell::Cell, iter::zip}; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -pub use faer_cholesky::{ +pub use crate::linalg::cholesky::{ bunch_kaufman::compute::BunchKaufmanRegularization, ldlt_diagonal::compute::LdltRegularization, llt::{compute::LltRegularization, CholeskyError}, }; -use faer_core::{ - assert, permutation::SignedIndex, temp_mat_req, temp_mat_uninit, unzipped, zipped, - ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, +use crate::{ + assert, + linalg::{temp_mat_req, temp_mat_uninit}, + unzipped, zipped, ComplexField, Conj, Entity, MatMut, MatRef, Parallelism, SignedIndex, }; -use faer_entity::{GroupFor, Symbolic}; +use core::{cell::Cell, iter::zip}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; +use faer_entity::GroupFor; use reborrow::*; #[derive(Copy, Clone)] #[allow(dead_code)] -enum Ordering<'a, I> { +enum Ordering<'a, I: Index> { Identity, Custom(&'a [I]), Algorithm( @@ -49,7 +51,7 @@ enum Ordering<'a, I> { pub mod simplicial { use super::*; - use faer_core::assert; + use crate::assert; /// Computes the elimination tree and column counts of the Cholesky factorization of the matrix /// `A`. @@ -269,7 +271,7 @@ pub mod simplicial { let has_delta = delta > E::Real::faer_zero(); let mut dynamic_regularization_count = 0usize; - let (x, stack) = crate::make_raw::(n, stack); + let (x, stack) = crate::sparse::linalg::make_raw::(n, stack); let (current_row_index, stack) = stack.make_raw::(n); let (ereach_stack, stack) = stack.make_raw::(n); let (visited, _) = stack.make_raw::(n); @@ -441,7 +443,7 @@ pub mod simplicial { let has_delta = delta > E::Real::faer_zero(); let mut dynamic_regularization_count = 0usize; - let (x, stack) = crate::make_raw::(n, stack); + let (x, stack) = crate::sparse::linalg::make_raw::(n, stack); let (current_row_index, stack) = stack.make_raw::(n); let (ereach_stack, stack) = stack.make_raw::(n); let (visited, _) = stack.make_raw::(n); @@ -842,7 +844,7 @@ pub mod simplicial { } #[derive(Debug)] - pub struct SymbolicSimplicialCholesky { + pub struct SymbolicSimplicialCholesky { dimension: usize, col_ptrs: alloc::vec::Vec, row_indices: alloc::vec::Vec, @@ -871,7 +873,7 @@ pub mod simplicial { pub mod supernodal { use super::*; - use faer_core::{assert, debug_assert}; + use crate::{assert, debug_assert}; fn ereach_super<'n, 'nsuper, I: Index>( A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>, @@ -905,7 +907,7 @@ pub mod supernodal { fn ereach_super_ata<'m, 'n, 'nsuper, I: Index>( A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>, - perm: Option>, + perm: Option>, min_col: &Array<'m, MaybeIdx<'n, I>>, super_etree: &Array<'nsuper, MaybeIdx<'nsuper, I>>, index_to_super: &Array<'n, Idx<'nsuper, I>>, @@ -917,7 +919,7 @@ pub mod supernodal { let k_: I = *k.truncate(); visited[index_to_super[k].zx()] = k_.to_signed(); - let fwd = perm.map(|perm| perm.into_arrays().0); + let fwd = perm.map(|perm| perm.arrays().0); let fwd = |i: Idx<'n, usize>| fwd.map(|fwd| fwd[k].zx()).unwrap_or(i); for i in A.row_indices_of_col(fwd(k)) { let Some(i) = min_col[i].idx() else { continue }; @@ -942,7 +944,7 @@ pub mod supernodal { } #[derive(Debug)] - pub struct SymbolicSupernodeRef<'a, I> { + pub struct SymbolicSupernodeRef<'a, I: Index> { start: usize, pattern: &'a [I], } @@ -974,7 +976,7 @@ pub mod supernodal { } #[derive(Debug)] - pub struct SupernodeRef<'a, I, E: Entity> { + pub struct SupernodeRef<'a, I: Index, E: Entity> { matrix: MatRef<'a, E>, symbolic: SymbolicSupernodeRef<'a, I>, } @@ -985,7 +987,7 @@ pub mod supernodal { symbolic: &'a SymbolicSupernodalCholesky, values: GroupFor, subdiag: GroupFor, - perm: PermutationRef<'a, I, E>, + perm: PermRef<'a, I>, ) -> Self { let values = SliceGroup::<'_, E>::new(values); let subdiag = SliceGroup::<'_, E>::new(subdiag); @@ -1020,7 +1022,7 @@ pub mod supernodal { let s_ncols = s_end - s_start; let s_nrows = s_pattern.len() + s_ncols; - let Ls = faer_core::mat::from_column_major_slice::<'_, E>( + let Ls = crate::mat::from_column_major_slice::<'_, E>( L_values .subslice( symbolic.col_ptrs_for_values()[s].zx() @@ -1063,7 +1065,7 @@ pub mod supernodal { let Ls = s.matrix; let (Ls_top, Ls_bot) = Ls.split_at_row(size); let mut x_top = x.rb_mut().subrows_mut(s.start(), size); - faer_core::solve::solve_unit_lower_triangular_in_place_with_conj( + crate::linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj( Ls_top, conj, x_top.rb_mut(), @@ -1071,7 +1073,7 @@ pub mod supernodal { ); let (mut tmp, _) = temp_mat_uninit::(s.pattern().len(), k, stack.rb_mut()); - faer_core::mul::matmul_with_conj( + crate::linalg::matmul::matmul_with_conj( tmp.rb_mut(), Ls_bot, conj, @@ -1082,7 +1084,7 @@ pub mod supernodal { parallelism, ); - let inv = self.perm.into_arrays().1; + let inv = self.perm.arrays().1; for j in 0..k { for (idx, i) in s.pattern().iter().enumerate() { let i = i.zx(); @@ -1154,7 +1156,7 @@ pub mod supernodal { let (Ls_top, Ls_bot) = Ls.split_at_row(size); let (mut tmp, _) = temp_mat_uninit::(s.pattern().len(), k, stack.rb_mut()); - let inv = self.perm.into_arrays().1; + let inv = self.perm.arrays().1; for j in 0..k { for (idx, i) in s.pattern().iter().enumerate() { let i = i.zx(); @@ -1164,7 +1166,7 @@ pub mod supernodal { } let mut x_top = x.rb_mut().subrows_mut(s.start(), size); - faer_core::mul::matmul_with_conj( + crate::linalg::matmul::matmul_with_conj( x_top.rb_mut(), Ls_bot.transpose(), conj.compose(Conj::Yes), @@ -1174,7 +1176,7 @@ pub mod supernodal { E::faer_one().faer_neg(), parallelism, ); - faer_core::solve::solve_unit_upper_triangular_in_place_with_conj( + crate::linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj( Ls_top.transpose(), conj.compose(Conj::Yes), x_top.rb_mut(), @@ -1217,7 +1219,7 @@ pub mod supernodal { let s_ncols = s_end - s_start; let s_nrows = s_pattern.len() + s_ncols; - let Ls = faer_core::mat::from_column_major_slice::<'_, E>( + let Ls = crate::mat::from_column_major_slice::<'_, E>( L_values .subslice( symbolic.col_ptrs_for_values()[s].zx() @@ -1259,7 +1261,7 @@ pub mod supernodal { let Ls = s.matrix; let (Ls_top, Ls_bot) = Ls.split_at_row(size); let mut x_top = x.rb_mut().subrows_mut(s.start(), size); - faer_core::solve::solve_unit_lower_triangular_in_place_with_conj( + crate::linalg::triangular_solve::solve_unit_lower_triangular_in_place_with_conj( Ls_top, conj, x_top.rb_mut(), @@ -1267,7 +1269,7 @@ pub mod supernodal { ); let (mut tmp, _) = temp_mat_uninit::(s.pattern().len(), k, stack.rb_mut()); - faer_core::mul::matmul_with_conj( + crate::linalg::matmul::matmul_with_conj( tmp.rb_mut(), Ls_bot, conj, @@ -1312,7 +1314,7 @@ pub mod supernodal { } let mut x_top = x.rb_mut().subrows_mut(s.start(), size); - faer_core::mul::matmul_with_conj( + crate::linalg::matmul::matmul_with_conj( x_top.rb_mut(), Ls_bot.transpose(), conj.compose(Conj::Yes), @@ -1322,7 +1324,7 @@ pub mod supernodal { E::faer_one().faer_neg(), parallelism, ); - faer_core::solve::solve_unit_upper_triangular_in_place_with_conj( + crate::linalg::triangular_solve::solve_unit_upper_triangular_in_place_with_conj( Ls_top.transpose(), conj.compose(Conj::Yes), x_top.rb_mut(), @@ -1365,7 +1367,7 @@ pub mod supernodal { let s_ncols = s_end - s_start; let s_nrows = s_pattern.len() + s_ncols; - let Ls = faer_core::mat::from_column_major_slice::<'_, E>( + let Ls = crate::mat::from_column_major_slice::<'_, E>( L_values .subslice( symbolic.col_ptrs_for_values()[s].zx() @@ -1407,7 +1409,7 @@ pub mod supernodal { let Ls = s.matrix; let (Ls_top, Ls_bot) = Ls.split_at_row(size); let mut x_top = x.rb_mut().subrows_mut(s.start(), size); - faer_core::solve::solve_lower_triangular_in_place_with_conj( + crate::linalg::triangular_solve::solve_lower_triangular_in_place_with_conj( Ls_top, conj, x_top.rb_mut(), @@ -1415,7 +1417,7 @@ pub mod supernodal { ); let (mut tmp, _) = temp_mat_uninit::(s.pattern().len(), k, stack.rb_mut()); - faer_core::mul::matmul_with_conj( + crate::linalg::matmul::matmul_with_conj( tmp.rb_mut(), Ls_bot, conj, @@ -1448,7 +1450,7 @@ pub mod supernodal { } let mut x_top = x.rb_mut().subrows_mut(s.start(), size); - faer_core::mul::matmul_with_conj( + crate::linalg::matmul::matmul_with_conj( x_top.rb_mut(), Ls_bot.transpose(), conj.compose(Conj::Yes), @@ -1458,7 +1460,7 @@ pub mod supernodal { E::faer_one().faer_neg(), parallelism, ); - faer_core::solve::solve_upper_triangular_in_place_with_conj( + crate::linalg::triangular_solve::solve_upper_triangular_in_place_with_conj( Ls_top.transpose(), conj.compose(Conj::Yes), x_top.rb_mut(), @@ -1575,7 +1577,7 @@ pub mod supernodal { pub(crate) fn ghost_factorize_supernodal_symbolic<'m, 'n, I: Index>( A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>, - col_perm: Option>, + col_perm: Option>, min_col: Option<&Array<'m, MaybeIdx<'n, I>>>, input: CholeskyInput, etree: &Array<'n, MaybeIdx<'n, I>>, @@ -2156,7 +2158,7 @@ pub mod supernodal { req = req.try_or(d_req)?; } req = req.try_or( - faer_cholesky::ldlt_diagonal::compute::raw_cholesky_in_place_req::( + crate::linalg::cholesky::ldlt_diagonal::compute::raw_cholesky_in_place_req::( s_ncols, parallelism, Default::default(), @@ -2212,7 +2214,7 @@ pub mod supernodal { req = req.try_or(d_req)?; } req = req.try_or( - faer_cholesky::ldlt_diagonal::compute::raw_cholesky_in_place_req::( + crate::linalg::cholesky::ldlt_diagonal::compute::raw_cholesky_in_place_req::( s_ncols, parallelism, Default::default(), @@ -2270,15 +2272,12 @@ pub mod supernodal { } req = StackReq::try_any_of([ req, - faer_cholesky::bunch_kaufman::compute::cholesky_in_place_req::( + crate::linalg::cholesky::bunch_kaufman::compute::cholesky_in_place_req::( s_ncols, parallelism, Default::default(), )?, - faer_core::permutation::permute_cols_in_place_req::( - s_pattern.len(), - s_ncols, - )?, + crate::perm::permute_cols_in_place_req::(s_pattern.len(), s_ncols)?, ])?; } req.try_and(StackReq::try_new::(n)?) @@ -2316,7 +2315,7 @@ pub mod supernodal { // mapping from global indices to local let (global_to_local, mut stack) = stack.make_raw::(n); - mem::fill_none(global_to_local.as_mut()); + mem::fill_none(global_to_local); for s in 0..n_supernodes { let s_start = symbolic.supernode_begin[s].zx(); @@ -2332,7 +2331,7 @@ pub mod supernodal { let (head, tail) = L_values.rb_mut().split_at(col_ptr_val[s].zx()); let head = head.rb(); - let mut Ls = faer_core::mat::from_column_major_slice_mut::<'_, E>( + let mut Ls = crate::mat::from_column_major_slice_mut::<'_, E>( tail.subslice(0..(col_ptr_val[s + 1] - col_ptr_val[s]).zx()) .into_inner(), s_nrows, @@ -2366,7 +2365,7 @@ pub mod supernodal { let d_ncols = d_end - d_start; let d_nrows = d_pattern.len() + d_ncols; - let Ld = faer_core::mat::from_column_major_slice::<'_, E>( + let Ld = crate::mat::from_column_major_slice::<'_, E>( head.subslice(col_ptr_val[d].zx()..col_ptr_val[d + 1].zx()) .into_inner(), d_nrows, @@ -2388,7 +2387,7 @@ pub mod supernodal { let (mut tmp_top, mut tmp_bot) = tmp.split_at_row_mut(d_pattern_mid_len); - use faer_core::{mul, mul::triangular}; + use crate::linalg::{matmul, matmul::triangular}; triangular::matmul( tmp_top.rb_mut(), triangular::BlockStructure::TriangularLower, @@ -2400,7 +2399,7 @@ pub mod supernodal { E::faer_one(), parallelism, ); - mul::matmul( + matmul::matmul( tmp_bot.rb_mut(), Ld_bot, Ld_mid.rb().adjoint(), @@ -2454,22 +2453,23 @@ pub mod supernodal { let (mut Ls_top, mut Ls_bot) = Ls.rb_mut().split_at_row_mut(s_ncols); let params = Default::default(); - dynamic_regularization_count += match faer_cholesky::llt::compute::cholesky_in_place( - Ls_top.rb_mut(), - regularization, - parallelism, - stack.rb_mut(), - params, - ) { - Ok(count) => count, - Err(err) => { - return Err(CholeskyError { - non_positive_definite_minor: err.non_positive_definite_minor + s_start, - }) + dynamic_regularization_count += + match crate::linalg::cholesky::llt::compute::cholesky_in_place( + Ls_top.rb_mut(), + regularization, + parallelism, + stack.rb_mut(), + params, + ) { + Ok(count) => count, + Err(err) => { + return Err(CholeskyError { + non_positive_definite_minor: err.non_positive_definite_minor + s_start, + }) + } } - } - .dynamic_regularization_count; - faer_core::solve::solve_lower_triangular_in_place( + .dynamic_regularization_count; + crate::linalg::triangular_solve::solve_lower_triangular_in_place( Ls_top.rb().conjugate(), Ls_bot.rb_mut().transpose_mut(), parallelism, @@ -2514,7 +2514,7 @@ pub mod supernodal { // mapping from global indices to local let (global_to_local, mut stack) = stack.make_raw::(n); - mem::fill_none(global_to_local.as_mut()); + mem::fill_none(global_to_local); for s in 0..n_supernodes { let s_start = symbolic.supernode_begin[s].zx(); @@ -2530,7 +2530,7 @@ pub mod supernodal { let (head, tail) = L_values.rb_mut().split_at(col_ptr_val[s].zx()); let head = head.rb(); - let mut Ls = faer_core::mat::from_column_major_slice_mut::<'_, E>( + let mut Ls = crate::mat::from_column_major_slice_mut::<'_, E>( tail.subslice(0..(col_ptr_val[s + 1] - col_ptr_val[s]).zx()) .into_inner(), s_nrows, @@ -2564,7 +2564,7 @@ pub mod supernodal { let d_ncols = d_end - d_start; let d_nrows = d_pattern.len() + d_ncols; - let Ld = faer_core::mat::from_column_major_slice::<'_, E>( + let Ld = crate::mat::from_column_major_slice::<'_, E>( head.subslice(col_ptr_val[d].zx()..col_ptr_val[d + 1].zx()) .into_inner(), d_nrows, @@ -2602,7 +2602,7 @@ pub mod supernodal { let (mut tmp_top, mut tmp_bot) = tmp.split_at_row_mut(d_pattern_mid_len); - use faer_core::{mul, mul::triangular}; + use crate::linalg::{matmul, matmul::triangular}; triangular::matmul( tmp_top.rb_mut(), triangular::BlockStructure::TriangularLower, @@ -2614,7 +2614,7 @@ pub mod supernodal { E::faer_one(), parallelism, ); - mul::matmul( + matmul::matmul( tmp_bot.rb_mut(), Ld_bot, Ld_mid_x_D.rb().adjoint(), @@ -2669,7 +2669,7 @@ pub mod supernodal { let params = Default::default(); dynamic_regularization_count += - faer_cholesky::ldlt_diagonal::compute::raw_cholesky_in_place( + crate::linalg::cholesky::ldlt_diagonal::compute::raw_cholesky_in_place( Ls_top.rb_mut(), LdltRegularization { dynamic_regularization_signs: regularization @@ -2683,10 +2683,10 @@ pub mod supernodal { ) .dynamic_regularization_count; zipped!(Ls_top.rb_mut()) - .for_each_triangular_upper(faer_core::zip::Diag::Skip, |unzipped!(mut x)| { + .for_each_triangular_upper(crate::linalg::zip::Diag::Skip, |unzipped!(mut x)| { x.write(E::faer_zero()) }); - faer_core::solve::solve_unit_lower_triangular_in_place( + crate::linalg::triangular_solve::solve_unit_lower_triangular_in_place( Ls_top.rb().conjugate(), Ls_bot.rb_mut().transpose_mut(), parallelism, @@ -2745,7 +2745,7 @@ pub mod supernodal { // mapping from global indices to local let (global_to_local, mut stack) = stack.make_raw::(n); - mem::fill_none(global_to_local.as_mut()); + mem::fill_none(global_to_local); for s in 0..n_supernodes { let s_start = symbolic.supernode_begin[s].zx(); @@ -2761,7 +2761,7 @@ pub mod supernodal { let (head, tail) = L_values.rb_mut().split_at(col_ptr_val[s].zx()); let head = head.rb(); - let mut Ls = faer_core::mat::from_column_major_slice_mut::<'_, E>( + let mut Ls = crate::mat::from_column_major_slice_mut::<'_, E>( tail.subslice(0..(col_ptr_val[s + 1] - col_ptr_val[s]).zx()) .into_inner(), s_nrows, @@ -2795,7 +2795,7 @@ pub mod supernodal { let d_ncols = d_end - d_start; let d_nrows = d_pattern.len() + d_ncols; - let Ld = faer_core::mat::from_column_major_slice::<'_, E>( + let Ld = crate::mat::from_column_major_slice::<'_, E>( head.subslice(col_ptr_val[d].zx()..col_ptr_val[d + 1].zx()) .into_inner(), d_nrows, @@ -2873,7 +2873,7 @@ pub mod supernodal { let (mut tmp_top, mut tmp_bot) = tmp.split_at_row_mut(d_pattern_mid_len); - use faer_core::{mul, mul::triangular}; + use crate::linalg::{matmul, matmul::triangular}; triangular::matmul( tmp_top.rb_mut(), triangular::BlockStructure::TriangularLower, @@ -2885,7 +2885,7 @@ pub mod supernodal { E::faer_one(), parallelism, ); - mul::matmul( + matmul::matmul( tmp_bot.rb_mut(), Ld_bot, Ld_mid_x_D.rb().adjoint(), @@ -2940,9 +2940,9 @@ pub mod supernodal { let mut s_subdiag = subdiag.rb_mut().subslice(s_start..s_end); let params = Default::default(); - let (info, perm) = faer_cholesky::bunch_kaufman::compute::cholesky_in_place( + let (info, perm) = crate::linalg::cholesky::bunch_kaufman::compute::cholesky_in_place( Ls_top.rb_mut(), - faer_core::mat::from_column_major_slice_mut::<'_, E>( + crate::mat::from_column_major_slice_mut::<'_, E>( s_subdiag.rb_mut().into_inner(), s_ncols, 1, @@ -2962,15 +2962,11 @@ pub mod supernodal { ); dynamic_regularization_count += info.dynamic_regularization_count; zipped!(Ls_top.rb_mut()) - .for_each_triangular_upper(faer_core::zip::Diag::Skip, |unzipped!(mut x)| { + .for_each_triangular_upper(crate::linalg::zip::Diag::Skip, |unzipped!(mut x)| { x.write(E::faer_zero()) }); - faer_core::permutation::permute_cols_in_place( - Ls_bot.rb_mut(), - perm.rb(), - stack.rb_mut(), - ); + crate::perm::permute_cols_in_place(Ls_bot.rb_mut(), perm.rb(), stack.rb_mut()); for p in &mut perm_forward[s_start..s_end] { *p += I::truncate(s_start); @@ -2979,7 +2975,7 @@ pub mod supernodal { *p += I::truncate(s_start); } - faer_core::solve::solve_unit_lower_triangular_in_place( + crate::linalg::triangular_solve::solve_unit_lower_triangular_in_place( Ls_top.rb().conjugate(), Ls_bot.rb_mut().transpose_mut(), parallelism, @@ -3022,27 +3018,27 @@ pub mod supernodal { } #[derive(Debug)] - pub struct SupernodalLltRef<'a, I, E: Entity> { + pub struct SupernodalLltRef<'a, I: Index, E: Entity> { symbolic: &'a SymbolicSupernodalCholesky, values: SliceGroup<'a, E>, } #[derive(Debug)] - pub struct SupernodalLdltRef<'a, I, E: Entity> { + pub struct SupernodalLdltRef<'a, I: Index, E: Entity> { symbolic: &'a SymbolicSupernodalCholesky, values: SliceGroup<'a, E>, } #[derive(Debug)] - pub struct SupernodalIntranodeBunchKaufmanRef<'a, I, E: Entity> { + pub struct SupernodalIntranodeBunchKaufmanRef<'a, I: Index, E: Entity> { symbolic: &'a SymbolicSupernodalCholesky, values: SliceGroup<'a, E>, subdiag: SliceGroup<'a, E>, - pub(super) perm: PermutationRef<'a, I, E>, + pub(super) perm: PermRef<'a, I>, } #[derive(Debug)] - pub struct SymbolicSupernodalCholesky { + pub struct SymbolicSupernodalCholesky { pub(crate) dimension: usize, pub(crate) supernode_postorder: alloc::vec::Vec, pub(crate) supernode_postorder_inv: alloc::vec::Vec, @@ -3166,14 +3162,14 @@ impl ComputationModel { /// The inner factorization used for the symbolic Cholesky, either simplicial or symbolic. #[derive(Debug)] -pub enum SymbolicCholeskyRaw { +pub enum SymbolicCholeskyRaw { Simplicial(simplicial::SymbolicSimplicialCholesky), Supernodal(supernodal::SymbolicSupernodalCholesky), } /// The symbolic structure of a sparse Cholesky decomposition. #[derive(Debug)] -pub struct SymbolicCholesky { +pub struct SymbolicCholesky { raw: SymbolicCholeskyRaw, perm_fwd: alloc::vec::Vec, perm_inv: alloc::vec::Vec, @@ -3204,8 +3200,8 @@ impl SymbolicCholesky { /// Returns the permutation that was computed during symbolic analysis. #[inline] - pub fn perm(&self) -> PermutationRef<'_, I, Symbolic> { - unsafe { PermutationRef::new_unchecked(&self.perm_fwd, &self.perm_inv) } + pub fn perm(&self) -> PermRef<'_, I> { + unsafe { PermRef::new_unchecked(&self.perm_fwd, &self.perm_inv) } } /// Returns the length of the slice needed to store the numerical values of the Cholesky @@ -3351,9 +3347,9 @@ impl SymbolicCholesky { let A_nnz = self.A_nnz; let A = ghost::SparseColMatRef::new(A, N, N); - let perm = ghost::PermutationRef::new(self.perm(), N); + let perm = ghost::PermRef::new(self.perm(), N); - let (mut new_values, stack) = crate::make_raw::(A_nnz, stack); + let (mut new_values, stack) = crate::sparse::linalg::make_raw::(A_nnz, stack); let (new_col_ptr, stack) = stack.make_raw::(n + 1); let (new_row_ind, mut stack) = stack.make_raw::(A_nnz); @@ -3368,7 +3364,7 @@ impl SymbolicCholesky { new_col_ptr, new_row_ind, A, - perm.cast(), + perm, side, out_side, false, @@ -3431,8 +3427,8 @@ impl SymbolicCholesky { 0 }); - let perm = ghost::PermutationRef::new(self.perm(), N); - let fwd = perm.into_arrays().0; + let perm = ghost::PermRef::new(self.perm(), N); + let fwd = perm.arrays().0; let signs = regularization.dynamic_regularization_signs.map(|signs| { { let new_signs = Array::from_mut(new_signs, N); @@ -3448,7 +3444,7 @@ impl SymbolicCholesky { ..regularization }; - let (mut new_values, stack) = crate::make_raw::(A_nnz, stack); + let (mut new_values, stack) = crate::sparse::linalg::make_raw::(A_nnz, stack); let (new_col_ptr, stack) = stack.make_raw::(n + 1); let (new_row_ind, mut stack) = stack.make_raw::(A_nnz); @@ -3463,7 +3459,7 @@ impl SymbolicCholesky { new_col_ptr, new_row_ind, A, - perm.cast(), + perm, side, out_side, false, @@ -3527,10 +3523,10 @@ impl SymbolicCholesky { 0 }); - let static_perm = ghost::PermutationRef::new(self.perm(), N); + let static_perm = ghost::PermRef::new(self.perm(), N); let signs = regularization.dynamic_regularization_signs.map(|signs| { { - let fwd = static_perm.into_arrays().0; + let fwd = static_perm.arrays().0; let new_signs = Array::from_mut(new_signs, N); let signs = Array::from_ref(signs, N); for i in N.indices() { @@ -3540,7 +3536,7 @@ impl SymbolicCholesky { &mut *new_signs }); - let (mut new_values, stack) = crate::make_raw::(A_nnz, stack); + let (mut new_values, stack) = crate::sparse::linalg::make_raw::(A_nnz, stack); let (new_col_ptr, stack) = stack.make_raw::(n + 1); let (new_row_ind, mut stack) = stack.make_raw::(A_nnz); @@ -3555,7 +3551,7 @@ impl SymbolicCholesky { new_col_ptr, new_row_ind, A, - static_perm.cast(), + static_perm, side, out_side, false, @@ -3611,7 +3607,7 @@ impl SymbolicCholesky { self, E::faer_into_const(L_values), E::faer_into_const(subdiag), - unsafe { PermutationRef::<'out, I, E>::new_unchecked(perm_forward, perm_inverse) }, + unsafe { PermRef::<'out, I>::new_unchecked(perm_forward, perm_inverse) }, ) }) } @@ -3649,40 +3645,40 @@ pub struct IntranodeBunchKaufmanRef<'a, I: Index, E: Entity> { symbolic: &'a SymbolicCholesky, values: SliceGroup<'a, E>, subdiag: SliceGroup<'a, E>, - perm: PermutationRef<'a, I, E>, + perm: PermRef<'a, I>, } impl<'a, I: Index, E: Entity> core::ops::Deref for LltRef<'a, I, E> { type Target = SymbolicCholesky; #[inline] fn deref(&self) -> &Self::Target { - &self.symbolic + self.symbolic } } impl<'a, I: Index, E: Entity> core::ops::Deref for LdltRef<'a, I, E> { type Target = SymbolicCholesky; #[inline] fn deref(&self) -> &Self::Target { - &self.symbolic + self.symbolic } } impl<'a, I: Index, E: Entity> core::ops::Deref for IntranodeBunchKaufmanRef<'a, I, E> { type Target = SymbolicCholesky; #[inline] fn deref(&self) -> &Self::Target { - &self.symbolic + self.symbolic } } -impl_copy!(<'a>>); -impl_copy!(<'a>>); +impl_copy!(<'a>>); +impl_copy!(<'a>>); -impl_copy!(<'a>>); -impl_copy!(<'a>>); +impl_copy!(<'a>>); +impl_copy!(<'a>>); -impl_copy!(<'a>>); -impl_copy!(<'a>>); -impl_copy!(<'a>>); +impl_copy!(<'a>>); +impl_copy!(<'a>>); +impl_copy!(<'a>>); impl_copy!(<'a>>); impl_copy!(<'a>>); @@ -3694,7 +3690,7 @@ impl<'a, I: Index, E: Entity> IntranodeBunchKaufmanRef<'a, I, E> { symbolic: &'a SymbolicCholesky, values: GroupFor, subdiag: GroupFor, - perm: PermutationRef<'a, I, E>, + perm: PermRef<'a, I>, ) -> Self { let values = SliceGroup::<'_, E>::new(values); let subdiag = SliceGroup::<'_, E>::new(subdiag); @@ -3727,7 +3723,7 @@ impl<'a, I: Index, E: Entity> IntranodeBunchKaufmanRef<'a, I, E> { let mut rhs = rhs; let (mut x, stack) = temp_mat_uninit::(n, k, stack); - let (fwd, inv) = self.symbolic.perm().into_arrays(); + let (fwd, inv) = self.symbolic.perm().arrays(); match self.symbolic.raw() { SymbolicCholeskyRaw::Simplicial(symbolic) => { @@ -3746,7 +3742,7 @@ impl<'a, I: Index, E: Entity> IntranodeBunchKaufmanRef<'a, I, E> { } } SymbolicCholeskyRaw::Supernodal(symbolic) => { - let (dyn_fwd, dyn_inv) = self.perm.into_arrays(); + let (dyn_fwd, dyn_inv) = self.perm.arrays(); for j in 0..k { for (i, dyn_fwd) in dyn_fwd.iter().enumerate() { x.write(i, j, rhs.read(fwd[dyn_fwd.zx()].zx(), j)); @@ -3805,7 +3801,7 @@ impl<'a, I: Index, E: Entity> LltRef<'a, I, E> { let (mut x, stack) = temp_mat_uninit::(n, k, stack); - let (fwd, inv) = self.symbolic.perm().into_arrays(); + let (fwd, inv) = self.symbolic.perm().arrays(); for j in 0..k { for (i, fwd) in fwd.iter().enumerate() { x.write(i, j, rhs.read(fwd.zx(), j)); @@ -3860,7 +3856,7 @@ impl<'a, I: Index, E: Entity> LdltRef<'a, I, E> { let (mut x, stack) = temp_mat_uninit::(n, k, stack); - let (fwd, inv) = self.symbolic.perm().into_arrays(); + let (fwd, inv) = self.symbolic.perm().arrays(); for j in 0..k { for (i, fwd) in fwd.iter().enumerate() { x.write(i, j, rhs.read(fwd.zx(), j)); @@ -4023,8 +4019,7 @@ pub fn factorize_symbolic_cholesky( stack.rb_mut(), )?; let flops = flops.n_div + flops.n_mult_subs_ldl; - let perm_ = - ghost::PermutationRef::new(PermutationRef::new_checked(&perm_fwd, &perm_inv), N); + let perm_ = ghost::PermRef::new(PermRef::new_checked(&perm_fwd, &perm_inv), N); let (new_col_ptr, stack) = stack.make_raw::(n + 1); let (new_row_ind, mut stack) = stack.make_raw::(A_nnz); @@ -4049,7 +4044,8 @@ pub fn factorize_symbolic_cholesky( let L_nnz = I::sum_nonnegative(col_counts.as_ref()).ok_or(FaerError::IndexOverflow)?; let raw = if (flops / L_nnz.zx() as f64) - > params.supernodal_flop_ratio_threshold.0 * crate::CHOLESKY_SUPERNODAL_RATIO_FACTOR + > params.supernodal_flop_ratio_threshold.0 + * crate::sparse::linalg::CHOLESKY_SUPERNODAL_RATIO_FACTOR { SymbolicCholeskyRaw::Supernodal(supernodal::ghost_factorize_supernodal_symbolic( A, @@ -4085,11 +4081,14 @@ pub fn factorize_symbolic_cholesky( pub(crate) mod tests { use super::{supernodal::SupernodalLdltRef, *}; use crate::{ - cholesky::supernodal::{CholeskyInput, SupernodalIntranodeBunchKaufmanRef}, - qd::Double, + assert, + sparse::linalg::{ + cholesky::supernodal::{CholeskyInput, SupernodalIntranodeBunchKaufmanRef}, + qd::Double, + }, + Mat, }; use dyn_stack::GlobalPodBuffer; - use faer_core::{assert, Mat}; use num_complex::Complex; use rand::{Rng, SeedableRng}; @@ -4145,7 +4144,7 @@ pub(crate) mod tests { assert_eq!(col_count, [3, 3, 4, 3, 3, 4, 4, 3, 3, 2, 1].map(truncate)); } - include!("../data.rs"); + include!("./data.rs"); fn test_amd() { for &(_, (_, col_ptr, row_ind, _)) in ALL { @@ -4163,13 +4162,14 @@ pub(crate) mod tests { let perm = &mut vec![I(0); n]; let perm_inv = &mut vec![I(0); n]; - crate::amd::order_maybe_unsorted( + crate::sparse::linalg::amd::order_maybe_unsorted( perm, perm_inv, A, Default::default(), PodStack::new(&mut GlobalPodBuffer::new( - crate::amd::order_maybe_unsorted_req::(n, row_ind.len()).unwrap(), + crate::sparse::linalg::amd::order_maybe_unsorted_req::(n, row_ind.len()) + .unwrap(), )), ) .unwrap(); @@ -4367,7 +4367,7 @@ pub(crate) mod tests { let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::>(); let row_ind = &*row_ind.iter().copied().map(truncate).collect::>(); let values_mat = - faer_core::Mat::::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i]))); + crate::Mat::::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i]))); let values = values_mat.col_as_slice(0); let A = SparseColMatRef::<'_, I, E>::new( @@ -4402,14 +4402,14 @@ pub(crate) mod tests { let mut A_lower_values = values_mat.clone(); let mut A_lower_row_ind = row_ind.to_vec(); let A_lower_values = SliceGroupMut::new(A_lower_values.col_as_slice_mut(0)); - let A_lower = faer_core::sparse::util::ghost_adjoint( + let A_lower = crate::sparse::utils::ghost_adjoint( &mut A_lower_col_ptr, &mut A_lower_row_ind, A_lower_values, A, PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(20 * n))), ); - let mut values = faer_core::Mat::::zeros(symbolic.len_values(), 1); + let mut values = crate::Mat::::zeros(symbolic.len_values(), 1); supernodal::factorize_supernodal_numeric_ldlt( values.col_as_slice_mut(0), @@ -4467,7 +4467,7 @@ pub(crate) mod tests { let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::>(); let row_ind = &*row_ind.iter().copied().map(truncate).collect::>(); let values_mat = - faer_core::Mat::::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i]))); + crate::Mat::::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i]))); let values = values_mat.col_as_slice(0); let A = SparseColMatRef::<'_, I, E>::new( @@ -4509,14 +4509,14 @@ pub(crate) mod tests { let mut A_lower_values = values_mat.clone(); let mut A_lower_row_ind = row_ind.to_vec(); let A_lower_values = SliceGroupMut::new(A_lower_values.col_as_slice_mut(0)); - let A_lower = faer_core::sparse::util::ghost_adjoint( + let A_lower = crate::sparse::utils::ghost_adjoint( &mut A_lower_col_ptr, &mut A_lower_row_ind, A_lower_values, A, PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(20 * n))), ); - let mut values = faer_core::Mat::::zeros(symbolic.len_values(), 1); + let mut values = crate::Mat::::zeros(symbolic.len_values(), 1); supernodal::factorize_supernodal_numeric_ldlt( values.col_as_slice_mut(0), @@ -4592,7 +4592,7 @@ pub(crate) mod tests { } }; let values_mat = - faer_core::Mat::::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i]))); + crate::Mat::::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i]))); let values = values_mat.col_as_slice(0); let A = SparseColMatRef::<'_, I, E>::new( @@ -4634,14 +4634,14 @@ pub(crate) mod tests { let mut A_lower_values = values_mat.clone(); let mut A_lower_row_ind = row_ind.to_vec(); let A_lower_values = SliceGroupMut::new(A_lower_values.col_as_slice_mut(0)); - let A_lower = faer_core::sparse::util::ghost_adjoint( + let A_lower = crate::sparse::utils::ghost_adjoint( &mut A_lower_col_ptr, &mut A_lower_row_ind, A_lower_values, A, PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(20 * n))), ); - let mut values = faer_core::Mat::::zeros(symbolic.len_values(), 1); + let mut values = crate::Mat::::zeros(symbolic.len_values(), 1); let mut fwd = vec![zero; n]; let mut inv = vec![zero; n]; @@ -4675,13 +4675,13 @@ pub(crate) mod tests { &symbolic, values.col_as_slice(0), subdiag.col_as_slice(0), - PermutationRef::new_checked(&fwd, &inv), + PermRef::new_checked(&fwd, &inv), ); - faer_core::permutation::permute_rows_in_place( + crate::perm::permute_rows_in_place( x.as_mut(), lblt.perm, PodStack::new(&mut GlobalPodBuffer::new( - faer_core::permutation::permute_rows_in_place_req::(n, k).unwrap(), + crate::perm::permute_rows_in_place_req::(n, k).unwrap(), )), ); lblt.solve_in_place_no_numeric_permute_with_conj( @@ -4692,11 +4692,11 @@ pub(crate) mod tests { symbolic.solve_in_place_req::(k).unwrap(), )), ); - faer_core::permutation::permute_rows_in_place( + crate::perm::permute_rows_in_place( x.as_mut(), lblt.perm.inverse(), PodStack::new(&mut GlobalPodBuffer::new( - faer_core::permutation::permute_rows_in_place_req::(n, k).unwrap(), + crate::perm::permute_rows_in_place_req::(n, k).unwrap(), )), ); @@ -4733,7 +4733,7 @@ pub(crate) mod tests { let nnz = values.len(); let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::>(); let row_ind = &*row_ind.iter().copied().map(truncate).collect::>(); - let values_mat = faer_core::Mat::::from_fn(nnz, 1, |i, _| values[i]); + let values_mat = crate::Mat::::from_fn(nnz, 1, |i, _| values[i]); let values = values_mat.col_as_slice(0); let A = SparseColMatRef::<'_, I, E>::new( @@ -4775,14 +4775,14 @@ pub(crate) mod tests { let mut A_lower_values = values_mat.clone(); let mut A_lower_row_ind = row_ind.to_vec(); let A_lower_values = SliceGroupMut::new(A_lower_values.col_as_slice_mut(0)); - let A_lower = faer_core::sparse::util::ghost_adjoint( + let A_lower = crate::sparse::utils::ghost_adjoint( &mut A_lower_col_ptr, &mut A_lower_row_ind, A_lower_values, A, PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(20 * n))), ); - let mut values = faer_core::Mat::::zeros(symbolic.len_values(), 1); + let mut values = crate::Mat::::zeros(symbolic.len_values(), 1); let mut fwd = vec![zero; n]; let mut inv = vec![zero; n]; @@ -4816,13 +4816,13 @@ pub(crate) mod tests { &symbolic, values.col_as_slice(0), subdiag.col_as_slice(0), - PermutationRef::new_checked(&fwd, &inv), + PermRef::new_checked(&fwd, &inv), ); - faer_core::permutation::permute_rows_in_place( + crate::perm::permute_rows_in_place( x.as_mut(), lblt.perm, PodStack::new(&mut GlobalPodBuffer::new( - faer_core::permutation::permute_rows_in_place_req::(n, k).unwrap(), + crate::perm::permute_rows_in_place_req::(n, k).unwrap(), )), ); lblt.solve_in_place_no_numeric_permute_with_conj( @@ -4833,11 +4833,11 @@ pub(crate) mod tests { symbolic.solve_in_place_req::(k).unwrap(), )), ); - faer_core::permutation::permute_rows_in_place( + crate::perm::permute_rows_in_place( x.as_mut(), lblt.perm.inverse(), PodStack::new(&mut GlobalPodBuffer::new( - faer_core::permutation::permute_rows_in_place_req::(n, k).unwrap(), + crate::perm::permute_rows_in_place_req::(n, k).unwrap(), )), ); @@ -4882,7 +4882,7 @@ pub(crate) mod tests { let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::>(); let row_ind = &*row_ind.iter().copied().map(truncate).collect::>(); let values_mat = - faer_core::Mat::::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i]))); + crate::Mat::::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i]))); let values = values_mat.col_as_slice(0); let A = SparseColMatRef::<'_, I, E>::new( @@ -4909,7 +4909,7 @@ pub(crate) mod tests { ) .unwrap(); - let mut values = faer_core::Mat::::zeros(symbolic.len_values(), 1); + let mut values = crate::Mat::::zeros(symbolic.len_values(), 1); simplicial::factorize_simplicial_numeric_ldlt::( values.col_as_slice_mut(0), @@ -4960,9 +4960,8 @@ pub(crate) mod tests { let nnz = values.len(); let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::>(); let row_ind = &*row_ind.iter().copied().map(truncate).collect::>(); - let values_mat = faer_core::Mat::::from_fn(nnz, 1, |i, _| { - complexify(E::faer_from_f64(values[i])) - }); + let values_mat = + crate::Mat::::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i]))); let values = values_mat.col_as_slice(0); let A_upper = SparseColMatRef::<'_, I, E>::new( @@ -4973,7 +4972,7 @@ pub(crate) mod tests { let mut A_lower_col_ptr = col_ptr.to_vec(); let mut A_lower_values = values_mat.clone(); let mut A_lower_row_ind = row_ind.to_vec(); - let A_lower = faer_core::sparse::util::adjoint( + let A_lower = crate::sparse::utils::adjoint( &mut A_lower_col_ptr, &mut A_lower_row_ind, A_lower_values.col_as_slice_mut(0), @@ -5052,7 +5051,7 @@ pub(crate) mod tests { } }; - let (perm_fwd, _) = symbolic.perm().into_arrays(); + let (perm_fwd, _) = symbolic.perm().arrays(); let mut max = ::Real::faer_zero(); for j in 0..n { @@ -5124,9 +5123,8 @@ pub(crate) mod tests { let nnz = values.len(); let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::>(); let row_ind = &*row_ind.iter().copied().map(truncate).collect::>(); - let values_mat = faer_core::Mat::::from_fn(nnz, 1, |i, _| { - complexify(E::faer_from_f64(values[i])) - }); + let values_mat = + crate::Mat::::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i]))); let values = values_mat.col_as_slice(0); let A_upper = SparseColMatRef::<'_, I, E>::new( @@ -5137,7 +5135,7 @@ pub(crate) mod tests { let mut A_lower_col_ptr = col_ptr.to_vec(); let mut A_lower_values = values_mat.clone(); let mut A_lower_row_ind = row_ind.to_vec(); - let A_lower = faer_core::sparse::util::adjoint( + let A_lower = crate::sparse::utils::adjoint( &mut A_lower_col_ptr, &mut A_lower_row_ind, A_lower_values.col_as_slice_mut(0), @@ -5213,7 +5211,7 @@ pub(crate) mod tests { } }; - let (perm_fwd, _) = symbolic.perm().into_arrays(); + let (perm_fwd, _) = symbolic.perm().arrays(); let mut max = ::Real::faer_zero(); for j in 0..n { @@ -5286,9 +5284,8 @@ pub(crate) mod tests { let nnz = values.len(); let col_ptr = &*col_ptr.iter().copied().map(truncate).collect::>(); let row_ind = &*row_ind.iter().copied().map(truncate).collect::>(); - let values_mat = faer_core::Mat::::from_fn(nnz, 1, |i, _| { - complexify(E::faer_from_f64(values[i])) - }); + let values_mat = + crate::Mat::::from_fn(nnz, 1, |i, _| complexify(E::faer_from_f64(values[i]))); let values = values_mat.col_as_slice(0); let A_upper = SparseColMatRef::<'_, I, E>::new( @@ -5299,7 +5296,7 @@ pub(crate) mod tests { let mut A_lower_col_ptr = col_ptr.to_vec(); let mut A_lower_values = values_mat.clone(); let mut A_lower_row_ind = row_ind.to_vec(); - let A_lower = faer_core::sparse::util::adjoint( + let A_lower = crate::sparse::utils::adjoint( &mut A_lower_col_ptr, &mut A_lower_row_ind, A_lower_values.col_as_slice_mut(0), @@ -5420,7 +5417,7 @@ pub(crate) mod tests { let col_ptr = &*col_ptr.iter().copied().map(I).collect::>(); let row_ind = &*row_ind.iter().copied().map(I).collect::>(); // artificial zeros - let values_mat = faer_core::Mat::::from_fn(nnz, 1, |_, _| 0.0); + let values_mat = crate::Mat::::from_fn(nnz, 1, |_, _| 0.0); let dynamic_regularization_epsilon = 1e-6; let dynamic_regularization_delta = 1e-2; @@ -5436,7 +5433,7 @@ pub(crate) mod tests { let mut A_lower_col_ptr = col_ptr.to_vec(); let mut A_lower_values = values_mat.clone(); let mut A_lower_row_ind = row_ind.to_vec(); - let A_lower = faer_core::sparse::util::adjoint( + let A_lower = crate::sparse::utils::adjoint( &mut A_lower_col_ptr, &mut A_lower_row_ind, A_lower_values.col_as_slice_mut(0), @@ -5519,7 +5516,7 @@ pub(crate) mod tests { } }; - let (perm_fwd, _) = symbolic.perm().into_arrays(); + let (perm_fwd, _) = symbolic.perm().arrays(); let mut max = ::Real::faer_zero(); for j in 0..n { for i in 0..n { diff --git a/faer-libs/faer-sparse/src/colamd.rs b/src/sparse/linalg/colamd.rs similarity index 98% rename from faer-libs/faer-sparse/src/colamd.rs rename to src/sparse/linalg/colamd.rs index 9610715fb0094bde19fff06550c3d598f45e6f02..faf41b92bca1ee4a1b7c0abd73dff55200ad2c69 100644 --- a/faer-libs/faer-sparse/src/colamd.rs +++ b/src/sparse/linalg/colamd.rs @@ -27,12 +27,12 @@ // OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH // DAMAGE. -use crate::{ +use super::{ mem::{self, NONE}, FaerError, Index, SignedIndex, SymbolicSparseColMatRef, }; +use crate::{assert, perm::PermRef}; use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{assert, permutation::PermutationRef}; use reborrow::*; impl ColamdCol { @@ -86,6 +86,8 @@ fn clear_mark(tag_mark: I, max_mark: I, row: &mut [ColamdRow]) -> I } } +/// Computes the size and alignment of required workspace for computing the COLAMD ordering of a +/// matrix. pub fn order_req( nrows: usize, ncols: usize, @@ -122,6 +124,9 @@ pub fn order_req( ) } +/// Computes the approximate minimum degree ordering for reducing the fill-in during the sparse +/// QR factorization of a matrix with the sparsity pattern of `A`. +/// /// # Note /// Allows unsorted matrices. pub fn order( @@ -651,15 +656,13 @@ pub fn order( let mut etree = alloc::vec![I(0); n]; let mut post = alloc::vec![I(0); n]; - let etree = crate::qr::col_etree::( + let etree = super::qr::col_etree::( A, - Some(PermutationRef::<'_, I, faer_entity::Symbolic>::new_checked( - perm, perm_inv, - )), + Some(PermRef::<'_, I>::new_checked(perm, perm_inv)), &mut etree, stack.rb_mut(), ); - crate::qr::postorder(&mut post, etree, stack.rb_mut()); + super::qr::postorder(&mut post, etree, stack.rb_mut()); for i in 0..n { perm[post[i].zx()] = I(i); } @@ -843,6 +846,7 @@ unsafe impl bytemuck::Pod for ColamdCol {} unsafe impl bytemuck::Zeroable for ColamdRow {} unsafe impl bytemuck::Pod for ColamdRow {} +/// Tuning parameters for the AMD implementation. #[derive(Debug, Copy, Clone, PartialEq)] pub struct Control { /// "dense" if degree > dense_row * sqrt(ncols) diff --git a/faer-libs/faer-sparse/data.rs b/src/sparse/linalg/data.rs similarity index 100% rename from faer-libs/faer-sparse/data.rs rename to src/sparse/linalg/data.rs diff --git a/faer-libs/faer-sparse/src/lu.rs b/src/sparse/linalg/lu.rs similarity index 91% rename from faer-libs/faer-sparse/src/lu.rs rename to src/sparse/linalg/lu.rs index c8ced6b20dd20e2df09e78364485c25ae08e31bc..60fcf0ad3368a7328974687e31b34b7ef2323bc6 100644 --- a/faer-libs/faer-sparse/src/lu.rs +++ b/src/sparse/linalg/lu.rs @@ -1,12 +1,14 @@ -//! Computes the LU decomposition of a given sparse matrix. See [`faer_lu`] for more info. +//! Computes the LU decomposition of a given sparse matrix. See +//! [`faer::linalg::lu`](crate::linalg::lu) for more info. //! //! The entry point in this module is [`SymbolicLu`] and [`factorize_symbolic_lu`]. //! //! # Warning //! The functions in this module accept unsorted input, and always produce unsorted decomposition //! factors. +#![allow(missing_docs)] -use crate::{ +use super::{ cholesky::simplicial::EliminationTreeRef, colamd::Control, ghost, @@ -14,21 +16,19 @@ use crate::{ mem::{ NONE, {self}, }, - nomem, try_zeroed, FaerError, Index, SupernodalThreshold, SymbolicSparseColMatRef, + nomem, try_zeroed, FaerError, Index, LuError, SupernodalThreshold, SymbolicSparseColMatRef, SymbolicSupernodalParams, }; -use core::{iter::zip, mem::MaybeUninit}; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - constrained::Size, - group_helpers::{SliceGroup, SliceGroupMut, VecGroup}, - mul, - permutation::{PermutationRef, SignedIndex}, - solve, + linalg::{matmul, temp_mat_req, temp_mat_uninit, triangular_solve as solve}, + perm::PermRef, sparse::SparseColMatRef, - temp_mat_req, temp_mat_uninit, Conj, MatMut, Parallelism, + utils::{constrained::Size, slice::*, vec::*}, + Conj, MatMut, Parallelism, SignedIndex, }; +use core::{iter::zip, mem::MaybeUninit}; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use faer_entity::*; use reborrow::*; @@ -93,35 +93,9 @@ fn resize_index( Ok(()) } -/// Sparse LU error. -#[derive(Copy, Clone, Debug)] -pub enum LuError { - Generic(FaerError), - SymbolicSingular(usize), -} - -impl core::fmt::Display for LuError { - #[inline] - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - core::fmt::Debug::fmt(self, f) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for LuError {} - -impl From for LuError { - #[inline] - fn from(value: FaerError) -> Self { - Self::Generic(value) - } -} - pub mod supernodal { - use crate::try_collect; - use super::*; - use faer_core::assert; + use crate::{assert, sparse::linalg::try_collect}; #[derive(Debug, Clone)] pub struct SymbolicSupernodalLu { @@ -153,6 +127,12 @@ pub mod supernodal { ut_val: VecGroup, } + impl Default for SupernodalLu { + fn default() -> Self { + Self::new() + } + } + impl SupernodalLu { #[inline] pub fn new() -> Self { @@ -195,8 +175,8 @@ pub mod supernodal { #[track_caller] pub fn solve_in_place_with_conj( &self, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, conj_lhs: Conj, rhs: MatMut<'_, E>, parallelism: Parallelism, @@ -211,17 +191,17 @@ pub mod supernodal { let mut X = rhs; let mut temp = work; - faer_core::permutation::permute_rows(temp.rb_mut(), X.rb(), row_perm); + crate::perm::permute_rows(temp.rb_mut(), X.rb(), row_perm); self.l_solve_in_place_with_conj(conj_lhs, temp.rb_mut(), X.rb_mut(), parallelism); self.u_solve_in_place_with_conj(conj_lhs, temp.rb_mut(), X.rb_mut(), parallelism); - faer_core::permutation::permute_rows(X.rb_mut(), temp.rb(), col_perm.inverse()); + crate::perm::permute_rows(X.rb_mut(), temp.rb(), col_perm.inverse()); } #[track_caller] pub fn solve_transpose_in_place_with_conj( &self, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, conj_lhs: Conj, rhs: MatMut<'_, E>, parallelism: Parallelism, @@ -235,7 +215,7 @@ pub mod supernodal { )); let mut X = rhs; let mut temp = work; - faer_core::permutation::permute_rows(temp.rb_mut(), X.rb(), col_perm); + crate::perm::permute_rows(temp.rb_mut(), X.rb(), col_perm); self.u_solve_transpose_in_place_with_conj( conj_lhs, temp.rb_mut(), @@ -248,7 +228,7 @@ pub mod supernodal { X.rb_mut(), parallelism, ); - faer_core::permutation::permute_rows(X.rb_mut(), temp.rb(), row_perm.inverse()); + crate::perm::permute_rows(X.rb_mut(), temp.rb(), row_perm.inverse()); } #[track_caller] @@ -261,7 +241,7 @@ pub mod supernodal { ) where E: ComplexField, { - let lu = &*self; + let lu = self; assert!(lu.nrows() == lu.ncols()); assert!(lu.nrows() == rhs.nrows()); @@ -282,7 +262,7 @@ pub mod supernodal { .l_val .as_slice() .subslice(lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()); - let L = faer_core::mat::from_column_major_slice::<'_, E>( + let L = crate::mat::from_column_major_slice::<'_, E>( L.into_inner(), s_row_index_count, s_size, @@ -294,7 +274,7 @@ pub mod supernodal { X.rb_mut().subrows_mut(s_begin, s_size), parallelism, ); - mul::matmul_with_conj( + matmul::matmul_with_conj( work.rb_mut().subrows_mut(0, s_row_index_count - s_size), L_bot, conj_lhs, @@ -329,7 +309,7 @@ pub mod supernodal { ) where E: ComplexField, { - let lu = &*self; + let lu = self; assert!(lu.nrows() == lu.ncols()); assert!(lu.nrows() == rhs.nrows()); @@ -350,7 +330,7 @@ pub mod supernodal { .l_val .as_slice() .subslice(lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()); - let L = faer_core::mat::from_column_major_slice::<'_, E>( + let L = crate::mat::from_column_major_slice::<'_, E>( L.into_inner(), s_row_index_count, s_size, @@ -370,7 +350,7 @@ pub mod supernodal { } } - mul::matmul_with_conj( + matmul::matmul_with_conj( X.rb_mut().subrows_mut(s_begin, s_size), L_bot.transpose(), conj_lhs, @@ -399,7 +379,7 @@ pub mod supernodal { ) where E: ComplexField, { - let lu = &*self; + let lu = self; assert!(lu.nrows() == lu.ncols()); assert!(lu.nrows() == rhs.nrows()); @@ -422,7 +402,7 @@ pub mod supernodal { .l_val .as_slice() .subslice(lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()); - let L = faer_core::mat::from_column_major_slice::<'_, E>( + let L = crate::mat::from_column_major_slice::<'_, E>( L.into_inner(), s_row_index_count, s_size, @@ -431,7 +411,7 @@ pub mod supernodal { .ut_val .as_slice() .subslice(lu.ut_col_ptr_for_val[s].zx()..lu.ut_col_ptr_for_val[s + 1].zx()); - let U_right = faer_core::mat::from_column_major_slice::<'_, E>( + let U_right = crate::mat::from_column_major_slice::<'_, E>( U.into_inner(), s_col_index_count, s_size, @@ -450,7 +430,7 @@ pub mod supernodal { } let (U_left, _) = L.split_at_row(s_size); - mul::matmul_with_conj( + matmul::matmul_with_conj( X.rb_mut().subrows_mut(s_begin, s_size), U_right, conj_lhs, @@ -479,7 +459,7 @@ pub mod supernodal { ) where E: ComplexField, { - let lu = &*self; + let lu = self; assert!(lu.nrows() == lu.ncols()); assert!(lu.nrows() == rhs.nrows()); @@ -502,7 +482,7 @@ pub mod supernodal { .l_val .as_slice() .subslice(lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()); - let L = faer_core::mat::from_column_major_slice::<'_, E>( + let L = crate::mat::from_column_major_slice::<'_, E>( L.into_inner(), s_row_index_count, s_size, @@ -511,7 +491,7 @@ pub mod supernodal { .ut_val .as_slice() .subslice(lu.ut_col_ptr_for_val[s].zx()..lu.ut_col_ptr_for_val[s + 1].zx()); - let U_right = faer_core::mat::from_column_major_slice::<'_, E>( + let U_right = crate::mat::from_column_major_slice::<'_, E>( U.into_inner(), s_col_index_count, s_size, @@ -525,7 +505,7 @@ pub mod supernodal { X.rb_mut().subrows_mut(s_begin, s_size), parallelism, ); - mul::matmul_with_conj( + matmul::matmul_with_conj( work.rb_mut().subrows_mut(0, s_col_index_count), U_right.transpose(), conj_lhs, @@ -555,13 +535,15 @@ pub mod supernodal { ncols: usize, ) -> Result { let _ = nrows; - crate::cholesky::supernodal::factorize_supernodal_symbolic_cholesky_req::(ncols) + crate::sparse::linalg::cholesky::supernodal::factorize_supernodal_symbolic_cholesky_req::( + ncols, + ) } #[track_caller] pub fn factorize_supernodal_symbolic_lu( A: SymbolicSparseColMatRef<'_, I>, - col_perm: Option>, + col_perm: Option>, min_col: &[I], etree: EliminationTreeRef<'_, I>, col_counts: &[I], @@ -574,30 +556,30 @@ pub mod supernodal { let I = I::truncate; let A = ghost::SymbolicSparseColMatRef::new(A, M, N); let min_col = Array::from_ref( - MaybeIdx::from_slice_ref_checked(bytemuck::cast_slice(&min_col), N), + MaybeIdx::from_slice_ref_checked(bytemuck::cast_slice(min_col), N), M, ); let etree = etree.ghost_inner(N); let mut stack = stack; - let L = crate::cholesky::supernodal::ghost_factorize_supernodal_symbolic( - A, - col_perm.map(|perm| ghost::PermutationRef::new(perm, N)), - Some(min_col), - crate::cholesky::supernodal::CholeskyInput::ATA, - etree, - Array::from_ref(&col_counts, N), - stack.rb_mut(), - params, - )?; + let L = + crate::sparse::linalg::cholesky::supernodal::ghost_factorize_supernodal_symbolic( + A, + col_perm.map(|perm| ghost::PermRef::new(perm, N)), + Some(min_col), + crate::sparse::linalg::cholesky::supernodal::CholeskyInput::ATA, + etree, + Array::from_ref(col_counts, N), + stack.rb_mut(), + params, + )?; let n_supernodes = L.n_supernodes(); let mut super_etree = try_zeroed::(n_supernodes)?; let (index_to_super, _) = stack.make_raw::(*N); for s in 0..n_supernodes { - index_to_super.as_mut()[L.supernode_begin[s].zx()..L.supernode_begin[s + 1].zx()] - .fill(I(s)); + index_to_super[L.supernode_begin[s].zx()..L.supernode_begin[s + 1].zx()].fill(I(s)); } for s in 0..n_supernodes { let last = L.supernode_begin[s + 1].zx() - 1; @@ -634,7 +616,7 @@ pub mod supernodal { fn with_dims(nrows: usize, ncols: usize) -> Result { Ok(Self { - data: try_collect((0..(nrows * ncols)).into_iter().map(|_| 1u8))?, + data: try_collect((0..(nrows * ncols)).map(|_| 1u8))?, nrows, }) } @@ -677,13 +659,13 @@ pub mod supernodal { A: SparseColMatRef<'_, I, E>, AT: SparseColMatRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + col_perm: PermRef<'_, I>, symbolic: &SymbolicSupernodalLu, parallelism: Parallelism, stack: PodStack<'_>, ) -> Result<(), LuError> { - use crate::cholesky::supernodal::partition_fn; + use crate::sparse::linalg::cholesky::supernodal::partition_fn; let SymbolicSupernodalLu { supernode_ptr, super_etree, @@ -761,7 +743,7 @@ pub mod supernodal { row_perm_inv[i] = I(i); } - let (col_perm, col_perm_inv) = col_perm.into_arrays(); + let (col_perm, col_perm_inv) = col_perm.arrays(); let mut contrib_work = try_collect((0..n_supernodes).map(|_| { ( @@ -785,7 +767,7 @@ pub mod supernodal { let work_to_mat_mut = |v: &mut GroupFor>>, nrows: usize, ncols: usize| unsafe { - faer_core::mat::from_raw_parts_mut::<'_, E>( + crate::mat::from_raw_parts_mut::<'_, E>( E::faer_map(E::faer_as_mut(v), |v| v.as_mut_ptr() as *mut E::Unit), nrows, ncols, @@ -889,7 +871,7 @@ pub mod supernodal { .l_val .as_slice_mut() .subslice(lu.l_col_ptr_for_val[s].zx()..lu.l_col_ptr_for_val[s + 1].zx()); - let mut s_L = faer_core::mat::from_column_major_slice_mut::<'_, E>( + let mut s_L = crate::mat::from_column_major_slice_mut::<'_, E>( s_L.into_inner(), s_row_index_count, s_size, @@ -990,7 +972,7 @@ pub mod supernodal { return Err(LuError::SymbolicSingular(s_begin + s_L.nrows())); } let transpositions = &mut transpositions[s_begin..s_end]; - faer_lu::partial_pivoting::compute::lu_in_place_impl( + crate::linalg::lu::partial_pivoting::compute::lu_in_place_impl( s_L.rb_mut(), 0, s_size, @@ -1090,7 +1072,7 @@ pub mod supernodal { .ut_val .as_slice_mut() .subslice(lu.ut_col_ptr_for_val[s].zx()..lu.ut_col_ptr_for_val[s + 1].zx()); - let mut s_U = faer_core::mat::from_column_major_slice_mut::<'_, E>( + let mut s_U = crate::mat::from_column_major_slice_mut::<'_, E>( s_U.into_inner(), s_col_index_count, s_size, @@ -1188,7 +1170,7 @@ pub mod supernodal { } } }); - faer_core::solve::solve_unit_lower_triangular_in_place( + solve::solve_unit_lower_triangular_in_place( s_L.rb().subrows(0, s_size), s_U.rb_mut(), parallelism, @@ -1218,7 +1200,7 @@ pub mod supernodal { s_row_index_count - s_size, s_col_index_count, ); - mul::matmul( + matmul::matmul( s_LU.rb_mut(), s_L.rb().get(s_size.., ..), s_U.rb(), @@ -1352,10 +1334,8 @@ pub mod supernodal { } pub mod simplicial { - use crate::triangular_solve; - use super::*; - use faer_core::assert; + use crate::{assert, sparse::linalg::triangular_solve}; #[derive(Debug, Clone)] pub struct SimplicialLu { @@ -1371,6 +1351,12 @@ pub mod simplicial { u_val: VecGroup, } + impl Default for SimplicialLu { + fn default() -> Self { + Self::new() + } + } + impl SimplicialLu { #[inline] pub fn new() -> Self { @@ -1434,8 +1420,8 @@ pub mod simplicial { #[track_caller] pub fn solve_in_place_with_conj( &self, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, conj_lhs: Conj, rhs: MatMut<'_, E>, parallelism: Parallelism, @@ -1451,7 +1437,7 @@ pub mod simplicial { let l = self.l_factor_unsorted(); let u = self.u_factor_unsorted(); - faer_core::permutation::permute_rows(temp.rb_mut(), X.rb(), row_perm); + crate::perm::permute_rows(temp.rb_mut(), X.rb(), row_perm); triangular_solve::solve_unit_lower_triangular_in_place( l, conj_lhs, @@ -1464,14 +1450,14 @@ pub mod simplicial { temp.rb_mut(), parallelism, ); - faer_core::permutation::permute_rows(X.rb_mut(), temp.rb(), col_perm.inverse()); + crate::perm::permute_rows(X.rb_mut(), temp.rb(), col_perm.inverse()); } #[track_caller] pub fn solve_transpose_in_place_with_conj( &self, - row_perm: PermutationRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + row_perm: PermRef<'_, I>, + col_perm: PermRef<'_, I>, conj_lhs: Conj, rhs: MatMut<'_, E>, parallelism: Parallelism, @@ -1489,7 +1475,7 @@ pub mod simplicial { let l = self.l_factor_unsorted(); let u = self.u_factor_unsorted(); - faer_core::permutation::permute_rows(temp.rb_mut(), X.rb(), col_perm); + crate::perm::permute_rows(temp.rb_mut(), X.rb(), col_perm); triangular_solve::solve_upper_triangular_transpose_in_place( u, conj_lhs, @@ -1502,7 +1488,7 @@ pub mod simplicial { temp.rb_mut(), parallelism, ); - faer_core::permutation::permute_rows(X.rb_mut(), temp.rb(), row_perm.inverse()); + crate::perm::permute_rows(X.rb_mut(), temp.rb(), row_perm.inverse()); } } @@ -1638,7 +1624,7 @@ pub mod simplicial { ncols: usize, ) -> Result { let idx = StackReq::try_new::(nrows)?; - let val = crate::make_raw_req::(nrows)?; + let val = crate::sparse::linalg::make_raw_req::(nrows)?; let _ = ncols; StackReq::try_all_of([val, idx, idx, idx]) } @@ -1649,7 +1635,7 @@ pub mod simplicial { lu: &mut SimplicialLu, A: SparseColMatRef<'_, I, E>, - col_perm: PermutationRef<'_, I, E>, + col_perm: PermRef<'_, I>, stack: PodStack<'_>, ) -> Result<(), LuError> { let I = I::truncate; @@ -1670,7 +1656,7 @@ pub mod simplicial { resize_index(&mut lu.l_col_ptr, n + 1, true, false)?; resize_index(&mut lu.u_col_ptr, n + 1, true, false)?; - let (mut x, stack) = crate::make_raw::(m, stack); + let (mut x, stack) = crate::sparse::linalg::make_raw::(m, stack); let (marked, stack) = stack.make_raw::(m); let (xj, stack) = stack.make_raw::(m); let (stack, _) = stack.make_raw::(m); @@ -1698,7 +1684,7 @@ pub mod simplicial { lu.l_val.as_slice().into_inner(), ); - let pj = col_perm.into_arrays().0[j].zx(); + let pj = col_perm.arrays().0[j].zx(); let tail_start = l_incomplete_solve_sparse( marked, I(j + 1), @@ -1825,6 +1811,12 @@ pub struct NumericLu { row_perm_inv: alloc::vec::Vec, } +impl Default for NumericLu { + fn default() -> Self { + Self::new() + } +} + impl NumericLu { #[inline] pub fn new() -> Self { @@ -1861,15 +1853,13 @@ impl<'a, I: Index, E: Entity> LuRef<'a, I, E> { } #[inline] - pub fn row_perm(self) -> PermutationRef<'a, I, E> { - unsafe { - PermutationRef::new_unchecked(&self.numeric.row_perm_fwd, &self.numeric.row_perm_inv) - } + pub fn row_perm(self) -> PermRef<'a, I> { + unsafe { PermRef::new_unchecked(&self.numeric.row_perm_fwd, &self.numeric.row_perm_inv) } } #[inline] - pub fn col_perm(self) -> PermutationRef<'a, I, E> { - self.symbolic.col_perm().cast() + pub fn col_perm(self) -> PermRef<'a, I> { + self.symbolic.col_perm() } #[track_caller] @@ -1960,8 +1950,8 @@ impl SymbolicLu { /// Returns the fill-reducing column permutation that was computed during symbolic analysis. #[inline] - pub fn col_perm(&self) -> PermutationRef<'_, I, Symbolic> { - unsafe { PermutationRef::new_unchecked(&self.col_perm_fwd, &self.col_perm_inv) } + pub fn col_perm(&self) -> PermRef<'_, I> { + unsafe { PermRef::new_unchecked(&self.col_perm_fwd, &self.col_perm_inv) } } pub fn factorize_numeric_lu_req( @@ -1978,7 +1968,7 @@ impl SymbolicLu { let A_nnz = self.A_nnz; let AT_req = StackReq::try_all_of([ - crate::make_raw_req::(A_nnz)?, + crate::sparse::linalg::make_raw_req::(A_nnz)?, StackReq::try_new::(m + 1)?, StackReq::try_new::(A_nnz)?, ])?; @@ -2049,7 +2039,7 @@ impl SymbolicLu { &mut numeric.row_perm_inv, lu, A, - self.col_perm().cast(), + self.col_perm(), stack, )?; } @@ -2057,8 +2047,9 @@ impl SymbolicLu { let m = symbolic.nrows; let (new_col_ptr, stack) = stack.make_raw::(m + 1); let (new_row_ind, stack) = stack.make_raw::(self.A_nnz); - let (new_values, mut stack) = crate::make_raw::(self.A_nnz, stack); - let AT = crate::transpose::( + let (new_values, mut stack) = + crate::sparse::linalg::make_raw::(self.A_nnz, stack); + let AT = crate::sparse::utils::transpose::( new_col_ptr, new_row_ind, new_values.into_inner(), @@ -2073,7 +2064,7 @@ impl SymbolicLu { lu, A, AT, - self.col_perm().cast(), + self.col_perm(), symbolic, parallelism, stack, @@ -2112,7 +2103,7 @@ pub fn factorize_symbolic_lu( )?; StackReq::try_or( - crate::colamd::order_req::(m, n, A_nnz)?, + crate::sparse::linalg::colamd::order_req::(m, n, A_nnz)?, StackReq::try_all_of([ n_req, n_req, @@ -2137,7 +2128,7 @@ pub fn factorize_symbolic_lu( let mut col_perm_inv = try_zeroed::(n)?; let mut min_row = try_zeroed::(m)?; - crate::colamd::order( + crate::sparse::linalg::colamd::order( &mut col_perm_fwd, &mut col_perm_inv, A.into_inner(), @@ -2145,31 +2136,42 @@ pub fn factorize_symbolic_lu( stack.rb_mut(), )?; - let col_perm = ghost::PermutationRef::new( - PermutationRef::new_checked(&col_perm_fwd, &col_perm_inv), - N, - ); + let col_perm = ghost::PermRef::new(PermRef::new_checked(&col_perm_fwd, &col_perm_inv), N); let (new_col_ptr, stack) = stack.make_raw::(m + 1); let (new_row_ind, mut stack) = stack.make_raw::(A_nnz); - let AT = crate::ghost_adjoint_symbolic(new_col_ptr, new_row_ind, A, stack.rb_mut()); + let AT = crate::sparse::utils::ghost_adjoint_symbolic( + new_col_ptr, + new_row_ind, + A, + stack.rb_mut(), + ); let (etree, stack) = stack.make_raw::(n); let (post, stack) = stack.make_raw::(n); let (col_counts, stack) = stack.make_raw::(n); let (h_col_counts, mut stack) = stack.make_raw::(n); - crate::qr::ghost_col_etree(A, Some(col_perm), Array::from_mut(etree, N), stack.rb_mut()); - let etree_ = Array::from_ref(MaybeIdx::<'_, I>::from_slice_ref_checked(&etree, N), N); - crate::cholesky::ghost_postorder(Array::from_mut(post, N), etree_, stack.rb_mut()); + crate::sparse::linalg::qr::ghost_col_etree( + A, + Some(col_perm), + Array::from_mut(etree, N), + stack.rb_mut(), + ); + let etree_ = Array::from_ref(MaybeIdx::<'_, I>::from_slice_ref_checked(etree, N), N); + crate::sparse::linalg::cholesky::ghost_postorder( + Array::from_mut(post, N), + etree_, + stack.rb_mut(), + ); - crate::qr::ghost_column_counts_aat( + crate::sparse::linalg::qr::ghost_column_counts_aat( Array::from_mut(col_counts, N), Array::from_mut(bytemuck::cast_slice_mut(&mut min_row), M), AT, Some(col_perm), etree_, - Array::from_ref(Idx::from_slice_ref_checked(&post, N), N), + Array::from_ref(Idx::from_slice_ref_checked(post, N), N), stack.rb_mut(), ); let min_col = min_row; @@ -2203,7 +2205,7 @@ pub fn factorize_symbolic_lu( nnz += hj + rj; } - if flops / nnz > threshold.0 * crate::LU_SUPERNODAL_RATIO_FACTOR { + if flops / nnz > threshold.0 * crate::sparse::linalg::LU_SUPERNODAL_RATIO_FACTOR { threshold = SupernodalThreshold::FORCE_SUPERNODAL; } else { threshold = SupernodalThreshold::FORCE_SIMPLICIAL; @@ -2214,9 +2216,9 @@ pub fn factorize_symbolic_lu( let symbolic = supernodal::factorize_supernodal_symbolic_lu::( A.into_inner(), Some(col_perm.into_inner()), - &*min_col, - EliminationTreeRef::<'_, I> { inner: &etree }, - &col_counts, + &min_col, + EliminationTreeRef::<'_, I> { inner: etree }, + col_counts, stack.rb_mut(), params.supernodal_params, )?; @@ -2240,28 +2242,32 @@ pub fn factorize_symbolic_lu( #[cfg(test)] mod tests { use crate::{ - lu::{ - simplicial::{ - factorize_simplicial_numeric_lu, factorize_simplicial_numeric_lu_req, SimplicialLu, - }, - supernodal::{ - factorize_supernodal_numeric_lu, factorize_supernodal_numeric_lu_req, SupernodalLu, + assert, + perm::PermRef, + sparse::{ + linalg::{ + lu::{ + simplicial::{ + factorize_simplicial_numeric_lu, factorize_simplicial_numeric_lu_req, + SimplicialLu, + }, + supernodal::{ + factorize_supernodal_numeric_lu, factorize_supernodal_numeric_lu_req, + SupernodalLu, + }, + LuSymbolicParams, NumericLu, + }, + qr::col_etree, + SupernodalThreshold, SymbolicSparseColMatRef, }, - LuSymbolicParams, NumericLu, + SparseColMatRef, }, - qr::col_etree, - SupernodalThreshold, SymbolicSparseColMatRef, + utils::slice::*, + Conj, Index, Mat, Parallelism, }; use core::iter::zip; use dyn_stack::{GlobalPodBuffer, PodStack, StackReq}; - use faer_core::{ - assert, - group_helpers::SliceGroup, - permutation::{Index, PermutationRef}, - sparse::SparseColMatRef, - Conj, Mat, Parallelism, - }; - use faer_entity::{ComplexField, Symbolic}; + use faer_entity::ComplexField; use matrix_market_rs::MtxData; use rand::{rngs::StdRng, Rng, SeedableRng}; use reborrow::*; @@ -2346,7 +2352,7 @@ mod tests { #[test] fn test_numeric_lu_multifrontal() { - type E = faer_core::c64; + type E = crate::complex_native::c64; let (m, n, col_ptr, row_ind, val) = load_mtx::(MtxData::from_file("test_data/YAO.mtx").unwrap()); @@ -2367,7 +2373,7 @@ mod tests { col_perm[i] = i; col_perm_inv[i] = i; } - let col_perm = PermutationRef::<'_, usize, Symbolic>::new_checked(&col_perm, &col_perm_inv); + let col_perm = PermRef::<'_, usize>::new_checked(&col_perm, &col_perm_inv); let mut etree = vec![0usize; n]; let mut min_col = vec![0usize; m]; @@ -2377,7 +2383,7 @@ mod tests { let mut new_col_ptrs = vec![0usize; m + 1]; let mut new_row_ind = vec![0usize; nnz]; let mut new_values = vec![E::faer_zero(); nnz]; - let AT = crate::transpose::( + let AT = crate::sparse::utils::transpose::( &mut new_col_ptrs, &mut new_row_ind, &mut new_values, @@ -2390,20 +2396,20 @@ mod tests { let mut post = vec![0usize; n]; let etree = col_etree( - *A, + A.symbolic(), Some(col_perm), &mut etree, PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(m + n))), ); - crate::qr::postorder( + crate::sparse::linalg::qr::postorder( &mut post, etree, PodStack::new(&mut GlobalPodBuffer::new(StackReq::new::(3 * n))), ); - crate::qr::column_counts_aat( + crate::sparse::linalg::qr::column_counts_aat( &mut col_counts, &mut min_col, - *AT, + AT.symbolic(), Some(col_perm), etree, &post, @@ -2412,20 +2418,21 @@ mod tests { etree }; - let symbolic = crate::lu::supernodal::factorize_supernodal_symbolic_lu::( - *A, - Some(col_perm), - &min_col, - etree, - &col_counts, - PodStack::new(&mut GlobalPodBuffer::new( - super::supernodal::factorize_supernodal_symbolic_lu_req::(m, n).unwrap(), - )), - crate::SymbolicSupernodalParams { - relax: Some(&[(4, 1.0), (16, 0.8), (48, 0.1), (usize::MAX, 0.05)]), - }, - ) - .unwrap(); + let symbolic = + crate::sparse::linalg::lu::supernodal::factorize_supernodal_symbolic_lu::( + A.symbolic(), + Some(col_perm), + &min_col, + etree, + &col_counts, + PodStack::new(&mut GlobalPodBuffer::new( + super::supernodal::factorize_supernodal_symbolic_lu_req::(m, n).unwrap(), + )), + crate::sparse::linalg::SymbolicSupernodalParams { + relax: Some(&[(4, 1.0), (16, 0.8), (48, 0.1), (usize::MAX, 0.05)]), + }, + ) + .unwrap(); let mut lu = SupernodalLu::::new(); factorize_supernodal_numeric_lu( @@ -2434,7 +2441,7 @@ mod tests { &mut lu, A, AT, - col_perm.cast(), + col_perm, &symbolic, Parallelism::None, PodStack::new(&mut GlobalPodBuffer::new( @@ -2448,14 +2455,14 @@ mod tests { let mut work = rhs.clone(); let A_dense = sparse_to_dense(A); - let row_perm = PermutationRef::<'_, _, Symbolic>::new_checked(&row_perm, &row_perm_inv); + let row_perm = PermRef::<'_, _>::new_checked(&row_perm, &row_perm_inv); { let mut x = rhs.clone(); lu.solve_in_place_with_conj( - row_perm.cast(), - col_perm.cast(), + row_perm, + col_perm, Conj::No, x.as_mut(), Parallelism::None, @@ -2467,8 +2474,8 @@ mod tests { let mut x = rhs.clone(); lu.solve_in_place_with_conj( - row_perm.cast(), - col_perm.cast(), + row_perm, + col_perm, Conj::Yes, x.as_mut(), Parallelism::None, @@ -2480,8 +2487,8 @@ mod tests { let mut x = rhs.clone(); lu.solve_transpose_in_place_with_conj( - row_perm.cast(), - col_perm.cast(), + row_perm, + col_perm, Conj::No, x.as_mut(), Parallelism::None, @@ -2493,8 +2500,8 @@ mod tests { let mut x = rhs.clone(); lu.solve_transpose_in_place_with_conj( - row_perm.cast(), - col_perm.cast(), + row_perm, + col_perm, Conj::Yes, x.as_mut(), Parallelism::None, @@ -2506,7 +2513,7 @@ mod tests { #[test] fn test_numeric_lu_simplicial() { - type E = faer_core::c64; + type E = crate::complex_native::c64; let (m, n, col_ptr, row_ind, val) = load_mtx::(MtxData::from_file("test_data/YAO.mtx").unwrap()); @@ -2527,7 +2534,7 @@ mod tests { col_perm[i] = i; col_perm_inv[i] = i; } - let col_perm = PermutationRef::<'_, usize, Symbolic>::new_checked(&col_perm, &col_perm_inv); + let col_perm = PermRef::<'_, usize>::new_checked(&col_perm, &col_perm_inv); let mut lu = SimplicialLu::::new(); factorize_simplicial_numeric_lu( @@ -2535,7 +2542,7 @@ mod tests { &mut row_perm_inv, &mut lu, A, - col_perm.cast(), + col_perm, PodStack::new(&mut GlobalPodBuffer::new( factorize_simplicial_numeric_lu_req::(m, n).unwrap(), )), @@ -2547,14 +2554,14 @@ mod tests { let mut work = rhs.clone(); let A_dense = sparse_to_dense(A); - let row_perm = PermutationRef::<'_, _, Symbolic>::new_checked(&row_perm, &row_perm_inv); + let row_perm = PermRef::<'_, _>::new_checked(&row_perm, &row_perm_inv); { let mut x = rhs.clone(); lu.solve_in_place_with_conj( - row_perm.cast(), - col_perm.cast(), + row_perm, + col_perm, Conj::No, x.as_mut(), Parallelism::None, @@ -2566,8 +2573,8 @@ mod tests { let mut x = rhs.clone(); lu.solve_in_place_with_conj( - row_perm.cast(), - col_perm.cast(), + row_perm, + col_perm, Conj::Yes, x.as_mut(), Parallelism::None, @@ -2580,8 +2587,8 @@ mod tests { let mut x = rhs.clone(); lu.solve_transpose_in_place_with_conj( - row_perm.cast(), - col_perm.cast(), + row_perm, + col_perm, Conj::No, x.as_mut(), Parallelism::None, @@ -2593,8 +2600,8 @@ mod tests { let mut x = rhs.clone(); lu.solve_transpose_in_place_with_conj( - row_perm.cast(), - col_perm.cast(), + row_perm, + col_perm, Conj::Yes, x.as_mut(), Parallelism::None, @@ -2606,7 +2613,7 @@ mod tests { #[test] fn test_solver_lu_simplicial() { - type E = faer_core::c64; + type E = crate::complex_native::c64; let (m, n, col_ptr, row_ind, val) = load_mtx::(MtxData::from_file("test_data/YAO.mtx").unwrap()); @@ -2652,7 +2659,7 @@ mod tests { { let mut x = rhs.clone(); lu.solve_in_place_with_conj( - faer_core::Conj::No, + crate::Conj::No, x.as_mut(), Parallelism::None, PodStack::new(&mut GlobalPodBuffer::new( @@ -2668,7 +2675,7 @@ mod tests { { let mut x = rhs.clone(); lu.solve_in_place_with_conj( - faer_core::Conj::Yes, + crate::Conj::Yes, x.as_mut(), Parallelism::None, PodStack::new(&mut GlobalPodBuffer::new( @@ -2685,7 +2692,7 @@ mod tests { { let mut x = rhs.clone(); lu.solve_transpose_in_place_with_conj( - faer_core::Conj::No, + crate::Conj::No, x.as_mut(), Parallelism::None, PodStack::new(&mut GlobalPodBuffer::new( @@ -2701,7 +2708,7 @@ mod tests { { let mut x = rhs.clone(); lu.solve_transpose_in_place_with_conj( - faer_core::Conj::Yes, + crate::Conj::Yes, x.as_mut(), Parallelism::None, PodStack::new(&mut GlobalPodBuffer::new( diff --git a/faer-libs/faer-sparse/src/lib.rs b/src/sparse/linalg/mod.rs similarity index 78% rename from faer-libs/faer-sparse/src/lib.rs rename to src/sparse/linalg/mod.rs index 1807e1359d49ba227bf221b998cacbfc126dbbd8..c9eeb9dedc5dd56b68b407f08ecdb17c6ae72d59 100644 --- a/faer-libs/faer-sparse/src/lib.rs +++ b/src/sparse/linalg/mod.rs @@ -1,25 +1,26 @@ -#![cfg_attr(not(feature = "std"), no_std)] -#![allow(clippy::missing_safety_doc)] -#![allow(clippy::type_complexity)] -#![allow(clippy::too_many_arguments)] -#![forbid(elided_lifetimes_in_paths)] -#![allow(non_snake_case)] - -use bytemuck::Pod; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ - group_helpers::*, - permutation::PermutationRef, +use crate::{ + perm::PermRef, sparse::{windows2, *}, Entity, Side, }; +use bytemuck::Pod; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; const CHOLESKY_SUPERNODAL_RATIO_FACTOR: f64 = 40.0; const QR_SUPERNODAL_RATIO_FACTOR: f64 = 40.0; const LU_SUPERNODAL_RATIO_FACTOR: f64 = 40.0; +/// Tuning parameters for the supernodal factorizations. #[derive(Copy, Clone, Debug)] pub struct SymbolicSupernodalParams<'a> { + /// Supernode relaxation thresholds. + /// + /// Let `n` be the total number of columns in two adjacent supernodes. + /// Let `z` be the fraction of zero entries in the two supernodes if they + /// are merged (z includes zero entries from prior amalgamations). The + /// two supernodes are merged if: + /// + /// `(n <= relax[0].0 && z < relax[0].1) || (n <= relax[1].0 && z < relax[1].1) || ...` pub relax: Option<&'a [(usize, f64)]>, } @@ -32,6 +33,12 @@ impl Default for SymbolicSupernodalParams<'_> { } } +/// Nonnegative threshold controlling when the supernodal factorization is used. +/// +/// Increasing it makes it more likely for the simplicial factorization to be used, +/// while decreasing it makes it more likely for the supernodal factorization to be used. +/// +/// A value of `1.0` is the default. #[derive(Copy, Clone, Debug, PartialEq)] pub struct SupernodalThreshold(pub f64); @@ -43,20 +50,17 @@ impl Default for SupernodalThreshold { } impl SupernodalThreshold { + /// Simplicial factorization is always selected. pub const FORCE_SIMPLICIAL: Self = Self(f64::INFINITY); + /// Supernodal factorization is always selected. pub const FORCE_SUPERNODAL: Self = Self(0.0); + /// Determine automatically which variant to select. pub const AUTO: Self = Self(1.0); } -use faer_core::sparse::util::{ - ghost_adjoint_symbolic, ghost_permute_hermitian_unsorted, - ghost_permute_hermitian_unsorted_symbolic, transpose, -}; +use super::utils::{ghost_permute_hermitian_unsorted, ghost_permute_hermitian_unsorted_symbolic}; -pub use faer_core::{ - permutation::{Index, SignedIndex}, - FaerError, -}; +pub use crate::{sparse::FaerError, Index, SignedIndex}; #[allow(unused_macros)] macro_rules! shadow { @@ -155,26 +159,295 @@ macro_rules! monomorphize_test { }; } -extern crate alloc; - +/// Solving sparse triangular linear systems with a dense right-hand-side. pub mod triangular_solve; pub mod amd; pub mod colamd; pub mod cholesky; - -#[doc(hidden)] pub mod lu; -#[doc(hidden)] pub mod qr; -#[doc(hidden)] -pub mod superlu; +/// Sparse LU error. +#[derive(Copy, Clone, Debug)] +pub enum LuError { + /// Generic sparse error. + Generic(FaerError), + /// Rank deficient symbolic structure. + /// + /// Contains the iteration at which a pivot could not be found. + SymbolicSingular(usize), +} + +impl core::fmt::Display for LuError { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + core::fmt::Debug::fmt(self, f) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for LuError {} -mod ghost; +impl From for LuError { + #[inline] + fn from(value: FaerError) -> Self { + Self::Generic(value) + } +} -mod mem; +/// Sparse Cholesky error. +#[derive(Copy, Clone, Debug)] +pub enum CholeskyError { + /// Generic sparse error. + Generic(FaerError), + /// Rank deficient symbolic structure. + SymbolicSingular, + /// Matrix is not positive definite. + NotPositiveDefinite, +} + +impl core::fmt::Display for CholeskyError { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + core::fmt::Debug::fmt(self, f) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for CholeskyError {} + +impl From for CholeskyError { + #[inline] + fn from(value: FaerError) -> Self { + Self::Generic(value) + } +} + +impl From for CholeskyError { + #[inline] + fn from(_: crate::linalg::solvers::CholeskyError) -> Self { + Self::NotPositiveDefinite + } +} + +/// High level sparse solvers. +pub mod solvers; + +mod ghost { + pub use crate::utils::constrained::{perm::*, sparse::*, *}; + use crate::Index; + + pub const NONE_BYTE: u8 = u8::MAX; + + #[inline] + pub fn with_size(n: usize, f: impl FnOnce(Size<'_>) -> R) -> R { + Size::with(n, f) + } + + #[inline] + pub fn fill_zero<'n, 'a, I: Index>(slice: &'a mut [I], size: Size<'n>) -> &'a mut [Idx<'n, I>] { + let len = slice.len(); + if len > 0 { + assert!(*size > 0); + } + unsafe { + core::ptr::write_bytes(slice.as_mut_ptr(), 0u8, len); + &mut *(slice as *mut _ as *mut _) + } + } + + #[inline] + pub fn fill_none<'n, 'a, I: Index>( + slice: &'a mut [I::Signed], + size: Size<'n>, + ) -> &'a mut [MaybeIdx<'n, I>] { + let _ = size; + let len = slice.len(); + unsafe { core::ptr::write_bytes(slice.as_mut_ptr(), NONE_BYTE, len) }; + unsafe { &mut *(slice as *mut _ as *mut _) } + } + + #[inline] + pub fn copy_slice<'n, 'a, I: Index>( + dst: &'a mut [I], + src: &[Idx<'n, I>], + ) -> &'a mut [Idx<'n, I>] { + let dst: &mut [Idx<'_, I>] = unsafe { &mut *(dst as *mut _ as *mut _) }; + dst.copy_from_slice(src); + dst + } +} + +mod mem { + use crate::SignedIndex; + + pub const NONE_BYTE: u8 = 0xFF; + pub const NONE: usize = crate::sparse::repeat_byte(NONE_BYTE); + + #[inline] + pub fn fill_none(slice: &mut [I]) { + let len = slice.len(); + unsafe { core::ptr::write_bytes(slice.as_mut_ptr(), NONE_BYTE, len) } + } + #[inline] + pub fn fill_zero(slice: &mut [I]) { + let len = slice.len(); + unsafe { core::ptr::write_bytes(slice.as_mut_ptr(), 0u8, len) } + } +} + +/// Sparse matrix multiplication. +pub mod matmul { + // TODO: sparse_sparse_matmul + // + // PERF: optimize matmul + // - parallelization + // - simd(?) + + use super::*; + use crate::{ + assert, + utils::constrained::{self, Size}, + }; + + /// Multiplies a sparse matrix `lhs` by a dense matrix `rhs`, and stores the result in + /// `acc`. See [`faer::linalg::matmul::matmul`](crate::linalg::matmul::matmul) for more details. + /// + /// # Note + /// Allows unsorted matrices. + #[track_caller] + pub fn sparse_dense_matmul< + I: Index, + E: ComplexField, + LhsE: Conjugate, + RhsE: Conjugate, + >( + acc: MatMut<'_, E>, + lhs: SparseColMatRef<'_, I, LhsE>, + rhs: MatRef<'_, RhsE>, + alpha: Option, + beta: E, + parallelism: Parallelism, + ) { + assert!(all( + acc.nrows() == lhs.nrows(), + acc.ncols() == rhs.ncols(), + lhs.ncols() == rhs.nrows(), + )); + + let _ = parallelism; + let m = acc.nrows(); + let n = acc.ncols(); + let k = lhs.ncols(); + + let mut acc = acc; + + match alpha { + Some(alpha) => { + if alpha != E::faer_one() { + zipped!(acc.rb_mut()) + .for_each(|unzipped!(mut dst)| dst.write(dst.read().faer_mul(alpha))) + } + } + None => acc.fill_zero(), + } + + Size::with2(m, n, |m, n| { + Size::with(k, |k| { + let mut acc = constrained::mat::MatMut::new(acc, m, n); + let lhs = constrained::sparse::SparseColMatRef::new(lhs, m, k); + let rhs = constrained::mat::MatRef::new(rhs, k, n); + + for j in n.indices() { + for depth in k.indices() { + let rhs_kj = rhs.read(depth, j).canonicalize().faer_mul(beta); + for (i, lhs_ik) in zip( + lhs.row_indices_of_col(depth), + SliceGroup::<'_, LhsE>::new(lhs.values_of_col(depth)).into_ref_iter(), + ) { + acc.write( + i, + j, + acc.read(i, j) + .faer_add(lhs_ik.read().canonicalize().faer_mul(rhs_kj)), + ); + } + } + } + }); + }); + } + + /// Multiplies a dense matrix `lhs` by a sparse matrix `rhs`, and stores the result in + /// `acc`. See [`faer::linalg::matmul::matmul`](crate::linalg::matmul::matmul) for more details. + /// + /// # Note + /// Allows unsorted matrices. + #[track_caller] + pub fn dense_sparse_matmul< + I: Index, + E: ComplexField, + LhsE: Conjugate, + RhsE: Conjugate, + >( + acc: MatMut<'_, E>, + lhs: MatRef<'_, LhsE>, + rhs: SparseColMatRef<'_, I, RhsE>, + alpha: Option, + beta: E, + parallelism: Parallelism, + ) { + assert!(all( + acc.nrows() == lhs.nrows(), + acc.ncols() == rhs.ncols(), + lhs.ncols() == rhs.nrows(), + )); + + let _ = parallelism; + let m = acc.nrows(); + let n = acc.ncols(); + let k = lhs.ncols(); + + let mut acc = acc; + + match alpha { + Some(alpha) => { + if alpha != E::faer_one() { + zipped!(acc.rb_mut()) + .for_each(|unzipped!(mut dst)| dst.write(dst.read().faer_mul(alpha))) + } + } + None => acc.fill_zero(), + } + + Size::with2(m, n, |m, n| { + Size::with(k, |k| { + let mut acc = constrained::mat::MatMut::new(acc, m, n); + let lhs = constrained::mat::MatRef::new(lhs, m, k); + let rhs = constrained::sparse::SparseColMatRef::new(rhs, k, n); + + for i in m.indices() { + for j in n.indices() { + let mut acc_ij = E::faer_zero(); + for (depth, rhs_kj) in zip( + rhs.row_indices_of_col(j), + SliceGroup::<'_, RhsE>::new(rhs.values_of_col(j)).into_ref_iter(), + ) { + let lhs_ik = lhs.read(i, depth); + acc_ij = acc_ij.faer_add( + lhs_ik.canonicalize().faer_mul(rhs_kj.read().canonicalize()), + ); + } + + acc.write(i, j, acc.read(i, j).faer_add(beta.faer_mul(acc_ij))); + } + } + }); + }); + } +} #[cfg(test)] pub(crate) mod qd { @@ -566,7 +839,7 @@ pub(crate) mod qd { mod faer_impl { use super::*; - use faer_core::{ComplexField, Conjugate, Entity, RealField}; + use crate::{ComplexField, Conjugate, Entity, RealField}; unsafe impl Entity for Double { type Unit = f64; diff --git a/faer-libs/faer-sparse/src/qr.rs b/src/sparse/linalg/qr.rs similarity index 92% rename from faer-libs/faer-sparse/src/qr.rs rename to src/sparse/linalg/qr.rs index 35be044d3d8b18a056738606d9f76dfb8b8ebad4..0cdd630e629ada37b22d517a503e5e9450e5fdbe 100644 --- a/faer-libs/faer-sparse/src/qr.rs +++ b/src/sparse/linalg/qr.rs @@ -1,12 +1,13 @@ -//! Computes the QR decomposition of a given sparse matrix. See [`faer_qr`] for more info. +//! Computes the QR decomposition of a given sparse matrix. See [`crate::linalg::qr`] for more info. //! //! The entry point in this module is [`SymbolicQr`] and [`factorize_symbolic_qr`]. //! //! # Warning //! The functions in this module accept unsorted input, and always produce unsorted decomposition //! factors. +#![allow(missing_docs)] -use crate::{ +use super::{ cholesky::{ ghost_postorder, simplicial::EliminationTreeRef, @@ -17,24 +18,27 @@ use crate::{ mem::{self, NONE}, nomem, try_zeroed, FaerError, Index, SupernodalThreshold, SymbolicSupernodalParams, }; -use core::iter::zip; -use dyn_stack::{PodStack, SizeOverflow, StackReq}; -use faer_core::{ +use crate::{ assert, - constrained::Size, - group_helpers::{SliceGroup, SliceGroupMut}, - householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj, - permutation::{PermutationRef, SignedIndex}, + linalg::{ + householder::apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj, + temp_mat_req, temp_mat_uninit, + }, + perm::PermRef, sparse::{SparseColMatRef, SymbolicSparseColMatRef}, - temp_mat_req, temp_mat_uninit, unzipped, zipped, Conj, MatMut, Parallelism, + unzipped, + utils::{constrained::Size, slice::*}, + zipped, Conj, MatMut, Parallelism, SignedIndex, }; +use core::iter::zip; +use dyn_stack::{PodStack, SizeOverflow, StackReq}; use faer_entity::*; use reborrow::*; #[inline] -pub(crate) fn ghost_col_etree<'m, 'n, I: Index>( - A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>, - col_perm: Option>, +pub(crate) fn ghost_col_etree<'n, I: Index>( + A: ghost::SymbolicSparseColMatRef<'_, 'n, '_, I>, + col_perm: Option>, etree: &mut Array<'n, I::Signed>, stack: PodStack<'_>, ) { @@ -51,9 +55,7 @@ pub(crate) fn ghost_col_etree<'m, 'n, I: Index>( mem::fill_none(etree.as_mut()); for j in N.indices() { - let pj = col_perm - .map(|perm| perm.into_arrays().0[j].zx()) - .unwrap_or(j); + let pj = col_perm.map(|perm| perm.arrays().0[j].zx()).unwrap_or(j); for i_ in A.row_indices_of_col(pj) { let mut i = prev[i_].sx(); while let Some(i_) = i.idx() { @@ -77,14 +79,14 @@ pub(crate) fn ghost_col_etree<'m, 'n, I: Index>( #[inline] pub fn col_etree<'out, I: Index>( A: SymbolicSparseColMatRef<'_, I>, - col_perm: Option>, + col_perm: Option>, etree: &'out mut [I], stack: PodStack<'_>, ) -> EliminationTreeRef<'out, I> { Size::with2(A.nrows(), A.ncols(), |M, N| { ghost_col_etree( ghost::SymbolicSparseColMatRef::new(A, M, N), - col_perm.map(|perm| ghost::PermutationRef::new(perm, N)), + col_perm.map(|perm| ghost::PermRef::new(perm, N)), Array::from_mut(bytemuck::cast_slice_mut(etree), N), stack, ); @@ -132,7 +134,7 @@ pub(crate) fn ghost_column_counts_aat<'m, 'n, I: Index>( col_counts: &mut Array<'m, I>, min_row: &mut Array<'n, I::Signed>, A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>, - row_perm: Option>, + row_perm: Option>, etree: &Array<'m, MaybeIdx<'m, I>>, post: &Array<'m, Idx<'m, I>>, stack: PodStack<'_>, @@ -163,7 +165,7 @@ pub(crate) fn ghost_column_counts_aat<'m, 'n, I: Index>( for j in N.indices() { if let Some(perm) = row_perm { - let inv = perm.into_arrays().1; + let inv = perm.arrays().1; min_row[j] = match Iterator::min(A.row_indices_of_col(j).map(|j| inv[j].zx())) { Some(first_row) => I::Signed::truncate(*first_row), None => *MaybeIdx::<'_, I>::none(), @@ -176,7 +178,7 @@ pub(crate) fn ghost_column_counts_aat<'m, 'n, I: Index>( } let min_row = if let Some(perm) = row_perm { - let inv = perm.into_arrays().1; + let inv = perm.arrays().1; Iterator::min(A.row_indices_of_col(j).map(|row| post_inv[inv[row].zx()])) } else { Iterator::min(A.row_indices_of_col(j).map(|row| post_inv[row])) @@ -234,9 +236,7 @@ pub(crate) fn ghost_column_counts_aat<'m, 'n, I: Index>( while let Some(j_) = j.idx() { for i in A.row_indices_of_col(j_) { - let i = row_perm - .map(|perm| perm.into_arrays().1[i].zx()) - .unwrap_or(i); + let i = row_perm.map(|perm| perm.arrays().1[i].zx()).unwrap_or(i); let lca = ghost_least_common_ancestor::(i, pk, first, max_first, prev_leaf, ancestor); @@ -268,7 +268,7 @@ pub fn column_counts_aat<'m, 'n, I: Index>( col_counts: &mut [I], min_row: &mut [I], A: SymbolicSparseColMatRef<'_, I>, - row_perm: Option>, + row_perm: Option>, etree: EliminationTreeRef<'_, I>, post: &[I], stack: PodStack<'_>, @@ -279,7 +279,7 @@ pub fn column_counts_aat<'m, 'n, I: Index>( Array::from_mut(col_counts, M), Array::from_mut(bytemuck::cast_slice_mut(min_row), N), A, - row_perm.map(|perm| ghost::PermutationRef::new(perm, M)), + row_perm.map(|perm| ghost::PermRef::new(perm, M)), etree.ghost_inner(M), Array::from_ref(Idx::from_slice_ref_checked(post, M), M), stack, @@ -296,7 +296,7 @@ pub fn postorder(post: &mut [I], etree: EliminationTreeRef<'_, I>, sta pub mod supernodal { use super::*; - use faer_core::assert; + use crate::assert; #[derive(Debug)] pub struct SymbolicSupernodalHouseholder { @@ -349,7 +349,7 @@ pub mod supernodal { } #[derive(Debug)] - pub struct SymbolicSupernodalQr { + pub struct SymbolicSupernodalQr { L: SymbolicSupernodalCholesky, H: SymbolicSupernodalHouseholder, min_col: alloc::vec::Vec, @@ -386,7 +386,7 @@ pub mod supernodal { let s_h_row_full_end = H_symbolic.col_ptrs_for_row_indices[s + 1].zx(); let max_blocksize = H_symbolic.max_blocksize[s].zx(); - loop_req = loop_req.try_or(faer_core::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_req::(s_h_row_full_end - s_h_row_begin, max_blocksize, rhs_ncols)?)?; + loop_req = loop_req.try_or(crate::linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_req::(s_h_row_full_end - s_h_row_begin, max_blocksize, rhs_ncols)?)?; } Ok(loop_req) @@ -398,12 +398,14 @@ pub mod supernodal { ncols: usize, ) -> Result { let _ = nrows; - crate::cholesky::supernodal::factorize_supernodal_symbolic_cholesky_req::(ncols) + crate::sparse::linalg::cholesky::supernodal::factorize_supernodal_symbolic_cholesky_req::( + ncols, + ) } pub fn factorize_supernodal_symbolic_qr( A: SymbolicSparseColMatRef<'_, I>, - col_perm: Option>, + col_perm: Option>, min_col: alloc::vec::Vec, etree: EliminationTreeRef<'_, I>, col_counts: &[I], @@ -421,13 +423,13 @@ pub mod supernodal { MaybeIdx::from_slice_ref_checked(bytemuck::cast_slice(&min_col), N), M, ); - let L = crate::cholesky::supernodal::ghost_factorize_supernodal_symbolic( + let L = crate::sparse::linalg::cholesky::supernodal::ghost_factorize_supernodal_symbolic( A, - col_perm.map(|perm| ghost::PermutationRef::new(perm, N)), + col_perm.map(|perm| ghost::PermRef::new(perm, N)), Some(min_col), - crate::cholesky::supernodal::CholeskyInput::ATA, + crate::sparse::linalg::cholesky::supernodal::CholeskyInput::ATA, etree, - Array::from_ref(&col_counts, N), + Array::from_ref(col_counts, N), stack.rb_mut(), params, )?; @@ -508,7 +510,7 @@ pub mod supernodal { let (index_to_super, _) = stack.make_raw::(*N); for s in N_SUPERNODES.indices() { - index_to_super.as_mut()[supernode_begin[s].zx()..supernode_end[s].zx()] + index_to_super[supernode_begin[s].zx()..supernode_end[s].zx()] .fill(*s.truncate::()); } let index_to_super = @@ -567,10 +569,9 @@ pub mod supernodal { + (L_col_ptrs_for_row_indices[*s + 1] - L_col_ptrs_for_row_indices[*s]); val_count += to_wide(s_row_count) * to_wide(s_col_count); row_count += to_wide(s_row_count); - let blocksize = faer_qr::no_pivoting::compute::recommended_blocksize::( - s_row_count.zx(), - s_col_count.zx(), - ) as u128; + let blocksize = crate::linalg::qr::no_pivoting::compute::recommended_blocksize::< + Symbolic, + >(s_row_count.zx(), s_col_count.zx()) as u128; max_blocksize[*s] = from_wide(blocksize); tau_count += blocksize * to_wide(Ord::min(s_row_count, s_col_count)); *next_val_ptr = from_wide(val_count); @@ -593,7 +594,7 @@ pub mod supernodal { } #[derive(Debug)] - pub struct SupernodalQrRef<'a, I, E: Entity> { + pub struct SupernodalQrRef<'a, I: Index, E: Entity> { symbolic: &'a SymbolicSupernodalQr, rt_values: SliceGroup<'a, E>, householder_values: SliceGroup<'a, E>, @@ -604,8 +605,8 @@ pub mod supernodal { householder_ncols: &'a [I], } - impl Copy for SupernodalQrRef<'_, I, E> {} - impl Clone for SupernodalQrRef<'_, I, E> { + impl Copy for SupernodalQrRef<'_, I, E> {} + impl Clone for SupernodalQrRef<'_, I, E> { #[inline] fn clone(&self) -> Self { *self @@ -727,7 +728,7 @@ pub mod supernodal { ..H_symbolic.col_ptrs_for_values[s + 1].zx(), ); - let s_H = faer_core::mat::from_column_major_slice::<'_, E>( + let s_H = crate::mat::from_column_major_slice::<'_, E>( s_H.into_inner(), s_h_row_full_end - s_h_row_begin, s_ncols @@ -736,7 +737,7 @@ pub mod supernodal { ); let s_tau = tau.subslice(tau_begin..tau_end); let max_blocksize = H_symbolic.max_blocksize[s].zx(); - let s_tau = faer_core::mat::from_column_major_slice::<'_, E>( + let s_tau = crate::mat::from_column_major_slice::<'_, E>( s_tau.into_inner(), max_blocksize, Ord::min(s_H.ncols(), s_h_row_full_end - s_h_row_begin), @@ -755,7 +756,7 @@ pub mod supernodal { apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj( b_H.rb(), b_tau.rb(), - faer_core::Conj::Yes.compose(conj), + crate::Conj::Yes.compose(conj), tmp.rb_mut().subrows_mut(start, nrows), parallelism, stack.rb_mut(), @@ -803,7 +804,7 @@ pub mod supernodal { } let mut x_top = x.rb_mut().subrows_mut(s.start(), size); - faer_core::mul::matmul_with_conj( + crate::linalg::matmul::matmul_with_conj( x_top.rb_mut(), s_L_bot.transpose(), conj.compose(Conj::Yes), @@ -813,7 +814,7 @@ pub mod supernodal { E::faer_one().faer_neg(), parallelism, ); - faer_core::solve::solve_upper_triangular_in_place_with_conj( + crate::linalg::triangular_solve::solve_upper_triangular_in_place_with_conj( s_L_top.transpose(), conj.compose(Conj::Yes), x_top.rb_mut(), @@ -825,8 +826,8 @@ pub mod supernodal { } #[track_caller] - pub fn factorize_supernodal_numeric_qr_req<'a, I: Index, E: Entity>( - symbolic: &'a SymbolicSupernodalQr, + pub fn factorize_supernodal_numeric_qr_req( + symbolic: &SymbolicSupernodalQr, parallelism: Parallelism, ) -> Result { let n_supernodes = symbolic.L.n_supernodes(); @@ -852,15 +853,17 @@ pub mod supernodal { let s_pattern_len = symbolic.L.col_ptrs_for_row_indices()[s + 1].zx() - symbolic.L.col_ptrs_for_row_indices()[s].zx(); - loop_req = loop_req.try_or(faer_qr::no_pivoting::compute::qr_in_place_req::( - s_h_row_full_end - s_h_row_begin, - s_ncols + s_pattern_len, - max_blocksize, - parallelism, - Default::default(), - )?)?; + loop_req = loop_req.try_or( + crate::linalg::qr::no_pivoting::compute::qr_in_place_req::( + s_h_row_full_end - s_h_row_begin, + s_ncols + s_pattern_len, + max_blocksize, + parallelism, + Default::default(), + )?, + )?; - loop_req = loop_req.try_or(faer_core::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_req::(s_h_row_full_end - s_h_row_begin, max_blocksize, s_ncols + s_pattern_len)?)?; + loop_req = loop_req.try_or(crate::linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_req::(s_h_row_full_end - s_h_row_begin, max_blocksize, s_ncols + s_pattern_len)?)?; } init_req.try_and(loop_req) @@ -880,7 +883,7 @@ pub mod supernodal { tau_values: GroupFor, AT: SparseColMatRef<'_, I, E>, - col_perm: Option>, + col_perm: Option>, symbolic: &'a SymbolicSupernodalQr, parallelism: Parallelism, stack: PodStack<'_>, @@ -950,7 +953,7 @@ pub mod supernodal { tau_values: GroupFor, AT: SparseColMatRef<'_, I, E>, - col_perm: Option>, + col_perm: Option>, L_symbolic: &SymbolicSupernodalCholesky, H_symbolic: &SymbolicSupernodalHouseholder, min_col: &[I], @@ -1047,7 +1050,7 @@ pub mod supernodal { let (c_H, s_H) = s_H.split_at(H_symbolic.col_ptrs_for_values[s].zx()); let c_H = c_H.into_const(); - let mut s_H = faer_core::mat::from_column_major_slice_mut::<'_, E>( + let mut s_H = crate::mat::from_column_major_slice_mut::<'_, E>( s_H.into_inner(), s_h_row_full_end - s_h_row_begin, s_ncols + s_pattern.len(), @@ -1094,9 +1097,7 @@ pub mod supernodal { AT.row_indices_of_col(i), SliceGroup::<'_, E>::new(AT.values_of_col(i)).into_ref_iter(), ) { - let pj = col_perm - .map(|perm| perm.into_arrays().1[j].zx()) - .unwrap_or(j); + let pj = col_perm.map(|perm| perm.arrays().1[j].zx()).unwrap_or(j); let ix = idx; let iy = col_global_to_local[pj].zx(); s_H.write(ix, iy, s_H.read(ix, iy).faer_add(value.read())); @@ -1123,7 +1124,7 @@ pub mod supernodal { H_symbolic.col_ptrs_for_values[child].zx() ..H_symbolic.col_ptrs_for_values[child + 1].zx(), ); - let c_H = faer_core::mat::from_column_major_slice::<'_, E>( + let c_H = crate::mat::from_column_major_slice::<'_, E>( c_H.into_inner(), H_symbolic.col_ptrs_for_row_indices[child + 1].zx() - c_h_row_begin, c_ncols + c_pattern.len(), @@ -1187,7 +1188,7 @@ pub mod supernodal { let s_L = L_values.rb_mut().subslice(L_begin..L_end); let max_blocksize = H_symbolic.max_blocksize[s].zx(); - let mut s_tau = faer_core::mat::from_column_major_slice_mut::<'_, E>( + let mut s_tau = crate::mat::from_column_major_slice_mut::<'_, E>( s_tau.into_inner(), max_blocksize, Ord::min(s_H.ncols(), s_h_row_full_end - s_h_row_begin), @@ -1225,10 +1226,9 @@ pub mod supernodal { ); let (mut left, mut right) = s_H.split_at_col_mut(ncols); - let bs = faer_qr::no_pivoting::compute::recommended_blocksize::( - left.nrows(), - left.ncols(), - ); + let bs = crate::linalg::qr::no_pivoting::compute::recommended_blocksize::< + Symbolic, + >(left.nrows(), left.ncols()); let bs = Ord::min(max_blocksize, bs); tau_blocksize[block_count] = I::truncate(bs); householder_nrows[block_count] = I::truncate(nrows); @@ -1240,7 +1240,7 @@ pub mod supernodal { .subrows_mut(0, bs) .subcols_mut(current_start, ncols); - faer_qr::no_pivoting::compute::qr_in_place( + crate::linalg::qr::no_pivoting::compute::qr_in_place( left.rb_mut(), s_tau.rb_mut(), parallelism, @@ -1252,7 +1252,7 @@ pub mod supernodal { apply_block_householder_sequence_transpose_on_the_left_in_place_with_conj( left.rb(), s_tau.rb(), - faer_core::Conj::Yes, + crate::Conj::Yes, right.rb_mut(), parallelism, stack.rb_mut(), @@ -1265,7 +1265,7 @@ pub mod supernodal { } } - let mut s_L = faer_core::mat::from_column_major_slice_mut::<'_, E>( + let mut s_L = crate::mat::from_column_major_slice_mut::<'_, E>( s_L.into_inner(), s_pattern.len() + s_ncols, s_ncols, @@ -1276,7 +1276,7 @@ pub mod supernodal { s_H.rb().subrows(0, nrows) ) .for_each_triangular_upper( - faer_core::zip::Diag::Include, + crate::linalg::zip::Diag::Include, |unzipped!(mut dst, src)| dst.write(src.read().faer_conj()), ); } @@ -1341,10 +1341,8 @@ pub mod supernodal { } pub mod simplicial { - use crate::triangular_solve; - use super::*; - use faer_core::assert; + use crate::{assert, sparse::linalg::triangular_solve}; #[derive(Debug)] pub struct SymbolicSimplicialQr { @@ -1513,7 +1511,7 @@ pub mod simplicial { let hx = SliceGroup::<'_, E>::new(h.values_of_col(j)); let tau_inv = tau.read(j).faer_real().faer_inv(); - if hi.len() == 0 { + if hi.is_empty() { tmp.rb_mut().row_mut(j).fill_zero(); continue; } @@ -1596,10 +1594,10 @@ pub mod simplicial { h_non_zero_count[parent.zx()] += h_non_zero_count[j] - I::truncate(1); } - let h_nnz = I::sum_nonnegative(&h_non_zero_count) + let h_nnz = I::sum_nonnegative(h_non_zero_count) .ok_or(FaerError::IndexOverflow)? .zx(); - let l_nnz = I::sum_nonnegative(&col_counts) + let l_nnz = I::sum_nonnegative(col_counts) .ok_or(FaerError::IndexOverflow)? .zx(); @@ -1625,15 +1623,15 @@ pub mod simplicial { }) } - pub fn factorize_simplicial_numeric_qr_req<'a, I: Index, E: Entity>( - symbolic: &'a SymbolicSimplicialQr, + pub fn factorize_simplicial_numeric_qr_req( + symbolic: &SymbolicSimplicialQr, ) -> Result { let m = symbolic.nrows; StackReq::try_all_of([ StackReq::try_new::(m)?, StackReq::try_new::(m)?, StackReq::try_new::(m)?, - crate::make_raw_req::(m)?, + crate::sparse::linalg::make_raw_req::(m)?, ]) } @@ -1647,7 +1645,7 @@ pub mod simplicial { tau_values: GroupFor, A: SparseColMatRef<'_, I, E>, - col_perm: Option>, + col_perm: Option>, symbolic: &'a SymbolicSimplicialQr, stack: PodStack<'_>, ) -> SimplicialQrRef<'a, I, E> { @@ -1662,10 +1660,10 @@ pub mod simplicial { let (r_idx, stack) = stack.make_raw::(m); let (marked, stack) = stack.make_raw::(m); let (pattern, stack) = stack.make_raw::(m); - let (mut x, _) = crate::make_raw::(m, stack); + let (mut x, _) = crate::sparse::linalg::make_raw::(m, stack); x.fill_zero(); - crate::mem::fill_zero(marked); - crate::mem::fill_none(r_idx); + super::mem::fill_zero(marked); + super::mem::fill_none(r_idx); let mut r_values = SliceGroupMut::<'_, E>::new(r_values); let mut householder_values = SliceGroupMut::<'_, E>::new(householder_values); @@ -1675,9 +1673,7 @@ pub mod simplicial { let mut r_pos = 0usize; let mut h_pos = 0usize; for j in 0..n { - let pj = col_perm - .map(|perm| perm.into_arrays().0[j].zx()) - .unwrap_or(j); + let pj = col_perm.map(|perm| perm.arrays().0[j].zx()).unwrap_or(j); let mut pattern_len = 0usize; for (i, val) in zip( @@ -1766,7 +1762,7 @@ pub mod simplicial { continue; } - let mut h_col = faer_core::col::from_slice_mut::( + let mut h_col = crate::col::from_slice_mut::( householder_values .rb_mut() .subslice(h_begin..h_pos) @@ -1775,7 +1771,7 @@ pub mod simplicial { let (mut head, tail) = h_col.rb_mut().split_at_mut(1); let tail_norm = tail.norm_l2(); - let (tau, beta) = faer_core::householder::make_householder_in_place_v2( + let (tau, beta) = crate::linalg::householder::make_householder_in_place( Some(tail.as_2d_mut()), head.read(0), tail_norm, @@ -1821,14 +1817,14 @@ pub struct QrSymbolicParams<'a> { /// The inner factorization used for the symbolic QR, either simplicial or symbolic. #[derive(Debug)] -pub enum SymbolicQrRaw { +pub enum SymbolicQrRaw { Simplicial(simplicial::SymbolicSimplicialQr), Supernodal(supernodal::SymbolicSupernodalQr), } /// The symbolic structure of a sparse QR decomposition. #[derive(Debug)] -pub struct SymbolicQr { +pub struct SymbolicQr { raw: SymbolicQrRaw, col_perm_fwd: alloc::vec::Vec, col_perm_inv: alloc::vec::Vec, @@ -1890,7 +1886,7 @@ impl<'a, I: Index, E: Entity> QrRef<'a, I, E> { let (mut x, stack) = temp_mat_uninit::(m, k, stack); - let (_, inv) = self.symbolic.col_perm().into_arrays(); + let (_, inv) = self.symbolic.col_perm().arrays(); x.copy_from(rhs.rb()); let indices = self.indices; @@ -1995,8 +1991,8 @@ impl SymbolicQr { /// Returns the fill-reducing column permutation that was computed during symbolic analysis. #[inline] - pub fn col_perm(&self) -> PermutationRef<'_, I, Symbolic> { - unsafe { PermutationRef::new_unchecked(&self.col_perm_fwd, &self.col_perm_inv) } + pub fn col_perm(&self) -> PermRef<'_, I> { + unsafe { PermRef::new_unchecked(&self.col_perm_fwd, &self.col_perm_inv) } } /// Returns the length of the slice needed to store the numerical values of the QR @@ -2049,7 +2045,7 @@ impl SymbolicQr { let m = self.nrows(); let A_nnz = self.A_nnz; let AT_req = StackReq::try_all_of([ - crate::make_raw_req::(A_nnz)?, + crate::sparse::linalg::make_raw_req::(A_nnz)?, StackReq::try_new::(m + 1)?, StackReq::try_new::(A_nnz)?, ])?; @@ -2104,7 +2100,7 @@ impl SymbolicQr { householder_values.into_inner(), tau_values.into_inner(), A, - Some(self.col_perm().cast()), + Some(self.col_perm()), symbolic, stack, ); @@ -2124,9 +2120,10 @@ impl SymbolicQr { let (new_col_ptr, stack) = stack.make_raw::(m + 1); let (new_row_ind, stack) = stack.make_raw::(self.A_nnz); - let (new_values, mut stack) = crate::make_raw::(self.A_nnz, stack); + let (new_values, mut stack) = + crate::sparse::linalg::make_raw::(self.A_nnz, stack); - let AT = crate::transpose::( + let AT = crate::sparse::utils::transpose::( new_col_ptr, new_row_ind, new_values.into_inner(), @@ -2144,7 +2141,7 @@ impl SymbolicQr { householder_values.into_inner(), tau_values.into_inner(), AT, - Some(self.col_perm().cast()), + Some(self.col_perm()), symbolic, parallelism, stack, @@ -2216,14 +2213,16 @@ pub fn factorize_symbolic_qr( stack.rb_mut(), )?; - let col_perm = ghost::PermutationRef::new( - PermutationRef::new_checked(&col_perm_fwd, &col_perm_inv), - N, - ); + let col_perm = ghost::PermRef::new(PermRef::new_checked(&col_perm_fwd, &col_perm_inv), N); let (new_col_ptr, stack) = stack.make_raw::(m + 1); let (new_row_ind, mut stack) = stack.make_raw::(A_nnz); - let AT = crate::ghost_adjoint_symbolic(new_col_ptr, new_row_ind, A, stack.rb_mut()); + let AT = crate::sparse::utils::ghost_adjoint_symbolic( + new_col_ptr, + new_row_ind, + A, + stack.rb_mut(), + ); let (etree, stack) = stack.make_raw::(n); let (post, stack) = stack.make_raw::(n); @@ -2231,7 +2230,7 @@ pub fn factorize_symbolic_qr( let (h_col_counts, mut stack) = stack.make_raw::(n); ghost_col_etree(A, Some(col_perm), Array::from_mut(etree, N), stack.rb_mut()); - let etree_ = Array::from_ref(MaybeIdx::<'_, I>::from_slice_ref_checked(&etree, N), N); + let etree_ = Array::from_ref(MaybeIdx::<'_, I>::from_slice_ref_checked(etree, N), N); ghost_postorder(Array::from_mut(post, N), etree_, stack.rb_mut()); ghost_column_counts_aat( @@ -2240,7 +2239,7 @@ pub fn factorize_symbolic_qr( AT, Some(col_perm), etree_, - Array::from_ref(Idx::from_slice_ref_checked(&post, N), N), + Array::from_ref(Idx::from_slice_ref_checked(post, N), N), stack.rb_mut(), ); let min_col = min_row; @@ -2274,7 +2273,7 @@ pub fn factorize_symbolic_qr( nnz += hj + rj; } - if flops / nnz > threshold.0 * crate::QR_SUPERNODAL_RATIO_FACTOR { + if flops / nnz > threshold.0 * crate::sparse::linalg::QR_SUPERNODAL_RATIO_FACTOR { threshold = SupernodalThreshold::FORCE_SUPERNODAL; } else { threshold = SupernodalThreshold::FORCE_SIMPLICIAL; @@ -2286,8 +2285,8 @@ pub fn factorize_symbolic_qr( A.into_inner(), Some(col_perm.into_inner()), min_col, - EliminationTreeRef::<'_, I> { inner: &etree }, - &col_counts, + EliminationTreeRef::<'_, I> { inner: etree }, + col_counts, stack.rb_mut(), params.supernodal_params, )?; @@ -2300,8 +2299,8 @@ pub fn factorize_symbolic_qr( } else { let symbolic = simplicial::factorize_simplicial_symbolic_qr::( &min_col, - EliminationTreeRef::<'_, I> { inner: &etree }, - &col_counts, + EliminationTreeRef::<'_, I> { inner: etree }, + col_counts, stack.rb_mut(), )?; Ok(SymbolicQr { @@ -2318,32 +2317,35 @@ pub fn factorize_symbolic_qr( mod tests { use super::*; use crate::{ - cholesky::{ - ghost_postorder, - simplicial::EliminationTreeRef, - supernodal::{SupernodalLdltRef, SymbolicSupernodalCholesky}, - }, - ghost_adjoint_symbolic, - qr::{ - simplicial::{ - factorize_simplicial_numeric_qr_req, factorize_simplicial_numeric_qr_unsorted, - factorize_simplicial_symbolic_qr, - }, - supernodal::{ - factorize_supernodal_numeric_qr, factorize_supernodal_numeric_qr_req, - factorize_supernodal_symbolic_qr, + assert, + complex_native::c64, + sparse::{ + linalg::{ + cholesky::{ + ghost_postorder, + simplicial::EliminationTreeRef, + supernodal::{SupernodalLdltRef, SymbolicSupernodalCholesky}, + }, + qr::{ + simplicial::{ + factorize_simplicial_numeric_qr_req, + factorize_simplicial_numeric_qr_unsorted, factorize_simplicial_symbolic_qr, + }, + supernodal::{ + factorize_supernodal_numeric_qr, factorize_supernodal_numeric_qr_req, + factorize_supernodal_symbolic_qr, + }, + }, + SymbolicSparseColMatRef, }, + utils::ghost_adjoint_symbolic, + SparseColMatRef, }, - SymbolicSparseColMatRef, + utils::slice::{SliceGroup, SliceGroupMut}, + Mat, }; use core::iter::zip; use dyn_stack::{GlobalPodBuffer, StackReq}; - use faer_core::{ - assert, c64, - group_helpers::{SliceGroup, SliceGroupMut}, - sparse::SparseColMatRef, - Mat, - }; use matrix_market_rs::MtxData; use rand::{Rng, SeedableRng}; @@ -2532,7 +2534,7 @@ mod tests { let mut new_row_ind = vec![zero; nnz]; let mut new_values = vec![0.0; nnz]; - let AT = faer_core::sparse::util::ghost_adjoint( + let AT = crate::sparse::utils::ghost_adjoint( &mut new_col_ptrs, &mut new_row_ind, SliceGroupMut::<'_, f64>::new(&mut new_values), @@ -2572,7 +2574,7 @@ mod tests { let min_col = min_row; let symbolic = factorize_supernodal_symbolic_qr::( - *A.into_inner(), + A.symbolic().into_inner(), None, min_col, EliminationTreeRef::<'_, I> { inner: &etree }, @@ -2604,11 +2606,11 @@ mod tests { AT.into_inner(), None, &symbolic, - faer_core::Parallelism::None, + crate::Parallelism::None, PodStack::new(&mut GlobalPodBuffer::new( factorize_supernodal_numeric_qr_req::( &symbolic, - faer_core::Parallelism::None, + crate::Parallelism::None, ) .unwrap(), )), @@ -2651,7 +2653,7 @@ mod tests { let mut new_row_ind = vec![zero; nnz]; let mut new_values = vec![E::faer_zero(); nnz]; - let AT = faer_core::sparse::util::ghost_transpose( + let AT = crate::sparse::utils::ghost_transpose( &mut new_col_ptrs, &mut new_row_ind, SliceGroupMut::<'_, E>::new(&mut new_values), @@ -2695,7 +2697,7 @@ mod tests { let min_col = min_row; let symbolic = factorize_supernodal_symbolic_qr::( - *A.into_inner(), + A.symbolic().into_inner(), None, min_col, EliminationTreeRef::<'_, I> { inner: &etree }, @@ -2730,11 +2732,11 @@ mod tests { AT.into_inner(), None, &symbolic, - faer_core::Parallelism::None, + crate::Parallelism::None, PodStack::new(&mut GlobalPodBuffer::new( factorize_supernodal_numeric_qr_req::( &symbolic, - faer_core::Parallelism::None, + crate::Parallelism::None, ) .unwrap(), )), @@ -2746,13 +2748,13 @@ mod tests { let mut x = rhs.clone(); let mut work = rhs.clone(); qr.solve_in_place_with_conj( - faer_core::Conj::No, + crate::Conj::No, x.as_mut(), - faer_core::Parallelism::None, + crate::Parallelism::None, work.as_mut(), PodStack::new(&mut GlobalPodBuffer::new( symbolic - .solve_in_place_req::(2, faer_core::Parallelism::None) + .solve_in_place_req::(2, crate::Parallelism::None) .unwrap(), )), ); @@ -2797,7 +2799,7 @@ mod tests { let mut new_row_ind = vec![zero; nnz]; let mut new_values = vec![E::faer_zero(); nnz]; - let AT = faer_core::sparse::util::ghost_transpose( + let AT = crate::sparse::utils::ghost_transpose( &mut new_col_ptrs, &mut new_row_ind, SliceGroupMut::<'_, E>::new(&mut new_values), @@ -2878,9 +2880,9 @@ mod tests { let mut x = rhs.clone(); let mut work = rhs.clone(); qr.solve_in_place_with_conj( - faer_core::Conj::No, + crate::Conj::No, x.as_mut(), - faer_core::Parallelism::None, + crate::Parallelism::None, work.as_mut(), ); @@ -2892,9 +2894,9 @@ mod tests { let mut x = rhs.clone(); let mut work = rhs.clone(); qr.solve_in_place_with_conj( - faer_core::Conj::Yes, + crate::Conj::Yes, x.as_mut(), - faer_core::Parallelism::None, + crate::Parallelism::None, work.as_mut(), ); @@ -2946,7 +2948,7 @@ mod tests { let mut new_row_ind = vec![I(0); nnz]; let mut new_values = vec![E::faer_zero(); nnz]; - let AT = crate::transpose::( + let AT = crate::sparse::utils::transpose::( &mut new_col_ptrs, &mut new_row_ind, SliceGroupMut::<'_, E>::new(&mut new_values).into_inner(), @@ -2990,9 +2992,9 @@ mod tests { { let mut x = rhs.clone(); qr.solve_in_place_with_conj( - faer_core::Conj::No, + crate::Conj::No, x.as_mut(), - faer_core::Parallelism::None, + crate::Parallelism::None, PodStack::new(&mut GlobalPodBuffer::new( symbolic .solve_in_place_req::(2, Parallelism::None) @@ -3007,9 +3009,9 @@ mod tests { { let mut x = rhs.clone(); qr.solve_in_place_with_conj( - faer_core::Conj::Yes, + crate::Conj::Yes, x.as_mut(), - faer_core::Parallelism::None, + crate::Parallelism::None, PodStack::new(&mut GlobalPodBuffer::new( symbolic .solve_in_place_req::(2, Parallelism::None) diff --git a/src/sparse/linalg/solvers.rs b/src/sparse/linalg/solvers.rs new file mode 100644 index 0000000000000000000000000000000000000000..396c4d0a9d9811ef5e9f36dd5395a8f967998d2e --- /dev/null +++ b/src/sparse/linalg/solvers.rs @@ -0,0 +1,961 @@ +use super::*; +use crate::mat::{AsMatMut, AsMatRef}; + +/// Object-safe base for [`SpSolver`] +pub trait SpSolverCore { + /// Returns the number of rows of the matrix used to construct this decomposition. + fn nrows(&self) -> usize; + /// Returns the number of columns of the matrix used to construct this decomposition. + fn ncols(&self) -> usize; + + #[doc(hidden)] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj); + #[doc(hidden)] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj); +} + +/// Object-safe base for [`SpSolverLstsq`] +pub trait SpSolverLstsqCore: SpSolverCore { + #[doc(hidden)] + fn solve_lstsq_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj); +} + +/// Solver that can compute solution of a linear system. +pub trait SpSolver: SpSolverCore { + /// Solves the equation `self * X = rhs`, and stores the result in `rhs`. + fn solve_in_place(&self, rhs: impl AsMatMut); + /// Solves the equation `conjugate(self) * X = rhs`, and stores the result in `rhs`. + fn solve_conj_in_place(&self, rhs: impl AsMatMut); + /// Solves the equation `transpose(self) * X = rhs`, and stores the result in `rhs`. + fn solve_transpose_in_place(&self, rhs: impl AsMatMut); + /// Solves the equation `adjoint(self) * X = rhs`, and stores the result in `rhs`. + fn solve_conj_transpose_in_place(&self, rhs: impl AsMatMut); + /// Solves the equation `self * X = rhs`, and returns the result. + fn solve>(&self, rhs: impl AsMatRef) -> Mat; + /// Solves the equation `conjugate(self) * X = rhs`, and returns the result. + fn solve_conj>(&self, rhs: impl AsMatRef) -> Mat; + /// Solves the equation `transpose(self) * X = rhs`, and returns the result. + fn solve_transpose>(&self, rhs: impl AsMatRef) + -> Mat; + /// Solves the equation `adjoint(self) * X = rhs`, and returns the result. + fn solve_conj_transpose>( + &self, + rhs: impl AsMatRef, + ) -> Mat; +} + +/// Solver that can compute the least squares solution of an overdetermined linear system. +pub trait SpSolverLstsq: SpSolverLstsqCore { + /// Solves the equation `self * X = rhs`, in the sense of least squares, and stores the + /// result in the top rows of `rhs`. + fn solve_lstsq_in_place(&self, rhs: impl AsMatMut); + /// Solves the equation `conjugate(self) * X = rhs`, in the sense of least squares, and + /// stores the result in the top rows of `rhs`. + fn solve_lstsq_conj_in_place(&self, rhs: impl AsMatMut); + /// Solves the equation `self * X = rhs`, and returns the result. + fn solve_lstsq>(&self, rhs: impl AsMatRef) -> Mat; + /// Solves the equation `conjugate(self) * X = rhs`, and returns the result. + fn solve_lstsq_conj>( + &self, + rhs: impl AsMatRef, + ) -> Mat; +} + +#[track_caller] +fn solve_with_conj_impl< + E: ComplexField, + D: ?Sized + SpSolverCore, + ViewE: Conjugate, +>( + d: &D, + rhs: MatRef<'_, ViewE>, + conj: Conj, +) -> Mat { + let mut rhs = rhs.to_owned(); + d.solve_in_place_with_conj_impl(rhs.as_mut(), conj); + rhs +} + +#[track_caller] +fn solve_transpose_with_conj_impl< + E: ComplexField, + D: ?Sized + SpSolverCore, + ViewE: Conjugate, +>( + d: &D, + rhs: MatRef<'_, ViewE>, + conj: Conj, +) -> Mat { + let mut rhs = rhs.to_owned(); + d.solve_transpose_in_place_with_conj_impl(rhs.as_mut(), conj); + rhs +} + +#[track_caller] +fn solve_lstsq_with_conj_impl< + E: ComplexField, + D: ?Sized + SpSolverLstsqCore, + ViewE: Conjugate, +>( + d: &D, + rhs: MatRef<'_, ViewE>, + conj: Conj, +) -> Mat { + let mut rhs = rhs.to_owned(); + let k = rhs.ncols(); + d.solve_lstsq_in_place_with_conj_impl(rhs.as_mut(), conj); + rhs.resize_with(d.ncols(), k, |_, _| unreachable!()); + rhs +} + +impl> SpSolver for Dec { + #[track_caller] + fn solve_in_place(&self, rhs: impl AsMatMut) { + let mut rhs = rhs; + self.solve_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::No) + } + + #[track_caller] + fn solve_conj_in_place(&self, rhs: impl AsMatMut) { + let mut rhs = rhs; + self.solve_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::Yes) + } + + #[track_caller] + fn solve_transpose_in_place(&self, rhs: impl AsMatMut) { + let mut rhs = rhs; + self.solve_transpose_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::No) + } + + #[track_caller] + fn solve_conj_transpose_in_place(&self, rhs: impl AsMatMut) { + let mut rhs = rhs; + self.solve_transpose_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::Yes) + } + + #[track_caller] + fn solve>(&self, rhs: impl AsMatRef) -> Mat { + solve_with_conj_impl::(self, rhs.as_mat_ref(), Conj::No) + } + + #[track_caller] + fn solve_conj>(&self, rhs: impl AsMatRef) -> Mat { + solve_with_conj_impl::(self, rhs.as_mat_ref(), Conj::Yes) + } + + #[track_caller] + fn solve_transpose>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + solve_transpose_with_conj_impl::(self, rhs.as_mat_ref(), Conj::No) + } + + #[track_caller] + fn solve_conj_transpose>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + solve_transpose_with_conj_impl::(self, rhs.as_mat_ref(), Conj::Yes) + } +} + +impl> SpSolverLstsq for Dec { + #[track_caller] + fn solve_lstsq_in_place(&self, rhs: impl AsMatMut) { + let mut rhs = rhs; + self.solve_lstsq_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::No) + } + + #[track_caller] + fn solve_lstsq_conj_in_place(&self, rhs: impl AsMatMut) { + let mut rhs = rhs; + self.solve_lstsq_in_place_with_conj_impl(rhs.as_mat_mut(), Conj::Yes) + } + + #[track_caller] + fn solve_lstsq>(&self, rhs: impl AsMatRef) -> Mat { + solve_lstsq_with_conj_impl::(self, rhs.as_mat_ref(), Conj::No) + } + + #[track_caller] + fn solve_lstsq_conj>( + &self, + rhs: impl AsMatRef, + ) -> Mat { + solve_lstsq_with_conj_impl::(self, rhs.as_mat_ref(), Conj::Yes) + } +} + +/// Reference-counted sparse symbolic Cholesky factorization. +#[derive(Debug)] +pub struct SymbolicCholesky { + inner: alloc::sync::Arc>, +} +/// Sparse Cholesky factorization. +#[derive(Clone, Debug)] +pub struct Cholesky { + symbolic: SymbolicCholesky, + values: VecGroup, +} + +/// Reference-counted sparse symbolic QR factorization. +#[derive(Debug)] +pub struct SymbolicQr { + inner: alloc::sync::Arc>, +} +/// Sparse QR factorization. +#[derive(Clone, Debug)] +pub struct Qr { + symbolic: SymbolicQr, + indices: alloc::vec::Vec, + values: VecGroup, +} + +/// Reference-counted sparse symbolic LU factorization. +#[derive(Debug)] +pub struct SymbolicLu { + inner: alloc::sync::Arc>, +} +/// Sparse LU factorization. +#[derive(Clone, Debug)] +pub struct Lu { + symbolic: SymbolicLu, + numeric: super::lu::NumericLu, +} + +impl Clone for SymbolicCholesky { + #[inline] + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} +impl Clone for SymbolicQr { + #[inline] + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} +impl Clone for SymbolicLu { + #[inline] + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl SymbolicCholesky { + /// Returns the symbolic Cholesky factorization of the input matrix. + /// + /// Only the provided side is accessed. + #[track_caller] + pub fn try_new(mat: SymbolicSparseColMatRef<'_, I>, side: Side) -> Result { + Ok(Self { + inner: alloc::sync::Arc::new(super::cholesky::factorize_symbolic_cholesky( + mat, + side, + Default::default(), + )?), + }) + } +} +impl SymbolicQr { + /// Returns the symbolic QR factorization of the input matrix. + #[track_caller] + pub fn try_new(mat: SymbolicSparseColMatRef<'_, I>) -> Result { + Ok(Self { + inner: alloc::sync::Arc::new(super::qr::factorize_symbolic_qr( + mat, + Default::default(), + )?), + }) + } +} +impl SymbolicLu { + /// Returns the symbolic LU factorization of the input matrix. + #[track_caller] + pub fn try_new(mat: SymbolicSparseColMatRef<'_, I>) -> Result { + Ok(Self { + inner: alloc::sync::Arc::new(super::lu::factorize_symbolic_lu( + mat, + Default::default(), + )?), + }) + } +} + +impl Cholesky { + /// Returns the Cholesky factorization of the input matrix with the same sparsity pattern as the + /// original one used to construct the symbolic factorization. + /// + /// Only the provided side is accessed. + #[track_caller] + pub fn try_new_with_symbolic( + symbolic: SymbolicCholesky, + mat: SparseColMatRef<'_, I, E>, + side: Side, + ) -> Result { + let len_values = symbolic.inner.len_values(); + let mut values = VecGroup::new(); + values + .try_reserve_exact(len_values) + .map_err(|_| FaerError::OutOfMemory)?; + values.resize(len_values, E::faer_zero().faer_into_units()); + let parallelism = get_global_parallelism(); + symbolic.inner.factorize_numeric_llt::( + values.as_slice_mut().into_inner(), + mat, + side, + Default::default(), + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + symbolic + .inner + .factorize_numeric_llt_req::(parallelism) + .map_err(|_| FaerError::OutOfMemory)?, + )), + )?; + Ok(Self { symbolic, values }) + } +} + +impl Qr { + /// Returns the QR factorization of the input matrix with the same sparsity pattern as the + /// original one used to construct the symbolic factorization. + #[track_caller] + pub fn try_new_with_symbolic( + symbolic: SymbolicQr, + mat: SparseColMatRef<'_, I, E>, + ) -> Result { + let len_values = symbolic.inner.len_values(); + let len_indices = symbolic.inner.len_indices(); + let mut values = VecGroup::new(); + let mut indices = alloc::vec::Vec::new(); + values + .try_reserve_exact(len_values) + .map_err(|_| FaerError::OutOfMemory)?; + indices + .try_reserve_exact(len_indices) + .map_err(|_| FaerError::OutOfMemory)?; + values.resize(len_values, E::faer_zero().faer_into_units()); + indices.resize(len_indices, I::truncate(0)); + let parallelism = get_global_parallelism(); + symbolic.inner.factorize_numeric_qr::( + &mut indices, + values.as_slice_mut().into_inner(), + mat, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + symbolic + .inner + .factorize_numeric_qr_req::(parallelism) + .map_err(|_| FaerError::OutOfMemory)?, + )), + ); + Ok(Self { + symbolic, + indices, + values, + }) + } +} + +impl Lu { + /// Returns the LU factorization of the input matrix with the same sparsity pattern as the + /// original one used to construct the symbolic factorization. + #[track_caller] + pub fn try_new_with_symbolic( + symbolic: SymbolicLu, + mat: SparseColMatRef<'_, I, E>, + ) -> Result { + let mut numeric = super::lu::NumericLu::new(); + let parallelism = get_global_parallelism(); + symbolic.inner.factorize_numeric_lu::( + &mut numeric, + mat, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + symbolic + .inner + .factorize_numeric_lu_req::(parallelism) + .map_err(|_| FaerError::OutOfMemory)?, + )), + )?; + Ok(Self { symbolic, numeric }) + } +} + +impl SpSolverCore for Cholesky { + #[inline] + fn nrows(&self) -> usize { + self.symbolic.inner.nrows() + } + #[inline] + fn ncols(&self) -> usize { + self.symbolic.inner.ncols() + } + + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + super::cholesky::LltRef::<'_, I, E>::new( + &self.symbolic.inner, + self.values.as_slice().into_inner(), + ) + .solve_in_place_with_conj( + conj, + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + self.symbolic + .inner + .solve_in_place_req::(rhs_ncols) + .unwrap(), + )), + ); + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + super::cholesky::LltRef::<'_, I, E>::new( + &self.symbolic.inner, + self.values.as_slice().into_inner(), + ) + .solve_in_place_with_conj( + conj.compose(Conj::Yes), + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + self.symbolic + .inner + .solve_in_place_req::(rhs_ncols) + .unwrap(), + )), + ); + } +} + +impl SpSolverCore for Qr { + #[inline] + fn nrows(&self) -> usize { + self.symbolic.inner.nrows() + } + #[inline] + fn ncols(&self) -> usize { + self.symbolic.inner.ncols() + } + + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + self.solve_lstsq_in_place_with_conj_impl(rhs, conj); + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let _ = (&rhs, &conj); + unimplemented!( + "the sparse QR decomposition doesn't support solve_transpose.\n\ + consider using the sparse LU or Cholesky instead." + ) + } +} + +impl SpSolverLstsqCore for Qr { + #[track_caller] + fn solve_lstsq_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + unsafe { + super::qr::QrRef::<'_, I, E>::new_unchecked( + &self.symbolic.inner, + &self.indices, + self.values.as_slice().into_inner(), + ) + } + .solve_in_place_with_conj( + conj, + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + self.symbolic + .inner + .solve_in_place_req::(rhs_ncols, parallelism) + .unwrap(), + )), + ); + } +} + +impl SpSolverCore for Lu { + #[inline] + fn nrows(&self) -> usize { + self.symbolic.inner.nrows() + } + #[inline] + fn ncols(&self) -> usize { + self.symbolic.inner.ncols() + } + + #[track_caller] + fn solve_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + unsafe { super::lu::LuRef::<'_, I, E>::new_unchecked(&self.symbolic.inner, &self.numeric) } + .solve_in_place_with_conj( + conj, + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + self.symbolic + .inner + .solve_in_place_req::(rhs_ncols, parallelism) + .unwrap(), + )), + ); + } + + #[track_caller] + fn solve_transpose_in_place_with_conj_impl(&self, rhs: MatMut<'_, E>, conj: Conj) { + let parallelism = get_global_parallelism(); + let rhs_ncols = rhs.ncols(); + unsafe { super::lu::LuRef::<'_, I, E>::new_unchecked(&self.symbolic.inner, &self.numeric) } + .solve_transpose_in_place_with_conj( + conj, + rhs, + parallelism, + PodStack::new(&mut GlobalPodBuffer::new( + self.symbolic + .inner + .solve_in_place_req::(rhs_ncols, parallelism) + .unwrap(), + )), + ); + } +} + +impl SparseColMatRef<'_, I, E> { + /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each column. + #[track_caller] + pub fn sp_solve_lower_triangular_in_place(&self, mut rhs: impl AsMatMut) { + crate::sparse::linalg::triangular_solve::solve_lower_triangular_in_place( + *self, + Conj::No, + rhs.as_mat_mut(), + get_global_parallelism(), + ); + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each column. + #[track_caller] + pub fn sp_solve_upper_triangular_in_place(&self, mut rhs: impl AsMatMut) { + crate::sparse::linalg::triangular_solve::solve_upper_triangular_in_place( + *self, + Conj::No, + rhs.as_mat_mut(), + get_global_parallelism(), + ); + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each column. + #[track_caller] + pub fn sp_solve_unit_lower_triangular_in_place(&self, mut rhs: impl AsMatMut) { + crate::sparse::linalg::triangular_solve::solve_unit_lower_triangular_in_place( + *self, + Conj::No, + rhs.as_mat_mut(), + get_global_parallelism(), + ); + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each column. + #[track_caller] + pub fn sp_solve_unit_upper_triangular_in_place(&self, mut rhs: impl AsMatMut) { + crate::sparse::linalg::triangular_solve::solve_unit_upper_triangular_in_place( + *self, + Conj::No, + rhs.as_mat_mut(), + get_global_parallelism(), + ); + } + + /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn sp_cholesky(&self, side: Side) -> Result, CholeskyError> { + Cholesky::try_new_with_symbolic( + SymbolicCholesky::try_new(self.symbolic(), side)?, + *self, + side, + ) + } + + /// Returns the LU decomposition of `self` with partial (row) pivoting. + #[track_caller] + pub fn sp_lu(&self) -> Result, LuError> { + Lu::try_new_with_symbolic(SymbolicLu::try_new(self.symbolic())?, *self) + } + + /// Returns the QR decomposition of `self`. + #[track_caller] + pub fn sp_qr(&self) -> Result, FaerError> { + Qr::try_new_with_symbolic(SymbolicQr::try_new(self.symbolic())?, *self) + } +} + +impl SparseRowMatRef<'_, I, E> { + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each row. + #[track_caller] + pub fn sp_solve_lower_triangular_in_place(&self, mut rhs: impl AsMatMut) { + crate::sparse::linalg::triangular_solve::solve_upper_triangular_in_place( + self.transpose(), + Conj::No, + rhs.as_mat_mut(), + get_global_parallelism(), + ); + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each row. + #[track_caller] + pub fn sp_solve_upper_triangular_in_place(&self, mut rhs: impl AsMatMut) { + crate::sparse::linalg::triangular_solve::solve_lower_triangular_in_place( + self.transpose(), + Conj::No, + rhs.as_mat_mut(), + get_global_parallelism(), + ); + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each row. + #[track_caller] + pub fn sp_solve_unit_lower_triangular_in_place(&self, mut rhs: impl AsMatMut) { + crate::sparse::linalg::triangular_solve::solve_unit_upper_triangular_in_place( + self.transpose(), + Conj::No, + rhs.as_mat_mut(), + get_global_parallelism(), + ); + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each row. + #[track_caller] + pub fn sp_solve_unit_upper_triangular_in_place(&self, mut rhs: impl AsMatMut) { + crate::sparse::linalg::triangular_solve::solve_unit_lower_triangular_in_place( + self.transpose(), + Conj::No, + rhs.as_mat_mut(), + get_global_parallelism(), + ); + } + + /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn sp_cholesky(&self, side: Side) -> Result, CholeskyError> { + let this = self.to_col_major()?; + let this = this.as_ref(); + Cholesky::try_new_with_symbolic( + SymbolicCholesky::try_new(this.symbolic(), side)?, + this, + side, + ) + } + + /// Returns the LU decomposition of `self` with partial (row) pivoting. + #[track_caller] + pub fn sp_lu(&self) -> Result, LuError> { + let this = self.to_col_major()?; + let this = this.as_ref(); + Lu::try_new_with_symbolic(SymbolicLu::try_new(this.symbolic())?, this) + } + + /// Returns the QR decomposition of `self`. + #[track_caller] + pub fn sp_qr(&self) -> Result, FaerError> { + let this = self.to_col_major()?; + let this = this.as_ref(); + Qr::try_new_with_symbolic(SymbolicQr::try_new(this.symbolic())?, this) + } +} + +impl SparseColMatMut<'_, I, E> { + /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each column. + #[track_caller] + pub fn sp_solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_lower_triangular_in_place(rhs); + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each column. + #[track_caller] + pub fn sp_solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_upper_triangular_in_place(rhs); + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each column. + #[track_caller] + pub fn sp_solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_unit_lower_triangular_in_place(rhs); + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each column. + #[track_caller] + pub fn sp_solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_unit_upper_triangular_in_place(rhs); + } + + /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn sp_cholesky(&self, side: Side) -> Result, CholeskyError> { + self.as_ref().sp_cholesky(side) + } + + /// Returns the LU decomposition of `self` with partial (row) pivoting. + #[track_caller] + pub fn sp_lu(&self) -> Result, LuError> { + self.as_ref().sp_lu() + } + + /// Returns the QR decomposition of `self`. + #[track_caller] + pub fn sp_qr(&self) -> Result, FaerError> { + self.as_ref().sp_qr() + } +} + +impl SparseRowMatMut<'_, I, E> { + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each row. + #[track_caller] + pub fn sp_solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_lower_triangular_in_place(rhs); + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each row. + #[track_caller] + pub fn sp_solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_upper_triangular_in_place(rhs); + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each row. + #[track_caller] + pub fn sp_solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_unit_lower_triangular_in_place(rhs); + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each row. + #[track_caller] + pub fn sp_solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_unit_upper_triangular_in_place(rhs); + } + + /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn sp_cholesky(&self, side: Side) -> Result, CholeskyError> { + self.as_ref().sp_cholesky(side) + } + + /// Returns the LU decomposition of `self` with partial (row) pivoting. + #[track_caller] + pub fn sp_lu(&self) -> Result, LuError> { + self.as_ref().sp_lu() + } + + /// Returns the QR decomposition of `self`. + #[track_caller] + pub fn sp_qr(&self) -> Result, FaerError> { + self.as_ref().sp_qr() + } +} +impl SparseColMat { + /// Assuming `self` is a lower triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each column. + #[track_caller] + pub fn sp_solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_lower_triangular_in_place(rhs); + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each column. + #[track_caller] + pub fn sp_solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_upper_triangular_in_place(rhs); + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each column. + #[track_caller] + pub fn sp_solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_unit_lower_triangular_in_place(rhs); + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each column. + #[track_caller] + pub fn sp_solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_unit_upper_triangular_in_place(rhs); + } + + /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn sp_cholesky(&self, side: Side) -> Result, CholeskyError> { + self.as_ref().sp_cholesky(side) + } + + /// Returns the LU decomposition of `self` with partial (row) pivoting. + #[track_caller] + pub fn sp_lu(&self) -> Result, LuError> { + self.as_ref().sp_lu() + } + + /// Returns the QR decomposition of `self`. + #[track_caller] + pub fn sp_qr(&self) -> Result, FaerError> { + self.as_ref().sp_qr() + } +} + +impl SparseRowMat { + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each row. + #[track_caller] + pub fn sp_solve_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_lower_triangular_in_place(rhs); + } + /// Assuming `self` is an upper triangular matrix, solves the equation `self * X = rhs`, and + /// stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each row. + #[track_caller] + pub fn sp_solve_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_upper_triangular_in_place(rhs); + } + /// Assuming `self` is a unit lower triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the last stored element in each row. + #[track_caller] + pub fn sp_solve_unit_lower_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_unit_lower_triangular_in_place(rhs); + } + /// Assuming `self` is a unit upper triangular matrix, solves the equation `self * X = rhs`, + /// and stores the result in `rhs`. + /// + /// # Note + /// The matrix indices need not be sorted, but + /// the diagonal element is assumed to be the first stored element in each row. + #[track_caller] + pub fn sp_solve_unit_upper_triangular_in_place(&self, rhs: impl AsMatMut) { + self.as_ref().sp_solve_unit_upper_triangular_in_place(rhs); + } + + /// Returns the Cholesky decomposition of `self`. Only the provided side is accessed. + #[track_caller] + pub fn sp_cholesky(&self, side: Side) -> Result, CholeskyError> { + self.as_ref().sp_cholesky(side) + } + + /// Returns the LU decomposition of `self` with partial (row) pivoting. + #[track_caller] + pub fn sp_lu(&self) -> Result, LuError> { + self.as_ref().sp_lu() + } + + /// Returns the QR decomposition of `self`. + #[track_caller] + pub fn sp_qr(&self) -> Result, FaerError> { + self.as_ref().sp_qr() + } +} diff --git a/faer-libs/faer-sparse/src/triangular_solve.rs b/src/sparse/linalg/triangular_solve.rs similarity index 93% rename from faer-libs/faer-sparse/src/triangular_solve.rs rename to src/sparse/linalg/triangular_solve.rs index b192be89a4e7148094f0f9d3f468c1245f3419e6..93477d864b00c6318329135921202e93eefec623 100644 --- a/faer-libs/faer-sparse/src/triangular_solve.rs +++ b/src/sparse/linalg/triangular_solve.rs @@ -1,10 +1,14 @@ +use crate::{assert, sparse::SparseColMatRef, utils::slice::*, Conj, Index, MatMut, Parallelism}; use core::iter::zip; -use faer_core::{ - assert, group_helpers::SliceGroup, permutation::Index, sparse::SparseColMatRef, Conj, MatMut, - Parallelism, -}; use faer_entity::ComplexField; +/// Assuming `self` is a lower triangular matrix, solves the equation `Op(self) * X = rhs`, and +/// stores the result in `rhs`, where `Op` is either the conjugate or the identity depending on the +/// value of `conj`. +/// +/// # Note +/// The matrix indices need not be sorted, but +/// the diagonal element is assumed to be the first stored element in each column. pub fn solve_lower_triangular_in_place( l: SparseColMatRef<'_, I, E>, conj: Conj, @@ -17,13 +21,13 @@ pub fn solve_lower_triangular_in_place( let slice_group = SliceGroup::<'_, E>::new; - faer_core::constrained::Size::with2( + crate::utils::constrained::Size::with2( rhs.nrows(), rhs.ncols(), #[inline(always)] |N, K| { - let mut x = faer_core::constrained::MatMut::new(rhs, N, K); - let l = faer_core::constrained::sparse::SparseColMatRef::new(l, N, N); + let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); + let l = crate::utils::constrained::sparse::SparseColMatRef::new(l, N, N); let mut k = 0usize; while k < *K { @@ -162,6 +166,13 @@ pub fn solve_lower_triangular_in_place( ); } +/// Assuming `self` is a lower triangular matrix, solves the equation `Op(self).transpose() * X = +/// rhs`, and stores the result in `rhs`, where `Op` is either the conjugate or the identity +/// depending on the value of `conj`. +/// +/// # Note +/// The matrix indices need not be sorted, but +/// the diagonal element is assumed to be the first stored element in each column. pub fn solve_lower_triangular_transpose_in_place( l: SparseColMatRef<'_, I, E>, conj: Conj, @@ -174,13 +185,13 @@ pub fn solve_lower_triangular_transpose_in_place( let slice_group = SliceGroup::<'_, E>::new; - faer_core::constrained::Size::with2( + crate::utils::constrained::Size::with2( rhs.nrows(), rhs.ncols(), #[inline(always)] |N, K| { - let mut x = faer_core::constrained::MatMut::new(rhs, N, K); - let l = faer_core::constrained::sparse::SparseColMatRef::new(l, N, N); + let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); + let l = crate::utils::constrained::sparse::SparseColMatRef::new(l, N, N); let mut k = 0usize; while k < *K { @@ -403,6 +414,13 @@ pub fn solve_lower_triangular_transpose_in_place( ); } +/// Assuming `self` is a unit lower triangular matrix, solves the equation `Op(self) * X = rhs`, and +/// stores the result in `rhs`, where `Op` is either the conjugate or the identity depending on the +/// value of `conj`. +/// +/// # Note +/// The matrix indices need not be sorted, but +/// the diagonal element is assumed to be the first stored element in each column. pub fn solve_unit_lower_triangular_in_place( l: SparseColMatRef<'_, I, E>, conj: Conj, @@ -415,13 +433,13 @@ pub fn solve_unit_lower_triangular_in_place( let slice_group = SliceGroup::<'_, E>::new; - faer_core::constrained::Size::with2( + crate::utils::constrained::Size::with2( rhs.nrows(), rhs.ncols(), #[inline(always)] |N, K| { - let mut x = faer_core::constrained::MatMut::new(rhs, N, K); - let l = faer_core::constrained::sparse::SparseColMatRef::new(l, N, N); + let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); + let l = crate::utils::constrained::sparse::SparseColMatRef::new(l, N, N); let mut k = 0usize; while k < *K { @@ -539,6 +557,13 @@ pub fn solve_unit_lower_triangular_in_place( ); } +/// Assuming `self` is a unit lower triangular matrix, solves the equation `Op(self).transpose() * X +/// = rhs`, and stores the result in `rhs`, where `Op` is either the conjugate or the identity +/// depending on the value of `conj`. +/// +/// # Note +/// The matrix indices need not be sorted, but +/// the diagonal element is assumed to be the first stored element in each column. pub fn solve_unit_lower_triangular_transpose_in_place( l: SparseColMatRef<'_, I, E>, conj: Conj, @@ -551,13 +576,13 @@ pub fn solve_unit_lower_triangular_transpose_in_place let slice_group = SliceGroup::<'_, E>::new; - faer_core::constrained::Size::with2( + crate::utils::constrained::Size::with2( rhs.nrows(), rhs.ncols(), #[inline(always)] |N, K| { - let mut x = faer_core::constrained::MatMut::new(rhs, N, K); - let l = faer_core::constrained::sparse::SparseColMatRef::new(l, N, N); + let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); + let l = crate::utils::constrained::sparse::SparseColMatRef::new(l, N, N); let mut k = 0usize; while k < *K { @@ -764,6 +789,13 @@ pub fn solve_unit_lower_triangular_transpose_in_place ); } +/// Assuming `self` is an upper triangular matrix, solves the equation `Op(self) * X = rhs`, and +/// stores the result in `rhs`, where `Op` is either the conjugate or the identity +/// depending on the value of `conj`. +/// +/// # Note +/// The matrix indices need not be sorted, but +/// the diagonal element is assumed to be the last stored element in each column. pub fn solve_upper_triangular_in_place( u: SparseColMatRef<'_, I, E>, conj: Conj, @@ -774,13 +806,13 @@ pub fn solve_upper_triangular_in_place( assert!(u.nrows() == u.ncols()); assert!(rhs.nrows() == u.nrows()); - faer_core::constrained::Size::with2( + crate::utils::constrained::Size::with2( rhs.nrows(), rhs.ncols(), #[inline(always)] |N, K| { - let mut x = faer_core::constrained::MatMut::new(rhs, N, K); - let u = faer_core::constrained::sparse::SparseColMatRef::new(u, N, N); + let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); + let u = crate::utils::constrained::sparse::SparseColMatRef::new(u, N, N); let mut k = 0usize; while k < *K { @@ -942,6 +974,13 @@ pub fn solve_upper_triangular_in_place( ); } +/// Assuming `self` is an upper triangular matrix, solves the equation `Op(self).transpose() * X = +/// rhs`, and stores the result in `rhs`, where `Op` is either the conjugate or the identity +/// depending on the value of `conj`. +/// +/// # Note +/// The matrix indices need not be sorted, but +/// the diagonal element is assumed to be the last stored element in each column. pub fn solve_upper_triangular_transpose_in_place( u: SparseColMatRef<'_, I, E>, conj: Conj, @@ -952,13 +991,13 @@ pub fn solve_upper_triangular_transpose_in_place( assert!(u.nrows() == u.ncols()); assert!(rhs.nrows() == u.nrows()); - faer_core::constrained::Size::with2( + crate::utils::constrained::Size::with2( rhs.nrows(), rhs.ncols(), #[inline(always)] |N, K| { - let mut x = faer_core::constrained::MatMut::new(rhs, N, K); - let u = faer_core::constrained::sparse::SparseColMatRef::new(u, N, N); + let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); + let u = crate::utils::constrained::sparse::SparseColMatRef::new(u, N, N); let mut k = 0usize; while k < *K { @@ -1208,6 +1247,13 @@ pub fn solve_upper_triangular_transpose_in_place( ); } +/// Assuming `self` is a unit upper triangular matrix, solves the equation `Op(self) * X = +/// rhs`, and stores the result in `rhs`, where `Op` is either the conjugate or the identity +/// depending on the value of `conj`. +/// +/// # Note +/// The matrix indices need not be sorted, but +/// the diagonal element is assumed to be the last stored element in each column. pub fn solve_unit_upper_triangular_in_place( u: SparseColMatRef<'_, I, E>, conj: Conj, @@ -1218,13 +1264,13 @@ pub fn solve_unit_upper_triangular_in_place( assert!(u.nrows() == u.ncols()); assert!(rhs.nrows() == u.nrows()); - faer_core::constrained::Size::with2( + crate::utils::constrained::Size::with2( rhs.nrows(), rhs.ncols(), #[inline(always)] |N, K| { - let mut x = faer_core::constrained::MatMut::new(rhs, N, K); - let u = faer_core::constrained::sparse::SparseColMatRef::new(u, N, N); + let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); + let u = crate::utils::constrained::sparse::SparseColMatRef::new(u, N, N); let mut k = 0usize; while k < *K { @@ -1352,6 +1398,13 @@ pub fn solve_unit_upper_triangular_in_place( ); } +/// Assuming `self` is a unit upper triangular matrix, solves the equation `Op(self).transpose() * X +/// = rhs`, and stores the result in `rhs`, where `Op` is either the conjugate or the identity +/// depending on the value of `conj`. +/// +/// # Note +/// The matrix indices need not be sorted, but +/// the diagonal element is assumed to be the last stored element in each column. pub fn solve_unit_upper_triangular_transpose_in_place( u: SparseColMatRef<'_, I, E>, conj: Conj, @@ -1362,13 +1415,13 @@ pub fn solve_unit_upper_triangular_transpose_in_place assert!(u.nrows() == u.ncols()); assert!(rhs.nrows() == u.nrows()); - faer_core::constrained::Size::with2( + crate::utils::constrained::Size::with2( rhs.nrows(), rhs.ncols(), #[inline(always)] |N, K| { - let mut x = faer_core::constrained::MatMut::new(rhs, N, K); - let u = faer_core::constrained::sparse::SparseColMatRef::new(u, N, N); + let mut x = crate::utils::constrained::mat::MatMut::new(rhs, N, K); + let u = crate::utils::constrained::sparse::SparseColMatRef::new(u, N, N); let mut k = 0usize; while k < *K { diff --git a/src/sparse/mod.rs b/src/sparse/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..9c42ee17e454ac14303cb87015c9794af374d731 --- /dev/null +++ b/src/sparse/mod.rs @@ -0,0 +1,1511 @@ +//! Sparse matrix data structures. +//! +//! Most sparse matrix algorithms accept matrices in sparse column-oriented format. +//! This format represents each column of the matrix by storing the row indices of its non-zero +//! elements, as well as their values. +//! +//! The indices and the values are each stored in a contiguous slice (or group of slices for +//! arbitrary values). In order to specify where each column starts and ends, a slice of size +//! `ncols + 1` stores the start of each column, with the last element being equal to the total +//! number of non-zeros (or the capacity in uncompressed mode). +//! +//! # Example +//! +//! Consider the 4-by-5 matrix: +//! ```notcode +//! [[10.0, 0.0, 12.0, -1.0, 13.0] +//! [ 0.0, 0.0, 25.0, -2.0, 0.0] +//! [ 1.0, 0.0, 0.0, 0.0, 0.0] +//! [ 4.0, 0.0, 0.0, 0.0, 5.0]] +//! ``` +//! +//! The matrix is stored as follows: +//! ```notcode +//! column pointers: 0 | 3 | 3 | 5 | 7 | 9 +//! +//! row indices: 0 | 2 | 3 | 0 | 1 | 0 | 1 | 0 | 3 +//! values : 10.0 | 1.0 | 4.0 | 12.0 | 25.0 | -1.0 | -2.0 | 13.0 | 5.0 +//! ``` + +use super::*; +use crate::utils::{slice::*, vec::VecGroup}; +use core::{cell::Cell, iter::zip, ops::Range, slice::SliceIndex}; +use dyn_stack::*; +use faer_entity::*; +use reborrow::*; + +mod ghost { + pub use crate::utils::constrained::{perm::*, sparse::*, *}; +} + +const TOP_BIT: usize = 1usize << (usize::BITS - 1); +const TOP_BIT_MASK: usize = TOP_BIT - 1; + +mod mem { + #[inline] + pub fn fill_zero(slice: &mut [I]) { + let len = slice.len(); + unsafe { core::ptr::write_bytes(slice.as_mut_ptr(), 0u8, len) } + } +} + +#[inline(always)] +#[track_caller] +#[doc(hidden)] +pub unsafe fn __get_unchecked>(slice: &[I], i: R) -> &R::Output { + #[cfg(debug_assertions)] + { + let _ = &slice[i.clone()]; + } + unsafe { slice.get_unchecked(i) } +} +#[inline(always)] +#[track_caller] +#[doc(hidden)] +pub unsafe fn __get_unchecked_mut>( + slice: &mut [I], + i: R, +) -> &mut R::Output { + #[cfg(debug_assertions)] + { + let _ = &slice[i.clone()]; + } + unsafe { slice.get_unchecked_mut(i) } +} + +#[inline(always)] +#[doc(hidden)] +pub fn windows2(slice: &[I]) -> impl DoubleEndedIterator { + slice + .windows(2) + .map(|window| unsafe { &*(window.as_ptr() as *const [I; 2]) }) +} + +#[inline] +#[doc(hidden)] +pub const fn repeat_byte(byte: u8) -> usize { + union Union { + bytes: [u8; 32], + value: usize, + } + + let data = Union { bytes: [byte; 32] }; + unsafe { data.value } +} + +/// Errors that can occur in sparse algorithms. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +#[non_exhaustive] +pub enum FaerError { + /// An index exceeding the maximum value (`I::Signed::MAX` for a given index type `I`). + IndexOverflow, + /// Memory allocation failed. + OutOfMemory, +} + +impl core::fmt::Display for FaerError { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + core::fmt::Debug::fmt(self, f) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for FaerError {} + +/// Errors that can occur in sparse algorithms. +#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +#[non_exhaustive] +pub enum CreationError { + /// Generic error (allocation or index overflow). + Generic(FaerError), + /// Matrix index out-of-bounds error. + OutOfBounds { + /// Row of the out-of-bounds index. + row: usize, + /// Column of the out-of-bounds index. + col: usize, + }, +} + +impl From for CreationError { + #[inline] + fn from(value: FaerError) -> Self { + Self::Generic(value) + } +} +impl core::fmt::Display for CreationError { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + core::fmt::Debug::fmt(self, f) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for CreationError {} + +#[inline] +#[track_caller] +fn try_zeroed(n: usize) -> Result, FaerError> { + let mut v = alloc::vec::Vec::new(); + v.try_reserve_exact(n).map_err(|_| FaerError::OutOfMemory)?; + unsafe { + core::ptr::write_bytes::(v.as_mut_ptr(), 0u8, n); + v.set_len(n); + } + Ok(v) +} + +#[inline] +#[track_caller] +fn try_collect(iter: I) -> Result, FaerError> { + let iter = iter.into_iter(); + let mut v = alloc::vec::Vec::new(); + v.try_reserve_exact(iter.size_hint().0) + .map_err(|_| FaerError::OutOfMemory)?; + v.extend(iter); + Ok(v) +} + +/// The order values should be read in, when constructing/filling from indices and values. +/// +/// Allows separately creating the symbolic structure and filling the numerical values. +#[derive(Debug, Clone)] +pub struct ValuesOrder { + argsort: Vec, + all_nnz: usize, + nnz: usize, + __marker: core::marker::PhantomData, +} + +/// Whether the filled values should replace the current matrix values or be added to them. +#[derive(Debug, Copy, Clone)] +pub enum FillMode { + /// New filled values should replace the old values. + Replace, + /// New filled values should be added to the old values. + Add, +} + +mod csc; +mod csr; + +/// Sparse linear algebra module. +/// Contains low level routines and the implementation of their corresponding high level wrappers. +pub mod linalg; + +/// Sparse matrix binary and ternary operation implementations. +pub mod ops; + +pub use csc::*; +pub use csr::*; + +/// Useful sparse matrix primitives. +pub mod utils { + use super::*; + use crate::{assert, debug_assert}; + + /// Sorts `row_indices` and `values` simultaneously so that `row_indices` is nonincreasing. + pub fn sort_indices( + col_ptrs: &[I], + row_indices: &mut [I], + values: GroupFor, + ) { + assert!(col_ptrs.len() >= 1); + let mut values = SliceGroupMut::<'_, E>::new(values); + + let n = col_ptrs.len() - 1; + for j in 0..n { + let start = col_ptrs[j].zx(); + let end = col_ptrs[j + 1].zx(); + + unsafe { + crate::sort::sort_indices( + &mut row_indices[start..end], + values.rb_mut().subslice(start..end), + ); + } + } + } + + #[doc(hidden)] + pub unsafe fn ghost_permute_hermitian_unsorted<'n, 'out, I: Index, E: ComplexField>( + new_values: SliceGroupMut<'out, E>, + new_col_ptrs: &'out mut [I], + new_row_indices: &'out mut [I], + A: ghost::SparseColMatRef<'n, 'n, '_, I, E>, + perm: ghost::PermRef<'n, '_, I>, + in_side: Side, + out_side: Side, + sort: bool, + stack: PodStack<'_>, + ) -> ghost::SparseColMatMut<'n, 'n, 'out, I, E> { + let N = A.ncols(); + let n = *A.ncols(); + + // (1) + assert!(new_col_ptrs.len() == n + 1); + let (_, perm_inv) = perm.arrays(); + + let (current_row_position, _) = stack.make_raw::(n); + let current_row_position = ghost::Array::from_mut(current_row_position, N); + + mem::fill_zero(current_row_position.as_mut()); + let col_counts = &mut *current_row_position; + match (in_side, out_side) { + (Side::Lower, Side::Lower) => { + for old_j in N.indices() { + let new_j = perm_inv[old_j].zx(); + for old_i in A.row_indices_of_col(old_j) { + if old_i >= old_j { + let new_i = perm_inv[old_i].zx(); + let new_min = Ord::min(new_i, new_j); + // cannot overflow because A.compute_nnz() <= I::MAX + // col_counts[new_max] always >= 0 + col_counts[new_min] += I::truncate(1); + } + } + } + } + (Side::Lower, Side::Upper) => { + for old_j in N.indices() { + let new_j = perm_inv[old_j].zx(); + for old_i in A.row_indices_of_col(old_j) { + if old_i >= old_j { + let new_i = perm_inv[old_i].zx(); + let new_max = Ord::max(new_i, new_j); + // cannot overflow because A.compute_nnz() <= I::MAX + // col_counts[new_max] always >= 0 + col_counts[new_max] += I::truncate(1); + } + } + } + } + (Side::Upper, Side::Lower) => { + for old_j in N.indices() { + let new_j = perm_inv[old_j].zx(); + for old_i in A.row_indices_of_col(old_j) { + if old_i <= old_j { + let new_i = perm_inv[old_i].zx(); + let new_min = Ord::min(new_i, new_j); + // cannot overflow because A.compute_nnz() <= I::MAX + // col_counts[new_max] always >= 0 + col_counts[new_min] += I::truncate(1); + } + } + } + } + (Side::Upper, Side::Upper) => { + for old_j in N.indices() { + let new_j = perm_inv[old_j].zx(); + for old_i in A.row_indices_of_col(old_j) { + if old_i <= old_j { + let new_i = perm_inv[old_i].zx(); + let new_max = Ord::max(new_i, new_j); + // cannot overflow because A.compute_nnz() <= I::MAX + // col_counts[new_max] always >= 0 + col_counts[new_max] += I::truncate(1); + } + } + } + } + } + + // col_counts[_] >= 0 + // cumulative sum cannot overflow because it is <= A.compute_nnz() + + // SAFETY: new_col_ptrs.len() == n + 1 > 0 + new_col_ptrs[0] = I::truncate(0); + for (count, [ci0, ci1]) in zip( + col_counts.as_mut(), + windows2(Cell::as_slice_of_cells(Cell::from_mut(&mut *new_col_ptrs))), + ) { + let ci0 = ci0.get(); + ci1.set(ci0 + *count); + *count = ci0; + } + // new_col_ptrs is non-decreasing + + let nnz = new_col_ptrs[n].zx(); + let new_row_indices = &mut new_row_indices[..nnz]; + let mut new_values = new_values.subslice(0..nnz); + + ghost::Size::with( + nnz, + #[inline(always)] + |NNZ| { + let mut new_values = + ghost::ArrayGroupMut::new(new_values.rb_mut().into_inner(), NNZ); + let new_row_indices = ghost::Array::from_mut(new_row_indices, NNZ); + + let conj_if = |cond: bool, x: E| { + if !coe::is_same::() && cond { + x.faer_conj() + } else { + x + } + }; + + match (in_side, out_side) { + (Side::Lower, Side::Lower) => { + for old_j in N.indices() { + let new_j_ = perm_inv[old_j]; + let new_j = new_j_.zx(); + + for (old_i, val) in zip( + A.row_indices_of_col(old_j), + SliceGroup::<'_, E>::new(A.values_of_col(old_j)).into_ref_iter(), + ) { + if old_i >= old_j { + let new_i_ = perm_inv[old_i]; + let new_i = new_i_.zx(); + + let new_max = Ord::max(new_i_, new_j_); + let new_min = Ord::min(new_i, new_j); + let current_row_pos: &mut I = + &mut current_row_position[new_min]; + // SAFETY: current_row_pos < NNZ + let row_pos = unsafe { + ghost::Idx::new_unchecked(current_row_pos.zx(), NNZ) + }; + *current_row_pos += I::truncate(1); + new_values + .write(row_pos, conj_if(new_min == new_i, val.read())); + // (2) + new_row_indices[row_pos] = *new_max; + } + } + } + } + (Side::Lower, Side::Upper) => { + for old_j in N.indices() { + let new_j_ = perm_inv[old_j]; + let new_j = new_j_.zx(); + + for (old_i, val) in zip( + A.row_indices_of_col(old_j), + SliceGroup::<'_, E>::new(A.values_of_col(old_j)).into_ref_iter(), + ) { + if old_i >= old_j { + let new_i_ = perm_inv[old_i]; + let new_i = new_i_.zx(); + + let new_max = Ord::max(new_i, new_j); + let new_min = Ord::min(new_i_, new_j_); + let current_row_pos = &mut current_row_position[new_max]; + // SAFETY: current_row_pos < NNZ + let row_pos = unsafe { + ghost::Idx::new_unchecked(current_row_pos.zx(), NNZ) + }; + *current_row_pos += I::truncate(1); + new_values + .write(row_pos, conj_if(new_max == new_i, val.read())); + // (2) + new_row_indices[row_pos] = *new_min; + } + } + } + } + (Side::Upper, Side::Lower) => { + for old_j in N.indices() { + let new_j_ = perm_inv[old_j]; + let new_j = new_j_.zx(); + + for (old_i, val) in zip( + A.row_indices_of_col(old_j), + SliceGroup::<'_, E>::new(A.values_of_col(old_j)).into_ref_iter(), + ) { + if old_i <= old_j { + let new_i_ = perm_inv[old_i]; + let new_i = new_i_.zx(); + + let new_max = Ord::max(new_i_, new_j_); + let new_min = Ord::min(new_i, new_j); + let current_row_pos = &mut current_row_position[new_min]; + // SAFETY: current_row_pos < NNZ + let row_pos = unsafe { + ghost::Idx::new_unchecked(current_row_pos.zx(), NNZ) + }; + *current_row_pos += I::truncate(1); + new_values + .write(row_pos, conj_if(new_min == new_i, val.read())); + // (2) + new_row_indices[row_pos] = *new_max; + } + } + } + } + (Side::Upper, Side::Upper) => { + for old_j in N.indices() { + let new_j_ = perm_inv[old_j]; + let new_j = new_j_.zx(); + + for (old_i, val) in zip( + A.row_indices_of_col(old_j), + SliceGroup::<'_, E>::new(A.values_of_col(old_j)).into_ref_iter(), + ) { + if old_i <= old_j { + let new_i_ = perm_inv[old_i]; + let new_i = new_i_.zx(); + + let new_max = Ord::max(new_i, new_j); + let new_min = Ord::min(new_i_, new_j_); + let current_row_pos = &mut current_row_position[new_max]; + // SAFETY: current_row_pos < NNZ + let row_pos = unsafe { + ghost::Idx::new_unchecked(current_row_pos.zx(), NNZ) + }; + *current_row_pos += I::truncate(1); + new_values + .write(row_pos, conj_if(new_max == new_i, val.read())); + // (2) + new_row_indices[row_pos] = *new_min; + } + } + } + } + } + debug_assert!(current_row_position.as_ref() == &new_col_ptrs[1..]); + }, + ); + + if sort { + sort_indices::( + new_col_ptrs, + new_row_indices, + new_values.rb_mut().into_inner(), + ); + } + + // SAFETY: + // 0. new_col_ptrs is non-decreasing + // 1. new_values.len() == new_row_indices.len() + // 2. all written row indices are less than n + unsafe { + ghost::SparseColMatMut::new( + SparseColMatMut::new( + SymbolicSparseColMatRef::new_unchecked( + n, + n, + new_col_ptrs, + None, + new_row_indices, + ), + new_values.into_inner(), + ), + N, + N, + ) + } + } + + #[doc(hidden)] + pub unsafe fn ghost_permute_hermitian_unsorted_symbolic<'n, 'out, I: Index>( + new_col_ptrs: &'out mut [I], + new_row_indices: &'out mut [I], + A: ghost::SymbolicSparseColMatRef<'n, 'n, '_, I>, + perm: ghost::PermRef<'n, '_, I>, + in_side: Side, + out_side: Side, + stack: PodStack<'_>, + ) -> ghost::SymbolicSparseColMatRef<'n, 'n, 'out, I> { + let old_values = &*Symbolic::materialize(A.into_inner().row_indices().len()); + let new_values = Symbolic::materialize(new_row_indices.len()); + *ghost_permute_hermitian_unsorted( + SliceGroupMut::<'_, Symbolic>::new(new_values), + new_col_ptrs, + new_row_indices, + ghost::SparseColMatRef::new( + SparseColMatRef::new(A.into_inner(), old_values), + A.nrows(), + A.ncols(), + ), + perm, + in_side, + out_side, + false, + stack, + ) + } + + /// Computes the self-adjoint permutation $P A P^\top$ of the matrix `A` without sorting the row + /// indices, and returns a view over it. + /// + /// The result is stored in `new_col_ptrs`, `new_row_indices`. + #[doc(hidden)] + pub unsafe fn permute_hermitian_unsorted<'out, I: Index, E: ComplexField>( + new_values: GroupFor, + new_col_ptrs: &'out mut [I], + new_row_indices: &'out mut [I], + A: SparseColMatRef<'_, I, E>, + perm: crate::perm::PermRef<'_, I>, + in_side: Side, + out_side: Side, + stack: PodStack<'_>, + ) -> SparseColMatMut<'out, I, E> { + ghost::Size::with(A.nrows(), |N| { + assert!(A.nrows() == A.ncols()); + ghost_permute_hermitian_unsorted( + SliceGroupMut::new(new_values), + new_col_ptrs, + new_row_indices, + ghost::SparseColMatRef::new(A, N, N), + ghost::PermRef::new(perm, N), + in_side, + out_side, + false, + stack, + ) + .into_inner() + }) + } + + /// Computes the self-adjoint permutation $P A P^\top$ of the matrix `A` and returns a view over + /// it. + /// + /// The result is stored in `new_col_ptrs`, `new_row_indices`. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + pub fn permute_hermitian<'out, I: Index, E: ComplexField>( + new_values: GroupFor, + new_col_ptrs: &'out mut [I], + new_row_indices: &'out mut [I], + A: SparseColMatRef<'_, I, E>, + perm: crate::perm::PermRef<'_, I>, + in_side: Side, + out_side: Side, + stack: PodStack<'_>, + ) -> SparseColMatMut<'out, I, E> { + ghost::Size::with(A.nrows(), |N| { + assert!(A.nrows() == A.ncols()); + unsafe { + ghost_permute_hermitian_unsorted( + SliceGroupMut::new(new_values), + new_col_ptrs, + new_row_indices, + ghost::SparseColMatRef::new(A, N, N), + ghost::PermRef::new(perm, N), + in_side, + out_side, + true, + stack, + ) + } + .into_inner() + }) + } + + #[doc(hidden)] + pub fn ghost_adjoint_symbolic<'m, 'n, 'a, I: Index>( + new_col_ptrs: &'a mut [I], + new_row_indices: &'a mut [I], + A: ghost::SymbolicSparseColMatRef<'m, 'n, '_, I>, + stack: PodStack<'_>, + ) -> ghost::SymbolicSparseColMatRef<'n, 'm, 'a, I> { + let old_values = &*Symbolic::materialize(A.into_inner().row_indices().len()); + let new_values = Symbolic::materialize(new_row_indices.len()); + *ghost_adjoint( + new_col_ptrs, + new_row_indices, + SliceGroupMut::<'_, Symbolic>::new(new_values), + ghost::SparseColMatRef::new( + SparseColMatRef::new(A.into_inner(), old_values), + A.nrows(), + A.ncols(), + ), + stack, + ) + } + + #[doc(hidden)] + pub fn ghost_adjoint<'m, 'n, 'a, I: Index, E: ComplexField>( + new_col_ptrs: &'a mut [I], + new_row_indices: &'a mut [I], + new_values: SliceGroupMut<'a, E>, + A: ghost::SparseColMatRef<'m, 'n, '_, I, E>, + stack: PodStack<'_>, + ) -> ghost::SparseColMatMut<'n, 'm, 'a, I, E> { + let M = A.nrows(); + let N = A.ncols(); + assert!(new_col_ptrs.len() == *M + 1); + + let (col_count, _) = stack.make_raw::(*M); + let col_count = ghost::Array::from_mut(col_count, M); + mem::fill_zero(col_count.as_mut()); + + // can't overflow because the total count is A.compute_nnz() <= I::MAX + for j in N.indices() { + for i in A.row_indices_of_col(j) { + col_count[i] += I::truncate(1); + } + } + + new_col_ptrs[0] = I::truncate(0); + // col_count elements are >= 0 + for (j, [pj0, pj1]) in zip( + M.indices(), + windows2(Cell::as_slice_of_cells(Cell::from_mut(new_col_ptrs))), + ) { + let cj = &mut col_count[j]; + let pj = pj0.get(); + // new_col_ptrs is non-decreasing + pj1.set(pj + *cj); + *cj = pj; + } + + let new_row_indices = &mut new_row_indices[..new_col_ptrs[*M].zx()]; + let mut new_values = new_values.subslice(0..new_col_ptrs[*M].zx()); + let current_row_position = &mut *col_count; + // current_row_position[i] == col_ptr[i] + for j in N.indices() { + let j_: ghost::Idx<'n, I> = j.truncate::(); + for (i, val) in zip( + A.row_indices_of_col(j), + SliceGroup::<'_, E>::new(A.values_of_col(j)).into_ref_iter(), + ) { + let ci = &mut current_row_position[i]; + + // SAFETY: see below + unsafe { + *new_row_indices.get_unchecked_mut(ci.zx()) = *j_; + new_values.write_unchecked(ci.zx(), val.read().faer_conj()) + }; + *ci += I::truncate(1); + } + } + // current_row_position[i] == col_ptr[i] + col_count[i] == col_ptr[i + 1] <= col_ptr[m] + // so all the unchecked accesses were valid and non-overlapping, which means the entire + // array is filled + debug_assert!(current_row_position.as_ref() == &new_col_ptrs[1..]); + + // SAFETY: + // 0. new_col_ptrs is non-decreasing + // 1. all written row indices are less than n + ghost::SparseColMatMut::new( + unsafe { + SparseColMatMut::new( + SymbolicSparseColMatRef::new_unchecked( + *N, + *M, + new_col_ptrs, + None, + new_row_indices, + ), + new_values.into_inner(), + ) + }, + N, + M, + ) + } + + #[doc(hidden)] + pub fn ghost_transpose<'m, 'n, 'a, I: Index, E: Entity>( + new_col_ptrs: &'a mut [I], + new_row_indices: &'a mut [I], + new_values: SliceGroupMut<'a, E>, + A: ghost::SparseColMatRef<'m, 'n, '_, I, E>, + stack: PodStack<'_>, + ) -> ghost::SparseColMatMut<'n, 'm, 'a, I, E> { + let M = A.nrows(); + let N = A.ncols(); + assert!(new_col_ptrs.len() == *M + 1); + + let (col_count, _) = stack.make_raw::(*M); + let col_count = ghost::Array::from_mut(col_count, M); + mem::fill_zero(col_count.as_mut()); + + // can't overflow because the total count is A.compute_nnz() <= I::MAX + for j in N.indices() { + for i in A.row_indices_of_col(j) { + col_count[i] += I::truncate(1); + } + } + + new_col_ptrs[0] = I::truncate(0); + // col_count elements are >= 0 + for (j, [pj0, pj1]) in zip( + M.indices(), + windows2(Cell::as_slice_of_cells(Cell::from_mut(new_col_ptrs))), + ) { + let cj = &mut col_count[j]; + let pj = pj0.get(); + // new_col_ptrs is non-decreasing + pj1.set(pj + *cj); + *cj = pj; + } + + let new_row_indices = &mut new_row_indices[..new_col_ptrs[*M].zx()]; + let mut new_values = new_values.subslice(0..new_col_ptrs[*M].zx()); + let current_row_position = &mut *col_count; + // current_row_position[i] == col_ptr[i] + for j in N.indices() { + let j_: ghost::Idx<'n, I> = j.truncate::(); + for (i, val) in zip( + A.row_indices_of_col(j), + SliceGroup::<'_, E>::new(A.values_of_col(j)).into_ref_iter(), + ) { + let ci = &mut current_row_position[i]; + + // SAFETY: see below + unsafe { + *new_row_indices.get_unchecked_mut(ci.zx()) = *j_; + new_values.write_unchecked(ci.zx(), val.read()) + }; + *ci += I::truncate(1); + } + } + // current_row_position[i] == col_ptr[i] + col_count[i] == col_ptr[i + 1] <= col_ptr[m] + // so all the unchecked accesses were valid and non-overlapping, which means the entire + // array is filled + debug_assert!(current_row_position.as_ref() == &new_col_ptrs[1..]); + + // SAFETY: + // 0. new_col_ptrs is non-decreasing + // 1. all written row indices are less than n + ghost::SparseColMatMut::new( + unsafe { + SparseColMatMut::new( + SymbolicSparseColMatRef::new_unchecked( + *N, + *M, + new_col_ptrs, + None, + new_row_indices, + ), + new_values.into_inner(), + ) + }, + N, + M, + ) + } + + /// Computes the transpose of the matrix `A` and returns a view over it. + /// + /// The result is stored in `new_col_ptrs`, `new_row_indices` and `new_values`. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + pub fn transpose<'a, I: Index, E: Entity>( + new_col_ptrs: &'a mut [I], + new_row_indices: &'a mut [I], + new_values: GroupFor, + A: SparseColMatRef<'_, I, E>, + stack: PodStack<'_>, + ) -> SparseColMatMut<'a, I, E> { + ghost::Size::with2(A.nrows(), A.ncols(), |M, N| { + ghost_transpose( + new_col_ptrs, + new_row_indices, + SliceGroupMut::new(new_values), + ghost::SparseColMatRef::new(A, M, N), + stack, + ) + .into_inner() + }) + } + + /// Computes the adjoint of the matrix `A` and returns a view over it. + /// + /// The result is stored in `new_col_ptrs`, `new_row_indices` and `new_values`. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + pub fn adjoint<'a, I: Index, E: ComplexField>( + new_col_ptrs: &'a mut [I], + new_row_indices: &'a mut [I], + new_values: GroupFor, + A: SparseColMatRef<'_, I, E>, + stack: PodStack<'_>, + ) -> SparseColMatMut<'a, I, E> { + ghost::Size::with2(A.nrows(), A.ncols(), |M, N| { + ghost_adjoint( + new_col_ptrs, + new_row_indices, + SliceGroupMut::new(new_values), + ghost::SparseColMatRef::new(A, M, N), + stack, + ) + .into_inner() + }) + } + + /// Computes the adjoint of the symbolic matrix `A` and returns a view over it. + /// + /// The result is stored in `new_col_ptrs`, `new_row_indices`. + /// + /// # Note + /// Allows unsorted matrices, producing a sorted output. Duplicate entries are kept, however. + pub fn adjoint_symbolic<'a, I: Index>( + new_col_ptrs: &'a mut [I], + new_row_indices: &'a mut [I], + A: SymbolicSparseColMatRef<'_, I>, + stack: PodStack<'_>, + ) -> SymbolicSparseColMatRef<'a, I> { + ghost::Size::with2(A.nrows(), A.ncols(), |M, N| { + ghost_adjoint_symbolic( + new_col_ptrs, + new_row_indices, + ghost::SymbolicSparseColMatRef::new(A, M, N), + stack, + ) + .into_inner() + }) + } +} + +impl core::ops::Index<(usize, usize)> for SparseColMatRef<'_, I, E> { + type Output = E; + + #[track_caller] + fn index(&self, (row, col): (usize, usize)) -> &Self::Output { + self.get(row, col).unwrap() + } +} + +impl core::ops::Index<(usize, usize)> for SparseRowMatRef<'_, I, E> { + type Output = E; + + #[track_caller] + fn index(&self, (row, col): (usize, usize)) -> &Self::Output { + self.get(row, col).unwrap() + } +} + +impl core::ops::Index<(usize, usize)> for SparseColMatMut<'_, I, E> { + type Output = E; + + #[track_caller] + fn index(&self, (row, col): (usize, usize)) -> &Self::Output { + self.rb().get(row, col).unwrap() + } +} + +impl core::ops::Index<(usize, usize)> for SparseRowMatMut<'_, I, E> { + type Output = E; + + #[track_caller] + fn index(&self, (row, col): (usize, usize)) -> &Self::Output { + self.rb().get(row, col).unwrap() + } +} + +impl core::ops::IndexMut<(usize, usize)> for SparseColMatMut<'_, I, E> { + #[track_caller] + fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output { + self.rb_mut().get_mut(row, col).unwrap() + } +} + +impl core::ops::IndexMut<(usize, usize)> for SparseRowMatMut<'_, I, E> { + #[track_caller] + fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output { + self.rb_mut().get_mut(row, col).unwrap() + } +} + +impl core::ops::Index<(usize, usize)> for SparseColMat { + type Output = E; + + #[track_caller] + fn index(&self, (row, col): (usize, usize)) -> &Self::Output { + self.as_ref().get(row, col).unwrap() + } +} + +impl core::ops::Index<(usize, usize)> for SparseRowMat { + type Output = E; + + #[track_caller] + fn index(&self, (row, col): (usize, usize)) -> &Self::Output { + self.as_ref().get(row, col).unwrap() + } +} + +impl core::ops::IndexMut<(usize, usize)> for SparseColMat { + #[track_caller] + fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output { + self.as_mut().get_mut(row, col).unwrap() + } +} + +impl core::ops::IndexMut<(usize, usize)> for SparseRowMat { + #[track_caller] + fn index_mut(&mut self, (row, col): (usize, usize)) -> &mut Self::Output { + self.as_mut().get_mut(row, col).unwrap() + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::Matrix for SparseColMatRef<'_, I, E> { + #[inline] + fn rows(&self) -> usize { + self.nrows() + } + #[inline] + fn cols(&self) -> usize { + self.ncols() + } + #[inline] + fn access(&self) -> matrixcompare_core::Access<'_, E> { + matrixcompare_core::Access::Sparse(self) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::SparseAccess for SparseColMatRef<'_, I, E> { + #[inline] + fn nnz(&self) -> usize { + self.compute_nnz() + } + + #[inline] + fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { + let mut triplets = Vec::new(); + for j in 0..self.ncols() { + for (i, val) in self + .row_indices_of_col(j) + .zip(SliceGroup::<'_, E>::new(self.values_of_col(j)).into_ref_iter()) + { + triplets.push((i, j, val.read())) + } + } + triplets + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::Matrix for SparseRowMatRef<'_, I, E> { + #[inline] + fn rows(&self) -> usize { + self.nrows() + } + #[inline] + fn cols(&self) -> usize { + self.ncols() + } + #[inline] + fn access(&self) -> matrixcompare_core::Access<'_, E> { + matrixcompare_core::Access::Sparse(self) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::SparseAccess for SparseRowMatRef<'_, I, E> { + #[inline] + fn nnz(&self) -> usize { + self.compute_nnz() + } + + #[inline] + fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { + let mut triplets = Vec::new(); + for i in 0..self.nrows() { + for (j, val) in self + .col_indices_of_row(i) + .zip(SliceGroup::<'_, E>::new(self.values_of_row(i)).into_ref_iter()) + { + triplets.push((i, j, val.read())) + } + } + triplets + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::Matrix for SparseColMatMut<'_, I, E> { + #[inline] + fn rows(&self) -> usize { + self.nrows() + } + #[inline] + fn cols(&self) -> usize { + self.ncols() + } + #[inline] + fn access(&self) -> matrixcompare_core::Access<'_, E> { + matrixcompare_core::Access::Sparse(self) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::SparseAccess for SparseColMatMut<'_, I, E> { + #[inline] + fn nnz(&self) -> usize { + self.compute_nnz() + } + + #[inline] + fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { + self.rb().fetch_triplets() + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::Matrix for SparseColMat { + #[inline] + fn rows(&self) -> usize { + self.nrows() + } + #[inline] + fn cols(&self) -> usize { + self.ncols() + } + #[inline] + fn access(&self) -> matrixcompare_core::Access<'_, E> { + matrixcompare_core::Access::Sparse(self) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::SparseAccess for SparseColMat { + #[inline] + fn nnz(&self) -> usize { + self.compute_nnz() + } + + #[inline] + fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { + self.as_ref().fetch_triplets() + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::Matrix for SparseRowMatMut<'_, I, E> { + #[inline] + fn rows(&self) -> usize { + self.nrows() + } + #[inline] + fn cols(&self) -> usize { + self.ncols() + } + #[inline] + fn access(&self) -> matrixcompare_core::Access<'_, E> { + matrixcompare_core::Access::Sparse(self) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::SparseAccess for SparseRowMatMut<'_, I, E> { + #[inline] + fn nnz(&self) -> usize { + self.compute_nnz() + } + + #[inline] + fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { + self.rb().fetch_triplets() + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::Matrix for SparseRowMat { + #[inline] + fn rows(&self) -> usize { + self.nrows() + } + #[inline] + fn cols(&self) -> usize { + self.ncols() + } + #[inline] + fn access(&self) -> matrixcompare_core::Access<'_, E> { + matrixcompare_core::Access::Sparse(self) + } +} + +#[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] +impl matrixcompare_core::SparseAccess for SparseRowMat { + #[inline] + fn nnz(&self) -> usize { + self.compute_nnz() + } + + #[inline] + fn fetch_triplets(&self) -> Vec<(usize, usize, E)> { + self.as_ref().fetch_triplets() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::assert; + + #[test] + fn test_from_indices() { + let nrows = 5; + let ncols = 4; + + let indices = &[(0, 0), (1, 2), (0, 0), (1, 1), (0, 1), (3, 3), (3, 3usize)]; + let values = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0f64]; + + let triplets = &[ + (0, 0, 1.0), + (1, 2, 2.0), + (0, 0, 3.0), + (1, 1, 4.0), + (0, 1, 5.0), + (3, 3, 6.0), + (3, 3usize, 7.0), + ]; + + { + let mat = SymbolicSparseColMat::try_new_from_indices(nrows, ncols, indices); + assert!(mat.is_ok()); + + let (mat, order) = mat.unwrap(); + assert!(mat.nrows() == nrows); + assert!(mat.ncols() == ncols); + assert!(mat.col_ptrs() == &[0, 1, 3, 4, 5]); + assert!(mat.nnz_per_col() == None); + assert!(mat.row_indices() == &[0, 0, 1, 1, 3]); + + let mat = + SparseColMat::<_, f64>::new_from_order_and_values(mat, &order, values).unwrap(); + assert!(mat.as_ref().values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); + } + + { + let mat = SparseColMat::try_new_from_triplets(nrows, ncols, triplets); + assert!(mat.is_ok()); + let mat = mat.unwrap(); + + assert!(mat.nrows() == nrows); + assert!(mat.ncols() == ncols); + assert!(mat.col_ptrs() == &[0, 1, 3, 4, 5]); + assert!(mat.nnz_per_col() == None); + assert!(mat.row_indices() == &[0, 0, 1, 1, 3]); + assert!(mat.values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); + } + + { + let mat = SymbolicSparseRowMat::try_new_from_indices(nrows, ncols, indices); + assert!(mat.is_ok()); + + let (mat, order) = mat.unwrap(); + assert!(mat.nrows() == nrows); + assert!(mat.ncols() == ncols); + assert!(mat.row_ptrs() == &[0, 2, 4, 4, 5, 5]); + assert!(mat.nnz_per_row() == None); + assert!(mat.col_indices() == &[0, 1, 1, 2, 3]); + + let mat = + SparseRowMat::<_, f64>::new_from_order_and_values(mat, &order, values).unwrap(); + assert!(mat.values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); + } + { + let mat = SparseRowMat::try_new_from_triplets(nrows, ncols, triplets); + assert!(mat.is_ok()); + + let mat = mat.unwrap(); + assert!(mat.nrows() == nrows); + assert!(mat.ncols() == ncols); + assert!(mat.row_ptrs() == &[0, 2, 4, 4, 5, 5]); + assert!(mat.nnz_per_row() == None); + assert!(mat.col_indices() == &[0, 1, 1, 2, 3]); + assert!(mat.as_ref().values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); + } + } + + #[test] + fn test_from_nonnegative_indices() { + let nrows = 5; + let ncols = 4; + + let indices = &[ + (0, 0), + (1, 2), + (0, 0), + (1, 1), + (0, 1), + (-1, 2), + (-2, 1), + (-3, -4), + (3, 3), + (3, 3isize), + ]; + let values = &[ + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + f64::NAN, + f64::NAN, + f64::NAN, + 6.0, + 7.0f64, + ]; + + let triplets = &[ + (0, 0, 1.0), + (1, 2, 2.0), + (0, 0, 3.0), + (1, 1, 4.0), + (0, 1, 5.0), + (-1, 2, f64::NAN), + (-2, 1, f64::NAN), + (-3, -4, f64::NAN), + (3, 3, 6.0), + (3, 3isize, 7.0), + ]; + + { + let mat = SymbolicSparseColMat::::try_new_from_nonnegative_indices( + nrows, ncols, indices, + ); + assert!(mat.is_ok()); + + let (mat, order) = mat.unwrap(); + assert!(mat.nrows() == nrows); + assert!(mat.ncols() == ncols); + assert!(mat.col_ptrs() == &[0, 1, 3, 4, 5]); + assert!(mat.nnz_per_col() == None); + assert!(mat.row_indices() == &[0, 0, 1, 1, 3]); + + let mat = + SparseColMat::<_, f64>::new_from_order_and_values(mat, &order, values).unwrap(); + assert!(mat.as_ref().values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); + } + + { + let mat = + SparseColMat::::try_new_from_nonnegative_triplets(nrows, ncols, triplets); + assert!(mat.is_ok()); + let mat = mat.unwrap(); + + assert!(mat.nrows() == nrows); + assert!(mat.ncols() == ncols); + assert!(mat.col_ptrs() == &[0, 1, 3, 4, 5]); + assert!(mat.nnz_per_col() == None); + assert!(mat.row_indices() == &[0, 0, 1, 1, 3]); + assert!(mat.values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); + } + + { + let mat = SymbolicSparseRowMat::::try_new_from_nonnegative_indices( + nrows, ncols, indices, + ); + assert!(mat.is_ok()); + + let (mat, order) = mat.unwrap(); + assert!(mat.nrows() == nrows); + assert!(mat.ncols() == ncols); + assert!(mat.row_ptrs() == &[0, 2, 4, 4, 5, 5]); + assert!(mat.nnz_per_row() == None); + assert!(mat.col_indices() == &[0, 1, 1, 2, 3]); + + let mat = + SparseRowMat::<_, f64>::new_from_order_and_values(mat, &order, values).unwrap(); + assert!(mat.values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); + } + { + let mat = + SparseRowMat::::try_new_from_nonnegative_triplets(nrows, ncols, triplets); + assert!(mat.is_ok()); + + let mat = mat.unwrap(); + assert!(mat.nrows() == nrows); + assert!(mat.ncols() == ncols); + assert!(mat.row_ptrs() == &[0, 2, 4, 4, 5, 5]); + assert!(mat.nnz_per_row() == None); + assert!(mat.col_indices() == &[0, 1, 1, 2, 3]); + assert!(mat.as_ref().values() == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); + } + { + let order = SymbolicSparseRowMat::::try_new_from_nonnegative_indices( + nrows, ncols, indices, + ) + .unwrap() + .1; + + let new_values = &mut [f64::NAN; 5]; + let mut mat = SparseRowMatMut::<'_, usize, f64>::new( + SymbolicSparseRowMatRef::new_checked( + nrows, + ncols, + &[0, 2, 4, 4, 5, 5], + None, + &[0, 1, 1, 2, 3], + ), + new_values, + ); + mat.fill_from_order_and_values(&order, values, FillMode::Replace); + + assert!(&*new_values == &[1.0 + 3.0, 5.0, 4.0, 2.0, 6.0 + 7.0]); + } + } + + #[test] + fn test_from_indices_oob_row() { + let nrows = 5; + let ncols = 4; + + let indices = &[ + (0, 0), + (1, 2), + (0, 0), + (1, 1), + (0, 1), + (3, 3), + (3, 3), + (5, 3usize), + ]; + let err = SymbolicSparseColMat::try_new_from_indices(nrows, ncols, indices); + assert!(err.is_err()); + let err = err.unwrap_err(); + assert!(err == CreationError::OutOfBounds { row: 5, col: 3 }); + } + + #[test] + fn test_from_indices_oob_col() { + let nrows = 5; + let ncols = 4; + + let indices = &[ + (0, 0), + (1, 2), + (0, 0), + (1, 1), + (0, 1), + (3, 3), + (3, 3), + (2, 4usize), + ]; + let err = SymbolicSparseColMat::try_new_from_indices(nrows, ncols, indices); + assert!(err.is_err()); + let err = err.unwrap_err(); + assert!(err == CreationError::OutOfBounds { row: 2, col: 4 }); + } + + #[test] + fn test_add_intersecting() { + let lhs = SparseColMat::::try_new_from_triplets( + 5, + 4, + &[ + (1, 0, 1.0), + (2, 1, 2.0), + (3, 2, 3.0), + (0, 0, 4.0), + (1, 1, 5.0), + (2, 2, 6.0), + (3, 3, 7.0), + (2, 0, 8.0), + (3, 1, 9.0), + (4, 2, 10.0), + (0, 2, 11.0), + (1, 3, 12.0), + (4, 0, 13.0), + ], + ) + .unwrap(); + + let rhs = SparseColMat::::try_new_from_triplets( + 5, + 4, + &[ + (1, 0, 10.0), + (2, 1, 14.0), + (3, 2, 15.0), + (4, 3, 16.0), + (0, 1, 17.0), + (1, 2, 18.0), + (2, 3, 19.0), + (3, 0, 20.0), + (4, 1, 21.0), + (0, 3, 22.0), + ], + ) + .unwrap(); + + let sum = ops::add(lhs.as_ref(), rhs.as_ref()).unwrap(); + assert!(sum.compute_nnz() == lhs.compute_nnz() + rhs.compute_nnz() - 3); + + for j in 0..4 { + for i in 0..5 { + assert!(sum.row_indices_of_col_raw(j)[i] == i); + } + } + + for j in 0..4 { + for i in 0..5 { + assert!( + sum[(i, j)] == lhs.get(i, j).unwrap_or(&0.0) + rhs.get(i, j).unwrap_or(&0.0) + ); + } + } + } + + #[test] + fn test_add_disjoint() { + let lhs = SparseColMat::::try_new_from_triplets( + 5, + 4, + &[ + (0, 0, 1.0), + (1, 1, 2.0), + (2, 2, 3.0), + (3, 3, 4.0), + (2, 0, 5.0), + (3, 1, 6.0), + (4, 2, 7.0), + (0, 2, 8.0), + (1, 3, 9.0), + (4, 0, 10.0), + ], + ) + .unwrap(); + + let rhs = SparseColMat::::try_new_from_triplets( + 5, + 4, + &[ + (1, 0, 11.0), + (2, 1, 12.0), + (3, 2, 13.0), + (4, 3, 14.0), + (0, 1, 15.0), + (1, 2, 16.0), + (2, 3, 17.0), + (3, 0, 18.0), + (4, 1, 19.0), + (0, 3, 20.0), + ], + ) + .unwrap(); + + let sum = ops::add(lhs.as_ref(), rhs.as_ref()).unwrap(); + assert!(sum.compute_nnz() == lhs.compute_nnz() + rhs.compute_nnz()); + + for j in 0..4 { + for i in 0..5 { + assert!(sum.row_indices_of_col_raw(j)[i] == i); + } + } + + for j in 0..4 { + for i in 0..5 { + assert!( + sum[(i, j)] == lhs.get(i, j).unwrap_or(&0.0) + rhs.get(i, j).unwrap_or(&0.0) + ); + } + } + } +} diff --git a/src/sparse/ops.rs b/src/sparse/ops.rs new file mode 100644 index 0000000000000000000000000000000000000000..0a8055ef25f2afb0c5e29e8c40bffe20d840de94 --- /dev/null +++ b/src/sparse/ops.rs @@ -0,0 +1,485 @@ +use super::*; +use crate::assert; + +/// Returns the resulting matrix obtained by applying `f` to the elements from `lhs` and `rhs`, +/// skipping entries that are unavailable in both of `lhs` and `rhs`. +/// +/// # Panics +/// Panics if `lhs` and `rhs` don't have matching dimensions. +#[track_caller] +pub fn binary_op( + lhs: SparseColMatRef<'_, I, LhsE>, + rhs: SparseColMatRef<'_, I, RhsE>, + f: impl FnMut(LhsE, RhsE) -> E, +) -> Result, FaerError> { + assert!(lhs.nrows() == rhs.nrows()); + assert!(lhs.ncols() == rhs.ncols()); + let mut f = f; + let m = lhs.nrows(); + let n = lhs.ncols(); + + let mut col_ptrs = try_zeroed::(n + 1)?; + + let mut nnz = 0usize; + for j in 0..n { + let lhs = lhs.row_indices_of_col_raw(j); + let rhs = rhs.row_indices_of_col_raw(j); + + let mut lhs_pos = 0usize; + let mut rhs_pos = 0usize; + while lhs_pos < lhs.len() && rhs_pos < rhs.len() { + let lhs = lhs[lhs_pos]; + let rhs = rhs[rhs_pos]; + + lhs_pos += (lhs <= rhs) as usize; + rhs_pos += (rhs <= lhs) as usize; + nnz += 1; + } + nnz += lhs.len() - lhs_pos; + nnz += rhs.len() - rhs_pos; + col_ptrs[j + 1] = I::truncate(nnz); + } + + if nnz > I::Signed::MAX.zx() { + return Err(FaerError::IndexOverflow); + } + + let mut row_indices = try_zeroed(nnz)?; + let mut values = VecGroup::::new(); + values + .try_reserve_exact(nnz) + .map_err(|_| FaerError::OutOfMemory)?; + values.resize(nnz, unsafe { core::mem::zeroed() }); + + let mut nnz = 0usize; + for j in 0..n { + let mut values = values.as_slice_mut(); + let lhs_values = SliceGroup::::new(lhs.values_of_col(j)); + let rhs_values = SliceGroup::::new(rhs.values_of_col(j)); + let lhs = lhs.row_indices_of_col_raw(j); + let rhs = rhs.row_indices_of_col_raw(j); + + let mut lhs_pos = 0usize; + let mut rhs_pos = 0usize; + while lhs_pos < lhs.len() && rhs_pos < rhs.len() { + let lhs = lhs[lhs_pos]; + let rhs = rhs[rhs_pos]; + + match lhs.cmp(&rhs) { + core::cmp::Ordering::Less => { + row_indices[nnz] = lhs; + values.write( + nnz, + f(lhs_values.read(lhs_pos), unsafe { core::mem::zeroed() }), + ); + } + core::cmp::Ordering::Equal => { + row_indices[nnz] = lhs; + values.write(nnz, f(lhs_values.read(lhs_pos), rhs_values.read(rhs_pos))); + } + core::cmp::Ordering::Greater => { + row_indices[nnz] = rhs; + values.write( + nnz, + f(unsafe { core::mem::zeroed() }, rhs_values.read(rhs_pos)), + ); + } + } + + lhs_pos += (lhs <= rhs) as usize; + rhs_pos += (rhs <= lhs) as usize; + nnz += 1; + } + row_indices[nnz..nnz + lhs.len() - lhs_pos].copy_from_slice(&lhs[lhs_pos..]); + for (mut dst, src) in values + .rb_mut() + .subslice(nnz..nnz + lhs.len() - lhs_pos) + .into_mut_iter() + .zip(lhs_values.subslice(lhs_pos..lhs.len()).into_ref_iter()) + { + dst.write(f(src.read(), unsafe { core::mem::zeroed() })); + } + nnz += lhs.len() - lhs_pos; + + row_indices[nnz..nnz + rhs.len() - rhs_pos].copy_from_slice(&rhs[rhs_pos..]); + for (mut dst, src) in values + .rb_mut() + .subslice(nnz..nnz + rhs.len() - rhs_pos) + .into_mut_iter() + .zip(rhs_values.subslice(rhs_pos..rhs.len()).into_ref_iter()) + { + dst.write(f(unsafe { core::mem::zeroed() }, src.read())); + } + nnz += rhs.len() - rhs_pos; + } + + Ok(SparseColMat::::new( + SymbolicSparseColMat::::new_checked(m, n, col_ptrs, None, row_indices), + values.into_inner(), + )) +} + +/// Returns the resulting matrix obtained by applying `f` to the elements from `dst` and `src` +/// skipping entries that are unavailable in both of them. +/// The sparsity patter of `dst` is unchanged. +/// +/// # Panics +/// Panics if `src` and `dst` don't have matching dimensions. +/// Panics if `src` contains an index that's unavailable in `dst`. +#[track_caller] +pub fn binary_op_assign_into( + dst: SparseColMatMut<'_, I, E>, + src: SparseColMatRef<'_, I, SrcE>, + f: impl FnMut(E, SrcE) -> E, +) { + { + assert!(dst.nrows() == src.nrows()); + assert!(dst.ncols() == src.ncols()); + + let n = dst.ncols(); + let mut dst = dst; + let mut f = f; + unsafe { + assert!(f(core::mem::zeroed(), core::mem::zeroed()) == core::mem::zeroed()); + } + + for j in 0..n { + let (dst, dst_val) = dst.rb_mut().into_parts_mut(); + + let mut dst_val = SliceGroupMut::::new(dst_val).subslice(dst.col_range(j)); + let src_val = SliceGroup::::new(src.values_of_col(j)); + + let dst = dst.row_indices_of_col_raw(j); + let src = src.row_indices_of_col_raw(j); + + let mut dst_pos = 0usize; + let mut src_pos = 0usize; + + while src_pos < src.len() { + let src = src[src_pos]; + + if dst[dst_pos] < src { + dst_val.write( + dst_pos, + f(dst_val.read(dst_pos), unsafe { core::mem::zeroed() }), + ); + dst_pos += 1; + continue; + } + + assert!(dst[dst_pos] == src); + + dst_val.write(dst_pos, f(dst_val.read(dst_pos), src_val.read(src_pos))); + + src_pos += 1; + dst_pos += 1; + } + while dst_pos < dst.len() { + dst_val.write( + dst_pos, + f(dst_val.read(dst_pos), unsafe { core::mem::zeroed() }), + ); + dst_pos += 1; + } + } + } +} + +/// Returns the resulting matrix obtained by applying `f` to the elements from `dst`, `lhs` and +/// `rhs`, skipping entries that are unavailable in all of `dst`, `lhs` and `rhs`. +/// The sparsity patter of `dst` is unchanged. +/// +/// # Panics +/// Panics if `lhs`, `rhs` and `dst` don't have matching dimensions. +/// Panics if `lhs` or `rhs` contains an index that's unavailable in `dst`. +#[track_caller] +pub fn ternary_op_assign_into( + dst: SparseColMatMut<'_, I, E>, + lhs: SparseColMatRef<'_, I, LhsE>, + rhs: SparseColMatRef<'_, I, RhsE>, + f: impl FnMut(E, LhsE, RhsE) -> E, +) { + { + assert!(dst.nrows() == lhs.nrows()); + assert!(dst.ncols() == lhs.ncols()); + assert!(dst.nrows() == rhs.nrows()); + assert!(dst.ncols() == rhs.ncols()); + + let n = dst.ncols(); + let mut dst = dst; + let mut f = f; + unsafe { + assert!( + f( + core::mem::zeroed(), + core::mem::zeroed(), + core::mem::zeroed() + ) == core::mem::zeroed() + ); + } + + for j in 0..n { + let (dst, dst_val) = dst.rb_mut().into_parts_mut(); + + let mut dst_val = SliceGroupMut::::new(dst_val); + let lhs_val = SliceGroup::::new(lhs.values_of_col(j)); + let rhs_val = SliceGroup::::new(rhs.values_of_col(j)); + + let dst = dst.row_indices_of_col_raw(j); + let rhs = rhs.row_indices_of_col_raw(j); + let lhs = lhs.row_indices_of_col_raw(j); + + let mut dst_pos = 0usize; + let mut lhs_pos = 0usize; + let mut rhs_pos = 0usize; + + while lhs_pos < lhs.len() && rhs_pos < rhs.len() { + let lhs = lhs[lhs_pos]; + let rhs = rhs[rhs_pos]; + + if dst[dst_pos] < Ord::min(lhs, rhs) { + dst_val.write( + dst_pos, + f( + dst_val.read(dst_pos), + unsafe { core::mem::zeroed() }, + unsafe { core::mem::zeroed() }, + ), + ); + dst_pos += 1; + continue; + } + + assert!(dst[dst_pos] == Ord::min(lhs, rhs)); + + match lhs.cmp(&rhs) { + core::cmp::Ordering::Less => { + dst_val.write( + dst_pos, + f(dst_val.read(dst_pos), lhs_val.read(lhs_pos), unsafe { + core::mem::zeroed() + }), + ); + } + core::cmp::Ordering::Equal => { + dst_val.write( + dst_pos, + f( + dst_val.read(dst_pos), + lhs_val.read(lhs_pos), + rhs_val.read(rhs_pos), + ), + ); + } + core::cmp::Ordering::Greater => { + dst_val.write( + dst_pos, + f( + dst_val.read(dst_pos), + unsafe { core::mem::zeroed() }, + rhs_val.read(rhs_pos), + ), + ); + } + } + + lhs_pos += (lhs <= rhs) as usize; + rhs_pos += (rhs <= lhs) as usize; + dst_pos += 1; + } + while lhs_pos < lhs.len() { + let lhs = lhs[lhs_pos]; + if dst[dst_pos] < lhs { + dst_val.write( + dst_pos, + f( + dst_val.read(dst_pos), + unsafe { core::mem::zeroed() }, + unsafe { core::mem::zeroed() }, + ), + ); + dst_pos += 1; + continue; + } + dst_val.write( + dst_pos, + f(dst_val.read(dst_pos), lhs_val.read(lhs_pos), unsafe { + core::mem::zeroed() + }), + ); + lhs_pos += 1; + dst_pos += 1; + } + while rhs_pos < rhs.len() { + let rhs = rhs[rhs_pos]; + if dst[dst_pos] < rhs { + dst_val.write( + dst_pos, + f( + dst_val.read(dst_pos), + unsafe { core::mem::zeroed() }, + unsafe { core::mem::zeroed() }, + ), + ); + dst_pos += 1; + continue; + } + dst_val.write( + dst_pos, + f( + dst_val.read(dst_pos), + unsafe { core::mem::zeroed() }, + rhs_val.read(rhs_pos), + ), + ); + rhs_pos += 1; + dst_pos += 1; + } + while rhs_pos < rhs.len() { + let rhs = rhs[rhs_pos]; + dst_pos += dst[dst_pos..].binary_search(&rhs).unwrap(); + dst_val.write( + dst_pos, + f( + dst_val.read(dst_pos), + unsafe { core::mem::zeroed() }, + rhs_val.read(rhs_pos), + ), + ); + rhs_pos += 1; + } + } + } +} + +/// Returns the sparsity pattern containing the union of those of `lhs` and `rhs`. +/// +/// # Panics +/// Panics if `lhs` and `rhs` don't have mathcing dimensions. +#[track_caller] +#[inline] +pub fn union_symbolic( + lhs: SymbolicSparseColMatRef<'_, I>, + rhs: SymbolicSparseColMatRef<'_, I>, +) -> Result, FaerError> { + Ok(binary_op( + SparseColMatRef::::new(lhs, Symbolic::materialize(lhs.compute_nnz())), + SparseColMatRef::::new(rhs, Symbolic::materialize(rhs.compute_nnz())), + #[inline(always)] + |_, _| Symbolic, + )? + .into_parts() + .0) +} + +/// Returns the sum of `lhs` and `rhs`. +/// +/// # Panics +/// Panics if `lhs` and `rhs` don't have mathcing dimensions. +#[track_caller] +#[inline] +pub fn add< + I: Index, + E: ComplexField, + LhsE: Conjugate, + RhsE: Conjugate, +>( + lhs: SparseColMatRef<'_, I, LhsE>, + rhs: SparseColMatRef<'_, I, RhsE>, +) -> Result, FaerError> { + binary_op(lhs, rhs, |lhs, rhs| { + lhs.canonicalize().faer_add(rhs.canonicalize()) + }) +} + +/// Returns the difference of `lhs` and `rhs`. +/// +/// # Panics +/// Panics if `lhs` and `rhs` don't have matching dimensions. +#[track_caller] +#[inline] +pub fn sub< + I: Index, + LhsE: Conjugate, + RhsE: Conjugate, + E: ComplexField, +>( + lhs: SparseColMatRef<'_, I, LhsE>, + rhs: SparseColMatRef<'_, I, RhsE>, +) -> Result, FaerError> { + binary_op(lhs, rhs, |lhs, rhs| { + lhs.canonicalize().faer_sub(rhs.canonicalize()) + }) +} + +/// Computes the sum of `dst` and `src` and stores the result in `dst` without changing its +/// symbolic structure. +/// +/// # Panics +/// Panics if `dst` and `rhs` don't have matching dimensions. +/// Panics if `rhs` contains an index that's unavailable in `dst`. +pub fn add_assign>( + dst: SparseColMatMut<'_, I, E>, + rhs: SparseColMatRef<'_, I, RhsE>, +) { + binary_op_assign_into(dst, rhs, |dst, rhs| dst.faer_add(rhs.canonicalize())) +} + +/// Computes the difference of `dst` and `src` and stores the result in `dst` without changing +/// its symbolic structure. +/// +/// # Panics +/// Panics if `dst` and `rhs` don't have matching dimensions. +/// Panics if `rhs` contains an index that's unavailable in `dst`. +pub fn sub_assign>( + dst: SparseColMatMut<'_, I, E>, + rhs: SparseColMatRef<'_, I, RhsE>, +) { + binary_op_assign_into(dst, rhs, |dst, rhs| dst.faer_sub(rhs.canonicalize())) +} + +/// Computes the sum of `lhs` and `rhs`, storing the result in `dst` without changing its +/// symbolic structure. +/// +/// # Panics +/// Panics if `dst`, `lhs` and `rhs` don't have matching dimensions. +/// Panics if `lhs` or `rhs` contains an index that's unavailable in `dst`. +#[track_caller] +#[inline] +pub fn add_into< + I: Index, + E: ComplexField, + LhsE: Conjugate, + RhsE: Conjugate, +>( + dst: SparseColMatMut<'_, I, E>, + lhs: SparseColMatRef<'_, I, LhsE>, + rhs: SparseColMatRef<'_, I, RhsE>, +) { + ternary_op_assign_into(dst, lhs, rhs, |_, lhs, rhs| { + lhs.canonicalize().faer_add(rhs.canonicalize()) + }) +} + +/// Computes the difference of `lhs` and `rhs`, storing the result in `dst` without changing its +/// symbolic structure. +/// +/// # Panics +/// Panics if `dst`, `lhs` and `rhs` don't have matching dimensions. +/// Panics if `lhs` or `rhs` contains an index that's unavailable in `dst`. +#[track_caller] +#[inline] +pub fn sub_into< + I: Index, + E: ComplexField, + LhsE: Conjugate, + RhsE: Conjugate, +>( + dst: SparseColMatMut<'_, I, E>, + lhs: SparseColMatRef<'_, I, LhsE>, + rhs: SparseColMatRef<'_, I, RhsE>, +) { + ternary_op_assign_into(dst, lhs, rhs, |_, lhs, rhs| { + lhs.canonicalize().faer_sub(rhs.canonicalize()) + }) +} diff --git a/src/utils/constrained/mat.rs b/src/utils/constrained/mat.rs new file mode 100644 index 0000000000000000000000000000000000000000..840bbde0e2e044c2fdd7de9e585897de96d00fdf --- /dev/null +++ b/src/utils/constrained/mat.rs @@ -0,0 +1,211 @@ +use super::*; +use crate::{assert, mat}; + +/// Immutable dense matrix view with dimensions equal to the values tied to `('nrows, 'ncols)`. +#[repr(transparent)] +pub struct MatRef<'nrows, 'ncols, 'a, E: Entity>( + Branded<'ncols, Branded<'nrows, mat::MatRef<'a, E>>>, +); +/// Mutable dense matrix view with dimensions equal to the values tied to `('nrows, 'ncols)`. +#[repr(transparent)] +pub struct MatMut<'nrows, 'ncols, 'a, E: Entity>( + Branded<'ncols, Branded<'nrows, mat::MatMut<'a, E>>>, +); + +impl<'nrows, 'ncols, 'a, E: Entity> MatRef<'nrows, 'ncols, 'a, E> { + /// Returns a new matrix view after checking that its dimensions match the + /// dimensions tied to `('nrows, 'ncols)`. + #[inline] + #[track_caller] + pub fn new(inner: mat::MatRef<'a, E>, nrows: Size<'nrows>, ncols: Size<'ncols>) -> Self { + assert!(all( + inner.nrows() == nrows.into_inner(), + inner.ncols() == ncols.into_inner(), + )); + Self(Branded { + __marker: PhantomData, + inner: Branded { + __marker: PhantomData, + inner, + }, + }) + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> Size<'nrows> { + unsafe { Size::new_raw_unchecked(self.0.inner.inner.nrows()) } + } + + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> Size<'ncols> { + unsafe { Size::new_raw_unchecked(self.0.inner.inner.ncols()) } + } + + /// Returns the unconstrained matrix. + #[inline] + pub fn into_inner(self) -> mat::MatRef<'a, E> { + self.0.inner.inner + } + + /// Returns the element at position `(i, j)`. + #[inline] + #[track_caller] + pub fn read(&self, i: Idx<'nrows, usize>, j: Idx<'ncols, usize>) -> E { + unsafe { + self.0 + .inner + .inner + .read_unchecked(i.into_inner(), j.into_inner()) + } + } +} + +impl<'nrows, 'ncols, 'a, E: Entity> MatMut<'nrows, 'ncols, 'a, E> { + /// Returns a new matrix view after checking that its dimensions match the + /// dimensions tied to `('nrows, 'ncols)`. + #[inline] + #[track_caller] + pub fn new(inner: mat::MatMut<'a, E>, nrows: Size<'nrows>, ncols: Size<'ncols>) -> Self { + assert!(all( + inner.nrows() == nrows.into_inner(), + inner.ncols() == ncols.into_inner(), + )); + Self(Branded { + __marker: PhantomData, + inner: Branded { + __marker: PhantomData, + inner, + }, + }) + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> Size<'nrows> { + unsafe { Size::new_raw_unchecked(self.0.inner.inner.nrows()) } + } + + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> Size<'ncols> { + unsafe { Size::new_raw_unchecked(self.0.inner.inner.ncols()) } + } + + /// Returns the unconstrained matrix. + #[inline] + pub fn into_inner(self) -> mat::MatMut<'a, E> { + self.0.inner.inner + } + + /// Returns the element at position `(i, j)`. + #[inline] + #[track_caller] + pub fn read(&self, i: Idx<'nrows, usize>, j: Idx<'ncols, usize>) -> E { + unsafe { + self.0 + .inner + .inner + .read_unchecked(i.into_inner(), j.into_inner()) + } + } + + /// Writes `value` to the location at position `(i, j)`. + #[inline] + #[track_caller] + pub fn write(&mut self, i: Idx<'nrows, usize>, j: Idx<'ncols, usize>, value: E) { + unsafe { + self.0 + .inner + .inner + .write_unchecked(i.into_inner(), j.into_inner(), value) + }; + } +} + +impl Clone for MatRef<'_, '_, '_, E> { + #[inline] + fn clone(&self) -> Self { + *self + } +} +impl Copy for MatRef<'_, '_, '_, E> {} + +impl<'nrows, 'ncols, 'a, E: Entity> IntoConst for MatRef<'nrows, 'ncols, 'a, E> { + type Target = MatRef<'nrows, 'ncols, 'a, E>; + #[inline] + fn into_const(self) -> Self::Target { + self + } +} +impl<'nrows, 'ncols, 'a, 'short, E: Entity> Reborrow<'short> for MatRef<'nrows, 'ncols, 'a, E> { + type Target = MatRef<'nrows, 'ncols, 'short, E>; + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} +impl<'nrows, 'ncols, 'a, 'short, E: Entity> ReborrowMut<'short> for MatRef<'nrows, 'ncols, 'a, E> { + type Target = MatRef<'nrows, 'ncols, 'short, E>; + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl<'nrows, 'ncols, 'a, E: Entity> IntoConst for MatMut<'nrows, 'ncols, 'a, E> { + type Target = MatRef<'nrows, 'ncols, 'a, E>; + #[inline] + fn into_const(self) -> Self::Target { + let inner = self.0.inner.inner.into_const(); + MatRef(Branded { + __marker: PhantomData, + inner: Branded { + __marker: PhantomData, + inner, + }, + }) + } +} +impl<'nrows, 'ncols, 'a, 'short, E: Entity> Reborrow<'short> for MatMut<'nrows, 'ncols, 'a, E> { + type Target = MatRef<'nrows, 'ncols, 'short, E>; + #[inline] + fn rb(&'short self) -> Self::Target { + let inner = self.0.inner.inner.rb(); + MatRef(Branded { + __marker: PhantomData, + inner: Branded { + __marker: PhantomData, + inner, + }, + }) + } +} +impl<'nrows, 'ncols, 'a, 'short, E: Entity> ReborrowMut<'short> for MatMut<'nrows, 'ncols, 'a, E> { + type Target = MatMut<'nrows, 'ncols, 'short, E>; + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + let inner = self.0.inner.inner.rb_mut(); + MatMut(Branded { + __marker: PhantomData, + inner: Branded { + __marker: PhantomData, + inner, + }, + }) + } +} + +impl Debug for MatRef<'_, '_, '_, E> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.inner.inner.fmt(f) + } +} +impl Debug for MatMut<'_, '_, '_, E> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.inner.inner.fmt(f) + } +} diff --git a/src/utils/constrained/mod.rs b/src/utils/constrained/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..e3455f51c4a3bda29dcc373acafa65cd4dab4854 --- /dev/null +++ b/src/utils/constrained/mod.rs @@ -0,0 +1,701 @@ +use super::*; +use crate::{assert, debug_assert, Index, SignedIndex}; +use core::{fmt::Debug, marker::PhantomData, ops::Range}; +use faer_entity::*; +use reborrow::*; + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +struct Branded<'a, T: ?Sized> { + __marker: PhantomData &'a ()>, + inner: T, +} + +/// `usize` value tied to the lifetime `'n`. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct Size<'n>(Branded<'n, usize>); + +/// `I` value smaller than the size corresponding to the lifetime `'n`. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct Idx<'n, I>(Branded<'n, I>); + +/// `I` value smaller or equal to the size corresponding to the lifetime `'n`. +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct IdxInclusive<'n, I>(Branded<'n, I>); + +/// `I` value smaller than the size corresponding to the lifetime `'n`, or `None`. +#[derive(Copy, Clone, PartialEq, Eq)] +#[repr(transparent)] +pub struct MaybeIdx<'n, I: Index>(Branded<'n, I>); + +impl core::ops::Deref for Size<'_> { + type Target = usize; + #[inline] + fn deref(&self) -> &Self::Target { + &self.0.inner + } +} +impl core::ops::Deref for Idx<'_, I> { + type Target = I; + #[inline] + fn deref(&self) -> &Self::Target { + &self.0.inner + } +} +impl core::ops::Deref for MaybeIdx<'_, I> { + type Target = I::Signed; + #[inline] + fn deref(&self) -> &Self::Target { + bytemuck::cast_ref(&self.0.inner) + } +} +impl core::ops::Deref for IdxInclusive<'_, I> { + type Target = I; + #[inline] + fn deref(&self) -> &Self::Target { + &self.0.inner + } +} + +/// Array of length equal to the value tied to `'n`. +#[derive(PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct Array<'n, T>(Branded<'n, [T]>); + +/// Dense matrices with compile-time access checks. +pub mod mat; +/// Permutations with compile-time checks. +pub mod perm; +/// Sparse matrices with compile-time access checks. +pub mod sparse; + +/// Immutable array group of length equal to the value tied to `'n`. +pub struct ArrayGroup<'n, 'a, E: Entity>(Branded<'n, slice::SliceGroup<'a, E>>); +/// Mutable array group of length equal to the value tied to `'n`. +pub struct ArrayGroupMut<'n, 'a, E: Entity>(Branded<'n, slice::SliceGroupMut<'a, E>>); + +impl Debug for ArrayGroup<'_, '_, E> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.inner.fmt(f) + } +} +impl Debug for ArrayGroupMut<'_, '_, E> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.inner.fmt(f) + } +} + +impl Copy for ArrayGroup<'_, '_, E> {} +impl Clone for ArrayGroup<'_, '_, E> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl<'short, 'n, 'a, E: Entity> reborrow::ReborrowMut<'short> for ArrayGroup<'n, 'a, E> { + type Target = ArrayGroup<'n, 'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl<'short, 'n, 'a, E: Entity> reborrow::Reborrow<'short> for ArrayGroup<'n, 'a, E> { + type Target = ArrayGroup<'n, 'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} + +impl<'short, 'n, 'a, E: Entity> reborrow::ReborrowMut<'short> for ArrayGroupMut<'n, 'a, E> { + type Target = ArrayGroupMut<'n, 'short, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + ArrayGroupMut(Branded { + __marker: PhantomData, + inner: self.0.inner.rb_mut(), + }) + } +} + +impl<'short, 'n, 'a, E: Entity> reborrow::Reborrow<'short> for ArrayGroupMut<'n, 'a, E> { + type Target = ArrayGroup<'n, 'short, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + ArrayGroup(Branded { + __marker: PhantomData, + inner: self.0.inner.rb(), + }) + } +} + +impl<'n, 'a, E: Entity> ArrayGroupMut<'n, 'a, E> { + /// Returns an array group with length after checking that its length matches + /// the value tied to `'n`. + #[inline] + pub fn new(slice: GroupFor, len: Size<'n>) -> Self { + let slice = slice::SliceGroupMut::<'_, E>::new(slice); + assert!(slice.rb().len() == len.into_inner()); + ArrayGroupMut(Branded { + __marker: PhantomData, + inner: slice, + }) + } + + /// Returns the unconstrained slice. + #[inline] + pub fn into_slice(self) -> GroupFor { + self.0.inner.into_inner() + } + + /// Returns a subslice at from the range start to its end. + #[inline] + pub fn subslice(self, range: Range>) -> GroupFor { + unsafe { + slice::SliceGroupMut::<'_, E>::new(self.into_slice()) + .subslice_unchecked(range.start.into_inner()..range.end.into_inner()) + .into_inner() + } + } + + /// Read the element at position `j`. + #[inline] + pub fn read(&self, j: Idx<'n, usize>) -> E { + self.rb().read(j) + } + + /// Write `value` to the location at position `j`. + #[inline] + pub fn write(&mut self, j: Idx<'n, usize>, value: E) { + unsafe { + slice::SliceGroupMut::new(self.rb_mut().into_slice()) + .write_unchecked(j.into_inner(), value) + } + } +} + +impl<'n, 'a, E: Entity> ArrayGroup<'n, 'a, E> { + /// Returns an array group with length after checking that its length matches + /// the value tied to `'n`. + #[inline] + pub fn new(slice: GroupFor, len: Size<'n>) -> Self { + let slice = slice::SliceGroup::<'_, E>::new(slice); + assert!(slice.rb().len() == len.into_inner()); + ArrayGroup(Branded { + __marker: PhantomData, + inner: slice, + }) + } + + /// Returns the unconstrained slice. + #[inline] + pub fn into_slice(self) -> GroupFor { + self.0.inner.into_inner() + } + + /// Returns a subslice at from the range start to its end. + #[inline] + pub fn subslice(self, range: Range>) -> GroupFor { + unsafe { + slice::SliceGroup::<'_, E>::new(self.into_slice()) + .subslice_unchecked(range.start.into_inner()..range.end.into_inner()) + .into_inner() + } + } + + /// Read the element at position `j`. + #[inline] + pub fn read(&self, j: Idx<'n, usize>) -> E { + unsafe { slice::SliceGroup::new(self.into_slice()).read_unchecked(j.into_inner()) } + } +} + +impl<'size> Size<'size> { + /// Create a new [`Size`] with a lifetime tied to `n`. + #[track_caller] + #[inline] + pub fn with(n: usize, f: impl for<'n> FnOnce(Size<'n>) -> R) -> R { + f(Size(Branded { + __marker: PhantomData, + inner: n, + })) + } + + /// Create two new [`Size`] with lifetimes tied to `m` and `n`. + #[track_caller] + #[inline] + pub fn with2(m: usize, n: usize, f: impl for<'m, 'n> FnOnce(Size<'m>, Size<'n>) -> R) -> R { + f( + Size(Branded { + __marker: PhantomData, + inner: m, + }), + Size(Branded { + __marker: PhantomData, + inner: n, + }), + ) + } + + /// Create a new [`Size`] tied to the lifetime `'n`. + #[inline] + pub unsafe fn new_raw_unchecked(n: usize) -> Self { + Size(Branded { + __marker: PhantomData, + inner: n, + }) + } + + /// Returns the unconstrained value. + #[inline] + pub fn into_inner(self) -> usize { + self.0.inner + } + + /// Returns an iterator of the indices smaller than `self`. + #[inline] + pub fn indices(self) -> impl DoubleEndedIterator> { + (0..self.0.inner).map(|i| unsafe { Idx::new_raw_unchecked(i) }) + } + + /// Check that the index is bounded by `self`, or panic otherwise. + #[track_caller] + #[inline] + pub fn check(self, idx: I) -> Idx<'size, I> { + Idx::new_checked(idx, self) + } + + /// Check that the index is bounded by `self`, or return `None` otherwise. + #[inline] + pub fn try_check(self, idx: I) -> Option> { + if idx.zx() < self.into_inner() { + Some(Idx(Branded { + __marker: PhantomData, + inner: idx, + })) + } else { + None + } + } +} + +impl<'n> Idx<'n, usize> { + /// Truncate `self` to a smaller type `I`. + pub fn truncate(self) -> Idx<'n, I> { + unsafe { Idx::new_raw_unchecked(I::truncate(self.into_inner())) } + } +} + +impl<'n, I: Index> Idx<'n, I> { + /// Returns a new index after asserting that it's bounded by `size`. + #[track_caller] + #[inline] + pub fn new_checked(idx: I, size: Size<'n>) -> Self { + assert!(idx.zx() < size.into_inner()); + Self(Branded { + __marker: PhantomData, + inner: idx, + }) + } + /// Returns a new index without asserting that it's bounded by `size`. + #[track_caller] + #[inline] + pub unsafe fn new_unchecked(idx: I, size: Size<'n>) -> Self { + debug_assert!(idx.zx() < size.into_inner()); + Self(Branded { + __marker: PhantomData, + inner: idx, + }) + } + + /// Returns a new index without asserting that it's bounded by the value tied to the + /// lifetime `'n`. + #[inline] + pub unsafe fn new_raw_unchecked(idx: I) -> Self { + Self(Branded { + __marker: PhantomData, + inner: idx, + }) + } + + /// Returns the unconstrained value. + #[inline] + pub fn into_inner(self) -> I { + self.0.inner + } + + /// Zero extend the value. + #[inline] + pub fn zx(self) -> Idx<'n, usize> { + unsafe { Idx::new_raw_unchecked(self.0.inner.zx()) } + } + + /// Unimplemented: Sign extend the value. + #[inline] + pub fn sx(self) -> ! { + unimplemented!() + } + + /// Returns the index, bounded inclusively by the value tied to `'n`. + #[inline] + pub fn to_inclusive(self) -> IdxInclusive<'n, I> { + unsafe { IdxInclusive::new_raw_unchecked(self.into_inner()) } + } + /// Returns the next index, bounded inclusively by the value tied to `'n`. + #[inline] + pub fn next(self) -> IdxInclusive<'n, I> { + unsafe { IdxInclusive::new_raw_unchecked(self.into_inner() + I::truncate(1)) } + } + + /// Assert that the values of `slice` are all bounded by `size`. + #[track_caller] + #[inline] + pub fn from_slice_mut_checked<'a>(slice: &'a mut [I], size: Size<'n>) -> &'a mut [Idx<'n, I>] { + Self::from_slice_ref_checked(slice, size); + unsafe { &mut *(slice as *mut _ as *mut _) } + } + + /// Assume that the values of `slice` are all bounded by the value tied to `'n`. + #[track_caller] + #[inline] + pub unsafe fn from_slice_mut_unchecked<'a>(slice: &'a mut [I]) -> &'a mut [Idx<'n, I>] { + unsafe { &mut *(slice as *mut _ as *mut _) } + } + + /// Assert that the values of `slice` are all bounded by `size`. + #[track_caller] + pub fn from_slice_ref_checked<'a>(slice: &'a [I], size: Size<'n>) -> &'a [Idx<'n, I>] { + for &idx in slice { + Self::new_checked(idx, size); + } + unsafe { &*(slice as *const _ as *const _) } + } + + /// Assume that the values of `slice` are all bounded by the value tied to `'n`. + #[track_caller] + #[inline] + pub unsafe fn from_slice_ref_unchecked<'a>(slice: &'a [I]) -> &'a [Idx<'n, I>] { + unsafe { &*(slice as *const _ as *const _) } + } +} + +impl<'n, I: Index> MaybeIdx<'n, I> { + /// Returns an index value. + #[inline] + pub fn from_index(idx: Idx<'n, I>) -> Self { + unsafe { Self::new_raw_unchecked(idx.into_inner()) } + } + /// Returns a `None` value. + #[inline] + pub fn none() -> Self { + unsafe { Self::new_raw_unchecked(I::truncate(usize::MAX)) } + } + + /// Returns a constrained index value if `idx` is nonnegative, `None` otherwise. + #[inline] + pub fn new_checked(idx: I::Signed, size: Size<'n>) -> Self { + assert!((idx.sx() as isize) < size.into_inner() as isize); + Self(Branded { + __marker: PhantomData, + inner: I::from_signed(idx), + }) + } + + /// Returns a constrained index value if `idx` is nonnegative, `None` otherwise. + #[inline] + pub unsafe fn new_unchecked(idx: I::Signed, size: Size<'n>) -> Self { + debug_assert!((idx.sx() as isize) < size.into_inner() as isize); + Self(Branded { + __marker: PhantomData, + inner: I::from_signed(idx), + }) + } + + /// Returns a constrained index value if `idx` is nonnegative, `None` otherwise. + #[inline] + pub unsafe fn new_raw_unchecked(idx: I) -> Self { + Self(Branded { + __marker: PhantomData, + inner: idx, + }) + } + + /// Returns the inner value. + #[inline] + pub fn into_inner(self) -> I { + self.0.inner + } + + /// Returns the index if available, or `None` otherwise. + #[inline] + pub fn idx(self) -> Option> { + if self.0.inner.to_signed() >= I::Signed::truncate(0) { + Some(unsafe { Idx::new_raw_unchecked(self.into_inner()) }) + } else { + None + } + } + + /// Unimplemented: Zero extend the value. + #[inline] + pub fn zx(self) -> ! { + unimplemented!() + } + + /// Sign extend the value. + #[inline] + pub fn sx(self) -> MaybeIdx<'n, usize> { + unsafe { MaybeIdx::new_raw_unchecked(self.0.inner.to_signed().sx()) } + } + + /// Assert that the values of `slice` are all bounded by `size`. + #[track_caller] + #[inline] + pub fn from_slice_mut_checked<'a>( + slice: &'a mut [I::Signed], + size: Size<'n>, + ) -> &'a mut [MaybeIdx<'n, I>] { + Self::from_slice_ref_checked(slice, size); + unsafe { &mut *(slice as *mut _ as *mut _) } + } + + /// Assume that the values of `slice` are all bounded by the value tied to `'n`. + #[track_caller] + #[inline] + pub unsafe fn from_slice_mut_unchecked<'a>( + slice: &'a mut [I::Signed], + ) -> &'a mut [MaybeIdx<'n, I>] { + unsafe { &mut *(slice as *mut _ as *mut _) } + } + + /// Assert that the values of `slice` are all bounded by `size`. + #[track_caller] + pub fn from_slice_ref_checked<'a>( + slice: &'a [I::Signed], + size: Size<'n>, + ) -> &'a [MaybeIdx<'n, I>] { + for &idx in slice { + Self::new_checked(idx, size); + } + unsafe { &*(slice as *const _ as *const _) } + } + + /// Convert a constrained slice to an unconstrained one. + #[track_caller] + pub fn as_slice_ref<'a>(slice: &'a [MaybeIdx<'n, I>]) -> &'a [I::Signed] { + unsafe { &*(slice as *const _ as *const _) } + } + + /// Assume that the values of `slice` are all bounded by the value tied to `'n`. + #[track_caller] + #[inline] + pub unsafe fn from_slice_ref_unchecked<'a>(slice: &'a [I::Signed]) -> &'a [MaybeIdx<'n, I>] { + unsafe { &*(slice as *const _ as *const _) } + } +} + +impl<'n> IdxInclusive<'n, usize> { + /// Returns an iterator over constrained indices from `0` to `self` (exclusive). + #[inline] + pub fn range_to(self, last: Self) -> impl DoubleEndedIterator> { + (*self..*last).map( + #[inline(always)] + |idx| unsafe { Idx::new_raw_unchecked(idx) }, + ) + } +} + +impl<'n, I: Index> IdxInclusive<'n, I> { + /// Returns a constrained inclusive index after checking that it's bounded (inclusively) by + /// `size`. + #[inline] + pub fn new_checked(idx: I, size: Size<'n>) -> Self { + assert!(idx.zx() <= size.into_inner()); + Self(Branded { + __marker: PhantomData, + inner: idx, + }) + } + /// Returns a constrained inclusive index, assuming that it's bounded (inclusively) by + /// `size`. + #[inline] + pub unsafe fn new_unchecked(idx: I, size: Size<'n>) -> Self { + debug_assert!(idx.zx() <= size.into_inner()); + Self(Branded { + __marker: PhantomData, + inner: idx, + }) + } + + /// Returns a constrained inclusive index, assuming that it's bounded (inclusively) by + /// the size tied to `'n`. + #[inline] + pub unsafe fn new_raw_unchecked(idx: I) -> Self { + Self(Branded { + __marker: PhantomData, + inner: idx, + }) + } + + /// Returns the unconstrained value. + #[inline] + pub fn into_inner(self) -> I { + self.0.inner + } + + /// Unimplemented: Sign extend the value. + #[inline] + pub fn sx(self) -> ! { + unimplemented!() + } + /// Unimplemented: Zero extend the value. + #[inline] + pub fn zx(self) -> ! { + unimplemented!() + } +} + +impl<'n, T> Array<'n, T> { + /// Returns a constrained array after checking that its length matches `size`. + #[inline] + #[track_caller] + pub fn from_ref<'a>(slice: &'a [T], size: Size<'n>) -> &'a Self { + assert!(slice.len() == size.into_inner()); + unsafe { &*(slice as *const [T] as *const Self) } + } + + /// Returns a constrained array after checking that its length matches `size`. + #[inline] + #[track_caller] + pub fn from_mut<'a>(slice: &'a mut [T], size: Size<'n>) -> &'a mut Self { + assert!(slice.len() == size.into_inner()); + unsafe { &mut *(slice as *mut [T] as *mut Self) } + } + + /// Returns the unconstrained slice. + #[inline] + #[track_caller] + pub fn as_ref(&self) -> &[T] { + unsafe { &*(self as *const _ as *const _) } + } + + /// Returns the unconstrained slice. + #[inline] + #[track_caller] + pub fn as_mut<'a>(&mut self) -> &'a mut [T] { + unsafe { &mut *(self as *mut _ as *mut _) } + } + + /// Returns the length of `self`. + #[inline] + pub fn len(&self) -> Size<'n> { + unsafe { Size::new_raw_unchecked(self.0.inner.len()) } + } +} + +impl Debug for Size<'_> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.inner.fmt(f) + } +} +impl Debug for Idx<'_, I> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.inner.fmt(f) + } +} +impl Debug for IdxInclusive<'_, I> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.inner.fmt(f) + } +} +impl Debug for MaybeIdx<'_, I> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + #[derive(Debug)] + struct None; + + match self.idx() { + Some(idx) => idx.fmt(f), + Option::None => None.fmt(f), + } + } +} +impl Debug for Array<'_, T> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.inner.fmt(f) + } +} + +impl<'n, T> core::ops::Index>> for Array<'n, T> { + type Output = [T]; + #[track_caller] + fn index(&self, idx: Range>) -> &Self::Output { + #[cfg(debug_assertions)] + { + &self.0.inner[idx.start.into_inner()..idx.end.into_inner()] + } + #[cfg(not(debug_assertions))] + unsafe { + self.0 + .inner + .get_unchecked(idx.start.into_inner()..idx.end.into_inner()) + } + } +} +impl<'n, T> core::ops::IndexMut>> for Array<'n, T> { + #[track_caller] + fn index_mut(&mut self, idx: Range>) -> &mut Self::Output { + #[cfg(debug_assertions)] + { + &mut self.0.inner[idx.start.into_inner()..idx.end.into_inner()] + } + #[cfg(not(debug_assertions))] + unsafe { + self.0 + .inner + .get_unchecked_mut(idx.start.into_inner()..idx.end.into_inner()) + } + } +} +impl<'n, T> core::ops::Index> for Array<'n, T> { + type Output = T; + #[track_caller] + fn index(&self, idx: Idx<'n, usize>) -> &Self::Output { + #[cfg(debug_assertions)] + { + &self.0.inner[idx.into_inner()] + } + #[cfg(not(debug_assertions))] + unsafe { + self.0.inner.get_unchecked(idx.into_inner()) + } + } +} +impl<'n, T> core::ops::IndexMut> for Array<'n, T> { + #[track_caller] + fn index_mut(&mut self, idx: Idx<'n, usize>) -> &mut Self::Output { + #[cfg(debug_assertions)] + { + &mut self.0.inner[idx.into_inner()] + } + #[cfg(not(debug_assertions))] + unsafe { + self.0.inner.get_unchecked_mut(idx.into_inner()) + } + } +} diff --git a/src/utils/constrained/perm.rs b/src/utils/constrained/perm.rs new file mode 100644 index 0000000000000000000000000000000000000000..3732aecd4dbe08bcf635c2adefd4d63a5f713bbc --- /dev/null +++ b/src/utils/constrained/perm.rs @@ -0,0 +1,70 @@ +use super::*; +use crate::{assert, perm}; + +/// Permutation of length equal to the value tied to `'n`. +#[repr(transparent)] +pub struct PermRef<'n, 'a, I: Index>(Branded<'n, perm::PermRef<'a, I>>); + +impl<'n, 'a, I: Index> PermRef<'n, 'a, I> { + /// Returns a new permutation after checking that it matches the size tied to `'n`. + #[inline] + #[track_caller] + pub fn new(perm: perm::PermRef<'a, I>, size: Size<'n>) -> Self { + let (fwd, inv) = perm.arrays(); + assert!(all( + fwd.len() == size.into_inner(), + inv.len() == size.into_inner(), + )); + Self(Branded { + __marker: PhantomData, + inner: perm, + }) + } + + /// Returns the inverse permutation. + #[inline] + pub fn inverse(self) -> PermRef<'n, 'a, I> { + PermRef(Branded { + __marker: PhantomData, + inner: self.0.inner.inverse(), + }) + } + + /// Returns the forward and inverse permutation indices. + #[inline] + pub fn arrays(self) -> (&'a Array<'n, Idx<'n, I>>, &'a Array<'n, Idx<'n, I>>) { + unsafe { + let (fwd, inv) = self.0.inner.arrays(); + let fwd = &*(fwd as *const [I] as *const Array<'n, Idx<'n, I>>); + let inv = &*(inv as *const [I] as *const Array<'n, Idx<'n, I>>); + (fwd, inv) + } + } + + /// Returns the unconstrained permutation. + #[inline] + pub fn into_inner(self) -> perm::PermRef<'a, I> { + self.0.inner + } + + /// Returns the length of the permutation. + #[inline] + pub fn len(&self) -> Size<'n> { + unsafe { Size::new_raw_unchecked(self.into_inner().len()) } + } +} + +impl Clone for PermRef<'_, '_, I> { + #[inline] + fn clone(&self) -> Self { + *self + } +} +impl Copy for PermRef<'_, '_, I> {} + +impl core::fmt::Debug for PermRef<'_, '_, I> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.0.inner.fmt(f) + } +} diff --git a/src/utils/constrained/sparse.rs b/src/utils/constrained/sparse.rs new file mode 100644 index 0000000000000000000000000000000000000000..5802ffdb8c5895b3a7d607239a9f56bde8d1cce0 --- /dev/null +++ b/src/utils/constrained/sparse.rs @@ -0,0 +1,293 @@ +use super::*; +use crate::{assert, sparse::__get_unchecked, utils::slice::*}; +use core::ops::Range; + +/// Symbolic structure view with dimensions equal to the values tied to `('nrows, 'ncols)`, +/// in column-major order. +#[repr(transparent)] +pub struct SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I: Index>( + Branded<'ncols, Branded<'nrows, crate::sparse::SymbolicSparseColMatRef<'a, I>>>, +); +/// Immutable sparse matrix view with dimensions equal to the values tied to `('nrows, +/// 'ncols)`, in column-major order. +pub struct SparseColMatRef<'nrows, 'ncols, 'a, I: Index, E: Entity> { + symbolic: SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>, + values: SliceGroup<'a, E>, +} +/// Mutable sparse matrix view with dimensions equal to the values tied to `('nrows, +/// 'ncols)`, in column-major order. +pub struct SparseColMatMut<'nrows, 'ncols, 'a, I: Index, E: Entity> { + symbolic: SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>, + values: SliceGroupMut<'a, E>, +} + +impl<'nrows, 'ncols, 'a, I: Index> SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I> { + /// Returns a new symbolic structure after checking that its dimensions match the + /// dimensions tied to `('nrows, 'ncols)`. + #[inline] + pub fn new( + inner: crate::sparse::SymbolicSparseColMatRef<'a, I>, + nrows: Size<'nrows>, + ncols: Size<'ncols>, + ) -> Self { + assert!(all( + inner.nrows() == nrows.into_inner(), + inner.ncols() == ncols.into_inner(), + )); + Self(Branded { + __marker: PhantomData, + inner: Branded { + __marker: PhantomData, + inner, + }, + }) + } + + /// Returns the unconstrained symbolic structure. + #[inline] + pub fn into_inner(self) -> crate::sparse::SymbolicSparseColMatRef<'a, I> { + self.0.inner.inner + } + + /// Returns the number of rows of the matrix. + #[inline] + pub fn nrows(&self) -> Size<'nrows> { + unsafe { Size::new_raw_unchecked(self.0.inner.inner.nrows()) } + } + + /// Returns the number of columns of the matrix. + #[inline] + pub fn ncols(&self) -> Size<'ncols> { + unsafe { Size::new_raw_unchecked(self.0.inner.inner.ncols()) } + } + + #[inline] + #[track_caller] + #[doc(hidden)] + pub fn col_range(&self, j: Idx<'ncols, usize>) -> Range { + unsafe { self.into_inner().col_range_unchecked(j.into_inner()) } + } + + /// Returns the row indices in column `j`. + #[inline] + #[track_caller] + pub fn row_indices_of_col_raw(&self, j: Idx<'ncols, usize>) -> &'a [Idx<'nrows, I>] { + unsafe { + &*(__get_unchecked(self.into_inner().row_indices(), self.col_range(j)) as *const [I] + as *const [Idx<'_, I>]) + } + } + + /// Returns the row indices in column `j`. + #[inline] + #[track_caller] + pub fn row_indices_of_col( + &self, + j: Idx<'ncols, usize>, + ) -> impl 'a + ExactSizeIterator + DoubleEndedIterator> { + unsafe { + __get_unchecked( + self.into_inner().row_indices(), + self.into_inner().col_range_unchecked(j.into_inner()), + ) + .iter() + .map( + #[inline(always)] + move |&row| Idx::new_raw_unchecked(row.zx()), + ) + } + } +} + +impl<'nrows, 'ncols, 'a, I: Index, E: Entity> SparseColMatRef<'nrows, 'ncols, 'a, I, E> { + /// Returns a new matrix view after checking that its dimensions match the + /// dimensions tied to `('nrows, 'ncols)`. + pub fn new( + inner: crate::sparse::SparseColMatRef<'a, I, E>, + nrows: Size<'nrows>, + ncols: Size<'ncols>, + ) -> Self { + assert!(all( + inner.nrows() == nrows.into_inner(), + inner.ncols() == ncols.into_inner(), + )); + Self { + symbolic: SymbolicSparseColMatRef::new(inner.symbolic(), nrows, ncols), + values: SliceGroup::new(inner.values()), + } + } + + /// Returns the unconstrained matrix. + #[inline] + pub fn into_inner(self) -> crate::sparse::SparseColMatRef<'a, I, E> { + crate::sparse::SparseColMatRef::new(self.symbolic.into_inner(), self.values.into_inner()) + } + + /// Returns the values in column `j`. + #[inline] + pub fn values_of_col(&self, j: Idx<'ncols, usize>) -> GroupFor { + unsafe { + self.values + .subslice_unchecked(self.col_range(j)) + .into_inner() + } + } + + /// Returns the symbolic structure of the matrix. + #[inline] + pub fn symbolic(&self) -> SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I> { + self.symbolic + } +} + +impl<'nrows, 'ncols, 'a, I: Index, E: Entity> SparseColMatMut<'nrows, 'ncols, 'a, I, E> { + /// Returns a new matrix view after checking that its dimensions match the + /// dimensions tied to `('nrows, 'ncols)`. + pub fn new( + inner: crate::sparse::SparseColMatMut<'a, I, E>, + nrows: Size<'nrows>, + ncols: Size<'ncols>, + ) -> Self { + assert!(all( + inner.nrows() == nrows.into_inner(), + inner.ncols() == ncols.into_inner(), + )); + Self { + symbolic: SymbolicSparseColMatRef::new(inner.symbolic(), nrows, ncols), + values: SliceGroupMut::new(inner.values_mut()), + } + } + + /// Returns the unconstrained matrix. + #[inline] + pub fn into_inner(self) -> crate::sparse::SparseColMatMut<'a, I, E> { + crate::sparse::SparseColMatMut::new(self.symbolic.into_inner(), self.values.into_inner()) + } + + /// Returns the values in column `j`. + #[inline] + pub fn values_of_col_mut(&mut self, j: Idx<'ncols, usize>) -> GroupFor { + unsafe { + let range = self.col_range(j); + self.values.rb_mut().subslice_unchecked(range).into_inner() + } + } + + /// Returns the symbolic structure of the matrix. + #[inline] + pub fn symbolic(&self) -> SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I> { + self.symbolic + } +} + +impl Copy for SparseColMatRef<'_, '_, '_, I, E> {} +impl Clone for SparseColMatRef<'_, '_, '_, I, E> { + #[inline] + fn clone(&self) -> Self { + *self + } +} +impl Copy for SymbolicSparseColMatRef<'_, '_, '_, I> {} +impl Clone for SymbolicSparseColMatRef<'_, '_, '_, I> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl<'nrows, 'ncols, 'a, I: Index, E: Entity> core::ops::Deref + for SparseColMatRef<'nrows, 'ncols, 'a, I, E> +{ + type Target = SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.symbolic + } +} + +impl<'nrows, 'ncols, 'a, I: Index, E: Entity> core::ops::Deref + for SparseColMatMut<'nrows, 'ncols, 'a, I, E> +{ + type Target = SymbolicSparseColMatRef<'nrows, 'ncols, 'a, I>; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.symbolic + } +} + +impl<'short, 'nrows, 'ncols, 'a, I: Index, E: Entity> ReborrowMut<'short> + for SparseColMatRef<'nrows, 'ncols, 'a, I, E> +{ + type Target = SparseColMatRef<'nrows, 'ncols, 'short, I, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl<'short, 'nrows, 'ncols, 'a, I: Index, E: Entity> Reborrow<'short> + for SparseColMatRef<'nrows, 'ncols, 'a, I, E> +{ + type Target = SparseColMatRef<'nrows, 'ncols, 'short, I, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} + +impl<'nrows, 'ncols, 'a, I: Index, E: Entity> IntoConst + for SparseColMatRef<'nrows, 'ncols, 'a, I, E> +{ + type Target = SparseColMatRef<'nrows, 'ncols, 'a, I, E>; + + #[inline] + fn into_const(self) -> Self::Target { + self + } +} + +impl<'short, 'nrows, 'ncols, 'a, I: Index, E: Entity> ReborrowMut<'short> + for SparseColMatMut<'nrows, 'ncols, 'a, I, E> +{ + type Target = SparseColMatMut<'nrows, 'ncols, 'short, I, E>; + + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + SparseColMatMut::<'nrows, 'ncols, 'short, I, E> { + symbolic: self.symbolic, + values: self.values.rb_mut(), + } + } +} + +impl<'short, 'nrows, 'ncols, 'a, I: Index, E: Entity> Reborrow<'short> + for SparseColMatMut<'nrows, 'ncols, 'a, I, E> +{ + type Target = SparseColMatRef<'nrows, 'ncols, 'short, I, E>; + + #[inline] + fn rb(&'short self) -> Self::Target { + SparseColMatRef::<'nrows, 'ncols, 'short, I, E> { + symbolic: self.symbolic, + values: self.values.rb(), + } + } +} + +impl<'nrows, 'ncols, 'a, I: Index, E: Entity> IntoConst + for SparseColMatMut<'nrows, 'ncols, 'a, I, E> +{ + type Target = SparseColMatRef<'nrows, 'ncols, 'a, I, E>; + + #[inline] + fn into_const(self) -> Self::Target { + SparseColMatRef::<'nrows, 'ncols, 'a, I, E> { + symbolic: self.symbolic, + values: self.values.into_const(), + } + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..dfefdb220adc691a96945568acdaa6ca3c2b95bb --- /dev/null +++ b/src/utils/mod.rs @@ -0,0 +1,57 @@ +#[inline(always)] +pub(crate) unsafe fn unchecked_mul(a: usize, b: isize) -> isize { + let (sum, overflow) = (a as isize).overflowing_mul(b); + if overflow { + core::hint::unreachable_unchecked(); + } + sum +} + +#[inline(always)] +pub(crate) unsafe fn unchecked_add(a: isize, b: isize) -> isize { + let (sum, overflow) = a.overflowing_add(b); + if overflow { + core::hint::unreachable_unchecked(); + } + sum +} + +#[doc(hidden)] +pub(crate) trait DivCeil: Sized { + fn msrv_div_ceil(self, rhs: Self) -> Self; + fn msrv_checked_next_multiple_of(self, rhs: Self) -> Option; +} + +impl DivCeil for usize { + #[inline] + fn msrv_div_ceil(self, rhs: Self) -> Self { + let d = self / rhs; + let r = self % rhs; + if r > 0 { + d + 1 + } else { + d + } + } + + #[inline] + fn msrv_checked_next_multiple_of(self, rhs: Self) -> Option { + { + match self.checked_rem(rhs)? { + 0 => Some(self), + r => self.checked_add(rhs - r), + } + } + } +} + +/// Index and matrix types with compile time checks, whichh can replace bound checks at runtime. +pub mod constrained; +/// Simd operations for a specific type satisfying [`ComplexField`](crate::ComplexField). +pub mod simd; +/// Slice types for [entities](crate::Entity). +pub mod slice; +/// Utilities relating to threading and parallelism. +pub mod thread; +/// Vector type for [entities](crate::Entity). +pub mod vec; diff --git a/src/utils/simd.rs b/src/utils/simd.rs new file mode 100644 index 0000000000000000000000000000000000000000..c6dfda2fb0beed379e6ba987747ae747d0750a8d --- /dev/null +++ b/src/utils/simd.rs @@ -0,0 +1,664 @@ +use super::slice::*; +use crate::Conj; +use core::{fmt::Debug, marker::PhantomData}; +use faer_entity::*; +use reborrow::*; + +pub use faer_entity::pulp::{Read, Write}; + +/// Do conjugate. +#[derive(Copy, Clone, Debug)] +pub struct YesConj; +/// Do not conjugate. +#[derive(Copy, Clone, Debug)] +pub struct NoConj; + +/// Similar to [`Conj`], but determined at compile time instead of runtime. +pub trait ConjTy: Copy + Debug { + /// The corresponding [`Conj`] value. + const CONJ: Conj; + /// The opposing conjugation type. + type Flip: ConjTy; + + /// Returns an instance of the corresponding conjugation type. + fn flip(self) -> Self::Flip; +} + +impl ConjTy for YesConj { + const CONJ: Conj = Conj::Yes; + type Flip = NoConj; + #[inline(always)] + fn flip(self) -> Self::Flip { + NoConj + } +} +impl ConjTy for NoConj { + const CONJ: Conj = Conj::No; + type Flip = YesConj; + #[inline(always)] + fn flip(self) -> Self::Flip { + YesConj + } +} + +/// Wrapper for simd operations for type `E`. +pub struct SimdFor { + /// Simd token. + pub simd: S, + __marker: PhantomData, +} + +/// Simd prefix, contains the elements before the body. +pub struct Prefix<'a, E: Entity, S: pulp::Simd>( + GroupCopyFor>, + PhantomData<&'a ()>, +); +/// Simd suffix, contains the elements after the body. +pub struct Suffix<'a, E: Entity, S: pulp::Simd>( + GroupCopyFor>, + PhantomData<&'a mut ()>, +); +/// Simd prefix (mutable), contains the elements before the body. +pub struct PrefixMut<'a, E: Entity, S: pulp::Simd>( + GroupFor>, + PhantomData<&'a ()>, +); +/// Simd suffix (mutable), contains the elements after the body. +pub struct SuffixMut<'a, E: Entity, S: pulp::Simd>( + GroupFor>, + PhantomData<&'a mut ()>, +); + +impl Copy for SimdFor {} +impl Clone for SimdFor { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl SimdFor { + /// Create a new wrapper from a simd token. + #[inline(always)] + pub fn new(simd: S) -> Self { + Self { + simd, + __marker: PhantomData, + } + } + + /// Computes the alignment offset for subsequent aligned loads. + #[inline(always)] + pub fn align_offset(self, slice: SliceGroup<'_, E>) -> pulp::Offset> { + let slice = E::faer_first(slice.into_inner()); + E::faer_align_offset(self.simd, slice.as_ptr(), slice.len()) + } + + /// Computes the alignment offset for subsequent aligned loads from a pointer. + #[inline(always)] + pub fn align_offset_ptr( + self, + ptr: GroupFor, + len: usize, + ) -> pulp::Offset> { + E::faer_align_offset(self.simd, E::faer_first(ptr), len) + } + + /// Convert a slice to a slice over vector registers, and a scalar tail. + #[inline(always)] + pub fn as_simd( + self, + slice: SliceGroup<'_, E>, + ) -> (SliceGroup<'_, E, SimdUnitFor>, SliceGroup<'_, E>) { + let (head, tail) = slice_as_simd::(slice.into_inner()); + (SliceGroup::new(head), SliceGroup::new(tail)) + } + + /// Convert a mutable slice to a slice over vector registers, and a scalar tail. + #[inline(always)] + pub fn as_simd_mut( + self, + slice: SliceGroupMut<'_, E>, + ) -> ( + SliceGroupMut<'_, E, SimdUnitFor>, + SliceGroupMut<'_, E>, + ) { + let (head, tail) = slice_as_mut_simd::(slice.into_inner()); + (SliceGroupMut::new(head), SliceGroupMut::new(tail)) + } + + /// Convert a slice to a partial register prefix and suffix, and a vector register slice + /// (body). + #[inline(always)] + pub fn as_aligned_simd( + self, + slice: SliceGroup<'_, E>, + offset: pulp::Offset>, + ) -> ( + Prefix<'_, E, S>, + SliceGroup<'_, E, SimdUnitFor>, + Suffix<'_, E, S>, + ) { + let (head_tail, body) = E::faer_unzip(E::faer_map(slice.into_inner(), |slice| { + let (head, body, tail) = E::faer_slice_as_aligned_simd(self.simd, slice, offset); + ((head, tail), body) + })); + + let (head, tail) = E::faer_unzip(head_tail); + + unsafe { + ( + Prefix( + transmute_unchecked::< + GroupCopyFor>, + GroupCopyFor>, + >(into_copy::(head)), + PhantomData, + ), + SliceGroup::new(body), + Suffix( + transmute_unchecked::< + GroupCopyFor>, + GroupCopyFor>, + >(into_copy::(tail)), + PhantomData, + ), + ) + } + } + + /// Convert a mutable slice to a partial register prefix and suffix, and a vector register + /// slice (body). + #[inline(always)] + pub fn as_aligned_simd_mut( + self, + slice: SliceGroupMut<'_, E>, + offset: pulp::Offset>, + ) -> ( + PrefixMut<'_, E, S>, + SliceGroupMut<'_, E, SimdUnitFor>, + SuffixMut<'_, E, S>, + ) { + let (head_tail, body) = E::faer_unzip(E::faer_map(slice.into_inner(), |slice| { + let (head, body, tail) = E::faer_slice_as_aligned_simd_mut(self.simd, slice, offset); + ((head, tail), body) + })); + + let (head, tail) = E::faer_unzip(head_tail); + + ( + PrefixMut( + unsafe { + transmute_unchecked::< + GroupFor>, + GroupFor>, + >(head) + }, + PhantomData, + ), + SliceGroupMut::new(body), + SuffixMut( + unsafe { + transmute_unchecked::< + GroupFor>, + GroupFor>, + >(tail) + }, + PhantomData, + ), + ) + } + + /// Fill all the register lanes with the same value. + #[inline(always)] + pub fn splat(self, value: E) -> SimdGroupFor { + E::faer_simd_splat(self.simd, value) + } + + /// Returns `lhs * rhs`. + #[inline(always)] + pub fn scalar_mul(self, lhs: E, rhs: E) -> E { + E::faer_simd_scalar_mul(self.simd, lhs, rhs) + } + /// Returns `conj(lhs) * rhs`. + #[inline(always)] + pub fn scalar_conj_mul(self, lhs: E, rhs: E) -> E { + E::faer_simd_scalar_conj_mul(self.simd, lhs, rhs) + } + /// Returns an estimate of `lhs * rhs + acc`. + #[inline(always)] + pub fn scalar_mul_add_e(self, lhs: E, rhs: E, acc: E) -> E { + E::faer_simd_scalar_mul_adde(self.simd, lhs, rhs, acc) + } + /// Returns an estimate of `conj(lhs) * rhs + acc`. + #[inline(always)] + pub fn scalar_conj_mul_add_e(self, lhs: E, rhs: E, acc: E) -> E { + E::faer_simd_scalar_conj_mul_adde(self.simd, lhs, rhs, acc) + } + + /// Returns an estimate of `op(lhs) * rhs`, where `op` is either the conjugation + /// or the identity operation. + #[inline(always)] + pub fn scalar_conditional_conj_mul(self, conj: C, lhs: E, rhs: E) -> E { + let _ = conj; + if C::CONJ == Conj::Yes { + self.scalar_conj_mul(lhs, rhs) + } else { + self.scalar_mul(lhs, rhs) + } + } + /// Returns an estimate of `op(lhs) * rhs + acc`, where `op` is either the conjugation or + /// the identity operation. + #[inline(always)] + pub fn scalar_conditional_conj_mul_add_e( + self, + conj: C, + lhs: E, + rhs: E, + acc: E, + ) -> E { + let _ = conj; + if C::CONJ == Conj::Yes { + self.scalar_conj_mul_add_e(lhs, rhs, acc) + } else { + self.scalar_mul_add_e(lhs, rhs, acc) + } + } + + /// Returns `lhs + rhs`. + #[inline(always)] + pub fn add(self, lhs: SimdGroupFor, rhs: SimdGroupFor) -> SimdGroupFor { + E::faer_simd_add(self.simd, lhs, rhs) + } + /// Returns `lhs - rhs`. + #[inline(always)] + pub fn sub(self, lhs: SimdGroupFor, rhs: SimdGroupFor) -> SimdGroupFor { + E::faer_simd_sub(self.simd, lhs, rhs) + } + /// Returns `-a`. + #[inline(always)] + pub fn neg(self, a: SimdGroupFor) -> SimdGroupFor { + E::faer_simd_neg(self.simd, a) + } + /// Returns `lhs * rhs`. + #[inline(always)] + pub fn scale_real( + self, + lhs: SimdGroupFor, + rhs: SimdGroupFor, + ) -> SimdGroupFor { + E::faer_simd_scale_real(self.simd, lhs, rhs) + } + /// Returns `lhs * rhs`. + #[inline(always)] + pub fn mul(self, lhs: SimdGroupFor, rhs: SimdGroupFor) -> SimdGroupFor { + E::faer_simd_mul(self.simd, lhs, rhs) + } + /// Returns `conj(lhs) * rhs`. + #[inline(always)] + pub fn conj_mul(self, lhs: SimdGroupFor, rhs: SimdGroupFor) -> SimdGroupFor { + E::faer_simd_conj_mul(self.simd, lhs, rhs) + } + /// Returns `op(lhs) * rhs`, where `op` is either the conjugation or the identity + /// operation. + #[inline(always)] + pub fn conditional_conj_mul( + self, + conj: C, + lhs: SimdGroupFor, + rhs: SimdGroupFor, + ) -> SimdGroupFor { + let _ = conj; + if C::CONJ == Conj::Yes { + self.conj_mul(lhs, rhs) + } else { + self.mul(lhs, rhs) + } + } + + /// Returns `lhs * rhs + acc`. + #[inline(always)] + pub fn mul_add_e( + self, + lhs: SimdGroupFor, + rhs: SimdGroupFor, + acc: SimdGroupFor, + ) -> SimdGroupFor { + E::faer_simd_mul_adde(self.simd, lhs, rhs, acc) + } + /// Returns `conj(lhs) * rhs + acc`. + #[inline(always)] + pub fn conj_mul_add_e( + self, + lhs: SimdGroupFor, + rhs: SimdGroupFor, + acc: SimdGroupFor, + ) -> SimdGroupFor { + E::faer_simd_conj_mul_adde(self.simd, lhs, rhs, acc) + } + /// Returns `op(lhs) * rhs + acc`, where `op` is either the conjugation or the identity + /// operation. + #[inline(always)] + pub fn conditional_conj_mul_add_e( + self, + conj: C, + lhs: SimdGroupFor, + rhs: SimdGroupFor, + acc: SimdGroupFor, + ) -> SimdGroupFor { + let _ = conj; + if C::CONJ == Conj::Yes { + self.conj_mul_add_e(lhs, rhs, acc) + } else { + self.mul_add_e(lhs, rhs, acc) + } + } + + /// Returns `abs(values) * abs(values) + acc`. + #[inline(always)] + pub fn abs2_add_e( + self, + values: SimdGroupFor, + acc: SimdGroupFor, + ) -> SimdGroupFor { + E::faer_simd_abs2_adde(self.simd, values, acc) + } + /// Returns `abs(values) * abs(values)`. + #[inline(always)] + pub fn abs2(self, values: SimdGroupFor) -> SimdGroupFor { + E::faer_simd_abs2(self.simd, values) + } + /// Returns `abs(values)` or `abs(values) * abs(values)`, whichever is cheaper to compute. + #[inline(always)] + pub fn score(self, values: SimdGroupFor) -> SimdGroupFor { + E::faer_simd_score(self.simd, values) + } + + /// Sum the components of a vector register into a single accumulator. + #[inline(always)] + pub fn reduce_add(self, values: SimdGroupFor) -> E { + E::faer_simd_reduce_add(self.simd, values) + } + + /// Rotate `values` to the left, with overflowing entries wrapping around to the right side + /// of the register. + #[inline(always)] + pub fn rotate_left(self, values: SimdGroupFor, amount: usize) -> SimdGroupFor { + E::faer_simd_rotate_left(self.simd, values, amount) + } +} + +impl SimdFor { + /// Returns `abs(values)`. + #[inline(always)] + pub fn abs(self, values: SimdGroupFor) -> SimdGroupFor { + E::faer_simd_abs(self.simd, values) + } + /// Returns `a < b`. + #[inline(always)] + pub fn less_than(self, a: SimdGroupFor, b: SimdGroupFor) -> SimdMaskFor { + E::faer_simd_less_than(self.simd, a, b) + } + /// Returns `a <= b`. + #[inline(always)] + pub fn less_than_or_equal( + self, + a: SimdGroupFor, + b: SimdGroupFor, + ) -> SimdMaskFor { + E::faer_simd_less_than_or_equal(self.simd, a, b) + } + /// Returns `a > b`. + #[inline(always)] + pub fn greater_than(self, a: SimdGroupFor, b: SimdGroupFor) -> SimdMaskFor { + E::faer_simd_greater_than(self.simd, a, b) + } + /// Returns `a >= b`. + #[inline(always)] + pub fn greater_than_or_equal( + self, + a: SimdGroupFor, + b: SimdGroupFor, + ) -> SimdMaskFor { + E::faer_simd_greater_than_or_equal(self.simd, a, b) + } + + /// Returns `if mask { if_true } else { if_false }` + #[inline(always)] + pub fn select( + self, + mask: SimdMaskFor, + if_true: SimdGroupFor, + if_false: SimdGroupFor, + ) -> SimdGroupFor { + E::faer_simd_select(self.simd, mask, if_true, if_false) + } + /// Returns `if mask { if_true } else { if_false }` + #[inline(always)] + pub fn index_select( + self, + mask: SimdMaskFor, + if_true: SimdIndexFor, + if_false: SimdIndexFor, + ) -> SimdIndexFor { + E::faer_simd_index_select(self.simd, mask, if_true, if_false) + } + /// Returns `[0, 1, 2, 3, ..., REGISTER_SIZE - 1]` + #[inline(always)] + pub fn index_seq(self) -> SimdIndexFor { + E::faer_simd_index_seq(self.simd) + } + /// Fill all the register lanes with the same value. + #[inline(always)] + pub fn index_splat(self, value: IndexFor) -> SimdIndexFor { + E::faer_simd_index_splat(self.simd, value) + } + /// Returns `a + b`. + #[inline(always)] + pub fn index_add(self, a: SimdIndexFor, b: SimdIndexFor) -> SimdIndexFor { + E::faer_simd_index_add(self.simd, a, b) + } +} +impl Read for Prefix<'_, E, S> { + type Output = SimdGroupFor; + #[inline(always)] + fn read_or(&self, or: Self::Output) -> Self::Output { + into_copy::(E::faer_map( + E::faer_zip(from_copy::(self.0), from_copy::(or)), + #[inline(always)] + |(prefix, or)| prefix.read_or(or), + )) + } +} +impl Read for PrefixMut<'_, E, S> { + type Output = SimdGroupFor; + #[inline(always)] + fn read_or(&self, or: Self::Output) -> Self::Output { + self.rb().read_or(or) + } +} +impl Write for PrefixMut<'_, E, S> { + #[inline(always)] + fn write(&mut self, values: Self::Output) { + E::faer_map( + E::faer_zip(self.rb_mut().0, from_copy::(values)), + #[inline(always)] + |(mut prefix, values)| prefix.write(values), + ); + } +} + +impl Read for Suffix<'_, E, S> { + type Output = SimdGroupFor; + #[inline(always)] + fn read_or(&self, or: Self::Output) -> Self::Output { + into_copy::(E::faer_map( + E::faer_zip(from_copy::(self.0), from_copy::(or)), + #[inline(always)] + |(suffix, or)| suffix.read_or(or), + )) + } +} +impl Read for SuffixMut<'_, E, S> { + type Output = SimdGroupFor; + #[inline(always)] + fn read_or(&self, or: Self::Output) -> Self::Output { + self.rb().read_or(or) + } +} +impl Write for SuffixMut<'_, E, S> { + #[inline(always)] + fn write(&mut self, values: Self::Output) { + E::faer_map( + E::faer_zip(self.rb_mut().0, from_copy::(values)), + #[inline(always)] + |(mut suffix, values)| suffix.write(values), + ); + } +} + +impl<'short, E: Entity, S: pulp::Simd> Reborrow<'short> for PrefixMut<'_, E, S> { + type Target = Prefix<'short, E, S>; + #[inline] + fn rb(&'short self) -> Self::Target { + unsafe { + Prefix( + into_copy::(transmute_unchecked::< + GroupFor as Reborrow<'_>>::Target>, + GroupFor>, + >(E::faer_map( + E::faer_as_ref(&self.0), + |x| (*x).rb(), + ))), + PhantomData, + ) + } + } +} +impl<'short, E: Entity, S: pulp::Simd> ReborrowMut<'short> for PrefixMut<'_, E, S> { + type Target = PrefixMut<'short, E, S>; + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + unsafe { + PrefixMut( + transmute_unchecked::< + GroupFor as ReborrowMut<'_>>::Target>, + GroupFor>, + >(E::faer_map(E::faer_as_mut(&mut self.0), |x| (*x).rb_mut())), + PhantomData, + ) + } + } +} +impl<'short, E: Entity, S: pulp::Simd> Reborrow<'short> for SuffixMut<'_, E, S> { + type Target = Suffix<'short, E, S>; + #[inline] + fn rb(&'short self) -> Self::Target { + unsafe { + Suffix( + into_copy::(transmute_unchecked::< + GroupFor as Reborrow<'_>>::Target>, + GroupFor>, + >(E::faer_map( + E::faer_as_ref(&self.0), + |x| (*x).rb(), + ))), + PhantomData, + ) + } + } +} +impl<'short, E: Entity, S: pulp::Simd> ReborrowMut<'short> for SuffixMut<'_, E, S> { + type Target = SuffixMut<'short, E, S>; + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + unsafe { + SuffixMut( + transmute_unchecked::< + GroupFor as ReborrowMut<'_>>::Target>, + GroupFor>, + >(E::faer_map(E::faer_as_mut(&mut self.0), |x| (*x).rb_mut())), + PhantomData, + ) + } + } +} + +impl<'short, E: Entity, S: pulp::Simd> Reborrow<'short> for Prefix<'_, E, S> { + type Target = Prefix<'short, E, S>; + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} +impl<'short, E: Entity, S: pulp::Simd> ReborrowMut<'short> for Prefix<'_, E, S> { + type Target = Prefix<'short, E, S>; + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} +impl<'short, E: Entity, S: pulp::Simd> Reborrow<'short> for Suffix<'_, E, S> { + type Target = Suffix<'short, E, S>; + #[inline] + fn rb(&'short self) -> Self::Target { + *self + } +} +impl<'short, E: Entity, S: pulp::Simd> ReborrowMut<'short> for Suffix<'_, E, S> { + type Target = Suffix<'short, E, S>; + #[inline] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl Copy for Prefix<'_, E, S> {} +impl Clone for Prefix<'_, E, S> { + #[inline] + fn clone(&self) -> Self { + *self + } +} +impl Copy for Suffix<'_, E, S> {} +impl Clone for Suffix<'_, E, S> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl core::fmt::Debug for Prefix<'_, E, S> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + unsafe { + transmute_unchecked::, GroupDebugFor>>( + self.read_or(core::mem::zeroed()), + ) + .fmt(f) + } + } +} +impl core::fmt::Debug for PrefixMut<'_, E, S> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.rb().fmt(f) + } +} +impl core::fmt::Debug for Suffix<'_, E, S> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + unsafe { + transmute_unchecked::, GroupDebugFor>>( + self.read_or(core::mem::zeroed()), + ) + .fmt(f) + } + } +} +impl core::fmt::Debug for SuffixMut<'_, E, S> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.rb().fmt(f) + } +} diff --git a/src/utils/slice.rs b/src/utils/slice.rs new file mode 100644 index 0000000000000000000000000000000000000000..01b9ab2be318fba42a902bb7c52089422aefbd74 --- /dev/null +++ b/src/utils/slice.rs @@ -0,0 +1,729 @@ +use crate::{assert, debug_assert}; +use core::{marker::PhantomData, ops::Range}; +use faer_entity::*; +use reborrow::*; + +/// Wrapper around a group of references. +pub struct RefGroup<'a, E: Entity, T: 'a = ::Unit>( + GroupCopyFor, + PhantomData<&'a ()>, +); +/// Wrapper around a group of mutable references. +pub struct RefGroupMut<'a, E: Entity, T: 'a = ::Unit>( + GroupFor, + PhantomData<&'a mut ()>, +); + +/// Analogous to an immutable reference to a [prim@slice] for groups. +pub struct SliceGroup<'a, E: Entity, T: 'a = ::Unit>( + GroupCopyFor, + PhantomData<&'a ()>, +); +/// Analogous to a mutable reference to a [prim@slice] for groups. +pub struct SliceGroupMut<'a, E: Entity, T: 'a = ::Unit>( + GroupFor, + PhantomData<&'a mut ()>, +); + +impl core::fmt::Debug for RefGroup<'_, E, T> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + unsafe { + transmute_unchecked::, GroupDebugFor>(self.into_inner()).fmt(f) + } + } +} +impl core::fmt::Debug for RefGroupMut<'_, E, T> { + #[inline] + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.rb().fmt(f) + } +} +impl core::fmt::Debug for SliceGroup<'_, E, T> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_list().entries(self.into_ref_iter()).finish() + } +} +impl core::fmt::Debug for SliceGroupMut<'_, E, T> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.rb().fmt(f) + } +} + +unsafe impl Send for SliceGroup<'_, E, T> {} +unsafe impl Sync for SliceGroup<'_, E, T> {} +unsafe impl Send for SliceGroupMut<'_, E, T> {} +unsafe impl Sync for SliceGroupMut<'_, E, T> {} + +impl Copy for SliceGroup<'_, E, T> {} +impl Copy for RefGroup<'_, E, T> {} +impl Clone for SliceGroup<'_, E, T> { + #[inline] + fn clone(&self) -> Self { + *self + } +} +impl Clone for RefGroup<'_, E, T> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl<'a, E: Entity, T> RefGroup<'a, E, T> { + /// Create a new [`RefGroup`] from a group of references. + #[inline(always)] + pub fn new(reference: GroupFor) -> Self { + Self( + into_copy::(E::faer_map( + reference, + #[inline(always)] + |reference| reference as *const T, + )), + PhantomData, + ) + } + + /// Consume `self` to return the internally stored group of references. + #[inline(always)] + pub fn into_inner(self) -> GroupFor { + E::faer_map( + from_copy::(self.0), + #[inline(always)] + |ptr| unsafe { &*ptr }, + ) + } + + /// Copies and returns the value pointed to by the references. + #[inline(always)] + pub fn get(self) -> GroupCopyFor + where + T: Copy, + { + into_copy::(E::faer_deref(self.into_inner())) + } +} + +impl<'a, E: Entity, T, const N: usize> RefGroup<'a, E, [T; N]> { + /// Convert a reference to an array to an array of references. + #[inline(always)] + pub fn unzip(self) -> [RefGroup<'a, E, T>; N] { + unsafe { + let mut out = transmute_unchecked::< + core::mem::MaybeUninit<[RefGroup<'a, E, T>; N]>, + [core::mem::MaybeUninit>; N], + >(core::mem::MaybeUninit::<[RefGroup<'a, E, T>; N]>::uninit()); + for (out, inp) in core::iter::zip(out.iter_mut(), E::faer_into_iter(self.into_inner())) + { + out.write(RefGroup::new(inp)); + } + transmute_unchecked::< + [core::mem::MaybeUninit>; N], + [RefGroup<'a, E, T>; N], + >(out) + } + } +} + +impl<'a, E: Entity, T, const N: usize> RefGroupMut<'a, E, [T; N]> { + /// Convert a mutable reference to an array to an array of mutable references. + #[inline(always)] + pub fn unzip(self) -> [RefGroupMut<'a, E, T>; N] { + unsafe { + let mut out = + transmute_unchecked::< + core::mem::MaybeUninit<[RefGroupMut<'a, E, T>; N]>, + [core::mem::MaybeUninit>; N], + >(core::mem::MaybeUninit::<[RefGroupMut<'a, E, T>; N]>::uninit()); + for (out, inp) in core::iter::zip(out.iter_mut(), E::faer_into_iter(self.into_inner())) + { + out.write(RefGroupMut::new(inp)); + } + transmute_unchecked::< + [core::mem::MaybeUninit>; N], + [RefGroupMut<'a, E, T>; N], + >(out) + } + } +} + +impl<'a, E: Entity, T> RefGroupMut<'a, E, T> { + /// Create a new [`RefGroupMut`] from a group of mutable references. + #[inline(always)] + pub fn new(reference: GroupFor) -> Self { + Self( + E::faer_map( + reference, + #[inline(always)] + |reference| reference as *mut T, + ), + PhantomData, + ) + } + + /// Consume `self` to return the internally stored group of references. + #[inline(always)] + pub fn into_inner(self) -> GroupFor { + E::faer_map( + self.0, + #[inline(always)] + |ptr| unsafe { &mut *ptr }, + ) + } + + /// Copies and returns the value pointed to by the references. + #[inline(always)] + pub fn get(&self) -> GroupCopyFor + where + T: Copy, + { + self.rb().get() + } + + /// Writes `value` to the location pointed to by the references. + #[inline(always)] + pub fn set(&mut self, value: GroupCopyFor) + where + T: Copy, + { + E::faer_map( + E::faer_zip(self.rb_mut().into_inner(), from_copy::(value)), + #[inline(always)] + |(r, value)| *r = value, + ); + } +} + +impl<'a, E: Entity, T> IntoConst for SliceGroup<'a, E, T> { + type Target = SliceGroup<'a, E, T>; + + #[inline(always)] + fn into_const(self) -> Self::Target { + self + } +} +impl<'a, E: Entity, T> IntoConst for SliceGroupMut<'a, E, T> { + type Target = SliceGroup<'a, E, T>; + + #[inline(always)] + fn into_const(self) -> Self::Target { + SliceGroup::new(E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| &*slice, + )) + } +} + +impl<'a, E: Entity, T> IntoConst for RefGroup<'a, E, T> { + type Target = RefGroup<'a, E, T>; + + #[inline(always)] + fn into_const(self) -> Self::Target { + self + } +} +impl<'a, E: Entity, T> IntoConst for RefGroupMut<'a, E, T> { + type Target = RefGroup<'a, E, T>; + + #[inline(always)] + fn into_const(self) -> Self::Target { + RefGroup::new(E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| &*slice, + )) + } +} + +impl<'short, 'a, E: Entity, T> ReborrowMut<'short> for RefGroup<'a, E, T> { + type Target = RefGroup<'short, E, T>; + + #[inline(always)] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl<'short, 'a, E: Entity, T> Reborrow<'short> for RefGroup<'a, E, T> { + type Target = RefGroup<'short, E, T>; + + #[inline(always)] + fn rb(&'short self) -> Self::Target { + *self + } +} + +impl<'short, 'a, E: Entity, T> ReborrowMut<'short> for RefGroupMut<'a, E, T> { + type Target = RefGroupMut<'short, E, T>; + + #[inline(always)] + fn rb_mut(&'short mut self) -> Self::Target { + RefGroupMut::new(E::faer_map( + E::faer_as_mut(&mut self.0), + #[inline(always)] + |this| unsafe { &mut **this }, + )) + } +} + +impl<'short, 'a, E: Entity, T> Reborrow<'short> for RefGroupMut<'a, E, T> { + type Target = RefGroup<'short, E, T>; + + #[inline(always)] + fn rb(&'short self) -> Self::Target { + RefGroup::new(E::faer_map( + E::faer_as_ref(&self.0), + #[inline(always)] + |this| unsafe { &**this }, + )) + } +} + +impl<'a, E: Entity, T> SliceGroup<'a, E, T> { + /// Create a new [`SliceGroup`] from a group of slice references. + #[inline(always)] + pub fn new(slice: GroupFor) -> Self { + Self( + into_copy::(E::faer_map(slice, |slice| slice as *const [T])), + PhantomData, + ) + } + + /// Consume `self` to return the internally stored group of slice references. + #[inline(always)] + pub fn into_inner(self) -> GroupFor { + unsafe { E::faer_map(from_copy::(self.0), |ptr| &*ptr) } + } + + /// Decompose `self` into a slice of arrays of size `N`, and a remainder part with length + /// `< N`. + #[inline(always)] + pub fn as_arrays(self) -> (SliceGroup<'a, E, [T; N]>, SliceGroup<'a, E, T>) { + let (head, tail) = E::faer_as_arrays::(self.into_inner()); + (SliceGroup::new(head), SliceGroup::new(tail)) + } +} + +impl<'a, E: Entity, T> SliceGroupMut<'a, E, T> { + /// Create a new [`SliceGroup`] from a group of mutable slice references. + #[inline(always)] + pub fn new(slice: GroupFor) -> Self { + Self(E::faer_map(slice, |slice| slice as *mut [T]), PhantomData) + } + + /// Consume `self` to return the internally stored group of mutable slice references. + #[inline(always)] + pub fn into_inner(self) -> GroupFor { + unsafe { E::faer_map(self.0, |ptr| &mut *ptr) } + } + + /// Decompose `self` into a mutable slice of arrays of size `N`, and a remainder part with + /// length `< N`. + #[inline(always)] + pub fn as_arrays_mut( + self, + ) -> (SliceGroupMut<'a, E, [T; N]>, SliceGroupMut<'a, E, T>) { + let (head, tail) = E::faer_as_arrays_mut::(self.into_inner()); + (SliceGroupMut::new(head), SliceGroupMut::new(tail)) + } +} + +impl<'short, 'a, E: Entity, T> ReborrowMut<'short> for SliceGroup<'a, E, T> { + type Target = SliceGroup<'short, E, T>; + + #[inline(always)] + fn rb_mut(&'short mut self) -> Self::Target { + *self + } +} + +impl<'short, 'a, E: Entity, T> Reborrow<'short> for SliceGroup<'a, E, T> { + type Target = SliceGroup<'short, E, T>; + + #[inline(always)] + fn rb(&'short self) -> Self::Target { + *self + } +} + +impl<'short, 'a, E: Entity, T> ReborrowMut<'short> for SliceGroupMut<'a, E, T> { + type Target = SliceGroupMut<'short, E, T>; + + #[inline(always)] + fn rb_mut(&'short mut self) -> Self::Target { + SliceGroupMut::new(E::faer_map( + E::faer_as_mut(&mut self.0), + #[inline(always)] + |this| unsafe { &mut **this }, + )) + } +} + +impl<'short, 'a, E: Entity, T> Reborrow<'short> for SliceGroupMut<'a, E, T> { + type Target = SliceGroup<'short, E, T>; + + #[inline(always)] + fn rb(&'short self) -> Self::Target { + SliceGroup::new(E::faer_map( + E::faer_as_ref(&self.0), + #[inline(always)] + |this| unsafe { &**this }, + )) + } +} + +impl<'a, E: Entity> RefGroup<'a, E> { + /// Read the element pointed to by the references. + #[inline(always)] + pub fn read(&self) -> E { + E::faer_from_units(E::faer_deref(self.into_inner())) + } +} + +impl<'a, E: Entity> RefGroupMut<'a, E> { + /// Read the element pointed to by the references. + #[inline(always)] + pub fn read(&self) -> E { + self.rb().read() + } + + /// Write `value` to the location pointed to by the references. + #[inline(always)] + pub fn write(&mut self, value: E) { + E::faer_map( + E::faer_zip(self.rb_mut().into_inner(), value.faer_into_units()), + #[inline(always)] + |(r, value)| *r = value, + ); + } +} + +impl<'a, E: Entity> SliceGroup<'a, E> { + /// Read the element at position `idx`. + #[inline(always)] + #[track_caller] + pub fn read(&self, idx: usize) -> E { + assert!(idx < self.len()); + unsafe { self.read_unchecked(idx) } + } + + /// Read the element at position `idx`, without bound checks. + /// + /// # Safety + /// The behavior is undefined if `idx >= self.len()`. + #[inline(always)] + #[track_caller] + pub unsafe fn read_unchecked(&self, idx: usize) -> E { + debug_assert!(idx < self.len()); + E::faer_from_units(E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| *slice.get_unchecked(idx), + )) + } +} +impl<'a, E: Entity, T> SliceGroup<'a, E, T> { + /// Get a [`RefGroup`] pointing to the element at position `idx`. + #[inline(always)] + #[track_caller] + pub fn get(self, idx: usize) -> RefGroup<'a, E, T> { + assert!(idx < self.len()); + unsafe { self.get_unchecked(idx) } + } + + /// Get a [`RefGroup`] pointing to the element at position `idx`, without bound checks. + /// + /// # Safety + /// The behavior is undefined if `idx >= self.len()`. + #[inline(always)] + #[track_caller] + pub unsafe fn get_unchecked(self, idx: usize) -> RefGroup<'a, E, T> { + debug_assert!(idx < self.len()); + RefGroup::new(E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| slice.get_unchecked(idx), + )) + } + + /// Checks whether the slice is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the length of the slice. + #[inline] + pub fn len(&self) -> usize { + let mut len = usize::MAX; + E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| len = Ord::min(len, slice.len()), + ); + len + } + + /// Returns the subslice of `self` from the start to the end of the provided range. + #[inline(always)] + #[track_caller] + pub fn subslice(self, range: Range) -> Self { + assert!(all(range.start <= range.end, range.end <= self.len())); + unsafe { self.subslice_unchecked(range) } + } + + /// Split `self` at the midpoint `idx`, and return the two parts. + #[inline(always)] + #[track_caller] + pub fn split_at(self, idx: usize) -> (Self, Self) { + assert!(idx <= self.len()); + let (head, tail) = E::faer_unzip(E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| slice.split_at(idx), + )); + (Self::new(head), Self::new(tail)) + } + + /// Returns the subslice of `self` from the start to the end of the provided range, without + /// bound checks. + /// + /// # Safety + /// The behavior is undefined if `range.start > range.end` or `range.end > self.len()`. + #[inline(always)] + #[track_caller] + pub unsafe fn subslice_unchecked(self, range: Range) -> Self { + debug_assert!(all(range.start <= range.end, range.end <= self.len())); + Self::new(E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| slice.get_unchecked(range.start..range.end), + )) + } + + /// Returns an iterator of [`RefGroup`] over the elements of the slice. + #[inline(always)] + pub fn into_ref_iter(self) -> impl Iterator> { + E::faer_into_iter(self.into_inner()).map(RefGroup::new) + } + + /// Returns an iterator of slices over chunks of size `chunk_size`, and the remainder of + /// the slice. + #[inline(always)] + pub fn into_chunks_exact( + self, + chunk_size: usize, + ) -> (impl Iterator>, Self) { + let len = self.len(); + let mid = len / chunk_size * chunk_size; + let (head, tail) = E::faer_unzip(E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| slice.split_at(mid), + )); + let head = E::faer_map( + head, + #[inline(always)] + |head| head.chunks_exact(chunk_size), + ); + ( + E::faer_into_iter(head).map(SliceGroup::new), + SliceGroup::new(tail), + ) + } +} + +impl<'a, E: Entity> SliceGroupMut<'a, E> { + /// Read the element at position `idx`. + #[inline(always)] + #[track_caller] + pub fn read(&self, idx: usize) -> E { + self.rb().read(idx) + } + + /// Read the element at position `idx`, without bound checks. + /// + /// # Safety + /// The behavior is undefined if `idx >= self.len()`. + #[inline(always)] + #[track_caller] + pub unsafe fn read_unchecked(&self, idx: usize) -> E { + self.rb().read_unchecked(idx) + } + + /// Write `value` to the location at position `idx`. + #[inline(always)] + #[track_caller] + pub fn write(&mut self, idx: usize, value: E) { + assert!(idx < self.len()); + unsafe { self.write_unchecked(idx, value) } + } + + /// Write `value` to the location at position `idx`, without bound checks. + /// + /// # Safety + /// The behavior is undefined if `idx >= self.len()`. + #[inline(always)] + #[track_caller] + pub unsafe fn write_unchecked(&mut self, idx: usize, value: E) { + debug_assert!(idx < self.len()); + E::faer_map( + E::faer_zip(self.rb_mut().into_inner(), value.faer_into_units()), + #[inline(always)] + |(slice, value)| *slice.get_unchecked_mut(idx) = value, + ); + } + + /// Fill the slice with zeros. + #[inline] + pub fn fill_zero(&mut self) { + E::faer_map(self.rb_mut().into_inner(), |slice| unsafe { + let len = slice.len(); + core::ptr::write_bytes(slice.as_mut_ptr(), 0u8, len); + }); + } +} + +impl<'a, E: Entity, T> SliceGroupMut<'a, E, T> { + /// Get a [`RefGroupMut`] pointing to the element at position `idx`. + #[inline(always)] + #[track_caller] + pub fn get_mut(self, idx: usize) -> RefGroupMut<'a, E, T> { + assert!(idx < self.len()); + unsafe { self.get_unchecked_mut(idx) } + } + + /// Get a [`RefGroupMut`] pointing to the element at position `idx`. + /// + /// # Safety + /// The behavior is undefined if `idx >= self.len()`. + #[inline(always)] + #[track_caller] + pub unsafe fn get_unchecked_mut(self, idx: usize) -> RefGroupMut<'a, E, T> { + debug_assert!(idx < self.len()); + RefGroupMut::new(E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| slice.get_unchecked_mut(idx), + )) + } + + /// Get a [`RefGroup`] pointing to the element at position `idx`. + #[inline(always)] + #[track_caller] + pub fn get(self, idx: usize) -> RefGroup<'a, E, T> { + self.into_const().get(idx) + } + + /// Get a [`RefGroup`] pointing to the element at position `idx`, without bound checks. + /// + /// # Safety + /// The behavior is undefined if `idx >= self.len()`. + #[inline(always)] + #[track_caller] + pub unsafe fn get_unchecked(self, idx: usize) -> RefGroup<'a, E, T> { + self.into_const().get_unchecked(idx) + } + + /// Checks whether the slice is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.rb().is_empty() + } + + /// Returns the length of the slice. + #[inline] + pub fn len(&self) -> usize { + self.rb().len() + } + + /// Returns the subslice of `self` from the start to the end of the provided range. + #[inline(always)] + #[track_caller] + pub fn subslice(self, range: Range) -> Self { + assert!(all(range.start <= range.end, range.end <= self.len())); + unsafe { self.subslice_unchecked(range) } + } + + /// Returns the subslice of `self` from the start to the end of the provided range, without + /// bound checks. + /// + /// # Safety + /// The behavior is undefined if `range.start > range.end` or `range.end > self.len()`. + #[inline(always)] + #[track_caller] + pub unsafe fn subslice_unchecked(self, range: Range) -> Self { + debug_assert!(all(range.start <= range.end, range.end <= self.len())); + Self::new(E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| slice.get_unchecked_mut(range.start..range.end), + )) + } + + /// Returns an iterator of [`RefGroupMut`] over the elements of the slice. + #[inline(always)] + pub fn into_mut_iter(self) -> impl Iterator> { + E::faer_into_iter(self.into_inner()).map(RefGroupMut::new) + } + + /// Split `self` at the midpoint `idx`, and return the two parts. + #[inline(always)] + #[track_caller] + pub fn split_at(self, idx: usize) -> (Self, Self) { + assert!(idx <= self.len()); + let (head, tail) = E::faer_unzip(E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| slice.split_at_mut(idx), + )); + (Self::new(head), Self::new(tail)) + } + + /// Returns an iterator of slices over chunks of size `chunk_size`, and the remainder of + /// the slice. + #[inline(always)] + pub fn into_chunks_exact( + self, + chunk_size: usize, + ) -> (impl Iterator>, Self) { + let len = self.len(); + let mid = len % chunk_size * chunk_size; + let (head, tail) = E::faer_unzip(E::faer_map( + self.into_inner(), + #[inline(always)] + |slice| slice.split_at_mut(mid), + )); + let head = E::faer_map( + head, + #[inline(always)] + |head| head.chunks_exact_mut(chunk_size), + ); + ( + E::faer_into_iter(head).map(SliceGroupMut::new), + SliceGroupMut::new(tail), + ) + } +} + +impl pulp::Read for RefGroupMut<'_, E, T> { + type Output = GroupCopyFor; + #[inline(always)] + fn read_or(&self, _or: Self::Output) -> Self::Output { + self.get() + } +} +impl pulp::Write for RefGroupMut<'_, E, T> { + #[inline(always)] + fn write(&mut self, values: Self::Output) { + self.set(values) + } +} +impl pulp::Read for RefGroup<'_, E, T> { + type Output = GroupCopyFor; + #[inline(always)] + fn read_or(&self, _or: Self::Output) -> Self::Output { + self.get() + } +} diff --git a/src/utils/thread.rs b/src/utils/thread.rs new file mode 100644 index 0000000000000000000000000000000000000000..22ee7f181387a8bf5b17ab00fb1fb7b38b5249c3 --- /dev/null +++ b/src/utils/thread.rs @@ -0,0 +1,124 @@ +use crate::*; + +/// Executes the two operations, possibly in parallel, while splitting the amount of parallelism +/// between the two. +#[inline] +pub fn join_raw( + op_a: impl Send + FnOnce(Parallelism), + op_b: impl Send + FnOnce(Parallelism), + parallelism: Parallelism, +) { + fn implementation( + op_a: &mut (dyn Send + FnMut(Parallelism)), + op_b: &mut (dyn Send + FnMut(Parallelism)), + parallelism: Parallelism, + ) { + match parallelism { + Parallelism::None => (op_a(parallelism), op_b(parallelism)), + #[cfg(feature = "rayon")] + Parallelism::Rayon(n_threads) => { + if n_threads == 1 { + (op_a(Parallelism::None), op_b(Parallelism::None)) + } else { + let n_threads = if n_threads > 0 { + n_threads + } else { + rayon::current_num_threads() + }; + let parallelism = Parallelism::Rayon(n_threads - n_threads / 2); + rayon::join(|| op_a(parallelism), || op_b(parallelism)) + } + } + }; + } + let mut op_a = Some(op_a); + let mut op_b = Some(op_b); + implementation( + &mut |parallelism| (op_a.take().unwrap())(parallelism), + &mut |parallelism| (op_b.take().unwrap())(parallelism), + parallelism, + ) +} + +/// Executes the tasks by passing the values in `0..n_tasks` to `op`, possibly in parallel, while +/// splitting the amount of parallelism between the tasks. +#[inline] +pub fn for_each_raw(n_tasks: usize, op: impl Send + Sync + Fn(usize), parallelism: Parallelism) { + fn implementation( + n_tasks: usize, + op: &(dyn Send + Sync + Fn(usize)), + parallelism: Parallelism, + ) { + if n_tasks == 1 { + op(0); + return; + } + + match parallelism { + Parallelism::None => (0..n_tasks).for_each(op), + #[cfg(feature = "rayon")] + Parallelism::Rayon(n_threads) => { + let n_threads = if n_threads > 0 { + n_threads + } else { + rayon::current_num_threads() + }; + + use rayon::prelude::*; + let min_len = n_tasks / n_threads; + (0..n_tasks) + .into_par_iter() + .with_min_len(min_len) + .for_each(op); + } + } + } + implementation(n_tasks, &op, parallelism); +} + +/// Unsafe [`Send`] and [`Sync`] pointer type. +pub struct Ptr(pub *mut T); +unsafe impl Send for Ptr {} +unsafe impl Sync for Ptr {} +impl Copy for Ptr {} +impl Clone for Ptr { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +/// The amount of threads that should ideally execute an operation with the given parallelism. +#[inline] +pub fn parallelism_degree(parallelism: Parallelism) -> usize { + match parallelism { + Parallelism::None => 1, + #[cfg(feature = "rayon")] + Parallelism::Rayon(0) => rayon::current_num_threads(), + #[cfg(feature = "rayon")] + Parallelism::Rayon(n_threads) => n_threads, + } +} + +/// Returns the start and length of a subsegment of `0..n`, split between `chunk_count` consumers, +/// for the consumer at index `idx`. +/// +/// For the same `n` and `chunk_count`, different values of `idx` between in `0..chunk_count` will +/// represent distinct subsegments. +#[inline] +pub fn par_split_indices(n: usize, idx: usize, chunk_count: usize) -> (usize, usize) { + let chunk_size = n / chunk_count; + let rem = n % chunk_count; + + let idx_to_col_start = move |idx| { + if idx < rem { + idx * (chunk_size + 1) + } else { + rem + idx * chunk_size + } + }; + + let start = idx_to_col_start(idx); + let end = idx_to_col_start(idx + 1); + (start, end - start) +} diff --git a/src/utils/vec.rs b/src/utils/vec.rs new file mode 100644 index 0000000000000000000000000000000000000000..246a030a3c7fb072b254239e7ea9e3d7e44aa96e --- /dev/null +++ b/src/utils/vec.rs @@ -0,0 +1,219 @@ +use super::slice::*; +use core::fmt::Debug; +use faer_entity::*; + +/// Analogous to [`alloc::vec::Vec`] for groups. +pub struct VecGroup> { + inner: GroupFor>, +} + +impl Clone for VecGroup { + #[inline] + fn clone(&self) -> Self { + Self { + inner: E::faer_map(E::faer_as_ref(&self.inner), |v| (*v).clone()), + } + } +} + +impl Debug for VecGroup { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + self.as_slice().fmt(f) + } +} + +unsafe impl Sync for VecGroup {} +unsafe impl Send for VecGroup {} + +impl Default for VecGroup { + fn default() -> Self { + Self::new() + } +} + +impl VecGroup { + /// Create a new [`VecGroup`] from a group of [`alloc::vec::Vec`]. + #[inline] + pub fn from_inner(inner: GroupFor>) -> Self { + Self { inner } + } + + /// Consume `self` to return a group of [`alloc::vec::Vec`]. + #[inline] + pub fn into_inner(self) -> GroupFor> { + self.inner + } + + /// Return a reference to the inner group of [`alloc::vec::Vec`]. + #[inline] + pub fn as_inner_ref(&self) -> GroupFor> { + E::faer_as_ref(&self.inner) + } + + /// Return a mutable reference to the inner group of [`alloc::vec::Vec`]. + #[inline] + pub fn as_inner_mut(&mut self) -> GroupFor> { + E::faer_as_mut(&mut self.inner) + } + + /// Return a [`SliceGroup`] view over the elements of `self`. + #[inline] + pub fn as_slice(&self) -> SliceGroup<'_, E, T> { + SliceGroup::new(E::faer_map( + E::faer_as_ref(&self.inner), + #[inline] + |slice| &**slice, + )) + } + + /// Return a [`SliceGroupMut`] mutable view over the elements of `self`. + #[inline] + pub fn as_slice_mut(&mut self) -> SliceGroupMut<'_, E, T> { + SliceGroupMut::new(E::faer_map( + E::faer_as_mut(&mut self.inner), + #[inline] + |slice| &mut **slice, + )) + } + + /// Create an empty [`VecGroup`]. + #[inline] + pub fn new() -> Self { + Self { + inner: E::faer_map(E::UNIT, |()| alloc::vec::Vec::new()), + } + } + + /// Returns the length of the vector group. + #[inline] + pub fn len(&self) -> usize { + let mut len = usize::MAX; + E::faer_map( + E::faer_as_ref(&self.inner), + #[inline(always)] + |slice| len = Ord::min(len, slice.len()), + ); + len + } + + /// Returns the capacity of the vector group. + #[inline] + pub fn capacity(&self) -> usize { + let mut cap = usize::MAX; + E::faer_map( + E::faer_as_ref(&self.inner), + #[inline(always)] + |slice| cap = Ord::min(cap, slice.capacity()), + ); + cap + } + + /// Reserve enough capacity for extra `additional` elements. + pub fn reserve(&mut self, additional: usize) { + E::faer_map(E::faer_as_mut(&mut self.inner), |v| v.reserve(additional)); + } + + /// Reserve exactly enough capacity for extra `additional` elements. + pub fn reserve_exact(&mut self, additional: usize) { + E::faer_map(E::faer_as_mut(&mut self.inner), |v| { + v.reserve_exact(additional) + }); + } + + /// Try to reserve enough capacity for extra `additional` elements. + pub fn try_reserve( + &mut self, + additional: usize, + ) -> Result<(), alloc::collections::TryReserveError> { + let mut result = Ok(()); + E::faer_map(E::faer_as_mut(&mut self.inner), |v| match &result { + Ok(()) => result = v.try_reserve(additional), + Err(_) => {} + }); + result + } + + /// Try to reserve exactly enough capacity for extra `additional` elements. + pub fn try_reserve_exact( + &mut self, + additional: usize, + ) -> Result<(), alloc::collections::TryReserveError> { + let mut result = Ok(()); + E::faer_map(E::faer_as_mut(&mut self.inner), |v| match &result { + Ok(()) => result = v.try_reserve_exact(additional), + Err(_) => {} + }); + result + } + + /// Truncate the length of the vector to `len`. + pub fn truncate(&mut self, len: usize) { + E::faer_map(E::faer_as_mut(&mut self.inner), |v| v.truncate(len)); + } + + /// Clear the vector, making it empty. + pub fn clear(&mut self) { + E::faer_map(E::faer_as_mut(&mut self.inner), |v| v.clear()); + } + + /// Resize the vector to `new_len`, filling the new elements with + /// `value`. + pub fn resize(&mut self, new_len: usize, value: GroupFor) + where + T: Clone, + { + E::faer_map( + E::faer_zip(E::faer_as_mut(&mut self.inner), value), + |(v, value)| v.resize(new_len, value), + ); + } + + /// Resize the vector to `new_len`, filling the new elements with + /// the output of `f`. + pub fn resize_with(&mut self, new_len: usize, f: impl FnMut() -> GroupFor) { + let len = self.len(); + let mut f = f; + if new_len <= len { + self.truncate(new_len); + } else { + self.reserve(new_len - len); + for _ in len..new_len { + self.push(f()) + } + } + } + + /// Push a new element to the end of `self`. + #[inline] + pub fn push(&mut self, value: GroupFor) { + E::faer_map( + E::faer_zip(E::faer_as_mut(&mut self.inner), value), + #[inline] + |(v, value)| v.push(value), + ); + } + + /// Remove a new element from the end of `self`, and return it. + #[inline] + pub fn pop(&mut self) -> Option> { + if self.len() >= 1 { + Some(E::faer_map( + E::faer_as_mut(&mut self.inner), + #[inline] + |v| v.pop().unwrap(), + )) + } else { + None + } + } + + /// Remove a new element from position `index`, and return it. + #[inline] + pub fn remove(&mut self, index: usize) -> GroupFor { + E::faer_map( + E::faer_as_mut(&mut self.inner), + #[inline] + |v| v.remove(index), + ) + } +} diff --git a/faer-libs/faer-sparse/test_data/YAO.mtx b/test_data/YAO.mtx similarity index 100% rename from faer-libs/faer-sparse/test_data/YAO.mtx rename to test_data/YAO.mtx diff --git a/faer-libs/faer-sparse/test_data/lp_share2b.mtx b/test_data/lp_share2b.mtx similarity index 100% rename from faer-libs/faer-sparse/test_data/lp_share2b.mtx rename to test_data/lp_share2b.mtx