Răsfoiți Sursa

minor changes in the requirements and the model_info script

Dimitri Korsch 3 ani în urmă
părinte
comite
ebec0fe27f
2 a modificat fișierele cu 8 adăugiri și 2 ștergeri
  1. 7 1
      cvmodelz/model_info.py
  2. 1 1
      requirements.txt

+ 7 - 1
cvmodelz/model_info.py

@@ -1,6 +1,8 @@
 #!/usr/bin/env python
 #!/usr/bin/env python
 if __name__ != '__main__': raise Exception("Do not import me!")
 if __name__ != '__main__': raise Exception("Do not import me!")
 
 
+import chainer
+
 from cvargparse import Arg
 from cvargparse import Arg
 from cvargparse import BaseParser
 from cvargparse import BaseParser
 
 
@@ -10,6 +12,9 @@ from cvmodelz.models import ModelFactory
 def main(args):
 def main(args):
 
 
 	model = ModelFactory.new(args.model_type, input_size=args.input_size)
 	model = ModelFactory.new(args.model_type, input_size=args.input_size)
+	device = chainer.get_device(args.device)
+	device.use()
+	model.to_device(device)
 	utils.print_model_info(model)
 	utils.print_model_info(model)
 
 
 parser = BaseParser()
 parser = BaseParser()
@@ -17,7 +22,8 @@ parser = BaseParser()
 parser.add_args([
 parser.add_args([
 	Arg("model_type", choices=ModelFactory.get_all_models()),
 	Arg("model_type", choices=ModelFactory.get_all_models()),
 
 
-	Arg("--input_size", "-size", type=int, default=None)
+	Arg("--input_size", "-size", type=int, default=None),
+	Arg("--device", "-dev", type=int, default=-1)
 ])
 ])
 
 
 main(parser.parse_args())
 main(parser.parse_args())

+ 1 - 1
requirements.txt

@@ -1,5 +1,5 @@
 
 
-chainer~=7.0
+chainer~=7.8
 chainercv~=0.13
 chainercv~=0.13
 chainercv2~=0.0
 chainercv2~=0.0