From 58e3de19797ac3c383fd483c7525635dc05b36a8 Mon Sep 17 00:00:00 2001 From: Sudipta Basak Date: Wed, 4 Oct 2023 13:39:46 +1100 Subject: [PATCH] plot feature correlation matrix --- configs/ref_rf.yaml | 1 + uncoverml/geoio.py | 18 ++++++++++++++++++ uncoverml/scripts/uncoverml.py | 1 + 3 files changed, 20 insertions(+) diff --git a/configs/ref_rf.yaml b/configs/ref_rf.yaml index a7baa983..eafd7313 100644 --- a/configs/ref_rf.yaml +++ b/configs/ref_rf.yaml @@ -83,6 +83,7 @@ prediction: validation: #- feature_rank - permutation_importance + # TODO: - feature_importance # only works for tree based models - parallel - k-fold: folds: 5 diff --git a/uncoverml/geoio.py b/uncoverml/geoio.py index e7f55ac2..1d51acaa 100644 --- a/uncoverml/geoio.py +++ b/uncoverml/geoio.py @@ -12,6 +12,7 @@ from typing import Union import pickle import matplotlib.pyplot as plt +import seaborn as sns import rasterio from rasterio.warp import reproject from rasterio.windows import Window @@ -736,6 +737,23 @@ def export_validation_scatter_plot_and_validation_csv(outfile_results, config: C plt.savefig(true_vs_pred_plot) +def plot_feature_correlation_matrix(config: Config, x_all): + fig, corr_ax = plt.subplots() + features = [Path(f).stem for f in feature_names(config)] + corr_df = pd.DataFrame(x_all) + corr_df.columns = features + sns.heatmap(corr_df.corr(), + vmin=-1, vmax=1, annot=True, + square=True, linewidths=.5, cbar_kws={"shrink": .5}, + cmap='BrBG', + ) + fig.suptitle('Feature Correlations') + fig.tight_layout() + save_path = Path(config.output_dir).joinpath(config.name + "_feature_correlation.png") \ + .as_posix() + fig.savefig(save_path) + + def resample(input_tif, output_tif, ratio, resampling="average"): """ Parameters diff --git a/uncoverml/scripts/uncoverml.py b/uncoverml/scripts/uncoverml.py index 6429eaf6..260c4413 100644 --- a/uncoverml/scripts/uncoverml.py +++ b/uncoverml/scripts/uncoverml.py @@ -55,6 +55,7 @@ def cli(verbosity): def run_crossval(x_all, targets_all, config): crossval_results = ls.validate.local_crossval(x_all, targets_all, config) ls.mpiops.run_once(ls.geoio.export_crossval, crossval_results, config) + ls.mpiops.run_once(ls.geoio.plot_feature_correlation_matrix, config, x_all) @cli.command()