Skip to content

Instantly share code, notes, and snippets.

@DanHickstein
Created April 15, 2016 22:58
Show Gist options
  • Save DanHickstein/4c5a61b13ee50d75ab0a21cdba188c47 to your computer and use it in GitHub Desktop.
Save DanHickstein/4c5a61b13ee50d75ab0a21cdba188c47 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
import pynlo
import matplotlib.cm as cm
FWHM = 0.2 # pulse duration (ps)
pulseWL = 1550 # pulse central wavelength (nm)
EPP = 200e-12 # Energy per pulse (J)
GDD = 0.0 # Group delay dispersion (ps^2)
TOD = 0.0 # Third order dispersion (ps^3)
Window = 5.0 # simulation window (ps)
Steps = 100 # simulation steps
Points = 2**12 # simulation points
error = 0.005 # error desired by the integrator. Usually 0.001 is plenty good. Use larger values for speed
beta2 = -40 # (ps^2/km)
beta3 = 0.00 # (ps^3/km)
beta4 = 0.001 # (ps^4/km)
Length = 60 # length in mm
Alpha = 0.0 # attentuation coefficient (dB/cm)
Gamma = 1000 # Gamma (1/(W km)
fibWL = pulseWL # Center WL of fiber (nm)
Raman = True # Enable Raman effect?
Steep = True # Enable self steepening?
alpha = np.log((10**(Alpha * 0.1))) * 100 # convert from dB/cm to 1/m
# set up plots for the results:
fig = plt.figure(figsize=(13,10))
ax0 = plt.subplot2grid((3,3), (0, 0), rowspan=1)
ax1 = plt.subplot2grid((3,3), (1, 0), rowspan=2, sharex=ax0)
ax2 = plt.subplot2grid((3,3), (0, 1), rowspan=1)
ax3 = plt.subplot2grid((3,3), (1, 1), rowspan=2, sharex=ax2)
ax4 = plt.subplot2grid((3,3), (0, 2), rowspan=1)
ax5 = plt.subplot2grid((3,3), (1, 2), rowspan=2, sharex=ax4)
# create the fiber!
fiber1 = pynlo.media.fibers.fiber.FiberInstance()
fiber1.generate_fiber(Length * 1e-3, center_wl_nm=fibWL, betas=(beta2, beta3, beta4),
gamma_W_m=Gamma * 1e-3, gvd_units='ps^n/km', gain=-alpha)
# Propagation
evol = pynlo.interactions.FourWaveMixing.SSFM.SSFM(local_error=error, USE_SIMPLE_RAMAN=True,
disable_Raman = np.logical_not(Raman),
disable_self_steepening = np.logical_not(Steep))
# create the pulse!
original_pulse = pynlo.light.DerivedPulses.SechPulse(power = 1, # Power will be scaled by set_epp
T0_ps = FWHM/1.76,
center_wavelength_nm = pulseWL,
time_window_ps = Window,
GDD=GDD, TOD=TOD,
NPTS = Points,
frep_MHz = 100,
power_is_avg = False)
original_pulse.set_epp(EPP) # set the pulse energy
def include_noise(Pulse):
import copy
W = Pulse.W_THz
A = Pulse.AW
size_of_bins = np.gradient(W)
energy_per_bin = np.abs(A)**2/size_of_bins * 1e-12
h = 6.62607004e-34
photon_energy = h * W/(2*np.pi) * 1e12
photons_per_bin = energy_per_bin/photon_energy
print 'Total energy: %.1f pJ'%(np.sum(energy_per_bin) * 1e12)
# plt.plot(F, photons_per_bin)
print 'Photons per bin (min/max/avg): %.2e/%.2e/%.2e\nTotal photons: %.2e'%(np.max(photons_per_bin),
np.min(photons_per_bin),np.mean(photons_per_bin), np.sum(photons_per_bin))
size = np.shape(A)[0]
random_intensity = np.random.normal(size=size)
random_phase = np.random.uniform(size=size) * 2 * np.pi
photons_per_bin[photons_per_bin<0] = 0
noise = random_intensity * np.sqrt(photons_per_bin) * photon_energy * size_of_bins * 1e12 * np.exp(1j*random_phase)
print noise
output_pulse = copy.copy(Pulse)
output_pulse.set_AW(A + noise)
return output_pulse
# plt.plot(W, A, label='before')
# plt.plot(W, Pulse.AW, label='with noise')
# ax0.plot(W, (A - Pulse.AW))
#
# plt.legend(frameon=False)
#
# plt.show()
trials = 4
# np.random.seed(0)
for num in range(0,trials):
pulse = include_noise(original_pulse)
y, AW, AT, pulse_out = evol.propagate(pulse_in=pulse, fiber=fiber1, n_steps=Steps)
if 'AW_stack' not in locals():
AW_stack = AW
else:
AW_stack = np.dstack((AW, AW_stack))
AW_stack = AW_stack.transpose()
print np.angle(AW_stack)/np.pi
for n1, E1 in enumerate(AW_stack):
for n2,E2 in enumerate(AW_stack[np.arange(trials) != n1]):
print n1, n2
g12 = np.conj(E1)*E2/np.sqrt(np.abs(E1)**2 * np.abs(E2)**2)
if 'g12_stack' not in locals():
g12_stack = g12
else:
g12_stack = np.dstack((g12, g12_stack))
# print g12_stack.shape, g12_stack.transpose().shape
g12 = np.abs(np.mean(g12_stack, axis=2))
F = pulse_out.F_THz # Frequency grid of pulse (THz)
def dB(num):
return 10 * np.log10(np.abs(num)**2)
for AW in AW_stack:
zW = dB(AW[:, (F > 0)] )
ax0.plot(F[F>0], zW[0], color = 'b')
ax0.plot(F[F>0], zW[-1], color = 'r')
zT = dB( np.transpose(AT) )
ax4.plot(pulse_out.T_ps, dB(pulse_out.AT), color = 'r')
ax4.plot(pulse.T_ps, dB(pulse.AT), color = 'b')
extent = (np.min(F[F > 0]), np.max(F[F > 0]), 0, Length)
ax1.imshow(zW, extent=extent,
vmin=np.max(zW) - 40.0, vmax=np.max(zW),
aspect='auto', origin='lower')
extent = (np.min(F), np.max(F), 0, Length)
ax3.imshow(g12, extent=extent, clim=(0,1), aspect='auto', origin='lower', cmap=cm.inferno)
extent = (np.min(pulse.T_ps), np.max(pulse.T_ps), 0, Length)
ax5.imshow(zT, extent=extent,
vmin=np.max(zT) - 40.0, vmax=np.max(zT),
aspect='auto', origin='lower')
ax0.set_ylabel('Intensity (dB)')
ax0.set_ylim( -80, 0)
ax4.set_ylim( -40, 40)
ax1.set_ylabel('Propagation distance (mm)')
for ax in (ax1,ax3):
ax.set_xlabel('Frequency (THz)')
ax.set_xlim(0,400)
ax5.set_xlabel('Time (ps)')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment