Skip to content

Instantly share code, notes, and snippets.

@ydnar

ydnar/server.go Secret

Created July 16, 2021 19:29
Show Gist options
  • Save ydnar/f73e1806c50c83f90f9b46e13bf751fc to your computer and use it in GitHub Desktop.
Save ydnar/f73e1806c50c83f90f9b46e13bf751fc to your computer and use it in GitHub Desktop.
basic example WebTransport server in Go
package main
import (
"context"
"crypto/tls"
"embed"
"flag"
"fmt"
"io"
"io/fs"
"log"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/alta/insecure"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/http3"
"github.com/lucas-clemente/quic-go/logging"
"github.com/lucas-clemente/quic-go/qlog"
)
func main() {
addr := flag.String("a", "localhost:4433", "address in host:port format")
flag.Parse()
err := serverMain(*addr)
if err != nil {
log.Fatal(err)
}
}
func serverMain(addr string) error {
cert, err := insecureLocalCert(addr)
if err != nil {
return err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"h3", "h3-29", "h2"},
// MinVersion: tls.VersionTLS13,
}
quicConfig := &quic.Config{
EnableDatagrams: true,
}
qlogDir := os.Getenv("QUIC_LOG_DIRECTORY")
if qlogDir != "" {
quicConfig.Tracer = qlog.NewTracer(func(_ logging.Perspective, connID []byte) io.WriteCloser {
fn := filepath.Join(qlogDir, fmt.Sprintf("server-%x.qlog", connID))
f, err := os.Create(fn)
if err != nil {
log.Fatal(err)
}
log.Printf("Created qlog file: %s", fn)
return f
})
}
fsStatic, err := fs.Sub(efsStatic, "static")
if err != nil {
return err
}
if os.Getenv("FS_LIVE") != "" {
_, file, _, _ := runtime.Caller(0)
fsStatic = os.DirFS(filepath.Join(filepath.Dir(file), "static"))
}
hfs := http.FileServer(http.FS(fsStatic))
mux := http.NewServeMux()
mux.Handle("/", hfs)
mux.HandleFunc("/counter", handleCounter)
handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
rw.Header().Add("Access-Control-Allow-Origin", "*")
for _, tok := range originTrials {
rw.Header().Add("Origin-Trial", tok)
}
mux.ServeHTTP(rw, req)
})
err = listenAndServe(addr, tlsConfig, quicConfig, handler)
return err
}
var (
//go:embed static
efsStatic embed.FS
// Chrome origin trial tokens for WebTransport
// https://developer.chrome.com/origintrials
// https://googlechrome.github.io/OriginTrials/check-token.html
originTrials = []string{
// https://127.0.0.1:4433
"ADD YOUR ORIGIN TRIAL TOKEN HERE",
// https://localhost:4433
"ADD YOUR ORIGIN TRIAL TOKEN HERE",
}
)
func handleCounter(rw http.ResponseWriter, req *http.Request) {
log.Printf("WebTransport %s %s", req.Method, req.URL.Path)
go io.ReadAll(req.Body)
rw.Header().Add("cache-controls", "no-store")
rw.WriteHeader(http.StatusOK)
if flusher, ok := rw.(http.Flusher); ok {
flusher.Flush() // Required to send headers
}
ctx := req.Context()
sess := rw.(http3.Session)
go handleIncomingUniStreams(ctx, sess)
for {
str, err := sess.AcceptStream(ctx)
if err != nil {
return
}
log.Printf("accepted WebTransport stream %d", str.StreamID())
go func(str quic.Stream) {
defer str.Close()
buf := make([]byte, 4096)
n, err := str.Read(buf)
if err != nil && err != io.EOF {
log.Printf("error reading from stream %d: %v", str.StreamID(), err)
return
}
log.Printf("received data on WebTransport stream %d: %v", str.StreamID(), string(buf))
s := strings.ToUpper(string(buf[:n]))
_, err = str.Write([]byte(s))
if err != nil {
log.Printf("error writing to stream %d: %v", str.StreamID(), err)
return
}
}(str)
}
}
func handleIncomingUniStreams(ctx context.Context, sess http3.Session) {
for {
str, err := sess.AcceptUniStream(ctx)
if err != nil {
return
}
log.Printf("accepted WebTransport unidirectional stream %d", str.StreamID())
go func(str quic.ReceiveStream) {
buf := make([]byte, 4096)
_, err := str.Read(buf)
if err != nil && err != io.EOF {
log.Printf("error reading from stream %d: %v", str.StreamID(), err)
return
}
log.Printf("received data on WebTransport stream %d: %v", str.StreamID(), string(buf))
}(str)
}
}
func listenAndServe(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, handler http.Handler) error {
// Open the listeners
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
defer udpConn.Close()
tcpAddr, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return err
}
tcpConn, err := net.ListenTCP("tcp", tcpAddr)
if err != nil {
return err
}
defer tcpConn.Close()
tlsConn := tls.NewListener(tcpConn, tlsConfig)
defer tlsConn.Close()
// Start the servers
httpServer := &http.Server{
Addr: addr,
TLSConfig: tlsConfig,
}
quicServer := &http3.Server{
Server: httpServer,
QuicConfig: quicConfig,
EnableDatagrams: quicConfig.EnableDatagrams,
EnableWebTransport: quicConfig.EnableDatagrams,
}
if handler == nil {
handler = http.DefaultServeMux
}
httpServer.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
quicServer.SetQuicHeaders(w.Header())
handler.ServeHTTP(w, r)
})
hErr := make(chan error)
qErr := make(chan error)
go func() {
hErr <- httpServer.Serve(tlsConn)
}()
go func() {
qErr <- quicServer.Serve(udpConn)
}()
log.Printf("Server listening at: https://%s", addr)
select {
case err := <-hErr:
quicServer.Close()
return err
case err := <-qErr:
// Cannot close the HTTP server or wait for requests to complete properly :/
return err
}
}
func insecureLocalCert(addr string) (tls.Certificate, error) {
sans := insecure.LocalSANs()
san, _ := os.Hostname()
if san != "" {
san = strings.ToLower(san)
sans = append(sans, san)
}
san, _, _ = net.SplitHostPort(addr)
if san != "" {
san = strings.ToLower(san)
sans = append(sans, san)
}
return insecure.Cert(sans...)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment