import pylab; from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neural_network import MLPClassifier
import warnings; warnings.filterwarnings('ignore')
train=fetch_20newsgroups(
subset='train',shuffle=True,remove=('headers','footers','quotes'))
test=fetch_20newsgroups(
subset='test',shuffle=True,remove=('headers','footers','quotes'))
y_train,y_test=train.target,test.target
vectorizer=TfidfVectorizer(
sublinear_tf=True,max_df=.5,stop_words='english')
x_train=vectorizer.fit_transform(train.data)
x_test=vectorizer.transform(test.data)
del train,test
print(x_train.shape,x_test.shape,y_train.shape,y_test.shape)
clf=MLPClassifier(
hidden_layer_sizes=(128,),max_iter=3,solver='adam',
verbose=2,random_state=1,learning_rate_init=.01)
clf.fit(x_train,y_train); print(clf.score(x_test,y_test))
y_test_predictions=clf.predict(x_test)
pylab.figure(figsize=(6,4))
pylab.scatter(range(100),y_test[:100],s=100)
pylab.scatter(range(100),y_test_predictions[:100],s=25)
pylab.tight_lauoyt(); pylab.show()
Monday, December 16, 2019
Text Classification
20newsgroups.ipynb
Labels:
bug,
instagram,
interactive,
Python,
SageMathCell,
sklearn
Subscribe to:
Post Comments (Atom)
No comments:
Post a Comment