Explorar o código

rewritten the utils function for extraction of the number of classes

Dimitri Korsch %!s(int64=3) %!d(string=hai) anos
pai
achega
cd3cb473ad
Modificáronse 1 ficheiros con 29 adicións e 12 borrados
  1. 29 12
      cvfinetune/parser/utils.py

+ 29 - 12
cvfinetune/parser/utils.py

@@ -1,6 +1,7 @@
 import logging
 import numpy as np
 import os
+import typing as T
 import warnings
 import yaml
 
@@ -72,17 +73,33 @@ def populate_args(args, ignore = None, replace = {}, fc_params = []):
 		setattr(args, key, value)
 
 	# get the correct number of classes
-	args.n_classes = 1000
-	weights = np.load(args.load)
-	n_classes_found = False
-	for key in fc_params:
-		try:
-			args.n_classes = weights[key].shape[0]
-			n_classes_found = True
-			break
-		except KeyError as e:
-			pass
-
-	if not n_classes_found:
+	args.n_classes = get_n_classes(args.load, args.load_path,
+		fc_params=fc_params)
+
+	if args.n_classes <= 0:
 		raise KeyError("Could not find number of classes!")
 
+def get_n_classes(weights_file: str, load_path: str = "",
+	*, fc_params: T.Tuple[str] = ()) -> int:
+	"""
+		Searches for the classification layer identified by
+		names in fc_params.
+	"""
+
+	weights = np.load(weights_file)
+	prefixes = [""]
+
+	if load_path != "":
+		prefixes.append(load_path)
+
+	for prefix in prefixes:
+		for name in fc_params:
+			key = f"{prefix}{name}"
+			if key not in weights:
+				continue
+			return weights[key].shape[0]
+
+	return -1
+
+
+