mirror of
https://github.com/ultralytics/ultralytics.git
synced 2025-09-15 15:48:41 +08:00
ultralytics 8.3.198 Improve Tuner with BLX-α gene crossover (#22038)
Signed-off-by: Glenn Jocher <glenn.jocher@ultralytics.com> Co-authored-by: UltralyticsAssistant <web@ultralytics.com>
This commit is contained in:
parent
fd4acf7bd0
commit
9e372fc12a
@ -136,6 +136,9 @@ Fruchtzwerg94@users.noreply.github.com:
|
||||
aaurelions@gmail.com:
|
||||
avatar: null
|
||||
username: null
|
||||
abi@ultralytics.com:
|
||||
avatar: https://avatars.githubusercontent.com/u/224584378?v=4
|
||||
username: UltralyticsAbi
|
||||
abirami.vina@gmail.com:
|
||||
avatar: https://avatars.githubusercontent.com/u/25847604?v=4
|
||||
username: abirami-vina
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
||||
|
||||
__version__ = "8.3.197"
|
||||
__version__ = "8.3.198"
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ Examples:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import random
|
||||
import shutil
|
||||
import subprocess
|
||||
@ -23,6 +24,7 @@ import time
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ultralytics.cfg import get_cfg, get_save_dir
|
||||
from ultralytics.utils import DEFAULT_CFG, LOGGER, YAML, callbacks, colorstr, remove_colorstr
|
||||
@ -97,7 +99,7 @@ class Tuner:
|
||||
"warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
|
||||
"warmup_momentum": (0.0, 0.95), # warmup initial momentum
|
||||
"box": (1.0, 20.0), # box loss gain
|
||||
"cls": (0.2, 4.0), # cls loss gain (scale with pixels)
|
||||
"cls": (0.1, 4.0), # cls loss gain (scale with pixels)
|
||||
"dfl": (0.4, 6.0), # dfl loss gain
|
||||
"hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
|
||||
"hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
|
||||
@ -114,6 +116,7 @@ class Tuner:
|
||||
"mixup": (0.0, 1.0), # image mixup (probability)
|
||||
"cutmix": (0.0, 1.0), # image cutmix (probability)
|
||||
"copy_paste": (0.0, 1.0), # segment copy-paste (probability)
|
||||
"close_mosaic": (0.0, 10.0), # close dataloader mosaic (epochs)
|
||||
}
|
||||
mongodb_uri = args.pop("mongodb_uri", None)
|
||||
mongodb_db = args.pop("mongodb_db", "ultralytics")
|
||||
@ -266,19 +269,31 @@ class Tuner:
|
||||
except Exception as e:
|
||||
LOGGER.warning(f"{self.prefix}MongoDB to CSV sync failed: {e}")
|
||||
|
||||
def _crossover(self, x: np.ndarray, alpha: float = 0.2, k: int = 9) -> np.ndarray:
|
||||
"""BLX-α crossover from up to top-k parents (x[:,0]=fitness, rest=genes)."""
|
||||
k = min(k, len(x))
|
||||
# fitness weights (shifted to >0); fallback to uniform if degenerate
|
||||
weights = x[:, 0] - x[:, 0].min() + 1e-6
|
||||
if not np.isfinite(weights).all() or weights.sum() == 0:
|
||||
weights = np.ones_like(weights)
|
||||
idxs = random.choices(range(len(x)), weights=weights, k=k)
|
||||
parents_mat = np.stack([x[i][1:] for i in idxs], 0) # (k, ng) strip fitness
|
||||
lo, hi = parents_mat.min(0), parents_mat.max(0)
|
||||
span = hi - lo
|
||||
return np.random.uniform(lo - alpha * span, hi + alpha * span)
|
||||
|
||||
def _mutate(
|
||||
self,
|
||||
parent: str = "single",
|
||||
n: int = 5,
|
||||
mutation: float = 0.8,
|
||||
n: int = 9,
|
||||
mutation: float = 0.5,
|
||||
sigma: float = 0.2,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Mutate hyperparameters based on bounds and scaling factors specified in `self.space`.
|
||||
|
||||
Args:
|
||||
parent (str): Parent selection method: 'single' or 'weighted'.
|
||||
n (int): Number of parents to consider.
|
||||
parent (str): Parent selection method (kept for API compatibility, unused in BLX mode).
|
||||
n (int): Number of top parents to consider.
|
||||
mutation (float): Probability of a parameter mutation in any given iteration.
|
||||
sigma (float): Standard deviation for Gaussian random number generator.
|
||||
|
||||
@ -293,41 +308,40 @@ class Tuner:
|
||||
if results:
|
||||
# MongoDB already sorted by fitness DESC, so results[0] is best
|
||||
x = np.array([[r["fitness"]] + [r["hyperparameters"][k] for k in self.space.keys()] for r in results])
|
||||
n = min(n, len(x))
|
||||
|
||||
# Fall back to CSV if MongoDB unavailable or empty
|
||||
if x is None and self.tune_csv.exists():
|
||||
csv_data = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
||||
if len(csv_data) > 0:
|
||||
fitness = csv_data[:, 0] # first column
|
||||
n = min(n, len(csv_data))
|
||||
x = csv_data[np.argsort(-fitness)][:n] # top n sorted by fitness DESC
|
||||
order = np.argsort(-fitness)
|
||||
x = csv_data[order][:n] # top-n sorted by fitness DESC
|
||||
|
||||
# Mutate if we have data, otherwise use defaults
|
||||
if x is not None:
|
||||
w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
|
||||
if parent == "single" or len(x) <= 1:
|
||||
x = x[random.choices(range(n), weights=w)[0]] # weighted selection
|
||||
elif parent == "weighted":
|
||||
x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
|
||||
|
||||
# Mutate
|
||||
r = np.random
|
||||
r.seed(int(time.time()))
|
||||
g = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1
|
||||
np.random.seed(int(time.time()))
|
||||
ng = len(self.space)
|
||||
v = np.ones(ng)
|
||||
while all(v == 1): # mutate until a change occurs (prevent duplicates)
|
||||
v = (g * (r.random(ng) < mutation) * r.randn(ng) * r.random() * sigma + 1).clip(0.3, 3.0)
|
||||
hyp = {k: float(x[i + 1] * v[i]) for i, k in enumerate(self.space.keys())}
|
||||
|
||||
# Crossover
|
||||
genes = self._crossover(x)
|
||||
|
||||
# Mutation
|
||||
gains = np.array([v[2] if len(v) == 3 else 1.0 for v in self.space.values()]) # gains 0-1
|
||||
factors = np.ones(ng)
|
||||
while np.all(factors == 1): # mutate until a change occurs (prevent duplicates)
|
||||
mask = np.random.random(ng) < mutation
|
||||
step = np.random.randn(ng) * (sigma * gains)
|
||||
factors = np.where(mask, np.exp(step), 1.0).clip(0.25, 4.0)
|
||||
hyp = {k: float(genes[i] * factors[i]) for i, k in enumerate(self.space.keys())}
|
||||
else:
|
||||
hyp = {k: getattr(self.args, k) for k in self.space.keys()}
|
||||
|
||||
# Constrain to limits
|
||||
for k, bounds in self.space.items():
|
||||
hyp[k] = max(hyp[k], bounds[0]) # lower limit
|
||||
hyp[k] = min(hyp[k], bounds[1]) # upper limit
|
||||
hyp[k] = round(hyp[k], 5) # significant digits
|
||||
hyp[k] = round(min(max(hyp[k], bounds[0]), bounds[1]), 5)
|
||||
|
||||
# Update types
|
||||
hyp["close_mosaic"] = int(round(hyp["close_mosaic"]))
|
||||
|
||||
return hyp
|
||||
|
||||
@ -361,8 +375,12 @@ class Tuner:
|
||||
start = x.shape[0]
|
||||
LOGGER.info(f"{self.prefix}Resuming tuning run {self.tune_dir} from iteration {start + 1}...")
|
||||
for i in range(start, iterations):
|
||||
# Linearly decay sigma from 0.2 → 0.1 over first 300 iterations
|
||||
frac = min(i / 300.0, 1.0)
|
||||
sigma_i = 0.2 - 0.1 * frac
|
||||
|
||||
# Mutate hyperparameters
|
||||
mutated_hyp = self._mutate()
|
||||
mutated_hyp = self._mutate(sigma=sigma_i)
|
||||
LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
|
||||
|
||||
metrics = {}
|
||||
@ -378,6 +396,11 @@ class Tuner:
|
||||
metrics = torch_load(ckpt_file)["train_metrics"]
|
||||
assert return_code == 0, "training failed"
|
||||
|
||||
# Cleanup
|
||||
time.sleep(1)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
except Exception as e:
|
||||
LOGGER.error(f"training failure for hyperparameter tuning iteration {i + 1}\n{e}")
|
||||
|
||||
@ -403,14 +426,14 @@ class Tuner:
|
||||
x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
|
||||
fitness = x[:, 0] # first column
|
||||
best_idx = fitness.argmax()
|
||||
best_is_current = best_idx == i
|
||||
best_is_current = best_idx == (i - start)
|
||||
if best_is_current:
|
||||
best_save_dir = save_dir
|
||||
best_save_dir = str(save_dir)
|
||||
best_metrics = {k: round(v, 5) for k, v in metrics.items()}
|
||||
for ckpt in weights_dir.glob("*.pt"):
|
||||
shutil.copy2(ckpt, self.tune_dir / "weights")
|
||||
elif cleanup:
|
||||
shutil.rmtree(weights_dir, ignore_errors=True) # remove iteration weights/ dir to reduce storage space
|
||||
shutil.rmtree(best_save_dir, ignore_errors=True) # remove iteration dirs to reduce storage space
|
||||
|
||||
# Plot tune results
|
||||
plot_tune_results(str(self.tune_csv))
|
||||
@ -421,8 +444,7 @@ class Tuner:
|
||||
f"{self.prefix}Results saved to {colorstr('bold', self.tune_dir)}\n"
|
||||
f"{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n"
|
||||
f"{self.prefix}Best fitness metrics are {best_metrics}\n"
|
||||
f"{self.prefix}Best fitness model is {best_save_dir}\n"
|
||||
f"{self.prefix}Best fitness hyperparameters are printed below.\n"
|
||||
f"{self.prefix}Best fitness model is {best_save_dir}"
|
||||
)
|
||||
LOGGER.info("\n" + header)
|
||||
data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
|
||||
|
||||
@ -980,7 +980,7 @@ class Metric(SimpleClass):
|
||||
|
||||
def fitness(self) -> float:
|
||||
"""Return model fitness as a weighted combination of metrics."""
|
||||
w = [0.0, 0.0, 0.1, 0.9] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
||||
w = [0.0, 0.0, 0.0, 1.0] # weights for [P, R, mAP@0.5, mAP@0.5:0.95]
|
||||
return (np.nan_to_num(np.array(self.mean_results())) * w).sum()
|
||||
|
||||
def update(self, results: tuple):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user