Skip to content
Snippets Groups Projects
Commit 482273cd authored by PIERRE Fabien's avatar PIERRE Fabien
Browse files

ajout

parent 7fc92711
No related branches found
No related tags found
No related merge requests found
Showing
with 516 additions and 0 deletions
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
File added
import torch
import torch.nn as nn
import torchvision
import sys
import numpy as np
from PIL import Image
import PIL
import numpy as np
import matplotlib.pyplot as plt
def crop_image(img, d=32):
'''Make dimensions divisible by `d`'''
new_size = (img.size[0] - img.size[0] % d,
img.size[1] - img.size[1] % d)
bbox = [
int((img.size[0] - new_size[0])/2),
int((img.size[1] - new_size[1])/2),
int((img.size[0] + new_size[0])/2),
int((img.size[1] + new_size[1])/2),
]
img_cropped = img.crop(bbox)
return img_cropped
def get_params(opt_over, net, net_input, downsampler=None):
'''Returns parameters that we want to optimize over.
Args:
opt_over: comma separated list, e.g. "net,input" or "net"
net: network
net_input: torch.Tensor that stores input `z`
'''
opt_over_list = opt_over.split(',')
params = []
for opt in opt_over_list:
if opt == 'net':
params += [x for x in net.parameters() ]
elif opt=='down':
assert downsampler is not None
params = [x for x in downsampler.parameters()]
elif opt == 'input':
net_input.requires_grad = True
params += [net_input]
else:
assert False, 'what is it?'
return params
def get_image_grid(images_np, nrow=8):
'''Creates a grid from a list of images by concatenating them.'''
images_torch = [torch.from_numpy(x) for x in images_np]
torch_grid = torchvision.utils.make_grid(images_torch, nrow)
return torch_grid.numpy()
def plot_image_grid(images_np, nrow =8, factor=1, interpolation='lanczos'):
"""Draws images in a grid
Args:
images_np: list of images, each image is np.array of size 3xHxW of 1xHxW
nrow: how many images will be in one row
factor: size if the plt.figure
interpolation: interpolation used in plt.imshow
"""
n_channels = max(x.shape[0] for x in images_np)
assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels"
images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np]
grid = get_image_grid(images_np, nrow)
plt.figure(figsize=(len(images_np) + factor, 12 + factor))
if images_np[0].shape[0] == 1:
plt.imshow(grid[0], cmap='gray', interpolation=interpolation)
else:
plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation)
plt.show()
return grid
def load(path):
"""Load PIL image."""
img = Image.open(path)
return img
def get_image(path, imsize=-1):
"""Load an image and resize to a cpecific size.
Args:
path: path to image
imsize: tuple or scalar with dimensions; -1 for `no resize`
"""
img = load(path)
if isinstance(imsize, int):
imsize = (imsize, imsize)
if imsize[0]!= -1 and img.size != imsize:
if imsize[0] > img.size[0]:
img = img.resize(imsize, Image.BICUBIC)
else:
img = img.resize(imsize, Image.ANTIALIAS)
img_np = pil_to_np(img)
return img, img_np
def fill_noise(x, noise_type):
"""Fills tensor `x` with noise of type `noise_type`."""
if noise_type == 'u':
x.uniform_()
elif noise_type == 'n':
x.normal_()
else:
assert False
def get_noise(input_depth, method, spatial_size, noise_type='u', var=1./10):
"""Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`)
initialized in a specific way.
Args:
input_depth: number of channels in the tensor
method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid
spatial_size: spatial size of the tensor to initialize
noise_type: 'u' for uniform; 'n' for normal
var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler.
"""
if isinstance(spatial_size, int):
spatial_size = (spatial_size, spatial_size)
if method == 'noise':
shape = [1, input_depth, spatial_size[0], spatial_size[1]]
net_input = torch.zeros(shape)
fill_noise(net_input, noise_type)
net_input *= var
elif method == 'meshgrid':
assert input_depth == 2
X, Y = np.meshgrid(np.arange(0, spatial_size[1])/float(spatial_size[1]-1), np.arange(0, spatial_size[0])/float(spatial_size[0]-1))
meshgrid = np.concatenate([X[None,:], Y[None,:]])
net_input= np_to_torch(meshgrid)
else:
assert False
return net_input
def pil_to_np(img_PIL):
'''Converts image in PIL format to np.array.
From W x H x C [0...255] to C x W x H [0..1]
'''
ar = np.array(img_PIL)
if len(ar.shape) == 3:
ar = ar.transpose(2,0,1)
else:
ar = ar[None, ...]
return ar.astype(np.float32) / 255.
def np_to_pil(img_np):
'''Converts image in np.array format to PIL image.
From C x W x H [0..1] to W x H x C [0...255]
'''
ar = np.clip(img_np*255,0,255).astype(np.uint8)
if img_np.shape[0] == 1:
ar = ar[0]
else:
ar = ar.transpose(1, 2, 0)
return Image.fromarray(ar)
def np_to_torch(img_np):
'''Converts image in numpy.array to torch.Tensor.
From C x W x H [0..1] to C x W x H [0..1]
'''
return torch.from_numpy(img_np)[None, :]
def torch_to_np(img_var):
'''Converts an image in torch.Tensor format to np.array.
From 1 x C x W x H [0..1] to C x W x H [0..1]
'''
return img_var.detach().cpu().numpy()[0]
def optimize(optimizer_type, OPT , closure, num_iter):
"""Runs optimization loop.
Args:
optimizer_type: 'LBFGS' of 'adam'
parameters: list of Tensors to optimize over
closure: function, that returns loss variable
LR: learning rate
num_iter: number of iterations
"""
if optimizer_type == 'LBFGS':
# Do several steps with adam first
optimizer = torch.optim.Adam(parameters, lr=0.001)
for j in range(100):
optimizer.zero_grad()
closure()
optimizer.step()
print('Starting optimization with LBFGS')
def closure2():
optimizer.zero_grad()
return closure()
optimizer = torch.optim.LBFGS(parameters, max_iter=num_iter, lr=LR, tolerance_grad=-1, tolerance_change=-1)
optimizer.step(closure2)
elif optimizer_type == 'adam':
print('Starting optimization with ADAM')
#optimizer = torch.optim.Adam(parameters, lr=LR)
for j in range(num_iter):
OPT.zero_grad()
closure()
OPT.step()
else:
assert False
\ No newline at end of file
import os
from .common_utils import *
def get_noisy_image(img_np, sigma):
"""Adds Gaussian noise to an image.
Args:
img_np: image, np.array with values from 0 to 1
sigma: std of the noise
"""
img_noisy_np = np.clip(img_np + np.random.normal(scale=sigma, size=img_np.shape), 0, 1).astype(np.float32)
img_noisy_pil = np_to_pil(img_noisy_np)
return img_noisy_pil, img_noisy_np
\ No newline at end of file
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from .matcher import Matcher
import os
from collections import OrderedDict
class View(nn.Module):
def __init__(self):
super(View, self).__init__()
def forward(self, x):
return x.view(-1)
def get_vanilla_vgg_features(cut_idx=-1):
if not os.path.exists('vgg_features.pth'):
os.system(
'wget --no-check-certificate -N https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth')
vgg_weights = torch.load('vgg19-d01eb7cb.pth')
# fix compatibility issues
map = {'classifier.6.weight':u'classifier.7.weight', 'classifier.6.bias':u'classifier.7.bias'}
vgg_weights = OrderedDict([(map[k] if k in map else k,v) for k,v in vgg_weights.iteritems()])
model = models.vgg19()
model.classifier = nn.Sequential(View(), *model.classifier._modules.values())
model.load_state_dict(vgg_weights)
torch.save(model.features, 'vgg_features.pth')
torch.save(model.classifier, 'vgg_classifier.pth')
vgg = torch.load('vgg_features.pth')
if cut_idx > 36:
vgg_classifier = torch.load('vgg_classifier.pth')
vgg = nn.Sequential(*(vgg._modules.values() + vgg_classifier._modules.values()))
vgg.eval()
return vgg
def get_matcher(net, opt):
idxs = [x for x in opt['layers'].split(',')]
matcher = Matcher(opt['what'])
def hook(module, input, output):
matcher(module, output)
for i in idxs:
net._modules[i].register_forward_hook(hook)
return matcher
def get_vgg(cut_idx=-1):
f = get_vanilla_vgg_features(cut_idx)
if cut_idx > 0:
num_modules = len(f._modules)
keys_to_delete = [f._modules.keys()[x] for x in range(cut_idx, num_modules)]
for k in keys_to_delete:
del f._modules[k]
return f
def vgg_preprocess_var(var):
(r, g, b) = torch.chunk(var, 3, dim=1)
bgr = torch.cat((b, g, r), 1)
out = bgr * 255 - torch.autograd.Variable(vgg_mean[None, ...]).type(var.type()).expand_as(bgr)
return out
vgg_mean = torch.FloatTensor([103.939, 116.779, 123.680]).view(3, 1, 1)
def get_preprocessor(imsize):
def vgg_preprocess(tensor):
(r, g, b) = torch.chunk(tensor, 3, dim=0)
bgr = torch.cat((b, g, r), 0)
out = bgr * 255 - vgg_mean.type(tensor.type()).expand_as(bgr)
return out
preprocess = transforms.Compose([
transforms.Resize(imsize),
transforms.ToTensor(),
transforms.Lambda(vgg_preprocess)
])
return preprocess
def get_deprocessor():
def vgg_deprocess(tensor):
bgr = (tensor + vgg_mean.expand_as(tensor)) / 255.0
(b, g, r) = torch.chunk(bgr, 3, dim=0)
rgb = torch.cat((r, g, b), 0)
return rgb
deprocess = transforms.Compose([
transforms.Lambda(vgg_deprocess),
transforms.Lambda(lambda x: torch.clamp(x, 0, 1)),
transforms.ToPILImage()
])
return deprocess
import numpy as np
from PIL import Image
import PIL.ImageDraw as ImageDraw
import PIL.ImageFont as ImageFont
from .common_utils import *
def get_text_mask(for_image, sz=20):
font_fname = '/usr/share/fonts/truetype/freefont/FreeSansBold.ttf'
font_size = sz
font = ImageFont.truetype(font_fname, font_size)
img_mask = Image.fromarray(np.array(for_image)*0+255)
draw = ImageDraw.Draw(img_mask)
draw.text((128, 128), "hello world", font=font, fill='rgb(0, 0, 0)')
return img_mask
def get_bernoulli_mask(for_image, zero_fraction=0.95):
img_mask_np=(np.random.random_sample(size=pil_to_np(for_image).shape) > zero_fraction).astype(int)
img_mask = np_to_pil(img_mask_np)
return img_mask
import torch
import torch.nn as nn
class Matcher:
def __init__(self, how='gram_matrix', loss='mse'):
self.mode = 'store'
self.stored = {}
self.losses = {}
if how in all_features.keys():
self.get_statistics = all_features[how]
else:
assert False
pass
if loss in all_losses.keys():
self.loss = all_losses[loss]
else:
assert False
def __call__(self, module, features):
statistics = self.get_statistics(features)
self.statistics = statistics
if self.mode == 'store':
self.stored[module] = statistics.detach().clone()
elif self.mode == 'match':
self.losses[module] = self.loss(statistics, self.stored[module])
def clean(self):
self.losses = {}
def gram_matrix(x):
(b, ch, h, w) = x.size()
features = x.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
def features(x):
return x
all_features = {
'gram_matrix': gram_matrix,
'features': features,
}
all_losses = {
'mse': nn.MSELoss(),
'smoothL1': nn.SmoothL1Loss(),
'L1': nn.L1Loss(),
}
import torch
import torch.nn as nn
class Matcher:
def __init__(self, how='gram_matrix', loss='mse', map_index=933):
self.mode = 'store'
self.stored = {}
self.losses = {}
if how in all_features.keys():
self.get_statistics = all_features[how]
else:
assert False
pass
if loss in all_losses.keys():
self.loss = all_losses[loss]
else:
assert False
self.map_index = map_index
self.method = 'match'
def __call__(self, module, features):
statistics = self.get_statistics(features)
self.statistics = statistics
if self.mode == 'store':
self.stored[module] = statistics.detach()
elif self.mode == 'match':
if statistics.ndimension() == 2:
if self.method == 'maximize':
self.losses[module] = - statistics[0, self.map_index]
else:
self.losses[module] = torch.abs(300 - statistics[0, self.map_index])
else:
ws = self.window_size
t = statistics.detach() * 0
s_cc = statistics[:1, :, t.shape[2] // 2 - ws:t.shape[2] // 2 + ws, t.shape[3] // 2 - ws:t.shape[3] // 2 + ws] #* 1.0
t_cc = t[:1, :, t.shape[2] // 2 - ws:t.shape[2] // 2 + ws, t.shape[3] // 2 - ws:t.shape[3] // 2 + ws] #* 1.0
t_cc[:, self.map_index,...] = 1
if self.method == 'maximize':
self.losses[module] = -(s_cc * t_cc.contiguous()).sum()
else:
self.losses[module] = torch.abs(200 -(s_cc * t_cc.contiguous())).sum()
def clean(self):
self.losses = {}
def gram_matrix(x):
(b, ch, h, w) = x.size()
features = x.view(b, ch, w * h)
features_t = features.transpose(1, 2)
gram = features.bmm(features_t) / (ch * h * w)
return gram
def features(x):
return x
all_features = {
'gram_matrix': gram_matrix,
'features': features,
}
all_losses = {
'mse': nn.MSELoss(),
'smoothL1': nn.SmoothL1Loss(),
'L1': nn.L1Loss(),
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment