Model Training and Evaluation (Timor Leste)

import sys
sys.path.append("../../")

import uuid
import numpy as np
import geopandas as gpd
import pandas as pd
from shapely import wkt

from geowrangler import dhs
from povertymapping import settings, osm, ookla, nightlights
from povertymapping.dhs import generate_dhs_cluster_level_data
from povertymapping.osm import OsmDataManager
from povertymapping.ookla import OoklaDataManager
import getpass


import pickle
import os
from sklearn.model_selection import train_test_split, KFold, RepeatedKFold
from sklearn.model_selection import GroupKFold, cross_val_predict, cross_val_score
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
import seaborn as sns

import shap
/home/jc_tm/project_repos/unicef-ai4d-poverty-mapping/env/lib/python3.9/site-packages/geopandas/_compat.py:111: UserWarning: The Shapely GEOS version (3.10.3-CAPI-1.16.1) is incompatible with the GEOS version PyGEOS was compiled with (3.10.1-CAPI-1.16.0). Conversions between both will be slow.
  warnings.warn(
/home/jc_tm/project_repos/unicef-ai4d-poverty-mapping/env/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
%reload_ext autoreload
%autoreload 2

Load Target Country From DHS data

# Set country-specific variables
country_osm = "east-timor"
ookla_year = 2019
nightlights_year = 2016
dhs_household_dta_path = settings.DATA_DIR/"dhs/tl/TLHR71DT/TLHR71FL.DTA"
dhs_geographic_shp_path = settings.DATA_DIR/"dhs/tl/TLGE71FL/TLGE71FL.shp"
dhs_gdf = generate_dhs_cluster_level_data(
    dhs_household_dta_path, 
    dhs_geographic_shp_path, 
    col_rename_config="tl",
    convert_geoms_to_bbox=True,
    bbox_size_km=2.4
).reset_index(drop=True)
dhs_gdf.explore()
Make this Notebook Trusted to load map: File -> Trust Notebook
dhs_gdf.head()
DHSCLUST Wealth Index DHSID DHSCC DHSYEAR CCFIPS ADM1FIPS ADM1FIPSNA ADM1SALBNA ADM1SALBCO ... URBAN_RURA LATNUM LONGNUM ALT_GPS ALT_DEM DATUM F21 F22 F23 geometry
0 1 32166.600000 TL201600000001 TL 2016.0 TT NULL NULL NULL NULL ... R -8.712016 125.567381 9999.0 1005.0 WGS84 None None None POLYGON ((125.55646 -8.70122, 125.57830 -8.701...
1 2 -34063.923077 TL201600000002 TL 2016.0 TT NULL NULL NULL NULL ... R -8.730226 125.590219 9999.0 1342.0 WGS84 None None None POLYGON ((125.57930 -8.71943, 125.60114 -8.719...
2 3 39230.590909 TL201600000003 TL 2016.0 TT NULL NULL NULL NULL ... R -8.741340 125.556399 9999.0 1060.0 WGS84 None None None POLYGON ((125.54548 -8.73055, 125.56732 -8.730...
3 4 -82140.227273 TL201600000004 TL 2016.0 TT NULL NULL NULL NULL ... R -8.811291 125.535161 9999.0 1986.0 WGS84 None None None POLYGON ((125.52424 -8.80050, 125.54608 -8.800...
4 5 -56203.423077 TL201600000005 TL 2016.0 TT NULL NULL NULL NULL ... R -8.791590 125.473219 9999.0 1491.0 WGS84 None None None POLYGON ((125.46230 -8.78080, 125.48414 -8.780...

5 rows × 25 columns

Set up Data Access

# Instantiate data managers for Ookla and OSM
# This auto-caches requested data in RAM, so next fetches of the data are faster.
osm_data_manager = OsmDataManager(cache_dir=settings.ROOT_DIR/"data/data_cache")
ookla_data_manager = OoklaDataManager(cache_dir=settings.ROOT_DIR/"data/data_cache")
# Log-in using EOG credentials
username = os.environ.get('EOG_USER',None)
username = username if username is not None else input('Username?')
password = os.environ.get('EOG_PASSWORD',None)
password = password if password is not None else getpass.getpass('Password?') 


# set save_token to True so that access token gets stored in ~/.eog_creds/eog_access_token
access_token = nightlights.get_eog_access_token(username,password, save_token=True)
2023-01-31 14:00:38.706 | INFO     | povertymapping.nightlights:get_eog_access_token:48 - Saving access_token to ~/.eog_creds/eog_access_token
2023-01-31 14:00:38.707 | INFO     | povertymapping.nightlights:get_eog_access_token:56 - Adding access token to environmentt var EOG_ACCESS_TOKEN

Generate Base Features

If this is your first time running this notebook for this specific area, expect a long runtime for the following cell as it will download and cache the ff. datasets from the internet.

  • OpenStreetMap Data from Geofabrik
  • Ookla Internet Speed Data
  • VIIRS nighttime lights data from NASA EOG

On subsequent runs, the runtime will be much faster as the data is already stored in your filesystem.

%%time
country_data = dhs_gdf.copy()

# Add in OSM features
country_data = osm.add_osm_poi_features(country_data, country_osm, osm_data_manager)
country_data = osm.add_osm_road_features(country_data, country_osm, osm_data_manager)

# Add in Ookla features
country_data = ookla.add_ookla_features(country_data, 'fixed', ookla_year, ookla_data_manager)
country_data = ookla.add_ookla_features(country_data, 'mobile', ookla_year, ookla_data_manager)

# Add in the nighttime lights features
country_data = nightlights.generate_nightlights_feature(country_data, str(nightlights_year)) 
2023-01-31 14:00:38.851 | INFO     | povertymapping.osm:download_osm_country_data:187 - OSM Data: Cached data available for east-timor at /home/jc_tm/project_repos/unicef-ai4d-poverty-mapping/notebooks/2023-01-17-initial-model-ph-mm-tl-kh/../../data/data_cache/osm/east-timor? True
2023-01-31 14:00:38.852 | DEBUG    | povertymapping.osm:load_pois:149 - OSM POIs for east-timor being loaded from /home/jc_tm/project_repos/unicef-ai4d-poverty-mapping/notebooks/2023-01-17-initial-model-ph-mm-tl-kh/../../data/data_cache/osm/east-timor/gis_osm_pois_free_1.shp
2023-01-31 14:00:42.117 | INFO     | povertymapping.osm:download_osm_country_data:187 - OSM Data: Cached data available for east-timor at /home/jc_tm/project_repos/unicef-ai4d-poverty-mapping/notebooks/2023-01-17-initial-model-ph-mm-tl-kh/../../data/data_cache/osm/east-timor? True
2023-01-31 14:00:42.118 | DEBUG    | povertymapping.osm:load_roads:168 - OSM Roads for east-timor being loaded from /home/jc_tm/project_repos/unicef-ai4d-poverty-mapping/notebooks/2023-01-17-initial-model-ph-mm-tl-kh/../../data/data_cache/osm/east-timor/gis_osm_roads_free_1.shp
2023-01-31 14:00:43.655 | DEBUG    | povertymapping.ookla:load_type_year_data:68 - Contents of data cache: []
2023-01-31 14:00:43.657 | INFO     | povertymapping.ookla:load_type_year_data:83 - Cached data available at /home/jc_tm/project_repos/unicef-ai4d-poverty-mapping/notebooks/2023-01-17-initial-model-ph-mm-tl-kh/../../data/data_cache/ookla/processed/206a0323fa0e80f82339b66d0c859b4a.csv? True
2023-01-31 14:00:43.658 | DEBUG    | povertymapping.ookla:load_type_year_data:88 - Processed Ookla data for aoi, fixed 2019 (key: 206a0323fa0e80f82339b66d0c859b4a) found in filesystem. Loading in cache.
2023-01-31 14:00:43.871 | DEBUG    | povertymapping.ookla:load_type_year_data:68 - Contents of data cache: ['206a0323fa0e80f82339b66d0c859b4a']
2023-01-31 14:00:43.873 | INFO     | povertymapping.ookla:load_type_year_data:83 - Cached data available at /home/jc_tm/project_repos/unicef-ai4d-poverty-mapping/notebooks/2023-01-17-initial-model-ph-mm-tl-kh/../../data/data_cache/ookla/processed/209c2544788b8e2bdf4db4685c50e26d.csv? True
2023-01-31 14:00:43.874 | DEBUG    | povertymapping.ookla:load_type_year_data:88 - Processed Ookla data for aoi, mobile 2019 (key: 209c2544788b8e2bdf4db4685c50e26d) found in filesystem. Loading in cache.
2023-01-31 14:00:44.047 | INFO     | povertymapping.nightlights:get_clipped_raster:414 - Retrieving clipped raster file /home/jc_tm/.geowrangler/nightlights/clip/b0d0551dd5a67c8eada595334f2655ed.tif
CPU times: user 8.07 s, sys: 127 ms, total: 8.2 s
Wall time: 8.27 s

Inspect the combined target country data

country_data.info()
<class 'geopandas.geodataframe.GeoDataFrame'>
Int64Index: 455 entries, 0 to 454
Data columns (total 86 columns):
 #   Column                             Non-Null Count  Dtype   
---  ------                             --------------  -----   
 0   DHSCLUST                           455 non-null    int64   
 1   Wealth Index                       455 non-null    float64 
 2   DHSID                              455 non-null    object  
 3   DHSCC                              455 non-null    object  
 4   DHSYEAR                            455 non-null    float64 
 5   CCFIPS                             455 non-null    object  
 6   ADM1FIPS                           455 non-null    object  
 7   ADM1FIPSNA                         455 non-null    object  
 8   ADM1SALBNA                         455 non-null    object  
 9   ADM1SALBCO                         455 non-null    object  
 10  ADM1DHS                            455 non-null    float64 
 11  ADM1NAME                           455 non-null    object  
 12  DHSREGCO                           455 non-null    float64 
 13  DHSREGNA                           455 non-null    object  
 14  SOURCE                             455 non-null    object  
 15  URBAN_RURA                         455 non-null    object  
 16  LATNUM                             455 non-null    float64 
 17  LONGNUM                            455 non-null    float64 
 18  ALT_GPS                            455 non-null    float64 
 19  ALT_DEM                            455 non-null    float64 
 20  DATUM                              455 non-null    object  
 21  F21                                0 non-null      object  
 22  F22                                0 non-null      object  
 23  F23                                0 non-null      object  
 24  geometry                           455 non-null    geometry
 25  poi_count                          455 non-null    float64 
 26  atm_count                          455 non-null    float64 
 27  atm_nearest                        455 non-null    float64 
 28  bank_count                         455 non-null    float64 
 29  bank_nearest                       455 non-null    float64 
 30  bus_station_count                  455 non-null    float64 
 31  bus_station_nearest                455 non-null    float64 
 32  cafe_count                         455 non-null    float64 
 33  cafe_nearest                       455 non-null    float64 
 34  charging_station_count             455 non-null    float64 
 35  charging_station_nearest           455 non-null    float64 
 36  courthouse_count                   455 non-null    float64 
 37  courthouse_nearest                 455 non-null    float64 
 38  dentist_count                      455 non-null    float64 
 39  dentist_nearest                    455 non-null    float64 
 40  fast_food_count                    455 non-null    float64 
 41  fast_food_nearest                  455 non-null    float64 
 42  fire_station_count                 455 non-null    float64 
 43  fire_station_nearest               455 non-null    float64 
 44  food_court_count                   455 non-null    float64 
 45  food_court_nearest                 455 non-null    float64 
 46  fuel_count                         455 non-null    float64 
 47  fuel_nearest                       455 non-null    float64 
 48  hospital_count                     455 non-null    float64 
 49  hospital_nearest                   455 non-null    float64 
 50  library_count                      455 non-null    float64 
 51  library_nearest                    455 non-null    float64 
 52  marketplace_count                  455 non-null    float64 
 53  marketplace_nearest                455 non-null    float64 
 54  pharmacy_count                     455 non-null    float64 
 55  pharmacy_nearest                   455 non-null    float64 
 56  police_count                       455 non-null    float64 
 57  police_nearest                     455 non-null    float64 
 58  post_box_count                     455 non-null    float64 
 59  post_box_nearest                   455 non-null    float64 
 60  post_office_count                  455 non-null    float64 
 61  post_office_nearest                455 non-null    float64 
 62  restaurant_count                   455 non-null    float64 
 63  restaurant_nearest                 455 non-null    float64 
 64  social_facility_count              455 non-null    float64 
 65  social_facility_nearest            455 non-null    float64 
 66  supermarket_count                  455 non-null    float64 
 67  supermarket_nearest                455 non-null    float64 
 68  townhall_count                     455 non-null    float64 
 69  townhall_nearest                   455 non-null    float64 
 70  road_count                         455 non-null    float64 
 71  fixed_2019_mean_avg_d_kbps_mean    64 non-null     float64 
 72  fixed_2019_mean_avg_u_kbps_mean    64 non-null     float64 
 73  fixed_2019_mean_avg_lat_ms_mean    64 non-null     float64 
 74  fixed_2019_mean_num_tests_mean     64 non-null     float64 
 75  fixed_2019_mean_num_devices_mean   64 non-null     float64 
 76  mobile_2019_mean_avg_d_kbps_mean   173 non-null    float64 
 77  mobile_2019_mean_avg_u_kbps_mean   173 non-null    float64 
 78  mobile_2019_mean_avg_lat_ms_mean   173 non-null    float64 
 79  mobile_2019_mean_num_tests_mean    173 non-null    float64 
 80  mobile_2019_mean_num_devices_mean  173 non-null    float64 
 81  avg_rad_min                        455 non-null    float64 
 82  avg_rad_max                        455 non-null    float64 
 83  avg_rad_mean                       455 non-null    float64 
 84  avg_rad_std                        455 non-null    float64 
 85  avg_rad_median                     455 non-null    float64 
dtypes: float64(69), geometry(1), int64(1), object(15)
memory usage: 325.4+ KB
country_data.head()
DHSCLUST Wealth Index DHSID DHSCC DHSYEAR CCFIPS ADM1FIPS ADM1FIPSNA ADM1SALBNA ADM1SALBCO ... mobile_2019_mean_avg_d_kbps_mean mobile_2019_mean_avg_u_kbps_mean mobile_2019_mean_avg_lat_ms_mean mobile_2019_mean_num_tests_mean mobile_2019_mean_num_devices_mean avg_rad_min avg_rad_max avg_rad_mean avg_rad_std avg_rad_median
0 1 32166.600000 TL201600000001 TL 2016.0 TT NULL NULL NULL NULL ... 120.487266 36.768028 0.16491 0.041227 0.013742 0.026005 1.205315 0.281820 0.357397 0.112378
1 2 -34063.923077 TL201600000002 TL 2016.0 TT NULL NULL NULL NULL ... NaN NaN NaN NaN NaN 0.009625 0.356664 0.104796 0.097928 0.052513
2 3 39230.590909 TL201600000003 TL 2016.0 TT NULL NULL NULL NULL ... 734.607003 147.058372 1.40893 0.049611 0.041888 0.010133 1.219182 0.328512 0.397882 0.118295
3 4 -82140.227273 TL201600000004 TL 2016.0 TT NULL NULL NULL NULL ... NaN NaN NaN NaN NaN -0.021887 0.284718 0.063929 0.089171 0.024247
4 5 -56203.423077 TL201600000005 TL 2016.0 TT NULL NULL NULL NULL ... NaN NaN NaN NaN NaN -0.039887 0.024366 0.002567 0.015578 0.005337

5 rows × 86 columns

Data Preparation

Split into labels and features

# Set parameters
label_col = 'Wealth Index'
# Split train/test data into features and labels

# For labels, we just select the target label column
labels = country_data[[label_col]]

# For features, drop all columns from the input country geometries
# If you need the cluster data, refer to country_data / country_test
input_dhs_cols = dhs_gdf.columns
features = country_data.drop(input_dhs_cols, axis=1)

features.shape, labels.shape
((455, 61), (455, 1))
# Clean features
# For now, just impute nans with 0
# TODO: Implement other cleaning steps
features = features.fillna(0)

Base Features List

The features can be subdivided by the source dataset

OSM

  • <poi type>_count: number of points of interest (POI) of a specified type in that area
    • ex. atm_count: number of atms in cluster
    • poi_count: number of all POIs of all types in cluster
  • <poi_type>_nearest: distance of nearest POI of the specified type
    • ex. atm_nearest: distance of nearest ATM from that cluster
  • OSM POI types included: atm, bank, bus_stations, cafe, charging_station, courthouse, dentist (clinic), fast_food, fire_station, food_court, fuel (gas station), hospital, library, marketplace, pharmacy, police, post_box, post_office, restaurant, social_facility, supermarket, townhall, road

Ookla

The network metrics features follow the following name convention:

<type>_<year>_<yearly aggregate>_<network variable>_<cluster aggregate>

  • type: kind of network connection measured
    • fixed: connection from fixed sources (landline, fiber, etc.)
    • mobile: connection from mobile devices
  • year: Year of source data
  • yearly aggregate: How data was aggregated into yearly data
    • Note: Ookla provides data per quarter, so a yearly mean takes the average across 4 quarters
    • For this model, we only aggregate by yearly mean
  • network variable: network characteristic described
    • avg_d_kbps: average download speed in kbps
    • avg_u_kbps: average upload speed in kbps
    • avg_lat_ms: average latency in ms
    • num_devices: number of devices measured
  • cluster aggregate: how the data was aggregated per cluster aggregate
    • Types: min, mean, max, median, std.
      • For this model: only mean is used
    • This is calculated using area zonal stats, which weighs the average by the intersection of the Ookla tile with the cluster geometry.

Ex. fixed_2019_mean_avg_d_kbps_median takes the cluster median of 2019 yearly average download speed.

Nightlights (VIIRS)

All nightlights features are taken as the zonal aggregate of the raster data per cluster

  • ex. avg_rad_mean: cluster mean of the average radiance
  • aggregations used: min, mean, max, median
features.info()
<class 'geopandas.geodataframe.GeoDataFrame'>
Int64Index: 455 entries, 0 to 454
Data columns (total 61 columns):
 #   Column                             Non-Null Count  Dtype  
---  ------                             --------------  -----  
 0   poi_count                          455 non-null    float64
 1   atm_count                          455 non-null    float64
 2   atm_nearest                        455 non-null    float64
 3   bank_count                         455 non-null    float64
 4   bank_nearest                       455 non-null    float64
 5   bus_station_count                  455 non-null    float64
 6   bus_station_nearest                455 non-null    float64
 7   cafe_count                         455 non-null    float64
 8   cafe_nearest                       455 non-null    float64
 9   charging_station_count             455 non-null    float64
 10  charging_station_nearest           455 non-null    float64
 11  courthouse_count                   455 non-null    float64
 12  courthouse_nearest                 455 non-null    float64
 13  dentist_count                      455 non-null    float64
 14  dentist_nearest                    455 non-null    float64
 15  fast_food_count                    455 non-null    float64
 16  fast_food_nearest                  455 non-null    float64
 17  fire_station_count                 455 non-null    float64
 18  fire_station_nearest               455 non-null    float64
 19  food_court_count                   455 non-null    float64
 20  food_court_nearest                 455 non-null    float64
 21  fuel_count                         455 non-null    float64
 22  fuel_nearest                       455 non-null    float64
 23  hospital_count                     455 non-null    float64
 24  hospital_nearest                   455 non-null    float64
 25  library_count                      455 non-null    float64
 26  library_nearest                    455 non-null    float64
 27  marketplace_count                  455 non-null    float64
 28  marketplace_nearest                455 non-null    float64
 29  pharmacy_count                     455 non-null    float64
 30  pharmacy_nearest                   455 non-null    float64
 31  police_count                       455 non-null    float64
 32  police_nearest                     455 non-null    float64
 33  post_box_count                     455 non-null    float64
 34  post_box_nearest                   455 non-null    float64
 35  post_office_count                  455 non-null    float64
 36  post_office_nearest                455 non-null    float64
 37  restaurant_count                   455 non-null    float64
 38  restaurant_nearest                 455 non-null    float64
 39  social_facility_count              455 non-null    float64
 40  social_facility_nearest            455 non-null    float64
 41  supermarket_count                  455 non-null    float64
 42  supermarket_nearest                455 non-null    float64
 43  townhall_count                     455 non-null    float64
 44  townhall_nearest                   455 non-null    float64
 45  road_count                         455 non-null    float64
 46  fixed_2019_mean_avg_d_kbps_mean    455 non-null    float64
 47  fixed_2019_mean_avg_u_kbps_mean    455 non-null    float64
 48  fixed_2019_mean_avg_lat_ms_mean    455 non-null    float64
 49  fixed_2019_mean_num_tests_mean     455 non-null    float64
 50  fixed_2019_mean_num_devices_mean   455 non-null    float64
 51  mobile_2019_mean_avg_d_kbps_mean   455 non-null    float64
 52  mobile_2019_mean_avg_u_kbps_mean   455 non-null    float64
 53  mobile_2019_mean_avg_lat_ms_mean   455 non-null    float64
 54  mobile_2019_mean_num_tests_mean    455 non-null    float64
 55  mobile_2019_mean_num_devices_mean  455 non-null    float64
 56  avg_rad_min                        455 non-null    float64
 57  avg_rad_max                        455 non-null    float64
 58  avg_rad_mean                       455 non-null    float64
 59  avg_rad_std                        455 non-null    float64
 60  avg_rad_median                     455 non-null    float64
dtypes: float64(61)
memory usage: 236.6 KB

Model Training

# Set parameters
cv_col = 'ADM1NAME'
cv_num_splits = 5
cv_num_repeats = 5
train_test_seed = 42
test_size = 0.2

Create train/test cross-validation indices

# train_features, test_features, train_labels, test_labels = train_test_split(
#     features, labels, test_size=test_size, random_state=train_test_seed
# )

# Cross validation
print(f"Performing {cv_num_splits}-fold CV...")
cv = RepeatedKFold(n_splits=cv_num_splits, n_repeats=cv_num_repeats, random_state=train_test_seed)

print(cv.split(features))
Performing 5-fold CV...
<generator object _RepeatedSplits.split at 0x7fc72dca7cf0>

Instantiate model

For now, we will train a simple random forest model

from sklearn.ensemble import RandomForestRegressor
model = RandomForestRegressor(n_estimators=100, random_state=train_test_seed, verbose=0)
model
RandomForestRegressor(random_state=42)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Evaluate model training using cross-validation

We evalute the model’s generalizability when training over different train/test splits

Ideally for R^2 - We want a high mean: This means that we achieve a high model performance over the different train/test splits - We want a low standard deviation (std): This means that the model performance is stable over multiple training repetitions

R_cv = cross_val_score(model, features.values, labels.values.ravel(), cv=cv)
print("Cross validation scores are: ", R_cv)
cv_mean = round(np.array(R_cv).mean(), 2)
cv_std = round(np.array(R_cv).std(), 2)
print(f"Cross validation R^2 mean: {cv_mean}")
print(f"Cross validation R^2 std: {cv_std}")
Cross validation scores are:  [0.52314563 0.71589212 0.62702992 0.62876376 0.60727356 0.59854787
 0.64459262 0.5645657  0.6475496  0.58516898 0.64333412 0.57955347
 0.68957151 0.5678542  0.39258228 0.57938886 0.73362811 0.55649354
 0.62505678 0.55131979 0.58291504 0.51003125 0.66213206 0.57155403
 0.62023663]
Cross validation R^2 mean: 0.6
Cross validation R^2 std: 0.07

Train the final model

For training the final model, we train on all the available data.

model.fit(features.values, labels.values.ravel())
RandomForestRegressor(random_state=42)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Model Evaluation

SHAP Feature Importance

explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(features)
shap_values
array([[-2.78288444e+02,  2.54662269e+00, -3.61114762e+02, ...,
        -3.76359324e+03,  3.56991434e+02,  3.14029666e+02],
       [ 6.83641498e-01,  3.08155449e+00, -4.04342108e+02, ...,
        -5.44415663e+03, -2.67927065e+03, -1.10857191e+03],
       [-3.68998815e+02,  2.09687769e+00, -2.06384964e+02, ...,
        -3.49928320e+03, -7.77595317e+02, -2.66918649e+02],
       ...,
       [-7.92942075e+02,  1.30255281e+01, -2.02013077e+02, ...,
         3.55792366e+03,  5.90910492e+03, -7.51810625e+02],
       [ 6.99005836e+01,  3.12683522e+00, -2.89566630e+02, ...,
        -4.03058008e+03, -9.21508542e+02,  6.59822668e+02],
       [-2.77892341e+02,  2.80698684e+00, -4.20183284e+02, ...,
        -6.95179308e+03, -2.06168794e+03, -5.03960873e+03]])
shap.summary_plot(shap_values, features, feature_names=features.columns, plot_type="bar")

shap.summary_plot(shap_values, features.values, feature_names=features.columns)
No data for colormapping provided via 'c'. Parameters 'vmin', 'vmax' will be ignored

Save Model

model_save_path = "./model_tl.pkl"
with open(model_save_path, "wb") as file:
    pickle.dump(model, file)