mirror of
https://github.com/ultralytics/ultralytics.git
synced 2025-09-15 15:48:41 +08:00
Merge 9708903567 into 1aa3688613
This commit is contained in:
commit
578337b959
@ -98,12 +98,12 @@ def export_engine(
|
||||
# Engine builder
|
||||
builder = trt.Builder(logger)
|
||||
config = builder.create_builder_config()
|
||||
workspace = int((workspace or 0) * (1 << 30))
|
||||
workspace_bytes = int((workspace or 0) * (1 << 30))
|
||||
is_trt10 = int(trt.__version__.split(".", 1)[0]) >= 10 # is TensorRT >= 10
|
||||
if is_trt10 and workspace > 0:
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_bytes)
|
||||
elif workspace > 0: # TensorRT versions 7, 8
|
||||
config.max_workspace_size = workspace
|
||||
config.max_workspace_size = workspace_bytes
|
||||
flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
network = builder.create_network(flag)
|
||||
half = builder.platform_has_fast_fp16 and half
|
||||
|
||||
Loading…
Reference in New Issue
Block a user