Skip to content

Commit

Permalink
Changes for SDID implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
nilshg committed Dec 3, 2021
1 parent 8e417b1 commit ddd2ad7
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 8 deletions.
83 changes: 76 additions & 7 deletions src/TreatmentPanel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ The following table provides an overview of the types of treatment pattern suppo
| **multiple units** | Vector{Pair{String, Date}} | Vector{Pair{String, Tuple{Date, Date}}} | Vector{Pair{String}, Vector{Tuple{Date, Date}}}} |
"""
@with_kw struct BalancedPanel{UTType, TDType} <: TreatmentPanel where UTType <: UnitTreatmentType where TDType <: TreatmentDurationType
N::Int64
T::Int64
W::Union{Matrix{Bool}, Matrix{Union{Missing, Bool}}}
Y::Matrix{Float64}
df::DataFrame
id_var::Union{String, Symbol}
t_var::Union{String, Symbol}
outcome_var::Union{String, Symbol}
ts::Vector{T1} where T1 <: Union{Date, Int64}
is::Vector{T2} where T2 <: Union{Symbol, String, Int64}
Y::Matrix{Float64}
end

# Check that ID, time, and outcome variable are provided
Expand Down Expand Up @@ -158,56 +160,123 @@ function BalancedPanel(df::DataFrame, treatment_assignment::Pair{T1, T2};
end
end

BalancedPanel{SingleUnitTreatment, ContinuousTreatment}(N, T, W, ts, is, Y)
BalancedPanel{SingleUnitTreatment, ContinuousTreatment}(W, Y, df, id_var, t_var, outcome_var, ts, is)
end

# Getter functions
"""
treated_ids(x <: BalancedPanel)
Returns the indices of treated units in the panel, so that Y[treated_ids(x), :] returns a
(Nₜᵣ×T) matrix of outcomes for treated units in all periods.
"""
function treated_ids(x::BalancedPanel{SingleUnitTreatment, T2}) where T2
for i 1:x.N
for t 1:x.T
for i 1:size(x.Y, 1)
for t 1:size(x.Y, 2)
if x.W[i, t]
return i
end
end
end
end

"""
treated_labels(x <: BalancedPanel)
Returns the labels of treated units as given by the `id_var` column in the underlying data set.
"""
function treated_labels(x::BalancedPanel{SingleUnitTreatment, T2}) where T2
x.is[treated_ids(x)]
end

"""
first_treated_period_ids(x <: BalancedPanel)
Returns the indices of the first treated period for each treated units, that is, a Vector{Int}
of length Nₜᵣ, where each element is the index of the first 1 in the row of treatment matrix W
corresonding to the treatment unit.
"""
function first_treated_period_ids(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
findfirst(vec(x.W[treated_ids(x), :]))
end

"""
first_treated_period_labels(x <: BalancedPanel)
Returns the labels of the first treated period for each treated units, that is, a Vector{T}
of length Nₜᵣ, where T is the eltype of the `t_var` column in the underlying data.
"""
function first_treated_period_labels(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
x.ts[first_treated_period_ids(x)]
end

"""
length_T₀(x <: BalancedPanel)
Returns the number of pre-treatment periods for each treated unit.
"""
function length_T₀(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
first_treated_period_ids(x) - 1
end

"""
length_T₁(x <: BalancedPanel)
Returns the number of treatment periods for each treated unit.
"""
function length_T₁(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
x.T .- first_treated_period_ids(x) + 1
size(x.Y, 2) .- first_treated_period_ids(x) + 1
end


"""
get_y₁₀(x <: BalancedPanel)
Returns the pre-treatment outcomes for the treated unit(s). For SingleUnitTreatment designs,
this is a vector of length T₀, while for MultiUnitTreatment designs, it is a (Nₜᵣ×T₀) matrix
"""
function get_y₁₀(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
x.Y[treated_ids(x), 1:first_treated_period_ids(x)-1]
end

"""
get_y₁₁(x <: BalancedPanel)
Returns the post-treatment outcomes for the treated unit(s). For SingleUnitTreatment designs,
this is a vector of length T₁, while for MultiUnitTreatment designs, it is a (Nₜᵣ×T₁) matrix
"""
function get_y₁₁(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
x.Y[x.W]
end

"""
get_y₀₀(x <: BalancedPanel)
Returns the pre-treatment outcomes for the untreated units, an (Nₖₒ×T₀) matrix
"""
function get_y₀₀(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
x.Y[Not(treated_ids(x)), 1:first_treated_period_ids(x)-1]
end

"""
get_y₀₀(x <: BalancedPanel)
Returns the post-treatment outcomes for the untreated units, an (Nₖₒ×T₁) matrix
"""
function get_y₀₁(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
x.Y[Not(treated_ids(x)), first_treated_period_ids(x):end]
end

"""
get_y₀₀(x <: BalancedPanel)
Decomposes the outcome matrix Y into four elements:
* Pre-treatment outcomes for treated units (y₁₀)
* Post-treatment outcomes for treated units (y₁₀)
* Pre-treatment outcomes for control units (y₀₀)
* Post-treatment outcomes for treated units (y₀₁)
"""
function decompose_y(x)
get_y₁₀(x), get_y₁₁(x), get_y₀₀(x), get_y₀₁(x)
end
Expand Down
2 changes: 1 addition & 1 deletion src/show_and_plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using RecipesBase
function Base.show(io::IO, mime::MIME"text/plain", x::BalancedPanel{SingleUnitTreatment, ContinuousTreatment})
println("Balanced Panel - single unit, single continuous treatment")
println(" Treated unit: $(treated_labels(x))")
println(" Number of untreated units: $(x.N - 1)")
println(" Number of untreated units: $(size(x.Y, 1) - 1)")
println(" First treatment period: $(first_treated_period_labels(x))")
println(" Number of pretreatment periods: $(length_T₀(x))")
println(" Number of treatment periods: $(length_T₁(x))")
Expand Down

0 comments on commit ddd2ad7

Please sign in to comment.