Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Sagemath script for 3D plots
#!/usr/bin/env python
"""
From Sagemath console:
cd /path/to/file/
load("visual_all_in_one.py")
Author: Tommaso Soru <tsoru@informatik.uni-leipzig.de>
"""
import sys
import csv
##################### INSERT PARAMETERS HERE #####################
# input CSV/TSV file
inp = "data_2k.csv"
# header rows
hdr = 1
# labels (if not needed or not present, insert "[]")
lab = [0,1]
# comma-separated columns
cols = [2,3,4]
# class column
cl_col = 5
# positive class name (negative is everything else)
cl_name = "POS"
##################################################################
print inp, hdr, lab, cols, cl_col, cl_name
if '.csv' in inp.lower():
sep = ','
else: # assume '.tsv'
sep = '\t'
vectors_p = list()
vectors_n = list()
labels = list()
rdr = csv.reader(open(inp), delimiter=sep)
i = 0
for line in rdr:
i += 1
if i <= hdr:
continue
v = [float(line[c]) for c in cols]
if line[cl_col] == cl_name:
vectors_p.append(v)
else:
vectors_n.append(v)
if len(lab) > 0:
s = ""
for l in lab:
s += line[l] + ";"
labels.append(s[:-1])
else:
labels.append("")
print "|P|={}\t|N|={}".format(len(vectors_p), len(vectors_n))
p = point3d(vectors_p, size=10, color='blue')
p += point3d(vectors_n, size=10, color='red')
for i in range(len(labels)):
if i < len(vectors_p):
psn = vectors_p[i]
else:
psn = vectors_n[i - len(vectors_p)]
psn[2] += 0.03
p += text3d(labels[i], psn, color='black')
p.show(xmin=-1, xmax=1, ymin=-1, ymax=1, zmin=-1, zmax=1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.