Cox (semiparametric proportional hazard)#

This demonstration requires 2 optional dependencies : pandas and lifelines. Install them if you wish to reproduce examples outputs. Otherwise, below examples may help you to understand how you can use Cox model with ReLife.

Loading data#

Let’s use a lifelines dataset.

[1]:
import lifelines
import pandas as pd
[2]:
df = lifelines.datasets.load_canadian_senators()

Clean up the dataset

[3]:
df = df[["Name", "Province / Territory", "reason", "diff_days", "observed"]]
df = df[df["reason"] != "Appointment declined"]
# standardize text (removing white spaces, non alphanumeric characters and put in lower case)
df["Province / Territory"] = (
    df["Province / Territory"].str.replace("\W+", "", regex=True).str.lower()
)
display(df.head())
Name Province / Territory reason diff_days observed
0 Abbott, John Joseph Caldwell quebec Death 2363 True
1 Adams, Michael newbrunswick Death 1090 True
2 Adams, Willie northwestterritories Retirement 11766 True
3 Aikins, James Cox ontario Resignation 5333 True
4 Aikins, James Cox ontario Death 3133 True
[4]:
df["Province / Territory"].value_counts()
[4]:
Province / Territory
quebec                      246
ontario                     242
novascotia                   98
newbrunswick                 93
britishcolumbia              44
manitoba                     44
alberta                      43
princeedwardisland           39
saskatchewan                 32
newfoundlandandlabrador      30
northwestterritories          7
yukon                         3
westernprovincesdivision      2
maritimesdivision             2
quebecdivision                2
ontariodivision               2
nunavut                       1
Name: count, dtype: int64

Dummy encoding of “Province / Territory” covariates

[5]:
# dummy encoding of "Province / Territory"
df = pd.get_dummies(df, columns=["Province / Territory"], prefix="covar", dtype=int)
# remove covariate columns corresponding to low count province
covar_cols = [col for col in df.columns if "covar" in col]
province_count = df[covar_cols].sum()
province_with_low_count = province_count[province_count < 10]
df = df.drop(columns=province_with_low_count.index)
[6]:
print(df.columns)
Index(['Name', 'reason', 'diff_days', 'observed', 'covar_alberta',
       'covar_britishcolumbia', 'covar_manitoba', 'covar_newbrunswick',
       'covar_newfoundlandandlabrador', 'covar_novascotia', 'covar_ontario',
       'covar_princeedwardisland', 'covar_quebec', 'covar_saskatchewan'],
      dtype='object')

Fitting#

Lifelines version

[7]:
from lifelines.fitters.coxph_fitter import CoxPHFitter
# Lifelines model fit

covar_cols = [col for col in df.columns if "covar" in col]
ll_model = CoxPHFitter()
ll_model.fit(
    df=df,
    duration_col="diff_days",
    event_col="observed",
    formula="~ " + " + ".join(covar_cols),
)
print(ll_model.params_)
covariate
covar_alberta                    0.345928
covar_britishcolumbia            0.173319
covar_manitoba                   0.292570
covar_newbrunswick               0.193955
covar_newfoundlandandlabrador    0.444993
covar_novascotia                 0.266590
covar_ontario                    0.319706
covar_princeedwardisland         0.397880
covar_quebec                     0.344936
covar_saskatchewan              -0.049557
Name: coef, dtype: float64
[8]:
# Lifelines sf
max_offset = 0

X = df.filter(regex="covar").iloc[:2]

sf_lifelines = ll_model.predict_survival_function(
    X=X,
    times=range(
        df["diff_days"].min().astype(int),
        df["diff_days"].max().astype(int) + max_offset,
    ),
)

sf_lifelines.columns = X.index
sf_lifelines.plot(xlabel="time", ylabel="sf")
[8]:
<Axes: xlabel='time', ylabel='sf'>
../_images/user_guides_semiparametric_regression_15_1.png

ReLife version

[9]:
from relife.lifetime_models import SemiParametricProportionalHazard

cox_model = SemiParametricProportionalHazard()
cox_model.fit(
    df["diff_days"].to_numpy(),
    covar=df.filter(regex="covar").to_numpy(),
    event=df["observed"].to_numpy(dtype=bool),
)
[9]:
<relife.lifetime_models._semi_parametric_regressions.SemiParametricProportionalHazard at 0x7cf61d876750>
[10]:
timeline, sf_values = cox_model.sf(covar=X, se=False)
[11]:
import matplotlib.pyplot as plt

plt.plot(timeline, sf_values[0])
plt.plot(timeline, sf_values[1])
[11]:
[<matplotlib.lines.Line2D at 0x7cf61d44a590>]
../_images/user_guides_semiparametric_regression_19_1.png