Update modularStarEncoder.py
Browse files- modularStarEncoder.py +3 -1
    	
        modularStarEncoder.py
    CHANGED
    
    | @@ -205,11 +205,13 @@ def get_pooling_mask( | |
| 205 | 
             
                repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
         | 
| 206 |  | 
| 207 | 
             
                DEVICE = input_ids.get_device()
         | 
|  | |
|  | |
| 208 | 
             
                if DEVICE<0:
         | 
| 209 | 
             
                    DEVICE = "cpu"
         | 
| 210 | 
             
                ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
         | 
| 211 |  | 
| 212 | 
            -
             | 
| 213 | 
             
                pooling_mask = (repeated_idx <= ranges).long()
         | 
| 214 | 
             
                pooling_mask.to(DEVICE)
         | 
| 215 |  | 
|  | |
| 205 | 
             
                repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
         | 
| 206 |  | 
| 207 | 
             
                DEVICE = input_ids.get_device()
         | 
| 208 | 
            +
                print(DEVICE)
         | 
| 209 | 
            +
             | 
| 210 | 
             
                if DEVICE<0:
         | 
| 211 | 
             
                    DEVICE = "cpu"
         | 
| 212 | 
             
                ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
         | 
| 213 |  | 
| 214 | 
            +
                print(repeated_idx.get_device(),ranges.get_device())
         | 
| 215 | 
             
                pooling_mask = (repeated_idx <= ranges).long()
         | 
| 216 | 
             
                pooling_mask.to(DEVICE)
         | 
| 217 |  | 
