Skip to content

Commit

Permalink
Implement more functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
nilshg committed Nov 19, 2021
1 parent f95f5dc commit ad16780
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 78 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TreatmentPanels"
uuid = "7885c543-3ac4-48a3-abed-7a36d7ddb69f"
authors = ["Nils <[email protected]> and contributors"]
version = "0.1.0"
version = "0.1.1"

[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Expand Down
70 changes: 65 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,67 @@ or in the Pkg REPL
```
(@v1.7) add TreatmentPanels
```
## Usage

## Quickstart

The basic idea of the package is to combine a `DataFrame` with a specification of treatment
assignment to construct a an outcome matrix `Y`, and an accompanying treatment matrix `W`. Both `Y`
and `W` follow the convention that a row is an observational unit and a column is a time period -
i.e. it can be seen as a "wide" panel data set.

```
julia> using DataFrames, TreatmentPanels
julia> data = DataFrame(region = repeat(["A", "B"], inner = 5), year = repeat(Date(2000):Year(1):Date(2004), 2), outcome = rand(10))
10×3 DataFrame
Row │ region year outcome
│ String Date Float64
─────┼───────────────────────────────
1 │ A 2000-01-01 0.0605538
2 │ A 2001-01-01 0.820218
3 │ A 2002-01-01 0.533732
4 │ A 2003-01-01 0.144979
5 │ A 2004-01-01 0.353885
6 │ B 2000-01-01 0.65294
7 │ B 2001-01-01 0.353973
8 │ B 2002-01-01 0.683144
9 │ B 2003-01-01 0.477427
10 │ B 2004-01-01 0.702888
julia> bp = BalancedPanel(data, "A" => Date(2003); id_var = "region", t_var = "year", outcome_var = "outcome")
Balanced Panel - single unit, single continuous treatment
Treated unit: A
Number of untreated units: 1
First treatment period: 2003-01-01
Number of pretreatment periods: 3
Number of treatment periods: 2
julia> bp.Y
2×5 Matrix{Float64}:
0.0605538 0.820218 0.533732 0.144979 0.353885
0.65294 0.353973 0.683144 0.477427 0.702888
julia> bp.W
2×5 Matrix{Bool}:
0 0 0 1 1
0 0 0 0 0
```

The package provides simple plotting functionality to visualise treatment assignment and outcomes:

```
julia> using Plots
julia> plot(bp, markersize = 10)
```
<img src="plot1.jpg" alt="Treatment plot" width="400"/>

```
julia> plot(bp; kind = "outcome")
```
<img src="plot2.jpg" alt="Outcome plot" width="400"/>

## Types provided

There are two basic types:

Expand Down Expand Up @@ -59,10 +119,10 @@ treatment observed. `UnitTreatmentType` currently has three concrete types:

The types of `T1` and `T2` are automatically chosen based on the `treatment_assignment` passed.

| | Only starting point | Start and end point | Multiple start & end points |
|----------- |------------------------ |------------------------- |------------------------------ |
| **one unit** | Pair{String, Date} | Pair{String, Tuple{Date, Date}} | Pair{String}, Vector{Tuple{Date, Date}}} |
| **multiple units** | Vector{Pair{String, Date}} | Vector{Pair{String, Tuple{Date, Date}}} | Vector{Pair{String}, Vector{Tuple{Date, Date}}}} |
| | Only starting point | Start and end point | Multiple start & end points |
|---------------------|-----------------------------|-------------------------------------------|---------------------------------------------------|
| **one unit** | Pair{String, Date} | Pair{String, Tuple{Date, Date}} | Pair{String}, Vector{Tuple{Date, Date}}} |
| **multiple units** | Vector{Pair{String, Date}} | Vector{Pair{String, Tuple{Date, Date}}} | Vector{Pair{String, Vector{Tuple{Date, Date}}}} |

As an example, calling `BalancedPanel(data, "unit1" => Date(2000))` will return an object of type
`BalancedPanel{SingleUnitTreatment, ContinuousTreatment}`, as only one unit is treated and only a
Expand Down
Binary file added plot1.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added plot2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
169 changes: 106 additions & 63 deletions src/TreatmentPanel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,10 @@ multiple period treatments can be generalized to multiple treated units.
The following table provides an overview of the types of treatment pattern supported:
| | Only starting point | Start and end point | Multiple start & end points |
|----------- |------------------------ |------------------------- |------------------------------ |
| **one unit** | Pair{String, Date} | Pair{String, Tuple{Date, Date}} | Pair{String}, Vector{Tuple{Date, Date}}} |
| | Only starting point | Start and end point | Multiple start & end points |
|---------------------|-----------------------------|-------------------------------------------|-----------------------------------------------------|
| **one unit** | Pair{String, Date} | Pair{String, Tuple{Date, Date}} | Pair{String}, Vector{Tuple{Date, Date}}} |
| **multiple units** | Vector{Pair{String, Date}} | Vector{Pair{String, Tuple{Date, Date}}} | Vector{Pair{String}, Vector{Tuple{Date, Date}}}} |
Currently, only single treatment unit and continuous treatment is supported.
"""
@with_kw struct BalancedPanel{UTType, TDType} <: TreatmentPanel where UTType <: UnitTreatmentType where TDType <: TreatmentDurationType
N::Int64
Expand All @@ -47,15 +45,8 @@ Currently, only single treatment unit and continuous treatment is supported.
Y::Matrix{Float64}
end

# Constructor based on DatFrame and treatment assignment in pairs
function BalancedPanel(df::DataFrame, treatment_assignment::Vector{Pair{NType, TType}};
id_var = nothing,
t_var = nothing,
outcome_var = nothing,
sort_inplace = false) where NType where TType <: Union{Date, Int64}

### SANITY CHECKS ###

# Check that ID, time, and outcome variable are provided
function check_id_t_outcome(df, outcome_var, id_var, t_var)
# Check relevant info has been provided
!isnothing(outcome_var) || error(ArgumentError(
"Please specify outcome_var, the name of the column in your dataset holding the "*
Expand All @@ -69,47 +60,102 @@ function BalancedPanel(df::DataFrame, treatment_assignment::Vector{Pair{NType, T
"Please specify t_var, the name of the column in your dataset holding the "*
"time dimension."
))
!isnothing(treatment_assignment) || error(ArgumentError(
"Please specify treatment assignment, the identifier of the treated unit(s) in your "*
"dataset and associated start date(s) of treatment."
))

# Ensure columns exist in data
f = in(names(df))
f(string(id_var)) || throw("Error: ID variable $id_var is not present in the data.")
f(string(t_var)) || throw("Error: Time variable $t_var is not present in the data.")
f(string(outcome_var)) || throw("Error: ID variable $outcome_var is not present in the data.")
end

# Functions to get all treatment periods
function treatment_periods(ta::Pair{T1, S1}) where T1 where S1
[last(ta)]
end

function treatment_periods(ta::Pair{T1, S1}) where T1 where S1 <: Union{Pair{Int, Int}, Pair{Date, Date}}
collect(last(ta))
end

function treatment_periods(ta::Vector{Pair{T1, S1}}) where T1 where S1
last.(ta)
end

function treatment_periods(ta::Vector{Pair{T1, S1}}) where T1 where S1 <: Union{Pair{Int, Int}, Pair{Date, Date}}
unique(reduce(vcat, collect.(last.(ta))))
end

# Functions to construct treatment assignment matrix
function construct_W(ta::Pair{T1, S1}, N, T, is, ts) where T1 where S1
W = [false for i = 1:N, j = 1:T]
W[findfirst(==(ta[1]), is), findfirst(==(ta[2]), ts):end] .= true

return W
end

function construct_W(ta::Pair{T1, S1}, N, T, is, ts) where T1 where S1 <: Union{Pair{Int, Int}, Pair{Date, Date}}
W = [false for i = 1:N, j = 1:T]
W[findfirst(==(ta[1]), is), findfirst(==(ta[2][1]), ts):findfirst(==(ta[2][2]), ts)] .= true

return W
end

function construct_W(tas::Vector{Pair{T1, S1}}, N, T, is, ts) where T1 where S1
W = [false for i = 1:N, j = 1:T]
for ta tas
W[findfirst(==(ta[1]), is), findfirst(==(ta[2]), ts):end] .= true
end

return W
end

function construct_W(tas::Vector{Pair{T1, S1}}, N, T, is, ts) where T1 where S1 <: Union{Pair{Int, Int}, Pair{Date, Date}}
W = [false for i = 1:N, j = 1:T]
for ta tas
W[findfirst(==(ta[1]), is), findfirst(==(ta[2][1]), ts):findfirst(==(ta[2][2]), ts)] .= true
end

return W
end

# Constructor based on DatFrame and treatment assignment in pairs
function BalancedPanel(df::DataFrame, treatment_assignment;
id_var = nothing, t_var = nothing, outcome_var = nothing,
sort_inplace = false) where NType where TType

# Get all units and time periods
is = sort(unique(df[!, id_var])); i_set = Set(is)
ts = sort(unique(df[!, t_var])); t_set = Set(ts)

for tp treatment_assignment
in(tp[1], i_set) || throw("Error: Treatment assignment $tp provided, but $(tp[1]) is not in the list of unit identifiers $id_var")
in(tp[2], t_set) || throw("Error: Treatment assignment $tp provided, but $(tp[2]) is not in the list of time identifiers $t_var")
end

# Sort data if necessary, in place if required
df = ifelse(issorted(df, [id_var, t_var]), df,
ifelse(sort_inplace, sort!(df, [id_var, t_var]),
sort(df, [id_var, t_var])))
# Get all treatment units and treatment periods
treated_is = first.(treatment_assignment)
treated_is = typeof(treated_is) <: AbstractArray ? treated_is : [treated_is]
treated_ts = treatment_periods(treatment_assignment)

# Dimensions
N = length(is)
T = length(ts)

# Treatment matrix
W = [false for i 1:N, t 1:T]
### SANITY CHECKS ###
check_id_t_outcome(df, outcome_var, id_var, t_var)
for ti treated_is
in(ti, i_set) || throw("Error: Treatment unit $ti is not in the list of unit identifiers $id_var")
end

for tp treatment_assignment
i_id = findfirst(==(tp[1]), is)
t_id = findfirst(==(tp[2]), ts)
W[i_id, t_id:end] .= true
for tt treated_ts
in(tt, t_set) || throw("Error: Treatment period $tt is not in the list of time identifiers $t_var")
end

# Sort data if necessary, in place if required
df = ifelse(issorted(df, [id_var, t_var]), df,
ifelse(sort_inplace, sort!(df, [id_var, t_var]),
sort(df, [id_var, t_var])))

# Treatment matrix
W = construct_W(treatment_assignment, N, T, is, ts)

# Outcome matrix
Y = zeros(size(W))

for (row, i) enumerate(is), (col, t) enumerate(ts)
try
Y[row, col] = only(df[(df[!, id_var] .== i) .& (df[!, t_var] .== t), outcome_var])
Expand All @@ -119,7 +165,7 @@ function BalancedPanel(df::DataFrame, treatment_assignment::Vector{Pair{NType, T
end

# Determine UnitTreatmentType and TreatmentDurationType
uttype = if length(treatment_assignment) == 1
uttype = if length(treated_is) == 1
SingleUnitTreatment
else
if all(==(treatment_assignment[1][2]), last.(treatment_assignment))
Expand All @@ -129,51 +175,48 @@ function BalancedPanel(df::DataFrame, treatment_assignment::Vector{Pair{NType, T
end
end

tdtype = ContinuousTreatment
tdtype = if typeof(treatment_assignment) <: Pair
if typeof(treatment_assignment[2]) <: Pair
StartEndTreatment
else
ContinuousTreatment
end
else
if typeof(treatment_assignment[1][2]) <: Pair
StartEndTreatment
else
ContinuousTreatment
end
end

BalancedPanel{uttype, tdtype}(N, T, W, ts, is, Y)
end

# Constructor for single treatment
function BalancedPanel(df::DataFrame, treatment_assignment::Pair{NType, TType};
id_var = nothing, t_var= nothing, outcome_var = nothing, sort_inplace = false) where NType where TType

BalancedPanel(df, [treatment_assignment]; id_var = id_var,
t_var = t_var, outcome_var = outcome_var, sort_inplace = sort_inplace)

end

# UnblancedPanel - N observations but not all of them for T periods
## UnblancedPanel - N observations but not all of them for T periods

#!# Not yet implemented

# Utility functions
function treated_ids(x::BalancedPanel{SingleUnitTreatment, T}) where T
for i 1:x.N
for t 1:x.T
if x.W[i, t]
return i
end
end
end
## Utility functions
function treated_ids(x::BalancedPanel)
any.(eachrow(x.W))
end

function treated_labels(x::BalancedPanel{SingleUnitTreatment, T}) where T
function treated_labels(x::BalancedPanel)
x.is[treated_ids(x)]
end

function first_treated_period_ids(x::BalancedPanel{SingleUnitTreatment, T}) where T
findfirst(x.W[treated_ids(x), :])
function first_treated_period_ids(x::BalancedPanel)
findfirst.(eachrow(x.W[treated_ids(x), :]))
end

function first_treated_period_labels(x::BalancedPanel{SingleUnitTreatment, T}) where T
function first_treated_period_labels(x::BalancedPanel)
x.ts[first_treated_period_ids(x)]
end

function length_T₀(x::BalancedPanel{SingleUnitTreatment, T}) where T
first_treated_period_ids(x) - 1
function length_T₀(x::BalancedPanel)
first_treated_period_ids(x) .- 1
end

function length_T₁(x::BalancedPanel{SingleUnitTreatment, T}) where T
x.T - first_treated_period_ids(x) + 1
function length_T₁(x::BalancedPanel)
x.T .- first_treated_period_ids(x) .+ 1
end
Loading

0 comments on commit ad16780

Please sign in to comment.