Skip to content

Instantly share code, notes, and snippets.

@esevan
Created March 30, 2020 07:54
Show Gist options
  • Save esevan/3769a99f0b24317f436015ee6dd44cd1 to your computer and use it in GitHub Desktop.
Save esevan/3769a99f0b24317f436015ee6dd44cd1 to your computer and use it in GitHub Desktop.
handson_ml
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'data': array([[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]]),\n",
" 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object),\n",
" 'feature_names': ['pixel1',\n",
" 'pixel2',\n",
" 'pixel3',\n",
" 'pixel4',\n",
" 'pixel5',\n",
" 'pixel6',\n",
" 'pixel7',\n",
" 'pixel8',\n",
" 'pixel9',\n",
" 'pixel10',\n",
" 'pixel11',\n",
" 'pixel12',\n",
" 'pixel13',\n",
" 'pixel14',\n",
" 'pixel15',\n",
" 'pixel16',\n",
" 'pixel17',\n",
" 'pixel18',\n",
" 'pixel19',\n",
" 'pixel20',\n",
" 'pixel21',\n",
" 'pixel22',\n",
" 'pixel23',\n",
" 'pixel24',\n",
" 'pixel25',\n",
" 'pixel26',\n",
" 'pixel27',\n",
" 'pixel28',\n",
" 'pixel29',\n",
" 'pixel30',\n",
" 'pixel31',\n",
" 'pixel32',\n",
" 'pixel33',\n",
" 'pixel34',\n",
" 'pixel35',\n",
" 'pixel36',\n",
" 'pixel37',\n",
" 'pixel38',\n",
" 'pixel39',\n",
" 'pixel40',\n",
" 'pixel41',\n",
" 'pixel42',\n",
" 'pixel43',\n",
" 'pixel44',\n",
" 'pixel45',\n",
" 'pixel46',\n",
" 'pixel47',\n",
" 'pixel48',\n",
" 'pixel49',\n",
" 'pixel50',\n",
" 'pixel51',\n",
" 'pixel52',\n",
" 'pixel53',\n",
" 'pixel54',\n",
" 'pixel55',\n",
" 'pixel56',\n",
" 'pixel57',\n",
" 'pixel58',\n",
" 'pixel59',\n",
" 'pixel60',\n",
" 'pixel61',\n",
" 'pixel62',\n",
" 'pixel63',\n",
" 'pixel64',\n",
" 'pixel65',\n",
" 'pixel66',\n",
" 'pixel67',\n",
" 'pixel68',\n",
" 'pixel69',\n",
" 'pixel70',\n",
" 'pixel71',\n",
" 'pixel72',\n",
" 'pixel73',\n",
" 'pixel74',\n",
" 'pixel75',\n",
" 'pixel76',\n",
" 'pixel77',\n",
" 'pixel78',\n",
" 'pixel79',\n",
" 'pixel80',\n",
" 'pixel81',\n",
" 'pixel82',\n",
" 'pixel83',\n",
" 'pixel84',\n",
" 'pixel85',\n",
" 'pixel86',\n",
" 'pixel87',\n",
" 'pixel88',\n",
" 'pixel89',\n",
" 'pixel90',\n",
" 'pixel91',\n",
" 'pixel92',\n",
" 'pixel93',\n",
" 'pixel94',\n",
" 'pixel95',\n",
" 'pixel96',\n",
" 'pixel97',\n",
" 'pixel98',\n",
" 'pixel99',\n",
" 'pixel100',\n",
" 'pixel101',\n",
" 'pixel102',\n",
" 'pixel103',\n",
" 'pixel104',\n",
" 'pixel105',\n",
" 'pixel106',\n",
" 'pixel107',\n",
" 'pixel108',\n",
" 'pixel109',\n",
" 'pixel110',\n",
" 'pixel111',\n",
" 'pixel112',\n",
" 'pixel113',\n",
" 'pixel114',\n",
" 'pixel115',\n",
" 'pixel116',\n",
" 'pixel117',\n",
" 'pixel118',\n",
" 'pixel119',\n",
" 'pixel120',\n",
" 'pixel121',\n",
" 'pixel122',\n",
" 'pixel123',\n",
" 'pixel124',\n",
" 'pixel125',\n",
" 'pixel126',\n",
" 'pixel127',\n",
" 'pixel128',\n",
" 'pixel129',\n",
" 'pixel130',\n",
" 'pixel131',\n",
" 'pixel132',\n",
" 'pixel133',\n",
" 'pixel134',\n",
" 'pixel135',\n",
" 'pixel136',\n",
" 'pixel137',\n",
" 'pixel138',\n",
" 'pixel139',\n",
" 'pixel140',\n",
" 'pixel141',\n",
" 'pixel142',\n",
" 'pixel143',\n",
" 'pixel144',\n",
" 'pixel145',\n",
" 'pixel146',\n",
" 'pixel147',\n",
" 'pixel148',\n",
" 'pixel149',\n",
" 'pixel150',\n",
" 'pixel151',\n",
" 'pixel152',\n",
" 'pixel153',\n",
" 'pixel154',\n",
" 'pixel155',\n",
" 'pixel156',\n",
" 'pixel157',\n",
" 'pixel158',\n",
" 'pixel159',\n",
" 'pixel160',\n",
" 'pixel161',\n",
" 'pixel162',\n",
" 'pixel163',\n",
" 'pixel164',\n",
" 'pixel165',\n",
" 'pixel166',\n",
" 'pixel167',\n",
" 'pixel168',\n",
" 'pixel169',\n",
" 'pixel170',\n",
" 'pixel171',\n",
" 'pixel172',\n",
" 'pixel173',\n",
" 'pixel174',\n",
" 'pixel175',\n",
" 'pixel176',\n",
" 'pixel177',\n",
" 'pixel178',\n",
" 'pixel179',\n",
" 'pixel180',\n",
" 'pixel181',\n",
" 'pixel182',\n",
" 'pixel183',\n",
" 'pixel184',\n",
" 'pixel185',\n",
" 'pixel186',\n",
" 'pixel187',\n",
" 'pixel188',\n",
" 'pixel189',\n",
" 'pixel190',\n",
" 'pixel191',\n",
" 'pixel192',\n",
" 'pixel193',\n",
" 'pixel194',\n",
" 'pixel195',\n",
" 'pixel196',\n",
" 'pixel197',\n",
" 'pixel198',\n",
" 'pixel199',\n",
" 'pixel200',\n",
" 'pixel201',\n",
" 'pixel202',\n",
" 'pixel203',\n",
" 'pixel204',\n",
" 'pixel205',\n",
" 'pixel206',\n",
" 'pixel207',\n",
" 'pixel208',\n",
" 'pixel209',\n",
" 'pixel210',\n",
" 'pixel211',\n",
" 'pixel212',\n",
" 'pixel213',\n",
" 'pixel214',\n",
" 'pixel215',\n",
" 'pixel216',\n",
" 'pixel217',\n",
" 'pixel218',\n",
" 'pixel219',\n",
" 'pixel220',\n",
" 'pixel221',\n",
" 'pixel222',\n",
" 'pixel223',\n",
" 'pixel224',\n",
" 'pixel225',\n",
" 'pixel226',\n",
" 'pixel227',\n",
" 'pixel228',\n",
" 'pixel229',\n",
" 'pixel230',\n",
" 'pixel231',\n",
" 'pixel232',\n",
" 'pixel233',\n",
" 'pixel234',\n",
" 'pixel235',\n",
" 'pixel236',\n",
" 'pixel237',\n",
" 'pixel238',\n",
" 'pixel239',\n",
" 'pixel240',\n",
" 'pixel241',\n",
" 'pixel242',\n",
" 'pixel243',\n",
" 'pixel244',\n",
" 'pixel245',\n",
" 'pixel246',\n",
" 'pixel247',\n",
" 'pixel248',\n",
" 'pixel249',\n",
" 'pixel250',\n",
" 'pixel251',\n",
" 'pixel252',\n",
" 'pixel253',\n",
" 'pixel254',\n",
" 'pixel255',\n",
" 'pixel256',\n",
" 'pixel257',\n",
" 'pixel258',\n",
" 'pixel259',\n",
" 'pixel260',\n",
" 'pixel261',\n",
" 'pixel262',\n",
" 'pixel263',\n",
" 'pixel264',\n",
" 'pixel265',\n",
" 'pixel266',\n",
" 'pixel267',\n",
" 'pixel268',\n",
" 'pixel269',\n",
" 'pixel270',\n",
" 'pixel271',\n",
" 'pixel272',\n",
" 'pixel273',\n",
" 'pixel274',\n",
" 'pixel275',\n",
" 'pixel276',\n",
" 'pixel277',\n",
" 'pixel278',\n",
" 'pixel279',\n",
" 'pixel280',\n",
" 'pixel281',\n",
" 'pixel282',\n",
" 'pixel283',\n",
" 'pixel284',\n",
" 'pixel285',\n",
" 'pixel286',\n",
" 'pixel287',\n",
" 'pixel288',\n",
" 'pixel289',\n",
" 'pixel290',\n",
" 'pixel291',\n",
" 'pixel292',\n",
" 'pixel293',\n",
" 'pixel294',\n",
" 'pixel295',\n",
" 'pixel296',\n",
" 'pixel297',\n",
" 'pixel298',\n",
" 'pixel299',\n",
" 'pixel300',\n",
" 'pixel301',\n",
" 'pixel302',\n",
" 'pixel303',\n",
" 'pixel304',\n",
" 'pixel305',\n",
" 'pixel306',\n",
" 'pixel307',\n",
" 'pixel308',\n",
" 'pixel309',\n",
" 'pixel310',\n",
" 'pixel311',\n",
" 'pixel312',\n",
" 'pixel313',\n",
" 'pixel314',\n",
" 'pixel315',\n",
" 'pixel316',\n",
" 'pixel317',\n",
" 'pixel318',\n",
" 'pixel319',\n",
" 'pixel320',\n",
" 'pixel321',\n",
" 'pixel322',\n",
" 'pixel323',\n",
" 'pixel324',\n",
" 'pixel325',\n",
" 'pixel326',\n",
" 'pixel327',\n",
" 'pixel328',\n",
" 'pixel329',\n",
" 'pixel330',\n",
" 'pixel331',\n",
" 'pixel332',\n",
" 'pixel333',\n",
" 'pixel334',\n",
" 'pixel335',\n",
" 'pixel336',\n",
" 'pixel337',\n",
" 'pixel338',\n",
" 'pixel339',\n",
" 'pixel340',\n",
" 'pixel341',\n",
" 'pixel342',\n",
" 'pixel343',\n",
" 'pixel344',\n",
" 'pixel345',\n",
" 'pixel346',\n",
" 'pixel347',\n",
" 'pixel348',\n",
" 'pixel349',\n",
" 'pixel350',\n",
" 'pixel351',\n",
" 'pixel352',\n",
" 'pixel353',\n",
" 'pixel354',\n",
" 'pixel355',\n",
" 'pixel356',\n",
" 'pixel357',\n",
" 'pixel358',\n",
" 'pixel359',\n",
" 'pixel360',\n",
" 'pixel361',\n",
" 'pixel362',\n",
" 'pixel363',\n",
" 'pixel364',\n",
" 'pixel365',\n",
" 'pixel366',\n",
" 'pixel367',\n",
" 'pixel368',\n",
" 'pixel369',\n",
" 'pixel370',\n",
" 'pixel371',\n",
" 'pixel372',\n",
" 'pixel373',\n",
" 'pixel374',\n",
" 'pixel375',\n",
" 'pixel376',\n",
" 'pixel377',\n",
" 'pixel378',\n",
" 'pixel379',\n",
" 'pixel380',\n",
" 'pixel381',\n",
" 'pixel382',\n",
" 'pixel383',\n",
" 'pixel384',\n",
" 'pixel385',\n",
" 'pixel386',\n",
" 'pixel387',\n",
" 'pixel388',\n",
" 'pixel389',\n",
" 'pixel390',\n",
" 'pixel391',\n",
" 'pixel392',\n",
" 'pixel393',\n",
" 'pixel394',\n",
" 'pixel395',\n",
" 'pixel396',\n",
" 'pixel397',\n",
" 'pixel398',\n",
" 'pixel399',\n",
" 'pixel400',\n",
" 'pixel401',\n",
" 'pixel402',\n",
" 'pixel403',\n",
" 'pixel404',\n",
" 'pixel405',\n",
" 'pixel406',\n",
" 'pixel407',\n",
" 'pixel408',\n",
" 'pixel409',\n",
" 'pixel410',\n",
" 'pixel411',\n",
" 'pixel412',\n",
" 'pixel413',\n",
" 'pixel414',\n",
" 'pixel415',\n",
" 'pixel416',\n",
" 'pixel417',\n",
" 'pixel418',\n",
" 'pixel419',\n",
" 'pixel420',\n",
" 'pixel421',\n",
" 'pixel422',\n",
" 'pixel423',\n",
" 'pixel424',\n",
" 'pixel425',\n",
" 'pixel426',\n",
" 'pixel427',\n",
" 'pixel428',\n",
" 'pixel429',\n",
" 'pixel430',\n",
" 'pixel431',\n",
" 'pixel432',\n",
" 'pixel433',\n",
" 'pixel434',\n",
" 'pixel435',\n",
" 'pixel436',\n",
" 'pixel437',\n",
" 'pixel438',\n",
" 'pixel439',\n",
" 'pixel440',\n",
" 'pixel441',\n",
" 'pixel442',\n",
" 'pixel443',\n",
" 'pixel444',\n",
" 'pixel445',\n",
" 'pixel446',\n",
" 'pixel447',\n",
" 'pixel448',\n",
" 'pixel449',\n",
" 'pixel450',\n",
" 'pixel451',\n",
" 'pixel452',\n",
" 'pixel453',\n",
" 'pixel454',\n",
" 'pixel455',\n",
" 'pixel456',\n",
" 'pixel457',\n",
" 'pixel458',\n",
" 'pixel459',\n",
" 'pixel460',\n",
" 'pixel461',\n",
" 'pixel462',\n",
" 'pixel463',\n",
" 'pixel464',\n",
" 'pixel465',\n",
" 'pixel466',\n",
" 'pixel467',\n",
" 'pixel468',\n",
" 'pixel469',\n",
" 'pixel470',\n",
" 'pixel471',\n",
" 'pixel472',\n",
" 'pixel473',\n",
" 'pixel474',\n",
" 'pixel475',\n",
" 'pixel476',\n",
" 'pixel477',\n",
" 'pixel478',\n",
" 'pixel479',\n",
" 'pixel480',\n",
" 'pixel481',\n",
" 'pixel482',\n",
" 'pixel483',\n",
" 'pixel484',\n",
" 'pixel485',\n",
" 'pixel486',\n",
" 'pixel487',\n",
" 'pixel488',\n",
" 'pixel489',\n",
" 'pixel490',\n",
" 'pixel491',\n",
" 'pixel492',\n",
" 'pixel493',\n",
" 'pixel494',\n",
" 'pixel495',\n",
" 'pixel496',\n",
" 'pixel497',\n",
" 'pixel498',\n",
" 'pixel499',\n",
" 'pixel500',\n",
" 'pixel501',\n",
" 'pixel502',\n",
" 'pixel503',\n",
" 'pixel504',\n",
" 'pixel505',\n",
" 'pixel506',\n",
" 'pixel507',\n",
" 'pixel508',\n",
" 'pixel509',\n",
" 'pixel510',\n",
" 'pixel511',\n",
" 'pixel512',\n",
" 'pixel513',\n",
" 'pixel514',\n",
" 'pixel515',\n",
" 'pixel516',\n",
" 'pixel517',\n",
" 'pixel518',\n",
" 'pixel519',\n",
" 'pixel520',\n",
" 'pixel521',\n",
" 'pixel522',\n",
" 'pixel523',\n",
" 'pixel524',\n",
" 'pixel525',\n",
" 'pixel526',\n",
" 'pixel527',\n",
" 'pixel528',\n",
" 'pixel529',\n",
" 'pixel530',\n",
" 'pixel531',\n",
" 'pixel532',\n",
" 'pixel533',\n",
" 'pixel534',\n",
" 'pixel535',\n",
" 'pixel536',\n",
" 'pixel537',\n",
" 'pixel538',\n",
" 'pixel539',\n",
" 'pixel540',\n",
" 'pixel541',\n",
" 'pixel542',\n",
" 'pixel543',\n",
" 'pixel544',\n",
" 'pixel545',\n",
" 'pixel546',\n",
" 'pixel547',\n",
" 'pixel548',\n",
" 'pixel549',\n",
" 'pixel550',\n",
" 'pixel551',\n",
" 'pixel552',\n",
" 'pixel553',\n",
" 'pixel554',\n",
" 'pixel555',\n",
" 'pixel556',\n",
" 'pixel557',\n",
" 'pixel558',\n",
" 'pixel559',\n",
" 'pixel560',\n",
" 'pixel561',\n",
" 'pixel562',\n",
" 'pixel563',\n",
" 'pixel564',\n",
" 'pixel565',\n",
" 'pixel566',\n",
" 'pixel567',\n",
" 'pixel568',\n",
" 'pixel569',\n",
" 'pixel570',\n",
" 'pixel571',\n",
" 'pixel572',\n",
" 'pixel573',\n",
" 'pixel574',\n",
" 'pixel575',\n",
" 'pixel576',\n",
" 'pixel577',\n",
" 'pixel578',\n",
" 'pixel579',\n",
" 'pixel580',\n",
" 'pixel581',\n",
" 'pixel582',\n",
" 'pixel583',\n",
" 'pixel584',\n",
" 'pixel585',\n",
" 'pixel586',\n",
" 'pixel587',\n",
" 'pixel588',\n",
" 'pixel589',\n",
" 'pixel590',\n",
" 'pixel591',\n",
" 'pixel592',\n",
" 'pixel593',\n",
" 'pixel594',\n",
" 'pixel595',\n",
" 'pixel596',\n",
" 'pixel597',\n",
" 'pixel598',\n",
" 'pixel599',\n",
" 'pixel600',\n",
" 'pixel601',\n",
" 'pixel602',\n",
" 'pixel603',\n",
" 'pixel604',\n",
" 'pixel605',\n",
" 'pixel606',\n",
" 'pixel607',\n",
" 'pixel608',\n",
" 'pixel609',\n",
" 'pixel610',\n",
" 'pixel611',\n",
" 'pixel612',\n",
" 'pixel613',\n",
" 'pixel614',\n",
" 'pixel615',\n",
" 'pixel616',\n",
" 'pixel617',\n",
" 'pixel618',\n",
" 'pixel619',\n",
" 'pixel620',\n",
" 'pixel621',\n",
" 'pixel622',\n",
" 'pixel623',\n",
" 'pixel624',\n",
" 'pixel625',\n",
" 'pixel626',\n",
" 'pixel627',\n",
" 'pixel628',\n",
" 'pixel629',\n",
" 'pixel630',\n",
" 'pixel631',\n",
" 'pixel632',\n",
" 'pixel633',\n",
" 'pixel634',\n",
" 'pixel635',\n",
" 'pixel636',\n",
" 'pixel637',\n",
" 'pixel638',\n",
" 'pixel639',\n",
" 'pixel640',\n",
" 'pixel641',\n",
" 'pixel642',\n",
" 'pixel643',\n",
" 'pixel644',\n",
" 'pixel645',\n",
" 'pixel646',\n",
" 'pixel647',\n",
" 'pixel648',\n",
" 'pixel649',\n",
" 'pixel650',\n",
" 'pixel651',\n",
" 'pixel652',\n",
" 'pixel653',\n",
" 'pixel654',\n",
" 'pixel655',\n",
" 'pixel656',\n",
" 'pixel657',\n",
" 'pixel658',\n",
" 'pixel659',\n",
" 'pixel660',\n",
" 'pixel661',\n",
" 'pixel662',\n",
" 'pixel663',\n",
" 'pixel664',\n",
" 'pixel665',\n",
" 'pixel666',\n",
" 'pixel667',\n",
" 'pixel668',\n",
" 'pixel669',\n",
" 'pixel670',\n",
" 'pixel671',\n",
" 'pixel672',\n",
" 'pixel673',\n",
" 'pixel674',\n",
" 'pixel675',\n",
" 'pixel676',\n",
" 'pixel677',\n",
" 'pixel678',\n",
" 'pixel679',\n",
" 'pixel680',\n",
" 'pixel681',\n",
" 'pixel682',\n",
" 'pixel683',\n",
" 'pixel684',\n",
" 'pixel685',\n",
" 'pixel686',\n",
" 'pixel687',\n",
" 'pixel688',\n",
" 'pixel689',\n",
" 'pixel690',\n",
" 'pixel691',\n",
" 'pixel692',\n",
" 'pixel693',\n",
" 'pixel694',\n",
" 'pixel695',\n",
" 'pixel696',\n",
" 'pixel697',\n",
" 'pixel698',\n",
" 'pixel699',\n",
" 'pixel700',\n",
" 'pixel701',\n",
" 'pixel702',\n",
" 'pixel703',\n",
" 'pixel704',\n",
" 'pixel705',\n",
" 'pixel706',\n",
" 'pixel707',\n",
" 'pixel708',\n",
" 'pixel709',\n",
" 'pixel710',\n",
" 'pixel711',\n",
" 'pixel712',\n",
" 'pixel713',\n",
" 'pixel714',\n",
" 'pixel715',\n",
" 'pixel716',\n",
" 'pixel717',\n",
" 'pixel718',\n",
" 'pixel719',\n",
" 'pixel720',\n",
" 'pixel721',\n",
" 'pixel722',\n",
" 'pixel723',\n",
" 'pixel724',\n",
" 'pixel725',\n",
" 'pixel726',\n",
" 'pixel727',\n",
" 'pixel728',\n",
" 'pixel729',\n",
" 'pixel730',\n",
" 'pixel731',\n",
" 'pixel732',\n",
" 'pixel733',\n",
" 'pixel734',\n",
" 'pixel735',\n",
" 'pixel736',\n",
" 'pixel737',\n",
" 'pixel738',\n",
" 'pixel739',\n",
" 'pixel740',\n",
" 'pixel741',\n",
" 'pixel742',\n",
" 'pixel743',\n",
" 'pixel744',\n",
" 'pixel745',\n",
" 'pixel746',\n",
" 'pixel747',\n",
" 'pixel748',\n",
" 'pixel749',\n",
" 'pixel750',\n",
" 'pixel751',\n",
" 'pixel752',\n",
" 'pixel753',\n",
" 'pixel754',\n",
" 'pixel755',\n",
" 'pixel756',\n",
" 'pixel757',\n",
" 'pixel758',\n",
" 'pixel759',\n",
" 'pixel760',\n",
" 'pixel761',\n",
" 'pixel762',\n",
" 'pixel763',\n",
" 'pixel764',\n",
" 'pixel765',\n",
" 'pixel766',\n",
" 'pixel767',\n",
" 'pixel768',\n",
" 'pixel769',\n",
" 'pixel770',\n",
" 'pixel771',\n",
" 'pixel772',\n",
" 'pixel773',\n",
" 'pixel774',\n",
" 'pixel775',\n",
" 'pixel776',\n",
" 'pixel777',\n",
" 'pixel778',\n",
" 'pixel779',\n",
" 'pixel780',\n",
" 'pixel781',\n",
" 'pixel782',\n",
" 'pixel783',\n",
" 'pixel784'],\n",
" 'DESCR': \"**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges \\n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown \\n**Please cite**: \\n\\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples \\n\\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field. \\n\\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets. \\n\\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\\n\\nDownloaded from openml.org.\",\n",
" 'details': {'id': '554',\n",
" 'name': 'mnist_784',\n",
" 'version': '1',\n",
" 'format': 'ARFF',\n",
" 'upload_date': '2014-09-29T03:28:38',\n",
" 'licence': 'Public',\n",
" 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff',\n",
" 'file_id': '52667',\n",
" 'default_target_attribute': 'class',\n",
" 'tag': ['AzurePilot',\n",
" 'OpenML-CC18',\n",
" 'OpenML100',\n",
" 'study_1',\n",
" 'study_123',\n",
" 'study_41',\n",
" 'study_99',\n",
" 'vision'],\n",
" 'visibility': 'public',\n",
" 'status': 'active',\n",
" 'processing_date': '2018-10-03 21:23:30',\n",
" 'md5_checksum': '0298d579eb1b86163de7723944c7e495'},\n",
" 'categories': {},\n",
" 'url': 'https://www.openml.org/d/554'}"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.datasets import fetch_openml\n",
"\n",
"mnist = fetch_openml('mnist_784', version=1)\n",
"mnist"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"X, y = mnist['data'], mnist['target']"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(70000, 784)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(70000,)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y[1]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAABmxJREFUeJzt3UuIzf8fx/H/0WSUadSklA2SlWRjWMkkC1mhZiHKzk4WiqVSctuRhZ1kQzaKhGKWNkKUSS5ZUMICuSSa/+q3UHPeJ3PmnJk5r8djOa++53tinn3Lp3M0JiYm/gf0vnkz/QaA7hA7hBA7hBA7hBA7hOjr8v380z90XmOyH3qyQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQ4i+mX4DtO/ixYtNtzt37pTXjoyMlPuRI0fK/e3bt+U+PDzcdDt//nx57Y8fP8p94cKF5b527dpyT+PJDiHEDiHEDiHEDiHEDiHEDiHEDiEaExMT3bxfV2/WK0ZHR8v9xo0bTbefP39O99v5S6vfnxUrVjTdPn/+XF67adOmKb2n/+zbt6/pNjY2Vl574sSJtu49wxqT/dCTHUKIHUKIHUKIHUKIHUKIHUL4iOsccPXq1XJvNCY9aemKVh8j3b59e9Pt9u3b5bXXr18v9/7+/nJfvnx5021oaKi8thd5skMIsUMIsUMIsUMIsUMIsUMIsUMI5+xzwODgYLl//fp1yq+9YcOGcj98+HC5r1mzptwHBgaabjt27CivHR8fL/dWvn//3nRr589srvJkhxBihxBihxBihxBihxBihxBihxC+SnoOaHXWffr06Sm/9v3798t9/fr1U35tZoyvkoZkYocQYocQYocQYocQYocQYocQPs8+C1y6dKncT5061bF737t3r9yds/cOT3YIIXYIIXYIIXYIIXYIIXYIIXYI4Zy9C37//l3ud+/ebev1q/+fvdX3wh88eLCtezN3eLJDCLFDCLFDCLFDCLFDCLFDCEdvXXDz5s1yv3DhQluvv23btqbbuXPnymv7+vwKpPBkhxBihxBihxBihxBihxBihxBihxAOWafBnz9/yv3o0aMdvf/WrVubbsuWLevovZk7PNkhhNghhNghhNghhNghhNghhNghhHP2afDq1atyf/DgQUfvPzIy0tHXpzd4skMIsUMIsUMIsUMIsUMIsUMIsUMI5+zTYMmSJeW+cuXKcn/x4kVb91+6dGlb15PBkx1CiB1CiB1CiB1CiB1CiB1CiB1COGefBoODg+W+ePHicn/58mW5tzpHf/r0adOtv7+/vLZdr1+/Lvfh4eGm28ePH8trV69eXe4DAwPlzt882SGE2CGE2CGE2CGE2CGE2CFEY2Jiopv36+rNZouzZ8+W+4EDB8q91d9Ro9H45/c0XVq9t0WLFjXdvnz5Ul67ZcuWct+9e3e57927t9x72KS/EJ7sEELsEELsEELsEELsEELsEELsEMI5exd8+/at3NetW1fu4+Pj5T6bz9ln8r1duXKl6bZz587y2nnz5vRz0Dk7JBM7hBA7hBA7hBA7hBA7hBA7hHDOPgu8e/eu3Pfs2VPuY2Nj0/hu/k2rr7nevHlz0+3JkyfltY8fP57Se/pP9bv96dOn8tqhoaG27j3DnLNDMrFDCLFDCLFDCLFDCLFDCLFDCOfsPa7Vd7PfunWr3EdHR6fz7fyTVp/j37VrV7k/evSo6XbmzJny2v3795f7LOecHZKJHUKIHUKIHUKIHUKIHUL0zfQboLMGBwfLfSaP1lp5+PBhuT979mzKrz08PDzla+cqT3YIIXYIIXYIIXYIIXYIIXYIIXYI4Zydjvrw4UPT7dChQ+W1ly9fLvdfv36Ve/URWOfsQM8SO4QQO4QQO4QQO4QQO4QQO4TwVdKUrl27Vu7Hjx8v9zdv3jTd3r9/X167YMGCcm/1dc8nT54s9x7mq6QhmdghhNghhNghhNghhNghhNghhM+zU3r+/Hm5r1q1qtznz58/5WuPHTtW7hs3bix3/ubJDiHEDiHEDiHEDiHEDiHEDiHEDiF8nh16j8+zQzKxQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQ4hu/5fNk37FLdB5nuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQ4v/GxfNRZYW5kQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"some_digit = X[2000]\n",
"some_digit_image = some_digit.reshape(28, 28)\n",
"plt.imshow(some_digit_image, cmap=matplotlib.cm.binary, interpolation='nearest')\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'9'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y[36000]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"shuffle_index = np.random.permutation(70000)\n",
"X_train, y_train = X[shuffle_index[0:60000]], y[shuffle_index[0:60000]]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([65096, 18196, 6131, ..., 24452, 19983, 45538])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"shuffle_index"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"X_test, y_test = X[shuffle_index[60000:70000]], y[shuffle_index[60000:70000]]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'9'"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_test[100]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['2', '4', '3', ..., '4', '4', '2'], dtype=object)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"y_train_5 = (y_train == '5')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"5430"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len([y for y in y_train_5 if y])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"y_test_5 = (y_test == 5)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\pes94\\Anaconda3\\lib\\site-packages\\sklearn\\linear_model\\stochastic_gradient.py:183: FutureWarning: max_iter and tol parameters have been added in SGDClassifier in 0.19. If max_iter is set but tol is left unset, the default value for tol in 0.19 and 0.20 will be None (which is equivalent to -infinity, so it has no effect) but will change in 0.21 to 1e-3. Specify tol to silence this warning.\n",
" FutureWarning)\n"
]
},
{
"data": {
"text/plain": [
"SGDClassifier(alpha=0.0001, average=False, class_weight=None,\n",
" early_stopping=False, epsilon=0.1, eta0=0.0, fit_intercept=True,\n",
" l1_ratio=0.15, learning_rate='optimal', loss='hinge', max_iter=5,\n",
" n_iter=None, n_iter_no_change=5, n_jobs=None, penalty='l2',\n",
" power_t=0.5, random_state=42, shuffle=True, tol=None,\n",
" validation_fraction=0.1, verbose=0, warm_start=False)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.linear_model import SGDClassifier\n",
"\n",
"sgd_clf = SGDClassifier(max_iter=5, random_state=42)\n",
"sgd_clf.fit(X_train, y_train_5)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAABmxJREFUeJzt3UuIzf8fx/H/0WSUadSklA2SlWRjWMkkC1mhZiHKzk4WiqVSctuRhZ1kQzaKhGKWNkKUSS5ZUMICuSSa/+q3UHPeJ3PmnJk5r8djOa++53tinn3Lp3M0JiYm/gf0vnkz/QaA7hA7hBA7hBA7hBA7hOjr8v380z90XmOyH3qyQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQ4i+mX4DtO/ixYtNtzt37pTXjoyMlPuRI0fK/e3bt+U+PDzcdDt//nx57Y8fP8p94cKF5b527dpyT+PJDiHEDiHEDiHEDiHEDiHEDiHEDiEaExMT3bxfV2/WK0ZHR8v9xo0bTbefP39O99v5S6vfnxUrVjTdPn/+XF67adOmKb2n/+zbt6/pNjY2Vl574sSJtu49wxqT/dCTHUKIHUKIHUKIHUKIHUKIHUL4iOsccPXq1XJvNCY9aemKVh8j3b59e9Pt9u3b5bXXr18v9/7+/nJfvnx5021oaKi8thd5skMIsUMIsUMIsUMIsUMIsUMIsUMI5+xzwODgYLl//fp1yq+9YcOGcj98+HC5r1mzptwHBgaabjt27CivHR8fL/dWvn//3nRr589srvJkhxBihxBihxBihxBihxBihxBihxC+SnoOaHXWffr06Sm/9v3798t9/fr1U35tZoyvkoZkYocQYocQYocQYocQYocQYocQPs8+C1y6dKncT5061bF737t3r9yds/cOT3YIIXYIIXYIIXYIIXYIIXYIIXYI4Zy9C37//l3ud+/ebev1q/+fvdX3wh88eLCtezN3eLJDCLFDCLFDCLFDCLFDCLFDCEdvXXDz5s1yv3DhQluvv23btqbbuXPnymv7+vwKpPBkhxBihxBihxBihxBihxBihxBihxAOWafBnz9/yv3o0aMdvf/WrVubbsuWLevovZk7PNkhhNghhNghhNghhNghhNghhNghhHP2afDq1atyf/DgQUfvPzIy0tHXpzd4skMIsUMIsUMIsUMIsUMIsUMIsUMI5+zTYMmSJeW+cuXKcn/x4kVb91+6dGlb15PBkx1CiB1CiB1CiB1CiB1CiB1CiB1COGefBoODg+W+ePHicn/58mW5tzpHf/r0adOtv7+/vLZdr1+/Lvfh4eGm28ePH8trV69eXe4DAwPlzt882SGE2CGE2CGE2CGE2CGE2CFEY2Jiopv36+rNZouzZ8+W+4EDB8q91d9Ro9H45/c0XVq9t0WLFjXdvnz5Ul67ZcuWct+9e3e57927t9x72KS/EJ7sEELsEELsEELsEELsEELsEELsEMI5exd8+/at3NetW1fu4+Pj5T6bz9ln8r1duXKl6bZz587y2nnz5vRz0Dk7JBM7hBA7hBA7hBA7hBA7hBA7hHDOPgu8e/eu3Pfs2VPuY2Nj0/hu/k2rr7nevHlz0+3JkyfltY8fP57Se/pP9bv96dOn8tqhoaG27j3DnLNDMrFDCLFDCLFDCLFDCLFDCLFDCOfsPa7Vd7PfunWr3EdHR6fz7fyTVp/j37VrV7k/evSo6XbmzJny2v3795f7LOecHZKJHUKIHUKIHUKIHUKIHUL0zfQboLMGBwfLfSaP1lp5+PBhuT979mzKrz08PDzla+cqT3YIIXYIIXYIIXYIIXYIIXYIIXYI4Zydjvrw4UPT7dChQ+W1ly9fLvdfv36Ve/URWOfsQM8SO4QQO4QQO4QQO4QQO4QQO4TwVdKUrl27Vu7Hjx8v9zdv3jTd3r9/X167YMGCcm/1dc8nT54s9x7mq6QhmdghhNghhNghhNghhNghhNghhM+zU3r+/Hm5r1q1qtznz58/5WuPHTtW7hs3bix3/ubJDiHEDiHEDiHEDiHEDiHEDiHEDiF8nh16j8+zQzKxQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQ4hu/5fNk37FLdB5nuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQ4v/GxfNRZYW5kQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"\n",
"some_digit = X[2000]\n",
"some_digit_image = some_digit.reshape(28, 28)\n",
"plt.imshow(some_digit_image, cmap=matplotlib.cm.binary, interpolation='nearest')\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([False])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sgd_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(action='ignore')"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.9574 , 0.9643 , 0.96545])"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.model_selection import cross_val_score\n",
"\n",
"cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring='accuracy')"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.base import BaseEstimator\n",
"\n",
"class Never5Classifier(BaseEstimator):\n",
" def fit(self, X, y=None):\n",
" pass\n",
" \n",
" def predict(self, X):\n",
" return np.zeros((len(X), 1), dtype=bool)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.91 , 0.9063, 0.9122])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"never_5_clf = Never5Classifier()\n",
"\n",
"cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring='accuracy')"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import cross_val_predict\n",
"\n",
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[53702, 868],\n",
" [ 1389, 4041]], dtype=int64)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"\n",
"confusion_matrix(y_train_5, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import roc_curve\n",
"\n",
"# decision_function = 각 Sample에 대한 점수를 구함. Threshold값과의 비교를 통해 class가 나뉨\n",
"y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method='decision_function')\n",
"\n",
"# roc_curve를 통해 fpr, tpr, 가능한 threhold set이 나옴\n",
"fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"def plot_roc_curve(fpr, tpr, label=None):\n",
" plt.plot(fpr, tpr, linewidth=2, label=label)\n",
" plt.plot([0, 1], [0, 1], 'k--')\n",
" plt.axis([0, 1, 0, 1])\n",
" plt.xlabel('FPR')\n",
" plt.ylabel('TPR')\n",
" \n",
"plot_roc_curve(fpr, tpr)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"\n",
"forest_clf = RandomForestClassifier(random_state=42)\n",
"y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method='predict_proba')"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"y_scores_forest = y_probas_forest[:, 1]"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(fpr, tpr, 'b:', label='SGD')\n",
"plot_roc_curve(fpr_forest, tpr_forest, 'Random Forest')\n",
"plt.legend(loc='lower right')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['8'], dtype='<U1')"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sgd_clf.fit(X_train, y_train)\n",
"sgd_clf.predict([some_digit])"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAABmxJREFUeJzt3UuIzf8fx/H/0WSUadSklA2SlWRjWMkkC1mhZiHKzk4WiqVSctuRhZ1kQzaKhGKWNkKUSS5ZUMICuSSa/+q3UHPeJ3PmnJk5r8djOa++53tinn3Lp3M0JiYm/gf0vnkz/QaA7hA7hBA7hBA7hBA7hOjr8v380z90XmOyH3qyQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQ4i+mX4DtO/ixYtNtzt37pTXjoyMlPuRI0fK/e3bt+U+PDzcdDt//nx57Y8fP8p94cKF5b527dpyT+PJDiHEDiHEDiHEDiHEDiHEDiHEDiEaExMT3bxfV2/WK0ZHR8v9xo0bTbefP39O99v5S6vfnxUrVjTdPn/+XF67adOmKb2n/+zbt6/pNjY2Vl574sSJtu49wxqT/dCTHUKIHUKIHUKIHUKIHUKIHUL4iOsccPXq1XJvNCY9aemKVh8j3b59e9Pt9u3b5bXXr18v9/7+/nJfvnx5021oaKi8thd5skMIsUMIsUMIsUMIsUMIsUMIsUMI5+xzwODgYLl//fp1yq+9YcOGcj98+HC5r1mzptwHBgaabjt27CivHR8fL/dWvn//3nRr589srvJkhxBihxBihxBihxBihxBihxBihxC+SnoOaHXWffr06Sm/9v3798t9/fr1U35tZoyvkoZkYocQYocQYocQYocQYocQYocQPs8+C1y6dKncT5061bF737t3r9yds/cOT3YIIXYIIXYIIXYIIXYIIXYIIXYI4Zy9C37//l3ud+/ebev1q/+fvdX3wh88eLCtezN3eLJDCLFDCLFDCLFDCLFDCLFDCEdvXXDz5s1yv3DhQluvv23btqbbuXPnymv7+vwKpPBkhxBihxBihxBihxBihxBihxBihxAOWafBnz9/yv3o0aMdvf/WrVubbsuWLevovZk7PNkhhNghhNghhNghhNghhNghhNghhHP2afDq1atyf/DgQUfvPzIy0tHXpzd4skMIsUMIsUMIsUMIsUMIsUMIsUMI5+zTYMmSJeW+cuXKcn/x4kVb91+6dGlb15PBkx1CiB1CiB1CiB1CiB1CiB1CiB1COGefBoODg+W+ePHicn/58mW5tzpHf/r0adOtv7+/vLZdr1+/Lvfh4eGm28ePH8trV69eXe4DAwPlzt882SGE2CGE2CGE2CGE2CGE2CFEY2Jiopv36+rNZouzZ8+W+4EDB8q91d9Ro9H45/c0XVq9t0WLFjXdvnz5Ul67ZcuWct+9e3e57927t9x72KS/EJ7sEELsEELsEELsEELsEELsEELsEMI5exd8+/at3NetW1fu4+Pj5T6bz9ln8r1duXKl6bZz587y2nnz5vRz0Dk7JBM7hBA7hBA7hBA7hBA7hBA7hHDOPgu8e/eu3Pfs2VPuY2Nj0/hu/k2rr7nevHlz0+3JkyfltY8fP57Se/pP9bv96dOn8tqhoaG27j3DnLNDMrFDCLFDCLFDCLFDCLFDCLFDCOfsPa7Vd7PfunWr3EdHR6fz7fyTVp/j37VrV7k/evSo6XbmzJny2v3795f7LOecHZKJHUKIHUKIHUKIHUKIHUL0zfQboLMGBwfLfSaP1lp5+PBhuT979mzKrz08PDzla+cqT3YIIXYIIXYIIXYIIXYIIXYIIXYI4Zydjvrw4UPT7dChQ+W1ly9fLvdfv36Ve/URWOfsQM8SO4QQO4QQO4QQO4QQO4QQO4TwVdKUrl27Vu7Hjx8v9zdv3jTd3r9/X167YMGCcm/1dc8nT54s9x7mq6QhmdghhNghhNghhNghhNghhNghhM+zU3r+/Hm5r1q1qtznz58/5WuPHTtW7hs3bix3/ubJDiHEDiHEDiHEDiHEDiHEDiHEDiF8nh16j8+zQzKxQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQwixQ4hu/5fNk37FLdB5nuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQQuwQ4v/GxfNRZYW5kQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"some_digit = X[2000]\n",
"some_digit_image = some_digit.reshape(28, 28)\n",
"plt.imshow(some_digit_image, cmap=matplotlib.cm.binary, interpolation='nearest')\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.88071789, 0.87890605, 0.85382076])"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring='accuracy')"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train, cv=3)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"conf_mx = confusion_matrix(y_train, y_train_pred)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[5767, 3, 15, 14, 19, 24, 27, 2, 22, 2],\n",
" [ 1, 6493, 115, 10, 8, 20, 9, 18, 62, 39],\n",
" [ 91, 63, 5314, 93, 72, 28, 84, 78, 162, 13],\n",
" [ 77, 32, 245, 5182, 17, 309, 25, 69, 119, 84],\n",
" [ 32, 24, 40, 8, 5402, 27, 41, 26, 71, 175],\n",
" [ 145, 37, 50, 244, 82, 4458, 92, 39, 211, 72],\n",
" [ 64, 42, 102, 1, 63, 132, 5422, 3, 41, 3],\n",
" [ 44, 50, 119, 38, 115, 28, 7, 5573, 35, 212],\n",
" [ 117, 248, 185, 223, 108, 379, 48, 40, 4331, 158],\n",
" [ 76, 66, 37, 120, 493, 166, 1, 487, 193, 4327]],\n",
" dtype=int64)"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conf_mx"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x15f3c819940>"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAECCAYAAADesWqHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAACypJREFUeJzt3c+LXfUZx/HPZ36YcZL6I7YbM6FRqLYqloShqAEXxkVbRTdZWIhQN9m0GkWQ2I3/gIguijDEutCgwphFEbEWVLCb0DEJjHEs/myMRkwJ1cSFM5N5upgJqEnnntH73DPX5/0CITPefHm4ue85594593sdEQJQy0DbAwDoPcIHCiJ8oCDCBwoifKAgwgcKai1827+2/S/b79re3dYcTdneaPtV2zO2D9ve1fZMTdgetH3Q9gttz9KE7YtsT9p+e+m+vr7tmTqxfd/SY+JN28/YHml7pk5aCd/2oKQ/S/qNpKsk/c72VW3MsgLzku6PiF9Iuk7SH/pgZknaJWmm7SFW4DFJL0XEzyX9Uqt8dtsbJN0jaTwirpE0KOmOdqfqrK0j/q8kvRsR70fErKRnJd3e0iyNRMSxiDiw9OeTWnxAbmh3quXZHpN0i6Q9bc/ShO0LJN0o6QlJiojZiPhvu1M1MiTpfNtDkkYlfdLyPB21Ff4GSR997eujWuURfZ3tTZI2S9rf7iQdPSrpAUkLbQ/S0OWSjkt6cunpyR7ba9seajkR8bGkhyUdkXRM0ucR8XK7U3XWVvg+x/f64tph2+skPS/p3oj4ou15/h/bt0r6LCLeaHuWFRiStEXS4xGxWdKXklb16z+2L9bi2eplki6VtNb2jnan6qyt8I9K2vi1r8fUB6dHtoe1GP3eiNjX9jwdbJV0m+0PtfhU6ibbT7c7UkdHJR2NiDNnUpNa/EGwmt0s6YOIOB4Rc5L2Sbqh5Zk6aiv8f0r6me3LbJ+nxRdD/trSLI3Ythafe85ExCNtz9NJRDwYEWMRsUmL9+8rEbGqj0QR8amkj2xfufStbZLeanGkJo5Ius726NJjZJtW+QuS0uKpVc9FxLztP0r6mxZfBf1LRBxuY5YV2CrpTknTtg8tfe9PEfFiizP9EN0tae/SAeF9SXe1PM+yImK/7UlJB7T4m5+Dkibanaoz87ZcoB6u3AMKInygIMIHCiJ8oCDCBwpqPXzbO9ueYSX6bV6JmXuh3+ZtPXxJfXWHqf/mlZi5F/pq3tUQPoAeS7mAZ/369TE2NtbotidOnND69esb3XZ6evr7jAWUEBHnehPcN6Rcsjs2NqYXX+z+lawbN27sfCOsWouXsufgCtSV4VQfKIjwgYIIHyiI8IGCCB8oqFH4/bYHPoDldQy/T/fAB7CMJkf8vtsDH8DymoTf13vgAzhbk/Ab7YFve6ftKdtTJ06c+P6TAUjTJPxGe+BHxEREjEfEeNNr7wG0o0n4fbcHPoDldXyTTp/ugQ9gGY3enbf0oRF8cATwA8GVe0BBhA8URPhAQYQPFET4QEEpm23aTtkALXNfteHh4ZR15+fnU9ZFbwwM5B0bMx7PEdFos02O+EBBhA8URPhAQYQPFET4QEGEDxRE+EBBhA8URPhAQYQPFET4QEGEDxRE+EBBhA8URPhAQYQPFET4QEGEDxRE+EBBhA8URPhAQYQPFNToQzO/i8HBwb5Y84yDBw+mrLtly5aUdaW87caz1s3cqjpL5mNudnY2be1O+u9fAsD3RvhAQYQPFET4QEGEDxRE+EBBhA8U1DF82xttv2p7xvZh27t6MRiAPE0u4JmXdH9EHLD9I0lv2P57RLyVPBuAJB2P+BFxLCIOLP35pKQZSRuyBwOQZ0XP8W1vkrRZ0v6MYQD0RuNr9W2vk/S8pHsj4otz/P+dknZ2cTYASRqFb3tYi9HvjYh957pNRExImli6fc67PAB0RZNX9S3pCUkzEfFI/kgAsjV5jr9V0p2SbrJ9aOm/3ybPBSBRx1P9iPiHJPdgFgA9wpV7QEGEDxRE+EBBhA8URPhAQc7YUdV2ZOyomrX7qySNjIykrPv666+nrCtJ4+PjKeuuW7cuZd1Tp06lrCvl7eCbuTNwxuP59OnTioiOv4XjiA8URPhAQYQPFET4QEGEDxRE+EBBhA8URPhAQYQPFET4QEGEDxRE+EBBhA8URPhAQYQPFET4QEGEDxRE+EBBhA8URPhAQYQPFET4QEFp22svfrp2d2Vur50xr5Q78/T0dMq61157bcq6mbLu58zttYeGOn5m7YrNzc1pYWGB7bUBnI3wgYIIHyiI8IGCCB8oiPCBgggfKKhx+LYHbR+0/ULmQADyreSIv0vSTNYgAHqnUfi2xyTdImlP7jgAeqHpEf9RSQ9IWkicBUCPdAzf9q2SPouINzrcbqftKdtTXZsOQIomR/ytkm6z/aGkZyXdZPvpb98oIiYiYjwixrs8I4Au6xh+RDwYEWMRsUnSHZJeiYgd6ZMBSMPv8YGCVvSG4Ih4TdJrKZMA6BmO+EBBhA8URPhAQYQPFET4QEEpu+wODAxExg6i8/PzXV8z28jISNras7OzKetOTk6mrLt9+/aUdSVpYSHnavI1a9akrCvl/PstLCwoIthlF8DZCB8oiPCBgggfKIjwgYIIHyiI8IGCCB8oiPCBgggfKIjwgYIIHyiI8IGCCB8oiPCBgggfKIjwgYIIHyiI8IGCCB8oiPCBglJ22bUdAwPd/5mSMesZGfNK0unTp1PWlfJmHhwcTFn30KFDKetK0tVXX52yrt1xw9rvLOvxzC67AM6J8IGCCB8oiPCBgggfKIjwgYIIHyioUfi2L7I9aftt2zO2r88eDECepp9l/ZiklyJiu+3zJI0mzgQgWcfwbV8g6UZJv5ekiJiVlPPB7AB6osmp/uWSjkt60vZB23tsr02eC0CiJuEPSdoi6fGI2CzpS0m7v30j2zttT9me6vKMALqsSfhHJR2NiP1LX09q8QfBN0TERESMR8R4NwcE0H0dw4+ITyV9ZPvKpW9tk/RW6lQAUjV9Vf9uSXuXXtF/X9JdeSMByNYo/Ig4JIlTeOAHgiv3gIIIHyiI8IGCCB8oiPCBgggfKChte+3MbYkzZG0pnbm9dr/NnLk9+jvvvJOy7hVXXJGyriStWbOm62t+9dVXWlhYYHttAGcjfKAgwgcKInygIMIHCiJ8oCDCBwoifKAgwgcKInygIMIHCiJ8oCDCBwoifKAgwgcKInygIMIHCiJ8oCDCBwoifKAgwgcKSttld2io6QfxNjc8PNz1Nc9YWFhIWTdzt+H5+fmUdUdHR1PWPXnyZMq6mZ577rm0tXfs2NH1Nefm5thlF8C5ET5QEOEDBRE+UBDhAwURPlAQ4QMFNQrf9n22D9t+0/YztkeyBwOQp2P4tjdIukfSeERcI2lQ0h3ZgwHI0/RUf0jS+baHJI1K+iRvJADZOoYfER9LeljSEUnHJH0eES9nDwYgT5NT/Ysl3S7pMkmXSlpr+6yLjG3vtD1le6r7YwLopian+jdL+iAijkfEnKR9km749o0iYiIixiNivNtDAuiuJuEfkXSd7VEvvtVsm6SZ3LEAZGryHH+/pElJByRNL/2dieS5ACRq9Kb5iHhI0kPJswDoEa7cAwoifKAgwgcKInygIMIHCiJ8oKC07bUHBvrrZ0rG/SBJg4ODKetK0oUXXpiybtY22LOzsynrStIll1ySsu6pU6dS1pWkp556qutr7t69W++99x7bawM4G+EDBRE+UBDhAwURPlAQ4QMFET5QEOEDBRE+UBDhAwURPlAQ4QMFET5QEOEDBRE+UBDhAwURPlAQ4QMFET5QEOEDBRE+UFDWLrvHJf274c1/LOk/XR8iT7/NKzFzL6yWeX8aET/pdKOU8FfC9lREjLc6xAr027wSM/dCv83LqT5QEOEDBa2G8CfaHmCF+m1eiZl7oa/mbf05PoDeWw1HfAA9RvhAQYQPFET4QEGEDxT0P6R8uVZSmUT9AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 288x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.matshow(conf_mx, cmap=plt.cm.gray)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"np.fill_diagonal(conf_mx, 0)"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x15f3cea2898>"
]
},
"execution_count": 67,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAECCAYAAADesWqHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADA9JREFUeJzt3V+InfWZwPHvk5lJJjGGFhUhf3Asmu7WgFrG1TbQC83Fui3tzYoWLGxvgrCmthRK603wvpT2YqmEdHtTaZDUC6lr7UpbYQWTHaPQmnSh2qxJakgiJI0JIZnM04uZgHXdnHfk/Oad0+f7ASEzvj48DvOd98zJe94TmYmkWlb0vYCkpWf4UkGGLxVk+FJBhi8VZPhSQb2FHxH/GBH/ExF/iIhv97VHVxGxKSJ+HRGHIuKNiHis7526iIixiHgtIn7e9y5dRMTHImJvRPx+4Wv9mb53GiQivrHwPfG7iPhpREz2vdMgvYQfEWPAvwH3A58CvhwRn+pjl0WYBb6ZmX8P3AP86wjsDPAYcKjvJRbhB8AvMvPvgNtZ5rtHxAbga8B0Zm4BxoCH+t1qsL7O+P8A/CEz38rMi8Ae4Es97dJJZr6TmQcW/nyW+W/IDf1udXURsRH4PLC77126iIh1wOeAHwFk5sXMPN3vVp2MA6sjYhxYA/yp530G6iv8DcCR9318lGUe0ftFxBRwJ7Cv300G+j7wLWCu70U6+gRwEvjxwq8nuyPimr6XuprMPAZ8F3gbeAc4k5m/7HerwfoKPz7kcyNx7XBErAV+Bnw9M//c9z7/n4j4AnAiM1/te5dFGAc+DfwwM+8EzgHL+vmfiPg4849WbwbWA9dExMP9bjVYX+EfBTa97+ONjMDDo4iYYD76pzLzmb73GWAr8MWIOMz8r1L3RsRP+l1poKPA0cy88khqL/M/CJazbcAfM/NkZl4CngE+2/NOA/UV/n8Dt0bEzRGxkvknQ57taZdOIiKY/93zUGZ+r+99BsnM72TmxsycYv7r+6vMXNZnosw8DhyJiE8ufOo+4GCPK3XxNnBPRKxZ+B65j2X+hCTMP7Racpk5GxGPAi8w/yzov2fmG33ssghbga8Av42I1xc+93hm/kePO/0t2gE8tXBCeAv4as/7XFVm7ouIvcAB5v/m5zVgV79bDRa+LFeqxyv3pIIMXyrI8KWCDF8qyPClgnoPPyK2973DYozavuDOS2HU9u09fGCkvmCM3r7gzkthpPZdDuFLWmJNLuCJiJG7Kmj+asvBMrPzsaNqOfz/LfbrPDfX5gWIK1Z0Ozd+lO+LVjtn5sBFerlk96Nq+Q25cuXKZrNb6fpNuVhjY2NN5rbaF+Ds2bNN5q5evbrJXIDz588PfWbXE7kP9aWCDF8qyPClggxfKsjwpYI6hT9q98CXdHUDwx/Re+BLuoouZ/yRuwe+pKvrEv5I3wNf0v/V5cq9TvfAX3h10ki9UEGqqkv4ne6Bn5m7WLi76Cheqy9V0uWh/sjdA1/S1Q0844/oPfAlXUWnV+ctvGmEbxwh/Y3wyj2pIMOXCjJ8qSDDlwoyfKmgkbrnXst39p2enm4yd//+/U3mAly+fLnJ3Fb3Nrxw4UKTuQCTk5NN5t5www1N5gK89957Q595+vTpTsd5xpcKMnypIMOXCjJ8qSDDlwoyfKkgw5cKMnypIMOXCjJ8qSDDlwoyfKkgw5cKMnypIMOXCjJ8qSDDlwoyfKkgw5cKMnypIMOXCjJ8qaBoccvqNWvW5ObNm4c+98yZM0OfecXhw4ebzL3jjjuazIU2t2cGOHHiRJO5mzZtajIX2t1qfMuWLU3mAjz77PDfbf7SpUvMzc0NvD+6Z3ypIMOXCjJ8qSDDlwoyfKkgw5cKMnypoIHhR8SmiPh1RByKiDci4rGlWExSO+MdjpkFvpmZByLiWuDViPjPzDzYeDdJjQw842fmO5l5YOHPZ4FDwIbWi0lqZ1G/40fEFHAnsK/FMpKWRpeH+gBExFrgZ8DXM/PPH/LvtwPbASYmJoa2oKTh63TGj4gJ5qN/KjOf+bBjMnNXZk5n5vT4eOefJ5J60OVZ/QB+BBzKzO+1X0lSa13O+FuBrwD3RsTrC//8U+O9JDU08DF5Zv4XMPD1vZJGh1fuSQUZvlSQ4UsFGb5UkOFLBTW50mZubo4LFy4Mfe7p06eHPvOKnTt3Npn7xBNPNJkLMDk52WTuk08+2WTuo48+2mQuwI033thk7vHjx5vMBbj11luHPvPNN9/sdJxnfKkgw5cKMnypIMOXCjJ8qSDDlwoyfKkgw5cKMnypIMOXCjJ8qSDDlwoyfKkgw5cKMnypIMOXCjJ8qSDDlwoyfKkgw5cKMnypIMOXCorMHPrQ8fHxvPbaa4c+d3Z2dugzr2h1q+qWtwRv9fVYt25dk7mrVq1qMhdocjt3aHfbboC77rpr6DNfeOEF3n333YFvcusZXyrI8KWCDF8qyPClggxfKsjwpYIMXyqoc/gRMRYRr0XEz1suJKm9xZzxHwMOtVpE0tLpFH5EbAQ+D+xuu46kpdD1jP994FvAXMNdJC2RgeFHxBeAE5n56oDjtkfETETMzM3580Fazrqc8bcCX4yIw8Ae4N6I+MkHD8rMXZk5nZnTK1b4lwXScjaw0Mz8TmZuzMwp4CHgV5n5cPPNJDXjqVkqaHwxB2fmb4DfNNlE0pLxjC8VZPhSQYYvFWT4UkGGLxW0qGf1u1q7di1bt24d+txXXnll6DOvmJiYaDL38ccfbzIXYO/evU3mHjx4sMnc22+/vclcgGPHjjWZe//99zeZC/D8888Pfea5c+c6HecZXyrI8KWCDF8qyPClggxfKsjwpYIMXyrI8KWCDF8qyPClggxfKsjwpYIMXyrI8KWCDF8qyPClggxfKsjwpYIMXyrI8KWCDF8qKDJz6ENXrVqV69evH/rcrncQ/ShuuummJnNnZmaazAWYmppqMveWW25pMvfFF19sMhcgIprMvf7665vMBTh16tTQZ2YmmTnwi+EZXyrI8KWCDF8qyPClggxfKsjwpYIMXyqoU/gR8bGI2BsRv4+IQxHxmdaLSWqn69tk/wD4RWb+c0SsBNY03ElSYwPDj4h1wOeAfwHIzIvAxbZrSWqpy0P9TwAngR9HxGsRsTsirmm8l6SGuoQ/Dnwa+GFm3gmcA779wYMiYntEzETEzOXLl4e8pqRh6hL+UeBoZu5b+Hgv8z8I/kpm7srM6cycHhsbG+aOkoZsYPiZeRw4EhGfXPjUfcDBpltJaqrrs/o7gKcWntF/C/hqu5UktdYp/Mx8HZhuvIukJeKVe1JBhi8VZPhSQYYvFWT4UkGGLxXU9e/xF2Vubo7z588Pfe74eJN1AZicnGwy97rrrmsyF2DDhg1N5u7fv7/J3JZXdLa6TLzl5ec7duwY+sw9e/Z0Os4zvlSQ4UsFGb5UkOFLBRm+VJDhSwUZvlSQ4UsFGb5UkOFLBRm+VJDhSwUZvlSQ4UsFGb5UkOFLBRm+VJDhSwUZvlSQ4UsFGb5UUJPb1k5OTrJ58+ahz922bdvQZ15x+PDhJnPvvvvuJnMBXn755SZzn3vuuSZzH3jggSZzAWZnZ5vMPXXqVJO5AEeOHBn6zIsXL3Y6zjO+VJDhSwUZvlSQ4UsFGb5UkOFLBRm+VFCn8CPiGxHxRkT8LiJ+GhFt3lpW0pIYGH5EbAC+Bkxn5hZgDHio9WKS2un6UH8cWB0R48Aa4E/tVpLU2sDwM/MY8F3gbeAd4Exm/rL1YpLa6fJQ/+PAl4CbgfXANRHx8Icctz0iZiJi5tKlS8PfVNLQdHmovw34Y2aezMxLwDPAZz94UGbuyszpzJyemJgY9p6ShqhL+G8D90TEmogI4D7gUNu1JLXU5Xf8fcBe4ADw24X/ZlfjvSQ11On1+Jm5E9jZeBdJS8Qr96SCDF8qyPClggxfKsjwpYIMXyooMnPoQycnJ3Nqamroc1teCnzmzJkmc2+77bYmcwFeeumlJnMffPDBJnOffvrpJnOh3e21H3nkkSZzAXbv3t1kbmbGoGM840sFGb5UkOFLBRm+VJDhSwUZvlSQ4UsFGb5UkOFLBRm+VJDhSwUZvlSQ4UsFGb5UkOFLBRm+VJDhSwUZvlSQ4UsFGb5UkOFLBTW5y25EnAT+t+Ph1wOnhr5EO6O2L7jzUlgu+96UmTcMOqhJ+IsRETOZOd3rEoswavuCOy+FUdvXh/pSQYYvFbQcwt/V9wKLNGr7gjsvhZHat/ff8SUtveVwxpe0xAxfKsjwpYIMXyrI8KWC/gLRCdLsjvAymQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 288x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.matshow(conf_mx, cmap=plt.cm.gray)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x15f3cef9208>"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP4AAAECCAYAAADesWqHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAADGhJREFUeJzt3VGInfWZx/Hfz5lMkjPZmILFsYlurC6xpVBShsZ0oKDpxWpLe+EaLKRgVXKztTYUatubXgQEodRUKIFo7U2lBVORpa7dKG2FvQkdY6Q1k0XRrk7MYAaxKWPiZMzTizkD1qbzvid9/+ed0+f7ASEzvj48jPPNe86Z97zjiBCAXC5pewEA/Uf4QEKEDyRE+EBChA8kRPhAQq2Fb/vfbf+f7Zdtf7utPeqyfaXt39iesv2i7Xva3qkO20O2n7f9y7Z3qcP2BtsHbR/vfq23t71TFdt7ut8Tf7D9M9tr2t6pSivh2x6S9CNJN0n6uKQv2/54G7v0YEHSNyPiY5Kul/SfA7CzJN0jaartJXrwQ0m/iojrJH1SK3x32xslfV3SeER8QtKQpNva3apaW2f8T0t6OSJeiYh5ST+X9KWWdqklIk5GxJHun/+sxW/Ije1utTzbmyR9XtLDbe9Sh+31kj4r6ceSFBHzEfF2u1vVMixpre1hSR1Jb7S8T6W2wt8o6fX3fTytFR7R+9neLGmrpMPtblJpn6RvSTrf9iI1fVTSKUk/6T49edj2aNtLLSciTkj6vqTXJJ2U9KeIONTuVtXaCt8X+NxAXDtse52kX0j6RkScbnufv8f2FyS9GRHPtb1LD4YlfUrS/ojYKmlO0op+/cf2h7T4aPVqSR+RNGp7V7tbVWsr/GlJV77v400agIdHtldpMfpHI+LxtvepMCHpi7b/qMWnUjfa/mm7K1WaljQdEUuPpA5q8S+Clexzkl6NiFMRcU7S45I+0/JOldoK/3eS/s321bZHtPhiyH+1tEsttq3F555TEfGDtvepEhHfiYhNEbFZi1/fX0fEij4TRcSMpNdtb+l+aoekYy2uVMdrkq633el+j+zQCn9BUlp8aNV3EbFg+2uS/keLr4I+EhEvtrFLDyYkfUXS720f7X7uuxHx3y3u9M/obkmPdk8Ir0j6asv7LCsiDts+KOmIFn/y87ykA+1uVc28LRfIhyv3gIQIH0iI8IGECB9IiPCBhFoP3/butnfoxaDtK7FzPwzavq2HL2mgvmAavH0ldu6Hgdp3JYQPoM+KXMBju8hVQZdcUu7vqXXr1tU6bn5+XiMjI7XnlrxAatWqVbWOe/fdd7V69erG5/ZqaGio9rFzc3MaHa3/xryTJ09ezEqVLrvsslrHnTlzRmvXru1p9uzs7MWstKyIUERc6E1wf6WVS3YvVt04L8bExESRuefOnSsyV5KuuOKKInPHxsaKzF2/fn2RuZJ03333FZm7c+fOInMl6ZFHHml85tmzZ2sdx0N9ICHCBxIifCAhwgcSInwgoVrhD9o98AEsrzL8Ab0HPoBl1DnjD9w98AEsr074A30PfAB/q86Ve7Xugd99d9JAvVEByKpO+LXugR8RB9S9u2ipa/UBNKPOQ/2Buwc+gOVVnvEH9B74AJZR69153V8awS+OAP5JcOUekBDhAwkRPpAQ4QMJET6QULGbbZa4Meb58+cbn7mk1E0xL7/88iJzpXL389u4scwV2TMzM0XmSvVvitmrO+64o8hcSZqcnGx85tNPP6233nqr8mabnPGBhAgfSIjwgYQIH0iI8IGECB9IiPCBhAgfSIjwgYQIH0iI8IGECB9IiPCBhAgfSIjwgYQIH0iI8IGECB9IiPCBhAgfSIjwgYQIH0io1i/N7NWGDRu0Y8eOxueeOHGi8ZlL7Mo7El+UW2+9tchcSZqdnS0y9/jx40Xmbt++vchcSZqfny8yd+3atUXmStKTTz7Z+MwzZ87UOo4zPpAQ4QMJET6QEOEDCRE+kBDhAwkRPpBQZfi2r7T9G9tTtl+0fU8/FgNQTp0LeBYkfTMijtj+F0nP2X46Io4V3g1AIZVn/Ig4GRFHun/+s6QpSRtLLwagnJ6e49veLGmrpMMllgHQH7Wv1be9TtIvJH0jIk5f4N/vlrRbKnt9M4B/XK0zvu1VWoz+0Yh4/ELHRMSBiBiPiPHVq1c3uSOAhtV5Vd+SfixpKiJ+UH4lAKXVOeNPSPqKpBttH+3+c3PhvQAUVPkcPyL+V1KZN6sDaAVX7gEJET6QEOEDCRE+kBDhAwk5Ihof2ul0YsuWLY3PnZ6ebnzmkvvvv7/I3DvvvLPIXEnqdDpF5j7zzDNF5t58c7mfAm/atKnI3M2bNxeZK0kzMzONzzx27Jjm5uYqfwrHGR9IiPCBhAgfSIjwgYQIH0iI8IGECB9IiPCBhAgfSIjwgYQIH0iI8IGECB9IiPCBhAgfSIjwgYQIH0iI8IGECB9IiPCBhAgfSIjwgYQqf2nmxYgILSwsND533bp1jc9c8tBDDxWZOzo6WmSuJM3NzRWZe/vttxeZe9111xWZK0mzs7NF5o6NjRWZK0mPPfZY4zMnJiZqHccZH0iI8IGECB9IiPCBhAgfSIjwgYQIH0iodvi2h2w/b/uXJRcCUF4vZ/x7JE2VWgRA/9QK3/YmSZ+X9HDZdQD0Q90z/j5J35J0vuAuAPqkMnzbX5D0ZkQ8V3HcbtuTtiffe++9xhYE0Lw6Z/wJSV+0/UdJP5d0o+2ffvCgiDgQEeMRMT40NNTwmgCaVBl+RHwnIjZFxGZJt0n6dUTsKr4ZgGL4OT6QUE/vx4+I30r6bZFNAPQNZ3wgIcIHEiJ8ICHCBxIifCChInfZ7XQ62rp1a+NzDx061PjMJTMzM0XmPvjgg0XmStL+/fuLzJ2cnCwyt+4dYC9GqTsOX3vttUXmStK2bdsan/nSSy/VOo4zPpAQ4QMJET6QEOEDCRE+kBDhAwkRPpAQ4QMJET6QEOEDCRE+kBDhAwkRPpAQ4QMJET6QEOEDCRE+kBDhAwkRPpAQ4QMJET6QUJG77ErSwsJC4zOvueaaxmcuueWWW4rMveuuu4rMlaSdO3cWmbt3794ic2+66aYicyVpeLjMt/ITTzxRZK4kHTt2rPGZdbvjjA8kRPhAQoQPJET4QEKEDyRE+EBChA8kVCt82xtsH7R93PaU7e2lFwNQTt2rHn4o6VcR8R+2RyR1Cu4EoLDK8G2vl/RZSbdLUkTMS5ovuxaAkuo81P+opFOSfmL7edsP2x4tvBeAguqEPyzpU5L2R8RWSXOSvv3Bg2zvtj1pe/Ls2bMNrwmgSXXCn5Y0HRGHux8f1OJfBH8lIg5ExHhEjK9Zs6bJHQE0rDL8iJiR9LrtLd1P7ZDU/NuKAPRN3Vf175b0aPcV/VckfbXcSgBKqxV+RByVNF54FwB9wpV7QEKEDyRE+EBChA8kRPhAQoQPJFTknsTvvPOOXnjhhcbnjoyMND5zyRtvvFFkbslbgo+NjRWZW+qW4CX//83Pl3nf2MzMTJG5knT48OHqg3q0a9euWsdxxgcSInwgIcIHEiJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIifCAhwgcSInwgIcIHEiJ8ICHCBxIifCAhwgcSInwgIcIHEnJEND600+nEli1bqg/s0bZt2xqfuaTUXXYvvfTSInMl6ejRo0Xm7tmzp8jce++9t8hcSVpYWCgy9+233y4yV5JuuOGGxmdOTk7q9OnTrjqOMz6QEOEDCRE+kBDhAwkRPpAQ4QMJET6QUK3wbe+x/aLtP9j+me01pRcDUE5l+LY3Svq6pPGI+ISkIUm3lV4MQDl1H+oPS1pre1hSR1KZy9wA9EVl+BFxQtL3Jb0m6aSkP0XEodKLASinzkP9D0n6kqSrJX1E0qjtXRc4brftSduTpa6bBtCMOg/1Pyfp1Yg4FRHnJD0u6TMfPCgiDkTEeESMDw8PN70ngAbVCf81Sdfb7ti2pB2SpsquBaCkOs/xD0s6KOmIpN93/5sDhfcCUFCtx+QR8T1J3yu8C4A+4co9ICHCBxIifCAhwgcSInwgIcIHEipye+2RkZEYGxtrfG7JS4Hn5uaKzL3qqquKzJWkffv2FZm7d+/eInOfffbZInMl6dChMm8feeCBB4rMlaSnnnqqyNyI4PbaAP4W4QMJET6QEOEDCRE+kBDhAwkRPpAQ4QMJET6QEOEDCRE+kBDhAwkRPpAQ4QMJET6QEOEDCRE+kBDhAwkRPpAQ4QMJET6QUJG77No+Jen/ax5+maTZxpcoZ9D2ldi5H1bKvv8aER+uOqhI+L2wPRkR460u0YNB21di534YtH15qA8kRPhAQish/ANtL9CjQdtXYud+GKh9W3+OD6D/VsIZH0CfET6QEOEDCRE+kBDhAwn9BRhO2ZHHs4JQAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 288x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"row_sums = conf_mx.sum(axis=1, keepdims=True)\n",
"norm_conf_mx = conf_mx / row_sums\n",
"plt.matshow(norm_conf_mx, cmap=plt.cm.gray)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"\n",
"\n",
"y_train_large = (y_train.astype('uint32') >= 7)\n",
"y_train_odd = (y_train.astype('uint32') % 2 == 1)\n",
"y_multilabel = np.c_[y_train_large, y_train_odd]\n",
"\n",
"knn_clf = KNeighborsClassifier()\n",
"knn_clf.fit(X_train, y_multilabel)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multilabel, cv=3, n_jobs=-1)\n",
"f1_score(y_multilabel, y_train_knn_pred, average='macro')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import numpy.random as rnd\n",
"\n",
"noise = rnd.randint(0, 100, (len(X_train), 784))\n",
"X_train_mod = X_train + noise\n",
"noise = rnd.randint(0, 100, (len(X_test), 784))\n",
"X_test_mod = X_test + noise\n",
"\n",
"y_train_mod = X_train\n",
"y_test_mod = X_test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"knn_clf.fit(X_train_mod, y_train_mod)\n",
"clean_digit = knn_clf.predict([X_test_mod[some_index]])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment