Source code for smqtk_relevancy.impls.rank_relevancy.margin_sampling

from typing import Hashable, Sequence, Tuple, Dict, Any, TypeVar, Type

from numpy import ndarray

from smqtk_core.configuration import (
    from_config_dict,
    make_default_config,
    to_config_dict,
)
from smqtk_core.dict import merge_dict
from smqtk_relevancy.interfaces.rank_relevancy import RankRelevancy, RankRelevancyWithFeedback

T = TypeVar("T", bound="RankRelevancyWithMarginSampledFeedback")


[docs]class RankRelevancyWithMarginSampledFeedback(RankRelevancyWithFeedback): """ Wrap an instance of :class:`RankRelevancy` to provide feedback via margin sampling :param rank_relevancy: :class:`RankRelevancy` to use for computing relevancy scores :param n: Maximum number of items to return for feedback :param center: Value for which pool items whose relevancy score is closest to it will be returned for feedback (default: 0.5) :raises ValueError: n is negative """ def __init__(self, rank_relevancy: RankRelevancy, n: int, center: float = 0.5): self._rank_relevancy = rank_relevancy if n < 0: raise ValueError(f"n must be nonnegative but got {n}") self._n = n self._center = center def _rank_with_feedback( self, pos: Sequence[ndarray], neg: Sequence[ndarray], pool: Sequence[ndarray], pool_uids: Sequence[Hashable], ) -> Tuple[Sequence[float], Sequence[Hashable]]: scores = self._rank_relevancy.rank(pos, neg, pool) c = self._center ranked = sorted(zip(scores, pool_uids), key=lambda su: abs(su[0] - c)) return scores, [r[1] for r in ranked[:self._n]]
[docs] @classmethod def from_config(cls: Type[T], config_dict: Dict[str, Any], merge_default: bool = True) -> T: config_dict = dict(config_dict, rank_relevancy=from_config_dict( config_dict['rank_relevancy'], RankRelevancy.get_impls(), )) return super().from_config(config_dict, merge_default=merge_default)
[docs] @classmethod def get_default_config(cls) -> Dict[str, Any]: c = super().get_default_config() rr_default = make_default_config(RankRelevancy.get_impls()) return dict(c, rank_relevancy=rr_default)
[docs] def get_config(self) -> Dict[str, Any]: return merge_dict(self.get_default_config(), dict( rank_relevancy=to_config_dict(self._rank_relevancy), n=self._n, center=self._center, ))