Skip to content

Commit

Permalink
add SubDitStr function to enable DitStr slicing (#54)
Browse files Browse the repository at this point in the history
* add `SubDitStr` function to enable `DitStr` slicing

* * fix doc and doctests
* new `DitStr` function to raise `SubDitStr` struct to `DitStr`
* tests added

* * add @views macro for `SubDitStr`
* add benchmark

* remove `using BenchmarkTools`

* comment bm()

* update

* update
  • Loading branch information
hz-xiaxz authored Jun 28, 2024
1 parent e264b10 commit 4c1e25a
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 20 deletions.
1 change: 1 addition & 0 deletions src/BitBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export bitarray, basis, packbits, bfloat, bfloat_r, bint, bint_r, flip
export anyone, allone, bmask, baddrs, readbit, setbit, controller
export swapbits, ismatch, neg, breflect, btruncate
export LongLongUInt
export SubDitStr

include("utils.jl")
include("longlonguint.jl")
Expand Down
156 changes: 139 additions & 17 deletions src/DitStr.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const UIntStorage = Union{UInt8,UInt16,UInt32,UInt64,UInt128,LongLongUInt}
const IntStorage = Union{Int8,Int16,Int32,Int64,Int128,BigInt,UIntStorage}
const UIntStorage = Union{UInt8, UInt16, UInt32, UInt64, UInt128, LongLongUInt}
const IntStorage = Union{Int8, Int16, Int32,Int64,Int128,BigInt,UIntStorage}

########## DitStr #########
"""
Expand Down Expand Up @@ -37,26 +37,26 @@ function DitStr{D,T}(vector::Union{AbstractVector,Tuple}) where {D,T}
val = zero(T)
D_power_k = one(T)
for k in 1:length(vector)
0 <= vector[k] <= D-1 || error("expect 0-$(D-1), got $(vector[k])")
0 <= vector[k] <= D - 1 || error("expect 0-$(D-1), got $(vector[k])")
val = accum(Val{D}(), val, vector[k], D_power_k)
D_power_k = _lshift(Val{D}(), D_power_k, 1)
end
return DitStr{D,length(vector),T}(val)
end
# val += x * y
accum(::Val{D}, val, x, y) where D = val + x * y
accum(::Val{D}, val, x, y) where {D} = val + x * y
accum(::Val{2}, val, x, y) = iszero(x) ? val : val y
DitStr{D}(vector::Tuple{T,Vararg{T,N}}) where {N,T,D} = DitStr{D,T}(vector)
DitStr{D}(vector::AbstractVector{T}) where {D,T} = DitStr{D,T}(vector)
DitStr{D,N,T}(val::DitStr) where {D,N,T<:Integer} = convert(DitStr{D,N,T}, val)
DitStr{D,N,T}(val::DitStr{D,N,T}) where {D,N,T<:Integer} = val

const DitStr64{D,N} = DitStr{D,N,Int64}
const LongDitStr{D,N} = DitStr{D,N,LongLongUInt{C}} where C
const LongDitStr{D,N} = DitStr{D,N,LongLongUInt{C}} where {C}
LongDitStr{D}(vector::AbstractVector{T}) where {D,T} = DitStr{D,longinttype(length(vector), D)}(vector)

Base.show(io::IO, ditstr::DitStr{D,N,<:Integer}) where {D,N} =
print(io, string(buffer(ditstr), base = D, pad = N), "$(''+D)")
print(io, string(buffer(ditstr), base=D, pad=N), "$(''+D)")
Base.show(io::IO, ditstr::DitStr{D,N,<:LongLongUInt}) where {D,N} =
print(io, join(map(string, [ditstr[end:-1:1]...])), "$(''+D)")

Expand Down Expand Up @@ -146,7 +146,7 @@ Read the dit config at given location.
"""
@inline @generated function readat(x::DitStr{D,N,T}, locs::Integer...) where {D,N,T}
length(locs) == 0 && return :(zero($T))
Expr(:call, :+, [:($_lshift($(Val(D)), mod($_rshift($(Val{D}()), buffer(x), locs[$i]-1), $D), $(i - 1))) for i=1:length(locs)]...)
Expr(:call, :+, [:($_lshift($(Val(D)), mod($_rshift($(Val{D}()), buffer(x), locs[$i] - 1), $D), $(i - 1))) for i = 1:length(locs)]...)
end

Base.@propagate_inbounds function Base.getindex(dit::DitStr{D,N}, index::Integer) where {D,N}
Expand All @@ -159,6 +159,128 @@ Base.@propagate_inbounds function Base.getindex(dit::DitStr{D,N,T}, itr::Abstrac
return map(x -> readat(dit, x), itr)
end


"""
SubDitStr{D,N,T<:Integer} <: Integer
The struct as a `SubString`-like object for `DitStr`(`SubString` is an official implementation of sliced strings, see [String](https://docs.julialang.org/en/v1/base/strings/#Base.SubString) for reference). This slicing returns a view into the parent `DitStr` instead of making a copy (similar to the `@views` macro for strings).
`SubDitStr` can be used to describe the qubit configuration within the subspace of the entire Hilbert space.It provides similar `getindex`, `length` functions as `DitStr`.
SubDitStr(dit::DitStr{D,N,T}, i::Int, j::Int)
SubDitStr(dit::DitStr{D,N,T}, r::AbstractUnitRange{<:Integer})
Or by `@views` macro for `DitStr` (this macro makes your life easier by supporting `begin` and `end` syntax):
@views dit[i:j]
Returns a `SubDitStr`.
### Examples
```jldoctest
julia> x = DitStr{3, 5}(71)
02122 ₍₃₎
julia> sx = SubDitStr(x, 2, 4)
SubDitStr{3, 5, Int64}(02122 ₍₃₎, 1, 3)
julia> @views x[2:end]
SubDitStr{3, 5, Int64}(02122 ₍₃₎, 1, 4)
julia> sx == dit"212;3"
true
```
"""
struct SubDitStr{D,N,T<:Integer} <: Integer
dit::DitStr{D,N,T}
offset::Int
ncodeunits::Int

function SubDitStr(dit::DitStr{D,N,T}, i::Int, j::Int) where {D,N,T}
i j || return new{D,N,T}(dit, 0, 0)
@boundscheck begin
1 i length(dit) || throw(BoundsError(dit, i))
1 j length(dit) || throw(BoundsError(dit, i))
end
return new{D,N,T}(dit, i - 1, j - i + 1)
end
end

Base.@propagate_inbounds Base.view(dit::DitStr{D,N,T}, i::Integer, j::Integer) where {D,N,T} = SubDitStr(dit, i, j)
Base.@propagate_inbounds Base.view(dit::DitStr{D,N,T}, r::AbstractUnitRange{<:Integer}) where {D,N,T} = SubDitStr(dit, first(r), last(r))
Base.@propagate_inbounds Base.maybeview(dit::DitStr{D,N,T}, r::AbstractUnitRange{<:Integer}) where {D,N,T} = view(dit,r)

"""
DitStr(dit::SubDitStr{D,N,T}) -> DitStr{D,N,T}
Raise type `SubDitStr` to `DitStr`.
```jldoctest
julia> x = DitStr{3, 5}(71)
02122 ₍₃₎
julia> sx = SubDitStr(x, 2, 4)
SubDitStr{3, 5, Int64}(02122 ₍₃₎, 1, 3)
julia> DitStr(sx)
212 ₍₃₎
```
"""
function DitStr(dit::SubDitStr{D,N,T}) where {D,N,T}
val = zero(T)
D_power_k = one(T)
len = ncodeunits(dit)
for k in 1:len
val = accum(Val{D}(), val, readat(dit.dit, dit.offset + k), D_power_k)
D_power_k = _lshift(Val{D}(), D_power_k, 1)
end
return DitStr{D,len,T}(val)
end

ncodeunits(dit::SubDitStr{D,N,T}) where {D,N,T} = dit.ncodeunits

## bounds checking ##
Base.checkbounds(::Type{Bool}, dit::SubDitStr{D,N,T}, i::Integer) where {D,N,T} =
1 i ncodeunits(dit)
Base.checkbounds(::Type{Bool}, dit::SubDitStr{D,N,T}, r::AbstractRange{<:Integer}) where {D,N,T} =
isempty(r) || (1 minimum(r) && maximum(r) ncodeunits(dit))
Base.checkbounds(::Type{Bool}, dit::SubDitStr{D,N,T}, I::AbstractArray{<:Integer}) where {D,N,T} =
all(i -> checkbounds(Bool, dit, i), I)
Base.checkbounds(dit::SubDitStr{D,N,T}, I::Union{Integer,AbstractArray}) where {D,N,T} = checkbounds(Bool, dit, I) ? nothing : throw(BoundsError(dit, I))

Base.@propagate_inbounds SubDitStr(dit::DitStr{D,N,T}, i::Integer, j::Integer) where {D,N,T} = SubDitStr{D,N,T}(dit, i, j)
Base.@propagate_inbounds SubDitStr(dit::DitStr{D,N,T}, r::AbstractUnitRange{<:Integer}) where {D,N,T} = SubDitStr{D,N,T}(dit, first(r), last(r))

Base.@propagate_inbounds function SubDitStr(dit::SubDitStr{D,N,T}, i::Int, j::Int) where {D,N,T}
@boundscheck i j && checkbounds(dit, i:j)
SubString(dit.dit, dit.offset + i, dit.offset + j)
end

Base.length(dit::SubDitStr{D,N,T}) where {D,N,T} = ncodeunits(dit)

"""
==(lhs::SubDitStr{D,N,T}, rhs::DitStr{D,N,T}) -> Bool
==(lhs::DitStr{D,N,T}, rhs::SubDitStr{D,N,T}) -> Bool
==(lhs::SubDitStr{D,N,T}, rhs::SubDitStr{D,N,T}) -> Bool
Compare the equality between `SubDitStr` and `DitStr`.
"""
function Base.:(==)(lhs::SubDitStr{D,N1}, rhs::DitStr{D,N2}) where {D,N1,N2}
length(lhs) == length(rhs) && @inbounds all(i -> lhs[i] == rhs[i], 1:length(lhs))
end

function Base.:(==)(lhs::SubDitStr{D,N1}, rhs::SubDitStr{D,N2}) where {D,N1,N2}
length(lhs) == length(rhs) && @inbounds all(i -> lhs[i] == rhs[i], 1:length(lhs))
end

function Base.:(==)(lhs::DitStr{D,N1}, rhs::SubDitStr{D,N2}) where {D,N1,N2}
length(lhs) == length(rhs) && @inbounds all(i -> lhs[i] == rhs[i], 1:length(lhs))
end

function Base.getindex(dit::SubDitStr{D,N,T}, i::Integer) where {D,N,T}
@boundscheck checkbounds(dit, i)
@inbounds return getindex(dit.dit, dit.offset + i)
end


# TODO: support AbstractArray, should return its corresponding shape

Base.@propagate_inbounds function Base.getindex(
Expand All @@ -178,7 +300,7 @@ end

Base.eltype(::DitStr{D,N,T}) where {D,N,T} = T

function Base.iterate(dit::DitStr, state::Integer = 1)
function Base.iterate(dit::DitStr, state::Integer=1)
if state > length(dit)
return nothing
else
Expand All @@ -201,8 +323,8 @@ function Base.rand(::Type{T}) where {D,N,Ti,T<:DitStr{D,N,Ti}}
end

######################### Operations #####################
_lshift(::Val{D}, x::Integer, i::Integer) where D = x * (D^i)
_rshift(::Val{D}, x::Integer, i::Integer) where D = x ÷ (D^i)
_lshift(::Val{D}, x::Integer, i::Integer) where {D} = x * (D^i)
_rshift(::Val{D}, x::Integer, i::Integer) where {D} = x ÷ (D^i)
_lshift(::Val{2}, x::Integer, i::Integer) = x << i
_rshift(::Val{2}, x::Integer, i::Integer) = x >> i

Expand Down Expand Up @@ -230,10 +352,10 @@ Base.repeat(s::DitStr, n::Integer) = join([s for i in 1:n]...)
Create an onehot vector in type `Vector{T}` or a batch of onehot vector in type `Matrix{T}`, where index `x + 1` is one.
One can specify the value of the nonzero entry by inputing a pair.
"""
onehot(::Type{T}, n::DitStr{D,N,T1}; nbatch=nothing) where {D,T, N,T1} = _onehot(T, D^N, buffer(n)+1; nbatch)
onehot(::Type{T}, n::DitStr{D,N,T1}; nbatch=nothing) where {D,T,N,T1} = _onehot(T, D^N, buffer(n) + 1; nbatch)
onehot(n::DitStr; nbatch=nothing) = onehot(ComplexF64, n; nbatch)

readbit(x::DitStr{D, N, LongLongUInt{C}}, loc::Int) where {D, N, C} = readbit(x.buf, loc)
readbit(x::DitStr{D,N,LongLongUInt{C}}, loc::Int) where {D,N,C} = readbit(x.buf, loc)

########## @dit_str macro ##############
"""
Expand Down Expand Up @@ -297,20 +419,20 @@ function parse_dit(::Type{T}, str::String) where {T<:Integer}
if res === nothing
error("Input string literal format error, should be e.g. `dit\"01121;3\"`")
end
return _parse_dit(Val(parse(Int,res[2])), T, res[1])
return _parse_dit(Val(parse(Int, res[2])), T, res[1])
end

function _parse_dit(::Val{D}, ::Type{T}, str::AbstractString) where {D, T<:Integer}
function _parse_dit(::Val{D}, ::Type{T}, str::AbstractString) where {D,T<:Integer}
TT = T <: LongLongUInt ? longinttype(count(isdigit, str), D) : T
_parse_dit_safe(Val(D), TT, str)
end

function _parse_dit_safe(::Val{D}, ::Type{T}, str::AbstractString) where {D, T<:Integer}
function _parse_dit_safe(::Val{D}, ::Type{T}, str::AbstractString) where {D,T<:Integer}
val = zero(T)
k = 0
maxk = max_num_elements(T, D)
for each in reverse(str)
k >= maxk-1 && error("string length is larger than $(maxk), use @ldit_str instead")
k >= maxk - 1 && error("string length is larger than $(maxk), use @ldit_str instead")
v = each - '0'
if 0 <= v < D
val += _lshift(Val(D), T(v), k)
Expand All @@ -324,6 +446,6 @@ function _parse_dit_safe(::Val{D}, ::Type{T}, str::AbstractString) where {D, T<:
return DitStr{D,k,T}(val)
end

max_num_elements(::Type{T}, D::Int) where T<:Integer = floor(Int, log(typemax(T))/log(D))
max_num_elements(::Type{T}, D::Int) where {T<:Integer} = floor(Int, log(typemax(T)) / log(D))
max_num_elements(::Type{BigInt}, D::Int) = typemax(Int)
max_num_elements(::Type{LongLongUInt{C}}, D::Int) where {C} = max_num_elements(UInt, D) * C
25 changes: 22 additions & 3 deletions test/DitStr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ using BitBasis, Test
@test x isa DitStr
@test x |> length == 6
println(x)
@test buffer(x) === Int64(3^5 + 3^4 + 2*3^3 + 3^2)
@test buffer(x) === Int64(3^5 + 3^4 + 2 * 3^3 + 3^2)
@test_throws ErrorException randn(100)[x]
@test BitBasis.readat(x, 2) == Int64(0)
@test x[2] == Int64(0)
@test x[3] == Int64(1)
@test x[4] == Int64(2)
@test [x...] == Int64[0,0,1,2,1,1]
@test [DitStr{3}(Int64[0,0,1,2,1,1])...] == Int64[0,0,1,2,1,1]
@test [x...] == Int64[0, 0, 1, 2, 1, 1]
@test [DitStr{3}(Int64[0, 0, 1, 2, 1, 1])...] == Int64[0, 0, 1, 2, 1, 1]
@test_throws ErrorException BitBasis.parse_dit(Int64, "112103;3")
@test_throws ErrorException BitBasis.parse_dit(Int64, "112101;")

Expand All @@ -24,4 +24,23 @@ using BitBasis, Test
@test_throws ErrorException BitBasis.parse_dit(Int64, "12341111111111111111111111111111111111111111111111111111111;5")

@test hash(x) isa UInt64
end

@testset "SubDitStr" begin
x = dit"112100;3"
sx = SubDitStr(x, 2, 4) # bit"210"
@test_throws BoundsError SubDitStr(x, 2, 7)
@test checkbounds(sx, 1) == nothing
@test getindex(sx, 1) == 0
@test getindex(sx, 2) == 1
@test getindex(sx, 3) == 2
@test_throws BoundsError getindex(sx, 4)
@test_throws BoundsError getindex(sx, 0)
@test length(sx) == 3
@test sx == dit"210;3"
@test dit"210;3" == sx
@test DitStr(sx) == dit"210;3"
@test SubDitStr(dit"210;3",1,length(dit"210;3")) == sx
@test (@views x[4:end]) == dit"112;3"
@test (@views x[begin:3]) == dit"100;3"
end

0 comments on commit 4c1e25a

Please sign in to comment.