Skip to content

Gating

BlurGate

Source code in video_sampler/gating.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class BlurGate:
    def __init__(
        self, method: Literal["fft", "laplacian"] = "laplacian", threshold: float = 100
    ) -> None:
        """
        Initializes the Gating object.

        Args:
            method (str): The method to use for blur detection. Can be "fft" or "laplacian".
            threshold (float): The threshold for bluriness. The higher the threshold, the less
                blurry the image needs to be to be discarded.
                The default threshold values are:
                - 20 for the "fft" method
                - 100 for the "laplacian" method.

        Raises:
            ValueError: If an unknown blur method is provided.
        """
        self.is_blurry = None
        if method == "fft":
            self.is_blurry = self._is_blurry_fft
        elif method == "laplacian":
            self.is_blurry = self._is_blurry_laplacian
        else:
            raise ValueError(f"Unknown blur method {method}")
        self.threshold = threshold

    def __call__(self, frame: Image.Image, meta: dict, last=False) -> GatedObject:
        if self.is_blurry(frame) or last:
            return EMPTY_GATED_OBJECT
        return GatedObject([FrameObject(frame, meta)], 1)

    def _is_blurry_laplacian(self, frame: Image.Image) -> bool:
        """Check if the image is blurry with laplacian method."""
        return (
            cv2.Laplacian(
                cv2.cvtColor(np.array(frame), cv2.COLOR_BGR2GRAY), cv2.CV_64F
            ).var()
            < self.threshold
        )

    def _is_blurry_fft(self, frame: Image.Image) -> bool:
        """Check if the image is blurry with fft method."""
        f = np.fft.fft2(frame)
        fshift = np.fft.fftshift(f)
        magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1e-12)
        return magnitude_spectrum.mean() < self.threshold

    def flush(self):
        return EMPTY_GATED_OBJECT

__init__(method='laplacian', threshold=100)

Initializes the Gating object.

Parameters:

Name Type Description Default
method str

The method to use for blur detection. Can be "fft" or "laplacian".

'laplacian'
threshold float

The threshold for bluriness. The higher the threshold, the less blurry the image needs to be to be discarded. The default threshold values are: - 20 for the "fft" method - 100 for the "laplacian" method.

100

Raises:

Type Description
ValueError

If an unknown blur method is provided.

Source code in video_sampler/gating.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __init__(
    self, method: Literal["fft", "laplacian"] = "laplacian", threshold: float = 100
) -> None:
    """
    Initializes the Gating object.

    Args:
        method (str): The method to use for blur detection. Can be "fft" or "laplacian".
        threshold (float): The threshold for bluriness. The higher the threshold, the less
            blurry the image needs to be to be discarded.
            The default threshold values are:
            - 20 for the "fft" method
            - 100 for the "laplacian" method.

    Raises:
        ValueError: If an unknown blur method is provided.
    """
    self.is_blurry = None
    if method == "fft":
        self.is_blurry = self._is_blurry_fft
    elif method == "laplacian":
        self.is_blurry = self._is_blurry_laplacian
    else:
        raise ValueError(f"Unknown blur method {method}")
    self.threshold = threshold

ClipGate

Source code in video_sampler/gating.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
class ClipGate:
    def __init__(
        self,
        pos_samples: list[str] = None,
        neg_samples: list[str] = None,
        model_name: str = "ViT-B-32",
        batch_size: int = 32,
        pos_margin: float = 0.2,
        neg_margin: float = 0.3,
    ) -> None:
        """
        Initializes the Clip Gating object.

        Args:
            pos_samples (list[str], optional): List of positive samples. Defaults to None.
            neg_samples (list[str], optional): List of negative samples. Defaults to None.
            model_name (str, optional): Name of the model. Defaults to "ViT-B-32".
            batch_size (int, optional): Batch size. Defaults to 32.
            pos_margin (float, optional): Positive margin. Defaults to 0.2.
            neg_margin (float, optional): Negative margin. Defaults to 0.3.
        """
        self.model, self.preprocess, self.tokenizer = create_model(
            model_name=model_name
        )
        self.pos_margin = pos_margin
        self.neg_margin = neg_margin
        self.batch_size = batch_size
        self.frame_accumulator = []
        self.metadata_accumulator = []
        if pos_samples is None:
            self.pos_samples = torch.zeros((1, 512))
        else:
            self.pos_samples = self._preproc_samples(pos_samples)
        if neg_samples is None:
            self.neg_samples = torch.zeros((1, 512))
        else:
            self.neg_samples = self._preproc_samples(neg_samples)

    def __call__(self, frame: Image.Image, meta: dict, last=False) -> Any:
        return self.flush() if last else self.add_frame(frame, meta)

    def _preproc_samples(self, sample_texts: list[str]):
        inputs = self.tokenizer(sample_texts)
        embeds = torch.zeros((len(sample_texts), 512))
        with torch.no_grad():
            for i, batch in enumerate(batched(inputs, n=self.batch_size)):
                batch = torch.stack(batch)
                text_embeds = self.model.encode_text(batch.to(DEVICE))
                embeds[i * self.batch_size : (i + 1) * self.batch_size] = (
                    text_embeds.cpu()
                )
        embeds /= embeds.norm(dim=-1, keepdim=True)
        return embeds

    def _embed_frames(self, frames: list[Image.Image]):
        """Compute the embeddings for each frame."""
        inputs = torch.stack([self.preprocess(frame) for frame in frames]).to(DEVICE)
        with torch.no_grad():
            image_embeds = self.model.encode_image(inputs).cpu()
            image_embeds /= image_embeds.norm(dim=-1, keepdim=True)
        return image_embeds

    def _get_margins(self, frame_embeds: "torch.Tensor"):
        """Compute the margins for each frame."""
        org_indx = np.arange(frame_embeds.shape[0])
        neg_distance = frame_embeds @ self.neg_samples.T
        pos_distance = frame_embeds @ self.pos_samples.T
        neg_margin, _ = neg_distance.max(axis=-1)
        pos_margin, _ = pos_distance.max(axis=-1)
        incl_samples = torch.argwhere(
            (neg_margin < self.neg_margin) & (pos_margin >= self.pos_margin)
        )
        return org_indx[incl_samples].ravel()

    def add_frame(self, frame: Image.Image, metadata: dict) -> GatedObject:
        self.frame_accumulator.append(frame)
        self.metadata_accumulator.append(metadata)
        if len(self.frame_accumulator) == self.batch_size:
            return self.__process_metadata()
        return EMPTY_GATED_OBJECT

    def flush(self):
        return self.__process_metadata()

    def __process_metadata(self) -> GatedObject:
        frame_embeds = self._embed_frames(self.frame_accumulator)
        selected_frames = self._get_margins(frame_embeds)
        to_return = [
            FrameObject(self.frame_accumulator[i], self.metadata_accumulator[i])
            for i in range(len(self.frame_accumulator))
            if i in selected_frames
        ]
        self.frame_accumulator.clear()
        self.metadata_accumulator.clear()
        return GatedObject(to_return, len(selected_frames))

__init__(pos_samples=None, neg_samples=None, model_name='ViT-B-32', batch_size=32, pos_margin=0.2, neg_margin=0.3)

Initializes the Clip Gating object.

Parameters:

Name Type Description Default
pos_samples list[str]

List of positive samples. Defaults to None.

None
neg_samples list[str]

List of negative samples. Defaults to None.

None
model_name str

Name of the model. Defaults to "ViT-B-32".

'ViT-B-32'
batch_size int

Batch size. Defaults to 32.

32
pos_margin float

Positive margin. Defaults to 0.2.

0.2
neg_margin float

Negative margin. Defaults to 0.3.

0.3
Source code in video_sampler/gating.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
def __init__(
    self,
    pos_samples: list[str] = None,
    neg_samples: list[str] = None,
    model_name: str = "ViT-B-32",
    batch_size: int = 32,
    pos_margin: float = 0.2,
    neg_margin: float = 0.3,
) -> None:
    """
    Initializes the Clip Gating object.

    Args:
        pos_samples (list[str], optional): List of positive samples. Defaults to None.
        neg_samples (list[str], optional): List of negative samples. Defaults to None.
        model_name (str, optional): Name of the model. Defaults to "ViT-B-32".
        batch_size (int, optional): Batch size. Defaults to 32.
        pos_margin (float, optional): Positive margin. Defaults to 0.2.
        neg_margin (float, optional): Negative margin. Defaults to 0.3.
    """
    self.model, self.preprocess, self.tokenizer = create_model(
        model_name=model_name
    )
    self.pos_margin = pos_margin
    self.neg_margin = neg_margin
    self.batch_size = batch_size
    self.frame_accumulator = []
    self.metadata_accumulator = []
    if pos_samples is None:
        self.pos_samples = torch.zeros((1, 512))
    else:
        self.pos_samples = self._preproc_samples(pos_samples)
    if neg_samples is None:
        self.neg_samples = torch.zeros((1, 512))
    else:
        self.neg_samples = self._preproc_samples(neg_samples)

PassGate

Source code in video_sampler/gating.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class PassGate:
    def __call__(self, frame: Image.Image, meta: dict, last=False) -> GatedObject:
        """
        Passes the frame through the gating mechanism.

        Args:
            frame (Image.Image): The frame to pass through.
            meta (dict): The metadata for the frame.
            last (bool): If this is the last frame in the video.

        Returns:
            GatedObject: The gated object containing the processed frame.
        """
        return self.flush() if last else GatedObject([FrameObject(frame, meta)], 1)

    def flush(self):
        return EMPTY_GATED_OBJECT

__call__(frame, meta, last=False)

Passes the frame through the gating mechanism.

Parameters:

Name Type Description Default
frame Image

The frame to pass through.

required
meta dict

The metadata for the frame.

required
last bool

If this is the last frame in the video.

False

Returns:

Name Type Description
GatedObject GatedObject

The gated object containing the processed frame.

Source code in video_sampler/gating.py
36
37
38
39
40
41
42
43
44
45
46
47
48
def __call__(self, frame: Image.Image, meta: dict, last=False) -> GatedObject:
    """
    Passes the frame through the gating mechanism.

    Args:
        frame (Image.Image): The frame to pass through.
        meta (dict): The metadata for the frame.
        last (bool): If this is the last frame in the video.

    Returns:
        GatedObject: The gated object containing the processed frame.
    """
    return self.flush() if last else GatedObject([FrameObject(frame, meta)], 1)

create_gate(gate_config)

Create a gate from a configuration.

Source code in video_sampler/gating.py
203
204
205
206
207
208
209
210
211
212
213
214
def create_gate(gate_config: dict) -> BlurGate | ClipGate | PassGate:
    """Create a gate from a configuration."""
    gate_type = gate_config["type"]
    del gate_config["type"]
    if gate_type == "pass":
        return PassGate()
    elif gate_type == "clip":
        return ClipGate(**gate_config)
    elif gate_type == "blur":
        return BlurGate(**gate_config)
    else:
        raise ValueError(f"Unknown gate type {gate_type}")