Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index traits #410

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -19,7 +19,10 @@ ArrayInterfaceBandedMatricesExt = "BandedMatrices"
ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices"
ArrayInterfaceCUDAExt = "CUDA"
ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore"
ArrayInterfaceOffsetArraysExt = "OffsetArrays"
ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore"
ArrayInterfaceStaticArraysExt = "StaticArrays"
ArrayInterfaceStaticExt = "Static"
ArrayInterfaceTrackerExt = "Tracker"

[extras]
@@ -30,21 +33,26 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "StaticArrays", "StaticArraysCore", "Tracker"]
test = ["SafeTestsets", "Pkg", "Test", "Aqua", "Random", "SparseArrays", "SuiteSparse", "BandedMatrices", "BlockBandedMatrices", "GPUArraysCore", "OffsetArrays", "StaticArrays", "StaticArraysCore", "Static", "Tracker"]

[weakdeps]
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
8 changes: 7 additions & 1 deletion docs/src/indexing.md
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@ ArrayInterface.can_change_size
ArrayInterface.can_setindex
ArrayInterface.fast_scalar_indexing
ArrayInterface.ismutable
ArrayInterface.is_splat_index
ArrayInterface.ndims_index
ArrayInterface.ndims_shape
ArrayInterface.defines_strides
@@ -22,6 +23,11 @@ ArrayInterface.ensures_sorted
ArrayInterface.indices_do_not_alias
ArrayInterface.instances_do_not_alias
ArrayInterface.device
ArrayInterface.known_first
ArrayInterface.known_step
ArrayInterface.known_last
ArrayInterface.known_size
ArrayInterface.known_length
```

## Allowed Indexing Functions
@@ -46,4 +52,4 @@ and index translations.
ArrayInterface.ArrayIndex
ArrayInterface.GetIndex
ArrayInterface.SetIndex!
```
```
21 changes: 21 additions & 0 deletions ext/ArrayInterfaceOffsetArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module ArrayInterfaceOffsetArraysExt

if isdefined(Base, :get_extension)
using ArrayInterface
using OffsetArrays
else
using ..ArrayInterface
using ..OffsetArrays
end

ArrayInterface.parent_type(@nospecialize T::Type{<:OffsetArrays.IdOffsetRange}) = fieldtype(T, :parent)
ArrayInterface.parent_type(@nospecialize T::Type{<:OffsetArray}) = fieldtype(T, :parent)

function ArrayInterface.known_size(@nospecialize T::Type{<:OffsetArrays.IdOffsetRange})
ArrayInterface.known_size(ArrayInterface.parent_type(T))
end
function ArrayInterface.known_size(@nospecialize T::Type{<:OffsetArray})
ArrayInterface.known_size(ArrayInterface.parent_type(T))
end

end
9 changes: 9 additions & 0 deletions ext/ArrayInterfaceStaticArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -32,4 +32,13 @@ end

ArrayInterface.restructure(x::StaticArraysCore.SArray{S}, y) where {S} = StaticArraysCore.SArray{S}(y)

function ArrayInterface.known_size(::Type{<:StaticArraysCore.StaticArray{S}}) where {S}
@isdefined(S) ? tuple(S.parameters...) : ntuple(_-> nothing, ndims(T))
end

function ArrayInterface.known_length(T::Type{<:StaticArraysCore.StaticArray})
sz = ArrayInterface.known_size(T)
isa(sz, Tuple{Vararg{Nothing}}) ? nothing : prod(sz)
end

end
26 changes: 26 additions & 0 deletions ext/ArrayInterfaceStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module ArrayInterfaceStaticArraysExt

if isdefined(Base, :get_extension)
import ArrayInterface
import StaticArrays
else
import ..ArrayInterface
import ..StaticArrays
end

ArrayInterface.known_first(@nospecialize T::Type{<:StaticArrays.SOneTo}) = 1
ArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = @isdefined(N) ? N::Int : nothing

function ArrayInterface.known_first(::Type{<:StaticArrays.SUnitRange{S}}) where {S}
@isdefined(S) ? S::Int : nothing
end
function ArrayInterface.known_size(::Type{<:StaticArrays.SUnitRange{<:Any, L}}) where {L}
@isdefined(L) ? (L::Int,) : (nothing,)
end
function ArrayInterface.known_last(::Type{<:StaticArrays.SUnitRange{S, L}}) where {S, L}
start = @isdefined(S) ? S::Int : nothing
len = @isdefined(L) ? L::Int : nothing
(start === nothing || len === nothing) ? nothing : (start + len - 1)
end

end
19 changes: 19 additions & 0 deletions ext/ArrayInterfaceStaticExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module ArrayInterfaceStaticExt

if isdefined(Base, :get_extension)
import ArrayInterface
import Static
else
import ..ArrayInterface
import ..Static
end

ArrayInterface.known_first(::Type{<:Static.OptionallyStaticUnitRange{Static.StaticInt{F}}}) where {F} = F::Int
ArrayInterface.known_first(::Type{<:Static.OptionallyStaticStepRange{Static.StaticInt{F}}}) where {F} = F::Int

ArrayInterface.known_step(::Type{<:Static.OptionallyStaticStepRange{<:Any,Static.StaticInt{S}}}) where {S} = S::Int

ArrayInterface.known_last(::Type{<:Static.OptionallyStaticUnitRange{<:Any,Static.StaticInt{L}}}) where {L} = L::Int
ArrayInterface.known_last(::Type{<:Static.OptionallyStaticStepRange{<:Any,<:Any,Static.StaticInt{L}}}) where {L} = L::Int

end
248 changes: 244 additions & 4 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@ else
end
end
end

@assume_effects :total __parameterless_type(T)=Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)
@@ -486,6 +487,7 @@ end
function cholesky_instance(A::Union{SparseMatrixCSC,Symmetric{<:Number,<:SparseMatrixCSC}}, pivot = DEFAULT_CHOLESKY_PIVOT)
cholesky(sparse(similar(A, 1, 1)), check = false)
end


"""
cholesky_instance(a::Number, pivot = LinearAlgebra.RowMaximum()) -> a
@@ -837,6 +839,13 @@ Base.@propagate_inbounds function Base.getindex(ind::TridiagonalIndex, i::Int)
end
end

"""
is_splat_index(::Type{T}) -> Bool
Returns `true` if `T` is a type that splats across multiple dimensions.
"""
is_splat_index(T::Type) = false
is_splat_index(@nospecialize(x)) = is_splat_index(typeof(x))

"""
ndims_index(::Type{I}) -> Int
@@ -866,7 +875,7 @@ ndims_index(::Type{CartesianIndices{0, Tuple{}}}) = 1
ndims_index(@nospecialize T::Type{<:AbstractArray{Bool}}) = ndims(T)
ndims_index(@nospecialize T::Type{<:AbstractArray}) = ndims_index(eltype(T))
ndims_index(@nospecialize T::Type{<:Base.LogicalIndex}) = ndims(fieldtype(T, :mask))
ndims_index(T::Type) = 1
ndims_index(@nospecialize(T::Type)) = 1
ndims_index(@nospecialize(i)) = ndims_index(typeof(i))

"""
@@ -887,16 +896,14 @@ julia> ndims(CartesianIndices((2,2))[[CartesianIndex(1, 1), CartesianIndex(1, 2)
1
"""
ndims_shape(T::DataType) = ndims_index(T)
ndims_shape(T::Type) = ndims_index(T)
ndims_shape(::Type{Colon}) = 1
ndims_shape(@nospecialize T::Type{<:CartesianIndices}) = ndims(T)
ndims_shape(@nospecialize T::Type{<:Union{Number, Base.AbstractCartesianIndex}}) = 0
ndims_shape(@nospecialize T::Type{<:AbstractArray{Bool}}) = 1
ndims_shape(@nospecialize T::Type{<:AbstractArray}) = ndims(T)
ndims_shape(x) = ndims_shape(typeof(x))



"""
instances_do_not_alias(::Type{T}) -> Bool
@@ -1030,6 +1037,237 @@ ensures_sorted(@nospecialize( T::Type{<:AbstractRange})) = true
ensures_sorted(T::Type) = is_forwarding_wrapper(T) ? ensures_sorted(parent_type(T)) : false
ensures_sorted(@nospecialize(x)) = ensures_sorted(typeof(x))

"""
known_first(I::Type) -> Union{Int, Nothing}
Return the first index in an index range of type `I` when known at compile time.
Otherwise, return `nothing`.
See also: [`ArrayInterface.known_last`](@ref), [`ArrayInterface.known_step`](@ref)
```julia
julia> known_first(typeof(1:4))
nothing
julia> known_first(typeof(Base.OneTo(4)))
1
```
"""
known_first(x) = known_first(typeof(x))
known_first(T::Type) = is_forwarding_wrapper(T) ? known_first(parent_type(T)) : nothing
known_first(::Type{<:Base.OneTo}) = 1
known_first(@nospecialize T::Type{<:LinearIndices}) = 1
known_first(@nospecialize T::Type{<:Base.IdentityUnitRange}) = known_first(parent_type(T))
@inline function known_first(::Type{<:CartesianIndices{N, R}}) where {N, R}
tup = ntuple(i -> known_first(fieldtype(R, i)), Val(N))
isa(tup, NTuple{N, Int}) ? CartesianIndex(tup) : nothing
end

"""
known_last(::Type{T}) -> Union{Int, Nothing}
Return the last index in an index range of type `I` when known at compile time.
Otherwise, return `nothing`.
See also: [`ArrayInterface.known_first`](@ref), [`ArrayInterface.known_step`](@ref)
```julia
julia> known_last(typeof(1:4))
nothing
julia> known_first(typeof(static(1):static(4)))
4
```
"""
known_last(x) = known_last(typeof(x))
known_last(T::Type) = is_forwarding_wrapper(T) ? known_last(parent_type(T)) : nothing
@inline function known_last(::Type{<:CartesianIndices{N, R}}) where {N, R}
tup = ntuple(i -> known_last(fieldtype(R, i)), Val(N))
isa(tup, NTuple{N, Int}) ? CartesianIndex(tup) : nothing
end

"""
known_step(I::Type) -> Union{Int, Nothing}
Return the step size for an index range of type `I` when known at compile time.
Otherwise, return `nothing`.
See also: [`ArrayInterface.known_first`](@ref), [`ArrayInterface.known_last`](@ref)
```julia
julia> known_step(typeof(1:2:8))
nothing
julia> known_step(typeof(1:4))
1
```
"""
known_step(x) = known_step(typeof(x))
known_step(T::Type) = is_forwarding_wrapper(T) ? known_step(parent_type(T)) : nothing
known_step(@nospecialize T::Type{<:AbstractUnitRange}) = 1

"""
known_size(::Type{T}) -> Tuple
known_size(::Type{T}, dim) -> Union{Int, Nothing}
Returns the size of each dimension of `A` or along dimension `dim` of `A` that is known at
compile time. If a dimension does not have a known size along a dimension then `nothing` is
returned in its position.
"""
@inline known_size(x, dim::Integer) = ndims(x) < dim ? 1 : known_size(x)[dim]
known_size(x) = known_size(typeof(x))
@inline function known_size(T::Type)
if is_forwarding_wrapper(T)
return known_size(parent_type(T))
elseif isa(Base.IteratorSize(T), Base.HasShape)
return ntuple(_ -> nothing, ndims(T))
else
return (known_length(T),)
end
end
@inline known_size(@nospecialize T::Type{<:Number}) = ()
@inline known_size(@nospecialize T::Type{<:VecAdjTrans}) = (1, known_length(parent_type(T)))
@inline function known_size(@nospecialize T::Type{<:MatAdjTrans})
s1, s2 = known_size(parent_type(T))
(s2, s1)
end
function known_size(::Type{<:PermutedDimsArray{<:Any, N, I1, I2, P}}) where {N, I1, I2, P}
psize = known_size(P)
ntuple(i -> getfield(psize, getfield(I1, i)), Val{N}())
end
function known_size(@nospecialize T::Type{<:Diagonal})
s = known_length(parent_type(T))
(s, s)
end
known_size(@nospecialize T::Type{<:Union{Symmetric,Hermitian}}) = known_size(parent_type(T))
@inline function known_size(::Type{<:Base.ReinterpretArray{T,N,S,A,IsReshaped}}) where {T,N,S,A,IsReshaped}
psize = known_size(A)
if IsReshaped
if sizeof(S) > sizeof(T)
return (div(sizeof(S), sizeof(T)), psize...)
elseif sizeof(S) < sizeof(T)
return Base.tail(psize)
else
return psize
end
else
if Base.issingletontype(T) || first(psize) === nothing
return psize
else
return (div(first(psize) * sizeof(S), sizeof(T)), Base.tail(psize)...)
end
end
end
known_size(::Type{<:Base.IdentityUnitRange{I}}) where {I} = known_size(I)
known_size(::Type{<:Base.Generator{I}}) where {I} = known_size(I)
known_size(::Type{<:Iterators.Reverse{I}}) where {I} = known_size(I)
known_size(::Type{<:Iterators.Enumerate{I}}) where {I} = known_size(I)
known_size(::Type{<:Iterators.Accumulate{<:Any,I}}) where {I} = known_size(I)
known_size(::Type{<:Iterators.Pairs{<:Any,<:Any,I}}) where {I} = known_size(I)
@inline function known_size(::Type{<:Iterators.ProductIterator{T}}) where {T}
ntuple(i -> known_length(fieldtype(T, i)), Val(known_length(T)))
end
@inline function known_size(@nospecialize T::Type{<:AbstractRange})
if is_forwarding_wrapper(T)
return known_size(parent_type(T))
else
start = known_first(T)
s = known_step(T)
stop = known_last(T)
if isa(stop, Int) && isa(s, Int) && isa(start, Int)
if s > 0
return (stop < start ? 0 : div(stop - start, s) + 1,)
else
return (stop > start ? 0 : div(start - stop, -s) + 1,)
end
else
return (nothing,)
end
end
end

@inline function known_size(@nospecialize T::Type{<:Union{LinearIndices,CartesianIndices}})
I = fieldtype(T, :indices)
ntuple(i -> known_length(fieldtype(I, i)), Val(ndims(T)))
end

@inline function known_size(T::Type{<:SubArray})
I = fieldtype(T, :indices)
ninds = fieldcount(I)
if ninds === 1
I_1 = fieldtype(I, 1)
return I_1 <: Base.Slice ? (known_length(parent_type(T)),) : known_size(I_1)
else
psize = known_size(parent_type(T))
ndi_summed = cumsum(map_tuple_type(ndims_index, I))
sz = ntuple(Val{nfields(ndi_summed)}()) do i
I_i = fieldtype(I, i)
if I_i <: Base.Slice
getfield(psize, getfield(ndi_summed, i))
else
known_size(I_i)
end
end
return flatten_tuples(sz)
end
end

# 1. `Zip` doesn't check that its collections are compatible (same size) at construction,
# but we assume as much b/c otherwise it will error while iterating. So we promote to the
# known size if matching a `Nothing` and `Int` size.
# 2. `promote_shape(::Tuple{Vararg{IntType}}, ::Tuple{Vararg{IntType}})` promotes
# trailing dimensions (which must be of size 1), to `static(1)`. We want to stick to
# `Nothing` and `Int` types, so we do one last pass to ensure everything is dynamic
@inline function known_size(::Type{<:Iterators.Zip{T}}) where {T}
reduce(promote_known_shape, map_tuple_type(known_size, T))
end
function promote_known_shape(x::Tuple{Vararg{Union{Nothing,Int}, XN}}, y::Tuple{Vararg{Union{Nothing,Int}, YN}}) where {XN, YN}
if XN >= YN
ntuple(Val{XN}()) do i
x_i = getfield(x, i)
x_i === nothing ? i > YN ? 1 : getfield(y, i) : x_i
end
else
return promote_known_shape(y, x)
end
end

"""
known_length(::Type{T}) -> Union{Int, Nothing}
If `length` of an instance of type `T` is known at compile time, return it.
Otherwise, return `nothing`.
"""
known_length(x) = known_length(typeof(x))
function known_length(::Type{T}) where {T}
if isa(Base.IteratorSize(T), Base.HasShape)
# this is a multidimensional iterator so we assume that known_size is defined
sz = known_size(T)
len = 1
for sz_i in sz
isa(sz_i, Int) || return nothing
len *= sz_i
end
return len
else
# if it is an iterator with length it's compile time length is not provided
return nothing
end
end

known_length(::Type{<:NamedTuple{L}}) where {L} = nfields(L)
known_length(@nospecialize T::Type{<:Base.Slice}) = known_length(parent_type(T))
known_length(@nospecialize T::Type{<:Tuple}) = fieldcount(T)
known_length(@nospecialize T::Type{<:Number}) = 1
known_length(::Type{<:Base.AbstractCartesianIndex{N}}) where {N} = N::Int
function known_length(::Type{<:Iterators.Flatten{I}}) where {I}
lenitr = known_length(I)
lenelt = known_length(eltype(I))
(lenelt isa Int && lenitr isa Int) ? (lenitr * lenelt) : nothing
end

## Extensions

import Requires
@@ -1039,6 +1277,8 @@ import Requires
Requires.@require BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" begin include("../ext/ArrayInterfaceBlockBandedMatricesExt.jl") end
Requires.@require GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" begin include("../ext/ArrayInterfaceGPUArraysCoreExt.jl") end
Requires.@require StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" begin include("../ext/ArrayInterfaceStaticArraysCoreExt.jl") end
Requires.@require StaticArrays = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" begin include("../ext/ArrayInterfaceStaticArraysExt.jl") end
Requires.@require Static = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" begin include("../ext/ArrayInterfaceStaticExt.jl") end
Requires.@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" begin include("../ext/ArrayInterfaceCUDAExt.jl") end
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/ArrayInterfaceTrackerExt.jl") end
end
52 changes: 50 additions & 2 deletions test/core.jl
Original file line number Diff line number Diff line change
@@ -262,7 +262,6 @@ end

@testset "linearalgebra instances" begin
for A in [rand(2,2), rand(Float32,2,2), rand(BigFloat,2,2)]

@test ArrayInterface.lu_instance(A) isa typeof(lu(A))
@test ArrayInterface.qr_instance(A) isa typeof(qr(A))

@@ -282,4 +281,53 @@ end
end
@test ArrayInterface.ldlt_instance(SymTridiagonal(A' * A)) isa typeof(ldlt(SymTridiagonal(A' * A)))
end
end
end

@testset "known values" begin
CI = CartesianIndices((2, 2))

@test isnothing(@inferred(ArrayInterface.known_first(typeof(1:4))))
@test isone(@inferred(ArrayInterface.known_first(Base.OneTo(4))))
@test isone(@inferred(ArrayInterface.known_first(Base.IdentityUnitRange(Base.OneTo(4)))))
@test isone(@inferred(ArrayInterface.known_first(LinearIndices((1, 1, 1)))))
@test isone(@inferred(ArrayInterface.known_first(typeof(Base.OneTo(4)))))
@test @inferred(ArrayInterface.known_first(typeof(CI))) == CartesianIndex(1, 1)
@test @inferred(ArrayInterface.known_first(typeof(CI))) == CartesianIndex(1, 1)

@test isnothing(@inferred(ArrayInterface.known_last(1:4)))
@test isnothing(@inferred(ArrayInterface.known_last(typeof(1:4))))
@test @inferred(ArrayInterface.known_last(typeof(CI))) === nothing

@test isnothing(@inferred(ArrayInterface.known_step(typeof(1:0.2:4))))
@test isone(@inferred(ArrayInterface.known_step(1:4)))
@test isone(@inferred(ArrayInterface.known_step(typeof(1:4))))
@test isone(@inferred(ArrayInterface.known_step(typeof(Base.Slice(1:4)))))
@test isone(@inferred(ArrayInterface.known_step(typeof(view(1:4, 1:2)))))

A = zeros(3, 4, 5);
A[:] = 1:60
Ap = @view(PermutedDimsArray(A, (3, 1, 2))[:, 1:2, 1])';
Ar = reinterpret(Float32, A);
A_trailingdim = zeros(2, 3, 4, 1)
D = @view(A[:, 2:2:4, :]);
A2 = zeros(4, 3, 5)
A2r = reinterpret(ComplexF64, A2)

@test @inferred(ArrayInterface.known_size(1)) === ()
@test @inferred(ArrayInterface.known_size([1, 1]')) === (1, nothing)
@test @inferred(ArrayInterface.known_size(view([1, 1]', :, 1))) === (1, )
@test @inferred(ArrayInterface.known_size(Diagonal(view([1, 1]', :, 1)))) === (1, 1)
@test @inferred(ArrayInterface.known_size(view(rand(4), reshape(1:4, 2, 2)))) == (nothing, nothing)
@test @inferred(ArrayInterface.known_size(A)) === (nothing, nothing, nothing)
@test @inferred(ArrayInterface.known_size(Ap)) === (nothing, nothing)
@test @inferred(ArrayInterface.known_size(Ar)) === (nothing, nothing, nothing,)
@test ArrayInterface.known_size(Ar, 1) === nothing
@test ArrayInterface.known_size(Ar, 4) === 1
@test @inferred(ArrayInterface.known_size(A2)) === (nothing, nothing, nothing)
@test @inferred(ArrayInterface.known_size(A2r)) === (nothing, nothing, nothing)

@test @inferred(ArrayInterface.known_length(1)) === 1
@test @inferred(ArrayInterface.known_length(Base.Slice(1:2))) === nothing
@test @inferred(ArrayInterface.known_length(CartesianIndex(1, 2, 3))) === 3
@test @inferred(ArrayInterface.known_length((x = 1, y = 2))) === 2
end
15 changes: 15 additions & 0 deletions test/offsetarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

using ArrayInterface
using OffsetArrays
using StaticArrays
using Test

oa = OffsetArray([1, 2]', 1, 1)
@test @inferred(ArrayInterface.known_size(oa)) == (1, nothing)
@test @inferred(ArrayInterface.known_length(oa)) === nothing


id = OffsetArrays.IdOffsetRange(SOneTo(10), 1)
@test @inferred(ArrayInterface.known_size(id)) == (10, )
@test @inferred(ArrayInterface.known_length(id)) == 10

4 changes: 3 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -14,10 +14,12 @@ end
@time @safetestset "BlockBandedMatrices" begin include("blockbandedmatrices.jl") end
@time @safetestset "Core" begin include("core.jl") end
@time @safetestset "StaticArraysCore" begin include("staticarrayscore.jl") end
@time @safetestset "StaticArrays" begin include("staticarrays.jl") end
@time @safetestset "Static" begin include("static.jl") end
end

if GROUP == "GPU"
activate_gpu_env()
@time @safetestset "CUDA" begin include("gpu/cuda.jl") end
end
end
end
11 changes: 11 additions & 0 deletions test/static.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@

using ArrayInterface
using Static
using Test

iprod = Iterators.product(static(1):static(2), static(1):static(3), static(1):static(4))
@test @inferred(ArrayInterface.known_size(iprod)) === (2, 3, 4)

iflat = Iterators.flatten(iprod)
@test @inferred(ArrayInterface.known_size(iflat)) === (72,)

52 changes: 52 additions & 0 deletions test/staticarrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@

using ArrayInterface
using StaticArrays
using Test

so = SOneTo(10)
@test ArrayInterface.known_first(typeof(so)) == first(so)
@test ArrayInterface.known_last(typeof(so)) == last(so)
@test ArrayInterface.known_length(typeof(so)) == length(so)

su = StaticArrays.SUnitRange(2, 10)
@test ArrayInterface.known_first(typeof(su)) == first(su)
@test ArrayInterface.known_last(typeof(su)) == last(su)
@test ArrayInterface.known_length(typeof(su)) == length(su)

S = @SArray(zeros(2, 3, 4))
Sp = @view(PermutedDimsArray(S, (3, 1, 2))[2:3, 1:2, :]);
Sp2 = @view(PermutedDimsArray(S, (3, 2, 1))[2:3, :, :]);
Mp = @view(PermutedDimsArray(S, (3, 1, 2))[:, 2, :])';
Mp2 = @view(PermutedDimsArray(S, (3, 1, 2))[2:3, :, 2])';


irev = Iterators.reverse(S)
igen = Iterators.map(identity, S)
iacc = Iterators.accumulate(+, S)

ienum = enumerate(S)
ipairs = pairs(S)
izip = zip(S, S)

@test @inferred(ArrayInterface.known_size(S)) === (2, 3, 4)
@test @inferred(ArrayInterface.known_size(irev)) === (2, 3, 4)
@test @inferred(ArrayInterface.known_size(igen)) === (2, 3, 4)
@test @inferred(ArrayInterface.known_size(iacc)) === (2, 3, 4)
@test @inferred(ArrayInterface.known_size(ienum)) === (2, 3, 4)
@test @inferred(ArrayInterface.known_size(izip)) === (2, 3, 4)
@test @inferred(ArrayInterface.known_size(ipairs)) === (2, 3, 4)
@test @inferred(ArrayInterface.known_size(zip(S, zeros(2, 3, 4, 1)))) === (2, 3, 4, 1)
@test @inferred(ArrayInterface.known_size(zip(zeros(2, 3, 4, 1), S))) === (2, 3, 4, 1)
@test @inferred(ArrayInterface.known_length(Iterators.flatten(((x, y) for x in 0:1 for y in 'a':'c')))) === nothing
@test ArrayInterface.known_length(S) == length(S)


@test @inferred(ArrayInterface.known_size(S)) === (2, 3, 4)
@test @inferred(ArrayInterface.known_size(Sp)) === (nothing, nothing, 3)
@test @inferred(ArrayInterface.known_size(Sp2)) === (nothing, 3, 2)
@test ArrayInterface.known_size(Sp2, 1) === nothing
@test ArrayInterface.known_size(Sp2, 2) === 3
@test ArrayInterface.known_size(Sp2, 3) === 2
@test @inferred(ArrayInterface.known_size(Mp)) === (3, 4)
@test @inferred(ArrayInterface.known_size(Mp2)) === (2, nothing)

2 changes: 2 additions & 0 deletions test/staticarrayscore.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

using StaticArrays, ArrayInterface, Test
using LinearAlgebra
using ArrayInterface: undefmatrix, zeromatrix
@@ -43,3 +44,4 @@ zr = ArrayInterface.restructure(x, z)
end
end
end