Skip to content

Commit

Permalink
Run pre-commit hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
OPHoperHPO committed Feb 8, 2023
1 parent 580a2ac commit e60e7fa
Show file tree
Hide file tree
Showing 9 changed files with 118 additions and 95 deletions.
80 changes: 48 additions & 32 deletions carvekit/api/autointerface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,19 @@
__all__ = ["AutoInterface"]




class AutoInterface(Interface):
def __init__(
self,
scene_classifier: SceneClassifier,
object_classifier: SimplifiedYoloV4,
segmentation_batch_size: int = 3,
refining_batch_size: int = 1,
refining_image_size: int = 900,
postprocessing_batch_size: int = 1,
postprocessing_image_size: int = 2048,
segmentation_device: str = "cpu",
postprocessing_device: str = "cpu",
fp16=False,
self,
scene_classifier: SceneClassifier,
object_classifier: SimplifiedYoloV4,
segmentation_batch_size: int = 3,
refining_batch_size: int = 1,
refining_image_size: int = 900,
postprocessing_batch_size: int = 1,
postprocessing_image_size: int = 2048,
segmentation_device: str = "cpu",
postprocessing_device: str = "cpu",
fp16=False,
):
"""
Args:
Expand Down Expand Up @@ -73,38 +71,53 @@ def select_params_for_net(net: Union[TracerUniversalB7, U2NET, DeepLabV3]):
"""
if net == TracerUniversalB7:
return {
"trimap_generator": {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5},
"trimap_generator": {
"prob_threshold": 231,
"kernel_size": 30,
"erosion_iters": 5,
},
"matting_module": {"disable_noise_filter": False},
"refining": {"enabled": True,
"mask_binary_threshold": 128}
"refining": {"enabled": True, "mask_binary_threshold": 128},
}
elif net == U2NET:
return {
"trimap_generator": {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5},
"trimap_generator": {
"prob_threshold": 231,
"kernel_size": 30,
"erosion_iters": 5,
},
"matting_module": {"disable_noise_filter": False},
"refining": {"enabled": True,
"mask_binary_threshold": 128}
"refining": {"enabled": True, "mask_binary_threshold": 128},
}
elif net == ISNet:
return {
"trimap_generator": {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5},
"trimap_generator": {
"prob_threshold": 231,
"kernel_size": 30,
"erosion_iters": 5,
},
"matting_module": {"disable_noise_filter": True},
"refining": {"enabled": False,
"mask_binary_threshold": 128}
"refining": {"enabled": False, "mask_binary_threshold": 128},
}
elif net == DeepLabV3:
return {
"trimap_generator": {"prob_threshold": 231, "kernel_size": 40, "erosion_iters": 20},
"trimap_generator": {
"prob_threshold": 231,
"kernel_size": 40,
"erosion_iters": 20,
},
"matting_module": {"disable_noise_filter": False},
"refining": {"enabled": True,
"mask_binary_threshold": 128}
"refining": {"enabled": True, "mask_binary_threshold": 128},
}
elif net == BASNET:
return {
"trimap_generator": {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5},
"trimap_generator": {
"prob_threshold": 231,
"kernel_size": 30,
"erosion_iters": 5,
},
"matting_module": {"disable_noise_filter": False},
"refining": {"enabled": True,
"mask_binary_threshold": 128}
"refining": {"enabled": True, "mask_binary_threshold": 128},
}
else:
raise ValueError("Unknown network type")
Expand Down Expand Up @@ -180,7 +193,6 @@ def select_net(self, scene: str, images_info: List[dict]):
"net"
] = TracerUniversalB7 # It seems that the image is empty, but we will try to process it


def __call__(self, images: List[Union[str, Path, Image.Image]]):
"""
Automatically detects the scene and selects the appropriate network for segmentation
Expand Down Expand Up @@ -254,9 +266,13 @@ def __call__(self, images: List[Union[str, Path, Image.Image]]):
# Configure custom pipeline for image group
config_params = self.select_params_for_net(net)
trimap_generator = TrimapGenerator(**config_params["trimap_generator"])
fba.disable_noise_filter = config_params["matting_module"]["disable_noise_filter"]
fba.disable_noise_filter = config_params["matting_module"][
"disable_noise_filter"
]
if config_params["refining"]["enabled"]:
cascadepsp.mask_binary_threshold = config_params["refining"]["mask_binary_threshold"]
cascadepsp.mask_binary_threshold = config_params["refining"][
"mask_binary_threshold"
]
matting_method = CasMattingMethod(
refining_module=cascadepsp,
matting_module=fba,
Expand All @@ -267,7 +283,7 @@ def __call__(self, images: List[Union[str, Path, Image.Image]]):
matting_method = MattingMethod(
matting_module=fba,
trimap_generator=trimap_generator,
device=self.postprocessing_device
device=self.postprocessing_device,
)

sc_images = [image_info["image"] for image_info in gimages_info]
Expand Down
4 changes: 1 addition & 3 deletions carvekit/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,7 @@ def __call__(
else:
images = list(
map(
lambda x: apply_mask(
image=images[x], mask=masks[x]
),
lambda x: apply_mask(image=images[x], mask=masks[x]),
range(len(images)),
)
)
Expand Down
56 changes: 31 additions & 25 deletions carvekit/ml/arch/isnet/isnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
super(REBNCONV, self).__init__()

self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
self.conv_s1 = nn.Conv2d(
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)

Expand All @@ -25,13 +27,12 @@ def forward(self, x):

## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src, tar):
src = F.upsample(src, size=tar.shape[2:], mode='bilinear')
src = F.upsample(src, size=tar.shape[2:], mode="bilinear")

return src


class RSU7(nn.Module):

def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
super(RSU7, self).__init__()

Expand Down Expand Up @@ -113,7 +114,6 @@ def forward(self, x):


class RSU6(nn.Module):

def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6, self).__init__()

Expand Down Expand Up @@ -180,7 +180,6 @@ def forward(self, x):


class RSU5(nn.Module):

def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()

Expand Down Expand Up @@ -237,7 +236,6 @@ def forward(self, x):


class RSU4(nn.Module):

def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()

Expand Down Expand Up @@ -284,7 +282,6 @@ def forward(self, x):


class RSU4F(nn.Module):

def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F, self).__init__()

Expand Down Expand Up @@ -319,22 +316,27 @@ def forward(self, x):


class myrebnconv(nn.Module):
def __init__(self, in_ch=3,
out_ch=1,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1):
def __init__(
self,
in_ch=3,
out_ch=1,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=1,
):
super(myrebnconv, self).__init__()

self.conv = nn.Conv2d(in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
self.conv = nn.Conv2d(
in_ch,
out_ch,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
)
self.bn = nn.BatchNorm2d(out_ch)
self.rl = nn.ReLU(inplace=True)

Expand All @@ -343,7 +345,6 @@ def forward(self, x):


class ISNetDIS(nn.Module):

def __init__(self, in_ch=3, out_ch=1):
super(ISNetDIS, self).__init__()

Expand Down Expand Up @@ -447,6 +448,11 @@ def forward(self, x):

# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [hx1d, hx2d,
hx3d, hx4d,
hx5d, hx6]
return [
F.sigmoid(d1),
F.sigmoid(d2),
F.sigmoid(d3),
F.sigmoid(d4),
F.sigmoid(d5),
F.sigmoid(d6),
], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
30 changes: 15 additions & 15 deletions carvekit/ml/wrap/fba_matting.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ class FBAMatting(FBA):
"""

