Dimitri Korsch 5 éve
commit
be7c8b4792
11 módosított fájl, 497 hozzáadás és 0 törlés
  1. 103 0
      .gitignore
  2. 96 0
      create_annotations.py
  3. 49 0
      postprocess_folders.py
  4. 37 0
      postprocess_image_names.py
  5. 64 0
      remove_duplicates.py
  6. 6 0
      requirements.txt
  7. 45 0
      run.py
  8. 34 0
      test_reading.py
  9. 0 0
      utils/__init__.py
  10. 51 0
      utils/imaging.py
  11. 12 0
      utils/parser.py

+ 103 - 0
.gitignore

@@ -0,0 +1,103 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+
+# PyInstaller
+#  Usually these files are written by a python script from a template
+#  before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+
+download
+*.txt
+*.csv

+ 96 - 0
create_annotations.py

@@ -0,0 +1,96 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+import re
+import logging
+import numpy as np
+
+from tqdm import tqdm
+from os.path import basename, join
+
+from cvargparse import Arg
+from utils import parser
+from utils import imaging
+
+def _match(content, regex):
+	match = regex.match(content)
+	if match is not None:
+		return int(match.group(1))
+
+
+def _label_from_foldername(folder_name, label_regex=re.compile(r"^(\d+)\..*$")):
+	return _match(basename(folder_name), label_regex)
+
+def folder_sorting_key(entry):
+	folder_name, _ = entry
+	label = _label_from_foldername(folder_name)
+	return -1 if label is None else label
+
+def file_sorting_key(fname, im_id_regex=re.compile(r"^.*\_(\d+)\..*$")):
+	idx = _match(fname, im_id_regex)
+	return -1 if idx is None else idx
+
+
+def main(args):
+
+	scores = np.loadtxt(args.scores, delimiter=",", skiprows=1, dtype=[
+		("class", np.int),
+		("filename", np.dtype("U255")),
+		("score", np.float),
+	])
+
+	scores_above = scores["score"] > args.min_score
+
+	logging.info(f"{scores_above.sum()} of {len(scores)} are above the score {args.min_score}")
+
+	content = imaging.get_content(args.folder, args.extensions)
+	content = sorted(content, key=folder_sorting_key)
+
+	i = 0
+
+	rnd = np.random.RandomState(args.seed)
+	labels = []
+	split = []
+	with open(args.images_file, "w") as ims_f:
+
+		for folder_name, img_files in tqdm(content):
+			label = _label_from_foldername(folder_name)
+			img_files = sorted(img_files, key=file_sorting_key)
+
+			cls_mask = scores["class"] == label
+
+			assert len(cls_mask) != 0, f"Could not find scores for class {label}"
+
+			ims_above_score = scores["filename"][np.logical_and(cls_mask, scores_above)]
+			ims_above_score = set([basename(fname) for fname in ims_above_score])
+
+			for fname in img_files:
+				if fname not in ims_above_score:
+					continue
+
+				print(i, join(basename(folder_name), fname), file=ims_f)
+				labels.append(label)
+				i += 1
+
+			n_files = len(ims_above_score)
+			split.extend(rnd.choice(2, size=n_files, p=[args.ratio, 1-args.ratio]))
+
+	np.savetxt(args.labels_file, labels, fmt="%d")
+	np.savetxt(args.split_file, split, fmt="%d")
+
+	logging.info(f"Created annotations for {i} samples")
+
+
+
+main(parser.parse_args([
+
+	Arg("scores"),
+	Arg("--min_score", type=float, default=0.6),
+
+	Arg("--images_file", default="images.txt"),
+	Arg("--labels_file", default="labels.txt"),
+	Arg("--split_file", default="tr_ID.txt"),
+	Arg("--split_ratio", dest="ratio",  type=float, default=.1),
+
+	Arg("--seed", type=int, default=42),
+]))

+ 49 - 0
postprocess_folders.py

@@ -0,0 +1,49 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+import os
+import numpy as np
+import re
+import hashlib
+import logging
+
+from cvargparse import Arg
+from os.path import join, basename, dirname, isdir
+from tqdm import tqdm
+
+from collections import OrderedDict
+
+from utils import parser
+from utils import imaging
+
+def rename(path, new_folder):
+	parent = dirname(path)
+	new_path = join(parent, new_folder)
+
+	assert not isdir(new_path), f"Folder already exists: {new_path}!"
+	# print(path, "->", new_path)
+	os.rename(path, new_path)
+
+	return new_path
+
+
+def main(args):
+
+	class_names = imaging.get_classnames(args.classnames)
+	content = imaging.get_content(args.folder, args.extensions)
+
+	for root, fnames in tqdm(content):
+		folder = basename(root)
+
+		if folder not in class_names:
+			assert folder in class_names.values(), f"\"{folder}\" has not valid name!"
+
+		fpaths = [join(root, fname) for fname in fnames]
+
+		new_path = rename(root, class_names[folder])
+
+		# import pdb; pdb.set_trace()
+
+main(parser.parse_args([
+	Arg("classnames")
+]))

+ 37 - 0
postprocess_image_names.py

@@ -0,0 +1,37 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+import os
+import re
+
+from cvargparse import Arg
+from os.path import isfile, join, basename, splitext
+from tqdm import tqdm
+
+from utils import parser
+from utils import imaging
+
+def main(args):
+
+	# class_names = imaging.get_classnames(args.classnames)
+
+	content = imaging.get_content(args.folder, args.extensions)
+
+	for root, fnames in tqdm(content):
+		folder = basename(root)
+		if not re.match(r"^\d+\..*$", folder): continue
+		fnames = sorted(fnames, key=lambda name: int(re.match(r"^(\d+)\..*$", name).group(1)))
+		for i, fname in enumerate(fnames):
+			ext = splitext(fname)[1][1:].strip().lower()
+
+			if ext == "jpeg":
+				ext = "jpg"
+
+			new_name = f"{folder}_{i:04d}.{ext}"
+			new_path = join(root, new_name)
+			assert not isfile(new_path), f"File exists: {new_path}"
+
+			os.rename(join(root, fname), join(root, new_name))
+
+main(parser.parse_args([
+])

+ 64 - 0
remove_duplicates.py

@@ -0,0 +1,64 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+import os
+import numpy as np
+import re
+import hashlib
+
+from cvargparse import BaseParser, Arg
+from os.path import isfile, join
+from tqdm import tqdm
+
+from collections import defaultdict
+
+from utils import parser
+from utils import imaging
+
+def remove_duplicates(fpaths):
+	assert len(fpaths) >= 2, f"There should be at least two paths, but were {len(fpaths)}!"
+
+	for fpath in fpaths[1:]:
+		os.remove(fpath)
+
+
+def main(args):
+
+	fname_regex = re.compile(r"^\d+.(.+)\..*$")
+
+	content = imaging.get_content(args.folder, args.extensions)
+
+	for root, fnames in tqdm(content):
+		names = [(name, fname_regex.match(name).group(1)) for name in fnames]
+
+		counts = defaultdict(int)
+		name_to_fname = defaultdict(list)
+		for fname, name in names:
+			counts[name] += 1
+			name_to_fname[name].append(fname)
+
+
+		for name, count in counts.items():
+			if count == 1: continue
+
+			md5sums = defaultdict(list)
+			md5counts = defaultdict(int)
+
+			for fname in name_to_fname[name]:
+				fpath = join(root, fname)
+				assert isfile(fpath), f"Could not find {fpath}"
+
+				with open(fpath, "rb") as f:
+					md5sum = hashlib.md5(f.read()).hexdigest()
+
+					md5sums[md5sum].append(fpath)
+					md5counts[md5sum] += 1
+
+			for md5sum, count in md5counts.items():
+				if count == 1: continue
+
+				remove_duplicates(md5sums[md5sum])
+
+
+
+main(parser.parse_args())

+ 6 - 0
requirements.txt

@@ -0,0 +1,6 @@
+google_images_download
+piexif
+pillow
+numpy
+
+cvargparse

+ 45 - 0
run.py

@@ -0,0 +1,45 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+from argparse import ArgumentParser
+from google_images_download import google_images_download   #importing the library
+
+parser = ArgumentParser()
+
+parser.add_argument("classnames",
+	help="file containing a class name in each line")
+
+parser.add_argument("--output_directory", "-o", default="download",
+	help="output folder")
+
+parser.add_argument("--limit", "-l", default=60, type=int,
+	help="number of images to dowload")
+
+
+def main(args):
+
+	with open(args.classnames, "r") as f:
+		names = f.readlines()
+
+	# query_names = [name.strip().partition(".")[-1].replace("_", " ").lower() for name in names if not name.startswith("#")]
+	query_names = [name.strip() for name in names if not name.startswith("#")]
+
+	response = google_images_download.googleimagesdownload()   #class instantiation
+
+	print(f"Found {len(query_names)} query names")
+	paths = response.download(dict(
+			keywords=",".join(query_names),
+			limit=args.limit,
+			print_urls=False,
+			size=">800*600",
+			output_directory=args.output_directory,
+			chromedriver="chromedriver"
+		))   #passing the arguments to the function
+
+	# print(paths)   #printing absolute paths of the downloaded images
+
+
+
+
+
+main(parser.parse_args())

+ 34 - 0
test_reading.py

@@ -0,0 +1,34 @@
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+import os
+import numpy as np
+import re
+import hashlib
+
+from cvargparse import BaseParser, Arg
+from os.path import isfile, join
+from tqdm import tqdm
+
+
+from collections import defaultdict
+
+from multiprocessing.pool import Pool
+
+from utils import parser
+from utils import imaging
+
+def main(args):
+
+	i = 0
+	content = imaging.get_content(args.folder, args.extensions)
+
+	# with Pool(6) as pool:
+	for root, fnames in tqdm(content):
+
+		paths = [join(root, fname) for fname in fnames]
+		i += sum(map(imaging.check_readability, paths))
+
+	print(f"unable to read {i} images")
+
+main(parser.parse_args())

+ 0 - 0
utils/__init__.py


+ 51 - 0
utils/imaging.py

@@ -0,0 +1,51 @@
+import os
+import piexif
+import logging
+import numpy as np
+
+from collections import OrderedDict
+from functools import partial
+from PIL import Image
+
+def check_readability(fpath):
+
+	try:
+		piexif.remove(fpath)
+	except:
+		pass
+
+	try:
+		im = Image.open(fpath)
+		return 0
+	except Exception as e:
+		os.remove(fpath)
+		return 1
+
+def is_accepted_image(fname, extensions):
+	return any([fname.lower().endswith(ext) for ext in extensions])
+
+def get_content(folder, extensions):
+
+	_check = partial(is_accepted_image, extensions=extensions)
+	content = []
+
+	for root, _, fnames in os.walk(folder):
+		fnames = list(filter(_check, fnames))
+		if not fnames: continue
+		content.append([root, fnames])
+
+	logging.info(f"Found {len(content)} sub-folders")
+	return content
+
+def get_classnames(classnames):
+
+	_names = np.loadtxt(classnames, dtype=str)
+
+	class_names = OrderedDict()
+	for name in _names:
+		key = name.strip().partition(".")[-1].replace("_", " ").lower()
+		class_names[key] = name
+
+	logging.info(f"Parsed {len(class_names)} class names")
+
+	return class_names

+ 12 - 0
utils/parser.py

@@ -0,0 +1,12 @@
+from cvargparse import Arg, BaseParser
+
+def parse_args(extra_args=[]):
+	parser = BaseParser([
+		Arg("folder"),
+
+		Arg("--extensions", "-ext", nargs="*", default=["jpg", "jpeg", "png"]),
+	] + extra_args)
+
+	parser.init_logger()
+
+	return parser.parse_args()