首页
学习
活动
专区
工具
TVP
发布
社区首页 >问答首页 >如何使用functorch.jacrev计算BertForMaskedLM的雅克比矩阵?

如何使用functorch.jacrev计算BertForMaskedLM的雅克比矩阵?

提问于 2022-06-21 01:54:05
回答 0关注 0查看 129
如何使用functorch.jacrev计算BertForMaskedLM的雅克比矩阵?

我想要使用functorch.jacrev计算BertForMaskedLM的雅克比矩阵,我的尝试方案:

问题相关代码
代码语言:js
复制
import numpy as np
from transformers import BertTokenizer,BertForMaskedLM
import torch
import torch.nn as nn
from functorch import make_functional, make_functional_with_buffers, vmap, vjp, jvp, jacrev
device = 'cuda:2'
torch.cuda.empty_cache()
model_name = 'bert-base-chinese'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertForMaskedLM.from_pretrained(model_name)
 
net = bert_model.to(device)
fnet, params, buffers = make_functional_with_buffers(net)
 
def fnet_single(params,x,y):
    result = fnet(params, buffers, x.unsqueeze(0).unsqueeze(0),y.unsqueeze(0).unsqueeze(0))['logits']
    return result.squeeze(0).squeeze(0)
 
text = u'江苏省苏州市读者马玉兰有一个在外地上学的朋友'
inputs = tokenizer.encode_plus(text)
 
segment_ids = inputs['token_type_ids']
token_ids = inputs['input_ids']
length = len(token_ids) - 2
batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device)
batch_segment_ids = torch.zeros_like(batch_token_ids).to(device)
 
for i in range(length):
    if i > 0:
        batch_token_ids[2 * i - 1, i] = 103
        batch_token_ids[2 * i - 1, i + 1] = 103
    batch_token_ids[2 * i, i + 1] = 103
threshold = 100
word_token_ids = [[token_ids[1]]]
for i in range(1, length):
    x,y = batch_token_ids[2 * i],batch_segment_ids[2*i]
    jacobian1 = jacrev(fnet_single,argnums=1)(params,x,y)
    x,y = batch_token_ids[2 * i - 1],batch_segment_ids[2*i-1]
    jacobian2 = jacrev(fnet_single,argnums=1)(params,x,y)
    print(jacobian1,end='-----------------jacobian1-----------------\n')  
    print(jacobian2,end='-----------------jacobian2-----------------\n') 

运行结果及报错内容

Traceback (most recent call last): File "study_jacrev.py", line 49, in batch_token_ids = torch.tensor([token_ids] * (2 * length - 1),requires_grad=True).to(device) RuntimeError: Only Tensors of floating point and complex dtype can require gradients

我想要达到的结果

可以计算出BertForMaskedLM的雅克比矩阵(梯度)

期待您的答复!

回答

和开发者交流更多问题细节吧,去 写回答
相关文章

相似问题

相关问答用户
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档