
i'm starting vectorizing the operations in my for loop to calculate fetches into pages around sparse file holes, it can be done by rote despite control flow and stuff. interesting that vector operations and loops can be transformed between each other. it seems possible to be workable, if i can vectorize this big long function correctly, then it would condense the millions of scalar requests into instead larger page requests that surround groups of them. one thing that's missing is consolidating adjacent pages, but this is also reasonable to vectorize tails = (offset_lengths[:,0] + offset_lengths[:,1]).clamp(max=len(self.mmap)) aligned_offsets = offset_lengths[:,0] // self.blksize; aligned_offsets *= self.blksize aligned_tails = (tails - 1); aligned_tails //= self.blksize; aligned_tails += 1; aligned_tails *= self.blksize; torch.clamp(aligned_tails, max=self.size(), out=aligned_tails) that's the start of my vectorizing read_many which is just started. called by this code: def fetch_scalars(self, offsets, progress='', validate_usage=True): if validate_usage: bytes_avail_cpu = psutil.virtual_memory().available * self.safeslice.statedict.usage_frac assert len(offsets) * self.dtype.itemsize < bytes_avail_cpu readsize = self.safeslice.tensor.element_size() offset_lengths = torch.empty([offsets.numel(), 2], dtype=int) offset_lengths[:,0] = offsets.view(-1) offset_lengths[:,0] += self.storage_offset() offset_lengths[:,0] *= readsize offset_lengths[:,0] += self.safeslice.offset offset_lengths[:,1] = readsize datas = b''.join(self.safeslice.fetcher.read_many(offset_lengths, progress=progress, validate_sorted=False)) return torch.frombuffer( datas, dtype=self.safeslice.tensor.dtype, count=len(offsets), ).view(offsets.shape).to(self.device, self.dtype) which is called by something like this code in F_linear: major_stride, minor_stride = weight.stride() offsets = torch.add(*torch.meshgrid( torch.arange(weight.shape[0]) * major_stride, top_k_indices * minor_stride, indexing = 'ij' )) ... [below simplified removing local unfixed bugs to try to match new code here] product = torch.matmul( input, weight.fetch_scalars(offsets, progress=name, validate_usage=False).T )