Skip to content

zero_shot_classification

mindnlp.transformers.pipelines.zero_shot_classification.ZeroShotClassificationPipeline

Bases: ChunkPipeline

NLI-based zero-shot classification pipeline using a ModelForSequenceClassification trained on NLI (natural language inference) tasks. Equivalent of text-classification pipelines, but these models don't require a hardcoded number of potential classes, they can be chosen at runtime. It usually means it's slower but it is much more flexible.

Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis pair and passed to the pretrained model. Then, the logit for entailment is taken as the logit for the candidate label being valid. Any NLI model can be used, but the id of the entailment label must be included in the model config's :attr:~transformers.PretrainedConfig.label2id.

Example
>>> from transformers import pipeline
...
>>> oracle = pipeline(model="facebook/bart-large-mnli")
>>> oracle(
...     "I have a problem with my iphone that needs to be resolved asap!!",
...     candidate_labels=["urgent", "not urgent", "phone", "tablet", "computer"],
... )
{'sequence': 'I have a problem with my iphone that needs to be resolved asap!!',
    'labels': ['urgent', 'phone', 'computer', 'not urgent', 'tablet'],
    'scores': [0.504, 0.479, 0.013, 0.003, 0.002]}
...
>>> oracle(
...     "I have a problem with my iphone that needs to be resolved asap!!",
...     candidate_labels=["english", "german"],
... )
{'sequence': 'I have a problem with my iphone that needs to be resolved asap!!',
    'labels': ['english', 'german'], 'scores': [0.814, 0.186]}

Learn more about the basics of using a pipeline in the pipeline tutorial

This NLI pipeline can currently be loaded from [pipeline] using the following task identifier: "zero-shot-classification".

The models that this pipeline can use are models that have been fine-tuned on an NLI task. See the up-to-date list of available models on hf-mirror.com/models.

Source code in mindnlp/transformers/pipelines/zero_shot_classification.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
class ZeroShotClassificationPipeline(ChunkPipeline):
    """
    NLI-based zero-shot classification pipeline using a `ModelForSequenceClassification` trained on NLI (natural
    language inference) tasks. Equivalent of `text-classification` pipelines, but these models don't require a
    hardcoded number of potential classes, they can be chosen at runtime. It usually means it's slower but it is
    **much** more flexible.

    Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
    pair and passed to the pretrained model. Then, the logit for *entailment* is taken as the logit for the candidate
    label being valid. Any NLI model can be used, but the id of the *entailment* label must be included in the model
    config's :attr:*~transformers.PretrainedConfig.label2id*.

    Example:
        ```python
        >>> from transformers import pipeline
        ...
        >>> oracle = pipeline(model="facebook/bart-large-mnli")
        >>> oracle(
        ...     "I have a problem with my iphone that needs to be resolved asap!!",
        ...     candidate_labels=["urgent", "not urgent", "phone", "tablet", "computer"],
        ... )
        {'sequence': 'I have a problem with my iphone that needs to be resolved asap!!',
            'labels': ['urgent', 'phone', 'computer', 'not urgent', 'tablet'],
            'scores': [0.504, 0.479, 0.013, 0.003, 0.002]}
        ...
        >>> oracle(
        ...     "I have a problem with my iphone that needs to be resolved asap!!",
        ...     candidate_labels=["english", "german"],
        ... )
        {'sequence': 'I have a problem with my iphone that needs to be resolved asap!!',
            'labels': ['english', 'german'], 'scores': [0.814, 0.186]}
        ```

    Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)

    This NLI pipeline can currently be loaded from [`pipeline`] using the following task identifier:
    `"zero-shot-classification"`.

    The models that this pipeline can use are models that have been fine-tuned on an NLI task.
    See the up-to-date list
    of available models on [hf-mirror.com/models](https://hf-mirror.com/models?search=nli).
    """

    def __init__(self, *args, args_parser=ZeroShotClassificationArgumentHandler(), **kwargs):
        """
        Initializes a new instance of the ZeroShotClassificationPipeline class.

        Args:
            self: The instance of the ZeroShotClassificationPipeline class.
            *args: Variable length argument list.
            args_parser:
                An instance of the ZeroShotClassificationArgumentHandler class that handles the arguments for
                zero-shot classification. Defaults to ZeroShotClassificationArgumentHandler().
            **kwargs: Keyword arguments.

        Returns:
            None.

        Raises:
            None.
        """
        self._args_parser = args_parser
        super().__init__(*args, **kwargs)
        if self.entailment_id == -1:
            logger.warning(
                "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
                "-1. Define a descriptive label2id mapping in the model config to ensure correct outputs."
            )

    @property
    def entailment_id(self):
        """
        Returns the index of the 'entailment' label in the label-to-identifier mapping of the
        ZeroShotClassificationPipeline's model configuration.

        Args:
            self (ZeroShotClassificationPipeline): The current instance of the ZeroShotClassificationPipeline class.

        Returns:
            int: The index of the 'entailment' label in the label-to-identifier mapping. If the 'entailment' label is
                not found, -1 is returned.

        Raises:
            None.

        """
        for label, ind in self.model.config.label2id.items():
            if label.lower().startswith("entail"):
                return ind
        return -1

    def _parse_and_tokenize(
            self, sequence_pairs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.ONLY_FIRST
    ):
        """
        Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
        """
        return_tensors = 'ms'
        if self.tokenizer.pad_token is None:
            # Override for tokenizers not supporting padding
            logger.error(
                "Tokenizer was not supporting padding necessary for zero-shot, attempting to use "
                " `pad_token=eos_token`"
            )
            self.tokenizer.pad_token = self.tokenizer.eos_token
        try:
            inputs = self.tokenizer(
                sequence_pairs,
                add_special_tokens=add_special_tokens,
                return_tensors=return_tensors,
                padding=padding,
                truncation=truncation,
            )
        except Exception as exception:
            if "too short" in str(exception):
                # tokenizers might yell that we want to truncate
                # to a value that is not even reached by the input.
                # In that case we don't want to truncate.
                # It seems there's not a really better way to catch that
                # exception.

                inputs = self.tokenizer(
                    sequence_pairs,
                    add_special_tokens=add_special_tokens,
                    return_tensors=return_tensors,
                    padding=padding,
                    truncation=TruncationStrategy.DO_NOT_TRUNCATE,
                )
            else:
                raise exception

        return inputs

    def _sanitize_parameters(self, **kwargs):
        """
        Sanitizes the parameters for the ZeroShotClassificationPipeline.

        Args:
            self: An instance of the ZeroShotClassificationPipeline class.

        Returns:
            None.

        Raises:
            None.

        This method performs the following tasks:

        - Renames the deprecated 'multi_class' argument to 'multi_label' if provided and logs a warning.
        - Parses and sanitizes the 'candidate_labels' parameter if provided.
        - Retrieves the 'hypothesis_template' parameter if provided.
        - Collects the 'multi_label' parameter if provided.

        Note:
            - The 'multi_class' argument has been deprecated and renamed to 'multi_label'. 'multi_class' will be
            removed in a future version of Transformers.
            - The 'candidate_labels' parameter should be a list of strings representing the labels.
            - The 'hypothesis_template' parameter should be a string representing the template for the hypothesis.
            - The 'multi_label' parameter should be a boolean indicating whether multi-label classification should be used.

        Example:
            ```python
            >>> pipeline = ZeroShotClassificationPipeline()
            >>> pipeline._sanitize_parameters(multi_class=True, candidate_labels=['label1', 'label2'], hypothesis_template='This text is about {}.')
            ```
        """
        if kwargs.get("multi_class", None) is not None:
            kwargs["multi_label"] = kwargs["multi_class"]
            logger.warning(
                "The `multi_class` argument has been deprecated and renamed to `multi_label`. "
                "`multi_class` will be removed in a future version of Transformers."
            )
        preprocess_params = {}
        if "candidate_labels" in kwargs:
            preprocess_params["candidate_labels"] = self._args_parser._parse_labels(kwargs["candidate_labels"])
        if "hypothesis_template" in kwargs:
            preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]

        postprocess_params = {}
        if "multi_label" in kwargs:
            postprocess_params["multi_label"] = kwargs["multi_label"]
        return preprocess_params, {}, postprocess_params

    def __call__(
            self,
            sequences: Union[str, List[str]],
            *args,
            **kwargs,
    ):
        """
        Classify the sequence(s) given as inputs. See the [`ZeroShotClassificationPipeline`] documentation
        for more information.

        Args:
            sequences (`str` or `List[str]`):
                The sequence(s) to classify, will be truncated if the model input is too large.
            candidate_labels (`str` or `List[str]`):
                The set of possible class labels to classify each sequence into.
                Can be a single label, a string of
                comma-separated labels, or a list of labels.
            hypothesis_template (`str`, *optional*, defaults to `"This example is {}."`):
                The template used to turn each label into an NLI-style hypothesis.
                This template must include a {} or similar syntax for the candidate
                label to be inserted into the template. For example, the default
                template is `"This example is {}."` With the candidate label `"sports"`,
                this would be fed into the model like `"<cls> sequence to classify <sep> This example is sports . <sep>"`.
                The default templateworks well in many cases,
                but it may be worthwhile to experiment with different templates depending on the task setting.
            multi_label (`bool`, *optional*, defaults to `False`):
                Whether or not multiple candidate labels can be true. If `False`, the scores are normalized such that
                the sum of the label likelihoods for each sequence is 1. If `True`,
                the labels are considered independent and probabilities are normalized for each candidate
                by doing a softmax of the entailment score vs. the contradiction score.

        Returns:
            A `dict` or a list of `dict`:
                Each result comes as a dictionary with the following keys:

                - **sequence** (`str`) -- The sequence for which this is the output.
                - **labels** (`List[str]`) -- The labels sorted by order of likelihood.
                - **scores** (`List[float]`) -- The probabilities for each of the labels.
        """
        if len(args) == 0:
            pass
        elif len(args) == 1 and "candidate_labels" not in kwargs:
            kwargs["candidate_labels"] = args[0]
        else:
            raise ValueError(f"Unable to understand extra arguments {args}")

        return super().__call__(sequences, **kwargs)

    def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
        """
        This method preprocesses inputs for zero-shot classification and generates model inputs for each candidate label.

        Args:
            self: The instance of the ZeroShotClassificationPipeline class.
            inputs: The input sequences to be classified.
            candidate_labels: The list of candidate labels for classification. Defaults to None.
            hypothesis_template: The template string for the hypothesis. Defaults to 'This example is {}'.

        Returns:
            None: This method yields dictionaries with model inputs for each candidate label.

        Raises:
            None.
        """
        sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
        for i, (candidate_label, sequence_pair) in enumerate(zip(candidate_labels, sequence_pairs)):
            model_input = self._parse_and_tokenize([sequence_pair])
            yield {
                "candidate_label": candidate_label,
                "sequence": sequences[0],
                "is_last": i == len(candidate_labels) - 1,
                **model_input,
            }

    def _forward(self, inputs):
        """
        Executes the forward pass for the ZeroShotClassificationPipeline.

        Args:
            self (ZeroShotClassificationPipeline): The instance of the ZeroShotClassificationPipeline class.
            inputs (dict): A dictionary containing the input data for the forward pass.
                - candidate_label (str): The candidate label for classification.
                - sequence (str): The sequence to classify.

        Returns:
            None

        Raises:
            None
        """
        candidate_label = inputs["candidate_label"]
        sequence = inputs["sequence"]
        model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names}

        #`XXForSequenceClassification` models should not use `use_cache=True` even if it's supported
        model_forward = self.model.forward
        if "use_cache" in inspect.signature(model_forward).parameters.keys():
            model_inputs["use_cache"] = False
        outputs = model_forward(**model_inputs)
        model_outputs = {
            "candidate_label": candidate_label,
            "sequence": sequence,
            "is_last": inputs["is_last"],
            **outputs,
        }
        return model_outputs

    def postprocess(self, model_outputs, multi_label=False):
        """
        This method postprocesses the model outputs for a ZeroShotClassificationPipeline.

        Args:
            self (object): The instance of the ZeroShotClassificationPipeline class.
            model_outputs (list): A list of dictionaries containing the model outputs.
                Each dictionary must have the keys 'candidate_label', 'sequence', and 'logits'.
                The 'candidate_label' key represents the candidate label, 'sequence' key represents the sequence,
                and 'logits' key holds the logits values.
            multi_label (bool): A flag indicating whether the classification is multi-label or not.
                If set to True, the method processes the outputs accordingly.

        Returns:
            dict: A dictionary containing the processed information of the model outputs.
                It includes the 'sequence' key with the sequence value, 'labels' key with the list of candidate labels
                in descending order of their scores, and 'scores' key with the corresponding scores of the candidate
                labels.

        Raises:
            IndexError: If the indices accessed during processing are out of bounds.
            ValueError: If there are issues with the input data or calculations within the method.
        """
        candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
        sequences = [outputs["sequence"] for outputs in model_outputs]
        logits = np.concatenate([output["logits"].numpy() for output in model_outputs])
        num_examples = logits.shape[0]
        num_candidates = len(candidate_labels)
        num_sequences = num_examples // num_candidates
        reshaped_outputs = logits.reshape((num_sequences, num_candidates, -1))

        if multi_label or len(candidate_labels) == 1:
            # softmax over the entailment vs. contradiction dim for each label independently
            entailment_id = self.entailment_id
            contradiction_id = -1 if entailment_id == 0 else 0
            entail_contr_logits = reshaped_outputs[..., [contradiction_id, entailment_id]]
            scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
            scores = scores[..., 1]
        else:
            # softmax the "entailment" logits over all candidate labels
            entail_logits = reshaped_outputs[..., self.entailment_id]
            scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)

        top_inds = list(reversed(scores[0].argsort()))
        return {
            "sequence": sequences[0],
            "labels": [candidate_labels[i] for i in top_inds],
            "scores": scores[0, top_inds].tolist()}

mindnlp.transformers.pipelines.zero_shot_classification.ZeroShotClassificationPipeline.entailment_id property

Returns the index of the 'entailment' label in the label-to-identifier mapping of the ZeroShotClassificationPipeline's model configuration.

PARAMETER DESCRIPTION
self

The current instance of the ZeroShotClassificationPipeline class.

TYPE: ZeroShotClassificationPipeline

RETURNS DESCRIPTION
int

The index of the 'entailment' label in the label-to-identifier mapping. If the 'entailment' label is not found, -1 is returned.

mindnlp.transformers.pipelines.zero_shot_classification.ZeroShotClassificationPipeline.__call__(sequences, *args, **kwargs)

Classify the sequence(s) given as inputs. See the [ZeroShotClassificationPipeline] documentation for more information.

PARAMETER DESCRIPTION
sequences

The sequence(s) to classify, will be truncated if the model input is too large.

TYPE: `str` or `List[str]`

candidate_labels

The set of possible class labels to classify each sequence into. Can be a single label, a string of comma-separated labels, or a list of labels.

TYPE: `str` or `List[str]`

hypothesis_template

The template used to turn each label into an NLI-style hypothesis. This template must include a {} or similar syntax for the candidate label to be inserted into the template. For example, the default template is "This example is {}." With the candidate label "sports", this would be fed into the model like "<cls> sequence to classify <sep> This example is sports . <sep>". The default templateworks well in many cases, but it may be worthwhile to experiment with different templates depending on the task setting.

