Skip to content

Commit

Permalink
Added ISNet segmentation neural network.
Browse files Browse the repository at this point in the history
  • Loading branch information
OPHoperHPO committed Jan 29, 2023
1 parent d06f360 commit 17b0587
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 37 deletions.
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,28 @@ It can be briefly described as
3. Using machine learning technology, the background of the image is removed
4. Image post-processing to improve the quality of the processed image
## 🎓 Implemented Neural Networks:
| Networks | Target | Accuracy |
|:-----------------------:|:-------------------------------------------:|:--------------------------------:|
| **Tracer-B7** (default) | **General** (objects, animals, etc) | **90%** (mean F1-Score, DUTS-TE) |
| U^2-net | **Hairs** (hairs, people, animals, objects) | 80.4% (mean F1-Score, DUTS-TE) |
| BASNet | **General** (people, objects) | 80.3% (mean F1-Score, DUTS-TE) |
| DeepLabV3 | People, Animals, Cars, etc | 67.4% (mean IoU, COCO val2017) |

| Networks | Target | Accuracy |
|:-----------------------:|:-------------------------------------------:|:------------------------------------:|
| **Tracer-B7** (default) | **General** (objects, animals, etc) | **90%** (mean F1-Score, DUTS-TE, LR) |
| **ISNet** | **Hairs** (hairs, people, animals) | **79%** (max F1-Score, DIS5K, HR) |
| U^2-net | **Hairs** (hairs, people, animals, objects) | 80.4% (mean F1-Score, DUTS-TE, LR) |
| BASNet | **General** (people, objects) | 80.3% (mean F1-Score, DUTS-TE, LR) |
| DeepLabV3 | People, Animals, Cars, etc | 67.4% (mean IoU, COCO val2017, LR) |

> HR - High resolution images.
> LR - Low resolution images.
### Recommended parameters for different models
| Networks | Segmentation mask size | Trimap parameters (dilation, erosion) |
|:-----------:|:-----------------------:|:-------------------------------------:|
| `tracer_b7` | 640 | (30, 5) |
| `isnet` | 1024 | (30, 5) |
| `u2net` | 320 | (30, 5) |
| `basnet` | 320 | (30, 5) |
| `deeplabv3` | 1024 | (40, 20) |

> ### Notes:
> 1. The final quality may depend on the resolution of your image, the type of scene or object.
> 2. Use **U2-Net for hairs** and **Tracer-B7 for general images** and correct parameters. \
> 2. Use **ISNet for hairs** or **U2-Net for hairs** and **Tracer-B7 for general images** and correct parameters. \
> It is very important for final quality! Example images was taken by using U2-Net and FBA post-processing.
## 🖼️ Image pre-processing and post-processing methods:
### 🔍 Preprocessing methods:
Expand Down Expand Up @@ -216,7 +220,7 @@ Options:
processed by refining network
--seg_mask_size 640 The size of the input image for the
segmentation neural network. Use 640 for Tracer B7 and 320 for U2Net
segmentation neural network. Use 640 for Tracer B7 and 1024 for ISNet
--matting_mask_size 2048 The size of the input image for the matting
neural network.
Expand Down
11 changes: 7 additions & 4 deletions carvekit/api/autointerface.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from carvekit.ml.wrap.basnet import BASNET
from carvekit.ml.wrap.cascadepsp import CascadePSP
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
from carvekit.ml.wrap.isnet import ISNet
from carvekit.ml.wrap.fba_matting import FBAMatting
from carvekit.ml.wrap.scene_classifier import SceneClassifier
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
Expand Down Expand Up @@ -73,6 +74,8 @@ def select_params_for_net(net: Union[TracerUniversalB7, U2NET, DeepLabV3]):
return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
elif net == U2NET:
return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
elif net == ISNet:
return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5}
elif net == DeepLabV3:
return {"prob_threshold": 231, "kernel_size": 40, "erosion_iters": 20}
elif net == BASNET:
Expand Down Expand Up @@ -111,7 +114,7 @@ def select_net(self, scene: str, images_info: List[dict]):
image_info["net"] = TracerUniversalB7
elif obj_counter["animals"] > 0:
# Animals case
image_info["net"] = U2NET # animals should be always in soft scenes
image_info["net"] = ISNet # animals should be always in soft scenes
else:
# We have no idea what is in the image, so we will try to process it with universal model
image_info["net"] = TracerUniversalB7
Expand All @@ -134,16 +137,16 @@ def select_net(self, scene: str, images_info: List[dict]):

