|
|
@@ -95,6 +95,7 @@ class TrainerOptions:
|
|
|
use_pai: bool
|
|
|
oss_bucket: Union[oss2.Bucket, None]
|
|
|
batch_interval: int
|
|
|
+ bias_grad_times: float
|
|
|
|
|
|
class Trainer:
|
|
|
"""Trainer having a optimizer.
|
|
|
@@ -546,8 +547,11 @@ class Trainer:
|
|
|
no_forward_run = options.no_forward_run
|
|
|
ngpu = options.ngpu
|
|
|
use_wandb = options.use_wandb
|
|
|
+ bias_grad_times = options.bias_grad_times
|
|
|
distributed = distributed_option.distributed
|
|
|
|
|
|
+ if bias_grad_times != 1.0:
|
|
|
+ logging.warning("Using bias_grad_times: {} for gradient scaling".format(bias_grad_times))
|
|
|
if log_interval is None:
|
|
|
try:
|
|
|
log_interval = max(len(iterator) // 20, 10)
|
|
|
@@ -690,6 +694,16 @@ class Trainer:
|
|
|
scale_factor=0.55,
|
|
|
)
|
|
|
|
|
|
+ # for contextual training
|
|
|
+ if bias_grad_times != 1.0:
|
|
|
+ # contextual related parameter names
|
|
|
+ cr_pnames = ["bias_encoder", "bias_embed", "decoder.bias_decoder", "decoder.bias_output"]
|
|
|
+ for name, param in model.named_parameters():
|
|
|
+ for cr_pname in cr_pnames:
|
|
|
+ if cr_pname in name:
|
|
|
+ param.grad *= bias_grad_times
|
|
|
+ continue
|
|
|
+
|
|
|
# compute the gradient norm to check if it is normal or not
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
|
model.parameters(),
|