Source code for layoutparser.models.detectron2.layoutmodel

# Copyright 2021 The Layout Parser team. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from PIL import Image
import numpy as np
import warnings

from .catalog import MODEL_CATALOG, PathManager, LABEL_MAP_CATALOG
from ..base_layoutmodel import BaseLayoutModel
from ...elements import Rectangle, TextBlock, Layout
from ...file_utils import is_torch_cuda_available, is_detectron2_available

if is_detectron2_available():
    import detectron2.engine
    import detectron2.config

__all__ = ["Detectron2LayoutModel"]

[docs]class Detectron2LayoutModel(BaseLayoutModel): """Create a Detectron2-based Layout Detection Model Args: config_path (:obj:`str`): The path to the configuration file. model_path (:obj:`str`, None): The path to the saved weights of the model. If set, overwrite the weights in the configuration file. Defaults to `None`. label_map (:obj:`dict`, optional): The map from the model prediction (ids) to real word labels (strings). If the config is from one of the supported datasets, Layout Parser will automatically initialize the label_map. Defaults to `None`. device(:obj:`str`, optional): Whether to use cuda or cpu devices. If not set, LayoutParser will automatically determine the device to initialize the models on. extra_config (:obj:`list`, optional): Extra configuration passed to the Detectron2 model configuration. The argument will be used in the `merge_from_list < #detectron2.config.CfgNode.merge_from_list>`_ function. Defaults to `[]`. Examples:: >>> import layoutparser as lp >>> model = lp.Detectron2LayoutModel('lp://HJDataset/faster_rcnn_R_50_FPN_3x/config') >>> model.detect(image) """ DEPENDENCIES = ["detectron2"] DETECTOR_NAME = "detectron2" MODEL_CATALOG = MODEL_CATALOG def __init__( self, config_path, model_path=None, label_map=None, extra_config=None, enforce_cpu=None, device=None, ): if enforce_cpu is not None: warnings.warn( "Setting enforce_cpu is deprecated. Please set `device` instead.", DeprecationWarning, ) if extra_config is None: extra_config = [] config_path, model_path = self.config_parser( config_path, model_path, allow_empty_path=True ) config_path = PathManager.get_local_path(config_path) if label_map is None: if config_path.startswith("lp://"): dataset_name = config_path.lstrip("lp://").split("/")[1] label_map = LABEL_MAP_CATALOG[dataset_name] else: label_map = {} cfg = detectron2.config.get_cfg() cfg.merge_from_file(config_path) cfg.merge_from_list(extra_config) if model_path is not None: model_path = PathManager.get_local_path(model_path) # Because it will be forwarded to the detectron2 paths cfg.MODEL.WEIGHTS = model_path if is_torch_cuda_available(): if device is None: device = "cuda" else: device = "cpu" cfg.MODEL.DEVICE = device self.cfg = cfg self.label_map = label_map self._create_model() def _create_model(self): self.model = detectron2.engine.DefaultPredictor(self.cfg)
[docs] def gather_output(self, outputs): instance_pred = outputs["instances"].to("cpu") layout = Layout() scores = instance_pred.scores.tolist() boxes = instance_pred.pred_boxes.tensor.tolist() labels = instance_pred.pred_classes.tolist() for score, box, label in zip(scores, boxes, labels): x_1, y_1, x_2, y_2 = box label = self.label_map.get(label, label) cur_block = TextBlock( Rectangle(x_1, y_1, x_2, y_2), type=label, score=score ) layout.append(cur_block) return layout
[docs] def detect(self, image): """Detect the layout of a given image. Args: image (:obj:`np.ndarray` or `PIL.Image`): The input image to detect. Returns: :obj:`~layoutparser.Layout`: The detected layout of the input image """ image = self.image_loader(image) outputs = self.model(image) layout = self.gather_output(outputs) return layout
[docs] def image_loader(self, image: Union["np.ndarray", "Image.Image"]): # Convert PIL Image Input if isinstance(image, Image.Image): if image.mode != "RGB": image = image.convert("RGB") image = np.array(image) return image