Skip to content

Instantly share code, notes, and snippets.

@prl900
Created April 21, 2023 00:36
Show Gist options
  • Save prl900/3af811bb2ecd56dd381a46cf58b19ded to your computer and use it in GitHub Desktop.
Save prl900/3af811bb2ecd56dd381a46cf58b19ded to your computer and use it in GitHub Desktop.
from glob import glob
from sklearn.ensemble import RandomForestRegressor
import pandas as pd
import rasterio as rio
import numpy as np
print("Starting!")
path = 'WCF_traindata_14122019.csv'
df = pd.read_csv(path)
df.columns = df.columns.str.replace(' ', '')
df = df.dropna()
# Define dependent variable
Y = df['WCF']
# Define independent variables
X = df.drop(labels=['WCF'],axis=1)
model = RandomForestRegressor(n_estimators=100,
criterion = 'mse',
min_samples_split = 2,
min_samples_leaf = 50,
min_weight_fraction_leaf = 0,
max_features = 'auto',
min_impurity_decrease = 0,
bootstrap = True,
oob_score = False,
random_state=20,
verbose = 0,
warm_start = False
)
# Fit model
model.fit(X,Y)
print("Trained!")
bnames = ['blue', 'green', 'red', 'nir', 'swir1', 'swir2']
year = 2022
for i, fname in enumerate(glob(f"/g/data/jw04/ga/ga_ls8c_nbart_gm_cyear_3/3-0-0/*/*/{year}--P1Y/ga_ls8c_nbart_gm_cyear_*_{year}--P1Y_final_swir2.tif")):
tile_id = fname.split('/')[-1].split('_')[6]
x = int(tile_id[1:3])
y = int(tile_id[4:6])
print(x, y)
rast_stack = []
for bname in bnames:
fnameb = fname.replace("swir2", bname)
ds = rio.open(fnameb)
rast_stack.append(ds.read(1))
stack = np.stack(rast_stack, axis=2).astype(np.float32)
red = stack[:,:,2]
nir = stack[:,:,3]
ndvi = (nir-red)/(nir+red)
fpar = np.clip((ndvi-0.1)/(0.9-0.1), 0, 1)
#tile_size = 1000
df = pd.DataFrame(stack.reshape((-1,6))/10000, columns=['blue', 'green', 'red', 'nir', 'swir1', 'swir2'])
wcf = model.predict(df).reshape((3200, 3200))
#print(" -", wcf.shape, wcf.min(), wcf.max())
wcf = np.clip(wcf, 0, 1)
wcf[ndvi<0] = 0 # creates a mask for negative values of NDVI normally associated with water
wcf[wcf>2*fpar] = 0 # creates a mask of where tree cover (fraction 0-1) is >2 larger than FPAR - is not realistic
wcf *= 100
wcf = wcf.astype(np.uint8)
np.save(f"/g/data/ub8/au/blobs/wcf_{x:02d}_{y:02d}_2022.npy", wcf)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment