Skip to content

Quick start

MindNLP offers powerful functionalities for training and using AI models for various tasks. To get started, this tutorial will guide you through loading a pretrained model and fine-tuning it to fit your specific needs.

Using a pretrained model has great benefits: it saves computing time and resources. Fine-tuning allows you to adapt these models for enhanced performance on your unique dataset. Now that you're ready, let's get started!

We will use the BERT model as an example and fine-tune it to perform classification task on the Large Movie Review Dataset.

To perform the fine-tuning, MindNLP provides two approaches: one approach is through the user-friendly Trainer API from MindNLP, which supports essential training functionalities; To have more customized control, you can use the other approach through native MindSpore. We will guide you through both approaches in this tutorial.

For both of the approches, you first need to prepare the dataset by running the Prepare a dataset part of this tutorial.

After dataset is ready, choose one of the trategies from below and start your journey!

Prepare a dataset

Before you can fine-tune a pretrained model, download a dataset and prepare it for training.

MindNLP includes a load_dataset API that loads any dataset from the Hugging Face dataset repository. Here let's use it to load the Large Movie Review Dataset dataset, which is named 'imdb', and split it into training, validation and test datasets.

from mindnlp import load_dataset

imdb_ds = load_dataset('imdb', split=['train', 'test'])
imdb_train = imdb_ds['train']
imdb_test = imdb_ds['test']

# Split train dataset further into training and validation datasets
imdb_train, imdb_val = imdb_train.split([0.7, 0.3])

Next, load the tokenizer for the model. The process of tokenization converts raw text into a format that machine learning models can process, which is crucial for natural language processing tasks.

In MindNLP, AutoTokenizer helps automatically fetch and instantiate the appropriate tokenizer for a pre-trained model.

from mindnlp.transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

Once the dataset and the tokenizer are ready, we can process the dataset.

This includes * Tokenize the text. * Cast to correct datatype. * Handle variable sequence lengths with padding or truncation. * Shuffle the order of entries. * Batch the dataset.

In the Data Preprocess tutorial, these steps will be elaborated.

Here, we define the following process_dataset function to prepare the dataset.

import mindspore
import numpy as np
from mindspore.dataset import transforms

def process_dataset(dataset, tokenizer, max_seq_len=256, batch_size=32, shuffle=False, take_len=None):
    # The tokenize function
    def tokenize(text):
        tokenized = tokenizer(text, truncation=True, max_length=max_seq_len)
        return tokenized['input_ids'], tokenized['token_type_ids'], tokenized['attention_mask']

    # Shuffle the order of the dataset
    if shuffle:
        dataset = dataset.shuffle(buffer_size=batch_size)

    # Select the first several entries of the dataset
    if take_len:
        dataset = dataset.take(take_len)

    # Apply the tokenize function, transforming the 'text' column into the three output columns generated by the tokenizer.
    dataset = dataset.map(operations=[tokenize], input_columns="text", output_columns=['input_ids', 'token_type_ids', 'attention_mask'])
    # Cast the datatype of the 'label' column to int32 and rename the column to 'labels'
    dataset = dataset.map(operations=transforms.TypeCast(mindspore.int32), input_columns="label", output_columns="labels")
    # Batch the dataset with padding.
    dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),
                                                         'token_type_ids': (None, 0),
                                                         'attention_mask': (None, 0)})
    return dataset

Now process all splits of the dataset and create smaller subsets of the datasets to shorten the process of the fine-tuning:

batch_size = 4
take_len = batch_size * 200
small_dataset_train = process_dataset(imdb_train, tokenizer, batch_size=batch_size, shuffle=True, take_len=take_len)
small_dataset_val = process_dataset(imdb_val, tokenizer, batch_size=batch_size, shuffle=True, take_len=take_len)
small_dataset_test = process_dataset(imdb_test, tokenizer, batch_size=batch_size, shuffle=True, take_len=take_len)

Here take_len is an optional parameter, which helps to create a smaller subset of the dataset to shorten the process of the fine-tuning.

In practical fine-tuning jobs, however, the full dataset is normally used.

Train

At this stage, you can choose either the MindNLP Trainer API or the native MindSpore approach to fine-tune the model.

Let's start with the Trainer API approach.

Train with MindNLP Trainer

MindNLP comes with a Trainer class designed to simplify model training. With Trainer, you can avoid the need to manually write your own training loop.

Trainer supports a wide range of training options, which will be explained in the Use Trainer tutorial.

Initialize the model

In our example here, we will first instantiate the pretrained BERT model.

For this purpose, we use AutoModelForSequenceClassification. Supply the name of the pretrained model, i.e. 'bert-base-cased' to AutoModelForSequenceClassification. It will automatically infer the model architecture, instatiate the model and load the pretrained parameters. The model loaded here is a BERT model specialized in classification tasks, BertForSequenceClassification.

To supply additional arguments to the model initialization, you can add more key-word arguments. Here, since the classification task involves determining whether a movie review expresses a positive or negative sentiment, we supply num_labels=2 to the BERT model.

For different types of tasks, MindNLP has a variety AutoModel classes to be chosen from.

from mindnlp.transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels=2)
print(type(model))

Training hyperparameters

Next, create a TrainingArguments class where you can define the hyperparameters used in training.

from mindnlp.engine import TrainingArguments
training_args = TrainingArguments(
    "../../output",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    learning_rate=2e-5,
    num_train_epochs=3,
    logging_steps=200,
    evaluation_strategy='epoch',
    save_strategy='epoch'
)

For a comprehensive understanding of more parameters in TrainingArguments, please refer to the Use Trainer tutorial. Here, we specified the following parameters. * output_dir: This is the directory where all outputs like model checkpoints and predictions will be saved. In this example, it is set to "../../output". * per_device_train_batch_size: This controls the batch size used for training on each device. Since we already batched our dataset, here the batch size is set to 1. If you want to use Trainer's batch functionality, you can set your own batch size here. * per_device_eval_batch_size: Similar to the training batch size, but used during the evaluation phase on the validation data. Since we already batched our dataset, here the batch size is set to 1. * learning_rate: The rate at which the model learns. Smaller values mean slower learning, but they may lead to better model fine-tuning. * num_train_epochs: Defines how many times the training loop will run over the entire training dataset. * evaluation_strategy: Determines the strategy for performing evaluation. Setting it to 'epoch' means that the model is evaluated at the end of each training epoch. * logging_steps: This setting controls how often to log training loss and other metrics into the console. It helps in monitoring the training progress. * save_strategy: Determines the strategy for saving model checkpoints. Setting it to 'epoch' ensures that the model is saved at the end of every epoch.

Evaluate

Evaluation is essential for understanding the model's performance and generalizability on new, unseen data.

To enable evaluation of your model's performance during training, it's necessary to supply a function for metric compuation to Trainer.

Here, we write a compute_metrics function, which will take an EvalPrediction object as input, and compute the evaluation metrics between the predictions and ground-truth labels.

import evaluate
import numpy as np
from mindnlp.engine.utils import EvalPrediction

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred: EvalPrediction):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

Initialize the trainer

Once the TrainingArguments instance is configured, you can pass it to the Trainer class along with your model and datasets. This setup allows the Trainer to utilize these arguments throughout the training and evaluation phases.

from mindnlp.engine import Trainer
trainer = Trainer(
    model=model,
    train_dataset=small_dataset_train,
    eval_dataset=small_dataset_val,
    compute_metrics=compute_metrics,
    args=training_args,
)

Start training

Now we are all set, let's start training!

trainer.train()

Use the trained model

You can now use the trained model to predict on a simple example. We define a text, tokenize it and use it as model input.

import numpy as np
from mindspore import Tensor, ops

text = "What an amusing movie!"

# Tokenize the text
inputs = tokenizer(text, padding=True, truncation=True, max_length=256)
ts_inputs = {key: Tensor(val).expand_dims(0) for key, val in inputs.items()}

# Predict
model.set_train(False)
outputs = model(**ts_inputs)
print(outputs)

The outputs are logits, which can be converted to the probability that the given text belong to each category.

# Convert predictions to probabilities
predictions = ops.softmax(outputs.logits)
probabilities = predictions.numpy().flatten()

# Here first class is 'negative' and the second is 'positive'
print(f"Negative sentiment: {probabilities[0]:.4f}")
print(f"Positive sentiment: {probabilities[1]:.4f}")

Train in native MindSpore

If you prefer to have more customized control over the training process, you can also fine-tune a in native MindSpore.

If you went trough the Train with MindNLP Trainer part, you may need to restart your notebook and re-run the Prepare a dataset part, or execute the following code to free some memory:

# Free up memory by deleting model and trainer used in the Train with MindNLP Trainer step
del model
del trainer

Load the model

Load your model with the number of expected labels:

from mindnlp.transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)

Optimizer and loss function

Set up the optimizer, which updates the model parameters to minimize the loss function based on the computed gradients. Let's use the AdamWeightDeday optimizer from MindSpore:

from mindspore.experimental.optim import AdamW

optimizer = AdamW(model.trainable_params(), lr=5e-5)

Define the loss function, which quantifies the difference between the model's predictions and the actual target values. Here we use the cross-entropy loss function:

from mindspore import ops
loss_fn = ops.cross_entropy

Forward and Gradient Functions

Define a forward function forward_fn to manage the forward pass of the model and compute the loss.

Then make use of MindSpore's value_and_grad, and define a gradient function grad_fn to automatically compute both the loss and the gradients of this loss with respect to the model's parameters.

from mindspore import value_and_grad
from tqdm import tqdm

def forward_fn(data, labels):
    logits = model(**data).logits
    loss = loss_fn(logits, labels)
    return loss

grad_fn = value_and_grad(forward_fn, None, optimizer.parameters)

Training step

Implement a train_step function that will be excuted in each step of the training.

This function processes a single batch of data, computes the loss and gradients, and updates the model parameters.

def train_step(batch):
    labels = batch.pop('labels')
    loss, grads = grad_fn(batch, labels)
    optimizer(grads)
    return loss

Training loop for one epoch

Implement a train_one_epoch function that trains the model for one epoch by iterating over all batches in the dataset:

from tqdm import tqdm

def train_one_epoch(model, train_dataset, epoch=0):
    model.set_train(True)
    total = train_dataset.get_dataset_size()
    loss_total = 0
    step_total = 0
    with tqdm(total=total) as progress_bar:
        progress_bar.set_description('Epoch %i' % epoch)
        for batch in train_dataset.create_dict_iterator():
            loss = train_step(batch)
            loss_total += loss.asnumpy()
            step_total += 1
            progress_bar.set_postfix(loss=loss_total/step_total)
            progress_bar.update(1)

Before the training loop starts, train_one_epoch sets the model to the training mode by model.set_train(True).

In each iteration, the function calls train_step on the current batch of data.

To keep track of the training progress, it also accumulates and displays the average loss across batches in a progress bar, providing a real-time view of the training progress during the epoch.

Evaluation

Create a function to compute the accuracy of the model's predictions. Similar as in training with Trainer API, we make use of the evaluate package from Hugging Face.

import evaluate
import numpy as np

metric = evaluate.load("accuracy")

def compute_accuracy(logits, labels):
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

Implement a function evaluate_fn to evaluate the model on a validation dataset.

def evaluate_fn(model, test_dataset, criterion, epoch=0):
    total = test_dataset.get_dataset_size()
    epoch_loss = 0
    epoch_acc = 0
    step_total = 0
    model.set_train(False)

    with tqdm(total=total) as progress_bar:
        progress_bar.set_description('Epoch %i' % epoch)
        for batch in test_dataset.create_dict_iterator():
            label = batch.pop('labels')
            logits = model(**batch).logits
            loss = criterion(logits, label)
            epoch_loss += loss.asnumpy()

            acc = compute_accuracy(logits, label)['accuracy']
            epoch_acc += acc

            step_total += 1
            progress_bar.set_postfix(loss=epoch_loss/step_total, acc=epoch_acc/step_total)
            progress_bar.update(1)

    return epoch_loss / total

By the start of the evaluation, evaluate_fn disables the training mode by model.set_train(False)

The function then iterates over all test batches. For each batch, it computes the logits, calculates the loss, and assesses the accuracy. These metrics are accumulated to provide average loss and accuracy for the epoch, which are displayed on a progress bar.

Training loop for all epochs

Finally, we can excute the training that loops through each epoch and at the end of each epoch, evaluate the models' performance.

When the validation performance is better than all previous epochs, the model parameters will be saved as checkpoint file for future use.

import mindspore as ms
num_epochs = 3
best_valid_loss = float('inf')

for epoch in range(num_epochs):
    train_one_epoch(model, small_dataset_train, epoch)
    valid_loss = evaluate_fn(model, small_dataset_val, loss_fn, epoch)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        ms.save_checkpoint(model, '../../sentiment_analysis.ckpt')

Use the trained model

If you are curious about how your trained model actually classifies text to its sentiment category, try the following code:

# Predict on example
import numpy as np
from mindspore import Tensor, ops

text = "I am pretty convinced that the movie depicted the future of AI in an elegant way."

# Encode the text to input IDs and attention masks
inputs = tokenizer(text, padding=True, truncation=True, max_length=256)
ts_inputs = {key: Tensor(val).expand_dims(0) for key, val in inputs.items()}

# Predict
model.set_train(False)
outputs = model(**ts_inputs)
print(outputs)

# Convert predictions to probabilities
predictions = ops.softmax(outputs.logits)
probabilities = predictions.numpy().flatten()

# Here first class is 'negative' and the second is 'positive'
print(f"Negative sentiment: {probabilities[0]:.4f}")
print(f"Positive sentiment: {probabilities[1]:.4f}")