feat: matting-model-selection (#62)

* feat: app support more matting model

* update README
This commit is contained in:
Ze-Yi LIN 2024-09-06 17:02:00 +08:00 committed by GitHub
parent 48923c4908
commit aceeb454a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 83 additions and 40 deletions

View File

@ -30,6 +30,7 @@
- 在线体验: [![SwanHub Demo](https://img.shields.io/static/v1?label=Demo&message=SwanHub%20Demo&color=blue)](https://swanhub.co/ZeYiLin/HivisionIDPhotos/demo)、[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/TheEeeeLin/HivisionIDPhotos)
- 2024.9.6: 增加新的抠图模型 [modnet_photographic_portrait_matting.onnx](https://github.com/ZHKKKe/MODNet)
- 2024.9.5: 更新 [Restful API 文档](docs/api_CN.md)
- 2024.9.2: 更新**调整照片 KB 大小**[DockerHub](https://hub.docker.com/r/linzeyi/hivision_idphotos/tags)
- 2023.12.1: 更新**API 部署(基于 fastapi**
@ -88,6 +89,10 @@ pip install -r requirements-app.txt
在我们的[Release](https://github.com/Zeyi-Lin/HivisionIDPhotos/releases/tag/pretrained-model)下载权重文件`hivision_modnet.onnx` (24.7MB),存到项目的`hivision/creator/weights`目录下。
拓展抠图模型权重(均放到`hivision/creator/weights`目录下):
- modnet_photographic_portrait_matting.onnx: [MODNet](https://github.com/ZHKKKe/MODNet)官方权重,[下载](https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR)
<br>
# 🚀 运行 Gradio Demo

View File

@ -21,6 +21,7 @@ English / [中文](README.md) / [日本語](README_JP.md) / [한국어](README_K
- Online Demo: [![SwanHub Demo](https://img.shields.io/static/v1?label=Demo&message=SwanHub%20Demo&color=blue)](https://swanhub.co/ZeYiLin/HivisionIDPhotos/demo)、[![Spaces](https://img.shields.io/badge/🤗-Open%20in%20Spaces-blue)](https://huggingface.co/spaces/TheEeeeLin/HivisionIDPhotos)
- 2024.9.6: Add a new image matting model [modnet_photographic_portrait_matting.onnx](https://github.com/ZHKKKe/MODNet)
- 2024.9.2: Update **Adjusted photo KB size**[DockerHub](https://hub.docker.com/r/linzeyi/hivision_idphotos/tags)
- 2023.12.1: Update **API deployment (based on fastapi)**
- 2023.6.20: Update **Preset size menu**
@ -80,6 +81,10 @@ pip install -r requirements-app.txt
Download the weight file `hivision_modnet.onnx` from our [Release](https://github.com/Zeyi-Lin/HivisionIDPhotos/releases/tag/pretrained-model) and save it to the `hivision/creator/weights` directory.
Expand matting model weights (all in the `hivision/creator/weights` directory) :
- modnet_photographic_portrait_matting.onnx: by [MODNet](https://github.com/ZHKKKe/MODNet)[Download](https://drive.google.com/drive/folders/1umYmlCulvIFNaqPjwod1SayFmSRHziyR)
<br>
# 🚀 Gradio Demo

63
app.py
View File

@ -7,6 +7,10 @@ from hivision.creator.layout_calculator import (
generate_layout_photo,
generate_layout_image,
)
from hivision.creator.human_matting import (
extract_human_modnet_photographic_portrait_matting,
extract_human,
)
import pathlib
import numpy as np
from demo.utils import csv_to_size_list
@ -54,6 +58,7 @@ def idphoto_inference(
custom_size_width,
custom_image_kb,
language,
matting_model_option,
head_measure_ratio=0.2,
head_height_ratio=0.45,
top_distance_max=0.12,
@ -146,6 +151,11 @@ def idphoto_inference(
idphoto_json["custom_image_kb"] = None
creator = IDCreator()
if matting_model_option == "modnet_photographic_portrait_matting":
creator.matting_handler = extract_human_modnet_photographic_portrait_matting
else:
creator.matting_handler = extract_human
change_bg_only = idphoto_json["size_mode"] in ["只换底", "Only Change Background"]
# 生成证件照
try:
@ -267,7 +277,24 @@ def idphoto_inference(
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument(
"--port", type=int, default=7860, help="The port number of the server"
)
argparser.add_argument(
"--host", type=str, default="127.0.0.1", help="The host of the server"
)
args = argparser.parse_args()
language = ["中文", "English"]
matting_model_list = [
os.path.splitext(file)[0]
for file in os.listdir(os.path.join(root_dir, "hivision/creator/weights"))
if file.endswith(".onnx")
]
size_mode_CN = ["尺寸列表", "只换底", "自定义尺寸"]
size_mode_EN = ["Size List", "Only Change Background", "Custom Size"]
@ -283,14 +310,6 @@ if __name__ == "__main__":
image_kb_CN = ["不设置", "自定义"]
image_kb_EN = ["Not Set", "Custom"]
# title = "<h1 id='title'>HivisionIDPhotos</h1>"
# description = "<h3>😎9.2 Update: Add photo size KB adjustment</h3>"
# css = """
# h1#title, h3 {
# text-align: center;
# }
# """
css = """
#col-left {
margin: 0 auto;
@ -326,9 +345,20 @@ if __name__ == "__main__":
# ------------ 左半边 UI ----------------
with gr.Column():
img_input = gr.Image(height=400)
language_options = gr.Dropdown(
choices=language, label="Language", value="中文", elem_id="language"
)
with gr.Row():
language_options = gr.Dropdown(
choices=language,
label="Language",
value="中文",
elem_id="language",
)
matting_model_options = gr.Dropdown(
choices=matting_model_list,
label="Matting Model",
value="hivision_modnet",
elem_id="matting_model",
)
mode_options = gr.Radio(
choices=size_mode_CN,
@ -453,6 +483,7 @@ if __name__ == "__main__":
img_output_layout: gr.update(label="六寸排版照"),
file_download: gr.update(label="下载调整 KB 大小后的照片"),
}
elif language == "English":
return {
size_list_options: gr.update(
@ -576,6 +607,7 @@ if __name__ == "__main__":
custom_size_wdith,
custom_image_kb_size,
language_options,
matting_model_options,
],
outputs=[
img_output_standard,
@ -586,13 +618,4 @@ if __name__ == "__main__":
],
)
argparser = argparse.ArgumentParser()
argparser.add_argument(
"--port", type=int, default=7860, help="The port number of the server"
)
argparser.add_argument(
"--host", type=str, default="127.0.0.1", help="The host of the server"
)
args = argparser.parse_args()
demo.launch(server_name=args.host, server_port=args.port)

View File

@ -146,12 +146,5 @@ async def generate_layout_photos(
if __name__ == "__main__":
import uvicorn
# 加载权重文件
root_dir = os.path.dirname(os.path.abspath(__file__))
HY_HUMAN_MATTING_WEIGHTS_PATH = os.path.join(
root_dir, "hivision/creator/weights/hivision_modnet.onnx"
)
sess = onnxruntime.InferenceSession(HY_HUMAN_MATTING_WEIGHTS_PATH)
# 在8080端口运行推理服务
uvicorn.run(app, host="0.0.0.0", port=8080)

View File

@ -15,7 +15,17 @@ from .context import Context
import cv2
import os
weight_path = os.path.join(os.path.dirname(__file__), "weights", "hivision_modnet.onnx")
WEIGHTS = {
"hivision_modnet": os.path.join(
os.path.dirname(__file__), "weights", "hivision_modnet.onnx"
),
"modnet_photographic_portrait_matting": os.path.join(
os.path.dirname(__file__),
"weights",
"modnet_photographic_portrait_matting.onnx",
),
}
def extract_human(ctx: Context):
@ -24,7 +34,21 @@ def extract_human(ctx: Context):
:param ctx: 上下文
"""
# 抠图
matting_image = get_modnet_matting(ctx.processing_image, weight_path)
matting_image = get_modnet_matting(ctx.processing_image, WEIGHTS["hivision_modnet"])
# 修复抠图
ctx.processing_image = hollow_out_fix(matting_image)
ctx.matting_image = ctx.processing_image.copy()
def extract_human_modnet_photographic_portrait_matting(ctx: Context):
"""
人像抠图
:param ctx: 上下文
"""
# 抠图
matting_image = get_modnet_matting(
ctx.processing_image, WEIGHTS["modnet_photographic_portrait_matting"]
)
# 修复抠图
ctx.processing_image = hollow_out_fix(matting_image)
ctx.matting_image = ctx.processing_image.copy()
@ -92,13 +116,13 @@ def read_modnet_image(input_image, ref_size=512):
return im, width, length
sess = None
# sess = None
def get_modnet_matting(input_image, checkpoint_path, ref_size=512):
global sess
if sess is None:
sess = onnxruntime.InferenceSession(checkpoint_path)
# global sess
# if sess is None:
sess = onnxruntime.InferenceSession(checkpoint_path)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name

View File

@ -34,13 +34,6 @@ args = parser.parse_args()
root_dir = os.path.dirname(os.path.abspath(__file__))
# 预加载 ONNX 模型
print("正在加载抠图模型...")
# HY_HUMAN_MATTING_WEIGHTS_PATH = os.path.join(
# root_dir, "hivision/creator/weights/hivision_modnet.onnx"
# )
# sess = onnxruntime.InferenceSession(HY_HUMAN_MATTING_WEIGHTS_PATH)
input_image = cv2.imread(args.input_image_dir, cv2.IMREAD_UNCHANGED)