bert4torch.snippets module

class bert4torch.snippets.AdversarialTraining(mode, adversarial={})[source]

对抗训练Callback

Parameters
  • mode – str, 对抗训练的模式,可选{‘fgm’, ‘pgd’, ‘vat’, ‘gradient_penalty’}

  • adversarial – dict, 对抗训练的参数配置,不同模式所需参数不同

class bert4torch.snippets.AutoRegressiveDecoder(start_id, end_id, maxlen, minlen=1, device='cpu')[source]

通用自回归生成模型解码基类 包含beam search和random sample两种策略

Parameters
  • start_id – int, 解码使用的起始token_id,不同预训练模型设置可能不一样

  • end_id – int, 解码使用的结束token_id,不同预训练模型设置可能不一样

  • maxlen – int, 最大解码长度

  • minlen – int, 最小解码长度, 默认为1

  • device – str, 默认为’cpu’

beam search解码

Parameters
  • inputs_raw – tensor、array、list、tuple, 解码的输入,一般为last_hidden_state, shape=[btz, seq_len, hdsz]

  • topk – int, 这里的topk即beam size

  • states

  • temperature – 温度参数,默认为1

  • min_ends

  • add_btz_dim – bool, 是否保留btz维度, 默认为True

Returns

最优解码序列。

predict(inputs, output_ids, states=None)[source]

用户需自定义递归预测函数; 说明: 定义的时候,需要用wraps方法进行装饰,传入default_rtype和use_states,其中default_rtype为字符串logits或probas,probas时返回归一化的概率, rtype=logits时则返回softmax前的结果或者概率对数。

Returns

二元组 (得分或概率, states)

random_sample(inputs_raw, n, topk=None, topp=None, states=None, temperature=1, min_ends=1, add_btz_dim=True)[source]

随机采样n个结果; 说明: 非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。

Parameters
  • inputs_raw – tensor、array、list、tuple, 解码的输入,一般为last_hidden_state, shape=[btz, seq_len, hdsz]

  • topk – int, 这里的topk即beam size

  • topp – float, 这里的topp是token的概率阈值设置

  • states

  • temperature – 温度参数,默认为1

  • min_ends

Returns

n个解码序列组成的list。

static wraps(default_rtype='probas', use_states=False)[source]

用来进一步完善predict函数

目前包含:
  1. 设置rtype参数,并做相应处理;

  2. 确定states的使用,并做相应处理;

  3. 设置温度参数,并做相应处理。

class bert4torch.snippets.DottableDict(*args, **kwargs)[source]

支持点操作符的字典

class bert4torch.snippets.FGM(model)[source]

FGM对抗训练

class bert4torch.snippets.PGD(model)[source]

PGD对抗训练

class bert4torch.snippets.VAT(model, emb_name='word_embeddings', noise_var=1e-05, noise_gamma=1e-06, adv_step_size=0.001, adv_alpha=1, norm_type='l2', **kwargs)[source]

虚拟对抗训练 https://github.com/namisan/mt-dnn/blob/v0.2/alum/adv_masked_lm.py

static adv_project(grad, norm_type='inf', eps=1e-06)[source]

L0,L1,L2正则,对于扰动计算

static kl(inputs, targets, reduction='sum')[source]

计算kl散度

:param inputs:tensor,logits :param targets:tensor,logits

class bert4torch.snippets.WebServing(host='0.0.0.0', port=8000, server='paste')[source]

简单的Web接口,基于bottlepy简单封装,仅作为临时测试使用,不保证性能。

Example:
>>> arguments = {'text': (None, True), 'n': (int, False)}
>>> web = WebServing(port=8864)
>>> web.route('/gen_synonyms', gen_synonyms, arguments)
>>> web.start()
>>> # 然后访问 http://127.0.0.1:8864/gen_synonyms?text=你好
依赖(如果不用 server=’paste’ 的话,可以不装paste库):
>>> pip install bottle
>>> pip install paste
route(path, func, arguments, method='GET')[source]

添加接口

start()[source]

启动服务

wraps(func, arguments, method='GET')[source]

封装为接口函数

Parameters
  • func – 要转换为接口的函数,需要保证输出可以json化,即需要保证 json.dumps(func(inputs)) 能被执行成功;

  • arguments – 声明func所需参数,其中key为参数名,value[0]为对应的转换函数(接口获取到的参数值都是字符串型),value[1]为该参数是否必须;

  • method – ‘GET’或者’POST’。

bert4torch.snippets.cal_ts_num(tensor_shape)[source]

查看某个tensor在gc中的数量

bert4torch.snippets.create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0)[source]

生成padding_ids, 从padding_idx+1开始。忽略填充符号

bert4torch.snippets.delete_arguments(*arguments)[source]

装饰器,为类方法删除参数(主要用于类的__init__方法)

bert4torch.snippets.get_kw(cls, kwargs)[source]

保留排除cls的入参后的kwargs

Parameters
  • cls – 类

  • kwargs – dict, 所有参数

bert4torch.snippets.get_pool_emb(hidden_state=None, pooler=None, attention_mask=None, pool_strategy='cls', custom_layer=None)[source]

获取句向量

Parameters
  • hidden_state – torch.Tensor/List(torch.Tensor),last_hidden_state/all_encoded_layers

  • pooler – torch.Tensor, bert的pool_output输出

  • attention_mask – torch.Tensor

  • pool_strategy – str, (‘cls’, ‘last-avg’, ‘mean’, ‘last-max’, ‘max’, ‘first-last-avg’, ‘custom’)

  • custom_layer – int/List[int],指定对某几层做average pooling

bert4torch.snippets.get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None)[source]

sinusoid编码

Parameters
  • n_position – int, 位置长度

  • d_hid – int, 位置编码长度

  • padding_idx – padding的token_ids

Returns

[seq_len, d_hid]

bert4torch.snippets.insert_arguments(**arguments)[source]

装饰器,为类方法增加参数(主要用于类的__init__方法)

bert4torch.snippets.is_string(s)[source]

判断是否是字符串

bert4torch.snippets.lowercase_and_normalize(text, never_split=())[source]

转小写,并进行简单的标准化

bert4torch.snippets.merge_segmentate(sequences, maxlen, sep='')[source]

把m个句子合并成不超过maxlen的n个句子, 主要用途是合并碎句子

Parameters
  • sequences – List(str), 短句子列表

  • maxlen – int, 最大长度

  • sep – str, 合并使用的分隔符, 可以是,。等标点符号

bert4torch.snippets.parallel_apply(func, iterable, workers, max_queue_size, callback=None, dummy=False, random_seeds=True, unordered=True)[source]

多进程或多线程地将func应用到iterable的每个元素中(直接从bert4keras中移植过来)。 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是输出可能是func(c), func(a), func(b)。

Parameters
  • callback – 处理单个输出的回调函数;

  • dummy – False是多进程/线性,True则是多线程/线性;windows需设置dummy=True

  • random_seeds – 每个进程的随机种子;

  • unordered – 若为False,则按照输入顺序返回,仅当callback为None时生效。

bert4torch.snippets.parallel_apply_generator(func, iterable, workers, max_queue_size, dummy=False, random_seeds=True)[source]

多进程或多线程地将func应用到iterable的每个元素中(直接从bert4keras中移植过来)。 注意这个apply是异步且无序的,也就是说依次输入a,b,c,但是输出可能是func(c), func(a), func(b)。结果将作为一个 generator返回,其中每个item是输入的序号以及该输入对应的处理结果。

Parameters
  • dummy – False是多进程/线性,True则是多线程/线性;

  • random_seeds – 每个进程的随机种子。

bert4torch.snippets.sequence_padding(inputs, length=None, value=0, seq_dims=1, mode='post')[source]

将序列padding到同一长度

bert4torch.snippets.text_augmentation(texts, noise_dict=None, noise_len=0, noise_p=0.0, skip_words=None, strategy='random', allow_dup=True)[source]

简单的EDA策略, 增删改

Parameters
  • texts – 需要增强的文本/文本list

  • noise_dict – 噪音数据, 元素为str的list, tuple, set

  • noise_len – 噪音长度, 优先试用

  • noise_p – 噪音比例

  • skip_words – 跳过的短语, string/list

  • strategy – 修改的策略, 包含增insert, 删delete, 改replace, 随机random

  • allow_dup – 是否允许同一个位置多次EDA

bert4torch.snippets.text_segmentate(text, maxlen, seps='\n', strips=None, truncate=True)[source]

将文本按照标点符号划分为若干个短句

Parameters
  • text – 待划分的句子

  • maxlen – int, 截断长度

  • seps – 分隔符

  • strips – ‘’.strip()

  • truncate – True表示标点符号切分后仍然超长时, 按照maxlen硬截断分成若干个短句

Returns

List[str], 划分后的句子列表

bert4torch.snippets.truncate_sequences(maxlen, indices, *sequences)[source]

截断总长度至不超过maxlen