Metadata-Version: 2.1
Name: tkitAutoMask
Version: 0.0.0.316483919
Summary: Terry toolkit tkitAutoMask,
Home-page: https://docs.terrychan.org/tkit-automask/
Author: Terry Chan
Author-email: napoler2008@gmail.com
License: UNKNOWN
Description: # tkitAutoMask
        
        自动构建掩码
        加入多种动态掩码合集，上下三角和动态片段，以及默认的概率
        
        -上三角，实现类似从左到右的预测，就是单向注意，用于续写。
        - 片段，连续多个mask，更加适合解决补全。
        
        
        未来尝试加入 模板预测掩码
        
        
        ```
        pip install tkitAutoMask
        
        
        ```
        
        
        
        ```python
        from tkitAutoMask import autoMask
        from transformers import BertTokenizer
        tokenizer = BertTokenizer.from_pretrained("uer/chinese_roberta_L-2_H-128") 
        # dir(tokenizer)
        tomask = autoMask(
            # transformer,
            mask_token_id = tokenizer.mask_token_id,          # the token id reserved for masking
            pad_token_id = -100,           # the token id for padding
            mask_prob = 0.05,           # 仅仅是常规的掩码比例 masking probability for masked language modeling
            replace_prob = 0.90,        # ~10% probability that token will not be masked, but included in loss, as detailed in the epaper
            mask_ignore_token_ids = [tokenizer.cls_token_id,tokenizer.eos_token_id]  # other tokens to exclude from masking, include the [cls] and [sep] here
        )
        
        
        # x=torch.ones(5,5)
        x = torch.randint(0, 20000, (10, 10))
        for i in range(10):
          a,b=tomask(x)
          # a,b
          print(b)
         
        ```
        labels：形状为[batch_size, seq_length] ，代表MLM任务的标签，注意这里对于原本未被遮盖的词设置为-100，被遮盖词才会有它们对应的id，和任务设置是反过来的。
        例如，原始句子是I want to [MASK] an apple，这里我把单词eat给遮住了输入模型，对应的label设置为[-100, -100, -100, 【eat对应的id】, -100, -100]；
        为什么要设置为-100而不是其他数？ 因为torch.nn.CrossEntropyLoss默认的ignore_index=-100，也就是说对于标签为100的类别输入不会计算loss。
        
        ```
        tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  6238,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  7321,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 11728,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  3641,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100, 14913,  -100,  -100,  -100,  -100],
                [ -100,  8332,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 11952,  -100],
                [ -100,  -100,  -100,  -100, 12768,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,    77],
                [ -100,  -100, 16031,  -100,  -100,  -100,  -100,  -100,  -100,  -100]])
        tensor([[ -100,  -100,  1312,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  7849],
                [ 9007,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  1822],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 17593],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 13736,  -100,  -100],
                [ -100,  -100,  -100, 16620,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 18083,  -100,  -100],
                [ -100,  -100,  -100, 15338,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100, 12984,  -100,  -100,  -100,  -100,  -100,  -100]])
        tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  4867],
                [ -100, 15820,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ 9007,  1684,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ 4373, 13507,  -100,  -100,  -100,  -100,  -100, 19849,  -100,  -100],
                [19143, 19690, 16235,  -100,  -100, 14913,  -100,  -100,  -100,  -100],
                [18837,  8332, 13231, 16312,  -100,  -100,  8517,  -100,  -100,  -100],
                [ 1567,   928,   268, 16620, 16337,  2932,  -100,  -100,  -100,  -100],
                [ 9537,  1362, 16203, 10865, 12768, 10351,  -100,  -100,  -100,  4658],
                [12488, 17234,  4130, 15338,  4766,  6458, 15765,  -100,  -100,  -100],
                [19972,   457, 16031, 12984, 14118,  4127, 13889, 13456,  -100,  -100]])
        tensor([[ 2649,  3837,  1312, 12421, 15558,  -100,  -100,  -100,  -100,  -100],
                [ -100, 15820,  2654,  3647, 13259,  6178,  -100,  -100,  -100,  7849],
                [ 9007,  -100, 17864,   360,  4748, 10698,  3624,  -100,  -100,  -100],
                [ -100, 13507,  -100,  5198,  4845, 18414,  3641, 19849,  -100,  -100],
                [ -100,  -100,  -100, 17247,  7694, 14913,  4696,  3476,  7539,  -100],
                [ -100,  -100,  -100,  -100,  -100,  5739,  8517, 13736,  8122, 16682],
                [ -100,  -100,  -100,  -100, 16337,  -100, 12610,  6181, 11952,  4669],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 18083, 14632,  4658],
                [ -100,  -100,  -100, 15338,  -100,  -100,  -100,  -100, 10558,    77],
                [ -100,  -100, 16031,  -100,  -100,  -100,  -100,  -100,  -100, 12816]])
        tensor([[ -100,  -100,  -100,  -100, 15558,  -100,  -100,  -100,  -100,  -100],
                [ -100, 15820,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100, 17864,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  4845,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  7694,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100, 16312,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100, 12610,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  4658],
                [12488,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 13456,  -100,  -100]])
        tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  4867],
                [ -100,  -100,  2654,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ 9007,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  3641,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100, 14913,  -100,  -100,  -100,  -100],
                [ -100,  -100, 13231,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,   268,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 14632,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100, 15765,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100, 14118,  -100,  -100,  -100,  -100,  -100]])
        tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  7519,  -100,  -100,  -100],
                [15670,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  1684,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  1822],
                [ -100, 19690,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100, 13231,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  4669],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  4658],
                [12488,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100, 16031,  -100,  -100,  -100,  -100,  -100,  -100,  -100]])
        tensor([[ 2649,  3837,  1312,  -100,  -100,   976,  -100,  -100,  -100,  -100],
                [ -100, 15820,  2654,  3647,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100, 17864,   360,  4748,  -100,  3624,  -100,  -100,  -100],
                [ 4373,  -100,  -100,  5198,  4845, 18414,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  7694, 14913,  4696,  -100,  7539,  -100],
                [ -100,  -100,  -100,  -100,  -100,  5739,  8517, 13736,  -100,  -100],
                [ -100,   928,  -100,  -100,  -100,  -100, 12610,  6181, 11952,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 18083, 14632,  4658],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100, 19026, 10558,    77],
                [ -100,   457,  -100,  -100,  -100,  -100,  -100,  -100,  -100, 12816]])
        tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  7519,  -100,  -100,  -100],
                [ -100,  -100,  2654,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  4748,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  7381,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  7539,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  8122,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100, 12610,  -100,  -100,  -100],
                [ -100,  -100, 16203,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  6458,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  4127,  -100,  -100,  -100,  -100]])
        tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  6238,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  7849],
                [ -100,  -100,  -100,  -100,  4748,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100, 18414,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100, 14913,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100, 16312,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,   928,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  -100, 19242,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  6458,  -100,  -100,  -100,  -100],
                [ -100,  -100,  -100,  -100,  -100,  4127,  -100,  -100,  -100,  -100]])
        
        ```
        
        
        其他测试
        
        https://colab.research.google.com/drive/1CvkoJ1pZQDRWGPA-5IzJufvocBM-RVT2#scrollTo=UwkociF5ZF-d
        
        https://colab.research.google.com/drive/1kNHD0I0wH3WBpJXPdgZqs0MZTRnGD-ok#scrollTo=6M1ZXRsuxZAa
        
        unilm_mask注意力写法
        https://colab.research.google.com/drive/11IDalP2xNYWzF4gIz6T3yTjp53UqzkOe#scrollTo=gFeycxpykrCx
        
        详细参考
        
        > dev.md
        
        
        
Platform: UNKNOWN
Description-Content-Type: text/markdown
