common.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Common modules
  4. """
  5. import json
  6. import math
  7. import platform
  8. import warnings
  9. from collections import OrderedDict, namedtuple
  10. from copy import copy
  11. from pathlib import Path
  12. import cv2
  13. import numpy as np
  14. import pandas as pd
  15. import requests
  16. import torch
  17. import torch.nn as nn
  18. import yaml
  19. from PIL import Image
  20. from torch.cuda import amp
  21. from utils.datasets import exif_transpose, letterbox
  22. from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
  23. make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
  24. from utils.plots import Annotator, colors, save_one_box
  25. from utils.torch_utils import copy_attr, time_sync
  26. def autopad(k, p=None): # kernel, padding
  27. # Pad to 'same'
  28. if p is None:
  29. p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
  30. return p
  31. class Conv(nn.Module):
  32. # Standard convolution
  33. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  34. super().__init__()
  35. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  36. self.bn = nn.BatchNorm2d(c2)
  37. self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())
  38. def forward(self, x):
  39. return self.act(self.bn(self.conv(x)))
  40. def forward_fuse(self, x):
  41. return self.act(self.conv(x))
  42. class DWConv(Conv):
  43. # Depth-wise convolution class
  44. def __init__(self, c1, c2, k=1, s=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  45. super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), act=act)
  46. class TransformerLayer(nn.Module):
  47. # Transformer layer https://arxiv.org/abs/2010.11929 (LayerNorm layers removed for better performance)
  48. def __init__(self, c, num_heads):
  49. super().__init__()
  50. self.q = nn.Linear(c, c, bias=False)
  51. self.k = nn.Linear(c, c, bias=False)
  52. self.v = nn.Linear(c, c, bias=False)
  53. self.ma = nn.MultiheadAttention(embed_dim=c, num_heads=num_heads)
  54. self.fc1 = nn.Linear(c, c, bias=False)
  55. self.fc2 = nn.Linear(c, c, bias=False)
  56. def forward(self, x):
  57. x = self.ma(self.q(x), self.k(x), self.v(x))[0] + x
  58. x = self.fc2(self.fc1(x)) + x
  59. return x
  60. class TransformerBlock(nn.Module):
  61. # Vision Transformer https://arxiv.org/abs/2010.11929
  62. def __init__(self, c1, c2, num_heads, num_layers):
  63. super().__init__()
  64. self.conv = None
  65. if c1 != c2:
  66. self.conv = Conv(c1, c2)
  67. self.linear = nn.Linear(c2, c2) # learnable position embedding
  68. self.tr = nn.Sequential(*(TransformerLayer(c2, num_heads) for _ in range(num_layers)))
  69. self.c2 = c2
  70. def forward(self, x):
  71. if self.conv is not None:
  72. x = self.conv(x)
  73. b, _, w, h = x.shape
  74. p = x.flatten(2).permute(2, 0, 1)
  75. return self.tr(p + self.linear(p)).permute(1, 2, 0).reshape(b, self.c2, w, h)
  76. class Bottleneck(nn.Module):
  77. # Standard bottleneck
  78. def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansion
  79. super().__init__()
  80. c_ = int(c2 * e) # hidden channels
  81. self.cv1 = Conv(c1, c_, 1, 1)
  82. self.cv2 = Conv(c_, c2, 3, 1, g=g)
  83. self.add = shortcut and c1 == c2
  84. def forward(self, x):
  85. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
  86. class BottleneckCSP(nn.Module):
  87. # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
  88. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  89. super().__init__()
  90. c_ = int(c2 * e) # hidden channels
  91. self.cv1 = Conv(c1, c_, 1, 1)
  92. self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False)
  93. self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False)
  94. self.cv4 = Conv(2 * c_, c2, 1, 1)
  95. self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3)
  96. self.act = nn.SiLU()
  97. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  98. def forward(self, x):
  99. y1 = self.cv3(self.m(self.cv1(x)))
  100. y2 = self.cv2(x)
  101. return self.cv4(self.act(self.bn(torch.cat((y1, y2), dim=1))))
  102. class C3(nn.Module):
  103. # CSP Bottleneck with 3 convolutions
  104. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  105. super().__init__()
  106. c_ = int(c2 * e) # hidden channels
  107. self.cv1 = Conv(c1, c_, 1, 1)
  108. self.cv2 = Conv(c1, c_, 1, 1)
  109. self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)
  110. self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  111. # self.m = nn.Sequential(*[CrossConv(c_, c_, 3, 1, g, 1.0, shortcut) for _ in range(n)])
  112. def forward(self, x):
  113. return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
  114. class C3TR(C3):
  115. # C3 module with TransformerBlock()
  116. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  117. super().__init__(c1, c2, n, shortcut, g, e)
  118. c_ = int(c2 * e)
  119. self.m = TransformerBlock(c_, c_, 4, n)
  120. class C3SPP(C3):
  121. # C3 module with SPP()
  122. def __init__(self, c1, c2, k=(5, 9, 13), n=1, shortcut=True, g=1, e=0.5):
  123. super().__init__(c1, c2, n, shortcut, g, e)
  124. c_ = int(c2 * e)
  125. self.m = SPP(c_, c_, k)
  126. class C3Ghost(C3):
  127. # C3 module with GhostBottleneck()
  128. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
  129. super().__init__(c1, c2, n, shortcut, g, e)
  130. c_ = int(c2 * e) # hidden channels
  131. self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n)))
  132. class SPP(nn.Module):
  133. # Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729
  134. def __init__(self, c1, c2, k=(5, 9, 13)):
  135. super().__init__()
  136. c_ = c1 // 2 # hidden channels
  137. self.cv1 = Conv(c1, c_, 1, 1)
  138. self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)
  139. self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k])
  140. def forward(self, x):
  141. x = self.cv1(x)
  142. with warnings.catch_warnings():
  143. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  144. return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1))
  145. class SPPF(nn.Module):
  146. # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
  147. def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))
  148. super().__init__()
  149. c_ = c1 // 2 # hidden channels
  150. self.cv1 = Conv(c1, c_, 1, 1)
  151. self.cv2 = Conv(c_ * 4, c2, 1, 1)
  152. self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
  153. def forward(self, x):
  154. x = self.cv1(x)
  155. with warnings.catch_warnings():
  156. warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning
  157. y1 = self.m(x)
  158. y2 = self.m(y1)
  159. return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
  160. class Focus(nn.Module):
  161. # Focus wh information into c-space
  162. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  163. super().__init__()
  164. self.conv = Conv(c1 * 4, c2, k, s, p, g, act)
  165. # self.contract = Contract(gain=2)
  166. def forward(self, x): # x(b,c,w,h) -> y(b,4c,w/2,h/2)
  167. return self.conv(torch.cat([x[..., ::2, ::2], x[..., 1::2, ::2], x[..., ::2, 1::2], x[..., 1::2, 1::2]], 1))
  168. # return self.conv(self.contract(x))
  169. class GhostConv(nn.Module):
  170. # Ghost Convolution https://github.com/huawei-noah/ghostnet
  171. def __init__(self, c1, c2, k=1, s=1, g=1, act=True): # ch_in, ch_out, kernel, stride, groups
  172. super().__init__()
  173. c_ = c2 // 2 # hidden channels
  174. self.cv1 = Conv(c1, c_, k, s, None, g, act)
  175. self.cv2 = Conv(c_, c_, 5, 1, None, c_, act)
  176. def forward(self, x):
  177. y = self.cv1(x)
  178. return torch.cat([y, self.cv2(y)], 1)
  179. class GhostBottleneck(nn.Module):
  180. # Ghost Bottleneck https://github.com/huawei-noah/ghostnet
  181. def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride
  182. super().__init__()
  183. c_ = c2 // 2
  184. self.conv = nn.Sequential(GhostConv(c1, c_, 1, 1), # pw
  185. DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw
  186. GhostConv(c_, c2, 1, 1, act=False)) # pw-linear
  187. self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False),
  188. Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity()
  189. def forward(self, x):
  190. return self.conv(x) + self.shortcut(x)
  191. class Contract(nn.Module):
  192. # Contract width-height into channels, i.e. x(1,64,80,80) to x(1,256,40,40)
  193. def __init__(self, gain=2):
  194. super().__init__()
  195. self.gain = gain
  196. def forward(self, x):
  197. b, c, h, w = x.size() # assert (h / s == 0) and (W / s == 0), 'Indivisible gain'
  198. s = self.gain
  199. x = x.view(b, c, h // s, s, w // s, s) # x(1,64,40,2,40,2)
  200. x = x.permute(0, 3, 5, 1, 2, 4).contiguous() # x(1,2,2,64,40,40)
  201. return x.view(b, c * s * s, h // s, w // s) # x(1,256,40,40)
  202. class Expand(nn.Module):
  203. # Expand channels into width-height, i.e. x(1,64,80,80) to x(1,16,160,160)
  204. def __init__(self, gain=2):
  205. super().__init__()
  206. self.gain = gain
  207. def forward(self, x):
  208. b, c, h, w = x.size() # assert C / s ** 2 == 0, 'Indivisible gain'
  209. s = self.gain
  210. x = x.view(b, s, s, c // s ** 2, h, w) # x(1,2,2,16,80,80)
  211. x = x.permute(0, 3, 4, 1, 5, 2).contiguous() # x(1,16,80,2,80,2)
  212. return x.view(b, c // s ** 2, h * s, w * s) # x(1,16,160,160)
  213. class Concat(nn.Module):
  214. # Concatenate a list of tensors along dimension
  215. def __init__(self, dimension=1):
  216. super().__init__()
  217. self.d = dimension
  218. def forward(self, x):
  219. return torch.cat(x, self.d)
  220. class DetectMultiBackend(nn.Module):
  221. # YOLOv5 MultiBackend class for python inference on various backends
  222. def __init__(self, weights='yolov5s.pt', device=None, dnn=False, data=None):
  223. # Usage:
  224. # PyTorch: weights = *.pt
  225. # TorchScript: *.torchscript
  226. # CoreML: *.mlmodel
  227. # OpenVINO: *.xml
  228. # TensorFlow: *_saved_model
  229. # TensorFlow: *.pb
  230. # TensorFlow Lite: *.tflite
  231. # TensorFlow Edge TPU: *_edgetpu.tflite
  232. # ONNX Runtime: *.onnx
  233. # OpenCV DNN: *.onnx with dnn=True
  234. # TensorRT: *.engine
  235. from models.experimental import attempt_download, attempt_load # scoped to avoid circular import
  236. super().__init__()
  237. w = str(weights[0] if isinstance(weights, list) else weights)
  238. suffix = Path(w).suffix.lower()
  239. suffixes = ['.pt', '.torchscript', '.onnx', '.engine', '.tflite', '.pb', '', '.mlmodel', '.xml']
  240. check_suffix(w, suffixes) # check weights have acceptable suffix
  241. pt, jit, onnx, engine, tflite, pb, saved_model, coreml, xml = (suffix == x for x in suffixes) # backends
  242. stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
  243. w = attempt_download(w) # download if not local
  244. if data: # data.yaml path (optional)
  245. with open(data, errors='ignore') as f:
  246. names = yaml.safe_load(f)['names'] # class names
  247. if pt: # PyTorch
  248. model = attempt_load(weights if isinstance(weights, list) else w, map_location=device)
  249. stride = max(int(model.stride.max()), 32) # model stride
  250. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  251. self.model = model # explicitly assign for to(), cpu(), cuda(), half()
  252. elif jit: # TorchScript
  253. LOGGER.info(f'Loading {w} for TorchScript inference...')
  254. extra_files = {'config.txt': ''} # model metadata
  255. model = torch.jit.load(w, _extra_files=extra_files)
  256. if extra_files['config.txt']:
  257. d = json.loads(extra_files['config.txt']) # extra_files dict
  258. stride, names = int(d['stride']), d['names']
  259. elif dnn: # ONNX OpenCV DNN
  260. LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
  261. check_requirements(('opencv-python>=4.5.4',))
  262. net = cv2.dnn.readNetFromONNX(w)
  263. elif onnx: # ONNX Runtime
  264. LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
  265. cuda = torch.cuda.is_available()
  266. check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
  267. import onnxruntime
  268. providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
  269. session = onnxruntime.InferenceSession(w, providers=providers)
  270. elif xml: # OpenVINO
  271. LOGGER.info(f'Loading {w} for OpenVINO inference...')
  272. check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
  273. import openvino.inference_engine as ie
  274. core = ie.IECore()
  275. network = core.read_network(model=w, weights=Path(w).with_suffix('.bin')) # *.xml, *.bin paths
  276. executable_network = core.load_network(network, device_name='CPU', num_requests=1)
  277. elif engine: # TensorRT
  278. LOGGER.info(f'Loading {w} for TensorRT inference...')
  279. import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
  280. check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
  281. Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
  282. logger = trt.Logger(trt.Logger.INFO)
  283. with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
  284. model = runtime.deserialize_cuda_engine(f.read())
  285. bindings = OrderedDict()
  286. for index in range(model.num_bindings):
  287. name = model.get_binding_name(index)
  288. dtype = trt.nptype(model.get_binding_dtype(index))
  289. shape = tuple(model.get_binding_shape(index))
  290. data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
  291. bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
  292. binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
  293. context = model.create_execution_context()
  294. batch_size = bindings['images'].shape[0]
  295. elif coreml: # CoreML
  296. LOGGER.info(f'Loading {w} for CoreML inference...')
  297. import coremltools as ct
  298. model = ct.models.MLModel(w)
  299. else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
  300. if saved_model: # SavedModel
  301. LOGGER.info(f'Loading {w} for TensorFlow SavedModel inference...')
  302. import tensorflow as tf
  303. model = tf.keras.models.load_model(w)
  304. elif pb: # GraphDef https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
  305. LOGGER.info(f'Loading {w} for TensorFlow GraphDef inference...')
  306. import tensorflow as tf
  307. def wrap_frozen_graph(gd, inputs, outputs):
  308. x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
  309. return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
  310. tf.nest.map_structure(x.graph.as_graph_element, outputs))
  311. graph_def = tf.Graph().as_graph_def()
  312. graph_def.ParseFromString(open(w, 'rb').read())
  313. frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
  314. elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
  315. try: # https://coral.ai/docs/edgetpu/tflite-python/#update-existing-tf-lite-code-for-the-edge-tpu
  316. from tflite_runtime.interpreter import Interpreter, load_delegate
  317. except ImportError:
  318. import tensorflow as tf
  319. Interpreter, load_delegate = tf.lite.Interpreter, tf.lite.experimental.load_delegate,
  320. if 'edgetpu' in w.lower(): # Edge TPU https://coral.ai/software/#edgetpu-runtime
  321. LOGGER.info(f'Loading {w} for TensorFlow Lite Edge TPU inference...')
  322. delegate = {'Linux': 'libedgetpu.so.1',
  323. 'Darwin': 'libedgetpu.1.dylib',
  324. 'Windows': 'edgetpu.dll'}[platform.system()]
  325. interpreter = Interpreter(model_path=w, experimental_delegates=[load_delegate(delegate)])
  326. else: # Lite
  327. LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
  328. interpreter = Interpreter(model_path=w) # load TFLite model
  329. interpreter.allocate_tensors() # allocate
  330. input_details = interpreter.get_input_details() # inputs
  331. output_details = interpreter.get_output_details() # outputs
  332. self.__dict__.update(locals()) # assign all variables to self
  333. def forward(self, im, augment=False, visualize=False, val=False):
  334. # YOLOv5 MultiBackend inference
  335. b, ch, h, w = im.shape # batch, channel, height, width
  336. if self.pt or self.jit: # PyTorch
  337. y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
  338. return y if val else y[0]
  339. elif self.dnn: # ONNX OpenCV DNN
  340. im = im.cpu().numpy() # torch to numpy
  341. self.net.setInput(im)
  342. y = self.net.forward()
  343. elif self.onnx: # ONNX Runtime
  344. im = im.cpu().numpy() # torch to numpy
  345. y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
  346. elif self.xml: # OpenVINO
  347. im = im.cpu().numpy() # FP32
  348. desc = self.ie.TensorDesc(precision='FP32', dims=im.shape, layout='NCHW') # Tensor Description
  349. request = self.executable_network.requests[0] # inference request
  350. request.set_blob(blob_name='images', blob=self.ie.Blob(desc, im)) # name=next(iter(request.input_blobs))
  351. request.infer()
  352. y = request.output_blobs['output'].buffer # name=next(iter(request.output_blobs))
  353. elif self.engine: # TensorRT
  354. assert im.shape == self.bindings['images'].shape, (im.shape, self.bindings['images'].shape)
  355. self.binding_addrs['images'] = int(im.data_ptr())
  356. self.context.execute_v2(list(self.binding_addrs.values()))
  357. y = self.bindings['output'].data
  358. elif self.coreml: # CoreML
  359. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  360. im = Image.fromarray((im[0] * 255).astype('uint8'))
  361. # im = im.resize((192, 320), Image.ANTIALIAS)
  362. y = self.model.predict({'image': im}) # coordinates are xywh normalized
  363. if 'confidence' in y:
  364. box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
  365. conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
  366. y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
  367. else:
  368. y = y[sorted(y)[-1]] # last output
  369. else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
  370. im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
  371. if self.saved_model: # SavedModel
  372. y = self.model(im, training=False).numpy()
  373. elif self.pb: # GraphDef
  374. y = self.frozen_func(x=self.tf.constant(im)).numpy()
  375. elif self.tflite: # Lite
  376. input, output = self.input_details[0], self.output_details[0]
  377. int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
  378. if int8:
  379. scale, zero_point = input['quantization']
  380. im = (im / scale + zero_point).astype(np.uint8) # de-scale
  381. self.interpreter.set_tensor(input['index'], im)
  382. self.interpreter.invoke()
  383. y = self.interpreter.get_tensor(output['index'])
  384. if int8:
  385. scale, zero_point = output['quantization']
  386. y = (y.astype(np.float32) - zero_point) * scale # re-scale
  387. y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
  388. y = torch.tensor(y) if isinstance(y, np.ndarray) else y
  389. return (y, []) if val else y
  390. def warmup(self, imgsz=(1, 3, 640, 640), half=False):
  391. # Warmup model by running inference once
  392. if self.pt or self.jit or self.onnx or self.engine: # warmup types
  393. if isinstance(self.device, torch.device) and self.device.type != 'cpu': # only warmup GPU models
  394. im = torch.zeros(*imgsz).to(self.device).type(torch.half if half else torch.float) # input image
  395. self.forward(im) # warmup
  396. class AutoShape(nn.Module):
  397. # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
  398. conf = 0.25 # NMS confidence threshold
  399. iou = 0.45 # NMS IoU threshold
  400. agnostic = False # NMS class-agnostic
  401. multi_label = False # NMS multiple labels per box
  402. classes = None # (optional list) filter by class, i.e. = [0, 15, 16] for COCO persons, cats and dogs
  403. max_det = 1000 # maximum number of detections per image
  404. amp = False # Automatic Mixed Precision (AMP) inference
  405. def __init__(self, model):
  406. super().__init__()
  407. LOGGER.info('Adding AutoShape... ')
  408. copy_attr(self, model, include=('yaml', 'nc', 'hyp', 'names', 'stride', 'abc'), exclude=()) # copy attributes
  409. self.dmb = isinstance(model, DetectMultiBackend) # DetectMultiBackend() instance
  410. self.pt = not self.dmb or model.pt # PyTorch model
  411. self.model = model.eval()
  412. def _apply(self, fn):
  413. # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
  414. self = super()._apply(fn)
  415. if self.pt:
  416. m = self.model.model.model[-1] if self.dmb else self.model.model[-1] # Detect()
  417. m.stride = fn(m.stride)
  418. m.grid = list(map(fn, m.grid))
  419. if isinstance(m.anchor_grid, list):
  420. m.anchor_grid = list(map(fn, m.anchor_grid))
  421. return self
  422. @torch.no_grad()
  423. def forward(self, imgs, size=640, augment=False, profile=False):
  424. # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
  425. # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
  426. # URI: = 'https://ultralytics.com/images/zidane.jpg'
  427. # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
  428. # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
  429. # numpy: = np.zeros((640,1280,3)) # HWC
  430. # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
  431. # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
  432. t = [time_sync()]
  433. p = next(self.model.parameters()) if self.pt else torch.zeros(1) # for device and type
  434. autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
  435. if isinstance(imgs, torch.Tensor): # torch
  436. with amp.autocast(enabled=autocast):
  437. return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
  438. # Pre-process
  439. n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
  440. shape0, shape1, files = [], [], [] # image and inference shapes, filenames
  441. for i, im in enumerate(imgs):
  442. f = f'image{i}' # filename
  443. if isinstance(im, (str, Path)): # filename or uri
  444. im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
  445. im = np.asarray(exif_transpose(im))
  446. elif isinstance(im, Image.Image): # PIL Image
  447. im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
  448. files.append(Path(f).with_suffix('.jpg').name)
  449. if im.shape[0] < 5: # image in CHW
  450. im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
  451. im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input
  452. s = im.shape[:2] # HWC
  453. shape0.append(s) # image shape
  454. g = (size / max(s)) # gain
  455. shape1.append([y * g for y in s])
  456. imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
  457. shape1 = [make_divisible(x, self.stride) for x in np.stack(shape1, 0).max(0)] # inference shape
  458. x = [letterbox(im, new_shape=shape1 if self.pt else size, auto=False)[0] for im in imgs] # pad
  459. x = np.stack(x, 0) if n > 1 else x[0][None] # stack
  460. x = np.ascontiguousarray(x.transpose((0, 3, 1, 2))) # BHWC to BCHW
  461. x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
  462. t.append(time_sync())
  463. with amp.autocast(enabled=autocast):
  464. # Inference
  465. y = self.model(x, augment, profile) # forward
  466. t.append(time_sync())
  467. # Post-process
  468. y = non_max_suppression(y if self.dmb else y[0], self.conf, iou_thres=self.iou, classes=self.classes,
  469. agnostic=self.agnostic, multi_label=self.multi_label, max_det=self.max_det) # NMS
  470. for i in range(n):
  471. scale_coords(shape1, y[i][:, :4], shape0[i])
  472. t.append(time_sync())
  473. return Detections(imgs, y, files, t, self.names, x.shape)
  474. class Detections:
  475. # YOLOv5 detections class for inference results
  476. def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
  477. super().__init__()
  478. d = pred[0].device # device
  479. gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
  480. self.imgs = imgs # list of images as numpy arrays
  481. self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
  482. self.names = names # class names
  483. self.files = files # image filenames
  484. self.times = times # profiling times
  485. self.xyxy = pred # xyxy pixels
  486. self.xywh = [xyxy2xywh(x) for x in pred] # xywh pixels
  487. self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
  488. self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
  489. self.n = len(self.pred) # number of images (batch size)
  490. self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
  491. self.s = shape # inference BCHW shape
  492. def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')):
  493. crops = []
  494. for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
  495. s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
  496. if pred.shape[0]:
  497. for c in pred[:, -1].unique():
  498. n = (pred[:, -1] == c).sum() # detections per class
  499. s += f"{n} {self.names[int(c)]}{'s' * (n > 1)}, " # add to string
  500. if show or save or render or crop:
  501. annotator = Annotator(im, example=str(self.names))
  502. for *box, conf, cls in reversed(pred): # xyxy, confidence, class
  503. label = f'{self.names[int(cls)]} {conf:.2f}'
  504. if crop:
  505. file = save_dir / 'crops' / self.names[int(cls)] / self.files[i] if save else None
  506. crops.append({'box': box, 'conf': conf, 'cls': cls, 'label': label,
  507. 'im': save_one_box(box, im, file=file, save=save)})
  508. else: # all others
  509. annotator.box_label(box, label, color=colors(cls))
  510. im = annotator.im
  511. else:
  512. s += '(no detections)'
  513. im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np
  514. if pprint:
  515. LOGGER.info(s.rstrip(', '))
  516. if show:
  517. im.show(self.files[i]) # show
  518. if save:
  519. f = self.files[i]
  520. im.save(save_dir / f) # save
  521. if i == self.n - 1:
  522. LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
  523. if render:
  524. self.imgs[i] = np.asarray(im)
  525. if crop:
  526. if save:
  527. LOGGER.info(f'Saved results to {save_dir}\n')
  528. return crops
  529. def print(self):
  530. self.display(pprint=True) # print results
  531. LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {tuple(self.s)}' %
  532. self.t)
  533. def show(self):
  534. self.display(show=True) # show results
  535. def save(self, save_dir='runs/detect/exp'):
  536. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) # increment save_dir
  537. self.display(save=True, save_dir=save_dir) # save results
  538. def crop(self, save=True, save_dir='runs/detect/exp'):
  539. save_dir = increment_path(save_dir, exist_ok=save_dir != 'runs/detect/exp', mkdir=True) if save else None
  540. return self.display(crop=True, save=save, save_dir=save_dir) # crop results
  541. def render(self):
  542. self.display(render=True) # render results
  543. return self.imgs
  544. def pandas(self):
  545. # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
  546. new = copy(self) # return copy
  547. ca = 'xmin', 'ymin', 'xmax', 'ymax', 'confidence', 'class', 'name' # xyxy columns
  548. cb = 'xcenter', 'ycenter', 'width', 'height', 'confidence', 'class', 'name' # xywh columns
  549. for k, c in zip(['xyxy', 'xyxyn', 'xywh', 'xywhn'], [ca, ca, cb, cb]):
  550. a = [[x[:5] + [int(x[5]), self.names[int(x[5])]] for x in x.tolist()] for x in getattr(self, k)] # update
  551. setattr(new, k, [pd.DataFrame(x, columns=c) for x in a])
  552. return new
  553. def tolist(self):
  554. # return a list of Detections objects, i.e. 'for result in results.tolist():'
  555. r = range(self.n) # iterable
  556. x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
  557. # for d in x:
  558. # for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
  559. # setattr(d, k, getattr(d, k)[0]) # pop out of list
  560. return x
  561. def __len__(self):
  562. return self.n
  563. class Classify(nn.Module):
  564. # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
  565. def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
  566. super().__init__()
  567. self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
  568. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
  569. self.flat = nn.Flatten()
  570. def forward(self, x):
  571. z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
  572. return self.flat(self.conv(z)) # flatten to x(b,c2)