Skip to content

Commit

Permalink
fix(visualization): 修复recognizer函数导入文件
Browse files Browse the repository at this point in the history
  • Loading branch information
zjykzj committed Nov 5, 2020
1 parent a8494f4 commit a320268
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
8 changes: 8 additions & 0 deletions demo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# -*- coding: utf-8 -*-

"""
@date: 2020/10/30 下午3:22
@file: __init__.py.py
@author: zj
@description:
"""
9 changes: 5 additions & 4 deletions demo/tsn_r50_ucf101_rgb_raw_dense_1x16x4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ NUM_NODES: 1
RANK_ID: 0
DIST_BACKEND: "nccl"
RNG_SEED: 1
OUTPUT_DIR: 'outputs/tsn_r50_ucf101_rgb_raw_dense_1x16x4'
DATASETS:
MODALITY: 'RGB'
TYPE: 'RawFrame'
SAMPLE_STRATEGY: 'DenseSample'
CLIP_LEN: 1
FRAME_INTERVAL: 16
NUM_CLIPS: 4
NUM_SAMPLE_POSITIONS: 1
TRANSFORM:
MEAN: (0.485, 0.456, 0.406)
STD: (0.229, 0.224, 0.225)
Expand All @@ -21,9 +23,8 @@ MODEL:
NAME: 'TSN'
PRETRAINED: ''
SYNC_BN: True
INPUT_SIZE: (224, 224, 3)
BACKBONE:
NAME: 'resnet50'
NAME: 'ResNet50'
PARTIAL_BN: False
TORCHVISION_PRETRAINED: True
ZERO_INIT_RESIDUAL: True
Expand All @@ -48,4 +49,4 @@ VISUALIZATION:
DISPLAY_HEIGHT: 0
OUTPUT_FPS: -1
OUTPUT_FILE: ""
LABEL_FILE_PATH: 'data/ucf101/annotations/classInd.txt'
LABEL_FILE_PATH: 'data/ucf101/annotations/classInd.txt'
4 changes: 2 additions & 2 deletions tsn/visualization/predictor/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from tsn.model.build import build_model
from tsn.model.recognizers.build import build_recognizer
from tsn.data.transforms.build import build_transform
from .util import process_cv2_inputs
from tsn.util.distributed import get_device, get_local_rank
Expand All @@ -27,7 +27,7 @@ def __init__(self, cfg):
device = get_device()

# Build the video model and print model statistics.
self.model = build_model(cfg, device)
self.model = build_recognizer(cfg, device)
self.model.eval()
self.transform = build_transform(cfg, is_train=False)

Expand Down

0 comments on commit a320268

Please sign in to comment.