bert4torch.snippets module¶
- class bert4torch.snippets.AdversarialTraining(mode, adversarial={})[source]¶
对抗训练Callback
- Parameters
mode – str, 对抗训练的模式,可选{‘fgm’, ‘pgd’, ‘vat’, ‘gradient_penalty’}
adversarial – dict, 对抗训练的参数配置,不同模式所需参数不同
- class bert4torch.snippets.AutoRegressiveDecoder(start_id, end_id, maxlen, minlen=1, device='cpu')[source]¶
通用自回归生成模型解码基类 包含beam search和random sample两种策略
- Parameters
start_id – int, 解码使用的起始token_id,不同预训练模型设置可能不一样
end_id – int, 解码使用的结束token_id,不同预训练模型设置可能不一样
maxlen – int, 最大解码长度
minlen – int, 最小解码长度, 默认为1
device – str, 默认为’cpu’
- beam_search(inputs_raw, topk, states=None, temperature=1, min_ends=1, add_btz_dim=True)[source]¶
beam search解码
- Parameters
inputs_raw – tensor、array、list、tuple, 解码的输入,一般为last_hidden_state, shape=[btz, seq_len, hdsz]
topk – int, 这里的topk即beam size
states –
temperature – 温度参数,默认为1
min_ends –
add_btz_dim – bool, 是否保留btz维度, 默认为True
- Returns
最优解码序列。
- predict(inputs, output_ids, states=None)[source]¶
用户需自定义递归预测函数; 说明: 定义的时候,需要用wraps方法进行装饰,传入default_rtype和use_states,其中default_rtype为字符串logits或probas,probas时返回归一化的概率, rtype=logits时则返回softmax前的结果或者概率对数。
- Returns
二元组 (得分或概率, states)
- random_sample(inputs_raw, n, topk=None, topp=None, states=None, temperature=1, min_ends=1, add_btz_dim=True)[source]¶
随机采样n个结果; 说明: 非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。
- Parameters
inputs_raw – tensor、array、list、tuple, 解码的输入,一般为last_hidden_state, shape=[btz, seq_len, hdsz]
topk – int, 这里的topk即beam size
topp – float, 这里的topp是token的概率阈值设置
states –
temperature – 温度参数,默认为1
min_ends –
- Returns
n个解码序列组成的list。
- class bert4torch.snippets.VAT(model, emb_name='word_embeddings', noise_var=1e-05, noise_gamma=1e-06, adv_step_size=0.001, adv_alpha=1, norm_type='l2', **kwargs)[source]¶
虚拟对抗训练 https://github.com/namisan/mt-dnn/blob/v0.2/alum/adv_masked_lm.py
- class bert4torch.snippets.WebServing(host='0.0.0.0', port=8000, server='paste')[source]¶
简单的Web接口,基于bottlepy简单封装,仅作为临时测试使用,不保证性能。
- Example:
>>> arguments = {'text': (None, True), 'n': (int, False)} >>> web = WebServing(port=8864) >>> web.route('/gen_synonyms', gen_synonyms, arguments) >>> web.start() >>> # 然后访问 http://127.0.0.1:8864/gen_synonyms?text=你好
- 依赖(如果不用 server=’paste’ 的话,可以不装paste库):
>>> pip install bottle >>> pip install paste
- bert4torch.snippets.create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0)[source]¶
生成padding_ids, 从padding_idx+1开始。忽略填充符号
- bert4torch.snippets.get_kw(cls, kwargs)[source]¶
保留排除cls的入参后的kwargs
- Parameters
cls – 类
kwargs – dict, 所有参数
- bert4torch.snippets.get_pool_emb(hidden_state=None, pooler=None, attention_mask=None, pool_strategy='cls', custom_layer=None)[source]¶
获取句向量
- Parameters
hidden_state – torch.Tensor/List(torch.Tensor),last_hidden_state/all_encoded_layers
pooler – torch.Tensor, bert的pool_output输出
attention_mask – torch.Tensor
pool_strategy – str, (‘cls’, ‘last-avg’, ‘mean’, ‘last-max’, ‘max’, ‘first-last-avg’, ‘custom’)
custom_layer – int/List[int],指定对某几层做average pooling
- bert4torch.snippets.get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None)[source]¶
sinusoid编码
- Parameters
n_position – int, 位置长度
d_hid – int, 位置编码长度
padding_idx – padding的token_ids
- Returns
[seq_len, d_hid]
- bert4torch.snippets.merge_segmentate(sequences, maxlen, sep='')[source]¶
把m个句子合并成不超过maxlen的n个句子, 主要用途是合并碎句子
- Parameters
sequences – List(str), 短句子列表
maxlen – int, 最大长度
sep – str, 合并使用的分隔符, 可以是,。等标点符号
- bert4torch.snippets.parallel_apply(func, iterable, workers, max_queue_size, callback=None, dummy=False, random_seeds=True, unordered=True)[source]¶
多进程或多线程地将func应用到iterable的每个元素中(直接从bert4keras中移植过来)。 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是输出可能是func(c), func(a), func(b)。
- Parameters
callback – 处理单个输出的回调函数;
dummy – False是多进程/线性,True则是多线程/线性;windows需设置dummy=True
random_seeds – 每个进程的随机种子;
unordered – 若为False,则按照输入顺序返回,仅当callback为None时生效。
- bert4torch.snippets.parallel_apply_generator(func, iterable, workers, max_queue_size, dummy=False, random_seeds=True)[source]¶
多进程或多线程地将func应用到iterable的每个元素中(直接从bert4keras中移植过来)。 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是输出可能是func(c), func(a), func(b)。结果将作为一个 generator返回,其中每个item是输入的序号以及该输入对应的处理结果。
- Parameters
dummy – False是多进程/线性,True则是多线程/线性;
random_seeds – 每个进程的随机种子。
- bert4torch.snippets.sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post')[source]¶
将序列padding到同一长度
- bert4torch.snippets.text_augmentation(texts, noise_dict=None, noise_len=0, noise_p=0.0, skip_words=None, strategy='random', allow_dup=True)[source]¶
简单的EDA策略, 增删改
- Parameters
texts – 需要增强的文本/文本list
noise_dict – 噪音数据, 元素为str的list, tuple, set
noise_len – 噪音长度, 优先试用
noise_p – 噪音比例
skip_words – 跳过的短语, string/list
strategy – 修改的策略, 包含增insert, 删delete, 改replace, 随机random
allow_dup – 是否允许同一个位置多次EDA