diff --git a/README.md b/README.md index cb20f72..0dafd5b 100644 --- a/README.md +++ b/README.md @@ -45,24 +45,28 @@ It can be briefly described as 3. Using machine learning technology, the background of the image is removed 4. Image post-processing to improve the quality of the processed image ## 🎓 Implemented Neural Networks: -| Networks | Target | Accuracy | -|:-----------------------:|:-------------------------------------------:|:--------------------------------:| -| **Tracer-B7** (default) | **General** (objects, animals, etc) | **90%** (mean F1-Score, DUTS-TE) | -| U^2-net | **Hairs** (hairs, people, animals, objects) | 80.4% (mean F1-Score, DUTS-TE) | -| BASNet | **General** (people, objects) | 80.3% (mean F1-Score, DUTS-TE) | -| DeepLabV3 | People, Animals, Cars, etc | 67.4% (mean IoU, COCO val2017) | - +| Networks | Target | Accuracy | +|:-----------------------:|:-------------------------------------------:|:------------------------------------:| +| **Tracer-B7** (default) | **General** (objects, animals, etc) | **90%** (mean F1-Score, DUTS-TE, LR) | +| **ISNet** | **Hairs** (hairs, people, animals) | **79%** (max F1-Score, DIS5K, HR) | +| U^2-net | **Hairs** (hairs, people, animals, objects) | 80.4% (mean F1-Score, DUTS-TE, LR) | +| BASNet | **General** (people, objects) | 80.3% (mean F1-Score, DUTS-TE, LR) | +| DeepLabV3 | People, Animals, Cars, etc | 67.4% (mean IoU, COCO val2017, LR) | + +> HR - High resolution images. +> LR - Low resolution images. ### Recommended parameters for different models | Networks | Segmentation mask size | Trimap parameters (dilation, erosion) | |:-----------:|:-----------------------:|:-------------------------------------:| | `tracer_b7` | 640 | (30, 5) | +| `isnet` | 1024 | (30, 5) | | `u2net` | 320 | (30, 5) | | `basnet` | 320 | (30, 5) | | `deeplabv3` | 1024 | (40, 20) | > ### Notes: > 1. The final quality may depend on the resolution of your image, the type of scene or object. -> 2. Use **U2-Net for hairs** and **Tracer-B7 for general images** and correct parameters. \ +> 2. Use **ISNet for hairs** or **U2-Net for hairs** and **Tracer-B7 for general images** and correct parameters. \ > It is very important for final quality! Example images was taken by using U2-Net and FBA post-processing. ## 🖼️ Image pre-processing and post-processing methods: ### 🔍 Preprocessing methods: @@ -216,7 +220,7 @@ Options: processed by refining network --seg_mask_size 640 The size of the input image for the - segmentation neural network. Use 640 for Tracer B7 and 320 for U2Net + segmentation neural network. Use 640 for Tracer B7 and 1024 for ISNet --matting_mask_size 2048 The size of the input image for the matting neural network. diff --git a/carvekit/api/autointerface.py b/carvekit/api/autointerface.py index 60ba0c7..3030832 100644 --- a/carvekit/api/autointerface.py +++ b/carvekit/api/autointerface.py @@ -13,6 +13,7 @@ from carvekit.ml.wrap.basnet import BASNET from carvekit.ml.wrap.cascadepsp import CascadePSP from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 +from carvekit.ml.wrap.isnet import ISNet from carvekit.ml.wrap.fba_matting import FBAMatting from carvekit.ml.wrap.scene_classifier import SceneClassifier from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 @@ -73,6 +74,8 @@ def select_params_for_net(net: Union[TracerUniversalB7, U2NET, DeepLabV3]): return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5} elif net == U2NET: return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5} + elif net == ISNet: + return {"prob_threshold": 231, "kernel_size": 30, "erosion_iters": 5} elif net == DeepLabV3: return {"prob_threshold": 231, "kernel_size": 40, "erosion_iters": 20} elif net == BASNET: @@ -111,7 +114,7 @@ def select_net(self, scene: str, images_info: List[dict]): image_info["net"] = TracerUniversalB7 elif obj_counter["animals"] > 0: # Animals case - image_info["net"] = U2NET # animals should be always in soft scenes + image_info["net"] = ISNet # animals should be always in soft scenes else: # We have no idea what is in the image, so we will try to process it with universal model image_info["net"] = TracerUniversalB7 @@ -134,16 +137,16 @@ def select_net(self, scene: str, images_info: List[dict]): if obj_counter["human"] > 0 and len(non_empty_classes) == 1: # Human only case. It may be a portrait - image_info["net"] = U2NET + image_info["net"] = ISNet elif obj_counter["human"] > 0 and len(non_empty_classes) > 1: # Okay, we have a human with hairs and something else - image_info["net"] = U2NET + image_info["net"] = ISNet elif obj_counter["cars"] > 0: # Cars case. image_info["net"] = TracerUniversalB7 elif obj_counter["animals"] > 0: # Animals case - image_info["net"] = U2NET # animals should be always in soft scenes + image_info["net"] = ISNet # animals should be always in soft scenes else: # We have no idea what is in the image, so we will try to process it with universal model image_info["net"] = TracerUniversalB7 diff --git a/carvekit/api/high.py b/carvekit/api/high.py index 4d86470..5834429 100644 --- a/carvekit/api/high.py +++ b/carvekit/api/high.py @@ -11,7 +11,7 @@ from carvekit.ml.wrap.cascadepsp import CascadePSP from carvekit.ml.wrap.scene_classifier import SceneClassifier from carvekit.pipelines.preprocessing import AutoScene -from carvekit.ml.wrap.u2net import U2NET +from carvekit.ml.wrap.isnet import ISNet from carvekit.pipelines.postprocessing import CasMattingMethod from carvekit.trimap.generator import TrimapGenerator @@ -71,7 +71,7 @@ def __init__( fp16=fp16, ) elif object_type == "hairs-like": - self._segnet = U2NET( + self._segnet = ISNet( device=device, batch_size=batch_size_seg, input_image_size=seg_mask_size, diff --git a/carvekit/api/interface.py b/carvekit/api/interface.py index 4f776de..df53801 100644 --- a/carvekit/api/interface.py +++ b/carvekit/api/interface.py @@ -11,6 +11,7 @@ from carvekit.ml.wrap.basnet import BASNET from carvekit.ml.wrap.deeplab_v3 import DeepLabV3 from carvekit.ml.wrap.u2net import U2NET +from carvekit.ml.wrap.isnet import ISNet from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 from carvekit.pipelines.preprocessing import PreprocessingStub, AutoScene from carvekit.pipelines.postprocessing import MattingMethod, CasMattingMethod @@ -22,7 +23,7 @@ class Interface: def __init__( self, - seg_pipe: Optional[Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7]], + seg_pipe: Optional[Union[U2NET, BASNET, DeepLabV3, TracerUniversalB7, ISNet]], pre_pipe: Optional[Union[PreprocessingStub, AutoScene]] = None, post_pipe: Optional[Union[MattingMethod, CasMattingMethod]] = None, device="cpu", diff --git a/carvekit/ml/wrap/fba_matting.py b/carvekit/ml/wrap/fba_matting.py index faec265..53eb674 100644 --- a/carvekit/ml/wrap/fba_matting.py +++ b/carvekit/ml/wrap/fba_matting.py @@ -33,7 +33,7 @@ class FBAMatting(FBA): def __init__( self, device="cpu", - input_tensor_size: Union[List[int], int] = 1500, + input_tensor_size: Union[List[int], int] = 2048, #1500, batch_size: int = 2, encoder="resnet50_GN_WS", load_pretrained: bool = True, diff --git a/carvekit/pipelines/preprocessing/autoscene.py b/carvekit/pipelines/preprocessing/autoscene.py index 04138fb..a270543 100644 --- a/carvekit/pipelines/preprocessing/autoscene.py +++ b/carvekit/pipelines/preprocessing/autoscene.py @@ -10,7 +10,7 @@ from carvekit.ml.wrap.scene_classifier import SceneClassifier from carvekit.ml.wrap.tracer_b7 import TracerUniversalB7 -from carvekit.ml.wrap.u2net import U2NET +from carvekit.ml.wrap.isnet import ISNet __all__ = ["AutoScene"] @@ -36,7 +36,7 @@ def select_net(scene: str): if scene == "hard": return TracerUniversalB7 elif scene == "soft": - return U2NET + return ISNet elif scene == "digital": return TracerUniversalB7 # TODO: not implemented yet diff --git a/carvekit/web/schemas/config.py b/carvekit/web/schemas/config.py index 8b12c02..6eabedb 100644 --- a/carvekit/web/schemas/config.py +++ b/carvekit/web/schemas/config.py @@ -21,7 +21,7 @@ class MLConfig(BaseModel): """Config for ml part of framework""" segmentation_network: Literal[ - "u2net", "deeplabv3", "basnet", "tracer_b7" + "u2net", "deeplabv3", "basnet", "tracer_b7", "isnet" ] = "tracer_b7" """Segmentation Network""" preprocessing_method: Literal["none", "stub", "autoscene", "auto"] = "autoscene" diff --git a/carvekit/web/utils/init_utils.py b/carvekit/web/utils/init_utils.py index d975e27..0e14edb 100644 --- a/carvekit/web/utils/init_utils.py +++ b/carvekit/web/utils/init_utils.py @@ -5,6 +5,7 @@ from loguru import logger from carvekit.ml.wrap.cascadepsp import CascadePSP +from carvekit.ml.wrap.isnet import ISNet from carvekit.ml.wrap.scene_classifier import SceneClassifier from carvekit.web.schemas.config import WebAPIConfig, MLConfig, AuthConfig @@ -146,6 +147,13 @@ def init_interface(config: Union[WebAPIConfig, MLConfig]) -> Interface: input_image_size=config.seg_mask_size, fp16=config.fp16, ) + elif config.segmentation_network == "isnet": + seg_net = ISNet( + device=config.device, + batch_size=config.batch_size_seg, + input_image_size=config.seg_mask_size, + fp16=config.fp16, + ) elif config.segmentation_network == "deeplabv3": seg_net = DeepLabV3( device=config.device, diff --git a/docker-compose.cpu.yml b/docker-compose.cpu.yml index bcaa647..2b12659 100644 --- a/docker-compose.cpu.yml +++ b/docker-compose.cpu.yml @@ -6,7 +6,7 @@ services: environment: - CARVEKIT_PORT=5000 - CARVEKIT_HOST=0.0.0.0 - - CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3 + - CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3, isnet - CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub, autoscene, auto - CARVEKIT_POSTPROCESSING_METHOD=cascade_fba # can be none, fba, cascade_fba - CARVEKIT_DEVICE=cpu # can be cuda (req. cuda docker image), cpu diff --git a/docker-compose.cuda.yml b/docker-compose.cuda.yml index f90d9a2..7b55655 100644 --- a/docker-compose.cuda.yml +++ b/docker-compose.cuda.yml @@ -6,7 +6,7 @@ services: environment: - CARVEKIT_PORT=5000 - CARVEKIT_HOST=0.0.0.0 - - CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3 + - CARVEKIT_SEGMENTATION_NETWORK=tracer_b7 # can be u2net, tracer_b7, basnet, deeplabv3, isnet - CARVEKIT_PREPROCESSING_METHOD=none # can be none, stub, autoscene, auto - CARVEKIT_POSTPROCESSING_METHOD=cascade_fba # can be none, fba, cascade_fba - CARVEKIT_DEVICE=cuda # can be cuda (req. cuda docker image), cpu diff --git a/docs/CREDITS.md b/docs/CREDITS.md index 337f9d0..6ef99e3 100644 --- a/docs/CREDITS.md +++ b/docs/CREDITS.md @@ -13,16 +13,17 @@ All images are copyrighted by their authors. ## References: 1. https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/ -2. https://github.com/NathanUA/U-2-Net -3. https://github.com/NathanUA/BASNet -4. https://github.com/MarcoForte/FBA_Matting -5. https://arxiv.org/abs/1706.05587 -6. https://arxiv.org/pdf/2005.09007.pdf -7. http://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html -8. https://arxiv.org/abs/2003.07711 -9. https://arxiv.org/abs/1506.01497 -10. https://arxiv.org/abs/1703.06870 -11. https://github.com/Karel911/TRACER -12. https://arxiv.org/abs/2112.07380 -13. https://github.com/hkchengrex/CascadePSP +2. https://github.com/xuebinqin/DIS +3. https://github.com/NathanUA/U-2-Net +4. https://github.com/NathanUA/BASNet +5. https://github.com/MarcoForte/FBA_Matting +6. https://arxiv.org/abs/1706.05587 +7. https://arxiv.org/pdf/2005.09007.pdf +8. http://openaccess.thecvf.com/content_CVPR_2019/html/Qin_BASNet_Boundary-Aware_Salient_Object_Detection_CVPR_2019_paper.html +9. https://arxiv.org/abs/2003.07711 +10. https://arxiv.org/abs/1506.01497 +11. https://arxiv.org/abs/1703.06870 +12. https://github.com/Karel911/TRACER +13. https://arxiv.org/abs/2112.07380 +14. https://github.com/hkchengrex/CascadePSP diff --git a/docs/other/carvekit_try.ipynb b/docs/other/carvekit_try.ipynb index 35989bd..4c64570 100644 --- a/docs/other/carvekit_try.ipynb +++ b/docs/other/carvekit_try.ipynb @@ -115,9 +115,9 @@ "#@markdown Description of parameters\n", "#@markdown - `SHOW_FULLSIZE` - Shows image in full size (may take a long time to load)\n", "#@markdown - `PREPROCESSING_METHOD` - Preprocessing method. `AutoScene` will automatically select needed model depends on your image. If you don't want, disable it.\n", - "#@markdown - `SEGMENTATION_NETWORK` - Segmentation network. Use `u2net` for hairs-like objects and `tracer_b7` for objects\n", + "#@markdown - `SEGMENTATION_NETWORK` - Segmentation network. Use `isnet` for hairs-like objects and `tracer_b7` for objects\n", "#@markdown - `POSTPROCESSING_METHOD` - Postprocessing method\n", - "#@markdown - `SEGMENTATION_MASK_SIZE` - Segmentation mask size. Use 640 for Tracer B7 and 320 for U2Net\n", + "#@markdown - `SEGMENTATION_MASK_SIZE` - Segmentation mask size. Use 640 for Tracer B7 and 1024 for ISNet\n", "#@markdown - `TRIMAP_DILATION` - The size of the offset radius from the object mask in pixels when forming an unknown area\n", "#@markdown - `TRIMAP_EROSION` - The number of iterations of erosion that the object's mask will be subjected to before forming an unknown area\n", "#@markdown > Look README.md and code for more details on networks and methods\n", @@ -131,7 +131,7 @@ "\n", "SHOW_FULLSIZE = False #@param {type:\"boolean\"}\n", "PREPROCESSING_METHOD = \"autoscene\" #@param [\"autoscene\", \"auto\", \"none\"]\n", - "SEGMENTATION_NETWORK = \"tracer_b7\" #@param [\"u2net\", \"deeplabv3\", \"basnet\", \"tracer_b7\"]\n", + "SEGMENTATION_NETWORK = \"tracer_b7\" #@param [\"u2net\", \"deeplabv3\", \"basnet\", \"tracer_b7\", \"isnet\"]\n", "POSTPROCESSING_METHOD = \"cascade_fba\" #@param [\"fba\", \"cascade_fba\", \"none\"]\n", "SEGMENTATION_MASK_SIZE = 640 #@param [\"640\", \"320\"] {type:\"raw\", allow-input: true}\n", "TRIMAP_DILATION = 30 #@param {type:\"integer\"}\n",