Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Jan 28, 2025
1 parent 588fa61 commit 95e454b
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 76 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
run: |
using Pkg
Pkg.add([
PackageSpec(name="StarAlgebras", rev="mk/quadratic_form"),
PackageSpec(name="StarAlgebras", rev="bl/quad_form"),
PackageSpec(name="SymbolicWedderburn", rev="master"),
PackageSpec(name="MultivariateBases", rev="master"),
PackageSpec(name="MultivariateMoments", rev="master"),
Expand Down
6 changes: 5 additions & 1 deletion src/Bridges/Variable/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ function MOI.Bridges.Variable.bridge_constrained_variable(
gram, vars, con = SOS.add_gram_matrix(model, M, gram_basis, T)
push!(variables, vars)
push!(constraints, con)
MA.operate!(SA.UnsafeAddMul(*), acc, gram, weight)
if isone(weight)
MA.operate!(SA.UnsafeAdd(), acc, SA.QuadraticForm(gram))
else
MA.operate!(SA.UnsafeAddMul(*), acc, gram, weight)
end
end
MA.operate!(SA.canonical, SA.coeffs(acc))
return KernelBridge{T,M}(
Expand Down
20 changes: 14 additions & 6 deletions src/Certificate/ideal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,20 @@ function _combine_with_gram(
)
end
for (gram, weight) in zip(gram_bases, weights)
MA.operate!(
SA.UnsafeAddMul(*),
p,
GramMatrix{_NonZero}((_, _) -> _NonZero(), gram),
weight,
)
if isone(weight)
MA.operate!(
SA.UnsafeAdd(),
p,
SA.QuadraticForm(GramMatrix{_NonZero}((_, _) -> _NonZero(), gram)),
)
else
MA.operate!(
SA.UnsafeAddMul(*),
p,
GramMatrix{_NonZero}((_, _) -> _NonZero(), gram),
weight,
)
end
end
MA.operate!(SA.canonical, SA.coeffs(p))
return MB.SubBasis{B}(keys(SA.coeffs(p)))
Expand Down
123 changes: 61 additions & 62 deletions src/Certificate/newton_polytope.jl
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,8 @@ Base.iszero(::SignChange) = false
MA.scaling_convert(::Type, s::SignChange) = s
Base.:*(s::SignChange, α::Real) = SignChange(s.sign * α, s.Δ)
Base.:*::Real, s::SignChange) = SignChange* s.sign, s.Δ)
#Base.convert(::Type{SignChange{T}}, s::SignChange) where {T} = SignChange{T}(s.sign, s.Δ)
#Base.:+(a::SignChange, b::SignChange) = convert(SignCount, a) + b

struct SignCount
unknown::Int
Expand All @@ -593,6 +595,18 @@ function _sign(c::SignCount)
end
end

function Base.:*(α, a::SignCount)
if α > 0
return a
elseif α < 0
return SignCount(a.unknown, a.negative, a.positive)
else
error("Cannot multiply `SignCount`` with ``")
end
end

Base.:*(a::SignCount, α) = α * a

function Base.:+(a::SignCount, b::SignCount)
return SignCount(
a.unknown + b.unknown,
Expand All @@ -602,16 +616,16 @@ function Base.:+(a::SignCount, b::SignCount)
end

function Base.:+(c::SignCount, a::SignChange{Missing})
@assert c.unknown >= -a.Δ
#@assert c.unknown >= -a.Δ
return SignCount(c.unknown + a.Δ, c.positive, c.negative)
end

function Base.:+(c::SignCount, a::SignChange{<:Number})
if a.sign > 0
@assert c.positive >= -a.Δ
#@assert c.positive >= -a.Δ
return SignCount(c.unknown, c.positive + a.Δ, c.negative)
elseif a.sign < 0
@assert c.negative >= -a.Δ
#@assert c.negative >= -a.Δ
return SignCount(c.unknown, c.positive, c.negative + a.Δ)
elseif iszero(a.sign)
error(
Expand All @@ -624,26 +638,16 @@ end

Base.convert(::Type{SignCount}, Δ::SignChange) = SignCount() + Δ

function increase(cache, counter, generator_sign, monos, mult)
for a in monos
for b in monos
MA.operate_to!(
cache,
*,
MB.algebra_element(mult),
MB.algebra_element(a),
MB.algebra_element(b),
)
MA.operate!(
SA.UnsafeAddMul(*),
counter,
_term_constant_monomial(
SignChange((a != b) ? missing : generator_sign, 1),
mult,
),
cache,
)
end
struct SignGram{T,B}
sign::T
basis::B
end
SA.basis(g::SignGram) = g.basis
function Base.getindex(g::SignGram, i, j)
if i == j
return SignChange(g.sign, 1)
else
return SignChange(missing, 2)
end
end

Expand Down Expand Up @@ -708,7 +712,8 @@ function post_filter(
_DictCoefficients(Dict{MP.monomial_type(typeof(poly)),SignCount}()),
MB.implicit_basis(SA.basis(poly)),
)
cache = zero(Float64, MB.algebra(MB.implicit_basis(SA.basis(poly))))
cache = zero(SignCount, MB.algebra(MB.implicit_basis(SA.basis(poly))))
cache2 = zero(SignCount, MB.algebra(MB.implicit_basis(SA.basis(poly))))
for (mono, v) in SA.nonzero_pairs(SA.coeffs(poly))
MA.operate!(
SA.UnsafeAdd(),
Expand All @@ -717,29 +722,21 @@ function post_filter(
)
end
for (mult, gram_monos) in zip(generators, multipliers_gram_monos)
for (mono, v) in SA.nonzero_pairs(SA.coeffs(mult))
increase(
cache,
counter,
-_sign(v),
gram_monos,
SA.basis(mult)[mono],
)
end
MA.operate_to!(cache, copy, SA.QuadraticForm(SignGram(-1, gram_monos)))
MA.operate!(SA.UnsafeAddMul(*), counter, mult, cache)
end
function decrease(sign, a, b, c)
function decrease(sign, a, b, generator)
MA.operate_to!(
cache,
*,
MB.algebra_element(a),
_term(SignChange(1, -1), a),
MB.algebra_element(b),
MB.algebra_element(c),
)
MA.operate!(
SA.UnsafeAddMul(*),
counter,
_term_constant_monomial(SignChange(sign, -1), a),
cache,
generator,
)
for mono in SA.supp(cache)
count = SA.coeffs(counter)[SA.basis(counter)[mono]]
Expand All @@ -765,36 +762,38 @@ function post_filter(
end
keep[i][j] = false
a = multipliers_gram_monos[i][j]
for (k, v) in SA.nonzero_pairs(SA.coeffs(generators[i]))
mono = SA.basis(generators[i])[k]
sign = -_sign(v)
decrease(sign, mono, a, a)
for (j, b) in enumerate(multipliers_gram_monos[i])
if keep[i][j]
decrease(missing, mono, a, b)
decrease(missing, mono, b, a)
end
decrease(-1, a, a, generators[i])
for (k, b) in enumerate(multipliers_gram_monos[i])
if keep[i][k]
decrease(missing, a, b, generators[i])
decrease(missing, b, a, generators[i])
end
end
end
for i in eachindex(generators)
for k in SA.supp(generators[i])
for (j, mono) in enumerate(multipliers_gram_monos[i])
MA.operate_to!(
cache,
*,
MB.algebra_element(k),
MB.algebra_element(mono),
MB.algebra_element(mono),
for (j, mono) in enumerate(multipliers_gram_monos[i])
MA.operate_to!(
cache,
*,
# Dummy coef to help convert to `SignCount` which is the `eltype` of `cache`
_term(SignChange(1, 1), mono),
MB.algebra_element(mono),
)
# The `eltype` of `cache` is `SignCount`
# so there is no risk of term cancellation
MA.operate_to!(
cache2,
*,
cache,
generators[i],
)
for w in SA.supp(cache)
if ismissing(
_sign(SA.coeffs(counter)[SA.basis(counter)[w]]),
)
for w in SA.supp(cache)
if ismissing(
_sign(SA.coeffs(counter)[SA.basis(counter)[w]]),
)
push!(get!(back, w, Tuple{Int,Int}[]), (i, j))
else
delete(i, j)
end
push!(get!(back, w, Tuple{Int,Int}[]), (i, j))
else
delete(i, j)
end
end
end
Expand Down
22 changes: 16 additions & 6 deletions src/gram_matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,21 @@ end
# convert(PT, MP.polynomial(p))
#end

function MB.algebra_element(
p::Union{GramMatrix{T,B,U},BlockDiagonalGramMatrix{T,B,U}},
) where {T,B,U}
return MB.algebra_element(p, U)
end

function MB.algebra_element(
g::Union{GramMatrix,BlockDiagonalGramMatrix},
::Type{T},
) where {T}
a = zero(T, MB.algebra(MB.implicit_basis(g)))
MA.operate_to!(a, copy, SA.QuadraticForm(g))
return a
end

function MP.polynomial(
p::Union{GramMatrix{T,B,U},BlockDiagonalGramMatrix{T,B,U}},
) where {T,B,U}
Expand All @@ -318,10 +333,5 @@ function MP.polynomial(
g::Union{GramMatrix,BlockDiagonalGramMatrix},
::Type{T},
) where {T}
p = zero(T, MB.algebra(MB.implicit_basis(g)))
MA.operate!(SA.UnsafeAddMul(*), p, g)
MA.operate!(SA.canonical, SA.coeffs(p))
return MP.polynomial(
SA.coeffs(p, MB.FullBasis{MB.Monomial,MP.monomial_type(g)}()),
)
return MP.polynomial(MB.algebra_element(g, T))
end

0 comments on commit 95e454b

Please sign in to comment.