본문 바로가기
A. Research/Machine Learning

PPL 원리와 구현

by IMCOMKING 2022. 5. 16.

PPL 원리와 구현

PPL이란

PPL의 정의는 기본적으로 target token seq에 대한 모델의 negative log-likelihood (NLL)의 평균을 exponential한 것이다.

Perplexity of fixed-length models

계산 방법

PPL은 target data와 model prediction 사이의 CrossEntropy Loss를 exponentiation하여 매우 쉽게 계산할 할 수 있다. 이는 자명한 것이, model prediction과 target data (일종의 label로 생각)와의 CE를 계산하게 되면, target token의 index에 대한 model predction의 확률만으로 NLL이 계산되기 때문이다.

$$ CrossEntropy(ModelPrediction,\ TargetSequence) = \\

-\sum{p_{label} * \log{p_{model}}} = \\

-\log{p_{model}}[targetTokenIndex] $$

  • This is also equivalent to the exponentiation of the cross-entropy between the data and model predictions. For more intuition about perplexity and its relationship to Bits Per Character (BPC) and data compression, check out this fantastic blog post on The Gradient.

Evaluation Metrics for Language Modeling

Train Loss to Train PPL

위의 정의에 따라서, Train PPL의 경우 Train Loss에 exponential을 취하면 바로 계산이 가능하다.

PPL의 의의

결과적으로 perplexity의 장점은 log가 들어가 값을 이해하기 어려운 cross_entropy에 비해서, log를 exponential로 캔슬링 하기 때문에 사람이 숫자를 직관적으로 이해하기 쉽다는 점이다.

 

Python code

  • input_text: dialouge context
  • labels: target sentence to be measured
  • model: GPT3 language model
  • tokenizer: GPT3 tokenizer
  • repetition_penalty: penalty score in labels with repetition on input_texts
def get_ltor_masks_and_position_ids(data,
                                    eod_token,
                                    reset_position_ids,
                                    reset_attention_mask,
                                    eod_mask_loss):
    """Build masks and position id for left to right model."""

    # Extract batch size and sequence length.
    micro_batch_size, seq_length = data.size()

    # Attention mask (lower triangular).
    if reset_attention_mask:
        att_mask_batch = micro_batch_size
    else:
        att_mask_batch = 1
    attention_mask = torch.tril(torch.ones(
        (att_mask_batch, seq_length, seq_length), device=data.device)).view(
            att_mask_batch, 1, seq_length, seq_length)

    # Loss mask.
    loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
    if eod_mask_loss:
        loss_mask[data == eod_token] = 0.0

    # Position ids.
    position_ids = torch.arange(seq_length, dtype=torch.long,
                                device=data.device)
    position_ids = position_ids.unsqueeze(0).expand_as(data)
    # We need to clone as the ids will be modifed based on batch index.
    if reset_position_ids:
        position_ids = position_ids.clone()

    if reset_position_ids or reset_attention_mask:
        # Loop through the batches:
        for b in range(micro_batch_size):

            # Find indecies where EOD token is.
            eod_index = position_ids[b, data[b] == eod_token]
            # Detach indecies from positions if going to modify positions.
            if reset_position_ids:
                eod_index = eod_index.clone()

            # Loop through EOD indecies:
            prev_index = 0
            for j in range(eod_index.size()[0]):
                i = eod_index[j]
                # Mask attention loss.
                if reset_attention_mask:
                    attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
                # Reset positions.
                if reset_position_ids:
                    position_ids[b, (i + 1):] -= (i + 1 - prev_index)
                    prev_index = i + 1

    # Convert attention mask to binary:
    attention_mask = (attention_mask < 0.5)

    return attention_mask, loss_mask, position_ids

def repetition_penalty_logit_processor(logits, input_ids, labels, repetition_penalty=1.0):
    masked_input_ids = input_ids * (labels == -100).int()
    max_len = logits.size(1)
    logits_original = copy.deepcopy(logits)
    for i in range(max_len):
        scores = logits[:, i]
        score = torch.gather(scores, 1, masked_input_ids)
        score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
        scores.scatter_(1, masked_input_ids, score)
    logits[:, :, 0] = logits_original[:, :, 0]
    return logits

def get_ppl(input_texts, labels, model, tokenizer, repetition_penalty=1.0):
    input_ids_list = []
    label_list = []
    label_lens = []
    for input_text, label in zip(input_texts, labels):
        input_ids = tokenizer.encode(input_text).ids
        label = tokenizer.encode(label).ids
        label_lens.append(len(label))

        input_ids = input_ids + label  # concatenate input_ids and label
        padded_label = [-100] * len(
            input_ids
        )  # so that input_ids not considered when calculating loss
        padded_label[-len(label) :] = label
        input_ids = input_ids[:-1]  # shift
        padded_label = padded_label[1:]
        input_ids_list.append(torch.tensor(input_ids, device="cuda"))
        label_list.append(torch.tensor(padded_label, device="cuda"))
    input_ids = pad_sequence(
        input_ids_list, batch_first=True, padding_value=EOD_ID,
    ).long()
    labels = pad_sequence(label_list, batch_first=True, padding_value=-100,).long()

    attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
        input_ids,
        EOD_ID,
        reset_position_ids=False,
        reset_attention_mask=False,
        eod_mask_loss=False,
    )
    attention_mask = attention_mask.cuda()
    position_ids = position_ids.cuda()

    ppls = []
    with torch.no_grad():
        logits = model(input_ids, position_ids, attention_mask)
        logits = repetition_penalty_logit_processor(
            logits, input_ids, labels, repetition_penalty
        )
        bs, maxlen, _ = logits.size()
        for i in range(bs):
            ppl = (
                torch.exp(nn.functional.cross_entropy(logits[i], labels[i]))
                .cpu()
                .tolist()
            )
            ppls.append(ppl)
    return ppls

Details

위 코드의 동작 방식은 기본적으로 Fully-autogressive 방식의 구현이다. 그런데 model을 여러번 호출하는 것이 너무 느리기 때문에, 여기서는 time_dim으로 autogressive한 forward pass를 upper triangular attention_mask trick을 사용해서 batch_dim으로 재구성하여, model을 한 번만 호출하도록 구현되어있다.

triangular mask attention

그리고 input_texts에 대한 ppl계산을 방지하기 위해서, [-100]이라는 ignore_index를 input_texts의 길이만큼 padded_label의 앞부분에 붙여줘서 CrossEntorpy 계산 시 반영이 안되도록 해서. input_texts에 대한 ppl은 무시하고, labels에 대해서만 ppl을 측정하는것이다.

Heuristic 구현

위와 같이 autogressive한 forward pass를 upper triangular attention_mask trick을 사용할 수 없는 상황에서는 autogressive 계산을 포기하고, input_text로 부터 한 번에 label가 생성될 확률을 측정하는 방법이 있다.

예를 들어, hugging face에 구현된 GPT-J 모델의 경우에는, attention_mask를 전달하는게 아니고, seq_length만 전달하면 attention_mask 직접 내부에서 만들어서 사용하게 된다. 이러한 경우 위와 같은 upper triangular mask attention을 사용할 수가 없다. 그래서 이러한 경우에는 for loop을 돌면서 직접 RNNcjfja auto-regressive하게 masking을 시켜야하는데, 이러면 병렬성을 살릴 수 없어서 속도가 매우 느려진다.

그래서 이런 경우엔 차라리 auto-regressive한 속성을 포기하고, input_text로 부터 label이 생성될 확률에 대한 PPL을 측정하는 heuristic 구현을 사용하는 것이 옳다. 즉 1-token씩 늘어나면서 생성확률을 측정하는게 아니고, label 문장을 LM이 한 번에 쫙 생성시킬 확률을 측정하는 것이다. 아마도 이 방법은 기존의 auto-regressive방법과 teacher-forcing을 하고 안하고 정도의 차이만 있을 것으로 보이고, 그 오차는 허용가능한 수준일 것으로 보인다.

  • 이 방식은 GPT모델 학습 중에 test ppl을 측정할 때도 아마 사용된다고 알고 있다.(test loss를 계산하면서 동시에 ppl을 측정할 수 있으니까?) 이 부분 확실하지 않네. 어쩌면 test loss 계산에도 teacher forcing을 하는걸까?

'A. Research > Machine Learning' 카테고리의 다른 글

PPL 원리와 구현  (0) 2022.05.16
Advanced Gradient Descent Method  (0) 2015.10.24

댓글0