TYPE: `str`, *optional*, defaults to `"This example is {}."`

multi_label

Whether or not multiple candidate labels can be true. If False, the scores are normalized such that the sum of the label likelihoods for each sequence is 1. If True, the labels are considered independent and probabilities are normalized for each candidate by doing a softmax of the entailment score vs. the contradiction score.

TYPE: `bool`, *optional*, defaults to `False`

RETURNS DESCRIPTION

A dict or a list of dict: Each result comes as a dictionary with the following keys:

  • sequence (str) -- The sequence for which this is the output.
  • labels (List[str]) -- The labels sorted by order of likelihood.
  • scores (List[float]) -- The probabilities for each of the labels.
Source code in mindnlp/transformers/pipelines/zero_shot_classification.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def __call__(
        self,
        sequences: Union[str, List[str]],
        *args,
        **kwargs,
):
    """
    Classify the sequence(s) given as inputs. See the [`ZeroShotClassificationPipeline`] documentation
    for more information.

    Args:
        sequences (`str` or `List[str]`):
            The sequence(s) to classify, will be truncated if the model input is too large.
        candidate_labels (`str` or `List[str]`):
            The set of possible class labels to classify each sequence into.
            Can be a single label, a string of
            comma-separated labels, or a list of labels.
        hypothesis_template (`str`, *optional*, defaults to `"This example is {}."`):
            The template used to turn each label into an NLI-style hypothesis.
            This template must include a {} or similar syntax for the candidate
            label to be inserted into the template. For example, the default
            template is `"This example is {}."` With the candidate label `"sports"`,
            this would be fed into the model like `"<cls> sequence to classify <sep> This example is sports . <sep>"`.
            The default templateworks well in many cases,
            but it may be worthwhile to experiment with different templates depending on the task setting.
        multi_label (`bool`, *optional*, defaults to `False`):
            Whether or not multiple candidate labels can be true. If `False`, the scores are normalized such that
            the sum of the label likelihoods for each sequence is 1. If `True`,
            the labels are considered independent and probabilities are normalized for each candidate
            by doing a softmax of the entailment score vs. the contradiction score.

    Returns:
        A `dict` or a list of `dict`:
            Each result comes as a dictionary with the following keys:

            - **sequence** (`str`) -- The sequence for which this is the output.
            - **labels** (`List[str]`) -- The labels sorted by order of likelihood.
            - **scores** (`List[float]`) -- The probabilities for each of the labels.
    """
    if len(args) == 0:
        pass
    elif len(args) == 1 and "candidate_labels" not in kwargs:
        kwargs["candidate_labels"] = args[0]
    else:
        raise ValueError(f"Unable to understand extra arguments {args}")

    return super().__call__(sequences, **kwargs)

