Skip to content

Commit

Permalink
BlockArrays in module and more efficient Conv
Browse files Browse the repository at this point in the history
  • Loading branch information
nantonel committed Nov 27, 2017
1 parent af0bf85 commit 322a9d0
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 41 deletions.
9 changes: 4 additions & 5 deletions src/AbstractOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ __precompile__()

module AbstractOperators

# Block stuff
include("utilities/block.jl")
using AbstractOperators.BlockArrays

abstract type AbstractOperator end

abstract type LinearOperator <: AbstractOperator end
Expand All @@ -13,11 +17,6 @@ export LinearOperator,
NonLinearOperator,
AbstractOperator

# Block stuff

include("utilities/block.jl")
#include("utilities/deep.jl") # TODO: remove this eventually

# Predicates and properties

include("properties.jl")
Expand Down
57 changes: 48 additions & 9 deletions src/linearoperators/Conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,70 @@ Creates a `LinearOperator` which, when multiplied with an array `x::AbstractVect
"""

struct Conv{T,H <: AbstractVector{T}} <: LinearOperator
struct Conv{T,
H <: AbstractVector{T},
Hc <: AbstractVector{Complex{T}},
} <: LinearOperator
dim_in::Tuple{Int}
h::H
buf::H
buf_c1::Hc
buf_c2::Hc
R::Base.DFT.Plan
I::Base.DFT.Plan
end

# Constructors

###standard constructor
function Conv(DomainType::Type, DomainDim::NTuple{N,Int}, h::H) where {H<:AbstractVector, N}
function Conv(DomainType::Type, dim_in::Tuple{Int}, h::H) where {H<:AbstractVector}
eltype(h) != DomainType && error("eltype(h) is $(eltype(h)), should be $(DomainType)")
N != 1 && error("Conv treats only SISO, check Filt and MIMOFilt for MIMO")
Conv{DomainType,H}(DomainDim,h)

buf = zeros(dim_in[1]+length(h)-1)
R = plan_rfft(buf)
buf_c1 = zeros(Complex{eltype(h)}, div(dim_in[1]+length(h)-1,2)+1)
buf_c2 = zeros(Complex{eltype(h)}, div(dim_in[1]+length(h)-1,2)+1)
I = plan_irfft(buf_c1,dim_in[1]+length(h)-1)
Conv{DomainType, H, typeof(buf_c1)}(dim_in,h,buf,buf_c1,buf_c2,R,I)
end

Conv(DomainDim::NTuple{N,Int}, h::H) where {H<:AbstractVector, N} = Conv(eltype(h), DomainDim, h)
Conv(dim_in::NTuple{N,Int}, h::H) where {H<:AbstractVector, N} = Conv(eltype(h), dim_in, h)
Conv(x::H, h::H) where {H} = Conv(eltype(x), size(x), h)

# Mappings

function A_mul_B!(y::H,A::Conv{T,H},b::H) where {T,H}
y .= conv(A.h,b)
function A_mul_B!(y::H, A::Conv{T,H}, b::H) where {T, H}
#y .= conv(A.h,b) #naive implementation
for i in eachindex(A.buf)
A.buf[i] = i <= length(A.h) ? A.h[i] : zero(T)
end
A_mul_B!(A.buf_c1, A.R, A.buf)
for i in eachindex(A.buf)
A.buf[i] = i <= length(b) ? b[i] : zero(T)
end
A_mul_B!(A.buf_c2, A.R, A.buf)
A.buf_c2 .*= A.buf_c1
A_mul_B!(y,A.I,A.buf_c2)

end

function Ac_mul_B!(y::H,A::Conv{T,H},b::H) where {T,H}
y .= xcorr(b,A.h)[size(A,1)[1]:end-length(A.h)+1]
function Ac_mul_B!(y::H, A::Conv{T,H}, b::H) where {T, H}
#y .= xcorr(b,A.h)[size(A,1)[1]:end-length(A.h)+1] #naive implementation
for i in eachindex(A.buf)
ii = length(A.buf)-i+1
A.buf[ii] = i <= length(A.h) ? A.h[i] : zero(T)
end
A_mul_B!(A.buf_c1, A.R, A.buf)
for i in eachindex(A.buf)
A.buf[i] = b[i]
end
A_mul_B!(A.buf_c2, A.R, A.buf)
A.buf_c2 .*= A.buf_c1
A_mul_B!(A.buf,A.I,A.buf_c2)
y[1] = A.buf[end]
for i = 2:length(y)
y[i] = A.buf[i-1]
end
end

# Properties
Expand Down
19 changes: 19 additions & 0 deletions src/utilities/block.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,21 @@
# Define block-arrays
module BlockArrays

export RealOrComplex,
BlockArray,
blocksize,
blockeltype,
blocklength,
blockvecnorm,
blockmaxabs,
blocksimilar,
blockcopy,
blockcopy!,
blockset!,
blockvecdot,
blockzeros,
blockaxpy!


const RealOrComplex{R} = Union{R, Complex{R}}
const BlockArray{R} = Union{
Expand Down Expand Up @@ -86,3 +103,5 @@ function broadcast!(f::Any, dest::Tuple, op1::Tuple, coef::Number, op2::Tuple, o
end
return dest
end

end
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ verb = true

@testset "AbstractOperators" begin

#@testset "Tuple operations" begin
# include("test_deep.jl")
#end
@testset "Block Arrays" begin
include("test_block.jl")
end

@testset "Linear operators" begin
include("test_linear_operators.jl")
Expand Down
79 changes: 79 additions & 0 deletions test/test_block.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
@printf("\nTesting BlockArrays\n")

using AbstractOperators:BlockArrays

T = Float64
x = [2.0; 3.0]
xb = ([2.0; 3.0], Complex{T}[4.0; 5.0; 6.0], [1.0 2.0 3.0; 4.0 5.0 6.0])

x2 = randn(2)
x2b = (randn(2), randn(3)+im.*randn(3), randn(2,3))

@test blocksize(x) == (2,)
@test blocksize(xb) == ((2,),(3,),(2,3))

@test blockeltype(x) == T
@test blockeltype(xb) == (T, Complex{T},T)

@test blocklength(x) == 2
@test blocklength(xb) == 2+3+6

@test blockvecnorm(x) == vecnorm(x)
@test blockvecnorm(xb) == sqrt(vecnorm(xb[1])^2+vecnorm(xb[2])^2+vecnorm(xb[3])^2)

@test blockmaxabs(x) == 3.0
@test blockmaxabs(xb) == 6.0

@test typeof(blocksimilar(x)) == typeof(x)
@test typeof(blocksimilar(xb)) == typeof(xb)

@test blockcopy(x) == x
@test blockcopy(xb) == xb

y = blocksimilar(x)
yb = blocksimilar(xb)
blockcopy!(y,x)
blockcopy!(yb,xb)

@test y == x
@test yb == xb

y = blocksimilar(x)
yb = blocksimilar(xb)
blockset!(y,x)
blockset!(yb,xb)

@test y == x
@test yb == xb

z = blockvecdot(x,x2)
zb = blockvecdot(xb,x2b)
@test z == vecdot(x,x2)
@test zb == vecdot(xb[1],x2b[1])+vecdot(xb[2],x2b[2])+vecdot(xb[3],x2b[3])

y = blockzeros(x)
yb = blockzeros(xb)
@test y == zeros(2)
@test yb == (zeros(2),zeros(3)+im*zeros(3),zeros(2,3))

y = blockzeros(blocksize(x))
yb = blockzeros(blocksize(xb))
@test y == zeros(2)
@test yb == (zeros(2),zeros(3)+im*zeros(3),zeros(2,3))

y = blockzeros(blockeltype(x),blocksize(x))
yb = blockzeros(blockeltype(xb),blocksize(xb))
@test y == zeros(2)
@test yb == (zeros(2),zeros(3)+im*zeros(3),zeros(2,3))

blockaxpy!(y,x,2,x2)
blockaxpy!(yb,xb,2,x2b)

@test y == x .+ 2.*x2
@test yb == (xb[1] .+ 2.*x2b[1], xb[2] .+ 2.*x2b[2], xb[3] .+ 2.*x2b[3])






24 changes: 0 additions & 24 deletions test/test_deep.jl

This file was deleted.

0 comments on commit 322a9d0

Please sign in to comment.