A NumPyro implementation of Bayesian Synthetic Difference-in-Differences using
modular inference and cut posteriors.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
| import jax
import jax.lax
import jax.random as jr
import jax.numpy as jnp
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpyro as npy
from numpyro import distributions as npy_dist
from numpyro.infer import MCMC, NUTS
import numpy as np
import arviz as az
import seaborn as sns
import pandas as pd
jax.config.update("jax_platform_name", "cpu")
npy.set_host_device_count(4)
key = jr.PRNGKey(123)
sns.set_theme(
context="notebook",
font="serif",
style="whitegrid",
palette="deep",
rc={"figure.figsize": (6, 3), "axes.spines.top": False, "axes.spines.right": False},
)
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
|
Introduction
Synthetic Difference-in-Differences (SDiD), introduced by
Arkhangelsky et al. (2021), combines the
strengths of Difference-in-Differences and the Synthetic Control Method to
estimate causal effects in panel data. The frequentist SDiD estimator operates
in three stages: it learns unit weights that balance control units against the
treated unit, time weights that balance pre-treatment against post-treatment
periods, and then estimates a treatment effect from a weighted two-way fixed
effects regression.
A natural question is whether this estimator can be made Bayesian. The
difficulty is that the three-stage structure is not just a computational
convenience — it encodes a deliberate separation of concerns. The unit weights
should be informed only by pre-treatment data, the time weights only by
control-unit data, and neither should be distorted by the outcome likelihood
that identifies the treatment effect. Naively placing all parameters into a
single Bayesian model would violate this separation: the outcome likelihood
would pull the weights away from their balancing objective.
The solution comes from modular Bayesian inference and the cut posterior.
We partition the model into three modules, each with its own likelihood, and
cut the gradient feedback from the treatment-effect module to the weight
parameters. In JAX, this is a single line: jax.lax.stop_gradient. The result
is a proper Bayesian SDiD that preserves the estimator’s three-stage logic while
propagating weight uncertainty into the treatment effect posterior.
We apply this to the canonical Proposition 99 dataset — California’s 1988
cigarette tax increase — and estimate the causal effect on per-capita cigarette
sales.
The SDiD estimator
We briefly recall the structure of the frequentist SDiD. Given a balanced panel
with control units and
treated units observed over pre-treatment and
post-treatment periods, the estimator proceeds in three
stages.
Stage 1: Unit weights. Find simplex weights
that balance control units against
treated units in the pre-treatment period:
Stage 2: Time weights. Find simplex weights
that balance pre-treatment periods
against post-treatment periods for control units:
Stage 3: Weighted regression. Estimate the treatment effect from a
weighted two-way fixed effects regression with observation weights :
where the weight matrix has the structure:
| Pre-treatment | Post-treatment |
|---|
| Control | | |
| Treated | | |
The critical feature is the separation of concerns: the unit weights are
informed only by pre-treatment data, the time weights only by control-unit data,
and the treatment effect estimate uses the full panel with the weights fixed.

Figure 1: Data partitioning in SDiD. Each stage uses a specific slice of the panel, preserving the separation of concerns.
Why naive joint estimation fails
A natural first attempt would be to place priors on , , and
and write a single likelihood over all the data. The problem is that the
outcome likelihood in Stage 3 would then update the weight parameters, pulling
them away from their balancing objective. The unit weights would no longer
reflect pre-treatment balance; they would be distorted by the post-treatment
outcome data, which is exactly what the weights are supposed to be insulated
from.
The cut posterior
The solution comes from modular Bayesian inference
(Lunn et al., 2009;
Plummer, 2015).
The idea is to partition the model into modules, each with its own likelihood,
and cut the feedback from downstream modules to upstream parameters. The
posterior factorises as:
Each factor is a separate module:
- Module 1 — : the unit-weight posterior,
identified by pre-treatment matching.
- Module 2 — : the time-weight posterior,
identified by control-unit matching.
- Module 3 — : the treatment effect,
determined by the SDiD double-difference formula given the weights.
The “cut” prevents the treatment-effect computation from feeding information
back into the weight parameters. Weight uncertainty propagates into
because different weight draws yield different double-difference estimates.

Figure 2: Information flow in the cut posterior. Solid arrows show forward propagation of weight draws into the treatment-effect computation. Dashed arrows with ✗ show the cut — Module 3 cannot update the weight posteriors.
Implementation via two-stage estimation
In principle, the cut could be implemented within a single model by applying
jax.lax.stop_gradient to the weights where they enter Module 3’s likelihood.
In practice, this creates a fundamental incompatibility with gradient-based
samplers like NUTS: the posterior has strong correlation between and
(different weights yield different gaps yield different treatment
effects), but stop_gradient zeros out the gradient that NUTS needs to detect
this correlation. The mass matrix cannot adapt to the true posterior geometry,
and leapfrog trajectories follow the wrong Hamiltonian. No amount of
tuning — priors, step size, warmup — can fix this structural mismatch.
The robust solution is to enforce modularity by separation: sample the
weights from their matching likelihoods (Modules 1 and 2) via standard MCMC,
then compute the treatment effect analytically for each posterior draw using
the SDiD double-difference formula. This matches the frequentist SDiD’s
three-stage structure and is the approach taken by CausalPy, CausalImpact,
and related Bayesian synthetic control implementations.
Prior specification
Unit weights. We parameterise the simplex via softmax:
with (reference
level) and for
. Pinning one logit removes the shift invariance
of softmax (since ), which would
otherwise create a flat direction in the posterior. The prior scale
plays the role of the regularisation in the frequentist formulation:
larger pulls more strongly toward uniform weights (DiD-like), while
smaller allows the data to concentrate weight on a few good-match
states (SC-like). The matching likelihood is:
Time weights. Symmetric construction:
with and
for
, and
Treatment effect. The treatment effect is not a sampled parameter —
it is computed analytically for each posterior draw of via
the SDiD double-difference formula:
This is algebraically equivalent to the weighted two-way fixed-effects
regression (Arkhangelsky et al., 2021, Proposition 4.1) but avoids the
fixed-effect parameters entirely. The posterior distribution of
is the pushforward of the weight posteriors through this formula:
all uncertainty in comes from uncertainty in the weights. We
standardise the outcome panel to zero mean and unit variance before fitting,
so all weight-module priors are on the standardised scale, and
back-transform afterwards.
Data
We use the Proposition 99 dataset, the canonical benchmark for synthetic control
methods. In November 1988, California passed Proposition 99, which imposed a
25-cent-per-pack excise tax on cigarettes along with restrictions on smoking in
public spaces. The dataset contains annual per-capita cigarette sales (in packs)
for 39 US states from 1970 to 2000.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
| data = pd.read_csv("datasets/california_smoking.csv")
treatment_year = 1988
treated_state = "California"
controls = data.loc[data.state != treated_state].pivot(
index="year", columns="state", values="cigsale"
)
treated = data.loc[data.state == treated_state, ["year", "cigsale"]].set_index("year")
control_states = list(controls.columns)
years = sorted(data["year"].unique().astype(int))
N_co = len(control_states)
T_pre = len([y for y in years if y < treatment_year])
T_post = len([y for y in years if y >= treatment_year])
N = N_co + 1
T = T_pre + T_post
print(f"Panel: {N} states x {T} years ({years[0]}-{years[-1]})")
print(f" Control: {N_co} states, Treated: {treated_state}")
print(f" Pre-treatment: {T_pre} years (1970-1987)")
print(f" Post-treatment: {T_post} years (1988-2000)")
|
Panel: 39 states x 31 years (1970-2000)
Control: 38 states, Treated: California
Pre-treatment: 18 years (1970-1987)
Post-treatment: 13 years (1988-2000)
1
2
3
4
5
6
7
8
9
10
11
12
| # Construct (N, T) panel matrix: control states first, California last
Y_co = controls.values.T # (N_co, T)
Y_tr = treated.values.T # (1, T)
Y = jnp.array(np.vstack([Y_co, Y_tr]))
D = jnp.zeros((N, T))
D = D.at[-1, T_pre:].set(1.0) # California, post-1988
# Standardise outcomes for better conditioning
Y_mean = float(Y.mean())
Y_std = float(Y.std())
Y_scaled = (Y - Y_mean) / Y_std
print(f" Standardised: mean={Y_mean:.1f}, std={Y_std:.1f}")
|
Standardised: mean=118.9, std=32.8
Show panel plot
1
2
3
4
5
6
7
8
9
10
11
12
13
| def clean_legend(ax):
handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles, strict=False))
ax.legend(by_label.values(), by_label.keys(), loc="best")
return ax
fig, ax = plt.subplots(figsize=(8, 3))
ax.plot(years, controls.values, color="grey", alpha=0.3, linewidth=0.8, label="Control states")
ax.plot(years, treated.values, color=cols[0], linewidth=2, label="California")
ax.axvline(treatment_year, color=cols[2], linestyle="--", label="Proposition 99")
clean_legend(ax)
ax.set(xlabel="Year", ylabel="Per-capita cigarette sales (packs)")
fig.tight_layout()
|

Show plotting helpers
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
| def plot_weight_posteriors(
samples: dict,
state_names: list[str],
pre_years: list[int],
) -> plt.Figure:
"""Bar charts of unit and time weight posteriors with credible intervals."""
N_co = len(state_names)
T_pre = len(pre_years)
omega = np.array(samples["omega"]).reshape(-1, N_co)
lam = np.array(samples["lam"]).reshape(-1, T_pre)
# Sort states by posterior median weight (descending)
med_omega = np.median(omega, axis=0)
sort_idx = np.argsort(med_omega)[::-1]
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
# Unit weights (sorted)
ax = axes[0]
sorted_med = med_omega[sort_idx]
sorted_lo, sorted_hi = np.percentile(omega[:, sort_idx], [3, 97], axis=0)
sorted_names = [state_names[i] for i in sort_idx]
x = np.arange(N_co)
ax.bar(x, sorted_med, color=cols[1], alpha=0.6, edgecolor="white")
ax.vlines(x, sorted_lo, sorted_hi, color=cols[1], linewidth=1.2)
ax.axhline(1.0 / N_co, color="grey", linestyle="--", linewidth=1,
label=f"Uniform (1/{N_co})")
ax.set_xticks(x[::3])
ax.set_xticklabels([sorted_names[i] for i in range(0, N_co, 3)],
rotation=70, fontsize=7, ha="right")
ax.set(ylabel="Weight", title="Unit Weights ($\\omega$)")
ax.legend(fontsize=8)
# Time weights
ax = axes[1]
med_lam = np.median(lam, axis=0)
lo_lam, hi_lam = np.percentile(lam, [3, 97], axis=0)
ax.bar(pre_years, med_lam, color=cols[3], alpha=0.6, edgecolor="white")
ax.vlines(pre_years, lo_lam, hi_lam, color=cols[3], linewidth=1.5)
ax.axhline(1.0 / T_pre, color="grey", linestyle="--", linewidth=1,
label=f"Uniform (1/{T_pre})")
ax.set(xlabel="Year", ylabel="Weight", title="Time Weights ($\\lambda$)")
ax.legend(fontsize=8)
fig.tight_layout()
return fig
def plot_counterfactual(
samples: dict,
Y: np.ndarray,
N_co: int,
T_pre: int,
years: list[int],
treatment_year: int,
) -> plt.Figure:
"""Observed vs synthetic control trajectory and period-by-period effects."""
omega = np.array(samples["omega"]).reshape(-1, N_co)
N, T = Y.shape
# Treated trajectory
Y_co = Y[:N_co]
Y_tr = Y[N_co:].mean(axis=0)
# Synthetic control for each posterior draw
sc_draws = omega @ Y_co # (n_draws, T)
sc_median = np.median(sc_draws, axis=0)
sc_lo, sc_hi = np.percentile(sc_draws, [3, 97], axis=0)
fig, axes = plt.subplots(1, 2, figsize=(12, 4), gridspec_kw={"width_ratios": [2, 1]})
# Left: trajectories
ax = axes[0]
ax.plot(years, Y_tr, "o-", color="black", markersize=3, linewidth=1.5,
label="California (observed)")
ax.plot(years, sc_median, "s--", color=cols[1], markersize=3, linewidth=1.5,
label="Synthetic California")
ax.fill_between(years, sc_lo, sc_hi, alpha=0.2, color=cols[1], label="94% CI")
ax.axvline(treatment_year - 0.5, color="grey", linestyle=":", linewidth=1,
label="Proposition 99")
ax.set(xlabel="Year", ylabel="Per-capita cigarette sales (packs)",
title="California vs. Synthetic Control")
ax.legend(fontsize=8)
# Right: period-by-period effect
ax2 = axes[1]
post_years = [y for y in years if y >= treatment_year]
te_draws = Y_tr[T_pre:][np.newaxis, :] - sc_draws[:, T_pre:]
te_median = np.median(te_draws, axis=0)
te_lo, te_hi = np.percentile(te_draws, [3, 97], axis=0)
ax2.bar(post_years, te_median, color=cols[1], alpha=0.6, edgecolor="white")
ax2.vlines(post_years, te_lo, te_hi, color=cols[1], linewidth=2)
ax2.axhline(0, color="grey", linestyle="--", linewidth=0.5)
ax2.set(xlabel="Year", ylabel="$\\hat{\\tau}_t$ (packs)",
title="Period-by-Period Effect")
fig.tight_layout()
return fig
|
Model
The NumPyro model encodes Modules 1 and 2. The treatment effect (Module 3)
is computed analytically from the weight posterior draws — no
jax.lax.stop_gradient is needed because the modularity is enforced by
separating the weight estimation from the treatment-effect computation.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
| def bayesian_sdid(
Y: jnp.ndarray,
N_co: int,
T_pre: int,
zeta: float = 1.0,
):
"""Bayesian SDiD weight model (Modules 1 & 2).
The treatment effect is computed analytically from the weight
posterior draws via the SDiD double-difference formula.
Parameters
----------
Y : array (N, T)
Standardised panel outcome matrix. First N_co rows are control units.
N_co : int
Number of control units.
T_pre : int
Number of pre-treatment periods.
zeta : float
Regularisation strength for the weight priors.
"""
N, T = Y.shape
Y_co_pre = Y[:N_co, :T_pre] # (N_co, T_pre)
y_tr_pre_mean = Y[N_co:, :T_pre].mean(axis=0) # (T_pre,)
y_co_post_mean = Y[:N_co, T_pre:].mean(axis=1) # (N_co,)
# ── Module 1: Unit weights ──────────────────────────────────────────
# Pin first logit to zero to remove softmax shift invariance
omega_raw = npy.sample(
"omega_raw", npy_dist.Normal(0, 1.0 / zeta).expand([N_co - 1])
)
omega_tilde = jnp.concatenate([jnp.zeros(1), omega_raw])
omega = npy.deterministic("omega", jax.nn.softmax(omega_tilde))
omega0 = npy.sample("omega0", npy_dist.Normal(0, 5.0))
sigma_omega = npy.sample("sigma_omega", npy_dist.HalfNormal(1.0))
# Matching: weighted control average ≈ treated average at each pre-period
npy.sample(
"omega_match",
npy_dist.Normal(omega0 + Y_co_pre.T @ omega, sigma_omega),
obs=y_tr_pre_mean,
)
# ── Module 2: Time weights ──────────────────────────────────────────
# Pin first logit to zero to remove softmax shift invariance
lambda_raw = npy.sample(
"lambda_raw", npy_dist.Normal(0, 1.0 / zeta).expand([T_pre - 1])
)
lambda_tilde = jnp.concatenate([jnp.zeros(1), lambda_raw])
lam = npy.deterministic("lam", jax.nn.softmax(lambda_tilde))
lambda0 = npy.sample("lambda0", npy_dist.Normal(0, 5.0))
sigma_lambda = npy.sample("sigma_lambda", npy_dist.HalfNormal(1.0))
# Matching: weighted pre-period average ≈ post-treatment average for each control
npy.sample(
"lambda_match",
npy_dist.Normal(lambda0 + Y_co_pre @ lam, sigma_lambda),
obs=y_co_post_mean,
)
|
The two modules each define their own likelihood (omega_match and
lambda_match) over their respective data slices. The weights are sampled
without any treatment-effect likelihood — the modularity is enforced by
construction rather than by gradient manipulation.

Figure 3: The Bayesian SDiD model. Modules 1 and 2 are sampled jointly via MCMC. The treatment effect τ is computed analytically from weight posterior draws, enforcing the cut by construction.
Two implementation details are worth noting. The panel is standardised
before fitting so that all priors have a natural unit scale; the treatment
effect is back-transformed to the original scale (packs per capita) after
computation. The regularisation parameter zeta is set to 1.0, which
gives a prior standard deviation of 1 on the softmax logits. This is
weak enough that the matching likelihood can push the weights away from
uniform toward the sparse solutions that distinguish SDiD from simple
DiD — stronger regularisation (e.g. ) over-shrinks the
weights toward and collapses the estimator to DiD.
Posterior sampling
We sample the weight parameters via NUTS. The model has
weight logit parameters
plus 4 matching parameters (, , ,
), for a total of 58 sampled parameters. After sampling,
we compute for each posterior draw via the SDiD double-difference
formula.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
| kernel = NUTS(bayesian_sdid, target_accept_prob=0.95)
mcmc = MCMC(
kernel,
num_warmup=2000,
num_samples=2000,
num_chains=4,
chain_method="parallel",
progress_bar=False,
)
mcmc.run(
key,
Y=Y_scaled,
N_co=N_co,
T_pre=T_pre,
)
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
| samples = mcmc.get_samples(group_by_chain=True)
# ── Module 3: Compute tau via SDiD double-difference ────────────────
omega_all = np.array(samples["omega"]) # (chains, draws, N_co)
lam_all = np.array(samples["lam"]) # (chains, draws, T_pre)
Y_co_np = np.array(Y_scaled[:N_co]) # (N_co, T)
y_tr_np = np.array(Y_scaled[N_co:].mean(axis=0)) # (T,)
# Gap series for each posterior draw: treated minus synthetic control
gaps = y_tr_np - omega_all @ Y_co_np # (chains, draws, T)
# SDiD formula: tau = mean(gap_post) - lambda' gap_pre
tau_scaled = (
gaps[..., T_pre:].mean(axis=-1)
- (gaps[..., :T_pre] * lam_all).sum(axis=-1)
)
tau_packs = tau_scaled * Y_std
idata = az.from_dict(posterior={
"tau": tau_packs,
"omega": omega_all,
"lam": lam_all,
"sigma_omega": np.array(samples["sigma_omega"]),
"sigma_lambda": np.array(samples["sigma_lambda"]),
})
print(az.summary(idata, var_names=["tau", "sigma_omega", "sigma_lambda"], hdi_prob=0.94))
|
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk \
tau -14.213 1.988 -17.589 -10.422 0.049 0.019 1655.0
sigma_omega 0.079 0.027 0.037 0.127 0.001 0.000 2098.0
sigma_lambda 0.322 0.046 0.244 0.408 0.001 0.001 6868.0
ess_tail r_hat
tau 3384.0 1.0
sigma_omega 3408.0 1.0
sigma_lambda 5248.0 1.0
Results
Treatment effect
The posterior of captures the average treatment effect of Proposition 99
on per-capita cigarette sales. A negative value indicates that the policy
reduced cigarette consumption.
1
2
3
4
| fig, ax = plt.subplots(figsize=(6, 3))
az.plot_posterior(idata, var_names=["tau"], hdi_prob=0.94, ax=ax)
ax.set_title("Posterior of $\\tau$: Effect of Proposition 99")
fig.tight_layout()
|

Weight posteriors
The weight posteriors reveal which control states and which pre-treatment years
the model relies on most heavily. Sparse unit weights indicate that a few
states dominate the synthetic California (SCM-like behaviour); near-uniform
unit weights would indicate that all controls contribute equally (DiD-like
behaviour). Time weights that concentrate on later pre-treatment years suggest
that the period just before Proposition 99 is most informative for the
post-treatment counterfactual.
1
2
| pre_years = [int(y) for y in years if y < treatment_year]
plot_weight_posteriors(mcmc.get_samples(), control_states, pre_years)
|

Counterfactual trajectories
We construct the synthetic California trajectory as
for each posterior draw,
yielding a full posterior distribution over the counterfactual. The gap between
California’s observed cigarette sales and the synthetic control in the
post-treatment period is the implied treatment effect.
1
2
3
4
5
6
7
8
| plot_counterfactual(
mcmc.get_samples(),
np.array(Y),
N_co,
T_pre,
[int(y) for y in years],
treatment_year,
)
|

Comparison to simple DiD
To illustrate the value of the SDiD reweighting, we compare against the naive
DiD estimator, which implicitly uses uniform unit and time weights:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
| Y_np = np.array(Y)
y_tr_post = Y_np[N_co:, T_pre:].mean()
y_tr_pre = Y_np[N_co:, :T_pre].mean()
y_co_post = Y_np[:N_co, T_pre:].mean()
y_co_pre = Y_np[:N_co, :T_pre].mean()
tau_did = (y_tr_post - y_tr_pre) - (y_co_post - y_co_pre)
tau_draws = tau_packs.flatten()
fig, ax = plt.subplots(figsize=(6, 3))
ax.hist(tau_draws, bins=50, density=True, alpha=0.5, color=cols[1],
edgecolor="white", label="Bayesian SDiD")
ax.axvline(tau_did, color=cols[2], linewidth=2, linestyle="--",
label=f"DiD estimate = {tau_did:.1f}")
ax.axvline(tau_draws.mean(), color=cols[1], linewidth=2, linestyle="--",
label=f"SDiD posterior mean = {tau_draws.mean():.1f}")
ax.legend(fontsize=8)
ax.set(xlabel="Treatment effect (packs per capita)", ylabel="Density",
title="Bayesian SDiD vs. Naive DiD")
fig.tight_layout()
|

Discussion
The Bayesian SDiD via cut posteriors preserves the three-stage structure of the
Arkhangelsky et al. estimator while providing full uncertainty quantification.
The weight posteriors offer an interpretive advantage over the frequentist point
estimates: they reveal not just which states and years receive weight, but how
confident the model is in those allocations.
The role of : regularisation as an estimator dial
The regularisation parameter controls where the Bayesian SDiD sits
on the spectrum between Difference-in-Differences and Synthetic Control. This
is not merely a tuning knob — it determines the fundamental character of
the estimator.
Large (strong regularisation). The prior
on the softmax logits concentrates them near zero,
producing near-uniform weights . The
synthetic control becomes the simple average of all control units, and the
double-difference formula reduces to ordinary DiD. In the limit
, the estimator recovers exact DiD.
Small (weak regularisation). The prior allows large logit
deviations, so the matching likelihood can concentrate weight on the few
states that closely track the treated unit’s pre-treatment trajectory. The
estimator approaches the Synthetic Control Method. In the limit
with no intercept , it recovers exact SC.
SDiD is designed to sit between these extremes, adapting to the data. The
Bayesian analogue inherits this spectrum through .
Diagnosing over-shrinkage. Two symptoms indicate that is too
large:
- The posterior mean of is close to the naive DiD estimate.
- The unit-weight posteriors are near-uniform, with the top-weighted states
receiving less than the uniform weight .
If both symptoms are present, the weights have insufficient freedom to
differentiate good-match states from poor ones, and the estimator is
effectively computing DiD with Bayesian uncertainty.
Remedies for over-shrinkage, in increasing order of sophistication:
Lower . Start with and decrease toward if
the weight posteriors remain near-uniform. With the two-stage estimation
approach, mixing is robust to weaker regularisation because the sampler
only explores weight space (no treatment-effect parameters to couple
with).
Data-driven . The synthdid R package computes from
the first-difference variance of the control panel, scaled by
. This can serve as a
rough calibration target when translated to the logit scale.
Sparsity-inducing priors. Replace the Gaussian prior on logits with
a Laplace (double-exponential) prior to encourage sparse weight vectors,
or use a Dirichlet prior with concentration directly on the
simplex. These better match the frequentist SDiD’s ability to produce
exact-zero weights for poor-match states, at the cost of more complex
posterior geometry.
A useful sanity check after any of these adjustments is to plot the
synthetic control trajectory
against the treated unit: if the pre-treatment fit is poor, the weights
need more freedom (lower ) or a different prior structure entirely.
Limitations
The cut posterior is not a standard Bayesian posterior — it does not correspond
to conditioning on a single generative model. Theoretical guarantees (such as
Bernstein–von Mises) require separate arguments for modular posteriors, and
calibration of credible intervals cannot be taken for granted.
The matching variances and are free
parameters with no direct counterpart in the frequentist SDiD. They control how
tightly the weights must satisfy the balancing condition: too small, and the
weight posterior is near-degenerate (approaching the frequentist point estimate);
too large, and the weights are diffuse and uninformative. In this implementation,
we treat them as parameters with half-normal priors and let the data inform
their scale, but sensitivity to this choice should be investigated in applied
settings.
The softmax parameterisation ensures all weights are strictly positive, while
the frequentist solution may have exact zeros (sparse weights). For applications
where sparsity is desirable, a Dirichlet prior with concentration or a
spike-and-slab prior on the softmax logits would be more appropriate.
Alternative approaches
An alternative Bayesian route to the same problem is to specify a latent factor
model with interactive fixed effects
(Xu, 2017). Rather than learning explicit
weights, the factor model captures unit heterogeneity through latent loadings
and extrapolates counterfactuals from the estimated factor structure. This
approach is more structurally ambitious and allows model comparison via LOO-CV,
but requires committing to a parametric factor structure and choosing the number
of latent factors . The cut-posterior SDiD stays closer to the frequentist
estimator and produces directly interpretable weight posteriors.
References
- Arkhangelsky, D., Athey, S., Hirshberg, D. A., Imbens, G. W., & Wager, S.
(2021). Synthetic Difference-in-Differences. American Economic Review,
111(12), 4088–4118.
- Lunn, D., Best, N., Spiegelhalter, D., Graham, G., & Neuenschwander, B.
(2009). Combining MCMC with ‘sequential’ PKPD modelling. Journal of
Pharmacokinetics and Pharmacodynamics, 36, 19–38.
- Plummer, M. (2015). Cuts in Bayesian graphical models. Statistics and
Computing, 25, 37–43.
- Xu, Y. (2017). Generalized Synthetic Control Method: Causal Inference with
Interactive Fixed Effects Models. Political Analysis, 25(1), 57–76.