Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] BlockIndices #356

Draft
wants to merge 8 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/BlockArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ end
include("blockindices.jl")
include("blockaxis.jl")
include("abstractblockarray.jl")
include("block_indices.jl")
include("blockarray.jl")
include("pseudo_blockarray.jl")
include("views.jl")
Expand Down
116 changes: 116 additions & 0 deletions src/block_indices.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
function to_blockindex(a::AbstractUnitRange, index::Integer)
axis_blockindex = findblockindex(only(axes(a)), index)
if !isone(first(a)) && block(axis_blockindex) == Block(1)
axis_blockindex = block(axis_blockindex)[blockindex(axis_blockindex) + first(a) - one(eltype(a))]

Check warning on line 4 in src/block_indices.jl

View check run for this annotation

Codecov / codecov/patch

src/block_indices.jl#L4

Added line #L4 was not covered by tests
end
return axis_blockindex
end

# https://github.com/JuliaArrays/BlockArrays.jl/issues/347
function blockedgetindex(a::BlockedUnitRange, indices::AbstractUnitRange)
first_block = block(to_blockindex(a, first(indices)))
last_block = block(to_blockindex(a, last(indices)))
lasts = [blocklasts(a)[Int(first_block):(Int(last_block) - 1)]; last(indices)]
return BlockArrays._BlockedUnitRange(first(indices), lasts)
end
function blockedgetindex(a::AbstractUnitRange, indices::AbstractUnitRange)
return a[indices]

Check warning on line 17 in src/block_indices.jl

View check run for this annotation

Codecov / codecov/patch

src/block_indices.jl#L16-L17

Added lines #L16 - L17 were not covered by tests
end

function _BlockIndices end

struct BlockIndices{N,R<:NTuple{N,AbstractUnitRange{Int}},BS<:NTuple{N,AbstractUnitRange{Int}}} <: AbstractBlockArray{BlockIndex{N},N}
first::NTuple{N,Int}
indices::R
axes::BS
global function _BlockIndices(first::NTuple{N,Int}, indices::R, axes::BS) where {N,R<:NTuple{N,AbstractUnitRange{Int}},BS<:NTuple{N,AbstractUnitRange{Int}}}
return new{N,R,BS}(first, indices, axes)
end
end
Base.axes(a::BlockIndices) = a.axes
function BlockIndices(indices::Tuple{Vararg{AbstractUnitRange{Int},N}}) where {N}
first = ntuple(_ -> 1, Val(N))
axes = map(Base.axes1, blockedgetindex.(indices, Base.OneTo.(last.(indices))))
return _BlockIndices(first, indices, axes)
end
BlockIndices(a::AbstractArray) = BlockIndices(axes(a))

function Base.getindex(a::BlockIndices{N}, index::Vararg{Integer,N}) where {N}
return BlockIndex(to_blockindex.(a.indices, index .+ a.first .- 1))
end

function Base.view(a::BlockIndices{N}, block::Block{N}) where {N}
return viewblock(a, block)
end

function Base.view(a::BlockIndices{1}, block::Block{1})
return viewblock(a, block)
end

function Base.view(a::BlockIndices{N}, block::Vararg{Block{1},N}) where {N}
return view(a, Block(block))
end

function viewblock(a::BlockIndices, block)
range = Base.OneTo.(getindex.(blocklengths.(axes(a)), Int.(Tuple(block))))
return block[range...]
end

function Base.view(a::BlockIndices{N}, indices::Vararg{BlockIndexRange{1},N}) where {N}
return view(a, BlockIndexRange(Block(block.(indices)), only.(getfield.(indices, :indices))))

Check warning on line 60 in src/block_indices.jl

View check run for this annotation

Codecov / codecov/patch

src/block_indices.jl#L59-L60

Added lines #L59 - L60 were not covered by tests
end

function Base.view(a::BlockIndices{N}, indices::BlockIndexRange{N}) where {N}
a_block = a[block(indices)]
return block(a_block)[getindex.(a_block.indices, indices.indices)...]
end

# Circumvent that this is getting hijacked to call `a[block(indices)][indices.indices...]`,
# which hits the bug https://github.com/JuliaArrays/BlockArrays.jl/issues/355.
function Base.getindex(a::BlockIndices{N}, indices::BlockIndexRange{N}) where {N}
return view(a, indices)
end

function Base.view(a::BlockIndices{N}, indices::Vararg{AbstractUnitRange,N}) where {N}
offsets = a.first .- ntuple(_ -> 1, Val(N))
firsts = first.(indices) .+ offsets
inds = blockedgetindex.(a.indices, Base.OneTo.(last.(indices) .+ offsets))
return _BlockIndices(firsts, inds, Base.axes1.(indices))
end

# Ranges that result in contiguous slices, and therefore preserve `BlockIndices`.
const BlockOrUnitRanges = Union{AbstractUnitRange,CartesianIndices{1},Block{1},BlockRange{1},BlockIndexRange{1}}

function Base.view(a::BlockIndices{N}, indices::Vararg{BlockOrUnitRanges,N}) where {N}
return view(a, to_indices(a, indices)...)
end

function Base.view(a::BlockIndices{N}, indices::CartesianIndices{N}) where {N}
return view(a, to_indices(a, (indices,))...)

Check warning on line 89 in src/block_indices.jl

View check run for this annotation

Codecov / codecov/patch

src/block_indices.jl#L88-L89

Added lines #L88 - L89 were not covered by tests
end

# For some reason this doesn't call `view` automatically.
function Base.getindex(a::BlockIndices{N}, indices::CartesianIndices{N}) where {N}
return view(a, to_indices(a, (indices,))...)
end

function Base.view(a::BlockIndices{N}, indices::BlockRange{N}) where {N}
return view(a, to_indices(a, (indices,))...)

Check warning on line 98 in src/block_indices.jl

View check run for this annotation

Codecov / codecov/patch

src/block_indices.jl#L97-L98

Added lines #L97 - L98 were not covered by tests
end

# For some reason this doesn't call `view` automatically.
function Base.getindex(a::BlockIndices{N}, indices::BlockRange{N}) where {N}
return view(a, to_indices(a, (indices,))...)
end

function Base.getindex(a::CartesianIndices{N}, indices::BlockIndices{N}) where {N}
# TODO: Is there a better way to write this?
new_axes = (:).(Tuple(a[first(indices)]), Tuple(a[last(indices)]))
firsts = first.(new_axes)
blocklasts = ntuple(i -> accumulate(+, blocklengths(axes(indices, i)); init=firsts[i] - one(firsts[i])), Val(N))
return CartesianIndices(_BlockedUnitRange.(firsts, blocklasts))
end

function Base.getindex(a::AbstractArray{<:Any,N}, indices::BlockIndices{N}) where {N}
return a[CartesianIndices(a)[indices]]
end
98 changes: 97 additions & 1 deletion test/test_blockindices.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using BlockArrays, FillArrays, Test, StaticArrays, ArrayLayouts
using OffsetArrays
import BlockArrays: BlockIndex, BlockIndexRange, BlockSlice
import BlockArrays: BlockIndex, BlockIndexRange, BlockIndices, BlockSlice

@testset "Blocks" begin
@test Int(Block(2)) === Integer(Block(2)) === Number(Block(2)) === 2
Expand Down Expand Up @@ -446,3 +446,99 @@ end
first(eachblock(B))[1,2] = 0
@test B[1,2] == 0
end

@testset "BlockIndices" begin
v = Array(reshape(1:35, (5, 7)))
A = BlockArray(v, [2, 3], [3, 4])
B = BlockIndices(A)
@test B == BlockIndices((blockedrange([2,3]), blockedrange([3,4])))
@test eltype(B) === BlockIndex{2}
@test size(B) == (5, 7)
@test axes(B) == (1:5, 1:7)
@test blocklengths.(axes(B)) == ([2, 3], [3, 4])
@test blocksize(B) == (2, 2)
@test blockaxes(B) == (Block.(1:2), Block.(1:2))
@test B[1, 1] == Block(1, 1)[1, 1]
@test B[4, 6] == Block(2, 2)[2, 3]

@test B[Block(1, 1)] isa BlockIndexRange{2}
@test B[Block(1, 1)] == Block(1, 1)[1:2, 1:3]
@test B[Block(2, 1)] == Block(2, 1)[1:3, 1:3]
@test B[Block(1, 2)] == Block(1, 2)[1:2, 1:4]
@test B[Block(2, 2)] == Block(2, 2)[1:3, 1:4]

@test view(B, Block(1, 2)) isa BlockIndexRange{2}
@test view(B, Block(1, 2)) == Block(1, 2)[1:2, 1:4]

@test B[Block(1), Block(2)] isa BlockIndexRange{2}
@test B[Block(1), Block(2)] == Block(1, 2)[1:2, 1:4]

@test view(B, Block(1), Block(2)) isa BlockIndexRange{2}
@test view(B, Block(1), Block(2)) == Block(1, 2)[1:2, 1:4]

B23_24 = mortar([[Block(1, 1)[2:2, 2:3]] [Block(1, 2)[2:2, 1:1]]
[Block(2, 1)[1:1, 2:3]] [Block(2, 2)[1:1, 1:1]]])
Br = B[2:3, 2:4]
@test Br == B23_24
@test blocksize(Br) == (1, 1)
@test Br isa AbstractMatrix{<:BlockIndex{2}}
@test Br isa BlockIndices{2}
Br = B[CartesianIndices((2:3,)), CartesianIndices((2:4,))]
@test Br == B23_24
@test blocksize(Br) == (1, 1)
@test Br isa AbstractMatrix{<:BlockIndex{2}}
@test Br isa BlockIndices{2}
Br = B[CartesianIndices((2:3, 2:4))]
@test Br == B23_24
@test Br isa AbstractMatrix{<:BlockIndex{2}}
@test Br isa BlockIndices{2}

@test B[Block(2, 2)[2:3, 2:3]] == Block(2, 2)[2:3, 2:3]
@test B[Block(2, 2)[2:3, 2:3]] isa BlockIndexRange{2}

@test B[Block.(1:2), Block.(1:2)] == B
@test B[Block.(1:2), Block.(1:2)] isa BlockIndices{2}

@test B[BlockRange(1:2, 1:2)] == B
@test B[BlockRange(1:2, 1:2)] isa BlockIndices{2}

@test B[Block.(2:2), Block.(1:2)] == mortar([[Block(2, 1)[1:3, 1:3]] [Block(2, 2)[1:3, 1:4]]])
@test B[Block.(2:2), Block.(1:2)] isa BlockIndices{2}

@test B[2:4, 2:5][2:3, 2:3] == mortar([[Block(2, 1)[1:2, 3:3]] [Block(2, 2)[1:2, 1:1]]])
@test B[2:4, 2:5][2:3, 2:3] isa BlockIndices{2}

B = BlockIndices(blockedrange([2, 3]))
@test B == [Block(1)[1], Block(1)[2], Block(2)[1], Block(2)[2], Block(2)[3]]
@test blocklengths(only(axes(B))) == [2, 3]
@test B[Block(1)] == Block(1)[1:2]
@test B[Block(1)] isa BlockIndexRange{1}
@test B[Block(2)] == Block(2)[1:3]
@test B[Block(2)] isa BlockIndexRange{1}
@test B[2:4] == [Block(1)[2], Block(2)[1], Block(2)[2]]
@test blocklength(only(axes(B[2:4]))) == 1
@test blocklengths(only(axes(B[2:4]))) == [3]
@test B[2:4] isa AbstractVector{<:BlockIndex{1}}
@test B[2:4] isa BlockIndices{1}

A = BlockArray(randn(5, 5), [2, 3], [2, 3])
BA = BlockIndices(A)
r = BlockArrays._BlockedUnitRange(2, [2, 4])
B = BA[r, r]
@test B == BA[2:4, 2:4]
@test size(B) == (3, 3)
@test blocksize(B) == (2, 2)
@test blocklengths.(axes(B)) == ([1, 2], [1, 2])
CB = CartesianIndices(A)[B]
@test CB == CartesianIndices((2:4, 2:4))
@test size(CB) == (3, 3)
@test blocksize(CB) == (2, 2)
@test blocklengths.(axes(CB)) == ([1, 2], [1, 2])
@test all(blockisequal.(axes(CB), axes(B)))
AB = A[B]
@test AB == A[2:4, 2:4]
@test size(AB) == (3, 3)
@test blocksize(AB) == (2, 2)
@test blocklengths.(axes(AB)) == ([1, 2], [1, 2])
@test all(blockisequal.(axes(AB), axes(B)))
end
Loading