Source code for smqtk_relevancy.impls.rank_relevancy.wrap_classifier

from typing import Sequence, Dict, Any, TypeVar, Type

import numpy as np

from smqtk_classifier import ClassifyDescriptorSupervised
from smqtk_descriptors import DescriptorElement
from smqtk_descriptors.impls.descriptor_element.memory import DescriptorMemoryElement
from smqtk_core.configuration import (
    from_config_dict,
    make_default_config,
    cls_conf_to_config_dict,
)

from smqtk_relevancy.interfaces.rank_relevancy import RankRelevancy


T = TypeVar("T", bound="RankRelevancyWithSupervisedClassifier")


[docs]class RankRelevancyWithSupervisedClassifier(RankRelevancy): """ Relevancy ranking that utilizes a usable supervised classifier for on-the-fly training and inference. While the name of this class merely states "supervised classifier," we specifically utilize the interface for descriptor classification as opposed to the interfaces for other modalities (like images). # Classifier "cloning" The input supervised classifier instance to the constructor is not directly used, but its type and configuration are recorded in order to create a new instance in ``rank`` to train and classify the index. The caveat here is that any non-configuration reflected, runtime modifications to the input classifier will not be reflected by the classifier used in ``rank``. Using a copy of the input classifier allows the ``rank`` method to be used in parallel without blocking other calls to ``rank``. :param classifier_inst: Supervised classifier instance to base the ephemeral ranking classifier on. The type and configuration of this classifier is used to create a clone at rank time. The input classifier instance is not modified. """ def __init__(self, classifier_inst: ClassifyDescriptorSupervised): super().__init__() self._classifier_type = type(classifier_inst) self._classifier_config = classifier_inst.get_config()
[docs] @classmethod def get_default_config(cls) -> Dict[str, Any]: c = super().get_default_config() c['classifier_inst'] = \ make_default_config(ClassifyDescriptorSupervised.get_impls()) return c
[docs] @classmethod def from_config(cls: Type[T], config_dict: Dict[str, Any], merge_default: bool = True) -> T: config_dict = dict(config_dict) # shallow copy to write to input dict config_dict['classifier_inst'] = \ from_config_dict(config_dict.get('classifier_inst', {}), ClassifyDescriptorSupervised.get_impls()) return super(RankRelevancyWithSupervisedClassifier, cls).from_config( config_dict, merge_default=merge_default, )
[docs] def get_config(self) -> Dict[str, Any]: return { 'classifier_inst': cls_conf_to_config_dict(self._classifier_type, self._classifier_config), }
[docs] def rank( self, pos: Sequence[np.ndarray], neg: Sequence[np.ndarray], pool: Sequence[np.ndarray], ) -> Sequence[float]: if len(pool) == 0: return [] # Train supervised classifier with positive/negative examples. label_pos = 'pos' label_neg = 'neg' i = 0 def create_de(v: np.ndarray) -> DescriptorElement: nonlocal i # Hopefully type_str doesn't matter de = DescriptorMemoryElement(i) de.set_vector(v) i += 1 return de classifier = self._classifier_type.from_config(self._classifier_config) classifier.train({ label_pos: map(create_de, pos), label_neg: map(create_de, neg), }) # Report ``label_pos`` class probabilities as rank score. scores = classifier.classify_arrays(pool) return [c_map.get(label_pos, 0.0) for c_map in scores]