Created
February 22, 2024 00:01
-
-
Save calebrob6/658edaa59c68f0c0a510f8d9d7a41458 to your computer and use it in GitHub Desktop.
Verify the behavior of `smp.losses.JaccardLoss`.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"from segmentation_models_pytorch.losses import JaccardLoss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"y_pred = torch.rand(1, 4, 256, 256)\n", | |
"y_pred[:,1,:,:] = 0\n", | |
"y_pred = nn.functional.softmax(y_pred, dim=1)\n", | |
"\n", | |
"y_true = torch.randint(0, 4, size=(1, 256, 256))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.8589)" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"loss = JaccardLoss(mode=\"multiclass\", classes=None, log_loss=False, from_logits=False)\n", | |
"loss(y_pred, y_true)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.8589)" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"loss = JaccardLoss(mode=\"multiclass\", classes=4, log_loss=False, from_logits=False)\n", | |
"loss(y_pred, y_true)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor(0.8490)\n", | |
"tensor(0.8893)\n", | |
"tensor(0.8479)\n", | |
"tensor(0.8493)\n" | |
] | |
} | |
], | |
"source": [ | |
"for i in range(4):\n", | |
" loss = JaccardLoss(mode=\"multiclass\", classes=[i], log_loss=False, from_logits=False)\n", | |
" print(loss(y_pred, y_true))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor(0.8589)\n" | |
] | |
} | |
], | |
"source": [ | |
"loss = JaccardLoss(mode=\"multiclass\", classes=[0,1,2,3], log_loss=False, from_logits=False)\n", | |
"print(loss(y_pred, y_true))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"tensor(0.8487)\n" | |
] | |
} | |
], | |
"source": [ | |
"loss = JaccardLoss(mode=\"multiclass\", classes=[0,2,3], log_loss=False, from_logits=False)\n", | |
"print(loss(y_pred, y_true))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "geo", | |
"language": "python", | |
"name": "geo" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment