A few weeks ago, a randomized trial of remdesivir showed that it seemed to reduce mortality risk for coronavirus patients, from 11.6% to 8%. The p-value for this comparison was p=0.059. But what does that really mean? Does remdesivir work, or not? How likely is it to work? 94.1%?
Pyro is a powerful probabilistic programming library based on PyTorch that can answer questions like this, and much, much more. To start off, consider a hypothetical disease, the Pikachu virus. We can create a disease severity scale to measure how bad each case of Pikachu is, from 0 being no symptoms at all, to 100 being severely ill. Let’s assume the patients we get follow a beta distribution (rescaled), with the center at 50 and a standard deviation of 20. We can then simulate groups of patients getting sick:
def pikachu_patients(n):
mode0 = common.make_tensor(0.5)
k0 = common.make_tensor(8.1269)
scaling = common.make_tensor(100.0)
# beta reparameterized
# https://bit.ly/2ZqjILG
patient_dist = common.beta(mode0, k0)
patients = []
for i in range(n):
patient = pyro.sample("severity", patient_dist)
patients.append(patient * scaling)
return torch.stack(patients)
(This is, of course, a very simplified model! More elaborate and realistic models will be explored in later posts.)
This gives us a handy distribution of data points to play with (n = 50):

Suppose someone runs a two-arm, n=100, randomized controlled trial for the use of rubber gloves in treating Pikachu. The data shown above is the non-treatment group, with a nice mean severity of 49.052, while the treatment group looks like this:

The mean severity of this group is 42.6573, p=0.0423, which is less than the standard 0.05. The treatment works! (In fact, xkcd-style, this was made by just running the sampler a bunch of times until something suitable came out.) But how likely is it to work, really? And with what effect size?
Before the trial starts, we already know that most treatments don’t work. Let’s assume that there’s an 80% chance of no effect (zero distribution), a 10% chance of a small effect (beta distribution, mode = 0, a = 1, b = 9, rescaled to range 0-30), and a 10% chance of a significant effect of unknown size (beta distribution, mode = 0.5, a = 1, b = 1, rescaled):
PRIORS = {
'p_effect_sizes': [0.8, 0.1, 0.1],
'mode_effect_size': 0.5, # before scaling
'effect_k': 2.0, # uniform distribution
'small_effect_k': 10.0,
'scale': 30.0
}
Our prior distribution looks like this:

Now, we make a simulation of what might happen to patients under different possible outcomes, and match that against the observed “trial data”:
def make_distributions(mode_effect_size, effect_k, small_effect_k):
return [common.zero_dist(),
common.beta(0, small_effect_k),
common.beta(mode_effect_size, effect_k)]
def model(patients):
patients = patients.clone()
scaling = common.make_tensor(100.0)
patients /= scaling
p_effect_sizes = common.make_tensor(PRIORS['p_effect_sizes'])
mode_effect_size = common.make_tensor(PRIORS['mode_effect_size'])
effect_k = common.make_tensor(PRIORS['effect_k'])
small_effect_k = common.make_tensor(PRIORS['small_effect_k'])
scale = common.make_tensor(PRIORS['scale']) / scaling
distributions = make_distributions(mode_effect_size, effect_k, small_effect_k)
which_world = pyro.sample("which_world", dist.Categorical(p_effect_sizes))
effect_size_raw = pyro.sample("effect_size", distributions[which_world])
effect_size = effect_size_raw * scale
mode0 = common.make_tensor(0.5)
k0 = common.make_tensor(8.1269)
# beta reparameterized
# https://bit.ly/2ZqjILGoriginal
original_dist = common.beta(mode0, k0)
new_dist = common.beta(mode0 - effect_size, k0)
with pyro.plate("trial", len(patients[1])):
pyro.sample("severity_trial", new_dist,
obs=patients[1])
with pyro.plate("test", len(patients[0])):
pyro.sample("severity_test", original_dist,
obs=patients[0])
(Note that in this simple model, since we already know the underlying distribution given no intervention, the “control” data has no effect on the findings. In real life, of course, control data would be used to find other parameters.)
We then create a “variational distribution”, or “guide”, to approximate the posterior:
def guide(patients):
p_effect_sizes = pyro.param(
"p_effect_sizes", common.make_tensor(PRIORS['p_effect_sizes']),
constraint=constraints.simplex)
mode_effect_size = pyro.param(
"mode_effect_size", common.make_tensor(PRIORS['mode_effect_size']),
constraint=constraints.interval(0, 1))
effect_k = pyro.param(
"effect_k", common.make_tensor(PRIORS['effect_k']),
constraint=constraints.positive)
small_effect_k = pyro.param(
"small_effect_k", common.make_tensor(PRIORS['small_effect_k']),
constraint=constraints.positive)
distributions = make_distributions(mode_effect_size, effect_k, small_effect_k)
which_world = pyro.sample("which_world", dist.Categorical(p_effect_sizes))
effect_size_raw = pyro.sample("effect_size", distributions[which_world])
And run some inference:
def find_params(model, guide, data, steps, callable_log):
# setup the optimizer
adam_params = {"lr": 0.002, "betas": (0.90, 0.999)}
scheduler = MultiStepLR(
{'optimizer': Adam, 'optim_args': adam_params,
'milestones': [100, 200], 'gamma': 0.2})
# setup the inference algorithm
svi = SVI(model, guide, scheduler, loss=Trace_ELBO())
logger.info("Begin gradient descent")
# do gradient steps
for step in progressbar.progressbar(range(steps)):
loss = svi.step(data)
if step % 1000 == 0:
callable_log()
scheduler.step()
We get an 18.74% chance of no effect, a 26.45% chance of a small effect, and a 54.82% chance of a significant effect. We also have a much clearer idea of how large a “significant” effect would be. Here’s the graph:

So we have a reasonable guess as to what the effect is, but it could still be a coincidence. This matches intuitions from eyeballing the data pretty well.
It’ll be cool to extend this simplified model, but one issue to be aware of with Pyro is that if the prior doesn’t include an outcome, it’s effectively assumed to be impossible, no matter what the data shows. The prior here has no term for adverse effects, for example. So even if the data was very negative, the posterior would never tell you that the treatment did harm (although it would have a very low probability for beneficial effects). On the flip side, PyTorch is so scalable that it should be straightforward to create more complex models with very large numbers of variables.
One comment
Do you want to comment?
Comments RSS and TrackBack URI
Trackbacks