Skip to content

switch_transformers

mindnlp.transformers.models.switch_transformers.configuration_switch_transformers

Switch Transformers model configuration

mindnlp.transformers.models.switch_transformers.configuration_switch_transformers.SwitchTransformersConfig

Bases: PretrainedConfig

This is the configuration class to store the configuration of a [SwitchTransformersModel]. It is used to instantiate a SwitchTransformers model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the SwitchTransformers google/switch-base-8 architecture.

Configuration objects inherit from [PretrainedConfig] and can be used to control the model outputs. Read the documentation from [PretrainedConfig] for more information.

PARAMETER DESCRIPTION
vocab_size

Vocabulary size of the SwitchTransformers model. Defines the number of different tokens that can be represented by the inputs_ids passed when calling [SwitchTransformersModel].

TYPE: `int`, *optional*, defaults to 32128 DEFAULT: 32128

d_model

Size of the encoder layers and the pooler layer.

TYPE: `int`, *optional*, defaults to 768 DEFAULT: 768

d_kv

Size of the key, query, value projections per attention head. d_kv has to be equal to d_model // num_heads.

TYPE: `int`, *optional*, defaults to 64 DEFAULT: 64

d_ff

Size of the intermediate feed forward layer in each SwitchTransformersBlock.

TYPE: `int`, *optional*, defaults to 2048 DEFAULT: 2048

expert_capacity

Number of tokens that can be stored in each expert. If set to 1, the model will behave like a regular Transformer.

TYPE: `int`, *optional*, defaults to 64 DEFAULT: 64

num_layers

Number of dense hidden layers in the Transformer encoder layer.

TYPE: `int`, *optional*, defaults to 12 DEFAULT: 12

num_sparse_encoder_layers

Number of sparse (MoE) dense hidden layers in the Transformer encoder layer.

TYPE: `int`, *optional*, defaults to 3 DEFAULT: 3

num_decoder_layers

Number of hidden layers in the Transformer decoder. Will use the same value as num_layers if not set.

TYPE: `int`, *optional*, defaults to 12 DEFAULT: 12

num_sparse_decoder_layers

Number of sparse (MoE) dense hidden layers in the Transformer decoder layer.

TYPE: `int`, *optional*, defaults to 3 DEFAULT: 3

num_heads

Number of attention heads for each attention layer in the Transformer encoder.

TYPE: `int`, *optional*, defaults to 12 DEFAULT: 12

num_experts

Number of experts for each SwitchTransformer layer.

TYPE: `int`, *optional*, defaults to 8 DEFAULT: 8

router_bias

Whether to add a bias to the router.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

router_jitter_noise

Amount of noise to add to the router.

TYPE: `float`, *optional*, defaults to 0.01 DEFAULT: 0.01

router_dtype

The dtype used for the routers. It is preferable to keep the dtype to "float32" as specified in the selective precision discussion in the paper.

TYPE: `str`, *optional*, default to `"float32"` DEFAULT: 'float32'

router_ignore_padding_tokens

Whether to ignore padding tokens when routing.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

relative_attention_num_buckets

The number of buckets to use for each attention layer.

TYPE: `int`, *optional*, defaults to 32 DEFAULT: 32

relative_attention_max_distance

The maximum distance of the longer sequences for the bucket separation.

TYPE: `int`, *optional*, defaults to 128 DEFAULT: 128

dropout_rate

The ratio for all dropout layers.

TYPE: `float`, *optional*, defaults to 0.1 DEFAULT: 0.1

layer_norm_eps

The epsilon used by the layer normalization layers.

TYPE: `float`, *optional*, defaults to 1e-6

router_z_loss_coef

The z loss factor for the total loss.

TYPE: `float`, *optional*, defaults to 0.001 DEFAULT: 0.001

router_aux_loss_coef

The aux loss factor for the total loss.

TYPE: `float`, *optional*, defaults to 0.001 DEFAULT: 0.001

initializer_factor

A factor for initializing all weight matrices (should be kept to 1, used internally for initialization testing).

TYPE: `float`, *optional*, defaults to 1.0 DEFAULT: 1.0

dense_act_fn

Type of feed forward layer to be used. Should be one of "relu" or "gated-gelu". SwitchTransformersv1.1 uses the "gated-gelu" feed forward projection. Original SwitchTransformers uses "relu".

TYPE: `string`, *optional*, defaults to `"relu"` DEFAULT: 'relu'

add_router_probs

Whether to output router probabilities to compute router auxiliary loss.

TYPE: `bool`, *optional*, defaults to `False` DEFAULT: False

use_cache

Whether or not the model should return the last key/values attentions (not used by all models).

TYPE: `bool`, *optional*, defaults to `True` DEFAULT: True

Source code in mindnlp/transformers/models/switch_transformers/configuration_switch_transformers.py
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class SwitchTransformersConfig(PretrainedConfig):
    r"""
    This is the configuration class to store the configuration of a [`SwitchTransformersModel`]. It is used to
    instantiate a SwitchTransformers model according to the specified arguments, defining the model architecture.
    Instantiating a configuration with the defaults will yield a similar configuration to that of the
    SwitchTransformers [google/switch-base-8](https://huggingface.co/google/switch-base-8) architecture.

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Arguments:
        vocab_size (`int`, *optional*, defaults to 32128):
            Vocabulary size of the SwitchTransformers model. Defines the number of different tokens that can be
            represented by the `inputs_ids` passed when calling [`SwitchTransformersModel`].
        d_model (`int`, *optional*, defaults to 768):
            Size of the encoder layers and the pooler layer.
        d_kv (`int`, *optional*, defaults to 64):
            Size of the key, query, value projections per attention head. `d_kv` has to be equal to `d_model //
            num_heads`.
        d_ff (`int`, *optional*, defaults to 2048):
            Size of the intermediate feed forward layer in each `SwitchTransformersBlock`.
        expert_capacity (`int`, *optional*, defaults to 64):
            Number of tokens that can be stored in each expert. If set to 1, the model will behave like a regular
            Transformer.
        num_layers (`int`, *optional*, defaults to 12):
            Number of dense hidden layers in the Transformer encoder layer.
        num_sparse_encoder_layers (`int`, *optional*, defaults to 3):
            Number of sparse (MoE) dense hidden layers in the Transformer encoder layer.
        num_decoder_layers (`int`, *optional*, defaults to 12):
            Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
        num_sparse_decoder_layers (`int`, *optional*, defaults to 3):
            Number of sparse (MoE) dense hidden layers in the Transformer decoder layer.
        num_heads (`int`, *optional*, defaults to 12):
            Number of attention heads for each attention layer in the Transformer encoder.
        num_experts (`int`, *optional*, defaults to 8):
            Number of experts for each SwitchTransformer layer.
        router_bias (`bool`, *optional*, defaults to `False`):
            Whether to add a bias to the router.
        router_jitter_noise (`float`, *optional*, defaults to 0.01):
            Amount of noise to add to the router.
        router_dtype (`str`, *optional*, default to `"float32"`):
            The `dtype` used for the routers. It is preferable to keep the `dtype` to `"float32"` as specified in the
            *selective precision* discussion in [the paper](https://arxiv.org/abs/2101.03961).
        router_ignore_padding_tokens (`bool`, *optional*, defaults to `False`):
            Whether to ignore padding tokens when routing.
        relative_attention_num_buckets (`int`, *optional*, defaults to 32):
            The number of buckets to use for each attention layer.
        relative_attention_max_distance (`int`, *optional*, defaults to 128):
            The maximum distance of the longer sequences for the bucket separation.
        dropout_rate (`float`, *optional*, defaults to 0.1):
            The ratio for all dropout layers.
        layer_norm_eps (`float`, *optional*, defaults to 1e-6):
            The epsilon used by the layer normalization layers.
        router_z_loss_coef (`float`, *optional*, defaults to 0.001):
            The z loss factor for the total loss.
        router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
            The aux loss factor for the total loss.
        initializer_factor (`float`, *optional*, defaults to 1.0):
            A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
            testing).
        dense_act_fn (`string`, *optional*, defaults to `"relu"`):
            Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. SwitchTransformersv1.1
            uses the `"gated-gelu"` feed forward projection. Original SwitchTransformers uses `"relu"`.
        add_router_probs (`bool`, *optional*, defaults to `False`):
            Whether to output router probabilities to compute router auxiliary loss.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models).
    """

    model_type = "switch_transformers"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}

    def __init__(
        self,
        vocab_size=32128,
        d_model=768,
        d_kv=64,
        d_ff=2048,
        expert_capacity=64,
        num_layers=12,
        num_sparse_encoder_layers=3,
        num_decoder_layers=12,
        num_sparse_decoder_layers=3,
        num_heads=12,
        num_experts=8,
        router_bias=False,
        router_jitter_noise=0.01,
        router_dtype="float32",
        router_ignore_padding_tokens=False,
        relative_attention_num_buckets=32,
        relative_attention_max_distance=128,
        dropout_rate=0.1,
        layer_norm_epsilon=1e-6,
        router_z_loss_coef=0.001,
        router_aux_loss_coef=0.001,
        initializer_factor=1.0,
        dense_act_fn="relu",
        is_encoder_decoder=True,
        add_router_probs=False,
        use_cache=True,
        pad_token_id=0,
        eos_token_id=1,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.d_kv = d_kv
        self.d_ff = d_ff

        self.num_sparse_encoder_layers = num_sparse_encoder_layers

        self.num_layers = num_layers
        self.num_decoder_layers = (
            num_decoder_layers if num_decoder_layers is not None else self.num_layers
        )  # default = symmetry
        self.num_sparse_decoder_layers = num_sparse_decoder_layers

        # This tells us, each how many encoder layer we'll have to set a sparse layer.
        if self.num_sparse_encoder_layers > 0:
            self.encoder_sparse_step = self.num_layers // self.num_sparse_encoder_layers
        else:
            self.encoder_sparse_step = self.num_layers  # HACK: this will create 0 sparse layers

        # This tells us, each how many encoder layer we'll have to set a sparse layer.
        if self.num_sparse_decoder_layers > 0:
            self.decoder_sparse_step = self.num_decoder_layers // self.num_sparse_decoder_layers
        else:
            self.decoder_sparse_step = self.num_decoder_layers  # HACK: this will create 0 sparse layers

        self.num_heads = num_heads
        self.num_experts = num_experts
        self.expert_capacity = expert_capacity
        self.router_bias = router_bias
        self.router_jitter_noise = router_jitter_noise
        if router_dtype not in ["float32", "float16", "bfloat16"]:
            raise ValueError(f"`router_dtype` must be one of 'float32', 'float16' or 'bfloat16', got {router_dtype}")
        self.router_dtype = router_dtype

        self.router_ignore_padding_tokens = router_ignore_padding_tokens
        self.relative_attention_num_buckets = relative_attention_num_buckets
        self.relative_attention_max_distance = relative_attention_max_distance

        self.dropout_rate = dropout_rate
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_factor = initializer_factor
        self.use_cache = use_cache
        self.add_router_probs = add_router_probs

        self.router_z_loss_coef = router_z_loss_coef
        self.router_aux_loss_coef = router_aux_loss_coef
        self.dense_act_fn = dense_act_fn

        super().__init__(
            pad_token_id=pad_token_id,
            eos_token_id=eos_token_id,
            is_encoder_decoder=is_encoder_decoder,
            **kwargs,
        )

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers

MindSpore SwitchTransformers model.

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersAttention

Bases: Module

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
class SwitchTransformersAttention(nn.Module):
    def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias=False):
        super().__init__()
        self.is_decoder = config.is_decoder
        self.has_relative_attention_bias = has_relative_attention_bias
        self.relative_attention_num_buckets = config.relative_attention_num_buckets
        self.relative_attention_max_distance = config.relative_attention_max_distance
        self.d_model = config.d_model
        self.key_value_proj_dim = config.d_kv
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
        self.inner_dim = self.n_heads * self.key_value_proj_dim

        # Mesh TensorFlow initialization to avoid scaling before softmax
        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)

        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
        self.pruned_heads = set()
        self.gradient_checkpointing = False

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
        )
        # Prune linear layers
        self.q = prune_linear_layer(self.q, index)
        self.k = prune_linear_layer(self.k, index)
        self.v = prune_linear_layer(self.v, index)
        self.o = prune_linear_layer(self.o, index, dim=1)
        # Update hyper params
        self.n_heads = self.n_heads - len(heads)
        self.inner_dim = self.key_value_proj_dim * self.n_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    @staticmethod
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593

        Translate relative position to a bucket number for relative attention. The relative position is defined as
        memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
        position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
        small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
        positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
        This should allow for more graceful generalization to longer sequences than the model has been trained on

        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer

        Returns:
            a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
        """
        relative_buckets = 0
        if bidirectional:
            num_buckets //= 2
            relative_buckets += (relative_position > 0).astype(mindspore.int64) * num_buckets
            relative_position = ops.abs(relative_position)
        else:
            relative_position = -ops.minimum(relative_position, ops.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
        relative_position_if_large = max_exact + (
            ops.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).astype(mindspore.int64)
        relative_position_if_large = ops.minimum(
            relative_position_if_large, ops.full_like(relative_position_if_large, num_buckets - 1)
        )

        relative_buckets += ops.where(is_small, relative_position, relative_position_if_large)
        return relative_buckets

    def compute_bias(self, query_length, key_length):
        """Compute binned relative position bias"""
        context_position = ops.arange(query_length, dtype=mindspore.int64)[:, None]
        memory_position = ops.arange(key_length, dtype=mindspore.int64)[None, :]
        relative_position = memory_position - context_position  # shape (query_length, key_length)
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
            num_buckets=self.relative_attention_num_buckets,
            max_distance=self.relative_attention_max_distance,
        )
        values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
        values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
        return values

    def forward(
        self,
        hidden_states,
        mask=None,
        key_value_states=None,
        position_bias=None,
        past_key_value=None,
        layer_head_mask=None,
        query_length=None,
        use_cache=False,
        output_attentions=False,
    ):
        """
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        """
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length

        if past_key_value is not None:
            if len(past_key_value) != 2:
                raise ValueError(
                    f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
                )
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

        def shape(states):
            """projection"""
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).swapaxes(1, 2)

        def unshape(states):
            """reshape"""
            return states.swapaxes(1, 2).view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
            """projects hidden states correctly to key/query states"""
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))

            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    hidden_states = ops.cat([past_key_value, hidden_states], axis=2)
                elif past_key_value.shape[2] != key_value_states.shape[1]:
                    # checking that the `sequence_length` of the `past_key_value` is the same as
                    # the provided `key_value_states` to support prefix tuning
                    # cross-attn
                    # (batch_size, n_heads, seq_length, dim_per_head)
                    hidden_states = shape(proj_layer(key_value_states))
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states

        # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )

        # compute scores
        scores = ops.matmul(
            query_states, key_states.swapaxes(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

        if position_bias is None:
            if not self.has_relative_attention_bias:
                position_bias = ops.zeros(
                    (1, self.n_heads, real_seq_length, key_length), dtype=scores.dtype
                )
                if self.gradient_checkpointing and self.training:
                    position_bias.requires_grad = True
            else:
                position_bias = self.compute_bias(real_seq_length, key_length)

            # if key and values are already calculated
            # we want only the last query position bias
            if past_key_value is not None:
                position_bias = position_bias[:, :, -hidden_states.shape[1] :, :]

            if mask is not None:
                position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)

        if self.pruned_heads:
            mask = ops.ones(position_bias.shape[1])
            mask[list(self.pruned_heads)] = 0
            position_bias_masked = position_bias[:, mask.bool()]
        else:
            position_bias_masked = position_bias

        scores += position_bias_masked
        attn_weights = ops.softmax(scores.float(), axis=-1).astype(
            scores.dtype
        )  # (batch_size, n_heads, seq_length, key_length)
        attn_weights = ops.dropout(
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)

        # Mask heads if we want to
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask

        attn_output = unshape(ops.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
        attn_output = self.o(attn_output)

        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

        if output_attentions:
            outputs = outputs + (attn_weights,)
        return outputs

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersAttention.compute_bias(query_length, key_length)

Compute binned relative position bias

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
def compute_bias(self, query_length, key_length):
    """Compute binned relative position bias"""
    context_position = ops.arange(query_length, dtype=mindspore.int64)[:, None]
    memory_position = ops.arange(key_length, dtype=mindspore.int64)[None, :]
    relative_position = memory_position - context_position  # shape (query_length, key_length)
    relative_position_bucket = self._relative_position_bucket(
        relative_position,  # shape (query_length, key_length)
        bidirectional=(not self.is_decoder),
        num_buckets=self.relative_attention_num_buckets,
        max_distance=self.relative_attention_max_distance,
    )
    values = self.relative_attention_bias(relative_position_bucket)  # shape (query_length, key_length, num_heads)
    values = values.permute([2, 0, 1]).unsqueeze(0)  # shape (1, num_heads, query_length, key_length)
    return values

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersAttention.forward(hidden_states, mask=None, key_value_states=None, position_bias=None, past_key_value=None, layer_head_mask=None, query_length=None, use_cache=False, output_attentions=False)

Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
def forward(
    self,
    hidden_states,
    mask=None,
    key_value_states=None,
    position_bias=None,
    past_key_value=None,
    layer_head_mask=None,
    query_length=None,
    use_cache=False,
    output_attentions=False,
):
    """
    Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
    """
    # Input is (batch_size, seq_length, dim)
    # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
    # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
    batch_size, seq_length = hidden_states.shape[:2]

    real_seq_length = seq_length

    if past_key_value is not None:
        if len(past_key_value) != 2:
            raise ValueError(
                f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
            )
        real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

    key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

    def shape(states):
        """projection"""
        return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).swapaxes(1, 2)

    def unshape(states):
        """reshape"""
        return states.swapaxes(1, 2).view(batch_size, -1, self.inner_dim)

    def project(hidden_states, proj_layer, key_value_states, past_key_value):
        """projects hidden states correctly to key/query states"""
        if key_value_states is None:
            # self-attn
            # (batch_size, n_heads, seq_length, dim_per_head)
            hidden_states = shape(proj_layer(hidden_states))
        elif past_key_value is None:
            # cross-attn
            # (batch_size, n_heads, seq_length, dim_per_head)
            hidden_states = shape(proj_layer(key_value_states))

        if past_key_value is not None:
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, key_length, dim_per_head)
                hidden_states = ops.cat([past_key_value, hidden_states], axis=2)
            elif past_key_value.shape[2] != key_value_states.shape[1]:
                # checking that the `sequence_length` of the `past_key_value` is the same as
                # the provided `key_value_states` to support prefix tuning
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))
            else:
                # cross-attn
                hidden_states = past_key_value
        return hidden_states

    # get query states
    query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head)

    # get key/value states
    key_states = project(
        hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
    )
    value_states = project(
        hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
    )

    # compute scores
    scores = ops.matmul(
        query_states, key_states.swapaxes(3, 2)
    )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

    if position_bias is None:
        if not self.has_relative_attention_bias:
            position_bias = ops.zeros(
                (1, self.n_heads, real_seq_length, key_length), dtype=scores.dtype
            )
            if self.gradient_checkpointing and self.training:
                position_bias.requires_grad = True
        else:
            position_bias = self.compute_bias(real_seq_length, key_length)

        # if key and values are already calculated
        # we want only the last query position bias
        if past_key_value is not None:
            position_bias = position_bias[:, :, -hidden_states.shape[1] :, :]

        if mask is not None:
            position_bias = position_bias + mask  # (batch_size, n_heads, seq_length, key_length)

    if self.pruned_heads:
        mask = ops.ones(position_bias.shape[1])
        mask[list(self.pruned_heads)] = 0
        position_bias_masked = position_bias[:, mask.bool()]
    else:
        position_bias_masked = position_bias

    scores += position_bias_masked
    attn_weights = ops.softmax(scores.float(), axis=-1).astype(
        scores.dtype
    )  # (batch_size, n_heads, seq_length, key_length)
    attn_weights = ops.dropout(
        attn_weights, p=self.dropout, training=self.training
    )  # (batch_size, n_heads, seq_length, key_length)

    # Mask heads if we want to
    if layer_head_mask is not None:
        attn_weights = attn_weights * layer_head_mask

    attn_output = unshape(ops.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim)
    attn_output = self.o(attn_output)

    present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
    outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

    if output_attentions:
        outputs = outputs + (attn_weights,)
    return outputs

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersEncoderModel

Bases: SwitchTransformersPreTrainedModel

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
class SwitchTransformersEncoderModel(SwitchTransformersPreTrainedModel):
    _tied_weights_keys = ["encoder.embed_tokens.weight"]

    def __init__(self, config: SwitchTransformersConfig):
        super().__init__(config)
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = SwitchTransformersStack(encoder_config, self.shared)

        # Initialize weights and apply final processing
        self.post_init()

        # Model parallel
        self.device_map = None

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)

    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)

    def get_encoder(self):
        return self.encoder

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)

    def forward(
        self,
        input_ids: Optional[mindspore.Tensor] = None,
        attention_mask: Optional[mindspore.Tensor] = None,
        head_mask: Optional[mindspore.Tensor] = None,
        inputs_embeds: Optional[mindspore.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = True,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[mindspore.Tensor], MoEModelOutput]:
        r"""

        Returns:
            `Union[Tuple[mindspore.Tensor], MoEModelOutput]`

        Example:
            ```python
            >>> from transformers import AutoTokenizer, SwitchTransformersEncoderModel
            ...
            >>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
            >>> model = SwitchTransformersEncoderModel.from_pretrained("google/switch-base-8")
            >>> input_ids = tokenizer(
            ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
            ... ).input_ids  # Batch size 1
            >>> outputs = model(input_ids=input_ids)
            >>> last_hidden_states = outputs.last_hidden_state
            ```
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_router_logits=output_router_logits,
            return_dict=return_dict,
        )

        return encoder_outputs

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersEncoderModel.forward(input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, output_router_logits=True, return_dict=None)

RETURNS DESCRIPTION
Union[Tuple[Tensor], MoEModelOutput]

Union[Tuple[mindspore.Tensor], MoEModelOutput]

Example
>>> from transformers import AutoTokenizer, SwitchTransformersEncoderModel
...
>>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
>>> model = SwitchTransformersEncoderModel.from_pretrained("google/switch-base-8")
>>> input_ids = tokenizer(
...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
... ).input_ids  # Batch size 1
>>> outputs = model(input_ids=input_ids)
>>> last_hidden_states = outputs.last_hidden_state
Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
def forward(
    self,
    input_ids: Optional[mindspore.Tensor] = None,
    attention_mask: Optional[mindspore.Tensor] = None,
    head_mask: Optional[mindspore.Tensor] = None,
    inputs_embeds: Optional[mindspore.Tensor] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    output_router_logits: Optional[bool] = True,
    return_dict: Optional[bool] = None,
) -> Union[Tuple[mindspore.Tensor], MoEModelOutput]:
    r"""

    Returns:
        `Union[Tuple[mindspore.Tensor], MoEModelOutput]`

    Example:
        ```python
        >>> from transformers import AutoTokenizer, SwitchTransformersEncoderModel
        ...
        >>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
        >>> model = SwitchTransformersEncoderModel.from_pretrained("google/switch-base-8")
        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
        ... ).input_ids  # Batch size 1
        >>> outputs = model(input_ids=input_ids)
        >>> last_hidden_states = outputs.last_hidden_state
        ```
    """
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    encoder_outputs = self.encoder(
        input_ids=input_ids,
        attention_mask=attention_mask,
        inputs_embeds=inputs_embeds,
        head_mask=head_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        output_router_logits=output_router_logits,
        return_dict=return_dict,
    )

    return encoder_outputs

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersForConditionalGeneration

Bases: SwitchTransformersPreTrainedModel

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedModel):
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]

    def __init__(self, config: SwitchTransformersConfig):
        super().__init__(config)
        self.model_dim = config.d_model

        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = SwitchTransformersStack(encoder_config, self.shared)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        decoder_config.num_layers = config.num_decoder_layers
        self.decoder = SwitchTransformersStack(decoder_config, self.shared)

        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        self.router_z_loss_coef = config.router_z_loss_coef
        self.router_aux_loss_coef = config.router_aux_loss_coef

        # Initialize weights and apply final processing
        self.post_init()

        # Model parallel
        self.device_map = None

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)

    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
            self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def get_output_embeddings(self):
        return self.lm_head

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    def forward(
        self,
        input_ids: Optional[mindspore.Tensor] = None,
        attention_mask: Optional[mindspore.Tensor] = None,
        decoder_input_ids: Optional[mindspore.Tensor] = None,
        decoder_attention_mask: Optional[mindspore.Tensor] = None,
        head_mask: Optional[mindspore.Tensor] = None,
        decoder_head_mask: Optional[mindspore.Tensor] = None,
        cross_attn_head_mask: Optional[mindspore.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[mindspore.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None,
        inputs_embeds: Optional[mindspore.Tensor] = None,
        decoder_inputs_embeds: Optional[mindspore.Tensor] = None,
        labels: Optional[mindspore.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = True,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[mindspore.Tensor], Seq2SeqMoEOutput]:
        r"""
        Args:
            labels (`mindspore.Tensor` of shape `(batch_size,)`, *optional*):
                Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
                config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
                labels in `[0, ..., config.vocab_size]`

        Returns:
            `Union[Tuple[mindspore.Tensor], Seq2SeqMoEOutput]`

        Example:
            ```python
            >>> from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
            ...
            >>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
            >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8")
            ...
            >>> # training
            >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
            >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
            >>> outputs = model(input_ids=input_ids, labels=labels)
            >>> loss = outputs.loss
            >>> logits = outputs.logits
            ...
            >>> # inference
            >>> input_ids = tokenizer(
            ...     "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
            ... ).input_ids  # Batch size 1
            >>> outputs = model.generate(input_ids)
            >>> # . To, let’s say you have a dog. To summarize:
            >>> # Since the model has been trained on MLM, this will output gibberish
            ```
        """
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                output_router_logits=output_router_logits,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, MoEModelOutput):
            encoder_outputs = MoEModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
                router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
            )

        hidden_states = encoder_outputs[0]

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_router_logits=output_router_logits,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim**-0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        encoder_z_loss = None
        encoder_aux_loss = None
        decoder_z_loss = None
        decoder_aux_loss = None

        if output_router_logits:
            # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
            if self.encoder.config.encoder_sparse_step > 1:
                encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_outputs[-1])
                encoder_z_loss = router_z_loss_func(encoder_router_logits)
                encoder_router_probs = nn.Softmax(axis=-1)(encoder_router_logits)
                encoder_aux_loss = load_balancing_loss_func(encoder_router_probs, encoder_expert_indexes)
            else:
                encoder_z_loss = 0
                encoder_aux_loss = 0

            if self.decoder.config.decoder_sparse_step > 1:
                decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits(decoder_outputs[-1])
                decoder_z_loss = router_z_loss_func(decoder_router_logits)
                decoder_router_probs = nn.Softmax(axis=-1)(decoder_router_logits)
                decoder_aux_loss = load_balancing_loss_func(decoder_router_probs, decoder_expert_indexes)
            else:
                decoder_z_loss = 0
                decoder_aux_loss = 0

        if labels is not None:
            loss = ops.cross_entropy(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))

            if output_router_logits:
                z_loss = self.router_z_loss_coef * (encoder_z_loss + decoder_z_loss)
                aux_loss = self.router_aux_loss_coef * (encoder_aux_loss + decoder_aux_loss)
                loss = loss + z_loss + aux_loss

        if not return_dict:
            output = (lm_logits,)
            if output_router_logits:
                output += (encoder_z_loss, encoder_aux_loss, decoder_z_loss, decoder_aux_loss)
            output += (*decoder_outputs[1:], *encoder_outputs)

            return ((loss,) + output) if loss is not None else output

        return Seq2SeqMoEOutput(
            loss=loss,
            logits=lm_logits,
            encoder_z_loss=encoder_z_loss,
            encoder_aux_loss=encoder_aux_loss,
            decoder_z_loss=decoder_z_loss,
            decoder_aux_loss=decoder_aux_loss,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            decoder_router_logits=decoder_outputs.router_probs,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
            encoder_router_logits=encoder_outputs.router_probs,
        )

    def _unpack_router_logits(self, router_outputs):
        total_router_logits = []
        total_expert_indexes = []
        for router_output in router_outputs:
            if len(router_output[0].shape) > 1:
                router_logits, expert_indexes = router_output
                total_router_logits.append(router_logits)
                total_expert_indexes.append(expert_indexes)
        return ops.cat(total_router_logits, axis=1), ops.cat(total_expert_indexes, axis=1)

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        use_cache=None,
        encoder_outputs=None,
        **kwargs,
    ):
        # cut decoder_input_ids if past_key_values is used
        if past_key_values is not None:
            past_length = past_key_values[0][0].shape[2]

            # Some generation methods already pass only the last input ID
            if input_ids.shape[1] > past_length:
                remove_prefix_length = past_length
            else:
                # Default to old behavior: keep only final ID
                remove_prefix_length = input_ids.shape[1] - 1

            input_ids = input_ids[:, remove_prefix_length:]

        output_router_logits = kwargs.get("output_router_logits", True)

        return {
            "decoder_input_ids": input_ids,
            "past_key_values": past_key_values,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "head_mask": head_mask,
            "decoder_head_mask": decoder_head_mask,
            "cross_attn_head_mask": cross_attn_head_mask,
            "use_cache": use_cache,
            "output_router_logits": output_router_logits,
        }

    def prepare_decoder_input_ids_from_labels(self, labels: mindspore.Tensor):
        return self._shift_right(labels)

    def _reorder_cache(self, past_key_values, beam_idx):
        # if decoder past is not included in output
        # speedy decoding is disabled and no need to reorder
        if past_key_values is None:
            logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
            return past_key_values

        reordered_decoder_past = ()
        for layer_past_states in past_key_values:
            # get the correct batch idx from layer past batch dim
            # batch dim of `past` is at 2nd position
            reordered_layer_past_states = ()
            for layer_past_state in layer_past_states:
                # need to set correct `past` for each of the four key / value states
                reordered_layer_past_states = reordered_layer_past_states + (
                    layer_past_state.index_select(0, beam_idx),
                )

            if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
                raise ValueError(
                    "expected reordered_layer_past_states to have the same shape than layer_past_states, "
                    f"but got {reordered_layer_past_states[0].shape} and {layer_past_states[0].shape}"
                )
            if len(reordered_layer_past_states) != len(layer_past_states):
                raise ValueError(
                    "expected layer_past_states to have the same length as reordered_layer_past_states, "
                    f"but got {len(layer_past_states)} and {len(reordered_layer_past_states)}"
                )

            reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
        return reordered_decoder_past

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersForConditionalGeneration.forward(input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, output_router_logits=True, return_dict=None)

PARAMETER DESCRIPTION
labels

Labels for computing the sequence classification/regression loss. Indices should be in [-100, 0, ..., config.vocab_size - 1]. All labels set to -100 are ignored (masked), the loss is only computed for labels in [0, ..., config.vocab_size]

TYPE: `mindspore.Tensor` of shape `(batch_size,)`, *optional* DEFAULT: None

RETURNS DESCRIPTION
Union[Tuple[Tensor], Seq2SeqMoEOutput]

Union[Tuple[mindspore.Tensor], Seq2SeqMoEOutput]

Example
>>> from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
...
>>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
>>> model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8")
...
>>> # training
>>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
>>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
...
>>> # inference
>>> input_ids = tokenizer(
...     "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
... ).input_ids  # Batch size 1
>>> outputs = model.generate(input_ids)
>>> # . To, let’s say you have a dog. To summarize:
>>> # Since the model has been trained on MLM, this will output gibberish
Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
def forward(
    self,
    input_ids: Optional[mindspore.Tensor] = None,
    attention_mask: Optional[mindspore.Tensor] = None,
    decoder_input_ids: Optional[mindspore.Tensor] = None,
    decoder_attention_mask: Optional[mindspore.Tensor] = None,
    head_mask: Optional[mindspore.Tensor] = None,
    decoder_head_mask: Optional[mindspore.Tensor] = None,
    cross_attn_head_mask: Optional[mindspore.Tensor] = None,
    encoder_outputs: Optional[Tuple[Tuple[mindspore.Tensor]]] = None,
    past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None,
    inputs_embeds: Optional[mindspore.Tensor] = None,
    decoder_inputs_embeds: Optional[mindspore.Tensor] = None,
    labels: Optional[mindspore.Tensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    output_router_logits: Optional[bool] = True,
    return_dict: Optional[bool] = None,
) -> Union[Tuple[mindspore.Tensor], Seq2SeqMoEOutput]:
    r"""
    Args:
        labels (`mindspore.Tensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
            labels in `[0, ..., config.vocab_size]`

    Returns:
        `Union[Tuple[mindspore.Tensor], Seq2SeqMoEOutput]`

    Example:
        ```python
        >>> from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration
        ...
        >>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
        >>> model = SwitchTransformersForConditionalGeneration.from_pretrained("google/switch-base-8")
        ...
        >>> # training
        >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
        >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
        >>> outputs = model(input_ids=input_ids, labels=labels)
        >>> loss = outputs.loss
        >>> logits = outputs.logits
        ...
        >>> # inference
        >>> input_ids = tokenizer(
        ...     "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
        ... ).input_ids  # Batch size 1
        >>> outputs = model.generate(input_ids)
        >>> # . To, let’s say you have a dog. To summarize:
        >>> # Since the model has been trained on MLM, this will output gibberish
        ```
    """
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
    if head_mask is not None and decoder_head_mask is None:
        if self.config.num_layers == self.config.num_decoder_layers:
            warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
            decoder_head_mask = head_mask

    # Encode if needed (training, first prediction pass)
    if encoder_outputs is None:
        # Convert encoder inputs in embeddings if needed
        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_router_logits=output_router_logits,
            return_dict=return_dict,
        )
    elif return_dict and not isinstance(encoder_outputs, MoEModelOutput):
        encoder_outputs = MoEModelOutput(
            last_hidden_state=encoder_outputs[0],
            hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
            attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
        )

    hidden_states = encoder_outputs[0]

    if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
        # get decoder inputs from shifting lm labels to the right
        decoder_input_ids = self._shift_right(labels)

    # Decode
    decoder_outputs = self.decoder(
        input_ids=decoder_input_ids,
        attention_mask=decoder_attention_mask,
        inputs_embeds=decoder_inputs_embeds,
        past_key_values=past_key_values,
        encoder_hidden_states=hidden_states,
        encoder_attention_mask=attention_mask,
        head_mask=decoder_head_mask,
        cross_attn_head_mask=cross_attn_head_mask,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        output_router_logits=output_router_logits,
        return_dict=return_dict,
    )

    sequence_output = decoder_outputs[0]

    if self.config.tie_word_embeddings:
        # Rescale output before projecting on vocab
        # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
        sequence_output = sequence_output * (self.model_dim**-0.5)

    lm_logits = self.lm_head(sequence_output)

    loss = None
    encoder_z_loss = None
    encoder_aux_loss = None
    decoder_z_loss = None
    decoder_aux_loss = None

    if output_router_logits:
        # Compute the router loss (z_loss + auxiliary loss) for each router in the encoder and decoder
        if self.encoder.config.encoder_sparse_step > 1:
            encoder_router_logits, encoder_expert_indexes = self._unpack_router_logits(encoder_outputs[-1])
            encoder_z_loss = router_z_loss_func(encoder_router_logits)
            encoder_router_probs = nn.Softmax(axis=-1)(encoder_router_logits)
            encoder_aux_loss = load_balancing_loss_func(encoder_router_probs, encoder_expert_indexes)
        else:
            encoder_z_loss = 0
            encoder_aux_loss = 0

        if self.decoder.config.decoder_sparse_step > 1:
            decoder_router_logits, decoder_expert_indexes = self._unpack_router_logits(decoder_outputs[-1])
            decoder_z_loss = router_z_loss_func(decoder_router_logits)
            decoder_router_probs = nn.Softmax(axis=-1)(decoder_router_logits)
            decoder_aux_loss = load_balancing_loss_func(decoder_router_probs, decoder_expert_indexes)
        else:
            decoder_z_loss = 0
            decoder_aux_loss = 0

    if labels is not None:
        loss = ops.cross_entropy(lm_logits.view(-1, lm_logits.shape[-1]), labels.view(-1))

        if output_router_logits:
            z_loss = self.router_z_loss_coef * (encoder_z_loss + decoder_z_loss)
            aux_loss = self.router_aux_loss_coef * (encoder_aux_loss + decoder_aux_loss)
            loss = loss + z_loss + aux_loss

    if not return_dict:
        output = (lm_logits,)
        if output_router_logits:
            output += (encoder_z_loss, encoder_aux_loss, decoder_z_loss, decoder_aux_loss)
        output += (*decoder_outputs[1:], *encoder_outputs)

        return ((loss,) + output) if loss is not None else output

    return Seq2SeqMoEOutput(
        loss=loss,
        logits=lm_logits,
        encoder_z_loss=encoder_z_loss,
        encoder_aux_loss=encoder_aux_loss,
        decoder_z_loss=decoder_z_loss,
        decoder_aux_loss=decoder_aux_loss,
        past_key_values=decoder_outputs.past_key_values,
        decoder_hidden_states=decoder_outputs.hidden_states,
        decoder_attentions=decoder_outputs.attentions,
        cross_attentions=decoder_outputs.cross_attentions,
        decoder_router_logits=decoder_outputs.router_probs,
        encoder_last_hidden_state=encoder_outputs.last_hidden_state,
        encoder_hidden_states=encoder_outputs.hidden_states,
        encoder_attentions=encoder_outputs.attentions,
        encoder_router_logits=encoder_outputs.router_probs,
    )

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersLayerFF

Bases: Module

Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module.

PARAMETER DESCRIPTION
config

([SwitchTransformersConfig]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [~PreTrainedModel.from_pretrained] method to load the model weights.

is_sparse

Whether the MLP layer is a Sparse layer (contains a Mixture of Experts) or not

TYPE: `bool` DEFAULT: False

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
class SwitchTransformersLayerFF(nn.Module):
    r"""
    Switch Transformers Feed Forward layer module. This is a wrapper around the Mixture of Experts module.

    Parameters:
        config : ([`SwitchTransformersConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
        is_sparse (`bool`):
            Whether the MLP layer is a `Sparse` layer (contains a Mixture of Experts) or not
    """

    def __init__(self, config: SwitchTransformersConfig, is_sparse=False):
        super().__init__()
        self.is_sparse = is_sparse

        # Check if it is a sparse layer, if not then it is a dense layer
        if not self.is_sparse:
            self.mlp = SwitchTransformersDenseActDense(config)
        else:
            self.mlp = SwitchTransformersSparseMLP(config)

        self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
        self.dropout = nn.Dropout(p=config.dropout_rate)

    def forward(self, hidden_states, output_router_logits):
        forwarded_states = self.layer_norm(hidden_states)
        forwarded_states = self.mlp(forwarded_states)

        if isinstance(forwarded_states, tuple):
            forwarded_states, router_tuple = forwarded_states
        else:
            router_tuple = None

        output = hidden_states + self.dropout(forwarded_states)

        if output_router_logits and router_tuple is not None:
            output = (output, router_tuple)

        return output

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersLayerNorm

Bases: Module

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
class SwitchTransformersLayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Construct a layernorm module in the SwitchTransformers style. No bias and no subtraction of mean.
        """
        super().__init__()
        self.weight = Parameter(ops.ones(hidden_size), 'weight')
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        # SwitchTransformers uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
        # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
        # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
        # half-precision inputs is done in fp32

        variance = hidden_states.astype(mindspore.float32).pow(2).mean(-1, keep_dims=True)
        hidden_states = hidden_states * ops.rsqrt(variance + self.variance_epsilon)

        # convert into half-precision if necessary
        if self.weight.dtype in [mindspore.float16, mindspore.bfloat16]:
            hidden_states = hidden_states.astype(self.weight.dtype)

        return self.weight * hidden_states

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersLayerNorm.__init__(hidden_size, eps=1e-06)

Construct a layernorm module in the SwitchTransformers style. No bias and no subtraction of mean.

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
213
214
215
216
217
218
219
def __init__(self, hidden_size, eps=1e-6):
    """
    Construct a layernorm module in the SwitchTransformers style. No bias and no subtraction of mean.
    """
    super().__init__()
    self.weight = Parameter(ops.ones(hidden_size), 'weight')
    self.variance_epsilon = eps

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersModel

Bases: SwitchTransformersPreTrainedModel

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
    _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]

    def __init__(self, config: SwitchTransformersConfig):
        super().__init__(config)
        self.shared = nn.Embedding(config.vocab_size, config.d_model)

        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False
        self.encoder = SwitchTransformersStack(encoder_config, self.shared)

        decoder_config = copy.deepcopy(config)
        decoder_config.is_decoder = True
        decoder_config.is_encoder_decoder = False
        self.decoder = SwitchTransformersStack(decoder_config, self.shared)

        # Initialize weights and apply final processing
        self.post_init()

        # Model parallel
        self.device_map = None

    def get_input_embeddings(self):
        return self.shared

    def set_input_embeddings(self, new_embeddings):
        self.shared = new_embeddings
        self.encoder.set_input_embeddings(new_embeddings)
        self.decoder.set_input_embeddings(new_embeddings)

    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
            self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        """
        for layer, heads in heads_to_prune.items():
            self.encoder.layer[layer].attention.prune_heads(heads)

    def forward(
        self,
        input_ids: Optional[mindspore.Tensor] = None,
        attention_mask: Optional[mindspore.Tensor] = None,
        decoder_input_ids: Optional[mindspore.Tensor] = None,
        decoder_attention_mask: Optional[mindspore.Tensor] = None,
        head_mask: Optional[mindspore.Tensor] = None,
        decoder_head_mask: Optional[mindspore.Tensor] = None,
        cross_attn_head_mask: Optional[mindspore.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[mindspore.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None,
        inputs_embeds: Optional[mindspore.Tensor] = None,
        decoder_inputs_embeds: Optional[mindspore.Tensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[mindspore.Tensor], Seq2SeqMoEModelOutput]:
        r"""

        Returns:
            `Union[Tuple[mindspore.Tensor], Seq2SeqMoEModelOutput]`

        Example:
            ```python
            >>> from transformers import AutoTokenizer, SwitchTransformersModel
            ...
            >>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
            >>> model = SwitchTransformersModel.from_pretrained("google/switch-base-8")
            ...
            >>> input_ids = tokenizer(
            ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
            ... ).input_ids  # Batch size 1
            >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
            ...
            >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for SwitchTransformersModel.
            >>> # This is not needed for torch's SwitchTransformersForConditionalGeneration as it does this internally using labels arg.
            >>> decoder_input_ids = model._shift_right(decoder_input_ids)
            ...
            >>> # forward pass
            >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
            >>> last_hidden_states = outputs.last_hidden_state
            ```
        """
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
        if head_mask is not None and decoder_head_mask is None:
            if self.config.num_layers == self.config.num_decoder_layers:
                warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
                decoder_head_mask = head_mask

        if (
            output_router_logits
            and self.config.num_sparse_encoder_layers == 0
            and self.config.num_sparse_encoder_layers == 0
        ):
            raise ValueError(
                "You asked to return `output_router_logits` but the transformer in dense, and does                    "
                "           not contain any sparse MLP Layers. Set `output_router_logits = False` and restart"
            )
        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                output_router_logits=output_router_logits,
                return_dict=return_dict,
            )
        elif return_dict and not isinstance(encoder_outputs, MoEModelOutput):
            encoder_outputs = MoEModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
                router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
            )

        hidden_states = encoder_outputs[0]

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_router_logits=output_router_logits,
            return_dict=return_dict,
        )

        if not return_dict:
            return decoder_outputs + encoder_outputs

        return Seq2SeqMoEModelOutput(
            last_hidden_state=decoder_outputs.last_hidden_state,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            decoder_router_logits=decoder_outputs.router_probs,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
            encoder_router_logits=encoder_outputs.router_probs,
        )

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersModel.forward(input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, output_router_logits=None, return_dict=None)

RETURNS DESCRIPTION
Union[Tuple[Tensor], Seq2SeqMoEModelOutput]

Union[Tuple[mindspore.Tensor], Seq2SeqMoEModelOutput]

Example
>>> from transformers import AutoTokenizer, SwitchTransformersModel
...
>>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
>>> model = SwitchTransformersModel.from_pretrained("google/switch-base-8")
...
>>> input_ids = tokenizer(
...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
... ).input_ids  # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
...
>>> # preprocess: Prepend decoder_input_ids with start token which is pad token for SwitchTransformersModel.
>>> # This is not needed for torch's SwitchTransformersForConditionalGeneration as it does this internally using labels arg.
>>> decoder_input_ids = model._shift_right(decoder_input_ids)
...
>>> # forward pass
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
>>> last_hidden_states = outputs.last_hidden_state
Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
def forward(
    self,
    input_ids: Optional[mindspore.Tensor] = None,
    attention_mask: Optional[mindspore.Tensor] = None,
    decoder_input_ids: Optional[mindspore.Tensor] = None,
    decoder_attention_mask: Optional[mindspore.Tensor] = None,
    head_mask: Optional[mindspore.Tensor] = None,
    decoder_head_mask: Optional[mindspore.Tensor] = None,
    cross_attn_head_mask: Optional[mindspore.Tensor] = None,
    encoder_outputs: Optional[Tuple[Tuple[mindspore.Tensor]]] = None,
    past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None,
    inputs_embeds: Optional[mindspore.Tensor] = None,
    decoder_inputs_embeds: Optional[mindspore.Tensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    output_router_logits: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple[mindspore.Tensor], Seq2SeqMoEModelOutput]:
    r"""

    Returns:
        `Union[Tuple[mindspore.Tensor], Seq2SeqMoEModelOutput]`

    Example:
        ```python
        >>> from transformers import AutoTokenizer, SwitchTransformersModel
        ...
        >>> tokenizer = AutoTokenizer.from_pretrained("google/switch-base-8")
        >>> model = SwitchTransformersModel.from_pretrained("google/switch-base-8")
        ...
        >>> input_ids = tokenizer(
        ...     "Studies have been shown that owning a dog is good for you", return_tensors="pt"
        ... ).input_ids  # Batch size 1
        >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids  # Batch size 1
        ...
        >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for SwitchTransformersModel.
        >>> # This is not needed for torch's SwitchTransformersForConditionalGeneration as it does this internally using labels arg.
        >>> decoder_input_ids = model._shift_right(decoder_input_ids)
        ...
        >>> # forward pass
        >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
        >>> last_hidden_states = outputs.last_hidden_state
        ```
    """
    use_cache = use_cache if use_cache is not None else self.config.use_cache
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
    if head_mask is not None and decoder_head_mask is None:
        if self.config.num_layers == self.config.num_decoder_layers:
            warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
            decoder_head_mask = head_mask

    if (
        output_router_logits
        and self.config.num_sparse_encoder_layers == 0
        and self.config.num_sparse_encoder_layers == 0
    ):
        raise ValueError(
            "You asked to return `output_router_logits` but the transformer in dense, and does                    "
            "           not contain any sparse MLP Layers. Set `output_router_logits = False` and restart"
        )
    # Encode if needed (training, first prediction pass)
    if encoder_outputs is None:
        encoder_outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_router_logits=output_router_logits,
            return_dict=return_dict,
        )
    elif return_dict and not isinstance(encoder_outputs, MoEModelOutput):
        encoder_outputs = MoEModelOutput(
            last_hidden_state=encoder_outputs[0],
            hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
            attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
        )

    hidden_states = encoder_outputs[0]

    # Decode
    decoder_outputs = self.decoder(
        input_ids=decoder_input_ids,
        attention_mask=decoder_attention_mask,
        inputs_embeds=decoder_inputs_embeds,
        past_key_values=past_key_values,
        encoder_hidden_states=hidden_states,
        encoder_attention_mask=attention_mask,
        head_mask=decoder_head_mask,
        cross_attn_head_mask=cross_attn_head_mask,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        output_router_logits=output_router_logits,
        return_dict=return_dict,
    )

    if not return_dict:
        return decoder_outputs + encoder_outputs

    return Seq2SeqMoEModelOutput(
        last_hidden_state=decoder_outputs.last_hidden_state,
        past_key_values=decoder_outputs.past_key_values,
        decoder_hidden_states=decoder_outputs.hidden_states,
        decoder_attentions=decoder_outputs.attentions,
        cross_attentions=decoder_outputs.cross_attentions,
        decoder_router_logits=decoder_outputs.router_probs,
        encoder_last_hidden_state=encoder_outputs.last_hidden_state,
        encoder_hidden_states=encoder_outputs.hidden_states,
        encoder_attentions=encoder_outputs.attentions,
        encoder_router_logits=encoder_outputs.router_probs,
    )

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersPreTrainedModel

Bases: PreTrainedModel

An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models.

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
class SwitchTransformersPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    """

    config_class = SwitchTransformersConfig
    base_model_prefix = "switch_transformers"
    supports_gradient_checkpointing = True
    _no_split_modules = ["SwitchTransformersBlock"]

    @property
    def dummy_inputs(self):
        input_ids = mindspore.Tensor(DUMMY_INPUTS)
        input_mask = mindspore.Tensor(DUMMY_MASK)
        dummy_inputs = {
            "decoder_input_ids": input_ids,
            "input_ids": input_ids,
            "decoder_attention_mask": input_mask,
        }
        return dummy_inputs

    def _init_weights(self, module):
        """Initialize the weights"""
        factor = self.config.initializer_factor  # Used for testing weights initialization
        if isinstance(module, SwitchTransformersLayerNorm):
            module.weight.data.set_data(initializer(Normal(factor * 1.0), \
                                                    module.weight.data.shape, module.weight.data.dtype))
        elif isinstance(
            module,
            (SwitchTransformersModel, SwitchTransformersForConditionalGeneration, SwitchTransformersEncoderModel),
        ):
            # Mesh TensorFlow embeddings initialization
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
            module.shared.weight.data.set_data(initializer(Normal(factor * 1.0), \
                                                              module.shared.weight.data.shape, \
                                                              module.shared.weight.data.dtype))
            if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
                module.lm_head.weight.data.set_data(initializer(Normal(factor * 1.0), \
                                                    module.lm_head.weight.data.shape, \
                                                    module.lm_head.weight.data.dtype))
        elif isinstance(module, SwitchTransformersDenseActDense):
            # Mesh TensorFlow FF initialization
            # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
            # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
            module.wi.weight.data.set_data(initializer(Normal(factor * ((self.config.d_model) ** -0.5)), \
                                           module.wi.weight.data.shape, \
                                           module.wi.weight.data.dtype))
            if hasattr(module.wi, "bias") and module.wi.bias is not None:
                module.wi.bias.data.set_data(initializer("zero", module.wi.bias.data.shape, \
                                                         module.wi.bias.data.dtype))
            module.wo.weight.data.set_data(initializer(Normal(factor * ((self.config.d_ff) ** -0.5)), \
                                           module.wo.weight.data.shape, \
                                           module.wo.weight.data.dtype))
            if hasattr(module.wo, "bias") and module.wo.bias is not None:
                module.wo.bias.data.set_data(initializer("zero", module.wo.bias.data.shape, \
                                                         module.wo.bias.data.dtype))
        elif isinstance(module, SwitchTransformersAttention):
            # Mesh TensorFlow attention initialization to avoid scaling before softmax
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
            d_model = self.config.d_model
            key_value_proj_dim = self.config.d_kv
            n_heads = self.config.num_heads
            module.q.weight.data.set_data(initializer(Normal(factor * ((d_model * key_value_proj_dim) ** -0.5)), \
                                          module.q.weight.data.shape, \
                                          module.q.weight.data.dtype))
            module.k.weight.data.set_data(initializer(Normal(factor * (d_model**-0.5)), \
                                          module.k.weight.data.shape, \
                                          module.k.weight.data.dtype))
            module.v.weight.data.set_data(initializer(Normal(factor * (d_model**-0.5)), \
                                          module.v.weight.data.shape, \
                                          module.v.weight.data.dtype))
            module.o.weight.data.set_data(initializer(Normal(factor * ((n_heads * key_value_proj_dim) ** -0.5)), \
                                          module.o.weight.data.shape, \
                                          module.o.weight.data.dtype))
            if module.has_relative_attention_bias:
                module.relative_attention_bias.weight.data.set_data(initializer(Normal(factor * ((d_model) ** -0.5)), \
                                                                    module.relative_attention_bias.weight.data.shape, \
                                                                    module.relative_attention_bias.weight.data.dtype))
        elif isinstance(module, SwitchTransformersSparseMLP):
            # Mesh TensorFlow attention initialization to avoid scaling before softmax
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
            d_model = self.config.d_model
            key_value_proj_dim = self.config.d_kv
            n_heads = self.config.num_heads
            module.router.classifier.weight.data.set_data(initializer(Normal(factor * 1.0), \
                                                              module.router.classifier.weight.data.shape, \
                                                              module.router.classifier.weight.data.dtype))
            for idx in range(self.config.num_experts):
                module.experts[f"expert_{idx}"].wi.weight.set_data(initializer(Normal(factor * (d_model**-0.5)), \
                                                              module.experts[f"expert_{idx}"].wi.weight.data.shape, \
                                                              module.experts[f"expert_{idx}"].wi.weight.data.dtype))
                module.experts[f"expert_{idx}"].wo.weight.set_data(initializer(Normal(factor * (d_model**-0.5)), \
                                                              module.experts[f"expert_{idx}"].wo.weight.data.shape, \
                                                              module.experts[f"expert_{idx}"].wo.weight.data.dtype))

    def _shift_right(self, input_ids):
        decoder_start_token_id = self.config.decoder_start_token_id
        pad_token_id = self.config.pad_token_id

        if decoder_start_token_id is None:
            raise ValueError(
                "self.model.config.decoder_start_token_id has to be defined. In SwitchTransformers it is usually set"
                " to the pad_token_id. See SwitchTransformers docs for more information"
            )

        # shift inputs to the right
        # if is_torch_fx_proxy(input_ids):
        #     # Item assignment is not supported natively for proxies.
        #     shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
        #     shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
        # else:
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[..., 1:] = input_ids[..., :-1].copy()
        shifted_input_ids[..., 0] = decoder_start_token_id

        if pad_token_id is None:
            raise ValueError("self.model.config.pad_token_id has to be defined.")
        # replace possible -100 values in labels by `pad_token_id`
        shifted_input_ids.masked_fill(shifted_input_ids == -100, pad_token_id)

        return shifted_input_ids

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP

Bases: Module

Implementation of the Switch Transformers Sparse MLP module.

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
class SwitchTransformersSparseMLP(nn.Module):
    r"""
    Implementation of the Switch Transformers Sparse MLP module.
    """

    def __init__(self, config: SwitchTransformersConfig, expert_class: nn.Module = SwitchTransformersDenseActDense):
        super().__init__()
        # Step 1: Get the correct router according to its class
        self.router = SwitchTransformersTop1Router(config)

        # Step 2: Get the experts
        self.experts = nn.ModuleDict()
        for idx in range(config.num_experts):
            self.experts[f"expert_{idx}"] = expert_class(config)

    def forward(self, hidden_states):
        r"""
        Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:

        1. Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
        and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
        hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
        2. Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
        expert the corresponding hidden states.

        """
        # Step 1: Get the router_mask from the router as wel as the probabilities
        router_mask, router_probs, router_logits = self.router(hidden_states)
        expert_index = ops.argmax(router_mask, dim=-1)

        # The routers introduced might not always map all the tokens, to a router, which means that some hidden states
        # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

        next_states = hidden_states.copy()
        for idx, expert in enumerate(self.experts.values()):
            token_indices = router_mask[:, :, idx].bool()
            next_states[token_indices] = expert(hidden_states[token_indices]).astype(next_states.dtype)

        hidden_states = router_probs * next_states
        return hidden_states, (router_logits, expert_index)

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP.forward(hidden_states)

Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:

  1. Gets the router_mask from the router. The shape of the mask is (batch_size, sequence_length, num_expert) and corresponds to the argmax of the router_probs. The probabilities are needed in the computation of the hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
  2. Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each expert the corresponding hidden states.
Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
def forward(self, hidden_states):
    r"""
    Hold on, this will be slightly tricky to understand In the correct order, a MoE layer does the following:

    1. Gets the `router_mask` from the router. The shape of the mask is `(batch_size, sequence_length, num_expert)`
    and corresponds to the argmax of the `router_probs`. The probabilities are needed in the computation of the
    hidden states : they are broadcasted to the hidden states values (can be interpreted as a scaling factor).
    2. Dispatch the tokens to its associated experts. We do a classic for loop over the experts and assign for each
    expert the corresponding hidden states.

    """
    # Step 1: Get the router_mask from the router as wel as the probabilities
    router_mask, router_probs, router_logits = self.router(hidden_states)
    expert_index = ops.argmax(router_mask, dim=-1)

    # The routers introduced might not always map all the tokens, to a router, which means that some hidden states
    # can be unchanged from one layer to another. That is why the hidden states are cloned before updating only the seleced ones.

    next_states = hidden_states.copy()
    for idx, expert in enumerate(self.experts.values()):
        token_indices = router_mask[:, :, idx].bool()
        next_states[token_indices] = expert(hidden_states[token_indices]).astype(next_states.dtype)

    hidden_states = router_probs * next_states
    return hidden_states, (router_logits, expert_index)

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router

Bases: Module

Router using tokens choose top-1 experts assignment.

This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then routed to their choice of expert until the expert's expert_capacity is reached. There is no guarantee that each token is processed by an expert, or that each expert receives at least one token.

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
class SwitchTransformersTop1Router(nn.Module):
    """
    Router using tokens choose top-1 experts assignment.

    This router uses the same mechanism as in Switch Transformer (https://arxiv.org/abs/2101.03961) and V-MoE
    (https://arxiv.org/abs/2106.05974): tokens choose their top experts. Items are sorted by router_probs and then
    routed to their choice of expert until the expert's expert_capacity is reached. **There is no guarantee that each
    token is processed by an expert**, or that each expert receives at least one token.

    """

    def __init__(self, config: SwitchTransformersConfig):
        super().__init__()
        self.num_experts = config.num_experts
        self.expert_capacity = config.expert_capacity
        self.classifier = nn.Linear(config.hidden_size, self.num_experts, bias=config.router_bias)
        self.jitter_noise = config.router_jitter_noise
        self.ignore_padding_tokens = config.router_ignore_padding_tokens
        self.dtype = getattr(mindspore, config.router_dtype)

    def _compute_router_probabilities(self, hidden_states: mindspore.Tensor) -> Tuple[mindspore.Tensor, mindspore.Tensor]:
        r"""
        Computes router probabilities from input hidden states.

        Args:
            hidden_states (`mindspore.Tensor`):
                (batch_size, sequence_length, hidden_dim) from which router probabilities are computed.

        Returns:
            router_probabilities (`mindspore.Tensor`):
                Tensor of shape (batch_size, sequence_length, num_experts) corresponding to the probabilities for each
                token and expert. Used for routing tokens to experts.
            router_logits (`mindspore.Tensor`):
                Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding to raw router logits.
                This is used later for computing router z-loss.
        """
        # float32 is used to ensure stability. See the discussion of "selective precision" in
        # https://arxiv.org/abs/2101.03961.
        # We also store the previous dtype to cast back the output to the previous dtype
        self.input_dtype = hidden_states.dtype
        hidden_states = hidden_states.astype(self.dtype)

        if self.training and self.jitter_noise > 0:
            # Multiply the token inputs by the uniform distribution - adding some noise
            hidden_states *= ops.uniform(hidden_states.shape,
                                         mindspore.Tensor(1.0 - self.jitter_noise, hidden_states.dtype),
                                         mindspore.Tensor(1.0 + self.jitter_noise, hidden_states.dtype),
                                         seed=0,
                                         dtype=hidden_states.dtype)

        # Shape: [num_groups, tokens_per_group, num_experts]
        self._cast_classifier()
        router_logits = self.classifier(hidden_states)

        # Apply Softmax and cast back to the original `dtype`
        router_probabilities = ops.softmax(router_logits, axis=-1, dtype=self.dtype).astype(self.input_dtype)
        return router_probabilities, router_logits

    def _cast_classifier(self):
        r"""
        `bitsandbytes` `Linear8bitLt` layers does not support manual casting Therefore we need to check if they are an
        instance of the `Linear8bitLt` class by checking special attributes.
        """
        if not (hasattr(self.classifier, "SCB") or hasattr(self.classifier, "CB")):
            self.classifier = self.classifier.to_float(self.dtype)

    def forward(self, hidden_states: mindspore.Tensor) -> Tuple:
        r"""
        Generic forward function for every Router class. Each Router expects to have the same input hidden states
        (`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the
        number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert.

        Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and
        `router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned
        to an expert. Then each Router class will have to define its own `_compute_routing_instructions`.

        Args:
            hidden_states (`mindspore.Tensor`) :
                [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.

        Returns:
            Tuple[`mindspore.Tensor`, `mindspore.Tensor`, `mindspore.Tensor`] Tuple containing the expert index,
                the router probs and the router logits. The router probabilities and logits are required to compute the loss.
        """
        router_probs, router_logits = self._compute_router_probabilities(hidden_states)

        expert_index = ops.argmax(router_probs, dim=-1)
        expert_index = ops.one_hot(expert_index, self.num_experts)

        # Mask tokens outside expert capacity. Sum over each sequence
        token_priority = ops.cumsum(expert_index, axis=-2)
        # mask if the token routed to to the expert will overflow
        expert_capacity_mask = token_priority <= self.expert_capacity
        expert_index = expert_index * expert_capacity_mask

        router_probs = ops.max(router_probs, axis=-1)[0].unsqueeze(-1)
        return expert_index, router_probs, router_logits

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router.forward(hidden_states)

Generic forward function for every Router class. Each Router expects to have the same input hidden states (hidden_states) corresponding to the hidden states for each token, the expert_capacity corresponding to the number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert.

Each Router works as the following: it expects the hidden states for each token, gets the router_probs and router_logits from the router_weights. This will assign for each token, the raw probability to be assigned to an expert. Then each Router class will have to define its own _compute_routing_instructions.

PARAMETER DESCRIPTION
hidden_states

[num_groups, tokens_per_group, hidden_dim] inputs to send to experts.

TYPE: `mindspore.Tensor`)

RETURNS DESCRIPTION
Tuple

Tuple[mindspore.Tensor, mindspore.Tensor, mindspore.Tensor] Tuple containing the expert index, the router probs and the router logits. The router probabilities and logits are required to compute the loss.

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def forward(self, hidden_states: mindspore.Tensor) -> Tuple:
    r"""
    Generic forward function for every Router class. Each Router expects to have the same input hidden states
    (`hidden_states`) corresponding to the hidden states for each token, the `expert_capacity` corresponding to the
    number of tokens the Router will send to each expert, some Routers can send up to few tokens to each expert.

    Each Router works as the following: it expects the hidden states for each token, gets the `router_probs` and
    `router_logits` from the `router_weights`. This will assign for each token, the raw probability to be assigned
    to an expert. Then each Router class will have to define its own `_compute_routing_instructions`.

    Args:
        hidden_states (`mindspore.Tensor`) :
            [num_groups, tokens_per_group, hidden_dim] inputs to send to experts.

    Returns:
        Tuple[`mindspore.Tensor`, `mindspore.Tensor`, `mindspore.Tensor`] Tuple containing the expert index,
            the router probs and the router logits. The router probabilities and logits are required to compute the loss.
    """
    router_probs, router_logits = self._compute_router_probabilities(hidden_states)

    expert_index = ops.argmax(router_probs, dim=-1)
    expert_index = ops.one_hot(expert_index, self.num_experts)

    # Mask tokens outside expert capacity. Sum over each sequence
    token_priority = ops.cumsum(expert_index, axis=-2)
    # mask if the token routed to to the expert will overflow
    expert_capacity_mask = token_priority <= self.expert_capacity
    expert_index = expert_index * expert_capacity_mask

    router_probs = ops.max(router_probs, axis=-1)[0].unsqueeze(-1)
    return expert_index, router_probs, router_logits

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.load_balancing_loss_func(router_probs, expert_indices)

Computes auxiliary load balancing loss as in Switch Transformer - implemented in MindSpore.

See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between experts is too unbalanced.

PARAMETER DESCRIPTION
router_probs

Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts].

TYPE: `mindspore.Tensor`

expert_indices

Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token.

TYPE: `mindspore.Tensor`

RETURNS DESCRIPTION
float

The auxiliary loss.

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def load_balancing_loss_func(router_probs: mindspore.Tensor, expert_indices: mindspore.Tensor) -> float:
    r"""
    Computes auxiliary load balancing loss as in Switch Transformer - implemented in MindSpore.

    See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
    function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
    experts is too unbalanced.

    Args:
        router_probs (`mindspore.Tensor`):
            Probability assigned to each expert per token. Shape: [batch_size, seqeunce_length, num_experts].
        expert_indices (`mindspore.Tensor`):
            Indices tensor of shape [batch_size, seqeunce_length] identifying the selected expert for a given token.

    Returns:
        The auxiliary loss.
    """
    num_experts = router_probs.shape[-1]

    # cast the expert indices to int64, otherwise one-hot encoding will fail
    if expert_indices.dtype != mindspore.int64:
        expert_indices = expert_indices.astype(mindspore.int64)

    if len(expert_indices.shape) == 2:
        expert_indices = expert_indices.unsqueeze(2)

    expert_mask = ops.one_hot(expert_indices, num_experts)

    # For a given token, determine if it was routed to a given expert.
    expert_mask = ops.max(expert_mask, axis=-2)[0]

    # cast to float32 otherwise mean will fail
    expert_mask = mindspore.Tensor(expert_mask, mindspore.float32)
    tokens_per_group_and_expert = ops.mean(expert_mask, axis=-2)

    router_prob_per_group_and_expert = ops.mean(router_probs, axis=-2)
    return ops.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)

mindnlp.transformers.models.switch_transformers.modeling_switch_transformers.router_z_loss_func(router_logits)

Compute the router z-loss implemented in PyTorch.

The router z-loss was introduced in Designing Effective Sparse Expert Models. It encourages router logits to remain small in an effort to improve stability.

PARAMETER DESCRIPTION
router_logits

Input logits of shape [batch_size, sequence_length, num_experts]

TYPE: `float`

RETURNS DESCRIPTION
float

Scalar router z-loss.

Source code in mindnlp/transformers/models/switch_transformers/modeling_switch_transformers.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def router_z_loss_func(router_logits: mindspore.Tensor) -> float:
    r"""
    Compute the router z-loss implemented in PyTorch.

    The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906).
    It encourages router logits to remain small in an effort to improve stability.

    Args:
        router_logits (`float`):
            Input logits of shape [batch_size, sequence_length, num_experts]

    Returns:
        Scalar router z-loss.
    """
    num_groups, tokens_per_group, _ = router_logits.shape
    log_z = ops.logsumexp(router_logits, axis=-1)
    z_loss = log_z**2
    return ops.sum(z_loss) / (num_groups * tokens_per_group)