Skip to content

default_func

mindnlp.engine.trainer.default_func

utils for trainer.

mindnlp.engine.trainer.default_func.get_default_forward_fn_with_loss_fn(network, loss_fn, loss_scaler)

get default forward function with loss function

Source code in mindnlp/engine/trainer/default_func.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def get_default_forward_fn_with_loss_fn(network, loss_fn, loss_scaler):
    """get default forward function with loss function"""
    # forward function
    def forward_fn(labels, *args, **kwargs):
        logits_list = ()
        logits = network(*args, **kwargs)
        if isinstance(logits, tuple):
            logits_list += logits
        elif isinstance(logits, ModelOutput):
            logits_list += (logits.logits,)
        else:
            logits_list += (logits,)

        logits_list += labels
        loss = loss_fn(*logits_list)
        loss = loss_scaler.scale(loss)
        return loss

    return forward_fn

mindnlp.engine.trainer.default_func.get_default_forward_fn_without_loss_fn(network, loss_scaler)

get default forward function without loss function

Source code in mindnlp/engine/trainer/default_func.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def get_default_forward_fn_without_loss_fn(network, loss_scaler):
    """get default forward function without loss function"""
    def forward_fn(*args, **kwargs):
        outputs_list = ()
        outputs = network(*args, **kwargs)
        if isinstance(outputs, tuple):
            outputs_list += outputs
        elif isinstance(outputs, ModelOutput):
            outputs_list += (outputs.loss,)
        else:
            outputs_list += (outputs,)

        loss = loss_scaler.scale(outputs_list[0])
        return loss

    return forward_fn

mindnlp.engine.trainer.default_func.get_default_train_step_fn(forward_fn, optimizer, loss_scaler, check_gradients, jit, for_object_net=False)

get default train function

Source code in mindnlp/engine/trainer/default_func.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def get_default_train_step_fn(forward_fn, optimizer, loss_scaler, check_gradients, jit, for_object_net=False):
    """get default train function"""
    grad_fn = value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)

    def default_run_step(labels, *args, **kwargs):
        """Core process of each step, including the forward propagation process and back propagation of data."""
        status = init_status()
        loss, grads = grad_fn(labels, *args, **kwargs)
        loss = loss_scaler.unscale(loss)
        if check_gradients:
            is_finite = all_finite(grads, status)
            if is_finite:
                grads = loss_scaler.unscale(grads)
                optimizer(grads)
            loss_scaler.adjust(is_finite)
        else:
            optimizer(grads)
        return loss

    def default_run_step_for_obj_net(*args, **kwargs):
        """Core process of each step, including the forward propagation process and back propagation of data."""
        status = init_status()
        args = ops.depend(args, status)
        loss, grads = grad_fn(*args, **kwargs)
        loss = loss_scaler.unscale(loss)
        if check_gradients:
            is_finite = all_finite(grads, status)
            if is_finite:
                grads = loss_scaler.unscale(grads)
                loss = ops.depend(loss, optimizer(grads))
            loss = ops.depend(loss, loss_scaler.adjust(is_finite))
        else:
            loss = ops.depend(loss, optimizer(grads))
        return loss

    run_step = default_run_step_for_obj_net if for_object_net else default_run_step

    if jit:
        run_step = ms_jit(run_step)

    return run_step