Skip to content

Instantly share code, notes, and snippets.

@notedit
Created February 12, 2012 09:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save notedit/1807424 to your computer and use it in GitHub Desktop.
Save notedit/1807424 to your computer and use it in GitHub Desktop.
server.go
// Date: 2012-02-08
// Author: notedit <notedit@gmail.com>
// make a go rpc service
package rpc
import (
"fmt"
"log"
"net"
"io"
"encoding/binary"
"bytes"
"sync"
"reflect"
"errors"
"strings"
"unicode"
"runtime"
"unicode/utf8"
"launchpad.net/mgo/bson"
)
var typeOfError = reflect.TypeOf((*error)(nil)).Elem()
type methodType struct {
method reflect.Method
ArgType reflect.Type
ReplyType reflect.Type
}
type service struct {
name string
rcvr reflect.Value
typ reflect.Type
method map[string]*methodType
}
// rpc server
type Server struct {
mu sync.Mutex
serviceMap map[string]*service
listener *net.TCPListener
run chan bool
}
//BackendError
type BackendError struct {
Message string
Detail string
}
func (e BackendError) Error() string {
return fmt.Sprintf("%s:%s",e.Message,e.Detail)
}
// operation has three values -- call:1 reply:2 error:3
// request
type serverRequest struct {
messageLength uint32
Operation uint8
Method string
Argument bson.Raw
}
// response
type serverResponse struct {
messageLength uint32
Operation uint8
Reply interface{}
}
// decode request and encode response
type ServerCodec struct {
rwc io.ReadWriteCloser
buf *bytes.Buffer
}
// todo
func (c *ServerCodec)ReadRequestHeader(req *serverRequest) (err error) {
err = binary.Read(c.rwc,binary.BigEndian,req.messageLength)
if err != nil {
req = nil
if err == io.EOF {
return
}
err = errors.New("rpc: server cannot decode requestheader: " + err.Error())
}
return
}
func (c *ServerCodec)ReadRequestBody(req *serverRequest) (err error) {
msgbytes := make([]byte,req.messageLength)
_,err = io.ReadFull(c.rwc,msgbytes)
if err != nil {
return
}
if err = bson.Unmarshal(msgbytes,req); err != nil {
return
}
return
}
func (c *ServerCodec)WriteResponse(res *serverResponse) (err error) {
bys, err := bson.Marshal(res)
if err != nil {
log.Println("writeresponse error",err)
return
}
res.messageLength = uint32(len(bys))
// write message header
_,err = c.buf.Write([]byte{byte(res.messageLength>>24),byte(res.messageLength>>16),byte(res.messageLength>>8),byte(res.messageLength)})
if err != nil {
log.Println("write responseHeader error",err)
return
}
// write message body
_,err = c.buf.Write(bys)
if err != nil {
log.Println("write responseBody error",err)
return
}
_,err = c.buf.WriteTo(c.rwc)
return
}
// todo
func (c *ServerCodec)Close() error {
return c.rwc.Close()
}
// Is this an exported - upper case
func isExported(name string) bool {
rune,_ := utf8.DecodeRuneInString(name)
return unicode.IsUpper(rune)
}
// Is this typoe exported or a builtin?
func isExportedOrBuiltinType(t reflect.Type) bool {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return isExported(t.Name()) || t.PkgPath() == ""
}
// Register a service
func (server *Server)Register(rcvr interface{}) error {
return server.register(rcvr,"",false)
}
// Register a sevice with a name
func (server *Server)RegisterName(name string,rcvr interface{}) error {
return server.register(rcvr,name,true)
}
// the real register
func (server *Server)register(rcvr interface{}, name string,useName bool) error {
server.mu.Lock()
defer server.mu.Unlock()
if server.serviceMap == nil {
server.serviceMap = make(map[string]*service)
}
s := new(service)
s.typ = reflect.TypeOf(rcvr)
s.rcvr = reflect.ValueOf(rcvr)
sname := reflect.Indirect(s.rcvr).Type().Name()
if useName {
sname = name
}
if sname == "" {
log.Fatal("rpc: no service name for type",s.typ.String())
}
if !isExported(sname) && !useName {
s := "rpc Register: type " + sname + " is not exported"
log.Print(s)
return errors.New(s)
}
if _,present := server.serviceMap[sname]; present {
return errors.New("rpc: service already defined: " + sname)
}
s.name = sname
s.method = make(map[string]*methodType)
// Install the methods
for m:=0; m < s.typ.NumMethod(); m++ {
method := s.typ.Method(m)
mtype := method.Type
mname := method.Name
if method.PkgPath != "" {
fmt.Println(method.PkgPath)
continue
}
//Method needs three ins
if mtype.NumIn() != 3 {
log.Println("method needs three ins")
continue
}
// Method has one out:error
if mtype.NumOut() != 1 {
log.Println("method",mname,"has wrong number of outs:",mtype.NumOut())
continue
}
// first arg need not be a pointer
argType := mtype.In(1)
if !isExportedOrBuiltinType(argType) {
log.Println(mname,"argument type not exported or local",argType)
continue
}
replyType := mtype.In(2)
if replyType.Kind() != reflect.Ptr {
log.Println("method",mname," reply type not a pointer:",replyType)
continue
}
if !isExportedOrBuiltinType(replyType) {
log.Println("method ",mname,"reply type not exported or local",replyType)
continue
}
// error type
if returnType := mtype.Out(0); returnType != typeOfError {
log.Println("method",mname," returns",returnType.String(),"not error")
continue
}
s.method[mname] = &methodType{method:method,ArgType:argType,ReplyType:replyType}
}
if len(s.method) == 0 {
s := "rpc Register: type " + sname + " has no exported methods of suitable type"
log.Print(s)
return errors.New(s)
}
server.serviceMap[s.name] = s
return nil
}
func NewServer(host string,port uint) *Server {
addr,err := net.ResolveTCPAddr("tcp",fmt.Sprintf("%s:%d",host,port))
if err != nil {
log.Fatal("rpc error:",err.Error());
}
listener,err := net.ListenTCP("tcp",addr)
if err != nil {
log.Fatal("rpc error:",err.Error())
}
return &Server{
serviceMap:make(map[string]*service),
listener:listener,
run:make(chan bool),
}
}
// serv
func (server *Server) Serv() {
for{
conn,err := server.listener.AcceptTCP()
if err != nil {
log.Print("rpc error:",err.Error())
continue
}
go server.ServeConn(conn)
}
}
func (server *Server) ServeConn(conn io.ReadWriteCloser){
src := &ServerCodec{rwc:conn,buf:new(bytes.Buffer)}
server.ServeCodec(src)
}
func (server *Server)ServeCodec(codec *ServerCodec) {
defer func(){
if r := recover(); r != nil {
//if runtime.Error just panic
if _,ok := r.(runtime.Error); ok {
panic(r)
}
server.sendResponse(nil,codec,errors.New(fmt.Sprintf("%s",r)))
}
codec.Close()
}()
service,mtype,req,argv,replyv,err := server.readRequest(codec)
if err != nil {
if err != io.EOF {
log.Println("rpc:",err)
}
server.sendResponse(nil,codec,errors.New("rpc: can not read the valid request"))
return
}
service.call(server,mtype,req,argv,replyv,codec)
}
func (server *Server)readRequest(codec *ServerCodec) (service *service,mtype *methodType,req *serverRequest,argv reflect.Value,replyv reflect.Value,err error){
req,err = server.readRequestHeader(codec)
// to do
if err != nil {
return
}
service,mtype,argv,replyv,err = server.readRequestBody(codec,req)
return
}
func (server *Server)readRequestBody(codec *ServerCodec,req *serverRequest) (service *service,mtype *methodType, argv reflect.Value,replyv reflect.Value,err error){
err = codec.ReadRequestBody(req)
if err != nil {
return
}
// funcname'format -- service.method
serviceMethod := strings.Split(req.Method,".")
if len(serviceMethod) != 2 {
err = errors.New("rpc: service/method request ill-formed: " + req.Method)
return
}
// look up the service
server.mu.Lock()
service = server.serviceMap[serviceMethod[0]]
server.mu.Unlock()
if service == nil {
err = errors.New("rpc: can't find service " + serviceMethod[0])
return
}
// look up the method
mtype = service.method[serviceMethod[1]]
if mtype == nil {
err = errors.New("rpc: can't find method " + serviceMethod[1])
return
}
argIsValue := false
if mtype.ArgType.Kind() == reflect.Ptr {
argv = reflect.New(mtype.ArgType.Elem())
} else {
argv = reflect.New(mtype.ArgType)
argIsValue = true
}
//argv now is a pointer now
if err = req.Argument.Unmarshal(argv.Interface()); err != nil {
return
}
if argIsValue {
argv = argv.Elem()
}
replyv = reflect.New(mtype.ReplyType.Elem())
return
}
func (server *Server)readRequestHeader(codec *ServerCodec) (req *serverRequest,err error){
req = new(serverRequest)
err = codec.ReadRequestHeader(req)
if err != nil {
req = nil
if err == io.EOF {
return
}
err = errors.New("rpc: server cannot decode the requestheader: " + err.Error())
return
}
return
}
func (server *Server)sendResponse(reply interface{},codec *ServerCodec,err error) {
// to do check
var rerr error
var res *serverResponse
switch err.(type) {
case nil:
res = &serverResponse{Operation:uint8(2),Reply:reply}
case BackendError:
res = &serverResponse{Operation:uint8(3),Reply:err}
case error:
res = &serverResponse{Operation:uint8(3),Reply:BackendError{Message:"InternalError",Detail:err.Error()}}
default:
// to do some log
log.Println("rpc:unvalid error type")
res = &serverResponse{Operation:uint8(3),Reply:BackendError{Message:"InternalError",Detail:"error is unvalid "}}
}
rerr = codec.WriteResponse(res)
if rerr != nil {
log.Println("rpc error:",rerr)
}
return
}
// run the service.method
func (s *service) call(server *Server,mtype *methodType,req *serverRequest,argv,replyv reflect.Value, codec *ServerCodec) {
function := mtype.method.Func
returnValues := function.Call([]reflect.Value{s.rcvr,argv,replyv})
err := returnValues[0].Interface().(BackendError)
server.sendResponse(replyv.Interface(),codec,err)
}
//////////////////////////////////////////////////////////////////////
// some test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment