# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Feature extractor class for Perceiver."""
from typing import Optional, Union
import numpy as np
from PIL import Image
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...file_utils import TensorType
from ...image_utils import (
IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD,
ImageFeatureExtractionMixin,
ImageInput,
is_torch_tensor,
)
from ...utils import logging
logger = logging.get_logger(__name__)
class PerceiverFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r"""
Constructs a Perceiver feature extractor.
This feature extractor inherits from [`ImageFeatureExtractionMixin`] which contains most of the
main methods. Users should refer to this superclass for more information regarding those methods.
Args:
do_center_crop (`bool`, *optional*, defaults to `True`):
Whether to crop the input at the center. If the input size is smaller than `crop_size` along any edge,
the image is padded with 0's and then center cropped.
crop_size (`int`, *optional*, defaults to 256):
Desired output size when applying center-cropping. Only has an effect if `do_center_crop` is set to
`True`.
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the input to a certain `size`.
size (`int` or `Tuple(int)`, *optional*, defaults to 224):
Resize the input to the given size. If a tuple is provided, it should be (width, height). If only an
integer is provided, then the input will be resized to (size, size). Only has an effect if `do_resize`
is set to `True`.
resample (`int`, *optional*, defaults to `PIL.Image.BICUBIC`):
An optional resampling filter. This can be one of `PIL.Image.NEAREST`, `PIL.Image.BOX`,
`PIL.Image.BILINEAR`, `PIL.Image.HAMMING`, `PIL.Image.BICUBIC` or `PIL.Image.LANCZOS`.
Only has an effect if `do_resize` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether or not to normalize the input with `image_mean` and `image_std`.
image_mean (`List[int]`, defaults to `[0.485, 0.456, 0.406]`):
The sequence of means for each channel, to be used when normalizing images.
image_std (`List[int]`, defaults to `[0.229, 0.224, 0.225]`):
The sequence of standard deviations for each channel, to be used when normalizing images.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_center_crop=True,
crop_size=256,
do_resize=True,
size=224,
resample=Image.BICUBIC,
do_normalize=True,
image_mean=None,
image_std=None,
**kwargs
):
super().__init__(**kwargs)
self.do_center_crop = do_center_crop
self.crop_size = crop_size
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN
self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD
def center_crop(self, image):
"""
Crops `image` to *self.crop_size* using a center crop. Note that if the image is too small to be cropped
to the size given, it will be padded (so the returned result has the size asked).
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to resize.
"""
if isinstance(image, Image.Image):
image = self.to_numpy_array(image)
image_height, image_width = image.shape[-2:]
padded_center_crop_size = (
(self.size / (self.crop_size)) * np.minimum(image_height, image_width).astype(np.float32)
).astype(np.int32)
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
crop_window = [offset_height, offset_width, padded_center_crop_size, padded_center_crop_size]
image = image[
:, crop_window[0] : crop_window[0] + crop_window[2], crop_window[1] : crop_window[1] + crop_window[3]
]
return image
def __call__(
self, images: ImageInput, return_tensors: Optional[Union[str, TensorType]] = None, **kwargs
) -> BatchFeature:
"""
Main method to prepare for the model one or several image(s).
<Tip warning={true}>
NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass
PIL images.
</Tip>
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a
number of channels, H and W are image height and width.
return_tensors (`str` or [`~file_utils.TensorType`], *optional*, defaults to `'np'`):
If set, will return tensors of a particular framework. Acceptable values are:
- `'tf'`: Return TensorFlow `tf.constant` objects.
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
Returns:
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
- **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height,
width).
"""
# Input type checking for clearer error
valid_images = False
# Check that images has a valid type
if isinstance(images, (Image.Image, np.ndarray)) or is_torch_tensor(images):
valid_images = True
elif isinstance(images, (list, tuple)):
if len(images) == 0 or isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]):
valid_images = True
if not valid_images:
raise ValueError(
"Images must of type `PIL.Image.Image`, `np.ndarray` or `torch.Tensor` (single example),"
"`List[PIL.Image.Image]`, `List[np.ndarray]` or `List[torch.Tensor]` (batch of examples)."
)
is_batched = bool(
isinstance(images, (list, tuple))
and (isinstance(images[0], (Image.Image, np.ndarray)) or is_torch_tensor(images[0]))
)
if not is_batched:
images = [images]
# transformations (center cropping + resizing + normalization)
if self.do_center_crop and self.crop_size is not None:
images = [self.center_crop(image) for image in images]
if self.do_resize and self.size is not None and self.resample is not None:
images = [self.resize(image=image, size=self.size, resample=self.resample) for image in images]
if self.do_normalize:
images = [self.normalize(image=image, mean=self.image_mean, std=self.image_std) for image in images]
# return as BatchFeature
data = {"pixel_values": images}
encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors)
return encoded_inputs