Skip to content

Commit

Permalink
plot feature correlation matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
basaks committed Oct 4, 2023
1 parent 7b89ba4 commit 58e3de1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
1 change: 1 addition & 0 deletions configs/ref_rf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ prediction:
validation:
#- feature_rank
- permutation_importance
# TODO: - feature_importance # only works for tree based models
- parallel
- k-fold:
folds: 5
Expand Down
18 changes: 18 additions & 0 deletions uncoverml/geoio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Union
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
import rasterio
from rasterio.warp import reproject
from rasterio.windows import Window
Expand Down Expand Up @@ -736,6 +737,23 @@ def export_validation_scatter_plot_and_validation_csv(outfile_results, config: C
plt.savefig(true_vs_pred_plot)


def plot_feature_correlation_matrix(config: Config, x_all):
fig, corr_ax = plt.subplots()
features = [Path(f).stem for f in feature_names(config)]
corr_df = pd.DataFrame(x_all)
corr_df.columns = features
sns.heatmap(corr_df.corr(),
vmin=-1, vmax=1, annot=True,
square=True, linewidths=.5, cbar_kws={"shrink": .5},
cmap='BrBG',
)
fig.suptitle('Feature Correlations')
fig.tight_layout()
save_path = Path(config.output_dir).joinpath(config.name + "_feature_correlation.png") \
.as_posix()
fig.savefig(save_path)


def resample(input_tif, output_tif, ratio, resampling="average"):
"""
Parameters
Expand Down
1 change: 1 addition & 0 deletions uncoverml/scripts/uncoverml.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def cli(verbosity):
def run_crossval(x_all, targets_all, config):
crossval_results = ls.validate.local_crossval(x_all, targets_all, config)
ls.mpiops.run_once(ls.geoio.export_crossval, crossval_results, config)
ls.mpiops.run_once(ls.geoio.plot_feature_correlation_matrix, config, x_all)


@cli.command()
Expand Down

0 comments on commit 58e3de1

Please sign in to comment.