diff --git a/src/nodes.jl b/src/nodes.jl index ee1633c5..481339e9 100644 --- a/src/nodes.jl +++ b/src/nodes.jl @@ -58,6 +58,7 @@ function NodeSet(nodes::AbstractVector{RealT}) where {RealT} @assert length(nodes) > 0 return NodeSet([[node] for node in nodes]) end +NodeSet(nodeset::NodeSet) = nodeset """ empty_nodeset(Dim, RealT = Float64) @@ -119,9 +120,11 @@ end function Base.similar(nodeset::NodeSet{Dim, RealT}, ::Type{T}, n::Int) where {Dim, RealT, T} NodeSet{Dim, T}(similar(nodeset.nodes, MVector{Dim, T}, n), Inf) end -Base.getindex(nodeset::NodeSet, i::Int) = getindex(nodeset.nodes, i) -Base.getindex(nodeset::NodeSet, is::UnitRange) = nodeset.nodes[is] +Base.getindex(nodeset::NodeSet, i::Int) = nodeset.nodes[i] +Base.getindex(nodeset::NodeSet, is::AbstractVector) = nodeset.nodes[is] +Base.firstindex(nodeset::NodeSet) = firstindex(nodeset.nodes) Base.lastindex(nodeset::NodeSet) = lastindex(nodeset.nodes) +Base.keys(nodeset::NodeSet) = keys(nodeset.nodes) function Base.setindex!(nodeset::NodeSet{Dim, RealT}, v::MVector{Dim, RealT}, i::Int) where {Dim, RealT} nodeset.nodes[i] = v @@ -190,11 +193,10 @@ Compute the distance matrix between two [`NodeSet`](@ref)s, which is a matrix `` `nodeset1` and ``\xi_j`` are the nodes on `nodeset2`. """ function distance_matrix(nodeset1::NodeSet, nodeset2::NodeSet) - n1 = length(nodeset1) - n2 = length(nodeset2) - D = zeros(n1, n2) - for i in 1:n1 - for j in 1:n2 + n1, n2 = length(nodeset1), length(nodeset2) + D = zeros(eltype(nodeset1), n1, n2) + for i in eachindex(nodeset1) + for j in eachindex(nodeset2) D[i, j] = norm(nodeset1[i] - nodeset2[j]) end end diff --git a/test/test_unit.jl b/test/test_unit.jl index 658ccbb9..b9c3d9b9 100644 --- a/test/test_unit.jl +++ b/test/test_unit.jl @@ -197,6 +197,7 @@ end 1.0 0.0 0.0 1.0 1.0 1.0]) + @test NodeSet(nodeset1) == nodeset1 @test_nowarn println(nodeset1) @test_nowarn display(nodeset1) @test eltype(nodeset1) == Float64 @@ -205,7 +206,8 @@ end @test length(nodeset1) == 4 @test size(nodeset1) == (4, 2) @test axes(nodeset1) == (1:4,) - @test eachindex(nodeset1) == 1:4 + @test eachindex(nodeset1) == firstindex(nodeset1):lastindex(nodeset1) + @test keys(nodeset1) == 1:4 for node in nodeset1 @test node isa MVector{2, Float64} end