Dimitri Korsch пре 1 година
родитељ
комит
e9b813ccec
2 измењених фајлова са 137 додато и 0 уклоњено
  1. 2 0
      .gitignore
  2. 135 0
      README.md

+ 2 - 0
.gitignore

@@ -111,3 +111,5 @@ log
 *.png
 
 build/
+
+README.html

+ 135 - 0
README.md

@@ -1 +1,136 @@
 # Fine-Tune Framework based on Chainer
+
+[Chainer](https://docs.chainer.org/en/latest/glance.html) framework is an easy to use DL framework.
+It is designed in a hierarchical manner and provides usefull implementations for all of the parts required to train a network:
+
+<img src="https://docs.chainer.org/en/latest/_images/trainer1.png">
+
+We developed [cvmodelz](https://git.inf-cv.uni-jena.de/ComputerVisionJena/cvmodelz) for fast and easy way of initializing commonly used models ("Model" box), and [cvdatasets](https://git.inf-cv.uni-jena.de/ComputerVisionJena/cvdatasets) provides methods to load dataset annotations and create a dataset object ("Dataset" box) that can be further passed to the iterator.
+
+An example training script might be looking like this:
+
+```python
+#!/usr/bin/env python
+if __name__ != '__main__': raise Exception("Do not import me!")
+
+import chainer as ch
+
+from chainer import training
+from chainer.training import extensions
+
+from cvargparse import GPUParser
+from cvargparse import Arg
+
+from cvdatasets import FileListAnnotations
+from cvdatasets import Dataset as BaseDataset
+
+from cvmodelz.models import ModelFactory
+from cvmodelz.classifiers import Classifier
+
+
+class Dataset(BaseDataset):
+
+    def __init__(self, *args, prepare, **kw):
+        super().__init__(*args, **kw)
+        self._prepare = prepare
+
+    def get_example(self, i):
+        im, _, label = super().get_example(i)
+        im = self._prepare(im, size=(224, 224))
+        return im, label
+
+
+def main(args):
+
+    model = ModelFactory.new(args.model_type)
+    clf = Classifier(model)
+
+    device = ch.get_device(args.gpu[0])
+    device.use()
+
+    annot = FileListAnnotations(root_or_infofile=args.data_root)
+    train, test = annot.new_train_test_datasets(
+        dataset_cls=Dataset, prepare=model.prepare)
+
+    train_iter = ch.iterators.MultiprocessIterator(train, batch_size=32, n_processes=4)
+    test_iter = ch.iterators.MultiprocessIterator(test, batch_size=32, n_processes=4,
+        repeat=False, shuffle=False)
+
+    # Setup an optimizer
+    optimizer = ch.optimizers.AdamW(alpha=1e-3).setup(clf)
+
+    # Create the updater, using the optimizer
+    updater = training.StandardUpdater(train_iter, optimizer, device=device)
+
+    # Set up a trainer
+    trainer = training.Trainer(updater, (50, 'epoch'), out='result')
+    # Evaluate the model with the test dataset for each epoch
+    trainer.extend(extensions.Evaluator(test_iter, clf,
+        progress_bar=True,
+        device=device))
+
+    trainer.extend(extensions.ProgressBar(update_interval=1))
+    # Write a log of evaluation statistics for each epoch
+    trainer.extend(extensions.LogReport())
+
+    # Print selected entries of the log to stdout
+    trainer.extend(extensions.PrintReport(
+        ['epoch', 'main/loss', 'validation/main/loss',
+         'main/accuracy', 'validation/main/accuracy']))
+
+
+    #  Run the training
+    trainer.run()
+
+parser = GPUParser([
+    Arg("data_root"),
+    Arg("model_type", choices=ModelFactory.get_models(["cvmodelz"])),
+])
+
+main(parser.parse_args())
+
+# start it with 'python train.py path/to/dataset cvmodelz.ResNet50'
+```
+Everything after the first two lines in the `main` function is chainer-related and has to be done every time one needs to write an experiment.
+Hence, `cvfinetune` simplifies everything and abstracts the initializations:
+
+```python
+from chainer.training.updaters import StandardUpdater
+
+from cvfinetune.finetuner import FinetunerFactory
+from cvfinetune.training.trainer import Trainer
+from cvfinetune.parser import default_factory
+from cvfinetune.parser import FineTuneParser
+
+from cvmodelz.classifiers import Classifier
+
+from cvdatasets.dataset import AnnotationsReadMixin
+from cvdatasets.dataset import TransformMixin
+
+parser = FineTuneParser(default_factory())
+
+class Dataset(TransformMixin, AnnotationsReadMixin):
+    def __init__(self, *args, prepare, center_crop_on_val: bool = True,  **kwargs):
+        super().__init__(*args, **kwargs)
+        self.prepare = prepare
+        self.center_crop_on_val = center_crop_on_val
+
+    def transform(self, im_obj):
+        im, parts, lab = im_obj.as_tuple()
+        return self.prepare(im), lab + self.label_shift
+
+def main(args):
+    factory = FinetunerFactory(mpi=False)
+
+    tuner = factory(args,
+        classifier_cls=Classifier,
+        dataset_cls=Dataset,
+        updater_cls=StandardUpdater,
+
+        no_sacred=True,
+    )
+    tuner.run(trainer_cls=Trainer, opts=args)
+
+
+main(parser.parse_args())
+```