Reference

correlation_plot_strata(df, name_biomarkers, strata='status')

Generate a heatmap and pairplot of biomarkers for each strata.

Parameters:
  • df (DataFrame) –

    Dataframe with biomarkers and strata information.

  • name_biomarkers (list[str]) –

    List of biomarker names to plot.

  • strata (str, default: 'status' ) –

    The name of the strata column. Default is 'status'.

Returns:
  • None

    None

sreftml\plots.py
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
def correlation_plot_strata(
    df: pd.DataFrame, name_biomarkers: list[str], strata: str = "status"
) -> None:
    """
    Generate a heatmap and pairplot of biomarkers for each strata.

    Args:
        df (pd.DataFrame): Dataframe with biomarkers and strata information.
        name_biomarkers (list[str]): List of biomarker names to plot.
        strata (str, optional): The name of the strata column. Default is 'status'.

    Returns:
        None
    """
    for i in range(2):
        plt.figure(figsize=(10, 10))
        sns.heatmap(
            df[df[strata] == i][name_biomarkers].corr(),
            cmap="coolwarm",
            vmin=-1,
            vmax=1,
            annot=True,
            fmt="1.2f",
        )

        plt.figure(figsize=(10, 10))
        sns.pairplot(df[df[strata] == i][name_biomarkers].reset_index(drop=True))

    sns.pairplot(df[name_biomarkers + [strata]], hue=strata, diag_kind="hist")
    return None

get_regression_line_label(x, y)

Generate a label for a line fitted to the given x and y data using linear regression.

Parameters:
  • x (Series) –

    Series of x-axis data.

  • y (Series) –

    Series of y-axis data.

Returns:
  • str( str ) –

    Label for the fitted line, including slope, intercept, and R-squared value.

sreftml\plots.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def get_regression_line_label(x: pd.Series, y: pd.Series) -> str:
    """
    Generate a label for a line fitted to the given x and y data using linear regression.

    Parameters:
        x (pd.Series): Series of x-axis data.
        y (pd.Series): Series of y-axis data.

    Returns:
        str: Label for the fitted line, including slope, intercept, and R-squared value.

    """
    slope, intercept, r_value, p_value, std_err = linregress(x, y)
    label_line_1 = rf"$y={slope:.3f}x{'' if intercept < 0 else '+'}{intercept:.3f}$"
    label_line_2 = rf"$R^2:{r_value**2:.2f}$"
    label_line = label_line_1 + "\n" + label_line_2

    return label_line

histogram_plot(df, col_name, hue=None, sharex=True, sharey=True, ncol_max=4, save_file_path=None)

Plot a stratified histogram by column.

Parameters:
  • df (DataFrame) –

    Input DataFrame.

  • col_name (list[str] | str) –

    List of column names in df.

  • hue (str | None, default: None ) –

    Column to stratify the plot. Defaults to None.

  • share{x, (y} (bool | "col" | "row") –

    This is passed directly to seaborn.FacetGrid.

  • ncol_max (int, default: 4 ) –

    Maximum number of columns. Defaults to 4.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • FacetGrid

    sns.axisgrid.FacetGrid: FacetGrid object with the distribution plot.

sreftml\plots.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
def histogram_plot(
    df: pd.DataFrame,
    col_name: list[str] | str,
    hue: str | None = None,
    sharex: bool = True,
    sharey: bool = True,
    ncol_max: int = 4,
    save_file_path: str | None = None,
) -> sns.axisgrid.FacetGrid:
    """
    Plot a stratified histogram by column.

    Args:
        df (pd.DataFrame): Input DataFrame.
        col_name (list[str] | str): List of column names in df.
        hue (str | None, optional): Column to stratify the plot. Defaults to None.
        share{x, y} (bool | "col" | "row", optional): This is passed directly to seaborn.FacetGrid.
        ncol_max (int, optional): Maximum number of columns. Defaults to 4.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        sns.axisgrid.FacetGrid: FacetGrid object with the distribution plot.
    """
    if type(col_name) is str:
        col_name = [col_name]
    col_wrap = n2mfrow(len(col_name), ncol_max=ncol_max)[1]

    if hue is None:
        df_melt = pd.melt(df[col_name])
    else:
        df_melt = pd.melt(df[col_name + [hue]], hue)
    g = sns.FacetGrid(
        df_melt,
        col="variable",
        hue=hue,
        col_wrap=col_wrap,
        sharex=sharex,
        sharey=sharey,
        height=3.5,
    )
    g.map(plt.hist, "value", alpha=0.4)
    g.add_legend()
    g.set_titles("{col_name}")

    if save_file_path:
        g.savefig(save_file_path, transparent=True)

    return g

hp_search_plot(df_grid, eval_col='score', save_file_path=None)

Plot the results of hyperparameters search.

Parameters:
  • df_grid (DataFrame) –

    DataFrame containing the grid of hyperparameters.

  • eval_col (str, default: 'score' ) –

    The column to use for evaluation. Defaults to "score".

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • Figure

    plt.Figure: The plotted figure.

sreftml\plots.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def hp_search_plot(
    df_grid: pd.DataFrame,
    eval_col: str = "score",
    save_file_path: str | None = None,
) -> plt.Figure:
    """
    Plot the results of hyperparameters search.

    Args:
        df_grid (pd.DataFrame): DataFrame containing the grid of hyperparameters.
        eval_col (str, optional): The column to use for evaluation. Defaults to "score".
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        plt.Figure: The plotted figure.
    """
    df_grid = df_grid.sort_values(eval_col, ascending=False).reset_index(drop=True)
    cols = df_grid.columns.tolist()
    cols.remove(eval_col)
    cols.append(eval_col)
    df_grid_sorted = df_grid[cols]

    string_columns = df_grid_sorted.select_dtypes(include="object").columns
    numeric_columns = df_grid_sorted.select_dtypes(exclude="object").columns

    for i in string_columns:
        df_grid_sorted[i] = sp.LabelEncoder().fit_transform(df_grid_sorted[i])

    scaler = sp.MinMaxScaler()
    df_grid_scaled = pd.DataFrame(
        scaler.fit_transform(df_grid_sorted), columns=df_grid_sorted.columns
    )

    cm = plt.get_cmap("seismic", 2)
    fig = plt.figure(tight_layout=True, dpi=300)
    for i, row in df_grid_scaled.iterrows():
        if i == len(df_grid_sorted) - 1:
            plt.plot(df_grid_scaled.columns, row.values, color=cm(1), lw=4)
        else:
            plt.plot(df_grid_scaled.columns, row.values, color=cm(0))

    for i in string_columns:
        label_unique = df_grid[i].unique()
        scaled_unique = df_grid_scaled[i].unique()
        for label_, scaled_ in zip(label_unique, scaled_unique):
            plt.text(
                i, scaled_, label_, ha="center", va="center", backgroundcolor="white"
            )

    for i in numeric_columns:
        min_val = df_grid_sorted[i].min()
        max_val = df_grid_sorted[i].max()
        plt.text(i, 0, min_val, ha="center", va="center", backgroundcolor="white")
        if min_val != max_val:
            plt.text(i, 1, max_val, ha="center", va="center", backgroundcolor="white")

    for i, val in enumerate(df_grid_scaled.iloc[-1, :]):
        col_name = df_grid_scaled.columns[i]
        if val not in [0, 1] and col_name in numeric_columns:
            plt.text(
                col_name,
                val,
                df_grid_sorted.iloc[-1, i],
                ha="center",
                va="center",
                backgroundcolor="white",
            )

    plt.xticks(rotation=45)
    plt.ylabel("Min-Max Normalized value")

    if save_file_path is not None:
        plt.savefig(save_file_path, transparent=True)

    return fig

learning_history_plot(df_loss, save_file_path=None)

Plot learning history.

Parameters:
  • df_loss (DataFrame) –

    Data frame converted from tf.keras.callbacks.History.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • Figure

    plt.Figure: The plotted figure.

sreftml\plots.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def learning_history_plot(
    df_loss: pd.DataFrame, save_file_path: str | None = None
) -> plt.Figure:
    """
    Plot learning history.

    Args:
        df_loss (pd.DataFrame): Data frame converted from tf.keras.callbacks.History.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        plt.Figure: The plotted figure.
    """
    fig = plt.figure(tight_layout=True, dpi=300)
    plt.plot(df_loss["loss"], label="training")
    plt.plot(df_loss["val_loss"], label="validation")
    plt.xlabel("Epoch")
    plt.ylabel("loss")
    plt.legend()

    if save_file_path is not None:
        plt.savefig(save_file_path, transparent=True)

    return fig

merged_permutation_importance_plot(mean_pi, name_biomarkers, name_covariates, y_axis_log=False, save_file_path=None)

Generate a permutation importance plot.

Parameters:
  • mean_pi (ndarray) –

    Array of mean permutation importance values.

  • name_biomarkers (List[str]) –

    List of biomarker names.

  • name_covariates (list[str]) –

    The names of the covariates.

  • y_axis_log (bool, default: False ) –

    Whether to use log scale for y-axis. Default is False.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • Figure

    plt.Figure: The plotted figure.

sreftml\plots.py
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
def merged_permutation_importance_plot(
    mean_pi: np.ndarray,
    name_biomarkers: list[str],
    name_covariates: list[str],
    y_axis_log: bool = False,
    save_file_path: str | None = None,
) -> plt.Figure:
    """
    Generate a permutation importance plot.

    Args:
        mean_pi (np.ndarray): Array of mean permutation importance values.
        name_biomarkers (List[str]): List of biomarker names.
        name_covariates (list[str]): The names of the covariates.
        y_axis_log (bool, optional): Whether to use log scale for y-axis. Default is False.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        plt.Figure: The plotted figure.
    """
    bar = pd.DataFrame(
        {
            "labels": [i + j for j in ["_slope", "intercept"] for i in name_biomarkers]
            + name_covariates,
            "values": mean_pi.round(3),
        }
    )

    result_data = {"labels": [], "mean_pi": []}

    for i in name_biomarkers:
        slope_label = i + "_slope"
        intercept_label = i + "_intercept"
        slope_value = bar.loc[bar["labels"] == slope_label, "mean_pi"].values[0]
        intercept_value = bar.loc[bar["labels"] == intercept_label, "mean_pi"].values[0]
        total_value = slope_value + intercept_value
        result_data["labels"].append(i)
        result_data["mean_pi"].append(total_value)

    for i in name_covariates:
        cov_value = bar.loc[bar["labels"] == i, "mean_pi"].values[0]
        result_data["labels"].append(i)
        result_data["mean_pi"].append(cov_value)

    result_data = pd.DataFrame(result_data)

    mean_pi = result_data.mean_pi.values
    rank = np.argsort(mean_pi)
    fig = plt.figure(figsize=(len(rank) / 4, 10), dpi=300, tight_layout=True)
    plt.bar([result_data.labels.values[i] for i in rank], mean_pi[rank])
    plt.xticks(rotation=45, ha="right")
    if y_axis_log:
        plt.ylabel("Permutation Importance (log scale)")
        plt.yscale("log")
    else:
        plt.ylabel("Permutation Importance")

    if save_file_path is not None:
        plt.savefig(save_file_path, transparent=True)

    return fig

merged_shap_bar_plot(shap_exp_model_1, name_biomarkers, name_covariates, save_file_path=None)

Plot the SHAP values of the model 1.

Parameters:
  • shap_exp_model_1 (Explanation) –

    The SHAP explanation for the model 1.

  • name_biomarkers (List[str]) –

    List of biomarker names.

  • name_covariates (list[str]) –

    The names of the covariates.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • fig( Figure ) –

    Matplotlib figure object representing the generated plot.

sreftml\plots.py
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
def merged_shap_bar_plot(
    shap_exp_model_1: shap.Explanation,
    name_biomarkers: list[str],
    name_covariates: list[str],
    save_file_path: str | None = None,
) -> plt.Figure:
    """
    Plot the SHAP values of the model 1.

    Args:
        shap_exp_model_1 (shap.Explanation): The SHAP explanation for the model 1.
        name_biomarkers (List[str]): List of biomarker names.
        name_covariates (list[str]): The names of the covariates.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        fig (plt.Figure): Matplotlib figure object representing the generated plot.
    """
    bar = pd.DataFrame(
        {
            "labels": [i + j for j in ["_slope", "intercept"] for i in name_biomarkers]
            + name_covariates,
            "shap": np.mean(abs(shap_exp_model_1.values), axis=0).round(3),
        }
    )

    result_data = {"labels": [], "shap": []}

    for i in name_biomarkers:
        slope_label = i + "_slope"
        intercept_label = i + "_intercept"
        slope_value = bar.loc[bar["labels"] == slope_label, "shap"].values[0]
        intercept_value = bar.loc[bar["labels"] == intercept_label, "shap"].values[0]
        total_value = slope_value + intercept_value
        result_data["labels"].append(i)
        result_data["shap"].append(total_value)

    for i in name_covariates:
        cov_value = bar.loc[bar["labels"] == i, "shap"].values[0]
        result_data["labels"].append(i)
        result_data["shap"].append(cov_value)

    result_data = pd.DataFrame(result_data)

    shap_exp = result_data.shap.values
    rank = np.argsort(shap_exp)
    fig = plt.figure(figsize=(len(rank) / 4, 10), dpi=300, tight_layout=True)
    plt.bar([result_data.labels.values[i] for i in rank], shap_exp[rank])
    plt.xticks(rotation=45, ha="right")
    plt.ylabel("mean(|SHAP value|)")

    if save_file_path is not None:
        plt.savefig(save_file_path, transparent=True)

    return fig

multi_panel_scatter_plot(df, x_col, y_col, hue, duplicate_key=None, ncol_max=4, density=False, identity=False, save_file_path=None)

Draw scatter plots with multiple panels based on stratification factors.

Parameters:
  • df (DataFrame) –

    Input DataFrame.

  • x_col (str) –

    X-axis column in df.

  • y_col (str) –

    Y-axis column in df.

  • hue (list[str] | str) –

    Columns to stratify the plot.

  • duplicate_key (list[str] | str | None, default: None ) –

    Specify the column name(s) from which duplicates are to be removed. Defaults to None.

  • ncol_max (int, default: 4 ) –

    Maximum number of columns. Defaults to 4.

  • density (bool, default: False ) –

    Whether to plot density. Defaults to False.

  • identity (bool, default: False ) –

    Whether to plot identity line. Defaults to False.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • FacetGrid

    sns.axisgrid.FacetGrid: FacetGrid object with the scatter plot.

sreftml\plots.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
def multi_panel_scatter_plot(
    df: pd.DataFrame,
    x_col: str,
    y_col: str,
    hue: list[str] | str,
    duplicate_key: list[str] | str | None = None,
    ncol_max: int = 4,
    density: bool = False,
    identity: bool = False,
    save_file_path: str | None = None,
) -> sns.axisgrid.FacetGrid:
    """
    Draw scatter plots with multiple panels based on stratification factors.

    Args:
        df (pd.DataFrame): Input DataFrame.
        x_col (str): X-axis column in df.
        y_col (str): Y-axis column in df.
        hue (list[str] | str): Columns to stratify the plot.
        duplicate_key (list[str] | str | None, optional): Specify the column name(s) from which duplicates are to be removed. Defaults to None.
        ncol_max (int, optional): Maximum number of columns. Defaults to 4.
        density (bool, optional): Whether to plot density. Defaults to False.
        identity (bool, optional): Whether to plot identity line. Defaults to False.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        sns.axisgrid.FacetGrid: FacetGrid object with the scatter plot.
    """
    if type(hue) is str:
        hue = [hue]

    df_ = clean_duplicate(df, [x_col, y_col] + hue, duplicate_key)
    hue_ = ", ".join(hue)
    if len(hue) > 1:
        df_[hue_] = df[hue].apply(lambda x: ", ".join(x.astype(str)), axis=1)
    unique_hue = np.sort(df_[hue_].unique())

    col_wrap = n2mfrow(len(df_[hue_].unique()), ncol_max=ncol_max)[1]
    g = sns.lmplot(
        data=df_,
        x=x_col,
        y=y_col,
        col=hue_,
        col_wrap=col_wrap,
        col_order=unique_hue,
        scatter=not density,
        scatter_kws={"alpha": 0.5, "s": 20, "edgecolor": "none"},
        line_kws={"color": "red", "label": "lines"},
    )
    g.figure.set_dpi(300)

    for idx, s in enumerate(unique_hue):
        df_hue = df_.loc[df_[hue_] == s]
        label_line = get_regression_line_label(df_hue[x_col], df_hue[y_col])

        if density:
            xy = df_hue[[x_col, y_col]].values.T
            z = gaussian_kde(xy)(xy)
            x_ = xy.T[:, :1]
            y_ = xy.T[:, 1:]
            g.axes[idx].scatter(x_, y_, c=z, s=20, edgecolor="none", cmap="viridis")
            g.axes[idx].legend([label_line])
        else:
            g.axes[idx].legend(["_nolegend_", label_line])

        if identity:
            if df[y_col].max() < df[x_col].min() or df[x_col].max() < df[y_col].min():
                warnings.warn(
                    f"The data range of {x_col} and {y_col} is not covered, although idenntity=True. Skip drawing of identity line."
                )
            else:
                min_ = df[[x_col, y_col]].min().max()
                max_ = df[[x_col, y_col]].max().min()
                g.axes[idx].plot([min_, max_], [min_, max_], "k--")

    if save_file_path is not None:
        plt.savefig(save_file_path, transparent=True, dpi=300)

    return g

permutation_importance_plot(mean_pi, std_pi, feature_label, y_axis_log=False, save_file_path=None)

Generate a permutation importance plot.

Parameters:
  • mean_pi (ndarray) –

    Array of mean permutation importance values.

  • std_pi (ndarray) –

    Array of standard deviation permutation importance values.

  • feature_label (list[str]) –

    List of feature names for which PI was measured.

  • y_axis_log (bool, default: False ) –

    Whether to use log scale for y-axis. Default is False.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • Figure

    plt.Figure: The plotted figure.

sreftml\plots.py
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
def permutation_importance_plot(
    mean_pi: np.ndarray,
    std_pi: np.ndarray,
    feature_label: list[str],
    y_axis_log: bool = False,
    save_file_path: str | None = None,
) -> plt.Figure:
    """
    Generate a permutation importance plot.

    Args:
        mean_pi (np.ndarray): Array of mean permutation importance values.
        std_pi (np.ndarray): Array of standard deviation permutation importance values.
        feature_label (list[str]): List of feature names for which PI was measured.
        y_axis_log (bool, optional): Whether to use log scale for y-axis. Default is False.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        plt.Figure: The plotted figure.
    """
    rank = np.argsort(mean_pi)
    fig = plt.figure(figsize=(len(rank) / 4, 10), dpi=300, tight_layout=True)
    plt.bar([feature_label[i] for i in rank], mean_pi[rank], yerr=std_pi[rank])
    plt.xticks(rotation=45, ha="right")
    if y_axis_log:
        plt.ylabel("Permutation Importance (log scale)")
        plt.yscale("log")
    else:
        plt.ylabel("Permutation Importance")

    if save_file_path is not None:
        plt.savefig(save_file_path, transparent=True)

    return fig

prediction_plot(sreft, df, name_biomarkers, name_covariates, scaler_y, scaler_cov, res=100, density=False, useOffsetT=True, ncol_max=4, save_file_path=None)

Plot the predictions of the SReFT model.

Parameters:
  • sreft (Model) –

    The SReFT model.

  • df (DataFrame) –

    DataFrame with the data.

  • name_biomarkers (list[str]) –

    The names of the biomarkers.

  • name_covariates (list[str]) –

    The names of the covariates.

  • scaler_y (StandardScaler) –

    The scaler for the y values.

  • scaler_cov (StandardScaler) –

    The scaler for the covariate values.

  • res (int, default: 100 ) –

    Resolution of the plot. Defaults to 100.

  • density (bool, default: False ) –

    Whether to plot density or not. Defaults to False.

  • useOffsetT (bool, default: True ) –

    Whether to use offsetT or not. Defaults to True.

  • ncol_max (int, default: 4 ) –

    Maximum number of columns for subplots. Defaults to 4.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • Figure

    plt.Figure: The plotted figure.

sreftml\plots.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def prediction_plot(
    sreft: tf.keras.Model,
    df: pd.DataFrame,
    name_biomarkers: list[str],
    name_covariates: list[str],
    scaler_y: sp.StandardScaler,
    scaler_cov: sp.StandardScaler,
    res: int = 100,
    density: bool = False,
    useOffsetT: bool = True,
    ncol_max: int = 4,
    save_file_path: str | None = None,
) -> plt.Figure:
    """
    Plot the predictions of the SReFT model.

    Args:
        sreft (tf.keras.Model): The SReFT model.
        df (pd.DataFrame): DataFrame with the data.
        name_biomarkers (list[str]): The names of the biomarkers.
        name_covariates (list[str]): The names of the covariates.
        scaler_y (sp.StandardScaler): The scaler for the y values.
        scaler_cov (sp.StandardScaler): The scaler for the covariate values.
        res (int, optional): Resolution of the plot. Defaults to 100.
        density (bool, optional): Whether to plot density or not. Defaults to False.
        useOffsetT (bool, optional): Whether to use offsetT or not. Defaults to True.
        ncol_max (int, optional): Maximum number of columns for subplots. Defaults to 4.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        plt.Figure: The plotted figure.
    """
    n_biomarker = len(name_biomarkers)
    n_covariate = len(name_covariates)
    n_row, n_col = n2mfrow(n_biomarker, ncol_max)
    cm = plt.colormaps["Set1"]

    y_data = df[name_biomarkers].values
    if useOffsetT:
        x_data = df.TIME.values + df.offsetT.values
        cov_dummy = np.array([i for i in itertools.product([0, 1], repeat=n_covariate)])
        cov_dummy = np.repeat(cov_dummy, res, axis=0)
        cov_dummy_scaled = scaler_cov.transform(cov_dummy)
        x_model = np.linspace(x_data.min(), x_data.max(), res)
        x_model = np.tile(x_model, 2**n_covariate).reshape(-1, 1)
        x_model = np.concatenate((x_model, cov_dummy_scaled), axis=1)
        y_model = scaler_y.inverse_transform(sreft.model_y(x_model))
    else:
        x_data = df.TIME.values

    fig, axs = plt.subplots(
        n_row,
        n_col,
        figsize=(n_col * 3, n_row * 3),
        tight_layout=True,
        dpi=300,
        sharex="row",
    )
    for k, ax in enumerate(axs.flat):
        if k >= n_biomarker:
            ax.axis("off")
            continue

        if density:
            x_ = x_data[~np.isnan(y_data[:, k])]
            y_ = y_data[~np.isnan(y_data[:, k]), k]
            if np.var(x_) == 0:
                z = gaussian_kde(y_)(y_)
            else:
                xy = np.vstack([x_, y_])
                z = gaussian_kde(xy)(xy)
            idx = z.argsort()
            ax.scatter(x_[idx], y_[idx], c=z[idx], s=2, label="_nolegend_")
        else:
            ax.scatter(x_data, y_data[:, k], c="silver", s=2, label="_nolegend_")

        if useOffsetT:
            for i in range(2**n_covariate):
                ax.plot(
                    x_model[res * i : res * (i + 1), 0],
                    y_model[res * i : res * (i + 1), k],
                    c=cm(i),
                    lw=4,
                )
            ax.set_xlabel("Disease Time (year)")
        else:
            ax.set_xlabel("Observation Period (year)")

        ax.set_title(name_biomarkers[k], fontsize=15)

    if n_covariate > 0:
        legend_labels = [
            ", ".join(format(i, f"0{n_covariate}b")) for i in range(2**n_covariate)
        ]
        fig.legend(
            loc="center",
            framealpha=0,
            bbox_to_anchor=(1.1, 0.5),
            ncol=1,
            title=", ".join(name_covariates),
            labels=legend_labels,
        )

    if save_file_path is not None:
        fig.savefig(save_file_path, transparent=True, bbox_inches="tight")

    return fig

prediction_sim_plot(df, sreft, params_true, name_biomarkers, name_covariates, scaler_cov, scaler_y, res=100, density=False, ncol_max=4, save_file_path=None)

Generate a prediction simulation plot.

Parameters:
  • df (DataFrame) –

    Dataframe with biomarkers and other information.

  • sreft (Model) –

    Object responsible for transforming the data.

  • params_true (DataFrame) –

    Dataframe with true parameters for the model.

  • name_biomarkers (list[str]) –

    List of biomarker names.

  • name_covariates (list[str]) –

    List of covariate names.

  • scaler_cov (StandardScaler) –

    Scaler for the covariate values.

  • scaler_y (StandardScaler) –

    Scaler for the y values.

  • res (int, default: 100 ) –

    Resolution for the plot. Default is 100.

  • density (bool, default: False ) –

    Whether to use density or not. Default is False.

  • ncol_max (int, default: 4 ) –

    Maximum number of columns for the plot. Default is 4.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • Figure( Figure ) –

    The created matplotlib figure.

sreftml\plots.py
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
def prediction_sim_plot(
    df: pd.DataFrame,
    sreft: tf.keras.Model,
    params_true: pd.DataFrame,
    name_biomarkers: list[str],
    name_covariates: list[str],
    scaler_cov: sp.StandardScaler,
    scaler_y: sp.StandardScaler,
    res: int = 100,
    density: bool = False,
    ncol_max: int = 4,
    save_file_path: str | None = None,
) -> plt.Figure:
    """
    Generate a prediction simulation plot.

    Args:
        df (pd.DataFrame): Dataframe with biomarkers and other information.
        sreft (tf.keras.Model): Object responsible for transforming the data.
        params_true (pd.DataFrame): Dataframe with true parameters for the model.
        name_biomarkers (list[str]): List of biomarker names.
        name_covariates (list[str]): List of covariate names.
        scaler_cov: Scaler for the covariate values.
        scaler_y: Scaler for the y values.
        res (int, optional): Resolution for the plot. Default is 100.
        density (bool, optional): Whether to use density or not. Default is False.
        ncol_max (int, optional): Maximum number of columns for the plot. Default is 4.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        Figure: The created matplotlib figure.
    """
    n_biomarker = len(name_biomarkers)
    n_covariate = len(name_covariates)
    n_row, n_col = n2mfrow(n_biomarker, ncol_max)
    cm = plt.colormaps["Set1"]

    y_data = df[name_biomarkers].values
    x_data = df.TIME.values + df.offsetT.values

    cov_dummy = np.array(list(itertools.product([0, 1], repeat=n_covariate)))
    cov_dummy = np.repeat(cov_dummy, res, axis=0)
    cov_dummy_scaled = scaler_cov.transform(cov_dummy)
    x_model = np.linspace(x_data.min(), x_data.max(), res)
    input2 = np.tile(x_model, 2**n_covariate).reshape(-1, 1)
    input2 = np.concatenate((input2, cov_dummy_scaled), axis=1)
    y_model = scaler_y.inverse_transform(sreft.model_y(input2))

    name_covariates_true = [i for i in params_true.columns if "Covariate" in i]
    n_covariate_true = len(name_covariates_true)
    cov_dummy_true = np.array(list(itertools.product([0, 1], repeat=n_covariate_true)))
    cov_dummy_true = np.repeat(cov_dummy_true, res, axis=0)

    fig, axs = plt.subplots(
        n_row,
        n_col,
        figsize=(n_col * 3, n_row * 3),
        tight_layout=True,
        dpi=300,
        sharex="row",
    )
    for k, ax in enumerate(axs.flat):
        if k >= n_biomarker:
            ax.axis("off")
            continue

        if density:
            x_ = x_data[~np.isnan(y_data[:, k])]
            y_ = y_data[~np.isnan(y_data[:, k]), k]
            if np.var(x_) == 0:
                z = gaussian_kde(y_)(y_)
            else:
                xy = np.vstack([x_, y_])
                z = gaussian_kde(xy)(xy)
            idx = z.argsort()
            ax.scatter(x_[idx], y_[idx], c=z[idx], s=2, label="_nolegend_")
        else:
            ax.scatter(x_data, y_data[:, k], c="silver", s=2, label="_nolegend_")

        pred_line = []
        for i in range(2**n_covariate):
            pred_line.extend(
                ax.plot(
                    x_model,
                    y_model[res * i : (res * i + res), k],
                    c=cm(i),
                    lw=3,
                )
            )

        true_line = []
        for i in range(2**n_covariate_true):
            y_true = model_sigmoid(
                x_model,
                cov_dummy_true[res * i : (res * i + res)],
                params_true.loc[k],
            )
            true_line.extend(
                ax.plot(
                    x_model,
                    y_true,
                    c=cm(i),
                    lw=3,
                    ls="dashed",
                )
            )

        ax.set_xlabel("Disease Time (year)")
        ax.set_title(name_biomarkers[k], fontsize=15)

    if n_covariate > 0:
        legend_labels = [
            ", ".join(format(i, f"0{n_covariate}b")) for i in range(2**n_covariate)
        ]
        fig.legend(
            handles=pred_line,
            loc="center",
            framealpha=0,
            bbox_to_anchor=(1.1, 0.7),
            title="Pred\n" + ", ".join(name_covariates),
            labels=legend_labels,
        )

        legend_labels_true = [
            ", ".join(format(i, f"0{n_covariate_true}b"))
            for i in range(2**n_covariate_true)
        ]
        fig.legend(
            handles=true_line,
            loc="center",
            framealpha=0,
            bbox_to_anchor=(1.1, 0.3),
            title="True\n" + ", ".join(name_covariates_true),
            labels=legend_labels_true,
        )

    if save_file_path:
        fig.savefig(save_file_path, transparent=True, bbox_inches="tight")

    return fig

r_squared_plot(df, name_biomarkers, isSort=True, cutoff=0.1, save_file_path=None)

Generate a horizontal bar plot displaying the R-squared values of biomarkers.

Parameters:
  • df (DataFrame) –

    DataFrame containing the biomarker data.

  • name_biomarkers (list[str]) –

    List of column names representing the biomarkers.

  • isSort (bool, default: True ) –

    If True, sort biomarkers by R-squared values. Default is True.

  • cutoff (float, default: 0.1 ) –

    Cutoff value for highlighting specific R-squared values. Biomarkers with R-squared values greater than or equal to cutoff will be highlighted. Default is 0.1.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • fig( Figure ) –

    Matplotlib figure object representing the generated plot.

sreftml\plots.py
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
def r_squared_plot(
    df: pd.DataFrame,
    name_biomarkers: list[str],
    isSort: bool = True,
    cutoff: float = 0.1,
    save_file_path: str | None = None,
) -> plt.Figure:
    """
    Generate a horizontal bar plot displaying the R-squared values of biomarkers.

    Args:
        df (pd.DataFrame): DataFrame containing the biomarker data.
        name_biomarkers (list[str]): List of column names representing the biomarkers.
        isSort (bool, optional): If True, sort biomarkers by R-squared values. Default is True.
        cutoff (float, optional): Cutoff value for highlighting specific R-squared values. Biomarkers with R-squared values greater than or equal to cutoff will be highlighted. Default is 0.1.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        fig (plt.Figure): Matplotlib figure object representing the generated plot.
    """
    res = df[name_biomarkers].values - df.filter(like="_pred", axis=1).values
    res_var = np.nanvar(res, axis=0)
    df_var = np.nanvar(df[name_biomarkers].values, axis=0)
    r_squared = 1 - res_var / df_var

    cm = plt.get_cmap("tab10")
    fig = plt.figure(dpi=300, tight_layout=True)
    if cutoff > 0:
        plt.axvline(x=cutoff, ls="--", c="black")
        colors = [cm(1) if x >= cutoff else cm(0) for x in r_squared]
    else:
        colors = [cm(0) for _ in range(len(r_squared))]

    if isSort:
        rank = np.argsort(r_squared)
        plt.barh(
            [name_biomarkers[i] for i in rank],
            r_squared[rank],
            color=[colors[i] for i in rank],
        )
    else:
        plt.barh(name_biomarkers, r_squared, color=colors)

    plt.xlabel("r_squared")

    if save_file_path:
        plt.savefig(save_file_path, transparent=True)

    return fig

residual_plot(df, name_biomarkers, ncol_max=4, save_file_path=None)

Generate a plot of residuals.

Parameters:
  • df (DataFrame) –

    Input Dataframe. This must contain offsetT, actual value of biomarkers and prediction value of biomarkers.

  • name_biomarkers (List[str]) –

    List of biomarker names.

  • ncol_max (int, default: 4 ) –

    Maximum number of columns. Default is 4.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • Figure( Figure ) –

    The created matplotlib figure.

sreftml\plots.py
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
def residual_plot(
    df: pd.DataFrame,
    name_biomarkers: list[str],
    ncol_max: int = 4,
    save_file_path: str | None = None,
) -> plt.Figure:
    """
    Generate a plot of residuals.

    Args:
        df (pd.DataFrame): Input Dataframe. This must contain offsetT, actual value of biomarkers and prediction value of biomarkers.
        name_biomarkers (List[str]): List of biomarker names.
        ncol_max (int, optional): Maximum number of columns. Default is 4.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        Figure: The created matplotlib figure.
    """
    if not "offsetT" in df.columns:
        warnings.warn(
            "offsetT does not exist in df. df must contain offsetT. Skip drawing residual plot."
        )
        return None
    if not all([f"{biomarker}_pred" in df.columns for biomarker in name_biomarkers]):
        warnings.warn(
            "Some of the prediction values are missing in df. df must contain prediction values of biomarkers. Skip drawing residual plot."
        )
        return None
    n_biomarker = len(name_biomarkers)
    n_row, n_col = n2mfrow(n_biomarker, ncol_max)
    x_data = df.TIME.values + df.offsetT.values

    y_res = (
        df[[f"{biomarker}_pred" for biomarker in name_biomarkers]].values
        - df[name_biomarkers].values
    )

    fig, axs = plt.subplots(
        n_row, n_col, figsize=(n_col * 3, n_row * 3), tight_layout=True, dpi=300
    )
    for k, ax in enumerate(axs.flat):
        if k >= n_biomarker:
            ax.axis("off")
            continue

        ax.scatter(x_data, y_res[:, k], s=2)
        ax.axhline(0, c="black", ls="--")
        ax.set_title(name_biomarkers[k], fontsize=15)
        ax.set_xlabel("Disease Time (year)")
        ax.set_ylabel("y_pred - y_obs")

    if save_file_path is not None:
        fig.savefig(save_file_path, transparent=True)

    return fig

scatter_matrix_plot(df, save_file_path=None)

Plot correlation matrix.

Parameters:
  • df (DataFrame) –

    Input DataFrame.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • PairGrid

    sns.axisgrid.PairGrid: PairGrid object with the correlation plot.

sreftml\plots.py
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
def scatter_matrix_plot(
    df: pd.DataFrame, save_file_path: str | None = None
) -> sns.axisgrid.PairGrid:
    """
    Plot correlation matrix.

    Args:
        df (pd.DataFrame): Input DataFrame.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        sns.axisgrid.PairGrid: PairGrid object with the correlation plot.
    """

    def corrfunc(x, y, **kwds):
        ax = plt.gca()
        ax.tick_params(bottom=False, top=False, left=False, right=False)
        sns.despine(ax=ax, bottom=True, top=True, left=True, right=True)
        r = x.corr(y, method="pearson")
        norm = plt.Normalize(-1, 1)
        facecolor = plt.get_cmap("seismic")(norm(r))
        ax.set_facecolor(facecolor)
        ax.set_alpha(0)
        lightness = (max(facecolor[:3]) + min(facecolor[:3])) / 2
        ax.annotate(
            f"{r:.2f}",
            xy=(0.5, 0.5),
            xycoords=ax.transAxes,
            color="white" if lightness < 0.7 else "black",
            size=26,
            ha="center",
            va="center",
        )

    g = sns.PairGrid(df)
    g.map_diag(sns.histplot, kde=False)
    g.map_lower(plt.scatter, s=2)
    g.map_upper(corrfunc)
    g.figure.tight_layout()

    if save_file_path:
        g.savefig(save_file_path)

    return g

scatter_matrix_plot_extra(df, save_file_path=None)

Plot correlation matrix.

Parameters:
  • df (DataFrame) –

    Input DataFrame.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • PairGrid

    sns.axisgrid.PairGrid: PairGrid object with the correlation plot.

sreftml\plots.py
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
def scatter_matrix_plot_extra(
    df: pd.DataFrame, save_file_path: str | None = None
) -> sns.axisgrid.PairGrid:
    """
    Plot correlation matrix.

    Args:
        df (pd.DataFrame): Input DataFrame.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        sns.axisgrid.PairGrid: PairGrid object with the correlation plot.
    """

    def corrfunc(x, y, **kwds):
        ax = plt.gca()
        ax.tick_params(bottom=False, top=False, left=False, right=False)
        sns.despine(ax=ax, bottom=True, top=True, left=True, right=True)
        r = x.corr(y, method="pearson")
        norm = plt.Normalize(-1, 1)
        facecolor = plt.get_cmap("seismic")(norm(r))
        ax.set_facecolor(facecolor)
        ax.set_alpha(0)
        lightness = (max(facecolor[:3]) + min(facecolor[:3])) / 2
        ax.annotate(
            f"{r:.2f}",
            xy=(0.5, 0.5),
            xycoords=ax.transAxes,
            color="white" if lightness < 0.7 else "black",
            size=26,
            ha="center",
            va="center",
        )

    g = sns.PairGrid(df)
    g.map_diag(sns.histplot, kde=False)
    g.map_lower(plt.scatter, s=2)
    g.map_upper(corrfunc)
    g.figure.tight_layout()

    if save_file_path:
        g.savefig(save_file_path)

    return g

shap_plots(shap_exp_model_1, ncol_max=4, save_dir_path=None)

Plot the SHAP values of the model 1.

Parameters:
  • shap_exp_model_1 (Explanation) –

    The SHAP explanation for the model 1.

  • ncol_max (int, default: 4 ) –

    Maximum number of columns for subplots. Defaults to 4.

  • save_dir_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • tuple[Figure, Figure, Figure]

    tuple[plt.Figure, plt.Figure, plt.Figure]: Plot objects for shap bar, beeswarm and dependence plot.

sreftml\plots.py
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
def shap_plots(
    shap_exp_model_1: shap.Explanation,
    ncol_max: int = 4,
    save_dir_path: str | None = None,
) -> tuple[plt.Figure, plt.Figure, plt.Figure]:
    """
    Plot the SHAP values of the model 1.

    Args:
        shap_exp_model_1 (shap.Explanation): The SHAP explanation for the model 1.
        ncol_max (int, optional): Maximum number of columns for subplots. Defaults to 4.
        save_dir_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        tuple[plt.Figure, plt.Figure, plt.Figure]: Plot objects for shap bar, beeswarm and dependence plot.
    """
    bar_plot = plt.figure(figsize=(5, 5), dpi=300, tight_layout=True)
    shap.plots.bar(shap_exp_model_1, show=False)
    plt.title("model 1")
    if save_dir_path:
        plt.savefig(save_dir_path + "shap_bar_model_1.png", transparent=True)

    beeswarm_plot = plt.figure(figsize=(5, 5), dpi=300, tight_layout=True)
    shap.plots.beeswarm(shap_exp_model_1, show=False)
    if save_dir_path:
        plt.savefig(save_dir_path + "shap_beeswarm_model_1.png", transparent=True)

    n_row, n_col = n2mfrow(shap_exp_model_1.shape[1], ncol_max=ncol_max)
    fig, axs = plt.subplots(
        n_row,
        n_col,
        figsize=(n_col * 4, n_row * 3),
        tight_layout=True,
        dpi=300,
    )
    for k, ax in enumerate(axs.flat):
        if k >= shap_exp_model_1.shape[1]:
            ax.axis("off")
            continue
        shap.plots.scatter(
            shap_exp_model_1[:, k],
            color=shap_exp_model_1,
            x_jitter=0.01,
            ax=ax,
            show=False,
        )
    fig.suptitle("model 1")
    if save_dir_path:
        fig.savefig(save_dir_path + "shap_dependence_model_1.png", transparent=True)

    return bar_plot, beeswarm_plot, fig

single_panel_scatter_plot(df, x_col, y_col, hue=None, duplicate_key=None, density=False, identity=False, save_file_path=None)

Draw a scatter plot using a single panel.

Parameters:
  • df (DataFrame) –

    Input DataFrame.

  • x_col (str) –

    X-axis column in df.

  • y_col (str) –

    Y-axis column in df.

  • hue (str | None, default: None ) –

    Column to stratify the plot. Defaults to None.

  • duplicate_key (list[str] | str | None, default: None ) –

    Specify the column name(s) from which duplicates are to be removed. Defaults to None.

  • density (bool, default: False ) –

    Whether to plot density. Defaults to False.

  • identity (bool, default: False ) –

    Whether to plot identity line. Defaults to False.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • FacetGrid

    sns.axisgrid.FacetGrid: FacetGrid object with the scatter plot.

sreftml\plots.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def single_panel_scatter_plot(
    df: pd.DataFrame,
    x_col: str,
    y_col: str,
    hue: str | None = None,
    duplicate_key: list[str] | str | None = None,
    density: bool = False,
    identity: bool = False,
    save_file_path: str | None = None,
) -> sns.axisgrid.FacetGrid:
    """
    Draw a scatter plot using a single panel.

    Args:
        df (pd.DataFrame): Input DataFrame.
        x_col (str): X-axis column in df.
        y_col (str): Y-axis column in df.
        hue (str | None, optional): Column to stratify the plot. Defaults to None.
        duplicate_key (list[str] | str | None, optional): Specify the column name(s) from which duplicates are to be removed. Defaults to None.
        density (bool, optional): Whether to plot density. Defaults to False.
        identity (bool, optional): Whether to plot identity line. Defaults to False.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        sns.axisgrid.FacetGrid: FacetGrid object with the scatter plot.
    """
    if density:
        hue = None
        warnings.warn("Since density is True, the hue option is ignored.")

    if hue:
        df_ = clean_duplicate(df, [x_col, y_col, hue], duplicate_key)
        unique_hues = np.sort(df_[hue].unique())
        line_kws_ = None
    else:
        df_ = clean_duplicate(df, [x_col, y_col], duplicate_key)
        unique_hues = [None]
        line_kws_ = {"color": "red"}

    scatter_kws_ = {"alpha": 0.5, "s": 20, "edgecolor": "none"}
    if density:
        xy = df_[[x_col, y_col]].values.T
        z = gaussian_kde(xy)(xy)
        scatter_kws_.update({"c": z, "color": None, "cmap": "viridis"})

    g = sns.lmplot(
        data=df_,
        x=x_col,
        y=y_col,
        hue=hue,
        hue_order=unique_hues,
        scatter_kws=scatter_kws_,
        line_kws=line_kws_,
    )
    g.figure.set_dpi(300)

    if identity:
        if df[y_col].max() < df[x_col].min() or df[x_col].max() < df[y_col].min():
            warnings.warn(
                f"The data range of {x_col} and {y_col} is not covered, although idenntity=True. Skip drawing of identity line."
            )
        else:
            min_ = df[[x_col, y_col]].min().max()
            max_ = df[[x_col, y_col]].max().min()
            g.axes[0, 0].plot([min_, max_], [min_, max_], "k--")

    if hue:
        g.axes[0, 0].legend(
            ["_nolegend_", "dummy text", "_nolegned_"] * len(unique_hues)
        )
        for idx, h in enumerate(unique_hues):
            df_hue = df_.loc[df_[hue] == h]
            label_line = get_regression_line_label(df_hue[x_col], df_hue[y_col])
            g.axes[0, 0].get_legend().get_texts()[idx].set_text(label_line)
    else:
        label_line = get_regression_line_label(df_[x_col], df_[y_col])
        g.axes[0, 0].legend(labels=["_nolegend_", label_line])

    if save_file_path is not None:
        plt.savefig(save_file_path, transparent=True, dpi=300)

    return g

surv_analysis_plot(fit_model, ci_show=True, only_best=True, title=None, xlabel='Disease Time (year)', save_dir_path=None)

Generate survival analysis plot.

Parameters:
  • fit_model (dict) –

    A dictionary of survival analysis objects.

  • ci_show (bool, default: True ) –

    Whether to show confidence intervals. Defaults to True.

  • only_best (bool, default: True ) –

    Whether to plot only the best model. Defaults to True.

  • title (str | None, default: None ) –

    Title for each plot. If None, the title from fit_model is used.

  • xlabel (str, default: 'Disease Time (year)' ) –

    X-axis label for each plot. Defaults to "Disease Time (year)".

  • save_dir_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • tuple[Figure, Figure, Figure]

    tuple[plt.Figure, plt.Figure, plt.Figure]: Plot objects for the survival function, cumulative hazard function, and hazard function.

sreftml\plots.py
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
def surv_analysis_plot(
    fit_model: dict,
    ci_show: bool = True,
    only_best: bool = True,
    title: str | None = None,
    xlabel: str = "Disease Time (year)",
    save_dir_path: str | None = None,
) -> tuple[plt.Figure, plt.Figure, plt.Figure]:
    """
    Generate survival analysis plot.

    Args:
        fit_model (dict): A dictionary of survival analysis objects.
        ci_show (bool, optional): Whether to show confidence intervals. Defaults to True.
        only_best (bool, optional): Whether to plot only the best model. Defaults to True.
        title (str | None, optional): Title for each plot. If None, the title from fit_model is used.
        xlabel (str, optional): X-axis label for each plot. Defaults to "Disease Time (year)".
        save_dir_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        tuple[plt.Figure, plt.Figure, plt.Figure]: Plot objects for the survival function, cumulative hazard function, and hazard function.
    """
    if title is None:
        title = fit_model["title"]

    fit_model_parametric = {
        key: value
        for key, value in fit_model.items()
        if key not in ["title", "kmf", "naf"]
    }
    if only_best:
        aics = [i.AIC_ for i in fit_model_parametric.values()]
        best_model = list(fit_model_parametric.keys())[aics.index(min(aics))]
        fit_model_parametric = {best_model: fit_model_parametric[best_model]}

    surv_plot = plt.figure(figsize=(5, 5), dpi=300)
    fit_model["kmf"].plot_survival_function(ci_show=ci_show, lw=2)
    [
        k.plot_survival_function(ci_show=ci_show, lw=2)
        for k in fit_model_parametric.values()
    ]
    plt.xlabel(xlabel)
    plt.ylabel("Survival Function")
    plt.title(title)
    if save_dir_path:
        plt.savefig(save_dir_path + "surv_func.png", transparent=True)

    cumhaz_plot = plt.figure(figsize=(5, 5), dpi=300)
    fit_model["naf"].plot_cumulative_hazard(ci_show=ci_show, lw=2)
    [
        k.plot_cumulative_hazard(ci_show=ci_show, lw=2)
        for k in fit_model_parametric.values()
    ]
    plt.xlabel(xlabel)
    plt.ylabel("Cumlative Hazard Function")
    plt.title(title)
    if save_dir_path:
        plt.savefig(save_dir_path + "cumhaz_func.png", transparent=True)

    haz_plot = plt.figure(figsize=(5, 5), dpi=300)
    fit_model["naf"].plot_hazard(bandwidth=2, ci_show=ci_show, lw=2)
    [k.plot_hazard(ci_show=ci_show, lw=2) for k in fit_model_parametric.values()]
    plt.xlabel(xlabel)
    plt.ylabel("Hazard Function")
    plt.title(title)
    if save_dir_path:
        plt.savefig(save_dir_path + "haz_func.png", transparent=True)

    return surv_plot, cumhaz_plot, haz_plot

var_y_plot(sreft, name_biomarkers, save_file_path=None)

Generate a plot of var_y.

Parameters:
  • sreft (Model) –

    Object responsible for transforming the data.

  • name_biomarkers (list[str]) –

    List of biomarker names.

  • save_file_path (str, default: None ) –

    The path where the plot will be saved. Default to None.

Returns:
  • Figure

    plt.Figure: The plotted figure.

sreftml\plots.py
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
def var_y_plot(
    sreft: tf.keras.Model,
    name_biomarkers: list[str],
    save_file_path: str | None = None,
) -> plt.Figure:
    """
    Generate a plot of var_y.

    Args:
        sreft (tf.keras.Model): Object responsible for transforming the data.
        name_biomarkers (list[str]): List of biomarker names.
        save_file_path (str, optional): The path where the plot will be saved. Default to None.

    Returns:
        plt.Figure: The plotted figure.
    """
    rank = np.argsort(np.exp(sreft.lnvar_y))
    fig = plt.figure(dpi=300, tight_layout=True)
    plt.barh([name_biomarkers[i] for i in rank], np.exp(sreft.lnvar_y)[rank])
    plt.gca().invert_yaxis()
    plt.xlabel("var_y")

    if save_file_path:
        plt.savefig(save_file_path, transparent=True)

    return fig

SReFT

Bases: Model

A model class that extends tf.keras.Model for SReFT_ML.

Attributes:
  • activation (str) –

    The activation function to use.

  • activation_offsetT (str) –

    The activation function for offsetT.

  • output_dim (int) –

    The dimension of the output.

  • latent_dim (int) –

    The dimension of the latent variable.

  • offsetT_min (float) –

    The minimum value of offsetT.

  • offsetT_max (float) –

    The maximum value of offsetT.

  • lnvar_y (Variable) –

    The lnvar_y variable.

  • model_1 (Sequential) –

    A keras model for estimating offsetT.

  • model_y (Sequential) –

    A keras model for estimating prediction.

Source code in sreftml\sreftml_model.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class SReFT(tf.keras.Model):
    """
    A model class that extends tf.keras.Model for SReFT_ML.

    Attributes:
        activation (str): The activation function to use.
        activation_offsetT (str): The activation function for offsetT.
        output_dim (int): The dimension of the output.
        latent_dim (int): The dimension of the latent variable.
        offsetT_min (float): The minimum value of offsetT.
        offsetT_max (float): The maximum value of offsetT.
        lnvar_y (tf.Variable): The lnvar_y variable.
        model_1 (tf.keras.Sequential): A keras model for estimating offsetT.
        model_y (tf.keras.Sequential): A keras model for estimating prediction.
    """

    def __init__(
        self,
        output_dim: int,
        latent_dim_model_1: int,
        latent_dim_model_y: int,
        activation_model_1_mid: str = "sigmoid",
        activation_model_1_out: str = "softplus",
        activation_model_y_mid: str = "tanh",
        offsetT_min: float = -np.inf,
        offsetT_max: float = np.inf,
        random_state: int | None = None,
    ) -> None:
        """
        Initialize a new instance of SReFT_ML.

        Args:
            output_dim (int, optional): The dimension of the output. Defaults to 4.
            latent_dim_model_1 (int): The dimension of the latent dimention of model_1.
            latent_dim_model_1 (int): The dimension of the latent dimention of model_y.
            activation_model_1_mid (str, optional): The activation function to use. Defaults to "sigmoid".
            activation_model_1_out (str, optional): The activation function to use. Defaults to "softplus".
            activation_model_y_mid (str, optional): The activation function to use. Defaults to "tanh".
            offsetT_min (float, optional): The minimum value of offsetT. Defaults to -np.inf.
            offsetT_max (float, optional): The maximum value of offsetT. Defaults to np.inf.
            random_state (int | None, optional): The seed for random number generation. Defaults to None.
        """
        super(SReFT, self).__init__()

        initializer = tf.keras.initializers.GlorotUniform(seed=random_state)
        tf.random.set_seed(random_state)

        self.output_dim = int(output_dim)
        self.latent_dim_model_1 = int(latent_dim_model_1)
        self.latent_dim_model_y = int(latent_dim_model_y)
        self.activation_model_1_mid = activation_model_1_mid
        self.activation_model_1_out = activation_model_1_out
        self.activation_model_y_mid = activation_model_y_mid

        self.offsetT_min = offsetT_min
        self.offsetT_max = offsetT_max

        self.lnvar_y = tf.Variable(tf.zeros(self.output_dim))

        self.model_1 = tf.keras.Sequential(name="estimate_offsetT")
        self.model_1.add(
            tf.keras.layers.Dense(
                self.latent_dim_model_1,
                activation=self.activation_model_1_mid,
                kernel_initializer=initializer,
            )
        )
        self.model_1.add(
            tf.keras.layers.Dense(
                1,
                activation=self.activation_model_1_out,
                kernel_initializer=initializer,
            )
        )

        self.model_y = tf.keras.Sequential(name="estimate_prediction")
        self.model_y.add(
            tf.keras.layers.Dense(
                self.latent_dim_model_y,
                activation=self.activation_model_y_mid,
                kernel_initializer=initializer,
            )
        )
        self.model_y.add(
            tf.keras.layers.Dense(
                self.output_dim, activation=None, kernel_initializer=initializer
            )
        )

    def call(
        self,
        inputs: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
        training: bool = False,
        **kwargs,
    ) -> tf.Tensor:
        """
        Call the model with the given inputs.

        Args:
            inputs (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): The inputs for the model.
            training (bool, optional): Whether the model is in training mode. Defaults to False.

        Returns:
            tf.Tensor: The predicted y values.
            :param **kwargs:
        """
        (input_x, input_cov, input_m, input_y) = inputs
        input1 = tf.concat((input_m, input_cov), axis=-1, name="concat")
        offset = self.model_1(input1, training=training)
        offset = tf.clip_by_value(
            offset, self.offsetT_min, self.offsetT_max, name="clip"
        )
        dis_time = tf.add(input_x, offset, name="add")

        input2 = tf.concat((dis_time, input_cov), axis=-1, name="concat")
        y_pred = self.model_y(input2, training=training)

        obj = utilities.tf_compute_negative_log_likelihood(
            input_y, y_pred, self.lnvar_y
        )
        self.add_loss(tf.reduce_sum(obj))
        self.add_metric(tf.reduce_mean(obj), name="loss")

        return y_pred

    def build_graph(self, shapes: tuple[int, int, int, int]) -> tf.keras.Model:
        """
        Build the computational graph for the model.

        Args:
            shapes (tuple[int, int, int, int]): The shapes of the inputs.

        Returns:
            tf.keras.Model: The model with the built computational graph.
        """
        input_x = tf.keras.layers.Input(shape=shapes[0], name="time")
        input_cov = tf.keras.layers.Input(shape=shapes[1], name="covariate")
        input_m = tf.keras.layers.Input(shape=shapes[2], name="feature")
        input_y = tf.keras.layers.Input(shape=shapes[3], name="observation")

        return tf.keras.Model(
            inputs=[input_x, input_cov, input_m],
            outputs=self.call((input_x, input_cov, input_m, input_y)),
        )

__init__(output_dim, latent_dim_model_1, latent_dim_model_y, activation_model_1_mid='sigmoid', activation_model_1_out='softplus', activation_model_y_mid='tanh', offsetT_min=-np.inf, offsetT_max=np.inf, random_state=None)

Initialize a new instance of SReFT_ML.

Parameters:
  • output_dim (int) –

    The dimension of the output. Defaults to 4.

  • latent_dim_model_1 (int) –

    The dimension of the latent dimention of model_1.

  • latent_dim_model_1 (int) –

    The dimension of the latent dimention of model_y.

  • activation_model_1_mid (str, default: 'sigmoid' ) –

    The activation function to use. Defaults to "sigmoid".

  • activation_model_1_out (str, default: 'softplus' ) –

    The activation function to use. Defaults to "softplus".

  • activation_model_y_mid (str, default: 'tanh' ) –

    The activation function to use. Defaults to "tanh".

  • offsetT_min (float, default: -inf ) –

    The minimum value of offsetT. Defaults to -np.inf.

  • offsetT_max (float, default: inf ) –

    The maximum value of offsetT. Defaults to np.inf.

  • random_state (int | None, default: None ) –

    The seed for random number generation. Defaults to None.

sreftml\sreftml_model.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def __init__(
    self,
    output_dim: int,
    latent_dim_model_1: int,
    latent_dim_model_y: int,
    activation_model_1_mid: str = "sigmoid",
    activation_model_1_out: str = "softplus",
    activation_model_y_mid: str = "tanh",
    offsetT_min: float = -np.inf,
    offsetT_max: float = np.inf,
    random_state: int | None = None,
) -> None:
    """
    Initialize a new instance of SReFT_ML.

    Args:
        output_dim (int, optional): The dimension of the output. Defaults to 4.
        latent_dim_model_1 (int): The dimension of the latent dimention of model_1.
        latent_dim_model_1 (int): The dimension of the latent dimention of model_y.
        activation_model_1_mid (str, optional): The activation function to use. Defaults to "sigmoid".
        activation_model_1_out (str, optional): The activation function to use. Defaults to "softplus".
        activation_model_y_mid (str, optional): The activation function to use. Defaults to "tanh".
        offsetT_min (float, optional): The minimum value of offsetT. Defaults to -np.inf.
        offsetT_max (float, optional): The maximum value of offsetT. Defaults to np.inf.
        random_state (int | None, optional): The seed for random number generation. Defaults to None.
    """
    super(SReFT, self).__init__()

    initializer = tf.keras.initializers.GlorotUniform(seed=random_state)
    tf.random.set_seed(random_state)

    self.output_dim = int(output_dim)
    self.latent_dim_model_1 = int(latent_dim_model_1)
    self.latent_dim_model_y = int(latent_dim_model_y)
    self.activation_model_1_mid = activation_model_1_mid
    self.activation_model_1_out = activation_model_1_out
    self.activation_model_y_mid = activation_model_y_mid

    self.offsetT_min = offsetT_min
    self.offsetT_max = offsetT_max

    self.lnvar_y = tf.Variable(tf.zeros(self.output_dim))

    self.model_1 = tf.keras.Sequential(name="estimate_offsetT")
    self.model_1.add(
        tf.keras.layers.Dense(
            self.latent_dim_model_1,
            activation=self.activation_model_1_mid,
            kernel_initializer=initializer,
        )
    )
    self.model_1.add(
        tf.keras.layers.Dense(
            1,
            activation=self.activation_model_1_out,
            kernel_initializer=initializer,
        )
    )

    self.model_y = tf.keras.Sequential(name="estimate_prediction")
    self.model_y.add(
        tf.keras.layers.Dense(
            self.latent_dim_model_y,
            activation=self.activation_model_y_mid,
            kernel_initializer=initializer,
        )
    )
    self.model_y.add(
        tf.keras.layers.Dense(
            self.output_dim, activation=None, kernel_initializer=initializer
        )
    )

build_graph(shapes)

Build the computational graph for the model.

Parameters:
  • shapes (tuple[int, int, int, int]) –

    The shapes of the inputs.

Returns:
  • Model

    tf.keras.Model: The model with the built computational graph.

sreftml\sreftml_model.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def build_graph(self, shapes: tuple[int, int, int, int]) -> tf.keras.Model:
    """
    Build the computational graph for the model.

    Args:
        shapes (tuple[int, int, int, int]): The shapes of the inputs.

    Returns:
        tf.keras.Model: The model with the built computational graph.
    """
    input_x = tf.keras.layers.Input(shape=shapes[0], name="time")
    input_cov = tf.keras.layers.Input(shape=shapes[1], name="covariate")
    input_m = tf.keras.layers.Input(shape=shapes[2], name="feature")
    input_y = tf.keras.layers.Input(shape=shapes[3], name="observation")

    return tf.keras.Model(
        inputs=[input_x, input_cov, input_m],
        outputs=self.call((input_x, input_cov, input_m, input_y)),
    )

call(inputs, training=False, **kwargs)

Call the model with the given inputs.

Parameters:
  • inputs (tuple[ndarray, ndarray, ndarray, ndarray]) –

    The inputs for the model.

  • training (bool, default: False ) –

    Whether the model is in training mode. Defaults to False.

Returns:
  • Tensor

    tf.Tensor: The predicted y values.

  • Tensor

    param **kwargs:

sreftml\sreftml_model.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def call(
    self,
    inputs: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
    training: bool = False,
    **kwargs,
) -> tf.Tensor:
    """
    Call the model with the given inputs.

    Args:
        inputs (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): The inputs for the model.
        training (bool, optional): Whether the model is in training mode. Defaults to False.

    Returns:
        tf.Tensor: The predicted y values.
        :param **kwargs:
    """
    (input_x, input_cov, input_m, input_y) = inputs
    input1 = tf.concat((input_m, input_cov), axis=-1, name="concat")
    offset = self.model_1(input1, training=training)
    offset = tf.clip_by_value(
        offset, self.offsetT_min, self.offsetT_max, name="clip"
    )
    dis_time = tf.add(input_x, offset, name="add")

    input2 = tf.concat((dis_time, input_cov), axis=-1, name="concat")
    y_pred = self.model_y(input2, training=training)

    obj = utilities.tf_compute_negative_log_likelihood(
        input_y, y_pred, self.lnvar_y
    )
    self.add_loss(tf.reduce_sum(obj))
    self.add_metric(tf.reduce_mean(obj), name="loss")

    return y_pred

hp_search_for_sreftml(df, scaled_features, grid_dict, n_grid_sample=0, n_splits=3, random_seed=42, callbacks=None, epochs=9999, batch_size=256)

Perform hyperparameter search for the SReFT_ML.

Parameters:
  • df (DataFrame) –

    Input dataframe containing the data.

  • scaled_features (tuple) –

    Tuple of scaled feature. Pass x, cov, m and y in that order.

  • grid_dict (dict[list[any]]) –

    Dictionary of hyperparameter names and the corresponding values to be tested.

  • n_grid_sample (int, default: 0 ) –

    Number of samples to select randomly from the grid. Greater than 0 for random search and 0 or less for grid search. Default to 0.

  • n_splits (int, default: 3 ) –

    Number of splits for cross-validation. 2 or more is required. Default to 3.

  • random_seed (int, default: 42 ) –

    Random seed for reproducibility. Default to 42.

  • callbacks (list[any] | None, default: None ) –

    Callbacks to be used during model training. Default to None.

  • epochs (int, default: 9999 ) –

    Specifies the number of epochs to pass to the SReFT class. Defaults to 9999.

  • batch_size (int, default: 256 ) –

    Default to 256.

Returns:
  • DataFrame

    pd.DataFrame: Dataframe containing the hyperparameters and corresponding scores.

sreftml\sreftml_model.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def hp_search_for_sreftml(
    df: pd.DataFrame,
    scaled_features: tuple,
    grid_dict: dict[list[any]],
    n_grid_sample: int = 0,
    n_splits: int = 3,
    random_seed: int = 42,
    callbacks: list[any] | None = None,
    epochs: int = 9999,
    batch_size: int = 256,
) -> pd.DataFrame:
    """
    Perform hyperparameter search for the SReFT_ML.

    Args:
        df (pd.DataFrame): Input dataframe containing the data.
        scaled_features (tuple): Tuple of scaled feature. Pass x, cov, m and y in that order.
        grid_dict (dict[list[any]]): Dictionary of hyperparameter names and the corresponding values to be tested.
        n_grid_sample (int, optional): Number of samples to select randomly from the grid. Greater than 0 for random search and 0 or less for grid search. Default to 0.
        n_splits (int, optional): Number of splits for cross-validation. 2 or more is required. Default to 3.
        random_seed (int, optional): Random seed for reproducibility. Default to 42.
        callbacks (list[any] | None, optional): Callbacks to be used during model training. Default to None.
        epochs (int, optional): Specifies the number of epochs to pass to the SReFT class. Defaults to 9999.
        batch_size (int, optional): Default to 256.

    Returns:
        pd.DataFrame: Dataframe containing the hyperparameters and corresponding scores.

    """
    grid_prms = [i for i in itertools.product(*grid_dict.values())]
    df_grid = pd.DataFrame(grid_prms, columns=grid_dict.keys())
    if n_grid_sample > 0:
        df_grid = df_grid.sample(min(int(n_grid_sample), len(df_grid)))

    x_scaled, cov_scaled, m_scaled, y_scaled = scaled_features

    scores = []
    gkf = GroupKFold(n_splits=n_splits)
    for i, (tmp_train_idx, tmp_vali_idx) in enumerate(gkf.split(X=df, groups=df.ID)):
        for tmp_grid, grid_items in df_grid.iterrows():
            current_iter = i * len(df_grid) + (tmp_grid + 1)
            current_hp = ", ".join([f"{j}: {grid_items[j]}" for j in grid_items.keys()])
            print(f"\n({current_iter}/{n_splits * len(df_grid)}) {current_hp} -----")

            tmp_sreft = SReFT(
                output_dim=y_scaled.shape[1],
                latent_dim_model_1=m_scaled.shape[1],
                latent_dim_model_y=y_scaled.shape[1],
                activation_model_1_mid=grid_items["activation_model_1_mid"],
                activation_model_1_out=grid_items["activation_model_1_out"],
                activation_model_y_mid=grid_items["activation_model_y_mid"],
                random_state=random_seed,
            )
            tmp_sreft.compile(optimizer=tf.keras.optimizers.Adam(grid_items["adam_lr"]))
            tmp_sreft.fit(
                (
                    x_scaled[tmp_train_idx, :],
                    cov_scaled[tmp_train_idx, :],
                    m_scaled[tmp_train_idx, :],
                    y_scaled[tmp_train_idx, :],
                ),
                y_scaled[tmp_train_idx, :],
                validation_data=(
                    (
                        x_scaled[tmp_vali_idx, :],
                        cov_scaled[tmp_vali_idx, :],
                        m_scaled[tmp_vali_idx, :],
                        y_scaled[tmp_vali_idx, :],
                    ),
                    y_scaled[tmp_vali_idx, :],
                ),
                batch_size=batch_size,
                epochs=epochs,
                verbose=0,
                callbacks=callbacks,
            )

            y_pred = tmp_sreft(
                (
                    x_scaled[tmp_vali_idx, :],
                    cov_scaled[tmp_vali_idx, :],
                    m_scaled[tmp_vali_idx, :],
                    y_scaled[tmp_vali_idx, :],
                )
            )
            temp_score = utilities.np_compute_negative_log_likelihood(
                y_scaled[tmp_vali_idx, :], y_pred, tmp_sreft.lnvar_y
            )
            scores.append(np.nanmean(temp_score))

    df_grid["score"] = np.array(scores).reshape(n_splits, -1).mean(axis=0).round(3)

    return df_grid

calc_shap_explanation(sreft, feature_names, cov_scaled, m_scaled)

Calculate the SHAP values for model 1.

Parameters:
  • sreft (Model) –

    The model for which to calculate SHAP values.

  • feature_names (list[str]) –

    Provide the column names for 'm' and 'cov'. 'm' comes first, followed by 'cov'.

  • cov_scaled (ndarray) –

    The scaled covariate values.

  • m_scaled (ndarray) –

    The scaled m values.

Returns:
  • Explanation

    shap.Explanation: The explanation of SHAP values.

sreftml\utilities.py
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
def calc_shap_explanation(
    sreft: tf.keras.Model,
    feature_names: list[str],
    cov_scaled: np.ndarray,
    m_scaled: np.ndarray,
) -> shap.Explanation:
    """
    Calculate the SHAP values for model 1.

    Args:
        sreft (tf.keras.Model): The model for which to calculate SHAP values.
        feature_names (list[str]): Provide the column names for 'm' and 'cov'. 'm' comes first, followed by 'cov'.
        cov_scaled (np.ndarray): The scaled covariate values.
        m_scaled (np.ndarray): The scaled m values.

    Returns:
        shap.Explanation: The explanation of SHAP values.
    """
    input1 = np.concatenate((m_scaled, cov_scaled), axis=-1)
    explainer_model_1 = shap.Explainer(
        sreft.model_1,
        input1,
        algorithm="permutation",
        seed=42,
        feature_names=feature_names,
    )
    shap_value_model_1 = explainer_model_1(input1)
    shap_exp_model_1 = shap.Explanation(
        shap_value_model_1.values,
        shap_value_model_1.base_values[0][0],
        shap_value_model_1.data,
        feature_names=feature_names,
    )

    return shap_exp_model_1

calculate_offsetT_prediction(sreft, df, scaled_features, scaler_y, name_biomarkers)

Calculate offsetT and prediction value of biomarkers.

Parameters:
  • sreft (Model) –

    The trained SReFT model.

  • df (DataFrame) –

    The input DataFrame.

  • scaled_features (tuple[ndarray, ndarray, ndarray, ndarray]) –

    The scaled features. Pass x, cov, m, and y in that order.

  • scaler_y (StanderdScaler) –

    The scaler for y.

  • name_biomarkers (list[str]) –

    List of biomarker names.

Returns:
  • DataFrame

    pd.DataFrame: The DataFrame including the columns of the input DataFrame, offsetT and the prediction values.

sreftml\utilities.py
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
def calculate_offsetT_prediction(
    sreft: tf.keras.Model,
    df: pd.DataFrame,
    scaled_features: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
    scaler_y: sp.StandardScaler,
    name_biomarkers: list[str],
) -> pd.DataFrame:
    """
    Calculate offsetT and prediction value of biomarkers.

    Args:
        sreft (tf.keras.Model): The trained SReFT model.
        df (pd.DataFrame): The input DataFrame.
        scaled_features (tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]): The scaled features. Pass x, cov, m, and y in that order.
        scaler_y (sp.StanderdScaler): The scaler for y.
        name_biomarkers (list[str]): List of biomarker names.

    Returns:
        pd.DataFrame: The DataFrame including the columns of the input DataFrame, offsetT and the prediction values.
    """
    df_ = df.copy()
    x_scaled, cov_scaled, m_scaled, y_scaled = scaled_features
    offsetT = sreft.model_1(np.concatenate((m_scaled, cov_scaled), axis=-1))
    y_pred = pd.DataFrame(
        scaler_y.inverse_transform(sreft(scaled_features)),
        columns=[f"{biomarker}_pred" for biomarker in name_biomarkers],
    )
    df_ = df_.reset_index().assign(offsetT=offsetT, **y_pred)
    return df_

clean_duplicate(df, cols, duplicate_key)

Checks for duplicate entries in the DataFrame based on the specified columns and removes NaNs; also removes duplicate entries if a subset is specified.

Parameters:
  • df (DataFrame) –

    The DataFrame to check and drop duplicates from.

  • cols (list[str]) –

    List of column names to check (and remove) for duplicates.

  • duplicate_key (list[str] | str | None) –

    If specify, duplicate deletion will be performed. Then, check duplicate within sepecified columns.

Returns:
  • DataFrame

    pd.DataFrame: DataFrame with duplicates removed. It includes only the columns specified in cols and duplicate_key.

Warns:
  • The warning message depends on the `subset` parameter
sreftml\utilities.py
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def clean_duplicate(
    df: pd.DataFrame, cols: list[str], duplicate_key: list[str] | str | None
) -> pd.DataFrame:
    """
    Checks for duplicate entries in the DataFrame based on the specified columns and removes NaNs; also removes duplicate entries if a subset is specified.

    Parameters:
        df (pd.DataFrame): The DataFrame to check and drop duplicates from.
        cols (list[str]): List of column names to check (and remove) for duplicates.
        duplicate_key (list[str] | str | None): If specify, duplicate deletion will be performed. Then, check duplicate within sepecified columns.

    Returns:
        pd.DataFrame: DataFrame with duplicates removed. It includes only the columns specified in cols and duplicate_key.

    Warnings:
        If any duplicates are found in the DataFrame after cleaning, a warning message is displayed.
        The warning message depends on the `subset` parameter:
        - If `subset` is None, the warning message indicates that some records are duplicates across all columns in `cols`.
        - If `subset` is not None, the warning message indicates that some records are duplicates within the same subset.
    """
    if type(duplicate_key) is str:
        duplicate_key = [duplicate_key]

    if duplicate_key is None:
        df_ = df[cols].dropna()
        if df_.duplicated().any():
            warnings.warn(
                "Some records are duplicates. Set duplicate_key if necessary."
            )
    else:
        df_ = df[cols + duplicate_key].dropna().drop_duplicates()
        if df_.duplicated(subset=duplicate_key).any():
            warnings.warn(
                "Duplicate records remain in some duplicate_keys. Add duplicate_key if necessary."
            )

    return df_

compute_permutation_importance(random_seed, sreft, cov_test, m_test, n_sample)

Compute permutation importance of the model.

Parameters:
  • random_seed (int) –

    The seed for the random number generator.

  • sreft (Model) –

    The model for which to calculate permutation importance.

  • cov_test (ndarray) –

    The covariates test data.

  • m_test (ndarray) –

    The m test data.

  • n_sample (int) –

    The number of samples.

Returns:
  • tuple[ndarray, ndarray]

    tuple[np.ndarray, np.ndarray]: The mean and standard deviation of the permutation importance.

sreftml\utilities.py
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
def compute_permutation_importance(
    random_seed: int,
    sreft: tf.keras.Model,
    cov_test: np.ndarray,
    m_test: np.ndarray,
    n_sample: int,
) -> tuple[np.ndarray, np.ndarray]:
    """
    Compute permutation importance of the model.

    Args:
        random_seed (int): The seed for the random number generator.
        sreft (tf.keras.Model): The model for which to calculate permutation importance.
        cov_test (np.ndarray): The covariates test data.
        m_test (np.ndarray): The m test data.
        n_sample (int): The number of samples.

    Returns:
        tuple[np.ndarray, np.ndarray]: The mean and standard deviation of the permutation importance.
    """
    rng = np.random.default_rng(random_seed)
    offestt_pred = sreft.model_1(np.concatenate((m_test, cov_test), axis=-1)).numpy()

    mean_pi = []
    std_pi = []
    n_pi = m_test.shape[1] + cov_test.shape[1]

    for i in range(n_pi):
        pis = []
        for j in range(n_sample):
            if i < m_test.shape[1]:
                m_test_rand = np.copy(m_test)
                rng.shuffle(m_test_rand[:, i])
                y_pred_rand = sreft.model_1(
                    np.concatenate((m_test_rand, cov_test), axis=-1)
                ).numpy()
            else:
                cov_test_rand = np.copy(cov_test)
                rng.shuffle(cov_test_rand[:, i - m_test.shape[1]])
                y_pred_rand = sreft.model_1(
                    np.concatenate((m_test, cov_test_rand), axis=-1)
                ).numpy()

            nglls_diff = (offestt_pred - y_pred_rand) ** 2
            temp_pi = np.nanmean(nglls_diff)
            pis.append(temp_pi)

        mean_pi.append(np.mean(pis))
        std_pi.append(np.std(pis))

    return np.array(mean_pi), np.array(std_pi)

compute_permutation_importance_(random_seed, sreft, x_test, cov_test, m_test, y_test, n_sample)

[Superseded] Compute permutation importance of the model.

Parameters:
  • random_seed (int) –

    The seed for the random number generator.

  • sreft (Model) –

    The model for which to calculate permutation importance.

  • x_test (ndarray) –

    The x test data.

  • cov_test (ndarray) –

    The covariates test data.

  • m_test (ndarray) –

    The m test data.

  • y_test (ndarray) –

    The y test data.

  • n_sample (int) –

    The number of samples.

Returns:
  • tuple[ndarray, ndarray]

    tuple[np.ndarray, np.ndarray]: The mean and standard deviation of the permutation importance.

sreftml\utilities.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def compute_permutation_importance_(
    random_seed: int,
    sreft: tf.keras.Model,
    x_test: np.ndarray,
    cov_test: np.ndarray,
    m_test: np.ndarray,
    y_test: np.ndarray,
    n_sample: int,
) -> tuple[np.ndarray, np.ndarray]:
    """
    [Superseded] Compute permutation importance of the model.

    Args:
        random_seed (int): The seed for the random number generator.
        sreft (tf.keras.Model): The model for which to calculate permutation importance.
        x_test (np.ndarray): The x test data.
        cov_test (np.ndarray): The covariates test data.
        m_test (np.ndarray): The m test data.
        y_test (np.ndarray): The y test data.
        n_sample (int): The number of samples.

    Returns:
        tuple[np.ndarray, np.ndarray]: The mean and standard deviation of the permutation importance.
    """
    rng = np.random.default_rng(random_seed)
    y_pred = sreft((x_test, cov_test, m_test, y_test)).numpy()
    neglls_orig = np_compute_negative_log_likelihood(y_test, y_pred, sreft.lnvar_y)

    mean_pi = []
    std_pi = []
    n_pi = m_test.shape[1] + cov_test.shape[1]

    for i in range(n_pi):
        pis = []
        for j in range(n_sample):
            if i < m_test.shape[1]:
                m_test_rand = np.copy(m_test)
                rng.shuffle(m_test_rand[:, i])
                y_pred_rand = sreft((x_test, cov_test, m_test_rand, y_test)).numpy()
            else:
                cov_test_rand = np.copy(cov_test)
                rng.shuffle(cov_test_rand[:, i - m_test.shape[1]])
                y_pred_rand = sreft((x_test, cov_test_rand, m_test, y_test)).numpy()

            neglls_rand = np_compute_negative_log_likelihood(
                y_test, y_pred_rand, sreft.lnvar_y
            )
            nglls_diff = neglls_rand - neglls_orig
            temp_pi = np.nanmean(nglls_diff)
            pis.append(temp_pi)

        mean_pi.append(np.mean(pis))
        std_pi.append(np.std(pis))

    return np.array(mean_pi), np.array(std_pi)

get_current_commit_hash()

Retrieves the current commit hash of the git repository.

Returns:
  • str( str ) –

    The current commit hash or a placeholder string if an error occurs.

sreftml\utilities.py
264
265
266
267
268
269
270
271
272
273
274
275
276
def get_current_commit_hash() -> str:
    """
    Retrieves the current commit hash of the git repository.

    Returns:
        str: The current commit hash or a placeholder string if an error occurs.
    """
    try:
        commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"])
        return commit_hash.strip().decode("utf-8")
    except subprocess.CalledProcessError:
        warnings.warn("Could not get the current commit hash.", UserWarning)
        return "commit_hash_not_available"

linear_regression_each_subject(df, y_columns)

Perform linear regression for each subject (ID) in the given DataFrame.

Parameters:
  • df (DataFrame) –

    The input DataFrame containing the data for the regression. It must include columns for 'ID', 'TIME', and the target variables specified in 'y_columns'.

  • y_columns (list[str]) –

    A list of column names (strings) representing the target variables to be regressed.

Returns:
  • DataFrame

    pd.DataFrame: A DataFrame with the regression results for each subject.

sreftml\utilities.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def linear_regression_each_subject(
    df: pd.DataFrame, y_columns: list[str]
) -> pd.DataFrame:
    """
    Perform linear regression for each subject (ID) in the given DataFrame.

    Args:
        df (pd.DataFrame): The input DataFrame containing the data for the regression. It must include columns for 'ID', 'TIME', and the target variables specified in 'y_columns'.
        y_columns (list[str]): A list of column names (strings) representing the target variables to be regressed.

    Returns:
        pd.DataFrame: A DataFrame with the regression results for each subject.
    """
    model = LinearRegression()
    results = {"ID": df.ID.unique()}

    for y in y_columns:
        slopes = []
        intercepts = []

        for _, group in df.groupby("ID"):
            x_values = group["TIME"].values.reshape(-1, 1)
            y_values = group[y].values

            valid_mask = ~np.isnan(y_values)
            valid_sample_count = valid_mask.sum()

            if valid_sample_count == 0:
                slopes.append(np.nan)
                intercepts.append(np.nan)
                continue

            model.fit(x_values[valid_mask], y_values[valid_mask])

            if valid_sample_count == 1:
                slopes.append(np.nan)
            else:
                slopes.append(model.coef_[0])
            intercepts.append(model.intercept_)

        results[f"{y}_slope"] = slopes
        results[f"{y}_intercept"] = intercepts

    result = pd.DataFrame(results)
    result = result[
        ["ID"] + [i + j for j in ["_slope", "_intercept"] for i in y_columns]
    ]

    return result

load_shap(path_to_shap_file)

Load the specified SHAP binary file and return the SHAP explanations.

Parameters:
  • path_to_shap_file (str) –

    The path to the SHAP file.

Returns:
  • Explanation( Explanation ) –

    The explanation of SHAP values.

sreftml\utilities.py
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
def load_shap(
    path_to_shap_file: str,
) -> shap.Explanation:
    """
    Load the specified SHAP binary file and return the SHAP explanations.

    Args:
        path_to_shap_file (str): The path to the SHAP file.

    Returns:
        Explanation: The explanation of SHAP values.
    """
    with open(path_to_shap_file, "rb") as p:
        shap_exp = pickle.load(p)

    return shap_exp

mixed_effect_linear_regression(df, y_columns)

Perform mixed-effects linear regression on the given DataFrame.

Parameters:
  • df (DataFrame) –

    The input DataFrame containing the data for the regression. It must include columns for 'ID', 'TIME', and the target variables specified in 'y_columns'.

  • y_columns (list[str]) –

    A list of column names (strings) representing the target variables to be regressed.

Returns:
  • tuple( tuple[DataFrame, list] ) –

    A tuple containing two elements: - result (pd.DataFrame): The DataFrame with the fitted regression parameters for each individual. - models (list): A list of fitted mixed-effects regression models for each target variable.

sreftml\utilities.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
def mixed_effect_linear_regression(
    df: pd.DataFrame, y_columns: list[str]
) -> tuple[pd.DataFrame, list]:
    """
    Perform mixed-effects linear regression on the given DataFrame.

    Args:
        df (pd.DataFrame): The input DataFrame containing the data for the regression.
            It must include columns for 'ID', 'TIME', and the target variables specified in 'y_columns'.
        y_columns (list[str]): A list of column names (strings) representing the target variables to be regressed.

    Returns:
        tuple: A tuple containing two elements:
            - result (pd.DataFrame): The DataFrame with the fitted regression parameters for each individual.
            - models (list): A list of fitted mixed-effects regression models for each target variable.
    """
    result = pd.DataFrame(df.ID.unique()).set_axis(["ID"], axis=1)
    models = []

    for y in y_columns:
        df_ = (
            df[["ID", "TIME", y]]
            .dropna()
            .reset_index(drop=True)
            .set_axis(["ID", "TIME", "TARGET"], axis=1)
        )
        if df_["TIME"].nunique() == 1:
            warnings.warn(
                f"Only one time point is available for {y}. The slope cannot be calculated."
            )
            tmp = pd.DataFrame(
                {
                    "ID": df_.ID.unique(),
                    f"{y}_slope": np.nan,
                    f"{y}_intercept": df_.groupby("ID")["TARGET"].mean().values,
                }
            )
            result = result.merge(tmp, how="outer")
            models.append(NullModel(df_.groupby("ID")["TARGET"].mean().mean(), np.nan))
            continue

        full_model = smf.mixedlm(
            "TARGET ~ TIME", data=df_, groups="ID", re_formula="~TIME"
        ).fit()
        random_effects = pd.DataFrame(full_model.random_effects).T.values
        params_pop = full_model.params[0:2].values.T
        params_ind = pd.DataFrame(params_pop + random_effects).set_axis(
            [f"{y}_intercept", f"{y}_slope"], axis=1
        )
        params_ind["ID"] = pd.DataFrame(full_model.random_effects).T.index.values
        result = result.merge(params_ind, how="outer")
        models.append(full_model)

    result = result[
        ["ID"] + [i + j for j in ["_slope", "_intercept"] for i in y_columns]
    ]

    return result, models

multi_column_filter(df, upper_lim=None, lower_lim=None, IQR_filter=None)

Applies limits and IQR filtering on DataFrame columns.

Operations

NaN substitution for values outside the specified upper and lower limits. IQR-based outlier removal in specified columns.

Parameters:
  • df (DataFrame) –

    The DataFrame to be filtered.

  • upper_lim (dict[str, float], default: None ) –

    Upper limits per column.

  • lower_lim (dict[str, float], default: None ) –

    Lower limits per column.

  • IQR_filter (list, default: None ) –

    Columns for IQR outlier detection

Returns:
  • pd.DataFrame: DataFrame after applying the defined filters.

Notes

Overlapping upper_lim/lower_lim and IQR_filter keys cause warnings and filtering by upper_lim/lower_lim.

sreftml\utilities.py
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
def multi_column_filter(
    df: pd.DataFrame,
    upper_lim: dict[str, float] = None,
    lower_lim: dict[str, float] = None,
    IQR_filter: list = None,
):
    """
    Applies limits and IQR filtering on DataFrame columns.

    Operations:
        NaN substitution for values outside the specified upper and lower limits.
        IQR-based outlier removal in specified columns.

    Args:
        df (pd.DataFrame): The DataFrame to be filtered.
        upper_lim (dict[str, float], optional): Upper limits per column.
        lower_lim (dict[str, float], optional): Lower limits per column.
        IQR_filter (list, optional): Columns for IQR outlier detection

    Returns:
        pd.DataFrame: DataFrame after applying the defined filters.

    Notes:
        Overlapping `upper_lim`/`lower_lim` and `IQR_filter` keys cause warnings
        and filtering by `upper_lim`/`lower_lim`.
    """
    df_filtered = df.copy()
    if upper_lim is None:
        upper_lim = {}
    if lower_lim is None:
        lower_lim = {}
    if IQR_filter is None:
        IQR_filter = []

    if upper_lim:
        for k, v in upper_lim.items():
            df_filtered.loc[df_filtered[k] > v, k] = np.nan
        overlap_upper_IQR = set(upper_lim.keys()) & set(IQR_filter)
        if overlap_upper_IQR:
            warnings.warn(
                f"The columns {overlap_upper_IQR} were present in both upper_lim and IQR_filter, therefore they were filtered using the values from upper_lim."
            )

    if lower_lim:
        for k, v in lower_lim.items():
            df_filtered.loc[df_filtered[k] < v, k] = np.nan
        overlap_lower_IQR = set(lower_lim.keys()) & set(IQR_filter)
        if overlap_lower_IQR:
            warnings.warn(
                f"The columns {overlap_lower_IQR} were present in both lower_lim and IQR_filter, therefore they were filtered using the values from lower_lim."
            )

    if IQR_filter:
        IQR_exclusive = list(
            set(IQR_filter) - set(upper_lim.keys()) - set(lower_lim.keys())
        )
        q1 = df_filtered.quantile(0.25)
        q3 = df_filtered.quantile(0.75)
        iqr = q3 - q1
        df_filtered[IQR_exclusive] = df_filtered[IQR_exclusive].mask(
            (df_filtered < q1 - 1.5 * iqr) | (df_filtered > q3 + 1.5 * iqr), np.nan
        )

    return df_filtered

n2mfrow(n_plots, ncol_max=4)

Determines the number of rows and columns required to plot a given number of subplots.

Parameters:
  • n_plots (int) –

    Total number of subplots.

  • ncol_max (int, default: 4 ) –

    Maximum number of columns for subplots. Defaults to 4.

Returns:
  • tuple( tuple[int, int] ) –

    (number of rows, number of columns)

sreftml\utilities.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def n2mfrow(n_plots: int, ncol_max: int = 4) -> tuple[int, int]:
    """
    Determines the number of rows and columns required to plot a given number of subplots.

    Args:
        n_plots (int): Total number of subplots.
        ncol_max (int, optional): Maximum number of columns for subplots. Defaults to 4.

    Returns:
        tuple: (number of rows, number of columns)"""
    n_plots = int(n_plots)
    nrow = math.ceil(n_plots / ncol_max)
    ncol = math.ceil(n_plots / nrow)
    return nrow, ncol

np_compute_negative_log_likelihood(y_true, y_pred, lnvar_y)

Computes the negative log likelihood between true and predicted values using numpy.

Parameters:
  • y_true (array) –

    True target values.

  • y_pred (array) –

    Predicted target values.

  • lnvar_y (array) –

    Natural logarithm of the variance.

Returns:
  • ndarray

    np.array: The negative log likelihood for each instance.

sreftml\utilities.py
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def np_compute_negative_log_likelihood(
    y_true: np.ndarray, y_pred: np.ndarray, lnvar_y: np.ndarray
) -> np.ndarray:
    """
    Computes the negative log likelihood between true and predicted values using numpy.

    Args:
        y_true (np.array): True target values.
        y_pred (np.array): Predicted target values.
        lnvar_y (np.array): Natural logarithm of the variance.

    Returns:
        np.array: The negative log likelihood for each instance.
    """
    neg_ll = lnvar_y + np.power(y_true - y_pred, 2) / np.exp(lnvar_y)
    return np.nansum(neg_ll, axis=1)

save_shap(path_to_shap_file, shap_exp)

Save the SHAP explanations to the specified file.

Parameters:
  • path_to_shap_file (str) –

    The path to save the SHAP file.

  • shap_exp (Explanation) –

    The SHAP explanations to be saved.

Returns:
  • None

    None

sreftml\utilities.py
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
def save_shap(path_to_shap_file: str, shap_exp: shap.Explanation) -> None:
    """
    Save the SHAP explanations to the specified file.

    Parameters:
        path_to_shap_file (str): The path to save the SHAP file.
        shap_exp (shap.Explanation): The SHAP explanations to be saved.

    Returns:
        None
    """
    with open(path_to_shap_file, "wb") as p:
        pickle.dump(shap_exp, p)

    return None

split_data_for_sreftml(df, name_biomarkers, name_covariates, isMixedlm=True)

Split data for sreftml.

Parameters:
  • df (DataFrame) –

    Input DataFrame.

  • name_biomarkers (list[str]) –

    List of biomarker names.

  • name_covariates (list[str]) –

    List of covariate names.

  • isMixedlm (bool, default: True ) –

    Select whether to use a mixed-effects model when computing model_1 features. Default to True.

Returns:
  • tuple( tuple[DataFrame, DataFrame, DataFrame, DataFrame] ) –

    A tuple containing the following arrays: - x (pd.DataFrame): Time values. - cov (pd.DataFrame): Covariate values. - m (pd.DataFrame): Slope and intercept from regression by biomarker. - y (pd.DataFrame): Biomarker values.

sreftml\utilities.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def split_data_for_sreftml(
    df: pd.DataFrame,
    name_biomarkers: list[str],
    name_covariates: list[str],
    isMixedlm: bool = True,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Split data for sreftml.

    Args:
        df (pd.DataFrame): Input DataFrame.
        name_biomarkers (list[str]): List of biomarker names.
        name_covariates (list[str]): List of covariate names.
        isMixedlm (bool): Select whether to use a mixed-effects model when computing model_1 features. Default to True.

    Returns:
        tuple: A tuple containing the following arrays:
            - x (pd.DataFrame): Time values.
            - cov (pd.DataFrame): Covariate values.
            - m (pd.DataFrame): Slope and intercept from regression by biomarker.
            - y (pd.DataFrame): Biomarker values.
    """
    df_ = df.copy()
    if len(name_covariates) > 0 and pd.isna(df[name_covariates]).any().any():
        warnings.warn("Missing value imputation was performed for some covariates.")
        df_[name_covariates] = df_[name_covariates].fillna(
            df.loc[:, name_covariates].mean()
        )

    if isMixedlm:
        linreg, models = mixed_effect_linear_regression(df_, name_biomarkers)
        if pd.isna(linreg).any().any():
            warnings.warn("Missing value imputation was performed for some features.")
            prms = [i.params[0] for i in models] + [i.params[1] for i in models]
            labels = [i + j for j in ["_intercept", "_slope"] for i in name_biomarkers]
            dict_slope = dict(zip(labels, prms))
            linreg = linreg.fillna(dict_slope)
    else:
        linreg = linear_regression_each_subject(df_, name_biomarkers)
        if pd.isna(linreg).any().any():
            warnings.warn("Missing value imputation was performed for some features.")
            linreg = linreg.fillna(linreg.mean())

    df_ = df_.merge(linreg)

    x = df_.TIME
    cov = df_[name_covariates]
    m = df_.loc[:, df_.columns.str.contains("_slope|_intercept")].dropna(
        axis=1, how="all"
    )
    y = df_[name_biomarkers]

    return x, cov, m, y

survival_analysis(df, surv_time, event, useOffsetT=True, gompertz_init_params=[0.1, 0.1])

Perform survival analysis and return a dictionary of survival analysis objects.

If the survival time contains 0 or less, the survival time is converted so that the minimum value is 0.00001.

Parameters:
  • df (DataFrame) –

    Input DataFrame.

  • surv_time (str) –

    Column name of the survival time in df.

  • event (str) –

    Column name of the event in df.

  • useOffsetT (bool, default: True ) –

    Determines whether to use offsetT for the analysis. Defaults to True.

Returns:
  • dict( dict ) –

    A dictionary of survival analysis objects.

sreftml\utilities.py
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
def survival_analysis(
    df: pd.DataFrame,
    surv_time: str,
    event: str,
    useOffsetT: bool = True,
    gompertz_init_params: list = [0.1, 0.1],
) -> dict:
    """
    Perform survival analysis and return a dictionary of survival analysis objects.


    If the survival time contains 0 or less, the survival time is converted so that the minimum value is 0.00001.

    Args:
        df (pd.DataFrame): Input DataFrame.
        surv_time (str): Column name of the survival time in df.
        event (str): Column name of the event in df.
        useOffsetT (bool, optional): Determines whether to use offsetT for the analysis. Defaults to True.

    Returns:
        dict: A dictionary of survival analysis objects.
    """
    fitters = [
        (lifelines.KaplanMeierFitter, "kmf", "KaplanMeier"),
        (lifelines.NelsonAalenFitter, "naf", "NelsonAalen"),
        (lifelines.ExponentialFitter, "epf", "Exponential"),
        (lifelines.WeibullFitter, "wbf", "Weibull"),
        (GompertzFitter, "gpf", "Gompertz"),
        (lifelines.LogLogisticFitter, "llf", "LogLogistic"),
        (lifelines.LogNormalFitter, "lnf", "LogNormal"),
    ]
    fit_model = {"title": event}
    if useOffsetT:
        df_surv = df[["ID", "offsetT", surv_time, event]].dropna().drop_duplicates()

        if df_surv["offsetT"].min() < 0:
            raise ValueError("offsetT must be greater than or equal to 0.")

        for fitter_class, key, label in fitters:
            if key == "gpf":
                fit_model[key] = fitter_class(label=label).fit(
                    durations=df_surv["offsetT"] + df_surv[surv_time],
                    event_observed=df_surv[event],
                    entry=df_surv["offsetT"],
                    initial_point=gompertz_init_params,
                )
            else:
                fit_model[key] = fitter_class(label=label).fit(
                    durations=df_surv["offsetT"] + df_surv[surv_time],
                    event_observed=df_surv[event],
                    entry=df_surv["offsetT"],
                )
    else:
        df_surv = df[["ID", surv_time, event]].dropna().drop_duplicates()
        for fitter_class, key, label in fitters:
            fit_model[key] = fitter_class(label=label).fit(
                durations=df_surv[surv_time], event_observed=df_surv[event]
            )

    return fit_model

tf_compute_negative_log_likelihood(y_true, y_pred, lnvar_y)

Computes the negative log likelihood between true and predicted values using tensorflow.

Parameters:
  • y_true (ndarray) –

    True target values.

  • y_pred (ndarray) –

    Predicted target values.

  • lnvar_y (Variable) –

    Natural logarithm of the variance.

Returns:
  • Tensor

    tf.Tensor: The negative log likelihood for each instance.

sreftml\utilities.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def tf_compute_negative_log_likelihood(
    y_true: np.ndarray, y_pred: np.ndarray, lnvar_y: tf.Variable
) -> tf.Tensor:
    """
    Computes the negative log likelihood between true and predicted values using tensorflow.

    Args:
        y_true (np.ndarray): True target values.
        y_pred (np.ndarray): Predicted target values.
        lnvar_y (tf.Variable): Natural logarithm of the variance.

    Returns:
        tf.Tensor: The negative log likelihood for each instance.
    """
    is_nan = tf.math.is_nan(y_true)
    y_true = tf.where(is_nan, tf.zeros_like(y_true), y_true)
    y_pred = tf.where(is_nan, tf.zeros_like(y_pred), y_pred)
    neg_ll = lnvar_y + tf.pow(y_true - y_pred, 2) / tf.exp(lnvar_y)
    neg_ll = tf.where(is_nan, tf.zeros_like(neg_ll), neg_ll)

    return tf.reduce_sum(neg_ll, axis=1)