Skip to content

Instantly share code, notes, and snippets.

@ShihengDuan
Last active September 14, 2023 03:32
Show Gist options
  • Save ShihengDuan/62883ff1b04f6b0b9325a0e808992bf4 to your computer and use it in GitHub Desktop.
Save ShihengDuan/62883ff1b04f6b0b9325a0e808992bf4 to your computer and use it in GitHub Desktop.
Calculate KGE, NSE for 3-D xarray data (time, lat, Lon)
def calculate_kge(obs_data, model_data):
mean_obs = obs_data.mean(dim='time')
mean_model = model_data.mean(dim='time')
obs_std = obs_data.std(dim='time')
model_std = model_data.std(dim='time')
upper = (model_data-mean_model)*(obs_data-mean_obs)
upper = upper.sum(dim='time')
lower_x = np.square(model_data-mean_model)
lower_y = np.square(obs_data-mean_obs)
lower = (lower_x.sum(dim='time'))*(lower_y.sum(dim='time'))
lower = np.sqrt(lower)
r = upper/lower
alpha = obs_std / model_std
beta = mean_obs / mean_model
kge = 1 - np.sqrt((r - 1) ** 2 + (alpha - 1) ** 2 + (beta - 1) ** 2)
return kge
def calculate_r2(obs_data, model_data):
diff = model_data-obs_data
diff_square = np.square(diff)
SSE = diff_square.mean(dim='time')
obs_mean = obs_data.mean(dim='time')
obs_diff = obs_data-UA_mean
obs_diff_square = np.square(obs_diff)
TSS = obs_diff_square.mean(dim='time')
r2 = 1-SSE/TSS
return r2
# snowpack metrics
def calculate_metrics(swe, start_year, end_year):
peak_SWEs = []
peak_date = []
acc_times = []
melt_times = []
for year in range(start_year, end_year+1):
print(year)
slice_UA = UA_swe.sel(time=slice(str(year)+'-10-01', str(year+1)+'-09-30'))
max_SWE = slice_UA.max(dim='time', skipna=True)
fill_slice = slice_UA.fillna(-1)
max_indices = fill_slice.argmax(dim='time', skipna=True)
peak_SWEs.append(max_SWE)
peak_date.append(max_indices)
threshold = 0.1 * max_SWE
# Create a mask where values exceed the threshold
exceeds_threshold = slice_UA >= threshold
acc_time = exceeds_threshold.argmax(dim='time').where(exceeds_threshold.any(dim='time'))
melt_time = exceeds_threshold[::-1].argmax(dim='time').where(exceeds_threshold.any(dim='time'))
acc_times.append(acc_time)
melt_times.append(melt_time)
return acc_times, melt_times, peak_date, peak_SWEs
@ShihengDuan
Copy link
Author

Use matrix/xarray to avoid loops.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment