Skip to content

Commit

Permalink
Bugfixes and changes (#3)
Browse files Browse the repository at this point in the history
Added some missing functions and fixed bugs so that JuliaFEM.jl tests
pass now. Minor changes to `fields.jl` break backward compatibility, changing new minor version.
  • Loading branch information
ahojukka5 authored Dec 27, 2017
1 parent 3a0f9f1 commit 3ab1f97
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 148 deletions.
11 changes: 11 additions & 0 deletions src/FEMBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,16 @@ export update!, add_elements!, add!, get_gdofs, group_by_element_type,
get_unknown_field_name, get_unknown_field_dimension,
get_integration_points, initialize!, assemble!
export DCTI, DVTI, DCTV, DVTV, CCTI, CVTI, CCTV, CVTV
export add_elements!
export SparseMatrixCOO, SparseVectorCOO, Node, BasisInfo,
IP, AbstractProblem, IntegrationPoint, AbstractField
export is_field_problem, is_boundary_problem, get_elements,
get_connectivity, assemble_prehook!, assemble_posthook!,
get_parent_field_name, get_reference_coordinates,
get_assembly, get_nonzero_rows, get_nonzero_columns,
eval_basis!, get_basis, get_dbasis, grad!,
assemble_mass_matrix!, get_local_coordinates, inside,
get_element_type, filter_by_element_type, get_element_id,
optimize!, resize_sparse, resize_sparsevec

end
121 changes: 10 additions & 111 deletions src/elements.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function (element::Element)(field_name::String, time::Float64)
field = element[field_name]
result = interpolate(field, time)
if isa(result, Dict)
return tuple((field[i] for i in get_connectivity(element))...)
return tuple((result[i] for i in get_connectivity(element))...)
else
return result
end
Expand Down Expand Up @@ -179,8 +179,8 @@ function (element::Element)(ip, time::Float64, ::Type{Val{:Grad}})
end

function (element::Element)(field_name::String, ip, time::Float64, ::Type{Val{:Grad}})
X = interpolate(element["geometry"], time)
u = interpolate(element[field_name], time)
X = element("geometry", time)
u = element(field_name, time)
return grad(element.properties, u, X, ip)
end

Expand Down Expand Up @@ -221,118 +221,17 @@ function size(element::Element, dim)
return size(element)[dim]
end

#=
""" Update element field based on a dictionary of nodal data and connectivity information.
Examples
--------
julia> data = Dict(1 => [0.0, 0.0], 2 => [1.0, 2.0])
julia> element = Seg2([1, 2])
julia> update!(element, "geometry", data)
As a result element now have time invariant (variable) vector field "geometry" with data ([0.0, 0.0], [1.0, 2.0]).
"""
function update!{E}(element::Element{E}, field_name::AbstractString, data::Dict)
#element[field_name] = Field(data)
element_id = element.id
local_connectivity = get_connectivity(element)
for i in local_connectivity
if !haskey(data, i)
ndata = length(data)
critical("Unable to set field data $field_name for element $E with
id $element_id and connectivity $local_connectivity: no data for
node id $i found. Length of data dict = $ndata")
end
end
local_data = [data[i] for i in local_connectivity]
element[field_name] = local_data
end
function update!{K,V}(element::Element, field_name, data::Pair{Float64, Dict{K, V}})
time, field_data = data
element_data = V[field_data[i] for i in get_connectivity(element)]
update!(element, field_name, time => element_data)
end
function update!(element::Element, field_name::AbstractString, data::Pair{Float64, Vector{Any}})
if haskey(element, field_name)
update!(element[field_name], data)
else
element[field_name] = data
end
end
function update!(element::Element, field_name::AbstractString, data::Pair{Float64, Vector{Int64}})
if haskey(element, field_name)
update!(element[field_name], data)
else
element[field_name] = data
end
end
function update!(element::Element, field_name::AbstractString, data::Pair{Float64, Vector{Float64}})
if haskey(element, field_name)
update!(element[field_name], data)
else
element[field_name] = data
end
end
function update!(element::Element, field_name::AbstractString, data::Pair{Float64, Vector{Vector{Float64}}})
if haskey(element, field_name)
update!(element[field_name], data)
else
element[field_name] = data
end
end
function update!(element::Element, field_name::AbstractString, data::Pair{Float64, Float64})
if haskey(element, field_name)
update!(element[field_name], data)
else
element[field_name] = data
end
end
function update!(element::Element, field_name::AbstractString, data::Union{Float64, Vector})
if haskey(element, field_name)
update!(element[field_name], data)
function update!(element::Element, field_name, data)
if haskey(element.fields, field_name)
update!(element.fields[field_name], data)
else
if length(data) != length(element)
update!(element, field_name, DCTI(data))
else
element[field_name] = data
end
end
end
function update!(element::Element, datas::Pair...)
for (field_name, data) in datas
if haskey(element, field_name)
update!(element[field_name], data)
else
element[field_name] = data
end
element.fields[field_name] = field(data)
end
end

function update!(element::Element, field_name::String, data::Function)
element[field_name] = data
end
function update!{F<:AbstractField}(element::Element, field_name::String, field::F)
element[field_name] = field
end
function update!{F<:AbstractField}(element::Element, field_name::String, data...)
update!(element.fields[field_name], data...)
end
=#

function update!(element::Element, field_name, data)
if haskey(element.fields, field_name)
update!(element.fields[field_name], data)
function update!(element::Element, field_name, data::Function)
if method_exists(data, Tuple{Element, Any, Any})
element.fields[field_name] = field((ip, time) -> data(element, ip, time))
else
element.fields[field_name] = field(data)
end
Expand Down
28 changes: 26 additions & 2 deletions src/fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,17 @@ function interpolate{N,T}(f::DVTI{N,T}, t)
return f.data
end

"""
interpolate(a, b)
A helper function for interpolate routines. Given iterables `a` and `b`,
calculate c = aᵢbᵢ. Length of `a` can be less than `b`, but not vice versa.
"""
function interpolate(a, b)
@assert length(a) <= length(b)
return sum(a[i]*b[i] for i=1:length(a))
end

"""
interpolate(f::DVTI, t, B)
Expand Down Expand Up @@ -263,8 +274,8 @@ type DVTVd{T} <: AbstractField
data :: Vector{Pair{Float64,Dict{Int64,T}}}
end

function DVTVd{T}(a::Pair{Float64,Dict{Int64,T}}, b::Pair{Float64,Dict{Int64,T}})
return DVTVd([a, b])
function DVTVd{T}(data::Pair{Float64,Dict{Int64,T}}...)
return DVTVd(collect(data))
end

"""
Expand Down Expand Up @@ -300,6 +311,19 @@ function interpolate{T}(field::DVTVd{T}, time)
end
end

"""
update!(f::DCTVd, time => data)
Update new value to dictionary field.
"""
function update!{T}(f::DVTVd, data::Pair{Float64,Dict{Int64,T}})
if isapprox(last(f.data).first, data.first)
f.data[end] = data
else
push!(f.data, data)
end
end

"""
field(x)
Expand Down
75 changes: 41 additions & 34 deletions src/problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,37 @@ function get_formulation_type(problem::Problem)
return :incremental
end

function get_unknown_field_name{P<:BoundaryProblem}(::Type{P})
"""
get_unknown_field_dimension(problem)
Return the dimension of the unknown field of this problem.
"""
function get_unknown_field_dimension(problem::Problem)
return problem.dimension
end

"""
get_unknown_field_name(problem)
Default function if unknown field name is not defined for some problem.
"""
function get_unknown_field_name{P<:AbstractProblem}(::P)
warn("The name of unknown field (e.g. displacement, temperature, ...) of the ",
"problem type must be given by defining function `get_unknown_field_name`")
return "N/A"
end

""" Return the name of the unknown field of this problem. """
function get_unknown_field_name{P}(problem::Problem{P})
return get_unknown_field_name(problem.properties)
end

""" Return the name of the parent field of this (boundary) problem. """
function get_parent_field_name{P<:BoundaryProblem}(problem::Problem{P})
return problem.parent_field_name
end

function get_unknown_field_name{P<:BoundaryProblem}(::P)
return "lambda"
end

Expand Down Expand Up @@ -188,25 +218,22 @@ function initialize!(problem::Problem, element::Element, time::Float64)
field_name = get_unknown_field_name(problem)
field_dim = get_unknown_field_dimension(problem)
nnodes = length(element)
if field_dim == 1 # scalar field
empty_field = tuple(zeros(nnodes)...)
else # vector field
empty_field = tuple([zeros(field_dim) for i=1:nnodes]...)
end

# initialize primary field
if !haskey(element, field_name)
if field_dim == 1
update!(element, field_name, time => zeros(nnodes))
else
update!(element, field_name, time => [zeros(field_dim) for i=1:nnodes])
end
update!(element, field_name, time => empty_field)
end

# if a boundary problem, initialize also a field for the main problem
is_boundary_problem(problem) || return
field_name = get_parent_field_name(problem)
if !haskey(element, field_name)
if field_dim == 1
update!(element, field_name, time => zeros(nnodes))
else
update!(element, field_name, time => [zeros(field_dim) for i=1:nnodes])
end
update!(element, field_name, time => empty_field)
end
end

Expand Down Expand Up @@ -302,7 +329,7 @@ function update!{P<:FieldProblem}(problem::Problem{P}, assembly::Assembly, eleme
# update solution u for elements
for element in elements
connectivity = get_connectivity(element)
update!(element, field_name, time => u[connectivity])
update!(element, field_name, time => tuple(u[connectivity]...))
end
end

Expand All @@ -313,8 +340,8 @@ function update!{P<:BoundaryProblem}(problem::Problem{P}, assembly::Assembly, el
# update solution and lagrange multipliers for boundary elements
for element in elements
connectivity = get_connectivity(element)
update!(element, parent_field_name, time => u[connectivity])
update!(element, field_name, time => la[connectivity])
update!(element, parent_field_name, time => tuple(u[connectivity]...))
update!(element, field_name, time => tuple(la[connectivity]...))
end
end

Expand Down Expand Up @@ -386,26 +413,6 @@ function (problem::Problem)(field_name::String, time::Float64)
return f
end

""" Return the dimension of the unknown field of this problem. """
function get_unknown_field_dimension(problem::Problem)
return problem.dimension
end

function get_unknown_field_name{P<:AbstractProblem}(::Type{P})
warn("The name of unknown field (e.g. displacement, temperature, ...) of the problem type must be given by defining function `get_unknown_field_name`")
return "N/A"
end

""" Return the name of the unknown field of this problem. """
function get_unknown_field_name{P}(problem::Problem{P})
return get_unknown_field_name(P)
end

""" Return the name of the parent field of this (boundary) problem. """
function get_parent_field_name{P<:BoundaryProblem}(problem::Problem{P})
return problem.parent_field_name
end

function push!(problem::Problem, elements...)
push!(problem.elements, elements...)
end
Expand Down
12 changes: 12 additions & 0 deletions test/test_elements.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,15 @@ end
update!(el, "geometry", X)
@test_throws Exception get_local_coordinates(el, [0.1, 0.1], 0.0; max_iterations=0)
end

@testset "analytical functions as fields" begin
f(xi, time) = xi[1]*time
g(element, xi, time) = element("geometry", xi, time)*time
X = Dict(1 => [0.0, 0.0], 2 => [1.0, 0.0], 3 => [1.0, 1.0], 4 => [0.0, 1.0])
el = Element(Quad4, [1, 2, 3, 4])
update!(el, "geometry", X)
update!(el, "f", f)
update!(el, "g", g)
@test isapprox(el("f", (0.5, 0.5), 2.0), 1.0)
@test isapprox(el("g", (0.0, 0.0), 2.0), [1.0, 1.0])
end
20 changes: 20 additions & 0 deletions test/test_fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ end
@test isapprox(interpolate(F, 0.5)[100000], [2.0,2.0])
end

@testset "update dictionary field" begin
f1 = Dict(1=>1.0, 2=>2.0, 3=>3.0)
f2 = Dict(1=>2.0, 2=>3.0, 3=>4.0)
fld = DVTVd(0.0 => f1)
update!(fld, 1.0 => f2)
@test isapprox(interpolate(fld, 0.5)[1], 1.5)
update!(fld, 1.0 => f1)
@test isapprox(interpolate(fld, 0.5)[1], 1.0)
end

@testset "use of common constructor field" begin
@test isa(field(1.0), DCTI)
@test isa(field(1.0 => 1.0), DCTV)
Expand All @@ -130,3 +140,13 @@ end
@test isa(X1, DVTId)
@test isa(X2, DVTVd)
end

@testset "general interpolation" begin
a = [1, 2, 3]
b = (2, 3, 4)
@test interpolate(a, b) == 2+6+12
a = (1, 2)
b = (2, 3, 4)
@test interpolate(a, b) == 2+6
@test_throws AssertionError interpolate(b, a)
end
Loading

0 comments on commit 3ab1f97

Please sign in to comment.