fiftyone.utils.torch¶
PyTorch utilities.
Functions:
|
Loads an image model from PyTorch Hub as a |
|
Loads a raw model from PyTorch Hub as a |
|
Locates the |
|
Loads the package requirements from the |
|
Installs the package requirements from the |
|
Verifies that the package requirements from the |
Recommend a number of workers for running a |
|
|
Creates a |
|
Gets the number of processes per-machine in the local process group. |
Returns the world size of the current operation. |
|
|
Gets the rank of the current process within the local processes group. |
|
Gets the rank of the current process. |
|
Scatters the given array from the local leader to all local workers. |
|
Gathers arbitrary picklable data (not necessarily tensors). |
Classes:
|
A class that defines how to load the input for a model. |
|
Mixin for Torch models that can generate embeddings. |
Configuration for running a |
|
|
A |
|
Wrapper for evaluating a Torch model on images. |
Transform that converts a tensor or ndarray to a PIL image, while also allowing PIL images to passthrough. |
|
|
Transform that resizes the PIL image or torch Tensor, if necessary, so that its minimum dimensions are at least the specified size. |
|
Transform that resizes the PIL image or torch Tensor, if necessary, so that its maximum dimensions are at most the specified size. |
|
Transform that center crops the PIL image or torch Tensor, if necessary, so that its dimensions are multiples of the specified patch size. |
|
Callback that saves the input/output tensor of the specified layer of a Torch model during each |
|
Interface for processing the outputs of Torch models. |
|
Output processor for single label classifiers. |
|
Output processor for object detectors. |
|
Output processor for instance segementers. |
|
Output processor for keypoint detection models. |
|
Output processor for semantic segementers. |
|
Constructs a |
|
A |
A |
|
|
A |
|
|
|
|
|
-
fiftyone.utils.torch.
load_torch_hub_image_model
(repo_or_dir, model, hub_kwargs=None, **kwargs)¶ Loads an image model from PyTorch Hub as a
TorchImageModel
.Example usage:
import fiftyone.utils.torch as fout model = fout.load_torch_hub_image_model( "facebookresearch/dinov2", "dinov2_vits14", image_patch_size=14, embeddings_layer="head", ) assert model.has_embeddings is True
- Parameters
repo_or_dir – see
torch.hub.load
model – see
torch.hub.load
**kwargs – additional parameters for
TorchImageModelConfig
- Returns
-
fiftyone.utils.torch.
load_torch_hub_raw_model
(*args, **kwargs)¶ Loads a raw model from PyTorch Hub as a
torch.nn.Module
.Example usage:
import fiftyone.utils.torch as fout model = fout.load_torch_hub_raw_model( "facebookresearch/dinov2", "dinov2_vits14", ) print(type(model)) # <class 'dinov2.models.vision_transformer.DinoVisionTransformer'>
- Parameters
*args – positional arguments for
torch.hub.load
**kwargs – keyword arguments for
torch.hub.load
- Returns
-
fiftyone.utils.torch.
find_torch_hub_requirements
(repo_or_dir, source='github')¶ Locates the
requirements.txt
file on disk associated with a downloaded PyTorch Hub model.Example usage:
import fiftyone.utils.torch as fout req_path = fout.find_torch_hub_requirements("facebookresearch/dinov2") print(req_path) # '~/.cache/torch/hub/facebookresearch_dinov2_main/requirements.txt'
- Parameters
repo_or_dir – see
torch.hub.load
source ("github") – see
torch.hub.load
- Returns
the path to the requirements file on disk
-
fiftyone.utils.torch.
load_torch_hub_requirements
(repo_or_dir, source='github')¶ Loads the package requirements from the
requirements.txt
file on disk associated with a downloaded PyTorch Hub model.Example usage:
import fiftyone.utils.torch as fout requirements = fout.load_torch_hub_requirements("facebookresearch/dinov2") print(requirements) # ['torch==2.0.0', 'torchvision==0.15.0', ...]
- Parameters
repo_or_dir – see
torch.hub.load
source ("github") – see
torch.hub.load
- Returns
a list of requirement strings
-
fiftyone.utils.torch.
install_torch_hub_requirements
(repo_or_dir, source='github', error_level=None)¶ Installs the package requirements from the
requirements.txt
file on disk associated with a downloaded PyTorch Hub model.Example usage:
import fiftyone.utils.torch as fout fout.install_torch_hub_requirements("facebookresearch/dinov2")
- Parameters
repo_or_dir – see
torch.hub.load
source ("github") – see
torch.hub.load
error_level (None) –
the error level to use, defined as:
0: raise error if the install fails
1: log warning if the install fails
2: ignore install fails
By default,
fiftyone.config.requirement_error_level
is used
-
fiftyone.utils.torch.
ensure_torch_hub_requirements
(repo_or_dir, source='github', error_level=None, log_success=False)¶ Verifies that the package requirements from the
requirements.txt
file on disk associated with a downloaded PyTorch Hub model are installed.Example usage:
import fiftyone.utils.torch as fout fout.ensure_torch_hub_requirements("facebookresearch/dinov2")
- Parameters
repo_or_dir – see
torch.hub.load
source ("github") – see
torch.hub.load
error_level (None) –
the error level to use, defined as:
0: raise error if requirement is not satisfied
1: log warning if requirement is not satisfied
2: ignore unsatisifed requirements
By default,
fiftyone.config.requirement_error_level
is usedlog_success (False) – whether to generate a log message if a requirement is satisfied
-
class
fiftyone.utils.torch.
GetItem
(field_mapping=None, **kwargs)¶ Bases:
object
A class that defines how to load the input for a model.
Models that implement the
fiftyone.core.models.SupportsGetItem
mixin use this class to define howFiftyOneTorchDataset
should load their inputs.The
__call__()
method should accept a dictionary mapping the keys defined byrequired_keys
to values extracted from the inputfiftyone.core.sample.Sample
instance according to the mapping defined byfield_mapping
.- Parameters
field_mapping (None) – a user-supplied dict mapping keys in
required_keys
to field names of their dataset that contain the required values
Attributes:
The list of keys that must exist on the dicts provided to the
__call__()
method at runtime.A user-supplied dictionary mappings keys in
required_keys
to field names of their dataset that contain the required values.-
property
required_keys
¶ The list of keys that must exist on the dicts provided to the
__call__()
method at runtime.The user supplies the field names from which to extract these values from their samples via
field_mapping
.
-
property
field_mapping
¶ A user-supplied dictionary mappings keys in
required_keys
to field names of their dataset that contain the required values.
-
class
fiftyone.utils.torch.
TorchEmbeddingsMixin
(model, layer_name=None, as_feature_extractor=False)¶ Bases:
fiftyone.core.models.EmbeddingsMixin
Mixin for Torch models that can generate embeddings.
- Parameters
model – the Torch model, a
torch.nn.Module
layer_name (None) – the name of the embeddings layer whose output to save, or
None
if this model instance should not expose embeddings. Prepend"<"
to save the input tensor insteadas_feature_extractor (False) – whether to operate the model as a feature extractor. If
layer_name
is provided, this layer is passed to torchvision’screate_feature_extractor()
function. If nolayer_name
is provided, the model’s output is used as-is for feature extraction
Attributes:
Whether this instance has embeddings.
Methods:
embed
(arg)Generates an embedding for the given data.
embed_all
(args)Generates embeddings for the given iterable of data.
Returns the embeddings generated by the last forward pass of the model.
-
property
has_embeddings
¶ Whether this instance has embeddings.
-
embed
(arg)¶ Generates an embedding for the given data.
Subclasses can override this method to increase efficiency, but, by default, this method simply calls
predict()
and then returnsget_embeddings()
.- Parameters
arg – the data. See
predict()
for details- Returns
a numpy array containing the embedding
-
embed_all
(args)¶ Generates embeddings for the given iterable of data.
Subclasses can override this method to increase efficiency, but, by default, this method simply iterates over the data and applies
embed()
to each.- Parameters
args – an iterable of data. See
predict_all()
for details- Returns
a numpy array containing the embeddings stacked along axis 0
-
get_embeddings
()¶ Returns the embeddings generated by the last forward pass of the model.
By convention, this method should always return an array whose first axis represents batch size (which will always be 1 when
predict()
was last used).- Returns
a numpy array containing the embedding(s)
-
class
fiftyone.utils.torch.
TorchImageModelConfig
(d)¶ Bases:
fiftyone.core.config.Config
Configuration for running a
TorchImageModel
.Models are represented by this class via the following three components:
Model:
# Directly specify a model model # Load model from an entrypoint model = entrypoint_fcn(**entrypoint_args)
Transforms:
# Directly provide transforms transforms # Load transforms from a function transforms = transforms_fcn(**transforms_args) # Use the `image_XXX` parameters defined below to build a transform transforms = build_transforms(image_XXX, ...)
OutputProcessor:
# Directly provide an OutputProcessor output_processor # Load an OutputProcessor from a function output_processor = output_processor_cls(**output_processor_args)
Given these components, inference happens as follows:
def predict_all(imgs): imgs = [transforms(img) for img in imgs] if not raw_inputs: imgs = torch.stack(imgs) output = model(imgs) return output_processor(output, ...)
- Parameters
model (None) – a
torch.nn.Module
instance to useentrypoint_fcn (None) – a function or string like
"torchvision.models.inception_v3"
specifying the entrypoint function that loads the modelentrypoint_args (None) – a dictionary of arguments for
entrypoint_fcn
transforms (None) – a preprocessing transform to apply
transforms_fcn (None) – a function or string like
"torchvision.models.Inception_V3_Weights.DEFAULT.transforms"
specifying a function that returns a preprocessing transform function to applytransforms_args (None) – a dictionary of arguments for
transforms_args
ragged_batches (None) – whether the provided
transforms
ortransforms_fcn
may return tensors of different sizes. This must be set toFalse
to enable batch inference, if it is desiredraw_inputs (None) – whether to feed the raw list of images to the model rather than stacking them as a Torch tensor
output_processor (None) – an
OutputProcessor
instance to useoutput_processor_cls (None) – a class or string like
"fifytone.utils.torch.ClassifierOutputProcessor"
specifying theOutputProcessor
to useoutput_processor_args (None) – a dictionary of arguments for
output_processor_cls(classes=classes, **kwargs)
confidence_thresh (None) – an optional confidence threshold apply to any applicable predictions generated by the model
classes (None) – a list of class names for the model, if applicable
labels_string (None) – a comma-separated list of the class names for the model, if applicable
labels_path (None) – the path to the labels map for the model, if applicable
mask_targets (None) – a mask targets dict for the model, if applicable
mask_targets_path (None) – the path to a mask targets map for the model, if applicable
skeleton (None) – a keypoint skeleton dict for the model, if applicable
image_min_size (None) – resize the input images during preprocessing, if necessary, so that the image dimensions are at least this
(width, height)
image_min_dim (None) – resize input images during preprocessing, if necessary, so that the smaller image dimension is at least this value
image_max_size (None) – resize the input images during preprocessing, if necessary, so that the image dimensions are at most this
(width, height)
image_max_dim (None) – resize input images during preprocessing, if necessary, so that the largest image dimension is at most this value.
image_size (None) – a
(width, height)
to which to resize the input images during preprocessingimage_dim (None) – resize the smaller input dimension to this value during preprocessing
image_patch_size (None) – crop the input images during preprocessing, if necessary, so that the image dimensions are a multiple of this patch size
image_mean (None) – a 3-array of mean values in
[0, 1]
for preprocessing the input imagesimage_std (None) – a 3-array of std values in
[0, 1]
for preprocessing the input images inputs that are lists of Tensorsembeddings_layer (None) – the name of a layer whose output to expose as embeddings. Prepend
"<"
to save the input tensor insteadas_feature_extractor (False) – whether to operate the model as a feature extractor. If
embeddings_layer
is provided, this layer is passed to torchvision’screate_feature_extractor()
function. If noembeddings_layer
is provided, the model’s output is used as-is for feature extractionuse_half_precision (None) – whether to use half precision (only supported when using GPU)
cudnn_benchmark (None) – a value to use for
torch.backends.cudnn.benchmark
while the model is runningdevice (None) – a string specifying the device to use, eg
("cuda:0", "mps", "cpu")
. By default, CUDA is used if available, else CPU is used
Methods:
Returns a list of class attributes to be serialized.
builder
()Returns a ConfigBuilder instance for this class.
copy
()Returns a deep copy of the object.
custom_attributes
([dynamic, private])Returns a customizable list of class attributes.
default
()Returns the default config instance.
from_dict
(d)Constructs a Config object from a JSON dictionary.
from_json
(path, *args, **kwargs)Constructs a Serializable object from a JSON file.
from_kwargs
(**kwargs)Constructs a Config object from keyword arguments.
from_str
(s, *args, **kwargs)Constructs a Serializable object from a JSON string.
Returns the fully-qualified class name string of this object.
Loads the default config instance from file.
parse_array
(d, key[, default])Parses a raw array attribute.
parse_bool
(d, key[, default])Parses a boolean value.
parse_categorical
(d, key, choices[, default])Parses a categorical JSON field, which must take a value from among the given choices.
parse_dict
(d, key[, default])Parses a dictionary attribute.
parse_int
(d, key[, default])Parses an integer attribute.
parse_mutually_exclusive_fields
(fields)Parses a mutually exclusive dictionary of pre-parsed fields, which must contain exactly one field with a truthy value.
parse_number
(d, key[, default])Parses a number attribute.
parse_object
(d, key, cls[, default])Parses an object attribute.
parse_object_array
(d, key, cls[, default])Parses an array of objects.
parse_object_dict
(d, key, cls[, default])Parses a dictionary whose values are objects.
parse_path
(d, key[, default])Parses a path attribute.
parse_raw
(d, key[, default])Parses a raw (arbitrary) JSON field.
parse_string
(d, key[, default])Parses a string attribute.
serialize
([reflective])Serializes the object into a dictionary.
to_str
([pretty_print])Returns a string representation of this object.
validate_all_or_nothing_fields
(fields)Validates a dictionary of pre-parsed fields checking that either all or none of the fields have a truthy value.
write_json
(path[, pretty_print])Serializes the object and writes it to disk.
-
attributes
()¶ Returns a list of class attributes to be serialized.
This method is called internally by serialize() to determine the class attributes to serialize.
Subclasses can override this method, but, by default, all attributes in vars(self) are returned, minus private attributes, i.e., those starting with “_”. The order of the attributes in this list is preserved when serializing objects, so a common pattern is for subclasses to override this method if they want their JSON files to be organized in a particular way.
- Returns
a list of class attributes to be serialized
-
classmethod
builder
()¶ Returns a ConfigBuilder instance for this class.
-
copy
()¶ Returns a deep copy of the object.
- Returns
a Serializable instance
-
custom_attributes
(dynamic=False, private=False)¶ Returns a customizable list of class attributes.
By default, all attributes in vars(self) are returned, minus private attributes (those starting with “_”).
- Parameters
dynamic – whether to include dynamic properties, e.g., those defined by getter/setter methods or the @property decorator. By default, this is False
private – whether to include private properties, i.e., those starting with “_”. By default, this is False
- Returns
a list of class attributes
-
classmethod
default
()¶ Returns the default config instance.
By default, this method instantiates the class from an empty dictionary, which will only succeed if all attributes are optional. Otherwise, subclasses should override this method to provide the desired default configuration.
-
classmethod
from_dict
(d)¶ Constructs a Config object from a JSON dictionary.
Config subclass constructors accept JSON dictionaries, so this method simply passes the dictionary to cls().
- Parameters
d – a dict of fields expected by cls
- Returns
an instance of cls
-
classmethod
from_json
(path, *args, **kwargs)¶ Constructs a Serializable object from a JSON file.
Subclasses may override this method, but, by default, this method simply reads the JSON and calls from_dict(), which subclasses must implement.
- Parameters
path – the path to the JSON file on disk
*args – optional positional arguments for self.from_dict()
**kwargs – optional keyword arguments for self.from_dict()
- Returns
an instance of the Serializable class
-
classmethod
from_kwargs
(**kwargs)¶ Constructs a Config object from keyword arguments.
- Parameters
**kwargs – keyword arguments that define the fields expected by cls
- Returns
an instance of cls
-
classmethod
from_str
(s, *args, **kwargs)¶ Constructs a Serializable object from a JSON string.
Subclasses may override this method, but, by default, this method simply parses the string and calls from_dict(), which subclasses must implement.
- Parameters
s – a JSON string representation of a Serializable object
*args – optional positional arguments for self.from_dict()
**kwargs – optional keyword arguments for self.from_dict()
- Returns
an instance of the Serializable class
-
classmethod
get_class_name
()¶ Returns the fully-qualified class name string of this object.
-
classmethod
load_default
()¶ Loads the default config instance from file.
Subclasses must implement this method if they intend to support default instances.
-
static
parse_array
(d, key, default=<eta.core.config.NoDefault object>)¶ Parses a raw array attribute.
- Parameters
d – a JSON dictionary
key – the key to parse
default – a default list to return if key is not present
- Returns
a list of raw (untouched) values
- Raises
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
-
static
parse_bool
(d, key, default=<eta.core.config.NoDefault object>)¶ Parses a boolean value.
- Parameters
d – a JSON dictionary
key – the key to parse
default – a default bool to return if key is not present
- Returns
True/False
- Raises
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
-
static
parse_categorical
(d, key, choices, default=<eta.core.config.NoDefault object>)¶ Parses a categorical JSON field, which must take a value from among the given choices.
- Parameters
d – a JSON dictionary
key – the key to parse
choices – either an iterable of possible values or an enum-like class whose attributes define the possible values
default – a default value to return if key is not present
- Returns
the raw (untouched) value of the given field, which is equal to a value from choices
- Raises
ConfigError – if the key was present in the dictionary but its value was not an allowed choice, or if no default value was provided and the key was not found in the dictionary
-
static
parse_dict
(d, key, default=<eta.core.config.NoDefault object>)¶ Parses a dictionary attribute.
- Parameters
d – a JSON dictionary
key – the key to parse
default – a default dict to return if key is not present
- Returns
a dictionary
- Raises
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
-
static
parse_int
(d, key, default=<eta.core.config.NoDefault object>)¶ Parses an integer attribute.
- Parameters
d – a JSON dictionary
key – the key to parse
default – a default integer value to return if key is not present
- Returns
an int
- Raises
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
-
static
parse_mutually_exclusive_fields
(fields)¶ Parses a mutually exclusive dictionary of pre-parsed fields, which must contain exactly one field with a truthy value.
- Parameters
fields – a dictionary of pre-parsed fields
- Returns
the (field, value) that was set
- Raises
ConfigError – if zero or more than one truthy value was found
-
static
parse_number
(d, key, default=<eta.core.config.NoDefault object>)¶ Parses a number attribute.
- Parameters
d – a JSON dictionary
key – the key to parse
default – a default numeric value to return if key is not present
- Returns
a number (e.g. int, float)
- Raises
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
-
static
parse_object
(d, key, cls, default=<eta.core.config.NoDefault object>)¶ Parses an object attribute.
The value of d[key] can be either an instance of cls or a serialized dict from an instance of cls.
- Parameters
d – a JSON dictionary
key – the key to parse
cls – the class of d[key]
default – a default cls instance to return if key is not present
- Returns
an instance of cls
- Raises
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
-
static
parse_object_array
(d, key, cls, default=<eta.core.config.NoDefault object>)¶ Parses an array of objects.
The values in d[key] can be either instances of cls or serialized dicts from instances of cls.
- Parameters
d – a JSON dictionary
key – the key to parse
cls – the class of the elements of list d[key]
default – the default list to return if key is not present
- Returns
a list of cls instances
- Raises
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
-
static
parse_object_dict
(d, key, cls, default=<eta.core.config.NoDefault object>)¶ Parses a dictionary whose values are objects.
The values in d[key] can be either instances of cls or serialized dicts from instances of cls.
- Parameters
d – a JSON dictionary
key – the key to parse
cls – the class of the values of dictionary d[key]
default – the default dict of cls instances to return if key is not present
- Returns
a dictionary whose values are cls instances
- Raises
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
-
static
parse_path
(d, key, default=<eta.core.config.NoDefault object>)¶ Parses a path attribute.
The path is converted to an absolute path if necessary via
os.path.abspath(os.path.expanduser(value))
.- Parameters
d – a JSON dictionary
key – the key to parse
default – a default string to return if key is not present
- Returns
a path string
- Raises
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
-
static
parse_raw
(d, key, default=<eta.core.config.NoDefault object>)¶ Parses a raw (arbitrary) JSON field.
- Parameters
d – a JSON dictionary
key – the key to parse
default – a default value to return if key is not present
- Returns
the raw (untouched) value of the given field
- Raises
ConfigError – if no default value was provided and the key was not found in the dictionary
-
static
parse_string
(d, key, default=<eta.core.config.NoDefault object>)¶ Parses a string attribute.
- Parameters
d – a JSON dictionary
key – the key to parse
default – a default string to return if key is not present
- Returns
a string
- Raises
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
-
serialize
(reflective=False)¶ Serializes the object into a dictionary.
Serialization is applied recursively to all attributes in the object, including element-wise serialization of lists and dictionary values.
- Parameters
reflective – whether to include reflective attributes when serializing the object. By default, this is False
- Returns
a JSON dictionary representation of the object
-
to_str
(pretty_print=True, **kwargs)¶ Returns a string representation of this object.
- Parameters
pretty_print – whether to render the JSON in human readable format with newlines and indentations. By default, this is True
**kwargs – optional keyword arguments for self.serialize()
- Returns
a string representation of the object
-
static
validate_all_or_nothing_fields
(fields)¶ Validates a dictionary of pre-parsed fields checking that either all or none of the fields have a truthy value.
- Parameters
fields – a dictionary of pre-parsed fields
- Raises
ConfigError – if some values are truth and some are not
-
write_json
(path, pretty_print=False, **kwargs)¶ Serializes the object and writes it to disk.
- Parameters
path – the output path
pretty_print – whether to render the JSON in human readable format with newlines and indentations. By default, this is False
**kwargs – optional keyword arguments for self.serialize()
-
class
fiftyone.utils.torch.
ImageGetItem
(field_mapping=None, transform=None, raw_inputs=False, using_half_precision=False, use_numpy=False, **kwargs)¶ Bases:
fiftyone.utils.torch.GetItem
A
GetItem
that loads images to feed toTorchImageModel
instances.By default, images are loaded from the
"filepath"
field of samples, but users can override this by providingfield_mapping={"filepath": "another_field"}
.- Parameters
field_mapping (None) – the user-supplied dict mapping keys in
required_keys
to field names of their dataset that contain the required valuestransform (None) – a
torchvision.transforms
function to applyraw_inputs (False) – whether to feed the raw list of images to the model rather than stacking them as a Torch tensor
using_half_precision (False) – whether the model is using half precision
use_numpy (False) – whether to use numpy arrays rather than PIL images and Torch tensors when loading data
Attributes:
The list of keys that must exist on the dicts provided to the
__call__()
method at runtime.A user-supplied dictionary mappings keys in
required_keys
to field names of their dataset that contain the required values.-
property
required_keys
¶ The list of keys that must exist on the dicts provided to the
__call__()
method at runtime.The user supplies the field names from which to extract these values from their samples via
field_mapping
.
-
property
field_mapping
¶ A user-supplied dictionary mappings keys in
required_keys
to field names of their dataset that contain the required values.
-
class
fiftyone.utils.torch.
TorchImageModel
(config)¶ Bases:
fiftyone.core.models.SupportsGetItem
,fiftyone.utils.torch.TorchEmbeddingsMixin
,fiftyone.core.models.TorchModelMixin
,fiftyone.core.models.LogitsMixin
,fiftyone.core.models.Model
Wrapper for evaluating a Torch model on images.
See this page for example usage.
- Parameters
config – an
TorchImageModelConfig
Methods:
build_get_item
([field_mapping])Builds the
fiftyone.utils.torch.GetItem
instance that defines how the model’s data should be loaded by data loaders.collate_fn
(batch)The collate function to use when creating dataloaders for this model.
predict
(img)Performs prediction on the given image.
predict_all
(imgs)Performs prediction on the given batch of images.
embed
(arg)Generates an embedding for the given data.
embed_all
(args)Generates embeddings for the given iterable of data.
from_config
(config)Instantiates a Configurable class from a <cls>Config instance.
from_dict
(d)Instantiates a Configurable class from a <cls>Config dict.
from_json
(json_path)Instantiates a Configurable class from a <cls>Config JSON file.
from_kwargs
(**kwargs)Instantiates a Configurable class from keyword arguments defining the attributes of a <cls>Config.
Returns the embeddings generated by the last forward pass of the model.
parse
(class_name[, module_name])Parses a Configurable subclass name string.
validate
(config)Validates that the given config is an instance of <cls>Config.
Attributes:
The media type processed by the model.
Whether this instance can generate logits.
Whether
transforms()
may return tensors of different sizes.A
torchvision.transforms
function that will be applied to each input before prediction, if any.Whether this model has a custom collate function.
Whether to apply preprocessing transforms for inference, if any.
Whether the model is using GPU.
The
torch.torch.device
that the model is using.Whether the model is using half precision.
The list of class labels for the model, if known.
The number of classes for the model, if known.
The mask targets for the model, if any.
The keypoint skeleton for the model, if any.
Whether this instance can generate prompt embeddings.
Whether this instance has embeddings.
The required keys that must be provided as parameters to methods like
apply_model()
andcompute_embeddings()
at runtime.Whether the model should store logits in its predictions.
-
build_get_item
(field_mapping=None)¶ Builds the
fiftyone.utils.torch.GetItem
instance that defines how the model’s data should be loaded by data loaders.- Parameters
field_mapping (None) – a user-provided dict mapping required keys to dataset field names
- Returns
a
fiftyone.utils.torch.GetItem
instance
-
property
media_type
¶ The media type processed by the model.
-
property
has_logits
¶ Whether this instance can generate logits.
-
property
ragged_batches
¶ Whether
transforms()
may return tensors of different sizes. If True, then passing ragged lists of images topredict_all()
may not be not allowed.
-
property
transforms
¶ A
torchvision.transforms
function that will be applied to each input before prediction, if any.
-
property
has_collate_fn
¶ Whether this model has a custom collate function.
Set this to
True
if you wantcollate_fn()
to be used during inference.
-
static
collate_fn
(batch)¶ The collate function to use when creating dataloaders for this model.
In order to enable this functionality, the model’s
has_collate_fn()
property must returnTrue
.By default, this is the default collate function for
torch.utils.data.DataLoader
, but subclasses can override this method as necessary.Note that this function must be serializable so it is compatible with multiprocessing for dataloaders.
- Parameters
batch – a list of items to collate
- Returns
the collated batch, which will be fed directly to the model
-
property
preprocess
¶ Whether to apply preprocessing transforms for inference, if any.
-
property
using_gpu
¶ Whether the model is using GPU.
-
property
device
¶ The
torch.torch.device
that the model is using.
-
property
using_half_precision
¶ Whether the model is using half precision.
-
property
classes
¶ The list of class labels for the model, if known.
-
property
num_classes
¶ The number of classes for the model, if known.
-
property
mask_targets
¶ The mask targets for the model, if any.
-
property
skeleton
¶ The keypoint skeleton for the model, if any.
-
predict
(img)¶ Performs prediction on the given image.
- Parameters
img –
the image to process, which can be any of the following:
A PIL image
A uint8 numpy array (HWC)
A Torch tensor (CHW)
- Returns
a
fiftyone.core.labels.Label
instance or dict offiftyone.core.labels.Label
instances containing the predictions
-
predict_all
(imgs)¶ Performs prediction on the given batch of images.
- Parameters
imgs –
the batch of images to process, which can be any of the following:
A list of PIL images
A list of uint8 numpy arrays (HWC)
A list of Torch tensors (CHW)
A uint8 numpy tensor (NHWC)
A Torch tensor (NCHW)
- Returns
a list of
fiftyone.core.labels.Label
instances or a list of dicts offiftyone.core.labels.Label
instances containing the predictions
-
property
can_embed_prompts
¶ Whether this instance can generate prompt embeddings.
This method returns
False
by default. Methods that can generate prompt embeddings will override this via implementing thePromptMixin
interface.
-
embed
(arg)¶ Generates an embedding for the given data.
Subclasses can override this method to increase efficiency, but, by default, this method simply calls
predict()
and then returnsget_embeddings()
.- Parameters
arg – the data. See
predict()
for details- Returns
a numpy array containing the embedding
-
embed_all
(args)¶ Generates embeddings for the given iterable of data.
Subclasses can override this method to increase efficiency, but, by default, this method simply iterates over the data and applies
embed()
to each.- Parameters
args – an iterable of data. See
predict_all()
for details- Returns
a numpy array containing the embeddings stacked along axis 0
-
classmethod
from_config
(config)¶ Instantiates a Configurable class from a <cls>Config instance.
-
classmethod
from_dict
(d)¶ Instantiates a Configurable class from a <cls>Config dict.
- Parameters
d – a dict to construct a <cls>Config
- Returns
an instance of cls
-
classmethod
from_json
(json_path)¶ Instantiates a Configurable class from a <cls>Config JSON file.
- Parameters
json_path – path to a JSON file for type <cls>Config
- Returns
an instance of cls
-
classmethod
from_kwargs
(**kwargs)¶ Instantiates a Configurable class from keyword arguments defining the attributes of a <cls>Config.
- Parameters
**kwargs – keyword arguments that define the fields of a <cls>Config dict
- Returns
an instance of cls
-
get_embeddings
()¶ Returns the embeddings generated by the last forward pass of the model.
By convention, this method should always return an array whose first axis represents batch size (which will always be 1 when
predict()
was last used).- Returns
a numpy array containing the embedding(s)
-
property
has_embeddings
¶ Whether this instance has embeddings.
-
static
parse
(class_name, module_name=None)¶ Parses a Configurable subclass name string.
Assumes both the Configurable class and the Config class are defined in the same module. The module containing the classes will be loaded if necessary.
- Parameters
class_name – a string containing the name of the Configurable class, e.g. “ClassName”, or a fully-qualified class name, e.g. “eta.core.config.ClassName”
module_name – a string containing the fully-qualified module name, e.g. “eta.core.config”, or None if class_name includes the module name. Set module_name = __name__ to load a class from the calling module
- Returns
the Configurable class config_cls: the Config class associated with cls
- Return type
cls
-
property
required_keys
¶ The required keys that must be provided as parameters to methods like
apply_model()
andcompute_embeddings()
at runtime.
-
property
store_logits
¶ Whether the model should store logits in its predictions.
-
classmethod
validate
(config)¶ Validates that the given config is an instance of <cls>Config.
- Raises
ConfigurableError – if config is not an instance of <cls>Config
-
class
fiftyone.utils.torch.
TorchSamplesMixin
¶ Bases:
fiftyone.core.models.SamplesMixin
Methods:
predict
(img[, sample])Performs prediction on the given data.
predict_all
(args[, samples])Performs prediction on the given iterable of data.
Attributes:
A dict mapping model-specific keys to sample field names.
-
predict
(img, sample=None)¶ Performs prediction on the given data.
Image models should support, at minimum, processing
arg
values that are uint8 numpy arrays (HWC).Video models should support, at minimum, processing
arg
values that areeta.core.video.VideoReader
instances.- Parameters
arg – the data
sample (None) – the
fiftyone.core.sample.Sample
associated with the data
- Returns
a
fiftyone.core.labels.Label
instance or dict offiftyone.core.labels.Label
instances containing the predictions
-
property
needs_fields
¶ A dict mapping model-specific keys to sample field names.
-
predict_all
(args, samples=None)¶ Performs prediction on the given iterable of data.
Image models should support, at minimum, processing
args
values that are either lists of uint8 numpy arrays (HWC) or numpy array tensors (NHWC).Video models should support, at minimum, processing
args
values that are lists ofeta.core.video.VideoReader
instances.Subclasses can override this method to increase efficiency, but, by default, this method simply iterates over the data and applies
predict()
to each.- Parameters
args – an iterable of data
samples (None) – an iterable of
fiftyone.core.sample.Sample
instances associated with the data
- Returns
a list of
fiftyone.core.labels.Label
instances or a list of dicts offiftyone.core.labels.Label
instances containing the predictions
-
-
class
fiftyone.utils.torch.
ToPILImage
¶ Bases:
object
Transform that converts a tensor or ndarray to a PIL image, while also allowing PIL images to passthrough.
-
class
fiftyone.utils.torch.
MinResize
(min_output_size, interpolation=None)¶ Bases:
object
Transform that resizes the PIL image or torch Tensor, if necessary, so that its minimum dimensions are at least the specified size.
- Parameters
min_output_size – desired minimum output dimensions. Can either be a
(min_height, min_width)
tuple or a singlemin_dim
interpolation (None) – optional interpolation mode. Passed directly to
torchvision.transforms.functional.resize()
-
class
fiftyone.utils.torch.
MaxResize
(max_output_size, interpolation=None)¶ Bases:
object
Transform that resizes the PIL image or torch Tensor, if necessary, so that its maximum dimensions are at most the specified size.
- Parameters
max_output_size – desired maximum output dimensions. Can either be a
(max_height, max_width)
tuple or a singlemax_dim
interpolation (None) – optional interpolation mode. Passed directly to
torchvision.transforms.functional.resize()
-
class
fiftyone.utils.torch.
PatchSize
(patch_size)¶ Bases:
object
Transform that center crops the PIL image or torch Tensor, if necessary, so that its dimensions are multiples of the specified patch size.
- Parameters
patch_size – the patch size
-
class
fiftyone.utils.torch.
SaveLayerTensor
(model, layer_name)¶ Bases:
object
Callback that saves the input/output tensor of the specified layer of a Torch model during each
forward()
call.- Parameters
model – the Torch model, a
torch.nn.Module
layer_name – the name of the layer whose output to save. Prepend
"<"
to save the input tensor instead
Attributes:
The tensor saved from the last
forward()
call.-
property
tensor
¶ The tensor saved from the last
forward()
call.
-
class
fiftyone.utils.torch.
OutputProcessor
(classes=None, **kwargs)¶ Bases:
object
Interface for processing the outputs of Torch models.
- Parameters
classes (None) – the list of class labels for the model. This may not be required or used by some models
-
class
fiftyone.utils.torch.
ClassifierOutputProcessor
(classes=None, store_logits=False, logits_key='logits')¶ Bases:
fiftyone.utils.torch.OutputProcessor
Output processor for single label classifiers.
- Parameters
classes (None) – the list of class labels for the model
store_logits (False) – whether to store logits in the model outputs
-
class
fiftyone.utils.torch.
DetectorOutputProcessor
(classes=None)¶ Bases:
fiftyone.utils.torch.OutputProcessor
Output processor for object detectors.
- Parameters
classes (None) – the list of class labels for the model
-
class
fiftyone.utils.torch.
InstanceSegmenterOutputProcessor
(classes=None, mask_thresh=0.5)¶ Bases:
fiftyone.utils.torch.OutputProcessor
Output processor for instance segementers.
- Parameters
classes (None) – the list of class labels for the model
mask_thresh (0.5) – a threshold to use to convert soft masks to binary masks
-
class
fiftyone.utils.torch.
KeypointDetectorOutputProcessor
(classes=None)¶ Bases:
fiftyone.utils.torch.OutputProcessor
Output processor for keypoint detection models.
- Parameters
classes (None) – the list of class labels for the model
-
class
fiftyone.utils.torch.
SemanticSegmenterOutputProcessor
(classes=None)¶ Bases:
fiftyone.utils.torch.OutputProcessor
Output processor for semantic segementers.
- Parameters
classes (None) – the list of class labels for the model. This parameter is not used
-
fiftyone.utils.torch.
recommend_num_workers
()¶ Recommend a number of workers for running a
torch.utils.data.DataLoader
.- Returns
the recommended number of workers
-
class
fiftyone.utils.torch.
FiftyOneTorchDataset
(samples, get_item, vectorize=False, skip_failures=False, local_process_group=None)¶ Bases:
Generic
[torch.utils.data.dataset._T_co
]Constructs a
torch.utils.data.Dataset
that loads data from an arbitraryfiftyone.core.collections.SampleCollection
via the providedGetItem
instance.- Parameters
samples – a
fiftyone.core.collections.SampleCollection
get_item – a
GetItem
vectorize (False) – whether to load and cache the required fields from the sample collection upfront (True) or lazily load the values from each sample when items are retrieved (False). Vectorizing gives faster data loading times, but you must have enough memory to store the required field values for the entire collection. When
vectorize=True
, all field values must be serializable; iepickle.dumps(field_value)
must not raise an errorskip_failures (False) – whether to skip failures that occur when calling
get_item
. If True, the exception will be returned rather than the intended field valueslocal_process_group (None) – the local process group. Only used during distributed training
Attributes:
Methods:
worker_init
(worker_id)Initializes a worker during inference/training.
distributed_init
(dataset_name, …[, view_name])Initializes a trainer process during distributed training.
-
property
samples
¶
-
static
worker_init
(worker_id)¶ Initializes a worker during inference/training.
This method is used as the
worker_init_fn
parameter fortorch.utils.data.DataLoader
.- Parameters
worker_id – the worker ID
-
static
distributed_init
(dataset_name, local_process_group, view_name=None)¶ Initializes a trainer process during distributed training.
This function should be called at the beginning of the training script. It facilitates communication between processes and safely creates a database connection for each trainer.
- Parameters
dataset_name – the name of the dataset to load
local_process_group – the process group with all the processes running the main training script
view_name (None) – the name of a saved view to load
- Returns
the loaded
fiftyone.core.dataset.Dataset
orfiftyone.core.view.DatasetView
-
class
fiftyone.utils.torch.
TorchImageDataset
(image_paths=None, samples=None, sample_ids=None, include_ids=False, transform=None, use_numpy=False, force_rgb=False, skip_failures=False)¶ Bases:
Generic
[torch.utils.data.dataset._T_co
]A
torch.utils.data.Dataset
of images.Instances of this dataset emit images for each sample, or
(img, sample_id)
pairs ifsample_ids
are provided orinclude_ids == True
.By default, this class will load images in PIL format and emit Torch tensors, but you can use numpy images/tensors instead by passing
use_numpy = True
.- Parameters
image_paths (None) – an iterable of image paths
samples (None) – a
fiftyone.core.collections.SampleCollection
from which to extract image pathssample_ids (None) – an iterable of sample IDs corresponding to each image
include_ids (False) – whether to include the IDs of the
samples
in the returned itemstransform (None) – an optional transform function to apply to each image patch. When
use_numpy == False
, this is typically a torchvision transformuse_numpy (False) – whether to use numpy arrays rather than PIL images and Torch tensors when loading data
force_rgb (False) – whether to force convert the images to RGB
skip_failures (False) – whether to return an
Exception
object rather than raising it if an error occurs while loading a sample
Attributes:
Whether this dataset has sample IDs.
-
property
has_sample_ids
¶ Whether this dataset has sample IDs.
-
class
fiftyone.utils.torch.
TorchImageClassificationDataset
(image_paths=None, targets=None, samples=None, sample_ids=None, include_ids=False, transform=None, use_numpy=False, force_rgb=False, skip_failures=False)¶ Bases:
Generic
[torch.utils.data.dataset._T_co
]A
torch.utils.data.Dataset
for image classification.Instances of this dataset emit images and their associated targets for each sample, either directly as
(img, target)
pairs or as(img, target, sample_id)
pairs ifsample_ids
are provided orinclude_ids == True
.By default, this class will load images in PIL format and emit Torch tensors, but you can use numpy images/tensors instead by passing
use_numpy = True
.- Parameters
image_paths (None) – an iterable of image paths
targets (None) – an iterable of targets, or the name of a field or embedded field of
samples
to use as targetssamples (None) – a
fiftyone.core.collections.SampleCollection
from which to extract image paths and targetssample_ids (None) – an iterable of sample IDs corresponding to each image
include_ids (False) – whether to include the IDs of the
samples
in the returned itemstransform (None) – an optional transform function to apply to each image patch. When
use_numpy == False
, this is typically a torchvision transformuse_numpy (False) – whether to use numpy arrays rather than PIL images and Torch tensors when loading data
force_rgb (False) – whether to force convert the images to RGB
skip_failures (False) – whether to return an
Exception
object rather than raising it if an error occurs while loading a sample
Attributes:
Whether this dataset has sample IDs.
-
property
has_sample_ids
¶ Whether this dataset has sample IDs.
-
class
fiftyone.utils.torch.
TorchImagePatchesDataset
(image_paths=None, patches=None, samples=None, patches_field=None, handle_missing='skip', transform=None, sample_ids=None, include_ids=False, ragged_batches=False, use_numpy=False, force_rgb=False, force_square=False, alpha=None, skip_failures=False)¶ Bases:
Generic
[torch.utils.data.dataset._T_co
]A
torch.utils.data.Dataset
of image patch tensors extracted from a list of images.Provide either
image_paths
andpatches
orsamples
andpatches_field
in order to use this dataset.Instances of this dataset emit image patches for each sample, or
(patches, sample_id)
tuples ifsample_ids
are provided orinclude_ids == True
.By default, this class will load images in PIL format and emit Torch tensors, but you can use numpy images/tensors instead by passing
use_numpy = True
.If
ragged_batches = False
(the default), this class will emit tensors containing the stacked (along axis 0) patches from each image. In this case, the providedtransform
must ensure that all image patches are resized to the same shape so they can be stacked.If
ragged_batches = True
, lists of patch tensors will be returned.- Parameters
image_paths (None) – an iterable of image paths
patches (None) – a list of labels of type
fiftyone.core.labels.Detection
,fiftyone.core.labels.Detections
,fiftyone.core.labels.Polyline
, orfiftyone.core.labels.Polylines
specifying the image patch(es) to extract from each image. Elements can beNone
if an image has no patchessamples (None) – a
fiftyone.core.collections.SampleCollection
from which to extract patchespatches_field (None) – the name of the field defining the image patches in
samples
to extract. Must be of typefiftyone.core.labels.Detection
,fiftyone.core.labels.Detections
,fiftyone.core.labels.Polyline
, orfiftyone.core.labels.Polylines
handle_missing ("skip") –
how to handle images with no patches. The supported values are:
”skip”: skip the image and assign its embedding as
None
”image”: use the whole image as a single patch
”error”: raise an error
transform (None) – an optional transform function to apply to each image patch. When
use_numpy == False
, this is typically a torchvision transformsample_ids (None) – an iterable of sample IDs corresponding to each image
include_ids (False) – whether to include the IDs of the
samples
in the returned itemsragged_batches (False) – whether the provided
transform
may return tensors of different dimensions and thus cannot be stackeduse_numpy (False) – whether to use numpy arrays rather than PIL images and Torch tensors when loading data
force_rgb (False) – whether to force convert the images to RGB
force_square (False) – whether to minimally manipulate the patch bounding boxes into squares prior to extraction
alpha (None) – an optional expansion/contraction to apply to the patches before extracting them, in
[-1, inf)
. If provided, the length and width of the box are expanded (or contracted, whenalpha < 0
) by(100 * alpha)%
. For example, setalpha = 0.1
to expand the boxes by 10%, and setalpha = -0.1
to contract the boxes by 10%skip_failures (False) – whether to return an
Exception
object rather than raising it if an error occurs while loading a sample
Attributes:
Whether this dataset has sample IDs.
-
property
has_sample_ids
¶ Whether this dataset has sample IDs.
-
fiftyone.utils.torch.
from_image_classification_dir_tree
(dataset_dir)¶ Creates a
torch.utils.data.Dataset
for the given image classification dataset directory tree.The directory should have the following format:
<dataset_dir>/ <classA>/ <image1>.<ext> <image2>.<ext> ... <classB>/ <image1>.<ext> <image2>.<ext> ...
- Parameters
dataset_dir – the dataset directory
- Returns
-
class
fiftyone.utils.torch.
NumpySerializedList
(lst: list)¶ Bases:
object
-
class
fiftyone.utils.torch.
TorchSerializedList
(lst: list)¶
-
class
fiftyone.utils.torch.
TorchShmSerializedList
(lst: list, local_process_group)¶
-
fiftyone.utils.torch.
get_local_size
(local_process_group)¶ Gets the number of processes per-machine in the local process group.
- Parameters
local_process_group – the local process group
- Returns
the number of processes per-machine
-
fiftyone.utils.torch.
get_world_size
()¶ Returns the world size of the current operation.
- Returns
the world size
-
fiftyone.utils.torch.
get_local_rank
(local_process_group)¶ Gets the rank of the current process within the local processes group.
- Parameters
local_process_group – the local process group
- Returns
the rank of the current process
-
fiftyone.utils.torch.
get_rank
()¶ Gets the rank of the current process.
- Returns
the rank of the current process
-
fiftyone.utils.torch.
local_scatter
(array, local_process_group)¶ Scatters the given array from the local leader to all local workers.
The worker with rank
i
getsarray[i]
.- Parameters
array – an array with same size as the local process group
local_process_group – the local process group
- Returns
the array element for the current rank
-
fiftyone.utils.torch.
all_gather
(data, group=None)¶ Gathers arbitrary picklable data (not necessarily tensors).
- Parameters
data – any picklable object
group (None) – a torch process group. By default, uses a group which contains all ranks on gloo backend
- Returns
the list of data gathered from each rank
-
fiftyone.utils.torch.
local_broadcast_process_authkey
(local_process_group)¶