if obj_counter["human"] > 0 and len(non_empty_classes) == 1:
# Human only case. It may be a portrait
image_info["net"] = U2NET
image_info["net"] = ISNet
elif obj_counter["human"] > 0 and len(non_empty_classes) > 1:
# Okay, we have a human with hairs and something else
image_info["net"] = U2NET
image_info["net"] = ISNet
elif obj_counter["cars"] > 0:
# Cars case.
image_info["net"] = TracerUniversalB7
elif obj_counter["animals"] > 0:
# Animals case
image_info["net"] = U2NET # animals should be always in soft scenes
image_info["net"] = ISNet # animals should be always in soft scenes
else:
# We have no idea what is in the image, so we will try to process it with universal model
image_info["net"] = TracerUniversalB7
Expand Down
4 changes: 2 additions & 2 deletions carvekit/api/high.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from carvekit.ml.wrap.cascadepsp import CascadePSP
from carvekit.ml.wrap.scene_classifier import SceneClassifier
from carvekit.pipelines.preprocessing import AutoScene
from carvekit.ml.wrap.u2net import U2NET
from carvekit.ml.wrap.isnet import ISNet
from carvekit.pipelines.postprocessing import CasMattingMethod
from carvekit.trimap.generator import TrimapGenerator

Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
fp16=fp16,
)
elif object_type == "hairs-like":
self._segnet = U2NET(
self._segnet = ISNet(
device=device,
batch_size=batch_size_seg,
input_image_size=seg_mask_size,
Expand Down
3 changes: 2 additions & 1 deletion carvekit/api/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from carvekit.ml.wrap.basnet import BASNET
from carvekit.ml.wrap.deeplab_v3 import DeepLabV3
from carvekit.ml.wrap.u2net import U2NET
from carvekit.ml.wrap.isnet import ISNet
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.pipelines.preprocessing import PreprocessingStub, AutoScene
from carvekit.pipelines.postprocessing import MattingMethod, CasMattingMethod
Expand All @@ -22,7 +23,7 @@
class Interface:
def __init__(
self,
seg_pipe: Optional[Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]],
seg_pipe: Optional[Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7, ISNet]],
pre_pipe: Optional[Union[PreprocessingStub, AutoScene]] = None,
post_pipe: Optional[Union[MattingMethod, CasMattingMethod]] = None,
device="cpu",
Expand Down
2 changes: 1 addition & 1 deletion carvekit/ml/wrap/fba_matting.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class FBAMatting(FBA):
def __init__(
self,
device="cpu",
input_tensor_size: Union[List[int], int] = 1500,
input_tensor_size: Union[List[int], int] = 2048, #1500,
batch_size: int = 2,
encoder="resnet50_GN_WS",
load_pretrained: bool = True,
Expand Down
4 changes: 2 additions & 2 deletions carvekit/pipelines/preprocessing/autoscene.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from carvekit.ml.wrap.scene_classifier import SceneClassifier
from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7
from carvekit.ml.wrap.u2net import U2NET
from carvekit.ml.wrap.isnet import ISNet

__all__ = ["AutoScene"]

Expand All @@ -36,7 +36,7 @@ def select_net(scene: str):
if scene == "hard":
return TracerUniversalB7
elif scene == "soft":
return U2NET
return ISNet
elif scene == "digital":
return TracerUniversalB7 # TODO: not implemented yet

Expand Down
2 changes: 1 addition & 1 deletion carvekit/web/schemas/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class MLConfig(BaseModel):
"""Config for ml part of framework"""

segmentation_network: Literal[
"u2net", "deeplabv3", "basnet", "tracer_b7"
"u2net", "deeplabv3", "basnet", "tracer_b7", "isnet"
] = "tracer_b7"
"""Segmentation Network"""
preprocessing_method: Literal["none", "stub", "autoscene", "auto"] = "autoscene"
Expand Down
8 changes: 8 additions & 0 deletions carvekit/web/utils/init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from loguru import logger

from carvekit.ml.wrap.cascadepsp import CascadePSP
from carvekit.ml.wrap.isnet import ISNet
from carvekit.ml.wrap.scene_classifier import SceneClassifier
from carvekit.web.schemas.config import WebAPIConfig, MLConfig, AuthConfig

Expand Down Expand Up @@ -146,6 +147,13 @@ def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface:
input_image_size=config.seg_mask_size,
fp16=config.fp16,
)
elif config.segmentation_network == "isnet":
seg_net = ISNet(
device=config.device,
batch_size=config.batch_size_seg,
input_image_size=config.seg_mask_size,
fp16=config.fp16,
)
elif config.segmentation_network == "deeplabv3":
seg_net = DeepLabV3(
device=config.device,
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ services:
environment:
- CARVEKIT_PORT=5000
- CARVEKIT_HOST=0.0.0.0
- CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3
- CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3, isnet
- CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub, autoscene, auto
- CARVEKIT_POSTPROCESSING_METHOD=cascade_fba # can be none, fba, cascade_fba
- CARVEKIT_DEVICE=cpu # can be cuda (req. cuda docker image), cpu
Expand Down
2 changes: 1 addition & 1 deletion docker-compose.cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ services:
environment:
- CARVEKIT_PORT=5000
- CARVEKIT_HOST=0.0.0.0
- CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3
- CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3, isnet
- CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub, autoscene, auto
- CARVEKIT_POSTPROCESSING_METHOD=cascade_fba # can be none, fba, cascade_fba
- CARVEKIT_DEVICE=cuda # can be cuda (req. cuda docker image), cpu
Expand Down
25 changes: 13 additions & 12 deletions docs/CREDITS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ All images are copyrighted by their authors.

## References:
1. https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/
2. https://github.com/NathanUA/U-2-Net
3. https://github.com/NathanUA/BASNet
4. https://github.com/MarcoForte/FBA_Matting
5. https://arxiv.org/abs/1706.05587
6. https://arxiv.org/pdf/2005.09007.pdf
7. http://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html
8. https://arxiv.org/abs/2003.07711
9. https://arxiv.org/abs/1506.01497
10. https://arxiv.org/abs/1703.06870
11. https://github.com/Karel911/TRACER
12. https://arxiv.org/abs/2112.07380
13. https://github.com/hkchengrex/CascadePSP
2. https://github.com/xuebinqin/DIS
3. https://github.com/NathanUA/U-2-Net
4. https://github.com/NathanUA/BASNet
5. https://github.com/MarcoForte/FBA_Matting
6. https://arxiv.org/abs/1706.05587
7. https://arxiv.org/pdf/2005.09007.pdf
8. http://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html
9. https://arxiv.org/abs/2003.07711
10. https://arxiv.org/abs/1506.01497
11. https://arxiv.org/abs/1703.06870
12. https://github.com/Karel911/TRACER
13. https://arxiv.org/abs/2112.07380
14. https://github.com/hkchengrex/CascadePSP

6 changes: 3 additions & 3 deletions docs/other/carvekit_try.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@
"#@markdown Description of parameters\n",
"#@markdown - `SHOW_FULLSIZE` - Shows image in full size (may take a long time to load)\n",
"#@markdown - `PREPROCESSING_METHOD` - Preprocessing method. `AutoScene` will automatically select needed model depends on your image. If you don't want, disable it.\n",
"#@markdown - `SEGMENTATION_NETWORK` - Segmentation network. Use `u2net` for hairs-like objects and `tracer_b7` for objects\n",
"#@markdown - `SEGMENTATION_NETWORK` - Segmentation network. Use `isnet` for hairs-like objects and `tracer_b7` for objects\n",
"#@markdown - `POSTPROCESSING_METHOD` - Postprocessing method\n",
"#@markdown - `SEGMENTATION_MASK_SIZE` - Segmentation mask size. Use 640 for Tracer B7 and 320 for U2Net\n",
"#@markdown - `SEGMENTATION_MASK_SIZE` - Segmentation mask size. Use 640 for Tracer B7 and 1024 for ISNet\n",
"#@markdown - `TRIMAP_DILATION` - The size of the offset radius from the object mask in pixels when forming an unknown area\n",
"#@markdown - `TRIMAP_EROSION` - The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area\n",
"#@markdown > Look README.md and code for more details on networks and methods\n",
Expand All @@ -131,7 +131,7 @@
"\n",
"SHOW_FULLSIZE = False #@param {type:\"boolean\"}\n",
"PREPROCESSING_METHOD = \"autoscene\" #@param [\"autoscene\", \"auto\", \"none\"]\n",
"SEGMENTATION_NETWORK = \"tracer_b7\" #@param [\"u2net\", \"deeplabv3\", \"basnet\", \"tracer_b7\"]\n",
"SEGMENTATION_NETWORK = \"tracer_b7\" #@param [\"u2net\", \"deeplabv3\", \"basnet\", \"tracer_b7\", \"isnet\"]\n",
"POSTPROCESSING_METHOD = \"cascade_fba\" #@param [\"fba\", \"cascade_fba\", \"none\"]\n",
"SEGMENTATION_MASK_SIZE = 640 #@param [\"640\", \"320\"] {type:\"raw\", allow-input: true}\n",
"TRIMAP_DILATION = 30 #@param {type:\"integer\"}\n",
Expand Down

1 comment on commit 17b0587

@OPHoperHPO
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to #119

Please sign in to comment.