# -*- coding: utf-8 -*-
from threading import current_thread
from PIL import Image
from six import BytesIO
from six.moves import queue
from six.moves.urllib.parse import urlparse
from icrawler.utils import ThreadPool
[docs]class Downloader(ThreadPool):
"""Base class for downloader.
A thread pool of downloader threads, in charge of downloading files and
saving them in the corresponding paths.
Attributes:
task_queue (CachedQueue): A queue storing image downloading tasks,
connecting :class:`Parser` and :class:`Downloader`.
signal (Signal): A Signal object shared by all components.
session (Session): A session object.
logger: A logging.Logger object used for logging.
workers (list): A list of downloader threads.
thread_num (int): The number of downloader threads.
lock (Lock): A threading.Lock object.
storage (BaseStorage): storage backend.
"""
def __init__(self, thread_num, signal, session, storage):
"""Init Parser with some shared variables."""
super(Downloader, self).__init__(
thread_num, out_queue=None, name='downloader')
self.signal = signal
self.session = session
self.storage = storage
self.file_idx_offset = 0
self.clear_status()
[docs] def clear_status(self):
"""Reset fetched_num to 0."""
self.fetched_num = 0
[docs] def set_file_idx_offset(self, file_idx_offset=0):
"""Set offset of file index.
Args:
file_idx_offset: It can be either an integer or 'auto'. If set
to an integer, the filename will start from
``file_idx_offset`` + 1. If set to ``'auto'``, the filename
will start from existing max file index plus 1.
"""
if isinstance(file_idx_offset, int):
self.file_idx_offset = file_idx_offset
elif file_idx_offset == 'auto':
self.file_idx_offset = self.storage.max_file_idx()
else:
raise ValueError('"file_idx_offset" must be an integer or `auto`')
[docs] def get_filename(self, task, default_ext):
"""Set the path where the image will be saved.
The default strategy is to use an increasing 6-digit number as
the filename. You can override this method if you want to set custom
naming rules. The file extension is kept if it can be obtained from
the url, otherwise ``default_ext`` is used as extension.
Args:
task (dict): The task dict got from ``task_queue``.
Output:
Filename with extension.
"""
url_path = urlparse(task['file_url'])[2]
extension = url_path.split('.')[-1] if '.' in url_path else default_ext
file_idx = self.fetched_num + self.file_idx_offset
return '{:06d}.{}'.format(file_idx, extension)
[docs] def reach_max_num(self):
"""Check if downloaded images reached max num.
Returns:
bool: if downloaded images reached max num.
"""
if self.signal.get('reach_max_num'):
return True
if self.max_num > 0 and self.fetched_num >= self.max_num:
return True
else:
return False
def keep_file(self, task, response, **kwargs):
return True
[docs] def download(self,
task,
default_ext,
timeout=5,
max_retry=3,
overwrite=False,
**kwargs):
"""Download the image and save it to the corresponding path.
Args:
task (dict): The task dict got from ``task_queue``.
timeout (int): Timeout of making requests for downloading images.
max_retry (int): the max retry times if the request fails.
**kwargs: reserved arguments for overriding.
"""
file_url = task['file_url']
task['success'] = False
task['filename'] = None
retry = max_retry
if not overwrite:
with self.lock:
self.fetched_num += 1
filename = self.get_filename(task, default_ext)
if self.storage.exists(filename):
self.logger.info('skip downloading file %s', filename)
return
self.fetched_num -= 1
while retry > 0 and not self.signal.get('reach_max_num'):
try:
response = self.session.get(file_url, timeout=timeout)
except Exception as e:
self.logger.error('Exception caught when downloading file %s, '
'error: %s, remaining retry times: %d',
file_url, e, retry - 1)
else:
if self.reach_max_num():
self.signal.set(reach_max_num=True)
break
elif response.status_code != 200:
self.logger.error('Response status code %d, file %s',
response.status_code, file_url)
break
elif not self.keep_file(task, response, **kwargs):
break
with self.lock:
self.fetched_num += 1
filename = self.get_filename(task, default_ext)
self.logger.info('image #%s\t%s', self.fetched_num, file_url)
self.storage.write(filename, response.content)
task['success'] = True
task['filename'] = filename
break
finally:
retry -= 1
def start(self, file_idx_offset=0, *args, **kwargs):
self.clear_status()
self.set_file_idx_offset(file_idx_offset)
self.init_workers(*args, **kwargs)
for worker in self.workers:
worker.start()
self.logger.debug('thread %s started', worker.name)
[docs] def worker_exec(self,
max_num,
default_ext='',
queue_timeout=5,
req_timeout=5,
**kwargs):
"""Target method of workers.
Get task from ``task_queue`` and then download files and process meta
data. A downloader thread will exit in either of the following cases:
1. All parser threads have exited and the task_queue is empty.
2. Downloaded image number has reached required number(max_num).
Args:
queue_timeout (int): Timeout of getting tasks from ``task_queue``.
req_timeout (int): Timeout of making requests for downloading pages.
**kwargs: Arguments passed to the :func:`download` method.
"""
self.max_num = max_num
while True:
if self.signal.get('reach_max_num'):
self.logger.info('downloaded images reach max num, thread %s'
' is ready to exit',
current_thread().name)
break
try:
task = self.in_queue.get(timeout=queue_timeout)
except queue.Empty:
if self.signal.get('parser_exited'):
self.logger.info('no more download task for thread %s',
current_thread().name)
break
else:
self.logger.info('%s is waiting for new download tasks',
current_thread().name)
except:
self.logger.error('exception in thread %s',
current_thread().name)
else:
self.download(task, default_ext, req_timeout, **kwargs)
self.process_meta(task)
self.in_queue.task_done()
self.logger.info('thread {} exit'.format(current_thread().name))
def __exit__(self, exc_type, exc_val, exc_tb):
self.logger.info('all downloader threads exited')
[docs]class ImageDownloader(Downloader):
"""Downloader specified for images.
"""
def _size_lt(self, sz1, sz2):
return max(sz1) <= max(sz2) and min(sz1) <= min(sz2)
def _size_gt(self, sz1, sz2):
return max(sz1) >= max(sz2) and min(sz1) >= min(sz2)
[docs] def keep_file(self, task, response, min_size=None, max_size=None):
"""Decide whether to keep the image
Compare image size with ``min_size`` and ``max_size`` to decide.
Args:
response (Response): response of requests.
min_size (tuple or None): minimum size of required images.
max_size (tuple or None): maximum size of required images.
Returns:
bool: whether to keep the image.
"""
try:
img = Image.open(BytesIO(response.content))
except (IOError, OSError):
return False
task['img_size'] = img.size
if min_size and not self._size_gt(img.size, min_size):
return False
if max_size and not self._size_lt(img.size, max_size):
return False
return True
[docs] def get_filename(self, task, default_ext):
url_path = urlparse(task['file_url'])[2]
if '.' in url_path:
extension = url_path.split('.')[-1]
if extension.lower() not in [
'jpg', 'jpeg', 'png', 'bmp', 'tiff', 'gif', 'ppm', 'pgm'
]:
extension = default_ext
else:
extension = default_ext
file_idx = self.fetched_num + self.file_idx_offset
return '{:06d}.{}'.format(file_idx, extension)
[docs] def worker_exec(self,
max_num,
default_ext='jpg',
queue_timeout=5,
req_timeout=5,
**kwargs):
super(ImageDownloader, self).worker_exec(
max_num, default_ext, queue_timeout, req_timeout, **kwargs)