Skip to content

Instantly share code, notes, and snippets.

@huangzhuolin
Last active June 2, 2021 12:37
Show Gist options
  • Save huangzhuolin/528800a7a49b77877ad7bcad08b6cd59 to your computer and use it in GitHub Desktop.
Save huangzhuolin/528800a7a49b77877ad7bcad08b6cd59 to your computer and use it in GitHub Desktop.

Experiments using TorchServe to run PULSE mode

Install

https://github.com/pytorch/serve#install-torchserve-and-torch-model-archiver

Serve a model as a RESTful API

Achive the model

torch-model-archiver --model-name pulse --version 1.0 --model-file PULSE.py --serialized-file model.pt --export-path ./model_store --handler pulse_serve_handler:entry_point_function_name --extra-files PULSE.py,align_face.py,bicubic.py,drive.py,gaussian_fit.pt,loss.py,mapping.pt,SphericalOptimizer.py,shape_predictor.py,stylegan.py,synthesis.pt -f

Note the --extra-files parameters. All dependent files of the model should be specified and they will be packaged into the .mar file.

model.pt

To load the mapping and synthesis network at run time, we need to save the entire model first. In our case it's hard to load the dependent parameters (eg, mapping and synthesis) when serving the request.

import PULSE from PULSE
import torch
model = PULSE(...)
model.save('model.pt')

Custom handler

There are only 4 types of handler for some basic classification tasks. In our case the api should return a super-resolution image to the client. So a custom handler is needed.

import os
import torch

# Create model object
model = None


def entry_point_function_name(data, context):
    """
    Works on data and context to create model object or process inference request.
    Following sample demonstrates how model object can be initialized for jit mode.
    Similarly you can do it for eager mode models.
    :param data: Input data for prediction
    :param context: context contains model server system properties
    :return: prediction output
    """
    global model

    if not data:
        manifest = context.manifest

        properties = context.system_properties
        model_dir = properties.get("model_dir")
        device = torch.device(
            "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

        # Read model serialize/pt file
        serialized_file = manifest['model']['serializedFile']
        model_pt_path = os.path.join(model_dir, serialized_file)
        if not os.path.isfile(model_pt_path):
            raise RuntimeError("Missing the model.pt file")

        model = torch.load(model_pt_path)
    else:
        ref_im = torch.randn(
            (1, 3, 1024, 1024), dtype=torch.float, requires_grad=True, device='cpu')
        kwargs = {'input_dir': './input',
                  'output_dir': './output',
                  'cache_dir': 'cache',
                  'duplicates': 1,
                  'batch_size': 1,
                  'seed': 23,
                  'loss_str': '100*L2+0.05*GEOCROSS',
                  'eps': 0.02,
                  'noise_type': 'zero',
                  'num_trainable_noise_layers': 5,
                  'tile_latent': False,
                  'bad_noise_layers': '17',
                  'opt_name': 'adam',
                  'learning_rate': 0.4,
                  'steps': 10,
                  'lr_schedule': 'linear1cycledrop',
                  'save_intermediate': True}
        (HR, LR) = next(model(ref_im, **kwargs))

        return [HR]

When we start the server, it will invoke this entry point one time with no data, and a model instance will be initialized. Later when a request comes, it will call the forward method of this instance.

Server configuration

The server will start four workers for each model by default. This configuration can be changed using config.properties file.

default_workers_per_model=1

https://pytorch.org/serve/configuration.html

Start the server

torchserve --start --ncs --model-store model_store --models pulse

Issue: Message size exceed limit

In our case the server needs to return the generated SR image to the client. However, the size of the response is limited in the server. An error occurs:

io.netty.handler.codec.CorruptedFrameException: Message size exceed limit: 12583737
        at org.pytorch.serve.util.codec.CodecUtils.readLength(CodecUtils.java:24)
        at org.pytorch.serve.util.codec.ModelResponseDecoder.decode(ModelResponseDecoder.java:72)
        at io.netty.handler.codec.ByteToMessageDecoder.decodeRemovalReentryProtection(ByteToMessageDecoder.java:501)
        at io.netty.handler.codec.ByteToMessageDecoder.callDecode(ByteToMessageDecoder.java:440)
        at io.netty.handler.codec.ByteToMessageDecoder.channelRead(ByteToMessageDecoder.java:276)
        at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:379)
        at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:365)
        at io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:357)
        at io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1410)
        at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:379)
        at io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:365)
        at io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:919)
        at io.netty.channel.kqueue.AbstractKQueueStreamChannel$KQueueStreamUnsafe.readReady(AbstractKQueueStreamChannel.java:544)
        at io.netty.channel.kqueue.KQueueDomainSocketChannel$KQueueDomainUnsafe.readReady(KQueueDomainSocketChannel.java:131)
        at io.netty.channel.kqueue.AbstractKQueueChannel$AbstractKQueueUnsafe.readReady(AbstractKQueueChannel.java:382)
        at io.netty.channel.kqueue.KQueueEventLoop.processReady(KQueueEventLoop.java:211)
        at io.netty.channel.kqueue.KQueueEventLoop.run(KQueueEventLoop.java:289)
        at io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:989)
        at io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74)
        at io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30)
        at java.base/java.lang.Thread.run(Thread.java:829)

TorchServe uses a netty to serve the requests. In the CodecUtils.java it specifies the maximum size of a message. So in our case a traditional server may be more suitable

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment