Skip to content

Commit

Permalink
Update type hierarchy - new version
Browse files Browse the repository at this point in the history
  • Loading branch information
nilshg committed Dec 10, 2021
1 parent ddd2ad7 commit 996e16d
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 117 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TreatmentPanels"
uuid = "7885c543-3ac4-48a3-abed-7a36d7ddb69f"
authors = ["Nils <[email protected]> and contributors"]
version = "0.2.0"
version = "0.3.0"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand Down
149 changes: 52 additions & 97 deletions src/TreatmentPanel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,25 @@ using DataFrames, Dates, Parameters
abstract type TreatmentPanel end

# Types for number of treatment units and periods
abstract type UnitTreatmentType end
# Treatment duration - this is either continuous or discontinuous
abstract type TreatmentDurationType end
struct SingleUnitTreatment <: UnitTreatmentType end
struct MultiUnitSimultaneousTreatment <: UnitTreatmentType end
struct MultiUnitStaggeredTreatment <: UnitTreatmentType end
struct ContinuousTreatment <: TreatmentDurationType end
struct StartEndTreatment <: TreatmentDurationType end
struct Continuous <: TreatmentDurationType end
struct Discontinuous <: TreatmentDurationType end

# Treatment timing - relevant only for MultiUnitTreatments, can be simultaneous or staggered
abstract type TreatmentTimingType end
struct Simultaneous{T <: TreatmentDurationType} <: TreatmentTimingType end
struct Staggered{T <: TreatmentDurationType} <: TreatmentTimingType end

# Unit type - either single or multiple treated units
abstract type TreatmentType end
struct SingleUnitTreatment{T <: TreatmentDurationType} <: TreatmentType end
struct MultiUnitTreatment{T <: TreatmentTimingType} <: TreatmentType end


# BalancedPanel will have an N×T matrix of treatment assigment and outcomes
"""
BalancedPanel{UnitTreatmentType, TreatmentDurationType}
BalancedPanel{TreatmentType}
A TreatmentPanel in which all N treatment units are observed for the same T periods.
Expand All @@ -36,7 +44,7 @@ The following table provides an overview of the types of treatment pattern suppo
| **one unit** | Pair{String, Date} | Pair{String, Tuple{Date, Date}} | Pair{String}, Vector{Tuple{Date, Date}}} |
| **multiple units** | Vector{Pair{String, Date}} | Vector{Pair{String, Tuple{Date, Date}}} | Vector{Pair{String}, Vector{Tuple{Date, Date}}}} |
"""
@with_kw struct BalancedPanel{UTType, TDType} <: TreatmentPanel where UTType <: UnitTreatmentType where TDType <: TreatmentDurationType
@with_kw struct BalancedPanel{UTType} <: TreatmentPanel where UTType <: TreatmentType
W::Union{Matrix{Bool}, Matrix{Union{Missing, Bool}}}
Y::Matrix{Float64}
df::DataFrame
Expand Down Expand Up @@ -120,7 +128,7 @@ function construct_W(tas::Vector{Pair{T1, S1}}, N, T, is, ts) where T1 where S1
return W
end

# Constructor for single continuous treatment - returns BalancedPanel{SingleUnitTreatment, ContinuousTreatment}
# Constructor for single continuous treatment - returns BalancedPanel{SingleUnitTreatment{Continuous}}
function BalancedPanel(df::DataFrame, treatment_assignment::Pair{T1, T2};
id_var = nothing, t_var = nothing, outcome_var = nothing,
sort_inplace = false) where T1 where T2 <: Union{Date, Int}
Expand Down Expand Up @@ -151,16 +159,12 @@ function BalancedPanel(df::DataFrame, treatment_assignment::Pair{T1, T2};
W = construct_W(treatment_assignment, N, T, is, ts)

# Outcome matrix
Y = zeros(size(W))
Y = zeros(eltype(df[!, outcome_var]), size(W))
for (row, i) enumerate(is), (col, t) enumerate(ts)
try
Y[row, col] = only(df[(df[!, id_var] .== i) .& (df[!, t_var] .== t), outcome_var])
catch ArgumentError
throw("$(nrow(df[(df[!, id_var] .== i) .& (df[!, t_var] .== t), :])) outcomes present in the data for unit $i in period $t")
end
Y[row, col] = only(df[(df[!, id_var] .== i) .& (df[!, t_var] .== t), outcome_var])
end

BalancedPanel{SingleUnitTreatment, ContinuousTreatment}(W, Y, df, id_var, t_var, outcome_var, ts, is)
BalancedPanel{SingleUnitTreatment{Continuous}}(W, Y, df, id_var, t_var, outcome_var, ts, is)
end

# Getter functions
Expand All @@ -170,7 +174,7 @@ end
Returns the indices of treated units in the panel, so that Y[treated_ids(x), :] returns a
(Nₜᵣ×T) matrix of outcomes for treated units in all periods.
"""
function treated_ids(x::BalancedPanel{SingleUnitTreatment, T2}) where T2
function treated_ids(x::BalancedPanel{SingleUnitTreatment{T}}) where T
for i 1:size(x.Y, 1)
for t 1:size(x.Y, 2)
if x.W[i, t]
Expand All @@ -180,12 +184,16 @@ function treated_ids(x::BalancedPanel{SingleUnitTreatment, T2}) where T2
end
end

function treated_ids(x::BalancedPanel{MultiUnitTreatment{T}}) where T
findall(>(0), vec(sum(Y, dims = 2)))
end

"""
treated_labels(x <: BalancedPanel)
Returns the labels of treated units as given by the `id_var` column in the underlying data set.
"""
function treated_labels(x::BalancedPanel{SingleUnitTreatment, T2}) where T2
function treated_labels(x::BalancedPanel{SingleUnitTreatment{T}}) where T
x.is[treated_ids(x)]
end

Expand All @@ -196,7 +204,7 @@ end
of length Nₜᵣ, where each element is the index of the first 1 in the row of treatment matrix W
corresonding to the treatment unit.
"""
function first_treated_period_ids(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
function first_treated_period_ids(x::BalancedPanel{SingleUnitTreatment{T}}) where T
findfirst(vec(x.W[treated_ids(x), :]))
end

Expand All @@ -206,7 +214,7 @@ end
Returns the labels of the first treated period for each treated units, that is, a Vector{T}
of length Nₜᵣ, where T is the eltype of the `t_var` column in the underlying data.
"""
function first_treated_period_labels(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
function first_treated_period_labels(x::BalancedPanel{SingleUnitTreatment{T}}) where T
x.ts[first_treated_period_ids(x)]
end

Expand All @@ -215,7 +223,7 @@ end
Returns the number of pre-treatment periods for each treated unit.
"""
function length_T₀(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
function length_T₀(x::BalancedPanel{SingleUnitTreatment{Continuous}})
first_treated_period_ids(x) - 1
end

Expand All @@ -224,7 +232,7 @@ end
Returns the number of treatment periods for each treated unit.
"""
function length_T₁(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
function length_T₁(x::BalancedPanel{SingleUnitTreatment{Continuous}})
size(x.Y, 2) .- first_treated_period_ids(x) + 1
end

Expand All @@ -235,7 +243,7 @@ end
Returns the pre-treatment outcomes for the treated unit(s). For SingleUnitTreatment designs,
this is a vector of length T₀, while for MultiUnitTreatment designs, it is a (Nₜᵣ×T₀) matrix
"""
function get_y₁₀(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
function get_y₁₀(x::BalancedPanel{SingleUnitTreatment{Continuous}})
x.Y[treated_ids(x), 1:first_treated_period_ids(x)-1]
end

Expand All @@ -245,16 +253,16 @@ end
Returns the post-treatment outcomes for the treated unit(s). For SingleUnitTreatment designs,
this is a vector of length T₁, while for MultiUnitTreatment designs, it is a (Nₜᵣ×T₁) matrix
"""
function get_y₁₁(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
function get_y₁₁(x::BalancedPanel{SingleUnitTreatment{Continuous}})
x.Y[x.W]
end

"""
""" sc
get_y₀₀(x <: BalancedPanel)
Returns the pre-treatment outcomes for the untreated units, an (Nₖₒ×T₀) matrix
"""
function get_y₀₀(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
function get_y₀₀(x::BalancedPanel{SingleUnitTreatment{Continuous}})
x.Y[Not(treated_ids(x)), 1:first_treated_period_ids(x)-1]
end

Expand All @@ -263,27 +271,29 @@ end
Returns the post-treatment outcomes for the untreated units, an (Nₖₒ×T₁) matrix
"""
function get_y₀₁(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: ContinuousTreatment
function get_y₀₁(x::BalancedPanel{SingleUnitTreatment{Continuous}})
x.Y[Not(treated_ids(x)), first_treated_period_ids(x):end]
end

"""
get_y₀₀(x <: BalancedPanel)
decompose_y(x <: BalancedPanel)
Decomposes the outcome matrix Y into four elements:
* Pre-treatment outcomes for treated units (y₁₀)
* Post-treatment outcomes for treated units (y₁)
* Post-treatment outcomes for treated units (y₁)
* Pre-treatment outcomes for control units (y₀₀)
* Post-treatment outcomes for treated units (y₀₁)
and returns a tuple (y₁₀, y₁₁, y₀₀, y₀₁)
"""
function decompose_y(x)
get_y₁₀(x), get_y₁₁(x), get_y₀₀(x), get_y₀₁(x)
end

####################################################################################################

# Constructor for single start/end treatment - returns BalancedPanel{SingleUnitTreatment, StartEndTreatment}
# Constructor for single start/end treatment - returns BalancedPanel{SingleUnitTreatment{Discontinuous}}
function BalancedPanel(df::DataFrame, treatment_assignment::Pair{T1, T2};
id_var = nothing, t_var = nothing, outcome_var = nothing,
sort_inplace = false) where T1 where T2 <: Union{Pair{Date, Date}, Pair{Int, Int}}
Expand Down Expand Up @@ -324,35 +334,9 @@ function BalancedPanel(df::DataFrame, treatment_assignment::Pair{T1, T2};
end
end

BalancedPanel{SingleUnitTreatment, StartEndTreatment}(N, T, W, ts, is, Y)
end

function first_treated_period_ids(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: StartEndTreatment
ids = Int64[]
treated_row = vec(@view x.W[treated_ids(x), :])
for t 2:x.T
if treated_row[t] == 1 && treated_row[t-1] == 0
push!(ids, t)
end
end

return ids
end

function first_treated_period_labels(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: StartEndTreatment
x.ts[first_treated_period_ids(x)]
end

function length_T₀(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: StartEndTreatment
first(first_treated_period_ids(x)) - 1
BalancedPanel{SingleUnitTreatment{Discontinuous}}(W, Y, df, id_var, t_var, outcome_var, ts, is)
end

function length_T₁(x::BalancedPanel{SingleUnitTreatment, T2}) where T2 <: StartEndTreatment
x.T .- last(first_treated_period_ids(x)) + 1
end



# Fallback method - if the length of treatment assignment is one use single treatment method above
function BalancedPanel(df::DataFrame, treatment_assignment;
id_var = nothing, t_var = nothing, outcome_var = nothing,
Expand Down Expand Up @@ -395,64 +379,35 @@ function BalancedPanel(df::DataFrame, treatment_assignment;
W = construct_W(treatment_assignment, N, T, is, ts)

# Outcome matrix
Y = zeros(size(W))
Y = zeros(eltype(df[!, outcome_var]), size(W))
for (row, i) enumerate(is), (col, t) enumerate(ts)
try
Y[row, col] = only(df[(df[!, id_var] .== i) .& (df[!, t_var] .== t), outcome_var])
catch ArgumentError
throw("$(nrow(df[(df[!, id_var] .== i) .& (df[!, t_var] .== t), :])) outcomes present in the data for unit $i in period $t")
end
Y[row, col] = only(df[(df[!, id_var] .== i) .& (df[!, t_var] .== t), outcome_var])
end

# Determine UnitTreatmentType and TreatmentDurationType
# Determine TreatmentType and TreatmentDurationType
uttype = if all(==(treatment_assignment[1][2]), last.(treatment_assignment))
MultiUnitSimultaneousTreatment
Simultaneous
else
MultiUnitStaggeredTreatment
Staggered
end

tdtype = if typeof(treatment_assignment) <: Pair
if typeof(treatment_assignment[2]) <: Pair
StartEndTreatment
Discontinuous
else
ContinuousTreatment
Continuous
end
else
if typeof(treatment_assignment[1][2]) <: Pair
StartEndTreatment
Discontinuous
else
ContinuousTreatment
Continuous
end
end

BalancedPanel{uttype, tdtype}(N, T, W, ts, is, Y)
BalancedPanel{MultiUnitTreatment{uttype{tdtype}}}(W, Y, df, id_var, t_var, outcome_var, ts, is)
end

## UnblancedPanel - N observations but not all of them for T periods

#!# Not yet implemented

## Utility functions
function treated_ids(x::BalancedPanel)
any.(eachrow(x.W))
end

function treated_labels(x::BalancedPanel)
x.is[treated_ids(x)]
end

function first_treated_period_ids(x::BalancedPanel)
findfirst.(eachrow(x.W[treated_ids(x), :]))
end

function first_treated_period_labels(x::BalancedPanel)
x.ts[first_treated_period_ids(x)]
end

function length_T₀(x::BalancedPanel)
first_treated_period_ids(x) .- 1
end

function length_T₁(x::BalancedPanel)
x.T .- first_treated_period_ids(x) .+ 1
end
#!# Not yet implemented
6 changes: 4 additions & 2 deletions src/TreatmentPanels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ export BalancedPanel, UnbalancedPanel

# Export treatment description types
export UnitTreatmentType
export SingleUnitTreatment, MultiUnitSimultaneousTreatment, MultiUnitStaggeredTreatment
export SingleUnitTreatment, MultiUnitTreatment
export TreatmentTimingType
export Staggered, Simultaneous
export TreatmentDurationType
export ContinuousTreatment, StartEndTreatment
export Continuous, Discontinuous

# Export utility functions
export treated_ids, treated_labels, first_treated_period_ids, first_treated_period_labels, length_T₀, length_T₁
Expand Down
12 changes: 10 additions & 2 deletions src/show_and_plot.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
using RecipesBase

# Custom show methods
function Base.show(io::IO, mime::MIME"text/plain", x::BalancedPanel{SingleUnitTreatment, ContinuousTreatment})
println("Balanced Panel - single unit, single continuous treatment")
title(x::BalancedPanel{SingleUnitTreatment{Continuous}}) = "single treated unit, continuous treatment"
title(x::BalancedPanel{SingleUnitTreatment{Discontinuous}}) = "single treated unit, discontinuous treatment"
title(x::BalancedPanel{MultiUnitTreatment{Simultaneous{Continuous}}}) = "multiple treated units, simultaneous continuous treatment"
title(x::BalancedPanel{MultiUnitTreatment{Simultaneous{Discontinuous}}}) = "multiple treated units, simultaneous discontinuous treatment"
title(x::BalancedPanel{MultiUnitTreatment{Staggered{Continuous}}}) = "multiple treated units, staggered continuous treatment"
title(x::BalancedPanel{MultiUnitTreatment{Staggered{Discontinuous}}}) = "multiple treated units, staggered discontinuous treatment"

function Base.show(io::IO, mime::MIME"text/plain", x::BalancedPanel{SingleUnitTreatment{Continuous}})
println("Balanced Panel - $(title(x))")
println(" Treated unit: $(treated_labels(x))")
println(" Number of untreated units: $(size(x.Y, 1) - 1)")
println(" First treatment period: $(first_treated_period_labels(x))")
println(" Number of pretreatment periods: $(length_T₀(x))")
println(" Number of treatment periods: $(length_T₁(x))")
end

#!# TO DO - add show methods for other panel types

# Plotting recipe
@recipe function f(bp::BalancedPanel; kind = "treatment")
Expand Down
Loading

0 comments on commit 996e16d

Please sign in to comment.