import functools, itertools, random import torch, tqdm class TBasic(torch.nn.Module): V = 1 def __init__(self, d_model, n_heads, d_ff, n_layers=1, d_in=None, d_out=None, dropout=0, device='cpu', dtype=torch.float32): super().__init__() mismatched_params = d_model % n_heads if mismatched_params: d_model += n_heads - mismatched_params self.d_model = d_model self.n_heads = n_heads self.d_ff = d_ff self.n_layers = n_layers self.d_in = d_in self.d_out = d_out self.dtype = dtype self.device = device self.l_in = None if d_in is None else torch.nn.Linear(d_in, d_model, device=device, dtype=dtype) self.t = torch.nn.TransformerEncoderLayer(d_model=d_model, nhead=n_heads, dim_feedforward=d_ff, dropout=dropout, device=device, dtype=dtype, batch_first=True) self.l_out = None if d_out is None else torch.nn.Linear(d_model, d_out, device=device, dtype=dtype) self.__params = [ [fqn, functools.reduce(getattr, ns[:-1], self), ns[-1]] for fqn, v in super().named_parameters() for ns in [fqn.split('.')] ] def make_settable(self): for fqn, o, n in self.__params: v = getattr(o, n) v = v.clone() delattr(o, n) v.retain_grad() setattr(o, n, v) return self def make_trainable(self): for fqn, o, n in self.__params: v = getattr(o, n) v = torch.nn.Parameter(v) setattr(o, n, v) return self def get(self, include_grad=False, flatten=True): # feel free to add a condition block to provide some certain form of these if flatten: if not include_grad: return torch.cat([ getattr(o, n).flatten() for fqn, o, n in self.__params ]) else: return torch.cat([ torch.stack([ getattr(o, n).flatten(), getattr(o, n).grad.flatten() ]) for fqn, o, n in self.__params ], dim=-1).T def set(self, ps): of = 0 for fqn, o, n in self.__params: v = getattr(o, n) of2 = of + len(v.flatten()) v = ps[of:of2].view(v.shape) v.retain_grad() setattr(o, n, v) of = of2 assert of == len(ps) def parameters(self): return [getattr(o, n) for fqn, o, n in self.__params] def named_parameters(self): return [[fqn, getattr(o, n)] for fqn, o, n in self.__params] def forward(self, data = None): if data is None: data = torch.empty([1,0],dtype=self.dtype) else: data = data.to(self.dtype) if self.l_in: if self.l_in.weight.shape[-1] == 1: data = data[...,None] data = self.l_in(data) data = self.t(data) if self.l_out: data = self.l_out(data) if self.l_out.weight.shape[0] == 1: data = data[...,0] return data class TGroups(TBasic): # these are sequential groups. # each group describes a set of sequence items # each item has so many floats or ids and there may be trained embeddings # pass name=dict(trained_len=[0], trained_dim=[d_model], floats=[0], ids=[0], id_dim=[d_model], out=[1]) def __init__(self, **kwparams): groups = { k:kwparams.pop(k) for k,v in list(kwparams.items()) if type(v) is dict } super().__init__(**kwparams, d_in=None, d_out=None) self.groups = {} for name, kws in groups.items(): traineds = kws.get('trained_len', 0) trained_embeds = kws.get('trained_dim', self.d_model) floats = kws.get('floats', 0) ids = kws.get('ids', 0) id_embeds = kws.get('id_dim', self.d_model) out = kws.get('out', 1) in_size = floats if ids: embedding = torch.nn.Embedding(ids, id_embeds) setattr(self, name + '_embed', embedding) in_size += id_embeds else: embedding = None if traineds: trained = torch.nn.Parameter(torch.rand([traineds, trained_embeds], device=self.device, dtype=self.dtype)) setattr(self, name + '_trained', trained) in_size += trained_embeds else: trained = None l_in = torch.nn.Linear(in_size, self.d_model) setattr(self, name + '_in', l_in) l_out = torch.nn.Linear(self.d_model, out) setattr(self, name + '_out', l_out) self.groups[name] = { 'trained': trained, 'n_floats': floats, 'embed': embedding, 'in': l_in, 'out': l_out, } def forward(self, **kwparams): data = [] off = 0 groups = [] for name, kws in self.groups.items(): trained = kws['trained'] embed = kws['embed'] n_floats = kws['n_floats'] l_in = kws['in'] l_out = kws['out'] gdata = trained if n_floats: floats = kwparams.get(name) if floats is None: floats = kwparams[name + '_floats'] gdata = torch.cat(gdata, floats, dim=-1) if embed: ids = kwparams.get(name) if ids is None: ids = kwparams[name + '_ids'] embeds = embed(ids) gdata = torch.cat(gdata, embeds, dim=-1) gdata = l_in(gdata) data.append(gdata) off2 = off + gdata.shape[-2] groups.append([name, off, off2, l_out]) off = off2 data = self.t(torch.cat(data, dim=-2)) return { name: out(data[off:off2]) if out.weight.shape[0] > 1 else out(data[off:off2])[...,0] for name, off, off2, out in groups } #data = l_in(data) #data = self.t(data) #data = l_out(data) #if data.shape[-1] == 1: # data = data[...,0] def sin_cos_pos_embeds(seq_len, n_embeds, dtype, device): n_spectra = (n_embeds+1) // 2 #thetas = torch.linspace(0, 2*torch.pi, seq_len+1, dtype=dtype, device=device)[:-1] #scales = 2**torch.arange(n_spectra, dtype=dtype, device=device) thetas = torch.arange(seq_len, dtype=dtype, device=device) scales = 2**(-torch.arange(n_spectra, dtype=dtype, device=device)) embeds = torch.outer(scales, thetas) embeds = torch.cat([torch.sin(embeds),torch.cos(embeds)],dim=-1).view(n_spectra*2, seq_len) return embeds.T[:,:n_embeds] optims = [torch.optim.SGD, torch.optim.Adam, torch.optim.AdamW] def flat_grid_search(data, labels, **kw_min_max_steps_proc): key_proc_lists = [ [key, proc, torch.linspace(min,max,steps)[torch.randperm(steps)]] for key, min_max_steps_proc in kw_min_max_steps_proc.items() for min, max, steps, proc in [min_max_steps_proc] ] # it would be nice to do a few steps of all of them, then continue # this i suppose would mean caching all the objects and states all_combinations = list(itertools.product(*[list for key,proc,list in key_proc_lists])) random.shuffle(all_combinations) #all_combinations = all_combinations[torch.randperm(len(all_combinations))] for values in all_combinations: kwparams = { key: proc(values[idx]) for idx in range(len(values)) for key, proc, list in [key_proc_lists[idx]] } dtype = kwparams.get('dtype', data.dtype) data = data.to(dtype) labels = labels.to(dtype) lr = kwparams.pop('lr', 1e-7) steps = kwparams.pop('steps', 100) optim = kwparams.pop('optim', optims[0]) if 'd_model_per_head' in kwparams: kwparams['d_model'] = kwparams.pop('d_model_per_head') * kwparams['n_heads'] t = TBasic(d_in = data.shape[-1], d_out = labels.shape[-1], **kwparams) optim_obj = optim(t.parameters(), lr=lr) for s in range(steps): l = torch.nn.functional.mse_loss(t(data), labels) l.backward() optim_obj.step() optim_obj.zero_grad() kwparams['lr'] = lr kwparams['optim'] = optim kwparams['steps'] = steps yield [l.item(), kwparams] def make_sin_arch(): # somewhat arbitrary #lr = 0.003162277629598975 #return TBasic(d_in=1, d_out=1, d_ff=35, n_heads=4, n_layers=2, dtype=torch.float64, d_model=12, dropout=0) lr = 0.01 return TBasic(d_in=1, d_out=1, d_ff=25, n_heads=4, n_layers=1, dtype=torch.float32, d_model=4, dropout=0) # d_ff=35 n_heads=4 n_layers=2 dtype=torch.float64 d_model=12 # can it predict its loss? if __name__ == '__main__': data = torch.linspace(0, 3, 256)[...,None] labels = torch.sin(data) results = [] for result in flat_grid_search( data=data, labels=labels, #optim=[0,len(optims)-1,len(optims),lambda idx: optims[int(idx)]], optim=[2,2,1,lambda idx: optims[int(idx)]], # adamw steps=[500,933,4,int],#[50,200,3,int], lr=[-2,-3,3,lambda x:10**x], d_model_per_head=[1,3,3,int], d_ff=[25,35,3,int], n_heads=[3,5,3,int], n_layers=[1,2,2,int], dropout=[0,0,1,float],#[0,0.025,2,float], dtype=[0,1,2,lambda idx: [torch.float32,torch.float64][int(idx)]], ): results.append(result) if len(results) % 4 != 0: continue results.sort() numbered = list(enumerate(results)) print() for idx, r in numbered[:3] + numbered[-3:]: l, kwp = r print(idx, l, end=' ') for k, v in kwp.items(): print(f'{k}:{v} ; ', end='') print()