浏览代码

minor changes in the requirements and the model_info script

Dimitri Korsch 3 年之前
父节点
当前提交
ebec0fe27f
共有 2 个文件被更改,包括 8 次插入2 次删除
  1. 7 1
      cvmodelz/model_info.py
  2. 1 1
      requirements.txt

+ 7 - 1
cvmodelz/model_info.py

@@ -1,6 +1,8 @@
 #!/usr/bin/env python
 if __name__ != '__main__': raise Exception("Do not import me!")
 
+import chainer
+
 from cvargparse import Arg
 from cvargparse import BaseParser
 
@@ -10,6 +12,9 @@ from cvmodelz.models import ModelFactory
 def main(args):
 
 	model = ModelFactory.new(args.model_type, input_size=args.input_size)
+	device = chainer.get_device(args.device)
+	device.use()
+	model.to_device(device)
 	utils.print_model_info(model)
 
 parser = BaseParser()
@@ -17,7 +22,8 @@ parser = BaseParser()
 parser.add_args([
 	Arg("model_type", choices=ModelFactory.get_all_models()),
 
-	Arg("--input_size", "-size", type=int, default=None)
+	Arg("--input_size", "-size", type=int, default=None),
+	Arg("--device", "-dev", type=int, default=-1)
 ])
 
 main(parser.parse_args())

+ 1 - 1
requirements.txt

@@ -1,5 +1,5 @@
 
-chainer~=7.0
+chainer~=7.8
 chainercv~=0.13
 chainercv2~=0.0