|
@@ -1,6 +1,7 @@
|
|
|
import logging
|
|
import logging
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import os
|
|
import os
|
|
|
|
|
+import typing as T
|
|
|
import warnings
|
|
import warnings
|
|
|
import yaml
|
|
import yaml
|
|
|
|
|
|
|
@@ -72,17 +73,33 @@ def populate_args(args, ignore = None, replace = {}, fc_params = []):
|
|
|
setattr(args, key, value)
|
|
setattr(args, key, value)
|
|
|
|
|
|
|
|
# get the correct number of classes
|
|
# get the correct number of classes
|
|
|
- args.n_classes = 1000
|
|
|
|
|
- weights = np.load(args.load)
|
|
|
|
|
- n_classes_found = False
|
|
|
|
|
- for key in fc_params:
|
|
|
|
|
- try:
|
|
|
|
|
- args.n_classes = weights[key].shape[0]
|
|
|
|
|
- n_classes_found = True
|
|
|
|
|
- break
|
|
|
|
|
- except KeyError as e:
|
|
|
|
|
- pass
|
|
|
|
|
-
|
|
|
|
|
- if not n_classes_found:
|
|
|
|
|
|
|
+ args.n_classes = get_n_classes(args.load, args.load_path,
|
|
|
|
|
+ fc_params=fc_params)
|
|
|
|
|
+
|
|
|
|
|
+ if args.n_classes <= 0:
|
|
|
raise KeyError("Could not find number of classes!")
|
|
raise KeyError("Could not find number of classes!")
|
|
|
|
|
|
|
|
|
|
+def get_n_classes(weights_file: str, load_path: str = "",
|
|
|
|
|
+ *, fc_params: T.Tuple[str] = ()) -> int:
|
|
|
|
|
+ """
|
|
|
|
|
+ Searches for the classification layer identified by
|
|
|
|
|
+ names in fc_params.
|
|
|
|
|
+ """
|
|
|
|
|
+
|
|
|
|
|
+ weights = np.load(weights_file)
|
|
|
|
|
+ prefixes = [""]
|
|
|
|
|
+
|
|
|
|
|
+ if load_path != "":
|
|
|
|
|
+ prefixes.append(load_path)
|
|
|
|
|
+
|
|
|
|
|
+ for prefix in prefixes:
|
|
|
|
|
+ for name in fc_params:
|
|
|
|
|
+ key = f"{prefix}{name}"
|
|
|
|
|
+ if key not in weights:
|
|
|
|
|
+ continue
|
|
|
|
|
+ return weights[key].shape[0]
|
|
|
|
|
+
|
|
|
|
|
+ return -1
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+
|