feat: beast mode (#160)

This commit is contained in:
Ze-Yi LIN 2024-09-22 22:17:57 +08:00 committed by GitHub
parent ecf028c1a3
commit 4c2ac524ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 33 additions and 3 deletions

View File

@ -55,6 +55,7 @@
- 在线体验: [![SwanHub Demo](https://img.shields.io/static/v1?label=Demo&message=SwanHub%20Demo&color=blue)](https://swanhub.co/ZeYiLin/HivisionIDPhotos/demo)、[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/TheEeeeLin/HivisionIDPhotos)、[![][modelscope-shield]][modelscope-link]
- 2024.09.22: Gradio Demo增加**野兽模式**,可设置内存加载策略
- 2024.09.18: Gradio Demo增加**分享模版照**功能、增加**美式证件照**背景选项
- 2024.09.17: Gradio Demo增加**自定义底色-HEX输入**功能 | **社区贡献C++版本** - [HivisionIDPhotos-cpp](https://github.com/zjkhahah/HivisionIDPhotos-cpp) 贡献 by [zjkhahah](https://github.com/zjkhahah)
- 2024.09.16: Gradio Demo增加**人脸旋转对齐**功能,自定义尺寸输入支持**毫米**单位
@ -62,7 +63,6 @@
- 2024.09.12: Gradio Demo增加**美白**功能 | API接口增加**加水印**、**设置照片KB值大小**、**证件照裁切**
- 2024.09.11: Gradio Demo增加**透明图显示与下载**功能
- 2024.09.10: 增加新的**人脸检测模型** Retinaface-resnet50以稍弱于mtcnn的速度换取更高的检测精度推荐使用
- 2024.09.09: 增加新的**抠图模型** [BiRefNet-v1-lite](https://github.com/ZhengPeng7/BiRefNet) | Gradio增加**高级参数设置**和**水印**选项卡
<br>
@ -319,13 +319,15 @@ docker compose up -d
|--|--|--|--|
| FACE_PLUS_API_KEY | 可选 | 这是你在 Face++ 控制台申请的 API 密钥 | `7-fZStDJ····` |
| FACE_PLUS_API_SECRET | 可选 | Face++ API密钥对应的Secret | `VTee824E····` |
| RUN_MODE | 可选 | 运行模式,可选值为`beast`(野兽模式)。野兽模式下人脸检测和抠图模型将不释放内存从而获得更快的二次推理速度。建议内存16GB以上尝试。 | `beast` |
docker使用环境变量示例
```bash
docker run -d -p 7860:7860 \
-e FACE_PLUS_API_KEY=7-fZStDJ···· \
-e FACE_PLUS_API_SECRET=VTee824E···· \
linzeyi/hivision_idphotos
-e RUN_MODE=beast \
linzeyi/hivision_idphotos
```
<br>

5
app.py
View File

@ -65,6 +65,11 @@ if __name__ == "__main__":
FACE_DETECT_MODELS_CHOICE,
LANGUAGE,
)
# 如果RUN_MODE是Beast打印已开启野兽模式
if os.getenv("RUN_MODE") == "beast":
print("[Beast mode activated.] 已开启野兽模式。")
demo.launch(
server_name=args.host,
server_port=args.port,

View File

@ -213,3 +213,7 @@ def detect_face_retinaface(ctx: Context):
dx = right_eye[0] - left_eye[0]
roll_angle = np.degrees(np.arctan2(dy, dx))
ctx.face["roll_angle"] = roll_angle
# 如果RUN_MODE不是野兽模式则释放模型
if os.getenv("RUN_MODE") == "beast":
RETINAFCE_SESS = None

View File

@ -201,6 +201,7 @@ def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
print(f"Checkpoint file not found: {checkpoint_path}")
return None
# 如果RUN_MODE不是野兽模式则不加载模型
if HIVISION_MODNET_SESS is None:
HIVISION_MODNET_SESS = load_onnx_model(checkpoint_path, set_cpu=True)
@ -216,6 +217,10 @@ def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
b, g, r = cv2.split(np.uint8(input_image))
output_image = cv2.merge((b, g, r, mask))
# 如果RUN_MODE不是野兽模式则释放模型
if os.getenv("RUN_MODE") != "beast":
HIVISION_MODNET_SESS = None
return output_image
@ -229,6 +234,7 @@ def get_modnet_matting_photographic_portrait_matting(
print(f"Checkpoint file not found: {checkpoint_path}")
return None
# 如果RUN_MODE不是野兽模式则不加载模型
if MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS is None:
MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = load_onnx_model(
checkpoint_path, set_cpu=True
@ -248,6 +254,10 @@ def get_modnet_matting_photographic_portrait_matting(
b, g, r = cv2.split(np.uint8(input_image))
output_image = cv2.merge((b, g, r, mask))
# 如果RUN_MODE不是野兽模式则释放模型
if os.getenv("RUN_MODE") != "beast":
MODNET_PHOTOGRAPHIC_PORTRAIT_MATTING_SESS = None
return output_image
@ -297,6 +307,10 @@ def get_rmbg_matting(input_image: np.ndarray, checkpoint_path, ref_size=1024):
# Paste the mask on the original image
new_im = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
# 如果RUN_MODE不是野兽模式则释放模型
if os.getenv("RUN_MODE") != "beast":
RMBG_SESS = None
return np.array(new_im)
@ -362,8 +376,9 @@ def get_birefnet_portrait_matting(input_image, checkpoint_path, ref_size=512):
# 记录加载onnx模型的开始时间
load_start_time = time()
# 如果RUN_MODE不是野兽模式则不加载模型
if BIREFNET_V1_LITE_SESS is None:
print("首次加载birefnet-v1-lite模型...")
# print("首次加载birefnet-v1-lite模型...")
if ONNX_DEVICE == "GPU":
print("onnxruntime-gpu已安装尝试使用CUDA加载模型")
try:
@ -405,5 +420,9 @@ def get_birefnet_portrait_matting(input_image, checkpoint_path, ref_size=512):
# Paste the mask on the original image
new_im = Image.new("RGBA", orig_image.size, (0, 0, 0, 0))
new_im.paste(orig_image, mask=pil_im)
# 如果RUN_MODE不是野兽模式则释放模型
if os.getenv("RUN_MODE") != "beast":
BIREFNET_V1_LITE_SESS = None
return np.array(new_im)