Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix import error, Copy-on-read Overhead ( called memory leak in repository ) and slightly refactor dist_utils.py for improved readability #418

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion rtdetrv2_pytorch/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,7 @@ torch>=2.0.1
torchvision>=0.15.2
pycocotools
PyYAML
tensorboard
tensorboard
scipy
psutil
tabulate
3 changes: 2 additions & 1 deletion rtdetrv2_pytorch/src/data/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from .cifar_dataset import CIFAR10
from .coco_dataset import CocoDetection
from .coco_dataset import (
CocoDetection,
CocoDetection,
CocoDetection_share_memory,
mscoco_category2name,
mscoco_category2label,
mscoco_label2category,
Expand Down
93 changes: 92 additions & 1 deletion rtdetrv2_pytorch/src/data/dataset/coco_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""

import gc
import os
import torch
import torch.utils.data

Expand All @@ -17,8 +19,10 @@
from ._dataset import DetDataset
from .._misc import convert_to_tv_tensor
from ...core import register
from .coco_utils import TorchSerializedList
from pycocotools.coco import COCO

__all__ = ['CocoDetection']
__all__ = ['CocoDetection', 'CocoDetection_share_memory']


@register()
Expand Down Expand Up @@ -88,6 +92,93 @@ def label2category(self, ):
return {i: cat['id'] for i, cat in enumerate(self.categories)}


@register()
class CocoDetection_share_memory(torchvision.datasets.VisionDataset, DetDataset):
int11 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, img_folder, ann_file, transforms, return_masks=False, remap_mscoco_category=False, share_memory=True):
super(CocoDetection_share_memory, self).__init__(img_folder)
coco = COCO(ann_file)

index = sorted(coco.imgs.keys())

self.imgs_info = [coco.imgs[i] for i in index]
self.anns = [coco.imgToAnns[i] for i in index]
self.categories = coco.dataset['categories']

if share_memory:
self.imgs_info = TorchSerializedList(self.imgs_info)
self.anns = TorchSerializedList(self.anns)

self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_masks)
self.img_folder = img_folder
self.ann_file = ann_file
self.return_masks = return_masks
self.remap_mscoco_category = remap_mscoco_category

def _load_image(self, idx: int) -> Image.Image:
img_info = self.imgs_info[idx]
file_name = img_info["file_name"]
img_full_path = os.path.join(self.img_folder, file_name)

img = Image.open(img_full_path).convert("RGB")

return img

def _load_target(self, idx: int):
return self.anns[idx]

def load_item(self, idx):
image, target = self._load_image(idx), self._load_target(idx)
image_id = self.imgs_info[idx]['id']
target = {'image_id': image_id, 'annotations': target}

if self.remap_mscoco_category:
image, target = self.prepare(image, target, category2label=mscoco_category2label)
# image, target = self.prepare(image, target, category2label=self.category2label)
else:
image, target = self.prepare(image, target)

target['idx'] = torch.tensor([idx])

if 'boxes' in target:
target['boxes'] = convert_to_tv_tensor(target['boxes'], key='boxes', spatial_size=image.size[::-1])

if 'masks' in target:
target['masks'] = convert_to_tv_tensor(target['masks'], key='masks')

return image, target

def __getitem__(self, idx):
img, target = self.load_item(idx)
if self._transforms is not None:
img, target, _ = self._transforms(img, target, self)
return img, target

def extra_repr(self) -> str:
s = f' img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n'
s += f' return_masks: {self.return_masks}\n'
if hasattr(self, '_transforms') and self._transforms is not None:
s += f' transforms:\n {repr(self._transforms)}'
if hasattr(self, '_preset') and self._preset is not None:
s += f' preset:\n {repr(self._preset)}'
return s

def __len__(self) -> int:
return len(self.imgs_info)

@property
def category2name(self, ):
return {cat['id']: cat['name'] for cat in self.categories}

@property
def category2label(self, ):
return {cat['id']: i for i, cat in enumerate(self.categories)}

@property
def label2category(self, ):
return {i: cat['id'] for i, cat in enumerate(self.categories)}


def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
Expand Down
41 changes: 40 additions & 1 deletion rtdetrv2_pytorch/src/data/dataset/coco_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""


import pickle
import numpy as np
import torch
import torch.utils.data
import torchvision
Expand Down Expand Up @@ -193,3 +194,41 @@ def get_coco_api_from_dataset(dataset):
return convert_to_coco_api(dataset)


class NumpySerializedList():
def __init__(self, lst: list):
def _serialize(data):
buffer = pickle.dumps(data, protocol=-1)
return np.frombuffer(buffer, dtype=np.uint8)

print(
"Serializing {} elements to byte tensors and concatenating them all ...".format(
len(lst)
)
)
self._lst = [_serialize(x) for x in lst]
self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
self._addr = np.cumsum(self._addr)
self._lst = np.concatenate(self._lst)
print("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))

def __len__(self):
return len(self._addr)

def __getitem__(self, idx):
start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
end_addr = self._addr[idx].item()
bytes = memoryview(self._lst[start_addr:end_addr])
return pickle.loads(bytes)


class TorchSerializedList(NumpySerializedList):
def __init__(self, lst: list):
super().__init__(lst)
self._addr = torch.from_numpy(self._addr)
self._lst = torch.from_numpy(self._lst)

def __getitem__(self, idx):
start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
end_addr = self._addr[idx].item()
bytes = memoryview(self._lst[start_addr:end_addr].numpy())
return pickle.loads(bytes)
7 changes: 1 addition & 6 deletions rtdetrv2_pytorch/src/misc/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,7 @@ def print(*args, **kwargs):


def is_dist_available_and_initialized():
if not torch.distributed.is_available():
return False
if not torch.distributed.is_initialized():
return False
return True

return torch.distributed.is_available() and torch.distributed.is_initialized()

@atexit.register
def cleanup():
Expand Down
169 changes: 169 additions & 0 deletions rtdetrv2_pytorch/tools/memory_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..'))


import time
import pickle
from collections import defaultdict
from multiprocessing import Manager


from src.data.dataloader import DataLoader, BatchImageCollateFuncion
from src.data import transforms as T
from src.data.dataset.coco_dataset import CocoDetection, CocoDetection_share_memory


import torch
import psutil
from tabulate import tabulate


"""
testing memory usage of dataloader.

requires psutil and tabulate

pip install psutil tabulate
"""


class MemoryMonitor():
def __init__(self, pids: list[int] = None):
if pids is None:
pids = [os.getpid()]
self.pids = Manager().list(pids)

def add_pid(self, pid: int):
assert pid not in self.pids
self.pids.append(pid)

def _refresh(self):
self.data = {pid: self.get_mem_info(pid) for pid in self.pids}
return self.data

def table(self) -> str:
self._refresh()
table = []
keys = list(list(self.data.values())[0].keys())
now = str(int(time.perf_counter() % 1e5))
for pid, data in self.data.items():
table.append((now, str(pid)) + tuple(self.format(data[k]) for k in keys))
return tabulate(table, headers=["time", "PID"] + keys)

def str(self):
self._refresh()
keys = list(list(self.data.values())[0].keys())
res = []
for pid in self.pids:
s = f"PID={pid}"
for k in keys:
v = self.format(self.data[pid][k])
s += f", {k}={v}"
res.append(s)
return "\n".join(res)

@staticmethod
def format(size: int) -> str:
for unit in ('', 'K', 'M', 'G'):
if size < 1024:
break
size /= 1024.0
return "%.1f%s" % (size, unit)

@staticmethod
def get_mem_info(pid: int) -> dict[str, int]:
res = defaultdict(int)
for mmap in psutil.Process(pid).memory_maps():
res['rss'] += mmap.rss
res['pss'] += mmap.pss
res['uss'] += mmap.private_clean + mmap.private_dirty
res['shared'] += mmap.shared_clean + mmap.shared_dirty
if mmap.path.startswith('/'):
res['shared_file'] += mmap.shared_clean + mmap.shared_dirty
return res


def test_dataset(
dataset_class,
range_num=None,
img_folder="./dataset/coco/train2017/",
ann_file="./dataset/coco/annotations/instances_train2017.json",
**kwargs):

train_dataset = dataset_class(
img_folder=img_folder,
ann_file=ann_file,
transforms = T.Compose([T.RandomPhotometricDistort(p=0.5),
T.RandomZoomOut(fill=0),
T.RandomIoUCrop(p=0.8),
T.SanitizeBoundingBoxes(min_size=1),
T.RandomHorizontalFlip(),
T.Resize(size=[640, 640], ),
T.SanitizeBoundingBoxes(min_size=1),
T.ConvertPILImage(dtype='float32', scale=True),
T.ConvertBoxes(fmt='cxcywh', normalize=True)],
policy={'name': 'stop_epoch',
'epoch': 71 ,
'ops': ['RandomPhotometricDistort', 'RandomZoomOut', 'RandomIoUCrop']}),
return_masks=False,
remap_mscoco_category=True,
**kwargs)

if range_num is not None:
train_dataset = torch.utils.data.Subset(train_dataset, range(range_num))

return train_dataset


def test_dataloader(
dataset,
worker_init_fn,
batch_size=4,
shuffle=True,
num_workers=4):

collate_fn = BatchImageCollateFuncion(scales=[480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], stop_epoch=71)

return DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last=True,
worker_init_fn=worker_init_fn)


def main(**kwargs):
def hook_pid(worker_id):
pid = os.getpid()
monitor.pids.append(pid)
print(f"tracking {worker_id} PID: {pid}")

monitor = MemoryMonitor()

dataloader = test_dataloader(
dataset=test_dataset(**kwargs),
worker_init_fn=hook_pid,
batch_size=32,
num_workers=2)

t = time.time()

for i, (samples, targets) in enumerate(dataloader):
# fake read the data
samples = pickle.dumps(samples)
targets = pickle.dumps(targets)

if i % 10 == 0:
print(monitor.table())
print(f"totle pss : {sum([k[1]['pss'] / 1024 / 1024 / 1024 for k in monitor.data.items()]):.3f}GB")
print(f"iteration : {i} / {len(dataloader)}, time : {time.time() - t:.3f}")
t = time.time()


if __name__ == '__main__':
# main(dataset_class=CocoDetection, range_num=30000)
# main(dataset_class=CocoDetection_share_memory, share_memory=False, range_num=30000)
main(dataset_class=CocoDetection_share_memory, share_memory=True, range_num=30000)