Skip to content

Instantly share code, notes, and snippets.

@vstrimaitis
Last active February 14, 2018 11:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vstrimaitis/bfc6d1e6ec019b09732f3061e0798c2d to your computer and use it in GitHub Desktop.
Save vstrimaitis/bfc6d1e6ec019b09732f3061e0798c2d to your computer and use it in GitHub Desktop.
Trumpas tutorial, parodantis, kaip susisetup'int CapsNet su Tensorflow (GPU version) minimal working example.

Prerequisites:

  • Python 3.6 64bit.
  • Tensorflow (pip3 install --upgrade tensorflow-gpu)
  • NVIDIA: CUDA® Toolkit 9.0 ir cuDNN v6.0 (daugiau paaiškinta čia).
  1. Nusiklonuot repo:

    cd /path/to/capsnet
    git clone https://github.com/naturomics/CapsNet-Tensorflow.git
    cd CapsNet-Tensorflow
    
  2. Atsisiųst MNIST dataset:

    python download_data.py
    
  3. Paleist apmokymą:

    python main.py
    

    Šitam step'ui prireikė kelių bandymų, nes vis sulūždavo, nes pritrūkdavo atminties. Sprendimas: config.py faile mažinti batch_size parametrą kol pradės veikt (man pradėjo veikt prie, berods, 32).

    Man asmeniškai užtruko parą kol pabaigė. Mano GPU: NVIDIA GeForce GT 750M, su normalia vaizdo korta turėtų būt gerokai greičiau :)

    Dar galima pažiūrėt visokius grafikus. Tam reikia atsidaryt dar vieną cmd langą ir rašyt:

    cd /path/to/capsnet/CapsNet-Tensorflow
    tensorboard --logdir=logdir
    

    Konsolėj turėtų parašyt, kokiam endpoint'e galima prieit duomenis (man buvo PC_VARDAS:6006). Įvedus to endpoint adresą browser'yje pasirodys ir periodiškai atsinaujins mokymosi duomenys.


Pabaigus mokymą galima išsibandyt ir ant savų paveiksliukų:

  1. config.py faile nustatyt batch_size į 1.
  2. Papildyti capsNet.py failą:
    # b). pick out the index of max softmax val of the 10 caps
    # [batch_size, 10, 1, 1] => [batch_size] (index)
    self.argmax_idx = tf.to_int32(tf.argmax(self.softmax_v, axis=1))
    self.prediction = tf.argmax(self.softmax_v, axis=1) # <<<<<<<<<<<< Šitą pridėti
    assert self.argmax_idx.get_shape() == [cfg.batch_size, 1, 1]
    self.argmax_idx = tf.reshape(self.argmax_idx, shape=(cfg.batch_size, ))
    
  3. Pakeist main.py failą:
  • Pridėt import cv2
  • Pridėt funkciją paveiksliuko nuskaitymui iš failo:
    def read_img(path):
      img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
      img = cv2.resize(img, (28, 28)) # Pakeičia paveiksliuko dydį į 28x28 (nes tokį naudoja CapsNet'as)
      # cv2.imwrite(path+".resized.png", img) # Dėl visa ko galima išsaugot pakeisto dydžio paveiksliuką
      img = img / 255.0 # Sunormalizuoja pikselius (padaro, kad reikšmės būtų intervale [0, 1])
      img = np.expand_dims(img, axis=0) # Formą pakeičia iš (28, 28) į (1, 28, 28)
      img = np.expand_dims(img, axis=3) # Formą pakeičia iš (1, 28, 28) į (1, 28, 28, 1)
      # Formą (1, 28, 28, 1) naudoja pats tinklas, kitokio formato nepriima (bent kiek kol kas bandžiau neradau kito workaround)
      return img
    
  • Pasirašyt funkciją paveiksliukų testavimui:
    def test_images(model, supervisor, num_label, img_path, img_nums):
      with supervisor.managed_session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
          supervisor.saver.restore(sess, tf.train.latest_checkpoint(cfg.logdir))
          print('Model restored!')
          for img_num in img_nums:
              full_img_path = img_path + str(img_num) + ".png"
              img = read_img(full_img_path)
              predictions = sess.run(model.prediction, {model.X: img})
              #acc = sess.run(model.accuracy, {model.X: img}) # this gives only the accuracy, not the actual result
              print(predictions)
    
  • Papildyt main() funkciją. else kodą pakeisti iš evaluation(model, sv, num_label) į test_images(model, sv, num_label, "/path/to/images/", [2, 3, 5, 8]). Šitas kodas paima paveiksliukus 2.png, 3.png, 5.png ir 8.png iš folder'io /path/to/images/
@vstrimaitis
Copy link
Author

Aš pasitestavimui naudojau šituos paveiksliukus (per paint galima nusipiešt kažką kito):

2
3
5
8

PS: 5.png paimta iš paties MNIST dataseto. Užsiliko nuo bandymų visokių.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment