Source code for icrawler.downloader

# -*- coding: utf-8 -*-

from threading import current_thread

from PIL import Image
from six import BytesIO
from six.moves import queue
from six.moves.urllib.parse import urlparse

from icrawler.utils import ThreadPool


[docs]class Downloader(ThreadPool): """Base class for downloader. A thread pool of downloader threads, in charge of downloading files and saving them in the corresponding paths. Attributes: task_queue (CachedQueue): A queue storing image downloading tasks, connecting :class:`Parser` and :class:`Downloader`. signal (Signal): A Signal object shared by all components. session (Session): A session object. logger: A logging.Logger object used for logging. workers (list): A list of downloader threads. thread_num (int): The number of downloader threads. lock (Lock): A threading.Lock object. storage (BaseStorage): storage backend. """ def __init__(self, thread_num, signal, session, storage): """Init Parser with some shared variables.""" super(Downloader, self).__init__( thread_num, out_queue=None, name='downloader') self.signal = signal self.session = session self.storage = storage self.file_idx_offset = 0 self.clear_status()
[docs] def clear_status(self): """Reset fetched_num to 0.""" self.fetched_num = 0
[docs] def set_file_idx_offset(self, file_idx_offset=0): """Set offset of file index. Args: file_idx_offset: It can be either an integer or 'auto'. If set to an integer, the filename will start from ``file_idx_offset`` + 1. If set to ``'auto'``, the filename will start from existing max file index plus 1. """ if isinstance(file_idx_offset, int): self.file_idx_offset = file_idx_offset elif file_idx_offset == 'auto': self.file_idx_offset = self.storage.max_file_idx() else: raise ValueError('"file_idx_offset" must be an integer or `auto`')
[docs] def get_filename(self, task, default_ext): """Set the path where the image will be saved. The default strategy is to use an increasing 6-digit number as the filename. You can override this method if you want to set custom naming rules. The file extension is kept if it can be obtained from the url, otherwise ``default_ext`` is used as extension. Args: task (dict): The task dict got from ``task_queue``. Output: Filename with extension. """ url_path = urlparse(task['file_url'])[2] extension = url_path.split('.')[-1] if '.' in url_path else default_ext file_idx = self.fetched_num + self.file_idx_offset return '{:06d}.{}'.format(file_idx, extension)
[docs] def reach_max_num(self): """Check if downloaded images reached max num. Returns: bool: if downloaded images reached max num. """ if self.signal.get('reach_max_num'): return True if self.max_num > 0 and self.fetched_num >= self.max_num: return True else: return False
def keep_file(self, task, response, **kwargs): return True
[docs] def download(self, task, default_ext, timeout=5, max_retry=3, overwrite=False, **kwargs): """Download the image and save it to the corresponding path. Args: task (dict): The task dict got from ``task_queue``. timeout (int): Timeout of making requests for downloading images. max_retry (int): the max retry times if the request fails. **kwargs: reserved arguments for overriding. """ file_url = task['file_url'] task['success'] = False task['filename'] = None retry = max_retry if not overwrite: with self.lock: self.fetched_num += 1 filename = self.get_filename(task, default_ext) if self.storage.exists(filename): self.logger.info('skip downloading file %s', filename) return self.fetched_num -= 1 while retry > 0 and not self.signal.get('reach_max_num'): try: response = self.session.get(file_url, timeout=timeout) except Exception as e: self.logger.error('Exception caught when downloading file %s, ' 'error: %s, remaining retry times: %d', file_url, e, retry - 1) else: if self.reach_max_num(): self.signal.set(reach_max_num=True) break elif response.status_code != 200: self.logger.error('Response status code %d, file %s', response.status_code, file_url) break elif not self.keep_file(task, response, **kwargs): break with self.lock: self.fetched_num += 1 filename = self.get_filename(task, default_ext) self.logger.info('image #%s\t%s', self.fetched_num, file_url) self.storage.write(filename, response.content) task['success'] = True task['filename'] = filename break finally: retry -= 1
[docs] def process_meta(self, task): """Process some meta data of the images. This method should be overridden by users if wanting to do more things other than just downloading the image, such as saving annotations. Args: task (dict): The task dict got from task_queue. This method will make use of fields other than ``file_url`` in the dict. """ pass
def start(self, file_idx_offset=0, *args, **kwargs): self.clear_status() self.set_file_idx_offset(file_idx_offset) self.init_workers(*args, **kwargs) for worker in self.workers: worker.start() self.logger.debug('thread %s started', worker.name)
[docs] def worker_exec(self, max_num, default_ext='', queue_timeout=5, req_timeout=5, **kwargs): """Target method of workers. Get task from ``task_queue`` and then download files and process meta data. A downloader thread will exit in either of the following cases: 1. All parser threads have exited and the task_queue is empty. 2. Downloaded image number has reached required number(max_num). Args: queue_timeout (int): Timeout of getting tasks from ``task_queue``. req_timeout (int): Timeout of making requests for downloading pages. **kwargs: Arguments passed to the :func:`download` method. """ self.max_num = max_num while True: if self.signal.get('reach_max_num'): self.logger.info('downloaded images reach max num, thread %s' ' is ready to exit', current_thread().name) break try: task = self.in_queue.get(timeout=queue_timeout) except queue.Empty: if self.signal.get('parser_exited'): self.logger.info('no more download task for thread %s', current_thread().name) break else: self.logger.info('%s is waiting for new download tasks', current_thread().name) except: self.logger.error('exception in thread %s', current_thread().name) else: self.download(task, default_ext, req_timeout, **kwargs) self.process_meta(task) self.in_queue.task_done() self.logger.info('thread {} exit'.format(current_thread().name))
def __exit__(self, exc_type, exc_val, exc_tb): self.logger.info('all downloader threads exited')
[docs]class ImageDownloader(Downloader): """Downloader specified for images. """ def _size_lt(self, sz1, sz2): return max(sz1) <= max(sz2) and min(sz1) <= min(sz2) def _size_gt(self, sz1, sz2): return max(sz1) >= max(sz2) and min(sz1) >= min(sz2)
[docs] def keep_file(self, task, response, min_size=None, max_size=None): """Decide whether to keep the image Compare image size with ``min_size`` and ``max_size`` to decide. Args: response (Response): response of requests. min_size (tuple or None): minimum size of required images. max_size (tuple or None): maximum size of required images. Returns: bool: whether to keep the image. """ try: img = Image.open(BytesIO(response.content)) except (IOError, OSError): return False task['img_size'] = img.size if min_size and not self._size_gt(img.size, min_size): return False if max_size and not self._size_lt(img.size, max_size): return False return True
[docs] def get_filename(self, task, default_ext): url_path = urlparse(task['file_url'])[2] if '.' in url_path: extension = url_path.split('.')[-1] if extension.lower() not in [ 'jpg', 'jpeg', 'png', 'bmp', 'tiff', 'gif', 'ppm', 'pgm' ]: extension = default_ext else: extension = default_ext file_idx = self.fetched_num + self.file_idx_offset return '{:06d}.{}'.format(file_idx, extension)
[docs] def worker_exec(self, max_num, default_ext='jpg', queue_timeout=5, req_timeout=5, **kwargs): super(ImageDownloader, self).worker_exec( max_num, default_ext, queue_timeout, req_timeout, **kwargs)