Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save nassarofficial/71e85ee1000ff178db129a692f29424d to your computer and use it in GitHub Desktop.
Save nassarofficial/71e85ee1000ff178db129a692f29424d to your computer and use it in GitHub Desktop.
Active Contour Model Greedy Implementation
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Active Contour Model Greedy Algorithm By William and Shah"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Input Image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"http://glimpglobe.com/proj/shark1.png\">"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import cv2\n",
"import numpy as np\n",
"from matplotlib import pyplot as plt\n",
"import math"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Contour Selection"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Pick using user clicks"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def contour_selection(image, winname=-1, distance=10):\n",
" if winname == -1:\n",
" winname = \"Test\"\n",
"\n",
" overlay_image = image.copy()\n",
" if len(overlay_image.shape) > 2:\n",
" overlay_image = cv2.cvtColor(overlay_image, cv2.COLOR_GRAY2RGB)\n",
" point_list = []\n",
" cv2.putText(overlay_image, \"Click on image to draw initial snake\", (200, 30), cv2.FONT_HERSHEY_SIMPLEX,\n",
" 1, (255, 255, 0))\n",
" cv2.imshow(winname, overlay_image)\n",
"\n",
" cv2.setMouseCallback(winname, on_mouse, param=[point_list, image, winname, distance])\n",
" cv2.waitKey()\n",
" cv2.imshow(winname, image)\n",
" cv2.setMouseCallback(winname, lambda a, b, c, d, e: None)\n",
" return point_list\n",
"\n",
"def on_mouse(event, x, y, flag, param):\n",
" global flag_drawing\n",
" contour = param[0]\n",
" overlay_image = cv2.cvtColor(param[1].copy(), cv2.COLOR_GRAY2RGB)\n",
" winname = param[2]\n",
" distance = param[3]\n",
"\n",
" if event == cv2.EVENT_LBUTTONDOWN:\n",
" flag_drawing = not flag_drawing\n",
"\n",
" if event == cv2.EVENT_MOUSEMOVE and flag_drawing:\n",
" xy = np.array([x, y])\n",
"\n",
" if len(contour) < 1: # the first pixel clicked is always ok\n",
" contour.append(xy)\n",
"\n",
" elif np.linalg.norm(xy - np.array(contour)[-1]) > distance:\n",
" contour.append(xy)\n",
"\n",
" for i in range(len(contour)):\n",
" cv2.circle(overlay_image, (contour[i][0], contour[i][1]), 2, 255, 2)\n",
" cv2.polylines(overlay_image, np.array([contour]), 0, (0, 0, 255), 1)\n",
" cv2.imshow(winname, overlay_image)\n",
" return\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Image Energy"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def ImgEnrg(img, sigma):\n",
" blur = cv2.GaussianBlur(img, (int(math.ceil(3 * sigma)), int(math.ceil(3 * sigma))), 0)\n",
"\n",
" sobelx = cv2.Sobel(blur, cv2.CV_64F, 1, 0, ksize=5)\n",
" sobely = cv2.Sobel(blur, cv2.CV_64F, 0, 1, ksize=5)\n",
"\n",
" # cv2.imshow(\"blurred\", np.sqrt(np.add(sobelx**2, sobely**2)))\n",
" # cv2.waitKey()\n",
" return np.sqrt(np.add(sobelx ** 2, sobely ** 2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Get average distance between points"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def getAvgDist(points, n):\n",
" tot = 0.\n",
" for i in xrange(n):\n",
" tot += ((((points[i + 1:] - points[i]) ** 2).sum(1)) ** .5).sum()\n",
"\n",
" avg = tot / ((points.shape[0] - 1) * (points.shape[0]) / 2.)\n",
" return avg\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Function that calculates the index arithmetic index"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def getModulo(i, n):\n",
" modI = np.remainder(i, n)\n",
"\n",
" if modI == 0:\n",
" modI = n\n",
"\n",
" modIminus = modI - 1\n",
" modIplus = modI + 1\n",
"\n",
" if modIminus == 0:\n",
" modIminus = n\n",
"\n",
" if modIplus > n:\n",
" modIplus = 1\n",
" return modI, modIminus, modIplus"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"## Greedy Algoritm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def GreedyAlgorithm(points, img, alpha, beta, gamma, s, sigma, maxIt):\n",
" counter = 0\n",
" cThreshold = 0.3 # Set the curvature threshold\n",
" imgEnrgT = 120 # Set the image energy threshold\n",
" cnt = 0 # Define counter for the number of iterations\n",
"\n",
" # Initialize the alpha, beta and gamma values for each snake point\n",
" # Adding Columns\n",
" lengthofrows = points.shape[0]\n",
" z = np.zeros((lengthofrows, 3))\n",
" points = np.concatenate((points, z), axis=1)\n",
"\n",
" points[0] = np.array(points[0])\n",
" points[1] = np.array(points[1])\n",
"\n",
" points[:, 2] = alpha\n",
" points[:, 3] = beta\n",
" points[:, 4] = gamma\n",
" # Round indices of snake points\n",
"\n",
"\n",
" n = lengthofrows # number of points in snake\n",
"\n",
" enrgImg = ImgEnrg(img, sigma)\n",
"\n",
" avgDist = getAvgDist(points, n) # average distance between points\n",
"\n",
" a = s ** 2\n",
" dist = np.floor(s / 2)\n",
"\n",
" tmp = np.tile(((np.arange(1, 6)) - s + dist), (s, 1))\n",
" sz = tmp.shape[0] * tmp.shape[1]\n",
" x1 = np.reshape(tmp, (a, 1))\n",
" x2 = np.reshape(tmp, (a, 1), order='F')\n",
"\n",
" offsets = np.hstack((x2, x1))\n",
"\n",
" Econt = []\n",
" Ecurv = []\n",
" Eimg = []\n",
" c = []\n",
" cellArray = np.array([])\n",
"\n",
" I = [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5]\n",
" J = [1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5]\n",
" flag = True\n",
"\n",
" while flag == True:\n",
" pointsMoved = 0\n",
" p = np.random.permutation(n)\n",
" #\n",
" # # Iterate through all snake points randomly\n",
" for k in range(p.shape[0]):\n",
" for i in range(p[k]):\n",
" Econt = []\n",
" Ecurv = []\n",
" Eimg = []\n",
"\n",
" modI, modIminus, modIplus = getModulo(i, n)\n",
"\n",
" y0 = np.arange(points[modI - 1, 0] - dist, points[modI - 1, 0] + dist)\n",
"\n",
" y0 = np.append(y0, points[modI - 1, 0] + dist)\n",
" y1 = np.arange(points[modI - 1, 1] - dist, points[modI - 1, 1] + dist)\n",
" y1 = np.append(y1, points[modI - 1, 1] + dist)\n",
"\n",
" neighborhood = np.zeros((5, 5))\n",
" for l in range(y0.shape[0]):\n",
" for m in range(y1.shape[0]):\n",
" neighborhood[l][m] = enrgImg[y0[l], y1[m]]\n",
"\n",
" enrgMin = np.amin(neighborhood)\n",
" enrgMax = np.amax(neighborhood)\n",
"\n",
" if (enrgMax - enrgMin) < 5:\n",
" enrgMin = enrgMax - 5\n",
"\n",
" normNeigh = (enrgMin - neighborhood) / (enrgMax - enrgMin)\n",
" pos = np.array([0, 0])\n",
" # print offsets\n",
" for j in range(a):\n",
" print j\n",
" pos = points[i, [0, 2]] + offsets[j]\n",
" Econt.append(abs(avgDist - np.linalg.norm(np.subtract(pos, points[modIminus, [0, 2]]))))\n",
" Ecurv.append(np.linalg.norm(\n",
" np.subtract(points[modIminus, [0, 2]], 2 * pos + points[modIminus, [0, 2]])) ** 2)\n",
" Eimg.append(normNeigh[I[j] - 1, J[j] - 1])\n",
"\n",
" Econt = Econt / max(Econt)\n",
" Ecurv = Ecurv / max(Ecurv)\n",
"\n",
" Esnake = points[i, 3] * Econt + points[i, 3] * Ecurv + points[i, 4] * Eimg\n",
" #print Esnake\n",
" dummy, indexMin = np.amin(Esnake)\n",
"\n",
" if math.ceil(a / 2) != indexMin:\n",
" points[modI, [0, 2]] = np.add(points[modI, [0, 2]], offsets[indexMin])\n",
" pointsMoved = pointsMoved + 1\n",
"\n",
" points[6, modI] = neighborhood[I(indexMin) - 1, J(indexMin) - 1]\n",
"\n",
" for j in range(n):\n",
" modI, modIminus, modIplus = getModulo(i, n)\n",
" if (c[modI] > c[modIminus] and c[modI] > c[modIplus] and c[modI] > cThreshold and points[\n",
" 6, modI] > imgEnrgT and points[4, modI] != 0):\n",
" points[4, modI] = 0\n",
" print 'Relaxed beta for point nr. ' + i\n",
"\n",
" counter += 1\n",
" cellArray[counter] = points\n",
"\n",
" if (counter == maxIt or pointsMoved < 3):\n",
" flag = False\n",
" cellArray = cellArray[1:counter]\n",
"\n",
" avgDist = getAvgDist(points, n)\n",
"\n",
" return points\n",
"\n",
"def displaypoints(img,points):\n",
" img = plt.imread(img)\n",
" implot = plt.imshow(img)\n",
" plt.scatter(points)\n",
" plt.show()\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Start Here (Main)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"if __name__ == \"__main__\":\n",
" img = cv2.imread('shark1.png', 0)\n",
" pointselection = \"wd\"\n",
" if pointselection == \"user\":\n",
" points = contour_selection(img, \"Selection of points\")\n",
" else:\n",
" points = np.array(\n",
" [219, 218, 215, 211, 207, 201, 195, 188, 180, 172, 163, 154, 146, 137, 128, 120, 112, 105, 99, 93, 89,\n",
" 119,\n",
" 127, 136, 144, 151, 158, 164, 169, 173, 177, 179, 180, 180, 179, 177, 173, 169, 164, 158, 151, 144, 85,\n",
" 82,\n",
" 81, 80, 81, 82, 85, 89, 93, 99, 105, 112, 120, 128, 137, 146, 154, 163, 172, 180, 188, 136, 127, 119,\n",
" 110,\n",
" 101, 93, 84, 76, 69, 62, 56, 51, 47, 43, 41, 40, 40, 41, 43, 47, 51, 195, 201, 207, 211, 215, 218, 219,\n",
" 220, 56, 62, 69, 76, 84, 93, 101, 110])\n",
" points = np.reshape(points, (-1, 2))\n",
" # for i in range(50):\n",
" # points[i][j] = c[0] + math.floor(r * math.cos((i) * 2 * math.pi / i) + 0.5)\n",
" # points[i][j] = c[1] + math.floor(r * math.sin((i) * 2 * math.pi / i) + 0.5)\n",
"\n",
" alpha = 0.05 # controls continuity\n",
" beta = 1 # controls curvature\n",
" gamma = 1.2 # controls strength of image energy\n",
" s = 5 # controls the size of the neighborhood\n",
" sigma = 15 # controls amount of Gaussian blurring\n",
" maxIt = 200 # Defines the maximum number of snake iterations\n",
"\n",
" C = GreedyAlgorithm(points, img, alpha, beta, gamma, s, sigma, maxIt)\n",
" #print C\n",
" displaypoints(img,C)\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Output Image"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"http://glimpglobe.com/proj/acm.gif\">"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.11"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment