Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchserve docker fails to run on existing mar file #1010

Open
ayushch3 opened this issue Mar 12, 2021 · 16 comments
Open

Torchserve docker fails to run on existing mar file #1010

ayushch3 opened this issue Mar 12, 2021 · 16 comments
Assignees
Labels
bug Something isn't working triaged_wait Waiting for the Reporter's resp
Milestone

Comments

@ayushch3
Copy link

Context

I ran the torch model archiver on a different machine to create a mar with a custom handler for transformer model using this command:

torch-model-archiver --model-name TranslationClassifier --version 1.0 --serialized-file /home/ayush/transformer_model/pytorch_model.bin --handler ./translation_model/text_handler.py --extra-files "./transformer_model/config.json,./transformer_model/special_tokens_map.json,./transformer_model/tokenizer_config.json,./transformer_model/sentencepiece.bpe.model"

It took about 20 mins and the mar file was created correctly. I was able to locally verify torch serve indeed works on that system using the following command:

torchserve --start --model-store model_store --models my_tc=TranslationClassifier.mar

Expected Behavior

In order to run this on kubernetes, I took the pre-existing pytorch/torchserve:latest-gpu image from docker hub, so that I can run in a different environment by leveraging the mar file directly using this command:

sudo docker run -p 8080:8080 -p 8081:8081 -p 8082:8082 -p 7070:7070 -p 7071:7071 --mount type=bind,source=/home/ayush,target=/home/ayush/model_store pytorch/torchserve:latest-gpu torchserve --model-store /home/ayush/model_store --models my_tc=TranslationClassifier.mar

Current Behavior

The execution fails when running that docker container with the following error logs:

2021-03-12 21:13:43,128 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/local/lib/python3.6/dist-packages/ts/model_service_worker.py", line 182, in <module>
2021-03-12 21:13:43,128 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     worker.run_server()
2021-03-12 21:13:43,128 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/local/lib/python3.6/dist-packages/ts/model_service_worker.py", line 154, in run_server
2021-03-12 21:13:43,128 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     self.handle_connection(cl_socket)
2021-03-12 21:13:43,128 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/local/lib/python3.6/dist-packages/ts/model_service_worker.py", line 116, in handle_connection
2021-03-12 21:13:43,129 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     service, result, code = self.load_model(msg)
2021-03-12 21:13:43,129 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/local/lib/python3.6/dist-packages/ts/model_service_worker.py", line 89, in load_model
2021-03-12 21:13:43,129 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     service = model_loader.load(model_name, model_dir, handler, gpu, batch_size, envelope)
2021-03-12 21:13:43,129 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/local/lib/python3.6/dist-packages/ts/model_loader.py", line 83, in load
2021-03-12 21:13:43,130 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     module = self._load_default_handler(handler)
2021-03-12 21:13:43,130 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/local/lib/python3.6/dist-packages/ts/model_loader.py", line 120, in _load_default_handler
2021-03-12 21:13:43,130 [INFO ] W-9002-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     module = importlib.import_module(module_name, 'ts.torch_handler')
2021-03-12 21:13:43,131 [INFO ] epollEventLoopGroup-5-7 org.pytorch.serve.wlm.WorkerThread - 9002 Worker disconnected. WORKER_STARTED

Steps to Reproduce

  1. Run torch-model-archiver on a model and take the mar file into a different machine
  2. Run the docker image for pytorch serve with that existing mar file. It looks like its unable to find the customer handler that was used while running torch model archiver. My understanding is the mar file should have captured this information and when running the torchserve --start --model-store model_store --models my_tc=TranslationClassifier.mar in a different environment, it should run out of the box and not failure to recognize the custom handler
    ...
@ayushch3 ayushch3 changed the title Torch server docker fails to run on existing mar file Torchserve docker fails to run on existing mar file Mar 12, 2021
@dhanainme
Copy link
Collaborator

dhanainme commented Mar 15, 2021

Its not clear from the logs on what your custom handler is trying to do here. Addding the handler code would help out to understand more details on whats going on.

 module = importlib.import_module(module_name, 'ts.torch_handler')

Also, Were you able to run your model in a standalone TS ?

@dhanainme dhanainme added the triaged_wait Waiting for the Reporter's resp label Mar 15, 2021
@ayushch3
Copy link
Author

ayushch3 commented Mar 15, 2021

Handler file(text_handler.py):

from abc import ABC
import json
import logging
import os
import ast
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class TransformersClassifierHandler(BaseHandler, ABC):
    """
    Transformers text classifier handler class. This handler takes a text (string) and
    as input and returns the classification text based on the serialized transformers checkpoint.
    """
    def __init__(self):
        super(TransformersClassifierHandler, self).__init__()
        self.initialized = False

    def initialize(self, ctx):
        self.manifest = ctx.manifest

        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

        # Read model serialize/pt file
        self.model = MBartForConditionalGeneration.from_pretrained(model_dir)
        self.tokenizer = MBart50TokenizerFast.from_pretrained(model_dir)

        self.model.to(self.device)
        self.model.eval()

        logger.debug('Transformer model from path {0} loaded successfully'.format(model_dir))

        # Read the mapping file, index to object name
        mapping_file_path = os.path.join(model_dir, "index_to_name.json")

        if os.path.isfile(mapping_file_path):
            with open(mapping_file_path) as f:
                self.mapping = json.load(f)
        else:
            logger.warning('Missing the index_to_name.json file. Inference output will not include class name.')

        self.initialized = True

    def preprocess(self, data):
        """ Very basic preprocessing code - only tokenizes. 
            Extend with your own preprocessing steps as needed.
        """
        text = data[0].get("data")
        if text is None:
            text = data[0].get("body")
        text = text.decode('utf-8')
        input_text = ast.literal_eval(text)
        contents = input_text['contents']
        source_language_code = input_text['source_language_code']
        target_language_code = input_text['target_language_code']
        #logger.info("Received text: '%s'", sentences)

        self.tokenizer.src_lang = source_language_code
        inputs = self.tokenizer(contents, return_tensors="pt")
        generated_tokens = self.model.generate(
            **inputs,
            forced_bos_token_id=self.tokenizer.lang_code_to_id[target_language_code]
        )
        return generated_tokens

    def inference(self, generated_tokens):
        """
        Predict the class of a text using a trained transformer model.
        """
        # NOTE: This makes the assumption that your model expects text to be tokenized  
        # with "input_ids" and "token_type_ids" - which is true for some popular transformer models, e.g. bert.
        # If your transformer model expects different tokenization, adapt this code to suit 
        # its expected input format.
        prediction = self.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

        if self.mapping:
            prediction = self.mapping[str(prediction)]

        return prediction

    def postprocess(self, inference_output):
        # TODO: Add any needed post-processing of the model predictions here
        return inference_output


_service = TransformersClassifierHandler()


def handle(data, context):
    try:
        if not _service.initialized:
            _service.initialize(context)

        if data is None:
            return None

        data = _service.preprocess(data)
        data = _service.inference(data)
        data = _service.postprocess(data)

        return data
    except Exception as e:
        raise e

Dockerfile:

FROM pytorch/torchserve:latest-gpu
COPY transformer_model/pytorch_model.bin transformer_model/config.json \
  transformer_model/special_tokens_map.json transformer_model/tokenizer_config.json \ 
    transformer_model/sentencepiece.bpe.model text_handler.py /home/model-server/

ENV DEBIAN_FRONTEND=nonintercative
RUN export USE_CUDA=1
USER model-server

RUN torch-model-archiver \
  --model-name=TranslationClassifier \
  --version=1.0 \
  --serialized-file=/home/model-server/pytorch_model.bin \
  --handler=/home/model-server/text_handler.py \
  --export-path=/home/model-server/model-store \
  --extra-files=/home/model-server/config.json,/home/model-server/special_tokens_map.json,/home/model-server/tokenizer_config.json,/home/model-server/sentencepiece.bpe.model

CMD ["torchserve", "--start", "--model-store", "/home/model-server/model-store", "--models",  "my_tc=TranslationClassifier.mar"]

@ayushch3
Copy link
Author

ayushch3 commented Mar 15, 2021

2021-03-16 20:18:28,917 [INFO ] pool-2-thread-1 TS_METRICS - CPUUtilization.Percent:100.0|#Level:Host|#hostname:563d8708dab6,timestamp:1615925908
2021-03-16 20:18:28,918 [INFO ] pool-2-thread-1 TS_METRICS - DiskAvailable.Gigabytes:33.27151870727539|#Level:Host|#hostname:563d8708dab6,timestamp:1615925908
2021-03-16 20:18:28,919 [INFO ] pool-2-thread-1 TS_METRICS - DiskUsage.Gigabytes:18.78311538696289|#Level:Host|#hostname:563d8708dab6,timestamp:1615925908
2021-03-16 20:18:28,919 [INFO ] pool-2-thread-1 TS_METRICS - DiskUtilization.Percent:36.1|#Level:Host|#hostname:563d8708dab6,timestamp:1615925908
2021-03-16 20:18:28,919 [INFO ] pool-2-thread-1 TS_METRICS - MemoryAvailable.Megabytes:5738.9609375|#Level:Host|#hostname:563d8708dab6,timestamp:1615925908
2021-03-16 20:18:28,919 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUsed.Megabytes:457.17578125|#Level:Host|#hostname:563d8708dab6,timestamp:1615925908
2021-03-16 20:18:28,920 [INFO ] pool-2-thread-1 TS_METRICS - MemoryUtilization.Percent:11.1|#Level:Host|#hostname:563d8708dab6,timestamp:1615925908
2021-03-16 20:19:00,384 [INFO ] W-9001-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9001-my_tc_1.0-stdout
2021-03-16 20:19:00,384 [INFO ] W-9001-my_tc_1.0-stderr org.pytorch.serve.wlm.WorkerLifeCycle - Stopped Scanner - W-9001-my_tc_1.0-stderr
2021-03-16 20:19:02,008 [INFO ] epollEventLoopGroup-5-2 org.pytorch.serve.wlm.WorkerThread - 9001 Worker disconnected. WORKER_STARTED
2021-03-16 20:19:02,072 [DEBUG] W-9001-my_tc_1.0 org.pytorch.serve.wlm.WorkerThread - System state is : WORKER_STARTED
2021-03-16 20:19:02,105 [DEBUG] W-9001-my_tc_1.0 org.pytorch.serve.wlm.WorkerThread - Backend worker monitoring thread interrupted or backend worker process died.
java.lang.InterruptedException
	at java.base/java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.reportInterruptAfterWait(AbstractQueuedSynchronizer.java:2056)
	at java.base/java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2133)
	at java.base/java.util.concurrent.ArrayBlockingQueue.poll(ArrayBlockingQueue.java:432)
	at org.pytorch.serve.wlm.WorkerThread.run(WorkerThread.java:188)
	at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:515)
	at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:834)
2021-03-16 20:19:02,523 [WARN ] W-9001-my_tc_1.0 org.pytorch.serve.wlm.BatchAggregator - Load model failed: my_tc, error: Worker died.

@ayushch3
Copy link
Author

@dhanainme Just to clarify, all these issues only occur when trying to run the torchserve inside a docker image, there are no issues when running in a standalone ubuntu system. The docker image is necessary to deploy this as a microservice, but the torchserve just fails without emiting any failure logs, so its hard to debug what's going wrong

@ayushch3
Copy link
Author

@dhanainme I have tried everything possible to run this mar file inside a docker container using torchserve, every single time it fails with absolutenly no indication/logs of what's the underlying issue.

This is the mar file: https://drive.google.com/file/d/18tiD5gLvbRvq6P9kHjqlSNBGO55TbXVs/view?usp=sharing

Can you try downloading and running it inside a docker container and provide me with a Dockerfile that I can deploy it to a kubernetes cluster?

@maaquib
Copy link
Collaborator

maaquib commented Mar 17, 2021

@ayushch3 It seems like an issue with loading the model file. Will try and reproduce this but in the meantime do you mind getting a shell into the container and loading the model manually

@maaquib maaquib self-assigned this Mar 17, 2021
@ayushch3
Copy link
Author

@maaquib Running the mar file on a standalone ubuntu box correctly, but the same mar file when mounted to torchserve docker image fails to load. When using the shell to get into the container, it fails to load the model with the same logs as attached above.

I am certain that there is no issue with the mar file, since it wouldn't have run locally. However, the issue is how the torchserve runs inside a docker container, since neither mounting the image or building the image from scratch using cuda works

FROM nvidia/cuda:11.0-cudnn8-runtime-ubuntu18.04
WORKDIR /usr/src/app

ENV DEBIAN_FRONTEND=nonintercative

RUN apt-get update
RUN apt-get install -y python3 python3-dev python3-pip openjdk-11-jre-headless git wget curl

RUN export USE_CUDA=1
RUN python3 -m pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
RUN python3 -m pip install torchserve
RUN python3 -m pip install --no-cache-dir torchtext transformers

#the mar file is downloaded from S3, that part is skipped here

EXPOSE 8080 8081 8082 7070 7071
COPY . .

CMD ["torchserve", "--start", "--model-store", "/usr/src/app/model_store", "--models",  "my_tc=TranslationClassifier.mar"]

@ayushch3
Copy link
Author

@maaquib The underlying issue is the docker can't handle any custom handler, I literally wrote the entire docker image from scratch, it still failed with the following logs

2021-03-22 21:25:45,587 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/local/lib/python3.6/dist-packages/ts/model_service_worker.py", line 116, in handle_connection
2021-03-22 21:25:45,587 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     service, result, code = self.load_model(msg)
2021-03-22 21:25:45,587 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/local/lib/python3.6/dist-packages/ts/model_service_worker.py", line 89, in load_model
2021-03-22 21:25:45,587 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     service = model_loader.load(model_name, model_dir, handler, gpu, batch_size, envelope)
2021-03-22 21:25:45,587 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/local/lib/python3.6/dist-packages/ts/model_loader.py", line 83, in load
2021-03-22 21:25:45,587 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     module = self._load_default_handler(handler)
2021-03-22 21:25:45,587 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/local/lib/python3.6/dist-packages/ts/model_loader.py", line 120, in _load_default_handler
2021-03-22 21:25:45,587 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     module = importlib.import_module(module_name, 'ts.torch_handler')
2021-03-22 21:25:45,587 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/lib/python3.6/importlib/__init__.py", line 126, in import_module
2021-03-22 21:25:45,588 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     return _bootstrap._gcd_import(name[level:], package, level)
2021-03-22 21:25:45,588 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "<frozen importlib._bootstrap>", line 994, in _gcd_import
2021-03-22 21:25:45,588 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "<frozen importlib._bootstrap>", line 971, in _find_and_load
2021-03-22 21:25:45,588 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "<frozen importlib._bootstrap>", line 950, in _find_and_load_unlocked
2021-03-22 21:25:45,588 [INFO ] W-9009-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - ModuleNotFoundError: No module named 'ts.torch_handler.text_handler.py'; 'ts.torch_handler.text_handler' is not a package

I still don't understand why is it complaining about a custom handler when I am running torch-model-archiver inside the docker container

@Tony-X
Copy link

Tony-X commented Mar 23, 2021

I have a similar issue. In my setup, the goal is to invoke a model without any pre/post processing. I browsed the source and found that base_handler fits the need.

so I did the following

torch-model-archiver --model-name <name> --version 1.0 --model-file <model_file> --serialized-file <serialized_torchscript> --export-path model_store --handler /home/ec2-user/.local/lib/python3.8/site-packages/ts/torch_handler/base_handler.py

I started the sever and hist the error

2021-03-23 23:29:43,691 [DEBUG] W-9000-dl_cap_1.0 org.pytorch.serve.wlm.WorkerThread - Backend worker monitoring thread interrupted or backend worker process died.
java.lang.InterruptedException
        at java.base/java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.reportInterruptAfterWait(AbstractQueuedSynchronizer.java:2056)
        at java.base/java.util.concurrent.locks.AbstractQueuedSynchronizer$ConditionObject.awaitNanos(AbstractQueuedSynchronizer.java:2133)
        at java.base/java.util.concurrent.ArrayBlockingQueue.poll(ArrayBlockingQueue.java:432)
        at org.pytorch.serve.wlm.WorkerThread.run(WorkerThread.java:188)
        at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:515)
        at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
        at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
        at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
        at java.base/java.lang.Thread.run(Thread.java:829)
2021-03-23 23:29:43,691 [INFO ] W-9000-dl_cap_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "<frozen importlib._bootstrap>", line 970, in _find_and_load_unlocked
2021-03-23 23:29:43,693 [INFO ] W-9000-dl_cap_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - ModuleNotFoundError: No module named 'ts.torch_handler.base_handler.py'; 'ts.torch_handler.base_handler' is not a package

@maaquib maaquib added this to the v0.4.0 milestone Mar 24, 2021
@ayushch3
Copy link
Author

@maaquib Even mounting an existing mar file which runs on my ubuntu system to a docker image fails to execute on the image provided by pytorchserve

docker run -p 8080:8080 -p 8081:8081 --mount type=bind,source=/home/ayush/model-store-trans,target=/home/ayush/model-store-trans pytorch/torchserve:latest-gpu torchserve --model-store /home/ayush/model-store-trans --models my_tc=TranslationClassifier.mar

I see the following errors when running the docker image:

2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/venv/lib/python3.6/site-packages/ts/model_service_worker.py", line 182, in <module>
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     worker.run_server()
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/venv/lib/python3.6/site-packages/ts/model_service_worker.py", line 154, in run_server
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     self.handle_connection(cl_socket)
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/venv/lib/python3.6/site-packages/ts/model_service_worker.py", line 116, in handle_connection
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     service, result, code = self.load_model(msg)
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/venv/lib/python3.6/site-packages/ts/model_service_worker.py", line 89, in load_model
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     service = model_loader.load(model_name, model_dir, handler, gpu, batch_size, envelope)
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/venv/lib/python3.6/site-packages/ts/model_loader.py", line 83, in load
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     module = self._load_default_handler(handler)
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/home/venv/lib/python3.6/site-packages/ts/model_loader.py", line 120, in _load_default_handler
2021-03-24 23:34:58,973 [INFO ] epollEventLoopGroup-5-30 org.pytorch.serve.wlm.WorkerThread - 9012 Worker disconnected. WORKER_STARTED
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     module = importlib.import_module(module_name, 'ts.torch_handler')
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "/usr/lib/python3.6/importlib/__init__.py", line 126, in import_module
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -     return _bootstrap._gcd_import(name[level:], package, level)
2021-03-24 23:34:58,973 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "<frozen importlib._bootstrap>", line 994, in _gcd_import
2021-03-24 23:34:58,974 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "<frozen importlib._bootstrap>", line 971, in _find_and_load
2021-03-24 23:34:58,974 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle -   File "<frozen importlib._bootstrap>", line 950, in _find_and_load_unlocked
2021-03-24 23:34:58,974 [INFO ] W-9012-my_tc_1.0-stdout org.pytorch.serve.wlm.WorkerLifeCycle - ModuleNotFoundError: No module named 'ts.torch_handler.text_handler.py'; 'ts.torch_handler.text_handler' is not a package

Is there an ETA or workaround for this issue to deploy a torch serve model in production? I have been trying to find multiple ways to make it work in a docker image, but nothing seems to work. Surprisingly, the same mar file works as expected on the ubuntu box

@merryHunter
Copy link

@ayushch3 how is your progress on this issue? I get similar no module error even when I am running example case for image classification or object detection, also in a nvidia image container. And do you have any idea about the root cause?

@lxning
Copy link
Collaborator

lxning commented Apr 9, 2021

@ayushch3 can you run "python ts_scripts/print_env_info.py" at your local and docker container and compare the dependent packages differences?

@msaroufim msaroufim added the bug Something isn't working label Apr 30, 2021
@kqhuynguyen
Copy link

@ayushch3 Can you check if there was any other import error above the torch_handler import error in model_log.log?
I had the same issue. I discovered that the Dockerfile was missing some required pip libraries, so I fixed that.
The exception ModuleNotFoundError: No module named 'ts.torch_handler.*' was raised by the following code in ts/model_loader.py:

    def load(self, model_name, model_dir, handler, gpu_id, batch_size, envelope=None):
        ...
        try:
            module, function_name = self._load_handler_file(handler)
        except ImportError:
            module = self._load_default_handler(handler)
        ...

This code will see if it can load a custom handler, before attempting to load default handlers.
So suppose I forgot to install pip module foo in the Dockerfile. After the serve command, incoming requests would cause ModuleNotFoundError (which is actually a kind of ImportError) to be raised in self._load_handler_file(handler), and self._load_default_handler(handler) would run. Because I was using custom handlers, self._load_default_handler(handler) would fail again, resulting in another exception, but it's misleading us from the main issue (failing to import foo). So I suggest looking further back the log to see if there's any message that reads During handling of the above exception, another exception occurred:.

@lxning
Copy link
Collaborator

lxning commented Aug 25, 2021

@ayushch3 @kqhuynguyen pls add install_py_dep_per_model=true in config.properties if your model needs install package. And then copy or attach the config.properties to your docker container.

@pjerryhu
Copy link

I'm experiencing the same issue as this thread.

I have added install_py_dep_per_model=true into the config.properties file and built my image with all necessary python dependencies by adding the pip3 install -r midas_requirements.txt in Dockerfile

Then run my image with following command to ensure that I have all the volume mount working correctly.
docker run --rm -it -p 8080:8080 -p 8081:8081 --name mar -v $(pwd)/model-store:/home/model-server/model-store -v $(pwd)/docker/config.properties:/home/model-server/config.properties -v $(pwd)/examples:/home/model-server/examples pytorch/torchserve:midas torchserve --model-store=/home/model-server/model-store --ts-config=/home/model-server/config.properties

However, I'm still experiencing handler module not found error...

2022-06-10T23:44:16,772 [INFO ] W-9000-midas_small_v21_1.0-stdout MODEL_LOG - ModuleNotFoundError: No module named 'ts.torch_handler.MidasNetCustomHandler'

Can you illuminate on what else I can try?

Also is there examples on @lxning your linked Allow model specific custom python packages? I would like to try and have that "seamless model serving" experience....

@msaroufim msaroufim reopened this Jun 11, 2022
@nataliameira
Copy link

Any update on this? I'm still getting the same error.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working triaged_wait Waiting for the Reporter's resp
Projects
None yet
Development

No branches or pull requests

10 participants