===================
== Thomas Pinder ==
===================
Bayesian ML, Causal Inference, and JAX


Bayesian Synthetic Difference-in-Differences via Cut Posteriors

A NumPyro implementation of Bayesian Synthetic Difference-in-Differences using modular inference and cut posteriors.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import jax
import jax.lax
import jax.random as jr
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpyro as npy
from numpyro import distributions as npy_dist
from numpyro.infer import MCMC, NUTS
import numpy as np
import arviz as az
import seaborn as sns
import pandas as pd

jax.config.update("jax_platform_name", "cpu")
npy.set_host_device_count(4)
key = jr.PRNGKey(123)
sns.set_theme(
    context="notebook",
    font="serif",
    style="whitegrid",
    palette="deep",
    rc={"figure.figsize": (6, 3), "axes.spines.top": False, "axes.spines.right": False},
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]

Introduction

Synthetic Difference-in-Differences (SDiD), introduced by Arkhangelsky et al. (2021), combines the strengths of Difference-in-Differences and the Synthetic Control Method to estimate causal effects in panel data. The frequentist SDiD estimator operates in three stages: it learns unit weights that balance control units against the treated unit, time weights that balance pre-treatment against post-treatment periods, and then estimates a treatment effect from a weighted two-way fixed effects regression.

A natural question is whether this estimator can be made Bayesian. The difficulty is that the three-stage structure is not just a computational convenience — it encodes a deliberate separation of concerns. The unit weights should be informed only by pre-treatment data, the time weights only by control-unit data, and neither should be distorted by the outcome likelihood that identifies the treatment effect. Naively placing all parameters into a single Bayesian model would violate this separation: the outcome likelihood would pull the weights away from their balancing objective.

The solution comes from modular Bayesian inference and the cut posterior. We partition the model into three modules, each with its own likelihood, and cut the gradient feedback from the treatment-effect module to the weight parameters. In JAX, this is a single line: jax.lax.stop_gradient. The result is a proper Bayesian SDiD that preserves the estimator’s three-stage logic while propagating weight uncertainty into the treatment effect posterior.

We apply this to the canonical Proposition 99 dataset — California’s 1988 cigarette tax increase — and estimate the causal effect on per-capita cigarette sales.

The SDiD estimator

We briefly recall the structure of the frequentist SDiD. Given a balanced panel YRN×T with Nco control units and Ntr treated units observed over Tpre pre-treatment and Tpost post-treatment periods, the estimator proceeds in three stages.

Stage 1: Unit weights. Find simplex weights ω^ΔNco that balance control units against treated units in the pre-treatment period:

ω^=argminωΔNcot=1Tpre(ω0+jcoωjYjtY¯tr,t)2+ζ2Tpreω22.

Stage 2: Time weights. Find simplex weights λ^ΔTpre that balance pre-treatment periods against post-treatment periods for control units:

λ^=argminλΔTpreico(λ0+s=1TpreλsYisY¯i,post)2+ζ2Ncoλ22.

Stage 3: Weighted regression. Estimate the treatment effect τ from a weighted two-way fixed effects regression with observation weights wit:

τ^sdid=argminτ,α,βi,twit(YitαiβtτDit)2,

where the weight matrix has the structure:

Pre-treatmentPost-treatment
Controlω^iλ^tω^i/Tpost
Treatedλ^t/Ntr1/(NtrTpost)

The critical feature is the separation of concerns: the unit weights are informed only by pre-treatment data, the time weights only by control-unit data, and the treatment effect estimate uses the full panel with the weights fixed.

Figure 1: Data partitioning in SDiD. Each stage uses a specific slice of the panel, preserving the separation of concerns.

A modular Bayesian formulation

Why naive joint estimation fails

A natural first attempt would be to place priors on ω, λ, and τ and write a single likelihood over all the data. The problem is that the outcome likelihood in Stage 3 would then update the weight parameters, pulling them away from their balancing objective. The unit weights would no longer reflect pre-treatment balance; they would be distorted by the post-treatment outcome data, which is exactly what the weights are supposed to be insulated from.

The cut posterior

The solution comes from modular Bayesian inference (Lunn et al., 2009; Plummer, 2015). The idea is to partition the model into modules, each with its own likelihood, and cut the feedback from downstream modules to upstream parameters. The posterior factorises as:

p(τ,ω,λY)p(τω,λ,Y)p(ωYpre)p(λYco).

Each factor is a separate module:

  • Module 1p(ωYpre): the unit-weight posterior, identified by pre-treatment matching.
  • Module 2p(λYco): the time-weight posterior, identified by control-unit matching.
  • Module 3p(τω,λ,Y): the treatment effect, determined by the SDiD double-difference formula given the weights.

The “cut” prevents the treatment-effect computation from feeding information back into the weight parameters. Weight uncertainty propagates into τ because different weight draws yield different double-difference estimates.

Figure 2: Information flow in the cut posterior. Solid arrows show forward propagation of weight draws into the treatment-effect computation. Dashed arrows with ✗ show the cut — Module 3 cannot update the weight posteriors.

Implementation via two-stage estimation

In principle, the cut could be implemented within a single model by applying jax.lax.stop_gradient to the weights where they enter Module 3’s likelihood. In practice, this creates a fundamental incompatibility with gradient-based samplers like NUTS: the posterior has strong correlation between ω and τ (different weights yield different gaps yield different treatment effects), but stop_gradient zeros out the gradient that NUTS needs to detect this correlation. The mass matrix cannot adapt to the true posterior geometry, and leapfrog trajectories follow the wrong Hamiltonian. No amount of tuning — priors, step size, warmup — can fix this structural mismatch.

The robust solution is to enforce modularity by separation: sample the weights from their matching likelihoods (Modules 1 and 2) via standard MCMC, then compute the treatment effect analytically for each posterior draw using the SDiD double-difference formula. This matches the frequentist SDiD’s three-stage structure and is the approach taken by CausalPy, CausalImpact, and related Bayesian synthetic control implementations.

Prior specification

Unit weights. We parameterise the simplex via softmax: ω=softmax(ω~) with ω~1=0 (reference level) and ω~jN(0,1/ζ) for j=2,,Nco. Pinning one logit removes the shift invariance of softmax (since softmax(x+c)=softmax(x)), which would otherwise create a flat direction in the posterior. The prior scale 1/ζ plays the role of the 2 regularisation in the frequentist formulation: larger ζ pulls more strongly toward uniform weights (DiD-like), while smaller ζ allows the data to concentrate weight on a few good-match states (SC-like). The matching likelihood is:

Y¯tr,tN(ω0+ωYco,t,σω2),t=1,,Tpre.

Time weights. Symmetric construction: λ=softmax(λ~) with λ~1=0 and λ~sN(0,1/ζ) for s=2,,Tpre, and

Y¯i,postN(λ0+λYi,pre,σλ2),icontrol.

Treatment effect. The treatment effect τ is not a sampled parameter — it is computed analytically for each posterior draw of (ω,λ) via the SDiD double-difference formula:

τ^(ω,λ)=(g¯postλgpre),gt=Y¯tr,tωYco,t.

This is algebraically equivalent to the weighted two-way fixed-effects regression (Arkhangelsky et al., 2021, Proposition 4.1) but avoids the N+T1 fixed-effect parameters entirely. The posterior distribution of τ is the pushforward of the weight posteriors through this formula: all uncertainty in τ comes from uncertainty in the weights. We standardise the outcome panel to zero mean and unit variance before fitting, so all weight-module priors are on the standardised scale, and back-transform τ afterwards.

Data

We use the Proposition 99 dataset, the canonical benchmark for synthetic control methods. In November 1988, California passed Proposition 99, which imposed a 25-cent-per-pack excise tax on cigarettes along with restrictions on smoking in public spaces. The dataset contains annual per-capita cigarette sales (in packs) for 39 US states from 1970 to 2000.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
data = pd.read_csv("datasets/california_smoking.csv")

treatment_year = 1988
treated_state = "California"

controls = data.loc[data.state != treated_state].pivot(
    index="year", columns="state", values="cigsale"
)
treated = data.loc[data.state == treated_state, ["year", "cigsale"]].set_index("year")

control_states = list(controls.columns)
years = sorted(data["year"].unique().astype(int))
N_co = len(control_states)
T_pre = len([y for y in years if y < treatment_year])
T_post = len([y for y in years if y >= treatment_year])
N = N_co + 1
T = T_pre + T_post

print(f"Panel: {N} states x {T} years ({years[0]}-{years[-1]})")
print(f"  Control: {N_co} states, Treated: {treated_state}")
print(f"  Pre-treatment: {T_pre} years (1970-1987)")
print(f"  Post-treatment: {T_post} years (1988-2000)")
Panel: 39 states x 31 years (1970-2000)
  Control: 38 states, Treated: California
  Pre-treatment: 18 years (1970-1987)
  Post-treatment: 13 years (1988-2000)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# Construct (N, T) panel matrix: control states first, California last
Y_co = controls.values.T  # (N_co, T)
Y_tr = treated.values.T   # (1, T)
Y = jnp.array(np.vstack([Y_co, Y_tr]))
D = jnp.zeros((N, T))
D = D.at[-1, T_pre:].set(1.0)  # California, post-1988

# Standardise outcomes for better conditioning
Y_mean = float(Y.mean())
Y_std = float(Y.std())
Y_scaled = (Y - Y_mean) / Y_std
print(f"  Standardised: mean={Y_mean:.1f}, std={Y_std:.1f}")
  Standardised: mean=118.9, std=32.8
Show panel plot
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def clean_legend(ax):
    handles, labels = ax.get_legend_handles_labels()
    by_label = dict(zip(labels, handles, strict=False))
    ax.legend(by_label.values(), by_label.keys(), loc="best")
    return ax

fig, ax = plt.subplots(figsize=(8, 3))
ax.plot(years, controls.values, color="grey", alpha=0.3, linewidth=0.8, label="Control states")
ax.plot(years, treated.values, color=cols[0], linewidth=2, label="California")
ax.axvline(treatment_year, color=cols[2], linestyle="--", label="Proposition 99")
clean_legend(ax)
ax.set(xlabel="Year", ylabel="Per-capita cigarette sales (packs)")
fig.tight_layout()
Show plotting helpers
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def plot_weight_posteriors(
    samples: dict,
    state_names: list[str],
    pre_years: list[int],
) -> plt.Figure:
    """Bar charts of unit and time weight posteriors with credible intervals."""
    N_co = len(state_names)
    T_pre = len(pre_years)
    omega = np.array(samples["omega"]).reshape(-1, N_co)
    lam = np.array(samples["lam"]).reshape(-1, T_pre)

    # Sort states by posterior median weight (descending)
    med_omega = np.median(omega, axis=0)
    sort_idx = np.argsort(med_omega)[::-1]

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    # Unit weights (sorted)
    ax = axes[0]
    sorted_med = med_omega[sort_idx]
    sorted_lo, sorted_hi = np.percentile(omega[:, sort_idx], [3, 97], axis=0)
    sorted_names = [state_names[i] for i in sort_idx]
    x = np.arange(N_co)
    ax.bar(x, sorted_med, color=cols[1], alpha=0.6, edgecolor="white")
    ax.vlines(x, sorted_lo, sorted_hi, color=cols[1], linewidth=1.2)
    ax.axhline(1.0 / N_co, color="grey", linestyle="--", linewidth=1,
               label=f"Uniform (1/{N_co})")
    ax.set_xticks(x[::3])
    ax.set_xticklabels([sorted_names[i] for i in range(0, N_co, 3)],
                       rotation=70, fontsize=7, ha="right")
    ax.set(ylabel="Weight", title="Unit Weights ($\\omega$)")
    ax.legend(fontsize=8)

    # Time weights
    ax = axes[1]
    med_lam = np.median(lam, axis=0)
    lo_lam, hi_lam = np.percentile(lam, [3, 97], axis=0)
    ax.bar(pre_years, med_lam, color=cols[3], alpha=0.6, edgecolor="white")
    ax.vlines(pre_years, lo_lam, hi_lam, color=cols[3], linewidth=1.5)
    ax.axhline(1.0 / T_pre, color="grey", linestyle="--", linewidth=1,
               label=f"Uniform (1/{T_pre})")
    ax.set(xlabel="Year", ylabel="Weight", title="Time Weights ($\\lambda$)")
    ax.legend(fontsize=8)

    fig.tight_layout()
    return fig


def plot_counterfactual(
    samples: dict,
    Y: np.ndarray,
    N_co: int,
    T_pre: int,
    years: list[int],
    treatment_year: int,
) -> plt.Figure:
    """Observed vs synthetic control trajectory and period-by-period effects."""
    omega = np.array(samples["omega"]).reshape(-1, N_co)

    N, T = Y.shape

    # Treated trajectory
    Y_co = Y[:N_co]
    Y_tr = Y[N_co:].mean(axis=0)

    # Synthetic control for each posterior draw
    sc_draws = omega @ Y_co  # (n_draws, T)
    sc_median = np.median(sc_draws, axis=0)
    sc_lo, sc_hi = np.percentile(sc_draws, [3, 97], axis=0)

    fig, axes = plt.subplots(1, 2, figsize=(12, 4), gridspec_kw={"width_ratios": [2, 1]})

    # Left: trajectories
    ax = axes[0]
    ax.plot(years, Y_tr, "o-", color="black", markersize=3, linewidth=1.5,
            label="California (observed)")
    ax.plot(years, sc_median, "s--", color=cols[1], markersize=3, linewidth=1.5,
            label="Synthetic California")
    ax.fill_between(years, sc_lo, sc_hi, alpha=0.2, color=cols[1], label="94% CI")
    ax.axvline(treatment_year - 0.5, color="grey", linestyle=":", linewidth=1,
               label="Proposition 99")
    ax.set(xlabel="Year", ylabel="Per-capita cigarette sales (packs)",
           title="California vs. Synthetic Control")
    ax.legend(fontsize=8)

    # Right: period-by-period effect
    ax2 = axes[1]
    post_years = [y for y in years if y >= treatment_year]
    te_draws = Y_tr[T_pre:][np.newaxis, :] - sc_draws[:, T_pre:]
    te_median = np.median(te_draws, axis=0)
    te_lo, te_hi = np.percentile(te_draws, [3, 97], axis=0)

    ax2.bar(post_years, te_median, color=cols[1], alpha=0.6, edgecolor="white")
    ax2.vlines(post_years, te_lo, te_hi, color=cols[1], linewidth=2)
    ax2.axhline(0, color="grey", linestyle="--", linewidth=0.5)
    ax2.set(xlabel="Year", ylabel="$\\hat{\\tau}_t$ (packs)",
            title="Period-by-Period Effect")

    fig.tight_layout()
    return fig

Model

The NumPyro model encodes Modules 1 and 2. The treatment effect (Module 3) is computed analytically from the weight posterior draws — no jax.lax.stop_gradient is needed because the modularity is enforced by separating the weight estimation from the treatment-effect computation.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def bayesian_sdid(
    Y: jnp.ndarray,
    N_co: int,
    T_pre: int,
    zeta: float = 1.0,
):
    """Bayesian SDiD weight model (Modules 1 & 2).

    The treatment effect is computed analytically from the weight
    posterior draws via the SDiD double-difference formula.

    Parameters
    ----------
    Y : array (N, T)
        Standardised panel outcome matrix. First N_co rows are control units.
    N_co : int
        Number of control units.
    T_pre : int
        Number of pre-treatment periods.
    zeta : float
        Regularisation strength for the weight priors.
    """
    N, T = Y.shape
    Y_co_pre = Y[:N_co, :T_pre]                        # (N_co, T_pre)
    y_tr_pre_mean = Y[N_co:, :T_pre].mean(axis=0)      # (T_pre,)
    y_co_post_mean = Y[:N_co, T_pre:].mean(axis=1)     # (N_co,)

    # ── Module 1: Unit weights ──────────────────────────────────────────
    # Pin first logit to zero to remove softmax shift invariance
    omega_raw = npy.sample(
        "omega_raw", npy_dist.Normal(0, 1.0 / zeta).expand([N_co - 1])
    )
    omega_tilde = jnp.concatenate([jnp.zeros(1), omega_raw])
    omega = npy.deterministic("omega", jax.nn.softmax(omega_tilde))
    omega0 = npy.sample("omega0", npy_dist.Normal(0, 5.0))
    sigma_omega = npy.sample("sigma_omega", npy_dist.HalfNormal(1.0))

    # Matching: weighted control average ≈ treated average at each pre-period
    npy.sample(
        "omega_match",
        npy_dist.Normal(omega0 + Y_co_pre.T @ omega, sigma_omega),
        obs=y_tr_pre_mean,
    )

    # ── Module 2: Time weights ──────────────────────────────────────────
    # Pin first logit to zero to remove softmax shift invariance
    lambda_raw = npy.sample(
        "lambda_raw", npy_dist.Normal(0, 1.0 / zeta).expand([T_pre - 1])
    )
    lambda_tilde = jnp.concatenate([jnp.zeros(1), lambda_raw])
    lam = npy.deterministic("lam", jax.nn.softmax(lambda_tilde))
    lambda0 = npy.sample("lambda0", npy_dist.Normal(0, 5.0))
    sigma_lambda = npy.sample("sigma_lambda", npy_dist.HalfNormal(1.0))

    # Matching: weighted pre-period average ≈ post-treatment average for each control
    npy.sample(
        "lambda_match",
        npy_dist.Normal(lambda0 + Y_co_pre @ lam, sigma_lambda),
        obs=y_co_post_mean,
    )

The two modules each define their own likelihood (omega_match and lambda_match) over their respective data slices. The weights are sampled without any treatment-effect likelihood — the modularity is enforced by construction rather than by gradient manipulation.

Figure 3: The Bayesian SDiD model. Modules 1 and 2 are sampled jointly via MCMC. The treatment effect τ is computed analytically from weight posterior draws, enforcing the cut by construction.

Two implementation details are worth noting. The panel is standardised before fitting so that all priors have a natural unit scale; the treatment effect is back-transformed to the original scale (packs per capita) after computation. The regularisation parameter zeta is set to 1.0, which gives a prior standard deviation of 1 on the softmax logits. This is weak enough that the matching likelihood can push the weights away from uniform toward the sparse solutions that distinguish SDiD from simple DiD — stronger regularisation (e.g. ζ=3) over-shrinks the weights toward 1/Nco and collapses the estimator to DiD.

Posterior sampling

We sample the weight parameters via NUTS. The model has (Nco1)+(Tpre1)=54 weight logit parameters plus 4 matching parameters (ω0, λ0, σω, σλ), for a total of 58 sampled parameters. After sampling, we compute τ for each posterior draw via the SDiD double-difference formula.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
kernel = NUTS(bayesian_sdid, target_accept_prob=0.95)
mcmc = MCMC(
    kernel,
    num_warmup=2000,
    num_samples=2000,
    num_chains=4,
    chain_method="parallel",
    progress_bar=False,
)
mcmc.run(
    key,
    Y=Y_scaled,
    N_co=N_co,
    T_pre=T_pre,
)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
samples = mcmc.get_samples(group_by_chain=True)

# ── Module 3: Compute tau via SDiD double-difference ────────────────
omega_all = np.array(samples["omega"])      # (chains, draws, N_co)
lam_all = np.array(samples["lam"])          # (chains, draws, T_pre)
Y_co_np = np.array(Y_scaled[:N_co])         # (N_co, T)
y_tr_np = np.array(Y_scaled[N_co:].mean(axis=0))  # (T,)

# Gap series for each posterior draw: treated minus synthetic control
gaps = y_tr_np - omega_all @ Y_co_np        # (chains, draws, T)

# SDiD formula: tau = mean(gap_post) - lambda' gap_pre
tau_scaled = (
    gaps[..., T_pre:].mean(axis=-1)
    - (gaps[..., :T_pre] * lam_all).sum(axis=-1)
)
tau_packs = tau_scaled * Y_std

idata = az.from_dict(posterior={
    "tau": tau_packs,
    "omega": omega_all,
    "lam": lam_all,
    "sigma_omega": np.array(samples["sigma_omega"]),
    "sigma_lambda": np.array(samples["sigma_lambda"]),
})

print(az.summary(idata, var_names=["tau", "sigma_omega", "sigma_lambda"], hdi_prob=0.94))
                mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  \
tau          -14.213  1.988 -17.589  -10.422      0.049    0.019    1655.0   
sigma_omega    0.079  0.027   0.037    0.127      0.001    0.000    2098.0   
sigma_lambda   0.322  0.046   0.244    0.408      0.001    0.001    6868.0   

              ess_tail  r_hat  
tau             3384.0    1.0  
sigma_omega     3408.0    1.0  
sigma_lambda    5248.0    1.0  

Results

Treatment effect

The posterior of τ captures the average treatment effect of Proposition 99 on per-capita cigarette sales. A negative value indicates that the policy reduced cigarette consumption.

1
2
3
4
fig, ax = plt.subplots(figsize=(6, 3))
az.plot_posterior(idata, var_names=["tau"], hdi_prob=0.94, ax=ax)
ax.set_title("Posterior of $\\tau$: Effect of Proposition 99")
fig.tight_layout()

Weight posteriors

The weight posteriors reveal which control states and which pre-treatment years the model relies on most heavily. Sparse unit weights indicate that a few states dominate the synthetic California (SCM-like behaviour); near-uniform unit weights would indicate that all controls contribute equally (DiD-like behaviour). Time weights that concentrate on later pre-treatment years suggest that the period just before Proposition 99 is most informative for the post-treatment counterfactual.

1
2
pre_years = [int(y) for y in years if y < treatment_year]
plot_weight_posteriors(mcmc.get_samples(), control_states, pre_years)

Counterfactual trajectories

We construct the synthetic California trajectory as Y^tsc=jωjYjt for each posterior draw, yielding a full posterior distribution over the counterfactual. The gap between California’s observed cigarette sales and the synthetic control in the post-treatment period is the implied treatment effect.

1
2
3
4
5
6
7
8
plot_counterfactual(
    mcmc.get_samples(),
    np.array(Y),
    N_co,
    T_pre,
    [int(y) for y in years],
    treatment_year,
)

Comparison to simple DiD

To illustrate the value of the SDiD reweighting, we compare against the naive DiD estimator, which implicitly uses uniform unit and time weights:

τ^did=(Y¯tr,postY¯tr,pre)(Y¯co,postY¯co,pre).
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
Y_np = np.array(Y)

y_tr_post = Y_np[N_co:, T_pre:].mean()
y_tr_pre = Y_np[N_co:, :T_pre].mean()
y_co_post = Y_np[:N_co, T_pre:].mean()
y_co_pre = Y_np[:N_co, :T_pre].mean()
tau_did = (y_tr_post - y_tr_pre) - (y_co_post - y_co_pre)

tau_draws = tau_packs.flatten()

fig, ax = plt.subplots(figsize=(6, 3))
ax.hist(tau_draws, bins=50, density=True, alpha=0.5, color=cols[1],
        edgecolor="white", label="Bayesian SDiD")
ax.axvline(tau_did, color=cols[2], linewidth=2, linestyle="--",
           label=f"DiD estimate = {tau_did:.1f}")
ax.axvline(tau_draws.mean(), color=cols[1], linewidth=2, linestyle="--",
           label=f"SDiD posterior mean = {tau_draws.mean():.1f}")
ax.legend(fontsize=8)
ax.set(xlabel="Treatment effect (packs per capita)", ylabel="Density",
       title="Bayesian SDiD vs. Naive DiD")
fig.tight_layout()

Discussion

The Bayesian SDiD via cut posteriors preserves the three-stage structure of the Arkhangelsky et al. estimator while providing full uncertainty quantification. The weight posteriors offer an interpretive advantage over the frequentist point estimates: they reveal not just which states and years receive weight, but how confident the model is in those allocations.

The role of ζ: regularisation as an estimator dial

The regularisation parameter ζ controls where the Bayesian SDiD sits on the spectrum between Difference-in-Differences and Synthetic Control. This is not merely a tuning knob — it determines the fundamental character of the estimator.

Large ζ (strong regularisation). The prior N(0,1/ζ) on the softmax logits concentrates them near zero, producing near-uniform weights ωj1/Nco. The synthetic control becomes the simple average of all control units, and the double-difference formula reduces to ordinary DiD. In the limit ζ, the estimator recovers exact DiD.

Small ζ (weak regularisation). The prior allows large logit deviations, so the matching likelihood can concentrate weight on the few states that closely track the treated unit’s pre-treatment trajectory. The estimator approaches the Synthetic Control Method. In the limit ζ0 with no intercept ω0, it recovers exact SC.

SDiD is designed to sit between these extremes, adapting to the data. The Bayesian analogue inherits this spectrum through ζ.

Diagnosing over-shrinkage. Two symptoms indicate that ζ is too large:

  1. The posterior mean of τ is close to the naive DiD estimate.
  2. The unit-weight posteriors are near-uniform, with the top-weighted states receiving less than 3--5× the uniform weight 1/Nco.

If both symptoms are present, the weights have insufficient freedom to differentiate good-match states from poor ones, and the estimator is effectively computing DiD with Bayesian uncertainty.

Remedies for over-shrinkage, in increasing order of sophistication:

  1. Lower ζ. Start with ζ=1 and decrease toward 0.5 if the weight posteriors remain near-uniform. With the two-stage estimation approach, mixing is robust to weaker regularisation because the sampler only explores weight space (no treatment-effect parameters to couple with).

  2. Data-driven ζ. The synthdid R package computes ζ from the first-difference variance of the control panel, scaled by (NtrTpre)1/4. This can serve as a rough calibration target when translated to the logit scale.

  3. Sparsity-inducing priors. Replace the Gaussian prior on logits with a Laplace (double-exponential) prior to encourage sparse weight vectors, or use a Dirichlet prior with concentration α<1 directly on the simplex. These better match the frequentist SDiD’s ability to produce exact-zero weights for poor-match states, at the cost of more complex posterior geometry.

A useful sanity check after any of these adjustments is to plot the synthetic control trajectory ω^Yco against the treated unit: if the pre-treatment fit is poor, the weights need more freedom (lower ζ) or a different prior structure entirely.

Limitations

The cut posterior is not a standard Bayesian posterior — it does not correspond to conditioning on a single generative model. Theoretical guarantees (such as Bernstein–von Mises) require separate arguments for modular posteriors, and calibration of credible intervals cannot be taken for granted.

The matching variances σω2 and σλ2 are free parameters with no direct counterpart in the frequentist SDiD. They control how tightly the weights must satisfy the balancing condition: too small, and the weight posterior is near-degenerate (approaching the frequentist point estimate); too large, and the weights are diffuse and uninformative. In this implementation, we treat them as parameters with half-normal priors and let the data inform their scale, but sensitivity to this choice should be investigated in applied settings.

The softmax parameterisation ensures all weights are strictly positive, while the frequentist solution may have exact zeros (sparse weights). For applications where sparsity is desirable, a Dirichlet prior with concentration <1 or a spike-and-slab prior on the softmax logits would be more appropriate.

Alternative approaches

An alternative Bayesian route to the same problem is to specify a latent factor model with interactive fixed effects (Xu, 2017). Rather than learning explicit weights, the factor model captures unit heterogeneity through latent loadings and extrapolates counterfactuals from the estimated factor structure. This approach is more structurally ambitious and allows model comparison via LOO-CV, but requires committing to a parametric factor structure and choosing the number of latent factors K. The cut-posterior SDiD stays closer to the frequentist estimator and produces directly interpretable weight posteriors.

References

  • Arkhangelsky, D., Athey, S., Hirshberg, D. A., Imbens, G. W., & Wager, S. (2021). Synthetic Difference-in-Differences. American Economic Review, 111(12), 4088–4118.
  • Lunn, D., Best, N., Spiegelhalter, D., Graham, G., & Neuenschwander, B. (2009). Combining MCMC with ‘sequential’ PKPD modelling. Journal of Pharmacokinetics and Pharmacodynamics, 36, 19–38.
  • Plummer, M. (2015). Cuts in Bayesian graphical models. Statistics and Computing, 25, 37–43.
  • Xu, Y. (2017). Generalized Synthetic Control Method: Causal Inference with Interactive Fixed Effects Models. Political Analysis, 25(1), 57–76.