fiftyone.utils.sam#
Segment Anything wrapper for the FiftyOne Model Zoo.
Classes:
|
Configuration for running a |
|
Enumeration of supported prompt modes for SAM. |
|
A |
Workaround for applying image model to video reader frames. |
|
|
Converts SAM model outputs to FiftyOne format. |
|
Wrapper for running Segment Anything inference. |
Functions:
|
Pre-processes boxes from |
|
Pre-processes points from |
- class fiftyone.utils.sam.SegmentAnythingModelConfig(cfg_dict)#
Bases:
TorchImageModelConfig,HasZooModelConfiguration for running a
SegmentAnythingModel.See
fiftyone.utils.torch.TorchImageModelConfigfor additional arguments.- Parameters:
auto_kwargs (None) – a dictionary of keyword arguments to pass to
segment_anything.SamAutomaticMaskGenerator(model, **auto_kwargs)points_mask_index (None) – an optional mask index to use for each keypoint output
get_item_cls (None) – a string like
"fiftyone.utils.sam.SegmentAnythingImageGetItem"specifying theGetItemto use for SAMget_item_args (None) – a dictionary of arguments for
get_item_cls(field_mapping=field_mapping, **kwargs)
Methods:
Returns a list of class attributes to be serialized.
builder()Returns a ConfigBuilder instance for this class.
copy()Returns a deep copy of the object.
custom_attributes([dynamic, private])Returns a customizable list of class attributes.
default()Returns the default config instance.
Downloads the published model specified by the config, if necessary.
from_dict(d)Constructs a Config object from a JSON dictionary.
from_json(path, *args, **kwargs)Constructs a Serializable object from a JSON file.
from_kwargs(**kwargs)Constructs a Config object from keyword arguments.
from_str(s, *args, **kwargs)Constructs a Serializable object from a JSON string.
Returns the fully-qualified class name string of this object.
init(d)Initializes the published model config.
Loads the default config instance from file.
parse_array(d, key[, default])Parses a raw array attribute.
parse_bool(d, key[, default])Parses a boolean value.
parse_categorical(d, key, choices[, default])Parses a categorical JSON field, which must take a value from among the given choices.
parse_dict(d, key[, default])Parses a dictionary attribute.
parse_int(d, key[, default])Parses an integer attribute.
parse_mutually_exclusive_fields(fields)Parses a mutually exclusive dictionary of pre-parsed fields, which must contain exactly one field with a truthy value.
parse_number(d, key[, default])Parses a number attribute.
parse_object(d, key, cls[, default])Parses an object attribute.
parse_object_array(d, key, cls[, default])Parses an array of objects.
parse_object_dict(d, key, cls[, default])Parses a dictionary whose values are objects.
parse_path(d, key[, default])Parses a path attribute.
parse_raw(d, key[, default])Parses a raw (arbitrary) JSON field.
parse_string(d, key[, default])Parses a string attribute.
serialize([reflective])Serializes the object into a dictionary.
to_str([pretty_print])Returns a string representation of this object.
validate_all_or_nothing_fields(fields)Validates a dictionary of pre-parsed fields checking that either all or none of the fields have a truthy value.
write_json(path[, pretty_print])Serializes the object and writes it to disk.
- attributes()#
Returns a list of class attributes to be serialized.
This method is called internally by serialize() to determine the class attributes to serialize.
Subclasses can override this method, but, by default, all attributes in vars(self) are returned, minus private attributes, i.e., those starting with “_”. The order of the attributes in this list is preserved when serializing objects, so a common pattern is for subclasses to override this method if they want their JSON files to be organized in a particular way.
- Returns:
a list of class attributes to be serialized
- classmethod builder()#
Returns a ConfigBuilder instance for this class.
- copy()#
Returns a deep copy of the object.
- Returns:
a Serializable instance
- custom_attributes(dynamic=False, private=False)#
Returns a customizable list of class attributes.
By default, all attributes in vars(self) are returned, minus private attributes (those starting with “_”).
- Parameters:
dynamic – whether to include dynamic properties, e.g., those defined by getter/setter methods or the @property decorator. By default, this is False
private – whether to include private properties, i.e., those starting with “_”. By default, this is False
- Returns:
a list of class attributes
- classmethod default()#
Returns the default config instance.
By default, this method instantiates the class from an empty dictionary, which will only succeed if all attributes are optional. Otherwise, subclasses should override this method to provide the desired default configuration.
- download_model_if_necessary()#
Downloads the published model specified by the config, if necessary.
After this method is called, the model_path attribute will always contain the path to the model on disk.
- classmethod from_dict(d)#
Constructs a Config object from a JSON dictionary.
Config subclass constructors accept JSON dictionaries, so this method simply passes the dictionary to cls().
- Parameters:
d – a dict of fields expected by cls
- Returns:
an instance of cls
- classmethod from_json(path, *args, **kwargs)#
Constructs a Serializable object from a JSON file.
Subclasses may override this method, but, by default, this method simply reads the JSON and calls from_dict(), which subclasses must implement.
- Parameters:
path – the path to the JSON file on disk
*args – optional positional arguments for self.from_dict()
**kwargs – optional keyword arguments for self.from_dict()
- Returns:
an instance of the Serializable class
- classmethod from_kwargs(**kwargs)#
Constructs a Config object from keyword arguments.
- Parameters:
**kwargs – keyword arguments that define the fields expected by cls
- Returns:
an instance of cls
- classmethod from_str(s, *args, **kwargs)#
Constructs a Serializable object from a JSON string.
Subclasses may override this method, but, by default, this method simply parses the string and calls from_dict(), which subclasses must implement.
- Parameters:
s – a JSON string representation of a Serializable object
*args – optional positional arguments for self.from_dict()
**kwargs – optional keyword arguments for self.from_dict()
- Returns:
an instance of the Serializable class
- classmethod get_class_name()#
Returns the fully-qualified class name string of this object.
- init(d)#
Initializes the published model config.
This method should be called by ModelConfig.__init__(), and it performs the following tasks:
Parses the model_name and model_path parameters
Populates any default parameters in the provided ModelConfig dict
- Parameters:
d – a ModelConfig dict
- Returns:
a ModelConfig dict with any default parameters populated
- classmethod load_default()#
Loads the default config instance from file.
Subclasses must implement this method if they intend to support default instances.
- static parse_array(d, key, default=<eta.core.config.NoDefault object>)#
Parses a raw array attribute.
- Parameters:
d – a JSON dictionary
key – the key to parse
default – a default list to return if key is not present
- Returns:
a list of raw (untouched) values
- Raises:
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
- static parse_bool(d, key, default=<eta.core.config.NoDefault object>)#
Parses a boolean value.
- Parameters:
d – a JSON dictionary
key – the key to parse
default – a default bool to return if key is not present
- Returns:
True/False
- Raises:
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
- static parse_categorical(d, key, choices, default=<eta.core.config.NoDefault object>)#
Parses a categorical JSON field, which must take a value from among the given choices.
- Parameters:
d – a JSON dictionary
key – the key to parse
choices – either an iterable of possible values or an enum-like class whose attributes define the possible values
default – a default value to return if key is not present
- Returns:
the raw (untouched) value of the given field, which is equal to a value from choices
- Raises:
ConfigError – if the key was present in the dictionary but its value was not an allowed choice, or if no default value was provided and the key was not found in the dictionary
- static parse_dict(d, key, default=<eta.core.config.NoDefault object>)#
Parses a dictionary attribute.
- Parameters:
d – a JSON dictionary
key – the key to parse
default – a default dict to return if key is not present
- Returns:
a dictionary
- Raises:
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
- static parse_int(d, key, default=<eta.core.config.NoDefault object>)#
Parses an integer attribute.
- Parameters:
d – a JSON dictionary
key – the key to parse
default – a default integer value to return if key is not present
- Returns:
an int
- Raises:
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
- static parse_mutually_exclusive_fields(fields)#
Parses a mutually exclusive dictionary of pre-parsed fields, which must contain exactly one field with a truthy value.
- Parameters:
fields – a dictionary of pre-parsed fields
- Returns:
the (field, value) that was set
- Raises:
ConfigError – if zero or more than one truthy value was found
- static parse_number(d, key, default=<eta.core.config.NoDefault object>)#
Parses a number attribute.
- Parameters:
d – a JSON dictionary
key – the key to parse
default – a default numeric value to return if key is not present
- Returns:
a number (e.g. int, float)
- Raises:
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
- static parse_object(d, key, cls, default=<eta.core.config.NoDefault object>)#
Parses an object attribute.
The value of d[key] can be either an instance of cls or a serialized dict from an instance of cls.
- Parameters:
d – a JSON dictionary
key – the key to parse
cls – the class of d[key]
default – a default cls instance to return if key is not present
- Returns:
an instance of cls
- Raises:
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
- static parse_object_array(d, key, cls, default=<eta.core.config.NoDefault object>)#
Parses an array of objects.
The values in d[key] can be either instances of cls or serialized dicts from instances of cls.
- Parameters:
d – a JSON dictionary
key – the key to parse
cls – the class of the elements of list d[key]
default – the default list to return if key is not present
- Returns:
a list of cls instances
- Raises:
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
- static parse_object_dict(d, key, cls, default=<eta.core.config.NoDefault object>)#
Parses a dictionary whose values are objects.
The values in d[key] can be either instances of cls or serialized dicts from instances of cls.
- Parameters:
d – a JSON dictionary
key – the key to parse
cls – the class of the values of dictionary d[key]
default – the default dict of cls instances to return if key is not present
- Returns:
a dictionary whose values are cls instances
- Raises:
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
- static parse_path(d, key, default=<eta.core.config.NoDefault object>)#
Parses a path attribute.
The path is converted to an absolute path if necessary via
os.path.abspath(os.path.expanduser(value)).- Parameters:
d – a JSON dictionary
key – the key to parse
default – a default string to return if key is not present
- Returns:
a path string
- Raises:
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
- static parse_raw(d, key, default=<eta.core.config.NoDefault object>)#
Parses a raw (arbitrary) JSON field.
- Parameters:
d – a JSON dictionary
key – the key to parse
default – a default value to return if key is not present
- Returns:
the raw (untouched) value of the given field
- Raises:
ConfigError – if no default value was provided and the key was not found in the dictionary
- static parse_string(d, key, default=<eta.core.config.NoDefault object>)#
Parses a string attribute.
- Parameters:
d – a JSON dictionary
key – the key to parse
default – a default string to return if key is not present
- Returns:
a string
- Raises:
ConfigError – if the field value was the wrong type or no default value was provided and the key was not found in the dictionary
- serialize(reflective=False)#
Serializes the object into a dictionary.
Serialization is applied recursively to all attributes in the object, including element-wise serialization of lists and dictionary values.
- Parameters:
reflective – whether to include reflective attributes when serializing the object. By default, this is False
- Returns:
a JSON dictionary representation of the object
- to_str(pretty_print=True, **kwargs)#
Returns a string representation of this object.
- Parameters:
pretty_print – whether to render the JSON in human readable format with newlines and indentations. By default, this is True
**kwargs – optional keyword arguments for self.serialize()
- Returns:
a string representation of the object
- static validate_all_or_nothing_fields(fields)#
Validates a dictionary of pre-parsed fields checking that either all or none of the fields have a truthy value.
- Parameters:
fields – a dictionary of pre-parsed fields
- Raises:
ConfigError – if some values are truth and some are not
- write_json(path, pretty_print=False, **kwargs)#
Serializes the object and writes it to disk.
- Parameters:
path – the output path
pretty_print – whether to render the JSON in human readable format with newlines and indentations. By default, this is False
**kwargs – optional keyword arguments for self.serialize()
- class fiftyone.utils.sam.SAMPromptMode(value)#
Bases:
EnumEnumeration of supported prompt modes for SAM.
Attributes:
- auto = 1#
- box_only = 2#
- point_only = 3#
- box_point_combo = 4#
- class fiftyone.utils.sam.SegmentAnythingImageGetItem(field_mapping=None, transform=None, use_numpy=False, box_transform=None, point_transform=None, **kwargs)#
Bases:
GetItemA
GetItemthat loads images, bounding boxes and/or keypoints to feed toSegmentAnythingModelinstances.- Parameters:
field_mapping (None) – the user-supplied dict mapping keys in
required_keysto field names of their dataset that contain the required valuestransform (None) – SAM specific image transform function to apply
use_numpy (False) – whether to use numpy arrays rather than PIL images and Torch tensors when loading data
box_transform (None) – SAM specific box transform function to apply
point_transform (None) – SAM specific point transform function to apply
Attributes:
The list of keys that must exist on the dicts provided to the
__call__()method at runtime.A user-supplied dictionary mappings keys in
required_keysto field names of their dataset that contain the required values.- property required_keys#
The list of keys that must exist on the dicts provided to the
__call__()method at runtime.
- property field_mapping#
A user-supplied dictionary mappings keys in
required_keysto field names of their dataset that contain the required values.
- class fiftyone.utils.sam.SegmentAnythingImageGetItemForVideo(field_mapping=None, transform=None, use_numpy=False, box_transform=None, point_transform=None, **kwargs)#
Bases:
SegmentAnythingImageGetItemWorkaround for applying image model to video reader frames.
Frames are not stored on disk and therefore cannot be loaded.
Attributes:
The list of keys that must exist on the dicts provided to the
__call__()method at runtime.A user-supplied dictionary mappings keys in
required_keysto field names of their dataset that contain the required values.- property required_keys#
The list of keys that must exist on the dicts provided to the
__call__()method at runtime.
- property field_mapping#
A user-supplied dictionary mappings keys in
required_keysto field names of their dataset that contain the required values.
- fiftyone.utils.sam.preprocess_detections_to_sam(detections, img_hw, box_transform)#
Pre-processes boxes from
fiftyone.core.labels.Detections.- Parameters:
detections – a
fiftyone.core.labels.Detectionsinstanceimg_hw – original image height and width
box_transform – SAM specific box transform function to apply
- Returns:
a torch tensor of boxes for SAM model prompts a numpy array of boxes in XYXY pixels in original image space a list class labels for the boxes a list of positive/negative labels for the boxes
- fiftyone.utils.sam.preprocess_keypoints_to_sam(keypoints, img_hw, point_transform)#
Pre-processes points from
fiftyone.core.labels.Keypoints.- Parameters:
keypoints – a
fiftyone.core.labels.Keypointsinstanceimg_hw – original image height and width
point_transform – SAM specific point transform function to apply
- Returns:
a list of torch tensor of points in XYXY pixels for SAM model prompts a list of torch tensor of positive and negative labels for each point a list of class labels for each set of points
- class fiftyone.utils.sam.SAMSegmenterOutputProcessor(classes=None, mask_thresh=0.5, **kwargs)#
Bases:
OutputProcessorConverts SAM model outputs to FiftyOne format.
- Parameters:
classes (None) – the list of class labels for the model
mask_thresh (0.5) – Threshold for converting float masks to boolean masks
- class fiftyone.utils.sam.SegmentAnythingModel(config)#
Bases:
TorchImageModelWithPromptsWrapper for running Segment Anything inference.
Box prompt example:
import fiftyone as fo import fiftyone.zoo as foz dataset = foz.load_zoo_dataset( "quickstart", max_samples=25, shuffle=True, seed=51 ) model = foz.load_zoo_model("segment-anything-vitb-torch") # Prompt with boxes dataset.apply_model( model, label_field="segmentations", box_prompt_field="ground_truth", ) session = fo.launch_app(dataset)
Keypoint prompt example:
import fiftyone as fo import fiftyone.zoo as foz dataset = foz.load_zoo_dataset( "coco-2017", split="validation", label_types="detections", classes=["person"], max_samples=25, only_matching=True, ) # Generate some keypoints model = foz.load_zoo_model("keypoint-rcnn-resnet50-fpn-coco-torch") dataset.default_skeleton = model.skeleton dataset.apply_model(model, label_field="gt") model = foz.load_zoo_model("segment-anything-vitb-torch") # Prompt with keypoints dataset.apply_model( model, label_field="segmentations", point_prompt_field="gt_keypoints", ) session = fo.launch_app(dataset)
Automatic segmentation example:
import fiftyone as fo import fiftyone.zoo as foz dataset = foz.load_zoo_dataset( "quickstart", max_samples=5, shuffle=True, seed=51 ) model = foz.load_zoo_model("segment-anything-vitb-torch") # Automatic segmentation dataset.apply_model(model, label_field="auto") session = fo.launch_app(dataset)
- Parameters:
config – a
SegmentAnythingModelConfig
Methods:
predict_interactive([sample, boxes, points, ...])Generates predictions in interactive mode.
predict(img[, sample])Performs prediction a single image.
predict_all(imgs[, samples])Performs prediction on multiple images.
build_get_item([field_mapping])Builds a
SegmentAnythingImageGetItemfor loading model input from samples.collate_fn(batch)Collates a batch of inputs where each input is generated from
SegmentAnythingImageGetItem.embed(arg)Generates an embedding for the given data.
embed_all(args)Generates embeddings for the given iterable of data.
from_config(config)Instantiates a Configurable class from a <cls>Config instance.
from_dict(d)Instantiates a Configurable class from a <cls>Config dict.
from_json(json_path)Instantiates a Configurable class from a <cls>Config JSON file.
from_kwargs(**kwargs)Instantiates a Configurable class from keyword arguments defining the attributes of a <cls>Config.
Returns the embeddings generated by the last forward pass of the model.
parse(class_name[, module_name])Parses a Configurable subclass name string.
validate(config)Validates that the given config is an instance of <cls>Config.
Attributes:
Whether
transforms()may return tensors of different sizes.Whether this model has a custom collate function.
Whether this model can generate prompt embeddings.
The list of class labels for the model, if known.
The
torch:torch.torch.devicethat the model is using.Whether this model has embeddings.
Whether this instance can generate logits.
The mask targets for the model, if any.
The media type processed by the model.
The number of classes for the model, if known.
Whether to apply preprocessing transforms for inference, if any.
The required keys that must be provided as parameters to methods like
apply_model()andcompute_embeddings()at runtime.The keypoint skeleton for the model, if any.
Whether the model should store logits in its predictions.
A
torchvision.transformsfunction that will be applied to each input before prediction, if any.Whether the model is using GPU.
Whether the model is using half precision.
- predict_interactive(sample=None, boxes=None, points=None, point_labels=None, prompt_classes=None, boxes_xyxy=None)#
Generates predictions in interactive mode. Image embedding is cached.
- Parameters:
sample (None) – a FiftyOne Sample with image media
boxes (None) – a tensor of Bx4 pre-processed SAM transformed boxes in XYXY pixels
points (None) – a tensor of BxNx2 or a list of B tensors with pre-processed points in XY pixels
point_labels (None) – a BxN tensor or a list of B tensors of labels for the point prompts
prompt_classes (None) – a list of B class labels
boxes_xyxy – a list of Bx4 boxes in XYXY pixels in original image space
- Returns:
fiftyone.core.labels.Detectionsor dict containing the “masks”, “iou_predictions”, “low_res_logits” from SAM model output.
- predict(img, sample=None)#
Performs prediction a single image.
- Parameters:
img – a dictionary containing image, original size, and prompts. See
fiftyone.utils.sam.SegmentAnythingGetItemfor details.sample (None) – sample is no longer used. Available for backward compatibility.
- Returns:
a
fiftyone.core.labels.Detectionsinstance or a dict containing the “masks”, “iou_predictions”, “low_res_logits” from SAM model output.
- predict_all(imgs, samples=None)#
Performs prediction on multiple images.
To generate imgs dictionary and run prediction:
field_mapping = {“box_prompt_field”: “ground-truth”} get_item = model.build_get_item(field_mapping=field_mapping) model_inputs = fout.get_model_inputs_from_get_item(samples, get_item) outputs = model.predict_all(model_inputs)
- Parameters:
imgs – a list of dictionary or a dictionary containing images, original sizes, and prompts. See
fiftyone.utils.sam.SegmentAnythingGetItemfor details.samples (None) – samples is no longer used. Available for backward compatibility.
- Returns:
a list of
fiftyone.core.labels.Detectionsinstances or a list of dict containing the “masks”, “iou_predictions”, “low_res_logits” from SAM model output.
- build_get_item(field_mapping=None)#
Builds a
SegmentAnythingImageGetItemfor loading model input from samples.- Parameters:
field_mapping (None) – a dict mapping required keys to sample fields
- Returns:
a
SegmentAnythingImageGetIteminstance
- property ragged_batches#
Whether
transforms()may return tensors of different sizes. If True, then passing ragged lists of images topredict_all()may not be not allowed.
- property has_collate_fn#
Whether this model has a custom collate function.
Set this to
Trueif you wantcollate_fn()to be used during inference.
- static collate_fn(batch)#
Collates a batch of inputs where each input is generated from
SegmentAnythingImageGetItem.- Parameters:
batch – a list of dict containing model input from
SegmentAnythingImageGetItem- Returns:
“image”: a list of torch tensor of (1 x C X H X W) shape or HWC numpy arrays “boxes”: a list of B X 4 boxes for SAM model input “boxes_xyxy: a list of B x 4 boxes in XYXY pixels in original image space “boxes_labels”: a list of B x N positive / negative labels for boxes “point_coords”: a list of B X N x 2 point coordinates, padded as needed “point_labels”: a list of B X N point positive/negative labels, padded as needed “prompt_type”: name of prompt type for the batch “classes”: a list of classes for each prompt
- Return type:
a collated dictionary of model input for the batch. Expected keys are
- property can_embed_prompts#
Whether this model can generate prompt embeddings.
This method returns
Falseby default. Models that can generate prompt embeddings should override this via implementing thePromptMixininterface.
- property classes#
The list of class labels for the model, if known.
- property device#
The
torch:torch.torch.devicethat the model is using.
- embed(arg)#
Generates an embedding for the given data.
Subclasses can override this method to increase efficiency, but, by default, this method simply calls
predict()and then returnsget_embeddings().- Parameters:
arg – the data. See
predict()for details- Returns:
a numpy array containing the embedding
- embed_all(args)#
Generates embeddings for the given iterable of data.
Subclasses can override this method to increase efficiency, but, by default, this method simply iterates over the data and applies
embed()to each.- Parameters:
args – an iterable of data. See
predict_all()for details- Returns:
a numpy array containing the embeddings stacked along axis 0
- classmethod from_config(config)#
Instantiates a Configurable class from a <cls>Config instance.
- classmethod from_dict(d)#
Instantiates a Configurable class from a <cls>Config dict.
- Parameters:
d – a dict to construct a <cls>Config
- Returns:
an instance of cls
- classmethod from_json(json_path)#
Instantiates a Configurable class from a <cls>Config JSON file.
- Parameters:
json_path – path to a JSON file for type <cls>Config
- Returns:
an instance of cls
- classmethod from_kwargs(**kwargs)#
Instantiates a Configurable class from keyword arguments defining the attributes of a <cls>Config.
- Parameters:
**kwargs – keyword arguments that define the fields of a <cls>Config dict
- Returns:
an instance of cls
- get_embeddings()#
Returns the embeddings generated by the last forward pass of the model.
By convention, this method should always return an array whose first axis represents batch size (which will always be 1 when
predict()was last used).- Returns:
a numpy array containing the embedding(s)
- property has_embeddings#
Whether this model has embeddings.
- property has_logits#
Whether this instance can generate logits.
- property mask_targets#
The mask targets for the model, if any.
- property media_type#
The media type processed by the model.
- property num_classes#
The number of classes for the model, if known.
- static parse(class_name, module_name=None)#
Parses a Configurable subclass name string.
Assumes both the Configurable class and the Config class are defined in the same module. The module containing the classes will be loaded if necessary.
- Parameters:
class_name – a string containing the name of the Configurable class, e.g. “ClassName”, or a fully-qualified class name, e.g. “eta.core.config.ClassName”
module_name – a string containing the fully-qualified module name, e.g. “eta.core.config”, or None if class_name includes the module name. Set module_name = __name__ to load a class from the calling module
- Returns:
the Configurable class config_cls: the Config class associated with cls
- Return type:
- property preprocess#
Whether to apply preprocessing transforms for inference, if any.
- property required_keys#
The required keys that must be provided as parameters to methods like
apply_model()andcompute_embeddings()at runtime.
- property skeleton#
The keypoint skeleton for the model, if any.
- property store_logits#
Whether the model should store logits in its predictions.
- property transforms#
A
torchvision.transformsfunction that will be applied to each input before prediction, if any.
- property using_gpu#
Whether the model is using GPU.
- property using_half_precision#
Whether the model is using half precision.
- classmethod validate(config)#
Validates that the given config is an instance of <cls>Config.
- Raises:
ConfigurableError – if config is not an instance of <cls>Config