Skip to content

Instantly share code, notes, and snippets.

@lkevinzc
Last active October 16, 2020 15:27
Show Gist options
  • Save lkevinzc/ea5e7e1c6d5f99854ccaf6cc66f5e6df to your computer and use it in GitHub Desktop.
Save lkevinzc/ea5e7e1c6d5f99854ccaf6cc66f5e6df to your computer and use it in GitHub Desktop.
a glance at paddle framework
"""
[DS-CV Paddle Sharing]
16/09/2020
zichen.liu
"""
"""
Concepts (1): Variables
"""
from paddle import fluid
# 1) learnable
learnable_w = fluid.layers.create_parameter(name="fc_w",shape=[100, 50],dtype='float32')
learnable_b = fluid.layers.create_parameter(name="fc_b",shape=[50],dtype='float32')
# 2) placeholder (usually for feeding data)
x = fluid.data(name="input", shape=[None, 100], dtype="float32")
# 3) constant
const_data = fluid.layers.fill_constant(shape=[1], value=7, dtype='int64')
# *1) --> encapsulation: fully connected
y = fluid.layers.fc(input=x, size=50, bias_attr=True)
# Notes:
# Other encapsulations together with their associated operator (like mult, conv)
# are in paddle.fluid.layers, paddle.fluid.nets
"""
Concepts (2): Operators
"""
# most are encapsulated inside paddle.fluid.layers, paddle.fluid.nets
# can either contain trainable params or not contain
import paddle.fluid as fluid
import numpy
a = fluid.data(name="a",shape=[1],dtype='float32')
b = fluid.data(name="b",shape=[1],dtype='float32')
result = fluid.layers.elementwise_add(a,b)
# define executor and execution place as cpu
cpu = fluid.core.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(fluid.default_startup_program())
x = numpy.array([5]).astype("float32")
y = numpy.array([7]).astype("float32")
outs = exe.run(
feed={'a':x,'b':y},
fetch_list=[result, a, b])
"""
Concepts (3): Program & Executor [PCA as an example]
"""
# Paddle uses "program" to describe model construction & running.
# Operators will be put into program sequentially.
# Executor accepts the defined program and convert to C++ backend,
# after which executes the compiled "FluidProgram".
import paddle.fluid as fluid
import numpy
# 1) define a program
raw_data = fluid.data(name='raw_data_name',shape=[None, 2000, None], dtype='float32')
pca_weights = fluid.data(name='pca_weights_name',shape=[None, 100, 2000], dtype='float32')
# bmm - batched matrix multiplication
embedding = fluid.layers.bmm(x=pca_weights, y=raw_data)
# 2) define executor and execution place as cpu
cpu = fluid.core.CPUPlace()
exe = fluid.Executor(cpu)
# initialize
exe.run(fluid.default_startup_program())
input_vector = numpy.random.rand(1, 2000, 5).astype("float32")
pre_computed_pca = numpy.random.rand(1, 100, 2000).astype("float32")
outs = exe.run(
feed={'raw_data_name':input_vector,'pca_weights_name':pre_computed_pca},
fetch_list=[embedding])
"""
1. Paddle supports both imperative & declarative paradigms.
2. Paddle defines "program", which consists of user-defined operations on variables.
3. Executor is to run the program after converting it into C++ backend FluidProgram.
4. One issue of static graph is about control flow (if/else, for loop)
- Paddle offers control flow OP: e.g. fluid.layers.Switch, fluid.layers.While
- Use imperative programming (dynamic graph like PyTorch)
"""
"""
Concepts (4): Reader
"""
# Reader is about getting data (from file, network, etc.)
# and generating the correct format of data for our model.
import paddle.fluid as fluid
import numpy as np
# 1) shuffled batch reader (mini-batch sgd)
def reader():
for i in range(10):
yield i
batch_reader = fluid.io.batch(fluid.io.shuffle(reader, buf_size=4), batch_size=5)
for data in batch_reader():
print(data)
# 2) multiprocess reader
num_workers = 2
class MyDataSet(object):
def __init__(self):
self.DATASET = np.arange(10) # [0, ..., 9]
def __call__(self, process_id):
def gen_func():
for i, x in enumerate(self.DATASET):
if i % num_workers == process_id:
yield x, process_id
return gen_func
reader_function = MyDataSet()
readers = []
for process_id in range(num_workers):
readers.append(reader_function(process_id))
decorated_reader = fluid.io.multiprocess_reader(readers, use_pipe=False)
for x in decorated_reader():
print(x)
"""
Concepts (5): DataLoader
"""
# In imperative programming, DataLoader is asynchronous
# and accelerated by threading (high efficiency).
import paddle.fluid as fluid
import numpy as np
place = fluid.CUDAPlace(0)
fluid.enable_imperative()
class MyDataSet(object):
def __init__(self):
self.IMG = np.random.rand(10, 5, 5) # 10 images with size 5 x 5
self.LABEL = np.random.randint(0, 10, (10,)) # their labels ranging from 0~9
def __len__(self):
return len(self.IMG)
def __call__(self):
def gen_func():
for i in range(len(self)):
yield self.IMG[i], self.LABEL[i]
return gen_func
# reader
reader = MyDataSet()
# DataLoader
data_loader = fluid.io.DataLoader.from_generator(capacity=10) # like buffer
'''
data_loader = fluid.io.DataLoader.from_generator(capacity=10, use_multiprocess=True) # get rid of GIL
'''
data_loader.set_sample_generator(reader(), batch_size=2, places=place)
for x in data_loader():
img, label = x
break
img.numpy(), label.numpy()
"""
Concepts (6): Model I/O
"""
# All input & output of an operator are [Variable]
# [Parameter] is sub-class of [Variable], learnable
# [Persistables] furthur includes lr, global step, etc.
# 8 APIs:
# fluid.io.save_vars <-> fluid.io.load_vars
# fluid.io.save_params <-> fluid.io.load_params
# fluid.io.save_persistables <-> fluid.io.load_persistables
# fluid.io.save_inference_model <-> fluid.io.load_inference_model
import paddle.fluid as fluid
import numpy as np
main_prog = fluid.Program()
startup_prog = fluid.default_startup_program()
with fluid.program_guard(main_prog, startup_prog):
data = fluid.layers.data(name="img", shape=[1, 128], append_batch_size=False, dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
w = fluid.layers.create_parameter(shape=[128, 10], dtype='float32', name='fc_w')
b = fluid.layers.create_parameter(shape=[10], dtype='float32', name='fc_b')
hidden_w = fluid.layers.matmul(x=data, y=w)
predict = fluid.layers.elementwise_add(hidden_w, b)
loss = fluid.layers.cross_entropy(input=predict, label=label)
avg_loss = fluid.layers.mean(loss)
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
optimizer.minimize(avg_loss)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
# 1) save_vars
var_list = [w, b]
path = "./01_paddle_vars"
fluid.io.save_vars(executor=exe, dirname=path, vars=var_list)
# 2) save_params
params_path = "./02_paddle_model_params"
fluid.io.save_params(executor=exe, dirname=params_path, main_program=main_prog)
# 3) ** save_persistables ** -> allow finetuning
persis_path = "./03_paddle_checkpoint"
fluid.io.save_persistables(executor=exe, dirname=persis_path, main_program=main_prog)
# 4) ** save_inference_model ** -> allow deployment
# automatically prune compute graph
infer_path = "./04_paddle_inference"
fluid.io.save_inference_model(
dirname=infer_path,
feeded_var_names=['img'],
target_vars=[predict],
executor=exe,
main_program=main_prog,
model_filename='model',
params_filename='params')
# 1) Prepare data
import paddle
import matplotlib.pylab as plt
# Define batch size
BATCH_SIZE = 64
# create reader
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE, drop_last=True)
for tmp_data in train_reader():
break
print('Label:', tmp_data[0][1])
print('Image')
plt.imshow(tmp_data[0][0].reshape(28, 28), cmap='gray')
# 2) Build model
import paddle.fluid as fluid
from paddle.fluid.dygraph.nn import Conv2D, Pool2D
# 2.1) Construct a conv-(relu)-pool block,
# must inherit [fluid.dygraph.Layer] class
class SimpleImgConvPool(fluid.dygraph.Layer):
# define & initialize the network
def __init__(self,
num_channels,
num_filters,
filter_size,
pool_size,
pool_stride,
pool_padding=0,
pool_type='max',
global_pooling=False,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConvPool, self).__init__()
# Conv2D initialization
self._conv2d = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
act=act,
use_cudnn=use_cudnn)
# Pool2D initialization
self._pool2d = Pool2D(
pool_size=pool_size,
pool_type=pool_type,
pool_stride=pool_stride,
pool_padding=pool_padding,
global_pooling=global_pooling,
use_cudnn=use_cudnn)
# define forward pass logic
def forward(self, inputs):
x = self._conv2d(inputs)
x = self._pool2d(x)
return x
# 2.2) Construct a digit recognition network,
# must inherit [fluid.dygraph.Layer] class
class MnistNet(fluid.dygraph.Layer):
def __init__(self):
super(MnistNet, self).__init__()
self._simple_img_conv_pool_1 = SimpleImgConvPool(
1, 20, 5, 2, 2, act="relu")
self._simple_img_conv_pool_2 = SimpleImgConvPool(
20, 50, 5, 2, 2, act="relu")
# after passing through above two layers,
# the feature dimension is {c=50, h=w=4}
self.hidden_size = 50 * 4 * 4
# number of classes
SIZE = 10
self.output_weight = self.create_parameter(
[self.hidden_size, 10])
def forward(self, inputs, label=None):
x = self._simple_img_conv_pool_1(inputs)
x = self._simple_img_conv_pool_2(x)
x = fluid.layers.reshape(x, shape=[-1, self.hidden_size])
x = fluid.layers.matmul(x, self.output_weight)
x = fluid.layers.softmax(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
return x, acc
else:
return x
# 3) Start training!
'''
while not converged:
1. get batched data, convert from `numpy.ndarray` to `Variable`
2. forward pass the data into the model, obtain its prediction
3. calculate the loss value according the prediction and label
4. do error back propagation to get gradients
5. update model parameters following the optimizer's rules
6. clear gradients so that they do not accumulate
7. save the trained model
'''
import numpy as np
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.base import to_variable
place = fluid.CUDAPlace(0)
# place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
# instantiate a network from previous definition
my_hello_world_net = MnistNet()
# define the optimizer, set lr = 0.001, give parameters over which
# the optimizer's update rules will apply
optimizer = AdamOptimizer(learning_rate=0.001, parameter_list=my_hello_world_net.parameters())
# number of times we want to loop over the whole dataset
epoch_num = 5
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28) for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape(-1, 1)
# convert from `numpy.ndarray` to `Variable`
img = to_variable(dy_x_data)
label = to_variable(y_data)
# forward pass
cost, acc = my_hello_world_net(img, label)
# calculate loss
loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)
# error bp
avg_loss.backward()
# update parameters
optimizer.minimize(avg_loss)
# clear gradients manually
my_hello_world_net.clear_gradients()
# simple logging
if batch_id % 300 == 0:
print("Loss at epoch {} step {}: {:}".format(
epoch, batch_id, avg_loss.numpy()))
# model i/o - saving
model_dict = my_hello_world_net.state_dict()
fluid.save_dygraph(model_dict, "save_temp")
# side note: save optimizor info for future finetuning
optim_dict = optimizer.state_dict()
fluid.save_dygraph(optim_dict, "save_temp")
# to resume
# _, opt_state= fluid.load_dygraph("save_temp")
# optimizer.set_dict(opt_state)
# 4) Evaluation
with fluid.dygraph.guard():
# instantiate a network from previous definition
my_hello_world_net_eval = MnistNet()
# model i/o - loading, omitting opt_state here
model_dict, _ = fluid.load_dygraph("save_temp")
my_hello_world_net_eval.set_dict(model_dict)
print("checkpoint loaded")
# eval mode - will change the behaviour of some operators (e.g. BatchNorm)
my_hello_world_net_eval.eval()
acc_set = []
avg_loss_set = []
# use test set
for batch_id, data in enumerate(test_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
# convert from `numpy.ndarray` to `Variable`
img = to_variable(dy_x_data)
label = to_variable(y_data)
# forward pass
prediction, acc = my_hello_world_net_eval(img, label)
# calculate loss, accuracy
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
acc_set.append(float(acc.numpy()))
avg_loss_set.append(float(avg_loss.numpy()))
acc_val_mean = np.array(acc_set).mean()
avg_loss_val_mean = np.array(avg_loss_set).mean()
print("Eval avg_loss is: {}, acc is: {}".format(avg_loss_val_mean, acc_val_mean))
# 5) Deployment
'''
- Dynamic graph is easy to program, but less efficient for deployment.
- Thus we need to first convert it into static graph via tracing the network,
then save it to inference model.
- The saved model can be deployed using C++ or Python to do online inference.
'''
from paddle.fluid.dygraph import TracedLayer
deploy_save_dirname = 'saved_infer_model'
with fluid.dygraph.guard():
# instantiate a network from previous definition
my_net_deploy = MnistNet()
# model i/o - loading, omitting opt_state here
model_dict, _ = fluid.load_dygraph("save_temp")
my_net_deploy.set_dict(model_dict)
print("checkpoint loaded")
# not necessarily be meaningful image
in_np = np.random.random([10, 1, 28, 28]).astype('float32')
# convert from `numpy.ndarray` to `Variable`
input_var = fluid.dygraph.to_variable(in_np)
# [conversion] dynamic -> static
out_dygraph, static_layer = TracedLayer.trace(my_net_deploy, inputs=[input_var])
# save the traced layer, can be used as model interface for deployment
# `feed` and `fetch` are the index of model input/output
static_layer.save_inference_model(deploy_save_dirname, feed=[0], fetch=[0])
print("inference model exported")
# 6) Inference
import cv2
from glob import glob
import matplotlib.pylab as plt
import numpy as np
img_fns = sorted(glob('mytest_*.jpg'))
deploy_save_dirname = './work/saved_infer_model'
def read_image(im_fn):
im = cv2.imread(im_fn, cv2.IMREAD_GRAYSCALE)
im = 255 - im
kernel = np.ones((3,3),np.uint8)
im = cv2.dilate(im, kernel, iterations = 1)
im = cv2.resize(im, (28, 28), cv2.INTER_CUBIC)
return im
fig, axs = plt.subplots(1, len(img_fns), figsize=(7,7))
np_imgs = []
for i in range(len(img_fns)):
this_im = read_image(img_fns[i])
# normalization
this_im = ((this_im / 255. - 0.5) / 0.5).astype(np.float32)
axs[i].imshow(this_im, cmap='gray')
axs[i].axis('off')
this_im = this_im[np.newaxis, :]
np_imgs.append(this_im)
np_imgs = np.stack(np_imgs)
print(np_imgs.shape)
'''
For this simple task, write only 6 lines for model prediction in deployment code!!
- no network definition
- load a clean compute graph with weights for inference
'''
import paddle.fluid as fluid
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feed_vars, fetch_vars = fluid.io.load_inference_model(deploy_save_dirname, exe)
fetch, = exe.run(program, feed={feed_vars[0]: np_imgs}, fetch_list=fetch_vars)
fetch.argmax(1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment