#! /usr/bin/python import weakref, types from twisted.python import components, failure from twisted.internet import defer import slicer, schema, tokens, banana from tokens import BananaError, Violation from slicer import UnbananaFailure, BaseUnslicer class PendingRequest(object): def __init__(self): self.deferred = defer.Deferred() self.constraint = None # this constrains the results def setConstraint(self, constraint): self.constraint = constraint class RemoteReference(object): def __init__(self, broker, refID, interfaces=None, schema=None): self.broker = broker self.refID = refID self.interfaces = interfaces self.schema = schema def __del__(self): self.broker.freeRemoteReference(self.refID) def callRemote(self, _name, *args, **kwargs): # for consistency, *all* failures are reported asynchronously. try: # this implements an argument _resultConstraint=None, but keeps # it from interfering with any positional parameters _resultConstraint = kwargs.get("_resultConstraint") if kwargs.has_key("_resultConstraint"): del kwargs["_resultConstraint"] # newRequestID() could fail with a StaleBrokerError reqID = self.broker.newRequestID() req = PendingRequest() self.broker.waitingForAnswers[reqID] = req methodSchema = None if self.schema: # getMethodSchema() could raise KeyError for bad methodnames methodSchema = self.schema.getMethodSchema(_name) if methodSchema: # turn positional arguments into kwargs # mapArguments() could fail for bad argument names or # missing required parameters argsdict = methodSchema.mapArguments(args, kwargs) c = methodSchema.getArgsConstraint() if c: # check args against arg constraint. This could fail if # any arguments are of the wrong type c.checkArgs(kwargs) # get return value constraint req.setConstraint(methodSchema.getResponseConstraint()) else: assert not args argsdict = kwargs if _resultConstraint: req.setConstraint(_resultConstraint) # overrides schema child = CallSlicer(self.broker) # this could fail if any of the arguments (or their children) # are unsliceable self.broker.slice2(child, (reqID, self.refID, _name, argsdict)) except: req.deferred.errback(failure.Failure()) # the remote end could send back an error response for many reasons: # bad method name # bad argument types (violated their schema) # exception during method execution # method result violated the results schema # something else could occur to cause an errback: # connection lost before response completely received # exception during deserialization of the response # [but only if it occurs after the reqID is received] # method result violated our results schema # if none of those occurred, the callback will be run return req.deferred class ReferenceUnslicer(BaseUnslicer): refID = None interfaces = None wantInterfaceList = False ilistConstraint = schema.ListOf(schema.TupleOf(str, int)) def checkToken(self, typebyte): if self.refID == None: if typebyte != tokens.INT: raise BananaError("reference ID must be an INT") else: if self.wantInterfaceList: self.ilistConstraint.checkToken(typebyte) else: raise BananaError("interface list on non-initial receipt") def doOpen(self, opentype): # only for the interface list self.ilistConstraint.checkOpentype(opentype) unslicer = self.open(opentype) if unslicer: unslicer.setConstraint(self.ilistConstraint) return unslicer def receiveChild(self, token): if isinstance(token, UnbananaFailure): self.abort(token) return # TODO: if possible, return an error to the other side if self.refID == None: self.refID = token # do we want an interface list? Only if this is the first time # this reference has been received if not self.broker.remoteReferences.has_key(self.refID): self.wantInterfaceList = True else: # must be the interface list assert self.wantInterfaceList assert type(token) == type([]) # TODO: perhaps a dict instead self.interfaces = token def receiveClose(self): if self.refID == None: raise BananaError("sequence ended too early") return self.broker.registerRemoteReference(self.refID, self.interfaces) class DecRefUnslicer(BaseUnslicer): refID = None def checkToken(self, typebyte): if self.refID == None: if typebyte != tokens.INT: raise BananaError("reference ID must be an INT") else: raise BananaError("stop talking already!") def receiveChild(self, token): if isinstance(token, UnbananaFailure): self.abort(token) return # TODO: log but otherwise ignore self.refID = token def receiveClose(self): if self.refID == None: raise BananaError("sequence ended too early") return self.broker.decref(self.refID) class CallUnslicer(BaseUnslicer): stage = 0 # 0:reqID, 1:objID, 2:methodname, 3: [(argname/value)].. reqID = None obj = None methodname = None methodSchema = None # will be a MethodArgumentsConstraint argname = None argConstraint = None def start(self, count): self.args = {} def checkToken(self, typebyte): # TODO: limit strings by returning a number instead of None if self.stage == 0: if typebyte != tokens.INT: raise BananaError("request ID must be an INT") elif self.stage == 1: if typebyte != tokens.INT: raise BananaError("object ID must be an INT") elif self.stage == 2: if typebyte not in (tokens.STRING, tokens.VOCAB): raise BananaError("method name must be a STRING") elif self.stage == 3: if self.argname == None: if typebyte not in (tokens.STRING, tokens.VOCAB): raise BananaError("argument name must be a STRING") else: if self.argConstraint: self.argConstraint.checkToken(typebyte) def doOpen(self, opentype): # this can only happen when we're receiving an argument value, so # we don't have to bother checking self.stage or self.argname if self.argConstraint: self.argConstraint.checkOpentype(opentype) unslicer = self.open(opentype) if unslicer: if self.argConstraint: unslicer.setConstraint(self.argConstraint) return unslicer def receiveChild(self, token): if isinstance(token, UnbananaFailure): self.abort(token) return # TODO: if possible, return an error to the other side if self.stage == 0: self.reqID = token self.stage += 1 elif self.stage == 1: # this might raise an exception if objID is invalid self.obj = self.broker.getObj(token) self.stage += 1 elif self.stage == 2: # validate the methodname, get the schema. This may raise an # exception for unknown methods methodname = token self.methodSchema = self.obj.getMethodSchema(methodname) self.methodname = methodname self.stage += 1 elif self.stage == 3: if self.argname == None: argname = token if self.args.has_key(argname): raise BananaError("duplicate argument '%s'" % argname) ms = self.methodSchema if ms: # if the argname is invalid, this may raise Violation accept, self.argConstraint = ms.getArgConstraint(argname) assert accept # TODO: discard if not self.argname = argname else: argvalue = token self.args[self.argname] = argvalue self.argname = None def receiveClose(self): if self.stage != 3 or self.argname != None: raise BananaError("sequence ended too early") if self.methodSchema: # ask them again so they can look for missing arguments self.methodSchema.checkArgs(self.args) # this is where we actually call the method. doCall will catch any # exceptions. self.broker.doCall(self.reqID, self.obj, self.methodname, self.args, self.methodSchema) def describeSelf(self): if self.stage == 0: return "" elif self.stage == 1: return "" % self.reqID elif self.stage == 2: return "" % (self.reqID, self.obj) elif self.stage == 3: base = "" % (self.reqID, self.obj, self.methodname) if self.argname != None: return base + "arg[%s]" % self.argname return base class AnswerUnslicer(BaseUnslicer): request = None resultConstraint = None haveResults = False def checkToken(self, typebyte): if self.request == None: if typebyte != tokens.INT: raise BananaError("request ID must be an INT") elif not self.haveResults: if self.resultConstraint: try: self.resultConstraint.checkToken(typebyte) except Violation: # since we know which request was being sent, we # can errback the deferred self.broker.gotError(self.request, failure.Failure()) raise else: raise BananaError("stop sending me stuff!") def doOpen(self, opentype): if self.resultConstraint: self.resultConstraint.checkOpentype(opentype) unslicer = self.open(opentype) if unslicer: if self.resultConstraint: unslicer.setConstraint(self.resultConstraint) return unslicer def receiveChild(self, token): if isinstance(token, UnbananaFailure): if self.request != None: self.broker.gotError(self.request, token) self.abort(token) return if self.request == None: reqID = token # may raise BananaError for bad reqIDs self.request = self.broker.getRequest(reqID) self.resultConstraint = self.request.constraint else: self.results = token self.haveResults = True def receiveClose(self): self.broker.gotAnswer(self.request, self.results) class ErrorUnslicer(BaseUnslicer): request = None fConstraint = schema.FailureConstraint() gotFailure = False def checkToken(self, typebyte): if self.request == None: if typebyte != tokens.INT: raise BananaError("request ID must be an INT") elif not self.gotFailure: self.fConstraint.checkToken(typebyte) else: raise BananaError("stop sending me stuff!") def doOpen(self, opentype): self.fConstraint.checkOpentype(opentype) unslicer = self.open(opentype) if unslicer: unslicer.setConstraint(self.fConstraint) return unslicer def receiveChild(self, token): if isinstance(token, UnbananaFailure): # a failure while receiving the failure. A bit daft, really. if self.request != None: self.broker.gotError(self.request, token) self.abort(token) return if self.request == None: reqID = token # may raise BananaError for bad reqIDs self.request = self.broker.getRequest(reqID) else: # TODO: need real failures #self.failure = token self.failure = failure.Failure(RuntimeError(token)) self.gotFailure = True def receiveClose(self): self.broker.gotError(self.request, self.failure) class PBRootUnslicer(slicer.RootUnslicer): # topRegistry defines what objects are allowed at the top-level. All of # these accept a Broker in their __init__ call topRegistry = { ("remote",): ReferenceUnslicer, ("decref",): DecRefUnslicer, ("call",): CallUnslicer, ("answer",): AnswerUnslicer, ("error",): ErrorUnslicer, } # openRegistry defines what objects are allowed at the second level and # below openRegistry = slicer.UnslicerRegistry def checkToken(self, typebyte): if typebyte != tokens.OPEN: raise BananaError("top-level must be OPEN") def openTop(self, opentype): child = self.open(opentype, self.topRegistry) if child: child.broker = self.broker return child def receiveChild(self, obj): pass class BaseSlicer(slicer.BaseSlicer): def __init__(self, broker): slicer.BaseSlicer.__init__(self) self.broker = broker class AnswerSlicer(BaseSlicer): opentype = "answer" def slice(self, (reqID, results)): self.send(reqID) self.send(results) class ErrorSlicer(AnswerSlicer): opentype = "error" def slice(self, (reqID, f)): self.send(reqID) self.send(f) class ReferenceableSlicer(BaseSlicer): """I handle pb.Referenceable objects (things with remotely invokable methods, which are copied by reference). """ opentype = "remote" def slice(self, obj): puid = obj.processUniqueID() firstTime = self.broker.luids.has_key(puid) luid = self.broker.registerReference(obj) self.send(luid) if not firstTime: # this is the first time the Referenceable has crossed this # wire. In addition to the luid, send the interface list to the # far end. interfaces = components.getInterfaces(obj) # TODO: versioned Interfaces! ilist = [(name, 0) for name in interfaces] self.send(ilist) class DecRefSlicer(BaseSlicer): opentype = "decref" def slice(self, refID): self.send(refID) class CopyableSlicer(BaseSlicer): """I handle pb.Copyable objects (things which are copied by value).""" opentype = "instance" # ??? class CallSlicer(BaseSlicer): opentype = "call" def slice(self, (reqID, refID, methodname, args)): self.send(refID) self.send(refID) self.send(methodname) keys = args.keys() keys.sort() for argname in keys: self.send(argname) self.send(args[argname]) PBSlicerRegistry = {} PBSlicerRegistry.update(slicer.BaseSlicerRegistry) del PBSlicerRegistry[types.InstanceType] class PBRootSlicer(slicer.RootSlicer): SlicerRegistry = PBSlicerRegistry def slicerFactoryForObject(self, obj): if isinstance(obj, Referenceable): return ReferenceableSlicer if isinstance(obj, Copyable): return CopyableSlicer return slicer.RootSlicer.slicerFactoryForObject(self, obj) class Broker(banana.Banana): slicerClass = PBRootSlicer unslicerClass = PBRootUnslicer def __init__(self): banana.Banana.__init__(self) self.rootSlicer.broker = self self.rootUnslicer.broker = self self.remoteReferences = weakref.WeakValueDictionary() self.currentRequestID = 0 self.waitingForAnswers = {} self.currentLocalID = 0 self.localObjects = {} # things which are available to our peer. # These are reference counted and removed # when the last decref message is received. def newLocalID(self): """Generate a new LUID. """ self.currentLocalID = self.currentLocalID + 1 return self.currentLocalID def putObj(self, obj): # TODO: give duplicates the same objID objID = self.newLocalID() self.localObjects[objID] = obj return objID def getObj(self, objID): """objID is a number which refers to a object that the remote end is allowed to invoke methods upon. """ obj = self.localObjects[objID] # obj = tokens.IReferenceable(obj) #assert isinstance(obj, pb.Referenceable) # obj needs .getMethodSchema, which needs .getArgConstraint return obj # RemoteReference.callRemote, gotAnswer, and gotError are run on the # calling side def newRequestID(self): self.currentRequestID = self.currentRequestID + 1 return self.currentRequestID def getRequest(self, reqID): try: req = self.waitingForAnswers[reqID] del self.waitingForAnswers[reqID] return req except KeyError: raise BananaError("non-existent reqID '%d'" % reqID) def gotAnswer(self, req, results): req.deferred.callback(results) def gotError(self, req, failure): req.deferred.errback(failure) # decref is also invoked on the calling side (the pb.Referenceable # holder) when the other side sends us a decref message def decref(self, refID): del self.localObjects[refID] # doCall, callFinished, sendError are run on the target side def doCall(self, reqID, obj, methodname, args, methodSchema): try: meth = getattr(obj, "remote_%s" % methodname) res = meth(**args) except: # TODO: implement CopyableFailure and FailureConstraint #f = failure.CopyableFailure() f = failure.Failure() #print "doCall failure", f msg = f.getErrorMessage() + f.getBriefTraceback() #msg = "ooga booga" self.sendError(msg, reqID) else: if not isinstance(res, defer.Deferred): res = defer.succeed(res) res.addCallback(self.callFinished, reqID, methodSchema) res.addErrback(self.sendError, reqID) def callFinished(self, res, reqID, methodSchema): if methodSchema: methodSchema.checkResults(res) # may raise Violation child = AnswerSlicer(self) self.slice2(child, (reqID, res)) def sendError(self, f, reqID): child = ErrorSlicer(self) self.slice2(child, (reqID, f)) # registerRemoteReference and freeRemoteReference are also run on the # target side (the side that has the RemoteReference) def registerRemoteReference(self, refID, interfaces): """The far end holds a Referenceable and has just sent us a reference to it (expressed as a small integer). If this is a new reference, they will give us an interface list too. Obtain a RemoteReference object (creating it if necessary) to give to the local recipient. There is exactly one RemoteReference object for each refID. We hold a weakref to the RemoteReference so we can provide the same object later but so we can detect when the Broker is the only thing left that knows about it. The sender remembers that we hold a reference to their object. When our RemoteReference goes away, its __del__ method will tell us to send a decref message so they can possibly free their object. """ obj = self.remoteReferences.get(refID) if not obj: obj = RemoteReference(self, refID, interfaces) self.remoteReferences[refID] = obj return obj def freeRemoteReference(self, refID): # this is called by RemoteReference.__del__ # the WeakValueDictionary means we don't have to explicitly remove it #del self.remoteReferences[refID] try: child = DecRefSlicer(self) self.slice2(child, refID) except: print "failure during freeRemoteReference" f = failure.Failure() print f.getTraceback() raise