mindnlp.transformers.pipelines.zero_shot_classification.ZeroShotClassificationPipeline.__init__(*args, args_parser=ZeroShotClassificationArgumentHandler(), **kwargs)

Initializes a new instance of the ZeroShotClassificationPipeline class.

PARAMETER DESCRIPTION
self

The instance of the ZeroShotClassificationPipeline class.

*args

Variable length argument list.

DEFAULT: ()

args_parser

An instance of the ZeroShotClassificationArgumentHandler class that handles the arguments for zero-shot classification. Defaults to ZeroShotClassificationArgumentHandler().

DEFAULT: ZeroShotClassificationArgumentHandler()

**kwargs

Keyword arguments.

DEFAULT: {}

RETURNS DESCRIPTION

None.

Source code in mindnlp/transformers/pipelines/zero_shot_classification.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def __init__(self, *args, args_parser=ZeroShotClassificationArgumentHandler(), **kwargs):
    """
    Initializes a new instance of the ZeroShotClassificationPipeline class.

    Args:
        self: The instance of the ZeroShotClassificationPipeline class.
        *args: Variable length argument list.
        args_parser:
            An instance of the ZeroShotClassificationArgumentHandler class that handles the arguments for
            zero-shot classification. Defaults to ZeroShotClassificationArgumentHandler().
        **kwargs: Keyword arguments.

    Returns:
        None.

    Raises:
        None.
    """
    self._args_parser = args_parser
    super().__init__(*args, **kwargs)
    if self.entailment_id == -1:
        logger.warning(
            "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
            "-1. Define a descriptive label2id mapping in the model config to ensure correct outputs."
        )

mindnlp.transformers.pipelines.zero_shot_classification.ZeroShotClassificationPipeline.postprocess(model_outputs, multi_label=False)

This method postprocesses the model outputs for a ZeroShotClassificationPipeline.

PARAMETER DESCRIPTION
self

The instance of the ZeroShotClassificationPipeline class.

TYPE: object

model_outputs

A list of dictionaries containing the model outputs. Each dictionary must have the keys 'candidate_label', 'sequence', and 'logits'. The 'candidate_label' key represents the candidate label, 'sequence' key represents the sequence, and 'logits' key holds the logits values.

TYPE: list

multi_label

A flag indicating whether the classification is multi-label or not. If set to True, the method processes the outputs accordingly.

TYPE: bool DEFAULT: False

RETURNS DESCRIPTION
dict

A dictionary containing the processed information of the model outputs. It includes the 'sequence' key with the sequence value, 'labels' key with the list of candidate labels in descending order of their scores, and 'scores' key with the corresponding scores of the candidate labels.

RAISES DESCRIPTION
IndexError

If the indices accessed during processing are out of bounds.

ValueError

If there are issues with the input data or calculations within the method.

Source code in mindnlp/transformers/pipelines/zero_shot_classification.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def postprocess(self, model_outputs, multi_label=False):
    """
    This method postprocesses the model outputs for a ZeroShotClassificationPipeline.

    Args:
        self (object): The instance of the ZeroShotClassificationPipeline class.
        model_outputs (list): A list of dictionaries containing the model outputs.
            Each dictionary must have the keys 'candidate_label', 'sequence', and 'logits'.
            The 'candidate_label' key represents the candidate label, 'sequence' key represents the sequence,
            and 'logits' key holds the logits values.
        multi_label (bool): A flag indicating whether the classification is multi-label or not.
            If set to True, the method processes the outputs accordingly.

    Returns:
        dict: A dictionary containing the processed information of the model outputs.
            It includes the 'sequence' key with the sequence value, 'labels' key with the list of candidate labels
            in descending order of their scores, and 'scores' key with the corresponding scores of the candidate
            labels.

    Raises:
        IndexError: If the indices accessed during processing are out of bounds.
        ValueError: If there are issues with the input data or calculations within the method.
    """
    candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
    sequences = [outputs["sequence"] for outputs in model_outputs]
    logits = np.concatenate([output["logits"].numpy() for output in model_outputs])
    num_examples = logits.shape[0]
    num_candidates = len(candidate_labels)
    num_sequences = num_examples // num_candidates
    reshaped_outputs = logits.reshape((num_sequences, num_candidates, -1))

    if multi_label or len(candidate_labels) == 1:
        # softmax over the entailment vs. contradiction dim for each label independently
        entailment_id = self.entailment_id
        contradiction_id = -1 if entailment_id == 0 else 0
        entail_contr_logits = reshaped_outputs[..., [contradiction_id, entailment_id]]
        scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
        scores = scores[..., 1]
    else:
        # softmax the "entailment" logits over all candidate labels
        entail_logits = reshaped_outputs[..., self.entailment_id]
        scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)

    top_inds = list(reversed(scores[0].argsort()))
    return {
        "sequence": sequences[0],
        "labels": [candidate_labels[i] for i in top_inds],
        "scores": scores[0, top_inds].tolist()}

mindnlp.transformers.pipelines.zero_shot_classification.ZeroShotClassificationPipeline.preprocess(inputs, candidate_labels=None, hypothesis_template='This example is {}.')

This method preprocesses inputs for zero-shot classification and generates model inputs for each candidate label.

PARAMETER DESCRIPTION
self

The instance of the ZeroShotClassificationPipeline class.

inputs

The input sequences to be classified.

candidate_labels

The list of candidate labels for classification. Defaults to None.

DEFAULT: None

hypothesis_template

The template string for the hypothesis. Defaults to 'This example is {}'.

DEFAULT: 'This example is {}.'

RETURNS DESCRIPTION
None

This method yields dictionaries with model inputs for each candidate label.

Source code in mindnlp/transformers/pipelines/zero_shot_classification.py
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
    """
    This method preprocesses inputs for zero-shot classification and generates model inputs for each candidate label.

    Args:
        self: The instance of the ZeroShotClassificationPipeline class.
        inputs: The input sequences to be classified.
        candidate_labels: The list of candidate labels for classification. Defaults to None.
        hypothesis_template: The template string for the hypothesis. Defaults to 'This example is {}'.

    Returns:
        None: This method yields dictionaries with model inputs for each candidate label.

    Raises:
        None.
    """
    sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
    for i, (candidate_label, sequence_pair) in enumerate(zip(candidate_labels, sequence_pairs)):
        model_input = self._parse_and_tokenize([sequence_pair])
        yield {
            "candidate_label": candidate_label,
            "sequence": sequences[0],
            "is_last": i == len(candidate_labels) - 1,
            **model_input,
        }