def __init__(
self,
device="cpu",
input_tensor_size: Union[List[int], int] = 2048, # 1500,
batch_size: int = 2,
encoder="resnet50_GN_WS",
load_pretrained: bool = True,
fp16: bool = False,
disable_noise_filter=False
self,
device="cpu",
input_tensor_size: Union[List[int], int] = 2048, # 1500,
batch_size: int = 2,
encoder="resnet50_GN_WS",
load_pretrained: bool = True,
fp16: bool = False,
disable_noise_filter=False,
):
"""
Initialize the FBAMatting model
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
self.eval()

def data_preprocessing(
self, data: Union[PIL.Image.Image, np.ndarray]
self, data: Union[PIL.Image.Image, np.ndarray]
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""
Transform input image to suitable data format for neural network
Expand Down Expand Up @@ -115,9 +115,9 @@ def data_preprocessing(
.float(),
)

def data_postprocessing(self,
data: torch.tensor, trimap: PIL.Image.Image
) -> PIL.Image.Image:
def data_postprocessing(
self, data: torch.tensor, trimap: PIL.Image.Image
) -> PIL.Image.Image:
"""
Transforms output data from neural network to suitable data
format for using with other components of this framework.
Expand All @@ -144,9 +144,9 @@ def data_postprocessing(self,
return Image.fromarray(pred * 255).convert("L")

def __call__(
self,
images: List[Union[str, pathlib.Path, PIL.Image.Image]],
trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]],
self,
images: List[Union[str, pathlib.Path, PIL.Image.Image]],
trimaps: List[Union[str, pathlib.Path, PIL.Image.Image]],
) -> List[PIL.Image.Image]:
"""
Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
Expand Down
24 changes: 14 additions & 10 deletions carvekit/ml/wrap/isnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ class ISNet(ISNetDIS):
"""ISNet model interface"""

def __init__(
self,
device="cpu",
input_image_size: Union[List[int], int] = 1024,
batch_size: int = 1,
load_pretrained: bool = True,
fp16: bool = False,
self,
device="cpu",
input_image_size: Union[List[int], int] = 1024,
batch_size: int = 1,
load_pretrained: bool = True,
fp16: bool = False,
):
"""
Initialize the ISNet model
Expand Down Expand Up @@ -80,7 +80,7 @@ def data_preprocessing(self, data: PIL.Image.Image) -> torch.FloatTensor:

@staticmethod
def data_postprocessing(
data: torch.tensor, original_image: PIL.Image.Image
data: torch.tensor, original_image: PIL.Image.Image
) -> PIL.Image.Image:
"""
Transforms output data from neural network to suitable data
Expand All @@ -98,12 +98,14 @@ def data_postprocessing(
ma = torch.max(data)
mi = torch.min(data)
data = (data - mi) / (ma - mi)
mask = Image.fromarray((data * 255).cpu().data.numpy().astype(np.uint8)).convert("L")
mask = Image.fromarray(
(data * 255).cpu().data.numpy().astype(np.uint8)
).convert("L")
mask = mask.resize(original_image.size, resample=3)
return mask

def __call__(
self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
self, images: List[Union[str, pathlib.Path, PIL.Image.Image]]
) -> List[PIL.Image.Image]:
"""
Passes input images though neural network and returns segmentation masks as PIL.Image.Image instances
Expand Down Expand Up @@ -132,7 +134,9 @@ def __call__(
masks_cpu = masks.cpu()
del batches, masks
masks = thread_pool_processing(
lambda x: self.data_postprocessing(masks_cpu[x], converted_images[x]),
lambda x: self.data_postprocessing(
masks_cpu[x], converted_images[x]
),
range(len(converted_images)),
)
collect_masks += masks
Expand Down
4 changes: 1 addition & 3 deletions carvekit/pipelines/postprocessing/casmatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,7 @@ def __call__(
alpha = self.matting_module(images=images, trimaps=trimaps)
return list(
map(
lambda x: apply_mask(
image=images[x], mask=alpha[x]
),
lambda x: apply_mask(image=images[x], mask=alpha[x]),
range(len(images)),
)
)
Loading

0 comments on commit e60e7fa

Please sign in to comment.