xxxxxxxxxx
#!python3 -m pip install torch==1.8.0 \
#--user --quiet --no-warn-script-location
#!python3 -m pip install torchvision==0.9.0 \
#--user --quiet --no-warn-script-location
#spath='/home/sc_work/.sage/local/lib/python3.9/site-packages'
#import sys; sys.path.append(spath)
from torch.utils.data import Dataset as tds
from torch.utils.data import DataLoader as tdl
import numpy as np,pylab as pl,torch,zipfile,h5py,urllib
from IPython.core.magic import register_line_magic
class TData(tds):
def __init__(self,x,y):
self.x=torch.tensor(x,dtype=torch.float32)
self.y=torch.tensor(y,dtype=torch.int32)
def __getitem__(self,index):
img,lbl=self.x[index],self.y[index]
return img,lbl
def __len__(self):
return self.y.shape[0]
def display_examples(data):
for images,labels in dataloaders[data]:
print('Image dimensions: %s'%str(images.shape))
print('Label dimensions: %s'%str(labels.shape))
n=np.random.randint(1,3)
fig=pl.figure(figsize=(6,2))
for i in range(n,n+5):
ax=fig.add_subplot(1,5,i-n+1,xticks=[],yticks=[])
ax.set_title(names[labels[i].item()],fontsize=10)
ax.imshow(np.transpose(images[i],(1,2,0)))
pl.tight_layout(); pl.show(); break
xxxxxxxxxx
path='https://olgabelitskaya.github.io/'
zf='TomatoCultivarImages.h5.zip'
input_file=urllib.request.urlopen(path+zf)
output_file=open(zf,'wb');
output_file.write(input_file.read())
output_file.close(); input_file.close()
zipf=zipfile.ZipFile(zf,'r')
zipf.extractall(''); zipf.close()
f=h5py.File(zf[:-4],'r'); keys=list(f.keys())
x_test=np.array(f[keys[0]])
y_test=np.array(f[keys[1]])
names=['Kumato','Beefsteak','Tigerella',
'Roma','Japanese Black Trifele',
'Yellow Pear','Sun Gold','Green Zebra',
'Cherokee Purple','Oxheart','Blue Berries',
'San Marzano','Banana Legs',
'German Orange Strawberry','Supersweet 100']
num_classes=len(names); batch_size=int(16)
x_test=np.transpose(x_test,(0,3,1,2))
print(x_test.mean(),x_test.std())
test=TData(x_test,y_test)
dataloaders={'test':tdl(
dataset=test,shuffle=True,batch_size=batch_size)}
%display_examples test
No comments:
Post a Comment